mirror of
https://github.com/vale981/ray
synced 2025-03-05 18:11:42 -05:00
[RLlib] Fix problem in preprocessing nested MultiDiscrete (#13308)
This commit is contained in:
parent
daf0bef285
commit
d11e62f9e6
2 changed files with 12 additions and 1 deletions
|
@ -174,7 +174,7 @@ class OneHotPreprocessor(Preprocessor):
|
|||
@override(Preprocessor)
|
||||
def write(self, observation: TensorType, array: np.ndarray,
|
||||
offset: int) -> None:
|
||||
array[offset + observation] = 1
|
||||
array[offset:offset + self.size] = self.transform(observation)
|
||||
|
||||
|
||||
class NoPreprocessor(Preprocessor):
|
||||
|
|
|
@ -71,6 +71,17 @@ class TestPreprocessors(unittest.TestCase):
|
|||
pp.transform(np.array([0, 1, 3])),
|
||||
[1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0])
|
||||
|
||||
def test_nested_multidiscrete_one_hot_preprocessor(self):
|
||||
space = Tuple((MultiDiscrete([2, 3, 4]), ))
|
||||
pp = get_preprocessor(space)(space)
|
||||
self.assertTrue(pp.shape == (9, ))
|
||||
check(
|
||||
pp.transform((np.array([1, 2, 0]), )),
|
||||
[0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 0.0])
|
||||
check(
|
||||
pp.transform((np.array([0, 1, 3]), )),
|
||||
[1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import pytest
|
||||
|
|
Loading…
Add table
Reference in a new issue