[tune] Implement median stopping rule (#1170)

* trial scheduler interface

* remove

* wip median stopping

* remove

* median stopping rule

* update

* docs

* update

* Revrt

* update

* comments

* fix tesT
This commit is contained in:
Eric Liang 2017-11-03 11:25:02 -07:00 committed by Richard Liaw
parent fdf069bd1d
commit d06beacd84
5 changed files with 225 additions and 1 deletions

View file

@ -119,6 +119,7 @@ script:
- python test/recursion_test.py - python test/recursion_test.py
- python test/monitor_test.py - python test/monitor_test.py
- python test/trial_runner_test.py - python test/trial_runner_test.py
- python test/trial_scheduler_test.py
- python -m pytest python/ray/rllib/test/test_catalog.py - python -m pytest python/ray/rllib/test/test_catalog.py

View file

@ -23,6 +23,7 @@ import yaml
import ray import ray
from ray.tune.config_parser import make_parser, parse_to_trials from ray.tune.config_parser import make_parser, parse_to_trials
from ray.tune.trial_scheduler import MedianStoppingRule
from ray.tune.trial_runner import TrialRunner from ray.tune.trial_runner import TrialRunner
from ray.tune.trial import Trial from ray.tune.trial import Trial
@ -46,7 +47,7 @@ parser.add_argument("-f", "--config-file", default=None, type=str,
def main(argv): def main(argv):
args = parser.parse_args(argv) args = parser.parse_args(argv)
runner = TrialRunner() runner = TrialRunner(MedianStoppingRule())
if args.config_file: if args.config_file:
with open(args.config_file) as f: with open(args.config_file) as f:

View file

@ -148,6 +148,7 @@ class TrialRunner(object):
trial.last_result = result trial.last_result = result
if trial.should_stop(result): if trial.should_stop(result):
self._scheduler_alg.on_trial_complete(self, trial, result)
self._stop_trial(trial) self._stop_trial(trial)
else: else:
decision = self._scheduler_alg.on_trial_result( decision = self._scheduler_alg.on_trial_result(

View file

@ -1,6 +1,9 @@
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
import collections
import numpy as np
from ray.tune.trial import Trial from ray.tune.trial import Trial
@ -17,6 +20,13 @@ class TrialScheduler(object):
raise NotImplementedError raise NotImplementedError
def on_trial_complete(self, trial_runner, trial, result):
"""Notification for the completion of trial.
This will only be called when the trial completes naturally."""
raise NotImplementedError
def choose_trial_to_run(self, trial_runner, trials): def choose_trial_to_run(self, trial_runner, trials):
"""Called to choose a new trial to run. """Called to choose a new trial to run.
@ -32,9 +42,14 @@ class TrialScheduler(object):
class FIFOScheduler(TrialScheduler): class FIFOScheduler(TrialScheduler):
"""Simple scheduler that just runs trials in submission order."""
def on_trial_result(self, trial_runner, trial, result): def on_trial_result(self, trial_runner, trial, result):
return TrialScheduler.CONTINUE return TrialScheduler.CONTINUE
def on_trial_complete(self, trial_runner, trial, result):
pass
def choose_trial_to_run(self, trial_runner): def choose_trial_to_run(self, trial_runner):
for trial in trial_runner.get_trials(): for trial in trial_runner.get_trials():
if (trial.status == Trial.PENDING and if (trial.status == Trial.PENDING and
@ -44,3 +59,85 @@ class FIFOScheduler(TrialScheduler):
def debug_string(self): def debug_string(self):
return "Using FIFO scheduling algorithm." return "Using FIFO scheduling algorithm."
# TODO(ekl) expose this in the command line API
class MedianStoppingRule(FIFOScheduler):
"""Implements the median stopping rule as described in the Vizier paper:
https://research.google.com/pubs/pub46180.html
Args:
time_attr (str): The TrainingResult attr to use for comparing time.
Note that you can pass in something non-temporal such as
`training_iteration` as a measure of progress, the only requirement
is that the attribute should increase monotonically.
reward_attr (str): The TrainingResult objective value attribute. As
with `time_attr`, this may refer to any objective value that
is supposed to increase with time.
grace_period (float): Only stop trials at least this old in time.
The units are the same as the attribute named by `time_attr`.
min_samples_required (int): Min samples to compute median over.
"""
def __init__(
self, time_attr='time_total_s', reward_attr='episode_reward_mean',
grace_period=60.0, min_samples_required=3):
FIFOScheduler.__init__(self)
self._completed_trials = set()
self._results = collections.defaultdict(list)
self._grace_period = grace_period
self._min_samples_required = min_samples_required
self._reward_attr = reward_attr
self._time_attr = time_attr
self._num_stopped = 0
def on_trial_result(self, trial_runner, trial, result):
"""Callback for early stopping.
This stopping rule stops a running trial if the trial's best objective
value by step `t` is strictly worse than the median of the running
averages of all completed trials' objectives reported up to step `t`.
"""
time = getattr(result, self._time_attr)
self._results[trial].append(result)
median_result = self._get_median_result(time)
best_result = self._best_result(trial)
print("Trial {} best res={} vs median res={} at t={}".format(
trial, best_result, median_result, time))
if best_result < median_result and time > self._grace_period:
print("MedianStoppingRule: early stopping {}".format(trial))
self._num_stopped += 1
return TrialScheduler.STOP
else:
return TrialScheduler.CONTINUE
def on_trial_complete(self, trial_runner, trial, result):
self._results[trial].append(result)
self._completed_trials.add(trial)
def debug_string(self):
return "Using MedianStoppingRule: num_stopped={}.".format(
self._num_stopped)
def _get_median_result(self, time):
scores = []
for trial in self._completed_trials:
scores.append(self._running_result(trial, time))
if len(scores) >= self._min_samples_required:
return np.median(scores)
else:
return float('-inf')
def _running_result(self, trial, t_max=float('inf')):
results = self._results[trial]
# TODO(ekl) we could do interpolation to be more precise, but for now
# assume len(results) is large and the time diffs are roughly equal
return np.mean(
[getattr(r, self._reward_attr)
for r in results if getattr(r, self._time_attr) <= t_max])
def _best_result(self, trial):
results = self._results[trial]
return max([getattr(r, self._reward_attr) for r in results])

View file

@ -0,0 +1,124 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import unittest
from ray.tune.result import TrainingResult
from ray.tune.trial import Trial
from ray.tune.trial_scheduler import MedianStoppingRule, TrialScheduler
def result(t, rew):
return TrainingResult(time_total_s=t, episode_reward_mean=rew)
class EarlyStoppingSuite(unittest.TestCase):
def basicSetup(self, rule):
t1 = Trial("t1", "PPO") # mean is 450, max 900, t_max=10
t2 = Trial("t2", "PPO") # mean is 450, max 450, t_max=5
for i in range(10):
self.assertEqual(
rule.on_trial_result(None, t1, result(i, i * 100)),
TrialScheduler.CONTINUE)
for i in range(5):
self.assertEqual(
rule.on_trial_result(None, t2, result(i, 450)),
TrialScheduler.CONTINUE)
return t1, t2
def testMedianStoppingConstantPerf(self):
rule = MedianStoppingRule(grace_period=0, min_samples_required=1)
t1, t2 = self.basicSetup(rule)
rule.on_trial_complete(None, t1, result(10, 1000))
self.assertEqual(
rule.on_trial_result(None, t2, result(5, 450)),
TrialScheduler.CONTINUE)
self.assertEqual(
rule.on_trial_result(None, t2, result(6, 0)),
TrialScheduler.CONTINUE)
self.assertEqual(
rule.on_trial_result(None, t2, result(10, 450)),
TrialScheduler.STOP)
def testMedianStoppingOnCompleteOnly(self):
rule = MedianStoppingRule(grace_period=0, min_samples_required=1)
t1, t2 = self.basicSetup(rule)
self.assertEqual(
rule.on_trial_result(None, t2, result(100, 0)),
TrialScheduler.CONTINUE)
rule.on_trial_complete(None, t1, result(10, 1000))
self.assertEqual(
rule.on_trial_result(None, t2, result(101, 0)),
TrialScheduler.STOP)
def testMedianStoppingGracePeriod(self):
rule = MedianStoppingRule(grace_period=2.5, min_samples_required=1)
t1, t2 = self.basicSetup(rule)
rule.on_trial_complete(None, t1, result(10, 1000))
rule.on_trial_complete(None, t2, result(10, 1000))
t3 = Trial("t3", "PPO")
self.assertEqual(
rule.on_trial_result(None, t3, result(1, 10)),
TrialScheduler.CONTINUE)
self.assertEqual(
rule.on_trial_result(None, t3, result(2, 10)),
TrialScheduler.CONTINUE)
self.assertEqual(
rule.on_trial_result(None, t3, result(3, 10)),
TrialScheduler.STOP)
def testMedianStoppingMinSamples(self):
rule = MedianStoppingRule(grace_period=0, min_samples_required=2)
t1, t2 = self.basicSetup(rule)
rule.on_trial_complete(None, t1, result(10, 1000))
t3 = Trial("t3", "PPO")
self.assertEqual(
rule.on_trial_result(None, t3, result(3, 10)),
TrialScheduler.CONTINUE)
rule.on_trial_complete(None, t2, result(10, 1000))
self.assertEqual(
rule.on_trial_result(None, t3, result(3, 10)),
TrialScheduler.STOP)
def testMedianStoppingUsesMedian(self):
rule = MedianStoppingRule(grace_period=0, min_samples_required=1)
t1, t2 = self.basicSetup(rule)
rule.on_trial_complete(None, t1, result(10, 1000))
rule.on_trial_complete(None, t2, result(10, 1000))
t3 = Trial("t3", "PPO")
self.assertEqual(
rule.on_trial_result(None, t3, result(1, 260)),
TrialScheduler.CONTINUE)
self.assertEqual(
rule.on_trial_result(None, t3, result(2, 260)),
TrialScheduler.STOP)
def testAlternateMetrics(self):
def result2(t, rew):
return TrainingResult(training_iteration=t, neg_mean_loss=rew)
rule = MedianStoppingRule(
grace_period=0, min_samples_required=1,
time_attr='training_iteration', reward_attr='neg_mean_loss')
t1 = Trial("t1", "PPO") # mean is 450, max 900, t_max=10
t2 = Trial("t2", "PPO") # mean is 450, max 450, t_max=5
for i in range(10):
self.assertEqual(
rule.on_trial_result(None, t1, result2(i, i * 100)),
TrialScheduler.CONTINUE)
for i in range(5):
self.assertEqual(
rule.on_trial_result(None, t2, result2(i, 450)),
TrialScheduler.CONTINUE)
rule.on_trial_complete(None, t1, result2(10, 1000))
self.assertEqual(
rule.on_trial_result(None, t2, result2(5, 450)),
TrialScheduler.CONTINUE)
self.assertEqual(
rule.on_trial_result(None, t2, result2(6, 0)),
TrialScheduler.CONTINUE)
if __name__ == "__main__":
unittest.main(verbosity=2)