From 4ccfd07a614c3b0620e75ccb47c327bb0f817299 Mon Sep 17 00:00:00 2001 From: desktable Date: Tue, 15 Sep 2020 03:37:07 -0700 Subject: [PATCH] [RLlib] Add docstrings for agents/dqn (#10710) --- rllib/agents/dqn/apex.py | 44 +++++++++---- rllib/agents/dqn/distributional_q_tf_model.py | 44 +++++++------ rllib/agents/dqn/dqn.py | 58 ++++++++++++------ rllib/agents/dqn/dqn_tf_policy.py | 61 ++++++++++++++----- rllib/agents/dqn/dqn_torch_model.py | 35 ++++++----- rllib/agents/dqn/dqn_torch_policy.py | 49 ++++++++++++--- rllib/agents/dqn/learner_thread.py | 2 +- rllib/agents/dqn/simple_q.py | 21 +++++-- rllib/agents/dqn/simple_q_tf_policy.py | 59 +++++++++++++++--- rllib/agents/dqn/simple_q_torch_policy.py | 50 +++++++++++---- rllib/agents/ppo/appo_tf_policy.py | 2 +- 11 files changed, 309 insertions(+), 116 deletions(-) diff --git a/rllib/agents/dqn/apex.py b/rllib/agents/dqn/apex.py index 2ea8ebf91..05577b5de 100644 --- a/rllib/agents/dqn/apex.py +++ b/rllib/agents/dqn/apex.py @@ -1,22 +1,38 @@ +""" +Distributed Prioritized Experience Replay (Ape-X) +================================================= + +This file defines a DQN trainer using the Ape-X architecture. + +Ape-X uses a single GPU learner and many CPU workers for experience collection. +Experience collection can scale to hundreds of CPU workers due to the +distributed prioritization of experience prior to storage in replay buffers. + +Detailed documentation: +https://docs.ray.io/en/latest/rllib-algorithms.html#distributed-prioritized-experience-replay-ape-x +""" # noqa: E501 + import collections import copy +from typing import Tuple import ray -from ray.rllib.agents.dqn.dqn import DQNTrainer, \ - DEFAULT_CONFIG as DQN_CONFIG, calculate_rr_weights +from ray.rllib.agents.dqn.dqn import DEFAULT_CONFIG as DQN_CONFIG +from ray.rllib.agents.dqn.dqn import DQNTrainer, calculate_rr_weights from ray.rllib.agents.dqn.learner_thread import LearnerThread -from ray.rllib.execution.common import STEPS_TRAINED_COUNTER, \ - _get_shared_metrics, _get_global_vars from ray.rllib.evaluation.worker_set import WorkerSet -from ray.rllib.execution.rollout_ops import ParallelRollouts -from ray.rllib.execution.concurrency_ops import Concurrently, Enqueue, Dequeue -from ray.rllib.execution.replay_ops import StoreToReplayBuffer, Replay -from ray.rllib.execution.train_ops import UpdateTargetNetwork +from ray.rllib.execution.common import (STEPS_TRAINED_COUNTER, + _get_global_vars, _get_shared_metrics) +from ray.rllib.execution.concurrency_ops import Concurrently, Dequeue, Enqueue from ray.rllib.execution.metric_ops import StandardMetricsReporting from ray.rllib.execution.replay_buffer import ReplayActor +from ray.rllib.execution.replay_ops import Replay, StoreToReplayBuffer +from ray.rllib.execution.rollout_ops import ParallelRollouts +from ray.rllib.execution.train_ops import UpdateTargetNetwork from ray.rllib.utils import merge_dicts from ray.rllib.utils.actors import create_colocated from ray.rllib.utils.typing import SampleBatchType +from ray.util.iter import LocalIterator # yapf: disable # __sphinx_doc_begin__ @@ -53,14 +69,15 @@ APEX_DEFAULT_CONFIG = merge_dicts( # Update worker weights as they finish generating experiences. class UpdateWorkerWeights: - def __init__(self, learner_thread, workers, max_weight_sync_delay): + def __init__(self, learner_thread: LearnerThread, workers: WorkerSet, + max_weight_sync_delay: int): self.learner_thread = learner_thread self.workers = workers self.steps_since_update = collections.defaultdict(int) self.max_weight_sync_delay = max_weight_sync_delay self.weights = None - def __call__(self, item: ("ActorHandle", SampleBatchType)): + def __call__(self, item: Tuple["ActorHandle", SampleBatchType]): actor, batch = item self.steps_since_update[actor] += batch.count if self.steps_since_update[actor] >= self.max_weight_sync_delay: @@ -77,7 +94,8 @@ class UpdateWorkerWeights: metrics.counters["num_weight_syncs"] += 1 -def apex_execution_plan(workers: WorkerSet, config: dict): +def apex_execution_plan(workers: WorkerSet, + config: dict) -> LocalIterator[dict]: # Create a number of replay buffer actors. num_replay_buffer_shards = config["optimizer"]["num_replay_buffer_shards"] replay_actors = create_colocated(ReplayActor, [ @@ -97,7 +115,7 @@ def apex_execution_plan(workers: WorkerSet, config: dict): learner_thread.start() # Update experience priorities post learning. - def update_prio_and_stats(item: ("ActorHandle", dict, int)): + def update_prio_and_stats(item: Tuple["ActorHandle", dict, int]) -> None: actor, prio_dict, count = item actor.update_priorities.remote(prio_dict) metrics = _get_shared_metrics() @@ -155,7 +173,7 @@ def apex_execution_plan(workers: WorkerSet, config: dict): [store_op, replay_op, update_op], mode="async", output_indexes=[2]) # Add in extra replay and learner metrics to the training result. - def add_apex_metrics(result): + def add_apex_metrics(result: dict) -> dict: replay_stats = ray.get(replay_actors[0].stats.remote( config["optimizer"].get("debug"))) exploration_infos = workers.foreach_trainable_policy( diff --git a/rllib/agents/dqn/distributional_q_tf_model.py b/rllib/agents/dqn/distributional_q_tf_model.py index 169fa1ed6..8bd8b99ef 100644 --- a/rllib/agents/dqn/distributional_q_tf_model.py +++ b/rllib/agents/dqn/distributional_q_tf_model.py @@ -1,6 +1,12 @@ +"""Tensorflow model for DQN""" + +from typing import List + +import gym from ray.rllib.models.tf.layers import NoisyLayer from ray.rllib.models.tf.tf_modelv2 import TFModelV2 from ray.rllib.utils.framework import try_import_tf +from ray.rllib.utils.typing import ModelConfigDict, TensorType tf1, tf, tfv = try_import_tf() @@ -20,23 +26,23 @@ class DistributionalQTFModel(TFModelV2): def __init__( self, - obs_space, - action_space, - num_outputs, - model_config, - name, + obs_space: gym.spaces.Space, + action_space: gym.spaces.Space, + num_outputs: int, + model_config: ModelConfigDict, + name: str, q_hiddens=(256, ), - dueling=False, - num_atoms=1, - use_noisy=False, - v_min=-10.0, - v_max=10.0, - sigma0=0.5, + dueling: bool = False, + num_atoms: int = 1, + use_noisy: bool = False, + v_min: float = -10.0, + v_max: float = 10.0, + sigma0: float = 0.5, # TODO(sven): Move `add_layer_norm` into ModelCatalog as # generic option, then error if we use ParameterNoise as # Exploration type and do not have any LayerNorm layers in # the net. - add_layer_norm=False): + add_layer_norm: bool = False): """Initialize variables of this model. Extra model kwargs: @@ -60,7 +66,6 @@ class DistributionalQTFModel(TFModelV2): only defines the layers for the Q head. Those layers for forward() should be defined in subclasses of DistributionalQModel. """ - super(DistributionalQTFModel, self).__init__( obs_space, action_space, num_outputs, model_config, name) @@ -68,7 +73,8 @@ class DistributionalQTFModel(TFModelV2): self.model_out = tf.keras.layers.Input( shape=(num_outputs, ), name="model_out") - def build_action_value(prefix, model_out): + def build_action_value(prefix: str, + model_out: TensorType) -> List[TensorType]: if q_hiddens: action_out = model_out for i in range(len(q_hiddens)): @@ -129,7 +135,8 @@ class DistributionalQTFModel(TFModelV2): dist = tf.expand_dims(tf.ones_like(action_scores), -1) return [action_scores, logits, dist] - def build_state_score(prefix, model_out): + def build_state_score(prefix: str, + model_out: TensorType) -> TensorType: state_out = model_out for i in range(len(q_hiddens)): if use_noisy: @@ -163,7 +170,8 @@ class DistributionalQTFModel(TFModelV2): self.state_value_head = tf.keras.Model(self.model_out, state_out) self.register_variables(self.state_value_head.variables) - def get_q_value_distributions(self, model_out): + def get_q_value_distributions(self, + model_out: TensorType) -> List[TensorType]: """Returns distributional values for Q(s, a) given a state embedding. Override this in your custom model to customize the Q output head. @@ -175,10 +183,8 @@ class DistributionalQTFModel(TFModelV2): (action_scores, logits, dist) if num_atoms == 1, otherwise (action_scores, z, support_logits_per_action, logits, dist) """ - return self.q_value_head(model_out) - def get_state_value(self, model_out): + def get_state_value(self, model_out: TensorType) -> TensorType: """Returns the state value prediction for the given state embedding.""" - return self.state_value_head(model_out) diff --git a/rllib/agents/dqn/dqn.py b/rllib/agents/dqn/dqn.py index 6a4ba288b..04362a399 100644 --- a/rllib/agents/dqn/dqn.py +++ b/rllib/agents/dqn/dqn.py @@ -1,8 +1,19 @@ +""" +Deep Q-Networks (DQN, Rainbow, Parametric DQN) +============================================== + +This file defines the distributed Trainer class for the Deep Q-Networks +algorithm. See `dqn_[tf|torch]_policy.py` for the definition of the policies. + +Detailed documentation: +https://docs.ray.io/en/latest/rllib-algorithms.html#deep-q-networks-dqn-rainbow-parametric-dqn +""" # noqa: E501 + import logging -from typing import Type +from typing import List, Optional, Type from ray.rllib.agents.dqn.dqn_tf_policy import DQNTFPolicy -from ray.rllib.agents.dqn.simple_q_tf_policy import SimpleQTFPolicy +from ray.rllib.agents.dqn.dqn_torch_policy import DQNTorchPolicy from ray.rllib.agents.trainer import with_common_config from ray.rllib.agents.trainer_template import build_trainer from ray.rllib.evaluation.worker_set import WorkerSet @@ -158,6 +169,16 @@ def validate_config(config: TrainerConfigDict) -> None: def execution_plan(workers: WorkerSet, config: TrainerConfigDict) -> LocalIterator[dict]: + """Execution plan of the DQN algorithm. Defines the distributed dataflow. + + Args: + workers (WorkerSet): The WorkerSet for training the Polic(y/ies) + of the Trainer. + config (TrainerConfigDict): The trainer's configuration dict. + + Returns: + LocalIterator[dict]: A local iterator over training metrics. + """ if config.get("prioritized_replay"): prio_args = { "prioritized_replay_alpha": config["prioritized_replay_alpha"], @@ -222,7 +243,8 @@ def execution_plan(workers: WorkerSet, return StandardMetricsReporting(train_op, workers, config) -def calculate_rr_weights(config: TrainerConfigDict): +def calculate_rr_weights(config: TrainerConfigDict) -> List[float]: + """Calculate the round robin weights for the rollout and train steps""" if not config["training_intensity"]: return [1, 1] # e.g., 32 / 4 -> native ratio of 8.0 @@ -234,23 +256,22 @@ def calculate_rr_weights(config: TrainerConfigDict): return weights -def get_policy_class(config: TrainerConfigDict) -> Type[Policy]: +def get_policy_class(config: TrainerConfigDict) -> Optional[Type[Policy]]: + """Policy class picker function. Class is chosen based on DL-framework. + + Args: + config (TrainerConfigDict): The trainer's configuration dict. + + Returns: + Optional[Type[Policy]]: The Policy class to use with DQNTrainer. + If None, use `default_policy` provided in build_trainer(). + """ if config["framework"] == "torch": - from ray.rllib.agents.dqn.dqn_torch_policy import DQNTorchPolicy return DQNTorchPolicy - else: - return DQNTFPolicy - - -def get_simple_policy_class(config: TrainerConfigDict) -> Type[Policy]: - if config["framework"] == "torch": - from ray.rllib.agents.dqn.simple_q_torch_policy import \ - SimpleQTorchPolicy - return SimpleQTorchPolicy - else: - return SimpleQTFPolicy +# Build a generic off-policy trainer. Other trainers (such as DDPGTrainer) +# may build on top of it. GenericOffPolicyTrainer = build_trainer( name="GenericOffPolicyAlgorithm", default_policy=None, @@ -259,8 +280,7 @@ GenericOffPolicyTrainer = build_trainer( validate_config=validate_config, execution_plan=execution_plan) +# Build a DQN trainer, which uses the framework specific Policy +# determined in `get_policy_class()` above. DQNTrainer = GenericOffPolicyTrainer.with_updates( name="DQN", default_policy=DQNTFPolicy, default_config=DEFAULT_CONFIG) - -SimpleQTrainer = DQNTrainer.with_updates( - default_policy=SimpleQTFPolicy, get_policy_class=get_simple_policy_class) diff --git a/rllib/agents/dqn/dqn_tf_policy.py b/rllib/agents/dqn/dqn_tf_policy.py index 177129f20..d1d6d4570 100644 --- a/rllib/agents/dqn/dqn_tf_policy.py +++ b/rllib/agents/dqn/dqn_tf_policy.py @@ -1,3 +1,5 @@ +"""Tensorflow policy class used for DQN""" + from typing import Dict import gym @@ -33,18 +35,18 @@ PRIO_WEIGHTS = "weights" class QLoss: def __init__(self, - q_t_selected, - q_logits_t_selected, - q_tp1_best, - q_dist_tp1_best, - importance_weights, - rewards, - done_mask, - gamma=0.99, - n_step=1, - num_atoms=1, - v_min=-10.0, - v_max=10.0): + q_t_selected: TensorType, + q_logits_t_selected: TensorType, + q_tp1_best: TensorType, + q_dist_tp1_best: TensorType, + importance_weights: TensorType, + rewards: TensorType, + done_mask: TensorType, + gamma: float = 0.99, + n_step: int = 1, + num_atoms: int = 1, + v_min: float = -10.0, + v_max: float = 10.0): if num_atoms > 1: # Distributional Q-learning which corresponds to an entropy loss @@ -110,6 +112,11 @@ class QLoss: class ComputeTDErrorMixin: + """Assign the `compute_td_error` method to the DQNTFPolicy + + This allows us to prioritize on the worker side. + """ + def __init__(self): @make_tf_callable(self.get_session(), dynamic_shape=True) def compute_td_error(obs_t, act_t, rew_t, obs_tp1, done_mask, @@ -130,10 +137,22 @@ class ComputeTDErrorMixin: self.compute_td_error = compute_td_error -def build_q_model(policy: Policy, obs_space: gym.Space, - action_space: gym.Space, +def build_q_model(policy: Policy, obs_space: gym.spaces.Space, + action_space: gym.spaces.Space, config: TrainerConfigDict) -> ModelV2: + """Build q_model and target_q_model for DQN + Args: + policy (Policy): The Policy, which will use the model for optimization. + obs_space (gym.spaces.Space): The policy's observation space. + action_space (gym.spaces.Space): The policy's action space. + config (TrainerConfigDict): + + Returns: + ModelV2: The Model for the Policy to use. + Note: The target q model will not be returned, just assigned to + `policy.target_q_model`. + """ if not isinstance(action_space, gym.spaces.Discrete): raise UnsupportedSpaceException( "Action space {} is not supported for DQN.".format(action_space)) @@ -206,6 +225,16 @@ def get_distribution_inputs_and_class(policy: Policy, def build_q_losses(policy: Policy, model, _, train_batch: SampleBatch) -> TensorType: + """Constructs the loss for DQNTFPolicy. + + Args: + policy (Policy): The Policy to calculate the loss for. + model (ModelV2): The Model to calculate the loss for. + train_batch (SampleBatch): The training data. + + Returns: + TensorType: A single loss tensor. + """ config = policy.config # q network evaluation q_t, q_logits_t, q_dist_t = compute_q_values( @@ -300,8 +329,8 @@ def setup_mid_mixins(policy: Policy, obs_space, action_space, config) -> None: ComputeTDErrorMixin.__init__(policy) -def setup_late_mixins(policy: Policy, obs_space: gym.Space, - action_space: gym.Space, +def setup_late_mixins(policy: Policy, obs_space: gym.spaces.Space, + action_space: gym.spaces.Space, config: TrainerConfigDict) -> None: TargetNetworkMixin.__init__(policy, obs_space, action_space, config) diff --git a/rllib/agents/dqn/dqn_torch_model.py b/rllib/agents/dqn/dqn_torch_model.py index 2764f0bb7..ff15783d1 100644 --- a/rllib/agents/dqn/dqn_torch_model.py +++ b/rllib/agents/dqn/dqn_torch_model.py @@ -1,7 +1,12 @@ +"""PyTorch model for DQN""" + +from typing import Sequence +import gym from ray.rllib.models.torch.misc import SlimFC from ray.rllib.models.torch.modules.noisy_layer import NoisyLayer from ray.rllib.models.torch.torch_modelv2 import TorchModelV2 from ray.rllib.utils.framework import try_import_torch +from ray.rllib.utils.typing import ModelConfigDict torch, nn = try_import_torch() @@ -12,29 +17,29 @@ class DQNTorchModel(TorchModelV2, nn.Module): def __init__( self, - obs_space, - action_space, - num_outputs, - model_config, - name, + obs_space: gym.spaces.Space, + action_space: gym.spaces.Space, + num_outputs: int, + model_config: ModelConfigDict, + name: str, *, - q_hiddens=(256, ), - dueling=False, - dueling_activation="relu", - num_atoms=1, - use_noisy=False, - v_min=-10.0, - v_max=10.0, - sigma0=0.5, + q_hiddens: Sequence[int] = (256, ), + dueling: bool = False, + dueling_activation: str = "relu", + num_atoms: int = 1, + use_noisy: bool = False, + v_min: float = -10.0, + v_max: float = 10.0, + sigma0: float = 0.5, # TODO(sven): Move `add_layer_norm` into ModelCatalog as # generic option, then error if we use ParameterNoise as # Exploration type and do not have any LayerNorm layers in # the net. - add_layer_norm=False): + add_layer_norm: bool = False): """Initialize variables of this model. Extra model kwargs: - q_hiddens (List[int]): List of layer-sizes after(!) the + q_hiddens (Sequence[int]): List of layer-sizes after(!) the Advantages(A)/Value(V)-split. Hence, each of the A- and V- branches will have this structure of Dense layers. To define the NN before this A/V-split, use - as always - diff --git a/rllib/agents/dqn/dqn_torch_policy.py b/rllib/agents/dqn/dqn_torch_policy.py index e400f6b24..78b0cba31 100644 --- a/rllib/agents/dqn/dqn_torch_policy.py +++ b/rllib/agents/dqn/dqn_torch_policy.py @@ -1,3 +1,5 @@ +"""PyTorch policy class used for DQN""" + from typing import Dict, List, Tuple import gym @@ -31,13 +33,13 @@ if nn: class QLoss: def __init__(self, - q_t_selected, - q_logits_t_selected, - q_tp1_best, - q_probs_tp1_best, - importance_weights, - rewards, - done_mask, + q_t_selected: TensorType, + q_logits_t_selected: TensorType, + q_tp1_best: TensorType, + q_probs_tp1_best: TensorType, + importance_weights: TensorType, + rewards: TensorType, + done_mask: TensorType, gamma=0.99, n_step=1, num_atoms=1, @@ -103,6 +105,11 @@ class QLoss: class ComputeTDErrorMixin: + """Assign the `compute_td_error` method to the DQNTorchPolicy + + This allows us to prioritize on the worker side. + """ + def __init__(self): def compute_td_error(obs_t, act_t, rew_t, obs_tp1, done_mask, importance_weights): @@ -122,9 +129,22 @@ class ComputeTDErrorMixin: def build_q_model_and_distribution( - policy: Policy, obs_space: gym.Space, action_space: gym.Space, + policy: Policy, obs_space: gym.spaces.Space, + action_space: gym.spaces.Space, config: TrainerConfigDict) -> Tuple[ModelV2, TorchDistributionWrapper]: + """Build q_model and target_q_model for DQN + Args: + policy (Policy): The policy, which will use the model for optimization. + obs_space (gym.spaces.Space): The policy's observation space. + action_space (gym.spaces.Space): The policy's action space. + config (TrainerConfigDict): + + Returns: + (q_model, TorchCategorical) + Note: The target q model will not be returned, just assigned to + `policy.target_q_model`. + """ if not isinstance(action_space, gym.spaces.Discrete): raise UnsupportedSpaceException( "Action space {} is not supported for DQN.".format(action_space)) @@ -204,6 +224,16 @@ def get_distribution_inputs_and_class( def build_q_losses(policy: Policy, model, _, train_batch: SampleBatch) -> TensorType: + """Constructs the loss for DQNTorchPolicy. + + Args: + policy (Policy): The Policy to calculate the loss for. + model (ModelV2): The Model to calculate the loss for. + train_batch (SampleBatch): The training data. + + Returns: + TensorType: A single loss tensor. + """ config = policy.config # Q-network evaluation. q_t, q_logits_t, q_probs_t = compute_q_values( @@ -286,7 +316,8 @@ def setup_early_mixins(policy: Policy, obs_space, action_space, LearningRateSchedule.__init__(policy, config["lr"], config["lr_schedule"]) -def after_init(policy: Policy, obs_space: gym.Space, action_space: gym.Space, +def after_init(policy: Policy, obs_space: gym.spaces.Space, + action_space: gym.spaces.Space, config: TrainerConfigDict) -> None: ComputeTDErrorMixin.__init__(policy) TargetNetworkMixin.__init__(policy, obs_space, action_space, config) diff --git a/rllib/agents/dqn/learner_thread.py b/rllib/agents/dqn/learner_thread.py index 57d73aa7a..400c6d902 100644 --- a/rllib/agents/dqn/learner_thread.py +++ b/rllib/agents/dqn/learner_thread.py @@ -1,5 +1,5 @@ +import queue import threading -from six.moves import queue from ray.rllib.evaluation.metrics import get_learner_stats from ray.rllib.policy.policy import LEARNER_STATS_KEY diff --git a/rllib/agents/dqn/simple_q.py b/rllib/agents/dqn/simple_q.py index 443daf7f8..f2fbc59b8 100644 --- a/rllib/agents/dqn/simple_q.py +++ b/rllib/agents/dqn/simple_q.py @@ -1,8 +1,11 @@ """ -Simple Q (simple_q) -=================== +Simple Q-Learning +================= -This file defines the distributed Trainer class for the simple Q learning. +This module provides a basic implementation of the DQN algorithm without any +optimizations. + +This file defines the distributed Trainer class for the Simple Q algorithm. See `simple_q_[tf|torch]_policy.py` for the definition of the policy loss. """ @@ -99,7 +102,7 @@ def get_policy_class(config: TrainerConfigDict) -> Optional[Type[Policy]]: config (TrainerConfigDict): The trainer's configuration dict. Returns: - Optional[Type[Policy]]: The Policy class to use with PGTrainer. + Optional[Type[Policy]]: The Policy class to use with SimpleQTrainer. If None, use `default_policy` provided in build_trainer(). """ if config["framework"] == "torch": @@ -108,6 +111,16 @@ def get_policy_class(config: TrainerConfigDict) -> Optional[Type[Policy]]: def execution_plan(workers: WorkerSet, config: TrainerConfigDict) -> LocalIterator[dict]: + """Execution plan of the Simple Q algorithm. Defines the distributed dataflow. + + Args: + workers (WorkerSet): The WorkerSet for training the Polic(y/ies) + of the Trainer. + config (TrainerConfigDict): The trainer's configuration dict. + + Returns: + LocalIterator[dict]: A local iterator over training metrics. + """ local_replay_buffer = LocalReplayBuffer( num_shards=1, learning_starts=config["learning_starts"], diff --git a/rllib/agents/dqn/simple_q_tf_policy.py b/rllib/agents/dqn/simple_q_tf_policy.py index 526980c1a..515e64eef 100644 --- a/rllib/agents/dqn/simple_q_tf_policy.py +++ b/rllib/agents/dqn/simple_q_tf_policy.py @@ -1,4 +1,4 @@ -"""Basic example of a DQN policy without any optimizations.""" +"""TensorFlow policy class used for Simple Q-Learning""" import logging from typing import List, Tuple, Type @@ -11,6 +11,7 @@ from ray.rllib.models.tf.tf_action_dist import (Categorical, TFActionDistribution) from ray.rllib.models.torch.torch_action_dist import TorchCategorical from ray.rllib.policy import Policy +from ray.rllib.policy.dynamic_tf_policy import DynamicTFPolicy from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.policy.tf_policy import TFPolicy from ray.rllib.policy.tf_policy_template import build_tf_policy @@ -28,8 +29,14 @@ Q_TARGET_SCOPE = "target_q_func" class TargetNetworkMixin: - def __init__(self, obs_space: gym.Space, action_space: gym.Space, - config: TrainerConfigDict): + """Assign the `update_target` method to the SimpleQTFPolicy + + The function is called every `target_network_update_freq` steps by the + master learner. + """ + + def __init__(self, obs_space: gym.spaces.Space, + action_space: gym.spaces.Space, config: TrainerConfigDict): @make_tf_callable(self.get_session()) def do_update(): # update_target_fn will be called periodically to copy Q network to @@ -50,10 +57,24 @@ class TargetNetworkMixin: return self.q_func_vars + self.target_q_func_vars -def build_q_models(policy: Policy, obs_space: gym.Space, - action_space: gym.Space, +def build_q_models(policy: Policy, obs_space: gym.spaces.Space, + action_space: gym.spaces.Space, config: TrainerConfigDict) -> ModelV2: + """Build q_model and target_q_model for Simple Q learning + Note that this function works for both Tensorflow and PyTorch. + + Args: + policy (Policy): The Policy, which will use the model for optimization. + obs_space (gym.spaces.Space): The policy's observation space. + action_space (gym.spaces.Space): The policy's action space. + config (TrainerConfigDict): + + Returns: + ModelV2: The Model for the Policy to use. + Note: The target q model will not be returned, just assigned to + `policy.target_q_model`. + """ if not isinstance(action_space, gym.spaces.Discrete): raise UnsupportedSpaceException( "Action space {} is not supported for DQN.".format(action_space)) @@ -88,6 +109,7 @@ def get_distribution_inputs_and_class( explore=True, is_training=True, **kwargs) -> Tuple[TensorType, type, List[TensorType]]: + """Build the action distribution""" q_vals = compute_q_values(policy, q_model, obs_batch, explore, is_training) q_vals = q_vals[0] if isinstance(q_vals, tuple) else q_vals @@ -100,6 +122,17 @@ def get_distribution_inputs_and_class( def build_q_losses(policy: Policy, model: ModelV2, dist_class: Type[TFActionDistribution], train_batch: SampleBatch) -> TensorType: + """Constructs the loss for SimpleQTFPolicy. + + Args: + policy (Policy): The Policy to calculate the loss for. + model (ModelV2): The Model to calculate the loss for. + dist_class (Type[ActionDistribution]): The action distribution class. + train_batch (SampleBatch): The training data. + + Returns: + TensorType: A single loss tensor. + """ # q network evaluation q_t = compute_q_values( policy, @@ -156,13 +189,23 @@ def compute_q_values(policy: Policy, return model_out -def setup_late_mixins(policy: Policy, obs_space: gym.Space, - action_space: gym.Space, +def setup_late_mixins(policy: Policy, obs_space: gym.spaces.Space, + action_space: gym.spaces.Space, config: TrainerConfigDict) -> None: + """Call all mixin classes' constructors before SimpleQTFPolicy initialization. + + Args: + policy (Policy): The Policy object. + obs_space (gym.spaces.Space): The Policy's observation space. + action_space (gym.spaces.Space): The Policy's action space. + config (TrainerConfigDict): The Policy's config. + """ TargetNetworkMixin.__init__(policy, obs_space, action_space, config) -SimpleQTFPolicy = build_tf_policy( +# Build a child class of `DynamicTFPolicy`, given the custom functions defined +# above. +SimpleQTFPolicy: DynamicTFPolicy = build_tf_policy( name="SimpleQTFPolicy", get_default_config=lambda: ray.rllib.agents.dqn.dqn.DEFAULT_CONFIG, make_model=build_q_models, diff --git a/rllib/agents/dqn/simple_q_torch_policy.py b/rllib/agents/dqn/simple_q_torch_policy.py index fbdcc05ae..b9ec0f0c4 100644 --- a/rllib/agents/dqn/simple_q_torch_policy.py +++ b/rllib/agents/dqn/simple_q_torch_policy.py @@ -1,14 +1,15 @@ -"""Basic example of a DQN policy without any optimizations.""" +"""PyTorch policy class used for Simple Q-Learning""" import logging -from typing import Dict +from typing import Dict, Tuple import gym import ray from ray.rllib.agents.dqn.simple_q_tf_policy import ( build_q_models, compute_q_values, get_distribution_inputs_and_class) from ray.rllib.models.modelv2 import ModelV2 -from ray.rllib.models.torch.torch_action_dist import TorchCategorical +from ray.rllib.models.torch.torch_action_dist import TorchCategorical, \ + TorchDistributionWrapper from ray.rllib.policy import Policy from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.policy.torch_policy_template import build_torch_policy @@ -24,8 +25,14 @@ logger = logging.getLogger(__name__) class TargetNetworkMixin: - def __init__(self, obs_space: gym.Space, action_space: gym.Space, - config: TrainerConfigDict): + """Assign the `update_target` method to the SimpleQTorchPolicy + + The function is called every `target_network_update_freq` steps by the + master learner. + """ + + def __init__(self, obs_space: gym.spaces.Space, + action_space: gym.spaces.Space, config: TrainerConfigDict): def do_update(): # Update_target_fn will be called periodically to copy Q network to # target Q network. @@ -36,15 +43,27 @@ class TargetNetworkMixin: self.update_target = do_update -def build_q_model_and_distribution(policy: Policy, obs_space: gym.Space, - action_space: gym.Space, - config: TrainerConfigDict) -> ModelV2: +def build_q_model_and_distribution( + policy: Policy, obs_space: gym.spaces.Space, + action_space: gym.spaces.Space, + config: TrainerConfigDict) -> Tuple[ModelV2, TorchDistributionWrapper]: return build_q_models(policy, obs_space, action_space, config), \ TorchCategorical def build_q_losses(policy: Policy, model, dist_class, train_batch: SampleBatch) -> TensorType: + """Constructs the loss for SimpleQTorchPolicy. + + Args: + policy (Policy): The Policy to calculate the loss for. + model (ModelV2): The Model to calculate the loss for. + dist_class (Type[ActionDistribution]): The action distribution class. + train_batch (SampleBatch): The training data. + + Returns: + TensorType: A single loss tensor. + """ # q network evaluation q_t = compute_q_values( policy, @@ -89,13 +108,22 @@ def build_q_losses(policy: Policy, model, dist_class, def extra_action_out_fn(policy: Policy, input_dict, state_batches, model, action_dist) -> Dict[str, TensorType]: - """Adds q-values to action out dict.""" + """Adds q-values to the action out dict.""" return {"q_values": policy.q_values} -def setup_late_mixins(policy: Policy, obs_space: gym.Space, - action_space: gym.Space, +def setup_late_mixins(policy: Policy, obs_space: gym.spaces.Space, + action_space: gym.spaces.Space, config: TrainerConfigDict) -> None: + """Call all mixin classes' constructors before SimpleQTorchPolicy + initialization. + + Args: + policy (Policy): The Policy object. + obs_space (gym.spaces.Space): The Policy's observation space. + action_space (gym.spaces.Space): The Policy's action space. + config (TrainerConfigDict): The Policy's config. + """ TargetNetworkMixin.__init__(policy, obs_space, action_space, config) diff --git a/rllib/agents/ppo/appo_tf_policy.py b/rllib/agents/ppo/appo_tf_policy.py index e5ed2cef3..4cb122005 100644 --- a/rllib/agents/ppo/appo_tf_policy.py +++ b/rllib/agents/ppo/appo_tf_policy.py @@ -91,7 +91,7 @@ def appo_surrogate_loss( Args: policy (Policy): The Policy to calculate the loss for. model (ModelV2): The Model to calculate the loss for. - dist_class (Type[ActionDistribution]: The action distr. class. + dist_class (Type[ActionDistribution]): The action distr. class. train_batch (SampleBatch): The training data. Returns: