[tune] Function API checkpointing (#8471)

Co-authored-by: krfricke <krfricke@users.noreply.github.com>
This commit is contained in:
Richard Liaw 2020-06-15 10:42:54 -07:00 committed by GitHub
parent 91e57f2e53
commit 6c49c01837
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
21 changed files with 897 additions and 237 deletions

View file

@ -141,6 +141,14 @@ $SUPPRESS_OUTPUT docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE}
python /ray/python/ray/tune/examples/pbt_convnet_example.py \
--smoke-test
$SUPPRESS_OUTPUT docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} --memory-swap=-1 $DOCKER_SHA \
python /ray/python/ray/tune/examples/hyperband_function_example.py \
--smoke-test
$SUPPRESS_OUTPUT docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} --memory-swap=-1 $DOCKER_SHA \
python /ray/python/ray/tune/examples/pbt_function.py \
--smoke-test
$SUPPRESS_OUTPUT docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} --memory-swap=-1 $DOCKER_SHA \
python /ray/python/ray/tune/examples/pbt_dcgan_mnist/pbt_dcgan_mnist.py \
--smoke-test

View file

@ -35,7 +35,7 @@ Here's an example of specifying the objective function using :ref:`the function-
Now, there's two Trainable APIs - one being the :ref:`function-based API <tune-function-api>` that we demonstrated above.
The other is a :ref:`class-based API <tune-class-api>` that enables :ref:`checkpointing and pausing <tune-trainable-save-restore>`. Here's an example of specifying the objective function using the :ref:`class-based API <tune-class-api>`:
The other is a :ref:`class-based API <tune-class-api>`. Here's an example of specifying the objective function using the :ref:`class-based API <tune-class-api>`:
.. code-block:: python

View file

@ -147,46 +147,39 @@ When running a hyperparameter search, Tune can automatically and periodically sa
* fault-tolerance when using pre-emptible machines.
* Pausing trials when using Trial Schedulers such as HyperBand and PBT.
To enable checkpointing, you must implement a :ref:`Trainable class <trainable-docs>` (the function-based API are not checkpointable, since they never return control back to their caller).
Checkpointing assumes that the model state will be saved to disk on whichever node the Trainable is running on.
Checkpointing assumes that the model state will be saved to disk on whichever node the Trainable is running on. You can checkpoint with three different mechanisms: manually, periodically, and at termination.
**Manual Checkpointing**: A custom Trainable can manually trigger checkpointing by returning ``should_checkpoint: True`` (or ``tune.result.SHOULD_CHECKPOINT: True``) in the result dictionary of `_train`. This can be especially helpful in spot instances:
To use Tune's checkpointing features, you must expose a ``checkpoint`` argument in the function signature, and call ``tune.make_checkpoint_dir`` and ``tune.save_checkpoint``:
.. code-block:: python
def _train(self):
# training code
result = {"mean_accuracy": accuracy}
if detect_instance_preemption():
result.update(should_checkpoint=True)
return result
import time
from ray import tune
def train_func(config, checkpoint=None):
start = 0
if checkpoint:
with open(checkpoint) as f:
state = json.loads(f.read())
start = state["step"] + 1
**Periodic Checkpointing**: periodic checkpointing can be used to provide fault-tolerance for experiments. This can be enabled by setting ``checkpoint_freq=<int>`` and ``max_failures=<int>`` to checkpoint trials every *N* iterations and recover from up to *M* crashes per trial, e.g.:
for iter in range(start, 100):
time.sleep(1)
.. code-block:: python
# Obtain a checkpoint directory
checkpoint_dir = tune.make_checkpoint_dir(step=step)
path = os.path.join(checkpoint_dir, "checkpoint")
with open(path, "w") as f:
f.write(json.dumps({"step": start}))
tune.save_checkpoint(path)
tune.run(
my_trainable,
checkpoint_freq=10,
max_failures=5,
)
tune.report(hello="world", ray="tune")
**Checkpointing at Termination**: The checkpoint_freq may not coincide with the exact end of an experiment. If you want a checkpoint to be created at the end
of a trial, you can additionally set the ``checkpoint_at_end=True``:
tune.run(train_func)
.. code-block:: python
:emphasize-lines: 5
In this example, checkpoints will be saved by training iteration to ``local_dir/exp_name/trial_name/checkpoint_<step>``.
tune.run(
my_trainable,
checkpoint_freq=10,
checkpoint_at_end=True,
max_failures=5,
)
The checkpoint will be saved at a path that looks like ``local_dir/exp_name/trial_name/checkpoint_x/``, where the x is the number of iterations so far when the checkpoint is saved. To restore the checkpoint, you can use the ``restore`` argument and specify a checkpoint file. By doing this, you can change whatever experiments' configuration such as the experiment's name, the training iteration or so:
You can restore a single trial checkpoint by using ``tune.run(restore=<checkpoint_dir>)`` By doing this, you can change whatever experiments' configuration such as the experiment's name:
.. code-block:: python

View file

