[RLlib] Retry agents -> algorithms. with proper doc changes this time. (#24797)

This commit is contained in:
Jun Gong 2022-05-16 00:45:32 -07:00 committed by GitHub
parent d40fa391a5
commit 68a9a33386
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
59 changed files with 430 additions and 248 deletions

View file

@ -466,7 +466,7 @@ HalfCheetah 13000 ~15000
Model-Agnostic Meta-Learning (MAML)
-----------------------------------
|pytorch| |tensorflow|
`[paper] <https://arxiv.org/abs/1703.03400>`__ `[implementation] <https://github.com/ray-project/ray/blob/master/rllib/agents/maml/maml.py>`__
`[paper] <https://arxiv.org/abs/1703.03400>`__ `[implementation] <https://github.com/ray-project/ray/blob/master/rllib/algorithms/maml/maml.py>`__
RLlib's MAML implementation is a meta-learning method for learning and quick adaptation across different tasks for continuous control. Code here is adapted from https://github.com/jonasrothfuss, which outperforms vanilla MAML and avoids computation of the higher order gradients during the meta-update step. MAML is evaluated on custom environments that are described in greater detail `here <https://github.com/ray-project/ray/blob/master/rllib/env/apis/task_settable_env.py>`__.
@ -476,7 +476,7 @@ Tuned examples: HalfCheetahRandDirecEnv (`Env <https://github.com/ray-project/ra
**MAML-specific configs** (see also `common configs <rllib-training.html#common-parameters>`__):
.. literalinclude:: ../../../rllib/agents/maml/maml.py
.. literalinclude:: ../../../rllib/algorithms/maml/maml.py
:language: python
:start-after: __sphinx_doc_begin__
:end-before: __sphinx_doc_end__
@ -486,7 +486,7 @@ Tuned examples: HalfCheetahRandDirecEnv (`Env <https://github.com/ray-project/ra
Model-Based Meta-Policy-Optimization (MB-MPO)
---------------------------------------------
|pytorch|
`[paper] <https://arxiv.org/pdf/1809.05214.pdf>`__ `[implementation] <https://github.com/ray-project/ray/blob/master/rllib/agents/mbmpo/mbmpo.py>`__
`[paper] <https://arxiv.org/pdf/1809.05214.pdf>`__ `[implementation] <https://github.com/ray-project/ray/blob/master/rllib/algorithms/mbmpo/mbmpo.py>`__
RLlib's MBMPO implementation is a Dyna-styled model-based RL method that learns based on the predictions of an ensemble of transition-dynamics models. Similar to MAML, MBMPO metalearns an optimial policy by treating each dynamics model as a different task. Code here is adapted from https://github.com/jonasrothfuss/model_ensemble_meta_learning. Similar to the original paper, MBMPO is evaluated on MuJoCo, with the horizon set to 200 instead of the default 1000.
@ -510,7 +510,7 @@ Hopper 620 ~650
**MBMPO-specific configs** (see also `common configs <rllib-training.html#common-parameters>`__):
.. literalinclude:: ../../../rllib/agents/mbmpo/mbmpo.py
.. literalinclude:: ../../../rllib/algorithms/mbmpo/mbmpo.py
:language: python
:start-after: __sphinx_doc_begin__
:end-before: __sphinx_doc_end__
@ -539,7 +539,7 @@ Cheetah-Run 640 ~800
**Dreamer-specific configs** (see also `common configs <rllib-training.html#common-parameters>`__):
.. literalinclude:: ../../../rllib/agents/dreamer/dreamer.py
.. literalinclude:: ../../../rllib/algorithms/dreamer/dreamer.py
:language: python
:start-after: __sphinx_doc_begin__
:end-before: __sphinx_doc_end__
@ -567,7 +567,7 @@ RecSim environment wrapper: `Google RecSim <https://github.com/ray-project/ray/b
Conservative Q-Learning (CQL)
-----------------------------------
|pytorch| |tensorflow|
`[paper] <https://arxiv.org/abs/2006.04779>`__ `[implementation] <https://github.com/ray-project/ray/blob/master/rllib/agents/cql/cql.py>`__
`[paper] <https://arxiv.org/abs/2006.04779>`__ `[implementation] <https://github.com/ray-project/ray/blob/master/rllib/algorithms/cql/cql.py>`__
In offline RL, the algorithm has no access to an environment, but can only sample from a fixed dataset of pre-collected state-action-reward tuples.
In particular, CQL (Conservative Q-Learning) is an offline RL algorithm that mitigates the overestimation of Q-values outside the dataset distribution via
@ -581,7 +581,7 @@ Tuned examples: `HalfCheetah Random <https://github.com/ray-project/ray/blob/mas
**CQL-specific configs** (see also `common configs <rllib-training.html#common-parameters>`__):
.. literalinclude:: ../../../rllib/agents/cql/cql.py
.. literalinclude:: ../../../rllib/algorithms/cql/cql.py
:language: python
:start-after: __sphinx_doc_begin__
:end-before: __sphinx_doc_end__
@ -636,7 +636,7 @@ Monotonic Advantage Re-Weighted Imitation Learning (MARWIL)
-----------------------------------------------------------
|pytorch| |tensorflow|
`[paper] <http://papers.nips.cc/paper/7866-exponentially-weighted-imitation-learning-for-batched-historical-data>`__
`[implementation] <https://github.com/ray-project/ray/blob/master/rllib/agents/marwil/marwil.py>`__
`[implementation] <https://github.com/ray-project/ray/blob/master/rllib/algorithms/marwil/marwil.py>`__
MARWIL is a hybrid imitation learning and policy gradient algorithm suitable for training on batched historical data.
When the ``beta`` hyperparameter is set to zero, the MARWIL objective reduces to vanilla imitation learning (see `BC`_).
@ -646,7 +646,7 @@ Tuned examples: `CartPole-v0 <https://github.com/ray-project/ray/blob/master/rll
**MARWIL-specific configs** (see also `common configs <rllib-training.html#common-parameters>`__):
.. literalinclude:: ../../../rllib/agents/marwil/marwil.py
.. literalinclude:: ../../../rllib/algorithms/marwil/marwil.py
:language: python
:start-after: __sphinx_doc_begin__
:end-before: __sphinx_doc_end__
@ -658,7 +658,7 @@ Behavior Cloning (BC; derived from MARWIL implementation)
---------------------------------------------------------
|pytorch| |tensorflow|
`[paper] <http://papers.nips.cc/paper/7866-exponentially-weighted-imitation-learning-for-batched-historical-data>`__
`[implementation] <https://github.com/ray-project/ray/blob/master/rllib/agents/marwil/bc.py>`__
`[implementation] <https://github.com/ray-project/ray/blob/master/rllib/algorithms/marwil/bc.py>`__
Our behavioral cloning implementation is directly derived from our `MARWIL`_ implementation,
with the only difference being the ``beta`` parameter force-set to 0.0. This makes
@ -669,7 +669,7 @@ Tuned examples: `CartPole-v0 <https://github.com/ray-project/ray/blob/master/rll
**BC-specific configs** (see also `common configs <rllib-training.html#common-parameters>`__):
.. literalinclude:: ../../../rllib/agents/marwil/bc.py
.. literalinclude:: ../../../rllib/algorithms/marwil/bc.py
:language: python
:start-after: __sphinx_doc_begin__
:end-before: __sphinx_doc_end__

View file

@ -647,7 +647,7 @@ py_test(
name = "test_alpha_star",
tags = ["team:ml", "trainers_dir"],
size = "large",
srcs = ["agents/alpha_star/tests/test_alpha_star.py"]
srcs = ["algorithms/alpha_star/tests/test_alpha_star.py"]
)
# APEXTrainer (DQN)
@ -687,7 +687,7 @@ py_test(
name = "test_cql",
tags = ["team:ml", "trainers_dir"],
size = "medium",
srcs = ["agents/cql/tests/test_cql.py"]
srcs = ["algorithms/cql/tests/test_cql.py"]
)
# DDPGTrainer
@ -711,7 +711,7 @@ py_test(
name = "test_dreamer",
tags = ["team:ml", "trainers_dir"],
size = "small",
srcs = ["agents/dreamer/tests/test_dreamer.py"]
srcs = ["algorithms/dreamer/tests/test_dreamer.py"]
)
# ES
@ -743,7 +743,7 @@ py_test(
size = "large",
# Include the json data file.
data = ["tests/data/cartpole/large.json"],
srcs = ["agents/marwil/tests/test_marwil.py"]
srcs = ["algorithms/marwil/tests/test_marwil.py"]
)
# BCTrainer (sub-type of MARWIL)
@ -753,7 +753,7 @@ py_test(
size = "large",
# Include the json data file.
data = ["tests/data/cartpole/large.json"],
srcs = ["agents/marwil/tests/test_bc.py"]
srcs = ["algorithms/marwil/tests/test_bc.py"]
)
# MAMLTrainer
@ -761,7 +761,7 @@ py_test(
name = "test_maml",
tags = ["team:ml", "trainers_dir"],
size = "medium",
srcs = ["agents/maml/tests/test_maml.py"]
srcs = ["algorithms/maml/tests/test_maml.py"]
)
# MBMPOTrainer
@ -769,7 +769,7 @@ py_test(
name = "test_mbmpo",
tags = ["team:ml", "trainers_dir"],
size = "medium",
srcs = ["agents/mbmpo/tests/test_mbmpo.py"]
srcs = ["algorithms/mbmpo/tests/test_mbmpo.py"]
)
# PGTrainer

View file

@ -1,4 +1,4 @@
from ray.rllib.agents.alpha_star.alpha_star import (
from ray.rllib.algorithms.alpha_star.alpha_star import (
AlphaStarConfig,
AlphaStarTrainer,
DEFAULT_CONFIG,
@ -9,3 +9,10 @@ __all__ = [
"AlphaStarTrainer",
"DEFAULT_CONFIG",
]
from ray.rllib.utils.deprecation import deprecation_warning
deprecation_warning(
"ray.rllib.agents.alpha_star", "ray.rllib.algorithms.alpha_star", error=False
)

View file

@ -1,8 +1,13 @@
from ray.rllib.agents.cql.cql import CQLTrainer, CQL_DEFAULT_CONFIG
from ray.rllib.agents.cql.cql_torch_policy import CQLTorchPolicy
from ray.rllib.algorithms.cql.cql import CQLTrainer, CQL_DEFAULT_CONFIG
from ray.rllib.algorithms.cql.cql_tf_policy import CQLTFPolicy
from ray.rllib.algorithms.cql.cql_torch_policy import CQLTorchPolicy
from ray.rllib.utils.deprecation import deprecation_warning
__all__ = [
"CQL_DEFAULT_CONFIG",
"CQLTFPolicy",
"CQLTorchPolicy",
"CQLTrainer",
]
deprecation_warning("ray.rllib.agents.cql", "ray.rllib.algorithms.cql", error=False)

View file

@ -100,9 +100,9 @@ class SimpleQConfig(TrainerConfig):
>>> .exploration(exploration_config=explore_config)
"""
def __init__(self):
def __init__(self, trainer_class=None):
"""Initializes a SimpleQConfig instance."""
super().__init__(trainer_class=SimpleQTrainer)
super().__init__(trainer_class=trainer_class or SimpleQTrainer)
# Simple Q specific
# fmt: off

View file

@ -1,4 +1,4 @@
from ray.rllib.agents.dreamer.dreamer import (
from ray.rllib.algorithms.dreamer.dreamer import (
DREAMERConfig,
DREAMERTrainer,
DEFAULT_CONFIG,
@ -9,3 +9,10 @@ __all__ = [
"DREAMERTrainer",
"DEFAULT_CONFIG",
]
from ray.rllib.utils.deprecation import deprecation_warning
deprecation_warning(
"ray.rllib.agents.dreamer", "ray.rllib.algorithms.dreamer", error=False
)

View file

@ -1,6 +1,10 @@
from ray.rllib.agents.maml.maml import MAMLTrainer, DEFAULT_CONFIG
from ray.rllib.algorithms.maml.maml import MAMLTrainer, DEFAULT_CONFIG
__all__ = [
"MAMLTrainer",
"DEFAULT_CONFIG",
]
from ray.rllib.utils.deprecation import deprecation_warning
deprecation_warning("ray.rllib.agents.maml", "ray.rllib.algorithms.maml", error=False)

View file

@ -1,7 +1,7 @@
from ray.rllib.agents.marwil.bc import BCTrainer, BC_DEFAULT_CONFIG
from ray.rllib.agents.marwil.marwil import MARWILTrainer, DEFAULT_CONFIG
from ray.rllib.agents.marwil.marwil_tf_policy import MARWILTFPolicy
from ray.rllib.agents.marwil.marwil_torch_policy import MARWILTorchPolicy
from ray.rllib.algorithms.marwil.bc import BCTrainer, BC_DEFAULT_CONFIG
from ray.rllib.algorithms.marwil.marwil import MARWILTrainer, DEFAULT_CONFIG
from ray.rllib.algorithms.marwil.marwil_tf_policy import MARWILTFPolicy
from ray.rllib.algorithms.marwil.marwil_torch_policy import MARWILTorchPolicy
__all__ = [
"BCTrainer",
@ -11,3 +11,10 @@ __all__ = [
"MARWILTorchPolicy",
"MARWILTrainer",
]
from ray.rllib.utils.deprecation import deprecation_warning
deprecation_warning(
"ray.rllib.agents.marwil", "ray.rllib.algorithms.marwil", error=False
)

View file

@ -1,6 +1,11 @@
from ray.rllib.agents.mbmpo.mbmpo import MBMPOTrainer, DEFAULT_CONFIG
from ray.rllib.algorithms.mbmpo.mbmpo import MBMPOTrainer, DEFAULT_CONFIG
__all__ = [
"MBMPOTrainer",
"DEFAULT_CONFIG",
]
from ray.rllib.utils.deprecation import deprecation_warning
deprecation_warning("ray.rllib.agents.mbmpo", "ray.rllib.algorithms.mbmpo", error=False)

View file

@ -1,3 +1,3 @@
from ray.rllib.agents.qmix.qmix import QMixTrainer, DEFAULT_CONFIG
from ray.rllib.agents.qmix.qmix import QMixConfig, QMixTrainer, DEFAULT_CONFIG
__all__ = ["QMixTrainer", "DEFAULT_CONFIG"]
__all__ = ["QMixConfig", "QMixTrainer", "DEFAULT_CONFIG"]

View file

@ -1,7 +1,6 @@
from typing import Type
from typing import Optional, Type
from ray.rllib.agents.trainer import with_common_config
from ray.rllib.agents.dqn.simple_q import SimpleQTrainer
from ray.rllib.agents.dqn.simple_q import SimpleQConfig, SimpleQTrainer
from ray.rllib.agents.qmix.qmix_policy import QMixTorchPolicy
from ray.rllib.execution.rollout_ops import (
synchronous_parallel_sample,
@ -12,7 +11,7 @@ from ray.rllib.execution.train_ops import (
)
from ray.rllib.policy.policy import Policy
from ray.rllib.utils.annotations import override
from ray.rllib.utils.deprecation import DEPRECATED_VALUE
from ray.rllib.utils.deprecation import DEPRECATED_VALUE, Deprecated
from ray.rllib.utils.metrics import (
LAST_TARGET_UPDATE_TS,
NUM_AGENT_STEPS_SAMPLED,
@ -23,118 +22,174 @@ from ray.rllib.utils.metrics import (
from ray.rllib.utils.replay_buffers.utils import sample_min_n_steps_from_buffer
from ray.rllib.utils.typing import ResultDict, TrainerConfigDict
# fmt: off
# __sphinx_doc_begin__
DEFAULT_CONFIG = with_common_config({
# === QMix ===
# Mixing network. Either "qmix", "vdn", or None
"mixer": "qmix",
# Size of the mixing network embedding
"mixing_embed_dim": 32,
# Whether to use Double_Q learning
"double_q": True,
# Optimize over complete episodes by default.
"batch_mode": "complete_episodes",
# === Exploration Settings ===
"exploration_config": {
# The Exploration class to use.
"type": "EpsilonGreedy",
# Config for the Exploration class' constructor:
"initial_epsilon": 1.0,
"final_epsilon": 0.01,
# Timesteps over which to anneal epsilon.
"epsilon_timesteps": 40000,
class QMixConfig(SimpleQConfig):
"""Defines a configuration class from which a QMixTrainer can be built.
# For soft_q, use:
# "exploration_config" = {
# "type": "SoftQ"
# "temperature": [float, e.g. 1.0]
# }
},
Example:
>>> from ray.rllib.examples.env.two_step_game import TwoStepGame
>>> from ray.rllib.agents.qmix import QMixConfig
>>> config = QMixConfig().training(gamma=0.9, lr=0.01, kl_coeff=0.3)\
... .resources(num_gpus=0)\
... .rollouts(num_workers=4)
>>> print(config.to_dict())
>>> # Build a Trainer object from the config and run 1 training iteration.
>>> trainer = config.build(env=TwoStepGame)
>>> trainer.train()
# === Evaluation ===
# Evaluate with epsilon=0 every `evaluation_interval` training iterations.
# The evaluation stats will be reported under the "evaluation" metric key.
# Note that evaluation is currently not parallelized, and that for Ape-X
# metrics are already only reported for the lowest epsilon workers.
"evaluation_interval": None,
# Number of episodes to run per evaluation period.
"evaluation_duration": 10,
# Switch to greedy actions in evaluation workers.
"evaluation_config": {
"explore": False,
},
Example:
>>> from ray.rllib.examples.env.two_step_game import TwoStepGame
>>> from ray.rllib.agents.qmix import QMixConfig
>>> from ray import tune
>>> config = QMixConfig()
>>> # Print out some default values.
>>> print(config.optim_alpha)
>>> # Update the config object.
>>> config.training(lr=tune.grid_search([0.001, 0.0001]), optim_alpha=0.97)
>>> # Set the config object's env.
>>> config.environment(env=TwoStepGame)
>>> # Use to_dict() to get the old-style python config dict
>>> # when running with tune.
>>> tune.run(
... "QMix",
... stop={"episode_reward_mean": 200},
... config=config.to_dict(),
... )
"""
# Minimum env sampling timesteps to accumulate within a single `train()` call. This
# value does not affect learning, only the number of times `Trainer.step_attempt()`
# is called by `Trauber.train()`. If - after one `step_attempt()`, the env sampling
# timestep count has not been reached, will perform n more `step_attempt()` calls
# until the minimum timesteps have been executed. Set to 0 for no minimum timesteps.
"min_sample_timesteps_per_reporting": 1000,
# Update the target network every `target_network_update_freq` steps.
"target_network_update_freq": 500,
def __init__(self):
"""Initializes a PPOConfig instance."""
super().__init__(trainer_class=QMixTrainer)
# === Replay buffer ===
"replay_buffer_config": {
# Use the new ReplayBuffer API here
"_enable_replay_buffer_api": True,
"type": "SimpleReplayBuffer",
# Size of the replay buffer in batches (not timesteps!).
"capacity": 1000,
"learning_starts": 1000,
},
# fmt: off
# __sphinx_doc_begin__
# QMix specific settings:
self.mixer = "qmix"
self.mixing_embed_dim = 32
self.double_q = True
self.target_network_update_freq = 500
self.replay_buffer_config = {
# Use the new ReplayBuffer API here
"_enable_replay_buffer_api": True,
"type": "SimpleReplayBuffer",
# Size of the replay buffer in batches (not timesteps!).
"capacity": 1000,
"learning_starts": 1000,
}
self.optim_alpha = 0.99
self.optim_eps = 0.00001
self.grad_norm_clipping = 10
self.worker_side_prioritization = False
# === Optimization ===
# Learning rate for RMSProp optimizer
"lr": 0.0005,
# RMSProp alpha
"optim_alpha": 0.99,
# RMSProp epsilon
"optim_eps": 0.00001,
# If not None, clip gradients during optimization at this value
"grad_norm_clipping": 10,
# Update the replay buffer with this many samples at once. Note that
# this setting applies per-worker if num_workers > 1.
"rollout_fragment_length": 4,
# Minimum batch size used for training (in timesteps). With the default buffer
# (ReplayBuffer) this means, sampling from the buffer (entire-episode SampleBatches)
# as many times as is required to reach at least this number of timesteps.
"train_batch_size": 32,
# Override some of TrainerConfig's default values with QMix-specific values.
self.num_workers = 0
self.min_time_s_per_reporting = 1
self.model = {
"lstm_cell_size": 64,
"max_seq_len": 999999,
}
self.framework_str = "torch"
self.lr = 0.0005
self.rollout_fragment_length = 4
self.train_batch_size = 32
self.batch_mode = "complete_episodes"
self.exploration_config = {
# The Exploration class to use.
"type": "EpsilonGreedy",
# Config for the Exploration class' constructor:
"initial_epsilon": 1.0,
"final_epsilon": 0.01,
# Timesteps over which to anneal epsilon.
"epsilon_timesteps": 40000,
# === Parallelism ===
# Number of workers for collecting samples with. This only makes sense
# to increase if your environment is particularly slow to sample, or if
# you"re using the Async or Ape-X optimizers.
"num_workers": 0,
# Whether to compute priorities on workers.
"worker_side_prioritization": False,
# Prevent reporting frequency from going lower than this time span.
"min_time_s_per_reporting": 1,
# For soft_q, use:
# "exploration_config" = {
# "type": "SoftQ"
# "temperature": [float, e.g. 1.0]
# }
}
# === Model ===
"model": {
"lstm_cell_size": 64,
"max_seq_len": 999999,
},
# Only torch supported so far.
"framework": "torch",
# Evaluate with epsilon=0 every `evaluation_interval` training iterations.
# The evaluation stats will be reported under the "evaluation" metric key.
# Note that evaluation is currently not parallelized, and that for Ape-X
# metrics are already only reported for the lowest epsilon workers.
self.evaluation_interval = None
self.evaluation_duration = 10
self.evaluation_config = {
"explore": False,
}
self.min_sample_timesteps_per_reporting = 1000
# __sphinx_doc_end__
# fmt: on
# Deprecated keys:
# Use `replay_buffer_config.learning_starts` instead.
"learning_starts": DEPRECATED_VALUE,
# Use `replay_buffer_config.capacity` instead.
"buffer_size": DEPRECATED_VALUE,
})
# __sphinx_doc_end__
# fmt: on
# Deprecated keys:
self.learning_starts = DEPRECATED_VALUE
self.buffer_size = DEPRECATED_VALUE
@override(SimpleQConfig)
def training(
self,
*,
mixer: Optional[str] = None,
mixing_embed_dim: Optional[int] = None,
double_q: Optional[bool] = None,
target_network_update_freq: Optional[int] = None,
replay_buffer_config: Optional[dict] = None,
optim_alpha: Optional[float] = None,
optim_eps: Optional[float] = None,
grad_norm_clipping: Optional[float] = None,
worker_side_prioritization: Optional[bool] = None,
**kwargs,
) -> "QMixConfig":
"""Sets the training related configuration.
Args:
mixer: Mixing network. Either "qmix", "vdn", or None.
mixing_embed_dim: Size of the mixing network embedding.
double_q: Whether to use Double_Q learning.
target_network_update_freq: Update the target network every
`target_network_update_freq` sample steps.
replay_buffer_config:
optim_alpha: RMSProp alpha.
optim_eps: RMSProp epsilon.
grad_norm_clipping: If not None, clip gradients during optimization at
this value.
worker_side_prioritization: Whether to compute priorities for the replay
buffer on worker side.
Returns:
This updated TrainerConfig object.
"""
# Pass kwargs onto super's `training()` method.
super().training(**kwargs)
if mixer is not None:
self.mixer = mixer
if mixing_embed_dim is not None:
self.mixing_embed_dim = mixing_embed_dim
if double_q is not None:
self.double_q = double_q
if target_network_update_freq is not None:
self.target_network_update_freq = target_network_update_freq
if replay_buffer_config is not None:
self.replay_buffer_config = replay_buffer_config
if optim_alpha is not None:
self.optim_alpha = optim_alpha
if optim_eps is not None:
self.optim_eps = optim_eps
if grad_norm_clipping is not None:
self.grad_norm_clipping = grad_norm_clipping
if worker_side_prioritization is not None:
self.worker_side_prioritization = worker_side_prioritization
return self
class QMixTrainer(SimpleQTrainer):
@classmethod
@override(SimpleQTrainer)
def get_default_config(cls) -> TrainerConfigDict:
return DEFAULT_CONFIG
return QMixConfig().to_dict()
@override(SimpleQTrainer)
def validate_config(self, config: TrainerConfigDict) -> None:
@ -219,3 +274,20 @@ class QMixTrainer(SimpleQTrainer):
# Return all collected metrics for the iteration.
return train_results
# Deprecated: Use ray.rllib.agents.qmix.qmix.QMixConfig instead!
class _deprecated_default_config(dict):
def __init__(self):
super().__init__(QMixConfig().to_dict())
@Deprecated(
old="ray.rllib.agents.qmix.qmix.DEFAULT_CONFIG",
new="ray.rllib.agents.qmix.qmix.QMixConfig(...)",
error=False,
)
def __getitem__(self, item):
return super().__getitem__(item)
DEFAULT_CONFIG = _deprecated_default_config()

View file

@ -153,7 +153,6 @@ class QMixLoss(nn.Module):
return loss, mask, masked_td_error, chosen_action_qvals, targets
# TODO(sven): Make this a TorchPolicy child via `build_policy_class`.
class QMixTorchPolicy(TorchPolicy):
"""QMix impl. Assumes homogeneous agents for now.
@ -177,9 +176,6 @@ class QMixTorchPolicy(TorchPolicy):
self.h_size = config["model"]["lstm_cell_size"]
self.has_env_global_state = False
self.has_action_mask = False
self.device = (
torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
)
agent_obs_space = obs_space.original_space.spaces[0]
if isinstance(agent_obs_space, gym.spaces.Dict):
@ -218,7 +214,9 @@ class QMixTorchPolicy(TorchPolicy):
framework="torch",
name="model",
default_model=RNNModel,
).to(self.device)
)
super().__init__(obs_space, action_space, config, model=self.model)
self.target_model = ModelCatalog.get_model_v2(
agent_obs_space,
@ -230,8 +228,6 @@ class QMixTorchPolicy(TorchPolicy):
default_model=RNNModel,
).to(self.device)
super().__init__(obs_space, action_space, config, model=self.model)
self.exploration = self._create_exploration()
# Setup the mixer network.

View file

@ -4,7 +4,7 @@ import unittest
import ray
from ray.tune import register_env
from ray.rllib.agents.qmix import QMixTrainer
from ray.rllib.agents.qmix import QMixConfig
from ray.rllib.env.multi_agent_env import MultiAgentEnv
@ -95,18 +95,21 @@ class TestQMix(unittest.TestCase):
),
)
trainer = QMixTrainer(
env="action_mask_test",
config={
"num_envs_per_worker": 5, # test with vectorization on
"env_config": {
"avail_actions": [3, 4, 8],
},
"framework": "torch",
},
)
config = (
QMixConfig()
.framework(framework="torch")
.environment(
env="action_mask_test",
env_config={"avail_actions": [3, 4, 8]},
)
.rollouts(num_envs_per_worker=5)
) # Test with vectorization on.
trainer = config.build()
for _ in range(4):
trainer.train() # OK if it doesn't trip the action assertion error
assert trainer.train()["episode_reward_mean"] == 30.0
trainer.stop()
ray.shutdown()

