mirror of
https://github.com/vale981/ray
synced 2025-03-06 18:41:40 -05:00
[tune] Retry restore on timeout (#6284)
* Retry recovery on timeout * fix bug, revert some code * Add test for restore time outs. * Fix lint * Address comments * Don't timeout restores.
This commit is contained in:
parent
0b3d5d989b
commit
fa5d62e8ba
2 changed files with 105 additions and 70 deletions
|
@ -26,6 +26,7 @@ RESOURCE_REFRESH_PERIOD = 0.5 # Refresh resources every 500 ms
|
|||
BOTTLENECK_WARN_PERIOD_S = 60
|
||||
NONTRIVIAL_WAIT_TIME_THRESHOLD_S = 1e-3
|
||||
DEFAULT_GET_TIMEOUT = 30.0 # seconds
|
||||
TRIAL_START_ATTEMPTS = 3
|
||||
|
||||
|
||||
class _LocalWrapper(object):
|
||||
|
@ -80,8 +81,8 @@ class RayTrialExecutor(TrialExecutor):
|
|||
|
||||
if (self._reuse_actors and reuse_allowed
|
||||
and self._cached_actor is not None):
|
||||
logger.debug("Reusing cached runner {} for {}".format(
|
||||
self._cached_actor, trial.trial_id))
|
||||
logger.debug("Trial %s: Reusing cached runner %s", trial,
|
||||
self._cached_actor)
|
||||
existing_runner = self._cached_actor
|
||||
self._cached_actor = None
|
||||
trial.runner = existing_runner
|
||||
|
@ -134,21 +135,25 @@ class RayTrialExecutor(TrialExecutor):
|
|||
|
||||
self._running[remote] = trial
|
||||
|
||||
def _start_trial(self, trial, checkpoint=None):
|
||||
def _start_trial(self, trial, checkpoint=None, runner=None):
|
||||
"""Starts trial and restores last result if trial was paused.
|
||||
|
||||
Raises:
|
||||
RuntimeError if restoring from checkpoint fails.
|
||||
Args:
|
||||
trial (Trial): The trial to start.
|
||||
checkpoint (Optional[Checkpoint]): The checkpoint to restore from.
|
||||
If None, and no trial checkpoint exists, the trial is started
|
||||
from the beginning.
|
||||
runner (Trainable): The remote runner to use. This can be the
|
||||
cached actor. If None, a new runner is created.
|
||||
|
||||
See `RayTrialExecutor.restore` for possible errors raised.
|
||||
"""
|
||||
prior_status = trial.status
|
||||
self.set_status(trial, Trial.RUNNING)
|
||||
trial.runner = self._setup_remote_runner(
|
||||
trial.runner = runner or self._setup_remote_runner(
|
||||
trial,
|
||||
reuse_allowed=checkpoint is not None or trial.has_checkpoint())
|
||||
if not self.restore(trial, checkpoint):
|
||||
if trial.status == Trial.ERROR:
|
||||
raise RuntimeError(
|
||||
"Trial {}: Restore from checkpoint failed.".format(trial))
|
||||
self.restore(trial, checkpoint)
|
||||
|
||||
previous_run = self._find_item(self._paused, trial)
|
||||
if prior_status == Trial.PAUSED and previous_run:
|
||||
|
@ -206,34 +211,49 @@ class RayTrialExecutor(TrialExecutor):
|
|||
of trial.
|
||||
"""
|
||||
self._commit_resources(trial.resources)
|
||||
try:
|
||||
self._start_trial(trial, checkpoint)
|
||||
except AbortTrialExecution:
|
||||
logger.exception("Trial %s: Error starting runner, aborting!",
|
||||
trial)
|
||||
time.sleep(2)
|
||||
error_msg = traceback.format_exc()
|
||||
self._stop_trial(trial, error=True, error_msg=error_msg)
|
||||
return # don't retry fatal Tune errors
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Trial %s: Error starting runner. Attempting "
|
||||
"restart without checkpoint.", trial)
|
||||
time.sleep(2)
|
||||
error_msg = traceback.format_exc()
|
||||
self._stop_trial(trial, error=True, error_msg=error_msg)
|
||||
remote_runner = None
|
||||
attempts = 0
|
||||
while attempts < TRIAL_START_ATTEMPTS:
|
||||
attempts += 1
|
||||
if attempts > 1:
|
||||
logger.warning("Trial %s: Start attempt #%s...", trial,
|
||||
attempts)
|
||||
try:
|
||||
# This forces the trial to not start from checkpoint.
|
||||
trial.clear_checkpoint()
|
||||
self._start_trial(trial)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Trial %s: Error starting runner on second "
|
||||
"attempt, aborting!", trial)
|
||||
self._start_trial(trial, checkpoint, remote_runner)
|
||||
break
|
||||
except AbortTrialExecution:
|
||||
logger.exception("Trial %s: Error starting runner, aborting!",
|
||||
trial)
|
||||
time.sleep(2)
|
||||
error_msg = traceback.format_exc()
|
||||
self._stop_trial(trial, error=True, error_msg=error_msg)
|
||||
break # don't retry fatal Tune errors
|
||||
except RayTimeoutError:
|
||||
# Reuse the existing runner on retries.
|
||||
remote_runner = trial.runner
|
||||
warning = ("Runner task timed out. This could be due to "
|
||||
"slow worker startup.")
|
||||
if attempts == TRIAL_START_ATTEMPTS:
|
||||
error_msg = traceback.format_exc()
|
||||
self._stop_trial(trial, error=True, error_msg=error_msg)
|
||||
else:
|
||||
warning += " Reusing the same runner."
|
||||
logger.warning("Trial %s: %s", trial, warning)
|
||||
except Exception:
|
||||
logger.exception("Trial %s: Error starting runner.", trial)
|
||||
time.sleep(2)
|
||||
error_msg = traceback.format_exc()
|
||||
self._stop_trial(trial, error=True, error_msg=error_msg)
|
||||
remote_runner = None
|
||||
# This forces the trial to not start from checkpoint.
|
||||
checkpoint = None
|
||||
trial.clear_checkpoint()
|
||||
# Note that we don't return the resources, since they may
|
||||
# have been lost. TODO(ujvl): is this the right thing to do?
|
||||
else:
|
||||
logger.exception(
|
||||
"Trial %s: Aborting trial after %s start "
|
||||
"attempts!", trial, TRIAL_START_ATTEMPTS)
|
||||
|
||||
def _find_item(self, dictionary, item):
|
||||
out = [rid for rid, t in dictionary.items() if t is item]
|
||||
|
@ -562,48 +582,34 @@ class RayTrialExecutor(TrialExecutor):
|
|||
|
||||
This will also sync the trial results to a new location
|
||||
if restoring on a different node.
|
||||
|
||||
Raises:
|
||||
RuntimeError: This error is raised if no runner is found.
|
||||
RayTimeoutError: This error is raised if a remote call to the
|
||||
runner times out.
|
||||
"""
|
||||
if checkpoint is None or checkpoint.value is None:
|
||||
checkpoint = trial.checkpoint
|
||||
if checkpoint is None or checkpoint.value is None:
|
||||
return True
|
||||
if checkpoint.value is None:
|
||||
return
|
||||
if trial.runner is None:
|
||||
logger.error(
|
||||
"Trial %s: Unable to restore - no runner. "
|
||||
"Setting status to ERROR.", trial)
|
||||
self.set_status(trial, Trial.ERROR)
|
||||
return False
|
||||
try:
|
||||
value = checkpoint.value
|
||||
if checkpoint.storage == Checkpoint.MEMORY:
|
||||
assert type(value) != Checkpoint, type(value)
|
||||
trial.runner.restore_from_object.remote(value)
|
||||
else:
|
||||
logger.info("Trial %s: Attempting restoration from %s", trial,
|
||||
checkpoint.value)
|
||||
with warn_if_slow("get_current_ip"):
|
||||
worker_ip = ray.get(trial.runner.current_ip.remote(),
|
||||
DEFAULT_GET_TIMEOUT)
|
||||
with warn_if_slow("sync_to_new_location"):
|
||||
trial.sync_logger_to_new_location(worker_ip)
|
||||
with warn_if_slow("restore_from_disk"):
|
||||
ray.get(
|
||||
trial.runner.restore.remote(value),
|
||||
DEFAULT_GET_TIMEOUT)
|
||||
except RayTimeoutError:
|
||||
logger.exception(
|
||||
"Trial %s: Unable to restore - runner task timed "
|
||||
"out. Setting status to ERROR", trial)
|
||||
self.set_status(trial, Trial.ERROR)
|
||||
return False
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Trial %s: Unable to restore. Setting status to ERROR", trial)
|
||||
self.set_status(trial, Trial.ERROR)
|
||||
return False
|
||||
|
||||
raise RuntimeError(
|
||||
"Trial {}: Unable to restore - no runner found.".format(trial))
|
||||
value = checkpoint.value
|
||||
if checkpoint.storage == Checkpoint.MEMORY:
|
||||
assert not isinstance(value, Checkpoint), type(value)
|
||||
trial.runner.restore_from_object.remote(value)
|
||||
else:
|
||||
logger.info("Trial %s: Attempting restore from %s", trial, value)
|
||||
with warn_if_slow("get_current_ip"):
|
||||
worker_ip = ray.get(trial.runner.current_ip.remote(),
|
||||
DEFAULT_GET_TIMEOUT)
|
||||
with warn_if_slow("sync_to_new_location"):
|
||||
trial.sync_logger_to_new_location(worker_ip)
|
||||
with warn_if_slow("restore_from_disk"):
|
||||
# TODO(ujvl): Take blocking restores out of the control loop.
|
||||
ray.get(trial.runner.restore.remote(value))
|
||||
trial.last_result = checkpoint.result
|
||||
return True
|
||||
|
||||
def export_trial_if_needed(self, trial):
|
||||
"""Exports model of this trial based on trial.export_formats.
|
||||
|
|
|
@ -4,9 +4,11 @@ from __future__ import division
|
|||
from __future__ import print_function
|
||||
|
||||
import json
|
||||
import sys
|
||||
import unittest
|
||||
|
||||
import ray
|
||||
from ray.exceptions import RayTimeoutError
|
||||
from ray.rllib import _register_all
|
||||
from ray.tune import Trainable
|
||||
from ray.tune.ray_trial_executor import RayTrialExecutor
|
||||
|
@ -16,6 +18,11 @@ from ray.tune.trial import Trial, Checkpoint
|
|||
from ray.tune.resources import Resources
|
||||
from ray.cluster_utils import Cluster
|
||||
|
||||
if sys.version_info >= (3, 3):
|
||||
from unittest.mock import patch
|
||||
else:
|
||||
from mock import patch
|
||||
|
||||
|
||||
class RayTrialExecutorTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
|
@ -43,6 +50,28 @@ class RayTrialExecutorTest(unittest.TestCase):
|
|||
self.trial_executor.stop_trial(trial)
|
||||
self.assertEqual(Trial.TERMINATED, trial.status)
|
||||
|
||||
def testSaveRestoreTimeout(self):
|
||||
trial = Trial("__fake")
|
||||
self.trial_executor.start_trial(trial)
|
||||
self.assertEqual(Trial.RUNNING, trial.status)
|
||||
self.trial_executor.save(trial, Checkpoint.DISK)
|
||||
self.trial_executor.set_status(trial, Trial.PAUSED)
|
||||
|
||||
ray_get = ray.get
|
||||
start_trial = self.trial_executor._start_trial
|
||||
|
||||
# Timeout on first two attempts, then succeed on subsequent gets.
|
||||
side_effects = [RayTimeoutError, RayTimeoutError, ray_get, ray_get]
|
||||
with patch.object(self.trial_executor, "_start_trial") as mock_start:
|
||||
with patch("ray.get", side_effect=side_effects):
|
||||
mock_start.side_effect = start_trial
|
||||
self.trial_executor.start_trial(trial, trial.checkpoint)
|
||||
|
||||
# Trial starts successfully on 3rd attempt.
|
||||
assert mock_start.call_count == 3
|
||||
self.assertEqual(Trial.RUNNING, trial.status)
|
||||
self.trial_executor.stop_trial(trial)
|
||||
|
||||
def testPauseResume(self):
|
||||
"""Tests that pausing works for trials in flight."""
|
||||
trial = Trial("__fake")
|
||||
|
|
Loading…
Add table
Reference in a new issue