mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
[Tune] HEBO concurrency fix after discussion with authors (#14504)
This commit is contained in:
parent
ef944bc5f0
commit
2002cff42e
2 changed files with 29 additions and 30 deletions
|
@ -5,7 +5,6 @@ It also checks that it is usable with a separate scheduler.
|
|||
import time
|
||||
|
||||
from ray import tune
|
||||
from ray.tune.suggest import ConcurrencyLimiter
|
||||
from ray.tune.schedulers import AsyncHyperBandScheduler
|
||||
from ray.tune.suggest.hebo import HEBOSearch
|
||||
|
||||
|
@ -74,23 +73,15 @@ if __name__ == "__main__":
|
|||
]
|
||||
known_rewards = [-189, -1144]
|
||||
|
||||
# setting the n_suggestions parameter to >1 enables
|
||||
# the evolutionary part of HEBO. For best results,
|
||||
# use HEBOSearch with a ConcurrencyLimiter() set up
|
||||
# as below
|
||||
n_suggestions = 8
|
||||
# maximum number of concurrent trials
|
||||
max_concurrent = 8
|
||||
|
||||
algo = HEBOSearch(
|
||||
# space = space, # If you want to set the space
|
||||
points_to_evaluate=previously_run_params,
|
||||
evaluated_rewards=known_rewards,
|
||||
random_state_seed=123, # for reproducibility
|
||||
n_suggestions=n_suggestions,
|
||||
)
|
||||
algo = ConcurrencyLimiter(
|
||||
algo,
|
||||
max_concurrent=n_suggestions,
|
||||
batch=True,
|
||||
max_concurrent=max_concurrent,
|
||||
)
|
||||
|
||||
scheduler = AsyncHyperBandScheduler()
|
||||
|
|
|
@ -47,7 +47,11 @@ class HEBOSearch(Searcher):
|
|||
|
||||
Please note that the first few trials will be random and used
|
||||
to kickstart the search process. In order to achieve good results,
|
||||
we recommend setting the number of trials to at least 15.
|
||||
we recommend setting the number of trials to at least 16.
|
||||
|
||||
Maximum number of concurrent trials is determined by `max_concurrent`
|
||||
argument. Trials will be done in batches of `max_concurrent` trials.
|
||||
It is not recommended to use this Searcher in a `ConcurrencyLimiter`.
|
||||
|
||||
Args:
|
||||
space (dict|hebo.design_space.design_space.DesignSpace):
|
||||
|
@ -73,13 +77,7 @@ class HEBOSearch(Searcher):
|
|||
results. Defaults to None. Please note that setting this to a value
|
||||
will change global random states for `numpy` and `torch`
|
||||
on initalization and loading from checkpoint.
|
||||
n_suggestions (int, 1): If higher than one, suggestions will
|
||||
be made in batches and cached. Higher values may increase
|
||||
convergence speed in certain cases (authors recommend 8).
|
||||
For best results, wrap this searcher in a
|
||||
``ConcurrencyLimiter(max_concurrent=n, batch=True)``
|
||||
where `n == n_suggestions`.
|
||||
Refer to `tune/examples/hebo_example.py`.
|
||||
max_concurrent (int, 8): Number of maximum concurrent trials.
|
||||
**kwargs: The keyword arguments will be passed to `HEBO()``.
|
||||
|
||||
Tune automatically converts search spaces to HEBO's format:
|
||||
|
@ -126,15 +124,15 @@ class HEBOSearch(Searcher):
|
|||
points_to_evaluate: Optional[List[Dict]] = None,
|
||||
evaluated_rewards: Optional[List] = None,
|
||||
random_state_seed: Optional[int] = None,
|
||||
n_suggestions: int = 1,
|
||||
max_concurrent: int = 8,
|
||||
**kwargs):
|
||||
assert hebo is not None, (
|
||||
"HEBO must be installed!. You can install HEBO with"
|
||||
" the command: `pip install HEBO`.")
|
||||
if mode:
|
||||
assert mode in ["min", "max"], "`mode` must be 'min' or 'max'."
|
||||
assert isinstance(n_suggestions, int) and n_suggestions >= 1, (
|
||||
"`n_suggestions` must be an integer and at least 1.")
|
||||
assert isinstance(max_concurrent, int) and max_concurrent >= 1, (
|
||||
"`max_concurrent` must be an integer and at least 1.")
|
||||
if random_state_seed is not None:
|
||||
assert isinstance(
|
||||
random_state_seed, int
|
||||
|
@ -164,8 +162,9 @@ class HEBOSearch(Searcher):
|
|||
self._initial_points = []
|
||||
self._live_trial_mapping = {}
|
||||
|
||||
self._n_suggestions = n_suggestions
|
||||
self._max_concurrent = max_concurrent
|
||||
self._suggestions_cache = []
|
||||
self._batch_filled = False
|
||||
|
||||
self._opt = None
|
||||
if space:
|
||||
|
@ -240,17 +239,25 @@ class HEBOSearch(Searcher):
|
|||
metric=self._metric,
|
||||
mode=self._mode))
|
||||
|
||||
if not self._live_trial_mapping:
|
||||
self._batch_filled = False
|
||||
|
||||
if self._initial_points:
|
||||
params = self._initial_points.pop(0)
|
||||
suggestion = pd.DataFrame(params, index=[0])
|
||||
else:
|
||||
if self._batch_filled or len(
|
||||
self._live_trial_mapping) >= self._max_concurrent:
|
||||
return None
|
||||
if not self._suggestions_cache:
|
||||
suggestion = self._opt.suggest(
|
||||
n_suggestions=self._n_suggestions)
|
||||
n_suggestions=self._max_concurrent)
|
||||
self._suggestions_cache = suggestion.to_dict("records")
|
||||
params = self._suggestions_cache.pop(0)
|
||||
suggestion = pd.DataFrame(params, index=[0])
|
||||
self._live_trial_mapping[trial_id] = suggestion
|
||||
if len(self._live_trial_mapping) >= self._max_concurrent:
|
||||
self._batch_filled = True
|
||||
return unflatten_dict(params)
|
||||
|
||||
def on_trial_complete(self,
|
||||
|
@ -282,16 +289,17 @@ class HEBOSearch(Searcher):
|
|||
with open(checkpoint_path, "wb") as f:
|
||||
pickle.dump((self._opt, self._initial_points, numpy_random_state,
|
||||
torch_random_state, self._live_trial_mapping,
|
||||
self._n_suggestions, self._suggestions_cache,
|
||||
self._space, self._hebo_config), f)
|
||||
self._max_concurrent, self._suggestions_cache,
|
||||
self._space, self._hebo_config, self._batch_filled),
|
||||
f)
|
||||
|
||||
def restore(self, checkpoint_path: str):
|
||||
"""Restoring current optimizer state."""
|
||||
with open(checkpoint_path, "rb") as f:
|
||||
(self._opt, self._initial_points, numpy_random_state,
|
||||
torch_random_state, self._live_trial_mapping, self._n_suggestions,
|
||||
self._suggestions_cache, self._space,
|
||||
self._hebo_config) = pickle.load(f)
|
||||
torch_random_state, self._live_trial_mapping,
|
||||
self._max_concurrent, self._suggestions_cache, self._space,
|
||||
self._hebo_config, self._batch_filled) = pickle.load(f)
|
||||
if numpy_random_state is not None:
|
||||
np.random.set_state(numpy_random_state)
|
||||
if torch_random_state is not None:
|
||||
|
|
Loading…
Add table
Reference in a new issue