View file

@ -18,7 +18,10 @@ def _import_a3c():
def _import_alpha_star():
from ray.rllib.agents.alpha_star.alpha_star import AlphaStarTrainer, DEFAULT_CONFIG
from ray.rllib.algorithms.alpha_star.alpha_star import (
AlphaStarTrainer,
DEFAULT_CONFIG,
)
return AlphaStarTrainer, DEFAULT_CONFIG
@ -42,7 +45,7 @@ def _import_appo():
def _import_ars():
from ray.rllib.agents import ars
from ray.rllib.algorithms import ars
return ars.ARSTrainer, ars.DEFAULT_CONFIG
@ -60,13 +63,13 @@ def _import_bandit_linucb():
def _import_bc():
from ray.rllib.agents import marwil
from ray.rllib.algorithms import marwil
return marwil.BCTrainer, marwil.DEFAULT_CONFIG
def _import_cql():
from ray.rllib.agents import cql
from ray.rllib.algorithms import cql
return cql.CQLTrainer, cql.CQL_DEFAULT_CONFIG
@ -90,13 +93,13 @@ def _import_dqn():
def _import_dreamer():
from ray.rllib.agents import dreamer
from ray.rllib.algorithms import dreamer
return dreamer.DREAMERTrainer, dreamer.DEFAULT_CONFIG
def _import_es():
from ray.rllib.agents import es
from ray.rllib.algorithms import es
return es.ESTrainer, es.DEFAULT_CONFIG
@ -114,19 +117,19 @@ def _import_maddpg():
def _import_maml():
from ray.rllib.agents import maml
from ray.rllib.algorithms import maml
return maml.MAMLTrainer, maml.DEFAULT_CONFIG
def _import_marwil():
from ray.rllib.agents import marwil
from ray.rllib.algorithms import marwil
return marwil.MARWILTrainer, marwil.DEFAULT_CONFIG
def _import_mbmpo():
from ray.rllib.agents import mbmpo
from ray.rllib.algorithms import mbmpo
return mbmpo.MBMPOTrainer, mbmpo.DEFAULT_CONFIG

