ray/rllib/agents/qmix/qmix.py

222 lines
8.5 KiB
Python
Raw Normal View History

from typing import Type
from ray.rllib.agents.trainer import with_common_config
from ray.rllib.agents.dqn.simple_q import SimpleQTrainer
from ray.rllib.agents.qmix.qmix_policy import QMixTorchPolicy
from ray.rllib.execution.rollout_ops import (
synchronous_parallel_sample,
)
from ray.rllib.execution.train_ops import (
multi_gpu_train_one_step,
train_one_step,
)
from ray.rllib.policy.policy import Policy
from ray.rllib.utils.annotations import override
from ray.rllib.utils.deprecation import DEPRECATED_VALUE
from ray.rllib.utils.metrics import (
LAST_TARGET_UPDATE_TS,
NUM_AGENT_STEPS_SAMPLED,
NUM_ENV_STEPS_SAMPLED,
NUM_TARGET_UPDATES,
SYNCH_WORKER_WEIGHTS_TIMER,
)
from ray.rllib.utils.replay_buffers.utils import sample_min_n_steps_from_buffer
from ray.rllib.utils.typing import ResultDict, TrainerConfigDict
# fmt: off
# __sphinx_doc_begin__
DEFAULT_CONFIG = with_common_config({
# === QMix ===
# Mixing network. Either "qmix", "vdn", or None
"mixer": "qmix",
# Size of the mixing network embedding
"mixing_embed_dim": 32,
# Whether to use Double_Q learning
"double_q": True,
# Optimize over complete episodes by default.
"batch_mode": "complete_episodes",
# === Exploration Settings ===
"exploration_config": {
# The Exploration class to use.
"type": "EpsilonGreedy",
# Config for the Exploration class' constructor:
"initial_epsilon": 1.0,
"final_epsilon": 0.01,
# Timesteps over which to anneal epsilon.
"epsilon_timesteps": 40000,
# For soft_q, use:
# "exploration_config" = {
# "type": "SoftQ"
# "temperature": [float, e.g. 1.0]
# }
},
# === Evaluation ===
# Evaluate with epsilon=0 every `evaluation_interval` training iterations.
# The evaluation stats will be reported under the "evaluation" metric key.
# Note that evaluation is currently not parallelized, and that for Ape-X
# metrics are already only reported for the lowest epsilon workers.
"evaluation_interval": None,
# Number of episodes to run per evaluation period.
"evaluation_duration": 10,
# Switch to greedy actions in evaluation workers.
"evaluation_config": {
"explore": False,
},
# Minimum env sampling timesteps to accumulate within a single `train()` call. This
# value does not affect learning, only the number of times `Trainer.step_attempt()`
# is called by `Trauber.train()`. If - after one `step_attempt()`, the env sampling
# timestep count has not been reached, will perform n more `step_attempt()` calls
# until the minimum timesteps have been executed. Set to 0 for no minimum timesteps.
"min_sample_timesteps_per_reporting": 1000,
# Update the target network every `target_network_update_freq` steps.
"target_network_update_freq": 500,
# === Replay buffer ===
"replay_buffer_config": {
# Use the new ReplayBuffer API here
"_enable_replay_buffer_api": True,
"type": "SimpleReplayBuffer",
# Size of the replay buffer in batches (not timesteps!).
"capacity": 1000,
"learning_starts": 1000,
},
# === Optimization ===
# Learning rate for RMSProp optimizer
"lr": 0.0005,
# RMSProp alpha
"optim_alpha": 0.99,
# RMSProp epsilon
"optim_eps": 0.00001,
# If not None, clip gradients during optimization at this value
"grad_norm_clipping": 10,
# Update the replay buffer with this many samples at once. Note that
# this setting applies per-worker if num_workers > 1.
"rollout_fragment_length": 4,
# Minimum batch size used for training (in timesteps). With the default buffer
# (ReplayBuffer) this means, sampling from the buffer (entire-episode SampleBatches)
# as many times as is required to reach at least this number of timesteps.
"train_batch_size": 32,
# === Parallelism ===
# Number of workers for collecting samples with. This only makes sense
# to increase if your environment is particularly slow to sample, or if
# you"re using the Async or Ape-X optimizers.
"num_workers": 0,
# Whether to compute priorities on workers.
"worker_side_prioritization": False,
# Prevent reporting frequency from going lower than this time span.
"min_time_s_per_reporting": 1,
# === Model ===
"model": {
"lstm_cell_size": 64,
"max_seq_len": 999999,
},
# Only torch supported so far.
"framework": "torch",
# Deprecated keys:
# Use `replay_buffer_config.learning_starts` instead.
"learning_starts": DEPRECATED_VALUE,
# Use `replay_buffer_config.capacity` instead.
"buffer_size": DEPRECATED_VALUE,
})
# __sphinx_doc_end__
# fmt: on
class QMixTrainer(SimpleQTrainer):
@classmethod
@override(SimpleQTrainer)
def get_default_config(cls) -> TrainerConfigDict:
return DEFAULT_CONFIG
@override(SimpleQTrainer)
def validate_config(self, config: TrainerConfigDict) -> None:
# Call super's validation method.
super().validate_config(config)
if config["framework"] != "torch":
raise ValueError("Only `framework=torch` supported so far for QMixTrainer!")
@override(SimpleQTrainer)
def get_default_policy_class(self, config: TrainerConfigDict) -> Type[Policy]:
return QMixTorchPolicy
@override(SimpleQTrainer)
def training_iteration(self) -> ResultDict:
"""QMIX training iteration function.
- Sample n MultiAgentBatches from n workers synchronously.
- Store new samples in the replay buffer.
- Sample one training MultiAgentBatch from the replay buffer.
- Learn on the training batch.
- Update the target network every `target_network_update_freq` steps.
- Return all collected training metrics for the iteration.
Returns:
The results dict from executing the training iteration.
"""
# Sample n batches from n workers.
new_sample_batches = synchronous_parallel_sample(
worker_set=self.workers, concat=False
)
for batch in new_sample_batches:
# Update counters.
self._counters[NUM_ENV_STEPS_SAMPLED] += batch.env_steps()
self._counters[NUM_AGENT_STEPS_SAMPLED] += batch.agent_steps()
# Store new samples in the replay buffer.
self.local_replay_buffer.add(batch)
# Sample n batches from replay buffer until the total number of timesteps
# reaches `train_batch_size`.
train_batch = sample_min_n_steps_from_buffer(
replay_buffer=self.local_replay_buffer,
min_steps=self.config["train_batch_size"],
count_by_agent_steps=self._by_agent_steps,
)
if train_batch is None:
return {}
# Learn on the training batch.
# Use simple optimizer (only for multi-agent or tf-eager; all other
# cases should use the multi-GPU optimizer, even if only using 1 GPU)
if self.config.get("simple_optimizer") is True:
train_results = train_one_step(self, train_batch)
else:
train_results = multi_gpu_train_one_step(self, train_batch)
# TODO: Move training steps counter update outside of `train_one_step()` method.
# # Update train step counters.
# self._counters[NUM_ENV_STEPS_TRAINED] += train_batch.env_steps()
# self._counters[NUM_AGENT_STEPS_TRAINED] += train_batch.agent_steps()
# Update target network every `target_network_update_freq` steps.
cur_ts = self._counters[NUM_ENV_STEPS_SAMPLED]
last_update = self._counters[LAST_TARGET_UPDATE_TS]
if cur_ts - last_update >= self.config["target_network_update_freq"]:
to_update = self.workers.local_worker().get_policies_to_train()
self.workers.local_worker().foreach_policy_to_train(
lambda p, pid: pid in to_update and p.update_target()
)
self._counters[NUM_TARGET_UPDATES] += 1
self._counters[LAST_TARGET_UPDATE_TS] = cur_ts
# Update weights and global_vars - after learning on the local worker - on all
# remote workers.
global_vars = {
"timestep": self._counters[NUM_ENV_STEPS_SAMPLED],
}
# Update remote workers' weights and global vars after learning on local worker.
with self._timers[SYNCH_WORKER_WEIGHTS_TIMER]:
self.workers.sync_weights(global_vars=global_vars)
# Return all collected metrics for the iteration.
return train_results