mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[rllib] Set framework to tf by default and remove import checks; "Auto" option (#8748)
* tf by default * Update rllib/agents/trainer.py Co-authored-by: Sven Mika <sven@anyscale.io> * remove it * fix * remove * fix * lint Co-authored-by: Sven Mika <sven@anyscale.io>
This commit is contained in:
parent
f6034fd12e
commit
002e1e7c8d
7 changed files with 11 additions and 106 deletions
|
@ -20,8 +20,7 @@ from ray.rllib.evaluation.metrics import collect_metrics
|
|||
from ray.rllib.optimizers.policy_optimizer import PolicyOptimizer
|
||||
from ray.rllib.evaluation.worker_set import WorkerSet
|
||||
from ray.rllib.utils import FilterManager, deep_update, merge_dicts
|
||||
from ray.rllib.utils.framework import check_framework, try_import_tf, \
|
||||
TensorStructType
|
||||
from ray.rllib.utils.framework import try_import_tf, TensorStructType
|
||||
from ray.rllib.utils.annotations import override, PublicAPI, DeveloperAPI
|
||||
from ray.rllib.utils.deprecation import DEPRECATED_VALUE, deprecation_warning
|
||||
from ray.rllib.utils.from_config import from_config
|
||||
|
@ -149,8 +148,7 @@ COMMON_CONFIG = {
|
|||
# tf: TensorFlow
|
||||
# tfe: TensorFlow eager
|
||||
# torch: PyTorch
|
||||
# auto: "torch" if only PyTorch installed, "tf" otherwise.
|
||||
"framework": "auto",
|
||||
"framework": "tf",
|
||||
# Enable tracing in eager mode. This greatly improves performance, but
|
||||
# makes it slightly harder to debug since Python code won't be evaluated
|
||||
# after the initial eager pass. Only possible if framework=tfe.
|
||||
|
@ -576,9 +574,7 @@ class Trainer(Trainable):
|
|||
self.config["framework"] = "tfe"
|
||||
self.config.pop("eager")
|
||||
|
||||
# Check all dependencies and resolve "auto" framework.
|
||||
self.config["framework"] = check_framework(self.config["framework"])
|
||||
# Notify about eager/tracing support.
|
||||
# Enable eager/tracing support.
|
||||
if tf and self.config["framework"] == "tfe":
|
||||
if not tf.executing_eagerly():
|
||||
tf.enable_eager_execution()
|
||||
|
|
|
@ -23,7 +23,7 @@ from ray.rllib.utils import try_import_tree
|
|||
from ray.rllib.utils.annotations import DeveloperAPI, PublicAPI
|
||||
from ray.rllib.utils.deprecation import deprecation_warning, DEPRECATED_VALUE
|
||||
from ray.rllib.utils.error import UnsupportedSpaceException
|
||||
from ray.rllib.utils.framework import check_framework, try_import_tf
|
||||
from ray.rllib.utils.framework import try_import_tf
|
||||
from ray.rllib.utils.spaces.simplex import Simplex
|
||||
from ray.rllib.utils.spaces.space_utils import flatten_space
|
||||
|
||||
|
@ -135,9 +135,6 @@ class ModelCatalog:
|
|||
dist_dim (int): The size of the input vector to the distribution.
|
||||
"""
|
||||
|
||||
# Make sure, framework is ok.
|
||||
framework = check_framework(framework)
|
||||
|
||||
dist = None
|
||||
config = config or MODEL_DEFAULTS
|
||||
# Custom distribution given.
|
||||
|
@ -288,9 +285,6 @@ class ModelCatalog:
|
|||
model (ModelV2): Model to use for the policy.
|
||||
"""
|
||||
|
||||
# Make sure, framework is ok.
|
||||
framework = check_framework(framework)
|
||||
|
||||
if model_config.get("custom_model"):
|
||||
|
||||
if "custom_options" in model_config and \
|
||||
|
@ -589,9 +583,6 @@ class ModelCatalog:
|
|||
|
||||
@staticmethod
|
||||
def _get_v2_model_class(obs_space, model_config, framework="tf"):
|
||||
# Make sure, framework is ok.
|
||||
framework = check_framework(framework)
|
||||
|
||||
if framework == "torch":
|
||||
from ray.rllib.models.torch.fcnet import (FullyConnectedNetwork as
|
||||
FCNet)
|
||||
|
|
|
@ -2,7 +2,7 @@ from functools import partial
|
|||
|
||||
from ray.rllib.utils.annotations import override, PublicAPI, DeveloperAPI
|
||||
from ray.rllib.utils.framework import try_import_tf, try_import_tfp, \
|
||||
try_import_torch, check_framework
|
||||
try_import_torch
|
||||
from ray.rllib.utils.deprecation import deprecation_warning, renamed_agent, \
|
||||
renamed_class, renamed_function
|
||||
from ray.rllib.utils.filter_manager import FilterManager
|
||||
|
@ -72,7 +72,6 @@ __all__ = [
|
|||
"add_mixins",
|
||||
"check",
|
||||
"check_compute_action",
|
||||
"check_framework",
|
||||
"deprecation_warning",
|
||||
"fc",
|
||||
"force_list",
|
||||
|
|
|
@ -1,8 +1,7 @@
|
|||
from gym.spaces import Space
|
||||
from typing import Union
|
||||
|
||||
from ray.rllib.utils.framework import check_framework, try_import_torch, \
|
||||
TensorType
|
||||
from ray.rllib.utils.framework import try_import_torch, TensorType
|
||||
from ray.rllib.models.action_dist import ActionDistribution
|
||||
from ray.rllib.models.modelv2 import ModelV2
|
||||
from ray.rllib.utils.annotations import DeveloperAPI
|
||||
|
@ -36,7 +35,7 @@ class Exploration:
|
|||
self.model = model
|
||||
self.num_workers = num_workers
|
||||
self.worker_index = worker_index
|
||||
self.framework = check_framework(framework)
|
||||
self.framework = framework
|
||||
# The device on which the Model has been placed.
|
||||
# This Exploration will be on the same device.
|
||||
self.device = None
|
||||
|
|
|
@ -3,8 +3,6 @@ import os
|
|||
import sys
|
||||
from typing import Any, Union
|
||||
|
||||
from ray.util import log_once
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Represents a generic tensor type.
|
||||
|
@ -14,80 +12,6 @@ TensorType = Any
|
|||
TensorStructType = Union[TensorType, dict, tuple]
|
||||
|
||||
|
||||
def get_auto_framework():
|
||||
"""Returns the framework (str) when framework="auto" in the config.
|
||||
|
||||
If only PyTorch is installed, returns "torch", if only tf is installed,
|
||||
returns "tf", if both are installed, raises an error.
|
||||
"""
|
||||
|
||||
# PyTorch is installed.
|
||||
if torch is not None:
|
||||
# TF is not installed -> return torch.
|
||||
if tf is None:
|
||||
if log_once("get_auto_framework"):
|
||||
logger.info(
|
||||
"`framework=auto` found in config -> Detected PyTorch.")
|
||||
return "torch"
|
||||
# TF is also installed -> raise error.
|
||||
else:
|
||||
raise ValueError(
|
||||
"framework='auto' (default value) is not allowed if both "
|
||||
"TensorFlow AND PyTorch are installed! "
|
||||
"Instead, use framework='tf|tfe|torch' explicitly.")
|
||||
# PyTorch nor TF installed -> raise error.
|
||||
if not tf:
|
||||
raise ValueError(
|
||||
"Neither TensorFlow nor PyTorch are installed! You must install "
|
||||
"one of them by running either `pip install tensorflow` OR "
|
||||
"`pip install torch torchvision`")
|
||||
# Only TensorFlow installed -> return tf.
|
||||
if log_once("get_auto_framework"):
|
||||
logger.info("`framework=auto` found in config -> Detected TensorFlow.")
|
||||
return "tf"
|
||||
|
||||
|
||||
def check_framework(framework, allow_none=True):
|
||||
"""Checks, whether the given framework is "valid".
|
||||
|
||||
Meaning, whether all necessary dependencies are installed.
|
||||
|
||||
Args:
|
||||
framework (str): Once of "tf", "torch", or None.
|
||||
allow_none (bool): Whether framework=None (e.g. numpy implementatiopn)
|
||||
is allowed or not.
|
||||
|
||||
Returns:
|
||||
str: The input framework string.
|
||||
|
||||
Raises:
|
||||
ImportError: If given framework is not installed.
|
||||
"""
|
||||
# Resolve auto framework first.
|
||||
if framework == "auto":
|
||||
framework = get_auto_framework()
|
||||
|
||||
# Check, whether tf is installed.
|
||||
if framework in ["tf", "tfe"]:
|
||||
if tf is None:
|
||||
raise ImportError(
|
||||
"Could not import `tensorflow`. Try `pip install tensorflow`")
|
||||
# Check, whether torch is installed.
|
||||
elif framework == "torch":
|
||||
if torch is None:
|
||||
raise ImportError("Could not import `torch`. "
|
||||
"Try `pip install torch torchvision`")
|
||||
# Framework is None (use numpy version of the component).
|
||||
elif framework is None:
|
||||
if not allow_none:
|
||||
raise ValueError("framework=None not allowed!")
|
||||
# Invalid value.
|
||||
else:
|
||||
raise ValueError("Invalid framework='{}'. Use one of "
|
||||
"[tf|tfe|torch|auto].".format(framework))
|
||||
return framework
|
||||
|
||||
|
||||
def try_import_tf(error=False):
|
||||
"""Tries importing tf and returns the module (or None).
|
||||
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
from abc import ABCMeta, abstractmethod
|
||||
|
||||
from ray.rllib.utils.annotations import DeveloperAPI
|
||||
from ray.rllib.utils.framework import check_framework, try_import_tf
|
||||
from ray.rllib.utils.framework import try_import_tf
|
||||
|
||||
tf = try_import_tf()
|
||||
|
||||
|
@ -24,7 +24,7 @@ class Schedule(metaclass=ABCMeta):
|
|||
"""
|
||||
|
||||
def __init__(self, framework):
|
||||
self.framework = check_framework(framework)
|
||||
self.framework = framework
|
||||
|
||||
def value(self, t):
|
||||
"""Generates the value given a timestep (based on schedule's logic).
|
||||
|
|
|
@ -1,8 +1,7 @@
|
|||
import logging
|
||||
import numpy as np
|
||||
|
||||
from ray.rllib.utils.framework import try_import_tf, try_import_torch, \
|
||||
get_auto_framework
|
||||
from ray.rllib.utils.framework import try_import_tf, try_import_torch
|
||||
|
||||
tf = try_import_tf()
|
||||
if tf:
|
||||
|
@ -29,7 +28,7 @@ def framework_iterator(config=None,
|
|||
config (Optional[dict]): An optional config dict to alter in place
|
||||
depending on the iteration.
|
||||
frameworks (Tuple[str]): A list/tuple of the frameworks to be tested.
|
||||
Allowed are: "tf", "tfe", "torch", and "auto".
|
||||
Allowed are: "tf", "tfe", and "torch".
|
||||
session (bool): If True, enter a tf.Session() and yield that as
|
||||
well in the tf-case (otherwise, yield (fw, None)).
|
||||
|
||||
|
@ -43,9 +42,6 @@ def framework_iterator(config=None,
|
|||
frameworks = [frameworks] if isinstance(frameworks, str) else frameworks
|
||||
|
||||
for fw in frameworks:
|
||||
if fw == "auto":
|
||||
fw = get_auto_framework()
|
||||
|
||||
# Skip non-installed frameworks.
|
||||
if fw == "torch" and not torch:
|
||||
logger.warning(
|
||||
|
|
Loading…
Add table
Reference in a new issue