mirror of
https://github.com/vale981/ray
synced 2025-03-08 19:41:38 -05:00
[tune] Retry restore on timeout (#6284)
* Retry recovery on timeout * fix bug, revert some code * Add test for restore time outs. * Fix lint * Address comments * Don't timeout restores.
This commit is contained in:
parent
0b3d5d989b
commit
fa5d62e8ba
2 changed files with 105 additions and 70 deletions
|
@ -26,6 +26,7 @@ RESOURCE_REFRESH_PERIOD = 0.5 # Refresh resources every 500 ms
|
||||||
BOTTLENECK_WARN_PERIOD_S = 60
|
BOTTLENECK_WARN_PERIOD_S = 60
|
||||||
NONTRIVIAL_WAIT_TIME_THRESHOLD_S = 1e-3
|
NONTRIVIAL_WAIT_TIME_THRESHOLD_S = 1e-3
|
||||||
DEFAULT_GET_TIMEOUT = 30.0 # seconds
|
DEFAULT_GET_TIMEOUT = 30.0 # seconds
|
||||||
|
TRIAL_START_ATTEMPTS = 3
|
||||||
|
|
||||||
|
|
||||||
class _LocalWrapper(object):
|
class _LocalWrapper(object):
|
||||||
|
@ -80,8 +81,8 @@ class RayTrialExecutor(TrialExecutor):
|
||||||
|
|
||||||
if (self._reuse_actors and reuse_allowed
|
if (self._reuse_actors and reuse_allowed
|
||||||
and self._cached_actor is not None):
|
and self._cached_actor is not None):
|
||||||
logger.debug("Reusing cached runner {} for {}".format(
|
logger.debug("Trial %s: Reusing cached runner %s", trial,
|
||||||
self._cached_actor, trial.trial_id))
|
self._cached_actor)
|
||||||
existing_runner = self._cached_actor
|
existing_runner = self._cached_actor
|
||||||
self._cached_actor = None
|
self._cached_actor = None
|
||||||
trial.runner = existing_runner
|
trial.runner = existing_runner
|
||||||
|
@ -134,21 +135,25 @@ class RayTrialExecutor(TrialExecutor):
|
||||||
|
|
||||||
self._running[remote] = trial
|
self._running[remote] = trial
|
||||||
|
|
||||||
def _start_trial(self, trial, checkpoint=None):
|
def _start_trial(self, trial, checkpoint=None, runner=None):
|
||||||
"""Starts trial and restores last result if trial was paused.
|
"""Starts trial and restores last result if trial was paused.
|
||||||
|
|
||||||
Raises:
|
Args:
|
||||||
RuntimeError if restoring from checkpoint fails.
|
trial (Trial): The trial to start.
|
||||||
|
checkpoint (Optional[Checkpoint]): The checkpoint to restore from.
|
||||||
|
If None, and no trial checkpoint exists, the trial is started
|
||||||
|
from the beginning.
|
||||||
|
runner (Trainable): The remote runner to use. This can be the
|
||||||
|
cached actor. If None, a new runner is created.
|
||||||
|
|
||||||
|
See `RayTrialExecutor.restore` for possible errors raised.
|
||||||
"""
|
"""
|
||||||
prior_status = trial.status
|
prior_status = trial.status
|
||||||
self.set_status(trial, Trial.RUNNING)
|
self.set_status(trial, Trial.RUNNING)
|
||||||
trial.runner = self._setup_remote_runner(
|
trial.runner = runner or self._setup_remote_runner(
|
||||||
trial,
|
trial,
|
||||||
reuse_allowed=checkpoint is not None or trial.has_checkpoint())
|
reuse_allowed=checkpoint is not None or trial.has_checkpoint())
|
||||||
if not self.restore(trial, checkpoint):
|
self.restore(trial, checkpoint)
|
||||||
if trial.status == Trial.ERROR:
|
|
||||||
raise RuntimeError(
|
|
||||||
"Trial {}: Restore from checkpoint failed.".format(trial))
|
|
||||||
|
|
||||||
previous_run = self._find_item(self._paused, trial)
|
previous_run = self._find_item(self._paused, trial)
|
||||||
if prior_status == Trial.PAUSED and previous_run:
|
if prior_status == Trial.PAUSED and previous_run:
|
||||||
|
@ -206,34 +211,49 @@ class RayTrialExecutor(TrialExecutor):
|
||||||
of trial.
|
of trial.
|
||||||
"""
|
"""
|
||||||
self._commit_resources(trial.resources)
|
self._commit_resources(trial.resources)
|
||||||
try:
|
remote_runner = None
|
||||||
self._start_trial(trial, checkpoint)
|
attempts = 0
|
||||||
except AbortTrialExecution:
|
while attempts < TRIAL_START_ATTEMPTS:
|
||||||
logger.exception("Trial %s: Error starting runner, aborting!",
|
attempts += 1
|
||||||
trial)
|
if attempts > 1:
|
||||||
time.sleep(2)
|
logger.warning("Trial %s: Start attempt #%s...", trial,
|
||||||
error_msg = traceback.format_exc()
|
attempts)
|
||||||
self._stop_trial(trial, error=True, error_msg=error_msg)
|
|
||||||
return # don't retry fatal Tune errors
|
|
||||||
except Exception:
|
|
||||||
logger.exception(
|
|
||||||
"Trial %s: Error starting runner. Attempting "
|
|
||||||
"restart without checkpoint.", trial)
|
|
||||||
time.sleep(2)
|
|
||||||
error_msg = traceback.format_exc()
|
|
||||||
self._stop_trial(trial, error=True, error_msg=error_msg)
|
|
||||||
try:
|
try:
|
||||||
# This forces the trial to not start from checkpoint.
|
self._start_trial(trial, checkpoint, remote_runner)
|
||||||
trial.clear_checkpoint()
|
break
|
||||||
self._start_trial(trial)
|
except AbortTrialExecution:
|
||||||
except Exception:
|
logger.exception("Trial %s: Error starting runner, aborting!",
|
||||||
logger.exception(
|
trial)
|
||||||
"Trial %s: Error starting runner on second "
|
time.sleep(2)
|
||||||
"attempt, aborting!", trial)
|
|
||||||
error_msg = traceback.format_exc()
|
error_msg = traceback.format_exc()
|
||||||
self._stop_trial(trial, error=True, error_msg=error_msg)
|
self._stop_trial(trial, error=True, error_msg=error_msg)
|
||||||
|
break # don't retry fatal Tune errors
|
||||||
|
except RayTimeoutError:
|
||||||
|
# Reuse the existing runner on retries.
|
||||||
|
remote_runner = trial.runner
|
||||||
|
warning = ("Runner task timed out. This could be due to "
|
||||||
|
"slow worker startup.")
|
||||||
|
if attempts == TRIAL_START_ATTEMPTS:
|
||||||
|
error_msg = traceback.format_exc()
|
||||||
|
self._stop_trial(trial, error=True, error_msg=error_msg)
|
||||||
|
else:
|
||||||
|
warning += " Reusing the same runner."
|
||||||
|
logger.warning("Trial %s: %s", trial, warning)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Trial %s: Error starting runner.", trial)
|
||||||
|
time.sleep(2)
|
||||||
|
error_msg = traceback.format_exc()
|
||||||
|
self._stop_trial(trial, error=True, error_msg=error_msg)
|
||||||
|
remote_runner = None
|
||||||
|
# This forces the trial to not start from checkpoint.
|
||||||
|
checkpoint = None
|
||||||
|
trial.clear_checkpoint()
|
||||||
# Note that we don't return the resources, since they may
|
# Note that we don't return the resources, since they may
|
||||||
# have been lost. TODO(ujvl): is this the right thing to do?
|
# have been lost. TODO(ujvl): is this the right thing to do?
|
||||||
|
else:
|
||||||
|
logger.exception(
|
||||||
|
"Trial %s: Aborting trial after %s start "
|
||||||
|
"attempts!", trial, TRIAL_START_ATTEMPTS)
|
||||||
|
|
||||||
def _find_item(self, dictionary, item):
|
def _find_item(self, dictionary, item):
|
||||||
out = [rid for rid, t in dictionary.items() if t is item]
|
out = [rid for rid, t in dictionary.items() if t is item]
|
||||||
|
@ -562,48 +582,34 @@ class RayTrialExecutor(TrialExecutor):
|
||||||
|
|
||||||
This will also sync the trial results to a new location
|
This will also sync the trial results to a new location
|
||||||
if restoring on a different node.
|
if restoring on a different node.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: This error is raised if no runner is found.
|
||||||
|
RayTimeoutError: This error is raised if a remote call to the
|
||||||
|
runner times out.
|
||||||
"""
|
"""
|
||||||
if checkpoint is None or checkpoint.value is None:
|
if checkpoint is None or checkpoint.value is None:
|
||||||
checkpoint = trial.checkpoint
|
checkpoint = trial.checkpoint
|
||||||
if checkpoint is None or checkpoint.value is None:
|
if checkpoint.value is None:
|
||||||
return True
|
return
|
||||||
if trial.runner is None:
|
if trial.runner is None:
|
||||||
logger.error(
|
raise RuntimeError(
|
||||||
"Trial %s: Unable to restore - no runner. "
|
"Trial {}: Unable to restore - no runner found.".format(trial))
|
||||||
"Setting status to ERROR.", trial)
|
value = checkpoint.value
|
||||||
self.set_status(trial, Trial.ERROR)
|
if checkpoint.storage == Checkpoint.MEMORY:
|
||||||
return False
|
assert not isinstance(value, Checkpoint), type(value)
|
||||||
try:
|
trial.runner.restore_from_object.remote(value)
|
||||||
value = checkpoint.value
|
else:
|
||||||
if checkpoint.storage == Checkpoint.MEMORY:
|
logger.info("Trial %s: Attempting restore from %s", trial, value)
|
||||||
assert type(value) != Checkpoint, type(value)
|
with warn_if_slow("get_current_ip"):
|
||||||
trial.runner.restore_from_object.remote(value)
|
worker_ip = ray.get(trial.runner.current_ip.remote(),
|
||||||
else:
|
DEFAULT_GET_TIMEOUT)
|
||||||
logger.info("Trial %s: Attempting restoration from %s", trial,
|
with warn_if_slow("sync_to_new_location"):
|
||||||
checkpoint.value)
|
trial.sync_logger_to_new_location(worker_ip)
|
||||||
with warn_if_slow("get_current_ip"):
|
with warn_if_slow("restore_from_disk"):
|
||||||
worker_ip = ray.get(trial.runner.current_ip.remote(),
|
# TODO(ujvl): Take blocking restores out of the control loop.
|
||||||
DEFAULT_GET_TIMEOUT)
|
ray.get(trial.runner.restore.remote(value))
|
||||||
with warn_if_slow("sync_to_new_location"):
|
|
||||||
trial.sync_logger_to_new_location(worker_ip)
|
|
||||||
with warn_if_slow("restore_from_disk"):
|
|
||||||
ray.get(
|
|
||||||
trial.runner.restore.remote(value),
|
|
||||||
DEFAULT_GET_TIMEOUT)
|
|
||||||
except RayTimeoutError:
|
|
||||||
logger.exception(
|
|
||||||
"Trial %s: Unable to restore - runner task timed "
|
|
||||||
"out. Setting status to ERROR", trial)
|
|
||||||
self.set_status(trial, Trial.ERROR)
|
|
||||||
return False
|
|
||||||
except Exception:
|
|
||||||
logger.exception(
|
|
||||||
"Trial %s: Unable to restore. Setting status to ERROR", trial)
|
|
||||||
self.set_status(trial, Trial.ERROR)
|
|
||||||
return False
|
|
||||||
|
|
||||||
trial.last_result = checkpoint.result
|
trial.last_result = checkpoint.result
|
||||||
return True
|
|
||||||
|
|
||||||
def export_trial_if_needed(self, trial):
|
def export_trial_if_needed(self, trial):
|
||||||
"""Exports model of this trial based on trial.export_formats.
|
"""Exports model of this trial based on trial.export_formats.
|
||||||
|
|
|
@ -4,9 +4,11 @@ from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
import json
|
import json
|
||||||
|
import sys
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
import ray
|
import ray
|
||||||
|
from ray.exceptions import RayTimeoutError
|
||||||
from ray.rllib import _register_all
|
from ray.rllib import _register_all
|
||||||
from ray.tune import Trainable
|
from ray.tune import Trainable
|
||||||
from ray.tune.ray_trial_executor import RayTrialExecutor
|
from ray.tune.ray_trial_executor import RayTrialExecutor
|
||||||
|
@ -16,6 +18,11 @@ from ray.tune.trial import Trial, Checkpoint
|
||||||
from ray.tune.resources import Resources
|
from ray.tune.resources import Resources
|
||||||
from ray.cluster_utils import Cluster
|
from ray.cluster_utils import Cluster
|
||||||
|
|
||||||
|
if sys.version_info >= (3, 3):
|
||||||
|
from unittest.mock import patch
|
||||||
|
else:
|
||||||
|
from mock import patch
|
||||||
|
|
||||||
|
|
||||||
class RayTrialExecutorTest(unittest.TestCase):
|
class RayTrialExecutorTest(unittest.TestCase):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
|
@ -43,6 +50,28 @@ class RayTrialExecutorTest(unittest.TestCase):
|
||||||
self.trial_executor.stop_trial(trial)
|
self.trial_executor.stop_trial(trial)
|
||||||
self.assertEqual(Trial.TERMINATED, trial.status)
|
self.assertEqual(Trial.TERMINATED, trial.status)
|
||||||
|
|
||||||
|
def testSaveRestoreTimeout(self):
|
||||||
|
trial = Trial("__fake")
|
||||||
|
self.trial_executor.start_trial(trial)
|
||||||
|
self.assertEqual(Trial.RUNNING, trial.status)
|
||||||
|
self.trial_executor.save(trial, Checkpoint.DISK)
|
||||||
|
self.trial_executor.set_status(trial, Trial.PAUSED)
|
||||||
|
|
||||||
|
ray_get = ray.get
|
||||||
|
start_trial = self.trial_executor._start_trial
|
||||||
|
|
||||||
|
# Timeout on first two attempts, then succeed on subsequent gets.
|
||||||
|
side_effects = [RayTimeoutError, RayTimeoutError, ray_get, ray_get]
|
||||||
|
with patch.object(self.trial_executor, "_start_trial") as mock_start:
|
||||||
|
with patch("ray.get", side_effect=side_effects):
|
||||||
|
mock_start.side_effect = start_trial
|
||||||
|
self.trial_executor.start_trial(trial, trial.checkpoint)
|
||||||
|
|
||||||
|
# Trial starts successfully on 3rd attempt.
|
||||||
|
assert mock_start.call_count == 3
|
||||||
|
self.assertEqual(Trial.RUNNING, trial.status)
|
||||||
|
self.trial_executor.stop_trial(trial)
|
||||||
|
|
||||||
def testPauseResume(self):
|
def testPauseResume(self):
|
||||||
"""Tests that pausing works for trials in flight."""
|
"""Tests that pausing works for trials in flight."""
|
||||||
trial = Trial("__fake")
|
trial = Trial("__fake")
|
||||||
|
|
Loading…
Add table
Reference in a new issue