ray/rllib/models/torch/torch_action_dist.py
2020-06-03 19:06:06 +02:00

420 lines
15 KiB
Python

import functools
from math import log
import numpy as np
from ray.rllib.models.action_dist import ActionDistribution
from ray.rllib.utils import try_import_tree
from ray.rllib.utils.annotations import override
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.numpy import SMALL_NUMBER, MIN_LOG_NN_OUTPUT, \
MAX_LOG_NN_OUTPUT
from ray.rllib.utils.spaces.space_utils import get_base_struct_from_space
from ray.rllib.utils.torch_ops import atanh
torch, nn = try_import_torch()
tree = try_import_tree()
class TorchDistributionWrapper(ActionDistribution):
"""Wrapper class for torch.distributions."""
@override(ActionDistribution)
def __init__(self, inputs, model):
if not isinstance(inputs, torch.Tensor):
inputs = torch.Tensor(inputs)
super().__init__(inputs, model)
# Store the last sample here.
self.last_sample = None
@override(ActionDistribution)
def logp(self, actions):
return self.dist.log_prob(actions)
@override(ActionDistribution)
def entropy(self):
return self.dist.entropy()
@override(ActionDistribution)
def kl(self, other):
return torch.distributions.kl.kl_divergence(self.dist, other.dist)
@override(ActionDistribution)
def sample(self):
self.last_sample = self.dist.sample()
return self.last_sample
@override(ActionDistribution)
def sampled_action_logp(self):
assert self.last_sample is not None
return self.logp(self.last_sample)
class TorchCategorical(TorchDistributionWrapper):
"""Wrapper class for PyTorch Categorical distribution."""
@override(ActionDistribution)
def __init__(self, inputs, model=None, temperature=1.0):
if temperature != 1.0:
assert temperature > 0.0, \
"Categorical `temperature` must be > 0.0!"
inputs /= temperature
super().__init__(inputs, model)
self.dist = torch.distributions.categorical.Categorical(
logits=self.inputs)
@override(ActionDistribution)
def deterministic_sample(self):
self.last_sample = self.dist.probs.argmax(dim=1)
return self.last_sample
@staticmethod
@override(ActionDistribution)
def required_model_output_shape(action_space, model_config):
return action_space.n
class TorchMultiCategorical(TorchDistributionWrapper):
"""MultiCategorical distribution for MultiDiscrete action spaces."""
@override(TorchDistributionWrapper)
def __init__(self, inputs, model, input_lens):
super().__init__(inputs, model)
# If input_lens is np.ndarray or list, force-make it a tuple.
inputs_split = self.inputs.split(tuple(input_lens), dim=1)
self.cats = [
torch.distributions.categorical.Categorical(logits=input_)
for input_ in inputs_split
]
@override(TorchDistributionWrapper)
def sample(self):
arr = [cat.sample() for cat in self.cats]
self.last_sample = torch.stack(arr, dim=1)
return self.last_sample
@override(ActionDistribution)
def deterministic_sample(self):
arr = [torch.argmax(cat.probs, -1) for cat in self.cats]
self.last_sample = torch.stack(arr, dim=1)
return self.last_sample
@override(TorchDistributionWrapper)
def logp(self, actions):
# # If tensor is provided, unstack it into list.
if isinstance(actions, torch.Tensor):
actions = torch.unbind(actions, dim=1)
logps = torch.stack(
[cat.log_prob(act) for cat, act in zip(self.cats, actions)])
return torch.sum(logps, dim=0)
@override(ActionDistribution)
def multi_entropy(self):
return torch.stack([cat.entropy() for cat in self.cats], dim=1)
@override(TorchDistributionWrapper)
def entropy(self):
return torch.sum(self.multi_entropy(), dim=1)
@override(ActionDistribution)
def multi_kl(self, other):
return torch.stack(
[
torch.distributions.kl.kl_divergence(cat, oth_cat)
for cat, oth_cat in zip(self.cats, other.cats)
],
dim=1,
)
@override(TorchDistributionWrapper)
def kl(self, other):
return torch.sum(self.multi_kl(other), dim=1)
@staticmethod
@override(ActionDistribution)
def required_model_output_shape(action_space, model_config):
return np.sum(action_space.nvec)
class TorchDiagGaussian(TorchDistributionWrapper):
"""Wrapper class for PyTorch Normal distribution."""
@override(ActionDistribution)
def __init__(self, inputs, model):
super().__init__(inputs, model)
mean, log_std = torch.chunk(self.inputs, 2, dim=1)
self.dist = torch.distributions.normal.Normal(mean, torch.exp(log_std))
@override(ActionDistribution)
def deterministic_sample(self):
self.last_sample = self.dist.mean
return self.last_sample
@override(TorchDistributionWrapper)
def logp(self, actions):
return super().logp(actions).sum(-1)
@override(TorchDistributionWrapper)
def entropy(self):
return super().entropy().sum(-1)
@override(TorchDistributionWrapper)
def kl(self, other):
return super().kl(other).sum(-1)
@staticmethod
@override(ActionDistribution)
def required_model_output_shape(action_space, model_config):
return np.prod(action_space.shape) * 2
class TorchSquashedGaussian(TorchDistributionWrapper):
"""A tanh-squashed Gaussian distribution defined by: mean, std, low, high.
The distribution will never return low or high exactly, but
`low`+SMALL_NUMBER or `high`-SMALL_NUMBER respectively.
"""
def __init__(self, inputs, model, low=-1.0, high=1.0):
"""Parameterizes the distribution via `inputs`.
Args:
low (float): The lowest possible sampling value
(excluding this value).
high (float): The highest possible sampling value
(excluding this value).
"""
super().__init__(inputs, model)
# Split inputs into mean and log(std).
mean, log_std = torch.chunk(self.inputs, 2, dim=-1)
# Clip `scale` values (coming from NN) to reasonable values.
log_std = torch.clamp(log_std, MIN_LOG_NN_OUTPUT, MAX_LOG_NN_OUTPUT)
std = torch.exp(log_std)
self.dist = torch.distributions.normal.Normal(mean, std)
assert np.all(np.less(low, high))
self.low = low
self.high = high
@override(ActionDistribution)
def deterministic_sample(self):
self.last_sample = self._squash(self.dist.mean)
return self.last_sample
@override(TorchDistributionWrapper)
def sample(self):
# Use the reparameterization version of `dist.sample` to allow for
# the results to be backprop'able e.g. in a loss term.
normal_sample = self.dist.rsample()
self.last_sample = self._squash(normal_sample)
return self.last_sample
@override(ActionDistribution)
def logp(self, x):
# Unsquash values (from [low,high] to ]-inf,inf[)
unsquashed_values = self._unsquash(x)
# Get log prob of unsquashed values from our Normal.
log_prob_gaussian = self.dist.log_prob(unsquashed_values)
# For safety reasons, clamp somehow, only then sum up.
log_prob_gaussian = torch.clamp(log_prob_gaussian, -100, 100)
log_prob_gaussian = torch.sum(log_prob_gaussian, dim=-1)
# Get log-prob for squashed Gaussian.
unsquashed_values_tanhd = torch.tanh(unsquashed_values)
log_prob = log_prob_gaussian - torch.sum(
torch.log(1 - unsquashed_values_tanhd**2 + SMALL_NUMBER), dim=-1)
return log_prob
def _squash(self, raw_values):
# Returned values are within [low, high] (including `low` and `high`).
squashed = ((torch.tanh(raw_values) + 1.0) / 2.0) * \
(self.high - self.low) + self.low
return torch.clamp(squashed, self.low, self.high)
def _unsquash(self, values):
normed_values = (values - self.low) / (self.high - self.low) * 2.0 - \
1.0
# Stabilize input to atanh.
save_normed_values = torch.clamp(normed_values, -1.0 + SMALL_NUMBER,
1.0 - SMALL_NUMBER)
unsquashed = atanh(save_normed_values)
return unsquashed
@staticmethod
@override(ActionDistribution)
def required_model_output_shape(action_space, model_config):
return np.prod(action_space.shape) * 2
class TorchBeta(TorchDistributionWrapper):
"""
A Beta distribution is defined on the interval [0, 1] and parameterized by
shape parameters alpha and beta (also called concentration parameters).
PDF(x; alpha, beta) = x**(alpha - 1) (1 - x)**(beta - 1) / Z
with Z = Gamma(alpha) Gamma(beta) / Gamma(alpha + beta)
and Gamma(n) = (n - 1)!
"""
def __init__(self, inputs, model, low=0.0, high=1.0):
super().__init__(inputs, model)
# Stabilize input parameters (possibly coming from a linear layer).
self.inputs = torch.clamp(self.inputs, log(SMALL_NUMBER),
-log(SMALL_NUMBER))
self.inputs = torch.log(torch.exp(self.inputs) + 1.0) + 1.0
self.low = low
self.high = high
alpha, beta = torch.chunk(self.inputs, 2, dim=-1)
# Note: concentration0==beta, concentration1=alpha (!)
self.dist = torch.distributions.Beta(
concentration1=alpha, concentration0=beta)
@override(ActionDistribution)
def deterministic_sample(self):
self.last_sample = self._squash(self.dist.mean)
return self.last_sample
@override(TorchDistributionWrapper)
def sample(self):
# Use the reparameterization version of `dist.sample` to allow for
# the results to be backprop'able e.g. in a loss term.
normal_sample = self.dist.rsample()
self.last_sample = self._squash(normal_sample)
return self.last_sample
@override(ActionDistribution)
def logp(self, x):
unsquashed_values = self._unsquash(x)
return torch.sum(self.dist.log_prob(unsquashed_values), dim=-1)
def _squash(self, raw_values):
return raw_values * (self.high - self.low) + self.low
def _unsquash(self, values):
return (values - self.low) / (self.high - self.low)
@staticmethod
@override(ActionDistribution)
def required_model_output_shape(action_space, model_config):
return np.prod(action_space.shape) * 2
class TorchDeterministic(TorchDistributionWrapper):
"""Action distribution that returns the input values directly.
This is similar to DiagGaussian with standard deviation zero (thus only
requiring the "mean" values as NN output).
"""
@override(ActionDistribution)
def deterministic_sample(self):
return self.inputs
@override(TorchDistributionWrapper)
def sampled_action_logp(self):
return 0.0
@override(TorchDistributionWrapper)
def sample(self):
return self.deterministic_sample()
@staticmethod
@override(ActionDistribution)
def required_model_output_shape(action_space, model_config):
return np.prod(action_space.shape)
class TorchMultiActionDistribution(TorchDistributionWrapper):
"""Action distribution that operates on multiple, possibly nested actions.
"""
def __init__(self, inputs, model, *, child_distributions, input_lens,
action_space):
"""Initializes a TorchMultiActionDistribution object.
Args:
inputs (torch.Tensor): A single tensor of shape [BATCH, size].
model (ModelV2): The ModelV2 object used to produce inputs for this
distribution.
child_distributions (any[torch.Tensor]): Any struct
that contains the child distribution classes to use to
instantiate the child distributions from `inputs`. This could
be an already flattened list or a struct according to
`action_space`.
input_lens (any[int]): A flat list or a nested struct of input
split lengths used to split `inputs`.
action_space (Union[gym.spaces.Dict,gym.spaces.Tuple]): The complex
and possibly nested action space.
"""
if not isinstance(inputs, torch.Tensor):
inputs = torch.Tensor(inputs)
super().__init__(inputs, model)
self.action_space_struct = get_base_struct_from_space(action_space)
input_lens = tree.flatten(input_lens)
flat_child_distributions = tree.flatten(child_distributions)
split_inputs = torch.split(inputs, input_lens, dim=1)
self.flat_child_distributions = tree.map_structure(
lambda dist, input_: dist(input_, model), flat_child_distributions,
list(split_inputs))
@override(ActionDistribution)
def logp(self, x):
if isinstance(x, np.ndarray):
x = torch.Tensor(x)
# Single tensor input (all merged).
if isinstance(x, torch.Tensor):
split_indices = []
for dist in self.flat_child_distributions:
if isinstance(dist, TorchCategorical):
split_indices.append(1)
else:
split_indices.append(dist.sample().size()[1])
split_x = list(torch.split(x, split_indices, dim=1))
# Structured or flattened (by single action component) input.
else:
split_x = tree.flatten(x)
def map_(val, dist):
# Remove extra categorical dimension.
if isinstance(dist, TorchCategorical):
val = torch.squeeze(val, dim=-1).int()
return dist.logp(val)
# Remove extra categorical dimension and take the logp of each
# component.
flat_logps = tree.map_structure(map_, split_x,
self.flat_child_distributions)
return functools.reduce(lambda a, b: a + b, flat_logps)
@override(ActionDistribution)
def kl(self, other):
kl_list = [
d.kl(o) for d, o in zip(self.flat_child_distributions,
other.flat_child_distributions)
]
return functools.reduce(lambda a, b: a + b, kl_list)
@override(ActionDistribution)
def entropy(self):
entropy_list = [d.entropy() for d in self.flat_child_distributions]
return functools.reduce(lambda a, b: a + b, entropy_list)
@override(ActionDistribution)
def sample(self):
child_distributions = tree.unflatten_as(self.action_space_struct,
self.flat_child_distributions)
return tree.map_structure(lambda s: s.sample(), child_distributions)
@override(ActionDistribution)
def deterministic_sample(self):
child_distributions = tree.unflatten_as(self.action_space_struct,
self.flat_child_distributions)
return tree.map_structure(lambda s: s.deterministic_sample(),
child_distributions)
@override(TorchDistributionWrapper)
def sampled_action_logp(self):
p = self.flat_child_distributions[0].sampled_action_logp()
for c in self.flat_child_distributions[1:]:
p += c.sampled_action_logp()
return p