2019-12-28 09:51:09 -08:00
|
|
|
import functools
|
2020-11-12 03:18:50 -08:00
|
|
|
import gym
|
2022-01-05 11:29:44 +01:00
|
|
|
from math import log
|
|
|
|
import numpy as np
|
2021-04-16 09:16:24 +02:00
|
|
|
import tree # pip install dm_tree
|
2022-01-05 11:29:44 +01:00
|
|
|
from typing import Optional
|
2019-07-27 02:08:16 -07:00
|
|
|
|
|
|
|
from ray.rllib.models.action_dist import ActionDistribution
|
2020-07-24 12:01:46 -07:00
|
|
|
from ray.rllib.models.modelv2 import ModelV2
|
2020-04-28 14:59:16 +02:00
|
|
|
from ray.rllib.utils import MIN_LOG_NN_OUTPUT, MAX_LOG_NN_OUTPUT, SMALL_NUMBER
|
2022-02-23 13:03:45 +01:00
|
|
|
from ray.rllib.utils.annotations import override, DeveloperAPI, ExperimentalAPI
|
2020-04-28 14:59:16 +02:00
|
|
|
from ray.rllib.utils.framework import try_import_tf, try_import_tfp
|
2020-05-27 10:21:30 +02:00
|
|
|
from ray.rllib.utils.spaces.space_utils import get_base_struct_from_space
|
2020-11-12 03:18:50 -08:00
|
|
|
from ray.rllib.utils.typing import TensorType, List, Union, Tuple, ModelConfigDict
|
2019-07-27 02:08:16 -07:00
|
|
|
|
2020-06-30 10:13:20 +02:00
|
|
|
tf1, tf, tfv = try_import_tf()
|
2020-02-22 23:19:49 +01:00
|
|
|
tfp = try_import_tfp()
|
2019-07-27 02:08:16 -07:00
|
|
|
|
|
|
|
|
|
|
|
@DeveloperAPI
|
|
|
|
class TFActionDistribution(ActionDistribution):
|
|
|
|
"""TF-specific extensions for building action distributions."""
|
|
|
|
|
2020-07-24 12:01:46 -07:00
|
|
|
@override(ActionDistribution)
|
|
|
|
def __init__(self, inputs: List[TensorType], model: ModelV2):
|
2020-02-11 00:22:07 +01:00
|
|
|
super().__init__(inputs, model)
|
2019-07-27 02:08:16 -07:00
|
|
|
self.sample_op = self._build_sample_op()
|
2020-04-15 13:25:16 +02:00
|
|
|
self.sampled_action_logp_op = self.logp(self.sample_op)
|
2019-07-27 02:08:16 -07:00
|
|
|
|
2020-07-24 12:01:46 -07:00
|
|
|
def _build_sample_op(self) -> TensorType:
|
2019-07-27 02:08:16 -07:00
|
|
|
"""Implement this instead of sample(), to enable op reuse.
|
|
|
|
|
|
|
|
This is needed since the sample op is non-deterministic and is shared
|
2019-08-10 14:05:12 -07:00
|
|
|
between sample() and sampled_action_logp().
|
2019-07-27 02:08:16 -07:00
|
|
|
"""
|
|
|
|
raise NotImplementedError
|
|
|
|
|
2019-08-10 14:05:12 -07:00
|
|
|
@override(ActionDistribution)
|
2020-07-24 12:01:46 -07:00
|
|
|
def sample(self) -> TensorType:
|
2019-07-27 02:08:16 -07:00
|
|
|
"""Draw a sample from the action distribution."""
|
|
|
|
return self.sample_op
|
|
|
|
|
2019-08-10 14:05:12 -07:00
|
|
|
@override(ActionDistribution)
|
2020-07-24 12:01:46 -07:00
|
|
|
def sampled_action_logp(self) -> TensorType:
|
2019-07-27 02:08:16 -07:00
|
|
|
"""Returns the log probability of the sampled action."""
|
2020-04-15 13:25:16 +02:00
|
|
|
return self.sampled_action_logp_op
|
2019-07-27 02:08:16 -07:00
|
|
|
|
|
|
|
|
2022-05-24 22:14:25 -07:00
|
|
|
@DeveloperAPI
|
2019-07-27 02:08:16 -07:00
|
|
|
class Categorical(TFActionDistribution):
|
|
|
|
"""Categorical distribution for discrete action spaces."""
|
|
|
|
|
2020-11-12 03:18:50 -08:00
|
|
|
def __init__(
|
|
|
|
self, inputs: List[TensorType], model: ModelV2 = None, temperature: float = 1.0
|
|
|
|
):
|
2020-03-06 19:37:12 +01:00
|
|
|
assert temperature > 0.0, "Categorical `temperature` must be > 0.0!"
|
2020-02-19 21:18:45 +01:00
|
|
|
# Allow softmax formula w/ temperature != 1.0:
|
|
|
|
# Divide inputs by temperature.
|
|
|
|
super().__init__(inputs / temperature, model)
|
|
|
|
|
|
|
|
@override(ActionDistribution)
|
2020-11-12 03:18:50 -08:00
|
|
|
def deterministic_sample(self) -> TensorType:
|
2020-02-19 21:18:45 +01:00
|
|
|
return tf.math.argmax(self.inputs, axis=1)
|
2019-08-10 14:05:12 -07:00
|
|
|
|
2019-07-27 02:08:16 -07:00
|
|
|
@override(ActionDistribution)
|
2020-11-12 03:18:50 -08:00
|
|
|
def logp(self, x: TensorType) -> TensorType:
|
2019-07-27 02:08:16 -07:00
|
|
|
return -tf.nn.sparse_softmax_cross_entropy_with_logits(
|
|
|
|
logits=self.inputs, labels=tf.cast(x, tf.int32)
|
2022-01-29 18:41:57 -08:00
|
|
|
)
|
2019-07-27 02:08:16 -07:00
|
|
|
|
|
|
|
@override(ActionDistribution)
|
2020-11-12 03:18:50 -08:00
|
|
|
def entropy(self) -> TensorType:
|
2020-06-25 19:01:32 +02:00
|
|
|
a0 = self.inputs - tf.reduce_max(self.inputs, axis=1, keepdims=True)
|
2019-07-27 02:08:16 -07:00
|
|
|
ea0 = tf.exp(a0)
|
2020-06-25 19:01:32 +02:00
|
|
|
z0 = tf.reduce_sum(ea0, axis=1, keepdims=True)
|
2019-07-27 02:08:16 -07:00
|
|
|
p0 = ea0 / z0
|
2020-06-25 19:01:32 +02:00
|
|
|
return tf.reduce_sum(p0 * (tf.math.log(z0) - a0), axis=1)
|
2019-07-27 02:08:16 -07:00
|
|
|
|
|
|
|
@override(ActionDistribution)
|
2020-11-12 03:18:50 -08:00
|
|
|
def kl(self, other: ActionDistribution) -> TensorType:
|
2020-06-25 19:01:32 +02:00
|
|
|
a0 = self.inputs - tf.reduce_max(self.inputs, axis=1, keepdims=True)
|
|
|
|
a1 = other.inputs - tf.reduce_max(other.inputs, axis=1, keepdims=True)
|
2019-07-27 02:08:16 -07:00
|
|
|
ea0 = tf.exp(a0)
|
|
|
|
ea1 = tf.exp(a1)
|
2020-06-25 19:01:32 +02:00
|
|
|
z0 = tf.reduce_sum(ea0, axis=1, keepdims=True)
|
|
|
|
z1 = tf.reduce_sum(ea1, axis=1, keepdims=True)
|
2019-07-27 02:08:16 -07:00
|
|
|
p0 = ea0 / z0
|
2020-06-25 19:01:32 +02:00
|
|
|
return tf.reduce_sum(p0 * (a0 - tf.math.log(z0) - a1 + tf.math.log(z1)), axis=1)
|
2019-07-27 02:08:16 -07:00
|
|
|
|
|
|
|
@override(TFActionDistribution)
|
2020-11-12 03:18:50 -08:00
|
|
|
def _build_sample_op(self) -> TensorType:
|
2020-06-30 10:13:20 +02:00
|
|
|
return tf.squeeze(tf.random.categorical(self.inputs, 1), axis=1)
|
2019-07-27 02:08:16 -07:00
|
|
|
|
2019-08-06 18:13:16 +00:00
|
|
|
@staticmethod
|
|
|
|
@override(ActionDistribution)
|
|
|
|
def required_model_output_shape(action_space, model_config):
|
|
|
|
return action_space.n
|
|
|
|
|
2019-07-27 02:08:16 -07:00
|
|
|
|
2022-02-23 13:03:45 +01:00
|
|
|
@DeveloperAPI
|
2019-07-27 02:08:16 -07:00
|
|
|
class MultiCategorical(TFActionDistribution):
|
2019-08-06 18:13:16 +00:00
|
|
|
"""MultiCategorical distribution for MultiDiscrete action spaces."""
|
2019-07-27 02:08:16 -07:00
|
|
|
|
2021-04-11 13:16:01 +02:00
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
inputs: List[TensorType],
|
|
|
|
model: ModelV2,
|
|
|
|
input_lens: Union[List[int], np.ndarray, Tuple[int, ...]],
|
|
|
|
action_space=None,
|
|
|
|
):
|
2019-08-10 14:05:12 -07:00
|
|
|
# skip TFActionDistribution init
|
|
|
|
ActionDistribution.__init__(self, inputs, model)
|
2019-07-27 02:08:16 -07:00
|
|
|
self.cats = [
|
2019-08-10 14:05:12 -07:00
|
|
|
Categorical(input_, model)
|
2019-07-27 02:08:16 -07:00
|
|
|
for input_ in tf.split(inputs, input_lens, axis=1)
|
|
|
|
]
|
2021-04-11 13:16:01 +02:00
|
|
|
self.action_space = action_space
|
2021-12-11 14:57:58 +01:00
|
|
|
if self.action_space is None:
|
|
|
|
self.action_space = gym.spaces.MultiDiscrete(
|
|
|
|
[c.inputs.shape[1] for c in self.cats]
|
|
|
|
)
|
2019-07-27 02:08:16 -07:00
|
|
|
self.sample_op = self._build_sample_op()
|
2020-04-15 13:25:16 +02:00
|
|
|
self.sampled_action_logp_op = self.logp(self.sample_op)
|
2019-07-27 02:08:16 -07:00
|
|
|
|
2020-02-19 21:18:45 +01:00
|
|
|
@override(ActionDistribution)
|
2020-11-12 03:18:50 -08:00
|
|
|
def deterministic_sample(self) -> TensorType:
|
2020-03-23 20:19:30 +01:00
|
|
|
sample_ = tf.stack([cat.deterministic_sample() for cat in self.cats], axis=1)
|
2021-04-11 13:16:01 +02:00
|
|
|
if isinstance(self.action_space, gym.spaces.Box):
|
|
|
|
return tf.cast(
|
|
|
|
tf.reshape(sample_, [-1] + list(self.action_space.shape)),
|
|
|
|
self.action_space.dtype,
|
|
|
|
)
|
|
|
|
return sample_
|
2020-02-19 21:18:45 +01:00
|
|
|
|
2019-07-27 02:08:16 -07:00
|
|
|
@override(ActionDistribution)
|
2020-11-12 03:18:50 -08:00
|
|
|
def logp(self, actions: TensorType) -> TensorType:
|
2020-03-04 09:41:40 +01:00
|
|
|
# If tensor is provided, unstack it into list.
|
2019-07-27 02:08:16 -07:00
|
|
|
if isinstance(actions, tf.Tensor):
|
2021-04-11 13:16:01 +02:00
|
|
|
if isinstance(self.action_space, gym.spaces.Box):
|
|
|
|
actions = tf.reshape(
|
2022-03-31 13:52:00 +02:00
|
|
|
actions, [-1, int(np.prod(self.action_space.shape))]
|
2021-04-11 13:16:01 +02:00
|
|
|
)
|
2021-12-11 14:57:58 +01:00
|
|
|
elif isinstance(self.action_space, gym.spaces.MultiDiscrete):
|
|
|
|
actions.set_shape((None, len(self.cats)))
|
2019-07-27 02:08:16 -07:00
|
|
|
actions = tf.unstack(tf.cast(actions, tf.int32), axis=1)
|
|
|
|
logps = tf.stack([cat.logp(act) for cat, act in zip(self.cats, actions)])
|
|
|
|
return tf.reduce_sum(logps, axis=0)
|
|
|
|
|
|
|
|
@override(ActionDistribution)
|
2020-11-12 03:18:50 -08:00
|
|
|
def multi_entropy(self) -> TensorType:
|
2019-07-27 02:08:16 -07:00
|
|
|
return tf.stack([cat.entropy() for cat in self.cats], axis=1)
|
|
|
|
|
|
|
|
@override(ActionDistribution)
|
2020-11-12 03:18:50 -08:00
|
|
|
def entropy(self) -> TensorType:
|
2019-07-27 02:08:16 -07:00
|
|
|
return tf.reduce_sum(self.multi_entropy(), axis=1)
|
|
|
|
|
|
|
|
@override(ActionDistribution)
|
2020-11-12 03:18:50 -08:00
|
|
|
def multi_kl(self, other: ActionDistribution) -> TensorType:
|
2020-02-12 21:46:15 +01:00
|
|
|
return tf.stack(
|
|
|
|
[cat.kl(oth_cat) for cat, oth_cat in zip(self.cats, other.cats)], axis=1
|
|
|
|
)
|
2019-07-27 02:08:16 -07:00
|
|
|
|
|
|
|
@override(ActionDistribution)
|
2020-11-12 03:18:50 -08:00
|
|
|
def kl(self, other: ActionDistribution) -> TensorType:
|
2019-07-27 02:08:16 -07:00
|
|
|
return tf.reduce_sum(self.multi_kl(other), axis=1)
|
|
|
|
|
|
|
|
@override(TFActionDistribution)
|
2020-11-12 03:18:50 -08:00
|
|
|
def _build_sample_op(self) -> TensorType:
|
2021-04-11 13:16:01 +02:00
|
|
|
sample_op = tf.stack([cat.sample() for cat in self.cats], axis=1)
|
|
|
|
if isinstance(self.action_space, gym.spaces.Box):
|
|
|
|
return tf.cast(
|
|
|
|
tf.reshape(sample_op, [-1] + list(self.action_space.shape)),
|
|
|
|
dtype=self.action_space.dtype,
|
|
|
|
)
|
|
|
|
return sample_op
|
2019-07-27 02:08:16 -07:00
|
|
|
|
2019-08-06 18:13:16 +00:00
|
|
|
@staticmethod
|
|
|
|
@override(ActionDistribution)
|
2020-11-12 03:18:50 -08:00
|
|
|
def required_model_output_shape(
|
|
|
|
action_space: gym.Space, model_config: ModelConfigDict
|
|
|
|
) -> Union[int, np.ndarray]:
|
2021-04-11 13:16:01 +02:00
|
|
|
# Int Box.
|
|
|
|
if isinstance(action_space, gym.spaces.Box):
|
|
|
|
assert action_space.dtype.name.startswith("int")
|
|
|
|
low_ = np.min(action_space.low)
|
|
|
|
high_ = np.max(action_space.high)
|
|
|
|
assert np.all(action_space.low == low_)
|
|
|
|
assert np.all(action_space.high == high_)
|
2022-03-31 13:52:00 +02:00
|
|
|
np.prod(action_space.shape, dtype=np.int32) * (high_ - low_ + 1)
|
2021-04-11 13:16:01 +02:00
|
|
|
# MultiDiscrete space.
|
|
|
|
else:
|
2022-03-31 13:52:00 +02:00
|
|
|
# nvec is already integer, so no casting needed.
|
2021-04-11 13:16:01 +02:00
|
|
|
return np.sum(action_space.nvec)
|
2019-08-06 18:13:16 +00:00
|
|
|
|
2019-07-27 02:08:16 -07:00
|
|
|
|
2022-02-23 13:03:45 +01:00
|
|
|
@ExperimentalAPI
|
|
|
|
class SlateMultiCategorical(Categorical):
|
|
|
|
"""MultiCategorical distribution for MultiDiscrete action spaces.
|
|
|
|
|
|
|
|
The action space must be uniform, meaning all nvec items have the same size, e.g.
|
|
|
|
MultiDiscrete([10, 10, 10]), where 10 is the number of candidates to pick from
|
|
|
|
and 3 is the slate size (pick 3 out of 10). When picking candidates, no candidate
|
|
|
|
must be picked more than once.
|
|
|
|
"""
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
inputs: List[TensorType],
|
|
|
|
model: ModelV2 = None,
|
|
|
|
temperature: float = 1.0,
|
|
|
|
action_space: Optional[gym.spaces.MultiDiscrete] = None,
|
|
|
|
all_slates=None,
|
|
|
|
):
|
|
|
|
assert temperature > 0.0, "Categorical `temperature` must be > 0.0!"
|
|
|
|
# Allow softmax formula w/ temperature != 1.0:
|
|
|
|
# Divide inputs by temperature.
|
|
|
|
super().__init__(inputs / temperature, model)
|
|
|
|
self.action_space = action_space
|
|
|
|
# Assert uniformness of the action space (all discrete buckets have the same
|
|
|
|
# size).
|
|
|
|
assert isinstance(self.action_space, gym.spaces.MultiDiscrete) and all(
|
|
|
|
n == self.action_space.nvec[0] for n in self.action_space.nvec
|
|
|
|
)
|
|
|
|
self.all_slates = all_slates
|
|
|
|
|
|
|
|
@override(ActionDistribution)
|
|
|
|
def deterministic_sample(self) -> TensorType:
|
|
|
|
# Get a sample from the underlying Categorical (batch of ints).
|
|
|
|
sample = super().deterministic_sample()
|
|
|
|
# Use the sampled ints to pick the actual slates.
|
|
|
|
return tf.gather(self.all_slates, sample)
|
|
|
|
|
|
|
|
@override(ActionDistribution)
|
|
|
|
def logp(self, x: TensorType) -> TensorType:
|
|
|
|
# TODO: Implement.
|
|
|
|
return tf.ones_like(self.inputs[:, 0])
|
|
|
|
|
|
|
|
|
|
|
|
@DeveloperAPI
|
2020-03-06 19:37:12 +01:00
|
|
|
class GumbelSoftmax(TFActionDistribution):
|
|
|
|
"""GumbelSoftmax distr. (for differentiable sampling in discr. actions
|
|
|
|
|
|
|
|
The Gumbel Softmax distribution [1] (also known as the Concrete [2]
|
|
|
|
distribution) is a close cousin of the relaxed one-hot categorical
|
|
|
|
distribution, whose tfp implementation we will use here plus
|
|
|
|
adjusted `sample_...` and `log_prob` methods. See discussion at [0].
|
|
|
|
|
|
|
|
[0] https://stackoverflow.com/questions/56226133/
|
|
|
|
soft-actor-critic-with-discrete-action-space
|
|
|
|
|
|
|
|
[1] Categorical Reparametrization with Gumbel-Softmax (Jang et al, 2017):
|
|
|
|
https://arxiv.org/abs/1611.01144
|
|
|
|
[2] The Concrete Distribution: A Continuous Relaxation of Discrete Random
|
|
|
|
Variables (Maddison et al, 2017) https://arxiv.org/abs/1611.00712
|
|
|
|
"""
|
|
|
|
|
2020-11-12 03:18:50 -08:00
|
|
|
def __init__(
|
|
|
|
self, inputs: List[TensorType], model: ModelV2 = None, temperature: float = 1.0
|
|
|
|
):
|
2020-03-06 19:37:12 +01:00
|
|
|
"""Initializes a GumbelSoftmax distribution.
|
|
|
|
|
|
|
|
Args:
|
2022-06-01 11:27:54 -07:00
|
|
|
temperature: Temperature parameter. For low temperatures,
|
2020-03-06 19:37:12 +01:00
|
|
|
the expected value approaches a categorical random variable.
|
|
|
|
For high temperatures, the expected value approaches a uniform
|
|
|
|
distribution.
|
|
|
|
"""
|
|
|
|
assert temperature >= 0.0
|
|
|
|
self.dist = tfp.distributions.RelaxedOneHotCategorical(
|
|
|
|
temperature=temperature, logits=inputs
|
|
|
|
)
|
2020-07-16 14:55:50 +02:00
|
|
|
self.probs = tf.nn.softmax(self.dist._distribution.logits)
|
2020-03-06 19:37:12 +01:00
|
|
|
super().__init__(inputs, model)
|
|
|
|
|
|
|
|
@override(ActionDistribution)
|
2020-11-12 03:18:50 -08:00
|
|
|
def deterministic_sample(self) -> TensorType:
|
2020-03-06 19:37:12 +01:00
|
|
|
# Return the dist object's prob values.
|
2020-07-16 14:55:50 +02:00
|
|
|
return self.probs
|
2020-03-06 19:37:12 +01:00
|
|
|
|
|
|
|
@override(ActionDistribution)
|
2020-11-12 03:18:50 -08:00
|
|
|
def logp(self, x: TensorType) -> TensorType:
|
2020-03-06 19:37:12 +01:00
|
|
|
# Override since the implementation of tfp.RelaxedOneHotCategorical
|
|
|
|
# yields positive values.
|
|
|
|
if x.shape != self.dist.logits.shape:
|
|
|
|
values = tf.one_hot(
|
|
|
|
x, self.dist.logits.shape.as_list()[-1], dtype=tf.float32
|
|
|
|
)
|
|
|
|
assert values.shape == self.dist.logits.shape, (
|
|
|
|
values.shape,
|
|
|
|
self.dist.logits.shape,
|
|
|
|
)
|
|
|
|
|
|
|
|
# [0]'s implementation (see line below) seems to be an approximation
|
|
|
|
# to the actual Gumbel Softmax density.
|
|
|
|
return -tf.reduce_sum(
|
|
|
|
-x * tf.nn.log_softmax(self.dist.logits, axis=-1), axis=-1
|
|
|
|
)
|
|
|
|
|
|
|
|
@override(TFActionDistribution)
|
2020-11-12 03:18:50 -08:00
|
|
|
def _build_sample_op(self) -> TensorType:
|
2020-03-06 19:37:12 +01:00
|
|
|
return self.dist.sample()
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
@override(ActionDistribution)
|
2020-11-12 03:18:50 -08:00
|
|
|
def required_model_output_shape(
|
|
|
|
action_space: gym.Space, model_config: ModelConfigDict
|
|
|
|
) -> Union[int, np.ndarray]:
|
2020-03-06 19:37:12 +01:00
|
|
|
return action_space.n
|
|
|
|
|
|
|
|
|
2022-02-23 13:03:45 +01:00
|
|
|
@DeveloperAPI
|
2019-07-27 02:08:16 -07:00
|
|
|
class DiagGaussian(TFActionDistribution):
|
|
|
|
"""Action distribution where each vector element is a gaussian.
|
|
|
|
|
|
|
|
The first half of the input vector defines the gaussian means, and the
|
|
|
|
second half the gaussian standard deviations.
|
|
|
|
"""
|
|
|
|
|
2022-01-05 11:29:44 +01:00
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
inputs: List[TensorType],
|
|
|
|
model: ModelV2,
|
|
|
|
*,
|
|
|
|
action_space: Optional[gym.spaces.Space] = None
|
|
|
|
):
|
2019-07-27 02:08:16 -07:00
|
|
|
mean, log_std = tf.split(inputs, 2, axis=1)
|
|
|
|
self.mean = mean
|
|
|
|
self.log_std = log_std
|
|
|
|
self.std = tf.exp(log_std)
|
2022-01-05 11:29:44 +01:00
|
|
|
# Remember to squeeze action samples in case action space is Box(shape)
|
|
|
|
self.zero_action_dim = action_space and action_space.shape == ()
|
2020-02-11 00:22:07 +01:00
|
|
|
super().__init__(inputs, model)
|
2019-07-27 02:08:16 -07:00
|
|
|
|
2020-02-19 21:18:45 +01:00
|
|
|
@override(ActionDistribution)
|
2020-11-12 03:18:50 -08:00
|
|
|
def deterministic_sample(self) -> TensorType:
|
2020-02-19 21:18:45 +01:00
|
|
|
return self.mean
|
|
|
|
|
2019-07-27 02:08:16 -07:00
|
|
|
@override(ActionDistribution)
|
2020-11-12 03:18:50 -08:00
|
|
|
def logp(self, x: TensorType) -> TensorType:
|
2022-01-05 11:29:44 +01:00
|
|
|
# Cover case where action space is Box(shape=()).
|
|
|
|
if int(tf.shape(x).shape[0]) == 1:
|
|
|
|
x = tf.expand_dims(x, axis=1)
|
2020-02-19 21:18:45 +01:00
|
|
|
return (
|
|
|
|
-0.5
|
|
|
|
* tf.reduce_sum(
|
2020-06-25 19:01:32 +02:00
|
|
|
tf.math.square((tf.cast(x, tf.float32) - self.mean) / self.std), axis=1
|
|
|
|
)
|
|
|
|
- 0.5 * np.log(2.0 * np.pi) * tf.cast(tf.shape(x)[1], tf.float32)
|
2020-04-23 09:09:22 +02:00
|
|
|
- tf.reduce_sum(self.log_std, axis=1)
|
2022-01-29 18:41:57 -08:00
|
|
|
)
|
2019-07-27 02:08:16 -07:00
|
|
|
|
|
|
|
@override(ActionDistribution)
|
2020-11-12 03:18:50 -08:00
|
|
|
def kl(self, other: ActionDistribution) -> TensorType:
|
2019-07-27 02:08:16 -07:00
|
|
|
assert isinstance(other, DiagGaussian)
|
|
|
|
return tf.reduce_sum(
|
|
|
|
other.log_std
|
|
|
|
- self.log_std
|
2020-07-24 12:01:46 -07:00
|
|
|
+ (tf.math.square(self.std) + tf.math.square(self.mean - other.mean))
|
|
|
|
/ (2.0 * tf.math.square(other.std))
|
|
|
|
- 0.5,
|
2020-02-19 21:18:45 +01:00
|
|
|
axis=1,
|
|
|
|
)
|
2019-07-27 02:08:16 -07:00
|
|
|
|
|
|
|
@override(ActionDistribution)
|
2020-11-12 03:18:50 -08:00
|
|
|
def entropy(self) -> TensorType:
|
2020-02-19 21:18:45 +01:00
|
|
|
return tf.reduce_sum(self.log_std + 0.5 * np.log(2.0 * np.pi * np.e), axis=1)
|
2019-07-27 02:08:16 -07:00
|
|
|
|
|
|
|
@override(TFActionDistribution)
|
2020-11-12 03:18:50 -08:00
|
|
|
def _build_sample_op(self) -> TensorType:
|
2022-01-05 11:29:44 +01:00
|
|
|
sample = self.mean + self.std * tf.random.normal(tf.shape(self.mean))
|
|
|
|
if self.zero_action_dim:
|
|
|
|
return tf.squeeze(sample, axis=-1)
|
|
|
|
return sample
|
2019-07-27 02:08:16 -07:00
|
|
|
|
2019-08-06 18:13:16 +00:00
|
|
|
@staticmethod
|
|
|
|
@override(ActionDistribution)
|
2020-11-12 03:18:50 -08:00
|
|
|
def required_model_output_shape(
|
|
|
|
action_space: gym.Space, model_config: ModelConfigDict
|
|
|
|
) -> Union[int, np.ndarray]:
|
2022-03-31 13:52:00 +02:00
|
|
|
return np.prod(action_space.shape, dtype=np.int32) * 2
|
2019-08-06 18:13:16 +00:00
|
|
|
|
2019-07-27 02:08:16 -07:00
|
|
|
|
2022-02-23 13:03:45 +01:00
|
|
|
@DeveloperAPI
|
2020-02-22 23:19:49 +01:00
|
|
|
class SquashedGaussian(TFActionDistribution):
|
|
|
|
"""A tanh-squashed Gaussian distribution defined by: mean, std, low, high.
|
|
|
|
|
|
|
|
The distribution will never return low or high exactly, but
|
|
|
|
`low`+SMALL_NUMBER or `high`-SMALL_NUMBER respectively.
|
|
|
|
"""
|
|
|
|
|
2020-11-12 03:18:50 -08:00
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
inputs: List[TensorType],
|
|
|
|
model: ModelV2,
|
|
|
|
low: float = -1.0,
|
|
|
|
high: float = 1.0,
|
|
|
|
):
|
2020-02-22 23:19:49 +01:00
|
|
|
"""Parameterizes the distribution via `inputs`.
|
|
|
|
|
|
|
|
Args:
|
2022-06-01 11:27:54 -07:00
|
|
|
low: The lowest possible sampling value
|
2020-02-22 23:19:49 +01:00
|
|
|
(excluding this value).
|
2022-06-01 11:27:54 -07:00
|
|
|
high: The highest possible sampling value
|
2020-02-22 23:19:49 +01:00
|
|
|
(excluding this value).
|
|
|
|
"""
|
|
|
|
assert tfp is not None
|
2020-04-15 13:25:16 +02:00
|
|
|
mean, log_std = tf.split(inputs, 2, axis=-1)
|
2020-02-22 23:19:49 +01:00
|
|
|
# Clip `scale` values (coming from NN) to reasonable values.
|
2020-04-15 13:25:16 +02:00
|
|
|
log_std = tf.clip_by_value(log_std, MIN_LOG_NN_OUTPUT, MAX_LOG_NN_OUTPUT)
|
|
|
|
std = tf.exp(log_std)
|
|
|
|
self.distr = tfp.distributions.Normal(loc=mean, scale=std)
|
2020-02-22 23:19:49 +01:00
|
|
|
assert np.all(np.less(low, high))
|
|
|
|
self.low = low
|
|
|
|
self.high = high
|
|
|
|
super().__init__(inputs, model)
|
|
|
|
|
|
|
|
@override(ActionDistribution)
|
2020-11-12 03:18:50 -08:00
|
|
|
def deterministic_sample(self) -> TensorType:
|
2020-02-22 23:19:49 +01:00
|
|
|
mean = self.distr.mean()
|
|
|
|
return self._squash(mean)
|
|
|
|
|
|
|
|
@override(TFActionDistribution)
|
2020-11-12 03:18:50 -08:00
|
|
|
def _build_sample_op(self) -> TensorType:
|
2020-02-22 23:19:49 +01:00
|
|
|
return self._squash(self.distr.sample())
|
|
|
|
|
|
|
|
@override(ActionDistribution)
|
2020-11-12 03:18:50 -08:00
|
|
|
def logp(self, x: TensorType) -> TensorType:
|
2020-04-19 10:20:23 +02:00
|
|
|
# Unsquash values (from [low,high] to ]-inf,inf[)
|
2021-05-18 11:10:46 +02:00
|
|
|
unsquashed_values = tf.cast(self._unsquash(x), self.inputs.dtype)
|
2020-04-19 10:20:23 +02:00
|
|
|
# Get log prob of unsquashed values from our Normal.
|
|
|
|
log_prob_gaussian = self.distr.log_prob(unsquashed_values)
|
|
|
|
# For safety reasons, clamp somehow, only then sum up.
|
|
|
|
log_prob_gaussian = tf.clip_by_value(log_prob_gaussian, -100, 100)
|
|
|
|
log_prob_gaussian = tf.reduce_sum(log_prob_gaussian, axis=-1)
|
|
|
|
# Get log-prob for squashed Gaussian.
|
2020-02-22 23:19:49 +01:00
|
|
|
unsquashed_values_tanhd = tf.math.tanh(unsquashed_values)
|
2020-04-19 10:20:23 +02:00
|
|
|
log_prob = log_prob_gaussian - tf.reduce_sum(
|
2020-02-22 23:19:49 +01:00
|
|
|
tf.math.log(1 - unsquashed_values_tanhd ** 2 + SMALL_NUMBER), axis=-1
|
|
|
|
)
|
|
|
|
return log_prob
|
|
|
|
|
2021-05-18 11:10:46 +02:00
|
|
|
def sample_logp(self):
|
|
|
|
z = self.distr.sample()
|
|
|
|
actions = self._squash(z)
|
|
|
|
return actions, tf.reduce_sum(
|
|
|
|
self.distr.log_prob(z) - tf.math.log(1 - actions * actions + SMALL_NUMBER),
|
|
|
|
axis=-1,
|
|
|
|
)
|
|
|
|
|
2021-01-07 15:07:35 +01:00
|
|
|
@override(ActionDistribution)
|
|
|
|
def entropy(self) -> TensorType:
|
|
|
|
raise ValueError("Entropy not defined for SquashedGaussian!")
|
|
|
|
|
|
|
|
@override(ActionDistribution)
|
|
|
|
def kl(self, other: ActionDistribution) -> TensorType:
|
|
|
|
raise ValueError("KL not defined for SquashedGaussian!")
|
|
|
|
|
2020-11-12 03:18:50 -08:00
|
|
|
def _squash(self, raw_values: TensorType) -> TensorType:
|
2020-04-19 10:20:23 +02:00
|
|
|
# Returned values are within [low, high] (including `low` and `high`).
|
|
|
|
squashed = ((tf.math.tanh(raw_values) + 1.0) / 2.0) * (
|
|
|
|
self.high - self.low
|
|
|
|
) + self.low
|
|
|
|
return tf.clip_by_value(squashed, self.low, self.high)
|
2020-02-22 23:19:49 +01:00
|
|
|
|
2020-11-12 03:18:50 -08:00
|
|
|
def _unsquash(self, values: TensorType) -> TensorType:
|
2020-04-19 10:20:23 +02:00
|
|
|
normed_values = (values - self.low) / (self.high - self.low) * 2.0 - 1.0
|
|
|
|
# Stabilize input to atanh.
|
|
|
|
save_normed_values = tf.clip_by_value(
|
|
|
|
normed_values, -1.0 + SMALL_NUMBER, 1.0 - SMALL_NUMBER
|
|
|
|
)
|
|
|
|
unsquashed = tf.math.atanh(save_normed_values)
|
|
|
|
return unsquashed
|
2020-02-22 23:19:49 +01:00
|
|
|
|
2020-05-08 08:26:32 +02:00
|
|
|
@staticmethod
|
|
|
|
@override(ActionDistribution)
|
2020-11-12 03:18:50 -08:00
|
|
|
def required_model_output_shape(
|
|
|
|
action_space: gym.Space, model_config: ModelConfigDict
|
|
|
|
) -> Union[int, np.ndarray]:
|
2022-03-31 13:52:00 +02:00
|
|
|
return np.prod(action_space.shape, dtype=np.int32) * 2
|
2020-05-08 08:26:32 +02:00
|
|
|
|
2020-02-22 23:19:49 +01:00
|
|
|
|
2022-02-23 13:03:45 +01:00
|
|
|
@DeveloperAPI
|
2020-04-30 20:09:33 +02:00
|
|
|
class Beta(TFActionDistribution):
|
|
|
|
"""
|
|
|
|
A Beta distribution is defined on the interval [0, 1] and parameterized by
|
|
|
|
shape parameters alpha and beta (also called concentration parameters).
|
|
|
|
|
|
|
|
PDF(x; alpha, beta) = x**(alpha - 1) (1 - x)**(beta - 1) / Z
|
|
|
|
with Z = Gamma(alpha) Gamma(beta) / Gamma(alpha + beta)
|
|
|
|
and Gamma(n) = (n - 1)!
|
|
|
|
"""
|
|
|
|
|
2020-11-12 03:18:50 -08:00
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
inputs: List[TensorType],
|
|
|
|
model: ModelV2,
|
|
|
|
low: float = 0.0,
|
|
|
|
high: float = 1.0,
|
|
|
|
):
|
2020-04-30 20:09:33 +02:00
|
|
|
# Stabilize input parameters (possibly coming from a linear layer).
|
|
|
|
inputs = tf.clip_by_value(inputs, log(SMALL_NUMBER), -log(SMALL_NUMBER))
|
|
|
|
inputs = tf.math.log(tf.math.exp(inputs) + 1.0) + 1.0
|
|
|
|
self.low = low
|
|
|
|
self.high = high
|
|
|
|
alpha, beta = tf.split(inputs, 2, axis=-1)
|
|
|
|
# Note: concentration0==beta, concentration1=alpha (!)
|
|
|
|
self.dist = tfp.distributions.Beta(concentration1=alpha, concentration0=beta)
|
|
|
|
super().__init__(inputs, model)
|
|
|
|
|
|
|
|
@override(ActionDistribution)
|
2020-11-12 03:18:50 -08:00
|
|
|
def deterministic_sample(self) -> TensorType:
|
2020-04-30 20:09:33 +02:00
|
|
|
mean = self.dist.mean()
|
|
|
|
return self._squash(mean)
|
|
|
|
|
|
|
|
@override(TFActionDistribution)
|
2020-11-12 03:18:50 -08:00
|
|
|
def _build_sample_op(self) -> TensorType:
|
2020-04-30 20:09:33 +02:00
|
|
|
return self._squash(self.dist.sample())
|
|
|
|
|
|
|
|
@override(ActionDistribution)
|
2020-11-12 03:18:50 -08:00
|
|
|
def logp(self, x: TensorType) -> TensorType:
|
2020-04-30 20:09:33 +02:00
|
|
|
unsquashed_values = self._unsquash(x)
|
|
|
|
return tf.math.reduce_sum(self.dist.log_prob(unsquashed_values), axis=-1)
|
|
|
|
|
2020-11-12 03:18:50 -08:00
|
|
|
def _squash(self, raw_values: TensorType) -> TensorType:
|
2020-04-30 20:09:33 +02:00
|
|
|
return raw_values * (self.high - self.low) + self.low
|
|
|
|
|
2020-11-12 03:18:50 -08:00
|
|
|
def _unsquash(self, values: TensorType) -> TensorType:
|
2020-04-30 20:09:33 +02:00
|
|
|
return (values - self.low) / (self.high - self.low)
|
|
|
|
|
2020-05-08 08:26:32 +02:00
|
|
|
@staticmethod
|
|
|
|
@override(ActionDistribution)
|
2020-11-12 03:18:50 -08:00
|
|
|
def required_model_output_shape(
|
|
|
|
action_space: gym.Space, model_config: ModelConfigDict
|
|
|
|
) -> Union[int, np.ndarray]:
|
2022-03-31 13:52:00 +02:00
|
|
|
return np.prod(action_space.shape, dtype=np.int32) * 2
|
2020-05-08 08:26:32 +02:00
|
|
|
|
2020-04-30 20:09:33 +02:00
|
|
|
|
2022-02-23 13:03:45 +01:00
|
|
|
@DeveloperAPI
|
2019-07-27 02:08:16 -07:00
|
|
|
class Deterministic(TFActionDistribution):
|
|
|
|
"""Action distribution that returns the input values directly.
|
|
|
|
|
2020-03-01 20:53:35 +01:00
|
|
|
This is similar to DiagGaussian with standard deviation zero (thus only
|
|
|
|
requiring the "mean" values as NN output).
|
2019-07-27 02:08:16 -07:00
|
|
|
"""
|
|
|
|
|
2020-02-19 21:18:45 +01:00
|
|
|
@override(ActionDistribution)
|
2020-11-12 03:18:50 -08:00
|
|
|
def deterministic_sample(self) -> TensorType:
|
2020-02-19 21:18:45 +01:00
|
|
|
return self.inputs
|
|
|
|
|
2019-07-27 02:08:16 -07:00
|
|
|
@override(TFActionDistribution)
|
2020-11-12 03:18:50 -08:00
|
|
|
def logp(self, x: TensorType) -> TensorType:
|
2020-04-01 09:43:21 +02:00
|
|
|
return tf.zeros_like(self.inputs)
|
2019-07-27 02:08:16 -07:00
|
|
|
|
|
|
|
@override(TFActionDistribution)
|
2020-11-12 03:18:50 -08:00
|
|
|
def _build_sample_op(self) -> TensorType:
|
2019-07-27 02:08:16 -07:00
|
|
|
return self.inputs
|
|
|
|
|
2019-08-06 18:13:16 +00:00
|
|
|
@staticmethod
|
|
|
|
@override(ActionDistribution)
|
2020-11-12 03:18:50 -08:00
|
|
|
def required_model_output_shape(
|
|
|
|
action_space: gym.Space, model_config: ModelConfigDict
|
|
|
|
) -> Union[int, np.ndarray]:
|
2022-03-31 13:52:00 +02:00
|
|
|
return np.prod(action_space.shape, dtype=np.int32)
|
2019-08-06 18:13:16 +00:00
|
|
|
|
2019-07-27 02:08:16 -07:00
|
|
|
|
2022-02-23 13:03:45 +01:00
|
|
|
@DeveloperAPI
|
2019-07-27 02:08:16 -07:00
|
|
|
class MultiActionDistribution(TFActionDistribution):
|
2020-04-28 14:59:16 +02:00
|
|
|
"""Action distribution that operates on a set of actions.
|
2019-07-27 02:08:16 -07:00
|
|
|
|
|
|
|
Args:
|
|
|
|
inputs (Tensor list): A list of tensors from which to compute samples.
|
|
|
|
"""
|
|
|
|
|
2022-05-18 03:22:37 -04:00
|
|
|
def __init__(
|
|
|
|
self, inputs, model, *, child_distributions, input_lens, action_space, **kwargs
|
|
|
|
):
|
2019-08-10 14:05:12 -07:00
|
|
|
ActionDistribution.__init__(self, inputs, model)
|
2020-04-28 14:59:16 +02:00
|
|
|
|
|
|
|
self.action_space_struct = get_base_struct_from_space(action_space)
|
|
|
|
|
2020-10-06 20:28:16 +02:00
|
|
|
self.input_lens = np.array(input_lens, dtype=np.int32)
|
|
|
|
split_inputs = tf.split(inputs, self.input_lens, axis=1)
|
2020-04-28 14:59:16 +02:00
|
|
|
self.flat_child_distributions = tree.map_structure(
|
2022-05-18 03:22:37 -04:00
|
|
|
lambda dist, input_: dist(input_, model, **kwargs),
|
|
|
|
child_distributions,
|
|
|
|
split_inputs,
|
2020-04-28 14:59:16 +02:00
|
|
|
)
|
2019-07-27 02:08:16 -07:00
|
|
|
|
|
|
|
@override(ActionDistribution)
|
|
|
|
def logp(self, x):
|
2020-04-28 14:59:16 +02:00
|
|
|
# Single tensor input (all merged).
|
|
|
|
if isinstance(x, (tf.Tensor, np.ndarray)):
|
|
|
|
split_indices = []
|
|
|
|
for dist in self.flat_child_distributions:
|
|
|
|
if isinstance(dist, Categorical):
|
|
|
|
split_indices.append(1)
|
2021-12-11 14:57:58 +01:00
|
|
|
elif (
|
|
|
|
isinstance(dist, MultiCategorical) and dist.action_space is not None
|
|
|
|
):
|
|
|
|
split_indices.append(np.prod(dist.action_space.shape))
|
2020-04-28 14:59:16 +02:00
|
|
|
else:
|
2022-01-05 11:29:44 +01:00
|
|
|
sample = dist.sample()
|
|
|
|
# Cover Box(shape=()) case.
|
|
|
|
if len(sample.shape) == 1:
|
|
|
|
split_indices.append(1)
|
|
|
|
else:
|
|
|
|
split_indices.append(tf.shape(sample)[1])
|
2020-04-28 14:59:16 +02:00
|
|
|
split_x = tf.split(x, split_indices, axis=1)
|
|
|
|
# Structured or flattened (by single action component) input.
|
|
|
|
else:
|
|
|
|
split_x = tree.flatten(x)
|
|
|
|
|
|
|
|
def map_(val, dist):
|
|
|
|
# Remove extra categorical dimension.
|
2019-07-27 02:08:16 -07:00
|
|
|
if isinstance(dist, Categorical):
|
2021-12-11 14:57:58 +01:00
|
|
|
val = tf.cast(
|
|
|
|
tf.squeeze(val, axis=-1) if len(val.shape) > 1 else val, tf.int32
|
2022-01-29 18:41:57 -08:00
|
|
|
)
|
2020-04-28 14:59:16 +02:00
|
|
|
return dist.logp(val)
|
|
|
|
|
|
|
|
# Remove extra categorical dimension and take the logp of each
|
|
|
|
# component.
|
|
|
|
flat_logps = tree.map_structure(map_, split_x, self.flat_child_distributions)
|
|
|
|
|
|
|
|
return functools.reduce(lambda a, b: a + b, flat_logps)
|
2019-07-27 02:08:16 -07:00
|
|
|
|
|
|
|
@override(ActionDistribution)
|
|
|
|
def kl(self, other):
|
2019-12-28 09:51:09 -08:00
|
|
|
kl_list = [
|
2020-04-28 14:59:16 +02:00
|
|
|
d.kl(o)
|
|
|
|
for d, o in zip(
|
|
|
|
self.flat_child_distributions, other.flat_child_distributions
|
|
|
|
)
|
2019-12-28 09:51:09 -08:00
|
|
|
]
|
|
|
|
return functools.reduce(lambda a, b: a + b, kl_list)
|
2019-07-27 02:08:16 -07:00
|
|
|
|
|
|
|
@override(ActionDistribution)
|
|
|
|
def entropy(self):
|
2020-04-28 14:59:16 +02:00
|
|
|
entropy_list = [d.entropy() for d in self.flat_child_distributions]
|
2019-12-28 09:51:09 -08:00
|
|
|
return functools.reduce(lambda a, b: a + b, entropy_list)
|
2019-07-27 02:08:16 -07:00
|
|
|
|
|
|
|
@override(ActionDistribution)
|
|
|
|
def sample(self):
|
2020-04-28 14:59:16 +02:00
|
|
|
child_distributions = tree.unflatten_as(
|
|
|
|
self.action_space_struct, self.flat_child_distributions
|
|
|
|
)
|
|
|
|
return tree.map_structure(lambda s: s.sample(), child_distributions)
|
2019-07-27 02:08:16 -07:00
|
|
|
|
2020-02-19 21:18:45 +01:00
|
|
|
@override(ActionDistribution)
|
|
|
|
def deterministic_sample(self):
|
2020-04-28 14:59:16 +02:00
|
|
|
child_distributions = tree.unflatten_as(
|
|
|
|
self.action_space_struct, self.flat_child_distributions
|
|
|
|
)
|
|
|
|
return tree.map_structure(
|
|
|
|
lambda s: s.deterministic_sample(), child_distributions
|
|
|
|
)
|
2020-02-19 21:18:45 +01:00
|
|
|
|
2019-07-27 02:08:16 -07:00
|
|
|
@override(TFActionDistribution)
|
2019-08-10 14:05:12 -07:00
|
|
|
def sampled_action_logp(self):
|
2020-04-28 14:59:16 +02:00
|
|
|
p = self.flat_child_distributions[0].sampled_action_logp()
|
|
|
|
for c in self.flat_child_distributions[1:]:
|
2019-08-10 14:05:12 -07:00
|
|
|
p += c.sampled_action_logp()
|
2019-07-27 02:08:16 -07:00
|
|
|
return p
|
|
|
|
|
2020-10-06 20:28:16 +02:00
|
|
|
@override(ActionDistribution)
|
|
|
|
def required_model_output_shape(self, action_space, model_config):
|
2022-03-31 13:52:00 +02:00
|
|
|
return np.sum(self.input_lens, dtype=np.int32)
|
2020-10-06 20:28:16 +02:00
|
|
|
|
2019-07-27 02:08:16 -07:00
|
|
|
|
2022-02-23 13:03:45 +01:00
|
|
|
@DeveloperAPI
|
2019-07-27 02:08:16 -07:00
|
|
|
class Dirichlet(TFActionDistribution):
|
|
|
|
"""Dirichlet distribution for continuous actions that are between
|
|
|
|
[0,1] and sum to 1.
|
|
|
|
|
|
|
|
e.g. actions that represent resource allocation."""
|
|
|
|
|
2020-11-12 03:18:50 -08:00
|
|
|
def __init__(self, inputs: List[TensorType], model: ModelV2):
|
2019-07-27 02:08:16 -07:00
|
|
|
"""Input is a tensor of logits. The exponential of logits is used to
|
|
|
|
parametrize the Dirichlet distribution as all parameters need to be
|
|
|
|
positive. An arbitrary small epsilon is added to the concentration
|
|
|
|
parameters to be zero due to numerical error.
|
|
|
|
|
|
|
|
See issue #4440 for more details.
|
|
|
|
"""
|
|
|
|
self.epsilon = 1e-7
|
|
|
|
concentration = tf.exp(inputs) + self.epsilon
|
2020-11-11 18:45:28 +01:00
|
|
|
self.dist = tf1.distributions.Dirichlet(
|
2019-07-27 02:08:16 -07:00
|
|
|
concentration=concentration,
|
|
|
|
validate_args=True,
|
|
|
|
allow_nan_stats=False,
|
|
|
|
)
|
2020-02-11 00:22:07 +01:00
|
|
|
super().__init__(concentration, model)
|
2019-07-27 02:08:16 -07:00
|
|
|
|
2020-11-11 18:45:28 +01:00
|
|
|
@override(ActionDistribution)
|
|
|
|
def deterministic_sample(self) -> TensorType:
|
|
|
|
return tf.nn.softmax(self.dist.concentration)
|
|
|
|
|
2019-07-27 02:08:16 -07:00
|
|
|
@override(ActionDistribution)
|
2020-11-12 03:18:50 -08:00
|
|
|
def logp(self, x: TensorType) -> TensorType:
|
2020-02-11 00:22:07 +01:00
|
|
|
# Support of Dirichlet are positive real numbers. x is already
|
|
|
|
# an array of positive numbers, but we clip to avoid zeros due to
|
2019-07-27 02:08:16 -07:00
|
|
|
# numerical errors.
|
|
|
|
x = tf.maximum(x, self.epsilon)
|
|
|
|
x = x / tf.reduce_sum(x, axis=-1, keepdims=True)
|
|
|
|
return self.dist.log_prob(x)
|
|
|
|
|
|
|
|
@override(ActionDistribution)
|
2020-11-12 03:18:50 -08:00
|
|
|
def entropy(self) -> TensorType:
|
2019-07-27 02:08:16 -07:00
|
|
|
return self.dist.entropy()
|
|
|
|
|
|
|
|
@override(ActionDistribution)
|
2020-11-12 03:18:50 -08:00
|
|
|
def kl(self, other: ActionDistribution) -> TensorType:
|
2019-07-27 02:08:16 -07:00
|
|
|
return self.dist.kl_divergence(other.dist)
|
|
|
|
|
|
|
|
@override(TFActionDistribution)
|
2020-11-12 03:18:50 -08:00
|
|
|
def _build_sample_op(self) -> TensorType:
|
2019-07-27 02:08:16 -07:00
|
|
|
return self.dist.sample()
|
2019-08-06 18:13:16 +00:00
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
@override(ActionDistribution)
|
2020-11-12 03:18:50 -08:00
|
|
|
def required_model_output_shape(
|
|
|
|
action_space: gym.Space, model_config: ModelConfigDict
|
|
|
|
) -> Union[int, np.ndarray]:
|
2022-03-31 13:52:00 +02:00
|
|
|
return np.prod(action_space.shape, dtype=np.int32)
|