[tune] Async Hyperband (#1595)

This commit is contained in:
Richard Liaw 2018-03-04 14:05:56 -08:00 committed by GitHub
parent ecb811c26e
commit 78716094b5
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 338 additions and 3 deletions

View file

@ -146,8 +146,6 @@ To reduce costs, long-running trials can often be early stopped if their initial
An example of this can be found in `hyperband_example.py <https://github.com/ray-project/ray/blob/master/python/ray/tune/examples/hyperband_example.py>`__. The progress of one such HyperBand run is shown below.
Note that some trial schedulers such as HyperBand and PBT require your Trainable to support checkpointing, which is described in the next section. Checkpointing enables the scheduler to multiplex many concurrent trials onto a limited size cluster.
::
== Status ==
@ -180,10 +178,15 @@ Note that some trial schedulers such as HyperBand and PBT require your Trainable
- my_class_31_height=40,width=10: RUNNING
- my_class_53_height=28,width=96: RUNNING
Ray Tune also implements an `asynchronous version of HyperBand <https://openreview.net/forum?id=S1Y7OOlRZ>`__, providing better parallelism and avoids straggler issues during eliminations. An example of this can be found in `async_hyperband_example.py <https://github.com/ray-project/ray/blob/master/python/ray/tune/examples/async_hyperband_example.py>`__. We recommend using this over the vanilla HyperBand scheduler.
.. note:: Some trial schedulers such as HyperBand and PBT require your Trainable to support checkpointing, which is described in the next section. Checkpointing enables the scheduler to multiplex many concurrent trials onto a limited size cluster.
Currently we support the following early stopping algorithms, or you can write your own that implements the `TrialScheduler <https://github.com/ray-project/ray/blob/master/python/ray/tune/trial_scheduler.py>`__ interface.
.. autoclass:: ray.tune.median_stopping_rule.MedianStoppingRule
.. autoclass:: ray.tune.hyperband.HyperBandScheduler
.. autoclass:: ray.tune.async_hyperband.AsyncHyperBandScheduler
Population Based Training
-------------------------

View file

@ -0,0 +1,148 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from ray.tune.trial_scheduler import FIFOScheduler, TrialScheduler
class AsyncHyperBandScheduler(FIFOScheduler):
"""Implements the Async Successive Halving.
This should provide similar theoretical performance as HyperBand but
avoid straggler issues that HyperBand faces. One implementation detail
is when using multiple brackets, trial allocation to bracket is done
randomly with over a softmax probability.
See https://openreview.net/forum?id=S1Y7OOlRZ
Args:
time_attr (str): The TrainingResult attr to use for comparing time.
Note that you can pass in something non-temporal such as
`training_iteration` as a measure of progress, the only requirement
is that the attribute should increase monotonically.
reward_attr (str): The TrainingResult objective value attribute. As
with `time_attr`, this may refer to any objective value. Stopping
procedures will use this attribute.
max_t (float): max time units per trial. Trials will be stopped after
max_t time units (determined by time_attr) have passed.
grace_period (float): Only stop trials at least this old in time.
The units are the same as the attribute named by `time_attr`.
reduction_factor (float): Used to set halving rate and amount. This
is simply a unit-less scalar.
brackets (int): Number of brackets. Each bracket has a different
halving rate, specified by the reduction factor.
"""
def __init__(
self, time_attr='training_iteration',
reward_attr='episode_reward_mean', max_t=100,
grace_period=10, reduction_factor=3, brackets=3):
assert max_t > 0, "Max (time_attr) not valid!"
assert max_t >= grace_period, "grace_period must be <= max_t!"
assert grace_period > 0, "grace_period must be positive!"
assert reduction_factor > 1, "Reduction Factor not valid!"
assert brackets > 0, "brackets must be positive!"
FIFOScheduler.__init__(self)
self._reduction_factor = reduction_factor
self._max_t = max_t
self._trial_info = {} # Stores Trial -> Bracket
# Tracks state for new trial add
self._brackets = [_Bracket(
grace_period, max_t, reduction_factor, s) for s in range(brackets)]
self._counter = 0 # for
self._num_stopped = 0
self._reward_attr = reward_attr
self._time_attr = time_attr
def on_trial_add(self, trial_runner, trial):
sizes = np.array([len(b._rungs) for b in self._brackets])
probs = np.e ** (sizes - sizes.max())
normalized = probs / probs.sum()
idx = np.random.choice(len(self._brackets), p=normalized)
self._trial_info[trial.trial_id] = self._brackets[idx]
def on_trial_result(self, trial_runner, trial, result):
if getattr(result, self._time_attr) >= self._max_t:
self._num_stopped += 1
return TrialScheduler.STOP
bracket = self._trial_info[trial.trial_id]
action = bracket.on_result(
trial,
getattr(result, self._time_attr),
getattr(result, self._reward_attr))
return action
def on_trial_complete(self, trial_runner, trial, result):
bracket = self._trial_info[trial.trial_id]
bracket.on_result(
trial,
getattr(result, self._time_attr),
getattr(result, self._reward_attr))
del self._trial_info[trial.trial_id]
def on_trial_remove(self, trial_runner, trial):
del self._trial_info[trial.trial_id]
def debug_string(self):
out = "Using AsyncHyperBand: num_stopped={}".format(
self._num_stopped)
out += "\n" + "\n".join([b.debug_str() for b in self._brackets])
return out
class _Bracket():
"""Bookkeeping system to track the cutoffs.
Rungs are created in reversed order so that we can more easily find
the correct rung corresponding to the current iteration of the result.
Example:
>>> b = _Bracket(1, 10, 2, 3)
>>> b.on_result(trial1, 1, 2) # CONTINUE
>>> b.on_result(trial2, 1, 4) # CONTINUE
>>> b.cutoff(b._rungs[-1][1]) == 3.0 # rungs are reversed
>>> b.on_result(trial3, 1, 1) # STOP
>>> b.cutoff(b._rungs[0][1]) == 2.0
"""
def __init__(self, min_t, max_t, reduction_factor, s):
self.rf = reduction_factor
MAX_RUNGS = int(np.log(max_t / min_t) / np.log(self.rf) - s + 1)
self._rungs = [(min_t * self.rf**(k + s), {})
for k in reversed(range(MAX_RUNGS))]
def cutoff(self, recorded):
if not recorded:
return None
return np.percentile(list(recorded.values()), (1 - 1 / self.rf) * 100)
def on_result(self, trial, cur_iter, cur_rew):
action = TrialScheduler.CONTINUE
for milestone, recorded in self._rungs:
if cur_iter < milestone or trial.trial_id in recorded:
continue
else:
cutoff = self.cutoff(recorded)
if cutoff is not None and cur_rew < cutoff:
action = TrialScheduler.STOP
recorded[trial.trial_id] = cur_rew
break
return action
def debug_str(self):
iters = " | ".join(
["Iter {:.3f}: {}".format(milestone, self.cutoff(recorded))
for milestone, recorded in self._rungs])
return "Bracket: " + iters
if __name__ == '__main__':
sched = AsyncHyperBandScheduler(
grace_period=1, max_t=10, reduction_factor=2)
print(sched.debug_string())
bracket = sched._brackets[0]
print(bracket.cutoff({str(i): i for i in range(20)}))

View file

@ -80,7 +80,7 @@ def make_parser(**kwargs):
"many times. Only applies if checkpointing is enabled.")
parser.add_argument(
"--scheduler", default="FIFO", type=str,
help="FIFO (default), MedianStopping, or HyperBand.")
help="FIFO (default), MedianStopping, AsyncHyperBand, or HyperBand.")
parser.add_argument(
"--scheduler-config", default="{}", type=json.loads,
help="Config options to pass to the scheduler.")