@ -3,9 +3,7 @@
Training (tune.Trainable, tune.report)
======================================
Training can be done with either a **Class API** (``tune.Trainable``) or **function-based API** (``tune.report``).
You can use the **function-based API** for fast prototyping. On the other hand, the ``tune.Trainable`` interface supports checkpoint/restore functionality and provides more control for advanced algorithms.
Training can be done with either a **Class API** (``tune.Trainable``) or **function API** (``tune.report``).
For the sake of example, let's maximize this objective function:
@ -16,8 +14,10 @@ For the sake of example, let's maximize this objective function:
.. _tune-function-api:
Function-based API
------------------
Function API
------------
Here is a simple example of using the function API. You can report intermediate metrics by simply calling ``tune.report`` within the provided function.
.. code-block:: python
@ -25,31 +25,74 @@ Function-based API
# config (dict): A dict of hyperparameters.
for x in range(20):
score = objective(x, config["a"], config["b"])
intermediate_score = objective(x, config["a"], config["b"])
tune.report(score=score) # This sends the score to Tune.
tune.report(value=intermediate_score) # This sends the score to Tune.
analysis = tune.run(
trainable,
config={
"a": 2,
"b": 4
})
config={"a": 2, "b": 4}
)
print("best config: ", analysis.get_best_config(metric="score", mode="max"))
.. tip:: Do not use ``tune.track.log`` within a ``Trainable`` class.
.. tip:: Do not use ``tune.report`` within a ``Trainable`` class.
Tune will run this function on a separate thread in a Ray actor process. Note that this API is not checkpointable, since the thread will never return control back to its caller.
Tune will run this function on a separate thread in a Ray actor process.
.. note:: If you want to pass in a Python lambda, you will need to first register the function: ``tune.register_trainable("lambda_id", lambda x: ...)``. You can then use ``lambda_id`` in place of ``my_trainable``.
Function API Checkpointing
~~~~~~~~~~~~~~~~~~~~~~~~~~
Many Tune features rely on checkpointing, including the usage of certain Trial Schedulers and fault tolerance. To use Tune's checkpointing features, you must expose a ``checkpoint`` argument in the function signature, and call ``tune.make_checkpoint_dir`` and ``tune.save_checkpoint``:
.. code-block:: python
import time
from ray import tune
def train_func(config, checkpoint=None):
start = 0
if checkpoint:
with open(checkpoint) as f:
state = json.loads(f.read())
start = state["step"] + 1
for iter in range(start, 100):
time.sleep(1)
#
checkpoint_dir = tune.make_checkpoint_dir(step=step)
path = os.path.join(checkpoint_dir, "checkpoint")
with open(path, "w") as f:
f.write(json.dumps({"step": start}))
tune.save_checkpoint(path)
tune.report(hello="world", ray="tune")
tune.run(train_func)
In this example, checkpoints will be saved by training iteration to ``local_dir/exp_name/trial_name/checkpoint_<step>``. You can restore a single trial checkpoint by using ``tune.run(restore=<checkpoint_dir>)``:
.. code-block:: python
analysis = tune.run(
train,
config={
"max_iter": 5
},
).trials
last_ckpt = trial.checkpoint.value
analysis = tune.run(train, config={"max_iter": 10}, restore=last_ckpt)
Tune also may copy or move checkpoints during the course of tuning. For this purpose, it is important not to depend on absolute paths in the implementation of ``save``.
.. _tune-class-api:
Trainable Class API
-------------------
.. caution:: Do not use ``tune.track.log`` within a ``Trainable`` class.
.. caution:: Do not use ``tune.report`` within a ``Trainable`` class.
The Trainable **class API** will require users to subclass ``ray.tune.Trainable``. Here's a naive example of this API:
@ -87,14 +130,13 @@ As a subclass of ``tune.Trainable``, Tune will create a ``Trainable`` object on
.. tip:: As a rule of thumb, the execution time of ``_train`` should be large enough to avoid overheads (i.e. more than a few seconds), but short enough to report progress periodically (i.e. at most a few minutes).
In this example, we only implemented the ``_setup`` and ``_train`` methods for simplification. Next, we'll implement ``_save`` and ``_restore`` for checkpoint and fault tolerance.
.. _tune-trainable-save-restore:
Save and Restore
~~~~~~~~~~~~~~~~
Class API Checkpointing
~~~~~~~~~~~~~~~~~~~~~~~
Many Tune features rely on ``_save``, and ``_restore``, including the usage of certain Trial Schedulers, fault tolerance, and checkpointing.
You can also implement checkpoint/restore using the Trainable Class API:
.. code-block:: python
@ -108,9 +150,45 @@ Many Tune features rely on ``_save``, and ``_restore``, including the usage of c
checkpoint_path = os.path.join(tmp_checkpoint_dir, "model.pth")
self.model.load_state_dict(torch.load(checkpoint_path))
Checkpoints will be saved by training iteration to ``local_dir/exp_name/trial_name/checkpoint_<iter>``. You can restore a single trial checkpoint by using ``tune.run(restore=<checkpoint_dir>)``.
tune.run(MyTrainableClass, checkpoint_freq=2)
You can checkpoint with three different mechanisms: manually, periodically, and at termination.
**Manual Checkpointing**: A custom Trainable can manually trigger checkpointing by returning ``should_checkpoint: True`` (or ``tune.result.SHOULD_CHECKPOINT: True``) in the result dictionary of `_train`. This can be especially helpful in spot instances:
.. code-block:: python
def _train(self):
# training code
result = {"mean_accuracy": accuracy}
if detect_instance_preemption():
result.update(should_checkpoint=True)
return result
**Periodic Checkpointing**: periodic checkpointing can be used to provide fault-tolerance for experiments. This can be enabled by setting ``checkpoint_freq=<int>`` and ``max_failures=<int>`` to checkpoint trials every *N* iterations and recover from up to *M* crashes per trial, e.g.:
.. code-block:: python
tune.run(
my_trainable,
checkpoint_freq=10,
max_failures=5,
)
**Checkpointing at Termination**: The checkpoint_freq may not coincide with the exact end of an experiment. If you want a checkpoint to be created at the end
of a trial, you can additionally set the ``checkpoint_at_end=True``:
.. code-block:: python
:emphasize-lines: 5
tune.run(
my_trainable,
checkpoint_freq=10,
checkpoint_at_end=True,
max_failures=5,
)
Tune also generates temporary checkpoints for pausing and switching between trials. For this purpose, it is important not to depend on absolute paths in the implementation of ``save``.
Use ``validate_save_restore`` to catch ``_save``/``_restore`` errors before execution.
@ -122,31 +200,11 @@ Use ``validate_save_restore`` to catch ``_save``/``_restore`` errors before exec
validate_save_restore(MyTrainableClass)
validate_save_restore(MyTrainableClass, use_object_store=True)
Advanced Resource Allocation
----------------------------
Trainables can themselves be distributed. If your trainable function / class creates further Ray actors or tasks that also consume CPU / GPU resources, you will want to set ``extra_cpu`` or ``extra_gpu`` inside ``tune.run`` to reserve extra resource slots. For example, if a trainable class requires 1 GPU itself, but also launches 4 actors, each using another GPU, then you should set ``"gpu": 1, "extra_gpu": 4``.
.. code-block:: python
:emphasize-lines: 4-8
tune.run(
my_trainable,
name="my_trainable",
resources_per_trial={
"cpu": 1,
"gpu": 1,
"extra_gpu": 4
}
)
The ``Trainable`` also provides the ``default_resource_requests`` interface to automatically declare the ``resources_per_trial`` based on the given configuration.
Advanced: Reusing Actors
~~~~~~~~~~~~~~~~~~~~~~~~
.. note:: This feature is only for the Trainable Class API.
Your Trainable can often take a long time to start. To avoid this, you can do ``tune.run(reuse_actors=True)`` to reuse the same Trainable Python process and object for multiple hyperparameters.
This requires you to implement ``Trainable.reset_config``, which provides a new set of hyperparameters. It is up to the user to correctly update the hyperparameters of your trainable.
@ -176,8 +234,47 @@ This requires you to implement ``Trainable.reset_config``, which provides a new
return True
tune.Trainable
--------------
Advanced Resource Allocation
----------------------------
Trainables can themselves be distributed. If your trainable function / class creates further Ray actors or tasks that also consume CPU / GPU resources, you will want to set ``extra_cpu`` or ``extra_gpu`` inside ``tune.run`` to reserve extra resource slots. For example, if a trainable class requires 1 GPU itself, but also launches 4 actors, each using another GPU, then you should set ``"gpu": 1, "extra_gpu": 4``.
.. code-block:: python
:emphasize-lines: 4-8
tune.run(
my_trainable,
name="my_trainable",
resources_per_trial={
"cpu": 1,
"gpu": 1,
"extra_gpu": 4
}
)
The ``Trainable`` also provides the ``default_resource_requests`` interface to automatically declare the ``resources_per_trial`` based on the given configuration.
.. _track-docstring:
tune.report / tune.checkpoint (Function API)
--------------------------------------------
.. autofunction:: ray.tune.report
.. autofunction:: ray.tune.make_checkpoint_dir
.. autofunction:: ray.tune.save_checkpoint
.. autofunction:: ray.tune.get_trial_dir
.. autofunction:: ray.tune.get_trial_name
.. autofunction:: ray.tune.get_trial_id
tune.Trainable (Class API)
--------------------------
.. autoclass:: ray.tune.Trainable
@ -190,21 +287,6 @@ tune.DurableTrainable
.. autoclass:: ray.tune.DurableTrainable
.. _track-docstring:
tune.track
----------
.. automodule:: ray.tune.track
:members:
:exclude-members: init,
KerasCallback
-------------
.. automodule:: ray.tune.integration.keras
:members:
StatusReporter
--------------

