mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
[tune] Adds option to checkpoint at end of trials (#2754)
* Added checkpoint_at_end option. To fix #2740 * Added ability to checkpoint at the end of trials if the option is set to True * checkpoint_at_end option added; Consistent with Experience and Trial runner * checkpoint_at_end option mentioned in the tune usage guide * Moved the redundant checkpoint criteria check out of the if-elif * Added note that checkpoint_at_end is enabled only when checkpoint_freq is not 0 * Added test case for checkpoint_at_end * Made checkpoint_at_end have an effect regardless of checkpoint_freq * Removed comment from the test case * Fixed the indentation * Fixed pep8 E231 * Handled cases when trainable does not have _save implemented * Constrained test case to a particular exp using the MockAgent * Revert "Constrained test case to a particular exp using the MockAgent" This reverts commit e965a9358ec7859b99a3aabb681286d6ba3c3906. * Revert "Handled cases when trainable does not have _save implemented" This reverts commit 0f5382f996ff0cbf3d054742db866c33494d173a. * Simpler test case for checkpoint_at_end * Preserved bools from loosing their actual value * Revert "Moved the redundant checkpoint criteria check out of the if-elif" This reverts commit 783005122902240b0ee177e9e206e397356af9c5. * Fix linting error.
This commit is contained in:
parent
6edbbf4fbf
commit
357c0d6156
6 changed files with 59 additions and 2 deletions
|
@ -254,6 +254,20 @@ Additionally, checkpointing can be used to provide fault-tolerance for experimen
|
|||
},
|
||||
})
|
||||
|
||||
The checkpoint_freq may not coincide with the exact end of an experiment. If you want a checkpoint to be created at the end
|
||||
of a trial, you can additionally set the checkpoint_at_end to True. An example is shown below:
|
||||
|
||||
.. code-block:: python
|
||||
:emphasize-lines: 5
|
||||
|
||||
run_experiments({
|
||||
"my_experiment_name": {
|
||||
"run": my_trainable
|
||||
"checkpoint_freq": 10,
|
||||
"checkpoint_at_end": True,
|
||||
"max_failures": 5,
|
||||
},
|
||||
})
|
||||
|
||||
Handling Large Datasets
|
||||
-----------------------
|
||||
|
|
|
@ -112,6 +112,12 @@ def make_parser(parser_creator=None, **kwargs):
|
|||
type=int,
|
||||
help="How many training iterations between checkpoints. "
|
||||
"A value of 0 (default) disables checkpointing.")
|
||||
parser.add_argument(
|
||||
"--checkpoint-at-end",
|
||||
default=False,
|
||||
type=bool,
|
||||
help="Whether to checkpoint at the end of the experiment. "
|
||||
"Default is False.")
|
||||
parser.add_argument(
|
||||
"--max-failures",
|
||||
default=3,
|
||||
|
@ -149,6 +155,8 @@ def to_argv(config):
|
|||
argv.append("--{}".format(k.replace("_", "-")))
|
||||
if isinstance(v, string_types):
|
||||
argv.append(v)
|
||||
elif isinstance(v, bool):
|
||||
argv.append(v)
|
||||
else:
|
||||
argv.append(json.dumps(v, cls=_SafeFallbackEncoder))
|
||||
return argv
|
||||
|
@ -186,6 +194,7 @@ def create_trial_from_spec(spec, output_path, parser, **trial_kwargs):
|
|||
# json.load leads to str -> unicode in py2.7
|
||||
stopping_criterion=spec.get("stop", {}),
|
||||
checkpoint_freq=args.checkpoint_freq,
|
||||
checkpoint_at_end=args.checkpoint_at_end,
|
||||
# str(None) doesn't create None
|
||||
restore_path=spec.get("restore"),
|
||||
upload_dir=args.upload_dir,
|
||||
|
|
|
@ -43,6 +43,8 @@ class Experiment(object):
|
|||
to (e.g. ``s3://bucket``).
|
||||
checkpoint_freq (int): How many training iterations between
|
||||
checkpoints. A value of 0 (default) disables checkpointing.
|
||||
checkpoint_at_end (bool): Whether to checkpoint at the end of the
|
||||
experiment regardless of the checkpoint_freq. Default is False.
|
||||
max_failures (int): Try to recover a trial from its last
|
||||
checkpoint at least this many times. Only applies if
|
||||
checkpointing is enabled. Defaults to 3.
|
||||
|
@ -82,6 +84,7 @@ class Experiment(object):
|
|||
local_dir=None,
|
||||
upload_dir="",
|
||||
checkpoint_freq=0,
|
||||
checkpoint_at_end=False,
|
||||
max_failures=3,
|
||||
restore=None):
|
||||
spec = {
|
||||
|
@ -93,6 +96,7 @@ class Experiment(object):
|
|||
"local_dir": local_dir or DEFAULT_RESULTS_DIR,
|
||||
"upload_dir": upload_dir,
|
||||
"checkpoint_freq": checkpoint_freq,
|
||||
"checkpoint_at_end": checkpoint_at_end,
|
||||
"max_failures": max_failures,
|
||||
"restore": restore
|
||||
}
|
||||
|
|
|
@ -938,6 +938,26 @@ class TrialRunnerTest(unittest.TestCase):
|
|||
self.assertEqual(ray.get(trials[1].runner.get_info.remote()), 1)
|
||||
self.addCleanup(os.remove, path)
|
||||
|
||||
def testCheckpointingAtEnd(self):
|
||||
ray.init(num_cpus=1, num_gpus=1)
|
||||
runner = TrialRunner(BasicVariantGenerator())
|
||||
kwargs = {
|
||||
"stopping_criterion": {
|
||||
"training_iteration": 2
|
||||
},
|
||||
"checkpoint_at_end": True,
|
||||
"resources": Resources(cpu=1, gpu=1),
|
||||
}
|
||||
runner.add_trial(Trial("__fake", **kwargs))
|
||||
trials = runner.get_trials()
|
||||
|
||||
runner.step()
|
||||
self.assertEqual(trials[0].status, Trial.RUNNING)
|
||||
runner.step()
|
||||
runner.step()
|
||||
self.assertEqual(trials[0].last_result[DONE], True)
|
||||
self.assertEqual(trials[0].has_checkpoint(), True)
|
||||
|
||||
def testResultDone(self):
|
||||
"""Tests that last_result is marked `done` after trial is complete."""
|
||||
ray.init(num_cpus=1, num_gpus=1)
|
||||
|
|
|
@ -112,6 +112,7 @@ class Trial(object):
|
|||
resources=None,
|
||||
stopping_criterion=None,
|
||||
checkpoint_freq=0,
|
||||
checkpoint_at_end=False,
|
||||
restore_path=None,
|
||||
upload_dir=None,
|
||||
max_failures=0):
|
||||
|
@ -142,6 +143,7 @@ class Trial(object):
|
|||
# Local trial state that is updated during the run
|
||||
self.last_result = None
|
||||
self.checkpoint_freq = checkpoint_freq
|
||||
self.checkpoint_at_end = checkpoint_at_end
|
||||
self._checkpoint = Checkpoint(
|
||||
storage=Checkpoint.DISK, value=restore_path)
|
||||
self.status = Trial.PENDING
|
||||
|
@ -203,9 +205,12 @@ class Trial(object):
|
|||
|
||||
return False
|
||||
|
||||
def should_checkpoint(self):
|
||||
def should_checkpoint(self, result):
|
||||
"""Whether this trial is due for checkpointing."""
|
||||
|
||||
if result.get(DONE) and self.checkpoint_at_end:
|
||||
return True
|
||||
|
||||
if not self.checkpoint_freq:
|
||||
return False
|
||||
|
||||
|
|
|
@ -223,6 +223,7 @@ class TrialRunner(object):
|
|||
self._search_alg.on_trial_complete(
|
||||
trial.trial_id, result=result)
|
||||
decision = TrialScheduler.STOP
|
||||
|
||||
else:
|
||||
decision = self._scheduler_alg.on_trial_result(
|
||||
self, trial, result)
|
||||
|
@ -234,13 +235,17 @@ class TrialRunner(object):
|
|||
result, terminate=(decision == TrialScheduler.STOP))
|
||||
|
||||
if decision == TrialScheduler.CONTINUE:
|
||||
if trial.should_checkpoint():
|
||||
if trial.should_checkpoint(result):
|
||||
# TODO(rliaw): This is a blocking call
|
||||
self.trial_executor.save(trial)
|
||||
self.trial_executor.continue_training(trial)
|
||||
elif decision == TrialScheduler.PAUSE:
|
||||
self.trial_executor.pause_trial(trial)
|
||||
elif decision == TrialScheduler.STOP:
|
||||
# Checkpoint before ending the trial
|
||||
# if checkpoint_at_end experiment option is set to True
|
||||
if trial.should_checkpoint(result):
|
||||
self.trial_executor.save(trial)
|
||||
self.trial_executor.stop_trial(trial)
|
||||
else:
|
||||
assert False, "Invalid scheduling decision: {}".format(
|
||||
|
|
Loading…
Add table
Reference in a new issue