diff --git a/rllib/agents/trainer.py b/rllib/agents/trainer.py index 90f1d2774..467ff8bff 100644 --- a/rllib/agents/trainer.py +++ b/rllib/agents/trainer.py @@ -20,8 +20,7 @@ from ray.rllib.evaluation.metrics import collect_metrics from ray.rllib.optimizers.policy_optimizer import PolicyOptimizer from ray.rllib.evaluation.worker_set import WorkerSet from ray.rllib.utils import FilterManager, deep_update, merge_dicts -from ray.rllib.utils.framework import check_framework, try_import_tf, \ - TensorStructType +from ray.rllib.utils.framework import try_import_tf, TensorStructType from ray.rllib.utils.annotations import override, PublicAPI, DeveloperAPI from ray.rllib.utils.deprecation import DEPRECATED_VALUE, deprecation_warning from ray.rllib.utils.from_config import from_config @@ -149,8 +148,7 @@ COMMON_CONFIG = { # tf: TensorFlow # tfe: TensorFlow eager # torch: PyTorch - # auto: "torch" if only PyTorch installed, "tf" otherwise. - "framework": "auto", + "framework": "tf", # Enable tracing in eager mode. This greatly improves performance, but # makes it slightly harder to debug since Python code won't be evaluated # after the initial eager pass. Only possible if framework=tfe. @@ -576,9 +574,7 @@ class Trainer(Trainable): self.config["framework"] = "tfe" self.config.pop("eager") - # Check all dependencies and resolve "auto" framework. - self.config["framework"] = check_framework(self.config["framework"]) - # Notify about eager/tracing support. + # Enable eager/tracing support. if tf and self.config["framework"] == "tfe": if not tf.executing_eagerly(): tf.enable_eager_execution() diff --git a/rllib/models/catalog.py b/rllib/models/catalog.py index d2ad42bc8..5e2135f4c 100644 --- a/rllib/models/catalog.py +++ b/rllib/models/catalog.py @@ -23,7 +23,7 @@ from ray.rllib.utils import try_import_tree from ray.rllib.utils.annotations import DeveloperAPI, PublicAPI from ray.rllib.utils.deprecation import deprecation_warning, DEPRECATED_VALUE from ray.rllib.utils.error import UnsupportedSpaceException -from ray.rllib.utils.framework import check_framework, try_import_tf +from ray.rllib.utils.framework import try_import_tf from ray.rllib.utils.spaces.simplex import Simplex from ray.rllib.utils.spaces.space_utils import flatten_space @@ -135,9 +135,6 @@ class ModelCatalog: dist_dim (int): The size of the input vector to the distribution. """ - # Make sure, framework is ok. - framework = check_framework(framework) - dist = None config = config or MODEL_DEFAULTS # Custom distribution given. @@ -288,9 +285,6 @@ class ModelCatalog: model (ModelV2): Model to use for the policy. """ - # Make sure, framework is ok. - framework = check_framework(framework) - if model_config.get("custom_model"): if "custom_options" in model_config and \ @@ -589,9 +583,6 @@ class ModelCatalog: @staticmethod def _get_v2_model_class(obs_space, model_config, framework="tf"): - # Make sure, framework is ok. - framework = check_framework(framework) - if framework == "torch": from ray.rllib.models.torch.fcnet import (FullyConnectedNetwork as FCNet) diff --git a/rllib/utils/__init__.py b/rllib/utils/__init__.py index 74c090350..4030f5c46 100644 --- a/rllib/utils/__init__.py +++ b/rllib/utils/__init__.py @@ -2,7 +2,7 @@ from functools import partial from ray.rllib.utils.annotations import override, PublicAPI, DeveloperAPI from ray.rllib.utils.framework import try_import_tf, try_import_tfp, \ - try_import_torch, check_framework + try_import_torch from ray.rllib.utils.deprecation import deprecation_warning, renamed_agent, \ renamed_class, renamed_function from ray.rllib.utils.filter_manager import FilterManager @@ -72,7 +72,6 @@ __all__ = [ "add_mixins", "check", "check_compute_action", - "check_framework", "deprecation_warning", "fc", "force_list", diff --git a/rllib/utils/exploration/exploration.py b/rllib/utils/exploration/exploration.py index ca650e82e..921c5aebc 100644 --- a/rllib/utils/exploration/exploration.py +++ b/rllib/utils/exploration/exploration.py @@ -1,8 +1,7 @@ from gym.spaces import Space from typing import Union -from ray.rllib.utils.framework import check_framework, try_import_torch, \ - TensorType +from ray.rllib.utils.framework import try_import_torch, TensorType from ray.rllib.models.action_dist import ActionDistribution from ray.rllib.models.modelv2 import ModelV2 from ray.rllib.utils.annotations import DeveloperAPI @@ -36,7 +35,7 @@ class Exploration: self.model = model self.num_workers = num_workers self.worker_index = worker_index - self.framework = check_framework(framework) + self.framework = framework # The device on which the Model has been placed. # This Exploration will be on the same device. self.device = None diff --git a/rllib/utils/framework.py b/rllib/utils/framework.py index 27770962c..c0434126c 100644 --- a/rllib/utils/framework.py +++ b/rllib/utils/framework.py @@ -3,8 +3,6 @@ import os import sys from typing import Any, Union -from ray.util import log_once - logger = logging.getLogger(__name__) # Represents a generic tensor type. @@ -14,80 +12,6 @@ TensorType = Any TensorStructType = Union[TensorType, dict, tuple] -def get_auto_framework(): - """Returns the framework (str) when framework="auto" in the config. - - If only PyTorch is installed, returns "torch", if only tf is installed, - returns "tf", if both are installed, raises an error. - """ - - # PyTorch is installed. - if torch is not None: - # TF is not installed -> return torch. - if tf is None: - if log_once("get_auto_framework"): - logger.info( - "`framework=auto` found in config -> Detected PyTorch.") - return "torch" - # TF is also installed -> raise error. - else: - raise ValueError( - "framework='auto' (default value) is not allowed if both " - "TensorFlow AND PyTorch are installed! " - "Instead, use framework='tf|tfe|torch' explicitly.") - # PyTorch nor TF installed -> raise error. - if not tf: - raise ValueError( - "Neither TensorFlow nor PyTorch are installed! You must install " - "one of them by running either `pip install tensorflow` OR " - "`pip install torch torchvision`") - # Only TensorFlow installed -> return tf. - if log_once("get_auto_framework"): - logger.info("`framework=auto` found in config -> Detected TensorFlow.") - return "tf" - - -def check_framework(framework, allow_none=True): - """Checks, whether the given framework is "valid". - - Meaning, whether all necessary dependencies are installed. - - Args: - framework (str): Once of "tf", "torch", or None. - allow_none (bool): Whether framework=None (e.g. numpy implementatiopn) - is allowed or not. - - Returns: - str: The input framework string. - - Raises: - ImportError: If given framework is not installed. - """ - # Resolve auto framework first. - if framework == "auto": - framework = get_auto_framework() - - # Check, whether tf is installed. - if framework in ["tf", "tfe"]: - if tf is None: - raise ImportError( - "Could not import `tensorflow`. Try `pip install tensorflow`") - # Check, whether torch is installed. - elif framework == "torch": - if torch is None: - raise ImportError("Could not import `torch`. " - "Try `pip install torch torchvision`") - # Framework is None (use numpy version of the component). - elif framework is None: - if not allow_none: - raise ValueError("framework=None not allowed!") - # Invalid value. - else: - raise ValueError("Invalid framework='{}'. Use one of " - "[tf|tfe|torch|auto].".format(framework)) - return framework - - def try_import_tf(error=False): """Tries importing tf and returns the module (or None). diff --git a/rllib/utils/schedules/schedule.py b/rllib/utils/schedules/schedule.py index a811a4113..52a3205c5 100644 --- a/rllib/utils/schedules/schedule.py +++ b/rllib/utils/schedules/schedule.py @@ -1,7 +1,7 @@ from abc import ABCMeta, abstractmethod from ray.rllib.utils.annotations import DeveloperAPI -from ray.rllib.utils.framework import check_framework, try_import_tf +from ray.rllib.utils.framework import try_import_tf tf = try_import_tf() @@ -24,7 +24,7 @@ class Schedule(metaclass=ABCMeta): """ def __init__(self, framework): - self.framework = check_framework(framework) + self.framework = framework def value(self, t): """Generates the value given a timestep (based on schedule's logic). diff --git a/rllib/utils/test_utils.py b/rllib/utils/test_utils.py index c88437098..c92a6c809 100644 --- a/rllib/utils/test_utils.py +++ b/rllib/utils/test_utils.py @@ -1,8 +1,7 @@ import logging import numpy as np -from ray.rllib.utils.framework import try_import_tf, try_import_torch, \ - get_auto_framework +from ray.rllib.utils.framework import try_import_tf, try_import_torch tf = try_import_tf() if tf: @@ -29,7 +28,7 @@ def framework_iterator(config=None, config (Optional[dict]): An optional config dict to alter in place depending on the iteration. frameworks (Tuple[str]): A list/tuple of the frameworks to be tested. - Allowed are: "tf", "tfe", "torch", and "auto". + Allowed are: "tf", "tfe", and "torch". session (bool): If True, enter a tf.Session() and yield that as well in the tf-case (otherwise, yield (fw, None)). @@ -43,9 +42,6 @@ def framework_iterator(config=None, frameworks = [frameworks] if isinstance(frameworks, str) else frameworks for fw in frameworks: - if fw == "auto": - fw = get_auto_framework() - # Skip non-installed frameworks. if fw == "torch" and not torch: logger.warning(