ray/rllib/algorithms/pg/pg_torch_policy.py

85 lines
2.7 KiB
Python

"""
PyTorch policy class used for PG.
"""
from typing import Dict, List, Type, Union
import ray
from ray.rllib.algorithms.pg.utils import post_process_advantages
from ray.rllib.evaluation.postprocessing import Postprocessing
from ray.rllib.models.torch.torch_action_dist import TorchDistributionWrapper
from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.policy import Policy
from ray.rllib.policy.policy_template import build_policy_class
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.typing import TensorType
torch, _ = try_import_torch()
def pg_torch_loss(
policy: Policy,
model: ModelV2,
dist_class: Type[TorchDistributionWrapper],
train_batch: SampleBatch,
) -> Union[TensorType, List[TensorType]]:
"""The basic policy gradients loss function.
Args:
policy (Policy): The Policy to calculate the loss for.
model (ModelV2): The Model to calculate the loss for.
dist_class (Type[ActionDistribution]: The action distr. class.
train_batch (SampleBatch): The training data.
Returns:
Union[TensorType, List[TensorType]]: A single loss tensor or a list
of loss tensors.
"""
# Pass the training data through our model to get distribution parameters.
dist_inputs, _ = model(train_batch)
# Create an action distribution object.
action_dist = dist_class(dist_inputs, model)
# Calculate the vanilla PG loss based on:
# L = -E[ log(pi(a|s)) * A]
log_probs = action_dist.logp(train_batch[SampleBatch.ACTIONS])
# Final policy loss.
policy_loss = -torch.mean(log_probs * train_batch[Postprocessing.ADVANTAGES])
# Store values for stats function in model (tower), such that for
# multi-GPU, we do not override them during the parallel loss phase.
model.tower_stats["policy_loss"] = policy_loss
return policy_loss
def pg_loss_stats(policy: Policy, train_batch: SampleBatch) -> Dict[str, TensorType]:
"""Returns the calculated loss in a stats dict.
Args:
policy (Policy): The Policy object.
train_batch (SampleBatch): The data used for training.
Returns:
Dict[str, TensorType]: The stats dict.
"""
return {
"policy_loss": torch.mean(torch.stack(policy.get_tower_stats("policy_loss"))),
}
# Build a child class of `TorchPolicy`, given the extra options:
# - trajectory post-processing function (to calculate advantages)
# - PG loss function
PGTorchPolicy = build_policy_class(
name="PGTorchPolicy",
framework="torch",
get_default_config=lambda: ray.rllib.algorithms.pg.DEFAULT_CONFIG,
loss_fn=pg_torch_loss,
stats_fn=pg_loss_stats,
postprocess_fn=post_process_advantages,
)