[RLlib] Check training_enabled on PolicyServer (#19007)

This commit is contained in:
Antoine Galataud 2021-10-12 16:21:02 +02:00 committed by GitHub
parent cbbd349df9
commit edb338ff7c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 34 additions and 1 deletions

View file

@ -833,7 +833,8 @@ def _process_observations(
sample_collector.add_init_obs(episode, agent_id, env_id,
policy_id, episode.length - 1,
filtered_obs)
else:
elif agent_infos is None or agent_infos.get(
"training_enabled", True):
# Add actions, rewards, next-obs to collectors.
values_dict = {
SampleBatch.T: episode.length - 1,

View file

@ -708,6 +708,38 @@ class TestRolloutWorker(unittest.TestCase):
self.assertTrue(isinstance(ev.env, VideoMonitor))
ev.stop()
def test_no_training(self):
class NoTrainingEnv(MockEnv):
def __init__(self, episode_length, training_enabled):
super(NoTrainingEnv, self).__init__(episode_length)
self.training_enabled = training_enabled
def step(self, action):
obs, rew, done, info = super(NoTrainingEnv, self).step(action)
return obs, rew, done, {
**info, "training_enabled": self.training_enabled
}
ev = RolloutWorker(
env_creator=lambda _: NoTrainingEnv(10, True),
policy_spec=MockPolicy,
rollout_fragment_length=5,
batch_mode="complete_episodes")
batch = ev.sample()
self.assertEqual(batch.count, 10)
self.assertEqual(len(batch["obs"]), 10)
ev.stop()
ev = RolloutWorker(
env_creator=lambda _: NoTrainingEnv(10, False),
policy_spec=MockPolicy,
rollout_fragment_length=5,
batch_mode="complete_episodes")
batch = ev.sample()
self.assertTrue(isinstance(batch, MultiAgentBatch))
self.assertEqual(len(batch.policy_batches), 0)
ev.stop()
def sample_and_flush(self, ev):
time.sleep(2)
ev.sample()