ray/rllib/utils/spaces/space_utils.py
2020-05-27 10:21:30 +02:00

130 lines
4 KiB
Python

from gym.spaces import Tuple, Dict
import numpy as np
from ray.rllib.utils import try_import_tree
tree = try_import_tree()
def flatten_space(space):
"""Flattens a gym.Space into its primitive components.
Primitive components are any non Tuple/Dict spaces.
Args:
space(gym.Space): The gym.Space to flatten. This may be any
supported type (including nested Tuples and Dicts).
Returns:
List[gym.Space]: The flattened list of primitive Spaces. This list
does not contain Tuples or Dicts anymore.
"""
def _helper_flatten(space_, l):
if isinstance(space_, Tuple):
for s in space_:
_helper_flatten(s, l)
elif isinstance(space_, Dict):
for k in space_.spaces:
_helper_flatten(space_[k], l)
else:
l.append(space_)
ret = []
_helper_flatten(space, ret)
return ret
def get_base_struct_from_space(space):
"""Returns a Tuple/Dict Space as native (equally structured) py tuple/dict.
Args:
space (gym.Space): The Space to get the python struct for.
Returns:
Union[dict,tuple,gym.Space]: The struct equivalent to the given Space.
Note that the returned struct still contains all original
"primitive" Spaces (e.g. Box, Discrete).
Examples:
>>> get_base_struct_from_space(Dict({
>>> "a": Box(),
>>> "b": Tuple([Discrete(2), Discrete(3)])
>>> }))
>>> # Will return: dict(a=Box(), b=tuple(Discrete(2), Discrete(3)))
"""
def _helper_struct(space_):
if isinstance(space_, Tuple):
return tuple(_helper_struct(s) for s in space_)
elif isinstance(space_, Dict):
return {k: _helper_struct(space_[k]) for k in space_.spaces}
else:
return space_
return _helper_struct(space)
def flatten_to_single_ndarray(input_):
"""Returns a single np.ndarray given a list/tuple of np.ndarrays.
Args:
input_ (Union[List[np.ndarray],np.ndarray]): The list of ndarrays or
a single ndarray.
Returns:
np.ndarray: The result after concatenating all single arrays in input_.
Examples:
>>> flatten_to_single_ndarray([
>>> np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]),
>>> np.array([7, 8, 9]),
>>> ])
>>> # Will return:
>>> # np.array([
>>> # 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0
>>> # ])
"""
# Concatenate complex inputs.
if isinstance(input_, (list, tuple, dict)):
expanded = []
for in_ in tree.flatten(input_):
expanded.append(np.reshape(in_, [-1]))
input_ = np.concatenate(expanded, axis=0).flatten()
return input_
def unbatch(batches_struct):
"""Converts input from (nested) struct of batches to batch of structs.
Input: Struct of different batches (each batch has size=3):
{"a": [1, 2, 3], "b": ([4, 5, 6], [7.0, 8.0, 9.0])}
Output: Batch (list) of structs (each of these structs representing a
single action):
[
{"a": 1, "b": (4, 7.0)}, <- action 1
{"a": 2, "b": (5, 8.0)}, <- action 2
{"a": 3, "b": (6, 9.0)}, <- action 3
]
Args:
batches_struct (any): The struct of component batches. Each leaf item
in this struct represents the batch for a single component
(in case struct is tuple/dict).
Alternatively, `batches_struct` may also simply be a batch of
primitives (non tuple/dict).
Returns:
List[struct[components]]: The list of rows. Each item
in the returned list represents a single (maybe complex) struct.
"""
flat_batches = tree.flatten(batches_struct)
out = []
for batch_pos in range(len(flat_batches[0])):
out.append(
tree.unflatten_as(
batches_struct,
[flat_batches[i][batch_pos]
for i in range(len(flat_batches))]))
return out