[RLlib] Put env-checker on critical path. (#22191)

This commit is contained in:
Avnish Narayan 2022-02-17 05:06:14 -08:00 committed by GitHub
parent e03606f0b3
commit 740def0a13
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
25 changed files with 260 additions and 185 deletions

View file

@ -87,6 +87,7 @@ class TestDDPG(unittest.TestCase):
# Test against all frameworks.
for _ in framework_iterator(core_config):
config = core_config.copy()
config["seed"] = 42
# Default OUNoise setup.
trainer = ddpg.DDPGTrainer(config=config, env="Pendulum-v1")
# Setting explore=False should always return the same action.
@ -142,6 +143,7 @@ class TestDDPG(unittest.TestCase):
"""Tests DDPG loss function results across all frameworks."""
config = ddpg.DEFAULT_CONFIG.copy()
# Run locally.
config["seed"] = 42
config["num_workers"] = 0
config["learning_starts"] = 0
config["twin_q"] = True

View file

@ -36,6 +36,7 @@ torch, _ = try_import_torch()
class SimpleEnv(Env):
def __init__(self, config):
self._skip_env_checking = True
if config.get("simplex_actions", False):
self.action_space = Simplex((2,))
else:
@ -168,6 +169,7 @@ class TestSAC(unittest.TestCase):
"""Tests SAC loss function results across all frameworks."""
config = sac.DEFAULT_CONFIG.copy()
# Run locally.
config["seed"] = 42
config["num_workers"] = 0
config["learning_starts"] = 0
config["twin_q"] = False

View file

@ -628,6 +628,9 @@ COMMON_CONFIG: TrainerConfigDict = {
# training iteration.
"_disable_execution_plan_api": False,
# If True, disable the environment pre-checking module.
"disable_env_checking": False,
# === Deprecated keys ===
# Uses the sync samples optimizer instead of the multi-gpu one. This is
# usually slower, but you might want to try it if you run into issues with

View file

@ -5,7 +5,7 @@ import numpy as np
from gym.spaces import Discrete, Dict, Box
class CartPole:
class CartPole(gym.Env):
"""
Wrapper for gym CartPole environment where the reward
is accumulated to the end

15
rllib/env/base_env.py vendored
View file

@ -204,7 +204,7 @@ class BaseEnv:
Returns:
All agent ids for each the environment.
"""
return {_DUMMY_AGENT_ID}
return {}
@PublicAPI
def try_render(self, env_id: Optional[EnvID] = None) -> None:
@ -313,7 +313,7 @@ class BaseEnv:
True if the observations are contained within their respective
spaces. False otherwise.
"""
self._space_contains(self.observation_space, x)
return self._space_contains(self.observation_space, x)
@PublicAPI
def action_space_contains(self, x: MultiEnvDict) -> bool:
@ -340,8 +340,15 @@ class BaseEnv:
"""
agents = set(self.get_agent_ids())
for multi_agent_dict in x.values():
for agent_id, obs in multi_agent_dict:
if (agent_id not in agents) or (not space[agent_id].contains(obs)):
for agent_id, obs in multi_agent_dict.items():
# this is for the case where we have a single agent
# and we're checking a Vector env thats been converted to
# a BaseEnv
if agent_id == _DUMMY_AGENT_ID:
if not space.contains(obs):
return False
# for the MultiAgent env case
elif (agent_id not in agents) or (not space[agent_id].contains(obs)):
return False
return True

View file

@ -178,7 +178,11 @@ class MultiAgentEnv(gym.Env):
if agent_ids is None:
agent_ids = self.get_agent_ids()
samples = self.action_space.sample()
return {agent_id: samples[agent_id] for agent_id in agent_ids}
return {
agent_id: samples[agent_id]
for agent_id in agent_ids
if agent_id != "__all__"
}
logger.warning("action_space_sample() has not been implemented")
del agent_ids
return {}

View file

@ -1,9 +1,9 @@
import logging
import gym
import numpy as np
from typing import Callable, List, Optional, Tuple, Union
from typing import Callable, List, Optional, Tuple, Union, Set
from ray.rllib.env.base_env import BaseEnv
from ray.rllib.env.base_env import BaseEnv, _DUMMY_AGENT_ID
from ray.rllib.utils.annotations import Deprecated, override, PublicAPI
from ray.rllib.utils.typing import (
EnvActionType,
@ -12,6 +12,7 @@ from ray.rllib.utils.typing import (
EnvObsType,
EnvType,
MultiEnvDict,
AgentID,
)
logger = logging.getLogger(__name__)
@ -355,3 +356,20 @@ class VectorEnvWrapper(BaseEnv):
@PublicAPI
def action_space(self) -> gym.Space:
return self._action_space
@override(BaseEnv)
@PublicAPI
def action_space_sample(self, agent_id: list = None) -> MultiEnvDict:
del agent_id
return {0: {_DUMMY_AGENT_ID: self._action_space.sample()}}
@override(BaseEnv)
@PublicAPI
def observation_space_sample(self, agent_id: list = None) -> MultiEnvDict:
del agent_id
return {0: {_DUMMY_AGENT_ID: self._observation_space.sample()}}
@override(BaseEnv)
@PublicAPI
def get_agent_ids(self) -> Set[AgentID]:
return {_DUMMY_AGENT_ID}

View file

@ -55,6 +55,8 @@ class GroupAgentsWrapper(MultiAgentEnv):
self.observation_space = obs_space
if act_space is not None:
self.action_space = act_space
for group_id in groups.keys():
self._agent_ids.add(group_id)
def seed(self, seed=None):
if not hasattr(self.env, "seed"):

View file

@ -9,7 +9,7 @@ class OpenSpielEnv(MultiAgentEnv):
def __init__(self, env):
super().__init__()
self.env = env
self._skip_env_checking = True
# Agent IDs are ints, starting from 0.
self.num_agents = self.env.num_players()
# Store the open-spiel game type.

View file

@ -70,6 +70,8 @@ class PettingZooEnv(MultiAgentEnv):
super().__init__()
self.env = env
env.reset()
self._skip_env_checking = True # TODO avnishn - remove this after making
# petting zoo env compatible with check_env
# Get first observation space, assuming all agents have equal space
self.observation_space = self.env.observation_space(self.env.agents[0])
@ -94,6 +96,7 @@ class PettingZooEnv(MultiAgentEnv):
"SuperSuit's pad_action_space wrapper can help (usage: "
"`supersuit.aec_wrappers.pad_action_space(env)`"
)
self._agent_ids = set(self.env.agents)
def reset(self):
self.env.reset()

View file

@ -1,6 +1,6 @@
import copy
import gym
from gym.spaces import Box, Discrete, MultiDiscrete, Space
from gym.spaces import Discrete, MultiDiscrete, Space
import logging
import numpy as np
import platform
@ -25,11 +25,9 @@ from ray import ObjectRef
from ray import cloudpickle as pickle
from ray.rllib.env.base_env import BaseEnv, convert_to_base_env
from ray.rllib.env.env_context import EnvContext
from ray.rllib.env.external_env import ExternalEnv
from ray.rllib.env.multi_agent_env import MultiAgentEnv
from ray.rllib.env.external_multi_agent_env import ExternalMultiAgentEnv
from ray.rllib.env.utils import record_env_wrapper
from ray.rllib.env.vector_env import VectorEnv
from ray.rllib.env.wrappers.atari_wrappers import wrap_deepmind, is_atari
from ray.rllib.evaluation.sampler import AsyncSampler, SyncSampler
from ray.rllib.evaluation.metrics import RolloutMetrics
@ -43,11 +41,11 @@ from ray.rllib.policy.sample_batch import MultiAgentBatch, DEFAULT_POLICY_ID
from ray.rllib.policy.policy import Policy, PolicySpec
from ray.rllib.policy.policy_map import PolicyMap
from ray.rllib.policy.torch_policy import TorchPolicy
from ray.rllib.utils import force_list, merge_dicts
from ray.rllib.utils import force_list, merge_dicts, check_env
from ray.rllib.utils.annotations import Deprecated, DeveloperAPI, ExperimentalAPI
from ray.rllib.utils.debug import summarize, update_global_seed_if_necessary
from ray.rllib.utils.deprecation import deprecation_warning
from ray.rllib.utils.error import EnvError, ERR_MSG_NO_GPUS, HOWTO_CHANGE_CONFIG
from ray.rllib.utils.error import ERR_MSG_NO_GPUS, HOWTO_CHANGE_CONFIG
from ray.rllib.utils.filter import get_filter, Filter
from ray.rllib.utils.framework import try_import_tf, try_import_torch
from ray.rllib.utils.sgd import do_minibatch_sgd
@ -251,6 +249,7 @@ class RolloutWorker(ParallelIteratorWorker):
spaces: Optional[Dict[PolicyID, Tuple[Space, Space]]] = None,
policy=None,
monitor_path=None,
disable_env_checking=False,
):
"""Initializes a RolloutWorker instance.
@ -365,6 +364,8 @@ class RolloutWorker(ParallelIteratorWorker):
Env is created on this RolloutWorker.
policy: Obsoleted arg. Use `policy_spec` instead.
monitor_path: Obsoleted arg. Use `record_env` instead.
disable_env_checking: If True, disables the env checking module that
validates the properties of the passed environment.
"""
# Deprecated args.
@ -463,6 +464,7 @@ class RolloutWorker(ParallelIteratorWorker):
self.last_batch: Optional[SampleBatchType] = None
self.global_vars: Optional[dict] = None
self.fake_sampler: bool = fake_sampler
self._disable_env_checking: bool = disable_env_checking
# Update the global seed for numpy/random/tf-eager/torch if we are not
# the local worker, otherwise, this was already done in the Trainer
@ -492,8 +494,19 @@ class RolloutWorker(ParallelIteratorWorker):
if self.env is not None:
# Validate environment (general validation function).
_validate_env(self.env, env_context=self.env_context)
# Custom validation function given.
if not self._disable_env_checking:
logger.warning(
"We've added a module for checking environments that "
"are used in experiments. It will cause your "
"environment to fail if your environment is not set up"
"correctly. You can disable check env by setting "
"`disable_env_checking` to True in your experiment config "
"dictionary. You can run the environment checking module "
"standalone by calling ray.rllib.utils.check_env(env)."
)
check_env(self.env)
# Custom validation function given, typically a function attribute of the
# algorithm trainer.
if validate_env is not None:
validate_env(self.env, self.env_context)
# We can't auto-wrap a BaseEnv.
@ -1717,6 +1730,8 @@ class RolloutWorker(ParallelIteratorWorker):
def _get_make_sub_env_fn(
self, env_creator, env_context, validate_env, env_wrapper, seed
):
disable_env_checking = self._disable_env_checking
def _make_sub_env_local(vector_index):
# Used to created additional environments during environment
# vectorization.
@ -1727,7 +1742,17 @@ class RolloutWorker(ParallelIteratorWorker):
# Create the sub-env.
env = env_creator(env_ctx)
# Validate first.
_validate_env(env, env_context=env_ctx)
if not disable_env_checking:
logger.warning(
"We've added a module for checking environments that "
"are used in experiments. It will cause your "
"environment to fail if your environment is not set up"
"correctly. You can disable check env by setting "
"`disable_env_checking` to True in your experiment config "
"dictionary. You can run the environment checking module "
"standalone by calling ray.rllib.utils.check_env(env)."
)
check_env(env)
# Custom validation function given by user.
if validate_env is not None:
validate_env(env, env_ctx)
@ -1870,52 +1895,3 @@ def _determine_spaces_for_multi_agent_dict(
action_space=act_space
)
return multi_agent_dict
def _validate_env(env: EnvType, env_context: EnvContext = None):
# Base message for checking the env for vector-index=0
msg = f"Validating sub-env at vector index={env_context.vector_index} ..."
allowed_types = [gym.Env, ExternalEnv, VectorEnv, BaseEnv, ray.actor.ActorHandle]
if not any(isinstance(env, tpe) for tpe in allowed_types):
# Allow this as a special case (assumed gym.Env).
# TODO: Disallow this early-out. Everything should conform to a few
# supported classes, i.e. gym.Env/MultiAgentEnv/etc...
if hasattr(env, "observation_space") and hasattr(env, "action_space"):
logger.warning(msg + f" (warning; invalid env-type={type(env)})")
return
else:
logger.warning(msg + " (NOT OK)")
raise EnvError(
"Returned env should be an instance of gym.Env (incl. "
"MultiAgentEnv), ExternalEnv, VectorEnv, or BaseEnv. "
f"The provided env creator function returned {env} "
f"(type={type(env)})."
)
# Do some test runs with the provided env.
if isinstance(env, gym.Env) and not isinstance(env, MultiAgentEnv):
# Make sure the gym.Env has the two space attributes properly set.
assert hasattr(env, "observation_space") and hasattr(env, "action_space")
# Get a dummy observation by resetting the env.
dummy_obs = env.reset()
# Convert lists to np.ndarrays.
if type(dummy_obs) is list and isinstance(env.observation_space, Box):
dummy_obs = np.array(dummy_obs)
# Ignore float32/float64 diffs.
if (
isinstance(env.observation_space, Box)
and env.observation_space.dtype != dummy_obs.dtype
):
dummy_obs = dummy_obs.astype(env.observation_space.dtype)
# Check, if observation is ok (part of the observation space). If not,
# error.
if not env.observation_space.contains(dummy_obs):
logger.warning(msg + " (NOT OK)")
raise EnvError(
f"Env's `observation_space` {env.observation_space} does not "
f"contain returned observation after a reset ({dummy_obs})!"
)
# Log that everything is ok.
logger.info(msg + " (ok)")

View file

@ -74,6 +74,7 @@ class EchoPolicy(Policy):
class EpisodeEnv(MultiAgentEnv):
def __init__(self, episode_length, num):
super().__init__()
self._skip_env_checking = True
self.agents = [MockEnv3(episode_length) for _ in range(num)]
self.dones = set()
self.observation_space = self.agents[0].observation_space

View file

@ -591,6 +591,7 @@ class WorkerSet:
fake_sampler=config["fake_sampler"],
extra_python_environs=extra_python_environs,
spaces=spaces,
disable_env_checking=config["disable_env_checking"],
)
return worker

