[RLlib] QMix TrainerConfig objects. (#24775)

This commit is contained in:
Sven Mika 2022-05-13 18:50:28 +02:00 committed by GitHub
parent ffcbb30552
commit 8fe3fd8f7b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
9 changed files with 246 additions and 166 deletions

View file

@ -1,14 +1,15 @@
from ray.rllib.algorithms.cql.cql import CQLTrainer, CQL_DEFAULT_CONFIG
from ray.rllib.algorithms.cql.cql_tf_policy import CQLTFPolicy
from ray.rllib.algorithms.cql.cql_torch_policy import CQLTorchPolicy
from ray.rllib.utils.deprecation import deprecation_warning
__all__ = [
"CQL_DEFAULT_CONFIG",
"CQLTFPolicy",
"CQLTorchPolicy",
"CQLTrainer",
]
from ray.rllib.utils.deprecation import deprecation_warning
deprecation_warning(
"ray.rllib.agents.dreamer", "ray.rllib.algorithms.dreamer", error=False
"ray.rllib.agents.cql", "ray.rllib.algorithms.cql", error=False
)

View file

@ -104,9 +104,9 @@ class SimpleQConfig(TrainerConfig):
>>> .exploration(exploration_config=explore_config)
"""
def __init__(self):
def __init__(self, trainer_class=None):
"""Initializes a SimpleQConfig instance."""
super().__init__(trainer_class=SimpleQTrainer)
super().__init__(trainer_class=trainer_class or SimpleQTrainer)
# Simple Q specific
# fmt: off

View file

@ -1,3 +1,3 @@
from ray.rllib.agents.qmix.qmix import QMixTrainer, DEFAULT_CONFIG
from ray.rllib.agents.qmix.qmix import QMixConfig, QMixTrainer, DEFAULT_CONFIG
__all__ = ["QMixTrainer", "DEFAULT_CONFIG"]
__all__ = ["QMixConfig", "QMixTrainer", "DEFAULT_CONFIG"]

View file

@ -1,7 +1,6 @@
from typing import Type
from typing import Optional, Type
from ray.rllib.agents.trainer import with_common_config
from ray.rllib.agents.dqn.simple_q import SimpleQTrainer
from ray.rllib.agents.dqn.simple_q import SimpleQConfig, SimpleQTrainer
from ray.rllib.agents.qmix.qmix_policy import QMixTorchPolicy
from ray.rllib.execution.rollout_ops import (
synchronous_parallel_sample,
@ -12,7 +11,7 @@ from ray.rllib.execution.train_ops import (
)
from ray.rllib.policy.policy import Policy
from ray.rllib.utils.annotations import override
from ray.rllib.utils.deprecation import DEPRECATED_VALUE
from ray.rllib.utils.deprecation import DEPRECATED_VALUE, Deprecated
from ray.rllib.utils.metrics import (
LAST_TARGET_UPDATE_TS,
NUM_AGENT_STEPS_SAMPLED,
@ -23,118 +22,174 @@ from ray.rllib.utils.metrics import (
from ray.rllib.utils.replay_buffers.utils import sample_min_n_steps_from_buffer
from ray.rllib.utils.typing import ResultDict, TrainerConfigDict
# fmt: off
# __sphinx_doc_begin__
DEFAULT_CONFIG = with_common_config({
# === QMix ===
# Mixing network. Either "qmix", "vdn", or None
"mixer": "qmix",
# Size of the mixing network embedding
"mixing_embed_dim": 32,
# Whether to use Double_Q learning
"double_q": True,
# Optimize over complete episodes by default.
"batch_mode": "complete_episodes",
# === Exploration Settings ===
"exploration_config": {
# The Exploration class to use.
"type": "EpsilonGreedy",
# Config for the Exploration class' constructor:
"initial_epsilon": 1.0,
"final_epsilon": 0.01,
# Timesteps over which to anneal epsilon.
"epsilon_timesteps": 40000,
class QMixConfig(SimpleQConfig):
"""Defines a configuration class from which a QMixTrainer can be built.
# For soft_q, use:
# "exploration_config" = {
# "type": "SoftQ"
# "temperature": [float, e.g. 1.0]
# }
},
Example:
>>> from ray.rllib.examples.env.two_step_game import TwoStepGame
>>> from ray.rllib.agents.qmix import QMixConfig
>>> config = QMixConfig().training(gamma=0.9, lr=0.01, kl_coeff=0.3)\
... .resources(num_gpus=0)\
... .rollouts(num_workers=4)
>>> print(config.to_dict())
>>> # Build a Trainer object from the config and run 1 training iteration.
>>> trainer = config.build(env=TwoStepGame)
>>> trainer.train()
# === Evaluation ===
# Evaluate with epsilon=0 every `evaluation_interval` training iterations.
# The evaluation stats will be reported under the "evaluation" metric key.
# Note that evaluation is currently not parallelized, and that for Ape-X
# metrics are already only reported for the lowest epsilon workers.
"evaluation_interval": None,
# Number of episodes to run per evaluation period.
"evaluation_duration": 10,
# Switch to greedy actions in evaluation workers.
"evaluation_config": {
"explore": False,
},
Example:
>>> from ray.rllib.examples.env.two_step_game import TwoStepGame
>>> from ray.rllib.agents.qmix import QMixConfig
>>> from ray import tune
>>> config = QMixConfig()
>>> # Print out some default values.
>>> print(config.optim_alpha)
>>> # Update the config object.
>>> config.training(lr=tune.grid_search([0.001, 0.0001]), optim_alpha=0.97)
>>> # Set the config object's env.
>>> config.environment(env=TwoStepGame)
>>> # Use to_dict() to get the old-style python config dict
>>> # when running with tune.
>>> tune.run(
... "QMix",
... stop={"episode_reward_mean": 200},
... config=config.to_dict(),
... )
"""
# Minimum env sampling timesteps to accumulate within a single `train()` call. This
# value does not affect learning, only the number of times `Trainer.step_attempt()`
# is called by `Trauber.train()`. If - after one `step_attempt()`, the env sampling
# timestep count has not been reached, will perform n more `step_attempt()` calls
# until the minimum timesteps have been executed. Set to 0 for no minimum timesteps.
"min_sample_timesteps_per_reporting": 1000,
# Update the target network every `target_network_update_freq` steps.
"target_network_update_freq": 500,
def __init__(self):
"""Initializes a PPOConfig instance."""
super().__init__(trainer_class=QMixTrainer)
# === Replay buffer ===
"replay_buffer_config": {
# Use the new ReplayBuffer API here
"_enable_replay_buffer_api": True,
"type": "SimpleReplayBuffer",
# Size of the replay buffer in batches (not timesteps!).
"capacity": 1000,
"learning_starts": 1000,
},
# fmt: off
# __sphinx_doc_begin__
# QMix specific settings:
self.mixer = "qmix"
self.mixing_embed_dim = 32
self.double_q = True
self.target_network_update_freq = 500
self.replay_buffer_config = {
# Use the new ReplayBuffer API here
"_enable_replay_buffer_api": True,
"type": "SimpleReplayBuffer",
# Size of the replay buffer in batches (not timesteps!).
"capacity": 1000,
"learning_starts": 1000,
}
self.optim_alpha = 0.99
self.optim_eps = 0.00001
self.grad_norm_clipping = 10
self.worker_side_prioritization = False
# === Optimization ===
# Learning rate for RMSProp optimizer
"lr": 0.0005,
# RMSProp alpha
"optim_alpha": 0.99,
# RMSProp epsilon
"optim_eps": 0.00001,
# If not None, clip gradients during optimization at this value
"grad_norm_clipping": 10,
# Update the replay buffer with this many samples at once. Note that
# this setting applies per-worker if num_workers > 1.
"rollout_fragment_length": 4,
# Minimum batch size used for training (in timesteps). With the default buffer
# (ReplayBuffer) this means, sampling from the buffer (entire-episode SampleBatches)
# as many times as is required to reach at least this number of timesteps.
"train_batch_size": 32,
# Override some of TrainerConfig's default values with QMix-specific values.
self.num_workers = 0
self.min_time_s_per_reporting = 1
self.model = {
"lstm_cell_size": 64,
"max_seq_len": 999999,
}
self.framework_str = "torch"
self.lr = 0.0005
self.rollout_fragment_length = 4
self.train_batch_size = 32
self.batch_mode = "complete_episodes"
self.exploration_config = {
# The Exploration class to use.
"type": "EpsilonGreedy",
# Config for the Exploration class' constructor:
"initial_epsilon": 1.0,
"final_epsilon": 0.01,
# Timesteps over which to anneal epsilon.
"epsilon_timesteps": 40000,
# === Parallelism ===
# Number of workers for collecting samples with. This only makes sense
# to increase if your environment is particularly slow to sample, or if
# you"re using the Async or Ape-X optimizers.
"num_workers": 0,
# Whether to compute priorities on workers.
"worker_side_prioritization": False,
# Prevent reporting frequency from going lower than this time span.
"min_time_s_per_reporting": 1,
# For soft_q, use:
# "exploration_config" = {
# "type": "SoftQ"
# "temperature": [float, e.g. 1.0]
# }
}
# === Model ===
"model": {
"lstm_cell_size": 64,
"max_seq_len": 999999,
},
# Only torch supported so far.
"framework": "torch",
# Evaluate with epsilon=0 every `evaluation_interval` training iterations.
# The evaluation stats will be reported under the "evaluation" metric key.
# Note that evaluation is currently not parallelized, and that for Ape-X
# metrics are already only reported for the lowest epsilon workers.
self.evaluation_interval = None
self.evaluation_duration = 10
self.evaluation_config = {
"explore": False,
}
self.min_sample_timesteps_per_reporting = 1000
# __sphinx_doc_end__
# fmt: on
# Deprecated keys:
# Use `replay_buffer_config.learning_starts` instead.
"learning_starts": DEPRECATED_VALUE,
# Use `replay_buffer_config.capacity` instead.
"buffer_size": DEPRECATED_VALUE,
})
# __sphinx_doc_end__
# fmt: on
# Deprecated keys:
self.learning_starts = DEPRECATED_VALUE
self.buffer_size = DEPRECATED_VALUE
@override(SimpleQConfig)
def training(
self,
*,
mixer: Optional[str] = None,
mixing_embed_dim: Optional[int] = None,
double_q: Optional[bool] = None,
target_network_update_freq: Optional[int] = None,
replay_buffer_config: Optional[dict] = None,
optim_alpha: Optional[float] = None,
optim_eps: Optional[float] = None,
grad_norm_clipping: Optional[float] = None,
worker_side_prioritization: Optional[bool] = None,
**kwargs,
) -> "QMixConfig":
"""Sets the training related configuration.
Args:
mixer: Mixing network. Either "qmix", "vdn", or None.
mixing_embed_dim: Size of the mixing network embedding.
double_q: Whether to use Double_Q learning.
target_network_update_freq: Update the target network every
`target_network_update_freq` sample steps.
replay_buffer_config:
optim_alpha: RMSProp alpha.
optim_eps: RMSProp epsilon.
grad_norm_clipping: If not None, clip gradients during optimization at
this value.
worker_side_prioritization: Whether to compute priorities for the replay
buffer on worker side.
Returns:
This updated TrainerConfig object.
"""
# Pass kwargs onto super's `training()` method.
super().training(**kwargs)
if mixer is not None:
self.mixer = mixer
if mixing_embed_dim is not None:
self.mixing_embed_dim = mixing_embed_dim
if double_q is not None:
self.double_q = double_q
if target_network_update_freq is not None:
self.target_network_update_freq = target_network_update_freq
if replay_buffer_config is not None:
self.replay_buffer_config = replay_buffer_config
if optim_alpha is not None:
self.optim_alpha = optim_alpha
if optim_eps is not None:
self.optim_eps = optim_eps
if grad_norm_clipping is not None:
self.grad_norm_clipping = grad_norm_clipping
if worker_side_prioritization is not None:
self.worker_side_prioritization = worker_side_prioritization
return self
class QMixTrainer(SimpleQTrainer):
@classmethod
@override(SimpleQTrainer)
def get_default_config(cls) -> TrainerConfigDict:
return DEFAULT_CONFIG
return QMixConfig().to_dict()
@override(SimpleQTrainer)
def validate_config(self, config: TrainerConfigDict) -> None:
@ -219,3 +274,20 @@ class QMixTrainer(SimpleQTrainer):
# Return all collected metrics for the iteration.
return train_results
# Deprecated: Use ray.rllib.agents.qmix.qmix.QMixConfig instead!
class _deprecated_default_config(dict):
def __init__(self):
super().__init__(QMixConfig().to_dict())
@Deprecated(
old="ray.rllib.agents.qmix.qmix.DEFAULT_CONFIG",
new="ray.rllib.agents.qmix.qmix.QMixConfig(...)",
error=False,
)
def __getitem__(self, item):
return super().__getitem__(item)
DEFAULT_CONFIG = _deprecated_default_config()

View file

@ -153,7 +153,6 @@ class QMixLoss(nn.Module):
return loss, mask, masked_td_error, chosen_action_qvals, targets
# TODO(sven): Make this a TorchPolicy child via `build_policy_class`.
class QMixTorchPolicy(TorchPolicy):
"""QMix impl. Assumes homogeneous agents for now.
@ -177,9 +176,6 @@ class QMixTorchPolicy(TorchPolicy):
self.h_size = config["model"]["lstm_cell_size"]
self.has_env_global_state = False
self.has_action_mask = False
self.device = (
torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
)
agent_obs_space = obs_space.original_space.spaces[0]
if isinstance(agent_obs_space, gym.spaces.Dict):
@ -218,7 +214,9 @@ class QMixTorchPolicy(TorchPolicy):
framework="torch",
name="model",
default_model=RNNModel,
).to(self.device)
)
super().__init__(obs_space, action_space, config, model=self.model)
self.target_model = ModelCatalog.get_model_v2(
agent_obs_space,
@ -230,8 +228,6 @@ class QMixTorchPolicy(TorchPolicy):
default_model=RNNModel,
).to(self.device)
super().__init__(obs_space, action_space, config, model=self.model)
self.exploration = self._create_exploration()
# Setup the mixer network.

