2020-05-27 16:19:13 +02:00
|
|
|
import unittest
|
|
|
|
|
|
|
|
import ray
|
2021-02-08 12:05:16 +01:00
|
|
|
from ray.rllib.agents.registry import get_trainer_class
|
2020-05-27 16:19:13 +02:00
|
|
|
from ray.rllib.examples.env.multi_agent import MultiAgentCartPole, MultiAgentMountainCar
|
2021-10-07 23:57:53 +02:00
|
|
|
from ray.rllib.policy.policy import PolicySpec
|
2021-09-30 16:39:05 +02:00
|
|
|
from ray.rllib.utils.test_utils import check_train_results, framework_iterator
|
2020-05-27 16:19:13 +02:00
|
|
|
from ray.tune import register_env
|
|
|
|
|
|
|
|
|
|
|
|
def check_support_multiagent(alg, config):
|
|
|
|
register_env(
|
|
|
|
"multi_agent_mountaincar", lambda _: MultiAgentMountainCar({"num_agents": 2})
|
2022-01-29 18:41:57 -08:00
|
|
|
)
|
2020-05-27 16:19:13 +02:00
|
|
|
register_env(
|
|
|
|
"multi_agent_cartpole", lambda _: MultiAgentCartPole({"num_agents": 2})
|
2022-01-29 18:41:57 -08:00
|
|
|
)
|
2021-10-07 23:57:53 +02:00
|
|
|
|
|
|
|
# Simulate a simple multi-agent setup.
|
|
|
|
policies = {
|
|
|
|
"policy_0": PolicySpec(config={"gamma": 0.99}),
|
|
|
|
"policy_1": PolicySpec(config={"gamma": 0.95}),
|
|
|
|
}
|
|
|
|
policy_ids = list(policies.keys())
|
|
|
|
|
|
|
|
def policy_mapping_fn(agent_id, episode, worker, **kwargs):
|
|
|
|
pol_id = policy_ids[agent_id]
|
|
|
|
return pol_id
|
|
|
|
|
|
|
|
config["multiagent"] = {
|
|
|
|
"policies": policies,
|
|
|
|
"policy_mapping_fn": policy_mapping_fn,
|
|
|
|
}
|
|
|
|
|
2020-07-08 16:12:20 +02:00
|
|
|
for fw in framework_iterator(config):
|
2020-07-11 22:06:35 +02:00
|
|
|
if fw in ["tf2", "tfe"] and alg in ["A3C", "APEX", "APEX_DDPG", "IMPALA"]:
|
2020-07-08 16:12:20 +02:00
|
|
|
continue
|
2020-05-27 16:19:13 +02:00
|
|
|
if alg in ["DDPG", "APEX_DDPG", "SAC"]:
|
|
|
|
a = get_trainer_class(alg)(config=config, env="multi_agent_mountaincar")
|
|
|
|
else:
|
2021-02-08 12:05:16 +01:00
|
|
|
a = get_trainer_class(alg)(config=config, env="multi_agent_cartpole")
|
2020-06-30 05:33:19 +02:00
|
|
|
|
2021-09-30 16:39:05 +02:00
|
|
|
results = a.train()
|
|
|
|
check_train_results(results)
|
|
|
|
print(results)
|
2020-06-30 05:33:19 +02:00
|
|
|
a.stop()
|
2020-05-27 16:19:13 +02:00
|
|
|
|
|
|
|
|
2020-07-02 13:06:34 +02:00
|
|
|
class TestSupportedMultiAgentPG(unittest.TestCase):
|
2020-06-05 08:34:21 +02:00
|
|
|
@classmethod
|
|
|
|
def setUpClass(cls) -> None:
|
|
|
|
ray.init(num_cpus=4)
|
2020-05-27 16:19:13 +02:00
|
|
|
|
2020-06-05 08:34:21 +02:00
|
|
|
@classmethod
|
|
|
|
def tearDownClass(cls) -> None:
|
2020-05-27 16:19:13 +02:00
|
|
|
ray.shutdown()
|
|
|
|
|
|
|
|
def test_a3c_multiagent(self):
|
|
|
|
check_support_multiagent(
|
|
|
|
"A3C", {"num_workers": 1, "optimizer": {"grads_per_step": 1}}
|
|
|
|
)
|
|
|
|
|
2020-07-02 13:06:34 +02:00
|
|
|
def test_impala_multiagent(self):
|
|
|
|
check_support_multiagent("IMPALA", {"num_gpus": 0})
|
|
|
|
|
|
|
|
def test_pg_multiagent(self):
|
|
|
|
check_support_multiagent("PG", {"num_workers": 1, "optimizer": {}})
|
|
|
|
|
|
|
|
def test_ppo_multiagent(self):
|
|
|
|
check_support_multiagent(
|
|
|
|
"PPO",
|
|
|
|
{
|
|
|
|
"num_workers": 1,
|
|
|
|
"num_sgd_iter": 1,
|
|
|
|
"train_batch_size": 10,
|
|
|
|
"rollout_fragment_length": 10,
|
|
|
|
"sgd_minibatch_size": 1,
|
|
|
|
},
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
class TestSupportedMultiAgentOffPolicy(unittest.TestCase):
|
|
|
|
@classmethod
|
|
|
|
def setUpClass(cls) -> None:
|
2021-02-08 15:02:19 +01:00
|
|
|
ray.init(num_cpus=6)
|
2020-07-02 13:06:34 +02:00
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def tearDownClass(cls) -> None:
|
|
|
|
ray.shutdown()
|
|
|
|
|
2020-05-27 16:19:13 +02:00
|
|
|
def test_apex_multiagent(self):
|
|
|
|
check_support_multiagent(
|
|
|
|
"APEX",
|
|
|
|
{
|
|
|
|
"num_workers": 2,
|
2022-06-10 17:09:18 +02:00
|
|
|
"min_sample_timesteps_per_iteration": 100,
|
2020-05-27 16:19:13 +02:00
|
|
|
"num_gpus": 0,
|
2022-05-17 13:43:49 +02:00
|
|
|
"replay_buffer_config": {
|
|
|
|
"capacity": 1000,
|
|
|
|
"learning_starts": 10,
|
|
|
|
},
|
2022-06-10 17:09:18 +02:00
|
|
|
"min_time_s_per_iteration": 1,
|
2020-05-27 16:19:13 +02:00
|
|
|
"target_network_update_freq": 100,
|
2021-02-08 15:02:19 +01:00
|
|
|
"optimizer": {
|
|
|
|
"num_replay_buffer_shards": 1,
|
|
|
|
},
|
2020-05-27 16:19:13 +02:00
|
|
|
},
|
|
|
|
)
|
|
|
|
|
|
|
|
def test_apex_ddpg_multiagent(self):
|
|
|
|
check_support_multiagent(
|
|
|
|
"APEX_DDPG",
|
|
|
|
{
|
|
|
|
"num_workers": 2,
|
2022-06-10 17:09:18 +02:00
|
|
|
"min_sample_timesteps_per_iteration": 100,
|
2022-05-17 13:43:49 +02:00
|
|
|
"replay_buffer_config": {
|
|
|
|
"capacity": 1000,
|
|
|
|
"learning_starts": 10,
|
|
|
|
},
|
2020-05-27 16:19:13 +02:00
|
|
|
"num_gpus": 0,
|
2022-06-10 17:09:18 +02:00
|
|
|
"min_time_s_per_iteration": 1,
|
2020-05-27 16:19:13 +02:00
|
|
|
"target_network_update_freq": 100,
|
|
|
|
"use_state_preprocessor": True,
|
|
|
|
},
|
|
|
|
)
|
|
|
|
|
|
|
|
def test_ddpg_multiagent(self):
|
|
|
|
check_support_multiagent(
|
|
|
|
"DDPG",
|
|
|
|
{
|
2022-06-10 17:09:18 +02:00
|
|
|
"min_sample_timesteps_per_iteration": 1,
|
2022-05-17 13:43:49 +02:00
|
|
|
"replay_buffer_config": {
|
|
|
|
"capacity": 1000,
|
|
|
|
"learning_starts": 500,
|
|
|
|
},
|
2020-05-27 16:19:13 +02:00
|
|
|
"use_state_preprocessor": True,
|
|
|
|
},
|
|
|
|
)
|
|
|
|
|
|
|
|
def test_dqn_multiagent(self):
|
2020-06-05 08:34:21 +02:00
|
|
|
check_support_multiagent(
|
|
|
|
"DQN",
|
|
|
|
{
|
2022-06-10 17:09:18 +02:00
|
|
|
"min_sample_timesteps_per_iteration": 1,
|
2022-05-17 13:43:49 +02:00
|
|
|
"replay_buffer_config": {
|
|
|
|
"capacity": 1000,
|
|
|
|
},
|
2020-06-05 08:34:21 +02:00
|
|
|
},
|
|
|
|
)
|
2020-05-27 16:19:13 +02:00
|
|
|
|
|
|
|
def test_sac_multiagent(self):
|
|
|
|
check_support_multiagent(
|
|
|
|
"SAC",
|
|
|
|
{
|
|
|
|
"num_workers": 0,
|
2022-05-17 13:43:49 +02:00
|
|
|
"replay_buffer_config": {
|
|
|
|
"capacity": 1000,
|
|
|
|
},
|
2020-05-27 16:19:13 +02:00
|
|
|
"normalize_actions": False,
|
|
|
|
},
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
import pytest
|
|
|
|
import sys
|
2022-01-29 18:41:57 -08:00
|
|
|
|
2020-07-02 13:06:34 +02:00
|
|
|
# One can specify the specific TestCase class to run.
|
|
|
|
# None for all unittest.TestCase classes in this file.
|
2020-07-11 22:06:35 +02:00
|
|
|
class_ = sys.argv[1] if len(sys.argv) > 1 else None
|
2020-08-07 16:49:49 -07:00
|
|
|
sys.exit(pytest.main(["-v", __file__ + ("" if class_ is None else "::" + class_)]))
|