from typing import Type from ray.rllib.agents.trainer import with_common_config from ray.rllib.agents.dqn.simple_q import SimpleQTrainer from ray.rllib.agents.qmix.qmix_policy import QMixTorchPolicy from ray.rllib.execution.rollout_ops import ( synchronous_parallel_sample, ) from ray.rllib.execution.train_ops import ( multi_gpu_train_one_step, train_one_step, ) from ray.rllib.policy.policy import Policy from ray.rllib.utils.annotations import override from ray.rllib.utils.deprecation import DEPRECATED_VALUE from ray.rllib.utils.metrics import ( LAST_TARGET_UPDATE_TS, NUM_AGENT_STEPS_SAMPLED, NUM_ENV_STEPS_SAMPLED, NUM_TARGET_UPDATES, SYNCH_WORKER_WEIGHTS_TIMER, ) from ray.rllib.utils.replay_buffers.utils import sample_min_n_steps_from_buffer from ray.rllib.utils.typing import ResultDict, TrainerConfigDict # fmt: off # __sphinx_doc_begin__ DEFAULT_CONFIG = with_common_config({ # === QMix === # Mixing network. Either "qmix", "vdn", or None "mixer": "qmix", # Size of the mixing network embedding "mixing_embed_dim": 32, # Whether to use Double_Q learning "double_q": True, # Optimize over complete episodes by default. "batch_mode": "complete_episodes", # === Exploration Settings === "exploration_config": { # The Exploration class to use. "type": "EpsilonGreedy", # Config for the Exploration class' constructor: "initial_epsilon": 1.0, "final_epsilon": 0.01, # Timesteps over which to anneal epsilon. "epsilon_timesteps": 40000, # For soft_q, use: # "exploration_config" = { # "type": "SoftQ" # "temperature": [float, e.g. 1.0] # } }, # === Evaluation === # Evaluate with epsilon=0 every `evaluation_interval` training iterations. # The evaluation stats will be reported under the "evaluation" metric key. # Note that evaluation is currently not parallelized, and that for Ape-X # metrics are already only reported for the lowest epsilon workers. "evaluation_interval": None, # Number of episodes to run per evaluation period. "evaluation_duration": 10, # Switch to greedy actions in evaluation workers. "evaluation_config": { "explore": False, }, # Minimum env sampling timesteps to accumulate within a single `train()` call. This # value does not affect learning, only the number of times `Trainer.step_attempt()` # is called by `Trauber.train()`. If - after one `step_attempt()`, the env sampling # timestep count has not been reached, will perform n more `step_attempt()` calls # until the minimum timesteps have been executed. Set to 0 for no minimum timesteps. "min_sample_timesteps_per_reporting": 1000, # Update the target network every `target_network_update_freq` steps. "target_network_update_freq": 500, # === Replay buffer === "replay_buffer_config": { # Use the new ReplayBuffer API here "_enable_replay_buffer_api": True, "type": "SimpleReplayBuffer", # Size of the replay buffer in batches (not timesteps!). "capacity": 1000, "learning_starts": 1000, }, # === Optimization === # Learning rate for RMSProp optimizer "lr": 0.0005, # RMSProp alpha "optim_alpha": 0.99, # RMSProp epsilon "optim_eps": 0.00001, # If not None, clip gradients during optimization at this value "grad_norm_clipping": 10, # Update the replay buffer with this many samples at once. Note that # this setting applies per-worker if num_workers > 1. "rollout_fragment_length": 4, # Minimum batch size used for training (in timesteps). With the default buffer # (ReplayBuffer) this means, sampling from the buffer (entire-episode SampleBatches) # as many times as is required to reach at least this number of timesteps. "train_batch_size": 32, # === Parallelism === # Number of workers for collecting samples with. This only makes sense # to increase if your environment is particularly slow to sample, or if # you"re using the Async or Ape-X optimizers. "num_workers": 0, # Whether to compute priorities on workers. "worker_side_prioritization": False, # Prevent reporting frequency from going lower than this time span. "min_time_s_per_reporting": 1, # === Model === "model": { "lstm_cell_size": 64, "max_seq_len": 999999, }, # Only torch supported so far. "framework": "torch", # Deprecated keys: # Use `replay_buffer_config.learning_starts` instead. "learning_starts": DEPRECATED_VALUE, # Use `replay_buffer_config.capacity` instead. "buffer_size": DEPRECATED_VALUE, }) # __sphinx_doc_end__ # fmt: on class QMixTrainer(SimpleQTrainer): @classmethod @override(SimpleQTrainer) def get_default_config(cls) -> TrainerConfigDict: return DEFAULT_CONFIG @override(SimpleQTrainer) def validate_config(self, config: TrainerConfigDict) -> None: # Call super's validation method. super().validate_config(config) if config["framework"] != "torch": raise ValueError("Only `framework=torch` supported so far for QMixTrainer!") @override(SimpleQTrainer) def get_default_policy_class(self, config: TrainerConfigDict) -> Type[Policy]: return QMixTorchPolicy @override(SimpleQTrainer) def training_iteration(self) -> ResultDict: """QMIX training iteration function. - Sample n MultiAgentBatches from n workers synchronously. - Store new samples in the replay buffer. - Sample one training MultiAgentBatch from the replay buffer. - Learn on the training batch. - Update the target network every `target_network_update_freq` steps. - Return all collected training metrics for the iteration. Returns: The results dict from executing the training iteration. """ # Sample n batches from n workers. new_sample_batches = synchronous_parallel_sample( worker_set=self.workers, concat=False ) for batch in new_sample_batches: # Update counters. self._counters[NUM_ENV_STEPS_SAMPLED] += batch.env_steps() self._counters[NUM_AGENT_STEPS_SAMPLED] += batch.agent_steps() # Store new samples in the replay buffer. self.local_replay_buffer.add(batch) # Sample n batches from replay buffer until the total number of timesteps # reaches `train_batch_size`. train_batch = sample_min_n_steps_from_buffer( replay_buffer=self.local_replay_buffer, min_steps=self.config["train_batch_size"], count_by_agent_steps=self._by_agent_steps, ) if train_batch is None: return {} # Learn on the training batch. # Use simple optimizer (only for multi-agent or tf-eager; all other # cases should use the multi-GPU optimizer, even if only using 1 GPU) if self.config.get("simple_optimizer") is True: train_results = train_one_step(self, train_batch) else: train_results = multi_gpu_train_one_step(self, train_batch) # TODO: Move training steps counter update outside of `train_one_step()` method. # # Update train step counters. # self._counters[NUM_ENV_STEPS_TRAINED] += train_batch.env_steps() # self._counters[NUM_AGENT_STEPS_TRAINED] += train_batch.agent_steps() # Update target network every `target_network_update_freq` steps. cur_ts = self._counters[NUM_ENV_STEPS_SAMPLED] last_update = self._counters[LAST_TARGET_UPDATE_TS] if cur_ts - last_update >= self.config["target_network_update_freq"]: to_update = self.workers.local_worker().get_policies_to_train() self.workers.local_worker().foreach_policy_to_train( lambda p, pid: pid in to_update and p.update_target() ) self._counters[NUM_TARGET_UPDATES] += 1 self._counters[LAST_TARGET_UPDATE_TS] = cur_ts # Update weights and global_vars - after learning on the local worker - on all # remote workers. global_vars = { "timestep": self._counters[NUM_ENV_STEPS_SAMPLED], } # Update remote workers' weights and global vars after learning on local worker. with self._timers[SYNCH_WORKER_WEIGHTS_TIMER]: self.workers.sync_weights(global_vars=global_vars) # Return all collected metrics for the iteration. return train_results