mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[tune] fix error handling for fail_fast case. (#22982)
This commit is contained in:
parent
832354ce3f
commit
b1496d235f
4 changed files with 28 additions and 8 deletions
|
@ -11,12 +11,12 @@ import time
|
|||
import traceback
|
||||
from contextlib import contextmanager
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
Iterable,
|
||||
List,
|
||||
Optional,
|
||||
Union,
|
||||
Set,
|
||||
)
|
||||
|
||||
|
@ -170,7 +170,7 @@ class ExecutorEvent:
|
|||
self,
|
||||
event_type: ExecutorEventType,
|
||||
trial: Optional[Trial] = None,
|
||||
result: Optional[Union[str, Dict]] = None,
|
||||
result: Optional[Any] = None,
|
||||
):
|
||||
self.type = event_type
|
||||
self.trial = trial
|
||||
|
@ -1010,9 +1010,9 @@ class RayTrialExecutor(TrialExecutor):
|
|||
return ExecutorEvent(result_type, trial, result=future_result)
|
||||
else:
|
||||
raise TuneError(f"Unexpected future type - [{result_type}]")
|
||||
except Exception:
|
||||
except Exception as e:
|
||||
return ExecutorEvent(
|
||||
ExecutorEventType.ERROR, trial, traceback.format_exc()
|
||||
ExecutorEventType.ERROR, trial, (e, traceback.format_exc())
|
||||
)
|
||||
|
||||
|
||||
|
|
|
@ -194,7 +194,9 @@ class TrialRunnerCallbacks(unittest.TestCase):
|
|||
|
||||
# Let the first trial error
|
||||
self.executor.next_future_result = ExecutorEvent(
|
||||
event_type=ExecutorEventType.ERROR, trial=trials[0]
|
||||
event_type=ExecutorEventType.ERROR,
|
||||
trial=trials[0],
|
||||
result=(Exception(), "error"),
|
||||
)
|
||||
self.trial_runner.step()
|
||||
self.assertEqual(self.callback.state["trial_fail"]["iteration"], 6)
|
||||
|
|
|
@ -13,11 +13,13 @@ import unittest
|
|||
import ray
|
||||
from ray import tune
|
||||
from ray._private.test_utils import recursive_fnmatch
|
||||
from ray.exceptions import RayTaskError
|
||||
from ray.rllib import _register_all
|
||||
from ray.tune.callback import Callback
|
||||
from ray.tune.suggest.basic_variant import BasicVariantGenerator
|
||||
from ray.tune.suggest import Searcher
|
||||
from ray.tune.trial import Trial
|
||||
from ray.tune.trial_runner import TrialRunner
|
||||
from ray.tune.utils import validate_save_restore
|
||||
from ray.tune.utils.mock_trainable import MyTrainableClass
|
||||
|
||||
|
@ -525,6 +527,20 @@ class WorkingDirectoryTest(unittest.TestCase):
|
|||
tune.run(f)
|
||||
|
||||
|
||||
class TrainableCrashWithFailFast(unittest.TestCase):
|
||||
def test(self):
|
||||
"""Trainable crashes with fail_fast flag and the original crash message
|
||||
should bubble up."""
|
||||
|
||||
def f(config):
|
||||
tune.report({"a": 1})
|
||||
time.sleep(0.1)
|
||||
raise RuntimeError("Error happens in trainable!!")
|
||||
|
||||
with self.assertRaisesRegex(RayTaskError, "Error happens in trainable!!"):
|
||||
tune.run(f, fail_fast=TrialRunner.RAISE)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import pytest
|
||||
import sys
|
||||
|
|
|
@ -739,7 +739,7 @@ class TrialRunner:
|
|||
self._on_saving_result(trial, result)
|
||||
self._post_process_on_training_saving_result(trial)
|
||||
except Exception as e:
|
||||
if e is TuneError:
|
||||
if e is TuneError or self._fail_fast == TrialRunner.RAISE:
|
||||
raise e
|
||||
else:
|
||||
raise TuneError(traceback.format_exc())
|
||||
|
@ -868,10 +868,12 @@ class TrialRunner:
|
|||
error_msg = f"Trial {trial}: Error processing event."
|
||||
if self._fail_fast == TrialRunner.RAISE:
|
||||
logger.error(error_msg)
|
||||
raise
|
||||
assert isinstance(result[0], Exception)
|
||||
raise result[0]
|
||||
else:
|
||||
logger.exception(error_msg)
|
||||
self._process_trial_failure(trial, result)
|
||||
assert isinstance(result[1], str)
|
||||
self._process_trial_failure(trial, result[1])
|
||||
|
||||
def get_trial(self, tid):
|
||||
trial = [t for t in self._trials if t.trial_id == tid]
|
||||
|
|
Loading…
Add table
Reference in a new issue