[tune] SigOpt multi-objective search + experiments (#10457)

This commit is contained in:
raoul-khour-ts 2020-09-01 19:22:29 -04:00 committed by GitHub
parent 2b95b613f2
commit 3b10b67a15
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 345 additions and 25 deletions

View file

@ -564,6 +564,27 @@ py_test(
# args = ["--smoke-test"]
# )
# Needs SigOpt API key.
# py_test(
# name = "sigopt_multi_objective_example",
# size = "medium",
# srcs = ["examples/sigopt_multi_objective_example.py"],
# deps = [":tune_lib"], s
# tags = ["exclusive", "example"],
# args = ["--smoke-test"]
# )
# Needs SigOpt API key.
# py_test(
# name = "sigopt_prior_beliefs_example",
# size = "medium",
# srcs = ["examples/sigopt_prior_beliefs_example.py"],
# deps = [":tune_lib"],
# tags = ["exclusive", "example"],
# args = ["--smoke-test"]
# )
py_test(
name = "skopt_example",
size = "medium",

View file

@ -0,0 +1,79 @@
"""This test checks that SigOpt is functional.
It also checks that it is usable with a separate scheduler.
"""
import time
import ray
import numpy as np
from ray import tune
from ray.tune.schedulers import FIFOScheduler
from ray.tune.suggest.sigopt import SigOptSearch
np.random.seed(0)
vector1 = np.random.normal(0, 0.1, 100)
vector2 = np.random.normal(0, 0.1, 100)
def evaluate(w1, w2):
total = w1 * vector1 + w2 * vector2
return total.mean(), total.std()
def easy_objective(config):
# Hyperparameters
w1 = config["w1"]
w2 = config["total_weight"] - w1
average, std = evaluate(w1, w2)
tune.report(average=average, std=std, sharpe=average / std)
time.sleep(0.1)
if __name__ == "__main__":
import argparse
import os
assert "SIGOPT_KEY" in os.environ, \
"SigOpt API key must be stored as environment variable at SIGOPT_KEY"
parser = argparse.ArgumentParser()
parser.add_argument(
"--smoke-test", action="store_true", help="Finish quickly for testing")
args, _ = parser.parse_known_args()
ray.init()
space = [
{
"name": "w1",
"type": "double",
"bounds": {
"min": 0,
"max": 1
},
},
]
config = {
"num_samples": 10 if args.smoke_test else 1000,
"config": {
"total_weight": 1
}
}
algo = SigOptSearch(
space,
name="SigOpt Example Multi Objective Experiment",
observation_budget=10 if args.smoke_test else 1000,
max_concurrent=1,
metric=["average", "std", "sharpe"],
mode=["max", "min", "obs"])
scheduler = FIFOScheduler()
tune.run(
easy_objective,
name="my_exp",
search_alg=algo,
scheduler=scheduler,
**config)

View file

@ -0,0 +1,110 @@
"""This test checks that SigOpt is functional.
It also checks that it is usable with a separate scheduler.
"""
import time
import ray
import numpy as np
from ray import tune
from ray.tune.schedulers import FIFOScheduler
from ray.tune.suggest.sigopt import SigOptSearch
np.random.seed(0)
vector1 = np.random.normal(0.0, 0.1, 100)
vector2 = np.random.normal(0.0, 0.1, 100)
vector3 = np.random.normal(0.0, 0.1, 100)
def evaluate(w1, w2, w3):
total = w1 * vector1 + w2 * vector2 + w3 * vector3
return total.mean(), total.std()
def easy_objective(config):
# Hyperparameters
w1 = config["w1"]
w2 = config["w2"]
total = (w1 + w2)
if total > 1:
w3 = 0
w1 /= total
w2 /= total
else:
w3 = 1 - total
average, std = evaluate(w1, w2, w3)
tune.report(average=average, std=std)
time.sleep(0.1)
if __name__ == "__main__":
import argparse
import os
from sigopt import Connection
assert "SIGOPT_KEY" in os.environ, \
"SigOpt API key must be stored as environment variable at SIGOPT_KEY"
parser = argparse.ArgumentParser()
parser.add_argument(
"--smoke-test", action="store_true", help="Finish quickly for testing")
args, _ = parser.parse_known_args()
ray.init()
samples = 10 if args.smoke_test else 1000
conn = Connection(client_token=os.environ["SIGOPT_KEY"])
experiment = conn.experiments().create(
name="prior experiment example",
parameters=[{
"name": "w1",
"bounds": {
"max": 1,
"min": 0
},
"prior": {
"mean": 1 / 3,
"name": "normal",
"scale": 0.2
},
"type": "double"
}, {
"name": "w2",
"bounds": {
"max": 1,
"min": 0
},
"prior": {
"mean": 1 / 3,
"name": "normal",
"scale": 0.2
},
"type": "double"
}],
metrics=[
dict(name="std", objective="minimize", strategy="optimize"),
dict(name="average", strategy="store")
],
observation_budget=samples,
parallel_bandwidth=1)
config = {"num_samples": samples, "config": {}}
algo = SigOptSearch(
connection=conn,
experiment_id=experiment.id,
name="SigOpt Example Existing Experiment",
max_concurrent=1,
metric=["average", "std"],
mode=["obs", "min"])
scheduler = FIFOScheduler()
tune.run(
easy_objective,
name="my_exp",
search_alg=algo,
scheduler=scheduler,
**config)

View file

@ -32,16 +32,27 @@ class SigOptSearch(Searcher):
space (list of dict): SigOpt configuration. Parameters will be sampled
from this configuration and will be used to override
parameters generated in the variant generation process.
Not used if existing experiment_id is given
name (str): Name of experiment. Required by SigOpt.
max_concurrent (int): Number of maximum concurrent trials supported
based on the user's SigOpt plan. Defaults to 1.
connection (Connection): An existing connection to SigOpt.
experiment_id (str): Optional, if given will connect to an existing
experiment. This allows for a more interactive experience with
SigOpt, such as prior beliefs and constraints.
observation_budget (int): Optional, can improve SigOpt performance.
project (str): Optional, Project name to assign this experiment to.
SigOpt can group experiments by project
metric (str): The training result objective value attribute.
mode (str): One of {min, max}. Determines whether objective is
minimizing or maximizing the metric attribute.
metric (str or list(str)): If str then the training result
objective value attribute. If list(str) then a list of
metrics that can be optimized together. SigOpt currently
supports up to 2 metrics.
mode (str or list(str)): If experiment_id is given then this
field is ignored, If str then must be one of {min, max}.
If list then must be comprised of {min, max, obs}. Determines
whether objective is minimizing or maximizing the metric
attribute. If metrics is a list then mode must be a list
of the same length as metric.
Example:
@ -68,21 +79,63 @@ class SigOptSearch(Searcher):
algo = SigOptSearch(
space, name="SigOpt Example Experiment",
max_concurrent=1, metric="mean_loss", mode="min")
Example:
.. code-block:: python
space = [
{
'name': 'width',
'type': 'int',
'bounds': {
'min': 0,
'max': 20
},
},
{
'name': 'height',
'type': 'int',
'bounds': {
'min': -100,
'max': 100
},
},
]
algo = SigOptSearch(
space, name="SigOpt Multi Objective Example Experiment",
max_concurrent=1, metric=["average", "std"], mode=["max", "min"])
"""
OBJECTIVE_MAP = {
"max": {
"objective": "maximize",
"strategy": "optimize"
},
"min": {
"objective": "minimize",
"strategy": "optimize"
},
"obs": {
"strategy": "store"
}
}
def __init__(self,
space,
space=None,
name="Default Tune Experiment",
max_concurrent=1,
reward_attr=None,
connection=None,
experiment_id=None,
observation_budget=None,
project=None,
metric="episode_reward_mean",
mode="max",
**kwargs):
assert (experiment_id is
None) ^ (space is None), "space xor experiment_id must be set"
assert type(max_concurrent) is int and max_concurrent > 0
assert mode in ["min", "max"], "`mode` must be 'min' or 'max'!"
if connection is not None:
self.conn = connection
@ -95,25 +148,33 @@ class SigOptSearch(Searcher):
self.conn = sgo.Connection(client_token=os.environ["SIGOPT_KEY"])
self._max_concurrent = max_concurrent
if isinstance(metric, str):
metric = [metric]
mode = [mode]
self._metric = metric
if mode == "max":
self._metric_op = 1.
elif mode == "min":
self._metric_op = -1.
self._live_trial_mapping = {}
sigopt_params = dict(
name=name,
parameters=space,
parallel_bandwidth=self._max_concurrent)
if experiment_id is None:
sigopt_params = dict(
name=name,
parameters=space,
parallel_bandwidth=self._max_concurrent)
if observation_budget is not None:
sigopt_params["observation_budget"] = observation_budget
if observation_budget is not None:
sigopt_params["observation_budget"] = observation_budget
if project is not None:
sigopt_params["project"] = project
if project is not None:
sigopt_params["project"] = project
self.experiment = self.conn.experiments().create(**sigopt_params)
if len(metric) > 1 and observation_budget is None:
raise ValueError(
"observation_budget is required for an"
"experiment with more than one optimized metric")
sigopt_params["metrics"] = self.serialize_metric(metric, mode)
self.experiment = self.conn.experiments().create(**sigopt_params)
else:
self.experiment = self.conn.experiments(experiment_id).fetch()
super(SigOptSearch, self).__init__(metric=metric, mode=mode, **kwargs)
@ -139,10 +200,11 @@ class SigOptSearch(Searcher):
Creates SigOpt Observation object for trial.
"""
if result:
self.conn.experiments(self.experiment.id).observations().create(
payload = dict(
suggestion=self._live_trial_mapping[trial_id].id,
value=self._metric_op * result[self._metric],
)
values=self.serialize_result(result))
self.conn.experiments(
self.experiment.id).observations().create(**payload)
# Update the experiment object
self.experiment = self.conn.experiments(self.experiment.id).fetch()
elif error:
@ -151,6 +213,37 @@ class SigOptSearch(Searcher):
failed=True, suggestion=self._live_trial_mapping[trial_id].id)
del self._live_trial_mapping[trial_id]
@staticmethod
def serialize_metric(metrics, modes):
"""
Converts metrics to https://app.sigopt.com/docs/objects/metric
"""
serialized_metric = []
for metric, mode in zip(metrics, modes):
serialized_metric.append(
dict(name=metric, **SigOptSearch.OBJECTIVE_MAP[mode].copy()))
return serialized_metric
def serialize_result(self, result):
"""
Converts experiments results to
https://app.sigopt.com/docs/objects/metric_evaluation
"""
missing_scores = [
metric for metric in self._metric if metric not in result
]
if missing_scores:
raise ValueError(
f"Some metrics specified during initialization are missing. "
f"Missing metrics: {missing_scores}, provided result {result}")
values = []
for metric in self._metric:
value = dict(name=metric, value=result[metric])
values.append(value)
return values
def save(self, checkpoint_path):
trials_object = (self.conn, self.experiment)
with open(checkpoint_path, "wb") as outputFile:

View file

@ -21,10 +21,14 @@ class Searcher:
`suggest` will be passed a trial_id, which will be used in
subsequent notifications.
Not all implementations support multi objectives.
Args:
metric (str): The training result objective value attribute.
mode (str): One of {min, max}. Determines whether objective is
minimizing or maximizing the metric attribute.
metric (str or list): The training result objective value attribute. If
list then list of training result objective value attributes
mode (str or list): If string One of {min, max}. If list then
list of max and min, determines whether objective is minimizing
or maximizing the metric attribute. Must match type of metric.
.. code-block:: python
@ -65,7 +69,20 @@ class Searcher:
"DeprecationWarning: `max_concurrent` is deprecated for this "
"search algorithm. Use tune.suggest.ConcurrencyLimiter() "
"instead. This will raise an error in future versions of Ray.")
assert mode in ["min", "max"], "`mode` must be 'min' or 'max'!"
assert isinstance(
metric, type(mode)), "metric and mode must be of the same type"
if isinstance(mode, str):
assert mode in ["min", "max"
], "if `mode` is a str must be 'min' or 'max'!"
elif isinstance(mode, list):
assert len(mode) == len(
metric), "Metric and mode must be the same length"
assert all(mod in ["min", "max", "obs"] for mod in
mode), "All of mode must be 'min' or 'max' or 'obs'!"
else:
raise ValueError("Mode most either be a list or string")
self._metric = metric
self._mode = mode