mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[tune] Move Optuna to ask/tell interface (#14387)
This commit is contained in:
parent
bacbdd297b
commit
47603045f9
2 changed files with 26 additions and 34 deletions
|
@ -133,7 +133,7 @@ class OptunaSearch(Searcher):
|
|||
|
||||
self._space = space
|
||||
|
||||
self._points_to_evaluate = points_to_evaluate
|
||||
self._points_to_evaluate = points_to_evaluate or []
|
||||
|
||||
self._study_name = "optuna" # Fixed study name for in-memory storage
|
||||
self._sampler = sampler or ot.samplers.TPESampler()
|
||||
|
@ -141,9 +141,6 @@ class OptunaSearch(Searcher):
|
|||
"You can only pass an instance of `optuna.samplers.BaseSampler` " \
|
||||
"as a sampler to `OptunaSearcher`."
|
||||
|
||||
self._pruner = ot.pruners.NopPruner()
|
||||
self._storage = ot.storages.InMemoryStorage()
|
||||
|
||||
self._ot_trials = {}
|
||||
self._ot_study = None
|
||||
if self._space:
|
||||
|
@ -154,14 +151,20 @@ class OptunaSearch(Searcher):
|
|||
# If only a mode was passed, use anonymous metric
|
||||
self._metric = DEFAULT_METRIC
|
||||
|
||||
pruner = ot.pruners.NopPruner()
|
||||
storage = ot.storages.InMemoryStorage()
|
||||
|
||||
self._ot_study = ot.study.create_study(
|
||||
storage=self._storage,
|
||||
storage=storage,
|
||||
sampler=self._sampler,
|
||||
pruner=self._pruner,
|
||||
pruner=pruner,
|
||||
study_name=self._study_name,
|
||||
direction="minimize" if mode == "min" else "maximize",
|
||||
load_if_exists=True)
|
||||
|
||||
for point in self._points_to_evaluate:
|
||||
self._ot_study.enqueue_trial(point)
|
||||
|
||||
def set_search_properties(self, metric: Optional[str], mode: Optional[str],
|
||||
config: Dict) -> bool:
|
||||
if self._space:
|
||||
|
@ -189,21 +192,17 @@ class OptunaSearch(Searcher):
|
|||
mode=self._mode))
|
||||
|
||||
if trial_id not in self._ot_trials:
|
||||
ot_trial_id = self._storage.create_new_trial(
|
||||
self._ot_study._study_id)
|
||||
self._ot_trials[trial_id] = ot.trial.Trial(self._ot_study,
|
||||
ot_trial_id)
|
||||
self._ot_trials[trial_id] = self._ot_study.ask()
|
||||
|
||||
ot_trial = self._ot_trials[trial_id]
|
||||
|
||||
if self._points_to_evaluate:
|
||||
params = self._points_to_evaluate.pop(0)
|
||||
else:
|
||||
# getattr will fetch the trial.suggest_ function on Optuna trials
|
||||
params = {
|
||||
args[0] if len(args) > 0 else kwargs["name"]: getattr(
|
||||
ot_trial, fn)(*args, **kwargs)
|
||||
for (fn, args, kwargs) in self._space
|
||||
}
|
||||
# getattr will fetch the trial.suggest_ function on Optuna trials
|
||||
params = {
|
||||
args[0] if len(args) > 0 else kwargs["name"]: getattr(
|
||||
ot_trial, fn)(*args, **kwargs)
|
||||
for (fn, args, kwargs) in self._space
|
||||
}
|
||||
|
||||
return unflatten_dict(params)
|
||||
|
||||
def on_trial_result(self, trial_id: str, result: Dict):
|
||||
|
@ -217,21 +216,15 @@ class OptunaSearch(Searcher):
|
|||
result: Optional[Dict] = None,
|
||||
error: bool = False):
|
||||
ot_trial = self._ot_trials[trial_id]
|
||||
ot_trial_id = ot_trial._trial_id
|
||||
|
||||
val = result.get(self.metric, None)
|
||||
if hasattr(self._storage, "set_trial_value"):
|
||||
# Backwards compatibility with optuna < 2.4.0
|
||||
self._storage.set_trial_value(ot_trial_id, val)
|
||||
else:
|
||||
self._storage.set_trial_values(ot_trial_id, [val])
|
||||
|
||||
self._storage.set_trial_state(ot_trial_id,
|
||||
ot.trial.TrialState.COMPLETE)
|
||||
val = result.get(self.metric, None) if result else None
|
||||
try:
|
||||
self._ot_study.tell(ot_trial, val)
|
||||
except ValueError as exc:
|
||||
logger.warning(exc) # E.g. if NaN was reported
|
||||
|
||||
def save(self, checkpoint_path: str):
|
||||
save_object = (self._storage, self._pruner, self._sampler,
|
||||
self._ot_trials, self._ot_study,
|
||||
save_object = (self._sampler, self._ot_trials, self._ot_study,
|
||||
self._points_to_evaluate)
|
||||
with open(checkpoint_path, "wb") as outputFile:
|
||||
pickle.dump(save_object, outputFile)
|
||||
|
@ -239,8 +232,7 @@ class OptunaSearch(Searcher):
|
|||
def restore(self, checkpoint_path: str):
|
||||
with open(checkpoint_path, "rb") as inputFile:
|
||||
save_object = pickle.load(inputFile)
|
||||
self._storage, self._pruner, self._sampler, \
|
||||
self._ot_trials, self._ot_study, \
|
||||
self._sampler, self._ot_trials, self._ot_study, \
|
||||
self._points_to_evaluate = save_object
|
||||
|
||||
@staticmethod
|
||||
|
|
|
@ -18,7 +18,7 @@ matplotlib==3.3.4
|
|||
mlflow==1.14.0
|
||||
mxnet==1.7.0.post1
|
||||
nevergrad==0.4.2.post5
|
||||
optuna==2.4.0
|
||||
optuna==2.5.0
|
||||
pytest-remotedata==0.3.2
|
||||
pytorch-lightning-bolts==0.2.5
|
||||
pytorch-lightning==1.0.3
|
||||
|
|
Loading…
Add table
Reference in a new issue