[rllib] More efficient tuple flattening. (#4416)

* More efficient tuple flattening.

* Preprocessor.write uses transform by default.

* lint

* to array

* Update test_catalog.py

* Update test_catalog.py
This commit is contained in:
Vlad Firoiu 2019-03-25 19:00:33 -04:00 committed by Eric Liang
parent a275af337e
commit c68eea6134
2 changed files with 35 additions and 11 deletions

View file

@ -30,6 +30,7 @@ class Preprocessor(object):
self._obs_space = obs_space
self._options = options or {}
self.shape = self._init_shape(obs_space, options)
self._size = int(np.product(self.shape))
@PublicAPI
def _init_shape(self, obs_space, options):
@ -41,10 +42,14 @@ class Preprocessor(object):
"""Returns the preprocessed observation."""
raise NotImplementedError
def write(self, observation, array, offset):
"""Alternative to transform for more efficient flattening."""
array[offset:offset + self._size] = self.transform(observation)
@property
@PublicAPI
def size(self):
return int(np.product(self.shape))
return self._size
@property
@PublicAPI
@ -123,6 +128,10 @@ class OneHotPreprocessor(Preprocessor):
arr[observation] = 1
return arr
@override(Preprocessor)
def write(self, observation, array, offset):
array[offset + observation] = 1
class NoPreprocessor(Preprocessor):
@override(Preprocessor)
@ -133,6 +142,11 @@ class NoPreprocessor(Preprocessor):
def transform(self, observation):
return observation
@override(Preprocessor)
def write(self, observation, array, offset):
array[offset:offset + self._size] = np.array(
observation, copy=False).ravel()
class TupleFlatteningPreprocessor(Preprocessor):
"""Preprocesses each tuple element, then flattens it all into a vector.
@ -155,11 +169,16 @@ class TupleFlatteningPreprocessor(Preprocessor):
@override(Preprocessor)
def transform(self, observation):
array = np.zeros(self.shape)
self.write(observation, array, 0)
return array
@override(Preprocessor)
def write(self, observation, array, offset):
assert len(observation) == len(self.preprocessors), observation
return np.concatenate([
np.reshape(p.transform(o), [p.size])
for (o, p) in zip(observation, self.preprocessors)
])
for o, p in zip(observation, self.preprocessors):
p.write(o, array, offset)
offset += p.size
class DictFlatteningPreprocessor(Preprocessor):
@ -182,14 +201,19 @@ class DictFlatteningPreprocessor(Preprocessor):
@override(Preprocessor)
def transform(self, observation):
array = np.zeros(self.shape)
self.write(observation, array, 0)
return array
@override(Preprocessor)
def write(self, observation, array, offset):
if not isinstance(observation, OrderedDict):
observation = OrderedDict(sorted(list(observation.items())))
assert len(observation) == len(self.preprocessors), \
(len(observation), len(self.preprocessors))
return np.concatenate([
np.reshape(p.transform(o), [p.size])
for (o, p) in zip(observation.values(), self.preprocessors)
])
for o, p in zip(observation.values(), self.preprocessors):
p.write(o, array, offset)
offset += p.size
@PublicAPI

View file

@ -16,12 +16,12 @@ from ray.rllib.models.visionnet import VisionNetwork
class CustomPreprocessor(Preprocessor):
def _init_shape(self, obs_space, options):
return None
return [1]
class CustomPreprocessor2(Preprocessor):
def _init_shape(self, obs_space, options):
return None
return [1]
class CustomModel(Model):