[tune] Adds option to checkpoint at end of trials (#2754)

* Added checkpoint_at_end option. To fix #2740

* Added ability to checkpoint at the end of trials if the option is set to True

* checkpoint_at_end option added; Consistent with Experience and Trial runner

* checkpoint_at_end option mentioned in the tune usage guide

* Moved the redundant checkpoint criteria check out of the if-elif

* Added note that checkpoint_at_end is enabled only when checkpoint_freq is not 0

* Added test case for checkpoint_at_end

* Made checkpoint_at_end have an effect regardless of checkpoint_freq

* Removed comment from the test case

* Fixed the indentation

* Fixed pep8 E231

* Handled cases when trainable does not have _save implemented

* Constrained test case to a particular exp using the MockAgent

* Revert "Constrained test case to a particular exp using the MockAgent"

This reverts commit e965a9358ec7859b99a3aabb681286d6ba3c3906.

* Revert "Handled cases when trainable does not have _save implemented"

This reverts commit 0f5382f996ff0cbf3d054742db866c33494d173a.

* Simpler test case for checkpoint_at_end

* Preserved bools from loosing their actual value

* Revert "Moved the redundant checkpoint criteria check out of the if-elif"

This reverts commit 783005122902240b0ee177e9e206e397356af9c5.

* Fix linting error.
This commit is contained in:
Praveen Palanisamy 2018-08-29 16:14:17 -04:00 committed by Richard Liaw
parent 6edbbf4fbf
commit 357c0d6156
6 changed files with 59 additions and 2 deletions

View file

@ -254,6 +254,20 @@ Additionally, checkpointing can be used to provide fault-tolerance for experimen
},
})
The checkpoint_freq may not coincide with the exact end of an experiment. If you want a checkpoint to be created at the end
of a trial, you can additionally set the checkpoint_at_end to True. An example is shown below:
.. code-block:: python
:emphasize-lines: 5
run_experiments({
"my_experiment_name": {
"run": my_trainable
"checkpoint_freq": 10,
"checkpoint_at_end": True,
"max_failures": 5,
},
})
Handling Large Datasets
-----------------------

View file

@ -112,6 +112,12 @@ def make_parser(parser_creator=None, **kwargs):
type=int,
help="How many training iterations between checkpoints. "
"A value of 0 (default) disables checkpointing.")
parser.add_argument(
"--checkpoint-at-end",
default=False,
type=bool,
help="Whether to checkpoint at the end of the experiment. "
"Default is False.")
parser.add_argument(
"--max-failures",
default=3,
@ -149,6 +155,8 @@ def to_argv(config):
argv.append("--{}".format(k.replace("_", "-")))
if isinstance(v, string_types):
argv.append(v)
elif isinstance(v, bool):
argv.append(v)
else:
argv.append(json.dumps(v, cls=_SafeFallbackEncoder))
return argv
@ -186,6 +194,7 @@ def create_trial_from_spec(spec, output_path, parser, **trial_kwargs):
# json.load leads to str -> unicode in py2.7
stopping_criterion=spec.get("stop", {}),
checkpoint_freq=args.checkpoint_freq,
checkpoint_at_end=args.checkpoint_at_end,
# str(None) doesn't create None
restore_path=spec.get("restore"),
upload_dir=args.upload_dir,

View file

@ -43,6 +43,8 @@ class Experiment(object):
to (e.g. ``s3://bucket``).
checkpoint_freq (int): How many training iterations between
checkpoints. A value of 0 (default) disables checkpointing.
checkpoint_at_end (bool): Whether to checkpoint at the end of the
experiment regardless of the checkpoint_freq. Default is False.
max_failures (int): Try to recover a trial from its last
checkpoint at least this many times. Only applies if
checkpointing is enabled. Defaults to 3.
@ -82,6 +84,7 @@ class Experiment(object):
local_dir=None,
upload_dir="",
checkpoint_freq=0,
checkpoint_at_end=False,
max_failures=3,
restore=None):
spec = {
@ -93,6 +96,7 @@ class Experiment(object):
"local_dir": local_dir or DEFAULT_RESULTS_DIR,
"upload_dir": upload_dir,
"checkpoint_freq": checkpoint_freq,
"checkpoint_at_end": checkpoint_at_end,
"max_failures": max_failures,
"restore": restore
}

View file

@ -938,6 +938,26 @@ class TrialRunnerTest(unittest.TestCase):
self.assertEqual(ray.get(trials[1].runner.get_info.remote()), 1)
self.addCleanup(os.remove, path)
def testCheckpointingAtEnd(self):
ray.init(num_cpus=1, num_gpus=1)
runner = TrialRunner(BasicVariantGenerator())
kwargs = {
"stopping_criterion": {
"training_iteration": 2
},
"checkpoint_at_end": True,
"resources": Resources(cpu=1, gpu=1),
}
runner.add_trial(Trial("__fake", **kwargs))
trials = runner.get_trials()
runner.step()
self.assertEqual(trials[0].status, Trial.RUNNING)
runner.step()
runner.step()
self.assertEqual(trials[0].last_result[DONE], True)
self.assertEqual(trials[0].has_checkpoint(), True)
def testResultDone(self):
"""Tests that last_result is marked `done` after trial is complete."""
ray.init(num_cpus=1, num_gpus=1)

View file

@ -112,6 +112,7 @@ class Trial(object):
resources=None,
stopping_criterion=None,
checkpoint_freq=0,
checkpoint_at_end=False,
restore_path=None,
upload_dir=None,
max_failures=0):
@ -142,6 +143,7 @@ class Trial(object):
# Local trial state that is updated during the run
self.last_result = None
self.checkpoint_freq = checkpoint_freq
self.checkpoint_at_end = checkpoint_at_end
self._checkpoint = Checkpoint(
storage=Checkpoint.DISK, value=restore_path)
self.status = Trial.PENDING
@ -203,9 +205,12 @@ class Trial(object):
return False
def should_checkpoint(self):
def should_checkpoint(self, result):
"""Whether this trial is due for checkpointing."""
if result.get(DONE) and self.checkpoint_at_end:
return True
if not self.checkpoint_freq:
return False

View file

@ -223,6 +223,7 @@ class TrialRunner(object):
self._search_alg.on_trial_complete(
trial.trial_id, result=result)
decision = TrialScheduler.STOP
else:
decision = self._scheduler_alg.on_trial_result(
self, trial, result)
@ -234,13 +235,17 @@ class TrialRunner(object):
result, terminate=(decision == TrialScheduler.STOP))
if decision == TrialScheduler.CONTINUE:
if trial.should_checkpoint():
if trial.should_checkpoint(result):
# TODO(rliaw): This is a blocking call
self.trial_executor.save(trial)
self.trial_executor.continue_training(trial)
elif decision == TrialScheduler.PAUSE:
self.trial_executor.pause_trial(trial)
elif decision == TrialScheduler.STOP:
# Checkpoint before ending the trial
# if checkpoint_at_end experiment option is set to True
if trial.should_checkpoint(result):
self.trial_executor.save(trial)
self.trial_executor.stop_trial(trial)
else:
assert False, "Invalid scheduling decision: {}".format(