mirror of
https://github.com/vale981/ray
synced 2025-04-23 06:25:52 -04:00
[tune] Recover experiments from last checkpoint (#1532)
This commit is contained in:
parent
7e998db656
commit
ca0f08d100
8 changed files with 138 additions and 7 deletions
|
@ -199,6 +199,20 @@ Trial Checkpointing
|
|||
|
||||
To enable checkpoint / resume, you must subclass ``Trainable`` and implement its ``_train``, ``_save``, and ``_restore`` abstract methods `(example) <https://github.com/ray-project/ray/blob/master/python/ray/tune/examples/hyperband_example.py>`__: Implementing this interface is required to support resource multiplexing in schedulers such as HyperBand and PBT.
|
||||
|
||||
Additionally, checkpointing can be used to provide fault-tolerance for experiments. This can be enabled by setting ``checkpoint_freq: N`` and ``max_failures: M`` to checkpoint trials every *N* iterations and recover from up to *M* crashes per trial, e.g.:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
run_experiments({
|
||||
"my_experiment": {
|
||||
...
|
||||
"checkpoint_freq": 10,
|
||||
"max_failures": 5,
|
||||
},
|
||||
})
|
||||
|
||||
The class interface that must be implemented to enable checkpointing is as follows:
|
||||
|
||||
.. autoclass:: ray.tune.trainable.Trainable
|
||||
|
||||
Resource Allocation
|
||||
|
|
|
@ -139,12 +139,19 @@ class _MockAgent(Agent):
|
|||
"""Mock agent for use in tests"""
|
||||
|
||||
_agent_name = "MockAgent"
|
||||
_default_config = {}
|
||||
_default_config = {
|
||||
"mock_error": False,
|
||||
"persistent_error": False,
|
||||
}
|
||||
|
||||
def _init(self):
|
||||
self.info = None
|
||||
self.restored = False
|
||||
|
||||
def _train(self):
|
||||
if self.config["mock_error"] and self.iteration == 1 \
|
||||
and (self.config["persistent_error"] or not self.restored):
|
||||
raise Exception("mock error")
|
||||
return TrainingResult(
|
||||
episode_reward_mean=10, episode_len_mean=10,
|
||||
timesteps_this_iter=10, info={})
|
||||
|
@ -159,6 +166,7 @@ class _MockAgent(Agent):
|
|||
with open(checkpoint_path, 'rb') as f:
|
||||
info = pickle.load(f)
|
||||
self.info = info
|
||||
self.restored = True
|
||||
|
||||
def set_info(self, info):
|
||||
self.info = info
|
||||
|
|
|
@ -174,7 +174,9 @@ class DQNEvaluator(TFMultiGPUSupport):
|
|||
self.episode_rewards,
|
||||
self.episode_lengths,
|
||||
self.saved_mean_reward,
|
||||
self.obs]
|
||||
self.obs,
|
||||
self.global_timestep,
|
||||
self.local_timestep]
|
||||
|
||||
def restore(self, data):
|
||||
self.exploration = data[0]
|
||||
|
@ -182,3 +184,5 @@ class DQNEvaluator(TFMultiGPUSupport):
|
|||
self.episode_lengths = data[2]
|
||||
self.saved_mean_reward = data[3]
|
||||
self.obs = data[4]
|
||||
self.global_timestep = data[5]
|
||||
self.local_timestep = data[6]
|
||||
|
|
|
@ -74,6 +74,10 @@ def make_parser(**kwargs):
|
|||
"--checkpoint-freq", default=0, type=int,
|
||||
help="How many training iterations between checkpoints. "
|
||||
"A value of 0 (default) disables checkpointing.")
|
||||
parser.add_argument(
|
||||
"--max-failures", default=3, type=int,
|
||||
help="Try to recover a trial from its last checkpoint at least this "
|
||||
"many times. Only applies if checkpointing is enabled.")
|
||||
parser.add_argument(
|
||||
"--scheduler", default="FIFO", type=str,
|
||||
help="FIFO (default), MedianStopping, or HyperBand.")
|
||||
|
|
|
@ -208,6 +208,7 @@ class VariantGeneratorTest(unittest.TestCase):
|
|||
trials = generate_trials({
|
||||
"run": "PPO",
|
||||
"repeat": 2,
|
||||
"max_failures": 5,
|
||||
"config": {
|
||||
"env": "Pong-v0",
|
||||
"foo": "bar"
|
||||
|
@ -219,6 +220,7 @@ class VariantGeneratorTest(unittest.TestCase):
|
|||
self.assertEqual(trials[0].config, {"foo": "bar", "env": "Pong-v0"})
|
||||
self.assertEqual(trials[0].trainable_name, "PPO")
|
||||
self.assertEqual(trials[0].experiment_tag, "0")
|
||||
self.assertEqual(trials[0].max_failures, 5)
|
||||
self.assertEqual(
|
||||
trials[0].local_dir,
|
||||
os.path.join(DEFAULT_RESULTS_DIR, "tune-pong"))
|
||||
|
@ -457,6 +459,81 @@ class TrialRunnerTest(unittest.TestCase):
|
|||
self.assertEqual(trials[0].status, Trial.ERROR)
|
||||
self.assertEqual(trials[1].status, Trial.RUNNING)
|
||||
|
||||
def testFailureRecoveryDisabled(self):
|
||||
ray.init(num_cpus=1, num_gpus=1)
|
||||
runner = TrialRunner()
|
||||
kwargs = {
|
||||
"resources": Resources(cpu=1, gpu=1),
|
||||
"checkpoint_freq": 1,
|
||||
"max_failures": 0,
|
||||
"config": {
|
||||
"mock_error": True,
|
||||
},
|
||||
}
|
||||
runner.add_trial(Trial("__fake", **kwargs))
|
||||
trials = runner.get_trials()
|
||||
|
||||
runner.step()
|
||||
self.assertEqual(trials[0].status, Trial.RUNNING)
|
||||
runner.step()
|
||||
self.assertEqual(trials[0].status, Trial.RUNNING)
|
||||
runner.step()
|
||||
self.assertEqual(trials[0].status, Trial.ERROR)
|
||||
self.assertEqual(trials[0].num_failures, 1)
|
||||
|
||||
def testFailureRecoveryEnabled(self):
|
||||
ray.init(num_cpus=1, num_gpus=1)
|
||||
runner = TrialRunner()
|
||||
kwargs = {
|
||||
"resources": Resources(cpu=1, gpu=1),
|
||||
"checkpoint_freq": 1,
|
||||
"max_failures": 1,
|
||||
"config": {
|
||||
"mock_error": True,
|
||||
},
|
||||
}
|
||||
runner.add_trial(Trial("__fake", **kwargs))
|
||||
trials = runner.get_trials()
|
||||
|
||||
runner.step()
|
||||
self.assertEqual(trials[0].status, Trial.RUNNING)
|
||||
runner.step()
|
||||
self.assertEqual(trials[0].status, Trial.RUNNING)
|
||||
runner.step()
|
||||
self.assertEqual(trials[0].status, Trial.RUNNING)
|
||||
self.assertEqual(trials[0].num_failures, 1)
|
||||
runner.step()
|
||||
self.assertEqual(trials[0].status, Trial.RUNNING)
|
||||
|
||||
def testFailureRecoveryMaxFailures(self):
|
||||
ray.init(num_cpus=1, num_gpus=1)
|
||||
runner = TrialRunner()
|
||||
kwargs = {
|
||||
"resources": Resources(cpu=1, gpu=1),
|
||||
"checkpoint_freq": 1,
|
||||
"max_failures": 2,
|
||||
"config": {
|
||||
"mock_error": True,
|
||||
"persistent_error": True,
|
||||
},
|
||||
}
|
||||
runner.add_trial(Trial("__fake", **kwargs))
|
||||
trials = runner.get_trials()
|
||||
|
||||
runner.step()
|
||||
self.assertEqual(trials[0].status, Trial.RUNNING)
|
||||
runner.step()
|
||||
self.assertEqual(trials[0].status, Trial.RUNNING)
|
||||
runner.step()
|
||||
self.assertEqual(trials[0].status, Trial.RUNNING)
|
||||
self.assertEqual(trials[0].num_failures, 1)
|
||||
runner.step()
|
||||
self.assertEqual(trials[0].status, Trial.RUNNING)
|
||||
self.assertEqual(trials[0].num_failures, 2)
|
||||
runner.step()
|
||||
self.assertEqual(trials[0].status, Trial.ERROR)
|
||||
self.assertEqual(trials[0].num_failures, 3)
|
||||
|
||||
def testCheckpointing(self):
|
||||
ray.init(num_cpus=1, num_gpus=1)
|
||||
runner = TrialRunner()
|
||||
|
|
|
@ -78,7 +78,7 @@ class Trial(object):
|
|||
self, trainable_name, config=None, local_dir=DEFAULT_RESULTS_DIR,
|
||||
experiment_tag=None, resources=Resources(cpu=1, gpu=0),
|
||||
stopping_criterion=None, checkpoint_freq=0,
|
||||
restore_path=None, upload_dir=None):
|
||||
restore_path=None, upload_dir=None, max_failures=0):
|
||||
"""Initialize a new trial.
|
||||
|
||||
The args here take the same meaning as the command line flags defined
|
||||
|
@ -106,6 +106,7 @@ class Trial(object):
|
|||
self.checkpoint_freq = checkpoint_freq
|
||||
self.upload_dir = upload_dir
|
||||
self.verbose = True
|
||||
self.max_failures = max_failures
|
||||
|
||||
# Local trial state that is updated during the run
|
||||
self.last_result = None
|
||||
|
@ -119,6 +120,7 @@ class Trial(object):
|
|||
self.last_debug = 0
|
||||
self.trial_id = binary_to_hex(random_string())[:8]
|
||||
self.error_file = None
|
||||
self.num_failures = 0
|
||||
|
||||
def start(self, checkpoint_obj=None):
|
||||
"""Starts this trial.
|
||||
|
@ -158,6 +160,7 @@ class Trial(object):
|
|||
|
||||
try:
|
||||
if error_msg and self.logdir:
|
||||
self.num_failures += 1
|
||||
error_file = os.path.join(
|
||||
self.logdir, "error_{}.txt".format(date_str()))
|
||||
with open(error_file, "w") as f:
|
||||
|
@ -268,7 +271,12 @@ class Trial(object):
|
|||
def _status_string(self):
|
||||
return "{}{}".format(
|
||||
self.status,
|
||||
" => {}".format(self.error_file) if self.error_file else "")
|
||||
", {} failures: {}".format(self.num_failures, self.error_file)
|
||||
if self.error_file else "")
|
||||
|
||||
def has_checkpoint(self):
|
||||
return self._checkpoint_path is not None or \
|
||||
self._checkpoint_obj is not None
|
||||
|
||||
def checkpoint(self, to_object_store=False):
|
||||
"""Checkpoints the state of this trial.
|
||||
|
|
|
@ -241,8 +241,23 @@ class TrialRunner(object):
|
|||
error_msg = traceback.format_exc()
|
||||
print("Error processing event:", error_msg)
|
||||
if trial.status == Trial.RUNNING:
|
||||
self._scheduler_alg.on_trial_error(self, trial)
|
||||
self._stop_trial(trial, error=True, error_msg=error_msg)
|
||||
if trial.has_checkpoint() and \
|
||||
trial.num_failures < trial.max_failures:
|
||||
self._try_recover(trial, error_msg)
|
||||
else:
|
||||
self._scheduler_alg.on_trial_error(self, trial)
|
||||
self._stop_trial(trial, error=True, error_msg=error_msg)
|
||||
|
||||
def _try_recover(self, trial, error_msg):
|
||||
try:
|
||||
print("Attempting to recover trial state from last checkpoint")
|
||||
trial.stop(error=True, error_msg=error_msg, stop_logger=False)
|
||||
trial.start()
|
||||
self._running[trial.train_remote()] = trial
|
||||
except Exception:
|
||||
error_msg = traceback.format_exc()
|
||||
print("Error recovering trial from checkpoint, abort:", error_msg)
|
||||
self._stop_trial(trial, error=True, error_msg=error_msg)
|
||||
|
||||
def _get_runnable(self):
|
||||
return self._scheduler_alg.choose_trial_to_run(self)
|
||||
|
|
|
@ -62,7 +62,8 @@ def generate_trials(unresolved_spec, output_path=''):
|
|||
stopping_criterion=spec.get("stop", {}),
|
||||
checkpoint_freq=args.checkpoint_freq,
|
||||
restore_path=spec.get("restore"),
|
||||
upload_dir=args.upload_dir)
|
||||
upload_dir=args.upload_dir,
|
||||
max_failures=args.max_failures)
|
||||
|
||||
|
||||
def generate_variants(unresolved_spec):
|
||||
|
|
Loading…
Add table
Reference in a new issue