mirror of
https://github.com/vale981/ray
synced 2025-03-06 18:41:40 -05:00
171 lines
5.9 KiB
Python
171 lines
5.9 KiB
Python
from typing import Type
|
|
|
|
from ray.rllib.agents.trainer import with_common_config
|
|
from ray.rllib.agents.dqn.simple_q import SimpleQTrainer
|
|
from ray.rllib.agents.qmix.qmix_policy import QMixTorchPolicy
|
|
from ray.rllib.evaluation.worker_set import WorkerSet
|
|
from ray.rllib.execution.concurrency_ops import Concurrently
|
|
from ray.rllib.execution.metric_ops import StandardMetricsReporting
|
|
from ray.rllib.execution.replay_ops import (
|
|
SimpleReplayBuffer,
|
|
Replay,
|
|
StoreToReplayBuffer,
|
|
)
|
|
from ray.rllib.execution.rollout_ops import ParallelRollouts, ConcatBatches
|
|
from ray.rllib.execution.train_ops import TrainOneStep, UpdateTargetNetwork
|
|
from ray.rllib.policy.policy import Policy
|
|
from ray.rllib.utils.annotations import override
|
|
from ray.rllib.utils.typing import TrainerConfigDict
|
|
from ray.util.iter import LocalIterator
|
|
|
|
# fmt: off
|
|
# __sphinx_doc_begin__
|
|
DEFAULT_CONFIG = with_common_config({
|
|
# === QMix ===
|
|
# Mixing network. Either "qmix", "vdn", or None
|
|
"mixer": "qmix",
|
|
# Size of the mixing network embedding
|
|
"mixing_embed_dim": 32,
|
|
# Whether to use Double_Q learning
|
|
"double_q": True,
|
|
# Optimize over complete episodes by default.
|
|
"batch_mode": "complete_episodes",
|
|
|
|
# === Exploration Settings ===
|
|
"exploration_config": {
|
|
# The Exploration class to use.
|
|
"type": "EpsilonGreedy",
|
|
# Config for the Exploration class' constructor:
|
|
"initial_epsilon": 1.0,
|
|
"final_epsilon": 0.01,
|
|
# Timesteps over which to anneal epsilon.
|
|
"epsilon_timesteps": 40000,
|
|
|
|
# For soft_q, use:
|
|
# "exploration_config" = {
|
|
# "type": "SoftQ"
|
|
# "temperature": [float, e.g. 1.0]
|
|
# }
|
|
},
|
|
|
|
# === Evaluation ===
|
|
# Evaluate with epsilon=0 every `evaluation_interval` training iterations.
|
|
# The evaluation stats will be reported under the "evaluation" metric key.
|
|
# Note that evaluation is currently not parallelized, and that for Ape-X
|
|
# metrics are already only reported for the lowest epsilon workers.
|
|
"evaluation_interval": None,
|
|
# Number of episodes to run per evaluation period.
|
|
"evaluation_duration": 10,
|
|
# Switch to greedy actions in evaluation workers.
|
|
"evaluation_config": {
|
|
"explore": False,
|
|
},
|
|
|
|
# Number of env steps to optimize for before returning
|
|
"timesteps_per_iteration": 1000,
|
|
# Update the target network every `target_network_update_freq` steps.
|
|
"target_network_update_freq": 500,
|
|
|
|
# === Replay buffer ===
|
|
# Size of the replay buffer in batches (not timesteps!).
|
|
"buffer_size": 1000,
|
|
"replay_buffer_config": {
|
|
"no_local_replay_buffer": True,
|
|
},
|
|
# === Optimization ===
|
|
# Learning rate for RMSProp optimizer
|
|
"lr": 0.0005,
|
|
# RMSProp alpha
|
|
"optim_alpha": 0.99,
|
|
# RMSProp epsilon
|
|
"optim_eps": 0.00001,
|
|
# If not None, clip gradients during optimization at this value
|
|
"grad_norm_clipping": 10,
|
|
# How many steps of the model to sample before learning starts.
|
|
"learning_starts": 1000,
|
|
# Update the replay buffer with this many samples at once. Note that
|
|
# this setting applies per-worker if num_workers > 1.
|
|
"rollout_fragment_length": 4,
|
|
# Size of a batched sampled from replay buffer for training. Note that
|
|
# if async_updates is set, then each worker returns gradients for a
|
|
# batch of this size.
|
|
"train_batch_size": 32,
|
|
|
|
# === Parallelism ===
|
|
# Number of workers for collecting samples with. This only makes sense
|
|
# to increase if your environment is particularly slow to sample, or if
|
|
# you"re using the Async or Ape-X optimizers.
|
|
"num_workers": 0,
|
|
# Whether to compute priorities on workers.
|
|
"worker_side_prioritization": False,
|
|
# Prevent reporting frequency from going lower than this time span.
|
|
"min_time_s_per_reporting": 1,
|
|
|
|
# === Model ===
|
|
"model": {
|
|
"lstm_cell_size": 64,
|
|
"max_seq_len": 999999,
|
|
},
|
|
# Only torch supported so far.
|
|
"framework": "torch",
|
|
# Experimental flag.
|
|
# If True, the execution plan API will not be used. Instead,
|
|
# a Trainer's `training_iteration` method will be called as-is each
|
|
# training iteration.
|
|
"_disable_execution_plan_api": False,
|
|
})
|
|
# __sphinx_doc_end__
|
|
# fmt: on
|
|
|
|
|
|
class QMixTrainer(SimpleQTrainer):
|
|
@classmethod
|
|
@override(SimpleQTrainer)
|
|
def get_default_config(cls) -> TrainerConfigDict:
|
|
return DEFAULT_CONFIG
|
|
|
|
@override(SimpleQTrainer)
|
|
def validate_config(self, config: TrainerConfigDict) -> None:
|
|
# Call super's validation method.
|
|
super().validate_config(config)
|
|
|
|
if config["framework"] != "torch":
|
|
raise ValueError("Only `framework=torch` supported so far for QMixTrainer!")
|
|
|
|
@override(SimpleQTrainer)
|
|
def get_default_policy_class(self, config: TrainerConfigDict) -> Type[Policy]:
|
|
return QMixTorchPolicy
|
|
|
|
@staticmethod
|
|
@override(SimpleQTrainer)
|
|
def execution_plan(
|
|
workers: WorkerSet, config: TrainerConfigDict, **kwargs
|
|
) -> LocalIterator[dict]:
|
|
assert (
|
|
len(kwargs) == 0
|
|
), "QMIX execution_plan does NOT take any additional parameters"
|
|
|
|
rollouts = ParallelRollouts(workers, mode="bulk_sync")
|
|
replay_buffer = SimpleReplayBuffer(config["buffer_size"])
|
|
|
|
store_op = rollouts.for_each(StoreToReplayBuffer(local_buffer=replay_buffer))
|
|
|
|
train_op = (
|
|
Replay(local_buffer=replay_buffer)
|
|
.combine(
|
|
ConcatBatches(
|
|
min_batch_size=config["train_batch_size"],
|
|
count_steps_by=config["multiagent"]["count_steps_by"],
|
|
)
|
|
)
|
|
.for_each(TrainOneStep(workers))
|
|
.for_each(
|
|
UpdateTargetNetwork(workers, config["target_network_update_freq"])
|
|
)
|
|
)
|
|
|
|
merged_op = Concurrently(
|
|
[store_op, train_op], mode="round_robin", output_indexes=[1]
|
|
)
|
|
|
|
return StandardMetricsReporting(merged_op, workers, config)
|