[tune] cleanup error messaging/diagnose_serialization helper (#10210)

This commit is contained in:
Richard Liaw 2020-08-22 11:50:49 -07:00 committed by GitHub
parent 24ee496b89
commit 6bd5458bef
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 178 additions and 38 deletions

View file

@ -197,7 +197,7 @@ Distributed Checkpointing
On a multinode cluster, Tune automatically creates a copy of all trial checkpoints on the head node. This requires the Ray cluster to be started with the :ref:`cluster launcher <ref-automatic-cluster>` and also requires rsync to be installed. On a multinode cluster, Tune automatically creates a copy of all trial checkpoints on the head node. This requires the Ray cluster to be started with the :ref:`cluster launcher <ref-automatic-cluster>` and also requires rsync to be installed.
Note that you must use the ``tune.checkpoint_dir`` API to trigger syncing. Also, if running Tune on Kubernetes, be sure to use the :ref:`KubernetesSyncer <tune-kubernetes>` to transfer files between different pods. Note that you must use the ``tune.checkpoint_dir`` API to trigger syncing. Also, if running Tune on Kubernetes, be sure to use the :ref:`KubernetesSyncer <tune-kubernetes>` to transfer files between different pods.
If you do not use the cluster launcher, you should set up a NFS or global file system and If you do not use the cluster launcher, you should set up a NFS or global file system and
disable cross-node syncing: disable cross-node syncing:
@ -225,7 +225,7 @@ You often will want to compute a large object (e.g., training data, model weight
# X_id can be referenced in closures # X_id can be referenced in closures
X_id = pin_in_object_store(np.random.random(size=100000000)) X_id = pin_in_object_store(np.random.random(size=100000000))
def f(config, reporter): def f(config):
X = get_pinned_object(X_id) X = get_pinned_object(X_id)
# use X # use X

View file

@ -1,5 +1,6 @@
import copy import copy
import logging import logging
from pickle import PicklingError
import os import os
from typing import Sequence from typing import Sequence
@ -242,7 +243,24 @@ class Experiment:
else: else:
logger.warning( logger.warning(
"No name detected on trainable. Using {}.".format(name)) "No name detected on trainable. Using {}.".format(name))
register_trainable(name, run_object) try:
register_trainable(name, run_object)
except (TypeError, PicklingError) as e:
msg = (
f"{str(e)}. The trainable ({str(run_object)}) could not "
"be serialized, which is needed for parallel execution. "
"To diagnose the issue, try the following:\n\n"
"\t- Run `tune.utils.diagnose_serialization(trainable)` "
"to check if non-serializable variables are captured "
"in scope.\n"
"\t- Try reproducing the issue by calling "
"`pickle.dumps(trainable)`.\n"
"\t- If the error is typing-related, try removing "
"the type annotations and try again.\n\n"
"If you have any suggestions on how to improve "
"this error message, please reach out to the "
"Ray developers on github.com/ray-project/ray/issues/")
raise type(e)(msg) from None
return name return name
else: else:
raise TuneError("Improper 'run' - not string nor trainable.") raise TuneError("Improper 'run' - not string nor trainable.")

View file

@ -9,6 +9,7 @@ import uuid
from six.moves import queue from six.moves import queue
from ray.util.debug import log_once
from ray.tune import TuneError, session from ray.tune import TuneError, session
from ray.tune.trainable import Trainable, TrainableUtil from ray.tune.trainable import Trainable, TrainableUtil
from ray.tune.result import (TIME_THIS_ITER_S, RESULT_DUPLICATE, from ray.tune.result import (TIME_THIS_ITER_S, RESULT_DUPLICATE,
@ -476,34 +477,37 @@ def detect_checkpoint_function(train_func, abort=False):
return validated return validated
def wrap_function(train_func): def wrap_function(train_func, warn=True):
if hasattr(train_func, "__mixins__"): if hasattr(train_func, "__mixins__"):
inherit_from = train_func.__mixins__ + (FunctionRunner, ) inherit_from = train_func.__mixins__ + (FunctionRunner, )
else: else:
inherit_from = (FunctionRunner, ) inherit_from = (FunctionRunner, )
func_args = inspect.getfullargspec(train_func).args
use_checkpoint = detect_checkpoint_function(train_func)
if len(func_args) > 1: # more arguments than just the config
if "reporter" not in func_args and not use_checkpoint:
raise ValueError(
"Unknown argument found in the Trainable function. "
"Arguments other than the 'config' arg must be one "
"of ['reporter', 'checkpoint_dir']. Found: {}".format(
func_args))
use_reporter = "reporter" in func_args
if not use_checkpoint and not use_reporter:
if log_once("tune_function_checkpoint") and warn:
logger.warning(
"Function checkpointing is disabled. This may result in "
"unexpected behavior when using checkpointing features or "
"certain schedulers. To enable, set the train function "
"arguments to be `func(config, checkpoint_dir=None)`.")
class ImplicitFunc(*inherit_from): class ImplicitFunc(*inherit_from):
_name = train_func.__name__ if hasattr(train_func, "__name__") \ _name = train_func.__name__ if hasattr(train_func, "__name__") \
else "func" else "func"
def _trainable_func(self, config, reporter, checkpoint_dir): def _trainable_func(self, config, reporter, checkpoint_dir):
func_args = inspect.getfullargspec(train_func).args
if len(func_args) > 1: # more arguments than just the config
if "reporter" not in func_args and (
not detect_checkpoint_function(train_func)):
raise ValueError(
"Unknown argument found in the Trainable function. "
"Arguments other than the 'config' arg must be one "
"of ['reporter', 'checkpoint_dir']. Found: {}".format(
func_args))
use_reporter = "reporter" in func_args
use_checkpoint = detect_checkpoint_function(train_func)
if not use_checkpoint and not use_reporter: if not use_checkpoint and not use_reporter:
logger.warning(
"Function checkpointing is disabled. This may result in "
"unexpected behavior when using checkpointing features or "
"certain schedulers. To enable, set the train function "
"arguments to be `func(config, checkpoint_dir=None)`.")
output = train_func(config) output = train_func(config)
elif use_checkpoint: elif use_checkpoint:
output = train_func(config, checkpoint_dir=checkpoint_dir) output = train_func(config, checkpoint_dir=checkpoint_dir)

View file

@ -196,7 +196,9 @@ class TBXLogger(Logger):
try: try:
from tensorboardX import SummaryWriter from tensorboardX import SummaryWriter
except ImportError: except ImportError:
logger.error("pip install 'ray[tune]' to see TensorBoard files.") if log_once("tbx-install"):
logger.info(
"pip install 'ray[tune]' to see TensorBoard files.")
raise raise
self._file_writer = SummaryWriter(self.logdir, flush_secs=30) self._file_writer = SummaryWriter(self.logdir, flush_secs=30)
self.last_result = None self.last_result = None
@ -329,8 +331,9 @@ class UnifiedLogger(Logger):
try: try:
self._loggers.append(cls(self.config, self.logdir, self.trial)) self._loggers.append(cls(self.config, self.logdir, self.trial))
except Exception as exc: except Exception as exc:
logger.warning("Could not instantiate %s: %s.", cls.__name__, if log_once(f"instantiate:{cls.__name__}"):
str(exc)) logger.warning("Could not instantiate %s: %s.",
cls.__name__, str(exc))
self._log_syncer = get_node_syncer( self._log_syncer = get_node_syncer(
self.logdir, self.logdir,
remote_dir=self.logdir, remote_dir=self.logdir,

View file

@ -12,9 +12,10 @@ ENV_CREATOR = "env_creator"
RLLIB_MODEL = "rllib_model" RLLIB_MODEL = "rllib_model"
RLLIB_PREPROCESSOR = "rllib_preprocessor" RLLIB_PREPROCESSOR = "rllib_preprocessor"
RLLIB_ACTION_DIST = "rllib_action_dist" RLLIB_ACTION_DIST = "rllib_action_dist"
TEST = "__test__"
KNOWN_CATEGORIES = [ KNOWN_CATEGORIES = [
TRAINABLE_CLASS, ENV_CREATOR, RLLIB_MODEL, RLLIB_PREPROCESSOR, TRAINABLE_CLASS, ENV_CREATOR, RLLIB_MODEL, RLLIB_PREPROCESSOR,
RLLIB_ACTION_DIST RLLIB_ACTION_DIST, TEST
] ]
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -38,7 +39,7 @@ def validate_trainable(trainable_name):
raise TuneError("Unknown trainable: " + trainable_name) raise TuneError("Unknown trainable: " + trainable_name)
def register_trainable(name, trainable): def register_trainable(name, trainable, warn=True):
"""Register a trainable function or class. """Register a trainable function or class.
This enables a class or function to be accessed on every Ray process This enables a class or function to be accessed on every Ray process
@ -58,11 +59,11 @@ def register_trainable(name, trainable):
logger.debug("Detected class for trainable.") logger.debug("Detected class for trainable.")
elif isinstance(trainable, FunctionType): elif isinstance(trainable, FunctionType):
logger.debug("Detected function for trainable.") logger.debug("Detected function for trainable.")
trainable = wrap_function(trainable) trainable = wrap_function(trainable, warn=warn)
elif callable(trainable): elif callable(trainable):
logger.warning( logger.info(
"Detected unknown callable for trainable. Converting to class.") "Detected unknown callable for trainable. Converting to class.")
trainable = wrap_function(trainable) trainable = wrap_function(trainable, warn=warn)
if not issubclass(trainable, Trainable): if not issubclass(trainable, Trainable):
raise TypeError("Second argument must be convertable to Trainable", raise TypeError("Second argument must be convertable to Trainable",
@ -86,6 +87,10 @@ def register_env(name, env_creator):
_global_registry.register(ENV_CREATOR, name, env_creator) _global_registry.register(ENV_CREATOR, name, env_creator)
def check_serializability(key, value):
_global_registry.register(TEST, key, value)
def _make_key(category, key): def _make_key(category, key):
"""Generate a binary key for the given category and key. """Generate a binary key for the given category and key.
@ -105,6 +110,11 @@ class _Registry:
self._to_flush = {} self._to_flush = {}
def register(self, category, key, value): def register(self, category, key, value):
"""Registers the value with the global registry.
Raises:
PicklingError if unable to pickle to provided file.
"""
if category not in KNOWN_CATEGORIES: if category not in KNOWN_CATEGORIES:
from ray.tune import TuneError from ray.tune import TuneError
raise TuneError("Unknown category {} not among {}".format( raise TuneError("Unknown category {} not among {}".format(

View file

@ -1,10 +1,12 @@
import unittest import unittest
import threading
import ray import ray
from ray.rllib import _register_all from ray.rllib import _register_all
from ray.tune import register_trainable from ray.tune import register_trainable
from ray.tune.experiment import Experiment, convert_to_experiment_list from ray.tune.experiment import Experiment, convert_to_experiment_list
from ray.tune.error import TuneError from ray.tune.error import TuneError
from ray.tune.utils import diagnose_serialization
class ExperimentTest(unittest.TestCase): class ExperimentTest(unittest.TestCase):
@ -71,6 +73,28 @@ class ExperimentTest(unittest.TestCase):
self.assertRaises(TuneError, lambda: convert_to_experiment_list("hi")) self.assertRaises(TuneError, lambda: convert_to_experiment_list("hi"))
class ValidateUtilTest(unittest.TestCase):
def testDiagnoseSerialization(self):
# this is not serializable
e = threading.Event()
def test():
print(e)
assert diagnose_serialization(test) is not True
# should help identify that 'e' should be moved into
# the `test` scope.
# correct implementation
def test():
e = threading.Event()
print(e)
assert diagnose_serialization(test) is True
if __name__ == "__main__": if __name__ == "__main__":
import pytest import pytest
import sys import sys

View file

@ -1,15 +1,9 @@
from ray.tune.utils.util import deep_update, flatten_dict, get_pinned_object, \ from ray.tune.utils.util import deep_update, flatten_dict, get_pinned_object, \
merge_dicts, pin_in_object_store, unflattened_lookup, UtilMonitor, \ merge_dicts, pin_in_object_store, unflattened_lookup, UtilMonitor, \
validate_save_restore, warn_if_slow validate_save_restore, warn_if_slow, diagnose_serialization
__all__ = [ __all__ = [
"deep_update", "deep_update", "flatten_dict", "get_pinned_object", "merge_dicts",
"flatten_dict", "pin_in_object_store", "unflattened_lookup", "UtilMonitor",
"get_pinned_object", "validate_save_restore", "warn_if_slow", "diagnose_serialization"
"merge_dicts",
"pin_in_object_store",
"unflattened_lookup",
"UtilMonitor",
"validate_save_restore",
"warn_if_slow",
] ]

View file

@ -1,5 +1,6 @@
import copy import copy
import logging import logging
import inspect
import threading import threading
import time import time
from collections import defaultdict, deque, Mapping, Sequence from collections import defaultdict, deque, Mapping, Sequence
@ -269,6 +270,92 @@ def _from_pinnable(obj):
return obj[0] return obj[0]
def diagnose_serialization(trainable):
"""Utility for detecting accidentally-scoped objects.
Args:
trainable (cls | func): The trainable object passed to
tune.run(trainable).
Returns:
bool | set of unserializable objects.
Example:
.. code-block::
import threading
# this is not serializable
e = threading.Event()
def test():
print(e)
diagnose_serialization(test)
# should help identify that 'e' should be moved into
# the `test` scope.
# correct implementation
def test():
e = threading.Event()
print(e)
assert diagnose_serialization(test) is True
"""
from ray.tune.registry import register_trainable, check_serializability
def check_variables(objects, failure_set, printer):
for var_name, variable in objects.items():
msg = None
try:
check_serializability(var_name, variable)
status = "PASSED"
except Exception as e:
status = "FAILED"
msg = f"{e.__class__.__name__}: {str(e)}"
failure_set.add(var_name)
printer(f"{str(variable)}[name='{var_name}'']... {status}")
if msg:
printer(msg)
print(f"Trying to serialize {trainable}...")
try:
register_trainable("__test:" + str(trainable), trainable, warn=False)
print("Serialization succeeded!")
return True
except Exception as e:
print(f"Serialization failed: {e}")
print("Inspecting the scope of the trainable by running "
f"`inspect.getclosurevars({str(trainable)})`...")
closure = inspect.getclosurevars(trainable)
failure_set = set()
if closure.globals:
print(f"Detected {len(closure.globals)} global variables. "
"Checking serializability...")
check_variables(closure.globals, failure_set,
lambda s: print(" " + s))
if closure.nonlocals:
print(f"Detected {len(closure.nonlocals)} nonlocal variables. "
"Checking serializability...")
check_variables(closure.nonlocals, failure_set,
lambda s: print(" " + s))
if not failure_set:
print("Nothing was found to have failed the diagnostic test, though "
"serialization did not succeed. Feel free to raise an "
"issue on github.")
return failure_set
else:
print(f"Variable(s) {failure_set} was found to be non-serializable. "
"Consider either removing the instantiation/imports "
"of these objects or moving them into the scope of "
"the trainable. ")
return failure_set
def validate_save_restore(trainable_cls, def validate_save_restore(trainable_cls,
config=None, config=None,
num_gpus=0, num_gpus=0,