[tune] Add Nevergrad to Tune (#3985)

This commit is contained in:
Adi Zimmerman 2019-02-12 11:00:04 -08:00 committed by Richard Liaw
parent c523bc04ad
commit dac1969647
7 changed files with 183 additions and 2 deletions

View file

@ -9,11 +9,13 @@ You can utilize these search algorithms as follows:
run_experiments(experiments, search_alg=SearchAlgorithm(...))
Currently, Tune offers the following search algorithms:
Currently, Tune offers the following search algorithms (and library integrations):
- `Grid Search and Random Search <tune-searchalg.html#variant-generation-grid-search-random-search>`__
- `BayesOpt <tune-searchalg.html#bayesopt-search>`__
- `HyperOpt <tune-searchalg.html#hyperopt-search-tree-structured-parzen-estimators>`__
- `SigOpt <tune-searchalg.html#sigopt-search>`__
- `Nevergrad <tune-searchalg.html#nevergrad-search>`__
Variant Generation (Grid Search/Random Search)
@ -103,6 +105,31 @@ An example of this can be found in `sigopt_example.py <https://github.com/ray-pr
:show-inheritance:
:noindex:
Nevergrad Search
----------------
The ``NevergradSearch`` is a SearchAlgorithm that is backed by `Nevergrad <https://github.com/facebookresearch/nevergrad>`__ to perform sequential model-based hyperparameter optimization. Note that this class does not extend ``ray.tune.suggest.BasicVariantGenerator``, so you will not be able to use Tune's default variant generation/search space declaration when using NevergradSearch.
In order to use this search algorithm, you will need to install Nevergrad via the following command.:
.. code-block:: bash
$ pip install nevergrad
Keep in mind that ``nevergrad`` is a Python 3.6+ library.
This algorithm requires using an optimizer provided by ``nevergrad``, of which there are many options. A good rundown can be found on their README's `Optimization <https://github.com/facebookresearch/nevergrad>`__ section. You can use ``NevergradSearch`` like follows:
.. code-block:: python
run_experiments(experiment_config, search_alg=NevergradSearch(optimizer, parameter_names, ... ))
An example of this can be found in `nevergrad_example.py <https://github.com/ray-project/ray/blob/master/python/ray/tune/examples/nevergrad_example.py>`__.
.. autoclass:: ray.tune.suggest.NevergradSearch
:show-inheritance:
:noindex:
Scikit-Optimize Search
----------------------

View file

@ -25,7 +25,7 @@ Features
- `HyperBand <tune-schedulers.html#asynchronous-hyperband>`__
* Mix and match different hyperparameter optimization approaches - such as using `HyperOpt with HyperBand`_.
* Mix and match different hyperparameter optimization approaches - such as using `HyperOpt with HyperBand`_ or `Nevergrad with HyperBand`_.
* Visualize results with `TensorBoard <https://www.tensorflow.org/get_started/summaries_and_tensorboard>`__, `parallel coordinates (Plot.ly) <https://plot.ly/python/parallel-coordinates-plot/>`__, and `rllab's VisKit <https://media.readthedocs.org/pdf/rllab/latest/rllab.pdf>`__.
@ -109,3 +109,4 @@ If Tune helps you in your academic research, you are encouraged to cite `our pap
.. _HyperOpt with HyperBand: https://github.com/ray-project/ray/blob/master/python/ray/tune/examples/hyperopt_example.py
.. _Nevergrad with HyperBand: https://github.com/ray-project/ray/blob/master/python/ray/tune/examples/nevergrad_example.py

View file

@ -10,5 +10,6 @@ RUN pip install -U h5py # Mutes FutureWarnings
RUN pip install --upgrade bayesian-optimization
RUN pip install --upgrade git+git://github.com/hyperopt/hyperopt.git
RUN pip install --upgrade sigopt
# RUN pip install --upgrade nevergrad
RUN pip install --upgrade scikit-optimize
RUN conda install pytorch-cpu torchvision-cpu -c pytorch

View file

@ -0,0 +1,56 @@
"""This test checks that Nevergrad is functional.
It also checks that it is usable with a separate scheduler.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import ray
from ray.tune import run_experiments, register_trainable
from ray.tune.schedulers import AsyncHyperBandScheduler
from ray.tune.suggest import NevergradSearch
def easy_objective(config, reporter):
import time
time.sleep(0.2)
for i in range(config["iterations"]):
reporter(
timesteps_total=i,
neg_mean_loss=-(config["height"] - 14)**2 +
abs(config["width"] - 3))
time.sleep(0.02)
if __name__ == '__main__':
import argparse
from nevergrad.optimization import optimizerlib
parser = argparse.ArgumentParser()
parser.add_argument(
"--smoke-test", action="store_true", help="Finish quickly for testing")
args, _ = parser.parse_known_args()
ray.init(redirect_output=True)
register_trainable("exp", easy_objective)
config = {
"nevergrad": {
"run": "exp",
"num_samples": 10 if args.smoke_test else 50,
"config": {
"iterations": 100,
},
"stop": {
"timesteps_total": 100
},
}
}
optimizer = optimizerlib.OnePlusOne(dimension=2)
algo = NevergradSearch(
optimizer, ["height", "width"],
max_concurrent=4,
reward_attr="neg_mean_loss")
scheduler = AsyncHyperBandScheduler(reward_attr="neg_mean_loss")
run_experiments(config, search_alg=algo, scheduler=scheduler)

View file

@ -3,6 +3,7 @@ from ray.tune.suggest.basic_variant import BasicVariantGenerator
from ray.tune.suggest.suggestion import SuggestionAlgorithm
from ray.tune.suggest.bayesopt import BayesOptSearch
from ray.tune.suggest.hyperopt import HyperOptSearch
from ray.tune.suggest.nevergrad import NevergradSearch
from ray.tune.suggest.skopt import SkOptSearch
from ray.tune.suggest.sigopt import SigOptSearch
from ray.tune.suggest.variant_generator import grid_search, function, \
@ -13,6 +14,7 @@ __all__ = [
"BasicVariantGenerator",
"BayesOptSearch",
"HyperOptSearch",
"NevergradSearch",
"SkOptSearch",
"SigOptSearch",
"SuggestionAlgorithm",

View file

@ -0,0 +1,89 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
try:
import nevergrad
except Exception:
nevergrad = None
from ray.tune.suggest.suggestion import SuggestionAlgorithm
class NevergradSearch(SuggestionAlgorithm):
"""A wrapper around Nevergrad to provide trial suggestions.
Requires Nevergrad to be installed.
Nevergrad is an open source tool from Facebook for derivative free
optimization of parameters and/or hyperparameters. It features a wide
range of optimizers in a standard ask and tell interface. More information
can be found at https://github.com/facebookresearch/nevergrad.
Parameters:
optimizer (nevergrad.optimization.Optimizer): Optimizer provided
from Nevergrad.
parameter_names (list): List of parameter names. Should match
the dimension of the optimizer output.
max_concurrent (int): Number of maximum concurrent trials. Defaults
to 10.
reward_attr (str): The training result objective value attribute.
This refers to an increasing value.
Example:
>>> from nevergrad.optimization import optimizerlib
>>> optimizer = optimizerlib.OnePlusOne(dimension=1, budget=100)
>>> config = {
>>> "my_exp": {
>>> "run": "exp",
>>> "num_samples": 10,
>>> "stop": {
>>> "training_iteration": 100
>>> },
>>> }
>>> }
>>> algo = NevergradSearch(
>>> optimizer, max_concurrent=4, reward_attr="neg_mean_loss")
"""
def __init__(self,
optimizer,
parameter_names,
max_concurrent=10,
reward_attr="episode_reward_mean",
**kwargs):
assert nevergrad is not None, "Nevergrad must be installed!"
assert type(max_concurrent) is int and max_concurrent > 0
self._max_concurrent = max_concurrent
self._parameters = parameter_names
self._reward_attr = reward_attr
self._nevergrad_opt = optimizer
self._live_trial_mapping = {}
super(NevergradSearch, self).__init__(**kwargs)
def _suggest(self, trial_id):
if self._num_live_trials() >= self._max_concurrent:
return None
suggested_config = self._nevergrad_opt.ask()
self._live_trial_mapping[trial_id] = suggested_config
return dict(zip(self._parameters, suggested_config))
def on_trial_result(self, trial_id, result):
pass
def on_trial_complete(self,
trial_id,
result=None,
error=False,
early_terminated=False):
"""Passes the result to Nevergrad unless early terminated or errored.
The result is internally negated when interacting with Nevergrad
so that Nevergrad Optimizers can "maximize" this value,
as it minimizes on default.
"""
ng_trial_info = self._live_trial_mapping.pop(trial_id)
if result:
self._nevergrad_opt.tell(ng_trial_info, -result[self._reward_attr])
def _num_live_trials(self):
return len(self._live_trial_mapping)

View file

@ -362,6 +362,11 @@ docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} -e SIGOPT_KEY $DO
python /ray/python/ray/tune/examples/sigopt_example.py \
--smoke-test
# Runs only on Python3
# docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \
# python /ray/python/ray/tune/examples/nevergrad_example.py \
# --smoke-test
docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \
python /ray/python/ray/tune/examples/tune_mnist_keras.py \
--smoke-test