[tune] Recover experiments from last checkpoint (#1532)

This commit is contained in:
Eric Liang 2018-02-12 14:01:19 -08:00 committed by Richard Liaw
parent 7e998db656
commit ca0f08d100
8 changed files with 138 additions and 7 deletions

View file

@ -199,6 +199,20 @@ Trial Checkpointing
To enable checkpoint / resume, you must subclass ``Trainable`` and implement its ``_train``, ``_save``, and ``_restore`` abstract methods `(example) <https://github.com/ray-project/ray/blob/master/python/ray/tune/examples/hyperband_example.py>`__: Implementing this interface is required to support resource multiplexing in schedulers such as HyperBand and PBT.
Additionally, checkpointing can be used to provide fault-tolerance for experiments. This can be enabled by setting ``checkpoint_freq: N`` and ``max_failures: M`` to checkpoint trials every *N* iterations and recover from up to *M* crashes per trial, e.g.:
.. code-block:: python
run_experiments({
"my_experiment": {
...
"checkpoint_freq": 10,
"max_failures": 5,
},
})
The class interface that must be implemented to enable checkpointing is as follows:
.. autoclass:: ray.tune.trainable.Trainable
Resource Allocation

View file

@ -139,12 +139,19 @@ class _MockAgent(Agent):
"""Mock agent for use in tests"""
_agent_name = "MockAgent"
_default_config = {}
_default_config = {
"mock_error": False,
"persistent_error": False,
}
def _init(self):
self.info = None
self.restored = False
def _train(self):
if self.config["mock_error"] and self.iteration == 1 \
and (self.config["persistent_error"] or not self.restored):
raise Exception("mock error")
return TrainingResult(
episode_reward_mean=10, episode_len_mean=10,
timesteps_this_iter=10, info={})
@ -159,6 +166,7 @@ class _MockAgent(Agent):
with open(checkpoint_path, 'rb') as f:
info = pickle.load(f)
self.info = info
self.restored = True
def set_info(self, info):
self.info = info

View file

@ -174,7 +174,9 @@ class DQNEvaluator(TFMultiGPUSupport):
self.episode_rewards,
self.episode_lengths,
self.saved_mean_reward,
self.obs]
self.obs,
self.global_timestep,
self.local_timestep]
def restore(self, data):
self.exploration = data[0]
@ -182,3 +184,5 @@ class DQNEvaluator(TFMultiGPUSupport):
self.episode_lengths = data[2]
self.saved_mean_reward = data[3]
self.obs = data[4]
self.global_timestep = data[5]
self.local_timestep = data[6]

View file

@ -74,6 +74,10 @@ def make_parser(**kwargs):
"--checkpoint-freq", default=0, type=int,
help="How many training iterations between checkpoints. "
"A value of 0 (default) disables checkpointing.")
parser.add_argument(
"--max-failures", default=3, type=int,
help="Try to recover a trial from its last checkpoint at least this "
"many times. Only applies if checkpointing is enabled.")
parser.add_argument(
"--scheduler", default="FIFO", type=str,
help="FIFO (default), MedianStopping, or HyperBand.")

View file

