"""Adapted from VTraceTFPolicy to use the PPO surrogate loss. Keep in sync with changes to VTraceTFPolicy.""" import numpy as np import logging import gym from ray.rllib.agents.a3c.a3c_torch_policy import apply_grad_clipping import ray.rllib.agents.impala.vtrace_torch as vtrace from ray.rllib.agents.impala.vtrace_torch_policy import make_time_major, \ choose_optimizer from ray.rllib.agents.ppo.appo_tf_policy import build_appo_model, \ postprocess_trajectory from ray.rllib.agents.ppo.ppo_torch_policy import ValueNetworkMixin, \ KLCoeffMixin from ray.rllib.evaluation.postprocessing import Postprocessing from ray.rllib.models.torch.torch_action_dist import TorchCategorical from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.policy.torch_policy import LearningRateSchedule from ray.rllib.policy.torch_policy_template import build_torch_policy from ray.rllib.utils.explained_variance import explained_variance from ray.rllib.utils.framework import try_import_torch from ray.rllib.utils.torch_ops import global_norm, sequence_mask torch, nn = try_import_torch() logger = logging.getLogger(__name__) class PPOSurrogateLoss: """Loss used when V-trace is disabled. Arguments: prev_actions_logp: A float32 tensor of shape [T, B]. actions_logp: A float32 tensor of shape [T, B]. action_kl: A float32 tensor of shape [T, B]. actions_entropy: A float32 tensor of shape [T, B]. values: A float32 tensor of shape [T, B]. valid_mask: A bool tensor of valid RNN input elements (#2992). advantages: A float32 tensor of shape [T, B]. value_targets: A float32 tensor of shape [T, B]. vf_loss_coeff (float): Coefficient of the value function loss. entropy_coeff (float): Coefficient of the entropy regularizer. clip_param (float): Clip parameter. cur_kl_coeff (float): Coefficient for KL loss. use_kl_loss (bool): If true, use KL loss. """ def __init__(self, prev_actions_logp, actions_logp, action_kl, actions_entropy, values, valid_mask, advantages, value_targets, vf_loss_coeff=0.5, entropy_coeff=0.01, clip_param=0.3, cur_kl_coeff=None, use_kl_loss=False): if valid_mask is not None: num_valid = torch.sum(valid_mask) def reduce_mean_valid(t): return torch.sum(t * valid_mask) / num_valid else: def reduce_mean_valid(t): return torch.mean(t) logp_ratio = torch.exp(actions_logp - prev_actions_logp) surrogate_loss = torch.min( advantages * logp_ratio, advantages * torch.clamp(logp_ratio, 1 - clip_param, 1 + clip_param)) self.mean_kl = reduce_mean_valid(action_kl) self.pi_loss = -reduce_mean_valid(surrogate_loss) # The baseline loss delta = values - value_targets self.value_targets = value_targets self.vf_loss = 0.5 * reduce_mean_valid(torch.pow(delta, 2.0)) # The entropy loss self.entropy = reduce_mean_valid(actions_entropy) # The summed weighted loss self.total_loss = (self.pi_loss + self.vf_loss * vf_loss_coeff - self.entropy * entropy_coeff) # Optional additional KL Loss if use_kl_loss: self.total_loss += cur_kl_coeff * self.mean_kl class VTraceSurrogateLoss: def __init__(self, actions, prev_actions_logp, actions_logp, old_policy_actions_logp, action_kl, actions_entropy, dones, behaviour_logits, old_policy_behaviour_logits, target_logits, discount, rewards, values, bootstrap_value, dist_class, model, valid_mask, vf_loss_coeff=0.5, entropy_coeff=0.01, clip_rho_threshold=1.0, clip_pg_rho_threshold=1.0, clip_param=0.3, cur_kl_coeff=None, use_kl_loss=False): """APPO Loss, with IS modifications and V-trace for Advantage Estimation VTraceLoss takes tensors of shape [T, B, ...], where `B` is the batch_size. The reason we need to know `B` is for V-trace to properly handle episode cut boundaries. Arguments: actions: An int|float32 tensor of shape [T, B, logit_dim]. prev_actions_logp: A float32 tensor of shape [T, B]. actions_logp: A float32 tensor of shape [T, B]. old_policy_actions_logp: A float32 tensor of shape [T, B]. action_kl: A float32 tensor of shape [T, B]. actions_entropy: A float32 tensor of shape [T, B]. dones: A bool tensor of shape [T, B]. behaviour_logits: A float32 tensor of shape [T, B, logit_dim]. old_policy_behaviour_logits: A float32 tensor of shape [T, B, logit_dim]. target_logits: A float32 tensor of shape [T, B, logit_dim]. discount: A float32 scalar. rewards: A float32 tensor of shape [T, B]. values: A float32 tensor of shape [T, B]. bootstrap_value: A float32 tensor of shape [B]. dist_class: action distribution class for logits. model: backing ModelV2 instance valid_mask: A bool tensor of valid RNN input elements (#2992). vf_loss_coeff (float): Coefficient of the value function loss. entropy_coeff (float): Coefficient of the entropy regularizer. clip_param (float): Clip parameter. cur_kl_coeff (float): Coefficient for KL loss. use_kl_loss (bool): If true, use KL loss. """ if valid_mask is not None: num_valid = torch.sum(valid_mask) def reduce_mean_valid(t): return torch.sum(t * valid_mask) / num_valid else: def reduce_mean_valid(t): return torch.mean(t) # Compute vtrace on the CPU for better perf. self.vtrace_returns = vtrace.multi_from_logits( behaviour_policy_logits=behaviour_logits, target_policy_logits=old_policy_behaviour_logits, actions=torch.unbind(actions, dim=2), discounts=(1.0 - dones.float()) * discount, rewards=rewards, values=values, bootstrap_value=bootstrap_value, dist_class=dist_class, model=model, clip_rho_threshold=clip_rho_threshold, clip_pg_rho_threshold=clip_pg_rho_threshold) self.is_ratio = torch.clamp( torch.exp(prev_actions_logp - old_policy_actions_logp), 0.0, 2.0) logp_ratio = self.is_ratio * torch.exp(actions_logp - prev_actions_logp) advantages = self.vtrace_returns.pg_advantages surrogate_loss = torch.min( advantages * logp_ratio, advantages * torch.clamp(logp_ratio, 1 - clip_param, 1 + clip_param)) self.mean_kl = reduce_mean_valid(action_kl) self.pi_loss = -reduce_mean_valid(surrogate_loss) # The baseline loss delta = values - self.vtrace_returns.vs self.value_targets = self.vtrace_returns.vs self.vf_loss = 0.5 * reduce_mean_valid(torch.pow(delta, 2.0)) # The entropy loss self.entropy = reduce_mean_valid(actions_entropy) # The summed weighted loss self.total_loss = (self.pi_loss + self.vf_loss * vf_loss_coeff - self.entropy * entropy_coeff) # Optional additional KL Loss if use_kl_loss: self.total_loss += cur_kl_coeff * self.mean_kl def build_appo_surrogate_loss(policy, model, dist_class, train_batch): model_out, _ = model.from_batch(train_batch) action_dist = dist_class(model_out, model) if isinstance(policy.action_space, gym.spaces.Discrete): is_multidiscrete = False output_hidden_shape = [policy.action_space.n] elif isinstance(policy.action_space, gym.spaces.multi_discrete.MultiDiscrete): is_multidiscrete = True output_hidden_shape = policy.action_space.nvec.astype(np.int32) else: is_multidiscrete = False output_hidden_shape = 1 def _make_time_major(*args, **kw): return make_time_major(policy, train_batch.get("seq_lens"), *args, **kw) actions = train_batch[SampleBatch.ACTIONS] dones = train_batch[SampleBatch.DONES] rewards = train_batch[SampleBatch.REWARDS] behaviour_logits = train_batch[SampleBatch.ACTION_DIST_INPUTS] target_model_out, _ = policy.target_model.from_batch(train_batch) old_policy_behaviour_logits = target_model_out.detach() unpacked_behaviour_logits = torch.split( behaviour_logits, output_hidden_shape, dim=1) unpacked_old_policy_behaviour_logits = torch.split( old_policy_behaviour_logits, output_hidden_shape, dim=1) unpacked_outputs = torch.split(model_out, output_hidden_shape, dim=1) old_policy_action_dist = dist_class(old_policy_behaviour_logits, model) prev_action_dist = dist_class(behaviour_logits, policy.model) values = policy.model.value_function() policy.model_vars = policy.model.variables() policy.target_model_vars = policy.target_model.variables() if policy.is_recurrent(): max_seq_len = torch.max(train_batch["seq_lens"]) - 1 mask = sequence_mask(train_batch["seq_lens"], max_seq_len) mask = torch.reshape(mask, [-1]) else: mask = torch.ones_like(rewards) if policy.config["vtrace"]: logger.debug("Using V-Trace surrogate loss (vtrace=True)") # Prepare actions for loss loss_actions = actions if is_multidiscrete else torch.unsqueeze( actions, dim=1) # Prepare KL for Loss mean_kl = _make_time_major( old_policy_action_dist.multi_kl(action_dist), drop_last=True) policy.loss = VTraceSurrogateLoss( actions=_make_time_major(loss_actions, drop_last=True), prev_actions_logp=_make_time_major( prev_action_dist.logp(actions), drop_last=True), actions_logp=_make_time_major( action_dist.logp(actions), drop_last=True), old_policy_actions_logp=_make_time_major( old_policy_action_dist.logp(actions), drop_last=True), action_kl=torch.mean(mean_kl, dim=0) if is_multidiscrete else mean_kl, actions_entropy=_make_time_major( action_dist.multi_entropy(), drop_last=True), dones=_make_time_major(dones, drop_last=True), behaviour_logits=_make_time_major( unpacked_behaviour_logits, drop_last=True), old_policy_behaviour_logits=_make_time_major( unpacked_old_policy_behaviour_logits, drop_last=True), target_logits=_make_time_major(unpacked_outputs, drop_last=True), discount=policy.config["gamma"], rewards=_make_time_major(rewards, drop_last=True), values=_make_time_major(values, drop_last=True), bootstrap_value=_make_time_major(values)[-1], dist_class=TorchCategorical if is_multidiscrete else dist_class, model=policy.model, valid_mask=_make_time_major(mask, drop_last=True), vf_loss_coeff=policy.config["vf_loss_coeff"], entropy_coeff=policy.config["entropy_coeff"], clip_rho_threshold=policy.config["vtrace_clip_rho_threshold"], clip_pg_rho_threshold=policy.config[ "vtrace_clip_pg_rho_threshold"], clip_param=policy.config["clip_param"], cur_kl_coeff=policy.kl_coeff, use_kl_loss=policy.config["use_kl_loss"]) else: logger.debug("Using PPO surrogate loss (vtrace=False)") # Prepare KL for Loss mean_kl = _make_time_major(prev_action_dist.multi_kl(action_dist)) policy.loss = PPOSurrogateLoss( prev_actions_logp=_make_time_major(prev_action_dist.logp(actions)), actions_logp=_make_time_major(action_dist.logp(actions)), action_kl=torch.mean(mean_kl, dim=0) if is_multidiscrete else mean_kl, actions_entropy=_make_time_major(action_dist.multi_entropy()), values=_make_time_major(values), valid_mask=_make_time_major(mask), advantages=_make_time_major( train_batch[Postprocessing.ADVANTAGES]), value_targets=_make_time_major( train_batch[Postprocessing.VALUE_TARGETS]), vf_loss_coeff=policy.config["vf_loss_coeff"], entropy_coeff=policy.config["entropy_coeff"], clip_param=policy.config["clip_param"], cur_kl_coeff=policy.kl_coeff, use_kl_loss=policy.config["use_kl_loss"]) return policy.loss.total_loss def stats(policy, train_batch): values_batched = make_time_major( policy, train_batch.get("seq_lens"), policy.model.value_function(), drop_last=policy.config["vtrace"]) stats_dict = { "cur_lr": policy.cur_lr, "policy_loss": policy.loss.pi_loss, "entropy": policy.loss.entropy, "var_gnorm": global_norm(policy.model.trainable_variables()), "vf_loss": policy.loss.vf_loss, "vf_explained_var": explained_variance( torch.reshape(policy.loss.value_targets, [-1]), torch.reshape(values_batched, [-1]), framework="torch"), } if policy.config["vtrace"]: is_stat_mean = torch.mean(policy.loss.is_ratio, [0, 1]) is_stat_var = torch.var(policy.loss.is_ratio, [0, 1]) stats_dict.update({"mean_IS": is_stat_mean}) stats_dict.update({"var_IS": is_stat_var}) if policy.config["use_kl_loss"]: stats_dict.update({"kl": policy.loss.mean_kl}) stats_dict.update({"KL_Coeff": policy.kl_coeff}) return stats_dict class TargetNetworkMixin: def __init__(self, obs_space, action_space, config): def do_update(): # Update_target_fn will be called periodically to copy Q network to # target Q network. assert len(self.model_variables) == \ len(self.target_model_variables), \ (self.model_variables, self.target_model_variables) self.target_model.load_state_dict(self.model.state_dict()) self.update_target = do_update def add_values(policy, input_dict, state_batches, model, action_dist): out = {} if not policy.config["vtrace"]: out[SampleBatch.VF_PREDS] = policy.model.value_function() return out def setup_early_mixins(policy, obs_space, action_space, config): LearningRateSchedule.__init__(policy, config["lr"], config["lr_schedule"]) def setup_late_mixins(policy, obs_space, action_space, config): KLCoeffMixin.__init__(policy, config) ValueNetworkMixin.__init__(policy, obs_space, action_space, config) TargetNetworkMixin.__init__(policy, obs_space, action_space, config) AsyncPPOTorchPolicy = build_torch_policy( name="AsyncPPOTorchPolicy", loss_fn=build_appo_surrogate_loss, stats_fn=stats, postprocess_fn=postprocess_trajectory, extra_action_out_fn=add_values, extra_grad_process_fn=apply_grad_clipping, optimizer_fn=choose_optimizer, before_init=setup_early_mixins, after_init=setup_late_mixins, make_model=build_appo_model, mixins=[ LearningRateSchedule, KLCoeffMixin, TargetNetworkMixin, ValueNetworkMixin ], get_batch_divisibility_req=lambda p: p.config["rollout_fragment_length"])