[tune] Component notification on node failure + Tests (#3414)

Changes include:
 - Notify Components on Requeue
 - Slight refactoring of Node Failure handling
 - Better tests
This commit is contained in:
Richard Liaw 2018-12-04 14:47:31 -08:00 committed by GitHub
parent ce355d13d4
commit 9d0bd50e78
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 180 additions and 114 deletions

View file

@ -25,7 +25,7 @@ if [[ "$PYTHON" == "2.7" ]] && [[ "$platform" == "linux" ]]; then
bash miniconda.sh -b -p $HOME/miniconda
export PATH="$HOME/miniconda/bin:$PATH"
pip install -q cython==0.27.3 cmake tensorflow gym opencv-python pyyaml pandas==0.23.4 requests \
feather-format lxml openpyxl xlrd py-spy setproctitle faulthandler pytest-timeout
feather-format lxml openpyxl xlrd py-spy setproctitle faulthandler pytest-timeout mock
elif [[ "$PYTHON" == "3.5" ]] && [[ "$platform" == "linux" ]]; then
sudo apt-get update
sudo apt-get install -y cmake pkg-config python-dev python-numpy build-essential autoconf curl libtool unzip
@ -51,7 +51,7 @@ elif [[ "$PYTHON" == "2.7" ]] && [[ "$platform" == "macosx" ]]; then
bash miniconda.sh -b -p $HOME/miniconda
export PATH="$HOME/miniconda/bin:$PATH"
pip install -q cython==0.27.3 cmake tensorflow gym opencv-python pyyaml pandas==0.23.4 requests \
feather-format lxml openpyxl xlrd py-spy setproctitle faulthandler pytest-timeout
feather-format lxml openpyxl xlrd py-spy setproctitle faulthandler pytest-timeout mock
elif [[ "$PYTHON" == "3.5" ]] && [[ "$platform" == "macosx" ]]; then
# check that brew is installed
which -s brew

View file

@ -110,19 +110,27 @@ class RayTrialExecutor(TrialExecutor):
if stop_logger:
trial.close_logger()
def start_trial(self, trial, checkpoint_obj=None):
"""Starts the trial."""
def start_trial(self, trial, checkpoint=None):
"""Starts the trial.
Will not return resources if trial repeatedly fails on start.
Args:
trial (Trial): Trial to be started.
checkpoint (Checkpoint): A Python object or path storing the state
of trial.
"""
self._commit_resources(trial.resources)
try:
self._start_trial(trial, checkpoint_obj)
self._start_trial(trial, checkpoint)
except Exception:
logger.exception("Error stopping runner - retrying...")
error_msg = traceback.format_exc()
time.sleep(2)
self._stop_trial(trial, error=True, error_msg=error_msg)
try:
self._start_trial(trial)
self._start_trial(trial, checkpoint)
except Exception:
logger.exception("Error starting runner, aborting!")
error_msg = traceback.format_exc()
@ -140,6 +148,7 @@ class RayTrialExecutor(TrialExecutor):
self._stop_trial(
trial, error=error, error_msg=error_msg, stop_logger=stop_logger)
if prior_status == Trial.RUNNING:
logger.debug("Returning resources for this trial.")
self._return_resources(trial.resources)
out = self._find_item(self._running, trial)
for result_id in out:

View file

@ -3,45 +3,22 @@ from __future__ import division
from __future__ import print_function
import json
import time
import pytest
try:
import pytest_timeout
except ImportError:
pytest_timeout = None
from ray.test.cluster_utils import Cluster
import ray
from ray import tune
from ray.rllib import _register_all
from ray.test.cluster_utils import Cluster
from ray.tune.error import TuneError
from ray.tune.trial import Trial
from ray.tune.trial_runner import TrialRunner
from ray.tune.suggest import BasicVariantGenerator
def register_test_trainable():
class _Train(tune.Trainable):
def _setup(self, config):
self.state = {"hi": 1}
def _train(self):
self.state["hi"] += 1
time.sleep(0.5)
return {}
def _save(self, path):
return self.state
def _restore(self, state):
self.state = state
tune.register_trainable("test", _Train)
@pytest.fixture
def start_connected_cluster():
# Start the Ray processes.
def _start_new_cluster():
cluster = Cluster(
initialize_head=True,
connect=True,
@ -51,7 +28,15 @@ def start_connected_cluster():
"num_heartbeats_timeout": 10
})
})
register_test_trainable()
# Pytest doesn't play nicely with imports
_register_all()
return cluster
@pytest.fixture
def start_connected_cluster():
# Start the Ray processes.
cluster = _start_new_cluster()
yield cluster
# The code after the yield will run as teardown code.
ray.shutdown()
@ -71,39 +56,36 @@ def start_connected_emptyhead_cluster():
"num_heartbeats_timeout": 10
})
})
register_test_trainable()
# Pytest doesn't play nicely with imports
_register_all()
yield cluster
# The code after the yield will run as teardown code.
ray.shutdown()
cluster.shutdown()
@pytest.mark.skipif(
pytest_timeout is None,
reason="Timeout package not installed; skipping test.")
@pytest.mark.timeout(10, method="thread")
def test_counting_resources(start_connected_cluster):
"""Tests that Tune accounting is consistent with actual cluster."""
cluster = start_connected_cluster
assert ray.global_state.cluster_resources()["CPU"] == 1
nodes = []
nodes += [cluster.add_node(resources=dict(CPU=1))]
assert cluster.wait_for_nodes()
assert ray.global_state.cluster_resources()["CPU"] == 2
assert ray.global_state.cluster_resources()["CPU"] == 1
runner = TrialRunner(BasicVariantGenerator())
kwargs = {"stopping_criterion": {"training_iteration": 10}}
trials = [Trial("test", **kwargs), Trial("test", **kwargs)]
trials = [Trial("__fake", **kwargs), Trial("__fake", **kwargs)]
for t in trials:
runner.add_trial(t)
runner.step() # run 1
nodes += [cluster.add_node(resources=dict(CPU=1))]
assert cluster.wait_for_nodes()
assert ray.global_state.cluster_resources()["CPU"] == 2
cluster.remove_node(nodes.pop())
assert cluster.wait_for_nodes()
assert ray.global_state.cluster_resources()["CPU"] == 1
runner.step() # run 2
assert sum(t.status == Trial.RUNNING for t in runner.get_trials()) == 1
for i in range(5):
nodes += [cluster.add_node(resources=dict(CPU=1))]
@ -111,12 +93,7 @@ def test_counting_resources(start_connected_cluster):
assert ray.global_state.cluster_resources()["CPU"] == 6
runner.step() # 1 result
for i in range(5):
node = nodes.pop()
cluster.remove_node(node)
assert cluster.wait_for_nodes()
assert ray.global_state.cluster_resources()["CPU"] == 1
assert sum(t.status == Trial.RUNNING for t in runner.get_trials()) == 2
@pytest.mark.skip("Add this test once reconstruction is fixed")
@ -133,7 +110,7 @@ def test_remove_node_before_result(start_connected_cluster):
runner = TrialRunner(BasicVariantGenerator())
kwargs = {"stopping_criterion": {"training_iteration": 3}}
trials = [Trial("test", **kwargs), Trial("test", **kwargs)]
trials = [Trial("__fake", **kwargs), Trial("__fake", **kwargs)]
for t in trials:
runner.add_trial(t)
@ -179,7 +156,7 @@ def test_trial_migration(start_connected_emptyhead_cluster):
}
# Test recovery of trial that hasn't been checkpointed
t = Trial("test", **kwargs)
t = Trial("__fake", **kwargs)
runner.add_trial(t)
runner.step() # start
runner.step() # 1 result
@ -199,7 +176,7 @@ def test_trial_migration(start_connected_emptyhead_cluster):
assert t.status == Trial.TERMINATED
# Test recovery of trial that has been checkpointed
t2 = Trial("test", **kwargs)
t2 = Trial("__fake", **kwargs)
runner.add_trial(t2)
runner.step() # start
runner.step() # 1 result
@ -216,7 +193,7 @@ def test_trial_migration(start_connected_emptyhead_cluster):
assert t2.status == Trial.TERMINATED
# Test recovery of trial that won't be checkpointed
t3 = Trial("test", **{"stopping_criterion": {"training_iteration": 3}})
t3 = Trial("__fake", **{"stopping_criterion": {"training_iteration": 3}})
runner.add_trial(t3)
runner.step() # start
runner.step() # 1 result
@ -238,6 +215,7 @@ def test_trial_requeue(start_connected_emptyhead_cluster):
"""Removing a node in full cluster causes Trial to be requeued."""
cluster = start_connected_emptyhead_cluster
node = cluster.add_node(resources=dict(CPU=1))
assert cluster.wait_for_nodes()
runner = TrialRunner(BasicVariantGenerator())
kwargs = {
@ -248,7 +226,7 @@ def test_trial_requeue(start_connected_emptyhead_cluster):
"max_failures": 1
}
trials = [Trial("test", **kwargs), Trial("test", **kwargs)]
trials = [Trial("__fake", **kwargs), Trial("__fake", **kwargs)]
for t in trials:
runner.add_trial(t)

