ray/rllib/agents/dqn/r2d2_torch_policy.py

290 lines
10 KiB
Python

"""PyTorch policy class used for R2D2."""
from typing import Dict, Tuple
import gym
import ray
from ray.rllib.agents.dqn.dqn_tf_policy import (PRIO_WEIGHTS,
postprocess_nstep_and_prio)
from ray.rllib.agents.dqn.dqn_torch_policy import adam_optimizer, \
build_q_model_and_distribution, compute_q_values
from ray.rllib.agents.dqn.r2d2_tf_policy import \
get_distribution_inputs_and_class
from ray.rllib.agents.dqn.simple_q_torch_policy import TargetNetworkMixin
from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.models.torch.torch_action_dist import \
TorchDistributionWrapper
from ray.rllib.policy.policy import Policy
from ray.rllib.policy.policy_template import build_policy_class
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.policy.torch_policy import LearningRateSchedule
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.torch_ops import apply_grad_clipping, FLOAT_MIN, \
huber_loss, sequence_mask
from ray.rllib.utils.typing import TensorType, TrainerConfigDict
torch, nn = try_import_torch()
F = None
if nn:
F = nn.functional
def build_r2d2_model_and_distribution(
policy: Policy, obs_space: gym.spaces.Space,
action_space: gym.spaces.Space,
config: TrainerConfigDict) -> \
Tuple[ModelV2, TorchDistributionWrapper]:
"""Build q_model and target_model for DQN
Args:
policy (Policy): The policy, which will use the model for optimization.
obs_space (gym.spaces.Space): The policy's observation space.
action_space (gym.spaces.Space): The policy's action space.
config (TrainerConfigDict):
Returns:
(q_model, TorchCategorical)
Note: The target q model will not be returned, just assigned to
`policy.target_model`.
"""
# Create the policy's models and action dist class.
model, distribution_cls = build_q_model_and_distribution(
policy, obs_space, action_space, config)
# Assert correct model type by checking the init state to be present.
# For attention nets: These don't necessarily publish their init state via
# Model.get_initial_state, but may only use the trajectory view API
# (view_requirements).
assert (model.get_initial_state() != [] or
model.view_requirements.get("state_in_0") is not None), \
"R2D2 requires its model to be a recurrent one! Try using " \
"`model.use_lstm` or `model.use_attention` in your config " \
"to auto-wrap your model with an LSTM- or attention net."
return model, distribution_cls
def r2d2_loss(policy: Policy, model, _,
train_batch: SampleBatch) -> TensorType:
"""Constructs the loss for R2D2TorchPolicy.
Args:
policy (Policy): The Policy to calculate the loss for.
model (ModelV2): The Model to calculate the loss for.
train_batch (SampleBatch): The training data.
Returns:
TensorType: A single loss tensor.
"""
target_model = policy.target_models[model]
config = policy.config
# Construct internal state inputs.
i = 0
state_batches = []
while "state_in_{}".format(i) in train_batch:
state_batches.append(train_batch["state_in_{}".format(i)])
i += 1
assert state_batches
# Q-network evaluation (at t).
q, _, _, _ = compute_q_values(
policy,
model,
train_batch,
state_batches=state_batches,
seq_lens=train_batch.get(SampleBatch.SEQ_LENS),
explore=False,
is_training=True)
# Target Q-network evaluation (at t+1).
q_target, _, _, _ = compute_q_values(
policy,
target_model,
train_batch,
state_batches=state_batches,
seq_lens=train_batch.get(SampleBatch.SEQ_LENS),
explore=False,
is_training=True)
actions = train_batch[SampleBatch.ACTIONS].long()
dones = train_batch[SampleBatch.DONES].float()
rewards = train_batch[SampleBatch.REWARDS]
weights = train_batch[PRIO_WEIGHTS]
B = state_batches[0].shape[0]
T = q.shape[0] // B
# Q scores for actions which we know were selected in the given state.
one_hot_selection = F.one_hot(actions, policy.action_space.n)
q_selected = torch.sum(
torch.where(q > FLOAT_MIN, q, torch.tensor(0.0, device=q.device)) *
one_hot_selection, 1)
if config["double_q"]:
best_actions = torch.argmax(q, dim=1)
else:
best_actions = torch.argmax(q_target, dim=1)
best_actions_one_hot = F.one_hot(best_actions, policy.action_space.n)
q_target_best = torch.sum(
torch.where(q_target > FLOAT_MIN, q_target,
torch.tensor(0.0, device=q_target.device)) *
best_actions_one_hot,
dim=1)
if config["num_atoms"] > 1:
raise ValueError("Distributional R2D2 not supported yet!")
else:
q_target_best_masked_tp1 = (1.0 - dones) * torch.cat([
q_target_best[1:],
torch.tensor([0.0], device=q_target_best.device)
])
if config["use_h_function"]:
h_inv = h_inverse(q_target_best_masked_tp1,
config["h_function_epsilon"])
target = h_function(
rewards + config["gamma"]**config["n_step"] * h_inv,
config["h_function_epsilon"])
else:
target = rewards + \
config["gamma"] ** config["n_step"] * q_target_best_masked_tp1
# Seq-mask all loss-related terms.
seq_mask = sequence_mask(train_batch[SampleBatch.SEQ_LENS], T)[:, :-1]
# Mask away also the burn-in sequence at the beginning.
burn_in = policy.config["burn_in"]
if burn_in > 0 and burn_in < T:
seq_mask[:, :burn_in] = False
num_valid = torch.sum(seq_mask)
def reduce_mean_valid(t):
return torch.sum(t[seq_mask]) / num_valid
# Make sure use the correct time indices:
# Q(t) - [gamma * r + Q^(t+1)]
q_selected = q_selected.reshape([B, T])[:, :-1]
td_error = q_selected - target.reshape([B, T])[:, :-1].detach()
td_error = td_error * seq_mask
weights = weights.reshape([B, T])[:, :-1]
policy._total_loss = reduce_mean_valid(weights * huber_loss(td_error))
policy._td_error = td_error.reshape([-1])
policy._loss_stats = {
"mean_q": reduce_mean_valid(q_selected),
"min_q": torch.min(q_selected),
"max_q": torch.max(q_selected),
"mean_td_error": reduce_mean_valid(td_error),
}
return policy._total_loss
def h_function(x, epsilon=1.0):
"""h-function to normalize target Qs, described in the paper [1].
h(x) = sign(x) * [sqrt(abs(x) + 1) - 1] + epsilon * x
Used in [1] in combination with h_inverse:
targets = h(r + gamma * h_inverse(Q^))
"""
return torch.sign(x) * (torch.sqrt(torch.abs(x) + 1.0) - 1.0) + epsilon * x
def h_inverse(x, epsilon=1.0):
"""Inverse if the above h-function, described in the paper [1].
If x > 0.0:
h-1(x) = [2eps * x + (2eps + 1) - sqrt(4eps x + (2eps + 1)^2)] /
(2 * eps^2)
If x < 0.0:
h-1(x) = [2eps * x + (2eps + 1) + sqrt(-4eps x + (2eps + 1)^2)] /
(2 * eps^2)
"""
two_epsilon = epsilon * 2
if_x_pos = (two_epsilon * x + (two_epsilon + 1.0) -
torch.sqrt(4.0 * epsilon * x +
(two_epsilon + 1.0)**2)) / (2.0 * epsilon**2)
if_x_neg = (two_epsilon * x - (two_epsilon + 1.0) +
torch.sqrt(-4.0 * epsilon * x +
(two_epsilon + 1.0)**2)) / (2.0 * epsilon**2)
return torch.where(x < 0.0, if_x_neg, if_x_pos)
class ComputeTDErrorMixin:
"""Assign the `compute_td_error` method to the R2D2TorchPolicy
This allows us to prioritize on the worker side.
"""
def __init__(self):
def compute_td_error(obs_t, act_t, rew_t, obs_tp1, done_mask,
importance_weights):
input_dict = self._lazy_tensor_dict({SampleBatch.CUR_OBS: obs_t})
input_dict[SampleBatch.ACTIONS] = act_t
input_dict[SampleBatch.REWARDS] = rew_t
input_dict[SampleBatch.NEXT_OBS] = obs_tp1
input_dict[SampleBatch.DONES] = done_mask
input_dict[PRIO_WEIGHTS] = importance_weights
# Do forward pass on loss to update td error attribute
r2d2_loss(self, self.model, None, input_dict)
return self._td_error
self.compute_td_error = compute_td_error
def build_q_stats(policy: Policy, batch) -> Dict[str, TensorType]:
return dict({
"cur_lr": policy.cur_lr,
}, **policy._loss_stats)
def setup_early_mixins(policy: Policy, obs_space, action_space,
config: TrainerConfigDict) -> None:
LearningRateSchedule.__init__(policy, config["lr"], config["lr_schedule"])
def before_loss_init(policy: Policy, obs_space: gym.spaces.Space,
action_space: gym.spaces.Space,
config: TrainerConfigDict) -> None:
ComputeTDErrorMixin.__init__(policy)
TargetNetworkMixin.__init__(policy)
def grad_process_and_td_error_fn(policy: Policy,
optimizer: "torch.optim.Optimizer",
loss: TensorType) -> Dict[str, TensorType]:
# Clip grads if configured.
return apply_grad_clipping(policy, optimizer, loss)
def extra_action_out_fn(policy: Policy, input_dict, state_batches, model,
action_dist) -> Dict[str, TensorType]:
return {"q_values": policy.q_values}
R2D2TorchPolicy = build_policy_class(
name="R2D2TorchPolicy",
framework="torch",
loss_fn=r2d2_loss,
get_default_config=lambda: ray.rllib.agents.dqn.r2d2.DEFAULT_CONFIG,
make_model_and_action_dist=build_r2d2_model_and_distribution,
action_distribution_fn=get_distribution_inputs_and_class,
stats_fn=build_q_stats,
postprocess_fn=postprocess_nstep_and_prio,
optimizer_fn=adam_optimizer,
extra_grad_process_fn=grad_process_and_td_error_fn,
extra_learn_fetches_fn=lambda policy: {"td_error": policy._td_error},
extra_action_out_fn=extra_action_out_fn,
before_init=setup_early_mixins,
before_loss_init=before_loss_init,
mixins=[
TargetNetworkMixin,
ComputeTDErrorMixin,
LearningRateSchedule,
])