mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
[tune] MedianStopping on result (#5402)
* added class median_stopping_result to schedulers and updated __init__ * Dicts flatten and combine schedulers. MedianStoppingRule is now combined with MedianStoppingResult; I think the functionality is essentially the same so there's no need to duplicate. Dict flattening was already taken care of in a separate PR, so I've reverted that. * lint * revert * remove time sharing and simplify state * fix * fixtests * added class median_stopping_result to schedulers and updated __init__ * update property names and types to reflect suggestions by ray developers, merged get_median_result and get_best_result into a single method to eliminate duplicate steps, added resource check on PAUSE condition, modified utility function to use updated properties * updated tests for median_stopping_result in separate file * remove stray characters from previous merge conflict * reformatted and cleaned up dependencies from running code format and linting * added class median_stopping_result to schedulers and updated __init__ * Dicts flatten and combine schedulers. MedianStoppingRule is now combined with MedianStoppingResult; I think the functionality is essentially the same so there's no need to duplicate. Dict flattening was already taken care of in a separate PR, so I've reverted that. * lint * revert * remove time sharing and simplify state * fix * added class median_stopping_result to schedulers and updated __init__ * update property names and types to reflect suggestions by ray developers, merged get_median_result and get_best_result into a single method to eliminate duplicate steps, added resource check on PAUSE condition, modified utility function to use updated properties * updated tests for median_stopping_result in separate file * remove stray characters from previous merge conflict * reformatted and cleaned up dependencies from running code format and linting * update scheduler to coordinate eval interval * modify median_stopping_result to synchronize result evaluation at regular intervals, driven by least common interval * add some logging info to median_result * add new scheduler, SyncMedianStoppingResult, which evaluates and stops trials in a synchronous fashion * Cleanup median_stopping_rule - remove eval_interval - pause trials with insufficient samples if there are other waiting trials - compute score only for trials that have reached result_time * Remove extraneous classes * Fix median stopping rule tests * Added min_time_slice flag to reduce potential checkpointing cost * Only compute mean after grace * Relegate logging to debug mode
This commit is contained in:
parent
486abedcdf
commit
054583ffe6
2 changed files with 121 additions and 74 deletions
|
@ -27,13 +27,18 @@ class MedianStoppingRule(FIFOScheduler):
|
|||
mode (str): One of {min, max}. Determines whether objective is
|
||||
minimizing or maximizing the metric attribute.
|
||||
grace_period (float): Only stop trials at least this old in time.
|
||||
The mean will only be computed from this time onwards. The units
|
||||
are the same as the attribute named by `time_attr`.
|
||||
min_samples_required (int): Minimum number of trials to compute median
|
||||
over.
|
||||
min_time_slice (float): Each trial runs at least this long before
|
||||
yielding (assuming it isn't stopped). Note: trials ONLY yield if
|
||||
there are not enough samples to evaluate performance for the
|
||||
current result AND there are other trials waiting to run.
|
||||
The units are the same as the attribute named by `time_attr`.
|
||||
min_samples_required (int): Min samples to compute median over.
|
||||
hard_stop (bool): If False, pauses trials instead of stopping
|
||||
them. When all other trials are complete, paused trials will be
|
||||
resumed and allowed to run FIFO.
|
||||
verbose (bool): If True, will output the median and best result each
|
||||
time a trial reports. Defaults to True.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
|
@ -43,10 +48,9 @@ class MedianStoppingRule(FIFOScheduler):
|
|||
mode="max",
|
||||
grace_period=60.0,
|
||||
min_samples_required=3,
|
||||
hard_stop=True,
|
||||
verbose=True):
|
||||
min_time_slice=0,
|
||||
hard_stop=True):
|
||||
assert mode in ["min", "max"], "`mode` must be 'min' or 'max'!"
|
||||
|
||||
if reward_attr is not None:
|
||||
mode = "max"
|
||||
metric = reward_attr
|
||||
|
@ -54,21 +58,20 @@ class MedianStoppingRule(FIFOScheduler):
|
|||
"`reward_attr` is deprecated and will be removed in a future "
|
||||
"version of Tune. "
|
||||
"Setting `metric={}` and `mode=max`.".format(reward_attr))
|
||||
|
||||
FIFOScheduler.__init__(self)
|
||||
self._stopped_trials = set()
|
||||
self._completed_trials = set()
|
||||
self._results = collections.defaultdict(list)
|
||||
self._grace_period = grace_period
|
||||
self._min_samples_required = min_samples_required
|
||||
self._min_time_slice = min_time_slice
|
||||
self._metric = metric
|
||||
if mode == "max":
|
||||
self._metric_op = 1.
|
||||
elif mode == "min":
|
||||
self._metric_op = -1.
|
||||
assert mode in {"min", "max"}, "`mode` must be 'min' or 'max'."
|
||||
self._worst = float("-inf") if mode == "max" else float("inf")
|
||||
self._compare_op = max if mode == "max" else min
|
||||
self._time_attr = time_attr
|
||||
self._hard_stop = hard_stop
|
||||
self._verbose = verbose
|
||||
self._trial_state = {}
|
||||
self._last_pause = collections.defaultdict(lambda: float("-inf"))
|
||||
self._results = collections.defaultdict(list)
|
||||
|
||||
def on_trial_result(self, trial_runner, trial, result):
|
||||
"""Callback for early stopping.
|
||||
|
@ -82,19 +85,38 @@ class MedianStoppingRule(FIFOScheduler):
|
|||
|
||||
if trial in self._stopped_trials:
|
||||
assert not self._hard_stop
|
||||
return TrialScheduler.CONTINUE # fall back to FIFO
|
||||
# Fall back to FIFO
|
||||
return TrialScheduler.CONTINUE
|
||||
|
||||
time = result[self._time_attr]
|
||||
self._results[trial].append(result)
|
||||
median_result = self._get_median_result(time)
|
||||
|
||||
if time < self._grace_period:
|
||||
return TrialScheduler.CONTINUE
|
||||
|
||||
trials = self._trials_beyond_time(time)
|
||||
trials.remove(trial)
|
||||
|
||||
if len(trials) < self._min_samples_required:
|
||||
action = self._on_insufficient_samples(trial_runner, trial, time)
|
||||
if action == TrialScheduler.PAUSE:
|
||||
self._last_pause[trial] = time
|
||||
action_str = "Yielding time to other trials."
|
||||
else:
|
||||
action_str = "Continuing anyways."
|
||||
logger.debug(
|
||||
"MedianStoppingRule: insufficient samples={} to evaluate "
|
||||
"trial {} at t={}. {}".format(
|
||||
len(trials), trial.trial_id, time, action_str))
|
||||
return action
|
||||
|
||||
median_result = self._median_result(trials, time)
|
||||
best_result = self._best_result(trial)
|
||||
if self._verbose:
|
||||
logger.info("Trial {} best res={} vs median res={} at t={}".format(
|
||||
trial, best_result, median_result, time))
|
||||
if best_result < median_result and time > self._grace_period:
|
||||
if self._verbose:
|
||||
logger.info("MedianStoppingRule: "
|
||||
"early stopping {}".format(trial))
|
||||
logger.debug("Trial {} best res={} vs median res={} at t={}".format(
|
||||
trial, best_result, median_result, time))
|
||||
|
||||
if self._compare_op(median_result, best_result) != best_result:
|
||||
logger.debug("MedianStoppingRule: early stopping {}".format(trial))
|
||||
self._stopped_trials.add(trial)
|
||||
if self._hard_stop:
|
||||
return TrialScheduler.STOP
|
||||
|
@ -105,33 +127,39 @@ class MedianStoppingRule(FIFOScheduler):
|
|||
|
||||
def on_trial_complete(self, trial_runner, trial, result):
|
||||
self._results[trial].append(result)
|
||||
self._completed_trials.add(trial)
|
||||
|
||||
def on_trial_remove(self, trial_runner, trial):
|
||||
"""Marks trial as completed if it is paused and has previously ran."""
|
||||
if trial.status is Trial.PAUSED and trial in self._results:
|
||||
self._completed_trials.add(trial)
|
||||
|
||||
def debug_string(self):
|
||||
return "Using MedianStoppingRule: num_stopped={}.".format(
|
||||
len(self._stopped_trials))
|
||||
|
||||
def _get_median_result(self, time):
|
||||
scores = []
|
||||
for trial in self._completed_trials:
|
||||
scores.append(self._running_result(trial, time))
|
||||
if len(scores) >= self._min_samples_required:
|
||||
return np.median(scores)
|
||||
else:
|
||||
return float("-inf")
|
||||
def _on_insufficient_samples(self, trial_runner, trial, time):
|
||||
pause = time - self._last_pause[trial] > self._min_time_slice
|
||||
pause = pause and [
|
||||
t for t in trial_runner.get_trials()
|
||||
if t.status in (Trial.PENDING, Trial.PAUSED)
|
||||
]
|
||||
return TrialScheduler.PAUSE if pause else TrialScheduler.CONTINUE
|
||||
|
||||
def _running_result(self, trial, t_max=float("inf")):
|
||||
def _trials_beyond_time(self, time):
|
||||
trials = [
|
||||
trial for trial in self._results
|
||||
if self._results[trial][-1][self._time_attr] >= time
|
||||
]
|
||||
return trials
|
||||
|
||||
def _median_result(self, trials, time):
|
||||
return np.median([self._running_mean(trial, time) for trial in trials])
|
||||
|
||||
def _running_mean(self, trial, time):
|
||||
results = self._results[trial]
|
||||
# TODO(ekl) we could do interpolation to be more precise, but for now
|
||||
# assume len(results) is large and the time diffs are roughly equal
|
||||
return self._metric_op * np.mean(
|
||||
[r[self._metric] for r in results if r[self._time_attr] <= t_max])
|
||||
scoped_results = [
|
||||
r for r in results
|
||||
if self._grace_period <= r[self._time_attr] <= time
|
||||
]
|
||||
return np.mean([r[self._metric] for r in scoped_results])
|
||||
|
||||
def _best_result(self, trial):
|
||||
results = self._results[trial]
|
||||
return max(self._metric_op * r[self._metric] for r in results)
|
||||
return self._compare_op([r[self._metric] for r in results])
|
||||
|
|
|
@ -36,6 +36,12 @@ def result(t, rew):
|
|||
time_total_s=t, episode_reward_mean=rew, training_iteration=int(t))
|
||||
|
||||
|
||||
def mock_trial_runner(trials=None):
|
||||
trial_runner = MagicMock()
|
||||
trial_runner.get_trials.return_value = trials or []
|
||||
return trial_runner
|
||||
|
||||
|
||||
class EarlyStoppingSuite(unittest.TestCase):
|
||||
def setUp(self):
|
||||
ray.init()
|
||||
|
@ -47,93 +53,105 @@ class EarlyStoppingSuite(unittest.TestCase):
|
|||
def basicSetup(self, rule):
|
||||
t1 = Trial("PPO") # mean is 450, max 900, t_max=10
|
||||
t2 = Trial("PPO") # mean is 450, max 450, t_max=5
|
||||
runner = mock_trial_runner()
|
||||
for i in range(10):
|
||||
r1 = result(i, i * 100)
|
||||
print("basicSetup:", i)
|
||||
self.assertEqual(
|
||||
rule.on_trial_result(None, t1, result(i, i * 100)),
|
||||
TrialScheduler.CONTINUE)
|
||||
rule.on_trial_result(runner, t1, r1), TrialScheduler.CONTINUE)
|
||||
for i in range(5):
|
||||
r2 = result(i, 450)
|
||||
self.assertEqual(
|
||||
rule.on_trial_result(None, t2, result(i, 450)),
|
||||
TrialScheduler.CONTINUE)
|
||||
rule.on_trial_result(runner, t2, r2), TrialScheduler.CONTINUE)
|
||||
return t1, t2
|
||||
|
||||
def testMedianStoppingConstantPerf(self):
|
||||
rule = MedianStoppingRule(grace_period=0, min_samples_required=1)
|
||||
t1, t2 = self.basicSetup(rule)
|
||||
rule.on_trial_complete(None, t1, result(10, 1000))
|
||||
runner = mock_trial_runner()
|
||||
rule.on_trial_complete(runner, t1, result(10, 1000))
|
||||
self.assertEqual(
|
||||
rule.on_trial_result(None, t2, result(5, 450)),
|
||||
rule.on_trial_result(runner, t2, result(5, 450)),
|
||||
TrialScheduler.CONTINUE)
|
||||
self.assertEqual(
|
||||
rule.on_trial_result(None, t2, result(6, 0)),
|
||||
rule.on_trial_result(runner, t2, result(6, 0)),
|
||||
TrialScheduler.CONTINUE)
|
||||
self.assertEqual(
|
||||
rule.on_trial_result(None, t2, result(10, 450)),
|
||||
rule.on_trial_result(runner, t2, result(10, 450)),
|
||||
TrialScheduler.STOP)
|
||||
|
||||
def testMedianStoppingOnCompleteOnly(self):
|
||||
rule = MedianStoppingRule(grace_period=0, min_samples_required=1)
|
||||
t1, t2 = self.basicSetup(rule)
|
||||
runner = mock_trial_runner()
|
||||
self.assertEqual(
|
||||
rule.on_trial_result(None, t2, result(100, 0)),
|
||||
rule.on_trial_result(runner, t2, result(100, 0)),
|
||||
TrialScheduler.CONTINUE)
|
||||
rule.on_trial_complete(None, t1, result(10, 1000))
|
||||
rule.on_trial_complete(runner, t1, result(101, 1000))
|
||||
self.assertEqual(
|
||||
rule.on_trial_result(None, t2, result(101, 0)),
|
||||
rule.on_trial_result(runner, t2, result(101, 0)),
|
||||
TrialScheduler.STOP)
|
||||
|
||||
def testMedianStoppingGracePeriod(self):
|
||||
rule = MedianStoppingRule(grace_period=2.5, min_samples_required=1)
|
||||
t1, t2 = self.basicSetup(rule)
|
||||
rule.on_trial_complete(None, t1, result(10, 1000))
|
||||
rule.on_trial_complete(None, t2, result(10, 1000))
|
||||
runner = mock_trial_runner()
|
||||
rule.on_trial_complete(runner, t1, result(10, 1000))
|
||||
rule.on_trial_complete(runner, t2, result(10, 1000))
|
||||
t3 = Trial("PPO")
|
||||
self.assertEqual(
|
||||
rule.on_trial_result(None, t3, result(1, 10)),
|
||||
rule.on_trial_result(runner, t3, result(1, 10)),
|
||||
TrialScheduler.CONTINUE)
|
||||
self.assertEqual(
|
||||
rule.on_trial_result(None, t3, result(2, 10)),
|
||||
rule.on_trial_result(runner, t3, result(2, 10)),
|
||||
TrialScheduler.CONTINUE)
|
||||
self.assertEqual(
|
||||
rule.on_trial_result(None, t3, result(3, 10)), TrialScheduler.STOP)
|
||||
rule.on_trial_result(runner, t3, result(3, 10)),
|
||||
TrialScheduler.STOP)
|
||||
|
||||
def testMedianStoppingMinSamples(self):
|
||||
rule = MedianStoppingRule(grace_period=0, min_samples_required=2)
|
||||
t1, t2 = self.basicSetup(rule)
|
||||
rule.on_trial_complete(None, t1, result(10, 1000))
|
||||
runner = mock_trial_runner()
|
||||
rule.on_trial_complete(runner, t1, result(10, 1000))
|
||||
t3 = Trial("PPO")
|
||||
# Insufficient samples to evaluate t3
|
||||
self.assertEqual(
|
||||
rule.on_trial_result(None, t3, result(3, 10)),
|
||||
rule.on_trial_result(runner, t3, result(5, 10)),
|
||||
TrialScheduler.CONTINUE)
|
||||
rule.on_trial_complete(None, t2, result(10, 1000))
|
||||
rule.on_trial_complete(runner, t2, result(5, 1000))
|
||||
# Sufficient samples to evaluate t3
|
||||
self.assertEqual(
|
||||
rule.on_trial_result(None, t3, result(3, 10)), TrialScheduler.STOP)
|
||||
rule.on_trial_result(runner, t3, result(5, 10)),
|
||||
TrialScheduler.STOP)
|
||||
|
||||
def testMedianStoppingUsesMedian(self):
|
||||
rule = MedianStoppingRule(grace_period=0, min_samples_required=1)
|
||||
t1, t2 = self.basicSetup(rule)
|
||||
rule.on_trial_complete(None, t1, result(10, 1000))
|
||||
rule.on_trial_complete(None, t2, result(10, 1000))
|
||||
runner = mock_trial_runner()
|
||||
rule.on_trial_complete(runner, t1, result(10, 1000))
|
||||
rule.on_trial_complete(runner, t2, result(10, 1000))
|
||||
t3 = Trial("PPO")
|
||||
self.assertEqual(
|
||||
rule.on_trial_result(None, t3, result(1, 260)),
|
||||
rule.on_trial_result(runner, t3, result(1, 260)),
|
||||
TrialScheduler.CONTINUE)
|
||||
self.assertEqual(
|
||||
rule.on_trial_result(None, t3, result(2, 260)),
|
||||
rule.on_trial_result(runner, t3, result(2, 260)),
|
||||
TrialScheduler.STOP)
|
||||
|
||||
def testMedianStoppingSoftStop(self):
|
||||
rule = MedianStoppingRule(
|
||||
grace_period=0, min_samples_required=1, hard_stop=False)
|
||||
t1, t2 = self.basicSetup(rule)
|
||||
rule.on_trial_complete(None, t1, result(10, 1000))
|
||||
rule.on_trial_complete(None, t2, result(10, 1000))
|
||||
runner = mock_trial_runner()
|
||||
rule.on_trial_complete(runner, t1, result(10, 1000))
|
||||
rule.on_trial_complete(runner, t2, result(10, 1000))
|
||||
t3 = Trial("PPO")
|
||||
self.assertEqual(
|
||||
rule.on_trial_result(None, t3, result(1, 260)),
|
||||
rule.on_trial_result(runner, t3, result(1, 260)),
|
||||
TrialScheduler.CONTINUE)
|
||||
self.assertEqual(
|
||||
rule.on_trial_result(None, t3, result(2, 260)),
|
||||
rule.on_trial_result(runner, t3, result(2, 260)),
|
||||
TrialScheduler.PAUSE)
|
||||
|
||||
def _test_metrics(self, result_func, metric, mode):
|
||||
|
@ -145,20 +163,21 @@ class EarlyStoppingSuite(unittest.TestCase):
|
|||
mode=mode)
|
||||
t1 = Trial("PPO") # mean is 450, max 900, t_max=10
|
||||
t2 = Trial("PPO") # mean is 450, max 450, t_max=5
|
||||
runner = mock_trial_runner()
|
||||
for i in range(10):
|
||||
self.assertEqual(
|
||||
rule.on_trial_result(None, t1, result_func(i, i * 100)),
|
||||
rule.on_trial_result(runner, t1, result_func(i, i * 100)),
|
||||
TrialScheduler.CONTINUE)
|
||||
for i in range(5):
|
||||
self.assertEqual(
|
||||
rule.on_trial_result(None, t2, result_func(i, 450)),
|
||||
rule.on_trial_result(runner, t2, result_func(i, 450)),
|
||||
TrialScheduler.CONTINUE)
|
||||
rule.on_trial_complete(None, t1, result_func(10, 1000))
|
||||
rule.on_trial_complete(runner, t1, result_func(10, 1000))
|
||||
self.assertEqual(
|
||||
rule.on_trial_result(None, t2, result_func(5, 450)),
|
||||
rule.on_trial_result(runner, t2, result_func(5, 450)),
|
||||
TrialScheduler.CONTINUE)
|
||||
self.assertEqual(
|
||||
rule.on_trial_result(None, t2, result_func(6, 0)),
|
||||
rule.on_trial_result(runner, t2, result_func(6, 0)),
|
||||
TrialScheduler.CONTINUE)
|
||||
|
||||
def testAlternateMetrics(self):
|
||||
|
|
Loading…
Add table
Reference in a new issue