from typing import Type from ray.rllib.agents.trainer import with_common_config from ray.rllib.agents.dqn.simple_q import SimpleQTrainer from ray.rllib.agents.qmix.qmix_policy import QMixTorchPolicy from ray.rllib.evaluation.worker_set import WorkerSet from ray.rllib.execution.concurrency_ops import Concurrently from ray.rllib.execution.metric_ops import StandardMetricsReporting from ray.rllib.execution.replay_ops import ( SimpleReplayBuffer, Replay, StoreToReplayBuffer, ) from ray.rllib.execution.rollout_ops import ParallelRollouts, ConcatBatches from ray.rllib.execution.train_ops import TrainOneStep, UpdateTargetNetwork from ray.rllib.policy.policy import Policy from ray.rllib.utils.annotations import override from ray.rllib.utils.typing import TrainerConfigDict from ray.util.iter import LocalIterator # fmt: off # __sphinx_doc_begin__ DEFAULT_CONFIG = with_common_config({ # === QMix === # Mixing network. Either "qmix", "vdn", or None "mixer": "qmix", # Size of the mixing network embedding "mixing_embed_dim": 32, # Whether to use Double_Q learning "double_q": True, # Optimize over complete episodes by default. "batch_mode": "complete_episodes", # === Exploration Settings === "exploration_config": { # The Exploration class to use. "type": "EpsilonGreedy", # Config for the Exploration class' constructor: "initial_epsilon": 1.0, "final_epsilon": 0.01, # Timesteps over which to anneal epsilon. "epsilon_timesteps": 40000, # For soft_q, use: # "exploration_config" = { # "type": "SoftQ" # "temperature": [float, e.g. 1.0] # } }, # === Evaluation === # Evaluate with epsilon=0 every `evaluation_interval` training iterations. # The evaluation stats will be reported under the "evaluation" metric key. # Note that evaluation is currently not parallelized, and that for Ape-X # metrics are already only reported for the lowest epsilon workers. "evaluation_interval": None, # Number of episodes to run per evaluation period. "evaluation_duration": 10, # Switch to greedy actions in evaluation workers. "evaluation_config": { "explore": False, }, # Number of env steps to optimize for before returning "timesteps_per_iteration": 1000, # Update the target network every `target_network_update_freq` steps. "target_network_update_freq": 500, # === Replay buffer === # Size of the replay buffer in batches (not timesteps!). "buffer_size": 1000, "replay_buffer_config": { "no_local_replay_buffer": True, }, # === Optimization === # Learning rate for RMSProp optimizer "lr": 0.0005, # RMSProp alpha "optim_alpha": 0.99, # RMSProp epsilon "optim_eps": 0.00001, # If not None, clip gradients during optimization at this value "grad_norm_clipping": 10, # How many steps of the model to sample before learning starts. "learning_starts": 1000, # Update the replay buffer with this many samples at once. Note that # this setting applies per-worker if num_workers > 1. "rollout_fragment_length": 4, # Size of a batched sampled from replay buffer for training. Note that # if async_updates is set, then each worker returns gradients for a # batch of this size. "train_batch_size": 32, # === Parallelism === # Number of workers for collecting samples with. This only makes sense # to increase if your environment is particularly slow to sample, or if # you"re using the Async or Ape-X optimizers. "num_workers": 0, # Whether to compute priorities on workers. "worker_side_prioritization": False, # Prevent reporting frequency from going lower than this time span. "min_time_s_per_reporting": 1, # === Model === "model": { "lstm_cell_size": 64, "max_seq_len": 999999, }, # Only torch supported so far. "framework": "torch", # Experimental flag. # If True, the execution plan API will not be used. Instead, # a Trainer's `training_iteration` method will be called as-is each # training iteration. "_disable_execution_plan_api": False, }) # __sphinx_doc_end__ # fmt: on class QMixTrainer(SimpleQTrainer): @classmethod @override(SimpleQTrainer) def get_default_config(cls) -> TrainerConfigDict: return DEFAULT_CONFIG @override(SimpleQTrainer) def validate_config(self, config: TrainerConfigDict) -> None: # Call super's validation method. super().validate_config(config) if config["framework"] != "torch": raise ValueError("Only `framework=torch` supported so far for QMixTrainer!") @override(SimpleQTrainer) def get_default_policy_class(self, config: TrainerConfigDict) -> Type[Policy]: return QMixTorchPolicy @staticmethod @override(SimpleQTrainer) def execution_plan( workers: WorkerSet, config: TrainerConfigDict, **kwargs ) -> LocalIterator[dict]: assert ( len(kwargs) == 0 ), "QMIX execution_plan does NOT take any additional parameters" rollouts = ParallelRollouts(workers, mode="bulk_sync") replay_buffer = SimpleReplayBuffer(config["buffer_size"]) store_op = rollouts.for_each(StoreToReplayBuffer(local_buffer=replay_buffer)) train_op = ( Replay(local_buffer=replay_buffer) .combine( ConcatBatches( min_batch_size=config["train_batch_size"], count_steps_by=config["multiagent"]["count_steps_by"], ) ) .for_each(TrainOneStep(workers)) .for_each( UpdateTargetNetwork(workers, config["target_network_update_freq"]) ) ) merged_op = Concurrently( [store_op, train_op], mode="round_robin", output_indexes=[1] ) return StandardMetricsReporting(merged_op, workers, config)