ray/rllib/utils/spaces/space_utils.py

345 lines
12 KiB
Python
Raw Normal View History

import gym
from gym.spaces import Tuple, Dict
import numpy as np
[RLlib] Upgrade gym version to 0.21 and deprecate pendulum-v0. (#19535) * Fix QMix, SAC, and MADDPA too. * Unpin gym and deprecate pendulum v0 Many tests in rllib depended on pendulum v0, however in gym 0.21, pendulum v0 was deprecated in favor of pendulum v1. This may change reward thresholds, so will have to potentially rerun all of the pendulum v1 benchmarks, or use another environment in favor. The same applies to frozen lake v0 and frozen lake v1 Lastly, all of the RLlib tests and have been moved to python 3.7 * Add gym installation based on python version. Pin python<= 3.6 to gym 0.19 due to install issues with atari roms in gym 0.20 * Reformatting * Fixing tests * Move atari-py install conditional to req.txt * migrate to new ale install method * Fix QMix, SAC, and MADDPA too. * Unpin gym and deprecate pendulum v0 Many tests in rllib depended on pendulum v0, however in gym 0.21, pendulum v0 was deprecated in favor of pendulum v1. This may change reward thresholds, so will have to potentially rerun all of the pendulum v1 benchmarks, or use another environment in favor. The same applies to frozen lake v0 and frozen lake v1 Lastly, all of the RLlib tests and have been moved to python 3.7 * Add gym installation based on python version. Pin python<= 3.6 to gym 0.19 due to install issues with atari roms in gym 0.20 Move atari-py install conditional to req.txt migrate to new ale install method Make parametric_actions_cartpole return float32 actions/obs Adding type conversions if obs/actions don't match space Add utils to make elements match gym space dtypes Co-authored-by: Jun Gong <jungong@anyscale.com> Co-authored-by: sven1977 <svenmika1977@gmail.com>
2021-11-03 08:24:00 -07:00
from ray.rllib.utils.annotations import DeveloperAPI
import tree # pip install dm_tree
[RLlib] Upgrade gym version to 0.21 and deprecate pendulum-v0. (#19535) * Fix QMix, SAC, and MADDPA too. * Unpin gym and deprecate pendulum v0 Many tests in rllib depended on pendulum v0, however in gym 0.21, pendulum v0 was deprecated in favor of pendulum v1. This may change reward thresholds, so will have to potentially rerun all of the pendulum v1 benchmarks, or use another environment in favor. The same applies to frozen lake v0 and frozen lake v1 Lastly, all of the RLlib tests and have been moved to python 3.7 * Add gym installation based on python version. Pin python<= 3.6 to gym 0.19 due to install issues with atari roms in gym 0.20 * Reformatting * Fixing tests * Move atari-py install conditional to req.txt * migrate to new ale install method * Fix QMix, SAC, and MADDPA too. * Unpin gym and deprecate pendulum v0 Many tests in rllib depended on pendulum v0, however in gym 0.21, pendulum v0 was deprecated in favor of pendulum v1. This may change reward thresholds, so will have to potentially rerun all of the pendulum v1 benchmarks, or use another environment in favor. The same applies to frozen lake v0 and frozen lake v1 Lastly, all of the RLlib tests and have been moved to python 3.7 * Add gym installation based on python version. Pin python<= 3.6 to gym 0.19 due to install issues with atari roms in gym 0.20 Move atari-py install conditional to req.txt migrate to new ale install method Make parametric_actions_cartpole return float32 actions/obs Adding type conversions if obs/actions don't match space Add utils to make elements match gym space dtypes Co-authored-by: Jun Gong <jungong@anyscale.com> Co-authored-by: sven1977 <svenmika1977@gmail.com>
2021-11-03 08:24:00 -07:00
from typing import Any, List, Optional, Union
def flatten_space(space: gym.Space) -> List[gym.Space]:
"""Flattens a gym.Space into its primitive components.
Primitive components are any non Tuple/Dict spaces.
Args:
space (gym.Space): The gym.Space to flatten. This may be any
supported type (including nested Tuples and Dicts).
Returns:
List[gym.Space]: The flattened list of primitive Spaces. This list
does not contain Tuples or Dicts anymore.
"""
def _helper_flatten(space_, return_list):
from ray.rllib.utils.spaces.flexdict import FlexDict
if isinstance(space_, Tuple):
for s in space_:
_helper_flatten(s, return_list)
elif isinstance(space_, (Dict, FlexDict)):
for k in sorted(space_.spaces):
_helper_flatten(space_[k], return_list)
else:
return_list.append(space_)
ret = []
_helper_flatten(space, ret)
return ret
def get_base_struct_from_space(space):
"""Returns a Tuple/Dict Space as native (equally structured) py tuple/dict.
Args:
space (gym.Space): The Space to get the python struct for.
Returns:
Union[dict,tuple,gym.Space]: The struct equivalent to the given Space.
Note that the returned struct still contains all original
"primitive" Spaces (e.g. Box, Discrete).
Examples:
>>> get_base_struct_from_space(Dict({
>>> "a": Box(),
>>> "b": Tuple([Discrete(2), Discrete(3)])
>>> }))
>>> # Will return: dict(a=Box(), b=tuple(Discrete(2), Discrete(3)))
"""
def _helper_struct(space_):
if isinstance(space_, Tuple):
return tuple(_helper_struct(s) for s in space_)
elif isinstance(space_, Dict):
return {k: _helper_struct(space_[k]) for k in space_.spaces}
else:
return space_
return _helper_struct(space)
def get_dummy_batch_for_space(
space: gym.Space,
batch_size: int = 32,
fill_value: Union[float, int, str] = 0.0,
time_size: Optional[int] = None,
time_major: bool = False,
) -> np.ndarray:
"""Returns batched dummy data (using `batch_size`) for the given `space`.
Note: The returned batch will not pass a `space.contains(batch)` test
as an additional batch dimension has to be added as dim=0.
Args:
space (gym.Space): The space to get a dummy batch for.
batch_size(int): The required batch size (B). Note that this can also
be 0 (only if `time_size` is None!), which will result in a
non-batched sample for the given space (no batch dim).
fill_value (Union[float, int, str]): The value to fill the batch with
or "random" for random values.
time_size (Optional[int]): If not None, add an optional time axis
of `time_size` size to the returned batch.
time_major (bool): If True AND `time_size` is not None, return batch
as shape [T x B x ...], otherwise as [B x T x ...]. If `time_size`
if None, ignore this setting and return [B x ...].
Returns:
The dummy batch of size `bqtch_size` matching the given space.
"""
# Complex spaces. Perform recursive calls of this function.
if isinstance(space, (gym.spaces.Dict, gym.spaces.Tuple)):
return tree.map_structure(
lambda s: get_dummy_batch_for_space(s, batch_size, fill_value),
get_base_struct_from_space(space),
)
# Primivite spaces: Box, Discrete, MultiDiscrete.
# Random values: Use gym's sample() method.
elif fill_value == "random":
if time_size is not None:
assert batch_size > 0 and time_size > 0
if time_major:
return np.array(
[
[space.sample() for _ in range(batch_size)]
for t in range(time_size)
],
dtype=space.dtype,
)
else:
return np.array(
[
[space.sample() for t in range(time_size)]
for _ in range(batch_size)
],
dtype=space.dtype,
)
else:
return np.array(
[space.sample() for _ in range(batch_size)]
if batch_size > 0
else space.sample(),
dtype=space.dtype,
)
# Fill value given: Use np.full.
else:
if time_size is not None:
assert batch_size > 0 and time_size > 0
if time_major:
shape = [time_size, batch_size]
else:
shape = [batch_size, time_size]
else:
shape = [batch_size] if batch_size > 0 else []
return np.full(
shape + list(space.shape), fill_value=fill_value, dtype=space.dtype
)
def flatten_to_single_ndarray(input_):
"""Returns a single np.ndarray given a list/tuple of np.ndarrays.
Args:
input_ (Union[List[np.ndarray], np.ndarray]): The list of ndarrays or
a single ndarray.
Returns:
np.ndarray: The result after concatenating all single arrays in input_.
Examples:
>>> flatten_to_single_ndarray([
>>> np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]),
>>> np.array([7, 8, 9]),
>>> ])
>>> # Will return:
>>> # np.array([
>>> # 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0
>>> # ])
"""
# Concatenate complex inputs.
if isinstance(input_, (list, tuple, dict)):
expanded = []
for in_ in tree.flatten(input_):
expanded.append(np.reshape(in_, [-1]))
input_ = np.concatenate(expanded, axis=0).flatten()
return input_
def unbatch(batches_struct):
"""Converts input from (nested) struct of batches to batch of structs.
Input: Struct of different batches (each batch has size=3):
{"a": [1, 2, 3], "b": ([4, 5, 6], [7.0, 8.0, 9.0])}
Output: Batch (list) of structs (each of these structs representing a
single action):
[
{"a": 1, "b": (4, 7.0)}, <- action 1
{"a": 2, "b": (5, 8.0)}, <- action 2
{"a": 3, "b": (6, 9.0)}, <- action 3
]
Args:
batches_struct (any): The struct of component batches. Each leaf item
in this struct represents the batch for a single component
(in case struct is tuple/dict).
Alternatively, `batches_struct` may also simply be a batch of
primitives (non tuple/dict).
Returns:
List[struct[components]]: The list of rows. Each item
in the returned list represents a single (maybe complex) struct.
"""
flat_batches = tree.flatten(batches_struct)
out = []
for batch_pos in range(len(flat_batches[0])):
out.append(
tree.unflatten_as(
batches_struct,
[flat_batches[i][batch_pos] for i in range(len(flat_batches))],
)
)
return out
def clip_action(action, action_space):
"""Clips all components in `action` according to the given Space.
Only applies to Box components within the action space.
Args:
action (Any): The action to be clipped. This could be any complex
action, e.g. a dict or tuple.
action_space (Any): The action space struct,
e.g. `{"a": Distrete(2)}` for a space: Dict({"a": Discrete(2)}).
Returns:
Any: The input action, but clipped by value according to the space's
bounds.
"""
def map_(a, s):
if isinstance(s, gym.spaces.Box):
a = np.clip(a, s.low, s.high)
return a
return tree.map_structure(map_, action, action_space)
def unsquash_action(action, action_space_struct):
"""Unsquashes all components in `action` according to the given Space.
Inverse of `normalize_action()`. Useful for mapping policy action
outputs (normalized between -1.0 and 1.0) to an env's action space.
Unsquashing results in cont. action component values between the
given Space's bounds (`low` and `high`). This only applies to Box
components within the action space, whose dtype is float32 or float64.
Args:
action (Any): The action to be unsquashed. This could be any complex
action, e.g. a dict or tuple.
action_space_struct (Any): The action space struct,
e.g. `{"a": Box()}` for a space: Dict({"a": Box()}).
Returns:
Any: The input action, but unsquashed, according to the space's
bounds. An unsquashed action is ready to be sent to the
environment (`BaseEnv.send_actions([unsquashed actions])`).
"""
def map_(a, s):
if (
isinstance(s, gym.spaces.Box)
and (s.dtype == np.float32 or s.dtype == np.float64)
and np.all(s.bounded_below)
and np.all(s.bounded_above)
):
# Assuming values are roughly between -1.0 and 1.0 ->
# unsquash them to the given bounds.
a = s.low + (a + 1.0) * (s.high - s.low) / 2.0
# Clip to given bounds, just in case the squashed values were
# outside [-1.0, 1.0].
a = np.clip(a, s.low, s.high)
return a
return tree.map_structure(map_, action, action_space_struct)
def normalize_action(action, action_space_struct):
"""Normalizes all (Box) components in `action` to be in [-1.0, 1.0].
Inverse of `unsquash_action()`. Useful for mapping an env's action
(arbitrary bounded values) to a [-1.0, 1.0] interval.
This only applies to Box components within the action space, whose
dtype is float32 or float64.
Args:
action (Any): The action to be normalized. This could be any complex
action, e.g. a dict or tuple.
action_space_struct (Any): The action space struct,
e.g. `{"a": Box()}` for a space: Dict({"a": Box()}).
Returns:
Any: The input action, but normalized, according to the space's
bounds.
"""
def map_(a, s):
if isinstance(s, gym.spaces.Box) and (
s.dtype == np.float32 or s.dtype == np.float64
):
# Normalize values to be exactly between -1.0 and 1.0.
a = ((a - s.low) * 2.0) / (s.high - s.low) - 1.0
return a
return tree.map_structure(map_, action, action_space_struct)
[RLlib] Upgrade gym version to 0.21 and deprecate pendulum-v0. (#19535) * Fix QMix, SAC, and MADDPA too. * Unpin gym and deprecate pendulum v0 Many tests in rllib depended on pendulum v0, however in gym 0.21, pendulum v0 was deprecated in favor of pendulum v1. This may change reward thresholds, so will have to potentially rerun all of the pendulum v1 benchmarks, or use another environment in favor. The same applies to frozen lake v0 and frozen lake v1 Lastly, all of the RLlib tests and have been moved to python 3.7 * Add gym installation based on python version. Pin python<= 3.6 to gym 0.19 due to install issues with atari roms in gym 0.20 * Reformatting * Fixing tests * Move atari-py install conditional to req.txt * migrate to new ale install method * Fix QMix, SAC, and MADDPA too. * Unpin gym and deprecate pendulum v0 Many tests in rllib depended on pendulum v0, however in gym 0.21, pendulum v0 was deprecated in favor of pendulum v1. This may change reward thresholds, so will have to potentially rerun all of the pendulum v1 benchmarks, or use another environment in favor. The same applies to frozen lake v0 and frozen lake v1 Lastly, all of the RLlib tests and have been moved to python 3.7 * Add gym installation based on python version. Pin python<= 3.6 to gym 0.19 due to install issues with atari roms in gym 0.20 Move atari-py install conditional to req.txt migrate to new ale install method Make parametric_actions_cartpole return float32 actions/obs Adding type conversions if obs/actions don't match space Add utils to make elements match gym space dtypes Co-authored-by: Jun Gong <jungong@anyscale.com> Co-authored-by: sven1977 <svenmika1977@gmail.com>
2021-11-03 08:24:00 -07:00
@DeveloperAPI
def convert_element_to_space_type(element: Any, sampled_element: Any) -> Any:
"""Convert all the components of the element to match the space dtypes.
Args:
element: The element to be converted.
sampled_element: An element sampled from a space to be matched
to.
Returns:
The input element, but with all its components converted to match
the space dtypes.
"""
def map_(elem, s):
if isinstance(s, np.ndarray):
if not isinstance(elem, np.ndarray):
assert isinstance(
elem, (float, int)
), f"ERROR: `elem` ({elem}) must be np.array, float or int!"
if s.shape == ():
elem = np.array(elem, dtype=s.dtype)
else:
raise ValueError(
"Element should be of type np.ndarray but is instead of \
type {}".format(
type(elem)
)
)
elif s.dtype != elem.dtype:
[RLlib] Upgrade gym version to 0.21 and deprecate pendulum-v0. (#19535) * Fix QMix, SAC, and MADDPA too. * Unpin gym and deprecate pendulum v0 Many tests in rllib depended on pendulum v0, however in gym 0.21, pendulum v0 was deprecated in favor of pendulum v1. This may change reward thresholds, so will have to potentially rerun all of the pendulum v1 benchmarks, or use another environment in favor. The same applies to frozen lake v0 and frozen lake v1 Lastly, all of the RLlib tests and have been moved to python 3.7 * Add gym installation based on python version. Pin python<= 3.6 to gym 0.19 due to install issues with atari roms in gym 0.20 * Reformatting * Fixing tests * Move atari-py install conditional to req.txt * migrate to new ale install method * Fix QMix, SAC, and MADDPA too. * Unpin gym and deprecate pendulum v0 Many tests in rllib depended on pendulum v0, however in gym 0.21, pendulum v0 was deprecated in favor of pendulum v1. This may change reward thresholds, so will have to potentially rerun all of the pendulum v1 benchmarks, or use another environment in favor. The same applies to frozen lake v0 and frozen lake v1 Lastly, all of the RLlib tests and have been moved to python 3.7 * Add gym installation based on python version. Pin python<= 3.6 to gym 0.19 due to install issues with atari roms in gym 0.20 Move atari-py install conditional to req.txt migrate to new ale install method Make parametric_actions_cartpole return float32 actions/obs Adding type conversions if obs/actions don't match space Add utils to make elements match gym space dtypes Co-authored-by: Jun Gong <jungong@anyscale.com> Co-authored-by: sven1977 <svenmika1977@gmail.com>
2021-11-03 08:24:00 -07:00
elem = elem.astype(s.dtype)
elif isinstance(s, int):
if isinstance(elem, float) and elem.is_integer():
[RLlib] Upgrade gym version to 0.21 and deprecate pendulum-v0. (#19535) * Fix QMix, SAC, and MADDPA too. * Unpin gym and deprecate pendulum v0 Many tests in rllib depended on pendulum v0, however in gym 0.21, pendulum v0 was deprecated in favor of pendulum v1. This may change reward thresholds, so will have to potentially rerun all of the pendulum v1 benchmarks, or use another environment in favor. The same applies to frozen lake v0 and frozen lake v1 Lastly, all of the RLlib tests and have been moved to python 3.7 * Add gym installation based on python version. Pin python<= 3.6 to gym 0.19 due to install issues with atari roms in gym 0.20 * Reformatting * Fixing tests * Move atari-py install conditional to req.txt * migrate to new ale install method * Fix QMix, SAC, and MADDPA too. * Unpin gym and deprecate pendulum v0 Many tests in rllib depended on pendulum v0, however in gym 0.21, pendulum v0 was deprecated in favor of pendulum v1. This may change reward thresholds, so will have to potentially rerun all of the pendulum v1 benchmarks, or use another environment in favor. The same applies to frozen lake v0 and frozen lake v1 Lastly, all of the RLlib tests and have been moved to python 3.7 * Add gym installation based on python version. Pin python<= 3.6 to gym 0.19 due to install issues with atari roms in gym 0.20 Move atari-py install conditional to req.txt migrate to new ale install method Make parametric_actions_cartpole return float32 actions/obs Adding type conversions if obs/actions don't match space Add utils to make elements match gym space dtypes Co-authored-by: Jun Gong <jungong@anyscale.com> Co-authored-by: sven1977 <svenmika1977@gmail.com>
2021-11-03 08:24:00 -07:00
elem = int(elem)
[RLlib] Upgrade gym version to 0.21 and deprecate pendulum-v0. (#19535) * Fix QMix, SAC, and MADDPA too. * Unpin gym and deprecate pendulum v0 Many tests in rllib depended on pendulum v0, however in gym 0.21, pendulum v0 was deprecated in favor of pendulum v1. This may change reward thresholds, so will have to potentially rerun all of the pendulum v1 benchmarks, or use another environment in favor. The same applies to frozen lake v0 and frozen lake v1 Lastly, all of the RLlib tests and have been moved to python 3.7 * Add gym installation based on python version. Pin python<= 3.6 to gym 0.19 due to install issues with atari roms in gym 0.20 * Reformatting * Fixing tests * Move atari-py install conditional to req.txt * migrate to new ale install method * Fix QMix, SAC, and MADDPA too. * Unpin gym and deprecate pendulum v0 Many tests in rllib depended on pendulum v0, however in gym 0.21, pendulum v0 was deprecated in favor of pendulum v1. This may change reward thresholds, so will have to potentially rerun all of the pendulum v1 benchmarks, or use another environment in favor. The same applies to frozen lake v0 and frozen lake v1 Lastly, all of the RLlib tests and have been moved to python 3.7 * Add gym installation based on python version. Pin python<= 3.6 to gym 0.19 due to install issues with atari roms in gym 0.20 Move atari-py install conditional to req.txt migrate to new ale install method Make parametric_actions_cartpole return float32 actions/obs Adding type conversions if obs/actions don't match space Add utils to make elements match gym space dtypes Co-authored-by: Jun Gong <jungong@anyscale.com> Co-authored-by: sven1977 <svenmika1977@gmail.com>
2021-11-03 08:24:00 -07:00
return elem
return tree.map_structure(map_, element, sampled_element, check_types=False)