mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[RLlib] Check training_enabled
on PolicyServer (#19007)
This commit is contained in:
parent
cbbd349df9
commit
edb338ff7c
2 changed files with 34 additions and 1 deletions
|
@ -833,7 +833,8 @@ def _process_observations(
|
|||
sample_collector.add_init_obs(episode, agent_id, env_id,
|
||||
policy_id, episode.length - 1,
|
||||
filtered_obs)
|
||||
else:
|
||||
elif agent_infos is None or agent_infos.get(
|
||||
"training_enabled", True):
|
||||
# Add actions, rewards, next-obs to collectors.
|
||||
values_dict = {
|
||||
SampleBatch.T: episode.length - 1,
|
||||
|
|
|
@ -708,6 +708,38 @@ class TestRolloutWorker(unittest.TestCase):
|
|||
self.assertTrue(isinstance(ev.env, VideoMonitor))
|
||||
ev.stop()
|
||||
|
||||
def test_no_training(self):
|
||||
class NoTrainingEnv(MockEnv):
|
||||
def __init__(self, episode_length, training_enabled):
|
||||
super(NoTrainingEnv, self).__init__(episode_length)
|
||||
self.training_enabled = training_enabled
|
||||
|
||||
def step(self, action):
|
||||
obs, rew, done, info = super(NoTrainingEnv, self).step(action)
|
||||
return obs, rew, done, {
|
||||
**info, "training_enabled": self.training_enabled
|
||||
}
|
||||
|
||||
ev = RolloutWorker(
|
||||
env_creator=lambda _: NoTrainingEnv(10, True),
|
||||
policy_spec=MockPolicy,
|
||||
rollout_fragment_length=5,
|
||||
batch_mode="complete_episodes")
|
||||
batch = ev.sample()
|
||||
self.assertEqual(batch.count, 10)
|
||||
self.assertEqual(len(batch["obs"]), 10)
|
||||
ev.stop()
|
||||
|
||||
ev = RolloutWorker(
|
||||
env_creator=lambda _: NoTrainingEnv(10, False),
|
||||
policy_spec=MockPolicy,
|
||||
rollout_fragment_length=5,
|
||||
batch_mode="complete_episodes")
|
||||
batch = ev.sample()
|
||||
self.assertTrue(isinstance(batch, MultiAgentBatch))
|
||||
self.assertEqual(len(batch.policy_batches), 0)
|
||||
ev.stop()
|
||||
|
||||
def sample_and_flush(self, ev):
|
||||
time.sleep(2)
|
||||
ev.sample()
|
||||
|
|
Loading…
Add table
Reference in a new issue