[RLlib] Retry agents -> algorithms. with proper doc changes this time. (#24797)

This commit is contained in:
Jun Gong 2022-05-16 00:45:32 -07:00 committed by GitHub
parent d40fa391a5
commit 68a9a33386
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
59 changed files with 430 additions and 248 deletions

View file

@ -466,7 +466,7 @@ HalfCheetah 13000 ~15000
Model-Agnostic Meta-Learning (MAML) Model-Agnostic Meta-Learning (MAML)
----------------------------------- -----------------------------------
|pytorch| |tensorflow| |pytorch| |tensorflow|
`[paper] <https://arxiv.org/abs/1703.03400>`__ `[implementation] <https://github.com/ray-project/ray/blob/master/rllib/agents/maml/maml.py>`__ `[paper] <https://arxiv.org/abs/1703.03400>`__ `[implementation] <https://github.com/ray-project/ray/blob/master/rllib/algorithms/maml/maml.py>`__
RLlib's MAML implementation is a meta-learning method for learning and quick adaptation across different tasks for continuous control. Code here is adapted from https://github.com/jonasrothfuss, which outperforms vanilla MAML and avoids computation of the higher order gradients during the meta-update step. MAML is evaluated on custom environments that are described in greater detail `here <https://github.com/ray-project/ray/blob/master/rllib/env/apis/task_settable_env.py>`__. RLlib's MAML implementation is a meta-learning method for learning and quick adaptation across different tasks for continuous control. Code here is adapted from https://github.com/jonasrothfuss, which outperforms vanilla MAML and avoids computation of the higher order gradients during the meta-update step. MAML is evaluated on custom environments that are described in greater detail `here <https://github.com/ray-project/ray/blob/master/rllib/env/apis/task_settable_env.py>`__.
@ -476,7 +476,7 @@ Tuned examples: HalfCheetahRandDirecEnv (`Env <https://github.com/ray-project/ra
**MAML-specific configs** (see also `common configs <rllib-training.html#common-parameters>`__): **MAML-specific configs** (see also `common configs <rllib-training.html#common-parameters>`__):
.. literalinclude:: ../../../rllib/agents/maml/maml.py .. literalinclude:: ../../../rllib/algorithms/maml/maml.py
:language: python :language: python
:start-after: __sphinx_doc_begin__ :start-after: __sphinx_doc_begin__
:end-before: __sphinx_doc_end__ :end-before: __sphinx_doc_end__
@ -486,7 +486,7 @@ Tuned examples: HalfCheetahRandDirecEnv (`Env <https://github.com/ray-project/ra
Model-Based Meta-Policy-Optimization (MB-MPO) Model-Based Meta-Policy-Optimization (MB-MPO)
--------------------------------------------- ---------------------------------------------
|pytorch| |pytorch|
`[paper] <https://arxiv.org/pdf/1809.05214.pdf>`__ `[implementation] <https://github.com/ray-project/ray/blob/master/rllib/agents/mbmpo/mbmpo.py>`__ `[paper] <https://arxiv.org/pdf/1809.05214.pdf>`__ `[implementation] <https://github.com/ray-project/ray/blob/master/rllib/algorithms/mbmpo/mbmpo.py>`__
RLlib's MBMPO implementation is a Dyna-styled model-based RL method that learns based on the predictions of an ensemble of transition-dynamics models. Similar to MAML, MBMPO metalearns an optimial policy by treating each dynamics model as a different task. Code here is adapted from https://github.com/jonasrothfuss/model_ensemble_meta_learning. Similar to the original paper, MBMPO is evaluated on MuJoCo, with the horizon set to 200 instead of the default 1000. RLlib's MBMPO implementation is a Dyna-styled model-based RL method that learns based on the predictions of an ensemble of transition-dynamics models. Similar to MAML, MBMPO metalearns an optimial policy by treating each dynamics model as a different task. Code here is adapted from https://github.com/jonasrothfuss/model_ensemble_meta_learning. Similar to the original paper, MBMPO is evaluated on MuJoCo, with the horizon set to 200 instead of the default 1000.
@ -510,7 +510,7 @@ Hopper 620 ~650
**MBMPO-specific configs** (see also `common configs <rllib-training.html#common-parameters>`__): **MBMPO-specific configs** (see also `common configs <rllib-training.html#common-parameters>`__):
.. literalinclude:: ../../../rllib/agents/mbmpo/mbmpo.py .. literalinclude:: ../../../rllib/algorithms/mbmpo/mbmpo.py
:language: python :language: python
:start-after: __sphinx_doc_begin__ :start-after: __sphinx_doc_begin__
:end-before: __sphinx_doc_end__ :end-before: __sphinx_doc_end__
@ -539,7 +539,7 @@ Cheetah-Run 640 ~800
**Dreamer-specific configs** (see also `common configs <rllib-training.html#common-parameters>`__): **Dreamer-specific configs** (see also `common configs <rllib-training.html#common-parameters>`__):
.. literalinclude:: ../../../rllib/agents/dreamer/dreamer.py .. literalinclude:: ../../../rllib/algorithms/dreamer/dreamer.py
:language: python :language: python
:start-after: __sphinx_doc_begin__ :start-after: __sphinx_doc_begin__
:end-before: __sphinx_doc_end__ :end-before: __sphinx_doc_end__
@ -567,7 +567,7 @@ RecSim environment wrapper: `Google RecSim <https://github.com/ray-project/ray/b
Conservative Q-Learning (CQL) Conservative Q-Learning (CQL)
----------------------------------- -----------------------------------
|pytorch| |tensorflow| |pytorch| |tensorflow|
`[paper] <https://arxiv.org/abs/2006.04779>`__ `[implementation] <https://github.com/ray-project/ray/blob/master/rllib/agents/cql/cql.py>`__ `[paper] <https://arxiv.org/abs/2006.04779>`__ `[implementation] <https://github.com/ray-project/ray/blob/master/rllib/algorithms/cql/cql.py>`__
In offline RL, the algorithm has no access to an environment, but can only sample from a fixed dataset of pre-collected state-action-reward tuples. In offline RL, the algorithm has no access to an environment, but can only sample from a fixed dataset of pre-collected state-action-reward tuples.
In particular, CQL (Conservative Q-Learning) is an offline RL algorithm that mitigates the overestimation of Q-values outside the dataset distribution via In particular, CQL (Conservative Q-Learning) is an offline RL algorithm that mitigates the overestimation of Q-values outside the dataset distribution via
@ -581,7 +581,7 @@ Tuned examples: `HalfCheetah Random <https://github.com/ray-project/ray/blob/mas
**CQL-specific configs** (see also `common configs <rllib-training.html#common-parameters>`__): **CQL-specific configs** (see also `common configs <rllib-training.html#common-parameters>`__):
.. literalinclude:: ../../../rllib/agents/cql/cql.py .. literalinclude:: ../../../rllib/algorithms/cql/cql.py
:language: python :language: python
:start-after: __sphinx_doc_begin__ :start-after: __sphinx_doc_begin__
:end-before: __sphinx_doc_end__ :end-before: __sphinx_doc_end__
@ -636,7 +636,7 @@ Monotonic Advantage Re-Weighted Imitation Learning (MARWIL)
----------------------------------------------------------- -----------------------------------------------------------
|pytorch| |tensorflow| |pytorch| |tensorflow|
`[paper] <http://papers.nips.cc/paper/7866-exponentially-weighted-imitation-learning-for-batched-historical-data>`__ `[paper] <http://papers.nips.cc/paper/7866-exponentially-weighted-imitation-learning-for-batched-historical-data>`__
`[implementation] <https://github.com/ray-project/ray/blob/master/rllib/agents/marwil/marwil.py>`__ `[implementation] <https://github.com/ray-project/ray/blob/master/rllib/algorithms/marwil/marwil.py>`__
MARWIL is a hybrid imitation learning and policy gradient algorithm suitable for training on batched historical data. MARWIL is a hybrid imitation learning and policy gradient algorithm suitable for training on batched historical data.
When the ``beta`` hyperparameter is set to zero, the MARWIL objective reduces to vanilla imitation learning (see `BC`_). When the ``beta`` hyperparameter is set to zero, the MARWIL objective reduces to vanilla imitation learning (see `BC`_).
@ -646,7 +646,7 @@ Tuned examples: `CartPole-v0 <https://github.com/ray-project/ray/blob/master/rll
**MARWIL-specific configs** (see also `common configs <rllib-training.html#common-parameters>`__): **MARWIL-specific configs** (see also `common configs <rllib-training.html#common-parameters>`__):
.. literalinclude:: ../../../rllib/agents/marwil/marwil.py .. literalinclude:: ../../../rllib/algorithms/marwil/marwil.py
:language: python :language: python
:start-after: __sphinx_doc_begin__ :start-after: __sphinx_doc_begin__
:end-before: __sphinx_doc_end__ :end-before: __sphinx_doc_end__
@ -658,7 +658,7 @@ Behavior Cloning (BC; derived from MARWIL implementation)
--------------------------------------------------------- ---------------------------------------------------------
|pytorch| |tensorflow| |pytorch| |tensorflow|
`[paper] <http://papers.nips.cc/paper/7866-exponentially-weighted-imitation-learning-for-batched-historical-data>`__ `[paper] <http://papers.nips.cc/paper/7866-exponentially-weighted-imitation-learning-for-batched-historical-data>`__
`[implementation] <https://github.com/ray-project/ray/blob/master/rllib/agents/marwil/bc.py>`__ `[implementation] <https://github.com/ray-project/ray/blob/master/rllib/algorithms/marwil/bc.py>`__
Our behavioral cloning implementation is directly derived from our `MARWIL`_ implementation, Our behavioral cloning implementation is directly derived from our `MARWIL`_ implementation,
with the only difference being the ``beta`` parameter force-set to 0.0. This makes with the only difference being the ``beta`` parameter force-set to 0.0. This makes
@ -669,7 +669,7 @@ Tuned examples: `CartPole-v0 <https://github.com/ray-project/ray/blob/master/rll
**BC-specific configs** (see also `common configs <rllib-training.html#common-parameters>`__): **BC-specific configs** (see also `common configs <rllib-training.html#common-parameters>`__):
.. literalinclude:: ../../../rllib/agents/marwil/bc.py .. literalinclude:: ../../../rllib/algorithms/marwil/bc.py
:language: python :language: python
:start-after: __sphinx_doc_begin__ :start-after: __sphinx_doc_begin__
:end-before: __sphinx_doc_end__ :end-before: __sphinx_doc_end__

View file

@ -647,7 +647,7 @@ py_test(
name = "test_alpha_star", name = "test_alpha_star",
tags = ["team:ml", "trainers_dir"], tags = ["team:ml", "trainers_dir"],
size = "large", size = "large",
srcs = ["agents/alpha_star/tests/test_alpha_star.py"] srcs = ["algorithms/alpha_star/tests/test_alpha_star.py"]
) )
# APEXTrainer (DQN) # APEXTrainer (DQN)
@ -687,7 +687,7 @@ py_test(
name = "test_cql", name = "test_cql",
tags = ["team:ml", "trainers_dir"], tags = ["team:ml", "trainers_dir"],
size = "medium", size = "medium",
srcs = ["agents/cql/tests/test_cql.py"] srcs = ["algorithms/cql/tests/test_cql.py"]
) )
# DDPGTrainer # DDPGTrainer
@ -711,7 +711,7 @@ py_test(
name = "test_dreamer", name = "test_dreamer",
tags = ["team:ml", "trainers_dir"], tags = ["team:ml", "trainers_dir"],
size = "small", size = "small",
srcs = ["agents/dreamer/tests/test_dreamer.py"] srcs = ["algorithms/dreamer/tests/test_dreamer.py"]
) )
# ES # ES
@ -743,7 +743,7 @@ py_test(
size = "large", size = "large",
# Include the json data file. # Include the json data file.
data = ["tests/data/cartpole/large.json"], data = ["tests/data/cartpole/large.json"],
srcs = ["agents/marwil/tests/test_marwil.py"] srcs = ["algorithms/marwil/tests/test_marwil.py"]
) )
# BCTrainer (sub-type of MARWIL) # BCTrainer (sub-type of MARWIL)
@ -753,7 +753,7 @@ py_test(
size = "large", size = "large",
# Include the json data file. # Include the json data file.
data = ["tests/data/cartpole/large.json"], data = ["tests/data/cartpole/large.json"],
srcs = ["agents/marwil/tests/test_bc.py"] srcs = ["algorithms/marwil/tests/test_bc.py"]
) )
# MAMLTrainer # MAMLTrainer
@ -761,7 +761,7 @@ py_test(
name = "test_maml", name = "test_maml",
tags = ["team:ml", "trainers_dir"], tags = ["team:ml", "trainers_dir"],
size = "medium", size = "medium",
srcs = ["agents/maml/tests/test_maml.py"] srcs = ["algorithms/maml/tests/test_maml.py"]
) )
# MBMPOTrainer # MBMPOTrainer
@ -769,7 +769,7 @@ py_test(
name = "test_mbmpo", name = "test_mbmpo",
tags = ["team:ml", "trainers_dir"], tags = ["team:ml", "trainers_dir"],
size = "medium", size = "medium",
srcs = ["agents/mbmpo/tests/test_mbmpo.py"] srcs = ["algorithms/mbmpo/tests/test_mbmpo.py"]
) )
# PGTrainer # PGTrainer

