mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00

* Fix. * Rollback. * WIP. * WIP. * WIP. * WIP. * WIP. * WIP. * WIP. * WIP. * Fix. * Fix. * Fix. * Fix. * Fix. * WIP. * WIP. * Fix. * Test case fixes. * Test case fixes and LINT. * Test case fixes and LINT. * Rollback. * WIP. * WIP. * Test case fixes. * Fix. * Fix. * Fix. * Add regression test for DQN w/ param noise. * Fixes and LINT. * Fixes and LINT. * Fixes and LINT. * Fixes and LINT. * Fixes and LINT. * Comment * Regression test case. * WIP. * WIP. * LINT. * LINT. * WIP. * Fix. * Fix. * Fix. * LINT. * Fix (SAC does currently not support eager). * Fix. * WIP. * LINT. * Update rllib/evaluation/sampler.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/evaluation/sampler.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/utils/exploration/exploration.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/utils/exploration/exploration.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * WIP. * WIP. * Fix. * LINT. * LINT. * Fix and LINT. * WIP. * WIP. * WIP. * WIP. * Fix. * LINT. * Fix. * Fix and LINT. * Update rllib/utils/exploration/exploration.py * Update rllib/policy/dynamic_tf_policy.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/policy/dynamic_tf_policy.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/policy/dynamic_tf_policy.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Fixes. * WIP. * LINT. * Fixes and LINT. * LINT and fixes. * LINT. * Move action_dist back into torch extra_action_out_fn and LINT. * Working SimpleQ learning cartpole on both torch AND tf. * Working Rainbow learning cartpole on tf. * Working Rainbow learning cartpole on tf. * WIP. * LINT. * LINT. * Update docs and add torch to APEX test. * LINT. * Fix. * LINT. * Fix. * Fix. * Fix and docstrings. * Fix broken RLlib tests in master. * Split BAZEL learning tests into cartpole and pendulum (reached the 60min barrier). * Fix error_outputs option in BAZEL for RLlib regression tests. * Fix. * Tune param-noise tests. * LINT. * Fix. * Fix. * test * test * test * Fix. * Fix. * WIP. * WIP. * WIP. * WIP. * LINT. * WIP. Co-authored-by: Eric Liang <ekhliang@gmail.com>
66 lines
2.4 KiB
Python
66 lines
2.4 KiB
Python
import gym
|
|
import unittest
|
|
|
|
import ray
|
|
from ray.rllib.agents.dqn import DQNTrainer
|
|
from ray.rllib.agents.a3c import A3CTrainer
|
|
from ray.rllib.agents.dqn.dqn_tf_policy import _adjust_nstep
|
|
from ray.tune.registry import register_env
|
|
|
|
|
|
class EvalTest(unittest.TestCase):
|
|
def test_dqn_n_step(self):
|
|
obs = [1, 2, 3, 4, 5, 6, 7]
|
|
actions = ["a", "b", "a", "a", "a", "b", "a"]
|
|
rewards = [10.0, 0.0, 100.0, 100.0, 100.0, 100.0, 100.0]
|
|
new_obs = [2, 3, 4, 5, 6, 7, 8]
|
|
dones = [0, 0, 0, 0, 0, 0, 1]
|
|
_adjust_nstep(3, 0.9, obs, actions, rewards, new_obs, dones)
|
|
self.assertEqual(obs, [1, 2, 3, 4, 5, 6, 7])
|
|
self.assertEqual(actions, ["a", "b", "a", "a", "a", "b", "a"])
|
|
self.assertEqual(new_obs, [4, 5, 6, 7, 8, 8, 8])
|
|
self.assertEqual(dones, [0, 0, 0, 0, 1, 1, 1])
|
|
self.assertEqual(rewards,
|
|
[91.0, 171.0, 271.0, 271.0, 271.0, 190.0, 100.0])
|
|
|
|
def test_evaluation_option(self):
|
|
def env_creator(env_config):
|
|
return gym.make("CartPole-v0")
|
|
|
|
agent_classes = [DQNTrainer, A3CTrainer]
|
|
|
|
for agent_cls in agent_classes:
|
|
ray.init(object_store_memory=1000 * 1024 * 1024)
|
|
register_env("CartPoleWrapped-v0", env_creator)
|
|
agent = agent_cls(
|
|
env="CartPoleWrapped-v0",
|
|
config={
|
|
"evaluation_interval": 2,
|
|
"evaluation_num_episodes": 2,
|
|
"evaluation_config": {
|
|
"gamma": 0.98,
|
|
"env_config": {
|
|
"fake_arg": True
|
|
}
|
|
},
|
|
})
|
|
# Given evaluation_interval=2, r0, r2, r4 should not contain
|
|
# evaluation metrics while r1, r3 should do.
|
|
r0 = agent.train()
|
|
r1 = agent.train()
|
|
r2 = agent.train()
|
|
r3 = agent.train()
|
|
|
|
self.assertTrue("evaluation" in r1)
|
|
self.assertTrue("evaluation" in r3)
|
|
self.assertFalse("evaluation" in r0)
|
|
self.assertFalse("evaluation" in r2)
|
|
self.assertTrue("episode_reward_mean" in r1["evaluation"])
|
|
self.assertNotEqual(r1["evaluation"], r3["evaluation"])
|
|
ray.shutdown()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import pytest
|
|
import sys
|
|
sys.exit(pytest.main(["-v", __file__]))
|