[tune] Save and Restore for bayesopt (#5623)

This commit is contained in:
Hersh Godse 2019-09-10 13:11:59 -07:00 committed by Richard Liaw
parent 4d16677a68
commit 336aef1774
3 changed files with 79 additions and 1 deletions

View file

@ -4,6 +4,7 @@ from __future__ import print_function
import copy
import logging
import pickle
try: # Python 3 only -- needed for lint test.
import bayes_opt as byo
except ImportError:
@ -111,3 +112,13 @@ class BayesOptSearch(SuggestionAlgorithm):
def _num_live_trials(self):
return len(self._live_trial_mapping)
def save(self, checkpoint_dir):
trials_object = self.optimizer
with open(checkpoint_dir, "wb") as output:
pickle.dump(trials_object, output)
def restore(self, checkpoint_dir):
with open(checkpoint_dir, "rb") as input:
trials_object = pickle.load(input)
self.optimizer = trials_object

View file

@ -3,6 +3,7 @@ from __future__ import division
from __future__ import print_function
import logging
import pickle
try:
import skopt as sko
except ImportError:
@ -157,3 +158,13 @@ class SkOptSearch(SuggestionAlgorithm):
def _num_live_trials(self):
return len(self._live_trial_mapping)
def save(self, checkpoint_dir):
trials_object = self._skopt_opt
with open(checkpoint_dir, "wb") as output:
pickle.dump(trials_object, output)
def restore(self, checkpoint_dir):
with open(checkpoint_dir, "rb") as input:
trials_object = pickle.load(input)
self._skopt_opt = trials_object

View file

@ -15,6 +15,7 @@ from ray.tests.utils import recursive_fnmatch
from ray.tune.util import validate_save_restore
from ray.rllib import _register_all
from ray.tune.suggest.hyperopt import HyperOptSearch
from ray.tune.suggest.bayesopt import BayesOptSearch
class TuneRestoreTest(unittest.TestCase):
@ -146,7 +147,7 @@ class HyperoptWarmStartTest(unittest.TestCase):
def run_exp_1(self):
search_alg, cost = self.set_basic_conf()
results_exp_1 = tune.run(cost, num_samples=15, search_alg=search_alg)
self.log_dir = os.path.join(self.tmpdir, "trials_algo1.pkl")
self.log_dir = os.path.join(self.tmpdir, "trials_algo_hyo.pkl")
search_alg.save(self.log_dir)
return results_exp_1
@ -169,5 +170,60 @@ class HyperoptWarmStartTest(unittest.TestCase):
self.assertEqual(trials_1_config + trials_2_config, trials_3_config)
class BayesoptWarmStartTest(unittest.TestCase):
def setUp(self):
ray.init(local_mode=True)
self.tmpdir = tempfile.mkdtemp()
def tearDown(self):
shutil.rmtree(self.tmpdir)
ray.shutdown()
_register_all()
def set_basic_conf(self):
space = {"width": (0, 20), "height": (-100, 100)}
def cost(space, reporter):
loss = space["width"]**2 + space["height"]**2
reporter(loss=loss)
search_alg = BayesOptSearch(
space,
max_concurrent=1,
metric="loss",
mode="min",
utility_kwargs={
"kind": "ucb",
"kappa": 2.5,
"xi": 0.0
})
return search_alg, cost
def run_exp_1(self):
search_alg, cost = self.set_basic_conf()
results_exp_1 = tune.run(cost, num_samples=15, search_alg=search_alg)
self.log_dir = os.path.join(self.tmpdir, "trials_algo_byo.pkl")
search_alg.save(self.log_dir)
return results_exp_1
def run_exp_2(self):
search_alg2, cost = self.set_basic_conf()
search_alg2.restore(self.log_dir)
return tune.run(cost, num_samples=15, search_alg=search_alg2)
def run_exp_3(self):
search_alg3, cost = self.set_basic_conf()
return tune.run(cost, num_samples=30, search_alg=search_alg3)
def testBayesoptWarmStart(self):
results_exp_1 = self.run_exp_1()
results_exp_2 = self.run_exp_2()
results_exp_3 = self.run_exp_3()
trials_1_config = [trial.config for trial in results_exp_1.trials]
trials_2_config = [trial.config for trial in results_exp_2.trials]
trials_3_config = [trial.config for trial in results_exp_3.trials]
self.assertEqual(trials_1_config + trials_2_config, trials_3_config)
if __name__ == "__main__":
unittest.main(verbosity=2)