diff --git a/rllib/models/preprocessors.py b/rllib/models/preprocessors.py index 2b0bcb092..44312a807 100644 --- a/rllib/models/preprocessors.py +++ b/rllib/models/preprocessors.py @@ -174,7 +174,7 @@ class OneHotPreprocessor(Preprocessor): @override(Preprocessor) def write(self, observation: TensorType, array: np.ndarray, offset: int) -> None: - array[offset + observation] = 1 + array[offset:offset + self.size] = self.transform(observation) class NoPreprocessor(Preprocessor): diff --git a/rllib/models/tests/test_preprocessors.py b/rllib/models/tests/test_preprocessors.py index 5515b6fea..4ce7b73e7 100644 --- a/rllib/models/tests/test_preprocessors.py +++ b/rllib/models/tests/test_preprocessors.py @@ -71,6 +71,17 @@ class TestPreprocessors(unittest.TestCase): pp.transform(np.array([0, 1, 3])), [1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0]) + def test_nested_multidiscrete_one_hot_preprocessor(self): + space = Tuple((MultiDiscrete([2, 3, 4]), )) + pp = get_preprocessor(space)(space) + self.assertTrue(pp.shape == (9, )) + check( + pp.transform((np.array([1, 2, 0]), )), + [0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 0.0]) + check( + pp.transform((np.array([0, 1, 3]), )), + [1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0]) + if __name__ == "__main__": import pytest