mirror of
https://github.com/vale981/ray
synced 2025-03-05 10:01:43 -05:00
[Tune] Add ZOOpt search algorithm (#7960)
* add zoopt * add zoopt search algo * add zoopt * fix zoopt * add zoopt requirements * fix zoopt * remove generated guides * Apply suggestions from code review Co-authored-by: Richard Liaw <rliaw@berkeley.edu>
This commit is contained in:
parent
956ea7c944
commit
5c274fe631
7 changed files with 309 additions and 2 deletions
|
@ -127,6 +127,10 @@ $SUPPRESS_OUTPUT docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE}
|
|||
python /ray/python/ray/tune/examples/dragonfly_example.py \
|
||||
--smoke-test
|
||||
|
||||
$SUPPRESS_OUTPUT docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} --memory-swap=-1 $DOCKER_SHA \
|
||||
python /ray/python/ray/tune/examples/zoopt_example.py \
|
||||
--smoke-test
|
||||
|
||||
# Commenting out because flaky
|
||||
# $SUPPRESS_OUTPUT docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} --memory-swap=-1 $DOCKER_SHA \
|
||||
# python /ray/python/ray/tune/examples/pbt_memnn_example.py \
|
||||
|
|
|
@ -286,6 +286,59 @@ Take a look at `an example here <https://github.com/ray-project/ray/blob/master/
|
|||
:show-inheritance:
|
||||
:noindex:
|
||||
|
||||
ZOOpt Search
|
||||
------------
|
||||
|
||||
The ``ZOOptSearch`` is a SearchAlgorithm for derivative-free optimization. It is backed by the `ZOOpt <https://github.com/polixir/ZOOpt>`__ package. Currently, Asynchronous Sequential RAndomized COordinate Shrinking (ASRacos) algorithm is implemented in Tune. Note that this class does not extend ``ray.tune.suggest.BasicVariantGenerator``, so you will not be able to use Tune’s default variant generation/search space declaration when using ZOOptSearch.
|
||||
|
||||
In order to use this search algorithm, you will need to install the ZOOpt package **(>=0.4.0)** via the following command:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
$ pip install -U zoopt
|
||||
|
||||
Keep in mind that zoopt only supports Python 3.
|
||||
|
||||
This algorithm allows users to mix continuous dimensions and discrete dimensions, for example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
dim_dict = {
|
||||
# for continuous dimensions: (continuous, search_range, precision)
|
||||
"height": (ValueType.CONTINUOUS, [-10, 10], 1e-2),
|
||||
# for discrete dimensions: (discrete, search_range, has_order)
|
||||
"width": (ValueType.DISCRETE, [-10, 10], False)
|
||||
}
|
||||
|
||||
config = {
|
||||
"num_samples": 200 if args.smoke_test else 1000,
|
||||
"config": {
|
||||
"iterations": 10, # evaluation times
|
||||
},
|
||||
"stop": {
|
||||
"timesteps_total": 10 # cumstom stop rules
|
||||
}
|
||||
}
|
||||
|
||||
zoopt_search = ZOOptSearch(
|
||||
algo="Asracos", # only support ASRacos currently
|
||||
budget=config["num_samples"],
|
||||
dim_dict=dim_dict,
|
||||
max_concurrent=4,
|
||||
metric="mean_loss",
|
||||
mode="min")
|
||||
|
||||
run(my_objective,
|
||||
search_alg=zoopt_search,
|
||||
name="zoopt_search",
|
||||
**config)
|
||||
|
||||
An example of this can be found in `zoopt_example.py <https://github.com/ray-project/ray/blob/master/python/ray/tune/examples/zoopt_example.py>`__.
|
||||
|
||||
.. autoclass:: ray.tune.suggest.zoopt.ZOOptSearch
|
||||
:show-inheritance:
|
||||
:noindex:
|
||||
|
||||
Contributing a New Algorithm
|
||||
----------------------------
|
||||
|
||||
|
|
|
@ -12,7 +12,7 @@ RUN pip install -U h5py # Mutes FutureWarnings
|
|||
RUN pip install --upgrade bayesian-optimization
|
||||
RUN pip install --upgrade hyperopt==0.1.2
|
||||
RUN pip install ConfigSpace==0.4.10
|
||||
RUN pip install --upgrade sigopt nevergrad scikit-optimize hpbandster lightgbm xgboost torch torchvision tensorboardX dragonfly-opt
|
||||
RUN pip install --upgrade sigopt nevergrad scikit-optimize hpbandster lightgbm xgboost torch torchvision tensorboardX dragonfly-opt zoopt
|
||||
RUN pip install -U tabulate mlflow
|
||||
RUN pip install -U pytest-remotedata>=0.3.1
|
||||
RUN pip install -U matplotlib jupyter pandas
|
||||
|
|
|
@ -26,4 +26,6 @@ tensorboardX
|
|||
tensorflow_probability
|
||||
torch
|
||||
torchvision
|
||||
xgboost
|
||||
xgboost
|
||||
zoopt>=0.4.0
|
||||
dill
|
||||
|
|
65
python/ray/tune/examples/zoopt_example.py
Normal file
65
python/ray/tune/examples/zoopt_example.py
Normal file
|
@ -0,0 +1,65 @@
|
|||
"""This test checks that ZOOpt is functional.
|
||||
|
||||
It also checks that it is usable with a separate scheduler.
|
||||
"""
|
||||
|
||||
import ray
|
||||
from ray.tune import run
|
||||
from ray.tune.suggest.zoopt import ZOOptSearch
|
||||
from ray.tune.schedulers import AsyncHyperBandScheduler
|
||||
from zoopt import ValueType
|
||||
|
||||
|
||||
def easy_objective(config, reporter):
|
||||
import time
|
||||
time.sleep(0.2)
|
||||
for i in range(config["iterations"]):
|
||||
reporter(
|
||||
timesteps_total=i,
|
||||
mean_loss=(config["height"] - 14)**2 - abs(config["width"] - 3))
|
||||
time.sleep(0.02)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--smoke-test", action="store_true", help="Finish quickly for testing")
|
||||
args, _ = parser.parse_known_args()
|
||||
ray.init()
|
||||
|
||||
# This dict could mix continuous dimensions and discrete dimensions,
|
||||
# for example:
|
||||
dim_dict = {
|
||||
# for continuous dimensions: (continuous, search_range, precision)
|
||||
"height": (ValueType.CONTINUOUS, [-10, 10], 1e-2),
|
||||
# for discrete dimensions: (discrete, search_range, has_order)
|
||||
"width": (ValueType.DISCRETE, [-10, 10], False)
|
||||
}
|
||||
|
||||
config = {
|
||||
"num_samples": 200 if args.smoke_test else 1000,
|
||||
"config": {
|
||||
"iterations": 10, # evaluation times
|
||||
},
|
||||
"stop": {
|
||||
"timesteps_total": 10 # cumstom stop rules
|
||||
}
|
||||
}
|
||||
|
||||
zoopt_search = ZOOptSearch(
|
||||
algo="Asracos", # only support ASRacos currently
|
||||
budget=config["num_samples"],
|
||||
dim_dict=dim_dict,
|
||||
max_concurrent=4,
|
||||
metric="mean_loss",
|
||||
mode="min")
|
||||
|
||||
scheduler = AsyncHyperBandScheduler(metric="mean_loss", mode="min")
|
||||
|
||||
run(easy_objective,
|
||||
search_alg=zoopt_search,
|
||||
name="zoopt_search",
|
||||
scheduler=scheduler,
|
||||
**config)
|
159
python/ray/tune/suggest/zoopt.py
Normal file
159
python/ray/tune/suggest/zoopt.py
Normal file
|
@ -0,0 +1,159 @@
|
|||
import copy
|
||||
import logging
|
||||
import dill as pickle
|
||||
from zoopt import Dimension2, Parameter
|
||||
from zoopt.algos.opt_algorithms.racos.sracos import SRacosTune
|
||||
from ray.tune.suggest.suggestion import SuggestionAlgorithm
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ZOOptSearch(SuggestionAlgorithm):
|
||||
"""A wrapper around ZOOpt to provide trial suggestions.
|
||||
|
||||
Requires zoopt package (>=0.4.0) to be installed. You can install it
|
||||
with the command: ``pip install -U zoopt``.
|
||||
|
||||
Parameters:
|
||||
algo (str): To specify an algorithm in zoopt you want to use.
|
||||
Only support ASRacos currently.
|
||||
budget (int): Number of samples.
|
||||
dim_dict (dict): Dimension dictionary.
|
||||
For continuous dimensions: (continuous, search_range, precision);
|
||||
For discrete dimensions: (discrete, search_range, has_order).
|
||||
More details can be found in zoopt package.
|
||||
max_concurrent (int): Number of maximum concurrent trials.
|
||||
Defaults to 10.
|
||||
metric (str): The training result objective value attribute.
|
||||
Defaults to "episode_reward_mean".
|
||||
mode (str): One of {min, max}. Determines whether objective is
|
||||
minimizing or maximizing the metric attribute.
|
||||
Defaults to "min".
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from ray.tune import run
|
||||
from ray.tune.suggest.zoopt import ZOOptSearch
|
||||
from zoopt import ValueType
|
||||
|
||||
dim_dict = {
|
||||
"height": (ValueType.CONTINUOUS, [-10, 10], 1e-2),
|
||||
"width": (ValueType.DISCRETE, [-10, 10], False)
|
||||
}
|
||||
|
||||
config = {
|
||||
"num_samples": 200,
|
||||
"config": {
|
||||
"iterations": 10, # evaluation times
|
||||
},
|
||||
"stop": {
|
||||
"timesteps_total": 10 # cumstom stop rules
|
||||
}
|
||||
}
|
||||
|
||||
zoopt_search = ZOOptSearch(
|
||||
algo="Asracos", # only support Asracos currently
|
||||
budget=config["num_samples"],
|
||||
dim_dict=dim_dict,
|
||||
max_concurrent=4,
|
||||
metric="mean_loss",
|
||||
mode="min")
|
||||
|
||||
run(my_objective,
|
||||
search_alg=zoopt_search,
|
||||
name="zoopt_search",
|
||||
**config)
|
||||
|
||||
"""
|
||||
|
||||
optimizer = None
|
||||
|
||||
def __init__(self,
|
||||
algo="asracos",
|
||||
budget=None,
|
||||
dim_dict=None,
|
||||
max_concurrent=10,
|
||||
metric="episode_reward_mean",
|
||||
mode="min",
|
||||
**kwargs):
|
||||
|
||||
assert budget is not None, "`budget` should not be None!"
|
||||
assert dim_dict is not None, "`dim_list` should not be None!"
|
||||
assert type(max_concurrent) is int and max_concurrent > 0
|
||||
assert mode in ["min", "max"], "`mode` must be 'min' or 'max'!"
|
||||
_algo = algo.lower()
|
||||
assert _algo in ["asracos", "sracos"
|
||||
], "`algo` must be in ['asracos', 'sracos'] currently"
|
||||
|
||||
self._max_concurrent = max_concurrent
|
||||
self._metric = metric
|
||||
if mode == "max":
|
||||
self._metric_op = -1.
|
||||
elif mode == "min":
|
||||
self._metric_op = 1.
|
||||
self._live_trial_mapping = {}
|
||||
|
||||
self._dim_keys = []
|
||||
_dim_list = []
|
||||
for k in dim_dict:
|
||||
self._dim_keys.append(k)
|
||||
_dim_list.append(dim_dict[k])
|
||||
|
||||
dim = Dimension2(_dim_list)
|
||||
par = Parameter(budget=budget)
|
||||
if _algo == "sracos" or _algo == "asracos":
|
||||
self.optimizer = SRacosTune(dimension=dim, parameter=par)
|
||||
|
||||
self.solution_dict = {}
|
||||
self.best_solution_list = []
|
||||
|
||||
super(ZOOptSearch, self).__init__(
|
||||
metric=self._metric, mode=mode, **kwargs)
|
||||
|
||||
def suggest(self, trial_id):
|
||||
if self._num_live_trials() >= self._max_concurrent:
|
||||
return None
|
||||
|
||||
_solution = self.optimizer.suggest()
|
||||
if _solution:
|
||||
self.solution_dict[str(trial_id)] = _solution
|
||||
_x = _solution.get_x()
|
||||
new_trial = dict(zip(self._dim_keys, _x))
|
||||
self._live_trial_mapping[trial_id] = new_trial
|
||||
return copy.deepcopy(new_trial)
|
||||
|
||||
def on_trial_result(self, trial_id, result):
|
||||
pass
|
||||
|
||||
def on_trial_complete(self,
|
||||
trial_id,
|
||||
result=None,
|
||||
error=False,
|
||||
early_terminated=False):
|
||||
"""Notification for the completion of trial."""
|
||||
if result:
|
||||
_solution = self.solution_dict[str(trial_id)]
|
||||
_best_solution_so_far = self.optimizer.complete(
|
||||
_solution, self._metric_op * result[self._metric])
|
||||
if _best_solution_so_far:
|
||||
self.best_solution_list.append(_best_solution_so_far)
|
||||
self._process_result(trial_id, result, early_terminated)
|
||||
|
||||
del self._live_trial_mapping[trial_id]
|
||||
|
||||
def _process_result(self, trial_id, result, early_terminated=False):
|
||||
if early_terminated and self._use_early_stopped is False:
|
||||
return
|
||||
|
||||
def _num_live_trials(self):
|
||||
return len(self._live_trial_mapping)
|
||||
|
||||
def save(self, checkpoint_dir):
|
||||
trials_object = self.optimizer
|
||||
with open(checkpoint_dir, "wb") as output:
|
||||
pickle.dump(trials_object, output)
|
||||
|
||||
def restore(self, checkpoint_dir):
|
||||
with open(checkpoint_dir, "rb") as input:
|
||||
trials_object = pickle.load(input)
|
||||
self.optimizer = trials_object
|
|
@ -17,6 +17,8 @@ from ray.tune.suggest.bayesopt import BayesOptSearch
|
|||
from ray.tune.suggest.skopt import SkOptSearch
|
||||
from ray.tune.suggest.nevergrad import NevergradSearch
|
||||
from ray.tune.suggest.sigopt import SigOptSearch
|
||||
from ray.tune.suggest.zoopt import ZOOptSearch
|
||||
from zoopt import ValueType
|
||||
from ray.tune.utils import validate_save_restore
|
||||
|
||||
|
||||
|
@ -288,6 +290,28 @@ class SigOptWarmStartTest(AbstractWarmStartTest, unittest.TestCase):
|
|||
super().testWarmStart()
|
||||
|
||||
|
||||
class ZOOptWarmStartTest(AbstractWarmStartTest, unittest.TestCase):
|
||||
def set_basic_conf(self):
|
||||
dim_dict = {
|
||||
"height": (ValueType.CONTINUOUS, [-100, 100], 1e-2),
|
||||
"width": (ValueType.DISCRETE, [0, 20], False)
|
||||
}
|
||||
|
||||
def cost(dim_dict, reporter):
|
||||
reporter(
|
||||
loss=(dim_dict["height"] - 14)**2 - abs(dim_dict["width"] - 3))
|
||||
|
||||
search_alg = ZOOptSearch(
|
||||
algo="Asracos", # only support ASRacos currently
|
||||
budget=200,
|
||||
dim_dict=dim_dict,
|
||||
max_concurrent=1,
|
||||
metric="loss",
|
||||
mode="min")
|
||||
|
||||
return search_alg, cost
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import pytest
|
||||
import sys
|
||||
|
|
Loading…
Add table
Reference in a new issue