mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
Add error messages for missing tf and torch imports (#20205)
Co-authored-by: Sven Mika <sven@anyscale.io> Co-authored-by: sven1977 <svenmika1977@gmail.com>
This commit is contained in:
parent
5fccad4cc9
commit
dc17f0a241
3 changed files with 118 additions and 29 deletions
|
@ -1487,6 +1487,13 @@ py_test(
|
|||
srcs = ["tests/test_nested_observation_spaces.py"]
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "tests/test_nn_framework_import_errors",
|
||||
tags = ["team:ml", "tests_dir", "tests_dir_N"],
|
||||
size = "small",
|
||||
srcs = ["tests/test_nn_framework_import_errors.py"]
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "tests/test_pettingzoo_env",
|
||||
tags = ["team:ml", "tests_dir", "tests_dir_P"],
|
||||
|
|
|
@ -38,7 +38,7 @@ from ray.rllib.utils.debug import update_global_seed_if_necessary
|
|||
from ray.rllib.utils.deprecation import Deprecated, deprecation_warning, \
|
||||
DEPRECATED_VALUE
|
||||
from ray.rllib.utils.error import EnvError, ERR_MSG_INVALID_ENV_DESCRIPTOR
|
||||
from ray.rllib.utils.framework import try_import_tf
|
||||
from ray.rllib.utils.framework import try_import_tf, try_import_torch
|
||||
from ray.rllib.utils.from_config import from_config
|
||||
from ray.rllib.utils.multi_agent import check_multi_agent
|
||||
from ray.rllib.utils.spaces import space_utils
|
||||
|
@ -669,6 +669,9 @@ class Trainer(Trainable):
|
|||
self.config = self.merge_trainer_configs(
|
||||
self.get_default_config(), config, self._allow_unknown_configs)
|
||||
|
||||
# Validate the framework settings in config.
|
||||
self.validate_framework(self.config)
|
||||
|
||||
# Setup the "env creator" callable.
|
||||
env = self._env_id
|
||||
if env:
|
||||
|
@ -698,34 +701,6 @@ class Trainer(Trainable):
|
|||
else:
|
||||
self.env_creator = lambda env_config: None
|
||||
|
||||
# Check and resolve DL framework settings.
|
||||
# Tf-eager (tf2|tfe), possibly with tracing set to True. Recommend
|
||||
# setting tracing to True for speedups.
|
||||
if tf1 and self.config["framework"] in ["tf2", "tfe"]:
|
||||
if self.config["framework"] == "tf2" and tfv < 2:
|
||||
raise ValueError(
|
||||
"You configured `framework`=tf2, but your installed pip "
|
||||
"tf-version is < 2.0! Make sure your TensorFlow version "
|
||||
"is >= 2.x.")
|
||||
if not tf1.executing_eagerly():
|
||||
tf1.enable_eager_execution()
|
||||
logger.info(
|
||||
f"Executing eagerly (framework='{self.config['framework']}'),"
|
||||
f" with eager_tracing={self.config['eager_tracing']}. For "
|
||||
"production workloads, make sure to set `eager_tracing=True` "
|
||||
"in order to match the speed of tf-static-graph "
|
||||
"(framework='tf'). For debugging purposes, "
|
||||
"`eager_tracing=False` is the best choice.")
|
||||
# Tf-static-graph (framework=tf): Recommend upgrading to tf2 and
|
||||
# enabling eager tracing for similar speed.
|
||||
elif tf1 and self.config["framework"] == "tf":
|
||||
logger.info(
|
||||
"Your framework setting is 'tf', meaning you are using static"
|
||||
"-graph mode. Set framework='tf2' to enable eager execution "
|
||||
"with tf2.x. You may also want to then set "
|
||||
"`eager_tracing=True` in order to reach similar execution "
|
||||
"speed as with static-graph mode.")
|
||||
|
||||
# Set Trainer's seed after we have - if necessary - enabled
|
||||
# tf eager-execution.
|
||||
update_global_seed_if_necessary(
|
||||
|
@ -1802,6 +1777,75 @@ class Trainer(Trainable):
|
|||
cls._allow_unknown_subkeys,
|
||||
cls._override_all_subkeys_if_type_changes)
|
||||
|
||||
@staticmethod
|
||||
def validate_framework(config: PartialTrainerConfigDict) -> None:
|
||||
"""Validates the config dictionary wrt the framework settings.
|
||||
|
||||
Args:
|
||||
config: The config dictionary to be validated.
|
||||
|
||||
"""
|
||||
_tf1, _tf, _tfv = None, None, None
|
||||
_torch = None
|
||||
framework = config["framework"]
|
||||
tf_valid_frameworks = {"tf", "tf2", "tfe"}
|
||||
if framework not in tf_valid_frameworks and framework != "torch":
|
||||
return
|
||||
elif framework in tf_valid_frameworks:
|
||||
_tf1, _tf, _tfv = try_import_tf()
|
||||
else:
|
||||
_torch, _ = try_import_torch()
|
||||
|
||||
def check_if_correct_nn_framework_installed():
|
||||
"""Check if tf/torch experiment is running and tf/torch installed.
|
||||
"""
|
||||
if framework in tf_valid_frameworks:
|
||||
if not (_tf1 or _tf):
|
||||
raise ImportError((
|
||||
"TensorFlow was specified as the 'framework' "
|
||||
"inside of your config dictionary. However, there was "
|
||||
"no installation found. You can install TensorFlow "
|
||||
"via `pip install tensorflow`"))
|
||||
elif framework == "torch":
|
||||
if not _torch:
|
||||
raise ImportError(
|
||||
("PyTorch was specified as the 'framework' inside "
|
||||
"of your config dictionary. However, there was no "
|
||||
"installation found. You can install PyTorch via "
|
||||
"`pip install torch`"))
|
||||
|
||||
def resolve_tf_settings():
|
||||
"""Check and resolve tf settings."""
|
||||
|
||||
if _tf1 and config["framework"] in ["tf2", "tfe"]:
|
||||
if config["framework"] == "tf2" and _tfv < 2:
|
||||
raise ValueError(
|
||||
"You configured `framework`=tf2, but your installed "
|
||||
"pip tf-version is < 2.0! Make sure your TensorFlow "
|
||||
"version is >= 2.x.")
|
||||
if not _tf1.executing_eagerly():
|
||||
_tf1.enable_eager_execution()
|
||||
# Recommend setting tracing to True for speedups.
|
||||
logger.info(
|
||||
f"Executing eagerly (framework='{config['framework']}'),"
|
||||
f" with eager_tracing={config['eager_tracing']}. For "
|
||||
"production workloads, make sure to set eager_tracing=True"
|
||||
" in order to match the speed of tf-static-graph "
|
||||
"(framework='tf'). For debugging purposes, "
|
||||
"`eager_tracing=False` is the best choice.")
|
||||
# Tf-static-graph (framework=tf): Recommend upgrading to tf2 and
|
||||
# enabling eager tracing for similar speed.
|
||||
elif _tf1 and config["framework"] == "tf":
|
||||
logger.info(
|
||||
"Your framework setting is 'tf', meaning you are using "
|
||||
"static-graph mode. Set framework='tf2' to enable eager "
|
||||
"execution with tf2.x. You may also then want to set "
|
||||
"eager_tracing=True in order to reach similar execution "
|
||||
"speed as with static-graph mode.")
|
||||
|
||||
check_if_correct_nn_framework_installed()
|
||||
resolve_tf_settings()
|
||||
|
||||
@staticmethod
|
||||
def _validate_config(config: PartialTrainerConfigDict,
|
||||
trainer_obj_or_none: Optional["Trainer"] = None):
|
||||
|
|
38
rllib/tests/test_nn_framework_import_errors.py
Normal file
38
rllib/tests/test_nn_framework_import_errors.py
Normal file
|
@ -0,0 +1,38 @@
|
|||
#!/usr/bin/env python
|
||||
import os
|
||||
import pytest
|
||||
|
||||
import ray.rllib.agents.ppo as ppo
|
||||
from ray.rllib.utils.test_utils import framework_iterator
|
||||
|
||||
|
||||
def test_dont_import_tf_error():
|
||||
"""Check that an error is thrown when tf isn't installed
|
||||
but we try to run a tf experiment.
|
||||
"""
|
||||
# Do not import tf for testing purposes.
|
||||
os.environ["RLLIB_TEST_NO_TF_IMPORT"] = "1"
|
||||
|
||||
config = {}
|
||||
for _ in framework_iterator(config, frameworks=("tf", "tf2", "tfe")):
|
||||
with pytest.raises(
|
||||
ImportError,
|
||||
match="However, there was no installation found."):
|
||||
ppo.PPOTrainer(config, env="CartPole-v1")
|
||||
|
||||
|
||||
def test_dont_import_torch_error():
|
||||
"""Check that an error is thrown when torch isn't installed
|
||||
but we try to run a torch experiment.
|
||||
"""
|
||||
# Do not import tf for testing purposes.
|
||||
os.environ["RLLIB_TEST_NO_TORCH_IMPORT"] = "1"
|
||||
config = {"framework": "torch"}
|
||||
with pytest.raises(
|
||||
ImportError, match="However, there was no installation found."):
|
||||
ppo.PPOTrainer(config, env="CartPole-v1")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_dont_import_tf_error()
|
||||
test_dont_import_torch_error()
|
Loading…
Add table
Reference in a new issue