mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
61 lines
1.8 KiB
Python
61 lines
1.8 KiB
Python
import gym
|
|
import numpy as np
|
|
|
|
from ray.rllib.utils.annotations import PublicAPI
|
|
|
|
|
|
@PublicAPI
|
|
class Simplex(gym.Space):
|
|
"""Represents a d - 1 dimensional Simplex in R^d.
|
|
|
|
That is, all coordinates are in [0, 1] and sum to 1.
|
|
The dimension d of the simplex is assumed to be shape[-1].
|
|
|
|
Additionally one can specify the underlying distribution of
|
|
the simplex as a Dirichlet distribution by providing concentration
|
|
parameters. By default, sampling is uniform, i.e. concentration is
|
|
all 1s.
|
|
|
|
Example usage:
|
|
self.action_space = spaces.Simplex(shape=(3, 4))
|
|
--> 3 independent 4d Dirichlet with uniform concentration
|
|
"""
|
|
|
|
def __init__(self, shape, concentration=None, dtype=np.float32):
|
|
assert type(shape) in [tuple, list]
|
|
|
|
super().__init__(shape, dtype)
|
|
self.dim = self.shape[-1]
|
|
|
|
if concentration is not None:
|
|
assert (
|
|
concentration.shape == shape[:-1]
|
|
), f"{concentration.shape} vs {shape[:-1]}"
|
|
self.concentration = concentration
|
|
else:
|
|
self.concentration = np.array([1] * self.dim)
|
|
|
|
def sample(self):
|
|
return np.random.dirichlet(self.concentration, size=self.shape[:-1]).astype(
|
|
self.dtype
|
|
)
|
|
|
|
def contains(self, x):
|
|
return x.shape == self.shape and np.allclose(
|
|
np.sum(x, axis=-1), np.ones_like(x[..., 0])
|
|
)
|
|
|
|
def to_jsonable(self, sample_n):
|
|
return np.array(sample_n).tolist()
|
|
|
|
def from_jsonable(self, sample_n):
|
|
return [np.asarray(sample) for sample in sample_n]
|
|
|
|
def __repr__(self):
|
|
return "Simplex({}; {})".format(self.shape, self.concentration)
|
|
|
|
def __eq__(self, other):
|
|
return (
|
|
np.allclose(self.concentration, other.concentration)
|
|
and self.shape == other.shape
|
|
)
|