ray/rllib/utils/spaces/repeated.py
Balaji Veeramani 7f1bacc7dc
[CI] Format Python code with Black (#21975)
See #21316 and #21311 for the motivation behind these changes.
2022-01-29 18:41:57 -08:00

38 lines
1.1 KiB
Python

import gym
import numpy as np
from ray.rllib.utils.annotations import PublicAPI
@PublicAPI
class Repeated(gym.Space):
"""Represents a variable-length list of child spaces.
Example:
self.observation_space = spaces.Repeated(spaces.Box(4,), max_len=10)
--> from 0 to 10 boxes of shape (4,)
See also: documentation for rllib.models.RepeatedValues, which shows how
the lists are represented as batched input for ModelV2 classes.
"""
def __init__(self, child_space: gym.Space, max_len: int):
super().__init__()
self.child_space = child_space
self.max_len = max_len
def sample(self):
return [
self.child_space.sample()
for _ in range(self.np_random.randint(1, self.max_len + 1))
]
def contains(self, x):
return (
isinstance(x, (list, np.ndarray))
and len(x) <= self.max_len
and all(self.child_space.contains(c) for c in x)
)
def __repr__(self):
return "Repeated({}, {})".format(self.child_space, self.max_len)