View file

@ -149,6 +149,14 @@ py_test(
tags = ["exclusive"],
)
py_test(
name = "test_function_api",
size = "medium",
srcs = ["tests/test_function_api.py"],
deps = [":tune_lib"],
tags = ["exclusive"],
)
py_test(
name = "test_sync",
size = "medium",

View file

@ -8,7 +8,8 @@ from ray.tune.trainable import Trainable
from ray.tune.durable_trainable import DurableTrainable
from ray.tune.suggest import grid_search
from ray.tune.session import (report, get_trial_dir, get_trial_name,
get_trial_id)
get_trial_id, make_checkpoint_dir,
save_checkpoint)
from ray.tune.progress_reporter import (ProgressReporter, CLIReporter,
JupyterNotebookReporter)
from ray.tune.sample import (function, sample_from, uniform, choice, randint,
@ -21,5 +22,5 @@ __all__ = [
"uniform", "choice", "randint", "randn", "loguniform",
"ExperimentAnalysis", "Analysis", "CLIReporter", "JupyterNotebookReporter",
"ProgressReporter", "report", "get_trial_dir", "get_trial_name",
"get_trial_id"
"get_trial_id", "make_checkpoint_dir", "save_checkpoint"
]

View file

@ -57,7 +57,7 @@ class DurableTrainable(Trainable):
Checkpoint path or prefix that may be passed to restore().
"""
if checkpoint_dir:
if checkpoint_dir.starts_with(os.path.abspath(self.logdir)):
if checkpoint_dir.startswith(os.path.abspath(self.logdir)):
raise ValueError("`checkpoint_dir` must be `self.logdir`, or "
"a sub-directory.")
checkpoint_path = super(DurableTrainable, self).save(checkpoint_dir)

View file

@ -0,0 +1,59 @@
#!/usr/bin/env python
import argparse
import json
import os
import numpy as np
import ray
from ray import tune
from ray.tune.schedulers import HyperBandScheduler
def train(config, checkpoint=None):
step = 0
if checkpoint:
with open(checkpoint) as f:
step = json.loads(f.read())["timestep"]
for timestep in range(step, 100):
v = np.tanh(float(timestep) / config.get("width", 1))
v *= config.get("height", 1)
if timestep % 3 == 0:
checkpoint_dir = tune.make_checkpoint_dir(step=timestep)
path = os.path.join(checkpoint_dir, "checkpoint")
with open(path, "w") as f:
f.write(json.dumps({"timestep": timestep}))
tune.save_checkpoint(path)
# Here we use `episode_reward_mean`, but you can also report other
# objectives such as loss or accuracy.
tune.report(episode_reward_mean=v)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--smoke-test", action="store_true", help="Finish quickly for testing")
args, _ = parser.parse_known_args()
ray.init(num_cpus=4 if args.smoke_test else None)
# Hyperband early stopping, configured with `episode_reward_mean` as the
# objective and `training_iteration` as the time unit,
# which is automatically filled by Tune.
hyperband = HyperBandScheduler(
time_attr="training_iteration",
metric="episode_reward_mean",
mode="max",
max_t=200)
tune.run(
train,
name="hyperband_test",
num_samples=20,
stop={"training_iteration": 10 if args.smoke_test else 99999},
config={"height": tune.uniform(0, 100)},
scheduler=hyperband,
fail_fast=True)

View file

@ -0,0 +1,119 @@
#!/usr/bin/env python
import numpy as np
import argparse
import json
import os
import random
import ray
from ray import tune
from ray.tune.schedulers import PopulationBasedTraining
def pbt_function(config, checkpoint=None):
"""Toy PBT problem for benchmarking adaptive learning rate.
The goal is to optimize this trainable's accuracy. The accuracy increases
fastest at the optimal lr, which is a function of the current accuracy.
The optimal lr schedule for this problem is the triangle wave as follows.
Note that many lr schedules for real models also follow this shape:
best lr
^
| /\
| / \
| / \
| / \
------------> accuracy
In this problem, using PBT with a population of 2-4 is sufficient to
roughly approximate this lr schedule. Higher population sizes will yield
faster convergence. Training will not converge without PBT.
"""
lr = config["lr"]
accuracy = 0.0 # end = 1000
start = 0
if checkpoint:
with open(checkpoint) as f:
state = json.loads(f.read())
accuracy = state["acc"]
start = state["step"]
midpoint = 100 # lr starts decreasing after acc > midpoint
q_tolerance = 3 # penalize exceeding lr by more than this multiple
noise_level = 2 # add gaussian noise to the acc increase
# triangle wave:
# - start at 0.001 @ t=0,
# - peak at 0.01 @ t=midpoint,
# - end at 0.001 @ t=midpoint * 2,
for step in range(start, 100):
if accuracy < midpoint:
optimal_lr = 0.01 * accuracy / midpoint
else:
optimal_lr = 0.01 - 0.01 * (accuracy - midpoint) / midpoint
optimal_lr = min(0.01, max(0.001, optimal_lr))
# compute accuracy increase
q_err = max(lr, optimal_lr) / min(lr, optimal_lr)
if q_err < q_tolerance:
accuracy += (1.0 / q_err) * random.random()
elif lr > optimal_lr:
accuracy -= (q_err - q_tolerance) * random.random()
accuracy += noise_level * np.random.normal()
accuracy = max(0, accuracy)
if step % 3 == 0:
checkpoint_dir = tune.make_checkpoint_dir(step=step)
path = os.path.join(checkpoint_dir, "checkpoint")
with open(path, "w") as f:
f.write(json.dumps({"acc": accuracy, "step": start}))
tune.save_checkpoint(path)
tune.report(
mean_accuracy=accuracy,
cur_lr=lr,
optimal_lr=optimal_lr, # for debugging
q_err=q_err, # for debugging
done=accuracy > midpoint * 2)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--smoke-test", action="store_true", help="Finish quickly for testing")
args, _ = parser.parse_known_args()
if args.smoke_test:
ray.init(num_cpus=2) # force pausing to happen for test
else:
ray.init()
pbt = PopulationBasedTraining(
time_attr="training_iteration",
metric="mean_accuracy",
mode="max",
perturbation_interval=4,
hyperparam_mutations={
# distribution for resampling
"lr": lambda: random.uniform(0.0001, 0.02),
# allow perturbations within this set of categorical values
"some_other_factor": [1, 2],
})
tune.run(
pbt_function,
name="pbt_test",
scheduler=pbt,
verbose=False,
stop={
"training_iteration": 30,
},
num_samples=8,
fail_fast=True,
config={
"lr": 0.0001,
# note: this parameter is perturbed but has no effect on
# the model training in this example
"some_other_factor": 1,
})

View file

@ -3,6 +3,7 @@ import logging
import os
from ray.tune.error import TuneError
from ray.tune.function_runner import detect_checkpoint_function
from ray.tune.registry import register_trainable, get_trainable_cls
from ray.tune.result import DEFAULT_RESULTS_DIR
from ray.tune.sample import sample_from
@ -92,6 +93,18 @@ class Experiment:
restore=None):
config = config or {}
if callable(run) and detect_checkpoint_function(run):
if checkpoint_at_end:
raise ValueError(
"'checkpoint_at_end' cannot be used with a "
"checkpointable function. You can specify and register "
"checkpoints within your trainable function.")
if checkpoint_freq:
raise ValueError(
"'checkpoint_freq' cannot be used with a "
"checkpointable function. You can specify checkpoints "
"within your trainable function.")
self._run_identifier = Experiment.register_if_needed(run)
self.name = name or self._run_identifier
if upload_dir:

View file

@ -1,13 +1,18 @@
import logging
import os
import io
import time
import inspect
import shutil
import threading
import traceback
from six.moves import queue
from ray.tune import TuneError, session
from ray.tune.trainable import Trainable
from ray.tune.result import TIME_THIS_ITER_S, RESULT_DUPLICATE
from ray.tune.trainable import Trainable, TrainableUtil
from ray.tune.result import (TIME_THIS_ITER_S, RESULT_DUPLICATE,
SHOULD_CHECKPOINT)
logger = logging.getLogger(__name__)
@ -40,6 +45,8 @@ class StatusReporter:
self._trial_name = trial_name
self._trial_id = trial_id
self._logdir = logdir
self._last_checkpoint = {}
self._fresh_checkpoint = False
def __call__(self, **kwargs):
"""Report updated training status.
@ -77,6 +84,29 @@ class StatusReporter:
# resume training.
self._continue_semaphore.acquire()
def make_checkpoint_dir(self, step=None):
checkpoint_dir = TrainableUtil.make_checkpoint_dir(
self.logdir, index=step)
return checkpoint_dir
def save_checkpoint(self, checkpoint):
if isinstance(checkpoint, str):
try:
TrainableUtil.find_checkpoint_dir(checkpoint)
except FileNotFoundError:
logger.error("Checkpoint must be created with path given from "
"make_checkpoint_dir.")
raise
self._last_checkpoint = checkpoint
self._fresh_checkpoint = True
def has_new_checkpoint(self):
return self._fresh_checkpoint
def get_checkpoint(self):
self._fresh_checkpoint = False
return self._last_checkpoint
def _start(self):
self._last_report_time = time.time()
@ -155,21 +185,33 @@ class FunctionRunner(Trainable):
trial_id=self.trial_id,
logdir=self.logdir)
self._last_result = {}
config = config.copy()
session.init(self._status_reporter)
def entrypoint():
return self._trainable_func(config, self._status_reporter)
# the runner thread is not started until the first call to _train
self._runner = _RunnerThread(entrypoint, self._error_queue)
self._runner = None
self._restore_tmpdir = None
self.default_checkpoint_dir = None
def _trainable_func(self):
"""Subclasses can override this to set the trainable func."""
raise NotImplementedError
def _start(self):
def entrypoint():
return self._trainable_func(self.config, self._status_reporter,
self._status_reporter.get_checkpoint())
# the runner thread is not started until the first call to _train
self._runner = _RunnerThread(entrypoint, self._error_queue)
# if not alive, try to start
self._status_reporter._start()
try:
self._runner.start()
except RuntimeError:
# If this is reached, it means the thread was started and is
# now done or has raised an exception.
pass
def _train(self):
"""Implements train() for a Function API.
@ -178,19 +220,12 @@ class FunctionRunner(Trainable):
along with a result with "done=True". The TrialRunner will handle the
result accordingly (see tune/trial_runner.py).
"""
if self._runner.is_alive():
if self._runner and self._runner.is_alive():
# if started and alive, inform the reporter to continue and
# generate the next result
self._continue_semaphore.release()
else:
# if not alive, try to start
self._status_reporter._start()
try:
self._runner.start()
except RuntimeError:
# If this is reached, it means the thread was started and is
# now done or has raised an exception.
pass
self._start()
result = None
while result is None and self._runner.is_alive():
@ -240,8 +275,61 @@ class FunctionRunner(Trainable):
result = new_result
self._last_result = result
if self._status_reporter.has_new_checkpoint():
result[SHOULD_CHECKPOINT] = True
return result
def create_default_checkpoint_dir(self):
self.default_checkpoint_dir = TrainableUtil.make_checkpoint_dir(
self.logdir, index="default")
return self.default_checkpoint_dir
def save(self, checkpoint_path=None):
if checkpoint_path:
raise ValueError(
"Checkpoint path should not be used with function API.")
checkpoint = self._status_reporter.get_checkpoint()
state = self.get_state()
if not checkpoint:
state.update(iteration=0, timesteps_total=0, episodes_total=0)
parent_dir = self.create_default_checkpoint_dir()
elif isinstance(checkpoint, dict):
parent_dir = TrainableUtil.make_checkpoint_dir(
self.logdir, index=self.training_iteration)
else:
parent_dir = TrainableUtil.find_checkpoint_dir(checkpoint)
checkpoint_path = TrainableUtil.process_checkpoint(
checkpoint, parent_dir, state)
return checkpoint_path
def save_to_object(self):
checkpoint_path = self.save()
data_dict = TrainableUtil.pickle_checkpoint(checkpoint_path)
out = io.BytesIO()
if len(data_dict) > 10e6: # getting pretty large
logger.info("Checkpoint size is {} bytes".format(len(data_dict)))
out.write(data_dict)
return out.getvalue()
def _restore(self, checkpoint):
# This should be removed once Trainables are refactored.
if "tune_checkpoint_path" in checkpoint:
del checkpoint["tune_checkpoint_path"]
self._status_reporter.save_checkpoint(checkpoint)
def restore_from_object(self, obj):
if self.default_checkpoint_dir is not None and os.exists(
self.default_checkpoint_dir):
shutil.rmtree(self.default_checkpoint_dir)
logger.debug("Clearing default checkpoint: %s",
self.default_checkpoint_dir)
checkpoint_dir = self.create_default_checkpoint_dir()
checkpoint_path = TrainableUtil.create_from_pickle(obj, checkpoint_dir)
self.restore(checkpoint_path)
def _stop(self):
# If everything stayed in synch properly, this should never happen.
if not self._results_queue.empty():
@ -251,7 +339,6 @@ class FunctionRunner(Trainable):
# Check for any errors that might have been missed.
self._report_thread_runner_error()
session.shutdown()
def _report_thread_runner_error(self, block=False):
@ -264,13 +351,35 @@ class FunctionRunner(Trainable):
pass
def detect_checkpoint_function(train_func):
func_args = inspect.getfullargspec(train_func).args
use_checkpoint = "checkpoint" in func_args
return use_checkpoint
def wrap_function(train_func):
class ImplicitFunc(FunctionRunner):
def _trainable_func(self, config, reporter):
def _trainable_func(self, config, reporter, checkpoint):
func_args = inspect.getfullargspec(train_func).args
use_track = ("reporter" not in func_args and len(func_args) == 1)
if use_track:
if len(func_args) > 1: # more arguments than just the config
if "reporter" not in func_args and (
"checkpoint" not in func_args):
raise ValueError(
"Unknown argument found in the Trainable function. "
"Arguments other than the 'config' arg must be one "
"of ['reporter', 'checkpoint']. Found: {}".format(
func_args))
use_reporter = "reporter" in func_args
use_checkpoint = "checkpoint" in func_args
if not use_checkpoint and not use_reporter:
logger.warning(
"Function checkpointing is disabled. This may result in "
"unexpected behavior when using checkpointing features or "
"certain schedulers. To enable, set the train function "
"arguments to be `func(config, checkpoint)`.")
output = train_func(config)
elif use_checkpoint:
output = train_func(config, checkpoint=checkpoint)
else:
output = train_func(config, reporter)

