[tune] MedianStopping on result (#5402)

* added class median_stopping_result to schedulers and updated __init__

* Dicts flatten and combine schedulers.

MedianStoppingRule is now combined with MedianStoppingResult; I think
the functionality is essentially the same so there's no need to
duplicate.

Dict flattening was already taken care of in a separate PR, so I've
reverted that.

* lint

* revert

* remove time sharing and simplify state

* fix

* fixtests

* added class median_stopping_result to schedulers and updated __init__

* update property names and types to reflect suggestions by ray developers, merged get_median_result and get_best_result into a single method to eliminate duplicate steps, added resource check on PAUSE condition, modified utility function to use updated properties

* updated tests for median_stopping_result in separate file

* remove stray characters from previous merge conflict

* reformatted and cleaned up dependencies from running code format and linting

* added class median_stopping_result to schedulers and updated __init__

* Dicts flatten and combine schedulers.

MedianStoppingRule is now combined with MedianStoppingResult; I think
the functionality is essentially the same so there's no need to
duplicate.

Dict flattening was already taken care of in a separate PR, so I've
reverted that.

* lint

* revert

* remove time sharing and simplify state

* fix

* added class median_stopping_result to schedulers and updated __init__

* update property names and types to reflect suggestions by ray developers, merged get_median_result and get_best_result into a single method to eliminate duplicate steps, added resource check on PAUSE condition, modified utility function to use updated properties

* updated tests for median_stopping_result in separate file

* remove stray characters from previous merge conflict

* reformatted and cleaned up dependencies from running code format and linting

* update scheduler to coordinate eval interval

* modify median_stopping_result to synchronize result evaluation at regular intervals, driven by least common interval

* add some logging info to median_result

* add new scheduler, SyncMedianStoppingResult, which evaluates and stops trials in a synchronous fashion

* Cleanup median_stopping_rule

- remove eval_interval
- pause trials with insufficient samples if there are other waiting trials
- compute score only for trials that have reached result_time

* Remove extraneous classes

* Fix median stopping rule tests

* Added min_time_slice flag to reduce potential checkpointing cost

* Only compute mean after grace

* Relegate logging to debug mode
This commit is contained in:
waldroje 2019-10-08 14:40:41 -04:00 committed by Richard Liaw
parent 486abedcdf
commit 054583ffe6
2 changed files with 121 additions and 74 deletions

View file

@ -27,13 +27,18 @@ class MedianStoppingRule(FIFOScheduler):
mode (str): One of {min, max}. Determines whether objective is
minimizing or maximizing the metric attribute.
grace_period (float): Only stop trials at least this old in time.
The mean will only be computed from this time onwards. The units
are the same as the attribute named by `time_attr`.
min_samples_required (int): Minimum number of trials to compute median
over.
min_time_slice (float): Each trial runs at least this long before
yielding (assuming it isn't stopped). Note: trials ONLY yield if
there are not enough samples to evaluate performance for the
current result AND there are other trials waiting to run.
The units are the same as the attribute named by `time_attr`.
min_samples_required (int): Min samples to compute median over.
hard_stop (bool): If False, pauses trials instead of stopping
them. When all other trials are complete, paused trials will be
resumed and allowed to run FIFO.
verbose (bool): If True, will output the median and best result each
time a trial reports. Defaults to True.
"""
def __init__(self,
@ -43,10 +48,9 @@ class MedianStoppingRule(FIFOScheduler):
mode="max",
grace_period=60.0,
min_samples_required=3,
hard_stop=True,
verbose=True):
min_time_slice=0,
hard_stop=True):
assert mode in ["min", "max"], "`mode` must be 'min' or 'max'!"
if reward_attr is not None:
mode = "max"
metric = reward_attr
@ -54,21 +58,20 @@ class MedianStoppingRule(FIFOScheduler):
"`reward_attr` is deprecated and will be removed in a future "
"version of Tune. "
"Setting `metric={}` and `mode=max`.".format(reward_attr))
FIFOScheduler.__init__(self)
self._stopped_trials = set()
self._completed_trials = set()
self._results = collections.defaultdict(list)
self._grace_period = grace_period
self._min_samples_required = min_samples_required
self._min_time_slice = min_time_slice
self._metric = metric
if mode == "max":
self._metric_op = 1.
elif mode == "min":
self._metric_op = -1.
assert mode in {"min", "max"}, "`mode` must be 'min' or 'max'."
self._worst = float("-inf") if mode == "max" else float("inf")
self._compare_op = max if mode == "max" else min
self._time_attr = time_attr
self._hard_stop = hard_stop
self._verbose = verbose
self._trial_state = {}
self._last_pause = collections.defaultdict(lambda: float("-inf"))
self._results = collections.defaultdict(list)
def on_trial_result(self, trial_runner, trial, result):
"""Callback for early stopping.
@ -82,19 +85,38 @@ class MedianStoppingRule(FIFOScheduler):
if trial in self._stopped_trials:
assert not self._hard_stop
return TrialScheduler.CONTINUE # fall back to FIFO
# Fall back to FIFO
return TrialScheduler.CONTINUE
time = result[self._time_attr]
self._results[trial].append(result)
median_result = self._get_median_result(time)
if time < self._grace_period:
return TrialScheduler.CONTINUE
trials = self._trials_beyond_time(time)
trials.remove(trial)
if len(trials) < self._min_samples_required:
action = self._on_insufficient_samples(trial_runner, trial, time)
if action == TrialScheduler.PAUSE:
self._last_pause[trial] = time
action_str = "Yielding time to other trials."
else:
action_str = "Continuing anyways."
logger.debug(
"MedianStoppingRule: insufficient samples={} to evaluate "
"trial {} at t={}. {}".format(
len(trials), trial.trial_id, time, action_str))
return action
median_result = self._median_result(trials, time)
best_result = self._best_result(trial)
if self._verbose:
logger.info("Trial {} best res={} vs median res={} at t={}".format(
trial, best_result, median_result, time))
if best_result < median_result and time > self._grace_period:
if self._verbose:
logger.info("MedianStoppingRule: "
"early stopping {}".format(trial))
logger.debug("Trial {} best res={} vs median res={} at t={}".format(
trial, best_result, median_result, time))
if self._compare_op(median_result, best_result) != best_result:
logger.debug("MedianStoppingRule: early stopping {}".format(trial))
self._stopped_trials.add(trial)
if self._hard_stop:
return TrialScheduler.STOP
@ -105,33 +127,39 @@ class MedianStoppingRule(FIFOScheduler):
def on_trial_complete(self, trial_runner, trial, result):
self._results[trial].append(result)
self._completed_trials.add(trial)
def on_trial_remove(self, trial_runner, trial):
"""Marks trial as completed if it is paused and has previously ran."""
if trial.status is Trial.PAUSED and trial in self._results:
self._completed_trials.add(trial)
def debug_string(self):
return "Using MedianStoppingRule: num_stopped={}.".format(
len(self._stopped_trials))
def _get_median_result(self, time):
scores = []
for trial in self._completed_trials:
scores.append(self._running_result(trial, time))
if len(scores) >= self._min_samples_required:
return np.median(scores)
else:
return float("-inf")
def _on_insufficient_samples(self, trial_runner, trial, time):
pause = time - self._last_pause[trial] > self._min_time_slice
pause = pause and [
t for t in trial_runner.get_trials()
if t.status in (Trial.PENDING, Trial.PAUSED)
]
return TrialScheduler.PAUSE if pause else TrialScheduler.CONTINUE
def _running_result(self, trial, t_max=float("inf")):
def _trials_beyond_time(self, time):
trials = [
trial for trial in self._results
if self._results[trial][-1][self._time_attr] >= time
]
return trials
def _median_result(self, trials, time):
return np.median([self._running_mean(trial, time) for trial in trials])
def _running_mean(self, trial, time):
results = self._results[trial]
# TODO(ekl) we could do interpolation to be more precise, but for now
# assume len(results) is large and the time diffs are roughly equal
return self._metric_op * np.mean(
[r[self._metric] for r in results if r[self._time_attr] <= t_max])
scoped_results = [
r for r in results
if self._grace_period <= r[self._time_attr] <= time
]
return np.mean([r[self._metric] for r in scoped_results])
def _best_result(self, trial):
results = self._results[trial]
return max(self._metric_op * r[self._metric] for r in results)
return self._compare_op([r[self._metric] for r in results])

View file

@ -36,6 +36,12 @@ def result(t, rew):
time_total_s=t, episode_reward_mean=rew, training_iteration=int(t))
def mock_trial_runner(trials=None):
trial_runner = MagicMock()
trial_runner.get_trials.return_value = trials or []
return trial_runner
class EarlyStoppingSuite(unittest.TestCase):
def setUp(self):
ray.init()
@ -47,93 +53,105 @@ class EarlyStoppingSuite(unittest.TestCase):
def basicSetup(self, rule):
t1 = Trial("PPO") # mean is 450, max 900, t_max=10
t2 = Trial("PPO") # mean is 450, max 450, t_max=5
runner = mock_trial_runner()
for i in range(10):
r1 = result(i, i * 100)
print("basicSetup:", i)
self.assertEqual(
rule.on_trial_result(None, t1, result(i, i * 100)),
TrialScheduler.CONTINUE)
rule.on_trial_result(runner, t1, r1), TrialScheduler.CONTINUE)
for i in range(5):
r2 = result(i, 450)
self.assertEqual(
rule.on_trial_result(None, t2, result(i, 450)),
TrialScheduler.CONTINUE)
rule.on_trial_result(runner, t2, r2), TrialScheduler.CONTINUE)
return t1, t2
def testMedianStoppingConstantPerf(self):
rule = MedianStoppingRule(grace_period=0, min_samples_required=1)
t1, t2 = self.basicSetup(rule)
rule.on_trial_complete(None, t1, result(10, 1000))
runner = mock_trial_runner()
rule.on_trial_complete(runner, t1, result(10, 1000))
self.assertEqual(
rule.on_trial_result(None, t2, result(5, 450)),
rule.on_trial_result(runner, t2, result(5, 450)),
TrialScheduler.CONTINUE)
self.assertEqual(
rule.on_trial_result(None, t2, result(6, 0)),
rule.on_trial_result(runner, t2, result(6, 0)),
TrialScheduler.CONTINUE)
self.assertEqual(
rule.on_trial_result(None, t2, result(10, 450)),
rule.on_trial_result(runner, t2, result(10, 450)),
TrialScheduler.STOP)
def testMedianStoppingOnCompleteOnly(self):
rule = MedianStoppingRule(grace_period=0, min_samples_required=1)
t1, t2 = self.basicSetup(rule)
runner = mock_trial_runner()
self.assertEqual(
rule.on_trial_result(None, t2, result(100, 0)),
rule.on_trial_result(runner, t2, result(100, 0)),
TrialScheduler.CONTINUE)
rule.on_trial_complete(None, t1, result(10, 1000))
rule.on_trial_complete(runner, t1, result(101, 1000))
self.assertEqual(
rule.on_trial_result(None, t2, result(101, 0)),
rule.on_trial_result(runner, t2, result(101, 0)),
TrialScheduler.STOP)
def testMedianStoppingGracePeriod(self):
rule = MedianStoppingRule(grace_period=2.5, min_samples_required=1)
t1, t2 = self.basicSetup(rule)
rule.on_trial_complete(None, t1, result(10, 1000))
rule.on_trial_complete(None, t2, result(10, 1000))
runner = mock_trial_runner()
rule.on_trial_complete(runner, t1, result(10, 1000))
rule.on_trial_complete(runner, t2, result(10, 1000))
t3 = Trial("PPO")
self.assertEqual(
rule.on_trial_result(None, t3, result(1, 10)),
rule.on_trial_result(runner, t3, result(1, 10)),
TrialScheduler.CONTINUE)
self.assertEqual(
rule.on_trial_result(None, t3, result(2, 10)),
rule.on_trial_result(runner, t3, result(2, 10)),
TrialScheduler.CONTINUE)
self.assertEqual(
rule.on_trial_result(None, t3, result(3, 10)), TrialScheduler.STOP)
rule.on_trial_result(runner, t3, result(3, 10)),
TrialScheduler.STOP)
def testMedianStoppingMinSamples(self):
rule = MedianStoppingRule(grace_period=0, min_samples_required=2)
t1, t2 = self.basicSetup(rule)
rule.on_trial_complete(None, t1, result(10, 1000))
runner = mock_trial_runner()
rule.on_trial_complete(runner, t1, result(10, 1000))
t3 = Trial("PPO")
# Insufficient samples to evaluate t3
self.assertEqual(
rule.on_trial_result(None, t3, result(3, 10)),
rule.on_trial_result(runner, t3, result(5, 10)),
TrialScheduler.CONTINUE)
rule.on_trial_complete(None, t2, result(10, 1000))
rule.on_trial_complete(runner, t2, result(5, 1000))
# Sufficient samples to evaluate t3
self.assertEqual(
rule.on_trial_result(None, t3, result(3, 10)), TrialScheduler.STOP)
rule.on_trial_result(runner, t3, result(5, 10)),
TrialScheduler.STOP)
def testMedianStoppingUsesMedian(self):
rule = MedianStoppingRule(grace_period=0, min_samples_required=1)
t1, t2 = self.basicSetup(rule)
rule.on_trial_complete(None, t1, result(10, 1000))
rule.on_trial_complete(None, t2, result(10, 1000))
runner = mock_trial_runner()
rule.on_trial_complete(runner, t1, result(10, 1000))
rule.on_trial_complete(runner, t2, result(10, 1000))
t3 = Trial("PPO")
self.assertEqual(
rule.on_trial_result(None, t3, result(1, 260)),
rule.on_trial_result(runner, t3, result(1, 260)),
TrialScheduler.CONTINUE)
self.assertEqual(
rule.on_trial_result(None, t3, result(2, 260)),
rule.on_trial_result(runner, t3, result(2, 260)),
TrialScheduler.STOP)
def testMedianStoppingSoftStop(self):
rule = MedianStoppingRule(
grace_period=0, min_samples_required=1, hard_stop=False)
t1, t2 = self.basicSetup(rule)
rule.on_trial_complete(None, t1, result(10, 1000))
rule.on_trial_complete(None, t2, result(10, 1000))
runner = mock_trial_runner()
rule.on_trial_complete(runner, t1, result(10, 1000))
rule.on_trial_complete(runner, t2, result(10, 1000))
t3 = Trial("PPO")
self.assertEqual(
rule.on_trial_result(None, t3, result(1, 260)),
rule.on_trial_result(runner, t3, result(1, 260)),
TrialScheduler.CONTINUE)
self.assertEqual(
rule.on_trial_result(None, t3, result(2, 260)),
rule.on_trial_result(runner, t3, result(2, 260)),
TrialScheduler.PAUSE)
def _test_metrics(self, result_func, metric, mode):
@ -145,20 +163,21 @@ class EarlyStoppingSuite(unittest.TestCase):
mode=mode)
t1 = Trial("PPO") # mean is 450, max 900, t_max=10
t2 = Trial("PPO") # mean is 450, max 450, t_max=5
runner = mock_trial_runner()
for i in range(10):
self.assertEqual(
rule.on_trial_result(None, t1, result_func(i, i * 100)),
rule.on_trial_result(runner, t1, result_func(i, i * 100)),
TrialScheduler.CONTINUE)
for i in range(5):
self.assertEqual(
rule.on_trial_result(None, t2, result_func(i, 450)),
rule.on_trial_result(runner, t2, result_func(i, 450)),
TrialScheduler.CONTINUE)
rule.on_trial_complete(None, t1, result_func(10, 1000))
rule.on_trial_complete(runner, t1, result_func(10, 1000))
self.assertEqual(
rule.on_trial_result(None, t2, result_func(5, 450)),
rule.on_trial_result(runner, t2, result_func(5, 450)),
TrialScheduler.CONTINUE)
self.assertEqual(
rule.on_trial_result(None, t2, result_func(6, 0)),
rule.on_trial_result(runner, t2, result_func(6, 0)),
TrialScheduler.CONTINUE)
def testAlternateMetrics(self):