mirror of
https://github.com/vale981/ray
synced 2025-03-08 11:31:40 -05:00
[tune] cleanup error messaging/diagnose_serialization helper (#10210)
This commit is contained in:
parent
24ee496b89
commit
6bd5458bef
8 changed files with 178 additions and 38 deletions
|
@ -197,7 +197,7 @@ Distributed Checkpointing
|
||||||
|
|
||||||
On a multinode cluster, Tune automatically creates a copy of all trial checkpoints on the head node. This requires the Ray cluster to be started with the :ref:`cluster launcher <ref-automatic-cluster>` and also requires rsync to be installed.
|
On a multinode cluster, Tune automatically creates a copy of all trial checkpoints on the head node. This requires the Ray cluster to be started with the :ref:`cluster launcher <ref-automatic-cluster>` and also requires rsync to be installed.
|
||||||
|
|
||||||
Note that you must use the ``tune.checkpoint_dir`` API to trigger syncing. Also, if running Tune on Kubernetes, be sure to use the :ref:`KubernetesSyncer <tune-kubernetes>` to transfer files between different pods.
|
Note that you must use the ``tune.checkpoint_dir`` API to trigger syncing. Also, if running Tune on Kubernetes, be sure to use the :ref:`KubernetesSyncer <tune-kubernetes>` to transfer files between different pods.
|
||||||
|
|
||||||
If you do not use the cluster launcher, you should set up a NFS or global file system and
|
If you do not use the cluster launcher, you should set up a NFS or global file system and
|
||||||
disable cross-node syncing:
|
disable cross-node syncing:
|
||||||
|
@ -225,7 +225,7 @@ You often will want to compute a large object (e.g., training data, model weight
|
||||||
# X_id can be referenced in closures
|
# X_id can be referenced in closures
|
||||||
X_id = pin_in_object_store(np.random.random(size=100000000))
|
X_id = pin_in_object_store(np.random.random(size=100000000))
|
||||||
|
|
||||||
def f(config, reporter):
|
def f(config):
|
||||||
X = get_pinned_object(X_id)
|
X = get_pinned_object(X_id)
|
||||||
# use X
|
# use X
|
||||||
|
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
import copy
|
import copy
|
||||||
import logging
|
import logging
|
||||||
|
from pickle import PicklingError
|
||||||
import os
|
import os
|
||||||
from typing import Sequence
|
from typing import Sequence
|
||||||
|
|
||||||
|
@ -242,7 +243,24 @@ class Experiment:
|
||||||
else:
|
else:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"No name detected on trainable. Using {}.".format(name))
|
"No name detected on trainable. Using {}.".format(name))
|
||||||
register_trainable(name, run_object)
|
try:
|
||||||
|
register_trainable(name, run_object)
|
||||||
|
except (TypeError, PicklingError) as e:
|
||||||
|
msg = (
|
||||||
|
f"{str(e)}. The trainable ({str(run_object)}) could not "
|
||||||
|
"be serialized, which is needed for parallel execution. "
|
||||||
|
"To diagnose the issue, try the following:\n\n"
|
||||||
|
"\t- Run `tune.utils.diagnose_serialization(trainable)` "
|
||||||
|
"to check if non-serializable variables are captured "
|
||||||
|
"in scope.\n"
|
||||||
|
"\t- Try reproducing the issue by calling "
|
||||||
|
"`pickle.dumps(trainable)`.\n"
|
||||||
|
"\t- If the error is typing-related, try removing "
|
||||||
|
"the type annotations and try again.\n\n"
|
||||||
|
"If you have any suggestions on how to improve "
|
||||||
|
"this error message, please reach out to the "
|
||||||
|
"Ray developers on github.com/ray-project/ray/issues/")
|
||||||
|
raise type(e)(msg) from None
|
||||||
return name
|
return name
|
||||||
else:
|
else:
|
||||||
raise TuneError("Improper 'run' - not string nor trainable.")
|
raise TuneError("Improper 'run' - not string nor trainable.")
|
||||||
|
|
|
@ -9,6 +9,7 @@ import uuid
|
||||||
|
|
||||||
from six.moves import queue
|
from six.moves import queue
|
||||||
|
|
||||||
|
from ray.util.debug import log_once
|
||||||
from ray.tune import TuneError, session
|
from ray.tune import TuneError, session
|
||||||
from ray.tune.trainable import Trainable, TrainableUtil
|
from ray.tune.trainable import Trainable, TrainableUtil
|
||||||
from ray.tune.result import (TIME_THIS_ITER_S, RESULT_DUPLICATE,
|
from ray.tune.result import (TIME_THIS_ITER_S, RESULT_DUPLICATE,
|
||||||
|
@ -476,34 +477,37 @@ def detect_checkpoint_function(train_func, abort=False):
|
||||||
return validated
|
return validated
|
||||||
|
|
||||||
|
|
||||||
def wrap_function(train_func):
|
def wrap_function(train_func, warn=True):
|
||||||
if hasattr(train_func, "__mixins__"):
|
if hasattr(train_func, "__mixins__"):
|
||||||
inherit_from = train_func.__mixins__ + (FunctionRunner, )
|
inherit_from = train_func.__mixins__ + (FunctionRunner, )
|
||||||
else:
|
else:
|
||||||
inherit_from = (FunctionRunner, )
|
inherit_from = (FunctionRunner, )
|
||||||
|
|
||||||
|
func_args = inspect.getfullargspec(train_func).args
|
||||||
|
use_checkpoint = detect_checkpoint_function(train_func)
|
||||||
|
if len(func_args) > 1: # more arguments than just the config
|
||||||
|
if "reporter" not in func_args and not use_checkpoint:
|
||||||
|
raise ValueError(
|
||||||
|
"Unknown argument found in the Trainable function. "
|
||||||
|
"Arguments other than the 'config' arg must be one "
|
||||||
|
"of ['reporter', 'checkpoint_dir']. Found: {}".format(
|
||||||
|
func_args))
|
||||||
|
|
||||||
|
use_reporter = "reporter" in func_args
|
||||||
|
if not use_checkpoint and not use_reporter:
|
||||||
|
if log_once("tune_function_checkpoint") and warn:
|
||||||
|
logger.warning(
|
||||||
|
"Function checkpointing is disabled. This may result in "
|
||||||
|
"unexpected behavior when using checkpointing features or "
|
||||||
|
"certain schedulers. To enable, set the train function "
|
||||||
|
"arguments to be `func(config, checkpoint_dir=None)`.")
|
||||||
|
|
||||||
class ImplicitFunc(*inherit_from):
|
class ImplicitFunc(*inherit_from):
|
||||||
_name = train_func.__name__ if hasattr(train_func, "__name__") \
|
_name = train_func.__name__ if hasattr(train_func, "__name__") \
|
||||||
else "func"
|
else "func"
|
||||||
|
|
||||||
def _trainable_func(self, config, reporter, checkpoint_dir):
|
def _trainable_func(self, config, reporter, checkpoint_dir):
|
||||||
func_args = inspect.getfullargspec(train_func).args
|
|
||||||
if len(func_args) > 1: # more arguments than just the config
|
|
||||||
if "reporter" not in func_args and (
|
|
||||||
not detect_checkpoint_function(train_func)):
|
|
||||||
raise ValueError(
|
|
||||||
"Unknown argument found in the Trainable function. "
|
|
||||||
"Arguments other than the 'config' arg must be one "
|
|
||||||
"of ['reporter', 'checkpoint_dir']. Found: {}".format(
|
|
||||||
func_args))
|
|
||||||
use_reporter = "reporter" in func_args
|
|
||||||
use_checkpoint = detect_checkpoint_function(train_func)
|
|
||||||
if not use_checkpoint and not use_reporter:
|
if not use_checkpoint and not use_reporter:
|
||||||
logger.warning(
|
|
||||||
"Function checkpointing is disabled. This may result in "
|
|
||||||
"unexpected behavior when using checkpointing features or "
|
|
||||||
"certain schedulers. To enable, set the train function "
|
|
||||||
"arguments to be `func(config, checkpoint_dir=None)`.")
|
|
||||||
output = train_func(config)
|
output = train_func(config)
|
||||||
elif use_checkpoint:
|
elif use_checkpoint:
|
||||||
output = train_func(config, checkpoint_dir=checkpoint_dir)
|
output = train_func(config, checkpoint_dir=checkpoint_dir)
|
||||||
|
|
|
@ -196,7 +196,9 @@ class TBXLogger(Logger):
|
||||||
try:
|
try:
|
||||||
from tensorboardX import SummaryWriter
|
from tensorboardX import SummaryWriter
|
||||||
except ImportError:
|
except ImportError:
|
||||||
logger.error("pip install 'ray[tune]' to see TensorBoard files.")
|
if log_once("tbx-install"):
|
||||||
|
logger.info(
|
||||||
|
"pip install 'ray[tune]' to see TensorBoard files.")
|
||||||
raise
|
raise
|
||||||
self._file_writer = SummaryWriter(self.logdir, flush_secs=30)
|
self._file_writer = SummaryWriter(self.logdir, flush_secs=30)
|
||||||
self.last_result = None
|
self.last_result = None
|
||||||
|
@ -329,8 +331,9 @@ class UnifiedLogger(Logger):
|
||||||
try:
|
try:
|
||||||
self._loggers.append(cls(self.config, self.logdir, self.trial))
|
self._loggers.append(cls(self.config, self.logdir, self.trial))
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.warning("Could not instantiate %s: %s.", cls.__name__,
|
if log_once(f"instantiate:{cls.__name__}"):
|
||||||
str(exc))
|
logger.warning("Could not instantiate %s: %s.",
|
||||||
|
cls.__name__, str(exc))
|
||||||
self._log_syncer = get_node_syncer(
|
self._log_syncer = get_node_syncer(
|
||||||
self.logdir,
|
self.logdir,
|
||||||
remote_dir=self.logdir,
|
remote_dir=self.logdir,
|
||||||
|
|
|
@ -12,9 +12,10 @@ ENV_CREATOR = "env_creator"
|
||||||
RLLIB_MODEL = "rllib_model"
|
RLLIB_MODEL = "rllib_model"
|
||||||
RLLIB_PREPROCESSOR = "rllib_preprocessor"
|
RLLIB_PREPROCESSOR = "rllib_preprocessor"
|
||||||
RLLIB_ACTION_DIST = "rllib_action_dist"
|
RLLIB_ACTION_DIST = "rllib_action_dist"
|
||||||
|
TEST = "__test__"
|
||||||
KNOWN_CATEGORIES = [
|
KNOWN_CATEGORIES = [
|
||||||
TRAINABLE_CLASS, ENV_CREATOR, RLLIB_MODEL, RLLIB_PREPROCESSOR,
|
TRAINABLE_CLASS, ENV_CREATOR, RLLIB_MODEL, RLLIB_PREPROCESSOR,
|
||||||
RLLIB_ACTION_DIST
|
RLLIB_ACTION_DIST, TEST
|
||||||
]
|
]
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
@ -38,7 +39,7 @@ def validate_trainable(trainable_name):
|
||||||
raise TuneError("Unknown trainable: " + trainable_name)
|
raise TuneError("Unknown trainable: " + trainable_name)
|
||||||
|
|
||||||
|
|
||||||
def register_trainable(name, trainable):
|
def register_trainable(name, trainable, warn=True):
|
||||||
"""Register a trainable function or class.
|
"""Register a trainable function or class.
|
||||||
|
|
||||||
This enables a class or function to be accessed on every Ray process
|
This enables a class or function to be accessed on every Ray process
|
||||||
|
@ -58,11 +59,11 @@ def register_trainable(name, trainable):
|
||||||
logger.debug("Detected class for trainable.")
|
logger.debug("Detected class for trainable.")
|
||||||
elif isinstance(trainable, FunctionType):
|
elif isinstance(trainable, FunctionType):
|
||||||
logger.debug("Detected function for trainable.")
|
logger.debug("Detected function for trainable.")
|
||||||
trainable = wrap_function(trainable)
|
trainable = wrap_function(trainable, warn=warn)
|
||||||
elif callable(trainable):
|
elif callable(trainable):
|
||||||
logger.warning(
|
logger.info(
|
||||||
"Detected unknown callable for trainable. Converting to class.")
|
"Detected unknown callable for trainable. Converting to class.")
|
||||||
trainable = wrap_function(trainable)
|
trainable = wrap_function(trainable, warn=warn)
|
||||||
|
|
||||||
if not issubclass(trainable, Trainable):
|
if not issubclass(trainable, Trainable):
|
||||||
raise TypeError("Second argument must be convertable to Trainable",
|
raise TypeError("Second argument must be convertable to Trainable",
|
||||||
|
@ -86,6 +87,10 @@ def register_env(name, env_creator):
|
||||||
_global_registry.register(ENV_CREATOR, name, env_creator)
|
_global_registry.register(ENV_CREATOR, name, env_creator)
|
||||||
|
|
||||||
|
|
||||||
|
def check_serializability(key, value):
|
||||||
|
_global_registry.register(TEST, key, value)
|
||||||
|
|
||||||
|
|
||||||
def _make_key(category, key):
|
def _make_key(category, key):
|
||||||
"""Generate a binary key for the given category and key.
|
"""Generate a binary key for the given category and key.
|
||||||
|
|
||||||
|
@ -105,6 +110,11 @@ class _Registry:
|
||||||
self._to_flush = {}
|
self._to_flush = {}
|
||||||
|
|
||||||
def register(self, category, key, value):
|
def register(self, category, key, value):
|
||||||
|
"""Registers the value with the global registry.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
PicklingError if unable to pickle to provided file.
|
||||||
|
"""
|
||||||
if category not in KNOWN_CATEGORIES:
|
if category not in KNOWN_CATEGORIES:
|
||||||
from ray.tune import TuneError
|
from ray.tune import TuneError
|
||||||
raise TuneError("Unknown category {} not among {}".format(
|
raise TuneError("Unknown category {} not among {}".format(
|
||||||
|
|
|
@ -1,10 +1,12 @@
|
||||||
import unittest
|
import unittest
|
||||||
|
import threading
|
||||||
|
|
||||||
import ray
|
import ray
|
||||||
from ray.rllib import _register_all
|
from ray.rllib import _register_all
|
||||||
from ray.tune import register_trainable
|
from ray.tune import register_trainable
|
||||||
from ray.tune.experiment import Experiment, convert_to_experiment_list
|
from ray.tune.experiment import Experiment, convert_to_experiment_list
|
||||||
from ray.tune.error import TuneError
|
from ray.tune.error import TuneError
|
||||||
|
from ray.tune.utils import diagnose_serialization
|
||||||
|
|
||||||
|
|
||||||
class ExperimentTest(unittest.TestCase):
|
class ExperimentTest(unittest.TestCase):
|
||||||
|
@ -71,6 +73,28 @@ class ExperimentTest(unittest.TestCase):
|
||||||
self.assertRaises(TuneError, lambda: convert_to_experiment_list("hi"))
|
self.assertRaises(TuneError, lambda: convert_to_experiment_list("hi"))
|
||||||
|
|
||||||
|
|
||||||
|
class ValidateUtilTest(unittest.TestCase):
|
||||||
|
def testDiagnoseSerialization(self):
|
||||||
|
|
||||||
|
# this is not serializable
|
||||||
|
e = threading.Event()
|
||||||
|
|
||||||
|
def test():
|
||||||
|
print(e)
|
||||||
|
|
||||||
|
assert diagnose_serialization(test) is not True
|
||||||
|
|
||||||
|
# should help identify that 'e' should be moved into
|
||||||
|
# the `test` scope.
|
||||||
|
|
||||||
|
# correct implementation
|
||||||
|
def test():
|
||||||
|
e = threading.Event()
|
||||||
|
print(e)
|
||||||
|
|
||||||
|
assert diagnose_serialization(test) is True
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
import pytest
|
import pytest
|
||||||
import sys
|
import sys
|
||||||
|
|
|
@ -1,15 +1,9 @@
|
||||||
from ray.tune.utils.util import deep_update, flatten_dict, get_pinned_object, \
|
from ray.tune.utils.util import deep_update, flatten_dict, get_pinned_object, \
|
||||||
merge_dicts, pin_in_object_store, unflattened_lookup, UtilMonitor, \
|
merge_dicts, pin_in_object_store, unflattened_lookup, UtilMonitor, \
|
||||||
validate_save_restore, warn_if_slow
|
validate_save_restore, warn_if_slow, diagnose_serialization
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"deep_update",
|
"deep_update", "flatten_dict", "get_pinned_object", "merge_dicts",
|
||||||
"flatten_dict",
|
"pin_in_object_store", "unflattened_lookup", "UtilMonitor",
|
||||||
"get_pinned_object",
|
"validate_save_restore", "warn_if_slow", "diagnose_serialization"
|
||||||
"merge_dicts",
|
|
||||||
"pin_in_object_store",
|
|
||||||
"unflattened_lookup",
|
|
||||||
"UtilMonitor",
|
|
||||||
"validate_save_restore",
|
|
||||||
"warn_if_slow",
|
|
||||||
]
|
]
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
import copy
|
import copy
|
||||||
import logging
|
import logging
|
||||||
|
import inspect
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
from collections import defaultdict, deque, Mapping, Sequence
|
from collections import defaultdict, deque, Mapping, Sequence
|
||||||
|
@ -269,6 +270,92 @@ def _from_pinnable(obj):
|
||||||
return obj[0]
|
return obj[0]
|
||||||
|
|
||||||
|
|
||||||
|
def diagnose_serialization(trainable):
|
||||||
|
"""Utility for detecting accidentally-scoped objects.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
trainable (cls | func): The trainable object passed to
|
||||||
|
tune.run(trainable).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool | set of unserializable objects.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
.. code-block::
|
||||||
|
|
||||||
|
import threading
|
||||||
|
# this is not serializable
|
||||||
|
e = threading.Event()
|
||||||
|
|
||||||
|
def test():
|
||||||
|
print(e)
|
||||||
|
|
||||||
|
diagnose_serialization(test)
|
||||||
|
# should help identify that 'e' should be moved into
|
||||||
|
# the `test` scope.
|
||||||
|
|
||||||
|
# correct implementation
|
||||||
|
def test():
|
||||||
|
e = threading.Event()
|
||||||
|
print(e)
|
||||||
|
|
||||||
|
assert diagnose_serialization(test) is True
|
||||||
|
|
||||||
|
"""
|
||||||
|
from ray.tune.registry import register_trainable, check_serializability
|
||||||
|
|
||||||
|
def check_variables(objects, failure_set, printer):
|
||||||
|
for var_name, variable in objects.items():
|
||||||
|
msg = None
|
||||||
|
try:
|
||||||
|
check_serializability(var_name, variable)
|
||||||
|
status = "PASSED"
|
||||||
|
except Exception as e:
|
||||||
|
status = "FAILED"
|
||||||
|
msg = f"{e.__class__.__name__}: {str(e)}"
|
||||||
|
failure_set.add(var_name)
|
||||||
|
printer(f"{str(variable)}[name='{var_name}'']... {status}")
|
||||||
|
if msg:
|
||||||
|
printer(msg)
|
||||||
|
|
||||||
|
print(f"Trying to serialize {trainable}...")
|
||||||
|
try:
|
||||||
|
register_trainable("__test:" + str(trainable), trainable, warn=False)
|
||||||
|
print("Serialization succeeded!")
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Serialization failed: {e}")
|
||||||
|
|
||||||
|
print("Inspecting the scope of the trainable by running "
|
||||||
|
f"`inspect.getclosurevars({str(trainable)})`...")
|
||||||
|
closure = inspect.getclosurevars(trainable)
|
||||||
|
failure_set = set()
|
||||||
|
if closure.globals:
|
||||||
|
print(f"Detected {len(closure.globals)} global variables. "
|
||||||
|
"Checking serializability...")
|
||||||
|
check_variables(closure.globals, failure_set,
|
||||||
|
lambda s: print(" " + s))
|
||||||
|
|
||||||
|
if closure.nonlocals:
|
||||||
|
print(f"Detected {len(closure.nonlocals)} nonlocal variables. "
|
||||||
|
"Checking serializability...")
|
||||||
|
check_variables(closure.nonlocals, failure_set,
|
||||||
|
lambda s: print(" " + s))
|
||||||
|
|
||||||
|
if not failure_set:
|
||||||
|
print("Nothing was found to have failed the diagnostic test, though "
|
||||||
|
"serialization did not succeed. Feel free to raise an "
|
||||||
|
"issue on github.")
|
||||||
|
return failure_set
|
||||||
|
else:
|
||||||
|
print(f"Variable(s) {failure_set} was found to be non-serializable. "
|
||||||
|
"Consider either removing the instantiation/imports "
|
||||||
|
"of these objects or moving them into the scope of "
|
||||||
|
"the trainable. ")
|
||||||
|
return failure_set
|
||||||
|
|
||||||
|
|
||||||
def validate_save_restore(trainable_cls,
|
def validate_save_restore(trainable_cls,
|
||||||
config=None,
|
config=None,
|
||||||
num_gpus=0,
|
num_gpus=0,
|
||||||
|
|
Loading…
Add table
Reference in a new issue