ray/rllib/models/jax/jax_action_dist.py
Balaji Veeramani 7f1bacc7dc
[CI] Format Python code with Black (#21975)
See #21316 and #21311 for the motivation behind these changes.
2022-01-29 18:41:57 -08:00

70 lines
2.4 KiB
Python

import time
from ray.rllib.models.action_dist import ActionDistribution
from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.utils.annotations import override
from ray.rllib.utils.framework import try_import_jax, try_import_tfp
from ray.rllib.utils.typing import TensorType, List
jax, flax = try_import_jax()
tfp = try_import_tfp()
class JAXDistribution(ActionDistribution):
"""Wrapper class for JAX distributions."""
@override(ActionDistribution)
def __init__(self, inputs: List[TensorType], model: ModelV2):
super().__init__(inputs, model)
# Store the last sample here.
self.last_sample = None
# Use current time as pseudo-random number generator's seed.
self.prng_key = jax.random.PRNGKey(seed=int(time.time()))
@override(ActionDistribution)
def logp(self, actions: TensorType) -> TensorType:
return self.dist.log_prob(actions)
@override(ActionDistribution)
def entropy(self) -> TensorType:
return self.dist.entropy()
@override(ActionDistribution)
def kl(self, other: ActionDistribution) -> TensorType:
return self.dist.kl_divergence(other.dist)
@override(ActionDistribution)
def sample(self) -> TensorType:
# Update the state of our PRNG.
_, self.prng_key = jax.random.split(self.prng_key)
self.last_sample = jax.random.categorical(self.prng_key, self.inputs)
return self.last_sample
@override(ActionDistribution)
def sampled_action_logp(self) -> TensorType:
assert self.last_sample is not None
return self.logp(self.last_sample)
class JAXCategorical(JAXDistribution):
"""Wrapper class for a JAX Categorical distribution."""
@override(ActionDistribution)
def __init__(self, inputs, model=None, temperature=1.0):
if temperature != 1.0:
assert temperature > 0.0, "Categorical `temperature` must be > 0.0!"
inputs /= temperature
super().__init__(inputs, model)
self.dist = tfp.experimental.substrates.jax.distributions.Categorical(
logits=self.inputs
)
@override(ActionDistribution)
def deterministic_sample(self):
self.last_sample = self.inputs.argmax(axis=1)
return self.last_sample
@staticmethod
@override(ActionDistribution)
def required_model_output_shape(action_space, model_config):
return action_space.n