ray/rllib/tests/test_eager_support.py

147 lines
4.1 KiB
Python
Raw Normal View History

import unittest
import ray
from ray import tune
from ray.rllib.algorithms.registry import get_algorithm_class
2020-07-11 22:06:35 +02:00
from ray.rllib.utils.framework import try_import_tf
tf1, tf, tfv = try_import_tf()
def check_support(alg, config, test_eager=False, test_trace=True):
config["framework"] = "tfe"
config["log_level"] = "ERROR"
# Test both continuous and discrete actions.
for cont in [True, False]:
if cont and alg in ["DQN", "APEX", "SimpleQ"]:
continue
elif not cont and alg in ["DDPG", "APEX_DDPG", "TD3"]:
continue
if cont:
[RLlib] Upgrade gym version to 0.21 and deprecate pendulum-v0. (#19535) * Fix QMix, SAC, and MADDPA too. * Unpin gym and deprecate pendulum v0 Many tests in rllib depended on pendulum v0, however in gym 0.21, pendulum v0 was deprecated in favor of pendulum v1. This may change reward thresholds, so will have to potentially rerun all of the pendulum v1 benchmarks, or use another environment in favor. The same applies to frozen lake v0 and frozen lake v1 Lastly, all of the RLlib tests and have been moved to python 3.7 * Add gym installation based on python version. Pin python<= 3.6 to gym 0.19 due to install issues with atari roms in gym 0.20 * Reformatting * Fixing tests * Move atari-py install conditional to req.txt * migrate to new ale install method * Fix QMix, SAC, and MADDPA too. * Unpin gym and deprecate pendulum v0 Many tests in rllib depended on pendulum v0, however in gym 0.21, pendulum v0 was deprecated in favor of pendulum v1. This may change reward thresholds, so will have to potentially rerun all of the pendulum v1 benchmarks, or use another environment in favor. The same applies to frozen lake v0 and frozen lake v1 Lastly, all of the RLlib tests and have been moved to python 3.7 * Add gym installation based on python version. Pin python<= 3.6 to gym 0.19 due to install issues with atari roms in gym 0.20 Move atari-py install conditional to req.txt migrate to new ale install method Make parametric_actions_cartpole return float32 actions/obs Adding type conversions if obs/actions don't match space Add utils to make elements match gym space dtypes Co-authored-by: Jun Gong <jungong@anyscale.com> Co-authored-by: sven1977 <svenmika1977@gmail.com>
2021-11-03 08:24:00 -07:00
config["env"] = "Pendulum-v1"
else:
config["env"] = "CartPole-v0"
a = get_algorithm_class(alg)
if test_eager:
print("tf-eager: alg={} cont.act={}".format(alg, cont))
config["eager_tracing"] = False
tune.run(a, config=config, stop={"training_iteration": 1}, verbose=1)
if test_trace:
config["eager_tracing"] = True
print("tf-eager-tracing: alg={} cont.act={}".format(alg, cont))
tune.run(a, config=config, stop={"training_iteration": 1}, verbose=1)
class TestEagerSupportPG(unittest.TestCase):
def setUp(self):
ray.init(num_cpus=4)
def tearDown(self):
ray.shutdown()
2020-07-11 22:06:35 +02:00
def test_simple_q(self):
check_support(
"SimpleQ",
{"num_workers": 0, "replay_buffer_config": {"learning_starts": 0}},
)
2020-07-11 22:06:35 +02:00
def test_dqn(self):
check_support(
"DQN", {"num_workers": 0, "replay_buffer_config": {"learning_starts": 0}}
)
2020-07-11 22:06:35 +02:00
def test_ddpg(self):
check_support("DDPG", {"num_workers": 0})
# TODO(sven): Add these once APEX_DDPG supports eager.
# def test_apex_ddpg(self):
# check_support("APEX_DDPG", {"num_workers": 1})
def test_td3(self):
check_support("TD3", {"num_workers": 0})
def test_a2c(self):
check_support("A2C", {"num_workers": 0})
def test_a3c(self):
check_support("A3C", {"num_workers": 1})
def test_pg(self):
check_support("PG", {"num_workers": 0})
def test_ppo(self):
check_support("PPO", {"num_workers": 0})
def test_appo(self):
check_support("APPO", {"num_workers": 1, "num_gpus": 0})
def test_impala(self):
check_support("IMPALA", {"num_workers": 1, "num_gpus": 0}, test_eager=True)
class TestEagerSupportOffPolicy(unittest.TestCase):
def setUp(self):
ray.init(num_cpus=4)
def tearDown(self):
ray.shutdown()
def test_simple_q(self):
check_support(
"SimpleQ",
{"num_workers": 0, "replay_buffer_config": {"learning_starts": 0}},
)
def test_dqn(self):
check_support(
"DQN", {"num_workers": 0, "replay_buffer_config": {"learning_starts": 0}}
)
def test_ddpg(self):
check_support("DDPG", {"num_workers": 0})
# def test_apex_ddpg(self):
# check_support("APEX_DDPG", {"num_workers": 1})
def test_td3(self):
check_support("TD3", {"num_workers": 0})
def test_apex_dqn(self):
check_support(
"APEX",
{
"num_workers": 2,
"replay_buffer_config": {"learning_starts": 0},
"num_gpus": 0,
"min_time_s_per_iteration": 1,
"min_sample_timesteps_per_iteration": 100,
"optimizer": {
"num_replay_buffer_shards": 1,
},
},
)
def test_sac(self):
check_support(
"SAC", {"num_workers": 0, "replay_buffer_config": {"learning_starts": 0}}
)
[RLlib] Policy.compute_log_likelihoods() and SAC refactor. (issue #7107) (#7124) * Exploration API (+EpsilonGreedy sub-class). * Exploration API (+EpsilonGreedy sub-class). * Cleanup/LINT. * Add `deterministic` to generic Trainer config (NOTE: this is still ignored by most Agents). * Add `error` option to deprecation_warning(). * WIP. * Bug fix: Get exploration-info for tf framework. Bug fix: Properly deprecate some DQN config keys. * WIP. * LINT. * WIP. * Split PerWorkerEpsilonGreedy out of EpsilonGreedy. Docstrings. * Fix bug in sampler.py in case Policy has self.exploration = None * Update rllib/agents/dqn/dqn.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * WIP. * Update rllib/agents/trainer.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * WIP. * Change requests. * LINT * In tune/utils/util.py::deep_update() Only keep deep_updat'ing if both original and value are dicts. If value is not a dict, set * Completely obsolete syn_replay_optimizer.py's parameters schedule_max_timesteps AND beta_annealing_fraction (replaced with prioritized_replay_beta_annealing_timesteps). * Update rllib/evaluation/worker_set.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Review fixes. * Fix default value for DQN's exploration spec. * LINT * Fix recursion bug (wrong parent c'tor). * Do not pass timestep to get_exploration_info. * Update tf_policy.py * Fix some remaining issues with test cases and remove more deprecated DQN/APEX exploration configs. * Bug fix tf-action-dist * DDPG incompatibility bug fix with new DQN exploration handling (which is imported by DDPG). * Switch off exploration when getting action probs from off-policy-estimator's policy. * LINT * Fix test_checkpoint_restore.py. * Deprecate all SAC exploration (unused) configs. * Properly use `model.last_output()` everywhere. Instead of `model._last_output`. * WIP. * Take out set_epsilon from multi-agent-env test (not needed, decays anyway). * WIP. * Trigger re-test (flaky checkpoint-restore test). * WIP. * WIP. * Add test case for deterministic action sampling in PPO. * bug fix. * Added deterministic test cases for different Agents. * Fix problem with TupleActions in dynamic-tf-policy. * Separate supported_spaces tests so they can be run separately for easier debugging. * LINT. * Fix autoregressive_action_dist.py test case. * Re-test. * Fix. * Remove duplicate py_test rule from bazel. * LINT. * WIP. * WIP. * SAC fix. * SAC fix. * WIP. * WIP. * WIP. * FIX 2 examples tests. * WIP. * WIP. * WIP. * WIP. * WIP. * Fix. * LINT. * Renamed test file. * WIP. * Add unittest.main. * Make action_dist_class mandatory. * fix * FIX. * WIP. * WIP. * Fix. * Fix. * Fix explorations test case (contextlib cannot find its own nullcontext??). * Force torch to be installed for QMIX. * LINT. * Fix determine_tests_to_run.py. * Fix determine_tests_to_run.py. * WIP * Add Random exploration component to tests (fixed issue with "static-graph randomness" via py_function). * Add Random exploration component to tests (fixed issue with "static-graph randomness" via py_function). * Rename some stuff. * Rename some stuff. * WIP. * WIP. * Fix SAC. * Fix SAC. * Fix strange tf-error in ray core tests. * Fix strange ray-core tf-error in test_memory_scheduling test case. * Fix test_io.py. * LINT. * Update SAC yaml files' config. Co-authored-by: Eric Liang <ekhliang@gmail.com>
2020-02-22 23:19:49 +01:00
if __name__ == "__main__":
Enable direct calls by default (#6367) * wip * add * timeout fix * const ref * comments * fix * fix * Move actor state into actor handle * comments 2 * enable by default * temp reorder * some fixes * add debug code * tmp * fix * wip * remove dbg * fix compile * fix * fix check * remove non direct tests * Increment ref count before resolving value * rename * fix another bug * tmp * tmp * Fix object pinning * build change * lint * ActorManager * tmp * ActorManager * fix test component failures * Remove old code * Remove unused * fix * fix * fix resources * fix advanced * eric's diff * blacklist * blacklist * cleanup * annotate * disable tests for now * remove * fix * fix * clean up verbosity * fix test * fix concurrency test * Update .travis.yml * Update .travis.yml * Update .travis.yml * split up analysis suite * split up trial runner suite * fix detached direct actors * fix * split up advanced tesT * lint * fix core worker test hang * fix bad check fail which breaks test_cluster.py in tune * fix some minor diffs in test_cluster * less workers * make less stressful * split up test * retry flaky tests * remove old test flags * fixes * lint * Update worker_pool.cc * fix race * fix * fix bugs in node failure handling * fix race condition * fix bugs in node failure handling * fix race condition * nits * fix test * disable heartbeatS * disable heartbeatS * fix * fix * use worker id * fix max fail * debug exit * fix merge, and apply [PATCH] fix concurrency test * [patch] fix core worker test hang * remove NotifyActorCreation, and return worker on completion of actor creation task * remove actor diied callback * Update core_worker.cc * lint * use task manager * fix merge * fix deadlock * wip * merge conflits * fix * better sysexit handling * better sysexit handling * better sysexit handling * check id * better debug * task failed msg * task failed msg * retry failed tasks with delay * retry failed tasks with delay * clip deps * fix * fix core worker tests * fix task manager test * fix all tests * cleanup * set to 0 for direct tests * dont check worker id for ownership rpc * dont check worker id for ownership rpc * debug messages * add comment * remove debug statements * nit * check worker id * fix test * owner * fix tests
2019-12-13 13:58:04 -08:00
import sys
2020-07-11 22:06:35 +02:00
# Don't test anything for version 2.x (all tests are eager anyways).
# TODO: (sven) remove entire file in the future.
if tfv == 2:
print("\tskip due to tf==2.x")
sys.exit(0)
# One can specify the specific TestCase class to run.
# None for all unittest.TestCase classes in this file.
2020-07-11 22:06:35 +02:00
import pytest
2020-07-11 22:06:35 +02:00
class_ = sys.argv[1] if len(sys.argv) > 1 else None
sys.exit(pytest.main(["-v", __file__ + ("" if class_ is None else "::" + class_)]))