"""Note: Keep in sync with changes to VTraceTFPolicy.""" from typing import Optional, Dict import gym import ray from ray.rllib.agents.ppo.ppo_tf_policy import ValueNetworkMixin from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.evaluation.episode import Episode from ray.rllib.evaluation.postprocessing import ( compute_gae_for_sample_batch, Postprocessing, ) from ray.rllib.models.action_dist import ActionDistribution from ray.rllib.models.modelv2 import ModelV2 from ray.rllib.policy.policy import Policy from ray.rllib.policy.tf_policy import LearningRateSchedule, EntropyCoeffSchedule from ray.rllib.policy.tf_policy_template import build_tf_policy from ray.rllib.utils.deprecation import Deprecated from ray.rllib.utils.framework import try_import_tf from ray.rllib.utils.tf_utils import explained_variance from ray.rllib.utils.typing import ( TrainerConfigDict, TensorType, PolicyID, LocalOptimizer, ModelGradients, ) tf1, tf, tfv = try_import_tf() @Deprecated( old="rllib.agents.a3c.a3c_tf_policy.postprocess_advantages", new="rllib.evaluation.postprocessing.compute_gae_for_sample_batch", error=False, ) def postprocess_advantages( policy: Policy, sample_batch: SampleBatch, other_agent_batches: Optional[Dict[PolicyID, SampleBatch]] = None, episode: Optional[Episode] = None, ) -> SampleBatch: return compute_gae_for_sample_batch( policy, sample_batch, other_agent_batches, episode ) class A3CLoss: def __init__( self, action_dist: ActionDistribution, actions: TensorType, advantages: TensorType, v_target: TensorType, vf: TensorType, valid_mask: TensorType, vf_loss_coeff: float = 0.5, entropy_coeff: float = 0.01, use_critic: bool = True, ): log_prob = action_dist.logp(actions) # The "policy gradients" loss self.pi_loss = -tf.reduce_sum( tf.boolean_mask(log_prob * advantages, valid_mask) ) delta = tf.boolean_mask(vf - v_target, valid_mask) # Compute a value function loss. if use_critic: self.vf_loss = 0.5 * tf.reduce_sum(tf.math.square(delta)) # Ignore the value function. else: self.vf_loss = tf.constant(0.0) self.entropy = tf.reduce_sum(tf.boolean_mask(action_dist.entropy(), valid_mask)) self.total_loss = ( self.pi_loss + self.vf_loss * vf_loss_coeff - self.entropy * entropy_coeff ) def actor_critic_loss( policy: Policy, model: ModelV2, dist_class: ActionDistribution, train_batch: SampleBatch, ) -> TensorType: model_out, _ = model(train_batch) action_dist = dist_class(model_out, model) if policy.is_recurrent(): max_seq_len = tf.reduce_max(train_batch[SampleBatch.SEQ_LENS]) mask = tf.sequence_mask(train_batch[SampleBatch.SEQ_LENS], max_seq_len) mask = tf.reshape(mask, [-1]) else: mask = tf.ones_like(train_batch[SampleBatch.REWARDS]) policy.loss = A3CLoss( action_dist, train_batch[SampleBatch.ACTIONS], train_batch[Postprocessing.ADVANTAGES], train_batch[Postprocessing.VALUE_TARGETS], model.value_function(), mask, policy.config["vf_loss_coeff"], policy.entropy_coeff, policy.config.get("use_critic", True), ) return policy.loss.total_loss def add_value_function_fetch(policy: Policy) -> Dict[str, TensorType]: return {SampleBatch.VF_PREDS: policy.model.value_function()} def stats(policy: Policy, train_batch: SampleBatch) -> Dict[str, TensorType]: return { "cur_lr": tf.cast(policy.cur_lr, tf.float64), "entropy_coeff": tf.cast(policy.entropy_coeff, tf.float64), "policy_loss": policy.loss.pi_loss, "policy_entropy": policy.loss.entropy, "var_gnorm": tf.linalg.global_norm(list(policy.model.trainable_variables())), "vf_loss": policy.loss.vf_loss, } def grad_stats( policy: Policy, train_batch: SampleBatch, grads: ModelGradients ) -> Dict[str, TensorType]: return { "grad_gnorm": tf.linalg.global_norm(grads), "vf_explained_var": explained_variance( train_batch[Postprocessing.VALUE_TARGETS], policy.model.value_function() ), } def clip_gradients( policy: Policy, optimizer: LocalOptimizer, loss: TensorType ) -> ModelGradients: grads_and_vars = optimizer.compute_gradients( loss, policy.model.trainable_variables() ) grads = [g for (g, v) in grads_and_vars] grads, _ = tf.clip_by_global_norm(grads, policy.config["grad_clip"]) clipped_grads = list(zip(grads, policy.model.trainable_variables())) return clipped_grads def setup_mixins( policy: Policy, obs_space: gym.spaces.Space, action_space: gym.spaces.Space, config: TrainerConfigDict, ) -> None: ValueNetworkMixin.__init__(policy, obs_space, action_space, config) LearningRateSchedule.__init__(policy, config["lr"], config["lr_schedule"]) EntropyCoeffSchedule.__init__( policy, config["entropy_coeff"], config["entropy_coeff_schedule"] ) A3CTFPolicy = build_tf_policy( name="A3CTFPolicy", get_default_config=lambda: ray.rllib.agents.a3c.a3c.DEFAULT_CONFIG, loss_fn=actor_critic_loss, stats_fn=stats, grad_stats_fn=grad_stats, compute_gradients_fn=clip_gradients, postprocess_fn=compute_gae_for_sample_batch, extra_action_out_fn=add_value_function_fetch, before_loss_init=setup_mixins, mixins=[ValueNetworkMixin, LearningRateSchedule, EntropyCoeffSchedule], )