[tune] Hotfix Ax breakage when fixing backwards-compat (#8285)

This commit is contained in:
Richard Liaw 2020-05-02 20:42:50 -07:00 committed by GitHub
parent eda526c154
commit 40dfb337bf
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 38 additions and 29 deletions

View file

@ -16,8 +16,6 @@ class AxSearch(Searcher):
Facebook for configuring and optimizing experiments. More information
can be found in https://ax.dev/.
This module manages its own concurrency.
Parameters:
parameters (list[dict]): Parameters in the experiment search space.
Required elements in the dictionaries are: "name" (name of
@ -30,18 +28,18 @@ class AxSearch(Searcher):
experiment. This metric must be present in `raw_data` argument
to `log_data`. This metric must also be present in the dict
reported/returned by the Trainable.
max_concurrent (int): Number of maximum concurrent trials. Defaults
to 10.
mode (str): One of {min, max}. Determines whether objective is
minimizing or maximizing the metric attribute. Defaults to "max".
parameter_constraints (list[str]): Parameter constraints, such as
"x3 >= x4" or "x3 + x4 >= 2".
outcome_constraints (list[str]): Outcome constraints of form
"metric_name >= bound", like "m1 <= 3."
max_concurrent (int): Deprecated.
use_early_stopped_trials: Deprecated.
.. code-block:: python
from ax.service.ax_client import AxClient
from ray import tune
from ray.tune.suggest.ax import AxSearch
@ -50,40 +48,45 @@ class AxSearch(Searcher):
{"name": "x2", "type": "range", "bounds": [0.0, 1.0]},
]
algo = AxSearch(parameters=parameters,
objective_name="hartmann6", max_concurrent=4)
tune.run(my_func, algo=algo)
def easy_objective(config):
for i in range(100):
intermediate_result = config["x1"] + config["x2"] * i
tune.track.log(score=intermediate_result)
client = AxClient(enforce_sequential_optimization=False)
client.create_experiment(parameters=parameters, objective_name="score")
algo = AxSearch(client)
tune.run(easy_objective, search_alg=algo)
"""
def __init__(self,
ax_client,
max_concurrent=10,
mode="max",
use_early_stopped_trials=None):
use_early_stopped_trials=None,
max_concurrent=None):
assert ax is not None, "Ax must be installed!"
assert type(max_concurrent) is int and max_concurrent > 0
self._ax = ax_client
exp = self._ax.experiment
self._objective_name = exp.optimization_config.objective.metric.name
if self._ax._enforce_sequential_optimization:
logger.warning("Detected sequential enforcement. Setting max "
"concurrency to 1.")
max_concurrent = 1
self.max_concurrent = max_concurrent
self._parameters = list(exp.parameters)
self._live_index_mapping = {}
self._live_trial_mapping = {}
super(AxSearch, self).__init__(
metric=self._objective_name,
mode=mode,
max_concurrent=max_concurrent,
use_early_stopped_trials=use_early_stopped_trials)
if self._ax._enforce_sequential_optimization:
logger.warning("Detected sequential enforcement. Be sure to use "
"a ConcurrencyLimiter.")
def suggest(self, trial_id):
if self.max_concurrent:
if len(self._live_trial_mapping) >= self.max_concurrent:
return None
parameters, trial_index = self._ax.get_next_trial()
self._live_index_mapping[trial_id] = trial_index
self._live_trial_mapping[trial_id] = trial_index
return parameters
def on_trial_complete(self, trial_id, result=None, error=False):
@ -93,10 +96,10 @@ class AxSearch(Searcher):
"""
if result:
self._process_result(trial_id, result)
self._live_index_mapping.pop(trial_id)
self._live_trial_mapping.pop(trial_id)
def _process_result(self, trial_id, result):
ax_trial_index = self._live_index_mapping[trial_id]
ax_trial_index = self._live_trial_mapping[trial_id]
metric_dict = {
self._objective_name: (result[self._objective_name], 0.0)
}

View file

@ -195,7 +195,9 @@ class HyperoptWarmStartTest(AbstractWarmStartTest, unittest.TestCase):
metric="loss",
mode="min",
random_state_seed=5,
n_initial_points=1)
n_initial_points=1,
max_concurrent=1000 # Here to avoid breaking back-compat.
)
return search_alg, cost
@ -228,9 +230,11 @@ class SkoptWarmStartTest(AbstractWarmStartTest, unittest.TestCase):
reporter(loss=(space["height"]**2 + space["width"]**2))
search_alg = SkOptSearch(
optimizer, ["width", "height"],
optimizer,
["width", "height"],
metric="loss",
mode="min",
max_concurrent=1000, # Here to avoid breaking back-compat.
points_to_evaluate=previously_run_params,
evaluated_rewards=known_rewards)
return search_alg, cost
@ -243,11 +247,15 @@ class NevergradWarmStartTest(AbstractWarmStartTest, unittest.TestCase):
optimizer = optimizerlib.OnePlusOne(instrumentation)
def cost(space, reporter):
reporter(
mean_loss=(space["height"] - 14)**2 - abs(space["width"] - 3))
reporter(loss=(space["height"] - 14)**2 - abs(space["width"] - 3))
search_alg = NevergradSearch(
optimizer, parameter_names, metric="mean_loss", mode="min")
optimizer,
parameter_names,
metric="loss",
mode="min",
max_concurrent=1000, # Here to avoid breaking back-compat.
)
return search_alg, cost
@ -273,14 +281,13 @@ class SigOptWarmStartTest(AbstractWarmStartTest, unittest.TestCase):
]
def cost(space, reporter):
reporter(
mean_loss=(space["height"] - 14)**2 - abs(space["width"] - 3))
reporter(loss=(space["height"] - 14)**2 - abs(space["width"] - 3))
search_alg = SigOptSearch(
space,
name="SigOpt Example Experiment",
max_concurrent=1,
metric="mean_loss",
metric="loss",
mode="min")
return search_alg, cost
@ -298,9 +305,8 @@ class ZOOptWarmStartTest(AbstractWarmStartTest, unittest.TestCase):
"width": (ValueType.DISCRETE, [0, 20], False)
}
def cost(dim_dict, reporter):
reporter(
loss=(dim_dict["height"] - 14)**2 - abs(dim_dict["width"] - 3))
def cost(param, reporter):
reporter(loss=(param["height"] - 14)**2 - abs(param["width"] - 3))
search_alg = ZOOptSearch(
algo="Asracos", # only support ASRacos currently