View file

@ -669,8 +669,7 @@ class RayTrialExecutor(TrialExecutor):
elif trial.sync_on_checkpoint:
# This provides FT backwards compatibility in the
# case where a DurableTrainable is not provided.
logger.warning("Trial %s: Reading checkpoint into memory.",
trial)
logger.debug("Trial %s: Reading checkpoint into memory", trial)
data_dict = TrainableUtil.pickle_checkpoint(value)
with self._change_working_directory(trial):
remote = trial.runner.restore_from_object.remote(data_dict)

View file

@ -381,10 +381,14 @@ class Bracket:
assert trial in self._live_trials
assert self._get_result_time(result) >= 0
observed_time = self._get_result_time(result)
last_observed = self._get_result_time(self._live_trials[trial])
delta = self._get_result_time(result) - \
self._get_result_time(self._live_trials[trial])
assert delta >= 0, (result, self._live_trials[trial])
delta = last_observed - observed_time
if delta >= 0:
logger.info("Restoring from a previous point in time. "
"Previous={}; Now={}".format(last_observed,
observed_time))
self._completed_progress += delta
self._live_trials[trial] = result
@ -424,7 +428,7 @@ class Bracket:
def _calculate_total_work(self, n, r, s):
work = 0
cumulative_r = r
for i in range(s + 1):
for _ in range(s + 1):
work += int(n) * int(r)
n /= self._eta
n = int(np.ceil(n))

