mirror of
https://github.com/vale981/ray
synced 2025-03-09 12:56:46 -04:00
78 lines
2.9 KiB
Python
78 lines
2.9 KiB
Python
import gym
|
|
from gym.spaces import Box, Dict, Discrete, MultiDiscrete, Tuple
|
|
import numpy as np
|
|
import unittest
|
|
|
|
from ray.rllib.models.catalog import ModelCatalog
|
|
from ray.rllib.models.preprocessors import DictFlatteningPreprocessor, \
|
|
get_preprocessor, NoPreprocessor, TupleFlatteningPreprocessor, \
|
|
OneHotPreprocessor, AtariRamPreprocessor, GenericPixelPreprocessor
|
|
from ray.rllib.utils.test_utils import check
|
|
|
|
|
|
class TestPreprocessors(unittest.TestCase):
|
|
def test_gym_preprocessors(self):
|
|
p1 = ModelCatalog.get_preprocessor(gym.make("CartPole-v0"))
|
|
self.assertEqual(type(p1), NoPreprocessor)
|
|
|
|
p2 = ModelCatalog.get_preprocessor(gym.make("FrozenLake-v0"))
|
|
self.assertEqual(type(p2), OneHotPreprocessor)
|
|
|
|
p3 = ModelCatalog.get_preprocessor(gym.make("MsPacman-ram-v0"))
|
|
self.assertEqual(type(p3), AtariRamPreprocessor)
|
|
|
|
p4 = ModelCatalog.get_preprocessor(gym.make("MsPacmanNoFrameskip-v4"))
|
|
self.assertEqual(type(p4), GenericPixelPreprocessor)
|
|
|
|
def test_tuple_preprocessor(self):
|
|
class TupleEnv:
|
|
def __init__(self):
|
|
self.observation_space = Tuple(
|
|
[Discrete(5),
|
|
Box(0, 5, shape=(3, ), dtype=np.float32)])
|
|
|
|
pp = ModelCatalog.get_preprocessor(TupleEnv())
|
|
self.assertTrue(isinstance(pp, TupleFlatteningPreprocessor))
|
|
self.assertEqual(pp.shape, (8, ))
|
|
self.assertEqual(
|
|
list(pp.transform((0, np.array([1, 2, 3])))),
|
|
[float(x) for x in [1, 0, 0, 0, 0, 1, 2, 3]])
|
|
|
|
def test_dict_flattening_preprocessor(self):
|
|
space = Dict({
|
|
"a": Discrete(2),
|
|
"b": Tuple([Discrete(3), Box(-1.0, 1.0, (4, ))]),
|
|
})
|
|
pp = get_preprocessor(space)(space)
|
|
self.assertTrue(isinstance(pp, DictFlatteningPreprocessor))
|
|
self.assertEqual(pp.shape, (9, ))
|
|
check(
|
|
pp.transform({
|
|
"a": 1,
|
|
"b": (1, np.array([0.0, -0.5, 0.1, 0.6]))
|
|
}), [0.0, 1.0, 0.0, 1.0, 0.0, 0.0, -0.5, 0.1, 0.6])
|
|
|
|
def test_one_hot_preprocessor(self):
|
|
space = Discrete(5)
|
|
pp = get_preprocessor(space)(space)
|
|
self.assertTrue(isinstance(pp, OneHotPreprocessor))
|
|
self.assertTrue(pp.shape == (5, ))
|
|
check(pp.transform(3), [0.0, 0.0, 0.0, 1.0, 0.0])
|
|
check(pp.transform(0), [1.0, 0.0, 0.0, 0.0, 0.0])
|
|
|
|
space = MultiDiscrete([2, 3, 4])
|
|
pp = get_preprocessor(space)(space)
|
|
self.assertTrue(isinstance(pp, OneHotPreprocessor))
|
|
self.assertTrue(pp.shape == (9, ))
|
|
check(
|
|
pp.transform(np.array([1, 2, 0])),
|
|
[0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 0.0])
|
|
check(
|
|
pp.transform(np.array([0, 1, 3])),
|
|
[1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0])
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import pytest
|
|
import sys
|
|
sys.exit(pytest.main(["-v", __file__]))
|