2018-01-24 11:03:43 -08:00
|
|
|
import unittest
|
|
|
|
import traceback
|
|
|
|
|
|
|
|
import gym
|
2018-10-20 15:21:22 -07:00
|
|
|
from gym.spaces import Box, Discrete, Tuple, Dict
|
2018-01-24 11:03:43 -08:00
|
|
|
from gym.envs.registration import EnvSpec
|
2018-03-06 08:31:02 +00:00
|
|
|
import numpy as np
|
2018-07-07 13:29:20 -07:00
|
|
|
import sys
|
2018-01-24 11:03:43 -08:00
|
|
|
|
|
|
|
import ray
|
2018-12-21 03:44:34 +09:00
|
|
|
from ray.rllib.agents.registry import get_agent_class
|
2019-03-02 13:37:16 -08:00
|
|
|
from ray.rllib.tests.test_multi_agent_env import (MultiCartpole,
|
|
|
|
MultiMountainCar)
|
2018-01-24 11:03:43 -08:00
|
|
|
from ray.rllib.utils.error import UnsupportedSpaceException
|
|
|
|
from ray.tune.registry import register_env
|
|
|
|
|
|
|
|
ACTION_SPACES_TO_TEST = {
|
|
|
|
"discrete": Discrete(5),
|
2018-10-20 15:21:22 -07:00
|
|
|
"vector": Box(-1.0, 1.0, (5, ), dtype=np.float32),
|
|
|
|
"tuple": Tuple(
|
2018-08-15 10:19:41 -07:00
|
|
|
[Discrete(2),
|
|
|
|
Discrete(3),
|
2018-10-20 15:21:22 -07:00
|
|
|
Box(-1.0, 1.0, (5, ), dtype=np.float32)]),
|
2018-01-24 11:03:43 -08:00
|
|
|
}
|
|
|
|
|
|
|
|
OBSERVATION_SPACES_TO_TEST = {
|
|
|
|
"discrete": Discrete(5),
|
2018-10-20 15:21:22 -07:00
|
|
|
"vector": Box(-1.0, 1.0, (5, ), dtype=np.float32),
|
|
|
|
"image": Box(-1.0, 1.0, (84, 84, 1), dtype=np.float32),
|
|
|
|
"atari": Box(-1.0, 1.0, (210, 160, 3), dtype=np.float32),
|
|
|
|
"tuple": Tuple([Discrete(10),
|
|
|
|
Box(-1.0, 1.0, (5, ), dtype=np.float32)]),
|
|
|
|
"dict": Dict({
|
|
|
|
"task": Discrete(10),
|
|
|
|
"position": Box(-1.0, 1.0, (5, ), dtype=np.float32),
|
|
|
|
}),
|
2018-01-24 11:03:43 -08:00
|
|
|
}
|
|
|
|
|
|
|
|
|
2018-10-20 15:21:22 -07:00
|
|
|
def make_stub_env(action_space, obs_space, check_action_bounds):
|
2018-01-24 11:03:43 -08:00
|
|
|
class StubEnv(gym.Env):
|
|
|
|
def __init__(self):
|
|
|
|
self.action_space = action_space
|
|
|
|
self.observation_space = obs_space
|
2018-03-06 08:31:02 +00:00
|
|
|
self.spec = EnvSpec("StubEnv-v0")
|
2018-01-24 11:03:43 -08:00
|
|
|
|
|
|
|
def reset(self):
|
|
|
|
sample = self.observation_space.sample()
|
|
|
|
return sample
|
|
|
|
|
|
|
|
def step(self, action):
|
2018-10-20 15:21:22 -07:00
|
|
|
if check_action_bounds and not self.action_space.contains(action):
|
|
|
|
raise ValueError("Illegal action for {}: {}".format(
|
|
|
|
self.action_space, action))
|
|
|
|
if (isinstance(self.action_space, Tuple)
|
|
|
|
and len(action) != len(self.action_space.spaces)):
|
|
|
|
raise ValueError("Illegal action for {}: {}".format(
|
|
|
|
self.action_space, action))
|
2018-01-24 11:03:43 -08:00
|
|
|
return self.observation_space.sample(), 1, True, {}
|
|
|
|
|
|
|
|
return StubEnv
|
|
|
|
|
|
|
|
|
2018-10-20 15:21:22 -07:00
|
|
|
def check_support(alg, config, stats, check_bounds=False):
|
2018-01-24 11:03:43 -08:00
|
|
|
for a_name, action_space in ACTION_SPACES_TO_TEST.items():
|
|
|
|
for o_name, obs_space in OBSERVATION_SPACES_TO_TEST.items():
|
|
|
|
print("=== Testing", alg, action_space, obs_space, "===")
|
2018-10-20 15:21:22 -07:00
|
|
|
stub_env = make_stub_env(action_space, obs_space, check_bounds)
|
2018-06-19 22:47:00 -07:00
|
|
|
register_env("stub_env", lambda c: stub_env())
|
2018-01-24 11:03:43 -08:00
|
|
|
stat = "ok"
|
|
|
|
a = None
|
|
|
|
try:
|
|
|
|
a = get_agent_class(alg)(config=config, env="stub_env")
|
|
|
|
a.train()
|
2018-10-24 16:30:00 -07:00
|
|
|
except UnsupportedSpaceException:
|
2018-01-24 11:03:43 -08:00
|
|
|
stat = "unsupported"
|
|
|
|
except Exception as e:
|
|
|
|
stat = "ERROR"
|
|
|
|
print(e)
|
|
|
|
print(traceback.format_exc())
|
|
|
|
finally:
|
|
|
|
if a:
|
|
|
|
try:
|
|
|
|
a.stop()
|
|
|
|
except Exception as e:
|
|
|
|
print("Ignoring error stopping agent", e)
|
|
|
|
pass
|
|
|
|
print(stat)
|
|
|
|
print()
|
|
|
|
stats[alg, a_name, o_name] = stat
|
|
|
|
|
|
|
|
|
2018-11-14 14:14:07 -08:00
|
|
|
def check_support_multiagent(alg, config):
|
|
|
|
register_env("multi_mountaincar", lambda _: MultiMountainCar(2))
|
|
|
|
register_env("multi_cartpole", lambda _: MultiCartpole(2))
|
2019-01-06 19:37:35 -08:00
|
|
|
if "DDPG" in alg:
|
2018-11-14 14:14:07 -08:00
|
|
|
a = get_agent_class(alg)(config=config, env="multi_mountaincar")
|
|
|
|
else:
|
|
|
|
a = get_agent_class(alg)(config=config, env="multi_cartpole")
|
|
|
|
try:
|
|
|
|
a.train()
|
|
|
|
finally:
|
|
|
|
a.stop()
|
|
|
|
|
|
|
|
|
2018-01-24 11:03:43 -08:00
|
|
|
class ModelSupportedSpaces(unittest.TestCase):
|
2018-11-14 14:14:07 -08:00
|
|
|
def setUp(self):
|
2019-03-19 09:58:45 -07:00
|
|
|
ray.init(num_cpus=4)
|
2018-11-14 14:14:07 -08:00
|
|
|
|
|
|
|
def tearDown(self):
|
|
|
|
ray.shutdown()
|
|
|
|
|
2018-01-24 11:03:43 -08:00
|
|
|
def testAll(self):
|
|
|
|
stats = {}
|
2018-10-15 11:02:50 -07:00
|
|
|
check_support("IMPALA", {"num_gpus": 0}, stats)
|
2019-01-18 13:40:26 -08:00
|
|
|
check_support("APPO", {"num_gpus": 0, "vtrace": False}, stats)
|
2018-11-24 00:56:50 -08:00
|
|
|
check_support(
|
|
|
|
"DDPG", {
|
|
|
|
"noise_scale": 100.0,
|
|
|
|
"timesteps_per_iteration": 1
|
|
|
|
},
|
|
|
|
stats,
|
|
|
|
check_bounds=True)
|
2018-01-24 11:03:43 -08:00
|
|
|
check_support("DQN", {"timesteps_per_iteration": 1}, stats)
|
2018-12-03 19:55:25 -08:00
|
|
|
check_support(
|
|
|
|
"A3C", {
|
|
|
|
"num_workers": 1,
|
|
|
|
"optimizer": {
|
|
|
|
"grads_per_step": 1
|
|
|
|
}
|
|
|
|
},
|
|
|
|
stats,
|
|
|
|
check_bounds=True)
|
2018-01-24 11:03:43 -08:00
|
|
|
check_support(
|
2018-07-19 15:30:36 -07:00
|
|
|
"PPO", {
|
|
|
|
"num_workers": 1,
|
|
|
|
"num_sgd_iter": 1,
|
2018-09-05 12:06:13 -07:00
|
|
|
"train_batch_size": 10,
|
|
|
|
"sample_batch_size": 10,
|
2018-10-20 15:21:22 -07:00
|
|
|
"sgd_minibatch_size": 1,
|
|
|
|
},
|
|
|
|
stats,
|
|
|
|
check_bounds=True)
|
2018-01-24 11:03:43 -08:00
|
|
|
check_support(
|
2018-07-19 15:30:36 -07:00
|
|
|
"ES", {
|
|
|
|
"num_workers": 1,
|
|
|
|
"noise_size": 10000000,
|
|
|
|
"episodes_per_batch": 1,
|
2018-09-05 12:06:13 -07:00
|
|
|
"train_batch_size": 1
|
2018-07-19 15:30:36 -07:00
|
|
|
}, stats)
|
2018-08-24 22:20:02 -07:00
|
|
|
check_support(
|
|
|
|
"ARS", {
|
|
|
|
"num_workers": 1,
|
|
|
|
"noise_size": 10000000,
|
2018-09-26 22:32:26 -07:00
|
|
|
"num_rollouts": 1,
|
|
|
|
"rollouts_used": 1
|
2018-08-24 22:20:02 -07:00
|
|
|
}, stats)
|
2018-12-03 19:55:25 -08:00
|
|
|
check_support(
|
|
|
|
"PG", {
|
|
|
|
"num_workers": 1,
|
|
|
|
"optimizer": {}
|
|
|
|
},
|
|
|
|
stats,
|
|
|
|
check_bounds=True)
|
2018-01-24 11:03:43 -08:00
|
|
|
num_unexpected_errors = 0
|
|
|
|
for (alg, a_name, o_name), stat in sorted(stats.items()):
|
2018-06-09 00:21:35 -07:00
|
|
|
if stat not in ["ok", "unsupported"]:
|
|
|
|
num_unexpected_errors += 1
|
2018-07-19 15:30:36 -07:00
|
|
|
print(alg, "action_space", a_name, "obs_space", o_name, "result",
|
|
|
|
stat)
|
2018-01-24 11:03:43 -08:00
|
|
|
self.assertEqual(num_unexpected_errors, 0)
|
|
|
|
|
2018-11-14 14:14:07 -08:00
|
|
|
def testMultiAgent(self):
|
2019-01-06 19:37:35 -08:00
|
|
|
check_support_multiagent(
|
|
|
|
"APEX", {
|
|
|
|
"num_workers": 2,
|
|
|
|
"timesteps_per_iteration": 1000,
|
|
|
|
"num_gpus": 0,
|
|
|
|
"min_iter_time_s": 1,
|
|
|
|
"learning_starts": 1000,
|
|
|
|
"target_network_update_freq": 100,
|
|
|
|
})
|
|
|
|
check_support_multiagent(
|
|
|
|
"APEX_DDPG", {
|
|
|
|
"num_workers": 2,
|
|
|
|
"timesteps_per_iteration": 1000,
|
|
|
|
"num_gpus": 0,
|
|
|
|
"min_iter_time_s": 1,
|
|
|
|
"learning_starts": 1000,
|
|
|
|
"target_network_update_freq": 100,
|
|
|
|
})
|
2018-11-14 14:14:07 -08:00
|
|
|
check_support_multiagent("IMPALA", {"num_gpus": 0})
|
|
|
|
check_support_multiagent("DQN", {"timesteps_per_iteration": 1})
|
|
|
|
check_support_multiagent("A3C", {
|
|
|
|
"num_workers": 1,
|
|
|
|
"optimizer": {
|
|
|
|
"grads_per_step": 1
|
|
|
|
}
|
|
|
|
})
|
|
|
|
check_support_multiagent(
|
|
|
|
"PPO", {
|
|
|
|
"num_workers": 1,
|
|
|
|
"num_sgd_iter": 1,
|
|
|
|
"train_batch_size": 10,
|
|
|
|
"sample_batch_size": 10,
|
|
|
|
"sgd_minibatch_size": 1,
|
|
|
|
})
|
|
|
|
check_support_multiagent("PG", {"num_workers": 1, "optimizer": {}})
|
|
|
|
check_support_multiagent("DDPG", {"timesteps_per_iteration": 1})
|
|
|
|
|
2018-01-24 11:03:43 -08:00
|
|
|
|
|
|
|
if __name__ == "__main__":
|
2018-07-07 13:29:20 -07:00
|
|
|
if len(sys.argv) > 1 and sys.argv[1] == "--smoke":
|
|
|
|
ACTION_SPACES_TO_TEST = {
|
|
|
|
"discrete": Discrete(5),
|
|
|
|
}
|
|
|
|
OBSERVATION_SPACES_TO_TEST = {
|
2018-07-19 15:30:36 -07:00
|
|
|
"vector": Box(0.0, 1.0, (5, ), dtype=np.float32),
|
2018-07-07 13:29:20 -07:00
|
|
|
"atari": Box(0.0, 1.0, (210, 160, 3), dtype=np.float32),
|
|
|
|
}
|
2018-01-24 11:03:43 -08:00
|
|
|
unittest.main(verbosity=2)
|