import logging import numpy as np import ray from ray.rllib.agents.a3c.a3c_torch_policy import apply_grad_clipping from ray.rllib.agents.ppo.ppo_tf_policy import postprocess_ppo_gae, \ setup_config from ray.rllib.evaluation.postprocessing import Postprocessing from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.policy.torch_policy import EntropyCoeffSchedule, \ LearningRateSchedule from ray.rllib.policy.torch_policy_template import build_torch_policy from ray.rllib.utils.framework import try_import_torch from ray.rllib.utils.torch_ops import convert_to_torch_tensor, \ explained_variance, sequence_mask torch, nn = try_import_torch() logger = logging.getLogger(__name__) class PPOLoss: def __init__(self, dist_class, model, value_targets, advantages, actions, prev_logits, prev_actions_logp, vf_preds, curr_action_dist, value_fn, cur_kl_coeff, valid_mask, entropy_coeff=0, clip_param=0.1, vf_clip_param=0.1, vf_loss_coeff=1.0, use_gae=True): """Constructs the loss for Proximal Policy Objective. Arguments: dist_class: action distribution class for logits. value_targets (Placeholder): Placeholder for target values; used for GAE. actions (Placeholder): Placeholder for actions taken from previous model evaluation. advantages (Placeholder): Placeholder for calculated advantages from previous model evaluation. prev_logits (Placeholder): Placeholder for logits output from previous model evaluation. prev_actions_logp (Placeholder): Placeholder for prob output from previous model evaluation. vf_preds (Placeholder): Placeholder for value function output from previous model evaluation. curr_action_dist (ActionDistribution): ActionDistribution of the current model. value_fn (Tensor): Current value function output Tensor. cur_kl_coeff (Variable): Variable holding the current PPO KL coefficient. valid_mask (Tensor): A bool mask of valid input elements (#2992). entropy_coeff (float): Coefficient of the entropy regularizer. clip_param (float): Clip parameter vf_clip_param (float): Clip parameter for the value function vf_loss_coeff (float): Coefficient of the value function loss use_gae (bool): If true, use the Generalized Advantage Estimator. """ if valid_mask is not None: num_valid = torch.sum(valid_mask) def reduce_mean_valid(t): return torch.sum(t[valid_mask]) / num_valid else: def reduce_mean_valid(t): return torch.mean(t) prev_dist = dist_class(prev_logits, model) # Make loss functions. logp_ratio = torch.exp( curr_action_dist.logp(actions) - prev_actions_logp) action_kl = prev_dist.kl(curr_action_dist) self.mean_kl = reduce_mean_valid(action_kl) curr_entropy = curr_action_dist.entropy() self.mean_entropy = reduce_mean_valid(curr_entropy) surrogate_loss = torch.min( advantages * logp_ratio, advantages * torch.clamp(logp_ratio, 1 - clip_param, 1 + clip_param)) self.mean_policy_loss = reduce_mean_valid(-surrogate_loss) if use_gae: vf_loss1 = torch.pow(value_fn - value_targets, 2.0) vf_clipped = vf_preds + torch.clamp(value_fn - vf_preds, -vf_clip_param, vf_clip_param) vf_loss2 = torch.pow(vf_clipped - value_targets, 2.0) vf_loss = torch.max(vf_loss1, vf_loss2) self.mean_vf_loss = reduce_mean_valid(vf_loss) loss = reduce_mean_valid( -surrogate_loss + cur_kl_coeff * action_kl + vf_loss_coeff * vf_loss - entropy_coeff * curr_entropy) else: self.mean_vf_loss = 0.0 loss = reduce_mean_valid(-surrogate_loss + cur_kl_coeff * action_kl - entropy_coeff * curr_entropy) self.loss = loss def ppo_surrogate_loss(policy, model, dist_class, train_batch): logits, state = model.from_batch(train_batch, is_training=True) action_dist = dist_class(logits, model) mask = None if state: max_seq_len = torch.max(train_batch["seq_lens"]) mask = sequence_mask(train_batch["seq_lens"], max_seq_len) mask = torch.reshape(mask, [-1]) policy.loss_obj = PPOLoss( dist_class, model, train_batch[Postprocessing.VALUE_TARGETS], train_batch[Postprocessing.ADVANTAGES], train_batch[SampleBatch.ACTIONS], train_batch[SampleBatch.ACTION_DIST_INPUTS], train_batch[SampleBatch.ACTION_LOGP], train_batch[SampleBatch.VF_PREDS], action_dist, model.value_function(), policy.kl_coeff, mask, entropy_coeff=policy.entropy_coeff, clip_param=policy.config["clip_param"], vf_clip_param=policy.config["vf_clip_param"], vf_loss_coeff=policy.config["vf_loss_coeff"], use_gae=policy.config["use_gae"], ) return policy.loss_obj.loss def kl_and_loss_stats(policy, train_batch): return { "cur_kl_coeff": policy.kl_coeff, "cur_lr": policy.cur_lr, "total_loss": policy.loss_obj.loss, "policy_loss": policy.loss_obj.mean_policy_loss, "vf_loss": policy.loss_obj.mean_vf_loss, "vf_explained_var": explained_variance( train_batch[Postprocessing.VALUE_TARGETS], policy.model.value_function()), "kl": policy.loss_obj.mean_kl, "entropy": policy.loss_obj.mean_entropy, "entropy_coeff": policy.entropy_coeff, } def vf_preds_fetches(policy, input_dict, state_batches, model, action_dist): """Adds value function outputs to experience train_batches.""" return { SampleBatch.VF_PREDS: policy.model.value_function(), } class KLCoeffMixin: def __init__(self, config): # KL Coefficient. self.kl_coeff = config["kl_coeff"] self.kl_target = config["kl_target"] def update_kl(self, sampled_kl): if sampled_kl > 2.0 * self.kl_target: self.kl_coeff *= 1.5 elif sampled_kl < 0.5 * self.kl_target: self.kl_coeff *= 0.5 return self.kl_coeff class ValueNetworkMixin: def __init__(self, obs_space, action_space, config): if config["use_gae"]: def value(ob, prev_action, prev_reward, *state): model_out, _ = self.model({ SampleBatch.CUR_OBS: convert_to_torch_tensor( np.asarray([ob]), self.device), SampleBatch.PREV_ACTIONS: convert_to_torch_tensor( np.asarray([prev_action]), self.device), SampleBatch.PREV_REWARDS: convert_to_torch_tensor( np.asarray([prev_reward]), self.device), "is_training": False, }, [ convert_to_torch_tensor(np.asarray([s]), self.device) for s in state ], convert_to_torch_tensor(np.asarray([1]), self.device)) return self.model.value_function()[0] else: def value(ob, prev_action, prev_reward, *state): return 0.0 self._value = value def setup_mixins(policy, obs_space, action_space, config): ValueNetworkMixin.__init__(policy, obs_space, action_space, config) KLCoeffMixin.__init__(policy, config) EntropyCoeffSchedule.__init__(policy, config["entropy_coeff"], config["entropy_coeff_schedule"]) LearningRateSchedule.__init__(policy, config["lr"], config["lr_schedule"]) PPOTorchPolicy = build_torch_policy( name="PPOTorchPolicy", get_default_config=lambda: ray.rllib.agents.ppo.ppo.DEFAULT_CONFIG, loss_fn=ppo_surrogate_loss, stats_fn=kl_and_loss_stats, extra_action_out_fn=vf_preds_fetches, postprocess_fn=postprocess_ppo_gae, extra_grad_process_fn=apply_grad_clipping, before_init=setup_config, after_init=setup_mixins, mixins=[ LearningRateSchedule, EntropyCoeffSchedule, KLCoeffMixin, ValueNetworkMixin ])