[RLlib] Add docstrings for agents/dqn (#10710)

This commit is contained in:
desktable 2020-09-15 03:37:07 -07:00 committed by GitHub
parent 34bb61dabc
commit 4ccfd07a61
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
11 changed files with 309 additions and 116 deletions

View file

@ -1,22 +1,38 @@
"""
Distributed Prioritized Experience Replay (Ape-X)
=================================================
This file defines a DQN trainer using the Ape-X architecture.
Ape-X uses a single GPU learner and many CPU workers for experience collection.
Experience collection can scale to hundreds of CPU workers due to the
distributed prioritization of experience prior to storage in replay buffers.
Detailed documentation:
https://docs.ray.io/en/latest/rllib-algorithms.html#distributed-prioritized-experience-replay-ape-x
""" # noqa: E501
import collections
import copy
from typing import Tuple
import ray
from ray.rllib.agents.dqn.dqn import DQNTrainer, \
DEFAULT_CONFIG as DQN_CONFIG, calculate_rr_weights
from ray.rllib.agents.dqn.dqn import DEFAULT_CONFIG as DQN_CONFIG
from ray.rllib.agents.dqn.dqn import DQNTrainer, calculate_rr_weights
from ray.rllib.agents.dqn.learner_thread import LearnerThread
from ray.rllib.execution.common import STEPS_TRAINED_COUNTER, \
_get_shared_metrics, _get_global_vars
from ray.rllib.evaluation.worker_set import WorkerSet
from ray.rllib.execution.rollout_ops import ParallelRollouts
from ray.rllib.execution.concurrency_ops import Concurrently, Enqueue, Dequeue
from ray.rllib.execution.replay_ops import StoreToReplayBuffer, Replay
from ray.rllib.execution.train_ops import UpdateTargetNetwork
from ray.rllib.execution.common import (STEPS_TRAINED_COUNTER,
_get_global_vars, _get_shared_metrics)
from ray.rllib.execution.concurrency_ops import Concurrently, Dequeue, Enqueue
from ray.rllib.execution.metric_ops import StandardMetricsReporting
from ray.rllib.execution.replay_buffer import ReplayActor
from ray.rllib.execution.replay_ops import Replay, StoreToReplayBuffer
from ray.rllib.execution.rollout_ops import ParallelRollouts
from ray.rllib.execution.train_ops import UpdateTargetNetwork
from ray.rllib.utils import merge_dicts
from ray.rllib.utils.actors import create_colocated
from ray.rllib.utils.typing import SampleBatchType
from ray.util.iter import LocalIterator
# yapf: disable
# __sphinx_doc_begin__
@ -53,14 +69,15 @@ APEX_DEFAULT_CONFIG = merge_dicts(
# Update worker weights as they finish generating experiences.
class UpdateWorkerWeights:
def __init__(self, learner_thread, workers, max_weight_sync_delay):
def __init__(self, learner_thread: LearnerThread, workers: WorkerSet,
max_weight_sync_delay: int):
self.learner_thread = learner_thread
self.workers = workers
self.steps_since_update = collections.defaultdict(int)
self.max_weight_sync_delay = max_weight_sync_delay
self.weights = None
def __call__(self, item: ("ActorHandle", SampleBatchType)):
def __call__(self, item: Tuple["ActorHandle", SampleBatchType]):
actor, batch = item
self.steps_since_update[actor] += batch.count
if self.steps_since_update[actor] >= self.max_weight_sync_delay:
@ -77,7 +94,8 @@ class UpdateWorkerWeights:
metrics.counters["num_weight_syncs"] += 1
def apex_execution_plan(workers: WorkerSet, config: dict):
def apex_execution_plan(workers: WorkerSet,
config: dict) -> LocalIterator[dict]:
# Create a number of replay buffer actors.
num_replay_buffer_shards = config["optimizer"]["num_replay_buffer_shards"]
replay_actors = create_colocated(ReplayActor, [
@ -97,7 +115,7 @@ def apex_execution_plan(workers: WorkerSet, config: dict):
learner_thread.start()
# Update experience priorities post learning.
def update_prio_and_stats(item: ("ActorHandle", dict, int)):
def update_prio_and_stats(item: Tuple["ActorHandle", dict, int]) -> None:
actor, prio_dict, count = item
actor.update_priorities.remote(prio_dict)
metrics = _get_shared_metrics()
@ -155,7 +173,7 @@ def apex_execution_plan(workers: WorkerSet, config: dict):
[store_op, replay_op, update_op], mode="async", output_indexes=[2])
# Add in extra replay and learner metrics to the training result.
def add_apex_metrics(result):
def add_apex_metrics(result: dict) -> dict:
replay_stats = ray.get(replay_actors[0].stats.remote(
config["optimizer"].get("debug")))
exploration_infos = workers.foreach_trainable_policy(

View file

@ -1,6 +1,12 @@
"""Tensorflow model for DQN"""
from typing import List
import gym
from ray.rllib.models.tf.layers import NoisyLayer
from ray.rllib.models.tf.tf_modelv2 import TFModelV2
from ray.rllib.utils.framework import try_import_tf
from ray.rllib.utils.typing import ModelConfigDict, TensorType
tf1, tf, tfv = try_import_tf()
@ -20,23 +26,23 @@ class DistributionalQTFModel(TFModelV2):
def __init__(
self,
obs_space,
action_space,
num_outputs,
model_config,
name,
obs_space: gym.spaces.Space,
action_space: gym.spaces.Space,
num_outputs: int,
model_config: ModelConfigDict,
name: str,
q_hiddens=(256, ),
dueling=False,
num_atoms=1,
use_noisy=False,
v_min=-10.0,
v_max=10.0,
sigma0=0.5,
dueling: bool = False,
num_atoms: int = 1,
use_noisy: bool = False,
v_min: float = -10.0,
v_max: float = 10.0,
sigma0: float = 0.5,
# TODO(sven): Move `add_layer_norm` into ModelCatalog as
# generic option, then error if we use ParameterNoise as
# Exploration type and do not have any LayerNorm layers in
# the net.
add_layer_norm=False):
add_layer_norm: bool = False):
"""Initialize variables of this model.
Extra model kwargs:
@ -60,7 +66,6 @@ class DistributionalQTFModel(TFModelV2):
only defines the layers for the Q head. Those layers for forward()
should be defined in subclasses of DistributionalQModel.
"""
super(DistributionalQTFModel, self).__init__(
obs_space, action_space, num_outputs, model_config, name)
@ -68,7 +73,8 @@ class DistributionalQTFModel(TFModelV2):
self.model_out = tf.keras.layers.Input(
shape=(num_outputs, ), name="model_out")
def build_action_value(prefix, model_out):
def build_action_value(prefix: str,
model_out: TensorType) -> List[TensorType]:
if q_hiddens:
action_out = model_out
for i in range(len(q_hiddens)):
@ -129,7 +135,8 @@ class DistributionalQTFModel(TFModelV2):
dist = tf.expand_dims(tf.ones_like(action_scores), -1)
return [action_scores, logits, dist]
def build_state_score(prefix, model_out):
def build_state_score(prefix: str,
model_out: TensorType) -> TensorType:
state_out = model_out
for i in range(len(q_hiddens)):
if use_noisy:
@ -163,7 +170,8 @@ class DistributionalQTFModel(TFModelV2):
self.state_value_head = tf.keras.Model(self.model_out, state_out)
self.register_variables(self.state_value_head.variables)
def get_q_value_distributions(self, model_out):
def get_q_value_distributions(self,
model_out: TensorType) -> List[TensorType]:
"""Returns distributional values for Q(s, a) given a state embedding.
Override this in your custom model to customize the Q output head.
@ -175,10 +183,8 @@ class DistributionalQTFModel(TFModelV2):
(action_scores, logits, dist) if num_atoms == 1, otherwise
(action_scores, z, support_logits_per_action, logits, dist)
"""
return self.q_value_head(model_out)
def get_state_value(self, model_out):
def get_state_value(self, model_out: TensorType) -> TensorType:
"""Returns the state value prediction for the given state embedding."""
return self.state_value_head(model_out)

View file

@ -1,8 +1,19 @@
"""
Deep Q-Networks (DQN, Rainbow, Parametric DQN)
==============================================
This file defines the distributed Trainer class for the Deep Q-Networks
algorithm. See `dqn_[tf|torch]_policy.py` for the definition of the policies.
Detailed documentation:
https://docs.ray.io/en/latest/rllib-algorithms.html#deep-q-networks-dqn-rainbow-parametric-dqn
""" # noqa: E501
import logging
from typing import Type
from typing import List, Optional, Type
from ray.rllib.agents.dqn.dqn_tf_policy import DQNTFPolicy
from ray.rllib.agents.dqn.simple_q_tf_policy import SimpleQTFPolicy
from ray.rllib.agents.dqn.dqn_torch_policy import DQNTorchPolicy
from ray.rllib.agents.trainer import with_common_config
from ray.rllib.agents.trainer_template import build_trainer
from ray.rllib.evaluation.worker_set import WorkerSet
@ -158,6 +169,16 @@ def validate_config(config: TrainerConfigDict) -> None:
def execution_plan(workers: WorkerSet,
config: TrainerConfigDict) -> LocalIterator[dict]:
"""Execution plan of the DQN algorithm. Defines the distributed dataflow.
Args:
workers (WorkerSet): The WorkerSet for training the Polic(y/ies)
of the Trainer.
config (TrainerConfigDict): The trainer's configuration dict.
Returns:
LocalIterator[dict]: A local iterator over training metrics.
"""
if config.get("prioritized_replay"):
prio_args = {
"prioritized_replay_alpha": config["prioritized_replay_alpha"],
@ -222,7 +243,8 @@ def execution_plan(workers: WorkerSet,
return StandardMetricsReporting(train_op, workers, config)
def calculate_rr_weights(config: TrainerConfigDict):
def calculate_rr_weights(config: TrainerConfigDict) -> List[float]:
"""Calculate the round robin weights for the rollout and train steps"""
if not config["training_intensity"]:
return [1, 1]
# e.g., 32 / 4 -> native ratio of 8.0
@ -234,23 +256,22 @@ def calculate_rr_weights(config: TrainerConfigDict):
return weights
def get_policy_class(config: TrainerConfigDict) -> Type[Policy]:
def get_policy_class(config: TrainerConfigDict) -> Optional[Type[Policy]]:
"""Policy class picker function. Class is chosen based on DL-framework.
Args:
config (TrainerConfigDict): The trainer's configuration dict.
Returns:
Optional[Type[Policy]]: The Policy class to use with DQNTrainer.
If None, use `default_policy` provided in build_trainer().
"""
if config["framework"] == "torch":
from ray.rllib.agents.dqn.dqn_torch_policy import DQNTorchPolicy
return DQNTorchPolicy
else:
return DQNTFPolicy
def get_simple_policy_class(config: TrainerConfigDict) -> Type[Policy]:
if config["framework"] == "torch":
from ray.rllib.agents.dqn.simple_q_torch_policy import \
SimpleQTorchPolicy
return SimpleQTorchPolicy
else:
return SimpleQTFPolicy
# Build a generic off-policy trainer. Other trainers (such as DDPGTrainer)
# may build on top of it.
GenericOffPolicyTrainer = build_trainer(
name="GenericOffPolicyAlgorithm",
default_policy=None,
@ -259,8 +280,7 @@ GenericOffPolicyTrainer = build_trainer(
validate_config=validate_config,
execution_plan=execution_plan)
# Build a DQN trainer, which uses the framework specific Policy
# determined in `get_policy_class()` above.
DQNTrainer = GenericOffPolicyTrainer.with_updates(
name="DQN", default_policy=DQNTFPolicy, default_config=DEFAULT_CONFIG)
SimpleQTrainer = DQNTrainer.with_updates(
default_policy=SimpleQTFPolicy, get_policy_class=get_simple_policy_class)

View file

@ -1,3 +1,5 @@
"""Tensorflow policy class used for DQN"""
from typing import Dict
import gym
@ -33,18 +35,18 @@ PRIO_WEIGHTS = "weights"
class QLoss:
def __init__(self,
q_t_selected,
q_logits_t_selected,
q_tp1_best,
q_dist_tp1_best,
importance_weights,
rewards,
done_mask,
gamma=0.99,
n_step=1,
num_atoms=1,
v_min=-10.0,
v_max=10.0):
q_t_selected: TensorType,
q_logits_t_selected: TensorType,
q_tp1_best: TensorType,
q_dist_tp1_best: TensorType,
importance_weights: TensorType,
rewards: TensorType,
done_mask: TensorType,
gamma: float = 0.99,
n_step: int = 1,
num_atoms: int = 1,
v_min: float = -10.0,
v_max: float = 10.0):
if num_atoms > 1:
# Distributional Q-learning which corresponds to an entropy loss
@ -110,6 +112,11 @@ class QLoss:
class ComputeTDErrorMixin:
"""Assign the `compute_td_error` method to the DQNTFPolicy
This allows us to prioritize on the worker side.
"""
def __init__(self):
@make_tf_callable(self.get_session(), dynamic_shape=True)
def compute_td_error(obs_t, act_t, rew_t, obs_tp1, done_mask,
@ -130,10 +137,22 @@ class ComputeTDErrorMixin:
self.compute_td_error = compute_td_error
def build_q_model(policy: Policy, obs_space: gym.Space,
action_space: gym.Space,
def build_q_model(policy: Policy, obs_space: gym.spaces.Space,
action_space: gym.spaces.Space,
config: TrainerConfigDict) -> ModelV2:
"""Build q_model and target_q_model for DQN
Args:
policy (Policy): The Policy, which will use the model for optimization.
obs_space (gym.spaces.Space): The policy's observation space.
action_space (gym.spaces.Space): The policy's action space.
config (TrainerConfigDict):
Returns:
ModelV2: The Model for the Policy to use.
Note: The target q model will not be returned, just assigned to
`policy.target_q_model`.
"""
if not isinstance(action_space, gym.spaces.Discrete):
raise UnsupportedSpaceException(
"Action space {} is not supported for DQN.".format(action_space))
@ -206,6 +225,16 @@ def get_distribution_inputs_and_class(policy: Policy,
def build_q_losses(policy: Policy, model, _,
train_batch: SampleBatch) -> TensorType:
"""Constructs the loss for DQNTFPolicy.
Args:
policy (Policy): The Policy to calculate the loss for.
model (ModelV2): The Model to calculate the loss for.
train_batch (SampleBatch): The training data.
Returns:
TensorType: A single loss tensor.
"""
config = policy.config
# q network evaluation
q_t, q_logits_t, q_dist_t = compute_q_values(
@ -300,8 +329,8 @@ def setup_mid_mixins(policy: Policy, obs_space, action_space, config) -> None:
ComputeTDErrorMixin.__init__(policy)
def setup_late_mixins(policy: Policy, obs_space: gym.Space,
action_space: gym.Space,
def setup_late_mixins(policy: Policy, obs_space: gym.spaces.Space,
action_space: gym.spaces.Space,
config: TrainerConfigDict) -> None:
TargetNetworkMixin.__init__(policy, obs_space, action_space, config)

View file

@ -1,7 +1,12 @@
"""PyTorch model for DQN"""
from typing import Sequence
import gym
from ray.rllib.models.torch.misc import SlimFC
from ray.rllib.models.torch.modules.noisy_layer import NoisyLayer
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.typing import ModelConfigDict
torch, nn = try_import_torch()
@ -12,29 +17,29 @@ class DQNTorchModel(TorchModelV2, nn.Module):
def __init__(
self,
obs_space,
action_space,
num_outputs,
model_config,
name,
obs_space: gym.spaces.Space,
action_space: gym.spaces.Space,
num_outputs: int,
model_config: ModelConfigDict,
name: str,
*,
q_hiddens=(256, ),
dueling=False,
dueling_activation="relu",
num_atoms=1,
use_noisy=False,
v_min=-10.0,
v_max=10.0,
sigma0=0.5,
q_hiddens: Sequence[int] = (256, ),
dueling: bool = False,
dueling_activation: str = "relu",
num_atoms: int = 1,
use_noisy: bool = False,
v_min: float = -10.0,
v_max: float = 10.0,
sigma0: float = 0.5,
# TODO(sven): Move `add_layer_norm` into ModelCatalog as
# generic option, then error if we use ParameterNoise as
# Exploration type and do not have any LayerNorm layers in
# the net.
add_layer_norm=False):
add_layer_norm: bool = False):
"""Initialize variables of this model.
Extra model kwargs:
q_hiddens (List[int]): List of layer-sizes after(!) the
q_hiddens (Sequence[int]): List of layer-sizes after(!) the
Advantages(A)/Value(V)-split. Hence, each of the A- and V-
branches will have this structure of Dense layers. To define
the NN before this A/V-split, use - as always -

View file

@ -1,3 +1,5 @@
"""PyTorch policy class used for DQN"""
from typing import Dict, List, Tuple
import gym
@ -31,13 +33,13 @@ if nn:
class QLoss:
def __init__(self,
q_t_selected,
q_logits_t_selected,
q_tp1_best,
q_probs_tp1_best,
importance_weights,
rewards,
done_mask,
q_t_selected: TensorType,
q_logits_t_selected: TensorType,
q_tp1_best: TensorType,
q_probs_tp1_best: TensorType,
importance_weights: TensorType,
rewards: TensorType,
done_mask: TensorType,
gamma=0.99,
n_step=1,
num_atoms=1,
@ -103,6 +105,11 @@ class QLoss:
class ComputeTDErrorMixin:
"""Assign the `compute_td_error` method to the DQNTorchPolicy
This allows us to prioritize on the worker side.
"""
def __init__(self):
def compute_td_error(obs_t, act_t, rew_t, obs_tp1, done_mask,
importance_weights):
@ -122,9 +129,22 @@ class ComputeTDErrorMixin:
def build_q_model_and_distribution(
policy: Policy, obs_space: gym.Space, action_space: gym.Space,
policy: Policy, obs_space: gym.spaces.Space,
action_space: gym.spaces.Space,
config: TrainerConfigDict) -> Tuple[ModelV2, TorchDistributionWrapper]:
"""Build q_model and target_q_model for DQN
Args:
policy (Policy): The policy, which will use the model for optimization.
obs_space (gym.spaces.Space): The policy's observation space.
action_space (gym.spaces.Space): The policy's action space.
config (TrainerConfigDict):
Returns:
(q_model, TorchCategorical)
Note: The target q model will not be returned, just assigned to
`policy.target_q_model`.
"""
if not isinstance(action_space, gym.spaces.Discrete):
raise UnsupportedSpaceException(
"Action space {} is not supported for DQN.".format(action_space))
@ -204,6 +224,16 @@ def get_distribution_inputs_and_class(
def build_q_losses(policy: Policy, model, _,
train_batch: SampleBatch) -> TensorType:
"""Constructs the loss for DQNTorchPolicy.
Args:
policy (Policy): The Policy to calculate the loss for.
model (ModelV2): The Model to calculate the loss for.
train_batch (SampleBatch): The training data.
Returns:
TensorType: A single loss tensor.
"""
config = policy.config
# Q-network evaluation.
q_t, q_logits_t, q_probs_t = compute_q_values(
@ -286,7 +316,8 @@ def setup_early_mixins(policy: Policy, obs_space, action_space,
LearningRateSchedule.__init__(policy, config["lr"], config["lr_schedule"])
def after_init(policy: Policy, obs_space: gym.Space, action_space: gym.Space,
def after_init(policy: Policy, obs_space: gym.spaces.Space,
action_space: gym.spaces.Space,
config: TrainerConfigDict) -> None:
ComputeTDErrorMixin.__init__(policy)
TargetNetworkMixin.__init__(policy, obs_space, action_space, config)

View file

@ -1,5 +1,5 @@
import queue
import threading
from six.moves import queue
from ray.rllib.evaluation.metrics import get_learner_stats
from ray.rllib.policy.policy import LEARNER_STATS_KEY

View file

@ -1,8 +1,11 @@
"""
Simple Q (simple_q)
===================
Simple Q-Learning
=================
This file defines the distributed Trainer class for the simple Q learning.
This module provides a basic implementation of the DQN algorithm without any
optimizations.
This file defines the distributed Trainer class for the Simple Q algorithm.
See `simple_q_[tf|torch]_policy.py` for the definition of the policy loss.
"""
@ -99,7 +102,7 @@ def get_policy_class(config: TrainerConfigDict) -> Optional[Type[Policy]]:
config (TrainerConfigDict): The trainer's configuration dict.
Returns:
Optional[Type[Policy]]: The Policy class to use with PGTrainer.
Optional[Type[Policy]]: The Policy class to use with SimpleQTrainer.
If None, use `default_policy` provided in build_trainer().
"""
if config["framework"] == "torch":
@ -108,6 +111,16 @@ def get_policy_class(config: TrainerConfigDict) -> Optional[Type[Policy]]:
def execution_plan(workers: WorkerSet,
config: TrainerConfigDict) -> LocalIterator[dict]:
"""Execution plan of the Simple Q algorithm. Defines the distributed dataflow.
Args:
workers (WorkerSet): The WorkerSet for training the Polic(y/ies)
of the Trainer.
config (TrainerConfigDict): The trainer's configuration dict.
Returns:
LocalIterator[dict]: A local iterator over training metrics.
"""
local_replay_buffer = LocalReplayBuffer(
num_shards=1,
learning_starts=config["learning_starts"],

View file

@ -1,4 +1,4 @@
"""Basic example of a DQN policy without any optimizations."""
"""TensorFlow policy class used for Simple Q-Learning"""
import logging
from typing import List, Tuple, Type
@ -11,6 +11,7 @@ from ray.rllib.models.tf.tf_action_dist import (Categorical,
TFActionDistribution)
from ray.rllib.models.torch.torch_action_dist import TorchCategorical
from ray.rllib.policy import Policy
from ray.rllib.policy.dynamic_tf_policy import DynamicTFPolicy
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.policy.tf_policy import TFPolicy
from ray.rllib.policy.tf_policy_template import build_tf_policy
@ -28,8 +29,14 @@ Q_TARGET_SCOPE = "target_q_func"
class TargetNetworkMixin:
def __init__(self, obs_space: gym.Space, action_space: gym.Space,
config: TrainerConfigDict):
"""Assign the `update_target` method to the SimpleQTFPolicy
The function is called every `target_network_update_freq` steps by the
master learner.
"""
def __init__(self, obs_space: gym.spaces.Space,
action_space: gym.spaces.Space, config: TrainerConfigDict):
@make_tf_callable(self.get_session())
def do_update():
# update_target_fn will be called periodically to copy Q network to
@ -50,10 +57,24 @@ class TargetNetworkMixin:
return self.q_func_vars + self.target_q_func_vars
def build_q_models(policy: Policy, obs_space: gym.Space,
action_space: gym.Space,
def build_q_models(policy: Policy, obs_space: gym.spaces.Space,
action_space: gym.spaces.Space,
config: TrainerConfigDict) -> ModelV2:
"""Build q_model and target_q_model for Simple Q learning
Note that this function works for both Tensorflow and PyTorch.
Args:
policy (Policy): The Policy, which will use the model for optimization.
obs_space (gym.spaces.Space): The policy's observation space.
action_space (gym.spaces.Space): The policy's action space.
config (TrainerConfigDict):
Returns:
ModelV2: The Model for the Policy to use.
Note: The target q model will not be returned, just assigned to
`policy.target_q_model`.
"""
if not isinstance(action_space, gym.spaces.Discrete):
raise UnsupportedSpaceException(
"Action space {} is not supported for DQN.".format(action_space))
@ -88,6 +109,7 @@ def get_distribution_inputs_and_class(
explore=True,
is_training=True,
**kwargs) -> Tuple[TensorType, type, List[TensorType]]:
"""Build the action distribution"""
q_vals = compute_q_values(policy, q_model, obs_batch, explore, is_training)
q_vals = q_vals[0] if isinstance(q_vals, tuple) else q_vals
@ -100,6 +122,17 @@ def get_distribution_inputs_and_class(
def build_q_losses(policy: Policy, model: ModelV2,
dist_class: Type[TFActionDistribution],
train_batch: SampleBatch) -> TensorType:
"""Constructs the loss for SimpleQTFPolicy.
Args:
policy (Policy): The Policy to calculate the loss for.
model (ModelV2): The Model to calculate the loss for.
dist_class (Type[ActionDistribution]): The action distribution class.
train_batch (SampleBatch): The training data.
Returns:
TensorType: A single loss tensor.
"""
# q network evaluation
q_t = compute_q_values(
policy,
@ -156,13 +189,23 @@ def compute_q_values(policy: Policy,
return model_out
def setup_late_mixins(policy: Policy, obs_space: gym.Space,
action_space: gym.Space,
def setup_late_mixins(policy: Policy, obs_space: gym.spaces.Space,
action_space: gym.spaces.Space,
config: TrainerConfigDict) -> None:
"""Call all mixin classes' constructors before SimpleQTFPolicy initialization.
Args:
policy (Policy): The Policy object.
obs_space (gym.spaces.Space): The Policy's observation space.
action_space (gym.spaces.Space): The Policy's action space.
config (TrainerConfigDict): The Policy's config.
"""
TargetNetworkMixin.__init__(policy, obs_space, action_space, config)
SimpleQTFPolicy = build_tf_policy(
# Build a child class of `DynamicTFPolicy`, given the custom functions defined
# above.
SimpleQTFPolicy: DynamicTFPolicy = build_tf_policy(
name="SimpleQTFPolicy",
get_default_config=lambda: ray.rllib.agents.dqn.dqn.DEFAULT_CONFIG,
make_model=build_q_models,

View file

@ -1,14 +1,15 @@
"""Basic example of a DQN policy without any optimizations."""
"""PyTorch policy class used for Simple Q-Learning"""
import logging
from typing import Dict
from typing import Dict, Tuple
import gym
import ray
from ray.rllib.agents.dqn.simple_q_tf_policy import (
build_q_models, compute_q_values, get_distribution_inputs_and_class)
from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.models.torch.torch_action_dist import TorchCategorical
from ray.rllib.models.torch.torch_action_dist import TorchCategorical, \
TorchDistributionWrapper
from ray.rllib.policy import Policy
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.policy.torch_policy_template import build_torch_policy
@ -24,8 +25,14 @@ logger = logging.getLogger(__name__)
class TargetNetworkMixin:
def __init__(self, obs_space: gym.Space, action_space: gym.Space,
config: TrainerConfigDict):
"""Assign the `update_target` method to the SimpleQTorchPolicy
The function is called every `target_network_update_freq` steps by the
master learner.
"""
def __init__(self, obs_space: gym.spaces.Space,
action_space: gym.spaces.Space, config: TrainerConfigDict):
def do_update():
# Update_target_fn will be called periodically to copy Q network to
# target Q network.
@ -36,15 +43,27 @@ class TargetNetworkMixin:
self.update_target = do_update
def build_q_model_and_distribution(policy: Policy, obs_space: gym.Space,
action_space: gym.Space,
config: TrainerConfigDict) -> ModelV2:
def build_q_model_and_distribution(
policy: Policy, obs_space: gym.spaces.Space,
action_space: gym.spaces.Space,
config: TrainerConfigDict) -> Tuple[ModelV2, TorchDistributionWrapper]:
return build_q_models(policy, obs_space, action_space, config), \
TorchCategorical
def build_q_losses(policy: Policy, model, dist_class,
train_batch: SampleBatch) -> TensorType:
"""Constructs the loss for SimpleQTorchPolicy.
Args:
policy (Policy): The Policy to calculate the loss for.
model (ModelV2): The Model to calculate the loss for.
dist_class (Type[ActionDistribution]): The action distribution class.
train_batch (SampleBatch): The training data.
Returns:
TensorType: A single loss tensor.
"""
# q network evaluation
q_t = compute_q_values(
policy,
@ -89,13 +108,22 @@ def build_q_losses(policy: Policy, model, dist_class,
def extra_action_out_fn(policy: Policy, input_dict, state_batches, model,
action_dist) -> Dict[str, TensorType]:
"""Adds q-values to action out dict."""
"""Adds q-values to the action out dict."""
return {"q_values": policy.q_values}
def setup_late_mixins(policy: Policy, obs_space: gym.Space,
action_space: gym.Space,
def setup_late_mixins(policy: Policy, obs_space: gym.spaces.Space,
action_space: gym.spaces.Space,
config: TrainerConfigDict) -> None:
"""Call all mixin classes' constructors before SimpleQTorchPolicy
initialization.
Args:
policy (Policy): The Policy object.
obs_space (gym.spaces.Space): The Policy's observation space.
action_space (gym.spaces.Space): The Policy's action space.
config (TrainerConfigDict): The Policy's config.
"""
TargetNetworkMixin.__init__(policy, obs_space, action_space, config)

View file

@ -91,7 +91,7 @@ def appo_surrogate_loss(
Args:
policy (Policy): The Policy to calculate the loss for.
model (ModelV2): The Model to calculate the loss for.
dist_class (Type[ActionDistribution]: The action distr. class.
dist_class (Type[ActionDistribution]): The action distr. class.
train_batch (SampleBatch): The training data.
Returns: