mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
[tune] Check node liveness before result fetch (#5844)
* Check if trial's node is alive before trying to fetch result * Added function for failed trials to trial_executor interface * Address comments, add test.
This commit is contained in:
parent
054583ffe6
commit
375852af23
4 changed files with 75 additions and 18 deletions
|
@ -301,18 +301,22 @@ class RayTrialExecutor(TrialExecutor):
|
|||
def get_current_trial_ips(self):
|
||||
return {t.node_ip for t in self.get_running_trials()}
|
||||
|
||||
def get_next_available_trial(self):
|
||||
def get_next_failed_trial(self):
|
||||
"""Gets the first trial found to be running on a node presumed dead.
|
||||
|
||||
Returns:
|
||||
A Trial object that is ready for failure processing. None if
|
||||
no failure detected.
|
||||
"""
|
||||
if ray.worker._mode() != ray.worker.LOCAL_MODE:
|
||||
live_cluster_ips = self.get_alive_node_ips()
|
||||
if live_cluster_ips - self.get_current_trial_ips():
|
||||
for trial in self.get_running_trials():
|
||||
if trial.node_ip and trial.node_ip not in live_cluster_ips:
|
||||
logger.warning(
|
||||
"{} (ip: {}) detected as stale. This is likely "
|
||||
"because the node was lost. Processing this "
|
||||
"trial first.".format(trial, trial.node_ip))
|
||||
return trial
|
||||
return None
|
||||
|
||||
def get_next_available_trial(self):
|
||||
shuffled_results = list(self._running.keys())
|
||||
random.shuffle(shuffled_results)
|
||||
# Note: We shuffle the results because `ray.wait` by default returns
|
||||
|
|
|
@ -8,6 +8,7 @@ import time
|
|||
import os
|
||||
import pytest
|
||||
import shutil
|
||||
import sys
|
||||
|
||||
import ray
|
||||
from ray import tune
|
||||
|
@ -20,6 +21,11 @@ from ray.tune.trial import Trial
|
|||
from ray.tune.trial_runner import TrialRunner
|
||||
from ray.tune.suggest import BasicVariantGenerator
|
||||
|
||||
if sys.version_info >= (3, 3):
|
||||
from unittest.mock import MagicMock
|
||||
else:
|
||||
from mock import MagicMock
|
||||
|
||||
|
||||
def _start_new_cluster():
|
||||
cluster = Cluster(
|
||||
|
@ -98,6 +104,26 @@ def test_counting_resources(start_connected_cluster):
|
|||
assert sum(t.status == Trial.RUNNING for t in runner.get_trials()) == 2
|
||||
|
||||
|
||||
def test_trial_processed_after_node_failure(start_connected_emptyhead_cluster):
|
||||
"""Tests that Tune processes a trial as failed if its node died."""
|
||||
cluster = start_connected_emptyhead_cluster
|
||||
node = cluster.add_node(num_cpus=1)
|
||||
cluster.wait_for_nodes()
|
||||
|
||||
runner = TrialRunner(BasicVariantGenerator())
|
||||
mock_process_failure = MagicMock(side_effect=runner._process_trial_failure)
|
||||
runner._process_trial_failure = mock_process_failure
|
||||
|
||||
runner.add_trial(Trial("__fake"))
|
||||
runner.step()
|
||||
runner.step()
|
||||
assert not mock_process_failure.called
|
||||
|
||||
cluster.remove_node(node)
|
||||
runner.step()
|
||||
assert mock_process_failure.called
|
||||
|
||||
|
||||
def test_remove_node_before_result(start_connected_emptyhead_cluster):
|
||||
"""Tune continues when node is removed before trial returns."""
|
||||
cluster = start_connected_emptyhead_cluster
|
||||
|
|
|
@ -158,6 +158,15 @@ class TrialExecutor(object):
|
|||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def get_next_failed_trial(self):
|
||||
"""Non-blocking call that detects and returns one failed trial.
|
||||
|
||||
Returns:
|
||||
A Trial object that is ready for failure processing. None if
|
||||
no failure detected.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def fetch_result(self, trial):
|
||||
"""Fetches one result for the trial.
|
||||
|
||||
|
|
|
@ -497,9 +497,18 @@ class TrialRunner(object):
|
|||
return trial
|
||||
|
||||
def _process_events(self):
|
||||
trial = self.trial_executor.get_next_available_trial() # blocking
|
||||
with warn_if_slow("process_trial"):
|
||||
self._process_trial(trial)
|
||||
failed_trial = self.trial_executor.get_next_failed_trial()
|
||||
if failed_trial:
|
||||
with warn_if_slow("process_failed_trial"):
|
||||
self._process_trial_failure(
|
||||
failed_trial,
|
||||
error_msg="{} (ip: {}) detected as stale. This is likely"
|
||||
"because the node was lost".format(failed_trial,
|
||||
failed_trial.node_ip))
|
||||
else:
|
||||
trial = self.trial_executor.get_next_available_trial() # blocking
|
||||
with warn_if_slow("process_trial"):
|
||||
self._process_trial(trial)
|
||||
|
||||
def _process_trial(self, trial):
|
||||
try:
|
||||
|
@ -558,16 +567,25 @@ class TrialRunner(object):
|
|||
decision)
|
||||
except Exception:
|
||||
logger.exception("Error processing event.")
|
||||
error_msg = traceback.format_exc()
|
||||
if trial.status == Trial.RUNNING:
|
||||
if trial.should_recover():
|
||||
self._try_recover(trial, error_msg)
|
||||
else:
|
||||
self._scheduler_alg.on_trial_error(self, trial)
|
||||
self._search_alg.on_trial_complete(
|
||||
trial.trial_id, error=True)
|
||||
self.trial_executor.stop_trial(
|
||||
trial, error=True, error_msg=error_msg)
|
||||
self._process_trial_failure(trial, traceback.format_exc())
|
||||
|
||||
def _process_trial_failure(self, trial, error_msg):
|
||||
"""Handle trial failure.
|
||||
|
||||
Attempt trial recovery if possible, clean up state otherwise.
|
||||
|
||||
Args:
|
||||
trial (Trial): Failed trial.
|
||||
error_msg (str): Error message prior to invoking this method.
|
||||
"""
|
||||
if trial.status == Trial.RUNNING:
|
||||
if trial.should_recover():
|
||||
self._try_recover(trial, error_msg)
|
||||
else:
|
||||
self._scheduler_alg.on_trial_error(self, trial)
|
||||
self._search_alg.on_trial_complete(trial.trial_id, error=True)
|
||||
self.trial_executor.stop_trial(
|
||||
trial, error=True, error_msg=error_msg)
|
||||
|
||||
def _checkpoint_trial_if_needed(self, trial, force=False):
|
||||
"""Checkpoints trial based off trial.last_result."""
|
||||
|
|
Loading…
Add table
Reference in a new issue