"""Note: Keep in sync with changes to VTraceTFPolicy.""" from typing import Dict, List, Optional, Type, Union import ray from ray.rllib.evaluation.episode import Episode from ray.rllib.evaluation.postprocessing import ( compute_gae_for_sample_batch, Postprocessing, ) from ray.rllib.models.modelv2 import ModelV2 from ray.rllib.models.tf.tf_action_dist import TFActionDistribution from ray.rllib.policy.dynamic_tf_policy_v2 import DynamicTFPolicyV2 from ray.rllib.policy.eager_tf_policy_v2 import EagerTFPolicyV2 from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.policy.tf_mixins import ( compute_gradients, EntropyCoeffSchedule, LearningRateSchedule, ValueNetworkMixin, ) from ray.rllib.utils.annotations import override from ray.rllib.utils.deprecation import Deprecated from ray.rllib.utils.framework import try_import_tf from ray.rllib.utils.tf_utils import explained_variance from ray.rllib.utils.typing import ( AgentID, LocalOptimizer, ModelGradients, TensorType, TFPolicyV2Type, ) tf1, tf, tfv = try_import_tf() # We need this builder function because we want to share the same # custom logics between TF1 dynamic and TF2 eager policies. def get_a3c_tf_policy(name: str, base: TFPolicyV2Type) -> TFPolicyV2Type: """Construct a A3CTFPolicy inheriting either dynamic or eager base policies. Args: base: Base class for this policy. DynamicTFPolicyV2 or EagerTFPolicyV2. Returns: A TF Policy to be used with MAML. """ class A3CTFPolicy( ValueNetworkMixin, LearningRateSchedule, EntropyCoeffSchedule, base ): def __init__( self, obs_space, action_space, config, existing_model=None, existing_inputs=None, ): # First thing first, enable eager execution if necessary. base.enable_eager_execution_if_necessary() config = dict(ray.rllib.algorithms.a3c.a3c.A3CConfig().to_dict(), **config) # Initialize base class. base.__init__( self, obs_space, action_space, config, existing_inputs=existing_inputs, existing_model=existing_model, ) ValueNetworkMixin.__init__(self, self.config) LearningRateSchedule.__init__( self, self.config["lr"], self.config["lr_schedule"] ) EntropyCoeffSchedule.__init__( self, config["entropy_coeff"], config["entropy_coeff_schedule"] ) # Note: this is a bit ugly, but loss and optimizer initialization must # happen after all the MixIns are initialized. self.maybe_initialize_optimizer_and_loss() @override(base) def loss( self, model: Union[ModelV2, "tf.keras.Model"], dist_class: Type[TFActionDistribution], train_batch: SampleBatch, ) -> Union[TensorType, List[TensorType]]: model_out, _ = model(train_batch) action_dist = dist_class(model_out, model) if self.is_recurrent(): max_seq_len = tf.reduce_max(train_batch[SampleBatch.SEQ_LENS]) valid_mask = tf.sequence_mask( train_batch[SampleBatch.SEQ_LENS], max_seq_len ) valid_mask = tf.reshape(valid_mask, [-1]) else: valid_mask = tf.ones_like(train_batch[SampleBatch.REWARDS]) log_prob = action_dist.logp(train_batch[SampleBatch.ACTIONS]) vf = model.value_function() # The "policy gradients" loss self.pi_loss = -tf.reduce_sum( tf.boolean_mask( log_prob * train_batch[Postprocessing.ADVANTAGES], valid_mask ) ) delta = tf.boolean_mask( vf - train_batch[Postprocessing.VALUE_TARGETS], valid_mask ) # Compute a value function loss. if self.config.get("use_critic", True): self.vf_loss = 0.5 * tf.reduce_sum(tf.math.square(delta)) # Ignore the value function. else: self.vf_loss = tf.constant(0.0) self.entropy_loss = tf.reduce_sum( tf.boolean_mask(action_dist.entropy(), valid_mask) ) self.total_loss = ( self.pi_loss + self.vf_loss * self.config["vf_loss_coeff"] - self.entropy_loss * self.entropy_coeff ) return self.total_loss @override(base) def stats_fn(self, train_batch: SampleBatch) -> Dict[str, TensorType]: return { "cur_lr": tf.cast(self.cur_lr, tf.float64), "entropy_coeff": tf.cast(self.entropy_coeff, tf.float64), "policy_loss": self.pi_loss, "policy_entropy": self.entropy_loss, "var_gnorm": tf.linalg.global_norm( list(self.model.trainable_variables()) ), "vf_loss": self.vf_loss, } @override(base) def grad_stats_fn( self, train_batch: SampleBatch, grads: ModelGradients ) -> Dict[str, TensorType]: return { "grad_gnorm": tf.linalg.global_norm(grads), "vf_explained_var": explained_variance( train_batch[Postprocessing.VALUE_TARGETS], self.model.value_function(), ), } @override(base) def postprocess_trajectory( self, sample_batch: SampleBatch, other_agent_batches: Optional[Dict[AgentID, SampleBatch]] = None, episode: Optional[Episode] = None, ): sample_batch = super().postprocess_trajectory(sample_batch) return compute_gae_for_sample_batch( self, sample_batch, other_agent_batches, episode ) @override(base) def compute_gradients_fn( self, optimizer: LocalOptimizer, loss: TensorType ) -> ModelGradients: return compute_gradients(self, optimizer, loss) A3CTFPolicy.__name__ = name A3CTFPolicy.__qualname__ = name return A3CTFPolicy A3CTF1Policy = get_a3c_tf_policy("A3CTF1Policy", DynamicTFPolicyV2) A3CTF2Policy = get_a3c_tf_policy("A3CTF2Policy", EagerTFPolicyV2) @Deprecated( old="rllib.algorithms.a3c.a3c_tf_policy.postprocess_advantages", new="rllib.evaluation.postprocessing.compute_gae_for_sample_batch", error=True, ) def postprocess_advantages(*args, **kwargs): pass