View file

@ -10,7 +10,7 @@ import unittest
import ray
import ray.rllib.agents.a3c as a3c
import ray.rllib.agents.dqn as dqn
from ray.rllib.agents.marwil import BCTrainer
from ray.rllib.algorithms.marwil import BCTrainer
import ray.rllib.agents.pg as pg
from ray.rllib.agents.trainer import COMMON_CONFIG
from ray.rllib.examples.env.multi_agent import MultiAgentCartPole

View file

@ -0,0 +1,11 @@
from ray.rllib.algorithms.alpha_star.alpha_star import (
AlphaStarConfig,
AlphaStarTrainer,
DEFAULT_CONFIG,
)
__all__ = [
"AlphaStarConfig",
"AlphaStarTrainer",
"DEFAULT_CONFIG",
]

View file

@ -9,8 +9,8 @@ from typing import Any, Dict, Optional, Type
import ray
from ray.actor import ActorHandle
from ray.rllib.agents.alpha_star.distributed_learners import DistributedLearners
from ray.rllib.agents.alpha_star.league_builder import AlphaStarLeagueBuilder
from ray.rllib.algorithms.alpha_star.distributed_learners import DistributedLearners
from ray.rllib.algorithms.alpha_star.league_builder import AlphaStarLeagueBuilder
from ray.rllib.agents.trainer import Trainer
import ray.rllib.agents.ppo.appo as appo
from ray.rllib.evaluation.rollout_worker import RolloutWorker
@ -48,7 +48,7 @@ class AlphaStarConfig(appo.APPOConfig):
"""Defines a configuration class from which an AlphaStarTrainer can be built.
Example:
>>> from ray.rllib.agents.alpha_star import AlphaStarConfig
>>> from ray.rllib.algorithms.alpha_star import AlphaStarConfig
>>> config = AlphaStarConfig().training(lr=0.0003, train_batch_size=512)\
... .resources(num_gpus=4)\
... .rollouts(num_rollout_workers=64)
@ -58,7 +58,7 @@ class AlphaStarConfig(appo.APPOConfig):
>>> trainer.train()
Example:
>>> from ray.rllib.agents.alpha_star import AlphaStarConfig
>>> from ray.rllib.algorithms.alpha_star import AlphaStarConfig
>>> from ray import tune
>>> config = AlphaStarConfig()
>>> # Print out some default values.
@ -168,7 +168,7 @@ class AlphaStarConfig(appo.APPOConfig):
to be used for league building logic. All other keys (that are not
`type`) will be used as constructor kwargs on the given class to
construct the LeagueBuilder instance. See the
`ray.rllib.agents.alpha_star.league_builder::AlphaStarLeagueBuilder`
`ray.rllib.algorithms.alpha_star.league_builder::AlphaStarLeagueBuilder`
(used by default by this algo) as an example.
max_num_policies_to_train: The maximum number of trainable policies for this
Trainer. Each trainable policy will exist as a independent remote actor,
@ -584,8 +584,8 @@ class _deprecated_default_config(dict):
super().__init__(AlphaStarConfig().to_dict())
@Deprecated(
old="ray.rllib.agents.alpha_star.alpha_star.DEFAULT_CONFIG",
new="ray.rllib.agents.alpha_star.alpha_star.AlphaStarConfig(...)",
old="ray.rllib.algorithms.alpha_star.alpha_star.DEFAULT_CONFIG",
new="ray.rllib.algorithms.alpha_star.alpha_star.AlphaStarConfig(...)",
error=False,
)
def __getitem__(self, item):

View file

@ -2,7 +2,7 @@ import pyspiel
import unittest
import ray
import ray.rllib.agents.alpha_star as alpha_star
import ray.rllib.algorithms.alpha_star as alpha_star
from ray.rllib.env.wrappers.open_spiel import OpenSpielEnv
from ray.rllib.utils.test_utils import (
check_compute_single_action,

View file

@ -0,0 +1,8 @@
from ray.rllib.algorithms.cql.cql import CQLTrainer, CQL_DEFAULT_CONFIG
from ray.rllib.algorithms.cql.cql_torch_policy import CQLTorchPolicy
__all__ = [
"CQL_DEFAULT_CONFIG",
"CQLTorchPolicy",
"CQLTrainer",
]

View file

@ -2,8 +2,8 @@ import logging
import numpy as np
from typing import Type
from ray.rllib.agents.cql.cql_tf_policy import CQLTFPolicy
from ray.rllib.agents.cql.cql_torch_policy import CQLTorchPolicy
from ray.rllib.algorithms.cql.cql_tf_policy import CQLTFPolicy
from ray.rllib.algorithms.cql.cql_torch_policy import CQLTorchPolicy
from ray.rllib.agents.sac.sac import SACTrainer, DEFAULT_CONFIG as SAC_CONFIG
from ray.rllib.execution.train_ops import (
multi_gpu_train_one_step,

View file

@ -411,7 +411,7 @@ def apply_gradients_fn(policy, optimizer, grads_and_vars):
CQLTFPolicy = build_tf_policy(
name="CQLTFPolicy",
loss_fn=cql_loss,
get_default_config=lambda: ray.rllib.agents.cql.cql.CQL_DEFAULT_CONFIG,
get_default_config=lambda: ray.rllib.algorithms.cql.cql.CQL_DEFAULT_CONFIG,
validate_spaces=validate_spaces,
stats_fn=cql_stats,
postprocess_fn=postprocess_trajectory,

View file

@ -387,7 +387,7 @@ CQLTorchPolicy = build_policy_class(
name="CQLTorchPolicy",
framework="torch",
loss_fn=cql_loss,
get_default_config=lambda: ray.rllib.agents.cql.cql.CQL_DEFAULT_CONFIG,
get_default_config=lambda: ray.rllib.algorithms.cql.cql.CQL_DEFAULT_CONFIG,
stats_fn=cql_stats,
postprocess_fn=postprocess_trajectory,
extra_grad_process_fn=apply_grad_clipping,

View file

@ -4,7 +4,7 @@ import os
import unittest
import ray
import ray.rllib.agents.cql as cql
import ray.rllib.algorithms.cql as cql
from ray.rllib.utils.framework import try_import_tf, try_import_torch
from ray.rllib.utils.test_utils import (
check_compute_single_action,

View file

@ -0,0 +1,11 @@
from ray.rllib.algorithms.dreamer.dreamer import (
DREAMERConfig,
DREAMERTrainer,
DEFAULT_CONFIG,
)
__all__ = [
"DREAMERConfig",
"DREAMERTrainer",
"DEFAULT_CONFIG",
]

View file

@ -4,12 +4,12 @@ import random
from typing import Optional
from ray.rllib.agents.trainer_config import TrainerConfig
from ray.rllib.agents.dreamer.dreamer_torch_policy import DreamerTorchPolicy
from ray.rllib.algorithms.dreamer.dreamer_torch_policy import DreamerTorchPolicy
from ray.rllib.agents.trainer import Trainer
from ray.rllib.execution.common import STEPS_SAMPLED_COUNTER, _get_shared_metrics
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID, SampleBatch
from ray.rllib.evaluation.metrics import collect_metrics
from ray.rllib.agents.dreamer.dreamer_model import DreamerModel
from ray.rllib.algorithms.dreamer.dreamer_model import DreamerModel
from ray.rllib.execution.rollout_ops import (
ParallelRollouts,
synchronous_parallel_sample,
@ -31,7 +31,7 @@ class DREAMERConfig(TrainerConfig):
"""Defines a PPOTrainer configuration class from which a PPOTrainer can be built.
Example:
>>> from ray.rllib.agents.dreamer import DREAMERConfig
>>> from ray.rllib.algorithms.dreamer import DREAMERConfig
>>> config = DREAMERConfig().training(gamma=0.9, lr=0.01)\
... .resources(num_gpus=0)\
... .rollouts(num_rollout_workers=4)
@ -42,7 +42,7 @@ class DREAMERConfig(TrainerConfig):
Example:
>>> from ray import tune
>>> from ray.rllib.agents.dreamer import DREAMERConfig
>>> from ray.rllib.algorithms.dreamer import DREAMERConfig
>>> config = DREAMERConfig()
>>> # Print out some default values.
>>> print(config.clip_param)
@ -418,14 +418,14 @@ class DREAMERTrainer(Trainer):
return results
# Deprecated: Use ray.rllib.agents.dreamer.DREAMERConfig instead!
# Deprecated: Use ray.rllib.algorithms.dreamer.DREAMERConfig instead!
class _deprecated_default_config(dict):
def __init__(self):
super().__init__(DREAMERConfig().to_dict())
@Deprecated(
old="ray.rllib.agents.dreamer.dreamer.DEFAULT_CONFIG",
new="ray.rllib.agents.dreamer.dreamer.DREAMERConfig(...)",
old="ray.rllib.algorithms.dreamer.dreamer.DEFAULT_CONFIG",
new="ray.rllib.algorithms.dreamer.dreamer.DREAMERConfig(...)",
error=False,
)
def __getitem__(self, item):

