ray/rllib/tests/test_reproducibility.py
Sven 60d4d5e1aa Remove future imports (#6724)
* Remove all __future__ imports from RLlib.

* Remove (object) again from tf_run_builder.py::TFRunBuilder.

* Fix 2xLINT warnings.

* Fix broken appo_policy import (must be appo_tf_policy)

* Remove future imports from all other ray files (not just RLlib).

* Remove future imports from all other ray files (not just RLlib).

* Remove future import blocks that contain `unicode_literals` as well.
Revert appo_tf_policy.py to appo_policy.py (belongs to another PR).

* Add two empty lines before Schedule class.

* Put back __future__ imports into determine_tests_to_run.py. Fails otherwise on a py2/print related error.
2020-01-09 00:15:48 -08:00

64 lines
1.9 KiB
Python

import unittest
import ray
from ray.rllib.agents.dqn import DQNTrainer
from ray.tune.registry import register_env
import numpy as np
import gym
class TestReproducibility(unittest.TestCase):
def testReproducingTrajectory(self):
class PickLargest(gym.Env):
def __init__(self):
self.observation_space = gym.spaces.Box(
low=float("-inf"), high=float("inf"), shape=(4, ))
self.action_space = gym.spaces.Discrete(4)
def reset(self, **kwargs):
self.obs = np.random.randn(4)
return self.obs
def step(self, action):
reward = self.obs[action]
return self.obs, reward, True, {}
def env_creator(env_config):
return PickLargest()
trajs = list()
for trial in range(3):
ray.init()
register_env("PickLargest", env_creator)
agent = DQNTrainer(
env="PickLargest",
config={"seed": 666 if trial in [0, 1] else 999})
trajectory = list()
for _ in range(8):
r = agent.train()
trajectory.append(r["episode_reward_max"])
trajectory.append(r["episode_reward_min"])
trajs.append(trajectory)
ray.shutdown()
# trial0 and trial1 use same seed and thus
# expect identical trajectories.
all_same = True
for v0, v1 in zip(trajs[0], trajs[1]):
if v0 != v1:
all_same = False
self.assertTrue(all_same)
# trial1 and trial2 use different seeds and thus
# most rewards tend to be different.
diff_cnt = 0
for v1, v2 in zip(trajs[1], trajs[2]):
if v1 != v2:
diff_cnt += 1
self.assertTrue(diff_cnt > 8)
if __name__ == "__main__":
unittest.main(verbosity=2)