mirror of
https://github.com/vale981/ray
synced 2025-03-05 10:01:43 -05:00
[RLlib] Retry agents -> algorithms. with proper doc changes this time. (#24797)
This commit is contained in:
parent
d40fa391a5
commit
68a9a33386
59 changed files with 430 additions and 248 deletions
|
@ -466,7 +466,7 @@ HalfCheetah 13000 ~15000
|
||||||
Model-Agnostic Meta-Learning (MAML)
|
Model-Agnostic Meta-Learning (MAML)
|
||||||
-----------------------------------
|
-----------------------------------
|
||||||
|pytorch| |tensorflow|
|
|pytorch| |tensorflow|
|
||||||
`[paper] <https://arxiv.org/abs/1703.03400>`__ `[implementation] <https://github.com/ray-project/ray/blob/master/rllib/agents/maml/maml.py>`__
|
`[paper] <https://arxiv.org/abs/1703.03400>`__ `[implementation] <https://github.com/ray-project/ray/blob/master/rllib/algorithms/maml/maml.py>`__
|
||||||
|
|
||||||
RLlib's MAML implementation is a meta-learning method for learning and quick adaptation across different tasks for continuous control. Code here is adapted from https://github.com/jonasrothfuss, which outperforms vanilla MAML and avoids computation of the higher order gradients during the meta-update step. MAML is evaluated on custom environments that are described in greater detail `here <https://github.com/ray-project/ray/blob/master/rllib/env/apis/task_settable_env.py>`__.
|
RLlib's MAML implementation is a meta-learning method for learning and quick adaptation across different tasks for continuous control. Code here is adapted from https://github.com/jonasrothfuss, which outperforms vanilla MAML and avoids computation of the higher order gradients during the meta-update step. MAML is evaluated on custom environments that are described in greater detail `here <https://github.com/ray-project/ray/blob/master/rllib/env/apis/task_settable_env.py>`__.
|
||||||
|
|
||||||
|
@ -476,7 +476,7 @@ Tuned examples: HalfCheetahRandDirecEnv (`Env <https://github.com/ray-project/ra
|
||||||
|
|
||||||
**MAML-specific configs** (see also `common configs <rllib-training.html#common-parameters>`__):
|
**MAML-specific configs** (see also `common configs <rllib-training.html#common-parameters>`__):
|
||||||
|
|
||||||
.. literalinclude:: ../../../rllib/agents/maml/maml.py
|
.. literalinclude:: ../../../rllib/algorithms/maml/maml.py
|
||||||
:language: python
|
:language: python
|
||||||
:start-after: __sphinx_doc_begin__
|
:start-after: __sphinx_doc_begin__
|
||||||
:end-before: __sphinx_doc_end__
|
:end-before: __sphinx_doc_end__
|
||||||
|
@ -486,7 +486,7 @@ Tuned examples: HalfCheetahRandDirecEnv (`Env <https://github.com/ray-project/ra
|
||||||
Model-Based Meta-Policy-Optimization (MB-MPO)
|
Model-Based Meta-Policy-Optimization (MB-MPO)
|
||||||
---------------------------------------------
|
---------------------------------------------
|
||||||
|pytorch|
|
|pytorch|
|
||||||
`[paper] <https://arxiv.org/pdf/1809.05214.pdf>`__ `[implementation] <https://github.com/ray-project/ray/blob/master/rllib/agents/mbmpo/mbmpo.py>`__
|
`[paper] <https://arxiv.org/pdf/1809.05214.pdf>`__ `[implementation] <https://github.com/ray-project/ray/blob/master/rllib/algorithms/mbmpo/mbmpo.py>`__
|
||||||
|
|
||||||
RLlib's MBMPO implementation is a Dyna-styled model-based RL method that learns based on the predictions of an ensemble of transition-dynamics models. Similar to MAML, MBMPO metalearns an optimial policy by treating each dynamics model as a different task. Code here is adapted from https://github.com/jonasrothfuss/model_ensemble_meta_learning. Similar to the original paper, MBMPO is evaluated on MuJoCo, with the horizon set to 200 instead of the default 1000.
|
RLlib's MBMPO implementation is a Dyna-styled model-based RL method that learns based on the predictions of an ensemble of transition-dynamics models. Similar to MAML, MBMPO metalearns an optimial policy by treating each dynamics model as a different task. Code here is adapted from https://github.com/jonasrothfuss/model_ensemble_meta_learning. Similar to the original paper, MBMPO is evaluated on MuJoCo, with the horizon set to 200 instead of the default 1000.
|
||||||
|
|
||||||
|
@ -510,7 +510,7 @@ Hopper 620 ~650
|
||||||
|
|
||||||
**MBMPO-specific configs** (see also `common configs <rllib-training.html#common-parameters>`__):
|
**MBMPO-specific configs** (see also `common configs <rllib-training.html#common-parameters>`__):
|
||||||
|
|
||||||
.. literalinclude:: ../../../rllib/agents/mbmpo/mbmpo.py
|
.. literalinclude:: ../../../rllib/algorithms/mbmpo/mbmpo.py
|
||||||
:language: python
|
:language: python
|
||||||
:start-after: __sphinx_doc_begin__
|
:start-after: __sphinx_doc_begin__
|
||||||
:end-before: __sphinx_doc_end__
|
:end-before: __sphinx_doc_end__
|
||||||
|
@ -539,7 +539,7 @@ Cheetah-Run 640 ~800
|
||||||
|
|
||||||
**Dreamer-specific configs** (see also `common configs <rllib-training.html#common-parameters>`__):
|
**Dreamer-specific configs** (see also `common configs <rllib-training.html#common-parameters>`__):
|
||||||
|
|
||||||
.. literalinclude:: ../../../rllib/agents/dreamer/dreamer.py
|
.. literalinclude:: ../../../rllib/algorithms/dreamer/dreamer.py
|
||||||
:language: python
|
:language: python
|
||||||
:start-after: __sphinx_doc_begin__
|
:start-after: __sphinx_doc_begin__
|
||||||
:end-before: __sphinx_doc_end__
|
:end-before: __sphinx_doc_end__
|
||||||
|
@ -567,7 +567,7 @@ RecSim environment wrapper: `Google RecSim <https://github.com/ray-project/ray/b
|
||||||
Conservative Q-Learning (CQL)
|
Conservative Q-Learning (CQL)
|
||||||
-----------------------------------
|
-----------------------------------
|
||||||
|pytorch| |tensorflow|
|
|pytorch| |tensorflow|
|
||||||
`[paper] <https://arxiv.org/abs/2006.04779>`__ `[implementation] <https://github.com/ray-project/ray/blob/master/rllib/agents/cql/cql.py>`__
|
`[paper] <https://arxiv.org/abs/2006.04779>`__ `[implementation] <https://github.com/ray-project/ray/blob/master/rllib/algorithms/cql/cql.py>`__
|
||||||
|
|
||||||
In offline RL, the algorithm has no access to an environment, but can only sample from a fixed dataset of pre-collected state-action-reward tuples.
|
In offline RL, the algorithm has no access to an environment, but can only sample from a fixed dataset of pre-collected state-action-reward tuples.
|
||||||
In particular, CQL (Conservative Q-Learning) is an offline RL algorithm that mitigates the overestimation of Q-values outside the dataset distribution via
|
In particular, CQL (Conservative Q-Learning) is an offline RL algorithm that mitigates the overestimation of Q-values outside the dataset distribution via
|
||||||
|
@ -581,7 +581,7 @@ Tuned examples: `HalfCheetah Random <https://github.com/ray-project/ray/blob/mas
|
||||||
|
|
||||||
**CQL-specific configs** (see also `common configs <rllib-training.html#common-parameters>`__):
|
**CQL-specific configs** (see also `common configs <rllib-training.html#common-parameters>`__):
|
||||||
|
|
||||||
.. literalinclude:: ../../../rllib/agents/cql/cql.py
|
.. literalinclude:: ../../../rllib/algorithms/cql/cql.py
|
||||||
:language: python
|
:language: python
|
||||||
:start-after: __sphinx_doc_begin__
|
:start-after: __sphinx_doc_begin__
|
||||||
:end-before: __sphinx_doc_end__
|
:end-before: __sphinx_doc_end__
|
||||||
|
@ -636,7 +636,7 @@ Monotonic Advantage Re-Weighted Imitation Learning (MARWIL)
|
||||||
-----------------------------------------------------------
|
-----------------------------------------------------------
|
||||||
|pytorch| |tensorflow|
|
|pytorch| |tensorflow|
|
||||||
`[paper] <http://papers.nips.cc/paper/7866-exponentially-weighted-imitation-learning-for-batched-historical-data>`__
|
`[paper] <http://papers.nips.cc/paper/7866-exponentially-weighted-imitation-learning-for-batched-historical-data>`__
|
||||||
`[implementation] <https://github.com/ray-project/ray/blob/master/rllib/agents/marwil/marwil.py>`__
|
`[implementation] <https://github.com/ray-project/ray/blob/master/rllib/algorithms/marwil/marwil.py>`__
|
||||||
|
|
||||||
MARWIL is a hybrid imitation learning and policy gradient algorithm suitable for training on batched historical data.
|
MARWIL is a hybrid imitation learning and policy gradient algorithm suitable for training on batched historical data.
|
||||||
When the ``beta`` hyperparameter is set to zero, the MARWIL objective reduces to vanilla imitation learning (see `BC`_).
|
When the ``beta`` hyperparameter is set to zero, the MARWIL objective reduces to vanilla imitation learning (see `BC`_).
|
||||||
|
@ -646,7 +646,7 @@ Tuned examples: `CartPole-v0 <https://github.com/ray-project/ray/blob/master/rll
|
||||||
|
|
||||||
**MARWIL-specific configs** (see also `common configs <rllib-training.html#common-parameters>`__):
|
**MARWIL-specific configs** (see also `common configs <rllib-training.html#common-parameters>`__):
|
||||||
|
|
||||||
.. literalinclude:: ../../../rllib/agents/marwil/marwil.py
|
.. literalinclude:: ../../../rllib/algorithms/marwil/marwil.py
|
||||||
:language: python
|
:language: python
|
||||||
:start-after: __sphinx_doc_begin__
|
:start-after: __sphinx_doc_begin__
|
||||||
:end-before: __sphinx_doc_end__
|
:end-before: __sphinx_doc_end__
|
||||||
|
@ -658,7 +658,7 @@ Behavior Cloning (BC; derived from MARWIL implementation)
|
||||||
---------------------------------------------------------
|
---------------------------------------------------------
|
||||||
|pytorch| |tensorflow|
|
|pytorch| |tensorflow|
|
||||||
`[paper] <http://papers.nips.cc/paper/7866-exponentially-weighted-imitation-learning-for-batched-historical-data>`__
|
`[paper] <http://papers.nips.cc/paper/7866-exponentially-weighted-imitation-learning-for-batched-historical-data>`__
|
||||||
`[implementation] <https://github.com/ray-project/ray/blob/master/rllib/agents/marwil/bc.py>`__
|
`[implementation] <https://github.com/ray-project/ray/blob/master/rllib/algorithms/marwil/bc.py>`__
|
||||||
|
|
||||||
Our behavioral cloning implementation is directly derived from our `MARWIL`_ implementation,
|
Our behavioral cloning implementation is directly derived from our `MARWIL`_ implementation,
|
||||||
with the only difference being the ``beta`` parameter force-set to 0.0. This makes
|
with the only difference being the ``beta`` parameter force-set to 0.0. This makes
|
||||||
|
@ -669,7 +669,7 @@ Tuned examples: `CartPole-v0 <https://github.com/ray-project/ray/blob/master/rll
|
||||||
|
|
||||||
**BC-specific configs** (see also `common configs <rllib-training.html#common-parameters>`__):
|
**BC-specific configs** (see also `common configs <rllib-training.html#common-parameters>`__):
|
||||||
|
|
||||||
.. literalinclude:: ../../../rllib/agents/marwil/bc.py
|
.. literalinclude:: ../../../rllib/algorithms/marwil/bc.py
|
||||||
:language: python
|
:language: python
|
||||||
:start-after: __sphinx_doc_begin__
|
:start-after: __sphinx_doc_begin__
|
||||||
:end-before: __sphinx_doc_end__
|
:end-before: __sphinx_doc_end__
|
||||||
|
|
14
rllib/BUILD
14
rllib/BUILD
|
@ -647,7 +647,7 @@ py_test(
|
||||||
name = "test_alpha_star",
|
name = "test_alpha_star",
|
||||||
tags = ["team:ml", "trainers_dir"],
|
tags = ["team:ml", "trainers_dir"],
|
||||||
size = "large",
|
size = "large",
|
||||||
srcs = ["agents/alpha_star/tests/test_alpha_star.py"]
|
srcs = ["algorithms/alpha_star/tests/test_alpha_star.py"]
|
||||||
)
|
)
|
||||||
|
|
||||||
# APEXTrainer (DQN)
|
# APEXTrainer (DQN)
|
||||||
|
@ -687,7 +687,7 @@ py_test(
|
||||||
name = "test_cql",
|
name = "test_cql",
|
||||||
tags = ["team:ml", "trainers_dir"],
|
tags = ["team:ml", "trainers_dir"],
|
||||||
size = "medium",
|
size = "medium",
|
||||||
srcs = ["agents/cql/tests/test_cql.py"]
|
srcs = ["algorithms/cql/tests/test_cql.py"]
|
||||||
)
|
)
|
||||||
|
|
||||||
# DDPGTrainer
|
# DDPGTrainer
|
||||||
|
@ -711,7 +711,7 @@ py_test(
|
||||||
name = "test_dreamer",
|
name = "test_dreamer",
|
||||||
tags = ["team:ml", "trainers_dir"],
|
tags = ["team:ml", "trainers_dir"],
|
||||||
size = "small",
|
size = "small",
|
||||||
srcs = ["agents/dreamer/tests/test_dreamer.py"]
|
srcs = ["algorithms/dreamer/tests/test_dreamer.py"]
|
||||||
)
|
)
|
||||||
|
|
||||||
# ES
|
# ES
|
||||||
|
@ -743,7 +743,7 @@ py_test(
|
||||||
size = "large",
|
size = "large",
|
||||||
# Include the json data file.
|
# Include the json data file.
|
||||||
data = ["tests/data/cartpole/large.json"],
|
data = ["tests/data/cartpole/large.json"],
|
||||||
srcs = ["agents/marwil/tests/test_marwil.py"]
|
srcs = ["algorithms/marwil/tests/test_marwil.py"]
|
||||||
)
|
)
|
||||||
|
|
||||||
# BCTrainer (sub-type of MARWIL)
|
# BCTrainer (sub-type of MARWIL)
|
||||||
|
@ -753,7 +753,7 @@ py_test(
|
||||||
size = "large",
|
size = "large",
|
||||||
# Include the json data file.
|
# Include the json data file.
|
||||||
data = ["tests/data/cartpole/large.json"],
|
data = ["tests/data/cartpole/large.json"],
|
||||||
srcs = ["agents/marwil/tests/test_bc.py"]
|
srcs = ["algorithms/marwil/tests/test_bc.py"]
|
||||||
)
|
)
|
||||||
|
|
||||||
# MAMLTrainer
|
# MAMLTrainer
|
||||||
|
@ -761,7 +761,7 @@ py_test(
|
||||||
name = "test_maml",
|
name = "test_maml",
|
||||||
tags = ["team:ml", "trainers_dir"],
|
tags = ["team:ml", "trainers_dir"],
|
||||||
size = "medium",
|
size = "medium",
|
||||||
srcs = ["agents/maml/tests/test_maml.py"]
|
srcs = ["algorithms/maml/tests/test_maml.py"]
|
||||||
)
|
)
|
||||||
|
|
||||||
# MBMPOTrainer
|
# MBMPOTrainer
|
||||||
|
@ -769,7 +769,7 @@ py_test(
|
||||||
name = "test_mbmpo",
|
name = "test_mbmpo",
|
||||||
tags = ["team:ml", "trainers_dir"],
|
tags = ["team:ml", "trainers_dir"],
|
||||||
size = "medium",
|
size = "medium",
|
||||||
srcs = ["agents/mbmpo/tests/test_mbmpo.py"]
|
srcs = ["algorithms/mbmpo/tests/test_mbmpo.py"]
|
||||||
)
|
)
|
||||||
|
|
||||||
# PGTrainer
|
# PGTrainer
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
from ray.rllib.agents.alpha_star.alpha_star import (
|
from ray.rllib.algorithms.alpha_star.alpha_star import (
|
||||||
AlphaStarConfig,
|
AlphaStarConfig,
|
||||||
AlphaStarTrainer,
|
AlphaStarTrainer,
|
||||||
DEFAULT_CONFIG,
|
DEFAULT_CONFIG,
|
||||||
|
@ -9,3 +9,10 @@ __all__ = [
|
||||||
"AlphaStarTrainer",
|
"AlphaStarTrainer",
|
||||||
"DEFAULT_CONFIG",
|
"DEFAULT_CONFIG",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
from ray.rllib.utils.deprecation import deprecation_warning
|
||||||
|
|
||||||
|
deprecation_warning(
|
||||||
|
"ray.rllib.agents.alpha_star", "ray.rllib.algorithms.alpha_star", error=False
|
||||||
|
)
|
||||||
|
|
|
@ -1,8 +1,13 @@
|
||||||
from ray.rllib.agents.cql.cql import CQLTrainer, CQL_DEFAULT_CONFIG
|
from ray.rllib.algorithms.cql.cql import CQLTrainer, CQL_DEFAULT_CONFIG
|
||||||
from ray.rllib.agents.cql.cql_torch_policy import CQLTorchPolicy
|
from ray.rllib.algorithms.cql.cql_tf_policy import CQLTFPolicy
|
||||||
|
from ray.rllib.algorithms.cql.cql_torch_policy import CQLTorchPolicy
|
||||||
|
from ray.rllib.utils.deprecation import deprecation_warning
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"CQL_DEFAULT_CONFIG",
|
"CQL_DEFAULT_CONFIG",
|
||||||
|
"CQLTFPolicy",
|
||||||
"CQLTorchPolicy",
|
"CQLTorchPolicy",
|
||||||
"CQLTrainer",
|
"CQLTrainer",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
deprecation_warning("ray.rllib.agents.cql", "ray.rllib.algorithms.cql", error=False)
|
||||||
|
|
|
@ -100,9 +100,9 @@ class SimpleQConfig(TrainerConfig):
|
||||||
>>> .exploration(exploration_config=explore_config)
|
>>> .exploration(exploration_config=explore_config)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self, trainer_class=None):
|
||||||
"""Initializes a SimpleQConfig instance."""
|
"""Initializes a SimpleQConfig instance."""
|
||||||
super().__init__(trainer_class=SimpleQTrainer)
|
super().__init__(trainer_class=trainer_class or SimpleQTrainer)
|
||||||
|
|
||||||
# Simple Q specific
|
# Simple Q specific
|
||||||
# fmt: off
|
# fmt: off
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
from ray.rllib.agents.dreamer.dreamer import (
|
from ray.rllib.algorithms.dreamer.dreamer import (
|
||||||
DREAMERConfig,
|
DREAMERConfig,
|
||||||
DREAMERTrainer,
|
DREAMERTrainer,
|
||||||
DEFAULT_CONFIG,
|
DEFAULT_CONFIG,
|
||||||
|
@ -9,3 +9,10 @@ __all__ = [
|
||||||
"DREAMERTrainer",
|
"DREAMERTrainer",
|
||||||
"DEFAULT_CONFIG",
|
"DEFAULT_CONFIG",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
from ray.rllib.utils.deprecation import deprecation_warning
|
||||||
|
|
||||||
|
deprecation_warning(
|
||||||
|
"ray.rllib.agents.dreamer", "ray.rllib.algorithms.dreamer", error=False
|
||||||
|
)
|
||||||
|
|
|
@ -1,6 +1,10 @@
|
||||||
from ray.rllib.agents.maml.maml import MAMLTrainer, DEFAULT_CONFIG
|
from ray.rllib.algorithms.maml.maml import MAMLTrainer, DEFAULT_CONFIG
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"MAMLTrainer",
|
"MAMLTrainer",
|
||||||
"DEFAULT_CONFIG",
|
"DEFAULT_CONFIG",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
from ray.rllib.utils.deprecation import deprecation_warning
|
||||||
|
|
||||||
|
deprecation_warning("ray.rllib.agents.maml", "ray.rllib.algorithms.maml", error=False)
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
from ray.rllib.agents.marwil.bc import BCTrainer, BC_DEFAULT_CONFIG
|
from ray.rllib.algorithms.marwil.bc import BCTrainer, BC_DEFAULT_CONFIG
|
||||||
from ray.rllib.agents.marwil.marwil import MARWILTrainer, DEFAULT_CONFIG
|
from ray.rllib.algorithms.marwil.marwil import MARWILTrainer, DEFAULT_CONFIG
|
||||||
from ray.rllib.agents.marwil.marwil_tf_policy import MARWILTFPolicy
|
from ray.rllib.algorithms.marwil.marwil_tf_policy import MARWILTFPolicy
|
||||||
from ray.rllib.agents.marwil.marwil_torch_policy import MARWILTorchPolicy
|
from ray.rllib.algorithms.marwil.marwil_torch_policy import MARWILTorchPolicy
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"BCTrainer",
|
"BCTrainer",
|
||||||
|
@ -11,3 +11,10 @@ __all__ = [
|
||||||
"MARWILTorchPolicy",
|
"MARWILTorchPolicy",
|
||||||
"MARWILTrainer",
|
"MARWILTrainer",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
from ray.rllib.utils.deprecation import deprecation_warning
|
||||||
|
|
||||||
|
deprecation_warning(
|
||||||
|
"ray.rllib.agents.marwil", "ray.rllib.algorithms.marwil", error=False
|
||||||
|
)
|
||||||
|
|
|
@ -1,6 +1,11 @@
|
||||||
from ray.rllib.agents.mbmpo.mbmpo import MBMPOTrainer, DEFAULT_CONFIG
|
from ray.rllib.algorithms.mbmpo.mbmpo import MBMPOTrainer, DEFAULT_CONFIG
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"MBMPOTrainer",
|
"MBMPOTrainer",
|
||||||
"DEFAULT_CONFIG",
|
"DEFAULT_CONFIG",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
from ray.rllib.utils.deprecation import deprecation_warning
|
||||||
|
|
||||||
|
deprecation_warning("ray.rllib.agents.mbmpo", "ray.rllib.algorithms.mbmpo", error=False)
|
||||||
|
|
|
@ -1,3 +1,3 @@
|
||||||
from ray.rllib.agents.qmix.qmix import QMixTrainer, DEFAULT_CONFIG
|
from ray.rllib.agents.qmix.qmix import QMixConfig, QMixTrainer, DEFAULT_CONFIG
|
||||||
|
|
||||||
__all__ = ["QMixTrainer", "DEFAULT_CONFIG"]
|
__all__ = ["QMixConfig", "QMixTrainer", "DEFAULT_CONFIG"]
|
||||||
|
|
|
@ -1,7 +1,6 @@
|
||||||
from typing import Type
|
from typing import Optional, Type
|
||||||
|
|
||||||
from ray.rllib.agents.trainer import with_common_config
|
from ray.rllib.agents.dqn.simple_q import SimpleQConfig, SimpleQTrainer
|
||||||
from ray.rllib.agents.dqn.simple_q import SimpleQTrainer
|
|
||||||
from ray.rllib.agents.qmix.qmix_policy import QMixTorchPolicy
|
from ray.rllib.agents.qmix.qmix_policy import QMixTorchPolicy
|
||||||
from ray.rllib.execution.rollout_ops import (
|
from ray.rllib.execution.rollout_ops import (
|
||||||
synchronous_parallel_sample,
|
synchronous_parallel_sample,
|
||||||
|
@ -12,7 +11,7 @@ from ray.rllib.execution.train_ops import (
|
||||||
)
|
)
|
||||||
from ray.rllib.policy.policy import Policy
|
from ray.rllib.policy.policy import Policy
|
||||||
from ray.rllib.utils.annotations import override
|
from ray.rllib.utils.annotations import override
|
||||||
from ray.rllib.utils.deprecation import DEPRECATED_VALUE
|
from ray.rllib.utils.deprecation import DEPRECATED_VALUE, Deprecated
|
||||||
from ray.rllib.utils.metrics import (
|
from ray.rllib.utils.metrics import (
|
||||||
LAST_TARGET_UPDATE_TS,
|
LAST_TARGET_UPDATE_TS,
|
||||||
NUM_AGENT_STEPS_SAMPLED,
|
NUM_AGENT_STEPS_SAMPLED,
|
||||||
|
@ -23,118 +22,174 @@ from ray.rllib.utils.metrics import (
|
||||||
from ray.rllib.utils.replay_buffers.utils import sample_min_n_steps_from_buffer
|
from ray.rllib.utils.replay_buffers.utils import sample_min_n_steps_from_buffer
|
||||||
from ray.rllib.utils.typing import ResultDict, TrainerConfigDict
|
from ray.rllib.utils.typing import ResultDict, TrainerConfigDict
|
||||||
|
|
||||||
# fmt: off
|
|
||||||
# __sphinx_doc_begin__
|
|
||||||
DEFAULT_CONFIG = with_common_config({
|
|
||||||
# === QMix ===
|
|
||||||
# Mixing network. Either "qmix", "vdn", or None
|
|
||||||
"mixer": "qmix",
|
|
||||||
# Size of the mixing network embedding
|
|
||||||
"mixing_embed_dim": 32,
|
|
||||||
# Whether to use Double_Q learning
|
|
||||||
"double_q": True,
|
|
||||||
# Optimize over complete episodes by default.
|
|
||||||
"batch_mode": "complete_episodes",
|
|
||||||
|
|
||||||
# === Exploration Settings ===
|
class QMixConfig(SimpleQConfig):
|
||||||
"exploration_config": {
|
"""Defines a configuration class from which a QMixTrainer can be built.
|
||||||
# The Exploration class to use.
|
|
||||||
"type": "EpsilonGreedy",
|
|
||||||
# Config for the Exploration class' constructor:
|
|
||||||
"initial_epsilon": 1.0,
|
|
||||||
"final_epsilon": 0.01,
|
|
||||||
# Timesteps over which to anneal epsilon.
|
|
||||||
"epsilon_timesteps": 40000,
|
|
||||||
|
|
||||||
# For soft_q, use:
|
Example:
|
||||||
# "exploration_config" = {
|
>>> from ray.rllib.examples.env.two_step_game import TwoStepGame
|
||||||
# "type": "SoftQ"
|
>>> from ray.rllib.agents.qmix import QMixConfig
|
||||||
# "temperature": [float, e.g. 1.0]
|
>>> config = QMixConfig().training(gamma=0.9, lr=0.01, kl_coeff=0.3)\
|
||||||
# }
|
... .resources(num_gpus=0)\
|
||||||
},
|
... .rollouts(num_workers=4)
|
||||||
|
>>> print(config.to_dict())
|
||||||
|
>>> # Build a Trainer object from the config and run 1 training iteration.
|
||||||
|
>>> trainer = config.build(env=TwoStepGame)
|
||||||
|
>>> trainer.train()
|
||||||
|
|
||||||
# === Evaluation ===
|
Example:
|
||||||
# Evaluate with epsilon=0 every `evaluation_interval` training iterations.
|
>>> from ray.rllib.examples.env.two_step_game import TwoStepGame
|
||||||
# The evaluation stats will be reported under the "evaluation" metric key.
|
>>> from ray.rllib.agents.qmix import QMixConfig
|
||||||
# Note that evaluation is currently not parallelized, and that for Ape-X
|
>>> from ray import tune
|
||||||
# metrics are already only reported for the lowest epsilon workers.
|
>>> config = QMixConfig()
|
||||||
"evaluation_interval": None,
|
>>> # Print out some default values.
|
||||||
# Number of episodes to run per evaluation period.
|
>>> print(config.optim_alpha)
|
||||||
"evaluation_duration": 10,
|
>>> # Update the config object.
|
||||||
# Switch to greedy actions in evaluation workers.
|
>>> config.training(lr=tune.grid_search([0.001, 0.0001]), optim_alpha=0.97)
|
||||||
"evaluation_config": {
|
>>> # Set the config object's env.
|
||||||
"explore": False,
|
>>> config.environment(env=TwoStepGame)
|
||||||
},
|
>>> # Use to_dict() to get the old-style python config dict
|
||||||
|
>>> # when running with tune.
|
||||||
|
>>> tune.run(
|
||||||
|
... "QMix",
|
||||||
|
... stop={"episode_reward_mean": 200},
|
||||||
|
... config=config.to_dict(),
|
||||||
|
... )
|
||||||
|
"""
|
||||||
|
|
||||||
# Minimum env sampling timesteps to accumulate within a single `train()` call. This
|
def __init__(self):
|
||||||
# value does not affect learning, only the number of times `Trainer.step_attempt()`
|
"""Initializes a PPOConfig instance."""
|
||||||
# is called by `Trauber.train()`. If - after one `step_attempt()`, the env sampling
|
super().__init__(trainer_class=QMixTrainer)
|
||||||
# timestep count has not been reached, will perform n more `step_attempt()` calls
|
|
||||||
# until the minimum timesteps have been executed. Set to 0 for no minimum timesteps.
|
|
||||||
"min_sample_timesteps_per_reporting": 1000,
|
|
||||||
# Update the target network every `target_network_update_freq` steps.
|
|
||||||
"target_network_update_freq": 500,
|
|
||||||
|
|
||||||
# === Replay buffer ===
|
# fmt: off
|
||||||
"replay_buffer_config": {
|
# __sphinx_doc_begin__
|
||||||
# Use the new ReplayBuffer API here
|
# QMix specific settings:
|
||||||
"_enable_replay_buffer_api": True,
|
self.mixer = "qmix"
|
||||||
"type": "SimpleReplayBuffer",
|
self.mixing_embed_dim = 32
|
||||||
# Size of the replay buffer in batches (not timesteps!).
|
self.double_q = True
|
||||||
"capacity": 1000,
|
self.target_network_update_freq = 500
|
||||||
"learning_starts": 1000,
|
self.replay_buffer_config = {
|
||||||
},
|
# Use the new ReplayBuffer API here
|
||||||
|
"_enable_replay_buffer_api": True,
|
||||||
|
"type": "SimpleReplayBuffer",
|
||||||
|
# Size of the replay buffer in batches (not timesteps!).
|
||||||
|
"capacity": 1000,
|
||||||
|
"learning_starts": 1000,
|
||||||
|
}
|
||||||
|
self.optim_alpha = 0.99
|
||||||
|
self.optim_eps = 0.00001
|
||||||
|
self.grad_norm_clipping = 10
|
||||||
|
self.worker_side_prioritization = False
|
||||||
|
|
||||||
# === Optimization ===
|
# Override some of TrainerConfig's default values with QMix-specific values.
|
||||||
# Learning rate for RMSProp optimizer
|
self.num_workers = 0
|
||||||
"lr": 0.0005,
|
self.min_time_s_per_reporting = 1
|
||||||
# RMSProp alpha
|
self.model = {
|
||||||
"optim_alpha": 0.99,
|
"lstm_cell_size": 64,
|
||||||
# RMSProp epsilon
|
"max_seq_len": 999999,
|
||||||
"optim_eps": 0.00001,
|
}
|
||||||
# If not None, clip gradients during optimization at this value
|
self.framework_str = "torch"
|
||||||
"grad_norm_clipping": 10,
|
self.lr = 0.0005
|
||||||
# Update the replay buffer with this many samples at once. Note that
|
self.rollout_fragment_length = 4
|
||||||
# this setting applies per-worker if num_workers > 1.
|
self.train_batch_size = 32
|
||||||
"rollout_fragment_length": 4,
|
self.batch_mode = "complete_episodes"
|
||||||
# Minimum batch size used for training (in timesteps). With the default buffer
|
self.exploration_config = {
|
||||||
# (ReplayBuffer) this means, sampling from the buffer (entire-episode SampleBatches)
|
# The Exploration class to use.
|
||||||
# as many times as is required to reach at least this number of timesteps.
|
"type": "EpsilonGreedy",
|
||||||
"train_batch_size": 32,
|
# Config for the Exploration class' constructor:
|
||||||
|
"initial_epsilon": 1.0,
|
||||||
|
"final_epsilon": 0.01,
|
||||||
|
# Timesteps over which to anneal epsilon.
|
||||||
|
"epsilon_timesteps": 40000,
|
||||||
|
|
||||||
# === Parallelism ===
|
# For soft_q, use:
|
||||||
# Number of workers for collecting samples with. This only makes sense
|
# "exploration_config" = {
|
||||||
# to increase if your environment is particularly slow to sample, or if
|
# "type": "SoftQ"
|
||||||
# you"re using the Async or Ape-X optimizers.
|
# "temperature": [float, e.g. 1.0]
|
||||||
"num_workers": 0,
|
# }
|
||||||
# Whether to compute priorities on workers.
|
}
|
||||||
"worker_side_prioritization": False,
|
|
||||||
# Prevent reporting frequency from going lower than this time span.
|
|
||||||
"min_time_s_per_reporting": 1,
|
|
||||||
|
|
||||||
# === Model ===
|
# Evaluate with epsilon=0 every `evaluation_interval` training iterations.
|
||||||
"model": {
|
# The evaluation stats will be reported under the "evaluation" metric key.
|
||||||
"lstm_cell_size": 64,
|
# Note that evaluation is currently not parallelized, and that for Ape-X
|
||||||
"max_seq_len": 999999,
|
# metrics are already only reported for the lowest epsilon workers.
|
||||||
},
|
self.evaluation_interval = None
|
||||||
# Only torch supported so far.
|
self.evaluation_duration = 10
|
||||||
"framework": "torch",
|
self.evaluation_config = {
|
||||||
|
"explore": False,
|
||||||
|
}
|
||||||
|
self.min_sample_timesteps_per_reporting = 1000
|
||||||
|
# __sphinx_doc_end__
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
# Deprecated keys:
|
# Deprecated keys:
|
||||||
# Use `replay_buffer_config.learning_starts` instead.
|
self.learning_starts = DEPRECATED_VALUE
|
||||||
"learning_starts": DEPRECATED_VALUE,
|
self.buffer_size = DEPRECATED_VALUE
|
||||||
# Use `replay_buffer_config.capacity` instead.
|
|
||||||
"buffer_size": DEPRECATED_VALUE,
|
@override(SimpleQConfig)
|
||||||
})
|
def training(
|
||||||
# __sphinx_doc_end__
|
self,
|
||||||
# fmt: on
|
*,
|
||||||
|
mixer: Optional[str] = None,
|
||||||
|
mixing_embed_dim: Optional[int] = None,
|
||||||
|
double_q: Optional[bool] = None,
|
||||||
|
target_network_update_freq: Optional[int] = None,
|
||||||
|
replay_buffer_config: Optional[dict] = None,
|
||||||
|
optim_alpha: Optional[float] = None,
|
||||||
|
optim_eps: Optional[float] = None,
|
||||||
|
grad_norm_clipping: Optional[float] = None,
|
||||||
|
worker_side_prioritization: Optional[bool] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> "QMixConfig":
|
||||||
|
"""Sets the training related configuration.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mixer: Mixing network. Either "qmix", "vdn", or None.
|
||||||
|
mixing_embed_dim: Size of the mixing network embedding.
|
||||||
|
double_q: Whether to use Double_Q learning.
|
||||||
|
target_network_update_freq: Update the target network every
|
||||||
|
`target_network_update_freq` sample steps.
|
||||||
|
replay_buffer_config:
|
||||||
|
optim_alpha: RMSProp alpha.
|
||||||
|
optim_eps: RMSProp epsilon.
|
||||||
|
grad_norm_clipping: If not None, clip gradients during optimization at
|
||||||
|
this value.
|
||||||
|
worker_side_prioritization: Whether to compute priorities for the replay
|
||||||
|
buffer on worker side.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
This updated TrainerConfig object.
|
||||||
|
"""
|
||||||
|
# Pass kwargs onto super's `training()` method.
|
||||||
|
super().training(**kwargs)
|
||||||
|
|
||||||
|
if mixer is not None:
|
||||||
|
self.mixer = mixer
|
||||||
|
if mixing_embed_dim is not None:
|
||||||
|
self.mixing_embed_dim = mixing_embed_dim
|
||||||
|
if double_q is not None:
|
||||||
|
self.double_q = double_q
|
||||||
|
if target_network_update_freq is not None:
|
||||||
|
self.target_network_update_freq = target_network_update_freq
|
||||||
|
if replay_buffer_config is not None:
|
||||||
|
self.replay_buffer_config = replay_buffer_config
|
||||||
|
if optim_alpha is not None:
|
||||||
|
self.optim_alpha = optim_alpha
|
||||||
|
if optim_eps is not None:
|
||||||
|
self.optim_eps = optim_eps
|
||||||
|
if grad_norm_clipping is not None:
|
||||||
|
self.grad_norm_clipping = grad_norm_clipping
|
||||||
|
if worker_side_prioritization is not None:
|
||||||
|
self.worker_side_prioritization = worker_side_prioritization
|
||||||
|
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
class QMixTrainer(SimpleQTrainer):
|
class QMixTrainer(SimpleQTrainer):
|
||||||
@classmethod
|
@classmethod
|
||||||
@override(SimpleQTrainer)
|
@override(SimpleQTrainer)
|
||||||
def get_default_config(cls) -> TrainerConfigDict:
|
def get_default_config(cls) -> TrainerConfigDict:
|
||||||
return DEFAULT_CONFIG
|
return QMixConfig().to_dict()
|
||||||
|
|
||||||
@override(SimpleQTrainer)
|
@override(SimpleQTrainer)
|
||||||
def validate_config(self, config: TrainerConfigDict) -> None:
|
def validate_config(self, config: TrainerConfigDict) -> None:
|
||||||
|
@ -219,3 +274,20 @@ class QMixTrainer(SimpleQTrainer):
|
||||||
|
|
||||||
# Return all collected metrics for the iteration.
|
# Return all collected metrics for the iteration.
|
||||||
return train_results
|
return train_results
|
||||||
|
|
||||||
|
|
||||||
|
# Deprecated: Use ray.rllib.agents.qmix.qmix.QMixConfig instead!
|
||||||
|
class _deprecated_default_config(dict):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(QMixConfig().to_dict())
|
||||||
|
|
||||||
|
@Deprecated(
|
||||||
|
old="ray.rllib.agents.qmix.qmix.DEFAULT_CONFIG",
|
||||||
|
new="ray.rllib.agents.qmix.qmix.QMixConfig(...)",
|
||||||
|
error=False,
|
||||||
|
)
|
||||||
|
def __getitem__(self, item):
|
||||||
|
return super().__getitem__(item)
|
||||||
|
|
||||||
|
|
||||||
|
DEFAULT_CONFIG = _deprecated_default_config()
|
||||||
|
|
|
@ -153,7 +153,6 @@ class QMixLoss(nn.Module):
|
||||||
return loss, mask, masked_td_error, chosen_action_qvals, targets
|
return loss, mask, masked_td_error, chosen_action_qvals, targets
|
||||||
|
|
||||||
|
|
||||||
# TODO(sven): Make this a TorchPolicy child via `build_policy_class`.
|
|
||||||
class QMixTorchPolicy(TorchPolicy):
|
class QMixTorchPolicy(TorchPolicy):
|
||||||
"""QMix impl. Assumes homogeneous agents for now.
|
"""QMix impl. Assumes homogeneous agents for now.
|
||||||
|
|
||||||
|
@ -177,9 +176,6 @@ class QMixTorchPolicy(TorchPolicy):
|
||||||
self.h_size = config["model"]["lstm_cell_size"]
|
self.h_size = config["model"]["lstm_cell_size"]
|
||||||
self.has_env_global_state = False
|
self.has_env_global_state = False
|
||||||
self.has_action_mask = False
|
self.has_action_mask = False
|
||||||
self.device = (
|
|
||||||
torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
|
||||||
)
|
|
||||||
|
|
||||||
agent_obs_space = obs_space.original_space.spaces[0]
|
agent_obs_space = obs_space.original_space.spaces[0]
|
||||||
if isinstance(agent_obs_space, gym.spaces.Dict):
|
if isinstance(agent_obs_space, gym.spaces.Dict):
|
||||||
|
@ -218,7 +214,9 @@ class QMixTorchPolicy(TorchPolicy):
|
||||||
framework="torch",
|
framework="torch",
|
||||||
name="model",
|
name="model",
|
||||||
default_model=RNNModel,
|
default_model=RNNModel,
|
||||||
).to(self.device)
|
)
|
||||||
|
|
||||||
|
super().__init__(obs_space, action_space, config, model=self.model)
|
||||||
|
|
||||||
self.target_model = ModelCatalog.get_model_v2(
|
self.target_model = ModelCatalog.get_model_v2(
|
||||||
agent_obs_space,
|
agent_obs_space,
|
||||||
|
@ -230,8 +228,6 @@ class QMixTorchPolicy(TorchPolicy):
|
||||||
default_model=RNNModel,
|
default_model=RNNModel,
|
||||||
).to(self.device)
|
).to(self.device)
|
||||||
|
|
||||||
super().__init__(obs_space, action_space, config, model=self.model)
|
|
||||||
|
|
||||||
self.exploration = self._create_exploration()
|
self.exploration = self._create_exploration()
|
||||||
|
|
||||||
# Setup the mixer network.
|
# Setup the mixer network.
|
||||||
|
|
|
@ -4,7 +4,7 @@ import unittest
|
||||||
|
|
||||||
import ray
|
import ray
|
||||||
from ray.tune import register_env
|
from ray.tune import register_env
|
||||||
from ray.rllib.agents.qmix import QMixTrainer
|
from ray.rllib.agents.qmix import QMixConfig
|
||||||
from ray.rllib.env.multi_agent_env import MultiAgentEnv
|
from ray.rllib.env.multi_agent_env import MultiAgentEnv
|
||||||
|
|
||||||
|
|
||||||
|
@ -95,18 +95,21 @@ class TestQMix(unittest.TestCase):
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
trainer = QMixTrainer(
|
config = (
|
||||||
env="action_mask_test",
|
QMixConfig()
|
||||||
config={
|
.framework(framework="torch")
|
||||||
"num_envs_per_worker": 5, # test with vectorization on
|
.environment(
|
||||||
"env_config": {
|
env="action_mask_test",
|
||||||
"avail_actions": [3, 4, 8],
|
env_config={"avail_actions": [3, 4, 8]},
|
||||||
},
|
)
|
||||||
"framework": "torch",
|
.rollouts(num_envs_per_worker=5)
|
||||||
},
|
) # Test with vectorization on.
|
||||||
)
|
|
||||||
|
trainer = config.build()
|
||||||
|
|
||||||
for _ in range(4):
|
for _ in range(4):
|
||||||
trainer.train() # OK if it doesn't trip the action assertion error
|
trainer.train() # OK if it doesn't trip the action assertion error
|
||||||
|
|
||||||
assert trainer.train()["episode_reward_mean"] == 30.0
|
assert trainer.train()["episode_reward_mean"] == 30.0
|
||||||
trainer.stop()
|
trainer.stop()
|
||||||
ray.shutdown()
|
ray.shutdown()
|
||||||
|
|
|
@ -18,7 +18,10 @@ def _import_a3c():
|
||||||
|
|
||||||
|
|
||||||
def _import_alpha_star():
|
def _import_alpha_star():
|
||||||
from ray.rllib.agents.alpha_star.alpha_star import AlphaStarTrainer, DEFAULT_CONFIG
|
from ray.rllib.algorithms.alpha_star.alpha_star import (
|
||||||
|
AlphaStarTrainer,
|
||||||
|
DEFAULT_CONFIG,
|
||||||
|
)
|
||||||
|
|
||||||
return AlphaStarTrainer, DEFAULT_CONFIG
|
return AlphaStarTrainer, DEFAULT_CONFIG
|
||||||
|
|
||||||
|
@ -42,7 +45,7 @@ def _import_appo():
|
||||||
|
|
||||||
|
|
||||||
def _import_ars():
|
def _import_ars():
|
||||||
from ray.rllib.agents import ars
|
from ray.rllib.algorithms import ars
|
||||||
|
|
||||||
return ars.ARSTrainer, ars.DEFAULT_CONFIG
|
return ars.ARSTrainer, ars.DEFAULT_CONFIG
|
||||||
|
|
||||||
|
@ -60,13 +63,13 @@ def _import_bandit_linucb():
|
||||||
|
|
||||||
|
|
||||||
def _import_bc():
|
def _import_bc():
|
||||||
from ray.rllib.agents import marwil
|
from ray.rllib.algorithms import marwil
|
||||||
|
|
||||||
return marwil.BCTrainer, marwil.DEFAULT_CONFIG
|
return marwil.BCTrainer, marwil.DEFAULT_CONFIG
|
||||||
|
|
||||||
|
|
||||||
def _import_cql():
|
def _import_cql():
|
||||||
from ray.rllib.agents import cql
|
from ray.rllib.algorithms import cql
|
||||||
|
|
||||||
return cql.CQLTrainer, cql.CQL_DEFAULT_CONFIG
|
return cql.CQLTrainer, cql.CQL_DEFAULT_CONFIG
|
||||||
|
|
||||||
|
@ -90,13 +93,13 @@ def _import_dqn():
|
||||||
|
|
||||||
|
|
||||||
def _import_dreamer():
|
def _import_dreamer():
|
||||||
from ray.rllib.agents import dreamer
|
from ray.rllib.algorithms import dreamer
|
||||||
|
|
||||||
return dreamer.DREAMERTrainer, dreamer.DEFAULT_CONFIG
|
return dreamer.DREAMERTrainer, dreamer.DEFAULT_CONFIG
|
||||||
|
|
||||||
|
|
||||||
def _import_es():
|
def _import_es():
|
||||||
from ray.rllib.agents import es
|
from ray.rllib.algorithms import es
|
||||||
|
|
||||||
return es.ESTrainer, es.DEFAULT_CONFIG
|
return es.ESTrainer, es.DEFAULT_CONFIG
|
||||||
|
|
||||||
|
@ -114,19 +117,19 @@ def _import_maddpg():
|
||||||
|
|
||||||
|
|
||||||
def _import_maml():
|
def _import_maml():
|
||||||
from ray.rllib.agents import maml
|
from ray.rllib.algorithms import maml
|
||||||
|
|
||||||
return maml.MAMLTrainer, maml.DEFAULT_CONFIG
|
return maml.MAMLTrainer, maml.DEFAULT_CONFIG
|
||||||
|
|
||||||
|
|
||||||
def _import_marwil():
|
def _import_marwil():
|
||||||
from ray.rllib.agents import marwil
|
from ray.rllib.algorithms import marwil
|
||||||
|
|
||||||
return marwil.MARWILTrainer, marwil.DEFAULT_CONFIG
|
return marwil.MARWILTrainer, marwil.DEFAULT_CONFIG
|
||||||
|
|
||||||
|
|
||||||
def _import_mbmpo():
|
def _import_mbmpo():
|
||||||
from ray.rllib.agents import mbmpo
|
from ray.rllib.algorithms import mbmpo
|
||||||
|
|
||||||
return mbmpo.MBMPOTrainer, mbmpo.DEFAULT_CONFIG
|
return mbmpo.MBMPOTrainer, mbmpo.DEFAULT_CONFIG
|
||||||
|
|
||||||
|
|
|
@ -10,7 +10,7 @@ import unittest
|
||||||
import ray
|
import ray
|
||||||
import ray.rllib.agents.a3c as a3c
|
import ray.rllib.agents.a3c as a3c
|
||||||
import ray.rllib.agents.dqn as dqn
|
import ray.rllib.agents.dqn as dqn
|
||||||
from ray.rllib.agents.marwil import BCTrainer
|
from ray.rllib.algorithms.marwil import BCTrainer
|
||||||
import ray.rllib.agents.pg as pg
|
import ray.rllib.agents.pg as pg
|
||||||
from ray.rllib.agents.trainer import COMMON_CONFIG
|
from ray.rllib.agents.trainer import COMMON_CONFIG
|
||||||
from ray.rllib.examples.env.multi_agent import MultiAgentCartPole
|
from ray.rllib.examples.env.multi_agent import MultiAgentCartPole
|
||||||
|
|
11
rllib/algorithms/alpha_star/__init__.py
Normal file
11
rllib/algorithms/alpha_star/__init__.py
Normal file
|
@ -0,0 +1,11 @@
|
||||||
|
from ray.rllib.algorithms.alpha_star.alpha_star import (
|
||||||
|
AlphaStarConfig,
|
||||||
|
AlphaStarTrainer,
|
||||||
|
DEFAULT_CONFIG,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"AlphaStarConfig",
|
||||||
|
"AlphaStarTrainer",
|
||||||
|
"DEFAULT_CONFIG",
|
||||||
|
]
|
|
@ -9,8 +9,8 @@ from typing import Any, Dict, Optional, Type
|
||||||
|
|
||||||
import ray
|
import ray
|
||||||
from ray.actor import ActorHandle
|
from ray.actor import ActorHandle
|
||||||
from ray.rllib.agents.alpha_star.distributed_learners import DistributedLearners
|
from ray.rllib.algorithms.alpha_star.distributed_learners import DistributedLearners
|
||||||
from ray.rllib.agents.alpha_star.league_builder import AlphaStarLeagueBuilder
|
from ray.rllib.algorithms.alpha_star.league_builder import AlphaStarLeagueBuilder
|
||||||
from ray.rllib.agents.trainer import Trainer
|
from ray.rllib.agents.trainer import Trainer
|
||||||
import ray.rllib.agents.ppo.appo as appo
|
import ray.rllib.agents.ppo.appo as appo
|
||||||
from ray.rllib.evaluation.rollout_worker import RolloutWorker
|
from ray.rllib.evaluation.rollout_worker import RolloutWorker
|
||||||
|
@ -48,7 +48,7 @@ class AlphaStarConfig(appo.APPOConfig):
|
||||||
"""Defines a configuration class from which an AlphaStarTrainer can be built.
|
"""Defines a configuration class from which an AlphaStarTrainer can be built.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
>>> from ray.rllib.agents.alpha_star import AlphaStarConfig
|
>>> from ray.rllib.algorithms.alpha_star import AlphaStarConfig
|
||||||
>>> config = AlphaStarConfig().training(lr=0.0003, train_batch_size=512)\
|
>>> config = AlphaStarConfig().training(lr=0.0003, train_batch_size=512)\
|
||||||
... .resources(num_gpus=4)\
|
... .resources(num_gpus=4)\
|
||||||
... .rollouts(num_rollout_workers=64)
|
... .rollouts(num_rollout_workers=64)
|
||||||
|
@ -58,7 +58,7 @@ class AlphaStarConfig(appo.APPOConfig):
|
||||||
>>> trainer.train()
|
>>> trainer.train()
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
>>> from ray.rllib.agents.alpha_star import AlphaStarConfig
|
>>> from ray.rllib.algorithms.alpha_star import AlphaStarConfig
|
||||||
>>> from ray import tune
|
>>> from ray import tune
|
||||||
>>> config = AlphaStarConfig()
|
>>> config = AlphaStarConfig()
|
||||||
>>> # Print out some default values.
|
>>> # Print out some default values.
|
||||||
|
@ -168,7 +168,7 @@ class AlphaStarConfig(appo.APPOConfig):
|
||||||
to be used for league building logic. All other keys (that are not
|
to be used for league building logic. All other keys (that are not
|
||||||
`type`) will be used as constructor kwargs on the given class to
|
`type`) will be used as constructor kwargs on the given class to
|
||||||
construct the LeagueBuilder instance. See the
|
construct the LeagueBuilder instance. See the
|
||||||
`ray.rllib.agents.alpha_star.league_builder::AlphaStarLeagueBuilder`
|
`ray.rllib.algorithms.alpha_star.league_builder::AlphaStarLeagueBuilder`
|
||||||
(used by default by this algo) as an example.
|
(used by default by this algo) as an example.
|
||||||
max_num_policies_to_train: The maximum number of trainable policies for this
|
max_num_policies_to_train: The maximum number of trainable policies for this
|
||||||
Trainer. Each trainable policy will exist as a independent remote actor,
|
Trainer. Each trainable policy will exist as a independent remote actor,
|
||||||
|
@ -584,8 +584,8 @@ class _deprecated_default_config(dict):
|
||||||
super().__init__(AlphaStarConfig().to_dict())
|
super().__init__(AlphaStarConfig().to_dict())
|
||||||
|
|
||||||
@Deprecated(
|
@Deprecated(
|
||||||
old="ray.rllib.agents.alpha_star.alpha_star.DEFAULT_CONFIG",
|
old="ray.rllib.algorithms.alpha_star.alpha_star.DEFAULT_CONFIG",
|
||||||
new="ray.rllib.agents.alpha_star.alpha_star.AlphaStarConfig(...)",
|
new="ray.rllib.algorithms.alpha_star.alpha_star.AlphaStarConfig(...)",
|
||||||
error=False,
|
error=False,
|
||||||
)
|
)
|
||||||
def __getitem__(self, item):
|
def __getitem__(self, item):
|
|
@ -2,7 +2,7 @@ import pyspiel
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
import ray
|
import ray
|
||||||
import ray.rllib.agents.alpha_star as alpha_star
|
import ray.rllib.algorithms.alpha_star as alpha_star
|
||||||
from ray.rllib.env.wrappers.open_spiel import OpenSpielEnv
|
from ray.rllib.env.wrappers.open_spiel import OpenSpielEnv
|
||||||
from ray.rllib.utils.test_utils import (
|
from ray.rllib.utils.test_utils import (
|
||||||
check_compute_single_action,
|
check_compute_single_action,
|
8
rllib/algorithms/cql/__init__.py
Normal file
8
rllib/algorithms/cql/__init__.py
Normal file
|
@ -0,0 +1,8 @@
|
||||||
|
from ray.rllib.algorithms.cql.cql import CQLTrainer, CQL_DEFAULT_CONFIG
|
||||||
|
from ray.rllib.algorithms.cql.cql_torch_policy import CQLTorchPolicy
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"CQL_DEFAULT_CONFIG",
|
||||||
|
"CQLTorchPolicy",
|
||||||
|
"CQLTrainer",
|
||||||
|
]
|
|
@ -2,8 +2,8 @@ import logging
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from typing import Type
|
from typing import Type
|
||||||
|
|
||||||
from ray.rllib.agents.cql.cql_tf_policy import CQLTFPolicy
|
from ray.rllib.algorithms.cql.cql_tf_policy import CQLTFPolicy
|
||||||
from ray.rllib.agents.cql.cql_torch_policy import CQLTorchPolicy
|
from ray.rllib.algorithms.cql.cql_torch_policy import CQLTorchPolicy
|
||||||
from ray.rllib.agents.sac.sac import SACTrainer, DEFAULT_CONFIG as SAC_CONFIG
|
from ray.rllib.agents.sac.sac import SACTrainer, DEFAULT_CONFIG as SAC_CONFIG
|
||||||
from ray.rllib.execution.train_ops import (
|
from ray.rllib.execution.train_ops import (
|
||||||
multi_gpu_train_one_step,
|
multi_gpu_train_one_step,
|
|
@ -411,7 +411,7 @@ def apply_gradients_fn(policy, optimizer, grads_and_vars):
|
||||||
CQLTFPolicy = build_tf_policy(
|
CQLTFPolicy = build_tf_policy(
|
||||||
name="CQLTFPolicy",
|
name="CQLTFPolicy",
|
||||||
loss_fn=cql_loss,
|
loss_fn=cql_loss,
|
||||||
get_default_config=lambda: ray.rllib.agents.cql.cql.CQL_DEFAULT_CONFIG,
|
get_default_config=lambda: ray.rllib.algorithms.cql.cql.CQL_DEFAULT_CONFIG,
|
||||||
validate_spaces=validate_spaces,
|
validate_spaces=validate_spaces,
|
||||||
stats_fn=cql_stats,
|
stats_fn=cql_stats,
|
||||||
postprocess_fn=postprocess_trajectory,
|
postprocess_fn=postprocess_trajectory,
|
|
@ -387,7 +387,7 @@ CQLTorchPolicy = build_policy_class(
|
||||||
name="CQLTorchPolicy",
|
name="CQLTorchPolicy",
|
||||||
framework="torch",
|
framework="torch",
|
||||||
loss_fn=cql_loss,
|
loss_fn=cql_loss,
|
||||||
get_default_config=lambda: ray.rllib.agents.cql.cql.CQL_DEFAULT_CONFIG,
|
get_default_config=lambda: ray.rllib.algorithms.cql.cql.CQL_DEFAULT_CONFIG,
|
||||||
stats_fn=cql_stats,
|
stats_fn=cql_stats,
|
||||||
postprocess_fn=postprocess_trajectory,
|
postprocess_fn=postprocess_trajectory,
|
||||||
extra_grad_process_fn=apply_grad_clipping,
|
extra_grad_process_fn=apply_grad_clipping,
|
|
@ -4,7 +4,7 @@ import os
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
import ray
|
import ray
|
||||||
import ray.rllib.agents.cql as cql
|
import ray.rllib.algorithms.cql as cql
|
||||||
from ray.rllib.utils.framework import try_import_tf, try_import_torch
|
from ray.rllib.utils.framework import try_import_tf, try_import_torch
|
||||||
from ray.rllib.utils.test_utils import (
|
from ray.rllib.utils.test_utils import (
|
||||||
check_compute_single_action,
|
check_compute_single_action,
|
11
rllib/algorithms/dreamer/__init__.py
Normal file
11
rllib/algorithms/dreamer/__init__.py
Normal file
|
@ -0,0 +1,11 @@
|
||||||
|
from ray.rllib.algorithms.dreamer.dreamer import (
|
||||||
|
DREAMERConfig,
|
||||||
|
DREAMERTrainer,
|
||||||
|
DEFAULT_CONFIG,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"DREAMERConfig",
|
||||||
|
"DREAMERTrainer",
|
||||||
|
"DEFAULT_CONFIG",
|
||||||
|
]
|
|
@ -4,12 +4,12 @@ import random
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from ray.rllib.agents.trainer_config import TrainerConfig
|
from ray.rllib.agents.trainer_config import TrainerConfig
|
||||||
from ray.rllib.agents.dreamer.dreamer_torch_policy import DreamerTorchPolicy
|
from ray.rllib.algorithms.dreamer.dreamer_torch_policy import DreamerTorchPolicy
|
||||||
from ray.rllib.agents.trainer import Trainer
|
from ray.rllib.agents.trainer import Trainer
|
||||||
from ray.rllib.execution.common import STEPS_SAMPLED_COUNTER, _get_shared_metrics
|
from ray.rllib.execution.common import STEPS_SAMPLED_COUNTER, _get_shared_metrics
|
||||||
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID, SampleBatch
|
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID, SampleBatch
|
||||||
from ray.rllib.evaluation.metrics import collect_metrics
|
from ray.rllib.evaluation.metrics import collect_metrics
|
||||||
from ray.rllib.agents.dreamer.dreamer_model import DreamerModel
|
from ray.rllib.algorithms.dreamer.dreamer_model import DreamerModel
|
||||||
from ray.rllib.execution.rollout_ops import (
|
from ray.rllib.execution.rollout_ops import (
|
||||||
ParallelRollouts,
|
ParallelRollouts,
|
||||||
synchronous_parallel_sample,
|
synchronous_parallel_sample,
|
||||||
|
@ -31,7 +31,7 @@ class DREAMERConfig(TrainerConfig):
|
||||||
"""Defines a PPOTrainer configuration class from which a PPOTrainer can be built.
|
"""Defines a PPOTrainer configuration class from which a PPOTrainer can be built.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
>>> from ray.rllib.agents.dreamer import DREAMERConfig
|
>>> from ray.rllib.algorithms.dreamer import DREAMERConfig
|
||||||
>>> config = DREAMERConfig().training(gamma=0.9, lr=0.01)\
|
>>> config = DREAMERConfig().training(gamma=0.9, lr=0.01)\
|
||||||
... .resources(num_gpus=0)\
|
... .resources(num_gpus=0)\
|
||||||
... .rollouts(num_rollout_workers=4)
|
... .rollouts(num_rollout_workers=4)
|
||||||
|
@ -42,7 +42,7 @@ class DREAMERConfig(TrainerConfig):
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
>>> from ray import tune
|
>>> from ray import tune
|
||||||
>>> from ray.rllib.agents.dreamer import DREAMERConfig
|
>>> from ray.rllib.algorithms.dreamer import DREAMERConfig
|
||||||
>>> config = DREAMERConfig()
|
>>> config = DREAMERConfig()
|
||||||
>>> # Print out some default values.
|
>>> # Print out some default values.
|
||||||
>>> print(config.clip_param)
|
>>> print(config.clip_param)
|
||||||
|
@ -418,14 +418,14 @@ class DREAMERTrainer(Trainer):
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
# Deprecated: Use ray.rllib.agents.dreamer.DREAMERConfig instead!
|
# Deprecated: Use ray.rllib.algorithms.dreamer.DREAMERConfig instead!
|
||||||
class _deprecated_default_config(dict):
|
class _deprecated_default_config(dict):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__(DREAMERConfig().to_dict())
|
super().__init__(DREAMERConfig().to_dict())
|
||||||
|
|
||||||
@Deprecated(
|
@Deprecated(
|
||||||
old="ray.rllib.agents.dreamer.dreamer.DEFAULT_CONFIG",
|
old="ray.rllib.algorithms.dreamer.dreamer.DEFAULT_CONFIG",
|
||||||
new="ray.rllib.agents.dreamer.dreamer.DREAMERConfig(...)",
|
new="ray.rllib.algorithms.dreamer.dreamer.DREAMERConfig(...)",
|
||||||
error=False,
|
error=False,
|
||||||
)
|
)
|
||||||
def __getitem__(self, item):
|
def __getitem__(self, item):
|
|
@ -8,7 +8,7 @@ from ray.rllib.utils.framework import TensorType
|
||||||
torch, nn = try_import_torch()
|
torch, nn = try_import_torch()
|
||||||
if torch:
|
if torch:
|
||||||
from torch import distributions as td
|
from torch import distributions as td
|
||||||
from ray.rllib.agents.dreamer.utils import (
|
from ray.rllib.algorithms.dreamer.utils import (
|
||||||
Linear,
|
Linear,
|
||||||
Conv2d,
|
Conv2d,
|
||||||
ConvTranspose2d,
|
ConvTranspose2d,
|
|
@ -4,7 +4,7 @@ import numpy as np
|
||||||
from typing import Dict, Optional
|
from typing import Dict, Optional
|
||||||
|
|
||||||
import ray
|
import ray
|
||||||
from ray.rllib.agents.dreamer.utils import FreezeParameters
|
from ray.rllib.algorithms.dreamer.utils import FreezeParameters
|
||||||
from ray.rllib.evaluation.episode import Episode
|
from ray.rllib.evaluation.episode import Episode
|
||||||
from ray.rllib.models.catalog import ModelCatalog
|
from ray.rllib.models.catalog import ModelCatalog
|
||||||
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
|
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
|
||||||
|
@ -284,7 +284,7 @@ def preprocess_episode(
|
||||||
DreamerTorchPolicy = build_policy_class(
|
DreamerTorchPolicy = build_policy_class(
|
||||||
name="DreamerTorchPolicy",
|
name="DreamerTorchPolicy",
|
||||||
framework="torch",
|
framework="torch",
|
||||||
get_default_config=lambda: ray.rllib.agents.dreamer.dreamer.DEFAULT_CONFIG,
|
get_default_config=lambda: ray.rllib.algorithms.dreamer.dreamer.DEFAULT_CONFIG,
|
||||||
action_sampler_fn=action_sampler_fn,
|
action_sampler_fn=action_sampler_fn,
|
||||||
postprocess_fn=preprocess_episode,
|
postprocess_fn=preprocess_episode,
|
||||||
loss_fn=dreamer_loss,
|
loss_fn=dreamer_loss,
|
|
@ -2,7 +2,7 @@ from gym.spaces import Box
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
import ray
|
import ray
|
||||||
import ray.rllib.agents.dreamer as dreamer
|
import ray.rllib.algorithms.dreamer as dreamer
|
||||||
from ray.rllib.examples.env.random_env import RandomEnv
|
from ray.rllib.examples.env.random_env import RandomEnv
|
||||||
from ray.rllib.utils.test_utils import framework_iterator
|
from ray.rllib.utils.test_utils import framework_iterator
|
||||||
|
|
6
rllib/algorithms/maml/__init__.py
Normal file
6
rllib/algorithms/maml/__init__.py
Normal file
|
@ -0,0 +1,6 @@
|
||||||
|
from ray.rllib.algorithms.maml.maml import MAMLTrainer, DEFAULT_CONFIG
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"MAMLTrainer",
|
||||||
|
"DEFAULT_CONFIG",
|
||||||
|
]
|
|
@ -4,8 +4,8 @@ from typing import Type
|
||||||
|
|
||||||
from ray.rllib.utils.sgd import standardized
|
from ray.rllib.utils.sgd import standardized
|
||||||
from ray.rllib.agents import with_common_config
|
from ray.rllib.agents import with_common_config
|
||||||
from ray.rllib.agents.maml.maml_tf_policy import MAMLTFPolicy
|
from ray.rllib.algorithms.maml.maml_tf_policy import MAMLTFPolicy
|
||||||
from ray.rllib.agents.maml.maml_torch_policy import MAMLTorchPolicy
|
from ray.rllib.algorithms.maml.maml_torch_policy import MAMLTorchPolicy
|
||||||
from ray.rllib.agents.trainer import Trainer
|
from ray.rllib.agents.trainer import Trainer
|
||||||
from ray.rllib.evaluation.metrics import get_learner_stats
|
from ray.rllib.evaluation.metrics import get_learner_stats
|
||||||
from ray.rllib.evaluation.worker_set import WorkerSet
|
from ray.rllib.evaluation.worker_set import WorkerSet
|
|
@ -442,7 +442,7 @@ def setup_mixins(policy, obs_space, action_space, config):
|
||||||
|
|
||||||
MAMLTFPolicy = build_tf_policy(
|
MAMLTFPolicy = build_tf_policy(
|
||||||
name="MAMLTFPolicy",
|
name="MAMLTFPolicy",
|
||||||
get_default_config=lambda: ray.rllib.agents.maml.maml.DEFAULT_CONFIG,
|
get_default_config=lambda: ray.rllib.algorithms.maml.maml.DEFAULT_CONFIG,
|
||||||
loss_fn=maml_loss,
|
loss_fn=maml_loss,
|
||||||
stats_fn=maml_stats,
|
stats_fn=maml_stats,
|
||||||
optimizer_fn=maml_optimizer_fn,
|
optimizer_fn=maml_optimizer_fn,
|
|
@ -382,7 +382,7 @@ def setup_mixins(policy, obs_space, action_space, config):
|
||||||
MAMLTorchPolicy = build_policy_class(
|
MAMLTorchPolicy = build_policy_class(
|
||||||
name="MAMLTorchPolicy",
|
name="MAMLTorchPolicy",
|
||||||
framework="torch",
|
framework="torch",
|
||||||
get_default_config=lambda: ray.rllib.agents.maml.maml.DEFAULT_CONFIG,
|
get_default_config=lambda: ray.rllib.algorithms.maml.maml.DEFAULT_CONFIG,
|
||||||
loss_fn=maml_loss,
|
loss_fn=maml_loss,
|
||||||
stats_fn=maml_stats,
|
stats_fn=maml_stats,
|
||||||
optimizer_fn=maml_optimizer_fn,
|
optimizer_fn=maml_optimizer_fn,
|
|
@ -1,7 +1,7 @@
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
import ray
|
import ray
|
||||||
import ray.rllib.agents.maml as maml
|
import ray.rllib.algorithms.maml as maml
|
||||||
from ray.rllib.utils.test_utils import (
|
from ray.rllib.utils.test_utils import (
|
||||||
check_compute_single_action,
|
check_compute_single_action,
|
||||||
check_train_results,
|
check_train_results,
|
13
rllib/algorithms/marwil/__init__.py
Normal file
13
rllib/algorithms/marwil/__init__.py
Normal file
|
@ -0,0 +1,13 @@
|
||||||
|
from ray.rllib.algorithms.marwil.bc import BCTrainer, BC_DEFAULT_CONFIG
|
||||||
|
from ray.rllib.algorithms.marwil.marwil import MARWILTrainer, DEFAULT_CONFIG
|
||||||
|
from ray.rllib.algorithms.marwil.marwil_tf_policy import MARWILTFPolicy
|
||||||
|
from ray.rllib.algorithms.marwil.marwil_torch_policy import MARWILTorchPolicy
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"BCTrainer",
|
||||||
|
"BC_DEFAULT_CONFIG",
|
||||||
|
"DEFAULT_CONFIG",
|
||||||
|
"MARWILTFPolicy",
|
||||||
|
"MARWILTorchPolicy",
|
||||||
|
"MARWILTrainer",
|
||||||
|
]
|
|
@ -1,4 +1,4 @@
|
||||||
from ray.rllib.agents.marwil.marwil import (
|
from ray.rllib.algorithms.marwil.marwil import (
|
||||||
MARWILTrainer,
|
MARWILTrainer,
|
||||||
DEFAULT_CONFIG as MARWIL_CONFIG,
|
DEFAULT_CONFIG as MARWIL_CONFIG,
|
||||||
)
|
)
|
|
@ -1,7 +1,7 @@
|
||||||
from typing import Type
|
from typing import Type
|
||||||
|
|
||||||
from ray.rllib.agents.trainer import Trainer, with_common_config
|
from ray.rllib.agents.trainer import Trainer, with_common_config
|
||||||
from ray.rllib.agents.marwil.marwil_tf_policy import MARWILTFPolicy
|
from ray.rllib.algorithms.marwil.marwil_tf_policy import MARWILTFPolicy
|
||||||
from ray.rllib.execution.buffers.multi_agent_replay_buffer import MultiAgentReplayBuffer
|
from ray.rllib.execution.buffers.multi_agent_replay_buffer import MultiAgentReplayBuffer
|
||||||
from ray.rllib.execution.rollout_ops import (
|
from ray.rllib.execution.rollout_ops import (
|
||||||
synchronous_parallel_sample,
|
synchronous_parallel_sample,
|
||||||
|
@ -110,7 +110,9 @@ class MARWILTrainer(Trainer):
|
||||||
@override(Trainer)
|
@override(Trainer)
|
||||||
def get_default_policy_class(self, config: TrainerConfigDict) -> Type[Policy]:
|
def get_default_policy_class(self, config: TrainerConfigDict) -> Type[Policy]:
|
||||||
if config["framework"] == "torch":
|
if config["framework"] == "torch":
|
||||||
from ray.rllib.agents.marwil.marwil_torch_policy import MARWILTorchPolicy
|
from ray.rllib.algorithms.marwil.marwil_torch_policy import (
|
||||||
|
MARWILTorchPolicy,
|
||||||
|
)
|
||||||
|
|
||||||
return MARWILTorchPolicy
|
return MARWILTorchPolicy
|
||||||
else:
|
else:
|
|
@ -232,7 +232,7 @@ def setup_mixins(
|
||||||
|
|
||||||
MARWILTFPolicy = build_tf_policy(
|
MARWILTFPolicy = build_tf_policy(
|
||||||
name="MARWILTFPolicy",
|
name="MARWILTFPolicy",
|
||||||
get_default_config=lambda: ray.rllib.agents.marwil.marwil.DEFAULT_CONFIG,
|
get_default_config=lambda: ray.rllib.algorithms.marwil.marwil.DEFAULT_CONFIG,
|
||||||
loss_fn=marwil_loss,
|
loss_fn=marwil_loss,
|
||||||
stats_fn=stats,
|
stats_fn=stats,
|
||||||
postprocess_fn=postprocess_advantages,
|
postprocess_fn=postprocess_advantages,
|
|
@ -3,7 +3,7 @@ from typing import Dict
|
||||||
|
|
||||||
import ray
|
import ray
|
||||||
from ray.rllib.agents.a3c.a3c_torch_policy import ValueNetworkMixin
|
from ray.rllib.agents.a3c.a3c_torch_policy import ValueNetworkMixin
|
||||||
from ray.rllib.agents.marwil.marwil_tf_policy import postprocess_advantages
|
from ray.rllib.algorithms.marwil.marwil_tf_policy import postprocess_advantages
|
||||||
from ray.rllib.evaluation.postprocessing import Postprocessing
|
from ray.rllib.evaluation.postprocessing import Postprocessing
|
||||||
from ray.rllib.policy.policy_template import build_policy_class
|
from ray.rllib.policy.policy_template import build_policy_class
|
||||||
from ray.rllib.policy.sample_batch import SampleBatch
|
from ray.rllib.policy.sample_batch import SampleBatch
|
||||||
|
@ -113,7 +113,7 @@ MARWILTorchPolicy = build_policy_class(
|
||||||
name="MARWILTorchPolicy",
|
name="MARWILTorchPolicy",
|
||||||
framework="torch",
|
framework="torch",
|
||||||
loss_fn=marwil_loss,
|
loss_fn=marwil_loss,
|
||||||
get_default_config=lambda: ray.rllib.agents.marwil.marwil.DEFAULT_CONFIG,
|
get_default_config=lambda: ray.rllib.algorithms.marwil.marwil.DEFAULT_CONFIG,
|
||||||
stats_fn=stats,
|
stats_fn=stats,
|
||||||
postprocess_fn=postprocess_advantages,
|
postprocess_fn=postprocess_advantages,
|
||||||
extra_grad_process_fn=apply_grad_clipping,
|
extra_grad_process_fn=apply_grad_clipping,
|
|
@ -3,7 +3,7 @@ from pathlib import Path
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
import ray
|
import ray
|
||||||
import ray.rllib.agents.marwil as marwil
|
import ray.rllib.algorithms.marwil as marwil
|
||||||
from ray.rllib.utils.framework import try_import_tf
|
from ray.rllib.utils.framework import try_import_tf
|
||||||
from ray.rllib.utils.test_utils import (
|
from ray.rllib.utils.test_utils import (
|
||||||
check_compute_single_action,
|
check_compute_single_action,
|
|
@ -4,7 +4,7 @@ from pathlib import Path
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
import ray
|
import ray
|
||||||
import ray.rllib.agents.marwil as marwil
|
import ray.rllib.algorithms.marwil as marwil
|
||||||
from ray.rllib.evaluation.postprocessing import compute_advantages
|
from ray.rllib.evaluation.postprocessing import compute_advantages
|
||||||
from ray.rllib.offline import JsonReader
|
from ray.rllib.offline import JsonReader
|
||||||
from ray.rllib.utils.framework import try_import_tf, try_import_torch
|
from ray.rllib.utils.framework import try_import_tf, try_import_torch
|
6
rllib/algorithms/mbmpo/__init__.py
Normal file
6
rllib/algorithms/mbmpo/__init__.py
Normal file
|
@ -0,0 +1,6 @@
|
||||||
|
from ray.rllib.algorithms.mbmpo.mbmpo import MBMPOTrainer, DEFAULT_CONFIG
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"MBMPOTrainer",
|
||||||
|
"DEFAULT_CONFIG",
|
||||||
|
]
|
|
@ -4,9 +4,9 @@ from typing import List, Type
|
||||||
|
|
||||||
import ray
|
import ray
|
||||||
from ray.rllib.agents import with_common_config
|
from ray.rllib.agents import with_common_config
|
||||||
from ray.rllib.agents.mbmpo.mbmpo_torch_policy import MBMPOTorchPolicy
|
from ray.rllib.algorithms.mbmpo.mbmpo_torch_policy import MBMPOTorchPolicy
|
||||||
from ray.rllib.agents.mbmpo.model_ensemble import DynamicsEnsembleCustomModel
|
from ray.rllib.algorithms.mbmpo.model_ensemble import DynamicsEnsembleCustomModel
|
||||||
from ray.rllib.agents.mbmpo.utils import calculate_gae_advantages, MBMPOExploration
|
from ray.rllib.algorithms.mbmpo.utils import calculate_gae_advantages, MBMPOExploration
|
||||||
from ray.rllib.agents.trainer import Trainer
|
from ray.rllib.agents.trainer import Trainer
|
||||||
from ray.rllib.env.env_context import EnvContext
|
from ray.rllib.env.env_context import EnvContext
|
||||||
from ray.rllib.env.wrappers.model_vector_env import model_vector_env
|
from ray.rllib.env.wrappers.model_vector_env import model_vector_env
|
|
@ -5,7 +5,7 @@ from typing import Tuple, Type
|
||||||
|
|
||||||
import ray
|
import ray
|
||||||
from ray.rllib.agents.a3c.a3c_torch_policy import vf_preds_fetches
|
from ray.rllib.agents.a3c.a3c_torch_policy import vf_preds_fetches
|
||||||
from ray.rllib.agents.maml.maml_torch_policy import (
|
from ray.rllib.algorithms.maml.maml_torch_policy import (
|
||||||
setup_mixins,
|
setup_mixins,
|
||||||
maml_loss,
|
maml_loss,
|
||||||
maml_stats,
|
maml_stats,
|
||||||
|
@ -121,7 +121,7 @@ def make_model_and_action_dist(
|
||||||
MBMPOTorchPolicy = build_policy_class(
|
MBMPOTorchPolicy = build_policy_class(
|
||||||
name="MBMPOTorchPolicy",
|
name="MBMPOTorchPolicy",
|
||||||
framework="torch",
|
framework="torch",
|
||||||
get_default_config=lambda: ray.rllib.agents.mbmpo.mbmpo.DEFAULT_CONFIG,
|
get_default_config=lambda: ray.rllib.algorithms.mbmpo.mbmpo.DEFAULT_CONFIG,
|
||||||
make_model_and_action_dist=make_model_and_action_dist,
|
make_model_and_action_dist=make_model_and_action_dist,
|
||||||
loss_fn=maml_loss,
|
loss_fn=maml_loss,
|
||||||
stats_fn=maml_stats,
|
stats_fn=maml_stats,
|
|
@ -1,7 +1,7 @@
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
import ray
|
import ray
|
||||||
import ray.rllib.agents.mbmpo as mbmpo
|
import ray.rllib.algorithms.mbmpo as mbmpo
|
||||||
from ray.rllib.utils.test_utils import (
|
from ray.rllib.utils.test_utils import (
|
||||||
check_compute_single_action,
|
check_compute_single_action,
|
||||||
check_train_results,
|
check_train_results,
|
12
rllib/env/multi_agent_env.py
vendored
12
rllib/env/multi_agent_env.py
vendored
|
@ -245,7 +245,6 @@ class MultiAgentEnv(gym.Env):
|
||||||
|
|
||||||
# fmt: off
|
# fmt: off
|
||||||
# __grouping_doc_begin__
|
# __grouping_doc_begin__
|
||||||
@ExperimentalAPI
|
|
||||||
def with_agent_groups(
|
def with_agent_groups(
|
||||||
self,
|
self,
|
||||||
groups: Dict[str, List[AgentID]],
|
groups: Dict[str, List[AgentID]],
|
||||||
|
@ -265,16 +264,17 @@ class MultiAgentEnv(gym.Env):
|
||||||
|
|
||||||
Agent grouping is required to leverage algorithms such as Q-Mix.
|
Agent grouping is required to leverage algorithms such as Q-Mix.
|
||||||
|
|
||||||
This API is experimental.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
groups: Mapping from group id to a list of the agent ids
|
groups: Mapping from group id to a list of the agent ids
|
||||||
of group members. If an agent id is not present in any group
|
of group members. If an agent id is not present in any group
|
||||||
value, it will be left ungrouped.
|
value, it will be left ungrouped. The group id becomes a new agent ID
|
||||||
|
in the final environment.
|
||||||
obs_space: Optional observation space for the grouped
|
obs_space: Optional observation space for the grouped
|
||||||
env. Must be a tuple space.
|
env. Must be a tuple space. If not provided, will infer this to be a
|
||||||
|
Tuple of n individual agents spaces (n=num agents in a group).
|
||||||
act_space: Optional action space for the grouped env.
|
act_space: Optional action space for the grouped env.
|
||||||
Must be a tuple space.
|
Must be a tuple space. If not provided, will infer this to be a Tuple
|
||||||
|
of n individual agents spaces (n=num agents in a group).
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
>>> from ray.rllib.env.multi_agent_env import MultiAgentEnv
|
>>> from ray.rllib.env.multi_agent_env import MultiAgentEnv
|
||||||
|
|
35
rllib/env/wrappers/group_agents_wrapper.py
vendored
35
rllib/env/wrappers/group_agents_wrapper.py
vendored
|
@ -1,6 +1,9 @@
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
|
import gym
|
||||||
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
from ray.rllib.env.multi_agent_env import MultiAgentEnv
|
from ray.rllib.env.multi_agent_env import MultiAgentEnv
|
||||||
|
from ray.rllib.utils.typing import AgentID
|
||||||
|
|
||||||
# info key for the individual rewards of an agent, for example:
|
# info key for the individual rewards of an agent, for example:
|
||||||
# info: {
|
# info: {
|
||||||
|
@ -27,21 +30,35 @@ class GroupAgentsWrapper(MultiAgentEnv):
|
||||||
This API is experimental.
|
This API is experimental.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, env, groups, obs_space=None, act_space=None):
|
def __init__(
|
||||||
"""Wrap an existing multi-agent env to group agents together.
|
self,
|
||||||
|
env: MultiAgentEnv,
|
||||||
|
groups: Dict[str, List[AgentID]],
|
||||||
|
obs_space: Optional[gym.Space] = None,
|
||||||
|
act_space: Optional[gym.Space] = None,
|
||||||
|
):
|
||||||
|
"""Wrap an existing MultiAgentEnv to group agent ID together.
|
||||||
|
|
||||||
See MultiAgentEnv.with_agent_groups() for usage info.
|
See `MultiAgentEnv.with_agent_groups()` for more detailed usage info.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
env (MultiAgentEnv): env to wrap
|
env: The env to wrap and whose agent IDs to group into new agents.
|
||||||
groups (dict): Grouping spec as documented in MultiAgentEnv.
|
groups: Mapping from group id to a list of the agent ids
|
||||||
obs_space (Space): Optional observation space for the grouped
|
of group members. If an agent id is not present in any group
|
||||||
env. Must be a tuple space.
|
value, it will be left ungrouped. The group id becomes a new agent ID
|
||||||
act_space (Space): Optional action space for the grouped env.
|
in the final environment.
|
||||||
Must be a tuple space.
|
obs_space: Optional observation space for the grouped
|
||||||
|
env. Must be a tuple space. If not provided, will infer this to be a
|
||||||
|
Tuple of n individual agents spaces (n=num agents in a group).
|
||||||
|
act_space: Optional action space for the grouped env.
|
||||||
|
Must be a tuple space. If not provided, will infer this to be a Tuple
|
||||||
|
of n individual agents spaces (n=num agents in a group).
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.env = env
|
self.env = env
|
||||||
|
# Inherit wrapped env's `_skip_env_checking` flag.
|
||||||
|
if hasattr(self.env, "_skip_env_checking"):
|
||||||
|
self._skip_env_checking = self.env._skip_env_checking
|
||||||
self.groups = groups
|
self.groups = groups
|
||||||
self.agent_id_to_group = {}
|
self.agent_id_to_group = {}
|
||||||
for group_id, agent_ids in groups.items():
|
for group_id, agent_ids in groups.items():
|
||||||
|
|
|
@ -20,7 +20,9 @@ import os
|
||||||
|
|
||||||
import ray
|
import ray
|
||||||
from ray import tune
|
from ray import tune
|
||||||
from ray.rllib.agents.maml.maml_torch_policy import KLCoeffMixin as TorchKLCoeffMixin
|
from ray.rllib.algorithms.maml.maml_torch_policy import (
|
||||||
|
KLCoeffMixin as TorchKLCoeffMixin,
|
||||||
|
)
|
||||||
from ray.rllib.agents.ppo.ppo import PPOTrainer
|
from ray.rllib.agents.ppo.ppo import PPOTrainer
|
||||||
from ray.rllib.agents.ppo.ppo_tf_policy import (
|
from ray.rllib.agents.ppo.ppo_tf_policy import (
|
||||||
PPOTFPolicy,
|
PPOTFPolicy,
|
||||||
|
|
|
@ -16,6 +16,7 @@ import os
|
||||||
import ray
|
import ray
|
||||||
from ray import tune
|
from ray import tune
|
||||||
from ray.tune import register_env
|
from ray.tune import register_env
|
||||||
|
from ray.rllib.agents.qmix import QMixConfig
|
||||||
from ray.rllib.env.multi_agent_env import ENV_STATE
|
from ray.rllib.env.multi_agent_env import ENV_STATE
|
||||||
from ray.rllib.examples.env.two_step_game import TwoStepGame
|
from ray.rllib.examples.env.two_step_game import TwoStepGame
|
||||||
from ray.rllib.policy.policy import PolicySpec
|
from ray.rllib.policy.policy import PolicySpec
|
||||||
|
@ -110,10 +111,11 @@ if __name__ == "__main__":
|
||||||
obs_space = Discrete(6)
|
obs_space = Discrete(6)
|
||||||
act_space = TwoStepGame.action_space
|
act_space = TwoStepGame.action_space
|
||||||
config = {
|
config = {
|
||||||
"learning_starts": 100,
|
"env": TwoStepGame,
|
||||||
"env_config": {
|
"env_config": {
|
||||||
"actions_are_logits": True,
|
"actions_are_logits": True,
|
||||||
},
|
},
|
||||||
|
"learning_starts": 100,
|
||||||
"multiagent": {
|
"multiagent": {
|
||||||
"policies": {
|
"policies": {
|
||||||
"pol1": PolicySpec(
|
"pol1": PolicySpec(
|
||||||
|
@ -133,31 +135,33 @@ if __name__ == "__main__":
|
||||||
# Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
|
# Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
|
||||||
"num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")),
|
"num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")),
|
||||||
}
|
}
|
||||||
group = False
|
|
||||||
elif args.run == "QMIX":
|
elif args.run == "QMIX":
|
||||||
config = {
|
config = (
|
||||||
"rollout_fragment_length": 4,
|
QMixConfig()
|
||||||
"train_batch_size": 32,
|
.training(mixer=args.mixer, train_batch_size=32)
|
||||||
"exploration_config": {
|
.rollouts(num_rollout_workers=0, rollout_fragment_length=4)
|
||||||
"final_epsilon": 0.0,
|
.exploration(
|
||||||
},
|
exploration_config={
|
||||||
"num_workers": 0,
|
"final_epsilon": 0.0,
|
||||||
"mixer": args.mixer,
|
}
|
||||||
"env_config": {
|
)
|
||||||
"separate_state_space": True,
|
.environment(
|
||||||
"one_hot_state_encoding": True,
|
env="grouped_twostep",
|
||||||
},
|
env_config={
|
||||||
# Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
|
"separate_state_space": True,
|
||||||
"num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")),
|
"one_hot_state_encoding": True,
|
||||||
}
|
},
|
||||||
group = True
|
)
|
||||||
|
.resources(num_gpus=int(os.environ.get("RLLIB_NUM_GPUS", "0")))
|
||||||
|
)
|
||||||
|
config = config.to_dict()
|
||||||
else:
|
else:
|
||||||
config = {
|
config = {
|
||||||
|
"env": TwoStepGame,
|
||||||
# Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
|
# Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
|
||||||
"num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")),
|
"num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")),
|
||||||
"framework": args.framework,
|
"framework": args.framework,
|
||||||
}
|
}
|
||||||
group = False
|
|
||||||
|
|
||||||
stop = {
|
stop = {
|
||||||
"episode_reward_mean": args.stop_reward,
|
"episode_reward_mean": args.stop_reward,
|
||||||
|
@ -165,13 +169,6 @@ if __name__ == "__main__":
|
||||||
"training_iteration": args.stop_iters,
|
"training_iteration": args.stop_iters,
|
||||||
}
|
}
|
||||||
|
|
||||||
config = dict(
|
|
||||||
config,
|
|
||||||
**{
|
|
||||||
"env": "grouped_twostep" if group else TwoStepGame,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
results = tune.run(args.run, stop=stop, config=config, verbose=2)
|
results = tune.run(args.run, stop=stop, config=config, verbose=2)
|
||||||
|
|
||||||
if args.as_test:
|
if args.as_test:
|
||||||
|
|
|
@ -6,7 +6,7 @@ import tree # pip install dm_tree
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
import ray
|
import ray
|
||||||
from ray.rllib.agents.marwil import BCTrainer
|
from ray.rllib.algorithms.marwil import BCTrainer
|
||||||
from ray.rllib.agents.pg import PGTrainer, DEFAULT_CONFIG
|
from ray.rllib.agents.pg import PGTrainer, DEFAULT_CONFIG
|
||||||
from ray.rllib.examples.env.random_env import RandomEnv
|
from ray.rllib.examples.env.random_env import RandomEnv
|
||||||
from ray.rllib.offline.json_reader import JsonReader
|
from ray.rllib.offline.json_reader import JsonReader
|
||||||
|
|
|
@ -32,7 +32,7 @@ multi-agent-cartpole-alpha-star:
|
||||||
|
|
||||||
# No league-building needed.
|
# No league-building needed.
|
||||||
league_builder_config:
|
league_builder_config:
|
||||||
type: ray.rllib.agents.alpha_star.league_builder.NoLeagueBuilder
|
type: ray.rllib.algorithms.alpha_star.league_builder.NoLeagueBuilder
|
||||||
|
|
||||||
multiagent:
|
multiagent:
|
||||||
policies: ["p0", "p1", "p2", "p3"]
|
policies: ["p0", "p1", "p2", "p3"]
|
||||||
|
|
Loading…
Add table
Reference in a new issue