[tune] Move Optuna to ask/tell interface (#14387)

This commit is contained in:
Kai Fricke 2021-03-03 00:35:11 +01:00 committed by GitHub
parent bacbdd297b
commit 47603045f9
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 26 additions and 34 deletions

View file

@ -133,7 +133,7 @@ class OptunaSearch(Searcher):
self._space = space
self._points_to_evaluate = points_to_evaluate
self._points_to_evaluate = points_to_evaluate or []
self._study_name = "optuna" # Fixed study name for in-memory storage
self._sampler = sampler or ot.samplers.TPESampler()
@ -141,9 +141,6 @@ class OptunaSearch(Searcher):
"You can only pass an instance of `optuna.samplers.BaseSampler` " \
"as a sampler to `OptunaSearcher`."
self._pruner = ot.pruners.NopPruner()
self._storage = ot.storages.InMemoryStorage()
self._ot_trials = {}
self._ot_study = None
if self._space:
@ -154,14 +151,20 @@ class OptunaSearch(Searcher):
# If only a mode was passed, use anonymous metric
self._metric = DEFAULT_METRIC
pruner = ot.pruners.NopPruner()
storage = ot.storages.InMemoryStorage()
self._ot_study = ot.study.create_study(
storage=self._storage,
storage=storage,
sampler=self._sampler,
pruner=self._pruner,
pruner=pruner,
study_name=self._study_name,
direction="minimize" if mode == "min" else "maximize",
load_if_exists=True)
for point in self._points_to_evaluate:
self._ot_study.enqueue_trial(point)
def set_search_properties(self, metric: Optional[str], mode: Optional[str],
config: Dict) -> bool:
if self._space:
@ -189,21 +192,17 @@ class OptunaSearch(Searcher):
mode=self._mode))
if trial_id not in self._ot_trials:
ot_trial_id = self._storage.create_new_trial(
self._ot_study._study_id)
self._ot_trials[trial_id] = ot.trial.Trial(self._ot_study,
ot_trial_id)
self._ot_trials[trial_id] = self._ot_study.ask()
ot_trial = self._ot_trials[trial_id]
if self._points_to_evaluate:
params = self._points_to_evaluate.pop(0)
else:
# getattr will fetch the trial.suggest_ function on Optuna trials
params = {
args[0] if len(args) > 0 else kwargs["name"]: getattr(
ot_trial, fn)(*args, **kwargs)
for (fn, args, kwargs) in self._space
}
# getattr will fetch the trial.suggest_ function on Optuna trials
params = {
args[0] if len(args) > 0 else kwargs["name"]: getattr(
ot_trial, fn)(*args, **kwargs)
for (fn, args, kwargs) in self._space
}
return unflatten_dict(params)
def on_trial_result(self, trial_id: str, result: Dict):
@ -217,21 +216,15 @@ class OptunaSearch(Searcher):
result: Optional[Dict] = None,
error: bool = False):
ot_trial = self._ot_trials[trial_id]
ot_trial_id = ot_trial._trial_id
val = result.get(self.metric, None)
if hasattr(self._storage, "set_trial_value"):
# Backwards compatibility with optuna < 2.4.0
self._storage.set_trial_value(ot_trial_id, val)
else:
self._storage.set_trial_values(ot_trial_id, [val])
self._storage.set_trial_state(ot_trial_id,
ot.trial.TrialState.COMPLETE)
val = result.get(self.metric, None) if result else None
try:
self._ot_study.tell(ot_trial, val)
except ValueError as exc:
logger.warning(exc) # E.g. if NaN was reported
def save(self, checkpoint_path: str):
save_object = (self._storage, self._pruner, self._sampler,
self._ot_trials, self._ot_study,
save_object = (self._sampler, self._ot_trials, self._ot_study,
self._points_to_evaluate)
with open(checkpoint_path, "wb") as outputFile:
pickle.dump(save_object, outputFile)
@ -239,8 +232,7 @@ class OptunaSearch(Searcher):
def restore(self, checkpoint_path: str):
with open(checkpoint_path, "rb") as inputFile:
save_object = pickle.load(inputFile)
self._storage, self._pruner, self._sampler, \
self._ot_trials, self._ot_study, \
self._sampler, self._ot_trials, self._ot_study, \
self._points_to_evaluate = save_object
@staticmethod

View file

@ -18,7 +18,7 @@ matplotlib==3.3.4
mlflow==1.14.0
mxnet==1.7.0.post1
nevergrad==0.4.2.post5
optuna==2.4.0
optuna==2.5.0
pytest-remotedata==0.3.2
pytorch-lightning-bolts==0.2.5
pytorch-lightning==1.0.3