from ray.rllib.agents.trainer import with_common_config from ray.rllib.agents.dqn.dqn import GenericOffPolicyTrainer from ray.rllib.agents.qmix.qmix_policy import QMixTorchPolicy from ray.rllib.execution.replay_ops import SimpleReplayBuffer, Replay, \ StoreToReplayBuffer from ray.rllib.execution.rollout_ops import ParallelRollouts, ConcatBatches from ray.rllib.execution.train_ops import TrainOneStep, UpdateTargetNetwork from ray.rllib.execution.metric_ops import StandardMetricsReporting from ray.rllib.execution.concurrency_ops import Concurrently # yapf: disable # __sphinx_doc_begin__ DEFAULT_CONFIG = with_common_config({ # === QMix === # Mixing network. Either "qmix", "vdn", or None "mixer": "qmix", # Size of the mixing network embedding "mixing_embed_dim": 32, # Whether to use Double_Q learning "double_q": True, # Optimize over complete episodes by default. "batch_mode": "complete_episodes", # === Exploration Settings === "exploration_config": { # The Exploration class to use. "type": "EpsilonGreedy", # Config for the Exploration class' constructor: "initial_epsilon": 1.0, "final_epsilon": 0.02, "epsilon_timesteps": 10000, # Timesteps over which to anneal epsilon. # For soft_q, use: # "exploration_config" = { # "type": "SoftQ" # "temperature": [float, e.g. 1.0] # } }, # === Evaluation === # Evaluate with epsilon=0 every `evaluation_interval` training iterations. # The evaluation stats will be reported under the "evaluation" metric key. # Note that evaluation is currently not parallelized, and that for Ape-X # metrics are already only reported for the lowest epsilon workers. "evaluation_interval": None, # Number of episodes to run per evaluation period. "evaluation_num_episodes": 10, # Switch to greedy actions in evaluation workers. "evaluation_config": { "explore": False, }, # Number of env steps to optimize for before returning "timesteps_per_iteration": 1000, # Update the target network every `target_network_update_freq` steps. "target_network_update_freq": 500, # === Replay buffer === # Size of the replay buffer in batches (not timesteps!). "buffer_size": 1000, # === Optimization === # Learning rate for RMSProp optimizer "lr": 0.0005, # RMSProp alpha "optim_alpha": 0.99, # RMSProp epsilon "optim_eps": 0.00001, # If not None, clip gradients during optimization at this value "grad_norm_clipping": 10, # How many steps of the model to sample before learning starts. "learning_starts": 1000, # Update the replay buffer with this many samples at once. Note that # this setting applies per-worker if num_workers > 1. "rollout_fragment_length": 4, # Size of a batched sampled from replay buffer for training. Note that # if async_updates is set, then each worker returns gradients for a # batch of this size. "train_batch_size": 32, # === Parallelism === # Number of workers for collecting samples with. This only makes sense # to increase if your environment is particularly slow to sample, or if # you"re using the Async or Ape-X optimizers. "num_workers": 0, # Whether to use a distribution of epsilons across workers for exploration. "per_worker_exploration": False, # Whether to compute priorities on workers. "worker_side_prioritization": False, # Prevent iterations from going lower than this time span "min_iter_time_s": 1, # === Model === "model": { "lstm_cell_size": 64, "max_seq_len": 999999, }, }) # __sphinx_doc_end__ # yapf: enable def execution_plan(workers, config): rollouts = ParallelRollouts(workers, mode="bulk_sync") replay_buffer = SimpleReplayBuffer(config["buffer_size"]) store_op = rollouts \ .for_each(StoreToReplayBuffer(local_buffer=replay_buffer)) train_op = Replay(local_buffer=replay_buffer) \ .combine( ConcatBatches( min_batch_size=config["train_batch_size"], count_steps_by=config["multiagent"]["count_steps_by"] )) \ .for_each(TrainOneStep(workers)) \ .for_each(UpdateTargetNetwork( workers, config["target_network_update_freq"])) merged_op = Concurrently( [store_op, train_op], mode="round_robin", output_indexes=[1]) return StandardMetricsReporting(merged_op, workers, config) QMixTrainer = GenericOffPolicyTrainer.with_updates( name="QMIX", default_config=DEFAULT_CONFIG, default_policy=QMixTorchPolicy, get_policy_class=None, execution_plan=execution_plan)