View file

@ -4,7 +4,7 @@ import unittest
import ray
from ray.tune import register_env
from ray.rllib.agents.qmix import QMixTrainer
from ray.rllib.agents.qmix import QMixConfig
from ray.rllib.env.multi_agent_env import MultiAgentEnv
@ -95,18 +95,19 @@ class TestQMix(unittest.TestCase):
),
)
trainer = QMixTrainer(
env="action_mask_test",
config={
"num_envs_per_worker": 5, # test with vectorization on
"env_config": {
"avail_actions": [3, 4, 8],
},
"framework": "torch",
},
)
config = QMixConfig()\
.framework(framework="torch")\
.environment(
env="action_mask_test",
env_config={"avail_actions": [3, 4, 8]},
)\
.rollouts(num_envs_per_worker=5) # Test with vectorization on.
trainer = config.build()
for _ in range(4):
trainer.train() # OK if it doesn't trip the action assertion error
assert trainer.train()["episode_reward_mean"] == 30.0
trainer.stop()
ray.shutdown()

View file

@ -245,7 +245,6 @@ class MultiAgentEnv(gym.Env):
# fmt: off
# __grouping_doc_begin__
@ExperimentalAPI
def with_agent_groups(
self,
groups: Dict[str, List[AgentID]],
@ -265,16 +264,17 @@ class MultiAgentEnv(gym.Env):
Agent grouping is required to leverage algorithms such as Q-Mix.
This API is experimental.
Args:
groups: Mapping from group id to a list of the agent ids
of group members. If an agent id is not present in any group
value, it will be left ungrouped.
value, it will be left ungrouped. The group id becomes a new agent ID
in the final environment.
obs_space: Optional observation space for the grouped
env. Must be a tuple space.
env. Must be a tuple space. If not provided, will infer this to be a
Tuple of n individual agents spaces (n=num agents in a group).
act_space: Optional action space for the grouped env.
Must be a tuple space.
Must be a tuple space. If not provided, will infer this to be a Tuple
of n individual agents spaces (n=num agents in a group).
Examples:
>>> from ray.rllib.env.multi_agent_env import MultiAgentEnv

View file

@ -1,6 +1,9 @@
from collections import OrderedDict
import gym
from typing import Dict, List, Optional
from ray.rllib.env.multi_agent_env import MultiAgentEnv
from ray.rllib.utils.typing import AgentID
# info key for the individual rewards of an agent, for example:
# info: {
@ -27,21 +30,35 @@ class GroupAgentsWrapper(MultiAgentEnv):
This API is experimental.
"""
def __init__(self, env, groups, obs_space=None, act_space=None):
"""Wrap an existing multi-agent env to group agents together.
def __init__(
self,
env: MultiAgentEnv,
groups: Dict[str, List[AgentID]],
obs_space: Optional[gym.Space] = None,
act_space: Optional[gym.Space] = None,
):
"""Wrap an existing MultiAgentEnv to group agent ID together.
See MultiAgentEnv.with_agent_groups() for usage info.
See `MultiAgentEnv.with_agent_groups()` for more detailed usage info.
Args:
env (MultiAgentEnv): env to wrap
groups (dict): Grouping spec as documented in MultiAgentEnv.
obs_space (Space): Optional observation space for the grouped
env. Must be a tuple space.
act_space (Space): Optional action space for the grouped env.
Must be a tuple space.
env: The env to wrap and whose agent IDs to group into new agents.
groups: Mapping from group id to a list of the agent ids
of group members. If an agent id is not present in any group
value, it will be left ungrouped. The group id becomes a new agent ID
in the final environment.
obs_space: Optional observation space for the grouped
env. Must be a tuple space. If not provided, will infer this to be a
Tuple of n individual agents spaces (n=num agents in a group).
act_space: Optional action space for the grouped env.
Must be a tuple space. If not provided, will infer this to be a Tuple
of n individual agents spaces (n=num agents in a group).
"""
super().__init__()
self.env = env
# Inherit wrapped env's `_skip_env_checking` flag.
if hasattr(self.env, "_skip_env_checking"):
self._skip_env_checking = self.env._skip_env_checking
self.groups = groups
self.agent_id_to_group = {}
for group_id, agent_ids in groups.items():

View file

@ -16,6 +16,7 @@ import os
import ray
from ray import tune
from ray.tune import register_env
from ray.rllib.agents.qmix import QMixConfig
from ray.rllib.env.multi_agent_env import ENV_STATE
from ray.rllib.examples.env.two_step_game import TwoStepGame
from ray.rllib.policy.policy import PolicySpec
@ -110,10 +111,11 @@ if __name__ == "__main__":
obs_space = Discrete(6)
act_space = TwoStepGame.action_space
config = {
"learning_starts": 100,
"env": TwoStepGame,
"env_config": {
"actions_are_logits": True,
},
"learning_starts": 100,
"multiagent": {
"policies": {
"pol1": PolicySpec(
@ -133,31 +135,29 @@ if __name__ == "__main__":
# Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
"num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")),
}
group = False
elif args.run == "QMIX":
config = {
"rollout_fragment_length": 4,
"train_batch_size": 32,
"exploration_config": {
config = QMixConfig()\
.training(mixer=args.mixer, train_batch_size=32)\
.rollouts(num_rollout_workers=0, rollout_fragment_length=4)\
.exploration(exploration_config={
"final_epsilon": 0.0,
},
"num_workers": 0,
"mixer": args.mixer,
"env_config": {
"separate_state_space": True,
"one_hot_state_encoding": True,
},
# Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
"num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")),
}
group = True
})\
.environment(
env="grouped_twostep",
env_config={
"separate_state_space": True,
"one_hot_state_encoding": True,
}
)\
.resources(num_gpus=int(os.environ.get("RLLIB_NUM_GPUS", "0")))
config = config.to_dict()
else:
config = {
"env": TwoStepGame,
# Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
"num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")),
"framework": args.framework,
}
group = False
stop = {
"episode_reward_mean": args.stop_reward,
@ -165,13 +165,6 @@ if __name__ == "__main__":
"training_iteration": args.stop_iters,
}
config = dict(
config,
**{
"env": "grouped_twostep" if group else TwoStepGame,
}
)
results = tune.run(args.run, stop=stop, config=config, verbose=2)
if args.as_test: