[Tune] Restore old max concurrent logic in BOHB (#26529)

As discussed on Ray Slack (https://ray-distributed.slack.com/archives/CNECXMW22/p1657051287814569), the changes introduced in #18770 and #20822 have caused the concurrency limiting logic in BOHB to work incorrectly. This PR restores the old logic, while making use of the set_max_concurrency API (as eg. HEBO), maintaining backwards compatibility.

It should be noted that the old logic this PR reintroduces is essentially a hack and should be refactored in the future. This PR is intended to rapidly fix a bug causing search performance to be suboptimal.

Signed-off-by: Antoni Baum <antoni.baum@protonmail.com>

Co-authored-by: Kai Fricke <krfricke@users.noreply.github.com>
This commit is contained in:
Antoni Baum 2022-07-14 16:40:51 +02:00 committed by GitHub
parent c54916bc0f
commit c168c09281
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 70 additions and 0 deletions

View file

@ -97,10 +97,27 @@ class HyperBandForBOHB(HyperBandScheduler):
if not bracket.filled() or any(
status != Trial.PAUSED for t, status in statuses if t is not trial
):
# BOHB Specific. This hack existed in old Ray versions
# and was removed, but it needs to be brought back
# as otherwise the BOHB doesn't behave as intended.
# The default concurrency limiter works by discarding
# new suggestions if there are more running trials
# than the limit. That doesn't take into account paused
# trials. With BOHB, this leads to N trials finishing
# completely and then another N trials starting,
# instead of trials being paused and resumed in brackets
# as intended.
# There should be a better API for this.
# TODO(team-ml): Refactor alongside HyperBandForBOHB
trial_runner._search_alg.searcher.on_pause(trial.trial_id)
return TrialScheduler.PAUSE
action = self._process_bracket(trial_runner, bracket)
return action
def _unpause_trial(self, trial_runner: "trial_runner.TrialRunner", trial: Trial):
# Hack. See comment in on_trial_result
trial_runner._search_alg.searcher.on_unpause(trial.trial_id)
def choose_trial_to_run(
self, trial_runner: "trial_runner.TrialRunner", allow_recurse: bool = True
) -> Optional[Trial]:

View file

@ -272,11 +272,16 @@ class HyperBandScheduler(FIFOScheduler):
)
if bracket.continue_trial(t):
if t.status == Trial.PAUSED:
self._unpause_trial(trial_runner, t)
t.status = Trial.PENDING
elif t.status == Trial.RUNNING:
action = TrialScheduler.CONTINUE
return action
def _unpause_trial(self, trial_runner: "trial_runner.TrialRunner", trial: Trial):
"""No-op by default."""
return
def on_trial_remove(self, trial_runner: "trial_runner.TrialRunner", trial: Trial):
"""Notification when trial terminates.

View file

@ -74,6 +74,10 @@ class TuneBOHB(Searcher):
seed: Optional random seed to initialize the random number
generator. Setting this should lead to identical initial
configurations at each run.
max_concurrent: Number of maximum concurrent trials.
If this Searcher is used in a ``ConcurrencyLimiter``, the
``max_concurrent`` value passed to it will override the
value passed here. Set to <= 0 for no limit on concurrency.
Tune automatically converts search spaces to TuneBOHB's format:
@ -128,6 +132,7 @@ class TuneBOHB(Searcher):
mode: Optional[str] = None,
points_to_evaluate: Optional[List[Dict]] = None,
seed: Optional[int] = None,
max_concurrent: int = 0,
):
assert (
BOHB is not None
@ -152,6 +157,10 @@ class TuneBOHB(Searcher):
self._space = space
self._seed = seed
self.running = set()
self.paused = set()
self._max_concurrent = max_concurrent
self._points_to_evaluate = points_to_evaluate
super(TuneBOHB, self).__init__(
@ -162,6 +171,10 @@ class TuneBOHB(Searcher):
if self._space:
self._setup_bohb()
def set_max_concurrency(self, max_concurrent: int) -> bool:
self._max_concurrent = max_concurrent
return True
def _setup_bohb(self):
from hpbandster.optimizers.config_generators.bohb import BOHB
@ -177,6 +190,9 @@ class TuneBOHB(Searcher):
if self._seed is not None:
self._space.seed(self._seed)
self.running = set()
self.paused = set()
bohb_config = self._bohb_config or {}
self.bohber = BOHB(self._space, **bohb_config)
@ -211,15 +227,24 @@ class TuneBOHB(Searcher):
)
)
max_concurrent = (
self._max_concurrent if self._max_concurrent > 0 else float("inf")
)
if len(self.running) >= max_concurrent:
return None
if self._points_to_evaluate:
config = self._points_to_evaluate.pop(0)
else:
# This parameter is not used in hpbandster implementation.
config, _ = self.bohber.get_config(None)
self.trial_to_params[trial_id] = copy.deepcopy(config)
self.running.add(trial_id)
return unflatten_list_dict(config)
def on_trial_result(self, trial_id: str, result: Dict):
if trial_id not in self.paused:
self.running.add(trial_id)
if "hyperband_info" not in result:
logger.warning(
"BOHB Info not detected in result. Are you using "
@ -233,6 +258,8 @@ class TuneBOHB(Searcher):
self, trial_id: str, result: Optional[Dict] = None, error: bool = False
):
del self.trial_to_params[trial_id]
self.paused.discard(trial_id)
self.running.discard(trial_id)
def to_wrapper(self, trial_id: str, result: Dict) -> _BOHBJobWrapper:
return _BOHBJobWrapper(
@ -241,6 +268,16 @@ class TuneBOHB(Searcher):
self.trial_to_params[trial_id],
)
# BOHB Specific.
# TODO(team-ml): Refactor alongside HyperBandForBOHB
def on_pause(self, trial_id: str):
self.paused.add(trial_id)
self.running.remove(trial_id)
def on_unpause(self, trial_id: str):
self.paused.remove(trial_id)
self.running.add(trial_id)
@staticmethod
def convert_search_space(spec: Dict) -> "ConfigSpace.ConfigurationSpace":
resolved_vars, domain_vars, grid_vars = parse_spec_vars(spec)

View file

@ -161,3 +161,11 @@ class ConcurrencyLimiter(Searcher):
def restore(self, checkpoint_path: str):
self.searcher.restore(checkpoint_path)
# BOHB Specific.
# TODO(team-ml): Refactor alongside HyperBandForBOHB
def on_pause(self, trial_id: str):
self.searcher.on_pause(trial_id)
def on_unpause(self, trial_id: str):
self.searcher.on_unpause(trial_id)

View file

@ -733,6 +733,8 @@ class BOHBSuite(unittest.TestCase):
decision = sched.on_trial_result(runner, trials[-1], spy_result)
self.assertEqual(decision, TrialScheduler.STOP)
sched.choose_trial_to_run(runner)
self.assertEqual(runner._search_alg.searcher.on_pause.call_count, 2)
self.assertEqual(runner._search_alg.searcher.on_unpause.call_count, 1)
self.assertTrue("hyperband_info" in spy_result)
self.assertEqual(spy_result["hyperband_info"]["budget"], 1)
@ -759,6 +761,7 @@ class BOHBSuite(unittest.TestCase):
decision = sched.on_trial_result(runner, trials[-1], spy_result)
self.assertEqual(decision, TrialScheduler.CONTINUE)
sched.choose_trial_to_run(runner)
self.assertEqual(runner._search_alg.searcher.on_pause.call_count, 2)
self.assertTrue("hyperband_info" in spy_result)
self.assertEqual(spy_result["hyperband_info"]["budget"], 1)