diff --git a/python/ray/tune/ray_trial_executor.py b/python/ray/tune/ray_trial_executor.py index 4383cc784..4ad318515 100644 --- a/python/ray/tune/ray_trial_executor.py +++ b/python/ray/tune/ray_trial_executor.py @@ -26,6 +26,7 @@ RESOURCE_REFRESH_PERIOD = 0.5 # Refresh resources every 500 ms BOTTLENECK_WARN_PERIOD_S = 60 NONTRIVIAL_WAIT_TIME_THRESHOLD_S = 1e-3 DEFAULT_GET_TIMEOUT = 30.0 # seconds +TRIAL_START_ATTEMPTS = 3 class _LocalWrapper(object): @@ -80,8 +81,8 @@ class RayTrialExecutor(TrialExecutor): if (self._reuse_actors and reuse_allowed and self._cached_actor is not None): - logger.debug("Reusing cached runner {} for {}".format( - self._cached_actor, trial.trial_id)) + logger.debug("Trial %s: Reusing cached runner %s", trial, + self._cached_actor) existing_runner = self._cached_actor self._cached_actor = None trial.runner = existing_runner @@ -134,21 +135,25 @@ class RayTrialExecutor(TrialExecutor): self._running[remote] = trial - def _start_trial(self, trial, checkpoint=None): + def _start_trial(self, trial, checkpoint=None, runner=None): """Starts trial and restores last result if trial was paused. - Raises: - RuntimeError if restoring from checkpoint fails. + Args: + trial (Trial): The trial to start. + checkpoint (Optional[Checkpoint]): The checkpoint to restore from. + If None, and no trial checkpoint exists, the trial is started + from the beginning. + runner (Trainable): The remote runner to use. This can be the + cached actor. If None, a new runner is created. + + See `RayTrialExecutor.restore` for possible errors raised. """ prior_status = trial.status self.set_status(trial, Trial.RUNNING) - trial.runner = self._setup_remote_runner( + trial.runner = runner or self._setup_remote_runner( trial, reuse_allowed=checkpoint is not None or trial.has_checkpoint()) - if not self.restore(trial, checkpoint): - if trial.status == Trial.ERROR: - raise RuntimeError( - "Trial {}: Restore from checkpoint failed.".format(trial)) + self.restore(trial, checkpoint) previous_run = self._find_item(self._paused, trial) if prior_status == Trial.PAUSED and previous_run: @@ -206,34 +211,49 @@ class RayTrialExecutor(TrialExecutor): of trial. """ self._commit_resources(trial.resources) - try: - self._start_trial(trial, checkpoint) - except AbortTrialExecution: - logger.exception("Trial %s: Error starting runner, aborting!", - trial) - time.sleep(2) - error_msg = traceback.format_exc() - self._stop_trial(trial, error=True, error_msg=error_msg) - return # don't retry fatal Tune errors - except Exception: - logger.exception( - "Trial %s: Error starting runner. Attempting " - "restart without checkpoint.", trial) - time.sleep(2) - error_msg = traceback.format_exc() - self._stop_trial(trial, error=True, error_msg=error_msg) + remote_runner = None + attempts = 0 + while attempts < TRIAL_START_ATTEMPTS: + attempts += 1 + if attempts > 1: + logger.warning("Trial %s: Start attempt #%s...", trial, + attempts) try: - # This forces the trial to not start from checkpoint. - trial.clear_checkpoint() - self._start_trial(trial) - except Exception: - logger.exception( - "Trial %s: Error starting runner on second " - "attempt, aborting!", trial) + self._start_trial(trial, checkpoint, remote_runner) + break + except AbortTrialExecution: + logger.exception("Trial %s: Error starting runner, aborting!", + trial) + time.sleep(2) error_msg = traceback.format_exc() self._stop_trial(trial, error=True, error_msg=error_msg) + break # don't retry fatal Tune errors + except RayTimeoutError: + # Reuse the existing runner on retries. + remote_runner = trial.runner + warning = ("Runner task timed out. This could be due to " + "slow worker startup.") + if attempts == TRIAL_START_ATTEMPTS: + error_msg = traceback.format_exc() + self._stop_trial(trial, error=True, error_msg=error_msg) + else: + warning += " Reusing the same runner." + logger.warning("Trial %s: %s", trial, warning) + except Exception: + logger.exception("Trial %s: Error starting runner.", trial) + time.sleep(2) + error_msg = traceback.format_exc() + self._stop_trial(trial, error=True, error_msg=error_msg) + remote_runner = None + # This forces the trial to not start from checkpoint. + checkpoint = None + trial.clear_checkpoint() # Note that we don't return the resources, since they may # have been lost. TODO(ujvl): is this the right thing to do? + else: + logger.exception( + "Trial %s: Aborting trial after %s start " + "attempts!", trial, TRIAL_START_ATTEMPTS) def _find_item(self, dictionary, item): out = [rid for rid, t in dictionary.items() if t is item] @@ -562,48 +582,34 @@ class RayTrialExecutor(TrialExecutor): This will also sync the trial results to a new location if restoring on a different node. + + Raises: + RuntimeError: This error is raised if no runner is found. + RayTimeoutError: This error is raised if a remote call to the + runner times out. """ if checkpoint is None or checkpoint.value is None: checkpoint = trial.checkpoint - if checkpoint is None or checkpoint.value is None: - return True + if checkpoint.value is None: + return if trial.runner is None: - logger.error( - "Trial %s: Unable to restore - no runner. " - "Setting status to ERROR.", trial) - self.set_status(trial, Trial.ERROR) - return False - try: - value = checkpoint.value - if checkpoint.storage == Checkpoint.MEMORY: - assert type(value) != Checkpoint, type(value) - trial.runner.restore_from_object.remote(value) - else: - logger.info("Trial %s: Attempting restoration from %s", trial, - checkpoint.value) - with warn_if_slow("get_current_ip"): - worker_ip = ray.get(trial.runner.current_ip.remote(), - DEFAULT_GET_TIMEOUT) - with warn_if_slow("sync_to_new_location"): - trial.sync_logger_to_new_location(worker_ip) - with warn_if_slow("restore_from_disk"): - ray.get( - trial.runner.restore.remote(value), - DEFAULT_GET_TIMEOUT) - except RayTimeoutError: - logger.exception( - "Trial %s: Unable to restore - runner task timed " - "out. Setting status to ERROR", trial) - self.set_status(trial, Trial.ERROR) - return False - except Exception: - logger.exception( - "Trial %s: Unable to restore. Setting status to ERROR", trial) - self.set_status(trial, Trial.ERROR) - return False - + raise RuntimeError( + "Trial {}: Unable to restore - no runner found.".format(trial)) + value = checkpoint.value + if checkpoint.storage == Checkpoint.MEMORY: + assert not isinstance(value, Checkpoint), type(value) + trial.runner.restore_from_object.remote(value) + else: + logger.info("Trial %s: Attempting restore from %s", trial, value) + with warn_if_slow("get_current_ip"): + worker_ip = ray.get(trial.runner.current_ip.remote(), + DEFAULT_GET_TIMEOUT) + with warn_if_slow("sync_to_new_location"): + trial.sync_logger_to_new_location(worker_ip) + with warn_if_slow("restore_from_disk"): + # TODO(ujvl): Take blocking restores out of the control loop. + ray.get(trial.runner.restore.remote(value)) trial.last_result = checkpoint.result - return True def export_trial_if_needed(self, trial): """Exports model of this trial based on trial.export_formats. diff --git a/python/ray/tune/tests/test_ray_trial_executor.py b/python/ray/tune/tests/test_ray_trial_executor.py index 1c91e813f..fa12d7ddf 100644 --- a/python/ray/tune/tests/test_ray_trial_executor.py +++ b/python/ray/tune/tests/test_ray_trial_executor.py @@ -4,9 +4,11 @@ from __future__ import division from __future__ import print_function import json +import sys import unittest import ray +from ray.exceptions import RayTimeoutError from ray.rllib import _register_all from ray.tune import Trainable from ray.tune.ray_trial_executor import RayTrialExecutor @@ -16,6 +18,11 @@ from ray.tune.trial import Trial, Checkpoint from ray.tune.resources import Resources from ray.cluster_utils import Cluster +if sys.version_info >= (3, 3): + from unittest.mock import patch +else: + from mock import patch + class RayTrialExecutorTest(unittest.TestCase): def setUp(self): @@ -43,6 +50,28 @@ class RayTrialExecutorTest(unittest.TestCase): self.trial_executor.stop_trial(trial) self.assertEqual(Trial.TERMINATED, trial.status) + def testSaveRestoreTimeout(self): + trial = Trial("__fake") + self.trial_executor.start_trial(trial) + self.assertEqual(Trial.RUNNING, trial.status) + self.trial_executor.save(trial, Checkpoint.DISK) + self.trial_executor.set_status(trial, Trial.PAUSED) + + ray_get = ray.get + start_trial = self.trial_executor._start_trial + + # Timeout on first two attempts, then succeed on subsequent gets. + side_effects = [RayTimeoutError, RayTimeoutError, ray_get, ray_get] + with patch.object(self.trial_executor, "_start_trial") as mock_start: + with patch("ray.get", side_effect=side_effects): + mock_start.side_effect = start_trial + self.trial_executor.start_trial(trial, trial.checkpoint) + + # Trial starts successfully on 3rd attempt. + assert mock_start.call_count == 3 + self.assertEqual(Trial.RUNNING, trial.status) + self.trial_executor.stop_trial(trial) + def testPauseResume(self): """Tests that pausing works for trials in flight.""" trial = Trial("__fake")