mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[RLlib] Upgrade gym 0.23 (#24171)
This commit is contained in:
parent
c03d0432f3
commit
09886d7ab8
31 changed files with 133 additions and 367 deletions
|
@ -80,7 +80,7 @@ _____________________________________________________________________
|
||||||
# __get_q_values_dqn_start__
|
# __get_q_values_dqn_start__
|
||||||
# Get a reference to the model through the policy
|
# Get a reference to the model through the policy
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from ray.rllib.agents.dqn import DQNTrainer
|
from ray.rllib.algorithms.dqn import DQNTrainer
|
||||||
|
|
||||||
trainer = DQNTrainer(env="CartPole-v0", config={"framework": "tf2"})
|
trainer = DQNTrainer(env="CartPole-v0", config={"framework": "tf2"})
|
||||||
model = trainer.get_policy().model
|
model = trainer.get_policy().model
|
||||||
|
|
|
@ -39,7 +39,7 @@ __[Full Ray Enhancement Proposal, REP-001: Serve Pipeline](https://github.com/ra
|
||||||
|
|
||||||
## Concepts
|
## Concepts
|
||||||
|
|
||||||
- **Deployment**: Scalable, upgradeable group of actors managed by Ray Serve. __[See docs for detail](https://docs.ray.io/en/master/serve/core-apis.html#core-api-deployments)__
|
- **Deployment**: Scalable, upgradeable group of actors managed by Ray Serve. __[See docs for detail](https://docs.ray.io/en/master/serve/package-ref.html#deployment-api)__
|
||||||
|
|
||||||
- **DeploymentNode**: Smallest unit in a graph, created by calling `.bind()` on a serve decorated class or function, backed by a Deployment.
|
- **DeploymentNode**: Smallest unit in a graph, created by calling `.bind()` on a serve decorated class or function, backed by a Deployment.
|
||||||
|
|
||||||
|
|
|
@ -10,7 +10,7 @@ from ray._private.client_mode_hook import enable_client_mode, client_mode_should
|
||||||
@pytest.mark.skip(reason="KV store is not working properly.")
|
@pytest.mark.skip(reason="KV store is not working properly.")
|
||||||
def test_rllib_integration(ray_start_regular_shared):
|
def test_rllib_integration(ray_start_regular_shared):
|
||||||
with ray_start_client_server():
|
with ray_start_client_server():
|
||||||
import ray.rllib.agents.dqn as dqn
|
import ray.rllib.algorithms.dqn as dqn
|
||||||
|
|
||||||
# Confirming the behavior of this context manager.
|
# Confirming the behavior of this context manager.
|
||||||
# (Client mode hook not yet enabled.)
|
# (Client mode hook not yet enabled.)
|
||||||
|
|
|
@ -29,7 +29,7 @@ virtualenv
|
||||||
## setup.py extras
|
## setup.py extras
|
||||||
dm_tree
|
dm_tree
|
||||||
flask
|
flask
|
||||||
gym==0.21.0; python_version >= '3.7'
|
gym>=0.21.0,<0.24.0; python_version >= '3.7'
|
||||||
gym==0.19.0; python_version < '3.7'
|
gym==0.19.0; python_version < '3.7'
|
||||||
lz4
|
lz4
|
||||||
scikit-image
|
scikit-image
|
||||||
|
|
|
@ -4,7 +4,7 @@
|
||||||
# ---------------------
|
# ---------------------
|
||||||
# Atari
|
# Atari
|
||||||
autorom[accept-rom-license]
|
autorom[accept-rom-license]
|
||||||
gym[atari]==0.21.0; python_version >= '3.7'
|
gym>=0.21.0,<0.24.0; python_version >= '3.7'
|
||||||
gym[atari]==0.19.0; python_version < '3.7'
|
gym[atari]==0.19.0; python_version < '3.7'
|
||||||
# Kaggle envs.
|
# Kaggle envs.
|
||||||
kaggle_environments==1.7.11
|
kaggle_environments==1.7.11
|
||||||
|
|
|
@ -11,7 +11,7 @@ freezegun==1.1.0
|
||||||
gluoncv==0.10.1.post0
|
gluoncv==0.10.1.post0
|
||||||
gpy==1.10.0
|
gpy==1.10.0
|
||||||
autorom[accept-rom-license]
|
autorom[accept-rom-license]
|
||||||
gym[atari]==0.21.0; python_version >= '3.7'
|
gym>=0.21.0,<0.24.0; python_version >= '3.7'
|
||||||
gym[atari]==0.19.0; python_version < '3.7'
|
gym[atari]==0.19.0; python_version < '3.7'
|
||||||
h5py==3.1.0
|
h5py==3.1.0
|
||||||
hpbandster==0.7.4
|
hpbandster==0.7.4
|
||||||
|
|
|
@ -7,7 +7,7 @@ python:
|
||||||
pip_packages:
|
pip_packages:
|
||||||
- pytest
|
- pytest
|
||||||
- awscli
|
- awscli
|
||||||
- gym==0.21.0
|
- gym
|
||||||
conda_packages: []
|
conda_packages: []
|
||||||
|
|
||||||
post_build_cmds:
|
post_build_cmds:
|
||||||
|
|
|
@ -7,7 +7,7 @@ debian_packages:
|
||||||
|
|
||||||
python:
|
python:
|
||||||
pip_packages:
|
pip_packages:
|
||||||
- gym[atari]==0.21.0
|
- gym[atari]
|
||||||
- pytest
|
- pytest
|
||||||
- tensorflow
|
- tensorflow
|
||||||
conda_packages: []
|
conda_packages: []
|
||||||
|
@ -15,7 +15,7 @@ python:
|
||||||
post_build_cmds:
|
post_build_cmds:
|
||||||
- 'rm -r wrk || true && git clone https://github.com/wg/wrk.git /tmp/wrk && cd /tmp/wrk && make -j && sudo cp wrk /usr/local/bin'
|
- 'rm -r wrk || true && git clone https://github.com/wg/wrk.git /tmp/wrk && cd /tmp/wrk && make -j && sudo cp wrk /usr/local/bin'
|
||||||
- pip3 install pytest || true
|
- pip3 install pytest || true
|
||||||
- pip3 install -U ray[all] gym[atari]==0.21.0 autorom[accept-rom-license]
|
- pip3 install -U ray[all] gym[atari] autorom[accept-rom-license]
|
||||||
- pip3 install ray[all]
|
- pip3 install ray[all]
|
||||||
# TODO (Alex): Ideally we would install all the dependencies from the new
|
# TODO (Alex): Ideally we would install all the dependencies from the new
|
||||||
# version too, but pip won't be able to find the new version of ray-cpp.
|
# version too, but pip won't be able to find the new version of ray-cpp.
|
||||||
|
|
|
@ -7,7 +7,7 @@ debian_packages:
|
||||||
|
|
||||||
python:
|
python:
|
||||||
pip_packages:
|
pip_packages:
|
||||||
- gym[atari]==0.21.0
|
- gym[atari]
|
||||||
- pygame
|
- pygame
|
||||||
- pytest
|
- pytest
|
||||||
- tensorflow
|
- tensorflow
|
||||||
|
@ -18,7 +18,7 @@ post_build_cmds:
|
||||||
- 'rm -r wrk || true && git clone https://github.com/wg/wrk.git /tmp/wrk && cd /tmp/wrk && make -j && sudo cp wrk /usr/local/bin'
|
- 'rm -r wrk || true && git clone https://github.com/wg/wrk.git /tmp/wrk && cd /tmp/wrk && make -j && sudo cp wrk /usr/local/bin'
|
||||||
- pip3 install numpy==1.19 || true
|
- pip3 install numpy==1.19 || true
|
||||||
- pip3 install pytest || true
|
- pip3 install pytest || true
|
||||||
- pip3 install -U ray[all] gym[atari]==0.21.0 autorom[accept-rom-license]
|
- pip3 install -U ray[all] gym[atari] autorom[accept-rom-license]
|
||||||
- pip3 install ray[all]
|
- pip3 install ray[all]
|
||||||
# TODO (Alex): Ideally we would install all the dependencies from the new
|
# TODO (Alex): Ideally we would install all the dependencies from the new
|
||||||
# version too, but pip won't be able to find the new version of ray-cpp.
|
# version too, but pip won't be able to find the new version of ray-cpp.
|
||||||
|
|
|
@ -8,7 +8,7 @@ python:
|
||||||
# These dependencies should be handled by requirements_rllib.txt and
|
# These dependencies should be handled by requirements_rllib.txt and
|
||||||
# requirements_ml_docker.txt
|
# requirements_ml_docker.txt
|
||||||
pip_packages:
|
pip_packages:
|
||||||
- gym==0.21.0
|
- gym
|
||||||
conda_packages: []
|
conda_packages: []
|
||||||
|
|
||||||
post_build_cmds:
|
post_build_cmds:
|
||||||
|
|
|
@ -8,8 +8,8 @@ python:
|
||||||
- pytest
|
- pytest
|
||||||
- awscli
|
- awscli
|
||||||
- gsutil
|
- gsutil
|
||||||
|
- gym
|
||||||
- gcsfs
|
- gcsfs
|
||||||
- gym==0.21.0
|
|
||||||
- pyarrow>=6.0.1,<7.0.0
|
- pyarrow>=6.0.1,<7.0.0
|
||||||
conda_packages: []
|
conda_packages: []
|
||||||
|
|
||||||
|
|
|
@ -8,8 +8,8 @@ python:
|
||||||
- pytest
|
- pytest
|
||||||
- awscli
|
- awscli
|
||||||
- gsutil
|
- gsutil
|
||||||
|
- gym
|
||||||
- gcsfs
|
- gcsfs
|
||||||
- gym==0.21.0
|
|
||||||
- pyarrow>=6.0.1,<7.0.0
|
- pyarrow>=6.0.1,<7.0.0
|
||||||
conda_packages: []
|
conda_packages: []
|
||||||
|
|
||||||
|
|
20
rllib/BUILD
20
rllib/BUILD
|
@ -1318,12 +1318,6 @@ sh_test(
|
||||||
data = glob(["examples/serving/*.py"]),
|
data = glob(["examples/serving/*.py"]),
|
||||||
)
|
)
|
||||||
|
|
||||||
py_test(
|
|
||||||
name = "env/tests/test_record_env_wrapper",
|
|
||||||
tags = ["team:ml", "env"],
|
|
||||||
size = "small",
|
|
||||||
srcs = ["env/tests/test_record_env_wrapper.py"]
|
|
||||||
)
|
|
||||||
|
|
||||||
py_test(
|
py_test(
|
||||||
name = "env/tests/test_remote_worker_envs",
|
name = "env/tests/test_remote_worker_envs",
|
||||||
|
@ -2818,13 +2812,13 @@ py_test(
|
||||||
args = ["--as-test", "--framework=torch"],
|
args = ["--as-test", "--framework=torch"],
|
||||||
)
|
)
|
||||||
|
|
||||||
py_test(
|
# py_test(
|
||||||
name = "examples/remote_base_env_with_custom_api",
|
# name = "examples/remote_base_env_with_custom_api",
|
||||||
tags = ["team:ml", "examples", "examples_R"],
|
# tags = ["team:ml", "examples", "examples_R"],
|
||||||
size = "medium",
|
# size = "medium",
|
||||||
srcs = ["examples/remote_base_env_with_custom_api.py"],
|
# srcs = ["examples/remote_base_env_with_custom_api.py"],
|
||||||
args = ["--stop-iters=3"]
|
# args = ["--stop-iters=3"]
|
||||||
)
|
# )
|
||||||
|
|
||||||
py_test(
|
py_test(
|
||||||
name = "examples/restore_1_of_n_agents_from_checkpoint",
|
name = "examples/restore_1_of_n_agents_from_checkpoint",
|
||||||
|
|
|
@ -75,10 +75,11 @@ class _MockTrainer(Trainer):
|
||||||
self.info = info
|
self.info = info
|
||||||
self.restored = True
|
self.restored = True
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
@override(Trainer)
|
@override(Trainer)
|
||||||
def _register_if_needed(self, env_object, config):
|
def _get_env_id_and_creator(env_specifier, config):
|
||||||
# No env to register.
|
# No env to register.
|
||||||
pass
|
return None, None
|
||||||
|
|
||||||
def set_info(self, info):
|
def set_info(self, info):
|
||||||
self.info = info
|
self.info = info
|
||||||
|
|
|
@ -8,7 +8,9 @@ import logging
|
||||||
import math
|
import math
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import os
|
import os
|
||||||
|
from packaging import version
|
||||||
import pickle
|
import pickle
|
||||||
|
import pkg_resources
|
||||||
import tempfile
|
import tempfile
|
||||||
import time
|
import time
|
||||||
from typing import (
|
from typing import (
|
||||||
|
@ -30,7 +32,6 @@ from ray.exceptions import RayError
|
||||||
from ray.rllib.agents.callbacks import DefaultCallbacks
|
from ray.rllib.agents.callbacks import DefaultCallbacks
|
||||||
from ray.rllib.agents.trainer_config import TrainerConfig
|
from ray.rllib.agents.trainer_config import TrainerConfig
|
||||||
from ray.rllib.env.env_context import EnvContext
|
from ray.rllib.env.env_context import EnvContext
|
||||||
from ray.rllib.env.multi_agent_env import MultiAgentEnv
|
|
||||||
from ray.rllib.env.utils import gym_env_creator
|
from ray.rllib.env.utils import gym_env_creator
|
||||||
from ray.rllib.evaluation.episode import Episode
|
from ray.rllib.evaluation.episode import Episode
|
||||||
from ray.rllib.evaluation.metrics import (
|
from ray.rllib.evaluation.metrics import (
|
||||||
|
@ -94,7 +95,7 @@ from ray.rllib.utils.typing import (
|
||||||
TrainerConfigDict,
|
TrainerConfigDict,
|
||||||
)
|
)
|
||||||
from ray.tune.logger import Logger, UnifiedLogger
|
from ray.tune.logger import Logger, UnifiedLogger
|
||||||
from ray.tune.registry import ENV_CREATOR, register_env, _global_registry
|
from ray.tune.registry import ENV_CREATOR, _global_registry
|
||||||
from ray.tune.resources import Resources
|
from ray.tune.resources import Resources
|
||||||
from ray.tune.result import DEFAULT_RESULTS_DIR
|
from ray.tune.result import DEFAULT_RESULTS_DIR
|
||||||
from ray.tune.trainable import Trainable
|
from ray.tune.trainable import Trainable
|
||||||
|
@ -220,19 +221,13 @@ class Trainer(Trainable):
|
||||||
if isinstance(config, TrainerConfig):
|
if isinstance(config, TrainerConfig):
|
||||||
config = config.to_dict()
|
config = config.to_dict()
|
||||||
|
|
||||||
# Convert `env` provided in config into a string:
|
# Convert `env` provided in config into a concrete env creator callable, which
|
||||||
# - If `env` is a string: `self._env_id` = `env`.
|
# takes an EnvContext (config dict) as arg and returning an RLlib supported Env
|
||||||
# - If `env` is a class: `self._env_id` = `env.__name__` -> Already
|
# type (e.g. a gym.Env).
|
||||||
# register it with a auto-generated env creator.
|
self._env_id, self.env_creator = self._get_env_id_and_creator(
|
||||||
# - If `env` is None: `self._env_id` is None.
|
|
||||||
self._env_id: Optional[str] = self._register_if_needed(
|
|
||||||
env or config.get("env"), config
|
env or config.get("env"), config
|
||||||
)
|
)
|
||||||
|
|
||||||
# The env creator callable, taking an EnvContext (config dict)
|
|
||||||
# as arg and returning an RLlib supported Env type (e.g. a gym.Env).
|
|
||||||
self.env_creator: Optional[EnvCreator] = None
|
|
||||||
|
|
||||||
# Placeholder for a local replay buffer instance.
|
# Placeholder for a local replay buffer instance.
|
||||||
self.local_replay_buffer = None
|
self.local_replay_buffer = None
|
||||||
|
|
||||||
|
@ -310,10 +305,6 @@ class Trainer(Trainable):
|
||||||
# Validate the framework settings in config.
|
# Validate the framework settings in config.
|
||||||
self.validate_framework(self.config)
|
self.validate_framework(self.config)
|
||||||
|
|
||||||
# Setup the self.env_creator callable (to be passed
|
|
||||||
# e.g. to RolloutWorkers' c'tors).
|
|
||||||
self.env_creator = self._get_env_creator_from_env_id(self._env_id)
|
|
||||||
|
|
||||||
# Set Trainer's seed after we have - if necessary - enabled
|
# Set Trainer's seed after we have - if necessary - enabled
|
||||||
# tf eager-execution.
|
# tf eager-execution.
|
||||||
update_global_seed_if_necessary(self.config["framework"], self.config["seed"])
|
update_global_seed_if_necessary(self.config["framework"], self.config["seed"])
|
||||||
|
@ -466,8 +457,9 @@ class Trainer(Trainable):
|
||||||
|
|
||||||
self.config["evaluation_config"] = eval_config
|
self.config["evaluation_config"] = eval_config
|
||||||
|
|
||||||
env_id = self._register_if_needed(eval_config.get("env"), eval_config)
|
env_id, env_creator = self._get_env_id_and_creator(
|
||||||
env_creator = self._get_env_creator_from_env_id(env_id)
|
eval_config.get("env"), eval_config
|
||||||
|
)
|
||||||
|
|
||||||
# Create a separate evaluation worker set for evaluation.
|
# Create a separate evaluation worker set for evaluation.
|
||||||
# If evaluation_num_workers=0, use the evaluation set's local
|
# If evaluation_num_workers=0, use the evaluation set's local
|
||||||
|
@ -1541,37 +1533,87 @@ class Trainer(Trainable):
|
||||||
"""Pre-evaluation callback."""
|
"""Pre-evaluation callback."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def _get_env_creator_from_env_id(self, env_id: Optional[str] = None) -> EnvCreator:
|
@staticmethod
|
||||||
"""Returns an env creator callable, given an `env_id` (e.g. "CartPole-v0").
|
def _get_env_id_and_creator(
|
||||||
|
env_specifier: Union[str, EnvType, None], config: PartialTrainerConfigDict
|
||||||
|
) -> Tuple[Optional[str], EnvCreator]:
|
||||||
|
"""Returns env_id and creator callable given original env id from config.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
env_id: An already tune registered env ID, a known gym env name,
|
env_specifier: An env class, an already tune registered env ID, a known
|
||||||
or None (if no env is used).
|
gym env name, or None (if no env is used).
|
||||||
|
config: The Trainer's (maybe partial) config dict.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
|
Tuple consisting of a) env ID string and b) env creator callable.
|
||||||
"""
|
"""
|
||||||
if env_id:
|
# Environment is specified via a string.
|
||||||
|
if isinstance(env_specifier, str):
|
||||||
# An already registered env.
|
# An already registered env.
|
||||||
if _global_registry.contains(ENV_CREATOR, env_id):
|
if _global_registry.contains(ENV_CREATOR, env_specifier):
|
||||||
return _global_registry.get(ENV_CREATOR, env_id)
|
return env_specifier, _global_registry.get(ENV_CREATOR, env_specifier)
|
||||||
|
|
||||||
# A class path specifier.
|
# A class path specifier.
|
||||||
elif "." in env_id:
|
elif "." in env_specifier:
|
||||||
|
|
||||||
def env_creator_from_classpath(env_context):
|
def env_creator_from_classpath(env_context):
|
||||||
try:
|
try:
|
||||||
env_obj = from_config(env_id, env_context)
|
env_obj = from_config(env_specifier, env_context)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
raise EnvError(ERR_MSG_INVALID_ENV_DESCRIPTOR.format(env_id))
|
raise EnvError(
|
||||||
|
ERR_MSG_INVALID_ENV_DESCRIPTOR.format(env_specifier)
|
||||||
|
)
|
||||||
return env_obj
|
return env_obj
|
||||||
|
|
||||||
return env_creator_from_classpath
|
return env_specifier, env_creator_from_classpath
|
||||||
# Try gym/PyBullet/Vizdoom.
|
# Try gym/PyBullet/Vizdoom.
|
||||||
else:
|
else:
|
||||||
return functools.partial(gym_env_creator, env_descriptor=env_id)
|
return env_specifier, functools.partial(
|
||||||
# No env -> Env creator always returns None.
|
gym_env_creator, env_descriptor=env_specifier
|
||||||
|
)
|
||||||
|
|
||||||
|
elif isinstance(env_specifier, type):
|
||||||
|
env_id = env_specifier # .__name__
|
||||||
|
|
||||||
|
if config.get("remote_worker_envs"):
|
||||||
|
# Check gym version (0.22 or higher?).
|
||||||
|
# If > 0.21, can't perform auto-wrapping of the given class as this
|
||||||
|
# would lead to a pickle error.
|
||||||
|
gym_version = pkg_resources.get_distribution("gym").version
|
||||||
|
if version.parse(gym_version) >= version.parse("0.22"):
|
||||||
|
raise ValueError(
|
||||||
|
"Cannot specify a gym.Env class via `config.env` while setting "
|
||||||
|
"`config.remote_worker_env=True` AND your gym version is >= "
|
||||||
|
"0.22! Try installing an older version of gym or set `config."
|
||||||
|
"remote_worker_env=False`."
|
||||||
|
)
|
||||||
|
|
||||||
|
@ray.remote(num_cpus=1)
|
||||||
|
class _wrapper(env_specifier):
|
||||||
|
# Add convenience `_get_spaces` and `_is_multi_agent`
|
||||||
|
# methods:
|
||||||
|
def _get_spaces(self):
|
||||||
|
return self.observation_space, self.action_space
|
||||||
|
|
||||||
|
def _is_multi_agent(self):
|
||||||
|
from ray.rllib.env.multi_agent_env import MultiAgentEnv
|
||||||
|
|
||||||
|
return isinstance(self, MultiAgentEnv)
|
||||||
|
|
||||||
|
return env_id, lambda cfg: _wrapper.remote(cfg)
|
||||||
else:
|
else:
|
||||||
return lambda env_config: None
|
return env_id, lambda cfg: env_specifier(cfg)
|
||||||
|
|
||||||
|
# No env -> Env creator always returns None.
|
||||||
|
elif env_specifier is None:
|
||||||
|
return None, lambda env_config: None
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"{} is an invalid env specifier. ".format(env_specifier)
|
||||||
|
+ "You can specify a custom env as either a class "
|
||||||
|
'(e.g., YourEnvCls) or a registered env id (e.g., "your_env").'
|
||||||
|
)
|
||||||
|
|
||||||
def _sync_filters_if_needed(self, workers: WorkerSet):
|
def _sync_filters_if_needed(self, workers: WorkerSet):
|
||||||
if self.config.get("observation_filter", "NoFilter") != "NoFilter":
|
if self.config.get("observation_filter", "NoFilter") != "NoFilter":
|
||||||
|
@ -1753,16 +1795,6 @@ class Trainer(Trainable):
|
||||||
if model_config is None:
|
if model_config is None:
|
||||||
config["model"] = model_config = {}
|
config["model"] = model_config = {}
|
||||||
|
|
||||||
# Monitor should be replaced by `record_env`.
|
|
||||||
if config.get("monitor", DEPRECATED_VALUE) != DEPRECATED_VALUE:
|
|
||||||
deprecation_warning("monitor", "record_env", error=False)
|
|
||||||
config["record_env"] = config.get("monitor", False)
|
|
||||||
# Empty string would fail some if-blocks checking for this setting.
|
|
||||||
# Set to True instead, meaning: use default output dir to store
|
|
||||||
# the videos.
|
|
||||||
if config.get("record_env") == "":
|
|
||||||
config["record_env"] = True
|
|
||||||
|
|
||||||
# Use DefaultCallbacks class, if callbacks is None.
|
# Use DefaultCallbacks class, if callbacks is None.
|
||||||
if config["callbacks"] is None:
|
if config["callbacks"] is None:
|
||||||
config["callbacks"] = DefaultCallbacks
|
config["callbacks"] = DefaultCallbacks
|
||||||
|
@ -2149,38 +2181,6 @@ class Trainer(Trainable):
|
||||||
kwargs["local_replay_buffer"] = self.local_replay_buffer
|
kwargs["local_replay_buffer"] = self.local_replay_buffer
|
||||||
return kwargs
|
return kwargs
|
||||||
|
|
||||||
def _register_if_needed(
|
|
||||||
self, env_object: Union[str, EnvType, None], config
|
|
||||||
) -> Optional[str]:
|
|
||||||
if isinstance(env_object, str):
|
|
||||||
return env_object
|
|
||||||
elif isinstance(env_object, type):
|
|
||||||
name = env_object.__name__
|
|
||||||
|
|
||||||
if config.get("remote_worker_envs"):
|
|
||||||
|
|
||||||
@ray.remote(num_cpus=0)
|
|
||||||
class _wrapper(env_object):
|
|
||||||
# Add convenience `_get_spaces` and `_is_multi_agent`
|
|
||||||
# methods.
|
|
||||||
def _get_spaces(self):
|
|
||||||
return self.observation_space, self.action_space
|
|
||||||
|
|
||||||
def _is_multi_agent(self):
|
|
||||||
return isinstance(self, MultiAgentEnv)
|
|
||||||
|
|
||||||
register_env(name, lambda cfg: _wrapper.remote(cfg))
|
|
||||||
else:
|
|
||||||
register_env(name, lambda cfg: env_object(cfg))
|
|
||||||
return name
|
|
||||||
elif env_object is None:
|
|
||||||
return None
|
|
||||||
raise ValueError(
|
|
||||||
"{} is an invalid env specification. ".format(env_object)
|
|
||||||
+ "You can specify a custom env as either a class "
|
|
||||||
'(e.g., YourEnvCls) or a registered env id (e.g., "your_env").'
|
|
||||||
)
|
|
||||||
|
|
||||||
def _step_context(trainer):
|
def _step_context(trainer):
|
||||||
class StepCtx:
|
class StepCtx:
|
||||||
def __enter__(self):
|
def __enter__(self):
|
||||||
|
|
|
@ -108,7 +108,6 @@ class TrainerConfig:
|
||||||
self.action_space = None
|
self.action_space = None
|
||||||
self.env_task_fn = None
|
self.env_task_fn = None
|
||||||
self.render_env = False
|
self.render_env = False
|
||||||
self.record_env = False
|
|
||||||
self.clip_rewards = None
|
self.clip_rewards = None
|
||||||
self.normalize_actions = True
|
self.normalize_actions = True
|
||||||
self.clip_actions = False
|
self.clip_actions = False
|
||||||
|
@ -458,7 +457,6 @@ class TrainerConfig:
|
||||||
action_space: Optional[gym.spaces.Space] = None,
|
action_space: Optional[gym.spaces.Space] = None,
|
||||||
env_task_fn: Optional[Callable[[ResultDict, EnvType, EnvContext], Any]] = None,
|
env_task_fn: Optional[Callable[[ResultDict, EnvType, EnvContext], Any]] = None,
|
||||||
render_env: Optional[bool] = None,
|
render_env: Optional[bool] = None,
|
||||||
record_env: Optional[bool] = None,
|
|
||||||
clip_rewards: Optional[Union[bool, float]] = None,
|
clip_rewards: Optional[Union[bool, float]] = None,
|
||||||
normalize_actions: Optional[bool] = None,
|
normalize_actions: Optional[bool] = None,
|
||||||
clip_actions: Optional[bool] = None,
|
clip_actions: Optional[bool] = None,
|
||||||
|
@ -489,11 +487,6 @@ class TrainerConfig:
|
||||||
`render()` method which either:
|
`render()` method which either:
|
||||||
a) handles window generation and rendering itself (returning True) or
|
a) handles window generation and rendering itself (returning True) or
|
||||||
b) returns a numpy uint8 image of shape [height x width x 3 (RGB)].
|
b) returns a numpy uint8 image of shape [height x width x 3 (RGB)].
|
||||||
record_env: If True, stores videos in this relative directory inside the
|
|
||||||
default output dir (~/ray_results/...). Alternatively, you can
|
|
||||||
specify an absolute path (str), in which the env recordings should be
|
|
||||||
stored instead. Set to False for not recording anything.
|
|
||||||
Note: This setting replaces the deprecated `monitor` key.
|
|
||||||
clip_rewards: Whether to clip rewards during Policy's postprocessing.
|
clip_rewards: Whether to clip rewards during Policy's postprocessing.
|
||||||
None (default): Clip for Atari only (r=sign(r)).
|
None (default): Clip for Atari only (r=sign(r)).
|
||||||
True: r=sign(r): Fixed rewards -1.0, 1.0, or 0.0.
|
True: r=sign(r): Fixed rewards -1.0, 1.0, or 0.0.
|
||||||
|
@ -524,8 +517,6 @@ class TrainerConfig:
|
||||||
self.env_task_fn = env_task_fn
|
self.env_task_fn = env_task_fn
|
||||||
if render_env is not None:
|
if render_env is not None:
|
||||||
self.render_env = render_env
|
self.render_env = render_env
|
||||||
if record_env is not None:
|
|
||||||
self.record_env = record_env
|
|
||||||
if clip_rewards is not None:
|
if clip_rewards is not None:
|
||||||
self.clip_rewards = clip_rewards
|
self.clip_rewards = clip_rewards
|
||||||
if normalize_actions is not None:
|
if normalize_actions is not None:
|
||||||
|
|
|
@ -168,7 +168,7 @@ class AlphaZeroPolicyWrapperClass(AlphaZeroPolicy):
|
||||||
model = ModelCatalog.get_model_v2(
|
model = ModelCatalog.get_model_v2(
|
||||||
obs_space, action_space, action_space.n, config["model"], "torch"
|
obs_space, action_space, action_space.n, config["model"], "torch"
|
||||||
)
|
)
|
||||||
env_creator = Trainer._get_env_creator_from_env_id(None, config["env"])
|
_, env_creator = Trainer._get_env_id_and_creator(config["env"], config)
|
||||||
if config["ranked_rewards"]["enable"]:
|
if config["ranked_rewards"]["enable"]:
|
||||||
# if r2 is enabled, tne env is wrapped to include a rewards buffer
|
# if r2 is enabled, tne env is wrapped to include a rewards buffer
|
||||||
# used to normalize rewards
|
# used to normalize rewards
|
||||||
|
|
|
@ -355,8 +355,6 @@ class ARSTrainer(Trainer):
|
||||||
# Validate our config dict.
|
# Validate our config dict.
|
||||||
self.validate_config(self.config)
|
self.validate_config(self.config)
|
||||||
|
|
||||||
# Generate `self.env_creator` callable to create an env instance.
|
|
||||||
self.env_creator = self._get_env_creator_from_env_id(self._env_id)
|
|
||||||
# Generate the local env.
|
# Generate the local env.
|
||||||
env_context = EnvContext(self.config["env_config"] or {}, worker_index=0)
|
env_context = EnvContext(self.config["env_config"] or {}, worker_index=0)
|
||||||
env = self.env_creator(env_context)
|
env = self.env_creator(env_context)
|
||||||
|
|
|
@ -362,8 +362,6 @@ class ESTrainer(Trainer):
|
||||||
# Call super's validation method.
|
# Call super's validation method.
|
||||||
self.validate_config(self.config)
|
self.validate_config(self.config)
|
||||||
|
|
||||||
# Generate `self.env_creator` callable to create an env instance.
|
|
||||||
self.env_creator = self._get_env_creator_from_env_id(self._env_id)
|
|
||||||
# Generate the local env.
|
# Generate the local env.
|
||||||
env_context = EnvContext(self.config["env_config"] or {}, worker_index=0)
|
env_context = EnvContext(self.config["env_config"] or {}, worker_index=0)
|
||||||
env = self.env_creator(env_context)
|
env = self.env_creator(env_context)
|
||||||
|
|
2
rllib/env/multi_agent_env.py
vendored
2
rllib/env/multi_agent_env.py
vendored
|
@ -1,6 +1,6 @@
|
||||||
import gym
|
import gym
|
||||||
import logging
|
import logging
|
||||||
from typing import Callable, Dict, List, Tuple, Type, Optional, Union, Set
|
from typing import Callable, Dict, List, Tuple, Optional, Union, Set, Type
|
||||||
|
|
||||||
from ray.rllib.env.base_env import BaseEnv
|
from ray.rllib.env.base_env import BaseEnv
|
||||||
from ray.rllib.utils.annotations import (
|
from ray.rllib.utils.annotations import (
|
||||||
|
|
100
rllib/env/tests/test_record_env_wrapper.py
vendored
100
rllib/env/tests/test_record_env_wrapper.py
vendored
|
@ -1,100 +0,0 @@
|
||||||
import glob
|
|
||||||
import gym
|
|
||||||
import numpy as np
|
|
||||||
import os
|
|
||||||
import shutil
|
|
||||||
import unittest
|
|
||||||
|
|
||||||
from ray.rllib.env.utils import VideoMonitor, record_env_wrapper
|
|
||||||
from ray.rllib.examples.env.mock_env import MockEnv2
|
|
||||||
from ray.rllib.examples.env.multi_agent import BasicMultiAgent
|
|
||||||
from ray.rllib.utils.test_utils import check
|
|
||||||
|
|
||||||
|
|
||||||
class TestRecordEnvWrapper(unittest.TestCase):
|
|
||||||
def test_wrap_gym_env(self):
|
|
||||||
record_env_dir = os.popen("mktemp -d").read()[:-1]
|
|
||||||
print(f"tmp dir for videos={record_env_dir}")
|
|
||||||
|
|
||||||
if not os.path.exists(record_env_dir):
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
num_steps_per_episode = 10
|
|
||||||
wrapped = record_env_wrapper(
|
|
||||||
env=MockEnv2(num_steps_per_episode),
|
|
||||||
record_env=record_env_dir,
|
|
||||||
log_dir="",
|
|
||||||
policy_config={
|
|
||||||
"in_evaluation": False,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
# Non MultiAgentEnv: Wrapper's type is wrappers.Monitor.
|
|
||||||
self.assertTrue(isinstance(wrapped, gym.wrappers.Monitor))
|
|
||||||
self.assertFalse(isinstance(wrapped, VideoMonitor))
|
|
||||||
|
|
||||||
wrapped.reset()
|
|
||||||
# Expect one video file to have been produced in the tmp dir.
|
|
||||||
os.chdir(record_env_dir)
|
|
||||||
ls = glob.glob("*.mp4")
|
|
||||||
self.assertTrue(len(ls) == 1)
|
|
||||||
# 10 steps for a complete episode.
|
|
||||||
for i in range(num_steps_per_episode):
|
|
||||||
wrapped.step(0)
|
|
||||||
# Another episode.
|
|
||||||
wrapped.reset()
|
|
||||||
for i in range(num_steps_per_episode):
|
|
||||||
wrapped.step(0)
|
|
||||||
# Expect another video file to have been produced (2nd episode).
|
|
||||||
ls = glob.glob("*.mp4")
|
|
||||||
self.assertTrue(len(ls) == 2)
|
|
||||||
|
|
||||||
# MockEnv2 returns a reward of 100.0 every step.
|
|
||||||
# So total reward is 1000.0 per episode (10 steps).
|
|
||||||
check(
|
|
||||||
np.array([100.0, 100.0]) * num_steps_per_episode,
|
|
||||||
wrapped.get_episode_rewards(),
|
|
||||||
)
|
|
||||||
# Erase all generated files and the temp path just in case,
|
|
||||||
# as to not disturb further CI-tests.
|
|
||||||
shutil.rmtree(record_env_dir)
|
|
||||||
|
|
||||||
def test_wrap_multi_agent_env(self):
|
|
||||||
record_env_dir = os.popen("mktemp -d").read()[:-1]
|
|
||||||
print(f"tmp dir for videos={record_env_dir}")
|
|
||||||
|
|
||||||
if not os.path.exists(record_env_dir):
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
wrapped = record_env_wrapper(
|
|
||||||
env=BasicMultiAgent(3),
|
|
||||||
record_env=record_env_dir,
|
|
||||||
log_dir="",
|
|
||||||
policy_config={
|
|
||||||
"in_evaluation": False,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
# Type is VideoMonitor.
|
|
||||||
self.assertTrue(isinstance(wrapped, gym.wrappers.Monitor))
|
|
||||||
self.assertTrue(isinstance(wrapped, VideoMonitor))
|
|
||||||
|
|
||||||
wrapped.reset()
|
|
||||||
|
|
||||||
# BasicMultiAgent is hardcoded to run 25-step episodes.
|
|
||||||
for i in range(25):
|
|
||||||
wrapped.step({0: 0, 1: 0, 2: 0})
|
|
||||||
|
|
||||||
# Expect one video file to have been produced in the tmp dir.
|
|
||||||
os.chdir(record_env_dir)
|
|
||||||
ls = glob.glob("*.mp4")
|
|
||||||
self.assertTrue(len(ls) == 1)
|
|
||||||
|
|
||||||
# However VideoMonitor's _after_step is overwritten to not
|
|
||||||
# use stats_recorder. So nothing to verify here, except that
|
|
||||||
# it runs fine.
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
import pytest
|
|
||||||
import sys
|
|
||||||
|
|
||||||
sys.exit(pytest.main(["-v", __file__]))
|
|
33
rllib/env/tests/test_remote_worker_envs.py
vendored
33
rllib/env/tests/test_remote_worker_envs.py
vendored
|
@ -7,10 +7,8 @@ import unittest
|
||||||
import ray
|
import ray
|
||||||
from ray.rllib.algorithms.pg import pg
|
from ray.rllib.algorithms.pg import pg
|
||||||
from ray.rllib.env.wrappers.pettingzoo_env import PettingZooEnv
|
from ray.rllib.env.wrappers.pettingzoo_env import PettingZooEnv
|
||||||
from ray.rllib.examples.env.random_env import RandomEnv, RandomMultiAgentEnv
|
|
||||||
from ray.rllib.examples.remote_base_env_with_custom_api import (
|
# from ray.rllib.examples.env.random_env import RandomEnv
|
||||||
NonVectorizedEnvToBeVectorizedIntoRemoteBaseEnv,
|
|
||||||
)
|
|
||||||
from ray import tune
|
from ray import tune
|
||||||
|
|
||||||
|
|
||||||
|
@ -55,17 +53,19 @@ class TestRemoteWorkerEnvSetting(unittest.TestCase):
|
||||||
trainer.stop()
|
trainer.stop()
|
||||||
|
|
||||||
# Using class directly.
|
# Using class directly.
|
||||||
config["env"] = RandomEnv
|
# This doesn't work anymore as of gym==0.23
|
||||||
trainer = pg.PGTrainer(config=config)
|
# config["env"] = RandomEnv
|
||||||
print(trainer.train())
|
# trainer = pg.PGTrainer(config=config)
|
||||||
trainer.stop()
|
# print(trainer.train())
|
||||||
|
# trainer.stop()
|
||||||
|
|
||||||
# Using class directly: Sub-class of gym.Env,
|
# Using class directly: Sub-class of gym.Env,
|
||||||
# which implements its own API.
|
# which implements its own API.
|
||||||
config["env"] = NonVectorizedEnvToBeVectorizedIntoRemoteBaseEnv
|
# This doesn't work anymore as of gym==0.23
|
||||||
trainer = pg.PGTrainer(config=config)
|
# config["env"] = NonVectorizedEnvToBeVectorizedIntoRemoteBaseEnv
|
||||||
print(trainer.train())
|
# trainer = pg.PGTrainer(config=config)
|
||||||
trainer.stop()
|
# print(trainer.train())
|
||||||
|
# trainer.stop()
|
||||||
|
|
||||||
def test_remote_worker_env_multi_agent(self):
|
def test_remote_worker_env_multi_agent(self):
|
||||||
config = pg.DEFAULT_CONFIG.copy()
|
config = pg.DEFAULT_CONFIG.copy()
|
||||||
|
@ -85,10 +85,11 @@ class TestRemoteWorkerEnvSetting(unittest.TestCase):
|
||||||
trainer.stop()
|
trainer.stop()
|
||||||
|
|
||||||
# Using class directly.
|
# Using class directly.
|
||||||
config["env"] = RandomMultiAgentEnv
|
# This doesn't work anymore as of gym==0.23.
|
||||||
trainer = pg.PGTrainer(config=config)
|
# config["env"] = RandomMultiAgentEnv
|
||||||
print(trainer.train())
|
# trainer = pg.PGTrainer(config=config)
|
||||||
trainer.stop()
|
# print(trainer.train())
|
||||||
|
# trainer.stop()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
49
rllib/env/utils.py
vendored
49
rllib/env/utils.py
vendored
|
@ -1,10 +1,6 @@
|
||||||
import gym
|
import gym
|
||||||
from gym import wrappers
|
|
||||||
import os
|
|
||||||
|
|
||||||
from ray.rllib.env.env_context import EnvContext
|
from ray.rllib.env.env_context import EnvContext
|
||||||
from ray.rllib.env.multi_agent_env import MultiAgentEnv
|
|
||||||
from ray.rllib.utils import add_mixins
|
|
||||||
from ray.rllib.utils.error import ERR_MSG_INVALID_ENV_DESCRIPTOR, EnvError
|
from ray.rllib.utils.error import ERR_MSG_INVALID_ENV_DESCRIPTOR, EnvError
|
||||||
|
|
||||||
|
|
||||||
|
@ -54,48 +50,3 @@ def gym_env_creator(env_context: EnvContext, env_descriptor: str) -> gym.Env:
|
||||||
return gym.make(env_descriptor, **env_context)
|
return gym.make(env_descriptor, **env_context)
|
||||||
except gym.error.Error:
|
except gym.error.Error:
|
||||||
raise EnvError(ERR_MSG_INVALID_ENV_DESCRIPTOR.format(env_descriptor))
|
raise EnvError(ERR_MSG_INVALID_ENV_DESCRIPTOR.format(env_descriptor))
|
||||||
|
|
||||||
|
|
||||||
class VideoMonitor(wrappers.Monitor):
|
|
||||||
# Same as original method, but doesn't use the StatsRecorder as it will
|
|
||||||
# try to add up multi-agent rewards dicts, which throws errors.
|
|
||||||
def _after_step(self, observation, reward, done, info):
|
|
||||||
if not self.enabled:
|
|
||||||
return done
|
|
||||||
|
|
||||||
# Use done["__all__"] b/c this is a multi-agent dict.
|
|
||||||
if done["__all__"] and self.env_semantics_autoreset:
|
|
||||||
# For envs with BlockingReset wrapping VNCEnv, this observation
|
|
||||||
# will be the first one of the new episode
|
|
||||||
self.reset_video_recorder()
|
|
||||||
self.episode_id += 1
|
|
||||||
self._flush()
|
|
||||||
|
|
||||||
# Record video
|
|
||||||
self.video_recorder.capture_frame()
|
|
||||||
|
|
||||||
return done
|
|
||||||
|
|
||||||
|
|
||||||
def record_env_wrapper(env, record_env, log_dir, policy_config):
|
|
||||||
if record_env:
|
|
||||||
path_ = record_env if isinstance(record_env, str) else log_dir
|
|
||||||
# Relative path: Add logdir here, otherwise, this would
|
|
||||||
# not work for non-local workers.
|
|
||||||
if not os.path.isabs(path_):
|
|
||||||
path_ = os.path.join(log_dir, path_)
|
|
||||||
print(f"Setting the path for recording to {path_}")
|
|
||||||
wrapper_cls = (
|
|
||||||
VideoMonitor if isinstance(env, MultiAgentEnv) else wrappers.Monitor
|
|
||||||
)
|
|
||||||
if isinstance(env, MultiAgentEnv):
|
|
||||||
wrapper_cls = add_mixins(wrapper_cls, [MultiAgentEnv], reversed=True)
|
|
||||||
env = wrapper_cls(
|
|
||||||
env,
|
|
||||||
path_,
|
|
||||||
resume=True,
|
|
||||||
force=True,
|
|
||||||
video_callable=lambda _: True,
|
|
||||||
mode="evaluation" if policy_config["in_evaluation"] else "training",
|
|
||||||
)
|
|
||||||
return env
|
|
||||||
|
|
|
@ -4,7 +4,6 @@ import argparse
|
||||||
import collections
|
import collections
|
||||||
import copy
|
import copy
|
||||||
import gym
|
import gym
|
||||||
from gym import wrappers as gym_wrappers
|
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
@ -339,7 +338,6 @@ def run(args, parser):
|
||||||
deprecation_warning(old="--no-render", new="--render", error=False)
|
deprecation_warning(old="--no-render", new="--render", error=False)
|
||||||
args.render = False
|
args.render = False
|
||||||
config["render_env"] = args.render
|
config["render_env"] = args.render
|
||||||
config["record_env"] = args.video_dir
|
|
||||||
|
|
||||||
ray.init(local_mode=args.local_mode)
|
ray.init(local_mode=args.local_mode)
|
||||||
|
|
||||||
|
@ -354,12 +352,6 @@ def run(args, parser):
|
||||||
num_steps = int(args.steps)
|
num_steps = int(args.steps)
|
||||||
num_episodes = int(args.episodes)
|
num_episodes = int(args.episodes)
|
||||||
|
|
||||||
# Determine the video output directory.
|
|
||||||
video_dir = None
|
|
||||||
# Allow user to specify a video output path.
|
|
||||||
if args.video_dir:
|
|
||||||
video_dir = os.path.expanduser(args.video_dir)
|
|
||||||
|
|
||||||
# Do the actual rollout.
|
# Do the actual rollout.
|
||||||
with RolloutSaver(
|
with RolloutSaver(
|
||||||
args.out,
|
args.out,
|
||||||
|
@ -369,9 +361,7 @@ def run(args, parser):
|
||||||
target_episodes=num_episodes,
|
target_episodes=num_episodes,
|
||||||
save_info=args.save_info,
|
save_info=args.save_info,
|
||||||
) as saver:
|
) as saver:
|
||||||
rollout(
|
rollout(agent, args.env, num_steps, num_episodes, saver, not args.render)
|
||||||
agent, args.env, num_steps, num_episodes, saver, not args.render, video_dir
|
|
||||||
)
|
|
||||||
agent.stop()
|
agent.stop()
|
||||||
|
|
||||||
|
|
||||||
|
@ -406,7 +396,6 @@ def rollout(
|
||||||
num_episodes=0,
|
num_episodes=0,
|
||||||
saver=None,
|
saver=None,
|
||||||
no_render=True,
|
no_render=True,
|
||||||
video_dir=None,
|
|
||||||
):
|
):
|
||||||
policy_agent_mapping = default_policy_agent_mapping
|
policy_agent_mapping = default_policy_agent_mapping
|
||||||
|
|
||||||
|
@ -473,13 +462,6 @@ def rollout(
|
||||||
for p, m in policy_map.items()
|
for p, m in policy_map.items()
|
||||||
}
|
}
|
||||||
|
|
||||||
# If monitoring has been requested, manually wrap our environment with a
|
|
||||||
# gym monitor, which is set to record every episode.
|
|
||||||
if video_dir:
|
|
||||||
env = gym_wrappers.Monitor(
|
|
||||||
env=env, directory=video_dir, video_callable=lambda _: True, force=True
|
|
||||||
)
|
|
||||||
|
|
||||||
steps = 0
|
steps = 0
|
||||||
episodes = 0
|
episodes = 0
|
||||||
while keep_going(steps, num_steps, episodes, num_episodes):
|
while keep_going(steps, num_steps, episodes, num_episodes):
|
||||||
|
|
|
@ -27,7 +27,6 @@ from ray.rllib.env.base_env import BaseEnv, convert_to_base_env
|
||||||
from ray.rllib.env.env_context import EnvContext
|
from ray.rllib.env.env_context import EnvContext
|
||||||
from ray.rllib.env.multi_agent_env import MultiAgentEnv
|
from ray.rllib.env.multi_agent_env import MultiAgentEnv
|
||||||
from ray.rllib.env.external_multi_agent_env import ExternalMultiAgentEnv
|
from ray.rllib.env.external_multi_agent_env import ExternalMultiAgentEnv
|
||||||
from ray.rllib.env.utils import record_env_wrapper
|
|
||||||
from ray.rllib.env.wrappers.atari_wrappers import wrap_deepmind, is_atari
|
from ray.rllib.env.wrappers.atari_wrappers import wrap_deepmind, is_atari
|
||||||
from ray.rllib.evaluation.sampler import AsyncSampler, SyncSampler
|
from ray.rllib.evaluation.sampler import AsyncSampler, SyncSampler
|
||||||
from ray.rllib.evaluation.metrics import RolloutMetrics
|
from ray.rllib.evaluation.metrics import RolloutMetrics
|
||||||
|
@ -233,7 +232,6 @@ class RolloutWorker(ParallelIteratorWorker):
|
||||||
worker_index: int = 0,
|
worker_index: int = 0,
|
||||||
num_workers: int = 0,
|
num_workers: int = 0,
|
||||||
recreated_worker: bool = False,
|
recreated_worker: bool = False,
|
||||||
record_env: Union[bool, str] = False,
|
|
||||||
log_dir: Optional[str] = None,
|
log_dir: Optional[str] = None,
|
||||||
log_level: Optional[str] = None,
|
log_level: Optional[str] = None,
|
||||||
callbacks: Type["DefaultCallbacks"] = None,
|
callbacks: Type["DefaultCallbacks"] = None,
|
||||||
|
@ -253,7 +251,6 @@ class RolloutWorker(ParallelIteratorWorker):
|
||||||
fake_sampler: bool = False,
|
fake_sampler: bool = False,
|
||||||
spaces: Optional[Dict[PolicyID, Tuple[Space, Space]]] = None,
|
spaces: Optional[Dict[PolicyID, Tuple[Space, Space]]] = None,
|
||||||
policy=None,
|
policy=None,
|
||||||
monitor_path=None,
|
|
||||||
disable_env_checking=False,
|
disable_env_checking=False,
|
||||||
):
|
):
|
||||||
"""Initializes a RolloutWorker instance.
|
"""Initializes a RolloutWorker instance.
|
||||||
|
@ -332,10 +329,6 @@ class RolloutWorker(ParallelIteratorWorker):
|
||||||
`recreate_failed_workers=True` and one of the original workers (or an
|
`recreate_failed_workers=True` and one of the original workers (or an
|
||||||
already recreated one) has failed. They don't differ from original
|
already recreated one) has failed. They don't differ from original
|
||||||
workers other than the value of this flag (`self.recreated_worker`).
|
workers other than the value of this flag (`self.recreated_worker`).
|
||||||
record_env: Write out episode stats and videos
|
|
||||||
using gym.wrappers.Monitor to this directory if specified. If
|
|
||||||
True, use the default output dir in ~/ray_results/.... If
|
|
||||||
False, do not record anything.
|
|
||||||
log_dir: Directory where logs can be placed.
|
log_dir: Directory where logs can be placed.
|
||||||
log_level: Set the root log level on creation.
|
log_level: Set the root log level on creation.
|
||||||
callbacks: Custom sub-class of
|
callbacks: Custom sub-class of
|
||||||
|
@ -374,7 +367,6 @@ class RolloutWorker(ParallelIteratorWorker):
|
||||||
to (obs_space, action_space)-tuples. This is used in case no
|
to (obs_space, action_space)-tuples. This is used in case no
|
||||||
Env is created on this RolloutWorker.
|
Env is created on this RolloutWorker.
|
||||||
policy: Obsoleted arg. Use `policy_spec` instead.
|
policy: Obsoleted arg. Use `policy_spec` instead.
|
||||||
monitor_path: Obsoleted arg. Use `record_env` instead.
|
|
||||||
disable_env_checking: If True, disables the env checking module that
|
disable_env_checking: If True, disables the env checking module that
|
||||||
validates the properties of the passed environment.
|
validates the properties of the passed environment.
|
||||||
"""
|
"""
|
||||||
|
@ -395,10 +387,6 @@ class RolloutWorker(ParallelIteratorWorker):
|
||||||
for pid, spec in policy_spec.copy().items()
|
for pid, spec in policy_spec.copy().items()
|
||||||
}
|
}
|
||||||
|
|
||||||
if monitor_path is not None:
|
|
||||||
deprecation_warning("monitor_path", "record_env", error=False)
|
|
||||||
record_env = monitor_path
|
|
||||||
|
|
||||||
self._original_kwargs: dict = locals().copy()
|
self._original_kwargs: dict = locals().copy()
|
||||||
del self._original_kwargs["self"]
|
del self._original_kwargs["self"]
|
||||||
|
|
||||||
|
@ -490,7 +478,7 @@ class RolloutWorker(ParallelIteratorWorker):
|
||||||
# 1) Create the env using the user provided env_creator. This may
|
# 1) Create the env using the user provided env_creator. This may
|
||||||
# return a gym.Env (incl. MultiAgentEnv), an already vectorized
|
# return a gym.Env (incl. MultiAgentEnv), an already vectorized
|
||||||
# VectorEnv, BaseEnv, ExternalEnv, or an ActorHandle (remote env).
|
# VectorEnv, BaseEnv, ExternalEnv, or an ActorHandle (remote env).
|
||||||
# 2) Wrap - if applicable - with Atari/recording/rendering wrappers.
|
# 2) Wrap - if applicable - with Atari/rendering wrappers.
|
||||||
# 3) Seed the env, if necessary.
|
# 3) Seed the env, if necessary.
|
||||||
# 4) Vectorize the existing single env by creating more clones of
|
# 4) Vectorize the existing single env by creating more clones of
|
||||||
# this env and wrapping it with the RLlib BaseEnv class.
|
# this env and wrapping it with the RLlib BaseEnv class.
|
||||||
|
@ -541,14 +529,12 @@ class RolloutWorker(ParallelIteratorWorker):
|
||||||
env = wrap_deepmind(
|
env = wrap_deepmind(
|
||||||
env, dim=model_config.get("dim"), framestack=use_framestack
|
env, dim=model_config.get("dim"), framestack=use_framestack
|
||||||
)
|
)
|
||||||
env = record_env_wrapper(env, record_env, log_dir, policy_config)
|
|
||||||
return env
|
return env
|
||||||
|
|
||||||
# gym.Env -> Wrap with gym Monitor.
|
|
||||||
else:
|
else:
|
||||||
|
|
||||||
def wrap(env):
|
def wrap(env):
|
||||||
return record_env_wrapper(env, record_env, log_dir, policy_config)
|
return env
|
||||||
|
|
||||||
# Wrap env through the correct wrapper.
|
# Wrap env through the correct wrapper.
|
||||||
self.env: EnvType = wrap(self.env)
|
self.env: EnvType = wrap(self.env)
|
||||||
|
|
|
@ -4,7 +4,6 @@ from gym.spaces import Box, Discrete
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
import tempfile
|
|
||||||
import time
|
import time
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
|
@ -12,7 +11,6 @@ import ray
|
||||||
from ray.rllib.algorithms.pg import PGTrainer
|
from ray.rllib.algorithms.pg import PGTrainer
|
||||||
from ray.rllib.agents.a3c import A2CTrainer
|
from ray.rllib.agents.a3c import A2CTrainer
|
||||||
from ray.rllib.env.multi_agent_env import MultiAgentEnv
|
from ray.rllib.env.multi_agent_env import MultiAgentEnv
|
||||||
from ray.rllib.env.utils import VideoMonitor
|
|
||||||
from ray.rllib.evaluation.rollout_worker import RolloutWorker
|
from ray.rllib.evaluation.rollout_worker import RolloutWorker
|
||||||
from ray.rllib.evaluation.metrics import collect_metrics
|
from ray.rllib.evaluation.metrics import collect_metrics
|
||||||
from ray.rllib.evaluation.postprocessing import compute_advantages
|
from ray.rllib.evaluation.postprocessing import compute_advantages
|
||||||
|
@ -376,10 +374,8 @@ class TestRolloutWorker(unittest.TestCase):
|
||||||
|
|
||||||
curframe = inspect.currentframe()
|
curframe = inspect.currentframe()
|
||||||
called_from_check = any(
|
called_from_check = any(
|
||||||
[
|
|
||||||
frame[3] == "check_gym_environments"
|
frame[3] == "check_gym_environments"
|
||||||
for frame in inspect.getouterframes(curframe, 2)
|
for frame in inspect.getouterframes(curframe, 2)
|
||||||
]
|
|
||||||
)
|
)
|
||||||
# Check, whether the action is immutable.
|
# Check, whether the action is immutable.
|
||||||
if action.flags.writeable and not called_from_check:
|
if action.flags.writeable and not called_from_check:
|
||||||
|
@ -825,15 +821,12 @@ class TestRolloutWorker(unittest.TestCase):
|
||||||
policy_config={
|
policy_config={
|
||||||
"in_evaluation": False,
|
"in_evaluation": False,
|
||||||
},
|
},
|
||||||
record_env=tempfile.gettempdir(),
|
|
||||||
)
|
)
|
||||||
# Make sure we can properly sample from the wrapped env.
|
# Make sure we can properly sample from the wrapped env.
|
||||||
ev.sample()
|
ev.sample()
|
||||||
# Make sure the resulting environment is indeed still an
|
# Make sure the resulting environment is indeed still an
|
||||||
# instance of MultiAgentEnv and VideoMonitor.
|
|
||||||
self.assertTrue(isinstance(ev.env.unwrapped, MultiAgentEnv))
|
self.assertTrue(isinstance(ev.env.unwrapped, MultiAgentEnv))
|
||||||
self.assertTrue(isinstance(ev.env, gym.Env))
|
self.assertTrue(isinstance(ev.env, gym.Env))
|
||||||
self.assertTrue(isinstance(ev.env, VideoMonitor))
|
|
||||||
ev.stop()
|
ev.stop()
|
||||||
|
|
||||||
def test_no_training(self):
|
def test_no_training(self):
|
||||||
|
|
|
@ -654,7 +654,6 @@ class WorkerSet:
|
||||||
worker_index=worker_index,
|
worker_index=worker_index,
|
||||||
num_workers=num_workers,
|
num_workers=num_workers,
|
||||||
recreated_worker=recreated_worker,
|
recreated_worker=recreated_worker,
|
||||||
record_env=config["record_env"],
|
|
||||||
log_dir=self._logdir,
|
log_dir=self._logdir,
|
||||||
log_level=config["log_level"],
|
log_level=config["log_level"],
|
||||||
callbacks=config["callbacks"],
|
callbacks=config["callbacks"],
|
||||||
|
|
|
@ -1,12 +1,3 @@
|
||||||
# ---------------
|
|
||||||
# IMPORTANT NOTE:
|
|
||||||
# ---------------
|
|
||||||
# A recent bug in openAI gym prevents RLlib's "record_env" option
|
|
||||||
# from recording videos properly. Instead, the produced mp4 files
|
|
||||||
# have a size of 1kb and are corrupted.
|
|
||||||
# A simple fix for this is described here:
|
|
||||||
# https://github.com/openai/gym/issues/1925
|
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import gym
|
import gym
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
@ -115,13 +106,6 @@ if __name__ == "__main__":
|
||||||
# Special evaluation config. Keys specified here will override
|
# Special evaluation config. Keys specified here will override
|
||||||
# the same keys in the main config, but only for evaluation.
|
# the same keys in the main config, but only for evaluation.
|
||||||
"evaluation_config": {
|
"evaluation_config": {
|
||||||
# Store videos in this relative directory here inside
|
|
||||||
# the default output dir (~/ray_results/...).
|
|
||||||
# Alternatively, you can specify an absolute path.
|
|
||||||
# Set to True for using the default output dir (~/ray_results/...).
|
|
||||||
# Set to False for not recording anything.
|
|
||||||
"record_env": "videos",
|
|
||||||
# "record_env": "/Users/xyz/my_videos/",
|
|
||||||
# Render the env while evaluating.
|
# Render the env while evaluating.
|
||||||
# Note that this will always only render the 1st RolloutWorker's
|
# Note that this will always only render the 1st RolloutWorker's
|
||||||
# env and only the 1st sub-env in a vectorized env.
|
# env and only the 1st sub-env in a vectorized env.
|
||||||
|
|
|
@ -46,6 +46,11 @@ parser.add_argument(
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--stop-reward", type=float, default=180.0, help="Reward at which we stop training."
|
"--stop-reward", type=float, default=180.0, help="Reward at which we stop training."
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--local-mode",
|
||||||
|
action="store_true",
|
||||||
|
help="Init Ray in local mode for easier debugging.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class NonVectorizedEnvToBeVectorizedIntoRemoteBaseEnv(TaskSettableEnv):
|
class NonVectorizedEnvToBeVectorizedIntoRemoteBaseEnv(TaskSettableEnv):
|
||||||
|
@ -96,7 +101,7 @@ class TaskSettingCallback(DefaultCallbacks):
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
ray.init(num_cpus=6)
|
ray.init(num_cpus=6, local_mode=args.local_mode)
|
||||||
|
|
||||||
config = {
|
config = {
|
||||||
# Specify your custom (single, non-vectorized) env directly as a
|
# Specify your custom (single, non-vectorized) env directly as a
|
||||||
|
|
|
@ -6,15 +6,14 @@ import ray
|
||||||
from ray import tune
|
from ray import tune
|
||||||
from ray.rllib.agents.registry import get_trainer_class
|
from ray.rllib.agents.registry import get_trainer_class
|
||||||
|
|
||||||
from ray.rllib.examples.env.repeat_after_me_env import RepeatAfterMeEnv
|
# from ray.rllib.examples.env.repeat_after_me_env import RepeatAfterMeEnv
|
||||||
from ray.rllib.examples.env.stateless_cartpole import StatelessCartPole
|
from ray.rllib.examples.env.stateless_cartpole import StatelessCartPole
|
||||||
|
|
||||||
envs = {"RepeatAfterMeEnv": RepeatAfterMeEnv, "StatelessCartPole": StatelessCartPole}
|
|
||||||
|
|
||||||
config = {
|
config = {
|
||||||
"name": "RNNSAC_example",
|
"name": "RNNSAC_example",
|
||||||
"local_dir": str(Path(__file__).parent / "example_out"),
|
"local_dir": str(Path(__file__).parent / "example_out"),
|
||||||
"checkpoint_freq": 1,
|
"checkpoint_at_end": True,
|
||||||
"keep_checkpoints_num": 1,
|
"keep_checkpoints_num": 1,
|
||||||
"checkpoint_score_attr": "episode_reward_mean",
|
"checkpoint_score_attr": "episode_reward_mean",
|
||||||
"stop": {
|
"stop": {
|
||||||
|
@ -29,11 +28,8 @@ config = {
|
||||||
"num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")),
|
"num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")),
|
||||||
"framework": "torch",
|
"framework": "torch",
|
||||||
"num_workers": 4,
|
"num_workers": 4,
|
||||||
"num_envs_per_worker": 1,
|
# "env": RepeatAfterMeEnv,
|
||||||
"num_cpus_per_worker": 1,
|
"env": StatelessCartPole,
|
||||||
"log_level": "INFO",
|
|
||||||
# "env": envs["RepeatAfterMeEnv"],
|
|
||||||
"env": envs["StatelessCartPole"],
|
|
||||||
"horizon": 1000,
|
"horizon": 1000,
|
||||||
"gamma": 0.95,
|
"gamma": 0.95,
|
||||||
"batch_mode": "complete_episodes",
|
"batch_mode": "complete_episodes",
|
||||||
|
@ -102,7 +98,7 @@ if __name__ == "__main__":
|
||||||
eps = 0
|
eps = 0
|
||||||
ep_reward = 0
|
ep_reward = 0
|
||||||
while eps < 10:
|
while eps < 10:
|
||||||
action, state, info_trainer = agent.compute_action(
|
action, state, info_trainer = agent.compute_single_action(
|
||||||
obs,
|
obs,
|
||||||
state=state,
|
state=state,
|
||||||
prev_action=prev_action,
|
prev_action=prev_action,
|
||||||
|
@ -115,7 +111,7 @@ if __name__ == "__main__":
|
||||||
ep_reward += reward
|
ep_reward += reward
|
||||||
try:
|
try:
|
||||||
env.render()
|
env.render()
|
||||||
except (NotImplementedError, ImportError):
|
except Exception:
|
||||||
pass
|
pass
|
||||||
if done:
|
if done:
|
||||||
eps += 1
|
eps += 1
|
||||||
|
|
|
@ -29,7 +29,6 @@ class TestGymCheckEnv(unittest.TestCase):
|
||||||
env = Mock(spec=["observation_space"])
|
env = Mock(spec=["observation_space"])
|
||||||
with pytest.raises(AttributeError, match="Env must have action_space."):
|
with pytest.raises(AttributeError, match="Env must have action_space."):
|
||||||
check_gym_environments(env)
|
check_gym_environments(env)
|
||||||
del env
|
|
||||||
|
|
||||||
def test_obs_and_action_spaces_are_gym_spaces(self):
|
def test_obs_and_action_spaces_are_gym_spaces(self):
|
||||||
env = RandomEnv()
|
env = RandomEnv()
|
||||||
|
@ -41,7 +40,6 @@ class TestGymCheckEnv(unittest.TestCase):
|
||||||
env.action_space = "not an action space"
|
env.action_space = "not an action space"
|
||||||
with pytest.raises(ValueError, match="Action space must be a gym.space"):
|
with pytest.raises(ValueError, match="Action space must be a gym.space"):
|
||||||
check_env(env)
|
check_env(env)
|
||||||
del env
|
|
||||||
|
|
||||||
def test_reset(self):
|
def test_reset(self):
|
||||||
reset = MagicMock(return_value=5)
|
reset = MagicMock(return_value=5)
|
||||||
|
@ -56,7 +54,6 @@ class TestGymCheckEnv(unittest.TestCase):
|
||||||
env.reset = reset
|
env.reset = reset
|
||||||
with pytest.raises(ValueError, match=error):
|
with pytest.raises(ValueError, match=error):
|
||||||
check_env(env)
|
check_env(env)
|
||||||
del env
|
|
||||||
|
|
||||||
def test_step(self):
|
def test_step(self):
|
||||||
step = MagicMock(return_value=(5, 5, True, {}))
|
step = MagicMock(return_value=(5, 5, True, {}))
|
||||||
|
@ -92,7 +89,6 @@ class TestGymCheckEnv(unittest.TestCase):
|
||||||
error = "Your step function must return a info that is a dict."
|
error = "Your step function must return a info that is a dict."
|
||||||
with pytest.raises(ValueError, match=error):
|
with pytest.raises(ValueError, match=error):
|
||||||
check_env(env)
|
check_env(env)
|
||||||
del env
|
|
||||||
|
|
||||||
|
|
||||||
class TestCheckMultiAgentEnv(unittest.TestCase):
|
class TestCheckMultiAgentEnv(unittest.TestCase):
|
||||||
|
@ -104,7 +100,6 @@ class TestCheckMultiAgentEnv(unittest.TestCase):
|
||||||
env = RandomEnv()
|
env = RandomEnv()
|
||||||
with pytest.raises(ValueError, match="The passed env is not"):
|
with pytest.raises(ValueError, match="The passed env is not"):
|
||||||
check_multiagent_environments(env)
|
check_multiagent_environments(env)
|
||||||
del env
|
|
||||||
|
|
||||||
def test_check_env_reset_incorrect_error(self):
|
def test_check_env_reset_incorrect_error(self):
|
||||||
reset = MagicMock(return_value=5)
|
reset = MagicMock(return_value=5)
|
||||||
|
@ -119,7 +114,6 @@ class TestCheckMultiAgentEnv(unittest.TestCase):
|
||||||
env.reset = lambda *_: bad_obs
|
env.reset = lambda *_: bad_obs
|
||||||
with pytest.raises(ValueError, match="The observation collected from env"):
|
with pytest.raises(ValueError, match="The observation collected from env"):
|
||||||
check_env(env)
|
check_env(env)
|
||||||
del env
|
|
||||||
|
|
||||||
def test_check_incorrect_space_contains_functions_error(self):
|
def test_check_incorrect_space_contains_functions_error(self):
|
||||||
def bad_contains_function(self, x):
|
def bad_contains_function(self, x):
|
||||||
|
@ -131,7 +125,6 @@ class TestCheckMultiAgentEnv(unittest.TestCase):
|
||||||
ValueError, match="Your observation_space_contains function has some"
|
ValueError, match="Your observation_space_contains function has some"
|
||||||
):
|
):
|
||||||
check_env(env)
|
check_env(env)
|
||||||
del env
|
|
||||||
env = make_multi_agent("CartPole-v1")({"num_agents": 2})
|
env = make_multi_agent("CartPole-v1")({"num_agents": 2})
|
||||||
bad_action = {0: 2, 1: 2}
|
bad_action = {0: 2, 1: 2}
|
||||||
env.action_space_sample = lambda *_: bad_action
|
env.action_space_sample = lambda *_: bad_action
|
||||||
|
@ -178,7 +171,6 @@ class TestCheckMultiAgentEnv(unittest.TestCase):
|
||||||
ValueError, match="The action collected from action_space_sample"
|
ValueError, match="The action collected from action_space_sample"
|
||||||
):
|
):
|
||||||
check_env(env)
|
check_env(env)
|
||||||
del env
|
|
||||||
env = make_multi_agent("CartPole-v1")({"num_agents": 2})
|
env = make_multi_agent("CartPole-v1")({"num_agents": 2})
|
||||||
bad_obs = {
|
bad_obs = {
|
||||||
0: np.array([np.inf, np.inf, np.inf, np.inf]),
|
0: np.array([np.inf, np.inf, np.inf, np.inf]),
|
||||||
|
@ -206,7 +198,6 @@ class TestCheckBaseEnv:
|
||||||
env = RandomEnv()
|
env = RandomEnv()
|
||||||
with pytest.raises(ValueError, match="The passed env is not"):
|
with pytest.raises(ValueError, match="The passed env is not"):
|
||||||
check_base_env(env)
|
check_base_env(env)
|
||||||
del env
|
|
||||||
|
|
||||||
def test_check_env_reset_incorrect_error(self):
|
def test_check_env_reset_incorrect_error(self):
|
||||||
reset = MagicMock(return_value=5)
|
reset = MagicMock(return_value=5)
|
||||||
|
@ -244,7 +235,6 @@ class TestCheckBaseEnv:
|
||||||
ValueError, match="The observation collected from try_reset"
|
ValueError, match="The observation collected from try_reset"
|
||||||
):
|
):
|
||||||
check_env(env)
|
check_env(env)
|
||||||
del env
|
|
||||||
|
|
||||||
def test_check_space_contains_functions_errors(self):
|
def test_check_space_contains_functions_errors(self):
|
||||||
def bad_contains_function(self, x):
|
def bad_contains_function(self, x):
|
||||||
|
@ -258,14 +248,12 @@ class TestCheckBaseEnv:
|
||||||
):
|
):
|
||||||
check_env(env)
|
check_env(env)
|
||||||
|
|
||||||
del env
|
|
||||||
env = self._make_base_env()
|
env = self._make_base_env()
|
||||||
env.action_space_contains = bad_contains_function
|
env.action_space_contains = bad_contains_function
|
||||||
with pytest.raises(
|
with pytest.raises(
|
||||||
ValueError, match="Your action_space_contains function has some error"
|
ValueError, match="Your action_space_contains function has some error"
|
||||||
):
|
):
|
||||||
check_env(env)
|
check_env(env)
|
||||||
del env
|
|
||||||
|
|
||||||
def test_bad_sample_function(self):
|
def test_bad_sample_function(self):
|
||||||
env = self._make_base_env()
|
env = self._make_base_env()
|
||||||
|
@ -276,7 +264,6 @@ class TestCheckBaseEnv:
|
||||||
):
|
):
|
||||||
check_env(env)
|
check_env(env)
|
||||||
|
|
||||||
del env
|
|
||||||
env = self._make_base_env()
|
env = self._make_base_env()
|
||||||
bad_obs = {
|
bad_obs = {
|
||||||
0: {
|
0: {
|
||||||
|
|
Loading…
Add table
Reference in a new issue