From 4014168928755c8bb6e774a9d906cd231ac62836 Mon Sep 17 00:00:00 2001 From: Kai Fricke Date: Fri, 26 Feb 2021 13:59:28 +0100 Subject: [PATCH] [tune] Introduce `durable()` wrapper to convert trainables into durable trainables (#14306) * [tune] Introduce `durable()` wrapper to convert trainables into durable trainables * Fix wrong check * Improve docs, add FAQ for tackling overhead * Fix bugs in `tune.with_parameters` * Update doc/source/tune/api_docs/trainable.rst Co-authored-by: Richard Liaw * Update doc/source/tune/_tutorials/_faq.rst Co-authored-by: Richard Liaw Co-authored-by: Richard Liaw --- doc/source/tune/_tutorials/_faq.rst | 75 +++++++++++++ doc/source/tune/api_docs/trainable.rst | 14 +++ python/ray/tune/__init__.py | 22 ++-- python/ray/tune/durable_trainable.py | 82 ++++++++++++++ python/ray/tune/experiment.py | 14 +-- python/ray/tune/function_runner.py | 16 +-- python/ray/tune/tests/test_api.py | 101 +++++++++++++++--- python/ray/tune/utils/util.py | 7 +- .../scalability_tests/workloads/_trainable.py | 1 + 9 files changed, 294 insertions(+), 38 deletions(-) diff --git a/doc/source/tune/_tutorials/_faq.rst b/doc/source/tune/_tutorials/_faq.rst index f525df741..f31a6afb9 100644 --- a/doc/source/tune/_tutorials/_faq.rst +++ b/doc/source/tune/_tutorials/_faq.rst @@ -394,6 +394,81 @@ We strongly advise to try reproduction on smaller toy problems first before rely on it for larger experiments. +How can I avoid bottlenecks? +~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +Sometimes you might run into a message like this: + +.. code-block:: + + The `experiment_checkpoint` operation took 2.43 seconds to complete, which may be a performance bottleneck + +Most commonly, the ``experiment_checkpoint`` operation is throwing this warning, but it might be something else, +like ``process_trial_result``. + +These operations should usually take less than 500ms to complete. When it consistently takes longer, this might +indicate a problem or inefficiencies. To get rid of this message, it is important to understand where it comes +from. + +These are the main reasons this problem comes up: + +**The Trial config is very large** + +This is the case if you e.g. try to pass a dataset or other large object via the ``config`` parameter. +If this is the case, the dataset is serialized and written to disk repeatedly during experiment +checkpointing, which takes a long time. + +**Solution**: Use :func:`tune.with_parameters ` to pass large objects to +function trainables via the objects store. For class trainables you can do this manually via ``ray.put()`` +and ``ray.get()``. If you need to pass a class definition, consider passing an +indicator (e.g. a string) instead and let the trainable select the class instead. Generally, your config +dictionary should only contain primitive types, like numbers or strings. + +**The Trial result is very large** + +This is the case if you return objects, data, or other large objects via the return value of ``step()`` in +your class trainable or to ``tune.report()`` in your function trainable. The effect is the same as above: +The results are repeatedly serialized and written to disk, and this can take a long time. + +**Solution**: Usually you should be able to write data to the trial directory instead. You can then pass a +filename back (or assume it is a known location). The trial dir is usually the current working directory. Class +trainables have the ``Trainable.logdir`` property and function trainables the :func:`ray.tune.get_trial_dir` +function to retrieve the logdir. If you really have to, you can also ``ray.put()`` an object to the Ray +object store and retrieve it with ``ray.get()`` on the other side. Generally, your result dictionary +should only contain primitive types, like numbers or strings. + +**You are training a large number of trials on a cluster, or you are saving huge checkpoints** + +Checkpoints and logs are synced between nodes +- usually at least to the driver on the head node, but sometimes between worker nodes if needed (e.g. when +using :ref:`Population Based Training `). If these checkpoints are very large (e.g. for +NLP models), or if you are training a large number of trials, this syncing can take a long time. + +If nothing else is specified, syncing happens via SSH, which can lead to network overhead as connections are +not kept open by Ray Tune. + +**Solution**: There are multiple solutions, depending on your needs: + +1. You can disable syncing to the driver in the :class:`tune.SyncConfig `. In this case, + logs and checkpoints will not be synced to the driver, so if you need to access them later, you will have to + transfer them where you need them manually. + +2. You can use the :ref:`ray.tune.durable ` wrapper to save logs and checkpoints to a specified `upload_dir`. + This is the preferred way to deal with this. All syncing will be taken care of automatically, as all nodes + are able to access the cloud storage. Additionally, your results will be safe, so even when you're working on + pre-emptible instances, you won't lose any of your data. + +**You are reporting results too often** + +Each result is processed by the search algorithm, trial scheduler, and callbacks (including loggers and the +trial syncer). If you're reporting a large number of results per trial (e.g. multiple results per second), +this can take a long time. + +**Solution**: The solution here is obvious: Just don't report results that often. In class trainables, ``step()`` +should maybe process a larger chunk of data. In function trainables, you can report only every n-th iteration +of the training loop. Try to balance the number of results you really need to make scheduling or searching +decisions. If you need more fine grained metrics for logging or tracking, consider using a separate logging +mechanism for this instead of the Ray Tune-provided progress logging of results. + Further Questions or Issues? diff --git a/doc/source/tune/api_docs/trainable.rst b/doc/source/tune/api_docs/trainable.rst index 49300401d..fb510cb8c 100644 --- a/doc/source/tune/api_docs/trainable.rst +++ b/doc/source/tune/api_docs/trainable.rst @@ -377,11 +377,25 @@ Ray also offers lightweight integrations to distribute your TensorFlow training .. autofunction:: ray.tune.integration.tensorflow.DistributedTrainableCreator :noindex: +.. _tune-durable-trainable: + tune.DurableTrainable --------------------- + +Tune provides a :func:`ray.tune.durable` wrapper that can be used to convert any kind of trainable +to a ``DurableTrainable``, including pre-registered RLLib trainers and :ref:`function trainables `. + + +The :class:`DurableTrainable ` syncs trial logs and checkpoints to cloud storage (via the `upload_dir`). This is especially +useful when training a large number of distributed trials, as logs and checkpoints are otherwise synchronized +via SSH, which quickly can become a performance bottleneck. The :class:`DurableTrainable ` class inherits from +:class:`Trainable ` and thus can be extended like the base class. + .. autoclass:: ray.tune.DurableTrainable +.. autofunction:: ray.tune.durable + StatusReporter -------------- diff --git a/python/ray/tune/__init__.py b/python/ray/tune/__init__.py index 1b5831226..5d9b9abdd 100644 --- a/python/ray/tune/__init__.py +++ b/python/ray/tune/__init__.py @@ -7,7 +7,7 @@ from ray.tune.analysis import ExperimentAnalysis, Analysis from ray.tune.stopper import Stopper, EarlyStopping from ray.tune.registry import register_env, register_trainable from ray.tune.trainable import Trainable -from ray.tune.durable_trainable import DurableTrainable +from ray.tune.durable_trainable import DurableTrainable, durable from ray.tune.callback import Callback from ray.tune.suggest import grid_search from ray.tune.session import ( @@ -23,14 +23,14 @@ from ray.tune.schedulers import create_scheduler from ray.tune.utils.placement_groups import PlacementGroupFactory __all__ = [ - "Trainable", "DurableTrainable", "Callback", "TuneError", "grid_search", - "register_env", "register_trainable", "run", "run_experiments", - "with_parameters", "Stopper", "EarlyStopping", "Experiment", "function", - "sample_from", "track", "uniform", "quniform", "choice", "randint", - "lograndint", "qrandint", "qlograndint", "randn", "qrandn", "loguniform", - "qloguniform", "ExperimentAnalysis", "Analysis", "CLIReporter", - "JupyterNotebookReporter", "ProgressReporter", "report", "get_trial_dir", - "get_trial_name", "get_trial_id", "make_checkpoint_dir", "save_checkpoint", - "is_session_enabled", "checkpoint_dir", "SyncConfig", "create_searcher", - "create_scheduler", "PlacementGroupFactory" + "Trainable", "DurableTrainable", "durable", "Callback", "TuneError", + "grid_search", "register_env", "register_trainable", "run", + "run_experiments", "with_parameters", "Stopper", "EarlyStopping", + "Experiment", "function", "sample_from", "track", "uniform", "quniform", + "choice", "randint", "lograndint", "qrandint", "qlograndint", "randn", + "qrandn", "loguniform", "qloguniform", "ExperimentAnalysis", "Analysis", + "CLIReporter", "JupyterNotebookReporter", "ProgressReporter", "report", + "get_trial_dir", "get_trial_name", "get_trial_id", "make_checkpoint_dir", + "save_checkpoint", "is_session_enabled", "checkpoint_dir", "SyncConfig", + "create_searcher", "create_scheduler", "PlacementGroupFactory" ] diff --git a/python/ray/tune/durable_trainable.py b/python/ray/tune/durable_trainable.py index d6a12839c..e46130393 100644 --- a/python/ray/tune/durable_trainable.py +++ b/python/ray/tune/durable_trainable.py @@ -1,6 +1,11 @@ +from typing import Callable, Type, Union + +import inspect import logging import os +from ray.tune.function_runner import wrap_function +from ray.tune.registry import get_trainable_cls from ray.tune.trainable import Trainable, TrainableUtil from ray.tune.syncer import get_cloud_sync_client @@ -99,3 +104,80 @@ class DurableTrainable(Trainable): def _storage_path(self, local_path): rel_local_path = os.path.relpath(local_path, self.logdir) return os.path.join(self.remote_checkpoint_dir, rel_local_path) + + +def durable(trainable: Union[str, Type[Trainable], Callable]): + """Convert trainable into a durable trainable. + + Durable trainables are used to upload trial results and checkpoints + to cloud storage, like e.g. AWS S3. + + This function can be used to convert your trainable, i.e. your trainable + classes, functions, or string identifiers, to a durable trainable. + + To make durable trainables work, you should pass a valid + :class:`SyncConfig ` object to `tune.run()`. + + Example: + + .. code-block:: python + + from ray import tune + + analysis = tune.run( + tune.durable("PPO"), + config={"env": "CartPole-v0"}, + checkpoint_freq=1, + sync_config=tune.SyncConfig( + sync_to_driver=False, + upload_dir="s3://your-s3-bucket/durable-ppo/", + )) + + You can also convert your trainable functions: + + .. code-block:: python + + tune.run( + tune.durable(your_training_fn), + # ... + ) + + And your class functions: + + .. code-block:: python + + tune.run( + tune.durable(YourTrainableClass), + # ... + ) + + + Args: + trainable (str|Type[Trainable]|Callable): Trainable. Can be a + string identifier, a trainable class, or a trainable function. + + Returns: + A durable trainable class wrapped around your trainable. + + """ + if isinstance(trainable, str): + trainable_cls = get_trainable_cls(trainable) + else: + trainable_cls = trainable + + if not inspect.isclass(trainable_cls): + # Function API + return wrap_function(trainable_cls, durable=True) + + if not issubclass(trainable_cls, Trainable): + raise ValueError( + "You can only use `durable()` with valid trainables. The class " + "you passed does not inherit from `Trainable`. Please make sure " + f"it does. Got: {type(trainable_cls)}") + + # else: Class API + class _WrappedDurableTrainable(DurableTrainable, trainable_cls): + _name = trainable_cls.__name__ if hasattr(trainable_cls, "__name__") \ + else "durable_trainable" + + return _WrappedDurableTrainable diff --git a/python/ray/tune/experiment.py b/python/ray/tune/experiment.py index 4ad5c43c7..5d5c89cff 100644 --- a/python/ray/tune/experiment.py +++ b/python/ray/tune/experiment.py @@ -1,8 +1,10 @@ -import copy -import logging -from pickle import PicklingError -import os from typing import Sequence +import copy +import inspect +import logging +import os + +from pickle import PicklingError from ray.tune.error import TuneError from ray.tune.registry import register_trainable, get_trainable_cls @@ -11,7 +13,6 @@ from ray.tune.sample import Domain from ray.tune.stopper import CombinedStopper, FunctionStopper, Stopper, \ TimeoutStopper from ray.tune.utils import date_str, detect_checkpoint_function - logger = logging.getLogger(__name__) @@ -135,7 +136,8 @@ class Experiment: "`tune.run()`.") config = config or {} - if callable(run) and detect_checkpoint_function(run): + if callable(run) and not inspect.isclass(run) and \ + detect_checkpoint_function(run): if checkpoint_at_end: raise ValueError("'checkpoint_at_end' cannot be used with a " "checkpointable function. You can specify " diff --git a/python/ray/tune/function_runner.py b/python/ray/tune/function_runner.py index c7c088293..83078081c 100644 --- a/python/ray/tune/function_runner.py +++ b/python/ray/tune/function_runner.py @@ -516,11 +516,15 @@ class FunctionRunner(Trainable): pass -def wrap_function(train_func, warn=True): +def wrap_function(train_func, durable=False, warn=True): + inherit_from = (FunctionRunner, ) + if hasattr(train_func, "__mixins__"): - inherit_from = train_func.__mixins__ + (FunctionRunner, ) - else: - inherit_from = (FunctionRunner, ) + inherit_from = train_func.__mixins__ + inherit_from + + if durable: + from ray.tune import DurableTrainable + inherit_from = (DurableTrainable, ) + inherit_from func_args = inspect.getfullargspec(train_func).args use_checkpoint = detect_checkpoint_function(train_func) @@ -617,7 +621,7 @@ def with_parameters(fn, **kwargs): ) """ - if not callable(fn): + if not callable(fn) or inspect.isclass(fn): raise ValueError( "`tune.with_parameters()` only works with the function API. " "If you want to pass parameters to Trainable _classes_, consider " @@ -627,7 +631,7 @@ def with_parameters(fn, **kwargs): for k, v in kwargs.items(): parameter_registry.put(prefix + k, v) - use_checkpoint = detect_checkpoint_function(fn) + use_checkpoint = detect_checkpoint_function(fn, partial=True) keys = list(kwargs.keys()) def inner(config, checkpoint_dir=None): diff --git a/python/ray/tune/tests/test_api.py b/python/ray/tune/tests/test_api.py index a0d8c474e..91e4d3c9c 100644 --- a/python/ray/tune/tests/test_api.py +++ b/python/ray/tune/tests/test_api.py @@ -1,3 +1,4 @@ +import pickle from collections import Counter import shutil import tempfile @@ -15,6 +16,7 @@ from ray import tune from ray.tune import (DurableTrainable, Trainable, TuneError, Stopper, EarlyStopping, run) from ray.tune import register_env, register_trainable, run_experiments +from ray.tune.durable_trainable import durable from ray.tune.schedulers import (TrialScheduler, FIFOScheduler, AsyncHyperBandScheduler) from ray.tune.stopper import MaximumIterationStopper, TrialPlateauStopper @@ -838,14 +840,51 @@ class TrainableFunctionApiTest(unittest.TestCase): ] self.assertTrue(all(complete_results1)) - def testDurableTrainable(self): + def _testDurableTrainable(self, trainable, function=False, cleanup=True): + sync_client = mock_storage_client() + mock_get_client = "ray.tune.durable_trainable.get_cloud_sync_client" + with patch(mock_get_client) as mock_get_cloud_sync_client: + mock_get_cloud_sync_client.return_value = sync_client + test_trainable = trainable(remote_checkpoint_dir=MOCK_REMOTE_DIR) + result = test_trainable.train() + self.assertEqual(result["metric"], 1) + checkpoint_path = test_trainable.save() + result = test_trainable.train() + self.assertEqual(result["metric"], 2) + result = test_trainable.train() + self.assertEqual(result["metric"], 3) + result = test_trainable.train() + self.assertEqual(result["metric"], 4) + + if not function: + test_trainable.state["hi"] = 2 + test_trainable.restore(checkpoint_path) + self.assertEqual(test_trainable.state["hi"], 1) + else: + # Cannot re-use function trainable, create new + tune.session.shutdown() + test_trainable = trainable( + remote_checkpoint_dir=MOCK_REMOTE_DIR) + test_trainable.restore(checkpoint_path) + + result = test_trainable.train() + self.assertEqual(result["metric"], 2) + + if cleanup: + self.addCleanup(shutil.rmtree, MOCK_REMOTE_DIR) + + def testDurableTrainableClass(self): class TestTrain(DurableTrainable): def setup(self, config): self.state = {"hi": 1, "iter": 0} def step(self): self.state["iter"] += 1 - return {"timesteps_this_iter": 1, "done": True} + return { + "timesteps_this_iter": 1, + "metric": self.state["iter"], + "done": self.state["iter"] > 3 + } def save_checkpoint(self, path): return self.state @@ -853,18 +892,54 @@ class TrainableFunctionApiTest(unittest.TestCase): def load_checkpoint(self, state): self.state = state - sync_client = mock_storage_client() - mock_get_client = "ray.tune.durable_trainable.get_cloud_sync_client" - with patch(mock_get_client) as mock_get_cloud_sync_client: - mock_get_cloud_sync_client.return_value = sync_client - test_trainable = TestTrain(remote_checkpoint_dir=MOCK_REMOTE_DIR) - checkpoint_path = test_trainable.save() - test_trainable.train() - test_trainable.state["hi"] = 2 - test_trainable.restore(checkpoint_path) - self.assertEqual(test_trainable.state["hi"], 1) + self._testDurableTrainable(TestTrain) - self.addCleanup(shutil.rmtree, MOCK_REMOTE_DIR) + def testDurableTrainableWrapped(self): + class TestTrain(Trainable): + def setup(self, config): + self.state = {"hi": 1, "iter": 0} + + def step(self): + self.state["iter"] += 1 + return { + "timesteps_this_iter": 1, + "metric": self.state["iter"], + "done": self.state["iter"] > 3 + } + + def save_checkpoint(self, path): + return self.state + + def load_checkpoint(self, state): + self.state = state + + self._testDurableTrainable(durable(TestTrain), cleanup=False) + + tune.register_trainable("test_train", TestTrain) + + self._testDurableTrainable(durable("test_train")) + + def testDurableTrainableFunction(self): + def test_train(config, checkpoint_dir=None): + state = {"hi": 1, "iter": 0} + if checkpoint_dir: + with open(os.path.join(checkpoint_dir, "ckpt.pkl"), + "rb") as fp: + state = pickle.load(fp) + + for i in range(4): + state["iter"] += 1 + with tune.checkpoint_dir(step=state["iter"]) as dir: + with open(os.path.join(dir, "ckpt.pkl"), "wb") as fp: + pickle.dump(state, fp) + tune.report( + **{ + "timesteps_this_iter": 1, + "metric": state["iter"], + "done": state["iter"] > 3 + }) + + self._testDurableTrainable(durable(test_train), function=True) def testCheckpointDict(self): class TestTrain(Trainable): diff --git a/python/ray/tune/utils/util.py b/python/ray/tune/utils/util.py index b309ffc81..83d3563e8 100644 --- a/python/ray/tune/utils/util.py +++ b/python/ray/tune/utils/util.py @@ -602,13 +602,16 @@ def validate_save_restore(trainable_cls, return True -def detect_checkpoint_function(train_func, abort=False): +def detect_checkpoint_function(train_func, abort=False, partial=False): """Use checkpointing if any arg has "checkpoint_dir" and args = 2""" func_sig = inspect.signature(train_func) validated = True try: # check if signature is func(config, checkpoint_dir=None) - func_sig.bind({}, checkpoint_dir="tmp/path") + if partial: + func_sig.bind_partial({}, checkpoint_dir="tmp/path") + else: + func_sig.bind({}, checkpoint_dir="tmp/path") except Exception as e: logger.debug(str(e)) validated = False diff --git a/release/tune_tests/scalability_tests/workloads/_trainable.py b/release/tune_tests/scalability_tests/workloads/_trainable.py index c5ce8c005..192e3382e 100644 --- a/release/tune_tests/scalability_tests/workloads/_trainable.py +++ b/release/tune_tests/scalability_tests/workloads/_trainable.py @@ -88,6 +88,7 @@ def timed_tune_run(name: str, checkpoint_size_b: int = 0, **tune_kwargs): durable = "sync_config" in tune_kwargs and \ + tune_kwargs["sync_config"].upload_dir and \ tune_kwargs["sync_config"].upload_dir.startswith("s3://") sleep_time = 1. / results_per_second