mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
238 lines
8.1 KiB
Python
238 lines
8.1 KiB
Python
from gym.spaces import Box, Dict, Discrete, Tuple, MultiDiscrete
|
|
import numpy as np
|
|
import unittest
|
|
|
|
import ray
|
|
from ray.rllib.agents.registry import get_trainer_class
|
|
from ray.rllib.examples.env.random_env import RandomEnv
|
|
from ray.rllib.models.tf.complex_input_net import ComplexInputNetwork as ComplexNet
|
|
from ray.rllib.models.tf.fcnet import FullyConnectedNetwork as FCNet
|
|
from ray.rllib.models.tf.visionnet import VisionNetwork as VisionNet
|
|
from ray.rllib.models.torch.complex_input_net import (
|
|
ComplexInputNetwork as TorchComplexNet,
|
|
)
|
|
from ray.rllib.models.torch.fcnet import FullyConnectedNetwork as TorchFCNet
|
|
from ray.rllib.models.torch.visionnet import VisionNetwork as TorchVisionNet
|
|
from ray.rllib.utils.error import UnsupportedSpaceException
|
|
from ray.rllib.utils.test_utils import framework_iterator
|
|
|
|
ACTION_SPACES_TO_TEST = {
|
|
"discrete": Discrete(5),
|
|
"vector1d": Box(-1.0, 1.0, (5,), dtype=np.float32),
|
|
"vector2d": Box(-1.0, 1.0, (5,), dtype=np.float32),
|
|
"int_actions": Box(0, 3, (2, 3), dtype=np.int32),
|
|
"multidiscrete": MultiDiscrete([1, 2, 3, 4]),
|
|
"tuple": Tuple([Discrete(2), Discrete(3), Box(-1.0, 1.0, (5,), dtype=np.float32)]),
|
|
"dict": Dict(
|
|
{
|
|
"action_choice": Discrete(3),
|
|
"parameters": Box(-1.0, 1.0, (1,), dtype=np.float32),
|
|
"yet_another_nested_dict": Dict({"a": Tuple([Discrete(2), Discrete(3)])}),
|
|
}
|
|
),
|
|
}
|
|
|
|
OBSERVATION_SPACES_TO_TEST = {
|
|
"discrete": Discrete(5),
|
|
"vector1d": Box(-1.0, 1.0, (5,), dtype=np.float32),
|
|
"vector2d": Box(-1.0, 1.0, (5, 5), dtype=np.float32),
|
|
"image": Box(-1.0, 1.0, (84, 84, 1), dtype=np.float32),
|
|
"vizdoomgym": Box(-1.0, 1.0, (240, 320, 3), dtype=np.float32),
|
|
"tuple": Tuple([Discrete(10), Box(-1.0, 1.0, (5,), dtype=np.float32)]),
|
|
"dict": Dict(
|
|
{
|
|
"task": Discrete(10),
|
|
"position": Box(-1.0, 1.0, (5,), dtype=np.float32),
|
|
}
|
|
),
|
|
}
|
|
|
|
|
|
def check_support(alg, config, train=True, check_bounds=False, tfe=False):
|
|
config["log_level"] = "ERROR"
|
|
config["train_batch_size"] = 10
|
|
config["rollout_fragment_length"] = 10
|
|
|
|
def _do_check(alg, config, a_name, o_name):
|
|
fw = config["framework"]
|
|
action_space = ACTION_SPACES_TO_TEST[a_name]
|
|
obs_space = OBSERVATION_SPACES_TO_TEST[o_name]
|
|
print(
|
|
"=== Testing {} (fw={}) A={} S={} ===".format(
|
|
alg, fw, action_space, obs_space
|
|
)
|
|
)
|
|
config.update(
|
|
dict(
|
|
env_config=dict(
|
|
action_space=action_space,
|
|
observation_space=obs_space,
|
|
reward_space=Box(1.0, 1.0, shape=(), dtype=np.float32),
|
|
p_done=1.0,
|
|
check_action_bounds=check_bounds,
|
|
)
|
|
)
|
|
)
|
|
stat = "ok"
|
|
|
|
try:
|
|
a = get_trainer_class(alg)(config=config, env=RandomEnv)
|
|
except ray.exceptions.RayActorError as e:
|
|
if len(e.args) >= 2 and isinstance(e.args[2], UnsupportedSpaceException):
|
|
stat = "unsupported"
|
|
elif isinstance(e.args[0].args[2], UnsupportedSpaceException):
|
|
stat = "unsupported"
|
|
else:
|
|
raise
|
|
except UnsupportedSpaceException:
|
|
stat = "unsupported"
|
|
else:
|
|
if alg not in ["DDPG", "ES", "ARS", "SAC"]:
|
|
# 2D (image) input: Expect VisionNet.
|
|
if o_name in ["atari", "image"]:
|
|
if fw == "torch":
|
|
assert isinstance(a.get_policy().model, TorchVisionNet)
|
|
else:
|
|
assert isinstance(a.get_policy().model, VisionNet)
|
|
# 1D input: Expect FCNet.
|
|
elif o_name == "vector1d":
|
|
if fw == "torch":
|
|
assert isinstance(a.get_policy().model, TorchFCNet)
|
|
else:
|
|
assert isinstance(a.get_policy().model, FCNet)
|
|
# Could be either one: ComplexNet (if disabled Preprocessor)
|
|
# or FCNet (w/ Preprocessor).
|
|
elif o_name == "vector2d":
|
|
if fw == "torch":
|
|
assert isinstance(
|
|
a.get_policy().model, (TorchComplexNet, TorchFCNet)
|
|
)
|
|
else:
|
|
assert isinstance(a.get_policy().model, (ComplexNet, FCNet))
|
|
if train:
|
|
a.train()
|
|
a.stop()
|
|
print(stat)
|
|
|
|
frameworks = ("tf", "torch")
|
|
if tfe:
|
|
frameworks += ("tf2", "tfe")
|
|
for _ in framework_iterator(config, frameworks=frameworks):
|
|
# Zip through action- and obs-spaces.
|
|
for a_name, o_name in zip(
|
|
ACTION_SPACES_TO_TEST.keys(), OBSERVATION_SPACES_TO_TEST.keys()
|
|
):
|
|
_do_check(alg, config, a_name, o_name)
|
|
# Do the remaining obs spaces.
|
|
assert len(OBSERVATION_SPACES_TO_TEST) >= len(ACTION_SPACES_TO_TEST)
|
|
fixed_action_key = next(iter(ACTION_SPACES_TO_TEST.keys()))
|
|
for i, o_name in enumerate(OBSERVATION_SPACES_TO_TEST.keys()):
|
|
if i < len(ACTION_SPACES_TO_TEST):
|
|
continue
|
|
_do_check(alg, config, fixed_action_key, o_name)
|
|
|
|
|
|
class TestSupportedSpacesPG(unittest.TestCase):
|
|
@classmethod
|
|
def setUpClass(cls) -> None:
|
|
ray.init()
|
|
|
|
@classmethod
|
|
def tearDownClass(cls) -> None:
|
|
ray.shutdown()
|
|
|
|
def test_a3c(self):
|
|
config = {"num_workers": 1, "optimizer": {"grads_per_step": 1}}
|
|
check_support("A3C", config, check_bounds=True)
|
|
|
|
def test_appo(self):
|
|
check_support("APPO", {"num_gpus": 0, "vtrace": False}, train=False)
|
|
check_support("APPO", {"num_gpus": 0, "vtrace": True})
|
|
|
|
def test_impala(self):
|
|
check_support("IMPALA", {"num_gpus": 0})
|
|
|
|
def test_ppo(self):
|
|
config = {
|
|
"num_workers": 0,
|
|
"train_batch_size": 100,
|
|
"rollout_fragment_length": 10,
|
|
"num_sgd_iter": 1,
|
|
"sgd_minibatch_size": 10,
|
|
}
|
|
check_support("PPO", config, check_bounds=True, tfe=True)
|
|
|
|
def test_pg(self):
|
|
config = {"num_workers": 1, "optimizer": {}}
|
|
check_support("PG", config, train=False, check_bounds=True, tfe=True)
|
|
|
|
|
|
class TestSupportedSpacesOffPolicy(unittest.TestCase):
|
|
@classmethod
|
|
def setUpClass(cls) -> None:
|
|
ray.init(num_cpus=4)
|
|
|
|
@classmethod
|
|
def tearDownClass(cls) -> None:
|
|
ray.shutdown()
|
|
|
|
def test_ddpg(self):
|
|
check_support(
|
|
"DDPG",
|
|
{
|
|
"exploration_config": {"ou_base_scale": 100.0},
|
|
"min_sample_timesteps_per_reporting": 1,
|
|
"buffer_size": 1000,
|
|
"use_state_preprocessor": True,
|
|
},
|
|
check_bounds=True,
|
|
)
|
|
|
|
def test_dqn(self):
|
|
config = {"min_sample_timesteps_per_reporting": 1, "buffer_size": 1000}
|
|
check_support("DQN", config, tfe=True)
|
|
|
|
def test_sac(self):
|
|
check_support("SAC", {"buffer_size": 1000}, check_bounds=True)
|
|
|
|
|
|
class TestSupportedSpacesEvolutionAlgos(unittest.TestCase):
|
|
@classmethod
|
|
def setUpClass(cls) -> None:
|
|
ray.init(num_cpus=4)
|
|
|
|
@classmethod
|
|
def tearDownClass(cls) -> None:
|
|
ray.shutdown()
|
|
|
|
def test_ars(self):
|
|
check_support(
|
|
"ARS",
|
|
{
|
|
"num_workers": 1,
|
|
"noise_size": 1500000,
|
|
"num_rollouts": 1,
|
|
"rollouts_used": 1,
|
|
},
|
|
)
|
|
|
|
def test_es(self):
|
|
check_support(
|
|
"ES",
|
|
{
|
|
"num_workers": 1,
|
|
"noise_size": 1500000,
|
|
"episodes_per_batch": 1,
|
|
"train_batch_size": 1,
|
|
},
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import pytest
|
|
import sys
|
|
|
|
# One can specify the specific TestCase class to run.
|
|
# None for all unittest.TestCase classes in this file.
|
|
class_ = sys.argv[1] if len(sys.argv) > 1 else None
|
|
sys.exit(pytest.main(["-v", __file__ + ("" if class_ is None else "::" + class_)]))
|