mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[tune] Component notification on node failure + Tests (#3414)
Changes include: - Notify Components on Requeue - Slight refactoring of Node Failure handling - Better tests
This commit is contained in:
parent
ce355d13d4
commit
9d0bd50e78
8 changed files with 180 additions and 114 deletions
|
@ -25,7 +25,7 @@ if [[ "$PYTHON" == "2.7" ]] && [[ "$platform" == "linux" ]]; then
|
|||
bash miniconda.sh -b -p $HOME/miniconda
|
||||
export PATH="$HOME/miniconda/bin:$PATH"
|
||||
pip install -q cython==0.27.3 cmake tensorflow gym opencv-python pyyaml pandas==0.23.4 requests \
|
||||
feather-format lxml openpyxl xlrd py-spy setproctitle faulthandler pytest-timeout
|
||||
feather-format lxml openpyxl xlrd py-spy setproctitle faulthandler pytest-timeout mock
|
||||
elif [[ "$PYTHON" == "3.5" ]] && [[ "$platform" == "linux" ]]; then
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y cmake pkg-config python-dev python-numpy build-essential autoconf curl libtool unzip
|
||||
|
@ -51,7 +51,7 @@ elif [[ "$PYTHON" == "2.7" ]] && [[ "$platform" == "macosx" ]]; then
|
|||
bash miniconda.sh -b -p $HOME/miniconda
|
||||
export PATH="$HOME/miniconda/bin:$PATH"
|
||||
pip install -q cython==0.27.3 cmake tensorflow gym opencv-python pyyaml pandas==0.23.4 requests \
|
||||
feather-format lxml openpyxl xlrd py-spy setproctitle faulthandler pytest-timeout
|
||||
feather-format lxml openpyxl xlrd py-spy setproctitle faulthandler pytest-timeout mock
|
||||
elif [[ "$PYTHON" == "3.5" ]] && [[ "$platform" == "macosx" ]]; then
|
||||
# check that brew is installed
|
||||
which -s brew
|
||||
|
|
|
@ -110,19 +110,27 @@ class RayTrialExecutor(TrialExecutor):
|
|||
if stop_logger:
|
||||
trial.close_logger()
|
||||
|
||||
def start_trial(self, trial, checkpoint_obj=None):
|
||||
"""Starts the trial."""
|
||||
def start_trial(self, trial, checkpoint=None):
|
||||
"""Starts the trial.
|
||||
|
||||
Will not return resources if trial repeatedly fails on start.
|
||||
|
||||
Args:
|
||||
trial (Trial): Trial to be started.
|
||||
checkpoint (Checkpoint): A Python object or path storing the state
|
||||
of trial.
|
||||
"""
|
||||
|
||||
self._commit_resources(trial.resources)
|
||||
try:
|
||||
self._start_trial(trial, checkpoint_obj)
|
||||
self._start_trial(trial, checkpoint)
|
||||
except Exception:
|
||||
logger.exception("Error stopping runner - retrying...")
|
||||
error_msg = traceback.format_exc()
|
||||
time.sleep(2)
|
||||
self._stop_trial(trial, error=True, error_msg=error_msg)
|
||||
try:
|
||||
self._start_trial(trial)
|
||||
self._start_trial(trial, checkpoint)
|
||||
except Exception:
|
||||
logger.exception("Error starting runner, aborting!")
|
||||
error_msg = traceback.format_exc()
|
||||
|
@ -140,6 +148,7 @@ class RayTrialExecutor(TrialExecutor):
|
|||
self._stop_trial(
|
||||
trial, error=error, error_msg=error_msg, stop_logger=stop_logger)
|
||||
if prior_status == Trial.RUNNING:
|
||||
logger.debug("Returning resources for this trial.")
|
||||
self._return_resources(trial.resources)
|
||||
out = self._find_item(self._running, trial)
|
||||
for result_id in out:
|
||||
|
|
|
@ -3,45 +3,22 @@ from __future__ import division
|
|||
from __future__ import print_function
|
||||
|
||||
import json
|
||||
import time
|
||||
import pytest
|
||||
try:
|
||||
import pytest_timeout
|
||||
except ImportError:
|
||||
pytest_timeout = None
|
||||
|
||||
from ray.test.cluster_utils import Cluster
|
||||
import ray
|
||||
from ray import tune
|
||||
from ray.rllib import _register_all
|
||||
from ray.test.cluster_utils import Cluster
|
||||
from ray.tune.error import TuneError
|
||||
from ray.tune.trial import Trial
|
||||
from ray.tune.trial_runner import TrialRunner
|
||||
from ray.tune.suggest import BasicVariantGenerator
|
||||
|
||||
|
||||
def register_test_trainable():
|
||||
class _Train(tune.Trainable):
|
||||
def _setup(self, config):
|
||||
self.state = {"hi": 1}
|
||||
|
||||
def _train(self):
|
||||
self.state["hi"] += 1
|
||||
time.sleep(0.5)
|
||||
return {}
|
||||
|
||||
def _save(self, path):
|
||||
return self.state
|
||||
|
||||
def _restore(self, state):
|
||||
self.state = state
|
||||
|
||||
tune.register_trainable("test", _Train)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def start_connected_cluster():
|
||||
# Start the Ray processes.
|
||||
|
||||
def _start_new_cluster():
|
||||
cluster = Cluster(
|
||||
initialize_head=True,
|
||||
connect=True,
|
||||
|
@ -51,7 +28,15 @@ def start_connected_cluster():
|
|||
"num_heartbeats_timeout": 10
|
||||
})
|
||||
})
|
||||
register_test_trainable()
|
||||
# Pytest doesn't play nicely with imports
|
||||
_register_all()
|
||||
return cluster
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def start_connected_cluster():
|
||||
# Start the Ray processes.
|
||||
cluster = _start_new_cluster()
|
||||
yield cluster
|
||||
# The code after the yield will run as teardown code.
|
||||
ray.shutdown()
|
||||
|
@ -71,39 +56,36 @@ def start_connected_emptyhead_cluster():
|
|||
"num_heartbeats_timeout": 10
|
||||
})
|
||||
})
|
||||
register_test_trainable()
|
||||
# Pytest doesn't play nicely with imports
|
||||
_register_all()
|
||||
yield cluster
|
||||
# The code after the yield will run as teardown code.
|
||||
ray.shutdown()
|
||||
cluster.shutdown()
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
pytest_timeout is None,
|
||||
reason="Timeout package not installed; skipping test.")
|
||||
@pytest.mark.timeout(10, method="thread")
|
||||
def test_counting_resources(start_connected_cluster):
|
||||
"""Tests that Tune accounting is consistent with actual cluster."""
|
||||
|
||||
cluster = start_connected_cluster
|
||||
assert ray.global_state.cluster_resources()["CPU"] == 1
|
||||
nodes = []
|
||||
nodes += [cluster.add_node(resources=dict(CPU=1))]
|
||||
assert cluster.wait_for_nodes()
|
||||
assert ray.global_state.cluster_resources()["CPU"] == 2
|
||||
|
||||
assert ray.global_state.cluster_resources()["CPU"] == 1
|
||||
runner = TrialRunner(BasicVariantGenerator())
|
||||
kwargs = {"stopping_criterion": {"training_iteration": 10}}
|
||||
|
||||
trials = [Trial("test", **kwargs), Trial("test", **kwargs)]
|
||||
trials = [Trial("__fake", **kwargs), Trial("__fake", **kwargs)]
|
||||
for t in trials:
|
||||
runner.add_trial(t)
|
||||
|
||||
runner.step() # run 1
|
||||
nodes += [cluster.add_node(resources=dict(CPU=1))]
|
||||
assert cluster.wait_for_nodes()
|
||||
assert ray.global_state.cluster_resources()["CPU"] == 2
|
||||
cluster.remove_node(nodes.pop())
|
||||
assert cluster.wait_for_nodes()
|
||||
assert ray.global_state.cluster_resources()["CPU"] == 1
|
||||
runner.step() # run 2
|
||||
assert sum(t.status == Trial.RUNNING for t in runner.get_trials()) == 1
|
||||
|
||||
for i in range(5):
|
||||
nodes += [cluster.add_node(resources=dict(CPU=1))]
|
||||
|
@ -111,12 +93,7 @@ def test_counting_resources(start_connected_cluster):
|
|||
assert ray.global_state.cluster_resources()["CPU"] == 6
|
||||
|
||||
runner.step() # 1 result
|
||||
|
||||
for i in range(5):
|
||||
node = nodes.pop()
|
||||
cluster.remove_node(node)
|
||||
assert cluster.wait_for_nodes()
|
||||
assert ray.global_state.cluster_resources()["CPU"] == 1
|
||||
assert sum(t.status == Trial.RUNNING for t in runner.get_trials()) == 2
|
||||
|
||||
|
||||
@pytest.mark.skip("Add this test once reconstruction is fixed")
|
||||
|
@ -133,7 +110,7 @@ def test_remove_node_before_result(start_connected_cluster):
|
|||
|
||||
runner = TrialRunner(BasicVariantGenerator())
|
||||
kwargs = {"stopping_criterion": {"training_iteration": 3}}
|
||||
trials = [Trial("test", **kwargs), Trial("test", **kwargs)]
|
||||
trials = [Trial("__fake", **kwargs), Trial("__fake", **kwargs)]
|
||||
for t in trials:
|
||||
runner.add_trial(t)
|
||||
|
||||
|
@ -179,7 +156,7 @@ def test_trial_migration(start_connected_emptyhead_cluster):
|
|||
}
|
||||
|
||||
# Test recovery of trial that hasn't been checkpointed
|
||||
t = Trial("test", **kwargs)
|
||||
t = Trial("__fake", **kwargs)
|
||||
runner.add_trial(t)
|
||||
runner.step() # start
|
||||
runner.step() # 1 result
|
||||
|
@ -199,7 +176,7 @@ def test_trial_migration(start_connected_emptyhead_cluster):
|
|||
assert t.status == Trial.TERMINATED
|
||||
|
||||
# Test recovery of trial that has been checkpointed
|
||||
t2 = Trial("test", **kwargs)
|
||||
t2 = Trial("__fake", **kwargs)
|
||||
runner.add_trial(t2)
|
||||
runner.step() # start
|
||||
runner.step() # 1 result
|
||||
|
@ -216,7 +193,7 @@ def test_trial_migration(start_connected_emptyhead_cluster):
|
|||
assert t2.status == Trial.TERMINATED
|
||||
|
||||
# Test recovery of trial that won't be checkpointed
|
||||
t3 = Trial("test", **{"stopping_criterion": {"training_iteration": 3}})
|
||||
t3 = Trial("__fake", **{"stopping_criterion": {"training_iteration": 3}})
|
||||
runner.add_trial(t3)
|
||||
runner.step() # start
|
||||
runner.step() # 1 result
|
||||
|
@ -238,6 +215,7 @@ def test_trial_requeue(start_connected_emptyhead_cluster):
|
|||
"""Removing a node in full cluster causes Trial to be requeued."""
|
||||
cluster = start_connected_emptyhead_cluster
|
||||
node = cluster.add_node(resources=dict(CPU=1))
|
||||
assert cluster.wait_for_nodes()
|
||||
|
||||
runner = TrialRunner(BasicVariantGenerator())
|
||||
kwargs = {
|
||||
|
@ -248,7 +226,7 @@ def test_trial_requeue(start_connected_emptyhead_cluster):
|
|||
"max_failures": 1
|
||||
}
|
||||
|
||||
trials = [Trial("test", **kwargs), Trial("test", **kwargs)]
|
||||
trials = [Trial("__fake", **kwargs), Trial("__fake", **kwargs)]
|
||||
for t in trials:
|
||||
runner.add_trial(t)
|
||||
|
||||
|
|
|
@ -9,8 +9,9 @@ import ray
|
|||
from ray.rllib import _register_all
|
||||
from ray.tune import Trainable
|
||||
from ray.tune.ray_trial_executor import RayTrialExecutor
|
||||
from ray.tune.registry import _global_registry, TRAINABLE_CLASS
|
||||
from ray.tune.suggest import BasicVariantGenerator
|
||||
from ray.tune.trial import Trial, Checkpoint
|
||||
from ray.tune.trial import Trial, Checkpoint, Resources
|
||||
|
||||
|
||||
class RayTrialExecutorTest(unittest.TestCase):
|
||||
|
@ -50,6 +51,12 @@ class RayTrialExecutorTest(unittest.TestCase):
|
|||
self.trial_executor.stop_trial(trial)
|
||||
self.assertEqual(Trial.TERMINATED, trial.status)
|
||||
|
||||
def testStartFailure(self):
|
||||
_global_registry.register(TRAINABLE_CLASS, "asdf", None)
|
||||
trial = Trial("asdf", resources=Resources(1, 0))
|
||||
self.trial_executor.start_trial(trial)
|
||||
self.assertEqual(Trial.ERROR, trial.status)
|
||||
|
||||
def testPauseResume2(self):
|
||||
"""Tests that pausing works for trials being processed."""
|
||||
trial = Trial("__fake")
|
||||
|
|
|
@ -3,6 +3,7 @@ from __future__ import division
|
|||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import unittest
|
||||
|
||||
|
@ -25,6 +26,11 @@ from ray.tune.suggest.suggestion import (_MockSuggestionAlgorithm,
|
|||
SuggestionAlgorithm)
|
||||
from ray.tune.suggest.variant_generator import RecursiveDependencyError
|
||||
|
||||
if sys.version_info >= (3, 3):
|
||||
from unittest.mock import patch
|
||||
else:
|
||||
from mock import patch
|
||||
|
||||
|
||||
class TrainableFunctionApiTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
|
@ -845,6 +851,25 @@ class VariantGeneratorTest(unittest.TestCase):
|
|||
self.assertEqual(len(searcher.next_trials()), 0)
|
||||
|
||||
|
||||
def create_mock_components():
|
||||
class _MockScheduler(FIFOScheduler):
|
||||
errored_trials = []
|
||||
|
||||
def on_trial_error(self, trial_runner, trial):
|
||||
self.errored_trials += [trial]
|
||||
|
||||
class _MockSearchAlg(BasicVariantGenerator):
|
||||
errored_trials = []
|
||||
|
||||
def on_trial_complete(self, trial_id, error=False, **kwargs):
|
||||
if error:
|
||||
self.errored_trials += [trial_id]
|
||||
|
||||
searchalg = _MockSearchAlg()
|
||||
scheduler = _MockScheduler()
|
||||
return searchalg, scheduler
|
||||
|
||||
|
||||
class TrialRunnerTest(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
ray.shutdown()
|
||||
|
@ -889,16 +914,6 @@ class TrialRunnerTest(unittest.TestCase):
|
|||
self.assertLessEqual(len(trial.logdir), 200)
|
||||
trial_executor.stop_trial(trial)
|
||||
|
||||
def testTrialErrorOnStart(self):
|
||||
ray.init()
|
||||
trial_executor = RayTrialExecutor()
|
||||
_global_registry.register(TRAINABLE_CLASS, "asdf", None)
|
||||
trial = Trial("asdf", resources=Resources(1, 0))
|
||||
try:
|
||||
trial_executor.start_trial(trial)
|
||||
except Exception as e:
|
||||
self.assertIn("a class", str(e))
|
||||
|
||||
def testExtraResources(self):
|
||||
ray.init(num_cpus=4, num_gpus=2)
|
||||
runner = TrialRunner(BasicVariantGenerator())
|
||||
|
@ -1055,7 +1070,9 @@ class TrialRunnerTest(unittest.TestCase):
|
|||
|
||||
def testFailureRecoveryDisabled(self):
|
||||
ray.init(num_cpus=1, num_gpus=1)
|
||||
runner = TrialRunner(BasicVariantGenerator())
|
||||
searchalg, scheduler = create_mock_components()
|
||||
|
||||
runner = TrialRunner(searchalg, scheduler=scheduler)
|
||||
kwargs = {
|
||||
"resources": Resources(cpu=1, gpu=1),
|
||||
"checkpoint_freq": 1,
|
||||
|
@ -1074,10 +1091,15 @@ class TrialRunnerTest(unittest.TestCase):
|
|||
runner.step()
|
||||
self.assertEqual(trials[0].status, Trial.ERROR)
|
||||
self.assertEqual(trials[0].num_failures, 1)
|
||||
self.assertEqual(len(searchalg.errored_trials), 1)
|
||||
self.assertEqual(len(scheduler.errored_trials), 1)
|
||||
|
||||
def testFailureRecoveryEnabled(self):
|
||||
ray.init(num_cpus=1, num_gpus=1)
|
||||
runner = TrialRunner(BasicVariantGenerator())
|
||||
searchalg, scheduler = create_mock_components()
|
||||
|
||||
runner = TrialRunner(searchalg, scheduler=scheduler)
|
||||
|
||||
kwargs = {
|
||||
"resources": Resources(cpu=1, gpu=1),
|
||||
"checkpoint_freq": 1,
|
||||
|
@ -1098,6 +1120,40 @@ class TrialRunnerTest(unittest.TestCase):
|
|||
self.assertEqual(trials[0].num_failures, 1)
|
||||
runner.step()
|
||||
self.assertEqual(trials[0].status, Trial.RUNNING)
|
||||
self.assertEqual(len(searchalg.errored_trials), 0)
|
||||
self.assertEqual(len(scheduler.errored_trials), 0)
|
||||
|
||||
def testFailureRecoveryNodeRemoval(self):
|
||||
ray.init(num_cpus=1, num_gpus=1)
|
||||
searchalg, scheduler = create_mock_components()
|
||||
|
||||
runner = TrialRunner(searchalg, scheduler=scheduler)
|
||||
|
||||
kwargs = {
|
||||
"resources": Resources(cpu=1, gpu=1),
|
||||
"checkpoint_freq": 1,
|
||||
"max_failures": 1,
|
||||
"config": {
|
||||
"mock_error": True,
|
||||
},
|
||||
}
|
||||
runner.add_trial(Trial("__fake", **kwargs))
|
||||
trials = runner.get_trials()
|
||||
|
||||
with patch('ray.global_state.cluster_resources') as resource_mock:
|
||||
resource_mock.return_value = {"CPU": 1, "GPU": 1}
|
||||
runner.step()
|
||||
self.assertEqual(trials[0].status, Trial.RUNNING)
|
||||
runner.step()
|
||||
self.assertEqual(trials[0].status, Trial.RUNNING)
|
||||
|
||||
# Mimic a node failure
|
||||
resource_mock.return_value = {"CPU": 0, "GPU": 0}
|
||||
runner.step()
|
||||
self.assertEqual(trials[0].status, Trial.PENDING)
|
||||
self.assertEqual(trials[0].num_failures, 1)
|
||||
self.assertEqual(len(searchalg.errored_trials), 0)
|
||||
self.assertEqual(len(scheduler.errored_trials), 1)
|
||||
|
||||
def testFailureRecoveryMaxFailures(self):
|
||||
ray.init(num_cpus=1, num_gpus=1)
|
||||
|
|
|
@ -216,17 +216,19 @@ class Trial(object):
|
|||
|
||||
return False
|
||||
|
||||
def should_checkpoint(self, result):
|
||||
def should_checkpoint(self):
|
||||
"""Whether this trial is due for checkpointing."""
|
||||
result = self.last_result or {}
|
||||
|
||||
if result.get(DONE) and self.checkpoint_at_end:
|
||||
return True
|
||||
|
||||
if not self.checkpoint_freq:
|
||||
if self.checkpoint_freq:
|
||||
return result.get(TRAINING_ITERATION,
|
||||
0) % self.checkpoint_freq == 0
|
||||
else:
|
||||
return False
|
||||
|
||||
return self.last_result[TRAINING_ITERATION] % self.checkpoint_freq == 0
|
||||
|
||||
def progress_string(self):
|
||||
"""Returns a progress message for printing out to the console."""
|
||||
|
||||
|
@ -281,10 +283,12 @@ class Trial(object):
|
|||
def should_recover(self):
|
||||
"""Returns whether the trial qualifies for restoring.
|
||||
|
||||
This is if a checkpoint frequency is set, which includes settings
|
||||
where there may not yet be a checkpoint.
|
||||
This is if a checkpoint frequency is set and has not failed more than
|
||||
max_failures. This may return true even when there may not yet
|
||||
be a checkpoint.
|
||||
"""
|
||||
return self.checkpoint_freq > 0
|
||||
return (self.checkpoint_freq > 0
|
||||
and self.num_failures < self.max_failures)
|
||||
|
||||
def update_last_result(self, result, terminate=False):
|
||||
if terminate:
|
||||
|
|
|
@ -32,12 +32,10 @@ class TrialExecutor(object):
|
|||
"has_resources() method")
|
||||
|
||||
def start_trial(self, trial, checkpoint=None):
|
||||
"""Starts the trial restoring from checkpoint if checkpoint != None.
|
||||
|
||||
If an error is encountered when starting the trial, an exception will
|
||||
be thrown.
|
||||
"""Starts the trial restoring from checkpoint if checkpoint is provided.
|
||||
|
||||
Args:
|
||||
trial (Trial): Trial to be started.
|
||||
checkpoint(Checkpoint): A Python object or path storing the state
|
||||
of trial.
|
||||
"""
|
||||
|
@ -59,26 +57,6 @@ class TrialExecutor(object):
|
|||
raise NotImplementedError("Subclasses of TrialExecutor must provide "
|
||||
"stop_trial() method")
|
||||
|
||||
def restart_trial(self, trial, error_msg=None):
|
||||
"""Restarts or requeues the trial.
|
||||
|
||||
The state of the trial should restore from the last checkpoint. Trial
|
||||
is requeued if the cluster no longer has resources to accomodate it.
|
||||
|
||||
Args:
|
||||
error_msg (str): Optional error message.
|
||||
"""
|
||||
self.stop_trial(
|
||||
trial,
|
||||
error=error_msg is not None,
|
||||
error_msg=error_msg,
|
||||
stop_logger=False)
|
||||
trial.result_logger.flush()
|
||||
if self.has_resources(trial.resources):
|
||||
self.start_trial(trial)
|
||||
else:
|
||||
trial.status = Trial.PENDING
|
||||
|
||||
def continue_training(self, trial):
|
||||
"""Continues the training of this trial."""
|
||||
pass
|
||||
|
|
|
@ -12,7 +12,7 @@ import traceback
|
|||
from ray.tune import TuneError
|
||||
from ray.tune.ray_trial_executor import RayTrialExecutor
|
||||
from ray.tune.result import TIME_THIS_ITER_S
|
||||
from ray.tune.trial import Trial
|
||||
from ray.tune.trial import Trial, Checkpoint
|
||||
from ray.tune.schedulers import FIFOScheduler, TrialScheduler
|
||||
from ray.tune.web_server import TuneServer
|
||||
|
||||
|
@ -279,17 +279,14 @@ class TrialRunner(object):
|
|||
result, terminate=(decision == TrialScheduler.STOP))
|
||||
|
||||
if decision == TrialScheduler.CONTINUE:
|
||||
if trial.should_checkpoint(result):
|
||||
# TODO(rliaw): This is a blocking call
|
||||
self.trial_executor.save(trial)
|
||||
self._checkpoint_if_needed(trial)
|
||||
self.trial_executor.continue_training(trial)
|
||||
elif decision == TrialScheduler.PAUSE:
|
||||
self.trial_executor.pause_trial(trial)
|
||||
elif decision == TrialScheduler.STOP:
|
||||
# Checkpoint before ending the trial
|
||||
# if checkpoint_at_end experiment option is set to True
|
||||
if trial.should_checkpoint(result):
|
||||
self.trial_executor.save(trial)
|
||||
self._checkpoint_if_needed(trial)
|
||||
self.trial_executor.stop_trial(trial)
|
||||
else:
|
||||
assert False, "Invalid scheduling decision: {}".format(
|
||||
|
@ -298,24 +295,61 @@ class TrialRunner(object):
|
|||
logger.exception("Error processing event.")
|
||||
error_msg = traceback.format_exc()
|
||||
if trial.status == Trial.RUNNING:
|
||||
if trial.should_recover() and \
|
||||
trial.num_failures < trial.max_failures:
|
||||
if trial.should_recover():
|
||||
self._try_recover(trial, error_msg)
|
||||
else:
|
||||
self._scheduler_alg.on_trial_error(self, trial)
|
||||
self._search_alg.on_trial_complete(
|
||||
trial.trial_id, error=True)
|
||||
self.trial_executor.stop_trial(trial, True, error_msg)
|
||||
self.trial_executor.stop_trial(
|
||||
trial, error=True, error_msg=error_msg)
|
||||
|
||||
def _checkpoint_if_needed(self, trial):
|
||||
"""Checkpoints trial based off trial.last_result."""
|
||||
if trial.should_checkpoint():
|
||||
# Save trial runtime if possible
|
||||
if hasattr(trial, "runner") and trial.runner:
|
||||
self.trial_executor.save(trial, storage=Checkpoint.DISK)
|
||||
|
||||
def _try_recover(self, trial, error_msg):
|
||||
"""Tries to recover trial.
|
||||
|
||||
Notifies SearchAlgorithm and Scheduler if failure to recover.
|
||||
|
||||
Args:
|
||||
trial (Trial): Trial to recover.
|
||||
error_msg (str): Error message from prior to invoking this method.
|
||||
"""
|
||||
try:
|
||||
logger.info("Attempting to recover"
|
||||
" trial state from last checkpoint.")
|
||||
self.trial_executor.restart_trial(trial, error_msg)
|
||||
self.trial_executor.stop_trial(
|
||||
trial,
|
||||
error=error_msg is not None,
|
||||
error_msg=error_msg,
|
||||
stop_logger=False)
|
||||
trial.result_logger.flush()
|
||||
if self.trial_executor.has_resources(trial.resources):
|
||||
logger.info("Attempting to recover"
|
||||
" trial state from last checkpoint.")
|
||||
self.trial_executor.start_trial(trial)
|
||||
if trial.status == Trial.ERROR:
|
||||
raise RuntimeError("Trial did not start correctly.")
|
||||
else:
|
||||
logger.debug("Notifying Scheduler and requeueing trial.")
|
||||
self._requeue_trial(trial)
|
||||
except Exception:
|
||||
error_msg = traceback.format_exc()
|
||||
logger.warning("Error recovering trial from checkpoint, abort.")
|
||||
self.trial_executor.stop_trial(trial, True, error_msg=error_msg)
|
||||
logger.exception("Error recovering trial from checkpoint, abort.")
|
||||
self._scheduler_alg.on_trial_error(self, trial)
|
||||
self._search_alg.on_trial_complete(trial.trial_id, error=True)
|
||||
|
||||
def _requeue_trial(self, trial):
|
||||
"""Notification to TrialScheduler and requeue trial.
|
||||
|
||||
This does not notify the SearchAlgorithm because
|
||||
the function evaluation is still in progress.
|
||||
"""
|
||||
self._scheduler_alg.on_trial_error(self, trial)
|
||||
trial.status = Trial.PENDING
|
||||
self._scheduler_alg.on_trial_add(self, trial)
|
||||
|
||||
def _update_trial_queue(self, blocking=False, timeout=600):
|
||||
"""Adds next trials to queue if possible.
|
||||
|
|
Loading…
Add table
Reference in a new issue