mirror of
https://github.com/vale981/ray
synced 2025-03-11 21:56:39 -04:00
[tune] Fixing up Hyperband (#1207)
* Fixing up Hyperband * nit * cleanup * Timing test Added * added_exception_back * fixup_tests * reverse placement * fixes_and_tests * fix * fix * fixlint * cleanup_timing * lint * Update hyperband.py
This commit is contained in:
parent
7c38f964b7
commit
71f8cd2403
2 changed files with 176 additions and 161 deletions
|
@ -60,11 +60,9 @@ class HyperBandScheduler(FIFOScheduler):
|
||||||
|
|
||||||
def on_trial_add(self, trial_runner, trial):
|
def on_trial_add(self, trial_runner, trial):
|
||||||
"""On a new trial add, if current bracket is not filled,
|
"""On a new trial add, if current bracket is not filled,
|
||||||
add to current bracket. Else, if current hp iteration is not filled,
|
add to current bracket. Else, if current band is not filled,
|
||||||
create new bracket, add to current bracket.
|
create new bracket, add to current bracket.
|
||||||
Else, create new iteration, create new bracket, add to bracket.
|
Else, create new iteration, create new bracket, add to bracket."""
|
||||||
|
|
||||||
TODO(rliaw): This is messy."""
|
|
||||||
|
|
||||||
cur_bracket = self._state["bracket"]
|
cur_bracket = self._state["bracket"]
|
||||||
cur_band = self._hyperbands[self._state["band_idx"]]
|
cur_band = self._hyperbands[self._state["band_idx"]]
|
||||||
|
@ -76,9 +74,9 @@ class HyperBandScheduler(FIFOScheduler):
|
||||||
self._hyperbands.append(cur_band)
|
self._hyperbands.append(cur_band)
|
||||||
self._state["band_idx"] += 1
|
self._state["band_idx"] += 1
|
||||||
|
|
||||||
# cur_band will always be less than s_max or else filled
|
# cur_band will always be less than s_max_1 or else filled
|
||||||
s = self._s_max_1 - len(cur_band) - 1
|
s = len(cur_band)
|
||||||
assert s >= 0, "Current band is filled but adding bracket!"
|
assert s < self._s_max_1, "Current band is filled!"
|
||||||
|
|
||||||
# create new bracket
|
# create new bracket
|
||||||
cur_bracket = Bracket(self._get_n0(s),
|
cur_bracket = Bracket(self._get_n0(s),
|
||||||
|
@ -102,34 +100,44 @@ class HyperBandScheduler(FIFOScheduler):
|
||||||
|
|
||||||
If a given trial finishes and bracket iteration is not done,
|
If a given trial finishes and bracket iteration is not done,
|
||||||
the trial will be paused and resources will be given up.
|
the trial will be paused and resources will be given up.
|
||||||
When bracket iteration is done, Trials will be successively halved,
|
|
||||||
and during each halving phase, bad trials will be stopped while good
|
This scheduler will not start trials but will stop trials.
|
||||||
trials will return to "PENDING". This scheduler will not start trials
|
The current running trial will not be handled,
|
||||||
but will stop trials. The current running trial will not be handled,
|
|
||||||
as the trialrunner will be given control to handle it.
|
as the trialrunner will be given control to handle it.
|
||||||
|
|
||||||
# TODO(rliaw) should be only called if trial has not errored"""
|
# TODO(rliaw) should be only called if trial has not errored"""
|
||||||
bracket, _ = self._trial_info[trial]
|
bracket, _ = self._trial_info[trial]
|
||||||
bracket.update_trial_stats(trial, result)
|
bracket.update_trial_stats(trial, result)
|
||||||
|
|
||||||
if bracket.continue_trial(trial):
|
if bracket.continue_trial(trial):
|
||||||
return TrialScheduler.CONTINUE
|
return TrialScheduler.CONTINUE
|
||||||
|
|
||||||
signal = TrialScheduler.PAUSE
|
action = self._process_bracket(trial_runner, bracket, trial)
|
||||||
|
return action
|
||||||
|
|
||||||
|
def _process_bracket(self, trial_runner, bracket, trial):
|
||||||
|
"""This is called whenever a trial makes progress.
|
||||||
|
|
||||||
|
When all live trials in the bracket have no more iterations left,
|
||||||
|
Trials will be successively halved. If bracket is done, all
|
||||||
|
non-running trials will be stopped and cleaned up,
|
||||||
|
and during each halving phase, bad trials will be stopped while good
|
||||||
|
trials will return to "PENDING"."""
|
||||||
|
|
||||||
|
action = TrialScheduler.PAUSE
|
||||||
if bracket.cur_iter_done():
|
if bracket.cur_iter_done():
|
||||||
if bracket.finished():
|
if bracket.finished():
|
||||||
self._cleanup_bracket(trial_runner, bracket)
|
self._cleanup_bracket(trial_runner, bracket)
|
||||||
return TrialScheduler.STOP
|
return TrialScheduler.STOP
|
||||||
# what if bracket is done and trial not completed?
|
|
||||||
good, bad = bracket.successive_halving()
|
good, bad = bracket.successive_halving()
|
||||||
# kill bad trials
|
# kill bad trials
|
||||||
for t in bad:
|
for t in bad:
|
||||||
self._num_stopped += 1
|
|
||||||
if t.status == Trial.PAUSED:
|
if t.status == Trial.PAUSED:
|
||||||
trial_runner._stop_trial(t)
|
self._cleanup_trial(trial_runner, t, bracket, hard=True)
|
||||||
bracket.cleanup_trial_early(t)
|
elif t.status == Trial.RUNNING:
|
||||||
elif t is trial:
|
self._cleanup_trial(trial_runner, t, bracket, hard=False)
|
||||||
signal = TrialScheduler.STOP
|
action = TrialScheduler.STOP
|
||||||
else:
|
else:
|
||||||
raise Exception("Trial with unexpected status encountered")
|
raise Exception("Trial with unexpected status encountered")
|
||||||
|
|
||||||
|
@ -137,38 +145,42 @@ class HyperBandScheduler(FIFOScheduler):
|
||||||
for t in good:
|
for t in good:
|
||||||
if t.status == Trial.PAUSED:
|
if t.status == Trial.PAUSED:
|
||||||
t.unpause()
|
t.unpause()
|
||||||
elif t is trial:
|
elif t.status == Trial.RUNNING:
|
||||||
signal = TrialScheduler.CONTINUE
|
action = TrialScheduler.CONTINUE
|
||||||
else:
|
else:
|
||||||
raise Exception("Trial with unexpected status encountered")
|
raise Exception("Trial with unexpected status encountered")
|
||||||
|
return action
|
||||||
|
|
||||||
return signal
|
def _cleanup_trial(self, trial_runner, t, bracket, hard=False):
|
||||||
|
"""Bookkeeping for trials finished. If `hard=True`, then
|
||||||
|
this scheduler will force the trial_runner to release resources.
|
||||||
|
|
||||||
|
Otherwise, only clean up trial information locally."""
|
||||||
|
self._num_stopped += 1
|
||||||
|
if hard:
|
||||||
|
trial_runner._stop_trial(t)
|
||||||
|
bracket.cleanup_trial(t)
|
||||||
|
|
||||||
def _cleanup_bracket(self, trial_runner, bracket):
|
def _cleanup_bracket(self, trial_runner, bracket):
|
||||||
"""Cleans up bracket after bracket is completely finished.
|
"""Cleans up bracket after bracket is completely finished."""
|
||||||
|
for trial in bracket.current_trials():
|
||||||
Bracket information will only be cleaned up after the trialrunner has
|
self._cleanup_trial(
|
||||||
finished its bookkeeping."""
|
trial_runner, trial, bracket,
|
||||||
for t in bracket.current_trials():
|
hard=(trial.status == Trial.PAUSED))
|
||||||
if t.status == Trial.PAUSED:
|
|
||||||
trial_runner._stop_trial(t)
|
|
||||||
bracket.cleanup_trial_early(t)
|
|
||||||
|
|
||||||
def on_trial_complete(self, trial_runner, trial, result):
|
def on_trial_complete(self, trial_runner, trial, result):
|
||||||
"""Cleans up trial info from bracket if trial completed early.
|
"""Cleans up trial info from bracket if trial completed early."""
|
||||||
|
|
||||||
Bracket information will only be cleaned up after the trialrunner has
|
|
||||||
finished its bookkeeping."""
|
|
||||||
bracket, _ = self._trial_info[trial]
|
bracket, _ = self._trial_info[trial]
|
||||||
bracket.cleanup_trial_early(trial)
|
self._cleanup_trial(trial_runner, trial, bracket, hard=False)
|
||||||
|
self._process_bracket(trial_runner, bracket, trial)
|
||||||
|
|
||||||
def on_trial_error(self, trial_runner, trial):
|
def on_trial_error(self, trial_runner, trial):
|
||||||
"""Cleans up trial info from bracket if trial errored early.
|
"""Cleans up trial info from bracket if trial errored early."""
|
||||||
|
|
||||||
Bracket information will only be cleaned up after the trialrunner has
|
|
||||||
finished its bookkeeping."""
|
|
||||||
bracket, _ = self._trial_info[trial]
|
bracket, _ = self._trial_info[trial]
|
||||||
bracket.cleanup_trial_early(trial)
|
self._cleanup_trial(trial_runner, trial, bracket, hard=False)
|
||||||
|
self._process_bracket(trial_runner, bracket, trial)
|
||||||
|
|
||||||
def choose_trial_to_run(self, trial_runner, *args):
|
def choose_trial_to_run(self, trial_runner, *args):
|
||||||
"""Fair scheduling within iteration by completion percentage.
|
"""Fair scheduling within iteration by completion percentage.
|
||||||
|
@ -177,6 +189,7 @@ class HyperBandScheduler(FIFOScheduler):
|
||||||
|
|
||||||
If iteration is occupied (ie, no trials to run), then look into
|
If iteration is occupied (ie, no trials to run), then look into
|
||||||
next iteration."""
|
next iteration."""
|
||||||
|
|
||||||
for hyperband in self._hyperbands:
|
for hyperband in self._hyperbands:
|
||||||
for bracket in sorted(hyperband,
|
for bracket in sorted(hyperband,
|
||||||
key=lambda b: b.completion_percentage()):
|
key=lambda b: b.completion_percentage()):
|
||||||
|
@ -187,10 +200,17 @@ class HyperBandScheduler(FIFOScheduler):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def debug_string(self):
|
def debug_string(self):
|
||||||
|
brackets = [
|
||||||
|
"({0}/{1})".format(
|
||||||
|
len(bracket._live_trials), len(bracket._all_trials))
|
||||||
|
for band in self._hyperbands for bracket in band]
|
||||||
return " ".join([
|
return " ".join([
|
||||||
"Using HyperBand:",
|
"Using HyperBand:",
|
||||||
"num_stopped={}".format(self._num_stopped),
|
"num_stopped={}".format(self._num_stopped),
|
||||||
"brackets={}".format(sum(len(band) for band in self._hyperbands))])
|
"total_brackets={}".format(
|
||||||
|
sum(len(band) for band in self._hyperbands)),
|
||||||
|
" ".join(brackets)
|
||||||
|
])
|
||||||
|
|
||||||
|
|
||||||
class Bracket():
|
class Bracket():
|
||||||
|
@ -278,8 +298,9 @@ class Bracket():
|
||||||
self._live_trials[trial] = (result, itr - 1)
|
self._live_trials[trial] = (result, itr - 1)
|
||||||
self._completed_progress += 1
|
self._completed_progress += 1
|
||||||
|
|
||||||
def cleanup_trial_early(self, trial):
|
def cleanup_trial(self, trial):
|
||||||
"""Clean up statistics tracking for trial that terminated early.
|
"""Clean up statistics tracking for terminated trials (either by force
|
||||||
|
or otherwise).
|
||||||
|
|
||||||
This may cause bad trials to continue for a long time, in the case
|
This may cause bad trials to continue for a long time, in the case
|
||||||
where all the good trials finish early and there are only bad trials
|
where all the good trials finish early and there are only bad trials
|
||||||
|
|
|
@ -151,17 +151,26 @@ class _MockTrialRunner():
|
||||||
|
|
||||||
|
|
||||||
class HyperbandSuite(unittest.TestCase):
|
class HyperbandSuite(unittest.TestCase):
|
||||||
def basicSetup(self):
|
|
||||||
"""s_max_1 = 3;
|
|
||||||
brackets: iter (n, r) | iter (n, r) | iter (n, r)
|
|
||||||
(9, 1) -> (3, 3) -> (1, 9)
|
|
||||||
(9, 1) -> (3, 3) -> (1, 9)
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
def schedulerSetup(self, num_trials):
|
||||||
|
"""Setup a scheduler and Runner with max Iter = 9
|
||||||
|
|
||||||
|
Bracketing is placed as follows:
|
||||||
|
(3, 9);
|
||||||
|
(5, 3) -> (2, 9);
|
||||||
|
(9, 1) -> (3, 3) -> (1, 9); """
|
||||||
sched = HyperBandScheduler(9, eta=3)
|
sched = HyperBandScheduler(9, eta=3)
|
||||||
for i in range(17):
|
for i in range(num_trials):
|
||||||
t = Trial("t%d" % i, "__fake")
|
t = Trial("t%d" % i, "__fake")
|
||||||
sched.on_trial_add(None, t)
|
sched.on_trial_add(None, t)
|
||||||
|
runner = _MockTrialRunner()
|
||||||
|
return sched, runner
|
||||||
|
|
||||||
|
def basicSetup(self):
|
||||||
|
"""Setup and verify full band.
|
||||||
|
"""
|
||||||
|
|
||||||
|
sched, _ = self.schedulerSetup(17)
|
||||||
|
|
||||||
self.assertEqual(len(sched._hyperbands), 1)
|
self.assertEqual(len(sched._hyperbands), 1)
|
||||||
self.assertEqual(sched._cur_band_filled(), True)
|
self.assertEqual(sched._cur_band_filled(), True)
|
||||||
|
@ -173,69 +182,80 @@ class HyperbandSuite(unittest.TestCase):
|
||||||
|
|
||||||
def advancedSetup(self):
|
def advancedSetup(self):
|
||||||
sched = self.basicSetup()
|
sched = self.basicSetup()
|
||||||
for i in range(3):
|
for i in range(4):
|
||||||
t = Trial("t%d" % (i + 20), "__fake")
|
t = Trial("t%d" % (i + 20), "__fake")
|
||||||
sched.on_trial_add(None, t)
|
sched.on_trial_add(None, t)
|
||||||
|
|
||||||
self.assertEqual(sched._cur_band_filled(), False)
|
self.assertEqual(sched._cur_band_filled(), False)
|
||||||
|
|
||||||
unfilled_band = sched._hyperbands[1]
|
unfilled_band = sched._hyperbands[-1]
|
||||||
self.assertEqual(len(unfilled_band), 1)
|
self.assertEqual(len(unfilled_band), 2)
|
||||||
self.assertEqual(len(sched._hyperbands[1]), 1)
|
bracket = unfilled_band[-1]
|
||||||
bracket = unfilled_band[0]
|
|
||||||
self.assertEqual(bracket.filled(), False)
|
self.assertEqual(bracket.filled(), False)
|
||||||
self.assertEqual(len(bracket.current_trials()), 3)
|
self.assertEqual(len(bracket.current_trials()), 1)
|
||||||
|
|
||||||
return sched
|
return sched
|
||||||
|
|
||||||
def testBasicHalving(self):
|
def stopTrial(self, trial, mock_runner):
|
||||||
sched = self.advancedSetup()
|
self.assertNotEqual(trial.status, Trial.TERMINATED)
|
||||||
mock_runner = _MockTrialRunner()
|
mock_runner._stop_trial(trial)
|
||||||
filled_band = sched._hyperbands[0]
|
|
||||||
big_bracket = filled_band[0]
|
|
||||||
bracket_trials = big_bracket.current_trials()
|
|
||||||
|
|
||||||
for t in bracket_trials:
|
|
||||||
mock_runner._launch_trial(t)
|
|
||||||
|
|
||||||
for i, t in enumerate(bracket_trials):
|
|
||||||
if i == len(bracket_trials) - 1:
|
|
||||||
break
|
|
||||||
self.assertEqual(
|
|
||||||
TrialScheduler.PAUSE,
|
|
||||||
sched.on_trial_result(mock_runner, t, result(i, 10)))
|
|
||||||
mock_runner._pause_trial(t)
|
|
||||||
self.assertEqual(
|
|
||||||
TrialScheduler.CONTINUE,
|
|
||||||
sched.on_trial_result(
|
|
||||||
mock_runner, bracket_trials[-1], result(7, 12)))
|
|
||||||
|
|
||||||
def testSuccessiveHalving(self):
|
def testSuccessiveHalving(self):
|
||||||
sched = HyperBandScheduler(9, eta=3)
|
"""Setup full band, then iterate through last bracket (n=9)
|
||||||
for i in range(9):
|
to make sure successive halving is correct."""
|
||||||
t = Trial("t%d" % i, "__fake")
|
|
||||||
sched.on_trial_add(None, t)
|
|
||||||
filled_band = sched._hyperbands[0]
|
|
||||||
big_bracket = filled_band[0]
|
|
||||||
mock_runner = _MockTrialRunner()
|
|
||||||
|
|
||||||
current_length = len(big_bracket.current_trials())
|
sched, mock_runner = self.schedulerSetup(17)
|
||||||
for i in range(current_length):
|
filled_band = sched._hyperbands[0][-1]
|
||||||
trl = sched.choose_trial_to_run(mock_runner)
|
big_bracket = filled_band
|
||||||
|
|
||||||
|
for trl in big_bracket.current_trials():
|
||||||
mock_runner._launch_trial(trl)
|
mock_runner._launch_trial(trl)
|
||||||
while True:
|
|
||||||
status = sched.on_trial_result(mock_runner, trl, result(1, 10))
|
# Provides results from 0 to 8 in order, keeping the last one running
|
||||||
|
for i, trl in enumerate(big_bracket.current_trials()):
|
||||||
|
status = sched.on_trial_result(mock_runner, trl, result(1, i))
|
||||||
if status == TrialScheduler.CONTINUE:
|
if status == TrialScheduler.CONTINUE:
|
||||||
continue
|
continue
|
||||||
elif status == TrialScheduler.PAUSE:
|
elif status == TrialScheduler.PAUSE:
|
||||||
mock_runner._pause_trial(trl)
|
mock_runner._pause_trial(trl)
|
||||||
break
|
elif status == TrialScheduler.STOP:
|
||||||
|
self.assertNotEqual(trl.status, Trial.TERMINATED)
|
||||||
|
self.stopTrial(trl, mock_runner)
|
||||||
|
|
||||||
def testBasicRun(self):
|
current_length = len(big_bracket.current_trials())
|
||||||
|
self.assertEqual(status, TrialScheduler.CONTINUE)
|
||||||
|
self.assertEqual(current_length, 3)
|
||||||
|
|
||||||
|
# Techincally only need to launch 2/3, as one is already running
|
||||||
|
for trl in big_bracket.current_trials():
|
||||||
|
mock_runner._launch_trial(trl)
|
||||||
|
|
||||||
|
# Provides results from 2 to 0 in order, killing the last one
|
||||||
|
for i, trl in reversed(list(enumerate(big_bracket.current_trials()))):
|
||||||
|
for j in range(3):
|
||||||
|
status = sched.on_trial_result(mock_runner, trl, result(1, i))
|
||||||
|
if status == TrialScheduler.CONTINUE:
|
||||||
|
continue
|
||||||
|
elif status == TrialScheduler.PAUSE:
|
||||||
|
mock_runner._pause_trial(trl)
|
||||||
|
elif status == TrialScheduler.STOP:
|
||||||
|
self.stopTrial(trl, mock_runner)
|
||||||
|
|
||||||
|
self.assertEqual(status, TrialScheduler.STOP)
|
||||||
|
trl = big_bracket.current_trials()[0]
|
||||||
|
for i in range(9):
|
||||||
|
status = sched.on_trial_result(mock_runner, trl, result(1, i))
|
||||||
|
self.assertEqual(status, TrialScheduler.STOP)
|
||||||
|
self.assertEqual(len(big_bracket.current_trials()), 0)
|
||||||
|
self.assertEqual(sched._num_stopped, 9)
|
||||||
|
|
||||||
|
def testScheduling(self):
|
||||||
|
"""Setup two bands, then make sure all trials are running"""
|
||||||
sched = self.advancedSetup()
|
sched = self.advancedSetup()
|
||||||
mock_runner = _MockTrialRunner()
|
mock_runner = _MockTrialRunner()
|
||||||
trl = sched.choose_trial_to_run(mock_runner)
|
trl = sched.choose_trial_to_run(mock_runner)
|
||||||
while trl:
|
while trl:
|
||||||
|
# If band iteration > 0, make sure first band is all running
|
||||||
if sched._trial_info[trl][1] > 0:
|
if sched._trial_info[trl][1] > 0:
|
||||||
first_band = sched._hyperbands[0]
|
first_band = sched._hyperbands[0]
|
||||||
trials = [t for b in first_band for t in b._live_trials]
|
trials = [t for b in first_band for t in b._live_trials]
|
||||||
|
@ -252,93 +272,67 @@ class HyperbandSuite(unittest.TestCase):
|
||||||
all(t.status == Trial.RUNNING for t in trials), True)
|
all(t.status == Trial.RUNNING for t in trials), True)
|
||||||
|
|
||||||
def testTrialErrored(self):
|
def testTrialErrored(self):
|
||||||
sched = HyperBandScheduler(9, eta=3)
|
sched, mock_runner = self.schedulerSetup(10)
|
||||||
t1 = Trial("t1", "__fake")
|
t1, t2 = sched._state["bracket"].current_trials()
|
||||||
t2 = Trial("t2", "__fake")
|
mock_runner._launch_trial(t1)
|
||||||
sched.on_trial_add(None, t1)
|
mock_runner._launch_trial(t2)
|
||||||
sched.on_trial_add(None, t2)
|
|
||||||
mock_runner = _MockTrialRunner()
|
|
||||||
filled_band = sched._hyperbands[0]
|
|
||||||
big_bracket = filled_band[0]
|
|
||||||
bracket_trials = big_bracket.current_trials()
|
|
||||||
|
|
||||||
for t in bracket_trials:
|
|
||||||
mock_runner._launch_trial(t)
|
|
||||||
|
|
||||||
sched.on_trial_error(mock_runner, t2)
|
sched.on_trial_error(mock_runner, t2)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
TrialScheduler.CONTINUE,
|
TrialScheduler.CONTINUE,
|
||||||
sched.on_trial_result(mock_runner, t1, result(3, 10)))
|
sched.on_trial_result(mock_runner, t1, result(1, 10)))
|
||||||
|
|
||||||
|
def testTrialErrored2(self):
|
||||||
|
"""Check successive halving happened even when last trial failed"""
|
||||||
|
sched, mock_runner = self.schedulerSetup(17)
|
||||||
|
trials = sched._state["bracket"].current_trials()
|
||||||
|
self.assertEqual(len(trials), 9)
|
||||||
|
for t in trials[:-1]:
|
||||||
|
mock_runner._launch_trial(t)
|
||||||
|
sched.on_trial_result(mock_runner, t, result(1, 10))
|
||||||
|
|
||||||
|
mock_runner._launch_trial(trials[-1])
|
||||||
|
sched.on_trial_error(mock_runner, trials[-1])
|
||||||
|
self.assertEqual(len(sched._state["bracket"].current_trials()), 3)
|
||||||
|
|
||||||
def testTrialEndedEarly(self):
|
def testTrialEndedEarly(self):
|
||||||
sched = HyperBandScheduler(9, eta=3)
|
sched, mock_runner = self.schedulerSetup(10)
|
||||||
t1 = Trial("t1", "__fake")
|
trials = sched._state["bracket"].current_trials()
|
||||||
t2 = Trial("t2", "__fake")
|
for t in trials:
|
||||||
sched.on_trial_add(None, t1)
|
|
||||||
sched.on_trial_add(None, t2)
|
|
||||||
mock_runner = _MockTrialRunner()
|
|
||||||
filled_band = sched._hyperbands[0]
|
|
||||||
big_bracket = filled_band[0]
|
|
||||||
bracket_trials = big_bracket.current_trials()
|
|
||||||
|
|
||||||
for t in bracket_trials:
|
|
||||||
mock_runner._launch_trial(t)
|
mock_runner._launch_trial(t)
|
||||||
|
|
||||||
sched.on_trial_complete(mock_runner, t2, result(5, 10))
|
sched.on_trial_complete(mock_runner, trials[-1], result(1, 12))
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
TrialScheduler.CONTINUE,
|
TrialScheduler.CONTINUE,
|
||||||
sched.on_trial_result(mock_runner, t1, result(3, 12)))
|
sched.on_trial_result(mock_runner, trials[0], result(1, 12)))
|
||||||
|
|
||||||
def testAddAfterHalf(self):
|
def testTrialEndedEarly2(self):
|
||||||
sched = HyperBandScheduler(9, eta=3)
|
"""Check successive halving happened even when last trial finished"""
|
||||||
for i in range(2):
|
sched, mock_runner = self.schedulerSetup(17)
|
||||||
t = Trial("t%d" % i, "__fake")
|
trials = sched._state["bracket"].current_trials()
|
||||||
sched.on_trial_add(None, t)
|
self.assertEqual(len(trials), 9)
|
||||||
mock_runner = _MockTrialRunner()
|
for t in trials[:-1]:
|
||||||
filled_band = sched._hyperbands[0]
|
mock_runner._launch_trial(t)
|
||||||
big_bracket = filled_band[0]
|
sched.on_trial_result(mock_runner, t, result(1, 10))
|
||||||
bracket_trials = big_bracket.current_trials()
|
|
||||||
|
mock_runner._launch_trial(trials[-1])
|
||||||
|
sched.on_trial_complete(mock_runner, trials[-1], result(1, 12))
|
||||||
|
self.assertEqual(len(sched._state["bracket"].current_trials()), 3)
|
||||||
|
|
||||||
|
def testAddAfterHalving(self):
|
||||||
|
sched, mock_runner = self.schedulerSetup(10)
|
||||||
|
bracket_trials = sched._state["bracket"].current_trials()
|
||||||
|
|
||||||
for t in bracket_trials:
|
for t in bracket_trials:
|
||||||
mock_runner._launch_trial(t)
|
mock_runner._launch_trial(t)
|
||||||
|
|
||||||
for i, t in enumerate(bracket_trials):
|
for i, t in enumerate(bracket_trials):
|
||||||
if i == len(bracket_trials) - 1:
|
res = sched.on_trial_result(
|
||||||
break
|
mock_runner, t, result(1, i))
|
||||||
self.assertEqual(
|
self.assertEqual(res, TrialScheduler.CONTINUE)
|
||||||
TrialScheduler.PAUSE,
|
|
||||||
sched.on_trial_result(mock_runner, t, result(i, 10)))
|
|
||||||
mock_runner._pause_trial(t)
|
|
||||||
self.assertEqual(
|
|
||||||
TrialScheduler.CONTINUE,
|
|
||||||
sched.on_trial_result(
|
|
||||||
mock_runner, bracket_trials[-1], result(7, 12)))
|
|
||||||
t = Trial("t%d" % 5, "__fake")
|
t = Trial("t%d" % 5, "__fake")
|
||||||
sched.on_trial_add(None, t)
|
sched.on_trial_add(None, t)
|
||||||
self.assertEqual(4, big_bracket._live_trials[t][1])
|
self.assertEqual(3 + 1, sched._state["bracket"]._live_trials[t][1])
|
||||||
|
|
||||||
def testDone(self):
|
|
||||||
sched = HyperBandScheduler(3, eta=3)
|
|
||||||
mock_runner = _MockTrialRunner()
|
|
||||||
trials = [Trial("t%d" % i, "__fake") for i in range(5)]
|
|
||||||
for t in trials:
|
|
||||||
sched.on_trial_add(None, t)
|
|
||||||
|
|
||||||
filled_band = sched._hyperbands[0]
|
|
||||||
brack = filled_band[1]
|
|
||||||
bracket_trials = brack.current_trials()
|
|
||||||
for t in bracket_trials:
|
|
||||||
mock_runner._launch_trial(t)
|
|
||||||
for i in range(3):
|
|
||||||
res = sched.on_trial_result(
|
|
||||||
mock_runner, bracket_trials[-1], result(i, 10))
|
|
||||||
self.assertEqual(res, TrialScheduler.PAUSE)
|
|
||||||
mock_runner._pause_trial(bracket_trials[-1])
|
|
||||||
for i in range(3):
|
|
||||||
res = sched.on_trial_result(
|
|
||||||
mock_runner, bracket_trials[-2], result(i, 10))
|
|
||||||
self.assertEqual(res, TrialScheduler.STOP)
|
|
||||||
self.assertEqual(len(brack.current_trials()), 1)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
Loading…
Add table
Reference in a new issue