[tune] fix error handling for fail_fast case. (#22982)

This commit is contained in:
xwjiang2010 2022-03-10 12:10:05 -08:00 committed by GitHub
parent 832354ce3f
commit b1496d235f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 28 additions and 8 deletions

View file

@ -11,12 +11,12 @@ import time
import traceback
from contextlib import contextmanager
from typing import (
Any,
Callable,
Dict,
Iterable,
List,
Optional,
Union,
Set,
)
@ -170,7 +170,7 @@ class ExecutorEvent:
self,
event_type: ExecutorEventType,
trial: Optional[Trial] = None,
result: Optional[Union[str, Dict]] = None,
result: Optional[Any] = None,
):
self.type = event_type
self.trial = trial
@ -1010,9 +1010,9 @@ class RayTrialExecutor(TrialExecutor):
return ExecutorEvent(result_type, trial, result=future_result)
else:
raise TuneError(f"Unexpected future type - [{result_type}]")
except Exception:
except Exception as e:
return ExecutorEvent(
ExecutorEventType.ERROR, trial, traceback.format_exc()
ExecutorEventType.ERROR, trial, (e, traceback.format_exc())
)

View file

@ -194,7 +194,9 @@ class TrialRunnerCallbacks(unittest.TestCase):
# Let the first trial error
self.executor.next_future_result = ExecutorEvent(
event_type=ExecutorEventType.ERROR, trial=trials[0]
event_type=ExecutorEventType.ERROR,
trial=trials[0],
result=(Exception(), "error"),
)
self.trial_runner.step()
self.assertEqual(self.callback.state["trial_fail"]["iteration"], 6)

View file

@ -13,11 +13,13 @@ import unittest
import ray
from ray import tune
from ray._private.test_utils import recursive_fnmatch
from ray.exceptions import RayTaskError
from ray.rllib import _register_all
from ray.tune.callback import Callback
from ray.tune.suggest.basic_variant import BasicVariantGenerator
from ray.tune.suggest import Searcher
from ray.tune.trial import Trial
from ray.tune.trial_runner import TrialRunner
from ray.tune.utils import validate_save_restore
from ray.tune.utils.mock_trainable import MyTrainableClass
@ -525,6 +527,20 @@ class WorkingDirectoryTest(unittest.TestCase):
tune.run(f)
class TrainableCrashWithFailFast(unittest.TestCase):
def test(self):
"""Trainable crashes with fail_fast flag and the original crash message
should bubble up."""
def f(config):
tune.report({"a": 1})
time.sleep(0.1)
raise RuntimeError("Error happens in trainable!!")
with self.assertRaisesRegex(RayTaskError, "Error happens in trainable!!"):
tune.run(f, fail_fast=TrialRunner.RAISE)
if __name__ == "__main__":
import pytest
import sys

View file

@ -739,7 +739,7 @@ class TrialRunner:
self._on_saving_result(trial, result)
self._post_process_on_training_saving_result(trial)
except Exception as e:
if e is TuneError:
if e is TuneError or self._fail_fast == TrialRunner.RAISE:
raise e
else:
raise TuneError(traceback.format_exc())
@ -868,10 +868,12 @@ class TrialRunner:
error_msg = f"Trial {trial}: Error processing event."
if self._fail_fast == TrialRunner.RAISE:
logger.error(error_msg)
raise
assert isinstance(result[0], Exception)
raise result[0]
else:
logger.exception(error_msg)
self._process_trial_failure(trial, result)
assert isinstance(result[1], str)
self._process_trial_failure(trial, result[1])
def get_trial(self, tid):
trial = [t for t in self._trials if t.trial_id == tid]