[RLlib] Put env-checker on critical path. (#22191)

This commit is contained in:
Avnish Narayan 2022-02-17 05:06:14 -08:00 committed by GitHub
parent e03606f0b3
commit 740def0a13
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
25 changed files with 260 additions and 185 deletions

View file

@ -87,6 +87,7 @@ class TestDDPG(unittest.TestCase):
# Test against all frameworks. # Test against all frameworks.
for _ in framework_iterator(core_config): for _ in framework_iterator(core_config):
config = core_config.copy() config = core_config.copy()
config["seed"] = 42
# Default OUNoise setup. # Default OUNoise setup.
trainer = ddpg.DDPGTrainer(config=config, env="Pendulum-v1") trainer = ddpg.DDPGTrainer(config=config, env="Pendulum-v1")
# Setting explore=False should always return the same action. # Setting explore=False should always return the same action.
@ -142,6 +143,7 @@ class TestDDPG(unittest.TestCase):
"""Tests DDPG loss function results across all frameworks.""" """Tests DDPG loss function results across all frameworks."""
config = ddpg.DEFAULT_CONFIG.copy() config = ddpg.DEFAULT_CONFIG.copy()
# Run locally. # Run locally.
config["seed"] = 42
config["num_workers"] = 0 config["num_workers"] = 0
config["learning_starts"] = 0 config["learning_starts"] = 0
config["twin_q"] = True config["twin_q"] = True

View file

@ -36,6 +36,7 @@ torch, _ = try_import_torch()
class SimpleEnv(Env): class SimpleEnv(Env):
def __init__(self, config): def __init__(self, config):
self._skip_env_checking = True
if config.get("simplex_actions", False): if config.get("simplex_actions", False):
self.action_space = Simplex((2,)) self.action_space = Simplex((2,))
else: else:
@ -168,6 +169,7 @@ class TestSAC(unittest.TestCase):
"""Tests SAC loss function results across all frameworks.""" """Tests SAC loss function results across all frameworks."""
config = sac.DEFAULT_CONFIG.copy() config = sac.DEFAULT_CONFIG.copy()
# Run locally. # Run locally.
config["seed"] = 42
config["num_workers"] = 0 config["num_workers"] = 0
config["learning_starts"] = 0 config["learning_starts"] = 0
config["twin_q"] = False config["twin_q"] = False

View file

@ -628,6 +628,9 @@ COMMON_CONFIG: TrainerConfigDict = {
# training iteration. # training iteration.
"_disable_execution_plan_api": False, "_disable_execution_plan_api": False,
# If True, disable the environment pre-checking module.
"disable_env_checking": False,
# === Deprecated keys === # === Deprecated keys ===
# Uses the sync samples optimizer instead of the multi-gpu one. This is # Uses the sync samples optimizer instead of the multi-gpu one. This is
# usually slower, but you might want to try it if you run into issues with # usually slower, but you might want to try it if you run into issues with

View file

@ -5,7 +5,7 @@ import numpy as np
from gym.spaces import Discrete, Dict, Box from gym.spaces import Discrete, Dict, Box
class CartPole: class CartPole(gym.Env):
""" """
Wrapper for gym CartPole environment where the reward Wrapper for gym CartPole environment where the reward
is accumulated to the end is accumulated to the end

15
rllib/env/base_env.py vendored
View file

@ -204,7 +204,7 @@ class BaseEnv:
Returns: Returns:
All agent ids for each the environment. All agent ids for each the environment.
""" """
return {_DUMMY_AGENT_ID} return {}
@PublicAPI @PublicAPI
def try_render(self, env_id: Optional[EnvID] = None) -> None: def try_render(self, env_id: Optional[EnvID] = None) -> None:
@ -313,7 +313,7 @@ class BaseEnv:
True if the observations are contained within their respective True if the observations are contained within their respective
spaces. False otherwise. spaces. False otherwise.
""" """
self._space_contains(self.observation_space, x) return self._space_contains(self.observation_space, x)
@PublicAPI @PublicAPI
def action_space_contains(self, x: MultiEnvDict) -> bool: def action_space_contains(self, x: MultiEnvDict) -> bool:
@ -340,8 +340,15 @@ class BaseEnv:
""" """
agents = set(self.get_agent_ids()) agents = set(self.get_agent_ids())
for multi_agent_dict in x.values(): for multi_agent_dict in x.values():
for agent_id, obs in multi_agent_dict: for agent_id, obs in multi_agent_dict.items():
if (agent_id not in agents) or (not space[agent_id].contains(obs)): # this is for the case where we have a single agent
# and we're checking a Vector env thats been converted to
# a BaseEnv
if agent_id == _DUMMY_AGENT_ID:
if not space.contains(obs):
return False
# for the MultiAgent env case
elif (agent_id not in agents) or (not space[agent_id].contains(obs)):
return False return False
return True return True

View file

@ -178,7 +178,11 @@ class MultiAgentEnv(gym.Env):
if agent_ids is None: if agent_ids is None:
agent_ids = self.get_agent_ids() agent_ids = self.get_agent_ids()
samples = self.action_space.sample() samples = self.action_space.sample()
return {agent_id: samples[agent_id] for agent_id in agent_ids} return {
agent_id: samples[agent_id]
for agent_id in agent_ids
if agent_id != "__all__"
}
logger.warning("action_space_sample() has not been implemented") logger.warning("action_space_sample() has not been implemented")
del agent_ids del agent_ids
return {} return {}

View file

@ -1,9 +1,9 @@
import logging import logging
import gym import gym
import numpy as np import numpy as np
from typing import Callable, List, Optional, Tuple, Union from typing import Callable, List, Optional, Tuple, Union, Set
from ray.rllib.env.base_env import BaseEnv from ray.rllib.env.base_env import BaseEnv, _DUMMY_AGENT_ID
from ray.rllib.utils.annotations import Deprecated, override, PublicAPI from ray.rllib.utils.annotations import Deprecated, override, PublicAPI
from ray.rllib.utils.typing import ( from ray.rllib.utils.typing import (
EnvActionType, EnvActionType,
@ -12,6 +12,7 @@ from ray.rllib.utils.typing import (
EnvObsType, EnvObsType,
EnvType, EnvType,
MultiEnvDict, MultiEnvDict,
AgentID,
) )
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -355,3 +356,20 @@ class VectorEnvWrapper(BaseEnv):
@PublicAPI @PublicAPI
def action_space(self) -> gym.Space: def action_space(self) -> gym.Space:
return self._action_space return self._action_space
@override(BaseEnv)
@PublicAPI
def action_space_sample(self, agent_id: list = None) -> MultiEnvDict:
del agent_id
return {0: {_DUMMY_AGENT_ID: self._action_space.sample()}}
@override(BaseEnv)
@PublicAPI
def observation_space_sample(self, agent_id: list = None) -> MultiEnvDict:
del agent_id
return {0: {_DUMMY_AGENT_ID: self._observation_space.sample()}}
@override(BaseEnv)
@PublicAPI
def get_agent_ids(self) -> Set[AgentID]:
return {_DUMMY_AGENT_ID}

View file

@ -55,6 +55,8 @@ class GroupAgentsWrapper(MultiAgentEnv):
self.observation_space = obs_space self.observation_space = obs_space
if act_space is not None: if act_space is not None:
self.action_space = act_space self.action_space = act_space
for group_id in groups.keys():
self._agent_ids.add(group_id)
def seed(self, seed=None): def seed(self, seed=None):
if not hasattr(self.env, "seed"): if not hasattr(self.env, "seed"):

View file

@ -9,7 +9,7 @@ class OpenSpielEnv(MultiAgentEnv):
def __init__(self, env): def __init__(self, env):
super().__init__() super().__init__()
self.env = env self.env = env
self._skip_env_checking = True
# Agent IDs are ints, starting from 0. # Agent IDs are ints, starting from 0.
self.num_agents = self.env.num_players() self.num_agents = self.env.num_players()
# Store the open-spiel game type. # Store the open-spiel game type.

View file

@ -70,6 +70,8 @@ class PettingZooEnv(MultiAgentEnv):
super().__init__() super().__init__()
self.env = env self.env = env
env.reset() env.reset()
self._skip_env_checking = True # TODO avnishn - remove this after making
# petting zoo env compatible with check_env
# Get first observation space, assuming all agents have equal space # Get first observation space, assuming all agents have equal space
self.observation_space = self.env.observation_space(self.env.agents[0]) self.observation_space = self.env.observation_space(self.env.agents[0])
@ -94,6 +96,7 @@ class PettingZooEnv(MultiAgentEnv):
"SuperSuit's pad_action_space wrapper can help (usage: " "SuperSuit's pad_action_space wrapper can help (usage: "
"`supersuit.aec_wrappers.pad_action_space(env)`" "`supersuit.aec_wrappers.pad_action_space(env)`"
) )
self._agent_ids = set(self.env.agents)
def reset(self): def reset(self):
self.env.reset() self.env.reset()

View file

@ -1,6 +1,6 @@
import copy import copy
import gym import gym
from gym.spaces import Box, Discrete, MultiDiscrete, Space from gym.spaces import Discrete, MultiDiscrete, Space
import logging import logging
import numpy as np import numpy as np
import platform import platform
@ -25,11 +25,9 @@ from ray import ObjectRef
from ray import cloudpickle as pickle from ray import cloudpickle as pickle
from ray.rllib.env.base_env import BaseEnv, convert_to_base_env from ray.rllib.env.base_env import BaseEnv, convert_to_base_env
from ray.rllib.env.env_context import EnvContext from ray.rllib.env.env_context import EnvContext
from ray.rllib.env.external_env import ExternalEnv
from ray.rllib.env.multi_agent_env import MultiAgentEnv from ray.rllib.env.multi_agent_env import MultiAgentEnv
from ray.rllib.env.external_multi_agent_env import ExternalMultiAgentEnv from ray.rllib.env.external_multi_agent_env import ExternalMultiAgentEnv
from ray.rllib.env.utils import record_env_wrapper from ray.rllib.env.utils import record_env_wrapper
from ray.rllib.env.vector_env import VectorEnv
from ray.rllib.env.wrappers.atari_wrappers import wrap_deepmind, is_atari from ray.rllib.env.wrappers.atari_wrappers import wrap_deepmind, is_atari
from ray.rllib.evaluation.sampler import AsyncSampler, SyncSampler from ray.rllib.evaluation.sampler import AsyncSampler, SyncSampler
from ray.rllib.evaluation.metrics import RolloutMetrics from ray.rllib.evaluation.metrics import RolloutMetrics
@ -43,11 +41,11 @@ from ray.rllib.policy.sample_batch import MultiAgentBatch, DEFAULT_POLICY_ID
from ray.rllib.policy.policy import Policy, PolicySpec from ray.rllib.policy.policy import Policy, PolicySpec
from ray.rllib.policy.policy_map import PolicyMap from ray.rllib.policy.policy_map import PolicyMap
from ray.rllib.policy.torch_policy import TorchPolicy from ray.rllib.policy.torch_policy import TorchPolicy
from ray.rllib.utils import force_list, merge_dicts from ray.rllib.utils import force_list, merge_dicts, check_env
from ray.rllib.utils.annotations import Deprecated, DeveloperAPI, ExperimentalAPI from ray.rllib.utils.annotations import Deprecated, DeveloperAPI, ExperimentalAPI
from ray.rllib.utils.debug import summarize, update_global_seed_if_necessary from ray.rllib.utils.debug import summarize, update_global_seed_if_necessary
from ray.rllib.utils.deprecation import deprecation_warning from ray.rllib.utils.deprecation import deprecation_warning
from ray.rllib.utils.error import EnvError, ERR_MSG_NO_GPUS, HOWTO_CHANGE_CONFIG from ray.rllib.utils.error import ERR_MSG_NO_GPUS, HOWTO_CHANGE_CONFIG
from ray.rllib.utils.filter import get_filter, Filter from ray.rllib.utils.filter import get_filter, Filter
from ray.rllib.utils.framework import try_import_tf, try_import_torch from ray.rllib.utils.framework import try_import_tf, try_import_torch
from ray.rllib.utils.sgd import do_minibatch_sgd from ray.rllib.utils.sgd import do_minibatch_sgd
@ -251,6 +249,7 @@ class RolloutWorker(ParallelIteratorWorker):
spaces: Optional[Dict[PolicyID, Tuple[Space, Space]]] = None, spaces: Optional[Dict[PolicyID, Tuple[Space, Space]]] = None,
policy=None, policy=None,
monitor_path=None, monitor_path=None,
disable_env_checking=False,
): ):
"""Initializes a RolloutWorker instance. """Initializes a RolloutWorker instance.
@ -365,6 +364,8 @@ class RolloutWorker(ParallelIteratorWorker):
Env is created on this RolloutWorker. Env is created on this RolloutWorker.
policy: Obsoleted arg. Use `policy_spec` instead. policy: Obsoleted arg. Use `policy_spec` instead.
monitor_path: Obsoleted arg. Use `record_env` instead. monitor_path: Obsoleted arg. Use `record_env` instead.
disable_env_checking: If True, disables the env checking module that
validates the properties of the passed environment.
""" """
# Deprecated args. # Deprecated args.
@ -463,6 +464,7 @@ class RolloutWorker(ParallelIteratorWorker):
self.last_batch: Optional[SampleBatchType] = None self.last_batch: Optional[SampleBatchType] = None
self.global_vars: Optional[dict] = None self.global_vars: Optional[dict] = None
self.fake_sampler: bool = fake_sampler self.fake_sampler: bool = fake_sampler
self._disable_env_checking: bool = disable_env_checking
# Update the global seed for numpy/random/tf-eager/torch if we are not # Update the global seed for numpy/random/tf-eager/torch if we are not
# the local worker, otherwise, this was already done in the Trainer # the local worker, otherwise, this was already done in the Trainer
@ -492,8 +494,19 @@ class RolloutWorker(ParallelIteratorWorker):
if self.env is not None: if self.env is not None:
# Validate environment (general validation function). # Validate environment (general validation function).
_validate_env(self.env, env_context=self.env_context) if not self._disable_env_checking:
# Custom validation function given. logger.warning(
"We've added a module for checking environments that "
"are used in experiments. It will cause your "
"environment to fail if your environment is not set up"
"correctly. You can disable check env by setting "
"`disable_env_checking` to True in your experiment config "
"dictionary. You can run the environment checking module "
"standalone by calling ray.rllib.utils.check_env(env)."
)
check_env(self.env)
# Custom validation function given, typically a function attribute of the
# algorithm trainer.
if validate_env is not None: if validate_env is not None:
validate_env(self.env, self.env_context) validate_env(self.env, self.env_context)
# We can't auto-wrap a BaseEnv. # We can't auto-wrap a BaseEnv.
@ -1717,6 +1730,8 @@ class RolloutWorker(ParallelIteratorWorker):
def _get_make_sub_env_fn( def _get_make_sub_env_fn(
self, env_creator, env_context, validate_env, env_wrapper, seed self, env_creator, env_context, validate_env, env_wrapper, seed
): ):
disable_env_checking = self._disable_env_checking
def _make_sub_env_local(vector_index): def _make_sub_env_local(vector_index):
# Used to created additional environments during environment # Used to created additional environments during environment
# vectorization. # vectorization.
@ -1727,7 +1742,17 @@ class RolloutWorker(ParallelIteratorWorker):
# Create the sub-env. # Create the sub-env.
env = env_creator(env_ctx) env = env_creator(env_ctx)
# Validate first. # Validate first.
_validate_env(env, env_context=env_ctx) if not disable_env_checking:
logger.warning(
"We've added a module for checking environments that "
"are used in experiments. It will cause your "
"environment to fail if your environment is not set up"
"correctly. You can disable check env by setting "
"`disable_env_checking` to True in your experiment config "
"dictionary. You can run the environment checking module "
"standalone by calling ray.rllib.utils.check_env(env)."
)
check_env(env)
# Custom validation function given by user. # Custom validation function given by user.
if validate_env is not None: if validate_env is not None:
validate_env(env, env_ctx) validate_env(env, env_ctx)
@ -1870,52 +1895,3 @@ def _determine_spaces_for_multi_agent_dict(
action_space=act_space action_space=act_space
) )
return multi_agent_dict return multi_agent_dict
def _validate_env(env: EnvType, env_context: EnvContext = None):
# Base message for checking the env for vector-index=0
msg = f"Validating sub-env at vector index={env_context.vector_index} ..."
allowed_types = [gym.Env, ExternalEnv, VectorEnv, BaseEnv, ray.actor.ActorHandle]
if not any(isinstance(env, tpe) for tpe in allowed_types):
# Allow this as a special case (assumed gym.Env).
# TODO: Disallow this early-out. Everything should conform to a few
# supported classes, i.e. gym.Env/MultiAgentEnv/etc...
if hasattr(env, "observation_space") and hasattr(env, "action_space"):
logger.warning(msg + f" (warning; invalid env-type={type(env)})")
return
else:
logger.warning(msg + " (NOT OK)")
raise EnvError(
"Returned env should be an instance of gym.Env (incl. "
"MultiAgentEnv), ExternalEnv, VectorEnv, or BaseEnv. "
f"The provided env creator function returned {env} "
f"(type={type(env)})."
)
# Do some test runs with the provided env.
if isinstance(env, gym.Env) and not isinstance(env, MultiAgentEnv):
# Make sure the gym.Env has the two space attributes properly set.
assert hasattr(env, "observation_space") and hasattr(env, "action_space")
# Get a dummy observation by resetting the env.
dummy_obs = env.reset()
# Convert lists to np.ndarrays.
if type(dummy_obs) is list and isinstance(env.observation_space, Box):
dummy_obs = np.array(dummy_obs)
# Ignore float32/float64 diffs.
if (
isinstance(env.observation_space, Box)
and env.observation_space.dtype != dummy_obs.dtype
):
dummy_obs = dummy_obs.astype(env.observation_space.dtype)
# Check, if observation is ok (part of the observation space). If not,
# error.
if not env.observation_space.contains(dummy_obs):
logger.warning(msg + " (NOT OK)")
raise EnvError(
f"Env's `observation_space` {env.observation_space} does not "
f"contain returned observation after a reset ({dummy_obs})!"
)
# Log that everything is ok.
logger.info(msg + " (ok)")

