ray/rllib/algorithms/a3c/a3c_tf_policy.py

197 lines
6.6 KiB
Python

"""Note: Keep in sync with changes to VTraceTFPolicy."""
from typing import Dict, List, Optional, Type, Union
import ray
from ray.rllib.evaluation.episode import Episode
from ray.rllib.evaluation.postprocessing import (
compute_gae_for_sample_batch,
Postprocessing,
)
from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.models.tf.tf_action_dist import TFActionDistribution
from ray.rllib.policy.dynamic_tf_policy_v2 import DynamicTFPolicyV2
from ray.rllib.policy.eager_tf_policy_v2 import EagerTFPolicyV2
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.policy.tf_mixins import (
compute_gradients,
EntropyCoeffSchedule,
LearningRateSchedule,
ValueNetworkMixin,
)
from ray.rllib.utils.annotations import override
from ray.rllib.utils.deprecation import Deprecated
from ray.rllib.utils.framework import try_import_tf
from ray.rllib.utils.tf_utils import explained_variance
from ray.rllib.utils.typing import (
AgentID,
LocalOptimizer,
ModelGradients,
TensorType,
TFPolicyV2Type,
)
tf1, tf, tfv = try_import_tf()
# We need this builder function because we want to share the same
# custom logics between TF1 dynamic and TF2 eager policies.
def get_a3c_tf_policy(name: str, base: TFPolicyV2Type) -> TFPolicyV2Type:
"""Construct a A3CTFPolicy inheriting either dynamic or eager base policies.
Args:
base: Base class for this policy. DynamicTFPolicyV2 or EagerTFPolicyV2.
Returns:
A TF Policy to be used with MAML.
"""
class A3CTFPolicy(
ValueNetworkMixin, LearningRateSchedule, EntropyCoeffSchedule, base
):
def __init__(
self,
obs_space,
action_space,
config,
existing_model=None,
existing_inputs=None,
):
# First thing first, enable eager execution if necessary.
base.enable_eager_execution_if_necessary()
config = dict(ray.rllib.algorithms.a3c.a3c.A3CConfig().to_dict(), **config)
# Initialize base class.
base.__init__(
self,
obs_space,
action_space,
config,
existing_inputs=existing_inputs,
existing_model=existing_model,
)
ValueNetworkMixin.__init__(self, self.config)
LearningRateSchedule.__init__(
self, self.config["lr"], self.config["lr_schedule"]
)
EntropyCoeffSchedule.__init__(
self, config["entropy_coeff"], config["entropy_coeff_schedule"]
)
# Note: this is a bit ugly, but loss and optimizer initialization must
# happen after all the MixIns are initialized.
self.maybe_initialize_optimizer_and_loss()
@override(base)
def loss(
self,
model: Union[ModelV2, "tf.keras.Model"],
dist_class: Type[TFActionDistribution],
train_batch: SampleBatch,
) -> Union[TensorType, List[TensorType]]:
model_out, _ = model(train_batch)
action_dist = dist_class(model_out, model)
if self.is_recurrent():
max_seq_len = tf.reduce_max(train_batch[SampleBatch.SEQ_LENS])
valid_mask = tf.sequence_mask(
train_batch[SampleBatch.SEQ_LENS], max_seq_len
)
valid_mask = tf.reshape(valid_mask, [-1])
else:
valid_mask = tf.ones_like(train_batch[SampleBatch.REWARDS])
log_prob = action_dist.logp(train_batch[SampleBatch.ACTIONS])
vf = model.value_function()
# The "policy gradients" loss
self.pi_loss = -tf.reduce_sum(
tf.boolean_mask(
log_prob * train_batch[Postprocessing.ADVANTAGES], valid_mask
)
)
delta = tf.boolean_mask(
vf - train_batch[Postprocessing.VALUE_TARGETS], valid_mask
)
# Compute a value function loss.
if self.config.get("use_critic", True):
self.vf_loss = 0.5 * tf.reduce_sum(tf.math.square(delta))
# Ignore the value function.
else:
self.vf_loss = tf.constant(0.0)
self.entropy_loss = tf.reduce_sum(
tf.boolean_mask(action_dist.entropy(), valid_mask)
)
self.total_loss = (
self.pi_loss
+ self.vf_loss * self.config["vf_loss_coeff"]
- self.entropy_loss * self.entropy_coeff
)
return self.total_loss
@override(base)
def stats_fn(self, train_batch: SampleBatch) -> Dict[str, TensorType]:
return {
"cur_lr": tf.cast(self.cur_lr, tf.float64),
"entropy_coeff": tf.cast(self.entropy_coeff, tf.float64),
"policy_loss": self.pi_loss,
"policy_entropy": self.entropy_loss,
"var_gnorm": tf.linalg.global_norm(
list(self.model.trainable_variables())
),
"vf_loss": self.vf_loss,
}
@override(base)
def grad_stats_fn(
self, train_batch: SampleBatch, grads: ModelGradients
) -> Dict[str, TensorType]:
return {
"grad_gnorm": tf.linalg.global_norm(grads),
"vf_explained_var": explained_variance(
train_batch[Postprocessing.VALUE_TARGETS],
self.model.value_function(),
),
}
@override(base)
def postprocess_trajectory(
self,
sample_batch: SampleBatch,
other_agent_batches: Optional[Dict[AgentID, SampleBatch]] = None,
episode: Optional[Episode] = None,
):
sample_batch = super().postprocess_trajectory(sample_batch)
return compute_gae_for_sample_batch(
self, sample_batch, other_agent_batches, episode
)
@override(base)
def compute_gradients_fn(
self, optimizer: LocalOptimizer, loss: TensorType
) -> ModelGradients:
return compute_gradients(self, optimizer, loss)
A3CTFPolicy.__name__ = name
A3CTFPolicy.__qualname__ = name
return A3CTFPolicy
A3CTF1Policy = get_a3c_tf_policy("A3CTF1Policy", DynamicTFPolicyV2)
A3CTF2Policy = get_a3c_tf_policy("A3CTF2Policy", EagerTFPolicyV2)
@Deprecated(
old="rllib.algorithms.a3c.a3c_tf_policy.postprocess_advantages",
new="rllib.evaluation.postprocessing.compute_gae_for_sample_batch",
error=True,
)
def postprocess_advantages(*args, **kwargs):
pass