View file

@ -0,0 +1,77 @@
#!/usr/bin/env python
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import json
import os
import random
import numpy as np
import ray
from ray.tune import Trainable, TrainingResult, register_trainable, \
run_experiments
from ray.tune.async_hyperband import AsyncHyperBandScheduler
class MyTrainableClass(Trainable):
"""Example agent whose learning curve is a random sigmoid.
The dummy hyperparameters "width" and "height" determine the slope and
maximum reward value reached.
"""
def _setup(self):
self.timestep = 0
def _train(self):
self.timestep += 1
v = np.tanh(float(self.timestep) / self.config["width"])
v *= self.config["height"]
# Here we use `episode_reward_mean`, but you can also report other
# objectives such as loss or accuracy (see tune/result.py).
return TrainingResult(episode_reward_mean=v, timesteps_this_iter=1)
def _save(self, checkpoint_dir):
path = os.path.join(checkpoint_dir, "checkpoint")
with open(path, "w") as f:
f.write(json.dumps({"timestep": self.timestep}))
return path
def _restore(self, checkpoint_path):
with open(checkpoint_path) as f:
self.timestep = json.loads(f.read())["timestep"]
register_trainable("my_class", MyTrainableClass)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--smoke-test", action="store_true", help="Finish quickly for testing")
args, _ = parser.parse_known_args()
ray.init()
# asynchronous hyperband early stopping, configured with
# `episode_reward_mean` as the
# objective and `timesteps_total` as the time unit.
ahb = AsyncHyperBandScheduler(
time_attr="timesteps_total", reward_attr="episode_reward_mean",
grace_period=5, max_t=100)
run_experiments({
"asynchyperband_test": {
"run": "my_class",
"stop": {"training_iteration": 1 if args.smoke_test else 99999},
"repeat": 20,
"resources": {"cpu": 1, "gpu": 0},
"config": {
"width": lambda spec: 10 + int(90 * random.random()),
"height": lambda spec: int(100 * random.random()),
},
}
}, scheduler=ahb)

