mirror of
https://github.com/vale981/ray
synced 2025-03-06 18:41:40 -05:00
[tune] cleanup error messaging/diagnose_serialization helper (#10210)
This commit is contained in:
parent
24ee496b89
commit
6bd5458bef
8 changed files with 178 additions and 38 deletions
|
@ -197,7 +197,7 @@ Distributed Checkpointing
|
|||
|
||||
On a multinode cluster, Tune automatically creates a copy of all trial checkpoints on the head node. This requires the Ray cluster to be started with the :ref:`cluster launcher <ref-automatic-cluster>` and also requires rsync to be installed.
|
||||
|
||||
Note that you must use the ``tune.checkpoint_dir`` API to trigger syncing. Also, if running Tune on Kubernetes, be sure to use the :ref:`KubernetesSyncer <tune-kubernetes>` to transfer files between different pods.
|
||||
Note that you must use the ``tune.checkpoint_dir`` API to trigger syncing. Also, if running Tune on Kubernetes, be sure to use the :ref:`KubernetesSyncer <tune-kubernetes>` to transfer files between different pods.
|
||||
|
||||
If you do not use the cluster launcher, you should set up a NFS or global file system and
|
||||
disable cross-node syncing:
|
||||
|
@ -225,7 +225,7 @@ You often will want to compute a large object (e.g., training data, model weight
|
|||
# X_id can be referenced in closures
|
||||
X_id = pin_in_object_store(np.random.random(size=100000000))
|
||||
|
||||
def f(config, reporter):
|
||||
def f(config):
|
||||
X = get_pinned_object(X_id)
|
||||
# use X
|
||||
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
import copy
|
||||
import logging
|
||||
from pickle import PicklingError
|
||||
import os
|
||||
from typing import Sequence
|
||||
|
||||
|
@ -242,7 +243,24 @@ class Experiment:
|
|||
else:
|
||||
logger.warning(
|
||||
"No name detected on trainable. Using {}.".format(name))
|
||||
register_trainable(name, run_object)
|
||||
try:
|
||||
register_trainable(name, run_object)
|
||||
except (TypeError, PicklingError) as e:
|
||||
msg = (
|
||||
f"{str(e)}. The trainable ({str(run_object)}) could not "
|
||||
"be serialized, which is needed for parallel execution. "
|
||||
"To diagnose the issue, try the following:\n\n"
|
||||
"\t- Run `tune.utils.diagnose_serialization(trainable)` "
|
||||
"to check if non-serializable variables are captured "
|
||||
"in scope.\n"
|
||||
"\t- Try reproducing the issue by calling "
|
||||
"`pickle.dumps(trainable)`.\n"
|
||||
"\t- If the error is typing-related, try removing "
|
||||
"the type annotations and try again.\n\n"
|
||||
"If you have any suggestions on how to improve "
|
||||
"this error message, please reach out to the "
|
||||
"Ray developers on github.com/ray-project/ray/issues/")
|
||||
raise type(e)(msg) from None
|
||||
return name
|
||||
else:
|
||||
raise TuneError("Improper 'run' - not string nor trainable.")
|
||||
|
|
|
@ -9,6 +9,7 @@ import uuid
|
|||
|
||||
from six.moves import queue
|
||||
|
||||
from ray.util.debug import log_once
|
||||
from ray.tune import TuneError, session
|
||||
from ray.tune.trainable import Trainable, TrainableUtil
|
||||
from ray.tune.result import (TIME_THIS_ITER_S, RESULT_DUPLICATE,
|
||||
|
@ -476,34 +477,37 @@ def detect_checkpoint_function(train_func, abort=False):
|
|||
return validated
|
||||
|
||||
|
||||
def wrap_function(train_func):
|
||||
def wrap_function(train_func, warn=True):
|
||||
if hasattr(train_func, "__mixins__"):
|
||||
inherit_from = train_func.__mixins__ + (FunctionRunner, )
|
||||
else:
|
||||
inherit_from = (FunctionRunner, )
|
||||
|
||||
func_args = inspect.getfullargspec(train_func).args
|
||||
use_checkpoint = detect_checkpoint_function(train_func)
|
||||
if len(func_args) > 1: # more arguments than just the config
|
||||
if "reporter" not in func_args and not use_checkpoint:
|
||||
raise ValueError(
|
||||
"Unknown argument found in the Trainable function. "
|
||||
"Arguments other than the 'config' arg must be one "
|
||||
"of ['reporter', 'checkpoint_dir']. Found: {}".format(
|
||||
func_args))
|
||||
|
||||
use_reporter = "reporter" in func_args
|
||||
if not use_checkpoint and not use_reporter:
|
||||
if log_once("tune_function_checkpoint") and warn:
|
||||
logger.warning(
|
||||
"Function checkpointing is disabled. This may result in "
|
||||
"unexpected behavior when using checkpointing features or "
|
||||
"certain schedulers. To enable, set the train function "
|
||||
"arguments to be `func(config, checkpoint_dir=None)`.")
|
||||
|
||||
class ImplicitFunc(*inherit_from):
|
||||
_name = train_func.__name__ if hasattr(train_func, "__name__") \
|
||||
else "func"
|
||||
|
||||
def _trainable_func(self, config, reporter, checkpoint_dir):
|
||||
func_args = inspect.getfullargspec(train_func).args
|
||||
if len(func_args) > 1: # more arguments than just the config
|
||||
if "reporter" not in func_args and (
|
||||
not detect_checkpoint_function(train_func)):
|
||||
raise ValueError(
|
||||
"Unknown argument found in the Trainable function. "
|
||||
"Arguments other than the 'config' arg must be one "
|
||||
"of ['reporter', 'checkpoint_dir']. Found: {}".format(
|
||||
func_args))
|
||||
use_reporter = "reporter" in func_args
|
||||
use_checkpoint = detect_checkpoint_function(train_func)
|
||||
if not use_checkpoint and not use_reporter:
|
||||
logger.warning(
|
||||
"Function checkpointing is disabled. This may result in "
|
||||
"unexpected behavior when using checkpointing features or "
|
||||
"certain schedulers. To enable, set the train function "
|
||||
"arguments to be `func(config, checkpoint_dir=None)`.")
|
||||
output = train_func(config)
|
||||
elif use_checkpoint:
|
||||
output = train_func(config, checkpoint_dir=checkpoint_dir)
|
||||
|
|
|
@ -196,7 +196,9 @@ class TBXLogger(Logger):
|
|||
try:
|
||||
from tensorboardX import SummaryWriter
|
||||
except ImportError:
|
||||
logger.error("pip install 'ray[tune]' to see TensorBoard files.")
|
||||
if log_once("tbx-install"):
|
||||
logger.info(
|
||||
"pip install 'ray[tune]' to see TensorBoard files.")
|
||||
raise
|
||||
self._file_writer = SummaryWriter(self.logdir, flush_secs=30)
|
||||
self.last_result = None
|
||||
|
@ -329,8 +331,9 @@ class UnifiedLogger(Logger):
|
|||
try:
|
||||
self._loggers.append(cls(self.config, self.logdir, self.trial))
|
||||
except Exception as exc:
|
||||
logger.warning("Could not instantiate %s: %s.", cls.__name__,
|
||||
str(exc))
|
||||
if log_once(f"instantiate:{cls.__name__}"):
|
||||
logger.warning("Could not instantiate %s: %s.",
|
||||
cls.__name__, str(exc))
|
||||
self._log_syncer = get_node_syncer(
|
||||
self.logdir,
|
||||
remote_dir=self.logdir,
|
||||
|
|
|
@ -12,9 +12,10 @@ ENV_CREATOR = "env_creator"
|
|||
RLLIB_MODEL = "rllib_model"
|
||||
RLLIB_PREPROCESSOR = "rllib_preprocessor"
|
||||
RLLIB_ACTION_DIST = "rllib_action_dist"
|
||||
TEST = "__test__"
|
||||
KNOWN_CATEGORIES = [
|
||||
TRAINABLE_CLASS, ENV_CREATOR, RLLIB_MODEL, RLLIB_PREPROCESSOR,
|
||||
RLLIB_ACTION_DIST
|
||||
RLLIB_ACTION_DIST, TEST
|
||||
]
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -38,7 +39,7 @@ def validate_trainable(trainable_name):
|
|||
raise TuneError("Unknown trainable: " + trainable_name)
|
||||
|
||||
|
||||
def register_trainable(name, trainable):
|
||||
def register_trainable(name, trainable, warn=True):
|
||||
"""Register a trainable function or class.
|
||||
|
||||
This enables a class or function to be accessed on every Ray process
|
||||
|
@ -58,11 +59,11 @@ def register_trainable(name, trainable):
|
|||
logger.debug("Detected class for trainable.")
|
||||
elif isinstance(trainable, FunctionType):
|
||||
logger.debug("Detected function for trainable.")
|
||||
trainable = wrap_function(trainable)
|
||||
trainable = wrap_function(trainable, warn=warn)
|
||||
elif callable(trainable):
|
||||
logger.warning(
|
||||
logger.info(
|
||||
"Detected unknown callable for trainable. Converting to class.")
|
||||
trainable = wrap_function(trainable)
|
||||
trainable = wrap_function(trainable, warn=warn)
|
||||
|
||||
if not issubclass(trainable, Trainable):
|
||||
raise TypeError("Second argument must be convertable to Trainable",
|
||||
|
@ -86,6 +87,10 @@ def register_env(name, env_creator):
|
|||
_global_registry.register(ENV_CREATOR, name, env_creator)
|
||||
|
||||
|
||||
def check_serializability(key, value):
|
||||
_global_registry.register(TEST, key, value)
|
||||
|
||||
|
||||
def _make_key(category, key):
|
||||
"""Generate a binary key for the given category and key.
|
||||
|
||||
|
@ -105,6 +110,11 @@ class _Registry:
|
|||
self._to_flush = {}
|
||||
|
||||
def register(self, category, key, value):
|
||||
"""Registers the value with the global registry.
|
||||
|
||||
Raises:
|
||||
PicklingError if unable to pickle to provided file.
|
||||
"""
|
||||
if category not in KNOWN_CATEGORIES:
|
||||
from ray.tune import TuneError
|
||||
raise TuneError("Unknown category {} not among {}".format(
|
||||
|
|
|
@ -1,10 +1,12 @@
|
|||
import unittest
|
||||
import threading
|
||||
|
||||
import ray
|
||||
from ray.rllib import _register_all
|
||||
from ray.tune import register_trainable
|
||||
from ray.tune.experiment import Experiment, convert_to_experiment_list
|
||||
from ray.tune.error import TuneError
|
||||
from ray.tune.utils import diagnose_serialization
|
||||
|
||||
|
||||
class ExperimentTest(unittest.TestCase):
|
||||
|
@ -71,6 +73,28 @@ class ExperimentTest(unittest.TestCase):
|
|||
self.assertRaises(TuneError, lambda: convert_to_experiment_list("hi"))
|
||||
|
||||
|
||||
class ValidateUtilTest(unittest.TestCase):
|
||||
def testDiagnoseSerialization(self):
|
||||
|
||||
# this is not serializable
|
||||
e = threading.Event()
|
||||
|
||||
def test():
|
||||
print(e)
|
||||
|
||||
assert diagnose_serialization(test) is not True
|
||||
|
||||
# should help identify that 'e' should be moved into
|
||||
# the `test` scope.
|
||||
|
||||
# correct implementation
|
||||
def test():
|
||||
e = threading.Event()
|
||||
print(e)
|
||||
|
||||
assert diagnose_serialization(test) is True
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import pytest
|
||||
import sys
|
||||
|
|
|
@ -1,15 +1,9 @@
|
|||
from ray.tune.utils.util import deep_update, flatten_dict, get_pinned_object, \
|
||||
merge_dicts, pin_in_object_store, unflattened_lookup, UtilMonitor, \
|
||||
validate_save_restore, warn_if_slow
|
||||
validate_save_restore, warn_if_slow, diagnose_serialization
|
||||
|
||||
__all__ = [
|
||||
"deep_update",
|
||||
"flatten_dict",
|
||||
"get_pinned_object",
|
||||
"merge_dicts",
|
||||
"pin_in_object_store",
|
||||
"unflattened_lookup",
|
||||
"UtilMonitor",
|
||||
"validate_save_restore",
|
||||
"warn_if_slow",
|
||||
"deep_update", "flatten_dict", "get_pinned_object", "merge_dicts",
|
||||
"pin_in_object_store", "unflattened_lookup", "UtilMonitor",
|
||||
"validate_save_restore", "warn_if_slow", "diagnose_serialization"
|
||||
]
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
import copy
|
||||
import logging
|
||||
import inspect
|
||||
import threading
|
||||
import time
|
||||
from collections import defaultdict, deque, Mapping, Sequence
|
||||
|
@ -269,6 +270,92 @@ def _from_pinnable(obj):
|
|||
return obj[0]
|
||||
|
||||
|
||||
def diagnose_serialization(trainable):
|
||||
"""Utility for detecting accidentally-scoped objects.
|
||||
|
||||
Args:
|
||||
trainable (cls | func): The trainable object passed to
|
||||
tune.run(trainable).
|
||||
|
||||
Returns:
|
||||
bool | set of unserializable objects.
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block::
|
||||
|
||||
import threading
|
||||
# this is not serializable
|
||||
e = threading.Event()
|
||||
|
||||
def test():
|
||||
print(e)
|
||||
|
||||
diagnose_serialization(test)
|
||||
# should help identify that 'e' should be moved into
|
||||
# the `test` scope.
|
||||
|
||||
# correct implementation
|
||||
def test():
|
||||
e = threading.Event()
|
||||
print(e)
|
||||
|
||||
assert diagnose_serialization(test) is True
|
||||
|
||||
"""
|
||||
from ray.tune.registry import register_trainable, check_serializability
|
||||
|
||||
def check_variables(objects, failure_set, printer):
|
||||
for var_name, variable in objects.items():
|
||||
msg = None
|
||||
try:
|
||||
check_serializability(var_name, variable)
|
||||
status = "PASSED"
|
||||
except Exception as e:
|
||||
status = "FAILED"
|
||||
msg = f"{e.__class__.__name__}: {str(e)}"
|
||||
failure_set.add(var_name)
|
||||
printer(f"{str(variable)}[name='{var_name}'']... {status}")
|
||||
if msg:
|
||||
printer(msg)
|
||||
|
||||
print(f"Trying to serialize {trainable}...")
|
||||
try:
|
||||
register_trainable("__test:" + str(trainable), trainable, warn=False)
|
||||
print("Serialization succeeded!")
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"Serialization failed: {e}")
|
||||
|
||||
print("Inspecting the scope of the trainable by running "
|
||||
f"`inspect.getclosurevars({str(trainable)})`...")
|
||||
closure = inspect.getclosurevars(trainable)
|
||||
failure_set = set()
|
||||
if closure.globals:
|
||||
print(f"Detected {len(closure.globals)} global variables. "
|
||||
"Checking serializability...")
|
||||
check_variables(closure.globals, failure_set,
|
||||
lambda s: print(" " + s))
|
||||
|
||||
if closure.nonlocals:
|
||||
print(f"Detected {len(closure.nonlocals)} nonlocal variables. "
|
||||
"Checking serializability...")
|
||||
check_variables(closure.nonlocals, failure_set,
|
||||
lambda s: print(" " + s))
|
||||
|
||||
if not failure_set:
|
||||
print("Nothing was found to have failed the diagnostic test, though "
|
||||
"serialization did not succeed. Feel free to raise an "
|
||||
"issue on github.")
|
||||
return failure_set
|
||||
else:
|
||||
print(f"Variable(s) {failure_set} was found to be non-serializable. "
|
||||
"Consider either removing the instantiation/imports "
|
||||
"of these objects or moving them into the scope of "
|
||||
"the trainable. ")
|
||||
return failure_set
|
||||
|
||||
|
||||
def validate_save_restore(trainable_cls,
|
||||
config=None,
|
||||
num_gpus=0,
|
||||
|
|
Loading…
Add table
Reference in a new issue