2020-04-23 09:09:22 +02:00
|
|
|
import functools
|
2020-04-15 13:25:16 +02:00
|
|
|
from math import log
|
2019-08-06 18:13:16 +00:00
|
|
|
import numpy as np
|
2021-04-16 09:16:24 +02:00
|
|
|
import tree # pip install dm_tree
|
2020-11-12 03:16:12 -08:00
|
|
|
import gym
|
2019-08-06 18:13:16 +00:00
|
|
|
|
2019-04-12 11:39:14 -07:00
|
|
|
from ray.rllib.models.action_dist import ActionDistribution
|
2020-10-27 10:00:24 +01:00
|
|
|
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
|
2019-04-12 11:39:14 -07:00
|
|
|
from ray.rllib.utils.annotations import override
|
2020-04-23 09:09:22 +02:00
|
|
|
from ray.rllib.utils.framework import try_import_torch
|
2020-04-15 13:25:16 +02:00
|
|
|
from ray.rllib.utils.numpy import SMALL_NUMBER, MIN_LOG_NN_OUTPUT, \
|
|
|
|
MAX_LOG_NN_OUTPUT
|
2020-05-27 10:21:30 +02:00
|
|
|
from ray.rllib.utils.spaces.space_utils import get_base_struct_from_space
|
2020-04-15 13:25:16 +02:00
|
|
|
from ray.rllib.utils.torch_ops import atanh
|
2020-11-12 03:16:12 -08:00
|
|
|
from ray.rllib.utils.typing import TensorType, List, Union, \
|
|
|
|
Tuple, ModelConfigDict
|
2019-12-30 15:27:32 -05:00
|
|
|
|
|
|
|
torch, nn = try_import_torch()
|
2019-04-12 11:39:14 -07:00
|
|
|
|
|
|
|
|
|
|
|
class TorchDistributionWrapper(ActionDistribution):
|
|
|
|
"""Wrapper class for torch.distributions."""
|
|
|
|
|
2020-03-04 09:41:40 +01:00
|
|
|
@override(ActionDistribution)
|
2020-10-27 10:00:24 +01:00
|
|
|
def __init__(self, inputs: List[TensorType], model: TorchModelV2):
|
|
|
|
# If inputs are not a torch Tensor, make them one and make sure they
|
|
|
|
# are on the correct device.
|
2020-03-08 21:03:18 +01:00
|
|
|
if not isinstance(inputs, torch.Tensor):
|
2020-10-27 10:00:24 +01:00
|
|
|
inputs = torch.from_numpy(inputs)
|
|
|
|
if isinstance(model, TorchModelV2):
|
|
|
|
inputs = inputs.to(next(model.parameters()).device)
|
2020-03-04 09:41:40 +01:00
|
|
|
super().__init__(inputs, model)
|
2020-02-19 21:18:45 +01:00
|
|
|
# Store the last sample here.
|
|
|
|
self.last_sample = None
|
|
|
|
|
2019-04-12 11:39:14 -07:00
|
|
|
@override(ActionDistribution)
|
2020-07-24 12:01:46 -07:00
|
|
|
def logp(self, actions: TensorType) -> TensorType:
|
2019-04-12 11:39:14 -07:00
|
|
|
return self.dist.log_prob(actions)
|
|
|
|
|
|
|
|
@override(ActionDistribution)
|
2020-07-24 12:01:46 -07:00
|
|
|
def entropy(self) -> TensorType:
|
2019-04-12 11:39:14 -07:00
|
|
|
return self.dist.entropy()
|
|
|
|
|
|
|
|
@override(ActionDistribution)
|
2020-07-24 12:01:46 -07:00
|
|
|
def kl(self, other: ActionDistribution) -> TensorType:
|
2019-12-30 15:27:32 -05:00
|
|
|
return torch.distributions.kl.kl_divergence(self.dist, other.dist)
|
2019-04-12 11:39:14 -07:00
|
|
|
|
|
|
|
@override(ActionDistribution)
|
2020-07-24 12:01:46 -07:00
|
|
|
def sample(self) -> TensorType:
|
2020-02-19 21:18:45 +01:00
|
|
|
self.last_sample = self.dist.sample()
|
|
|
|
return self.last_sample
|
|
|
|
|
|
|
|
@override(ActionDistribution)
|
2020-07-24 12:01:46 -07:00
|
|
|
def sampled_action_logp(self) -> TensorType:
|
2020-02-19 21:18:45 +01:00
|
|
|
assert self.last_sample is not None
|
|
|
|
return self.logp(self.last_sample)
|
2019-04-12 11:39:14 -07:00
|
|
|
|
|
|
|
|
|
|
|
class TorchCategorical(TorchDistributionWrapper):
|
|
|
|
"""Wrapper class for PyTorch Categorical distribution."""
|
|
|
|
|
|
|
|
@override(ActionDistribution)
|
2020-11-12 03:16:12 -08:00
|
|
|
def __init__(self,
|
|
|
|
inputs: List[TensorType],
|
|
|
|
model: TorchModelV2 = None,
|
|
|
|
temperature: float = 1.0):
|
2020-04-06 20:56:16 +02:00
|
|
|
if temperature != 1.0:
|
|
|
|
assert temperature > 0.0, \
|
|
|
|
"Categorical `temperature` must be > 0.0!"
|
|
|
|
inputs /= temperature
|
2020-03-08 21:03:18 +01:00
|
|
|
super().__init__(inputs, model)
|
2020-03-06 19:37:12 +01:00
|
|
|
self.dist = torch.distributions.categorical.Categorical(
|
|
|
|
logits=self.inputs)
|
2019-08-06 18:13:16 +00:00
|
|
|
|
2020-02-19 21:18:45 +01:00
|
|
|
@override(ActionDistribution)
|
2020-11-12 03:16:12 -08:00
|
|
|
def deterministic_sample(self) -> TensorType:
|
2020-04-01 07:00:28 +02:00
|
|
|
self.last_sample = self.dist.probs.argmax(dim=1)
|
|
|
|
return self.last_sample
|
2020-02-19 21:18:45 +01:00
|
|
|
|
2019-08-06 18:13:16 +00:00
|
|
|
@staticmethod
|
|
|
|
@override(ActionDistribution)
|
2020-11-12 03:16:12 -08:00
|
|
|
def required_model_output_shape(
|
|
|
|
action_space: gym.Space,
|
|
|
|
model_config: ModelConfigDict) -> Union[int, np.ndarray]:
|
2019-08-06 18:13:16 +00:00
|
|
|
return action_space.n
|
2019-04-12 11:39:14 -07:00
|
|
|
|
|
|
|
|
2020-03-04 09:41:40 +01:00
|
|
|
class TorchMultiCategorical(TorchDistributionWrapper):
|
|
|
|
"""MultiCategorical distribution for MultiDiscrete action spaces."""
|
|
|
|
|
|
|
|
@override(TorchDistributionWrapper)
|
2021-04-11 13:16:01 +02:00
|
|
|
def __init__(self,
|
|
|
|
inputs: List[TensorType],
|
|
|
|
model: TorchModelV2,
|
|
|
|
input_lens: Union[List[int], np.ndarray, Tuple[int, ...]],
|
|
|
|
action_space=None):
|
2020-03-04 09:41:40 +01:00
|
|
|
super().__init__(inputs, model)
|
2020-04-01 07:00:28 +02:00
|
|
|
# If input_lens is np.ndarray or list, force-make it a tuple.
|
|
|
|
inputs_split = self.inputs.split(tuple(input_lens), dim=1)
|
2020-03-04 09:41:40 +01:00
|
|
|
self.cats = [
|
|
|
|
torch.distributions.categorical.Categorical(logits=input_)
|
|
|
|
for input_ in inputs_split
|
|
|
|
]
|
2021-04-11 13:16:01 +02:00
|
|
|
# Used in case we are dealing with an Int Box.
|
|
|
|
self.action_space = action_space
|
2020-03-04 09:41:40 +01:00
|
|
|
|
|
|
|
@override(TorchDistributionWrapper)
|
2020-11-12 03:16:12 -08:00
|
|
|
def sample(self) -> TensorType:
|
2020-03-04 09:41:40 +01:00
|
|
|
arr = [cat.sample() for cat in self.cats]
|
2021-04-11 13:16:01 +02:00
|
|
|
sample_ = torch.stack(arr, dim=1)
|
|
|
|
if isinstance(self.action_space, gym.spaces.Box):
|
|
|
|
sample_ = torch.reshape(sample_,
|
|
|
|
[-1] + list(self.action_space.shape))
|
|
|
|
self.last_sample = sample_
|
|
|
|
return sample_
|
2020-03-04 09:41:40 +01:00
|
|
|
|
|
|
|
@override(ActionDistribution)
|
2020-11-12 03:16:12 -08:00
|
|
|
def deterministic_sample(self) -> TensorType:
|
2020-03-04 09:41:40 +01:00
|
|
|
arr = [torch.argmax(cat.probs, -1) for cat in self.cats]
|
2021-04-11 13:16:01 +02:00
|
|
|
sample_ = torch.stack(arr, dim=1)
|
|
|
|
if isinstance(self.action_space, gym.spaces.Box):
|
|
|
|
sample_ = torch.reshape(sample_,
|
|
|
|
[-1] + list(self.action_space.shape))
|
|
|
|
self.last_sample = sample_
|
|
|
|
return sample_
|
2020-03-04 09:41:40 +01:00
|
|
|
|
|
|
|
@override(TorchDistributionWrapper)
|
2020-11-12 03:16:12 -08:00
|
|
|
def logp(self, actions: TensorType) -> TensorType:
|
2020-03-04 09:41:40 +01:00
|
|
|
# # If tensor is provided, unstack it into list.
|
|
|
|
if isinstance(actions, torch.Tensor):
|
2021-04-11 13:16:01 +02:00
|
|
|
if isinstance(self.action_space, gym.spaces.Box):
|
|
|
|
actions = torch.reshape(
|
|
|
|
actions, [-1, int(np.product(self.action_space.shape))])
|
2020-03-04 09:41:40 +01:00
|
|
|
actions = torch.unbind(actions, dim=1)
|
|
|
|
logps = torch.stack(
|
|
|
|
[cat.log_prob(act) for cat, act in zip(self.cats, actions)])
|
|
|
|
return torch.sum(logps, dim=0)
|
|
|
|
|
|
|
|
@override(ActionDistribution)
|
2020-11-12 03:16:12 -08:00
|
|
|
def multi_entropy(self) -> TensorType:
|
2020-03-04 09:41:40 +01:00
|
|
|
return torch.stack([cat.entropy() for cat in self.cats], dim=1)
|
|
|
|
|
|
|
|
@override(TorchDistributionWrapper)
|
2020-11-12 03:16:12 -08:00
|
|
|
def entropy(self) -> TensorType:
|
2020-03-04 09:41:40 +01:00
|
|
|
return torch.sum(self.multi_entropy(), dim=1)
|
|
|
|
|
|
|
|
@override(ActionDistribution)
|
2020-11-12 03:16:12 -08:00
|
|
|
def multi_kl(self, other: ActionDistribution) -> TensorType:
|
2020-03-04 09:41:40 +01:00
|
|
|
return torch.stack(
|
|
|
|
[
|
|
|
|
torch.distributions.kl.kl_divergence(cat, oth_cat)
|
|
|
|
for cat, oth_cat in zip(self.cats, other.cats)
|
|
|
|
],
|
|
|
|
dim=1,
|
|
|
|
)
|
|
|
|
|
|
|
|
@override(TorchDistributionWrapper)
|
2020-11-12 03:16:12 -08:00
|
|
|
def kl(self, other: ActionDistribution) -> TensorType:
|
2020-03-04 09:41:40 +01:00
|
|
|
return torch.sum(self.multi_kl(other), dim=1)
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
@override(ActionDistribution)
|
2020-11-12 03:16:12 -08:00
|
|
|
def required_model_output_shape(
|
|
|
|
action_space: gym.Space,
|
|
|
|
model_config: ModelConfigDict) -> Union[int, np.ndarray]:
|
2021-04-11 13:16:01 +02:00
|
|
|
# Int Box.
|
|
|
|
if isinstance(action_space, gym.spaces.Box):
|
|
|
|
assert action_space.dtype.name.startswith("int")
|
|
|
|
low_ = np.min(action_space.low)
|
|
|
|
high_ = np.max(action_space.high)
|
|
|
|
assert np.all(action_space.low == low_)
|
|
|
|
assert np.all(action_space.high == high_)
|
|
|
|
np.product(action_space.shape) * (high_ - low_ + 1)
|
|
|
|
# MultiDiscrete space.
|
|
|
|
else:
|
|
|
|
return np.sum(action_space.nvec)
|
2020-03-04 09:41:40 +01:00
|
|
|
|
|
|
|
|
2019-04-12 11:39:14 -07:00
|
|
|
class TorchDiagGaussian(TorchDistributionWrapper):
|
|
|
|
"""Wrapper class for PyTorch Normal distribution."""
|
|
|
|
|
|
|
|
@override(ActionDistribution)
|
2020-11-12 03:16:12 -08:00
|
|
|
def __init__(self, inputs: List[TensorType], model: TorchModelV2):
|
2020-02-19 21:18:45 +01:00
|
|
|
super().__init__(inputs, model)
|
2020-06-03 13:06:06 -04:00
|
|
|
mean, log_std = torch.chunk(self.inputs, 2, dim=1)
|
2019-04-12 11:39:14 -07:00
|
|
|
self.dist = torch.distributions.normal.Normal(mean, torch.exp(log_std))
|
|
|
|
|
2020-02-19 21:18:45 +01:00
|
|
|
@override(ActionDistribution)
|
2020-11-12 03:16:12 -08:00
|
|
|
def deterministic_sample(self) -> TensorType:
|
2020-04-01 07:00:28 +02:00
|
|
|
self.last_sample = self.dist.mean
|
|
|
|
return self.last_sample
|
2020-02-19 21:18:45 +01:00
|
|
|
|
2019-04-12 11:39:14 -07:00
|
|
|
@override(TorchDistributionWrapper)
|
2020-11-12 03:16:12 -08:00
|
|
|
def logp(self, actions: TensorType) -> TensorType:
|
2020-03-02 19:53:19 +01:00
|
|
|
return super().logp(actions).sum(-1)
|
|
|
|
|
|
|
|
@override(TorchDistributionWrapper)
|
2020-11-12 03:16:12 -08:00
|
|
|
def entropy(self) -> TensorType:
|
2020-03-02 19:53:19 +01:00
|
|
|
return super().entropy().sum(-1)
|
|
|
|
|
|
|
|
@override(TorchDistributionWrapper)
|
2020-11-12 03:16:12 -08:00
|
|
|
def kl(self, other: ActionDistribution) -> TensorType:
|
2020-03-02 19:53:19 +01:00
|
|
|
return super().kl(other).sum(-1)
|
2019-08-06 18:13:16 +00:00
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
@override(ActionDistribution)
|
2020-11-12 03:16:12 -08:00
|
|
|
def required_model_output_shape(
|
|
|
|
action_space: gym.Space,
|
|
|
|
model_config: ModelConfigDict) -> Union[int, np.ndarray]:
|
2019-08-06 18:13:16 +00:00
|
|
|
return np.prod(action_space.shape) * 2
|
2020-04-09 23:04:21 +02:00
|
|
|
|
|
|
|
|
2020-04-15 13:25:16 +02:00
|
|
|
class TorchSquashedGaussian(TorchDistributionWrapper):
|
|
|
|
"""A tanh-squashed Gaussian distribution defined by: mean, std, low, high.
|
|
|
|
|
|
|
|
The distribution will never return low or high exactly, but
|
|
|
|
`low`+SMALL_NUMBER or `high`-SMALL_NUMBER respectively.
|
|
|
|
"""
|
|
|
|
|
2020-11-12 03:16:12 -08:00
|
|
|
def __init__(self,
|
|
|
|
inputs: List[TensorType],
|
|
|
|
model: TorchModelV2,
|
|
|
|
low: float = -1.0,
|
|
|
|
high: float = 1.0):
|
2020-04-15 13:25:16 +02:00
|
|
|
"""Parameterizes the distribution via `inputs`.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
low (float): The lowest possible sampling value
|
|
|
|
(excluding this value).
|
|
|
|
high (float): The highest possible sampling value
|
|
|
|
(excluding this value).
|
|
|
|
"""
|
|
|
|
super().__init__(inputs, model)
|
|
|
|
# Split inputs into mean and log(std).
|
|
|
|
mean, log_std = torch.chunk(self.inputs, 2, dim=-1)
|
|
|
|
# Clip `scale` values (coming from NN) to reasonable values.
|
|
|
|
log_std = torch.clamp(log_std, MIN_LOG_NN_OUTPUT, MAX_LOG_NN_OUTPUT)
|
|
|
|
std = torch.exp(log_std)
|
|
|
|
self.dist = torch.distributions.normal.Normal(mean, std)
|
|
|
|
assert np.all(np.less(low, high))
|
|
|
|
self.low = low
|
|
|
|
self.high = high
|
2021-05-04 10:06:19 -07:00
|
|
|
self.mean = mean
|
|
|
|
self.std = std
|
2020-04-15 13:25:16 +02:00
|
|
|
|
|
|
|
@override(ActionDistribution)
|
2020-11-12 03:16:12 -08:00
|
|
|
def deterministic_sample(self) -> TensorType:
|
2020-04-15 13:25:16 +02:00
|
|
|
self.last_sample = self._squash(self.dist.mean)
|
|
|
|
return self.last_sample
|
|
|
|
|
|
|
|
@override(TorchDistributionWrapper)
|
2020-11-12 03:16:12 -08:00
|
|
|
def sample(self) -> TensorType:
|
2020-04-15 13:25:16 +02:00
|
|
|
# Use the reparameterization version of `dist.sample` to allow for
|
|
|
|
# the results to be backprop'able e.g. in a loss term.
|
2021-05-04 10:06:19 -07:00
|
|
|
|
2020-04-15 13:25:16 +02:00
|
|
|
normal_sample = self.dist.rsample()
|
|
|
|
self.last_sample = self._squash(normal_sample)
|
|
|
|
return self.last_sample
|
|
|
|
|
|
|
|
@override(ActionDistribution)
|
2020-11-12 03:16:12 -08:00
|
|
|
def logp(self, x: TensorType) -> TensorType:
|
2020-04-19 10:20:23 +02:00
|
|
|
# Unsquash values (from [low,high] to ]-inf,inf[)
|
2020-04-15 13:25:16 +02:00
|
|
|
unsquashed_values = self._unsquash(x)
|
2020-04-19 10:20:23 +02:00
|
|
|
# Get log prob of unsquashed values from our Normal.
|
|
|
|
log_prob_gaussian = self.dist.log_prob(unsquashed_values)
|
|
|
|
# For safety reasons, clamp somehow, only then sum up.
|
|
|
|
log_prob_gaussian = torch.clamp(log_prob_gaussian, -100, 100)
|
|
|
|
log_prob_gaussian = torch.sum(log_prob_gaussian, dim=-1)
|
|
|
|
# Get log-prob for squashed Gaussian.
|
2020-04-15 13:25:16 +02:00
|
|
|
unsquashed_values_tanhd = torch.tanh(unsquashed_values)
|
2020-04-19 10:20:23 +02:00
|
|
|
log_prob = log_prob_gaussian - torch.sum(
|
2020-04-15 13:25:16 +02:00
|
|
|
torch.log(1 - unsquashed_values_tanhd**2 + SMALL_NUMBER), dim=-1)
|
|
|
|
return log_prob
|
|
|
|
|
2021-05-04 10:06:19 -07:00
|
|
|
def sample_logp(self):
|
|
|
|
z = self.dist.rsample()
|
|
|
|
actions = self._squash(z)
|
|
|
|
return actions, torch.sum(
|
|
|
|
self.dist.log_prob(z) -
|
|
|
|
torch.log(1 - actions * actions + SMALL_NUMBER),
|
|
|
|
dim=-1)
|
|
|
|
|
2021-01-07 15:07:35 +01:00
|
|
|
@override(TorchDistributionWrapper)
|
|
|
|
def entropy(self) -> TensorType:
|
|
|
|
raise ValueError("Entropy not defined for SquashedGaussian!")
|
|
|
|
|
|
|
|
@override(TorchDistributionWrapper)
|
|
|
|
def kl(self, other: ActionDistribution) -> TensorType:
|
|
|
|
raise ValueError("KL not defined for SquashedGaussian!")
|
|
|
|
|
2020-11-12 03:16:12 -08:00
|
|
|
def _squash(self, raw_values: TensorType) -> TensorType:
|
2020-04-19 10:20:23 +02:00
|
|
|
# Returned values are within [low, high] (including `low` and `high`).
|
|
|
|
squashed = ((torch.tanh(raw_values) + 1.0) / 2.0) * \
|
|
|
|
(self.high - self.low) + self.low
|
|
|
|
return torch.clamp(squashed, self.low, self.high)
|
2020-04-15 13:25:16 +02:00
|
|
|
|
2020-11-12 03:16:12 -08:00
|
|
|
def _unsquash(self, values: TensorType) -> TensorType:
|
2020-04-19 10:20:23 +02:00
|
|
|
normed_values = (values - self.low) / (self.high - self.low) * 2.0 - \
|
|
|
|
1.0
|
|
|
|
# Stabilize input to atanh.
|
|
|
|
save_normed_values = torch.clamp(normed_values, -1.0 + SMALL_NUMBER,
|
|
|
|
1.0 - SMALL_NUMBER)
|
|
|
|
unsquashed = atanh(save_normed_values)
|
|
|
|
return unsquashed
|
2020-04-15 13:25:16 +02:00
|
|
|
|
2020-05-03 13:44:25 +02:00
|
|
|
@staticmethod
|
|
|
|
@override(ActionDistribution)
|
2020-11-12 03:16:12 -08:00
|
|
|
def required_model_output_shape(
|
|
|
|
action_space: gym.Space,
|
|
|
|
model_config: ModelConfigDict) -> Union[int, np.ndarray]:
|
2020-05-03 13:44:25 +02:00
|
|
|
return np.prod(action_space.shape) * 2
|
|
|
|
|
2020-04-15 13:25:16 +02:00
|
|
|
|
|
|
|
class TorchBeta(TorchDistributionWrapper):
|
|
|
|
"""
|
|
|
|
A Beta distribution is defined on the interval [0, 1] and parameterized by
|
|
|
|
shape parameters alpha and beta (also called concentration parameters).
|
|
|
|
|
|
|
|
PDF(x; alpha, beta) = x**(alpha - 1) (1 - x)**(beta - 1) / Z
|
|
|
|
with Z = Gamma(alpha) Gamma(beta) / Gamma(alpha + beta)
|
|
|
|
and Gamma(n) = (n - 1)!
|
|
|
|
"""
|
|
|
|
|
2020-11-12 03:16:12 -08:00
|
|
|
def __init__(self,
|
|
|
|
inputs: List[TensorType],
|
|
|
|
model: TorchModelV2,
|
|
|
|
low: float = 0.0,
|
|
|
|
high: float = 1.0):
|
2020-04-15 13:25:16 +02:00
|
|
|
super().__init__(inputs, model)
|
|
|
|
# Stabilize input parameters (possibly coming from a linear layer).
|
|
|
|
self.inputs = torch.clamp(self.inputs, log(SMALL_NUMBER),
|
|
|
|
-log(SMALL_NUMBER))
|
|
|
|
self.inputs = torch.log(torch.exp(self.inputs) + 1.0) + 1.0
|
|
|
|
self.low = low
|
|
|
|
self.high = high
|
|
|
|
alpha, beta = torch.chunk(self.inputs, 2, dim=-1)
|
|
|
|
# Note: concentration0==beta, concentration1=alpha (!)
|
|
|
|
self.dist = torch.distributions.Beta(
|
|
|
|
concentration1=alpha, concentration0=beta)
|
|
|
|
|
|
|
|
@override(ActionDistribution)
|
2020-11-12 03:16:12 -08:00
|
|
|
def deterministic_sample(self) -> TensorType:
|
2020-04-15 13:25:16 +02:00
|
|
|
self.last_sample = self._squash(self.dist.mean)
|
|
|
|
return self.last_sample
|
|
|
|
|
|
|
|
@override(TorchDistributionWrapper)
|
2020-11-12 03:16:12 -08:00
|
|
|
def sample(self) -> TensorType:
|
2020-04-15 13:25:16 +02:00
|
|
|
# Use the reparameterization version of `dist.sample` to allow for
|
|
|
|
# the results to be backprop'able e.g. in a loss term.
|
|
|
|
normal_sample = self.dist.rsample()
|
|
|
|
self.last_sample = self._squash(normal_sample)
|
|
|
|
return self.last_sample
|
|
|
|
|
|
|
|
@override(ActionDistribution)
|
2020-11-12 03:16:12 -08:00
|
|
|
def logp(self, x: TensorType) -> TensorType:
|
2020-04-15 13:25:16 +02:00
|
|
|
unsquashed_values = self._unsquash(x)
|
|
|
|
return torch.sum(self.dist.log_prob(unsquashed_values), dim=-1)
|
|
|
|
|
2020-11-12 03:16:12 -08:00
|
|
|
def _squash(self, raw_values: TensorType) -> TensorType:
|
2020-04-15 13:25:16 +02:00
|
|
|
return raw_values * (self.high - self.low) + self.low
|
|
|
|
|
2020-11-12 03:16:12 -08:00
|
|
|
def _unsquash(self, values: TensorType) -> TensorType:
|
2020-04-15 13:25:16 +02:00
|
|
|
return (values - self.low) / (self.high - self.low)
|
|
|
|
|
2020-05-03 13:44:25 +02:00
|
|
|
@staticmethod
|
|
|
|
@override(ActionDistribution)
|
2020-11-12 03:16:12 -08:00
|
|
|
def required_model_output_shape(
|
|
|
|
action_space: gym.Space,
|
|
|
|
model_config: ModelConfigDict) -> Union[int, np.ndarray]:
|
2020-05-03 13:44:25 +02:00
|
|
|
return np.prod(action_space.shape) * 2
|
|
|
|
|
2020-04-15 13:25:16 +02:00
|
|
|
|
2020-04-09 23:04:21 +02:00
|
|
|
class TorchDeterministic(TorchDistributionWrapper):
|
|
|
|
"""Action distribution that returns the input values directly.
|
|
|
|
|
|
|
|
This is similar to DiagGaussian with standard deviation zero (thus only
|
|
|
|
requiring the "mean" values as NN output).
|
|
|
|
"""
|
|
|
|
|
|
|
|
@override(ActionDistribution)
|
2020-11-12 03:16:12 -08:00
|
|
|
def deterministic_sample(self) -> TensorType:
|
2020-04-09 23:04:21 +02:00
|
|
|
return self.inputs
|
|
|
|
|
|
|
|
@override(TorchDistributionWrapper)
|
2020-11-12 03:16:12 -08:00
|
|
|
def sampled_action_logp(self) -> TensorType:
|
2020-07-23 19:43:20 +02:00
|
|
|
return torch.zeros((self.inputs.size()[0], ), dtype=torch.float32)
|
2020-04-09 23:04:21 +02:00
|
|
|
|
|
|
|
@override(TorchDistributionWrapper)
|
2020-11-12 03:16:12 -08:00
|
|
|
def sample(self) -> TensorType:
|
2020-04-09 23:04:21 +02:00
|
|
|
return self.deterministic_sample()
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
@override(ActionDistribution)
|
2020-11-12 03:16:12 -08:00
|
|
|
def required_model_output_shape(
|
|
|
|
action_space: gym.Space,
|
|
|
|
model_config: ModelConfigDict) -> Union[int, np.ndarray]:
|
2020-04-09 23:04:21 +02:00
|
|
|
return np.prod(action_space.shape)
|
2020-04-23 09:09:22 +02:00
|
|
|
|
|
|
|
|
|
|
|
class TorchMultiActionDistribution(TorchDistributionWrapper):
|
|
|
|
"""Action distribution that operates on multiple, possibly nested actions.
|
|
|
|
"""
|
|
|
|
|
|
|
|
def __init__(self, inputs, model, *, child_distributions, input_lens,
|
|
|
|
action_space):
|
|
|
|
"""Initializes a TorchMultiActionDistribution object.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
inputs (torch.Tensor): A single tensor of shape [BATCH, size].
|
2020-10-27 10:00:24 +01:00
|
|
|
model (TorchModelV2): The TorchModelV2 object used to produce
|
|
|
|
inputs for this distribution.
|
2020-04-23 09:09:22 +02:00
|
|
|
child_distributions (any[torch.Tensor]): Any struct
|
|
|
|
that contains the child distribution classes to use to
|
|
|
|
instantiate the child distributions from `inputs`. This could
|
|
|
|
be an already flattened list or a struct according to
|
|
|
|
`action_space`.
|
|
|
|
input_lens (any[int]): A flat list or a nested struct of input
|
|
|
|
split lengths used to split `inputs`.
|
|
|
|
action_space (Union[gym.spaces.Dict,gym.spaces.Tuple]): The complex
|
|
|
|
and possibly nested action space.
|
|
|
|
"""
|
|
|
|
if not isinstance(inputs, torch.Tensor):
|
2020-10-27 10:00:24 +01:00
|
|
|
inputs = torch.from_numpy(inputs)
|
|
|
|
if isinstance(model, TorchModelV2):
|
|
|
|
inputs = inputs.to(next(model.parameters()).device)
|
2020-04-23 09:09:22 +02:00
|
|
|
super().__init__(inputs, model)
|
|
|
|
|
|
|
|
self.action_space_struct = get_base_struct_from_space(action_space)
|
|
|
|
|
2020-10-06 20:28:16 +02:00
|
|
|
self.input_lens = tree.flatten(input_lens)
|
2020-04-23 09:09:22 +02:00
|
|
|
flat_child_distributions = tree.flatten(child_distributions)
|
2020-10-06 20:28:16 +02:00
|
|
|
split_inputs = torch.split(inputs, self.input_lens, dim=1)
|
2020-04-23 09:09:22 +02:00
|
|
|
self.flat_child_distributions = tree.map_structure(
|
|
|
|
lambda dist, input_: dist(input_, model), flat_child_distributions,
|
|
|
|
list(split_inputs))
|
|
|
|
|
|
|
|
@override(ActionDistribution)
|
|
|
|
def logp(self, x):
|
|
|
|
if isinstance(x, np.ndarray):
|
|
|
|
x = torch.Tensor(x)
|
|
|
|
# Single tensor input (all merged).
|
|
|
|
if isinstance(x, torch.Tensor):
|
|
|
|
split_indices = []
|
|
|
|
for dist in self.flat_child_distributions:
|
|
|
|
if isinstance(dist, TorchCategorical):
|
|
|
|
split_indices.append(1)
|
|
|
|
else:
|
|
|
|
split_indices.append(dist.sample().size()[1])
|
|
|
|
split_x = list(torch.split(x, split_indices, dim=1))
|
|
|
|
# Structured or flattened (by single action component) input.
|
|
|
|
else:
|
|
|
|
split_x = tree.flatten(x)
|
|
|
|
|
|
|
|
def map_(val, dist):
|
|
|
|
# Remove extra categorical dimension.
|
|
|
|
if isinstance(dist, TorchCategorical):
|
|
|
|
val = torch.squeeze(val, dim=-1).int()
|
|
|
|
return dist.logp(val)
|
|
|
|
|
|
|
|
# Remove extra categorical dimension and take the logp of each
|
|
|
|
# component.
|
|
|
|
flat_logps = tree.map_structure(map_, split_x,
|
|
|
|
self.flat_child_distributions)
|
|
|
|
|
|
|
|
return functools.reduce(lambda a, b: a + b, flat_logps)
|
|
|
|
|
|
|
|
@override(ActionDistribution)
|
|
|
|
def kl(self, other):
|
|
|
|
kl_list = [
|
|
|
|
d.kl(o) for d, o in zip(self.flat_child_distributions,
|
|
|
|
other.flat_child_distributions)
|
|
|
|
]
|
|
|
|
return functools.reduce(lambda a, b: a + b, kl_list)
|
|
|
|
|
|
|
|
@override(ActionDistribution)
|
|
|
|
def entropy(self):
|
|
|
|
entropy_list = [d.entropy() for d in self.flat_child_distributions]
|
|
|
|
return functools.reduce(lambda a, b: a + b, entropy_list)
|
|
|
|
|
|
|
|
@override(ActionDistribution)
|
|
|
|
def sample(self):
|
|
|
|
child_distributions = tree.unflatten_as(self.action_space_struct,
|
|
|
|
self.flat_child_distributions)
|
|
|
|
return tree.map_structure(lambda s: s.sample(), child_distributions)
|
|
|
|
|
|
|
|
@override(ActionDistribution)
|
|
|
|
def deterministic_sample(self):
|
|
|
|
child_distributions = tree.unflatten_as(self.action_space_struct,
|
|
|
|
self.flat_child_distributions)
|
|
|
|
return tree.map_structure(lambda s: s.deterministic_sample(),
|
|
|
|
child_distributions)
|
|
|
|
|
|
|
|
@override(TorchDistributionWrapper)
|
|
|
|
def sampled_action_logp(self):
|
|
|
|
p = self.flat_child_distributions[0].sampled_action_logp()
|
|
|
|
for c in self.flat_child_distributions[1:]:
|
|
|
|
p += c.sampled_action_logp()
|
|
|
|
return p
|
2020-10-06 20:28:16 +02:00
|
|
|
|
|
|
|
@override(ActionDistribution)
|
|
|
|
def required_model_output_shape(self, action_space, model_config):
|
|
|
|
return np.sum(self.input_lens)
|
2020-11-11 18:45:28 +01:00
|
|
|
|
|
|
|
|
|
|
|
class TorchDirichlet(TorchDistributionWrapper):
|
|
|
|
"""Dirichlet distribution for continuous actions that are between
|
|
|
|
[0,1] and sum to 1.
|
|
|
|
|
|
|
|
e.g. actions that represent resource allocation."""
|
|
|
|
|
|
|
|
def __init__(self, inputs, model):
|
|
|
|
"""Input is a tensor of logits. The exponential of logits is used to
|
|
|
|
parametrize the Dirichlet distribution as all parameters need to be
|
|
|
|
positive. An arbitrary small epsilon is added to the concentration
|
|
|
|
parameters to be zero due to numerical error.
|
|
|
|
|
|
|
|
See issue #4440 for more details.
|
|
|
|
"""
|
|
|
|
self.epsilon = torch.tensor(1e-7).to(inputs.device)
|
|
|
|
concentration = torch.exp(inputs) + self.epsilon
|
|
|
|
self.dist = torch.distributions.dirichlet.Dirichlet(
|
|
|
|
concentration=concentration,
|
|
|
|
validate_args=True,
|
|
|
|
)
|
|
|
|
super().__init__(concentration, model)
|
|
|
|
|
|
|
|
@override(ActionDistribution)
|
|
|
|
def deterministic_sample(self) -> TensorType:
|
2020-11-12 16:27:34 +01:00
|
|
|
self.last_sample = nn.functional.softmax(self.dist.concentration)
|
|
|
|
return self.last_sample
|
2020-11-11 18:45:28 +01:00
|
|
|
|
|
|
|
@override(ActionDistribution)
|
|
|
|
def logp(self, x):
|
|
|
|
# Support of Dirichlet are positive real numbers. x is already
|
|
|
|
# an array of positive numbers, but we clip to avoid zeros due to
|
|
|
|
# numerical errors.
|
|
|
|
x = torch.max(x, self.epsilon)
|
|
|
|
x = x / torch.sum(x, dim=-1, keepdim=True)
|
|
|
|
return self.dist.log_prob(x)
|
|
|
|
|
|
|
|
@override(ActionDistribution)
|
|
|
|
def entropy(self):
|
|
|
|
return self.dist.entropy()
|
|
|
|
|
|
|
|
@override(ActionDistribution)
|
|
|
|
def kl(self, other):
|
|
|
|
return self.dist.kl_divergence(other.dist)
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
@override(ActionDistribution)
|
|
|
|
def required_model_output_shape(action_space, model_config):
|
|
|
|
return np.prod(action_space.shape)
|