[tune] Retry restore on timeout (#6284)

* Retry recovery on timeout

* fix bug, revert some code

* Add test for restore time outs.

* Fix lint

* Address comments

* Don't timeout restores.
This commit is contained in:
Ujval Misra 2019-12-02 20:01:47 -08:00 committed by Richard Liaw
parent 0b3d5d989b
commit fa5d62e8ba
2 changed files with 105 additions and 70 deletions

View file

@ -26,6 +26,7 @@ RESOURCE_REFRESH_PERIOD = 0.5 # Refresh resources every 500 ms
BOTTLENECK_WARN_PERIOD_S = 60
NONTRIVIAL_WAIT_TIME_THRESHOLD_S = 1e-3
DEFAULT_GET_TIMEOUT = 30.0 # seconds
TRIAL_START_ATTEMPTS = 3
class _LocalWrapper(object):
@ -80,8 +81,8 @@ class RayTrialExecutor(TrialExecutor):
if (self._reuse_actors and reuse_allowed
and self._cached_actor is not None):
logger.debug("Reusing cached runner {} for {}".format(
self._cached_actor, trial.trial_id))
logger.debug("Trial %s: Reusing cached runner %s", trial,
self._cached_actor)
existing_runner = self._cached_actor
self._cached_actor = None
trial.runner = existing_runner
@ -134,21 +135,25 @@ class RayTrialExecutor(TrialExecutor):
self._running[remote] = trial
def _start_trial(self, trial, checkpoint=None):
def _start_trial(self, trial, checkpoint=None, runner=None):
"""Starts trial and restores last result if trial was paused.
Raises:
RuntimeError if restoring from checkpoint fails.
Args:
trial (Trial): The trial to start.
checkpoint (Optional[Checkpoint]): The checkpoint to restore from.
If None, and no trial checkpoint exists, the trial is started
from the beginning.
runner (Trainable): The remote runner to use. This can be the
cached actor. If None, a new runner is created.
See `RayTrialExecutor.restore` for possible errors raised.
"""
prior_status = trial.status
self.set_status(trial, Trial.RUNNING)
trial.runner = self._setup_remote_runner(
trial.runner = runner or self._setup_remote_runner(
trial,
reuse_allowed=checkpoint is not None or trial.has_checkpoint())
if not self.restore(trial, checkpoint):
if trial.status == Trial.ERROR:
raise RuntimeError(
"Trial {}: Restore from checkpoint failed.".format(trial))
self.restore(trial, checkpoint)
previous_run = self._find_item(self._paused, trial)
if prior_status == Trial.PAUSED and previous_run:
@ -206,34 +211,49 @@ class RayTrialExecutor(TrialExecutor):
of trial.
"""
self._commit_resources(trial.resources)
try:
self._start_trial(trial, checkpoint)
except AbortTrialExecution:
logger.exception("Trial %s: Error starting runner, aborting!",
trial)
time.sleep(2)
error_msg = traceback.format_exc()
self._stop_trial(trial, error=True, error_msg=error_msg)
return # don't retry fatal Tune errors
except Exception:
logger.exception(
"Trial %s: Error starting runner. Attempting "
"restart without checkpoint.", trial)
time.sleep(2)
error_msg = traceback.format_exc()
self._stop_trial(trial, error=True, error_msg=error_msg)
remote_runner = None
attempts = 0
while attempts < TRIAL_START_ATTEMPTS:
attempts += 1
if attempts > 1:
logger.warning("Trial %s: Start attempt #%s...", trial,
attempts)
try:
# This forces the trial to not start from checkpoint.
trial.clear_checkpoint()
self._start_trial(trial)
except Exception:
logger.exception(
"Trial %s: Error starting runner on second "
"attempt, aborting!", trial)
self._start_trial(trial, checkpoint, remote_runner)
break
except AbortTrialExecution:
logger.exception("Trial %s: Error starting runner, aborting!",
trial)
time.sleep(2)
error_msg = traceback.format_exc()
self._stop_trial(trial, error=True, error_msg=error_msg)
break # don't retry fatal Tune errors
except RayTimeoutError:
# Reuse the existing runner on retries.
remote_runner = trial.runner
warning = ("Runner task timed out. This could be due to "
"slow worker startup.")
if attempts == TRIAL_START_ATTEMPTS:
error_msg = traceback.format_exc()
self._stop_trial(trial, error=True, error_msg=error_msg)
else:
warning += " Reusing the same runner."
logger.warning("Trial %s: %s", trial, warning)
except Exception:
logger.exception("Trial %s: Error starting runner.", trial)
time.sleep(2)
error_msg = traceback.format_exc()
self._stop_trial(trial, error=True, error_msg=error_msg)
remote_runner = None
# This forces the trial to not start from checkpoint.
checkpoint = None
trial.clear_checkpoint()
# Note that we don't return the resources, since they may
# have been lost. TODO(ujvl): is this the right thing to do?
else:
logger.exception(
"Trial %s: Aborting trial after %s start "
"attempts!", trial, TRIAL_START_ATTEMPTS)
def _find_item(self, dictionary, item):
out = [rid for rid, t in dictionary.items() if t is item]
@ -562,48 +582,34 @@ class RayTrialExecutor(TrialExecutor):
This will also sync the trial results to a new location
if restoring on a different node.
Raises:
RuntimeError: This error is raised if no runner is found.
RayTimeoutError: This error is raised if a remote call to the
runner times out.
"""
if checkpoint is None or checkpoint.value is None:
checkpoint = trial.checkpoint
if checkpoint is None or checkpoint.value is None:
return True
if checkpoint.value is None:
return
if trial.runner is None:
logger.error(
"Trial %s: Unable to restore - no runner. "
"Setting status to ERROR.", trial)
self.set_status(trial, Trial.ERROR)
return False
try:
value = checkpoint.value
if checkpoint.storage == Checkpoint.MEMORY:
assert type(value) != Checkpoint, type(value)
trial.runner.restore_from_object.remote(value)
else:
logger.info("Trial %s: Attempting restoration from %s", trial,
checkpoint.value)
with warn_if_slow("get_current_ip"):
worker_ip = ray.get(trial.runner.current_ip.remote(),
DEFAULT_GET_TIMEOUT)
with warn_if_slow("sync_to_new_location"):
trial.sync_logger_to_new_location(worker_ip)
with warn_if_slow("restore_from_disk"):
ray.get(
trial.runner.restore.remote(value),
DEFAULT_GET_TIMEOUT)
except RayTimeoutError:
logger.exception(
"Trial %s: Unable to restore - runner task timed "
"out. Setting status to ERROR", trial)
self.set_status(trial, Trial.ERROR)
return False
except Exception:
logger.exception(
"Trial %s: Unable to restore. Setting status to ERROR", trial)
self.set_status(trial, Trial.ERROR)
return False
raise RuntimeError(
"Trial {}: Unable to restore - no runner found.".format(trial))
value = checkpoint.value
if checkpoint.storage == Checkpoint.MEMORY:
assert not isinstance(value, Checkpoint), type(value)
trial.runner.restore_from_object.remote(value)
else:
logger.info("Trial %s: Attempting restore from %s", trial, value)
with warn_if_slow("get_current_ip"):
worker_ip = ray.get(trial.runner.current_ip.remote(),
DEFAULT_GET_TIMEOUT)
with warn_if_slow("sync_to_new_location"):
trial.sync_logger_to_new_location(worker_ip)
with warn_if_slow("restore_from_disk"):
# TODO(ujvl): Take blocking restores out of the control loop.
ray.get(trial.runner.restore.remote(value))
trial.last_result = checkpoint.result
return True
def export_trial_if_needed(self, trial):
"""Exports model of this trial based on trial.export_formats.

View file

@ -4,9 +4,11 @@ from __future__ import division
from __future__ import print_function
import json
import sys
import unittest
import ray
from ray.exceptions import RayTimeoutError
from ray.rllib import _register_all
from ray.tune import Trainable
from ray.tune.ray_trial_executor import RayTrialExecutor
@ -16,6 +18,11 @@ from ray.tune.trial import Trial, Checkpoint
from ray.tune.resources import Resources
from ray.cluster_utils import Cluster
if sys.version_info >= (3, 3):
from unittest.mock import patch
else:
from mock import patch
class RayTrialExecutorTest(unittest.TestCase):
def setUp(self):
@ -43,6 +50,28 @@ class RayTrialExecutorTest(unittest.TestCase):
self.trial_executor.stop_trial(trial)
self.assertEqual(Trial.TERMINATED, trial.status)
def testSaveRestoreTimeout(self):
trial = Trial("__fake")
self.trial_executor.start_trial(trial)
self.assertEqual(Trial.RUNNING, trial.status)
self.trial_executor.save(trial, Checkpoint.DISK)
self.trial_executor.set_status(trial, Trial.PAUSED)
ray_get = ray.get
start_trial = self.trial_executor._start_trial
# Timeout on first two attempts, then succeed on subsequent gets.
side_effects = [RayTimeoutError, RayTimeoutError, ray_get, ray_get]
with patch.object(self.trial_executor, "_start_trial") as mock_start:
with patch("ray.get", side_effect=side_effects):
mock_start.side_effect = start_trial
self.trial_executor.start_trial(trial, trial.checkpoint)
# Trial starts successfully on 3rd attempt.
assert mock_start.call_count == 3
self.assertEqual(Trial.RUNNING, trial.status)
self.trial_executor.stop_trial(trial)
def testPauseResume(self):
"""Tests that pausing works for trials in flight."""
trial = Trial("__fake")