ray/rllib/algorithms/pg/pg_torch_policy.py

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

86 lines
2.7 KiB
Python
Raw Normal View History

"""
PyTorch policy class used for PG.
"""
from typing import Dict, List, Type, Union
import ray
from ray.rllib.algorithms.pg.utils import post_process_advantages
from ray.rllib.evaluation.postprocessing import Postprocessing
from ray.rllib.models.torch.torch_action_dist import TorchDistributionWrapper
from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.policy import Policy
from ray.rllib.policy.policy_template import build_policy_class
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.typing import TensorType
torch, _ = try_import_torch()
def pg_torch_loss(
policy: Policy,
model: ModelV2,
dist_class: Type[TorchDistributionWrapper],
train_batch: SampleBatch,
) -> Union[TensorType, List[TensorType]]:
"""The basic policy gradients loss function.
Args:
policy (Policy): The Policy to calculate the loss for.
model (ModelV2): The Model to calculate the loss for.
dist_class (Type[ActionDistribution]: The action distr. class.
train_batch (SampleBatch): The training data.
Returns:
Union[TensorType, List[TensorType]]: A single loss tensor or a list
of loss tensors.
"""
# Pass the training data through our model to get distribution parameters.
dist_inputs, _ = model(train_batch)
# Create an action distribution object.
action_dist = dist_class(dist_inputs, model)
# Calculate the vanilla PG loss based on:
# L = -E[ log(pi(a|s)) * A]
log_probs = action_dist.logp(train_batch[SampleBatch.ACTIONS])
# Final policy loss.
policy_loss = -torch.mean(log_probs * train_batch[Postprocessing.ADVANTAGES])
# Store values for stats function in model (tower), such that for
# multi-GPU, we do not override them during the parallel loss phase.
model.tower_stats["policy_loss"] = policy_loss
return policy_loss
def pg_loss_stats(policy: Policy, train_batch: SampleBatch) -> Dict[str, TensorType]:
"""Returns the calculated loss in a stats dict.
Args:
policy (Policy): The Policy object.
train_batch (SampleBatch): The data used for training.
Returns:
Dict[str, TensorType]: The stats dict.
"""
return {
"policy_loss": torch.mean(torch.stack(policy.get_tower_stats("policy_loss"))),
}
# Build a child class of `TorchPolicy`, given the extra options:
# - trajectory post-processing function (to calculate advantages)
# - PG loss function
PGTorchPolicy = build_policy_class(
name="PGTorchPolicy",
framework="torch",
get_default_config=lambda: ray.rllib.algorithms.pg.DEFAULT_CONFIG,
loss_fn=pg_torch_loss,
stats_fn=pg_loss_stats,
postprocess_fn=post_process_advantages,
)