mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
[tune] Hotfix Ax breakage when fixing backwards-compat (#8285)
This commit is contained in:
parent
eda526c154
commit
40dfb337bf
2 changed files with 38 additions and 29 deletions
|
@ -16,8 +16,6 @@ class AxSearch(Searcher):
|
|||
Facebook for configuring and optimizing experiments. More information
|
||||
can be found in https://ax.dev/.
|
||||
|
||||
This module manages its own concurrency.
|
||||
|
||||
Parameters:
|
||||
parameters (list[dict]): Parameters in the experiment search space.
|
||||
Required elements in the dictionaries are: "name" (name of
|
||||
|
@ -30,18 +28,18 @@ class AxSearch(Searcher):
|
|||
experiment. This metric must be present in `raw_data` argument
|
||||
to `log_data`. This metric must also be present in the dict
|
||||
reported/returned by the Trainable.
|
||||
max_concurrent (int): Number of maximum concurrent trials. Defaults
|
||||
to 10.
|
||||
mode (str): One of {min, max}. Determines whether objective is
|
||||
minimizing or maximizing the metric attribute. Defaults to "max".
|
||||
parameter_constraints (list[str]): Parameter constraints, such as
|
||||
"x3 >= x4" or "x3 + x4 >= 2".
|
||||
outcome_constraints (list[str]): Outcome constraints of form
|
||||
"metric_name >= bound", like "m1 <= 3."
|
||||
max_concurrent (int): Deprecated.
|
||||
use_early_stopped_trials: Deprecated.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from ax.service.ax_client import AxClient
|
||||
from ray import tune
|
||||
from ray.tune.suggest.ax import AxSearch
|
||||
|
||||
|
@ -50,40 +48,45 @@ class AxSearch(Searcher):
|
|||
{"name": "x2", "type": "range", "bounds": [0.0, 1.0]},
|
||||
]
|
||||
|
||||
algo = AxSearch(parameters=parameters,
|
||||
objective_name="hartmann6", max_concurrent=4)
|
||||
tune.run(my_func, algo=algo)
|
||||
def easy_objective(config):
|
||||
for i in range(100):
|
||||
intermediate_result = config["x1"] + config["x2"] * i
|
||||
tune.track.log(score=intermediate_result)
|
||||
|
||||
client = AxClient(enforce_sequential_optimization=False)
|
||||
client.create_experiment(parameters=parameters, objective_name="score")
|
||||
algo = AxSearch(client)
|
||||
tune.run(easy_objective, search_alg=algo)
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
ax_client,
|
||||
max_concurrent=10,
|
||||
mode="max",
|
||||
use_early_stopped_trials=None):
|
||||
use_early_stopped_trials=None,
|
||||
max_concurrent=None):
|
||||
assert ax is not None, "Ax must be installed!"
|
||||
assert type(max_concurrent) is int and max_concurrent > 0
|
||||
self._ax = ax_client
|
||||
exp = self._ax.experiment
|
||||
self._objective_name = exp.optimization_config.objective.metric.name
|
||||
if self._ax._enforce_sequential_optimization:
|
||||
logger.warning("Detected sequential enforcement. Setting max "
|
||||
"concurrency to 1.")
|
||||
max_concurrent = 1
|
||||
self.max_concurrent = max_concurrent
|
||||
self._parameters = list(exp.parameters)
|
||||
self._live_index_mapping = {}
|
||||
self._live_trial_mapping = {}
|
||||
super(AxSearch, self).__init__(
|
||||
metric=self._objective_name,
|
||||
mode=mode,
|
||||
max_concurrent=max_concurrent,
|
||||
use_early_stopped_trials=use_early_stopped_trials)
|
||||
if self._ax._enforce_sequential_optimization:
|
||||
logger.warning("Detected sequential enforcement. Be sure to use "
|
||||
"a ConcurrencyLimiter.")
|
||||
|
||||
def suggest(self, trial_id):
|
||||
if self.max_concurrent:
|
||||
if len(self._live_trial_mapping) >= self.max_concurrent:
|
||||
return None
|
||||
parameters, trial_index = self._ax.get_next_trial()
|
||||
self._live_index_mapping[trial_id] = trial_index
|
||||
self._live_trial_mapping[trial_id] = trial_index
|
||||
return parameters
|
||||
|
||||
def on_trial_complete(self, trial_id, result=None, error=False):
|
||||
|
@ -93,10 +96,10 @@ class AxSearch(Searcher):
|
|||
"""
|
||||
if result:
|
||||
self._process_result(trial_id, result)
|
||||
self._live_index_mapping.pop(trial_id)
|
||||
self._live_trial_mapping.pop(trial_id)
|
||||
|
||||
def _process_result(self, trial_id, result):
|
||||
ax_trial_index = self._live_index_mapping[trial_id]
|
||||
ax_trial_index = self._live_trial_mapping[trial_id]
|
||||
metric_dict = {
|
||||
self._objective_name: (result[self._objective_name], 0.0)
|
||||
}
|
||||
|
|
|
@ -195,7 +195,9 @@ class HyperoptWarmStartTest(AbstractWarmStartTest, unittest.TestCase):
|
|||
metric="loss",
|
||||
mode="min",
|
||||
random_state_seed=5,
|
||||
n_initial_points=1)
|
||||
n_initial_points=1,
|
||||
max_concurrent=1000 # Here to avoid breaking back-compat.
|
||||
)
|
||||
return search_alg, cost
|
||||
|
||||
|
||||
|
@ -228,9 +230,11 @@ class SkoptWarmStartTest(AbstractWarmStartTest, unittest.TestCase):
|
|||
reporter(loss=(space["height"]**2 + space["width"]**2))
|
||||
|
||||
search_alg = SkOptSearch(
|
||||
optimizer, ["width", "height"],
|
||||
optimizer,
|
||||
["width", "height"],
|
||||
metric="loss",
|
||||
mode="min",
|
||||
max_concurrent=1000, # Here to avoid breaking back-compat.
|
||||
points_to_evaluate=previously_run_params,
|
||||
evaluated_rewards=known_rewards)
|
||||
return search_alg, cost
|
||||
|
@ -243,11 +247,15 @@ class NevergradWarmStartTest(AbstractWarmStartTest, unittest.TestCase):
|
|||
optimizer = optimizerlib.OnePlusOne(instrumentation)
|
||||
|
||||
def cost(space, reporter):
|
||||
reporter(
|
||||
mean_loss=(space["height"] - 14)**2 - abs(space["width"] - 3))
|
||||
reporter(loss=(space["height"] - 14)**2 - abs(space["width"] - 3))
|
||||
|
||||
search_alg = NevergradSearch(
|
||||
optimizer, parameter_names, metric="mean_loss", mode="min")
|
||||
optimizer,
|
||||
parameter_names,
|
||||
metric="loss",
|
||||
mode="min",
|
||||
max_concurrent=1000, # Here to avoid breaking back-compat.
|
||||
)
|
||||
return search_alg, cost
|
||||
|
||||
|
||||
|
@ -273,14 +281,13 @@ class SigOptWarmStartTest(AbstractWarmStartTest, unittest.TestCase):
|
|||
]
|
||||
|
||||
def cost(space, reporter):
|
||||
reporter(
|
||||
mean_loss=(space["height"] - 14)**2 - abs(space["width"] - 3))
|
||||
reporter(loss=(space["height"] - 14)**2 - abs(space["width"] - 3))
|
||||
|
||||
search_alg = SigOptSearch(
|
||||
space,
|
||||
name="SigOpt Example Experiment",
|
||||
max_concurrent=1,
|
||||
metric="mean_loss",
|
||||
metric="loss",
|
||||
mode="min")
|
||||
return search_alg, cost
|
||||
|
||||
|
@ -298,9 +305,8 @@ class ZOOptWarmStartTest(AbstractWarmStartTest, unittest.TestCase):
|
|||
"width": (ValueType.DISCRETE, [0, 20], False)
|
||||
}
|
||||
|
||||
def cost(dim_dict, reporter):
|
||||
reporter(
|
||||
loss=(dim_dict["height"] - 14)**2 - abs(dim_dict["width"] - 3))
|
||||
def cost(param, reporter):
|
||||
reporter(loss=(param["height"] - 14)**2 - abs(param["width"] - 3))
|
||||
|
||||
search_alg = ZOOptSearch(
|
||||
algo="Asracos", # only support ASRacos currently
|
||||
|
|
Loading…
Add table
Reference in a new issue