from gym import spaces from gym.envs.registration import EnvSpec import gym import pickle import unittest import ray from ray.rllib.agents.a3c import A2CTrainer from ray.rllib.agents.pg import PGTrainer from ray.rllib.agents.pg.pg_tf_policy import PGTFPolicy from ray.rllib.env import MultiAgentEnv from ray.rllib.env.base_env import BaseEnv from ray.rllib.env.vector_env import VectorEnv from ray.rllib.models import ModelCatalog from ray.rllib.models.tf.tf_modelv2 import TFModelV2 from ray.rllib.models.torch.fcnet import FullyConnectedNetwork from ray.rllib.models.torch.torch_modelv2 import TorchModelV2 from ray.rllib.rollout import rollout from ray.rllib.tests.test_external_env import SimpleServing from ray.tune.registry import register_env from ray.rllib.utils import try_import_tf, try_import_torch tf = try_import_tf() _, nn = try_import_torch() DICT_SPACE = spaces.Dict({ "sensors": spaces.Dict({ "position": spaces.Box(low=-100, high=100, shape=(3, )), "velocity": spaces.Box(low=-1, high=1, shape=(3, )), "front_cam": spaces.Tuple( (spaces.Box(low=0, high=1, shape=(10, 10, 3)), spaces.Box(low=0, high=1, shape=(10, 10, 3)))), "rear_cam": spaces.Box(low=0, high=1, shape=(10, 10, 3)), }), "inner_state": spaces.Dict({ "charge": spaces.Discrete(100), "job_status": spaces.Dict({ "task": spaces.Discrete(5), "progress": spaces.Box(low=0, high=100, shape=()), }) }) }) DICT_SAMPLES = [DICT_SPACE.sample() for _ in range(10)] TUPLE_SPACE = spaces.Tuple([ spaces.Box(low=-100, high=100, shape=(3, )), spaces.Tuple((spaces.Box(low=0, high=1, shape=(10, 10, 3)), spaces.Box(low=0, high=1, shape=(10, 10, 3)))), spaces.Discrete(5), ]) TUPLE_SAMPLES = [TUPLE_SPACE.sample() for _ in range(10)] def one_hot(i, n): out = [0.0] * n out[i] = 1.0 return out class NestedDictEnv(gym.Env): def __init__(self): self.action_space = spaces.Discrete(2) self.observation_space = DICT_SPACE self._spec = EnvSpec("NestedDictEnv-v0") self.steps = 0 def reset(self): self.steps = 0 return DICT_SAMPLES[0] def step(self, action): self.steps += 1 return DICT_SAMPLES[self.steps], 1, self.steps >= 5, {} class NestedTupleEnv(gym.Env): def __init__(self): self.action_space = spaces.Discrete(2) self.observation_space = TUPLE_SPACE self._spec = EnvSpec("NestedTupleEnv-v0") self.steps = 0 def reset(self): self.steps = 0 return TUPLE_SAMPLES[0] def step(self, action): self.steps += 1 return TUPLE_SAMPLES[self.steps], 1, self.steps >= 5, {} class NestedMultiAgentEnv(MultiAgentEnv): def __init__(self): self.steps = 0 def reset(self): return { "dict_agent": DICT_SAMPLES[0], "tuple_agent": TUPLE_SAMPLES[0], } def step(self, actions): self.steps += 1 obs = { "dict_agent": DICT_SAMPLES[self.steps], "tuple_agent": TUPLE_SAMPLES[self.steps], } rew = { "dict_agent": 0, "tuple_agent": 0, } dones = {"__all__": self.steps >= 5} infos = { "dict_agent": {}, "tuple_agent": {}, } return obs, rew, dones, infos class InvalidModel(TorchModelV2): def forward(self, input_dict, state, seq_lens): return "not", "valid" class InvalidModel2(TFModelV2): def forward(self, input_dict, state, seq_lens): return tf.constant(0), tf.constant(0) class TorchSpyModel(TorchModelV2, nn.Module): capture_index = 0 def __init__(self, obs_space, action_space, num_outputs, model_config, name): TorchModelV2.__init__(self, obs_space, action_space, num_outputs, model_config, name) nn.Module.__init__(self) self.fc = FullyConnectedNetwork( obs_space.original_space.spaces["sensors"].spaces["position"], action_space, num_outputs, model_config, name) def forward(self, input_dict, state, seq_lens): pos = input_dict["obs"]["sensors"]["position"].numpy() front_cam = input_dict["obs"]["sensors"]["front_cam"][0].numpy() task = input_dict["obs"]["inner_state"]["job_status"]["task"].numpy() ray.experimental.internal_kv._internal_kv_put( "torch_spy_in_{}".format(TorchSpyModel.capture_index), pickle.dumps((pos, front_cam, task)), overwrite=True) TorchSpyModel.capture_index += 1 return self.fc({ "obs": input_dict["obs"]["sensors"]["position"] }, state, seq_lens) def value_function(self): return self.fc.value_function() class DictSpyModel(TFModelV2): capture_index = 0 def forward(self, input_dict, state, seq_lens): def spy(pos, front_cam, task): # TF runs this function in an isolated context, so we have to use # redis to communicate back to our suite ray.experimental.internal_kv._internal_kv_put( "d_spy_in_{}".format(DictSpyModel.capture_index), pickle.dumps((pos, front_cam, task)), overwrite=True) DictSpyModel.capture_index += 1 return 0 spy_fn = tf.py_func( spy, [ input_dict["obs"]["sensors"]["position"], input_dict["obs"]["sensors"]["front_cam"][0], input_dict["obs"]["inner_state"]["job_status"]["task"] ], tf.int64, stateful=True) with tf.control_dependencies([spy_fn]): output = tf.layers.dense(input_dict["obs"]["sensors"]["position"], self.num_outputs) return output, [] class TupleSpyModel(TFModelV2): capture_index = 0 def forward(self, input_dict, state, seq_lens): def spy(pos, cam, task): # TF runs this function in an isolated context, so we have to use # redis to communicate back to our suite ray.experimental.internal_kv._internal_kv_put( "t_spy_in_{}".format(TupleSpyModel.capture_index), pickle.dumps((pos, cam, task)), overwrite=True) TupleSpyModel.capture_index += 1 return 0 spy_fn = tf.py_func( spy, [ input_dict["obs"][0], input_dict["obs"][1][0], input_dict["obs"][2], ], tf.int64, stateful=True) with tf.control_dependencies([spy_fn]): output = tf.layers.dense(input_dict["obs"][0], self.num_outputs) return output, [] class NestedSpacesTest(unittest.TestCase): @classmethod def setUpClass(cls): ray.init(num_cpus=5) @classmethod def tearDownClass(cls): ray.shutdown() def test_invalid_model(self): ModelCatalog.register_custom_model("invalid", InvalidModel) self.assertRaisesRegexp( ValueError, "optimizer got an empty parameter list", lambda: PGTrainer( env="CartPole-v0", config={ "model": { "custom_model": "invalid", }, "framework": "torch", })) def test_invalid_model2(self): ModelCatalog.register_custom_model("invalid2", InvalidModel2) self.assertRaisesRegexp( ValueError, "Expected output shape of", lambda: PGTrainer( env="CartPole-v0", config={ "model": { "custom_model": "invalid2", }, "framework": "tf", })) def do_test_nested_dict(self, make_env, test_lstm=False): ModelCatalog.register_custom_model("composite", DictSpyModel) register_env("nested", make_env) pg = PGTrainer( env="nested", config={ "num_workers": 0, "rollout_fragment_length": 5, "train_batch_size": 5, "model": { "custom_model": "composite", "use_lstm": test_lstm, }, "framework": "tf", }) pg.train() # Check that the model sees the correct reconstructed observations for i in range(4): seen = pickle.loads( ray.experimental.internal_kv._internal_kv_get( "d_spy_in_{}".format(i))) pos_i = DICT_SAMPLES[i]["sensors"]["position"].tolist() cam_i = DICT_SAMPLES[i]["sensors"]["front_cam"][0].tolist() task_i = one_hot( DICT_SAMPLES[i]["inner_state"]["job_status"]["task"], 5) self.assertEqual(seen[0][0].tolist(), pos_i) self.assertEqual(seen[1][0].tolist(), cam_i) self.assertEqual(seen[2][0].tolist(), task_i) def do_test_nested_tuple(self, make_env): ModelCatalog.register_custom_model("composite2", TupleSpyModel) register_env("nested2", make_env) pg = PGTrainer( env="nested2", config={ "num_workers": 0, "rollout_fragment_length": 5, "train_batch_size": 5, "model": { "custom_model": "composite2", }, "framework": "tf", }) pg.train() # Check that the model sees the correct reconstructed observations for i in range(4): seen = pickle.loads( ray.experimental.internal_kv._internal_kv_get( "t_spy_in_{}".format(i))) pos_i = TUPLE_SAMPLES[i][0].tolist() cam_i = TUPLE_SAMPLES[i][1][0].tolist() task_i = one_hot(TUPLE_SAMPLES[i][2], 5) self.assertEqual(seen[0][0].tolist(), pos_i) self.assertEqual(seen[1][0].tolist(), cam_i) self.assertEqual(seen[2][0].tolist(), task_i) def test_nested_dict_gym(self): self.do_test_nested_dict(lambda _: NestedDictEnv()) def test_nested_dict_gym_lstm(self): self.do_test_nested_dict(lambda _: NestedDictEnv(), test_lstm=True) def test_nested_dict_vector(self): self.do_test_nested_dict( lambda _: VectorEnv.wrap(lambda i: NestedDictEnv())) def test_nested_dict_serving(self): self.do_test_nested_dict(lambda _: SimpleServing(NestedDictEnv())) def test_nested_dict_async(self): self.do_test_nested_dict( lambda _: BaseEnv.to_base_env(NestedDictEnv())) def test_nested_tuple_gym(self): self.do_test_nested_tuple(lambda _: NestedTupleEnv()) def test_nested_tuple_vector(self): self.do_test_nested_tuple( lambda _: VectorEnv.wrap(lambda i: NestedTupleEnv())) def test_nested_tuple_serving(self): self.do_test_nested_tuple(lambda _: SimpleServing(NestedTupleEnv())) def test_nested_tuple_async(self): self.do_test_nested_tuple( lambda _: BaseEnv.to_base_env(NestedTupleEnv())) def test_multi_agent_complex_spaces(self): ModelCatalog.register_custom_model("dict_spy", DictSpyModel) ModelCatalog.register_custom_model("tuple_spy", TupleSpyModel) register_env("nested_ma", lambda _: NestedMultiAgentEnv()) act_space = spaces.Discrete(2) pg = PGTrainer( env="nested_ma", config={ "num_workers": 0, "rollout_fragment_length": 5, "train_batch_size": 5, "multiagent": { "policies": { "tuple_policy": ( PGTFPolicy, TUPLE_SPACE, act_space, {"model": {"custom_model": "tuple_spy"}}), "dict_policy": ( PGTFPolicy, DICT_SPACE, act_space, {"model": {"custom_model": "dict_spy"}}), }, "policy_mapping_fn": lambda a: { "tuple_agent": "tuple_policy", "dict_agent": "dict_policy"}[a], }, "framework": "tf", }) pg.train() for i in range(4): seen = pickle.loads( ray.experimental.internal_kv._internal_kv_get( "d_spy_in_{}".format(i))) pos_i = DICT_SAMPLES[i]["sensors"]["position"].tolist() cam_i = DICT_SAMPLES[i]["sensors"]["front_cam"][0].tolist() task_i = one_hot( DICT_SAMPLES[i]["inner_state"]["job_status"]["task"], 5) self.assertEqual(seen[0][0].tolist(), pos_i) self.assertEqual(seen[1][0].tolist(), cam_i) self.assertEqual(seen[2][0].tolist(), task_i) for i in range(4): seen = pickle.loads( ray.experimental.internal_kv._internal_kv_get( "t_spy_in_{}".format(i))) pos_i = TUPLE_SAMPLES[i][0].tolist() cam_i = TUPLE_SAMPLES[i][1][0].tolist() task_i = one_hot(TUPLE_SAMPLES[i][2], 5) self.assertEqual(seen[0][0].tolist(), pos_i) self.assertEqual(seen[1][0].tolist(), cam_i) self.assertEqual(seen[2][0].tolist(), task_i) def test_rollout_dict_space(self): register_env("nested", lambda _: NestedDictEnv()) agent = PGTrainer(env="nested", config={"framework": "tf"}) agent.train() path = agent.save() agent.stop() # Test train works on restore agent2 = PGTrainer(env="nested", config={"framework": "tf"}) agent2.restore(path) agent2.train() # Test rollout works on restore rollout(agent2, "nested", 100) def test_py_torch_model(self): ModelCatalog.register_custom_model("composite", TorchSpyModel) register_env("nested", lambda _: NestedDictEnv()) a2c = A2CTrainer( env="nested", config={ "num_workers": 0, "rollout_fragment_length": 5, "train_batch_size": 5, "model": { "custom_model": "composite", }, "framework": "torch", }) a2c.train() # Check that the model sees the correct reconstructed observations for i in range(4): seen = pickle.loads( ray.experimental.internal_kv._internal_kv_get( "torch_spy_in_{}".format(i))) pos_i = DICT_SAMPLES[i]["sensors"]["position"].tolist() cam_i = DICT_SAMPLES[i]["sensors"]["front_cam"][0].tolist() task_i = one_hot( DICT_SAMPLES[i]["inner_state"]["job_status"]["task"], 5) self.assertEqual(seen[0][0].tolist(), pos_i) self.assertEqual(seen[1][0].tolist(), cam_i) self.assertEqual(seen[2][0].tolist(), task_i) if __name__ == "__main__": import pytest import sys sys.exit(pytest.main(["-v", __file__]))