mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
[Tune] Fix HEBO evaluated rewards for max mode & save/restore (#14427)
* Fix HEBO evaluated rewards for max mode * Lint * Make sure everything necessary is saved
This commit is contained in:
parent
63c2b7356e
commit
85a092c3d7
1 changed files with 9 additions and 5 deletions
|
@ -209,7 +209,7 @@ class HEBOSearch(Searcher):
|
||||||
if self._evaluated_rewards:
|
if self._evaluated_rewards:
|
||||||
self._opt.observe(
|
self._opt.observe(
|
||||||
pd.DataFrame(self._points_to_evaluate),
|
pd.DataFrame(self._points_to_evaluate),
|
||||||
np.array(self._evaluated_rewards))
|
np.array(self._evaluated_rewards) * self._metric_op)
|
||||||
else:
|
else:
|
||||||
self._initial_points = self._points_to_evaluate
|
self._initial_points = self._points_to_evaluate
|
||||||
|
|
||||||
|
@ -280,14 +280,18 @@ class HEBOSearch(Searcher):
|
||||||
numpy_random_state = None
|
numpy_random_state = None
|
||||||
torch_random_state = None
|
torch_random_state = None
|
||||||
with open(checkpoint_path, "wb") as f:
|
with open(checkpoint_path, "wb") as f:
|
||||||
pickle.dump((self._opt, self._points_to_evaluate,
|
pickle.dump((self._opt, self._initial_points, numpy_random_state,
|
||||||
numpy_random_state, torch_random_state), f)
|
torch_random_state, self._live_trial_mapping,
|
||||||
|
self._n_suggestions, self._suggestions_cache,
|
||||||
|
self._space, self._hebo_config), f)
|
||||||
|
|
||||||
def restore(self, checkpoint_path: str):
|
def restore(self, checkpoint_path: str):
|
||||||
"""Restoring current optimizer state."""
|
"""Restoring current optimizer state."""
|
||||||
with open(checkpoint_path, "rb") as f:
|
with open(checkpoint_path, "rb") as f:
|
||||||
(self._opt, self._points_to_evaluate, numpy_random_state,
|
(self._opt, self._initial_points, numpy_random_state,
|
||||||
torch_random_state) = pickle.load(f)
|
torch_random_state, self._live_trial_mapping, self._n_suggestions,
|
||||||
|
self._suggestions_cache, self._space,
|
||||||
|
self._hebo_config) = pickle.load(f)
|
||||||
if numpy_random_state is not None:
|
if numpy_random_state is not None:
|
||||||
np.random.set_state(numpy_random_state)
|
np.random.set_state(numpy_random_state)
|
||||||
if torch_random_state is not None:
|
if torch_random_state is not None:
|
||||||
|
|
Loading…
Add table
Reference in a new issue