mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
147 lines
4.9 KiB
Python
147 lines
4.9 KiB
Python
from ray.rllib.models.tf.tf_action_dist import Categorical, ActionDistribution
|
|
from ray.rllib.models.torch.torch_action_dist import TorchCategorical, \
|
|
TorchDistributionWrapper
|
|
from ray.rllib.utils.framework import try_import_tf, try_import_torch
|
|
|
|
tf1, tf, tfv = try_import_tf()
|
|
torch, nn = try_import_torch()
|
|
|
|
|
|
class BinaryAutoregressiveDistribution(ActionDistribution):
|
|
"""Action distribution P(a1, a2) = P(a1) * P(a2 | a1)"""
|
|
|
|
def deterministic_sample(self):
|
|
# First, sample a1.
|
|
a1_dist = self._a1_distribution()
|
|
a1 = a1_dist.deterministic_sample()
|
|
|
|
# Sample a2 conditioned on a1.
|
|
a2_dist = self._a2_distribution(a1)
|
|
a2 = a2_dist.deterministic_sample()
|
|
self._action_logp = a1_dist.logp(a1) + a2_dist.logp(a2)
|
|
|
|
# Return the action tuple.
|
|
return (a1, a2)
|
|
|
|
def sample(self):
|
|
# First, sample a1.
|
|
a1_dist = self._a1_distribution()
|
|
a1 = a1_dist.sample()
|
|
|
|
# Sample a2 conditioned on a1.
|
|
a2_dist = self._a2_distribution(a1)
|
|
a2 = a2_dist.sample()
|
|
self._action_logp = a1_dist.logp(a1) + a2_dist.logp(a2)
|
|
|
|
# Return the action tuple.
|
|
return (a1, a2)
|
|
|
|
def logp(self, actions):
|
|
a1, a2 = actions[:, 0], actions[:, 1]
|
|
a1_vec = tf.expand_dims(tf.cast(a1, tf.float32), 1)
|
|
a1_logits, a2_logits = self.model.action_model([self.inputs, a1_vec])
|
|
return (
|
|
Categorical(a1_logits).logp(a1) + Categorical(a2_logits).logp(a2))
|
|
|
|
def sampled_action_logp(self):
|
|
return tf.exp(self._action_logp)
|
|
|
|
def entropy(self):
|
|
a1_dist = self._a1_distribution()
|
|
a2_dist = self._a2_distribution(a1_dist.sample())
|
|
return a1_dist.entropy() + a2_dist.entropy()
|
|
|
|
def kl(self, other):
|
|
a1_dist = self._a1_distribution()
|
|
a1_terms = a1_dist.kl(other._a1_distribution())
|
|
|
|
a1 = a1_dist.sample()
|
|
a2_terms = self._a2_distribution(a1).kl(other._a2_distribution(a1))
|
|
return a1_terms + a2_terms
|
|
|
|
def _a1_distribution(self):
|
|
BATCH = tf.shape(self.inputs)[0]
|
|
a1_logits, _ = self.model.action_model(
|
|
[self.inputs, tf.zeros((BATCH, 1))])
|
|
a1_dist = Categorical(a1_logits)
|
|
return a1_dist
|
|
|
|
def _a2_distribution(self, a1):
|
|
a1_vec = tf.expand_dims(tf.cast(a1, tf.float32), 1)
|
|
_, a2_logits = self.model.action_model([self.inputs, a1_vec])
|
|
a2_dist = Categorical(a2_logits)
|
|
return a2_dist
|
|
|
|
@staticmethod
|
|
def required_model_output_shape(action_space, model_config):
|
|
return 16 # controls model output feature vector size
|
|
|
|
|
|
class TorchBinaryAutoregressiveDistribution(TorchDistributionWrapper):
|
|
"""Action distribution P(a1, a2) = P(a1) * P(a2 | a1)"""
|
|
|
|
def deterministic_sample(self):
|
|
# First, sample a1.
|
|
a1_dist = self._a1_distribution()
|
|
a1 = a1_dist.deterministic_sample()
|
|
|
|
# Sample a2 conditioned on a1.
|
|
a2_dist = self._a2_distribution(a1)
|
|
a2 = a2_dist.deterministic_sample()
|
|
self._action_logp = a1_dist.logp(a1) + a2_dist.logp(a2)
|
|
|
|
# Return the action tuple.
|
|
return (a1, a2)
|
|
|
|
def sample(self):
|
|
# First, sample a1.
|
|
a1_dist = self._a1_distribution()
|
|
a1 = a1_dist.sample()
|
|
|
|
# Sample a2 conditioned on a1.
|
|
a2_dist = self._a2_distribution(a1)
|
|
a2 = a2_dist.sample()
|
|
self._action_logp = a1_dist.logp(a1) + a2_dist.logp(a2)
|
|
|
|
# Return the action tuple.
|
|
return (a1, a2)
|
|
|
|
def logp(self, actions):
|
|
a1, a2 = actions[:, 0], actions[:, 1]
|
|
a1_vec = torch.unsqueeze(a1.float(), 1)
|
|
a1_logits, a2_logits = self.model.action_module(self.inputs, a1_vec)
|
|
return (TorchCategorical(a1_logits).logp(a1) +
|
|
TorchCategorical(a2_logits).logp(a2))
|
|
|
|
def sampled_action_logp(self):
|
|
return torch.exp(self._action_logp)
|
|
|
|
def entropy(self):
|
|
a1_dist = self._a1_distribution()
|
|
a2_dist = self._a2_distribution(a1_dist.sample())
|
|
return a1_dist.entropy() + a2_dist.entropy()
|
|
|
|
def kl(self, other):
|
|
a1_dist = self._a1_distribution()
|
|
a1_terms = a1_dist.kl(other._a1_distribution())
|
|
|
|
a1 = a1_dist.sample()
|
|
a2_terms = self._a2_distribution(a1).kl(other._a2_distribution(a1))
|
|
return a1_terms + a2_terms
|
|
|
|
def _a1_distribution(self):
|
|
BATCH = self.inputs.shape[0]
|
|
zeros = torch.zeros((BATCH, 1)).to(self.inputs.device)
|
|
a1_logits, _ = self.model.action_module(self.inputs, zeros)
|
|
a1_dist = TorchCategorical(a1_logits)
|
|
return a1_dist
|
|
|
|
def _a2_distribution(self, a1):
|
|
a1_vec = torch.unsqueeze(a1.float(), 1)
|
|
_, a2_logits = self.model.action_module(self.inputs, a1_vec)
|
|
a2_dist = TorchCategorical(a2_logits)
|
|
return a2_dist
|
|
|
|
@staticmethod
|
|
def required_model_output_shape(action_space, model_config):
|
|
return 16 # controls model output feature vector size
|