[tune] cleanup error messaging/diagnose_serialization helper (#10210)

This commit is contained in:
Richard Liaw 2020-08-22 11:50:49 -07:00 committed by GitHub
parent 24ee496b89
commit 6bd5458bef
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 178 additions and 38 deletions

View file

@ -197,7 +197,7 @@ Distributed Checkpointing
On a multinode cluster, Tune automatically creates a copy of all trial checkpoints on the head node. This requires the Ray cluster to be started with the :ref:`cluster launcher <ref-automatic-cluster>` and also requires rsync to be installed.
Note that you must use the ``tune.checkpoint_dir`` API to trigger syncing. Also, if running Tune on Kubernetes, be sure to use the :ref:`KubernetesSyncer <tune-kubernetes>` to transfer files between different pods.
Note that you must use the ``tune.checkpoint_dir`` API to trigger syncing. Also, if running Tune on Kubernetes, be sure to use the :ref:`KubernetesSyncer <tune-kubernetes>` to transfer files between different pods.
If you do not use the cluster launcher, you should set up a NFS or global file system and
disable cross-node syncing:
@ -225,7 +225,7 @@ You often will want to compute a large object (e.g., training data, model weight
# X_id can be referenced in closures
X_id = pin_in_object_store(np.random.random(size=100000000))
def f(config, reporter):
def f(config):
X = get_pinned_object(X_id)
# use X

View file

@ -1,5 +1,6 @@
import copy
import logging
from pickle import PicklingError
import os
from typing import Sequence
@ -242,7 +243,24 @@ class Experiment:
else:
logger.warning(
"No name detected on trainable. Using {}.".format(name))
register_trainable(name, run_object)
try:
register_trainable(name, run_object)
except (TypeError, PicklingError) as e:
msg = (
f"{str(e)}. The trainable ({str(run_object)}) could not "
"be serialized, which is needed for parallel execution. "
"To diagnose the issue, try the following:\n\n"
"\t- Run `tune.utils.diagnose_serialization(trainable)` "
"to check if non-serializable variables are captured "
"in scope.\n"
"\t- Try reproducing the issue by calling "
"`pickle.dumps(trainable)`.\n"
"\t- If the error is typing-related, try removing "
"the type annotations and try again.\n\n"
"If you have any suggestions on how to improve "
"this error message, please reach out to the "
"Ray developers on github.com/ray-project/ray/issues/")
raise type(e)(msg) from None
return name
else:
raise TuneError("Improper 'run' - not string nor trainable.")

View file

@ -9,6 +9,7 @@ import uuid
from six.moves import queue
from ray.util.debug import log_once
from ray.tune import TuneError, session
from ray.tune.trainable import Trainable, TrainableUtil
from ray.tune.result import (TIME_THIS_ITER_S, RESULT_DUPLICATE,
@ -476,34 +477,37 @@ def detect_checkpoint_function(train_func, abort=False):
return validated
def wrap_function(train_func):
def wrap_function(train_func, warn=True):
if hasattr(train_func, "__mixins__"):
inherit_from = train_func.__mixins__ + (FunctionRunner, )
else:
inherit_from = (FunctionRunner, )
func_args = inspect.getfullargspec(train_func).args
use_checkpoint = detect_checkpoint_function(train_func)
if len(func_args) > 1: # more arguments than just the config
if "reporter" not in func_args and not use_checkpoint:
raise ValueError(
"Unknown argument found in the Trainable function. "
"Arguments other than the 'config' arg must be one "
"of ['reporter', 'checkpoint_dir']. Found: {}".format(
func_args))
use_reporter = "reporter" in func_args
if not use_checkpoint and not use_reporter:
if log_once("tune_function_checkpoint") and warn:
logger.warning(
"Function checkpointing is disabled. This may result in "
"unexpected behavior when using checkpointing features or "
"certain schedulers. To enable, set the train function "
"arguments to be `func(config, checkpoint_dir=None)`.")
class ImplicitFunc(*inherit_from):
_name = train_func.__name__ if hasattr(train_func, "__name__") \
else "func"
def _trainable_func(self, config, reporter, checkpoint_dir):
func_args = inspect.getfullargspec(train_func).args
if len(func_args) > 1: # more arguments than just the config
if "reporter" not in func_args and (
not detect_checkpoint_function(train_func)):
raise ValueError(
"Unknown argument found in the Trainable function. "
"Arguments other than the 'config' arg must be one "
"of ['reporter', 'checkpoint_dir']. Found: {}".format(
func_args))
use_reporter = "reporter" in func_args
use_checkpoint = detect_checkpoint_function(train_func)
if not use_checkpoint and not use_reporter:
logger.warning(
"Function checkpointing is disabled. This may result in "
"unexpected behavior when using checkpointing features or "
"certain schedulers. To enable, set the train function "
"arguments to be `func(config, checkpoint_dir=None)`.")
output = train_func(config)
elif use_checkpoint:
output = train_func(config, checkpoint_dir=checkpoint_dir)

View file

@ -196,7 +196,9 @@ class TBXLogger(Logger):
try:
from tensorboardX import SummaryWriter
except ImportError:
logger.error("pip install 'ray[tune]' to see TensorBoard files.")
if log_once("tbx-install"):
logger.info(
"pip install 'ray[tune]' to see TensorBoard files.")
raise
self._file_writer = SummaryWriter(self.logdir, flush_secs=30)
self.last_result = None
@ -329,8 +331,9 @@ class UnifiedLogger(Logger):
try:
self._loggers.append(cls(self.config, self.logdir, self.trial))
except Exception as exc:
logger.warning("Could not instantiate %s: %s.", cls.__name__,
str(exc))
if log_once(f"instantiate:{cls.__name__}"):
logger.warning("Could not instantiate %s: %s.",
cls.__name__, str(exc))
self._log_syncer = get_node_syncer(
self.logdir,
remote_dir=self.logdir,

View file

@ -12,9 +12,10 @@ ENV_CREATOR = "env_creator"
RLLIB_MODEL = "rllib_model"
RLLIB_PREPROCESSOR = "rllib_preprocessor"
RLLIB_ACTION_DIST = "rllib_action_dist"
TEST = "__test__"
KNOWN_CATEGORIES = [
TRAINABLE_CLASS, ENV_CREATOR, RLLIB_MODEL, RLLIB_PREPROCESSOR,
RLLIB_ACTION_DIST
RLLIB_ACTION_DIST, TEST
]
logger = logging.getLogger(__name__)
@ -38,7 +39,7 @@ def validate_trainable(trainable_name):
raise TuneError("Unknown trainable: " + trainable_name)
def register_trainable(name, trainable):
def register_trainable(name, trainable, warn=True):
"""Register a trainable function or class.
This enables a class or function to be accessed on every Ray process
@ -58,11 +59,11 @@ def register_trainable(name, trainable):
logger.debug("Detected class for trainable.")
elif isinstance(trainable, FunctionType):
logger.debug("Detected function for trainable.")
trainable = wrap_function(trainable)
trainable = wrap_function(trainable, warn=warn)
elif callable(trainable):
logger.warning(
logger.info(
"Detected unknown callable for trainable. Converting to class.")
trainable = wrap_function(trainable)
trainable = wrap_function(trainable, warn=warn)
if not issubclass(trainable, Trainable):
raise TypeError("Second argument must be convertable to Trainable",
@ -86,6 +87,10 @@ def register_env(name, env_creator):
_global_registry.register(ENV_CREATOR, name, env_creator)
def check_serializability(key, value):
_global_registry.register(TEST, key, value)
def _make_key(category, key):
"""Generate a binary key for the given category and key.
@ -105,6 +110,11 @@ class _Registry:
self._to_flush = {}
def register(self, category, key, value):
"""Registers the value with the global registry.
Raises:
PicklingError if unable to pickle to provided file.
"""
if category not in KNOWN_CATEGORIES:
from ray.tune import TuneError
raise TuneError("Unknown category {} not among {}".format(

View file

@ -1,10 +1,12 @@
import unittest
import threading
import ray
from ray.rllib import _register_all
from ray.tune import register_trainable
from ray.tune.experiment import Experiment, convert_to_experiment_list
from ray.tune.error import TuneError
from ray.tune.utils import diagnose_serialization
class ExperimentTest(unittest.TestCase):
@ -71,6 +73,28 @@ class ExperimentTest(unittest.TestCase):
self.assertRaises(TuneError, lambda: convert_to_experiment_list("hi"))
class ValidateUtilTest(unittest.TestCase):
def testDiagnoseSerialization(self):
# this is not serializable
e = threading.Event()
def test():
print(e)
assert diagnose_serialization(test) is not True
# should help identify that 'e' should be moved into
# the `test` scope.
# correct implementation
def test():
e = threading.Event()
print(e)
assert diagnose_serialization(test) is True
if __name__ == "__main__":
import pytest
import sys

View file

@ -1,15 +1,9 @@
from ray.tune.utils.util import deep_update, flatten_dict, get_pinned_object, \
merge_dicts, pin_in_object_store, unflattened_lookup, UtilMonitor, \
validate_save_restore, warn_if_slow
validate_save_restore, warn_if_slow, diagnose_serialization
__all__ = [
"deep_update",
"flatten_dict",
"get_pinned_object",
"merge_dicts",
"pin_in_object_store",
"unflattened_lookup",
"UtilMonitor",
"validate_save_restore",
"warn_if_slow",
"deep_update", "flatten_dict", "get_pinned_object", "merge_dicts",
"pin_in_object_store", "unflattened_lookup", "UtilMonitor",
"validate_save_restore", "warn_if_slow", "diagnose_serialization"
]

View file

@ -1,5 +1,6 @@
import copy
import logging
import inspect
import threading
import time
from collections import defaultdict, deque, Mapping, Sequence
@ -269,6 +270,92 @@ def _from_pinnable(obj):
return obj[0]
def diagnose_serialization(trainable):
"""Utility for detecting accidentally-scoped objects.
Args:
trainable (cls | func): The trainable object passed to
tune.run(trainable).
Returns:
bool | set of unserializable objects.
Example:
.. code-block::
import threading
# this is not serializable
e = threading.Event()
def test():
print(e)
diagnose_serialization(test)
# should help identify that 'e' should be moved into
# the `test` scope.
# correct implementation
def test():
e = threading.Event()
print(e)
assert diagnose_serialization(test) is True
"""
from ray.tune.registry import register_trainable, check_serializability
def check_variables(objects, failure_set, printer):
for var_name, variable in objects.items():
msg = None
try:
check_serializability(var_name, variable)
status = "PASSED"
except Exception as e:
status = "FAILED"
msg = f"{e.__class__.__name__}: {str(e)}"
failure_set.add(var_name)
printer(f"{str(variable)}[name='{var_name}'']... {status}")
if msg:
printer(msg)
print(f"Trying to serialize {trainable}...")
try:
register_trainable("__test:" + str(trainable), trainable, warn=False)
print("Serialization succeeded!")
return True
except Exception as e:
print(f"Serialization failed: {e}")
print("Inspecting the scope of the trainable by running "
f"`inspect.getclosurevars({str(trainable)})`...")
closure = inspect.getclosurevars(trainable)
failure_set = set()
if closure.globals:
print(f"Detected {len(closure.globals)} global variables. "
"Checking serializability...")
check_variables(closure.globals, failure_set,
lambda s: print(" " + s))
if closure.nonlocals:
print(f"Detected {len(closure.nonlocals)} nonlocal variables. "
"Checking serializability...")
check_variables(closure.nonlocals, failure_set,
lambda s: print(" " + s))
if not failure_set:
print("Nothing was found to have failed the diagnostic test, though "
"serialization did not succeed. Feel free to raise an "
"issue on github.")
return failure_set
else:
print(f"Variable(s) {failure_set} was found to be non-serializable. "
"Consider either removing the instantiation/imports "
"of these objects or moving them into the scope of "
"the trainable. ")
return failure_set
def validate_save_restore(trainable_cls,
config=None,
num_gpus=0,