2019-05-20 16:46:05 -07:00
|
|
|
"""Adapted from VTraceTFPolicy to use the PPO surrogate loss.
|
2019-03-29 12:44:23 -07:00
|
|
|
|
2019-05-20 16:46:05 -07:00
|
|
|
Keep in sync with changes to VTraceTFPolicy."""
|
2019-03-29 12:44:23 -07:00
|
|
|
|
|
|
|
from __future__ import absolute_import
|
|
|
|
from __future__ import division
|
|
|
|
from __future__ import print_function
|
|
|
|
|
|
|
|
import numpy as np
|
|
|
|
import logging
|
|
|
|
import gym
|
|
|
|
|
|
|
|
from ray.rllib.agents.impala import vtrace
|
2019-07-07 15:06:41 -07:00
|
|
|
from ray.rllib.agents.impala.vtrace_policy import _make_time_major, \
|
2019-08-23 02:21:11 -04:00
|
|
|
BEHAVIOUR_LOGITS, clip_gradients, validate_config, choose_optimizer
|
2019-05-18 00:23:11 -07:00
|
|
|
from ray.rllib.evaluation.postprocessing import Postprocessing
|
2019-07-27 02:08:16 -07:00
|
|
|
from ray.rllib.models.tf.tf_action_dist import Categorical
|
2019-05-20 16:46:05 -07:00
|
|
|
from ray.rllib.policy.sample_batch import SampleBatch
|
2019-03-29 12:44:23 -07:00
|
|
|
from ray.rllib.evaluation.postprocessing import compute_advantages
|
2019-05-10 20:36:18 -07:00
|
|
|
from ray.rllib.utils import try_import_tf
|
2019-07-29 15:02:32 -07:00
|
|
|
from ray.rllib.policy.tf_policy_template import build_tf_policy
|
2019-10-31 15:16:02 -07:00
|
|
|
from ray.rllib.policy.tf_policy import LearningRateSchedule, TFPolicy
|
2019-08-23 02:21:11 -04:00
|
|
|
from ray.rllib.agents.ppo.ppo_policy import KLCoeffMixin, ValueNetworkMixin
|
2019-07-29 15:02:32 -07:00
|
|
|
from ray.rllib.models import ModelCatalog
|
2019-10-31 15:16:02 -07:00
|
|
|
from ray.rllib.utils.annotations import override
|
2019-07-29 15:02:32 -07:00
|
|
|
from ray.rllib.utils.explained_variance import explained_variance
|
2019-08-23 02:21:11 -04:00
|
|
|
from ray.rllib.utils.tf_ops import make_tf_callable
|
2019-05-10 20:36:18 -07:00
|
|
|
|
|
|
|
tf = try_import_tf()
|
2019-03-29 12:44:23 -07:00
|
|
|
|
2019-07-29 15:02:32 -07:00
|
|
|
POLICY_SCOPE = "func"
|
|
|
|
TARGET_POLICY_SCOPE = "target_func"
|
|
|
|
|
2019-03-29 12:44:23 -07:00
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
|
|
class PPOSurrogateLoss(object):
|
|
|
|
"""Loss used when V-trace is disabled.
|
|
|
|
|
|
|
|
Arguments:
|
|
|
|
prev_actions_logp: A float32 tensor of shape [T, B].
|
|
|
|
actions_logp: A float32 tensor of shape [T, B].
|
|
|
|
action_kl: A float32 tensor of shape [T, B].
|
|
|
|
actions_entropy: A float32 tensor of shape [T, B].
|
|
|
|
values: A float32 tensor of shape [T, B].
|
|
|
|
valid_mask: A bool tensor of valid RNN input elements (#2992).
|
|
|
|
advantages: A float32 tensor of shape [T, B].
|
|
|
|
value_targets: A float32 tensor of shape [T, B].
|
2019-07-29 15:02:32 -07:00
|
|
|
vf_loss_coeff (float): Coefficient of the value function loss.
|
|
|
|
entropy_coeff (float): Coefficient of the entropy regularizer.
|
|
|
|
clip_param (float): Clip parameter.
|
|
|
|
cur_kl_coeff (float): Coefficient for KL loss.
|
|
|
|
use_kl_loss (bool): If true, use KL loss.
|
2019-03-29 12:44:23 -07:00
|
|
|
"""
|
|
|
|
|
|
|
|
def __init__(self,
|
|
|
|
prev_actions_logp,
|
|
|
|
actions_logp,
|
|
|
|
action_kl,
|
|
|
|
actions_entropy,
|
|
|
|
values,
|
|
|
|
valid_mask,
|
|
|
|
advantages,
|
|
|
|
value_targets,
|
|
|
|
vf_loss_coeff=0.5,
|
|
|
|
entropy_coeff=0.01,
|
2019-07-29 15:02:32 -07:00
|
|
|
clip_param=0.3,
|
|
|
|
cur_kl_coeff=None,
|
|
|
|
use_kl_loss=False):
|
|
|
|
def reduce_mean_valid(t):
|
|
|
|
return tf.reduce_mean(tf.boolean_mask(t, valid_mask))
|
2019-03-29 12:44:23 -07:00
|
|
|
|
|
|
|
logp_ratio = tf.exp(actions_logp - prev_actions_logp)
|
|
|
|
|
|
|
|
surrogate_loss = tf.minimum(
|
|
|
|
advantages * logp_ratio,
|
|
|
|
advantages * tf.clip_by_value(logp_ratio, 1 - clip_param,
|
|
|
|
1 + clip_param))
|
|
|
|
|
2019-07-29 15:02:32 -07:00
|
|
|
self.mean_kl = reduce_mean_valid(action_kl)
|
|
|
|
self.pi_loss = -reduce_mean_valid(surrogate_loss)
|
2019-03-29 12:44:23 -07:00
|
|
|
|
|
|
|
# The baseline loss
|
2019-07-29 15:02:32 -07:00
|
|
|
delta = values - value_targets
|
2019-03-29 12:44:23 -07:00
|
|
|
self.value_targets = value_targets
|
2019-07-29 15:02:32 -07:00
|
|
|
self.vf_loss = 0.5 * reduce_mean_valid(tf.square(delta))
|
2019-03-29 12:44:23 -07:00
|
|
|
|
|
|
|
# The entropy loss
|
2019-07-29 15:02:32 -07:00
|
|
|
self.entropy = reduce_mean_valid(actions_entropy)
|
2019-03-29 12:44:23 -07:00
|
|
|
|
|
|
|
# The summed weighted loss
|
|
|
|
self.total_loss = (self.pi_loss + self.vf_loss * vf_loss_coeff -
|
|
|
|
self.entropy * entropy_coeff)
|
|
|
|
|
2019-07-29 15:02:32 -07:00
|
|
|
# Optional additional KL Loss
|
|
|
|
if use_kl_loss:
|
|
|
|
self.total_loss += cur_kl_coeff * self.mean_kl
|
|
|
|
|
2019-03-29 12:44:23 -07:00
|
|
|
|
|
|
|
class VTraceSurrogateLoss(object):
|
|
|
|
def __init__(self,
|
|
|
|
actions,
|
|
|
|
prev_actions_logp,
|
|
|
|
actions_logp,
|
2019-07-29 15:02:32 -07:00
|
|
|
old_policy_actions_logp,
|
2019-03-29 12:44:23 -07:00
|
|
|
action_kl,
|
|
|
|
actions_entropy,
|
|
|
|
dones,
|
|
|
|
behaviour_logits,
|
2019-07-29 15:02:32 -07:00
|
|
|
old_policy_behaviour_logits,
|
2019-03-29 12:44:23 -07:00
|
|
|
target_logits,
|
|
|
|
discount,
|
|
|
|
rewards,
|
|
|
|
values,
|
|
|
|
bootstrap_value,
|
2019-05-16 22:05:07 -07:00
|
|
|
dist_class,
|
2019-08-10 14:05:12 -07:00
|
|
|
model,
|
2019-03-29 12:44:23 -07:00
|
|
|
valid_mask,
|
|
|
|
vf_loss_coeff=0.5,
|
|
|
|
entropy_coeff=0.01,
|
|
|
|
clip_rho_threshold=1.0,
|
|
|
|
clip_pg_rho_threshold=1.0,
|
2019-07-29 15:02:32 -07:00
|
|
|
clip_param=0.3,
|
|
|
|
cur_kl_coeff=None,
|
|
|
|
use_kl_loss=False):
|
|
|
|
"""APPO Loss, with IS modifications and V-trace for Advantage Estimation
|
2019-03-29 12:44:23 -07:00
|
|
|
|
|
|
|
VTraceLoss takes tensors of shape [T, B, ...], where `B` is the
|
|
|
|
batch_size. The reason we need to know `B` is for V-trace to properly
|
|
|
|
handle episode cut boundaries.
|
|
|
|
|
|
|
|
Arguments:
|
2019-05-16 22:05:07 -07:00
|
|
|
actions: An int|float32 tensor of shape [T, B, logit_dim].
|
2019-03-29 12:44:23 -07:00
|
|
|
prev_actions_logp: A float32 tensor of shape [T, B].
|
|
|
|
actions_logp: A float32 tensor of shape [T, B].
|
2019-07-29 15:02:32 -07:00
|
|
|
old_policy_actions_logp: A float32 tensor of shape [T, B].
|
2019-03-29 12:44:23 -07:00
|
|
|
action_kl: A float32 tensor of shape [T, B].
|
|
|
|
actions_entropy: A float32 tensor of shape [T, B].
|
|
|
|
dones: A bool tensor of shape [T, B].
|
2019-05-16 22:05:07 -07:00
|
|
|
behaviour_logits: A float32 tensor of shape [T, B, logit_dim].
|
2019-07-29 15:02:32 -07:00
|
|
|
old_policy_behaviour_logits: A float32 tensor of shape
|
|
|
|
[T, B, logit_dim].
|
2019-05-16 22:05:07 -07:00
|
|
|
target_logits: A float32 tensor of shape [T, B, logit_dim].
|
2019-03-29 12:44:23 -07:00
|
|
|
discount: A float32 scalar.
|
|
|
|
rewards: A float32 tensor of shape [T, B].
|
|
|
|
values: A float32 tensor of shape [T, B].
|
|
|
|
bootstrap_value: A float32 tensor of shape [B].
|
2019-05-16 22:05:07 -07:00
|
|
|
dist_class: action distribution class for logits.
|
2019-08-10 14:05:12 -07:00
|
|
|
model: backing ModelV2 instance
|
2019-03-29 12:44:23 -07:00
|
|
|
valid_mask: A bool tensor of valid RNN input elements (#2992).
|
2019-07-29 15:02:32 -07:00
|
|
|
vf_loss_coeff (float): Coefficient of the value function loss.
|
|
|
|
entropy_coeff (float): Coefficient of the entropy regularizer.
|
|
|
|
clip_param (float): Clip parameter.
|
|
|
|
cur_kl_coeff (float): Coefficient for KL loss.
|
|
|
|
use_kl_loss (bool): If true, use KL loss.
|
2019-03-29 12:44:23 -07:00
|
|
|
"""
|
|
|
|
|
2019-07-29 15:02:32 -07:00
|
|
|
def reduce_mean_valid(t):
|
|
|
|
return tf.reduce_mean(tf.boolean_mask(t, valid_mask))
|
|
|
|
|
2019-03-29 12:44:23 -07:00
|
|
|
# Compute vtrace on the CPU for better perf.
|
|
|
|
with tf.device("/cpu:0"):
|
|
|
|
self.vtrace_returns = vtrace.multi_from_logits(
|
|
|
|
behaviour_policy_logits=behaviour_logits,
|
2019-07-29 15:02:32 -07:00
|
|
|
target_policy_logits=old_policy_behaviour_logits,
|
2019-05-16 22:05:07 -07:00
|
|
|
actions=tf.unstack(actions, axis=2),
|
2019-03-29 12:44:23 -07:00
|
|
|
discounts=tf.to_float(~dones) * discount,
|
|
|
|
rewards=rewards,
|
|
|
|
values=values,
|
|
|
|
bootstrap_value=bootstrap_value,
|
2019-05-16 22:05:07 -07:00
|
|
|
dist_class=dist_class,
|
2019-08-10 14:05:12 -07:00
|
|
|
model=model,
|
2019-03-29 12:44:23 -07:00
|
|
|
clip_rho_threshold=tf.cast(clip_rho_threshold, tf.float32),
|
|
|
|
clip_pg_rho_threshold=tf.cast(clip_pg_rho_threshold,
|
|
|
|
tf.float32))
|
|
|
|
|
2019-07-29 15:02:32 -07:00
|
|
|
self.is_ratio = tf.clip_by_value(
|
|
|
|
tf.exp(prev_actions_logp - old_policy_actions_logp), 0.0, 2.0)
|
|
|
|
logp_ratio = self.is_ratio * tf.exp(actions_logp - prev_actions_logp)
|
2019-03-29 12:44:23 -07:00
|
|
|
|
|
|
|
advantages = self.vtrace_returns.pg_advantages
|
|
|
|
surrogate_loss = tf.minimum(
|
|
|
|
advantages * logp_ratio,
|
|
|
|
advantages * tf.clip_by_value(logp_ratio, 1 - clip_param,
|
|
|
|
1 + clip_param))
|
|
|
|
|
2019-07-29 15:02:32 -07:00
|
|
|
self.mean_kl = reduce_mean_valid(action_kl)
|
|
|
|
self.pi_loss = -reduce_mean_valid(surrogate_loss)
|
2019-03-29 12:44:23 -07:00
|
|
|
|
|
|
|
# The baseline loss
|
2019-07-29 15:02:32 -07:00
|
|
|
delta = values - self.vtrace_returns.vs
|
2019-03-29 12:44:23 -07:00
|
|
|
self.value_targets = self.vtrace_returns.vs
|
2019-07-29 15:02:32 -07:00
|
|
|
self.vf_loss = 0.5 * reduce_mean_valid(tf.square(delta))
|
2019-03-29 12:44:23 -07:00
|
|
|
|
|
|
|
# The entropy loss
|
2019-07-29 15:02:32 -07:00
|
|
|
self.entropy = reduce_mean_valid(actions_entropy)
|
2019-03-29 12:44:23 -07:00
|
|
|
|
|
|
|
# The summed weighted loss
|
|
|
|
self.total_loss = (self.pi_loss + self.vf_loss * vf_loss_coeff -
|
|
|
|
self.entropy * entropy_coeff)
|
|
|
|
|
2019-07-29 15:02:32 -07:00
|
|
|
# Optional additional KL Loss
|
|
|
|
if use_kl_loss:
|
|
|
|
self.total_loss += cur_kl_coeff * self.mean_kl
|
|
|
|
|
|
|
|
|
|
|
|
def build_appo_model(policy, obs_space, action_space, config):
|
2019-08-23 02:21:11 -04:00
|
|
|
_, logit_dim = ModelCatalog.get_action_dist(action_space, config["model"])
|
|
|
|
|
2019-07-29 15:02:32 -07:00
|
|
|
policy.model = ModelCatalog.get_model_v2(
|
|
|
|
obs_space,
|
|
|
|
action_space,
|
2019-08-23 02:21:11 -04:00
|
|
|
logit_dim,
|
2019-07-29 15:02:32 -07:00
|
|
|
config["model"],
|
|
|
|
name=POLICY_SCOPE,
|
|
|
|
framework="tf")
|
|
|
|
|
|
|
|
policy.target_model = ModelCatalog.get_model_v2(
|
|
|
|
obs_space,
|
|
|
|
action_space,
|
2019-08-23 02:21:11 -04:00
|
|
|
logit_dim,
|
2019-07-29 15:02:32 -07:00
|
|
|
config["model"],
|
|
|
|
name=TARGET_POLICY_SCOPE,
|
|
|
|
framework="tf")
|
|
|
|
|
|
|
|
return policy.model
|
|
|
|
|
2019-03-29 12:44:23 -07:00
|
|
|
|
2019-08-23 02:21:11 -04:00
|
|
|
def build_appo_surrogate_loss(policy, model, dist_class, train_batch):
|
|
|
|
model_out, _ = model.from_batch(train_batch)
|
|
|
|
action_dist = dist_class(model_out, model)
|
|
|
|
|
2019-05-18 00:23:11 -07:00
|
|
|
if isinstance(policy.action_space, gym.spaces.Discrete):
|
|
|
|
is_multidiscrete = False
|
|
|
|
output_hidden_shape = [policy.action_space.n]
|
|
|
|
elif isinstance(policy.action_space,
|
|
|
|
gym.spaces.multi_discrete.MultiDiscrete):
|
|
|
|
is_multidiscrete = True
|
|
|
|
output_hidden_shape = policy.action_space.nvec.astype(np.int32)
|
|
|
|
else:
|
|
|
|
is_multidiscrete = False
|
|
|
|
output_hidden_shape = 1
|
|
|
|
|
|
|
|
def make_time_major(*args, **kw):
|
2019-08-23 02:21:11 -04:00
|
|
|
return _make_time_major(policy, train_batch.get("seq_lens"), *args,
|
|
|
|
**kw)
|
2019-05-18 00:23:11 -07:00
|
|
|
|
2019-08-23 02:21:11 -04:00
|
|
|
actions = train_batch[SampleBatch.ACTIONS]
|
|
|
|
dones = train_batch[SampleBatch.DONES]
|
|
|
|
rewards = train_batch[SampleBatch.REWARDS]
|
|
|
|
behaviour_logits = train_batch[BEHAVIOUR_LOGITS]
|
2019-07-29 15:02:32 -07:00
|
|
|
|
2019-08-23 02:21:11 -04:00
|
|
|
target_model_out, _ = policy.target_model.from_batch(train_batch)
|
|
|
|
old_policy_behaviour_logits = tf.stop_gradient(target_model_out)
|
2019-07-29 15:02:32 -07:00
|
|
|
|
2019-05-18 00:23:11 -07:00
|
|
|
unpacked_behaviour_logits = tf.split(
|
|
|
|
behaviour_logits, output_hidden_shape, axis=1)
|
2019-07-29 15:02:32 -07:00
|
|
|
unpacked_old_policy_behaviour_logits = tf.split(
|
|
|
|
old_policy_behaviour_logits, output_hidden_shape, axis=1)
|
2019-08-23 02:21:11 -04:00
|
|
|
unpacked_outputs = tf.split(model_out, output_hidden_shape, axis=1)
|
|
|
|
old_policy_action_dist = dist_class(old_policy_behaviour_logits, model)
|
|
|
|
prev_action_dist = dist_class(behaviour_logits, policy.model)
|
|
|
|
values = policy.model.value_function()
|
2019-05-18 00:23:11 -07:00
|
|
|
|
2019-07-29 15:02:32 -07:00
|
|
|
policy.model_vars = policy.model.variables()
|
|
|
|
policy.target_model_vars = policy.target_model.variables()
|
|
|
|
|
2019-08-23 02:21:11 -04:00
|
|
|
if policy.is_recurrent():
|
|
|
|
max_seq_len = tf.reduce_max(train_batch["seq_lens"]) - 1
|
|
|
|
mask = tf.sequence_mask(train_batch["seq_lens"], max_seq_len)
|
2019-05-18 00:23:11 -07:00
|
|
|
mask = tf.reshape(mask, [-1])
|
|
|
|
else:
|
|
|
|
mask = tf.ones_like(rewards)
|
|
|
|
|
|
|
|
if policy.config["vtrace"]:
|
2019-08-23 02:21:11 -04:00
|
|
|
logger.debug("Using V-Trace surrogate loss (vtrace=True)")
|
2019-05-18 00:23:11 -07:00
|
|
|
|
|
|
|
# Prepare actions for loss
|
|
|
|
loss_actions = actions if is_multidiscrete else tf.expand_dims(
|
|
|
|
actions, axis=1)
|
|
|
|
|
2019-07-29 15:02:32 -07:00
|
|
|
# Prepare KL for Loss
|
|
|
|
mean_kl = make_time_major(
|
|
|
|
old_policy_action_dist.multi_kl(action_dist), drop_last=True)
|
|
|
|
|
2019-05-18 00:23:11 -07:00
|
|
|
policy.loss = VTraceSurrogateLoss(
|
|
|
|
actions=make_time_major(loss_actions, drop_last=True),
|
|
|
|
prev_actions_logp=make_time_major(
|
|
|
|
prev_action_dist.logp(actions), drop_last=True),
|
|
|
|
actions_logp=make_time_major(
|
|
|
|
action_dist.logp(actions), drop_last=True),
|
2019-07-29 15:02:32 -07:00
|
|
|
old_policy_actions_logp=make_time_major(
|
|
|
|
old_policy_action_dist.logp(actions), drop_last=True),
|
|
|
|
action_kl=tf.reduce_mean(mean_kl, axis=0)
|
|
|
|
if is_multidiscrete else mean_kl,
|
2019-05-18 00:23:11 -07:00
|
|
|
actions_entropy=make_time_major(
|
2019-07-19 12:12:04 -07:00
|
|
|
action_dist.multi_entropy(), drop_last=True),
|
2019-05-18 00:23:11 -07:00
|
|
|
dones=make_time_major(dones, drop_last=True),
|
|
|
|
behaviour_logits=make_time_major(
|
|
|
|
unpacked_behaviour_logits, drop_last=True),
|
2019-07-29 15:02:32 -07:00
|
|
|
old_policy_behaviour_logits=make_time_major(
|
|
|
|
unpacked_old_policy_behaviour_logits, drop_last=True),
|
2019-05-18 00:23:11 -07:00
|
|
|
target_logits=make_time_major(unpacked_outputs, drop_last=True),
|
|
|
|
discount=policy.config["gamma"],
|
|
|
|
rewards=make_time_major(rewards, drop_last=True),
|
|
|
|
values=make_time_major(values, drop_last=True),
|
|
|
|
bootstrap_value=make_time_major(values)[-1],
|
2019-08-23 02:21:11 -04:00
|
|
|
dist_class=Categorical if is_multidiscrete else dist_class,
|
2019-08-10 14:05:12 -07:00
|
|
|
model=policy.model,
|
2019-05-18 00:23:11 -07:00
|
|
|
valid_mask=make_time_major(mask, drop_last=True),
|
|
|
|
vf_loss_coeff=policy.config["vf_loss_coeff"],
|
2019-07-29 15:02:32 -07:00
|
|
|
entropy_coeff=policy.config["entropy_coeff"],
|
2019-05-18 00:23:11 -07:00
|
|
|
clip_rho_threshold=policy.config["vtrace_clip_rho_threshold"],
|
|
|
|
clip_pg_rho_threshold=policy.config[
|
|
|
|
"vtrace_clip_pg_rho_threshold"],
|
2019-07-29 15:02:32 -07:00
|
|
|
clip_param=policy.config["clip_param"],
|
|
|
|
cur_kl_coeff=policy.kl_coeff,
|
|
|
|
use_kl_loss=policy.config["use_kl_loss"])
|
2019-05-18 00:23:11 -07:00
|
|
|
else:
|
2019-08-23 02:21:11 -04:00
|
|
|
logger.debug("Using PPO surrogate loss (vtrace=False)")
|
2019-07-29 15:02:32 -07:00
|
|
|
|
|
|
|
# Prepare KL for Loss
|
|
|
|
mean_kl = make_time_major(prev_action_dist.multi_kl(action_dist))
|
|
|
|
|
2019-05-18 00:23:11 -07:00
|
|
|
policy.loss = PPOSurrogateLoss(
|
|
|
|
prev_actions_logp=make_time_major(prev_action_dist.logp(actions)),
|
|
|
|
actions_logp=make_time_major(action_dist.logp(actions)),
|
2019-07-29 15:02:32 -07:00
|
|
|
action_kl=tf.reduce_mean(mean_kl, axis=0)
|
|
|
|
if is_multidiscrete else mean_kl,
|
2019-07-19 12:12:04 -07:00
|
|
|
actions_entropy=make_time_major(action_dist.multi_entropy()),
|
2019-05-18 00:23:11 -07:00
|
|
|
values=make_time_major(values),
|
|
|
|
valid_mask=make_time_major(mask),
|
2019-08-23 02:21:11 -04:00
|
|
|
advantages=make_time_major(train_batch[Postprocessing.ADVANTAGES]),
|
2019-05-18 00:23:11 -07:00
|
|
|
value_targets=make_time_major(
|
2019-08-23 02:21:11 -04:00
|
|
|
train_batch[Postprocessing.VALUE_TARGETS]),
|
2019-05-18 00:23:11 -07:00
|
|
|
vf_loss_coeff=policy.config["vf_loss_coeff"],
|
2019-07-29 15:02:32 -07:00
|
|
|
entropy_coeff=policy.config["entropy_coeff"],
|
|
|
|
clip_param=policy.config["clip_param"],
|
|
|
|
cur_kl_coeff=policy.kl_coeff,
|
|
|
|
use_kl_loss=policy.config["use_kl_loss"])
|
2019-05-18 00:23:11 -07:00
|
|
|
|
|
|
|
return policy.loss.total_loss
|
|
|
|
|
|
|
|
|
2019-08-23 02:21:11 -04:00
|
|
|
def stats(policy, train_batch):
|
2019-07-29 15:02:32 -07:00
|
|
|
values_batched = _make_time_major(
|
2019-08-23 02:21:11 -04:00
|
|
|
policy,
|
|
|
|
train_batch.get("seq_lens"),
|
|
|
|
policy.model.value_function(),
|
|
|
|
drop_last=policy.config["vtrace"])
|
2019-07-29 15:02:32 -07:00
|
|
|
|
|
|
|
stats_dict = {
|
|
|
|
"cur_lr": tf.cast(policy.cur_lr, tf.float64),
|
|
|
|
"policy_loss": policy.loss.pi_loss,
|
|
|
|
"entropy": policy.loss.entropy,
|
2019-08-23 02:21:11 -04:00
|
|
|
"var_gnorm": tf.global_norm(policy.model.trainable_variables()),
|
2019-07-29 15:02:32 -07:00
|
|
|
"vf_loss": policy.loss.vf_loss,
|
|
|
|
"vf_explained_var": explained_variance(
|
|
|
|
tf.reshape(policy.loss.value_targets, [-1]),
|
|
|
|
tf.reshape(values_batched, [-1])),
|
|
|
|
}
|
|
|
|
|
|
|
|
if policy.config["vtrace"]:
|
|
|
|
is_stat_mean, is_stat_var = tf.nn.moments(policy.loss.is_ratio, [0, 1])
|
|
|
|
stats_dict.update({"mean_IS": is_stat_mean})
|
|
|
|
stats_dict.update({"var_IS": is_stat_var})
|
|
|
|
|
|
|
|
if policy.config["use_kl_loss"]:
|
|
|
|
stats_dict.update({"kl": policy.loss.mean_kl})
|
|
|
|
stats_dict.update({"KL_Coeff": policy.kl_coeff})
|
|
|
|
|
|
|
|
return stats_dict
|
|
|
|
|
|
|
|
|
2019-05-18 00:23:11 -07:00
|
|
|
def postprocess_trajectory(policy,
|
|
|
|
sample_batch,
|
|
|
|
other_agent_batches=None,
|
|
|
|
episode=None):
|
|
|
|
if not policy.config["vtrace"]:
|
|
|
|
completed = sample_batch["dones"][-1]
|
|
|
|
if completed:
|
|
|
|
last_r = 0.0
|
2019-03-29 12:44:23 -07:00
|
|
|
else:
|
2019-05-18 00:23:11 -07:00
|
|
|
next_state = []
|
2019-08-23 02:21:11 -04:00
|
|
|
for i in range(policy.num_state_tensors()):
|
2019-05-18 00:23:11 -07:00
|
|
|
next_state.append([sample_batch["state_out_{}".format(i)][-1]])
|
2019-08-23 02:21:11 -04:00
|
|
|
last_r = policy._value(sample_batch[SampleBatch.NEXT_OBS][-1],
|
|
|
|
sample_batch[SampleBatch.ACTIONS][-1],
|
|
|
|
sample_batch[SampleBatch.REWARDS][-1],
|
|
|
|
*next_state)
|
2019-05-18 00:23:11 -07:00
|
|
|
batch = compute_advantages(
|
|
|
|
sample_batch,
|
|
|
|
last_r,
|
|
|
|
policy.config["gamma"],
|
|
|
|
policy.config["lambda"],
|
|
|
|
use_gae=policy.config["use_gae"])
|
|
|
|
else:
|
|
|
|
batch = sample_batch
|
|
|
|
del batch.data["new_obs"] # not used, so save some bandwidth
|
|
|
|
return batch
|
|
|
|
|
|
|
|
|
|
|
|
def add_values_and_logits(policy):
|
2019-08-23 02:21:11 -04:00
|
|
|
out = {BEHAVIOUR_LOGITS: policy.model.last_output()}
|
2019-05-18 00:23:11 -07:00
|
|
|
if not policy.config["vtrace"]:
|
2019-08-23 02:21:11 -04:00
|
|
|
out[SampleBatch.VF_PREDS] = policy.model.value_function()
|
2019-05-18 00:23:11 -07:00
|
|
|
return out
|
|
|
|
|
|
|
|
|
2019-07-29 15:02:32 -07:00
|
|
|
class TargetNetworkMixin(object):
|
|
|
|
def __init__(self, obs_space, action_space, config):
|
|
|
|
"""Target Network is updated by the master learner every
|
|
|
|
trainer.update_target_frequency steps. All worker batches
|
|
|
|
are importance sampled w.r. to the target network to ensure
|
|
|
|
a more stable pi_old in PPO.
|
|
|
|
"""
|
|
|
|
|
2019-08-23 02:21:11 -04:00
|
|
|
@make_tf_callable(self.get_session())
|
|
|
|
def do_update():
|
|
|
|
assign_ops = []
|
|
|
|
assert len(self.model_vars) == len(self.target_model_vars)
|
|
|
|
for var, var_target in zip(self.model_vars,
|
|
|
|
self.target_model_vars):
|
|
|
|
assign_ops.append(var_target.assign(var))
|
|
|
|
return tf.group(*assign_ops)
|
|
|
|
|
|
|
|
self.update_target = do_update
|
2019-07-29 15:02:32 -07:00
|
|
|
|
2019-10-31 15:16:02 -07:00
|
|
|
@override(TFPolicy)
|
|
|
|
def variables(self):
|
|
|
|
return self.model_vars + self.target_model_vars
|
|
|
|
|
2019-07-29 15:02:32 -07:00
|
|
|
|
|
|
|
def setup_mixins(policy, obs_space, action_space, config):
|
|
|
|
LearningRateSchedule.__init__(policy, config["lr"], config["lr_schedule"])
|
|
|
|
KLCoeffMixin.__init__(policy, config)
|
2019-08-23 02:21:11 -04:00
|
|
|
ValueNetworkMixin.__init__(policy, obs_space, action_space, config)
|
2019-07-29 15:02:32 -07:00
|
|
|
|
|
|
|
|
|
|
|
def setup_late_mixins(policy, obs_space, action_space, config):
|
|
|
|
TargetNetworkMixin.__init__(policy, obs_space, action_space, config)
|
|
|
|
|
|
|
|
|
|
|
|
AsyncPPOTFPolicy = build_tf_policy(
|
2019-05-18 00:23:11 -07:00
|
|
|
name="AsyncPPOTFPolicy",
|
2019-07-29 15:02:32 -07:00
|
|
|
make_model=build_appo_model,
|
2019-05-18 00:23:11 -07:00
|
|
|
loss_fn=build_appo_surrogate_loss,
|
2019-07-29 15:02:32 -07:00
|
|
|
stats_fn=stats,
|
2019-05-18 00:23:11 -07:00
|
|
|
postprocess_fn=postprocess_trajectory,
|
2019-07-29 15:02:32 -07:00
|
|
|
optimizer_fn=choose_optimizer,
|
|
|
|
gradients_fn=clip_gradients,
|
|
|
|
extra_action_fetches_fn=add_values_and_logits,
|
|
|
|
before_init=validate_config,
|
|
|
|
before_loss_init=setup_mixins,
|
|
|
|
after_init=setup_late_mixins,
|
|
|
|
mixins=[
|
|
|
|
LearningRateSchedule, KLCoeffMixin, TargetNetworkMixin,
|
|
|
|
ValueNetworkMixin
|
|
|
|
],
|
|
|
|
get_batch_divisibility_req=lambda p: p.config["sample_batch_size"])
|