mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
[rllib] More efficient tuple flattening. (#4416)
* More efficient tuple flattening. * Preprocessor.write uses transform by default. * lint * to array * Update test_catalog.py * Update test_catalog.py
This commit is contained in:
parent
a275af337e
commit
c68eea6134
2 changed files with 35 additions and 11 deletions
|
@ -30,6 +30,7 @@ class Preprocessor(object):
|
|||
self._obs_space = obs_space
|
||||
self._options = options or {}
|
||||
self.shape = self._init_shape(obs_space, options)
|
||||
self._size = int(np.product(self.shape))
|
||||
|
||||
@PublicAPI
|
||||
def _init_shape(self, obs_space, options):
|
||||
|
@ -41,10 +42,14 @@ class Preprocessor(object):
|
|||
"""Returns the preprocessed observation."""
|
||||
raise NotImplementedError
|
||||
|
||||
def write(self, observation, array, offset):
|
||||
"""Alternative to transform for more efficient flattening."""
|
||||
array[offset:offset + self._size] = self.transform(observation)
|
||||
|
||||
@property
|
||||
@PublicAPI
|
||||
def size(self):
|
||||
return int(np.product(self.shape))
|
||||
return self._size
|
||||
|
||||
@property
|
||||
@PublicAPI
|
||||
|
@ -123,6 +128,10 @@ class OneHotPreprocessor(Preprocessor):
|
|||
arr[observation] = 1
|
||||
return arr
|
||||
|
||||
@override(Preprocessor)
|
||||
def write(self, observation, array, offset):
|
||||
array[offset + observation] = 1
|
||||
|
||||
|
||||
class NoPreprocessor(Preprocessor):
|
||||
@override(Preprocessor)
|
||||
|
@ -133,6 +142,11 @@ class NoPreprocessor(Preprocessor):
|
|||
def transform(self, observation):
|
||||
return observation
|
||||
|
||||
@override(Preprocessor)
|
||||
def write(self, observation, array, offset):
|
||||
array[offset:offset + self._size] = np.array(
|
||||
observation, copy=False).ravel()
|
||||
|
||||
|
||||
class TupleFlatteningPreprocessor(Preprocessor):
|
||||
"""Preprocesses each tuple element, then flattens it all into a vector.
|
||||
|
@ -155,11 +169,16 @@ class TupleFlatteningPreprocessor(Preprocessor):
|
|||
|
||||
@override(Preprocessor)
|
||||
def transform(self, observation):
|
||||
array = np.zeros(self.shape)
|
||||
self.write(observation, array, 0)
|
||||
return array
|
||||
|
||||
@override(Preprocessor)
|
||||
def write(self, observation, array, offset):
|
||||
assert len(observation) == len(self.preprocessors), observation
|
||||
return np.concatenate([
|
||||
np.reshape(p.transform(o), [p.size])
|
||||
for (o, p) in zip(observation, self.preprocessors)
|
||||
])
|
||||
for o, p in zip(observation, self.preprocessors):
|
||||
p.write(o, array, offset)
|
||||
offset += p.size
|
||||
|
||||
|
||||
class DictFlatteningPreprocessor(Preprocessor):
|
||||
|
@ -182,14 +201,19 @@ class DictFlatteningPreprocessor(Preprocessor):
|
|||
|
||||
@override(Preprocessor)
|
||||
def transform(self, observation):
|
||||
array = np.zeros(self.shape)
|
||||
self.write(observation, array, 0)
|
||||
return array
|
||||
|
||||
@override(Preprocessor)
|
||||
def write(self, observation, array, offset):
|
||||
if not isinstance(observation, OrderedDict):
|
||||
observation = OrderedDict(sorted(list(observation.items())))
|
||||
assert len(observation) == len(self.preprocessors), \
|
||||
(len(observation), len(self.preprocessors))
|
||||
return np.concatenate([
|
||||
np.reshape(p.transform(o), [p.size])
|
||||
for (o, p) in zip(observation.values(), self.preprocessors)
|
||||
])
|
||||
for o, p in zip(observation.values(), self.preprocessors):
|
||||
p.write(o, array, offset)
|
||||
offset += p.size
|
||||
|
||||
|
||||
@PublicAPI
|
||||
|
|
|
@ -16,12 +16,12 @@ from ray.rllib.models.visionnet import VisionNetwork
|
|||
|
||||
class CustomPreprocessor(Preprocessor):
|
||||
def _init_shape(self, obs_space, options):
|
||||
return None
|
||||
return [1]
|
||||
|
||||
|
||||
class CustomPreprocessor2(Preprocessor):
|
||||
def _init_shape(self, obs_space, options):
|
||||
return None
|
||||
return [1]
|
||||
|
||||
|
||||
class CustomModel(Model):
|
||||
|
|
Loading…
Add table
Reference in a new issue