ray/rllib/examples/models/autoregressive_action_dist.py
Balaji Veeramani 7f1bacc7dc
[CI] Format Python code with Black (#21975)
See #21316 and #21311 for the motivation behind these changes.
2022-01-29 18:41:57 -08:00

148 lines
4.9 KiB
Python

from ray.rllib.models.tf.tf_action_dist import Categorical, ActionDistribution
from ray.rllib.models.torch.torch_action_dist import (
TorchCategorical,
TorchDistributionWrapper,
)
from ray.rllib.utils.framework import try_import_tf, try_import_torch
tf1, tf, tfv = try_import_tf()
torch, nn = try_import_torch()
class BinaryAutoregressiveDistribution(ActionDistribution):
"""Action distribution P(a1, a2) = P(a1) * P(a2 | a1)"""
def deterministic_sample(self):
# First, sample a1.
a1_dist = self._a1_distribution()
a1 = a1_dist.deterministic_sample()
# Sample a2 conditioned on a1.
a2_dist = self._a2_distribution(a1)
a2 = a2_dist.deterministic_sample()
self._action_logp = a1_dist.logp(a1) + a2_dist.logp(a2)
# Return the action tuple.
return (a1, a2)
def sample(self):
# First, sample a1.
a1_dist = self._a1_distribution()
a1 = a1_dist.sample()
# Sample a2 conditioned on a1.
a2_dist = self._a2_distribution(a1)
a2 = a2_dist.sample()
self._action_logp = a1_dist.logp(a1) + a2_dist.logp(a2)
# Return the action tuple.
return (a1, a2)
def logp(self, actions):
a1, a2 = actions[:, 0], actions[:, 1]
a1_vec = tf.expand_dims(tf.cast(a1, tf.float32), 1)
a1_logits, a2_logits = self.model.action_model([self.inputs, a1_vec])
return Categorical(a1_logits).logp(a1) + Categorical(a2_logits).logp(a2)
def sampled_action_logp(self):
return tf.exp(self._action_logp)
def entropy(self):
a1_dist = self._a1_distribution()
a2_dist = self._a2_distribution(a1_dist.sample())
return a1_dist.entropy() + a2_dist.entropy()
def kl(self, other):
a1_dist = self._a1_distribution()
a1_terms = a1_dist.kl(other._a1_distribution())
a1 = a1_dist.sample()
a2_terms = self._a2_distribution(a1).kl(other._a2_distribution(a1))
return a1_terms + a2_terms
def _a1_distribution(self):
BATCH = tf.shape(self.inputs)[0]
a1_logits, _ = self.model.action_model([self.inputs, tf.zeros((BATCH, 1))])
a1_dist = Categorical(a1_logits)
return a1_dist
def _a2_distribution(self, a1):
a1_vec = tf.expand_dims(tf.cast(a1, tf.float32), 1)
_, a2_logits = self.model.action_model([self.inputs, a1_vec])
a2_dist = Categorical(a2_logits)
return a2_dist
@staticmethod
def required_model_output_shape(action_space, model_config):
return 16 # controls model output feature vector size
class TorchBinaryAutoregressiveDistribution(TorchDistributionWrapper):
"""Action distribution P(a1, a2) = P(a1) * P(a2 | a1)"""
def deterministic_sample(self):
# First, sample a1.
a1_dist = self._a1_distribution()
a1 = a1_dist.deterministic_sample()
# Sample a2 conditioned on a1.
a2_dist = self._a2_distribution(a1)
a2 = a2_dist.deterministic_sample()
self._action_logp = a1_dist.logp(a1) + a2_dist.logp(a2)
# Return the action tuple.
return (a1, a2)
def sample(self):
# First, sample a1.
a1_dist = self._a1_distribution()
a1 = a1_dist.sample()
# Sample a2 conditioned on a1.
a2_dist = self._a2_distribution(a1)
a2 = a2_dist.sample()
self._action_logp = a1_dist.logp(a1) + a2_dist.logp(a2)
# Return the action tuple.
return (a1, a2)
def logp(self, actions):
a1, a2 = actions[:, 0], actions[:, 1]
a1_vec = torch.unsqueeze(a1.float(), 1)
a1_logits, a2_logits = self.model.action_module(self.inputs, a1_vec)
return TorchCategorical(a1_logits).logp(a1) + TorchCategorical(a2_logits).logp(
a2
)
def sampled_action_logp(self):
return torch.exp(self._action_logp)
def entropy(self):
a1_dist = self._a1_distribution()
a2_dist = self._a2_distribution(a1_dist.sample())
return a1_dist.entropy() + a2_dist.entropy()
def kl(self, other):
a1_dist = self._a1_distribution()
a1_terms = a1_dist.kl(other._a1_distribution())
a1 = a1_dist.sample()
a2_terms = self._a2_distribution(a1).kl(other._a2_distribution(a1))
return a1_terms + a2_terms
def _a1_distribution(self):
BATCH = self.inputs.shape[0]
zeros = torch.zeros((BATCH, 1)).to(self.inputs.device)
a1_logits, _ = self.model.action_module(self.inputs, zeros)
a1_dist = TorchCategorical(a1_logits)
return a1_dist
def _a2_distribution(self, a1):
a1_vec = torch.unsqueeze(a1.float(), 1)
_, a2_logits = self.model.action_module(self.inputs, a1_vec)
a2_dist = TorchCategorical(a2_logits)
return a2_dist
@staticmethod
def required_model_output_shape(action_space, model_config):
return 16 # controls model output feature vector size