diff --git a/doc/source/tune/_tutorials/_faq.rst b/doc/source/tune/_tutorials/_faq.rst index f31a6afb9..a48d860d4 100644 --- a/doc/source/tune/_tutorials/_faq.rst +++ b/doc/source/tune/_tutorials/_faq.rst @@ -249,16 +249,14 @@ on other nodes as well. Please refer to the about these placement strategies. -How can I pass further parameter values to my trainable function? -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -**This is only applicable for the Tune function API.** +How can I pass further parameter values to my trainable? +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Ray Tune expects your trainable functions to accept only up to two parameters, ``config`` and ``checkpoint_dir``. But sometimes there are cases where you want to pass constant arguments, like the number of epochs to run, or a dataset to train on. Ray Tune offers a wrapper function to achieve -just that, called ``tune.with_parameters()``: +just that, called :func:`tune.with_parameters() `: .. code-block:: python @@ -283,6 +281,11 @@ the parameters directly in the Ray object store. This means that you can pass even huge objects like datasets, and Ray makes sure that these are efficiently stored and retrieved on your cluster machines. +:func:`tune.with_parameters() ` +also works with class trainables. Please see +:ref:`here for further details ` and examples. + + How can I reproduce experiments ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Reproducing experiments and experiment results means that you get the exact same diff --git a/doc/source/tune/api_docs/execution.rst b/doc/source/tune/api_docs/execution.rst index 9eebc3c27..6748b1df7 100644 --- a/doc/source/tune/api_docs/execution.rst +++ b/doc/source/tune/api_docs/execution.rst @@ -18,11 +18,6 @@ tune.Experiment .. autofunction:: ray.tune.Experiment -tune.with_parameters --------------------- - -.. autofunction:: ray.tune.with_parameters - .. _tune-sync-config: tune.SyncConfig diff --git a/doc/source/tune/api_docs/trainable.rst b/doc/source/tune/api_docs/trainable.rst index fb510cb8c..0d84600a1 100644 --- a/doc/source/tune/api_docs/trainable.rst +++ b/doc/source/tune/api_docs/trainable.rst @@ -396,6 +396,13 @@ via SSH, which quickly can become a performance bottleneck. The :class:`DurableT .. autofunction:: ray.tune.durable +.. _tune-with-parameters: + +tune.with_parameters +-------------------- + +.. autofunction:: ray.tune.with_parameters + StatusReporter -------------- diff --git a/python/ray/tune/__init__.py b/python/ray/tune/__init__.py index 5d9b9abdd..4d13abe44 100644 --- a/python/ray/tune/__init__.py +++ b/python/ray/tune/__init__.py @@ -1,6 +1,5 @@ from ray.tune.error import TuneError from ray.tune.tune import run_experiments, run -from ray.tune.function_runner import with_parameters from ray.tune.syncer import SyncConfig from ray.tune.experiment import Experiment from ray.tune.analysis import ExperimentAnalysis, Analysis @@ -21,6 +20,7 @@ from ray.tune.sample import (function, sample_from, uniform, quniform, choice, from ray.tune.suggest import create_searcher from ray.tune.schedulers import create_scheduler from ray.tune.utils.placement_groups import PlacementGroupFactory +from ray.tune.utils.trainable import with_parameters __all__ = [ "Trainable", "DurableTrainable", "durable", "Callback", "TuneError", diff --git a/python/ray/tune/function_runner.py b/python/ray/tune/function_runner.py index 83078081c..d4a303444 100644 --- a/python/ray/tune/function_runner.py +++ b/python/ray/tune/function_runner.py @@ -10,7 +10,6 @@ import uuid from functools import partial from numbers import Number -from ray.tune.registry import parameter_registry from six.moves import queue from ray.util.debug import log_once @@ -20,6 +19,7 @@ from ray.tune.result import (DEFAULT_METRIC, TIME_THIS_ITER_S, RESULT_DUPLICATE, SHOULD_CHECKPOINT) from ray.tune.utils import (detect_checkpoint_function, detect_config_single, detect_reporter) +from ray.tune.utils.trainable import with_parameters # noqa: F401 logger = logging.getLogger(__name__) @@ -587,83 +587,3 @@ def wrap_function(train_func, durable=False, warn=True): return output return ImplicitFunc - - -def with_parameters(fn, **kwargs): - """Wrapper for function trainables to pass arbitrary large data objects. - - This wrapper function will store all passed parameters in the Ray - object store and retrieve them when calling the function. It can thus - be used to pass arbitrary data, even datasets, to Tune trainable functions. - - This can also be used as an alternative to `functools.partial` to pass - default arguments to trainables. - - Args: - fn: function to wrap - **kwargs: parameters to store in object store. - - - .. code-block:: python - - from ray import tune - - def train(config, data=None): - for sample in data: - # ... - tune.report(loss=loss) - - data = HugeDataset(download=True) - - tune.run( - tune.with_parameters(train, data=data), - #... - ) - - """ - if not callable(fn) or inspect.isclass(fn): - raise ValueError( - "`tune.with_parameters()` only works with the function API. " - "If you want to pass parameters to Trainable _classes_, consider " - "passing them via the `config` parameter.") - - prefix = f"{str(fn)}_" - for k, v in kwargs.items(): - parameter_registry.put(prefix + k, v) - - use_checkpoint = detect_checkpoint_function(fn, partial=True) - keys = list(kwargs.keys()) - - def inner(config, checkpoint_dir=None): - fn_kwargs = {} - if use_checkpoint: - default = checkpoint_dir - sig = inspect.signature(fn) - if "checkpoint_dir" in sig.parameters: - default = sig.parameters["checkpoint_dir"].default \ - or default - fn_kwargs["checkpoint_dir"] = default - - for k in keys: - fn_kwargs[k] = parameter_registry.get(prefix + k) - fn(config, **fn_kwargs) - - fn_name = getattr(fn, "__name__", "tune_with_parameters") - inner.__name__ = fn_name - - # Use correct function signature if no `checkpoint_dir` parameter is set - if not use_checkpoint: - - def _inner(config): - inner(config, checkpoint_dir=None) - - _inner.__name__ = fn_name - - if hasattr(fn, "__mixins__"): - _inner.__mixins__ = fn.__mixins__ - return _inner - - if hasattr(fn, "__mixins__"): - inner.__mixins__ = fn.__mixins__ - - return inner diff --git a/python/ray/tune/tests/test_api.py b/python/ray/tune/tests/test_api.py index 91e4d3c9c..efc26b180 100644 --- a/python/ray/tune/tests/test_api.py +++ b/python/ray/tune/tests/test_api.py @@ -5,6 +5,7 @@ import tempfile import copy import numpy as np import os +import sys import time import unittest from unittest.mock import patch @@ -1243,6 +1244,53 @@ class TrainableFunctionApiTest(unittest.TestCase): break self.assertFalse(found) + def testWithParameters(self): + class Data: + def __init__(self): + self.data = [0] * 500_000 + + data = Data() + data.data[100] = 1 + + class TestTrainable(Trainable): + def setup(self, config, data): + self.data = data.data + self.data[101] = 2 # Changes are local + + def step(self): + return dict( + metric=len(self.data), hundred=self.data[100], done=True) + + trial_1, trial_2 = tune.run( + tune.with_parameters(TestTrainable, data=data), + num_samples=2).trials + + self.assertEqual(data.data[101], 0) + self.assertEqual(trial_1.last_result["metric"], 500_000) + self.assertEqual(trial_1.last_result["hundred"], 1) + self.assertEqual(trial_2.last_result["metric"], 500_000) + self.assertEqual(trial_2.last_result["hundred"], 1) + self.assertTrue(str(trial_1).startswith("TestTrainable")) + + def testWithParameters2(self): + class Data: + def __init__(self): + import numpy as np + self.data = np.random.rand((2 * 1024 * 1024)) + + class TestTrainable(Trainable): + def setup(self, config, data): + self.data = data.data + + def step(self): + return dict(metric=len(self.data), done=True) + + trainable = tune.with_parameters(TestTrainable, data=Data()) + # ray.cloudpickle will crash for some reason + import cloudpickle as cp + dumped = cp.dumps(trainable) + assert sys.getsizeof(dumped) < 100 * 1024 + class SerializabilityTest(unittest.TestCase): @classmethod @@ -1467,5 +1515,4 @@ class ApiTestFast(unittest.TestCase): if __name__ == "__main__": import pytest - import sys sys.exit(pytest.main(["-v", __file__])) diff --git a/python/ray/tune/trainable.py b/python/ray/tune/trainable.py index 1f5db4f43..8adee5ff5 100644 --- a/python/ray/tune/trainable.py +++ b/python/ray/tune/trainable.py @@ -779,6 +779,7 @@ class Trainable: Args: config (dict): Hyperparameters and other configs given. Copy of `self.config`. + """ self._setup(config) if self._is_overridden("_setup") and log_once("_setup"): diff --git a/python/ray/tune/utils/trainable.py b/python/ray/tune/utils/trainable.py index 9872788bf..f5cbea7cf 100644 --- a/python/ray/tune/utils/trainable.py +++ b/python/ray/tune/utils/trainable.py @@ -1,14 +1,18 @@ +from typing import Dict, Any + import glob +import inspect import io import logging import shutil -from typing import Dict, Any import pandas as pd import ray.cloudpickle as pickle import os import ray +from ray.tune.registry import parameter_registry +from ray.tune.utils import detect_checkpoint_function from ray.util import placement_group from six import string_types @@ -211,3 +215,136 @@ class PlacementGroupUtil: options["placement_group"] = pg return options, pg + + +def with_parameters(trainable, **kwargs): + """Wrapper for trainables to pass arbitrary large data objects. + + This wrapper function will store all passed parameters in the Ray + object store and retrieve them when calling the function. It can thus + be used to pass arbitrary data, even datasets, to Tune trainables. + + This can also be used as an alternative to ``functools.partial`` to pass + default arguments to trainables. + + When used with the function API, the trainable function is called with + the passed parameters as keyword arguments. When used with the class API, + the ``Trainable.setup()`` method is called with the respective kwargs. + + Args: + trainable: Trainable to wrap. + **kwargs: parameters to store in object store. + + Function API example: + + .. code-block:: python + + from ray import tune + + def train(config, data=None): + for sample in data: + loss = update_model(sample) + tune.report(loss=loss) + + data = HugeDataset(download=True) + + tune.run( + tune.with_parameters(train, data=data), + # ... + ) + + Class API example: + + .. code-block:: python + + from ray import tune + + class MyTrainable(tune.Trainable): + def setup(self, config, data=None): + self.data = data + self.iter = iter(self.data) + self.next_sample = next(self.iter) + + def step(self): + loss = update_model(self.next_sample) + try: + self.next_sample = next(self.iter) + except StopIteration: + return {"loss": loss, done: True} + return {"loss": loss} + + data = HugeDataset(download=True) + + tune.run( + tune.with_parameters(MyTrainable, data=data), + # ... + ) + + """ + from ray.tune.trainable import Trainable + + if not callable(trainable) or (inspect.isclass(trainable) + and not issubclass(trainable, Trainable)): + raise ValueError( + f"`tune.with_parameters() only works with function trainables " + f"or classes that inherit from `tune.Trainable()`. Got type: " + f"{type(trainable)}.") + + # Objects are moved into the object store + prefix = f"{str(trainable)}_" + for k, v in kwargs.items(): + parameter_registry.put(prefix + k, v) + + trainable_name = getattr(trainable, "__name__", "tune_with_parameters") + + if inspect.isclass(trainable): + # Class trainable + keys = list(kwargs.keys()) + + class _Inner(trainable): + def setup(self, config): + setup_kwargs = {} + for k in keys: + setup_kwargs[k] = parameter_registry.get(prefix + k) + super(_Inner, self).setup(config, **setup_kwargs) + + _Inner.__name__ = trainable_name + return _Inner + else: + # Function trainable + use_checkpoint = detect_checkpoint_function(trainable, partial=True) + keys = list(kwargs.keys()) + + def inner(config, checkpoint_dir=None): + fn_kwargs = {} + if use_checkpoint: + default = checkpoint_dir + sig = inspect.signature(trainable) + if "checkpoint_dir" in sig.parameters: + default = sig.parameters["checkpoint_dir"].default \ + or default + fn_kwargs["checkpoint_dir"] = default + + for k in keys: + fn_kwargs[k] = parameter_registry.get(prefix + k) + trainable(config, **fn_kwargs) + + inner.__name__ = trainable_name + + # Use correct function signature if no `checkpoint_dir` parameter + # is set + if not use_checkpoint: + + def _inner(config): + inner(config, checkpoint_dir=None) + + _inner.__name__ = trainable_name + + if hasattr(trainable, "__mixins__"): + _inner.__mixins__ = trainable.__mixins__ + return _inner + + if hasattr(trainable, "__mixins__"): + inner.__mixins__ = trainable.__mixins__ + + return inner