mirror of
https://github.com/vale981/ray
synced 2025-03-05 10:01:43 -05:00
[RLlib] Put env-checker on critical path. (#22191)
This commit is contained in:
parent
e03606f0b3
commit
740def0a13
25 changed files with 260 additions and 185 deletions
|
@ -87,6 +87,7 @@ class TestDDPG(unittest.TestCase):
|
|||
# Test against all frameworks.
|
||||
for _ in framework_iterator(core_config):
|
||||
config = core_config.copy()
|
||||
config["seed"] = 42
|
||||
# Default OUNoise setup.
|
||||
trainer = ddpg.DDPGTrainer(config=config, env="Pendulum-v1")
|
||||
# Setting explore=False should always return the same action.
|
||||
|
@ -142,6 +143,7 @@ class TestDDPG(unittest.TestCase):
|
|||
"""Tests DDPG loss function results across all frameworks."""
|
||||
config = ddpg.DEFAULT_CONFIG.copy()
|
||||
# Run locally.
|
||||
config["seed"] = 42
|
||||
config["num_workers"] = 0
|
||||
config["learning_starts"] = 0
|
||||
config["twin_q"] = True
|
||||
|
|
|
@ -36,6 +36,7 @@ torch, _ = try_import_torch()
|
|||
|
||||
class SimpleEnv(Env):
|
||||
def __init__(self, config):
|
||||
self._skip_env_checking = True
|
||||
if config.get("simplex_actions", False):
|
||||
self.action_space = Simplex((2,))
|
||||
else:
|
||||
|
@ -168,6 +169,7 @@ class TestSAC(unittest.TestCase):
|
|||
"""Tests SAC loss function results across all frameworks."""
|
||||
config = sac.DEFAULT_CONFIG.copy()
|
||||
# Run locally.
|
||||
config["seed"] = 42
|
||||
config["num_workers"] = 0
|
||||
config["learning_starts"] = 0
|
||||
config["twin_q"] = False
|
||||
|
|
|
@ -628,6 +628,9 @@ COMMON_CONFIG: TrainerConfigDict = {
|
|||
# training iteration.
|
||||
"_disable_execution_plan_api": False,
|
||||
|
||||
# If True, disable the environment pre-checking module.
|
||||
"disable_env_checking": False,
|
||||
|
||||
# === Deprecated keys ===
|
||||
# Uses the sync samples optimizer instead of the multi-gpu one. This is
|
||||
# usually slower, but you might want to try it if you run into issues with
|
||||
|
|
|
@ -5,7 +5,7 @@ import numpy as np
|
|||
from gym.spaces import Discrete, Dict, Box
|
||||
|
||||
|
||||
class CartPole:
|
||||
class CartPole(gym.Env):
|
||||
"""
|
||||
Wrapper for gym CartPole environment where the reward
|
||||
is accumulated to the end
|
||||
|
|
15
rllib/env/base_env.py
vendored
15
rllib/env/base_env.py
vendored
|
@ -204,7 +204,7 @@ class BaseEnv:
|
|||
Returns:
|
||||
All agent ids for each the environment.
|
||||
"""
|
||||
return {_DUMMY_AGENT_ID}
|
||||
return {}
|
||||
|
||||
@PublicAPI
|
||||
def try_render(self, env_id: Optional[EnvID] = None) -> None:
|
||||
|
@ -313,7 +313,7 @@ class BaseEnv:
|
|||
True if the observations are contained within their respective
|
||||
spaces. False otherwise.
|
||||
"""
|
||||
self._space_contains(self.observation_space, x)
|
||||
return self._space_contains(self.observation_space, x)
|
||||
|
||||
@PublicAPI
|
||||
def action_space_contains(self, x: MultiEnvDict) -> bool:
|
||||
|
@ -340,8 +340,15 @@ class BaseEnv:
|
|||
"""
|
||||
agents = set(self.get_agent_ids())
|
||||
for multi_agent_dict in x.values():
|
||||
for agent_id, obs in multi_agent_dict:
|
||||
if (agent_id not in agents) or (not space[agent_id].contains(obs)):
|
||||
for agent_id, obs in multi_agent_dict.items():
|
||||
# this is for the case where we have a single agent
|
||||
# and we're checking a Vector env thats been converted to
|
||||
# a BaseEnv
|
||||
if agent_id == _DUMMY_AGENT_ID:
|
||||
if not space.contains(obs):
|
||||
return False
|
||||
# for the MultiAgent env case
|
||||
elif (agent_id not in agents) or (not space[agent_id].contains(obs)):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
|
6
rllib/env/multi_agent_env.py
vendored
6
rllib/env/multi_agent_env.py
vendored
|
@ -178,7 +178,11 @@ class MultiAgentEnv(gym.Env):
|
|||
if agent_ids is None:
|
||||
agent_ids = self.get_agent_ids()
|
||||
samples = self.action_space.sample()
|
||||
return {agent_id: samples[agent_id] for agent_id in agent_ids}
|
||||
return {
|
||||
agent_id: samples[agent_id]
|
||||
for agent_id in agent_ids
|
||||
if agent_id != "__all__"
|
||||
}
|
||||
logger.warning("action_space_sample() has not been implemented")
|
||||
del agent_ids
|
||||
return {}
|
||||
|
|
22
rllib/env/vector_env.py
vendored
22
rllib/env/vector_env.py
vendored
|
@ -1,9 +1,9 @@
|
|||
import logging
|
||||
import gym
|
||||
import numpy as np
|
||||
from typing import Callable, List, Optional, Tuple, Union
|
||||
from typing import Callable, List, Optional, Tuple, Union, Set
|
||||
|
||||
from ray.rllib.env.base_env import BaseEnv
|
||||
from ray.rllib.env.base_env import BaseEnv, _DUMMY_AGENT_ID
|
||||
from ray.rllib.utils.annotations import Deprecated, override, PublicAPI
|
||||
from ray.rllib.utils.typing import (
|
||||
EnvActionType,
|
||||
|
@ -12,6 +12,7 @@ from ray.rllib.utils.typing import (
|
|||
EnvObsType,
|
||||
EnvType,
|
||||
MultiEnvDict,
|
||||
AgentID,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -355,3 +356,20 @@ class VectorEnvWrapper(BaseEnv):
|
|||
@PublicAPI
|
||||
def action_space(self) -> gym.Space:
|
||||
return self._action_space
|
||||
|
||||
@override(BaseEnv)
|
||||
@PublicAPI
|
||||
def action_space_sample(self, agent_id: list = None) -> MultiEnvDict:
|
||||
del agent_id
|
||||
return {0: {_DUMMY_AGENT_ID: self._action_space.sample()}}
|
||||
|
||||
@override(BaseEnv)
|
||||
@PublicAPI
|
||||
def observation_space_sample(self, agent_id: list = None) -> MultiEnvDict:
|
||||
del agent_id
|
||||
return {0: {_DUMMY_AGENT_ID: self._observation_space.sample()}}
|
||||
|
||||
@override(BaseEnv)
|
||||
@PublicAPI
|
||||
def get_agent_ids(self) -> Set[AgentID]:
|
||||
return {_DUMMY_AGENT_ID}
|
||||
|
|
2
rllib/env/wrappers/group_agents_wrapper.py
vendored
2
rllib/env/wrappers/group_agents_wrapper.py
vendored
|
@ -55,6 +55,8 @@ class GroupAgentsWrapper(MultiAgentEnv):
|
|||
self.observation_space = obs_space
|
||||
if act_space is not None:
|
||||
self.action_space = act_space
|
||||
for group_id in groups.keys():
|
||||
self._agent_ids.add(group_id)
|
||||
|
||||
def seed(self, seed=None):
|
||||
if not hasattr(self.env, "seed"):
|
||||
|
|
2
rllib/env/wrappers/open_spiel.py
vendored
2
rllib/env/wrappers/open_spiel.py
vendored
|
@ -9,7 +9,7 @@ class OpenSpielEnv(MultiAgentEnv):
|
|||
def __init__(self, env):
|
||||
super().__init__()
|
||||
self.env = env
|
||||
|
||||
self._skip_env_checking = True
|
||||
# Agent IDs are ints, starting from 0.
|
||||
self.num_agents = self.env.num_players()
|
||||
# Store the open-spiel game type.
|
||||
|
|
3
rllib/env/wrappers/pettingzoo_env.py
vendored
3
rllib/env/wrappers/pettingzoo_env.py
vendored
|
@ -70,6 +70,8 @@ class PettingZooEnv(MultiAgentEnv):
|
|||
super().__init__()
|
||||
self.env = env
|
||||
env.reset()
|
||||
self._skip_env_checking = True # TODO avnishn - remove this after making
|
||||
# petting zoo env compatible with check_env
|
||||
|
||||
# Get first observation space, assuming all agents have equal space
|
||||
self.observation_space = self.env.observation_space(self.env.agents[0])
|
||||
|
@ -94,6 +96,7 @@ class PettingZooEnv(MultiAgentEnv):
|
|||
"SuperSuit's pad_action_space wrapper can help (usage: "
|
||||
"`supersuit.aec_wrappers.pad_action_space(env)`"
|
||||
)
|
||||
self._agent_ids = set(self.env.agents)
|
||||
|
||||
def reset(self):
|
||||
self.env.reset()
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
import copy
|
||||
import gym
|
||||
from gym.spaces import Box, Discrete, MultiDiscrete, Space
|
||||
from gym.spaces import Discrete, MultiDiscrete, Space
|
||||
import logging
|
||||
import numpy as np
|
||||
import platform
|
||||
|
@ -25,11 +25,9 @@ from ray import ObjectRef
|
|||
from ray import cloudpickle as pickle
|
||||
from ray.rllib.env.base_env import BaseEnv, convert_to_base_env
|
||||
from ray.rllib.env.env_context import EnvContext
|
||||
from ray.rllib.env.external_env import ExternalEnv
|
||||
from ray.rllib.env.multi_agent_env import MultiAgentEnv
|
||||
from ray.rllib.env.external_multi_agent_env import ExternalMultiAgentEnv
|
||||
from ray.rllib.env.utils import record_env_wrapper
|
||||
from ray.rllib.env.vector_env import VectorEnv
|
||||
from ray.rllib.env.wrappers.atari_wrappers import wrap_deepmind, is_atari
|
||||
from ray.rllib.evaluation.sampler import AsyncSampler, SyncSampler
|
||||
from ray.rllib.evaluation.metrics import RolloutMetrics
|
||||
|
@ -43,11 +41,11 @@ from ray.rllib.policy.sample_batch import MultiAgentBatch, DEFAULT_POLICY_ID
|
|||
from ray.rllib.policy.policy import Policy, PolicySpec
|
||||
from ray.rllib.policy.policy_map import PolicyMap
|
||||
from ray.rllib.policy.torch_policy import TorchPolicy
|
||||
from ray.rllib.utils import force_list, merge_dicts
|
||||
from ray.rllib.utils import force_list, merge_dicts, check_env
|
||||
from ray.rllib.utils.annotations import Deprecated, DeveloperAPI, ExperimentalAPI
|
||||
from ray.rllib.utils.debug import summarize, update_global_seed_if_necessary
|
||||
from ray.rllib.utils.deprecation import deprecation_warning
|
||||
from ray.rllib.utils.error import EnvError, ERR_MSG_NO_GPUS, HOWTO_CHANGE_CONFIG
|
||||
from ray.rllib.utils.error import ERR_MSG_NO_GPUS, HOWTO_CHANGE_CONFIG
|
||||
from ray.rllib.utils.filter import get_filter, Filter
|
||||
from ray.rllib.utils.framework import try_import_tf, try_import_torch
|
||||
from ray.rllib.utils.sgd import do_minibatch_sgd
|
||||
|
@ -251,6 +249,7 @@ class RolloutWorker(ParallelIteratorWorker):
|
|||
spaces: Optional[Dict[PolicyID, Tuple[Space, Space]]] = None,
|
||||
policy=None,
|
||||
monitor_path=None,
|
||||
disable_env_checking=False,
|
||||
):
|
||||
"""Initializes a RolloutWorker instance.
|
||||
|
||||
|
@ -365,6 +364,8 @@ class RolloutWorker(ParallelIteratorWorker):
|
|||
Env is created on this RolloutWorker.
|
||||
policy: Obsoleted arg. Use `policy_spec` instead.
|
||||
monitor_path: Obsoleted arg. Use `record_env` instead.
|
||||
disable_env_checking: If True, disables the env checking module that
|
||||
validates the properties of the passed environment.
|
||||
"""
|
||||
|
||||
# Deprecated args.
|
||||
|
@ -463,6 +464,7 @@ class RolloutWorker(ParallelIteratorWorker):
|
|||
self.last_batch: Optional[SampleBatchType] = None
|
||||
self.global_vars: Optional[dict] = None
|
||||
self.fake_sampler: bool = fake_sampler
|
||||
self._disable_env_checking: bool = disable_env_checking
|
||||
|
||||
# Update the global seed for numpy/random/tf-eager/torch if we are not
|
||||
# the local worker, otherwise, this was already done in the Trainer
|
||||
|
@ -492,8 +494,19 @@ class RolloutWorker(ParallelIteratorWorker):
|
|||
|
||||
if self.env is not None:
|
||||
# Validate environment (general validation function).
|
||||
_validate_env(self.env, env_context=self.env_context)
|
||||
# Custom validation function given.
|
||||
if not self._disable_env_checking:
|
||||
logger.warning(
|
||||
"We've added a module for checking environments that "
|
||||
"are used in experiments. It will cause your "
|
||||
"environment to fail if your environment is not set up"
|
||||
"correctly. You can disable check env by setting "
|
||||
"`disable_env_checking` to True in your experiment config "
|
||||
"dictionary. You can run the environment checking module "
|
||||
"standalone by calling ray.rllib.utils.check_env(env)."
|
||||
)
|
||||
check_env(self.env)
|
||||
# Custom validation function given, typically a function attribute of the
|
||||
# algorithm trainer.
|
||||
if validate_env is not None:
|
||||
validate_env(self.env, self.env_context)
|
||||
# We can't auto-wrap a BaseEnv.
|
||||
|
@ -1717,6 +1730,8 @@ class RolloutWorker(ParallelIteratorWorker):
|
|||
def _get_make_sub_env_fn(
|
||||
self, env_creator, env_context, validate_env, env_wrapper, seed
|
||||
):
|
||||
disable_env_checking = self._disable_env_checking
|
||||
|
||||
def _make_sub_env_local(vector_index):
|
||||
# Used to created additional environments during environment
|
||||
# vectorization.
|
||||
|
@ -1727,7 +1742,17 @@ class RolloutWorker(ParallelIteratorWorker):
|
|||
# Create the sub-env.
|
||||
env = env_creator(env_ctx)
|
||||
# Validate first.
|
||||
_validate_env(env, env_context=env_ctx)
|
||||
if not disable_env_checking:
|
||||
logger.warning(
|
||||
"We've added a module for checking environments that "
|
||||
"are used in experiments. It will cause your "
|
||||
"environment to fail if your environment is not set up"
|
||||
"correctly. You can disable check env by setting "
|
||||
"`disable_env_checking` to True in your experiment config "
|
||||
"dictionary. You can run the environment checking module "
|
||||
"standalone by calling ray.rllib.utils.check_env(env)."
|
||||
)
|
||||
check_env(env)
|
||||
# Custom validation function given by user.
|
||||
if validate_env is not None:
|
||||
validate_env(env, env_ctx)
|
||||
|
@ -1870,52 +1895,3 @@ def _determine_spaces_for_multi_agent_dict(
|
|||
action_space=act_space
|
||||
)
|
||||
return multi_agent_dict
|
||||
|
||||
|
||||
def _validate_env(env: EnvType, env_context: EnvContext = None):
|
||||
# Base message for checking the env for vector-index=0
|
||||
msg = f"Validating sub-env at vector index={env_context.vector_index} ..."
|
||||
|
||||
allowed_types = [gym.Env, ExternalEnv, VectorEnv, BaseEnv, ray.actor.ActorHandle]
|
||||
if not any(isinstance(env, tpe) for tpe in allowed_types):
|
||||
# Allow this as a special case (assumed gym.Env).
|
||||
# TODO: Disallow this early-out. Everything should conform to a few
|
||||
# supported classes, i.e. gym.Env/MultiAgentEnv/etc...
|
||||
if hasattr(env, "observation_space") and hasattr(env, "action_space"):
|
||||
logger.warning(msg + f" (warning; invalid env-type={type(env)})")
|
||||
return
|
||||
else:
|
||||
logger.warning(msg + " (NOT OK)")
|
||||
raise EnvError(
|
||||
"Returned env should be an instance of gym.Env (incl. "
|
||||
"MultiAgentEnv), ExternalEnv, VectorEnv, or BaseEnv. "
|
||||
f"The provided env creator function returned {env} "
|
||||
f"(type={type(env)})."
|
||||
)
|
||||
|
||||
# Do some test runs with the provided env.
|
||||
if isinstance(env, gym.Env) and not isinstance(env, MultiAgentEnv):
|
||||
# Make sure the gym.Env has the two space attributes properly set.
|
||||
assert hasattr(env, "observation_space") and hasattr(env, "action_space")
|
||||
# Get a dummy observation by resetting the env.
|
||||
dummy_obs = env.reset()
|
||||
# Convert lists to np.ndarrays.
|
||||
if type(dummy_obs) is list and isinstance(env.observation_space, Box):
|
||||
dummy_obs = np.array(dummy_obs)
|
||||
# Ignore float32/float64 diffs.
|
||||
if (
|
||||
isinstance(env.observation_space, Box)
|
||||
and env.observation_space.dtype != dummy_obs.dtype
|
||||
):
|
||||
dummy_obs = dummy_obs.astype(env.observation_space.dtype)
|
||||
# Check, if observation is ok (part of the observation space). If not,
|
||||
# error.
|
||||
if not env.observation_space.contains(dummy_obs):
|
||||
logger.warning(msg + " (NOT OK)")
|
||||
raise EnvError(
|
||||
f"Env's `observation_space` {env.observation_space} does not "
|
||||
f"contain returned observation after a reset ({dummy_obs})!"
|
||||
)
|
||||
|
||||
# Log that everything is ok.
|
||||
logger.info(msg + " (ok)")
|
||||
|
|
|
@ -74,6 +74,7 @@ class EchoPolicy(Policy):
|
|||
class EpisodeEnv(MultiAgentEnv):
|
||||
def __init__(self, episode_length, num):
|
||||
super().__init__()
|
||||
self._skip_env_checking = True
|
||||
self.agents = [MockEnv3(episode_length) for _ in range(num)]
|
||||
self.dones = set()
|
||||
self.observation_space = self.agents[0].observation_space
|
||||
|
|
|
@ -591,6 +591,7 @@ class WorkerSet:
|
|||
fake_sampler=config["fake_sampler"],
|
||||
extra_python_environs=extra_python_environs,
|
||||
spaces=spaces,
|
||||
disable_env_checking=config["disable_env_checking"],
|
||||
)
|
||||
|
||||
return worker
|
||||
|
|
2
rllib/examples/env/action_mask_env.py
vendored
2
rllib/examples/env/action_mask_env.py
vendored
|
@ -9,7 +9,7 @@ class ActionMaskEnv(RandomEnv):
|
|||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
||||
self._skip_env_checking = True
|
||||
# Masking only works for Discrete actions.
|
||||
assert isinstance(self.action_space, Discrete)
|
||||
# Add action_mask to observations.
|
||||
|
|
1
rllib/examples/env/debug_counter_env.py
vendored
1
rllib/examples/env/debug_counter_env.py
vendored
|
@ -34,6 +34,7 @@ class DebugCounterEnv(gym.Env):
|
|||
class MultiAgentDebugCounterEnv(MultiAgentEnv):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self._skip_env_checking = True
|
||||
self.num_agents = config["num_agents"]
|
||||
self.base_episode_len = config.get("base_episode_len", 103)
|
||||
# Actions are always:
|
||||
|
|
11
rllib/examples/env/multi_agent.py
vendored
11
rllib/examples/env/multi_agent.py
vendored
|
@ -27,6 +27,7 @@ class BasicMultiAgent(MultiAgentEnv):
|
|||
def __init__(self, num):
|
||||
super().__init__()
|
||||
self.agents = [MockEnv(25) for _ in range(num)]
|
||||
self._agent_ids = set(range(num))
|
||||
self.dones = set()
|
||||
self.observation_space = gym.spaces.Discrete(2)
|
||||
self.action_space = gym.spaces.Discrete(2)
|
||||
|
@ -59,6 +60,7 @@ class EarlyDoneMultiAgent(MultiAgentEnv):
|
|||
def __init__(self):
|
||||
super().__init__()
|
||||
self.agents = [MockEnv(3), MockEnv(5)]
|
||||
self._agent_ids = set(range(len(self.agents)))
|
||||
self.dones = set()
|
||||
self.last_obs = {}
|
||||
self.last_rew = {}
|
||||
|
@ -77,7 +79,7 @@ class EarlyDoneMultiAgent(MultiAgentEnv):
|
|||
self.i = 0
|
||||
for i, a in enumerate(self.agents):
|
||||
self.last_obs[i] = a.reset()
|
||||
self.last_rew[i] = None
|
||||
self.last_rew[i] = 0
|
||||
self.last_done[i] = False
|
||||
self.last_info[i] = {}
|
||||
obs_dict = {self.i: self.last_obs[self.i]}
|
||||
|
@ -111,6 +113,7 @@ class FlexAgentsMultiAgent(MultiAgentEnv):
|
|||
def __init__(self):
|
||||
super().__init__()
|
||||
self.agents = {}
|
||||
self._agent_ids = set()
|
||||
self.agentID = 0
|
||||
self.dones = set()
|
||||
self.observation_space = gym.spaces.Discrete(2)
|
||||
|
@ -121,15 +124,16 @@ class FlexAgentsMultiAgent(MultiAgentEnv):
|
|||
# Spawn a new agent into the current episode.
|
||||
agentID = self.agentID
|
||||
self.agents[agentID] = MockEnv(25)
|
||||
self._agent_ids.add(agentID)
|
||||
self.agentID += 1
|
||||
return agentID
|
||||
|
||||
def reset(self):
|
||||
self.agents = {}
|
||||
self._agent_ids = set()
|
||||
self.spawn()
|
||||
self.resetted = True
|
||||
self.dones = set()
|
||||
|
||||
obs = {}
|
||||
for i, a in self.agents.items():
|
||||
obs[i] = a.reset()
|
||||
|
@ -175,6 +179,7 @@ class RoundRobinMultiAgent(MultiAgentEnv):
|
|||
else:
|
||||
# Observations are all zeros
|
||||
self.agents = [MockEnv(5) for _ in range(num)]
|
||||
self._agent_ids = set(range(num))
|
||||
self.dones = set()
|
||||
self.last_obs = {}
|
||||
self.last_rew = {}
|
||||
|
@ -194,7 +199,7 @@ class RoundRobinMultiAgent(MultiAgentEnv):
|
|||
self.i = 0
|
||||
for i, a in enumerate(self.agents):
|
||||
self.last_obs[i] = a.reset()
|
||||
self.last_rew[i] = None
|
||||
self.last_rew[i] = 0
|
||||
self.last_done[i] = False
|
||||
self.last_info[i] = {}
|
||||
obs_dict = {self.i: self.last_obs[self.i]}
|
||||
|
|
|
@ -40,6 +40,7 @@ class ParametricActionsCartPole(gym.Env):
|
|||
"cart": self.wrapped.observation_space,
|
||||
}
|
||||
)
|
||||
self._skip_env_checking = True
|
||||
|
||||
def update_avail_actions(self):
|
||||
self.action_assignments = np.array(
|
||||
|
@ -114,6 +115,7 @@ class ParametricActionsCartPoleNoEmbeddings(gym.Env):
|
|||
"cart": self.wrapped.observation_space,
|
||||
}
|
||||
)
|
||||
self._skip_env_checking = True
|
||||
|
||||
def reset(self):
|
||||
return {
|
||||
|
|
5
rllib/examples/env/two_step_game.py
vendored
5
rllib/examples/env/two_step_game.py
vendored
|
@ -13,11 +13,12 @@ class TwoStepGame(MultiAgentEnv):
|
|||
self.state = None
|
||||
self.agent_1 = 0
|
||||
self.agent_2 = 1
|
||||
self._skip_env_checking = True
|
||||
# MADDPG emits action logits instead of actual discrete actions
|
||||
self.actions_are_logits = env_config.get("actions_are_logits", False)
|
||||
self.one_hot_state_encoding = env_config.get("one_hot_state_encoding", False)
|
||||
self.with_state = env_config.get("separate_state_space", False)
|
||||
|
||||
self._agent_ids = {0, 1}
|
||||
if not self.one_hot_state_encoding:
|
||||
self.observation_space = Discrete(6)
|
||||
self.with_state = False
|
||||
|
@ -113,6 +114,8 @@ class TwoStepGameWithGroupedAgents(MultiAgentEnv):
|
|||
)
|
||||
self.observation_space = self.env.observation_space
|
||||
self.action_space = self.env.action_space
|
||||
self._agent_ids = {"agents"}
|
||||
self._skip_env_checking = True
|
||||
|
||||
def reset(self):
|
||||
return self.env.reset()
|
||||
|
|
1
rllib/examples/env/windy_maze_env.py
vendored
1
rllib/examples/env/windy_maze_env.py
vendored
|
@ -86,6 +86,7 @@ class WindyMazeEnv(gym.Env):
|
|||
class HierarchicalWindyMazeEnv(MultiAgentEnv):
|
||||
def __init__(self, env_config):
|
||||
super().__init__()
|
||||
self._skip_env_checking = True
|
||||
self.flat_env = WindyMazeEnv(env_config)
|
||||
|
||||
def reset(self):
|
||||
|
|
|
@ -60,6 +60,7 @@ class NonVectorizedEnvToBeVectorizedIntoRemoteBaseEnv(TaskSettableEnv):
|
|||
"""
|
||||
|
||||
def __init__(self, config=None):
|
||||
super().__init__()
|
||||
self.action_space = gym.spaces.Box(0, 1, shape=(1,))
|
||||
self.observation_space = gym.spaces.Box(0, 1, shape=(2,))
|
||||
self.task = 1
|
||||
|
|
|
@ -26,6 +26,7 @@ class FaultInjectEnv(gym.Env):
|
|||
|
||||
def __init__(self, config):
|
||||
self.env = gym.make("CartPole-v0")
|
||||
self._skip_env_checking = True
|
||||
self.action_space = self.env.action_space
|
||||
self.observation_space = self.env.observation_space
|
||||
self.config = config
|
||||
|
|
|
@ -404,6 +404,7 @@ class NestedObservationSpacesTest(unittest.TestCase):
|
|||
"use_lstm": test_lstm,
|
||||
},
|
||||
"framework": "tf",
|
||||
"disable_env_checking": True,
|
||||
},
|
||||
)
|
||||
# Skip first passes as they came from the TorchPolicy loss
|
||||
|
@ -436,6 +437,7 @@ class NestedObservationSpacesTest(unittest.TestCase):
|
|||
"custom_model": "composite2",
|
||||
},
|
||||
"framework": "tf",
|
||||
"disable_env_checking": True,
|
||||
},
|
||||
)
|
||||
# Skip first passes as they came from the TorchPolicy loss
|
||||
|
@ -518,6 +520,7 @@ class NestedObservationSpacesTest(unittest.TestCase):
|
|||
}[aid],
|
||||
},
|
||||
"framework": "tf",
|
||||
"disable_env_checking": True,
|
||||
},
|
||||
)
|
||||
# Skip first passes as they came from the TorchPolicy loss
|
||||
|
|
|
@ -1,8 +1,12 @@
|
|||
"""Common pre-checks for all RLlib experiments."""
|
||||
import logging
|
||||
import numpy as np
|
||||
from typing import TYPE_CHECKING, Set
|
||||
|
||||
import gym
|
||||
|
||||
from ray.actor import ActorHandle
|
||||
from ray.rllib.utils.spaces.space_utils import convert_element_to_space_type
|
||||
from ray.rllib.utils.typing import EnvType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
@ -21,12 +25,38 @@ def check_env(env: EnvType) -> None:
|
|||
ValueError: If env is not an instance of SUPPORTED_ENVIRONMENT_TYPES.
|
||||
ValueError: See check_gym_env docstring for details.
|
||||
"""
|
||||
from ray.rllib.env import BaseEnv, MultiAgentEnv, RemoteBaseEnv, VectorEnv
|
||||
from ray.rllib.env import (
|
||||
BaseEnv,
|
||||
MultiAgentEnv,
|
||||
RemoteBaseEnv,
|
||||
VectorEnv,
|
||||
ExternalMultiAgentEnv,
|
||||
ExternalEnv,
|
||||
)
|
||||
|
||||
if not isinstance(env, (BaseEnv, gym.Env, MultiAgentEnv, RemoteBaseEnv, VectorEnv)):
|
||||
if hasattr(env, "_skip_env_checking") and env._skip_env_checking:
|
||||
# This is a work around for some environments that we already have in RLlb
|
||||
# that we want to skip checking for now until we have the time to fix them.
|
||||
logger.warning("Skipping env checking for this experiment")
|
||||
return
|
||||
|
||||
if not isinstance(
|
||||
env,
|
||||
(
|
||||
BaseEnv,
|
||||
gym.Env,
|
||||
MultiAgentEnv,
|
||||
RemoteBaseEnv,
|
||||
VectorEnv,
|
||||
ExternalMultiAgentEnv,
|
||||
ExternalEnv,
|
||||
ActorHandle,
|
||||
),
|
||||
):
|
||||
raise ValueError(
|
||||
"Env must be one of the supported types: BaseEnv, gym.Env, "
|
||||
"MultiAgentEnv, VectorEnv, RemoteBaseEnv"
|
||||
"MultiAgentEnv, VectorEnv, RemoteBaseEnv, ExternalMultiAgentEnv, "
|
||||
f"ExternalEnv, but instead was a {type(env)}"
|
||||
)
|
||||
|
||||
if isinstance(env, MultiAgentEnv):
|
||||
|
@ -37,7 +67,8 @@ def check_env(env: EnvType) -> None:
|
|||
check_base_env(env)
|
||||
else:
|
||||
logger.warning(
|
||||
"Env checking isn't implemented for VectorEnvs or " "RemoteBaseEnvs."
|
||||
"Env checking isn't implemented for VectorEnvs, RemoteBaseEnvs, "
|
||||
"ExternalMultiAgentEnv,or ExternalEnvs or Environments that are Ray actors"
|
||||
)
|
||||
|
||||
|
||||
|
@ -97,33 +128,11 @@ def check_gym_environments(env: gym.Env) -> None:
|
|||
# check if sampled actions and observations are contained within their
|
||||
# respective action and observation spaces.
|
||||
|
||||
def contains_error(action_or_observation, sample, space):
|
||||
string_type = "observation" if not action_or_observation else "action"
|
||||
sample_type = get_type(sample)
|
||||
_space_type = space.dtype
|
||||
ret = (
|
||||
f"A sampled {string_type} from your env wasn't contained "
|
||||
f"within your env's {string_type} space. Its possible that "
|
||||
f"there was a type mismatch, or that one of the "
|
||||
f"sub-{string_type} was out of bounds:\n\nsampled_obs: "
|
||||
f"{sample}\nenv.{string_type}_space: {space}"
|
||||
f"\nsampled_obs's dtype: {sample_type}"
|
||||
f"\nenv.{sample_type}'s dtype: {_space_type}"
|
||||
)
|
||||
return ret
|
||||
|
||||
def get_type(var):
|
||||
return var.dtype if hasattr(var, "dtype") else type(var)
|
||||
|
||||
sampled_action = env.action_space.sample()
|
||||
sampled_observation = env.observation_space.sample()
|
||||
if not env.observation_space.contains(sampled_observation):
|
||||
raise ValueError(
|
||||
contains_error(False, sampled_observation, env.observation_space)
|
||||
)
|
||||
if not env.action_space.contains(sampled_action):
|
||||
raise ValueError(contains_error(True, sampled_action, env.action_space))
|
||||
|
||||
# check if observation generated from stepping the environment is
|
||||
# contained within the observation space
|
||||
reset_obs = env.reset()
|
||||
|
@ -140,7 +149,11 @@ def check_gym_environments(env: gym.Env) -> None:
|
|||
f"{reset_obs_type}\n\n env.observation_space's dtype: "
|
||||
f"{space_type}"
|
||||
)
|
||||
raise ValueError(error)
|
||||
temp_sampled_reset_obs = convert_element_to_space_type(
|
||||
reset_obs, sampled_observation
|
||||
)
|
||||
if not env.observation_space.contains(temp_sampled_reset_obs):
|
||||
raise ValueError(error)
|
||||
# check if env.step can run, and generates observations rewards, done
|
||||
# signals and infos that are within their respective spaces and are of
|
||||
# the correct dtypes
|
||||
|
@ -157,7 +170,11 @@ def check_gym_environments(env: gym.Env) -> None:
|
|||
f"\n\n next_obs's dtype: {next_obs_type}"
|
||||
f"\n\n env.observation_space's dtype: {space_type}"
|
||||
)
|
||||
raise ValueError(error)
|
||||
temp_sampled_next_obs = convert_element_to_space_type(
|
||||
next_obs, sampled_observation
|
||||
)
|
||||
if not env.observation_space.contains(temp_sampled_next_obs):
|
||||
raise ValueError(error)
|
||||
_check_done(done)
|
||||
_check_reward(reward)
|
||||
_check_info(info)
|
||||
|
@ -222,10 +239,15 @@ def check_multiagent_environments(env: "MultiAgentEnv") -> None:
|
|||
raise ValueError(error)
|
||||
|
||||
next_obs, reward, done, info = env.step(sampled_action)
|
||||
_check_if_element_multi_agent_dict(env, next_obs, "step(sampled_action)")
|
||||
_check_reward(reward)
|
||||
_check_done(done)
|
||||
_check_info(info)
|
||||
_check_if_element_multi_agent_dict(env, next_obs, "step, next_obs")
|
||||
_check_if_element_multi_agent_dict(env, reward, "step, reward")
|
||||
_check_if_element_multi_agent_dict(env, done, "step, done")
|
||||
_check_if_element_multi_agent_dict(env, info, "step, info")
|
||||
_check_reward(
|
||||
{"dummy_env_id": reward}, base_env=True, agent_ids=env.get_agent_ids()
|
||||
)
|
||||
_check_done({"dummy_env_id": done}, base_env=True, agent_ids=env.get_agent_ids())
|
||||
_check_info({"dummy_env_id": info}, base_env=True, agent_ids=env.get_agent_ids())
|
||||
if not env.observation_space_contains(next_obs):
|
||||
error = (
|
||||
_not_contained_error("env.step(sampled_action)", "observation")
|
||||
|
@ -255,7 +277,7 @@ def check_base_env(env: "BaseEnv") -> None:
|
|||
env.observation_space_contains(reset_obs)
|
||||
except Exception as e:
|
||||
raise ValueError(
|
||||
"Your observation_space_contains function has some " "error "
|
||||
"Your observation_space_contains function has some error "
|
||||
) from e
|
||||
|
||||
if not env.observation_space_contains(reset_obs):
|
||||
|
@ -303,51 +325,87 @@ def check_base_env(env: "BaseEnv") -> None:
|
|||
)
|
||||
raise ValueError(error)
|
||||
|
||||
_check_reward(reward, base_env=True)
|
||||
_check_done(done, base_env=True)
|
||||
_check_info(info, base_env=True)
|
||||
_check_reward(reward, base_env=True, agent_ids=env.get_agent_ids())
|
||||
_check_done(done, base_env=True, agent_ids=env.get_agent_ids())
|
||||
_check_info(info, base_env=True, agent_ids=env.get_agent_ids())
|
||||
|
||||
|
||||
def _check_reward(reward, base_env=False):
|
||||
def _check_reward(reward, base_env=False, agent_ids=None):
|
||||
if base_env:
|
||||
for _, multi_agent_dict in reward.items():
|
||||
for _, rew in multi_agent_dict.items():
|
||||
assert isinstance(rew, (float, int)), (
|
||||
"Your step function must return a rewards that are"
|
||||
f" integer or float. reward: {rew}"
|
||||
)
|
||||
else:
|
||||
assert isinstance(
|
||||
reward, (float, int)
|
||||
), "Your step function must return a reward that is integer or float."
|
||||
for agent_id, rew in multi_agent_dict.items():
|
||||
if not (
|
||||
np.isreal(rew) and not isinstance(rew, bool) and np.isscalar(rew)
|
||||
):
|
||||
error = (
|
||||
"Your step function must return rewards that are"
|
||||
f" integer or float. reward: {rew}. Instead it was a "
|
||||
f"{type(reward)}"
|
||||
)
|
||||
raise ValueError(error)
|
||||
if not (agent_id in agent_ids or agent_id == "__all__"):
|
||||
error = (
|
||||
f"Your reward dictionary must have agent ids that belong to "
|
||||
f"the environment. Agent_ids recieved from "
|
||||
f"env.get_agent_ids() are: {agent_ids}"
|
||||
)
|
||||
raise ValueError(error)
|
||||
elif not (
|
||||
np.isreal(reward) and not isinstance(reward, bool) and np.isscalar(reward)
|
||||
):
|
||||
error = (
|
||||
"Your step function must return a reward that is integer or float. "
|
||||
"Instead it was a {}".format(type(reward))
|
||||
)
|
||||
raise ValueError(error)
|
||||
|
||||
|
||||
def _check_done(done, base_env=False):
|
||||
def _check_done(done, base_env=False, agent_ids=None):
|
||||
if base_env:
|
||||
for _, multi_agent_dict in done.items():
|
||||
for _, done_ in multi_agent_dict.items():
|
||||
assert isinstance(done_, bool), (
|
||||
"Your step function must return a done that is boolean. "
|
||||
f"element: {done_}"
|
||||
)
|
||||
else:
|
||||
assert isinstance(done, bool), (
|
||||
"Your step function must return a done that is a " "boolean."
|
||||
for agent_id, done_ in multi_agent_dict.items():
|
||||
if not isinstance(done_, (bool, np.bool, np.bool_)):
|
||||
raise ValueError(
|
||||
"Your step function must return dones that are boolean. But "
|
||||
f"instead was a {type(done)}"
|
||||
)
|
||||
if not (agent_id in agent_ids or agent_id == "__all__"):
|
||||
error = (
|
||||
f"Your dones dictionary must have agent ids that belong to "
|
||||
f"the environment. Agent_ids recieved from "
|
||||
f"env.get_agent_ids() are: {agent_ids}"
|
||||
)
|
||||
raise ValueError(error)
|
||||
elif not isinstance(done, (bool, np.bool, np.bool_)):
|
||||
error = (
|
||||
"Your step function must return a done that is a boolean. But instead "
|
||||
f"was a {type(done)}"
|
||||
)
|
||||
raise ValueError(error)
|
||||
|
||||
|
||||
def _check_info(info, base_env=False):
|
||||
def _check_info(info, base_env=False, agent_ids=None):
|
||||
if base_env:
|
||||
for _, multi_agent_dict in info.items():
|
||||
for _, inf in multi_agent_dict.items():
|
||||
assert isinstance(inf, dict), (
|
||||
"Your step function must return a info that is a dict. "
|
||||
f"element: {inf}"
|
||||
)
|
||||
else:
|
||||
assert isinstance(
|
||||
info, dict
|
||||
), "Your step function must return a info that is a dict."
|
||||
for agent_id, inf in multi_agent_dict.items():
|
||||
if not isinstance(inf, dict):
|
||||
raise ValueError(
|
||||
"Your step function must return infos that are a dict. "
|
||||
f"instead was a {type(inf)}: element: {inf}"
|
||||
)
|
||||
if not (agent_id in agent_ids or agent_id == "__all__"):
|
||||
error = (
|
||||
f"Your dones dictionary must have agent ids that belong to "
|
||||
f"the environment. Agent_ids recieved from "
|
||||
f"env.get_agent_ids() are: {agent_ids}"
|
||||
)
|
||||
raise ValueError(error)
|
||||
elif not isinstance(info, dict):
|
||||
error = (
|
||||
"Your step function must return a info that "
|
||||
f"is a dict. element type: {type(info)}. element: {info}"
|
||||
)
|
||||
raise ValueError(error)
|
||||
|
||||
|
||||
def _not_contained_error(func_name, _type):
|
||||
|
@ -398,6 +456,7 @@ def _check_if_element_multi_agent_dict(env, element, function_string, base_env=F
|
|||
raise ValueError(error)
|
||||
agent_ids: Set = env.get_agent_ids()
|
||||
agent_ids.add("__all__")
|
||||
|
||||
if not all(k in agent_ids for k in element):
|
||||
if base_env:
|
||||
error = (
|
||||
|
@ -413,6 +472,8 @@ def _check_if_element_multi_agent_dict(env, element, function_string, base_env=F
|
|||
f" that are not the names of the agents in the env. "
|
||||
f"\nAgent_ids in this MultiAgentDict: "
|
||||
f"{list(element.keys())}\nAgent_ids in this env:"
|
||||
f"{list(env.get_agent_ids())}"
|
||||
f"{list(env.get_agent_ids())}. You likley need to add the private "
|
||||
f"attribute `_agent_ids` to your env, which is a set containing the "
|
||||
f"ids of agents supported by your env."
|
||||
)
|
||||
raise ValueError(error)
|
||||
|
|
|
@ -336,8 +336,9 @@ def convert_element_to_space_type(element: Any, sampled_element: Any) -> Any:
|
|||
elem = elem.astype(s.dtype)
|
||||
|
||||
elif isinstance(s, int):
|
||||
if isinstance(elem, float):
|
||||
if isinstance(elem, float) and elem.is_integer():
|
||||
elem = int(elem)
|
||||
|
||||
return elem
|
||||
|
||||
return tree.map_structure(map_, element, sampled_element, check_types=False)
|
||||
|
|
|
@ -1,9 +1,12 @@
|
|||
import logging
|
||||
|
||||
import gym
|
||||
import numpy as np
|
||||
import pytest
|
||||
import unittest
|
||||
from unittest.mock import Mock, MagicMock
|
||||
|
||||
from ray.rllib.env.base_env import convert_to_base_env
|
||||
from ray.rllib.env.multi_agent_env import make_multi_agent, MultiAgentEnvWrapper
|
||||
from ray.rllib.examples.env.random_env import RandomEnv
|
||||
from ray.rllib.utils.pre_checks.env import (
|
||||
|
@ -40,31 +43,6 @@ class TestGymCheckEnv(unittest.TestCase):
|
|||
check_env(env)
|
||||
del env
|
||||
|
||||
def test_sampled_observation_contained(self):
|
||||
env = RandomEnv()
|
||||
# check for observation that is out of bounds
|
||||
error = ".*A sampled observation from your env wasn't contained .*"
|
||||
env.observation_space.sample = MagicMock(return_value=5)
|
||||
with pytest.raises(ValueError, match=error):
|
||||
check_env(env)
|
||||
# check for observation that is in bounds, but the wrong type
|
||||
env.observation_space.sample = MagicMock(return_value=float(1))
|
||||
with pytest.raises(ValueError, match=error):
|
||||
check_env(env)
|
||||
del env
|
||||
|
||||
def test_sampled_action_contained(self):
|
||||
env = RandomEnv()
|
||||
error = ".*A sampled action from your env wasn't contained .*"
|
||||
env.action_space.sample = MagicMock(return_value=5)
|
||||
with pytest.raises(ValueError, match=error):
|
||||
check_env(env)
|
||||
# check for observation that is in bounds, but the wrong type
|
||||
env.action_space.sample = MagicMock(return_value=float(1))
|
||||
with pytest.raises(ValueError, match=error):
|
||||
check_env(env)
|
||||
del env
|
||||
|
||||
def test_reset(self):
|
||||
reset = MagicMock(return_value=5)
|
||||
env = RandomEnv()
|
||||
|
@ -74,7 +52,7 @@ class TestGymCheckEnv(unittest.TestCase):
|
|||
with pytest.raises(ValueError, match=error):
|
||||
check_env(env)
|
||||
# check reset with obs of incorrect type fails
|
||||
reset = MagicMock(return_value=float(1))
|
||||
reset = MagicMock(return_value=float(0.1))
|
||||
env.reset = reset
|
||||
with pytest.raises(ValueError, match=error):
|
||||
check_env(env)
|
||||
|
@ -89,7 +67,7 @@ class TestGymCheckEnv(unittest.TestCase):
|
|||
check_env(env)
|
||||
|
||||
# check reset that returns obs of incorrect type fails
|
||||
step = MagicMock(return_value=(float(1), 5, True, {}))
|
||||
step = MagicMock(return_value=(float(0.1), 5, True, {}))
|
||||
env.step = step
|
||||
with pytest.raises(ValueError, match=error):
|
||||
check_env(env)
|
||||
|
@ -98,21 +76,21 @@ class TestGymCheckEnv(unittest.TestCase):
|
|||
step = MagicMock(return_value=(1, "Not a valid reward", True, {}))
|
||||
env.step = step
|
||||
error = "Your step function must return a reward that is integer or " "float."
|
||||
with pytest.raises(AssertionError, match=error):
|
||||
with pytest.raises(ValueError, match=error):
|
||||
check_env(env)
|
||||
|
||||
# check step that returns a non bool fails
|
||||
step = MagicMock(return_value=(1, float(5), "not a valid done signal", {}))
|
||||
env.step = step
|
||||
error = "Your step function must return a done that is a boolean."
|
||||
with pytest.raises(AssertionError, match=error):
|
||||
with pytest.raises(ValueError, match=error):
|
||||
check_env(env)
|
||||
|
||||
# check step that returns a non dict fails
|
||||
step = MagicMock(return_value=(1, float(5), True, "not a valid env info"))
|
||||
env.step = step
|
||||
error = "Your step function must return a info that is a dict."
|
||||
with pytest.raises(AssertionError, match=error):
|
||||
with pytest.raises(ValueError, match=error):
|
||||
check_env(env)
|
||||
del env
|
||||
|
||||
|
@ -176,24 +154,20 @@ class TestCheckMultiAgentEnv(unittest.TestCase):
|
|||
with pytest.raises(ValueError, match="The element returned by step"):
|
||||
check_env(env)
|
||||
|
||||
step = MagicMock(return_value=(sampled_obs, "Not a reward", True, {}))
|
||||
step = MagicMock(return_value=(sampled_obs, {0: "Not a reward"}, {0: True}, {}))
|
||||
env.step = step
|
||||
with pytest.raises(
|
||||
AssertionError, match="Your step function must " "return a reward "
|
||||
):
|
||||
with pytest.raises(ValueError, match="Your step function must return rewards"):
|
||||
check_env(env)
|
||||
step = MagicMock(return_value=(sampled_obs, 5, "Not a bool", {}))
|
||||
step = MagicMock(return_value=(sampled_obs, {0: 5}, {0: "Not a bool"}, {}))
|
||||
env.step = step
|
||||
with pytest.raises(
|
||||
AssertionError, match="Your step function must " "return a done"
|
||||
):
|
||||
with pytest.raises(ValueError, match="Your step function must return dones"):
|
||||
check_env(env)
|
||||
|
||||
step = MagicMock(return_value=(sampled_obs, 5, False, "Not a Dict"))
|
||||
step = MagicMock(
|
||||
return_value=(sampled_obs, {0: 5}, {0: False}, {0: "Not a Dict"})
|
||||
)
|
||||
env.step = step
|
||||
with pytest.raises(
|
||||
AssertionError, match="Your step function must " "return a info"
|
||||
):
|
||||
with pytest.raises(ValueError, match="Your step function must return infos"):
|
||||
check_env(env)
|
||||
|
||||
def test_bad_sample_function(self):
|
||||
|
@ -345,29 +319,32 @@ class TestCheckBaseEnv:
|
|||
poll = MagicMock(return_value=(good_obs, bad_reward, good_done, good_info, {}))
|
||||
env.poll = poll
|
||||
with pytest.raises(
|
||||
AssertionError, match="Your step function must " "return a rewards that are"
|
||||
ValueError, match="Your step function must return rewards that are"
|
||||
):
|
||||
check_env(env)
|
||||
bad_done = {0: {0: "not_done", 1: False}}
|
||||
poll = MagicMock(return_value=(good_obs, good_reward, bad_done, good_info, {}))
|
||||
env.poll = poll
|
||||
with pytest.raises(
|
||||
AssertionError,
|
||||
match="Your step function must " "return a done that is " "boolean.",
|
||||
ValueError,
|
||||
match="Your step function must return dones that are boolean.",
|
||||
):
|
||||
check_env(env)
|
||||
bad_info = {0: {0: "not_info", 1: {}}}
|
||||
poll = MagicMock(return_value=(good_obs, good_reward, good_done, bad_info, {}))
|
||||
env.poll = poll
|
||||
with pytest.raises(
|
||||
AssertionError,
|
||||
match="Your step function must" " return a info that is a " "dict.",
|
||||
ValueError,
|
||||
match="Your step function must return infos that are a dict.",
|
||||
):
|
||||
check_env(env)
|
||||
|
||||
def test_check_correct_env(self):
|
||||
env = self._make_base_env()
|
||||
check_env(env)
|
||||
env = gym.make("CartPole-v0")
|
||||
env = convert_to_base_env(env)
|
||||
check_env(env)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
Loading…
Add table
Reference in a new issue