mirror of
https://github.com/vale981/ray
synced 2025-03-10 21:36:39 -04:00
[tune] Fixing up Hyperband (#1207)
* Fixing up Hyperband * nit * cleanup * Timing test Added * added_exception_back * fixup_tests * reverse placement * fixes_and_tests * fix * fix * fixlint * cleanup_timing * lint * Update hyperband.py
This commit is contained in:
parent
7c38f964b7
commit
71f8cd2403
2 changed files with 176 additions and 161 deletions
|
@ -60,11 +60,9 @@ class HyperBandScheduler(FIFOScheduler):
|
|||
|
||||
def on_trial_add(self, trial_runner, trial):
|
||||
"""On a new trial add, if current bracket is not filled,
|
||||
add to current bracket. Else, if current hp iteration is not filled,
|
||||
add to current bracket. Else, if current band is not filled,
|
||||
create new bracket, add to current bracket.
|
||||
Else, create new iteration, create new bracket, add to bracket.
|
||||
|
||||
TODO(rliaw): This is messy."""
|
||||
Else, create new iteration, create new bracket, add to bracket."""
|
||||
|
||||
cur_bracket = self._state["bracket"]
|
||||
cur_band = self._hyperbands[self._state["band_idx"]]
|
||||
|
@ -76,9 +74,9 @@ class HyperBandScheduler(FIFOScheduler):
|
|||
self._hyperbands.append(cur_band)
|
||||
self._state["band_idx"] += 1
|
||||
|
||||
# cur_band will always be less than s_max or else filled
|
||||
s = self._s_max_1 - len(cur_band) - 1
|
||||
assert s >= 0, "Current band is filled but adding bracket!"
|
||||
# cur_band will always be less than s_max_1 or else filled
|
||||
s = len(cur_band)
|
||||
assert s < self._s_max_1, "Current band is filled!"
|
||||
|
||||
# create new bracket
|
||||
cur_bracket = Bracket(self._get_n0(s),
|
||||
|
@ -102,34 +100,44 @@ class HyperBandScheduler(FIFOScheduler):
|
|||
|
||||
If a given trial finishes and bracket iteration is not done,
|
||||
the trial will be paused and resources will be given up.
|
||||
When bracket iteration is done, Trials will be successively halved,
|
||||
and during each halving phase, bad trials will be stopped while good
|
||||
trials will return to "PENDING". This scheduler will not start trials
|
||||
but will stop trials. The current running trial will not be handled,
|
||||
|
||||
This scheduler will not start trials but will stop trials.
|
||||
The current running trial will not be handled,
|
||||
as the trialrunner will be given control to handle it.
|
||||
|
||||
# TODO(rliaw) should be only called if trial has not errored"""
|
||||
bracket, _ = self._trial_info[trial]
|
||||
bracket.update_trial_stats(trial, result)
|
||||
|
||||
if bracket.continue_trial(trial):
|
||||
return TrialScheduler.CONTINUE
|
||||
|
||||
signal = TrialScheduler.PAUSE
|
||||
action = self._process_bracket(trial_runner, bracket, trial)
|
||||
return action
|
||||
|
||||
def _process_bracket(self, trial_runner, bracket, trial):
|
||||
"""This is called whenever a trial makes progress.
|
||||
|
||||
When all live trials in the bracket have no more iterations left,
|
||||
Trials will be successively halved. If bracket is done, all
|
||||
non-running trials will be stopped and cleaned up,
|
||||
and during each halving phase, bad trials will be stopped while good
|
||||
trials will return to "PENDING"."""
|
||||
|
||||
action = TrialScheduler.PAUSE
|
||||
if bracket.cur_iter_done():
|
||||
if bracket.finished():
|
||||
self._cleanup_bracket(trial_runner, bracket)
|
||||
return TrialScheduler.STOP
|
||||
# what if bracket is done and trial not completed?
|
||||
|
||||
good, bad = bracket.successive_halving()
|
||||
# kill bad trials
|
||||
for t in bad:
|
||||
self._num_stopped += 1
|
||||
if t.status == Trial.PAUSED:
|
||||
trial_runner._stop_trial(t)
|
||||
bracket.cleanup_trial_early(t)
|
||||
elif t is trial:
|
||||
signal = TrialScheduler.STOP
|
||||
self._cleanup_trial(trial_runner, t, bracket, hard=True)
|
||||
elif t.status == Trial.RUNNING:
|
||||
self._cleanup_trial(trial_runner, t, bracket, hard=False)
|
||||
action = TrialScheduler.STOP
|
||||
else:
|
||||
raise Exception("Trial with unexpected status encountered")
|
||||
|
||||
|
@ -137,38 +145,42 @@ class HyperBandScheduler(FIFOScheduler):
|
|||
for t in good:
|
||||
if t.status == Trial.PAUSED:
|
||||
t.unpause()
|
||||
elif t is trial:
|
||||
signal = TrialScheduler.CONTINUE
|
||||
elif t.status == Trial.RUNNING:
|
||||
action = TrialScheduler.CONTINUE
|
||||
else:
|
||||
raise Exception("Trial with unexpected status encountered")
|
||||
return action
|
||||
|
||||
return signal
|
||||
def _cleanup_trial(self, trial_runner, t, bracket, hard=False):
|
||||
"""Bookkeeping for trials finished. If `hard=True`, then
|
||||
this scheduler will force the trial_runner to release resources.
|
||||
|
||||
Otherwise, only clean up trial information locally."""
|
||||
self._num_stopped += 1
|
||||
if hard:
|
||||
trial_runner._stop_trial(t)
|
||||
bracket.cleanup_trial(t)
|
||||
|
||||
def _cleanup_bracket(self, trial_runner, bracket):
|
||||
"""Cleans up bracket after bracket is completely finished.
|
||||
|
||||
Bracket information will only be cleaned up after the trialrunner has
|
||||
finished its bookkeeping."""
|
||||
for t in bracket.current_trials():
|
||||
if t.status == Trial.PAUSED:
|
||||
trial_runner._stop_trial(t)
|
||||
bracket.cleanup_trial_early(t)
|
||||
"""Cleans up bracket after bracket is completely finished."""
|
||||
for trial in bracket.current_trials():
|
||||
self._cleanup_trial(
|
||||
trial_runner, trial, bracket,
|
||||
hard=(trial.status == Trial.PAUSED))
|
||||
|
||||
def on_trial_complete(self, trial_runner, trial, result):
|
||||
"""Cleans up trial info from bracket if trial completed early.
|
||||
"""Cleans up trial info from bracket if trial completed early."""
|
||||
|
||||
Bracket information will only be cleaned up after the trialrunner has
|
||||
finished its bookkeeping."""
|
||||
bracket, _ = self._trial_info[trial]
|
||||
bracket.cleanup_trial_early(trial)
|
||||
self._cleanup_trial(trial_runner, trial, bracket, hard=False)
|
||||
self._process_bracket(trial_runner, bracket, trial)
|
||||
|
||||
def on_trial_error(self, trial_runner, trial):
|
||||
"""Cleans up trial info from bracket if trial errored early.
|
||||
"""Cleans up trial info from bracket if trial errored early."""
|
||||
|
||||
Bracket information will only be cleaned up after the trialrunner has
|
||||
finished its bookkeeping."""
|
||||
bracket, _ = self._trial_info[trial]
|
||||
bracket.cleanup_trial_early(trial)
|
||||
self._cleanup_trial(trial_runner, trial, bracket, hard=False)
|
||||
self._process_bracket(trial_runner, bracket, trial)
|
||||
|
||||
def choose_trial_to_run(self, trial_runner, *args):
|
||||
"""Fair scheduling within iteration by completion percentage.
|
||||
|
@ -177,6 +189,7 @@ class HyperBandScheduler(FIFOScheduler):
|
|||
|
||||
If iteration is occupied (ie, no trials to run), then look into
|
||||
next iteration."""
|
||||
|
||||
for hyperband in self._hyperbands:
|
||||
for bracket in sorted(hyperband,
|
||||
key=lambda b: b.completion_percentage()):
|
||||
|
@ -187,10 +200,17 @@ class HyperBandScheduler(FIFOScheduler):
|
|||
return None
|
||||
|
||||
def debug_string(self):
|
||||
brackets = [
|
||||
"({0}/{1})".format(
|
||||
len(bracket._live_trials), len(bracket._all_trials))
|
||||
for band in self._hyperbands for bracket in band]
|
||||
return " ".join([
|
||||
"Using HyperBand:",
|
||||
"num_stopped={}".format(self._num_stopped),
|
||||
"brackets={}".format(sum(len(band) for band in self._hyperbands))])
|
||||
"total_brackets={}".format(
|
||||
sum(len(band) for band in self._hyperbands)),
|
||||
" ".join(brackets)
|
||||
])
|
||||
|
||||
|
||||
class Bracket():
|
||||
|
@ -278,8 +298,9 @@ class Bracket():
|
|||
self._live_trials[trial] = (result, itr - 1)
|
||||
self._completed_progress += 1
|
||||
|
||||
def cleanup_trial_early(self, trial):
|
||||
"""Clean up statistics tracking for trial that terminated early.
|
||||
def cleanup_trial(self, trial):
|
||||
"""Clean up statistics tracking for terminated trials (either by force
|
||||
or otherwise).
|
||||
|
||||
This may cause bad trials to continue for a long time, in the case
|
||||
where all the good trials finish early and there are only bad trials
|
||||
|
|
|
@ -151,17 +151,26 @@ class _MockTrialRunner():
|
|||
|
||||
|
||||
class HyperbandSuite(unittest.TestCase):
|
||||
def basicSetup(self):
|
||||
"""s_max_1 = 3;
|
||||
brackets: iter (n, r) | iter (n, r) | iter (n, r)
|
||||
(9, 1) -> (3, 3) -> (1, 9)
|
||||
(9, 1) -> (3, 3) -> (1, 9)
|
||||
"""
|
||||
|
||||
def schedulerSetup(self, num_trials):
|
||||
"""Setup a scheduler and Runner with max Iter = 9
|
||||
|
||||
Bracketing is placed as follows:
|
||||
(3, 9);
|
||||
(5, 3) -> (2, 9);
|
||||
(9, 1) -> (3, 3) -> (1, 9); """
|
||||
sched = HyperBandScheduler(9, eta=3)
|
||||
for i in range(17):
|
||||
for i in range(num_trials):
|
||||
t = Trial("t%d" % i, "__fake")
|
||||
sched.on_trial_add(None, t)
|
||||
runner = _MockTrialRunner()
|
||||
return sched, runner
|
||||
|
||||
def basicSetup(self):
|
||||
"""Setup and verify full band.
|
||||
"""
|
||||
|
||||
sched, _ = self.schedulerSetup(17)
|
||||
|
||||
self.assertEqual(len(sched._hyperbands), 1)
|
||||
self.assertEqual(sched._cur_band_filled(), True)
|
||||
|
@ -173,69 +182,80 @@ class HyperbandSuite(unittest.TestCase):
|
|||
|
||||
def advancedSetup(self):
|
||||
sched = self.basicSetup()
|
||||
for i in range(3):
|
||||
for i in range(4):
|
||||
t = Trial("t%d" % (i + 20), "__fake")
|
||||
sched.on_trial_add(None, t)
|
||||
|
||||
self.assertEqual(sched._cur_band_filled(), False)
|
||||
|
||||
unfilled_band = sched._hyperbands[1]
|
||||
self.assertEqual(len(unfilled_band), 1)
|
||||
self.assertEqual(len(sched._hyperbands[1]), 1)
|
||||
bracket = unfilled_band[0]
|
||||
unfilled_band = sched._hyperbands[-1]
|
||||
self.assertEqual(len(unfilled_band), 2)
|
||||
bracket = unfilled_band[-1]
|
||||
self.assertEqual(bracket.filled(), False)
|
||||
self.assertEqual(len(bracket.current_trials()), 3)
|
||||
self.assertEqual(len(bracket.current_trials()), 1)
|
||||
|
||||
return sched
|
||||
|
||||
def testBasicHalving(self):
|
||||
sched = self.advancedSetup()
|
||||
mock_runner = _MockTrialRunner()
|
||||
filled_band = sched._hyperbands[0]
|
||||
big_bracket = filled_band[0]
|
||||
bracket_trials = big_bracket.current_trials()
|
||||
|
||||
for t in bracket_trials:
|
||||
mock_runner._launch_trial(t)
|
||||
|
||||
for i, t in enumerate(bracket_trials):
|
||||
if i == len(bracket_trials) - 1:
|
||||
break
|
||||
self.assertEqual(
|
||||
TrialScheduler.PAUSE,
|
||||
sched.on_trial_result(mock_runner, t, result(i, 10)))
|
||||
mock_runner._pause_trial(t)
|
||||
self.assertEqual(
|
||||
TrialScheduler.CONTINUE,
|
||||
sched.on_trial_result(
|
||||
mock_runner, bracket_trials[-1], result(7, 12)))
|
||||
def stopTrial(self, trial, mock_runner):
|
||||
self.assertNotEqual(trial.status, Trial.TERMINATED)
|
||||
mock_runner._stop_trial(trial)
|
||||
|
||||
def testSuccessiveHalving(self):
|
||||
sched = HyperBandScheduler(9, eta=3)
|
||||
for i in range(9):
|
||||
t = Trial("t%d" % i, "__fake")
|
||||
sched.on_trial_add(None, t)
|
||||
filled_band = sched._hyperbands[0]
|
||||
big_bracket = filled_band[0]
|
||||
mock_runner = _MockTrialRunner()
|
||||
"""Setup full band, then iterate through last bracket (n=9)
|
||||
to make sure successive halving is correct."""
|
||||
|
||||
sched, mock_runner = self.schedulerSetup(17)
|
||||
filled_band = sched._hyperbands[0][-1]
|
||||
big_bracket = filled_band
|
||||
|
||||
for trl in big_bracket.current_trials():
|
||||
mock_runner._launch_trial(trl)
|
||||
|
||||
# Provides results from 0 to 8 in order, keeping the last one running
|
||||
for i, trl in enumerate(big_bracket.current_trials()):
|
||||
status = sched.on_trial_result(mock_runner, trl, result(1, i))
|
||||
if status == TrialScheduler.CONTINUE:
|
||||
continue
|
||||
elif status == TrialScheduler.PAUSE:
|
||||
mock_runner._pause_trial(trl)
|
||||
elif status == TrialScheduler.STOP:
|
||||
self.assertNotEqual(trl.status, Trial.TERMINATED)
|
||||
self.stopTrial(trl, mock_runner)
|
||||
|
||||
current_length = len(big_bracket.current_trials())
|
||||
for i in range(current_length):
|
||||
trl = sched.choose_trial_to_run(mock_runner)
|
||||
mock_runner._launch_trial(trl)
|
||||
while True:
|
||||
status = sched.on_trial_result(mock_runner, trl, result(1, 10))
|
||||
if status == TrialScheduler.CONTINUE:
|
||||
continue
|
||||
elif status == TrialScheduler.PAUSE:
|
||||
mock_runner._pause_trial(trl)
|
||||
break
|
||||
self.assertEqual(status, TrialScheduler.CONTINUE)
|
||||
self.assertEqual(current_length, 3)
|
||||
|
||||
def testBasicRun(self):
|
||||
# Techincally only need to launch 2/3, as one is already running
|
||||
for trl in big_bracket.current_trials():
|
||||
mock_runner._launch_trial(trl)
|
||||
|
||||
# Provides results from 2 to 0 in order, killing the last one
|
||||
for i, trl in reversed(list(enumerate(big_bracket.current_trials()))):
|
||||
for j in range(3):
|
||||
status = sched.on_trial_result(mock_runner, trl, result(1, i))
|
||||
if status == TrialScheduler.CONTINUE:
|
||||
continue
|
||||
elif status == TrialScheduler.PAUSE:
|
||||
mock_runner._pause_trial(trl)
|
||||
elif status == TrialScheduler.STOP:
|
||||
self.stopTrial(trl, mock_runner)
|
||||
|
||||
self.assertEqual(status, TrialScheduler.STOP)
|
||||
trl = big_bracket.current_trials()[0]
|
||||
for i in range(9):
|
||||
status = sched.on_trial_result(mock_runner, trl, result(1, i))
|
||||
self.assertEqual(status, TrialScheduler.STOP)
|
||||
self.assertEqual(len(big_bracket.current_trials()), 0)
|
||||
self.assertEqual(sched._num_stopped, 9)
|
||||
|
||||
def testScheduling(self):
|
||||
"""Setup two bands, then make sure all trials are running"""
|
||||
sched = self.advancedSetup()
|
||||
mock_runner = _MockTrialRunner()
|
||||
trl = sched.choose_trial_to_run(mock_runner)
|
||||
while trl:
|
||||
# If band iteration > 0, make sure first band is all running
|
||||
if sched._trial_info[trl][1] > 0:
|
||||
first_band = sched._hyperbands[0]
|
||||
trials = [t for b in first_band for t in b._live_trials]
|
||||
|
@ -252,93 +272,67 @@ class HyperbandSuite(unittest.TestCase):
|
|||
all(t.status == Trial.RUNNING for t in trials), True)
|
||||
|
||||
def testTrialErrored(self):
|
||||
sched = HyperBandScheduler(9, eta=3)
|
||||
t1 = Trial("t1", "__fake")
|
||||
t2 = Trial("t2", "__fake")
|
||||
sched.on_trial_add(None, t1)
|
||||
sched.on_trial_add(None, t2)
|
||||
mock_runner = _MockTrialRunner()
|
||||
filled_band = sched._hyperbands[0]
|
||||
big_bracket = filled_band[0]
|
||||
bracket_trials = big_bracket.current_trials()
|
||||
|
||||
for t in bracket_trials:
|
||||
mock_runner._launch_trial(t)
|
||||
sched, mock_runner = self.schedulerSetup(10)
|
||||
t1, t2 = sched._state["bracket"].current_trials()
|
||||
mock_runner._launch_trial(t1)
|
||||
mock_runner._launch_trial(t2)
|
||||
|
||||
sched.on_trial_error(mock_runner, t2)
|
||||
self.assertEqual(
|
||||
TrialScheduler.CONTINUE,
|
||||
sched.on_trial_result(mock_runner, t1, result(3, 10)))
|
||||
sched.on_trial_result(mock_runner, t1, result(1, 10)))
|
||||
|
||||
def testTrialErrored2(self):
|
||||
"""Check successive halving happened even when last trial failed"""
|
||||
sched, mock_runner = self.schedulerSetup(17)
|
||||
trials = sched._state["bracket"].current_trials()
|
||||
self.assertEqual(len(trials), 9)
|
||||
for t in trials[:-1]:
|
||||
mock_runner._launch_trial(t)
|
||||
sched.on_trial_result(mock_runner, t, result(1, 10))
|
||||
|
||||
mock_runner._launch_trial(trials[-1])
|
||||
sched.on_trial_error(mock_runner, trials[-1])
|
||||
self.assertEqual(len(sched._state["bracket"].current_trials()), 3)
|
||||
|
||||
def testTrialEndedEarly(self):
|
||||
sched = HyperBandScheduler(9, eta=3)
|
||||
t1 = Trial("t1", "__fake")
|
||||
t2 = Trial("t2", "__fake")
|
||||
sched.on_trial_add(None, t1)
|
||||
sched.on_trial_add(None, t2)
|
||||
mock_runner = _MockTrialRunner()
|
||||
filled_band = sched._hyperbands[0]
|
||||
big_bracket = filled_band[0]
|
||||
bracket_trials = big_bracket.current_trials()
|
||||
|
||||
for t in bracket_trials:
|
||||
sched, mock_runner = self.schedulerSetup(10)
|
||||
trials = sched._state["bracket"].current_trials()
|
||||
for t in trials:
|
||||
mock_runner._launch_trial(t)
|
||||
|
||||
sched.on_trial_complete(mock_runner, t2, result(5, 10))
|
||||
sched.on_trial_complete(mock_runner, trials[-1], result(1, 12))
|
||||
self.assertEqual(
|
||||
TrialScheduler.CONTINUE,
|
||||
sched.on_trial_result(mock_runner, t1, result(3, 12)))
|
||||
sched.on_trial_result(mock_runner, trials[0], result(1, 12)))
|
||||
|
||||
def testAddAfterHalf(self):
|
||||
sched = HyperBandScheduler(9, eta=3)
|
||||
for i in range(2):
|
||||
t = Trial("t%d" % i, "__fake")
|
||||
sched.on_trial_add(None, t)
|
||||
mock_runner = _MockTrialRunner()
|
||||
filled_band = sched._hyperbands[0]
|
||||
big_bracket = filled_band[0]
|
||||
bracket_trials = big_bracket.current_trials()
|
||||
def testTrialEndedEarly2(self):
|
||||
"""Check successive halving happened even when last trial finished"""
|
||||
sched, mock_runner = self.schedulerSetup(17)
|
||||
trials = sched._state["bracket"].current_trials()
|
||||
self.assertEqual(len(trials), 9)
|
||||
for t in trials[:-1]:
|
||||
mock_runner._launch_trial(t)
|
||||
sched.on_trial_result(mock_runner, t, result(1, 10))
|
||||
|
||||
mock_runner._launch_trial(trials[-1])
|
||||
sched.on_trial_complete(mock_runner, trials[-1], result(1, 12))
|
||||
self.assertEqual(len(sched._state["bracket"].current_trials()), 3)
|
||||
|
||||
def testAddAfterHalving(self):
|
||||
sched, mock_runner = self.schedulerSetup(10)
|
||||
bracket_trials = sched._state["bracket"].current_trials()
|
||||
|
||||
for t in bracket_trials:
|
||||
mock_runner._launch_trial(t)
|
||||
|
||||
for i, t in enumerate(bracket_trials):
|
||||
if i == len(bracket_trials) - 1:
|
||||
break
|
||||
self.assertEqual(
|
||||
TrialScheduler.PAUSE,
|
||||
sched.on_trial_result(mock_runner, t, result(i, 10)))
|
||||
mock_runner._pause_trial(t)
|
||||
self.assertEqual(
|
||||
TrialScheduler.CONTINUE,
|
||||
sched.on_trial_result(
|
||||
mock_runner, bracket_trials[-1], result(7, 12)))
|
||||
res = sched.on_trial_result(
|
||||
mock_runner, t, result(1, i))
|
||||
self.assertEqual(res, TrialScheduler.CONTINUE)
|
||||
t = Trial("t%d" % 5, "__fake")
|
||||
sched.on_trial_add(None, t)
|
||||
self.assertEqual(4, big_bracket._live_trials[t][1])
|
||||
|
||||
def testDone(self):
|
||||
sched = HyperBandScheduler(3, eta=3)
|
||||
mock_runner = _MockTrialRunner()
|
||||
trials = [Trial("t%d" % i, "__fake") for i in range(5)]
|
||||
for t in trials:
|
||||
sched.on_trial_add(None, t)
|
||||
|
||||
filled_band = sched._hyperbands[0]
|
||||
brack = filled_band[1]
|
||||
bracket_trials = brack.current_trials()
|
||||
for t in bracket_trials:
|
||||
mock_runner._launch_trial(t)
|
||||
for i in range(3):
|
||||
res = sched.on_trial_result(
|
||||
mock_runner, bracket_trials[-1], result(i, 10))
|
||||
self.assertEqual(res, TrialScheduler.PAUSE)
|
||||
mock_runner._pause_trial(bracket_trials[-1])
|
||||
for i in range(3):
|
||||
res = sched.on_trial_result(
|
||||
mock_runner, bracket_trials[-2], result(i, 10))
|
||||
self.assertEqual(res, TrialScheduler.STOP)
|
||||
self.assertEqual(len(brack.current_trials()), 1)
|
||||
self.assertEqual(3 + 1, sched._state["bracket"]._live_trials[t][1])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
Loading…
Add table
Reference in a new issue