2017-12-30 00:24:54 -08:00
|
|
|
from __future__ import absolute_import
|
|
|
|
from __future__ import division
|
|
|
|
from __future__ import print_function
|
|
|
|
|
|
|
|
import unittest
|
|
|
|
|
2019-01-29 21:19:53 -08:00
|
|
|
import ray
|
2019-04-07 00:36:18 -07:00
|
|
|
from ray.rllib.agents.dqn import DQNTrainer
|
2018-12-08 16:28:58 -08:00
|
|
|
from ray.rllib.agents.dqn.dqn_policy_graph import _adjust_nstep
|
2017-12-30 00:24:54 -08:00
|
|
|
|
|
|
|
|
2018-06-09 00:21:35 -07:00
|
|
|
class DQNTest(unittest.TestCase):
|
2018-01-23 10:31:19 -08:00
|
|
|
def testNStep(self):
|
|
|
|
obs = [1, 2, 3, 4, 5, 6, 7]
|
|
|
|
actions = ["a", "b", "a", "a", "a", "b", "a"]
|
2018-09-28 15:22:33 -07:00
|
|
|
rewards = [10.0, 0.0, 100.0, 100.0, 100.0, 100.0, 100.0]
|
2018-01-23 10:31:19 -08:00
|
|
|
new_obs = [2, 3, 4, 5, 6, 7, 8]
|
2018-09-28 15:22:33 -07:00
|
|
|
dones = [0, 0, 0, 0, 0, 0, 1]
|
2018-12-08 16:28:58 -08:00
|
|
|
_adjust_nstep(3, 0.9, obs, actions, rewards, new_obs, dones)
|
2018-09-28 15:22:33 -07:00
|
|
|
self.assertEqual(obs, [1, 2, 3, 4, 5, 6, 7])
|
|
|
|
self.assertEqual(actions, ["a", "b", "a", "a", "a", "b", "a"])
|
|
|
|
self.assertEqual(new_obs, [4, 5, 6, 7, 8, 8, 8])
|
|
|
|
self.assertEqual(dones, [0, 0, 0, 0, 1, 1, 1])
|
|
|
|
self.assertEqual(rewards,
|
|
|
|
[91.0, 171.0, 271.0, 271.0, 271.0, 190.0, 100.0])
|
2018-01-23 10:31:19 -08:00
|
|
|
|
2019-01-29 21:19:53 -08:00
|
|
|
def testEvaluationOption(self):
|
|
|
|
ray.init()
|
2019-04-07 00:36:18 -07:00
|
|
|
agent = DQNTrainer(
|
|
|
|
env="CartPole-v0", config={"evaluation_interval": 2})
|
2019-01-29 21:19:53 -08:00
|
|
|
r0 = agent.train()
|
|
|
|
r1 = agent.train()
|
|
|
|
r2 = agent.train()
|
|
|
|
r3 = agent.train()
|
|
|
|
r4 = agent.train()
|
|
|
|
self.assertTrue("evaluation" in r0)
|
|
|
|
self.assertTrue("episode_reward_mean" in r0["evaluation"])
|
|
|
|
self.assertEqual(r0["evaluation"], r1["evaluation"])
|
|
|
|
self.assertNotEqual(r1["evaluation"], r2["evaluation"])
|
|
|
|
self.assertEqual(r2["evaluation"], r3["evaluation"])
|
|
|
|
self.assertNotEqual(r3["evaluation"], r4["evaluation"])
|
|
|
|
|
2018-01-23 10:31:19 -08:00
|
|
|
|
2019-02-15 13:32:43 -08:00
|
|
|
if __name__ == "__main__":
|
2017-12-30 00:24:54 -08:00
|
|
|
unittest.main(verbosity=2)
|