View file

@ -74,6 +74,7 @@ class EchoPolicy(Policy):
class EpisodeEnv(MultiAgentEnv): class EpisodeEnv(MultiAgentEnv):
def __init__(self, episode_length, num): def __init__(self, episode_length, num):
super().__init__() super().__init__()
self._skip_env_checking = True
self.agents = [MockEnv3(episode_length) for _ in range(num)] self.agents = [MockEnv3(episode_length) for _ in range(num)]
self.dones = set() self.dones = set()
self.observation_space = self.agents[0].observation_space self.observation_space = self.agents[0].observation_space

View file

@ -591,6 +591,7 @@ class WorkerSet:
fake_sampler=config["fake_sampler"], fake_sampler=config["fake_sampler"],
extra_python_environs=extra_python_environs, extra_python_environs=extra_python_environs,
spaces=spaces, spaces=spaces,
disable_env_checking=config["disable_env_checking"],
) )
return worker return worker

View file

@ -9,7 +9,7 @@ class ActionMaskEnv(RandomEnv):
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
self._skip_env_checking = True
# Masking only works for Discrete actions. # Masking only works for Discrete actions.
assert isinstance(self.action_space, Discrete) assert isinstance(self.action_space, Discrete)
# Add action_mask to observations. # Add action_mask to observations.

