mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[tune][minor] gpu warning (#5948)
* gpu * formaat * defaults * format_and_check * better registration * fix * fix * trial * foramt * tune
This commit is contained in:
parent
d23696de17
commit
91acecc9f9
6 changed files with 63 additions and 23 deletions
|
@ -93,7 +93,7 @@ class RayTrialExecutor(TrialExecutor):
|
|||
memory=trial.resources.memory,
|
||||
object_store_memory=trial.resources.object_store_memory,
|
||||
resources=trial.resources.custom_resources)(
|
||||
trial._get_trainable_cls())
|
||||
trial.get_trainable_cls())
|
||||
|
||||
trial.init_logger()
|
||||
# We checkpoint metadata here to try mitigating logdir duplication
|
||||
|
@ -622,6 +622,11 @@ class RayTrialExecutor(TrialExecutor):
|
|||
trial.runner.export_model.remote(trial.export_formats))
|
||||
return {}
|
||||
|
||||
def has_gpus(self):
|
||||
if self._resources_initialized:
|
||||
self._update_avail_resources()
|
||||
return self._avail_resources.gpu > 0
|
||||
|
||||
|
||||
def _to_gb(n_bytes):
|
||||
return round(n_bytes / (1024**3), 2)
|
||||
|
|
|
@ -7,6 +7,7 @@ from types import FunctionType
|
|||
|
||||
import ray
|
||||
import ray.cloudpickle as pickle
|
||||
|
||||
from ray.experimental.internal_kv import _internal_kv_initialized, \
|
||||
_internal_kv_get, _internal_kv_put
|
||||
|
||||
|
@ -23,6 +24,24 @@ KNOWN_CATEGORIES = [
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def has_trainable(trainable_name):
|
||||
return _global_registry.contains(TRAINABLE_CLASS, trainable_name)
|
||||
|
||||
|
||||
def get_trainable_cls(trainable_name):
|
||||
validate_trainable(trainable_name)
|
||||
return _global_registry.get(TRAINABLE_CLASS, trainable_name)
|
||||
|
||||
|
||||
def validate_trainable(trainable_name):
|
||||
if not has_trainable(trainable_name):
|
||||
# Make sure rllib agents are registered
|
||||
from ray import rllib # noqa: F401
|
||||
from ray.tune.error import TuneError
|
||||
if not has_trainable(trainable_name):
|
||||
raise TuneError("Unknown trainable: " + trainable_name)
|
||||
|
||||
|
||||
def register_trainable(name, trainable):
|
||||
"""Register a trainable function or class.
|
||||
|
||||
|
|
|
@ -10,13 +10,12 @@ import uuid
|
|||
import time
|
||||
import tempfile
|
||||
import os
|
||||
import ray
|
||||
from ray.tune import TuneError
|
||||
from ray.tune.logger import pretty_print, UnifiedLogger
|
||||
# NOTE(rkn): We import ray.tune.registry here instead of importing the names we
|
||||
# need because there are cyclic imports that may cause specific names to not
|
||||
# have been defined yet. See https://github.com/ray-project/ray/issues/1716.
|
||||
import ray.tune.registry
|
||||
from ray.tune.registry import get_trainable_cls, validate_trainable
|
||||
from ray.tune.result import DEFAULT_RESULTS_DIR, DONE, TRAINING_ITERATION
|
||||
from ray.utils import binary_to_hex, hex_to_binary
|
||||
from ray.tune.resources import Resources, json_to_resources, resources_to_json
|
||||
|
@ -30,11 +29,6 @@ def date_str():
|
|||
return datetime.today().strftime("%Y-%m-%d_%H-%M-%S")
|
||||
|
||||
|
||||
def has_trainable(trainable_name):
|
||||
return ray.tune.registry._global_registry.contains(
|
||||
ray.tune.registry.TRAINABLE_CLASS, trainable_name)
|
||||
|
||||
|
||||
class Checkpoint(object):
|
||||
"""Describes a checkpoint of trial state.
|
||||
|
||||
|
@ -126,7 +120,7 @@ class Trial(object):
|
|||
in ray.tune.config_parser.
|
||||
"""
|
||||
|
||||
Trial._registration_check(trainable_name)
|
||||
validate_trainable(trainable_name)
|
||||
# Trial config
|
||||
self.trainable_name = trainable_name
|
||||
self.trial_id = Trial.generate_id() if trial_id is None else trial_id
|
||||
|
@ -136,7 +130,7 @@ class Trial(object):
|
|||
#: Parameters that Tune varies across searches.
|
||||
self.evaluated_params = evaluated_params or {}
|
||||
self.experiment_tag = experiment_tag
|
||||
trainable_cls = self._get_trainable_cls()
|
||||
trainable_cls = self.get_trainable_cls()
|
||||
if trainable_cls and hasattr(trainable_cls,
|
||||
"default_resource_request"):
|
||||
default_resources = trainable_cls.default_resource_request(
|
||||
|
@ -202,14 +196,6 @@ class Trial(object):
|
|||
if trial_name_creator:
|
||||
self.custom_trial_name = trial_name_creator(self)
|
||||
|
||||
@classmethod
|
||||
def _registration_check(cls, trainable_name):
|
||||
if not has_trainable(trainable_name):
|
||||
# Make sure rllib agents are registered
|
||||
from ray import rllib # noqa: F401
|
||||
if not has_trainable(trainable_name):
|
||||
raise TuneError("Unknown trainable: " + trainable_name)
|
||||
|
||||
@classmethod
|
||||
def generate_id(cls):
|
||||
return str(uuid.uuid1().hex)[:8]
|
||||
|
@ -363,9 +349,8 @@ class Trial(object):
|
|||
return True
|
||||
return False
|
||||
|
||||
def _get_trainable_cls(self):
|
||||
return ray.tune.registry._global_registry.get(
|
||||
ray.tune.registry.TRAINABLE_CLASS, self.trainable_name)
|
||||
def get_trainable_cls(self):
|
||||
return get_trainable_cls(self.trainable_name)
|
||||
|
||||
def set_verbose(self, verbose):
|
||||
self.verbose = verbose
|
||||
|
@ -430,6 +415,6 @@ class Trial(object):
|
|||
state[key] = cloudpickle.loads(hex_to_binary(state[key]))
|
||||
|
||||
self.__dict__.update(state)
|
||||
Trial._registration_check(self.trainable_name)
|
||||
validate_trainable(self.trainable_name)
|
||||
if logger_started:
|
||||
self.init_logger()
|
||||
|
|
|
@ -226,3 +226,7 @@ class TrialExecutor(object):
|
|||
"""
|
||||
raise NotImplementedError("Subclasses of TrialExecutor must provide "
|
||||
"export_trial_if_needed() method")
|
||||
|
||||
def has_gpus(self):
|
||||
"""Returns True if GPUs are detected on the cluster."""
|
||||
return None
|
||||
|
|
|
@ -346,7 +346,7 @@ class TrialRunner(object):
|
|||
"up. {}").format(
|
||||
trial.resources.summary_string(),
|
||||
self.trial_executor.resource_string(),
|
||||
trial._get_trainable_cls().resource_help(
|
||||
trial.get_trainable_cls().resource_help(
|
||||
trial.config)))
|
||||
elif trial.status == Trial.PAUSED:
|
||||
raise TuneError(
|
||||
|
|
|
@ -10,7 +10,9 @@ from ray.tune.experiment import convert_to_experiment_list, Experiment
|
|||
from ray.tune.analysis import ExperimentAnalysis
|
||||
from ray.tune.suggest import BasicVariantGenerator
|
||||
from ray.tune.trial import Trial, DEBUG_PRINT_INTERVAL
|
||||
from ray.tune.trainable import Trainable
|
||||
from ray.tune.ray_trial_executor import RayTrialExecutor
|
||||
from ray.tune.registry import get_trainable_cls
|
||||
from ray.tune.syncer import wait_for_sync
|
||||
from ray.tune.trial_runner import TrialRunner
|
||||
from ray.tune.progress_reporter import CLIReporter, JupyterNotebookReporter
|
||||
|
@ -42,6 +44,13 @@ def _make_scheduler(args):
|
|||
args.scheduler, _SCHEDULERS.keys()))
|
||||
|
||||
|
||||
def _check_default_resources_override(run_identifier):
|
||||
trainable_cls = get_trainable_cls(run_identifier)
|
||||
return hasattr(trainable_cls, "default_resource_request") and (
|
||||
trainable_cls.default_resource_request.__code__ !=
|
||||
Trainable.default_resource_request.__code__)
|
||||
|
||||
|
||||
def run(run_or_experiment,
|
||||
name=None,
|
||||
stop=None,
|
||||
|
@ -250,6 +259,24 @@ def run(run_or_experiment,
|
|||
else:
|
||||
reporter = CLIReporter()
|
||||
|
||||
# User Warning for GPUs
|
||||
if trial_executor.has_gpus():
|
||||
if isinstance(resources_per_trial,
|
||||
dict) and "gpu" in resources_per_trial:
|
||||
# "gpu" is manually set.
|
||||
pass
|
||||
elif _check_default_resources_override(run_identifier):
|
||||
# "default_resources" is manually overriden.
|
||||
pass
|
||||
else:
|
||||
logger.warning("Tune detects GPUs, but no trials are using GPUs. "
|
||||
"To enable trials to use GPUs, set "
|
||||
"tune.run(resources_per_trial={'gpu': 1}...) "
|
||||
"which allows Tune to expose 1 GPU to each trial. "
|
||||
"You can also override "
|
||||
"`Trainable.default_resource_request` if using the "
|
||||
"Trainable API.")
|
||||
|
||||
last_debug = 0
|
||||
while not runner.is_finished():
|
||||
runner.step()
|
||||
|
|
Loading…
Add table
Reference in a new issue