2018-11-27 23:35:19 -08:00
|
|
|
from collections import OrderedDict
|
2017-09-30 13:11:20 -07:00
|
|
|
import cv2
|
2018-10-21 23:43:57 -07:00
|
|
|
import logging
|
2017-09-30 13:11:20 -07:00
|
|
|
import numpy as np
|
2018-01-05 21:32:41 -08:00
|
|
|
import gym
|
2020-07-24 12:01:46 -07:00
|
|
|
from typing import Any, List
|
2018-01-05 21:32:41 -08:00
|
|
|
|
2019-01-23 21:27:26 -08:00
|
|
|
from ray.rllib.utils.annotations import override, PublicAPI
|
2020-06-06 03:22:19 -07:00
|
|
|
from ray.rllib.utils.spaces.repeated import Repeated
|
2020-11-12 03:16:12 -08:00
|
|
|
from ray.rllib.utils.typing import TensorType
|
2018-12-08 16:28:58 -08:00
|
|
|
|
2018-01-05 21:32:41 -08:00
|
|
|
ATARI_OBS_SHAPE = (210, 160, 3)
|
2018-07-19 15:30:36 -07:00
|
|
|
ATARI_RAM_OBS_SHAPE = (128, )
|
2020-12-27 09:46:03 -05:00
|
|
|
|
|
|
|
# Only validate env observations vs the observation space every n times in a
|
|
|
|
# Preprocessor.
|
|
|
|
OBS_VALIDATION_INTERVAL = 100
|
2017-08-22 03:51:49 +02:00
|
|
|
|
2018-10-21 23:43:57 -07:00
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
2017-08-22 03:51:49 +02:00
|
|
|
|
2019-01-23 21:27:26 -08:00
|
|
|
@PublicAPI
|
2020-01-02 17:42:13 -08:00
|
|
|
class Preprocessor:
|
2017-10-23 23:16:52 -07:00
|
|
|
"""Defines an abstract observation preprocessor function.
|
2017-08-22 03:51:49 +02:00
|
|
|
|
2017-10-23 23:16:52 -07:00
|
|
|
Attributes:
|
2020-07-24 12:01:46 -07:00
|
|
|
shape (List[int]): Shape of the preprocessed output.
|
2017-10-23 23:16:52 -07:00
|
|
|
"""
|
|
|
|
|
2019-01-23 21:27:26 -08:00
|
|
|
@PublicAPI
|
2020-07-24 12:01:46 -07:00
|
|
|
def __init__(self, obs_space: gym.Space, options: dict = None):
|
2018-01-05 21:32:41 -08:00
|
|
|
legacy_patch_shapes(obs_space)
|
2017-10-23 23:16:52 -07:00
|
|
|
self._obs_space = obs_space
|
2019-09-08 23:01:26 -07:00
|
|
|
if not options:
|
|
|
|
from ray.rllib.models.catalog import MODEL_DEFAULTS
|
|
|
|
self._options = MODEL_DEFAULTS.copy()
|
|
|
|
else:
|
|
|
|
self._options = options
|
|
|
|
self.shape = self._init_shape(obs_space, self._options)
|
2019-03-25 19:00:33 -04:00
|
|
|
self._size = int(np.product(self.shape))
|
2019-04-07 16:11:50 -07:00
|
|
|
self._i = 0
|
2017-09-02 17:20:56 -07:00
|
|
|
|
2019-01-23 21:27:26 -08:00
|
|
|
@PublicAPI
|
2020-07-24 12:01:46 -07:00
|
|
|
def _init_shape(self, obs_space: gym.Space, options: dict) -> List[int]:
|
2018-10-20 15:21:22 -07:00
|
|
|
"""Returns the shape after preprocessing."""
|
|
|
|
raise NotImplementedError
|
2017-09-02 17:20:56 -07:00
|
|
|
|
2019-01-23 21:27:26 -08:00
|
|
|
@PublicAPI
|
2020-11-12 03:16:12 -08:00
|
|
|
def transform(self, observation: TensorType) -> np.ndarray:
|
2017-08-22 03:51:49 +02:00
|
|
|
"""Returns the preprocessed observation."""
|
|
|
|
raise NotImplementedError
|
|
|
|
|
2020-11-12 03:16:12 -08:00
|
|
|
def write(self, observation: TensorType, array: np.ndarray,
|
|
|
|
offset: int) -> None:
|
2019-03-25 19:00:33 -04:00
|
|
|
"""Alternative to transform for more efficient flattening."""
|
|
|
|
array[offset:offset + self._size] = self.transform(observation)
|
|
|
|
|
2020-07-24 12:01:46 -07:00
|
|
|
def check_shape(self, observation: Any) -> None:
|
2019-04-07 16:11:50 -07:00
|
|
|
"""Checks the shape of the given observation."""
|
2020-12-27 09:46:03 -05:00
|
|
|
if self._i % OBS_VALIDATION_INTERVAL == 0:
|
2019-04-07 16:11:50 -07:00
|
|
|
if type(observation) is list and isinstance(
|
|
|
|
self._obs_space, gym.spaces.Box):
|
|
|
|
observation = np.array(observation)
|
|
|
|
try:
|
|
|
|
if not self._obs_space.contains(observation):
|
|
|
|
raise ValueError(
|
2020-10-06 20:28:16 +02:00
|
|
|
"Observation ({}) outside given space ({})!",
|
|
|
|
observation, self._obs_space)
|
2019-04-07 16:11:50 -07:00
|
|
|
except AttributeError:
|
|
|
|
raise ValueError(
|
2019-04-18 15:23:29 -07:00
|
|
|
"Observation for a Box/MultiBinary/MultiDiscrete space "
|
|
|
|
"should be an np.array, not a Python list.", observation)
|
2019-04-07 16:11:50 -07:00
|
|
|
self._i += 1
|
|
|
|
|
2018-10-20 15:21:22 -07:00
|
|
|
@property
|
2019-01-23 21:27:26 -08:00
|
|
|
@PublicAPI
|
2020-07-24 12:01:46 -07:00
|
|
|
def size(self) -> int:
|
2019-03-25 19:00:33 -04:00
|
|
|
return self._size
|
2018-10-20 15:21:22 -07:00
|
|
|
|
|
|
|
@property
|
2019-01-23 21:27:26 -08:00
|
|
|
@PublicAPI
|
2020-07-24 12:01:46 -07:00
|
|
|
def observation_space(self) -> gym.Space:
|
2019-10-04 09:28:06 -07:00
|
|
|
obs_space = gym.spaces.Box(-1., 1., self.shape, dtype=np.float32)
|
2018-10-20 15:21:22 -07:00
|
|
|
# Stash the unwrapped space so that we can unwrap dict and tuple spaces
|
2020-12-26 20:14:36 -05:00
|
|
|
# automatically in modelv2.py
|
2020-12-11 22:43:30 +01:00
|
|
|
classes = (DictFlatteningPreprocessor, OneHotPreprocessor,
|
|
|
|
RepeatedValuesPreprocessor, TupleFlatteningPreprocessor)
|
|
|
|
if isinstance(self, classes):
|
2018-10-20 15:21:22 -07:00
|
|
|
obs_space.original_space = self._obs_space
|
|
|
|
return obs_space
|
|
|
|
|
2017-08-22 03:51:49 +02:00
|
|
|
|
2018-10-16 15:55:11 -07:00
|
|
|
class GenericPixelPreprocessor(Preprocessor):
|
|
|
|
"""Generic image preprocessor.
|
|
|
|
|
|
|
|
Note: for Atari games, use config {"preprocessor_pref": "deepmind"}
|
|
|
|
instead for deepmind-style Atari preprocessing.
|
|
|
|
"""
|
|
|
|
|
2018-12-08 16:28:58 -08:00
|
|
|
@override(Preprocessor)
|
2020-11-12 03:16:12 -08:00
|
|
|
def _init_shape(self, obs_space: gym.Space, options: dict) -> List[int]:
|
2018-10-20 15:21:22 -07:00
|
|
|
self._grayscale = options.get("grayscale")
|
|
|
|
self._zero_mean = options.get("zero_mean")
|
|
|
|
self._dim = options.get("dim")
|
2017-10-23 23:16:52 -07:00
|
|
|
if self._grayscale:
|
2018-10-20 15:21:22 -07:00
|
|
|
shape = (self._dim, self._dim, 1)
|
2017-09-30 13:11:20 -07:00
|
|
|
else:
|
2018-10-20 15:21:22 -07:00
|
|
|
shape = (self._dim, self._dim, 3)
|
2017-08-22 03:51:49 +02:00
|
|
|
|
2018-10-20 15:21:22 -07:00
|
|
|
return shape
|
2017-10-29 11:12:17 -07:00
|
|
|
|
2018-12-08 16:28:58 -08:00
|
|
|
@override(Preprocessor)
|
2020-11-12 03:16:12 -08:00
|
|
|
def transform(self, observation: TensorType) -> np.ndarray:
|
2017-09-02 17:20:56 -07:00
|
|
|
"""Downsamples images from (210, 160, 3) by the configured factor."""
|
2019-04-07 16:11:50 -07:00
|
|
|
self.check_shape(observation)
|
2017-09-30 13:11:20 -07:00
|
|
|
scaled = observation[25:-25, :, :]
|
2018-08-20 15:28:03 -07:00
|
|
|
if self._dim < 84:
|
|
|
|
scaled = cv2.resize(scaled, (84, 84))
|
2017-10-03 18:45:02 -07:00
|
|
|
# OpenAI: Resize by half, then down to 42x42 (essentially mipmapping).
|
|
|
|
# If we resize directly we lose pixels that, when mapped to 42x42,
|
|
|
|
# aren't close enough to the pixel boundary.
|
2017-10-23 23:16:52 -07:00
|
|
|
scaled = cv2.resize(scaled, (self._dim, self._dim))
|
|
|
|
if self._grayscale:
|
2017-09-30 13:11:20 -07:00
|
|
|
scaled = scaled.mean(2)
|
|
|
|
scaled = scaled.astype(np.float32)
|
2017-10-03 18:45:02 -07:00
|
|
|
# Rescale needed for maintaining 1 channel
|
2017-10-23 23:16:52 -07:00
|
|
|
scaled = np.reshape(scaled, [self._dim, self._dim, 1])
|
|
|
|
if self._zero_mean:
|
2017-09-30 13:11:20 -07:00
|
|
|
scaled = (scaled - 128) / 128
|
|
|
|
else:
|
|
|
|
scaled *= 1.0 / 255.0
|
|
|
|
return scaled
|
|
|
|
|
|
|
|
|
2017-08-22 03:51:49 +02:00
|
|
|
class AtariRamPreprocessor(Preprocessor):
|
2018-12-08 16:28:58 -08:00
|
|
|
@override(Preprocessor)
|
2020-11-12 03:16:12 -08:00
|
|
|
def _init_shape(self, obs_space: gym.Space, options: dict) -> List[int]:
|
2018-10-20 15:21:22 -07:00
|
|
|
return (128, )
|
2017-08-22 03:51:49 +02:00
|
|
|
|
2018-12-08 16:28:58 -08:00
|
|
|
@override(Preprocessor)
|
2020-11-12 03:16:12 -08:00
|
|
|
def transform(self, observation: TensorType) -> np.ndarray:
|
2019-04-07 16:11:50 -07:00
|
|
|
self.check_shape(observation)
|
2017-08-22 03:51:49 +02:00
|
|
|
return (observation - 128) / 128
|
|
|
|
|
|
|
|
|
2017-10-23 23:16:52 -07:00
|
|
|
class OneHotPreprocessor(Preprocessor):
|
2020-12-26 20:14:36 -05:00
|
|
|
"""One-hot preprocessor for Discrete and MultiDiscrete spaces.
|
|
|
|
|
|
|
|
Examples:
|
|
|
|
>>> self.transform(Discrete(3).sample())
|
|
|
|
... np.array([0.0, 1.0, 0.0])
|
|
|
|
>>> self.transform(MultiDiscrete([2, 3]).sample())
|
|
|
|
... np.array([0.0, 1.0, 0.0, 0.0, 1.0])
|
|
|
|
"""
|
|
|
|
|
2018-12-08 16:28:58 -08:00
|
|
|
@override(Preprocessor)
|
2020-11-12 03:16:12 -08:00
|
|
|
def _init_shape(self, obs_space: gym.Space, options: dict) -> List[int]:
|
2020-12-26 20:14:36 -05:00
|
|
|
if isinstance(obs_space, gym.spaces.Discrete):
|
|
|
|
return (self._obs_space.n, )
|
|
|
|
else:
|
|
|
|
return (np.sum(self._obs_space.nvec), )
|
2017-10-23 23:16:52 -07:00
|
|
|
|
2018-12-08 16:28:58 -08:00
|
|
|
@override(Preprocessor)
|
2020-11-12 03:16:12 -08:00
|
|
|
def transform(self, observation: TensorType) -> np.ndarray:
|
2019-04-07 16:11:50 -07:00
|
|
|
self.check_shape(observation)
|
2020-12-26 20:14:36 -05:00
|
|
|
arr = np.zeros(self._init_shape(self._obs_space, {}), dtype=np.float32)
|
|
|
|
if isinstance(self._obs_space, gym.spaces.Discrete):
|
|
|
|
arr[observation] = 1
|
|
|
|
else:
|
|
|
|
for i, o in enumerate(observation):
|
|
|
|
arr[np.sum(self._obs_space.nvec[:i]) + o] = 1
|
2017-10-23 23:16:52 -07:00
|
|
|
return arr
|
|
|
|
|
2019-03-25 19:00:33 -04:00
|
|
|
@override(Preprocessor)
|
2020-11-12 03:16:12 -08:00
|
|
|
def write(self, observation: TensorType, array: np.ndarray,
|
|
|
|
offset: int) -> None:
|
2021-01-21 15:36:11 +00:00
|
|
|
array[offset:offset + self.size] = self.transform(observation)
|
2019-03-25 19:00:33 -04:00
|
|
|
|
2017-10-23 23:16:52 -07:00
|
|
|
|
2017-08-22 03:51:49 +02:00
|
|
|
class NoPreprocessor(Preprocessor):
|
2018-12-08 16:28:58 -08:00
|
|
|
@override(Preprocessor)
|
2020-11-12 03:16:12 -08:00
|
|
|
def _init_shape(self, obs_space: gym.Space, options: dict) -> List[int]:
|
2018-10-20 15:21:22 -07:00
|
|
|
return self._obs_space.shape
|
2017-08-22 03:51:49 +02:00
|
|
|
|
2018-12-08 16:28:58 -08:00
|
|
|
@override(Preprocessor)
|
2020-11-12 03:16:12 -08:00
|
|
|
def transform(self, observation: TensorType) -> np.ndarray:
|
2019-04-07 16:11:50 -07:00
|
|
|
self.check_shape(observation)
|
2017-08-22 03:51:49 +02:00
|
|
|
return observation
|
2018-01-05 21:32:41 -08:00
|
|
|
|
2019-03-25 19:00:33 -04:00
|
|
|
@override(Preprocessor)
|
2020-11-12 03:16:12 -08:00
|
|
|
def write(self, observation: TensorType, array: np.ndarray,
|
|
|
|
offset: int) -> None:
|
2019-03-25 19:00:33 -04:00
|
|
|
array[offset:offset + self._size] = np.array(
|
|
|
|
observation, copy=False).ravel()
|
|
|
|
|
2019-10-04 09:28:06 -07:00
|
|
|
@property
|
|
|
|
@override(Preprocessor)
|
2020-11-12 03:16:12 -08:00
|
|
|
def observation_space(self) -> gym.Space:
|
2019-10-04 09:28:06 -07:00
|
|
|
return self._obs_space
|
|
|
|
|
2018-01-05 21:32:41 -08:00
|
|
|
|
|
|
|
class TupleFlatteningPreprocessor(Preprocessor):
|
|
|
|
"""Preprocesses each tuple element, then flattens it all into a vector.
|
|
|
|
|
2018-10-20 15:21:22 -07:00
|
|
|
RLlib models will unpack the flattened output before _build_layers_v2().
|
2018-01-05 21:32:41 -08:00
|
|
|
"""
|
|
|
|
|
2018-12-08 16:28:58 -08:00
|
|
|
@override(Preprocessor)
|
2020-11-12 03:16:12 -08:00
|
|
|
def _init_shape(self, obs_space: gym.Space, options: dict) -> List[int]:
|
2018-01-05 21:32:41 -08:00
|
|
|
assert isinstance(self._obs_space, gym.spaces.Tuple)
|
|
|
|
size = 0
|
|
|
|
self.preprocessors = []
|
|
|
|
for i in range(len(self._obs_space.spaces)):
|
|
|
|
space = self._obs_space.spaces[i]
|
2018-11-07 14:54:28 -08:00
|
|
|
logger.debug("Creating sub-preprocessor for {}".format(space))
|
2018-01-05 21:32:41 -08:00
|
|
|
preprocessor = get_preprocessor(space)(space, self._options)
|
|
|
|
self.preprocessors.append(preprocessor)
|
2018-10-20 15:21:22 -07:00
|
|
|
size += preprocessor.size
|
|
|
|
return (size, )
|
2018-01-05 21:32:41 -08:00
|
|
|
|
2018-12-08 16:28:58 -08:00
|
|
|
@override(Preprocessor)
|
2020-11-12 03:16:12 -08:00
|
|
|
def transform(self, observation: TensorType) -> np.ndarray:
|
2019-04-07 16:11:50 -07:00
|
|
|
self.check_shape(observation)
|
2020-12-27 09:46:03 -05:00
|
|
|
array = np.zeros(self.shape, dtype=np.float32)
|
2019-03-25 19:00:33 -04:00
|
|
|
self.write(observation, array, 0)
|
|
|
|
return array
|
|
|
|
|
|
|
|
@override(Preprocessor)
|
2020-11-12 03:16:12 -08:00
|
|
|
def write(self, observation: TensorType, array: np.ndarray,
|
|
|
|
offset: int) -> None:
|
2018-01-05 21:32:41 -08:00
|
|
|
assert len(observation) == len(self.preprocessors), observation
|
2019-03-25 19:00:33 -04:00
|
|
|
for o, p in zip(observation, self.preprocessors):
|
|
|
|
p.write(o, array, offset)
|
|
|
|
offset += p.size
|
2018-01-05 21:32:41 -08:00
|
|
|
|
|
|
|
|
2018-10-20 15:21:22 -07:00
|
|
|
class DictFlatteningPreprocessor(Preprocessor):
|
|
|
|
"""Preprocesses each dict value, then flattens it all into a vector.
|
|
|
|
|
|
|
|
RLlib models will unpack the flattened output before _build_layers_v2().
|
|
|
|
"""
|
|
|
|
|
2018-12-08 16:28:58 -08:00
|
|
|
@override(Preprocessor)
|
2020-11-12 03:16:12 -08:00
|
|
|
def _init_shape(self, obs_space: gym.Space, options: dict) -> List[int]:
|
2018-10-20 15:21:22 -07:00
|
|
|
assert isinstance(self._obs_space, gym.spaces.Dict)
|
|
|
|
size = 0
|
|
|
|
self.preprocessors = []
|
|
|
|
for space in self._obs_space.spaces.values():
|
2018-11-07 14:54:28 -08:00
|
|
|
logger.debug("Creating sub-preprocessor for {}".format(space))
|
2018-10-20 15:21:22 -07:00
|
|
|
preprocessor = get_preprocessor(space)(space, self._options)
|
|
|
|
self.preprocessors.append(preprocessor)
|
|
|
|
size += preprocessor.size
|
|
|
|
return (size, )
|
|
|
|
|
2018-12-08 16:28:58 -08:00
|
|
|
@override(Preprocessor)
|
2020-11-12 03:16:12 -08:00
|
|
|
def transform(self, observation: TensorType) -> np.ndarray:
|
2019-04-07 16:11:50 -07:00
|
|
|
self.check_shape(observation)
|
2020-12-27 09:46:03 -05:00
|
|
|
array = np.zeros(self.shape, dtype=np.float32)
|
2019-03-25 19:00:33 -04:00
|
|
|
self.write(observation, array, 0)
|
|
|
|
return array
|
|
|
|
|
|
|
|
@override(Preprocessor)
|
2020-11-12 03:16:12 -08:00
|
|
|
def write(self, observation: TensorType, array: np.ndarray,
|
|
|
|
offset: int) -> None:
|
2018-11-27 23:35:19 -08:00
|
|
|
if not isinstance(observation, OrderedDict):
|
2019-11-16 10:02:58 -08:00
|
|
|
observation = OrderedDict(sorted(observation.items()))
|
2018-10-20 15:21:22 -07:00
|
|
|
assert len(observation) == len(self.preprocessors), \
|
|
|
|
(len(observation), len(self.preprocessors))
|
2019-03-25 19:00:33 -04:00
|
|
|
for o, p in zip(observation.values(), self.preprocessors):
|
|
|
|
p.write(o, array, offset)
|
|
|
|
offset += p.size
|
2018-10-20 15:21:22 -07:00
|
|
|
|
|
|
|
|
2020-06-06 03:22:19 -07:00
|
|
|
class RepeatedValuesPreprocessor(Preprocessor):
|
|
|
|
"""Pads and batches the variable-length list value."""
|
|
|
|
|
|
|
|
@override(Preprocessor)
|
2020-11-12 03:16:12 -08:00
|
|
|
def _init_shape(self, obs_space: gym.Space, options: dict) -> List[int]:
|
2020-06-06 03:22:19 -07:00
|
|
|
assert isinstance(self._obs_space, Repeated)
|
|
|
|
child_space = obs_space.child_space
|
|
|
|
self.child_preprocessor = get_preprocessor(child_space)(child_space,
|
|
|
|
self._options)
|
|
|
|
# The first slot encodes the list length.
|
|
|
|
size = 1 + self.child_preprocessor.size * obs_space.max_len
|
|
|
|
return (size, )
|
|
|
|
|
|
|
|
@override(Preprocessor)
|
2020-11-12 03:16:12 -08:00
|
|
|
def transform(self, observation: TensorType) -> np.ndarray:
|
2020-06-06 03:22:19 -07:00
|
|
|
array = np.zeros(self.shape)
|
|
|
|
if isinstance(observation, list):
|
|
|
|
for elem in observation:
|
|
|
|
self.child_preprocessor.check_shape(elem)
|
|
|
|
else:
|
|
|
|
pass # ValueError will be raised in write() below.
|
|
|
|
self.write(observation, array, 0)
|
|
|
|
return array
|
|
|
|
|
|
|
|
@override(Preprocessor)
|
2020-11-12 03:16:12 -08:00
|
|
|
def write(self, observation: TensorType, array: np.ndarray,
|
|
|
|
offset: int) -> None:
|
2020-06-06 03:22:19 -07:00
|
|
|
if not isinstance(observation, list):
|
|
|
|
raise ValueError("Input for {} must be list type, got {}".format(
|
|
|
|
self, observation))
|
|
|
|
elif len(observation) > self._obs_space.max_len:
|
|
|
|
raise ValueError("Input {} exceeds max len of space {}".format(
|
|
|
|
observation, self._obs_space.max_len))
|
|
|
|
# The first slot encodes the list length.
|
|
|
|
array[offset] = len(observation)
|
|
|
|
for i, elem in enumerate(observation):
|
|
|
|
offset_i = offset + 1 + i * self.child_preprocessor.size
|
|
|
|
self.child_preprocessor.write(elem, array, offset_i)
|
|
|
|
|
|
|
|
|
2019-01-23 21:27:26 -08:00
|
|
|
@PublicAPI
|
2020-07-24 12:01:46 -07:00
|
|
|
def get_preprocessor(space: gym.Space) -> type:
|
2018-01-05 21:32:41 -08:00
|
|
|
"""Returns an appropriate preprocessor class for the given space."""
|
|
|
|
|
|
|
|
legacy_patch_shapes(space)
|
|
|
|
obs_shape = space.shape
|
|
|
|
|
2020-12-26 20:14:36 -05:00
|
|
|
if isinstance(space, (gym.spaces.Discrete, gym.spaces.MultiDiscrete)):
|
2018-01-05 21:32:41 -08:00
|
|
|
preprocessor = OneHotPreprocessor
|
|
|
|
elif obs_shape == ATARI_OBS_SHAPE:
|
2018-10-16 15:55:11 -07:00
|
|
|
preprocessor = GenericPixelPreprocessor
|
2018-01-05 21:32:41 -08:00
|
|
|
elif obs_shape == ATARI_RAM_OBS_SHAPE:
|
|
|
|
preprocessor = AtariRamPreprocessor
|
|
|
|
elif isinstance(space, gym.spaces.Tuple):
|
|
|
|
preprocessor = TupleFlatteningPreprocessor
|
2018-10-20 15:21:22 -07:00
|
|
|
elif isinstance(space, gym.spaces.Dict):
|
|
|
|
preprocessor = DictFlatteningPreprocessor
|
2020-06-06 03:22:19 -07:00
|
|
|
elif isinstance(space, Repeated):
|
|
|
|
preprocessor = RepeatedValuesPreprocessor
|
2018-01-05 21:32:41 -08:00
|
|
|
else:
|
|
|
|
preprocessor = NoPreprocessor
|
|
|
|
|
|
|
|
return preprocessor
|
|
|
|
|
|
|
|
|
2020-07-24 12:01:46 -07:00
|
|
|
def legacy_patch_shapes(space: gym.Space) -> List[int]:
|
2018-01-05 21:32:41 -08:00
|
|
|
"""Assigns shapes to spaces that don't have shapes.
|
|
|
|
|
|
|
|
This is only needed for older gym versions that don't set shapes properly
|
|
|
|
for Tuple and Discrete spaces.
|
|
|
|
"""
|
|
|
|
|
|
|
|
if not hasattr(space, "shape"):
|
|
|
|
if isinstance(space, gym.spaces.Discrete):
|
|
|
|
space.shape = ()
|
|
|
|
elif isinstance(space, gym.spaces.Tuple):
|
|
|
|
shapes = []
|
|
|
|
for s in space.spaces:
|
|
|
|
shape = legacy_patch_shapes(s)
|
|
|
|
shapes.append(shape)
|
|
|
|
space.shape = tuple(shapes)
|
|
|
|
|
|
|
|
return space.shape
|