View file

@ -34,6 +34,7 @@ class DebugCounterEnv(gym.Env):
class MultiAgentDebugCounterEnv(MultiAgentEnv): class MultiAgentDebugCounterEnv(MultiAgentEnv):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
self._skip_env_checking = True
self.num_agents = config["num_agents"] self.num_agents = config["num_agents"]
self.base_episode_len = config.get("base_episode_len", 103) self.base_episode_len = config.get("base_episode_len", 103)
# Actions are always: # Actions are always:

View file

@ -27,6 +27,7 @@ class BasicMultiAgent(MultiAgentEnv):
def __init__(self, num): def __init__(self, num):
super().__init__() super().__init__()
self.agents = [MockEnv(25) for _ in range(num)] self.agents = [MockEnv(25) for _ in range(num)]
self._agent_ids = set(range(num))
self.dones = set() self.dones = set()
self.observation_space = gym.spaces.Discrete(2) self.observation_space = gym.spaces.Discrete(2)
self.action_space = gym.spaces.Discrete(2) self.action_space = gym.spaces.Discrete(2)
@ -59,6 +60,7 @@ class EarlyDoneMultiAgent(MultiAgentEnv):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.agents = [MockEnv(3), MockEnv(5)] self.agents = [MockEnv(3), MockEnv(5)]
self._agent_ids = set(range(len(self.agents)))
self.dones = set() self.dones = set()
self.last_obs = {} self.last_obs = {}
self.last_rew = {} self.last_rew = {}
@ -77,7 +79,7 @@ class EarlyDoneMultiAgent(MultiAgentEnv):
self.i = 0 self.i = 0
for i, a in enumerate(self.agents): for i, a in enumerate(self.agents):
self.last_obs[i] = a.reset() self.last_obs[i] = a.reset()
self.last_rew[i] = None self.last_rew[i] = 0
self.last_done[i] = False self.last_done[i] = False
self.last_info[i] = {} self.last_info[i] = {}
obs_dict = {self.i: self.last_obs[self.i]} obs_dict = {self.i: self.last_obs[self.i]}
@ -111,6 +113,7 @@ class FlexAgentsMultiAgent(MultiAgentEnv):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.agents = {} self.agents = {}
self._agent_ids = set()
self.agentID = 0 self.agentID = 0
self.dones = set() self.dones = set()
self.observation_space = gym.spaces.Discrete(2) self.observation_space = gym.spaces.Discrete(2)
@ -121,15 +124,16 @@ class FlexAgentsMultiAgent(MultiAgentEnv):
# Spawn a new agent into the current episode. # Spawn a new agent into the current episode.
agentID = self.agentID agentID = self.agentID
self.agents[agentID] = MockEnv(25) self.agents[agentID] = MockEnv(25)
self._agent_ids.add(agentID)
self.agentID += 1 self.agentID += 1
return agentID return agentID
def reset(self): def reset(self):
self.agents = {} self.agents = {}
self._agent_ids = set()
self.spawn() self.spawn()
self.resetted = True self.resetted = True
self.dones = set() self.dones = set()
obs = {} obs = {}
for i, a in self.agents.items(): for i, a in self.agents.items():
obs[i] = a.reset() obs[i] = a.reset()
@ -175,6 +179,7 @@ class RoundRobinMultiAgent(MultiAgentEnv):
else: else:
# Observations are all zeros # Observations are all zeros
self.agents = [MockEnv(5) for _ in range(num)] self.agents = [MockEnv(5) for _ in range(num)]
self._agent_ids = set(range(num))
self.dones = set() self.dones = set()
self.last_obs = {} self.last_obs = {}
self.last_rew = {} self.last_rew = {}
@ -194,7 +199,7 @@ class RoundRobinMultiAgent(MultiAgentEnv):
self.i = 0 self.i = 0
for i, a in enumerate(self.agents): for i, a in enumerate(self.agents):
self.last_obs[i] = a.reset() self.last_obs[i] = a.reset()
self.last_rew[i] = None self.last_rew[i] = 0
self.last_done[i] = False self.last_done[i] = False
self.last_info[i] = {} self.last_info[i] = {}
obs_dict = {self.i: self.last_obs[self.i]} obs_dict = {self.i: self.last_obs[self.i]}

