[RLlib] Upgrade gym 0.23 (#24171)

This commit is contained in:
Sven Mika 2022-05-23 08:18:44 +02:00 committed by GitHub
parent c03d0432f3
commit 09886d7ab8
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
31 changed files with 133 additions and 367 deletions

View file

@ -80,7 +80,7 @@ _____________________________________________________________________
# __get_q_values_dqn_start__ # __get_q_values_dqn_start__
# Get a reference to the model through the policy # Get a reference to the model through the policy
import numpy as np import numpy as np
from ray.rllib.agents.dqn import DQNTrainer from ray.rllib.algorithms.dqn import DQNTrainer
trainer = DQNTrainer(env="CartPole-v0", config={"framework": "tf2"}) trainer = DQNTrainer(env="CartPole-v0", config={"framework": "tf2"})
model = trainer.get_policy().model model = trainer.get_policy().model

View file

@ -39,7 +39,7 @@ __[Full Ray Enhancement Proposal, REP-001: Serve Pipeline](https://github.com/ra
## Concepts ## Concepts
- **Deployment**: Scalable, upgradeable group of actors managed by Ray Serve. __[See docs for detail](https://docs.ray.io/en/master/serve/core-apis.html#core-api-deployments)__ - **Deployment**: Scalable, upgradeable group of actors managed by Ray Serve. __[See docs for detail](https://docs.ray.io/en/master/serve/package-ref.html#deployment-api)__
- **DeploymentNode**: Smallest unit in a graph, created by calling `.bind()` on a serve decorated class or function, backed by a Deployment. - **DeploymentNode**: Smallest unit in a graph, created by calling `.bind()` on a serve decorated class or function, backed by a Deployment.

View file

@ -10,7 +10,7 @@ from ray._private.client_mode_hook import enable_client_mode, client_mode_should
@pytest.mark.skip(reason="KV store is not working properly.") @pytest.mark.skip(reason="KV store is not working properly.")
def test_rllib_integration(ray_start_regular_shared): def test_rllib_integration(ray_start_regular_shared):
with ray_start_client_server(): with ray_start_client_server():
import ray.rllib.agents.dqn as dqn import ray.rllib.algorithms.dqn as dqn
# Confirming the behavior of this context manager. # Confirming the behavior of this context manager.
# (Client mode hook not yet enabled.) # (Client mode hook not yet enabled.)

View file

@ -29,7 +29,7 @@ virtualenv
## setup.py extras ## setup.py extras
dm_tree dm_tree
flask flask
gym==0.21.0; python_version >= '3.7' gym>=0.21.0,<0.24.0; python_version >= '3.7'
gym==0.19.0; python_version < '3.7' gym==0.19.0; python_version < '3.7'
lz4 lz4
scikit-image scikit-image

View file

@ -4,7 +4,7 @@
# --------------------- # ---------------------
# Atari # Atari
autorom[accept-rom-license] autorom[accept-rom-license]
gym[atari]==0.21.0; python_version >= '3.7' gym>=0.21.0,<0.24.0; python_version >= '3.7'
gym[atari]==0.19.0; python_version < '3.7' gym[atari]==0.19.0; python_version < '3.7'
# Kaggle envs. # Kaggle envs.
kaggle_environments==1.7.11 kaggle_environments==1.7.11

View file

@ -11,7 +11,7 @@ freezegun==1.1.0
gluoncv==0.10.1.post0 gluoncv==0.10.1.post0
gpy==1.10.0 gpy==1.10.0
autorom[accept-rom-license] autorom[accept-rom-license]
gym[atari]==0.21.0; python_version >= '3.7' gym>=0.21.0,<0.24.0; python_version >= '3.7'
gym[atari]==0.19.0; python_version < '3.7' gym[atari]==0.19.0; python_version < '3.7'
h5py==3.1.0 h5py==3.1.0
hpbandster==0.7.4 hpbandster==0.7.4

View file

@ -7,7 +7,7 @@ python:
pip_packages: pip_packages:
- pytest - pytest
- awscli - awscli
- gym==0.21.0 - gym
conda_packages: [] conda_packages: []
post_build_cmds: post_build_cmds:

View file

@ -7,7 +7,7 @@ debian_packages:
python: python:
pip_packages: pip_packages:
- gym[atari]==0.21.0 - gym[atari]
- pytest - pytest
- tensorflow - tensorflow
conda_packages: [] conda_packages: []
@ -15,7 +15,7 @@ python:
post_build_cmds: post_build_cmds:
- 'rm -r wrk || true && git clone https://github.com/wg/wrk.git /tmp/wrk && cd /tmp/wrk && make -j && sudo cp wrk /usr/local/bin' - 'rm -r wrk || true && git clone https://github.com/wg/wrk.git /tmp/wrk && cd /tmp/wrk && make -j && sudo cp wrk /usr/local/bin'
- pip3 install pytest || true - pip3 install pytest || true
- pip3 install -U ray[all] gym[atari]==0.21.0 autorom[accept-rom-license] - pip3 install -U ray[all] gym[atari] autorom[accept-rom-license]
- pip3 install ray[all] - pip3 install ray[all]
# TODO (Alex): Ideally we would install all the dependencies from the new # TODO (Alex): Ideally we would install all the dependencies from the new
# version too, but pip won't be able to find the new version of ray-cpp. # version too, but pip won't be able to find the new version of ray-cpp.

View file

@ -7,7 +7,7 @@ debian_packages:
python: python:
pip_packages: pip_packages:
- gym[atari]==0.21.0 - gym[atari]
- pygame - pygame
- pytest - pytest
- tensorflow - tensorflow
@ -18,7 +18,7 @@ post_build_cmds:
- 'rm -r wrk || true && git clone https://github.com/wg/wrk.git /tmp/wrk && cd /tmp/wrk && make -j && sudo cp wrk /usr/local/bin' - 'rm -r wrk || true && git clone https://github.com/wg/wrk.git /tmp/wrk && cd /tmp/wrk && make -j && sudo cp wrk /usr/local/bin'
- pip3 install numpy==1.19 || true - pip3 install numpy==1.19 || true
- pip3 install pytest || true - pip3 install pytest || true
- pip3 install -U ray[all] gym[atari]==0.21.0 autorom[accept-rom-license] - pip3 install -U ray[all] gym[atari] autorom[accept-rom-license]
- pip3 install ray[all] - pip3 install ray[all]
# TODO (Alex): Ideally we would install all the dependencies from the new # TODO (Alex): Ideally we would install all the dependencies from the new
# version too, but pip won't be able to find the new version of ray-cpp. # version too, but pip won't be able to find the new version of ray-cpp.

View file

@ -8,7 +8,7 @@ python:
# These dependencies should be handled by requirements_rllib.txt and # These dependencies should be handled by requirements_rllib.txt and
# requirements_ml_docker.txt # requirements_ml_docker.txt
pip_packages: pip_packages:
- gym==0.21.0 - gym
conda_packages: [] conda_packages: []
post_build_cmds: post_build_cmds:

View file

@ -8,8 +8,8 @@ python:
- pytest - pytest
- awscli - awscli
- gsutil - gsutil
- gym
- gcsfs - gcsfs
- gym==0.21.0
- pyarrow>=6.0.1,<7.0.0 - pyarrow>=6.0.1,<7.0.0
conda_packages: [] conda_packages: []

View file

@ -8,8 +8,8 @@ python:
- pytest - pytest
- awscli - awscli
- gsutil - gsutil
- gym
- gcsfs - gcsfs
- gym==0.21.0
- pyarrow>=6.0.1,<7.0.0 - pyarrow>=6.0.1,<7.0.0
conda_packages: [] conda_packages: []

View file

@ -1318,12 +1318,6 @@ sh_test(
data = glob(["examples/serving/*.py"]), data = glob(["examples/serving/*.py"]),
) )
py_test(
name = "env/tests/test_record_env_wrapper",
tags = ["team:ml", "env"],
size = "small",
srcs = ["env/tests/test_record_env_wrapper.py"]
)
py_test( py_test(
name = "env/tests/test_remote_worker_envs", name = "env/tests/test_remote_worker_envs",
@ -2818,13 +2812,13 @@ py_test(
args = ["--as-test", "--framework=torch"], args = ["--as-test", "--framework=torch"],
) )
py_test( # py_test(
name = "examples/remote_base_env_with_custom_api", # name = "examples/remote_base_env_with_custom_api",
tags = ["team:ml", "examples", "examples_R"], # tags = ["team:ml", "examples", "examples_R"],
size = "medium", # size = "medium",
srcs = ["examples/remote_base_env_with_custom_api.py"], # srcs = ["examples/remote_base_env_with_custom_api.py"],
args = ["--stop-iters=3"] # args = ["--stop-iters=3"]
) # )
py_test( py_test(
name = "examples/restore_1_of_n_agents_from_checkpoint", name = "examples/restore_1_of_n_agents_from_checkpoint",

View file

@ -75,10 +75,11 @@ class _MockTrainer(Trainer):
self.info = info self.info = info
self.restored = True self.restored = True
@staticmethod
@override(Trainer) @override(Trainer)
def _register_if_needed(self, env_object, config): def _get_env_id_and_creator(env_specifier, config):
# No env to register. # No env to register.
pass return None, None
def set_info(self, info): def set_info(self, info):
self.info = info self.info = info

View file

@ -8,7 +8,9 @@ import logging
import math import math
import numpy as np import numpy as np
import os import os
from packaging import version
import pickle import pickle
import pkg_resources
import tempfile import tempfile
import time import time
from typing import ( from typing import (
@ -30,7 +32,6 @@ from ray.exceptions import RayError
from ray.rllib.agents.callbacks import DefaultCallbacks from ray.rllib.agents.callbacks import DefaultCallbacks
from ray.rllib.agents.trainer_config import TrainerConfig from ray.rllib.agents.trainer_config import TrainerConfig
from ray.rllib.env.env_context import EnvContext from ray.rllib.env.env_context import EnvContext
from ray.rllib.env.multi_agent_env import MultiAgentEnv
from ray.rllib.env.utils import gym_env_creator from ray.rllib.env.utils import gym_env_creator
from ray.rllib.evaluation.episode import Episode from ray.rllib.evaluation.episode import Episode
from ray.rllib.evaluation.metrics import ( from ray.rllib.evaluation.metrics import (
@ -94,7 +95,7 @@ from ray.rllib.utils.typing import (
TrainerConfigDict, TrainerConfigDict,
) )
from ray.tune.logger import Logger, UnifiedLogger from ray.tune.logger import Logger, UnifiedLogger
from ray.tune.registry import ENV_CREATOR, register_env, _global_registry from ray.tune.registry import ENV_CREATOR, _global_registry
from ray.tune.resources import Resources from ray.tune.resources import Resources
from ray.tune.result import DEFAULT_RESULTS_DIR from ray.tune.result import DEFAULT_RESULTS_DIR
from ray.tune.trainable import Trainable from ray.tune.trainable import Trainable
@ -220,19 +221,13 @@ class Trainer(Trainable):
if isinstance(config, TrainerConfig): if isinstance(config, TrainerConfig):
config = config.to_dict() config = config.to_dict()
# Convert `env` provided in config into a string: # Convert `env` provided in config into a concrete env creator callable, which
# - If `env` is a string: `self._env_id` = `env`. # takes an EnvContext (config dict) as arg and returning an RLlib supported Env
# - If `env` is a class: `self._env_id` = `env.__name__` -> Already # type (e.g. a gym.Env).
# register it with a auto-generated env creator. self._env_id, self.env_creator = self._get_env_id_and_creator(
# - If `env` is None: `self._env_id` is None.
self._env_id: Optional[str] = self._register_if_needed(
env or config.get("env"), config env or config.get("env"), config
) )
# The env creator callable, taking an EnvContext (config dict)
# as arg and returning an RLlib supported Env type (e.g. a gym.Env).
self.env_creator: Optional[EnvCreator] = None
# Placeholder for a local replay buffer instance. # Placeholder for a local replay buffer instance.
self.local_replay_buffer = None self.local_replay_buffer = None
@ -310,10 +305,6 @@ class Trainer(Trainable):
# Validate the framework settings in config. # Validate the framework settings in config.
self.validate_framework(self.config) self.validate_framework(self.config)
# Setup the self.env_creator callable (to be passed
# e.g. to RolloutWorkers' c'tors).
self.env_creator = self._get_env_creator_from_env_id(self._env_id)
# Set Trainer's seed after we have - if necessary - enabled # Set Trainer's seed after we have - if necessary - enabled
# tf eager-execution. # tf eager-execution.
update_global_seed_if_necessary(self.config["framework"], self.config["seed"]) update_global_seed_if_necessary(self.config["framework"], self.config["seed"])
@ -466,8 +457,9 @@ class Trainer(Trainable):
self.config["evaluation_config"] = eval_config self.config["evaluation_config"] = eval_config
env_id = self._register_if_needed(eval_config.get("env"), eval_config) env_id, env_creator = self._get_env_id_and_creator(
env_creator = self._get_env_creator_from_env_id(env_id) eval_config.get("env"), eval_config
)
# Create a separate evaluation worker set for evaluation. # Create a separate evaluation worker set for evaluation.
# If evaluation_num_workers=0, use the evaluation set's local # If evaluation_num_workers=0, use the evaluation set's local
@ -1541,37 +1533,87 @@ class Trainer(Trainable):
"""Pre-evaluation callback.""" """Pre-evaluation callback."""
pass pass
def _get_env_creator_from_env_id(self, env_id: Optional[str] = None) -> EnvCreator: @staticmethod
"""Returns an env creator callable, given an `env_id` (e.g. "CartPole-v0"). def _get_env_id_and_creator(
env_specifier: Union[str, EnvType, None], config: PartialTrainerConfigDict
) -> Tuple[Optional[str], EnvCreator]:
"""Returns env_id and creator callable given original env id from config.
Args: Args:
env_id: An already tune registered env ID, a known gym env name, env_specifier: An env class, an already tune registered env ID, a known
or None (if no env is used). gym env name, or None (if no env is used).
config: The Trainer's (maybe partial) config dict.
Returns: Returns:
Tuple consisting of a) env ID string and b) env creator callable.
""" """
if env_id: # Environment is specified via a string.
if isinstance(env_specifier, str):
# An already registered env. # An already registered env.
if _global_registry.contains(ENV_CREATOR, env_id): if _global_registry.contains(ENV_CREATOR, env_specifier):
return _global_registry.get(ENV_CREATOR, env_id) return env_specifier, _global_registry.get(ENV_CREATOR, env_specifier)
# A class path specifier. # A class path specifier.
elif "." in env_id: elif "." in env_specifier:
def env_creator_from_classpath(env_context): def env_creator_from_classpath(env_context):
try: try:
env_obj = from_config(env_id, env_context) env_obj = from_config(env_specifier, env_context)
except ValueError: except ValueError:
raise EnvError(ERR_MSG_INVALID_ENV_DESCRIPTOR.format(env_id)) raise EnvError(
ERR_MSG_INVALID_ENV_DESCRIPTOR.format(env_specifier)
)
return env_obj return env_obj
return env_creator_from_classpath return env_specifier, env_creator_from_classpath
# Try gym/PyBullet/Vizdoom. # Try gym/PyBullet/Vizdoom.
else: else:
return functools.partial(gym_env_creator, env_descriptor=env_id) return env_specifier, functools.partial(
# No env -> Env creator always returns None. gym_env_creator, env_descriptor=env_specifier
)
elif isinstance(env_specifier, type):
env_id = env_specifier # .__name__
if config.get("remote_worker_envs"):
# Check gym version (0.22 or higher?).
# If > 0.21, can't perform auto-wrapping of the given class as this
# would lead to a pickle error.
gym_version = pkg_resources.get_distribution("gym").version
if version.parse(gym_version) >= version.parse("0.22"):
raise ValueError(
"Cannot specify a gym.Env class via `config.env` while setting "
"`config.remote_worker_env=True` AND your gym version is >= "
"0.22! Try installing an older version of gym or set `config."
"remote_worker_env=False`."
)
@ray.remote(num_cpus=1)
class _wrapper(env_specifier):
# Add convenience `_get_spaces` and `_is_multi_agent`
# methods:
def _get_spaces(self):
return self.observation_space, self.action_space
def _is_multi_agent(self):
from ray.rllib.env.multi_agent_env import MultiAgentEnv
return isinstance(self, MultiAgentEnv)
return env_id, lambda cfg: _wrapper.remote(cfg)
else: else:
return lambda env_config: None return env_id, lambda cfg: env_specifier(cfg)
# No env -> Env creator always returns None.
elif env_specifier is None:
return None, lambda env_config: None
else:
raise ValueError(
"{} is an invalid env specifier. ".format(env_specifier)
+ "You can specify a custom env as either a class "
'(e.g., YourEnvCls) or a registered env id (e.g., "your_env").'
)
def _sync_filters_if_needed(self, workers: WorkerSet): def _sync_filters_if_needed(self, workers: WorkerSet):
if self.config.get("observation_filter", "NoFilter") != "NoFilter": if self.config.get("observation_filter", "NoFilter") != "NoFilter":
@ -1753,16 +1795,6 @@ class Trainer(Trainable):
if model_config is None: if model_config is None:
config["model"] = model_config = {} config["model"] = model_config = {}
# Monitor should be replaced by `record_env`.
if config.get("monitor", DEPRECATED_VALUE) != DEPRECATED_VALUE:
deprecation_warning("monitor", "record_env", error=False)
config["record_env"] = config.get("monitor", False)
# Empty string would fail some if-blocks checking for this setting.
# Set to True instead, meaning: use default output dir to store
# the videos.
if config.get("record_env") == "":
config["record_env"] = True
# Use DefaultCallbacks class, if callbacks is None. # Use DefaultCallbacks class, if callbacks is None.
if config["callbacks"] is None: if config["callbacks"] is None:
config["callbacks"] = DefaultCallbacks config["callbacks"] = DefaultCallbacks
@ -2149,38 +2181,6 @@ class Trainer(Trainable):
kwargs["local_replay_buffer"] = self.local_replay_buffer kwargs["local_replay_buffer"] = self.local_replay_buffer
return kwargs return kwargs
def _register_if_needed(
self, env_object: Union[str, EnvType, None], config
) -> Optional[str]:
if isinstance(env_object, str):
return env_object
elif isinstance(env_object, type):
name = env_object.__name__
if config.get("remote_worker_envs"):
@ray.remote(num_cpus=0)
class _wrapper(env_object):
# Add convenience `_get_spaces` and `_is_multi_agent`
# methods.
def _get_spaces(self):
return self.observation_space, self.action_space
def _is_multi_agent(self):
return isinstance(self, MultiAgentEnv)
register_env(name, lambda cfg: _wrapper.remote(cfg))
else:
register_env(name, lambda cfg: env_object(cfg))
return name
elif env_object is None:
return None
raise ValueError(
"{} is an invalid env specification. ".format(env_object)
+ "You can specify a custom env as either a class "
'(e.g., YourEnvCls) or a registered env id (e.g., "your_env").'
)
def _step_context(trainer): def _step_context(trainer):
class StepCtx: class StepCtx:
def __enter__(self): def __enter__(self):

View file

@ -108,7 +108,6 @@ class TrainerConfig:
self.action_space = None self.action_space = None
self.env_task_fn = None self.env_task_fn = None
self.render_env = False self.render_env = False
self.record_env = False
self.clip_rewards = None self.clip_rewards = None
self.normalize_actions = True self.normalize_actions = True
self.clip_actions = False self.clip_actions = False
@ -458,7 +457,6 @@ class TrainerConfig:
action_space: Optional[gym.spaces.Space] = None, action_space: Optional[gym.spaces.Space] = None,
env_task_fn: Optional[Callable[[ResultDict, EnvType, EnvContext], Any]] = None, env_task_fn: Optional[Callable[[ResultDict, EnvType, EnvContext], Any]] = None,
render_env: Optional[bool] = None, render_env: Optional[bool] = None,
record_env: Optional[bool] = None,
clip_rewards: Optional[Union[bool, float]] = None, clip_rewards: Optional[Union[bool, float]] = None,
normalize_actions: Optional[bool] = None, normalize_actions: Optional[bool] = None,
clip_actions: Optional[bool] = None, clip_actions: Optional[bool] = None,
@ -489,11 +487,6 @@ class TrainerConfig:
`render()` method which either: `render()` method which either:
a) handles window generation and rendering itself (returning True) or a) handles window generation and rendering itself (returning True) or
b) returns a numpy uint8 image of shape [height x width x 3 (RGB)]. b) returns a numpy uint8 image of shape [height x width x 3 (RGB)].
record_env: If True, stores videos in this relative directory inside the
default output dir (~/ray_results/...). Alternatively, you can
specify an absolute path (str), in which the env recordings should be
stored instead. Set to False for not recording anything.
Note: This setting replaces the deprecated `monitor` key.
clip_rewards: Whether to clip rewards during Policy's postprocessing. clip_rewards: Whether to clip rewards during Policy's postprocessing.
None (default): Clip for Atari only (r=sign(r)). None (default): Clip for Atari only (r=sign(r)).
True: r=sign(r): Fixed rewards -1.0, 1.0, or 0.0. True: r=sign(r): Fixed rewards -1.0, 1.0, or 0.0.
@ -524,8 +517,6 @@ class TrainerConfig:
self.env_task_fn = env_task_fn self.env_task_fn = env_task_fn
if render_env is not None: if render_env is not None:
self.render_env = render_env self.render_env = render_env
if record_env is not None:
self.record_env = record_env
if clip_rewards is not None: if clip_rewards is not None:
self.clip_rewards = clip_rewards self.clip_rewards = clip_rewards
if normalize_actions is not None: if normalize_actions is not None:

View file

@ -168,7 +168,7 @@ class AlphaZeroPolicyWrapperClass(AlphaZeroPolicy):
model = ModelCatalog.get_model_v2( model = ModelCatalog.get_model_v2(
obs_space, action_space, action_space.n, config["model"], "torch" obs_space, action_space, action_space.n, config["model"], "torch"
) )
env_creator = Trainer._get_env_creator_from_env_id(None, config["env"]) _, env_creator = Trainer._get_env_id_and_creator(config["env"], config)
if config["ranked_rewards"]["enable"]: if config["ranked_rewards"]["enable"]:
# if r2 is enabled, tne env is wrapped to include a rewards buffer # if r2 is enabled, tne env is wrapped to include a rewards buffer
# used to normalize rewards # used to normalize rewards

View file

@ -355,8 +355,6 @@ class ARSTrainer(Trainer):
# Validate our config dict. # Validate our config dict.
self.validate_config(self.config) self.validate_config(self.config)
# Generate `self.env_creator` callable to create an env instance.
self.env_creator = self._get_env_creator_from_env_id(self._env_id)
# Generate the local env. # Generate the local env.
env_context = EnvContext(self.config["env_config"] or {}, worker_index=0) env_context = EnvContext(self.config["env_config"] or {}, worker_index=0)
env = self.env_creator(env_context) env = self.env_creator(env_context)

View file

@ -362,8 +362,6 @@ class ESTrainer(Trainer):
# Call super's validation method. # Call super's validation method.
self.validate_config(self.config) self.validate_config(self.config)
# Generate `self.env_creator` callable to create an env instance.
self.env_creator = self._get_env_creator_from_env_id(self._env_id)
# Generate the local env. # Generate the local env.
env_context = EnvContext(self.config["env_config"] or {}, worker_index=0) env_context = EnvContext(self.config["env_config"] or {}, worker_index=0)
env = self.env_creator(env_context) env = self.env_creator(env_context)

View file

@ -1,6 +1,6 @@
import gym import gym
import logging import logging
from typing import Callable, Dict, List, Tuple, Type, Optional, Union, Set from typing import Callable, Dict, List, Tuple, Optional, Union, Set, Type
from ray.rllib.env.base_env import BaseEnv from ray.rllib.env.base_env import BaseEnv
from ray.rllib.utils.annotations import ( from ray.rllib.utils.annotations import (

View file

@ -1,100 +0,0 @@
import glob
import gym
import numpy as np
import os
import shutil
import unittest
from ray.rllib.env.utils import VideoMonitor, record_env_wrapper
from ray.rllib.examples.env.mock_env import MockEnv2
from ray.rllib.examples.env.multi_agent import BasicMultiAgent
from ray.rllib.utils.test_utils import check
class TestRecordEnvWrapper(unittest.TestCase):
def test_wrap_gym_env(self):
record_env_dir = os.popen("mktemp -d").read()[:-1]
print(f"tmp dir for videos={record_env_dir}")
if not os.path.exists(record_env_dir):
sys.exit(1)
num_steps_per_episode = 10
wrapped = record_env_wrapper(
env=MockEnv2(num_steps_per_episode),
record_env=record_env_dir,
log_dir="",
policy_config={
"in_evaluation": False,
},
)
# Non MultiAgentEnv: Wrapper's type is wrappers.Monitor.
self.assertTrue(isinstance(wrapped, gym.wrappers.Monitor))
self.assertFalse(isinstance(wrapped, VideoMonitor))
wrapped.reset()
# Expect one video file to have been produced in the tmp dir.
os.chdir(record_env_dir)
ls = glob.glob("*.mp4")
self.assertTrue(len(ls) == 1)
# 10 steps for a complete episode.
for i in range(num_steps_per_episode):
wrapped.step(0)
# Another episode.
wrapped.reset()
for i in range(num_steps_per_episode):
wrapped.step(0)
# Expect another video file to have been produced (2nd episode).
ls = glob.glob("*.mp4")
self.assertTrue(len(ls) == 2)
# MockEnv2 returns a reward of 100.0 every step.
# So total reward is 1000.0 per episode (10 steps).
check(
np.array([100.0, 100.0]) * num_steps_per_episode,
wrapped.get_episode_rewards(),
)
# Erase all generated files and the temp path just in case,
# as to not disturb further CI-tests.
shutil.rmtree(record_env_dir)
def test_wrap_multi_agent_env(self):
record_env_dir = os.popen("mktemp -d").read()[:-1]
print(f"tmp dir for videos={record_env_dir}")
if not os.path.exists(record_env_dir):
sys.exit(1)
wrapped = record_env_wrapper(
env=BasicMultiAgent(3),
record_env=record_env_dir,
log_dir="",
policy_config={
"in_evaluation": False,
},
)
# Type is VideoMonitor.
self.assertTrue(isinstance(wrapped, gym.wrappers.Monitor))
self.assertTrue(isinstance(wrapped, VideoMonitor))
wrapped.reset()
# BasicMultiAgent is hardcoded to run 25-step episodes.
for i in range(25):
wrapped.step({0: 0, 1: 0, 2: 0})
# Expect one video file to have been produced in the tmp dir.
os.chdir(record_env_dir)
ls = glob.glob("*.mp4")
self.assertTrue(len(ls) == 1)
# However VideoMonitor's _after_step is overwritten to not
# use stats_recorder. So nothing to verify here, except that
# it runs fine.
if __name__ == "__main__":
import pytest
import sys
sys.exit(pytest.main(["-v", __file__]))

View file

@ -7,10 +7,8 @@ import unittest
import ray import ray
from ray.rllib.algorithms.pg import pg from ray.rllib.algorithms.pg import pg
from ray.rllib.env.wrappers.pettingzoo_env import PettingZooEnv from ray.rllib.env.wrappers.pettingzoo_env import PettingZooEnv
from ray.rllib.examples.env.random_env import RandomEnv, RandomMultiAgentEnv
from ray.rllib.examples.remote_base_env_with_custom_api import ( # from ray.rllib.examples.env.random_env import RandomEnv
NonVectorizedEnvToBeVectorizedIntoRemoteBaseEnv,
)
from ray import tune from ray import tune
@ -55,17 +53,19 @@ class TestRemoteWorkerEnvSetting(unittest.TestCase):
trainer.stop() trainer.stop()
# Using class directly. # Using class directly.
config["env"] = RandomEnv # This doesn't work anymore as of gym==0.23
trainer = pg.PGTrainer(config=config) # config["env"] = RandomEnv
print(trainer.train()) # trainer = pg.PGTrainer(config=config)
trainer.stop() # print(trainer.train())
# trainer.stop()
# Using class directly: Sub-class of gym.Env, # Using class directly: Sub-class of gym.Env,
# which implements its own API. # which implements its own API.
config["env"] = NonVectorizedEnvToBeVectorizedIntoRemoteBaseEnv # This doesn't work anymore as of gym==0.23
trainer = pg.PGTrainer(config=config) # config["env"] = NonVectorizedEnvToBeVectorizedIntoRemoteBaseEnv
print(trainer.train()) # trainer = pg.PGTrainer(config=config)
trainer.stop() # print(trainer.train())
# trainer.stop()
def test_remote_worker_env_multi_agent(self): def test_remote_worker_env_multi_agent(self):
config = pg.DEFAULT_CONFIG.copy() config = pg.DEFAULT_CONFIG.copy()
@ -85,10 +85,11 @@ class TestRemoteWorkerEnvSetting(unittest.TestCase):
trainer.stop() trainer.stop()
# Using class directly. # Using class directly.
config["env"] = RandomMultiAgentEnv # This doesn't work anymore as of gym==0.23.
trainer = pg.PGTrainer(config=config) # config["env"] = RandomMultiAgentEnv
print(trainer.train()) # trainer = pg.PGTrainer(config=config)
trainer.stop() # print(trainer.train())
# trainer.stop()
if __name__ == "__main__": if __name__ == "__main__":

49
rllib/env/utils.py vendored
View file

@ -1,10 +1,6 @@
import gym import gym
from gym import wrappers
import os
from ray.rllib.env.env_context import EnvContext from ray.rllib.env.env_context import EnvContext
from ray.rllib.env.multi_agent_env import MultiAgentEnv
from ray.rllib.utils import add_mixins
from ray.rllib.utils.error import ERR_MSG_INVALID_ENV_DESCRIPTOR, EnvError from ray.rllib.utils.error import ERR_MSG_INVALID_ENV_DESCRIPTOR, EnvError
@ -54,48 +50,3 @@ def gym_env_creator(env_context: EnvContext, env_descriptor: str) -> gym.Env:
return gym.make(env_descriptor, **env_context) return gym.make(env_descriptor, **env_context)
except gym.error.Error: except gym.error.Error:
raise EnvError(ERR_MSG_INVALID_ENV_DESCRIPTOR.format(env_descriptor)) raise EnvError(ERR_MSG_INVALID_ENV_DESCRIPTOR.format(env_descriptor))
class VideoMonitor(wrappers.Monitor):
# Same as original method, but doesn't use the StatsRecorder as it will
# try to add up multi-agent rewards dicts, which throws errors.
def _after_step(self, observation, reward, done, info):
if not self.enabled:
return done
# Use done["__all__"] b/c this is a multi-agent dict.
if done["__all__"] and self.env_semantics_autoreset:
# For envs with BlockingReset wrapping VNCEnv, this observation
# will be the first one of the new episode
self.reset_video_recorder()
self.episode_id += 1
self._flush()
# Record video
self.video_recorder.capture_frame()
return done
def record_env_wrapper(env, record_env, log_dir, policy_config):
if record_env:
path_ = record_env if isinstance(record_env, str) else log_dir
# Relative path: Add logdir here, otherwise, this would
# not work for non-local workers.
if not os.path.isabs(path_):
path_ = os.path.join(log_dir, path_)
print(f"Setting the path for recording to {path_}")
wrapper_cls = (
VideoMonitor if isinstance(env, MultiAgentEnv) else wrappers.Monitor
)
if isinstance(env, MultiAgentEnv):
wrapper_cls = add_mixins(wrapper_cls, [MultiAgentEnv], reversed=True)
env = wrapper_cls(
env,
path_,
resume=True,
force=True,
video_callable=lambda _: True,
mode="evaluation" if policy_config["in_evaluation"] else "training",
)
return env

View file

@ -4,7 +4,6 @@ import argparse
import collections import collections
import copy import copy
import gym import gym
from gym import wrappers as gym_wrappers
import json import json
import os import os
from pathlib import Path from pathlib import Path
@ -339,7 +338,6 @@ def run(args, parser):
deprecation_warning(old="--no-render", new="--render", error=False) deprecation_warning(old="--no-render", new="--render", error=False)
args.render = False args.render = False
config["render_env"] = args.render config["render_env"] = args.render
config["record_env"] = args.video_dir
ray.init(local_mode=args.local_mode) ray.init(local_mode=args.local_mode)
@ -354,12 +352,6 @@ def run(args, parser):
num_steps = int(args.steps) num_steps = int(args.steps)
num_episodes = int(args.episodes) num_episodes = int(args.episodes)
# Determine the video output directory.
video_dir = None
# Allow user to specify a video output path.
if args.video_dir:
video_dir = os.path.expanduser(args.video_dir)
# Do the actual rollout. # Do the actual rollout.
with RolloutSaver( with RolloutSaver(
args.out, args.out,
@ -369,9 +361,7 @@ def run(args, parser):
target_episodes=num_episodes, target_episodes=num_episodes,
save_info=args.save_info, save_info=args.save_info,
) as saver: ) as saver:
rollout( rollout(agent, args.env, num_steps, num_episodes, saver, not args.render)
agent, args.env, num_steps, num_episodes, saver, not args.render, video_dir
)
agent.stop() agent.stop()
@ -406,7 +396,6 @@ def rollout(
num_episodes=0, num_episodes=0,
saver=None, saver=None,
no_render=True, no_render=True,
video_dir=None,
): ):
policy_agent_mapping = default_policy_agent_mapping policy_agent_mapping = default_policy_agent_mapping
@ -473,13 +462,6 @@ def rollout(
for p, m in policy_map.items() for p, m in policy_map.items()
} }
# If monitoring has been requested, manually wrap our environment with a
# gym monitor, which is set to record every episode.
if video_dir:
env = gym_wrappers.Monitor(
env=env, directory=video_dir, video_callable=lambda _: True, force=True
)
steps = 0 steps = 0
episodes = 0 episodes = 0
while keep_going(steps, num_steps, episodes, num_episodes): while keep_going(steps, num_steps, episodes, num_episodes):

View file

@ -27,7 +27,6 @@ from ray.rllib.env.base_env import BaseEnv, convert_to_base_env
from ray.rllib.env.env_context import EnvContext from ray.rllib.env.env_context import EnvContext
from ray.rllib.env.multi_agent_env import MultiAgentEnv from ray.rllib.env.multi_agent_env import MultiAgentEnv
from ray.rllib.env.external_multi_agent_env import ExternalMultiAgentEnv from ray.rllib.env.external_multi_agent_env import ExternalMultiAgentEnv
from ray.rllib.env.utils import record_env_wrapper
from ray.rllib.env.wrappers.atari_wrappers import wrap_deepmind, is_atari from ray.rllib.env.wrappers.atari_wrappers import wrap_deepmind, is_atari
from ray.rllib.evaluation.sampler import AsyncSampler, SyncSampler from ray.rllib.evaluation.sampler import AsyncSampler, SyncSampler
from ray.rllib.evaluation.metrics import RolloutMetrics from ray.rllib.evaluation.metrics import RolloutMetrics
@ -233,7 +232,6 @@ class RolloutWorker(ParallelIteratorWorker):
worker_index: int = 0, worker_index: int = 0,
num_workers: int = 0, num_workers: int = 0,
recreated_worker: bool = False, recreated_worker: bool = False,
record_env: Union[bool, str] = False,
log_dir: Optional[str] = None, log_dir: Optional[str] = None,
log_level: Optional[str] = None, log_level: Optional[str] = None,
callbacks: Type["DefaultCallbacks"] = None, callbacks: Type["DefaultCallbacks"] = None,
@ -253,7 +251,6 @@ class RolloutWorker(ParallelIteratorWorker):
fake_sampler: bool = False, fake_sampler: bool = False,
spaces: Optional[Dict[PolicyID, Tuple[Space, Space]]] = None, spaces: Optional[Dict[PolicyID, Tuple[Space, Space]]] = None,
policy=None, policy=None,
monitor_path=None,
disable_env_checking=False, disable_env_checking=False,
): ):
"""Initializes a RolloutWorker instance. """Initializes a RolloutWorker instance.
@ -332,10 +329,6 @@ class RolloutWorker(ParallelIteratorWorker):
`recreate_failed_workers=True` and one of the original workers (or an `recreate_failed_workers=True` and one of the original workers (or an
already recreated one) has failed. They don't differ from original already recreated one) has failed. They don't differ from original
workers other than the value of this flag (`self.recreated_worker`). workers other than the value of this flag (`self.recreated_worker`).
record_env: Write out episode stats and videos
using gym.wrappers.Monitor to this directory if specified. If
True, use the default output dir in ~/ray_results/.... If
False, do not record anything.
log_dir: Directory where logs can be placed. log_dir: Directory where logs can be placed.
log_level: Set the root log level on creation. log_level: Set the root log level on creation.
callbacks: Custom sub-class of callbacks: Custom sub-class of
@ -374,7 +367,6 @@ class RolloutWorker(ParallelIteratorWorker):
to (obs_space, action_space)-tuples. This is used in case no to (obs_space, action_space)-tuples. This is used in case no
Env is created on this RolloutWorker. Env is created on this RolloutWorker.
policy: Obsoleted arg. Use `policy_spec` instead. policy: Obsoleted arg. Use `policy_spec` instead.
monitor_path: Obsoleted arg. Use `record_env` instead.
disable_env_checking: If True, disables the env checking module that disable_env_checking: If True, disables the env checking module that
validates the properties of the passed environment. validates the properties of the passed environment.
""" """
@ -395,10 +387,6 @@ class RolloutWorker(ParallelIteratorWorker):
for pid, spec in policy_spec.copy().items() for pid, spec in policy_spec.copy().items()
} }
if monitor_path is not None:
deprecation_warning("monitor_path", "record_env", error=False)
record_env = monitor_path
self._original_kwargs: dict = locals().copy() self._original_kwargs: dict = locals().copy()
del self._original_kwargs["self"] del self._original_kwargs["self"]
@ -490,7 +478,7 @@ class RolloutWorker(ParallelIteratorWorker):
# 1) Create the env using the user provided env_creator. This may # 1) Create the env using the user provided env_creator. This may
# return a gym.Env (incl. MultiAgentEnv), an already vectorized # return a gym.Env (incl. MultiAgentEnv), an already vectorized
# VectorEnv, BaseEnv, ExternalEnv, or an ActorHandle (remote env). # VectorEnv, BaseEnv, ExternalEnv, or an ActorHandle (remote env).
# 2) Wrap - if applicable - with Atari/recording/rendering wrappers. # 2) Wrap - if applicable - with Atari/rendering wrappers.
# 3) Seed the env, if necessary. # 3) Seed the env, if necessary.
# 4) Vectorize the existing single env by creating more clones of # 4) Vectorize the existing single env by creating more clones of
# this env and wrapping it with the RLlib BaseEnv class. # this env and wrapping it with the RLlib BaseEnv class.
@ -541,14 +529,12 @@ class RolloutWorker(ParallelIteratorWorker):
env = wrap_deepmind( env = wrap_deepmind(
env, dim=model_config.get("dim"), framestack=use_framestack env, dim=model_config.get("dim"), framestack=use_framestack
) )
env = record_env_wrapper(env, record_env, log_dir, policy_config)
return env return env
# gym.Env -> Wrap with gym Monitor.
else: else:
def wrap(env): def wrap(env):
return record_env_wrapper(env, record_env, log_dir, policy_config) return env
# Wrap env through the correct wrapper. # Wrap env through the correct wrapper.
self.env: EnvType = wrap(self.env) self.env: EnvType = wrap(self.env)

