ray/rllib/agents/qmix/qmix.py

131 lines
4.6 KiB
Python

from ray.rllib.agents.trainer import with_common_config
from ray.rllib.agents.dqn.dqn import GenericOffPolicyTrainer
from ray.rllib.agents.qmix.qmix_policy import QMixTorchPolicy
from ray.rllib.execution.replay_ops import SimpleReplayBuffer, Replay, \
StoreToReplayBuffer
from ray.rllib.execution.rollout_ops import ParallelRollouts, ConcatBatches
from ray.rllib.execution.train_ops import TrainOneStep, UpdateTargetNetwork
from ray.rllib.execution.metric_ops import StandardMetricsReporting
from ray.rllib.execution.concurrency_ops import Concurrently
# yapf: disable
# __sphinx_doc_begin__
DEFAULT_CONFIG = with_common_config({
# === QMix ===
# Mixing network. Either "qmix", "vdn", or None
"mixer": "qmix",
# Size of the mixing network embedding
"mixing_embed_dim": 32,
# Whether to use Double_Q learning
"double_q": True,
# Optimize over complete episodes by default.
"batch_mode": "complete_episodes",
# === Exploration Settings ===
"exploration_config": {
# The Exploration class to use.
"type": "EpsilonGreedy",
# Config for the Exploration class' constructor:
"initial_epsilon": 1.0,
"final_epsilon": 0.02,
"epsilon_timesteps": 10000, # Timesteps over which to anneal epsilon.
# For soft_q, use:
# "exploration_config" = {
# "type": "SoftQ"
# "temperature": [float, e.g. 1.0]
# }
},
# === Evaluation ===
# Evaluate with epsilon=0 every `evaluation_interval` training iterations.
# The evaluation stats will be reported under the "evaluation" metric key.
# Note that evaluation is currently not parallelized, and that for Ape-X
# metrics are already only reported for the lowest epsilon workers.
"evaluation_interval": None,
# Number of episodes to run per evaluation period.
"evaluation_num_episodes": 10,
# Switch to greedy actions in evaluation workers.
"evaluation_config": {
"explore": False,
},
# Number of env steps to optimize for before returning
"timesteps_per_iteration": 1000,
# Update the target network every `target_network_update_freq` steps.
"target_network_update_freq": 500,
# === Replay buffer ===
# Size of the replay buffer in batches (not timesteps!).
"buffer_size": 1000,
# === Optimization ===
# Learning rate for RMSProp optimizer
"lr": 0.0005,
# RMSProp alpha
"optim_alpha": 0.99,
# RMSProp epsilon
"optim_eps": 0.00001,
# If not None, clip gradients during optimization at this value
"grad_norm_clipping": 10,
# How many steps of the model to sample before learning starts.
"learning_starts": 1000,
# Update the replay buffer with this many samples at once. Note that
# this setting applies per-worker if num_workers > 1.
"rollout_fragment_length": 4,
# Size of a batched sampled from replay buffer for training. Note that
# if async_updates is set, then each worker returns gradients for a
# batch of this size.
"train_batch_size": 32,
# === Parallelism ===
# Number of workers for collecting samples with. This only makes sense
# to increase if your environment is particularly slow to sample, or if
# you"re using the Async or Ape-X optimizers.
"num_workers": 0,
# Whether to use a distribution of epsilons across workers for exploration.
"per_worker_exploration": False,
# Whether to compute priorities on workers.
"worker_side_prioritization": False,
# Prevent iterations from going lower than this time span
"min_iter_time_s": 1,
# === Model ===
"model": {
"lstm_cell_size": 64,
"max_seq_len": 999999,
},
})
# __sphinx_doc_end__
# yapf: enable
def execution_plan(workers, config):
rollouts = ParallelRollouts(workers, mode="bulk_sync")
replay_buffer = SimpleReplayBuffer(config["buffer_size"])
store_op = rollouts \
.for_each(StoreToReplayBuffer(local_buffer=replay_buffer))
train_op = Replay(local_buffer=replay_buffer) \
.combine(
ConcatBatches(
min_batch_size=config["train_batch_size"],
count_steps_by=config["multiagent"]["count_steps_by"]
)) \
.for_each(TrainOneStep(workers)) \
.for_each(UpdateTargetNetwork(
workers, config["target_network_update_freq"]))
merged_op = Concurrently(
[store_op, train_op], mode="round_robin", output_indexes=[1])
return StandardMetricsReporting(merged_op, workers, config)
QMixTrainer = GenericOffPolicyTrainer.with_updates(
name="QMIX",
default_config=DEFAULT_CONFIG,
default_policy=QMixTorchPolicy,
get_policy_class=None,
execution_plan=execution_plan)