View file

@ -8,6 +8,7 @@ import unittest
import numpy as np
from ray.tune.hyperband import HyperBandScheduler
from ray.tune.async_hyperband import AsyncHyperBandScheduler
from ray.tune.pbt import PopulationBasedTraining, explore
from ray.tune.median_stopping_rule import MedianStoppingRule
from ray.tune.result import TrainingResult
@ -757,5 +758,105 @@ class PopulationBasedTestingSuite(unittest.TestCase):
self.assertEqual(trials[0].config["float_factor"], 43)
class AsyncHyperBandSuite(unittest.TestCase):
def basicSetup(self, scheduler):
t1 = Trial("PPO") # mean is 450, max 900, t_max=10
t2 = Trial("PPO") # mean is 450, max 450, t_max=5
scheduler.on_trial_add(None, t1)
scheduler.on_trial_add(None, t2)
for i in range(10):
self.assertEqual(
scheduler.on_trial_result(None, t1, result(i, i * 100)),
TrialScheduler.CONTINUE)
for i in range(5):
self.assertEqual(
scheduler.on_trial_result(None, t2, result(i, 450)),
TrialScheduler.CONTINUE)
return t1, t2
def testAsyncHBOnComplete(self):
scheduler = AsyncHyperBandScheduler(
max_t=10, brackets=1)
t1, t2 = self.basicSetup(scheduler)
t3 = Trial("PPO")
scheduler.on_trial_add(None, t3)
scheduler.on_trial_complete(None, t3, result(10, 1000))
self.assertEqual(
scheduler.on_trial_result(None, t2, result(101, 0)),
TrialScheduler.STOP)
def testAsyncHBGracePeriod(self):
scheduler = AsyncHyperBandScheduler(
grace_period=2.5, reduction_factor=3, brackets=1)
t1, t2 = self.basicSetup(scheduler)
scheduler.on_trial_complete(None, t1, result(10, 1000))
scheduler.on_trial_complete(None, t2, result(10, 1000))
t3 = Trial("PPO")
scheduler.on_trial_add(None, t3)
self.assertEqual(
scheduler.on_trial_result(None, t3, result(1, 10)),
TrialScheduler.CONTINUE)
self.assertEqual(
scheduler.on_trial_result(None, t3, result(2, 10)),
TrialScheduler.CONTINUE)
self.assertEqual(
scheduler.on_trial_result(None, t3, result(3, 10)),
TrialScheduler.STOP)
def testAsyncHBAllCompletes(self):
scheduler = AsyncHyperBandScheduler(
max_t=10, brackets=10)
trials = [Trial("PPO") for i in range(10)]
for t in trials:
scheduler.on_trial_add(None, t)
for t in trials:
self.assertEqual(
scheduler.on_trial_result(None, t, result(10, -2)),
TrialScheduler.STOP)
def testAsyncHBUsesPercentile(self):
scheduler = AsyncHyperBandScheduler(
grace_period=1, max_t=10, reduction_factor=2, brackets=1)
t1, t2 = self.basicSetup(scheduler)
scheduler.on_trial_complete(None, t1, result(10, 1000))
scheduler.on_trial_complete(None, t2, result(10, 1000))
t3 = Trial("PPO")
scheduler.on_trial_add(None, t3)
self.assertEqual(
scheduler.on_trial_result(None, t3, result(1, 260)),
TrialScheduler.STOP)
self.assertEqual(
scheduler.on_trial_result(None, t3, result(2, 260)),
TrialScheduler.STOP)
def testAlternateMetrics(self):
def result2(t, rew):
return TrainingResult(training_iteration=t, neg_mean_loss=rew)
scheduler = AsyncHyperBandScheduler(
grace_period=1, time_attr='training_iteration',
reward_attr='neg_mean_loss', brackets=1)
t1 = Trial("PPO") # mean is 450, max 900, t_max=10
t2 = Trial("PPO") # mean is 450, max 450, t_max=5
scheduler.on_trial_add(None, t1)
scheduler.on_trial_add(None, t2)
for i in range(10):
self.assertEqual(
scheduler.on_trial_result(None, t1, result2(i, i * 100)),
TrialScheduler.CONTINUE)
for i in range(5):
self.assertEqual(
scheduler.on_trial_result(None, t2, result2(i, 450)),
TrialScheduler.CONTINUE)
scheduler.on_trial_complete(None, t1, result2(10, 1000))
self.assertEqual(
scheduler.on_trial_result(None, t2, result2(5, 450)),
TrialScheduler.CONTINUE)
self.assertEqual(
scheduler.on_trial_result(None, t2, result2(6, 0)),
TrialScheduler.CONTINUE)
if __name__ == "__main__":
unittest.main(verbosity=2)

View file

@ -6,6 +6,7 @@ import time
from ray.tune import TuneError
from ray.tune.hyperband import HyperBandScheduler
from ray.tune.async_hyperband import AsyncHyperBandScheduler
from ray.tune.median_stopping_rule import MedianStoppingRule
from ray.tune.trial import Trial, DEBUG_PRINT_INTERVAL
from ray.tune.log_sync import wait_for_log_sync
@ -19,6 +20,7 @@ _SCHEDULERS = {
"FIFO": FIFOScheduler,
"MedianStopping": MedianStoppingRule,
"HyperBand": HyperBandScheduler,
"AsyncHyperBand": AsyncHyperBandScheduler,
}

View file

@ -202,6 +202,10 @@ docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \
python /ray/python/ray/tune/examples/hyperband_example.py \
--smoke-test
docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \
python /ray/python/ray/tune/examples/async_hyperband_example.py \
--smoke-test
docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \
python /ray/python/ray/tune/examples/tune_mnist_ray_hyperband.py \
--smoke-test