View file

@ -1,4 +1,4 @@
from ray.rllib.agents.alpha_star.alpha_star import ( from ray.rllib.algorithms.alpha_star.alpha_star import (
AlphaStarConfig, AlphaStarConfig,
AlphaStarTrainer, AlphaStarTrainer,
DEFAULT_CONFIG, DEFAULT_CONFIG,
@ -9,3 +9,10 @@ __all__ = [
"AlphaStarTrainer", "AlphaStarTrainer",
"DEFAULT_CONFIG", "DEFAULT_CONFIG",
] ]
from ray.rllib.utils.deprecation import deprecation_warning
deprecation_warning(
"ray.rllib.agents.alpha_star", "ray.rllib.algorithms.alpha_star", error=False
)

View file

@ -1,8 +1,13 @@
from ray.rllib.agents.cql.cql import CQLTrainer, CQL_DEFAULT_CONFIG from ray.rllib.algorithms.cql.cql import CQLTrainer, CQL_DEFAULT_CONFIG
from ray.rllib.agents.cql.cql_torch_policy import CQLTorchPolicy from ray.rllib.algorithms.cql.cql_tf_policy import CQLTFPolicy
from ray.rllib.algorithms.cql.cql_torch_policy import CQLTorchPolicy
from ray.rllib.utils.deprecation import deprecation_warning
__all__ = [ __all__ = [
"CQL_DEFAULT_CONFIG", "CQL_DEFAULT_CONFIG",
"CQLTFPolicy",
"CQLTorchPolicy", "CQLTorchPolicy",
"CQLTrainer", "CQLTrainer",
] ]
deprecation_warning("ray.rllib.agents.cql", "ray.rllib.algorithms.cql", error=False)

View file

@ -100,9 +100,9 @@ class SimpleQConfig(TrainerConfig):
>>> .exploration(exploration_config=explore_config) >>> .exploration(exploration_config=explore_config)
""" """
def __init__(self): def __init__(self, trainer_class=None):
"""Initializes a SimpleQConfig instance.""" """Initializes a SimpleQConfig instance."""
super().__init__(trainer_class=SimpleQTrainer) super().__init__(trainer_class=trainer_class or SimpleQTrainer)
# Simple Q specific # Simple Q specific
# fmt: off # fmt: off

View file

@ -1,4 +1,4 @@
from ray.rllib.agents.dreamer.dreamer import ( from ray.rllib.algorithms.dreamer.dreamer import (
DREAMERConfig, DREAMERConfig,
DREAMERTrainer, DREAMERTrainer,
DEFAULT_CONFIG, DEFAULT_CONFIG,
@ -9,3 +9,10 @@ __all__ = [
"DREAMERTrainer", "DREAMERTrainer",
"DEFAULT_CONFIG", "DEFAULT_CONFIG",
] ]
from ray.rllib.utils.deprecation import deprecation_warning
deprecation_warning(
"ray.rllib.agents.dreamer", "ray.rllib.algorithms.dreamer", error=False
)

View file

@ -1,6 +1,10 @@
from ray.rllib.agents.maml.maml import MAMLTrainer, DEFAULT_CONFIG from ray.rllib.algorithms.maml.maml import MAMLTrainer, DEFAULT_CONFIG
__all__ = [ __all__ = [
"MAMLTrainer", "MAMLTrainer",
"DEFAULT_CONFIG", "DEFAULT_CONFIG",
] ]
from ray.rllib.utils.deprecation import deprecation_warning
deprecation_warning("ray.rllib.agents.maml", "ray.rllib.algorithms.maml", error=False)

View file

@ -1,7 +1,7 @@
from ray.rllib.agents.marwil.bc import BCTrainer, BC_DEFAULT_CONFIG from ray.rllib.algorithms.marwil.bc import BCTrainer, BC_DEFAULT_CONFIG
from ray.rllib.agents.marwil.marwil import MARWILTrainer, DEFAULT_CONFIG from ray.rllib.algorithms.marwil.marwil import MARWILTrainer, DEFAULT_CONFIG
from ray.rllib.agents.marwil.marwil_tf_policy import MARWILTFPolicy from ray.rllib.algorithms.marwil.marwil_tf_policy import MARWILTFPolicy
from ray.rllib.agents.marwil.marwil_torch_policy import MARWILTorchPolicy from ray.rllib.algorithms.marwil.marwil_torch_policy import MARWILTorchPolicy
__all__ = [ __all__ = [
"BCTrainer", "BCTrainer",
@ -11,3 +11,10 @@ __all__ = [
"MARWILTorchPolicy", "MARWILTorchPolicy",
"MARWILTrainer", "MARWILTrainer",
] ]
from ray.rllib.utils.deprecation import deprecation_warning
deprecation_warning(
"ray.rllib.agents.marwil", "ray.rllib.algorithms.marwil", error=False
)

View file

@ -1,6 +1,11 @@
from ray.rllib.agents.mbmpo.mbmpo import MBMPOTrainer, DEFAULT_CONFIG from ray.rllib.algorithms.mbmpo.mbmpo import MBMPOTrainer, DEFAULT_CONFIG
__all__ = [ __all__ = [
"MBMPOTrainer", "MBMPOTrainer",
"DEFAULT_CONFIG", "DEFAULT_CONFIG",
] ]
from ray.rllib.utils.deprecation import deprecation_warning
deprecation_warning("ray.rllib.agents.mbmpo", "ray.rllib.algorithms.mbmpo", error=False)

View file

@ -1,3 +1,3 @@
from ray.rllib.agents.qmix.qmix import QMixTrainer, DEFAULT_CONFIG from ray.rllib.agents.qmix.qmix import QMixConfig, QMixTrainer, DEFAULT_CONFIG
__all__ = ["QMixTrainer", "DEFAULT_CONFIG"] __all__ = ["QMixConfig", "QMixTrainer", "DEFAULT_CONFIG"]

View file

@ -1,7 +1,6 @@
from typing import Type from typing import Optional, Type
from ray.rllib.agents.trainer import with_common_config from ray.rllib.agents.dqn.simple_q import SimpleQConfig, SimpleQTrainer
from ray.rllib.agents.dqn.simple_q import SimpleQTrainer
from ray.rllib.agents.qmix.qmix_policy import QMixTorchPolicy from ray.rllib.agents.qmix.qmix_policy import QMixTorchPolicy
from ray.rllib.execution.rollout_ops import ( from ray.rllib.execution.rollout_ops import (
synchronous_parallel_sample, synchronous_parallel_sample,
@ -12,7 +11,7 @@ from ray.rllib.execution.train_ops import (
) )
from ray.rllib.policy.policy import Policy from ray.rllib.policy.policy import Policy
from ray.rllib.utils.annotations import override from ray.rllib.utils.annotations import override
from ray.rllib.utils.deprecation import DEPRECATED_VALUE from ray.rllib.utils.deprecation import DEPRECATED_VALUE, Deprecated
from ray.rllib.utils.metrics import ( from ray.rllib.utils.metrics import (
LAST_TARGET_UPDATE_TS, LAST_TARGET_UPDATE_TS,
NUM_AGENT_STEPS_SAMPLED, NUM_AGENT_STEPS_SAMPLED,
@ -23,118 +22,174 @@ from ray.rllib.utils.metrics import (
from ray.rllib.utils.replay_buffers.utils import sample_min_n_steps_from_buffer from ray.rllib.utils.replay_buffers.utils import sample_min_n_steps_from_buffer
from ray.rllib.utils.typing import ResultDict, TrainerConfigDict from ray.rllib.utils.typing import ResultDict, TrainerConfigDict
# fmt: off
# __sphinx_doc_begin__
DEFAULT_CONFIG = with_common_config({
# === QMix ===
# Mixing network. Either "qmix", "vdn", or None
"mixer": "qmix",
# Size of the mixing network embedding
"mixing_embed_dim": 32,
# Whether to use Double_Q learning
"double_q": True,
# Optimize over complete episodes by default.
"batch_mode": "complete_episodes",
# === Exploration Settings === class QMixConfig(SimpleQConfig):
"exploration_config": { """Defines a configuration class from which a QMixTrainer can be built.
# The Exploration class to use.
"type": "EpsilonGreedy",
# Config for the Exploration class' constructor:
"initial_epsilon": 1.0,
"final_epsilon": 0.01,
# Timesteps over which to anneal epsilon.
"epsilon_timesteps": 40000,
# For soft_q, use: Example:
# "exploration_config" = { >>> from ray.rllib.examples.env.two_step_game import TwoStepGame
# "type": "SoftQ" >>> from ray.rllib.agents.qmix import QMixConfig
# "temperature": [float, e.g. 1.0] >>> config = QMixConfig().training(gamma=0.9, lr=0.01, kl_coeff=0.3)\
# } ... .resources(num_gpus=0)\
}, ... .rollouts(num_workers=4)
>>> print(config.to_dict())
>>> # Build a Trainer object from the config and run 1 training iteration.
>>> trainer = config.build(env=TwoStepGame)
>>> trainer.train()
# === Evaluation === Example:
# Evaluate with epsilon=0 every `evaluation_interval` training iterations. >>> from ray.rllib.examples.env.two_step_game import TwoStepGame
# The evaluation stats will be reported under the "evaluation" metric key. >>> from ray.rllib.agents.qmix import QMixConfig
# Note that evaluation is currently not parallelized, and that for Ape-X >>> from ray import tune
# metrics are already only reported for the lowest epsilon workers. >>> config = QMixConfig()
"evaluation_interval": None, >>> # Print out some default values.
# Number of episodes to run per evaluation period. >>> print(config.optim_alpha)
"evaluation_duration": 10, >>> # Update the config object.
# Switch to greedy actions in evaluation workers. >>> config.training(lr=tune.grid_search([0.001, 0.0001]), optim_alpha=0.97)
"evaluation_config": { >>> # Set the config object's env.
"explore": False, >>> config.environment(env=TwoStepGame)
}, >>> # Use to_dict() to get the old-style python config dict
>>> # when running with tune.
>>> tune.run(
... "QMix",
... stop={"episode_reward_mean": 200},
... config=config.to_dict(),
... )
"""
# Minimum env sampling timesteps to accumulate within a single `train()` call. This def __init__(self):
# value does not affect learning, only the number of times `Trainer.step_attempt()` """Initializes a PPOConfig instance."""
# is called by `Trauber.train()`. If - after one `step_attempt()`, the env sampling super().__init__(trainer_class=QMixTrainer)
# timestep count has not been reached, will perform n more `step_attempt()` calls
# until the minimum timesteps have been executed. Set to 0 for no minimum timesteps.
"min_sample_timesteps_per_reporting": 1000,
# Update the target network every `target_network_update_freq` steps.
"target_network_update_freq": 500,
# === Replay buffer === # fmt: off
"replay_buffer_config": { # __sphinx_doc_begin__
# Use the new ReplayBuffer API here # QMix specific settings:
"_enable_replay_buffer_api": True, self.mixer = "qmix"
"type": "SimpleReplayBuffer", self.mixing_embed_dim = 32
# Size of the replay buffer in batches (not timesteps!). self.double_q = True
"capacity": 1000, self.target_network_update_freq = 500
"learning_starts": 1000, self.replay_buffer_config = {
}, # Use the new ReplayBuffer API here
"_enable_replay_buffer_api": True,
"type": "SimpleReplayBuffer",
# Size of the replay buffer in batches (not timesteps!).
"capacity": 1000,
"learning_starts": 1000,
}
self.optim_alpha = 0.99
self.optim_eps = 0.00001
self.grad_norm_clipping = 10
self.worker_side_prioritization = False
# === Optimization === # Override some of TrainerConfig's default values with QMix-specific values.
# Learning rate for RMSProp optimizer self.num_workers = 0
"lr": 0.0005, self.min_time_s_per_reporting = 1
# RMSProp alpha self.model = {
"optim_alpha": 0.99, "lstm_cell_size": 64,
# RMSProp epsilon "max_seq_len": 999999,
"optim_eps": 0.00001, }
# If not None, clip gradients during optimization at this value self.framework_str = "torch"
"grad_norm_clipping": 10, self.lr = 0.0005
# Update the replay buffer with this many samples at once. Note that self.rollout_fragment_length = 4
# this setting applies per-worker if num_workers > 1. self.train_batch_size = 32
"rollout_fragment_length": 4, self.batch_mode = "complete_episodes"
# Minimum batch size used for training (in timesteps). With the default buffer self.exploration_config = {
# (ReplayBuffer) this means, sampling from the buffer (entire-episode SampleBatches) # The Exploration class to use.
# as many times as is required to reach at least this number of timesteps. "type": "EpsilonGreedy",
"train_batch_size": 32, # Config for the Exploration class' constructor:
"initial_epsilon": 1.0,
"final_epsilon": 0.01,
# Timesteps over which to anneal epsilon.
"epsilon_timesteps": 40000,
# === Parallelism === # For soft_q, use:
# Number of workers for collecting samples with. This only makes sense # "exploration_config" = {
# to increase if your environment is particularly slow to sample, or if # "type": "SoftQ"
# you"re using the Async or Ape-X optimizers. # "temperature": [float, e.g. 1.0]
"num_workers": 0, # }
# Whether to compute priorities on workers. }
"worker_side_prioritization": False,
# Prevent reporting frequency from going lower than this time span.
"min_time_s_per_reporting": 1,
# === Model === # Evaluate with epsilon=0 every `evaluation_interval` training iterations.
"model": { # The evaluation stats will be reported under the "evaluation" metric key.
"lstm_cell_size": 64, # Note that evaluation is currently not parallelized, and that for Ape-X
"max_seq_len": 999999, # metrics are already only reported for the lowest epsilon workers.
}, self.evaluation_interval = None
# Only torch supported so far. self.evaluation_duration = 10
"framework": "torch", self.evaluation_config = {
"explore": False,
}
self.min_sample_timesteps_per_reporting = 1000
# __sphinx_doc_end__
# fmt: on
# Deprecated keys: # Deprecated keys:
# Use `replay_buffer_config.learning_starts` instead. self.learning_starts = DEPRECATED_VALUE
"learning_starts": DEPRECATED_VALUE, self.buffer_size = DEPRECATED_VALUE
# Use `replay_buffer_config.capacity` instead.
"buffer_size": DEPRECATED_VALUE, @override(SimpleQConfig)
}) def training(
# __sphinx_doc_end__ self,
# fmt: on *,
mixer: Optional[str] = None,
mixing_embed_dim: Optional[int] = None,
double_q: Optional[bool] = None,
target_network_update_freq: Optional[int] = None,
replay_buffer_config: Optional[dict] = None,
optim_alpha: Optional[float] = None,
optim_eps: Optional[float] = None,
grad_norm_clipping: Optional[float] = None,
worker_side_prioritization: Optional[bool] = None,
**kwargs,
) -> "QMixConfig":
"""Sets the training related configuration.
Args:
mixer: Mixing network. Either "qmix", "vdn", or None.
mixing_embed_dim: Size of the mixing network embedding.
double_q: Whether to use Double_Q learning.
target_network_update_freq: Update the target network every
`target_network_update_freq` sample steps.
replay_buffer_config:
optim_alpha: RMSProp alpha.
optim_eps: RMSProp epsilon.
grad_norm_clipping: If not None, clip gradients during optimization at
this value.
worker_side_prioritization: Whether to compute priorities for the replay
buffer on worker side.
Returns:
This updated TrainerConfig object.
"""
# Pass kwargs onto super's `training()` method.
super().training(**kwargs)
if mixer is not None:
self.mixer = mixer
if mixing_embed_dim is not None:
self.mixing_embed_dim = mixing_embed_dim
if double_q is not None:
self.double_q = double_q
if target_network_update_freq is not None:
self.target_network_update_freq = target_network_update_freq
if replay_buffer_config is not None:
self.replay_buffer_config = replay_buffer_config
if optim_alpha is not None:
self.optim_alpha = optim_alpha
if optim_eps is not None:
self.optim_eps = optim_eps
if grad_norm_clipping is not None:
self.grad_norm_clipping = grad_norm_clipping
if worker_side_prioritization is not None:
self.worker_side_prioritization = worker_side_prioritization
return self
class QMixTrainer(SimpleQTrainer): class QMixTrainer(SimpleQTrainer):
@classmethod @classmethod
@override(SimpleQTrainer) @override(SimpleQTrainer)
def get_default_config(cls) -> TrainerConfigDict: def get_default_config(cls) -> TrainerConfigDict:
return DEFAULT_CONFIG return QMixConfig().to_dict()
@override(SimpleQTrainer) @override(SimpleQTrainer)
def validate_config(self, config: TrainerConfigDict) -> None: def validate_config(self, config: TrainerConfigDict) -> None:
@ -219,3 +274,20 @@ class QMixTrainer(SimpleQTrainer):
# Return all collected metrics for the iteration. # Return all collected metrics for the iteration.
return train_results return train_results
# Deprecated: Use ray.rllib.agents.qmix.qmix.QMixConfig instead!
class _deprecated_default_config(dict):
def __init__(self):
super().__init__(QMixConfig().to_dict())
@Deprecated(
old="ray.rllib.agents.qmix.qmix.DEFAULT_CONFIG",
new="ray.rllib.agents.qmix.qmix.QMixConfig(...)",
error=False,
)
def __getitem__(self, item):
return super().__getitem__(item)
DEFAULT_CONFIG = _deprecated_default_config()

View file

@ -153,7 +153,6 @@ class QMixLoss(nn.Module):
return loss, mask, masked_td_error, chosen_action_qvals, targets return loss, mask, masked_td_error, chosen_action_qvals, targets
# TODO(sven): Make this a TorchPolicy child via `build_policy_class`.
class QMixTorchPolicy(TorchPolicy): class QMixTorchPolicy(TorchPolicy):
"""QMix impl. Assumes homogeneous agents for now. """QMix impl. Assumes homogeneous agents for now.
@ -177,9 +176,6 @@ class QMixTorchPolicy(TorchPolicy):
self.h_size = config["model"]["lstm_cell_size"] self.h_size = config["model"]["lstm_cell_size"]
self.has_env_global_state = False self.has_env_global_state = False
self.has_action_mask = False self.has_action_mask = False
self.device = (
torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
)
agent_obs_space = obs_space.original_space.spaces[0] agent_obs_space = obs_space.original_space.spaces[0]
if isinstance(agent_obs_space, gym.spaces.Dict): if isinstance(agent_obs_space, gym.spaces.Dict):
@ -218,7 +214,9 @@ class QMixTorchPolicy(TorchPolicy):
framework="torch", framework="torch",
name="model", name="model",
default_model=RNNModel, default_model=RNNModel,
).to(self.device) )
super().__init__(obs_space, action_space, config, model=self.model)
self.target_model = ModelCatalog.get_model_v2( self.target_model = ModelCatalog.get_model_v2(
agent_obs_space, agent_obs_space,
@ -230,8 +228,6 @@ class QMixTorchPolicy(TorchPolicy):
default_model=RNNModel, default_model=RNNModel,
).to(self.device) ).to(self.device)
super().__init__(obs_space, action_space, config, model=self.model)
self.exploration = self._create_exploration() self.exploration = self._create_exploration()
# Setup the mixer network. # Setup the mixer network.

View file

@ -4,7 +4,7 @@ import unittest
import ray import ray
from ray.tune import register_env from ray.tune import register_env
from ray.rllib.agents.qmix import QMixTrainer from ray.rllib.agents.qmix import QMixConfig
from ray.rllib.env.multi_agent_env import MultiAgentEnv from ray.rllib.env.multi_agent_env import MultiAgentEnv
@ -95,18 +95,21 @@ class TestQMix(unittest.TestCase):
), ),
) )
trainer = QMixTrainer( config = (
env="action_mask_test", QMixConfig()
config={ .framework(framework="torch")
"num_envs_per_worker": 5, # test with vectorization on .environment(
"env_config": { env="action_mask_test",
"avail_actions": [3, 4, 8], env_config={"avail_actions": [3, 4, 8]},
}, )
"framework": "torch", .rollouts(num_envs_per_worker=5)
}, ) # Test with vectorization on.
)
trainer = config.build()
for _ in range(4): for _ in range(4):
trainer.train() # OK if it doesn't trip the action assertion error trainer.train() # OK if it doesn't trip the action assertion error
assert trainer.train()["episode_reward_mean"] == 30.0 assert trainer.train()["episode_reward_mean"] == 30.0
trainer.stop() trainer.stop()
ray.shutdown() ray.shutdown()

