mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
126 lines
3.8 KiB
Python
126 lines
3.8 KiB
Python
import unittest
|
|
|
|
import ray
|
|
from ray.rllib.agents.registry import get_agent_class
|
|
from ray.rllib.examples.env.multi_agent import MultiAgentCartPole, \
|
|
MultiAgentMountainCar
|
|
from ray.rllib.utils.test_utils import framework_iterator
|
|
from ray.tune import register_env
|
|
|
|
|
|
def check_support_multiagent(alg, config):
|
|
register_env("multi_agent_mountaincar",
|
|
lambda _: MultiAgentMountainCar({"num_agents": 2}))
|
|
register_env("multi_agent_cartpole",
|
|
lambda _: MultiAgentCartPole({"num_agents": 2}))
|
|
config["log_level"] = "ERROR"
|
|
for _ in framework_iterator(config, frameworks=("torch", "tf")):
|
|
if alg in ["DDPG", "APEX_DDPG", "SAC"]:
|
|
a = get_agent_class(alg)(
|
|
config=config, env="multi_agent_mountaincar")
|
|
else:
|
|
a = get_agent_class(alg)(config=config, env="multi_agent_cartpole")
|
|
|
|
print(a.train())
|
|
a.stop()
|
|
|
|
|
|
class TestSupportedMultiAgentPG(unittest.TestCase):
|
|
@classmethod
|
|
def setUpClass(cls) -> None:
|
|
ray.init(num_cpus=4)
|
|
|
|
@classmethod
|
|
def tearDownClass(cls) -> None:
|
|
ray.shutdown()
|
|
|
|
def test_a3c_multiagent(self):
|
|
check_support_multiagent("A3C", {
|
|
"num_workers": 1,
|
|
"optimizer": {
|
|
"grads_per_step": 1
|
|
}
|
|
})
|
|
|
|
def test_impala_multiagent(self):
|
|
check_support_multiagent("IMPALA", {"num_gpus": 0})
|
|
|
|
def test_pg_multiagent(self):
|
|
check_support_multiagent("PG", {"num_workers": 1, "optimizer": {}})
|
|
|
|
def test_ppo_multiagent(self):
|
|
check_support_multiagent(
|
|
"PPO", {
|
|
"num_workers": 1,
|
|
"num_sgd_iter": 1,
|
|
"train_batch_size": 10,
|
|
"rollout_fragment_length": 10,
|
|
"sgd_minibatch_size": 1,
|
|
})
|
|
|
|
|
|
class TestSupportedMultiAgentOffPolicy(unittest.TestCase):
|
|
@classmethod
|
|
def setUpClass(cls) -> None:
|
|
ray.init(num_cpus=4)
|
|
|
|
@classmethod
|
|
def tearDownClass(cls) -> None:
|
|
ray.shutdown()
|
|
|
|
def test_apex_multiagent(self):
|
|
check_support_multiagent(
|
|
"APEX", {
|
|
"num_workers": 2,
|
|
"timesteps_per_iteration": 100,
|
|
"num_gpus": 0,
|
|
"buffer_size": 1000,
|
|
"min_iter_time_s": 1,
|
|
"learning_starts": 10,
|
|
"target_network_update_freq": 100,
|
|
})
|
|
|
|
def test_apex_ddpg_multiagent(self):
|
|
check_support_multiagent(
|
|
"APEX_DDPG", {
|
|
"num_workers": 2,
|
|
"timesteps_per_iteration": 100,
|
|
"buffer_size": 1000,
|
|
"num_gpus": 0,
|
|
"min_iter_time_s": 1,
|
|
"learning_starts": 10,
|
|
"target_network_update_freq": 100,
|
|
"use_state_preprocessor": True,
|
|
})
|
|
|
|
def test_ddpg_multiagent(self):
|
|
check_support_multiagent(
|
|
"DDPG", {
|
|
"timesteps_per_iteration": 1,
|
|
"buffer_size": 1000,
|
|
"use_state_preprocessor": True,
|
|
"learning_starts": 500,
|
|
})
|
|
|
|
def test_dqn_multiagent(self):
|
|
check_support_multiagent("DQN", {
|
|
"timesteps_per_iteration": 1,
|
|
"buffer_size": 1000,
|
|
})
|
|
|
|
def test_sac_multiagent(self):
|
|
check_support_multiagent("SAC", {
|
|
"num_workers": 0,
|
|
"buffer_size": 1000,
|
|
"normalize_actions": False,
|
|
})
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import pytest
|
|
import sys
|
|
# One can specify the specific TestCase class to run.
|
|
# None for all unittest.TestCase classes in this file.
|
|
class_ = sys.argv[1] if len(sys.argv) > 0 else None
|
|
sys.exit(pytest.main(
|
|
["-v", __file__ + ("" if class_ is None else "::" + class_)]))
|