View file

@ -9,7 +9,7 @@ class ActionMaskEnv(RandomEnv):
def __init__(self, config):
super().__init__(config)
self._skip_env_checking = True
# Masking only works for Discrete actions.
assert isinstance(self.action_space, Discrete)
# Add action_mask to observations.

View file

@ -34,6 +34,7 @@ class DebugCounterEnv(gym.Env):
class MultiAgentDebugCounterEnv(MultiAgentEnv):
def __init__(self, config):
super().__init__()
self._skip_env_checking = True
self.num_agents = config["num_agents"]
self.base_episode_len = config.get("base_episode_len", 103)
# Actions are always:

View file

@ -27,6 +27,7 @@ class BasicMultiAgent(MultiAgentEnv):
def __init__(self, num):
super().__init__()
self.agents = [MockEnv(25) for _ in range(num)]
self._agent_ids = set(range(num))
self.dones = set()
self.observation_space = gym.spaces.Discrete(2)
self.action_space = gym.spaces.Discrete(2)
@ -59,6 +60,7 @@ class EarlyDoneMultiAgent(MultiAgentEnv):
def __init__(self):
super().__init__()
self.agents = [MockEnv(3), MockEnv(5)]
self._agent_ids = set(range(len(self.agents)))
self.dones = set()
self.last_obs = {}
self.last_rew = {}
@ -77,7 +79,7 @@ class EarlyDoneMultiAgent(MultiAgentEnv):
self.i = 0
for i, a in enumerate(self.agents):
self.last_obs[i] = a.reset()
self.last_rew[i] = None
self.last_rew[i] = 0
self.last_done[i] = False
self.last_info[i] = {}
obs_dict = {self.i: self.last_obs[self.i]}
@ -111,6 +113,7 @@ class FlexAgentsMultiAgent(MultiAgentEnv):
def __init__(self):
super().__init__()
self.agents = {}
self._agent_ids = set()
self.agentID = 0
self.dones = set()
self.observation_space = gym.spaces.Discrete(2)
@ -121,15 +124,16 @@ class FlexAgentsMultiAgent(MultiAgentEnv):
# Spawn a new agent into the current episode.
agentID = self.agentID
self.agents[agentID] = MockEnv(25)
self._agent_ids.add(agentID)
self.agentID += 1
return agentID
def reset(self):
self.agents = {}
self._agent_ids = set()
self.spawn()
self.resetted = True
self.dones = set()
obs = {}
for i, a in self.agents.items():
obs[i] = a.reset()
@ -175,6 +179,7 @@ class RoundRobinMultiAgent(MultiAgentEnv):
else:
# Observations are all zeros
self.agents = [MockEnv(5) for _ in range(num)]
self._agent_ids = set(range(num))
self.dones = set()
self.last_obs = {}
self.last_rew = {}
@ -194,7 +199,7 @@ class RoundRobinMultiAgent(MultiAgentEnv):
self.i = 0
for i, a in enumerate(self.agents):
self.last_obs[i] = a.reset()
self.last_rew[i] = None
self.last_rew[i] = 0
self.last_done[i] = False
self.last_info[i] = {}
obs_dict = {self.i: self.last_obs[self.i]}

