ray/rllib/agents/dqn/tests/test_apex_dqn.py

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

163 lines
5.2 KiB
Python
Raw Normal View History

import pytest
import unittest
import ray
import ray.rllib.agents.dqn.apex as apex
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID
from ray.rllib.utils.metrics.learner_info import LEARNER_INFO, LEARNER_STATS_KEY
from ray.rllib.utils.test_utils import (
check,
check_compute_single_action,
check_train_results,
framework_iterator,
)
class TestApexDQN(unittest.TestCase):
def setUp(self):
ray.init(num_cpus=4)
def tearDown(self):
ray.shutdown()
def test_apex_zero_workers(self):
config = (
apex.ApexConfig()
.rollouts(num_rollout_workers=0)
.resources(num_gpus=0)
.training(
replay_buffer_config={
"learning_starts": 1000,
},
optimizer={
"num_replay_buffer_shards": 1,
},
)
.reporting(
min_sample_timesteps_per_reporting=100,
min_time_s_per_reporting=1,
)
)
2020-07-11 22:06:35 +02:00
for _ in framework_iterator(config):
trainer = config.build(env="CartPole-v0")
results = trainer.train()
check_train_results(results)
print(results)
trainer.stop()
def test_apex_dqn_compilation_and_per_worker_epsilon_values(self):
"""Test whether an APEX-DQNTrainer can be built on all frameworks."""
config = (
apex.ApexConfig()
.rollouts(num_rollout_workers=3)
.resources(num_gpus=0)
.training(
replay_buffer_config={
"learning_starts": 1000,
},
optimizer={
"num_replay_buffer_shards": 1,
},
)
.reporting(
min_sample_timesteps_per_reporting=100,
min_time_s_per_reporting=1,
)
)
[RLlib] DQN torch version. (#7597) * Fix. * Rollback. * WIP. * WIP. * WIP. * WIP. * WIP. * WIP. * WIP. * WIP. * Fix. * Fix. * Fix. * Fix. * Fix. * WIP. * WIP. * Fix. * Test case fixes. * Test case fixes and LINT. * Test case fixes and LINT. * Rollback. * WIP. * WIP. * Test case fixes. * Fix. * Fix. * Fix. * Add regression test for DQN w/ param noise. * Fixes and LINT. * Fixes and LINT. * Fixes and LINT. * Fixes and LINT. * Fixes and LINT. * Comment * Regression test case. * WIP. * WIP. * LINT. * LINT. * WIP. * Fix. * Fix. * Fix. * LINT. * Fix (SAC does currently not support eager). * Fix. * WIP. * LINT. * Update rllib/evaluation/sampler.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/evaluation/sampler.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/utils/exploration/exploration.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/utils/exploration/exploration.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * WIP. * WIP. * Fix. * LINT. * LINT. * Fix and LINT. * WIP. * WIP. * WIP. * WIP. * Fix. * LINT. * Fix. * Fix and LINT. * Update rllib/utils/exploration/exploration.py * Update rllib/policy/dynamic_tf_policy.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/policy/dynamic_tf_policy.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/policy/dynamic_tf_policy.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Fixes. * WIP. * LINT. * Fixes and LINT. * LINT and fixes. * LINT. * Move action_dist back into torch extra_action_out_fn and LINT. * Working SimpleQ learning cartpole on both torch AND tf. * Working Rainbow learning cartpole on tf. * Working Rainbow learning cartpole on tf. * WIP. * LINT. * LINT. * Update docs and add torch to APEX test. * LINT. * Fix. * LINT. * Fix. * Fix. * Fix and docstrings. * Fix broken RLlib tests in master. * Split BAZEL learning tests into cartpole and pendulum (reached the 60min barrier). * Fix error_outputs option in BAZEL for RLlib regression tests. * Fix. * Tune param-noise tests. * LINT. * Fix. * Fix. * test * test * test * Fix. * Fix. * WIP. * WIP. * WIP. * WIP. * LINT. * WIP. Co-authored-by: Eric Liang <ekhliang@gmail.com>
2020-04-06 20:56:16 +02:00
for _ in framework_iterator(config, with_eager_tracing=True):
trainer = config.build(env="CartPole-v0")
# Test per-worker epsilon distribution.
[RLlib] DQN torch version. (#7597) * Fix. * Rollback. * WIP. * WIP. * WIP. * WIP. * WIP. * WIP. * WIP. * WIP. * Fix. * Fix. * Fix. * Fix. * Fix. * WIP. * WIP. * Fix. * Test case fixes. * Test case fixes and LINT. * Test case fixes and LINT. * Rollback. * WIP. * WIP. * Test case fixes. * Fix. * Fix. * Fix. * Add regression test for DQN w/ param noise. * Fixes and LINT. * Fixes and LINT. * Fixes and LINT. * Fixes and LINT. * Fixes and LINT. * Comment * Regression test case. * WIP. * WIP. * LINT. * LINT. * WIP. * Fix. * Fix. * Fix. * LINT. * Fix (SAC does currently not support eager). * Fix. * WIP. * LINT. * Update rllib/evaluation/sampler.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/evaluation/sampler.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/utils/exploration/exploration.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/utils/exploration/exploration.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * WIP. * WIP. * Fix. * LINT. * LINT. * Fix and LINT. * WIP. * WIP. * WIP. * WIP. * Fix. * LINT. * Fix. * Fix and LINT. * Update rllib/utils/exploration/exploration.py * Update rllib/policy/dynamic_tf_policy.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/policy/dynamic_tf_policy.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/policy/dynamic_tf_policy.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Fixes. * WIP. * LINT. * Fixes and LINT. * LINT and fixes. * LINT. * Move action_dist back into torch extra_action_out_fn and LINT. * Working SimpleQ learning cartpole on both torch AND tf. * Working Rainbow learning cartpole on tf. * Working Rainbow learning cartpole on tf. * WIP. * LINT. * LINT. * Update docs and add torch to APEX test. * LINT. * Fix. * LINT. * Fix. * Fix. * Fix and docstrings. * Fix broken RLlib tests in master. * Split BAZEL learning tests into cartpole and pendulum (reached the 60min barrier). * Fix error_outputs option in BAZEL for RLlib regression tests. * Fix. * Tune param-noise tests. * LINT. * Fix. * Fix. * test * test * test * Fix. * Fix. * WIP. * WIP. * WIP. * WIP. * LINT. * WIP. Co-authored-by: Eric Liang <ekhliang@gmail.com>
2020-04-06 20:56:16 +02:00
infos = trainer.workers.foreach_policy(
lambda p, _: p.get_exploration_state()
)
expected = [0.4, 0.016190862, 0.00065536]
check([i["cur_epsilon"] for i in infos], [0.0] + expected)
check_compute_single_action(trainer)
for i in range(2):
results = trainer.train()
check_train_results(results)
print(results)
# Test again per-worker epsilon distribution
# (should not have changed).
infos = trainer.workers.foreach_policy(
lambda p, _: p.get_exploration_state()
)
check([i["cur_epsilon"] for i in infos], [0.0] + expected)
trainer.stop()
def test_apex_lr_schedule(self):
config = (
apex.ApexConfig()
.rollouts(
num_rollout_workers=1,
rollout_fragment_length=5,
)
.resources(num_gpus=0)
.training(
train_batch_size=10,
optimizer={
"num_replay_buffer_shards": 1,
# This makes sure learning schedule is checked every 10 timesteps.
"max_weight_sync_delay": 10,
},
replay_buffer_config={
"no_local_replay_buffer": True,
"type": "MultiAgentPrioritizedReplayBuffer",
"learning_starts": 10,
"capacity": 100,
"prioritized_replay_alpha": 0.6,
# Beta parameter for sampling from prioritized replay buffer.
"prioritized_replay_beta": 0.4,
# Epsilon to add to the TD errors when updating priorities.
"prioritized_replay_eps": 1e-6,
},
# Initial lr, doesn't really matter because of the schedule below.
lr=0.2,
lr_schedule=[[0, 0.2], [100, 0.001]],
)
.reporting(
min_sample_timesteps_per_reporting=10,
# 0 metrics reporting delay, this makes sure timestep,
# which lr depends on, is updated after each worker rollout.
min_time_s_per_reporting=0,
)
)
def _step_n_times(trainer, n: int):
"""Step trainer n times.
Returns:
learning rate at the end of the execution.
"""
for _ in range(n):
results = trainer.train()
return results["info"][LEARNER_INFO][DEFAULT_POLICY_ID][LEARNER_STATS_KEY][
"cur_lr"
]
for _ in framework_iterator(config):
trainer = config.build(env="CartPole-v0")
lr = _step_n_times(trainer, 5) # 50 timesteps
# Close to 0.2
self.assertGreaterEqual(lr, 0.1)
lr = _step_n_times(trainer, 20) # 200 timesteps
# LR Annealed to 0.001
self.assertLessEqual(lr, 0.0011)
trainer.stop()
if __name__ == "__main__":
import sys
sys.exit(pytest.main(["-v", __file__]))