View file

@ -5,29 +5,6 @@ logger = logging.getLogger(__name__)
_session = None
class _ReporterSession:
def __init__(self, tune_reporter):
self.tune_reporter = tune_reporter
def report(self, **metrics):
return self.tune_reporter(**metrics)
@property
def logdir(self):
"""Trial logdir (subdir of given experiment directory)"""
return self.tune_reporter.logdir
@property
def trial_name(self):
"""Trial name for the corresponding trial of this Trainable"""
return self.tune_reporter.trial_name
@property
def trial_id(self):
"""Trial id for the corresponding trial of this Trainable"""
return self.tune_reporter.trial_id
def get_session():
global _session
if _session is None:
@ -56,7 +33,11 @@ def init(reporter, ignore_reinit_error=True):
else:
raise ValueError(reinit_msg)
_session = _ReporterSession(reporter)
if reporter is None:
logger.warning("You are using a Tune session outside of Tune. "
"Most session commands will have no effect.")
_session = reporter
def shutdown():
@ -86,34 +67,109 @@ def report(**kwargs):
metrics can be used for early stopping or optimization.
"""
_session = get_session()
return _session.report(**kwargs)
return _session(**kwargs)
def make_checkpoint_dir(step=None):
"""Gets the next checkpoint dir.
.. code-block:: python
import time
from ray import tune
def func(config, checkpoint=None):
start = 0
if checkpoint:
with open(checkpoint) as f:
state = json.loads(f.read())
start = state["step"] + 1
for iter in range(start, 100):
time.sleep(1)
checkpoint_dir = tune.make_checkpoint_dir(step=step)
path = os.path.join(checkpoint_dir, "checkpoint")
with open(path, "w") as f:
f.write(json.dumps({"step": start}))
tune.save_checkpoint(path)
tune.report(hello="world", ray="tune")
Args:
step (int): Current training iteration - used for setting
an index to uniquely identify the checkpoint.
.. versionadded:: 0.8.6
"""
_session = get_session()
return _session.make_checkpoint_dir(step=step)
def save_checkpoint(checkpoint):
"""Register the given checkpoint.
.. code-block:: python
import os
import json
import time
from ray import tune
def func(config, checkpoint=None):
start = 0
if checkpoint:
with open(checkpoint) as f:
state = json.loads(f.read())
accuracy = state["acc"]
start = state["step"] + 1
for iter in range(start, 10):
time.sleep(1)
checkpoint_dir = tune.make_checkpoint_dir(step=iter)
path = os.path.join(checkpoint_dir, "checkpoint")
with open(path, "w") as f:
f.write(json.dumps({"step": start}))
tune.save_checkpoint(path)
tune.report(hello="world", ray="tune")
analysis = tune.run(run_me)
Args:
**kwargs: Any key value pair to be logged by Tune. Any of these
metrics can be used for early stopping or optimization.
.. versionadded:: 0.8.6
"""
_session = get_session()
return _session.save_checkpoint(checkpoint)
def get_trial_dir():
"""Returns the directory where trial results are saved.
For function API use only. Do not call this method in the Class API. Use
`self.logdir` instead.
For function API use only.
"""
_session = get_session()
return _session.logdir
def get_trial_name():
"""Trial name for the corresponding trial of this Trainable.
"""Trial name for the corresponding trial.
For function API use only. Do not call this method in the Class API. Use
`self.trial_name` instead.
For function API use only.
"""
_session = get_session()
return _session.trial_name
def get_trial_id():
"""Trial id for the corresponding trial of this Trainable.
"""Trial id for the corresponding trial.
For function API use only. Do not call this method in the Class API. Use
`self.trial_id` instead.
For function API use only.
"""
_session = get_session()
return _session.trial_id

View file

@ -58,7 +58,7 @@ class AxSearch(Searcher):
def easy_objective(config):
for i in range(100):
intermediate_result = config["x1"] + config["x2"] * i
tune.track.log(score=intermediate_result)
tune.report(score=intermediate_result)
client = AxClient(enforce_sequential_optimization=False)
client.create_experiment(parameters=parameters, objective_name="score")

View file

@ -394,10 +394,28 @@ def test_migration_checkpoint_removal(start_connected_emptyhead_cluster,
}
# The following patches only affect __fake_remote.
find_checkpoint_dir = TrainableUtil.find_checkpoint_dir
with patch("ray.tune.logger.get_node_syncer") as mock_get_node_syncer:
def hide_remote_path(path_function):
def hidden_path_func(checkpoint_path):
"""Converts back to local path first."""
if MOCK_REMOTE_DIR in checkpoint_path:
checkpoint_path = checkpoint_path[len(MOCK_REMOTE_DIR):]
checkpoint_path = os.path.join("/", checkpoint_path)
return path_function(checkpoint_path)
return hidden_path_func
trainable_util = "ray.tune.ray_trial_executor.TrainableUtil"
with patch(trainable_util + ".find_checkpoint_dir") as mock_find_dir:
_find_ckpt = trainable_util + ".find_checkpoint_dir"
find_func = TrainableUtil.find_checkpoint_dir
_pickle_ckpt = trainable_util + ".pickle_checkpoint"
pickle_func = TrainableUtil.pickle_checkpoint
with patch(_find_ckpt) as mock_find, patch(_pickle_ckpt) as mock_pkl_ckpt:
# __fake_remote trainables save to a separate "remote" directory.
# TrainableUtil will not check this path unless we mock it.
mock_find.side_effect = hide_remote_path(find_func)
mock_pkl_ckpt.side_effect = hide_remote_path(pickle_func)
with patch("ray.tune.logger.get_node_syncer") as mock_get_node_syncer:
def mock_get_syncer_fn(local_dir, remote_dir, sync_function):
client = mock_storage_client()
@ -405,16 +423,6 @@ def test_migration_checkpoint_removal(start_connected_emptyhead_cluster,
mock_get_node_syncer.side_effect = mock_get_syncer_fn
def mock_find_dir_fn(checkpoint_path):
"""Converts back to local path first."""
checkpoint_path = checkpoint_path[len(MOCK_REMOTE_DIR):]
checkpoint_path = os.path.join("/", checkpoint_path)
return find_checkpoint_dir(checkpoint_path)
# __fake_remote trainables save to a separate "remote" directory.
# TrainableUtil will not check this path unless we mock it.
mock_find_dir.side_effect = mock_find_dir_fn
# Test recovery of trial that has been checkpointed
t1 = Trial(trainable_id, **kwargs)
runner.add_trial(t1)
@ -428,7 +436,6 @@ def test_migration_checkpoint_removal(start_connected_emptyhead_cluster,
cluster.remove_node(node)
cluster.wait_for_nodes()
shutil.rmtree(os.path.dirname(t1.checkpoint.value))
runner.step() # Collect result 3, kick off + fail result 4
runner.step() # Dispatch restore
runner.step() # Process restore + step 4

View file

@ -0,0 +1,169 @@
import json
import os
import unittest
import ray
from ray.rllib import _register_all
from ray import tune
from ray.tune.function_runner import wrap_function
from ray.tune.result import TRAINING_ITERATION
class FunctionApiTest(unittest.TestCase):
def setUp(self):
ray.init(num_cpus=4, num_gpus=0, object_store_memory=150 * 1024 * 1024)
def tearDown(self):
ray.shutdown()
_register_all() # re-register the evicted objects
def testFunctionNoCheckpointing(self):
def train(config, checkpoint=None):
for i in range(10):
tune.report(test=i)
wrapped = wrap_function(train)
new_trainable = wrapped()
result = new_trainable.train()
checkpoint = new_trainable.save()
new_trainable.stop()
new_trainable2 = wrapped()
new_trainable2.restore(checkpoint)
result = new_trainable2.train()
self.assertEquals(result[TRAINING_ITERATION], 1)
checkpoint = new_trainable2.save()
new_trainable2.stop()
def testFunctionRecurringSave(self):
"""This tests that save and restore are commutative."""
def train(config, checkpoint=None):
for step in range(10):
if step % 3 == 0:
checkpoint_dir = tune.make_checkpoint_dir(step=step)
path = os.path.join(checkpoint_dir, "checkpoint")
with open(path, "w") as f:
f.write(json.dumps({"step": step}))
tune.save_checkpoint(path)
tune.report(test=step)
wrapped = wrap_function(train)
new_trainable = wrapped()
new_trainable.train()
checkpoint_obj = new_trainable.save_to_object()
new_trainable.restore_from_object(checkpoint_obj)
checkpoint = new_trainable.save()
new_trainable.stop()
new_trainable2 = wrapped()
new_trainable2.restore(checkpoint)
new_trainable2.train()
new_trainable2.stop()
def testCheckpointFunctionAtEnd(self):
def train(config, checkpoint=False):
for i in range(10):
tune.report(test=i)
checkpoint_dir = tune.make_checkpoint_dir(step=10)
checkpoint_path = os.path.join(checkpoint_dir, "hello")
with open(checkpoint_path, "w") as f:
f.write("hello")
tune.save_checkpoint(checkpoint_path)
[trial] = tune.run(train).trials
assert "hello" in trial.checkpoint.value
def testVariousCheckpointFunctionAtEnd(self):
def train(config, checkpoint=False):
for i in range(10):
checkpoint_dir = tune.make_checkpoint_dir()
checkpoint_path = os.path.join(checkpoint_dir, "hello")
with open(checkpoint_path, "w") as f:
f.write("hello")
tune.save_checkpoint(checkpoint_path)
tune.report(test=i)
checkpoint_dir = tune.make_checkpoint_dir()
checkpoint_path = os.path.join(checkpoint_dir, "goodbye")
with open(checkpoint_path, "w") as f:
f.write("goodbye")
tune.save_checkpoint(checkpoint_path)
[trial] = tune.run(train, keep_checkpoints_num=3).trials
assert "goodbye" in trial.checkpoint.value
def testReuseCheckpoint(self):
def train(config, checkpoint=False):
itr = 0
if checkpoint:
with open(checkpoint, "r") as f:
itr = int(f.read()) + 1
for i in range(itr, config["max_iter"]):
checkpoint_dir = tune.make_checkpoint_dir(step=i)
checkpoint_path = os.path.join(checkpoint_dir, "goodbye")
with open(checkpoint_path, "w") as f:
f.write(str(i))
tune.save_checkpoint(checkpoint_path)
tune.report(test=i, training_iteration=i)
[trial] = tune.run(
train,
config={
"max_iter": 5
},
).trials
last_ckpt = trial.checkpoint.value
assert "goodbye" in last_ckpt
analysis = tune.run(train, config={"max_iter": 10}, restore=last_ckpt)
trial_dfs = list(analysis.trial_dataframes.values())
assert len(trial_dfs[0]["training_iteration"]) == 5
def testRetry(self):
def train(config, checkpoint=None):
restored = bool(checkpoint)
itr = 0
if checkpoint:
with open(checkpoint, "r") as f:
itr = int(f.read()) + 1
for i in range(itr, 10):
if i == 5 and not restored:
raise Exception("try to fail me")
checkpoint_dir = tune.make_checkpoint_dir(step=i)
checkpoint_path = os.path.join(checkpoint_dir, "goodbye")
with open(checkpoint_path, "w") as f:
f.write(str(i))
tune.save_checkpoint(checkpoint_path)
tune.report(test=i, training_iteration=i)
analysis = tune.run(train, max_failures=3)
last_ckpt = analysis.trials[0].checkpoint.value
assert "goodbye" in last_ckpt
trial_dfs = list(analysis.trial_dataframes.values())
assert len(trial_dfs[0]["training_iteration"]) == 10
def testBlankCheckpoint(self):
def train(config, checkpoint=None):
restored = bool(checkpoint)
itr = 0
if checkpoint:
with open(checkpoint, "r") as f:
itr = int(f.read()) + 1
for i in range(itr, 10):
if i == 5 and not restored:
raise Exception("try to fail me")
checkpoint_dir = tune.make_checkpoint_dir()
checkpoint_path = os.path.join(checkpoint_dir, "goodbye")
with open(checkpoint_path, "w") as f:
f.write(str(i))
tune.save_checkpoint(checkpoint_path)
tune.report(test=i, training_iteration=i)
analysis = tune.run(train, max_failures=3)
trial_dfs = list(analysis.trial_dataframes.values())
assert len(trial_dfs[0]["training_iteration"]) == 10

View file

@ -17,18 +17,6 @@ class TrackApiTest(unittest.TestCase):
session.shutdown()
ray.shutdown()
def testSessionInitShutdown(self):
self.assertTrue(session._session is None)
# Checks that the singleton _session is created/destroyed
# by session.init() and session.shutdown()
for _ in range(2):
# do it twice to see that we can reopen the session
session.init(reporter=None)
self.assertTrue(session._session is not None)
session.shutdown()
self.assertTrue(session._session is None)
def testSoftDeprecation(self):
"""Checks that tune.track.log code does not break."""
from ray.tune import track

View file

@ -28,6 +28,34 @@ SETUP_TIME_THRESHOLD = 10
class TrainableUtil:
@staticmethod
def process_checkpoint(checkpoint, parent_dir, trainable_state):
saved_as_dict = False
if isinstance(checkpoint, string_types):
if not checkpoint.startswith(parent_dir):
raise ValueError(
"The returned checkpoint path must be within the "
"given checkpoint dir {}: {}".format(
parent_dir, checkpoint))
checkpoint_path = checkpoint
if os.path.isdir(checkpoint_path):
# Add trailing slash to prevent tune metadata from
# being written outside the directory.
checkpoint_path = os.path.join(checkpoint_path, "")
elif isinstance(checkpoint, dict):
saved_as_dict = True
checkpoint_path = os.path.join(parent_dir, "checkpoint")
with open(checkpoint_path, "wb") as f:
pickle.dump(checkpoint, f)
else:
raise ValueError("Returned unexpected type {}. "
"Expected str or dict.".format(type(checkpoint)))
with open(checkpoint_path + ".tune_metadata", "wb") as f:
trainable_state["saved_as_dict"] = saved_as_dict
pickle.dump(trainable_state, f)
return checkpoint_path
@staticmethod
def pickle_checkpoint(checkpoint_path):
"""Pickles checkpoint data."""
@ -39,7 +67,8 @@ class TrainableUtil:
with open(path, "rb") as f:
data[os.path.relpath(path, checkpoint_dir)] = f.read()
# Use normpath so that a directory path isn't mapped to empty string.
name = os.path.basename(os.path.normpath(checkpoint_path))
name = os.path.relpath(
os.path.normpath(checkpoint_path), checkpoint_dir)
name += os.path.sep if os.path.isdir(checkpoint_path) else ""
data_dict = pickle.dumps({
"checkpoint_name": name,
@ -70,11 +99,38 @@ class TrainableUtil:
return checkpoint_dir
@staticmethod
def make_checkpoint_dir(checkpoint_dir):
"""Creates a checkpoint directory at the provided path."""
def make_checkpoint_dir(checkpoint_dir, index):
"""Creates a checkpoint directory within the provided path.
Args:
checkpoint_dir (str): Path to checkpoint directory.
index (str): A subdirectory will be created
at the checkpoint directory named 'checkpoint_{index}'.
"""
suffix = "checkpoint"
if index is not None:
suffix += "_{}".format(index)
checkpoint_dir = os.path.join(checkpoint_dir, suffix)
os.makedirs(checkpoint_dir, exist_ok=True)
# Drop marker in directory to identify it as a checkpoint dir.
open(os.path.join(checkpoint_dir, ".is_checkpoint"), "a").close()
return checkpoint_dir
@staticmethod
def create_from_pickle(obj, tmpdir):
info = pickle.loads(obj)
data = info["data"]
checkpoint_path = os.path.join(tmpdir, info["checkpoint_name"])
for relpath_name, file_contents in data.items():
path = os.path.join(tmpdir, relpath_name)
# This may be a subdirectory, hence not just using tmpdir
os.makedirs(os.path.dirname(path), exist_ok=True)
with open(path, "wb") as f:
f.write(file_contents)
return checkpoint_path
@staticmethod
def get_checkpoints_paths(logdir):
@ -324,6 +380,16 @@ class Trainable:
return result
def get_state(self):
return {
"experiment_id": self._experiment_id,
"iteration": self._iteration,
"timesteps_total": self._timesteps_total,
"time_total": self._time_total,
"episodes_total": self._episodes_total,
"ray_version": ray.__version__,
}
def save(self, checkpoint_dir=None):
"""Saves the current model state to a checkpoint.
@ -336,41 +402,14 @@ class Trainable:
Returns:
str: Checkpoint path or prefix that may be passed to restore().
"""
checkpoint_dir = os.path.join(checkpoint_dir or self.logdir,
"checkpoint_{}".format(self._iteration))
TrainableUtil.make_checkpoint_dir(checkpoint_dir)
checkpoint_dir = TrainableUtil.make_checkpoint_dir(
checkpoint_dir or self.logdir, index=self.iteration)
checkpoint = self._save(checkpoint_dir)
saved_as_dict = False
if isinstance(checkpoint, string_types):
if not checkpoint.startswith(checkpoint_dir):
raise ValueError(
"The returned checkpoint path must be within the "
"given checkpoint dir {}: {}".format(
checkpoint_dir, checkpoint))
checkpoint_path = checkpoint
if os.path.isdir(checkpoint_path):
# Add trailing slash to prevent tune metadata from
# being written outside the directory.
checkpoint_path = os.path.join(checkpoint_path, "")
elif isinstance(checkpoint, dict):
saved_as_dict = True
checkpoint_path = os.path.join(checkpoint_dir, "checkpoint")
with open(checkpoint_path, "wb") as f:
pickle.dump(checkpoint, f)
else:
raise ValueError("Returned unexpected type {}. "
"Expected str or dict.".format(type(checkpoint)))
with open(checkpoint_path + ".tune_metadata", "wb") as f:
pickle.dump({
"experiment_id": self._experiment_id,
"iteration": self._iteration,
"timesteps_total": self._timesteps_total,
"time_total": self._time_total,
"episodes_total": self._episodes_total,
"saved_as_dict": saved_as_dict,
"ray_version": ray.__version__,
}, f)
trainable_state = self.get_state()
checkpoint_path = TrainableUtil.process_checkpoint(
checkpoint,
parent_dir=checkpoint_dir,
trainable_state=trainable_state)
return checkpoint_path
def save_to_object(self):
@ -434,19 +473,8 @@ class Trainable:
These checkpoints are returned from calls to save_to_object().
"""
info = pickle.loads(obj)
data = info["data"]
tmpdir = tempfile.mkdtemp("restore_from_object", dir=self.logdir)
checkpoint_path = os.path.join(tmpdir, info["checkpoint_name"])
for relpath_name, file_contents in data.items():
path = os.path.join(tmpdir, relpath_name)
# This may be a subdirectory, hence not just using tmpdir
os.makedirs(os.path.dirname(path), exist_ok=True)
with open(path, "wb") as f:
f.write(file_contents)
checkpoint_path = TrainableUtil.create_from_pickle(obj, tmpdir)
self.restore(checkpoint_path)
shutil.rmtree(tmpdir)
@ -531,7 +559,10 @@ class Trainable:
name = self.trial_name
"""
if self._trial_info:
return self._trial_info.trial_name
else:
return "default"
@property
def trial_id(self):
@ -543,7 +574,10 @@ class Trainable:
trial_id = self.trial_id
"""
if self._trial_info:
return self._trial_info.trial_id
else:
return "default"
@property
def iteration(self):

