diff --git a/.travis.yml b/.travis.yml index 95659fb27..1999f08c9 100644 --- a/.travis.yml +++ b/.travis.yml @@ -119,6 +119,7 @@ script: - python test/recursion_test.py - python test/monitor_test.py - python test/trial_runner_test.py + - python test/trial_scheduler_test.py - python -m pytest python/ray/rllib/test/test_catalog.py diff --git a/python/ray/rllib/train.py b/python/ray/rllib/train.py index e32696760..21837bb94 100755 --- a/python/ray/rllib/train.py +++ b/python/ray/rllib/train.py @@ -23,6 +23,7 @@ import yaml import ray from ray.tune.config_parser import make_parser, parse_to_trials +from ray.tune.trial_scheduler import MedianStoppingRule from ray.tune.trial_runner import TrialRunner from ray.tune.trial import Trial @@ -46,7 +47,7 @@ parser.add_argument("-f", "--config-file", default=None, type=str, def main(argv): args = parser.parse_args(argv) - runner = TrialRunner() + runner = TrialRunner(MedianStoppingRule()) if args.config_file: with open(args.config_file) as f: diff --git a/python/ray/tune/trial_runner.py b/python/ray/tune/trial_runner.py index 195f153da..0e4f14aff 100644 --- a/python/ray/tune/trial_runner.py +++ b/python/ray/tune/trial_runner.py @@ -148,6 +148,7 @@ class TrialRunner(object): trial.last_result = result if trial.should_stop(result): + self._scheduler_alg.on_trial_complete(self, trial, result) self._stop_trial(trial) else: decision = self._scheduler_alg.on_trial_result( diff --git a/python/ray/tune/trial_scheduler.py b/python/ray/tune/trial_scheduler.py index b86e101a8..e8401440c 100644 --- a/python/ray/tune/trial_scheduler.py +++ b/python/ray/tune/trial_scheduler.py @@ -1,6 +1,9 @@ from __future__ import absolute_import from __future__ import division +import collections +import numpy as np + from ray.tune.trial import Trial @@ -17,6 +20,13 @@ class TrialScheduler(object): raise NotImplementedError + def on_trial_complete(self, trial_runner, trial, result): + """Notification for the completion of trial. + + This will only be called when the trial completes naturally.""" + + raise NotImplementedError + def choose_trial_to_run(self, trial_runner, trials): """Called to choose a new trial to run. @@ -32,9 +42,14 @@ class TrialScheduler(object): class FIFOScheduler(TrialScheduler): + """Simple scheduler that just runs trials in submission order.""" + def on_trial_result(self, trial_runner, trial, result): return TrialScheduler.CONTINUE + def on_trial_complete(self, trial_runner, trial, result): + pass + def choose_trial_to_run(self, trial_runner): for trial in trial_runner.get_trials(): if (trial.status == Trial.PENDING and @@ -44,3 +59,85 @@ class FIFOScheduler(TrialScheduler): def debug_string(self): return "Using FIFO scheduling algorithm." + + +# TODO(ekl) expose this in the command line API +class MedianStoppingRule(FIFOScheduler): + """Implements the median stopping rule as described in the Vizier paper: + + https://research.google.com/pubs/pub46180.html + + Args: + time_attr (str): The TrainingResult attr to use for comparing time. + Note that you can pass in something non-temporal such as + `training_iteration` as a measure of progress, the only requirement + is that the attribute should increase monotonically. + reward_attr (str): The TrainingResult objective value attribute. As + with `time_attr`, this may refer to any objective value that + is supposed to increase with time. + grace_period (float): Only stop trials at least this old in time. + The units are the same as the attribute named by `time_attr`. + min_samples_required (int): Min samples to compute median over. + """ + + def __init__( + self, time_attr='time_total_s', reward_attr='episode_reward_mean', + grace_period=60.0, min_samples_required=3): + FIFOScheduler.__init__(self) + self._completed_trials = set() + self._results = collections.defaultdict(list) + self._grace_period = grace_period + self._min_samples_required = min_samples_required + self._reward_attr = reward_attr + self._time_attr = time_attr + self._num_stopped = 0 + + def on_trial_result(self, trial_runner, trial, result): + """Callback for early stopping. + + This stopping rule stops a running trial if the trial's best objective + value by step `t` is strictly worse than the median of the running + averages of all completed trials' objectives reported up to step `t`. + """ + + time = getattr(result, self._time_attr) + self._results[trial].append(result) + median_result = self._get_median_result(time) + best_result = self._best_result(trial) + print("Trial {} best res={} vs median res={} at t={}".format( + trial, best_result, median_result, time)) + if best_result < median_result and time > self._grace_period: + print("MedianStoppingRule: early stopping {}".format(trial)) + self._num_stopped += 1 + return TrialScheduler.STOP + else: + return TrialScheduler.CONTINUE + + def on_trial_complete(self, trial_runner, trial, result): + self._results[trial].append(result) + self._completed_trials.add(trial) + + def debug_string(self): + return "Using MedianStoppingRule: num_stopped={}.".format( + self._num_stopped) + + def _get_median_result(self, time): + scores = [] + for trial in self._completed_trials: + scores.append(self._running_result(trial, time)) + if len(scores) >= self._min_samples_required: + return np.median(scores) + else: + return float('-inf') + + def _running_result(self, trial, t_max=float('inf')): + results = self._results[trial] + # TODO(ekl) we could do interpolation to be more precise, but for now + # assume len(results) is large and the time diffs are roughly equal + return np.mean( + [getattr(r, self._reward_attr) + for r in results if getattr(r, self._time_attr) <= t_max]) + + def _best_result(self, trial): + results = self._results[trial] + return max([getattr(r, self._reward_attr) for r in results]) diff --git a/test/trial_scheduler_test.py b/test/trial_scheduler_test.py new file mode 100644 index 000000000..c1767d3b7 --- /dev/null +++ b/test/trial_scheduler_test.py @@ -0,0 +1,124 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import unittest + +from ray.tune.result import TrainingResult +from ray.tune.trial import Trial +from ray.tune.trial_scheduler import MedianStoppingRule, TrialScheduler + + +def result(t, rew): + return TrainingResult(time_total_s=t, episode_reward_mean=rew) + + +class EarlyStoppingSuite(unittest.TestCase): + def basicSetup(self, rule): + t1 = Trial("t1", "PPO") # mean is 450, max 900, t_max=10 + t2 = Trial("t2", "PPO") # mean is 450, max 450, t_max=5 + for i in range(10): + self.assertEqual( + rule.on_trial_result(None, t1, result(i, i * 100)), + TrialScheduler.CONTINUE) + for i in range(5): + self.assertEqual( + rule.on_trial_result(None, t2, result(i, 450)), + TrialScheduler.CONTINUE) + return t1, t2 + + def testMedianStoppingConstantPerf(self): + rule = MedianStoppingRule(grace_period=0, min_samples_required=1) + t1, t2 = self.basicSetup(rule) + rule.on_trial_complete(None, t1, result(10, 1000)) + self.assertEqual( + rule.on_trial_result(None, t2, result(5, 450)), + TrialScheduler.CONTINUE) + self.assertEqual( + rule.on_trial_result(None, t2, result(6, 0)), + TrialScheduler.CONTINUE) + self.assertEqual( + rule.on_trial_result(None, t2, result(10, 450)), + TrialScheduler.STOP) + + def testMedianStoppingOnCompleteOnly(self): + rule = MedianStoppingRule(grace_period=0, min_samples_required=1) + t1, t2 = self.basicSetup(rule) + self.assertEqual( + rule.on_trial_result(None, t2, result(100, 0)), + TrialScheduler.CONTINUE) + rule.on_trial_complete(None, t1, result(10, 1000)) + self.assertEqual( + rule.on_trial_result(None, t2, result(101, 0)), + TrialScheduler.STOP) + + def testMedianStoppingGracePeriod(self): + rule = MedianStoppingRule(grace_period=2.5, min_samples_required=1) + t1, t2 = self.basicSetup(rule) + rule.on_trial_complete(None, t1, result(10, 1000)) + rule.on_trial_complete(None, t2, result(10, 1000)) + t3 = Trial("t3", "PPO") + self.assertEqual( + rule.on_trial_result(None, t3, result(1, 10)), + TrialScheduler.CONTINUE) + self.assertEqual( + rule.on_trial_result(None, t3, result(2, 10)), + TrialScheduler.CONTINUE) + self.assertEqual( + rule.on_trial_result(None, t3, result(3, 10)), + TrialScheduler.STOP) + + def testMedianStoppingMinSamples(self): + rule = MedianStoppingRule(grace_period=0, min_samples_required=2) + t1, t2 = self.basicSetup(rule) + rule.on_trial_complete(None, t1, result(10, 1000)) + t3 = Trial("t3", "PPO") + self.assertEqual( + rule.on_trial_result(None, t3, result(3, 10)), + TrialScheduler.CONTINUE) + rule.on_trial_complete(None, t2, result(10, 1000)) + self.assertEqual( + rule.on_trial_result(None, t3, result(3, 10)), + TrialScheduler.STOP) + + def testMedianStoppingUsesMedian(self): + rule = MedianStoppingRule(grace_period=0, min_samples_required=1) + t1, t2 = self.basicSetup(rule) + rule.on_trial_complete(None, t1, result(10, 1000)) + rule.on_trial_complete(None, t2, result(10, 1000)) + t3 = Trial("t3", "PPO") + self.assertEqual( + rule.on_trial_result(None, t3, result(1, 260)), + TrialScheduler.CONTINUE) + self.assertEqual( + rule.on_trial_result(None, t3, result(2, 260)), + TrialScheduler.STOP) + + def testAlternateMetrics(self): + def result2(t, rew): + return TrainingResult(training_iteration=t, neg_mean_loss=rew) + + rule = MedianStoppingRule( + grace_period=0, min_samples_required=1, + time_attr='training_iteration', reward_attr='neg_mean_loss') + t1 = Trial("t1", "PPO") # mean is 450, max 900, t_max=10 + t2 = Trial("t2", "PPO") # mean is 450, max 450, t_max=5 + for i in range(10): + self.assertEqual( + rule.on_trial_result(None, t1, result2(i, i * 100)), + TrialScheduler.CONTINUE) + for i in range(5): + self.assertEqual( + rule.on_trial_result(None, t2, result2(i, 450)), + TrialScheduler.CONTINUE) + rule.on_trial_complete(None, t1, result2(10, 1000)) + self.assertEqual( + rule.on_trial_result(None, t2, result2(5, 450)), + TrialScheduler.CONTINUE) + self.assertEqual( + rule.on_trial_result(None, t2, result2(6, 0)), + TrialScheduler.CONTINUE) + + +if __name__ == "__main__": + unittest.main(verbosity=2)