mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00

* Exploration API (+EpsilonGreedy sub-class). * Exploration API (+EpsilonGreedy sub-class). * Cleanup/LINT. * Add `deterministic` to generic Trainer config (NOTE: this is still ignored by most Agents). * Add `error` option to deprecation_warning(). * WIP. * Bug fix: Get exploration-info for tf framework. Bug fix: Properly deprecate some DQN config keys. * WIP. * LINT. * WIP. * Split PerWorkerEpsilonGreedy out of EpsilonGreedy. Docstrings. * Fix bug in sampler.py in case Policy has self.exploration = None * Update rllib/agents/dqn/dqn.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * WIP. * Update rllib/agents/trainer.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * WIP. * Change requests. * LINT * In tune/utils/util.py::deep_update() Only keep deep_updat'ing if both original and value are dicts. If value is not a dict, set * Completely obsolete syn_replay_optimizer.py's parameters schedule_max_timesteps AND beta_annealing_fraction (replaced with prioritized_replay_beta_annealing_timesteps). * Update rllib/evaluation/worker_set.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Review fixes. * Fix default value for DQN's exploration spec. * LINT * Fix recursion bug (wrong parent c'tor). * Do not pass timestep to get_exploration_info. * Update tf_policy.py * Fix some remaining issues with test cases and remove more deprecated DQN/APEX exploration configs. * Bug fix tf-action-dist * DDPG incompatibility bug fix with new DQN exploration handling (which is imported by DDPG). * Switch off exploration when getting action probs from off-policy-estimator's policy. * LINT * Fix test_checkpoint_restore.py. * Deprecate all SAC exploration (unused) configs. * Properly use `model.last_output()` everywhere. Instead of `model._last_output`. * WIP. * Take out set_epsilon from multi-agent-env test (not needed, decays anyway). * WIP. * Trigger re-test (flaky checkpoint-restore test). * WIP. * WIP. * Add test case for deterministic action sampling in PPO. * bug fix. * Added deterministic test cases for different Agents. * Fix problem with TupleActions in dynamic-tf-policy. * Separate supported_spaces tests so they can be run separately for easier debugging. * LINT. * Fix autoregressive_action_dist.py test case. * Re-test. * Fix. * Remove duplicate py_test rule from bazel. * LINT. * WIP. * WIP. * SAC fix. * SAC fix. * WIP. * WIP. * WIP. * FIX 2 examples tests. * WIP. * WIP. * WIP. * WIP. * WIP. * Fix. * LINT. * Renamed test file. * WIP. * Add unittest.main. * Make action_dist_class mandatory. * fix * FIX. * WIP. * WIP. * Fix. * Fix. * Fix explorations test case (contextlib cannot find its own nullcontext??). * Force torch to be installed for QMIX. * LINT. * Fix determine_tests_to_run.py. * Fix determine_tests_to_run.py. * WIP * Add Random exploration component to tests (fixed issue with "static-graph randomness" via py_function). * Add Random exploration component to tests (fixed issue with "static-graph randomness" via py_function). * Rename some stuff. * Rename some stuff. * WIP. * WIP. * Fix SAC. * Fix SAC. * Fix strange tf-error in ray core tests. * Fix strange ray-core tf-error in test_memory_scheduling test case. * Fix test_io.py. * LINT. * Update SAC yaml files' config. Co-authored-by: Eric Liang <ekhliang@gmail.com>
291 lines
9.2 KiB
Python
291 lines
9.2 KiB
Python
import gym
|
|
from gym.spaces import Box, Discrete, Tuple, Dict, MultiDiscrete
|
|
from gym.envs.registration import EnvSpec
|
|
import numpy as np
|
|
import sys
|
|
import unittest
|
|
import traceback
|
|
|
|
import ray
|
|
from ray.rllib.utils.framework import try_import_tf
|
|
from ray.rllib.agents.registry import get_agent_class
|
|
from ray.rllib.models.tf.fcnet_v2 import FullyConnectedNetwork as FCNetV2
|
|
from ray.rllib.models.tf.visionnet_v2 import VisionNetwork as VisionNetV2
|
|
from ray.rllib.tests.test_multi_agent_env import MultiCartpole, \
|
|
MultiMountainCar
|
|
from ray.rllib.utils.error import UnsupportedSpaceException
|
|
from ray.tune.registry import register_env
|
|
tf = try_import_tf()
|
|
|
|
ACTION_SPACES_TO_TEST = {
|
|
"discrete": Discrete(5),
|
|
"vector": Box(-1.0, 1.0, (5, ), dtype=np.float32),
|
|
"vector2": Box(-1.0, 1.0, (
|
|
5,
|
|
5,
|
|
), dtype=np.float32),
|
|
"multidiscrete": MultiDiscrete([1, 2, 3, 4]),
|
|
"tuple": Tuple(
|
|
[Discrete(2),
|
|
Discrete(3),
|
|
Box(-1.0, 1.0, (5, ), dtype=np.float32)]),
|
|
}
|
|
|
|
OBSERVATION_SPACES_TO_TEST = {
|
|
"discrete": Discrete(5),
|
|
"vector": Box(-1.0, 1.0, (5, ), dtype=np.float32),
|
|
"vector2": Box(-1.0, 1.0, (5, 5), dtype=np.float32),
|
|
"image": Box(-1.0, 1.0, (84, 84, 1), dtype=np.float32),
|
|
"atari": Box(-1.0, 1.0, (210, 160, 3), dtype=np.float32),
|
|
"tuple": Tuple([Discrete(10),
|
|
Box(-1.0, 1.0, (5, ), dtype=np.float32)]),
|
|
"dict": Dict({
|
|
"task": Discrete(10),
|
|
"position": Box(-1.0, 1.0, (5, ), dtype=np.float32),
|
|
}),
|
|
}
|
|
|
|
|
|
def make_stub_env(action_space, obs_space, check_action_bounds):
|
|
class StubEnv(gym.Env):
|
|
def __init__(self):
|
|
self.action_space = action_space
|
|
self.observation_space = obs_space
|
|
self.spec = EnvSpec("StubEnv-v0")
|
|
|
|
def reset(self):
|
|
sample = self.observation_space.sample()
|
|
return sample
|
|
|
|
def step(self, action):
|
|
if check_action_bounds and not self.action_space.contains(action):
|
|
raise ValueError("Illegal action for {}: {}".format(
|
|
self.action_space, action))
|
|
if (isinstance(self.action_space, Tuple)
|
|
and len(action) != len(self.action_space.spaces)):
|
|
raise ValueError("Illegal action for {}: {}".format(
|
|
self.action_space, action))
|
|
return self.observation_space.sample(), 1, True, {}
|
|
|
|
return StubEnv
|
|
|
|
|
|
def check_support(alg, config, stats, check_bounds=False, name=None):
|
|
covered_a = set()
|
|
covered_o = set()
|
|
config["log_level"] = "ERROR"
|
|
first_error = None
|
|
for a_name, action_space in ACTION_SPACES_TO_TEST.items():
|
|
for o_name, obs_space in OBSERVATION_SPACES_TO_TEST.items():
|
|
print("=== Testing {} A={} S={} ===".format(
|
|
alg, action_space, obs_space))
|
|
stub_env = make_stub_env(action_space, obs_space, check_bounds)
|
|
register_env("stub_env", lambda c: stub_env())
|
|
stat = "ok"
|
|
a = None
|
|
try:
|
|
if a_name in covered_a and o_name in covered_o:
|
|
stat = "skip" # speed up tests by avoiding full grid
|
|
else:
|
|
a = get_agent_class(alg)(config=config, env="stub_env")
|
|
if alg not in ["DDPG", "ES", "ARS", "SAC"]:
|
|
if o_name in ["atari", "image"]:
|
|
assert isinstance(a.get_policy().model,
|
|
VisionNetV2)
|
|
elif o_name in ["vector", "vector2"]:
|
|
assert isinstance(a.get_policy().model, FCNetV2)
|
|
a.train()
|
|
covered_a.add(a_name)
|
|
covered_o.add(o_name)
|
|
except UnsupportedSpaceException:
|
|
stat = "unsupported"
|
|
except Exception as e:
|
|
stat = "ERROR"
|
|
print(e)
|
|
print(traceback.format_exc())
|
|
first_error = first_error if first_error is not None else e
|
|
finally:
|
|
if a:
|
|
try:
|
|
a.stop()
|
|
except Exception as e:
|
|
print("Ignoring error stopping agent", e)
|
|
pass
|
|
print(stat)
|
|
print()
|
|
stats[name or alg, a_name, o_name] = stat
|
|
|
|
# If anything happened, raise error.
|
|
if first_error is not None:
|
|
raise first_error
|
|
|
|
|
|
def check_support_multiagent(alg, config):
|
|
register_env("multi_mountaincar", lambda _: MultiMountainCar(2))
|
|
register_env("multi_cartpole", lambda _: MultiCartpole(2))
|
|
config["log_level"] = "ERROR"
|
|
if "DDPG" in alg:
|
|
a = get_agent_class(alg)(config=config, env="multi_mountaincar")
|
|
else:
|
|
a = get_agent_class(alg)(config=config, env="multi_cartpole")
|
|
try:
|
|
a.train()
|
|
finally:
|
|
a.stop()
|
|
|
|
|
|
class ModelSupportedSpaces(unittest.TestCase):
|
|
stats = {}
|
|
|
|
def setUp(self):
|
|
ray.init(num_cpus=4, ignore_reinit_error=True)
|
|
|
|
def tearDown(self):
|
|
ray.shutdown()
|
|
|
|
def test_a3c(self):
|
|
check_support(
|
|
"A3C", {
|
|
"num_workers": 1,
|
|
"optimizer": {
|
|
"grads_per_step": 1
|
|
}
|
|
},
|
|
self.stats,
|
|
check_bounds=True)
|
|
|
|
def test_appo(self):
|
|
check_support("APPO", {"num_gpus": 0, "vtrace": False}, self.stats)
|
|
check_support(
|
|
"APPO", {
|
|
"num_gpus": 0,
|
|
"vtrace": True
|
|
},
|
|
self.stats,
|
|
name="APPO-vt")
|
|
|
|
def test_ars(self):
|
|
check_support(
|
|
"ARS", {
|
|
"num_workers": 1,
|
|
"noise_size": 10000000,
|
|
"num_rollouts": 1,
|
|
"rollouts_used": 1
|
|
}, self.stats)
|
|
|
|
def test_ddpg(self):
|
|
check_support(
|
|
"DDPG", {
|
|
"exploration_ou_noise_scale": 100.0,
|
|
"timesteps_per_iteration": 1,
|
|
"use_state_preprocessor": True,
|
|
},
|
|
self.stats,
|
|
check_bounds=True)
|
|
|
|
def test_dqn(self):
|
|
check_support("DQN", {"timesteps_per_iteration": 1}, self.stats)
|
|
|
|
def test_es(self):
|
|
check_support(
|
|
"ES", {
|
|
"num_workers": 1,
|
|
"noise_size": 10000000,
|
|
"episodes_per_batch": 1,
|
|
"train_batch_size": 1
|
|
}, self.stats)
|
|
|
|
def test_impala(self):
|
|
check_support("IMPALA", {"num_gpus": 0}, self.stats)
|
|
|
|
def test_ppo(self):
|
|
check_support(
|
|
"PPO", {
|
|
"num_workers": 1,
|
|
"num_sgd_iter": 1,
|
|
"train_batch_size": 10,
|
|
"sample_batch_size": 10,
|
|
"sgd_minibatch_size": 1,
|
|
},
|
|
self.stats,
|
|
check_bounds=True)
|
|
|
|
def test_pg(self):
|
|
check_support(
|
|
"PG", {
|
|
"num_workers": 1,
|
|
"optimizer": {}
|
|
},
|
|
self.stats,
|
|
check_bounds=True)
|
|
|
|
def test_sac(self):
|
|
check_support("SAC", {}, self.stats, check_bounds=True)
|
|
|
|
def test_a3c_multiagent(self):
|
|
check_support_multiagent("A3C", {
|
|
"num_workers": 1,
|
|
"optimizer": {
|
|
"grads_per_step": 1
|
|
}
|
|
})
|
|
|
|
def test_apex_multiagent(self):
|
|
check_support_multiagent(
|
|
"APEX", {
|
|
"num_workers": 2,
|
|
"timesteps_per_iteration": 1000,
|
|
"num_gpus": 0,
|
|
"min_iter_time_s": 1,
|
|
"learning_starts": 1000,
|
|
"target_network_update_freq": 100,
|
|
})
|
|
|
|
def test_apex_ddpg_multiagent(self):
|
|
check_support_multiagent(
|
|
"APEX_DDPG", {
|
|
"num_workers": 2,
|
|
"timesteps_per_iteration": 1000,
|
|
"num_gpus": 0,
|
|
"min_iter_time_s": 1,
|
|
"learning_starts": 1000,
|
|
"target_network_update_freq": 100,
|
|
"use_state_preprocessor": True,
|
|
})
|
|
|
|
def test_ddpg_multiagent(self):
|
|
check_support_multiagent("DDPG", {
|
|
"timesteps_per_iteration": 1,
|
|
"use_state_preprocessor": True,
|
|
})
|
|
|
|
def test_dqn_multiagent(self):
|
|
check_support_multiagent("DQN", {"timesteps_per_iteration": 1})
|
|
|
|
def test_impala_multiagent(self):
|
|
check_support_multiagent("IMPALA", {"num_gpus": 0})
|
|
|
|
def test_pg_multiagent(self):
|
|
check_support_multiagent("PG", {"num_workers": 1, "optimizer": {}})
|
|
|
|
def test_ppo_multiagent(self):
|
|
check_support_multiagent(
|
|
"PPO", {
|
|
"num_workers": 1,
|
|
"num_sgd_iter": 1,
|
|
"train_batch_size": 10,
|
|
"sample_batch_size": 10,
|
|
"sgd_minibatch_size": 1,
|
|
})
|
|
|
|
|
|
if __name__ == "__main__":
|
|
if len(sys.argv) > 1 and sys.argv[1] == "--smoke":
|
|
ACTION_SPACES_TO_TEST = {
|
|
"discrete": Discrete(5),
|
|
}
|
|
OBSERVATION_SPACES_TO_TEST = {
|
|
"vector": Box(0.0, 1.0, (5, ), dtype=np.float32),
|
|
"atari": Box(0.0, 1.0, (210, 160, 3), dtype=np.float32),
|
|
}
|
|
unittest.main(verbosity=2)
|