From dc17f0a241cd53507101881f688e18ae9c3cdc8f Mon Sep 17 00:00:00 2001 From: Avnish Narayan <38871737+avnishn@users.noreply.github.com> Date: Tue, 16 Nov 2021 16:30:53 -0800 Subject: [PATCH] Add error messages for missing tf and torch imports (#20205) Co-authored-by: Sven Mika Co-authored-by: sven1977 --- rllib/BUILD | 7 ++ rllib/agents/trainer.py | 102 +++++++++++++----- .../tests/test_nn_framework_import_errors.py | 38 +++++++ 3 files changed, 118 insertions(+), 29 deletions(-) create mode 100644 rllib/tests/test_nn_framework_import_errors.py diff --git a/rllib/BUILD b/rllib/BUILD index d2a655b02..d6ec7ffee 100644 --- a/rllib/BUILD +++ b/rllib/BUILD @@ -1487,6 +1487,13 @@ py_test( srcs = ["tests/test_nested_observation_spaces.py"] ) +py_test( + name = "tests/test_nn_framework_import_errors", + tags = ["team:ml", "tests_dir", "tests_dir_N"], + size = "small", + srcs = ["tests/test_nn_framework_import_errors.py"] +) + py_test( name = "tests/test_pettingzoo_env", tags = ["team:ml", "tests_dir", "tests_dir_P"], diff --git a/rllib/agents/trainer.py b/rllib/agents/trainer.py index 822e9788e..20c1e07cf 100644 --- a/rllib/agents/trainer.py +++ b/rllib/agents/trainer.py @@ -38,7 +38,7 @@ from ray.rllib.utils.debug import update_global_seed_if_necessary from ray.rllib.utils.deprecation import Deprecated, deprecation_warning, \ DEPRECATED_VALUE from ray.rllib.utils.error import EnvError, ERR_MSG_INVALID_ENV_DESCRIPTOR -from ray.rllib.utils.framework import try_import_tf +from ray.rllib.utils.framework import try_import_tf, try_import_torch from ray.rllib.utils.from_config import from_config from ray.rllib.utils.multi_agent import check_multi_agent from ray.rllib.utils.spaces import space_utils @@ -669,6 +669,9 @@ class Trainer(Trainable): self.config = self.merge_trainer_configs( self.get_default_config(), config, self._allow_unknown_configs) + # Validate the framework settings in config. + self.validate_framework(self.config) + # Setup the "env creator" callable. env = self._env_id if env: @@ -698,34 +701,6 @@ class Trainer(Trainable): else: self.env_creator = lambda env_config: None - # Check and resolve DL framework settings. - # Tf-eager (tf2|tfe), possibly with tracing set to True. Recommend - # setting tracing to True for speedups. - if tf1 and self.config["framework"] in ["tf2", "tfe"]: - if self.config["framework"] == "tf2" and tfv < 2: - raise ValueError( - "You configured `framework`=tf2, but your installed pip " - "tf-version is < 2.0! Make sure your TensorFlow version " - "is >= 2.x.") - if not tf1.executing_eagerly(): - tf1.enable_eager_execution() - logger.info( - f"Executing eagerly (framework='{self.config['framework']}')," - f" with eager_tracing={self.config['eager_tracing']}. For " - "production workloads, make sure to set `eager_tracing=True` " - "in order to match the speed of tf-static-graph " - "(framework='tf'). For debugging purposes, " - "`eager_tracing=False` is the best choice.") - # Tf-static-graph (framework=tf): Recommend upgrading to tf2 and - # enabling eager tracing for similar speed. - elif tf1 and self.config["framework"] == "tf": - logger.info( - "Your framework setting is 'tf', meaning you are using static" - "-graph mode. Set framework='tf2' to enable eager execution " - "with tf2.x. You may also want to then set " - "`eager_tracing=True` in order to reach similar execution " - "speed as with static-graph mode.") - # Set Trainer's seed after we have - if necessary - enabled # tf eager-execution. update_global_seed_if_necessary( @@ -1802,6 +1777,75 @@ class Trainer(Trainable): cls._allow_unknown_subkeys, cls._override_all_subkeys_if_type_changes) + @staticmethod + def validate_framework(config: PartialTrainerConfigDict) -> None: + """Validates the config dictionary wrt the framework settings. + + Args: + config: The config dictionary to be validated. + + """ + _tf1, _tf, _tfv = None, None, None + _torch = None + framework = config["framework"] + tf_valid_frameworks = {"tf", "tf2", "tfe"} + if framework not in tf_valid_frameworks and framework != "torch": + return + elif framework in tf_valid_frameworks: + _tf1, _tf, _tfv = try_import_tf() + else: + _torch, _ = try_import_torch() + + def check_if_correct_nn_framework_installed(): + """Check if tf/torch experiment is running and tf/torch installed. + """ + if framework in tf_valid_frameworks: + if not (_tf1 or _tf): + raise ImportError(( + "TensorFlow was specified as the 'framework' " + "inside of your config dictionary. However, there was " + "no installation found. You can install TensorFlow " + "via `pip install tensorflow`")) + elif framework == "torch": + if not _torch: + raise ImportError( + ("PyTorch was specified as the 'framework' inside " + "of your config dictionary. However, there was no " + "installation found. You can install PyTorch via " + "`pip install torch`")) + + def resolve_tf_settings(): + """Check and resolve tf settings.""" + + if _tf1 and config["framework"] in ["tf2", "tfe"]: + if config["framework"] == "tf2" and _tfv < 2: + raise ValueError( + "You configured `framework`=tf2, but your installed " + "pip tf-version is < 2.0! Make sure your TensorFlow " + "version is >= 2.x.") + if not _tf1.executing_eagerly(): + _tf1.enable_eager_execution() + # Recommend setting tracing to True for speedups. + logger.info( + f"Executing eagerly (framework='{config['framework']}')," + f" with eager_tracing={config['eager_tracing']}. For " + "production workloads, make sure to set eager_tracing=True" + " in order to match the speed of tf-static-graph " + "(framework='tf'). For debugging purposes, " + "`eager_tracing=False` is the best choice.") + # Tf-static-graph (framework=tf): Recommend upgrading to tf2 and + # enabling eager tracing for similar speed. + elif _tf1 and config["framework"] == "tf": + logger.info( + "Your framework setting is 'tf', meaning you are using " + "static-graph mode. Set framework='tf2' to enable eager " + "execution with tf2.x. You may also then want to set " + "eager_tracing=True in order to reach similar execution " + "speed as with static-graph mode.") + + check_if_correct_nn_framework_installed() + resolve_tf_settings() + @staticmethod def _validate_config(config: PartialTrainerConfigDict, trainer_obj_or_none: Optional["Trainer"] = None): diff --git a/rllib/tests/test_nn_framework_import_errors.py b/rllib/tests/test_nn_framework_import_errors.py new file mode 100644 index 000000000..9b2978432 --- /dev/null +++ b/rllib/tests/test_nn_framework_import_errors.py @@ -0,0 +1,38 @@ +#!/usr/bin/env python +import os +import pytest + +import ray.rllib.agents.ppo as ppo +from ray.rllib.utils.test_utils import framework_iterator + + +def test_dont_import_tf_error(): + """Check that an error is thrown when tf isn't installed + but we try to run a tf experiment. + """ + # Do not import tf for testing purposes. + os.environ["RLLIB_TEST_NO_TF_IMPORT"] = "1" + + config = {} + for _ in framework_iterator(config, frameworks=("tf", "tf2", "tfe")): + with pytest.raises( + ImportError, + match="However, there was no installation found."): + ppo.PPOTrainer(config, env="CartPole-v1") + + +def test_dont_import_torch_error(): + """Check that an error is thrown when torch isn't installed + but we try to run a torch experiment. + """ + # Do not import tf for testing purposes. + os.environ["RLLIB_TEST_NO_TORCH_IMPORT"] = "1" + config = {"framework": "torch"} + with pytest.raises( + ImportError, match="However, there was no installation found."): + ppo.PPOTrainer(config, env="CartPole-v1") + + +if __name__ == "__main__": + test_dont_import_tf_error() + test_dont_import_torch_error()