[tune] Make HyperBand Usable (#1215)

This commit is contained in:
Richard Liaw 2017-11-16 10:31:42 -08:00 committed by GitHub
parent 3a0206a1f4
commit eadb998643
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 379 additions and 193 deletions

View file

@ -1,7 +1,7 @@
cartpole-ppo:
env: CartPole-v0
alg: PPO
num_trials: 20
repeat: 3
stop:
episode_reward_mean: 200
time_total_s: 180

View file

@ -8,13 +8,11 @@ from ray.tune.trial_scheduler import FIFOScheduler, TrialScheduler
from ray.tune.trial import Trial
def calculate_bracket_count(max_iter, eta):
return int(np.log(max_iter)/np.log(eta)) + 1
class HyperBandScheduler(FIFOScheduler):
"""Implements HyperBand.
Blog post: https://people.eecs.berkeley.edu/~kjamieson/hyperband.html
This implementation contains 3 logical levels.
Each HyperBand iteration is a "band". There can be multiple
bands running at once, and there can be 1 band that is incomplete.
@ -30,26 +28,39 @@ class HyperBandScheduler(FIFOScheduler):
Trials added will be inserted into the most recent bracket
and band and will spill over to new brackets/bands accordingly.
This maintains the bracket size and max trial count per band
to 5 and 117 respectively, which correspond to that of
`max_attr=81, eta=3` from the blog post. Trials will fill up
from smallest bracket to largest, with largest
having the most rounds of successive halving.
Args:
time_attr (str): The TrainingResult attr to use for comparing time.
Note that you can pass in something non-temporal such as
`training_iteration` as a measure of progress, the only requirement
is that the attribute should increase monotonically.
reward_attr (str): The TrainingResult objective value attribute. As
with `time_attr`, this may refer to any objective value. Stopping
procedures will use this attribute.
max_t (int): max time units per trial. Trials will be stopped after
max_t time units (determined by time_attr) have passed.
The HyperBand scheduler automatically tries to determine a
reasonable number of brackets based on this and eta.
"""
def __init__(self, max_iter=200, eta=3):
"""
args:
max_iter (int): maximum iterations per configuration
eta (int): # defines downsampling rate (default=3)
"""
assert max_iter > 0, "Max Iterations not valid!"
assert eta > 1, "Downsampling rate (eta) not valid!"
def __init__(
self, time_attr='training_iteration',
reward_attr='episode_reward_mean', max_t=81):
assert max_t > 0, "Max (time_attr) not valid!"
FIFOScheduler.__init__(self)
self._eta = eta
self._s_max_1 = s_max_1 = calculate_bracket_count(max_iter, eta)
# total number of iterations per execution of Succesive Halving (n,r)
B = s_max_1 * max_iter
# bracket trial count total
self._get_n0 = lambda s: int(np.ceil(B/max_iter/(s+1)*eta**s))
self._eta = 3
self._s_max_1 = 5
# bracket max trials
self._get_n0 = lambda s: int(
np.ceil(self._s_max_1/(s+1) * self._eta**s))
# bracket initial iterations
self._get_r0 = lambda s: int(max_iter*eta**(-s))
self._get_r0 = lambda s: int((max_t*self._eta**(-s)))
self._hyperbands = [[]] # list of hyperband iterations
self._trial_info = {} # Stores Trial -> Bracket, Band Iteration
@ -57,6 +68,8 @@ class HyperBandScheduler(FIFOScheduler):
self._state = {"bracket": None,
"band_idx": 0}
self._num_stopped = 0
self._reward_attr = reward_attr
self._time_attr = time_attr
def on_trial_add(self, trial_runner, trial):
"""On a new trial add, if current bracket is not filled,
@ -67,22 +80,27 @@ class HyperBandScheduler(FIFOScheduler):
cur_bracket = self._state["bracket"]
cur_band = self._hyperbands[self._state["band_idx"]]
if cur_bracket is None or cur_bracket.filled():
retry = True
while retry:
# if current iteration is filled, create new iteration
if self._cur_band_filled():
cur_band = []
self._hyperbands.append(cur_band)
self._state["band_idx"] += 1
# if current iteration is filled, create new iteration
if self._cur_band_filled():
cur_band = []
self._hyperbands.append(cur_band)
self._state["band_idx"] += 1
# cur_band will always be less than s_max_1 or else filled
s = len(cur_band)
assert s < self._s_max_1, "Current band is filled!"
# create new bracket
cur_bracket = Bracket(self._get_n0(s),
self._get_r0(s), self._eta, s)
cur_band.append(cur_bracket)
self._state["bracket"] = cur_bracket
# cur_band will always be less than s_max_1 or else filled
s = len(cur_band)
assert s < self._s_max_1, "Current band is filled!"
if self._get_r0(s) == 0:
print("Bracket too small - Retrying...")
cur_bracket = None
else:
retry = False
cur_bracket = Bracket(
self._time_attr, self._get_n0(s), self._get_r0(s),
self._eta, s)
cur_band.append(cur_bracket)
self._state["bracket"] = cur_bracket
self._state["bracket"].add_trial(trial)
self._trial_info[trial] = cur_bracket, self._state["band_idx"]
@ -128,9 +146,9 @@ class HyperBandScheduler(FIFOScheduler):
if bracket.cur_iter_done():
if bracket.finished():
self._cleanup_bracket(trial_runner, bracket)
return TrialScheduler.STOP
return TrialScheduler.CONTINUE
good, bad = bracket.successive_halving()
good, bad = bracket.successive_halving(self._reward_attr)
# kill bad trials
for t in bad:
if t.status == Trial.PAUSED:
@ -141,14 +159,15 @@ class HyperBandScheduler(FIFOScheduler):
else:
raise Exception("Trial with unexpected status encountered")
# ready the good trials
# ready the good trials - if trial is too far ahead, don't continue
for t in good:
if t.status == Trial.PAUSED:
t.unpause()
elif t.status == Trial.RUNNING:
action = TrialScheduler.CONTINUE
else:
if t.status not in [Trial.PAUSED, Trial.RUNNING]:
raise Exception("Trial with unexpected status encountered")
if bracket.continue_trial(t):
if t.status == Trial.PAUSED:
t.unpause()
elif t.status == Trial.RUNNING:
action = TrialScheduler.CONTINUE
return action
def _cleanup_trial(self, trial_runner, t, bracket, hard=False):
@ -162,11 +181,14 @@ class HyperBandScheduler(FIFOScheduler):
bracket.cleanup_trial(t)
def _cleanup_bracket(self, trial_runner, bracket):
"""Cleans up bracket after bracket is completely finished."""
"""Cleans up bracket after bracket is completely finished.
Lets the last trial continue to run until termination condition
kicks in."""
for trial in bracket.current_trials():
self._cleanup_trial(
trial_runner, trial, bracket,
hard=(trial.status == Trial.PAUSED))
if (trial.status == Trial.PAUSED):
self._cleanup_trial(
trial_runner, trial, bracket,
hard=True)
def on_trial_complete(self, trial_runner, trial, result):
"""Cleans up trial info from bracket if trial completed early."""
@ -219,12 +241,15 @@ class Bracket():
Also keeps track of progress to ensure good scheduling.
"""
def __init__(self, max_trials, init_iters, eta, s):
self._live_trials = {} # stores (result, itrs left before halving)
def __init__(self, time_attr, max_trials, init_t_attr, eta, s):
self._live_trials = {} # maps trial -> current result
self._all_trials = []
self._time_attr = time_attr # attribute to
self._n = self._n0 = max_trials
self._r = self._r0 = init_iters
self._r = self._r0 = init_t_attr
self._cumul_r = self._r0
self._eta = eta
self._halves = s
@ -237,15 +262,15 @@ class Bracket():
At a later iteration, a newly added trial will be given equal
opportunity to catch up."""
assert not self.filled(), "Cannot add trial to filled bracket!"
self._live_trials[trial] = (None, self._cumul_r)
self._live_trials[trial] = None
self._all_trials.append(trial)
def cur_iter_done(self):
"""Checks if all iterations have completed.
TODO(rliaw): also check that `t.iterations == self._r`"""
all_done = all(itr == 0 for _, itr in self._live_trials.values())
return all_done
return all(self._get_result_time(result) >= self._cumul_r
for result in self._live_trials.values())
def finished(self):
return self._halves == 0 and self.cur_iter_done()
@ -254,8 +279,8 @@ class Bracket():
return list(self._live_trials)
def continue_trial(self, trial):
_, itr = self._live_trials[trial]
if itr > 0:
result = self._live_trials[trial]
if self._get_result_time(result) < self._cumul_r:
return True
else:
return False
@ -265,24 +290,19 @@ class Bracket():
minimizing the need to backtrack and bookkeep previous medians"""
return len(self._live_trials) == self._n
def successive_halving(self):
def successive_halving(self, reward_attr):
assert self._halves > 0
self._halves -= 1
self._n /= self._eta
self._n = int(np.ceil(self._n))
self._r *= self._eta
self._r = int(np.ceil(self._r))
self._r = int((self._r))
self._cumul_r += self._r
sorted_trials = sorted(
self._live_trials,
key=lambda t: self._live_trials[t][0].episode_reward_mean)
key=lambda t: getattr(self._live_trials[t], reward_attr))
good, bad = sorted_trials[-self._n:], sorted_trials[:-self._n]
# reset good trials to track updated iterations
for t in good:
res, old_itr = self._live_trials[t]
self._live_trials[t] = (res, self._r)
return good, bad
def update_trial_stats(self, trial, result):
@ -293,10 +313,13 @@ class Bracket():
in and make sure they're not set as pending later."""
assert trial in self._live_trials
_, itr = self._live_trials[trial]
assert itr > 0
self._live_trials[trial] = (result, itr - 1)
self._completed_progress += 1
assert self._get_result_time(result) >= 0
delta = self._get_result_time(result) - \
self._get_result_time(self._live_trials[trial])
assert delta >= 0
self._completed_progress += delta
self._live_trials[trial] = result
def cleanup_trial(self, trial):
"""Clean up statistics tracking for terminated trials (either by force
@ -315,6 +338,11 @@ class Bracket():
are dropped."""
return self._completed_progress / self._total_work
def _get_result_time(self, result):
if result is None:
return 0
return getattr(result, self._time_attr)
def _calculate_total_work(self, n, r, s):
work = 0
for i in range(s+1):
@ -322,6 +350,7 @@ class Bracket():
n /= self._eta
n = int(np.ceil(n))
r *= self._eta
r = int(r)
return work
def __repr__(self):

