import pytest from ray.rllib.env.multi_agent_env import make_multi_agent from ray.rllib.tests.test_nested_observation_spaces import NestedMultiAgentEnv class TestMultiAgentEnv: def test_space_in_preferred_format(self): env = NestedMultiAgentEnv() spaces_in_preferred_format = env._check_if_space_maps_agent_id_to_sub_space() assert spaces_in_preferred_format, "Space is not in preferred " "format" env2 = make_multi_agent("CartPole-v1")() spaces_in_preferred_format = env2._check_if_space_maps_agent_id_to_sub_space() assert not spaces_in_preferred_format, ( "Space should not be in " "preferred format but is." ) def test_spaces_sample_contain_in_preferred_format(self): env = NestedMultiAgentEnv() # this environment has spaces that are in the preferred format # for multi-agent environments where the spaces are dict spaces # mapping agent-ids to sub-spaces obs = env.observation_space_sample() assert env.observation_space_contains(obs), ( "Observation space does " "not contain obs" ) action = env.action_space_sample() assert env.action_space_contains(action), ( "Action space does " "not contain action" ) def test_spaces_sample_contain_not_in_preferred_format(self): env = make_multi_agent("CartPole-v1")({"num_agents": 2}) # this environment has spaces that are not in the preferred format # for multi-agent environments where the spaces not in the preferred # format, users must override the observation_space_contains, # action_space_contains observation_space_sample, # and action_space_sample methods in order to do proper checks obs = env.observation_space_sample() assert env.observation_space_contains(obs), ( "Observation space does " "not contain obs" ) action = env.action_space_sample() assert env.action_space_contains(action), ( "Action space does " "not contain action" ) if __name__ == "__main__": import sys sys.exit(pytest.main(["-v", __file__]))