diff --git a/doc/source/tune/user-guide.rst b/doc/source/tune/user-guide.rst index 33712aa0d..d2c8ffb3e 100644 --- a/doc/source/tune/user-guide.rst +++ b/doc/source/tune/user-guide.rst @@ -197,7 +197,7 @@ Distributed Checkpointing On a multinode cluster, Tune automatically creates a copy of all trial checkpoints on the head node. This requires the Ray cluster to be started with the :ref:`cluster launcher ` and also requires rsync to be installed. -Note that you must use the ``tune.checkpoint_dir`` API to trigger syncing. Also, if running Tune on Kubernetes, be sure to use the :ref:`KubernetesSyncer ` to transfer files between different pods. +Note that you must use the ``tune.checkpoint_dir`` API to trigger syncing. Also, if running Tune on Kubernetes, be sure to use the :ref:`KubernetesSyncer ` to transfer files between different pods. If you do not use the cluster launcher, you should set up a NFS or global file system and disable cross-node syncing: @@ -225,7 +225,7 @@ You often will want to compute a large object (e.g., training data, model weight # X_id can be referenced in closures X_id = pin_in_object_store(np.random.random(size=100000000)) - def f(config, reporter): + def f(config): X = get_pinned_object(X_id) # use X diff --git a/python/ray/tune/experiment.py b/python/ray/tune/experiment.py index ec2bffb31..b6034eb67 100644 --- a/python/ray/tune/experiment.py +++ b/python/ray/tune/experiment.py @@ -1,5 +1,6 @@ import copy import logging +from pickle import PicklingError import os from typing import Sequence @@ -242,7 +243,24 @@ class Experiment: else: logger.warning( "No name detected on trainable. Using {}.".format(name)) - register_trainable(name, run_object) + try: + register_trainable(name, run_object) + except (TypeError, PicklingError) as e: + msg = ( + f"{str(e)}. The trainable ({str(run_object)}) could not " + "be serialized, which is needed for parallel execution. " + "To diagnose the issue, try the following:\n\n" + "\t- Run `tune.utils.diagnose_serialization(trainable)` " + "to check if non-serializable variables are captured " + "in scope.\n" + "\t- Try reproducing the issue by calling " + "`pickle.dumps(trainable)`.\n" + "\t- If the error is typing-related, try removing " + "the type annotations and try again.\n\n" + "If you have any suggestions on how to improve " + "this error message, please reach out to the " + "Ray developers on github.com/ray-project/ray/issues/") + raise type(e)(msg) from None return name else: raise TuneError("Improper 'run' - not string nor trainable.") diff --git a/python/ray/tune/function_runner.py b/python/ray/tune/function_runner.py index 2fd1f9fa1..1c8b5f49c 100644 --- a/python/ray/tune/function_runner.py +++ b/python/ray/tune/function_runner.py @@ -9,6 +9,7 @@ import uuid from six.moves import queue +from ray.util.debug import log_once from ray.tune import TuneError, session from ray.tune.trainable import Trainable, TrainableUtil from ray.tune.result import (TIME_THIS_ITER_S, RESULT_DUPLICATE, @@ -476,34 +477,37 @@ def detect_checkpoint_function(train_func, abort=False): return validated -def wrap_function(train_func): +def wrap_function(train_func, warn=True): if hasattr(train_func, "__mixins__"): inherit_from = train_func.__mixins__ + (FunctionRunner, ) else: inherit_from = (FunctionRunner, ) + func_args = inspect.getfullargspec(train_func).args + use_checkpoint = detect_checkpoint_function(train_func) + if len(func_args) > 1: # more arguments than just the config + if "reporter" not in func_args and not use_checkpoint: + raise ValueError( + "Unknown argument found in the Trainable function. " + "Arguments other than the 'config' arg must be one " + "of ['reporter', 'checkpoint_dir']. Found: {}".format( + func_args)) + + use_reporter = "reporter" in func_args + if not use_checkpoint and not use_reporter: + if log_once("tune_function_checkpoint") and warn: + logger.warning( + "Function checkpointing is disabled. This may result in " + "unexpected behavior when using checkpointing features or " + "certain schedulers. To enable, set the train function " + "arguments to be `func(config, checkpoint_dir=None)`.") + class ImplicitFunc(*inherit_from): _name = train_func.__name__ if hasattr(train_func, "__name__") \ else "func" def _trainable_func(self, config, reporter, checkpoint_dir): - func_args = inspect.getfullargspec(train_func).args - if len(func_args) > 1: # more arguments than just the config - if "reporter" not in func_args and ( - not detect_checkpoint_function(train_func)): - raise ValueError( - "Unknown argument found in the Trainable function. " - "Arguments other than the 'config' arg must be one " - "of ['reporter', 'checkpoint_dir']. Found: {}".format( - func_args)) - use_reporter = "reporter" in func_args - use_checkpoint = detect_checkpoint_function(train_func) if not use_checkpoint and not use_reporter: - logger.warning( - "Function checkpointing is disabled. This may result in " - "unexpected behavior when using checkpointing features or " - "certain schedulers. To enable, set the train function " - "arguments to be `func(config, checkpoint_dir=None)`.") output = train_func(config) elif use_checkpoint: output = train_func(config, checkpoint_dir=checkpoint_dir) diff --git a/python/ray/tune/logger.py b/python/ray/tune/logger.py index a1edf9cbd..945c25537 100644 --- a/python/ray/tune/logger.py +++ b/python/ray/tune/logger.py @@ -196,7 +196,9 @@ class TBXLogger(Logger): try: from tensorboardX import SummaryWriter except ImportError: - logger.error("pip install 'ray[tune]' to see TensorBoard files.") + if log_once("tbx-install"): + logger.info( + "pip install 'ray[tune]' to see TensorBoard files.") raise self._file_writer = SummaryWriter(self.logdir, flush_secs=30) self.last_result = None @@ -329,8 +331,9 @@ class UnifiedLogger(Logger): try: self._loggers.append(cls(self.config, self.logdir, self.trial)) except Exception as exc: - logger.warning("Could not instantiate %s: %s.", cls.__name__, - str(exc)) + if log_once(f"instantiate:{cls.__name__}"): + logger.warning("Could not instantiate %s: %s.", + cls.__name__, str(exc)) self._log_syncer = get_node_syncer( self.logdir, remote_dir=self.logdir, diff --git a/python/ray/tune/registry.py b/python/ray/tune/registry.py index 45fd2a6b9..1cb78f43e 100644 --- a/python/ray/tune/registry.py +++ b/python/ray/tune/registry.py @@ -12,9 +12,10 @@ ENV_CREATOR = "env_creator" RLLIB_MODEL = "rllib_model" RLLIB_PREPROCESSOR = "rllib_preprocessor" RLLIB_ACTION_DIST = "rllib_action_dist" +TEST = "__test__" KNOWN_CATEGORIES = [ TRAINABLE_CLASS, ENV_CREATOR, RLLIB_MODEL, RLLIB_PREPROCESSOR, - RLLIB_ACTION_DIST + RLLIB_ACTION_DIST, TEST ] logger = logging.getLogger(__name__) @@ -38,7 +39,7 @@ def validate_trainable(trainable_name): raise TuneError("Unknown trainable: " + trainable_name) -def register_trainable(name, trainable): +def register_trainable(name, trainable, warn=True): """Register a trainable function or class. This enables a class or function to be accessed on every Ray process @@ -58,11 +59,11 @@ def register_trainable(name, trainable): logger.debug("Detected class for trainable.") elif isinstance(trainable, FunctionType): logger.debug("Detected function for trainable.") - trainable = wrap_function(trainable) + trainable = wrap_function(trainable, warn=warn) elif callable(trainable): - logger.warning( + logger.info( "Detected unknown callable for trainable. Converting to class.") - trainable = wrap_function(trainable) + trainable = wrap_function(trainable, warn=warn) if not issubclass(trainable, Trainable): raise TypeError("Second argument must be convertable to Trainable", @@ -86,6 +87,10 @@ def register_env(name, env_creator): _global_registry.register(ENV_CREATOR, name, env_creator) +def check_serializability(key, value): + _global_registry.register(TEST, key, value) + + def _make_key(category, key): """Generate a binary key for the given category and key. @@ -105,6 +110,11 @@ class _Registry: self._to_flush = {} def register(self, category, key, value): + """Registers the value with the global registry. + + Raises: + PicklingError if unable to pickle to provided file. + """ if category not in KNOWN_CATEGORIES: from ray.tune import TuneError raise TuneError("Unknown category {} not among {}".format( diff --git a/python/ray/tune/tests/test_experiment.py b/python/ray/tune/tests/test_experiment.py index de5fbc5ca..a154231a3 100644 --- a/python/ray/tune/tests/test_experiment.py +++ b/python/ray/tune/tests/test_experiment.py @@ -1,10 +1,12 @@ import unittest +import threading import ray from ray.rllib import _register_all from ray.tune import register_trainable from ray.tune.experiment import Experiment, convert_to_experiment_list from ray.tune.error import TuneError +from ray.tune.utils import diagnose_serialization class ExperimentTest(unittest.TestCase): @@ -71,6 +73,28 @@ class ExperimentTest(unittest.TestCase): self.assertRaises(TuneError, lambda: convert_to_experiment_list("hi")) +class ValidateUtilTest(unittest.TestCase): + def testDiagnoseSerialization(self): + + # this is not serializable + e = threading.Event() + + def test(): + print(e) + + assert diagnose_serialization(test) is not True + + # should help identify that 'e' should be moved into + # the `test` scope. + + # correct implementation + def test(): + e = threading.Event() + print(e) + + assert diagnose_serialization(test) is True + + if __name__ == "__main__": import pytest import sys diff --git a/python/ray/tune/utils/__init__.py b/python/ray/tune/utils/__init__.py index 0eed502b8..6cecf0f34 100644 --- a/python/ray/tune/utils/__init__.py +++ b/python/ray/tune/utils/__init__.py @@ -1,15 +1,9 @@ from ray.tune.utils.util import deep_update, flatten_dict, get_pinned_object, \ merge_dicts, pin_in_object_store, unflattened_lookup, UtilMonitor, \ - validate_save_restore, warn_if_slow + validate_save_restore, warn_if_slow, diagnose_serialization __all__ = [ - "deep_update", - "flatten_dict", - "get_pinned_object", - "merge_dicts", - "pin_in_object_store", - "unflattened_lookup", - "UtilMonitor", - "validate_save_restore", - "warn_if_slow", + "deep_update", "flatten_dict", "get_pinned_object", "merge_dicts", + "pin_in_object_store", "unflattened_lookup", "UtilMonitor", + "validate_save_restore", "warn_if_slow", "diagnose_serialization" ] diff --git a/python/ray/tune/utils/util.py b/python/ray/tune/utils/util.py index fb217d9d9..fefcf7665 100644 --- a/python/ray/tune/utils/util.py +++ b/python/ray/tune/utils/util.py @@ -1,5 +1,6 @@ import copy import logging +import inspect import threading import time from collections import defaultdict, deque, Mapping, Sequence @@ -269,6 +270,92 @@ def _from_pinnable(obj): return obj[0] +def diagnose_serialization(trainable): + """Utility for detecting accidentally-scoped objects. + + Args: + trainable (cls | func): The trainable object passed to + tune.run(trainable). + + Returns: + bool | set of unserializable objects. + + Example: + + .. code-block:: + + import threading + # this is not serializable + e = threading.Event() + + def test(): + print(e) + + diagnose_serialization(test) + # should help identify that 'e' should be moved into + # the `test` scope. + + # correct implementation + def test(): + e = threading.Event() + print(e) + + assert diagnose_serialization(test) is True + + """ + from ray.tune.registry import register_trainable, check_serializability + + def check_variables(objects, failure_set, printer): + for var_name, variable in objects.items(): + msg = None + try: + check_serializability(var_name, variable) + status = "PASSED" + except Exception as e: + status = "FAILED" + msg = f"{e.__class__.__name__}: {str(e)}" + failure_set.add(var_name) + printer(f"{str(variable)}[name='{var_name}'']... {status}") + if msg: + printer(msg) + + print(f"Trying to serialize {trainable}...") + try: + register_trainable("__test:" + str(trainable), trainable, warn=False) + print("Serialization succeeded!") + return True + except Exception as e: + print(f"Serialization failed: {e}") + + print("Inspecting the scope of the trainable by running " + f"`inspect.getclosurevars({str(trainable)})`...") + closure = inspect.getclosurevars(trainable) + failure_set = set() + if closure.globals: + print(f"Detected {len(closure.globals)} global variables. " + "Checking serializability...") + check_variables(closure.globals, failure_set, + lambda s: print(" " + s)) + + if closure.nonlocals: + print(f"Detected {len(closure.nonlocals)} nonlocal variables. " + "Checking serializability...") + check_variables(closure.nonlocals, failure_set, + lambda s: print(" " + s)) + + if not failure_set: + print("Nothing was found to have failed the diagnostic test, though " + "serialization did not succeed. Feel free to raise an " + "issue on github.") + return failure_set + else: + print(f"Variable(s) {failure_set} was found to be non-serializable. " + "Consider either removing the instantiation/imports " + "of these objects or moving them into the scope of " + "the trainable. ") + return failure_set + + def validate_save_restore(trainable_cls, config=None, num_gpus=0,