mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
[tune] make tune.with_parameters()
work with the class API (#14532)
* [tune] make `tune.with_parameters()` work with the class API * Update python/ray/tune/utils/trainable.py Co-authored-by: Richard Liaw <rliaw@berkeley.edu> Co-authored-by: Richard Liaw <rliaw@berkeley.edu>
This commit is contained in:
parent
f2348a5456
commit
43e098402a
8 changed files with 204 additions and 94 deletions
|
@ -249,16 +249,14 @@ on other nodes as well. Please refer to the
|
|||
about these placement strategies.
|
||||
|
||||
|
||||
How can I pass further parameter values to my trainable function?
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
**This is only applicable for the Tune function API.**
|
||||
How can I pass further parameter values to my trainable?
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
Ray Tune expects your trainable functions to accept only up to two parameters,
|
||||
``config`` and ``checkpoint_dir``. But sometimes there are cases where
|
||||
you want to pass constant arguments, like the number of epochs to run,
|
||||
or a dataset to train on. Ray Tune offers a wrapper function to achieve
|
||||
just that, called ``tune.with_parameters()``:
|
||||
just that, called :func:`tune.with_parameters() <ray.tune.with_parameters>`:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
|
@ -283,6 +281,11 @@ the parameters directly in the Ray object store. This means that you
|
|||
can pass even huge objects like datasets, and Ray makes sure that these
|
||||
are efficiently stored and retrieved on your cluster machines.
|
||||
|
||||
:func:`tune.with_parameters() <ray.tune.with_parameters>`
|
||||
also works with class trainables. Please see
|
||||
:ref:`here for further details <tune-with-parameters>` and examples.
|
||||
|
||||
|
||||
How can I reproduce experiments
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
Reproducing experiments and experiment results means that you get the exact same
|
||||
|
|
|
@ -18,11 +18,6 @@ tune.Experiment
|
|||
|
||||
.. autofunction:: ray.tune.Experiment
|
||||
|
||||
tune.with_parameters
|
||||
--------------------
|
||||
|
||||
.. autofunction:: ray.tune.with_parameters
|
||||
|
||||
.. _tune-sync-config:
|
||||
|
||||
tune.SyncConfig
|
||||
|
|
|
@ -396,6 +396,13 @@ via SSH, which quickly can become a performance bottleneck. The :class:`DurableT
|
|||
|
||||
.. autofunction:: ray.tune.durable
|
||||
|
||||
.. _tune-with-parameters:
|
||||
|
||||
tune.with_parameters
|
||||
--------------------
|
||||
|
||||
.. autofunction:: ray.tune.with_parameters
|
||||
|
||||
|
||||
StatusReporter
|
||||
--------------
|
||||
|
|
|
@ -1,6 +1,5 @@
|
|||
from ray.tune.error import TuneError
|
||||
from ray.tune.tune import run_experiments, run
|
||||
from ray.tune.function_runner import with_parameters
|
||||
from ray.tune.syncer import SyncConfig
|
||||
from ray.tune.experiment import Experiment
|
||||
from ray.tune.analysis import ExperimentAnalysis, Analysis
|
||||
|
@ -21,6 +20,7 @@ from ray.tune.sample import (function, sample_from, uniform, quniform, choice,
|
|||
from ray.tune.suggest import create_searcher
|
||||
from ray.tune.schedulers import create_scheduler
|
||||
from ray.tune.utils.placement_groups import PlacementGroupFactory
|
||||
from ray.tune.utils.trainable import with_parameters
|
||||
|
||||
__all__ = [
|
||||
"Trainable", "DurableTrainable", "durable", "Callback", "TuneError",
|
||||
|
|
|
@ -10,7 +10,6 @@ import uuid
|
|||
from functools import partial
|
||||
from numbers import Number
|
||||
|
||||
from ray.tune.registry import parameter_registry
|
||||
from six.moves import queue
|
||||
|
||||
from ray.util.debug import log_once
|
||||
|
@ -20,6 +19,7 @@ from ray.tune.result import (DEFAULT_METRIC, TIME_THIS_ITER_S,
|
|||
RESULT_DUPLICATE, SHOULD_CHECKPOINT)
|
||||
from ray.tune.utils import (detect_checkpoint_function, detect_config_single,
|
||||
detect_reporter)
|
||||
from ray.tune.utils.trainable import with_parameters # noqa: F401
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -587,83 +587,3 @@ def wrap_function(train_func, durable=False, warn=True):
|
|||
return output
|
||||
|
||||
return ImplicitFunc
|
||||
|
||||
|
||||
def with_parameters(fn, **kwargs):
|
||||
"""Wrapper for function trainables to pass arbitrary large data objects.
|
||||
|
||||
This wrapper function will store all passed parameters in the Ray
|
||||
object store and retrieve them when calling the function. It can thus
|
||||
be used to pass arbitrary data, even datasets, to Tune trainable functions.
|
||||
|
||||
This can also be used as an alternative to `functools.partial` to pass
|
||||
default arguments to trainables.
|
||||
|
||||
Args:
|
||||
fn: function to wrap
|
||||
**kwargs: parameters to store in object store.
|
||||
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from ray import tune
|
||||
|
||||
def train(config, data=None):
|
||||
for sample in data:
|
||||
# ...
|
||||
tune.report(loss=loss)
|
||||
|
||||
data = HugeDataset(download=True)
|
||||
|
||||
tune.run(
|
||||
tune.with_parameters(train, data=data),
|
||||
#...
|
||||
)
|
||||
|
||||
"""
|
||||
if not callable(fn) or inspect.isclass(fn):
|
||||
raise ValueError(
|
||||
"`tune.with_parameters()` only works with the function API. "
|
||||
"If you want to pass parameters to Trainable _classes_, consider "
|
||||
"passing them via the `config` parameter.")
|
||||
|
||||
prefix = f"{str(fn)}_"
|
||||
for k, v in kwargs.items():
|
||||
parameter_registry.put(prefix + k, v)
|
||||
|
||||
use_checkpoint = detect_checkpoint_function(fn, partial=True)
|
||||
keys = list(kwargs.keys())
|
||||
|
||||
def inner(config, checkpoint_dir=None):
|
||||
fn_kwargs = {}
|
||||
if use_checkpoint:
|
||||
default = checkpoint_dir
|
||||
sig = inspect.signature(fn)
|
||||
if "checkpoint_dir" in sig.parameters:
|
||||
default = sig.parameters["checkpoint_dir"].default \
|
||||
or default
|
||||
fn_kwargs["checkpoint_dir"] = default
|
||||
|
||||
for k in keys:
|
||||
fn_kwargs[k] = parameter_registry.get(prefix + k)
|
||||
fn(config, **fn_kwargs)
|
||||
|
||||
fn_name = getattr(fn, "__name__", "tune_with_parameters")
|
||||
inner.__name__ = fn_name
|
||||
|
||||
# Use correct function signature if no `checkpoint_dir` parameter is set
|
||||
if not use_checkpoint:
|
||||
|
||||
def _inner(config):
|
||||
inner(config, checkpoint_dir=None)
|
||||
|
||||
_inner.__name__ = fn_name
|
||||
|
||||
if hasattr(fn, "__mixins__"):
|
||||
_inner.__mixins__ = fn.__mixins__
|
||||
return _inner
|
||||
|
||||
if hasattr(fn, "__mixins__"):
|
||||
inner.__mixins__ = fn.__mixins__
|
||||
|
||||
return inner
|
||||
|
|
|
@ -5,6 +5,7 @@ import tempfile
|
|||
import copy
|
||||
import numpy as np
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import unittest
|
||||
from unittest.mock import patch
|
||||
|
@ -1243,6 +1244,53 @@ class TrainableFunctionApiTest(unittest.TestCase):
|
|||
break
|
||||
self.assertFalse(found)
|
||||
|
||||
def testWithParameters(self):
|
||||
class Data:
|
||||
def __init__(self):
|
||||
self.data = [0] * 500_000
|
||||
|
||||
data = Data()
|
||||
data.data[100] = 1
|
||||
|
||||
class TestTrainable(Trainable):
|
||||
def setup(self, config, data):
|
||||
self.data = data.data
|
||||
self.data[101] = 2 # Changes are local
|
||||
|
||||
def step(self):
|
||||
return dict(
|
||||
metric=len(self.data), hundred=self.data[100], done=True)
|
||||
|
||||
trial_1, trial_2 = tune.run(
|
||||
tune.with_parameters(TestTrainable, data=data),
|
||||
num_samples=2).trials
|
||||
|
||||
self.assertEqual(data.data[101], 0)
|
||||
self.assertEqual(trial_1.last_result["metric"], 500_000)
|
||||
self.assertEqual(trial_1.last_result["hundred"], 1)
|
||||
self.assertEqual(trial_2.last_result["metric"], 500_000)
|
||||
self.assertEqual(trial_2.last_result["hundred"], 1)
|
||||
self.assertTrue(str(trial_1).startswith("TestTrainable"))
|
||||
|
||||
def testWithParameters2(self):
|
||||
class Data:
|
||||
def __init__(self):
|
||||
import numpy as np
|
||||
self.data = np.random.rand((2 * 1024 * 1024))
|
||||
|
||||
class TestTrainable(Trainable):
|
||||
def setup(self, config, data):
|
||||
self.data = data.data
|
||||
|
||||
def step(self):
|
||||
return dict(metric=len(self.data), done=True)
|
||||
|
||||
trainable = tune.with_parameters(TestTrainable, data=Data())
|
||||
# ray.cloudpickle will crash for some reason
|
||||
import cloudpickle as cp
|
||||
dumped = cp.dumps(trainable)
|
||||
assert sys.getsizeof(dumped) < 100 * 1024
|
||||
|
||||
|
||||
class SerializabilityTest(unittest.TestCase):
|
||||
@classmethod
|
||||
|
@ -1467,5 +1515,4 @@ class ApiTestFast(unittest.TestCase):
|
|||
|
||||
if __name__ == "__main__":
|
||||
import pytest
|
||||
import sys
|
||||
sys.exit(pytest.main(["-v", __file__]))
|
||||
|
|
|
@ -779,6 +779,7 @@ class Trainable:
|
|||
Args:
|
||||
config (dict): Hyperparameters and other configs given.
|
||||
Copy of `self.config`.
|
||||
|
||||
"""
|
||||
self._setup(config)
|
||||
if self._is_overridden("_setup") and log_once("_setup"):
|
||||
|
|
|
@ -1,14 +1,18 @@
|
|||
from typing import Dict, Any
|
||||
|
||||
import glob
|
||||
import inspect
|
||||
import io
|
||||
import logging
|
||||
import shutil
|
||||
from typing import Dict, Any
|
||||
|
||||
import pandas as pd
|
||||
import ray.cloudpickle as pickle
|
||||
import os
|
||||
|
||||
import ray
|
||||
from ray.tune.registry import parameter_registry
|
||||
from ray.tune.utils import detect_checkpoint_function
|
||||
from ray.util import placement_group
|
||||
from six import string_types
|
||||
|
||||
|
@ -211,3 +215,136 @@ class PlacementGroupUtil:
|
|||
options["placement_group"] = pg
|
||||
|
||||
return options, pg
|
||||
|
||||
|
||||
def with_parameters(trainable, **kwargs):
|
||||
"""Wrapper for trainables to pass arbitrary large data objects.
|
||||
|
||||
This wrapper function will store all passed parameters in the Ray
|
||||
object store and retrieve them when calling the function. It can thus
|
||||
be used to pass arbitrary data, even datasets, to Tune trainables.
|
||||
|
||||
This can also be used as an alternative to ``functools.partial`` to pass
|
||||
default arguments to trainables.
|
||||
|
||||
When used with the function API, the trainable function is called with
|
||||
the passed parameters as keyword arguments. When used with the class API,
|
||||
the ``Trainable.setup()`` method is called with the respective kwargs.
|
||||
|
||||
Args:
|
||||
trainable: Trainable to wrap.
|
||||
**kwargs: parameters to store in object store.
|
||||
|
||||
Function API example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from ray import tune
|
||||
|
||||
def train(config, data=None):
|
||||
for sample in data:
|
||||
loss = update_model(sample)
|
||||
tune.report(loss=loss)
|
||||
|
||||
data = HugeDataset(download=True)
|
||||
|
||||
tune.run(
|
||||
tune.with_parameters(train, data=data),
|
||||
# ...
|
||||
)
|
||||
|
||||
Class API example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from ray import tune
|
||||
|
||||
class MyTrainable(tune.Trainable):
|
||||
def setup(self, config, data=None):
|
||||
self.data = data
|
||||
self.iter = iter(self.data)
|
||||
self.next_sample = next(self.iter)
|
||||
|
||||
def step(self):
|
||||
loss = update_model(self.next_sample)
|
||||
try:
|
||||
self.next_sample = next(self.iter)
|
||||
except StopIteration:
|
||||
return {"loss": loss, done: True}
|
||||
return {"loss": loss}
|
||||
|
||||
data = HugeDataset(download=True)
|
||||
|
||||
tune.run(
|
||||
tune.with_parameters(MyTrainable, data=data),
|
||||
# ...
|
||||
)
|
||||
|
||||
"""
|
||||
from ray.tune.trainable import Trainable
|
||||
|
||||
if not callable(trainable) or (inspect.isclass(trainable)
|
||||
and not issubclass(trainable, Trainable)):
|
||||
raise ValueError(
|
||||
f"`tune.with_parameters() only works with function trainables "
|
||||
f"or classes that inherit from `tune.Trainable()`. Got type: "
|
||||
f"{type(trainable)}.")
|
||||
|
||||
# Objects are moved into the object store
|
||||
prefix = f"{str(trainable)}_"
|
||||
for k, v in kwargs.items():
|
||||
parameter_registry.put(prefix + k, v)
|
||||
|
||||
trainable_name = getattr(trainable, "__name__", "tune_with_parameters")
|
||||
|
||||
if inspect.isclass(trainable):
|
||||
# Class trainable
|
||||
keys = list(kwargs.keys())
|
||||
|
||||
class _Inner(trainable):
|
||||
def setup(self, config):
|
||||
setup_kwargs = {}
|
||||
for k in keys:
|
||||
setup_kwargs[k] = parameter_registry.get(prefix + k)
|
||||
super(_Inner, self).setup(config, **setup_kwargs)
|
||||
|
||||
_Inner.__name__ = trainable_name
|
||||
return _Inner
|
||||
else:
|
||||
# Function trainable
|
||||
use_checkpoint = detect_checkpoint_function(trainable, partial=True)
|
||||
keys = list(kwargs.keys())
|
||||
|
||||
def inner(config, checkpoint_dir=None):
|
||||
fn_kwargs = {}
|
||||
if use_checkpoint:
|
||||
default = checkpoint_dir
|
||||
sig = inspect.signature(trainable)
|
||||
if "checkpoint_dir" in sig.parameters:
|
||||
default = sig.parameters["checkpoint_dir"].default \
|
||||
or default
|
||||
fn_kwargs["checkpoint_dir"] = default
|
||||
|
||||
for k in keys:
|
||||
fn_kwargs[k] = parameter_registry.get(prefix + k)
|
||||
trainable(config, **fn_kwargs)
|
||||
|
||||
inner.__name__ = trainable_name
|
||||
|
||||
# Use correct function signature if no `checkpoint_dir` parameter
|
||||
# is set
|
||||
if not use_checkpoint:
|
||||
|
||||
def _inner(config):
|
||||
inner(config, checkpoint_dir=None)
|
||||
|
||||
_inner.__name__ = trainable_name
|
||||
|
||||
if hasattr(trainable, "__mixins__"):
|
||||
_inner.__mixins__ = trainable.__mixins__
|
||||
return _inner
|
||||
|
||||
if hasattr(trainable, "__mixins__"):
|
||||
inner.__mixins__ = trainable.__mixins__
|
||||
|
||||
return inner
|
||||
|
|
Loading…
Add table
Reference in a new issue