View file

@ -3,6 +3,7 @@ from __future__ import division
from __future__ import print_function
import unittest
import numpy as np
from ray.tune.hyperband import HyperBandScheduler
from ray.tune.median_stopping_rule import MedianStoppingRule
@ -12,7 +13,9 @@ from ray.tune.trial_scheduler import TrialScheduler
def result(t, rew):
return TrainingResult(time_total_s=t, episode_reward_mean=rew)
return TrainingResult(time_total_s=t,
episode_reward_mean=rew,
training_iteration=int(t))
class EarlyStoppingSuite(unittest.TestCase):
@ -156,21 +159,46 @@ class HyperbandSuite(unittest.TestCase):
"""Setup a scheduler and Runner with max Iter = 9
Bracketing is placed as follows:
(3, 9);
(5, 3) -> (2, 9);
(9, 1) -> (3, 3) -> (1, 9); """
sched = HyperBandScheduler(9, eta=3)
(5, 81);
(8, 27) -> (3, 81);
(15, 9) -> (5, 27) -> (2, 81);
(34, 3) -> (12, 9) -> (4, 27) -> (2, 81);
(81, 1) -> (27, 3) -> (9, 9) -> (3, 27) -> (1, 81);"""
sched = HyperBandScheduler()
for i in range(num_trials):
t = Trial("t%d" % i, "__fake")
sched.on_trial_add(None, t)
runner = _MockTrialRunner()
return sched, runner
def default_statistics(self):
"""Default statistics for HyperBand"""
sched = HyperBandScheduler()
res = {
str(s): {"n": sched._get_n0(s), "r": sched._get_r0(s)}
for s in range(sched._s_max_1)
}
res["max_trials"] = sum(v["n"] for v in res.values())
res["brack_count"] = sched._s_max_1
res["s_max"] = sched._s_max_1 - 1
return res
def downscale(self, n, sched):
return int(np.ceil(n / sched._eta))
def process(self, trl, mock_runner, action):
if action == TrialScheduler.CONTINUE:
pass
elif action == TrialScheduler.PAUSE:
mock_runner._pause_trial(trl)
elif action == TrialScheduler.STOP:
self.stopTrial(trl, mock_runner)
def basicSetup(self):
"""Setup and verify full band.
"""
sched, _ = self.schedulerSetup(17)
stats = self.default_statistics()
sched, _ = self.schedulerSetup(stats["max_trials"])
self.assertEqual(len(sched._hyperbands), 1)
self.assertEqual(sched._cur_band_filled(), True)
@ -192,7 +220,7 @@ class HyperbandSuite(unittest.TestCase):
self.assertEqual(len(unfilled_band), 2)
bracket = unfilled_band[-1]
self.assertEqual(bracket.filled(), False)
self.assertEqual(len(bracket.current_trials()), 1)
self.assertEqual(len(bracket.current_trials()), 7)
return sched
@ -200,19 +228,254 @@ class HyperbandSuite(unittest.TestCase):
self.assertNotEqual(trial.status, Trial.TERMINATED)
mock_runner._stop_trial(trial)
def testSuccessiveHalving(self):
"""Setup full band, then iterate through last bracket (n=9)
to make sure successive halving is correct."""
def testConfigSameEta(self):
sched = HyperBandScheduler()
i = 0
while not sched._cur_band_filled():
t = Trial("t%d" % (i), "__fake")
sched.on_trial_add(None, t)
i += 1
self.assertEqual(len(sched._hyperbands[0]), 5)
self.assertEqual(sched._hyperbands[0][0]._n, 5)
self.assertEqual(sched._hyperbands[0][0]._r, 81)
self.assertEqual(sched._hyperbands[0][-1]._n, 81)
self.assertEqual(sched._hyperbands[0][-1]._r, 1)
sched, mock_runner = self.schedulerSetup(17)
filled_band = sched._hyperbands[0][-1]
big_bracket = filled_band
sched = HyperBandScheduler(max_t=810)
i = 0
while not sched._cur_band_filled():
t = Trial("t%d" % (i), "__fake")
sched.on_trial_add(None, t)
i += 1
self.assertEqual(len(sched._hyperbands[0]), 5)
self.assertEqual(sched._hyperbands[0][0]._n, 5)
self.assertEqual(sched._hyperbands[0][0]._r, 810)
self.assertEqual(sched._hyperbands[0][-1]._n, 81)
self.assertEqual(sched._hyperbands[0][-1]._r, 10)
def testConfigSameEtaSmall(self):
sched = HyperBandScheduler(max_t=1)
i = 0
while len(sched._hyperbands) < 2:
t = Trial("t%d" % (i), "__fake")
sched.on_trial_add(None, t)
i += 1
self.assertEqual(len(sched._hyperbands[0]), 5)
self.assertTrue(all(v is None for v in sched._hyperbands[0][1:]))
def testSuccessiveHalving(self):
"""Setup full band, then iterate through last bracket (n=81)
to make sure successive halving is correct."""
stats = self.default_statistics()
sched, mock_runner = self.schedulerSetup(stats["max_trials"])
big_bracket = sched._state["bracket"]
cur_units = stats[str(stats["s_max"])]["r"]
# The last bracket will downscale 4 times
for x in range(stats["brack_count"] - 1):
trials = big_bracket.current_trials()
current_length = len(trials)
for trl in trials:
mock_runner._launch_trial(trl)
# Provides results from 0 to 8 in order, keeping last one running
for i, trl in enumerate(trials):
action = sched.on_trial_result(
mock_runner, trl, result(cur_units, i))
if i < current_length - 1:
self.assertEqual(action, TrialScheduler.PAUSE)
self.process(trl, mock_runner, action)
self.assertEqual(action, TrialScheduler.CONTINUE)
new_length = len(big_bracket.current_trials())
self.assertEqual(new_length, self.downscale(current_length, sched))
cur_units += int(cur_units * sched._eta)
self.assertEqual(len(big_bracket.current_trials()), 1)
def testHalvingStop(self):
stats = self.default_statistics()
num_trials = stats[str(0)]["n"] + stats[str(1)]["n"]
sched, mock_runner = self.schedulerSetup(num_trials)
big_bracket = sched._state["bracket"]
for trl in big_bracket.current_trials():
mock_runner._launch_trial(trl)
# # Provides result in reverse order, killing the last one
cur_units = stats[str(1)]["r"]
for i, trl in reversed(list(enumerate(big_bracket.current_trials()))):
action = sched.on_trial_result(
mock_runner, trl, result(cur_units, i))
self.process(trl, mock_runner, action)
self.assertEqual(action, TrialScheduler.STOP)
def testContinueLastOne(self):
stats = self.default_statistics()
num_trials = stats[str(0)]["n"]
sched, mock_runner = self.schedulerSetup(num_trials)
big_bracket = sched._state["bracket"]
for trl in big_bracket.current_trials():
mock_runner._launch_trial(trl)
# # Provides result in reverse order, killing the last one
cur_units = stats[str(0)]["r"]
for i, trl in enumerate(big_bracket.current_trials()):
action = sched.on_trial_result(
mock_runner, trl, result(cur_units, i))
self.process(trl, mock_runner, action)
self.assertEqual(action, TrialScheduler.CONTINUE)
for x in range(100):
action = sched.on_trial_result(
mock_runner, trl, result(cur_units + x, 10))
self.assertEqual(action, TrialScheduler.CONTINUE)
def testTrialErrored(self):
"""If a trial errored, make sure successive halving still happens"""
stats = self.default_statistics()
trial_count = stats[str(0)]["n"] + 3
sched, mock_runner = self.schedulerSetup(trial_count)
t1, t2, t3 = sched._state["bracket"].current_trials()
for t in [t1, t2, t3]:
mock_runner._launch_trial(t)
sched.on_trial_error(mock_runner, t3)
self.assertEqual(
TrialScheduler.PAUSE,
sched.on_trial_result(
mock_runner, t1, result(stats[str(1)]["r"], 10)))
self.assertEqual(
TrialScheduler.CONTINUE,
sched.on_trial_result(
mock_runner, t2, result(stats[str(1)]["r"], 10)))
def testTrialErrored2(self):
"""Check successive halving happened even when last trial failed"""
stats = self.default_statistics()
trial_count = stats[str(0)]["n"] + stats[str(1)]["n"]
sched, mock_runner = self.schedulerSetup(trial_count)
trials = sched._state["bracket"].current_trials()
for t in trials[:-1]:
mock_runner._launch_trial(t)
sched.on_trial_result(
mock_runner, t, result(stats[str(1)]["r"], 10))
mock_runner._launch_trial(trials[-1])
sched.on_trial_error(mock_runner, trials[-1])
self.assertEqual(len(sched._state["bracket"].current_trials()),
self.downscale(stats[str(1)]["n"], sched))
def testTrialEndedEarly(self):
"""Check successive halving happened even when one trial failed"""
stats = self.default_statistics()
trial_count = stats[str(0)]["n"] + 3
sched, mock_runner = self.schedulerSetup(trial_count)
t1, t2, t3 = sched._state["bracket"].current_trials()
for t in [t1, t2, t3]:
mock_runner._launch_trial(t)
sched.on_trial_complete(mock_runner, t3, result(1, 12))
self.assertEqual(
TrialScheduler.PAUSE,
sched.on_trial_result(
mock_runner, t1, result(stats[str(1)]["r"], 10)))
self.assertEqual(
TrialScheduler.CONTINUE,
sched.on_trial_result(
mock_runner, t2, result(stats[str(1)]["r"], 10)))
def testTrialEndedEarly2(self):
"""Check successive halving happened even when last trial failed"""
stats = self.default_statistics()
trial_count = stats[str(0)]["n"] + stats[str(1)]["n"]
sched, mock_runner = self.schedulerSetup(trial_count)
trials = sched._state["bracket"].current_trials()
for t in trials[:-1]:
mock_runner._launch_trial(t)
sched.on_trial_result(
mock_runner, t, result(stats[str(1)]["r"], 10))
mock_runner._launch_trial(trials[-1])
sched.on_trial_complete(mock_runner, trials[-1], result(100, 12))
self.assertEqual(len(sched._state["bracket"].current_trials()),
self.downscale(stats[str(1)]["n"], sched))
def testAddAfterHalving(self):
stats = self.default_statistics()
trial_count = stats[str(0)]["n"] + 1
sched, mock_runner = self.schedulerSetup(trial_count)
bracket_trials = sched._state["bracket"].current_trials()
init_units = stats[str(1)]["r"]
for t in bracket_trials:
mock_runner._launch_trial(t)
for i, t in enumerate(bracket_trials):
status = sched.on_trial_result(
mock_runner, t, result(init_units, i))
self.assertEqual(status, TrialScheduler.CONTINUE)
t = Trial("t%d" % 100, "__fake")
sched.on_trial_add(None, t)
mock_runner._launch_trial(t)
self.assertEqual(len(sched._state["bracket"].current_trials()), 2)
# Make sure that newly added trial gets fair computation (not just 1)
self.assertEqual(
TrialScheduler.CONTINUE,
sched.on_trial_result(mock_runner, t, result(init_units, 12)))
new_units = init_units + int(init_units * sched._eta)
self.assertEqual(
TrialScheduler.PAUSE,
sched.on_trial_result(mock_runner, t, result(new_units, 12)))
def testAlternateMetrics(self):
"""Checking that alternate metrics will pass."""
def result2(t, rew):
return TrainingResult(time_total_s=t, neg_mean_loss=rew)
sched = HyperBandScheduler(
time_attr='time_total_s', reward_attr='neg_mean_loss')
stats = self.default_statistics()
for i in range(stats["max_trials"]):
t = Trial("t%d" % i, "__fake")
sched.on_trial_add(None, t)
runner = _MockTrialRunner()
big_bracket = sched._hyperbands[0][-1]
for trl in big_bracket.current_trials():
runner._launch_trial(trl)
current_length = len(big_bracket.current_trials())
# Provides results from 0 to 8 in order, keeping the last one running
for i, trl in enumerate(big_bracket.current_trials()):
status = sched.on_trial_result(runner, trl, result2(1, i))
if status == TrialScheduler.CONTINUE:
continue
elif status == TrialScheduler.PAUSE:
runner._pause_trial(trl)
elif status == TrialScheduler.STOP:
self.assertNotEqual(trl.status, Trial.TERMINATED)
self.stopTrial(trl, runner)
new_length = len(big_bracket.current_trials())
self.assertEqual(status, TrialScheduler.CONTINUE)
self.assertEqual(new_length, self.downscale(current_length, sched))
def testJumpingTime(self):
sched, mock_runner = self.schedulerSetup(81)
big_bracket = sched._hyperbands[0][-1]
for trl in big_bracket.current_trials():
mock_runner._launch_trial(trl)
# Provides results from 0 to 8 in order, keeping the last one running
for i, trl in enumerate(big_bracket.current_trials()):
main_trials = big_bracket.current_trials()[:-1]
jump = big_bracket.current_trials()[-1]
for i, trl in enumerate(main_trials):
status = sched.on_trial_result(mock_runner, trl, result(1, i))
if status == TrialScheduler.CONTINUE:
continue
@ -222,117 +485,11 @@ class HyperbandSuite(unittest.TestCase):
self.assertNotEqual(trl.status, Trial.TERMINATED)
self.stopTrial(trl, mock_runner)
status = sched.on_trial_result(mock_runner, jump, result(4, i))
self.assertEqual(status, TrialScheduler.PAUSE)
current_length = len(big_bracket.current_trials())
self.assertEqual(status, TrialScheduler.CONTINUE)
self.assertEqual(current_length, 3)
# Techincally only need to launch 2/3, as one is already running
for trl in big_bracket.current_trials():
mock_runner._launch_trial(trl)
# Provides results from 2 to 0 in order, killing the last one
for i, trl in reversed(list(enumerate(big_bracket.current_trials()))):
for j in range(3):
status = sched.on_trial_result(mock_runner, trl, result(1, i))
if status == TrialScheduler.CONTINUE:
continue
elif status == TrialScheduler.PAUSE:
mock_runner._pause_trial(trl)
elif status == TrialScheduler.STOP:
self.stopTrial(trl, mock_runner)
self.assertEqual(status, TrialScheduler.STOP)
trl = big_bracket.current_trials()[0]
for i in range(9):
status = sched.on_trial_result(mock_runner, trl, result(1, i))
self.assertEqual(status, TrialScheduler.STOP)
self.assertEqual(len(big_bracket.current_trials()), 0)
self.assertEqual(sched._num_stopped, 9)
def testScheduling(self):
"""Setup two bands, then make sure all trials are running"""
sched = self.advancedSetup()
mock_runner = _MockTrialRunner()
trl = sched.choose_trial_to_run(mock_runner)
while trl:
# If band iteration > 0, make sure first band is all running
if sched._trial_info[trl][1] > 0:
first_band = sched._hyperbands[0]
trials = [t for b in first_band for t in b._live_trials]
self.assertEqual(
all(t.status == Trial.RUNNING for t in trials),
True)
mock_runner._launch_trial(trl)
res = sched.on_trial_result(mock_runner, trl, result(1, 10))
if res is TrialScheduler.PAUSE:
mock_runner._pause_trial(trl)
trl = sched.choose_trial_to_run(mock_runner)
self.assertEqual(
all(t.status == Trial.RUNNING for t in trials), True)
def testTrialErrored(self):
sched, mock_runner = self.schedulerSetup(10)
t1, t2 = sched._state["bracket"].current_trials()
mock_runner._launch_trial(t1)
mock_runner._launch_trial(t2)
sched.on_trial_error(mock_runner, t2)
self.assertEqual(
TrialScheduler.CONTINUE,
sched.on_trial_result(mock_runner, t1, result(1, 10)))
def testTrialErrored2(self):
"""Check successive halving happened even when last trial failed"""
sched, mock_runner = self.schedulerSetup(17)
trials = sched._state["bracket"].current_trials()
self.assertEqual(len(trials), 9)
for t in trials[:-1]:
mock_runner._launch_trial(t)
sched.on_trial_result(mock_runner, t, result(1, 10))
mock_runner._launch_trial(trials[-1])
sched.on_trial_error(mock_runner, trials[-1])
self.assertEqual(len(sched._state["bracket"].current_trials()), 3)
def testTrialEndedEarly(self):
sched, mock_runner = self.schedulerSetup(10)
trials = sched._state["bracket"].current_trials()
for t in trials:
mock_runner._launch_trial(t)
sched.on_trial_complete(mock_runner, trials[-1], result(1, 12))
self.assertEqual(
TrialScheduler.CONTINUE,
sched.on_trial_result(mock_runner, trials[0], result(1, 12)))
def testTrialEndedEarly2(self):
"""Check successive halving happened even when last trial finished"""
sched, mock_runner = self.schedulerSetup(17)
trials = sched._state["bracket"].current_trials()
self.assertEqual(len(trials), 9)
for t in trials[:-1]:
mock_runner._launch_trial(t)
sched.on_trial_result(mock_runner, t, result(1, 10))
mock_runner._launch_trial(trials[-1])
sched.on_trial_complete(mock_runner, trials[-1], result(1, 12))
self.assertEqual(len(sched._state["bracket"].current_trials()), 3)
def testAddAfterHalving(self):
sched, mock_runner = self.schedulerSetup(10)
bracket_trials = sched._state["bracket"].current_trials()
for t in bracket_trials:
mock_runner._launch_trial(t)
for i, t in enumerate(bracket_trials):
res = sched.on_trial_result(
mock_runner, t, result(1, i))
self.assertEqual(res, TrialScheduler.CONTINUE)
t = Trial("t%d" % 5, "__fake")
sched.on_trial_add(None, t)
self.assertEqual(3 + 1, sched._state["bracket"]._live_trials[t][1])
self.assertLess(current_length, 27)
if __name__ == "__main__":