import unittest import ray from ray.rllib.agents.registry import get_trainer_class from ray.rllib.examples.env.multi_agent import MultiAgentCartPole, \ MultiAgentMountainCar from ray.rllib.utils.test_utils import framework_iterator from ray.tune import register_env def check_support_multiagent(alg, config): register_env("multi_agent_mountaincar", lambda _: MultiAgentMountainCar({"num_agents": 2})) register_env("multi_agent_cartpole", lambda _: MultiAgentCartPole({"num_agents": 2})) config["log_level"] = "ERROR" for fw in framework_iterator(config): if fw in ["tf2", "tfe"] and \ alg in ["A3C", "APEX", "APEX_DDPG", "IMPALA"]: continue if alg in ["DDPG", "APEX_DDPG", "SAC"]: a = get_trainer_class(alg)( config=config, env="multi_agent_mountaincar") else: a = get_trainer_class(alg)( config=config, env="multi_agent_cartpole") print(a.train()) a.stop() class TestSupportedMultiAgentPG(unittest.TestCase): @classmethod def setUpClass(cls) -> None: ray.init(num_cpus=4) @classmethod def tearDownClass(cls) -> None: ray.shutdown() def test_a3c_multiagent(self): check_support_multiagent("A3C", { "num_workers": 1, "optimizer": { "grads_per_step": 1 } }) def test_impala_multiagent(self): check_support_multiagent("IMPALA", {"num_gpus": 0}) def test_pg_multiagent(self): check_support_multiagent("PG", {"num_workers": 1, "optimizer": {}}) def test_ppo_multiagent(self): check_support_multiagent( "PPO", { "num_workers": 1, "num_sgd_iter": 1, "train_batch_size": 10, "rollout_fragment_length": 10, "sgd_minibatch_size": 1, }) class TestSupportedMultiAgentOffPolicy(unittest.TestCase): @classmethod def setUpClass(cls) -> None: ray.init(num_cpus=6) @classmethod def tearDownClass(cls) -> None: ray.shutdown() def test_apex_multiagent(self): check_support_multiagent( "APEX", { "num_workers": 2, "timesteps_per_iteration": 100, "num_gpus": 0, "buffer_size": 1000, "min_iter_time_s": 1, "learning_starts": 10, "target_network_update_freq": 100, "optimizer": { "num_replay_buffer_shards": 1, }, }) def test_apex_ddpg_multiagent(self): check_support_multiagent( "APEX_DDPG", { "num_workers": 2, "timesteps_per_iteration": 100, "buffer_size": 1000, "num_gpus": 0, "min_iter_time_s": 1, "learning_starts": 10, "target_network_update_freq": 100, "use_state_preprocessor": True, }) def test_ddpg_multiagent(self): check_support_multiagent( "DDPG", { "timesteps_per_iteration": 1, "buffer_size": 1000, "use_state_preprocessor": True, "learning_starts": 500, }) def test_dqn_multiagent(self): check_support_multiagent("DQN", { "timesteps_per_iteration": 1, "buffer_size": 1000, }) def test_sac_multiagent(self): check_support_multiagent("SAC", { "num_workers": 0, "buffer_size": 1000, "normalize_actions": False, }) if __name__ == "__main__": import pytest import sys # One can specify the specific TestCase class to run. # None for all unittest.TestCase classes in this file. class_ = sys.argv[1] if len(sys.argv) > 1 else None sys.exit( pytest.main( ["-v", __file__ + ("" if class_ is None else "::" + class_)]))