2019-08-06 18:13:16 +00:00
|
|
|
import numpy as np
|
|
|
|
|
2019-04-12 11:39:14 -07:00
|
|
|
from ray.rllib.models.action_dist import ActionDistribution
|
|
|
|
from ray.rllib.utils.annotations import override
|
2019-12-30 15:27:32 -05:00
|
|
|
from ray.rllib.utils import try_import_torch
|
|
|
|
|
|
|
|
torch, nn = try_import_torch()
|
2019-04-12 11:39:14 -07:00
|
|
|
|
|
|
|
|
|
|
|
class TorchDistributionWrapper(ActionDistribution):
|
|
|
|
"""Wrapper class for torch.distributions."""
|
|
|
|
|
2020-03-04 09:41:40 +01:00
|
|
|
@override(ActionDistribution)
|
|
|
|
def __init__(self, inputs, model):
|
2020-03-08 21:03:18 +01:00
|
|
|
if not isinstance(inputs, torch.Tensor):
|
|
|
|
inputs = torch.Tensor(inputs)
|
2020-03-04 09:41:40 +01:00
|
|
|
super().__init__(inputs, model)
|
2020-02-19 21:18:45 +01:00
|
|
|
# Store the last sample here.
|
|
|
|
self.last_sample = None
|
|
|
|
|
2019-04-12 11:39:14 -07:00
|
|
|
@override(ActionDistribution)
|
|
|
|
def logp(self, actions):
|
|
|
|
return self.dist.log_prob(actions)
|
|
|
|
|
|
|
|
@override(ActionDistribution)
|
|
|
|
def entropy(self):
|
|
|
|
return self.dist.entropy()
|
|
|
|
|
|
|
|
@override(ActionDistribution)
|
|
|
|
def kl(self, other):
|
2019-12-30 15:27:32 -05:00
|
|
|
return torch.distributions.kl.kl_divergence(self.dist, other.dist)
|
2019-04-12 11:39:14 -07:00
|
|
|
|
|
|
|
@override(ActionDistribution)
|
|
|
|
def sample(self):
|
2020-02-19 21:18:45 +01:00
|
|
|
self.last_sample = self.dist.sample()
|
|
|
|
return self.last_sample
|
|
|
|
|
|
|
|
@override(ActionDistribution)
|
|
|
|
def sampled_action_logp(self):
|
|
|
|
assert self.last_sample is not None
|
|
|
|
return self.logp(self.last_sample)
|
2019-04-12 11:39:14 -07:00
|
|
|
|
|
|
|
|
|
|
|
class TorchCategorical(TorchDistributionWrapper):
|
|
|
|
"""Wrapper class for PyTorch Categorical distribution."""
|
|
|
|
|
|
|
|
@override(ActionDistribution)
|
2020-03-06 19:37:12 +01:00
|
|
|
def __init__(self, inputs, model=None, temperature=1.0):
|
2020-04-06 20:56:16 +02:00
|
|
|
if temperature != 1.0:
|
|
|
|
assert temperature > 0.0, \
|
|
|
|
"Categorical `temperature` must be > 0.0!"
|
|
|
|
inputs /= temperature
|
2020-03-08 21:03:18 +01:00
|
|
|
super().__init__(inputs, model)
|
2020-03-06 19:37:12 +01:00
|
|
|
self.dist = torch.distributions.categorical.Categorical(
|
|
|
|
logits=self.inputs)
|
2019-08-06 18:13:16 +00:00
|
|
|
|
2020-02-19 21:18:45 +01:00
|
|
|
@override(ActionDistribution)
|
|
|
|
def deterministic_sample(self):
|
2020-04-01 07:00:28 +02:00
|
|
|
self.last_sample = self.dist.probs.argmax(dim=1)
|
|
|
|
return self.last_sample
|
2020-02-19 21:18:45 +01:00
|
|
|
|
2019-08-06 18:13:16 +00:00
|
|
|
@staticmethod
|
|
|
|
@override(ActionDistribution)
|
|
|
|
def required_model_output_shape(action_space, model_config):
|
|
|
|
return action_space.n
|
2019-04-12 11:39:14 -07:00
|
|
|
|
|
|
|
|
2020-03-04 09:41:40 +01:00
|
|
|
class TorchMultiCategorical(TorchDistributionWrapper):
|
|
|
|
"""MultiCategorical distribution for MultiDiscrete action spaces."""
|
|
|
|
|
|
|
|
@override(TorchDistributionWrapper)
|
|
|
|
def __init__(self, inputs, model, input_lens):
|
|
|
|
super().__init__(inputs, model)
|
2020-04-01 07:00:28 +02:00
|
|
|
# If input_lens is np.ndarray or list, force-make it a tuple.
|
|
|
|
inputs_split = self.inputs.split(tuple(input_lens), dim=1)
|
2020-03-04 09:41:40 +01:00
|
|
|
self.cats = [
|
|
|
|
torch.distributions.categorical.Categorical(logits=input_)
|
|
|
|
for input_ in inputs_split
|
|
|
|
]
|
|
|
|
|
|
|
|
@override(TorchDistributionWrapper)
|
|
|
|
def sample(self):
|
|
|
|
arr = [cat.sample() for cat in self.cats]
|
2020-04-01 07:00:28 +02:00
|
|
|
self.last_sample = torch.stack(arr, dim=1)
|
|
|
|
return self.last_sample
|
2020-03-04 09:41:40 +01:00
|
|
|
|
|
|
|
@override(ActionDistribution)
|
|
|
|
def deterministic_sample(self):
|
|
|
|
arr = [torch.argmax(cat.probs, -1) for cat in self.cats]
|
2020-04-01 07:00:28 +02:00
|
|
|
self.last_sample = torch.stack(arr, dim=1)
|
|
|
|
return self.last_sample
|
2020-03-04 09:41:40 +01:00
|
|
|
|
|
|
|
@override(TorchDistributionWrapper)
|
|
|
|
def logp(self, actions):
|
|
|
|
# # If tensor is provided, unstack it into list.
|
|
|
|
if isinstance(actions, torch.Tensor):
|
|
|
|
actions = torch.unbind(actions, dim=1)
|
|
|
|
logps = torch.stack(
|
|
|
|
[cat.log_prob(act) for cat, act in zip(self.cats, actions)])
|
|
|
|
return torch.sum(logps, dim=0)
|
|
|
|
|
|
|
|
@override(ActionDistribution)
|
|
|
|
def multi_entropy(self):
|
|
|
|
return torch.stack([cat.entropy() for cat in self.cats], dim=1)
|
|
|
|
|
|
|
|
@override(TorchDistributionWrapper)
|
|
|
|
def entropy(self):
|
|
|
|
return torch.sum(self.multi_entropy(), dim=1)
|
|
|
|
|
|
|
|
@override(ActionDistribution)
|
|
|
|
def multi_kl(self, other):
|
|
|
|
return torch.stack(
|
|
|
|
[
|
|
|
|
torch.distributions.kl.kl_divergence(cat, oth_cat)
|
|
|
|
for cat, oth_cat in zip(self.cats, other.cats)
|
|
|
|
],
|
|
|
|
dim=1,
|
|
|
|
)
|
|
|
|
|
|
|
|
@override(TorchDistributionWrapper)
|
|
|
|
def kl(self, other):
|
|
|
|
return torch.sum(self.multi_kl(other), dim=1)
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
@override(ActionDistribution)
|
|
|
|
def required_model_output_shape(action_space, model_config):
|
|
|
|
return np.sum(action_space.nvec)
|
|
|
|
|
|
|
|
|
2019-04-12 11:39:14 -07:00
|
|
|
class TorchDiagGaussian(TorchDistributionWrapper):
|
|
|
|
"""Wrapper class for PyTorch Normal distribution."""
|
|
|
|
|
|
|
|
@override(ActionDistribution)
|
2019-08-10 14:05:12 -07:00
|
|
|
def __init__(self, inputs, model):
|
2020-02-19 21:18:45 +01:00
|
|
|
super().__init__(inputs, model)
|
2019-04-12 11:39:14 -07:00
|
|
|
mean, log_std = torch.chunk(inputs, 2, dim=1)
|
|
|
|
self.dist = torch.distributions.normal.Normal(mean, torch.exp(log_std))
|
|
|
|
|
2020-02-19 21:18:45 +01:00
|
|
|
@override(ActionDistribution)
|
|
|
|
def deterministic_sample(self):
|
2020-04-01 07:00:28 +02:00
|
|
|
self.last_sample = self.dist.mean
|
|
|
|
return self.last_sample
|
2020-02-19 21:18:45 +01:00
|
|
|
|
2019-04-12 11:39:14 -07:00
|
|
|
@override(TorchDistributionWrapper)
|
|
|
|
def logp(self, actions):
|
2020-03-02 19:53:19 +01:00
|
|
|
return super().logp(actions).sum(-1)
|
|
|
|
|
|
|
|
@override(TorchDistributionWrapper)
|
|
|
|
def entropy(self):
|
|
|
|
return super().entropy().sum(-1)
|
|
|
|
|
|
|
|
@override(TorchDistributionWrapper)
|
|
|
|
def kl(self, other):
|
|
|
|
return super().kl(other).sum(-1)
|
2019-08-06 18:13:16 +00:00
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
@override(ActionDistribution)
|
|
|
|
def required_model_output_shape(action_space, model_config):
|
|
|
|
return np.prod(action_space.shape) * 2
|
2020-04-09 23:04:21 +02:00
|
|
|
|
|
|
|
|
|
|
|
class TorchDeterministic(TorchDistributionWrapper):
|
|
|
|
"""Action distribution that returns the input values directly.
|
|
|
|
|
|
|
|
This is similar to DiagGaussian with standard deviation zero (thus only
|
|
|
|
requiring the "mean" values as NN output).
|
|
|
|
"""
|
|
|
|
|
|
|
|
@override(ActionDistribution)
|
|
|
|
def deterministic_sample(self):
|
|
|
|
return self.inputs
|
|
|
|
|
|
|
|
@override(TorchDistributionWrapper)
|
|
|
|
def sampled_action_logp(self):
|
|
|
|
return 0.0
|
|
|
|
|
|
|
|
@override(TorchDistributionWrapper)
|
|
|
|
def sample(self):
|
|
|
|
return self.deterministic_sample()
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
@override(ActionDistribution)
|
|
|
|
def required_model_output_shape(action_space, model_config):
|
|
|
|
return np.prod(action_space.shape)
|