mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
[tune] Implement median stopping rule (#1170)
* trial scheduler interface * remove * wip median stopping * remove * median stopping rule * update * docs * update * Revrt * update * comments * fix tesT
This commit is contained in:
parent
fdf069bd1d
commit
d06beacd84
5 changed files with 225 additions and 1 deletions
|
@ -119,6 +119,7 @@ script:
|
||||||
- python test/recursion_test.py
|
- python test/recursion_test.py
|
||||||
- python test/monitor_test.py
|
- python test/monitor_test.py
|
||||||
- python test/trial_runner_test.py
|
- python test/trial_runner_test.py
|
||||||
|
- python test/trial_scheduler_test.py
|
||||||
|
|
||||||
- python -m pytest python/ray/rllib/test/test_catalog.py
|
- python -m pytest python/ray/rllib/test/test_catalog.py
|
||||||
|
|
||||||
|
|
|
@ -23,6 +23,7 @@ import yaml
|
||||||
|
|
||||||
import ray
|
import ray
|
||||||
from ray.tune.config_parser import make_parser, parse_to_trials
|
from ray.tune.config_parser import make_parser, parse_to_trials
|
||||||
|
from ray.tune.trial_scheduler import MedianStoppingRule
|
||||||
from ray.tune.trial_runner import TrialRunner
|
from ray.tune.trial_runner import TrialRunner
|
||||||
from ray.tune.trial import Trial
|
from ray.tune.trial import Trial
|
||||||
|
|
||||||
|
@ -46,7 +47,7 @@ parser.add_argument("-f", "--config-file", default=None, type=str,
|
||||||
|
|
||||||
def main(argv):
|
def main(argv):
|
||||||
args = parser.parse_args(argv)
|
args = parser.parse_args(argv)
|
||||||
runner = TrialRunner()
|
runner = TrialRunner(MedianStoppingRule())
|
||||||
|
|
||||||
if args.config_file:
|
if args.config_file:
|
||||||
with open(args.config_file) as f:
|
with open(args.config_file) as f:
|
||||||
|
|
|
@ -148,6 +148,7 @@ class TrialRunner(object):
|
||||||
trial.last_result = result
|
trial.last_result = result
|
||||||
|
|
||||||
if trial.should_stop(result):
|
if trial.should_stop(result):
|
||||||
|
self._scheduler_alg.on_trial_complete(self, trial, result)
|
||||||
self._stop_trial(trial)
|
self._stop_trial(trial)
|
||||||
else:
|
else:
|
||||||
decision = self._scheduler_alg.on_trial_result(
|
decision = self._scheduler_alg.on_trial_result(
|
||||||
|
|
|
@ -1,6 +1,9 @@
|
||||||
from __future__ import absolute_import
|
from __future__ import absolute_import
|
||||||
from __future__ import division
|
from __future__ import division
|
||||||
|
|
||||||
|
import collections
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
from ray.tune.trial import Trial
|
from ray.tune.trial import Trial
|
||||||
|
|
||||||
|
|
||||||
|
@ -17,6 +20,13 @@ class TrialScheduler(object):
|
||||||
|
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def on_trial_complete(self, trial_runner, trial, result):
|
||||||
|
"""Notification for the completion of trial.
|
||||||
|
|
||||||
|
This will only be called when the trial completes naturally."""
|
||||||
|
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
def choose_trial_to_run(self, trial_runner, trials):
|
def choose_trial_to_run(self, trial_runner, trials):
|
||||||
"""Called to choose a new trial to run.
|
"""Called to choose a new trial to run.
|
||||||
|
|
||||||
|
@ -32,9 +42,14 @@ class TrialScheduler(object):
|
||||||
|
|
||||||
|
|
||||||
class FIFOScheduler(TrialScheduler):
|
class FIFOScheduler(TrialScheduler):
|
||||||
|
"""Simple scheduler that just runs trials in submission order."""
|
||||||
|
|
||||||
def on_trial_result(self, trial_runner, trial, result):
|
def on_trial_result(self, trial_runner, trial, result):
|
||||||
return TrialScheduler.CONTINUE
|
return TrialScheduler.CONTINUE
|
||||||
|
|
||||||
|
def on_trial_complete(self, trial_runner, trial, result):
|
||||||
|
pass
|
||||||
|
|
||||||
def choose_trial_to_run(self, trial_runner):
|
def choose_trial_to_run(self, trial_runner):
|
||||||
for trial in trial_runner.get_trials():
|
for trial in trial_runner.get_trials():
|
||||||
if (trial.status == Trial.PENDING and
|
if (trial.status == Trial.PENDING and
|
||||||
|
@ -44,3 +59,85 @@ class FIFOScheduler(TrialScheduler):
|
||||||
|
|
||||||
def debug_string(self):
|
def debug_string(self):
|
||||||
return "Using FIFO scheduling algorithm."
|
return "Using FIFO scheduling algorithm."
|
||||||
|
|
||||||
|
|
||||||
|
# TODO(ekl) expose this in the command line API
|
||||||
|
class MedianStoppingRule(FIFOScheduler):
|
||||||
|
"""Implements the median stopping rule as described in the Vizier paper:
|
||||||
|
|
||||||
|
https://research.google.com/pubs/pub46180.html
|
||||||
|
|
||||||
|
Args:
|
||||||
|
time_attr (str): The TrainingResult attr to use for comparing time.
|
||||||
|
Note that you can pass in something non-temporal such as
|
||||||
|
`training_iteration` as a measure of progress, the only requirement
|
||||||
|
is that the attribute should increase monotonically.
|
||||||
|
reward_attr (str): The TrainingResult objective value attribute. As
|
||||||
|
with `time_attr`, this may refer to any objective value that
|
||||||
|
is supposed to increase with time.
|
||||||
|
grace_period (float): Only stop trials at least this old in time.
|
||||||
|
The units are the same as the attribute named by `time_attr`.
|
||||||
|
min_samples_required (int): Min samples to compute median over.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, time_attr='time_total_s', reward_attr='episode_reward_mean',
|
||||||
|
grace_period=60.0, min_samples_required=3):
|
||||||
|
FIFOScheduler.__init__(self)
|
||||||
|
self._completed_trials = set()
|
||||||
|
self._results = collections.defaultdict(list)
|
||||||
|
self._grace_period = grace_period
|
||||||
|
self._min_samples_required = min_samples_required
|
||||||
|
self._reward_attr = reward_attr
|
||||||
|
self._time_attr = time_attr
|
||||||
|
self._num_stopped = 0
|
||||||
|
|
||||||
|
def on_trial_result(self, trial_runner, trial, result):
|
||||||
|
"""Callback for early stopping.
|
||||||
|
|
||||||
|
This stopping rule stops a running trial if the trial's best objective
|
||||||
|
value by step `t` is strictly worse than the median of the running
|
||||||
|
averages of all completed trials' objectives reported up to step `t`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
time = getattr(result, self._time_attr)
|
||||||
|
self._results[trial].append(result)
|
||||||
|
median_result = self._get_median_result(time)
|
||||||
|
best_result = self._best_result(trial)
|
||||||
|
print("Trial {} best res={} vs median res={} at t={}".format(
|
||||||
|
trial, best_result, median_result, time))
|
||||||
|
if best_result < median_result and time > self._grace_period:
|
||||||
|
print("MedianStoppingRule: early stopping {}".format(trial))
|
||||||
|
self._num_stopped += 1
|
||||||
|
return TrialScheduler.STOP
|
||||||
|
else:
|
||||||
|
return TrialScheduler.CONTINUE
|
||||||
|
|
||||||
|
def on_trial_complete(self, trial_runner, trial, result):
|
||||||
|
self._results[trial].append(result)
|
||||||
|
self._completed_trials.add(trial)
|
||||||
|
|
||||||
|
def debug_string(self):
|
||||||
|
return "Using MedianStoppingRule: num_stopped={}.".format(
|
||||||
|
self._num_stopped)
|
||||||
|
|
||||||
|
def _get_median_result(self, time):
|
||||||
|
scores = []
|
||||||
|
for trial in self._completed_trials:
|
||||||
|
scores.append(self._running_result(trial, time))
|
||||||
|
if len(scores) >= self._min_samples_required:
|
||||||
|
return np.median(scores)
|
||||||
|
else:
|
||||||
|
return float('-inf')
|
||||||
|
|
||||||
|
def _running_result(self, trial, t_max=float('inf')):
|
||||||
|
results = self._results[trial]
|
||||||
|
# TODO(ekl) we could do interpolation to be more precise, but for now
|
||||||
|
# assume len(results) is large and the time diffs are roughly equal
|
||||||
|
return np.mean(
|
||||||
|
[getattr(r, self._reward_attr)
|
||||||
|
for r in results if getattr(r, self._time_attr) <= t_max])
|
||||||
|
|
||||||
|
def _best_result(self, trial):
|
||||||
|
results = self._results[trial]
|
||||||
|
return max([getattr(r, self._reward_attr) for r in results])
|
||||||
|
|
124
test/trial_scheduler_test.py
Normal file
124
test/trial_scheduler_test.py
Normal file
|
@ -0,0 +1,124 @@
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
from ray.tune.result import TrainingResult
|
||||||
|
from ray.tune.trial import Trial
|
||||||
|
from ray.tune.trial_scheduler import MedianStoppingRule, TrialScheduler
|
||||||
|
|
||||||
|
|
||||||
|
def result(t, rew):
|
||||||
|
return TrainingResult(time_total_s=t, episode_reward_mean=rew)
|
||||||
|
|
||||||
|
|
||||||
|
class EarlyStoppingSuite(unittest.TestCase):
|
||||||
|
def basicSetup(self, rule):
|
||||||
|
t1 = Trial("t1", "PPO") # mean is 450, max 900, t_max=10
|
||||||
|
t2 = Trial("t2", "PPO") # mean is 450, max 450, t_max=5
|
||||||
|
for i in range(10):
|
||||||
|
self.assertEqual(
|
||||||
|
rule.on_trial_result(None, t1, result(i, i * 100)),
|
||||||
|
TrialScheduler.CONTINUE)
|
||||||
|
for i in range(5):
|
||||||
|
self.assertEqual(
|
||||||
|
rule.on_trial_result(None, t2, result(i, 450)),
|
||||||
|
TrialScheduler.CONTINUE)
|
||||||
|
return t1, t2
|
||||||
|
|
||||||
|
def testMedianStoppingConstantPerf(self):
|
||||||
|
rule = MedianStoppingRule(grace_period=0, min_samples_required=1)
|
||||||
|
t1, t2 = self.basicSetup(rule)
|
||||||
|
rule.on_trial_complete(None, t1, result(10, 1000))
|
||||||
|
self.assertEqual(
|
||||||
|
rule.on_trial_result(None, t2, result(5, 450)),
|
||||||
|
TrialScheduler.CONTINUE)
|
||||||
|
self.assertEqual(
|
||||||
|
rule.on_trial_result(None, t2, result(6, 0)),
|
||||||
|
TrialScheduler.CONTINUE)
|
||||||
|
self.assertEqual(
|
||||||
|
rule.on_trial_result(None, t2, result(10, 450)),
|
||||||
|
TrialScheduler.STOP)
|
||||||
|
|
||||||
|
def testMedianStoppingOnCompleteOnly(self):
|
||||||
|
rule = MedianStoppingRule(grace_period=0, min_samples_required=1)
|
||||||
|
t1, t2 = self.basicSetup(rule)
|
||||||
|
self.assertEqual(
|
||||||
|
rule.on_trial_result(None, t2, result(100, 0)),
|
||||||
|
TrialScheduler.CONTINUE)
|
||||||
|
rule.on_trial_complete(None, t1, result(10, 1000))
|
||||||
|
self.assertEqual(
|
||||||
|
rule.on_trial_result(None, t2, result(101, 0)),
|
||||||
|
TrialScheduler.STOP)
|
||||||
|
|
||||||
|
def testMedianStoppingGracePeriod(self):
|
||||||
|
rule = MedianStoppingRule(grace_period=2.5, min_samples_required=1)
|
||||||
|
t1, t2 = self.basicSetup(rule)
|
||||||
|
rule.on_trial_complete(None, t1, result(10, 1000))
|
||||||
|
rule.on_trial_complete(None, t2, result(10, 1000))
|
||||||
|
t3 = Trial("t3", "PPO")
|
||||||
|
self.assertEqual(
|
||||||
|
rule.on_trial_result(None, t3, result(1, 10)),
|
||||||
|
TrialScheduler.CONTINUE)
|
||||||
|
self.assertEqual(
|
||||||
|
rule.on_trial_result(None, t3, result(2, 10)),
|
||||||
|
TrialScheduler.CONTINUE)
|
||||||
|
self.assertEqual(
|
||||||
|
rule.on_trial_result(None, t3, result(3, 10)),
|
||||||
|
TrialScheduler.STOP)
|
||||||
|
|
||||||
|
def testMedianStoppingMinSamples(self):
|
||||||
|
rule = MedianStoppingRule(grace_period=0, min_samples_required=2)
|
||||||
|
t1, t2 = self.basicSetup(rule)
|
||||||
|
rule.on_trial_complete(None, t1, result(10, 1000))
|
||||||
|
t3 = Trial("t3", "PPO")
|
||||||
|
self.assertEqual(
|
||||||
|
rule.on_trial_result(None, t3, result(3, 10)),
|
||||||
|
TrialScheduler.CONTINUE)
|
||||||
|
rule.on_trial_complete(None, t2, result(10, 1000))
|
||||||
|
self.assertEqual(
|
||||||
|
rule.on_trial_result(None, t3, result(3, 10)),
|
||||||
|
TrialScheduler.STOP)
|
||||||
|
|
||||||
|
def testMedianStoppingUsesMedian(self):
|
||||||
|
rule = MedianStoppingRule(grace_period=0, min_samples_required=1)
|
||||||
|
t1, t2 = self.basicSetup(rule)
|
||||||
|
rule.on_trial_complete(None, t1, result(10, 1000))
|
||||||
|
rule.on_trial_complete(None, t2, result(10, 1000))
|
||||||
|
t3 = Trial("t3", "PPO")
|
||||||
|
self.assertEqual(
|
||||||
|
rule.on_trial_result(None, t3, result(1, 260)),
|
||||||
|
TrialScheduler.CONTINUE)
|
||||||
|
self.assertEqual(
|
||||||
|
rule.on_trial_result(None, t3, result(2, 260)),
|
||||||
|
TrialScheduler.STOP)
|
||||||
|
|
||||||
|
def testAlternateMetrics(self):
|
||||||
|
def result2(t, rew):
|
||||||
|
return TrainingResult(training_iteration=t, neg_mean_loss=rew)
|
||||||
|
|
||||||
|
rule = MedianStoppingRule(
|
||||||
|
grace_period=0, min_samples_required=1,
|
||||||
|
time_attr='training_iteration', reward_attr='neg_mean_loss')
|
||||||
|
t1 = Trial("t1", "PPO") # mean is 450, max 900, t_max=10
|
||||||
|
t2 = Trial("t2", "PPO") # mean is 450, max 450, t_max=5
|
||||||
|
for i in range(10):
|
||||||
|
self.assertEqual(
|
||||||
|
rule.on_trial_result(None, t1, result2(i, i * 100)),
|
||||||
|
TrialScheduler.CONTINUE)
|
||||||
|
for i in range(5):
|
||||||
|
self.assertEqual(
|
||||||
|
rule.on_trial_result(None, t2, result2(i, 450)),
|
||||||
|
TrialScheduler.CONTINUE)
|
||||||
|
rule.on_trial_complete(None, t1, result2(10, 1000))
|
||||||
|
self.assertEqual(
|
||||||
|
rule.on_trial_result(None, t2, result2(5, 450)),
|
||||||
|
TrialScheduler.CONTINUE)
|
||||||
|
self.assertEqual(
|
||||||
|
rule.on_trial_result(None, t2, result2(6, 0)),
|
||||||
|
TrialScheduler.CONTINUE)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main(verbosity=2)
|
Loading…
Add table
Reference in a new issue