View file

@ -40,6 +40,7 @@ class ParametricActionsCartPole(gym.Env):
"cart": self.wrapped.observation_space,
}
)
self._skip_env_checking = True
def update_avail_actions(self):
self.action_assignments = np.array(
@ -114,6 +115,7 @@ class ParametricActionsCartPoleNoEmbeddings(gym.Env):
"cart": self.wrapped.observation_space,
}
)
self._skip_env_checking = True
def reset(self):
return {

View file

@ -13,11 +13,12 @@ class TwoStepGame(MultiAgentEnv):
self.state = None
self.agent_1 = 0
self.agent_2 = 1
self._skip_env_checking = True
# MADDPG emits action logits instead of actual discrete actions
self.actions_are_logits = env_config.get("actions_are_logits", False)
self.one_hot_state_encoding = env_config.get("one_hot_state_encoding", False)
self.with_state = env_config.get("separate_state_space", False)
self._agent_ids = {0, 1}
if not self.one_hot_state_encoding:
self.observation_space = Discrete(6)
self.with_state = False
@ -113,6 +114,8 @@ class TwoStepGameWithGroupedAgents(MultiAgentEnv):
)
self.observation_space = self.env.observation_space
self.action_space = self.env.action_space
self._agent_ids = {"agents"}
self._skip_env_checking = True
def reset(self):
return self.env.reset()