View file

@ -4,7 +4,6 @@ from gym.spaces import Box, Discrete
import numpy as np import numpy as np
import os import os
import random import random
import tempfile
import time import time
import unittest import unittest
@ -12,7 +11,6 @@ import ray
from ray.rllib.algorithms.pg import PGTrainer from ray.rllib.algorithms.pg import PGTrainer
from ray.rllib.agents.a3c import A2CTrainer from ray.rllib.agents.a3c import A2CTrainer
from ray.rllib.env.multi_agent_env import MultiAgentEnv from ray.rllib.env.multi_agent_env import MultiAgentEnv
from ray.rllib.env.utils import VideoMonitor
from ray.rllib.evaluation.rollout_worker import RolloutWorker from ray.rllib.evaluation.rollout_worker import RolloutWorker
from ray.rllib.evaluation.metrics import collect_metrics from ray.rllib.evaluation.metrics import collect_metrics
from ray.rllib.evaluation.postprocessing import compute_advantages from ray.rllib.evaluation.postprocessing import compute_advantages
@ -376,10 +374,8 @@ class TestRolloutWorker(unittest.TestCase):
curframe = inspect.currentframe() curframe = inspect.currentframe()
called_from_check = any( called_from_check = any(
[
frame[3] == "check_gym_environments" frame[3] == "check_gym_environments"
for frame in inspect.getouterframes(curframe, 2) for frame in inspect.getouterframes(curframe, 2)
]
) )
# Check, whether the action is immutable. # Check, whether the action is immutable.
if action.flags.writeable and not called_from_check: if action.flags.writeable and not called_from_check:
@ -825,15 +821,12 @@ class TestRolloutWorker(unittest.TestCase):
policy_config={ policy_config={
"in_evaluation": False, "in_evaluation": False,
}, },
record_env=tempfile.gettempdir(),
) )
# Make sure we can properly sample from the wrapped env. # Make sure we can properly sample from the wrapped env.
ev.sample() ev.sample()
# Make sure the resulting environment is indeed still an # Make sure the resulting environment is indeed still an
# instance of MultiAgentEnv and VideoMonitor.
self.assertTrue(isinstance(ev.env.unwrapped, MultiAgentEnv)) self.assertTrue(isinstance(ev.env.unwrapped, MultiAgentEnv))
self.assertTrue(isinstance(ev.env, gym.Env)) self.assertTrue(isinstance(ev.env, gym.Env))
self.assertTrue(isinstance(ev.env, VideoMonitor))
ev.stop() ev.stop()
def test_no_training(self): def test_no_training(self):

