ray/rllib/algorithms/pg/pg_tf_policy.py

150 lines
4.8 KiB
Python

"""
TensorFlow policy class used for PG.
"""
import logging
from typing import Dict, List, Type, Union, Optional, Tuple
import ray
from ray.rllib.evaluation.episode import Episode
from ray.rllib.policy.dynamic_tf_policy_v2 import DynamicTFPolicyV2
from ray.rllib.policy.eager_tf_policy_v2 import EagerTFPolicyV2
from ray.rllib.algorithms.pg.utils import post_process_advantages
from ray.rllib.utils.typing import AgentID
from ray.rllib.utils.annotations import override
from ray.rllib.utils.typing import (
TFPolicyV2Type,
)
from ray.rllib.evaluation.postprocessing import Postprocessing
from ray.rllib.models.action_dist import ActionDistribution
from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.policy import Policy
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils.framework import try_import_tf
from ray.rllib.utils.typing import TensorType
tf1, tf, tfv = try_import_tf()
logger = logging.getLogger(__name__)
# We need this builder function because we want to share the same
# custom logics between TF1 dynamic and TF2 eager policies.
def get_pg_tf_policy(name: str, base: TFPolicyV2Type) -> TFPolicyV2Type:
"""Construct a PGTFPolicy inheriting either dynamic or eager base policies.
Args:
base: Base class for this policy. DynamicTFPolicyV2 or EagerTFPolicyV2.
Returns:
A TF Policy to be used with PGTrainer.
"""
class PGTFPolicy(
base,
):
def __init__(
self,
obs_space,
action_space,
config,
existing_model=None,
existing_inputs=None,
):
# First thing first, enable eager execution if necessary.
base.enable_eager_execution_if_necessary()
config = dict(ray.rllib.algorithms.pg.PGConfig().to_dict(), **config)
# Initialize base class.
base.__init__(
self,
obs_space,
action_space,
config,
existing_inputs=existing_inputs,
existing_model=existing_model,
)
# Note: this is a bit ugly, but loss and optimizer initialization must
# happen after all the MixIns are initialized.
self.maybe_initialize_optimizer_and_loss()
@override(base)
def loss(
self,
model: ModelV2,
dist_class: Type[ActionDistribution],
train_batch: SampleBatch,
) -> Union[TensorType, List[TensorType]]:
"""The basic policy gradients loss function.
Calculates the vanilla policy gradient loss based on:
L = -E[ log(pi(a|s)) * A]
Args:
model: The Model to calculate the loss for.
dist_class: The action distr. class.
train_batch: The training data.
Returns:
Union[TensorType, List[TensorType]]: A single loss tensor or a list
of loss tensors.
"""
# Pass the training data through our model to get distribution parameters.
dist_inputs, _ = model(train_batch)
# Create an action distribution object.
action_dist = dist_class(dist_inputs, model)
# Calculate the vanilla PG loss based on:
# L = -E[ log(pi(a|s)) * A]
loss = -tf.reduce_mean(
action_dist.logp(train_batch[SampleBatch.ACTIONS])
* tf.cast(train_batch[Postprocessing.ADVANTAGES], dtype=tf.float32)
)
self.policy_loss = loss
return loss
@override(base)
def postprocess_trajectory(
self,
sample_batch: SampleBatch,
other_agent_batches: Optional[
Dict[AgentID, Tuple["Policy", SampleBatch]]
] = None,
episode: Optional["Episode"] = None,
) -> SampleBatch:
sample_batch = super().postprocess_trajectory(
sample_batch, other_agent_batches, episode
)
return post_process_advantages(
self, sample_batch, other_agent_batches, episode
)
@override(base)
def stats_fn(self, train_batch: SampleBatch) -> Dict[str, TensorType]:
"""Returns the calculated loss in a stats dict.
Args:
policy: The Policy object.
train_batch: The data used for training.
Returns:
Dict[str, TensorType]: The stats dict.
"""
return {
"policy_loss": self.policy_loss,
}
PGTFPolicy.__name__ = name
PGTFPolicy.__qualname__ = name
return PGTFPolicy
PGTF1Policy = get_pg_tf_policy("PGTF1Policy", DynamicTFPolicyV2)
PGTF2Policy = get_pg_tf_policy("PGTF2Policy", EagerTFPolicyV2)