[tune] Hotfix Ax breakage when fixing backwards-compat (#8285)

This commit is contained in:
Richard Liaw 2020-05-02 20:42:50 -07:00 committed by GitHub
parent eda526c154
commit 40dfb337bf
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 38 additions and 29 deletions

View file

@ -16,8 +16,6 @@ class AxSearch(Searcher):
Facebook for configuring and optimizing experiments. More information Facebook for configuring and optimizing experiments. More information
can be found in https://ax.dev/. can be found in https://ax.dev/.
This module manages its own concurrency.
Parameters: Parameters:
parameters (list[dict]): Parameters in the experiment search space. parameters (list[dict]): Parameters in the experiment search space.
Required elements in the dictionaries are: "name" (name of Required elements in the dictionaries are: "name" (name of
@ -30,18 +28,18 @@ class AxSearch(Searcher):
experiment. This metric must be present in `raw_data` argument experiment. This metric must be present in `raw_data` argument
to `log_data`. This metric must also be present in the dict to `log_data`. This metric must also be present in the dict
reported/returned by the Trainable. reported/returned by the Trainable.
max_concurrent (int): Number of maximum concurrent trials. Defaults
to 10.
mode (str): One of {min, max}. Determines whether objective is mode (str): One of {min, max}. Determines whether objective is
minimizing or maximizing the metric attribute. Defaults to "max". minimizing or maximizing the metric attribute. Defaults to "max".
parameter_constraints (list[str]): Parameter constraints, such as parameter_constraints (list[str]): Parameter constraints, such as
"x3 >= x4" or "x3 + x4 >= 2". "x3 >= x4" or "x3 + x4 >= 2".
outcome_constraints (list[str]): Outcome constraints of form outcome_constraints (list[str]): Outcome constraints of form
"metric_name >= bound", like "m1 <= 3." "metric_name >= bound", like "m1 <= 3."
max_concurrent (int): Deprecated.
use_early_stopped_trials: Deprecated. use_early_stopped_trials: Deprecated.
.. code-block:: python .. code-block:: python
from ax.service.ax_client import AxClient
from ray import tune from ray import tune
from ray.tune.suggest.ax import AxSearch from ray.tune.suggest.ax import AxSearch
@ -50,40 +48,45 @@ class AxSearch(Searcher):
{"name": "x2", "type": "range", "bounds": [0.0, 1.0]}, {"name": "x2", "type": "range", "bounds": [0.0, 1.0]},
] ]
algo = AxSearch(parameters=parameters, def easy_objective(config):
objective_name="hartmann6", max_concurrent=4) for i in range(100):
tune.run(my_func, algo=algo) intermediate_result = config["x1"] + config["x2"] * i
tune.track.log(score=intermediate_result)
client = AxClient(enforce_sequential_optimization=False)
client.create_experiment(parameters=parameters, objective_name="score")
algo = AxSearch(client)
tune.run(easy_objective, search_alg=algo)
""" """
def __init__(self, def __init__(self,
ax_client, ax_client,
max_concurrent=10,
mode="max", mode="max",
use_early_stopped_trials=None): use_early_stopped_trials=None,
max_concurrent=None):
assert ax is not None, "Ax must be installed!" assert ax is not None, "Ax must be installed!"
assert type(max_concurrent) is int and max_concurrent > 0
self._ax = ax_client self._ax = ax_client
exp = self._ax.experiment exp = self._ax.experiment
self._objective_name = exp.optimization_config.objective.metric.name self._objective_name = exp.optimization_config.objective.metric.name
if self._ax._enforce_sequential_optimization:
logger.warning("Detected sequential enforcement. Setting max "
"concurrency to 1.")
max_concurrent = 1
self.max_concurrent = max_concurrent self.max_concurrent = max_concurrent
self._parameters = list(exp.parameters) self._parameters = list(exp.parameters)
self._live_index_mapping = {} self._live_trial_mapping = {}
super(AxSearch, self).__init__( super(AxSearch, self).__init__(
metric=self._objective_name, metric=self._objective_name,
mode=mode, mode=mode,
max_concurrent=max_concurrent,
use_early_stopped_trials=use_early_stopped_trials) use_early_stopped_trials=use_early_stopped_trials)
if self._ax._enforce_sequential_optimization:
logger.warning("Detected sequential enforcement. Be sure to use "
"a ConcurrencyLimiter.")
def suggest(self, trial_id): def suggest(self, trial_id):
if self.max_concurrent: if self.max_concurrent:
if len(self._live_trial_mapping) >= self.max_concurrent: if len(self._live_trial_mapping) >= self.max_concurrent:
return None return None
parameters, trial_index = self._ax.get_next_trial() parameters, trial_index = self._ax.get_next_trial()
self._live_index_mapping[trial_id] = trial_index self._live_trial_mapping[trial_id] = trial_index
return parameters return parameters
def on_trial_complete(self, trial_id, result=None, error=False): def on_trial_complete(self, trial_id, result=None, error=False):
@ -93,10 +96,10 @@ class AxSearch(Searcher):
""" """
if result: if result:
self._process_result(trial_id, result) self._process_result(trial_id, result)
self._live_index_mapping.pop(trial_id) self._live_trial_mapping.pop(trial_id)
def _process_result(self, trial_id, result): def _process_result(self, trial_id, result):
ax_trial_index = self._live_index_mapping[trial_id] ax_trial_index = self._live_trial_mapping[trial_id]
metric_dict = { metric_dict = {
self._objective_name: (result[self._objective_name], 0.0) self._objective_name: (result[self._objective_name], 0.0)
} }

View file

@ -195,7 +195,9 @@ class HyperoptWarmStartTest(AbstractWarmStartTest, unittest.TestCase):
metric="loss", metric="loss",
mode="min", mode="min",
random_state_seed=5, random_state_seed=5,
n_initial_points=1) n_initial_points=1,
max_concurrent=1000 # Here to avoid breaking back-compat.
)
return search_alg, cost return search_alg, cost
@ -228,9 +230,11 @@ class SkoptWarmStartTest(AbstractWarmStartTest, unittest.TestCase):
reporter(loss=(space["height"]**2 + space["width"]**2)) reporter(loss=(space["height"]**2 + space["width"]**2))
search_alg = SkOptSearch( search_alg = SkOptSearch(
optimizer, ["width", "height"], optimizer,
["width", "height"],
metric="loss", metric="loss",
mode="min", mode="min",
max_concurrent=1000, # Here to avoid breaking back-compat.
points_to_evaluate=previously_run_params, points_to_evaluate=previously_run_params,
evaluated_rewards=known_rewards) evaluated_rewards=known_rewards)
return search_alg, cost return search_alg, cost
@ -243,11 +247,15 @@ class NevergradWarmStartTest(AbstractWarmStartTest, unittest.TestCase):
optimizer = optimizerlib.OnePlusOne(instrumentation) optimizer = optimizerlib.OnePlusOne(instrumentation)
def cost(space, reporter): def cost(space, reporter):
reporter( reporter(loss=(space["height"] - 14)**2 - abs(space["width"] - 3))
mean_loss=(space["height"] - 14)**2 - abs(space["width"] - 3))
search_alg = NevergradSearch( search_alg = NevergradSearch(
optimizer, parameter_names, metric="mean_loss", mode="min") optimizer,
parameter_names,
metric="loss",
mode="min",
max_concurrent=1000, # Here to avoid breaking back-compat.
)
return search_alg, cost return search_alg, cost
@ -273,14 +281,13 @@ class SigOptWarmStartTest(AbstractWarmStartTest, unittest.TestCase):
] ]
def cost(space, reporter): def cost(space, reporter):
reporter( reporter(loss=(space["height"] - 14)**2 - abs(space["width"] - 3))
mean_loss=(space["height"] - 14)**2 - abs(space["width"] - 3))
search_alg = SigOptSearch( search_alg = SigOptSearch(
space, space,
name="SigOpt Example Experiment", name="SigOpt Example Experiment",
max_concurrent=1, max_concurrent=1,
metric="mean_loss", metric="loss",
mode="min") mode="min")
return search_alg, cost return search_alg, cost
@ -298,9 +305,8 @@ class ZOOptWarmStartTest(AbstractWarmStartTest, unittest.TestCase):
"width": (ValueType.DISCRETE, [0, 20], False) "width": (ValueType.DISCRETE, [0, 20], False)
} }
def cost(dim_dict, reporter): def cost(param, reporter):
reporter( reporter(loss=(param["height"] - 14)**2 - abs(param["width"] - 3))
loss=(dim_dict["height"] - 14)**2 - abs(dim_dict["width"] - 3))
search_alg = ZOOptSearch( search_alg = ZOOptSearch(
algo="Asracos", # only support ASRacos currently algo="Asracos", # only support ASRacos currently