View file

@ -8,7 +8,7 @@ from ray.rllib.utils.framework import TensorType
torch, nn = try_import_torch()
if torch:
from torch import distributions as td
from ray.rllib.agents.dreamer.utils import (
from ray.rllib.algorithms.dreamer.utils import (
Linear,
Conv2d,
ConvTranspose2d,

View file

@ -4,7 +4,7 @@ import numpy as np
from typing import Dict, Optional
import ray
from ray.rllib.agents.dreamer.utils import FreezeParameters
from ray.rllib.algorithms.dreamer.utils import FreezeParameters
from ray.rllib.evaluation.episode import Episode
from ray.rllib.models.catalog import ModelCatalog
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
@ -284,7 +284,7 @@ def preprocess_episode(
DreamerTorchPolicy = build_policy_class(
name="DreamerTorchPolicy",
framework="torch",
get_default_config=lambda: ray.rllib.agents.dreamer.dreamer.DEFAULT_CONFIG,
get_default_config=lambda: ray.rllib.algorithms.dreamer.dreamer.DEFAULT_CONFIG,
action_sampler_fn=action_sampler_fn,
postprocess_fn=preprocess_episode,
loss_fn=dreamer_loss,

View file

@ -2,7 +2,7 @@ from gym.spaces import Box
import unittest
import ray
import ray.rllib.agents.dreamer as dreamer
import ray.rllib.algorithms.dreamer as dreamer
from ray.rllib.examples.env.random_env import RandomEnv
from ray.rllib.utils.test_utils import framework_iterator

View file

@ -0,0 +1,6 @@
from ray.rllib.algorithms.maml.maml import MAMLTrainer, DEFAULT_CONFIG
__all__ = [
"MAMLTrainer",
"DEFAULT_CONFIG",
]

View file

@ -4,8 +4,8 @@ from typing import Type
from ray.rllib.utils.sgd import standardized
from ray.rllib.agents import with_common_config
from ray.rllib.agents.maml.maml_tf_policy import MAMLTFPolicy
from ray.rllib.agents.maml.maml_torch_policy import MAMLTorchPolicy
from ray.rllib.algorithms.maml.maml_tf_policy import MAMLTFPolicy
from ray.rllib.algorithms.maml.maml_torch_policy import MAMLTorchPolicy
from ray.rllib.agents.trainer import Trainer
from ray.rllib.evaluation.metrics import get_learner_stats
from ray.rllib.evaluation.worker_set import WorkerSet

View file

@ -442,7 +442,7 @@ def setup_mixins(policy, obs_space, action_space, config):
MAMLTFPolicy = build_tf_policy(
name="MAMLTFPolicy",
get_default_config=lambda: ray.rllib.agents.maml.maml.DEFAULT_CONFIG,
get_default_config=lambda: ray.rllib.algorithms.maml.maml.DEFAULT_CONFIG,
loss_fn=maml_loss,
stats_fn=maml_stats,
optimizer_fn=maml_optimizer_fn,

View file

@ -382,7 +382,7 @@ def setup_mixins(policy, obs_space, action_space, config):
MAMLTorchPolicy = build_policy_class(
name="MAMLTorchPolicy",
framework="torch",
get_default_config=lambda: ray.rllib.agents.maml.maml.DEFAULT_CONFIG,
get_default_config=lambda: ray.rllib.algorithms.maml.maml.DEFAULT_CONFIG,
loss_fn=maml_loss,
stats_fn=maml_stats,
optimizer_fn=maml_optimizer_fn,

View file

@ -1,7 +1,7 @@
import unittest
import ray
import ray.rllib.agents.maml as maml
import ray.rllib.algorithms.maml as maml
from ray.rllib.utils.test_utils import (
check_compute_single_action,
check_train_results,

View file

@ -0,0 +1,13 @@
from ray.rllib.algorithms.marwil.bc import BCTrainer, BC_DEFAULT_CONFIG
from ray.rllib.algorithms.marwil.marwil import MARWILTrainer, DEFAULT_CONFIG
from ray.rllib.algorithms.marwil.marwil_tf_policy import MARWILTFPolicy
from ray.rllib.algorithms.marwil.marwil_torch_policy import MARWILTorchPolicy
__all__ = [
"BCTrainer",
"BC_DEFAULT_CONFIG",
"DEFAULT_CONFIG",
"MARWILTFPolicy",
"MARWILTorchPolicy",
"MARWILTrainer",
]

View file

@ -1,4 +1,4 @@
from ray.rllib.agents.marwil.marwil import (
from ray.rllib.algorithms.marwil.marwil import (
MARWILTrainer,
DEFAULT_CONFIG as MARWIL_CONFIG,
)

View file

@ -1,7 +1,7 @@
from typing import Type
from ray.rllib.agents.trainer import Trainer, with_common_config
from ray.rllib.agents.marwil.marwil_tf_policy import MARWILTFPolicy
from ray.rllib.algorithms.marwil.marwil_tf_policy import MARWILTFPolicy
from ray.rllib.execution.buffers.multi_agent_replay_buffer import MultiAgentReplayBuffer
from ray.rllib.execution.rollout_ops import (
synchronous_parallel_sample,
@ -110,7 +110,9 @@ class MARWILTrainer(Trainer):
@override(Trainer)
def get_default_policy_class(self, config: TrainerConfigDict) -> Type[Policy]:
if config["framework"] == "torch":
from ray.rllib.agents.marwil.marwil_torch_policy import MARWILTorchPolicy
from ray.rllib.algorithms.marwil.marwil_torch_policy import (
MARWILTorchPolicy,
)
return MARWILTorchPolicy
else:

View file

@ -232,7 +232,7 @@ def setup_mixins(
MARWILTFPolicy = build_tf_policy(
name="MARWILTFPolicy",
get_default_config=lambda: ray.rllib.agents.marwil.marwil.DEFAULT_CONFIG,
get_default_config=lambda: ray.rllib.algorithms.marwil.marwil.DEFAULT_CONFIG,
loss_fn=marwil_loss,
stats_fn=stats,
postprocess_fn=postprocess_advantages,

View file

@ -3,7 +3,7 @@ from typing import Dict
import ray
from ray.rllib.agents.a3c.a3c_torch_policy import ValueNetworkMixin
from ray.rllib.agents.marwil.marwil_tf_policy import postprocess_advantages
from ray.rllib.algorithms.marwil.marwil_tf_policy import postprocess_advantages
from ray.rllib.evaluation.postprocessing import Postprocessing
from ray.rllib.policy.policy_template import build_policy_class
from ray.rllib.policy.sample_batch import SampleBatch
@ -113,7 +113,7 @@ MARWILTorchPolicy = build_policy_class(
name="MARWILTorchPolicy",
framework="torch",
loss_fn=marwil_loss,
get_default_config=lambda: ray.rllib.agents.marwil.marwil.DEFAULT_CONFIG,
get_default_config=lambda: ray.rllib.algorithms.marwil.marwil.DEFAULT_CONFIG,
stats_fn=stats,
postprocess_fn=postprocess_advantages,
extra_grad_process_fn=apply_grad_clipping,

View file

@ -3,7 +3,7 @@ from pathlib import Path
import unittest
import ray
import ray.rllib.agents.marwil as marwil
import ray.rllib.algorithms.marwil as marwil
from ray.rllib.utils.framework import try_import_tf
from ray.rllib.utils.test_utils import (
check_compute_single_action,

View file

@ -4,7 +4,7 @@ from pathlib import Path
import unittest
import ray
import ray.rllib.agents.marwil as marwil
import ray.rllib.algorithms.marwil as marwil
from ray.rllib.evaluation.postprocessing import compute_advantages
from ray.rllib.offline import JsonReader
from ray.rllib.utils.framework import try_import_tf, try_import_torch

View file

@ -0,0 +1,6 @@
from ray.rllib.algorithms.mbmpo.mbmpo import MBMPOTrainer, DEFAULT_CONFIG
__all__ = [
"MBMPOTrainer",
"DEFAULT_CONFIG",
]

View file

@ -4,9 +4,9 @@ from typing import List, Type
import ray
from ray.rllib.agents import with_common_config
from ray.rllib.agents.mbmpo.mbmpo_torch_policy import MBMPOTorchPolicy
from ray.rllib.agents.mbmpo.model_ensemble import DynamicsEnsembleCustomModel
from ray.rllib.agents.mbmpo.utils import calculate_gae_advantages, MBMPOExploration
from ray.rllib.algorithms.mbmpo.mbmpo_torch_policy import MBMPOTorchPolicy
from ray.rllib.algorithms.mbmpo.model_ensemble import DynamicsEnsembleCustomModel
from ray.rllib.algorithms.mbmpo.utils import calculate_gae_advantages, MBMPOExploration
from ray.rllib.agents.trainer import Trainer
from ray.rllib.env.env_context import EnvContext
from ray.rllib.env.wrappers.model_vector_env import model_vector_env

View file

@ -5,7 +5,7 @@ from typing import Tuple, Type
import ray
from ray.rllib.agents.a3c.a3c_torch_policy import vf_preds_fetches
from ray.rllib.agents.maml.maml_torch_policy import (
from ray.rllib.algorithms.maml.maml_torch_policy import (
setup_mixins,
maml_loss,
maml_stats,
@ -121,7 +121,7 @@ def make_model_and_action_dist(
MBMPOTorchPolicy = build_policy_class(
name="MBMPOTorchPolicy",
framework="torch",
get_default_config=lambda: ray.rllib.agents.mbmpo.mbmpo.DEFAULT_CONFIG,
get_default_config=lambda: ray.rllib.algorithms.mbmpo.mbmpo.DEFAULT_CONFIG,
make_model_and_action_dist=make_model_and_action_dist,
loss_fn=maml_loss,
stats_fn=maml_stats,

View file

@ -1,7 +1,7 @@
import unittest
import ray
import ray.rllib.agents.mbmpo as mbmpo
import ray.rllib.algorithms.mbmpo as mbmpo
from ray.rllib.utils.test_utils import (
check_compute_single_action,
check_train_results,

View file

@ -245,7 +245,6 @@ class MultiAgentEnv(gym.Env):
# fmt: off
# __grouping_doc_begin__
@ExperimentalAPI
def with_agent_groups(
self,
groups: Dict[str, List[AgentID]],
@ -265,16 +264,17 @@ class MultiAgentEnv(gym.Env):
Agent grouping is required to leverage algorithms such as Q-Mix.
This API is experimental.
Args:
groups: Mapping from group id to a list of the agent ids
of group members. If an agent id is not present in any group
value, it will be left ungrouped.
value, it will be left ungrouped. The group id becomes a new agent ID
in the final environment.
obs_space: Optional observation space for the grouped
env. Must be a tuple space.
env. Must be a tuple space. If not provided, will infer this to be a
Tuple of n individual agents spaces (n=num agents in a group).
act_space: Optional action space for the grouped env.
Must be a tuple space.
Must be a tuple space. If not provided, will infer this to be a Tuple
of n individual agents spaces (n=num agents in a group).
Examples:
>>> from ray.rllib.env.multi_agent_env import MultiAgentEnv

View file

@ -1,6 +1,9 @@
from collections import OrderedDict
import gym
from typing import Dict, List, Optional
from ray.rllib.env.multi_agent_env import MultiAgentEnv
from ray.rllib.utils.typing import AgentID
# info key for the individual rewards of an agent, for example:
# info: {
@ -27,21 +30,35 @@ class GroupAgentsWrapper(MultiAgentEnv):
This API is experimental.
"""
def __init__(self, env, groups, obs_space=None, act_space=None):
"""Wrap an existing multi-agent env to group agents together.
def __init__(
self,
env: MultiAgentEnv,
groups: Dict[str, List[AgentID]],
obs_space: Optional[gym.Space] = None,
act_space: Optional[gym.Space] = None,
):
"""Wrap an existing MultiAgentEnv to group agent ID together.
See MultiAgentEnv.with_agent_groups() for usage info.
See `MultiAgentEnv.with_agent_groups()` for more detailed usage info.
Args:
env (MultiAgentEnv): env to wrap
groups (dict): Grouping spec as documented in MultiAgentEnv.
obs_space (Space): Optional observation space for the grouped
env. Must be a tuple space.
act_space (Space): Optional action space for the grouped env.
Must be a tuple space.
env: The env to wrap and whose agent IDs to group into new agents.
groups: Mapping from group id to a list of the agent ids
of group members. If an agent id is not present in any group
value, it will be left ungrouped. The group id becomes a new agent ID
in the final environment.
obs_space: Optional observation space for the grouped
env. Must be a tuple space. If not provided, will infer this to be a
Tuple of n individual agents spaces (n=num agents in a group).
act_space: Optional action space for the grouped env.
Must be a tuple space. If not provided, will infer this to be a Tuple
of n individual agents spaces (n=num agents in a group).
"""
super().__init__()
self.env = env
# Inherit wrapped env's `_skip_env_checking` flag.
if hasattr(self.env, "_skip_env_checking"):
self._skip_env_checking = self.env._skip_env_checking
self.groups = groups
self.agent_id_to_group = {}
for group_id, agent_ids in groups.items():

View file

@ -20,7 +20,9 @@ import os
import ray
from ray import tune
from ray.rllib.agents.maml.maml_torch_policy import KLCoeffMixin as TorchKLCoeffMixin
from ray.rllib.algorithms.maml.maml_torch_policy import (
KLCoeffMixin as TorchKLCoeffMixin,
)
from ray.rllib.agents.ppo.ppo import PPOTrainer
from ray.rllib.agents.ppo.ppo_tf_policy import (
PPOTFPolicy,

View file

@ -16,6 +16,7 @@ import os
import ray
from ray import tune
from ray.tune import register_env
from ray.rllib.agents.qmix import QMixConfig
from ray.rllib.env.multi_agent_env import ENV_STATE
from ray.rllib.examples.env.two_step_game import TwoStepGame
from ray.rllib.policy.policy import PolicySpec
@ -110,10 +111,11 @@ if __name__ == "__main__":
obs_space = Discrete(6)
act_space = TwoStepGame.action_space
config = {
"learning_starts": 100,
"env": TwoStepGame,
"env_config": {
"actions_are_logits": True,
},
"learning_starts": 100,
"multiagent": {
"policies": {
"pol1": PolicySpec(
@ -133,31 +135,33 @@ if __name__ == "__main__":
# Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
"num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")),
}
group = False
elif args.run == "QMIX":
config = {
"rollout_fragment_length": 4,
"train_batch_size": 32,
"exploration_config": {
"final_epsilon": 0.0,
},
"num_workers": 0,
"mixer": args.mixer,
"env_config": {
"separate_state_space": True,
"one_hot_state_encoding": True,
},
# Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
"num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")),
}
group = True
config = (
QMixConfig()
.training(mixer=args.mixer, train_batch_size=32)
.rollouts(num_rollout_workers=0, rollout_fragment_length=4)
.exploration(
exploration_config={
"final_epsilon": 0.0,
}
)
.environment(
env="grouped_twostep",
env_config={
"separate_state_space": True,
"one_hot_state_encoding": True,
},
)
.resources(num_gpus=int(os.environ.get("RLLIB_NUM_GPUS", "0")))
)
config = config.to_dict()
else:
config = {
"env": TwoStepGame,
# Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
"num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")),
"framework": args.framework,
}
group = False
stop = {
"episode_reward_mean": args.stop_reward,
@ -165,13 +169,6 @@ if __name__ == "__main__":
"training_iteration": args.stop_iters,
}
config = dict(
config,
**{
"env": "grouped_twostep" if group else TwoStepGame,
}
)
results = tune.run(args.run, stop=stop, config=config, verbose=2)
if args.as_test:

View file

@ -6,7 +6,7 @@ import tree # pip install dm_tree
import unittest
import ray
from ray.rllib.agents.marwil import BCTrainer
from ray.rllib.algorithms.marwil import BCTrainer
from ray.rllib.agents.pg import PGTrainer, DEFAULT_CONFIG
from ray.rllib.examples.env.random_env import RandomEnv
from ray.rllib.offline.json_reader import JsonReader

View file

@ -32,7 +32,7 @@ multi-agent-cartpole-alpha-star:
# No league-building needed.
league_builder_config:
type: ray.rllib.agents.alpha_star.league_builder.NoLeagueBuilder
type: ray.rllib.algorithms.alpha_star.league_builder.NoLeagueBuilder
multiagent:
policies: ["p0", "p1", "p2", "p3"]