mirror of
https://github.com/vale981/ray
synced 2025-03-08 19:41:38 -05:00
[tune] OptunaSearch: check compatibility of search space with evaluated_rewards (#18625)
Co-authored-by: Antoni Baum <antoni.baum@protonmail.com> Co-authored-by: Kai Fricke <krfricke@users.noreply.github.com>
This commit is contained in:
parent
99b1d8c95f
commit
882f7d3863
2 changed files with 70 additions and 1 deletions
|
@ -114,6 +114,13 @@ class OptunaSearch(Searcher):
|
||||||
needing to re-compute the trial. Must be the same length as
|
needing to re-compute the trial. Must be the same length as
|
||||||
points_to_evaluate.
|
points_to_evaluate.
|
||||||
|
|
||||||
|
..warning::
|
||||||
|
When using ``evaluated_rewards``, the search space ``space``
|
||||||
|
must be provided as a :class:`dict` with parameter names as
|
||||||
|
keys and ``optuna.distributions`` instances as values. The
|
||||||
|
define-by-run search space definition is not yet supported with
|
||||||
|
this functionality.
|
||||||
|
|
||||||
Tune automatically converts search spaces to Optuna's format:
|
Tune automatically converts search spaces to Optuna's format:
|
||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
@ -139,7 +146,7 @@ class OptunaSearch(Searcher):
|
||||||
from ray.tune.suggest.optuna import OptunaSearch
|
from ray.tune.suggest.optuna import OptunaSearch
|
||||||
import optuna
|
import optuna
|
||||||
|
|
||||||
config = {
|
space = {
|
||||||
"a": optuna.distributions.UniformDistribution(6, 8),
|
"a": optuna.distributions.UniformDistribution(6, 8),
|
||||||
"b": optuna.distributions.LogUniformDistribution(1e-4, 1e-2),
|
"b": optuna.distributions.LogUniformDistribution(1e-4, 1e-2),
|
||||||
}
|
}
|
||||||
|
@ -166,6 +173,49 @@ class OptunaSearch(Searcher):
|
||||||
|
|
||||||
tune.run(trainable, search_alg=optuna_search)
|
tune.run(trainable, search_alg=optuna_search)
|
||||||
|
|
||||||
|
You can pass configs that will be evaluated first using
|
||||||
|
``points_to_evaluate``:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
from ray.tune.suggest.optuna import OptunaSearch
|
||||||
|
import optuna
|
||||||
|
|
||||||
|
space = {
|
||||||
|
"a": optuna.distributions.UniformDistribution(6, 8),
|
||||||
|
"b": optuna.distributions.LogUniformDistribution(1e-4, 1e-2),
|
||||||
|
}
|
||||||
|
|
||||||
|
optuna_search = OptunaSearch(
|
||||||
|
space,
|
||||||
|
points_to_evaluate=[{"a": 6.5, "b": 5e-4}, {"a": 7.5, "b": 1e-3}]
|
||||||
|
metric="loss",
|
||||||
|
mode="min")
|
||||||
|
|
||||||
|
tune.run(trainable, search_alg=optuna_search)
|
||||||
|
|
||||||
|
Avoid re-running evaluated trials by passing the rewards together with
|
||||||
|
`points_to_evaluate`:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
from ray.tune.suggest.optuna import OptunaSearch
|
||||||
|
import optuna
|
||||||
|
|
||||||
|
space = {
|
||||||
|
"a": optuna.distributions.UniformDistribution(6, 8),
|
||||||
|
"b": optuna.distributions.LogUniformDistribution(1e-4, 1e-2),
|
||||||
|
}
|
||||||
|
|
||||||
|
optuna_search = OptunaSearch(
|
||||||
|
space,
|
||||||
|
points_to_evaluate=[{"a": 6.5, "b": 5e-4}, {"a": 7.5, "b": 1e-3}]
|
||||||
|
evaluated_rewards=[0.89, 0.42]
|
||||||
|
metric="loss",
|
||||||
|
mode="min")
|
||||||
|
|
||||||
|
tune.run(trainable, search_alg=optuna_search)
|
||||||
|
|
||||||
.. versionadded:: 0.8.8
|
.. versionadded:: 0.8.8
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
@ -367,6 +417,12 @@ class OptunaSearch(Searcher):
|
||||||
cls=self.__class__.__name__,
|
cls=self.__class__.__name__,
|
||||||
metric=self._metric,
|
metric=self._metric,
|
||||||
mode=self._mode))
|
mode=self._mode))
|
||||||
|
if callable(self._space):
|
||||||
|
raise TypeError(
|
||||||
|
"Define-by-run function passed in `space` argument is not "
|
||||||
|
"yet supported when using `evaluated_rewards`. Please provide "
|
||||||
|
"an `OptunaDistribution` dict or pass a Ray Tune "
|
||||||
|
"search space to `tune.run()`.")
|
||||||
|
|
||||||
ot_trial_state = OptunaTrialState.COMPLETE
|
ot_trial_state = OptunaTrialState.COMPLETE
|
||||||
if error:
|
if error:
|
||||||
|
|
|
@ -317,6 +317,19 @@ class AddEvaluatedPointTest(unittest.TestCase):
|
||||||
self.assertTrue(
|
self.assertTrue(
|
||||||
searcher._ot_study.trials[-1].state == TrialState.PRUNED)
|
searcher._ot_study.trials[-1].state == TrialState.PRUNED)
|
||||||
|
|
||||||
|
def dbr_space(trial):
|
||||||
|
return {
|
||||||
|
self.param_name: trial.suggest_float(self.param_name, 0.0, 5.0)
|
||||||
|
}
|
||||||
|
|
||||||
|
dbr_searcher = OptunaSearch(
|
||||||
|
space=dbr_space,
|
||||||
|
metric="metric",
|
||||||
|
mode="max",
|
||||||
|
)
|
||||||
|
with self.assertRaises(TypeError):
|
||||||
|
dbr_searcher.add_evaluated_point(point, 1.0)
|
||||||
|
|
||||||
def testHEBO(self):
|
def testHEBO(self):
|
||||||
from ray.tune.suggest.hebo import HEBOSearch
|
from ray.tune.suggest.hebo import HEBOSearch
|
||||||
|
|
||||||
|
|
Loading…
Add table
Reference in a new issue