View file

@ -40,6 +40,7 @@ class ParametricActionsCartPole(gym.Env):
"cart": self.wrapped.observation_space, "cart": self.wrapped.observation_space,
} }
) )
self._skip_env_checking = True
def update_avail_actions(self): def update_avail_actions(self):
self.action_assignments = np.array( self.action_assignments = np.array(
@ -114,6 +115,7 @@ class ParametricActionsCartPoleNoEmbeddings(gym.Env):
"cart": self.wrapped.observation_space, "cart": self.wrapped.observation_space,
} }
) )
self._skip_env_checking = True
def reset(self): def reset(self):
return { return {

View file

@ -13,11 +13,12 @@ class TwoStepGame(MultiAgentEnv):
self.state = None self.state = None
self.agent_1 = 0 self.agent_1 = 0
self.agent_2 = 1 self.agent_2 = 1
self._skip_env_checking = True
# MADDPG emits action logits instead of actual discrete actions # MADDPG emits action logits instead of actual discrete actions
self.actions_are_logits = env_config.get("actions_are_logits", False) self.actions_are_logits = env_config.get("actions_are_logits", False)
self.one_hot_state_encoding = env_config.get("one_hot_state_encoding", False) self.one_hot_state_encoding = env_config.get("one_hot_state_encoding", False)
self.with_state = env_config.get("separate_state_space", False) self.with_state = env_config.get("separate_state_space", False)
self._agent_ids = {0, 1}
if not self.one_hot_state_encoding: if not self.one_hot_state_encoding:
self.observation_space = Discrete(6) self.observation_space = Discrete(6)
self.with_state = False self.with_state = False
@ -113,6 +114,8 @@ class TwoStepGameWithGroupedAgents(MultiAgentEnv):
) )
self.observation_space = self.env.observation_space self.observation_space = self.env.observation_space
self.action_space = self.env.action_space self.action_space = self.env.action_space
self._agent_ids = {"agents"}
self._skip_env_checking = True
def reset(self): def reset(self):
return self.env.reset() return self.env.reset()

