mirror of
https://github.com/vale981/ray
synced 2025-03-05 10:01:43 -05:00
74 lines
2.3 KiB
Python
74 lines
2.3 KiB
Python
import gym
|
|
import numpy as np
|
|
import unittest
|
|
|
|
import ray
|
|
from ray.rllib.algorithms.dqn import DQN
|
|
from ray.rllib.utils.test_utils import framework_iterator
|
|
from ray.tune.registry import register_env
|
|
|
|
|
|
class TestReproducibility(unittest.TestCase):
|
|
def test_reproducing_trajectory(self):
|
|
class PickLargest(gym.Env):
|
|
def __init__(self):
|
|
self.observation_space = gym.spaces.Box(
|
|
low=float("-inf"), high=float("inf"), shape=(4,)
|
|
)
|
|
self.action_space = gym.spaces.Discrete(4)
|
|
|
|
def reset(self, **kwargs):
|
|
self.obs = np.random.randn(4)
|
|
return self.obs
|
|
|
|
def step(self, action):
|
|
reward = self.obs[action]
|
|
return self.obs, reward, True, {}
|
|
|
|
def env_creator(env_config):
|
|
return PickLargest()
|
|
|
|
for fw in framework_iterator(frameworks=("tf", "torch")):
|
|
trajs = list()
|
|
for trial in range(3):
|
|
ray.init()
|
|
register_env("PickLargest", env_creator)
|
|
config = {
|
|
"seed": 666 if trial in [0, 1] else 999,
|
|
"min_time_s_per_iteration": 0,
|
|
"min_sample_timesteps_per_iteration": 100,
|
|
"framework": fw,
|
|
}
|
|
agent = DQN(config=config, env="PickLargest")
|
|
|
|
trajectory = list()
|
|
for _ in range(8):
|
|
r = agent.train()
|
|
trajectory.append(r["episode_reward_max"])
|
|
trajectory.append(r["episode_reward_min"])
|
|
trajs.append(trajectory)
|
|
|
|
ray.shutdown()
|
|
|
|
# trial0 and trial1 use same seed and thus
|
|
# expect identical trajectories.
|
|
all_same = True
|
|
for v0, v1 in zip(trajs[0], trajs[1]):
|
|
if v0 != v1:
|
|
all_same = False
|
|
self.assertTrue(all_same)
|
|
|
|
# trial1 and trial2 use different seeds and thus
|
|
# most rewards tend to be different.
|
|
diff_cnt = 0
|
|
for v1, v2 in zip(trajs[1], trajs[2]):
|
|
if v1 != v2:
|
|
diff_cnt += 1
|
|
self.assertTrue(diff_cnt > 8)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import pytest
|
|
import sys
|
|
|
|
sys.exit(pytest.main(["-v", __file__]))
|