@ -208,6 +208,7 @@ class VariantGeneratorTest(unittest.TestCase):
trials = generate_trials({
"run": "PPO",
"repeat": 2,
"max_failures": 5,
"config": {
"env": "Pong-v0",
"foo": "bar"
@ -219,6 +220,7 @@ class VariantGeneratorTest(unittest.TestCase):
self.assertEqual(trials[0].config, {"foo": "bar", "env": "Pong-v0"})
self.assertEqual(trials[0].trainable_name, "PPO")
self.assertEqual(trials[0].experiment_tag, "0")
self.assertEqual(trials[0].max_failures, 5)
self.assertEqual(
trials[0].local_dir,
os.path.join(DEFAULT_RESULTS_DIR, "tune-pong"))
@ -457,6 +459,81 @@ class TrialRunnerTest(unittest.TestCase):
self.assertEqual(trials[0].status, Trial.ERROR)
self.assertEqual(trials[1].status, Trial.RUNNING)
def testFailureRecoveryDisabled(self):
ray.init(num_cpus=1, num_gpus=1)
runner = TrialRunner()
kwargs = {
"resources": Resources(cpu=1, gpu=1),
"checkpoint_freq": 1,
"max_failures": 0,
"config": {
"mock_error": True,
},
}
runner.add_trial(Trial("__fake", **kwargs))
trials = runner.get_trials()
runner.step()
self.assertEqual(trials[0].status, Trial.RUNNING)
runner.step()
self.assertEqual(trials[0].status, Trial.RUNNING)
runner.step()
self.assertEqual(trials[0].status, Trial.ERROR)
self.assertEqual(trials[0].num_failures, 1)
def testFailureRecoveryEnabled(self):
ray.init(num_cpus=1, num_gpus=1)
runner = TrialRunner()
kwargs = {
"resources": Resources(cpu=1, gpu=1),
"checkpoint_freq": 1,
"max_failures": 1,
"config": {
"mock_error": True,
},
}
runner.add_trial(Trial("__fake", **kwargs))
trials = runner.get_trials()
runner.step()
self.assertEqual(trials[0].status, Trial.RUNNING)
runner.step()
self.assertEqual(trials[0].status, Trial.RUNNING)
runner.step()
self.assertEqual(trials[0].status, Trial.RUNNING)
self.assertEqual(trials[0].num_failures, 1)
runner.step()
self.assertEqual(trials[0].status, Trial.RUNNING)
def testFailureRecoveryMaxFailures(self):
ray.init(num_cpus=1, num_gpus=1)
runner = TrialRunner()
kwargs = {
"resources": Resources(cpu=1, gpu=1),
"checkpoint_freq": 1,
"max_failures": 2,
"config": {
"mock_error": True,
"persistent_error": True,
},
}
runner.add_trial(Trial("__fake", **kwargs))
trials = runner.get_trials()
runner.step()
self.assertEqual(trials[0].status, Trial.RUNNING)
runner.step()
self.assertEqual(trials[0].status, Trial.RUNNING)
runner.step()
self.assertEqual(trials[0].status, Trial.RUNNING)
self.assertEqual(trials[0].num_failures, 1)
runner.step()
self.assertEqual(trials[0].status, Trial.RUNNING)
self.assertEqual(trials[0].num_failures, 2)
runner.step()
self.assertEqual(trials[0].status, Trial.ERROR)
self.assertEqual(trials[0].num_failures, 3)
def testCheckpointing(self):
ray.init(num_cpus=1, num_gpus=1)
runner = TrialRunner()

View file

@ -78,7 +78,7 @@ class Trial(object):
self, trainable_name, config=None, local_dir=DEFAULT_RESULTS_DIR,
experiment_tag=None, resources=Resources(cpu=1, gpu=0),
stopping_criterion=None, checkpoint_freq=0,
restore_path=None, upload_dir=None):
restore_path=None, upload_dir=None, max_failures=0):
"""Initialize a new trial.
The args here take the same meaning as the command line flags defined
@ -106,6 +106,7 @@ class Trial(object):
self.checkpoint_freq = checkpoint_freq
self.upload_dir = upload_dir
self.verbose = True
self.max_failures = max_failures
# Local trial state that is updated during the run
self.last_result = None
@ -119,6 +120,7 @@ class Trial(object):
self.last_debug = 0
self.trial_id = binary_to_hex(random_string())[:8]
self.error_file = None
self.num_failures = 0
def start(self, checkpoint_obj=None):
"""Starts this trial.
@ -158,6 +160,7 @@ class Trial(object):
try:
if error_msg and self.logdir:
self.num_failures += 1
error_file = os.path.join(
self.logdir, "error_{}.txt".format(date_str()))
with open(error_file, "w") as f:
@ -268,7 +271,12 @@ class Trial(object):
def _status_string(self):
return "{}{}".format(
self.status,
" => {}".format(self.error_file) if self.error_file else "")
", {} failures: {}".format(self.num_failures, self.error_file)
if self.error_file else "")
def has_checkpoint(self):
return self._checkpoint_path is not None or \
self._checkpoint_obj is not None
def checkpoint(self, to_object_store=False):
"""Checkpoints the state of this trial.

View file

@ -241,8 +241,23 @@ class TrialRunner(object):
error_msg = traceback.format_exc()
print("Error processing event:", error_msg)
if trial.status == Trial.RUNNING:
self._scheduler_alg.on_trial_error(self, trial)
self._stop_trial(trial, error=True, error_msg=error_msg)
if trial.has_checkpoint() and \
trial.num_failures < trial.max_failures:
self._try_recover(trial, error_msg)
else:
self._scheduler_alg.on_trial_error(self, trial)
self._stop_trial(trial, error=True, error_msg=error_msg)
def _try_recover(self, trial, error_msg):
try:
print("Attempting to recover trial state from last checkpoint")
trial.stop(error=True, error_msg=error_msg, stop_logger=False)
trial.start()
self._running[trial.train_remote()] = trial
except Exception:
error_msg = traceback.format_exc()
print("Error recovering trial from checkpoint, abort:", error_msg)
self._stop_trial(trial, error=True, error_msg=error_msg)
def _get_runnable(self):
return self._scheduler_alg.choose_trial_to_run(self)

View file

@ -62,7 +62,8 @@ def generate_trials(unresolved_spec, output_path=''):
stopping_criterion=spec.get("stop", {}),
checkpoint_freq=args.checkpoint_freq,
restore_path=spec.get("restore"),
upload_dir=args.upload_dir)
upload_dir=args.upload_dir,
max_failures=args.max_failures)
def generate_variants(unresolved_spec):