mirror of
https://github.com/vale981/ray
synced 2025-03-04 17:41:43 -05:00
[RLlib] Retry agents -> algorithms. with proper doc changes this time. (#24797)
This commit is contained in:
parent
d40fa391a5
commit
68a9a33386
59 changed files with 430 additions and 248 deletions
|
@ -466,7 +466,7 @@ HalfCheetah 13000 ~15000
|
|||
Model-Agnostic Meta-Learning (MAML)
|
||||
-----------------------------------
|
||||
|pytorch| |tensorflow|
|
||||
`[paper] <https://arxiv.org/abs/1703.03400>`__ `[implementation] <https://github.com/ray-project/ray/blob/master/rllib/agents/maml/maml.py>`__
|
||||
`[paper] <https://arxiv.org/abs/1703.03400>`__ `[implementation] <https://github.com/ray-project/ray/blob/master/rllib/algorithms/maml/maml.py>`__
|
||||
|
||||
RLlib's MAML implementation is a meta-learning method for learning and quick adaptation across different tasks for continuous control. Code here is adapted from https://github.com/jonasrothfuss, which outperforms vanilla MAML and avoids computation of the higher order gradients during the meta-update step. MAML is evaluated on custom environments that are described in greater detail `here <https://github.com/ray-project/ray/blob/master/rllib/env/apis/task_settable_env.py>`__.
|
||||
|
||||
|
@ -476,7 +476,7 @@ Tuned examples: HalfCheetahRandDirecEnv (`Env <https://github.com/ray-project/ra
|
|||
|
||||
**MAML-specific configs** (see also `common configs <rllib-training.html#common-parameters>`__):
|
||||
|
||||
.. literalinclude:: ../../../rllib/agents/maml/maml.py
|
||||
.. literalinclude:: ../../../rllib/algorithms/maml/maml.py
|
||||
:language: python
|
||||
:start-after: __sphinx_doc_begin__
|
||||
:end-before: __sphinx_doc_end__
|
||||
|
@ -486,7 +486,7 @@ Tuned examples: HalfCheetahRandDirecEnv (`Env <https://github.com/ray-project/ra
|
|||
Model-Based Meta-Policy-Optimization (MB-MPO)
|
||||
---------------------------------------------
|
||||
|pytorch|
|
||||
`[paper] <https://arxiv.org/pdf/1809.05214.pdf>`__ `[implementation] <https://github.com/ray-project/ray/blob/master/rllib/agents/mbmpo/mbmpo.py>`__
|
||||
`[paper] <https://arxiv.org/pdf/1809.05214.pdf>`__ `[implementation] <https://github.com/ray-project/ray/blob/master/rllib/algorithms/mbmpo/mbmpo.py>`__
|
||||
|
||||
RLlib's MBMPO implementation is a Dyna-styled model-based RL method that learns based on the predictions of an ensemble of transition-dynamics models. Similar to MAML, MBMPO metalearns an optimial policy by treating each dynamics model as a different task. Code here is adapted from https://github.com/jonasrothfuss/model_ensemble_meta_learning. Similar to the original paper, MBMPO is evaluated on MuJoCo, with the horizon set to 200 instead of the default 1000.
|
||||
|
||||
|
@ -510,7 +510,7 @@ Hopper 620 ~650
|
|||
|
||||
**MBMPO-specific configs** (see also `common configs <rllib-training.html#common-parameters>`__):
|
||||
|
||||
.. literalinclude:: ../../../rllib/agents/mbmpo/mbmpo.py
|
||||
.. literalinclude:: ../../../rllib/algorithms/mbmpo/mbmpo.py
|
||||
:language: python
|
||||
:start-after: __sphinx_doc_begin__
|
||||
:end-before: __sphinx_doc_end__
|
||||
|
@ -539,7 +539,7 @@ Cheetah-Run 640 ~800
|
|||
|
||||
**Dreamer-specific configs** (see also `common configs <rllib-training.html#common-parameters>`__):
|
||||
|
||||
.. literalinclude:: ../../../rllib/agents/dreamer/dreamer.py
|
||||
.. literalinclude:: ../../../rllib/algorithms/dreamer/dreamer.py
|
||||
:language: python
|
||||
:start-after: __sphinx_doc_begin__
|
||||
:end-before: __sphinx_doc_end__
|
||||
|
@ -567,7 +567,7 @@ RecSim environment wrapper: `Google RecSim <https://github.com/ray-project/ray/b
|
|||
Conservative Q-Learning (CQL)
|
||||
-----------------------------------
|
||||
|pytorch| |tensorflow|
|
||||
`[paper] <https://arxiv.org/abs/2006.04779>`__ `[implementation] <https://github.com/ray-project/ray/blob/master/rllib/agents/cql/cql.py>`__
|
||||
`[paper] <https://arxiv.org/abs/2006.04779>`__ `[implementation] <https://github.com/ray-project/ray/blob/master/rllib/algorithms/cql/cql.py>`__
|
||||
|
||||
In offline RL, the algorithm has no access to an environment, but can only sample from a fixed dataset of pre-collected state-action-reward tuples.
|
||||
In particular, CQL (Conservative Q-Learning) is an offline RL algorithm that mitigates the overestimation of Q-values outside the dataset distribution via
|
||||
|
@ -581,7 +581,7 @@ Tuned examples: `HalfCheetah Random <https://github.com/ray-project/ray/blob/mas
|
|||
|
||||
**CQL-specific configs** (see also `common configs <rllib-training.html#common-parameters>`__):
|
||||
|
||||
.. literalinclude:: ../../../rllib/agents/cql/cql.py
|
||||
.. literalinclude:: ../../../rllib/algorithms/cql/cql.py
|
||||
:language: python
|
||||
:start-after: __sphinx_doc_begin__
|
||||
:end-before: __sphinx_doc_end__
|
||||
|
@ -636,7 +636,7 @@ Monotonic Advantage Re-Weighted Imitation Learning (MARWIL)
|
|||
-----------------------------------------------------------
|
||||
|pytorch| |tensorflow|
|
||||
`[paper] <http://papers.nips.cc/paper/7866-exponentially-weighted-imitation-learning-for-batched-historical-data>`__
|
||||
`[implementation] <https://github.com/ray-project/ray/blob/master/rllib/agents/marwil/marwil.py>`__
|
||||
`[implementation] <https://github.com/ray-project/ray/blob/master/rllib/algorithms/marwil/marwil.py>`__
|
||||
|
||||
MARWIL is a hybrid imitation learning and policy gradient algorithm suitable for training on batched historical data.
|
||||
When the ``beta`` hyperparameter is set to zero, the MARWIL objective reduces to vanilla imitation learning (see `BC`_).
|
||||
|
@ -646,7 +646,7 @@ Tuned examples: `CartPole-v0 <https://github.com/ray-project/ray/blob/master/rll
|
|||
|
||||
**MARWIL-specific configs** (see also `common configs <rllib-training.html#common-parameters>`__):
|
||||
|
||||
.. literalinclude:: ../../../rllib/agents/marwil/marwil.py
|
||||
.. literalinclude:: ../../../rllib/algorithms/marwil/marwil.py
|
||||
:language: python
|
||||
:start-after: __sphinx_doc_begin__
|
||||
:end-before: __sphinx_doc_end__
|
||||
|
@ -658,7 +658,7 @@ Behavior Cloning (BC; derived from MARWIL implementation)
|
|||
---------------------------------------------------------
|
||||
|pytorch| |tensorflow|
|
||||
`[paper] <http://papers.nips.cc/paper/7866-exponentially-weighted-imitation-learning-for-batched-historical-data>`__
|
||||
`[implementation] <https://github.com/ray-project/ray/blob/master/rllib/agents/marwil/bc.py>`__
|
||||
`[implementation] <https://github.com/ray-project/ray/blob/master/rllib/algorithms/marwil/bc.py>`__
|
||||
|
||||
Our behavioral cloning implementation is directly derived from our `MARWIL`_ implementation,
|
||||
with the only difference being the ``beta`` parameter force-set to 0.0. This makes
|
||||
|
@ -669,7 +669,7 @@ Tuned examples: `CartPole-v0 <https://github.com/ray-project/ray/blob/master/rll
|
|||
|
||||
**BC-specific configs** (see also `common configs <rllib-training.html#common-parameters>`__):
|
||||
|
||||
.. literalinclude:: ../../../rllib/agents/marwil/bc.py
|
||||
.. literalinclude:: ../../../rllib/algorithms/marwil/bc.py
|
||||
:language: python
|
||||
:start-after: __sphinx_doc_begin__
|
||||
:end-before: __sphinx_doc_end__
|
||||
|
|
14
rllib/BUILD
14
rllib/BUILD
|
@ -647,7 +647,7 @@ py_test(
|
|||
name = "test_alpha_star",
|
||||
tags = ["team:ml", "trainers_dir"],
|
||||
size = "large",
|
||||
srcs = ["agents/alpha_star/tests/test_alpha_star.py"]
|
||||
srcs = ["algorithms/alpha_star/tests/test_alpha_star.py"]
|
||||
)
|
||||
|
||||
# APEXTrainer (DQN)
|
||||
|
@ -687,7 +687,7 @@ py_test(
|
|||
name = "test_cql",
|
||||
tags = ["team:ml", "trainers_dir"],
|
||||
size = "medium",
|
||||
srcs = ["agents/cql/tests/test_cql.py"]
|
||||
srcs = ["algorithms/cql/tests/test_cql.py"]
|
||||
)
|
||||
|
||||
# DDPGTrainer
|
||||
|
@ -711,7 +711,7 @@ py_test(
|
|||
name = "test_dreamer",
|
||||
tags = ["team:ml", "trainers_dir"],
|
||||
size = "small",
|
||||
srcs = ["agents/dreamer/tests/test_dreamer.py"]
|
||||
srcs = ["algorithms/dreamer/tests/test_dreamer.py"]
|
||||
)
|
||||
|
||||
# ES
|
||||
|
@ -743,7 +743,7 @@ py_test(
|
|||
size = "large",
|
||||
# Include the json data file.
|
||||
data = ["tests/data/cartpole/large.json"],
|
||||
srcs = ["agents/marwil/tests/test_marwil.py"]
|
||||
srcs = ["algorithms/marwil/tests/test_marwil.py"]
|
||||
)
|
||||
|
||||
# BCTrainer (sub-type of MARWIL)
|
||||
|
@ -753,7 +753,7 @@ py_test(
|
|||
size = "large",
|
||||
# Include the json data file.
|
||||
data = ["tests/data/cartpole/large.json"],
|
||||
srcs = ["agents/marwil/tests/test_bc.py"]
|
||||
srcs = ["algorithms/marwil/tests/test_bc.py"]
|
||||
)
|
||||
|
||||
# MAMLTrainer
|
||||
|
@ -761,7 +761,7 @@ py_test(
|
|||
name = "test_maml",
|
||||
tags = ["team:ml", "trainers_dir"],
|
||||
size = "medium",
|
||||
srcs = ["agents/maml/tests/test_maml.py"]
|
||||
srcs = ["algorithms/maml/tests/test_maml.py"]
|
||||
)
|
||||
|
||||
# MBMPOTrainer
|
||||
|
@ -769,7 +769,7 @@ py_test(
|
|||
name = "test_mbmpo",
|
||||
tags = ["team:ml", "trainers_dir"],
|
||||
size = "medium",
|
||||
srcs = ["agents/mbmpo/tests/test_mbmpo.py"]
|
||||
srcs = ["algorithms/mbmpo/tests/test_mbmpo.py"]
|
||||
)
|
||||
|
||||
# PGTrainer
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
from ray.rllib.agents.alpha_star.alpha_star import (
|
||||
from ray.rllib.algorithms.alpha_star.alpha_star import (
|
||||
AlphaStarConfig,
|
||||
AlphaStarTrainer,
|
||||
DEFAULT_CONFIG,
|
||||
|
@ -9,3 +9,10 @@ __all__ = [
|
|||
"AlphaStarTrainer",
|
||||
"DEFAULT_CONFIG",
|
||||
]
|
||||
|
||||
|
||||
from ray.rllib.utils.deprecation import deprecation_warning
|
||||
|
||||
deprecation_warning(
|
||||
"ray.rllib.agents.alpha_star", "ray.rllib.algorithms.alpha_star", error=False
|
||||
)
|
||||
|
|
|
@ -1,8 +1,13 @@
|
|||
from ray.rllib.agents.cql.cql import CQLTrainer, CQL_DEFAULT_CONFIG
|
||||
from ray.rllib.agents.cql.cql_torch_policy import CQLTorchPolicy
|
||||
from ray.rllib.algorithms.cql.cql import CQLTrainer, CQL_DEFAULT_CONFIG
|
||||
from ray.rllib.algorithms.cql.cql_tf_policy import CQLTFPolicy
|
||||
from ray.rllib.algorithms.cql.cql_torch_policy import CQLTorchPolicy
|
||||
from ray.rllib.utils.deprecation import deprecation_warning
|
||||
|
||||
__all__ = [
|
||||
"CQL_DEFAULT_CONFIG",
|
||||
"CQLTFPolicy",
|
||||
"CQLTorchPolicy",
|
||||
"CQLTrainer",
|
||||
]
|
||||
|
||||
deprecation_warning("ray.rllib.agents.cql", "ray.rllib.algorithms.cql", error=False)
|
||||
|
|
|
@ -100,9 +100,9 @@ class SimpleQConfig(TrainerConfig):
|
|||
>>> .exploration(exploration_config=explore_config)
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self, trainer_class=None):
|
||||
"""Initializes a SimpleQConfig instance."""
|
||||
super().__init__(trainer_class=SimpleQTrainer)
|
||||
super().__init__(trainer_class=trainer_class or SimpleQTrainer)
|
||||
|
||||
# Simple Q specific
|
||||
# fmt: off
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
from ray.rllib.agents.dreamer.dreamer import (
|
||||
from ray.rllib.algorithms.dreamer.dreamer import (
|
||||
DREAMERConfig,
|
||||
DREAMERTrainer,
|
||||
DEFAULT_CONFIG,
|
||||
|
@ -9,3 +9,10 @@ __all__ = [
|
|||
"DREAMERTrainer",
|
||||
"DEFAULT_CONFIG",
|
||||
]
|
||||
|
||||
|
||||
from ray.rllib.utils.deprecation import deprecation_warning
|
||||
|
||||
deprecation_warning(
|
||||
"ray.rllib.agents.dreamer", "ray.rllib.algorithms.dreamer", error=False
|
||||
)
|
||||
|
|
|
@ -1,6 +1,10 @@
|
|||
from ray.rllib.agents.maml.maml import MAMLTrainer, DEFAULT_CONFIG
|
||||
from ray.rllib.algorithms.maml.maml import MAMLTrainer, DEFAULT_CONFIG
|
||||
|
||||
__all__ = [
|
||||
"MAMLTrainer",
|
||||
"DEFAULT_CONFIG",
|
||||
]
|
||||
|
||||
from ray.rllib.utils.deprecation import deprecation_warning
|
||||
|
||||
deprecation_warning("ray.rllib.agents.maml", "ray.rllib.algorithms.maml", error=False)
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
from ray.rllib.agents.marwil.bc import BCTrainer, BC_DEFAULT_CONFIG
|
||||
from ray.rllib.agents.marwil.marwil import MARWILTrainer, DEFAULT_CONFIG
|
||||
from ray.rllib.agents.marwil.marwil_tf_policy import MARWILTFPolicy
|
||||
from ray.rllib.agents.marwil.marwil_torch_policy import MARWILTorchPolicy
|
||||
from ray.rllib.algorithms.marwil.bc import BCTrainer, BC_DEFAULT_CONFIG
|
||||
from ray.rllib.algorithms.marwil.marwil import MARWILTrainer, DEFAULT_CONFIG
|
||||
from ray.rllib.algorithms.marwil.marwil_tf_policy import MARWILTFPolicy
|
||||
from ray.rllib.algorithms.marwil.marwil_torch_policy import MARWILTorchPolicy
|
||||
|
||||
__all__ = [
|
||||
"BCTrainer",
|
||||
|
@ -11,3 +11,10 @@ __all__ = [
|
|||
"MARWILTorchPolicy",
|
||||
"MARWILTrainer",
|
||||
]
|
||||
|
||||
|
||||
from ray.rllib.utils.deprecation import deprecation_warning
|
||||
|
||||
deprecation_warning(
|
||||
"ray.rllib.agents.marwil", "ray.rllib.algorithms.marwil", error=False
|
||||
)
|
||||
|
|
|
@ -1,6 +1,11 @@
|
|||
from ray.rllib.agents.mbmpo.mbmpo import MBMPOTrainer, DEFAULT_CONFIG
|
||||
from ray.rllib.algorithms.mbmpo.mbmpo import MBMPOTrainer, DEFAULT_CONFIG
|
||||
|
||||
__all__ = [
|
||||
"MBMPOTrainer",
|
||||
"DEFAULT_CONFIG",
|
||||
]
|
||||
|
||||
|
||||
from ray.rllib.utils.deprecation import deprecation_warning
|
||||
|
||||
deprecation_warning("ray.rllib.agents.mbmpo", "ray.rllib.algorithms.mbmpo", error=False)
|
||||
|
|
|
@ -1,3 +1,3 @@
|
|||
from ray.rllib.agents.qmix.qmix import QMixTrainer, DEFAULT_CONFIG
|
||||
from ray.rllib.agents.qmix.qmix import QMixConfig, QMixTrainer, DEFAULT_CONFIG
|
||||
|
||||
__all__ = ["QMixTrainer", "DEFAULT_CONFIG"]
|
||||
__all__ = ["QMixConfig", "QMixTrainer", "DEFAULT_CONFIG"]
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
from typing import Type
|
||||
from typing import Optional, Type
|
||||
|
||||
from ray.rllib.agents.trainer import with_common_config
|
||||
from ray.rllib.agents.dqn.simple_q import SimpleQTrainer
|
||||
from ray.rllib.agents.dqn.simple_q import SimpleQConfig, SimpleQTrainer
|
||||
from ray.rllib.agents.qmix.qmix_policy import QMixTorchPolicy
|
||||
from ray.rllib.execution.rollout_ops import (
|
||||
synchronous_parallel_sample,
|
||||
|
@ -12,7 +11,7 @@ from ray.rllib.execution.train_ops import (
|
|||
)
|
||||
from ray.rllib.policy.policy import Policy
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.deprecation import DEPRECATED_VALUE
|
||||
from ray.rllib.utils.deprecation import DEPRECATED_VALUE, Deprecated
|
||||
from ray.rllib.utils.metrics import (
|
||||
LAST_TARGET_UPDATE_TS,
|
||||
NUM_AGENT_STEPS_SAMPLED,
|
||||
|
@ -23,118 +22,174 @@ from ray.rllib.utils.metrics import (
|
|||
from ray.rllib.utils.replay_buffers.utils import sample_min_n_steps_from_buffer
|
||||
from ray.rllib.utils.typing import ResultDict, TrainerConfigDict
|
||||
|
||||
# fmt: off
|
||||
# __sphinx_doc_begin__
|
||||
DEFAULT_CONFIG = with_common_config({
|
||||
# === QMix ===
|
||||
# Mixing network. Either "qmix", "vdn", or None
|
||||
"mixer": "qmix",
|
||||
# Size of the mixing network embedding
|
||||
"mixing_embed_dim": 32,
|
||||
# Whether to use Double_Q learning
|
||||
"double_q": True,
|
||||
# Optimize over complete episodes by default.
|
||||
"batch_mode": "complete_episodes",
|
||||
|
||||
# === Exploration Settings ===
|
||||
"exploration_config": {
|
||||
# The Exploration class to use.
|
||||
"type": "EpsilonGreedy",
|
||||
# Config for the Exploration class' constructor:
|
||||
"initial_epsilon": 1.0,
|
||||
"final_epsilon": 0.01,
|
||||
# Timesteps over which to anneal epsilon.
|
||||
"epsilon_timesteps": 40000,
|
||||
class QMixConfig(SimpleQConfig):
|
||||
"""Defines a configuration class from which a QMixTrainer can be built.
|
||||
|
||||
# For soft_q, use:
|
||||
# "exploration_config" = {
|
||||
# "type": "SoftQ"
|
||||
# "temperature": [float, e.g. 1.0]
|
||||
# }
|
||||
},
|
||||
Example:
|
||||
>>> from ray.rllib.examples.env.two_step_game import TwoStepGame
|
||||
>>> from ray.rllib.agents.qmix import QMixConfig
|
||||
>>> config = QMixConfig().training(gamma=0.9, lr=0.01, kl_coeff=0.3)\
|
||||
... .resources(num_gpus=0)\
|
||||
... .rollouts(num_workers=4)
|
||||
>>> print(config.to_dict())
|
||||
>>> # Build a Trainer object from the config and run 1 training iteration.
|
||||
>>> trainer = config.build(env=TwoStepGame)
|
||||
>>> trainer.train()
|
||||
|
||||
# === Evaluation ===
|
||||
# Evaluate with epsilon=0 every `evaluation_interval` training iterations.
|
||||
# The evaluation stats will be reported under the "evaluation" metric key.
|
||||
# Note that evaluation is currently not parallelized, and that for Ape-X
|
||||
# metrics are already only reported for the lowest epsilon workers.
|
||||
"evaluation_interval": None,
|
||||
# Number of episodes to run per evaluation period.
|
||||
"evaluation_duration": 10,
|
||||
# Switch to greedy actions in evaluation workers.
|
||||
"evaluation_config": {
|
||||
"explore": False,
|
||||
},
|
||||
Example:
|
||||
>>> from ray.rllib.examples.env.two_step_game import TwoStepGame
|
||||
>>> from ray.rllib.agents.qmix import QMixConfig
|
||||
>>> from ray import tune
|
||||
>>> config = QMixConfig()
|
||||
>>> # Print out some default values.
|
||||
>>> print(config.optim_alpha)
|
||||
>>> # Update the config object.
|
||||
>>> config.training(lr=tune.grid_search([0.001, 0.0001]), optim_alpha=0.97)
|
||||
>>> # Set the config object's env.
|
||||
>>> config.environment(env=TwoStepGame)
|
||||
>>> # Use to_dict() to get the old-style python config dict
|
||||
>>> # when running with tune.
|
||||
>>> tune.run(
|
||||
... "QMix",
|
||||
... stop={"episode_reward_mean": 200},
|
||||
... config=config.to_dict(),
|
||||
... )
|
||||
"""
|
||||
|
||||
# Minimum env sampling timesteps to accumulate within a single `train()` call. This
|
||||
# value does not affect learning, only the number of times `Trainer.step_attempt()`
|
||||
# is called by `Trauber.train()`. If - after one `step_attempt()`, the env sampling
|
||||
# timestep count has not been reached, will perform n more `step_attempt()` calls
|
||||
# until the minimum timesteps have been executed. Set to 0 for no minimum timesteps.
|
||||
"min_sample_timesteps_per_reporting": 1000,
|
||||
# Update the target network every `target_network_update_freq` steps.
|
||||
"target_network_update_freq": 500,
|
||||
def __init__(self):
|
||||
"""Initializes a PPOConfig instance."""
|
||||
super().__init__(trainer_class=QMixTrainer)
|
||||
|
||||
# === Replay buffer ===
|
||||
"replay_buffer_config": {
|
||||
# Use the new ReplayBuffer API here
|
||||
"_enable_replay_buffer_api": True,
|
||||
"type": "SimpleReplayBuffer",
|
||||
# Size of the replay buffer in batches (not timesteps!).
|
||||
"capacity": 1000,
|
||||
"learning_starts": 1000,
|
||||
},
|
||||
# fmt: off
|
||||
# __sphinx_doc_begin__
|
||||
# QMix specific settings:
|
||||
self.mixer = "qmix"
|
||||
self.mixing_embed_dim = 32
|
||||
self.double_q = True
|
||||
self.target_network_update_freq = 500
|
||||
self.replay_buffer_config = {
|
||||
# Use the new ReplayBuffer API here
|
||||
"_enable_replay_buffer_api": True,
|
||||
"type": "SimpleReplayBuffer",
|
||||
# Size of the replay buffer in batches (not timesteps!).
|
||||
"capacity": 1000,
|
||||
"learning_starts": 1000,
|
||||
}
|
||||
self.optim_alpha = 0.99
|
||||
self.optim_eps = 0.00001
|
||||
self.grad_norm_clipping = 10
|
||||
self.worker_side_prioritization = False
|
||||
|
||||
# === Optimization ===
|
||||
# Learning rate for RMSProp optimizer
|
||||
"lr": 0.0005,
|
||||
# RMSProp alpha
|
||||
"optim_alpha": 0.99,
|
||||
# RMSProp epsilon
|
||||
"optim_eps": 0.00001,
|
||||
# If not None, clip gradients during optimization at this value
|
||||
"grad_norm_clipping": 10,
|
||||
# Update the replay buffer with this many samples at once. Note that
|
||||
# this setting applies per-worker if num_workers > 1.
|
||||
"rollout_fragment_length": 4,
|
||||
# Minimum batch size used for training (in timesteps). With the default buffer
|
||||
# (ReplayBuffer) this means, sampling from the buffer (entire-episode SampleBatches)
|
||||
# as many times as is required to reach at least this number of timesteps.
|
||||
"train_batch_size": 32,
|
||||
# Override some of TrainerConfig's default values with QMix-specific values.
|
||||
self.num_workers = 0
|
||||
self.min_time_s_per_reporting = 1
|
||||
self.model = {
|
||||
"lstm_cell_size": 64,
|
||||
"max_seq_len": 999999,
|
||||
}
|
||||
self.framework_str = "torch"
|
||||
self.lr = 0.0005
|
||||
self.rollout_fragment_length = 4
|
||||
self.train_batch_size = 32
|
||||
self.batch_mode = "complete_episodes"
|
||||
self.exploration_config = {
|
||||
# The Exploration class to use.
|
||||
"type": "EpsilonGreedy",
|
||||
# Config for the Exploration class' constructor:
|
||||
"initial_epsilon": 1.0,
|
||||
"final_epsilon": 0.01,
|
||||
# Timesteps over which to anneal epsilon.
|
||||
"epsilon_timesteps": 40000,
|
||||
|
||||
# === Parallelism ===
|
||||
# Number of workers for collecting samples with. This only makes sense
|
||||
# to increase if your environment is particularly slow to sample, or if
|
||||
# you"re using the Async or Ape-X optimizers.
|
||||
"num_workers": 0,
|
||||
# Whether to compute priorities on workers.
|
||||
"worker_side_prioritization": False,
|
||||
# Prevent reporting frequency from going lower than this time span.
|
||||
"min_time_s_per_reporting": 1,
|
||||
# For soft_q, use:
|
||||
# "exploration_config" = {
|
||||
# "type": "SoftQ"
|
||||
# "temperature": [float, e.g. 1.0]
|
||||
# }
|
||||
}
|
||||
|
||||
# === Model ===
|
||||
"model": {
|
||||
"lstm_cell_size": 64,
|
||||
"max_seq_len": 999999,
|
||||
},
|
||||
# Only torch supported so far.
|
||||
"framework": "torch",
|
||||
# Evaluate with epsilon=0 every `evaluation_interval` training iterations.
|
||||
# The evaluation stats will be reported under the "evaluation" metric key.
|
||||
# Note that evaluation is currently not parallelized, and that for Ape-X
|
||||
# metrics are already only reported for the lowest epsilon workers.
|
||||
self.evaluation_interval = None
|
||||
self.evaluation_duration = 10
|
||||
self.evaluation_config = {
|
||||
"explore": False,
|
||||
}
|
||||
self.min_sample_timesteps_per_reporting = 1000
|
||||
# __sphinx_doc_end__
|
||||
# fmt: on
|
||||
|
||||
# Deprecated keys:
|
||||
# Use `replay_buffer_config.learning_starts` instead.
|
||||
"learning_starts": DEPRECATED_VALUE,
|
||||
# Use `replay_buffer_config.capacity` instead.
|
||||
"buffer_size": DEPRECATED_VALUE,
|
||||
})
|
||||
# __sphinx_doc_end__
|
||||
# fmt: on
|
||||
# Deprecated keys:
|
||||
self.learning_starts = DEPRECATED_VALUE
|
||||
self.buffer_size = DEPRECATED_VALUE
|
||||
|
||||
@override(SimpleQConfig)
|
||||
def training(
|
||||
self,
|
||||
*,
|
||||
mixer: Optional[str] = None,
|
||||
mixing_embed_dim: Optional[int] = None,
|
||||
double_q: Optional[bool] = None,
|
||||
target_network_update_freq: Optional[int] = None,
|
||||
replay_buffer_config: Optional[dict] = None,
|
||||
optim_alpha: Optional[float] = None,
|
||||
optim_eps: Optional[float] = None,
|
||||
grad_norm_clipping: Optional[float] = None,
|
||||
worker_side_prioritization: Optional[bool] = None,
|
||||
**kwargs,
|
||||
) -> "QMixConfig":
|
||||
"""Sets the training related configuration.
|
||||
|
||||
Args:
|
||||
mixer: Mixing network. Either "qmix", "vdn", or None.
|
||||
mixing_embed_dim: Size of the mixing network embedding.
|
||||
double_q: Whether to use Double_Q learning.
|
||||
target_network_update_freq: Update the target network every
|
||||
`target_network_update_freq` sample steps.
|
||||
replay_buffer_config:
|
||||
optim_alpha: RMSProp alpha.
|
||||
optim_eps: RMSProp epsilon.
|
||||
grad_norm_clipping: If not None, clip gradients during optimization at
|
||||
this value.
|
||||
worker_side_prioritization: Whether to compute priorities for the replay
|
||||
buffer on worker side.
|
||||
|
||||
Returns:
|
||||
This updated TrainerConfig object.
|
||||
"""
|
||||
# Pass kwargs onto super's `training()` method.
|
||||
super().training(**kwargs)
|
||||
|
||||
if mixer is not None:
|
||||
self.mixer = mixer
|
||||
if mixing_embed_dim is not None:
|
||||
self.mixing_embed_dim = mixing_embed_dim
|
||||
if double_q is not None:
|
||||
self.double_q = double_q
|
||||
if target_network_update_freq is not None:
|
||||
self.target_network_update_freq = target_network_update_freq
|
||||
if replay_buffer_config is not None:
|
||||
self.replay_buffer_config = replay_buffer_config
|
||||
if optim_alpha is not None:
|
||||
self.optim_alpha = optim_alpha
|
||||
if optim_eps is not None:
|
||||
self.optim_eps = optim_eps
|
||||
if grad_norm_clipping is not None:
|
||||
self.grad_norm_clipping = grad_norm_clipping
|
||||
if worker_side_prioritization is not None:
|
||||
self.worker_side_prioritization = worker_side_prioritization
|
||||
|
||||
return self
|
||||
|
||||
|
||||
class QMixTrainer(SimpleQTrainer):
|
||||
@classmethod
|
||||
@override(SimpleQTrainer)
|
||||
def get_default_config(cls) -> TrainerConfigDict:
|
||||
return DEFAULT_CONFIG
|
||||
return QMixConfig().to_dict()
|
||||
|
||||
@override(SimpleQTrainer)
|
||||
def validate_config(self, config: TrainerConfigDict) -> None:
|
||||
|
@ -219,3 +274,20 @@ class QMixTrainer(SimpleQTrainer):
|
|||
|
||||
# Return all collected metrics for the iteration.
|
||||
return train_results
|
||||
|
||||
|
||||
# Deprecated: Use ray.rllib.agents.qmix.qmix.QMixConfig instead!
|
||||
class _deprecated_default_config(dict):
|
||||
def __init__(self):
|
||||
super().__init__(QMixConfig().to_dict())
|
||||
|
||||
@Deprecated(
|
||||
old="ray.rllib.agents.qmix.qmix.DEFAULT_CONFIG",
|
||||
new="ray.rllib.agents.qmix.qmix.QMixConfig(...)",
|
||||
error=False,
|
||||
)
|
||||
def __getitem__(self, item):
|
||||
return super().__getitem__(item)
|
||||
|
||||
|
||||
DEFAULT_CONFIG = _deprecated_default_config()
|
||||
|
|
|
@ -153,7 +153,6 @@ class QMixLoss(nn.Module):
|
|||
return loss, mask, masked_td_error, chosen_action_qvals, targets
|
||||
|
||||
|
||||
# TODO(sven): Make this a TorchPolicy child via `build_policy_class`.
|
||||
class QMixTorchPolicy(TorchPolicy):
|
||||
"""QMix impl. Assumes homogeneous agents for now.
|
||||
|
||||
|
@ -177,9 +176,6 @@ class QMixTorchPolicy(TorchPolicy):
|
|||
self.h_size = config["model"]["lstm_cell_size"]
|
||||
self.has_env_global_state = False
|
||||
self.has_action_mask = False
|
||||
self.device = (
|
||||
torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
||||
)
|
||||
|
||||
agent_obs_space = obs_space.original_space.spaces[0]
|
||||
if isinstance(agent_obs_space, gym.spaces.Dict):
|
||||
|
@ -218,7 +214,9 @@ class QMixTorchPolicy(TorchPolicy):
|
|||
framework="torch",
|
||||
name="model",
|
||||
default_model=RNNModel,
|
||||
).to(self.device)
|
||||
)
|
||||
|
||||
super().__init__(obs_space, action_space, config, model=self.model)
|
||||
|
||||
self.target_model = ModelCatalog.get_model_v2(
|
||||
agent_obs_space,
|
||||
|
@ -230,8 +228,6 @@ class QMixTorchPolicy(TorchPolicy):
|
|||
default_model=RNNModel,
|
||||
).to(self.device)
|
||||
|
||||
super().__init__(obs_space, action_space, config, model=self.model)
|
||||
|
||||
self.exploration = self._create_exploration()
|
||||
|
||||
# Setup the mixer network.
|
||||
|
|
|
@ -4,7 +4,7 @@ import unittest
|
|||
|
||||
import ray
|
||||
from ray.tune import register_env
|
||||
from ray.rllib.agents.qmix import QMixTrainer
|
||||
from ray.rllib.agents.qmix import QMixConfig
|
||||
from ray.rllib.env.multi_agent_env import MultiAgentEnv
|
||||
|
||||
|
||||
|
@ -95,18 +95,21 @@ class TestQMix(unittest.TestCase):
|
|||
),
|
||||
)
|
||||
|
||||
trainer = QMixTrainer(
|
||||
env="action_mask_test",
|
||||
config={
|
||||
"num_envs_per_worker": 5, # test with vectorization on
|
||||
"env_config": {
|
||||
"avail_actions": [3, 4, 8],
|
||||
},
|
||||
"framework": "torch",
|
||||
},
|
||||
)
|
||||
config = (
|
||||
QMixConfig()
|
||||
.framework(framework="torch")
|
||||
.environment(
|
||||
env="action_mask_test",
|
||||
env_config={"avail_actions": [3, 4, 8]},
|
||||
)
|
||||
.rollouts(num_envs_per_worker=5)
|
||||
) # Test with vectorization on.
|
||||
|
||||
trainer = config.build()
|
||||
|
||||
for _ in range(4):
|
||||
trainer.train() # OK if it doesn't trip the action assertion error
|
||||
|
||||
assert trainer.train()["episode_reward_mean"] == 30.0
|
||||
trainer.stop()
|
||||
ray.shutdown()
|
||||
|
|
|
@ -18,7 +18,10 @@ def _import_a3c():
|
|||
|
||||
|
||||
def _import_alpha_star():
|
||||
from ray.rllib.agents.alpha_star.alpha_star import AlphaStarTrainer, DEFAULT_CONFIG
|
||||
from ray.rllib.algorithms.alpha_star.alpha_star import (
|
||||
AlphaStarTrainer,
|
||||
DEFAULT_CONFIG,
|
||||
)
|
||||
|
||||
return AlphaStarTrainer, DEFAULT_CONFIG
|
||||
|
||||
|
@ -42,7 +45,7 @@ def _import_appo():
|
|||
|
||||
|
||||
def _import_ars():
|
||||
from ray.rllib.agents import ars
|
||||
from ray.rllib.algorithms import ars
|
||||
|
||||
return ars.ARSTrainer, ars.DEFAULT_CONFIG
|
||||
|
||||
|
@ -60,13 +63,13 @@ def _import_bandit_linucb():
|
|||
|
||||
|
||||
def _import_bc():
|
||||
from ray.rllib.agents import marwil
|
||||
from ray.rllib.algorithms import marwil
|
||||
|
||||
return marwil.BCTrainer, marwil.DEFAULT_CONFIG
|
||||
|
||||
|
||||
def _import_cql():
|
||||
from ray.rllib.agents import cql
|
||||
from ray.rllib.algorithms import cql
|
||||
|
||||
return cql.CQLTrainer, cql.CQL_DEFAULT_CONFIG
|
||||
|
||||
|
@ -90,13 +93,13 @@ def _import_dqn():
|
|||
|
||||
|
||||
def _import_dreamer():
|
||||
from ray.rllib.agents import dreamer
|
||||
from ray.rllib.algorithms import dreamer
|
||||
|
||||
return dreamer.DREAMERTrainer, dreamer.DEFAULT_CONFIG
|
||||
|
||||
|
||||
def _import_es():
|
||||
from ray.rllib.agents import es
|
||||
from ray.rllib.algorithms import es
|
||||
|
||||
return es.ESTrainer, es.DEFAULT_CONFIG
|
||||
|
||||
|
@ -114,19 +117,19 @@ def _import_maddpg():
|
|||
|
||||
|
||||
def _import_maml():
|
||||
from ray.rllib.agents import maml
|
||||
from ray.rllib.algorithms import maml
|
||||
|
||||
return maml.MAMLTrainer, maml.DEFAULT_CONFIG
|
||||
|
||||
|
||||
def _import_marwil():
|
||||
from ray.rllib.agents import marwil
|
||||
from ray.rllib.algorithms import marwil
|
||||
|
||||
return marwil.MARWILTrainer, marwil.DEFAULT_CONFIG
|
||||
|
||||
|
||||
def _import_mbmpo():
|
||||
from ray.rllib.agents import mbmpo
|
||||
from ray.rllib.algorithms import mbmpo
|
||||
|
||||
return mbmpo.MBMPOTrainer, mbmpo.DEFAULT_CONFIG
|
||||
|
||||
|
|
|
@ -10,7 +10,7 @@ import unittest
|
|||
import ray
|
||||
import ray.rllib.agents.a3c as a3c
|
||||
import ray.rllib.agents.dqn as dqn
|
||||
from ray.rllib.agents.marwil import BCTrainer
|
||||
from ray.rllib.algorithms.marwil import BCTrainer
|
||||
import ray.rllib.agents.pg as pg
|
||||
from ray.rllib.agents.trainer import COMMON_CONFIG
|
||||
from ray.rllib.examples.env.multi_agent import MultiAgentCartPole
|
||||
|
|
11
rllib/algorithms/alpha_star/__init__.py
Normal file
11
rllib/algorithms/alpha_star/__init__.py
Normal file
|
@ -0,0 +1,11 @@
|
|||
from ray.rllib.algorithms.alpha_star.alpha_star import (
|
||||
AlphaStarConfig,
|
||||
AlphaStarTrainer,
|
||||
DEFAULT_CONFIG,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"AlphaStarConfig",
|
||||
"AlphaStarTrainer",
|
||||
"DEFAULT_CONFIG",
|
||||
]
|
|
@ -9,8 +9,8 @@ from typing import Any, Dict, Optional, Type
|
|||
|
||||
import ray
|
||||
from ray.actor import ActorHandle
|
||||
from ray.rllib.agents.alpha_star.distributed_learners import DistributedLearners
|
||||
from ray.rllib.agents.alpha_star.league_builder import AlphaStarLeagueBuilder
|
||||
from ray.rllib.algorithms.alpha_star.distributed_learners import DistributedLearners
|
||||
from ray.rllib.algorithms.alpha_star.league_builder import AlphaStarLeagueBuilder
|
||||
from ray.rllib.agents.trainer import Trainer
|
||||
import ray.rllib.agents.ppo.appo as appo
|
||||
from ray.rllib.evaluation.rollout_worker import RolloutWorker
|
||||
|
@ -48,7 +48,7 @@ class AlphaStarConfig(appo.APPOConfig):
|
|||
"""Defines a configuration class from which an AlphaStarTrainer can be built.
|
||||
|
||||
Example:
|
||||
>>> from ray.rllib.agents.alpha_star import AlphaStarConfig
|
||||
>>> from ray.rllib.algorithms.alpha_star import AlphaStarConfig
|
||||
>>> config = AlphaStarConfig().training(lr=0.0003, train_batch_size=512)\
|
||||
... .resources(num_gpus=4)\
|
||||
... .rollouts(num_rollout_workers=64)
|
||||
|
@ -58,7 +58,7 @@ class AlphaStarConfig(appo.APPOConfig):
|
|||
>>> trainer.train()
|
||||
|
||||
Example:
|
||||
>>> from ray.rllib.agents.alpha_star import AlphaStarConfig
|
||||
>>> from ray.rllib.algorithms.alpha_star import AlphaStarConfig
|
||||
>>> from ray import tune
|
||||
>>> config = AlphaStarConfig()
|
||||
>>> # Print out some default values.
|
||||
|
@ -168,7 +168,7 @@ class AlphaStarConfig(appo.APPOConfig):
|
|||
to be used for league building logic. All other keys (that are not
|
||||
`type`) will be used as constructor kwargs on the given class to
|
||||
construct the LeagueBuilder instance. See the
|
||||
`ray.rllib.agents.alpha_star.league_builder::AlphaStarLeagueBuilder`
|
||||
`ray.rllib.algorithms.alpha_star.league_builder::AlphaStarLeagueBuilder`
|
||||
(used by default by this algo) as an example.
|
||||
max_num_policies_to_train: The maximum number of trainable policies for this
|
||||
Trainer. Each trainable policy will exist as a independent remote actor,
|
||||
|
@ -584,8 +584,8 @@ class _deprecated_default_config(dict):
|
|||
super().__init__(AlphaStarConfig().to_dict())
|
||||
|
||||
@Deprecated(
|
||||
old="ray.rllib.agents.alpha_star.alpha_star.DEFAULT_CONFIG",
|
||||
new="ray.rllib.agents.alpha_star.alpha_star.AlphaStarConfig(...)",
|
||||
old="ray.rllib.algorithms.alpha_star.alpha_star.DEFAULT_CONFIG",
|
||||
new="ray.rllib.algorithms.alpha_star.alpha_star.AlphaStarConfig(...)",
|
||||
error=False,
|
||||
)
|
||||
def __getitem__(self, item):
|
|
@ -2,7 +2,7 @@ import pyspiel
|
|||
import unittest
|
||||
|
||||
import ray
|
||||
import ray.rllib.agents.alpha_star as alpha_star
|
||||
import ray.rllib.algorithms.alpha_star as alpha_star
|
||||
from ray.rllib.env.wrappers.open_spiel import OpenSpielEnv
|
||||
from ray.rllib.utils.test_utils import (
|
||||
check_compute_single_action,
|
8
rllib/algorithms/cql/__init__.py
Normal file
8
rllib/algorithms/cql/__init__.py
Normal file
|
@ -0,0 +1,8 @@
|
|||
from ray.rllib.algorithms.cql.cql import CQLTrainer, CQL_DEFAULT_CONFIG
|
||||
from ray.rllib.algorithms.cql.cql_torch_policy import CQLTorchPolicy
|
||||
|
||||
__all__ = [
|
||||
"CQL_DEFAULT_CONFIG",
|
||||
"CQLTorchPolicy",
|
||||
"CQLTrainer",
|
||||
]
|
|
@ -2,8 +2,8 @@ import logging
|
|||
import numpy as np
|
||||
from typing import Type
|
||||
|
||||
from ray.rllib.agents.cql.cql_tf_policy import CQLTFPolicy
|
||||
from ray.rllib.agents.cql.cql_torch_policy import CQLTorchPolicy
|
||||
from ray.rllib.algorithms.cql.cql_tf_policy import CQLTFPolicy
|
||||
from ray.rllib.algorithms.cql.cql_torch_policy import CQLTorchPolicy
|
||||
from ray.rllib.agents.sac.sac import SACTrainer, DEFAULT_CONFIG as SAC_CONFIG
|
||||
from ray.rllib.execution.train_ops import (
|
||||
multi_gpu_train_one_step,
|
|
@ -411,7 +411,7 @@ def apply_gradients_fn(policy, optimizer, grads_and_vars):
|
|||
CQLTFPolicy = build_tf_policy(
|
||||
name="CQLTFPolicy",
|
||||
loss_fn=cql_loss,
|
||||
get_default_config=lambda: ray.rllib.agents.cql.cql.CQL_DEFAULT_CONFIG,
|
||||
get_default_config=lambda: ray.rllib.algorithms.cql.cql.CQL_DEFAULT_CONFIG,
|
||||
validate_spaces=validate_spaces,
|
||||
stats_fn=cql_stats,
|
||||
postprocess_fn=postprocess_trajectory,
|
|
@ -387,7 +387,7 @@ CQLTorchPolicy = build_policy_class(
|
|||
name="CQLTorchPolicy",
|
||||
framework="torch",
|
||||
loss_fn=cql_loss,
|
||||
get_default_config=lambda: ray.rllib.agents.cql.cql.CQL_DEFAULT_CONFIG,
|
||||
get_default_config=lambda: ray.rllib.algorithms.cql.cql.CQL_DEFAULT_CONFIG,
|
||||
stats_fn=cql_stats,
|
||||
postprocess_fn=postprocess_trajectory,
|
||||
extra_grad_process_fn=apply_grad_clipping,
|
|
@ -4,7 +4,7 @@ import os
|
|||
import unittest
|
||||
|
||||
import ray
|
||||
import ray.rllib.agents.cql as cql
|
||||
import ray.rllib.algorithms.cql as cql
|
||||
from ray.rllib.utils.framework import try_import_tf, try_import_torch
|
||||
from ray.rllib.utils.test_utils import (
|
||||
check_compute_single_action,
|
11
rllib/algorithms/dreamer/__init__.py
Normal file
11
rllib/algorithms/dreamer/__init__.py
Normal file
|
@ -0,0 +1,11 @@
|
|||
from ray.rllib.algorithms.dreamer.dreamer import (
|
||||
DREAMERConfig,
|
||||
DREAMERTrainer,
|
||||
DEFAULT_CONFIG,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"DREAMERConfig",
|
||||
"DREAMERTrainer",
|
||||
"DEFAULT_CONFIG",
|
||||
]
|
|
@ -4,12 +4,12 @@ import random
|
|||
from typing import Optional
|
||||
|
||||
from ray.rllib.agents.trainer_config import TrainerConfig
|
||||
from ray.rllib.agents.dreamer.dreamer_torch_policy import DreamerTorchPolicy
|
||||
from ray.rllib.algorithms.dreamer.dreamer_torch_policy import DreamerTorchPolicy
|
||||
from ray.rllib.agents.trainer import Trainer
|
||||
from ray.rllib.execution.common import STEPS_SAMPLED_COUNTER, _get_shared_metrics
|
||||
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID, SampleBatch
|
||||
from ray.rllib.evaluation.metrics import collect_metrics
|
||||
from ray.rllib.agents.dreamer.dreamer_model import DreamerModel
|
||||
from ray.rllib.algorithms.dreamer.dreamer_model import DreamerModel
|
||||
from ray.rllib.execution.rollout_ops import (
|
||||
ParallelRollouts,
|
||||
synchronous_parallel_sample,
|
||||
|
@ -31,7 +31,7 @@ class DREAMERConfig(TrainerConfig):
|
|||
"""Defines a PPOTrainer configuration class from which a PPOTrainer can be built.
|
||||
|
||||
Example:
|
||||
>>> from ray.rllib.agents.dreamer import DREAMERConfig
|
||||
>>> from ray.rllib.algorithms.dreamer import DREAMERConfig
|
||||
>>> config = DREAMERConfig().training(gamma=0.9, lr=0.01)\
|
||||
... .resources(num_gpus=0)\
|
||||
... .rollouts(num_rollout_workers=4)
|
||||
|
@ -42,7 +42,7 @@ class DREAMERConfig(TrainerConfig):
|
|||
|
||||
Example:
|
||||
>>> from ray import tune
|
||||
>>> from ray.rllib.agents.dreamer import DREAMERConfig
|
||||
>>> from ray.rllib.algorithms.dreamer import DREAMERConfig
|
||||
>>> config = DREAMERConfig()
|
||||
>>> # Print out some default values.
|
||||
>>> print(config.clip_param)
|
||||
|
@ -418,14 +418,14 @@ class DREAMERTrainer(Trainer):
|
|||
return results
|
||||
|
||||
|
||||
# Deprecated: Use ray.rllib.agents.dreamer.DREAMERConfig instead!
|
||||
# Deprecated: Use ray.rllib.algorithms.dreamer.DREAMERConfig instead!
|
||||
class _deprecated_default_config(dict):
|
||||
def __init__(self):
|
||||
super().__init__(DREAMERConfig().to_dict())
|
||||
|
||||
@Deprecated(
|
||||
old="ray.rllib.agents.dreamer.dreamer.DEFAULT_CONFIG",
|
||||
new="ray.rllib.agents.dreamer.dreamer.DREAMERConfig(...)",
|
||||
old="ray.rllib.algorithms.dreamer.dreamer.DEFAULT_CONFIG",
|
||||
new="ray.rllib.algorithms.dreamer.dreamer.DREAMERConfig(...)",
|
||||
error=False,
|
||||
)
|
||||
def __getitem__(self, item):
|
|
@ -8,7 +8,7 @@ from ray.rllib.utils.framework import TensorType
|
|||
torch, nn = try_import_torch()
|
||||
if torch:
|
||||
from torch import distributions as td
|
||||
from ray.rllib.agents.dreamer.utils import (
|
||||
from ray.rllib.algorithms.dreamer.utils import (
|
||||
Linear,
|
||||
Conv2d,
|
||||
ConvTranspose2d,
|
|
@ -4,7 +4,7 @@ import numpy as np
|
|||
from typing import Dict, Optional
|
||||
|
||||
import ray
|
||||
from ray.rllib.agents.dreamer.utils import FreezeParameters
|
||||
from ray.rllib.algorithms.dreamer.utils import FreezeParameters
|
||||
from ray.rllib.evaluation.episode import Episode
|
||||
from ray.rllib.models.catalog import ModelCatalog
|
||||
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
|
||||
|
@ -284,7 +284,7 @@ def preprocess_episode(
|
|||
DreamerTorchPolicy = build_policy_class(
|
||||
name="DreamerTorchPolicy",
|
||||
framework="torch",
|
||||
get_default_config=lambda: ray.rllib.agents.dreamer.dreamer.DEFAULT_CONFIG,
|
||||
get_default_config=lambda: ray.rllib.algorithms.dreamer.dreamer.DEFAULT_CONFIG,
|
||||
action_sampler_fn=action_sampler_fn,
|
||||
postprocess_fn=preprocess_episode,
|
||||
loss_fn=dreamer_loss,
|
|
@ -2,7 +2,7 @@ from gym.spaces import Box
|
|||
import unittest
|
||||
|
||||
import ray
|
||||
import ray.rllib.agents.dreamer as dreamer
|
||||
import ray.rllib.algorithms.dreamer as dreamer
|
||||
from ray.rllib.examples.env.random_env import RandomEnv
|
||||
from ray.rllib.utils.test_utils import framework_iterator
|
||||
|
6
rllib/algorithms/maml/__init__.py
Normal file
6
rllib/algorithms/maml/__init__.py
Normal file
|
@ -0,0 +1,6 @@
|
|||
from ray.rllib.algorithms.maml.maml import MAMLTrainer, DEFAULT_CONFIG
|
||||
|
||||
__all__ = [
|
||||
"MAMLTrainer",
|
||||
"DEFAULT_CONFIG",
|
||||
]
|
|
@ -4,8 +4,8 @@ from typing import Type
|
|||
|
||||
from ray.rllib.utils.sgd import standardized
|
||||
from ray.rllib.agents import with_common_config
|
||||
from ray.rllib.agents.maml.maml_tf_policy import MAMLTFPolicy
|
||||
from ray.rllib.agents.maml.maml_torch_policy import MAMLTorchPolicy
|
||||
from ray.rllib.algorithms.maml.maml_tf_policy import MAMLTFPolicy
|
||||
from ray.rllib.algorithms.maml.maml_torch_policy import MAMLTorchPolicy
|
||||
from ray.rllib.agents.trainer import Trainer
|
||||
from ray.rllib.evaluation.metrics import get_learner_stats
|
||||
from ray.rllib.evaluation.worker_set import WorkerSet
|
|
@ -442,7 +442,7 @@ def setup_mixins(policy, obs_space, action_space, config):
|
|||
|
||||
MAMLTFPolicy = build_tf_policy(
|
||||
name="MAMLTFPolicy",
|
||||
get_default_config=lambda: ray.rllib.agents.maml.maml.DEFAULT_CONFIG,
|
||||
get_default_config=lambda: ray.rllib.algorithms.maml.maml.DEFAULT_CONFIG,
|
||||
loss_fn=maml_loss,
|
||||
stats_fn=maml_stats,
|
||||
optimizer_fn=maml_optimizer_fn,
|
|
@ -382,7 +382,7 @@ def setup_mixins(policy, obs_space, action_space, config):
|
|||
MAMLTorchPolicy = build_policy_class(
|
||||
name="MAMLTorchPolicy",
|
||||
framework="torch",
|
||||
get_default_config=lambda: ray.rllib.agents.maml.maml.DEFAULT_CONFIG,
|
||||
get_default_config=lambda: ray.rllib.algorithms.maml.maml.DEFAULT_CONFIG,
|
||||
loss_fn=maml_loss,
|
||||
stats_fn=maml_stats,
|
||||
optimizer_fn=maml_optimizer_fn,
|
|
@ -1,7 +1,7 @@
|
|||
import unittest
|
||||
|
||||
import ray
|
||||
import ray.rllib.agents.maml as maml
|
||||
import ray.rllib.algorithms.maml as maml
|
||||
from ray.rllib.utils.test_utils import (
|
||||
check_compute_single_action,
|
||||
check_train_results,
|
13
rllib/algorithms/marwil/__init__.py
Normal file
13
rllib/algorithms/marwil/__init__.py
Normal file
|
@ -0,0 +1,13 @@
|
|||
from ray.rllib.algorithms.marwil.bc import BCTrainer, BC_DEFAULT_CONFIG
|
||||
from ray.rllib.algorithms.marwil.marwil import MARWILTrainer, DEFAULT_CONFIG
|
||||
from ray.rllib.algorithms.marwil.marwil_tf_policy import MARWILTFPolicy
|
||||
from ray.rllib.algorithms.marwil.marwil_torch_policy import MARWILTorchPolicy
|
||||
|
||||
__all__ = [
|
||||
"BCTrainer",
|
||||
"BC_DEFAULT_CONFIG",
|
||||
"DEFAULT_CONFIG",
|
||||
"MARWILTFPolicy",
|
||||
"MARWILTorchPolicy",
|
||||
"MARWILTrainer",
|
||||
]
|
|
@ -1,4 +1,4 @@
|
|||
from ray.rllib.agents.marwil.marwil import (
|
||||
from ray.rllib.algorithms.marwil.marwil import (
|
||||
MARWILTrainer,
|
||||
DEFAULT_CONFIG as MARWIL_CONFIG,
|
||||
)
|
|
@ -1,7 +1,7 @@
|
|||
from typing import Type
|
||||
|
||||
from ray.rllib.agents.trainer import Trainer, with_common_config
|
||||
from ray.rllib.agents.marwil.marwil_tf_policy import MARWILTFPolicy
|
||||
from ray.rllib.algorithms.marwil.marwil_tf_policy import MARWILTFPolicy
|
||||
from ray.rllib.execution.buffers.multi_agent_replay_buffer import MultiAgentReplayBuffer
|
||||
from ray.rllib.execution.rollout_ops import (
|
||||
synchronous_parallel_sample,
|
||||
|
@ -110,7 +110,9 @@ class MARWILTrainer(Trainer):
|
|||
@override(Trainer)
|
||||
def get_default_policy_class(self, config: TrainerConfigDict) -> Type[Policy]:
|
||||
if config["framework"] == "torch":
|
||||
from ray.rllib.agents.marwil.marwil_torch_policy import MARWILTorchPolicy
|
||||
from ray.rllib.algorithms.marwil.marwil_torch_policy import (
|
||||
MARWILTorchPolicy,
|
||||
)
|
||||
|
||||
return MARWILTorchPolicy
|
||||
else:
|
|
@ -232,7 +232,7 @@ def setup_mixins(
|
|||
|
||||
MARWILTFPolicy = build_tf_policy(
|
||||
name="MARWILTFPolicy",
|
||||
get_default_config=lambda: ray.rllib.agents.marwil.marwil.DEFAULT_CONFIG,
|
||||
get_default_config=lambda: ray.rllib.algorithms.marwil.marwil.DEFAULT_CONFIG,
|
||||
loss_fn=marwil_loss,
|
||||
stats_fn=stats,
|
||||
postprocess_fn=postprocess_advantages,
|
|
@ -3,7 +3,7 @@ from typing import Dict
|
|||
|
||||
import ray
|
||||
from ray.rllib.agents.a3c.a3c_torch_policy import ValueNetworkMixin
|
||||
from ray.rllib.agents.marwil.marwil_tf_policy import postprocess_advantages
|
||||
from ray.rllib.algorithms.marwil.marwil_tf_policy import postprocess_advantages
|
||||
from ray.rllib.evaluation.postprocessing import Postprocessing
|
||||
from ray.rllib.policy.policy_template import build_policy_class
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
|
@ -113,7 +113,7 @@ MARWILTorchPolicy = build_policy_class(
|
|||
name="MARWILTorchPolicy",
|
||||
framework="torch",
|
||||
loss_fn=marwil_loss,
|
||||
get_default_config=lambda: ray.rllib.agents.marwil.marwil.DEFAULT_CONFIG,
|
||||
get_default_config=lambda: ray.rllib.algorithms.marwil.marwil.DEFAULT_CONFIG,
|
||||
stats_fn=stats,
|
||||
postprocess_fn=postprocess_advantages,
|
||||
extra_grad_process_fn=apply_grad_clipping,
|
|
@ -3,7 +3,7 @@ from pathlib import Path
|
|||
import unittest
|
||||
|
||||
import ray
|
||||
import ray.rllib.agents.marwil as marwil
|
||||
import ray.rllib.algorithms.marwil as marwil
|
||||
from ray.rllib.utils.framework import try_import_tf
|
||||
from ray.rllib.utils.test_utils import (
|
||||
check_compute_single_action,
|
|
@ -4,7 +4,7 @@ from pathlib import Path
|
|||
import unittest
|
||||
|
||||
import ray
|
||||
import ray.rllib.agents.marwil as marwil
|
||||
import ray.rllib.algorithms.marwil as marwil
|
||||
from ray.rllib.evaluation.postprocessing import compute_advantages
|
||||
from ray.rllib.offline import JsonReader
|
||||
from ray.rllib.utils.framework import try_import_tf, try_import_torch
|
6
rllib/algorithms/mbmpo/__init__.py
Normal file
6
rllib/algorithms/mbmpo/__init__.py
Normal file
|
@ -0,0 +1,6 @@
|
|||
from ray.rllib.algorithms.mbmpo.mbmpo import MBMPOTrainer, DEFAULT_CONFIG
|
||||
|
||||
__all__ = [
|
||||
"MBMPOTrainer",
|
||||
"DEFAULT_CONFIG",
|
||||
]
|
|
@ -4,9 +4,9 @@ from typing import List, Type
|
|||
|
||||
import ray
|
||||
from ray.rllib.agents import with_common_config
|
||||
from ray.rllib.agents.mbmpo.mbmpo_torch_policy import MBMPOTorchPolicy
|
||||
from ray.rllib.agents.mbmpo.model_ensemble import DynamicsEnsembleCustomModel
|
||||
from ray.rllib.agents.mbmpo.utils import calculate_gae_advantages, MBMPOExploration
|
||||
from ray.rllib.algorithms.mbmpo.mbmpo_torch_policy import MBMPOTorchPolicy
|
||||
from ray.rllib.algorithms.mbmpo.model_ensemble import DynamicsEnsembleCustomModel
|
||||
from ray.rllib.algorithms.mbmpo.utils import calculate_gae_advantages, MBMPOExploration
|
||||
from ray.rllib.agents.trainer import Trainer
|
||||
from ray.rllib.env.env_context import EnvContext
|
||||
from ray.rllib.env.wrappers.model_vector_env import model_vector_env
|
|
@ -5,7 +5,7 @@ from typing import Tuple, Type
|
|||
|
||||
import ray
|
||||
from ray.rllib.agents.a3c.a3c_torch_policy import vf_preds_fetches
|
||||
from ray.rllib.agents.maml.maml_torch_policy import (
|
||||
from ray.rllib.algorithms.maml.maml_torch_policy import (
|
||||
setup_mixins,
|
||||
maml_loss,
|
||||
maml_stats,
|
||||
|
@ -121,7 +121,7 @@ def make_model_and_action_dist(
|
|||
MBMPOTorchPolicy = build_policy_class(
|
||||
name="MBMPOTorchPolicy",
|
||||
framework="torch",
|
||||
get_default_config=lambda: ray.rllib.agents.mbmpo.mbmpo.DEFAULT_CONFIG,
|
||||
get_default_config=lambda: ray.rllib.algorithms.mbmpo.mbmpo.DEFAULT_CONFIG,
|
||||
make_model_and_action_dist=make_model_and_action_dist,
|
||||
loss_fn=maml_loss,
|
||||
stats_fn=maml_stats,
|
|
@ -1,7 +1,7 @@
|
|||
import unittest
|
||||
|
||||
import ray
|
||||
import ray.rllib.agents.mbmpo as mbmpo
|
||||
import ray.rllib.algorithms.mbmpo as mbmpo
|
||||
from ray.rllib.utils.test_utils import (
|
||||
check_compute_single_action,
|
||||
check_train_results,
|
12
rllib/env/multi_agent_env.py
vendored
12
rllib/env/multi_agent_env.py
vendored
|
@ -245,7 +245,6 @@ class MultiAgentEnv(gym.Env):
|
|||
|
||||
# fmt: off
|
||||
# __grouping_doc_begin__
|
||||
@ExperimentalAPI
|
||||
def with_agent_groups(
|
||||
self,
|
||||
groups: Dict[str, List[AgentID]],
|
||||
|
@ -265,16 +264,17 @@ class MultiAgentEnv(gym.Env):
|
|||
|
||||
Agent grouping is required to leverage algorithms such as Q-Mix.
|
||||
|
||||
This API is experimental.
|
||||
|
||||
Args:
|
||||
groups: Mapping from group id to a list of the agent ids
|
||||
of group members. If an agent id is not present in any group
|
||||
value, it will be left ungrouped.
|
||||
value, it will be left ungrouped. The group id becomes a new agent ID
|
||||
in the final environment.
|
||||
obs_space: Optional observation space for the grouped
|
||||
env. Must be a tuple space.
|
||||
env. Must be a tuple space. If not provided, will infer this to be a
|
||||
Tuple of n individual agents spaces (n=num agents in a group).
|
||||
act_space: Optional action space for the grouped env.
|
||||
Must be a tuple space.
|
||||
Must be a tuple space. If not provided, will infer this to be a Tuple
|
||||
of n individual agents spaces (n=num agents in a group).
|
||||
|
||||
Examples:
|
||||
>>> from ray.rllib.env.multi_agent_env import MultiAgentEnv
|
||||
|
|
35
rllib/env/wrappers/group_agents_wrapper.py
vendored
35
rllib/env/wrappers/group_agents_wrapper.py
vendored
|
@ -1,6 +1,9 @@
|
|||
from collections import OrderedDict
|
||||
import gym
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from ray.rllib.env.multi_agent_env import MultiAgentEnv
|
||||
from ray.rllib.utils.typing import AgentID
|
||||
|
||||
# info key for the individual rewards of an agent, for example:
|
||||
# info: {
|
||||
|
@ -27,21 +30,35 @@ class GroupAgentsWrapper(MultiAgentEnv):
|
|||
This API is experimental.
|
||||
"""
|
||||
|
||||
def __init__(self, env, groups, obs_space=None, act_space=None):
|
||||
"""Wrap an existing multi-agent env to group agents together.
|
||||
def __init__(
|
||||
self,
|
||||
env: MultiAgentEnv,
|
||||
groups: Dict[str, List[AgentID]],
|
||||
obs_space: Optional[gym.Space] = None,
|
||||
act_space: Optional[gym.Space] = None,
|
||||
):
|
||||
"""Wrap an existing MultiAgentEnv to group agent ID together.
|
||||
|
||||
See MultiAgentEnv.with_agent_groups() for usage info.
|
||||
See `MultiAgentEnv.with_agent_groups()` for more detailed usage info.
|
||||
|
||||
Args:
|
||||
env (MultiAgentEnv): env to wrap
|
||||
groups (dict): Grouping spec as documented in MultiAgentEnv.
|
||||
obs_space (Space): Optional observation space for the grouped
|
||||
env. Must be a tuple space.
|
||||
act_space (Space): Optional action space for the grouped env.
|
||||
Must be a tuple space.
|
||||
env: The env to wrap and whose agent IDs to group into new agents.
|
||||
groups: Mapping from group id to a list of the agent ids
|
||||
of group members. If an agent id is not present in any group
|
||||
value, it will be left ungrouped. The group id becomes a new agent ID
|
||||
in the final environment.
|
||||
obs_space: Optional observation space for the grouped
|
||||
env. Must be a tuple space. If not provided, will infer this to be a
|
||||
Tuple of n individual agents spaces (n=num agents in a group).
|
||||
act_space: Optional action space for the grouped env.
|
||||
Must be a tuple space. If not provided, will infer this to be a Tuple
|
||||
of n individual agents spaces (n=num agents in a group).
|
||||
"""
|
||||
super().__init__()
|
||||
self.env = env
|
||||
# Inherit wrapped env's `_skip_env_checking` flag.
|
||||
if hasattr(self.env, "_skip_env_checking"):
|
||||
self._skip_env_checking = self.env._skip_env_checking
|
||||
self.groups = groups
|
||||
self.agent_id_to_group = {}
|
||||
for group_id, agent_ids in groups.items():
|
||||
|
|
|
@ -20,7 +20,9 @@ import os
|
|||
|
||||
import ray
|
||||
from ray import tune
|
||||
from ray.rllib.agents.maml.maml_torch_policy import KLCoeffMixin as TorchKLCoeffMixin
|
||||
from ray.rllib.algorithms.maml.maml_torch_policy import (
|
||||
KLCoeffMixin as TorchKLCoeffMixin,
|
||||
)
|
||||
from ray.rllib.agents.ppo.ppo import PPOTrainer
|
||||
from ray.rllib.agents.ppo.ppo_tf_policy import (
|
||||
PPOTFPolicy,
|
||||
|
|
|
@ -16,6 +16,7 @@ import os
|
|||
import ray
|
||||
from ray import tune
|
||||
from ray.tune import register_env
|
||||
from ray.rllib.agents.qmix import QMixConfig
|
||||
from ray.rllib.env.multi_agent_env import ENV_STATE
|
||||
from ray.rllib.examples.env.two_step_game import TwoStepGame
|
||||
from ray.rllib.policy.policy import PolicySpec
|
||||
|
@ -110,10 +111,11 @@ if __name__ == "__main__":
|
|||
obs_space = Discrete(6)
|
||||
act_space = TwoStepGame.action_space
|
||||
config = {
|
||||
"learning_starts": 100,
|
||||
"env": TwoStepGame,
|
||||
"env_config": {
|
||||
"actions_are_logits": True,
|
||||
},
|
||||
"learning_starts": 100,
|
||||
"multiagent": {
|
||||
"policies": {
|
||||
"pol1": PolicySpec(
|
||||
|
@ -133,31 +135,33 @@ if __name__ == "__main__":
|
|||
# Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
|
||||
"num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")),
|
||||
}
|
||||
group = False
|
||||
elif args.run == "QMIX":
|
||||
config = {
|
||||
"rollout_fragment_length": 4,
|
||||
"train_batch_size": 32,
|
||||
"exploration_config": {
|
||||
"final_epsilon": 0.0,
|
||||
},
|
||||
"num_workers": 0,
|
||||
"mixer": args.mixer,
|
||||
"env_config": {
|
||||
"separate_state_space": True,
|
||||
"one_hot_state_encoding": True,
|
||||
},
|
||||
# Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
|
||||
"num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")),
|
||||
}
|
||||
group = True
|
||||
config = (
|
||||
QMixConfig()
|
||||
.training(mixer=args.mixer, train_batch_size=32)
|
||||
.rollouts(num_rollout_workers=0, rollout_fragment_length=4)
|
||||
.exploration(
|
||||
exploration_config={
|
||||
"final_epsilon": 0.0,
|
||||
}
|
||||
)
|
||||
.environment(
|
||||
env="grouped_twostep",
|
||||
env_config={
|
||||
"separate_state_space": True,
|
||||
"one_hot_state_encoding": True,
|
||||
},
|
||||
)
|
||||
.resources(num_gpus=int(os.environ.get("RLLIB_NUM_GPUS", "0")))
|
||||
)
|
||||
config = config.to_dict()
|
||||
else:
|
||||
config = {
|
||||
"env": TwoStepGame,
|
||||
# Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
|
||||
"num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")),
|
||||
"framework": args.framework,
|
||||
}
|
||||
group = False
|
||||
|
||||
stop = {
|
||||
"episode_reward_mean": args.stop_reward,
|
||||
|
@ -165,13 +169,6 @@ if __name__ == "__main__":
|
|||
"training_iteration": args.stop_iters,
|
||||
}
|
||||
|
||||
config = dict(
|
||||
config,
|
||||
**{
|
||||
"env": "grouped_twostep" if group else TwoStepGame,
|
||||
}
|
||||
)
|
||||
|
||||
results = tune.run(args.run, stop=stop, config=config, verbose=2)
|
||||
|
||||
if args.as_test:
|
||||
|
|
|
@ -6,7 +6,7 @@ import tree # pip install dm_tree
|
|||
import unittest
|
||||
|
||||
import ray
|
||||
from ray.rllib.agents.marwil import BCTrainer
|
||||
from ray.rllib.algorithms.marwil import BCTrainer
|
||||
from ray.rllib.agents.pg import PGTrainer, DEFAULT_CONFIG
|
||||
from ray.rllib.examples.env.random_env import RandomEnv
|
||||
from ray.rllib.offline.json_reader import JsonReader
|
||||
|
|
|
@ -32,7 +32,7 @@ multi-agent-cartpole-alpha-star:
|
|||
|
||||
# No league-building needed.
|
||||
league_builder_config:
|
||||
type: ray.rllib.agents.alpha_star.league_builder.NoLeagueBuilder
|
||||
type: ray.rllib.algorithms.alpha_star.league_builder.NoLeagueBuilder
|
||||
|
||||
multiagent:
|
||||
policies: ["p0", "p1", "p2", "p3"]
|
||||
|
|
Loading…
Add table
Reference in a new issue