mirror of
https://github.com/vale981/ray
synced 2025-03-05 18:11:42 -05:00
[tune] Introduce durable()
wrapper to convert trainables into durable trainables (#14306)
* [tune] Introduce `durable()` wrapper to convert trainables into durable trainables * Fix wrong check * Improve docs, add FAQ for tackling overhead * Fix bugs in `tune.with_parameters` * Update doc/source/tune/api_docs/trainable.rst Co-authored-by: Richard Liaw <rliaw@berkeley.edu> * Update doc/source/tune/_tutorials/_faq.rst Co-authored-by: Richard Liaw <rliaw@berkeley.edu> Co-authored-by: Richard Liaw <rliaw@berkeley.edu>
This commit is contained in:
parent
f1c8c8d12f
commit
4014168928
9 changed files with 294 additions and 38 deletions
|
@ -394,6 +394,81 @@ We strongly advise to try reproduction on smaller toy problems first before rely
|
|||
on it for larger experiments.
|
||||
|
||||
|
||||
How can I avoid bottlenecks?
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
Sometimes you might run into a message like this:
|
||||
|
||||
.. code-block::
|
||||
|
||||
The `experiment_checkpoint` operation took 2.43 seconds to complete, which may be a performance bottleneck
|
||||
|
||||
Most commonly, the ``experiment_checkpoint`` operation is throwing this warning, but it might be something else,
|
||||
like ``process_trial_result``.
|
||||
|
||||
These operations should usually take less than 500ms to complete. When it consistently takes longer, this might
|
||||
indicate a problem or inefficiencies. To get rid of this message, it is important to understand where it comes
|
||||
from.
|
||||
|
||||
These are the main reasons this problem comes up:
|
||||
|
||||
**The Trial config is very large**
|
||||
|
||||
This is the case if you e.g. try to pass a dataset or other large object via the ``config`` parameter.
|
||||
If this is the case, the dataset is serialized and written to disk repeatedly during experiment
|
||||
checkpointing, which takes a long time.
|
||||
|
||||
**Solution**: Use :func:`tune.with_parameters <ray.tune.with_parameters>` to pass large objects to
|
||||
function trainables via the objects store. For class trainables you can do this manually via ``ray.put()``
|
||||
and ``ray.get()``. If you need to pass a class definition, consider passing an
|
||||
indicator (e.g. a string) instead and let the trainable select the class instead. Generally, your config
|
||||
dictionary should only contain primitive types, like numbers or strings.
|
||||
|
||||
**The Trial result is very large**
|
||||
|
||||
This is the case if you return objects, data, or other large objects via the return value of ``step()`` in
|
||||
your class trainable or to ``tune.report()`` in your function trainable. The effect is the same as above:
|
||||
The results are repeatedly serialized and written to disk, and this can take a long time.
|
||||
|
||||
**Solution**: Usually you should be able to write data to the trial directory instead. You can then pass a
|
||||
filename back (or assume it is a known location). The trial dir is usually the current working directory. Class
|
||||
trainables have the ``Trainable.logdir`` property and function trainables the :func:`ray.tune.get_trial_dir`
|
||||
function to retrieve the logdir. If you really have to, you can also ``ray.put()`` an object to the Ray
|
||||
object store and retrieve it with ``ray.get()`` on the other side. Generally, your result dictionary
|
||||
should only contain primitive types, like numbers or strings.
|
||||
|
||||
**You are training a large number of trials on a cluster, or you are saving huge checkpoints**
|
||||
|
||||
Checkpoints and logs are synced between nodes
|
||||
- usually at least to the driver on the head node, but sometimes between worker nodes if needed (e.g. when
|
||||
using :ref:`Population Based Training <tune-scheduler-pbt>`). If these checkpoints are very large (e.g. for
|
||||
NLP models), or if you are training a large number of trials, this syncing can take a long time.
|
||||
|
||||
If nothing else is specified, syncing happens via SSH, which can lead to network overhead as connections are
|
||||
not kept open by Ray Tune.
|
||||
|
||||
**Solution**: There are multiple solutions, depending on your needs:
|
||||
|
||||
1. You can disable syncing to the driver in the :class:`tune.SyncConfig <ray.tune.SyncConfig>`. In this case,
|
||||
logs and checkpoints will not be synced to the driver, so if you need to access them later, you will have to
|
||||
transfer them where you need them manually.
|
||||
|
||||
2. You can use the :ref:`ray.tune.durable <tune-durable-trainable>` wrapper to save logs and checkpoints to a specified `upload_dir`.
|
||||
This is the preferred way to deal with this. All syncing will be taken care of automatically, as all nodes
|
||||
are able to access the cloud storage. Additionally, your results will be safe, so even when you're working on
|
||||
pre-emptible instances, you won't lose any of your data.
|
||||
|
||||
**You are reporting results too often**
|
||||
|
||||
Each result is processed by the search algorithm, trial scheduler, and callbacks (including loggers and the
|
||||
trial syncer). If you're reporting a large number of results per trial (e.g. multiple results per second),
|
||||
this can take a long time.
|
||||
|
||||
**Solution**: The solution here is obvious: Just don't report results that often. In class trainables, ``step()``
|
||||
should maybe process a larger chunk of data. In function trainables, you can report only every n-th iteration
|
||||
of the training loop. Try to balance the number of results you really need to make scheduling or searching
|
||||
decisions. If you need more fine grained metrics for logging or tracking, consider using a separate logging
|
||||
mechanism for this instead of the Ray Tune-provided progress logging of results.
|
||||
|
||||
|
||||
|
||||
Further Questions or Issues?
|
||||
|
|
|
@ -377,11 +377,25 @@ Ray also offers lightweight integrations to distribute your TensorFlow training
|
|||
.. autofunction:: ray.tune.integration.tensorflow.DistributedTrainableCreator
|
||||
:noindex:
|
||||
|
||||
.. _tune-durable-trainable:
|
||||
|
||||
tune.DurableTrainable
|
||||
---------------------
|
||||
|
||||
|
||||
Tune provides a :func:`ray.tune.durable` wrapper that can be used to convert any kind of trainable
|
||||
to a ``DurableTrainable``, including pre-registered RLLib trainers and :ref:`function trainables <tune-function-api>`.
|
||||
|
||||
|
||||
The :class:`DurableTrainable <ray.tune.DurableTrainable>` syncs trial logs and checkpoints to cloud storage (via the `upload_dir`). This is especially
|
||||
useful when training a large number of distributed trials, as logs and checkpoints are otherwise synchronized
|
||||
via SSH, which quickly can become a performance bottleneck. The :class:`DurableTrainable <ray.tune.DurableTrainable>` class inherits from
|
||||
:class:`Trainable <ray.tune.Trainable>` and thus can be extended like the base class.
|
||||
|
||||
.. autoclass:: ray.tune.DurableTrainable
|
||||
|
||||
.. autofunction:: ray.tune.durable
|
||||
|
||||
|
||||
StatusReporter
|
||||
--------------
|
||||
|
|
|
@ -7,7 +7,7 @@ from ray.tune.analysis import ExperimentAnalysis, Analysis
|
|||
from ray.tune.stopper import Stopper, EarlyStopping
|
||||
from ray.tune.registry import register_env, register_trainable
|
||||
from ray.tune.trainable import Trainable
|
||||
from ray.tune.durable_trainable import DurableTrainable
|
||||
from ray.tune.durable_trainable import DurableTrainable, durable
|
||||
from ray.tune.callback import Callback
|
||||
from ray.tune.suggest import grid_search
|
||||
from ray.tune.session import (
|
||||
|
@ -23,14 +23,14 @@ from ray.tune.schedulers import create_scheduler
|
|||
from ray.tune.utils.placement_groups import PlacementGroupFactory
|
||||
|
||||
__all__ = [
|
||||
"Trainable", "DurableTrainable", "Callback", "TuneError", "grid_search",
|
||||
"register_env", "register_trainable", "run", "run_experiments",
|
||||
"with_parameters", "Stopper", "EarlyStopping", "Experiment", "function",
|
||||
"sample_from", "track", "uniform", "quniform", "choice", "randint",
|
||||
"lograndint", "qrandint", "qlograndint", "randn", "qrandn", "loguniform",
|
||||
"qloguniform", "ExperimentAnalysis", "Analysis", "CLIReporter",
|
||||
"JupyterNotebookReporter", "ProgressReporter", "report", "get_trial_dir",
|
||||
"get_trial_name", "get_trial_id", "make_checkpoint_dir", "save_checkpoint",
|
||||
"is_session_enabled", "checkpoint_dir", "SyncConfig", "create_searcher",
|
||||
"create_scheduler", "PlacementGroupFactory"
|
||||
"Trainable", "DurableTrainable", "durable", "Callback", "TuneError",
|
||||
"grid_search", "register_env", "register_trainable", "run",
|
||||
"run_experiments", "with_parameters", "Stopper", "EarlyStopping",
|
||||
"Experiment", "function", "sample_from", "track", "uniform", "quniform",
|
||||
"choice", "randint", "lograndint", "qrandint", "qlograndint", "randn",
|
||||
"qrandn", "loguniform", "qloguniform", "ExperimentAnalysis", "Analysis",
|
||||
"CLIReporter", "JupyterNotebookReporter", "ProgressReporter", "report",
|
||||
"get_trial_dir", "get_trial_name", "get_trial_id", "make_checkpoint_dir",
|
||||
"save_checkpoint", "is_session_enabled", "checkpoint_dir", "SyncConfig",
|
||||
"create_searcher", "create_scheduler", "PlacementGroupFactory"
|
||||
]
|
||||
|
|
|
@ -1,6 +1,11 @@
|
|||
from typing import Callable, Type, Union
|
||||
|
||||
import inspect
|
||||
import logging
|
||||
import os
|
||||
|
||||
from ray.tune.function_runner import wrap_function
|
||||
from ray.tune.registry import get_trainable_cls
|
||||
from ray.tune.trainable import Trainable, TrainableUtil
|
||||
from ray.tune.syncer import get_cloud_sync_client
|
||||
|
||||
|
@ -99,3 +104,80 @@ class DurableTrainable(Trainable):
|
|||
def _storage_path(self, local_path):
|
||||
rel_local_path = os.path.relpath(local_path, self.logdir)
|
||||
return os.path.join(self.remote_checkpoint_dir, rel_local_path)
|
||||
|
||||
|
||||
def durable(trainable: Union[str, Type[Trainable], Callable]):
|
||||
"""Convert trainable into a durable trainable.
|
||||
|
||||
Durable trainables are used to upload trial results and checkpoints
|
||||
to cloud storage, like e.g. AWS S3.
|
||||
|
||||
This function can be used to convert your trainable, i.e. your trainable
|
||||
classes, functions, or string identifiers, to a durable trainable.
|
||||
|
||||
To make durable trainables work, you should pass a valid
|
||||
:class:`SyncConfig <ray.tune.SyncConfig>` object to `tune.run()`.
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from ray import tune
|
||||
|
||||
analysis = tune.run(
|
||||
tune.durable("PPO"),
|
||||
config={"env": "CartPole-v0"},
|
||||
checkpoint_freq=1,
|
||||
sync_config=tune.SyncConfig(
|
||||
sync_to_driver=False,
|
||||
upload_dir="s3://your-s3-bucket/durable-ppo/",
|
||||
))
|
||||
|
||||
You can also convert your trainable functions:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
tune.run(
|
||||
tune.durable(your_training_fn),
|
||||
# ...
|
||||
)
|
||||
|
||||
And your class functions:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
tune.run(
|
||||
tune.durable(YourTrainableClass),
|
||||
# ...
|
||||
)
|
||||
|
||||
|
||||
Args:
|
||||
trainable (str|Type[Trainable]|Callable): Trainable. Can be a
|
||||
string identifier, a trainable class, or a trainable function.
|
||||
|
||||
Returns:
|
||||
A durable trainable class wrapped around your trainable.
|
||||
|
||||
"""
|
||||
if isinstance(trainable, str):
|
||||
trainable_cls = get_trainable_cls(trainable)
|
||||
else:
|
||||
trainable_cls = trainable
|
||||
|
||||
if not inspect.isclass(trainable_cls):
|
||||
# Function API
|
||||
return wrap_function(trainable_cls, durable=True)
|
||||
|
||||
if not issubclass(trainable_cls, Trainable):
|
||||
raise ValueError(
|
||||
"You can only use `durable()` with valid trainables. The class "
|
||||
"you passed does not inherit from `Trainable`. Please make sure "
|
||||
f"it does. Got: {type(trainable_cls)}")
|
||||
|
||||
# else: Class API
|
||||
class _WrappedDurableTrainable(DurableTrainable, trainable_cls):
|
||||
_name = trainable_cls.__name__ if hasattr(trainable_cls, "__name__") \
|
||||
else "durable_trainable"
|
||||
|
||||
return _WrappedDurableTrainable
|
||||
|
|
|
@ -1,8 +1,10 @@
|
|||
import copy
|
||||
import logging
|
||||
from pickle import PicklingError
|
||||
import os
|
||||
from typing import Sequence
|
||||
import copy
|
||||
import inspect
|
||||
import logging
|
||||
import os
|
||||
|
||||
from pickle import PicklingError
|
||||
|
||||
from ray.tune.error import TuneError
|
||||
from ray.tune.registry import register_trainable, get_trainable_cls
|
||||
|
@ -11,7 +13,6 @@ from ray.tune.sample import Domain
|
|||
from ray.tune.stopper import CombinedStopper, FunctionStopper, Stopper, \
|
||||
TimeoutStopper
|
||||
from ray.tune.utils import date_str, detect_checkpoint_function
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
@ -135,7 +136,8 @@ class Experiment:
|
|||
"`tune.run()`.")
|
||||
|
||||
config = config or {}
|
||||
if callable(run) and detect_checkpoint_function(run):
|
||||
if callable(run) and not inspect.isclass(run) and \
|
||||
detect_checkpoint_function(run):
|
||||
if checkpoint_at_end:
|
||||
raise ValueError("'checkpoint_at_end' cannot be used with a "
|
||||
"checkpointable function. You can specify "
|
||||
|
|
|
@ -516,11 +516,15 @@ class FunctionRunner(Trainable):
|
|||
pass
|
||||
|
||||
|
||||
def wrap_function(train_func, warn=True):
|
||||
def wrap_function(train_func, durable=False, warn=True):
|
||||
inherit_from = (FunctionRunner, )
|
||||
|
||||
if hasattr(train_func, "__mixins__"):
|
||||
inherit_from = train_func.__mixins__ + (FunctionRunner, )
|
||||
else:
|
||||
inherit_from = (FunctionRunner, )
|
||||
inherit_from = train_func.__mixins__ + inherit_from
|
||||
|
||||
if durable:
|
||||
from ray.tune import DurableTrainable
|
||||
inherit_from = (DurableTrainable, ) + inherit_from
|
||||
|
||||
func_args = inspect.getfullargspec(train_func).args
|
||||
use_checkpoint = detect_checkpoint_function(train_func)
|
||||
|
@ -617,7 +621,7 @@ def with_parameters(fn, **kwargs):
|
|||
)
|
||||
|
||||
"""
|
||||
if not callable(fn):
|
||||
if not callable(fn) or inspect.isclass(fn):
|
||||
raise ValueError(
|
||||
"`tune.with_parameters()` only works with the function API. "
|
||||
"If you want to pass parameters to Trainable _classes_, consider "
|
||||
|
@ -627,7 +631,7 @@ def with_parameters(fn, **kwargs):
|
|||
for k, v in kwargs.items():
|
||||
parameter_registry.put(prefix + k, v)
|
||||
|
||||
use_checkpoint = detect_checkpoint_function(fn)
|
||||
use_checkpoint = detect_checkpoint_function(fn, partial=True)
|
||||
keys = list(kwargs.keys())
|
||||
|
||||
def inner(config, checkpoint_dir=None):
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
import pickle
|
||||
from collections import Counter
|
||||
import shutil
|
||||
import tempfile
|
||||
|
@ -15,6 +16,7 @@ from ray import tune
|
|||
from ray.tune import (DurableTrainable, Trainable, TuneError, Stopper,
|
||||
EarlyStopping, run)
|
||||
from ray.tune import register_env, register_trainable, run_experiments
|
||||
from ray.tune.durable_trainable import durable
|
||||
from ray.tune.schedulers import (TrialScheduler, FIFOScheduler,
|
||||
AsyncHyperBandScheduler)
|
||||
from ray.tune.stopper import MaximumIterationStopper, TrialPlateauStopper
|
||||
|
@ -838,14 +840,51 @@ class TrainableFunctionApiTest(unittest.TestCase):
|
|||
]
|
||||
self.assertTrue(all(complete_results1))
|
||||
|
||||
def testDurableTrainable(self):
|
||||
def _testDurableTrainable(self, trainable, function=False, cleanup=True):
|
||||
sync_client = mock_storage_client()
|
||||
mock_get_client = "ray.tune.durable_trainable.get_cloud_sync_client"
|
||||
with patch(mock_get_client) as mock_get_cloud_sync_client:
|
||||
mock_get_cloud_sync_client.return_value = sync_client
|
||||
test_trainable = trainable(remote_checkpoint_dir=MOCK_REMOTE_DIR)
|
||||
result = test_trainable.train()
|
||||
self.assertEqual(result["metric"], 1)
|
||||
checkpoint_path = test_trainable.save()
|
||||
result = test_trainable.train()
|
||||
self.assertEqual(result["metric"], 2)
|
||||
result = test_trainable.train()
|
||||
self.assertEqual(result["metric"], 3)
|
||||
result = test_trainable.train()
|
||||
self.assertEqual(result["metric"], 4)
|
||||
|
||||
if not function:
|
||||
test_trainable.state["hi"] = 2
|
||||
test_trainable.restore(checkpoint_path)
|
||||
self.assertEqual(test_trainable.state["hi"], 1)
|
||||
else:
|
||||
# Cannot re-use function trainable, create new
|
||||
tune.session.shutdown()
|
||||
test_trainable = trainable(
|
||||
remote_checkpoint_dir=MOCK_REMOTE_DIR)
|
||||
test_trainable.restore(checkpoint_path)
|
||||
|
||||
result = test_trainable.train()
|
||||
self.assertEqual(result["metric"], 2)
|
||||
|
||||
if cleanup:
|
||||
self.addCleanup(shutil.rmtree, MOCK_REMOTE_DIR)
|
||||
|
||||
def testDurableTrainableClass(self):
|
||||
class TestTrain(DurableTrainable):
|
||||
def setup(self, config):
|
||||
self.state = {"hi": 1, "iter": 0}
|
||||
|
||||
def step(self):
|
||||
self.state["iter"] += 1
|
||||
return {"timesteps_this_iter": 1, "done": True}
|
||||
return {
|
||||
"timesteps_this_iter": 1,
|
||||
"metric": self.state["iter"],
|
||||
"done": self.state["iter"] > 3
|
||||
}
|
||||
|
||||
def save_checkpoint(self, path):
|
||||
return self.state
|
||||
|
@ -853,18 +892,54 @@ class TrainableFunctionApiTest(unittest.TestCase):
|
|||
def load_checkpoint(self, state):
|
||||
self.state = state
|
||||
|
||||
sync_client = mock_storage_client()
|
||||
mock_get_client = "ray.tune.durable_trainable.get_cloud_sync_client"
|
||||
with patch(mock_get_client) as mock_get_cloud_sync_client:
|
||||
mock_get_cloud_sync_client.return_value = sync_client
|
||||
test_trainable = TestTrain(remote_checkpoint_dir=MOCK_REMOTE_DIR)
|
||||
checkpoint_path = test_trainable.save()
|
||||
test_trainable.train()
|
||||
test_trainable.state["hi"] = 2
|
||||
test_trainable.restore(checkpoint_path)
|
||||
self.assertEqual(test_trainable.state["hi"], 1)
|
||||
self._testDurableTrainable(TestTrain)
|
||||
|
||||
self.addCleanup(shutil.rmtree, MOCK_REMOTE_DIR)
|
||||
def testDurableTrainableWrapped(self):
|
||||
class TestTrain(Trainable):
|
||||
def setup(self, config):
|
||||
self.state = {"hi": 1, "iter": 0}
|
||||
|
||||
def step(self):
|
||||
self.state["iter"] += 1
|
||||
return {
|
||||
"timesteps_this_iter": 1,
|
||||
"metric": self.state["iter"],
|
||||
"done": self.state["iter"] > 3
|
||||
}
|
||||
|
||||
def save_checkpoint(self, path):
|
||||
return self.state
|
||||
|
||||
def load_checkpoint(self, state):
|
||||
self.state = state
|
||||
|
||||
self._testDurableTrainable(durable(TestTrain), cleanup=False)
|
||||
|
||||
tune.register_trainable("test_train", TestTrain)
|
||||
|
||||
self._testDurableTrainable(durable("test_train"))
|
||||
|
||||
def testDurableTrainableFunction(self):
|
||||
def test_train(config, checkpoint_dir=None):
|
||||
state = {"hi": 1, "iter": 0}
|
||||
if checkpoint_dir:
|
||||
with open(os.path.join(checkpoint_dir, "ckpt.pkl"),
|
||||
"rb") as fp:
|
||||
state = pickle.load(fp)
|
||||
|
||||
for i in range(4):
|
||||
state["iter"] += 1
|
||||
with tune.checkpoint_dir(step=state["iter"]) as dir:
|
||||
with open(os.path.join(dir, "ckpt.pkl"), "wb") as fp:
|
||||
pickle.dump(state, fp)
|
||||
tune.report(
|
||||
**{
|
||||
"timesteps_this_iter": 1,
|
||||
"metric": state["iter"],
|
||||
"done": state["iter"] > 3
|
||||
})
|
||||
|
||||
self._testDurableTrainable(durable(test_train), function=True)
|
||||
|
||||
def testCheckpointDict(self):
|
||||
class TestTrain(Trainable):
|
||||
|
|
|
@ -602,13 +602,16 @@ def validate_save_restore(trainable_cls,
|
|||
return True
|
||||
|
||||
|
||||
def detect_checkpoint_function(train_func, abort=False):
|
||||
def detect_checkpoint_function(train_func, abort=False, partial=False):
|
||||
"""Use checkpointing if any arg has "checkpoint_dir" and args = 2"""
|
||||
func_sig = inspect.signature(train_func)
|
||||
validated = True
|
||||
try:
|
||||
# check if signature is func(config, checkpoint_dir=None)
|
||||
func_sig.bind({}, checkpoint_dir="tmp/path")
|
||||
if partial:
|
||||
func_sig.bind_partial({}, checkpoint_dir="tmp/path")
|
||||
else:
|
||||
func_sig.bind({}, checkpoint_dir="tmp/path")
|
||||
except Exception as e:
|
||||
logger.debug(str(e))
|
||||
validated = False
|
||||
|
|
|
@ -88,6 +88,7 @@ def timed_tune_run(name: str,
|
|||
checkpoint_size_b: int = 0,
|
||||
**tune_kwargs):
|
||||
durable = "sync_config" in tune_kwargs and \
|
||||
tune_kwargs["sync_config"].upload_dir and \
|
||||
tune_kwargs["sync_config"].upload_dir.startswith("s3://")
|
||||
|
||||
sleep_time = 1. / results_per_second
|
||||
|
|
Loading…
Add table
Reference in a new issue