From e73c37cc175f79be95e1bf2e05d28d073190089c Mon Sep 17 00:00:00 2001 From: Sven Mika Date: Tue, 24 May 2022 12:53:53 +0200 Subject: [PATCH] [RLlib] MADDPG: Move into main `algorithms` folder and add proper unit and learning tests. (#24579) --- doc/source/rllib/rllib-algorithms.rst | 4 +- rllib/BUILD | 30 ++++++---- rllib/README.rst | 2 +- rllib/agents/maddpg/__init__.py | 20 ++++++- rllib/{agents => algorithms}/maddpg/README.md | 0 rllib/algorithms/maddpg/__init__.py | 3 + rllib/{agents => algorithms}/maddpg/maddpg.py | 11 ++-- .../maddpg/maddpg_tf_policy.py | 6 +- rllib/algorithms/maddpg/tests/test_maddpg.py | 57 +++++++++++++++++++ rllib/contrib/registry.py | 2 +- .../maddpg/two-step-game-maddpg.yaml | 4 +- 11 files changed, 111 insertions(+), 28 deletions(-) rename rllib/{agents => algorithms}/maddpg/README.md (100%) create mode 100644 rllib/algorithms/maddpg/__init__.py rename rllib/{agents => algorithms}/maddpg/maddpg.py (95%) rename rllib/{agents => algorithms}/maddpg/maddpg_tf_policy.py (98%) create mode 100644 rllib/algorithms/maddpg/tests/test_maddpg.py diff --git a/doc/source/rllib/rllib-algorithms.rst b/doc/source/rllib/rllib-algorithms.rst index 89b18e4cf..69efcd391 100644 --- a/doc/source/rllib/rllib-algorithms.rst +++ b/doc/source/rllib/rllib-algorithms.rst @@ -791,13 +791,13 @@ Tuned examples: `Two-step game `__ `[implementation] `__ MADDPG is a DDPG centralized/shared critic algorithm. Code here is adapted from https://github.com/openai/maddpg to integrate with RLlib multi-agent APIs. Please check `justinkterry/maddpg-rllib `__ for examples and more information. Note that the implementation here is based on OpenAI's, and is intended for use with the discrete MPE environments. Please also note that people typically find this method difficult to get to work, even with all applicable optimizations for their environment applied. This method should be viewed as for research purposes, and for reproducing the results of the paper introducing it. +`[paper] `__ `[implementation] `__ MADDPG is a DDPG centralized/shared critic algorithm. Code here is adapted from https://github.com/openai/maddpg to integrate with RLlib multi-agent APIs. Please check `justinkterry/maddpg-rllib `__ for examples and more information. Note that the implementation here is based on OpenAI's, and is intended for use with the discrete MPE environments. Please also note that people typically find this method difficult to get to work, even with all applicable optimizations for their environment applied. This method should be viewed as for research purposes, and for reproducing the results of the paper introducing it. **MADDPG-specific configs** (see also `common configs `__): Tuned examples: `Multi-Agent Particle Environment `__, `Two-step game `__ -.. literalinclude:: ../../../rllib/agents/maddpg/maddpg.py +.. literalinclude:: ../../../rllib/algorithms/maddpg/maddpg.py :language: python :start-after: __sphinx_doc_begin__ :end-before: __sphinx_doc_end__ diff --git a/rllib/BUILD b/rllib/BUILD index 2ebf81e90..11fac54a4 100644 --- a/rllib/BUILD +++ b/rllib/BUILD @@ -379,6 +379,17 @@ py_test( args = ["--yaml-dir=tuned_examples/impala"] ) +# MADDPG +py_test( + name = "learning_tests_two_step_game_maddpg", + main = "tests/run_regression_tests.py", + tags = ["team:ml", "tf_only", "no_tf_eager_tracing", "learning_tests", "learning_tests_discrete"], + size = "large", + srcs = ["tests/run_regression_tests.py"], + data = ["tuned_examples/maddpg/two-step-game-maddpg.yaml"], + args = ["--yaml-dir=tuned_examples/maddpg", "--framework=tf"] +) + # Working, but takes a long time to learn (>15min). # Removed due to Higher API conflicts with Pytorch-Import tests ## MB-MPO @@ -729,7 +740,7 @@ py_test( py_test( name = "test_dreamer", tags = ["team:ml", "trainers_dir"], - size = "small", + size = "medium", srcs = ["algorithms/dreamer/tests/test_dreamer.py"] ) @@ -775,6 +786,14 @@ py_test( srcs = ["algorithms/marwil/tests/test_bc.py"] ) +# MADDPGTrainer +py_test( + name = "test_maddpg", + tags = ["team:ml", "trainers_dir"], + size = "medium", + srcs = ["algorithms/maddpg/tests/test_maddpg.py"] +) + # MAMLTrainer py_test( name = "test_maml", @@ -2951,15 +2970,6 @@ py_test( args = ["--as-test", "--mixed-torch-tf", "--stop-reward=450.0"] ) -py_test( - name = "examples/two_step_game_maddpg", - main = "examples/two_step_game.py", - tags = ["team:ml", "examples", "examples_T"], - size = "medium", - srcs = ["examples/two_step_game.py"], - args = ["--as-test", "--stop-reward=7.1", "--run=MADDPG"] -) - py_test( name = "examples/two_step_game_pg_tf", main = "examples/two_step_game.py", diff --git a/rllib/README.rst b/rllib/README.rst index 439c4cb37..32acd1867 100644 --- a/rllib/README.rst +++ b/rllib/README.rst @@ -105,7 +105,7 @@ Multi-agent: - `Single-Player Alpha Zero (contrib/AlphaZero) `__ - `Parameter Sharing `__ - `QMIX Monotonic Value Factorisation (QMIX, VDN, IQN)) `__ -- `Multi-Agent Deep Deterministic Policy Gradient (contrib/MADDPG) `__ +- `Multi-Agent Deep Deterministic Policy Gradient (MADDPG) `__ - `Shared Critic Methods `__ Others: diff --git a/rllib/agents/maddpg/__init__.py b/rllib/agents/maddpg/__init__.py index 792350b8c..84d50c49b 100644 --- a/rllib/agents/maddpg/__init__.py +++ b/rllib/agents/maddpg/__init__.py @@ -1,3 +1,19 @@ -from ray.rllib.agents.maddpg.maddpg import MADDPGTrainer, DEFAULT_CONFIG +from ray.rllib.algorithms.maddpg.maddpg import ( + MADDPGTrainer, + MADDPGTFPolicy, + DEFAULT_CONFIG, +) -__all__ = ["MADDPGTrainer", "DEFAULT_CONFIG"] +__all__ = [ + "MADDPGTrainer", + "MADDPGTFPolicy", + "DEFAULT_CONFIG", +] + +from ray.rllib.utils.deprecation import deprecation_warning + +deprecation_warning( + "ray.rllib.agents.maddpg", + "ray.rllib.algorithms.maddpg", + error=False, +) diff --git a/rllib/agents/maddpg/README.md b/rllib/algorithms/maddpg/README.md similarity index 100% rename from rllib/agents/maddpg/README.md rename to rllib/algorithms/maddpg/README.md diff --git a/rllib/algorithms/maddpg/__init__.py b/rllib/algorithms/maddpg/__init__.py new file mode 100644 index 000000000..2ae788f1e --- /dev/null +++ b/rllib/algorithms/maddpg/__init__.py @@ -0,0 +1,3 @@ +from ray.rllib.algorithms.maddpg.maddpg import MADDPGTrainer, DEFAULT_CONFIG + +__all__ = ["MADDPGTrainer", "DEFAULT_CONFIG"] diff --git a/rllib/agents/maddpg/maddpg.py b/rllib/algorithms/maddpg/maddpg.py similarity index 95% rename from rllib/agents/maddpg/maddpg.py rename to rllib/algorithms/maddpg/maddpg.py index 85b186a8e..e63321586 100644 --- a/rllib/agents/maddpg/maddpg.py +++ b/rllib/algorithms/maddpg/maddpg.py @@ -12,12 +12,11 @@ and the README for how to run with the multi-agent particle envs. import logging from typing import Type -from ray.rllib.agents.maddpg.maddpg_tf_policy import MADDPGTFPolicy from ray.rllib.algorithms.dqn.dqn import DQNTrainer -from ray.rllib.agents.trainer import COMMON_CONFIG, with_common_config +from ray.rllib.algorithms.maddpg.maddpg_tf_policy import MADDPGTFPolicy +from ray.rllib.agents.trainer import with_common_config from ray.rllib.policy.policy import Policy from ray.rllib.policy.sample_batch import SampleBatch, MultiAgentBatch -from ray.rllib.utils import merge_dicts from ray.rllib.utils.annotations import override from ray.rllib.utils.typing import TrainerConfigDict from ray.rllib.utils.deprecation import DEPRECATED_VALUE @@ -77,6 +76,8 @@ DEFAULT_CONFIG = with_common_config({ "capacity": int(1e6), # How many steps of the model to sample before learning starts. "learning_starts": 1024 * 25, + # Force lockstep replay mode for MADDPG. + "replay_mode": "lockstep", }, # Observation compression. Note that compression makes simulation slow in # MPE. @@ -86,10 +87,6 @@ DEFAULT_CONFIG = with_common_config({ # timesteps. Otherwise, the replay will proceed at the native ratio # determined by (train_batch_size / rollout_fragment_length). "training_intensity": None, - # Force lockstep replay mode for MADDPG. - "multiagent": merge_dicts(COMMON_CONFIG["multiagent"], { - "replay_mode": "lockstep", - }), # === Optimization === # Learning rate for the critic (Q-function) optimizer. diff --git a/rllib/agents/maddpg/maddpg_tf_policy.py b/rllib/algorithms/maddpg/maddpg_tf_policy.py similarity index 98% rename from rllib/agents/maddpg/maddpg_tf_policy.py rename to rllib/algorithms/maddpg/maddpg_tf_policy.py index 8bf5f93a8..6b02fc09d 100644 --- a/rllib/agents/maddpg/maddpg_tf_policy.py +++ b/rllib/algorithms/maddpg/maddpg_tf_policy.py @@ -43,7 +43,7 @@ class MADDPGPostprocessing: class MADDPGTFPolicy(MADDPGPostprocessing, TFPolicy): def __init__(self, obs_space, act_space, config): # _____ Initial Configuration - config = dict(ray.rllib.agents.maddpg.DEFAULT_CONFIG, **config) + config = dict(ray.rllib.algorithms.maddpg.maddpg.DEFAULT_CONFIG, **config) self.config = config self.global_step = tf1.train.get_or_create_global_step() @@ -69,11 +69,11 @@ class MADDPGTFPolicy(MADDPGPostprocessing, TFPolicy): ) obs_space_n = [ - _make_continuous_space(space) + _make_continuous_space(space or obs_space) for _, (_, space, _, _) in config["multiagent"]["policies"].items() ] act_space_n = [ - _make_continuous_space(space) + _make_continuous_space(space or act_space) for _, (_, _, space, _) in config["multiagent"]["policies"].items() ] diff --git a/rllib/algorithms/maddpg/tests/test_maddpg.py b/rllib/algorithms/maddpg/tests/test_maddpg.py new file mode 100644 index 000000000..c6181f782 --- /dev/null +++ b/rllib/algorithms/maddpg/tests/test_maddpg.py @@ -0,0 +1,57 @@ +import unittest + +import ray +import ray.rllib.algorithms.maddpg as maddpg +from ray.rllib.examples.env.two_step_game import TwoStepGame +from ray.rllib.policy.policy import PolicySpec +from ray.rllib.utils.test_utils import ( + check_train_results, + framework_iterator, +) + + +class TestMADDPG(unittest.TestCase): + @classmethod + def setUpClass(cls) -> None: + ray.init() + + @classmethod + def tearDownClass(cls) -> None: + ray.shutdown() + + def test_maddpg_compilation(self): + """Test whether an MADDPGTrainer can be built with all frameworks.""" + config = maddpg.DEFAULT_CONFIG.copy() + config["env"] = TwoStepGame + config["env_config"] = { + "actions_are_logits": True, + } + config["multiagent"] = { + "policies": { + "pol1": PolicySpec( + config={"agent_id": 0}, + ), + "pol2": PolicySpec( + config={"agent_id": 1}, + ), + }, + "policy_mapping_fn": (lambda aid, **kwargs: "pol2" if aid else "pol1"), + } + + num_iterations = 1 + + # Only working for tf right now. + for _ in framework_iterator(config, frameworks="tf"): + trainer = maddpg.MADDPGTrainer(config) + for i in range(num_iterations): + results = trainer.train() + check_train_results(results) + print(results) + trainer.stop() + + +if __name__ == "__main__": + import pytest + import sys + + sys.exit(pytest.main(["-v", __file__])) diff --git a/rllib/contrib/registry.py b/rllib/contrib/registry.py index 154dd36fb..9ff06adbb 100644 --- a/rllib/contrib/registry.py +++ b/rllib/contrib/registry.py @@ -17,7 +17,7 @@ def _import_alphazero(): def _import_maddpg(): - from ray.rllib.agents.maddpg import maddpg + from ray.rllib.algorithms.maddpg import maddpg return maddpg.MADDPGTrainer, maddpg.DEFAULT_CONFIG diff --git a/rllib/tuned_examples/maddpg/two-step-game-maddpg.yaml b/rllib/tuned_examples/maddpg/two-step-game-maddpg.yaml index 6cc9f7a15..9cee89a5f 100644 --- a/rllib/tuned_examples/maddpg/two-step-game-maddpg.yaml +++ b/rllib/tuned_examples/maddpg/two-step-game-maddpg.yaml @@ -1,8 +1,8 @@ -two-step-game-qmix-with-qmix-mixer: +two-step-game-maddpg: env: ray.rllib.examples.env.two_step_game.TwoStepGame run: MADDPG stop: - episode_reward_mean: 8.0 + episode_reward_mean: 7.2 timesteps_total: 20000 config: # MADDPG only supports tf for now.