ray/rllib/models/tests/test_preprocessors.py
Balaji Veeramani 7f1bacc7dc
[CI] Format Python code with Black (#21975)
See #21316 and #21311 for the motivation behind these changes.
2022-01-29 18:41:57 -08:00

162 lines
5.4 KiB
Python

import gym
from gym.spaces import Box, Dict, Discrete, MultiDiscrete, Tuple
import numpy as np
import unittest
import ray
import ray.rllib.agents.ppo as ppo
from ray.rllib.models.catalog import ModelCatalog
from ray.rllib.models.preprocessors import (
DictFlatteningPreprocessor,
get_preprocessor,
NoPreprocessor,
TupleFlatteningPreprocessor,
OneHotPreprocessor,
AtariRamPreprocessor,
GenericPixelPreprocessor,
)
from ray.rllib.utils.test_utils import (
check,
check_compute_single_action,
check_train_results,
framework_iterator,
)
class TestPreprocessors(unittest.TestCase):
@classmethod
def setUpClass(cls) -> None:
ray.init()
@classmethod
def tearDownClass(cls) -> None:
ray.shutdown()
def test_preprocessing_disabled(self):
config = ppo.DEFAULT_CONFIG.copy()
config["seed"] = 42
config["env"] = "ray.rllib.examples.env.random_env.RandomEnv"
config["env_config"] = {
"config": {
"observation_space": Dict(
{
"a": Discrete(5),
"b": Dict(
{
"ba": Discrete(4),
"bb": Box(-1.0, 1.0, (2, 3), dtype=np.float32),
}
),
"c": Tuple((MultiDiscrete([2, 3]), Discrete(1))),
"d": Box(-1.0, 1.0, (1,), dtype=np.int32),
}
),
},
}
# Set this to True to enforce no preprocessors being used.
# Complex observations now arrive directly in the model as
# structures of batches, e.g. {"a": tensor, "b": [tensor, tensor]}
# for obs-space=Dict(a=..., b=Tuple(..., ...)).
config["_disable_preprocessor_api"] = True
# Speed things up a little.
config["train_batch_size"] = 100
config["sgd_minibatch_size"] = 10
config["rollout_fragment_length"] = 5
config["num_sgd_iter"] = 1
num_iterations = 1
# Only supported for tf so far.
for _ in framework_iterator(config):
trainer = ppo.PPOTrainer(config=config)
for i in range(num_iterations):
results = trainer.train()
check_train_results(results)
print(results)
check_compute_single_action(trainer)
trainer.stop()
def test_gym_preprocessors(self):
p1 = ModelCatalog.get_preprocessor(gym.make("CartPole-v0"))
self.assertEqual(type(p1), NoPreprocessor)
p2 = ModelCatalog.get_preprocessor(gym.make("FrozenLake-v1"))
self.assertEqual(type(p2), OneHotPreprocessor)
p3 = ModelCatalog.get_preprocessor(gym.make("MsPacman-ram-v0"))
self.assertEqual(type(p3), AtariRamPreprocessor)
p4 = ModelCatalog.get_preprocessor(gym.make("MsPacmanNoFrameskip-v4"))
self.assertEqual(type(p4), GenericPixelPreprocessor)
def test_tuple_preprocessor(self):
class TupleEnv:
def __init__(self):
self.observation_space = Tuple(
[Discrete(5), Box(0, 5, shape=(3,), dtype=np.float32)]
)
pp = ModelCatalog.get_preprocessor(TupleEnv())
self.assertTrue(isinstance(pp, TupleFlatteningPreprocessor))
self.assertEqual(pp.shape, (8,))
self.assertEqual(
list(pp.transform((0, np.array([1, 2, 3], np.float32)))),
[float(x) for x in [1, 0, 0, 0, 0, 1, 2, 3]],
)
def test_dict_flattening_preprocessor(self):
space = Dict(
{
"a": Discrete(2),
"b": Tuple([Discrete(3), Box(-1.0, 1.0, (4,))]),
}
)
pp = get_preprocessor(space)(space)
self.assertTrue(isinstance(pp, DictFlatteningPreprocessor))
self.assertEqual(pp.shape, (9,))
check(
pp.transform(
{"a": 1, "b": (1, np.array([0.0, -0.5, 0.1, 0.6], np.float32))}
),
[0.0, 1.0, 0.0, 1.0, 0.0, 0.0, -0.5, 0.1, 0.6],
)
def test_one_hot_preprocessor(self):
space = Discrete(5)
pp = get_preprocessor(space)(space)
self.assertTrue(isinstance(pp, OneHotPreprocessor))
self.assertTrue(pp.shape == (5,))
check(pp.transform(3), [0.0, 0.0, 0.0, 1.0, 0.0])
check(pp.transform(0), [1.0, 0.0, 0.0, 0.0, 0.0])
space = MultiDiscrete([2, 3, 4])
pp = get_preprocessor(space)(space)
self.assertTrue(isinstance(pp, OneHotPreprocessor))
self.assertTrue(pp.shape == (9,))
check(
pp.transform(np.array([1, 2, 0])),
[0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 0.0],
)
check(
pp.transform(np.array([0, 1, 3])),
[1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0],
)
def test_nested_multidiscrete_one_hot_preprocessor(self):
space = Tuple((MultiDiscrete([2, 3, 4]),))
pp = get_preprocessor(space)(space)
self.assertTrue(pp.shape == (9,))
check(
pp.transform((np.array([1, 2, 0]),)),
[0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 0.0],
)
check(
pp.transform((np.array([0, 1, 3]),)),
[1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0],
)
if __name__ == "__main__":
import pytest
import sys
sys.exit(pytest.main(["-v", __file__]))