View file

@ -654,7 +654,6 @@ class WorkerSet:
worker_index=worker_index, worker_index=worker_index,
num_workers=num_workers, num_workers=num_workers,
recreated_worker=recreated_worker, recreated_worker=recreated_worker,
record_env=config["record_env"],
log_dir=self._logdir, log_dir=self._logdir,
log_level=config["log_level"], log_level=config["log_level"],
callbacks=config["callbacks"], callbacks=config["callbacks"],

View file

@ -1,12 +1,3 @@
# ---------------
# IMPORTANT NOTE:
# ---------------
# A recent bug in openAI gym prevents RLlib's "record_env" option
# from recording videos properly. Instead, the produced mp4 files
# have a size of 1kb and are corrupted.
# A simple fix for this is described here:
# https://github.com/openai/gym/issues/1925
import argparse import argparse
import gym import gym
import numpy as np import numpy as np
@ -115,13 +106,6 @@ if __name__ == "__main__":
# Special evaluation config. Keys specified here will override # Special evaluation config. Keys specified here will override
# the same keys in the main config, but only for evaluation. # the same keys in the main config, but only for evaluation.
"evaluation_config": { "evaluation_config": {
# Store videos in this relative directory here inside
# the default output dir (~/ray_results/...).
# Alternatively, you can specify an absolute path.
# Set to True for using the default output dir (~/ray_results/...).
# Set to False for not recording anything.
"record_env": "videos",
# "record_env": "/Users/xyz/my_videos/",
# Render the env while evaluating. # Render the env while evaluating.
# Note that this will always only render the 1st RolloutWorker's # Note that this will always only render the 1st RolloutWorker's
# env and only the 1st sub-env in a vectorized env. # env and only the 1st sub-env in a vectorized env.

View file

@ -46,6 +46,11 @@ parser.add_argument(
parser.add_argument( parser.add_argument(
"--stop-reward", type=float, default=180.0, help="Reward at which we stop training." "--stop-reward", type=float, default=180.0, help="Reward at which we stop training."
) )
parser.add_argument(
"--local-mode",
action="store_true",
help="Init Ray in local mode for easier debugging.",
)
class NonVectorizedEnvToBeVectorizedIntoRemoteBaseEnv(TaskSettableEnv): class NonVectorizedEnvToBeVectorizedIntoRemoteBaseEnv(TaskSettableEnv):
@ -96,7 +101,7 @@ class TaskSettingCallback(DefaultCallbacks):
if __name__ == "__main__": if __name__ == "__main__":
args = parser.parse_args() args = parser.parse_args()
ray.init(num_cpus=6) ray.init(num_cpus=6, local_mode=args.local_mode)
config = { config = {
# Specify your custom (single, non-vectorized) env directly as a # Specify your custom (single, non-vectorized) env directly as a

View file

@ -6,15 +6,14 @@ import ray
from ray import tune from ray import tune
from ray.rllib.agents.registry import get_trainer_class from ray.rllib.agents.registry import get_trainer_class
from ray.rllib.examples.env.repeat_after_me_env import RepeatAfterMeEnv # from ray.rllib.examples.env.repeat_after_me_env import RepeatAfterMeEnv
from ray.rllib.examples.env.stateless_cartpole import StatelessCartPole from ray.rllib.examples.env.stateless_cartpole import StatelessCartPole
envs = {"RepeatAfterMeEnv": RepeatAfterMeEnv, "StatelessCartPole": StatelessCartPole}
config = { config = {
"name": "RNNSAC_example", "name": "RNNSAC_example",
"local_dir": str(Path(__file__).parent / "example_out"), "local_dir": str(Path(__file__).parent / "example_out"),
"checkpoint_freq": 1, "checkpoint_at_end": True,
"keep_checkpoints_num": 1, "keep_checkpoints_num": 1,
"checkpoint_score_attr": "episode_reward_mean", "checkpoint_score_attr": "episode_reward_mean",
"stop": { "stop": {
@ -29,11 +28,8 @@ config = {
"num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")), "num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")),
"framework": "torch", "framework": "torch",
"num_workers": 4, "num_workers": 4,
"num_envs_per_worker": 1, # "env": RepeatAfterMeEnv,
"num_cpus_per_worker": 1, "env": StatelessCartPole,
"log_level": "INFO",
# "env": envs["RepeatAfterMeEnv"],
"env": envs["StatelessCartPole"],
"horizon": 1000, "horizon": 1000,
"gamma": 0.95, "gamma": 0.95,
"batch_mode": "complete_episodes", "batch_mode": "complete_episodes",
@ -102,7 +98,7 @@ if __name__ == "__main__":
eps = 0 eps = 0
ep_reward = 0 ep_reward = 0
while eps < 10: while eps < 10:
action, state, info_trainer = agent.compute_action( action, state, info_trainer = agent.compute_single_action(
obs, obs,
state=state, state=state,
prev_action=prev_action, prev_action=prev_action,
@ -115,7 +111,7 @@ if __name__ == "__main__":
ep_reward += reward ep_reward += reward
try: try:
env.render() env.render()
except (NotImplementedError, ImportError): except Exception:
pass pass
if done: if done:
eps += 1 eps += 1

View file

@ -29,7 +29,6 @@ class TestGymCheckEnv(unittest.TestCase):
env = Mock(spec=["observation_space"]) env = Mock(spec=["observation_space"])
with pytest.raises(AttributeError, match="Env must have action_space."): with pytest.raises(AttributeError, match="Env must have action_space."):
check_gym_environments(env) check_gym_environments(env)
del env
def test_obs_and_action_spaces_are_gym_spaces(self): def test_obs_and_action_spaces_are_gym_spaces(self):
env = RandomEnv() env = RandomEnv()
@ -41,7 +40,6 @@ class TestGymCheckEnv(unittest.TestCase):
env.action_space = "not an action space" env.action_space = "not an action space"
with pytest.raises(ValueError, match="Action space must be a gym.space"): with pytest.raises(ValueError, match="Action space must be a gym.space"):
check_env(env) check_env(env)
del env
def test_reset(self): def test_reset(self):
reset = MagicMock(return_value=5) reset = MagicMock(return_value=5)
@ -56,7 +54,6 @@ class TestGymCheckEnv(unittest.TestCase):
env.reset = reset env.reset = reset
with pytest.raises(ValueError, match=error): with pytest.raises(ValueError, match=error):
check_env(env) check_env(env)
del env
def test_step(self): def test_step(self):
step = MagicMock(return_value=(5, 5, True, {})) step = MagicMock(return_value=(5, 5, True, {}))
@ -92,7 +89,6 @@ class TestGymCheckEnv(unittest.TestCase):
error = "Your step function must return a info that is a dict." error = "Your step function must return a info that is a dict."
with pytest.raises(ValueError, match=error): with pytest.raises(ValueError, match=error):
check_env(env) check_env(env)
del env
class TestCheckMultiAgentEnv(unittest.TestCase): class TestCheckMultiAgentEnv(unittest.TestCase):
@ -104,7 +100,6 @@ class TestCheckMultiAgentEnv(unittest.TestCase):
env = RandomEnv() env = RandomEnv()
with pytest.raises(ValueError, match="The passed env is not"): with pytest.raises(ValueError, match="The passed env is not"):
check_multiagent_environments(env) check_multiagent_environments(env)
del env
def test_check_env_reset_incorrect_error(self): def test_check_env_reset_incorrect_error(self):
reset = MagicMock(return_value=5) reset = MagicMock(return_value=5)
@ -119,7 +114,6 @@ class TestCheckMultiAgentEnv(unittest.TestCase):
env.reset = lambda *_: bad_obs env.reset = lambda *_: bad_obs
with pytest.raises(ValueError, match="The observation collected from env"): with pytest.raises(ValueError, match="The observation collected from env"):
check_env(env) check_env(env)
del env
def test_check_incorrect_space_contains_functions_error(self): def test_check_incorrect_space_contains_functions_error(self):
def bad_contains_function(self, x): def bad_contains_function(self, x):
@ -131,7 +125,6 @@ class TestCheckMultiAgentEnv(unittest.TestCase):
ValueError, match="Your observation_space_contains function has some" ValueError, match="Your observation_space_contains function has some"
): ):
check_env(env) check_env(env)
del env
env = make_multi_agent("CartPole-v1")({"num_agents": 2}) env = make_multi_agent("CartPole-v1")({"num_agents": 2})
bad_action = {0: 2, 1: 2} bad_action = {0: 2, 1: 2}
env.action_space_sample = lambda *_: bad_action env.action_space_sample = lambda *_: bad_action
@ -178,7 +171,6 @@ class TestCheckMultiAgentEnv(unittest.TestCase):
ValueError, match="The action collected from action_space_sample" ValueError, match="The action collected from action_space_sample"
): ):
check_env(env) check_env(env)
del env
env = make_multi_agent("CartPole-v1")({"num_agents": 2}) env = make_multi_agent("CartPole-v1")({"num_agents": 2})
bad_obs = { bad_obs = {
0: np.array([np.inf, np.inf, np.inf, np.inf]), 0: np.array([np.inf, np.inf, np.inf, np.inf]),
@ -206,7 +198,6 @@ class TestCheckBaseEnv:
env = RandomEnv() env = RandomEnv()
with pytest.raises(ValueError, match="The passed env is not"): with pytest.raises(ValueError, match="The passed env is not"):
check_base_env(env) check_base_env(env)
del env
def test_check_env_reset_incorrect_error(self): def test_check_env_reset_incorrect_error(self):
reset = MagicMock(return_value=5) reset = MagicMock(return_value=5)
@ -244,7 +235,6 @@ class TestCheckBaseEnv:
ValueError, match="The observation collected from try_reset" ValueError, match="The observation collected from try_reset"
): ):
check_env(env) check_env(env)
del env
def test_check_space_contains_functions_errors(self): def test_check_space_contains_functions_errors(self):
def bad_contains_function(self, x): def bad_contains_function(self, x):
@ -258,14 +248,12 @@ class TestCheckBaseEnv:
): ):
check_env(env) check_env(env)
del env
env = self._make_base_env() env = self._make_base_env()
env.action_space_contains = bad_contains_function env.action_space_contains = bad_contains_function
with pytest.raises( with pytest.raises(
ValueError, match="Your action_space_contains function has some error" ValueError, match="Your action_space_contains function has some error"
): ):
check_env(env) check_env(env)
del env
def test_bad_sample_function(self): def test_bad_sample_function(self):
env = self._make_base_env() env = self._make_base_env()
@ -276,7 +264,6 @@ class TestCheckBaseEnv:
): ):
check_env(env) check_env(env)
del env
env = self._make_base_env() env = self._make_base_env()
bad_obs = { bad_obs = {
0: { 0: {