from gym.spaces import Box, Dict, Discrete, Tuple, MultiDiscrete import numpy as np import unittest import traceback import ray from ray.rllib.utils.framework import try_import_tf from ray.rllib.agents.registry import get_agent_class from ray.rllib.examples.env.multi_agent import MultiAgentCartPole, \ MultiAgentMountainCar from ray.rllib.examples.env.random_env import RandomEnv from ray.rllib.models.tf.fcnet_v2 import FullyConnectedNetwork as FCNetV2 from ray.rllib.models.tf.visionnet_v2 import VisionNetwork as VisionNetV2 from ray.rllib.models.torch.visionnet import VisionNetwork as TorchVisionNetV2 from ray.rllib.models.torch.fcnet import FullyConnectedNetwork as TorchFCNetV2 from ray.rllib.utils.error import UnsupportedSpaceException from ray.tune.registry import register_env tf = try_import_tf() ACTION_SPACES_TO_TEST = { "discrete": Discrete(5), "vector": Box(-1.0, 1.0, (5, ), dtype=np.float32), "vector2": Box(-1.0, 1.0, ( 5, 5, ), dtype=np.float32), "multidiscrete": MultiDiscrete([1, 2, 3, 4]), "tuple": Tuple( [Discrete(2), Discrete(3), Box(-1.0, 1.0, (5, ), dtype=np.float32)]), "dict": Dict({ "action_choice": Discrete(3), "parameters": Box(-1.0, 1.0, (1, ), dtype=np.float32), "yet_another_nested_dict": Dict({ "a": Tuple([Discrete(2), Discrete(3)]) }) }), } OBSERVATION_SPACES_TO_TEST = { "discrete": Discrete(5), "vector": Box(-1.0, 1.0, (5, ), dtype=np.float32), "vector2": Box(-1.0, 1.0, (5, 5), dtype=np.float32), "image": Box(-1.0, 1.0, (84, 84, 1), dtype=np.float32), "atari": Box(-1.0, 1.0, (210, 160, 3), dtype=np.float32), "tuple": Tuple([Discrete(10), Box(-1.0, 1.0, (5, ), dtype=np.float32)]), "dict": Dict({ "task": Discrete(10), "position": Box(-1.0, 1.0, (5, ), dtype=np.float32), }), } def check_support(alg, config, stats, check_bounds=False, name=None): covered_a = set() covered_o = set() config["log_level"] = "ERROR" first_error = None torch = config.get("use_pytorch", False) for a_name, action_space in ACTION_SPACES_TO_TEST.items(): for o_name, obs_space in OBSERVATION_SPACES_TO_TEST.items(): print("=== Testing {} (torch={}) A={} S={} ===".format( alg, torch, action_space, obs_space)) config.update( dict( env_config=dict( action_space=action_space, observation_space=obs_space, reward_space=Box(1.0, 1.0, shape=(), dtype=np.float32), p_done=1.0, check_action_bounds=check_bounds))) stat = "ok" a = None try: if a_name in covered_a and o_name in covered_o: stat = "skip" # speed up tests by avoiding full grid else: a = get_agent_class(alg)(config=config, env=RandomEnv) if alg not in ["DDPG", "ES", "ARS", "SAC"]: if o_name in ["atari", "image"]: if torch: assert isinstance(a.get_policy().model, TorchVisionNetV2) else: assert isinstance(a.get_policy().model, VisionNetV2) elif o_name in ["vector", "vector2"]: if torch: assert isinstance(a.get_policy().model, TorchFCNetV2) else: assert isinstance(a.get_policy().model, FCNetV2) a.train() covered_a.add(a_name) covered_o.add(o_name) except UnsupportedSpaceException: stat = "unsupported" except Exception as e: stat = "ERROR" print(e) print(traceback.format_exc()) first_error = first_error if first_error is not None else e finally: if a: try: a.stop() except Exception as e: print("Ignoring error stopping agent", e) pass print(stat) print() stats[name or alg, a_name, o_name] = stat # If anything happened, raise error. if first_error is not None: raise first_error def check_support_multiagent(alg, config): register_env("multi_agent_mountaincar", lambda _: MultiAgentMountainCar({"num_agents": 2})) register_env("multi_agent_cartpole", lambda _: MultiAgentCartPole({"num_agents": 2})) config["log_level"] = "ERROR" if "DDPG" in alg: a = get_agent_class(alg)(config=config, env="multi_agent_mountaincar") else: a = get_agent_class(alg)(config=config, env="multi_agent_cartpole") try: a.train() finally: a.stop() class ModelSupportedSpaces(unittest.TestCase): stats = {} def setUp(self): ray.init(num_cpus=4, ignore_reinit_error=True) def tearDown(self): ray.shutdown() def test_a3c(self): config = {"num_workers": 1, "optimizer": {"grads_per_step": 1}} check_support("A3C", config, self.stats, check_bounds=True) config["use_pytorch"] = True check_support("A3C", config, self.stats, check_bounds=True) def test_appo(self): check_support("APPO", {"num_gpus": 0, "vtrace": False}, self.stats) check_support( "APPO", { "num_gpus": 0, "vtrace": True }, self.stats, name="APPO-vt") def test_ars(self): check_support( "ARS", { "num_workers": 1, "noise_size": 10000000, "num_rollouts": 1, "rollouts_used": 1 }, self.stats) def test_ddpg(self): check_support( "DDPG", { "exploration_config": { "ou_base_scale": 100.0 }, "timesteps_per_iteration": 1, "use_state_preprocessor": True, }, self.stats, check_bounds=True) def test_dqn(self): config = {"timesteps_per_iteration": 1} check_support("DQN", config, self.stats) config["use_pytorch"] = True check_support("DQN", config, self.stats) def test_es(self): check_support( "ES", { "num_workers": 1, "noise_size": 10000000, "episodes_per_batch": 1, "train_batch_size": 1 }, self.stats) def test_impala(self): check_support("IMPALA", {"num_gpus": 0}, self.stats) def test_ppo(self): config = { "num_workers": 1, "num_sgd_iter": 1, "train_batch_size": 10, "rollout_fragment_length": 10, "sgd_minibatch_size": 1, } check_support("PPO", config, self.stats, check_bounds=True) config["use_pytorch"] = True check_support("PPO", config, self.stats, check_bounds=True) def test_pg(self): config = {"num_workers": 1, "optimizer": {}} check_support("PG", config, self.stats, check_bounds=True) config["use_pytorch"] = True check_support("PG", config, self.stats, check_bounds=True) def test_sac(self): check_support("SAC", {}, self.stats, check_bounds=True) def test_a3c_multiagent(self): check_support_multiagent("A3C", { "num_workers": 1, "optimizer": { "grads_per_step": 1 } }) def test_apex_multiagent(self): check_support_multiagent( "APEX", { "num_workers": 2, "timesteps_per_iteration": 1000, "num_gpus": 0, "min_iter_time_s": 1, "learning_starts": 1000, "target_network_update_freq": 100, }) def test_apex_ddpg_multiagent(self): check_support_multiagent( "APEX_DDPG", { "num_workers": 2, "timesteps_per_iteration": 1000, "num_gpus": 0, "min_iter_time_s": 1, "learning_starts": 1000, "target_network_update_freq": 100, "use_state_preprocessor": True, }) def test_ddpg_multiagent(self): check_support_multiagent("DDPG", { "timesteps_per_iteration": 1, "use_state_preprocessor": True, "learning_starts": 500, }) def test_dqn_multiagent(self): check_support_multiagent("DQN", {"timesteps_per_iteration": 1}) def test_impala_multiagent(self): check_support_multiagent("IMPALA", {"num_gpus": 0}) def test_pg_multiagent(self): check_support_multiagent("PG", {"num_workers": 1, "optimizer": {}}) def test_ppo_multiagent(self): check_support_multiagent( "PPO", { "num_workers": 1, "num_sgd_iter": 1, "train_batch_size": 10, "rollout_fragment_length": 10, "sgd_minibatch_size": 1, }) if __name__ == "__main__": import pytest import sys if len(sys.argv) > 1 and sys.argv[1] == "--smoke": ACTION_SPACES_TO_TEST = { "discrete": Discrete(5), } OBSERVATION_SPACES_TO_TEST = { "vector": Box(0.0, 1.0, (5, ), dtype=np.float32), "atari": Box(0.0, 1.0, (210, 160, 3), dtype=np.float32), } sys.exit(pytest.main(["-v", __file__]))