ray/rllib/tests/test_supported_spaces.py
2022-05-17 13:43:49 +02:00

247 lines
8.2 KiB
Python

from gym.spaces import Box, Dict, Discrete, Tuple, MultiDiscrete
import numpy as np
import unittest
import ray
from ray.rllib.agents.registry import get_trainer_class
from ray.rllib.examples.env.random_env import RandomEnv
from ray.rllib.models.tf.complex_input_net import ComplexInputNetwork as ComplexNet
from ray.rllib.models.tf.fcnet import FullyConnectedNetwork as FCNet
from ray.rllib.models.tf.visionnet import VisionNetwork as VisionNet
from ray.rllib.models.torch.complex_input_net import (
ComplexInputNetwork as TorchComplexNet,
)
from ray.rllib.models.torch.fcnet import FullyConnectedNetwork as TorchFCNet
from ray.rllib.models.torch.visionnet import VisionNetwork as TorchVisionNet
from ray.rllib.utils.error import UnsupportedSpaceException
from ray.rllib.utils.test_utils import framework_iterator
ACTION_SPACES_TO_TEST = {
"discrete": Discrete(5),
"vector1d": Box(-1.0, 1.0, (5,), dtype=np.float32),
"vector2d": Box(-1.0, 1.0, (5,), dtype=np.float32),
"int_actions": Box(0, 3, (2, 3), dtype=np.int32),
"multidiscrete": MultiDiscrete([1, 2, 3, 4]),
"tuple": Tuple([Discrete(2), Discrete(3), Box(-1.0, 1.0, (5,), dtype=np.float32)]),
"dict": Dict(
{
"action_choice": Discrete(3),
"parameters": Box(-1.0, 1.0, (1,), dtype=np.float32),
"yet_another_nested_dict": Dict({"a": Tuple([Discrete(2), Discrete(3)])}),
}
),
}
OBSERVATION_SPACES_TO_TEST = {
"discrete": Discrete(5),
"vector1d": Box(-1.0, 1.0, (5,), dtype=np.float32),
"vector2d": Box(-1.0, 1.0, (5, 5), dtype=np.float32),
"image": Box(-1.0, 1.0, (84, 84, 1), dtype=np.float32),
"vizdoomgym": Box(-1.0, 1.0, (240, 320, 3), dtype=np.float32),
"tuple": Tuple([Discrete(10), Box(-1.0, 1.0, (5,), dtype=np.float32)]),
"dict": Dict(
{
"task": Discrete(10),
"position": Box(-1.0, 1.0, (5,), dtype=np.float32),
}
),
}
def check_support(alg, config, train=True, check_bounds=False, tfe=False):
config["log_level"] = "ERROR"
config["train_batch_size"] = 10
config["rollout_fragment_length"] = 10
def _do_check(alg, config, a_name, o_name):
fw = config["framework"]
action_space = ACTION_SPACES_TO_TEST[a_name]
obs_space = OBSERVATION_SPACES_TO_TEST[o_name]
print(
"=== Testing {} (fw={}) A={} S={} ===".format(
alg, fw, action_space, obs_space
)
)
config.update(
dict(
env_config=dict(
action_space=action_space,
observation_space=obs_space,
reward_space=Box(1.0, 1.0, shape=(), dtype=np.float32),
p_done=1.0,
check_action_bounds=check_bounds,
)
)
)
stat = "ok"
try:
a = get_trainer_class(alg)(config=config, env=RandomEnv)
except ray.exceptions.RayActorError as e:
if len(e.args) >= 2 and isinstance(e.args[2], UnsupportedSpaceException):
stat = "unsupported"
elif isinstance(e.args[0].args[2], UnsupportedSpaceException):
stat = "unsupported"
else:
raise
except UnsupportedSpaceException:
stat = "unsupported"
else:
if alg not in ["DDPG", "ES", "ARS", "SAC"]:
# 2D (image) input: Expect VisionNet.
if o_name in ["atari", "image"]:
if fw == "torch":
assert isinstance(a.get_policy().model, TorchVisionNet)
else:
assert isinstance(a.get_policy().model, VisionNet)
# 1D input: Expect FCNet.
elif o_name == "vector1d":
if fw == "torch":
assert isinstance(a.get_policy().model, TorchFCNet)
else:
assert isinstance(a.get_policy().model, FCNet)
# Could be either one: ComplexNet (if disabled Preprocessor)
# or FCNet (w/ Preprocessor).
elif o_name == "vector2d":
if fw == "torch":
assert isinstance(
a.get_policy().model, (TorchComplexNet, TorchFCNet)
)
else:
assert isinstance(a.get_policy().model, (ComplexNet, FCNet))
if train:
a.train()
a.stop()
print(stat)
frameworks = ("tf", "torch")
if tfe:
frameworks += ("tf2", "tfe")
for _ in framework_iterator(config, frameworks=frameworks):
# Zip through action- and obs-spaces.
for a_name, o_name in zip(
ACTION_SPACES_TO_TEST.keys(), OBSERVATION_SPACES_TO_TEST.keys()
):
_do_check(alg, config, a_name, o_name)
# Do the remaining obs spaces.
assert len(OBSERVATION_SPACES_TO_TEST) >= len(ACTION_SPACES_TO_TEST)
fixed_action_key = next(iter(ACTION_SPACES_TO_TEST.keys()))
for i, o_name in enumerate(OBSERVATION_SPACES_TO_TEST.keys()):
if i < len(ACTION_SPACES_TO_TEST):
continue
_do_check(alg, config, fixed_action_key, o_name)
class TestSupportedSpacesPG(unittest.TestCase):
@classmethod
def setUpClass(cls) -> None:
ray.init()
@classmethod
def tearDownClass(cls) -> None:
ray.shutdown()
def test_a3c(self):
config = {"num_workers": 1, "optimizer": {"grads_per_step": 1}}
check_support("A3C", config, check_bounds=True)
def test_appo(self):
check_support("APPO", {"num_gpus": 0, "vtrace": False}, train=False)
check_support("APPO", {"num_gpus": 0, "vtrace": True})
def test_impala(self):
check_support("IMPALA", {"num_gpus": 0})
def test_ppo(self):
config = {
"num_workers": 0,
"train_batch_size": 100,
"rollout_fragment_length": 10,
"num_sgd_iter": 1,
"sgd_minibatch_size": 10,
}
check_support("PPO", config, check_bounds=True, tfe=True)
def test_pg(self):
config = {"num_workers": 1, "optimizer": {}}
check_support("PG", config, train=False, check_bounds=True, tfe=True)
class TestSupportedSpacesOffPolicy(unittest.TestCase):
@classmethod
def setUpClass(cls) -> None:
ray.init(num_cpus=4)
@classmethod
def tearDownClass(cls) -> None:
ray.shutdown()
def test_ddpg(self):
check_support(
"DDPG",
{
"exploration_config": {"ou_base_scale": 100.0},
"min_sample_timesteps_per_reporting": 1,
"replay_buffer_config": {
"capacity": 1000,
},
"use_state_preprocessor": True,
},
check_bounds=True,
)
def test_dqn(self):
config = {
"min_sample_timesteps_per_reporting": 1,
"replay_buffer_config": {
"capacity": 1000,
},
}
check_support("DQN", config, tfe=True)
def test_sac(self):
check_support(
"SAC", {"replay_buffer_config": {"capacity": 1000}}, check_bounds=True
)
class TestSupportedSpacesEvolutionAlgos(unittest.TestCase):
@classmethod
def setUpClass(cls) -> None:
ray.init(num_cpus=4)
@classmethod
def tearDownClass(cls) -> None:
ray.shutdown()
def test_ars(self):
check_support(
"ARS",
{
"num_workers": 1,
"noise_size": 1500000,
"num_rollouts": 1,
"rollouts_used": 1,
},
)
def test_es(self):
check_support(
"ES",
{
"num_workers": 1,
"noise_size": 1500000,
"episodes_per_batch": 1,
"train_batch_size": 1,
},
)
if __name__ == "__main__":
import pytest
import sys
# One can specify the specific TestCase class to run.
# None for all unittest.TestCase classes in this file.
class_ = sys.argv[1] if len(sys.argv) > 1 else None
sys.exit(pytest.main(["-v", __file__ + ("" if class_ is None else "::" + class_)]))