From edb338ff7c88f7ccd2372acf42f39c58a106ffae Mon Sep 17 00:00:00 2001 From: Antoine Galataud Date: Tue, 12 Oct 2021 16:21:02 +0200 Subject: [PATCH] [RLlib] Check `training_enabled` on PolicyServer (#19007) --- rllib/evaluation/sampler.py | 3 +- rllib/evaluation/tests/test_rollout_worker.py | 32 +++++++++++++++++++ 2 files changed, 34 insertions(+), 1 deletion(-) diff --git a/rllib/evaluation/sampler.py b/rllib/evaluation/sampler.py index bcf7e314a..425738123 100644 --- a/rllib/evaluation/sampler.py +++ b/rllib/evaluation/sampler.py @@ -833,7 +833,8 @@ def _process_observations( sample_collector.add_init_obs(episode, agent_id, env_id, policy_id, episode.length - 1, filtered_obs) - else: + elif agent_infos is None or agent_infos.get( + "training_enabled", True): # Add actions, rewards, next-obs to collectors. values_dict = { SampleBatch.T: episode.length - 1, diff --git a/rllib/evaluation/tests/test_rollout_worker.py b/rllib/evaluation/tests/test_rollout_worker.py index 97e558173..a76615800 100644 --- a/rllib/evaluation/tests/test_rollout_worker.py +++ b/rllib/evaluation/tests/test_rollout_worker.py @@ -708,6 +708,38 @@ class TestRolloutWorker(unittest.TestCase): self.assertTrue(isinstance(ev.env, VideoMonitor)) ev.stop() + def test_no_training(self): + class NoTrainingEnv(MockEnv): + def __init__(self, episode_length, training_enabled): + super(NoTrainingEnv, self).__init__(episode_length) + self.training_enabled = training_enabled + + def step(self, action): + obs, rew, done, info = super(NoTrainingEnv, self).step(action) + return obs, rew, done, { + **info, "training_enabled": self.training_enabled + } + + ev = RolloutWorker( + env_creator=lambda _: NoTrainingEnv(10, True), + policy_spec=MockPolicy, + rollout_fragment_length=5, + batch_mode="complete_episodes") + batch = ev.sample() + self.assertEqual(batch.count, 10) + self.assertEqual(len(batch["obs"]), 10) + ev.stop() + + ev = RolloutWorker( + env_creator=lambda _: NoTrainingEnv(10, False), + policy_spec=MockPolicy, + rollout_fragment_length=5, + batch_mode="complete_episodes") + batch = ev.sample() + self.assertTrue(isinstance(batch, MultiAgentBatch)) + self.assertEqual(len(batch.policy_batches), 0) + ev.stop() + def sample_and_flush(self, ev): time.sleep(2) ev.sample()