2020-04-28 14:59:16 +02:00
|
|
|
from gym.spaces import Box, Dict, Discrete, Tuple, MultiDiscrete
|
2018-03-06 08:31:02 +00:00
|
|
|
import numpy as np
|
2020-02-19 21:18:45 +01:00
|
|
|
import unittest
|
2018-01-24 11:03:43 -08:00
|
|
|
|
|
|
|
import ray
|
2021-02-08 12:05:16 +01:00
|
|
|
from ray.rllib.agents.registry import get_trainer_class
|
2020-05-01 22:59:34 +02:00
|
|
|
from ray.rllib.examples.env.random_env import RandomEnv
|
2020-05-18 17:26:40 +02:00
|
|
|
from ray.rllib.models.tf.fcnet import FullyConnectedNetwork as FCNetV2
|
|
|
|
from ray.rllib.models.tf.visionnet import VisionNetwork as VisionNetV2
|
2020-03-02 19:53:19 +01:00
|
|
|
from ray.rllib.models.torch.visionnet import VisionNetwork as TorchVisionNetV2
|
|
|
|
from ray.rllib.models.torch.fcnet import FullyConnectedNetwork as TorchFCNetV2
|
2018-01-24 11:03:43 -08:00
|
|
|
from ray.rllib.utils.error import UnsupportedSpaceException
|
2020-05-27 16:19:13 +02:00
|
|
|
from ray.rllib.utils.test_utils import framework_iterator
|
2018-01-24 11:03:43 -08:00
|
|
|
|
|
|
|
ACTION_SPACES_TO_TEST = {
|
|
|
|
"discrete": Discrete(5),
|
2018-10-20 15:21:22 -07:00
|
|
|
"vector": Box(-1.0, 1.0, (5, ), dtype=np.float32),
|
2021-02-02 13:05:58 +01:00
|
|
|
"vector2": Box(-1.0, 1.0, (5, 5), dtype=np.float32),
|
2021-04-11 13:16:01 +02:00
|
|
|
"int_actions": Box(0, 3, (2, 3), dtype=np.int32),
|
2019-05-29 20:41:02 -07:00
|
|
|
"multidiscrete": MultiDiscrete([1, 2, 3, 4]),
|
2018-10-20 15:21:22 -07:00
|
|
|
"tuple": Tuple(
|
2018-08-15 10:19:41 -07:00
|
|
|
[Discrete(2),
|
|
|
|
Discrete(3),
|
2018-10-20 15:21:22 -07:00
|
|
|
Box(-1.0, 1.0, (5, ), dtype=np.float32)]),
|
2020-04-28 14:59:16 +02:00
|
|
|
"dict": Dict({
|
|
|
|
"action_choice": Discrete(3),
|
|
|
|
"parameters": Box(-1.0, 1.0, (1, ), dtype=np.float32),
|
|
|
|
"yet_another_nested_dict": Dict({
|
|
|
|
"a": Tuple([Discrete(2), Discrete(3)])
|
|
|
|
})
|
|
|
|
}),
|
2018-01-24 11:03:43 -08:00
|
|
|
}
|
|
|
|
|
|
|
|
OBSERVATION_SPACES_TO_TEST = {
|
|
|
|
"discrete": Discrete(5),
|
2018-10-20 15:21:22 -07:00
|
|
|
"vector": Box(-1.0, 1.0, (5, ), dtype=np.float32),
|
2019-09-19 12:10:31 -07:00
|
|
|
"vector2": Box(-1.0, 1.0, (5, 5), dtype=np.float32),
|
2018-10-20 15:21:22 -07:00
|
|
|
"image": Box(-1.0, 1.0, (84, 84, 1), dtype=np.float32),
|
|
|
|
"atari": Box(-1.0, 1.0, (210, 160, 3), dtype=np.float32),
|
|
|
|
"tuple": Tuple([Discrete(10),
|
|
|
|
Box(-1.0, 1.0, (5, ), dtype=np.float32)]),
|
|
|
|
"dict": Dict({
|
|
|
|
"task": Discrete(10),
|
|
|
|
"position": Box(-1.0, 1.0, (5, ), dtype=np.float32),
|
|
|
|
}),
|
2018-01-24 11:03:43 -08:00
|
|
|
}
|
|
|
|
|
|
|
|
|
2020-06-05 08:34:21 +02:00
|
|
|
def check_support(alg, config, train=True, check_bounds=False, tfe=False):
|
2019-08-23 02:21:11 -04:00
|
|
|
config["log_level"] = "ERROR"
|
2021-02-08 15:02:19 +01:00
|
|
|
config["train_batch_size"] = 10
|
|
|
|
config["rollout_fragment_length"] = 10
|
2018-01-24 11:03:43 -08:00
|
|
|
|
2020-05-27 16:19:13 +02:00
|
|
|
def _do_check(alg, config, a_name, o_name):
|
|
|
|
fw = config["framework"]
|
|
|
|
action_space = ACTION_SPACES_TO_TEST[a_name]
|
|
|
|
obs_space = OBSERVATION_SPACES_TO_TEST[o_name]
|
|
|
|
print("=== Testing {} (fw={}) A={} S={} ===".format(
|
|
|
|
alg, fw, action_space, obs_space))
|
|
|
|
config.update(
|
|
|
|
dict(
|
|
|
|
env_config=dict(
|
|
|
|
action_space=action_space,
|
|
|
|
observation_space=obs_space,
|
|
|
|
reward_space=Box(1.0, 1.0, shape=(), dtype=np.float32),
|
|
|
|
p_done=1.0,
|
|
|
|
check_action_bounds=check_bounds)))
|
|
|
|
stat = "ok"
|
2020-06-20 00:05:19 +02:00
|
|
|
|
2020-05-27 16:19:13 +02:00
|
|
|
try:
|
2021-02-08 12:05:16 +01:00
|
|
|
a = get_trainer_class(alg)(config=config, env=RandomEnv)
|
2020-06-20 00:05:19 +02:00
|
|
|
except UnsupportedSpaceException:
|
|
|
|
stat = "unsupported"
|
|
|
|
else:
|
2020-05-27 16:19:13 +02:00
|
|
|
if alg not in ["DDPG", "ES", "ARS", "SAC"]:
|
|
|
|
if o_name in ["atari", "image"]:
|
|
|
|
if fw == "torch":
|
|
|
|
assert isinstance(a.get_policy().model,
|
|
|
|
TorchVisionNetV2)
|
|
|
|
else:
|
|
|
|
assert isinstance(a.get_policy().model, VisionNetV2)
|
|
|
|
elif o_name in ["vector", "vector2"]:
|
|
|
|
if fw == "torch":
|
|
|
|
assert isinstance(a.get_policy().model, TorchFCNetV2)
|
|
|
|
else:
|
|
|
|
assert isinstance(a.get_policy().model, FCNetV2)
|
2020-06-05 08:34:21 +02:00
|
|
|
if train:
|
|
|
|
a.train()
|
2020-07-08 16:12:20 +02:00
|
|
|
a.stop()
|
2020-05-27 16:19:13 +02:00
|
|
|
print(stat)
|
|
|
|
|
2020-09-02 14:03:01 +02:00
|
|
|
frameworks = ("tf", "torch")
|
2020-06-04 22:28:46 +02:00
|
|
|
if tfe:
|
2021-02-08 15:02:19 +01:00
|
|
|
frameworks += ("tf2", "tfe")
|
2020-06-04 22:28:46 +02:00
|
|
|
for _ in framework_iterator(config, frameworks=frameworks):
|
2021-02-08 15:02:19 +01:00
|
|
|
# Zip through action- and obs-spaces.
|
|
|
|
for a_name, o_name in zip(ACTION_SPACES_TO_TEST.keys(),
|
|
|
|
OBSERVATION_SPACES_TO_TEST.keys()):
|
2020-06-05 08:34:21 +02:00
|
|
|
_do_check(alg, config, a_name, o_name)
|
2021-02-08 15:02:19 +01:00
|
|
|
# Do the remaining obs spaces.
|
|
|
|
assert len(OBSERVATION_SPACES_TO_TEST) >= len(ACTION_SPACES_TO_TEST)
|
|
|
|
for i, o_name in enumerate(OBSERVATION_SPACES_TO_TEST.keys()):
|
|
|
|
if i < len(ACTION_SPACES_TO_TEST):
|
|
|
|
continue
|
|
|
|
_do_check(alg, config, "discrete", o_name)
|
2018-11-14 14:14:07 -08:00
|
|
|
|
|
|
|
|
2020-07-08 16:12:20 +02:00
|
|
|
class TestSupportedSpacesPG(unittest.TestCase):
|
2020-06-05 08:34:21 +02:00
|
|
|
@classmethod
|
|
|
|
def setUpClass(cls) -> None:
|
2021-02-08 15:02:19 +01:00
|
|
|
ray.init(num_cpus=6)
|
2018-11-14 14:14:07 -08:00
|
|
|
|
2020-06-05 08:34:21 +02:00
|
|
|
@classmethod
|
|
|
|
def tearDownClass(cls) -> None:
|
2018-11-14 14:14:07 -08:00
|
|
|
ray.shutdown()
|
|
|
|
|
2020-02-19 21:18:45 +01:00
|
|
|
def test_a3c(self):
|
2020-03-07 14:47:58 -08:00
|
|
|
config = {"num_workers": 1, "optimizer": {"grads_per_step": 1}}
|
2020-05-27 16:19:13 +02:00
|
|
|
check_support("A3C", config, check_bounds=True)
|
2020-02-19 21:18:45 +01:00
|
|
|
|
|
|
|
def test_appo(self):
|
2020-06-05 08:34:21 +02:00
|
|
|
check_support("APPO", {"num_gpus": 0, "vtrace": False}, train=False)
|
2020-05-27 16:19:13 +02:00
|
|
|
check_support("APPO", {"num_gpus": 0, "vtrace": True})
|
2020-02-19 21:18:45 +01:00
|
|
|
|
2020-07-08 16:12:20 +02:00
|
|
|
def test_impala(self):
|
|
|
|
check_support("IMPALA", {"num_gpus": 0})
|
|
|
|
|
|
|
|
def test_ppo(self):
|
|
|
|
config = {
|
2021-02-08 15:02:19 +01:00
|
|
|
"num_workers": 0,
|
|
|
|
"train_batch_size": 100,
|
2020-07-08 16:12:20 +02:00
|
|
|
"rollout_fragment_length": 10,
|
2021-02-08 15:02:19 +01:00
|
|
|
"num_sgd_iter": 1,
|
|
|
|
"sgd_minibatch_size": 10,
|
2020-07-08 16:12:20 +02:00
|
|
|
}
|
|
|
|
check_support("PPO", config, check_bounds=True, tfe=True)
|
|
|
|
|
|
|
|
def test_pg(self):
|
|
|
|
config = {"num_workers": 1, "optimizer": {}}
|
|
|
|
check_support("PG", config, train=False, check_bounds=True, tfe=True)
|
|
|
|
|
|
|
|
|
|
|
|
class TestSupportedSpacesOffPolicy(unittest.TestCase):
|
|
|
|
@classmethod
|
|
|
|
def setUpClass(cls) -> None:
|
|
|
|
ray.init(num_cpus=4)
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def tearDownClass(cls) -> None:
|
|
|
|
ray.shutdown()
|
2020-02-19 21:18:45 +01:00
|
|
|
|
|
|
|
def test_ddpg(self):
|
2018-11-24 00:56:50 -08:00
|
|
|
check_support(
|
|
|
|
"DDPG", {
|
2020-03-01 20:53:35 +01:00
|
|
|
"exploration_config": {
|
|
|
|
"ou_base_scale": 100.0
|
|
|
|
},
|
2019-04-26 17:49:53 -07:00
|
|
|
"timesteps_per_iteration": 1,
|
2020-06-05 08:34:21 +02:00
|
|
|
"buffer_size": 1000,
|
2019-04-26 17:49:53 -07:00
|
|
|
"use_state_preprocessor": True,
|
2018-11-24 00:56:50 -08:00
|
|
|
},
|
|
|
|
check_bounds=True)
|
2020-02-19 21:18:45 +01:00
|
|
|
|
|
|
|
def test_dqn(self):
|
2020-06-05 08:34:21 +02:00
|
|
|
config = {"timesteps_per_iteration": 1, "buffer_size": 1000}
|
2020-06-04 22:28:46 +02:00
|
|
|
check_support("DQN", config, tfe=True)
|
2020-02-19 21:18:45 +01:00
|
|
|
|
2020-07-08 16:12:20 +02:00
|
|
|
def test_sac(self):
|
|
|
|
check_support("SAC", {"buffer_size": 1000}, check_bounds=True)
|
|
|
|
|
|
|
|
|
|
|
|
class TestSupportedSpacesEvolutionAlgos(unittest.TestCase):
|
|
|
|
@classmethod
|
|
|
|
def setUpClass(cls) -> None:
|
|
|
|
ray.init(num_cpus=4)
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def tearDownClass(cls) -> None:
|
|
|
|
ray.shutdown()
|
|
|
|
|
|
|
|
def test_ars(self):
|
|
|
|
check_support(
|
|
|
|
"ARS", {
|
|
|
|
"num_workers": 1,
|
|
|
|
"noise_size": 1500000,
|
|
|
|
"num_rollouts": 1,
|
|
|
|
"rollouts_used": 1
|
|
|
|
})
|
|
|
|
|
2020-02-19 21:18:45 +01:00
|
|
|
def test_es(self):
|
2018-12-03 19:55:25 -08:00
|
|
|
check_support(
|
2020-02-19 21:18:45 +01:00
|
|
|
"ES", {
|
2018-12-03 19:55:25 -08:00
|
|
|
"num_workers": 1,
|
2020-05-29 11:55:47 +02:00
|
|
|
"noise_size": 1500000,
|
2020-02-19 21:18:45 +01:00
|
|
|
"episodes_per_batch": 1,
|
|
|
|
"train_batch_size": 1
|
2020-05-27 16:19:13 +02:00
|
|
|
})
|
2020-02-19 21:18:45 +01:00
|
|
|
|
2018-01-24 11:03:43 -08:00
|
|
|
|
|
|
|
if __name__ == "__main__":
|
2020-03-12 04:39:47 +01:00
|
|
|
import pytest
|
|
|
|
import sys
|
|
|
|
|
2020-07-08 16:12:20 +02:00
|
|
|
# One can specify the specific TestCase class to run.
|
|
|
|
# None for all unittest.TestCase classes in this file.
|
2020-07-11 22:06:35 +02:00
|
|
|
class_ = sys.argv[1] if len(sys.argv) > 1 else None
|
2020-08-07 16:49:49 -07:00
|
|
|
sys.exit(
|
|
|
|
pytest.main(
|
|
|
|
["-v", __file__ + ("" if class_ is None else "::" + class_)]))
|