View file

@ -86,6 +86,7 @@ class WindyMazeEnv(gym.Env):
class HierarchicalWindyMazeEnv(MultiAgentEnv):
def __init__(self, env_config):
super().__init__()
self._skip_env_checking = True
self.flat_env = WindyMazeEnv(env_config)
def reset(self):

View file

@ -60,6 +60,7 @@ class NonVectorizedEnvToBeVectorizedIntoRemoteBaseEnv(TaskSettableEnv):
"""
def __init__(self, config=None):
super().__init__()
self.action_space = gym.spaces.Box(0, 1, shape=(1,))
self.observation_space = gym.spaces.Box(0, 1, shape=(2,))
self.task = 1

View file

@ -26,6 +26,7 @@ class FaultInjectEnv(gym.Env):
def __init__(self, config):
self.env = gym.make("CartPole-v0")
self._skip_env_checking = True
self.action_space = self.env.action_space
self.observation_space = self.env.observation_space
self.config = config

View file

@ -404,6 +404,7 @@ class NestedObservationSpacesTest(unittest.TestCase):
"use_lstm": test_lstm,
},
"framework": "tf",
"disable_env_checking": True,
},
)
# Skip first passes as they came from the TorchPolicy loss
@ -436,6 +437,7 @@ class NestedObservationSpacesTest(unittest.TestCase):
"custom_model": "composite2",
},
"framework": "tf",
"disable_env_checking": True,
},
)
# Skip first passes as they came from the TorchPolicy loss
@ -518,6 +520,7 @@ class NestedObservationSpacesTest(unittest.TestCase):
}[aid],
},
"framework": "tf",
"disable_env_checking": True,
},
)
# Skip first passes as they came from the TorchPolicy loss

View file

@ -1,8 +1,12 @@
"""Common pre-checks for all RLlib experiments."""
import logging
import numpy as np
from typing import TYPE_CHECKING, Set
import gym
from ray.actor import ActorHandle
from ray.rllib.utils.spaces.space_utils import convert_element_to_space_type
from ray.rllib.utils.typing import EnvType
if TYPE_CHECKING:
@ -21,12 +25,38 @@ def check_env(env: EnvType) -> None:
ValueError: If env is not an instance of SUPPORTED_ENVIRONMENT_TYPES.
ValueError: See check_gym_env docstring for details.
"""
from ray.rllib.env import BaseEnv, MultiAgentEnv, RemoteBaseEnv, VectorEnv
from ray.rllib.env import (
BaseEnv,
MultiAgentEnv,
RemoteBaseEnv,
VectorEnv,
ExternalMultiAgentEnv,
ExternalEnv,
)
if not isinstance(env, (BaseEnv, gym.Env, MultiAgentEnv, RemoteBaseEnv, VectorEnv)):
if hasattr(env, "_skip_env_checking") and env._skip_env_checking:
# This is a work around for some environments that we already have in RLlb
# that we want to skip checking for now until we have the time to fix them.
logger.warning("Skipping env checking for this experiment")
return
if not isinstance(
env,
(
BaseEnv,
gym.Env,
MultiAgentEnv,
RemoteBaseEnv,
VectorEnv,
ExternalMultiAgentEnv,
ExternalEnv,
ActorHandle,
),
):
raise ValueError(
"Env must be one of the supported types: BaseEnv, gym.Env, "
"MultiAgentEnv, VectorEnv, RemoteBaseEnv"
"MultiAgentEnv, VectorEnv, RemoteBaseEnv, ExternalMultiAgentEnv, "
f"ExternalEnv, but instead was a {type(env)}"
)
if isinstance(env, MultiAgentEnv):
@ -37,7 +67,8 @@ def check_env(env: EnvType) -> None:
check_base_env(env)
else:
logger.warning(
"Env checking isn't implemented for VectorEnvs or " "RemoteBaseEnvs."
"Env checking isn't implemented for VectorEnvs, RemoteBaseEnvs, "
"ExternalMultiAgentEnv,or ExternalEnvs or Environments that are Ray actors"
)
@ -97,33 +128,11 @@ def check_gym_environments(env: gym.Env) -> None:
# check if sampled actions and observations are contained within their
# respective action and observation spaces.
def contains_error(action_or_observation, sample, space):
string_type = "observation" if not action_or_observation else "action"
sample_type = get_type(sample)
_space_type = space.dtype
ret = (
f"A sampled {string_type} from your env wasn't contained "
f"within your env's {string_type} space. Its possible that "
f"there was a type mismatch, or that one of the "
f"sub-{string_type} was out of bounds:\n\nsampled_obs: "
f"{sample}\nenv.{string_type}_space: {space}"
f"\nsampled_obs's dtype: {sample_type}"
f"\nenv.{sample_type}'s dtype: {_space_type}"
)
return ret
def get_type(var):
return var.dtype if hasattr(var, "dtype") else type(var)
sampled_action = env.action_space.sample()
sampled_observation = env.observation_space.sample()
if not env.observation_space.contains(sampled_observation):
raise ValueError(
contains_error(False, sampled_observation, env.observation_space)
)
if not env.action_space.contains(sampled_action):
raise ValueError(contains_error(True, sampled_action, env.action_space))
# check if observation generated from stepping the environment is
# contained within the observation space
reset_obs = env.reset()
@ -140,7 +149,11 @@ def check_gym_environments(env: gym.Env) -> None:
f"{reset_obs_type}\n\n env.observation_space's dtype: "
f"{space_type}"
)
raise ValueError(error)
temp_sampled_reset_obs = convert_element_to_space_type(
reset_obs, sampled_observation
)
if not env.observation_space.contains(temp_sampled_reset_obs):
raise ValueError(error)
# check if env.step can run, and generates observations rewards, done
# signals and infos that are within their respective spaces and are of
# the correct dtypes
@ -157,7 +170,11 @@ def check_gym_environments(env: gym.Env) -> None:
f"\n\n next_obs's dtype: {next_obs_type}"
f"\n\n env.observation_space's dtype: {space_type}"
)
raise ValueError(error)
temp_sampled_next_obs = convert_element_to_space_type(
next_obs, sampled_observation
)
if not env.observation_space.contains(temp_sampled_next_obs):
raise ValueError(error)
_check_done(done)
_check_reward(reward)
_check_info(info)
@ -222,10 +239,15 @@ def check_multiagent_environments(env: "MultiAgentEnv") -> None:
raise ValueError(error)
next_obs, reward, done, info = env.step(sampled_action)
_check_if_element_multi_agent_dict(env, next_obs, "step(sampled_action)")
_check_reward(reward)
_check_done(done)
_check_info(info)
_check_if_element_multi_agent_dict(env, next_obs, "step, next_obs")
_check_if_element_multi_agent_dict(env, reward, "step, reward")
_check_if_element_multi_agent_dict(env, done, "step, done")
_check_if_element_multi_agent_dict(env, info, "step, info")
_check_reward(
{"dummy_env_id": reward}, base_env=True, agent_ids=env.get_agent_ids()
)
_check_done({"dummy_env_id": done}, base_env=True, agent_ids=env.get_agent_ids())
_check_info({"dummy_env_id": info}, base_env=True, agent_ids=env.get_agent_ids())
if not env.observation_space_contains(next_obs):
error = (
_not_contained_error("env.step(sampled_action)", "observation")
@ -255,7 +277,7 @@ def check_base_env(env: "BaseEnv") -> None:
env.observation_space_contains(reset_obs)
except Exception as e:
raise ValueError(
"Your observation_space_contains function has some " "error "
"Your observation_space_contains function has some error "
) from e
if not env.observation_space_contains(reset_obs):
@ -303,51 +325,87 @@ def check_base_env(env: "BaseEnv") -> None:
)
raise ValueError(error)
_check_reward(reward, base_env=True)
_check_done(done, base_env=True)
_check_info(info, base_env=True)
_check_reward(reward, base_env=True, agent_ids=env.get_agent_ids())
_check_done(done, base_env=True, agent_ids=env.get_agent_ids())
_check_info(info, base_env=True, agent_ids=env.get_agent_ids())
def _check_reward(reward, base_env=False):
def _check_reward(reward, base_env=False, agent_ids=None):
if base_env:
for _, multi_agent_dict in reward.items():
for _, rew in multi_agent_dict.items():
assert isinstance(rew, (float, int)), (
"Your step function must return a rewards that are"
f" integer or float. reward: {rew}"
)
else:
assert isinstance(
reward, (float, int)
), "Your step function must return a reward that is integer or float."
for agent_id, rew in multi_agent_dict.items():
if not (
np.isreal(rew) and not isinstance(rew, bool) and np.isscalar(rew)
):
error = (
"Your step function must return rewards that are"
f" integer or float. reward: {rew}. Instead it was a "
f"{type(reward)}"
)
raise ValueError(error)
if not (agent_id in agent_ids or agent_id == "__all__"):
error = (
f"Your reward dictionary must have agent ids that belong to "
f"the environment. Agent_ids recieved from "
f"env.get_agent_ids() are: {agent_ids}"
)
raise ValueError(error)
elif not (
np.isreal(reward) and not isinstance(reward, bool) and np.isscalar(reward)
):
error = (
"Your step function must return a reward that is integer or float. "
"Instead it was a {}".format(type(reward))
)
raise ValueError(error)
def _check_done(done, base_env=False):
def _check_done(done, base_env=False, agent_ids=None):
if base_env:
for _, multi_agent_dict in done.items():
for _, done_ in multi_agent_dict.items():
assert isinstance(done_, bool), (
"Your step function must return a done that is boolean. "
f"element: {done_}"
)
else:
assert isinstance(done, bool), (
"Your step function must return a done that is a " "boolean."
for agent_id, done_ in multi_agent_dict.items():
if not isinstance(done_, (bool, np.bool, np.bool_)):
raise ValueError(
"Your step function must return dones that are boolean. But "
f"instead was a {type(done)}"
)
if not (agent_id in agent_ids or agent_id == "__all__"):
error = (
f"Your dones dictionary must have agent ids that belong to "
f"the environment. Agent_ids recieved from "
f"env.get_agent_ids() are: {agent_ids}"
)
raise ValueError(error)
elif not isinstance(done, (bool, np.bool, np.bool_)):
error = (
"Your step function must return a done that is a boolean. But instead "
f"was a {type(done)}"
)
raise ValueError(error)
def _check_info(info, base_env=False):
def _check_info(info, base_env=False, agent_ids=None):
if base_env:
for _, multi_agent_dict in info.items():
for _, inf in multi_agent_dict.items():
assert isinstance(inf, dict), (
"Your step function must return a info that is a dict. "
f"element: {inf}"
)
else:
assert isinstance(
info, dict
), "Your step function must return a info that is a dict."
for agent_id, inf in multi_agent_dict.items():
if not isinstance(inf, dict):
raise ValueError(
"Your step function must return infos that are a dict. "
f"instead was a {type(inf)}: element: {inf}"
)
if not (agent_id in agent_ids or agent_id == "__all__"):
error = (
f"Your dones dictionary must have agent ids that belong to "
f"the environment. Agent_ids recieved from "
f"env.get_agent_ids() are: {agent_ids}"
)
raise ValueError(error)
elif not isinstance(info, dict):
error = (
"Your step function must return a info that "
f"is a dict. element type: {type(info)}. element: {info}"
)
raise ValueError(error)
def _not_contained_error(func_name, _type):
@ -398,6 +456,7 @@ def _check_if_element_multi_agent_dict(env, element, function_string, base_env=F
raise ValueError(error)
agent_ids: Set = env.get_agent_ids()
agent_ids.add("__all__")
if not all(k in agent_ids for k in element):
if base_env:
error = (
@ -413,6 +472,8 @@ def _check_if_element_multi_agent_dict(env, element, function_string, base_env=F
f" that are not the names of the agents in the env. "
f"\nAgent_ids in this MultiAgentDict: "
f"{list(element.keys())}\nAgent_ids in this env:"
f"{list(env.get_agent_ids())}"
f"{list(env.get_agent_ids())}. You likley need to add the private "
f"attribute `_agent_ids` to your env, which is a set containing the "
f"ids of agents supported by your env."
)
raise ValueError(error)

View file

@ -336,8 +336,9 @@ def convert_element_to_space_type(element: Any, sampled_element: Any) -> Any:
elem = elem.astype(s.dtype)
elif isinstance(s, int):
if isinstance(elem, float):
if isinstance(elem, float) and elem.is_integer():
elem = int(elem)
return elem
return tree.map_structure(map_, element, sampled_element, check_types=False)

View file

@ -1,9 +1,12 @@
import logging
import gym
import numpy as np
import pytest
import unittest
from unittest.mock import Mock, MagicMock
from ray.rllib.env.base_env import convert_to_base_env
from ray.rllib.env.multi_agent_env import make_multi_agent, MultiAgentEnvWrapper
from ray.rllib.examples.env.random_env import RandomEnv
from ray.rllib.utils.pre_checks.env import (
@ -40,31 +43,6 @@ class TestGymCheckEnv(unittest.TestCase):
check_env(env)
del env
def test_sampled_observation_contained(self):
env = RandomEnv()
# check for observation that is out of bounds
error = ".*A sampled observation from your env wasn't contained .*"
env.observation_space.sample = MagicMock(return_value=5)
with pytest.raises(ValueError, match=error):
check_env(env)
# check for observation that is in bounds, but the wrong type
env.observation_space.sample = MagicMock(return_value=float(1))
with pytest.raises(ValueError, match=error):
check_env(env)
del env
def test_sampled_action_contained(self):
env = RandomEnv()
error = ".*A sampled action from your env wasn't contained .*"
env.action_space.sample = MagicMock(return_value=5)
with pytest.raises(ValueError, match=error):
check_env(env)
# check for observation that is in bounds, but the wrong type
env.action_space.sample = MagicMock(return_value=float(1))
with pytest.raises(ValueError, match=error):
check_env(env)
del env
def test_reset(self):
reset = MagicMock(return_value=5)
env = RandomEnv()
@ -74,7 +52,7 @@ class TestGymCheckEnv(unittest.TestCase):
with pytest.raises(ValueError, match=error):
check_env(env)
# check reset with obs of incorrect type fails
reset = MagicMock(return_value=float(1))
reset = MagicMock(return_value=float(0.1))
env.reset = reset
with pytest.raises(ValueError, match=error):
check_env(env)
@ -89,7 +67,7 @@ class TestGymCheckEnv(unittest.TestCase):
check_env(env)
# check reset that returns obs of incorrect type fails
step = MagicMock(return_value=(float(1), 5, True, {}))
step = MagicMock(return_value=(float(0.1), 5, True, {}))
env.step = step
with pytest.raises(ValueError, match=error):
check_env(env)
@ -98,21 +76,21 @@ class TestGymCheckEnv(unittest.TestCase):
step = MagicMock(return_value=(1, "Not a valid reward", True, {}))
env.step = step
error = "Your step function must return a reward that is integer or " "float."
with pytest.raises(AssertionError, match=error):
with pytest.raises(ValueError, match=error):
check_env(env)
# check step that returns a non bool fails
step = MagicMock(return_value=(1, float(5), "not a valid done signal", {}))
env.step = step
error = "Your step function must return a done that is a boolean."
with pytest.raises(AssertionError, match=error):
with pytest.raises(ValueError, match=error):
check_env(env)
# check step that returns a non dict fails
step = MagicMock(return_value=(1, float(5), True, "not a valid env info"))
env.step = step
error = "Your step function must return a info that is a dict."
with pytest.raises(AssertionError, match=error):
with pytest.raises(ValueError, match=error):
check_env(env)
del env
@ -176,24 +154,20 @@ class TestCheckMultiAgentEnv(unittest.TestCase):
with pytest.raises(ValueError, match="The element returned by step"):
check_env(env)
step = MagicMock(return_value=(sampled_obs, "Not a reward", True, {}))
step = MagicMock(return_value=(sampled_obs, {0: "Not a reward"}, {0: True}, {}))
env.step = step
with pytest.raises(
AssertionError, match="Your step function must " "return a reward "
):
with pytest.raises(ValueError, match="Your step function must return rewards"):
check_env(env)
step = MagicMock(return_value=(sampled_obs, 5, "Not a bool", {}))
step = MagicMock(return_value=(sampled_obs, {0: 5}, {0: "Not a bool"}, {}))
env.step = step
with pytest.raises(
AssertionError, match="Your step function must " "return a done"
):
with pytest.raises(ValueError, match="Your step function must return dones"):
check_env(env)
step = MagicMock(return_value=(sampled_obs, 5, False, "Not a Dict"))
step = MagicMock(
return_value=(sampled_obs, {0: 5}, {0: False}, {0: "Not a Dict"})
)
env.step = step
with pytest.raises(
AssertionError, match="Your step function must " "return a info"
):
with pytest.raises(ValueError, match="Your step function must return infos"):
check_env(env)
def test_bad_sample_function(self):
@ -345,29 +319,32 @@ class TestCheckBaseEnv:
poll = MagicMock(return_value=(good_obs, bad_reward, good_done, good_info, {}))
env.poll = poll
with pytest.raises(
AssertionError, match="Your step function must " "return a rewards that are"
ValueError, match="Your step function must return rewards that are"
):
check_env(env)
bad_done = {0: {0: "not_done", 1: False}}
poll = MagicMock(return_value=(good_obs, good_reward, bad_done, good_info, {}))
env.poll = poll
with pytest.raises(
AssertionError,
match="Your step function must " "return a done that is " "boolean.",
ValueError,
match="Your step function must return dones that are boolean.",
):
check_env(env)
bad_info = {0: {0: "not_info", 1: {}}}
poll = MagicMock(return_value=(good_obs, good_reward, good_done, bad_info, {}))
env.poll = poll
with pytest.raises(
AssertionError,
match="Your step function must" " return a info that is a " "dict.",
ValueError,
match="Your step function must return infos that are a dict.",
):
check_env(env)
def test_check_correct_env(self):
env = self._make_base_env()
check_env(env)
env = gym.make("CartPole-v0")
env = convert_to_base_env(env)
check_env(env)
if __name__ == "__main__":