View file

@ -86,6 +86,7 @@ class WindyMazeEnv(gym.Env):
class HierarchicalWindyMazeEnv(MultiAgentEnv): class HierarchicalWindyMazeEnv(MultiAgentEnv):
def __init__(self, env_config): def __init__(self, env_config):
super().__init__() super().__init__()
self._skip_env_checking = True
self.flat_env = WindyMazeEnv(env_config) self.flat_env = WindyMazeEnv(env_config)
def reset(self): def reset(self):

View file

@ -60,6 +60,7 @@ class NonVectorizedEnvToBeVectorizedIntoRemoteBaseEnv(TaskSettableEnv):
""" """
def __init__(self, config=None): def __init__(self, config=None):
super().__init__()
self.action_space = gym.spaces.Box(0, 1, shape=(1,)) self.action_space = gym.spaces.Box(0, 1, shape=(1,))
self.observation_space = gym.spaces.Box(0, 1, shape=(2,)) self.observation_space = gym.spaces.Box(0, 1, shape=(2,))
self.task = 1 self.task = 1

View file

@ -26,6 +26,7 @@ class FaultInjectEnv(gym.Env):
def __init__(self, config): def __init__(self, config):
self.env = gym.make("CartPole-v0") self.env = gym.make("CartPole-v0")
self._skip_env_checking = True
self.action_space = self.env.action_space self.action_space = self.env.action_space
self.observation_space = self.env.observation_space self.observation_space = self.env.observation_space
self.config = config self.config = config

View file

@ -404,6 +404,7 @@ class NestedObservationSpacesTest(unittest.TestCase):
"use_lstm": test_lstm, "use_lstm": test_lstm,
}, },
"framework": "tf", "framework": "tf",
"disable_env_checking": True,
}, },
) )
# Skip first passes as they came from the TorchPolicy loss # Skip first passes as they came from the TorchPolicy loss
@ -436,6 +437,7 @@ class NestedObservationSpacesTest(unittest.TestCase):
"custom_model": "composite2", "custom_model": "composite2",
}, },
"framework": "tf", "framework": "tf",
"disable_env_checking": True,
}, },
) )
# Skip first passes as they came from the TorchPolicy loss # Skip first passes as they came from the TorchPolicy loss
@ -518,6 +520,7 @@ class NestedObservationSpacesTest(unittest.TestCase):
}[aid], }[aid],
}, },
"framework": "tf", "framework": "tf",
"disable_env_checking": True,
}, },
) )
# Skip first passes as they came from the TorchPolicy loss # Skip first passes as they came from the TorchPolicy loss

View file

