From 054583ffe691747e51506296196c1d9b41c3b3cc Mon Sep 17 00:00:00 2001 From: waldroje <31750961+waldroje@users.noreply.github.com> Date: Tue, 8 Oct 2019 14:40:41 -0400 Subject: [PATCH] [tune] MedianStopping on result (#5402) * added class median_stopping_result to schedulers and updated __init__ * Dicts flatten and combine schedulers. MedianStoppingRule is now combined with MedianStoppingResult; I think the functionality is essentially the same so there's no need to duplicate. Dict flattening was already taken care of in a separate PR, so I've reverted that. * lint * revert * remove time sharing and simplify state * fix * fixtests * added class median_stopping_result to schedulers and updated __init__ * update property names and types to reflect suggestions by ray developers, merged get_median_result and get_best_result into a single method to eliminate duplicate steps, added resource check on PAUSE condition, modified utility function to use updated properties * updated tests for median_stopping_result in separate file * remove stray characters from previous merge conflict * reformatted and cleaned up dependencies from running code format and linting * added class median_stopping_result to schedulers and updated __init__ * Dicts flatten and combine schedulers. MedianStoppingRule is now combined with MedianStoppingResult; I think the functionality is essentially the same so there's no need to duplicate. Dict flattening was already taken care of in a separate PR, so I've reverted that. * lint * revert * remove time sharing and simplify state * fix * added class median_stopping_result to schedulers and updated __init__ * update property names and types to reflect suggestions by ray developers, merged get_median_result and get_best_result into a single method to eliminate duplicate steps, added resource check on PAUSE condition, modified utility function to use updated properties * updated tests for median_stopping_result in separate file * remove stray characters from previous merge conflict * reformatted and cleaned up dependencies from running code format and linting * update scheduler to coordinate eval interval * modify median_stopping_result to synchronize result evaluation at regular intervals, driven by least common interval * add some logging info to median_result * add new scheduler, SyncMedianStoppingResult, which evaluates and stops trials in a synchronous fashion * Cleanup median_stopping_rule - remove eval_interval - pause trials with insufficient samples if there are other waiting trials - compute score only for trials that have reached result_time * Remove extraneous classes * Fix median stopping rule tests * Added min_time_slice flag to reduce potential checkpointing cost * Only compute mean after grace * Relegate logging to debug mode --- .../tune/schedulers/median_stopping_rule.py | 110 +++++++++++------- python/ray/tune/tests/test_trial_scheduler.py | 85 ++++++++------ 2 files changed, 121 insertions(+), 74 deletions(-) diff --git a/python/ray/tune/schedulers/median_stopping_rule.py b/python/ray/tune/schedulers/median_stopping_rule.py index 3609eee73..99b4c4d80 100644 --- a/python/ray/tune/schedulers/median_stopping_rule.py +++ b/python/ray/tune/schedulers/median_stopping_rule.py @@ -27,13 +27,18 @@ class MedianStoppingRule(FIFOScheduler): mode (str): One of {min, max}. Determines whether objective is minimizing or maximizing the metric attribute. grace_period (float): Only stop trials at least this old in time. + The mean will only be computed from this time onwards. The units + are the same as the attribute named by `time_attr`. + min_samples_required (int): Minimum number of trials to compute median + over. + min_time_slice (float): Each trial runs at least this long before + yielding (assuming it isn't stopped). Note: trials ONLY yield if + there are not enough samples to evaluate performance for the + current result AND there are other trials waiting to run. The units are the same as the attribute named by `time_attr`. - min_samples_required (int): Min samples to compute median over. hard_stop (bool): If False, pauses trials instead of stopping them. When all other trials are complete, paused trials will be resumed and allowed to run FIFO. - verbose (bool): If True, will output the median and best result each - time a trial reports. Defaults to True. """ def __init__(self, @@ -43,10 +48,9 @@ class MedianStoppingRule(FIFOScheduler): mode="max", grace_period=60.0, min_samples_required=3, - hard_stop=True, - verbose=True): + min_time_slice=0, + hard_stop=True): assert mode in ["min", "max"], "`mode` must be 'min' or 'max'!" - if reward_attr is not None: mode = "max" metric = reward_attr @@ -54,21 +58,20 @@ class MedianStoppingRule(FIFOScheduler): "`reward_attr` is deprecated and will be removed in a future " "version of Tune. " "Setting `metric={}` and `mode=max`.".format(reward_attr)) - FIFOScheduler.__init__(self) self._stopped_trials = set() - self._completed_trials = set() - self._results = collections.defaultdict(list) self._grace_period = grace_period self._min_samples_required = min_samples_required + self._min_time_slice = min_time_slice self._metric = metric - if mode == "max": - self._metric_op = 1. - elif mode == "min": - self._metric_op = -1. + assert mode in {"min", "max"}, "`mode` must be 'min' or 'max'." + self._worst = float("-inf") if mode == "max" else float("inf") + self._compare_op = max if mode == "max" else min self._time_attr = time_attr self._hard_stop = hard_stop - self._verbose = verbose + self._trial_state = {} + self._last_pause = collections.defaultdict(lambda: float("-inf")) + self._results = collections.defaultdict(list) def on_trial_result(self, trial_runner, trial, result): """Callback for early stopping. @@ -82,19 +85,38 @@ class MedianStoppingRule(FIFOScheduler): if trial in self._stopped_trials: assert not self._hard_stop - return TrialScheduler.CONTINUE # fall back to FIFO + # Fall back to FIFO + return TrialScheduler.CONTINUE time = result[self._time_attr] self._results[trial].append(result) - median_result = self._get_median_result(time) + + if time < self._grace_period: + return TrialScheduler.CONTINUE + + trials = self._trials_beyond_time(time) + trials.remove(trial) + + if len(trials) < self._min_samples_required: + action = self._on_insufficient_samples(trial_runner, trial, time) + if action == TrialScheduler.PAUSE: + self._last_pause[trial] = time + action_str = "Yielding time to other trials." + else: + action_str = "Continuing anyways." + logger.debug( + "MedianStoppingRule: insufficient samples={} to evaluate " + "trial {} at t={}. {}".format( + len(trials), trial.trial_id, time, action_str)) + return action + + median_result = self._median_result(trials, time) best_result = self._best_result(trial) - if self._verbose: - logger.info("Trial {} best res={} vs median res={} at t={}".format( - trial, best_result, median_result, time)) - if best_result < median_result and time > self._grace_period: - if self._verbose: - logger.info("MedianStoppingRule: " - "early stopping {}".format(trial)) + logger.debug("Trial {} best res={} vs median res={} at t={}".format( + trial, best_result, median_result, time)) + + if self._compare_op(median_result, best_result) != best_result: + logger.debug("MedianStoppingRule: early stopping {}".format(trial)) self._stopped_trials.add(trial) if self._hard_stop: return TrialScheduler.STOP @@ -105,33 +127,39 @@ class MedianStoppingRule(FIFOScheduler): def on_trial_complete(self, trial_runner, trial, result): self._results[trial].append(result) - self._completed_trials.add(trial) - - def on_trial_remove(self, trial_runner, trial): - """Marks trial as completed if it is paused and has previously ran.""" - if trial.status is Trial.PAUSED and trial in self._results: - self._completed_trials.add(trial) def debug_string(self): return "Using MedianStoppingRule: num_stopped={}.".format( len(self._stopped_trials)) - def _get_median_result(self, time): - scores = [] - for trial in self._completed_trials: - scores.append(self._running_result(trial, time)) - if len(scores) >= self._min_samples_required: - return np.median(scores) - else: - return float("-inf") + def _on_insufficient_samples(self, trial_runner, trial, time): + pause = time - self._last_pause[trial] > self._min_time_slice + pause = pause and [ + t for t in trial_runner.get_trials() + if t.status in (Trial.PENDING, Trial.PAUSED) + ] + return TrialScheduler.PAUSE if pause else TrialScheduler.CONTINUE - def _running_result(self, trial, t_max=float("inf")): + def _trials_beyond_time(self, time): + trials = [ + trial for trial in self._results + if self._results[trial][-1][self._time_attr] >= time + ] + return trials + + def _median_result(self, trials, time): + return np.median([self._running_mean(trial, time) for trial in trials]) + + def _running_mean(self, trial, time): results = self._results[trial] # TODO(ekl) we could do interpolation to be more precise, but for now # assume len(results) is large and the time diffs are roughly equal - return self._metric_op * np.mean( - [r[self._metric] for r in results if r[self._time_attr] <= t_max]) + scoped_results = [ + r for r in results + if self._grace_period <= r[self._time_attr] <= time + ] + return np.mean([r[self._metric] for r in scoped_results]) def _best_result(self, trial): results = self._results[trial] - return max(self._metric_op * r[self._metric] for r in results) + return self._compare_op([r[self._metric] for r in results]) diff --git a/python/ray/tune/tests/test_trial_scheduler.py b/python/ray/tune/tests/test_trial_scheduler.py index 4d9eb8d07..e12cb21cf 100644 --- a/python/ray/tune/tests/test_trial_scheduler.py +++ b/python/ray/tune/tests/test_trial_scheduler.py @@ -36,6 +36,12 @@ def result(t, rew): time_total_s=t, episode_reward_mean=rew, training_iteration=int(t)) +def mock_trial_runner(trials=None): + trial_runner = MagicMock() + trial_runner.get_trials.return_value = trials or [] + return trial_runner + + class EarlyStoppingSuite(unittest.TestCase): def setUp(self): ray.init() @@ -47,93 +53,105 @@ class EarlyStoppingSuite(unittest.TestCase): def basicSetup(self, rule): t1 = Trial("PPO") # mean is 450, max 900, t_max=10 t2 = Trial("PPO") # mean is 450, max 450, t_max=5 + runner = mock_trial_runner() for i in range(10): + r1 = result(i, i * 100) + print("basicSetup:", i) self.assertEqual( - rule.on_trial_result(None, t1, result(i, i * 100)), - TrialScheduler.CONTINUE) + rule.on_trial_result(runner, t1, r1), TrialScheduler.CONTINUE) for i in range(5): + r2 = result(i, 450) self.assertEqual( - rule.on_trial_result(None, t2, result(i, 450)), - TrialScheduler.CONTINUE) + rule.on_trial_result(runner, t2, r2), TrialScheduler.CONTINUE) return t1, t2 def testMedianStoppingConstantPerf(self): rule = MedianStoppingRule(grace_period=0, min_samples_required=1) t1, t2 = self.basicSetup(rule) - rule.on_trial_complete(None, t1, result(10, 1000)) + runner = mock_trial_runner() + rule.on_trial_complete(runner, t1, result(10, 1000)) self.assertEqual( - rule.on_trial_result(None, t2, result(5, 450)), + rule.on_trial_result(runner, t2, result(5, 450)), TrialScheduler.CONTINUE) self.assertEqual( - rule.on_trial_result(None, t2, result(6, 0)), + rule.on_trial_result(runner, t2, result(6, 0)), TrialScheduler.CONTINUE) self.assertEqual( - rule.on_trial_result(None, t2, result(10, 450)), + rule.on_trial_result(runner, t2, result(10, 450)), TrialScheduler.STOP) def testMedianStoppingOnCompleteOnly(self): rule = MedianStoppingRule(grace_period=0, min_samples_required=1) t1, t2 = self.basicSetup(rule) + runner = mock_trial_runner() self.assertEqual( - rule.on_trial_result(None, t2, result(100, 0)), + rule.on_trial_result(runner, t2, result(100, 0)), TrialScheduler.CONTINUE) - rule.on_trial_complete(None, t1, result(10, 1000)) + rule.on_trial_complete(runner, t1, result(101, 1000)) self.assertEqual( - rule.on_trial_result(None, t2, result(101, 0)), + rule.on_trial_result(runner, t2, result(101, 0)), TrialScheduler.STOP) def testMedianStoppingGracePeriod(self): rule = MedianStoppingRule(grace_period=2.5, min_samples_required=1) t1, t2 = self.basicSetup(rule) - rule.on_trial_complete(None, t1, result(10, 1000)) - rule.on_trial_complete(None, t2, result(10, 1000)) + runner = mock_trial_runner() + rule.on_trial_complete(runner, t1, result(10, 1000)) + rule.on_trial_complete(runner, t2, result(10, 1000)) t3 = Trial("PPO") self.assertEqual( - rule.on_trial_result(None, t3, result(1, 10)), + rule.on_trial_result(runner, t3, result(1, 10)), TrialScheduler.CONTINUE) self.assertEqual( - rule.on_trial_result(None, t3, result(2, 10)), + rule.on_trial_result(runner, t3, result(2, 10)), TrialScheduler.CONTINUE) self.assertEqual( - rule.on_trial_result(None, t3, result(3, 10)), TrialScheduler.STOP) + rule.on_trial_result(runner, t3, result(3, 10)), + TrialScheduler.STOP) def testMedianStoppingMinSamples(self): rule = MedianStoppingRule(grace_period=0, min_samples_required=2) t1, t2 = self.basicSetup(rule) - rule.on_trial_complete(None, t1, result(10, 1000)) + runner = mock_trial_runner() + rule.on_trial_complete(runner, t1, result(10, 1000)) t3 = Trial("PPO") + # Insufficient samples to evaluate t3 self.assertEqual( - rule.on_trial_result(None, t3, result(3, 10)), + rule.on_trial_result(runner, t3, result(5, 10)), TrialScheduler.CONTINUE) - rule.on_trial_complete(None, t2, result(10, 1000)) + rule.on_trial_complete(runner, t2, result(5, 1000)) + # Sufficient samples to evaluate t3 self.assertEqual( - rule.on_trial_result(None, t3, result(3, 10)), TrialScheduler.STOP) + rule.on_trial_result(runner, t3, result(5, 10)), + TrialScheduler.STOP) def testMedianStoppingUsesMedian(self): rule = MedianStoppingRule(grace_period=0, min_samples_required=1) t1, t2 = self.basicSetup(rule) - rule.on_trial_complete(None, t1, result(10, 1000)) - rule.on_trial_complete(None, t2, result(10, 1000)) + runner = mock_trial_runner() + rule.on_trial_complete(runner, t1, result(10, 1000)) + rule.on_trial_complete(runner, t2, result(10, 1000)) t3 = Trial("PPO") self.assertEqual( - rule.on_trial_result(None, t3, result(1, 260)), + rule.on_trial_result(runner, t3, result(1, 260)), TrialScheduler.CONTINUE) self.assertEqual( - rule.on_trial_result(None, t3, result(2, 260)), + rule.on_trial_result(runner, t3, result(2, 260)), TrialScheduler.STOP) def testMedianStoppingSoftStop(self): rule = MedianStoppingRule( grace_period=0, min_samples_required=1, hard_stop=False) t1, t2 = self.basicSetup(rule) - rule.on_trial_complete(None, t1, result(10, 1000)) - rule.on_trial_complete(None, t2, result(10, 1000)) + runner = mock_trial_runner() + rule.on_trial_complete(runner, t1, result(10, 1000)) + rule.on_trial_complete(runner, t2, result(10, 1000)) t3 = Trial("PPO") self.assertEqual( - rule.on_trial_result(None, t3, result(1, 260)), + rule.on_trial_result(runner, t3, result(1, 260)), TrialScheduler.CONTINUE) self.assertEqual( - rule.on_trial_result(None, t3, result(2, 260)), + rule.on_trial_result(runner, t3, result(2, 260)), TrialScheduler.PAUSE) def _test_metrics(self, result_func, metric, mode): @@ -145,20 +163,21 @@ class EarlyStoppingSuite(unittest.TestCase): mode=mode) t1 = Trial("PPO") # mean is 450, max 900, t_max=10 t2 = Trial("PPO") # mean is 450, max 450, t_max=5 + runner = mock_trial_runner() for i in range(10): self.assertEqual( - rule.on_trial_result(None, t1, result_func(i, i * 100)), + rule.on_trial_result(runner, t1, result_func(i, i * 100)), TrialScheduler.CONTINUE) for i in range(5): self.assertEqual( - rule.on_trial_result(None, t2, result_func(i, 450)), + rule.on_trial_result(runner, t2, result_func(i, 450)), TrialScheduler.CONTINUE) - rule.on_trial_complete(None, t1, result_func(10, 1000)) + rule.on_trial_complete(runner, t1, result_func(10, 1000)) self.assertEqual( - rule.on_trial_result(None, t2, result_func(5, 450)), + rule.on_trial_result(runner, t2, result_func(5, 450)), TrialScheduler.CONTINUE) self.assertEqual( - rule.on_trial_result(None, t2, result_func(6, 0)), + rule.on_trial_result(runner, t2, result_func(6, 0)), TrialScheduler.CONTINUE) def testAlternateMetrics(self):