from typing import Optional, Type from ray.rllib.algorithms.simple_q.simple_q import SimpleQ, SimpleQConfig from ray.rllib.algorithms.qmix.qmix_policy import QMixTorchPolicy from ray.rllib.utils.replay_buffers.utils import update_priorities_in_replay_buffer from ray.rllib.execution.rollout_ops import ( synchronous_parallel_sample, ) from ray.rllib.execution.train_ops import ( multi_gpu_train_one_step, train_one_step, ) from ray.rllib.policy.policy import Policy from ray.rllib.utils.annotations import override from ray.rllib.utils.deprecation import Deprecated from ray.rllib.utils.metrics import ( LAST_TARGET_UPDATE_TS, NUM_AGENT_STEPS_SAMPLED, NUM_ENV_STEPS_SAMPLED, NUM_TARGET_UPDATES, SYNCH_WORKER_WEIGHTS_TIMER, ) from ray.rllib.utils.replay_buffers.utils import sample_min_n_steps_from_buffer from ray.rllib.utils.typing import ResultDict, AlgorithmConfigDict from ray.rllib.utils.deprecation import DEPRECATED_VALUE from ray.rllib.utils.deprecation import deprecation_warning class QMixConfig(SimpleQConfig): """Defines a configuration class from which QMix can be built. Example: >>> from ray.rllib.examples.env.two_step_game import TwoStepGame >>> from ray.rllib.algorithms.qmix import QMixConfig >>> config = QMixConfig().training(gamma=0.9, lr=0.01, kl_coeff=0.3)\ ... .resources(num_gpus=0)\ ... .rollouts(num_workers=4) >>> print(config.to_dict()) >>> # Build an Algorithm object from the config and run 1 training iteration. >>> algo = config.build(env=TwoStepGame) >>> algo.train() Example: >>> from ray.rllib.examples.env.two_step_game import TwoStepGame >>> from ray.rllib.algorithms.qmix import QMixConfig >>> from ray import tune >>> config = QMixConfig() >>> # Print out some default values. >>> print(config.optim_alpha) >>> # Update the config object. >>> config.training(lr=tune.grid_search([0.001, 0.0001]), optim_alpha=0.97) >>> # Set the config object's env. >>> config.environment(env=TwoStepGame) >>> # Use to_dict() to get the old-style python config dict >>> # when running with tune. >>> tune.run( ... "QMix", ... stop={"episode_reward_mean": 200}, ... config=config.to_dict(), ... ) """ def __init__(self): """Initializes a PPOConfig instance.""" super().__init__(algo_class=QMix) # fmt: off # __sphinx_doc_begin__ # QMix specific settings: self.mixer = "qmix" self.mixing_embed_dim = 32 self.double_q = True self.optim_alpha = 0.99 self.optim_eps = 0.00001 self.grad_clip = 10 # Override some of AlgorithmConfig's default values with QMix-specific values. # .training() self.lr = 0.0005 self.train_batch_size = 32 self.target_network_update_freq = 500 self.replay_buffer_config = { "type": "ReplayBuffer", # Specify prioritized replay by supplying a buffer type that supports # prioritization, for example: MultiAgentPrioritizedReplayBuffer. "prioritized_replay": DEPRECATED_VALUE, # Size of the replay buffer in batches "capacity": 1000, # Choosing `fragments` here makes it so that the buffer stores entire # batches, instead of sequences, episodes or timesteps. "storage_unit": "fragments", "learning_starts": 1000, # Whether to compute priorities on workers. "worker_side_prioritization": False, } self.model = { "lstm_cell_size": 64, "max_seq_len": 999999, } # .framework() self.framework_str = "torch" # .rollouts() self.num_workers = 0 self.rollout_fragment_length = 4 self.batch_mode = "complete_episodes" # .reporting() self.min_time_s_per_iteration = 1 self.min_sample_timesteps_per_iteration = 1000 # .exploration() self.exploration_config = { # The Exploration class to use. "type": "EpsilonGreedy", # Config for the Exploration class' constructor: "initial_epsilon": 1.0, "final_epsilon": 0.01, # Timesteps over which to anneal epsilon. "epsilon_timesteps": 40000, # For soft_q, use: # "exploration_config" = { # "type": "SoftQ" # "temperature": [float, e.g. 1.0] # } } # .evaluation() # Evaluate with epsilon=0 every `evaluation_interval` training iterations. # The evaluation stats will be reported under the "evaluation" metric key. # Note that evaluation is currently not parallelized, and that for Ape-X # metrics are already only reported for the lowest epsilon workers. self.evaluation_interval = None self.evaluation_duration = 10 self.evaluation_config = { "explore": False, } # __sphinx_doc_end__ # fmt: on self.worker_side_prioritization = DEPRECATED_VALUE @override(SimpleQConfig) def training( self, *, mixer: Optional[str] = None, mixing_embed_dim: Optional[int] = None, double_q: Optional[bool] = None, target_network_update_freq: Optional[int] = None, replay_buffer_config: Optional[dict] = None, optim_alpha: Optional[float] = None, optim_eps: Optional[float] = None, grad_norm_clipping: Optional[float] = None, grad_clip: Optional[float] = None, **kwargs, ) -> "QMixConfig": """Sets the training related configuration. Args: mixer: Mixing network. Either "qmix", "vdn", or None. mixing_embed_dim: Size of the mixing network embedding. double_q: Whether to use Double_Q learning. target_network_update_freq: Update the target network every `target_network_update_freq` sample steps. replay_buffer_config: optim_alpha: RMSProp alpha. optim_eps: RMSProp epsilon. grad_clip: If not None, clip gradients during optimization at this value. grad_norm_clipping: Depcrecated in favor of grad_clip Returns: This updated AlgorithmConfig object. """ # Pass kwargs onto super's `training()` method. super().training(**kwargs) if grad_norm_clipping is not None: deprecation_warning( old="grad_norm_clipping", new="grad_clip", help="Parameter `grad_norm_clipping` has been " "deprecated in favor of grad_clip in QMix. " "This is now the same parameter as in other " "algorithms. `grad_clip` will be overwritten by " "`grad_norm_clipping={}`".format(grad_norm_clipping), error=False, ) grad_clip = grad_norm_clipping if mixer is not None: self.mixer = mixer if mixing_embed_dim is not None: self.mixing_embed_dim = mixing_embed_dim if double_q is not None: self.double_q = double_q if target_network_update_freq is not None: self.target_network_update_freq = target_network_update_freq if replay_buffer_config is not None: self.replay_buffer_config = replay_buffer_config if optim_alpha is not None: self.optim_alpha = optim_alpha if optim_eps is not None: self.optim_eps = optim_eps if grad_clip is not None: self.grad_clip = grad_clip return self class QMix(SimpleQ): @classmethod @override(SimpleQ) def get_default_config(cls) -> AlgorithmConfigDict: return QMixConfig().to_dict() @override(SimpleQ) def validate_config(self, config: AlgorithmConfigDict) -> None: # Call super's validation method. super().validate_config(config) if config["framework"] != "torch": raise ValueError("Only `framework=torch` supported so far for QMix!") @override(SimpleQ) def get_default_policy_class(self, config: AlgorithmConfigDict) -> Type[Policy]: return QMixTorchPolicy @override(SimpleQ) def training_step(self) -> ResultDict: """QMIX training iteration function. - Sample n MultiAgentBatches from n workers synchronously. - Store new samples in the replay buffer. - Sample one training MultiAgentBatch from the replay buffer. - Learn on the training batch. - Update the target network every `target_network_update_freq` sample steps. - Return all collected training metrics for the iteration. Returns: The results dict from executing the training iteration. """ # Sample n batches from n workers. new_sample_batches = synchronous_parallel_sample( worker_set=self.workers, concat=False ) for batch in new_sample_batches: # Update counters. self._counters[NUM_ENV_STEPS_SAMPLED] += batch.env_steps() self._counters[NUM_AGENT_STEPS_SAMPLED] += batch.agent_steps() # Store new samples in the replay buffer. self.local_replay_buffer.add(batch) # Sample n batches from replay buffer until the total number of timesteps # reaches `train_batch_size`. train_batch = sample_min_n_steps_from_buffer( replay_buffer=self.local_replay_buffer, min_steps=self.config["train_batch_size"], count_by_agent_steps=self._by_agent_steps, ) if train_batch is None: return {} # Learn on the training batch. # Use simple optimizer (only for multi-agent or tf-eager; all other # cases should use the multi-GPU optimizer, even if only using 1 GPU) if self.config.get("simple_optimizer") is True: train_results = train_one_step(self, train_batch) else: train_results = multi_gpu_train_one_step(self, train_batch) # TODO: Move training steps counter update outside of `train_one_step()` method. # # Update train step counters. # self._counters[NUM_ENV_STEPS_TRAINED] += train_batch.env_steps() # self._counters[NUM_AGENT_STEPS_TRAINED] += train_batch.agent_steps() # Update target network every `target_network_update_freq` sample steps. cur_ts = self._counters[ NUM_AGENT_STEPS_SAMPLED if self._by_agent_steps else NUM_ENV_STEPS_SAMPLED ] last_update = self._counters[LAST_TARGET_UPDATE_TS] if cur_ts - last_update >= self.config["target_network_update_freq"]: to_update = self.workers.local_worker().get_policies_to_train() self.workers.local_worker().foreach_policy_to_train( lambda p, pid: pid in to_update and p.update_target() ) self._counters[NUM_TARGET_UPDATES] += 1 self._counters[LAST_TARGET_UPDATE_TS] = cur_ts update_priorities_in_replay_buffer( self.local_replay_buffer, self.config, train_batch, train_results ) # Update weights and global_vars - after learning on the local worker - on all # remote workers. global_vars = { "timestep": self._counters[NUM_ENV_STEPS_SAMPLED], } # Update remote workers' weights and global vars after learning on local worker. with self._timers[SYNCH_WORKER_WEIGHTS_TIMER]: self.workers.sync_weights(global_vars=global_vars) # Return all collected metrics for the iteration. return train_results # Deprecated: Use ray.rllib.algorithms.qmix.qmix.QMixConfig instead! class _deprecated_default_config(dict): def __init__(self): super().__init__(QMixConfig().to_dict()) @Deprecated( old="ray.rllib.algorithms.qmix.qmix.DEFAULT_CONFIG", new="ray.rllib.algorithms.qmix.qmix.QMixConfig(...)", error=False, ) def __getitem__(self, item): return super().__getitem__(item) DEFAULT_CONFIG = _deprecated_default_config()