[tune] Deflake test_tune_restore.py (#20776)

By switching to on_step_end and keeping track of the number of trials we avoid race conditions in this test suite.
This commit is contained in:
xwjiang2010 2021-11-29 16:38:46 -08:00 committed by GitHub
parent c03b937b95
commit d697c13bda
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -169,15 +169,12 @@ class TuneFailResumeGridTest(unittest.TestCase):
class FailureInjectorCallback(Callback):
"""Adds random failure injection to the TrialExecutor."""
def __init__(self, steps=20):
self._step = 0
self.steps = steps
def __init__(self, num_trials=20):
self.num_trials = num_trials
def on_trial_start(self, trials, **info):
self._step += 1
if self._step >= self.steps:
print(f"Failing after step {self._step} with "
f"{len(trials)} trials")
def on_step_end(self, trials, **kwargs):
if len(trials) == self.num_trials:
print(f"Failing after {self.num_trials} trials.")
raise RuntimeError
class CheckStateCallback(Callback):