mirror of
https://github.com/vale981/ray
synced 2025-03-05 18:11:42 -05:00
[Tune] Restore old max concurrent logic in BOHB (#26529)
As discussed on Ray Slack (https://ray-distributed.slack.com/archives/CNECXMW22/p1657051287814569), the changes introduced in #18770 and #20822 have caused the concurrency limiting logic in BOHB to work incorrectly. This PR restores the old logic, while making use of the set_max_concurrency API (as eg. HEBO), maintaining backwards compatibility. It should be noted that the old logic this PR reintroduces is essentially a hack and should be refactored in the future. This PR is intended to rapidly fix a bug causing search performance to be suboptimal. Signed-off-by: Antoni Baum <antoni.baum@protonmail.com> Co-authored-by: Kai Fricke <krfricke@users.noreply.github.com>
This commit is contained in:
parent
c54916bc0f
commit
c168c09281
5 changed files with 70 additions and 0 deletions
|
@ -97,10 +97,27 @@ class HyperBandForBOHB(HyperBandScheduler):
|
|||
if not bracket.filled() or any(
|
||||
status != Trial.PAUSED for t, status in statuses if t is not trial
|
||||
):
|
||||
# BOHB Specific. This hack existed in old Ray versions
|
||||
# and was removed, but it needs to be brought back
|
||||
# as otherwise the BOHB doesn't behave as intended.
|
||||
# The default concurrency limiter works by discarding
|
||||
# new suggestions if there are more running trials
|
||||
# than the limit. That doesn't take into account paused
|
||||
# trials. With BOHB, this leads to N trials finishing
|
||||
# completely and then another N trials starting,
|
||||
# instead of trials being paused and resumed in brackets
|
||||
# as intended.
|
||||
# There should be a better API for this.
|
||||
# TODO(team-ml): Refactor alongside HyperBandForBOHB
|
||||
trial_runner._search_alg.searcher.on_pause(trial.trial_id)
|
||||
return TrialScheduler.PAUSE
|
||||
action = self._process_bracket(trial_runner, bracket)
|
||||
return action
|
||||
|
||||
def _unpause_trial(self, trial_runner: "trial_runner.TrialRunner", trial: Trial):
|
||||
# Hack. See comment in on_trial_result
|
||||
trial_runner._search_alg.searcher.on_unpause(trial.trial_id)
|
||||
|
||||
def choose_trial_to_run(
|
||||
self, trial_runner: "trial_runner.TrialRunner", allow_recurse: bool = True
|
||||
) -> Optional[Trial]:
|
||||
|
|
|
@ -272,11 +272,16 @@ class HyperBandScheduler(FIFOScheduler):
|
|||
)
|
||||
if bracket.continue_trial(t):
|
||||
if t.status == Trial.PAUSED:
|
||||
self._unpause_trial(trial_runner, t)
|
||||
t.status = Trial.PENDING
|
||||
elif t.status == Trial.RUNNING:
|
||||
action = TrialScheduler.CONTINUE
|
||||
return action
|
||||
|
||||
def _unpause_trial(self, trial_runner: "trial_runner.TrialRunner", trial: Trial):
|
||||
"""No-op by default."""
|
||||
return
|
||||
|
||||
def on_trial_remove(self, trial_runner: "trial_runner.TrialRunner", trial: Trial):
|
||||
"""Notification when trial terminates.
|
||||
|
||||
|
|
|
@ -74,6 +74,10 @@ class TuneBOHB(Searcher):
|
|||
seed: Optional random seed to initialize the random number
|
||||
generator. Setting this should lead to identical initial
|
||||
configurations at each run.
|
||||
max_concurrent: Number of maximum concurrent trials.
|
||||
If this Searcher is used in a ``ConcurrencyLimiter``, the
|
||||
``max_concurrent`` value passed to it will override the
|
||||
value passed here. Set to <= 0 for no limit on concurrency.
|
||||
|
||||
Tune automatically converts search spaces to TuneBOHB's format:
|
||||
|
||||
|
@ -128,6 +132,7 @@ class TuneBOHB(Searcher):
|
|||
mode: Optional[str] = None,
|
||||
points_to_evaluate: Optional[List[Dict]] = None,
|
||||
seed: Optional[int] = None,
|
||||
max_concurrent: int = 0,
|
||||
):
|
||||
assert (
|
||||
BOHB is not None
|
||||
|
@ -152,6 +157,10 @@ class TuneBOHB(Searcher):
|
|||
self._space = space
|
||||
self._seed = seed
|
||||
|
||||
self.running = set()
|
||||
self.paused = set()
|
||||
|
||||
self._max_concurrent = max_concurrent
|
||||
self._points_to_evaluate = points_to_evaluate
|
||||
|
||||
super(TuneBOHB, self).__init__(
|
||||
|
@ -162,6 +171,10 @@ class TuneBOHB(Searcher):
|
|||
if self._space:
|
||||
self._setup_bohb()
|
||||
|
||||
def set_max_concurrency(self, max_concurrent: int) -> bool:
|
||||
self._max_concurrent = max_concurrent
|
||||
return True
|
||||
|
||||
def _setup_bohb(self):
|
||||
from hpbandster.optimizers.config_generators.bohb import BOHB
|
||||
|
||||
|
@ -177,6 +190,9 @@ class TuneBOHB(Searcher):
|
|||
if self._seed is not None:
|
||||
self._space.seed(self._seed)
|
||||
|
||||
self.running = set()
|
||||
self.paused = set()
|
||||
|
||||
bohb_config = self._bohb_config or {}
|
||||
self.bohber = BOHB(self._space, **bohb_config)
|
||||
|
||||
|
@ -211,15 +227,24 @@ class TuneBOHB(Searcher):
|
|||
)
|
||||
)
|
||||
|
||||
max_concurrent = (
|
||||
self._max_concurrent if self._max_concurrent > 0 else float("inf")
|
||||
)
|
||||
if len(self.running) >= max_concurrent:
|
||||
return None
|
||||
|
||||
if self._points_to_evaluate:
|
||||
config = self._points_to_evaluate.pop(0)
|
||||
else:
|
||||
# This parameter is not used in hpbandster implementation.
|
||||
config, _ = self.bohber.get_config(None)
|
||||
self.trial_to_params[trial_id] = copy.deepcopy(config)
|
||||
self.running.add(trial_id)
|
||||
return unflatten_list_dict(config)
|
||||
|
||||
def on_trial_result(self, trial_id: str, result: Dict):
|
||||
if trial_id not in self.paused:
|
||||
self.running.add(trial_id)
|
||||
if "hyperband_info" not in result:
|
||||
logger.warning(
|
||||
"BOHB Info not detected in result. Are you using "
|
||||
|
@ -233,6 +258,8 @@ class TuneBOHB(Searcher):
|
|||
self, trial_id: str, result: Optional[Dict] = None, error: bool = False
|
||||
):
|
||||
del self.trial_to_params[trial_id]
|
||||
self.paused.discard(trial_id)
|
||||
self.running.discard(trial_id)
|
||||
|
||||
def to_wrapper(self, trial_id: str, result: Dict) -> _BOHBJobWrapper:
|
||||
return _BOHBJobWrapper(
|
||||
|
@ -241,6 +268,16 @@ class TuneBOHB(Searcher):
|
|||
self.trial_to_params[trial_id],
|
||||
)
|
||||
|
||||
# BOHB Specific.
|
||||
# TODO(team-ml): Refactor alongside HyperBandForBOHB
|
||||
def on_pause(self, trial_id: str):
|
||||
self.paused.add(trial_id)
|
||||
self.running.remove(trial_id)
|
||||
|
||||
def on_unpause(self, trial_id: str):
|
||||
self.paused.remove(trial_id)
|
||||
self.running.add(trial_id)
|
||||
|
||||
@staticmethod
|
||||
def convert_search_space(spec: Dict) -> "ConfigSpace.ConfigurationSpace":
|
||||
resolved_vars, domain_vars, grid_vars = parse_spec_vars(spec)
|
||||
|
|
|
@ -161,3 +161,11 @@ class ConcurrencyLimiter(Searcher):
|
|||
|
||||
def restore(self, checkpoint_path: str):
|
||||
self.searcher.restore(checkpoint_path)
|
||||
|
||||
# BOHB Specific.
|
||||
# TODO(team-ml): Refactor alongside HyperBandForBOHB
|
||||
def on_pause(self, trial_id: str):
|
||||
self.searcher.on_pause(trial_id)
|
||||
|
||||
def on_unpause(self, trial_id: str):
|
||||
self.searcher.on_unpause(trial_id)
|
||||
|
|
|
@ -733,6 +733,8 @@ class BOHBSuite(unittest.TestCase):
|
|||
decision = sched.on_trial_result(runner, trials[-1], spy_result)
|
||||
self.assertEqual(decision, TrialScheduler.STOP)
|
||||
sched.choose_trial_to_run(runner)
|
||||
self.assertEqual(runner._search_alg.searcher.on_pause.call_count, 2)
|
||||
self.assertEqual(runner._search_alg.searcher.on_unpause.call_count, 1)
|
||||
self.assertTrue("hyperband_info" in spy_result)
|
||||
self.assertEqual(spy_result["hyperband_info"]["budget"], 1)
|
||||
|
||||
|
@ -759,6 +761,7 @@ class BOHBSuite(unittest.TestCase):
|
|||
decision = sched.on_trial_result(runner, trials[-1], spy_result)
|
||||
self.assertEqual(decision, TrialScheduler.CONTINUE)
|
||||
sched.choose_trial_to_run(runner)
|
||||
self.assertEqual(runner._search_alg.searcher.on_pause.call_count, 2)
|
||||
self.assertTrue("hyperband_info" in spy_result)
|
||||
self.assertEqual(spy_result["hyperband_info"]["budget"], 1)
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue