ray/rllib/agents/dqn/dqn_torch_model.py

172 lines
6.9 KiB
Python

"""PyTorch model for DQN"""
from typing import Sequence
import gym
from ray.rllib.models.torch.misc import SlimFC
from ray.rllib.models.torch.modules.noisy_layer import NoisyLayer
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.typing import ModelConfigDict
torch, nn = try_import_torch()
class DQNTorchModel(TorchModelV2, nn.Module):
"""Extension of standard TorchModelV2 to provide dueling-Q functionality.
"""
def __init__(
self,
obs_space: gym.spaces.Space,
action_space: gym.spaces.Space,
num_outputs: int,
model_config: ModelConfigDict,
name: str,
*,
q_hiddens: Sequence[int] = (256, ),
dueling: bool = False,
dueling_activation: str = "relu",
num_atoms: int = 1,
use_noisy: bool = False,
v_min: float = -10.0,
v_max: float = 10.0,
sigma0: float = 0.5,
# TODO(sven): Move `add_layer_norm` into ModelCatalog as
# generic option, then error if we use ParameterNoise as
# Exploration type and do not have any LayerNorm layers in
# the net.
add_layer_norm: bool = False):
"""Initialize variables of this model.
Extra model kwargs:
q_hiddens (Sequence[int]): List of layer-sizes after(!) the
Advantages(A)/Value(V)-split. Hence, each of the A- and V-
branches will have this structure of Dense layers. To define
the NN before this A/V-split, use - as always -
config["model"]["fcnet_hiddens"].
dueling (bool): Whether to build the advantage(A)/value(V) heads
for DDQN. If True, Q-values are calculated as:
Q = (A - mean[A]) + V. If False, raw NN output is interpreted
as Q-values.
dueling_activation (str): The activation to use for all dueling
layers (A- and V-branch). One of "relu", "tanh", "linear".
num_atoms (int): If >1, enables distributional DQN.
use_noisy (bool): Use noisy layers.
v_min (float): Min value support for distributional DQN.
v_max (float): Max value support for distributional DQN.
sigma0 (float): Initial value of noisy layers.
add_layer_norm (bool): Enable layer norm (for param noise).
"""
nn.Module.__init__(self)
super(DQNTorchModel, self).__init__(obs_space, action_space,
num_outputs, model_config, name)
self.dueling = dueling
self.num_atoms = num_atoms
self.v_min = v_min
self.v_max = v_max
self.sigma0 = sigma0
ins = num_outputs
advantage_module = nn.Sequential()
value_module = nn.Sequential()
# Dueling case: Build the shared (advantages and value) fc-network.
for i, n in enumerate(q_hiddens):
if use_noisy:
advantage_module.add_module(
"dueling_A_{}".format(i),
NoisyLayer(
ins,
n,
sigma0=self.sigma0,
activation=dueling_activation))
value_module.add_module(
"dueling_V_{}".format(i),
NoisyLayer(
ins,
n,
sigma0=self.sigma0,
activation=dueling_activation))
else:
advantage_module.add_module(
"dueling_A_{}".format(i),
SlimFC(ins, n, activation_fn=dueling_activation))
value_module.add_module(
"dueling_V_{}".format(i),
SlimFC(ins, n, activation_fn=dueling_activation))
# Add LayerNorm after each Dense.
if add_layer_norm:
advantage_module.add_module("LayerNorm_A_{}".format(i),
nn.LayerNorm(n))
value_module.add_module("LayerNorm_V_{}".format(i),
nn.LayerNorm(n))
ins = n
# Actual Advantages layer (nodes=num-actions).
if use_noisy:
advantage_module.add_module(
"A",
NoisyLayer(
ins,
self.action_space.n * self.num_atoms,
sigma0,
activation=None))
elif q_hiddens:
advantage_module.add_module(
"A",
SlimFC(
ins, action_space.n * self.num_atoms, activation_fn=None))
self.advantage_module = advantage_module
# Value layer (nodes=1).
if self.dueling:
if use_noisy:
value_module.add_module(
"V",
NoisyLayer(ins, self.num_atoms, sigma0, activation=None))
elif q_hiddens:
value_module.add_module(
"V", SlimFC(ins, self.num_atoms, activation_fn=None))
self.value_module = value_module
def get_q_value_distributions(self, model_out):
"""Returns distributional values for Q(s, a) given a state embedding.
Override this in your custom model to customize the Q output head.
Args:
model_out (Tensor): Embedding from the model layers.
Returns:
(action_scores, logits, dist) if num_atoms == 1, otherwise
(action_scores, z, support_logits_per_action, logits, dist)
"""
action_scores = self.advantage_module(model_out)
if self.num_atoms > 1:
# Distributional Q-learning uses a discrete support z
# to represent the action value distribution
z = torch.range(
0.0, self.num_atoms - 1,
dtype=torch.float32).to(action_scores.device)
z = self.v_min + \
z * (self.v_max - self.v_min) / float(self.num_atoms - 1)
support_logits_per_action = torch.reshape(
action_scores, shape=(-1, self.action_space.n, self.num_atoms))
support_prob_per_action = nn.functional.softmax(
support_logits_per_action, dim=-1)
action_scores = torch.sum(z * support_prob_per_action, dim=-1)
logits = support_logits_per_action
probs = support_prob_per_action
return action_scores, z, support_logits_per_action, logits, probs
else:
logits = torch.unsqueeze(torch.ones_like(action_scores), -1)
return action_scores, logits, logits
def get_state_value(self, model_out):
"""Returns the state value prediction for the given state embedding."""
return self.value_module(model_out)