View file

@ -18,7 +18,10 @@ def _import_a3c():
def _import_alpha_star(): def _import_alpha_star():
from ray.rllib.agents.alpha_star.alpha_star import AlphaStarTrainer, DEFAULT_CONFIG from ray.rllib.algorithms.alpha_star.alpha_star import (
AlphaStarTrainer,
DEFAULT_CONFIG,
)
return AlphaStarTrainer, DEFAULT_CONFIG return AlphaStarTrainer, DEFAULT_CONFIG
@ -42,7 +45,7 @@ def _import_appo():
def _import_ars(): def _import_ars():
from ray.rllib.agents import ars from ray.rllib.algorithms import ars
return ars.ARSTrainer, ars.DEFAULT_CONFIG return ars.ARSTrainer, ars.DEFAULT_CONFIG
@ -60,13 +63,13 @@ def _import_bandit_linucb():
def _import_bc(): def _import_bc():
from ray.rllib.agents import marwil from ray.rllib.algorithms import marwil
return marwil.BCTrainer, marwil.DEFAULT_CONFIG return marwil.BCTrainer, marwil.DEFAULT_CONFIG
def _import_cql(): def _import_cql():
from ray.rllib.agents import cql from ray.rllib.algorithms import cql
return cql.CQLTrainer, cql.CQL_DEFAULT_CONFIG return cql.CQLTrainer, cql.CQL_DEFAULT_CONFIG
@ -90,13 +93,13 @@ def _import_dqn():
def _import_dreamer(): def _import_dreamer():
from ray.rllib.agents import dreamer from ray.rllib.algorithms import dreamer
return dreamer.DREAMERTrainer, dreamer.DEFAULT_CONFIG return dreamer.DREAMERTrainer, dreamer.DEFAULT_CONFIG
def _import_es(): def _import_es():
from ray.rllib.agents import es from ray.rllib.algorithms import es
return es.ESTrainer, es.DEFAULT_CONFIG return es.ESTrainer, es.DEFAULT_CONFIG
@ -114,19 +117,19 @@ def _import_maddpg():
def _import_maml(): def _import_maml():
from ray.rllib.agents import maml from ray.rllib.algorithms import maml
return maml.MAMLTrainer, maml.DEFAULT_CONFIG return maml.MAMLTrainer, maml.DEFAULT_CONFIG
def _import_marwil(): def _import_marwil():
from ray.rllib.agents import marwil from ray.rllib.algorithms import marwil
return marwil.MARWILTrainer, marwil.DEFAULT_CONFIG return marwil.MARWILTrainer, marwil.DEFAULT_CONFIG
def _import_mbmpo(): def _import_mbmpo():
from ray.rllib.agents import mbmpo from ray.rllib.algorithms import mbmpo
return mbmpo.MBMPOTrainer, mbmpo.DEFAULT_CONFIG return mbmpo.MBMPOTrainer, mbmpo.DEFAULT_CONFIG

