From 78716094b580fe58e2432eeb62f620d03bd0bd7d Mon Sep 17 00:00:00 2001 From: Richard Liaw Date: Sun, 4 Mar 2018 14:05:56 -0800 Subject: [PATCH] [tune] Async Hyperband (#1595) --- doc/source/tune.rst | 7 +- python/ray/tune/async_hyperband.py | 148 ++++++++++++++++++ python/ray/tune/config_parser.py | 2 +- .../tune/examples/async_hyperband_example.py | 77 +++++++++ python/ray/tune/test/trial_scheduler_test.py | 101 ++++++++++++ python/ray/tune/tune.py | 2 + test/jenkins_tests/run_multi_node_tests.sh | 4 + 7 files changed, 338 insertions(+), 3 deletions(-) create mode 100644 python/ray/tune/async_hyperband.py create mode 100644 python/ray/tune/examples/async_hyperband_example.py diff --git a/doc/source/tune.rst b/doc/source/tune.rst index 04ad02d1f..0d50bd5d6 100644 --- a/doc/source/tune.rst +++ b/doc/source/tune.rst @@ -146,8 +146,6 @@ To reduce costs, long-running trials can often be early stopped if their initial An example of this can be found in `hyperband_example.py `__. The progress of one such HyperBand run is shown below. -Note that some trial schedulers such as HyperBand and PBT require your Trainable to support checkpointing, which is described in the next section. Checkpointing enables the scheduler to multiplex many concurrent trials onto a limited size cluster. - :: == Status == @@ -180,10 +178,15 @@ Note that some trial schedulers such as HyperBand and PBT require your Trainable - my_class_31_height=40,width=10: RUNNING - my_class_53_height=28,width=96: RUNNING +Ray Tune also implements an `asynchronous version of HyperBand `__, providing better parallelism and avoids straggler issues during eliminations. An example of this can be found in `async_hyperband_example.py `__. We recommend using this over the vanilla HyperBand scheduler. + +.. note:: Some trial schedulers such as HyperBand and PBT require your Trainable to support checkpointing, which is described in the next section. Checkpointing enables the scheduler to multiplex many concurrent trials onto a limited size cluster. + Currently we support the following early stopping algorithms, or you can write your own that implements the `TrialScheduler `__ interface. .. autoclass:: ray.tune.median_stopping_rule.MedianStoppingRule .. autoclass:: ray.tune.hyperband.HyperBandScheduler +.. autoclass:: ray.tune.async_hyperband.AsyncHyperBandScheduler Population Based Training ------------------------- diff --git a/python/ray/tune/async_hyperband.py b/python/ray/tune/async_hyperband.py new file mode 100644 index 000000000..28b9a4ff8 --- /dev/null +++ b/python/ray/tune/async_hyperband.py @@ -0,0 +1,148 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from ray.tune.trial_scheduler import FIFOScheduler, TrialScheduler + + +class AsyncHyperBandScheduler(FIFOScheduler): + """Implements the Async Successive Halving. + + This should provide similar theoretical performance as HyperBand but + avoid straggler issues that HyperBand faces. One implementation detail + is when using multiple brackets, trial allocation to bracket is done + randomly with over a softmax probability. + + See https://openreview.net/forum?id=S1Y7OOlRZ + + Args: + time_attr (str): The TrainingResult attr to use for comparing time. + Note that you can pass in something non-temporal such as + `training_iteration` as a measure of progress, the only requirement + is that the attribute should increase monotonically. + reward_attr (str): The TrainingResult objective value attribute. As + with `time_attr`, this may refer to any objective value. Stopping + procedures will use this attribute. + max_t (float): max time units per trial. Trials will be stopped after + max_t time units (determined by time_attr) have passed. + grace_period (float): Only stop trials at least this old in time. + The units are the same as the attribute named by `time_attr`. + reduction_factor (float): Used to set halving rate and amount. This + is simply a unit-less scalar. + brackets (int): Number of brackets. Each bracket has a different + halving rate, specified by the reduction factor. + """ + + def __init__( + self, time_attr='training_iteration', + reward_attr='episode_reward_mean', max_t=100, + grace_period=10, reduction_factor=3, brackets=3): + assert max_t > 0, "Max (time_attr) not valid!" + assert max_t >= grace_period, "grace_period must be <= max_t!" + assert grace_period > 0, "grace_period must be positive!" + assert reduction_factor > 1, "Reduction Factor not valid!" + assert brackets > 0, "brackets must be positive!" + FIFOScheduler.__init__(self) + self._reduction_factor = reduction_factor + self._max_t = max_t + + self._trial_info = {} # Stores Trial -> Bracket + + # Tracks state for new trial add + self._brackets = [_Bracket( + grace_period, max_t, reduction_factor, s) for s in range(brackets)] + self._counter = 0 # for + self._num_stopped = 0 + self._reward_attr = reward_attr + self._time_attr = time_attr + + def on_trial_add(self, trial_runner, trial): + sizes = np.array([len(b._rungs) for b in self._brackets]) + probs = np.e ** (sizes - sizes.max()) + normalized = probs / probs.sum() + idx = np.random.choice(len(self._brackets), p=normalized) + self._trial_info[trial.trial_id] = self._brackets[idx] + + def on_trial_result(self, trial_runner, trial, result): + if getattr(result, self._time_attr) >= self._max_t: + self._num_stopped += 1 + return TrialScheduler.STOP + + bracket = self._trial_info[trial.trial_id] + action = bracket.on_result( + trial, + getattr(result, self._time_attr), + getattr(result, self._reward_attr)) + return action + + def on_trial_complete(self, trial_runner, trial, result): + bracket = self._trial_info[trial.trial_id] + bracket.on_result( + trial, + getattr(result, self._time_attr), + getattr(result, self._reward_attr)) + del self._trial_info[trial.trial_id] + + def on_trial_remove(self, trial_runner, trial): + del self._trial_info[trial.trial_id] + + def debug_string(self): + out = "Using AsyncHyperBand: num_stopped={}".format( + self._num_stopped) + out += "\n" + "\n".join([b.debug_str() for b in self._brackets]) + return out + + +class _Bracket(): + """Bookkeeping system to track the cutoffs. + + Rungs are created in reversed order so that we can more easily find + the correct rung corresponding to the current iteration of the result. + + Example: + >>> b = _Bracket(1, 10, 2, 3) + >>> b.on_result(trial1, 1, 2) # CONTINUE + >>> b.on_result(trial2, 1, 4) # CONTINUE + >>> b.cutoff(b._rungs[-1][1]) == 3.0 # rungs are reversed + >>> b.on_result(trial3, 1, 1) # STOP + >>> b.cutoff(b._rungs[0][1]) == 2.0 + """ + def __init__(self, min_t, max_t, reduction_factor, s): + self.rf = reduction_factor + MAX_RUNGS = int(np.log(max_t / min_t) / np.log(self.rf) - s + 1) + self._rungs = [(min_t * self.rf**(k + s), {}) + for k in reversed(range(MAX_RUNGS))] + + def cutoff(self, recorded): + if not recorded: + return None + return np.percentile(list(recorded.values()), (1 - 1 / self.rf) * 100) + + def on_result(self, trial, cur_iter, cur_rew): + action = TrialScheduler.CONTINUE + for milestone, recorded in self._rungs: + if cur_iter < milestone or trial.trial_id in recorded: + continue + else: + cutoff = self.cutoff(recorded) + if cutoff is not None and cur_rew < cutoff: + action = TrialScheduler.STOP + recorded[trial.trial_id] = cur_rew + break + return action + + def debug_str(self): + iters = " | ".join( + ["Iter {:.3f}: {}".format(milestone, self.cutoff(recorded)) + for milestone, recorded in self._rungs]) + return "Bracket: " + iters + + +if __name__ == '__main__': + sched = AsyncHyperBandScheduler( + grace_period=1, max_t=10, reduction_factor=2) + print(sched.debug_string()) + bracket = sched._brackets[0] + print(bracket.cutoff({str(i): i for i in range(20)})) diff --git a/python/ray/tune/config_parser.py b/python/ray/tune/config_parser.py index f8f086527..ae7ab35c0 100644 --- a/python/ray/tune/config_parser.py +++ b/python/ray/tune/config_parser.py @@ -80,7 +80,7 @@ def make_parser(**kwargs): "many times. Only applies if checkpointing is enabled.") parser.add_argument( "--scheduler", default="FIFO", type=str, - help="FIFO (default), MedianStopping, or HyperBand.") + help="FIFO (default), MedianStopping, AsyncHyperBand, or HyperBand.") parser.add_argument( "--scheduler-config", default="{}", type=json.loads, help="Config options to pass to the scheduler.") diff --git a/python/ray/tune/examples/async_hyperband_example.py b/python/ray/tune/examples/async_hyperband_example.py new file mode 100644 index 000000000..f6abe27f8 --- /dev/null +++ b/python/ray/tune/examples/async_hyperband_example.py @@ -0,0 +1,77 @@ +#!/usr/bin/env python + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import argparse +import json +import os +import random + +import numpy as np + +import ray +from ray.tune import Trainable, TrainingResult, register_trainable, \ + run_experiments +from ray.tune.async_hyperband import AsyncHyperBandScheduler + + +class MyTrainableClass(Trainable): + """Example agent whose learning curve is a random sigmoid. + + The dummy hyperparameters "width" and "height" determine the slope and + maximum reward value reached. + """ + + def _setup(self): + self.timestep = 0 + + def _train(self): + self.timestep += 1 + v = np.tanh(float(self.timestep) / self.config["width"]) + v *= self.config["height"] + + # Here we use `episode_reward_mean`, but you can also report other + # objectives such as loss or accuracy (see tune/result.py). + return TrainingResult(episode_reward_mean=v, timesteps_this_iter=1) + + def _save(self, checkpoint_dir): + path = os.path.join(checkpoint_dir, "checkpoint") + with open(path, "w") as f: + f.write(json.dumps({"timestep": self.timestep})) + return path + + def _restore(self, checkpoint_path): + with open(checkpoint_path) as f: + self.timestep = json.loads(f.read())["timestep"] + + +register_trainable("my_class", MyTrainableClass) + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--smoke-test", action="store_true", help="Finish quickly for testing") + args, _ = parser.parse_known_args() + ray.init() + + # asynchronous hyperband early stopping, configured with + # `episode_reward_mean` as the + # objective and `timesteps_total` as the time unit. + ahb = AsyncHyperBandScheduler( + time_attr="timesteps_total", reward_attr="episode_reward_mean", + grace_period=5, max_t=100) + + run_experiments({ + "asynchyperband_test": { + "run": "my_class", + "stop": {"training_iteration": 1 if args.smoke_test else 99999}, + "repeat": 20, + "resources": {"cpu": 1, "gpu": 0}, + "config": { + "width": lambda spec: 10 + int(90 * random.random()), + "height": lambda spec: int(100 * random.random()), + }, + } + }, scheduler=ahb) diff --git a/python/ray/tune/test/trial_scheduler_test.py b/python/ray/tune/test/trial_scheduler_test.py index bdc8d1a2d..5dc6d6af7 100644 --- a/python/ray/tune/test/trial_scheduler_test.py +++ b/python/ray/tune/test/trial_scheduler_test.py @@ -8,6 +8,7 @@ import unittest import numpy as np from ray.tune.hyperband import HyperBandScheduler +from ray.tune.async_hyperband import AsyncHyperBandScheduler from ray.tune.pbt import PopulationBasedTraining, explore from ray.tune.median_stopping_rule import MedianStoppingRule from ray.tune.result import TrainingResult @@ -757,5 +758,105 @@ class PopulationBasedTestingSuite(unittest.TestCase): self.assertEqual(trials[0].config["float_factor"], 43) +class AsyncHyperBandSuite(unittest.TestCase): + def basicSetup(self, scheduler): + t1 = Trial("PPO") # mean is 450, max 900, t_max=10 + t2 = Trial("PPO") # mean is 450, max 450, t_max=5 + scheduler.on_trial_add(None, t1) + scheduler.on_trial_add(None, t2) + for i in range(10): + self.assertEqual( + scheduler.on_trial_result(None, t1, result(i, i * 100)), + TrialScheduler.CONTINUE) + for i in range(5): + self.assertEqual( + scheduler.on_trial_result(None, t2, result(i, 450)), + TrialScheduler.CONTINUE) + return t1, t2 + + def testAsyncHBOnComplete(self): + scheduler = AsyncHyperBandScheduler( + max_t=10, brackets=1) + t1, t2 = self.basicSetup(scheduler) + t3 = Trial("PPO") + scheduler.on_trial_add(None, t3) + scheduler.on_trial_complete(None, t3, result(10, 1000)) + self.assertEqual( + scheduler.on_trial_result(None, t2, result(101, 0)), + TrialScheduler.STOP) + + def testAsyncHBGracePeriod(self): + scheduler = AsyncHyperBandScheduler( + grace_period=2.5, reduction_factor=3, brackets=1) + t1, t2 = self.basicSetup(scheduler) + scheduler.on_trial_complete(None, t1, result(10, 1000)) + scheduler.on_trial_complete(None, t2, result(10, 1000)) + t3 = Trial("PPO") + scheduler.on_trial_add(None, t3) + self.assertEqual( + scheduler.on_trial_result(None, t3, result(1, 10)), + TrialScheduler.CONTINUE) + self.assertEqual( + scheduler.on_trial_result(None, t3, result(2, 10)), + TrialScheduler.CONTINUE) + self.assertEqual( + scheduler.on_trial_result(None, t3, result(3, 10)), + TrialScheduler.STOP) + + def testAsyncHBAllCompletes(self): + scheduler = AsyncHyperBandScheduler( + max_t=10, brackets=10) + trials = [Trial("PPO") for i in range(10)] + for t in trials: + scheduler.on_trial_add(None, t) + + for t in trials: + self.assertEqual( + scheduler.on_trial_result(None, t, result(10, -2)), + TrialScheduler.STOP) + + def testAsyncHBUsesPercentile(self): + scheduler = AsyncHyperBandScheduler( + grace_period=1, max_t=10, reduction_factor=2, brackets=1) + t1, t2 = self.basicSetup(scheduler) + scheduler.on_trial_complete(None, t1, result(10, 1000)) + scheduler.on_trial_complete(None, t2, result(10, 1000)) + t3 = Trial("PPO") + scheduler.on_trial_add(None, t3) + self.assertEqual( + scheduler.on_trial_result(None, t3, result(1, 260)), + TrialScheduler.STOP) + self.assertEqual( + scheduler.on_trial_result(None, t3, result(2, 260)), + TrialScheduler.STOP) + + def testAlternateMetrics(self): + def result2(t, rew): + return TrainingResult(training_iteration=t, neg_mean_loss=rew) + + scheduler = AsyncHyperBandScheduler( + grace_period=1, time_attr='training_iteration', + reward_attr='neg_mean_loss', brackets=1) + t1 = Trial("PPO") # mean is 450, max 900, t_max=10 + t2 = Trial("PPO") # mean is 450, max 450, t_max=5 + scheduler.on_trial_add(None, t1) + scheduler.on_trial_add(None, t2) + for i in range(10): + self.assertEqual( + scheduler.on_trial_result(None, t1, result2(i, i * 100)), + TrialScheduler.CONTINUE) + for i in range(5): + self.assertEqual( + scheduler.on_trial_result(None, t2, result2(i, 450)), + TrialScheduler.CONTINUE) + scheduler.on_trial_complete(None, t1, result2(10, 1000)) + self.assertEqual( + scheduler.on_trial_result(None, t2, result2(5, 450)), + TrialScheduler.CONTINUE) + self.assertEqual( + scheduler.on_trial_result(None, t2, result2(6, 0)), + TrialScheduler.CONTINUE) + + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/python/ray/tune/tune.py b/python/ray/tune/tune.py index 13dd96157..0683974b1 100644 --- a/python/ray/tune/tune.py +++ b/python/ray/tune/tune.py @@ -6,6 +6,7 @@ import time from ray.tune import TuneError from ray.tune.hyperband import HyperBandScheduler +from ray.tune.async_hyperband import AsyncHyperBandScheduler from ray.tune.median_stopping_rule import MedianStoppingRule from ray.tune.trial import Trial, DEBUG_PRINT_INTERVAL from ray.tune.log_sync import wait_for_log_sync @@ -19,6 +20,7 @@ _SCHEDULERS = { "FIFO": FIFOScheduler, "MedianStopping": MedianStoppingRule, "HyperBand": HyperBandScheduler, + "AsyncHyperBand": AsyncHyperBandScheduler, } diff --git a/test/jenkins_tests/run_multi_node_tests.sh b/test/jenkins_tests/run_multi_node_tests.sh index ea3287082..ce7e6a304 100755 --- a/test/jenkins_tests/run_multi_node_tests.sh +++ b/test/jenkins_tests/run_multi_node_tests.sh @@ -202,6 +202,10 @@ docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \ python /ray/python/ray/tune/examples/hyperband_example.py \ --smoke-test +docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \ + python /ray/python/ray/tune/examples/async_hyperband_example.py \ + --smoke-test + docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \ python /ray/python/ray/tune/examples/tune_mnist_ray_hyperband.py \ --smoke-test