[rllib] Set framework to tf by default and remove import checks; "Auto" option (#8748)

* tf by default

* Update rllib/agents/trainer.py

Co-authored-by: Sven Mika <sven@anyscale.io>

* remove it

* fix

* remove

* fix

* lint

Co-authored-by: Sven Mika <sven@anyscale.io>
This commit is contained in:
Eric Liang 2020-06-08 23:04:50 -07:00 committed by SangBin Cho
parent f6034fd12e
commit 002e1e7c8d
7 changed files with 11 additions and 106 deletions

View file

@ -20,8 +20,7 @@ from ray.rllib.evaluation.metrics import collect_metrics
from ray.rllib.optimizers.policy_optimizer import PolicyOptimizer
from ray.rllib.evaluation.worker_set import WorkerSet
from ray.rllib.utils import FilterManager, deep_update, merge_dicts
from ray.rllib.utils.framework import check_framework, try_import_tf, \
TensorStructType
from ray.rllib.utils.framework import try_import_tf, TensorStructType
from ray.rllib.utils.annotations import override, PublicAPI, DeveloperAPI
from ray.rllib.utils.deprecation import DEPRECATED_VALUE, deprecation_warning
from ray.rllib.utils.from_config import from_config
@ -149,8 +148,7 @@ COMMON_CONFIG = {
# tf: TensorFlow
# tfe: TensorFlow eager
# torch: PyTorch
# auto: "torch" if only PyTorch installed, "tf" otherwise.
"framework": "auto",
"framework": "tf",
# Enable tracing in eager mode. This greatly improves performance, but
# makes it slightly harder to debug since Python code won't be evaluated
# after the initial eager pass. Only possible if framework=tfe.
@ -576,9 +574,7 @@ class Trainer(Trainable):
self.config["framework"] = "tfe"
self.config.pop("eager")
# Check all dependencies and resolve "auto" framework.
self.config["framework"] = check_framework(self.config["framework"])
# Notify about eager/tracing support.
# Enable eager/tracing support.
if tf and self.config["framework"] == "tfe":
if not tf.executing_eagerly():
tf.enable_eager_execution()

View file

@ -23,7 +23,7 @@ from ray.rllib.utils import try_import_tree
from ray.rllib.utils.annotations import DeveloperAPI, PublicAPI
from ray.rllib.utils.deprecation import deprecation_warning, DEPRECATED_VALUE
from ray.rllib.utils.error import UnsupportedSpaceException
from ray.rllib.utils.framework import check_framework, try_import_tf
from ray.rllib.utils.framework import try_import_tf
from ray.rllib.utils.spaces.simplex import Simplex
from ray.rllib.utils.spaces.space_utils import flatten_space
@ -135,9 +135,6 @@ class ModelCatalog:
dist_dim (int): The size of the input vector to the distribution.
"""
# Make sure, framework is ok.
framework = check_framework(framework)
dist = None
config = config or MODEL_DEFAULTS
# Custom distribution given.
@ -288,9 +285,6 @@ class ModelCatalog:
model (ModelV2): Model to use for the policy.
"""
# Make sure, framework is ok.
framework = check_framework(framework)
if model_config.get("custom_model"):
if "custom_options" in model_config and \
@ -589,9 +583,6 @@ class ModelCatalog:
@staticmethod
def _get_v2_model_class(obs_space, model_config, framework="tf"):
# Make sure, framework is ok.
framework = check_framework(framework)
if framework == "torch":
from ray.rllib.models.torch.fcnet import (FullyConnectedNetwork as
FCNet)

View file

@ -2,7 +2,7 @@ from functools import partial
from ray.rllib.utils.annotations import override, PublicAPI, DeveloperAPI
from ray.rllib.utils.framework import try_import_tf, try_import_tfp, \
try_import_torch, check_framework
try_import_torch
from ray.rllib.utils.deprecation import deprecation_warning, renamed_agent, \
renamed_class, renamed_function
from ray.rllib.utils.filter_manager import FilterManager
@ -72,7 +72,6 @@ __all__ = [
"add_mixins",
"check",
"check_compute_action",
"check_framework",
"deprecation_warning",
"fc",
"force_list",

View file

@ -1,8 +1,7 @@
from gym.spaces import Space
from typing import Union
from ray.rllib.utils.framework import check_framework, try_import_torch, \
TensorType
from ray.rllib.utils.framework import try_import_torch, TensorType
from ray.rllib.models.action_dist import ActionDistribution
from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.utils.annotations import DeveloperAPI
@ -36,7 +35,7 @@ class Exploration:
self.model = model
self.num_workers = num_workers
self.worker_index = worker_index
self.framework = check_framework(framework)
self.framework = framework
# The device on which the Model has been placed.
# This Exploration will be on the same device.
self.device = None

View file

@ -3,8 +3,6 @@ import os
import sys
from typing import Any, Union
from ray.util import log_once
logger = logging.getLogger(__name__)
# Represents a generic tensor type.
@ -14,80 +12,6 @@ TensorType = Any
TensorStructType = Union[TensorType, dict, tuple]
def get_auto_framework():
"""Returns the framework (str) when framework="auto" in the config.
If only PyTorch is installed, returns "torch", if only tf is installed,
returns "tf", if both are installed, raises an error.
"""
# PyTorch is installed.
if torch is not None:
# TF is not installed -> return torch.
if tf is None:
if log_once("get_auto_framework"):
logger.info(
"`framework=auto` found in config -> Detected PyTorch.")
return "torch"
# TF is also installed -> raise error.
else:
raise ValueError(
"framework='auto' (default value) is not allowed if both "
"TensorFlow AND PyTorch are installed! "
"Instead, use framework='tf|tfe|torch' explicitly.")
# PyTorch nor TF installed -> raise error.
if not tf:
raise ValueError(
"Neither TensorFlow nor PyTorch are installed! You must install "
"one of them by running either `pip install tensorflow` OR "
"`pip install torch torchvision`")
# Only TensorFlow installed -> return tf.
if log_once("get_auto_framework"):
logger.info("`framework=auto` found in config -> Detected TensorFlow.")
return "tf"
def check_framework(framework, allow_none=True):
"""Checks, whether the given framework is "valid".
Meaning, whether all necessary dependencies are installed.
Args:
framework (str): Once of "tf", "torch", or None.
allow_none (bool): Whether framework=None (e.g. numpy implementatiopn)
is allowed or not.
Returns:
str: The input framework string.
Raises:
ImportError: If given framework is not installed.
"""
# Resolve auto framework first.
if framework == "auto":
framework = get_auto_framework()
# Check, whether tf is installed.
if framework in ["tf", "tfe"]:
if tf is None:
raise ImportError(
"Could not import `tensorflow`. Try `pip install tensorflow`")
# Check, whether torch is installed.
elif framework == "torch":
if torch is None:
raise ImportError("Could not import `torch`. "
"Try `pip install torch torchvision`")
# Framework is None (use numpy version of the component).
elif framework is None:
if not allow_none:
raise ValueError("framework=None not allowed!")
# Invalid value.
else:
raise ValueError("Invalid framework='{}'. Use one of "
"[tf|tfe|torch|auto].".format(framework))
return framework
def try_import_tf(error=False):
"""Tries importing tf and returns the module (or None).

View file

@ -1,7 +1,7 @@
from abc import ABCMeta, abstractmethod
from ray.rllib.utils.annotations import DeveloperAPI
from ray.rllib.utils.framework import check_framework, try_import_tf
from ray.rllib.utils.framework import try_import_tf
tf = try_import_tf()
@ -24,7 +24,7 @@ class Schedule(metaclass=ABCMeta):
"""
def __init__(self, framework):
self.framework = check_framework(framework)
self.framework = framework
def value(self, t):
"""Generates the value given a timestep (based on schedule's logic).

View file

@ -1,8 +1,7 @@
import logging
import numpy as np
from ray.rllib.utils.framework import try_import_tf, try_import_torch, \
get_auto_framework
from ray.rllib.utils.framework import try_import_tf, try_import_torch
tf = try_import_tf()
if tf:
@ -29,7 +28,7 @@ def framework_iterator(config=None,
config (Optional[dict]): An optional config dict to alter in place
depending on the iteration.
frameworks (Tuple[str]): A list/tuple of the frameworks to be tested.
Allowed are: "tf", "tfe", "torch", and "auto".
Allowed are: "tf", "tfe", and "torch".
session (bool): If True, enter a tf.Session() and yield that as
well in the tf-case (otherwise, yield (fw, None)).
@ -43,9 +42,6 @@ def framework_iterator(config=None,
frameworks = [frameworks] if isinstance(frameworks, str) else frameworks
for fw in frameworks:
if fw == "auto":
fw = get_auto_framework()
# Skip non-installed frameworks.
if fw == "torch" and not torch:
logger.warning(