ray/rllib/utils/spaces/flexdict.py
Balaji Veeramani 7f1bacc7dc
[CI] Format Python code with Black (#21975)
See #21316 and #21311 for the motivation behind these changes.
2022-01-29 18:41:57 -08:00

52 lines
1.4 KiB
Python

import gym
from ray.rllib.utils.annotations import PublicAPI
@PublicAPI
class FlexDict(gym.spaces.Dict):
"""Gym Dictionary with arbitrary keys updatable after instantiation
Example:
space = FlexDict({})
space['key'] = spaces.Box(4,)
See also: documentation for gym.spaces.Dict
"""
def __init__(self, spaces=None, **spaces_kwargs):
err = "Use either Dict(spaces=dict(...)) or Dict(foo=x, bar=z)"
assert (spaces is None) or (not spaces_kwargs), err
if spaces is None:
spaces = spaces_kwargs
self.spaces = spaces
for space in spaces.values():
self.assertSpace(space)
# None for shape and dtype, since it'll require special handling
self.np_random = None
self.shape = None
self.dtype = None
self.seed()
def assertSpace(self, space):
err = "Values of the dict should be instances of gym.Space"
assert issubclass(type(space), gym.spaces.Space), err
def sample(self):
return {k: space.sample() for k, space in self.spaces.items()}
def __getitem__(self, key):
return self.spaces[key]
def __setitem__(self, key, space):
self.assertSpace(space)
self.spaces[key] = space
def __repr__(self):
return (
"FlexDict("
+ ", ".join([str(k) + ":" + str(s) for k, s in self.spaces.items()])
+ ")"
)