mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
52 lines
1.4 KiB
Python
52 lines
1.4 KiB
Python
import gym
|
|
|
|
from ray.rllib.utils.annotations import PublicAPI
|
|
|
|
|
|
@PublicAPI
|
|
class FlexDict(gym.spaces.Dict):
|
|
"""Gym Dictionary with arbitrary keys updatable after instantiation
|
|
|
|
Example:
|
|
space = FlexDict({})
|
|
space['key'] = spaces.Box(4,)
|
|
See also: documentation for gym.spaces.Dict
|
|
"""
|
|
|
|
def __init__(self, spaces=None, **spaces_kwargs):
|
|
err = "Use either Dict(spaces=dict(...)) or Dict(foo=x, bar=z)"
|
|
assert (spaces is None) or (not spaces_kwargs), err
|
|
|
|
if spaces is None:
|
|
spaces = spaces_kwargs
|
|
|
|
self.spaces = spaces
|
|
for space in spaces.values():
|
|
self.assertSpace(space)
|
|
|
|
# None for shape and dtype, since it'll require special handling
|
|
self.np_random = None
|
|
self.shape = None
|
|
self.dtype = None
|
|
self.seed()
|
|
|
|
def assertSpace(self, space):
|
|
err = "Values of the dict should be instances of gym.Space"
|
|
assert issubclass(type(space), gym.spaces.Space), err
|
|
|
|
def sample(self):
|
|
return {k: space.sample() for k, space in self.spaces.items()}
|
|
|
|
def __getitem__(self, key):
|
|
return self.spaces[key]
|
|
|
|
def __setitem__(self, key, space):
|
|
self.assertSpace(space)
|
|
self.spaces[key] = space
|
|
|
|
def __repr__(self):
|
|
return (
|
|
"FlexDict("
|
|
+ ", ".join([str(k) + ":" + str(s) for k, s in self.spaces.items()])
|
|
+ ")"
|
|
)
|