2020-09-02 14:03:01 +02:00
|
|
|
"""
|
|
|
|
TensorFlow policy class used for PPO.
|
|
|
|
"""
|
|
|
|
|
|
|
|
import gym
|
2019-02-22 11:18:51 -08:00
|
|
|
import logging
|
2020-09-02 14:03:01 +02:00
|
|
|
from typing import Dict, List, Optional, Type, Union
|
2018-06-28 09:49:08 -07:00
|
|
|
|
2018-07-22 05:09:25 -07:00
|
|
|
import ray
|
2020-09-02 14:03:01 +02:00
|
|
|
from ray.rllib.evaluation.episode import MultiAgentEpisode
|
2021-01-19 14:22:36 +01:00
|
|
|
from ray.rllib.evaluation.postprocessing import compute_gae_for_sample_batch, \
|
2019-03-29 12:44:23 -07:00
|
|
|
Postprocessing
|
2020-09-02 14:03:01 +02:00
|
|
|
from ray.rllib.models.modelv2 import ModelV2
|
|
|
|
from ray.rllib.models.tf.tf_action_dist import TFActionDistribution
|
|
|
|
from ray.rllib.policy.policy import Policy
|
2019-05-20 16:46:05 -07:00
|
|
|
from ray.rllib.policy.sample_batch import SampleBatch
|
2019-07-09 03:30:32 +02:00
|
|
|
from ray.rllib.policy.tf_policy import LearningRateSchedule, \
|
2020-01-21 08:06:50 +01:00
|
|
|
EntropyCoeffSchedule
|
2019-05-20 16:46:05 -07:00
|
|
|
from ray.rllib.policy.tf_policy_template import build_tf_policy
|
2021-01-19 09:51:35 +01:00
|
|
|
from ray.rllib.utils.deprecation import DEPRECATED_VALUE, deprecation_warning
|
2020-07-11 22:06:35 +02:00
|
|
|
from ray.rllib.utils.framework import try_import_tf, get_variable
|
2020-06-16 08:52:20 +02:00
|
|
|
from ray.rllib.utils.tf_ops import explained_variance, make_tf_callable
|
2020-09-02 14:03:01 +02:00
|
|
|
from ray.rllib.utils.typing import AgentID, LocalOptimizer, ModelGradients, \
|
|
|
|
TensorType, TrainerConfigDict
|
2019-05-10 20:36:18 -07:00
|
|
|
|
2020-06-30 10:13:20 +02:00
|
|
|
tf1, tf, tfv = try_import_tf()
|
2018-06-28 09:49:08 -07:00
|
|
|
|
2019-02-22 11:18:51 -08:00
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
2018-06-28 09:49:08 -07:00
|
|
|
|
2020-09-02 14:03:01 +02:00
|
|
|
def ppo_surrogate_loss(
|
2021-04-27 10:44:54 +02:00
|
|
|
policy: Policy, model: Union[ModelV2, "tf.keras.Model"],
|
|
|
|
dist_class: Type[TFActionDistribution],
|
2020-09-02 14:03:01 +02:00
|
|
|
train_batch: SampleBatch) -> Union[TensorType, List[TensorType]]:
|
|
|
|
"""Constructs the loss for Proximal Policy Objective.
|
2018-06-28 09:49:08 -07:00
|
|
|
|
2020-09-02 14:03:01 +02:00
|
|
|
Args:
|
|
|
|
policy (Policy): The Policy to calculate the loss for.
|
2021-04-27 10:44:54 +02:00
|
|
|
model (Union[ModelV2, tf.keras.Model]): The Model to calculate
|
|
|
|
the loss for.
|
2020-09-02 14:03:01 +02:00
|
|
|
dist_class (Type[ActionDistribution]: The action distr. class.
|
|
|
|
train_batch (SampleBatch): The training data.
|
2018-06-28 09:49:08 -07:00
|
|
|
|
2020-09-02 14:03:01 +02:00
|
|
|
Returns:
|
|
|
|
Union[TensorType, List[TensorType]]: A single loss tensor or a list
|
|
|
|
of loss tensors.
|
|
|
|
"""
|
2021-04-27 10:44:54 +02:00
|
|
|
if isinstance(model, tf.keras.Model):
|
|
|
|
logits, state, extra_outs = model(train_batch)
|
|
|
|
value_fn_out = extra_outs[SampleBatch.VF_PREDS]
|
|
|
|
else:
|
|
|
|
logits, state = model.from_batch(train_batch)
|
|
|
|
value_fn_out = model.value_function()
|
|
|
|
|
2020-09-02 14:03:01 +02:00
|
|
|
curr_action_dist = dist_class(logits, model)
|
2019-08-23 02:21:11 -04:00
|
|
|
|
2020-09-02 14:03:01 +02:00
|
|
|
# RNN case: Mask away 0-padded chunks at end of time axis.
|
2019-08-23 02:21:11 -04:00
|
|
|
if state:
|
2020-11-28 01:25:47 +01:00
|
|
|
# Derive max_seq_len from the data itself, not from the seq_lens
|
|
|
|
# tensor. This is in case e.g. seq_lens=[2, 3], but the data is still
|
|
|
|
# 0-padded up to T=5 (as it's the case for attention nets).
|
|
|
|
B = tf.shape(train_batch["seq_lens"])[0]
|
|
|
|
max_seq_len = tf.shape(logits)[0] // B
|
|
|
|
|
2019-08-23 02:21:11 -04:00
|
|
|
mask = tf.sequence_mask(train_batch["seq_lens"], max_seq_len)
|
2019-05-18 00:23:11 -07:00
|
|
|
mask = tf.reshape(mask, [-1])
|
|
|
|
|
2020-09-02 14:03:01 +02:00
|
|
|
def reduce_mean_valid(t):
|
|
|
|
return tf.reduce_mean(tf.boolean_mask(t, mask))
|
|
|
|
|
|
|
|
# non-RNN case: No masking.
|
|
|
|
else:
|
|
|
|
mask = None
|
|
|
|
reduce_mean_valid = tf.reduce_mean
|
|
|
|
|
|
|
|
prev_action_dist = dist_class(train_batch[SampleBatch.ACTION_DIST_INPUTS],
|
|
|
|
model)
|
|
|
|
|
|
|
|
logp_ratio = tf.exp(
|
|
|
|
curr_action_dist.logp(train_batch[SampleBatch.ACTIONS]) -
|
|
|
|
train_batch[SampleBatch.ACTION_LOGP])
|
|
|
|
action_kl = prev_action_dist.kl(curr_action_dist)
|
|
|
|
mean_kl = reduce_mean_valid(action_kl)
|
|
|
|
|
|
|
|
curr_entropy = curr_action_dist.entropy()
|
|
|
|
mean_entropy = reduce_mean_valid(curr_entropy)
|
|
|
|
|
|
|
|
surrogate_loss = tf.minimum(
|
|
|
|
train_batch[Postprocessing.ADVANTAGES] * logp_ratio,
|
|
|
|
train_batch[Postprocessing.ADVANTAGES] * tf.clip_by_value(
|
|
|
|
logp_ratio, 1 - policy.config["clip_param"],
|
|
|
|
1 + policy.config["clip_param"]))
|
|
|
|
mean_policy_loss = reduce_mean_valid(-surrogate_loss)
|
|
|
|
|
2021-05-04 14:17:00 +02:00
|
|
|
# Compute a value function loss.
|
|
|
|
if policy.config["use_critic"]:
|
2020-09-02 14:03:01 +02:00
|
|
|
prev_value_fn_out = train_batch[SampleBatch.VF_PREDS]
|
|
|
|
vf_loss1 = tf.math.square(value_fn_out -
|
|
|
|
train_batch[Postprocessing.VALUE_TARGETS])
|
|
|
|
vf_clipped = prev_value_fn_out + tf.clip_by_value(
|
|
|
|
value_fn_out - prev_value_fn_out, -policy.config["vf_clip_param"],
|
|
|
|
policy.config["vf_clip_param"])
|
|
|
|
vf_loss2 = tf.math.square(vf_clipped -
|
|
|
|
train_batch[Postprocessing.VALUE_TARGETS])
|
|
|
|
vf_loss = tf.maximum(vf_loss1, vf_loss2)
|
|
|
|
mean_vf_loss = reduce_mean_valid(vf_loss)
|
2021-05-04 14:17:00 +02:00
|
|
|
# Ignore the value function.
|
2020-09-02 14:03:01 +02:00
|
|
|
else:
|
2021-05-04 14:17:00 +02:00
|
|
|
vf_loss = mean_vf_loss = tf.constant(0.0)
|
|
|
|
|
|
|
|
total_loss = reduce_mean_valid(-surrogate_loss +
|
|
|
|
policy.kl_coeff * action_kl +
|
|
|
|
policy.config["vf_loss_coeff"] * vf_loss -
|
|
|
|
policy.entropy_coeff * curr_entropy)
|
2020-09-02 14:03:01 +02:00
|
|
|
|
|
|
|
# Store stats in policy for stats_fn.
|
|
|
|
policy._total_loss = total_loss
|
|
|
|
policy._mean_policy_loss = mean_policy_loss
|
|
|
|
policy._mean_vf_loss = mean_vf_loss
|
|
|
|
policy._mean_entropy = mean_entropy
|
|
|
|
policy._mean_kl = mean_kl
|
2021-04-27 10:44:54 +02:00
|
|
|
policy._value_fn_out = value_fn_out
|
2020-09-02 14:03:01 +02:00
|
|
|
|
|
|
|
return total_loss
|
|
|
|
|
|
|
|
|
|
|
|
def kl_and_loss_stats(policy: Policy,
|
|
|
|
train_batch: SampleBatch) -> Dict[str, TensorType]:
|
|
|
|
"""Stats function for PPO. Returns a dict with important KL and loss stats.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
policy (Policy): The Policy to generate stats for.
|
|
|
|
train_batch (SampleBatch): The SampleBatch (already) used for training.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
Dict[str, TensorType]: The stats dict.
|
|
|
|
"""
|
2019-07-03 15:59:47 -07:00
|
|
|
return {
|
2019-07-21 12:27:17 -07:00
|
|
|
"cur_kl_coeff": tf.cast(policy.kl_coeff, tf.float64),
|
|
|
|
"cur_lr": tf.cast(policy.cur_lr, tf.float64),
|
2020-09-02 14:03:01 +02:00
|
|
|
"total_loss": policy._total_loss,
|
|
|
|
"policy_loss": policy._mean_policy_loss,
|
|
|
|
"vf_loss": policy._mean_vf_loss,
|
2019-07-03 15:59:47 -07:00
|
|
|
"vf_explained_var": explained_variance(
|
2021-04-27 10:44:54 +02:00
|
|
|
train_batch[Postprocessing.VALUE_TARGETS], policy._value_fn_out),
|
2020-09-02 14:03:01 +02:00
|
|
|
"kl": policy._mean_kl,
|
|
|
|
"entropy": policy._mean_entropy,
|
2019-07-21 12:27:17 -07:00
|
|
|
"entropy_coeff": tf.cast(policy.entropy_coeff, tf.float64),
|
2019-05-18 00:23:11 -07:00
|
|
|
}
|
|
|
|
|
|
|
|
|
2021-04-27 10:44:54 +02:00
|
|
|
# TODO: (sven) Deprecate once we only allow native keras models.
|
2020-09-02 14:03:01 +02:00
|
|
|
def vf_preds_fetches(policy: Policy) -> Dict[str, TensorType]:
|
|
|
|
"""Defines extra fetches per action computation.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
policy (Policy): The Policy to perform the extra action fetch on.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
Dict[str, TensorType]: Dict with extra tf fetches to perform per
|
|
|
|
action computation.
|
|
|
|
"""
|
2021-04-27 10:44:54 +02:00
|
|
|
# Keras models return values for each call in third return argument
|
|
|
|
# (dict).
|
|
|
|
if isinstance(policy.model, tf.keras.Model):
|
|
|
|
return {}
|
2020-09-02 14:03:01 +02:00
|
|
|
# Return value function outputs. VF estimates will hence be added to the
|
|
|
|
# SampleBatches produced by the sampler(s) to generate the train batches
|
|
|
|
# going into the loss function.
|
2019-05-18 00:23:11 -07:00
|
|
|
return {
|
2019-08-23 02:21:11 -04:00
|
|
|
SampleBatch.VF_PREDS: policy.model.value_function(),
|
2019-05-18 00:23:11 -07:00
|
|
|
}
|
|
|
|
|
|
|
|
|
2020-09-02 14:03:01 +02:00
|
|
|
def compute_and_clip_gradients(policy: Policy, optimizer: LocalOptimizer,
|
|
|
|
loss: TensorType) -> ModelGradients:
|
|
|
|
"""Gradients computing function (from loss tensor, using local optimizer).
|
|
|
|
|
|
|
|
Args:
|
|
|
|
policy (Policy): The Policy object that generated the loss tensor and
|
|
|
|
that holds the given local optimizer.
|
|
|
|
optimizer (LocalOptimizer): The tf (local) optimizer object to
|
|
|
|
calculate the gradients with.
|
|
|
|
loss (TensorType): The loss tensor for which gradients should be
|
|
|
|
calculated.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
ModelGradients: List of the possibly clipped gradients- and variable
|
|
|
|
tuples.
|
|
|
|
"""
|
|
|
|
# Compute the gradients.
|
2021-04-27 10:44:54 +02:00
|
|
|
variables = policy.model.trainable_variables
|
|
|
|
if isinstance(policy.model, ModelV2):
|
|
|
|
variables = variables()
|
2020-09-02 14:03:01 +02:00
|
|
|
grads_and_vars = optimizer.compute_gradients(loss, variables)
|
|
|
|
|
|
|
|
# Clip by global norm, if necessary.
|
2019-05-18 00:23:11 -07:00
|
|
|
if policy.config["grad_clip"] is not None:
|
2021-01-22 19:36:02 +01:00
|
|
|
# Defuse inf gradients (due to super large losses).
|
2019-08-23 02:21:11 -04:00
|
|
|
grads = [g for (g, v) in grads_and_vars]
|
2021-01-22 19:36:02 +01:00
|
|
|
grads, _ = tf.clip_by_global_norm(grads, policy.config["grad_clip"])
|
|
|
|
# If the global_norm is inf -> All grads will be NaN. Stabilize this
|
|
|
|
# here by setting them to 0.0. This will simply ignore destructive loss
|
|
|
|
# calculations.
|
|
|
|
policy.grads = [
|
|
|
|
tf.where(tf.math.is_nan(g), tf.zeros_like(g), g) for g in grads
|
|
|
|
]
|
2020-09-02 14:03:01 +02:00
|
|
|
clipped_grads_and_vars = list(zip(policy.grads, variables))
|
|
|
|
return clipped_grads_and_vars
|
2019-05-18 00:23:11 -07:00
|
|
|
else:
|
2020-09-02 14:03:01 +02:00
|
|
|
return grads_and_vars
|
2019-05-18 00:23:11 -07:00
|
|
|
|
|
|
|
|
2020-01-02 17:42:13 -08:00
|
|
|
class KLCoeffMixin:
|
2020-09-02 14:03:01 +02:00
|
|
|
"""Assigns the `update_kl()` method to the PPOPolicy.
|
|
|
|
|
|
|
|
This is used in PPO's execution plan (see ppo.py) for updating the KL
|
|
|
|
coefficient after each learning step based on `config.kl_target` and
|
|
|
|
the measured KL value (from the train_batch).
|
|
|
|
"""
|
|
|
|
|
2019-05-18 00:23:11 -07:00
|
|
|
def __init__(self, config):
|
2020-09-02 14:03:01 +02:00
|
|
|
# The current KL value (as python float).
|
2019-05-18 00:23:11 -07:00
|
|
|
self.kl_coeff_val = config["kl_coeff"]
|
2020-09-02 14:03:01 +02:00
|
|
|
# The current KL value (as tf Variable for in-graph operations).
|
2020-07-11 22:06:35 +02:00
|
|
|
self.kl_coeff = get_variable(
|
2020-10-02 23:07:44 +02:00
|
|
|
float(self.kl_coeff_val),
|
|
|
|
tf_name="kl_coeff",
|
|
|
|
trainable=False,
|
|
|
|
framework=config["framework"])
|
2020-09-02 14:03:01 +02:00
|
|
|
# Constant target value.
|
|
|
|
self.kl_target = config["kl_target"]
|
2021-04-14 14:03:15 +02:00
|
|
|
if self.framework == "tf":
|
|
|
|
self._kl_coeff_placeholder = \
|
|
|
|
tf1.placeholder(dtype=tf.float32, name="kl_coeff")
|
|
|
|
self._kl_coeff_update = self.kl_coeff.assign(
|
|
|
|
self._kl_coeff_placeholder, read_value=False)
|
2018-06-28 09:49:08 -07:00
|
|
|
|
2019-05-18 00:23:11 -07:00
|
|
|
def update_kl(self, sampled_kl):
|
2020-09-02 14:03:01 +02:00
|
|
|
# Update the current KL value based on the recently measured value.
|
2021-04-14 14:03:15 +02:00
|
|
|
# Increase.
|
2019-05-18 00:23:11 -07:00
|
|
|
if sampled_kl > 2.0 * self.kl_target:
|
|
|
|
self.kl_coeff_val *= 1.5
|
2021-04-14 14:03:15 +02:00
|
|
|
# Decrease.
|
2019-05-18 00:23:11 -07:00
|
|
|
elif sampled_kl < 0.5 * self.kl_target:
|
|
|
|
self.kl_coeff_val *= 0.5
|
2021-04-14 14:03:15 +02:00
|
|
|
# No change.
|
|
|
|
else:
|
|
|
|
return self.kl_coeff_val
|
|
|
|
|
|
|
|
# Update the tf Variable (via session call for tf).
|
|
|
|
if self.framework == "tf":
|
|
|
|
self.get_session().run(
|
|
|
|
self._kl_coeff_update,
|
|
|
|
feed_dict={self._kl_coeff_placeholder: self.kl_coeff_val})
|
|
|
|
else:
|
|
|
|
self.kl_coeff.assign(self.kl_coeff_val, read_value=False)
|
2020-09-02 14:03:01 +02:00
|
|
|
# Return the current KL value.
|
2019-05-18 00:23:11 -07:00
|
|
|
return self.kl_coeff_val
|
|
|
|
|
|
|
|
|
2020-01-02 17:42:13 -08:00
|
|
|
class ValueNetworkMixin:
|
2020-09-02 14:03:01 +02:00
|
|
|
"""Assigns the `_value()` method to the PPOPolicy.
|
|
|
|
|
|
|
|
This way, Policy can call `_value()` to get the current VF estimate on a
|
|
|
|
single(!) observation (as done in `postprocess_trajectory_fn`).
|
|
|
|
Note: When doing this, an actual forward pass is being performed.
|
|
|
|
This is different from only calling `model.value_function()`, where
|
|
|
|
the result of the most recent forward pass is being used to return an
|
|
|
|
already calculated tensor.
|
|
|
|
"""
|
|
|
|
|
2019-05-18 00:23:11 -07:00
|
|
|
def __init__(self, obs_space, action_space, config):
|
2020-09-02 14:03:01 +02:00
|
|
|
# When doing GAE, we need the value function estimate on the
|
|
|
|
# observation.
|
2019-05-18 00:23:11 -07:00
|
|
|
if config["use_gae"]:
|
2019-08-23 02:21:11 -04:00
|
|
|
|
2020-12-07 13:08:17 +01:00
|
|
|
# Input dict is provided to us automatically via the Model's
|
|
|
|
# requirements. It's a single-timestep (last one in trajectory)
|
|
|
|
# input_dict.
|
2021-03-23 17:50:18 +01:00
|
|
|
@make_tf_callable(self.get_session())
|
|
|
|
def value(**input_dict):
|
|
|
|
input_dict = SampleBatch(input_dict)
|
2021-04-27 10:44:54 +02:00
|
|
|
if isinstance(self.model, tf.keras.Model):
|
|
|
|
_, _, extra_outs = self.model(input_dict)
|
|
|
|
return extra_outs[SampleBatch.VF_PREDS][0]
|
|
|
|
else:
|
|
|
|
model_out, _ = self.model(input_dict)
|
|
|
|
# [0] = remove the batch dim.
|
|
|
|
return self.model.value_function()[0]
|
2019-08-23 02:21:11 -04:00
|
|
|
|
2020-09-02 14:03:01 +02:00
|
|
|
# When not doing GAE, we do not require the value function's output.
|
2018-06-28 09:49:08 -07:00
|
|
|
else:
|
2019-08-23 02:21:11 -04:00
|
|
|
|
|
|
|
@make_tf_callable(self.get_session())
|
2020-12-07 13:08:17 +01:00
|
|
|
def value(*args, **kwargs):
|
2019-08-23 02:21:11 -04:00
|
|
|
return tf.constant(0.0)
|
|
|
|
|
|
|
|
self._value = value
|
2019-05-18 00:23:11 -07:00
|
|
|
|
|
|
|
|
2020-09-02 14:03:01 +02:00
|
|
|
def setup_config(policy: Policy, obs_space: gym.spaces.Space,
|
|
|
|
action_space: gym.spaces.Space,
|
|
|
|
config: TrainerConfigDict) -> None:
|
|
|
|
"""Executed before Policy is "initialized" (at beginning of constructor).
|
|
|
|
|
|
|
|
Args:
|
|
|
|
policy (Policy): The Policy object.
|
|
|
|
obs_space (gym.spaces.Space): The Policy's observation space.
|
|
|
|
action_space (gym.spaces.Space): The Policy's action space.
|
|
|
|
config (TrainerConfigDict): The Policy's config.
|
|
|
|
"""
|
2021-01-19 09:51:35 +01:00
|
|
|
# Setting `vf_share_layers` in the top-level config is deprecated.
|
|
|
|
# It's confusing as some users might (correctly!) set it in their
|
|
|
|
# model config and then won't notice that it's silently overwritten
|
|
|
|
# here.
|
|
|
|
if config["vf_share_layers"] != DEPRECATED_VALUE:
|
|
|
|
deprecation_warning(
|
|
|
|
old="config[vf_share_layers]",
|
|
|
|
new="config[model][vf_share_layers]",
|
|
|
|
error=False,
|
|
|
|
)
|
|
|
|
config["model"]["vf_share_layers"] = config["vf_share_layers"]
|
|
|
|
|
|
|
|
# If vf_share_layers is True, inform about the need to tune vf_loss_coeff.
|
|
|
|
if config.get("model", {}).get("vf_share_layers") is True:
|
|
|
|
logger.info(
|
|
|
|
"`vf_share_layers=True` in your model. "
|
|
|
|
"Therefore, remember to tune the value of `vf_loss_coeff`!")
|
2019-07-07 15:06:41 -07:00
|
|
|
|
|
|
|
|
2020-09-02 14:03:01 +02:00
|
|
|
def setup_mixins(policy: Policy, obs_space: gym.spaces.Space,
|
|
|
|
action_space: gym.spaces.Space,
|
|
|
|
config: TrainerConfigDict) -> None:
|
2020-09-20 11:27:02 +02:00
|
|
|
"""Call mixin classes' constructors before Policy's loss initialization.
|
2020-09-02 14:03:01 +02:00
|
|
|
|
|
|
|
Args:
|
|
|
|
policy (Policy): The Policy object.
|
|
|
|
obs_space (gym.spaces.Space): The Policy's observation space.
|
|
|
|
action_space (gym.spaces.Space): The Policy's action space.
|
|
|
|
config (TrainerConfigDict): The Policy's config.
|
|
|
|
"""
|
2019-05-18 00:23:11 -07:00
|
|
|
ValueNetworkMixin.__init__(policy, obs_space, action_space, config)
|
|
|
|
KLCoeffMixin.__init__(policy, config)
|
2019-07-09 03:30:32 +02:00
|
|
|
EntropyCoeffSchedule.__init__(policy, config["entropy_coeff"],
|
|
|
|
config["entropy_coeff_schedule"])
|
2019-05-18 00:23:11 -07:00
|
|
|
LearningRateSchedule.__init__(policy, config["lr"], config["lr_schedule"])
|
|
|
|
|
|
|
|
|
2021-01-19 14:22:36 +01:00
|
|
|
def postprocess_ppo_gae(
|
|
|
|
policy: Policy,
|
|
|
|
sample_batch: SampleBatch,
|
|
|
|
other_agent_batches: Optional[Dict[AgentID, SampleBatch]] = None,
|
|
|
|
episode: Optional[MultiAgentEpisode] = None) -> SampleBatch:
|
|
|
|
|
|
|
|
# Stub serving backward compatibility.
|
|
|
|
deprecation_warning(
|
|
|
|
old="rllib.agents.ppo.ppo_tf_policy.postprocess_ppo_gae",
|
|
|
|
new="rllib.evaluation.postprocessing.compute_gae_for_sample_batch",
|
|
|
|
error=False)
|
|
|
|
|
|
|
|
return compute_gae_for_sample_batch(policy, sample_batch,
|
|
|
|
other_agent_batches, episode)
|
|
|
|
|
|
|
|
|
2020-09-02 14:03:01 +02:00
|
|
|
# Build a child class of `DynamicTFPolicy`, given the custom functions defined
|
|
|
|
# above.
|
2019-05-18 00:23:11 -07:00
|
|
|
PPOTFPolicy = build_tf_policy(
|
|
|
|
name="PPOTFPolicy",
|
|
|
|
loss_fn=ppo_surrogate_loss,
|
2020-09-02 14:03:01 +02:00
|
|
|
get_default_config=lambda: ray.rllib.agents.ppo.ppo.DEFAULT_CONFIG,
|
2021-01-19 14:22:36 +01:00
|
|
|
postprocess_fn=compute_gae_for_sample_batch,
|
2019-05-18 00:23:11 -07:00
|
|
|
stats_fn=kl_and_loss_stats,
|
2021-05-18 11:10:46 +02:00
|
|
|
compute_gradients_fn=compute_and_clip_gradients,
|
2021-02-25 12:18:11 +01:00
|
|
|
extra_action_out_fn=vf_preds_fetches,
|
2019-07-07 15:06:41 -07:00
|
|
|
before_init=setup_config,
|
2019-05-18 00:23:11 -07:00
|
|
|
before_loss_init=setup_mixins,
|
2019-07-09 03:30:32 +02:00
|
|
|
mixins=[
|
|
|
|
LearningRateSchedule, EntropyCoeffSchedule, KLCoeffMixin,
|
|
|
|
ValueNetworkMixin
|
|
|
|
])
|