View file

@ -9,8 +9,9 @@ import ray
from ray.rllib import _register_all
from ray.tune import Trainable
from ray.tune.ray_trial_executor import RayTrialExecutor
from ray.tune.registry import _global_registry, TRAINABLE_CLASS
from ray.tune.suggest import BasicVariantGenerator
from ray.tune.trial import Trial, Checkpoint
from ray.tune.trial import Trial, Checkpoint, Resources
class RayTrialExecutorTest(unittest.TestCase):
@ -50,6 +51,12 @@ class RayTrialExecutorTest(unittest.TestCase):
self.trial_executor.stop_trial(trial)
self.assertEqual(Trial.TERMINATED, trial.status)
def testStartFailure(self):
_global_registry.register(TRAINABLE_CLASS, "asdf", None)
trial = Trial("asdf", resources=Resources(1, 0))
self.trial_executor.start_trial(trial)
self.assertEqual(Trial.ERROR, trial.status)
def testPauseResume2(self):
"""Tests that pausing works for trials being processed."""
trial = Trial("__fake")

View file

@ -3,6 +3,7 @@ from __future__ import division
from __future__ import print_function
import os
import sys
import time
import unittest
@ -25,6 +26,11 @@ from ray.tune.suggest.suggestion import (_MockSuggestionAlgorithm,
SuggestionAlgorithm)
from ray.tune.suggest.variant_generator import RecursiveDependencyError
if sys.version_info >= (3, 3):
from unittest.mock import patch
else:
from mock import patch
class TrainableFunctionApiTest(unittest.TestCase):
def setUp(self):
@ -845,6 +851,25 @@ class VariantGeneratorTest(unittest.TestCase):
self.assertEqual(len(searcher.next_trials()), 0)
def create_mock_components():
class _MockScheduler(FIFOScheduler):
errored_trials = []
def on_trial_error(self, trial_runner, trial):
self.errored_trials += [trial]
class _MockSearchAlg(BasicVariantGenerator):
errored_trials = []
def on_trial_complete(self, trial_id, error=False, **kwargs):
if error:
self.errored_trials += [trial_id]
searchalg = _MockSearchAlg()
scheduler = _MockScheduler()
return searchalg, scheduler
class TrialRunnerTest(unittest.TestCase):
def tearDown(self):
ray.shutdown()
@ -889,16 +914,6 @@ class TrialRunnerTest(unittest.TestCase):
self.assertLessEqual(len(trial.logdir), 200)
trial_executor.stop_trial(trial)
def testTrialErrorOnStart(self):
ray.init()
trial_executor = RayTrialExecutor()
_global_registry.register(TRAINABLE_CLASS, "asdf", None)
trial = Trial("asdf", resources=Resources(1, 0))
try:
trial_executor.start_trial(trial)
except Exception as e:
self.assertIn("a class", str(e))
def testExtraResources(self):
ray.init(num_cpus=4, num_gpus=2)
runner = TrialRunner(BasicVariantGenerator())
@ -1055,7 +1070,9 @@ class TrialRunnerTest(unittest.TestCase):
def testFailureRecoveryDisabled(self):
ray.init(num_cpus=1, num_gpus=1)
runner = TrialRunner(BasicVariantGenerator())
searchalg, scheduler = create_mock_components()
runner = TrialRunner(searchalg, scheduler=scheduler)
kwargs = {
"resources": Resources(cpu=1, gpu=1),
"checkpoint_freq": 1,
@ -1074,10 +1091,15 @@ class TrialRunnerTest(unittest.TestCase):
runner.step()
self.assertEqual(trials[0].status, Trial.ERROR)
self.assertEqual(trials[0].num_failures, 1)
self.assertEqual(len(searchalg.errored_trials), 1)
self.assertEqual(len(scheduler.errored_trials), 1)
def testFailureRecoveryEnabled(self):
ray.init(num_cpus=1, num_gpus=1)
runner = TrialRunner(BasicVariantGenerator())
searchalg, scheduler = create_mock_components()
runner = TrialRunner(searchalg, scheduler=scheduler)
kwargs = {
"resources": Resources(cpu=1, gpu=1),
"checkpoint_freq": 1,
@ -1098,6 +1120,40 @@ class TrialRunnerTest(unittest.TestCase):
self.assertEqual(trials[0].num_failures, 1)
runner.step()
self.assertEqual(trials[0].status, Trial.RUNNING)
self.assertEqual(len(searchalg.errored_trials), 0)
self.assertEqual(len(scheduler.errored_trials), 0)
def testFailureRecoveryNodeRemoval(self):
ray.init(num_cpus=1, num_gpus=1)
searchalg, scheduler = create_mock_components()
runner = TrialRunner(searchalg, scheduler=scheduler)
kwargs = {
"resources": Resources(cpu=1, gpu=1),
"checkpoint_freq": 1,
"max_failures": 1,
"config": {
"mock_error": True,
},
}
runner.add_trial(Trial("__fake", **kwargs))
trials = runner.get_trials()
with patch('ray.global_state.cluster_resources') as resource_mock:
resource_mock.return_value = {"CPU": 1, "GPU": 1}
runner.step()
self.assertEqual(trials[0].status, Trial.RUNNING)
runner.step()
self.assertEqual(trials[0].status, Trial.RUNNING)
# Mimic a node failure
resource_mock.return_value = {"CPU": 0, "GPU": 0}
runner.step()
self.assertEqual(trials[0].status, Trial.PENDING)
self.assertEqual(trials[0].num_failures, 1)
self.assertEqual(len(searchalg.errored_trials), 0)
self.assertEqual(len(scheduler.errored_trials), 1)
def testFailureRecoveryMaxFailures(self):
ray.init(num_cpus=1, num_gpus=1)

View file

@ -216,17 +216,19 @@ class Trial(object):
return False
def should_checkpoint(self, result):
def should_checkpoint(self):
"""Whether this trial is due for checkpointing."""
result = self.last_result or {}
if result.get(DONE) and self.checkpoint_at_end:
return True
if not self.checkpoint_freq:
if self.checkpoint_freq:
return result.get(TRAINING_ITERATION,
0) % self.checkpoint_freq == 0
else:
return False
return self.last_result[TRAINING_ITERATION] % self.checkpoint_freq == 0
def progress_string(self):
"""Returns a progress message for printing out to the console."""
@ -281,10 +283,12 @@ class Trial(object):
def should_recover(self):
"""Returns whether the trial qualifies for restoring.
This is if a checkpoint frequency is set, which includes settings
where there may not yet be a checkpoint.
This is if a checkpoint frequency is set and has not failed more than
max_failures. This may return true even when there may not yet
be a checkpoint.
"""
return self.checkpoint_freq > 0
return (self.checkpoint_freq > 0
and self.num_failures < self.max_failures)
def update_last_result(self, result, terminate=False):
if terminate:

View file

@ -32,12 +32,10 @@ class TrialExecutor(object):
"has_resources() method")
def start_trial(self, trial, checkpoint=None):
"""Starts the trial restoring from checkpoint if checkpoint != None.
If an error is encountered when starting the trial, an exception will
be thrown.
"""Starts the trial restoring from checkpoint if checkpoint is provided.
Args:
trial (Trial): Trial to be started.
checkpoint(Checkpoint): A Python object or path storing the state
of trial.
"""
@ -59,26 +57,6 @@ class TrialExecutor(object):
raise NotImplementedError("Subclasses of TrialExecutor must provide "
"stop_trial() method")
def restart_trial(self, trial, error_msg=None):
"""Restarts or requeues the trial.
The state of the trial should restore from the last checkpoint. Trial
is requeued if the cluster no longer has resources to accomodate it.
Args:
error_msg (str): Optional error message.
"""
self.stop_trial(
trial,
error=error_msg is not None,
error_msg=error_msg,
stop_logger=False)
trial.result_logger.flush()
if self.has_resources(trial.resources):
self.start_trial(trial)
else:
trial.status = Trial.PENDING
def continue_training(self, trial):
"""Continues the training of this trial."""
pass

View file

@ -12,7 +12,7 @@ import traceback
from ray.tune import TuneError
from ray.tune.ray_trial_executor import RayTrialExecutor
from ray.tune.result import TIME_THIS_ITER_S
from ray.tune.trial import Trial
from ray.tune.trial import Trial, Checkpoint
from ray.tune.schedulers import FIFOScheduler, TrialScheduler
from ray.tune.web_server import TuneServer
@ -279,17 +279,14 @@ class TrialRunner(object):
result, terminate=(decision == TrialScheduler.STOP))
if decision == TrialScheduler.CONTINUE:
if trial.should_checkpoint(result):
# TODO(rliaw): This is a blocking call
self.trial_executor.save(trial)
self._checkpoint_if_needed(trial)
self.trial_executor.continue_training(trial)
elif decision == TrialScheduler.PAUSE:
self.trial_executor.pause_trial(trial)
elif decision == TrialScheduler.STOP:
# Checkpoint before ending the trial
# if checkpoint_at_end experiment option is set to True
if trial.should_checkpoint(result):
self.trial_executor.save(trial)
self._checkpoint_if_needed(trial)
self.trial_executor.stop_trial(trial)
else:
assert False, "Invalid scheduling decision: {}".format(
@ -298,24 +295,61 @@ class TrialRunner(object):
logger.exception("Error processing event.")
error_msg = traceback.format_exc()
if trial.status == Trial.RUNNING:
if trial.should_recover() and \
trial.num_failures < trial.max_failures:
if trial.should_recover():
self._try_recover(trial, error_msg)
else:
self._scheduler_alg.on_trial_error(self, trial)
self._search_alg.on_trial_complete(
trial.trial_id, error=True)
self.trial_executor.stop_trial(trial, True, error_msg)
self.trial_executor.stop_trial(
trial, error=True, error_msg=error_msg)
def _checkpoint_if_needed(self, trial):
"""Checkpoints trial based off trial.last_result."""
if trial.should_checkpoint():
# Save trial runtime if possible
if hasattr(trial, "runner") and trial.runner:
self.trial_executor.save(trial, storage=Checkpoint.DISK)
def _try_recover(self, trial, error_msg):
"""Tries to recover trial.
Notifies SearchAlgorithm and Scheduler if failure to recover.
Args:
trial (Trial): Trial to recover.
error_msg (str): Error message from prior to invoking this method.
"""
try:
logger.info("Attempting to recover"
" trial state from last checkpoint.")
self.trial_executor.restart_trial(trial, error_msg)
self.trial_executor.stop_trial(
trial,
error=error_msg is not None,
error_msg=error_msg,
stop_logger=False)
trial.result_logger.flush()
if self.trial_executor.has_resources(trial.resources):
logger.info("Attempting to recover"
" trial state from last checkpoint.")
self.trial_executor.start_trial(trial)
if trial.status == Trial.ERROR:
raise RuntimeError("Trial did not start correctly.")
else:
logger.debug("Notifying Scheduler and requeueing trial.")
self._requeue_trial(trial)
except Exception:
error_msg = traceback.format_exc()
logger.warning("Error recovering trial from checkpoint, abort.")
self.trial_executor.stop_trial(trial, True, error_msg=error_msg)
logger.exception("Error recovering trial from checkpoint, abort.")
self._scheduler_alg.on_trial_error(self, trial)
self._search_alg.on_trial_complete(trial.trial_id, error=True)
def _requeue_trial(self, trial):
"""Notification to TrialScheduler and requeue trial.
This does not notify the SearchAlgorithm because
the function evaluation is still in progress.
"""
self._scheduler_alg.on_trial_error(self, trial)
trial.status = Trial.PENDING
self._scheduler_alg.on_trial_add(self, trial)
def _update_trial_queue(self, blocking=False, timeout=600):
"""Adds next trials to queue if possible.