[RLlib] Fix problem in preprocessing nested MultiDiscrete (#13308)

This commit is contained in:
Saeid 2021-01-21 15:36:11 +00:00 committed by GitHub
parent daf0bef285
commit d11e62f9e6
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 12 additions and 1 deletions

View file

@ -174,7 +174,7 @@ class OneHotPreprocessor(Preprocessor):
@override(Preprocessor)
def write(self, observation: TensorType, array: np.ndarray,
offset: int) -> None:
array[offset + observation] = 1
array[offset:offset + self.size] = self.transform(observation)
class NoPreprocessor(Preprocessor):

View file

@ -71,6 +71,17 @@ class TestPreprocessors(unittest.TestCase):
pp.transform(np.array([0, 1, 3])),
[1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0])
def test_nested_multidiscrete_one_hot_preprocessor(self):
space = Tuple((MultiDiscrete([2, 3, 4]), ))
pp = get_preprocessor(space)(space)
self.assertTrue(pp.shape == (9, ))
check(
pp.transform((np.array([1, 2, 0]), )),
[0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 0.0])
check(
pp.transform((np.array([0, 1, 3]), )),
[1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0])
if __name__ == "__main__":
import pytest