mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[tune] Function API checkpointing (#8471)
Co-authored-by: krfricke <krfricke@users.noreply.github.com>
This commit is contained in:
parent
91e57f2e53
commit
6c49c01837
21 changed files with 897 additions and 237 deletions
|
@ -141,6 +141,14 @@ $SUPPRESS_OUTPUT docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE}
|
|||
python /ray/python/ray/tune/examples/pbt_convnet_example.py \
|
||||
--smoke-test
|
||||
|
||||
$SUPPRESS_OUTPUT docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} --memory-swap=-1 $DOCKER_SHA \
|
||||
python /ray/python/ray/tune/examples/hyperband_function_example.py \
|
||||
--smoke-test
|
||||
|
||||
$SUPPRESS_OUTPUT docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} --memory-swap=-1 $DOCKER_SHA \
|
||||
python /ray/python/ray/tune/examples/pbt_function.py \
|
||||
--smoke-test
|
||||
|
||||
$SUPPRESS_OUTPUT docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} --memory-swap=-1 $DOCKER_SHA \
|
||||
python /ray/python/ray/tune/examples/pbt_dcgan_mnist/pbt_dcgan_mnist.py \
|
||||
--smoke-test
|
||||
|
|
|
@ -35,7 +35,7 @@ Here's an example of specifying the objective function using :ref:`the function-
|
|||
|
||||
Now, there's two Trainable APIs - one being the :ref:`function-based API <tune-function-api>` that we demonstrated above.
|
||||
|
||||
The other is a :ref:`class-based API <tune-class-api>` that enables :ref:`checkpointing and pausing <tune-trainable-save-restore>`. Here's an example of specifying the objective function using the :ref:`class-based API <tune-class-api>`:
|
||||
The other is a :ref:`class-based API <tune-class-api>`. Here's an example of specifying the objective function using the :ref:`class-based API <tune-class-api>`:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
|
|
|
@ -147,46 +147,39 @@ When running a hyperparameter search, Tune can automatically and periodically sa
|
|||
* fault-tolerance when using pre-emptible machines.
|
||||
* Pausing trials when using Trial Schedulers such as HyperBand and PBT.
|
||||
|
||||
To enable checkpointing, you must implement a :ref:`Trainable class <trainable-docs>` (the function-based API are not checkpointable, since they never return control back to their caller).
|
||||
Checkpointing assumes that the model state will be saved to disk on whichever node the Trainable is running on.
|
||||
|
||||
Checkpointing assumes that the model state will be saved to disk on whichever node the Trainable is running on. You can checkpoint with three different mechanisms: manually, periodically, and at termination.
|
||||
|
||||
**Manual Checkpointing**: A custom Trainable can manually trigger checkpointing by returning ``should_checkpoint: True`` (or ``tune.result.SHOULD_CHECKPOINT: True``) in the result dictionary of `_train`. This can be especially helpful in spot instances:
|
||||
To use Tune's checkpointing features, you must expose a ``checkpoint`` argument in the function signature, and call ``tune.make_checkpoint_dir`` and ``tune.save_checkpoint``:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
def _train(self):
|
||||
# training code
|
||||
result = {"mean_accuracy": accuracy}
|
||||
if detect_instance_preemption():
|
||||
result.update(should_checkpoint=True)
|
||||
return result
|
||||
import time
|
||||
from ray import tune
|
||||
|
||||
def train_func(config, checkpoint=None):
|
||||
start = 0
|
||||
if checkpoint:
|
||||
with open(checkpoint) as f:
|
||||
state = json.loads(f.read())
|
||||
start = state["step"] + 1
|
||||
|
||||
**Periodic Checkpointing**: periodic checkpointing can be used to provide fault-tolerance for experiments. This can be enabled by setting ``checkpoint_freq=<int>`` and ``max_failures=<int>`` to checkpoint trials every *N* iterations and recover from up to *M* crashes per trial, e.g.:
|
||||
for iter in range(start, 100):
|
||||
time.sleep(1)
|
||||
|
||||
.. code-block:: python
|
||||
# Obtain a checkpoint directory
|
||||
checkpoint_dir = tune.make_checkpoint_dir(step=step)
|
||||
path = os.path.join(checkpoint_dir, "checkpoint")
|
||||
with open(path, "w") as f:
|
||||
f.write(json.dumps({"step": start}))
|
||||
tune.save_checkpoint(path)
|
||||
|
||||
tune.run(
|
||||
my_trainable,
|
||||
checkpoint_freq=10,
|
||||
max_failures=5,
|
||||
)
|
||||
tune.report(hello="world", ray="tune")
|
||||
|
||||
**Checkpointing at Termination**: The checkpoint_freq may not coincide with the exact end of an experiment. If you want a checkpoint to be created at the end
|
||||
of a trial, you can additionally set the ``checkpoint_at_end=True``:
|
||||
tune.run(train_func)
|
||||
|
||||
.. code-block:: python
|
||||
:emphasize-lines: 5
|
||||
In this example, checkpoints will be saved by training iteration to ``local_dir/exp_name/trial_name/checkpoint_<step>``.
|
||||
|
||||
tune.run(
|
||||
my_trainable,
|
||||
checkpoint_freq=10,
|
||||
checkpoint_at_end=True,
|
||||
max_failures=5,
|
||||
)
|
||||
|
||||
The checkpoint will be saved at a path that looks like ``local_dir/exp_name/trial_name/checkpoint_x/``, where the x is the number of iterations so far when the checkpoint is saved. To restore the checkpoint, you can use the ``restore`` argument and specify a checkpoint file. By doing this, you can change whatever experiments' configuration such as the experiment's name, the training iteration or so:
|
||||
You can restore a single trial checkpoint by using ``tune.run(restore=<checkpoint_dir>)`` By doing this, you can change whatever experiments' configuration such as the experiment's name:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
|
|
|
@ -3,9 +3,7 @@
|
|||
Training (tune.Trainable, tune.report)
|
||||
======================================
|
||||
|
||||
Training can be done with either a **Class API** (``tune.Trainable``) or **function-based API** (``tune.report``).
|
||||
|
||||
You can use the **function-based API** for fast prototyping. On the other hand, the ``tune.Trainable`` interface supports checkpoint/restore functionality and provides more control for advanced algorithms.
|
||||
Training can be done with either a **Class API** (``tune.Trainable``) or **function API** (``tune.report``).
|
||||
|
||||
For the sake of example, let's maximize this objective function:
|
||||
|
||||
|
@ -16,8 +14,10 @@ For the sake of example, let's maximize this objective function:
|
|||
|
||||
.. _tune-function-api:
|
||||
|
||||
Function-based API
|
||||
------------------
|
||||
Function API
|
||||
------------
|
||||
|
||||
Here is a simple example of using the function API. You can report intermediate metrics by simply calling ``tune.report`` within the provided function.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
|
@ -25,31 +25,74 @@ Function-based API
|
|||
# config (dict): A dict of hyperparameters.
|
||||
|
||||
for x in range(20):
|
||||
score = objective(x, config["a"], config["b"])
|
||||
intermediate_score = objective(x, config["a"], config["b"])
|
||||
|
||||
tune.report(score=score) # This sends the score to Tune.
|
||||
tune.report(value=intermediate_score) # This sends the score to Tune.
|
||||
|
||||
analysis = tune.run(
|
||||
trainable,
|
||||
config={
|
||||
"a": 2,
|
||||
"b": 4
|
||||
})
|
||||
config={"a": 2, "b": 4}
|
||||
)
|
||||
|
||||
print("best config: ", analysis.get_best_config(metric="score", mode="max"))
|
||||
|
||||
.. tip:: Do not use ``tune.track.log`` within a ``Trainable`` class.
|
||||
.. tip:: Do not use ``tune.report`` within a ``Trainable`` class.
|
||||
|
||||
Tune will run this function on a separate thread in a Ray actor process. Note that this API is not checkpointable, since the thread will never return control back to its caller.
|
||||
Tune will run this function on a separate thread in a Ray actor process.
|
||||
|
||||
.. note:: If you want to pass in a Python lambda, you will need to first register the function: ``tune.register_trainable("lambda_id", lambda x: ...)``. You can then use ``lambda_id`` in place of ``my_trainable``.
|
||||
|
||||
Function API Checkpointing
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
Many Tune features rely on checkpointing, including the usage of certain Trial Schedulers and fault tolerance. To use Tune's checkpointing features, you must expose a ``checkpoint`` argument in the function signature, and call ``tune.make_checkpoint_dir`` and ``tune.save_checkpoint``:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import time
|
||||
from ray import tune
|
||||
|
||||
def train_func(config, checkpoint=None):
|
||||
start = 0
|
||||
if checkpoint:
|
||||
with open(checkpoint) as f:
|
||||
state = json.loads(f.read())
|
||||
start = state["step"] + 1
|
||||
|
||||
for iter in range(start, 100):
|
||||
time.sleep(1)
|
||||
|
||||
#
|
||||
checkpoint_dir = tune.make_checkpoint_dir(step=step)
|
||||
path = os.path.join(checkpoint_dir, "checkpoint")
|
||||
with open(path, "w") as f:
|
||||
f.write(json.dumps({"step": start}))
|
||||
tune.save_checkpoint(path)
|
||||
|
||||
tune.report(hello="world", ray="tune")
|
||||
|
||||
tune.run(train_func)
|
||||
|
||||
In this example, checkpoints will be saved by training iteration to ``local_dir/exp_name/trial_name/checkpoint_<step>``. You can restore a single trial checkpoint by using ``tune.run(restore=<checkpoint_dir>)``:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
analysis = tune.run(
|
||||
train,
|
||||
config={
|
||||
"max_iter": 5
|
||||
},
|
||||
).trials
|
||||
last_ckpt = trial.checkpoint.value
|
||||
analysis = tune.run(train, config={"max_iter": 10}, restore=last_ckpt)
|
||||
|
||||
Tune also may copy or move checkpoints during the course of tuning. For this purpose, it is important not to depend on absolute paths in the implementation of ``save``.
|
||||
|
||||
.. _tune-class-api:
|
||||
|
||||
Trainable Class API
|
||||
-------------------
|
||||
|
||||
.. caution:: Do not use ``tune.track.log`` within a ``Trainable`` class.
|
||||
.. caution:: Do not use ``tune.report`` within a ``Trainable`` class.
|
||||
|
||||
The Trainable **class API** will require users to subclass ``ray.tune.Trainable``. Here's a naive example of this API:
|
||||
|
||||
|
@ -87,14 +130,13 @@ As a subclass of ``tune.Trainable``, Tune will create a ``Trainable`` object on
|
|||
|
||||
.. tip:: As a rule of thumb, the execution time of ``_train`` should be large enough to avoid overheads (i.e. more than a few seconds), but short enough to report progress periodically (i.e. at most a few minutes).
|
||||
|
||||
In this example, we only implemented the ``_setup`` and ``_train`` methods for simplification. Next, we'll implement ``_save`` and ``_restore`` for checkpoint and fault tolerance.
|
||||
|
||||
.. _tune-trainable-save-restore:
|
||||
|
||||
Save and Restore
|
||||
~~~~~~~~~~~~~~~~
|
||||
Class API Checkpointing
|
||||
~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
Many Tune features rely on ``_save``, and ``_restore``, including the usage of certain Trial Schedulers, fault tolerance, and checkpointing.
|
||||
You can also implement checkpoint/restore using the Trainable Class API:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
|
@ -108,9 +150,45 @@ Many Tune features rely on ``_save``, and ``_restore``, including the usage of c
|
|||
checkpoint_path = os.path.join(tmp_checkpoint_dir, "model.pth")
|
||||
self.model.load_state_dict(torch.load(checkpoint_path))
|
||||
|
||||
Checkpoints will be saved by training iteration to ``local_dir/exp_name/trial_name/checkpoint_<iter>``. You can restore a single trial checkpoint by using ``tune.run(restore=<checkpoint_dir>)``.
|
||||
tune.run(MyTrainableClass, checkpoint_freq=2)
|
||||
|
||||
You can checkpoint with three different mechanisms: manually, periodically, and at termination.
|
||||
|
||||
**Manual Checkpointing**: A custom Trainable can manually trigger checkpointing by returning ``should_checkpoint: True`` (or ``tune.result.SHOULD_CHECKPOINT: True``) in the result dictionary of `_train`. This can be especially helpful in spot instances:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
def _train(self):
|
||||
# training code
|
||||
result = {"mean_accuracy": accuracy}
|
||||
if detect_instance_preemption():
|
||||
result.update(should_checkpoint=True)
|
||||
return result
|
||||
|
||||
|
||||
**Periodic Checkpointing**: periodic checkpointing can be used to provide fault-tolerance for experiments. This can be enabled by setting ``checkpoint_freq=<int>`` and ``max_failures=<int>`` to checkpoint trials every *N* iterations and recover from up to *M* crashes per trial, e.g.:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
tune.run(
|
||||
my_trainable,
|
||||
checkpoint_freq=10,
|
||||
max_failures=5,
|
||||
)
|
||||
|
||||
**Checkpointing at Termination**: The checkpoint_freq may not coincide with the exact end of an experiment. If you want a checkpoint to be created at the end
|
||||
of a trial, you can additionally set the ``checkpoint_at_end=True``:
|
||||
|
||||
.. code-block:: python
|
||||
:emphasize-lines: 5
|
||||
|
||||
tune.run(
|
||||
my_trainable,
|
||||
checkpoint_freq=10,
|
||||
checkpoint_at_end=True,
|
||||
max_failures=5,
|
||||
)
|
||||
|
||||
Tune also generates temporary checkpoints for pausing and switching between trials. For this purpose, it is important not to depend on absolute paths in the implementation of ``save``.
|
||||
|
||||
Use ``validate_save_restore`` to catch ``_save``/``_restore`` errors before execution.
|
||||
|
||||
|
@ -122,31 +200,11 @@ Use ``validate_save_restore`` to catch ``_save``/``_restore`` errors before exec
|
|||
validate_save_restore(MyTrainableClass)
|
||||
validate_save_restore(MyTrainableClass, use_object_store=True)
|
||||
|
||||
|
||||
Advanced Resource Allocation
|
||||
----------------------------
|
||||
|
||||
Trainables can themselves be distributed. If your trainable function / class creates further Ray actors or tasks that also consume CPU / GPU resources, you will want to set ``extra_cpu`` or ``extra_gpu`` inside ``tune.run`` to reserve extra resource slots. For example, if a trainable class requires 1 GPU itself, but also launches 4 actors, each using another GPU, then you should set ``"gpu": 1, "extra_gpu": 4``.
|
||||
|
||||
.. code-block:: python
|
||||
:emphasize-lines: 4-8
|
||||
|
||||
tune.run(
|
||||
my_trainable,
|
||||
name="my_trainable",
|
||||
resources_per_trial={
|
||||
"cpu": 1,
|
||||
"gpu": 1,
|
||||
"extra_gpu": 4
|
||||
}
|
||||
)
|
||||
|
||||
The ``Trainable`` also provides the ``default_resource_requests`` interface to automatically declare the ``resources_per_trial`` based on the given configuration.
|
||||
|
||||
|
||||
Advanced: Reusing Actors
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. note:: This feature is only for the Trainable Class API.
|
||||
|
||||
Your Trainable can often take a long time to start. To avoid this, you can do ``tune.run(reuse_actors=True)`` to reuse the same Trainable Python process and object for multiple hyperparameters.
|
||||
|
||||
This requires you to implement ``Trainable.reset_config``, which provides a new set of hyperparameters. It is up to the user to correctly update the hyperparameters of your trainable.
|
||||
|
@ -176,8 +234,47 @@ This requires you to implement ``Trainable.reset_config``, which provides a new
|
|||
return True
|
||||
|
||||
|
||||
tune.Trainable
|
||||
--------------
|
||||
Advanced Resource Allocation
|
||||
----------------------------
|
||||
|
||||
Trainables can themselves be distributed. If your trainable function / class creates further Ray actors or tasks that also consume CPU / GPU resources, you will want to set ``extra_cpu`` or ``extra_gpu`` inside ``tune.run`` to reserve extra resource slots. For example, if a trainable class requires 1 GPU itself, but also launches 4 actors, each using another GPU, then you should set ``"gpu": 1, "extra_gpu": 4``.
|
||||
|
||||
.. code-block:: python
|
||||
:emphasize-lines: 4-8
|
||||
|
||||
tune.run(
|
||||
my_trainable,
|
||||
name="my_trainable",
|
||||
resources_per_trial={
|
||||
"cpu": 1,
|
||||
"gpu": 1,
|
||||
"extra_gpu": 4
|
||||
}
|
||||
)
|
||||
|
||||
The ``Trainable`` also provides the ``default_resource_requests`` interface to automatically declare the ``resources_per_trial`` based on the given configuration.
|
||||
|
||||
|
||||
|
||||
.. _track-docstring:
|
||||
|
||||
tune.report / tune.checkpoint (Function API)
|
||||
--------------------------------------------
|
||||
|
||||
.. autofunction:: ray.tune.report
|
||||
|
||||
.. autofunction:: ray.tune.make_checkpoint_dir
|
||||
|
||||
.. autofunction:: ray.tune.save_checkpoint
|
||||
|
||||
.. autofunction:: ray.tune.get_trial_dir
|
||||
|
||||
.. autofunction:: ray.tune.get_trial_name
|
||||
|
||||
.. autofunction:: ray.tune.get_trial_id
|
||||
|
||||
tune.Trainable (Class API)
|
||||
--------------------------
|
||||
|
||||
|
||||
.. autoclass:: ray.tune.Trainable
|
||||
|
@ -190,21 +287,6 @@ tune.DurableTrainable
|
|||
|
||||
.. autoclass:: ray.tune.DurableTrainable
|
||||
|
||||
.. _track-docstring:
|
||||
|
||||
tune.track
|
||||
----------
|
||||
|
||||
.. automodule:: ray.tune.track
|
||||
:members:
|
||||
:exclude-members: init,
|
||||
|
||||
KerasCallback
|
||||
-------------
|
||||
|
||||
.. automodule:: ray.tune.integration.keras
|
||||
:members:
|
||||
|
||||
|
||||
StatusReporter
|
||||
--------------
|
||||
|
|
|
@ -149,6 +149,14 @@ py_test(
|
|||
tags = ["exclusive"],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "test_function_api",
|
||||
size = "medium",
|
||||
srcs = ["tests/test_function_api.py"],
|
||||
deps = [":tune_lib"],
|
||||
tags = ["exclusive"],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "test_sync",
|
||||
size = "medium",
|
||||
|
|
|
@ -8,7 +8,8 @@ from ray.tune.trainable import Trainable
|
|||
from ray.tune.durable_trainable import DurableTrainable
|
||||
from ray.tune.suggest import grid_search
|
||||
from ray.tune.session import (report, get_trial_dir, get_trial_name,
|
||||
get_trial_id)
|
||||
get_trial_id, make_checkpoint_dir,
|
||||
save_checkpoint)
|
||||
from ray.tune.progress_reporter import (ProgressReporter, CLIReporter,
|
||||
JupyterNotebookReporter)
|
||||
from ray.tune.sample import (function, sample_from, uniform, choice, randint,
|
||||
|
@ -21,5 +22,5 @@ __all__ = [
|
|||
"uniform", "choice", "randint", "randn", "loguniform",
|
||||
"ExperimentAnalysis", "Analysis", "CLIReporter", "JupyterNotebookReporter",
|
||||
"ProgressReporter", "report", "get_trial_dir", "get_trial_name",
|
||||
"get_trial_id"
|
||||
"get_trial_id", "make_checkpoint_dir", "save_checkpoint"
|
||||
]
|
||||
|
|
|
@ -57,7 +57,7 @@ class DurableTrainable(Trainable):
|
|||
Checkpoint path or prefix that may be passed to restore().
|
||||
"""
|
||||
if checkpoint_dir:
|
||||
if checkpoint_dir.starts_with(os.path.abspath(self.logdir)):
|
||||
if checkpoint_dir.startswith(os.path.abspath(self.logdir)):
|
||||
raise ValueError("`checkpoint_dir` must be `self.logdir`, or "
|
||||
"a sub-directory.")
|
||||
checkpoint_path = super(DurableTrainable, self).save(checkpoint_dir)
|
||||
|
|
59
python/ray/tune/examples/hyperband_function_example.py
Normal file
59
python/ray/tune/examples/hyperband_function_example.py
Normal file
|
@ -0,0 +1,59 @@
|
|||
#!/usr/bin/env python
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
|
||||
import ray
|
||||
from ray import tune
|
||||
from ray.tune.schedulers import HyperBandScheduler
|
||||
|
||||
|
||||
def train(config, checkpoint=None):
|
||||
step = 0
|
||||
if checkpoint:
|
||||
with open(checkpoint) as f:
|
||||
step = json.loads(f.read())["timestep"]
|
||||
|
||||
for timestep in range(step, 100):
|
||||
v = np.tanh(float(timestep) / config.get("width", 1))
|
||||
v *= config.get("height", 1)
|
||||
|
||||
if timestep % 3 == 0:
|
||||
checkpoint_dir = tune.make_checkpoint_dir(step=timestep)
|
||||
path = os.path.join(checkpoint_dir, "checkpoint")
|
||||
with open(path, "w") as f:
|
||||
f.write(json.dumps({"timestep": timestep}))
|
||||
tune.save_checkpoint(path)
|
||||
|
||||
# Here we use `episode_reward_mean`, but you can also report other
|
||||
# objectives such as loss or accuracy.
|
||||
tune.report(episode_reward_mean=v)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--smoke-test", action="store_true", help="Finish quickly for testing")
|
||||
args, _ = parser.parse_known_args()
|
||||
ray.init(num_cpus=4 if args.smoke_test else None)
|
||||
|
||||
# Hyperband early stopping, configured with `episode_reward_mean` as the
|
||||
# objective and `training_iteration` as the time unit,
|
||||
# which is automatically filled by Tune.
|
||||
hyperband = HyperBandScheduler(
|
||||
time_attr="training_iteration",
|
||||
metric="episode_reward_mean",
|
||||
mode="max",
|
||||
max_t=200)
|
||||
|
||||
tune.run(
|
||||
train,
|
||||
name="hyperband_test",
|
||||
num_samples=20,
|
||||
stop={"training_iteration": 10 if args.smoke_test else 99999},
|
||||
config={"height": tune.uniform(0, 100)},
|
||||
scheduler=hyperband,
|
||||
fail_fast=True)
|
119
python/ray/tune/examples/pbt_function.py
Normal file
119
python/ray/tune/examples/pbt_function.py
Normal file
|
@ -0,0 +1,119 @@
|
|||
#!/usr/bin/env python
|
||||
|
||||
import numpy as np
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
|
||||
import ray
|
||||
from ray import tune
|
||||
from ray.tune.schedulers import PopulationBasedTraining
|
||||
|
||||
|
||||
def pbt_function(config, checkpoint=None):
|
||||
"""Toy PBT problem for benchmarking adaptive learning rate.
|
||||
|
||||
The goal is to optimize this trainable's accuracy. The accuracy increases
|
||||
fastest at the optimal lr, which is a function of the current accuracy.
|
||||
|
||||
The optimal lr schedule for this problem is the triangle wave as follows.
|
||||
Note that many lr schedules for real models also follow this shape:
|
||||
|
||||
best lr
|
||||
^
|
||||
| /\
|
||||
| / \
|
||||
| / \
|
||||
| / \
|
||||
------------> accuracy
|
||||
|
||||
In this problem, using PBT with a population of 2-4 is sufficient to
|
||||
roughly approximate this lr schedule. Higher population sizes will yield
|
||||
faster convergence. Training will not converge without PBT.
|
||||
"""
|
||||
lr = config["lr"]
|
||||
accuracy = 0.0 # end = 1000
|
||||
start = 0
|
||||
if checkpoint:
|
||||
with open(checkpoint) as f:
|
||||
state = json.loads(f.read())
|
||||
accuracy = state["acc"]
|
||||
start = state["step"]
|
||||
|
||||
midpoint = 100 # lr starts decreasing after acc > midpoint
|
||||
q_tolerance = 3 # penalize exceeding lr by more than this multiple
|
||||
noise_level = 2 # add gaussian noise to the acc increase
|
||||
# triangle wave:
|
||||
# - start at 0.001 @ t=0,
|
||||
# - peak at 0.01 @ t=midpoint,
|
||||
# - end at 0.001 @ t=midpoint * 2,
|
||||
for step in range(start, 100):
|
||||
if accuracy < midpoint:
|
||||
optimal_lr = 0.01 * accuracy / midpoint
|
||||
else:
|
||||
optimal_lr = 0.01 - 0.01 * (accuracy - midpoint) / midpoint
|
||||
optimal_lr = min(0.01, max(0.001, optimal_lr))
|
||||
|
||||
# compute accuracy increase
|
||||
q_err = max(lr, optimal_lr) / min(lr, optimal_lr)
|
||||
if q_err < q_tolerance:
|
||||
accuracy += (1.0 / q_err) * random.random()
|
||||
elif lr > optimal_lr:
|
||||
accuracy -= (q_err - q_tolerance) * random.random()
|
||||
accuracy += noise_level * np.random.normal()
|
||||
accuracy = max(0, accuracy)
|
||||
|
||||
if step % 3 == 0:
|
||||
checkpoint_dir = tune.make_checkpoint_dir(step=step)
|
||||
path = os.path.join(checkpoint_dir, "checkpoint")
|
||||
with open(path, "w") as f:
|
||||
f.write(json.dumps({"acc": accuracy, "step": start}))
|
||||
tune.save_checkpoint(path)
|
||||
|
||||
tune.report(
|
||||
mean_accuracy=accuracy,
|
||||
cur_lr=lr,
|
||||
optimal_lr=optimal_lr, # for debugging
|
||||
q_err=q_err, # for debugging
|
||||
done=accuracy > midpoint * 2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--smoke-test", action="store_true", help="Finish quickly for testing")
|
||||
args, _ = parser.parse_known_args()
|
||||
if args.smoke_test:
|
||||
ray.init(num_cpus=2) # force pausing to happen for test
|
||||
else:
|
||||
ray.init()
|
||||
|
||||
pbt = PopulationBasedTraining(
|
||||
time_attr="training_iteration",
|
||||
metric="mean_accuracy",
|
||||
mode="max",
|
||||
perturbation_interval=4,
|
||||
hyperparam_mutations={
|
||||
# distribution for resampling
|
||||
"lr": lambda: random.uniform(0.0001, 0.02),
|
||||
# allow perturbations within this set of categorical values
|
||||
"some_other_factor": [1, 2],
|
||||
})
|
||||
|
||||
tune.run(
|
||||
pbt_function,
|
||||
name="pbt_test",
|
||||
scheduler=pbt,
|
||||
verbose=False,
|
||||
stop={
|
||||
"training_iteration": 30,
|
||||
},
|
||||
num_samples=8,
|
||||
fail_fast=True,
|
||||
config={
|
||||
"lr": 0.0001,
|
||||
# note: this parameter is perturbed but has no effect on
|
||||
# the model training in this example
|
||||
"some_other_factor": 1,
|
||||
})
|
|
@ -3,6 +3,7 @@ import logging
|
|||
import os
|
||||
|
||||
from ray.tune.error import TuneError
|
||||
from ray.tune.function_runner import detect_checkpoint_function
|
||||
from ray.tune.registry import register_trainable, get_trainable_cls
|
||||
from ray.tune.result import DEFAULT_RESULTS_DIR
|
||||
from ray.tune.sample import sample_from
|
||||
|
@ -92,6 +93,18 @@ class Experiment:
|
|||
restore=None):
|
||||
|
||||
config = config or {}
|
||||
|
||||
if callable(run) and detect_checkpoint_function(run):
|
||||
if checkpoint_at_end:
|
||||
raise ValueError(
|
||||
"'checkpoint_at_end' cannot be used with a "
|
||||
"checkpointable function. You can specify and register "
|
||||
"checkpoints within your trainable function.")
|
||||
if checkpoint_freq:
|
||||
raise ValueError(
|
||||
"'checkpoint_freq' cannot be used with a "
|
||||
"checkpointable function. You can specify checkpoints "
|
||||
"within your trainable function.")
|
||||
self._run_identifier = Experiment.register_if_needed(run)
|
||||
self.name = name or self._run_identifier
|
||||
if upload_dir:
|
||||
|
|
|
@ -1,13 +1,18 @@
|
|||
import logging
|
||||
import os
|
||||
import io
|
||||
import time
|
||||
import inspect
|
||||
import shutil
|
||||
import threading
|
||||
import traceback
|
||||
|
||||
from six.moves import queue
|
||||
|
||||
from ray.tune import TuneError, session
|
||||
from ray.tune.trainable import Trainable
|
||||
from ray.tune.result import TIME_THIS_ITER_S, RESULT_DUPLICATE
|
||||
from ray.tune.trainable import Trainable, TrainableUtil
|
||||
from ray.tune.result import (TIME_THIS_ITER_S, RESULT_DUPLICATE,
|
||||
SHOULD_CHECKPOINT)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -40,6 +45,8 @@ class StatusReporter:
|
|||
self._trial_name = trial_name
|
||||
self._trial_id = trial_id
|
||||
self._logdir = logdir
|
||||
self._last_checkpoint = {}
|
||||
self._fresh_checkpoint = False
|
||||
|
||||
def __call__(self, **kwargs):
|
||||
"""Report updated training status.
|
||||
|
@ -77,6 +84,29 @@ class StatusReporter:
|
|||
# resume training.
|
||||
self._continue_semaphore.acquire()
|
||||
|
||||
def make_checkpoint_dir(self, step=None):
|
||||
checkpoint_dir = TrainableUtil.make_checkpoint_dir(
|
||||
self.logdir, index=step)
|
||||
return checkpoint_dir
|
||||
|
||||
def save_checkpoint(self, checkpoint):
|
||||
if isinstance(checkpoint, str):
|
||||
try:
|
||||
TrainableUtil.find_checkpoint_dir(checkpoint)
|
||||
except FileNotFoundError:
|
||||
logger.error("Checkpoint must be created with path given from "
|
||||
"make_checkpoint_dir.")
|
||||
raise
|
||||
self._last_checkpoint = checkpoint
|
||||
self._fresh_checkpoint = True
|
||||
|
||||
def has_new_checkpoint(self):
|
||||
return self._fresh_checkpoint
|
||||
|
||||
def get_checkpoint(self):
|
||||
self._fresh_checkpoint = False
|
||||
return self._last_checkpoint
|
||||
|
||||
def _start(self):
|
||||
self._last_report_time = time.time()
|
||||
|
||||
|
@ -155,21 +185,33 @@ class FunctionRunner(Trainable):
|
|||
trial_id=self.trial_id,
|
||||
logdir=self.logdir)
|
||||
self._last_result = {}
|
||||
config = config.copy()
|
||||
|
||||
session.init(self._status_reporter)
|
||||
|
||||
def entrypoint():
|
||||
return self._trainable_func(config, self._status_reporter)
|
||||
|
||||
# the runner thread is not started until the first call to _train
|
||||
self._runner = _RunnerThread(entrypoint, self._error_queue)
|
||||
self._runner = None
|
||||
self._restore_tmpdir = None
|
||||
self.default_checkpoint_dir = None
|
||||
|
||||
def _trainable_func(self):
|
||||
"""Subclasses can override this to set the trainable func."""
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
def _start(self):
|
||||
def entrypoint():
|
||||
return self._trainable_func(self.config, self._status_reporter,
|
||||
self._status_reporter.get_checkpoint())
|
||||
|
||||
# the runner thread is not started until the first call to _train
|
||||
self._runner = _RunnerThread(entrypoint, self._error_queue)
|
||||
# if not alive, try to start
|
||||
self._status_reporter._start()
|
||||
try:
|
||||
self._runner.start()
|
||||
except RuntimeError:
|
||||
# If this is reached, it means the thread was started and is
|
||||
# now done or has raised an exception.
|
||||
pass
|
||||
|
||||
def _train(self):
|
||||
"""Implements train() for a Function API.
|
||||
|
||||
|
@ -178,19 +220,12 @@ class FunctionRunner(Trainable):
|
|||
along with a result with "done=True". The TrialRunner will handle the
|
||||
result accordingly (see tune/trial_runner.py).
|
||||
"""
|
||||
if self._runner.is_alive():
|
||||
if self._runner and self._runner.is_alive():
|
||||
# if started and alive, inform the reporter to continue and
|
||||
# generate the next result
|
||||
self._continue_semaphore.release()
|
||||
else:
|
||||
# if not alive, try to start
|
||||
self._status_reporter._start()
|
||||
try:
|
||||
self._runner.start()
|
||||
except RuntimeError:
|
||||
# If this is reached, it means the thread was started and is
|
||||
# now done or has raised an exception.
|
||||
pass
|
||||
self._start()
|
||||
|
||||
result = None
|
||||
while result is None and self._runner.is_alive():
|
||||
|
@ -240,8 +275,61 @@ class FunctionRunner(Trainable):
|
|||
result = new_result
|
||||
|
||||
self._last_result = result
|
||||
if self._status_reporter.has_new_checkpoint():
|
||||
result[SHOULD_CHECKPOINT] = True
|
||||
return result
|
||||
|
||||
def create_default_checkpoint_dir(self):
|
||||
self.default_checkpoint_dir = TrainableUtil.make_checkpoint_dir(
|
||||
self.logdir, index="default")
|
||||
return self.default_checkpoint_dir
|
||||
|
||||
def save(self, checkpoint_path=None):
|
||||
if checkpoint_path:
|
||||
raise ValueError(
|
||||
"Checkpoint path should not be used with function API.")
|
||||
|
||||
checkpoint = self._status_reporter.get_checkpoint()
|
||||
state = self.get_state()
|
||||
|
||||
if not checkpoint:
|
||||
state.update(iteration=0, timesteps_total=0, episodes_total=0)
|
||||
parent_dir = self.create_default_checkpoint_dir()
|
||||
elif isinstance(checkpoint, dict):
|
||||
parent_dir = TrainableUtil.make_checkpoint_dir(
|
||||
self.logdir, index=self.training_iteration)
|
||||
else:
|
||||
parent_dir = TrainableUtil.find_checkpoint_dir(checkpoint)
|
||||
checkpoint_path = TrainableUtil.process_checkpoint(
|
||||
checkpoint, parent_dir, state)
|
||||
return checkpoint_path
|
||||
|
||||
def save_to_object(self):
|
||||
checkpoint_path = self.save()
|
||||
data_dict = TrainableUtil.pickle_checkpoint(checkpoint_path)
|
||||
out = io.BytesIO()
|
||||
if len(data_dict) > 10e6: # getting pretty large
|
||||
logger.info("Checkpoint size is {} bytes".format(len(data_dict)))
|
||||
out.write(data_dict)
|
||||
return out.getvalue()
|
||||
|
||||
def _restore(self, checkpoint):
|
||||
# This should be removed once Trainables are refactored.
|
||||
if "tune_checkpoint_path" in checkpoint:
|
||||
del checkpoint["tune_checkpoint_path"]
|
||||
self._status_reporter.save_checkpoint(checkpoint)
|
||||
|
||||
def restore_from_object(self, obj):
|
||||
if self.default_checkpoint_dir is not None and os.exists(
|
||||
self.default_checkpoint_dir):
|
||||
shutil.rmtree(self.default_checkpoint_dir)
|
||||
logger.debug("Clearing default checkpoint: %s",
|
||||
self.default_checkpoint_dir)
|
||||
|
||||
checkpoint_dir = self.create_default_checkpoint_dir()
|
||||
checkpoint_path = TrainableUtil.create_from_pickle(obj, checkpoint_dir)
|
||||
self.restore(checkpoint_path)
|
||||
|
||||
def _stop(self):
|
||||
# If everything stayed in synch properly, this should never happen.
|
||||
if not self._results_queue.empty():
|
||||
|
@ -251,7 +339,6 @@ class FunctionRunner(Trainable):
|
|||
|
||||
# Check for any errors that might have been missed.
|
||||
self._report_thread_runner_error()
|
||||
|
||||
session.shutdown()
|
||||
|
||||
def _report_thread_runner_error(self, block=False):
|
||||
|
@ -264,13 +351,35 @@ class FunctionRunner(Trainable):
|
|||
pass
|
||||
|
||||
|
||||
def detect_checkpoint_function(train_func):
|
||||
func_args = inspect.getfullargspec(train_func).args
|
||||
use_checkpoint = "checkpoint" in func_args
|
||||
return use_checkpoint
|
||||
|
||||
|
||||
def wrap_function(train_func):
|
||||
class ImplicitFunc(FunctionRunner):
|
||||
def _trainable_func(self, config, reporter):
|
||||
def _trainable_func(self, config, reporter, checkpoint):
|
||||
func_args = inspect.getfullargspec(train_func).args
|
||||
use_track = ("reporter" not in func_args and len(func_args) == 1)
|
||||
if use_track:
|
||||
if len(func_args) > 1: # more arguments than just the config
|
||||
if "reporter" not in func_args and (
|
||||
"checkpoint" not in func_args):
|
||||
raise ValueError(
|
||||
"Unknown argument found in the Trainable function. "
|
||||
"Arguments other than the 'config' arg must be one "
|
||||
"of ['reporter', 'checkpoint']. Found: {}".format(
|
||||
func_args))
|
||||
use_reporter = "reporter" in func_args
|
||||
use_checkpoint = "checkpoint" in func_args
|
||||
if not use_checkpoint and not use_reporter:
|
||||
logger.warning(
|
||||
"Function checkpointing is disabled. This may result in "
|
||||
"unexpected behavior when using checkpointing features or "
|
||||
"certain schedulers. To enable, set the train function "
|
||||
"arguments to be `func(config, checkpoint)`.")
|
||||
output = train_func(config)
|
||||
elif use_checkpoint:
|
||||
output = train_func(config, checkpoint=checkpoint)
|
||||
else:
|
||||
output = train_func(config, reporter)
|
||||
|
||||
|
|
|
@ -669,8 +669,7 @@ class RayTrialExecutor(TrialExecutor):
|
|||
elif trial.sync_on_checkpoint:
|
||||
# This provides FT backwards compatibility in the
|
||||
# case where a DurableTrainable is not provided.
|
||||
logger.warning("Trial %s: Reading checkpoint into memory.",
|
||||
trial)
|
||||
logger.debug("Trial %s: Reading checkpoint into memory", trial)
|
||||
data_dict = TrainableUtil.pickle_checkpoint(value)
|
||||
with self._change_working_directory(trial):
|
||||
remote = trial.runner.restore_from_object.remote(data_dict)
|
||||
|
|
|
@ -381,10 +381,14 @@ class Bracket:
|
|||
|
||||
assert trial in self._live_trials
|
||||
assert self._get_result_time(result) >= 0
|
||||
observed_time = self._get_result_time(result)
|
||||
last_observed = self._get_result_time(self._live_trials[trial])
|
||||
|
||||
delta = self._get_result_time(result) - \
|
||||
self._get_result_time(self._live_trials[trial])
|
||||
assert delta >= 0, (result, self._live_trials[trial])
|
||||
delta = last_observed - observed_time
|
||||
if delta >= 0:
|
||||
logger.info("Restoring from a previous point in time. "
|
||||
"Previous={}; Now={}".format(last_observed,
|
||||
observed_time))
|
||||
self._completed_progress += delta
|
||||
self._live_trials[trial] = result
|
||||
|
||||
|
@ -424,7 +428,7 @@ class Bracket:
|
|||
def _calculate_total_work(self, n, r, s):
|
||||
work = 0
|
||||
cumulative_r = r
|
||||
for i in range(s + 1):
|
||||
for _ in range(s + 1):
|
||||
work += int(n) * int(r)
|
||||
n /= self._eta
|
||||
n = int(np.ceil(n))
|
||||
|
|
|
@ -5,29 +5,6 @@ logger = logging.getLogger(__name__)
|
|||
_session = None
|
||||
|
||||
|
||||
class _ReporterSession:
|
||||
def __init__(self, tune_reporter):
|
||||
self.tune_reporter = tune_reporter
|
||||
|
||||
def report(self, **metrics):
|
||||
return self.tune_reporter(**metrics)
|
||||
|
||||
@property
|
||||
def logdir(self):
|
||||
"""Trial logdir (subdir of given experiment directory)"""
|
||||
return self.tune_reporter.logdir
|
||||
|
||||
@property
|
||||
def trial_name(self):
|
||||
"""Trial name for the corresponding trial of this Trainable"""
|
||||
return self.tune_reporter.trial_name
|
||||
|
||||
@property
|
||||
def trial_id(self):
|
||||
"""Trial id for the corresponding trial of this Trainable"""
|
||||
return self.tune_reporter.trial_id
|
||||
|
||||
|
||||
def get_session():
|
||||
global _session
|
||||
if _session is None:
|
||||
|
@ -56,7 +33,11 @@ def init(reporter, ignore_reinit_error=True):
|
|||
else:
|
||||
raise ValueError(reinit_msg)
|
||||
|
||||
_session = _ReporterSession(reporter)
|
||||
if reporter is None:
|
||||
logger.warning("You are using a Tune session outside of Tune. "
|
||||
"Most session commands will have no effect.")
|
||||
|
||||
_session = reporter
|
||||
|
||||
|
||||
def shutdown():
|
||||
|
@ -86,34 +67,109 @@ def report(**kwargs):
|
|||
metrics can be used for early stopping or optimization.
|
||||
"""
|
||||
_session = get_session()
|
||||
return _session.report(**kwargs)
|
||||
return _session(**kwargs)
|
||||
|
||||
|
||||
def make_checkpoint_dir(step=None):
|
||||
"""Gets the next checkpoint dir.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import time
|
||||
from ray import tune
|
||||
|
||||
def func(config, checkpoint=None):
|
||||
start = 0
|
||||
if checkpoint:
|
||||
with open(checkpoint) as f:
|
||||
state = json.loads(f.read())
|
||||
start = state["step"] + 1
|
||||
|
||||
for iter in range(start, 100):
|
||||
time.sleep(1)
|
||||
|
||||
checkpoint_dir = tune.make_checkpoint_dir(step=step)
|
||||
path = os.path.join(checkpoint_dir, "checkpoint")
|
||||
with open(path, "w") as f:
|
||||
f.write(json.dumps({"step": start}))
|
||||
tune.save_checkpoint(path)
|
||||
|
||||
tune.report(hello="world", ray="tune")
|
||||
|
||||
Args:
|
||||
step (int): Current training iteration - used for setting
|
||||
an index to uniquely identify the checkpoint.
|
||||
|
||||
.. versionadded:: 0.8.6
|
||||
|
||||
"""
|
||||
_session = get_session()
|
||||
return _session.make_checkpoint_dir(step=step)
|
||||
|
||||
|
||||
def save_checkpoint(checkpoint):
|
||||
"""Register the given checkpoint.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import os
|
||||
import json
|
||||
import time
|
||||
from ray import tune
|
||||
|
||||
def func(config, checkpoint=None):
|
||||
start = 0
|
||||
if checkpoint:
|
||||
with open(checkpoint) as f:
|
||||
state = json.loads(f.read())
|
||||
accuracy = state["acc"]
|
||||
start = state["step"] + 1
|
||||
|
||||
for iter in range(start, 10):
|
||||
time.sleep(1)
|
||||
|
||||
checkpoint_dir = tune.make_checkpoint_dir(step=iter)
|
||||
path = os.path.join(checkpoint_dir, "checkpoint")
|
||||
with open(path, "w") as f:
|
||||
f.write(json.dumps({"step": start}))
|
||||
tune.save_checkpoint(path)
|
||||
|
||||
tune.report(hello="world", ray="tune")
|
||||
|
||||
analysis = tune.run(run_me)
|
||||
|
||||
Args:
|
||||
**kwargs: Any key value pair to be logged by Tune. Any of these
|
||||
metrics can be used for early stopping or optimization.
|
||||
|
||||
.. versionadded:: 0.8.6
|
||||
"""
|
||||
_session = get_session()
|
||||
return _session.save_checkpoint(checkpoint)
|
||||
|
||||
|
||||
def get_trial_dir():
|
||||
"""Returns the directory where trial results are saved.
|
||||
|
||||
For function API use only. Do not call this method in the Class API. Use
|
||||
`self.logdir` instead.
|
||||
For function API use only.
|
||||
"""
|
||||
_session = get_session()
|
||||
return _session.logdir
|
||||
|
||||
|
||||
def get_trial_name():
|
||||
"""Trial name for the corresponding trial of this Trainable.
|
||||
"""Trial name for the corresponding trial.
|
||||
|
||||
For function API use only. Do not call this method in the Class API. Use
|
||||
`self.trial_name` instead.
|
||||
For function API use only.
|
||||
"""
|
||||
_session = get_session()
|
||||
return _session.trial_name
|
||||
|
||||
|
||||
def get_trial_id():
|
||||
"""Trial id for the corresponding trial of this Trainable.
|
||||
"""Trial id for the corresponding trial.
|
||||
|
||||
For function API use only. Do not call this method in the Class API. Use
|
||||
`self.trial_id` instead.
|
||||
For function API use only.
|
||||
"""
|
||||
_session = get_session()
|
||||
return _session.trial_id
|
||||
|
|
|
@ -58,7 +58,7 @@ class AxSearch(Searcher):
|
|||
def easy_objective(config):
|
||||
for i in range(100):
|
||||
intermediate_result = config["x1"] + config["x2"] * i
|
||||
tune.track.log(score=intermediate_result)
|
||||
tune.report(score=intermediate_result)
|
||||
|
||||
client = AxClient(enforce_sequential_optimization=False)
|
||||
client.create_experiment(parameters=parameters, objective_name="score")
|
||||
|
|
|
@ -394,10 +394,28 @@ def test_migration_checkpoint_removal(start_connected_emptyhead_cluster,
|
|||
}
|
||||
|
||||
# The following patches only affect __fake_remote.
|
||||
find_checkpoint_dir = TrainableUtil.find_checkpoint_dir
|
||||
with patch("ray.tune.logger.get_node_syncer") as mock_get_node_syncer:
|
||||
trainable_util = "ray.tune.ray_trial_executor.TrainableUtil"
|
||||
with patch(trainable_util + ".find_checkpoint_dir") as mock_find_dir:
|
||||
def hide_remote_path(path_function):
|
||||
def hidden_path_func(checkpoint_path):
|
||||
"""Converts back to local path first."""
|
||||
if MOCK_REMOTE_DIR in checkpoint_path:
|
||||
checkpoint_path = checkpoint_path[len(MOCK_REMOTE_DIR):]
|
||||
checkpoint_path = os.path.join("/", checkpoint_path)
|
||||
return path_function(checkpoint_path)
|
||||
|
||||
return hidden_path_func
|
||||
|
||||
trainable_util = "ray.tune.ray_trial_executor.TrainableUtil"
|
||||
_find_ckpt = trainable_util + ".find_checkpoint_dir"
|
||||
find_func = TrainableUtil.find_checkpoint_dir
|
||||
_pickle_ckpt = trainable_util + ".pickle_checkpoint"
|
||||
pickle_func = TrainableUtil.pickle_checkpoint
|
||||
|
||||
with patch(_find_ckpt) as mock_find, patch(_pickle_ckpt) as mock_pkl_ckpt:
|
||||
# __fake_remote trainables save to a separate "remote" directory.
|
||||
# TrainableUtil will not check this path unless we mock it.
|
||||
mock_find.side_effect = hide_remote_path(find_func)
|
||||
mock_pkl_ckpt.side_effect = hide_remote_path(pickle_func)
|
||||
with patch("ray.tune.logger.get_node_syncer") as mock_get_node_syncer:
|
||||
|
||||
def mock_get_syncer_fn(local_dir, remote_dir, sync_function):
|
||||
client = mock_storage_client()
|
||||
|
@ -405,16 +423,6 @@ def test_migration_checkpoint_removal(start_connected_emptyhead_cluster,
|
|||
|
||||
mock_get_node_syncer.side_effect = mock_get_syncer_fn
|
||||
|
||||
def mock_find_dir_fn(checkpoint_path):
|
||||
"""Converts back to local path first."""
|
||||
checkpoint_path = checkpoint_path[len(MOCK_REMOTE_DIR):]
|
||||
checkpoint_path = os.path.join("/", checkpoint_path)
|
||||
return find_checkpoint_dir(checkpoint_path)
|
||||
|
||||
# __fake_remote trainables save to a separate "remote" directory.
|
||||
# TrainableUtil will not check this path unless we mock it.
|
||||
mock_find_dir.side_effect = mock_find_dir_fn
|
||||
|
||||
# Test recovery of trial that has been checkpointed
|
||||
t1 = Trial(trainable_id, **kwargs)
|
||||
runner.add_trial(t1)
|
||||
|
@ -428,7 +436,6 @@ def test_migration_checkpoint_removal(start_connected_emptyhead_cluster,
|
|||
cluster.remove_node(node)
|
||||
cluster.wait_for_nodes()
|
||||
shutil.rmtree(os.path.dirname(t1.checkpoint.value))
|
||||
|
||||
runner.step() # Collect result 3, kick off + fail result 4
|
||||
runner.step() # Dispatch restore
|
||||
runner.step() # Process restore + step 4
|
||||
|
|
169
python/ray/tune/tests/test_function_api.py
Normal file
169
python/ray/tune/tests/test_function_api.py
Normal file
|
@ -0,0 +1,169 @@
|
|||
import json
|
||||
import os
|
||||
import unittest
|
||||
|
||||
import ray
|
||||
from ray.rllib import _register_all
|
||||
|
||||
from ray import tune
|
||||
from ray.tune.function_runner import wrap_function
|
||||
from ray.tune.result import TRAINING_ITERATION
|
||||
|
||||
|
||||
class FunctionApiTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
ray.init(num_cpus=4, num_gpus=0, object_store_memory=150 * 1024 * 1024)
|
||||
|
||||
def tearDown(self):
|
||||
ray.shutdown()
|
||||
_register_all() # re-register the evicted objects
|
||||
|
||||
def testFunctionNoCheckpointing(self):
|
||||
def train(config, checkpoint=None):
|
||||
for i in range(10):
|
||||
tune.report(test=i)
|
||||
|
||||
wrapped = wrap_function(train)
|
||||
|
||||
new_trainable = wrapped()
|
||||
result = new_trainable.train()
|
||||
checkpoint = new_trainable.save()
|
||||
new_trainable.stop()
|
||||
|
||||
new_trainable2 = wrapped()
|
||||
new_trainable2.restore(checkpoint)
|
||||
result = new_trainable2.train()
|
||||
self.assertEquals(result[TRAINING_ITERATION], 1)
|
||||
checkpoint = new_trainable2.save()
|
||||
new_trainable2.stop()
|
||||
|
||||
def testFunctionRecurringSave(self):
|
||||
"""This tests that save and restore are commutative."""
|
||||
|
||||
def train(config, checkpoint=None):
|
||||
for step in range(10):
|
||||
if step % 3 == 0:
|
||||
checkpoint_dir = tune.make_checkpoint_dir(step=step)
|
||||
path = os.path.join(checkpoint_dir, "checkpoint")
|
||||
with open(path, "w") as f:
|
||||
f.write(json.dumps({"step": step}))
|
||||
tune.save_checkpoint(path)
|
||||
tune.report(test=step)
|
||||
|
||||
wrapped = wrap_function(train)
|
||||
|
||||
new_trainable = wrapped()
|
||||
new_trainable.train()
|
||||
checkpoint_obj = new_trainable.save_to_object()
|
||||
new_trainable.restore_from_object(checkpoint_obj)
|
||||
checkpoint = new_trainable.save()
|
||||
new_trainable.stop()
|
||||
|
||||
new_trainable2 = wrapped()
|
||||
new_trainable2.restore(checkpoint)
|
||||
new_trainable2.train()
|
||||
new_trainable2.stop()
|
||||
|
||||
def testCheckpointFunctionAtEnd(self):
|
||||
def train(config, checkpoint=False):
|
||||
for i in range(10):
|
||||
tune.report(test=i)
|
||||
checkpoint_dir = tune.make_checkpoint_dir(step=10)
|
||||
checkpoint_path = os.path.join(checkpoint_dir, "hello")
|
||||
with open(checkpoint_path, "w") as f:
|
||||
f.write("hello")
|
||||
tune.save_checkpoint(checkpoint_path)
|
||||
|
||||
[trial] = tune.run(train).trials
|
||||
assert "hello" in trial.checkpoint.value
|
||||
|
||||
def testVariousCheckpointFunctionAtEnd(self):
|
||||
def train(config, checkpoint=False):
|
||||
for i in range(10):
|
||||
checkpoint_dir = tune.make_checkpoint_dir()
|
||||
checkpoint_path = os.path.join(checkpoint_dir, "hello")
|
||||
with open(checkpoint_path, "w") as f:
|
||||
f.write("hello")
|
||||
tune.save_checkpoint(checkpoint_path)
|
||||
tune.report(test=i)
|
||||
checkpoint_dir = tune.make_checkpoint_dir()
|
||||
checkpoint_path = os.path.join(checkpoint_dir, "goodbye")
|
||||
with open(checkpoint_path, "w") as f:
|
||||
f.write("goodbye")
|
||||
tune.save_checkpoint(checkpoint_path)
|
||||
|
||||
[trial] = tune.run(train, keep_checkpoints_num=3).trials
|
||||
assert "goodbye" in trial.checkpoint.value
|
||||
|
||||
def testReuseCheckpoint(self):
|
||||
def train(config, checkpoint=False):
|
||||
itr = 0
|
||||
if checkpoint:
|
||||
with open(checkpoint, "r") as f:
|
||||
itr = int(f.read()) + 1
|
||||
|
||||
for i in range(itr, config["max_iter"]):
|
||||
checkpoint_dir = tune.make_checkpoint_dir(step=i)
|
||||
checkpoint_path = os.path.join(checkpoint_dir, "goodbye")
|
||||
with open(checkpoint_path, "w") as f:
|
||||
f.write(str(i))
|
||||
tune.save_checkpoint(checkpoint_path)
|
||||
tune.report(test=i, training_iteration=i)
|
||||
|
||||
[trial] = tune.run(
|
||||
train,
|
||||
config={
|
||||
"max_iter": 5
|
||||
},
|
||||
).trials
|
||||
last_ckpt = trial.checkpoint.value
|
||||
assert "goodbye" in last_ckpt
|
||||
analysis = tune.run(train, config={"max_iter": 10}, restore=last_ckpt)
|
||||
trial_dfs = list(analysis.trial_dataframes.values())
|
||||
assert len(trial_dfs[0]["training_iteration"]) == 5
|
||||
|
||||
def testRetry(self):
|
||||
def train(config, checkpoint=None):
|
||||
restored = bool(checkpoint)
|
||||
itr = 0
|
||||
if checkpoint:
|
||||
with open(checkpoint, "r") as f:
|
||||
itr = int(f.read()) + 1
|
||||
|
||||
for i in range(itr, 10):
|
||||
if i == 5 and not restored:
|
||||
raise Exception("try to fail me")
|
||||
checkpoint_dir = tune.make_checkpoint_dir(step=i)
|
||||
checkpoint_path = os.path.join(checkpoint_dir, "goodbye")
|
||||
with open(checkpoint_path, "w") as f:
|
||||
f.write(str(i))
|
||||
tune.save_checkpoint(checkpoint_path)
|
||||
tune.report(test=i, training_iteration=i)
|
||||
|
||||
analysis = tune.run(train, max_failures=3)
|
||||
last_ckpt = analysis.trials[0].checkpoint.value
|
||||
assert "goodbye" in last_ckpt
|
||||
trial_dfs = list(analysis.trial_dataframes.values())
|
||||
assert len(trial_dfs[0]["training_iteration"]) == 10
|
||||
|
||||
def testBlankCheckpoint(self):
|
||||
def train(config, checkpoint=None):
|
||||
restored = bool(checkpoint)
|
||||
itr = 0
|
||||
if checkpoint:
|
||||
with open(checkpoint, "r") as f:
|
||||
itr = int(f.read()) + 1
|
||||
|
||||
for i in range(itr, 10):
|
||||
if i == 5 and not restored:
|
||||
raise Exception("try to fail me")
|
||||
checkpoint_dir = tune.make_checkpoint_dir()
|
||||
checkpoint_path = os.path.join(checkpoint_dir, "goodbye")
|
||||
with open(checkpoint_path, "w") as f:
|
||||
f.write(str(i))
|
||||
tune.save_checkpoint(checkpoint_path)
|
||||
tune.report(test=i, training_iteration=i)
|
||||
|
||||
analysis = tune.run(train, max_failures=3)
|
||||
trial_dfs = list(analysis.trial_dataframes.values())
|
||||
assert len(trial_dfs[0]["training_iteration"]) == 10
|
|
@ -17,18 +17,6 @@ class TrackApiTest(unittest.TestCase):
|
|||
session.shutdown()
|
||||
ray.shutdown()
|
||||
|
||||
def testSessionInitShutdown(self):
|
||||
self.assertTrue(session._session is None)
|
||||
|
||||
# Checks that the singleton _session is created/destroyed
|
||||
# by session.init() and session.shutdown()
|
||||
for _ in range(2):
|
||||
# do it twice to see that we can reopen the session
|
||||
session.init(reporter=None)
|
||||
self.assertTrue(session._session is not None)
|
||||
session.shutdown()
|
||||
self.assertTrue(session._session is None)
|
||||
|
||||
def testSoftDeprecation(self):
|
||||
"""Checks that tune.track.log code does not break."""
|
||||
from ray.tune import track
|
||||
|
|
|
@ -28,6 +28,34 @@ SETUP_TIME_THRESHOLD = 10
|
|||
|
||||
|
||||
class TrainableUtil:
|
||||
@staticmethod
|
||||
def process_checkpoint(checkpoint, parent_dir, trainable_state):
|
||||
saved_as_dict = False
|
||||
if isinstance(checkpoint, string_types):
|
||||
if not checkpoint.startswith(parent_dir):
|
||||
raise ValueError(
|
||||
"The returned checkpoint path must be within the "
|
||||
"given checkpoint dir {}: {}".format(
|
||||
parent_dir, checkpoint))
|
||||
checkpoint_path = checkpoint
|
||||
if os.path.isdir(checkpoint_path):
|
||||
# Add trailing slash to prevent tune metadata from
|
||||
# being written outside the directory.
|
||||
checkpoint_path = os.path.join(checkpoint_path, "")
|
||||
elif isinstance(checkpoint, dict):
|
||||
saved_as_dict = True
|
||||
checkpoint_path = os.path.join(parent_dir, "checkpoint")
|
||||
with open(checkpoint_path, "wb") as f:
|
||||
pickle.dump(checkpoint, f)
|
||||
else:
|
||||
raise ValueError("Returned unexpected type {}. "
|
||||
"Expected str or dict.".format(type(checkpoint)))
|
||||
|
||||
with open(checkpoint_path + ".tune_metadata", "wb") as f:
|
||||
trainable_state["saved_as_dict"] = saved_as_dict
|
||||
pickle.dump(trainable_state, f)
|
||||
return checkpoint_path
|
||||
|
||||
@staticmethod
|
||||
def pickle_checkpoint(checkpoint_path):
|
||||
"""Pickles checkpoint data."""
|
||||
|
@ -39,7 +67,8 @@ class TrainableUtil:
|
|||
with open(path, "rb") as f:
|
||||
data[os.path.relpath(path, checkpoint_dir)] = f.read()
|
||||
# Use normpath so that a directory path isn't mapped to empty string.
|
||||
name = os.path.basename(os.path.normpath(checkpoint_path))
|
||||
name = os.path.relpath(
|
||||
os.path.normpath(checkpoint_path), checkpoint_dir)
|
||||
name += os.path.sep if os.path.isdir(checkpoint_path) else ""
|
||||
data_dict = pickle.dumps({
|
||||
"checkpoint_name": name,
|
||||
|
@ -70,11 +99,38 @@ class TrainableUtil:
|
|||
return checkpoint_dir
|
||||
|
||||
@staticmethod
|
||||
def make_checkpoint_dir(checkpoint_dir):
|
||||
"""Creates a checkpoint directory at the provided path."""
|
||||
def make_checkpoint_dir(checkpoint_dir, index):
|
||||
"""Creates a checkpoint directory within the provided path.
|
||||
|
||||
Args:
|
||||
checkpoint_dir (str): Path to checkpoint directory.
|
||||
index (str): A subdirectory will be created
|
||||
at the checkpoint directory named 'checkpoint_{index}'.
|
||||
"""
|
||||
suffix = "checkpoint"
|
||||
if index is not None:
|
||||
suffix += "_{}".format(index)
|
||||
checkpoint_dir = os.path.join(checkpoint_dir, suffix)
|
||||
|
||||
os.makedirs(checkpoint_dir, exist_ok=True)
|
||||
# Drop marker in directory to identify it as a checkpoint dir.
|
||||
open(os.path.join(checkpoint_dir, ".is_checkpoint"), "a").close()
|
||||
return checkpoint_dir
|
||||
|
||||
@staticmethod
|
||||
def create_from_pickle(obj, tmpdir):
|
||||
info = pickle.loads(obj)
|
||||
data = info["data"]
|
||||
checkpoint_path = os.path.join(tmpdir, info["checkpoint_name"])
|
||||
|
||||
for relpath_name, file_contents in data.items():
|
||||
path = os.path.join(tmpdir, relpath_name)
|
||||
|
||||
# This may be a subdirectory, hence not just using tmpdir
|
||||
os.makedirs(os.path.dirname(path), exist_ok=True)
|
||||
with open(path, "wb") as f:
|
||||
f.write(file_contents)
|
||||
return checkpoint_path
|
||||
|
||||
@staticmethod
|
||||
def get_checkpoints_paths(logdir):
|
||||
|
@ -324,6 +380,16 @@ class Trainable:
|
|||
|
||||
return result
|
||||
|
||||
def get_state(self):
|
||||
return {
|
||||
"experiment_id": self._experiment_id,
|
||||
"iteration": self._iteration,
|
||||
"timesteps_total": self._timesteps_total,
|
||||
"time_total": self._time_total,
|
||||
"episodes_total": self._episodes_total,
|
||||
"ray_version": ray.__version__,
|
||||
}
|
||||
|
||||
def save(self, checkpoint_dir=None):
|
||||
"""Saves the current model state to a checkpoint.
|
||||
|
||||
|
@ -336,41 +402,14 @@ class Trainable:
|
|||
Returns:
|
||||
str: Checkpoint path or prefix that may be passed to restore().
|
||||
"""
|
||||
checkpoint_dir = os.path.join(checkpoint_dir or self.logdir,
|
||||
"checkpoint_{}".format(self._iteration))
|
||||
TrainableUtil.make_checkpoint_dir(checkpoint_dir)
|
||||
checkpoint_dir = TrainableUtil.make_checkpoint_dir(
|
||||
checkpoint_dir or self.logdir, index=self.iteration)
|
||||
checkpoint = self._save(checkpoint_dir)
|
||||
saved_as_dict = False
|
||||
if isinstance(checkpoint, string_types):
|
||||
if not checkpoint.startswith(checkpoint_dir):
|
||||
raise ValueError(
|
||||
"The returned checkpoint path must be within the "
|
||||
"given checkpoint dir {}: {}".format(
|
||||
checkpoint_dir, checkpoint))
|
||||
checkpoint_path = checkpoint
|
||||
if os.path.isdir(checkpoint_path):
|
||||
# Add trailing slash to prevent tune metadata from
|
||||
# being written outside the directory.
|
||||
checkpoint_path = os.path.join(checkpoint_path, "")
|
||||
elif isinstance(checkpoint, dict):
|
||||
saved_as_dict = True
|
||||
checkpoint_path = os.path.join(checkpoint_dir, "checkpoint")
|
||||
with open(checkpoint_path, "wb") as f:
|
||||
pickle.dump(checkpoint, f)
|
||||
else:
|
||||
raise ValueError("Returned unexpected type {}. "
|
||||
"Expected str or dict.".format(type(checkpoint)))
|
||||
|
||||
with open(checkpoint_path + ".tune_metadata", "wb") as f:
|
||||
pickle.dump({
|
||||
"experiment_id": self._experiment_id,
|
||||
"iteration": self._iteration,
|
||||
"timesteps_total": self._timesteps_total,
|
||||
"time_total": self._time_total,
|
||||
"episodes_total": self._episodes_total,
|
||||
"saved_as_dict": saved_as_dict,
|
||||
"ray_version": ray.__version__,
|
||||
}, f)
|
||||
trainable_state = self.get_state()
|
||||
checkpoint_path = TrainableUtil.process_checkpoint(
|
||||
checkpoint,
|
||||
parent_dir=checkpoint_dir,
|
||||
trainable_state=trainable_state)
|
||||
return checkpoint_path
|
||||
|
||||
def save_to_object(self):
|
||||
|
@ -434,19 +473,8 @@ class Trainable:
|
|||
|
||||
These checkpoints are returned from calls to save_to_object().
|
||||
"""
|
||||
info = pickle.loads(obj)
|
||||
data = info["data"]
|
||||
tmpdir = tempfile.mkdtemp("restore_from_object", dir=self.logdir)
|
||||
checkpoint_path = os.path.join(tmpdir, info["checkpoint_name"])
|
||||
|
||||
for relpath_name, file_contents in data.items():
|
||||
path = os.path.join(tmpdir, relpath_name)
|
||||
|
||||
# This may be a subdirectory, hence not just using tmpdir
|
||||
os.makedirs(os.path.dirname(path), exist_ok=True)
|
||||
with open(path, "wb") as f:
|
||||
f.write(file_contents)
|
||||
|
||||
checkpoint_path = TrainableUtil.create_from_pickle(obj, tmpdir)
|
||||
self.restore(checkpoint_path)
|
||||
shutil.rmtree(tmpdir)
|
||||
|
||||
|
@ -531,7 +559,10 @@ class Trainable:
|
|||
|
||||
name = self.trial_name
|
||||
"""
|
||||
return self._trial_info.trial_name
|
||||
if self._trial_info:
|
||||
return self._trial_info.trial_name
|
||||
else:
|
||||
return "default"
|
||||
|
||||
@property
|
||||
def trial_id(self):
|
||||
|
@ -543,7 +574,10 @@ class Trainable:
|
|||
|
||||
trial_id = self.trial_id
|
||||
"""
|
||||
return self._trial_info.trial_id
|
||||
if self._trial_info:
|
||||
return self._trial_info.trial_id
|
||||
else:
|
||||
return "default"
|
||||
|
||||
@property
|
||||
def iteration(self):
|
||||
|
|
|
@ -385,10 +385,14 @@ class TrialRunner:
|
|||
self.trial_executor.try_checkpoint_metadata(trial)
|
||||
|
||||
def debug_string(self, delim="\n"):
|
||||
result_keys = [
|
||||
list(t.last_result) for t in self.get_trials() if t.last_result
|
||||
]
|
||||
metrics = set().union(*result_keys)
|
||||
messages = [
|
||||
self._scheduler_alg.debug_string(),
|
||||
self.trial_executor.debug_string(),
|
||||
trial_progress_str(self.get_trials()),
|
||||
trial_progress_str(self.get_trials(), metrics),
|
||||
]
|
||||
return delim.join(messages)
|
||||
|
||||
|
@ -468,6 +472,7 @@ class TrialRunner:
|
|||
result = self.trial_executor.fetch_result(trial)
|
||||
|
||||
is_duplicate = RESULT_DUPLICATE in result
|
||||
force_checkpoint = result.get(SHOULD_CHECKPOINT, False)
|
||||
# TrialScheduler and SearchAlgorithm still receive a
|
||||
# notification because there may be special handling for
|
||||
# the `on_trial_complete` hook.
|
||||
|
@ -506,8 +511,7 @@ class TrialRunner:
|
|||
# the scheduler decision is STOP or PAUSE. Note that
|
||||
# PAUSE only checkpoints to memory and does not update
|
||||
# the global checkpoint state.
|
||||
self._checkpoint_trial_if_needed(
|
||||
trial, force=result.get(SHOULD_CHECKPOINT, False))
|
||||
self._checkpoint_trial_if_needed(trial, force=force_checkpoint)
|
||||
|
||||
if trial.is_saving:
|
||||
# Cache decision to execute on after the save is processed.
|
||||
|
|
|
@ -110,7 +110,10 @@ def run(run_or_experiment,
|
|||
function or class, or the string identifier of a
|
||||
trainable function or class registered in the tune registry.
|
||||
If Experiment, then Tune will execute training based on
|
||||
Experiment.spec.
|
||||
Experiment.spec. If you want to pass in a Python lambda, you
|
||||
will need to first register the function:
|
||||
``tune.register_trainable("lambda_id", lambda x: ...)``. You can
|
||||
then use ``tune.run("lambda_id")``.
|
||||
name (str): Name of experiment.
|
||||
stop (dict | callable | :class:`Stopper`): Stopping criteria. If dict,
|
||||
the keys may be any field in the return result of 'train()',
|
||||
|
@ -154,8 +157,10 @@ def run(run_or_experiment,
|
|||
syncing to driver is disabled.
|
||||
checkpoint_freq (int): How many training iterations between
|
||||
checkpoints. A value of 0 (default) disables checkpointing.
|
||||
This has no effect when using the Functional Training API.
|
||||
checkpoint_at_end (bool): Whether to checkpoint at the end of the
|
||||
experiment regardless of the checkpoint_freq. Default is False.
|
||||
This has no effect when using the Functional Training API.
|
||||
sync_on_checkpoint (bool): Force sync-down of trial checkpoint to
|
||||
driver. If set to False, checkpoint syncing from worker to driver
|
||||
is asynchronous and best-effort. This does not affect persistent
|
||||
|
@ -214,6 +219,8 @@ def run(run_or_experiment,
|
|||
if using a RayTrialExecutor (which is the default) and
|
||||
if Ray is not initialized. Defaults to True.
|
||||
|
||||
|
||||
|
||||
Returns:
|
||||
ExperimentAnalysis: Object for experiment analysis.
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue