[tune] Fixing up Hyperband (#1207)

* Fixing up Hyperband

* nit

* cleanup

* Timing test Added

* added_exception_back

* fixup_tests

* reverse placement

* fixes_and_tests

* fix

* fix

* fixlint

* cleanup_timing

* lint

* Update hyperband.py
This commit is contained in:
Richard Liaw 2017-11-12 12:05:32 -08:00 committed by GitHub
parent 7c38f964b7
commit 71f8cd2403
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 176 additions and 161 deletions

View file

@ -60,11 +60,9 @@ class HyperBandScheduler(FIFOScheduler):
def on_trial_add(self, trial_runner, trial):
"""On a new trial add, if current bracket is not filled,
add to current bracket. Else, if current hp iteration is not filled,
add to current bracket. Else, if current band is not filled,
create new bracket, add to current bracket.
Else, create new iteration, create new bracket, add to bracket.
TODO(rliaw): This is messy."""
Else, create new iteration, create new bracket, add to bracket."""
cur_bracket = self._state["bracket"]
cur_band = self._hyperbands[self._state["band_idx"]]
@ -76,9 +74,9 @@ class HyperBandScheduler(FIFOScheduler):
self._hyperbands.append(cur_band)
self._state["band_idx"] += 1
# cur_band will always be less than s_max or else filled
s = self._s_max_1 - len(cur_band) - 1
assert s >= 0, "Current band is filled but adding bracket!"
# cur_band will always be less than s_max_1 or else filled
s = len(cur_band)
assert s < self._s_max_1, "Current band is filled!"
# create new bracket
cur_bracket = Bracket(self._get_n0(s),
@ -102,34 +100,44 @@ class HyperBandScheduler(FIFOScheduler):
If a given trial finishes and bracket iteration is not done,
the trial will be paused and resources will be given up.
When bracket iteration is done, Trials will be successively halved,
and during each halving phase, bad trials will be stopped while good
trials will return to "PENDING". This scheduler will not start trials
but will stop trials. The current running trial will not be handled,
This scheduler will not start trials but will stop trials.
The current running trial will not be handled,
as the trialrunner will be given control to handle it.
# TODO(rliaw) should be only called if trial has not errored"""
bracket, _ = self._trial_info[trial]
bracket.update_trial_stats(trial, result)
if bracket.continue_trial(trial):
return TrialScheduler.CONTINUE
signal = TrialScheduler.PAUSE
action = self._process_bracket(trial_runner, bracket, trial)
return action
def _process_bracket(self, trial_runner, bracket, trial):
"""This is called whenever a trial makes progress.
When all live trials in the bracket have no more iterations left,
Trials will be successively halved. If bracket is done, all
non-running trials will be stopped and cleaned up,
and during each halving phase, bad trials will be stopped while good
trials will return to "PENDING"."""
action = TrialScheduler.PAUSE
if bracket.cur_iter_done():
if bracket.finished():
self._cleanup_bracket(trial_runner, bracket)
return TrialScheduler.STOP
# what if bracket is done and trial not completed?
good, bad = bracket.successive_halving()
# kill bad trials
for t in bad:
self._num_stopped += 1
if t.status == Trial.PAUSED:
trial_runner._stop_trial(t)
bracket.cleanup_trial_early(t)
elif t is trial:
signal = TrialScheduler.STOP
self._cleanup_trial(trial_runner, t, bracket, hard=True)
elif t.status == Trial.RUNNING:
self._cleanup_trial(trial_runner, t, bracket, hard=False)
action = TrialScheduler.STOP
else:
raise Exception("Trial with unexpected status encountered")
@ -137,38 +145,42 @@ class HyperBandScheduler(FIFOScheduler):
for t in good:
if t.status == Trial.PAUSED:
t.unpause()
elif t is trial:
signal = TrialScheduler.CONTINUE
elif t.status == Trial.RUNNING:
action = TrialScheduler.CONTINUE
else:
raise Exception("Trial with unexpected status encountered")
return action
return signal
def _cleanup_trial(self, trial_runner, t, bracket, hard=False):
"""Bookkeeping for trials finished. If `hard=True`, then
this scheduler will force the trial_runner to release resources.
Otherwise, only clean up trial information locally."""
self._num_stopped += 1
if hard:
trial_runner._stop_trial(t)
bracket.cleanup_trial(t)
def _cleanup_bracket(self, trial_runner, bracket):
"""Cleans up bracket after bracket is completely finished.
Bracket information will only be cleaned up after the trialrunner has
finished its bookkeeping."""
for t in bracket.current_trials():
if t.status == Trial.PAUSED:
trial_runner._stop_trial(t)
bracket.cleanup_trial_early(t)
"""Cleans up bracket after bracket is completely finished."""
for trial in bracket.current_trials():
self._cleanup_trial(
trial_runner, trial, bracket,
hard=(trial.status == Trial.PAUSED))
def on_trial_complete(self, trial_runner, trial, result):
"""Cleans up trial info from bracket if trial completed early.
"""Cleans up trial info from bracket if trial completed early."""
Bracket information will only be cleaned up after the trialrunner has
finished its bookkeeping."""
bracket, _ = self._trial_info[trial]
bracket.cleanup_trial_early(trial)
self._cleanup_trial(trial_runner, trial, bracket, hard=False)
self._process_bracket(trial_runner, bracket, trial)
def on_trial_error(self, trial_runner, trial):
"""Cleans up trial info from bracket if trial errored early.
"""Cleans up trial info from bracket if trial errored early."""
Bracket information will only be cleaned up after the trialrunner has
finished its bookkeeping."""
bracket, _ = self._trial_info[trial]
bracket.cleanup_trial_early(trial)
self._cleanup_trial(trial_runner, trial, bracket, hard=False)
self._process_bracket(trial_runner, bracket, trial)
def choose_trial_to_run(self, trial_runner, *args):
"""Fair scheduling within iteration by completion percentage.
@ -177,6 +189,7 @@ class HyperBandScheduler(FIFOScheduler):
If iteration is occupied (ie, no trials to run), then look into
next iteration."""
for hyperband in self._hyperbands:
for bracket in sorted(hyperband,
key=lambda b: b.completion_percentage()):
@ -187,10 +200,17 @@ class HyperBandScheduler(FIFOScheduler):
return None
def debug_string(self):
brackets = [
"({0}/{1})".format(
len(bracket._live_trials), len(bracket._all_trials))
for band in self._hyperbands for bracket in band]
return " ".join([
"Using HyperBand:",
"num_stopped={}".format(self._num_stopped),
"brackets={}".format(sum(len(band) for band in self._hyperbands))])
"total_brackets={}".format(
sum(len(band) for band in self._hyperbands)),
" ".join(brackets)
])
class Bracket():
@ -278,8 +298,9 @@ class Bracket():
self._live_trials[trial] = (result, itr - 1)
self._completed_progress += 1
def cleanup_trial_early(self, trial):
"""Clean up statistics tracking for trial that terminated early.
def cleanup_trial(self, trial):
"""Clean up statistics tracking for terminated trials (either by force
or otherwise).
This may cause bad trials to continue for a long time, in the case
where all the good trials finish early and there are only bad trials

View file

@ -151,17 +151,26 @@ class _MockTrialRunner():
class HyperbandSuite(unittest.TestCase):
def basicSetup(self):
"""s_max_1 = 3;
brackets: iter (n, r) | iter (n, r) | iter (n, r)
(9, 1) -> (3, 3) -> (1, 9)
(9, 1) -> (3, 3) -> (1, 9)
"""
def schedulerSetup(self, num_trials):
"""Setup a scheduler and Runner with max Iter = 9
Bracketing is placed as follows:
(3, 9);
(5, 3) -> (2, 9);
(9, 1) -> (3, 3) -> (1, 9); """
sched = HyperBandScheduler(9, eta=3)
for i in range(17):
for i in range(num_trials):
t = Trial("t%d" % i, "__fake")
sched.on_trial_add(None, t)
runner = _MockTrialRunner()
return sched, runner
def basicSetup(self):
"""Setup and verify full band.
"""
sched, _ = self.schedulerSetup(17)
self.assertEqual(len(sched._hyperbands), 1)
self.assertEqual(sched._cur_band_filled(), True)
@ -173,69 +182,80 @@ class HyperbandSuite(unittest.TestCase):
def advancedSetup(self):
sched = self.basicSetup()
for i in range(3):
for i in range(4):
t = Trial("t%d" % (i + 20), "__fake")
sched.on_trial_add(None, t)
self.assertEqual(sched._cur_band_filled(), False)
unfilled_band = sched._hyperbands[1]
self.assertEqual(len(unfilled_band), 1)
self.assertEqual(len(sched._hyperbands[1]), 1)
bracket = unfilled_band[0]
unfilled_band = sched._hyperbands[-1]
self.assertEqual(len(unfilled_band), 2)
bracket = unfilled_band[-1]
self.assertEqual(bracket.filled(), False)
self.assertEqual(len(bracket.current_trials()), 3)
self.assertEqual(len(bracket.current_trials()), 1)
return sched
def testBasicHalving(self):
sched = self.advancedSetup()
mock_runner = _MockTrialRunner()
filled_band = sched._hyperbands[0]
big_bracket = filled_band[0]
bracket_trials = big_bracket.current_trials()
for t in bracket_trials:
mock_runner._launch_trial(t)
for i, t in enumerate(bracket_trials):
if i == len(bracket_trials) - 1:
break
self.assertEqual(
TrialScheduler.PAUSE,
sched.on_trial_result(mock_runner, t, result(i, 10)))
mock_runner._pause_trial(t)
self.assertEqual(
TrialScheduler.CONTINUE,
sched.on_trial_result(
mock_runner, bracket_trials[-1], result(7, 12)))
def stopTrial(self, trial, mock_runner):
self.assertNotEqual(trial.status, Trial.TERMINATED)
mock_runner._stop_trial(trial)
def testSuccessiveHalving(self):
sched = HyperBandScheduler(9, eta=3)
for i in range(9):
t = Trial("t%d" % i, "__fake")
sched.on_trial_add(None, t)
filled_band = sched._hyperbands[0]
big_bracket = filled_band[0]
mock_runner = _MockTrialRunner()
"""Setup full band, then iterate through last bracket (n=9)
to make sure successive halving is correct."""
sched, mock_runner = self.schedulerSetup(17)
filled_band = sched._hyperbands[0][-1]
big_bracket = filled_band
for trl in big_bracket.current_trials():
mock_runner._launch_trial(trl)
# Provides results from 0 to 8 in order, keeping the last one running
for i, trl in enumerate(big_bracket.current_trials()):
status = sched.on_trial_result(mock_runner, trl, result(1, i))
if status == TrialScheduler.CONTINUE:
continue
elif status == TrialScheduler.PAUSE:
mock_runner._pause_trial(trl)
elif status == TrialScheduler.STOP:
self.assertNotEqual(trl.status, Trial.TERMINATED)
self.stopTrial(trl, mock_runner)
current_length = len(big_bracket.current_trials())
for i in range(current_length):
trl = sched.choose_trial_to_run(mock_runner)
mock_runner._launch_trial(trl)
while True:
status = sched.on_trial_result(mock_runner, trl, result(1, 10))
if status == TrialScheduler.CONTINUE:
continue
elif status == TrialScheduler.PAUSE:
mock_runner._pause_trial(trl)
break
self.assertEqual(status, TrialScheduler.CONTINUE)
self.assertEqual(current_length, 3)
def testBasicRun(self):
# Techincally only need to launch 2/3, as one is already running
for trl in big_bracket.current_trials():
mock_runner._launch_trial(trl)
# Provides results from 2 to 0 in order, killing the last one
for i, trl in reversed(list(enumerate(big_bracket.current_trials()))):
for j in range(3):
status = sched.on_trial_result(mock_runner, trl, result(1, i))
if status == TrialScheduler.CONTINUE:
continue
elif status == TrialScheduler.PAUSE:
mock_runner._pause_trial(trl)
elif status == TrialScheduler.STOP:
self.stopTrial(trl, mock_runner)
self.assertEqual(status, TrialScheduler.STOP)
trl = big_bracket.current_trials()[0]
for i in range(9):
status = sched.on_trial_result(mock_runner, trl, result(1, i))
self.assertEqual(status, TrialScheduler.STOP)
self.assertEqual(len(big_bracket.current_trials()), 0)
self.assertEqual(sched._num_stopped, 9)
def testScheduling(self):
"""Setup two bands, then make sure all trials are running"""
sched = self.advancedSetup()
mock_runner = _MockTrialRunner()
trl = sched.choose_trial_to_run(mock_runner)
while trl:
# If band iteration > 0, make sure first band is all running
if sched._trial_info[trl][1] > 0:
first_band = sched._hyperbands[0]
trials = [t for b in first_band for t in b._live_trials]
@ -252,93 +272,67 @@ class HyperbandSuite(unittest.TestCase):
all(t.status == Trial.RUNNING for t in trials), True)
def testTrialErrored(self):
sched = HyperBandScheduler(9, eta=3)
t1 = Trial("t1", "__fake")
t2 = Trial("t2", "__fake")
sched.on_trial_add(None, t1)
sched.on_trial_add(None, t2)
mock_runner = _MockTrialRunner()
filled_band = sched._hyperbands[0]
big_bracket = filled_band[0]
bracket_trials = big_bracket.current_trials()
for t in bracket_trials:
mock_runner._launch_trial(t)
sched, mock_runner = self.schedulerSetup(10)
t1, t2 = sched._state["bracket"].current_trials()
mock_runner._launch_trial(t1)
mock_runner._launch_trial(t2)
sched.on_trial_error(mock_runner, t2)
self.assertEqual(
TrialScheduler.CONTINUE,
sched.on_trial_result(mock_runner, t1, result(3, 10)))
sched.on_trial_result(mock_runner, t1, result(1, 10)))
def testTrialErrored2(self):
"""Check successive halving happened even when last trial failed"""
sched, mock_runner = self.schedulerSetup(17)
trials = sched._state["bracket"].current_trials()
self.assertEqual(len(trials), 9)
for t in trials[:-1]:
mock_runner._launch_trial(t)
sched.on_trial_result(mock_runner, t, result(1, 10))
mock_runner._launch_trial(trials[-1])
sched.on_trial_error(mock_runner, trials[-1])
self.assertEqual(len(sched._state["bracket"].current_trials()), 3)
def testTrialEndedEarly(self):
sched = HyperBandScheduler(9, eta=3)
t1 = Trial("t1", "__fake")
t2 = Trial("t2", "__fake")
sched.on_trial_add(None, t1)
sched.on_trial_add(None, t2)
mock_runner = _MockTrialRunner()
filled_band = sched._hyperbands[0]
big_bracket = filled_band[0]
bracket_trials = big_bracket.current_trials()
for t in bracket_trials:
sched, mock_runner = self.schedulerSetup(10)
trials = sched._state["bracket"].current_trials()
for t in trials:
mock_runner._launch_trial(t)
sched.on_trial_complete(mock_runner, t2, result(5, 10))
sched.on_trial_complete(mock_runner, trials[-1], result(1, 12))
self.assertEqual(
TrialScheduler.CONTINUE,
sched.on_trial_result(mock_runner, t1, result(3, 12)))
sched.on_trial_result(mock_runner, trials[0], result(1, 12)))
def testAddAfterHalf(self):
sched = HyperBandScheduler(9, eta=3)
for i in range(2):
t = Trial("t%d" % i, "__fake")
sched.on_trial_add(None, t)
mock_runner = _MockTrialRunner()
filled_band = sched._hyperbands[0]
big_bracket = filled_band[0]
bracket_trials = big_bracket.current_trials()
def testTrialEndedEarly2(self):
"""Check successive halving happened even when last trial finished"""
sched, mock_runner = self.schedulerSetup(17)
trials = sched._state["bracket"].current_trials()
self.assertEqual(len(trials), 9)
for t in trials[:-1]:
mock_runner._launch_trial(t)
sched.on_trial_result(mock_runner, t, result(1, 10))
mock_runner._launch_trial(trials[-1])
sched.on_trial_complete(mock_runner, trials[-1], result(1, 12))
self.assertEqual(len(sched._state["bracket"].current_trials()), 3)
def testAddAfterHalving(self):
sched, mock_runner = self.schedulerSetup(10)
bracket_trials = sched._state["bracket"].current_trials()
for t in bracket_trials:
mock_runner._launch_trial(t)
for i, t in enumerate(bracket_trials):
if i == len(bracket_trials) - 1:
break
self.assertEqual(
TrialScheduler.PAUSE,
sched.on_trial_result(mock_runner, t, result(i, 10)))
mock_runner._pause_trial(t)
self.assertEqual(
TrialScheduler.CONTINUE,
sched.on_trial_result(
mock_runner, bracket_trials[-1], result(7, 12)))
res = sched.on_trial_result(
mock_runner, t, result(1, i))
self.assertEqual(res, TrialScheduler.CONTINUE)
t = Trial("t%d" % 5, "__fake")
sched.on_trial_add(None, t)
self.assertEqual(4, big_bracket._live_trials[t][1])
def testDone(self):
sched = HyperBandScheduler(3, eta=3)
mock_runner = _MockTrialRunner()
trials = [Trial("t%d" % i, "__fake") for i in range(5)]
for t in trials:
sched.on_trial_add(None, t)
filled_band = sched._hyperbands[0]
brack = filled_band[1]
bracket_trials = brack.current_trials()
for t in bracket_trials:
mock_runner._launch_trial(t)
for i in range(3):
res = sched.on_trial_result(
mock_runner, bracket_trials[-1], result(i, 10))
self.assertEqual(res, TrialScheduler.PAUSE)
mock_runner._pause_trial(bracket_trials[-1])
for i in range(3):
res = sched.on_trial_result(
mock_runner, bracket_trials[-2], result(i, 10))
self.assertEqual(res, TrialScheduler.STOP)
self.assertEqual(len(brack.current_trials()), 1)
self.assertEqual(3 + 1, sched._state["bracket"]._live_trials[t][1])
if __name__ == "__main__":