ray/rllib/agents/ppo/ppo_torch_policy.py

267 lines
9.8 KiB
Python

"""
PyTorch policy class used for PPO.
"""
import gym
import logging
from typing import Dict, List, Type, Union
import ray
from ray.rllib.agents.ppo.ppo_tf_policy import setup_config
from ray.rllib.evaluation.postprocessing import compute_gae_for_sample_batch, \
Postprocessing
from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.models.torch.torch_action_dist import TorchDistributionWrapper
from ray.rllib.policy.policy import Policy
from ray.rllib.policy.policy_template import build_policy_class
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.policy.torch_policy import EntropyCoeffSchedule, \
LearningRateSchedule
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.torch_ops import apply_grad_clipping, \
explained_variance, sequence_mask
from ray.rllib.utils.typing import TensorType, TrainerConfigDict
torch, nn = try_import_torch()
logger = logging.getLogger(__name__)
def ppo_surrogate_loss(
policy: Policy, model: ModelV2,
dist_class: Type[TorchDistributionWrapper],
train_batch: SampleBatch) -> Union[TensorType, List[TensorType]]:
"""Constructs the loss for Proximal Policy Objective.
Args:
policy (Policy): The Policy to calculate the loss for.
model (ModelV2): The Model to calculate the loss for.
dist_class (Type[ActionDistribution]: The action distr. class.
train_batch (SampleBatch): The training data.
Returns:
Union[TensorType, List[TensorType]]: A single loss tensor or a list
of loss tensors.
"""
logits, state = model(train_batch)
curr_action_dist = dist_class(logits, model)
# RNN case: Mask away 0-padded chunks at end of time axis.
if state:
B = len(train_batch["seq_lens"])
max_seq_len = logits.shape[0] // B
mask = sequence_mask(
train_batch["seq_lens"],
max_seq_len,
time_major=model.is_time_major())
mask = torch.reshape(mask, [-1])
num_valid = torch.sum(mask)
def reduce_mean_valid(t):
return torch.sum(t[mask]) / num_valid
# non-RNN case: No masking.
else:
mask = None
reduce_mean_valid = torch.mean
prev_action_dist = dist_class(train_batch[SampleBatch.ACTION_DIST_INPUTS],
model)
logp_ratio = torch.exp(
curr_action_dist.logp(train_batch[SampleBatch.ACTIONS]) -
train_batch[SampleBatch.ACTION_LOGP])
action_kl = prev_action_dist.kl(curr_action_dist)
mean_kl = reduce_mean_valid(action_kl)
curr_entropy = curr_action_dist.entropy()
mean_entropy = reduce_mean_valid(curr_entropy)
surrogate_loss = torch.min(
train_batch[Postprocessing.ADVANTAGES] * logp_ratio,
train_batch[Postprocessing.ADVANTAGES] * torch.clamp(
logp_ratio, 1 - policy.config["clip_param"],
1 + policy.config["clip_param"]))
mean_policy_loss = reduce_mean_valid(-surrogate_loss)
# Compute a value function loss.
if policy.config["use_critic"]:
prev_value_fn_out = train_batch[SampleBatch.VF_PREDS]
value_fn_out = model.value_function()
vf_loss1 = torch.pow(
value_fn_out - train_batch[Postprocessing.VALUE_TARGETS], 2.0)
vf_clipped = prev_value_fn_out + torch.clamp(
value_fn_out - prev_value_fn_out, -policy.config["vf_clip_param"],
policy.config["vf_clip_param"])
vf_loss2 = torch.pow(
vf_clipped - train_batch[Postprocessing.VALUE_TARGETS], 2.0)
vf_loss = torch.max(vf_loss1, vf_loss2)
mean_vf_loss = reduce_mean_valid(vf_loss)
# Ignore the value function.
else:
vf_loss = mean_vf_loss = 0.0
total_loss = reduce_mean_valid(-surrogate_loss +
policy.kl_coeff * action_kl +
policy.config["vf_loss_coeff"] * vf_loss -
policy.entropy_coeff * curr_entropy)
# Store stats in policy for stats_fn.
policy._total_loss = total_loss
policy._mean_policy_loss = mean_policy_loss
policy._mean_vf_loss = mean_vf_loss
policy._vf_explained_var = explained_variance(
train_batch[Postprocessing.VALUE_TARGETS], model.value_function())
policy._mean_entropy = mean_entropy
policy._mean_kl = mean_kl
return total_loss
def kl_and_loss_stats(policy: Policy,
train_batch: SampleBatch) -> Dict[str, TensorType]:
"""Stats function for PPO. Returns a dict with important KL and loss stats.
Args:
policy (Policy): The Policy to generate stats for.
train_batch (SampleBatch): The SampleBatch (already) used for training.
Returns:
Dict[str, TensorType]: The stats dict.
"""
return {
"cur_kl_coeff": policy.kl_coeff,
"cur_lr": policy.cur_lr,
"total_loss": policy._total_loss,
"policy_loss": policy._mean_policy_loss,
"vf_loss": policy._mean_vf_loss,
"vf_explained_var": policy._vf_explained_var,
"kl": policy._mean_kl,
"entropy": policy._mean_entropy,
"entropy_coeff": policy.entropy_coeff,
}
def vf_preds_fetches(
policy: Policy, input_dict: Dict[str, TensorType],
state_batches: List[TensorType], model: ModelV2,
action_dist: TorchDistributionWrapper) -> Dict[str, TensorType]:
"""Defines extra fetches per action computation.
Args:
policy (Policy): The Policy to perform the extra action fetch on.
input_dict (Dict[str, TensorType]): The input dict used for the action
computing forward pass.
state_batches (List[TensorType]): List of state tensors (empty for
non-RNNs).
model (ModelV2): The Model object of the Policy.
action_dist (TorchDistributionWrapper): The instantiated distribution
object, resulting from the model's outputs and the given
distribution class.
Returns:
Dict[str, TensorType]: Dict with extra tf fetches to perform per
action computation.
"""
# Return value function outputs. VF estimates will hence be added to the
# SampleBatches produced by the sampler(s) to generate the train batches
# going into the loss function.
return {
SampleBatch.VF_PREDS: model.value_function(),
}
class KLCoeffMixin:
"""Assigns the `update_kl()` method to the PPOPolicy.
This is used in PPO's execution plan (see ppo.py) for updating the KL
coefficient after each learning step based on `config.kl_target` and
the measured KL value (from the train_batch).
"""
def __init__(self, config):
# The current KL value (as python float).
self.kl_coeff = config["kl_coeff"]
# Constant target value.
self.kl_target = config["kl_target"]
def update_kl(self, sampled_kl):
# Update the current KL value based on the recently measured value.
if sampled_kl > 2.0 * self.kl_target:
self.kl_coeff *= 1.5
elif sampled_kl < 0.5 * self.kl_target:
self.kl_coeff *= 0.5
# Return the current KL value.
return self.kl_coeff
class ValueNetworkMixin:
"""Assigns the `_value()` method to the PPOPolicy.
This way, Policy can call `_value()` to get the current VF estimate on a
single(!) observation (as done in `postprocess_trajectory_fn`).
Note: When doing this, an actual forward pass is being performed.
This is different from only calling `model.value_function()`, where
the result of the most recent forward pass is being used to return an
already calculated tensor.
"""
def __init__(self, obs_space, action_space, config):
# When doing GAE, we need the value function estimate on the
# observation.
if config["use_gae"]:
# Input dict is provided to us automatically via the Model's
# requirements. It's a single-timestep (last one in trajectory)
# input_dict.
def value(**input_dict):
input_dict = SampleBatch(input_dict)
input_dict = self._lazy_tensor_dict(input_dict)
model_out, _ = self.model(input_dict)
# [0] = remove the batch dim.
return self.model.value_function()[0].item()
# When not doing GAE, we do not require the value function's output.
else:
def value(*args, **kwargs):
return 0.0
self._value = value
def setup_mixins(policy: Policy, obs_space: gym.spaces.Space,
action_space: gym.spaces.Space,
config: TrainerConfigDict) -> None:
"""Call all mixin classes' constructors before PPOPolicy initialization.
Args:
policy (Policy): The Policy object.
obs_space (gym.spaces.Space): The Policy's observation space.
action_space (gym.spaces.Space): The Policy's action space.
config (TrainerConfigDict): The Policy's config.
"""
ValueNetworkMixin.__init__(policy, obs_space, action_space, config)
KLCoeffMixin.__init__(policy, config)
EntropyCoeffSchedule.__init__(policy, config["entropy_coeff"],
config["entropy_coeff_schedule"])
LearningRateSchedule.__init__(policy, config["lr"], config["lr_schedule"])
# Build a child class of `TorchPolicy`, given the custom functions defined
# above.
PPOTorchPolicy = build_policy_class(
name="PPOTorchPolicy",
framework="torch",
get_default_config=lambda: ray.rllib.agents.ppo.ppo.DEFAULT_CONFIG,
loss_fn=ppo_surrogate_loss,
stats_fn=kl_and_loss_stats,
extra_action_out_fn=vf_preds_fetches,
postprocess_fn=compute_gae_for_sample_batch,
extra_grad_process_fn=apply_grad_clipping,
before_init=setup_config,
before_loss_init=setup_mixins,
mixins=[
LearningRateSchedule, EntropyCoeffSchedule, KLCoeffMixin,
ValueNetworkMixin
],
)