mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
[tune] Implement median stopping rule (#1170)
* trial scheduler interface * remove * wip median stopping * remove * median stopping rule * update * docs * update * Revrt * update * comments * fix tesT
This commit is contained in:
parent
fdf069bd1d
commit
d06beacd84
5 changed files with 225 additions and 1 deletions
|
@ -119,6 +119,7 @@ script:
|
|||
- python test/recursion_test.py
|
||||
- python test/monitor_test.py
|
||||
- python test/trial_runner_test.py
|
||||
- python test/trial_scheduler_test.py
|
||||
|
||||
- python -m pytest python/ray/rllib/test/test_catalog.py
|
||||
|
||||
|
|
|
@ -23,6 +23,7 @@ import yaml
|
|||
|
||||
import ray
|
||||
from ray.tune.config_parser import make_parser, parse_to_trials
|
||||
from ray.tune.trial_scheduler import MedianStoppingRule
|
||||
from ray.tune.trial_runner import TrialRunner
|
||||
from ray.tune.trial import Trial
|
||||
|
||||
|
@ -46,7 +47,7 @@ parser.add_argument("-f", "--config-file", default=None, type=str,
|
|||
|
||||
def main(argv):
|
||||
args = parser.parse_args(argv)
|
||||
runner = TrialRunner()
|
||||
runner = TrialRunner(MedianStoppingRule())
|
||||
|
||||
if args.config_file:
|
||||
with open(args.config_file) as f:
|
||||
|
|
|
@ -148,6 +148,7 @@ class TrialRunner(object):
|
|||
trial.last_result = result
|
||||
|
||||
if trial.should_stop(result):
|
||||
self._scheduler_alg.on_trial_complete(self, trial, result)
|
||||
self._stop_trial(trial)
|
||||
else:
|
||||
decision = self._scheduler_alg.on_trial_result(
|
||||
|
|
|
@ -1,6 +1,9 @@
|
|||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
|
||||
import collections
|
||||
import numpy as np
|
||||
|
||||
from ray.tune.trial import Trial
|
||||
|
||||
|
||||
|
@ -17,6 +20,13 @@ class TrialScheduler(object):
|
|||
|
||||
raise NotImplementedError
|
||||
|
||||
def on_trial_complete(self, trial_runner, trial, result):
|
||||
"""Notification for the completion of trial.
|
||||
|
||||
This will only be called when the trial completes naturally."""
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
def choose_trial_to_run(self, trial_runner, trials):
|
||||
"""Called to choose a new trial to run.
|
||||
|
||||
|
@ -32,9 +42,14 @@ class TrialScheduler(object):
|
|||
|
||||
|
||||
class FIFOScheduler(TrialScheduler):
|
||||
"""Simple scheduler that just runs trials in submission order."""
|
||||
|
||||
def on_trial_result(self, trial_runner, trial, result):
|
||||
return TrialScheduler.CONTINUE
|
||||
|
||||
def on_trial_complete(self, trial_runner, trial, result):
|
||||
pass
|
||||
|
||||
def choose_trial_to_run(self, trial_runner):
|
||||
for trial in trial_runner.get_trials():
|
||||
if (trial.status == Trial.PENDING and
|
||||
|
@ -44,3 +59,85 @@ class FIFOScheduler(TrialScheduler):
|
|||
|
||||
def debug_string(self):
|
||||
return "Using FIFO scheduling algorithm."
|
||||
|
||||
|
||||
# TODO(ekl) expose this in the command line API
|
||||
class MedianStoppingRule(FIFOScheduler):
|
||||
"""Implements the median stopping rule as described in the Vizier paper:
|
||||
|
||||
https://research.google.com/pubs/pub46180.html
|
||||
|
||||
Args:
|
||||
time_attr (str): The TrainingResult attr to use for comparing time.
|
||||
Note that you can pass in something non-temporal such as
|
||||
`training_iteration` as a measure of progress, the only requirement
|
||||
is that the attribute should increase monotonically.
|
||||
reward_attr (str): The TrainingResult objective value attribute. As
|
||||
with `time_attr`, this may refer to any objective value that
|
||||
is supposed to increase with time.
|
||||
grace_period (float): Only stop trials at least this old in time.
|
||||
The units are the same as the attribute named by `time_attr`.
|
||||
min_samples_required (int): Min samples to compute median over.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, time_attr='time_total_s', reward_attr='episode_reward_mean',
|
||||
grace_period=60.0, min_samples_required=3):
|
||||
FIFOScheduler.__init__(self)
|
||||
self._completed_trials = set()
|
||||
self._results = collections.defaultdict(list)
|
||||
self._grace_period = grace_period
|
||||
self._min_samples_required = min_samples_required
|
||||
self._reward_attr = reward_attr
|
||||
self._time_attr = time_attr
|
||||
self._num_stopped = 0
|
||||
|
||||
def on_trial_result(self, trial_runner, trial, result):
|
||||
"""Callback for early stopping.
|
||||
|
||||
This stopping rule stops a running trial if the trial's best objective
|
||||
value by step `t` is strictly worse than the median of the running
|
||||
averages of all completed trials' objectives reported up to step `t`.
|
||||
"""
|
||||
|
||||
time = getattr(result, self._time_attr)
|
||||
self._results[trial].append(result)
|
||||
median_result = self._get_median_result(time)
|
||||
best_result = self._best_result(trial)
|
||||
print("Trial {} best res={} vs median res={} at t={}".format(
|
||||
trial, best_result, median_result, time))
|
||||
if best_result < median_result and time > self._grace_period:
|
||||
print("MedianStoppingRule: early stopping {}".format(trial))
|
||||
self._num_stopped += 1
|
||||
return TrialScheduler.STOP
|
||||
else:
|
||||
return TrialScheduler.CONTINUE
|
||||
|
||||
def on_trial_complete(self, trial_runner, trial, result):
|
||||
self._results[trial].append(result)
|
||||
self._completed_trials.add(trial)
|
||||
|
||||
def debug_string(self):
|
||||
return "Using MedianStoppingRule: num_stopped={}.".format(
|
||||
self._num_stopped)
|
||||
|
||||
def _get_median_result(self, time):
|
||||
scores = []
|
||||
for trial in self._completed_trials:
|
||||
scores.append(self._running_result(trial, time))
|
||||
if len(scores) >= self._min_samples_required:
|
||||
return np.median(scores)
|
||||
else:
|
||||
return float('-inf')
|
||||
|
||||
def _running_result(self, trial, t_max=float('inf')):
|
||||
results = self._results[trial]
|
||||
# TODO(ekl) we could do interpolation to be more precise, but for now
|
||||
# assume len(results) is large and the time diffs are roughly equal
|
||||
return np.mean(
|
||||
[getattr(r, self._reward_attr)
|
||||
for r in results if getattr(r, self._time_attr) <= t_max])
|
||||
|
||||
def _best_result(self, trial):
|
||||
results = self._results[trial]
|
||||
return max([getattr(r, self._reward_attr) for r in results])
|
||||
|
|
124
test/trial_scheduler_test.py
Normal file
124
test/trial_scheduler_test.py
Normal file
|
@ -0,0 +1,124 @@
|
|||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import unittest
|
||||
|
||||
from ray.tune.result import TrainingResult
|
||||
from ray.tune.trial import Trial
|
||||
from ray.tune.trial_scheduler import MedianStoppingRule, TrialScheduler
|
||||
|
||||
|
||||
def result(t, rew):
|
||||
return TrainingResult(time_total_s=t, episode_reward_mean=rew)
|
||||
|
||||
|
||||
class EarlyStoppingSuite(unittest.TestCase):
|
||||
def basicSetup(self, rule):
|
||||
t1 = Trial("t1", "PPO") # mean is 450, max 900, t_max=10
|
||||
t2 = Trial("t2", "PPO") # mean is 450, max 450, t_max=5
|
||||
for i in range(10):
|
||||
self.assertEqual(
|
||||
rule.on_trial_result(None, t1, result(i, i * 100)),
|
||||
TrialScheduler.CONTINUE)
|
||||
for i in range(5):
|
||||
self.assertEqual(
|
||||
rule.on_trial_result(None, t2, result(i, 450)),
|
||||
TrialScheduler.CONTINUE)
|
||||
return t1, t2
|
||||
|
||||
def testMedianStoppingConstantPerf(self):
|
||||
rule = MedianStoppingRule(grace_period=0, min_samples_required=1)
|
||||
t1, t2 = self.basicSetup(rule)
|
||||
rule.on_trial_complete(None, t1, result(10, 1000))
|
||||
self.assertEqual(
|
||||
rule.on_trial_result(None, t2, result(5, 450)),
|
||||
TrialScheduler.CONTINUE)
|
||||
self.assertEqual(
|
||||
rule.on_trial_result(None, t2, result(6, 0)),
|
||||
TrialScheduler.CONTINUE)
|
||||
self.assertEqual(
|
||||
rule.on_trial_result(None, t2, result(10, 450)),
|
||||
TrialScheduler.STOP)
|
||||
|
||||
def testMedianStoppingOnCompleteOnly(self):
|
||||
rule = MedianStoppingRule(grace_period=0, min_samples_required=1)
|
||||
t1, t2 = self.basicSetup(rule)
|
||||
self.assertEqual(
|
||||
rule.on_trial_result(None, t2, result(100, 0)),
|
||||
TrialScheduler.CONTINUE)
|
||||
rule.on_trial_complete(None, t1, result(10, 1000))
|
||||
self.assertEqual(
|
||||
rule.on_trial_result(None, t2, result(101, 0)),
|
||||
TrialScheduler.STOP)
|
||||
|
||||
def testMedianStoppingGracePeriod(self):
|
||||
rule = MedianStoppingRule(grace_period=2.5, min_samples_required=1)
|
||||
t1, t2 = self.basicSetup(rule)
|
||||
rule.on_trial_complete(None, t1, result(10, 1000))
|
||||
rule.on_trial_complete(None, t2, result(10, 1000))
|
||||
t3 = Trial("t3", "PPO")
|
||||
self.assertEqual(
|
||||
rule.on_trial_result(None, t3, result(1, 10)),
|
||||
TrialScheduler.CONTINUE)
|
||||
self.assertEqual(
|
||||
rule.on_trial_result(None, t3, result(2, 10)),
|
||||
TrialScheduler.CONTINUE)
|
||||
self.assertEqual(
|
||||
rule.on_trial_result(None, t3, result(3, 10)),
|
||||
TrialScheduler.STOP)
|
||||
|
||||
def testMedianStoppingMinSamples(self):
|
||||
rule = MedianStoppingRule(grace_period=0, min_samples_required=2)
|
||||
t1, t2 = self.basicSetup(rule)
|
||||
rule.on_trial_complete(None, t1, result(10, 1000))
|
||||
t3 = Trial("t3", "PPO")
|
||||
self.assertEqual(
|
||||
rule.on_trial_result(None, t3, result(3, 10)),
|
||||
TrialScheduler.CONTINUE)
|
||||
rule.on_trial_complete(None, t2, result(10, 1000))
|
||||
self.assertEqual(
|
||||
rule.on_trial_result(None, t3, result(3, 10)),
|
||||
TrialScheduler.STOP)
|
||||
|
||||
def testMedianStoppingUsesMedian(self):
|
||||
rule = MedianStoppingRule(grace_period=0, min_samples_required=1)
|
||||
t1, t2 = self.basicSetup(rule)
|
||||
rule.on_trial_complete(None, t1, result(10, 1000))
|
||||
rule.on_trial_complete(None, t2, result(10, 1000))
|
||||
t3 = Trial("t3", "PPO")
|
||||
self.assertEqual(
|
||||
rule.on_trial_result(None, t3, result(1, 260)),
|
||||
TrialScheduler.CONTINUE)
|
||||
self.assertEqual(
|
||||
rule.on_trial_result(None, t3, result(2, 260)),
|
||||
TrialScheduler.STOP)
|
||||
|
||||
def testAlternateMetrics(self):
|
||||
def result2(t, rew):
|
||||
return TrainingResult(training_iteration=t, neg_mean_loss=rew)
|
||||
|
||||
rule = MedianStoppingRule(
|
||||
grace_period=0, min_samples_required=1,
|
||||
time_attr='training_iteration', reward_attr='neg_mean_loss')
|
||||
t1 = Trial("t1", "PPO") # mean is 450, max 900, t_max=10
|
||||
t2 = Trial("t2", "PPO") # mean is 450, max 450, t_max=5
|
||||
for i in range(10):
|
||||
self.assertEqual(
|
||||
rule.on_trial_result(None, t1, result2(i, i * 100)),
|
||||
TrialScheduler.CONTINUE)
|
||||
for i in range(5):
|
||||
self.assertEqual(
|
||||
rule.on_trial_result(None, t2, result2(i, 450)),
|
||||
TrialScheduler.CONTINUE)
|
||||
rule.on_trial_complete(None, t1, result2(10, 1000))
|
||||
self.assertEqual(
|
||||
rule.on_trial_result(None, t2, result2(5, 450)),
|
||||
TrialScheduler.CONTINUE)
|
||||
self.assertEqual(
|
||||
rule.on_trial_result(None, t2, result2(6, 0)),
|
||||
TrialScheduler.CONTINUE)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main(verbosity=2)
|
Loading…
Add table
Reference in a new issue