mirror of
https://github.com/vale981/ray
synced 2025-03-05 18:11:42 -05:00
[RLlib] QMix TrainerConfig objects. (#24775)
This commit is contained in:
parent
ffcbb30552
commit
8fe3fd8f7b
9 changed files with 246 additions and 166 deletions
|
@ -1,14 +1,15 @@
|
|||
from ray.rllib.algorithms.cql.cql import CQLTrainer, CQL_DEFAULT_CONFIG
|
||||
from ray.rllib.algorithms.cql.cql_tf_policy import CQLTFPolicy
|
||||
from ray.rllib.algorithms.cql.cql_torch_policy import CQLTorchPolicy
|
||||
from ray.rllib.utils.deprecation import deprecation_warning
|
||||
|
||||
__all__ = [
|
||||
"CQL_DEFAULT_CONFIG",
|
||||
"CQLTFPolicy",
|
||||
"CQLTorchPolicy",
|
||||
"CQLTrainer",
|
||||
]
|
||||
|
||||
from ray.rllib.utils.deprecation import deprecation_warning
|
||||
|
||||
deprecation_warning(
|
||||
"ray.rllib.agents.dreamer", "ray.rllib.algorithms.dreamer", error=False
|
||||
"ray.rllib.agents.cql", "ray.rllib.algorithms.cql", error=False
|
||||
)
|
||||
|
|
|
@ -104,9 +104,9 @@ class SimpleQConfig(TrainerConfig):
|
|||
>>> .exploration(exploration_config=explore_config)
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self, trainer_class=None):
|
||||
"""Initializes a SimpleQConfig instance."""
|
||||
super().__init__(trainer_class=SimpleQTrainer)
|
||||
super().__init__(trainer_class=trainer_class or SimpleQTrainer)
|
||||
|
||||
# Simple Q specific
|
||||
# fmt: off
|
||||
|
|
|
@ -1,3 +1,3 @@
|
|||
from ray.rllib.agents.qmix.qmix import QMixTrainer, DEFAULT_CONFIG
|
||||
from ray.rllib.agents.qmix.qmix import QMixConfig, QMixTrainer, DEFAULT_CONFIG
|
||||
|
||||
__all__ = ["QMixTrainer", "DEFAULT_CONFIG"]
|
||||
__all__ = ["QMixConfig", "QMixTrainer", "DEFAULT_CONFIG"]
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
from typing import Type
|
||||
from typing import Optional, Type
|
||||
|
||||
from ray.rllib.agents.trainer import with_common_config
|
||||
from ray.rllib.agents.dqn.simple_q import SimpleQTrainer
|
||||
from ray.rllib.agents.dqn.simple_q import SimpleQConfig, SimpleQTrainer
|
||||
from ray.rllib.agents.qmix.qmix_policy import QMixTorchPolicy
|
||||
from ray.rllib.execution.rollout_ops import (
|
||||
synchronous_parallel_sample,
|
||||
|
@ -12,7 +11,7 @@ from ray.rllib.execution.train_ops import (
|
|||
)
|
||||
from ray.rllib.policy.policy import Policy
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.deprecation import DEPRECATED_VALUE
|
||||
from ray.rllib.utils.deprecation import DEPRECATED_VALUE, Deprecated
|
||||
from ray.rllib.utils.metrics import (
|
||||
LAST_TARGET_UPDATE_TS,
|
||||
NUM_AGENT_STEPS_SAMPLED,
|
||||
|
@ -23,118 +22,174 @@ from ray.rllib.utils.metrics import (
|
|||
from ray.rllib.utils.replay_buffers.utils import sample_min_n_steps_from_buffer
|
||||
from ray.rllib.utils.typing import ResultDict, TrainerConfigDict
|
||||
|
||||
# fmt: off
|
||||
# __sphinx_doc_begin__
|
||||
DEFAULT_CONFIG = with_common_config({
|
||||
# === QMix ===
|
||||
# Mixing network. Either "qmix", "vdn", or None
|
||||
"mixer": "qmix",
|
||||
# Size of the mixing network embedding
|
||||
"mixing_embed_dim": 32,
|
||||
# Whether to use Double_Q learning
|
||||
"double_q": True,
|
||||
# Optimize over complete episodes by default.
|
||||
"batch_mode": "complete_episodes",
|
||||
|
||||
# === Exploration Settings ===
|
||||
"exploration_config": {
|
||||
# The Exploration class to use.
|
||||
"type": "EpsilonGreedy",
|
||||
# Config for the Exploration class' constructor:
|
||||
"initial_epsilon": 1.0,
|
||||
"final_epsilon": 0.01,
|
||||
# Timesteps over which to anneal epsilon.
|
||||
"epsilon_timesteps": 40000,
|
||||
class QMixConfig(SimpleQConfig):
|
||||
"""Defines a configuration class from which a QMixTrainer can be built.
|
||||
|
||||
# For soft_q, use:
|
||||
# "exploration_config" = {
|
||||
# "type": "SoftQ"
|
||||
# "temperature": [float, e.g. 1.0]
|
||||
# }
|
||||
},
|
||||
Example:
|
||||
>>> from ray.rllib.examples.env.two_step_game import TwoStepGame
|
||||
>>> from ray.rllib.agents.qmix import QMixConfig
|
||||
>>> config = QMixConfig().training(gamma=0.9, lr=0.01, kl_coeff=0.3)\
|
||||
... .resources(num_gpus=0)\
|
||||
... .rollouts(num_workers=4)
|
||||
>>> print(config.to_dict())
|
||||
>>> # Build a Trainer object from the config and run 1 training iteration.
|
||||
>>> trainer = config.build(env=TwoStepGame)
|
||||
>>> trainer.train()
|
||||
|
||||
# === Evaluation ===
|
||||
# Evaluate with epsilon=0 every `evaluation_interval` training iterations.
|
||||
# The evaluation stats will be reported under the "evaluation" metric key.
|
||||
# Note that evaluation is currently not parallelized, and that for Ape-X
|
||||
# metrics are already only reported for the lowest epsilon workers.
|
||||
"evaluation_interval": None,
|
||||
# Number of episodes to run per evaluation period.
|
||||
"evaluation_duration": 10,
|
||||
# Switch to greedy actions in evaluation workers.
|
||||
"evaluation_config": {
|
||||
"explore": False,
|
||||
},
|
||||
Example:
|
||||
>>> from ray.rllib.examples.env.two_step_game import TwoStepGame
|
||||
>>> from ray.rllib.agents.qmix import QMixConfig
|
||||
>>> from ray import tune
|
||||
>>> config = QMixConfig()
|
||||
>>> # Print out some default values.
|
||||
>>> print(config.optim_alpha)
|
||||
>>> # Update the config object.
|
||||
>>> config.training(lr=tune.grid_search([0.001, 0.0001]), optim_alpha=0.97)
|
||||
>>> # Set the config object's env.
|
||||
>>> config.environment(env=TwoStepGame)
|
||||
>>> # Use to_dict() to get the old-style python config dict
|
||||
>>> # when running with tune.
|
||||
>>> tune.run(
|
||||
... "QMix",
|
||||
... stop={"episode_reward_mean": 200},
|
||||
... config=config.to_dict(),
|
||||
... )
|
||||
"""
|
||||
|
||||
# Minimum env sampling timesteps to accumulate within a single `train()` call. This
|
||||
# value does not affect learning, only the number of times `Trainer.step_attempt()`
|
||||
# is called by `Trauber.train()`. If - after one `step_attempt()`, the env sampling
|
||||
# timestep count has not been reached, will perform n more `step_attempt()` calls
|
||||
# until the minimum timesteps have been executed. Set to 0 for no minimum timesteps.
|
||||
"min_sample_timesteps_per_reporting": 1000,
|
||||
# Update the target network every `target_network_update_freq` steps.
|
||||
"target_network_update_freq": 500,
|
||||
def __init__(self):
|
||||
"""Initializes a PPOConfig instance."""
|
||||
super().__init__(trainer_class=QMixTrainer)
|
||||
|
||||
# === Replay buffer ===
|
||||
"replay_buffer_config": {
|
||||
# Use the new ReplayBuffer API here
|
||||
"_enable_replay_buffer_api": True,
|
||||
"type": "SimpleReplayBuffer",
|
||||
# Size of the replay buffer in batches (not timesteps!).
|
||||
"capacity": 1000,
|
||||
"learning_starts": 1000,
|
||||
},
|
||||
# fmt: off
|
||||
# __sphinx_doc_begin__
|
||||
# QMix specific settings:
|
||||
self.mixer = "qmix"
|
||||
self.mixing_embed_dim = 32
|
||||
self.double_q = True
|
||||
self.target_network_update_freq = 500
|
||||
self.replay_buffer_config = {
|
||||
# Use the new ReplayBuffer API here
|
||||
"_enable_replay_buffer_api": True,
|
||||
"type": "SimpleReplayBuffer",
|
||||
# Size of the replay buffer in batches (not timesteps!).
|
||||
"capacity": 1000,
|
||||
"learning_starts": 1000,
|
||||
}
|
||||
self.optim_alpha = 0.99
|
||||
self.optim_eps = 0.00001
|
||||
self.grad_norm_clipping = 10
|
||||
self.worker_side_prioritization = False
|
||||
|
||||
# === Optimization ===
|
||||
# Learning rate for RMSProp optimizer
|
||||
"lr": 0.0005,
|
||||
# RMSProp alpha
|
||||
"optim_alpha": 0.99,
|
||||
# RMSProp epsilon
|
||||
"optim_eps": 0.00001,
|
||||
# If not None, clip gradients during optimization at this value
|
||||
"grad_norm_clipping": 10,
|
||||
# Update the replay buffer with this many samples at once. Note that
|
||||
# this setting applies per-worker if num_workers > 1.
|
||||
"rollout_fragment_length": 4,
|
||||
# Minimum batch size used for training (in timesteps). With the default buffer
|
||||
# (ReplayBuffer) this means, sampling from the buffer (entire-episode SampleBatches)
|
||||
# as many times as is required to reach at least this number of timesteps.
|
||||
"train_batch_size": 32,
|
||||
# Override some of TrainerConfig's default values with QMix-specific values.
|
||||
self.num_workers = 0
|
||||
self.min_time_s_per_reporting = 1
|
||||
self.model = {
|
||||
"lstm_cell_size": 64,
|
||||
"max_seq_len": 999999,
|
||||
}
|
||||
self.framework_str = "torch"
|
||||
self.lr = 0.0005
|
||||
self.rollout_fragment_length = 4
|
||||
self.train_batch_size = 32
|
||||
self.batch_mode = "complete_episodes"
|
||||
self.exploration_config = {
|
||||
# The Exploration class to use.
|
||||
"type": "EpsilonGreedy",
|
||||
# Config for the Exploration class' constructor:
|
||||
"initial_epsilon": 1.0,
|
||||
"final_epsilon": 0.01,
|
||||
# Timesteps over which to anneal epsilon.
|
||||
"epsilon_timesteps": 40000,
|
||||
|
||||
# === Parallelism ===
|
||||
# Number of workers for collecting samples with. This only makes sense
|
||||
# to increase if your environment is particularly slow to sample, or if
|
||||
# you"re using the Async or Ape-X optimizers.
|
||||
"num_workers": 0,
|
||||
# Whether to compute priorities on workers.
|
||||
"worker_side_prioritization": False,
|
||||
# Prevent reporting frequency from going lower than this time span.
|
||||
"min_time_s_per_reporting": 1,
|
||||
# For soft_q, use:
|
||||
# "exploration_config" = {
|
||||
# "type": "SoftQ"
|
||||
# "temperature": [float, e.g. 1.0]
|
||||
# }
|
||||
}
|
||||
|
||||
# === Model ===
|
||||
"model": {
|
||||
"lstm_cell_size": 64,
|
||||
"max_seq_len": 999999,
|
||||
},
|
||||
# Only torch supported so far.
|
||||
"framework": "torch",
|
||||
# Evaluate with epsilon=0 every `evaluation_interval` training iterations.
|
||||
# The evaluation stats will be reported under the "evaluation" metric key.
|
||||
# Note that evaluation is currently not parallelized, and that for Ape-X
|
||||
# metrics are already only reported for the lowest epsilon workers.
|
||||
self.evaluation_interval = None
|
||||
self.evaluation_duration = 10
|
||||
self.evaluation_config = {
|
||||
"explore": False,
|
||||
}
|
||||
self.min_sample_timesteps_per_reporting = 1000
|
||||
# __sphinx_doc_end__
|
||||
# fmt: on
|
||||
|
||||
# Deprecated keys:
|
||||
# Use `replay_buffer_config.learning_starts` instead.
|
||||
"learning_starts": DEPRECATED_VALUE,
|
||||
# Use `replay_buffer_config.capacity` instead.
|
||||
"buffer_size": DEPRECATED_VALUE,
|
||||
})
|
||||
# __sphinx_doc_end__
|
||||
# fmt: on
|
||||
# Deprecated keys:
|
||||
self.learning_starts = DEPRECATED_VALUE
|
||||
self.buffer_size = DEPRECATED_VALUE
|
||||
|
||||
@override(SimpleQConfig)
|
||||
def training(
|
||||
self,
|
||||
*,
|
||||
mixer: Optional[str] = None,
|
||||
mixing_embed_dim: Optional[int] = None,
|
||||
double_q: Optional[bool] = None,
|
||||
target_network_update_freq: Optional[int] = None,
|
||||
replay_buffer_config: Optional[dict] = None,
|
||||
optim_alpha: Optional[float] = None,
|
||||
optim_eps: Optional[float] = None,
|
||||
grad_norm_clipping: Optional[float] = None,
|
||||
worker_side_prioritization: Optional[bool] = None,
|
||||
**kwargs,
|
||||
) -> "QMixConfig":
|
||||
"""Sets the training related configuration.
|
||||
|
||||
Args:
|
||||
mixer: Mixing network. Either "qmix", "vdn", or None.
|
||||
mixing_embed_dim: Size of the mixing network embedding.
|
||||
double_q: Whether to use Double_Q learning.
|
||||
target_network_update_freq: Update the target network every
|
||||
`target_network_update_freq` sample steps.
|
||||
replay_buffer_config:
|
||||
optim_alpha: RMSProp alpha.
|
||||
optim_eps: RMSProp epsilon.
|
||||
grad_norm_clipping: If not None, clip gradients during optimization at
|
||||
this value.
|
||||
worker_side_prioritization: Whether to compute priorities for the replay
|
||||
buffer on worker side.
|
||||
|
||||
Returns:
|
||||
This updated TrainerConfig object.
|
||||
"""
|
||||
# Pass kwargs onto super's `training()` method.
|
||||
super().training(**kwargs)
|
||||
|
||||
if mixer is not None:
|
||||
self.mixer = mixer
|
||||
if mixing_embed_dim is not None:
|
||||
self.mixing_embed_dim = mixing_embed_dim
|
||||
if double_q is not None:
|
||||
self.double_q = double_q
|
||||
if target_network_update_freq is not None:
|
||||
self.target_network_update_freq = target_network_update_freq
|
||||
if replay_buffer_config is not None:
|
||||
self.replay_buffer_config = replay_buffer_config
|
||||
if optim_alpha is not None:
|
||||
self.optim_alpha = optim_alpha
|
||||
if optim_eps is not None:
|
||||
self.optim_eps = optim_eps
|
||||
if grad_norm_clipping is not None:
|
||||
self.grad_norm_clipping = grad_norm_clipping
|
||||
if worker_side_prioritization is not None:
|
||||
self.worker_side_prioritization = worker_side_prioritization
|
||||
|
||||
return self
|
||||
|
||||
|
||||
class QMixTrainer(SimpleQTrainer):
|
||||
@classmethod
|
||||
@override(SimpleQTrainer)
|
||||
def get_default_config(cls) -> TrainerConfigDict:
|
||||
return DEFAULT_CONFIG
|
||||
return QMixConfig().to_dict()
|
||||
|
||||
@override(SimpleQTrainer)
|
||||
def validate_config(self, config: TrainerConfigDict) -> None:
|
||||
|
@ -219,3 +274,20 @@ class QMixTrainer(SimpleQTrainer):
|
|||
|
||||
# Return all collected metrics for the iteration.
|
||||
return train_results
|
||||
|
||||
|
||||
# Deprecated: Use ray.rllib.agents.qmix.qmix.QMixConfig instead!
|
||||
class _deprecated_default_config(dict):
|
||||
def __init__(self):
|
||||
super().__init__(QMixConfig().to_dict())
|
||||
|
||||
@Deprecated(
|
||||
old="ray.rllib.agents.qmix.qmix.DEFAULT_CONFIG",
|
||||
new="ray.rllib.agents.qmix.qmix.QMixConfig(...)",
|
||||
error=False,
|
||||
)
|
||||
def __getitem__(self, item):
|
||||
return super().__getitem__(item)
|
||||
|
||||
|
||||
DEFAULT_CONFIG = _deprecated_default_config()
|
||||
|
|
|
@ -153,7 +153,6 @@ class QMixLoss(nn.Module):
|
|||
return loss, mask, masked_td_error, chosen_action_qvals, targets
|
||||
|
||||
|
||||
# TODO(sven): Make this a TorchPolicy child via `build_policy_class`.
|
||||
class QMixTorchPolicy(TorchPolicy):
|
||||
"""QMix impl. Assumes homogeneous agents for now.
|
||||
|
||||
|
@ -177,9 +176,6 @@ class QMixTorchPolicy(TorchPolicy):
|
|||
self.h_size = config["model"]["lstm_cell_size"]
|
||||
self.has_env_global_state = False
|
||||
self.has_action_mask = False
|
||||
self.device = (
|
||||
torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
||||
)
|
||||
|
||||
agent_obs_space = obs_space.original_space.spaces[0]
|
||||
if isinstance(agent_obs_space, gym.spaces.Dict):
|
||||
|
@ -218,7 +214,9 @@ class QMixTorchPolicy(TorchPolicy):
|
|||
framework="torch",
|
||||
name="model",
|
||||
default_model=RNNModel,
|
||||
).to(self.device)
|
||||
)
|
||||
|
||||
super().__init__(obs_space, action_space, config, model=self.model)
|
||||
|
||||
self.target_model = ModelCatalog.get_model_v2(
|
||||
agent_obs_space,
|
||||
|
@ -230,8 +228,6 @@ class QMixTorchPolicy(TorchPolicy):
|
|||
default_model=RNNModel,
|
||||
).to(self.device)
|
||||
|
||||
super().__init__(obs_space, action_space, config, model=self.model)
|
||||
|
||||
self.exploration = self._create_exploration()
|
||||
|
||||
# Setup the mixer network.
|
||||
|
|
|
@ -4,7 +4,7 @@ import unittest
|
|||
|
||||
import ray
|
||||
from ray.tune import register_env
|
||||
from ray.rllib.agents.qmix import QMixTrainer
|
||||
from ray.rllib.agents.qmix import QMixConfig
|
||||
from ray.rllib.env.multi_agent_env import MultiAgentEnv
|
||||
|
||||
|
||||
|
@ -95,18 +95,19 @@ class TestQMix(unittest.TestCase):
|
|||
),
|
||||
)
|
||||
|
||||
trainer = QMixTrainer(
|
||||
env="action_mask_test",
|
||||
config={
|
||||
"num_envs_per_worker": 5, # test with vectorization on
|
||||
"env_config": {
|
||||
"avail_actions": [3, 4, 8],
|
||||
},
|
||||
"framework": "torch",
|
||||
},
|
||||
)
|
||||
config = QMixConfig()\
|
||||
.framework(framework="torch")\
|
||||
.environment(
|
||||
env="action_mask_test",
|
||||
env_config={"avail_actions": [3, 4, 8]},
|
||||
)\
|
||||
.rollouts(num_envs_per_worker=5) # Test with vectorization on.
|
||||
|
||||
trainer = config.build()
|
||||
|
||||
for _ in range(4):
|
||||
trainer.train() # OK if it doesn't trip the action assertion error
|
||||
|
||||
assert trainer.train()["episode_reward_mean"] == 30.0
|
||||
trainer.stop()
|
||||
ray.shutdown()
|
||||
|
|
12
rllib/env/multi_agent_env.py
vendored
12
rllib/env/multi_agent_env.py
vendored
|
@ -245,7 +245,6 @@ class MultiAgentEnv(gym.Env):
|
|||
|
||||
# fmt: off
|
||||
# __grouping_doc_begin__
|
||||
@ExperimentalAPI
|
||||
def with_agent_groups(
|
||||
self,
|
||||
groups: Dict[str, List[AgentID]],
|
||||
|
@ -265,16 +264,17 @@ class MultiAgentEnv(gym.Env):
|
|||
|
||||
Agent grouping is required to leverage algorithms such as Q-Mix.
|
||||
|
||||
This API is experimental.
|
||||
|
||||
Args:
|
||||
groups: Mapping from group id to a list of the agent ids
|
||||
of group members. If an agent id is not present in any group
|
||||
value, it will be left ungrouped.
|
||||
value, it will be left ungrouped. The group id becomes a new agent ID
|
||||
in the final environment.
|
||||
obs_space: Optional observation space for the grouped
|
||||
env. Must be a tuple space.
|
||||
env. Must be a tuple space. If not provided, will infer this to be a
|
||||
Tuple of n individual agents spaces (n=num agents in a group).
|
||||
act_space: Optional action space for the grouped env.
|
||||
Must be a tuple space.
|
||||
Must be a tuple space. If not provided, will infer this to be a Tuple
|
||||
of n individual agents spaces (n=num agents in a group).
|
||||
|
||||
Examples:
|
||||
>>> from ray.rllib.env.multi_agent_env import MultiAgentEnv
|
||||
|
|
35
rllib/env/wrappers/group_agents_wrapper.py
vendored
35
rllib/env/wrappers/group_agents_wrapper.py
vendored
|
@ -1,6 +1,9 @@
|
|||
from collections import OrderedDict
|
||||
import gym
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from ray.rllib.env.multi_agent_env import MultiAgentEnv
|
||||
from ray.rllib.utils.typing import AgentID
|
||||
|
||||
# info key for the individual rewards of an agent, for example:
|
||||
# info: {
|
||||
|
@ -27,21 +30,35 @@ class GroupAgentsWrapper(MultiAgentEnv):
|
|||
This API is experimental.
|
||||
"""
|
||||
|
||||
def __init__(self, env, groups, obs_space=None, act_space=None):
|
||||
"""Wrap an existing multi-agent env to group agents together.
|
||||
def __init__(
|
||||
self,
|
||||
env: MultiAgentEnv,
|
||||
groups: Dict[str, List[AgentID]],
|
||||
obs_space: Optional[gym.Space] = None,
|
||||
act_space: Optional[gym.Space] = None,
|
||||
):
|
||||
"""Wrap an existing MultiAgentEnv to group agent ID together.
|
||||
|
||||
See MultiAgentEnv.with_agent_groups() for usage info.
|
||||
See `MultiAgentEnv.with_agent_groups()` for more detailed usage info.
|
||||
|
||||
Args:
|
||||
env (MultiAgentEnv): env to wrap
|
||||
groups (dict): Grouping spec as documented in MultiAgentEnv.
|
||||
obs_space (Space): Optional observation space for the grouped
|
||||
env. Must be a tuple space.
|
||||
act_space (Space): Optional action space for the grouped env.
|
||||
Must be a tuple space.
|
||||
env: The env to wrap and whose agent IDs to group into new agents.
|
||||
groups: Mapping from group id to a list of the agent ids
|
||||
of group members. If an agent id is not present in any group
|
||||
value, it will be left ungrouped. The group id becomes a new agent ID
|
||||
in the final environment.
|
||||
obs_space: Optional observation space for the grouped
|
||||
env. Must be a tuple space. If not provided, will infer this to be a
|
||||
Tuple of n individual agents spaces (n=num agents in a group).
|
||||
act_space: Optional action space for the grouped env.
|
||||
Must be a tuple space. If not provided, will infer this to be a Tuple
|
||||
of n individual agents spaces (n=num agents in a group).
|
||||
"""
|
||||
super().__init__()
|
||||
self.env = env
|
||||
# Inherit wrapped env's `_skip_env_checking` flag.
|
||||
if hasattr(self.env, "_skip_env_checking"):
|
||||
self._skip_env_checking = self.env._skip_env_checking
|
||||
self.groups = groups
|
||||
self.agent_id_to_group = {}
|
||||
for group_id, agent_ids in groups.items():
|
||||
|
|
|
@ -16,6 +16,7 @@ import os
|
|||
import ray
|
||||
from ray import tune
|
||||
from ray.tune import register_env
|
||||
from ray.rllib.agents.qmix import QMixConfig
|
||||
from ray.rllib.env.multi_agent_env import ENV_STATE
|
||||
from ray.rllib.examples.env.two_step_game import TwoStepGame
|
||||
from ray.rllib.policy.policy import PolicySpec
|
||||
|
@ -110,10 +111,11 @@ if __name__ == "__main__":
|
|||
obs_space = Discrete(6)
|
||||
act_space = TwoStepGame.action_space
|
||||
config = {
|
||||
"learning_starts": 100,
|
||||
"env": TwoStepGame,
|
||||
"env_config": {
|
||||
"actions_are_logits": True,
|
||||
},
|
||||
"learning_starts": 100,
|
||||
"multiagent": {
|
||||
"policies": {
|
||||
"pol1": PolicySpec(
|
||||
|
@ -133,31 +135,29 @@ if __name__ == "__main__":
|
|||
# Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
|
||||
"num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")),
|
||||
}
|
||||
group = False
|
||||
elif args.run == "QMIX":
|
||||
config = {
|
||||
"rollout_fragment_length": 4,
|
||||
"train_batch_size": 32,
|
||||
"exploration_config": {
|
||||
config = QMixConfig()\
|
||||
.training(mixer=args.mixer, train_batch_size=32)\
|
||||
.rollouts(num_rollout_workers=0, rollout_fragment_length=4)\
|
||||
.exploration(exploration_config={
|
||||
"final_epsilon": 0.0,
|
||||
},
|
||||
"num_workers": 0,
|
||||
"mixer": args.mixer,
|
||||
"env_config": {
|
||||
"separate_state_space": True,
|
||||
"one_hot_state_encoding": True,
|
||||
},
|
||||
# Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
|
||||
"num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")),
|
||||
}
|
||||
group = True
|
||||
})\
|
||||
.environment(
|
||||
env="grouped_twostep",
|
||||
env_config={
|
||||
"separate_state_space": True,
|
||||
"one_hot_state_encoding": True,
|
||||
}
|
||||
)\
|
||||
.resources(num_gpus=int(os.environ.get("RLLIB_NUM_GPUS", "0")))
|
||||
config = config.to_dict()
|
||||
else:
|
||||
config = {
|
||||
"env": TwoStepGame,
|
||||
# Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
|
||||
"num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")),
|
||||
"framework": args.framework,
|
||||
}
|
||||
group = False
|
||||
|
||||
stop = {
|
||||
"episode_reward_mean": args.stop_reward,
|
||||
|
@ -165,13 +165,6 @@ if __name__ == "__main__":
|
|||
"training_iteration": args.stop_iters,
|
||||
}
|
||||
|
||||
config = dict(
|
||||
config,
|
||||
**{
|
||||
"env": "grouped_twostep" if group else TwoStepGame,
|
||||
}
|
||||
)
|
||||
|
||||
results = tune.run(args.run, stop=stop, config=config, verbose=2)
|
||||
|
||||
if args.as_test:
|
||||
|
|
Loading…
Add table
Reference in a new issue