ray/rllib/utils/tests/test_check_env.py

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

335 lines
12 KiB
Python
Raw Normal View History

import logging
import gym
import numpy as np
import pytest
import unittest
from unittest.mock import Mock, MagicMock
from ray.rllib.env.base_env import convert_to_base_env
2022-01-18 07:34:06 -08:00
from ray.rllib.env.multi_agent_env import make_multi_agent, MultiAgentEnvWrapper
from ray.rllib.examples.env.random_env import RandomEnv
from ray.rllib.utils.pre_checks.env import (
check_env,
check_gym_environments,
2022-01-18 07:34:06 -08:00
check_multiagent_environments,
check_base_env,
)
class TestGymCheckEnv(unittest.TestCase):
@pytest.fixture(autouse=True)
def inject_fixtures(self, caplog):
caplog.set_level(logging.CRITICAL)
def test_has_observation_and_action_space(self):
env = Mock(spec=[])
with pytest.raises(AttributeError, match="Env must have observation_space."):
check_gym_environments(env)
env = Mock(spec=["observation_space"])
with pytest.raises(AttributeError, match="Env must have action_space."):
check_gym_environments(env)
def test_obs_and_action_spaces_are_gym_spaces(self):
env = RandomEnv()
observation_space = env.observation_space
env.observation_space = "not a gym space"
with pytest.raises(ValueError, match="Observation space must be a gym.space"):
check_env(env)
env.observation_space = observation_space
env.action_space = "not an action space"
with pytest.raises(ValueError, match="Action space must be a gym.space"):
check_env(env)
def test_reset(self):
reset = MagicMock(return_value=5)
env = RandomEnv()
env.reset = reset
# check reset with out of bounds fails
error = ".*The observation collected from env.reset().*"
with pytest.raises(ValueError, match=error):
check_env(env)
# check reset with obs of incorrect type fails
reset = MagicMock(return_value=float(0.1))
env.reset = reset
with pytest.raises(ValueError, match=error):
check_env(env)
def test_step(self):
step = MagicMock(return_value=(5, 5, True, {}))
env = RandomEnv()
env.step = step
error = ".*The observation collected from env.step.*"
with pytest.raises(ValueError, match=error):
check_env(env)
# check reset that returns obs of incorrect type fails
step = MagicMock(return_value=(float(0.1), 5, True, {}))
env.step = step
with pytest.raises(ValueError, match=error):
check_env(env)
# check step that returns reward of non float/int fails
step = MagicMock(return_value=(1, "Not a valid reward", True, {}))
env.step = step
error = "Your step function must return a reward that is integer or float."
with pytest.raises(ValueError, match=error):
check_env(env)
# check step that returns a non bool fails
step = MagicMock(return_value=(1, float(5), "not a valid done signal", {}))
env.step = step
error = "Your step function must return a done that is a boolean."
with pytest.raises(ValueError, match=error):
check_env(env)
# check step that returns a non dict fails
step = MagicMock(return_value=(1, float(5), True, "not a valid env info"))
env.step = step
error = "Your step function must return a info that is a dict."
with pytest.raises(ValueError, match=error):
check_env(env)
class TestCheckMultiAgentEnv(unittest.TestCase):
@pytest.fixture(autouse=True)
def inject_fixtures(self, caplog):
caplog.set_level(logging.CRITICAL)
def test_check_env_not_correct_type_error(self):
env = RandomEnv()
with pytest.raises(ValueError, match="The passed env is not"):
check_multiagent_environments(env)
def test_check_env_reset_incorrect_error(self):
reset = MagicMock(return_value=5)
env = make_multi_agent("CartPole-v1")({"num_agents": 2})
env.reset = reset
2022-01-18 07:34:06 -08:00
with pytest.raises(ValueError, match="The element returned by reset"):
check_env(env)
bad_obs = {
0: np.array([np.inf, np.inf, np.inf, np.inf]),
1: np.array([np.inf, np.inf, np.inf, np.inf]),
}
env.reset = lambda *_: bad_obs
with pytest.raises(ValueError, match="The observation collected from env"):
check_env(env)
2022-01-18 07:34:06 -08:00
def test_check_incorrect_space_contains_functions_error(self):
def bad_contains_function(self, x):
raise ValueError("This is a bad contains function")
env = make_multi_agent("CartPole-v1")({"num_agents": 2})
env.observation_space_contains = bad_contains_function
with pytest.raises(
ValueError, match="Your observation_space_contains function has some"
):
check_env(env)
env = make_multi_agent("CartPole-v1")({"num_agents": 2})
bad_action = {0: 2, 1: 2}
env.action_space_sample = lambda *_: bad_action
with pytest.raises(
ValueError, match="The action collected from action_space_sample"
):
check_env(env)
env.action_space_contains = bad_contains_function
with pytest.raises(
ValueError, match="Your action_space_contains function has some error"
):
check_env(env)
def test_check_env_step_incorrect_error(self):
step = MagicMock(return_value=(5, 5, True, {}))
env = make_multi_agent("CartPole-v1")({"num_agents": 2})
sampled_obs = env.reset()
env.step = step
2022-01-18 07:34:06 -08:00
with pytest.raises(ValueError, match="The element returned by step"):
check_env(env)
step = MagicMock(return_value=(sampled_obs, {0: "Not a reward"}, {0: True}, {}))
env.step = step
with pytest.raises(ValueError, match="Your step function must return rewards"):
check_env(env)
step = MagicMock(return_value=(sampled_obs, {0: 5}, {0: "Not a bool"}, {}))
env.step = step
with pytest.raises(ValueError, match="Your step function must return dones"):
check_env(env)
step = MagicMock(
return_value=(sampled_obs, {0: 5}, {0: False}, {0: "Not a Dict"})
)
env.step = step
with pytest.raises(ValueError, match="Your step function must return infos"):
check_env(env)
2022-01-18 07:34:06 -08:00
def test_bad_sample_function(self):
env = make_multi_agent("CartPole-v1")({"num_agents": 2})
bad_action = {0: 2, 1: 2}
env.action_space_sample = lambda *_: bad_action
with pytest.raises(
ValueError, match="The action collected from action_space_sample"
2022-01-18 07:34:06 -08:00
):
check_env(env)
env = make_multi_agent("CartPole-v1")({"num_agents": 2})
bad_obs = {
0: np.array([np.inf, np.inf, np.inf, np.inf]),
1: np.array([np.inf, np.inf, np.inf, np.inf]),
}
env.observation_space_sample = lambda *_: bad_obs
with pytest.raises(
ValueError,
match="The observation collected from observation_space_sample",
2022-01-18 07:34:06 -08:00
):
check_env(env)
class TestCheckBaseEnv:
def _make_base_env(self):
del self
num_envs = 2
sub_envs = [
make_multi_agent("CartPole-v1")({"num_agents": 2}) for _ in range(num_envs)
]
env = MultiAgentEnvWrapper(None, sub_envs, 2)
return env
def test_check_env_not_correct_type_error(self):
env = RandomEnv()
with pytest.raises(ValueError, match="The passed env is not"):
check_base_env(env)
def test_check_env_reset_incorrect_error(self):
reset = MagicMock(return_value=5)
env = self._make_base_env()
env.try_reset = reset
with pytest.raises(ValueError, match=("MultiEnvDict. Instead, it is of type")):
2022-01-18 07:34:06 -08:00
check_env(env)
obs_with_bad_agent_ids = {
2: np.array([np.inf, np.inf, np.inf, np.inf]),
1: np.array([np.inf, np.inf, np.inf, np.inf]),
}
obs_with_bad_env_ids = {"bad_env_id": obs_with_bad_agent_ids}
reset = MagicMock(return_value=obs_with_bad_env_ids)
env.try_reset = reset
with pytest.raises(ValueError, match="has dict keys that don't correspond to"):
2022-01-18 07:34:06 -08:00
check_env(env)
reset = MagicMock(return_value={0: obs_with_bad_agent_ids})
env.try_reset = reset
with pytest.raises(
ValueError,
match="The element returned by "
"try_reset has agent_ids that are"
" not the names of the agents",
):
check_env(env)
out_of_bounds_obs = {
0: {
0: np.array([np.inf, np.inf, np.inf, np.inf]),
1: np.array([np.inf, np.inf, np.inf, np.inf]),
}
}
env.try_reset = lambda *_: out_of_bounds_obs
with pytest.raises(
ValueError, match="The observation collected from try_reset"
2022-01-18 07:34:06 -08:00
):
check_env(env)
def test_check_space_contains_functions_errors(self):
def bad_contains_function(self, x):
raise ValueError("This is a bad contains function")
env = self._make_base_env()
env.observation_space_contains = bad_contains_function
with pytest.raises(
ValueError, match="Your observation_space_contains function has some"
2022-01-18 07:34:06 -08:00
):
check_env(env)
env = self._make_base_env()
env.action_space_contains = bad_contains_function
with pytest.raises(
ValueError, match="Your action_space_contains function has some error"
2022-01-18 07:34:06 -08:00
):
check_env(env)
def test_bad_sample_function(self):
env = self._make_base_env()
bad_action = {0: {0: 2, 1: 2}}
env.action_space_sample = lambda *_: bad_action
with pytest.raises(
ValueError, match="The action collected from action_space_sample"
2022-01-18 07:34:06 -08:00
):
check_env(env)
env = self._make_base_env()
bad_obs = {
0: {
0: np.array([np.inf, np.inf, np.inf, np.inf]),
1: np.array([np.inf, np.inf, np.inf, np.inf]),
}
}
env.observation_space_sample = lambda *_: bad_obs
with pytest.raises(
ValueError,
match="The observation collected from observation_space_sample",
2022-01-18 07:34:06 -08:00
):
check_env(env)
def test_check_env_step_incorrect_error(self):
good_reward = {0: {0: 0, 1: 0}, 1: {0: 0, 1: 0}}
good_done = {0: {0: False, 1: False}, 1: {0: False, 1: False}}
good_info = {0: {0: {}, 1: {}}, 1: {0: {}, 1: {}}}
env = self._make_base_env()
bad_multi_env_dict_obs = {0: 1, 1: {0: np.zeros(4)}}
poll = MagicMock(
return_value=(bad_multi_env_dict_obs, good_reward, good_done, good_info, {})
)
env.poll = poll
with pytest.raises(
ValueError,
match="The element returned by step, "
"next_obs contains values that are not"
2022-01-18 07:34:06 -08:00
" MultiAgentDicts",
):
check_env(env)
bad_reward = {0: {0: "not_reward", 1: 1}}
good_obs = env.observation_space_sample()
poll = MagicMock(return_value=(good_obs, bad_reward, good_done, good_info, {}))
env.poll = poll
with pytest.raises(
ValueError, match="Your step function must return rewards that are"
2022-01-18 07:34:06 -08:00
):
check_env(env)
bad_done = {0: {0: "not_done", 1: False}}
poll = MagicMock(return_value=(good_obs, good_reward, bad_done, good_info, {}))
env.poll = poll
with pytest.raises(
ValueError,
match="Your step function must return dones that are boolean.",
2022-01-18 07:34:06 -08:00
):
check_env(env)
bad_info = {0: {0: "not_info", 1: {}}}
poll = MagicMock(return_value=(good_obs, good_reward, good_done, bad_info, {}))
env.poll = poll
with pytest.raises(
ValueError,
match="Your step function must return infos that are a dict.",
2022-01-18 07:34:06 -08:00
):
check_env(env)
def test_check_correct_env(self):
env = self._make_base_env()
check_env(env)
env = gym.make("CartPole-v0")
env = convert_to_base_env(env)
check_env(env)
2022-01-18 07:34:06 -08:00
if __name__ == "__main__":
pytest.main()