mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
162 lines
5.4 KiB
Python
162 lines
5.4 KiB
Python
import gym
|
|
from gym.spaces import Box, Dict, Discrete, MultiDiscrete, Tuple
|
|
import numpy as np
|
|
import unittest
|
|
|
|
import ray
|
|
import ray.rllib.agents.ppo as ppo
|
|
from ray.rllib.models.catalog import ModelCatalog
|
|
from ray.rllib.models.preprocessors import (
|
|
DictFlatteningPreprocessor,
|
|
get_preprocessor,
|
|
NoPreprocessor,
|
|
TupleFlatteningPreprocessor,
|
|
OneHotPreprocessor,
|
|
AtariRamPreprocessor,
|
|
GenericPixelPreprocessor,
|
|
)
|
|
from ray.rllib.utils.test_utils import (
|
|
check,
|
|
check_compute_single_action,
|
|
check_train_results,
|
|
framework_iterator,
|
|
)
|
|
|
|
|
|
class TestPreprocessors(unittest.TestCase):
|
|
@classmethod
|
|
def setUpClass(cls) -> None:
|
|
ray.init()
|
|
|
|
@classmethod
|
|
def tearDownClass(cls) -> None:
|
|
ray.shutdown()
|
|
|
|
def test_preprocessing_disabled(self):
|
|
config = ppo.DEFAULT_CONFIG.copy()
|
|
config["seed"] = 42
|
|
config["env"] = "ray.rllib.examples.env.random_env.RandomEnv"
|
|
config["env_config"] = {
|
|
"config": {
|
|
"observation_space": Dict(
|
|
{
|
|
"a": Discrete(5),
|
|
"b": Dict(
|
|
{
|
|
"ba": Discrete(4),
|
|
"bb": Box(-1.0, 1.0, (2, 3), dtype=np.float32),
|
|
}
|
|
),
|
|
"c": Tuple((MultiDiscrete([2, 3]), Discrete(1))),
|
|
"d": Box(-1.0, 1.0, (1,), dtype=np.int32),
|
|
}
|
|
),
|
|
},
|
|
}
|
|
# Set this to True to enforce no preprocessors being used.
|
|
# Complex observations now arrive directly in the model as
|
|
# structures of batches, e.g. {"a": tensor, "b": [tensor, tensor]}
|
|
# for obs-space=Dict(a=..., b=Tuple(..., ...)).
|
|
config["_disable_preprocessor_api"] = True
|
|
# Speed things up a little.
|
|
config["train_batch_size"] = 100
|
|
config["sgd_minibatch_size"] = 10
|
|
config["rollout_fragment_length"] = 5
|
|
config["num_sgd_iter"] = 1
|
|
|
|
num_iterations = 1
|
|
# Only supported for tf so far.
|
|
for _ in framework_iterator(config):
|
|
trainer = ppo.PPOTrainer(config=config)
|
|
for i in range(num_iterations):
|
|
results = trainer.train()
|
|
check_train_results(results)
|
|
print(results)
|
|
check_compute_single_action(trainer)
|
|
trainer.stop()
|
|
|
|
def test_gym_preprocessors(self):
|
|
p1 = ModelCatalog.get_preprocessor(gym.make("CartPole-v0"))
|
|
self.assertEqual(type(p1), NoPreprocessor)
|
|
|
|
p2 = ModelCatalog.get_preprocessor(gym.make("FrozenLake-v1"))
|
|
self.assertEqual(type(p2), OneHotPreprocessor)
|
|
|
|
p3 = ModelCatalog.get_preprocessor(gym.make("MsPacman-ram-v0"))
|
|
self.assertEqual(type(p3), AtariRamPreprocessor)
|
|
|
|
p4 = ModelCatalog.get_preprocessor(gym.make("MsPacmanNoFrameskip-v4"))
|
|
self.assertEqual(type(p4), GenericPixelPreprocessor)
|
|
|
|
def test_tuple_preprocessor(self):
|
|
class TupleEnv:
|
|
def __init__(self):
|
|
self.observation_space = Tuple(
|
|
[Discrete(5), Box(0, 5, shape=(3,), dtype=np.float32)]
|
|
)
|
|
|
|
pp = ModelCatalog.get_preprocessor(TupleEnv())
|
|
self.assertTrue(isinstance(pp, TupleFlatteningPreprocessor))
|
|
self.assertEqual(pp.shape, (8,))
|
|
self.assertEqual(
|
|
list(pp.transform((0, np.array([1, 2, 3], np.float32)))),
|
|
[float(x) for x in [1, 0, 0, 0, 0, 1, 2, 3]],
|
|
)
|
|
|
|
def test_dict_flattening_preprocessor(self):
|
|
space = Dict(
|
|
{
|
|
"a": Discrete(2),
|
|
"b": Tuple([Discrete(3), Box(-1.0, 1.0, (4,))]),
|
|
}
|
|
)
|
|
pp = get_preprocessor(space)(space)
|
|
self.assertTrue(isinstance(pp, DictFlatteningPreprocessor))
|
|
self.assertEqual(pp.shape, (9,))
|
|
check(
|
|
pp.transform(
|
|
{"a": 1, "b": (1, np.array([0.0, -0.5, 0.1, 0.6], np.float32))}
|
|
),
|
|
[0.0, 1.0, 0.0, 1.0, 0.0, 0.0, -0.5, 0.1, 0.6],
|
|
)
|
|
|
|
def test_one_hot_preprocessor(self):
|
|
space = Discrete(5)
|
|
pp = get_preprocessor(space)(space)
|
|
self.assertTrue(isinstance(pp, OneHotPreprocessor))
|
|
self.assertTrue(pp.shape == (5,))
|
|
check(pp.transform(3), [0.0, 0.0, 0.0, 1.0, 0.0])
|
|
check(pp.transform(0), [1.0, 0.0, 0.0, 0.0, 0.0])
|
|
|
|
space = MultiDiscrete([2, 3, 4])
|
|
pp = get_preprocessor(space)(space)
|
|
self.assertTrue(isinstance(pp, OneHotPreprocessor))
|
|
self.assertTrue(pp.shape == (9,))
|
|
check(
|
|
pp.transform(np.array([1, 2, 0])),
|
|
[0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 0.0],
|
|
)
|
|
check(
|
|
pp.transform(np.array([0, 1, 3])),
|
|
[1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0],
|
|
)
|
|
|
|
def test_nested_multidiscrete_one_hot_preprocessor(self):
|
|
space = Tuple((MultiDiscrete([2, 3, 4]),))
|
|
pp = get_preprocessor(space)(space)
|
|
self.assertTrue(pp.shape == (9,))
|
|
check(
|
|
pp.transform((np.array([1, 2, 0]),)),
|
|
[0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 0.0],
|
|
)
|
|
check(
|
|
pp.transform((np.array([0, 1, 3]),)),
|
|
[1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0],
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import pytest
|
|
import sys
|
|
|
|
sys.exit(pytest.main(["-v", __file__]))
|