mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[RLlib] Add docstrings for agents/dqn (#10710)
This commit is contained in:
parent
34bb61dabc
commit
4ccfd07a61
11 changed files with 309 additions and 116 deletions
|
@ -1,22 +1,38 @@
|
|||
"""
|
||||
Distributed Prioritized Experience Replay (Ape-X)
|
||||
=================================================
|
||||
|
||||
This file defines a DQN trainer using the Ape-X architecture.
|
||||
|
||||
Ape-X uses a single GPU learner and many CPU workers for experience collection.
|
||||
Experience collection can scale to hundreds of CPU workers due to the
|
||||
distributed prioritization of experience prior to storage in replay buffers.
|
||||
|
||||
Detailed documentation:
|
||||
https://docs.ray.io/en/latest/rllib-algorithms.html#distributed-prioritized-experience-replay-ape-x
|
||||
""" # noqa: E501
|
||||
|
||||
import collections
|
||||
import copy
|
||||
from typing import Tuple
|
||||
|
||||
import ray
|
||||
from ray.rllib.agents.dqn.dqn import DQNTrainer, \
|
||||
DEFAULT_CONFIG as DQN_CONFIG, calculate_rr_weights
|
||||
from ray.rllib.agents.dqn.dqn import DEFAULT_CONFIG as DQN_CONFIG
|
||||
from ray.rllib.agents.dqn.dqn import DQNTrainer, calculate_rr_weights
|
||||
from ray.rllib.agents.dqn.learner_thread import LearnerThread
|
||||
from ray.rllib.execution.common import STEPS_TRAINED_COUNTER, \
|
||||
_get_shared_metrics, _get_global_vars
|
||||
from ray.rllib.evaluation.worker_set import WorkerSet
|
||||
from ray.rllib.execution.rollout_ops import ParallelRollouts
|
||||
from ray.rllib.execution.concurrency_ops import Concurrently, Enqueue, Dequeue
|
||||
from ray.rllib.execution.replay_ops import StoreToReplayBuffer, Replay
|
||||
from ray.rllib.execution.train_ops import UpdateTargetNetwork
|
||||
from ray.rllib.execution.common import (STEPS_TRAINED_COUNTER,
|
||||
_get_global_vars, _get_shared_metrics)
|
||||
from ray.rllib.execution.concurrency_ops import Concurrently, Dequeue, Enqueue
|
||||
from ray.rllib.execution.metric_ops import StandardMetricsReporting
|
||||
from ray.rllib.execution.replay_buffer import ReplayActor
|
||||
from ray.rllib.execution.replay_ops import Replay, StoreToReplayBuffer
|
||||
from ray.rllib.execution.rollout_ops import ParallelRollouts
|
||||
from ray.rllib.execution.train_ops import UpdateTargetNetwork
|
||||
from ray.rllib.utils import merge_dicts
|
||||
from ray.rllib.utils.actors import create_colocated
|
||||
from ray.rllib.utils.typing import SampleBatchType
|
||||
from ray.util.iter import LocalIterator
|
||||
|
||||
# yapf: disable
|
||||
# __sphinx_doc_begin__
|
||||
|
@ -53,14 +69,15 @@ APEX_DEFAULT_CONFIG = merge_dicts(
|
|||
|
||||
# Update worker weights as they finish generating experiences.
|
||||
class UpdateWorkerWeights:
|
||||
def __init__(self, learner_thread, workers, max_weight_sync_delay):
|
||||
def __init__(self, learner_thread: LearnerThread, workers: WorkerSet,
|
||||
max_weight_sync_delay: int):
|
||||
self.learner_thread = learner_thread
|
||||
self.workers = workers
|
||||
self.steps_since_update = collections.defaultdict(int)
|
||||
self.max_weight_sync_delay = max_weight_sync_delay
|
||||
self.weights = None
|
||||
|
||||
def __call__(self, item: ("ActorHandle", SampleBatchType)):
|
||||
def __call__(self, item: Tuple["ActorHandle", SampleBatchType]):
|
||||
actor, batch = item
|
||||
self.steps_since_update[actor] += batch.count
|
||||
if self.steps_since_update[actor] >= self.max_weight_sync_delay:
|
||||
|
@ -77,7 +94,8 @@ class UpdateWorkerWeights:
|
|||
metrics.counters["num_weight_syncs"] += 1
|
||||
|
||||
|
||||
def apex_execution_plan(workers: WorkerSet, config: dict):
|
||||
def apex_execution_plan(workers: WorkerSet,
|
||||
config: dict) -> LocalIterator[dict]:
|
||||
# Create a number of replay buffer actors.
|
||||
num_replay_buffer_shards = config["optimizer"]["num_replay_buffer_shards"]
|
||||
replay_actors = create_colocated(ReplayActor, [
|
||||
|
@ -97,7 +115,7 @@ def apex_execution_plan(workers: WorkerSet, config: dict):
|
|||
learner_thread.start()
|
||||
|
||||
# Update experience priorities post learning.
|
||||
def update_prio_and_stats(item: ("ActorHandle", dict, int)):
|
||||
def update_prio_and_stats(item: Tuple["ActorHandle", dict, int]) -> None:
|
||||
actor, prio_dict, count = item
|
||||
actor.update_priorities.remote(prio_dict)
|
||||
metrics = _get_shared_metrics()
|
||||
|
@ -155,7 +173,7 @@ def apex_execution_plan(workers: WorkerSet, config: dict):
|
|||
[store_op, replay_op, update_op], mode="async", output_indexes=[2])
|
||||
|
||||
# Add in extra replay and learner metrics to the training result.
|
||||
def add_apex_metrics(result):
|
||||
def add_apex_metrics(result: dict) -> dict:
|
||||
replay_stats = ray.get(replay_actors[0].stats.remote(
|
||||
config["optimizer"].get("debug")))
|
||||
exploration_infos = workers.foreach_trainable_policy(
|
||||
|
|
|
@ -1,6 +1,12 @@
|
|||
"""Tensorflow model for DQN"""
|
||||
|
||||
from typing import List
|
||||
|
||||
import gym
|
||||
from ray.rllib.models.tf.layers import NoisyLayer
|
||||
from ray.rllib.models.tf.tf_modelv2 import TFModelV2
|
||||
from ray.rllib.utils.framework import try_import_tf
|
||||
from ray.rllib.utils.typing import ModelConfigDict, TensorType
|
||||
|
||||
tf1, tf, tfv = try_import_tf()
|
||||
|
||||
|
@ -20,23 +26,23 @@ class DistributionalQTFModel(TFModelV2):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
obs_space,
|
||||
action_space,
|
||||
num_outputs,
|
||||
model_config,
|
||||
name,
|
||||
obs_space: gym.spaces.Space,
|
||||
action_space: gym.spaces.Space,
|
||||
num_outputs: int,
|
||||
model_config: ModelConfigDict,
|
||||
name: str,
|
||||
q_hiddens=(256, ),
|
||||
dueling=False,
|
||||
num_atoms=1,
|
||||
use_noisy=False,
|
||||
v_min=-10.0,
|
||||
v_max=10.0,
|
||||
sigma0=0.5,
|
||||
dueling: bool = False,
|
||||
num_atoms: int = 1,
|
||||
use_noisy: bool = False,
|
||||
v_min: float = -10.0,
|
||||
v_max: float = 10.0,
|
||||
sigma0: float = 0.5,
|
||||
# TODO(sven): Move `add_layer_norm` into ModelCatalog as
|
||||
# generic option, then error if we use ParameterNoise as
|
||||
# Exploration type and do not have any LayerNorm layers in
|
||||
# the net.
|
||||
add_layer_norm=False):
|
||||
add_layer_norm: bool = False):
|
||||
"""Initialize variables of this model.
|
||||
|
||||
Extra model kwargs:
|
||||
|
@ -60,7 +66,6 @@ class DistributionalQTFModel(TFModelV2):
|
|||
only defines the layers for the Q head. Those layers for forward()
|
||||
should be defined in subclasses of DistributionalQModel.
|
||||
"""
|
||||
|
||||
super(DistributionalQTFModel, self).__init__(
|
||||
obs_space, action_space, num_outputs, model_config, name)
|
||||
|
||||
|
@ -68,7 +73,8 @@ class DistributionalQTFModel(TFModelV2):
|
|||
self.model_out = tf.keras.layers.Input(
|
||||
shape=(num_outputs, ), name="model_out")
|
||||
|
||||
def build_action_value(prefix, model_out):
|
||||
def build_action_value(prefix: str,
|
||||
model_out: TensorType) -> List[TensorType]:
|
||||
if q_hiddens:
|
||||
action_out = model_out
|
||||
for i in range(len(q_hiddens)):
|
||||
|
@ -129,7 +135,8 @@ class DistributionalQTFModel(TFModelV2):
|
|||
dist = tf.expand_dims(tf.ones_like(action_scores), -1)
|
||||
return [action_scores, logits, dist]
|
||||
|
||||
def build_state_score(prefix, model_out):
|
||||
def build_state_score(prefix: str,
|
||||
model_out: TensorType) -> TensorType:
|
||||
state_out = model_out
|
||||
for i in range(len(q_hiddens)):
|
||||
if use_noisy:
|
||||
|
@ -163,7 +170,8 @@ class DistributionalQTFModel(TFModelV2):
|
|||
self.state_value_head = tf.keras.Model(self.model_out, state_out)
|
||||
self.register_variables(self.state_value_head.variables)
|
||||
|
||||
def get_q_value_distributions(self, model_out):
|
||||
def get_q_value_distributions(self,
|
||||
model_out: TensorType) -> List[TensorType]:
|
||||
"""Returns distributional values for Q(s, a) given a state embedding.
|
||||
|
||||
Override this in your custom model to customize the Q output head.
|
||||
|
@ -175,10 +183,8 @@ class DistributionalQTFModel(TFModelV2):
|
|||
(action_scores, logits, dist) if num_atoms == 1, otherwise
|
||||
(action_scores, z, support_logits_per_action, logits, dist)
|
||||
"""
|
||||
|
||||
return self.q_value_head(model_out)
|
||||
|
||||
def get_state_value(self, model_out):
|
||||
def get_state_value(self, model_out: TensorType) -> TensorType:
|
||||
"""Returns the state value prediction for the given state embedding."""
|
||||
|
||||
return self.state_value_head(model_out)
|
||||
|
|
|
@ -1,8 +1,19 @@
|
|||
"""
|
||||
Deep Q-Networks (DQN, Rainbow, Parametric DQN)
|
||||
==============================================
|
||||
|
||||
This file defines the distributed Trainer class for the Deep Q-Networks
|
||||
algorithm. See `dqn_[tf|torch]_policy.py` for the definition of the policies.
|
||||
|
||||
Detailed documentation:
|
||||
https://docs.ray.io/en/latest/rllib-algorithms.html#deep-q-networks-dqn-rainbow-parametric-dqn
|
||||
""" # noqa: E501
|
||||
|
||||
import logging
|
||||
from typing import Type
|
||||
from typing import List, Optional, Type
|
||||
|
||||
from ray.rllib.agents.dqn.dqn_tf_policy import DQNTFPolicy
|
||||
from ray.rllib.agents.dqn.simple_q_tf_policy import SimpleQTFPolicy
|
||||
from ray.rllib.agents.dqn.dqn_torch_policy import DQNTorchPolicy
|
||||
from ray.rllib.agents.trainer import with_common_config
|
||||
from ray.rllib.agents.trainer_template import build_trainer
|
||||
from ray.rllib.evaluation.worker_set import WorkerSet
|
||||
|
@ -158,6 +169,16 @@ def validate_config(config: TrainerConfigDict) -> None:
|
|||
|
||||
def execution_plan(workers: WorkerSet,
|
||||
config: TrainerConfigDict) -> LocalIterator[dict]:
|
||||
"""Execution plan of the DQN algorithm. Defines the distributed dataflow.
|
||||
|
||||
Args:
|
||||
workers (WorkerSet): The WorkerSet for training the Polic(y/ies)
|
||||
of the Trainer.
|
||||
config (TrainerConfigDict): The trainer's configuration dict.
|
||||
|
||||
Returns:
|
||||
LocalIterator[dict]: A local iterator over training metrics.
|
||||
"""
|
||||
if config.get("prioritized_replay"):
|
||||
prio_args = {
|
||||
"prioritized_replay_alpha": config["prioritized_replay_alpha"],
|
||||
|
@ -222,7 +243,8 @@ def execution_plan(workers: WorkerSet,
|
|||
return StandardMetricsReporting(train_op, workers, config)
|
||||
|
||||
|
||||
def calculate_rr_weights(config: TrainerConfigDict):
|
||||
def calculate_rr_weights(config: TrainerConfigDict) -> List[float]:
|
||||
"""Calculate the round robin weights for the rollout and train steps"""
|
||||
if not config["training_intensity"]:
|
||||
return [1, 1]
|
||||
# e.g., 32 / 4 -> native ratio of 8.0
|
||||
|
@ -234,23 +256,22 @@ def calculate_rr_weights(config: TrainerConfigDict):
|
|||
return weights
|
||||
|
||||
|
||||
def get_policy_class(config: TrainerConfigDict) -> Type[Policy]:
|
||||
def get_policy_class(config: TrainerConfigDict) -> Optional[Type[Policy]]:
|
||||
"""Policy class picker function. Class is chosen based on DL-framework.
|
||||
|
||||
Args:
|
||||
config (TrainerConfigDict): The trainer's configuration dict.
|
||||
|
||||
Returns:
|
||||
Optional[Type[Policy]]: The Policy class to use with DQNTrainer.
|
||||
If None, use `default_policy` provided in build_trainer().
|
||||
"""
|
||||
if config["framework"] == "torch":
|
||||
from ray.rllib.agents.dqn.dqn_torch_policy import DQNTorchPolicy
|
||||
return DQNTorchPolicy
|
||||
else:
|
||||
return DQNTFPolicy
|
||||
|
||||
|
||||
def get_simple_policy_class(config: TrainerConfigDict) -> Type[Policy]:
|
||||
if config["framework"] == "torch":
|
||||
from ray.rllib.agents.dqn.simple_q_torch_policy import \
|
||||
SimpleQTorchPolicy
|
||||
return SimpleQTorchPolicy
|
||||
else:
|
||||
return SimpleQTFPolicy
|
||||
|
||||
|
||||
# Build a generic off-policy trainer. Other trainers (such as DDPGTrainer)
|
||||
# may build on top of it.
|
||||
GenericOffPolicyTrainer = build_trainer(
|
||||
name="GenericOffPolicyAlgorithm",
|
||||
default_policy=None,
|
||||
|
@ -259,8 +280,7 @@ GenericOffPolicyTrainer = build_trainer(
|
|||
validate_config=validate_config,
|
||||
execution_plan=execution_plan)
|
||||
|
||||
# Build a DQN trainer, which uses the framework specific Policy
|
||||
# determined in `get_policy_class()` above.
|
||||
DQNTrainer = GenericOffPolicyTrainer.with_updates(
|
||||
name="DQN", default_policy=DQNTFPolicy, default_config=DEFAULT_CONFIG)
|
||||
|
||||
SimpleQTrainer = DQNTrainer.with_updates(
|
||||
default_policy=SimpleQTFPolicy, get_policy_class=get_simple_policy_class)
|
||||
|
|
|
@ -1,3 +1,5 @@
|
|||
"""Tensorflow policy class used for DQN"""
|
||||
|
||||
from typing import Dict
|
||||
|
||||
import gym
|
||||
|
@ -33,18 +35,18 @@ PRIO_WEIGHTS = "weights"
|
|||
|
||||
class QLoss:
|
||||
def __init__(self,
|
||||
q_t_selected,
|
||||
q_logits_t_selected,
|
||||
q_tp1_best,
|
||||
q_dist_tp1_best,
|
||||
importance_weights,
|
||||
rewards,
|
||||
done_mask,
|
||||
gamma=0.99,
|
||||
n_step=1,
|
||||
num_atoms=1,
|
||||
v_min=-10.0,
|
||||
v_max=10.0):
|
||||
q_t_selected: TensorType,
|
||||
q_logits_t_selected: TensorType,
|
||||
q_tp1_best: TensorType,
|
||||
q_dist_tp1_best: TensorType,
|
||||
importance_weights: TensorType,
|
||||
rewards: TensorType,
|
||||
done_mask: TensorType,
|
||||
gamma: float = 0.99,
|
||||
n_step: int = 1,
|
||||
num_atoms: int = 1,
|
||||
v_min: float = -10.0,
|
||||
v_max: float = 10.0):
|
||||
|
||||
if num_atoms > 1:
|
||||
# Distributional Q-learning which corresponds to an entropy loss
|
||||
|
@ -110,6 +112,11 @@ class QLoss:
|
|||
|
||||
|
||||
class ComputeTDErrorMixin:
|
||||
"""Assign the `compute_td_error` method to the DQNTFPolicy
|
||||
|
||||
This allows us to prioritize on the worker side.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
@make_tf_callable(self.get_session(), dynamic_shape=True)
|
||||
def compute_td_error(obs_t, act_t, rew_t, obs_tp1, done_mask,
|
||||
|
@ -130,10 +137,22 @@ class ComputeTDErrorMixin:
|
|||
self.compute_td_error = compute_td_error
|
||||
|
||||
|
||||
def build_q_model(policy: Policy, obs_space: gym.Space,
|
||||
action_space: gym.Space,
|
||||
def build_q_model(policy: Policy, obs_space: gym.spaces.Space,
|
||||
action_space: gym.spaces.Space,
|
||||
config: TrainerConfigDict) -> ModelV2:
|
||||
"""Build q_model and target_q_model for DQN
|
||||
|
||||
Args:
|
||||
policy (Policy): The Policy, which will use the model for optimization.
|
||||
obs_space (gym.spaces.Space): The policy's observation space.
|
||||
action_space (gym.spaces.Space): The policy's action space.
|
||||
config (TrainerConfigDict):
|
||||
|
||||
Returns:
|
||||
ModelV2: The Model for the Policy to use.
|
||||
Note: The target q model will not be returned, just assigned to
|
||||
`policy.target_q_model`.
|
||||
"""
|
||||
if not isinstance(action_space, gym.spaces.Discrete):
|
||||
raise UnsupportedSpaceException(
|
||||
"Action space {} is not supported for DQN.".format(action_space))
|
||||
|
@ -206,6 +225,16 @@ def get_distribution_inputs_and_class(policy: Policy,
|
|||
|
||||
def build_q_losses(policy: Policy, model, _,
|
||||
train_batch: SampleBatch) -> TensorType:
|
||||
"""Constructs the loss for DQNTFPolicy.
|
||||
|
||||
Args:
|
||||
policy (Policy): The Policy to calculate the loss for.
|
||||
model (ModelV2): The Model to calculate the loss for.
|
||||
train_batch (SampleBatch): The training data.
|
||||
|
||||
Returns:
|
||||
TensorType: A single loss tensor.
|
||||
"""
|
||||
config = policy.config
|
||||
# q network evaluation
|
||||
q_t, q_logits_t, q_dist_t = compute_q_values(
|
||||
|
@ -300,8 +329,8 @@ def setup_mid_mixins(policy: Policy, obs_space, action_space, config) -> None:
|
|||
ComputeTDErrorMixin.__init__(policy)
|
||||
|
||||
|
||||
def setup_late_mixins(policy: Policy, obs_space: gym.Space,
|
||||
action_space: gym.Space,
|
||||
def setup_late_mixins(policy: Policy, obs_space: gym.spaces.Space,
|
||||
action_space: gym.spaces.Space,
|
||||
config: TrainerConfigDict) -> None:
|
||||
TargetNetworkMixin.__init__(policy, obs_space, action_space, config)
|
||||
|
||||
|
|
|
@ -1,7 +1,12 @@
|
|||
"""PyTorch model for DQN"""
|
||||
|
||||
from typing import Sequence
|
||||
import gym
|
||||
from ray.rllib.models.torch.misc import SlimFC
|
||||
from ray.rllib.models.torch.modules.noisy_layer import NoisyLayer
|
||||
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
|
||||
from ray.rllib.utils.framework import try_import_torch
|
||||
from ray.rllib.utils.typing import ModelConfigDict
|
||||
|
||||
torch, nn = try_import_torch()
|
||||
|
||||
|
@ -12,29 +17,29 @@ class DQNTorchModel(TorchModelV2, nn.Module):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
obs_space,
|
||||
action_space,
|
||||
num_outputs,
|
||||
model_config,
|
||||
name,
|
||||
obs_space: gym.spaces.Space,
|
||||
action_space: gym.spaces.Space,
|
||||
num_outputs: int,
|
||||
model_config: ModelConfigDict,
|
||||
name: str,
|
||||
*,
|
||||
q_hiddens=(256, ),
|
||||
dueling=False,
|
||||
dueling_activation="relu",
|
||||
num_atoms=1,
|
||||
use_noisy=False,
|
||||
v_min=-10.0,
|
||||
v_max=10.0,
|
||||
sigma0=0.5,
|
||||
q_hiddens: Sequence[int] = (256, ),
|
||||
dueling: bool = False,
|
||||
dueling_activation: str = "relu",
|
||||
num_atoms: int = 1,
|
||||
use_noisy: bool = False,
|
||||
v_min: float = -10.0,
|
||||
v_max: float = 10.0,
|
||||
sigma0: float = 0.5,
|
||||
# TODO(sven): Move `add_layer_norm` into ModelCatalog as
|
||||
# generic option, then error if we use ParameterNoise as
|
||||
# Exploration type and do not have any LayerNorm layers in
|
||||
# the net.
|
||||
add_layer_norm=False):
|
||||
add_layer_norm: bool = False):
|
||||
"""Initialize variables of this model.
|
||||
|
||||
Extra model kwargs:
|
||||
q_hiddens (List[int]): List of layer-sizes after(!) the
|
||||
q_hiddens (Sequence[int]): List of layer-sizes after(!) the
|
||||
Advantages(A)/Value(V)-split. Hence, each of the A- and V-
|
||||
branches will have this structure of Dense layers. To define
|
||||
the NN before this A/V-split, use - as always -
|
||||
|
|
|
@ -1,3 +1,5 @@
|
|||
"""PyTorch policy class used for DQN"""
|
||||
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import gym
|
||||
|
@ -31,13 +33,13 @@ if nn:
|
|||
|
||||
class QLoss:
|
||||
def __init__(self,
|
||||
q_t_selected,
|
||||
q_logits_t_selected,
|
||||
q_tp1_best,
|
||||
q_probs_tp1_best,
|
||||
importance_weights,
|
||||
rewards,
|
||||
done_mask,
|
||||
q_t_selected: TensorType,
|
||||
q_logits_t_selected: TensorType,
|
||||
q_tp1_best: TensorType,
|
||||
q_probs_tp1_best: TensorType,
|
||||
importance_weights: TensorType,
|
||||
rewards: TensorType,
|
||||
done_mask: TensorType,
|
||||
gamma=0.99,
|
||||
n_step=1,
|
||||
num_atoms=1,
|
||||
|
@ -103,6 +105,11 @@ class QLoss:
|
|||
|
||||
|
||||
class ComputeTDErrorMixin:
|
||||
"""Assign the `compute_td_error` method to the DQNTorchPolicy
|
||||
|
||||
This allows us to prioritize on the worker side.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
def compute_td_error(obs_t, act_t, rew_t, obs_tp1, done_mask,
|
||||
importance_weights):
|
||||
|
@ -122,9 +129,22 @@ class ComputeTDErrorMixin:
|
|||
|
||||
|
||||
def build_q_model_and_distribution(
|
||||
policy: Policy, obs_space: gym.Space, action_space: gym.Space,
|
||||
policy: Policy, obs_space: gym.spaces.Space,
|
||||
action_space: gym.spaces.Space,
|
||||
config: TrainerConfigDict) -> Tuple[ModelV2, TorchDistributionWrapper]:
|
||||
"""Build q_model and target_q_model for DQN
|
||||
|
||||
Args:
|
||||
policy (Policy): The policy, which will use the model for optimization.
|
||||
obs_space (gym.spaces.Space): The policy's observation space.
|
||||
action_space (gym.spaces.Space): The policy's action space.
|
||||
config (TrainerConfigDict):
|
||||
|
||||
Returns:
|
||||
(q_model, TorchCategorical)
|
||||
Note: The target q model will not be returned, just assigned to
|
||||
`policy.target_q_model`.
|
||||
"""
|
||||
if not isinstance(action_space, gym.spaces.Discrete):
|
||||
raise UnsupportedSpaceException(
|
||||
"Action space {} is not supported for DQN.".format(action_space))
|
||||
|
@ -204,6 +224,16 @@ def get_distribution_inputs_and_class(
|
|||
|
||||
def build_q_losses(policy: Policy, model, _,
|
||||
train_batch: SampleBatch) -> TensorType:
|
||||
"""Constructs the loss for DQNTorchPolicy.
|
||||
|
||||
Args:
|
||||
policy (Policy): The Policy to calculate the loss for.
|
||||
model (ModelV2): The Model to calculate the loss for.
|
||||
train_batch (SampleBatch): The training data.
|
||||
|
||||
Returns:
|
||||
TensorType: A single loss tensor.
|
||||
"""
|
||||
config = policy.config
|
||||
# Q-network evaluation.
|
||||
q_t, q_logits_t, q_probs_t = compute_q_values(
|
||||
|
@ -286,7 +316,8 @@ def setup_early_mixins(policy: Policy, obs_space, action_space,
|
|||
LearningRateSchedule.__init__(policy, config["lr"], config["lr_schedule"])
|
||||
|
||||
|
||||
def after_init(policy: Policy, obs_space: gym.Space, action_space: gym.Space,
|
||||
def after_init(policy: Policy, obs_space: gym.spaces.Space,
|
||||
action_space: gym.spaces.Space,
|
||||
config: TrainerConfigDict) -> None:
|
||||
ComputeTDErrorMixin.__init__(policy)
|
||||
TargetNetworkMixin.__init__(policy, obs_space, action_space, config)
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
import queue
|
||||
import threading
|
||||
from six.moves import queue
|
||||
|
||||
from ray.rllib.evaluation.metrics import get_learner_stats
|
||||
from ray.rllib.policy.policy import LEARNER_STATS_KEY
|
||||
|
|
|
@ -1,8 +1,11 @@
|
|||
"""
|
||||
Simple Q (simple_q)
|
||||
===================
|
||||
Simple Q-Learning
|
||||
=================
|
||||
|
||||
This file defines the distributed Trainer class for the simple Q learning.
|
||||
This module provides a basic implementation of the DQN algorithm without any
|
||||
optimizations.
|
||||
|
||||
This file defines the distributed Trainer class for the Simple Q algorithm.
|
||||
See `simple_q_[tf|torch]_policy.py` for the definition of the policy loss.
|
||||
"""
|
||||
|
||||
|
@ -99,7 +102,7 @@ def get_policy_class(config: TrainerConfigDict) -> Optional[Type[Policy]]:
|
|||
config (TrainerConfigDict): The trainer's configuration dict.
|
||||
|
||||
Returns:
|
||||
Optional[Type[Policy]]: The Policy class to use with PGTrainer.
|
||||
Optional[Type[Policy]]: The Policy class to use with SimpleQTrainer.
|
||||
If None, use `default_policy` provided in build_trainer().
|
||||
"""
|
||||
if config["framework"] == "torch":
|
||||
|
@ -108,6 +111,16 @@ def get_policy_class(config: TrainerConfigDict) -> Optional[Type[Policy]]:
|
|||
|
||||
def execution_plan(workers: WorkerSet,
|
||||
config: TrainerConfigDict) -> LocalIterator[dict]:
|
||||
"""Execution plan of the Simple Q algorithm. Defines the distributed dataflow.
|
||||
|
||||
Args:
|
||||
workers (WorkerSet): The WorkerSet for training the Polic(y/ies)
|
||||
of the Trainer.
|
||||
config (TrainerConfigDict): The trainer's configuration dict.
|
||||
|
||||
Returns:
|
||||
LocalIterator[dict]: A local iterator over training metrics.
|
||||
"""
|
||||
local_replay_buffer = LocalReplayBuffer(
|
||||
num_shards=1,
|
||||
learning_starts=config["learning_starts"],
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
"""Basic example of a DQN policy without any optimizations."""
|
||||
"""TensorFlow policy class used for Simple Q-Learning"""
|
||||
|
||||
import logging
|
||||
from typing import List, Tuple, Type
|
||||
|
@ -11,6 +11,7 @@ from ray.rllib.models.tf.tf_action_dist import (Categorical,
|
|||
TFActionDistribution)
|
||||
from ray.rllib.models.torch.torch_action_dist import TorchCategorical
|
||||
from ray.rllib.policy import Policy
|
||||
from ray.rllib.policy.dynamic_tf_policy import DynamicTFPolicy
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.policy.tf_policy import TFPolicy
|
||||
from ray.rllib.policy.tf_policy_template import build_tf_policy
|
||||
|
@ -28,8 +29,14 @@ Q_TARGET_SCOPE = "target_q_func"
|
|||
|
||||
|
||||
class TargetNetworkMixin:
|
||||
def __init__(self, obs_space: gym.Space, action_space: gym.Space,
|
||||
config: TrainerConfigDict):
|
||||
"""Assign the `update_target` method to the SimpleQTFPolicy
|
||||
|
||||
The function is called every `target_network_update_freq` steps by the
|
||||
master learner.
|
||||
"""
|
||||
|
||||
def __init__(self, obs_space: gym.spaces.Space,
|
||||
action_space: gym.spaces.Space, config: TrainerConfigDict):
|
||||
@make_tf_callable(self.get_session())
|
||||
def do_update():
|
||||
# update_target_fn will be called periodically to copy Q network to
|
||||
|
@ -50,10 +57,24 @@ class TargetNetworkMixin:
|
|||
return self.q_func_vars + self.target_q_func_vars
|
||||
|
||||
|
||||
def build_q_models(policy: Policy, obs_space: gym.Space,
|
||||
action_space: gym.Space,
|
||||
def build_q_models(policy: Policy, obs_space: gym.spaces.Space,
|
||||
action_space: gym.spaces.Space,
|
||||
config: TrainerConfigDict) -> ModelV2:
|
||||
"""Build q_model and target_q_model for Simple Q learning
|
||||
|
||||
Note that this function works for both Tensorflow and PyTorch.
|
||||
|
||||
Args:
|
||||
policy (Policy): The Policy, which will use the model for optimization.
|
||||
obs_space (gym.spaces.Space): The policy's observation space.
|
||||
action_space (gym.spaces.Space): The policy's action space.
|
||||
config (TrainerConfigDict):
|
||||
|
||||
Returns:
|
||||
ModelV2: The Model for the Policy to use.
|
||||
Note: The target q model will not be returned, just assigned to
|
||||
`policy.target_q_model`.
|
||||
"""
|
||||
if not isinstance(action_space, gym.spaces.Discrete):
|
||||
raise UnsupportedSpaceException(
|
||||
"Action space {} is not supported for DQN.".format(action_space))
|
||||
|
@ -88,6 +109,7 @@ def get_distribution_inputs_and_class(
|
|||
explore=True,
|
||||
is_training=True,
|
||||
**kwargs) -> Tuple[TensorType, type, List[TensorType]]:
|
||||
"""Build the action distribution"""
|
||||
q_vals = compute_q_values(policy, q_model, obs_batch, explore, is_training)
|
||||
q_vals = q_vals[0] if isinstance(q_vals, tuple) else q_vals
|
||||
|
||||
|
@ -100,6 +122,17 @@ def get_distribution_inputs_and_class(
|
|||
def build_q_losses(policy: Policy, model: ModelV2,
|
||||
dist_class: Type[TFActionDistribution],
|
||||
train_batch: SampleBatch) -> TensorType:
|
||||
"""Constructs the loss for SimpleQTFPolicy.
|
||||
|
||||
Args:
|
||||
policy (Policy): The Policy to calculate the loss for.
|
||||
model (ModelV2): The Model to calculate the loss for.
|
||||
dist_class (Type[ActionDistribution]): The action distribution class.
|
||||
train_batch (SampleBatch): The training data.
|
||||
|
||||
Returns:
|
||||
TensorType: A single loss tensor.
|
||||
"""
|
||||
# q network evaluation
|
||||
q_t = compute_q_values(
|
||||
policy,
|
||||
|
@ -156,13 +189,23 @@ def compute_q_values(policy: Policy,
|
|||
return model_out
|
||||
|
||||
|
||||
def setup_late_mixins(policy: Policy, obs_space: gym.Space,
|
||||
action_space: gym.Space,
|
||||
def setup_late_mixins(policy: Policy, obs_space: gym.spaces.Space,
|
||||
action_space: gym.spaces.Space,
|
||||
config: TrainerConfigDict) -> None:
|
||||
"""Call all mixin classes' constructors before SimpleQTFPolicy initialization.
|
||||
|
||||
Args:
|
||||
policy (Policy): The Policy object.
|
||||
obs_space (gym.spaces.Space): The Policy's observation space.
|
||||
action_space (gym.spaces.Space): The Policy's action space.
|
||||
config (TrainerConfigDict): The Policy's config.
|
||||
"""
|
||||
TargetNetworkMixin.__init__(policy, obs_space, action_space, config)
|
||||
|
||||
|
||||
SimpleQTFPolicy = build_tf_policy(
|
||||
# Build a child class of `DynamicTFPolicy`, given the custom functions defined
|
||||
# above.
|
||||
SimpleQTFPolicy: DynamicTFPolicy = build_tf_policy(
|
||||
name="SimpleQTFPolicy",
|
||||
get_default_config=lambda: ray.rllib.agents.dqn.dqn.DEFAULT_CONFIG,
|
||||
make_model=build_q_models,
|
||||
|
|
|
@ -1,14 +1,15 @@
|
|||
"""Basic example of a DQN policy without any optimizations."""
|
||||
"""PyTorch policy class used for Simple Q-Learning"""
|
||||
|
||||
import logging
|
||||
from typing import Dict
|
||||
from typing import Dict, Tuple
|
||||
|
||||
import gym
|
||||
import ray
|
||||
from ray.rllib.agents.dqn.simple_q_tf_policy import (
|
||||
build_q_models, compute_q_values, get_distribution_inputs_and_class)
|
||||
from ray.rllib.models.modelv2 import ModelV2
|
||||
from ray.rllib.models.torch.torch_action_dist import TorchCategorical
|
||||
from ray.rllib.models.torch.torch_action_dist import TorchCategorical, \
|
||||
TorchDistributionWrapper
|
||||
from ray.rllib.policy import Policy
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.policy.torch_policy_template import build_torch_policy
|
||||
|
@ -24,8 +25,14 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
|
||||
class TargetNetworkMixin:
|
||||
def __init__(self, obs_space: gym.Space, action_space: gym.Space,
|
||||
config: TrainerConfigDict):
|
||||
"""Assign the `update_target` method to the SimpleQTorchPolicy
|
||||
|
||||
The function is called every `target_network_update_freq` steps by the
|
||||
master learner.
|
||||
"""
|
||||
|
||||
def __init__(self, obs_space: gym.spaces.Space,
|
||||
action_space: gym.spaces.Space, config: TrainerConfigDict):
|
||||
def do_update():
|
||||
# Update_target_fn will be called periodically to copy Q network to
|
||||
# target Q network.
|
||||
|
@ -36,15 +43,27 @@ class TargetNetworkMixin:
|
|||
self.update_target = do_update
|
||||
|
||||
|
||||
def build_q_model_and_distribution(policy: Policy, obs_space: gym.Space,
|
||||
action_space: gym.Space,
|
||||
config: TrainerConfigDict) -> ModelV2:
|
||||
def build_q_model_and_distribution(
|
||||
policy: Policy, obs_space: gym.spaces.Space,
|
||||
action_space: gym.spaces.Space,
|
||||
config: TrainerConfigDict) -> Tuple[ModelV2, TorchDistributionWrapper]:
|
||||
return build_q_models(policy, obs_space, action_space, config), \
|
||||
TorchCategorical
|
||||
|
||||
|
||||
def build_q_losses(policy: Policy, model, dist_class,
|
||||
train_batch: SampleBatch) -> TensorType:
|
||||
"""Constructs the loss for SimpleQTorchPolicy.
|
||||
|
||||
Args:
|
||||
policy (Policy): The Policy to calculate the loss for.
|
||||
model (ModelV2): The Model to calculate the loss for.
|
||||
dist_class (Type[ActionDistribution]): The action distribution class.
|
||||
train_batch (SampleBatch): The training data.
|
||||
|
||||
Returns:
|
||||
TensorType: A single loss tensor.
|
||||
"""
|
||||
# q network evaluation
|
||||
q_t = compute_q_values(
|
||||
policy,
|
||||
|
@ -89,13 +108,22 @@ def build_q_losses(policy: Policy, model, dist_class,
|
|||
|
||||
def extra_action_out_fn(policy: Policy, input_dict, state_batches, model,
|
||||
action_dist) -> Dict[str, TensorType]:
|
||||
"""Adds q-values to action out dict."""
|
||||
"""Adds q-values to the action out dict."""
|
||||
return {"q_values": policy.q_values}
|
||||
|
||||
|
||||
def setup_late_mixins(policy: Policy, obs_space: gym.Space,
|
||||
action_space: gym.Space,
|
||||
def setup_late_mixins(policy: Policy, obs_space: gym.spaces.Space,
|
||||
action_space: gym.spaces.Space,
|
||||
config: TrainerConfigDict) -> None:
|
||||
"""Call all mixin classes' constructors before SimpleQTorchPolicy
|
||||
initialization.
|
||||
|
||||
Args:
|
||||
policy (Policy): The Policy object.
|
||||
obs_space (gym.spaces.Space): The Policy's observation space.
|
||||
action_space (gym.spaces.Space): The Policy's action space.
|
||||
config (TrainerConfigDict): The Policy's config.
|
||||
"""
|
||||
TargetNetworkMixin.__init__(policy, obs_space, action_space, config)
|
||||
|
||||
|
||||
|
|
|
@ -91,7 +91,7 @@ def appo_surrogate_loss(
|
|||
Args:
|
||||
policy (Policy): The Policy to calculate the loss for.
|
||||
model (ModelV2): The Model to calculate the loss for.
|
||||
dist_class (Type[ActionDistribution]: The action distr. class.
|
||||
dist_class (Type[ActionDistribution]): The action distr. class.
|
||||
train_batch (SampleBatch): The training data.
|
||||
|
||||
Returns:
|
||||
|
|
Loading…
Add table
Reference in a new issue