mirror of
https://github.com/vale981/ray
synced 2025-03-06 18:41:40 -05:00
[tune] Hotfix Ax breakage when fixing backwards-compat (#8285)
This commit is contained in:
parent
eda526c154
commit
40dfb337bf
2 changed files with 38 additions and 29 deletions
|
@ -16,8 +16,6 @@ class AxSearch(Searcher):
|
||||||
Facebook for configuring and optimizing experiments. More information
|
Facebook for configuring and optimizing experiments. More information
|
||||||
can be found in https://ax.dev/.
|
can be found in https://ax.dev/.
|
||||||
|
|
||||||
This module manages its own concurrency.
|
|
||||||
|
|
||||||
Parameters:
|
Parameters:
|
||||||
parameters (list[dict]): Parameters in the experiment search space.
|
parameters (list[dict]): Parameters in the experiment search space.
|
||||||
Required elements in the dictionaries are: "name" (name of
|
Required elements in the dictionaries are: "name" (name of
|
||||||
|
@ -30,18 +28,18 @@ class AxSearch(Searcher):
|
||||||
experiment. This metric must be present in `raw_data` argument
|
experiment. This metric must be present in `raw_data` argument
|
||||||
to `log_data`. This metric must also be present in the dict
|
to `log_data`. This metric must also be present in the dict
|
||||||
reported/returned by the Trainable.
|
reported/returned by the Trainable.
|
||||||
max_concurrent (int): Number of maximum concurrent trials. Defaults
|
|
||||||
to 10.
|
|
||||||
mode (str): One of {min, max}. Determines whether objective is
|
mode (str): One of {min, max}. Determines whether objective is
|
||||||
minimizing or maximizing the metric attribute. Defaults to "max".
|
minimizing or maximizing the metric attribute. Defaults to "max".
|
||||||
parameter_constraints (list[str]): Parameter constraints, such as
|
parameter_constraints (list[str]): Parameter constraints, such as
|
||||||
"x3 >= x4" or "x3 + x4 >= 2".
|
"x3 >= x4" or "x3 + x4 >= 2".
|
||||||
outcome_constraints (list[str]): Outcome constraints of form
|
outcome_constraints (list[str]): Outcome constraints of form
|
||||||
"metric_name >= bound", like "m1 <= 3."
|
"metric_name >= bound", like "m1 <= 3."
|
||||||
|
max_concurrent (int): Deprecated.
|
||||||
use_early_stopped_trials: Deprecated.
|
use_early_stopped_trials: Deprecated.
|
||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
|
from ax.service.ax_client import AxClient
|
||||||
from ray import tune
|
from ray import tune
|
||||||
from ray.tune.suggest.ax import AxSearch
|
from ray.tune.suggest.ax import AxSearch
|
||||||
|
|
||||||
|
@ -50,40 +48,45 @@ class AxSearch(Searcher):
|
||||||
{"name": "x2", "type": "range", "bounds": [0.0, 1.0]},
|
{"name": "x2", "type": "range", "bounds": [0.0, 1.0]},
|
||||||
]
|
]
|
||||||
|
|
||||||
algo = AxSearch(parameters=parameters,
|
def easy_objective(config):
|
||||||
objective_name="hartmann6", max_concurrent=4)
|
for i in range(100):
|
||||||
tune.run(my_func, algo=algo)
|
intermediate_result = config["x1"] + config["x2"] * i
|
||||||
|
tune.track.log(score=intermediate_result)
|
||||||
|
|
||||||
|
client = AxClient(enforce_sequential_optimization=False)
|
||||||
|
client.create_experiment(parameters=parameters, objective_name="score")
|
||||||
|
algo = AxSearch(client)
|
||||||
|
tune.run(easy_objective, search_alg=algo)
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
ax_client,
|
ax_client,
|
||||||
max_concurrent=10,
|
|
||||||
mode="max",
|
mode="max",
|
||||||
use_early_stopped_trials=None):
|
use_early_stopped_trials=None,
|
||||||
|
max_concurrent=None):
|
||||||
assert ax is not None, "Ax must be installed!"
|
assert ax is not None, "Ax must be installed!"
|
||||||
assert type(max_concurrent) is int and max_concurrent > 0
|
|
||||||
self._ax = ax_client
|
self._ax = ax_client
|
||||||
exp = self._ax.experiment
|
exp = self._ax.experiment
|
||||||
self._objective_name = exp.optimization_config.objective.metric.name
|
self._objective_name = exp.optimization_config.objective.metric.name
|
||||||
if self._ax._enforce_sequential_optimization:
|
|
||||||
logger.warning("Detected sequential enforcement. Setting max "
|
|
||||||
"concurrency to 1.")
|
|
||||||
max_concurrent = 1
|
|
||||||
self.max_concurrent = max_concurrent
|
self.max_concurrent = max_concurrent
|
||||||
self._parameters = list(exp.parameters)
|
self._parameters = list(exp.parameters)
|
||||||
self._live_index_mapping = {}
|
self._live_trial_mapping = {}
|
||||||
super(AxSearch, self).__init__(
|
super(AxSearch, self).__init__(
|
||||||
metric=self._objective_name,
|
metric=self._objective_name,
|
||||||
mode=mode,
|
mode=mode,
|
||||||
|
max_concurrent=max_concurrent,
|
||||||
use_early_stopped_trials=use_early_stopped_trials)
|
use_early_stopped_trials=use_early_stopped_trials)
|
||||||
|
if self._ax._enforce_sequential_optimization:
|
||||||
|
logger.warning("Detected sequential enforcement. Be sure to use "
|
||||||
|
"a ConcurrencyLimiter.")
|
||||||
|
|
||||||
def suggest(self, trial_id):
|
def suggest(self, trial_id):
|
||||||
if self.max_concurrent:
|
if self.max_concurrent:
|
||||||
if len(self._live_trial_mapping) >= self.max_concurrent:
|
if len(self._live_trial_mapping) >= self.max_concurrent:
|
||||||
return None
|
return None
|
||||||
parameters, trial_index = self._ax.get_next_trial()
|
parameters, trial_index = self._ax.get_next_trial()
|
||||||
self._live_index_mapping[trial_id] = trial_index
|
self._live_trial_mapping[trial_id] = trial_index
|
||||||
return parameters
|
return parameters
|
||||||
|
|
||||||
def on_trial_complete(self, trial_id, result=None, error=False):
|
def on_trial_complete(self, trial_id, result=None, error=False):
|
||||||
|
@ -93,10 +96,10 @@ class AxSearch(Searcher):
|
||||||
"""
|
"""
|
||||||
if result:
|
if result:
|
||||||
self._process_result(trial_id, result)
|
self._process_result(trial_id, result)
|
||||||
self._live_index_mapping.pop(trial_id)
|
self._live_trial_mapping.pop(trial_id)
|
||||||
|
|
||||||
def _process_result(self, trial_id, result):
|
def _process_result(self, trial_id, result):
|
||||||
ax_trial_index = self._live_index_mapping[trial_id]
|
ax_trial_index = self._live_trial_mapping[trial_id]
|
||||||
metric_dict = {
|
metric_dict = {
|
||||||
self._objective_name: (result[self._objective_name], 0.0)
|
self._objective_name: (result[self._objective_name], 0.0)
|
||||||
}
|
}
|
||||||
|
|
|
@ -195,7 +195,9 @@ class HyperoptWarmStartTest(AbstractWarmStartTest, unittest.TestCase):
|
||||||
metric="loss",
|
metric="loss",
|
||||||
mode="min",
|
mode="min",
|
||||||
random_state_seed=5,
|
random_state_seed=5,
|
||||||
n_initial_points=1)
|
n_initial_points=1,
|
||||||
|
max_concurrent=1000 # Here to avoid breaking back-compat.
|
||||||
|
)
|
||||||
return search_alg, cost
|
return search_alg, cost
|
||||||
|
|
||||||
|
|
||||||
|
@ -228,9 +230,11 @@ class SkoptWarmStartTest(AbstractWarmStartTest, unittest.TestCase):
|
||||||
reporter(loss=(space["height"]**2 + space["width"]**2))
|
reporter(loss=(space["height"]**2 + space["width"]**2))
|
||||||
|
|
||||||
search_alg = SkOptSearch(
|
search_alg = SkOptSearch(
|
||||||
optimizer, ["width", "height"],
|
optimizer,
|
||||||
|
["width", "height"],
|
||||||
metric="loss",
|
metric="loss",
|
||||||
mode="min",
|
mode="min",
|
||||||
|
max_concurrent=1000, # Here to avoid breaking back-compat.
|
||||||
points_to_evaluate=previously_run_params,
|
points_to_evaluate=previously_run_params,
|
||||||
evaluated_rewards=known_rewards)
|
evaluated_rewards=known_rewards)
|
||||||
return search_alg, cost
|
return search_alg, cost
|
||||||
|
@ -243,11 +247,15 @@ class NevergradWarmStartTest(AbstractWarmStartTest, unittest.TestCase):
|
||||||
optimizer = optimizerlib.OnePlusOne(instrumentation)
|
optimizer = optimizerlib.OnePlusOne(instrumentation)
|
||||||
|
|
||||||
def cost(space, reporter):
|
def cost(space, reporter):
|
||||||
reporter(
|
reporter(loss=(space["height"] - 14)**2 - abs(space["width"] - 3))
|
||||||
mean_loss=(space["height"] - 14)**2 - abs(space["width"] - 3))
|
|
||||||
|
|
||||||
search_alg = NevergradSearch(
|
search_alg = NevergradSearch(
|
||||||
optimizer, parameter_names, metric="mean_loss", mode="min")
|
optimizer,
|
||||||
|
parameter_names,
|
||||||
|
metric="loss",
|
||||||
|
mode="min",
|
||||||
|
max_concurrent=1000, # Here to avoid breaking back-compat.
|
||||||
|
)
|
||||||
return search_alg, cost
|
return search_alg, cost
|
||||||
|
|
||||||
|
|
||||||
|
@ -273,14 +281,13 @@ class SigOptWarmStartTest(AbstractWarmStartTest, unittest.TestCase):
|
||||||
]
|
]
|
||||||
|
|
||||||
def cost(space, reporter):
|
def cost(space, reporter):
|
||||||
reporter(
|
reporter(loss=(space["height"] - 14)**2 - abs(space["width"] - 3))
|
||||||
mean_loss=(space["height"] - 14)**2 - abs(space["width"] - 3))
|
|
||||||
|
|
||||||
search_alg = SigOptSearch(
|
search_alg = SigOptSearch(
|
||||||
space,
|
space,
|
||||||
name="SigOpt Example Experiment",
|
name="SigOpt Example Experiment",
|
||||||
max_concurrent=1,
|
max_concurrent=1,
|
||||||
metric="mean_loss",
|
metric="loss",
|
||||||
mode="min")
|
mode="min")
|
||||||
return search_alg, cost
|
return search_alg, cost
|
||||||
|
|
||||||
|
@ -298,9 +305,8 @@ class ZOOptWarmStartTest(AbstractWarmStartTest, unittest.TestCase):
|
||||||
"width": (ValueType.DISCRETE, [0, 20], False)
|
"width": (ValueType.DISCRETE, [0, 20], False)
|
||||||
}
|
}
|
||||||
|
|
||||||
def cost(dim_dict, reporter):
|
def cost(param, reporter):
|
||||||
reporter(
|
reporter(loss=(param["height"] - 14)**2 - abs(param["width"] - 3))
|
||||||
loss=(dim_dict["height"] - 14)**2 - abs(dim_dict["width"] - 3))
|
|
||||||
|
|
||||||
search_alg = ZOOptSearch(
|
search_alg = ZOOptSearch(
|
||||||
algo="Asracos", # only support ASRacos currently
|
algo="Asracos", # only support ASRacos currently
|
||||||
|
|
Loading…
Add table
Reference in a new issue