2020-03-10 11:14:14 -07:00
|
|
|
import numpy as np
|
|
|
|
import pytest
|
|
|
|
import unittest
|
|
|
|
|
|
|
|
import ray
|
|
|
|
import ray.rllib.agents.dqn.apex as apex
|
2020-04-06 20:56:16 +02:00
|
|
|
from ray.rllib.utils.test_utils import framework_iterator
|
2020-03-10 11:14:14 -07:00
|
|
|
|
|
|
|
|
|
|
|
class TestApex(unittest.TestCase):
|
|
|
|
def setUp(self):
|
|
|
|
ray.init(num_cpus=4)
|
|
|
|
|
|
|
|
def tearDown(self):
|
|
|
|
ray.shutdown()
|
|
|
|
|
2020-04-20 10:03:25 +02:00
|
|
|
def test_apex_compilation_and_per_worker_epsilon_values(self):
|
|
|
|
"""Test whether an APEX-DQNTrainer can be built on all frameworks."""
|
2020-03-10 11:14:14 -07:00
|
|
|
config = apex.APEX_DEFAULT_CONFIG.copy()
|
|
|
|
config["num_workers"] = 3
|
2020-04-20 10:03:25 +02:00
|
|
|
config["prioritized_replay"] = True
|
2020-03-10 11:14:14 -07:00
|
|
|
config["optimizer"]["num_replay_buffer_shards"] = 1
|
2020-04-20 10:03:25 +02:00
|
|
|
num_iterations = 1
|
2020-04-06 20:56:16 +02:00
|
|
|
|
2020-04-20 10:03:25 +02:00
|
|
|
for _ in framework_iterator(config, ("torch", "tf", "eager")):
|
|
|
|
plain_config = config.copy()
|
|
|
|
trainer = apex.ApexTrainer(config=plain_config, env="CartPole-v0")
|
|
|
|
|
|
|
|
# Test per-worker epsilon distribution.
|
2020-04-06 20:56:16 +02:00
|
|
|
infos = trainer.workers.foreach_policy(
|
|
|
|
lambda p, _: p.get_exploration_info())
|
|
|
|
eps = [i["cur_epsilon"] for i in infos]
|
|
|
|
assert np.allclose(eps,
|
|
|
|
[1.0, 0.016190862, 0.00065536, 2.6527108e-05])
|
2020-03-10 11:14:14 -07:00
|
|
|
|
2020-04-20 10:03:25 +02:00
|
|
|
for i in range(num_iterations):
|
|
|
|
results = trainer.train()
|
|
|
|
print(results)
|
|
|
|
|
2020-03-10 11:14:14 -07:00
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
import sys
|
|
|
|
sys.exit(pytest.main(["-v", __file__]))
|