View file

@ -10,7 +10,7 @@ import unittest
import ray import ray
import ray.rllib.agents.a3c as a3c import ray.rllib.agents.a3c as a3c
import ray.rllib.agents.dqn as dqn import ray.rllib.agents.dqn as dqn
from ray.rllib.agents.marwil import BCTrainer from ray.rllib.algorithms.marwil import BCTrainer
import ray.rllib.agents.pg as pg import ray.rllib.agents.pg as pg
from ray.rllib.agents.trainer import COMMON_CONFIG from ray.rllib.agents.trainer import COMMON_CONFIG
from ray.rllib.examples.env.multi_agent import MultiAgentCartPole from ray.rllib.examples.env.multi_agent import MultiAgentCartPole

View file

@ -0,0 +1,11 @@
from ray.rllib.algorithms.alpha_star.alpha_star import (
AlphaStarConfig,
AlphaStarTrainer,
DEFAULT_CONFIG,
)
__all__ = [
"AlphaStarConfig",
"AlphaStarTrainer",
"DEFAULT_CONFIG",
]

View file

@ -9,8 +9,8 @@ from typing import Any, Dict, Optional, Type
import ray import ray
from ray.actor import ActorHandle from ray.actor import ActorHandle
from ray.rllib.agents.alpha_star.distributed_learners import DistributedLearners from ray.rllib.algorithms.alpha_star.distributed_learners import DistributedLearners
from ray.rllib.agents.alpha_star.league_builder import AlphaStarLeagueBuilder from ray.rllib.algorithms.alpha_star.league_builder import AlphaStarLeagueBuilder
from ray.rllib.agents.trainer import Trainer from ray.rllib.agents.trainer import Trainer
import ray.rllib.agents.ppo.appo as appo import ray.rllib.agents.ppo.appo as appo
from ray.rllib.evaluation.rollout_worker import RolloutWorker from ray.rllib.evaluation.rollout_worker import RolloutWorker
@ -48,7 +48,7 @@ class AlphaStarConfig(appo.APPOConfig):
"""Defines a configuration class from which an AlphaStarTrainer can be built. """Defines a configuration class from which an AlphaStarTrainer can be built.
Example: Example:
>>> from ray.rllib.agents.alpha_star import AlphaStarConfig >>> from ray.rllib.algorithms.alpha_star import AlphaStarConfig
>>> config = AlphaStarConfig().training(lr=0.0003, train_batch_size=512)\ >>> config = AlphaStarConfig().training(lr=0.0003, train_batch_size=512)\
... .resources(num_gpus=4)\ ... .resources(num_gpus=4)\
... .rollouts(num_rollout_workers=64) ... .rollouts(num_rollout_workers=64)
@ -58,7 +58,7 @@ class AlphaStarConfig(appo.APPOConfig):
>>> trainer.train() >>> trainer.train()
Example: Example:
>>> from ray.rllib.agents.alpha_star import AlphaStarConfig >>> from ray.rllib.algorithms.alpha_star import AlphaStarConfig
>>> from ray import tune >>> from ray import tune
>>> config = AlphaStarConfig() >>> config = AlphaStarConfig()
>>> # Print out some default values. >>> # Print out some default values.
@ -168,7 +168,7 @@ class AlphaStarConfig(appo.APPOConfig):
to be used for league building logic. All other keys (that are not to be used for league building logic. All other keys (that are not
`type`) will be used as constructor kwargs on the given class to `type`) will be used as constructor kwargs on the given class to
construct the LeagueBuilder instance. See the construct the LeagueBuilder instance. See the
`ray.rllib.agents.alpha_star.league_builder::AlphaStarLeagueBuilder` `ray.rllib.algorithms.alpha_star.league_builder::AlphaStarLeagueBuilder`
(used by default by this algo) as an example. (used by default by this algo) as an example.
max_num_policies_to_train: The maximum number of trainable policies for this max_num_policies_to_train: The maximum number of trainable policies for this
Trainer. Each trainable policy will exist as a independent remote actor, Trainer. Each trainable policy will exist as a independent remote actor,
@ -584,8 +584,8 @@ class _deprecated_default_config(dict):
super().__init__(AlphaStarConfig().to_dict()) super().__init__(AlphaStarConfig().to_dict())
@Deprecated( @Deprecated(
old="ray.rllib.agents.alpha_star.alpha_star.DEFAULT_CONFIG", old="ray.rllib.algorithms.alpha_star.alpha_star.DEFAULT_CONFIG",
new="ray.rllib.agents.alpha_star.alpha_star.AlphaStarConfig(...)", new="ray.rllib.algorithms.alpha_star.alpha_star.AlphaStarConfig(...)",
error=False, error=False,
) )
def __getitem__(self, item): def __getitem__(self, item):

