""" Deep Q-Networks (DQN, Rainbow, Parametric DQN) ============================================== This file defines the distributed Trainer class for the Deep Q-Networks algorithm. See `dqn_[tf|torch]_policy.py` for the definition of the policies. Detailed documentation: https://docs.ray.io/en/master/rllib-algorithms.html#deep-q-networks-dqn-rainbow-parametric-dqn """ # noqa: E501 import logging from typing import List, Optional, Type from ray.rllib.agents.dqn.dqn_tf_policy import DQNTFPolicy from ray.rllib.agents.dqn.dqn_torch_policy import DQNTorchPolicy from ray.rllib.agents.dqn.simple_q import ( SimpleQTrainer, DEFAULT_CONFIG as SIMPLEQ_DEFAULT_CONFIG, ) from ray.rllib.agents.trainer import Trainer from ray.rllib.evaluation.worker_set import WorkerSet from ray.rllib.execution.concurrency_ops import Concurrently from ray.rllib.execution.metric_ops import StandardMetricsReporting from ray.rllib.execution.replay_ops import Replay, StoreToReplayBuffer from ray.rllib.execution.rollout_ops import ( ParallelRollouts, synchronous_parallel_sample, ) from ray.rllib.execution.train_ops import ( TrainOneStep, UpdateTargetNetwork, MultiGPUTrainOneStep, train_one_step, multi_gpu_train_one_step, ) from ray.rllib.policy.policy import Policy from ray.rllib.utils.annotations import override from ray.rllib.utils.replay_buffers.utils import update_priorities_in_replay_buffer from ray.rllib.utils.metrics.learner_info import LEARNER_STATS_KEY from ray.rllib.utils.typing import ( ResultDict, TrainerConfigDict, ) from ray.rllib.utils.metrics import ( NUM_ENV_STEPS_SAMPLED, NUM_AGENT_STEPS_SAMPLED, ) from ray.util.iter import LocalIterator from ray.rllib.utils.replay_buffers import MultiAgentPrioritizedReplayBuffer from ray.rllib.execution.buffers.multi_agent_replay_buffer import ( MultiAgentReplayBuffer as LegacyMultiAgentReplayBuffer, ) from ray.rllib.utils.deprecation import ( Deprecated, DEPRECATED_VALUE, ) from ray.rllib.utils.annotations import ExperimentalAPI from ray.rllib.utils.metrics import SYNCH_WORKER_WEIGHTS_TIMER from ray.rllib.execution.common import ( LAST_TARGET_UPDATE_TS, NUM_TARGET_UPDATES, ) logger = logging.getLogger(__name__) # fmt: off # __sphinx_doc_begin__ DEFAULT_CONFIG = Trainer.merge_trainer_configs( SIMPLEQ_DEFAULT_CONFIG, { # === Model === # Number of atoms for representing the distribution of return. When # this is greater than 1, distributional Q-learning is used. # the discrete supports are bounded by v_min and v_max "num_atoms": 1, "v_min": -10.0, "v_max": 10.0, # Whether to use noisy network "noisy": False, # control the initial value of noisy nets "sigma0": 0.5, # Whether to use dueling dqn "dueling": True, # Dense-layer setup for each the advantage branch and the value branch # in a dueling architecture. "hiddens": [256], # Whether to use double dqn "double_q": True, # N-step Q learning "n_step": 1, # === Replay buffer === # Deprecated, use capacity in replay_buffer_config instead. "buffer_size": DEPRECATED_VALUE, "replay_buffer_config": { # For now we don't use the new ReplayBuffer API here "_enable_replay_buffer_api": True, "type": "MultiAgentPrioritizedReplayBuffer", # Size of the replay buffer. Note that if async_updates is set, # then each worker will have a replay buffer of this size. "capacity": 50000, "prioritized_replay_alpha": 0.6, # Beta parameter for sampling from prioritized replay buffer. "prioritized_replay_beta": 0.4, # Epsilon to add to the TD errors when updating priorities. "prioritized_replay_eps": 1e-6, # The number of continuous environment steps to replay at once. This may # be set to greater than 1 to support recurrent models. "replay_sequence_length": 1, }, # Set this to True, if you want the contents of your buffer(s) to be # stored in any saved checkpoints as well. # Warnings will be created if: # - This is True AND restoring from a checkpoint that contains no buffer # data. # - This is False AND restoring from a checkpoint that does contain # buffer data. "store_buffer_in_checkpoints": False, # Callback to run before learning on a multi-agent batch of # experiences. "before_learn_on_batch": None, # The intensity with which to update the model (vs collecting samples # from the env). If None, uses the "natural" value of: # `train_batch_size` / (`rollout_fragment_length` x `num_workers` x # `num_envs_per_worker`). # If provided, will make sure that the ratio between ts inserted into # and sampled from the buffer matches the given value. # Example: # training_intensity=1000.0 # train_batch_size=250 rollout_fragment_length=1 # num_workers=1 (or 0) num_envs_per_worker=1 # -> natural value = 250 / 1 = 250.0 # -> will make sure that replay+train op will be executed 4x as # often as rollout+insert op (4 * 250 = 1000). # See: rllib/agents/dqn/dqn.py::calculate_rr_weights for further # details. "training_intensity": None, # === Parallelism === # Whether to compute priorities on workers. "worker_side_prioritization": False, # Experimental flag. # If True, the execution plan API will not be used. Instead, # a Trainer's `training_iteration` method will be called as-is each # training iteration. "_disable_execution_plan_api": True, }, _allow_unknown_configs=True, ) # __sphinx_doc_end__ # fmt: on def calculate_rr_weights(config: TrainerConfigDict) -> List[float]: """Calculate the round robin weights for the rollout and train steps""" if not config["training_intensity"]: return [1, 1] # Calculate the "native ratio" as: # [train-batch-size] / [size of env-rolled-out sampled data] # This is to set freshly rollout-collected data in relation to # the data we pull from the replay buffer (which also contains old # samples). native_ratio = config["train_batch_size"] / ( config["rollout_fragment_length"] * config["num_envs_per_worker"] * config["num_workers"] ) # Training intensity is specified in terms of # (steps_replayed / steps_sampled), so adjust for the native ratio. weights = [1, config["training_intensity"] / native_ratio] return weights class DQNTrainer(SimpleQTrainer): @classmethod @override(SimpleQTrainer) def get_default_config(cls) -> TrainerConfigDict: return DEFAULT_CONFIG @override(SimpleQTrainer) def validate_config(self, config: TrainerConfigDict) -> None: # Call super's validation method. super().validate_config(config) # Update effective batch size to include n-step adjusted_rollout_len = max(config["rollout_fragment_length"], config["n_step"]) config["rollout_fragment_length"] = adjusted_rollout_len @override(SimpleQTrainer) def get_default_policy_class( self, config: TrainerConfigDict ) -> Optional[Type[Policy]]: if config["framework"] == "torch": return DQNTorchPolicy else: return DQNTFPolicy @staticmethod @override(SimpleQTrainer) def execution_plan( workers: WorkerSet, config: TrainerConfigDict, **kwargs ) -> LocalIterator[dict]: assert ( "local_replay_buffer" in kwargs ), "DQN's execution plan requires a local replay buffer." # Assign to Trainer, so we can store the MultiAgentReplayBuffer's # data when we save checkpoints. local_replay_buffer = kwargs["local_replay_buffer"] rollouts = ParallelRollouts(workers, mode="bulk_sync") # We execute the following steps concurrently: # (1) Generate rollouts and store them in our local replay buffer. # Calling next() on store_op drives this. store_op = rollouts.for_each( StoreToReplayBuffer(local_buffer=local_replay_buffer) ) def update_prio(item): samples, info_dict = item prio_dict = {} for policy_id, info in info_dict.items(): # TODO(sven): This is currently structured differently for # torch/tf. Clean up these results/info dicts across # policies (note: fixing this in torch_policy.py will # break e.g. DDPPO!). td_error = info.get("td_error", info[LEARNER_STATS_KEY].get("td_error")) samples.policy_batches[policy_id].set_get_interceptor(None) batch_indices = samples.policy_batches[policy_id].get("batch_indexes") # In case the buffer stores sequences, TD-error could # already be calculated per sequence chunk. if len(batch_indices) != len(td_error): T = local_replay_buffer.replay_sequence_length assert ( len(batch_indices) > len(td_error) and len(batch_indices) % T == 0 ) batch_indices = batch_indices.reshape([-1, T])[:, 0] assert len(batch_indices) == len(td_error) prio_dict[policy_id] = (batch_indices, td_error) local_replay_buffer.update_priorities(prio_dict) return info_dict # (2) Read and train on experiences from the replay buffer. Every batch # returned from the LocalReplay() iterator is passed to TrainOneStep to # take a SGD step, and then we decide whether to update the target # network. post_fn = config.get("before_learn_on_batch") or (lambda b, *a: b) if config["simple_optimizer"]: train_step_op = TrainOneStep(workers) else: train_step_op = MultiGPUTrainOneStep( workers=workers, sgd_minibatch_size=config["train_batch_size"], num_sgd_iter=1, num_gpus=config["num_gpus"], _fake_gpus=config["_fake_gpus"], ) if ( type(local_replay_buffer) is LegacyMultiAgentReplayBuffer and config["replay_buffer_config"].get("prioritized_replay_alpha", 0.0) > 0.0 ) or isinstance(local_replay_buffer, MultiAgentPrioritizedReplayBuffer): update_prio_fn = update_prio else: def update_prio_fn(x): return x replay_op = ( Replay(local_buffer=local_replay_buffer) .for_each(lambda x: post_fn(x, workers, config)) .for_each(train_step_op) .for_each(update_prio_fn) .for_each( UpdateTargetNetwork(workers, config["target_network_update_freq"]) ) ) # Alternate deterministically between (1) and (2). # Only return the output of (2) since training metrics are not # available until (2) runs. train_op = Concurrently( [store_op, replay_op], mode="round_robin", output_indexes=[1], round_robin_weights=calculate_rr_weights(config), ) return StandardMetricsReporting(train_op, workers, config) @ExperimentalAPI def training_iteration(self) -> ResultDict: """DQN training iteration function. Each training iteration, we: - Sample (MultiAgentBatch) from workers. - Store new samples in replay buffer. - Sample training batch (MultiAgentBatch) from replay buffer. - Learn on training batch. - Update remote workers' new policy weights. - Update target network every target_network_update_freq steps. - Return all collected metrics for the iteration. Returns: The results dict from executing the training iteration. """ local_worker = self.workers.local_worker() train_results = {} # We alternate between storing new samples and sampling and training store_weight, sample_and_train_weight = calculate_rr_weights(self.config) for _ in range(store_weight): # (1) Sample (MultiAgentBatch) from workers new_sample_batch = synchronous_parallel_sample( worker_set=self.workers, concat=True ) # Update counters self._counters[NUM_AGENT_STEPS_SAMPLED] += new_sample_batch.agent_steps() self._counters[NUM_ENV_STEPS_SAMPLED] += new_sample_batch.env_steps() # (2) Store new samples in replay buffer self.local_replay_buffer.add_batch(new_sample_batch) for _ in range(sample_and_train_weight): # (3) Sample training batch (MultiAgentBatch) from replay buffer. train_batch = self.local_replay_buffer.replay() # Old-style replay buffers return None if learning has not started if not train_batch: continue # Postprocess batch before we learn on it post_fn = self.config.get("before_learn_on_batch") or (lambda b, *a: b) train_batch = post_fn(train_batch, self.workers, self.config) # (4) Learn on training batch. # Use simple optimizer (only for multi-agent or tf-eager; all other # cases should use the multi-GPU optimizer, even if only using 1 GPU) if self.config.get("simple_optimizer") is True: train_results = train_one_step(self, train_batch) else: train_results = multi_gpu_train_one_step(self, train_batch) # Update replay buffer priorities. update_priorities_in_replay_buffer( self.local_replay_buffer, self.config, train_batch, train_results, ) # (6) Update target network every target_network_update_freq steps cur_ts = self._counters[NUM_ENV_STEPS_SAMPLED] last_update = self._counters[LAST_TARGET_UPDATE_TS] if cur_ts - last_update >= self.config["target_network_update_freq"]: to_update = local_worker.get_policies_to_train() local_worker.foreach_policy_to_train( lambda p, pid: pid in to_update and p.update_target() ) self._counters[NUM_TARGET_UPDATES] += 1 self._counters[LAST_TARGET_UPDATE_TS] = cur_ts # Update remote workers's weights after learning on local worker if self.workers.remote_workers(): with self._timers[SYNCH_WORKER_WEIGHTS_TIMER]: self.workers.sync_weights() # (7) Return all collected metrics for the iteration. return train_results @Deprecated( new="Sub-class directly from `DQNTrainer` and override its methods", error=False ) class GenericOffPolicyTrainer(DQNTrainer): pass