@ -1,8 +1,12 @@
"""Common pre-checks for all RLlib experiments.""" """Common pre-checks for all RLlib experiments."""
import logging import logging
import numpy as np
from typing import TYPE_CHECKING, Set from typing import TYPE_CHECKING, Set
import gym import gym
from ray.actor import ActorHandle
from ray.rllib.utils.spaces.space_utils import convert_element_to_space_type
from ray.rllib.utils.typing import EnvType from ray.rllib.utils.typing import EnvType
if TYPE_CHECKING: if TYPE_CHECKING:
@ -21,12 +25,38 @@ def check_env(env: EnvType) -> None:
ValueError: If env is not an instance of SUPPORTED_ENVIRONMENT_TYPES. ValueError: If env is not an instance of SUPPORTED_ENVIRONMENT_TYPES.
ValueError: See check_gym_env docstring for details. ValueError: See check_gym_env docstring for details.
""" """
from ray.rllib.env import BaseEnv, MultiAgentEnv, RemoteBaseEnv, VectorEnv from ray.rllib.env import (
BaseEnv,
MultiAgentEnv,
RemoteBaseEnv,
VectorEnv,
ExternalMultiAgentEnv,
ExternalEnv,
)
if not isinstance(env, (BaseEnv, gym.Env, MultiAgentEnv, RemoteBaseEnv, VectorEnv)): if hasattr(env, "_skip_env_checking") and env._skip_env_checking:
# This is a work around for some environments that we already have in RLlb
# that we want to skip checking for now until we have the time to fix them.
logger.warning("Skipping env checking for this experiment")
return
if not isinstance(
env,
(
BaseEnv,
gym.Env,
MultiAgentEnv,
RemoteBaseEnv,
VectorEnv,
ExternalMultiAgentEnv,
ExternalEnv,
ActorHandle,
),
):
raise ValueError( raise ValueError(
"Env must be one of the supported types: BaseEnv, gym.Env, " "Env must be one of the supported types: BaseEnv, gym.Env, "
"MultiAgentEnv, VectorEnv, RemoteBaseEnv" "MultiAgentEnv, VectorEnv, RemoteBaseEnv, ExternalMultiAgentEnv, "
f"ExternalEnv, but instead was a {type(env)}"
) )
if isinstance(env, MultiAgentEnv): if isinstance(env, MultiAgentEnv):
@ -37,7 +67,8 @@ def check_env(env: EnvType) -> None:
check_base_env(env) check_base_env(env)
else: else:
logger.warning( logger.warning(
"Env checking isn't implemented for VectorEnvs or " "RemoteBaseEnvs." "Env checking isn't implemented for VectorEnvs, RemoteBaseEnvs, "
"ExternalMultiAgentEnv,or ExternalEnvs or Environments that are Ray actors"
) )
@ -97,33 +128,11 @@ def check_gym_environments(env: gym.Env) -> None:
# check if sampled actions and observations are contained within their # check if sampled actions and observations are contained within their
# respective action and observation spaces. # respective action and observation spaces.
def contains_error(action_or_observation, sample, space):
string_type = "observation" if not action_or_observation else "action"
sample_type = get_type(sample)
_space_type = space.dtype
ret = (
f"A sampled {string_type} from your env wasn't contained "
f"within your env's {string_type} space. Its possible that "
f"there was a type mismatch, or that one of the "
f"sub-{string_type} was out of bounds:\n\nsampled_obs: "
f"{sample}\nenv.{string_type}_space: {space}"
f"\nsampled_obs's dtype: {sample_type}"
f"\nenv.{sample_type}'s dtype: {_space_type}"
)
return ret
def get_type(var): def get_type(var):
return var.dtype if hasattr(var, "dtype") else type(var) return var.dtype if hasattr(var, "dtype") else type(var)
sampled_action = env.action_space.sample() sampled_action = env.action_space.sample()
sampled_observation = env.observation_space.sample() sampled_observation = env.observation_space.sample()
if not env.observation_space.contains(sampled_observation):
raise ValueError(
contains_error(False, sampled_observation, env.observation_space)
)
if not env.action_space.contains(sampled_action):
raise ValueError(contains_error(True, sampled_action, env.action_space))
# check if observation generated from stepping the environment is # check if observation generated from stepping the environment is
# contained within the observation space # contained within the observation space
reset_obs = env.reset() reset_obs = env.reset()
@ -140,6 +149,10 @@ def check_gym_environments(env: gym.Env) -> None:
f"{reset_obs_type}\n\n env.observation_space's dtype: " f"{reset_obs_type}\n\n env.observation_space's dtype: "
f"{space_type}" f"{space_type}"
) )
temp_sampled_reset_obs = convert_element_to_space_type(
reset_obs, sampled_observation
)
if not env.observation_space.contains(temp_sampled_reset_obs):
raise ValueError(error) raise ValueError(error)
# check if env.step can run, and generates observations rewards, done # check if env.step can run, and generates observations rewards, done
# signals and infos that are within their respective spaces and are of # signals and infos that are within their respective spaces and are of
@ -157,6 +170,10 @@ def check_gym_environments(env: gym.Env) -> None:
f"\n\n next_obs's dtype: {next_obs_type}" f"\n\n next_obs's dtype: {next_obs_type}"
f"\n\n env.observation_space's dtype: {space_type}" f"\n\n env.observation_space's dtype: {space_type}"
) )
temp_sampled_next_obs = convert_element_to_space_type(
next_obs, sampled_observation
)
if not env.observation_space.contains(temp_sampled_next_obs):
raise ValueError(error) raise ValueError(error)
_check_done(done) _check_done(done)
_check_reward(reward) _check_reward(reward)
@ -222,10 +239,15 @@ def check_multiagent_environments(env: "MultiAgentEnv") -> None:
raise ValueError(error) raise ValueError(error)
next_obs, reward, done, info = env.step(sampled_action) next_obs, reward, done, info = env.step(sampled_action)
_check_if_element_multi_agent_dict(env, next_obs, "step(sampled_action)") _check_if_element_multi_agent_dict(env, next_obs, "step, next_obs")
_check_reward(reward) _check_if_element_multi_agent_dict(env, reward, "step, reward")
_check_done(done) _check_if_element_multi_agent_dict(env, done, "step, done")
_check_info(info) _check_if_element_multi_agent_dict(env, info, "step, info")
_check_reward(
{"dummy_env_id": reward}, base_env=True, agent_ids=env.get_agent_ids()
)
_check_done({"dummy_env_id": done}, base_env=True, agent_ids=env.get_agent_ids())
_check_info({"dummy_env_id": info}, base_env=True, agent_ids=env.get_agent_ids())
if not env.observation_space_contains(next_obs): if not env.observation_space_contains(next_obs):
error = ( error = (
_not_contained_error("env.step(sampled_action)", "observation") _not_contained_error("env.step(sampled_action)", "observation")
@ -255,7 +277,7 @@ def check_base_env(env: "BaseEnv") -> None:
env.observation_space_contains(reset_obs) env.observation_space_contains(reset_obs)
except Exception as e: except Exception as e:
raise ValueError( raise ValueError(
"Your observation_space_contains function has some " "error " "Your observation_space_contains function has some error "
) from e ) from e
if not env.observation_space_contains(reset_obs): if not env.observation_space_contains(reset_obs):
@ -303,51 +325,87 @@ def check_base_env(env: "BaseEnv") -> None:
) )
raise ValueError(error) raise ValueError(error)
_check_reward(reward, base_env=True) _check_reward(reward, base_env=True, agent_ids=env.get_agent_ids())
_check_done(done, base_env=True) _check_done(done, base_env=True, agent_ids=env.get_agent_ids())
_check_info(info, base_env=True) _check_info(info, base_env=True, agent_ids=env.get_agent_ids())
def _check_reward(reward, base_env=False): def _check_reward(reward, base_env=False, agent_ids=None):
if base_env: if base_env:
for _, multi_agent_dict in reward.items(): for _, multi_agent_dict in reward.items():
for _, rew in multi_agent_dict.items(): for agent_id, rew in multi_agent_dict.items():
assert isinstance(rew, (float, int)), ( if not (
"Your step function must return a rewards that are" np.isreal(rew) and not isinstance(rew, bool) and np.isscalar(rew)
f" integer or float. reward: {rew}" ):
error = (
"Your step function must return rewards that are"
f" integer or float. reward: {rew}. Instead it was a "
f"{type(reward)}"
) )
else: raise ValueError(error)
assert isinstance( if not (agent_id in agent_ids or agent_id == "__all__"):
reward, (float, int) error = (
), "Your step function must return a reward that is integer or float." f"Your reward dictionary must have agent ids that belong to "
f"the environment. Agent_ids recieved from "
f"env.get_agent_ids() are: {agent_ids}"
)
raise ValueError(error)
elif not (
np.isreal(reward) and not isinstance(reward, bool) and np.isscalar(reward)
):
error = (
"Your step function must return a reward that is integer or float. "
"Instead it was a {}".format(type(reward))
)
raise ValueError(error)
def _check_done(done, base_env=False): def _check_done(done, base_env=False, agent_ids=None):
if base_env: if base_env:
for _, multi_agent_dict in done.items(): for _, multi_agent_dict in done.items():
for _, done_ in multi_agent_dict.items(): for agent_id, done_ in multi_agent_dict.items():
assert isinstance(done_, bool), ( if not isinstance(done_, (bool, np.bool, np.bool_)):
"Your step function must return a done that is boolean. " raise ValueError(
f"element: {done_}" "Your step function must return dones that are boolean. But "
f"instead was a {type(done)}"
) )
else: if not (agent_id in agent_ids or agent_id == "__all__"):
assert isinstance(done, bool), ( error = (
"Your step function must return a done that is a " "boolean." f"Your dones dictionary must have agent ids that belong to "
f"the environment. Agent_ids recieved from "
f"env.get_agent_ids() are: {agent_ids}"
) )
raise ValueError(error)
elif not isinstance(done, (bool, np.bool, np.bool_)):
error = (
"Your step function must return a done that is a boolean. But instead "
f"was a {type(done)}"
)
raise ValueError(error)
def _check_info(info, base_env=False): def _check_info(info, base_env=False, agent_ids=None):
if base_env: if base_env:
for _, multi_agent_dict in info.items(): for _, multi_agent_dict in info.items():
for _, inf in multi_agent_dict.items(): for agent_id, inf in multi_agent_dict.items():
assert isinstance(inf, dict), ( if not isinstance(inf, dict):
"Your step function must return a info that is a dict. " raise ValueError(
f"element: {inf}" "Your step function must return infos that are a dict. "
f"instead was a {type(inf)}: element: {inf}"
) )
else: if not (agent_id in agent_ids or agent_id == "__all__"):
assert isinstance( error = (
info, dict f"Your dones dictionary must have agent ids that belong to "
), "Your step function must return a info that is a dict." f"the environment. Agent_ids recieved from "
f"env.get_agent_ids() are: {agent_ids}"
)
raise ValueError(error)
elif not isinstance(info, dict):
error = (
"Your step function must return a info that "
f"is a dict. element type: {type(info)}. element: {info}"
)
raise ValueError(error)
def _not_contained_error(func_name, _type): def _not_contained_error(func_name, _type):
@ -398,6 +456,7 @@ def _check_if_element_multi_agent_dict(env, element, function_string, base_env=F
raise ValueError(error) raise ValueError(error)
agent_ids: Set = env.get_agent_ids() agent_ids: Set = env.get_agent_ids()
agent_ids.add("__all__") agent_ids.add("__all__")
if not all(k in agent_ids for k in element): if not all(k in agent_ids for k in element):
if base_env: if base_env:
error = ( error = (
@ -413,6 +472,8 @@ def _check_if_element_multi_agent_dict(env, element, function_string, base_env=F
f" that are not the names of the agents in the env. " f" that are not the names of the agents in the env. "
f"\nAgent_ids in this MultiAgentDict: " f"\nAgent_ids in this MultiAgentDict: "
f"{list(element.keys())}\nAgent_ids in this env:" f"{list(element.keys())}\nAgent_ids in this env:"
f"{list(env.get_agent_ids())}" f"{list(env.get_agent_ids())}. You likley need to add the private "
f"attribute `_agent_ids` to your env, which is a set containing the "
f"ids of agents supported by your env."
) )
raise ValueError(error) raise ValueError(error)

View file

@ -336,8 +336,9 @@ def convert_element_to_space_type(element: Any, sampled_element: Any) -> Any:
elem = elem.astype(s.dtype) elem = elem.astype(s.dtype)
elif isinstance(s, int): elif isinstance(s, int):
if isinstance(elem, float): if isinstance(elem, float) and elem.is_integer():
elem = int(elem) elem = int(elem)
return elem return elem
return tree.map_structure(map_, element, sampled_element, check_types=False) return tree.map_structure(map_, element, sampled_element, check_types=False)

View file

@ -1,9 +1,12 @@
import logging import logging
import gym
import numpy as np import numpy as np
import pytest import pytest
import unittest import unittest
from unittest.mock import Mock, MagicMock from unittest.mock import Mock, MagicMock
from ray.rllib.env.base_env import convert_to_base_env
from ray.rllib.env.multi_agent_env import make_multi_agent, MultiAgentEnvWrapper from ray.rllib.env.multi_agent_env import make_multi_agent, MultiAgentEnvWrapper
from ray.rllib.examples.env.random_env import RandomEnv from ray.rllib.examples.env.random_env import RandomEnv
from ray.rllib.utils.pre_checks.env import ( from ray.rllib.utils.pre_checks.env import (
@ -40,31 +43,6 @@ class TestGymCheckEnv(unittest.TestCase):
check_env(env) check_env(env)
del env del env
def test_sampled_observation_contained(self):
env = RandomEnv()
# check for observation that is out of bounds
error = ".*A sampled observation from your env wasn't contained .*"
env.observation_space.sample = MagicMock(return_value=5)
with pytest.raises(ValueError, match=error):
check_env(env)
# check for observation that is in bounds, but the wrong type
env.observation_space.sample = MagicMock(return_value=float(1))
with pytest.raises(ValueError, match=error):
check_env(env)
del env
def test_sampled_action_contained(self):
env = RandomEnv()
error = ".*A sampled action from your env wasn't contained .*"
env.action_space.sample = MagicMock(return_value=5)
with pytest.raises(ValueError, match=error):
check_env(env)
# check for observation that is in bounds, but the wrong type
env.action_space.sample = MagicMock(return_value=float(1))
with pytest.raises(ValueError, match=error):
check_env(env)
del env
def test_reset(self): def test_reset(self):
reset = MagicMock(return_value=5) reset = MagicMock(return_value=5)
env = RandomEnv() env = RandomEnv()
@ -74,7 +52,7 @@ class TestGymCheckEnv(unittest.TestCase):
with pytest.raises(ValueError, match=error): with pytest.raises(ValueError, match=error):
check_env(env) check_env(env)
# check reset with obs of incorrect type fails # check reset with obs of incorrect type fails
reset = MagicMock(return_value=float(1)) reset = MagicMock(return_value=float(0.1))
env.reset = reset env.reset = reset
with pytest.raises(ValueError, match=error): with pytest.raises(ValueError, match=error):
check_env(env) check_env(env)
@ -89,7 +67,7 @@ class TestGymCheckEnv(unittest.TestCase):
check_env(env) check_env(env)
# check reset that returns obs of incorrect type fails # check reset that returns obs of incorrect type fails
step = MagicMock(return_value=(float(1), 5, True, {})) step = MagicMock(return_value=(float(0.1), 5, True, {}))
env.step = step env.step = step
with pytest.raises(ValueError, match=error): with pytest.raises(ValueError, match=error):
check_env(env) check_env(env)
@ -98,21 +76,21 @@ class TestGymCheckEnv(unittest.TestCase):
step = MagicMock(return_value=(1, "Not a valid reward", True, {})) step = MagicMock(return_value=(1, "Not a valid reward", True, {}))
env.step = step env.step = step
error = "Your step function must return a reward that is integer or " "float." error = "Your step function must return a reward that is integer or " "float."
with pytest.raises(AssertionError, match=error): with pytest.raises(ValueError, match=error):
check_env(env) check_env(env)
# check step that returns a non bool fails # check step that returns a non bool fails
step = MagicMock(return_value=(1, float(5), "not a valid done signal", {})) step = MagicMock(return_value=(1, float(5), "not a valid done signal", {}))
env.step = step env.step = step
error = "Your step function must return a done that is a boolean." error = "Your step function must return a done that is a boolean."
with pytest.raises(AssertionError, match=error): with pytest.raises(ValueError, match=error):
check_env(env) check_env(env)
# check step that returns a non dict fails # check step that returns a non dict fails
step = MagicMock(return_value=(1, float(5), True, "not a valid env info")) step = MagicMock(return_value=(1, float(5), True, "not a valid env info"))
env.step = step env.step = step
error = "Your step function must return a info that is a dict." error = "Your step function must return a info that is a dict."
with pytest.raises(AssertionError, match=error): with pytest.raises(ValueError, match=error):
check_env(env) check_env(env)
del env del env
@ -176,24 +154,20 @@ class TestCheckMultiAgentEnv(unittest.TestCase):
with pytest.raises(ValueError, match="The element returned by step"): with pytest.raises(ValueError, match="The element returned by step"):
check_env(env) check_env(env)
step = MagicMock(return_value=(sampled_obs, "Not a reward", True, {})) step = MagicMock(return_value=(sampled_obs, {0: "Not a reward"}, {0: True}, {}))
env.step = step env.step = step
with pytest.raises( with pytest.raises(ValueError, match="Your step function must return rewards"):
AssertionError, match="Your step function must " "return a reward "
):
check_env(env) check_env(env)
step = MagicMock(return_value=(sampled_obs, 5, "Not a bool", {})) step = MagicMock(return_value=(sampled_obs, {0: 5}, {0: "Not a bool"}, {}))
env.step = step env.step = step
with pytest.raises( with pytest.raises(ValueError, match="Your step function must return dones"):
AssertionError, match="Your step function must " "return a done"
):
check_env(env) check_env(env)
step = MagicMock(return_value=(sampled_obs, 5, False, "Not a Dict")) step = MagicMock(
return_value=(sampled_obs, {0: 5}, {0: False}, {0: "Not a Dict"})
)
env.step = step env.step = step
with pytest.raises( with pytest.raises(ValueError, match="Your step function must return infos"):
AssertionError, match="Your step function must " "return a info"
):
check_env(env) check_env(env)
def test_bad_sample_function(self): def test_bad_sample_function(self):
@ -345,29 +319,32 @@ class TestCheckBaseEnv:
poll = MagicMock(return_value=(good_obs, bad_reward, good_done, good_info, {})) poll = MagicMock(return_value=(good_obs, bad_reward, good_done, good_info, {}))
env.poll = poll env.poll = poll
with pytest.raises( with pytest.raises(
AssertionError, match="Your step function must " "return a rewards that are" ValueError, match="Your step function must return rewards that are"
): ):
check_env(env) check_env(env)
bad_done = {0: {0: "not_done", 1: False}} bad_done = {0: {0: "not_done", 1: False}}
poll = MagicMock(return_value=(good_obs, good_reward, bad_done, good_info, {})) poll = MagicMock(return_value=(good_obs, good_reward, bad_done, good_info, {}))
env.poll = poll env.poll = poll
with pytest.raises( with pytest.raises(
AssertionError, ValueError,
match="Your step function must " "return a done that is " "boolean.", match="Your step function must return dones that are boolean.",
): ):
check_env(env) check_env(env)
bad_info = {0: {0: "not_info", 1: {}}} bad_info = {0: {0: "not_info", 1: {}}}
poll = MagicMock(return_value=(good_obs, good_reward, good_done, bad_info, {})) poll = MagicMock(return_value=(good_obs, good_reward, good_done, bad_info, {}))
env.poll = poll env.poll = poll
with pytest.raises( with pytest.raises(
AssertionError, ValueError,
match="Your step function must" " return a info that is a " "dict.", match="Your step function must return infos that are a dict.",
): ):
check_env(env) check_env(env)
def test_check_correct_env(self): def test_check_correct_env(self):
env = self._make_base_env() env = self._make_base_env()
check_env(env) check_env(env)
env = gym.make("CartPole-v0")
env = convert_to_base_env(env)
check_env(env)
if __name__ == "__main__": if __name__ == "__main__":