[tune] make tune.with_parameters() work with the class API (#14532)

* [tune] make `tune.with_parameters()` work with the class API

* Update python/ray/tune/utils/trainable.py

Co-authored-by: Richard Liaw <rliaw@berkeley.edu>

Co-authored-by: Richard Liaw <rliaw@berkeley.edu>
This commit is contained in:
Kai Fricke 2021-03-09 09:36:17 +01:00 committed by GitHub
parent f2348a5456
commit 43e098402a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 204 additions and 94 deletions

View file

@ -249,16 +249,14 @@ on other nodes as well. Please refer to the
about these placement strategies.
How can I pass further parameter values to my trainable function?
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
**This is only applicable for the Tune function API.**
How can I pass further parameter values to my trainable?
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Ray Tune expects your trainable functions to accept only up to two parameters,
``config`` and ``checkpoint_dir``. But sometimes there are cases where
you want to pass constant arguments, like the number of epochs to run,
or a dataset to train on. Ray Tune offers a wrapper function to achieve
just that, called ``tune.with_parameters()``:
just that, called :func:`tune.with_parameters() <ray.tune.with_parameters>`:
.. code-block:: python
@ -283,6 +281,11 @@ the parameters directly in the Ray object store. This means that you
can pass even huge objects like datasets, and Ray makes sure that these
are efficiently stored and retrieved on your cluster machines.
:func:`tune.with_parameters() <ray.tune.with_parameters>`
also works with class trainables. Please see
:ref:`here for further details <tune-with-parameters>` and examples.
How can I reproduce experiments
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Reproducing experiments and experiment results means that you get the exact same

View file

@ -18,11 +18,6 @@ tune.Experiment
.. autofunction:: ray.tune.Experiment
tune.with_parameters
--------------------
.. autofunction:: ray.tune.with_parameters
.. _tune-sync-config:
tune.SyncConfig

View file

@ -396,6 +396,13 @@ via SSH, which quickly can become a performance bottleneck. The :class:`DurableT
.. autofunction:: ray.tune.durable
.. _tune-with-parameters:
tune.with_parameters
--------------------
.. autofunction:: ray.tune.with_parameters
StatusReporter
--------------

View file

@ -1,6 +1,5 @@
from ray.tune.error import TuneError
from ray.tune.tune import run_experiments, run
from ray.tune.function_runner import with_parameters
from ray.tune.syncer import SyncConfig
from ray.tune.experiment import Experiment
from ray.tune.analysis import ExperimentAnalysis, Analysis
@ -21,6 +20,7 @@ from ray.tune.sample import (function, sample_from, uniform, quniform, choice,
from ray.tune.suggest import create_searcher
from ray.tune.schedulers import create_scheduler
from ray.tune.utils.placement_groups import PlacementGroupFactory
from ray.tune.utils.trainable import with_parameters
__all__ = [
"Trainable", "DurableTrainable", "durable", "Callback", "TuneError",

View file

@ -10,7 +10,6 @@ import uuid
from functools import partial
from numbers import Number
from ray.tune.registry import parameter_registry
from six.moves import queue
from ray.util.debug import log_once
@ -20,6 +19,7 @@ from ray.tune.result import (DEFAULT_METRIC, TIME_THIS_ITER_S,
RESULT_DUPLICATE, SHOULD_CHECKPOINT)
from ray.tune.utils import (detect_checkpoint_function, detect_config_single,
detect_reporter)
from ray.tune.utils.trainable import with_parameters # noqa: F401
logger = logging.getLogger(__name__)
@ -587,83 +587,3 @@ def wrap_function(train_func, durable=False, warn=True):
return output
return ImplicitFunc
def with_parameters(fn, **kwargs):
"""Wrapper for function trainables to pass arbitrary large data objects.
This wrapper function will store all passed parameters in the Ray
object store and retrieve them when calling the function. It can thus
be used to pass arbitrary data, even datasets, to Tune trainable functions.
This can also be used as an alternative to `functools.partial` to pass
default arguments to trainables.
Args:
fn: function to wrap
**kwargs: parameters to store in object store.
.. code-block:: python
from ray import tune
def train(config, data=None):
for sample in data:
# ...
tune.report(loss=loss)
data = HugeDataset(download=True)
tune.run(
tune.with_parameters(train, data=data),
#...
)
"""
if not callable(fn) or inspect.isclass(fn):
raise ValueError(
"`tune.with_parameters()` only works with the function API. "
"If you want to pass parameters to Trainable _classes_, consider "
"passing them via the `config` parameter.")
prefix = f"{str(fn)}_"
for k, v in kwargs.items():
parameter_registry.put(prefix + k, v)
use_checkpoint = detect_checkpoint_function(fn, partial=True)
keys = list(kwargs.keys())
def inner(config, checkpoint_dir=None):
fn_kwargs = {}
if use_checkpoint:
default = checkpoint_dir
sig = inspect.signature(fn)
if "checkpoint_dir" in sig.parameters:
default = sig.parameters["checkpoint_dir"].default \
or default
fn_kwargs["checkpoint_dir"] = default
for k in keys:
fn_kwargs[k] = parameter_registry.get(prefix + k)
fn(config, **fn_kwargs)
fn_name = getattr(fn, "__name__", "tune_with_parameters")
inner.__name__ = fn_name
# Use correct function signature if no `checkpoint_dir` parameter is set
if not use_checkpoint:
def _inner(config):
inner(config, checkpoint_dir=None)
_inner.__name__ = fn_name
if hasattr(fn, "__mixins__"):
_inner.__mixins__ = fn.__mixins__
return _inner
if hasattr(fn, "__mixins__"):
inner.__mixins__ = fn.__mixins__
return inner

View file

@ -5,6 +5,7 @@ import tempfile
import copy
import numpy as np
import os
import sys
import time
import unittest
from unittest.mock import patch
@ -1243,6 +1244,53 @@ class TrainableFunctionApiTest(unittest.TestCase):
break
self.assertFalse(found)
def testWithParameters(self):
class Data:
def __init__(self):
self.data = [0] * 500_000
data = Data()
data.data[100] = 1
class TestTrainable(Trainable):
def setup(self, config, data):
self.data = data.data
self.data[101] = 2 # Changes are local
def step(self):
return dict(
metric=len(self.data), hundred=self.data[100], done=True)
trial_1, trial_2 = tune.run(
tune.with_parameters(TestTrainable, data=data),
num_samples=2).trials
self.assertEqual(data.data[101], 0)
self.assertEqual(trial_1.last_result["metric"], 500_000)
self.assertEqual(trial_1.last_result["hundred"], 1)
self.assertEqual(trial_2.last_result["metric"], 500_000)
self.assertEqual(trial_2.last_result["hundred"], 1)
self.assertTrue(str(trial_1).startswith("TestTrainable"))
def testWithParameters2(self):
class Data:
def __init__(self):
import numpy as np
self.data = np.random.rand((2 * 1024 * 1024))
class TestTrainable(Trainable):
def setup(self, config, data):
self.data = data.data
def step(self):
return dict(metric=len(self.data), done=True)
trainable = tune.with_parameters(TestTrainable, data=Data())
# ray.cloudpickle will crash for some reason
import cloudpickle as cp
dumped = cp.dumps(trainable)
assert sys.getsizeof(dumped) < 100 * 1024
class SerializabilityTest(unittest.TestCase):
@classmethod
@ -1467,5 +1515,4 @@ class ApiTestFast(unittest.TestCase):
if __name__ == "__main__":
import pytest
import sys
sys.exit(pytest.main(["-v", __file__]))

View file

@ -779,6 +779,7 @@ class Trainable:
Args:
config (dict): Hyperparameters and other configs given.
Copy of `self.config`.
"""
self._setup(config)
if self._is_overridden("_setup") and log_once("_setup"):

View file

@ -1,14 +1,18 @@
from typing import Dict, Any
import glob
import inspect
import io
import logging
import shutil
from typing import Dict, Any
import pandas as pd
import ray.cloudpickle as pickle
import os
import ray
from ray.tune.registry import parameter_registry
from ray.tune.utils import detect_checkpoint_function
from ray.util import placement_group
from six import string_types
@ -211,3 +215,136 @@ class PlacementGroupUtil:
options["placement_group"] = pg
return options, pg
def with_parameters(trainable, **kwargs):
"""Wrapper for trainables to pass arbitrary large data objects.
This wrapper function will store all passed parameters in the Ray
object store and retrieve them when calling the function. It can thus
be used to pass arbitrary data, even datasets, to Tune trainables.
This can also be used as an alternative to ``functools.partial`` to pass
default arguments to trainables.
When used with the function API, the trainable function is called with
the passed parameters as keyword arguments. When used with the class API,
the ``Trainable.setup()`` method is called with the respective kwargs.
Args:
trainable: Trainable to wrap.
**kwargs: parameters to store in object store.
Function API example:
.. code-block:: python
from ray import tune
def train(config, data=None):
for sample in data:
loss = update_model(sample)
tune.report(loss=loss)
data = HugeDataset(download=True)
tune.run(
tune.with_parameters(train, data=data),
# ...
)
Class API example:
.. code-block:: python
from ray import tune
class MyTrainable(tune.Trainable):
def setup(self, config, data=None):
self.data = data
self.iter = iter(self.data)
self.next_sample = next(self.iter)
def step(self):
loss = update_model(self.next_sample)
try:
self.next_sample = next(self.iter)
except StopIteration:
return {"loss": loss, done: True}
return {"loss": loss}
data = HugeDataset(download=True)
tune.run(
tune.with_parameters(MyTrainable, data=data),
# ...
)
"""
from ray.tune.trainable import Trainable
if not callable(trainable) or (inspect.isclass(trainable)
and not issubclass(trainable, Trainable)):
raise ValueError(
f"`tune.with_parameters() only works with function trainables "
f"or classes that inherit from `tune.Trainable()`. Got type: "
f"{type(trainable)}.")
# Objects are moved into the object store
prefix = f"{str(trainable)}_"
for k, v in kwargs.items():
parameter_registry.put(prefix + k, v)
trainable_name = getattr(trainable, "__name__", "tune_with_parameters")
if inspect.isclass(trainable):
# Class trainable
keys = list(kwargs.keys())
class _Inner(trainable):
def setup(self, config):
setup_kwargs = {}
for k in keys:
setup_kwargs[k] = parameter_registry.get(prefix + k)
super(_Inner, self).setup(config, **setup_kwargs)
_Inner.__name__ = trainable_name
return _Inner
else:
# Function trainable
use_checkpoint = detect_checkpoint_function(trainable, partial=True)
keys = list(kwargs.keys())
def inner(config, checkpoint_dir=None):
fn_kwargs = {}
if use_checkpoint:
default = checkpoint_dir
sig = inspect.signature(trainable)
if "checkpoint_dir" in sig.parameters:
default = sig.parameters["checkpoint_dir"].default \
or default
fn_kwargs["checkpoint_dir"] = default
for k in keys:
fn_kwargs[k] = parameter_registry.get(prefix + k)
trainable(config, **fn_kwargs)
inner.__name__ = trainable_name
# Use correct function signature if no `checkpoint_dir` parameter
# is set
if not use_checkpoint:
def _inner(config):
inner(config, checkpoint_dir=None)
_inner.__name__ = trainable_name
if hasattr(trainable, "__mixins__"):
_inner.__mixins__ = trainable.__mixins__
return _inner
if hasattr(trainable, "__mixins__"):
inner.__mixins__ = trainable.__mixins__
return inner