""" TensorFlow policy class used for PPO. """ import gym import logging from typing import Dict, List, Optional, Type, Union import ray from ray.rllib.evaluation.episode import MultiAgentEpisode from ray.rllib.evaluation.postprocessing import compute_gae_for_sample_batch, \ Postprocessing from ray.rllib.models.modelv2 import ModelV2 from ray.rllib.models.tf.tf_action_dist import TFActionDistribution from ray.rllib.policy.policy import Policy from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.policy.tf_policy import LearningRateSchedule, \ EntropyCoeffSchedule from ray.rllib.policy.tf_policy_template import build_tf_policy from ray.rllib.utils.deprecation import DEPRECATED_VALUE, deprecation_warning from ray.rllib.utils.framework import try_import_tf, get_variable from ray.rllib.utils.tf_ops import explained_variance, make_tf_callable from ray.rllib.utils.typing import AgentID, LocalOptimizer, ModelGradients, \ TensorType, TrainerConfigDict tf1, tf, tfv = try_import_tf() logger = logging.getLogger(__name__) def ppo_surrogate_loss( policy: Policy, model: Union[ModelV2, "tf.keras.Model"], dist_class: Type[TFActionDistribution], train_batch: SampleBatch) -> Union[TensorType, List[TensorType]]: """Constructs the loss for Proximal Policy Objective. Args: policy (Policy): The Policy to calculate the loss for. model (Union[ModelV2, tf.keras.Model]): The Model to calculate the loss for. dist_class (Type[ActionDistribution]: The action distr. class. train_batch (SampleBatch): The training data. Returns: Union[TensorType, List[TensorType]]: A single loss tensor or a list of loss tensors. """ if isinstance(model, tf.keras.Model): logits, state, extra_outs = model(train_batch) value_fn_out = extra_outs[SampleBatch.VF_PREDS] else: logits, state = model.from_batch(train_batch) value_fn_out = model.value_function() curr_action_dist = dist_class(logits, model) # RNN case: Mask away 0-padded chunks at end of time axis. if state: # Derive max_seq_len from the data itself, not from the seq_lens # tensor. This is in case e.g. seq_lens=[2, 3], but the data is still # 0-padded up to T=5 (as it's the case for attention nets). B = tf.shape(train_batch["seq_lens"])[0] max_seq_len = tf.shape(logits)[0] // B mask = tf.sequence_mask(train_batch["seq_lens"], max_seq_len) mask = tf.reshape(mask, [-1]) def reduce_mean_valid(t): return tf.reduce_mean(tf.boolean_mask(t, mask)) # non-RNN case: No masking. else: mask = None reduce_mean_valid = tf.reduce_mean prev_action_dist = dist_class(train_batch[SampleBatch.ACTION_DIST_INPUTS], model) logp_ratio = tf.exp( curr_action_dist.logp(train_batch[SampleBatch.ACTIONS]) - train_batch[SampleBatch.ACTION_LOGP]) action_kl = prev_action_dist.kl(curr_action_dist) mean_kl = reduce_mean_valid(action_kl) curr_entropy = curr_action_dist.entropy() mean_entropy = reduce_mean_valid(curr_entropy) surrogate_loss = tf.minimum( train_batch[Postprocessing.ADVANTAGES] * logp_ratio, train_batch[Postprocessing.ADVANTAGES] * tf.clip_by_value( logp_ratio, 1 - policy.config["clip_param"], 1 + policy.config["clip_param"])) mean_policy_loss = reduce_mean_valid(-surrogate_loss) # Compute a value function loss. if policy.config["use_critic"]: prev_value_fn_out = train_batch[SampleBatch.VF_PREDS] vf_loss1 = tf.math.square(value_fn_out - train_batch[Postprocessing.VALUE_TARGETS]) vf_clipped = prev_value_fn_out + tf.clip_by_value( value_fn_out - prev_value_fn_out, -policy.config["vf_clip_param"], policy.config["vf_clip_param"]) vf_loss2 = tf.math.square(vf_clipped - train_batch[Postprocessing.VALUE_TARGETS]) vf_loss = tf.maximum(vf_loss1, vf_loss2) mean_vf_loss = reduce_mean_valid(vf_loss) # Ignore the value function. else: vf_loss = mean_vf_loss = tf.constant(0.0) total_loss = reduce_mean_valid(-surrogate_loss + policy.kl_coeff * action_kl + policy.config["vf_loss_coeff"] * vf_loss - policy.entropy_coeff * curr_entropy) # Store stats in policy for stats_fn. policy._total_loss = total_loss policy._mean_policy_loss = mean_policy_loss policy._mean_vf_loss = mean_vf_loss policy._mean_entropy = mean_entropy policy._mean_kl = mean_kl policy._value_fn_out = value_fn_out return total_loss def kl_and_loss_stats(policy: Policy, train_batch: SampleBatch) -> Dict[str, TensorType]: """Stats function for PPO. Returns a dict with important KL and loss stats. Args: policy (Policy): The Policy to generate stats for. train_batch (SampleBatch): The SampleBatch (already) used for training. Returns: Dict[str, TensorType]: The stats dict. """ return { "cur_kl_coeff": tf.cast(policy.kl_coeff, tf.float64), "cur_lr": tf.cast(policy.cur_lr, tf.float64), "total_loss": policy._total_loss, "policy_loss": policy._mean_policy_loss, "vf_loss": policy._mean_vf_loss, "vf_explained_var": explained_variance( train_batch[Postprocessing.VALUE_TARGETS], policy._value_fn_out), "kl": policy._mean_kl, "entropy": policy._mean_entropy, "entropy_coeff": tf.cast(policy.entropy_coeff, tf.float64), } # TODO: (sven) Deprecate once we only allow native keras models. def vf_preds_fetches(policy: Policy) -> Dict[str, TensorType]: """Defines extra fetches per action computation. Args: policy (Policy): The Policy to perform the extra action fetch on. Returns: Dict[str, TensorType]: Dict with extra tf fetches to perform per action computation. """ # Keras models return values for each call in third return argument # (dict). if isinstance(policy.model, tf.keras.Model): return {} # Return value function outputs. VF estimates will hence be added to the # SampleBatches produced by the sampler(s) to generate the train batches # going into the loss function. return { SampleBatch.VF_PREDS: policy.model.value_function(), } def compute_and_clip_gradients(policy: Policy, optimizer: LocalOptimizer, loss: TensorType) -> ModelGradients: """Gradients computing function (from loss tensor, using local optimizer). Args: policy (Policy): The Policy object that generated the loss tensor and that holds the given local optimizer. optimizer (LocalOptimizer): The tf (local) optimizer object to calculate the gradients with. loss (TensorType): The loss tensor for which gradients should be calculated. Returns: ModelGradients: List of the possibly clipped gradients- and variable tuples. """ # Compute the gradients. variables = policy.model.trainable_variables if isinstance(policy.model, ModelV2): variables = variables() grads_and_vars = optimizer.compute_gradients(loss, variables) # Clip by global norm, if necessary. if policy.config["grad_clip"] is not None: # Defuse inf gradients (due to super large losses). grads = [g for (g, v) in grads_and_vars] grads, _ = tf.clip_by_global_norm(grads, policy.config["grad_clip"]) # If the global_norm is inf -> All grads will be NaN. Stabilize this # here by setting them to 0.0. This will simply ignore destructive loss # calculations. policy.grads = [ tf.where(tf.math.is_nan(g), tf.zeros_like(g), g) for g in grads ] clipped_grads_and_vars = list(zip(policy.grads, variables)) return clipped_grads_and_vars else: return grads_and_vars class KLCoeffMixin: """Assigns the `update_kl()` method to the PPOPolicy. This is used in PPO's execution plan (see ppo.py) for updating the KL coefficient after each learning step based on `config.kl_target` and the measured KL value (from the train_batch). """ def __init__(self, config): # The current KL value (as python float). self.kl_coeff_val = config["kl_coeff"] # The current KL value (as tf Variable for in-graph operations). self.kl_coeff = get_variable( float(self.kl_coeff_val), tf_name="kl_coeff", trainable=False, framework=config["framework"]) # Constant target value. self.kl_target = config["kl_target"] if self.framework == "tf": self._kl_coeff_placeholder = \ tf1.placeholder(dtype=tf.float32, name="kl_coeff") self._kl_coeff_update = self.kl_coeff.assign( self._kl_coeff_placeholder, read_value=False) def update_kl(self, sampled_kl): # Update the current KL value based on the recently measured value. # Increase. if sampled_kl > 2.0 * self.kl_target: self.kl_coeff_val *= 1.5 # Decrease. elif sampled_kl < 0.5 * self.kl_target: self.kl_coeff_val *= 0.5 # No change. else: return self.kl_coeff_val # Update the tf Variable (via session call for tf). if self.framework == "tf": self.get_session().run( self._kl_coeff_update, feed_dict={self._kl_coeff_placeholder: self.kl_coeff_val}) else: self.kl_coeff.assign(self.kl_coeff_val, read_value=False) # Return the current KL value. return self.kl_coeff_val class ValueNetworkMixin: """Assigns the `_value()` method to the PPOPolicy. This way, Policy can call `_value()` to get the current VF estimate on a single(!) observation (as done in `postprocess_trajectory_fn`). Note: When doing this, an actual forward pass is being performed. This is different from only calling `model.value_function()`, where the result of the most recent forward pass is being used to return an already calculated tensor. """ def __init__(self, obs_space, action_space, config): # When doing GAE, we need the value function estimate on the # observation. if config["use_gae"]: # Input dict is provided to us automatically via the Model's # requirements. It's a single-timestep (last one in trajectory) # input_dict. @make_tf_callable(self.get_session()) def value(**input_dict): input_dict = SampleBatch(input_dict) if isinstance(self.model, tf.keras.Model): _, _, extra_outs = self.model(input_dict) return extra_outs[SampleBatch.VF_PREDS][0] else: model_out, _ = self.model(input_dict) # [0] = remove the batch dim. return self.model.value_function()[0] # When not doing GAE, we do not require the value function's output. else: @make_tf_callable(self.get_session()) def value(*args, **kwargs): return tf.constant(0.0) self._value = value def setup_config(policy: Policy, obs_space: gym.spaces.Space, action_space: gym.spaces.Space, config: TrainerConfigDict) -> None: """Executed before Policy is "initialized" (at beginning of constructor). Args: policy (Policy): The Policy object. obs_space (gym.spaces.Space): The Policy's observation space. action_space (gym.spaces.Space): The Policy's action space. config (TrainerConfigDict): The Policy's config. """ # Setting `vf_share_layers` in the top-level config is deprecated. # It's confusing as some users might (correctly!) set it in their # model config and then won't notice that it's silently overwritten # here. if config["vf_share_layers"] != DEPRECATED_VALUE: deprecation_warning( old="config[vf_share_layers]", new="config[model][vf_share_layers]", error=False, ) config["model"]["vf_share_layers"] = config["vf_share_layers"] # If vf_share_layers is True, inform about the need to tune vf_loss_coeff. if config.get("model", {}).get("vf_share_layers") is True: logger.info( "`vf_share_layers=True` in your model. " "Therefore, remember to tune the value of `vf_loss_coeff`!") def setup_mixins(policy: Policy, obs_space: gym.spaces.Space, action_space: gym.spaces.Space, config: TrainerConfigDict) -> None: """Call mixin classes' constructors before Policy's loss initialization. Args: policy (Policy): The Policy object. obs_space (gym.spaces.Space): The Policy's observation space. action_space (gym.spaces.Space): The Policy's action space. config (TrainerConfigDict): The Policy's config. """ ValueNetworkMixin.__init__(policy, obs_space, action_space, config) KLCoeffMixin.__init__(policy, config) EntropyCoeffSchedule.__init__(policy, config["entropy_coeff"], config["entropy_coeff_schedule"]) LearningRateSchedule.__init__(policy, config["lr"], config["lr_schedule"]) def postprocess_ppo_gae( policy: Policy, sample_batch: SampleBatch, other_agent_batches: Optional[Dict[AgentID, SampleBatch]] = None, episode: Optional[MultiAgentEpisode] = None) -> SampleBatch: # Stub serving backward compatibility. deprecation_warning( old="rllib.agents.ppo.ppo_tf_policy.postprocess_ppo_gae", new="rllib.evaluation.postprocessing.compute_gae_for_sample_batch", error=False) return compute_gae_for_sample_batch(policy, sample_batch, other_agent_batches, episode) # Build a child class of `DynamicTFPolicy`, given the custom functions defined # above. PPOTFPolicy = build_tf_policy( name="PPOTFPolicy", loss_fn=ppo_surrogate_loss, get_default_config=lambda: ray.rllib.agents.ppo.ppo.DEFAULT_CONFIG, postprocess_fn=compute_gae_for_sample_batch, stats_fn=kl_and_loss_stats, compute_gradients_fn=compute_and_clip_gradients, extra_action_out_fn=vf_preds_fetches, before_init=setup_config, before_loss_init=setup_mixins, mixins=[ LearningRateSchedule, EntropyCoeffSchedule, KLCoeffMixin, ValueNetworkMixin ])