import logging import numpy as np import pytest import unittest from unittest.mock import Mock, MagicMock from ray.rllib.env.multi_agent_env import make_multi_agent, MultiAgentEnvWrapper from ray.rllib.examples.env.random_env import RandomEnv from ray.rllib.utils.pre_checks.env import ( check_env, check_gym_environments, check_multiagent_environments, check_base_env, ) class TestGymCheckEnv(unittest.TestCase): @pytest.fixture(autouse=True) def inject_fixtures(self, caplog): caplog.set_level(logging.CRITICAL) def test_has_observation_and_action_space(self): env = Mock(spec=[]) with pytest.raises(AttributeError, match="Env must have observation_space."): check_gym_environments(env) env = Mock(spec=["observation_space"]) with pytest.raises(AttributeError, match="Env must have action_space."): check_gym_environments(env) del env def test_obs_and_action_spaces_are_gym_spaces(self): env = RandomEnv() observation_space = env.observation_space env.observation_space = "not a gym space" with pytest.raises(ValueError, match="Observation space must be a gym.space"): check_env(env) env.observation_space = observation_space env.action_space = "not an action space" with pytest.raises(ValueError, match="Action space must be a gym.space"): check_env(env) del env def test_sampled_observation_contained(self): env = RandomEnv() # check for observation that is out of bounds error = ".*A sampled observation from your env wasn't contained .*" env.observation_space.sample = MagicMock(return_value=5) with pytest.raises(ValueError, match=error): check_env(env) # check for observation that is in bounds, but the wrong type env.observation_space.sample = MagicMock(return_value=float(1)) with pytest.raises(ValueError, match=error): check_env(env) del env def test_sampled_action_contained(self): env = RandomEnv() error = ".*A sampled action from your env wasn't contained .*" env.action_space.sample = MagicMock(return_value=5) with pytest.raises(ValueError, match=error): check_env(env) # check for observation that is in bounds, but the wrong type env.action_space.sample = MagicMock(return_value=float(1)) with pytest.raises(ValueError, match=error): check_env(env) del env def test_reset(self): reset = MagicMock(return_value=5) env = RandomEnv() env.reset = reset # check reset with out of bounds fails error = ".*The observation collected from env.reset().*" with pytest.raises(ValueError, match=error): check_env(env) # check reset with obs of incorrect type fails reset = MagicMock(return_value=float(1)) env.reset = reset with pytest.raises(ValueError, match=error): check_env(env) del env def test_step(self): step = MagicMock(return_value=(5, 5, True, {})) env = RandomEnv() env.step = step error = ".*The observation collected from env.step.*" with pytest.raises(ValueError, match=error): check_env(env) # check reset that returns obs of incorrect type fails step = MagicMock(return_value=(float(1), 5, True, {})) env.step = step with pytest.raises(ValueError, match=error): check_env(env) # check step that returns reward of non float/int fails step = MagicMock(return_value=(1, "Not a valid reward", True, {})) env.step = step error = "Your step function must return a reward that is integer or " "float." with pytest.raises(AssertionError, match=error): check_env(env) # check step that returns a non bool fails step = MagicMock(return_value=(1, float(5), "not a valid done signal", {})) env.step = step error = "Your step function must return a done that is a boolean." with pytest.raises(AssertionError, match=error): check_env(env) # check step that returns a non dict fails step = MagicMock(return_value=(1, float(5), True, "not a valid env info")) env.step = step error = "Your step function must return a info that is a dict." with pytest.raises(AssertionError, match=error): check_env(env) del env class TestCheckMultiAgentEnv(unittest.TestCase): @pytest.fixture(autouse=True) def inject_fixtures(self, caplog): caplog.set_level(logging.CRITICAL) def test_check_env_not_correct_type_error(self): env = RandomEnv() with pytest.raises(ValueError, match="The passed env is not"): check_multiagent_environments(env) del env def test_check_env_reset_incorrect_error(self): reset = MagicMock(return_value=5) env = make_multi_agent("CartPole-v1")({"num_agents": 2}) env.reset = reset with pytest.raises(ValueError, match="The element returned by reset"): check_env(env) bad_obs = { 0: np.array([np.inf, np.inf, np.inf, np.inf]), 1: np.array([np.inf, np.inf, np.inf, np.inf]), } env.reset = lambda *_: bad_obs with pytest.raises(ValueError, match="The observation collected from " "env"): check_env(env) del env def test_check_incorrect_space_contains_functions_error(self): def bad_contains_function(self, x): raise ValueError("This is a bad contains function") env = make_multi_agent("CartPole-v1")({"num_agents": 2}) env.observation_space_contains = bad_contains_function with pytest.raises( ValueError, match="Your observation_space_contains " "function has some" ): check_env(env) del env env = make_multi_agent("CartPole-v1")({"num_agents": 2}) bad_action = {0: 2, 1: 2} env.action_space_sample = lambda *_: bad_action with pytest.raises( ValueError, match="The action collected from " "action_space_sample" ): check_env(env) env.action_space_contains = bad_contains_function with pytest.raises( ValueError, match="Your action_space_contains " "function has some error" ): check_env(env) def test_check_env_step_incorrect_error(self): step = MagicMock(return_value=(5, 5, True, {})) env = make_multi_agent("CartPole-v1")({"num_agents": 2}) sampled_obs = env.reset() env.step = step with pytest.raises(ValueError, match="The element returned by step"): check_env(env) step = MagicMock(return_value=(sampled_obs, "Not a reward", True, {})) env.step = step with pytest.raises( AssertionError, match="Your step function must " "return a reward " ): check_env(env) step = MagicMock(return_value=(sampled_obs, 5, "Not a bool", {})) env.step = step with pytest.raises( AssertionError, match="Your step function must " "return a done" ): check_env(env) step = MagicMock(return_value=(sampled_obs, 5, False, "Not a Dict")) env.step = step with pytest.raises( AssertionError, match="Your step function must " "return a info" ): check_env(env) def test_bad_sample_function(self): env = make_multi_agent("CartPole-v1")({"num_agents": 2}) bad_action = {0: 2, 1: 2} env.action_space_sample = lambda *_: bad_action with pytest.raises( ValueError, match="The action collected from " "action_space_sample" ): check_env(env) del env env = make_multi_agent("CartPole-v1")({"num_agents": 2}) bad_obs = { 0: np.array([np.inf, np.inf, np.inf, np.inf]), 1: np.array([np.inf, np.inf, np.inf, np.inf]), } env.observation_space_sample = lambda *_: bad_obs with pytest.raises( ValueError, match="The observation collected from " "observation_space_sample", ): check_env(env) class TestCheckBaseEnv: def _make_base_env(self): del self num_envs = 2 sub_envs = [ make_multi_agent("CartPole-v1")({"num_agents": 2}) for _ in range(num_envs) ] env = MultiAgentEnvWrapper(None, sub_envs, 2) return env def test_check_env_not_correct_type_error(self): env = RandomEnv() with pytest.raises(ValueError, match="The passed env is not"): check_base_env(env) del env def test_check_env_reset_incorrect_error(self): reset = MagicMock(return_value=5) env = self._make_base_env() env.try_reset = reset with pytest.raises( ValueError, match=("MultiEnvDict. Instead, it is of" " type") ): check_env(env) obs_with_bad_agent_ids = { 2: np.array([np.inf, np.inf, np.inf, np.inf]), 1: np.array([np.inf, np.inf, np.inf, np.inf]), } obs_with_bad_env_ids = {"bad_env_id": obs_with_bad_agent_ids} reset = MagicMock(return_value=obs_with_bad_env_ids) env.try_reset = reset with pytest.raises( ValueError, match="has dict keys that don't " "correspond to" ): check_env(env) reset = MagicMock(return_value={0: obs_with_bad_agent_ids}) env.try_reset = reset with pytest.raises( ValueError, match="The element returned by " "try_reset has agent_ids that are" " not the names of the agents", ): check_env(env) out_of_bounds_obs = { 0: { 0: np.array([np.inf, np.inf, np.inf, np.inf]), 1: np.array([np.inf, np.inf, np.inf, np.inf]), } } env.try_reset = lambda *_: out_of_bounds_obs with pytest.raises( ValueError, match="The observation collected from " "try_reset" ): check_env(env) del env def test_check_space_contains_functions_errors(self): def bad_contains_function(self, x): raise ValueError("This is a bad contains function") env = self._make_base_env() env.observation_space_contains = bad_contains_function with pytest.raises( ValueError, match="Your observation_space_contains " "function has some" ): check_env(env) del env env = self._make_base_env() env.action_space_contains = bad_contains_function with pytest.raises( ValueError, match="Your action_space_contains " "function has some error" ): check_env(env) del env def test_bad_sample_function(self): env = self._make_base_env() bad_action = {0: {0: 2, 1: 2}} env.action_space_sample = lambda *_: bad_action with pytest.raises( ValueError, match="The action collected from " "action_space_sample" ): check_env(env) del env env = self._make_base_env() bad_obs = { 0: { 0: np.array([np.inf, np.inf, np.inf, np.inf]), 1: np.array([np.inf, np.inf, np.inf, np.inf]), } } env.observation_space_sample = lambda *_: bad_obs with pytest.raises( ValueError, match="The observation collected from " "observation_space_sample", ): check_env(env) def test_check_env_step_incorrect_error(self): good_reward = {0: {0: 0, 1: 0}, 1: {0: 0, 1: 0}} good_done = {0: {0: False, 1: False}, 1: {0: False, 1: False}} good_info = {0: {0: {}, 1: {}}, 1: {0: {}, 1: {}}} env = self._make_base_env() bad_multi_env_dict_obs = {0: 1, 1: {0: np.zeros(4)}} poll = MagicMock( return_value=(bad_multi_env_dict_obs, good_reward, good_done, good_info, {}) ) env.poll = poll with pytest.raises( ValueError, match="The element returned by step, " "next_obs has values that are not" " MultiAgentDicts", ): check_env(env) bad_reward = {0: {0: "not_reward", 1: 1}} good_obs = env.observation_space_sample() poll = MagicMock(return_value=(good_obs, bad_reward, good_done, good_info, {})) env.poll = poll with pytest.raises( AssertionError, match="Your step function must " "return a rewards that are" ): check_env(env) bad_done = {0: {0: "not_done", 1: False}} poll = MagicMock(return_value=(good_obs, good_reward, bad_done, good_info, {})) env.poll = poll with pytest.raises( AssertionError, match="Your step function must " "return a done that is " "boolean.", ): check_env(env) bad_info = {0: {0: "not_info", 1: {}}} poll = MagicMock(return_value=(good_obs, good_reward, good_done, bad_info, {})) env.poll = poll with pytest.raises( AssertionError, match="Your step function must" " return a info that is a " "dict.", ): check_env(env) def test_check_correct_env(self): env = self._make_base_env() check_env(env) if __name__ == "__main__": pytest.main()