View file

@ -2,7 +2,7 @@ import pyspiel
import unittest import unittest
import ray import ray
import ray.rllib.agents.alpha_star as alpha_star import ray.rllib.algorithms.alpha_star as alpha_star
from ray.rllib.env.wrappers.open_spiel import OpenSpielEnv from ray.rllib.env.wrappers.open_spiel import OpenSpielEnv
from ray.rllib.utils.test_utils import ( from ray.rllib.utils.test_utils import (
check_compute_single_action, check_compute_single_action,

View file

@ -0,0 +1,8 @@
from ray.rllib.algorithms.cql.cql import CQLTrainer, CQL_DEFAULT_CONFIG
from ray.rllib.algorithms.cql.cql_torch_policy import CQLTorchPolicy
__all__ = [
"CQL_DEFAULT_CONFIG",
"CQLTorchPolicy",
"CQLTrainer",
]

View file

@ -2,8 +2,8 @@ import logging
import numpy as np import numpy as np
from typing import Type from typing import Type
from ray.rllib.agents.cql.cql_tf_policy import CQLTFPolicy from ray.rllib.algorithms.cql.cql_tf_policy import CQLTFPolicy
from ray.rllib.agents.cql.cql_torch_policy import CQLTorchPolicy from ray.rllib.algorithms.cql.cql_torch_policy import CQLTorchPolicy
from ray.rllib.agents.sac.sac import SACTrainer, DEFAULT_CONFIG as SAC_CONFIG from ray.rllib.agents.sac.sac import SACTrainer, DEFAULT_CONFIG as SAC_CONFIG
from ray.rllib.execution.train_ops import ( from ray.rllib.execution.train_ops import (
multi_gpu_train_one_step, multi_gpu_train_one_step,

View file

@ -411,7 +411,7 @@ def apply_gradients_fn(policy, optimizer, grads_and_vars):
CQLTFPolicy = build_tf_policy( CQLTFPolicy = build_tf_policy(
name="CQLTFPolicy", name="CQLTFPolicy",
loss_fn=cql_loss, loss_fn=cql_loss,
get_default_config=lambda: ray.rllib.agents.cql.cql.CQL_DEFAULT_CONFIG, get_default_config=lambda: ray.rllib.algorithms.cql.cql.CQL_DEFAULT_CONFIG,
validate_spaces=validate_spaces, validate_spaces=validate_spaces,
stats_fn=cql_stats, stats_fn=cql_stats,
postprocess_fn=postprocess_trajectory, postprocess_fn=postprocess_trajectory,

View file

@ -387,7 +387,7 @@ CQLTorchPolicy = build_policy_class(
name="CQLTorchPolicy", name="CQLTorchPolicy",
framework="torch", framework="torch",
loss_fn=cql_loss, loss_fn=cql_loss,
get_default_config=lambda: ray.rllib.agents.cql.cql.CQL_DEFAULT_CONFIG, get_default_config=lambda: ray.rllib.algorithms.cql.cql.CQL_DEFAULT_CONFIG,
stats_fn=cql_stats, stats_fn=cql_stats,
postprocess_fn=postprocess_trajectory, postprocess_fn=postprocess_trajectory,
extra_grad_process_fn=apply_grad_clipping, extra_grad_process_fn=apply_grad_clipping,

View file

@ -4,7 +4,7 @@ import os
import unittest import unittest
import ray import ray
import ray.rllib.agents.cql as cql import ray.rllib.algorithms.cql as cql
from ray.rllib.utils.framework import try_import_tf, try_import_torch from ray.rllib.utils.framework import try_import_tf, try_import_torch
from ray.rllib.utils.test_utils import ( from ray.rllib.utils.test_utils import (
check_compute_single_action, check_compute_single_action,

View file

@ -0,0 +1,11 @@
from ray.rllib.algorithms.dreamer.dreamer import (
DREAMERConfig,
DREAMERTrainer,
DEFAULT_CONFIG,
)
__all__ = [
"DREAMERConfig",
"DREAMERTrainer",
"DEFAULT_CONFIG",
]

View file

@ -4,12 +4,12 @@ import random
from typing import Optional from typing import Optional
from ray.rllib.agents.trainer_config import TrainerConfig from ray.rllib.agents.trainer_config import TrainerConfig
from ray.rllib.agents.dreamer.dreamer_torch_policy import DreamerTorchPolicy from ray.rllib.algorithms.dreamer.dreamer_torch_policy import DreamerTorchPolicy
from ray.rllib.agents.trainer import Trainer from ray.rllib.agents.trainer import Trainer
from ray.rllib.execution.common import STEPS_SAMPLED_COUNTER, _get_shared_metrics from ray.rllib.execution.common import STEPS_SAMPLED_COUNTER, _get_shared_metrics
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID, SampleBatch from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID, SampleBatch
from ray.rllib.evaluation.metrics import collect_metrics from ray.rllib.evaluation.metrics import collect_metrics
from ray.rllib.agents.dreamer.dreamer_model import DreamerModel from ray.rllib.algorithms.dreamer.dreamer_model import DreamerModel
from ray.rllib.execution.rollout_ops import ( from ray.rllib.execution.rollout_ops import (
ParallelRollouts, ParallelRollouts,
synchronous_parallel_sample, synchronous_parallel_sample,
@ -31,7 +31,7 @@ class DREAMERConfig(TrainerConfig):
"""Defines a PPOTrainer configuration class from which a PPOTrainer can be built. """Defines a PPOTrainer configuration class from which a PPOTrainer can be built.
Example: Example:
>>> from ray.rllib.agents.dreamer import DREAMERConfig >>> from ray.rllib.algorithms.dreamer import DREAMERConfig
>>> config = DREAMERConfig().training(gamma=0.9, lr=0.01)\ >>> config = DREAMERConfig().training(gamma=0.9, lr=0.01)\
... .resources(num_gpus=0)\ ... .resources(num_gpus=0)\
... .rollouts(num_rollout_workers=4) ... .rollouts(num_rollout_workers=4)
@ -42,7 +42,7 @@ class DREAMERConfig(TrainerConfig):
Example: Example:
>>> from ray import tune >>> from ray import tune
>>> from ray.rllib.agents.dreamer import DREAMERConfig >>> from ray.rllib.algorithms.dreamer import DREAMERConfig
>>> config = DREAMERConfig() >>> config = DREAMERConfig()
>>> # Print out some default values. >>> # Print out some default values.
>>> print(config.clip_param) >>> print(config.clip_param)
@ -418,14 +418,14 @@ class DREAMERTrainer(Trainer):
return results return results
# Deprecated: Use ray.rllib.agents.dreamer.DREAMERConfig instead! # Deprecated: Use ray.rllib.algorithms.dreamer.DREAMERConfig instead!
class _deprecated_default_config(dict): class _deprecated_default_config(dict):
def __init__(self): def __init__(self):
super().__init__(DREAMERConfig().to_dict()) super().__init__(DREAMERConfig().to_dict())
@Deprecated( @Deprecated(
old="ray.rllib.agents.dreamer.dreamer.DEFAULT_CONFIG", old="ray.rllib.algorithms.dreamer.dreamer.DEFAULT_CONFIG",
new="ray.rllib.agents.dreamer.dreamer.DREAMERConfig(...)", new="ray.rllib.algorithms.dreamer.dreamer.DREAMERConfig(...)",
error=False, error=False,
) )
def __getitem__(self, item): def __getitem__(self, item):

View file

@ -8,7 +8,7 @@ from ray.rllib.utils.framework import TensorType
torch, nn = try_import_torch() torch, nn = try_import_torch()
if torch: if torch:
from torch import distributions as td from torch import distributions as td
from ray.rllib.agents.dreamer.utils import ( from ray.rllib.algorithms.dreamer.utils import (
Linear, Linear,
Conv2d, Conv2d,
ConvTranspose2d, ConvTranspose2d,

View file

@ -4,7 +4,7 @@ import numpy as np
from typing import Dict, Optional from typing import Dict, Optional
import ray import ray
from ray.rllib.agents.dreamer.utils import FreezeParameters from ray.rllib.algorithms.dreamer.utils import FreezeParameters
from ray.rllib.evaluation.episode import Episode from ray.rllib.evaluation.episode import Episode
from ray.rllib.models.catalog import ModelCatalog from ray.rllib.models.catalog import ModelCatalog
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2 from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
@ -284,7 +284,7 @@ def preprocess_episode(
DreamerTorchPolicy = build_policy_class( DreamerTorchPolicy = build_policy_class(
name="DreamerTorchPolicy", name="DreamerTorchPolicy",
framework="torch", framework="torch",
get_default_config=lambda: ray.rllib.agents.dreamer.dreamer.DEFAULT_CONFIG, get_default_config=lambda: ray.rllib.algorithms.dreamer.dreamer.DEFAULT_CONFIG,
action_sampler_fn=action_sampler_fn, action_sampler_fn=action_sampler_fn,
postprocess_fn=preprocess_episode, postprocess_fn=preprocess_episode,
loss_fn=dreamer_loss, loss_fn=dreamer_loss,

View file

@ -2,7 +2,7 @@ from gym.spaces import Box
import unittest import unittest
import ray import ray
import ray.rllib.agents.dreamer as dreamer import ray.rllib.algorithms.dreamer as dreamer
from ray.rllib.examples.env.random_env import RandomEnv from ray.rllib.examples.env.random_env import RandomEnv
from ray.rllib.utils.test_utils import framework_iterator from ray.rllib.utils.test_utils import framework_iterator

View file

@ -0,0 +1,6 @@
from ray.rllib.algorithms.maml.maml import MAMLTrainer, DEFAULT_CONFIG
__all__ = [
"MAMLTrainer",
"DEFAULT_CONFIG",
]

View file

@ -4,8 +4,8 @@ from typing import Type
from ray.rllib.utils.sgd import standardized from ray.rllib.utils.sgd import standardized
from ray.rllib.agents import with_common_config from ray.rllib.agents import with_common_config
from ray.rllib.agents.maml.maml_tf_policy import MAMLTFPolicy from ray.rllib.algorithms.maml.maml_tf_policy import MAMLTFPolicy
from ray.rllib.agents.maml.maml_torch_policy import MAMLTorchPolicy from ray.rllib.algorithms.maml.maml_torch_policy import MAMLTorchPolicy
from ray.rllib.agents.trainer import Trainer from ray.rllib.agents.trainer import Trainer
from ray.rllib.evaluation.metrics import get_learner_stats from ray.rllib.evaluation.metrics import get_learner_stats
from ray.rllib.evaluation.worker_set import WorkerSet from ray.rllib.evaluation.worker_set import WorkerSet

View file

@ -442,7 +442,7 @@ def setup_mixins(policy, obs_space, action_space, config):
MAMLTFPolicy = build_tf_policy( MAMLTFPolicy = build_tf_policy(
name="MAMLTFPolicy", name="MAMLTFPolicy",
get_default_config=lambda: ray.rllib.agents.maml.maml.DEFAULT_CONFIG, get_default_config=lambda: ray.rllib.algorithms.maml.maml.DEFAULT_CONFIG,
loss_fn=maml_loss, loss_fn=maml_loss,
stats_fn=maml_stats, stats_fn=maml_stats,
optimizer_fn=maml_optimizer_fn, optimizer_fn=maml_optimizer_fn,

View file

@ -382,7 +382,7 @@ def setup_mixins(policy, obs_space, action_space, config):
MAMLTorchPolicy = build_policy_class( MAMLTorchPolicy = build_policy_class(
name="MAMLTorchPolicy", name="MAMLTorchPolicy",
framework="torch", framework="torch",
get_default_config=lambda: ray.rllib.agents.maml.maml.DEFAULT_CONFIG, get_default_config=lambda: ray.rllib.algorithms.maml.maml.DEFAULT_CONFIG,
loss_fn=maml_loss, loss_fn=maml_loss,
stats_fn=maml_stats, stats_fn=maml_stats,
optimizer_fn=maml_optimizer_fn, optimizer_fn=maml_optimizer_fn,

View file

@ -1,7 +1,7 @@
import unittest import unittest
import ray import ray
import ray.rllib.agents.maml as maml import ray.rllib.algorithms.maml as maml
from ray.rllib.utils.test_utils import ( from ray.rllib.utils.test_utils import (
check_compute_single_action, check_compute_single_action,
check_train_results, check_train_results,

View file

@ -0,0 +1,13 @@
from ray.rllib.algorithms.marwil.bc import BCTrainer, BC_DEFAULT_CONFIG
from ray.rllib.algorithms.marwil.marwil import MARWILTrainer, DEFAULT_CONFIG
from ray.rllib.algorithms.marwil.marwil_tf_policy import MARWILTFPolicy
from ray.rllib.algorithms.marwil.marwil_torch_policy import MARWILTorchPolicy
__all__ = [
"BCTrainer",
"BC_DEFAULT_CONFIG",
"DEFAULT_CONFIG",
"MARWILTFPolicy",
"MARWILTorchPolicy",
"MARWILTrainer",
]

View file

@ -1,4 +1,4 @@
from ray.rllib.agents.marwil.marwil import ( from ray.rllib.algorithms.marwil.marwil import (
MARWILTrainer, MARWILTrainer,
DEFAULT_CONFIG as MARWIL_CONFIG, DEFAULT_CONFIG as MARWIL_CONFIG,
) )

View file

@ -1,7 +1,7 @@
from typing import Type from typing import Type
from ray.rllib.agents.trainer import Trainer, with_common_config from ray.rllib.agents.trainer import Trainer, with_common_config
from ray.rllib.agents.marwil.marwil_tf_policy import MARWILTFPolicy from ray.rllib.algorithms.marwil.marwil_tf_policy import MARWILTFPolicy
from ray.rllib.execution.buffers.multi_agent_replay_buffer import MultiAgentReplayBuffer from ray.rllib.execution.buffers.multi_agent_replay_buffer import MultiAgentReplayBuffer
from ray.rllib.execution.rollout_ops import ( from ray.rllib.execution.rollout_ops import (
synchronous_parallel_sample, synchronous_parallel_sample,
@ -110,7 +110,9 @@ class MARWILTrainer(Trainer):
@override(Trainer) @override(Trainer)
def get_default_policy_class(self, config: TrainerConfigDict) -> Type[Policy]: def get_default_policy_class(self, config: TrainerConfigDict) -> Type[Policy]:
if config["framework"] == "torch": if config["framework"] == "torch":
from ray.rllib.agents.marwil.marwil_torch_policy import MARWILTorchPolicy from ray.rllib.algorithms.marwil.marwil_torch_policy import (
MARWILTorchPolicy,
)
return MARWILTorchPolicy return MARWILTorchPolicy
else: else:

View file

@ -232,7 +232,7 @@ def setup_mixins(
MARWILTFPolicy = build_tf_policy( MARWILTFPolicy = build_tf_policy(
name="MARWILTFPolicy", name="MARWILTFPolicy",
get_default_config=lambda: ray.rllib.agents.marwil.marwil.DEFAULT_CONFIG, get_default_config=lambda: ray.rllib.algorithms.marwil.marwil.DEFAULT_CONFIG,
loss_fn=marwil_loss, loss_fn=marwil_loss,
stats_fn=stats, stats_fn=stats,
postprocess_fn=postprocess_advantages, postprocess_fn=postprocess_advantages,

View file

@ -3,7 +3,7 @@ from typing import Dict
import ray import ray
from ray.rllib.agents.a3c.a3c_torch_policy import ValueNetworkMixin from ray.rllib.agents.a3c.a3c_torch_policy import ValueNetworkMixin
from ray.rllib.agents.marwil.marwil_tf_policy import postprocess_advantages from ray.rllib.algorithms.marwil.marwil_tf_policy import postprocess_advantages
from ray.rllib.evaluation.postprocessing import Postprocessing from ray.rllib.evaluation.postprocessing import Postprocessing
from ray.rllib.policy.policy_template import build_policy_class from ray.rllib.policy.policy_template import build_policy_class
from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.policy.sample_batch import SampleBatch
@ -113,7 +113,7 @@ MARWILTorchPolicy = build_policy_class(
name="MARWILTorchPolicy", name="MARWILTorchPolicy",
framework="torch", framework="torch",
loss_fn=marwil_loss, loss_fn=marwil_loss,
get_default_config=lambda: ray.rllib.agents.marwil.marwil.DEFAULT_CONFIG, get_default_config=lambda: ray.rllib.algorithms.marwil.marwil.DEFAULT_CONFIG,
stats_fn=stats, stats_fn=stats,
postprocess_fn=postprocess_advantages, postprocess_fn=postprocess_advantages,
extra_grad_process_fn=apply_grad_clipping, extra_grad_process_fn=apply_grad_clipping,

View file

@ -3,7 +3,7 @@ from pathlib import Path
import unittest import unittest
import ray import ray
import ray.rllib.agents.marwil as marwil import ray.rllib.algorithms.marwil as marwil
from ray.rllib.utils.framework import try_import_tf from ray.rllib.utils.framework import try_import_tf
from ray.rllib.utils.test_utils import ( from ray.rllib.utils.test_utils import (
check_compute_single_action, check_compute_single_action,

View file

@ -4,7 +4,7 @@ from pathlib import Path
import unittest import unittest
import ray import ray
import ray.rllib.agents.marwil as marwil import ray.rllib.algorithms.marwil as marwil
from ray.rllib.evaluation.postprocessing import compute_advantages from ray.rllib.evaluation.postprocessing import compute_advantages
from ray.rllib.offline import JsonReader from ray.rllib.offline import JsonReader
from ray.rllib.utils.framework import try_import_tf, try_import_torch from ray.rllib.utils.framework import try_import_tf, try_import_torch

View file

@ -0,0 +1,6 @@
from ray.rllib.algorithms.mbmpo.mbmpo import MBMPOTrainer, DEFAULT_CONFIG
__all__ = [
"MBMPOTrainer",
"DEFAULT_CONFIG",
]

View file

@ -4,9 +4,9 @@ from typing import List, Type
import ray import ray
from ray.rllib.agents import with_common_config from ray.rllib.agents import with_common_config
from ray.rllib.agents.mbmpo.mbmpo_torch_policy import MBMPOTorchPolicy from ray.rllib.algorithms.mbmpo.mbmpo_torch_policy import MBMPOTorchPolicy
from ray.rllib.agents.mbmpo.model_ensemble import DynamicsEnsembleCustomModel from ray.rllib.algorithms.mbmpo.model_ensemble import DynamicsEnsembleCustomModel
from ray.rllib.agents.mbmpo.utils import calculate_gae_advantages, MBMPOExploration from ray.rllib.algorithms.mbmpo.utils import calculate_gae_advantages, MBMPOExploration
from ray.rllib.agents.trainer import Trainer from ray.rllib.agents.trainer import Trainer
from ray.rllib.env.env_context import EnvContext from ray.rllib.env.env_context import EnvContext
from ray.rllib.env.wrappers.model_vector_env import model_vector_env from ray.rllib.env.wrappers.model_vector_env import model_vector_env

View file

@ -5,7 +5,7 @@ from typing import Tuple, Type
import ray import ray
from ray.rllib.agents.a3c.a3c_torch_policy import vf_preds_fetches from ray.rllib.agents.a3c.a3c_torch_policy import vf_preds_fetches
from ray.rllib.agents.maml.maml_torch_policy import ( from ray.rllib.algorithms.maml.maml_torch_policy import (
setup_mixins, setup_mixins,
maml_loss, maml_loss,
maml_stats, maml_stats,
@ -121,7 +121,7 @@ def make_model_and_action_dist(
MBMPOTorchPolicy = build_policy_class( MBMPOTorchPolicy = build_policy_class(
name="MBMPOTorchPolicy", name="MBMPOTorchPolicy",
framework="torch", framework="torch",
get_default_config=lambda: ray.rllib.agents.mbmpo.mbmpo.DEFAULT_CONFIG, get_default_config=lambda: ray.rllib.algorithms.mbmpo.mbmpo.DEFAULT_CONFIG,
make_model_and_action_dist=make_model_and_action_dist, make_model_and_action_dist=make_model_and_action_dist,
loss_fn=maml_loss, loss_fn=maml_loss,
stats_fn=maml_stats, stats_fn=maml_stats,

View file

@ -1,7 +1,7 @@
import unittest import unittest
import ray import ray
import ray.rllib.agents.mbmpo as mbmpo import ray.rllib.algorithms.mbmpo as mbmpo
from ray.rllib.utils.test_utils import ( from ray.rllib.utils.test_utils import (
check_compute_single_action, check_compute_single_action,
check_train_results, check_train_results,

View file

@ -245,7 +245,6 @@ class MultiAgentEnv(gym.Env):
# fmt: off # fmt: off
# __grouping_doc_begin__ # __grouping_doc_begin__
@ExperimentalAPI
def with_agent_groups( def with_agent_groups(
self, self,
groups: Dict[str, List[AgentID]], groups: Dict[str, List[AgentID]],
@ -265,16 +264,17 @@ class MultiAgentEnv(gym.Env):
Agent grouping is required to leverage algorithms such as Q-Mix. Agent grouping is required to leverage algorithms such as Q-Mix.
This API is experimental.
Args: Args:
groups: Mapping from group id to a list of the agent ids groups: Mapping from group id to a list of the agent ids
of group members. If an agent id is not present in any group of group members. If an agent id is not present in any group
value, it will be left ungrouped. value, it will be left ungrouped. The group id becomes a new agent ID
in the final environment.
obs_space: Optional observation space for the grouped obs_space: Optional observation space for the grouped
env. Must be a tuple space. env. Must be a tuple space. If not provided, will infer this to be a
Tuple of n individual agents spaces (n=num agents in a group).
act_space: Optional action space for the grouped env. act_space: Optional action space for the grouped env.
Must be a tuple space. Must be a tuple space. If not provided, will infer this to be a Tuple
of n individual agents spaces (n=num agents in a group).
Examples: Examples:
>>> from ray.rllib.env.multi_agent_env import MultiAgentEnv >>> from ray.rllib.env.multi_agent_env import MultiAgentEnv

View file

@ -1,6 +1,9 @@
from collections import OrderedDict from collections import OrderedDict
import gym
from typing import Dict, List, Optional
from ray.rllib.env.multi_agent_env import MultiAgentEnv from ray.rllib.env.multi_agent_env import MultiAgentEnv
from ray.rllib.utils.typing import AgentID
# info key for the individual rewards of an agent, for example: # info key for the individual rewards of an agent, for example:
# info: { # info: {
@ -27,21 +30,35 @@ class GroupAgentsWrapper(MultiAgentEnv):
This API is experimental. This API is experimental.
""" """
def __init__(self, env, groups, obs_space=None, act_space=None): def __init__(
"""Wrap an existing multi-agent env to group agents together. self,
env: MultiAgentEnv,
groups: Dict[str, List[AgentID]],
obs_space: Optional[gym.Space] = None,
act_space: Optional[gym.Space] = None,
):
"""Wrap an existing MultiAgentEnv to group agent ID together.
See MultiAgentEnv.with_agent_groups() for usage info. See `MultiAgentEnv.with_agent_groups()` for more detailed usage info.
Args: Args:
env (MultiAgentEnv): env to wrap env: The env to wrap and whose agent IDs to group into new agents.
groups (dict): Grouping spec as documented in MultiAgentEnv. groups: Mapping from group id to a list of the agent ids
obs_space (Space): Optional observation space for the grouped of group members. If an agent id is not present in any group
env. Must be a tuple space. value, it will be left ungrouped. The group id becomes a new agent ID
act_space (Space): Optional action space for the grouped env. in the final environment.
Must be a tuple space. obs_space: Optional observation space for the grouped
env. Must be a tuple space. If not provided, will infer this to be a
Tuple of n individual agents spaces (n=num agents in a group).
act_space: Optional action space for the grouped env.
Must be a tuple space. If not provided, will infer this to be a Tuple
of n individual agents spaces (n=num agents in a group).
""" """
super().__init__() super().__init__()
self.env = env self.env = env
# Inherit wrapped env's `_skip_env_checking` flag.
if hasattr(self.env, "_skip_env_checking"):
self._skip_env_checking = self.env._skip_env_checking
self.groups = groups self.groups = groups
self.agent_id_to_group = {} self.agent_id_to_group = {}
for group_id, agent_ids in groups.items(): for group_id, agent_ids in groups.items():

View file

@ -20,7 +20,9 @@ import os
import ray import ray
from ray import tune from ray import tune
from ray.rllib.agents.maml.maml_torch_policy import KLCoeffMixin as TorchKLCoeffMixin from ray.rllib.algorithms.maml.maml_torch_policy import (
KLCoeffMixin as TorchKLCoeffMixin,
)
from ray.rllib.agents.ppo.ppo import PPOTrainer from ray.rllib.agents.ppo.ppo import PPOTrainer
from ray.rllib.agents.ppo.ppo_tf_policy import ( from ray.rllib.agents.ppo.ppo_tf_policy import (
PPOTFPolicy, PPOTFPolicy,

View file

@ -16,6 +16,7 @@ import os
import ray import ray
from ray import tune from ray import tune
from ray.tune import register_env from ray.tune import register_env
from ray.rllib.agents.qmix import QMixConfig
from ray.rllib.env.multi_agent_env import ENV_STATE from ray.rllib.env.multi_agent_env import ENV_STATE
from ray.rllib.examples.env.two_step_game import TwoStepGame from ray.rllib.examples.env.two_step_game import TwoStepGame
from ray.rllib.policy.policy import PolicySpec from ray.rllib.policy.policy import PolicySpec
@ -110,10 +111,11 @@ if __name__ == "__main__":
obs_space = Discrete(6) obs_space = Discrete(6)
act_space = TwoStepGame.action_space act_space = TwoStepGame.action_space
config = { config = {
"learning_starts": 100, "env": TwoStepGame,
"env_config": { "env_config": {
"actions_are_logits": True, "actions_are_logits": True,
}, },
"learning_starts": 100,
"multiagent": { "multiagent": {
"policies": { "policies": {
"pol1": PolicySpec( "pol1": PolicySpec(
@ -133,31 +135,33 @@ if __name__ == "__main__":
# Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0. # Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
"num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")), "num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")),
} }
group = False
elif args.run == "QMIX": elif args.run == "QMIX":
config = { config = (
"rollout_fragment_length": 4, QMixConfig()
"train_batch_size": 32, .training(mixer=args.mixer, train_batch_size=32)
"exploration_config": { .rollouts(num_rollout_workers=0, rollout_fragment_length=4)
"final_epsilon": 0.0, .exploration(
}, exploration_config={
"num_workers": 0, "final_epsilon": 0.0,
"mixer": args.mixer, }
"env_config": { )
"separate_state_space": True, .environment(
"one_hot_state_encoding": True, env="grouped_twostep",
}, env_config={
# Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0. "separate_state_space": True,
"num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")), "one_hot_state_encoding": True,
} },
group = True )
.resources(num_gpus=int(os.environ.get("RLLIB_NUM_GPUS", "0")))
)
config = config.to_dict()
else: else:
config = { config = {
"env": TwoStepGame,
# Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0. # Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
"num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")), "num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")),
"framework": args.framework, "framework": args.framework,
} }
group = False
stop = { stop = {
"episode_reward_mean": args.stop_reward, "episode_reward_mean": args.stop_reward,
@ -165,13 +169,6 @@ if __name__ == "__main__":
"training_iteration": args.stop_iters, "training_iteration": args.stop_iters,
} }
config = dict(
config,
**{
"env": "grouped_twostep" if group else TwoStepGame,
}
)
results = tune.run(args.run, stop=stop, config=config, verbose=2) results = tune.run(args.run, stop=stop, config=config, verbose=2)
if args.as_test: if args.as_test:

View file

@ -6,7 +6,7 @@ import tree # pip install dm_tree
import unittest import unittest
import ray import ray
from ray.rllib.agents.marwil import BCTrainer from ray.rllib.algorithms.marwil import BCTrainer
from ray.rllib.agents.pg import PGTrainer, DEFAULT_CONFIG from ray.rllib.agents.pg import PGTrainer, DEFAULT_CONFIG
from ray.rllib.examples.env.random_env import RandomEnv from ray.rllib.examples.env.random_env import RandomEnv
from ray.rllib.offline.json_reader import JsonReader from ray.rllib.offline.json_reader import JsonReader

View file

@ -32,7 +32,7 @@ multi-agent-cartpole-alpha-star:
# No league-building needed. # No league-building needed.
league_builder_config: league_builder_config:
type: ray.rllib.agents.alpha_star.league_builder.NoLeagueBuilder type: ray.rllib.algorithms.alpha_star.league_builder.NoLeagueBuilder
multiagent: multiagent:
policies: ["p0", "p1", "p2", "p3"] policies: ["p0", "p1", "p2", "p3"]