mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
[RLlib] Fix seeding issue (#10589)
This commit is contained in:
parent
34bda32054
commit
6ae9e76b81
2 changed files with 11 additions and 2 deletions
|
@ -388,9 +388,10 @@ class RolloutWorker(ParallelIteratorWorker):
|
|||
np.random.seed(seed)
|
||||
random.seed(seed)
|
||||
if not hasattr(self.env, "seed"):
|
||||
raise ValueError("Env doesn't support env.seed(): {}".format(
|
||||
logger.info("Env doesn't support env.seed(): {}".format(
|
||||
self.env))
|
||||
self.env.seed(seed)
|
||||
else:
|
||||
self.env.seed(seed)
|
||||
try:
|
||||
assert torch is not None
|
||||
torch.manual_seed(seed)
|
||||
|
|
|
@ -626,6 +626,14 @@ class TestRolloutWorker(unittest.TestCase):
|
|||
del os.environ["env_key_1"]
|
||||
del os.environ["env_key_2"]
|
||||
|
||||
def test_no_env_seed(self):
|
||||
ev = RolloutWorker(
|
||||
env_creator=lambda _: MockVectorEnv(episode_length=20, num_envs=8),
|
||||
policy=MockPolicy,
|
||||
seed=1)
|
||||
assert not hasattr(ev.env, "seed")
|
||||
ev.stop()
|
||||
|
||||
def sample_and_flush(self, ev):
|
||||
time.sleep(2)
|
||||
ev.sample()
|
||||
|
|
Loading…
Add table
Reference in a new issue