View file

@ -385,10 +385,14 @@ class TrialRunner:
self.trial_executor.try_checkpoint_metadata(trial)
def debug_string(self, delim="\n"):
result_keys = [
list(t.last_result) for t in self.get_trials() if t.last_result
]
metrics = set().union(*result_keys)
messages = [
self._scheduler_alg.debug_string(),
self.trial_executor.debug_string(),
trial_progress_str(self.get_trials()),
trial_progress_str(self.get_trials(), metrics),
]
return delim.join(messages)
@ -468,6 +472,7 @@ class TrialRunner:
result = self.trial_executor.fetch_result(trial)
is_duplicate = RESULT_DUPLICATE in result
force_checkpoint = result.get(SHOULD_CHECKPOINT, False)
# TrialScheduler and SearchAlgorithm still receive a
# notification because there may be special handling for
# the `on_trial_complete` hook.
@ -506,8 +511,7 @@ class TrialRunner:
# the scheduler decision is STOP or PAUSE. Note that
# PAUSE only checkpoints to memory and does not update
# the global checkpoint state.
self._checkpoint_trial_if_needed(
trial, force=result.get(SHOULD_CHECKPOINT, False))
self._checkpoint_trial_if_needed(trial, force=force_checkpoint)
if trial.is_saving:
# Cache decision to execute on after the save is processed.

