From 71f8cd24033884e453019bcc7e13b4211abe68a0 Mon Sep 17 00:00:00 2001 From: Richard Liaw Date: Sun, 12 Nov 2017 12:05:32 -0800 Subject: [PATCH] [tune] Fixing up Hyperband (#1207) * Fixing up Hyperband * nit * cleanup * Timing test Added * added_exception_back * fixup_tests * reverse placement * fixes_and_tests * fix * fix * fixlint * cleanup_timing * lint * Update hyperband.py --- python/ray/tune/hyperband.py | 101 +++++++++------ test/trial_scheduler_test.py | 236 +++++++++++++++++------------------ 2 files changed, 176 insertions(+), 161 deletions(-) diff --git a/python/ray/tune/hyperband.py b/python/ray/tune/hyperband.py index 6b545d318..b127bd4b2 100644 --- a/python/ray/tune/hyperband.py +++ b/python/ray/tune/hyperband.py @@ -60,11 +60,9 @@ class HyperBandScheduler(FIFOScheduler): def on_trial_add(self, trial_runner, trial): """On a new trial add, if current bracket is not filled, - add to current bracket. Else, if current hp iteration is not filled, + add to current bracket. Else, if current band is not filled, create new bracket, add to current bracket. - Else, create new iteration, create new bracket, add to bracket. - - TODO(rliaw): This is messy.""" + Else, create new iteration, create new bracket, add to bracket.""" cur_bracket = self._state["bracket"] cur_band = self._hyperbands[self._state["band_idx"]] @@ -76,9 +74,9 @@ class HyperBandScheduler(FIFOScheduler): self._hyperbands.append(cur_band) self._state["band_idx"] += 1 - # cur_band will always be less than s_max or else filled - s = self._s_max_1 - len(cur_band) - 1 - assert s >= 0, "Current band is filled but adding bracket!" + # cur_band will always be less than s_max_1 or else filled + s = len(cur_band) + assert s < self._s_max_1, "Current band is filled!" # create new bracket cur_bracket = Bracket(self._get_n0(s), @@ -102,34 +100,44 @@ class HyperBandScheduler(FIFOScheduler): If a given trial finishes and bracket iteration is not done, the trial will be paused and resources will be given up. - When bracket iteration is done, Trials will be successively halved, - and during each halving phase, bad trials will be stopped while good - trials will return to "PENDING". This scheduler will not start trials - but will stop trials. The current running trial will not be handled, + + This scheduler will not start trials but will stop trials. + The current running trial will not be handled, as the trialrunner will be given control to handle it. # TODO(rliaw) should be only called if trial has not errored""" bracket, _ = self._trial_info[trial] bracket.update_trial_stats(trial, result) + if bracket.continue_trial(trial): return TrialScheduler.CONTINUE - signal = TrialScheduler.PAUSE + action = self._process_bracket(trial_runner, bracket, trial) + return action + def _process_bracket(self, trial_runner, bracket, trial): + """This is called whenever a trial makes progress. + + When all live trials in the bracket have no more iterations left, + Trials will be successively halved. If bracket is done, all + non-running trials will be stopped and cleaned up, + and during each halving phase, bad trials will be stopped while good + trials will return to "PENDING".""" + + action = TrialScheduler.PAUSE if bracket.cur_iter_done(): if bracket.finished(): self._cleanup_bracket(trial_runner, bracket) return TrialScheduler.STOP - # what if bracket is done and trial not completed? + good, bad = bracket.successive_halving() # kill bad trials for t in bad: - self._num_stopped += 1 if t.status == Trial.PAUSED: - trial_runner._stop_trial(t) - bracket.cleanup_trial_early(t) - elif t is trial: - signal = TrialScheduler.STOP + self._cleanup_trial(trial_runner, t, bracket, hard=True) + elif t.status == Trial.RUNNING: + self._cleanup_trial(trial_runner, t, bracket, hard=False) + action = TrialScheduler.STOP else: raise Exception("Trial with unexpected status encountered") @@ -137,38 +145,42 @@ class HyperBandScheduler(FIFOScheduler): for t in good: if t.status == Trial.PAUSED: t.unpause() - elif t is trial: - signal = TrialScheduler.CONTINUE + elif t.status == Trial.RUNNING: + action = TrialScheduler.CONTINUE else: raise Exception("Trial with unexpected status encountered") + return action - return signal + def _cleanup_trial(self, trial_runner, t, bracket, hard=False): + """Bookkeeping for trials finished. If `hard=True`, then + this scheduler will force the trial_runner to release resources. + + Otherwise, only clean up trial information locally.""" + self._num_stopped += 1 + if hard: + trial_runner._stop_trial(t) + bracket.cleanup_trial(t) def _cleanup_bracket(self, trial_runner, bracket): - """Cleans up bracket after bracket is completely finished. - - Bracket information will only be cleaned up after the trialrunner has - finished its bookkeeping.""" - for t in bracket.current_trials(): - if t.status == Trial.PAUSED: - trial_runner._stop_trial(t) - bracket.cleanup_trial_early(t) + """Cleans up bracket after bracket is completely finished.""" + for trial in bracket.current_trials(): + self._cleanup_trial( + trial_runner, trial, bracket, + hard=(trial.status == Trial.PAUSED)) def on_trial_complete(self, trial_runner, trial, result): - """Cleans up trial info from bracket if trial completed early. + """Cleans up trial info from bracket if trial completed early.""" - Bracket information will only be cleaned up after the trialrunner has - finished its bookkeeping.""" bracket, _ = self._trial_info[trial] - bracket.cleanup_trial_early(trial) + self._cleanup_trial(trial_runner, trial, bracket, hard=False) + self._process_bracket(trial_runner, bracket, trial) def on_trial_error(self, trial_runner, trial): - """Cleans up trial info from bracket if trial errored early. + """Cleans up trial info from bracket if trial errored early.""" - Bracket information will only be cleaned up after the trialrunner has - finished its bookkeeping.""" bracket, _ = self._trial_info[trial] - bracket.cleanup_trial_early(trial) + self._cleanup_trial(trial_runner, trial, bracket, hard=False) + self._process_bracket(trial_runner, bracket, trial) def choose_trial_to_run(self, trial_runner, *args): """Fair scheduling within iteration by completion percentage. @@ -177,6 +189,7 @@ class HyperBandScheduler(FIFOScheduler): If iteration is occupied (ie, no trials to run), then look into next iteration.""" + for hyperband in self._hyperbands: for bracket in sorted(hyperband, key=lambda b: b.completion_percentage()): @@ -187,10 +200,17 @@ class HyperBandScheduler(FIFOScheduler): return None def debug_string(self): + brackets = [ + "({0}/{1})".format( + len(bracket._live_trials), len(bracket._all_trials)) + for band in self._hyperbands for bracket in band] return " ".join([ "Using HyperBand:", "num_stopped={}".format(self._num_stopped), - "brackets={}".format(sum(len(band) for band in self._hyperbands))]) + "total_brackets={}".format( + sum(len(band) for band in self._hyperbands)), + " ".join(brackets) + ]) class Bracket(): @@ -278,8 +298,9 @@ class Bracket(): self._live_trials[trial] = (result, itr - 1) self._completed_progress += 1 - def cleanup_trial_early(self, trial): - """Clean up statistics tracking for trial that terminated early. + def cleanup_trial(self, trial): + """Clean up statistics tracking for terminated trials (either by force + or otherwise). This may cause bad trials to continue for a long time, in the case where all the good trials finish early and there are only bad trials diff --git a/test/trial_scheduler_test.py b/test/trial_scheduler_test.py index 38deab8c6..af2f43d49 100644 --- a/test/trial_scheduler_test.py +++ b/test/trial_scheduler_test.py @@ -151,17 +151,26 @@ class _MockTrialRunner(): class HyperbandSuite(unittest.TestCase): - def basicSetup(self): - """s_max_1 = 3; - brackets: iter (n, r) | iter (n, r) | iter (n, r) - (9, 1) -> (3, 3) -> (1, 9) - (9, 1) -> (3, 3) -> (1, 9) - """ + def schedulerSetup(self, num_trials): + """Setup a scheduler and Runner with max Iter = 9 + + Bracketing is placed as follows: + (3, 9); + (5, 3) -> (2, 9); + (9, 1) -> (3, 3) -> (1, 9); """ sched = HyperBandScheduler(9, eta=3) - for i in range(17): + for i in range(num_trials): t = Trial("t%d" % i, "__fake") sched.on_trial_add(None, t) + runner = _MockTrialRunner() + return sched, runner + + def basicSetup(self): + """Setup and verify full band. + """ + + sched, _ = self.schedulerSetup(17) self.assertEqual(len(sched._hyperbands), 1) self.assertEqual(sched._cur_band_filled(), True) @@ -173,69 +182,80 @@ class HyperbandSuite(unittest.TestCase): def advancedSetup(self): sched = self.basicSetup() - for i in range(3): + for i in range(4): t = Trial("t%d" % (i + 20), "__fake") sched.on_trial_add(None, t) self.assertEqual(sched._cur_band_filled(), False) - unfilled_band = sched._hyperbands[1] - self.assertEqual(len(unfilled_band), 1) - self.assertEqual(len(sched._hyperbands[1]), 1) - bracket = unfilled_band[0] + unfilled_band = sched._hyperbands[-1] + self.assertEqual(len(unfilled_band), 2) + bracket = unfilled_band[-1] self.assertEqual(bracket.filled(), False) - self.assertEqual(len(bracket.current_trials()), 3) + self.assertEqual(len(bracket.current_trials()), 1) return sched - def testBasicHalving(self): - sched = self.advancedSetup() - mock_runner = _MockTrialRunner() - filled_band = sched._hyperbands[0] - big_bracket = filled_band[0] - bracket_trials = big_bracket.current_trials() - - for t in bracket_trials: - mock_runner._launch_trial(t) - - for i, t in enumerate(bracket_trials): - if i == len(bracket_trials) - 1: - break - self.assertEqual( - TrialScheduler.PAUSE, - sched.on_trial_result(mock_runner, t, result(i, 10))) - mock_runner._pause_trial(t) - self.assertEqual( - TrialScheduler.CONTINUE, - sched.on_trial_result( - mock_runner, bracket_trials[-1], result(7, 12))) + def stopTrial(self, trial, mock_runner): + self.assertNotEqual(trial.status, Trial.TERMINATED) + mock_runner._stop_trial(trial) def testSuccessiveHalving(self): - sched = HyperBandScheduler(9, eta=3) - for i in range(9): - t = Trial("t%d" % i, "__fake") - sched.on_trial_add(None, t) - filled_band = sched._hyperbands[0] - big_bracket = filled_band[0] - mock_runner = _MockTrialRunner() + """Setup full band, then iterate through last bracket (n=9) + to make sure successive halving is correct.""" + + sched, mock_runner = self.schedulerSetup(17) + filled_band = sched._hyperbands[0][-1] + big_bracket = filled_band + + for trl in big_bracket.current_trials(): + mock_runner._launch_trial(trl) + + # Provides results from 0 to 8 in order, keeping the last one running + for i, trl in enumerate(big_bracket.current_trials()): + status = sched.on_trial_result(mock_runner, trl, result(1, i)) + if status == TrialScheduler.CONTINUE: + continue + elif status == TrialScheduler.PAUSE: + mock_runner._pause_trial(trl) + elif status == TrialScheduler.STOP: + self.assertNotEqual(trl.status, Trial.TERMINATED) + self.stopTrial(trl, mock_runner) current_length = len(big_bracket.current_trials()) - for i in range(current_length): - trl = sched.choose_trial_to_run(mock_runner) - mock_runner._launch_trial(trl) - while True: - status = sched.on_trial_result(mock_runner, trl, result(1, 10)) - if status == TrialScheduler.CONTINUE: - continue - elif status == TrialScheduler.PAUSE: - mock_runner._pause_trial(trl) - break + self.assertEqual(status, TrialScheduler.CONTINUE) + self.assertEqual(current_length, 3) - def testBasicRun(self): + # Techincally only need to launch 2/3, as one is already running + for trl in big_bracket.current_trials(): + mock_runner._launch_trial(trl) + + # Provides results from 2 to 0 in order, killing the last one + for i, trl in reversed(list(enumerate(big_bracket.current_trials()))): + for j in range(3): + status = sched.on_trial_result(mock_runner, trl, result(1, i)) + if status == TrialScheduler.CONTINUE: + continue + elif status == TrialScheduler.PAUSE: + mock_runner._pause_trial(trl) + elif status == TrialScheduler.STOP: + self.stopTrial(trl, mock_runner) + + self.assertEqual(status, TrialScheduler.STOP) + trl = big_bracket.current_trials()[0] + for i in range(9): + status = sched.on_trial_result(mock_runner, trl, result(1, i)) + self.assertEqual(status, TrialScheduler.STOP) + self.assertEqual(len(big_bracket.current_trials()), 0) + self.assertEqual(sched._num_stopped, 9) + + def testScheduling(self): + """Setup two bands, then make sure all trials are running""" sched = self.advancedSetup() mock_runner = _MockTrialRunner() trl = sched.choose_trial_to_run(mock_runner) while trl: + # If band iteration > 0, make sure first band is all running if sched._trial_info[trl][1] > 0: first_band = sched._hyperbands[0] trials = [t for b in first_band for t in b._live_trials] @@ -252,93 +272,67 @@ class HyperbandSuite(unittest.TestCase): all(t.status == Trial.RUNNING for t in trials), True) def testTrialErrored(self): - sched = HyperBandScheduler(9, eta=3) - t1 = Trial("t1", "__fake") - t2 = Trial("t2", "__fake") - sched.on_trial_add(None, t1) - sched.on_trial_add(None, t2) - mock_runner = _MockTrialRunner() - filled_band = sched._hyperbands[0] - big_bracket = filled_band[0] - bracket_trials = big_bracket.current_trials() - - for t in bracket_trials: - mock_runner._launch_trial(t) + sched, mock_runner = self.schedulerSetup(10) + t1, t2 = sched._state["bracket"].current_trials() + mock_runner._launch_trial(t1) + mock_runner._launch_trial(t2) sched.on_trial_error(mock_runner, t2) self.assertEqual( TrialScheduler.CONTINUE, - sched.on_trial_result(mock_runner, t1, result(3, 10))) + sched.on_trial_result(mock_runner, t1, result(1, 10))) + + def testTrialErrored2(self): + """Check successive halving happened even when last trial failed""" + sched, mock_runner = self.schedulerSetup(17) + trials = sched._state["bracket"].current_trials() + self.assertEqual(len(trials), 9) + for t in trials[:-1]: + mock_runner._launch_trial(t) + sched.on_trial_result(mock_runner, t, result(1, 10)) + + mock_runner._launch_trial(trials[-1]) + sched.on_trial_error(mock_runner, trials[-1]) + self.assertEqual(len(sched._state["bracket"].current_trials()), 3) def testTrialEndedEarly(self): - sched = HyperBandScheduler(9, eta=3) - t1 = Trial("t1", "__fake") - t2 = Trial("t2", "__fake") - sched.on_trial_add(None, t1) - sched.on_trial_add(None, t2) - mock_runner = _MockTrialRunner() - filled_band = sched._hyperbands[0] - big_bracket = filled_band[0] - bracket_trials = big_bracket.current_trials() - - for t in bracket_trials: + sched, mock_runner = self.schedulerSetup(10) + trials = sched._state["bracket"].current_trials() + for t in trials: mock_runner._launch_trial(t) - sched.on_trial_complete(mock_runner, t2, result(5, 10)) + sched.on_trial_complete(mock_runner, trials[-1], result(1, 12)) self.assertEqual( TrialScheduler.CONTINUE, - sched.on_trial_result(mock_runner, t1, result(3, 12))) + sched.on_trial_result(mock_runner, trials[0], result(1, 12))) - def testAddAfterHalf(self): - sched = HyperBandScheduler(9, eta=3) - for i in range(2): - t = Trial("t%d" % i, "__fake") - sched.on_trial_add(None, t) - mock_runner = _MockTrialRunner() - filled_band = sched._hyperbands[0] - big_bracket = filled_band[0] - bracket_trials = big_bracket.current_trials() + def testTrialEndedEarly2(self): + """Check successive halving happened even when last trial finished""" + sched, mock_runner = self.schedulerSetup(17) + trials = sched._state["bracket"].current_trials() + self.assertEqual(len(trials), 9) + for t in trials[:-1]: + mock_runner._launch_trial(t) + sched.on_trial_result(mock_runner, t, result(1, 10)) + + mock_runner._launch_trial(trials[-1]) + sched.on_trial_complete(mock_runner, trials[-1], result(1, 12)) + self.assertEqual(len(sched._state["bracket"].current_trials()), 3) + + def testAddAfterHalving(self): + sched, mock_runner = self.schedulerSetup(10) + bracket_trials = sched._state["bracket"].current_trials() for t in bracket_trials: mock_runner._launch_trial(t) for i, t in enumerate(bracket_trials): - if i == len(bracket_trials) - 1: - break - self.assertEqual( - TrialScheduler.PAUSE, - sched.on_trial_result(mock_runner, t, result(i, 10))) - mock_runner._pause_trial(t) - self.assertEqual( - TrialScheduler.CONTINUE, - sched.on_trial_result( - mock_runner, bracket_trials[-1], result(7, 12))) + res = sched.on_trial_result( + mock_runner, t, result(1, i)) + self.assertEqual(res, TrialScheduler.CONTINUE) t = Trial("t%d" % 5, "__fake") sched.on_trial_add(None, t) - self.assertEqual(4, big_bracket._live_trials[t][1]) - - def testDone(self): - sched = HyperBandScheduler(3, eta=3) - mock_runner = _MockTrialRunner() - trials = [Trial("t%d" % i, "__fake") for i in range(5)] - for t in trials: - sched.on_trial_add(None, t) - - filled_band = sched._hyperbands[0] - brack = filled_band[1] - bracket_trials = brack.current_trials() - for t in bracket_trials: - mock_runner._launch_trial(t) - for i in range(3): - res = sched.on_trial_result( - mock_runner, bracket_trials[-1], result(i, 10)) - self.assertEqual(res, TrialScheduler.PAUSE) - mock_runner._pause_trial(bracket_trials[-1]) - for i in range(3): - res = sched.on_trial_result( - mock_runner, bracket_trials[-2], result(i, 10)) - self.assertEqual(res, TrialScheduler.STOP) - self.assertEqual(len(brack.current_trials()), 1) + self.assertEqual(3 + 1, sched._state["bracket"]._live_trials[t][1]) if __name__ == "__main__":