mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
Fix multi discrete (#4338)
* Revert "Revert "[wingman -> rllib] IMPALA MultiDiscrete changes (#3967)" (#4332)"
This reverts commit 3c41cb9b60
.
* Fix a bug with log rhos for vtrace
* Reformat
* lint
This commit is contained in:
parent
490d896f41
commit
2202a81773
8 changed files with 653 additions and 130 deletions
|
@ -410,3 +410,6 @@ docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \
|
|||
--stop='{"timesteps_total": 40000}' \
|
||||
--ray-object-store-memory=500000000 \
|
||||
--config '{"num_workers": 1, "num_gpus": 0, "num_envs_per_worker": 64, "sample_batch_size": 50, "train_batch_size": 50, "learner_queue_size": 1}'
|
||||
|
||||
docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \
|
||||
python /ray/python/ray/rllib/agents/impala/vtrace_test.py
|
||||
|
|
|
@ -20,6 +20,12 @@ Importance Weighted Actor-Learner Architectures"
|
|||
by Espeholt, Soyer, Munos et al.
|
||||
|
||||
See https://arxiv.org/abs/1802.01561 for the full paper.
|
||||
|
||||
In addition to the original paper's code, changes have been made
|
||||
to support MultiDiscrete action spaces. behaviour_policy_logits,
|
||||
target_policy_logits and actions parameters in the entry point
|
||||
multi_from_logits method accepts lists of tensors instead of just
|
||||
tensors.
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
|
@ -41,29 +47,48 @@ VTraceReturns = collections.namedtuple('VTraceReturns', 'vs pg_advantages')
|
|||
|
||||
|
||||
def log_probs_from_logits_and_actions(policy_logits, actions):
|
||||
return multi_log_probs_from_logits_and_actions([policy_logits],
|
||||
[actions])[0]
|
||||
|
||||
|
||||
def multi_log_probs_from_logits_and_actions(policy_logits, actions):
|
||||
"""Computes action log-probs from policy logits and actions.
|
||||
|
||||
In the notation used throughout documentation and comments, T refers to the
|
||||
time dimension ranging from 0 to T-1. B refers to the batch size and
|
||||
NUM_ACTIONS refers to the number of actions.
|
||||
ACTION_SPACE refers to the list of numbers each representing a number of
|
||||
actions.
|
||||
|
||||
Args:
|
||||
policy_logits: A float32 tensor of shape [T, B, NUM_ACTIONS] with
|
||||
un-normalized log-probabilities parameterizing a softmax policy.
|
||||
actions: An int32 tensor of shape [T, B] with actions.
|
||||
policy_logits: A list with length of ACTION_SPACE of float32
|
||||
tensors of shapes
|
||||
[T, B, ACTION_SPACE[0]],
|
||||
...,
|
||||
[T, B, ACTION_SPACE[-1]]
|
||||
with un-normalized log-probabilities parameterizing a softmax policy.
|
||||
actions: A list with length of ACTION_SPACE of int32
|
||||
tensors of shapes
|
||||
[T, B],
|
||||
...,
|
||||
[T, B]
|
||||
with actions.
|
||||
|
||||
Returns:
|
||||
A float32 tensor of shape [T, B] corresponding to the sampling log
|
||||
probability of the chosen action w.r.t. the policy.
|
||||
A list with length of ACTION_SPACE of float32
|
||||
tensors of shapes
|
||||
[T, B],
|
||||
...,
|
||||
[T, B]
|
||||
corresponding to the sampling log probability
|
||||
of the chosen action w.r.t. the policy.
|
||||
"""
|
||||
policy_logits = tf.convert_to_tensor(policy_logits, dtype=tf.float32)
|
||||
actions = tf.convert_to_tensor(actions, dtype=tf.int32)
|
||||
|
||||
policy_logits.shape.assert_has_rank(3)
|
||||
actions.shape.assert_has_rank(2)
|
||||
log_probs = []
|
||||
for i in range(len(policy_logits)):
|
||||
log_probs.append(-tf.nn.sparse_softmax_cross_entropy_with_logits(
|
||||
logits=policy_logits[i], labels=actions[i]))
|
||||
|
||||
return -tf.nn.sparse_softmax_cross_entropy_with_logits(
|
||||
logits=policy_logits, labels=actions)
|
||||
return log_probs
|
||||
|
||||
|
||||
def from_logits(behaviour_policy_logits,
|
||||
|
@ -76,6 +101,39 @@ def from_logits(behaviour_policy_logits,
|
|||
clip_rho_threshold=1.0,
|
||||
clip_pg_rho_threshold=1.0,
|
||||
name='vtrace_from_logits'):
|
||||
"""multi_from_logits wrapper used only for tests"""
|
||||
|
||||
res = multi_from_logits(
|
||||
[behaviour_policy_logits], [target_policy_logits], [actions],
|
||||
discounts,
|
||||
rewards,
|
||||
values,
|
||||
bootstrap_value,
|
||||
clip_rho_threshold=clip_rho_threshold,
|
||||
clip_pg_rho_threshold=clip_pg_rho_threshold,
|
||||
name=name)
|
||||
|
||||
return VTraceFromLogitsReturns(
|
||||
vs=res.vs,
|
||||
pg_advantages=res.pg_advantages,
|
||||
log_rhos=res.log_rhos,
|
||||
behaviour_action_log_probs=tf.squeeze(
|
||||
res.behaviour_action_log_probs, axis=0),
|
||||
target_action_log_probs=tf.squeeze(
|
||||
res.target_action_log_probs, axis=0),
|
||||
)
|
||||
|
||||
|
||||
def multi_from_logits(behaviour_policy_logits,
|
||||
target_policy_logits,
|
||||
actions,
|
||||
discounts,
|
||||
rewards,
|
||||
values,
|
||||
bootstrap_value,
|
||||
clip_rho_threshold=1.0,
|
||||
clip_pg_rho_threshold=1.0,
|
||||
name='vtrace_from_logits'):
|
||||
r"""V-trace for softmax policies.
|
||||
|
||||
Calculates V-trace actor critic targets for softmax polices as described in
|
||||
|
@ -90,16 +148,30 @@ def from_logits(behaviour_policy_logits,
|
|||
|
||||
In the notation used throughout documentation and comments, T refers to the
|
||||
time dimension ranging from 0 to T-1. B refers to the batch size and
|
||||
NUM_ACTIONS refers to the number of actions.
|
||||
ACTION_SPACE refers to the list of numbers each representing a number of
|
||||
actions.
|
||||
|
||||
Args:
|
||||
behaviour_policy_logits: A float32 tensor of shape [T, B, NUM_ACTIONS] with
|
||||
un-normalized log-probabilities parametrizing the softmax behaviour
|
||||
behaviour_policy_logits: A list with length of ACTION_SPACE of float32
|
||||
tensors of shapes
|
||||
[T, B, ACTION_SPACE[0]],
|
||||
...,
|
||||
[T, B, ACTION_SPACE[-1]]
|
||||
with un-normalized log-probabilities parameterizing the softmax behaviour
|
||||
policy.
|
||||
target_policy_logits: A float32 tensor of shape [T, B, NUM_ACTIONS] with
|
||||
un-normalized log-probabilities parametrizing the softmax target policy.
|
||||
actions: An int32 tensor of shape [T, B] of actions sampled from the
|
||||
behaviour policy.
|
||||
target_policy_logits: A list with length of ACTION_SPACE of float32
|
||||
tensors of shapes
|
||||
[T, B, ACTION_SPACE[0]],
|
||||
...,
|
||||
[T, B, ACTION_SPACE[-1]]
|
||||
with un-normalized log-probabilities parameterizing the softmax target
|
||||
policy.
|
||||
actions: A list with length of ACTION_SPACE of int32
|
||||
tensors of shapes
|
||||
[T, B],
|
||||
...,
|
||||
[T, B]
|
||||
with actions sampled from the behaviour policy.
|
||||
discounts: A float32 tensor of shape [T, B] with the discount encountered
|
||||
when following the behaviour policy.
|
||||
rewards: A float32 tensor of shape [T, B] with the rewards generated by
|
||||
|
@ -128,17 +200,19 @@ def from_logits(behaviour_policy_logits,
|
|||
target_action_log_probs: A float32 tensor of shape [T, B] containing
|
||||
target policy action probabilities (log \pi(a_t)).
|
||||
"""
|
||||
behaviour_policy_logits = tf.convert_to_tensor(
|
||||
behaviour_policy_logits, dtype=tf.float32)
|
||||
target_policy_logits = tf.convert_to_tensor(
|
||||
target_policy_logits, dtype=tf.float32)
|
||||
actions = tf.convert_to_tensor(actions, dtype=tf.int32)
|
||||
|
||||
# Make sure tensor ranks are as expected.
|
||||
# The rest will be checked by from_action_log_probs.
|
||||
behaviour_policy_logits.shape.assert_has_rank(3)
|
||||
target_policy_logits.shape.assert_has_rank(3)
|
||||
actions.shape.assert_has_rank(2)
|
||||
for i in range(len(behaviour_policy_logits)):
|
||||
behaviour_policy_logits[i] = tf.convert_to_tensor(
|
||||
behaviour_policy_logits[i], dtype=tf.float32)
|
||||
target_policy_logits[i] = tf.convert_to_tensor(
|
||||
target_policy_logits[i], dtype=tf.float32)
|
||||
actions[i] = tf.convert_to_tensor(actions[i], dtype=tf.int32)
|
||||
|
||||
# Make sure tensor ranks are as expected.
|
||||
# The rest will be checked by from_action_log_probs.
|
||||
behaviour_policy_logits[i].shape.assert_has_rank(3)
|
||||
target_policy_logits[i].shape.assert_has_rank(3)
|
||||
actions[i].shape.assert_has_rank(2)
|
||||
|
||||
with tf.name_scope(
|
||||
name,
|
||||
|
@ -146,11 +220,14 @@ def from_logits(behaviour_policy_logits,
|
|||
behaviour_policy_logits, target_policy_logits, actions,
|
||||
discounts, rewards, values, bootstrap_value
|
||||
]):
|
||||
target_action_log_probs = log_probs_from_logits_and_actions(
|
||||
target_action_log_probs = multi_log_probs_from_logits_and_actions(
|
||||
target_policy_logits, actions)
|
||||
behaviour_action_log_probs = log_probs_from_logits_and_actions(
|
||||
behaviour_action_log_probs = multi_log_probs_from_logits_and_actions(
|
||||
behaviour_policy_logits, actions)
|
||||
log_rhos = target_action_log_probs - behaviour_action_log_probs
|
||||
|
||||
log_rhos = get_log_rhos(target_action_log_probs,
|
||||
behaviour_action_log_probs)
|
||||
|
||||
vtrace_returns = from_importance_weights(
|
||||
log_rhos=log_rhos,
|
||||
discounts=discounts,
|
||||
|
@ -159,6 +236,7 @@ def from_logits(behaviour_policy_logits,
|
|||
bootstrap_value=bootstrap_value,
|
||||
clip_rho_threshold=clip_rho_threshold,
|
||||
clip_pg_rho_threshold=clip_pg_rho_threshold)
|
||||
|
||||
return VTraceFromLogitsReturns(
|
||||
log_rhos=log_rhos,
|
||||
behaviour_action_log_probs=behaviour_action_log_probs,
|
||||
|
@ -183,13 +261,13 @@ def from_importance_weights(log_rhos,
|
|||
by Espeholt, Soyer, Munos et al.
|
||||
|
||||
In the notation used throughout documentation and comments, T refers to the
|
||||
time dimension ranging from 0 to T-1. B refers to the batch size and
|
||||
NUM_ACTIONS refers to the number of actions. This code also supports the
|
||||
case where all tensors have the same number of additional dimensions, e.g.,
|
||||
`rewards` is [T, B, C], `values` is [T, B, C], `bootstrap_value` is [B, C].
|
||||
time dimension ranging from 0 to T-1. B refers to the batch size. This code
|
||||
also supports the case where all tensors have the same number of additional
|
||||
dimensions, e.g., `rewards` is [T, B, C], `values` is [T, B, C],
|
||||
`bootstrap_value` is [B, C].
|
||||
|
||||
Args:
|
||||
log_rhos: A float32 tensor of shape [T, B, NUM_ACTIONS] representing the
|
||||
log_rhos: A float32 tensor of shape [T, B] representing the
|
||||
log importance sampling weights, i.e.
|
||||
log(target_policy(a) / behaviour_policy(a)). V-trace performs operations
|
||||
on rhos in log-space for numerical stability.
|
||||
|
@ -246,6 +324,14 @@ def from_importance_weights(log_rhos,
|
|||
if clip_rho_threshold is not None:
|
||||
clipped_rhos = tf.minimum(
|
||||
clip_rho_threshold, rhos, name='clipped_rhos')
|
||||
|
||||
tf.summary.histogram('clipped_rhos_1000', tf.minimum(1000.0, rhos))
|
||||
tf.summary.scalar(
|
||||
'num_of_clipped_rhos',
|
||||
tf.reduce_sum(
|
||||
tf.cast(
|
||||
tf.equal(clipped_rhos, clip_rho_threshold), tf.int32)))
|
||||
tf.summary.scalar('size_of_clipped_rhos', tf.size(clipped_rhos))
|
||||
else:
|
||||
clipped_rhos = rhos
|
||||
|
||||
|
@ -298,3 +384,12 @@ def from_importance_weights(log_rhos,
|
|||
return VTraceReturns(
|
||||
vs=tf.stop_gradient(vs),
|
||||
pg_advantages=tf.stop_gradient(pg_advantages))
|
||||
|
||||
|
||||
def get_log_rhos(target_action_log_probs, behaviour_action_log_probs):
|
||||
"""With the selected log_probs for multi-discrete actions of behaviour
|
||||
and target policies we compute the log_rhos for calculating the vtrace."""
|
||||
t = tf.stack(target_action_log_probs)
|
||||
b = tf.stack(behaviour_action_log_probs)
|
||||
log_rhos = tf.reduce_sum(t - b, axis=0)
|
||||
return log_rhos
|
||||
|
|
|
@ -6,19 +6,19 @@ from __future__ import absolute_import
|
|||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import tensorflow as tf
|
||||
import gym
|
||||
|
||||
import ray
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
from ray.rllib.agents.impala import vtrace
|
||||
from ray.rllib.evaluation.policy_graph import PolicyGraph
|
||||
from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph, \
|
||||
LearningRateSchedule
|
||||
from ray.rllib.models.action_dist import MultiCategorical
|
||||
from ray.rllib.models.catalog import ModelCatalog
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.error import UnsupportedSpaceException
|
||||
from ray.rllib.utils.explained_variance import explained_variance
|
||||
from ray.rllib.models.action_dist import Categorical
|
||||
|
||||
|
||||
class VTraceLoss(object):
|
||||
|
@ -45,12 +45,20 @@ class VTraceLoss(object):
|
|||
handle episode cut boundaries.
|
||||
|
||||
Args:
|
||||
actions: An int32 tensor of shape [T, B, NUM_ACTIONS].
|
||||
actions: An int32 tensor of shape [T, B, ACTION_SPACE].
|
||||
actions_logp: A float32 tensor of shape [T, B].
|
||||
actions_entropy: A float32 tensor of shape [T, B].
|
||||
dones: A bool tensor of shape [T, B].
|
||||
behaviour_logits: A float32 tensor of shape [T, B, NUM_ACTIONS].
|
||||
target_logits: A float32 tensor of shape [T, B, NUM_ACTIONS].
|
||||
behaviour_logits: A list with length of ACTION_SPACE of float32
|
||||
tensors of shapes
|
||||
[T, B, ACTION_SPACE[0]],
|
||||
...,
|
||||
[T, B, ACTION_SPACE[-1]]
|
||||
target_logits: A list with length of ACTION_SPACE of float32
|
||||
tensors of shapes
|
||||
[T, B, ACTION_SPACE[0]],
|
||||
...,
|
||||
[T, B, ACTION_SPACE[-1]]
|
||||
discount: A float32 scalar.
|
||||
rewards: A float32 tensor of shape [T, B].
|
||||
values: A float32 tensor of shape [T, B].
|
||||
|
@ -60,10 +68,10 @@ class VTraceLoss(object):
|
|||
|
||||
# Compute vtrace on the CPU for better perf.
|
||||
with tf.device("/cpu:0"):
|
||||
self.vtrace_returns = vtrace.from_logits(
|
||||
self.vtrace_returns = vtrace.multi_from_logits(
|
||||
behaviour_policy_logits=behaviour_logits,
|
||||
target_policy_logits=target_logits,
|
||||
actions=tf.cast(actions, tf.int32),
|
||||
actions=tf.unstack(tf.cast(actions, tf.int32), axis=2),
|
||||
discounts=tf.to_float(~dones) * discount,
|
||||
rewards=rewards,
|
||||
values=values,
|
||||
|
@ -101,6 +109,20 @@ class VTracePolicyGraph(LearningRateSchedule, TFPolicyGraph):
|
|||
"Must use `truncate_episodes` batch mode with V-trace."
|
||||
self.config = config
|
||||
self.sess = tf.get_default_session()
|
||||
self.grads = None
|
||||
|
||||
if isinstance(action_space, gym.spaces.Discrete):
|
||||
is_multidiscrete = False
|
||||
actions_shape = [None]
|
||||
output_hidden_shape = [action_space.n]
|
||||
elif isinstance(action_space, gym.spaces.multi_discrete.MultiDiscrete):
|
||||
is_multidiscrete = True
|
||||
actions_shape = [None, len(action_space.nvec)]
|
||||
output_hidden_shape = action_space.nvec.astype(np.int32)
|
||||
else:
|
||||
raise UnsupportedSpaceException(
|
||||
"Action space {} is not supported for IMPALA.".format(
|
||||
action_space))
|
||||
|
||||
# Create input placeholders
|
||||
if existing_inputs:
|
||||
|
@ -109,22 +131,21 @@ class VTracePolicyGraph(LearningRateSchedule, TFPolicyGraph):
|
|||
existing_state_in = existing_inputs[7:-1]
|
||||
existing_seq_lens = existing_inputs[-1]
|
||||
else:
|
||||
if isinstance(action_space, gym.spaces.Discrete):
|
||||
ac_size = action_space.n
|
||||
actions = tf.placeholder(tf.int64, [None], name="ac")
|
||||
else:
|
||||
raise UnsupportedSpaceException(
|
||||
"Action space {} is not supported for IMPALA.".format(
|
||||
action_space))
|
||||
actions = tf.placeholder(tf.int64, actions_shape, name="ac")
|
||||
dones = tf.placeholder(tf.bool, [None], name="dones")
|
||||
rewards = tf.placeholder(tf.float32, [None], name="rewards")
|
||||
behaviour_logits = tf.placeholder(
|
||||
tf.float32, [None, ac_size], name="behaviour_logits")
|
||||
tf.float32, [None, sum(output_hidden_shape)],
|
||||
name="behaviour_logits")
|
||||
observations = tf.placeholder(
|
||||
tf.float32, [None] + list(observation_space.shape))
|
||||
existing_state_in = None
|
||||
existing_seq_lens = None
|
||||
|
||||
# Unpack behaviour logits
|
||||
unpacked_behaviour_logits = tf.split(
|
||||
behaviour_logits, output_hidden_shape, axis=1)
|
||||
|
||||
# Setup the policy
|
||||
dist_class, logit_dim = ModelCatalog.get_action_dist(
|
||||
action_space, self.config["model"])
|
||||
|
@ -143,12 +164,30 @@ class VTracePolicyGraph(LearningRateSchedule, TFPolicyGraph):
|
|||
self.config["model"],
|
||||
state_in=existing_state_in,
|
||||
seq_lens=existing_seq_lens)
|
||||
action_dist = dist_class(self.model.outputs)
|
||||
unpacked_outputs = tf.split(
|
||||
self.model.outputs, output_hidden_shape, axis=1)
|
||||
|
||||
dist_inputs = unpacked_outputs if is_multidiscrete else \
|
||||
self.model.outputs
|
||||
action_dist = dist_class(dist_inputs)
|
||||
|
||||
values = self.model.value_function()
|
||||
self.var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
|
||||
tf.get_variable_scope().name)
|
||||
|
||||
def to_batches(tensor):
|
||||
def make_time_major(tensor, drop_last=False):
|
||||
"""Swaps batch and trajectory axis.
|
||||
Args:
|
||||
tensor: A tensor or list of tensors to reshape.
|
||||
drop_last: A bool indicating whether to drop the last
|
||||
trajectory item.
|
||||
Returns:
|
||||
res: A tensor with swapped axes or a list of tensors with
|
||||
swapped axes.
|
||||
"""
|
||||
if isinstance(tensor, list):
|
||||
return [make_time_major(t, drop_last) for t in tensor]
|
||||
|
||||
if self.model.state_init:
|
||||
B = tf.shape(self.model.seq_lens)[0]
|
||||
T = tf.shape(tensor)[0] // B
|
||||
|
@ -159,11 +198,16 @@ class VTracePolicyGraph(LearningRateSchedule, TFPolicyGraph):
|
|||
B = tf.shape(tensor)[0] // T
|
||||
rs = tf.reshape(tensor,
|
||||
tf.concat([[B, T], tf.shape(tensor)[1:]], axis=0))
|
||||
|
||||
# swap B and T axes
|
||||
return tf.transpose(
|
||||
res = tf.transpose(
|
||||
rs,
|
||||
[1, 0] + list(range(2, 1 + int(tf.shape(tensor).shape[0]))))
|
||||
|
||||
if drop_last:
|
||||
return res[:-1]
|
||||
return res
|
||||
|
||||
if self.model.state_in:
|
||||
max_seq_len = tf.reduce_max(self.model.seq_lens) - 1
|
||||
mask = tf.sequence_mask(self.model.seq_lens, max_seq_len)
|
||||
|
@ -171,31 +215,52 @@ class VTracePolicyGraph(LearningRateSchedule, TFPolicyGraph):
|
|||
else:
|
||||
mask = tf.ones_like(rewards, dtype=tf.bool)
|
||||
|
||||
# Prepare actions for loss
|
||||
loss_actions = actions if is_multidiscrete else tf.expand_dims(
|
||||
actions, axis=1)
|
||||
|
||||
# Inputs are reshaped from [B * T] => [T - 1, B] for V-trace calc.
|
||||
self.loss = VTraceLoss(
|
||||
actions=to_batches(actions)[:-1],
|
||||
actions_logp=to_batches(action_dist.logp(actions))[:-1],
|
||||
actions_entropy=to_batches(action_dist.entropy())[:-1],
|
||||
dones=to_batches(dones)[:-1],
|
||||
behaviour_logits=to_batches(behaviour_logits)[:-1],
|
||||
target_logits=to_batches(self.model.outputs)[:-1],
|
||||
actions=make_time_major(loss_actions, drop_last=True),
|
||||
actions_logp=make_time_major(
|
||||
action_dist.logp(actions), drop_last=True),
|
||||
actions_entropy=make_time_major(
|
||||
action_dist.entropy(), drop_last=True),
|
||||
dones=make_time_major(dones, drop_last=True),
|
||||
behaviour_logits=make_time_major(
|
||||
unpacked_behaviour_logits, drop_last=True),
|
||||
target_logits=make_time_major(unpacked_outputs, drop_last=True),
|
||||
discount=config["gamma"],
|
||||
rewards=to_batches(rewards)[:-1],
|
||||
values=to_batches(values)[:-1],
|
||||
bootstrap_value=to_batches(values)[-1],
|
||||
valid_mask=to_batches(mask)[:-1],
|
||||
rewards=make_time_major(rewards, drop_last=True),
|
||||
values=make_time_major(values, drop_last=True),
|
||||
bootstrap_value=make_time_major(values)[-1],
|
||||
valid_mask=make_time_major(mask, drop_last=True),
|
||||
vf_loss_coeff=self.config["vf_loss_coeff"],
|
||||
entropy_coeff=self.config["entropy_coeff"],
|
||||
clip_rho_threshold=self.config["vtrace_clip_rho_threshold"],
|
||||
clip_pg_rho_threshold=self.config["vtrace_clip_pg_rho_threshold"])
|
||||
|
||||
# KL divergence between worker and learner logits for debugging
|
||||
model_dist = Categorical(self.model.outputs)
|
||||
behaviour_dist = Categorical(behaviour_logits)
|
||||
self.KLs = model_dist.kl(behaviour_dist)
|
||||
self.mean_KL = tf.reduce_mean(self.KLs)
|
||||
self.max_KL = tf.reduce_max(self.KLs)
|
||||
self.median_KL = tf.contrib.distributions.percentile(self.KLs, 50.0)
|
||||
model_dist = MultiCategorical(unpacked_outputs)
|
||||
behaviour_dist = MultiCategorical(unpacked_behaviour_logits)
|
||||
|
||||
kls = model_dist.kl(behaviour_dist)
|
||||
if len(kls) > 1:
|
||||
self.KL_stats = {}
|
||||
|
||||
for i, kl in enumerate(kls):
|
||||
self.KL_stats.update({
|
||||
"mean_KL_{}".format(i): tf.reduce_mean(kl),
|
||||
"max_KL_{}".format(i): tf.reduce_max(kl),
|
||||
"median_KL_{}".format(i): tf.contrib.distributions.
|
||||
percentile(kl, 50.0),
|
||||
})
|
||||
else:
|
||||
self.KL_stats = {
|
||||
"mean_KL": tf.reduce_mean(kls[0]),
|
||||
"max_KL": tf.reduce_max(kls[0]),
|
||||
"median_KL": tf.contrib.distributions.percentile(kls[0], 50.0),
|
||||
}
|
||||
|
||||
# Initialize TFPolicyGraph
|
||||
loss_in = [
|
||||
|
@ -231,7 +296,7 @@ class VTracePolicyGraph(LearningRateSchedule, TFPolicyGraph):
|
|||
self.sess.run(tf.global_variables_initializer())
|
||||
|
||||
self.stats_fetches = {
|
||||
"stats": {
|
||||
"stats": dict({
|
||||
"cur_lr": tf.cast(self.cur_lr, tf.float64),
|
||||
"policy_loss": self.loss.pi_loss,
|
||||
"entropy": self.loss.entropy,
|
||||
|
@ -240,11 +305,8 @@ class VTracePolicyGraph(LearningRateSchedule, TFPolicyGraph):
|
|||
"vf_loss": self.loss.vf_loss,
|
||||
"vf_explained_var": explained_variance(
|
||||
tf.reshape(self.loss.vtrace_returns.vs, [-1]),
|
||||
tf.reshape(to_batches(values)[:-1], [-1])),
|
||||
"mean_KL": self.mean_KL,
|
||||
"max_KL": self.max_KL,
|
||||
"median_KL": self.median_KL,
|
||||
},
|
||||
tf.reshape(make_time_major(values, drop_last=True), [-1])),
|
||||
}, **self.KL_stats),
|
||||
}
|
||||
|
||||
@override(TFPolicyGraph)
|
||||
|
|
268
python/ray/rllib/agents/impala/vtrace_test.py
Normal file
268
python/ray/rllib/agents/impala/vtrace_test.py
Normal file
|
@ -0,0 +1,268 @@
|
|||
# Copyright 2018 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Tests for V-trace.
|
||||
|
||||
For details and theory see:
|
||||
|
||||
"IMPALA: Scalable Distributed Deep-RL with
|
||||
Importance Weighted Actor-Learner Architectures"
|
||||
by Espeholt, Soyer, Munos et al.
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
import vtrace
|
||||
|
||||
|
||||
def _shaped_arange(*shape):
|
||||
"""Runs np.arange, converts to float and reshapes."""
|
||||
return np.arange(np.prod(shape), dtype=np.float32).reshape(*shape)
|
||||
|
||||
|
||||
def _softmax(logits):
|
||||
"""Applies softmax non-linearity on inputs."""
|
||||
return np.exp(logits) / np.sum(np.exp(logits), axis=-1, keepdims=True)
|
||||
|
||||
|
||||
def _ground_truth_calculation(discounts, log_rhos, rewards, values,
|
||||
bootstrap_value, clip_rho_threshold,
|
||||
clip_pg_rho_threshold):
|
||||
"""Calculates the ground truth for V-trace in Python/Numpy."""
|
||||
vs = []
|
||||
seq_len = len(discounts)
|
||||
rhos = np.exp(log_rhos)
|
||||
cs = np.minimum(rhos, 1.0)
|
||||
clipped_rhos = rhos
|
||||
if clip_rho_threshold:
|
||||
clipped_rhos = np.minimum(rhos, clip_rho_threshold)
|
||||
clipped_pg_rhos = rhos
|
||||
if clip_pg_rho_threshold:
|
||||
clipped_pg_rhos = np.minimum(rhos, clip_pg_rho_threshold)
|
||||
|
||||
# This is a very inefficient way to calculate the V-trace ground truth.
|
||||
# We calculate it this way because it is close to the mathematical notation
|
||||
# of
|
||||
# V-trace.
|
||||
# v_s = V(x_s)
|
||||
# + \sum^{T-1}_{t=s} \gamma^{t-s}
|
||||
# * \prod_{i=s}^{t-1} c_i
|
||||
# * \rho_t (r_t + \gamma V(x_{t+1}) - V(x_t))
|
||||
# Note that when we take the product over c_i, we write `s:t` as the
|
||||
# notation
|
||||
# of the paper is inclusive of the `t-1`, but Python is exclusive.
|
||||
# Also note that np.prod([]) == 1.
|
||||
values_t_plus_1 = np.concatenate(
|
||||
[values, bootstrap_value[None, :]], axis=0)
|
||||
for s in range(seq_len):
|
||||
v_s = np.copy(values[s]) # Very important copy.
|
||||
for t in range(s, seq_len):
|
||||
v_s += (np.prod(discounts[s:t], axis=0) * np.prod(cs[s:t], axis=0)
|
||||
* clipped_rhos[t] * (rewards[t] + discounts[t] *
|
||||
values_t_plus_1[t + 1] - values[t]))
|
||||
vs.append(v_s)
|
||||
vs = np.stack(vs, axis=0)
|
||||
pg_advantages = (clipped_pg_rhos * (rewards + discounts * np.concatenate(
|
||||
[vs[1:], bootstrap_value[None, :]], axis=0) - values))
|
||||
|
||||
return vtrace.VTraceReturns(vs=vs, pg_advantages=pg_advantages)
|
||||
|
||||
|
||||
class LogProbsFromLogitsAndActionsTest(tf.test.TestCase,
|
||||
parameterized.TestCase):
|
||||
@parameterized.named_parameters(('Batch1', 1), ('Batch2', 2))
|
||||
def test_log_probs_from_logits_and_actions(self, batch_size):
|
||||
"""Tests log_probs_from_logits_and_actions."""
|
||||
seq_len = 7
|
||||
num_actions = 3
|
||||
|
||||
policy_logits = _shaped_arange(seq_len, batch_size, num_actions) + 10
|
||||
actions = np.random.randint(
|
||||
0, num_actions - 1, size=(seq_len, batch_size), dtype=np.int32)
|
||||
|
||||
action_log_probs_tensor = vtrace.log_probs_from_logits_and_actions(
|
||||
policy_logits, actions)
|
||||
|
||||
# Ground Truth
|
||||
# Using broadcasting to create a mask that indexes action logits
|
||||
action_index_mask = actions[..., None] == np.arange(num_actions)
|
||||
|
||||
def index_with_mask(array, mask):
|
||||
return array[mask].reshape(*array.shape[:-1])
|
||||
|
||||
# Note: Normally log(softmax) is not a good idea because it's not
|
||||
# numerically stable. However, in this test we have well-behaved
|
||||
# values.
|
||||
ground_truth_v = index_with_mask(
|
||||
np.log(_softmax(policy_logits)), action_index_mask)
|
||||
|
||||
with self.test_session() as session:
|
||||
self.assertAllClose(ground_truth_v,
|
||||
session.run(action_log_probs_tensor))
|
||||
|
||||
|
||||
class VtraceTest(tf.test.TestCase, parameterized.TestCase):
|
||||
@parameterized.named_parameters(('Batch1', 1), ('Batch5', 5))
|
||||
def test_vtrace(self, batch_size):
|
||||
"""Tests V-trace against ground truth data calculated in python."""
|
||||
seq_len = 5
|
||||
|
||||
# Create log_rhos such that rho will span from near-zero to above the
|
||||
# clipping thresholds. In particular, calculate log_rhos in
|
||||
# [-2.5, 2.5),
|
||||
# so that rho is in approx [0.08, 12.2).
|
||||
log_rhos = _shaped_arange(seq_len, batch_size) / (batch_size * seq_len)
|
||||
log_rhos = 5 * (log_rhos - 0.5) # [0.0, 1.0) -> [-2.5, 2.5).
|
||||
values = {
|
||||
'log_rhos': log_rhos,
|
||||
# T, B where B_i: [0.9 / (i+1)] * T
|
||||
'discounts': np.array([[0.9 / (b + 1) for b in range(batch_size)]
|
||||
for _ in range(seq_len)]),
|
||||
'rewards': _shaped_arange(seq_len, batch_size),
|
||||
'values': _shaped_arange(seq_len, batch_size) / batch_size,
|
||||
'bootstrap_value': _shaped_arange(batch_size) + 1.0,
|
||||
'clip_rho_threshold': 3.7,
|
||||
'clip_pg_rho_threshold': 2.2,
|
||||
}
|
||||
|
||||
output = vtrace.from_importance_weights(**values)
|
||||
|
||||
with self.test_session() as session:
|
||||
output_v = session.run(output)
|
||||
|
||||
ground_truth_v = _ground_truth_calculation(**values)
|
||||
for a, b in zip(ground_truth_v, output_v):
|
||||
self.assertAllClose(a, b)
|
||||
|
||||
@parameterized.named_parameters(('Batch1', 1), ('Batch2', 2))
|
||||
def test_vtrace_from_logits(self, batch_size):
|
||||
"""Tests V-trace calculated from logits."""
|
||||
seq_len = 5
|
||||
num_actions = 3
|
||||
clip_rho_threshold = None # No clipping.
|
||||
clip_pg_rho_threshold = None # No clipping.
|
||||
|
||||
# Intentionally leaving shapes unspecified to test if V-trace can
|
||||
# deal with that.
|
||||
placeholders = {
|
||||
# T, B, NUM_ACTIONS
|
||||
'behaviour_policy_logits': tf.placeholder(
|
||||
dtype=tf.float32, shape=[None, None, None]),
|
||||
# T, B, NUM_ACTIONS
|
||||
'target_policy_logits': tf.placeholder(
|
||||
dtype=tf.float32, shape=[None, None, None]),
|
||||
'actions': tf.placeholder(dtype=tf.int32, shape=[None, None]),
|
||||
'discounts': tf.placeholder(dtype=tf.float32, shape=[None, None]),
|
||||
'rewards': tf.placeholder(dtype=tf.float32, shape=[None, None]),
|
||||
'values': tf.placeholder(dtype=tf.float32, shape=[None, None]),
|
||||
'bootstrap_value': tf.placeholder(dtype=tf.float32, shape=[None]),
|
||||
}
|
||||
|
||||
from_logits_output = vtrace.from_logits(
|
||||
clip_rho_threshold=clip_rho_threshold,
|
||||
clip_pg_rho_threshold=clip_pg_rho_threshold,
|
||||
**placeholders)
|
||||
|
||||
target_log_probs = vtrace.log_probs_from_logits_and_actions(
|
||||
placeholders['target_policy_logits'], placeholders['actions'])
|
||||
behaviour_log_probs = vtrace.log_probs_from_logits_and_actions(
|
||||
placeholders['behaviour_policy_logits'], placeholders['actions'])
|
||||
log_rhos = target_log_probs - behaviour_log_probs
|
||||
ground_truth = (log_rhos, behaviour_log_probs, target_log_probs)
|
||||
|
||||
values = {
|
||||
'behaviour_policy_logits': _shaped_arange(seq_len, batch_size,
|
||||
num_actions),
|
||||
'target_policy_logits': _shaped_arange(seq_len, batch_size,
|
||||
num_actions),
|
||||
'actions': np.random.randint(
|
||||
0, num_actions - 1, size=(seq_len, batch_size)),
|
||||
'discounts': np.array( # T, B where B_i: [0.9 / (i+1)] * T
|
||||
[[0.9 / (b + 1) for b in range(batch_size)]
|
||||
for _ in range(seq_len)]),
|
||||
'rewards': _shaped_arange(seq_len, batch_size),
|
||||
'values': _shaped_arange(seq_len, batch_size) / batch_size,
|
||||
'bootstrap_value': _shaped_arange(batch_size) + 1.0, # B
|
||||
}
|
||||
|
||||
feed_dict = {placeholders[k]: v for k, v in values.items()}
|
||||
with self.test_session() as session:
|
||||
from_logits_output_v = session.run(
|
||||
from_logits_output, feed_dict=feed_dict)
|
||||
(ground_truth_log_rhos, ground_truth_behaviour_action_log_probs,
|
||||
ground_truth_target_action_log_probs) = session.run(
|
||||
ground_truth, feed_dict=feed_dict)
|
||||
|
||||
# Calculate V-trace using the ground truth logits.
|
||||
from_iw = vtrace.from_importance_weights(
|
||||
log_rhos=ground_truth_log_rhos,
|
||||
discounts=values['discounts'],
|
||||
rewards=values['rewards'],
|
||||
values=values['values'],
|
||||
bootstrap_value=values['bootstrap_value'],
|
||||
clip_rho_threshold=clip_rho_threshold,
|
||||
clip_pg_rho_threshold=clip_pg_rho_threshold)
|
||||
|
||||
with self.test_session() as session:
|
||||
from_iw_v = session.run(from_iw)
|
||||
|
||||
self.assertAllClose(from_iw_v.vs, from_logits_output_v.vs)
|
||||
self.assertAllClose(from_iw_v.pg_advantages,
|
||||
from_logits_output_v.pg_advantages)
|
||||
self.assertAllClose(ground_truth_behaviour_action_log_probs,
|
||||
from_logits_output_v.behaviour_action_log_probs)
|
||||
self.assertAllClose(ground_truth_target_action_log_probs,
|
||||
from_logits_output_v.target_action_log_probs)
|
||||
self.assertAllClose(ground_truth_log_rhos,
|
||||
from_logits_output_v.log_rhos)
|
||||
|
||||
def test_higher_rank_inputs_for_importance_weights(self):
|
||||
"""Checks support for additional dimensions in inputs."""
|
||||
placeholders = {
|
||||
'log_rhos': tf.placeholder(
|
||||
dtype=tf.float32, shape=[None, None, 1]),
|
||||
'discounts': tf.placeholder(
|
||||
dtype=tf.float32, shape=[None, None, 1]),
|
||||
'rewards': tf.placeholder(
|
||||
dtype=tf.float32, shape=[None, None, 42]),
|
||||
'values': tf.placeholder(dtype=tf.float32, shape=[None, None, 42]),
|
||||
'bootstrap_value': tf.placeholder(
|
||||
dtype=tf.float32, shape=[None, 42])
|
||||
}
|
||||
output = vtrace.from_importance_weights(**placeholders)
|
||||
self.assertEqual(output.vs.shape.as_list()[-1], 42)
|
||||
|
||||
def test_inconsistent_rank_inputs_for_importance_weights(self):
|
||||
"""Test one of many possible errors in shape of inputs."""
|
||||
placeholders = {
|
||||
'log_rhos': tf.placeholder(
|
||||
dtype=tf.float32, shape=[None, None, 1]),
|
||||
'discounts': tf.placeholder(
|
||||
dtype=tf.float32, shape=[None, None, 1]),
|
||||
'rewards': tf.placeholder(
|
||||
dtype=tf.float32, shape=[None, None, 42]),
|
||||
'values': tf.placeholder(dtype=tf.float32, shape=[None, None, 42]),
|
||||
# Should be [None, 42].
|
||||
'bootstrap_value': tf.placeholder(dtype=tf.float32, shape=[None])
|
||||
}
|
||||
with self.assertRaisesRegexp(ValueError, 'must have rank 2'):
|
||||
vtrace.from_importance_weights(**placeholders)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
tf.test.main()
|
|
@ -6,6 +6,7 @@ from __future__ import absolute_import
|
|||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
import logging
|
||||
import gym
|
||||
|
@ -17,7 +18,7 @@ from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph, \
|
|||
from ray.rllib.models.catalog import ModelCatalog
|
||||
from ray.rllib.utils.error import UnsupportedSpaceException
|
||||
from ray.rllib.utils.explained_variance import explained_variance
|
||||
from ray.rllib.models.action_dist import Categorical
|
||||
from ray.rllib.models.action_dist import MultiCategorical
|
||||
from ray.rllib.evaluation.postprocessing import compute_advantages
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -29,7 +30,7 @@ class PPOSurrogateLoss(object):
|
|||
Arguments:
|
||||
prev_actions_logp: A float32 tensor of shape [T, B].
|
||||
actions_logp: A float32 tensor of shape [T, B].
|
||||
actions_kl: A float32 tensor of shape [T, B].
|
||||
action_kl: A float32 tensor of shape [T, B].
|
||||
actions_entropy: A float32 tensor of shape [T, B].
|
||||
values: A float32 tensor of shape [T, B].
|
||||
valid_mask: A bool tensor of valid RNN input elements (#2992).
|
||||
|
@ -104,7 +105,7 @@ class VTraceSurrogateLoss(object):
|
|||
actions: An int32 tensor of shape [T, B, NUM_ACTIONS].
|
||||
prev_actions_logp: A float32 tensor of shape [T, B].
|
||||
actions_logp: A float32 tensor of shape [T, B].
|
||||
actions_kl: A float32 tensor of shape [T, B].
|
||||
action_kl: A float32 tensor of shape [T, B].
|
||||
actions_entropy: A float32 tensor of shape [T, B].
|
||||
dones: A bool tensor of shape [T, B].
|
||||
behaviour_logits: A float32 tensor of shape [T, B, NUM_ACTIONS].
|
||||
|
@ -118,10 +119,10 @@ class VTraceSurrogateLoss(object):
|
|||
|
||||
# Compute vtrace on the CPU for better perf.
|
||||
with tf.device("/cpu:0"):
|
||||
self.vtrace_returns = vtrace.from_logits(
|
||||
self.vtrace_returns = vtrace.multi_from_logits(
|
||||
behaviour_policy_logits=behaviour_logits,
|
||||
target_policy_logits=target_logits,
|
||||
actions=tf.cast(actions, tf.int32),
|
||||
actions=tf.unstack(tf.cast(actions, tf.int32), axis=2),
|
||||
discounts=tf.to_float(~dones) * discount,
|
||||
rewards=rewards,
|
||||
values=values,
|
||||
|
@ -166,6 +167,21 @@ class AsyncPPOPolicyGraph(LearningRateSchedule, TFPolicyGraph):
|
|||
"Must use `truncate_episodes` batch mode with V-trace."
|
||||
self.config = config
|
||||
self.sess = tf.get_default_session()
|
||||
self.grads = None
|
||||
|
||||
if isinstance(action_space, gym.spaces.Discrete):
|
||||
is_multidiscrete = False
|
||||
output_hidden_shape = [action_space.n]
|
||||
elif isinstance(action_space, gym.spaces.multi_discrete.MultiDiscrete):
|
||||
is_multidiscrete = True
|
||||
output_hidden_shape = action_space.nvec.astype(np.int32)
|
||||
elif self.config["vtrace"]:
|
||||
raise UnsupportedSpaceException(
|
||||
"Action space {} is not supported for APPO + VTrace.",
|
||||
format(action_space))
|
||||
else:
|
||||
is_multidiscrete = False
|
||||
output_hidden_shape = 1
|
||||
|
||||
# Policy network model
|
||||
dist_class, logit_dim = ModelCatalog.get_action_dist(
|
||||
|
@ -186,11 +202,6 @@ class AsyncPPOPolicyGraph(LearningRateSchedule, TFPolicyGraph):
|
|||
existing_seq_lens = existing_inputs[-1]
|
||||
else:
|
||||
actions = ModelCatalog.get_action_placeholder(action_space)
|
||||
if (not isinstance(action_space, gym.spaces.Discrete)
|
||||
and self.config["vtrace"]):
|
||||
raise UnsupportedSpaceException(
|
||||
"Action space {} is not supported with vtrace.".format(
|
||||
action_space))
|
||||
dones = tf.placeholder(tf.bool, [None], name="dones")
|
||||
rewards = tf.placeholder(tf.float32, [None], name="rewards")
|
||||
behaviour_logits = tf.placeholder(
|
||||
|
@ -199,6 +210,7 @@ class AsyncPPOPolicyGraph(LearningRateSchedule, TFPolicyGraph):
|
|||
tf.float32, [None] + list(observation_space.shape))
|
||||
existing_state_in = None
|
||||
existing_seq_lens = None
|
||||
|
||||
if not self.config["vtrace"]:
|
||||
adv_ph = tf.placeholder(
|
||||
tf.float32, name="advantages", shape=(None, ))
|
||||
|
@ -206,7 +218,13 @@ class AsyncPPOPolicyGraph(LearningRateSchedule, TFPolicyGraph):
|
|||
tf.float32, name="value_targets", shape=(None, ))
|
||||
self.observations = observations
|
||||
|
||||
# Unpack behaviour logits
|
||||
unpacked_behaviour_logits = tf.split(
|
||||
behaviour_logits, output_hidden_shape, axis=1)
|
||||
|
||||
# Setup the policy
|
||||
dist_class, logit_dim = ModelCatalog.get_action_dist(
|
||||
action_space, self.config["model"])
|
||||
prev_actions = ModelCatalog.get_action_placeholder(action_space)
|
||||
prev_rewards = tf.placeholder(tf.float32, [None], name="prev_reward")
|
||||
self.model = ModelCatalog.get_model(
|
||||
|
@ -214,6 +232,7 @@ class AsyncPPOPolicyGraph(LearningRateSchedule, TFPolicyGraph):
|
|||
"obs": observations,
|
||||
"prev_actions": prev_actions,
|
||||
"prev_rewards": prev_rewards,
|
||||
"is_training": self._get_is_training_placeholder(),
|
||||
},
|
||||
observation_space,
|
||||
action_space,
|
||||
|
@ -221,16 +240,35 @@ class AsyncPPOPolicyGraph(LearningRateSchedule, TFPolicyGraph):
|
|||
self.config["model"],
|
||||
state_in=existing_state_in,
|
||||
seq_lens=existing_seq_lens)
|
||||
unpacked_outputs = tf.split(
|
||||
self.model.outputs, output_hidden_shape, axis=1)
|
||||
|
||||
action_dist = dist_class(self.model.outputs)
|
||||
prev_action_dist = dist_class(behaviour_logits)
|
||||
dist_inputs = unpacked_outputs if is_multidiscrete else \
|
||||
self.model.outputs
|
||||
prev_dist_inputs = unpacked_behaviour_logits if is_multidiscrete else \
|
||||
behaviour_logits
|
||||
|
||||
action_dist = dist_class(dist_inputs)
|
||||
prev_action_dist = dist_class(prev_dist_inputs)
|
||||
|
||||
values = self.model.value_function()
|
||||
self.value_function = values
|
||||
self.var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
|
||||
tf.get_variable_scope().name)
|
||||
|
||||
def to_batches(tensor):
|
||||
def make_time_major(tensor, drop_last=False):
|
||||
"""Swaps batch and trajectory axis.
|
||||
Args:
|
||||
tensor: A tensor or list of tensors to reshape.
|
||||
drop_last: A bool indicating whether to drop the last
|
||||
trajectory item.
|
||||
Returns:
|
||||
res: A tensor with swapped axes or a list of tensors with
|
||||
swapped axes.
|
||||
"""
|
||||
if isinstance(tensor, list):
|
||||
return [make_time_major(t, drop_last) for t in tensor]
|
||||
|
||||
if self.model.state_init:
|
||||
B = tf.shape(self.model.seq_lens)[0]
|
||||
T = tf.shape(tensor)[0] // B
|
||||
|
@ -241,11 +279,16 @@ class AsyncPPOPolicyGraph(LearningRateSchedule, TFPolicyGraph):
|
|||
B = tf.shape(tensor)[0] // T
|
||||
rs = tf.reshape(tensor,
|
||||
tf.concat([[B, T], tf.shape(tensor)[1:]], axis=0))
|
||||
|
||||
# swap B and T axes
|
||||
return tf.transpose(
|
||||
res = tf.transpose(
|
||||
rs,
|
||||
[1, 0] + list(range(2, 1 + int(tf.shape(tensor).shape[0]))))
|
||||
|
||||
if drop_last:
|
||||
return res[:-1]
|
||||
return res
|
||||
|
||||
if self.model.state_in:
|
||||
max_seq_len = tf.reduce_max(self.model.seq_lens) - 1
|
||||
mask = tf.sequence_mask(self.model.seq_lens, max_seq_len)
|
||||
|
@ -256,21 +299,30 @@ class AsyncPPOPolicyGraph(LearningRateSchedule, TFPolicyGraph):
|
|||
# Inputs are reshaped from [B * T] => [T - 1, B] for V-trace calc.
|
||||
if self.config["vtrace"]:
|
||||
logger.info("Using V-Trace surrogate loss (vtrace=True)")
|
||||
|
||||
# Prepare actions for loss
|
||||
loss_actions = actions if is_multidiscrete else tf.expand_dims(
|
||||
actions, axis=1)
|
||||
|
||||
self.loss = VTraceSurrogateLoss(
|
||||
actions=to_batches(actions)[:-1],
|
||||
prev_actions_logp=to_batches(
|
||||
prev_action_dist.logp(actions))[:-1],
|
||||
actions_logp=to_batches(action_dist.logp(actions))[:-1],
|
||||
actions=make_time_major(loss_actions, drop_last=True),
|
||||
prev_actions_logp=make_time_major(
|
||||
prev_action_dist.logp(actions), drop_last=True),
|
||||
actions_logp=make_time_major(
|
||||
action_dist.logp(actions), drop_last=True),
|
||||
action_kl=prev_action_dist.kl(action_dist),
|
||||
actions_entropy=to_batches(action_dist.entropy())[:-1],
|
||||
dones=to_batches(dones)[:-1],
|
||||
behaviour_logits=to_batches(behaviour_logits)[:-1],
|
||||
target_logits=to_batches(self.model.outputs)[:-1],
|
||||
actions_entropy=make_time_major(
|
||||
action_dist.entropy(), drop_last=True),
|
||||
dones=make_time_major(dones, drop_last=True),
|
||||
behaviour_logits=make_time_major(
|
||||
unpacked_behaviour_logits, drop_last=True),
|
||||
target_logits=make_time_major(
|
||||
unpacked_outputs, drop_last=True),
|
||||
discount=config["gamma"],
|
||||
rewards=to_batches(rewards)[:-1],
|
||||
values=to_batches(values)[:-1],
|
||||
bootstrap_value=to_batches(values)[-1],
|
||||
valid_mask=to_batches(mask)[:-1],
|
||||
rewards=make_time_major(rewards, drop_last=True),
|
||||
values=make_time_major(values, drop_last=True),
|
||||
bootstrap_value=make_time_major(values)[-1],
|
||||
valid_mask=make_time_major(mask, drop_last=True),
|
||||
vf_loss_coeff=self.config["vf_loss_coeff"],
|
||||
entropy_coeff=self.config["entropy_coeff"],
|
||||
clip_rho_threshold=self.config["vtrace_clip_rho_threshold"],
|
||||
|
@ -280,25 +332,41 @@ class AsyncPPOPolicyGraph(LearningRateSchedule, TFPolicyGraph):
|
|||
else:
|
||||
logger.info("Using PPO surrogate loss (vtrace=False)")
|
||||
self.loss = PPOSurrogateLoss(
|
||||
prev_actions_logp=to_batches(prev_action_dist.logp(actions)),
|
||||
actions_logp=to_batches(action_dist.logp(actions)),
|
||||
prev_actions_logp=make_time_major(
|
||||
prev_action_dist.logp(actions)),
|
||||
actions_logp=make_time_major(action_dist.logp(actions)),
|
||||
action_kl=prev_action_dist.kl(action_dist),
|
||||
actions_entropy=to_batches(action_dist.entropy()),
|
||||
values=to_batches(values),
|
||||
valid_mask=to_batches(mask),
|
||||
advantages=to_batches(adv_ph),
|
||||
value_targets=to_batches(value_targets),
|
||||
actions_entropy=make_time_major(action_dist.entropy()),
|
||||
values=make_time_major(values),
|
||||
valid_mask=make_time_major(mask),
|
||||
advantages=make_time_major(adv_ph),
|
||||
value_targets=make_time_major(value_targets),
|
||||
vf_loss_coeff=self.config["vf_loss_coeff"],
|
||||
entropy_coeff=self.config["entropy_coeff"],
|
||||
clip_param=self.config["clip_param"])
|
||||
|
||||
# KL divergence between worker and learner logits for debugging
|
||||
model_dist = Categorical(self.model.outputs)
|
||||
behaviour_dist = Categorical(behaviour_logits)
|
||||
self.KLs = model_dist.kl(behaviour_dist)
|
||||
self.mean_KL = tf.reduce_mean(self.KLs)
|
||||
self.max_KL = tf.reduce_max(self.KLs)
|
||||
self.median_KL = tf.contrib.distributions.percentile(self.KLs, 50.0)
|
||||
model_dist = MultiCategorical(unpacked_outputs)
|
||||
behaviour_dist = MultiCategorical(unpacked_behaviour_logits)
|
||||
|
||||
kls = model_dist.kl(behaviour_dist)
|
||||
if len(kls) > 1:
|
||||
self.KL_stats = {}
|
||||
|
||||
for i, kl in enumerate(kls):
|
||||
self.KL_stats.update({
|
||||
"mean_KL_{}".format(i): tf.reduce_mean(kl),
|
||||
"max_KL_{}".format(i): tf.reduce_max(kl),
|
||||
"median_KL_{}".format(i): tf.contrib.distributions.
|
||||
percentile(kl, 50.0),
|
||||
})
|
||||
else:
|
||||
self.KL_stats = {
|
||||
"mean_KL": tf.reduce_mean(kls[0]),
|
||||
"max_KL": tf.reduce_max(kls[0]),
|
||||
"median_KL": tf.contrib.distributions.percentile(kls[0], 50.0),
|
||||
}
|
||||
|
||||
# Initialize TFPolicyGraph
|
||||
loss_in = [
|
||||
("actions", actions),
|
||||
|
@ -335,12 +403,10 @@ class AsyncPPOPolicyGraph(LearningRateSchedule, TFPolicyGraph):
|
|||
|
||||
self.sess.run(tf.global_variables_initializer())
|
||||
|
||||
if self.config["vtrace"]:
|
||||
values_batched = to_batches(values)[:-1]
|
||||
else:
|
||||
values_batched = to_batches(values)
|
||||
values_batched = make_time_major(
|
||||
values, drop_last=self.config["vtrace"])
|
||||
self.stats_fetches = {
|
||||
"stats": {
|
||||
"stats": dict({
|
||||
"cur_lr": tf.cast(self.cur_lr, tf.float64),
|
||||
"policy_loss": self.loss.pi_loss,
|
||||
"entropy": self.loss.entropy,
|
||||
|
@ -350,12 +416,8 @@ class AsyncPPOPolicyGraph(LearningRateSchedule, TFPolicyGraph):
|
|||
"vf_explained_var": explained_variance(
|
||||
tf.reshape(self.loss.value_targets, [-1]),
|
||||
tf.reshape(values_batched, [-1])),
|
||||
"mean_KL": self.mean_KL,
|
||||
"max_KL": self.max_KL,
|
||||
"median_KL": self.median_KL,
|
||||
},
|
||||
}, **self.KL_stats),
|
||||
}
|
||||
self.stats_fetches["kl"] = self.loss.mean_kl
|
||||
|
||||
def optimizer(self):
|
||||
if self.config["opt_type"] == "adam":
|
||||
|
|
|
@ -662,7 +662,7 @@ class PolicyEvaluator(EvaluatorInterface):
|
|||
return policy_map, preprocessors
|
||||
|
||||
def __del__(self):
|
||||
if isinstance(self.sampler, AsyncSampler):
|
||||
if hasattr(self, "sampler") and isinstance(self.sampler, AsyncSampler):
|
||||
self.sampler.shutdown = True
|
||||
|
||||
|
||||
|
|
|
@ -114,6 +114,31 @@ class Categorical(ActionDistribution):
|
|||
return tf.squeeze(tf.multinomial(self.inputs, 1), axis=1)
|
||||
|
||||
|
||||
class MultiCategorical(ActionDistribution):
|
||||
"""Categorical distribution for discrete action spaces."""
|
||||
|
||||
def __init__(self, inputs):
|
||||
self.cats = [Categorical(input_) for input_ in inputs]
|
||||
self.sample_op = self._build_sample_op()
|
||||
|
||||
def logp(self, actions):
|
||||
# If tensor is provided, unstack it into list
|
||||
if isinstance(actions, tf.Tensor):
|
||||
actions = tf.unstack(actions, axis=1)
|
||||
logps = tf.stack(
|
||||
[cat.logp(act) for cat, act in zip(self.cats, actions)])
|
||||
return tf.reduce_sum(logps, axis=0)
|
||||
|
||||
def entropy(self):
|
||||
return tf.stack([cat.entropy() for cat in self.cats], axis=1)
|
||||
|
||||
def kl(self, other):
|
||||
return [cat.kl(oth_cat) for cat, oth_cat in zip(self.cats, other.cats)]
|
||||
|
||||
def _build_sample_op(self):
|
||||
return tf.stack([cat.sample() for cat in self.cats], axis=1)
|
||||
|
||||
|
||||
class DiagGaussian(ActionDistribution):
|
||||
"""Action distribution where each vector element is a gaussian.
|
||||
|
||||
|
|
|
@ -12,8 +12,8 @@ from ray.tune.registry import RLLIB_MODEL, RLLIB_PREPROCESSOR, \
|
|||
_global_registry
|
||||
|
||||
from ray.rllib.models.extra_spaces import Simplex
|
||||
from ray.rllib.models.action_dist import (Categorical, Deterministic,
|
||||
DiagGaussian,
|
||||
from ray.rllib.models.action_dist import (Categorical, MultiCategorical,
|
||||
Deterministic, DiagGaussian,
|
||||
MultiActionDistribution, Dirichlet)
|
||||
from ray.rllib.models.preprocessors import get_preprocessor
|
||||
from ray.rllib.models.fcnet import FullyConnectedNetwork
|
||||
|
@ -136,6 +136,9 @@ class ModelCatalog(object):
|
|||
input_lens=input_lens), sum(input_lens)
|
||||
elif isinstance(action_space, Simplex):
|
||||
return Dirichlet, action_space.shape[0]
|
||||
elif isinstance(action_space, gym.spaces.multi_discrete.MultiDiscrete):
|
||||
return MultiCategorical, sum(action_space.nvec)
|
||||
|
||||
raise NotImplementedError("Unsupported args: {} {}".format(
|
||||
action_space, dist_type))
|
||||
|
||||
|
@ -171,6 +174,11 @@ class ModelCatalog(object):
|
|||
elif isinstance(action_space, Simplex):
|
||||
return tf.placeholder(
|
||||
tf.float32, shape=(None, action_space.shape[0]), name="action")
|
||||
elif isinstance(action_space, gym.spaces.multi_discrete.MultiDiscrete):
|
||||
return tf.placeholder(
|
||||
tf.as_dtype(action_space.dtype),
|
||||
shape=(None, len(action_space.nvec)),
|
||||
name="action")
|
||||
else:
|
||||
raise NotImplementedError("action space {}"
|
||||
" not supported".format(action_space))
|
||||
|
|
Loading…
Add table
Reference in a new issue