ray/rllib/utils/pre_checks.py

150 lines
6.7 KiB
Python

"""Common pre-checks for all RLlib experiments."""
import logging
import gym
from ray.rllib.utils.typing import EnvType
logger = logging.getLogger(__name__)
def check_env(env: EnvType) -> None:
"""Run pre-checks on env that uncover common errors in environments.
Args:
env: Environment to be checked.
Raises:
ValueError: If env is not an instance of SUPPORTED_ENVIRONMENT_TYPES.
ValueError: See check_gym_env docstring for details.
"""
from ray.rllib.env import BaseEnv, MultiAgentEnv, RemoteBaseEnv, \
VectorEnv
if not isinstance(
env, (BaseEnv, gym.Env, MultiAgentEnv, RemoteBaseEnv, VectorEnv)):
raise ValueError(
"Env must be one of the supported types: BaseEnv, gym.Env, "
"MultiAgentEnv, VectorEnv, RemoteVectorEnv")
if isinstance(env, gym.Env):
check_gym_environments(env)
def check_gym_environments(env: gym.Env) -> None:
"""Checking for common errors in gym environments.
Args:
env: Environment to be checked.
Warning:
If env has no attribute spec with a sub attribute,
max_episode_steps.
Raises:
AttributeError: If env has no observation space.
AttributeError: If env has no action space.
ValueError: Observation space must be a gym.spaces.Space.
ValueError: Action space must be a gym.spaces.Space.
ValueError: Observation sampled from observation space must be
contained in the observation space.
ValueError: Action sampled from action space must be
contained in the observation space.
ValueError: If env cannot be resetted.
ValueError: If an observation collected from a call to env.reset().
is not contained in the observation_space.
ValueError: If env cannot be stepped via a call to env.step().
ValueError: If the observation collected from env.step() is not
contained in the observation_space.
AssertionError: If env.step() returns a reward that is not an
int or float.
AssertionError: IF env.step() returns a done that is not a bool.
AssertionError: If env.step() returns an env_info that is not a dict.
"""
# check that env has observation and action spaces
if not hasattr(env, "observation_space"):
raise AttributeError("Env must have observation_space.")
if not hasattr(env, "action_space"):
raise AttributeError("Env must have action_space.")
# check that observation and action spaces are gym.spaces
if not isinstance(env.observation_space, gym.spaces.Space):
raise ValueError("Observation space must be a gym.space")
if not isinstance(env.action_space, gym.spaces.Space):
raise ValueError("Action space must be a gym.space")
# raise a warning if there isn't a max_episode_steps attribute
if not hasattr(env, "spec") or not hasattr(env.spec, "max_episode_steps"):
logger.warning("Your env doesn't have a .spec.max_episode_steps "
"attribute. This is fine if you have set 'horizon' "
"in your config dictionary, or `soft_horizon`. "
"However, if you haven't, 'horizon' will default "
"to infinity, and your environment will not be "
"reset.")
# check if sampled actions and observations are contained within their
# respective action and observation spaces.
def contains_error(action_or_observation, sample, space):
string_type = "observation" if not action_or_observation else \
"action"
sample_type = get_type(sample)
_space_type = space.dtype
ret = (f"A sampled {string_type} from your env wasn't contained "
f"within your env's {string_type} space. Its possible that "
f"there was a type mismatch, or that one of the "
f"sub-{string_type} was out of bounds:\n\nsampled_obs: "
f"{sample}\nenv.{string_type}_space: {space}"
f"\nsampled_obs's dtype: {sample_type}"
f"\nenv.{sample_type}'s dtype: {_space_type}")
return ret
def get_type(var):
return var.dtype if hasattr(var, "dtype") else type(var)
sampled_action = env.action_space.sample()
sampled_observation = env.observation_space.sample()
if not env.observation_space.contains(sampled_observation):
raise ValueError(
contains_error(False, sampled_observation, env.observation_space))
if not env.action_space.contains(sampled_action):
raise ValueError(
contains_error(True, sampled_action, env.action_space))
# check if observation generated from stepping the environment is
# contained within the observation space
reset_obs = env.reset()
if not env.observation_space.contains(reset_obs):
reset_obs_type = get_type(reset_obs)
space_type = env.observation_space.dtype
error = (
f"The observation collected from env.reset() was not "
f"contained within your env's observation space. Its possible "
f"that There was a type mismatch, or that one of the "
f"sub-observations was out of bounds: \n\n reset_obs: "
f"{reset_obs}\n\n env.observation_space: "
f"{env.observation_space}\n\n reset_obs's dtype: "
f"{reset_obs_type}\n\n env.observation_space's dtype: "
f"{space_type}")
raise ValueError(error)
# check if env.step can run, and generates observations rewards, done
# signals and infos that are within their respective spaces and are of
# the correct dtypes
next_obs, reward, done, info = env.step(sampled_action)
if not env.observation_space.contains(next_obs):
next_obs_type = get_type(next_obs)
space_type = env.observation_space.dtype
error = (
f"The observation collected from env.step(sampled_action) was "
f"not contained within your env's observation space. Its "
f"possible that There was a type mismatch, or that one of the "
f"sub-observations was out of bounds:\n\n next_obs: {next_obs}"
f"\n\n env.observation_space: {env.observation_space}"
f"\n\n next_obs's dtype: {next_obs_type}"
f"\n\n env.observation_space's dtype: {space_type}")
raise ValueError(error)
assert isinstance(reward, (float, int)), \
"Your step function must return a reward that is integer or float."
assert isinstance(
done, bool), "Your step function must return a done that is a boolean."
assert isinstance(
info, dict), "Your step function must return a info that is a dict."