ray/rllib/agents/ppo/ppo_tf_policy.py

368 lines
14 KiB
Python
Raw Normal View History

"""
TensorFlow policy class used for PPO.
"""
import gym
import logging
from typing import Dict, List, Optional, Type, Union
import ray
from ray.rllib.evaluation.episode import MultiAgentEpisode
from ray.rllib.evaluation.postprocessing import compute_advantages, \
Postprocessing
from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.models.tf.tf_action_dist import TFActionDistribution
from ray.rllib.policy.policy import Policy
from ray.rllib.policy.sample_batch import SampleBatch
2019-07-09 03:30:32 +02:00
from ray.rllib.policy.tf_policy import LearningRateSchedule, \
EntropyCoeffSchedule
from ray.rllib.policy.tf_policy_template import build_tf_policy
2020-07-11 22:06:35 +02:00
from ray.rllib.utils.framework import try_import_tf, get_variable
from ray.rllib.utils.tf_ops import explained_variance, make_tf_callable
from ray.rllib.utils.typing import AgentID, LocalOptimizer, ModelGradients, \
TensorType, TrainerConfigDict
tf1, tf, tfv = try_import_tf()
logger = logging.getLogger(__name__)
def ppo_surrogate_loss(
policy: Policy, model: ModelV2, dist_class: Type[TFActionDistribution],
train_batch: SampleBatch) -> Union[TensorType, List[TensorType]]:
"""Constructs the loss for Proximal Policy Objective.
Args:
policy (Policy): The Policy to calculate the loss for.
model (ModelV2): The Model to calculate the loss for.
dist_class (Type[ActionDistribution]: The action distr. class.
train_batch (SampleBatch): The training data.
Returns:
Union[TensorType, List[TensorType]]: A single loss tensor or a list
of loss tensors.
"""
logits, state = model.from_batch(train_batch)
curr_action_dist = dist_class(logits, model)
# RNN case: Mask away 0-padded chunks at end of time axis.
if state:
# Derive max_seq_len from the data itself, not from the seq_lens
# tensor. This is in case e.g. seq_lens=[2, 3], but the data is still
# 0-padded up to T=5 (as it's the case for attention nets).
B = tf.shape(train_batch["seq_lens"])[0]
max_seq_len = tf.shape(logits)[0] // B
mask = tf.sequence_mask(train_batch["seq_lens"], max_seq_len)
mask = tf.reshape(mask, [-1])
def reduce_mean_valid(t):
return tf.reduce_mean(tf.boolean_mask(t, mask))
# non-RNN case: No masking.
else:
mask = None
reduce_mean_valid = tf.reduce_mean
prev_action_dist = dist_class(train_batch[SampleBatch.ACTION_DIST_INPUTS],
model)
logp_ratio = tf.exp(
curr_action_dist.logp(train_batch[SampleBatch.ACTIONS]) -
train_batch[SampleBatch.ACTION_LOGP])
action_kl = prev_action_dist.kl(curr_action_dist)
mean_kl = reduce_mean_valid(action_kl)
curr_entropy = curr_action_dist.entropy()
mean_entropy = reduce_mean_valid(curr_entropy)
surrogate_loss = tf.minimum(
train_batch[Postprocessing.ADVANTAGES] * logp_ratio,
train_batch[Postprocessing.ADVANTAGES] * tf.clip_by_value(
logp_ratio, 1 - policy.config["clip_param"],
1 + policy.config["clip_param"]))
mean_policy_loss = reduce_mean_valid(-surrogate_loss)
if policy.config["use_gae"]:
prev_value_fn_out = train_batch[SampleBatch.VF_PREDS]
value_fn_out = model.value_function()
vf_loss1 = tf.math.square(value_fn_out -
train_batch[Postprocessing.VALUE_TARGETS])
vf_clipped = prev_value_fn_out + tf.clip_by_value(
value_fn_out - prev_value_fn_out, -policy.config["vf_clip_param"],
policy.config["vf_clip_param"])
vf_loss2 = tf.math.square(vf_clipped -
train_batch[Postprocessing.VALUE_TARGETS])
vf_loss = tf.maximum(vf_loss1, vf_loss2)
mean_vf_loss = reduce_mean_valid(vf_loss)
total_loss = reduce_mean_valid(
-surrogate_loss + policy.kl_coeff * action_kl +
policy.config["vf_loss_coeff"] * vf_loss -
policy.entropy_coeff * curr_entropy)
else:
mean_vf_loss = tf.constant(0.0)
total_loss = reduce_mean_valid(-surrogate_loss +
policy.kl_coeff * action_kl -
policy.entropy_coeff * curr_entropy)
# Store stats in policy for stats_fn.
policy._total_loss = total_loss
policy._mean_policy_loss = mean_policy_loss
policy._mean_vf_loss = mean_vf_loss
policy._mean_entropy = mean_entropy
policy._mean_kl = mean_kl
return total_loss
def kl_and_loss_stats(policy: Policy,
train_batch: SampleBatch) -> Dict[str, TensorType]:
"""Stats function for PPO. Returns a dict with important KL and loss stats.
Args:
policy (Policy): The Policy to generate stats for.
train_batch (SampleBatch): The SampleBatch (already) used for training.
Returns:
Dict[str, TensorType]: The stats dict.
"""
2019-07-03 15:59:47 -07:00
return {
"cur_kl_coeff": tf.cast(policy.kl_coeff, tf.float64),
"cur_lr": tf.cast(policy.cur_lr, tf.float64),
"total_loss": policy._total_loss,
"policy_loss": policy._mean_policy_loss,
"vf_loss": policy._mean_vf_loss,
2019-07-03 15:59:47 -07:00
"vf_explained_var": explained_variance(
train_batch[Postprocessing.VALUE_TARGETS],
policy.model.value_function()),
"kl": policy._mean_kl,
"entropy": policy._mean_entropy,
"entropy_coeff": tf.cast(policy.entropy_coeff, tf.float64),
}
def vf_preds_fetches(policy: Policy) -> Dict[str, TensorType]:
"""Defines extra fetches per action computation.
Args:
policy (Policy): The Policy to perform the extra action fetch on.
Returns:
Dict[str, TensorType]: Dict with extra tf fetches to perform per
action computation.
"""
# Return value function outputs. VF estimates will hence be added to the
# SampleBatches produced by the sampler(s) to generate the train batches
# going into the loss function.
return {
SampleBatch.VF_PREDS: policy.model.value_function(),
}
def postprocess_ppo_gae(
policy: Policy,
sample_batch: SampleBatch,
other_agent_batches: Optional[Dict[AgentID, SampleBatch]] = None,
episode: Optional[MultiAgentEpisode] = None) -> SampleBatch:
"""Postprocesses a trajectory and returns the processed trajectory.
The trajectory contains only data from one episode and from one agent.
- If `config.batch_mode=truncate_episodes` (default), sample_batch may
contain a truncated (at-the-end) episode, in case the
`config.rollout_fragment_length` was reached by the sampler.
- If `config.batch_mode=complete_episodes`, sample_batch will contain
exactly one episode (no matter how long).
New columns can be added to sample_batch and existing ones may be altered.
Args:
policy (Policy): The Policy used to generate the trajectory
(`sample_batch`)
sample_batch (SampleBatch): The SampleBatch to postprocess.
other_agent_batches (Optional[Dict[PolicyID, SampleBatch]]): Optional
dict of AgentIDs mapping to other agents' trajectory data (from the
same episode). NOTE: The other agents use the same policy.
episode (Optional[MultiAgentEpisode]): Optional multi-agent episode
object in which the agents operated.
Returns:
SampleBatch: The postprocessed, modified SampleBatch (or a new one).
"""
# Trajectory is actually complete -> last r=0.0.
if sample_batch[SampleBatch.DONES][-1]:
last_r = 0.0
# Trajectory has been truncated -> last r=VF estimate of last obs.
else:
next_state = []
for i in range(policy.num_state_tensors()):
next_state.append(sample_batch["state_out_{}".format(i)][-1])
last_r = policy._value(sample_batch[SampleBatch.NEXT_OBS][-1],
sample_batch[SampleBatch.ACTIONS][-1],
sample_batch[SampleBatch.REWARDS][-1],
*next_state)
# Adds the policy logits, VF preds, and advantages to the batch,
# using GAE ("generalized advantage estimation") or not.
batch = compute_advantages(
sample_batch,
last_r,
policy.config["gamma"],
policy.config["lambda"],
use_gae=policy.config["use_gae"])
return batch
def compute_and_clip_gradients(policy: Policy, optimizer: LocalOptimizer,
loss: TensorType) -> ModelGradients:
"""Gradients computing function (from loss tensor, using local optimizer).
Args:
policy (Policy): The Policy object that generated the loss tensor and
that holds the given local optimizer.
optimizer (LocalOptimizer): The tf (local) optimizer object to
calculate the gradients with.
loss (TensorType): The loss tensor for which gradients should be
calculated.
Returns:
ModelGradients: List of the possibly clipped gradients- and variable
tuples.
"""
# Compute the gradients.
variables = policy.model.trainable_variables()
grads_and_vars = optimizer.compute_gradients(loss, variables)
# Clip by global norm, if necessary.
if policy.config["grad_clip"] is not None:
grads = [g for (g, v) in grads_and_vars]
policy.grads, _ = tf.clip_by_global_norm(grads,
policy.config["grad_clip"])
clipped_grads_and_vars = list(zip(policy.grads, variables))
return clipped_grads_and_vars
else:
return grads_and_vars
class KLCoeffMixin:
"""Assigns the `update_kl()` method to the PPOPolicy.
This is used in PPO's execution plan (see ppo.py) for updating the KL
coefficient after each learning step based on `config.kl_target` and
the measured KL value (from the train_batch).
"""
def __init__(self, config):
# The current KL value (as python float).
self.kl_coeff_val = config["kl_coeff"]
# The current KL value (as tf Variable for in-graph operations).
2020-07-11 22:06:35 +02:00
self.kl_coeff = get_variable(
float(self.kl_coeff_val),
tf_name="kl_coeff",
trainable=False,
framework=config["framework"])
# Constant target value.
self.kl_target = config["kl_target"]
def update_kl(self, sampled_kl):
# Update the current KL value based on the recently measured value.
if sampled_kl > 2.0 * self.kl_target:
self.kl_coeff_val *= 1.5
elif sampled_kl < 0.5 * self.kl_target:
self.kl_coeff_val *= 0.5
# Update the tf Variable (via session call).
self.kl_coeff.load(self.kl_coeff_val, session=self.get_session())
# Return the current KL value.
return self.kl_coeff_val
class ValueNetworkMixin:
"""Assigns the `_value()` method to the PPOPolicy.
This way, Policy can call `_value()` to get the current VF estimate on a
single(!) observation (as done in `postprocess_trajectory_fn`).
Note: When doing this, an actual forward pass is being performed.
This is different from only calling `model.value_function()`, where
the result of the most recent forward pass is being used to return an
already calculated tensor.
"""
def __init__(self, obs_space, action_space, config):
# When doing GAE, we need the value function estimate on the
# observation.
if config["use_gae"]:
@make_tf_callable(self.get_session())
def value(ob, prev_action, prev_reward, *state):
model_out, _ = self.model({
SampleBatch.CUR_OBS: tf.convert_to_tensor([ob]),
SampleBatch.PREV_ACTIONS: tf.convert_to_tensor(
[prev_action]),
SampleBatch.PREV_REWARDS: tf.convert_to_tensor(
[prev_reward]),
"is_training": tf.convert_to_tensor([False]),
}, [tf.convert_to_tensor([s]) for s in state],
tf.convert_to_tensor([1]))
# [0] = remove the batch dim.
return self.model.value_function()[0]
# When not doing GAE, we do not require the value function's output.
else:
@make_tf_callable(self.get_session())
def value(ob, prev_action, prev_reward, *state):
return tf.constant(0.0)
self._value = value
def setup_config(policy: Policy, obs_space: gym.spaces.Space,
action_space: gym.spaces.Space,
config: TrainerConfigDict) -> None:
"""Executed before Policy is "initialized" (at beginning of constructor).
Args:
policy (Policy): The Policy object.
obs_space (gym.spaces.Space): The Policy's observation space.
action_space (gym.spaces.Space): The Policy's action space.
config (TrainerConfigDict): The Policy's config.
"""
# Auto set the model option for VF layer sharing.
config["model"]["vf_share_layers"] = config["vf_share_layers"]
def setup_mixins(policy: Policy, obs_space: gym.spaces.Space,
action_space: gym.spaces.Space,
config: TrainerConfigDict) -> None:
2020-09-20 11:27:02 +02:00
"""Call mixin classes' constructors before Policy's loss initialization.
Args:
policy (Policy): The Policy object.
obs_space (gym.spaces.Space): The Policy's observation space.
action_space (gym.spaces.Space): The Policy's action space.
config (TrainerConfigDict): The Policy's config.
"""
ValueNetworkMixin.__init__(policy, obs_space, action_space, config)
KLCoeffMixin.__init__(policy, config)
2019-07-09 03:30:32 +02:00
EntropyCoeffSchedule.__init__(policy, config["entropy_coeff"],
config["entropy_coeff_schedule"])
LearningRateSchedule.__init__(policy, config["lr"], config["lr_schedule"])
# Build a child class of `DynamicTFPolicy`, given the custom functions defined
# above.
PPOTFPolicy = build_tf_policy(
name="PPOTFPolicy",
loss_fn=ppo_surrogate_loss,
get_default_config=lambda: ray.rllib.agents.ppo.ppo.DEFAULT_CONFIG,
postprocess_fn=postprocess_ppo_gae,
stats_fn=kl_and_loss_stats,
gradients_fn=compute_and_clip_gradients,
extra_action_fetches_fn=vf_preds_fetches,
before_init=setup_config,
before_loss_init=setup_mixins,
2019-07-09 03:30:32 +02:00
mixins=[
LearningRateSchedule, EntropyCoeffSchedule, KLCoeffMixin,
ValueNetworkMixin
])