import gym from gym.spaces import Box, Dict, Discrete, MultiDiscrete, Tuple import numpy as np import unittest from ray.rllib.models.catalog import ModelCatalog from ray.rllib.models.preprocessors import DictFlatteningPreprocessor, \ get_preprocessor, NoPreprocessor, TupleFlatteningPreprocessor, \ OneHotPreprocessor, AtariRamPreprocessor, GenericPixelPreprocessor from ray.rllib.utils.test_utils import check class TestPreprocessors(unittest.TestCase): def test_gym_preprocessors(self): p1 = ModelCatalog.get_preprocessor(gym.make("CartPole-v0")) self.assertEqual(type(p1), NoPreprocessor) p2 = ModelCatalog.get_preprocessor(gym.make("FrozenLake-v0")) self.assertEqual(type(p2), OneHotPreprocessor) p3 = ModelCatalog.get_preprocessor(gym.make("MsPacman-ram-v0")) self.assertEqual(type(p3), AtariRamPreprocessor) p4 = ModelCatalog.get_preprocessor(gym.make("MsPacmanNoFrameskip-v4")) self.assertEqual(type(p4), GenericPixelPreprocessor) def test_tuple_preprocessor(self): class TupleEnv: def __init__(self): self.observation_space = Tuple( [Discrete(5), Box(0, 5, shape=(3, ), dtype=np.float32)]) pp = ModelCatalog.get_preprocessor(TupleEnv()) self.assertTrue(isinstance(pp, TupleFlatteningPreprocessor)) self.assertEqual(pp.shape, (8, )) self.assertEqual( list(pp.transform((0, np.array([1, 2, 3])))), [float(x) for x in [1, 0, 0, 0, 0, 1, 2, 3]]) def test_dict_flattening_preprocessor(self): space = Dict({ "a": Discrete(2), "b": Tuple([Discrete(3), Box(-1.0, 1.0, (4, ))]), }) pp = get_preprocessor(space)(space) self.assertTrue(isinstance(pp, DictFlatteningPreprocessor)) self.assertEqual(pp.shape, (9, )) check( pp.transform({ "a": 1, "b": (1, np.array([0.0, -0.5, 0.1, 0.6])) }), [0.0, 1.0, 0.0, 1.0, 0.0, 0.0, -0.5, 0.1, 0.6]) def test_one_hot_preprocessor(self): space = Discrete(5) pp = get_preprocessor(space)(space) self.assertTrue(isinstance(pp, OneHotPreprocessor)) self.assertTrue(pp.shape == (5, )) check(pp.transform(3), [0.0, 0.0, 0.0, 1.0, 0.0]) check(pp.transform(0), [1.0, 0.0, 0.0, 0.0, 0.0]) space = MultiDiscrete([2, 3, 4]) pp = get_preprocessor(space)(space) self.assertTrue(isinstance(pp, OneHotPreprocessor)) self.assertTrue(pp.shape == (9, )) check( pp.transform(np.array([1, 2, 0])), [0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 0.0]) check( pp.transform(np.array([0, 1, 3])), [1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0]) def test_nested_multidiscrete_one_hot_preprocessor(self): space = Tuple((MultiDiscrete([2, 3, 4]), )) pp = get_preprocessor(space)(space) self.assertTrue(pp.shape == (9, )) check( pp.transform((np.array([1, 2, 0]), )), [0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 0.0]) check( pp.transform((np.array([0, 1, 3]), )), [1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0]) if __name__ == "__main__": import pytest import sys sys.exit(pytest.main(["-v", __file__]))