mirror of
https://github.com/vale981/ray
synced 2025-03-08 11:31:40 -05:00
[tune] Make HyperBand Usable (#1215)
This commit is contained in:
parent
3a0206a1f4
commit
eadb998643
3 changed files with 379 additions and 193 deletions
|
@ -1,7 +1,7 @@
|
|||
cartpole-ppo:
|
||||
env: CartPole-v0
|
||||
alg: PPO
|
||||
num_trials: 20
|
||||
repeat: 3
|
||||
stop:
|
||||
episode_reward_mean: 200
|
||||
time_total_s: 180
|
||||
|
|
|
@ -8,13 +8,11 @@ from ray.tune.trial_scheduler import FIFOScheduler, TrialScheduler
|
|||
from ray.tune.trial import Trial
|
||||
|
||||
|
||||
def calculate_bracket_count(max_iter, eta):
|
||||
return int(np.log(max_iter)/np.log(eta)) + 1
|
||||
|
||||
|
||||
class HyperBandScheduler(FIFOScheduler):
|
||||
"""Implements HyperBand.
|
||||
|
||||
Blog post: https://people.eecs.berkeley.edu/~kjamieson/hyperband.html
|
||||
|
||||
This implementation contains 3 logical levels.
|
||||
Each HyperBand iteration is a "band". There can be multiple
|
||||
bands running at once, and there can be 1 band that is incomplete.
|
||||
|
@ -30,26 +28,39 @@ class HyperBandScheduler(FIFOScheduler):
|
|||
|
||||
Trials added will be inserted into the most recent bracket
|
||||
and band and will spill over to new brackets/bands accordingly.
|
||||
|
||||
This maintains the bracket size and max trial count per band
|
||||
to 5 and 117 respectively, which correspond to that of
|
||||
`max_attr=81, eta=3` from the blog post. Trials will fill up
|
||||
from smallest bracket to largest, with largest
|
||||
having the most rounds of successive halving.
|
||||
|
||||
Args:
|
||||
time_attr (str): The TrainingResult attr to use for comparing time.
|
||||
Note that you can pass in something non-temporal such as
|
||||
`training_iteration` as a measure of progress, the only requirement
|
||||
is that the attribute should increase monotonically.
|
||||
reward_attr (str): The TrainingResult objective value attribute. As
|
||||
with `time_attr`, this may refer to any objective value. Stopping
|
||||
procedures will use this attribute.
|
||||
max_t (int): max time units per trial. Trials will be stopped after
|
||||
max_t time units (determined by time_attr) have passed.
|
||||
The HyperBand scheduler automatically tries to determine a
|
||||
reasonable number of brackets based on this and eta.
|
||||
"""
|
||||
|
||||
def __init__(self, max_iter=200, eta=3):
|
||||
"""
|
||||
args:
|
||||
max_iter (int): maximum iterations per configuration
|
||||
eta (int): # defines downsampling rate (default=3)
|
||||
"""
|
||||
assert max_iter > 0, "Max Iterations not valid!"
|
||||
assert eta > 1, "Downsampling rate (eta) not valid!"
|
||||
|
||||
def __init__(
|
||||
self, time_attr='training_iteration',
|
||||
reward_attr='episode_reward_mean', max_t=81):
|
||||
assert max_t > 0, "Max (time_attr) not valid!"
|
||||
FIFOScheduler.__init__(self)
|
||||
self._eta = eta
|
||||
self._s_max_1 = s_max_1 = calculate_bracket_count(max_iter, eta)
|
||||
# total number of iterations per execution of Succesive Halving (n,r)
|
||||
B = s_max_1 * max_iter
|
||||
# bracket trial count total
|
||||
self._get_n0 = lambda s: int(np.ceil(B/max_iter/(s+1)*eta**s))
|
||||
self._eta = 3
|
||||
self._s_max_1 = 5
|
||||
# bracket max trials
|
||||
self._get_n0 = lambda s: int(
|
||||
np.ceil(self._s_max_1/(s+1) * self._eta**s))
|
||||
# bracket initial iterations
|
||||
self._get_r0 = lambda s: int(max_iter*eta**(-s))
|
||||
self._get_r0 = lambda s: int((max_t*self._eta**(-s)))
|
||||
self._hyperbands = [[]] # list of hyperband iterations
|
||||
self._trial_info = {} # Stores Trial -> Bracket, Band Iteration
|
||||
|
||||
|
@ -57,6 +68,8 @@ class HyperBandScheduler(FIFOScheduler):
|
|||
self._state = {"bracket": None,
|
||||
"band_idx": 0}
|
||||
self._num_stopped = 0
|
||||
self._reward_attr = reward_attr
|
||||
self._time_attr = time_attr
|
||||
|
||||
def on_trial_add(self, trial_runner, trial):
|
||||
"""On a new trial add, if current bracket is not filled,
|
||||
|
@ -67,22 +80,27 @@ class HyperBandScheduler(FIFOScheduler):
|
|||
cur_bracket = self._state["bracket"]
|
||||
cur_band = self._hyperbands[self._state["band_idx"]]
|
||||
if cur_bracket is None or cur_bracket.filled():
|
||||
retry = True
|
||||
while retry:
|
||||
# if current iteration is filled, create new iteration
|
||||
if self._cur_band_filled():
|
||||
cur_band = []
|
||||
self._hyperbands.append(cur_band)
|
||||
self._state["band_idx"] += 1
|
||||
|
||||
# if current iteration is filled, create new iteration
|
||||
if self._cur_band_filled():
|
||||
cur_band = []
|
||||
self._hyperbands.append(cur_band)
|
||||
self._state["band_idx"] += 1
|
||||
|
||||
# cur_band will always be less than s_max_1 or else filled
|
||||
s = len(cur_band)
|
||||
assert s < self._s_max_1, "Current band is filled!"
|
||||
|
||||
# create new bracket
|
||||
cur_bracket = Bracket(self._get_n0(s),
|
||||
self._get_r0(s), self._eta, s)
|
||||
cur_band.append(cur_bracket)
|
||||
self._state["bracket"] = cur_bracket
|
||||
# cur_band will always be less than s_max_1 or else filled
|
||||
s = len(cur_band)
|
||||
assert s < self._s_max_1, "Current band is filled!"
|
||||
if self._get_r0(s) == 0:
|
||||
print("Bracket too small - Retrying...")
|
||||
cur_bracket = None
|
||||
else:
|
||||
retry = False
|
||||
cur_bracket = Bracket(
|
||||
self._time_attr, self._get_n0(s), self._get_r0(s),
|
||||
self._eta, s)
|
||||
cur_band.append(cur_bracket)
|
||||
self._state["bracket"] = cur_bracket
|
||||
|
||||
self._state["bracket"].add_trial(trial)
|
||||
self._trial_info[trial] = cur_bracket, self._state["band_idx"]
|
||||
|
@ -128,9 +146,9 @@ class HyperBandScheduler(FIFOScheduler):
|
|||
if bracket.cur_iter_done():
|
||||
if bracket.finished():
|
||||
self._cleanup_bracket(trial_runner, bracket)
|
||||
return TrialScheduler.STOP
|
||||
return TrialScheduler.CONTINUE
|
||||
|
||||
good, bad = bracket.successive_halving()
|
||||
good, bad = bracket.successive_halving(self._reward_attr)
|
||||
# kill bad trials
|
||||
for t in bad:
|
||||
if t.status == Trial.PAUSED:
|
||||
|
@ -141,14 +159,15 @@ class HyperBandScheduler(FIFOScheduler):
|
|||
else:
|
||||
raise Exception("Trial with unexpected status encountered")
|
||||
|
||||
# ready the good trials
|
||||
# ready the good trials - if trial is too far ahead, don't continue
|
||||
for t in good:
|
||||
if t.status == Trial.PAUSED:
|
||||
t.unpause()
|
||||
elif t.status == Trial.RUNNING:
|
||||
action = TrialScheduler.CONTINUE
|
||||
else:
|
||||
if t.status not in [Trial.PAUSED, Trial.RUNNING]:
|
||||
raise Exception("Trial with unexpected status encountered")
|
||||
if bracket.continue_trial(t):
|
||||
if t.status == Trial.PAUSED:
|
||||
t.unpause()
|
||||
elif t.status == Trial.RUNNING:
|
||||
action = TrialScheduler.CONTINUE
|
||||
return action
|
||||
|
||||
def _cleanup_trial(self, trial_runner, t, bracket, hard=False):
|
||||
|
@ -162,11 +181,14 @@ class HyperBandScheduler(FIFOScheduler):
|
|||
bracket.cleanup_trial(t)
|
||||
|
||||
def _cleanup_bracket(self, trial_runner, bracket):
|
||||
"""Cleans up bracket after bracket is completely finished."""
|
||||
"""Cleans up bracket after bracket is completely finished.
|
||||
Lets the last trial continue to run until termination condition
|
||||
kicks in."""
|
||||
for trial in bracket.current_trials():
|
||||
self._cleanup_trial(
|
||||
trial_runner, trial, bracket,
|
||||
hard=(trial.status == Trial.PAUSED))
|
||||
if (trial.status == Trial.PAUSED):
|
||||
self._cleanup_trial(
|
||||
trial_runner, trial, bracket,
|
||||
hard=True)
|
||||
|
||||
def on_trial_complete(self, trial_runner, trial, result):
|
||||
"""Cleans up trial info from bracket if trial completed early."""
|
||||
|
@ -219,12 +241,15 @@ class Bracket():
|
|||
|
||||
Also keeps track of progress to ensure good scheduling.
|
||||
"""
|
||||
def __init__(self, max_trials, init_iters, eta, s):
|
||||
self._live_trials = {} # stores (result, itrs left before halving)
|
||||
def __init__(self, time_attr, max_trials, init_t_attr, eta, s):
|
||||
self._live_trials = {} # maps trial -> current result
|
||||
self._all_trials = []
|
||||
self._time_attr = time_attr # attribute to
|
||||
|
||||
self._n = self._n0 = max_trials
|
||||
self._r = self._r0 = init_iters
|
||||
self._r = self._r0 = init_t_attr
|
||||
self._cumul_r = self._r0
|
||||
|
||||
self._eta = eta
|
||||
self._halves = s
|
||||
|
||||
|
@ -237,15 +262,15 @@ class Bracket():
|
|||
At a later iteration, a newly added trial will be given equal
|
||||
opportunity to catch up."""
|
||||
assert not self.filled(), "Cannot add trial to filled bracket!"
|
||||
self._live_trials[trial] = (None, self._cumul_r)
|
||||
self._live_trials[trial] = None
|
||||
self._all_trials.append(trial)
|
||||
|
||||
def cur_iter_done(self):
|
||||
"""Checks if all iterations have completed.
|
||||
|
||||
TODO(rliaw): also check that `t.iterations == self._r`"""
|
||||
all_done = all(itr == 0 for _, itr in self._live_trials.values())
|
||||
return all_done
|
||||
return all(self._get_result_time(result) >= self._cumul_r
|
||||
for result in self._live_trials.values())
|
||||
|
||||
def finished(self):
|
||||
return self._halves == 0 and self.cur_iter_done()
|
||||
|
@ -254,8 +279,8 @@ class Bracket():
|
|||
return list(self._live_trials)
|
||||
|
||||
def continue_trial(self, trial):
|
||||
_, itr = self._live_trials[trial]
|
||||
if itr > 0:
|
||||
result = self._live_trials[trial]
|
||||
if self._get_result_time(result) < self._cumul_r:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
@ -265,24 +290,19 @@ class Bracket():
|
|||
minimizing the need to backtrack and bookkeep previous medians"""
|
||||
return len(self._live_trials) == self._n
|
||||
|
||||
def successive_halving(self):
|
||||
def successive_halving(self, reward_attr):
|
||||
assert self._halves > 0
|
||||
self._halves -= 1
|
||||
self._n /= self._eta
|
||||
self._n = int(np.ceil(self._n))
|
||||
self._r *= self._eta
|
||||
self._r = int(np.ceil(self._r))
|
||||
self._r = int((self._r))
|
||||
self._cumul_r += self._r
|
||||
sorted_trials = sorted(
|
||||
self._live_trials,
|
||||
key=lambda t: self._live_trials[t][0].episode_reward_mean)
|
||||
key=lambda t: getattr(self._live_trials[t], reward_attr))
|
||||
|
||||
good, bad = sorted_trials[-self._n:], sorted_trials[:-self._n]
|
||||
|
||||
# reset good trials to track updated iterations
|
||||
for t in good:
|
||||
res, old_itr = self._live_trials[t]
|
||||
self._live_trials[t] = (res, self._r)
|
||||
return good, bad
|
||||
|
||||
def update_trial_stats(self, trial, result):
|
||||
|
@ -293,10 +313,13 @@ class Bracket():
|
|||
in and make sure they're not set as pending later."""
|
||||
|
||||
assert trial in self._live_trials
|
||||
_, itr = self._live_trials[trial]
|
||||
assert itr > 0
|
||||
self._live_trials[trial] = (result, itr - 1)
|
||||
self._completed_progress += 1
|
||||
assert self._get_result_time(result) >= 0
|
||||
|
||||
delta = self._get_result_time(result) - \
|
||||
self._get_result_time(self._live_trials[trial])
|
||||
assert delta >= 0
|
||||
self._completed_progress += delta
|
||||
self._live_trials[trial] = result
|
||||
|
||||
def cleanup_trial(self, trial):
|
||||
"""Clean up statistics tracking for terminated trials (either by force
|
||||
|
@ -315,6 +338,11 @@ class Bracket():
|
|||
are dropped."""
|
||||
return self._completed_progress / self._total_work
|
||||
|
||||
def _get_result_time(self, result):
|
||||
if result is None:
|
||||
return 0
|
||||
return getattr(result, self._time_attr)
|
||||
|
||||
def _calculate_total_work(self, n, r, s):
|
||||
work = 0
|
||||
for i in range(s+1):
|
||||
|
@ -322,6 +350,7 @@ class Bracket():
|
|||
n /= self._eta
|
||||
n = int(np.ceil(n))
|
||||
r *= self._eta
|
||||
r = int(r)
|
||||
return work
|
||||
|
||||
def __repr__(self):
|
||||
|
|
|
@ -3,6 +3,7 @@ from __future__ import division
|
|||
from __future__ import print_function
|
||||
|
||||
import unittest
|
||||
import numpy as np
|
||||
|
||||
from ray.tune.hyperband import HyperBandScheduler
|
||||
from ray.tune.median_stopping_rule import MedianStoppingRule
|
||||
|
@ -12,7 +13,9 @@ from ray.tune.trial_scheduler import TrialScheduler
|
|||
|
||||
|
||||
def result(t, rew):
|
||||
return TrainingResult(time_total_s=t, episode_reward_mean=rew)
|
||||
return TrainingResult(time_total_s=t,
|
||||
episode_reward_mean=rew,
|
||||
training_iteration=int(t))
|
||||
|
||||
|
||||
class EarlyStoppingSuite(unittest.TestCase):
|
||||
|
@ -156,21 +159,46 @@ class HyperbandSuite(unittest.TestCase):
|
|||
"""Setup a scheduler and Runner with max Iter = 9
|
||||
|
||||
Bracketing is placed as follows:
|
||||
(3, 9);
|
||||
(5, 3) -> (2, 9);
|
||||
(9, 1) -> (3, 3) -> (1, 9); """
|
||||
sched = HyperBandScheduler(9, eta=3)
|
||||
(5, 81);
|
||||
(8, 27) -> (3, 81);
|
||||
(15, 9) -> (5, 27) -> (2, 81);
|
||||
(34, 3) -> (12, 9) -> (4, 27) -> (2, 81);
|
||||
(81, 1) -> (27, 3) -> (9, 9) -> (3, 27) -> (1, 81);"""
|
||||
sched = HyperBandScheduler()
|
||||
for i in range(num_trials):
|
||||
t = Trial("t%d" % i, "__fake")
|
||||
sched.on_trial_add(None, t)
|
||||
runner = _MockTrialRunner()
|
||||
return sched, runner
|
||||
|
||||
def default_statistics(self):
|
||||
"""Default statistics for HyperBand"""
|
||||
sched = HyperBandScheduler()
|
||||
res = {
|
||||
str(s): {"n": sched._get_n0(s), "r": sched._get_r0(s)}
|
||||
for s in range(sched._s_max_1)
|
||||
}
|
||||
res["max_trials"] = sum(v["n"] for v in res.values())
|
||||
res["brack_count"] = sched._s_max_1
|
||||
res["s_max"] = sched._s_max_1 - 1
|
||||
return res
|
||||
|
||||
def downscale(self, n, sched):
|
||||
return int(np.ceil(n / sched._eta))
|
||||
|
||||
def process(self, trl, mock_runner, action):
|
||||
if action == TrialScheduler.CONTINUE:
|
||||
pass
|
||||
elif action == TrialScheduler.PAUSE:
|
||||
mock_runner._pause_trial(trl)
|
||||
elif action == TrialScheduler.STOP:
|
||||
self.stopTrial(trl, mock_runner)
|
||||
|
||||
def basicSetup(self):
|
||||
"""Setup and verify full band.
|
||||
"""
|
||||
|
||||
sched, _ = self.schedulerSetup(17)
|
||||
stats = self.default_statistics()
|
||||
sched, _ = self.schedulerSetup(stats["max_trials"])
|
||||
|
||||
self.assertEqual(len(sched._hyperbands), 1)
|
||||
self.assertEqual(sched._cur_band_filled(), True)
|
||||
|
@ -192,7 +220,7 @@ class HyperbandSuite(unittest.TestCase):
|
|||
self.assertEqual(len(unfilled_band), 2)
|
||||
bracket = unfilled_band[-1]
|
||||
self.assertEqual(bracket.filled(), False)
|
||||
self.assertEqual(len(bracket.current_trials()), 1)
|
||||
self.assertEqual(len(bracket.current_trials()), 7)
|
||||
|
||||
return sched
|
||||
|
||||
|
@ -200,19 +228,254 @@ class HyperbandSuite(unittest.TestCase):
|
|||
self.assertNotEqual(trial.status, Trial.TERMINATED)
|
||||
mock_runner._stop_trial(trial)
|
||||
|
||||
def testSuccessiveHalving(self):
|
||||
"""Setup full band, then iterate through last bracket (n=9)
|
||||
to make sure successive halving is correct."""
|
||||
def testConfigSameEta(self):
|
||||
sched = HyperBandScheduler()
|
||||
i = 0
|
||||
while not sched._cur_band_filled():
|
||||
t = Trial("t%d" % (i), "__fake")
|
||||
sched.on_trial_add(None, t)
|
||||
i += 1
|
||||
self.assertEqual(len(sched._hyperbands[0]), 5)
|
||||
self.assertEqual(sched._hyperbands[0][0]._n, 5)
|
||||
self.assertEqual(sched._hyperbands[0][0]._r, 81)
|
||||
self.assertEqual(sched._hyperbands[0][-1]._n, 81)
|
||||
self.assertEqual(sched._hyperbands[0][-1]._r, 1)
|
||||
|
||||
sched, mock_runner = self.schedulerSetup(17)
|
||||
filled_band = sched._hyperbands[0][-1]
|
||||
big_bracket = filled_band
|
||||
sched = HyperBandScheduler(max_t=810)
|
||||
i = 0
|
||||
while not sched._cur_band_filled():
|
||||
t = Trial("t%d" % (i), "__fake")
|
||||
sched.on_trial_add(None, t)
|
||||
i += 1
|
||||
self.assertEqual(len(sched._hyperbands[0]), 5)
|
||||
self.assertEqual(sched._hyperbands[0][0]._n, 5)
|
||||
self.assertEqual(sched._hyperbands[0][0]._r, 810)
|
||||
self.assertEqual(sched._hyperbands[0][-1]._n, 81)
|
||||
self.assertEqual(sched._hyperbands[0][-1]._r, 10)
|
||||
|
||||
def testConfigSameEtaSmall(self):
|
||||
sched = HyperBandScheduler(max_t=1)
|
||||
i = 0
|
||||
while len(sched._hyperbands) < 2:
|
||||
t = Trial("t%d" % (i), "__fake")
|
||||
sched.on_trial_add(None, t)
|
||||
i += 1
|
||||
self.assertEqual(len(sched._hyperbands[0]), 5)
|
||||
self.assertTrue(all(v is None for v in sched._hyperbands[0][1:]))
|
||||
|
||||
def testSuccessiveHalving(self):
|
||||
"""Setup full band, then iterate through last bracket (n=81)
|
||||
to make sure successive halving is correct."""
|
||||
stats = self.default_statistics()
|
||||
sched, mock_runner = self.schedulerSetup(stats["max_trials"])
|
||||
big_bracket = sched._state["bracket"]
|
||||
cur_units = stats[str(stats["s_max"])]["r"]
|
||||
# The last bracket will downscale 4 times
|
||||
for x in range(stats["brack_count"] - 1):
|
||||
trials = big_bracket.current_trials()
|
||||
current_length = len(trials)
|
||||
for trl in trials:
|
||||
mock_runner._launch_trial(trl)
|
||||
|
||||
# Provides results from 0 to 8 in order, keeping last one running
|
||||
for i, trl in enumerate(trials):
|
||||
action = sched.on_trial_result(
|
||||
mock_runner, trl, result(cur_units, i))
|
||||
if i < current_length - 1:
|
||||
self.assertEqual(action, TrialScheduler.PAUSE)
|
||||
self.process(trl, mock_runner, action)
|
||||
|
||||
self.assertEqual(action, TrialScheduler.CONTINUE)
|
||||
new_length = len(big_bracket.current_trials())
|
||||
self.assertEqual(new_length, self.downscale(current_length, sched))
|
||||
cur_units += int(cur_units * sched._eta)
|
||||
self.assertEqual(len(big_bracket.current_trials()), 1)
|
||||
|
||||
def testHalvingStop(self):
|
||||
stats = self.default_statistics()
|
||||
num_trials = stats[str(0)]["n"] + stats[str(1)]["n"]
|
||||
sched, mock_runner = self.schedulerSetup(num_trials)
|
||||
big_bracket = sched._state["bracket"]
|
||||
for trl in big_bracket.current_trials():
|
||||
mock_runner._launch_trial(trl)
|
||||
|
||||
# # Provides result in reverse order, killing the last one
|
||||
cur_units = stats[str(1)]["r"]
|
||||
for i, trl in reversed(list(enumerate(big_bracket.current_trials()))):
|
||||
action = sched.on_trial_result(
|
||||
mock_runner, trl, result(cur_units, i))
|
||||
self.process(trl, mock_runner, action)
|
||||
|
||||
self.assertEqual(action, TrialScheduler.STOP)
|
||||
|
||||
def testContinueLastOne(self):
|
||||
stats = self.default_statistics()
|
||||
num_trials = stats[str(0)]["n"]
|
||||
sched, mock_runner = self.schedulerSetup(num_trials)
|
||||
big_bracket = sched._state["bracket"]
|
||||
for trl in big_bracket.current_trials():
|
||||
mock_runner._launch_trial(trl)
|
||||
|
||||
# # Provides result in reverse order, killing the last one
|
||||
cur_units = stats[str(0)]["r"]
|
||||
for i, trl in enumerate(big_bracket.current_trials()):
|
||||
action = sched.on_trial_result(
|
||||
mock_runner, trl, result(cur_units, i))
|
||||
self.process(trl, mock_runner, action)
|
||||
|
||||
self.assertEqual(action, TrialScheduler.CONTINUE)
|
||||
|
||||
for x in range(100):
|
||||
action = sched.on_trial_result(
|
||||
mock_runner, trl, result(cur_units + x, 10))
|
||||
self.assertEqual(action, TrialScheduler.CONTINUE)
|
||||
|
||||
def testTrialErrored(self):
|
||||
"""If a trial errored, make sure successive halving still happens"""
|
||||
stats = self.default_statistics()
|
||||
trial_count = stats[str(0)]["n"] + 3
|
||||
sched, mock_runner = self.schedulerSetup(trial_count)
|
||||
t1, t2, t3 = sched._state["bracket"].current_trials()
|
||||
for t in [t1, t2, t3]:
|
||||
mock_runner._launch_trial(t)
|
||||
|
||||
sched.on_trial_error(mock_runner, t3)
|
||||
self.assertEqual(
|
||||
TrialScheduler.PAUSE,
|
||||
sched.on_trial_result(
|
||||
mock_runner, t1, result(stats[str(1)]["r"], 10)))
|
||||
self.assertEqual(
|
||||
TrialScheduler.CONTINUE,
|
||||
sched.on_trial_result(
|
||||
mock_runner, t2, result(stats[str(1)]["r"], 10)))
|
||||
|
||||
def testTrialErrored2(self):
|
||||
"""Check successive halving happened even when last trial failed"""
|
||||
stats = self.default_statistics()
|
||||
trial_count = stats[str(0)]["n"] + stats[str(1)]["n"]
|
||||
sched, mock_runner = self.schedulerSetup(trial_count)
|
||||
trials = sched._state["bracket"].current_trials()
|
||||
for t in trials[:-1]:
|
||||
mock_runner._launch_trial(t)
|
||||
sched.on_trial_result(
|
||||
mock_runner, t, result(stats[str(1)]["r"], 10))
|
||||
|
||||
mock_runner._launch_trial(trials[-1])
|
||||
sched.on_trial_error(mock_runner, trials[-1])
|
||||
self.assertEqual(len(sched._state["bracket"].current_trials()),
|
||||
self.downscale(stats[str(1)]["n"], sched))
|
||||
|
||||
def testTrialEndedEarly(self):
|
||||
"""Check successive halving happened even when one trial failed"""
|
||||
stats = self.default_statistics()
|
||||
trial_count = stats[str(0)]["n"] + 3
|
||||
sched, mock_runner = self.schedulerSetup(trial_count)
|
||||
|
||||
t1, t2, t3 = sched._state["bracket"].current_trials()
|
||||
for t in [t1, t2, t3]:
|
||||
mock_runner._launch_trial(t)
|
||||
|
||||
sched.on_trial_complete(mock_runner, t3, result(1, 12))
|
||||
self.assertEqual(
|
||||
TrialScheduler.PAUSE,
|
||||
sched.on_trial_result(
|
||||
mock_runner, t1, result(stats[str(1)]["r"], 10)))
|
||||
self.assertEqual(
|
||||
TrialScheduler.CONTINUE,
|
||||
sched.on_trial_result(
|
||||
mock_runner, t2, result(stats[str(1)]["r"], 10)))
|
||||
|
||||
def testTrialEndedEarly2(self):
|
||||
"""Check successive halving happened even when last trial failed"""
|
||||
stats = self.default_statistics()
|
||||
trial_count = stats[str(0)]["n"] + stats[str(1)]["n"]
|
||||
sched, mock_runner = self.schedulerSetup(trial_count)
|
||||
trials = sched._state["bracket"].current_trials()
|
||||
for t in trials[:-1]:
|
||||
mock_runner._launch_trial(t)
|
||||
sched.on_trial_result(
|
||||
mock_runner, t, result(stats[str(1)]["r"], 10))
|
||||
|
||||
mock_runner._launch_trial(trials[-1])
|
||||
sched.on_trial_complete(mock_runner, trials[-1], result(100, 12))
|
||||
self.assertEqual(len(sched._state["bracket"].current_trials()),
|
||||
self.downscale(stats[str(1)]["n"], sched))
|
||||
|
||||
def testAddAfterHalving(self):
|
||||
stats = self.default_statistics()
|
||||
trial_count = stats[str(0)]["n"] + 1
|
||||
sched, mock_runner = self.schedulerSetup(trial_count)
|
||||
bracket_trials = sched._state["bracket"].current_trials()
|
||||
init_units = stats[str(1)]["r"]
|
||||
|
||||
for t in bracket_trials:
|
||||
mock_runner._launch_trial(t)
|
||||
|
||||
for i, t in enumerate(bracket_trials):
|
||||
status = sched.on_trial_result(
|
||||
mock_runner, t, result(init_units, i))
|
||||
self.assertEqual(status, TrialScheduler.CONTINUE)
|
||||
t = Trial("t%d" % 100, "__fake")
|
||||
sched.on_trial_add(None, t)
|
||||
mock_runner._launch_trial(t)
|
||||
self.assertEqual(len(sched._state["bracket"].current_trials()), 2)
|
||||
|
||||
# Make sure that newly added trial gets fair computation (not just 1)
|
||||
self.assertEqual(
|
||||
TrialScheduler.CONTINUE,
|
||||
sched.on_trial_result(mock_runner, t, result(init_units, 12)))
|
||||
new_units = init_units + int(init_units * sched._eta)
|
||||
self.assertEqual(
|
||||
TrialScheduler.PAUSE,
|
||||
sched.on_trial_result(mock_runner, t, result(new_units, 12)))
|
||||
|
||||
def testAlternateMetrics(self):
|
||||
"""Checking that alternate metrics will pass."""
|
||||
|
||||
def result2(t, rew):
|
||||
return TrainingResult(time_total_s=t, neg_mean_loss=rew)
|
||||
|
||||
sched = HyperBandScheduler(
|
||||
time_attr='time_total_s', reward_attr='neg_mean_loss')
|
||||
stats = self.default_statistics()
|
||||
|
||||
for i in range(stats["max_trials"]):
|
||||
t = Trial("t%d" % i, "__fake")
|
||||
sched.on_trial_add(None, t)
|
||||
runner = _MockTrialRunner()
|
||||
|
||||
big_bracket = sched._hyperbands[0][-1]
|
||||
|
||||
for trl in big_bracket.current_trials():
|
||||
runner._launch_trial(trl)
|
||||
current_length = len(big_bracket.current_trials())
|
||||
|
||||
# Provides results from 0 to 8 in order, keeping the last one running
|
||||
for i, trl in enumerate(big_bracket.current_trials()):
|
||||
status = sched.on_trial_result(runner, trl, result2(1, i))
|
||||
if status == TrialScheduler.CONTINUE:
|
||||
continue
|
||||
elif status == TrialScheduler.PAUSE:
|
||||
runner._pause_trial(trl)
|
||||
elif status == TrialScheduler.STOP:
|
||||
self.assertNotEqual(trl.status, Trial.TERMINATED)
|
||||
self.stopTrial(trl, runner)
|
||||
|
||||
new_length = len(big_bracket.current_trials())
|
||||
self.assertEqual(status, TrialScheduler.CONTINUE)
|
||||
self.assertEqual(new_length, self.downscale(current_length, sched))
|
||||
|
||||
def testJumpingTime(self):
|
||||
sched, mock_runner = self.schedulerSetup(81)
|
||||
big_bracket = sched._hyperbands[0][-1]
|
||||
|
||||
for trl in big_bracket.current_trials():
|
||||
mock_runner._launch_trial(trl)
|
||||
|
||||
# Provides results from 0 to 8 in order, keeping the last one running
|
||||
for i, trl in enumerate(big_bracket.current_trials()):
|
||||
main_trials = big_bracket.current_trials()[:-1]
|
||||
jump = big_bracket.current_trials()[-1]
|
||||
for i, trl in enumerate(main_trials):
|
||||
status = sched.on_trial_result(mock_runner, trl, result(1, i))
|
||||
if status == TrialScheduler.CONTINUE:
|
||||
continue
|
||||
|
@ -222,117 +485,11 @@ class HyperbandSuite(unittest.TestCase):
|
|||
self.assertNotEqual(trl.status, Trial.TERMINATED)
|
||||
self.stopTrial(trl, mock_runner)
|
||||
|
||||
status = sched.on_trial_result(mock_runner, jump, result(4, i))
|
||||
self.assertEqual(status, TrialScheduler.PAUSE)
|
||||
|
||||
current_length = len(big_bracket.current_trials())
|
||||
self.assertEqual(status, TrialScheduler.CONTINUE)
|
||||
self.assertEqual(current_length, 3)
|
||||
|
||||
# Techincally only need to launch 2/3, as one is already running
|
||||
for trl in big_bracket.current_trials():
|
||||
mock_runner._launch_trial(trl)
|
||||
|
||||
# Provides results from 2 to 0 in order, killing the last one
|
||||
for i, trl in reversed(list(enumerate(big_bracket.current_trials()))):
|
||||
for j in range(3):
|
||||
status = sched.on_trial_result(mock_runner, trl, result(1, i))
|
||||
if status == TrialScheduler.CONTINUE:
|
||||
continue
|
||||
elif status == TrialScheduler.PAUSE:
|
||||
mock_runner._pause_trial(trl)
|
||||
elif status == TrialScheduler.STOP:
|
||||
self.stopTrial(trl, mock_runner)
|
||||
|
||||
self.assertEqual(status, TrialScheduler.STOP)
|
||||
trl = big_bracket.current_trials()[0]
|
||||
for i in range(9):
|
||||
status = sched.on_trial_result(mock_runner, trl, result(1, i))
|
||||
self.assertEqual(status, TrialScheduler.STOP)
|
||||
self.assertEqual(len(big_bracket.current_trials()), 0)
|
||||
self.assertEqual(sched._num_stopped, 9)
|
||||
|
||||
def testScheduling(self):
|
||||
"""Setup two bands, then make sure all trials are running"""
|
||||
sched = self.advancedSetup()
|
||||
mock_runner = _MockTrialRunner()
|
||||
trl = sched.choose_trial_to_run(mock_runner)
|
||||
while trl:
|
||||
# If band iteration > 0, make sure first band is all running
|
||||
if sched._trial_info[trl][1] > 0:
|
||||
first_band = sched._hyperbands[0]
|
||||
trials = [t for b in first_band for t in b._live_trials]
|
||||
self.assertEqual(
|
||||
all(t.status == Trial.RUNNING for t in trials),
|
||||
True)
|
||||
mock_runner._launch_trial(trl)
|
||||
res = sched.on_trial_result(mock_runner, trl, result(1, 10))
|
||||
if res is TrialScheduler.PAUSE:
|
||||
mock_runner._pause_trial(trl)
|
||||
trl = sched.choose_trial_to_run(mock_runner)
|
||||
|
||||
self.assertEqual(
|
||||
all(t.status == Trial.RUNNING for t in trials), True)
|
||||
|
||||
def testTrialErrored(self):
|
||||
sched, mock_runner = self.schedulerSetup(10)
|
||||
t1, t2 = sched._state["bracket"].current_trials()
|
||||
mock_runner._launch_trial(t1)
|
||||
mock_runner._launch_trial(t2)
|
||||
|
||||
sched.on_trial_error(mock_runner, t2)
|
||||
self.assertEqual(
|
||||
TrialScheduler.CONTINUE,
|
||||
sched.on_trial_result(mock_runner, t1, result(1, 10)))
|
||||
|
||||
def testTrialErrored2(self):
|
||||
"""Check successive halving happened even when last trial failed"""
|
||||
sched, mock_runner = self.schedulerSetup(17)
|
||||
trials = sched._state["bracket"].current_trials()
|
||||
self.assertEqual(len(trials), 9)
|
||||
for t in trials[:-1]:
|
||||
mock_runner._launch_trial(t)
|
||||
sched.on_trial_result(mock_runner, t, result(1, 10))
|
||||
|
||||
mock_runner._launch_trial(trials[-1])
|
||||
sched.on_trial_error(mock_runner, trials[-1])
|
||||
self.assertEqual(len(sched._state["bracket"].current_trials()), 3)
|
||||
|
||||
def testTrialEndedEarly(self):
|
||||
sched, mock_runner = self.schedulerSetup(10)
|
||||
trials = sched._state["bracket"].current_trials()
|
||||
for t in trials:
|
||||
mock_runner._launch_trial(t)
|
||||
|
||||
sched.on_trial_complete(mock_runner, trials[-1], result(1, 12))
|
||||
self.assertEqual(
|
||||
TrialScheduler.CONTINUE,
|
||||
sched.on_trial_result(mock_runner, trials[0], result(1, 12)))
|
||||
|
||||
def testTrialEndedEarly2(self):
|
||||
"""Check successive halving happened even when last trial finished"""
|
||||
sched, mock_runner = self.schedulerSetup(17)
|
||||
trials = sched._state["bracket"].current_trials()
|
||||
self.assertEqual(len(trials), 9)
|
||||
for t in trials[:-1]:
|
||||
mock_runner._launch_trial(t)
|
||||
sched.on_trial_result(mock_runner, t, result(1, 10))
|
||||
|
||||
mock_runner._launch_trial(trials[-1])
|
||||
sched.on_trial_complete(mock_runner, trials[-1], result(1, 12))
|
||||
self.assertEqual(len(sched._state["bracket"].current_trials()), 3)
|
||||
|
||||
def testAddAfterHalving(self):
|
||||
sched, mock_runner = self.schedulerSetup(10)
|
||||
bracket_trials = sched._state["bracket"].current_trials()
|
||||
|
||||
for t in bracket_trials:
|
||||
mock_runner._launch_trial(t)
|
||||
|
||||
for i, t in enumerate(bracket_trials):
|
||||
res = sched.on_trial_result(
|
||||
mock_runner, t, result(1, i))
|
||||
self.assertEqual(res, TrialScheduler.CONTINUE)
|
||||
t = Trial("t%d" % 5, "__fake")
|
||||
sched.on_trial_add(None, t)
|
||||
self.assertEqual(3 + 1, sched._state["bracket"]._live_trials[t][1])
|
||||
self.assertLess(current_length, 27)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
Loading…
Add table
Reference in a new issue