Add error messages for missing tf and torch imports (#20205)

Co-authored-by: Sven Mika <sven@anyscale.io>
Co-authored-by: sven1977 <svenmika1977@gmail.com>
This commit is contained in:
Avnish Narayan 2021-11-16 16:30:53 -08:00 committed by GitHub
parent 5fccad4cc9
commit dc17f0a241
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 118 additions and 29 deletions

View file

@ -1487,6 +1487,13 @@ py_test(
srcs = ["tests/test_nested_observation_spaces.py"]
)
py_test(
name = "tests/test_nn_framework_import_errors",
tags = ["team:ml", "tests_dir", "tests_dir_N"],
size = "small",
srcs = ["tests/test_nn_framework_import_errors.py"]
)
py_test(
name = "tests/test_pettingzoo_env",
tags = ["team:ml", "tests_dir", "tests_dir_P"],

View file

@ -38,7 +38,7 @@ from ray.rllib.utils.debug import update_global_seed_if_necessary
from ray.rllib.utils.deprecation import Deprecated, deprecation_warning, \
DEPRECATED_VALUE
from ray.rllib.utils.error import EnvError, ERR_MSG_INVALID_ENV_DESCRIPTOR
from ray.rllib.utils.framework import try_import_tf
from ray.rllib.utils.framework import try_import_tf, try_import_torch
from ray.rllib.utils.from_config import from_config
from ray.rllib.utils.multi_agent import check_multi_agent
from ray.rllib.utils.spaces import space_utils
@ -669,6 +669,9 @@ class Trainer(Trainable):
self.config = self.merge_trainer_configs(
self.get_default_config(), config, self._allow_unknown_configs)
# Validate the framework settings in config.
self.validate_framework(self.config)
# Setup the "env creator" callable.
env = self._env_id
if env:
@ -698,34 +701,6 @@ class Trainer(Trainable):
else:
self.env_creator = lambda env_config: None
# Check and resolve DL framework settings.
# Tf-eager (tf2|tfe), possibly with tracing set to True. Recommend
# setting tracing to True for speedups.
if tf1 and self.config["framework"] in ["tf2", "tfe"]:
if self.config["framework"] == "tf2" and tfv < 2:
raise ValueError(
"You configured `framework`=tf2, but your installed pip "
"tf-version is < 2.0! Make sure your TensorFlow version "
"is >= 2.x.")
if not tf1.executing_eagerly():
tf1.enable_eager_execution()
logger.info(
f"Executing eagerly (framework='{self.config['framework']}'),"
f" with eager_tracing={self.config['eager_tracing']}. For "
"production workloads, make sure to set `eager_tracing=True` "
"in order to match the speed of tf-static-graph "
"(framework='tf'). For debugging purposes, "
"`eager_tracing=False` is the best choice.")
# Tf-static-graph (framework=tf): Recommend upgrading to tf2 and
# enabling eager tracing for similar speed.
elif tf1 and self.config["framework"] == "tf":
logger.info(
"Your framework setting is 'tf', meaning you are using static"
"-graph mode. Set framework='tf2' to enable eager execution "
"with tf2.x. You may also want to then set "
"`eager_tracing=True` in order to reach similar execution "
"speed as with static-graph mode.")
# Set Trainer's seed after we have - if necessary - enabled
# tf eager-execution.
update_global_seed_if_necessary(
@ -1802,6 +1777,75 @@ class Trainer(Trainable):
cls._allow_unknown_subkeys,
cls._override_all_subkeys_if_type_changes)
@staticmethod
def validate_framework(config: PartialTrainerConfigDict) -> None:
"""Validates the config dictionary wrt the framework settings.
Args:
config: The config dictionary to be validated.
"""
_tf1, _tf, _tfv = None, None, None
_torch = None
framework = config["framework"]
tf_valid_frameworks = {"tf", "tf2", "tfe"}
if framework not in tf_valid_frameworks and framework != "torch":
return
elif framework in tf_valid_frameworks:
_tf1, _tf, _tfv = try_import_tf()
else:
_torch, _ = try_import_torch()
def check_if_correct_nn_framework_installed():
"""Check if tf/torch experiment is running and tf/torch installed.
"""
if framework in tf_valid_frameworks:
if not (_tf1 or _tf):
raise ImportError((
"TensorFlow was specified as the 'framework' "
"inside of your config dictionary. However, there was "
"no installation found. You can install TensorFlow "
"via `pip install tensorflow`"))
elif framework == "torch":
if not _torch:
raise ImportError(
("PyTorch was specified as the 'framework' inside "
"of your config dictionary. However, there was no "
"installation found. You can install PyTorch via "
"`pip install torch`"))
def resolve_tf_settings():
"""Check and resolve tf settings."""
if _tf1 and config["framework"] in ["tf2", "tfe"]:
if config["framework"] == "tf2" and _tfv < 2:
raise ValueError(
"You configured `framework`=tf2, but your installed "
"pip tf-version is < 2.0! Make sure your TensorFlow "
"version is >= 2.x.")
if not _tf1.executing_eagerly():
_tf1.enable_eager_execution()
# Recommend setting tracing to True for speedups.
logger.info(
f"Executing eagerly (framework='{config['framework']}'),"
f" with eager_tracing={config['eager_tracing']}. For "
"production workloads, make sure to set eager_tracing=True"
" in order to match the speed of tf-static-graph "
"(framework='tf'). For debugging purposes, "
"`eager_tracing=False` is the best choice.")
# Tf-static-graph (framework=tf): Recommend upgrading to tf2 and
# enabling eager tracing for similar speed.
elif _tf1 and config["framework"] == "tf":
logger.info(
"Your framework setting is 'tf', meaning you are using "
"static-graph mode. Set framework='tf2' to enable eager "
"execution with tf2.x. You may also then want to set "
"eager_tracing=True in order to reach similar execution "
"speed as with static-graph mode.")
check_if_correct_nn_framework_installed()
resolve_tf_settings()
@staticmethod
def _validate_config(config: PartialTrainerConfigDict,
trainer_obj_or_none: Optional["Trainer"] = None):

View file

@ -0,0 +1,38 @@
#!/usr/bin/env python
import os
import pytest
import ray.rllib.agents.ppo as ppo
from ray.rllib.utils.test_utils import framework_iterator
def test_dont_import_tf_error():
"""Check that an error is thrown when tf isn't installed
but we try to run a tf experiment.
"""
# Do not import tf for testing purposes.
os.environ["RLLIB_TEST_NO_TF_IMPORT"] = "1"
config = {}
for _ in framework_iterator(config, frameworks=("tf", "tf2", "tfe")):
with pytest.raises(
ImportError,
match="However, there was no installation found."):
ppo.PPOTrainer(config, env="CartPole-v1")
def test_dont_import_torch_error():
"""Check that an error is thrown when torch isn't installed
but we try to run a torch experiment.
"""
# Do not import tf for testing purposes.
os.environ["RLLIB_TEST_NO_TORCH_IMPORT"] = "1"
config = {"framework": "torch"}
with pytest.raises(
ImportError, match="However, there was no installation found."):
ppo.PPOTrainer(config, env="CartPole-v1")
if __name__ == "__main__":
test_dont_import_tf_error()
test_dont_import_torch_error()