[tune] Introduce durable() wrapper to convert trainables into durable trainables (#14306)

* [tune] Introduce `durable()` wrapper to convert trainables into durable trainables

* Fix wrong check

* Improve docs, add FAQ for tackling overhead

* Fix bugs in `tune.with_parameters`

* Update doc/source/tune/api_docs/trainable.rst

Co-authored-by: Richard Liaw <rliaw@berkeley.edu>

* Update doc/source/tune/_tutorials/_faq.rst

Co-authored-by: Richard Liaw <rliaw@berkeley.edu>

Co-authored-by: Richard Liaw <rliaw@berkeley.edu>
This commit is contained in:
Kai Fricke 2021-02-26 13:59:28 +01:00 committed by GitHub
parent f1c8c8d12f
commit 4014168928
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
9 changed files with 294 additions and 38 deletions

View file

@ -394,6 +394,81 @@ We strongly advise to try reproduction on smaller toy problems first before rely
on it for larger experiments.
How can I avoid bottlenecks?
~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Sometimes you might run into a message like this:
.. code-block::
The `experiment_checkpoint` operation took 2.43 seconds to complete, which may be a performance bottleneck
Most commonly, the ``experiment_checkpoint`` operation is throwing this warning, but it might be something else,
like ``process_trial_result``.
These operations should usually take less than 500ms to complete. When it consistently takes longer, this might
indicate a problem or inefficiencies. To get rid of this message, it is important to understand where it comes
from.
These are the main reasons this problem comes up:
**The Trial config is very large**
This is the case if you e.g. try to pass a dataset or other large object via the ``config`` parameter.
If this is the case, the dataset is serialized and written to disk repeatedly during experiment
checkpointing, which takes a long time.
**Solution**: Use :func:`tune.with_parameters <ray.tune.with_parameters>` to pass large objects to
function trainables via the objects store. For class trainables you can do this manually via ``ray.put()``
and ``ray.get()``. If you need to pass a class definition, consider passing an
indicator (e.g. a string) instead and let the trainable select the class instead. Generally, your config
dictionary should only contain primitive types, like numbers or strings.
**The Trial result is very large**
This is the case if you return objects, data, or other large objects via the return value of ``step()`` in
your class trainable or to ``tune.report()`` in your function trainable. The effect is the same as above:
The results are repeatedly serialized and written to disk, and this can take a long time.
**Solution**: Usually you should be able to write data to the trial directory instead. You can then pass a
filename back (or assume it is a known location). The trial dir is usually the current working directory. Class
trainables have the ``Trainable.logdir`` property and function trainables the :func:`ray.tune.get_trial_dir`
function to retrieve the logdir. If you really have to, you can also ``ray.put()`` an object to the Ray
object store and retrieve it with ``ray.get()`` on the other side. Generally, your result dictionary
should only contain primitive types, like numbers or strings.
**You are training a large number of trials on a cluster, or you are saving huge checkpoints**
Checkpoints and logs are synced between nodes
- usually at least to the driver on the head node, but sometimes between worker nodes if needed (e.g. when
using :ref:`Population Based Training <tune-scheduler-pbt>`). If these checkpoints are very large (e.g. for
NLP models), or if you are training a large number of trials, this syncing can take a long time.
If nothing else is specified, syncing happens via SSH, which can lead to network overhead as connections are
not kept open by Ray Tune.
**Solution**: There are multiple solutions, depending on your needs:
1. You can disable syncing to the driver in the :class:`tune.SyncConfig <ray.tune.SyncConfig>`. In this case,
logs and checkpoints will not be synced to the driver, so if you need to access them later, you will have to
transfer them where you need them manually.
2. You can use the :ref:`ray.tune.durable <tune-durable-trainable>` wrapper to save logs and checkpoints to a specified `upload_dir`.
This is the preferred way to deal with this. All syncing will be taken care of automatically, as all nodes
are able to access the cloud storage. Additionally, your results will be safe, so even when you're working on
pre-emptible instances, you won't lose any of your data.
**You are reporting results too often**
Each result is processed by the search algorithm, trial scheduler, and callbacks (including loggers and the
trial syncer). If you're reporting a large number of results per trial (e.g. multiple results per second),
this can take a long time.
**Solution**: The solution here is obvious: Just don't report results that often. In class trainables, ``step()``
should maybe process a larger chunk of data. In function trainables, you can report only every n-th iteration
of the training loop. Try to balance the number of results you really need to make scheduling or searching
decisions. If you need more fine grained metrics for logging or tracking, consider using a separate logging
mechanism for this instead of the Ray Tune-provided progress logging of results.
Further Questions or Issues?

View file

@ -377,11 +377,25 @@ Ray also offers lightweight integrations to distribute your TensorFlow training
.. autofunction:: ray.tune.integration.tensorflow.DistributedTrainableCreator
:noindex:
.. _tune-durable-trainable:
tune.DurableTrainable
---------------------
Tune provides a :func:`ray.tune.durable` wrapper that can be used to convert any kind of trainable
to a ``DurableTrainable``, including pre-registered RLLib trainers and :ref:`function trainables <tune-function-api>`.
The :class:`DurableTrainable <ray.tune.DurableTrainable>` syncs trial logs and checkpoints to cloud storage (via the `upload_dir`). This is especially
useful when training a large number of distributed trials, as logs and checkpoints are otherwise synchronized
via SSH, which quickly can become a performance bottleneck. The :class:`DurableTrainable <ray.tune.DurableTrainable>` class inherits from
:class:`Trainable <ray.tune.Trainable>` and thus can be extended like the base class.
.. autoclass:: ray.tune.DurableTrainable
.. autofunction:: ray.tune.durable
StatusReporter
--------------

View file

@ -7,7 +7,7 @@ from ray.tune.analysis import ExperimentAnalysis, Analysis
from ray.tune.stopper import Stopper, EarlyStopping
from ray.tune.registry import register_env, register_trainable
from ray.tune.trainable import Trainable
from ray.tune.durable_trainable import DurableTrainable
from ray.tune.durable_trainable import DurableTrainable, durable
from ray.tune.callback import Callback
from ray.tune.suggest import grid_search
from ray.tune.session import (
@ -23,14 +23,14 @@ from ray.tune.schedulers import create_scheduler
from ray.tune.utils.placement_groups import PlacementGroupFactory
__all__ = [
"Trainable", "DurableTrainable", "Callback", "TuneError", "grid_search",
"register_env", "register_trainable", "run", "run_experiments",
"with_parameters", "Stopper", "EarlyStopping", "Experiment", "function",
"sample_from", "track", "uniform", "quniform", "choice", "randint",
"lograndint", "qrandint", "qlograndint", "randn", "qrandn", "loguniform",
"qloguniform", "ExperimentAnalysis", "Analysis", "CLIReporter",
"JupyterNotebookReporter", "ProgressReporter", "report", "get_trial_dir",
"get_trial_name", "get_trial_id", "make_checkpoint_dir", "save_checkpoint",
"is_session_enabled", "checkpoint_dir", "SyncConfig", "create_searcher",
"create_scheduler", "PlacementGroupFactory"
"Trainable", "DurableTrainable", "durable", "Callback", "TuneError",
"grid_search", "register_env", "register_trainable", "run",
"run_experiments", "with_parameters", "Stopper", "EarlyStopping",
"Experiment", "function", "sample_from", "track", "uniform", "quniform",
"choice", "randint", "lograndint", "qrandint", "qlograndint", "randn",
"qrandn", "loguniform", "qloguniform", "ExperimentAnalysis", "Analysis",
"CLIReporter", "JupyterNotebookReporter", "ProgressReporter", "report",
"get_trial_dir", "get_trial_name", "get_trial_id", "make_checkpoint_dir",
"save_checkpoint", "is_session_enabled", "checkpoint_dir", "SyncConfig",
"create_searcher", "create_scheduler", "PlacementGroupFactory"
]

View file

@ -1,6 +1,11 @@
from typing import Callable, Type, Union
import inspect
import logging
import os
from ray.tune.function_runner import wrap_function
from ray.tune.registry import get_trainable_cls
from ray.tune.trainable import Trainable, TrainableUtil
from ray.tune.syncer import get_cloud_sync_client
@ -99,3 +104,80 @@ class DurableTrainable(Trainable):
def _storage_path(self, local_path):
rel_local_path = os.path.relpath(local_path, self.logdir)
return os.path.join(self.remote_checkpoint_dir, rel_local_path)
def durable(trainable: Union[str, Type[Trainable], Callable]):
"""Convert trainable into a durable trainable.
Durable trainables are used to upload trial results and checkpoints
to cloud storage, like e.g. AWS S3.
This function can be used to convert your trainable, i.e. your trainable
classes, functions, or string identifiers, to a durable trainable.
To make durable trainables work, you should pass a valid
:class:`SyncConfig <ray.tune.SyncConfig>` object to `tune.run()`.
Example:
.. code-block:: python
from ray import tune
analysis = tune.run(
tune.durable("PPO"),
config={"env": "CartPole-v0"},
checkpoint_freq=1,
sync_config=tune.SyncConfig(
sync_to_driver=False,
upload_dir="s3://your-s3-bucket/durable-ppo/",
))
You can also convert your trainable functions:
.. code-block:: python
tune.run(
tune.durable(your_training_fn),
# ...
)
And your class functions:
.. code-block:: python
tune.run(
tune.durable(YourTrainableClass),
# ...
)
Args:
trainable (str|Type[Trainable]|Callable): Trainable. Can be a
string identifier, a trainable class, or a trainable function.
Returns:
A durable trainable class wrapped around your trainable.
"""
if isinstance(trainable, str):
trainable_cls = get_trainable_cls(trainable)
else:
trainable_cls = trainable
if not inspect.isclass(trainable_cls):
# Function API
return wrap_function(trainable_cls, durable=True)
if not issubclass(trainable_cls, Trainable):
raise ValueError(
"You can only use `durable()` with valid trainables. The class "
"you passed does not inherit from `Trainable`. Please make sure "
f"it does. Got: {type(trainable_cls)}")
# else: Class API
class _WrappedDurableTrainable(DurableTrainable, trainable_cls):
_name = trainable_cls.__name__ if hasattr(trainable_cls, "__name__") \
else "durable_trainable"
return _WrappedDurableTrainable

View file

@ -1,8 +1,10 @@
import copy
import logging
from pickle import PicklingError
import os
from typing import Sequence
import copy
import inspect
import logging
import os
from pickle import PicklingError
from ray.tune.error import TuneError
from ray.tune.registry import register_trainable, get_trainable_cls
@ -11,7 +13,6 @@ from ray.tune.sample import Domain
from ray.tune.stopper import CombinedStopper, FunctionStopper, Stopper, \
TimeoutStopper
from ray.tune.utils import date_str, detect_checkpoint_function
logger = logging.getLogger(__name__)
@ -135,7 +136,8 @@ class Experiment:
"`tune.run()`.")
config = config or {}
if callable(run) and detect_checkpoint_function(run):
if callable(run) and not inspect.isclass(run) and \
detect_checkpoint_function(run):
if checkpoint_at_end:
raise ValueError("'checkpoint_at_end' cannot be used with a "
"checkpointable function. You can specify "

View file

@ -516,11 +516,15 @@ class FunctionRunner(Trainable):
pass
def wrap_function(train_func, warn=True):
def wrap_function(train_func, durable=False, warn=True):
inherit_from = (FunctionRunner, )
if hasattr(train_func, "__mixins__"):
inherit_from = train_func.__mixins__ + (FunctionRunner, )
else:
inherit_from = (FunctionRunner, )
inherit_from = train_func.__mixins__ + inherit_from
if durable:
from ray.tune import DurableTrainable
inherit_from = (DurableTrainable, ) + inherit_from
func_args = inspect.getfullargspec(train_func).args
use_checkpoint = detect_checkpoint_function(train_func)
@ -617,7 +621,7 @@ def with_parameters(fn, **kwargs):
)
"""
if not callable(fn):
if not callable(fn) or inspect.isclass(fn):
raise ValueError(
"`tune.with_parameters()` only works with the function API. "
"If you want to pass parameters to Trainable _classes_, consider "
@ -627,7 +631,7 @@ def with_parameters(fn, **kwargs):
for k, v in kwargs.items():
parameter_registry.put(prefix + k, v)
use_checkpoint = detect_checkpoint_function(fn)
use_checkpoint = detect_checkpoint_function(fn, partial=True)
keys = list(kwargs.keys())
def inner(config, checkpoint_dir=None):

View file

@ -1,3 +1,4 @@
import pickle
from collections import Counter
import shutil
import tempfile
@ -15,6 +16,7 @@ from ray import tune
from ray.tune import (DurableTrainable, Trainable, TuneError, Stopper,
EarlyStopping, run)
from ray.tune import register_env, register_trainable, run_experiments
from ray.tune.durable_trainable import durable
from ray.tune.schedulers import (TrialScheduler, FIFOScheduler,
AsyncHyperBandScheduler)
from ray.tune.stopper import MaximumIterationStopper, TrialPlateauStopper
@ -838,14 +840,51 @@ class TrainableFunctionApiTest(unittest.TestCase):
]
self.assertTrue(all(complete_results1))
def testDurableTrainable(self):
def _testDurableTrainable(self, trainable, function=False, cleanup=True):
sync_client = mock_storage_client()
mock_get_client = "ray.tune.durable_trainable.get_cloud_sync_client"
with patch(mock_get_client) as mock_get_cloud_sync_client:
mock_get_cloud_sync_client.return_value = sync_client
test_trainable = trainable(remote_checkpoint_dir=MOCK_REMOTE_DIR)
result = test_trainable.train()
self.assertEqual(result["metric"], 1)
checkpoint_path = test_trainable.save()
result = test_trainable.train()
self.assertEqual(result["metric"], 2)
result = test_trainable.train()
self.assertEqual(result["metric"], 3)
result = test_trainable.train()
self.assertEqual(result["metric"], 4)
if not function:
test_trainable.state["hi"] = 2
test_trainable.restore(checkpoint_path)
self.assertEqual(test_trainable.state["hi"], 1)
else:
# Cannot re-use function trainable, create new
tune.session.shutdown()
test_trainable = trainable(
remote_checkpoint_dir=MOCK_REMOTE_DIR)
test_trainable.restore(checkpoint_path)
result = test_trainable.train()
self.assertEqual(result["metric"], 2)
if cleanup:
self.addCleanup(shutil.rmtree, MOCK_REMOTE_DIR)
def testDurableTrainableClass(self):
class TestTrain(DurableTrainable):
def setup(self, config):
self.state = {"hi": 1, "iter": 0}
def step(self):
self.state["iter"] += 1
return {"timesteps_this_iter": 1, "done": True}
return {
"timesteps_this_iter": 1,
"metric": self.state["iter"],
"done": self.state["iter"] > 3
}
def save_checkpoint(self, path):
return self.state
@ -853,18 +892,54 @@ class TrainableFunctionApiTest(unittest.TestCase):
def load_checkpoint(self, state):
self.state = state
sync_client = mock_storage_client()
mock_get_client = "ray.tune.durable_trainable.get_cloud_sync_client"
with patch(mock_get_client) as mock_get_cloud_sync_client:
mock_get_cloud_sync_client.return_value = sync_client
test_trainable = TestTrain(remote_checkpoint_dir=MOCK_REMOTE_DIR)
checkpoint_path = test_trainable.save()
test_trainable.train()
test_trainable.state["hi"] = 2
test_trainable.restore(checkpoint_path)
self.assertEqual(test_trainable.state["hi"], 1)
self._testDurableTrainable(TestTrain)
self.addCleanup(shutil.rmtree, MOCK_REMOTE_DIR)
def testDurableTrainableWrapped(self):
class TestTrain(Trainable):
def setup(self, config):
self.state = {"hi": 1, "iter": 0}
def step(self):
self.state["iter"] += 1
return {
"timesteps_this_iter": 1,
"metric": self.state["iter"],
"done": self.state["iter"] > 3
}
def save_checkpoint(self, path):
return self.state
def load_checkpoint(self, state):
self.state = state
self._testDurableTrainable(durable(TestTrain), cleanup=False)
tune.register_trainable("test_train", TestTrain)
self._testDurableTrainable(durable("test_train"))
def testDurableTrainableFunction(self):
def test_train(config, checkpoint_dir=None):
state = {"hi": 1, "iter": 0}
if checkpoint_dir:
with open(os.path.join(checkpoint_dir, "ckpt.pkl"),
"rb") as fp:
state = pickle.load(fp)
for i in range(4):
state["iter"] += 1
with tune.checkpoint_dir(step=state["iter"]) as dir:
with open(os.path.join(dir, "ckpt.pkl"), "wb") as fp:
pickle.dump(state, fp)
tune.report(
**{
"timesteps_this_iter": 1,
"metric": state["iter"],
"done": state["iter"] > 3
})
self._testDurableTrainable(durable(test_train), function=True)
def testCheckpointDict(self):
class TestTrain(Trainable):

View file

@ -602,13 +602,16 @@ def validate_save_restore(trainable_cls,
return True
def detect_checkpoint_function(train_func, abort=False):
def detect_checkpoint_function(train_func, abort=False, partial=False):
"""Use checkpointing if any arg has "checkpoint_dir" and args = 2"""
func_sig = inspect.signature(train_func)
validated = True
try:
# check if signature is func(config, checkpoint_dir=None)
func_sig.bind({}, checkpoint_dir="tmp/path")
if partial:
func_sig.bind_partial({}, checkpoint_dir="tmp/path")
else:
func_sig.bind({}, checkpoint_dir="tmp/path")
except Exception as e:
logger.debug(str(e))
validated = False

View file

@ -88,6 +88,7 @@ def timed_tune_run(name: str,
checkpoint_size_b: int = 0,
**tune_kwargs):
durable = "sync_config" in tune_kwargs and \
tune_kwargs["sync_config"].upload_dir and \
tune_kwargs["sync_config"].upload_dir.startswith("s3://")
sleep_time = 1. / results_per_second