mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[RLlib] Better error messages and hints; + failure-mode tests; (#18466)
This commit is contained in:
parent
ead02b21b9
commit
3f89f35e52
7 changed files with 175 additions and 41 deletions
|
@ -29,6 +29,7 @@ from ray.rllib.utils.annotations import Deprecated, DeveloperAPI, override, \
|
||||||
PublicAPI
|
PublicAPI
|
||||||
from ray.rllib.utils.debug import update_global_seed_if_necessary
|
from ray.rllib.utils.debug import update_global_seed_if_necessary
|
||||||
from ray.rllib.utils.deprecation import deprecation_warning, DEPRECATED_VALUE
|
from ray.rllib.utils.deprecation import deprecation_warning, DEPRECATED_VALUE
|
||||||
|
from ray.rllib.utils.error import EnvError, ERR_MSG_INVALID_ENV_DESCRIPTOR
|
||||||
from ray.rllib.utils.framework import try_import_tf, TensorStructType
|
from ray.rllib.utils.framework import try_import_tf, TensorStructType
|
||||||
from ray.rllib.utils.from_config import from_config
|
from ray.rllib.utils.from_config import from_config
|
||||||
from ray.rllib.utils.multi_agent import check_multi_agent
|
from ray.rllib.utils.multi_agent import check_multi_agent
|
||||||
|
@ -698,8 +699,16 @@ class Trainer(Trainable):
|
||||||
self.env_creator = _global_registry.get(ENV_CREATOR, env)
|
self.env_creator = _global_registry.get(ENV_CREATOR, env)
|
||||||
# A class specifier.
|
# A class specifier.
|
||||||
elif "." in env:
|
elif "." in env:
|
||||||
self.env_creator = \
|
|
||||||
lambda env_context: from_config(env, env_context)
|
def env_creator_from_classpath(env_context):
|
||||||
|
try:
|
||||||
|
env_obj = from_config(env, env_context)
|
||||||
|
except ValueError:
|
||||||
|
raise EnvError(
|
||||||
|
ERR_MSG_INVALID_ENV_DESCRIPTOR.format(env))
|
||||||
|
return env_obj
|
||||||
|
|
||||||
|
self.env_creator = env_creator_from_classpath
|
||||||
# Try gym/PyBullet/Vizdoom.
|
# Try gym/PyBullet/Vizdoom.
|
||||||
else:
|
else:
|
||||||
self.env_creator = functools.partial(
|
self.env_creator = functools.partial(
|
||||||
|
|
21
rllib/env/utils.py
vendored
21
rllib/env/utils.py
vendored
|
@ -4,6 +4,7 @@ import os
|
||||||
from ray.rllib.env.env_context import EnvContext
|
from ray.rllib.env.env_context import EnvContext
|
||||||
from ray.rllib.env.multi_agent_env import MultiAgentEnv
|
from ray.rllib.env.multi_agent_env import MultiAgentEnv
|
||||||
from ray.rllib.utils import add_mixins
|
from ray.rllib.utils import add_mixins
|
||||||
|
from ray.rllib.utils.error import ERR_MSG_INVALID_ENV_DESCRIPTOR, EnvError
|
||||||
|
|
||||||
|
|
||||||
def gym_env_creator(env_context: EnvContext, env_descriptor: str):
|
def gym_env_creator(env_context: EnvContext, env_descriptor: str):
|
||||||
|
@ -50,25 +51,7 @@ def gym_env_creator(env_context: EnvContext, env_descriptor: str):
|
||||||
try:
|
try:
|
||||||
return gym.make(env_descriptor, **env_context)
|
return gym.make(env_descriptor, **env_context)
|
||||||
except gym.error.Error:
|
except gym.error.Error:
|
||||||
error_msg = f"The env string you provided ('{env_descriptor}') is:" + \
|
raise EnvError(ERR_MSG_INVALID_ENV_DESCRIPTOR.format(env_descriptor))
|
||||||
"""
|
|
||||||
a) Not a supported/installed environment.
|
|
||||||
b) Not a tune-registered environment creator.
|
|
||||||
c) Not a valid env class string.
|
|
||||||
|
|
||||||
Try one of the following:
|
|
||||||
a) For Atari support: `pip install gym[atari] atari_py`.
|
|
||||||
For VizDoom support: Install VizDoom
|
|
||||||
(https://github.com/mwydmuch/ViZDoom/blob/master/doc/Building.md) and
|
|
||||||
`pip install vizdoomgym`.
|
|
||||||
For PyBullet support: `pip install pybullet`.
|
|
||||||
b) To register your custom env, do `from ray import tune;
|
|
||||||
tune.register('[name]', lambda cfg: [return env obj from here using cfg])`.
|
|
||||||
Then in your config, do `config['env'] = [name]`.
|
|
||||||
c) Make sure you provide a fully qualified classpath, e.g.:
|
|
||||||
`ray.rllib.examples.env.repeat_after_me_env.RepeatAfterMeEnv`
|
|
||||||
"""
|
|
||||||
raise gym.error.Error(error_msg)
|
|
||||||
|
|
||||||
|
|
||||||
class VideoMonitor(wrappers.Monitor):
|
class VideoMonitor(wrappers.Monitor):
|
||||||
|
|
|
@ -33,7 +33,8 @@ from ray.rllib.utils import force_list, merge_dicts
|
||||||
from ray.rllib.utils.annotations import DeveloperAPI
|
from ray.rllib.utils.annotations import DeveloperAPI
|
||||||
from ray.rllib.utils.debug import summarize, update_global_seed_if_necessary
|
from ray.rllib.utils.debug import summarize, update_global_seed_if_necessary
|
||||||
from ray.rllib.utils.deprecation import deprecation_warning
|
from ray.rllib.utils.deprecation import deprecation_warning
|
||||||
from ray.rllib.utils.error import EnvError
|
from ray.rllib.utils.error import EnvError, ERR_MSG_NO_GPUS, \
|
||||||
|
HOWTO_CHANGE_CONFIG
|
||||||
from ray.rllib.utils.filter import get_filter, Filter
|
from ray.rllib.utils.filter import get_filter, Filter
|
||||||
from ray.rllib.utils.framework import try_import_tf, try_import_torch
|
from ray.rllib.utils.framework import try_import_tf, try_import_torch
|
||||||
from ray.rllib.utils.sgd import do_minibatch_sgd
|
from ray.rllib.utils.sgd import do_minibatch_sgd
|
||||||
|
@ -556,16 +557,16 @@ class RolloutWorker(ParallelIteratorWorker):
|
||||||
ray.worker._mode() != ray.worker.LOCAL_MODE and \
|
ray.worker._mode() != ray.worker.LOCAL_MODE and \
|
||||||
not policy_config.get("_fake_gpus"):
|
not policy_config.get("_fake_gpus"):
|
||||||
|
|
||||||
|
devices = []
|
||||||
if policy_config.get("framework") in ["tf2", "tf", "tfe"]:
|
if policy_config.get("framework") in ["tf2", "tf", "tfe"]:
|
||||||
if len(get_tf_gpu_devices()) < num_gpus:
|
devices = get_tf_gpu_devices()
|
||||||
raise RuntimeError(
|
|
||||||
f"Not enough GPUs found for num_gpus={num_gpus}! "
|
|
||||||
f"Found only these IDs: {get_tf_gpu_devices()}.")
|
|
||||||
elif policy_config.get("framework") == "torch":
|
elif policy_config.get("framework") == "torch":
|
||||||
if torch.cuda.device_count() < num_gpus:
|
devices = list(range(torch.cuda.device_count()))
|
||||||
raise RuntimeError(
|
|
||||||
f"Not enough GPUs found ({torch.cuda.device_count()}) "
|
if len(devices) < num_gpus:
|
||||||
f"for num_gpus={num_gpus}!")
|
raise RuntimeError(
|
||||||
|
ERR_MSG_NO_GPUS.format(len(devices), devices) +
|
||||||
|
HOWTO_CHANGE_CONFIG)
|
||||||
# Warn, if running in local-mode and actual GPUs (not faked) are
|
# Warn, if running in local-mode and actual GPUs (not faked) are
|
||||||
# requested.
|
# requested.
|
||||||
elif ray.is_initialized() and \
|
elif ray.is_initialized() and \
|
||||||
|
|
|
@ -51,8 +51,7 @@ class TestGPUs(unittest.TestCase):
|
||||||
print("direct RLlib")
|
print("direct RLlib")
|
||||||
self.assertRaisesRegex(
|
self.assertRaisesRegex(
|
||||||
RuntimeError,
|
RuntimeError,
|
||||||
"Not enough GPUs found.+for "
|
"Found 0 GPUs on your machine",
|
||||||
f"num_gpus={num_gpus}",
|
|
||||||
lambda: PGTrainer(config, env="CartPole-v0"),
|
lambda: PGTrainer(config, env="CartPole-v0"),
|
||||||
)
|
)
|
||||||
# If actual_gpus >= num_gpus or faked,
|
# If actual_gpus >= num_gpus or faked,
|
||||||
|
|
|
@ -11,3 +11,50 @@ class UnsupportedSpaceException(Exception):
|
||||||
class EnvError(Exception):
|
class EnvError(Exception):
|
||||||
"""Error if we encounter an error during RL environment validation."""
|
"""Error if we encounter an error during RL environment validation."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
# -------
|
||||||
|
# Error messages
|
||||||
|
# -------
|
||||||
|
|
||||||
|
# Message explaining there are no GPUs available for the
|
||||||
|
# num_gpus=n or num_gpus_per_worker=m settings.
|
||||||
|
ERR_MSG_NO_GPUS = \
|
||||||
|
"""Found {} GPUs on your machine (GPU devices found: {})! If your machine
|
||||||
|
does not have any GPUs, you should set the config keys `num_gpus` and
|
||||||
|
`num_gpus_per_worker` to 0 (they may be set to 1 by default for your
|
||||||
|
particular RL algorithm)."""
|
||||||
|
|
||||||
|
ERR_MSG_INVALID_ENV_DESCRIPTOR = \
|
||||||
|
"""The env string you provided ('{}') is:
|
||||||
|
a) Not a supported/installed environment.
|
||||||
|
b) Not a tune-registered environment creator.
|
||||||
|
c) Not a valid env class string.
|
||||||
|
|
||||||
|
Try one of the following:
|
||||||
|
a) For Atari support: `pip install gym[atari] atari_py`.
|
||||||
|
For VizDoom support: Install VizDoom
|
||||||
|
(https://github.com/mwydmuch/ViZDoom/blob/master/doc/Building.md) and
|
||||||
|
`pip install vizdoomgym`.
|
||||||
|
For PyBullet support: `pip install pybullet`.
|
||||||
|
b) To register your custom env, do `from ray import tune;
|
||||||
|
tune.register('[name]', lambda cfg: [return env obj from here using cfg])`.
|
||||||
|
Then in your config, do `config['env'] = [name]`.
|
||||||
|
c) Make sure you provide a fully qualified classpath, e.g.:
|
||||||
|
`ray.rllib.examples.env.repeat_after_me_env.RepeatAfterMeEnv`
|
||||||
|
"""
|
||||||
|
|
||||||
|
# -------
|
||||||
|
# HOWTO_ strings can be added to any error/warning/into message
|
||||||
|
# to eplain to the user, how to actually fix the encountered problem.
|
||||||
|
# -------
|
||||||
|
|
||||||
|
# HOWTO change the RLlib config, depending on how user runs the job.
|
||||||
|
HOWTO_CHANGE_CONFIG = """
|
||||||
|
To change the config for the `rllib train|rollout` command, use
|
||||||
|
`--config={'[key]': '[value]'}` on the command line.
|
||||||
|
To change the config for `tune.run()` in a script: Modify the python dict
|
||||||
|
passed to `tune.run(config=[...])`.
|
||||||
|
To change the config for an RLlib Trainer instance: Modify the python dict
|
||||||
|
passed to the Trainer's constructor, e.g. `PPOTrainer(config=[...])`.
|
||||||
|
"""
|
||||||
|
|
|
@ -10,8 +10,8 @@ from ray.rllib.utils import force_list, merge_dicts
|
||||||
|
|
||||||
|
|
||||||
def from_config(cls, config=None, **kwargs):
|
def from_config(cls, config=None, **kwargs):
|
||||||
"""
|
"""Uses the given config to create an object.
|
||||||
Uses the given config to create an object.
|
|
||||||
If `config` is a dict, an optional "type" key can be used as a
|
If `config` is a dict, an optional "type" key can be used as a
|
||||||
"constructor hint" to specify a certain class of the object.
|
"constructor hint" to specify a certain class of the object.
|
||||||
If `config` is not a dict, `config`'s value is used directly as this
|
If `config` is not a dict, `config`'s value is used directly as this
|
||||||
|
@ -37,7 +37,7 @@ def from_config(cls, config=None, **kwargs):
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
cls (class): The class to build an instance for (from `config`).
|
cls (class): The class to build an instance for (from `config`).
|
||||||
config (Optional[dict,str]): The config dict or type-string or
|
config (Optional[dict, str]): The config dict or type-string or
|
||||||
filename.
|
filename.
|
||||||
|
|
||||||
Keyword Args:
|
Keyword Args:
|
||||||
|
@ -143,17 +143,27 @@ def from_config(cls, config=None, **kwargs):
|
||||||
else:
|
else:
|
||||||
return obj
|
return obj
|
||||||
|
|
||||||
# Test for absolute module.class specifier.
|
# Test for absolute module.class path specifier.
|
||||||
if type_.find(".") != -1:
|
if type_.find(".") != -1:
|
||||||
module_name, function_name = type_.rsplit(".", 1)
|
module_name, function_name = type_.rsplit(".", 1)
|
||||||
try:
|
try:
|
||||||
module = importlib.import_module(module_name)
|
module = importlib.import_module(module_name)
|
||||||
constructor = getattr(module, function_name)
|
constructor = getattr(module, function_name)
|
||||||
except (ModuleNotFoundError, ImportError):
|
# Module not found.
|
||||||
|
except (ModuleNotFoundError, ImportError, AttributeError):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
# If constructor still not found, try attaching cls' module,
|
# If constructor still not found, try attaching cls' module,
|
||||||
# then look for type_ in there.
|
# then look for type_ in there.
|
||||||
if constructor is None:
|
if constructor is None:
|
||||||
|
if isinstance(cls, str):
|
||||||
|
# Module found, but doesn't have the specified
|
||||||
|
# c'tor/function.
|
||||||
|
raise ValueError(
|
||||||
|
f"Full classpath specifier ({type_}) must be a valid "
|
||||||
|
"full [module].[class] string! E.g.: "
|
||||||
|
"`my.cool.module.MyCoolClass`.")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
module = importlib.import_module(cls.__module__)
|
module = importlib.import_module(cls.__module__)
|
||||||
constructor = getattr(module, type_)
|
constructor = getattr(module, type_)
|
||||||
|
@ -166,12 +176,12 @@ def from_config(cls, config=None, **kwargs):
|
||||||
constructor = getattr(module, type_)
|
constructor = getattr(module, type_)
|
||||||
except (ModuleNotFoundError, ImportError, AttributeError):
|
except (ModuleNotFoundError, ImportError, AttributeError):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
if constructor is None:
|
if constructor is None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"String specifier ({}) in `from_config` must be a "
|
f"String specifier ({type_}) must be a valid filename, "
|
||||||
"filename, a module+class, a class within '{}', or a key "
|
f"a [module].[class], a class within '{cls.__module__}', "
|
||||||
"into {}.__type_registry__!".format(
|
f"or a key into {cls.__name__}.__type_registry__!")
|
||||||
type_, cls.__module__, cls.__name__))
|
|
||||||
|
|
||||||
if not constructor:
|
if not constructor:
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
|
|
85
rllib/utils/tests/test_errors.py
Normal file
85
rllib/utils/tests/test_errors.py
Normal file
|
@ -0,0 +1,85 @@
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
import ray
|
||||||
|
import ray.rllib.agents.impala as impala
|
||||||
|
import ray.rllib.agents.pg as pg
|
||||||
|
from ray.rllib.utils.error import EnvError
|
||||||
|
from ray.rllib.utils.test_utils import framework_iterator
|
||||||
|
|
||||||
|
|
||||||
|
class TestErrors(unittest.TestCase):
|
||||||
|
"""Tests various failure-modes, making sure we produce meaningful errmsgs.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls) -> None:
|
||||||
|
ray.init()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def tearDownClass(cls) -> None:
|
||||||
|
ray.shutdown()
|
||||||
|
|
||||||
|
def test_no_gpus_error(self):
|
||||||
|
"""Tests errors related to no-GPU/too-few GPUs/etc.
|
||||||
|
|
||||||
|
This test will only work ok on a CPU-only machine.
|
||||||
|
"""
|
||||||
|
|
||||||
|
config = impala.DEFAULT_CONFIG.copy()
|
||||||
|
env = "CartPole-v0"
|
||||||
|
|
||||||
|
for _ in framework_iterator(config):
|
||||||
|
self.assertRaisesRegex(
|
||||||
|
RuntimeError,
|
||||||
|
# (?s): "dot matches all" (also newlines).
|
||||||
|
"(?s)Found 0 GPUs on your machine.+To change the config",
|
||||||
|
lambda: impala.ImpalaTrainer(config=config, env=env),
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_bad_envs(self):
|
||||||
|
"""Tests different "bad env" errors.
|
||||||
|
"""
|
||||||
|
config = pg.DEFAULT_CONFIG.copy()
|
||||||
|
config["num_workers"] = 0
|
||||||
|
|
||||||
|
# Non existing/non-registered gym env string.
|
||||||
|
env = "Alien-Attack-v42"
|
||||||
|
for _ in framework_iterator(config):
|
||||||
|
self.assertRaisesRegex(
|
||||||
|
EnvError,
|
||||||
|
f"The env string you provided \\('{env}'\\) is",
|
||||||
|
lambda: pg.PGTrainer(config=config, env=env),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Malformed gym env string (must have v\d at end).
|
||||||
|
env = "Alien-Attack-part-42"
|
||||||
|
for _ in framework_iterator(config):
|
||||||
|
self.assertRaisesRegex(
|
||||||
|
EnvError,
|
||||||
|
f"The env string you provided \\('{env}'\\) is",
|
||||||
|
lambda: pg.PGTrainer(config=config, env=env),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Non-existing class in a full-class-path.
|
||||||
|
env = "ray.rllib.examples.env.random_env.RandomEnvThatDoesntExist"
|
||||||
|
for _ in framework_iterator(config):
|
||||||
|
self.assertRaisesRegex(
|
||||||
|
EnvError,
|
||||||
|
f"The env string you provided \\('{env}'\\) is",
|
||||||
|
lambda: pg.PGTrainer(config=config, env=env),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Non-existing module inside a full-class-path.
|
||||||
|
env = "ray.rllib.examples.env.module_that_doesnt_exist.SomeEnv"
|
||||||
|
for _ in framework_iterator(config):
|
||||||
|
self.assertRaisesRegex(
|
||||||
|
EnvError,
|
||||||
|
f"The env string you provided \\('{env}'\\) is",
|
||||||
|
lambda: pg.PGTrainer(config=config, env=env),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import pytest
|
||||||
|
import sys
|
||||||
|
sys.exit(pytest.main(["-v", __file__]))
|
Loading…
Add table
Reference in a new issue