ray/rllib/models/torch/torch_action_dist.py
Sven 60d4d5e1aa Remove future imports (#6724)
* Remove all __future__ imports from RLlib.

* Remove (object) again from tf_run_builder.py::TFRunBuilder.

* Fix 2xLINT warnings.

* Fix broken appo_policy import (must be appo_tf_policy)

* Remove future imports from all other ray files (not just RLlib).

* Remove future imports from all other ray files (not just RLlib).

* Remove future import blocks that contain `unicode_literals` as well.
Revert appo_tf_policy.py to appo_policy.py (belongs to another PR).

* Add two empty lines before Schedule class.

* Put back __future__ imports into determine_tests_to_run.py. Fails otherwise on a py2/print related error.
2020-01-09 00:15:48 -08:00

58 lines
1.7 KiB
Python

import numpy as np
from ray.rllib.models.action_dist import ActionDistribution
from ray.rllib.utils.annotations import override
from ray.rllib.utils import try_import_torch
torch, nn = try_import_torch()
class TorchDistributionWrapper(ActionDistribution):
"""Wrapper class for torch.distributions."""
@override(ActionDistribution)
def logp(self, actions):
return self.dist.log_prob(actions)
@override(ActionDistribution)
def entropy(self):
return self.dist.entropy()
@override(ActionDistribution)
def kl(self, other):
return torch.distributions.kl.kl_divergence(self.dist, other.dist)
@override(ActionDistribution)
def sample(self):
return self.dist.sample()
class TorchCategorical(TorchDistributionWrapper):
"""Wrapper class for PyTorch Categorical distribution."""
@override(ActionDistribution)
def __init__(self, inputs, model):
self.dist = torch.distributions.categorical.Categorical(logits=inputs)
@staticmethod
@override(ActionDistribution)
def required_model_output_shape(action_space, model_config):
return action_space.n
class TorchDiagGaussian(TorchDistributionWrapper):
"""Wrapper class for PyTorch Normal distribution."""
@override(ActionDistribution)
def __init__(self, inputs, model):
mean, log_std = torch.chunk(inputs, 2, dim=1)
self.dist = torch.distributions.normal.Normal(mean, torch.exp(log_std))
@override(TorchDistributionWrapper)
def logp(self, actions):
return TorchDistributionWrapper.logp(self, actions).sum(-1)
@staticmethod
@override(ActionDistribution)
def required_model_output_shape(action_space, model_config):
return np.prod(action_space.shape) * 2