[tune][minor] gpu warning (#5948)

* gpu

* formaat

* defaults

* format_and_check

* better registration

* fix

* fix

* trial

* foramt

* tune
This commit is contained in:
Richard Liaw 2019-10-19 17:09:48 -07:00 committed by GitHub
parent d23696de17
commit 91acecc9f9
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 63 additions and 23 deletions

View file

@ -93,7 +93,7 @@ class RayTrialExecutor(TrialExecutor):
memory=trial.resources.memory, memory=trial.resources.memory,
object_store_memory=trial.resources.object_store_memory, object_store_memory=trial.resources.object_store_memory,
resources=trial.resources.custom_resources)( resources=trial.resources.custom_resources)(
trial._get_trainable_cls()) trial.get_trainable_cls())
trial.init_logger() trial.init_logger()
# We checkpoint metadata here to try mitigating logdir duplication # We checkpoint metadata here to try mitigating logdir duplication
@ -622,6 +622,11 @@ class RayTrialExecutor(TrialExecutor):
trial.runner.export_model.remote(trial.export_formats)) trial.runner.export_model.remote(trial.export_formats))
return {} return {}
def has_gpus(self):
if self._resources_initialized:
self._update_avail_resources()
return self._avail_resources.gpu > 0
def _to_gb(n_bytes): def _to_gb(n_bytes):
return round(n_bytes / (1024**3), 2) return round(n_bytes / (1024**3), 2)

View file

