[tune] Check node liveness before result fetch (#5844)

* Check if trial's node is alive before trying to fetch result

* Added function for failed trials to trial_executor interface

* Address comments, add test.
This commit is contained in:
Ujval Misra 2019-10-08 11:41:01 -07:00 committed by Richard Liaw
parent 054583ffe6
commit 375852af23
4 changed files with 75 additions and 18 deletions

View file

@ -301,18 +301,22 @@ class RayTrialExecutor(TrialExecutor):
def get_current_trial_ips(self):
return {t.node_ip for t in self.get_running_trials()}
def get_next_available_trial(self):
def get_next_failed_trial(self):
"""Gets the first trial found to be running on a node presumed dead.
Returns:
A Trial object that is ready for failure processing. None if
no failure detected.
"""
if ray.worker._mode() != ray.worker.LOCAL_MODE:
live_cluster_ips = self.get_alive_node_ips()
if live_cluster_ips - self.get_current_trial_ips():
for trial in self.get_running_trials():
if trial.node_ip and trial.node_ip not in live_cluster_ips:
logger.warning(
"{} (ip: {}) detected as stale. This is likely "
"because the node was lost. Processing this "
"trial first.".format(trial, trial.node_ip))
return trial
return None
def get_next_available_trial(self):
shuffled_results = list(self._running.keys())
random.shuffle(shuffled_results)
# Note: We shuffle the results because `ray.wait` by default returns

View file

@ -8,6 +8,7 @@ import time
import os
import pytest
import shutil
import sys
import ray
from ray import tune
@ -20,6 +21,11 @@ from ray.tune.trial import Trial
from ray.tune.trial_runner import TrialRunner
from ray.tune.suggest import BasicVariantGenerator
if sys.version_info >= (3, 3):
from unittest.mock import MagicMock
else:
from mock import MagicMock
def _start_new_cluster():
cluster = Cluster(
@ -98,6 +104,26 @@ def test_counting_resources(start_connected_cluster):
assert sum(t.status == Trial.RUNNING for t in runner.get_trials()) == 2
def test_trial_processed_after_node_failure(start_connected_emptyhead_cluster):
"""Tests that Tune processes a trial as failed if its node died."""
cluster = start_connected_emptyhead_cluster
node = cluster.add_node(num_cpus=1)
cluster.wait_for_nodes()
runner = TrialRunner(BasicVariantGenerator())
mock_process_failure = MagicMock(side_effect=runner._process_trial_failure)
runner._process_trial_failure = mock_process_failure
runner.add_trial(Trial("__fake"))
runner.step()
runner.step()
assert not mock_process_failure.called
cluster.remove_node(node)
runner.step()
assert mock_process_failure.called
def test_remove_node_before_result(start_connected_emptyhead_cluster):
"""Tune continues when node is removed before trial returns."""
cluster = start_connected_emptyhead_cluster

View file

@ -158,6 +158,15 @@ class TrialExecutor(object):
"""
raise NotImplementedError
def get_next_failed_trial(self):
"""Non-blocking call that detects and returns one failed trial.
Returns:
A Trial object that is ready for failure processing. None if
no failure detected.
"""
raise NotImplementedError
def fetch_result(self, trial):
"""Fetches one result for the trial.

View file

@ -497,9 +497,18 @@ class TrialRunner(object):
return trial
def _process_events(self):
trial = self.trial_executor.get_next_available_trial() # blocking
with warn_if_slow("process_trial"):
self._process_trial(trial)
failed_trial = self.trial_executor.get_next_failed_trial()
if failed_trial:
with warn_if_slow("process_failed_trial"):
self._process_trial_failure(
failed_trial,
error_msg="{} (ip: {}) detected as stale. This is likely"
"because the node was lost".format(failed_trial,
failed_trial.node_ip))
else:
trial = self.trial_executor.get_next_available_trial() # blocking
with warn_if_slow("process_trial"):
self._process_trial(trial)
def _process_trial(self, trial):
try:
@ -558,16 +567,25 @@ class TrialRunner(object):
decision)
except Exception:
logger.exception("Error processing event.")
error_msg = traceback.format_exc()
if trial.status == Trial.RUNNING:
if trial.should_recover():
self._try_recover(trial, error_msg)
else:
self._scheduler_alg.on_trial_error(self, trial)
self._search_alg.on_trial_complete(
trial.trial_id, error=True)
self.trial_executor.stop_trial(
trial, error=True, error_msg=error_msg)
self._process_trial_failure(trial, traceback.format_exc())
def _process_trial_failure(self, trial, error_msg):
"""Handle trial failure.
Attempt trial recovery if possible, clean up state otherwise.
Args:
trial (Trial): Failed trial.
error_msg (str): Error message prior to invoking this method.
"""
if trial.status == Trial.RUNNING:
if trial.should_recover():
self._try_recover(trial, error_msg)
else:
self._scheduler_alg.on_trial_error(self, trial)
self._search_alg.on_trial_complete(trial.trial_id, error=True)
self.trial_executor.stop_trial(
trial, error=True, error_msg=error_msg)
def _checkpoint_trial_if_needed(self, trial, force=False):
"""Checkpoints trial based off trial.last_result."""