ray/rllib/utils/spaces/simplex.py
Balaji Veeramani 7f1bacc7dc
[CI] Format Python code with Black (#21975)
See #21316 and #21311 for the motivation behind these changes.
2022-01-29 18:41:57 -08:00

57 lines
1.7 KiB
Python

import numpy as np
import gym
from ray.rllib.utils.annotations import PublicAPI
@PublicAPI
class Simplex(gym.Space):
"""Represents a d - 1 dimensional Simplex in R^d.
That is, all coordinates are in [0, 1] and sum to 1.
The dimension d of the simplex is assumed to be shape[-1].
Additionally one can specify the underlying distribution of
the simplex as a Dirichlet distribution by providing concentration
parameters. By default, sampling is uniform, i.e. concentration is
all 1s.
Example usage:
self.action_space = spaces.Simplex(shape=(3, 4))
--> 3 independent 4d Dirichlet with uniform concentration
"""
def __init__(self, shape, concentration=None, dtype=np.float32):
assert type(shape) in [tuple, list]
super().__init__(shape, dtype)
self.dim = self.shape[-1]
if concentration is not None:
assert concentration.shape == shape[:-1]
else:
self.concentration = [1] * self.dim
def sample(self):
return np.random.dirichlet(self.concentration, size=self.shape[:-1]).astype(
self.dtype
)
def contains(self, x):
return x.shape == self.shape and np.allclose(
np.sum(x, axis=-1), np.ones_like(x[..., 0])
)
def to_jsonable(self, sample_n):
return np.array(sample_n).tolist()
def from_jsonable(self, sample_n):
return [np.asarray(sample) for sample in sample_n]
def __repr__(self):
return "Simplex({}; {})".format(self.shape, self.concentration)
def __eq__(self, other):
return (
np.allclose(self.concentration, other.concentration)
and self.shape == other.shape
)