@ -7,6 +7,7 @@ from types import FunctionType
import ray import ray
import ray.cloudpickle as pickle import ray.cloudpickle as pickle
from ray.experimental.internal_kv import _internal_kv_initialized, \ from ray.experimental.internal_kv import _internal_kv_initialized, \
_internal_kv_get, _internal_kv_put _internal_kv_get, _internal_kv_put
@ -23,6 +24,24 @@ KNOWN_CATEGORIES = [
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def has_trainable(trainable_name):
return _global_registry.contains(TRAINABLE_CLASS, trainable_name)
def get_trainable_cls(trainable_name):
validate_trainable(trainable_name)
return _global_registry.get(TRAINABLE_CLASS, trainable_name)
def validate_trainable(trainable_name):
if not has_trainable(trainable_name):
# Make sure rllib agents are registered
from ray import rllib # noqa: F401
from ray.tune.error import TuneError
if not has_trainable(trainable_name):
raise TuneError("Unknown trainable: " + trainable_name)
def register_trainable(name, trainable): def register_trainable(name, trainable):
"""Register a trainable function or class. """Register a trainable function or class.

View file

@ -10,13 +10,12 @@ import uuid
import time import time
import tempfile import tempfile
import os import os
import ray
from ray.tune import TuneError from ray.tune import TuneError
from ray.tune.logger import pretty_print, UnifiedLogger from ray.tune.logger import pretty_print, UnifiedLogger
# NOTE(rkn): We import ray.tune.registry here instead of importing the names we # NOTE(rkn): We import ray.tune.registry here instead of importing the names we
# need because there are cyclic imports that may cause specific names to not # need because there are cyclic imports that may cause specific names to not
# have been defined yet. See https://github.com/ray-project/ray/issues/1716. # have been defined yet. See https://github.com/ray-project/ray/issues/1716.
import ray.tune.registry from ray.tune.registry import get_trainable_cls, validate_trainable
from ray.tune.result import DEFAULT_RESULTS_DIR, DONE, TRAINING_ITERATION from ray.tune.result import DEFAULT_RESULTS_DIR, DONE, TRAINING_ITERATION
from ray.utils import binary_to_hex, hex_to_binary from ray.utils import binary_to_hex, hex_to_binary
from ray.tune.resources import Resources, json_to_resources, resources_to_json from ray.tune.resources import Resources, json_to_resources, resources_to_json
@ -30,11 +29,6 @@ def date_str():
return datetime.today().strftime("%Y-%m-%d_%H-%M-%S") return datetime.today().strftime("%Y-%m-%d_%H-%M-%S")
def has_trainable(trainable_name):
return ray.tune.registry._global_registry.contains(
ray.tune.registry.TRAINABLE_CLASS, trainable_name)
class Checkpoint(object): class Checkpoint(object):
"""Describes a checkpoint of trial state. """Describes a checkpoint of trial state.
@ -126,7 +120,7 @@ class Trial(object):
in ray.tune.config_parser. in ray.tune.config_parser.
""" """
Trial._registration_check(trainable_name) validate_trainable(trainable_name)
# Trial config # Trial config
self.trainable_name = trainable_name self.trainable_name = trainable_name
self.trial_id = Trial.generate_id() if trial_id is None else trial_id self.trial_id = Trial.generate_id() if trial_id is None else trial_id
@ -136,7 +130,7 @@ class Trial(object):
#: Parameters that Tune varies across searches. #: Parameters that Tune varies across searches.
self.evaluated_params = evaluated_params or {} self.evaluated_params = evaluated_params or {}
self.experiment_tag = experiment_tag self.experiment_tag = experiment_tag
trainable_cls = self._get_trainable_cls() trainable_cls = self.get_trainable_cls()
if trainable_cls and hasattr(trainable_cls, if trainable_cls and hasattr(trainable_cls,
"default_resource_request"): "default_resource_request"):
default_resources = trainable_cls.default_resource_request( default_resources = trainable_cls.default_resource_request(
@ -202,14 +196,6 @@ class Trial(object):
if trial_name_creator: if trial_name_creator:
self.custom_trial_name = trial_name_creator(self) self.custom_trial_name = trial_name_creator(self)
@classmethod
def _registration_check(cls, trainable_name):
if not has_trainable(trainable_name):
# Make sure rllib agents are registered
from ray import rllib # noqa: F401
if not has_trainable(trainable_name):
raise TuneError("Unknown trainable: " + trainable_name)
@classmethod @classmethod
def generate_id(cls): def generate_id(cls):
return str(uuid.uuid1().hex)[:8] return str(uuid.uuid1().hex)[:8]
@ -363,9 +349,8 @@ class Trial(object):
return True return True
return False return False
def _get_trainable_cls(self): def get_trainable_cls(self):
return ray.tune.registry._global_registry.get( return get_trainable_cls(self.trainable_name)
ray.tune.registry.TRAINABLE_CLASS, self.trainable_name)
def set_verbose(self, verbose): def set_verbose(self, verbose):
self.verbose = verbose self.verbose = verbose
@ -430,6 +415,6 @@ class Trial(object):
state[key] = cloudpickle.loads(hex_to_binary(state[key])) state[key] = cloudpickle.loads(hex_to_binary(state[key]))
self.__dict__.update(state) self.__dict__.update(state)
Trial._registration_check(self.trainable_name) validate_trainable(self.trainable_name)
if logger_started: if logger_started:
self.init_logger() self.init_logger()

View file

@ -226,3 +226,7 @@ class TrialExecutor(object):
""" """
raise NotImplementedError("Subclasses of TrialExecutor must provide " raise NotImplementedError("Subclasses of TrialExecutor must provide "
"export_trial_if_needed() method") "export_trial_if_needed() method")
def has_gpus(self):
"""Returns True if GPUs are detected on the cluster."""
return None

View file

@ -346,7 +346,7 @@ class TrialRunner(object):
"up. {}").format( "up. {}").format(
trial.resources.summary_string(), trial.resources.summary_string(),
self.trial_executor.resource_string(), self.trial_executor.resource_string(),
trial._get_trainable_cls().resource_help( trial.get_trainable_cls().resource_help(
trial.config))) trial.config)))
elif trial.status == Trial.PAUSED: elif trial.status == Trial.PAUSED:
raise TuneError( raise TuneError(

View file

@ -10,7 +10,9 @@ from ray.tune.experiment import convert_to_experiment_list, Experiment
from ray.tune.analysis import ExperimentAnalysis from ray.tune.analysis import ExperimentAnalysis
from ray.tune.suggest import BasicVariantGenerator from ray.tune.suggest import BasicVariantGenerator
from ray.tune.trial import Trial, DEBUG_PRINT_INTERVAL from ray.tune.trial import Trial, DEBUG_PRINT_INTERVAL
from ray.tune.trainable import Trainable
from ray.tune.ray_trial_executor import RayTrialExecutor from ray.tune.ray_trial_executor import RayTrialExecutor
from ray.tune.registry import get_trainable_cls
from ray.tune.syncer import wait_for_sync from ray.tune.syncer import wait_for_sync
from ray.tune.trial_runner import TrialRunner from ray.tune.trial_runner import TrialRunner
from ray.tune.progress_reporter import CLIReporter, JupyterNotebookReporter from ray.tune.progress_reporter import CLIReporter, JupyterNotebookReporter
@ -42,6 +44,13 @@ def _make_scheduler(args):
args.scheduler, _SCHEDULERS.keys())) args.scheduler, _SCHEDULERS.keys()))
def _check_default_resources_override(run_identifier):
trainable_cls = get_trainable_cls(run_identifier)
return hasattr(trainable_cls, "default_resource_request") and (
trainable_cls.default_resource_request.__code__ !=
Trainable.default_resource_request.__code__)
def run(run_or_experiment, def run(run_or_experiment,
name=None, name=None,
stop=None, stop=None,
@ -250,6 +259,24 @@ def run(run_or_experiment,
else: else:
reporter = CLIReporter() reporter = CLIReporter()
# User Warning for GPUs
if trial_executor.has_gpus():
if isinstance(resources_per_trial,
dict) and "gpu" in resources_per_trial:
# "gpu" is manually set.
pass
elif _check_default_resources_override(run_identifier):
# "default_resources" is manually overriden.
pass
else:
logger.warning("Tune detects GPUs, but no trials are using GPUs. "
"To enable trials to use GPUs, set "
"tune.run(resources_per_trial={'gpu': 1}...) "
"which allows Tune to expose 1 GPU to each trial. "
"You can also override "
"`Trainable.default_resource_request` if using the "
"Trainable API.")
last_debug = 0 last_debug = 0
while not runner.is_finished(): while not runner.is_finished():
runner.step() runner.step()