View file

@ -110,7 +110,10 @@ def run(run_or_experiment,
function or class, or the string identifier of a
trainable function or class registered in the tune registry.
If Experiment, then Tune will execute training based on
Experiment.spec.
Experiment.spec. If you want to pass in a Python lambda, you
will need to first register the function:
``tune.register_trainable("lambda_id", lambda x: ...)``. You can
then use ``tune.run("lambda_id")``.
name (str): Name of experiment.
stop (dict | callable | :class:`Stopper`): Stopping criteria. If dict,
the keys may be any field in the return result of 'train()',
@ -154,8 +157,10 @@ def run(run_or_experiment,
syncing to driver is disabled.
checkpoint_freq (int): How many training iterations between
checkpoints. A value of 0 (default) disables checkpointing.
This has no effect when using the Functional Training API.
checkpoint_at_end (bool): Whether to checkpoint at the end of the
experiment regardless of the checkpoint_freq. Default is False.
This has no effect when using the Functional Training API.
sync_on_checkpoint (bool): Force sync-down of trial checkpoint to
driver. If set to False, checkpoint syncing from worker to driver
is asynchronous and best-effort. This does not affect persistent
@ -214,6 +219,8 @@ def run(run_or_experiment,
if using a RayTrialExecutor (which is the default) and
if Ray is not initialized. Defaults to True.
Returns:
ExperimentAnalysis: Object for experiment analysis.