mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
[tune] Save and Restore for bayesopt (#5623)
This commit is contained in:
parent
4d16677a68
commit
336aef1774
3 changed files with 79 additions and 1 deletions
|
@ -4,6 +4,7 @@ from __future__ import print_function
|
|||
|
||||
import copy
|
||||
import logging
|
||||
import pickle
|
||||
try: # Python 3 only -- needed for lint test.
|
||||
import bayes_opt as byo
|
||||
except ImportError:
|
||||
|
@ -111,3 +112,13 @@ class BayesOptSearch(SuggestionAlgorithm):
|
|||
|
||||
def _num_live_trials(self):
|
||||
return len(self._live_trial_mapping)
|
||||
|
||||
def save(self, checkpoint_dir):
|
||||
trials_object = self.optimizer
|
||||
with open(checkpoint_dir, "wb") as output:
|
||||
pickle.dump(trials_object, output)
|
||||
|
||||
def restore(self, checkpoint_dir):
|
||||
with open(checkpoint_dir, "rb") as input:
|
||||
trials_object = pickle.load(input)
|
||||
self.optimizer = trials_object
|
||||
|
|
|
@ -3,6 +3,7 @@ from __future__ import division
|
|||
from __future__ import print_function
|
||||
|
||||
import logging
|
||||
import pickle
|
||||
try:
|
||||
import skopt as sko
|
||||
except ImportError:
|
||||
|
@ -157,3 +158,13 @@ class SkOptSearch(SuggestionAlgorithm):
|
|||
|
||||
def _num_live_trials(self):
|
||||
return len(self._live_trial_mapping)
|
||||
|
||||
def save(self, checkpoint_dir):
|
||||
trials_object = self._skopt_opt
|
||||
with open(checkpoint_dir, "wb") as output:
|
||||
pickle.dump(trials_object, output)
|
||||
|
||||
def restore(self, checkpoint_dir):
|
||||
with open(checkpoint_dir, "rb") as input:
|
||||
trials_object = pickle.load(input)
|
||||
self._skopt_opt = trials_object
|
||||
|
|
|
@ -15,6 +15,7 @@ from ray.tests.utils import recursive_fnmatch
|
|||
from ray.tune.util import validate_save_restore
|
||||
from ray.rllib import _register_all
|
||||
from ray.tune.suggest.hyperopt import HyperOptSearch
|
||||
from ray.tune.suggest.bayesopt import BayesOptSearch
|
||||
|
||||
|
||||
class TuneRestoreTest(unittest.TestCase):
|
||||
|
@ -146,7 +147,7 @@ class HyperoptWarmStartTest(unittest.TestCase):
|
|||
def run_exp_1(self):
|
||||
search_alg, cost = self.set_basic_conf()
|
||||
results_exp_1 = tune.run(cost, num_samples=15, search_alg=search_alg)
|
||||
self.log_dir = os.path.join(self.tmpdir, "trials_algo1.pkl")
|
||||
self.log_dir = os.path.join(self.tmpdir, "trials_algo_hyo.pkl")
|
||||
search_alg.save(self.log_dir)
|
||||
return results_exp_1
|
||||
|
||||
|
@ -169,5 +170,60 @@ class HyperoptWarmStartTest(unittest.TestCase):
|
|||
self.assertEqual(trials_1_config + trials_2_config, trials_3_config)
|
||||
|
||||
|
||||
class BayesoptWarmStartTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
ray.init(local_mode=True)
|
||||
self.tmpdir = tempfile.mkdtemp()
|
||||
|
||||
def tearDown(self):
|
||||
shutil.rmtree(self.tmpdir)
|
||||
ray.shutdown()
|
||||
_register_all()
|
||||
|
||||
def set_basic_conf(self):
|
||||
space = {"width": (0, 20), "height": (-100, 100)}
|
||||
|
||||
def cost(space, reporter):
|
||||
loss = space["width"]**2 + space["height"]**2
|
||||
reporter(loss=loss)
|
||||
|
||||
search_alg = BayesOptSearch(
|
||||
space,
|
||||
max_concurrent=1,
|
||||
metric="loss",
|
||||
mode="min",
|
||||
utility_kwargs={
|
||||
"kind": "ucb",
|
||||
"kappa": 2.5,
|
||||
"xi": 0.0
|
||||
})
|
||||
return search_alg, cost
|
||||
|
||||
def run_exp_1(self):
|
||||
search_alg, cost = self.set_basic_conf()
|
||||
results_exp_1 = tune.run(cost, num_samples=15, search_alg=search_alg)
|
||||
self.log_dir = os.path.join(self.tmpdir, "trials_algo_byo.pkl")
|
||||
search_alg.save(self.log_dir)
|
||||
return results_exp_1
|
||||
|
||||
def run_exp_2(self):
|
||||
search_alg2, cost = self.set_basic_conf()
|
||||
search_alg2.restore(self.log_dir)
|
||||
return tune.run(cost, num_samples=15, search_alg=search_alg2)
|
||||
|
||||
def run_exp_3(self):
|
||||
search_alg3, cost = self.set_basic_conf()
|
||||
return tune.run(cost, num_samples=30, search_alg=search_alg3)
|
||||
|
||||
def testBayesoptWarmStart(self):
|
||||
results_exp_1 = self.run_exp_1()
|
||||
results_exp_2 = self.run_exp_2()
|
||||
results_exp_3 = self.run_exp_3()
|
||||
trials_1_config = [trial.config for trial in results_exp_1.trials]
|
||||
trials_2_config = [trial.config for trial in results_exp_2.trials]
|
||||
trials_3_config = [trial.config for trial in results_exp_3.trials]
|
||||
self.assertEqual(trials_1_config + trials_2_config, trials_3_config)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main(verbosity=2)
|
||||
|
|
Loading…
Add table
Reference in a new issue