mirror of
https://github.com/vale981/ray
synced 2025-03-05 10:01:43 -05:00
[RLlib] Upgrade gym 0.23 (#24171)
This commit is contained in:
parent
c03d0432f3
commit
09886d7ab8
31 changed files with 133 additions and 367 deletions
|
@ -80,7 +80,7 @@ _____________________________________________________________________
|
|||
# __get_q_values_dqn_start__
|
||||
# Get a reference to the model through the policy
|
||||
import numpy as np
|
||||
from ray.rllib.agents.dqn import DQNTrainer
|
||||
from ray.rllib.algorithms.dqn import DQNTrainer
|
||||
|
||||
trainer = DQNTrainer(env="CartPole-v0", config={"framework": "tf2"})
|
||||
model = trainer.get_policy().model
|
||||
|
|
|
@ -39,7 +39,7 @@ __[Full Ray Enhancement Proposal, REP-001: Serve Pipeline](https://github.com/ra
|
|||
|
||||
## Concepts
|
||||
|
||||
- **Deployment**: Scalable, upgradeable group of actors managed by Ray Serve. __[See docs for detail](https://docs.ray.io/en/master/serve/core-apis.html#core-api-deployments)__
|
||||
- **Deployment**: Scalable, upgradeable group of actors managed by Ray Serve. __[See docs for detail](https://docs.ray.io/en/master/serve/package-ref.html#deployment-api)__
|
||||
|
||||
- **DeploymentNode**: Smallest unit in a graph, created by calling `.bind()` on a serve decorated class or function, backed by a Deployment.
|
||||
|
||||
|
|
|
@ -10,7 +10,7 @@ from ray._private.client_mode_hook import enable_client_mode, client_mode_should
|
|||
@pytest.mark.skip(reason="KV store is not working properly.")
|
||||
def test_rllib_integration(ray_start_regular_shared):
|
||||
with ray_start_client_server():
|
||||
import ray.rllib.agents.dqn as dqn
|
||||
import ray.rllib.algorithms.dqn as dqn
|
||||
|
||||
# Confirming the behavior of this context manager.
|
||||
# (Client mode hook not yet enabled.)
|
||||
|
|
|
@ -29,7 +29,7 @@ virtualenv
|
|||
## setup.py extras
|
||||
dm_tree
|
||||
flask
|
||||
gym==0.21.0; python_version >= '3.7'
|
||||
gym>=0.21.0,<0.24.0; python_version >= '3.7'
|
||||
gym==0.19.0; python_version < '3.7'
|
||||
lz4
|
||||
scikit-image
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
# ---------------------
|
||||
# Atari
|
||||
autorom[accept-rom-license]
|
||||
gym[atari]==0.21.0; python_version >= '3.7'
|
||||
gym>=0.21.0,<0.24.0; python_version >= '3.7'
|
||||
gym[atari]==0.19.0; python_version < '3.7'
|
||||
# Kaggle envs.
|
||||
kaggle_environments==1.7.11
|
||||
|
|
|
@ -11,7 +11,7 @@ freezegun==1.1.0
|
|||
gluoncv==0.10.1.post0
|
||||
gpy==1.10.0
|
||||
autorom[accept-rom-license]
|
||||
gym[atari]==0.21.0; python_version >= '3.7'
|
||||
gym>=0.21.0,<0.24.0; python_version >= '3.7'
|
||||
gym[atari]==0.19.0; python_version < '3.7'
|
||||
h5py==3.1.0
|
||||
hpbandster==0.7.4
|
||||
|
|
|
@ -7,7 +7,7 @@ python:
|
|||
pip_packages:
|
||||
- pytest
|
||||
- awscli
|
||||
- gym==0.21.0
|
||||
- gym
|
||||
conda_packages: []
|
||||
|
||||
post_build_cmds:
|
||||
|
|
|
@ -7,7 +7,7 @@ debian_packages:
|
|||
|
||||
python:
|
||||
pip_packages:
|
||||
- gym[atari]==0.21.0
|
||||
- gym[atari]
|
||||
- pytest
|
||||
- tensorflow
|
||||
conda_packages: []
|
||||
|
@ -15,7 +15,7 @@ python:
|
|||
post_build_cmds:
|
||||
- 'rm -r wrk || true && git clone https://github.com/wg/wrk.git /tmp/wrk && cd /tmp/wrk && make -j && sudo cp wrk /usr/local/bin'
|
||||
- pip3 install pytest || true
|
||||
- pip3 install -U ray[all] gym[atari]==0.21.0 autorom[accept-rom-license]
|
||||
- pip3 install -U ray[all] gym[atari] autorom[accept-rom-license]
|
||||
- pip3 install ray[all]
|
||||
# TODO (Alex): Ideally we would install all the dependencies from the new
|
||||
# version too, but pip won't be able to find the new version of ray-cpp.
|
||||
|
|
|
@ -7,7 +7,7 @@ debian_packages:
|
|||
|
||||
python:
|
||||
pip_packages:
|
||||
- gym[atari]==0.21.0
|
||||
- gym[atari]
|
||||
- pygame
|
||||
- pytest
|
||||
- tensorflow
|
||||
|
@ -18,7 +18,7 @@ post_build_cmds:
|
|||
- 'rm -r wrk || true && git clone https://github.com/wg/wrk.git /tmp/wrk && cd /tmp/wrk && make -j && sudo cp wrk /usr/local/bin'
|
||||
- pip3 install numpy==1.19 || true
|
||||
- pip3 install pytest || true
|
||||
- pip3 install -U ray[all] gym[atari]==0.21.0 autorom[accept-rom-license]
|
||||
- pip3 install -U ray[all] gym[atari] autorom[accept-rom-license]
|
||||
- pip3 install ray[all]
|
||||
# TODO (Alex): Ideally we would install all the dependencies from the new
|
||||
# version too, but pip won't be able to find the new version of ray-cpp.
|
||||
|
|
|
@ -8,7 +8,7 @@ python:
|
|||
# These dependencies should be handled by requirements_rllib.txt and
|
||||
# requirements_ml_docker.txt
|
||||
pip_packages:
|
||||
- gym==0.21.0
|
||||
- gym
|
||||
conda_packages: []
|
||||
|
||||
post_build_cmds:
|
||||
|
|
|
@ -8,8 +8,8 @@ python:
|
|||
- pytest
|
||||
- awscli
|
||||
- gsutil
|
||||
- gym
|
||||
- gcsfs
|
||||
- gym==0.21.0
|
||||
- pyarrow>=6.0.1,<7.0.0
|
||||
conda_packages: []
|
||||
|
||||
|
|
|
@ -8,8 +8,8 @@ python:
|
|||
- pytest
|
||||
- awscli
|
||||
- gsutil
|
||||
- gym
|
||||
- gcsfs
|
||||
- gym==0.21.0
|
||||
- pyarrow>=6.0.1,<7.0.0
|
||||
conda_packages: []
|
||||
|
||||
|
|
20
rllib/BUILD
20
rllib/BUILD
|
@ -1318,12 +1318,6 @@ sh_test(
|
|||
data = glob(["examples/serving/*.py"]),
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "env/tests/test_record_env_wrapper",
|
||||
tags = ["team:ml", "env"],
|
||||
size = "small",
|
||||
srcs = ["env/tests/test_record_env_wrapper.py"]
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "env/tests/test_remote_worker_envs",
|
||||
|
@ -2818,13 +2812,13 @@ py_test(
|
|||
args = ["--as-test", "--framework=torch"],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "examples/remote_base_env_with_custom_api",
|
||||
tags = ["team:ml", "examples", "examples_R"],
|
||||
size = "medium",
|
||||
srcs = ["examples/remote_base_env_with_custom_api.py"],
|
||||
args = ["--stop-iters=3"]
|
||||
)
|
||||
# py_test(
|
||||
# name = "examples/remote_base_env_with_custom_api",
|
||||
# tags = ["team:ml", "examples", "examples_R"],
|
||||
# size = "medium",
|
||||
# srcs = ["examples/remote_base_env_with_custom_api.py"],
|
||||
# args = ["--stop-iters=3"]
|
||||
# )
|
||||
|
||||
py_test(
|
||||
name = "examples/restore_1_of_n_agents_from_checkpoint",
|
||||
|
|
|
@ -75,10 +75,11 @@ class _MockTrainer(Trainer):
|
|||
self.info = info
|
||||
self.restored = True
|
||||
|
||||
@staticmethod
|
||||
@override(Trainer)
|
||||
def _register_if_needed(self, env_object, config):
|
||||
def _get_env_id_and_creator(env_specifier, config):
|
||||
# No env to register.
|
||||
pass
|
||||
return None, None
|
||||
|
||||
def set_info(self, info):
|
||||
self.info = info
|
||||
|
|
|
@ -8,7 +8,9 @@ import logging
|
|||
import math
|
||||
import numpy as np
|
||||
import os
|
||||
from packaging import version
|
||||
import pickle
|
||||
import pkg_resources
|
||||
import tempfile
|
||||
import time
|
||||
from typing import (
|
||||
|
@ -30,7 +32,6 @@ from ray.exceptions import RayError
|
|||
from ray.rllib.agents.callbacks import DefaultCallbacks
|
||||
from ray.rllib.agents.trainer_config import TrainerConfig
|
||||
from ray.rllib.env.env_context import EnvContext
|
||||
from ray.rllib.env.multi_agent_env import MultiAgentEnv
|
||||
from ray.rllib.env.utils import gym_env_creator
|
||||
from ray.rllib.evaluation.episode import Episode
|
||||
from ray.rllib.evaluation.metrics import (
|
||||
|
@ -94,7 +95,7 @@ from ray.rllib.utils.typing import (
|
|||
TrainerConfigDict,
|
||||
)
|
||||
from ray.tune.logger import Logger, UnifiedLogger
|
||||
from ray.tune.registry import ENV_CREATOR, register_env, _global_registry
|
||||
from ray.tune.registry import ENV_CREATOR, _global_registry
|
||||
from ray.tune.resources import Resources
|
||||
from ray.tune.result import DEFAULT_RESULTS_DIR
|
||||
from ray.tune.trainable import Trainable
|
||||
|
@ -220,19 +221,13 @@ class Trainer(Trainable):
|
|||
if isinstance(config, TrainerConfig):
|
||||
config = config.to_dict()
|
||||
|
||||
# Convert `env` provided in config into a string:
|
||||
# - If `env` is a string: `self._env_id` = `env`.
|
||||
# - If `env` is a class: `self._env_id` = `env.__name__` -> Already
|
||||
# register it with a auto-generated env creator.
|
||||
# - If `env` is None: `self._env_id` is None.
|
||||
self._env_id: Optional[str] = self._register_if_needed(
|
||||
# Convert `env` provided in config into a concrete env creator callable, which
|
||||
# takes an EnvContext (config dict) as arg and returning an RLlib supported Env
|
||||
# type (e.g. a gym.Env).
|
||||
self._env_id, self.env_creator = self._get_env_id_and_creator(
|
||||
env or config.get("env"), config
|
||||
)
|
||||
|
||||
# The env creator callable, taking an EnvContext (config dict)
|
||||
# as arg and returning an RLlib supported Env type (e.g. a gym.Env).
|
||||
self.env_creator: Optional[EnvCreator] = None
|
||||
|
||||
# Placeholder for a local replay buffer instance.
|
||||
self.local_replay_buffer = None
|
||||
|
||||
|
@ -310,10 +305,6 @@ class Trainer(Trainable):
|
|||
# Validate the framework settings in config.
|
||||
self.validate_framework(self.config)
|
||||
|
||||
# Setup the self.env_creator callable (to be passed
|
||||
# e.g. to RolloutWorkers' c'tors).
|
||||
self.env_creator = self._get_env_creator_from_env_id(self._env_id)
|
||||
|
||||
# Set Trainer's seed after we have - if necessary - enabled
|
||||
# tf eager-execution.
|
||||
update_global_seed_if_necessary(self.config["framework"], self.config["seed"])
|
||||
|
@ -466,8 +457,9 @@ class Trainer(Trainable):
|
|||
|
||||
self.config["evaluation_config"] = eval_config
|
||||
|
||||
env_id = self._register_if_needed(eval_config.get("env"), eval_config)
|
||||
env_creator = self._get_env_creator_from_env_id(env_id)
|
||||
env_id, env_creator = self._get_env_id_and_creator(
|
||||
eval_config.get("env"), eval_config
|
||||
)
|
||||
|
||||
# Create a separate evaluation worker set for evaluation.
|
||||
# If evaluation_num_workers=0, use the evaluation set's local
|
||||
|
@ -1541,37 +1533,87 @@ class Trainer(Trainable):
|
|||
"""Pre-evaluation callback."""
|
||||
pass
|
||||
|
||||
def _get_env_creator_from_env_id(self, env_id: Optional[str] = None) -> EnvCreator:
|
||||
"""Returns an env creator callable, given an `env_id` (e.g. "CartPole-v0").
|
||||
@staticmethod
|
||||
def _get_env_id_and_creator(
|
||||
env_specifier: Union[str, EnvType, None], config: PartialTrainerConfigDict
|
||||
) -> Tuple[Optional[str], EnvCreator]:
|
||||
"""Returns env_id and creator callable given original env id from config.
|
||||
|
||||
Args:
|
||||
env_id: An already tune registered env ID, a known gym env name,
|
||||
or None (if no env is used).
|
||||
env_specifier: An env class, an already tune registered env ID, a known
|
||||
gym env name, or None (if no env is used).
|
||||
config: The Trainer's (maybe partial) config dict.
|
||||
|
||||
Returns:
|
||||
Tuple consisting of a) env ID string and b) env creator callable.
|
||||
"""
|
||||
if env_id:
|
||||
# Environment is specified via a string.
|
||||
if isinstance(env_specifier, str):
|
||||
# An already registered env.
|
||||
if _global_registry.contains(ENV_CREATOR, env_id):
|
||||
return _global_registry.get(ENV_CREATOR, env_id)
|
||||
if _global_registry.contains(ENV_CREATOR, env_specifier):
|
||||
return env_specifier, _global_registry.get(ENV_CREATOR, env_specifier)
|
||||
|
||||
# A class path specifier.
|
||||
elif "." in env_id:
|
||||
elif "." in env_specifier:
|
||||
|
||||
def env_creator_from_classpath(env_context):
|
||||
try:
|
||||
env_obj = from_config(env_id, env_context)
|
||||
env_obj = from_config(env_specifier, env_context)
|
||||
except ValueError:
|
||||
raise EnvError(ERR_MSG_INVALID_ENV_DESCRIPTOR.format(env_id))
|
||||
raise EnvError(
|
||||
ERR_MSG_INVALID_ENV_DESCRIPTOR.format(env_specifier)
|
||||
)
|
||||
return env_obj
|
||||
|
||||
return env_creator_from_classpath
|
||||
return env_specifier, env_creator_from_classpath
|
||||
# Try gym/PyBullet/Vizdoom.
|
||||
else:
|
||||
return functools.partial(gym_env_creator, env_descriptor=env_id)
|
||||
return env_specifier, functools.partial(
|
||||
gym_env_creator, env_descriptor=env_specifier
|
||||
)
|
||||
|
||||
elif isinstance(env_specifier, type):
|
||||
env_id = env_specifier # .__name__
|
||||
|
||||
if config.get("remote_worker_envs"):
|
||||
# Check gym version (0.22 or higher?).
|
||||
# If > 0.21, can't perform auto-wrapping of the given class as this
|
||||
# would lead to a pickle error.
|
||||
gym_version = pkg_resources.get_distribution("gym").version
|
||||
if version.parse(gym_version) >= version.parse("0.22"):
|
||||
raise ValueError(
|
||||
"Cannot specify a gym.Env class via `config.env` while setting "
|
||||
"`config.remote_worker_env=True` AND your gym version is >= "
|
||||
"0.22! Try installing an older version of gym or set `config."
|
||||
"remote_worker_env=False`."
|
||||
)
|
||||
|
||||
@ray.remote(num_cpus=1)
|
||||
class _wrapper(env_specifier):
|
||||
# Add convenience `_get_spaces` and `_is_multi_agent`
|
||||
# methods:
|
||||
def _get_spaces(self):
|
||||
return self.observation_space, self.action_space
|
||||
|
||||
def _is_multi_agent(self):
|
||||
from ray.rllib.env.multi_agent_env import MultiAgentEnv
|
||||
|
||||
return isinstance(self, MultiAgentEnv)
|
||||
|
||||
return env_id, lambda cfg: _wrapper.remote(cfg)
|
||||
else:
|
||||
return env_id, lambda cfg: env_specifier(cfg)
|
||||
|
||||
# No env -> Env creator always returns None.
|
||||
elif env_specifier is None:
|
||||
return None, lambda env_config: None
|
||||
|
||||
else:
|
||||
return lambda env_config: None
|
||||
raise ValueError(
|
||||
"{} is an invalid env specifier. ".format(env_specifier)
|
||||
+ "You can specify a custom env as either a class "
|
||||
'(e.g., YourEnvCls) or a registered env id (e.g., "your_env").'
|
||||
)
|
||||
|
||||
def _sync_filters_if_needed(self, workers: WorkerSet):
|
||||
if self.config.get("observation_filter", "NoFilter") != "NoFilter":
|
||||
|
@ -1753,16 +1795,6 @@ class Trainer(Trainable):
|
|||
if model_config is None:
|
||||
config["model"] = model_config = {}
|
||||
|
||||
# Monitor should be replaced by `record_env`.
|
||||
if config.get("monitor", DEPRECATED_VALUE) != DEPRECATED_VALUE:
|
||||
deprecation_warning("monitor", "record_env", error=False)
|
||||
config["record_env"] = config.get("monitor", False)
|
||||
# Empty string would fail some if-blocks checking for this setting.
|
||||
# Set to True instead, meaning: use default output dir to store
|
||||
# the videos.
|
||||
if config.get("record_env") == "":
|
||||
config["record_env"] = True
|
||||
|
||||
# Use DefaultCallbacks class, if callbacks is None.
|
||||
if config["callbacks"] is None:
|
||||
config["callbacks"] = DefaultCallbacks
|
||||
|
@ -2149,38 +2181,6 @@ class Trainer(Trainable):
|
|||
kwargs["local_replay_buffer"] = self.local_replay_buffer
|
||||
return kwargs
|
||||
|
||||
def _register_if_needed(
|
||||
self, env_object: Union[str, EnvType, None], config
|
||||
) -> Optional[str]:
|
||||
if isinstance(env_object, str):
|
||||
return env_object
|
||||
elif isinstance(env_object, type):
|
||||
name = env_object.__name__
|
||||
|
||||
if config.get("remote_worker_envs"):
|
||||
|
||||
@ray.remote(num_cpus=0)
|
||||
class _wrapper(env_object):
|
||||
# Add convenience `_get_spaces` and `_is_multi_agent`
|
||||
# methods.
|
||||
def _get_spaces(self):
|
||||
return self.observation_space, self.action_space
|
||||
|
||||
def _is_multi_agent(self):
|
||||
return isinstance(self, MultiAgentEnv)
|
||||
|
||||
register_env(name, lambda cfg: _wrapper.remote(cfg))
|
||||
else:
|
||||
register_env(name, lambda cfg: env_object(cfg))
|
||||
return name
|
||||
elif env_object is None:
|
||||
return None
|
||||
raise ValueError(
|
||||
"{} is an invalid env specification. ".format(env_object)
|
||||
+ "You can specify a custom env as either a class "
|
||||
'(e.g., YourEnvCls) or a registered env id (e.g., "your_env").'
|
||||
)
|
||||
|
||||
def _step_context(trainer):
|
||||
class StepCtx:
|
||||
def __enter__(self):
|
||||
|
|
|
@ -108,7 +108,6 @@ class TrainerConfig:
|
|||
self.action_space = None
|
||||
self.env_task_fn = None
|
||||
self.render_env = False
|
||||
self.record_env = False
|
||||
self.clip_rewards = None
|
||||
self.normalize_actions = True
|
||||
self.clip_actions = False
|
||||
|
@ -458,7 +457,6 @@ class TrainerConfig:
|
|||
action_space: Optional[gym.spaces.Space] = None,
|
||||
env_task_fn: Optional[Callable[[ResultDict, EnvType, EnvContext], Any]] = None,
|
||||
render_env: Optional[bool] = None,
|
||||
record_env: Optional[bool] = None,
|
||||
clip_rewards: Optional[Union[bool, float]] = None,
|
||||
normalize_actions: Optional[bool] = None,
|
||||
clip_actions: Optional[bool] = None,
|
||||
|
@ -489,11 +487,6 @@ class TrainerConfig:
|
|||
`render()` method which either:
|
||||
a) handles window generation and rendering itself (returning True) or
|
||||
b) returns a numpy uint8 image of shape [height x width x 3 (RGB)].
|
||||
record_env: If True, stores videos in this relative directory inside the
|
||||
default output dir (~/ray_results/...). Alternatively, you can
|
||||
specify an absolute path (str), in which the env recordings should be
|
||||
stored instead. Set to False for not recording anything.
|
||||
Note: This setting replaces the deprecated `monitor` key.
|
||||
clip_rewards: Whether to clip rewards during Policy's postprocessing.
|
||||
None (default): Clip for Atari only (r=sign(r)).
|
||||
True: r=sign(r): Fixed rewards -1.0, 1.0, or 0.0.
|
||||
|
@ -524,8 +517,6 @@ class TrainerConfig:
|
|||
self.env_task_fn = env_task_fn
|
||||
if render_env is not None:
|
||||
self.render_env = render_env
|
||||
if record_env is not None:
|
||||
self.record_env = record_env
|
||||
if clip_rewards is not None:
|
||||
self.clip_rewards = clip_rewards
|
||||
if normalize_actions is not None:
|
||||
|
|
|
@ -168,7 +168,7 @@ class AlphaZeroPolicyWrapperClass(AlphaZeroPolicy):
|
|||
model = ModelCatalog.get_model_v2(
|
||||
obs_space, action_space, action_space.n, config["model"], "torch"
|
||||
)
|
||||
env_creator = Trainer._get_env_creator_from_env_id(None, config["env"])
|
||||
_, env_creator = Trainer._get_env_id_and_creator(config["env"], config)
|
||||
if config["ranked_rewards"]["enable"]:
|
||||
# if r2 is enabled, tne env is wrapped to include a rewards buffer
|
||||
# used to normalize rewards
|
||||
|
|
|
@ -355,8 +355,6 @@ class ARSTrainer(Trainer):
|
|||
# Validate our config dict.
|
||||
self.validate_config(self.config)
|
||||
|
||||
# Generate `self.env_creator` callable to create an env instance.
|
||||
self.env_creator = self._get_env_creator_from_env_id(self._env_id)
|
||||
# Generate the local env.
|
||||
env_context = EnvContext(self.config["env_config"] or {}, worker_index=0)
|
||||
env = self.env_creator(env_context)
|
||||
|
|
|
@ -362,8 +362,6 @@ class ESTrainer(Trainer):
|
|||
# Call super's validation method.
|
||||
self.validate_config(self.config)
|
||||
|
||||
# Generate `self.env_creator` callable to create an env instance.
|
||||
self.env_creator = self._get_env_creator_from_env_id(self._env_id)
|
||||
# Generate the local env.
|
||||
env_context = EnvContext(self.config["env_config"] or {}, worker_index=0)
|
||||
env = self.env_creator(env_context)
|
||||
|
|
2
rllib/env/multi_agent_env.py
vendored
2
rllib/env/multi_agent_env.py
vendored
|
@ -1,6 +1,6 @@
|
|||
import gym
|
||||
import logging
|
||||
from typing import Callable, Dict, List, Tuple, Type, Optional, Union, Set
|
||||
from typing import Callable, Dict, List, Tuple, Optional, Union, Set, Type
|
||||
|
||||
from ray.rllib.env.base_env import BaseEnv
|
||||
from ray.rllib.utils.annotations import (
|
||||
|
|
100
rllib/env/tests/test_record_env_wrapper.py
vendored
100
rllib/env/tests/test_record_env_wrapper.py
vendored
|
@ -1,100 +0,0 @@
|
|||
import glob
|
||||
import gym
|
||||
import numpy as np
|
||||
import os
|
||||
import shutil
|
||||
import unittest
|
||||
|
||||
from ray.rllib.env.utils import VideoMonitor, record_env_wrapper
|
||||
from ray.rllib.examples.env.mock_env import MockEnv2
|
||||
from ray.rllib.examples.env.multi_agent import BasicMultiAgent
|
||||
from ray.rllib.utils.test_utils import check
|
||||
|
||||
|
||||
class TestRecordEnvWrapper(unittest.TestCase):
|
||||
def test_wrap_gym_env(self):
|
||||
record_env_dir = os.popen("mktemp -d").read()[:-1]
|
||||
print(f"tmp dir for videos={record_env_dir}")
|
||||
|
||||
if not os.path.exists(record_env_dir):
|
||||
sys.exit(1)
|
||||
|
||||
num_steps_per_episode = 10
|
||||
wrapped = record_env_wrapper(
|
||||
env=MockEnv2(num_steps_per_episode),
|
||||
record_env=record_env_dir,
|
||||
log_dir="",
|
||||
policy_config={
|
||||
"in_evaluation": False,
|
||||
},
|
||||
)
|
||||
# Non MultiAgentEnv: Wrapper's type is wrappers.Monitor.
|
||||
self.assertTrue(isinstance(wrapped, gym.wrappers.Monitor))
|
||||
self.assertFalse(isinstance(wrapped, VideoMonitor))
|
||||
|
||||
wrapped.reset()
|
||||
# Expect one video file to have been produced in the tmp dir.
|
||||
os.chdir(record_env_dir)
|
||||
ls = glob.glob("*.mp4")
|
||||
self.assertTrue(len(ls) == 1)
|
||||
# 10 steps for a complete episode.
|
||||
for i in range(num_steps_per_episode):
|
||||
wrapped.step(0)
|
||||
# Another episode.
|
||||
wrapped.reset()
|
||||
for i in range(num_steps_per_episode):
|
||||
wrapped.step(0)
|
||||
# Expect another video file to have been produced (2nd episode).
|
||||
ls = glob.glob("*.mp4")
|
||||
self.assertTrue(len(ls) == 2)
|
||||
|
||||
# MockEnv2 returns a reward of 100.0 every step.
|
||||
# So total reward is 1000.0 per episode (10 steps).
|
||||
check(
|
||||
np.array([100.0, 100.0]) * num_steps_per_episode,
|
||||
wrapped.get_episode_rewards(),
|
||||
)
|
||||
# Erase all generated files and the temp path just in case,
|
||||
# as to not disturb further CI-tests.
|
||||
shutil.rmtree(record_env_dir)
|
||||
|
||||
def test_wrap_multi_agent_env(self):
|
||||
record_env_dir = os.popen("mktemp -d").read()[:-1]
|
||||
print(f"tmp dir for videos={record_env_dir}")
|
||||
|
||||
if not os.path.exists(record_env_dir):
|
||||
sys.exit(1)
|
||||
|
||||
wrapped = record_env_wrapper(
|
||||
env=BasicMultiAgent(3),
|
||||
record_env=record_env_dir,
|
||||
log_dir="",
|
||||
policy_config={
|
||||
"in_evaluation": False,
|
||||
},
|
||||
)
|
||||
# Type is VideoMonitor.
|
||||
self.assertTrue(isinstance(wrapped, gym.wrappers.Monitor))
|
||||
self.assertTrue(isinstance(wrapped, VideoMonitor))
|
||||
|
||||
wrapped.reset()
|
||||
|
||||
# BasicMultiAgent is hardcoded to run 25-step episodes.
|
||||
for i in range(25):
|
||||
wrapped.step({0: 0, 1: 0, 2: 0})
|
||||
|
||||
# Expect one video file to have been produced in the tmp dir.
|
||||
os.chdir(record_env_dir)
|
||||
ls = glob.glob("*.mp4")
|
||||
self.assertTrue(len(ls) == 1)
|
||||
|
||||
# However VideoMonitor's _after_step is overwritten to not
|
||||
# use stats_recorder. So nothing to verify here, except that
|
||||
# it runs fine.
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import pytest
|
||||
import sys
|
||||
|
||||
sys.exit(pytest.main(["-v", __file__]))
|
33
rllib/env/tests/test_remote_worker_envs.py
vendored
33
rllib/env/tests/test_remote_worker_envs.py
vendored
|
@ -7,10 +7,8 @@ import unittest
|
|||
import ray
|
||||
from ray.rllib.algorithms.pg import pg
|
||||
from ray.rllib.env.wrappers.pettingzoo_env import PettingZooEnv
|
||||
from ray.rllib.examples.env.random_env import RandomEnv, RandomMultiAgentEnv
|
||||
from ray.rllib.examples.remote_base_env_with_custom_api import (
|
||||
NonVectorizedEnvToBeVectorizedIntoRemoteBaseEnv,
|
||||
)
|
||||
|
||||
# from ray.rllib.examples.env.random_env import RandomEnv
|
||||
from ray import tune
|
||||
|
||||
|
||||
|
@ -55,17 +53,19 @@ class TestRemoteWorkerEnvSetting(unittest.TestCase):
|
|||
trainer.stop()
|
||||
|
||||
# Using class directly.
|
||||
config["env"] = RandomEnv
|
||||
trainer = pg.PGTrainer(config=config)
|
||||
print(trainer.train())
|
||||
trainer.stop()
|
||||
# This doesn't work anymore as of gym==0.23
|
||||
# config["env"] = RandomEnv
|
||||
# trainer = pg.PGTrainer(config=config)
|
||||
# print(trainer.train())
|
||||
# trainer.stop()
|
||||
|
||||
# Using class directly: Sub-class of gym.Env,
|
||||
# which implements its own API.
|
||||
config["env"] = NonVectorizedEnvToBeVectorizedIntoRemoteBaseEnv
|
||||
trainer = pg.PGTrainer(config=config)
|
||||
print(trainer.train())
|
||||
trainer.stop()
|
||||
# This doesn't work anymore as of gym==0.23
|
||||
# config["env"] = NonVectorizedEnvToBeVectorizedIntoRemoteBaseEnv
|
||||
# trainer = pg.PGTrainer(config=config)
|
||||
# print(trainer.train())
|
||||
# trainer.stop()
|
||||
|
||||
def test_remote_worker_env_multi_agent(self):
|
||||
config = pg.DEFAULT_CONFIG.copy()
|
||||
|
@ -85,10 +85,11 @@ class TestRemoteWorkerEnvSetting(unittest.TestCase):
|
|||
trainer.stop()
|
||||
|
||||
# Using class directly.
|
||||
config["env"] = RandomMultiAgentEnv
|
||||
trainer = pg.PGTrainer(config=config)
|
||||
print(trainer.train())
|
||||
trainer.stop()
|
||||
# This doesn't work anymore as of gym==0.23.
|
||||
# config["env"] = RandomMultiAgentEnv
|
||||
# trainer = pg.PGTrainer(config=config)
|
||||
# print(trainer.train())
|
||||
# trainer.stop()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
49
rllib/env/utils.py
vendored
49
rllib/env/utils.py
vendored
|
@ -1,10 +1,6 @@
|
|||
import gym
|
||||
from gym import wrappers
|
||||
import os
|
||||
|
||||
from ray.rllib.env.env_context import EnvContext
|
||||
from ray.rllib.env.multi_agent_env import MultiAgentEnv
|
||||
from ray.rllib.utils import add_mixins
|
||||
from ray.rllib.utils.error import ERR_MSG_INVALID_ENV_DESCRIPTOR, EnvError
|
||||
|
||||
|
||||
|
@ -54,48 +50,3 @@ def gym_env_creator(env_context: EnvContext, env_descriptor: str) -> gym.Env:
|
|||
return gym.make(env_descriptor, **env_context)
|
||||
except gym.error.Error:
|
||||
raise EnvError(ERR_MSG_INVALID_ENV_DESCRIPTOR.format(env_descriptor))
|
||||
|
||||
|
||||
class VideoMonitor(wrappers.Monitor):
|
||||
# Same as original method, but doesn't use the StatsRecorder as it will
|
||||
# try to add up multi-agent rewards dicts, which throws errors.
|
||||
def _after_step(self, observation, reward, done, info):
|
||||
if not self.enabled:
|
||||
return done
|
||||
|
||||
# Use done["__all__"] b/c this is a multi-agent dict.
|
||||
if done["__all__"] and self.env_semantics_autoreset:
|
||||
# For envs with BlockingReset wrapping VNCEnv, this observation
|
||||
# will be the first one of the new episode
|
||||
self.reset_video_recorder()
|
||||
self.episode_id += 1
|
||||
self._flush()
|
||||
|
||||
# Record video
|
||||
self.video_recorder.capture_frame()
|
||||
|
||||
return done
|
||||
|
||||
|
||||
def record_env_wrapper(env, record_env, log_dir, policy_config):
|
||||
if record_env:
|
||||
path_ = record_env if isinstance(record_env, str) else log_dir
|
||||
# Relative path: Add logdir here, otherwise, this would
|
||||
# not work for non-local workers.
|
||||
if not os.path.isabs(path_):
|
||||
path_ = os.path.join(log_dir, path_)
|
||||
print(f"Setting the path for recording to {path_}")
|
||||
wrapper_cls = (
|
||||
VideoMonitor if isinstance(env, MultiAgentEnv) else wrappers.Monitor
|
||||
)
|
||||
if isinstance(env, MultiAgentEnv):
|
||||
wrapper_cls = add_mixins(wrapper_cls, [MultiAgentEnv], reversed=True)
|
||||
env = wrapper_cls(
|
||||
env,
|
||||
path_,
|
||||
resume=True,
|
||||
force=True,
|
||||
video_callable=lambda _: True,
|
||||
mode="evaluation" if policy_config["in_evaluation"] else "training",
|
||||
)
|
||||
return env
|
||||
|
|
|
@ -4,7 +4,6 @@ import argparse
|
|||
import collections
|
||||
import copy
|
||||
import gym
|
||||
from gym import wrappers as gym_wrappers
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
@ -339,7 +338,6 @@ def run(args, parser):
|
|||
deprecation_warning(old="--no-render", new="--render", error=False)
|
||||
args.render = False
|
||||
config["render_env"] = args.render
|
||||
config["record_env"] = args.video_dir
|
||||
|
||||
ray.init(local_mode=args.local_mode)
|
||||
|
||||
|
@ -354,12 +352,6 @@ def run(args, parser):
|
|||
num_steps = int(args.steps)
|
||||
num_episodes = int(args.episodes)
|
||||
|
||||
# Determine the video output directory.
|
||||
video_dir = None
|
||||
# Allow user to specify a video output path.
|
||||
if args.video_dir:
|
||||
video_dir = os.path.expanduser(args.video_dir)
|
||||
|
||||
# Do the actual rollout.
|
||||
with RolloutSaver(
|
||||
args.out,
|
||||
|
@ -369,9 +361,7 @@ def run(args, parser):
|
|||
target_episodes=num_episodes,
|
||||
save_info=args.save_info,
|
||||
) as saver:
|
||||
rollout(
|
||||
agent, args.env, num_steps, num_episodes, saver, not args.render, video_dir
|
||||
)
|
||||
rollout(agent, args.env, num_steps, num_episodes, saver, not args.render)
|
||||
agent.stop()
|
||||
|
||||
|
||||
|
@ -406,7 +396,6 @@ def rollout(
|
|||
num_episodes=0,
|
||||
saver=None,
|
||||
no_render=True,
|
||||
video_dir=None,
|
||||
):
|
||||
policy_agent_mapping = default_policy_agent_mapping
|
||||
|
||||
|
@ -473,13 +462,6 @@ def rollout(
|
|||
for p, m in policy_map.items()
|
||||
}
|
||||
|
||||
# If monitoring has been requested, manually wrap our environment with a
|
||||
# gym monitor, which is set to record every episode.
|
||||
if video_dir:
|
||||
env = gym_wrappers.Monitor(
|
||||
env=env, directory=video_dir, video_callable=lambda _: True, force=True
|
||||
)
|
||||
|
||||
steps = 0
|
||||
episodes = 0
|
||||
while keep_going(steps, num_steps, episodes, num_episodes):
|
||||
|
|
|
@ -27,7 +27,6 @@ from ray.rllib.env.base_env import BaseEnv, convert_to_base_env
|
|||
from ray.rllib.env.env_context import EnvContext
|
||||
from ray.rllib.env.multi_agent_env import MultiAgentEnv
|
||||
from ray.rllib.env.external_multi_agent_env import ExternalMultiAgentEnv
|
||||
from ray.rllib.env.utils import record_env_wrapper
|
||||
from ray.rllib.env.wrappers.atari_wrappers import wrap_deepmind, is_atari
|
||||
from ray.rllib.evaluation.sampler import AsyncSampler, SyncSampler
|
||||
from ray.rllib.evaluation.metrics import RolloutMetrics
|
||||
|
@ -233,7 +232,6 @@ class RolloutWorker(ParallelIteratorWorker):
|
|||
worker_index: int = 0,
|
||||
num_workers: int = 0,
|
||||
recreated_worker: bool = False,
|
||||
record_env: Union[bool, str] = False,
|
||||
log_dir: Optional[str] = None,
|
||||
log_level: Optional[str] = None,
|
||||
callbacks: Type["DefaultCallbacks"] = None,
|
||||
|
@ -253,7 +251,6 @@ class RolloutWorker(ParallelIteratorWorker):
|
|||
fake_sampler: bool = False,
|
||||
spaces: Optional[Dict[PolicyID, Tuple[Space, Space]]] = None,
|
||||
policy=None,
|
||||
monitor_path=None,
|
||||
disable_env_checking=False,
|
||||
):
|
||||
"""Initializes a RolloutWorker instance.
|
||||
|
@ -332,10 +329,6 @@ class RolloutWorker(ParallelIteratorWorker):
|
|||
`recreate_failed_workers=True` and one of the original workers (or an
|
||||
already recreated one) has failed. They don't differ from original
|
||||
workers other than the value of this flag (`self.recreated_worker`).
|
||||
record_env: Write out episode stats and videos
|
||||
using gym.wrappers.Monitor to this directory if specified. If
|
||||
True, use the default output dir in ~/ray_results/.... If
|
||||
False, do not record anything.
|
||||
log_dir: Directory where logs can be placed.
|
||||
log_level: Set the root log level on creation.
|
||||
callbacks: Custom sub-class of
|
||||
|
@ -374,7 +367,6 @@ class RolloutWorker(ParallelIteratorWorker):
|
|||
to (obs_space, action_space)-tuples. This is used in case no
|
||||
Env is created on this RolloutWorker.
|
||||
policy: Obsoleted arg. Use `policy_spec` instead.
|
||||
monitor_path: Obsoleted arg. Use `record_env` instead.
|
||||
disable_env_checking: If True, disables the env checking module that
|
||||
validates the properties of the passed environment.
|
||||
"""
|
||||
|
@ -395,10 +387,6 @@ class RolloutWorker(ParallelIteratorWorker):
|
|||
for pid, spec in policy_spec.copy().items()
|
||||
}
|
||||
|
||||
if monitor_path is not None:
|
||||
deprecation_warning("monitor_path", "record_env", error=False)
|
||||
record_env = monitor_path
|
||||
|
||||
self._original_kwargs: dict = locals().copy()
|
||||
del self._original_kwargs["self"]
|
||||
|
||||
|
@ -490,7 +478,7 @@ class RolloutWorker(ParallelIteratorWorker):
|
|||
# 1) Create the env using the user provided env_creator. This may
|
||||
# return a gym.Env (incl. MultiAgentEnv), an already vectorized
|
||||
# VectorEnv, BaseEnv, ExternalEnv, or an ActorHandle (remote env).
|
||||
# 2) Wrap - if applicable - with Atari/recording/rendering wrappers.
|
||||
# 2) Wrap - if applicable - with Atari/rendering wrappers.
|
||||
# 3) Seed the env, if necessary.
|
||||
# 4) Vectorize the existing single env by creating more clones of
|
||||
# this env and wrapping it with the RLlib BaseEnv class.
|
||||
|
@ -541,14 +529,12 @@ class RolloutWorker(ParallelIteratorWorker):
|
|||
env = wrap_deepmind(
|
||||
env, dim=model_config.get("dim"), framestack=use_framestack
|
||||
)
|
||||
env = record_env_wrapper(env, record_env, log_dir, policy_config)
|
||||
return env
|
||||
|
||||
# gym.Env -> Wrap with gym Monitor.
|
||||
else:
|
||||
|
||||
def wrap(env):
|
||||
return record_env_wrapper(env, record_env, log_dir, policy_config)
|
||||
return env
|
||||
|
||||
# Wrap env through the correct wrapper.
|
||||
self.env: EnvType = wrap(self.env)
|
||||
|
|
|
@ -4,7 +4,6 @@ from gym.spaces import Box, Discrete
|
|||
import numpy as np
|
||||
import os
|
||||
import random
|
||||
import tempfile
|
||||
import time
|
||||
import unittest
|
||||
|
||||
|
@ -12,7 +11,6 @@ import ray
|
|||
from ray.rllib.algorithms.pg import PGTrainer
|
||||
from ray.rllib.agents.a3c import A2CTrainer
|
||||
from ray.rllib.env.multi_agent_env import MultiAgentEnv
|
||||
from ray.rllib.env.utils import VideoMonitor
|
||||
from ray.rllib.evaluation.rollout_worker import RolloutWorker
|
||||
from ray.rllib.evaluation.metrics import collect_metrics
|
||||
from ray.rllib.evaluation.postprocessing import compute_advantages
|
||||
|
@ -376,10 +374,8 @@ class TestRolloutWorker(unittest.TestCase):
|
|||
|
||||
curframe = inspect.currentframe()
|
||||
called_from_check = any(
|
||||
[
|
||||
frame[3] == "check_gym_environments"
|
||||
for frame in inspect.getouterframes(curframe, 2)
|
||||
]
|
||||
frame[3] == "check_gym_environments"
|
||||
for frame in inspect.getouterframes(curframe, 2)
|
||||
)
|
||||
# Check, whether the action is immutable.
|
||||
if action.flags.writeable and not called_from_check:
|
||||
|
@ -825,15 +821,12 @@ class TestRolloutWorker(unittest.TestCase):
|
|||
policy_config={
|
||||
"in_evaluation": False,
|
||||
},
|
||||
record_env=tempfile.gettempdir(),
|
||||
)
|
||||
# Make sure we can properly sample from the wrapped env.
|
||||
ev.sample()
|
||||
# Make sure the resulting environment is indeed still an
|
||||
# instance of MultiAgentEnv and VideoMonitor.
|
||||
self.assertTrue(isinstance(ev.env.unwrapped, MultiAgentEnv))
|
||||
self.assertTrue(isinstance(ev.env, gym.Env))
|
||||
self.assertTrue(isinstance(ev.env, VideoMonitor))
|
||||
ev.stop()
|
||||
|
||||
def test_no_training(self):
|
||||
|
|
|
@ -654,7 +654,6 @@ class WorkerSet:
|
|||
worker_index=worker_index,
|
||||
num_workers=num_workers,
|
||||
recreated_worker=recreated_worker,
|
||||
record_env=config["record_env"],
|
||||
log_dir=self._logdir,
|
||||
log_level=config["log_level"],
|
||||
callbacks=config["callbacks"],
|
||||
|
|
|
@ -1,12 +1,3 @@
|
|||
# ---------------
|
||||
# IMPORTANT NOTE:
|
||||
# ---------------
|
||||
# A recent bug in openAI gym prevents RLlib's "record_env" option
|
||||
# from recording videos properly. Instead, the produced mp4 files
|
||||
# have a size of 1kb and are corrupted.
|
||||
# A simple fix for this is described here:
|
||||
# https://github.com/openai/gym/issues/1925
|
||||
|
||||
import argparse
|
||||
import gym
|
||||
import numpy as np
|
||||
|
@ -115,13 +106,6 @@ if __name__ == "__main__":
|
|||
# Special evaluation config. Keys specified here will override
|
||||
# the same keys in the main config, but only for evaluation.
|
||||
"evaluation_config": {
|
||||
# Store videos in this relative directory here inside
|
||||
# the default output dir (~/ray_results/...).
|
||||
# Alternatively, you can specify an absolute path.
|
||||
# Set to True for using the default output dir (~/ray_results/...).
|
||||
# Set to False for not recording anything.
|
||||
"record_env": "videos",
|
||||
# "record_env": "/Users/xyz/my_videos/",
|
||||
# Render the env while evaluating.
|
||||
# Note that this will always only render the 1st RolloutWorker's
|
||||
# env and only the 1st sub-env in a vectorized env.
|
||||
|
|
|
@ -46,6 +46,11 @@ parser.add_argument(
|
|||
parser.add_argument(
|
||||
"--stop-reward", type=float, default=180.0, help="Reward at which we stop training."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--local-mode",
|
||||
action="store_true",
|
||||
help="Init Ray in local mode for easier debugging.",
|
||||
)
|
||||
|
||||
|
||||
class NonVectorizedEnvToBeVectorizedIntoRemoteBaseEnv(TaskSettableEnv):
|
||||
|
@ -96,7 +101,7 @@ class TaskSettingCallback(DefaultCallbacks):
|
|||
|
||||
if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
ray.init(num_cpus=6)
|
||||
ray.init(num_cpus=6, local_mode=args.local_mode)
|
||||
|
||||
config = {
|
||||
# Specify your custom (single, non-vectorized) env directly as a
|
||||
|
|
|
@ -6,15 +6,14 @@ import ray
|
|||
from ray import tune
|
||||
from ray.rllib.agents.registry import get_trainer_class
|
||||
|
||||
from ray.rllib.examples.env.repeat_after_me_env import RepeatAfterMeEnv
|
||||
# from ray.rllib.examples.env.repeat_after_me_env import RepeatAfterMeEnv
|
||||
from ray.rllib.examples.env.stateless_cartpole import StatelessCartPole
|
||||
|
||||
envs = {"RepeatAfterMeEnv": RepeatAfterMeEnv, "StatelessCartPole": StatelessCartPole}
|
||||
|
||||
config = {
|
||||
"name": "RNNSAC_example",
|
||||
"local_dir": str(Path(__file__).parent / "example_out"),
|
||||
"checkpoint_freq": 1,
|
||||
"checkpoint_at_end": True,
|
||||
"keep_checkpoints_num": 1,
|
||||
"checkpoint_score_attr": "episode_reward_mean",
|
||||
"stop": {
|
||||
|
@ -29,11 +28,8 @@ config = {
|
|||
"num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")),
|
||||
"framework": "torch",
|
||||
"num_workers": 4,
|
||||
"num_envs_per_worker": 1,
|
||||
"num_cpus_per_worker": 1,
|
||||
"log_level": "INFO",
|
||||
# "env": envs["RepeatAfterMeEnv"],
|
||||
"env": envs["StatelessCartPole"],
|
||||
# "env": RepeatAfterMeEnv,
|
||||
"env": StatelessCartPole,
|
||||
"horizon": 1000,
|
||||
"gamma": 0.95,
|
||||
"batch_mode": "complete_episodes",
|
||||
|
@ -102,7 +98,7 @@ if __name__ == "__main__":
|
|||
eps = 0
|
||||
ep_reward = 0
|
||||
while eps < 10:
|
||||
action, state, info_trainer = agent.compute_action(
|
||||
action, state, info_trainer = agent.compute_single_action(
|
||||
obs,
|
||||
state=state,
|
||||
prev_action=prev_action,
|
||||
|
@ -115,7 +111,7 @@ if __name__ == "__main__":
|
|||
ep_reward += reward
|
||||
try:
|
||||
env.render()
|
||||
except (NotImplementedError, ImportError):
|
||||
except Exception:
|
||||
pass
|
||||
if done:
|
||||
eps += 1
|
||||
|
|
|
@ -29,7 +29,6 @@ class TestGymCheckEnv(unittest.TestCase):
|
|||
env = Mock(spec=["observation_space"])
|
||||
with pytest.raises(AttributeError, match="Env must have action_space."):
|
||||
check_gym_environments(env)
|
||||
del env
|
||||
|
||||
def test_obs_and_action_spaces_are_gym_spaces(self):
|
||||
env = RandomEnv()
|
||||
|
@ -41,7 +40,6 @@ class TestGymCheckEnv(unittest.TestCase):
|
|||
env.action_space = "not an action space"
|
||||
with pytest.raises(ValueError, match="Action space must be a gym.space"):
|
||||
check_env(env)
|
||||
del env
|
||||
|
||||
def test_reset(self):
|
||||
reset = MagicMock(return_value=5)
|
||||
|
@ -56,7 +54,6 @@ class TestGymCheckEnv(unittest.TestCase):
|
|||
env.reset = reset
|
||||
with pytest.raises(ValueError, match=error):
|
||||
check_env(env)
|
||||
del env
|
||||
|
||||
def test_step(self):
|
||||
step = MagicMock(return_value=(5, 5, True, {}))
|
||||
|
@ -92,7 +89,6 @@ class TestGymCheckEnv(unittest.TestCase):
|
|||
error = "Your step function must return a info that is a dict."
|
||||
with pytest.raises(ValueError, match=error):
|
||||
check_env(env)
|
||||
del env
|
||||
|
||||
|
||||
class TestCheckMultiAgentEnv(unittest.TestCase):
|
||||
|
@ -104,7 +100,6 @@ class TestCheckMultiAgentEnv(unittest.TestCase):
|
|||
env = RandomEnv()
|
||||
with pytest.raises(ValueError, match="The passed env is not"):
|
||||
check_multiagent_environments(env)
|
||||
del env
|
||||
|
||||
def test_check_env_reset_incorrect_error(self):
|
||||
reset = MagicMock(return_value=5)
|
||||
|
@ -119,7 +114,6 @@ class TestCheckMultiAgentEnv(unittest.TestCase):
|
|||
env.reset = lambda *_: bad_obs
|
||||
with pytest.raises(ValueError, match="The observation collected from env"):
|
||||
check_env(env)
|
||||
del env
|
||||
|
||||
def test_check_incorrect_space_contains_functions_error(self):
|
||||
def bad_contains_function(self, x):
|
||||
|
@ -131,7 +125,6 @@ class TestCheckMultiAgentEnv(unittest.TestCase):
|
|||
ValueError, match="Your observation_space_contains function has some"
|
||||
):
|
||||
check_env(env)
|
||||
del env
|
||||
env = make_multi_agent("CartPole-v1")({"num_agents": 2})
|
||||
bad_action = {0: 2, 1: 2}
|
||||
env.action_space_sample = lambda *_: bad_action
|
||||
|
@ -178,7 +171,6 @@ class TestCheckMultiAgentEnv(unittest.TestCase):
|
|||
ValueError, match="The action collected from action_space_sample"
|
||||
):
|
||||
check_env(env)
|
||||
del env
|
||||
env = make_multi_agent("CartPole-v1")({"num_agents": 2})
|
||||
bad_obs = {
|
||||
0: np.array([np.inf, np.inf, np.inf, np.inf]),
|
||||
|
@ -206,7 +198,6 @@ class TestCheckBaseEnv:
|
|||
env = RandomEnv()
|
||||
with pytest.raises(ValueError, match="The passed env is not"):
|
||||
check_base_env(env)
|
||||
del env
|
||||
|
||||
def test_check_env_reset_incorrect_error(self):
|
||||
reset = MagicMock(return_value=5)
|
||||
|
@ -244,7 +235,6 @@ class TestCheckBaseEnv:
|
|||
ValueError, match="The observation collected from try_reset"
|
||||
):
|
||||
check_env(env)
|
||||
del env
|
||||
|
||||
def test_check_space_contains_functions_errors(self):
|
||||
def bad_contains_function(self, x):
|
||||
|
@ -258,14 +248,12 @@ class TestCheckBaseEnv:
|
|||
):
|
||||
check_env(env)
|
||||
|
||||
del env
|
||||
env = self._make_base_env()
|
||||
env.action_space_contains = bad_contains_function
|
||||
with pytest.raises(
|
||||
ValueError, match="Your action_space_contains function has some error"
|
||||
):
|
||||
check_env(env)
|
||||
del env
|
||||
|
||||
def test_bad_sample_function(self):
|
||||
env = self._make_base_env()
|
||||
|
@ -276,7 +264,6 @@ class TestCheckBaseEnv:
|
|||
):
|
||||
check_env(env)
|
||||
|
||||
del env
|
||||
env = self._make_base_env()
|
||||
bad_obs = {
|
||||
0: {
|
||||
|
|
Loading…
Add table
Reference in a new issue