ray/rllib/algorithms/appo/appo_tf_policy.py
2022-07-24 15:31:09 +02:00

444 lines
17 KiB
Python

"""
TensorFlow policy class used for APPO.
Adapted from VTraceTFPolicy to use the PPO surrogate loss.
Keep in sync with changes to VTraceTFPolicy.
"""
import numpy as np
import logging
import gym
from typing import Dict, List, Optional, Type, Union
import ray
from ray.rllib.algorithms.appo.utils import make_appo_models
from ray.rllib.algorithms.impala import vtrace_tf as vtrace
from ray.rllib.algorithms.impala.impala_tf_policy import (
_make_time_major,
VTraceClipGradients,
VTraceOptimizer,
)
from ray.rllib.evaluation.episode import Episode
from ray.rllib.evaluation.postprocessing import (
compute_gae_for_sample_batch,
Postprocessing,
)
from ray.rllib.models.tf.tf_action_dist import Categorical
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.policy.dynamic_tf_policy_v2 import DynamicTFPolicyV2
from ray.rllib.policy.eager_tf_policy_v2 import EagerTFPolicyV2
from ray.rllib.policy.tf_mixins import (
EntropyCoeffSchedule,
LearningRateSchedule,
KLCoeffMixin,
ValueNetworkMixin,
GradStatsMixin,
)
from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.models.tf.tf_action_dist import TFActionDistribution
from ray.rllib.utils.annotations import (
override,
)
from ray.rllib.utils.framework import try_import_tf
from ray.rllib.utils.tf_utils import explained_variance, make_tf_callable
from ray.rllib.utils.typing import TensorType
tf1, tf, tfv = try_import_tf()
logger = logging.getLogger(__name__)
class TargetNetworkMixin:
"""Target NN is updated by master learner via the `update_target` method.
Updates happen every `trainer.update_target_frequency` steps. All worker
batches are importance sampled wrt the target network to ensure a more
stable pi_old in PPO.
"""
def __init__(self, obs_space, action_space, config):
@make_tf_callable(self.get_session())
def do_update():
assign_ops = []
assert len(self.model_vars) == len(self.target_model_vars)
for var, var_target in zip(self.model_vars, self.target_model_vars):
assign_ops.append(var_target.assign(var))
return tf.group(*assign_ops)
self.update_target = do_update
@property
def model_vars(self):
if not hasattr(self, "_model_vars"):
self._model_vars = self.model.variables()
return self._model_vars
@property
def target_model_vars(self):
if not hasattr(self, "_target_model_vars"):
self._target_model_vars = self.target_model.variables()
return self._target_model_vars
# We need this builder function because we want to share the same
# custom logics between TF1 dynamic and TF2 eager policies.
def get_appo_tf_policy(name: str, base: type) -> type:
"""Construct an APPOTFPolicy inheriting either dynamic or eager base policies.
Args:
base: Base class for this policy. DynamicTFPolicyV2 or EagerTFPolicyV2.
Returns:
A TF Policy to be used with Impala.
"""
class APPOTFPolicy(
VTraceClipGradients,
VTraceOptimizer,
LearningRateSchedule,
KLCoeffMixin,
EntropyCoeffSchedule,
ValueNetworkMixin,
TargetNetworkMixin,
GradStatsMixin,
base,
):
def __init__(
self,
obs_space,
action_space,
config,
existing_model=None,
existing_inputs=None,
):
# First thing first, enable eager execution if necessary.
base.enable_eager_execution_if_necessary()
config = dict(
ray.rllib.algorithms.appo.appo.APPOConfig().to_dict(), **config
)
# Although this is a no-op, we call __init__ here to make it clear
# that base.__init__ will use the make_model() call.
VTraceClipGradients.__init__(self)
VTraceOptimizer.__init__(self)
LearningRateSchedule.__init__(self, config["lr"], config["lr_schedule"])
# Initialize base class.
base.__init__(
self,
obs_space,
action_space,
config,
existing_inputs=existing_inputs,
existing_model=existing_model,
)
EntropyCoeffSchedule.__init__(
self, config["entropy_coeff"], config["entropy_coeff_schedule"]
)
ValueNetworkMixin.__init__(self, config)
KLCoeffMixin.__init__(self, config)
GradStatsMixin.__init__(self)
# Note: this is a bit ugly, but loss and optimizer initialization must
# happen after all the MixIns are initialized.
self.maybe_initialize_optimizer_and_loss()
# Initiate TargetNetwork ops after loss initialization.
TargetNetworkMixin.__init__(self, obs_space, action_space, config)
@override(base)
def make_model(self) -> ModelV2:
return make_appo_models(self)
@override(base)
def loss(
self,
model: Union[ModelV2, "tf.keras.Model"],
dist_class: Type[TFActionDistribution],
train_batch: SampleBatch,
) -> Union[TensorType, List[TensorType]]:
model_out, _ = model(train_batch)
action_dist = dist_class(model_out, model)
if isinstance(self.action_space, gym.spaces.Discrete):
is_multidiscrete = False
output_hidden_shape = [self.action_space.n]
elif isinstance(self.action_space, gym.spaces.multi_discrete.MultiDiscrete):
is_multidiscrete = True
output_hidden_shape = self.action_space.nvec.astype(np.int32)
else:
is_multidiscrete = False
output_hidden_shape = 1
# TODO: (sven) deprecate this when trajectory view API gets activated.
def make_time_major(*args, **kw):
return _make_time_major(
self, train_batch.get(SampleBatch.SEQ_LENS), *args, **kw
)
actions = train_batch[SampleBatch.ACTIONS]
dones = train_batch[SampleBatch.DONES]
rewards = train_batch[SampleBatch.REWARDS]
behaviour_logits = train_batch[SampleBatch.ACTION_DIST_INPUTS]
target_model_out, _ = self.target_model(train_batch)
prev_action_dist = dist_class(behaviour_logits, self.model)
values = self.model.value_function()
values_time_major = make_time_major(values)
if self.is_recurrent():
max_seq_len = tf.reduce_max(train_batch[SampleBatch.SEQ_LENS])
mask = tf.sequence_mask(train_batch[SampleBatch.SEQ_LENS], max_seq_len)
mask = tf.reshape(mask, [-1])
mask = make_time_major(mask, drop_last=self.config["vtrace"])
def reduce_mean_valid(t):
return tf.reduce_mean(tf.boolean_mask(t, mask))
else:
reduce_mean_valid = tf.reduce_mean
if self.config["vtrace"]:
drop_last = self.config["vtrace_drop_last_ts"]
logger.debug(
"Using V-Trace surrogate loss (vtrace=True; "
f"drop_last={drop_last})"
)
# Prepare actions for loss.
loss_actions = (
actions if is_multidiscrete else tf.expand_dims(actions, axis=1)
)
old_policy_behaviour_logits = tf.stop_gradient(target_model_out)
old_policy_action_dist = dist_class(old_policy_behaviour_logits, model)
# Prepare KL for Loss
mean_kl = make_time_major(
old_policy_action_dist.multi_kl(action_dist), drop_last=drop_last
)
unpacked_behaviour_logits = tf.split(
behaviour_logits, output_hidden_shape, axis=1
)
unpacked_old_policy_behaviour_logits = tf.split(
old_policy_behaviour_logits, output_hidden_shape, axis=1
)
# Compute vtrace on the CPU for better perf.
with tf.device("/cpu:0"):
vtrace_returns = vtrace.multi_from_logits(
behaviour_policy_logits=make_time_major(
unpacked_behaviour_logits, drop_last=drop_last
),
target_policy_logits=make_time_major(
unpacked_old_policy_behaviour_logits, drop_last=drop_last
),
actions=tf.unstack(
make_time_major(loss_actions, drop_last=drop_last), axis=2
),
discounts=tf.cast(
~make_time_major(
tf.cast(dones, tf.bool), drop_last=drop_last
),
tf.float32,
)
* self.config["gamma"],
rewards=make_time_major(rewards, drop_last=drop_last),
values=values_time_major[:-1]
if drop_last
else values_time_major,
bootstrap_value=values_time_major[-1],
dist_class=Categorical if is_multidiscrete else dist_class,
model=model,
clip_rho_threshold=tf.cast(
self.config["vtrace_clip_rho_threshold"], tf.float32
),
clip_pg_rho_threshold=tf.cast(
self.config["vtrace_clip_pg_rho_threshold"], tf.float32
),
)
actions_logp = make_time_major(
action_dist.logp(actions), drop_last=drop_last
)
prev_actions_logp = make_time_major(
prev_action_dist.logp(actions), drop_last=drop_last
)
old_policy_actions_logp = make_time_major(
old_policy_action_dist.logp(actions), drop_last=drop_last
)
is_ratio = tf.clip_by_value(
tf.math.exp(prev_actions_logp - old_policy_actions_logp), 0.0, 2.0
)
logp_ratio = is_ratio * tf.exp(actions_logp - prev_actions_logp)
self._is_ratio = is_ratio
advantages = vtrace_returns.pg_advantages
surrogate_loss = tf.minimum(
advantages * logp_ratio,
advantages
* tf.clip_by_value(
logp_ratio,
1 - self.config["clip_param"],
1 + self.config["clip_param"],
),
)
action_kl = (
tf.reduce_mean(mean_kl, axis=0) if is_multidiscrete else mean_kl
)
mean_kl_loss = reduce_mean_valid(action_kl)
mean_policy_loss = -reduce_mean_valid(surrogate_loss)
# The value function loss.
if drop_last:
delta = values_time_major[:-1] - vtrace_returns.vs
else:
delta = values_time_major - vtrace_returns.vs
value_targets = vtrace_returns.vs
mean_vf_loss = 0.5 * reduce_mean_valid(tf.math.square(delta))
# The entropy loss.
actions_entropy = make_time_major(
action_dist.multi_entropy(), drop_last=True
)
mean_entropy = reduce_mean_valid(actions_entropy)
else:
logger.debug("Using PPO surrogate loss (vtrace=False)")
# Prepare KL for Loss
mean_kl = make_time_major(prev_action_dist.multi_kl(action_dist))
logp_ratio = tf.math.exp(
make_time_major(action_dist.logp(actions))
- make_time_major(prev_action_dist.logp(actions))
)
advantages = make_time_major(train_batch[Postprocessing.ADVANTAGES])
surrogate_loss = tf.minimum(
advantages * logp_ratio,
advantages
* tf.clip_by_value(
logp_ratio,
1 - self.config["clip_param"],
1 + self.config["clip_param"],
),
)
action_kl = (
tf.reduce_mean(mean_kl, axis=0) if is_multidiscrete else mean_kl
)
mean_kl_loss = reduce_mean_valid(action_kl)
mean_policy_loss = -reduce_mean_valid(surrogate_loss)
# The value function loss.
value_targets = make_time_major(
train_batch[Postprocessing.VALUE_TARGETS]
)
delta = values_time_major - value_targets
mean_vf_loss = 0.5 * reduce_mean_valid(tf.math.square(delta))
# The entropy loss.
mean_entropy = reduce_mean_valid(
make_time_major(action_dist.multi_entropy())
)
# The summed weighted loss.
total_loss = mean_policy_loss - mean_entropy * self.entropy_coeff
# Optional KL loss.
if self.config["use_kl_loss"]:
total_loss += self.kl_coeff * mean_kl_loss
# Optional vf loss (or in a separate term due to separate
# optimizers/networks).
loss_wo_vf = total_loss
if not self.config["_separate_vf_optimizer"]:
total_loss += mean_vf_loss * self.config["vf_loss_coeff"]
# Store stats in policy for stats_fn.
self._total_loss = total_loss
self._loss_wo_vf = loss_wo_vf
self._mean_policy_loss = mean_policy_loss
# Backward compatibility: Deprecate policy._mean_kl.
self._mean_kl_loss = self._mean_kl = mean_kl_loss
self._mean_vf_loss = mean_vf_loss
self._mean_entropy = mean_entropy
self._value_targets = value_targets
# Return one total loss or two losses: vf vs rest (policy + kl).
if self.config["_separate_vf_optimizer"]:
return loss_wo_vf, mean_vf_loss
else:
return total_loss
@override(base)
def stats_fn(self, train_batch: SampleBatch) -> Dict[str, TensorType]:
values_batched = _make_time_major(
self,
train_batch.get(SampleBatch.SEQ_LENS),
self.model.value_function(),
drop_last=self.config["vtrace"] and self.config["vtrace_drop_last_ts"],
)
stats_dict = {
"cur_lr": tf.cast(self.cur_lr, tf.float64),
"total_loss": self._total_loss,
"policy_loss": self._mean_policy_loss,
"entropy": self._mean_entropy,
"var_gnorm": tf.linalg.global_norm(self.model.trainable_variables()),
"vf_loss": self._mean_vf_loss,
"vf_explained_var": explained_variance(
tf.reshape(self._value_targets, [-1]),
tf.reshape(values_batched, [-1]),
),
"entropy_coeff": tf.cast(self.entropy_coeff, tf.float64),
}
if self.config["vtrace"]:
is_stat_mean, is_stat_var = tf.nn.moments(self._is_ratio, [0, 1])
stats_dict["mean_IS"] = is_stat_mean
stats_dict["var_IS"] = is_stat_var
if self.config["use_kl_loss"]:
stats_dict["kl"] = self._mean_kl_loss
stats_dict["KL_Coeff"] = self.kl_coeff
return stats_dict
@override(base)
def postprocess_trajectory(
self,
sample_batch: SampleBatch,
other_agent_batches: Optional[SampleBatch] = None,
episode: Optional["Episode"] = None,
):
if not self.config["vtrace"]:
sample_batch = compute_gae_for_sample_batch(
self, sample_batch, other_agent_batches, episode
)
return sample_batch
@override(base)
def extra_action_out_fn(self) -> Dict[str, TensorType]:
extra_action_fetches = super().extra_action_out_fn()
if not self.config["vtrace"]:
extra_action_fetches[SampleBatch.VF_PREDS] = self.model.value_function()
return extra_action_fetches
@override(base)
def get_batch_divisibility_req(self) -> int:
return self.config["rollout_fragment_length"]
APPOTFPolicy.__name__ = name
APPOTFPolicy.__qualname__ = name
return APPOTFPolicy
APPOTF1Policy = get_appo_tf_policy("APPOTF1Policy", DynamicTFPolicyV2)
APPOTF2Policy = get_appo_tf_policy("APPOTF2Policy", EagerTFPolicyV2)