mirror of
https://github.com/vale981/ray
synced 2025-03-05 18:11:42 -05:00
[tune] add mode/metric parameters to tune.run (#10627)
Co-authored-by: Richard Liaw <rliaw@berkeley.edu>
This commit is contained in:
parent
edce7a05e6
commit
756a9ea641
26 changed files with 322 additions and 88 deletions
|
@ -29,5 +29,5 @@ git+https://github.com/executablebooks/sphinx-book-theme.git@0a87d26e214c419d2e6
|
|||
tabulate
|
||||
uvicorn
|
||||
werkzeug
|
||||
tune-sklearn==0.0.5
|
||||
git+git://github.com/ray-project/tune-sklearn@master#tune-sklearn
|
||||
scikit-optimize
|
||||
|
|
|
@ -127,7 +127,7 @@ tune_search = TuneSearchCV(
|
|||
clf,
|
||||
parameter_grid,
|
||||
search_optimization="bayesian",
|
||||
n_iter=3,
|
||||
n_trials=3,
|
||||
early_stopping=True,
|
||||
max_iters=10,
|
||||
)
|
||||
|
|
|
@ -149,7 +149,7 @@ py_test(
|
|||
|
||||
py_test(
|
||||
name = "test_sample",
|
||||
size = "medium",
|
||||
size = "small",
|
||||
srcs = ["tests/test_sample.py"],
|
||||
deps = [":tune_lib"],
|
||||
tags = ["exclusive"],
|
||||
|
|
|
@ -121,7 +121,7 @@ if __name__ == "__main__":
|
|||
else:
|
||||
ray.init(num_cpus=2 if args.smoke_test else None)
|
||||
sched = AsyncHyperBandScheduler(
|
||||
time_attr="training_iteration", metric="mean_accuracy")
|
||||
time_attr="training_iteration", metric="mean_accuracy", mode="max")
|
||||
analysis = tune.run(
|
||||
train_mnist,
|
||||
name="exp",
|
||||
|
|
|
@ -65,9 +65,11 @@ class TrainMNIST(tune.Trainable):
|
|||
if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
ray.init(address=args.ray_address, num_cpus=6 if args.smoke_test else None)
|
||||
sched = ASHAScheduler(metric="mean_accuracy")
|
||||
sched = ASHAScheduler()
|
||||
analysis = tune.run(
|
||||
TrainMNIST,
|
||||
metric="mean_accuracy",
|
||||
mode="max",
|
||||
scheduler=sched,
|
||||
stop={
|
||||
"mean_accuracy": 0.95,
|
||||
|
|
|
@ -10,8 +10,8 @@ from ray.tune.schedulers.pbt import (PopulationBasedTraining,
|
|||
|
||||
def create_scheduler(
|
||||
scheduler,
|
||||
metric="episode_reward_mean",
|
||||
mode="max",
|
||||
metric=None,
|
||||
mode=None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Instantiate a scheduler based on the given string.
|
||||
|
|
|
@ -38,8 +38,8 @@ class AsyncHyperBandScheduler(FIFOScheduler):
|
|||
def __init__(self,
|
||||
time_attr="training_iteration",
|
||||
reward_attr=None,
|
||||
metric="episode_reward_mean",
|
||||
mode="max",
|
||||
metric=None,
|
||||
mode=None,
|
||||
max_t=100,
|
||||
grace_period=1,
|
||||
reduction_factor=4,
|
||||
|
@ -49,7 +49,8 @@ class AsyncHyperBandScheduler(FIFOScheduler):
|
|||
assert grace_period > 0, "grace_period must be positive!"
|
||||
assert reduction_factor > 1, "Reduction Factor not valid!"
|
||||
assert brackets > 0, "brackets must be positive!"
|
||||
assert mode in ["min", "max"], "`mode` must be 'min' or 'max'!"
|
||||
if mode:
|
||||
assert mode in ["min", "max"], "`mode` must be 'min' or 'max'!"
|
||||
|
||||
if reward_attr is not None:
|
||||
mode = "max"
|
||||
|
@ -73,13 +74,41 @@ class AsyncHyperBandScheduler(FIFOScheduler):
|
|||
self._counter = 0 # for
|
||||
self._num_stopped = 0
|
||||
self._metric = metric
|
||||
if mode == "max":
|
||||
self._mode = mode
|
||||
self._metric_op = None
|
||||
if self._mode == "max":
|
||||
self._metric_op = 1.
|
||||
elif mode == "min":
|
||||
elif self._mode == "min":
|
||||
self._metric_op = -1.
|
||||
self._time_attr = time_attr
|
||||
|
||||
def set_search_properties(self, metric, mode):
|
||||
if self._metric and metric:
|
||||
return False
|
||||
if self._mode and mode:
|
||||
return False
|
||||
|
||||
if metric:
|
||||
self._metric = metric
|
||||
if mode:
|
||||
self._mode = mode
|
||||
|
||||
if self._mode == "max":
|
||||
self._metric_op = 1.
|
||||
elif self._mode == "min":
|
||||
self._metric_op = -1.
|
||||
|
||||
return True
|
||||
|
||||
def on_trial_add(self, trial_runner, trial):
|
||||
if not self._metric or not self._metric_op:
|
||||
raise ValueError(
|
||||
"{} has been instantiated without a valid `metric` ({}) or "
|
||||
"`mode` ({}) parameter. Either pass these parameters when "
|
||||
"instantiating the scheduler, or pass them as parameters "
|
||||
"to `tune.run()`".format(self.__class__.__name__, self._metric,
|
||||
self._mode))
|
||||
|
||||
sizes = np.array([len(b._rungs) for b in self._brackets])
|
||||
probs = np.e**(sizes - sizes.max())
|
||||
normalized = probs / probs.sum()
|
||||
|
|
|
@ -30,6 +30,13 @@ class HyperBandForBOHB(HyperBandScheduler):
|
|||
to current bracket. Else, create new iteration, create new bracket,
|
||||
add to bracket.
|
||||
"""
|
||||
if not self._metric or not self._metric_op:
|
||||
raise ValueError(
|
||||
"{} has been instantiated without a valid `metric` ({}) or "
|
||||
"`mode` ({}) parameter. Either pass these parameters when "
|
||||
"instantiating the scheduler, or pass them as parameters "
|
||||
"to `tune.run()`".format(self.__class__.__name__, self._metric,
|
||||
self._mode))
|
||||
|
||||
cur_bracket = self._state["bracket"]
|
||||
cur_band = self._hyperbands[self._state["band_idx"]]
|
||||
|
|
|
@ -76,12 +76,13 @@ class HyperBandScheduler(FIFOScheduler):
|
|||
def __init__(self,
|
||||
time_attr="training_iteration",
|
||||
reward_attr=None,
|
||||
metric="episode_reward_mean",
|
||||
mode="max",
|
||||
metric=None,
|
||||
mode=None,
|
||||
max_t=81,
|
||||
reduction_factor=3):
|
||||
assert max_t > 0, "Max (time_attr) not valid!"
|
||||
assert mode in ["min", "max"], "`mode` must be 'min' or 'max'!"
|
||||
if mode:
|
||||
assert mode in ["min", "max"], "`mode` must be 'min' or 'max'!"
|
||||
|
||||
if reward_attr is not None:
|
||||
mode = "max"
|
||||
|
@ -108,12 +109,33 @@ class HyperBandScheduler(FIFOScheduler):
|
|||
self._state = {"bracket": None, "band_idx": 0}
|
||||
self._num_stopped = 0
|
||||
self._metric = metric
|
||||
if mode == "max":
|
||||
self._mode = mode
|
||||
self._metric_op = None
|
||||
|
||||
if self._mode == "max":
|
||||
self._metric_op = 1.
|
||||
elif mode == "min":
|
||||
elif self._mode == "min":
|
||||
self._metric_op = -1.
|
||||
self._time_attr = time_attr
|
||||
|
||||
def set_search_properties(self, metric, mode):
|
||||
if self._metric and metric:
|
||||
return False
|
||||
if self._mode and mode:
|
||||
return False
|
||||
|
||||
if metric:
|
||||
self._metric = metric
|
||||
if mode:
|
||||
self._mode = mode
|
||||
|
||||
if self._mode == "max":
|
||||
self._metric_op = 1.
|
||||
elif self._mode == "min":
|
||||
self._metric_op = -1.
|
||||
|
||||
return True
|
||||
|
||||
def on_trial_add(self, trial_runner, trial):
|
||||
"""Adds new trial.
|
||||
|
||||
|
@ -121,6 +143,13 @@ class HyperBandScheduler(FIFOScheduler):
|
|||
add to current bracket. Else, if current band is not filled,
|
||||
create new bracket, add to current bracket.
|
||||
Else, create new iteration, create new bracket, add to bracket."""
|
||||
if not self._metric or not self._metric_op:
|
||||
raise ValueError(
|
||||
"{} has been instantiated without a valid `metric` ({}) or "
|
||||
"`mode` ({}) parameter. Either pass these parameters when "
|
||||
"instantiating the scheduler, or pass them as parameters "
|
||||
"to `tune.run()`".format(self.__class__.__name__, self._metric,
|
||||
self._mode))
|
||||
|
||||
cur_bracket = self._state["bracket"]
|
||||
cur_band = self._hyperbands[self._state["band_idx"]]
|
||||
|
|
|
@ -40,13 +40,12 @@ class MedianStoppingRule(FIFOScheduler):
|
|||
def __init__(self,
|
||||
time_attr="time_total_s",
|
||||
reward_attr=None,
|
||||
metric="episode_reward_mean",
|
||||
mode="max",
|
||||
metric=None,
|
||||
mode=None,
|
||||
grace_period=60.0,
|
||||
min_samples_required=3,
|
||||
min_time_slice=0,
|
||||
hard_stop=True):
|
||||
assert mode in ["min", "max"], "`mode` must be 'min' or 'max'!"
|
||||
if reward_attr is not None:
|
||||
mode = "max"
|
||||
metric = reward_attr
|
||||
|
@ -60,15 +59,49 @@ class MedianStoppingRule(FIFOScheduler):
|
|||
self._min_samples_required = min_samples_required
|
||||
self._min_time_slice = min_time_slice
|
||||
self._metric = metric
|
||||
assert mode in {"min", "max"}, "`mode` must be 'min' or 'max'."
|
||||
self._worst = float("-inf") if mode == "max" else float("inf")
|
||||
self._compare_op = max if mode == "max" else min
|
||||
self._worst = None
|
||||
self._compare_op = None
|
||||
|
||||
self._mode = mode
|
||||
if mode:
|
||||
assert mode in ["min", "max"], "`mode` must be 'min' or 'max'."
|
||||
self._worst = float("-inf") if self._mode == "max" else float(
|
||||
"inf")
|
||||
self._compare_op = max if self._mode == "max" else min
|
||||
|
||||
self._time_attr = time_attr
|
||||
self._hard_stop = hard_stop
|
||||
self._trial_state = {}
|
||||
self._last_pause = collections.defaultdict(lambda: float("-inf"))
|
||||
self._results = collections.defaultdict(list)
|
||||
|
||||
def set_search_properties(self, metric, mode):
|
||||
if self._metric and metric:
|
||||
return False
|
||||
if self._mode and mode:
|
||||
return False
|
||||
|
||||
if metric:
|
||||
self._metric = metric
|
||||
if mode:
|
||||
self._mode = mode
|
||||
|
||||
self._worst = float("-inf") if self._mode == "max" else float("inf")
|
||||
self._compare_op = max if self._mode == "max" else min
|
||||
|
||||
return True
|
||||
|
||||
def on_trial_add(self, trial_runner, trial):
|
||||
if not self._metric or not self._worst or not self._compare_op:
|
||||
raise ValueError(
|
||||
"{} has been instantiated without a valid `metric` ({}) or "
|
||||
"`mode` ({}) parameter. Either pass these parameters when "
|
||||
"instantiating the scheduler, or pass them as parameters "
|
||||
"to `tune.run()`".format(self.__class__.__name__, self._metric,
|
||||
self._mode))
|
||||
|
||||
super(MedianStoppingRule, self).on_trial_add(trial_runner, trial)
|
||||
|
||||
def on_trial_result(self, trial_runner, trial, result):
|
||||
"""Callback for early stopping.
|
||||
|
||||
|
|
|
@ -216,8 +216,8 @@ class PopulationBasedTraining(FIFOScheduler):
|
|||
def __init__(self,
|
||||
time_attr="time_total_s",
|
||||
reward_attr=None,
|
||||
metric="episode_reward_mean",
|
||||
mode="max",
|
||||
metric=None,
|
||||
mode=None,
|
||||
perturbation_interval=60.0,
|
||||
hyperparam_mutations={},
|
||||
quantile_fraction=0.25,
|
||||
|
@ -253,7 +253,8 @@ class PopulationBasedTraining(FIFOScheduler):
|
|||
"perturbation_interval must be a positive number greater "
|
||||
"than 0. Current value: '{}'".format(perturbation_interval))
|
||||
|
||||
assert mode in ["min", "max"], "`mode` must be 'min' or 'max'!"
|
||||
if mode:
|
||||
assert mode in ["min", "max"], "`mode` must be 'min' or 'max'."
|
||||
|
||||
if reward_attr is not None:
|
||||
mode = "max"
|
||||
|
@ -265,9 +266,11 @@ class PopulationBasedTraining(FIFOScheduler):
|
|||
|
||||
FIFOScheduler.__init__(self)
|
||||
self._metric = metric
|
||||
if mode == "max":
|
||||
self._mode = mode
|
||||
self._metric_op = None
|
||||
if self._mode == "max":
|
||||
self._metric_op = 1.
|
||||
elif mode == "min":
|
||||
elif self._mode == "min":
|
||||
self._metric_op = -1.
|
||||
self._time_attr = time_attr
|
||||
self._perturbation_interval = perturbation_interval
|
||||
|
@ -285,7 +288,33 @@ class PopulationBasedTraining(FIFOScheduler):
|
|||
self._num_checkpoints = 0
|
||||
self._num_perturbations = 0
|
||||
|
||||
def set_search_properties(self, metric, mode):
|
||||
if self._metric and metric:
|
||||
return False
|
||||
if self._mode and mode:
|
||||
return False
|
||||
|
||||
if metric:
|
||||
self._metric = metric
|
||||
if mode:
|
||||
self._mode = mode
|
||||
|
||||
if self._mode == "max":
|
||||
self._metric_op = 1.
|
||||
elif self._mode == "min":
|
||||
self._metric_op = -1.
|
||||
|
||||
return True
|
||||
|
||||
def on_trial_add(self, trial_runner, trial):
|
||||
if not self._metric or not self._metric_op:
|
||||
raise ValueError(
|
||||
"{} has been instantiated without a valid `metric` ({}) or "
|
||||
"`mode` ({}) parameter. Either pass these parameters when "
|
||||
"instantiating the scheduler, or pass them as parameters "
|
||||
"to `tune.run()`".format(self.__class__.__name__, self._metric,
|
||||
self._mode))
|
||||
|
||||
self._trial_state[trial] = PBTTrialState(trial)
|
||||
|
||||
for attr in self._hyperparam_mutations.keys():
|
||||
|
|
|
@ -8,6 +8,18 @@ class TrialScheduler:
|
|||
PAUSE = "PAUSE" #: Status for pausing trial execution
|
||||
STOP = "STOP" #: Status for stopping trial execution
|
||||
|
||||
def set_search_properties(self, metric, mode):
|
||||
"""Pass search properties to scheduler.
|
||||
|
||||
This method acts as an alternative to instantiating schedulers
|
||||
that react to metrics with their own `metric` and `mode` parameters.
|
||||
|
||||
Args:
|
||||
metric (str): Metric to optimize
|
||||
mode (str): One of ["min", "max"]. Direction to optimize.
|
||||
"""
|
||||
return True
|
||||
|
||||
def on_trial_add(self, trial_runner, trial):
|
||||
"""Called when a new trial is added to the trial runner."""
|
||||
|
||||
|
|
|
@ -8,8 +8,8 @@ from ray.tune.suggest.repeater import Repeater
|
|||
|
||||
def create_searcher(
|
||||
search_alg,
|
||||
metric="episode_reward_mean",
|
||||
mode="max",
|
||||
metric=None,
|
||||
mode=None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Instantiate a search algorithm based on the given string.
|
||||
|
|
|
@ -104,15 +104,16 @@ class AxSearch(Searcher):
|
|||
|
||||
def __init__(self,
|
||||
space=None,
|
||||
metric="episode_reward_mean",
|
||||
mode="max",
|
||||
metric=None,
|
||||
mode=None,
|
||||
parameter_constraints=None,
|
||||
outcome_constraints=None,
|
||||
ax_client=None,
|
||||
use_early_stopped_trials=None,
|
||||
max_concurrent=None):
|
||||
assert ax is not None, "Ax must be installed!"
|
||||
assert mode in ["min", "max"], "`mode` must be one of ['min', 'max']"
|
||||
if mode:
|
||||
assert mode in ["min", "max"], "`mode` must be 'min' or 'max'."
|
||||
|
||||
super(AxSearch, self).__init__(
|
||||
metric=metric,
|
||||
|
|
|
@ -101,8 +101,8 @@ class BayesOptSearch(Searcher):
|
|||
|
||||
def __init__(self,
|
||||
space=None,
|
||||
metric="episode_reward_mean",
|
||||
mode="max",
|
||||
metric=None,
|
||||
mode=None,
|
||||
utility_kwargs=None,
|
||||
random_state=42,
|
||||
random_search_steps=10,
|
||||
|
@ -144,7 +144,8 @@ class BayesOptSearch(Searcher):
|
|||
assert byo is not None, (
|
||||
"BayesOpt must be installed!. You can install BayesOpt with"
|
||||
" the command: `pip install bayesian-optimization`.")
|
||||
assert mode in ["min", "max"], "`mode` must be 'min' or 'max'!"
|
||||
if mode:
|
||||
assert mode in ["min", "max"], "`mode` must be 'min' or 'max'."
|
||||
self.max_concurrent = max_concurrent
|
||||
self._config_counter = defaultdict(int)
|
||||
self._patience = patience
|
||||
|
|
|
@ -95,11 +95,12 @@ class TuneBOHB(Searcher):
|
|||
space=None,
|
||||
bohb_config=None,
|
||||
max_concurrent=10,
|
||||
metric="neg_mean_loss",
|
||||
mode="max"):
|
||||
metric=None,
|
||||
mode=None):
|
||||
from hpbandster.optimizers.config_generators.bohb import BOHB
|
||||
assert BOHB is not None, "HpBandSter must be installed!"
|
||||
assert mode in ["min", "max"], "`mode` must be in [min, max]!"
|
||||
if mode:
|
||||
assert mode in ["min", "max"], "`mode` must be 'min' or 'max'."
|
||||
self._max_concurrent = max_concurrent
|
||||
self.trial_to_params = {}
|
||||
self.running = set()
|
||||
|
|
|
@ -130,15 +130,16 @@ class DragonflySearch(Searcher):
|
|||
optimizer=None,
|
||||
domain=None,
|
||||
space=None,
|
||||
metric="episode_reward_mean",
|
||||
mode="max",
|
||||
metric=None,
|
||||
mode=None,
|
||||
points_to_evaluate=None,
|
||||
evaluated_rewards=None,
|
||||
**kwargs):
|
||||
assert dragonfly is not None, """dragonfly must be installed!
|
||||
You can install Dragonfly with the command:
|
||||
`pip install dragonfly-opt`."""
|
||||
assert mode in ["min", "max"], "`mode` must be 'min' or 'max'!"
|
||||
if mode:
|
||||
assert mode in ["min", "max"], "`mode` must be 'min' or 'max'."
|
||||
|
||||
super(DragonflySearch, self).__init__(
|
||||
metric=metric, mode=mode, **kwargs)
|
||||
|
|
|
@ -118,8 +118,8 @@ class HyperOptSearch(Searcher):
|
|||
def __init__(
|
||||
self,
|
||||
space=None,
|
||||
metric="episode_reward_mean",
|
||||
mode="max",
|
||||
metric=None,
|
||||
mode=None,
|
||||
points_to_evaluate=None,
|
||||
n_initial_points=20,
|
||||
random_state_seed=None,
|
||||
|
@ -129,6 +129,8 @@ class HyperOptSearch(Searcher):
|
|||
):
|
||||
assert hpo is not None, (
|
||||
"HyperOpt must be installed! Run `pip install hyperopt`.")
|
||||
if mode:
|
||||
assert mode in ["min", "max"], "`mode` must be 'min' or 'max'."
|
||||
from hyperopt.fmin import generate_trials_to_calculate
|
||||
super(HyperOptSearch, self).__init__(
|
||||
metric=metric,
|
||||
|
|
|
@ -87,12 +87,13 @@ class NevergradSearch(Searcher):
|
|||
def __init__(self,
|
||||
optimizer=None,
|
||||
space=None,
|
||||
metric="episode_reward_mean",
|
||||
mode="max",
|
||||
metric=None,
|
||||
mode=None,
|
||||
max_concurrent=None,
|
||||
**kwargs):
|
||||
assert ng is not None, "Nevergrad must be installed!"
|
||||
assert mode in ["min", "max"], "`mode` must be 'min' or 'max'!"
|
||||
if mode:
|
||||
assert mode in ["min", "max"], "`mode` must be 'min' or 'max'."
|
||||
|
||||
super(NevergradSearch, self).__init__(
|
||||
metric=metric, mode=mode, max_concurrent=max_concurrent, **kwargs)
|
||||
|
|
|
@ -100,11 +100,7 @@ class OptunaSearch(Searcher):
|
|||
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
space=None,
|
||||
metric="episode_reward_mean",
|
||||
mode="max",
|
||||
sampler=None):
|
||||
def __init__(self, space=None, metric=None, mode=None, sampler=None):
|
||||
assert ot is not None, (
|
||||
"Optuna must be installed! Run `pip install optuna`.")
|
||||
super(OptunaSearch, self).__init__(
|
||||
|
|
|
@ -127,8 +127,8 @@ class SkOptSearch(Searcher):
|
|||
def __init__(self,
|
||||
optimizer=None,
|
||||
space=None,
|
||||
metric="episode_reward_mean",
|
||||
mode="max",
|
||||
metric=None,
|
||||
mode=None,
|
||||
points_to_evaluate=None,
|
||||
evaluated_rewards=None,
|
||||
max_concurrent=None,
|
||||
|
@ -137,7 +137,8 @@ class SkOptSearch(Searcher):
|
|||
You can install Skopt with the command:
|
||||
`pip install scikit-optimize`."""
|
||||
|
||||
assert mode in ["min", "max"], "`mode` must be 'min' or 'max'!"
|
||||
if mode:
|
||||
assert mode in ["min", "max"], "`mode` must be 'min' or 'max'."
|
||||
self.max_concurrent = max_concurrent
|
||||
super(SkOptSearch, self).__init__(
|
||||
metric=metric,
|
||||
|
|
|
@ -56,8 +56,8 @@ class Searcher:
|
|||
CKPT_FILE_TMPL = "searcher-state-{}.pkl"
|
||||
|
||||
def __init__(self,
|
||||
metric="episode_reward_mean",
|
||||
mode="max",
|
||||
metric=None,
|
||||
mode=None,
|
||||
max_concurrent=None,
|
||||
use_early_stopped_trials=None):
|
||||
if use_early_stopped_trials is False:
|
||||
|
@ -70,6 +70,13 @@ class Searcher:
|
|||
"search algorithm. Use tune.suggest.ConcurrencyLimiter() "
|
||||
"instead. This will raise an error in future versions of Ray.")
|
||||
|
||||
self._metric = metric
|
||||
self._mode = mode
|
||||
|
||||
if not mode or not metric:
|
||||
# Early return to avoid assertions
|
||||
return
|
||||
|
||||
assert isinstance(
|
||||
metric, type(mode)), "metric and mode must be of the same type"
|
||||
if isinstance(mode, str):
|
||||
|
@ -83,9 +90,6 @@ class Searcher:
|
|||
else:
|
||||
raise ValueError("Mode most either be a list or string")
|
||||
|
||||
self._metric = metric
|
||||
self._mode = mode
|
||||
|
||||
def set_search_properties(self, metric, mode, config):
|
||||
"""Pass search properties to searcher.
|
||||
|
||||
|
|
|
@ -109,12 +109,13 @@ class ZOOptSearch(Searcher):
|
|||
algo="asracos",
|
||||
budget=None,
|
||||
dim_dict=None,
|
||||
metric="episode_reward_mean",
|
||||
mode="min",
|
||||
metric=None,
|
||||
mode=None,
|
||||
**kwargs):
|
||||
assert zoopt is not None, "Zoopt not found - please install zoopt."
|
||||
assert budget is not None, "`budget` should not be None!"
|
||||
assert mode in ["min", "max"], "`mode` must be 'min' or 'max'!"
|
||||
if mode:
|
||||
assert mode in ["min", "max"], "`mode` must be 'min' or 'max'."
|
||||
_algo = algo.lower()
|
||||
assert _algo in ["asracos", "sracos"
|
||||
], "`algo` must be in ['asracos', 'sracos'] currently"
|
||||
|
|
|
@ -60,7 +60,11 @@ class EarlyStoppingSuite(unittest.TestCase):
|
|||
return t1, t2
|
||||
|
||||
def testMedianStoppingConstantPerf(self):
|
||||
rule = MedianStoppingRule(grace_period=0, min_samples_required=1)
|
||||
rule = MedianStoppingRule(
|
||||
metric="episode_reward_mean",
|
||||
mode="max",
|
||||
grace_period=0,
|
||||
min_samples_required=1)
|
||||
t1, t2 = self.basicSetup(rule)
|
||||
runner = mock_trial_runner()
|
||||
rule.on_trial_complete(runner, t1, result(10, 1000))
|
||||
|
@ -75,7 +79,11 @@ class EarlyStoppingSuite(unittest.TestCase):
|
|||
TrialScheduler.STOP)
|
||||
|
||||
def testMedianStoppingOnCompleteOnly(self):
|
||||
rule = MedianStoppingRule(grace_period=0, min_samples_required=1)
|
||||
rule = MedianStoppingRule(
|
||||
metric="episode_reward_mean",
|
||||
mode="max",
|
||||
grace_period=0,
|
||||
min_samples_required=1)
|
||||
t1, t2 = self.basicSetup(rule)
|
||||
runner = mock_trial_runner()
|
||||
self.assertEqual(
|
||||
|
@ -87,7 +95,11 @@ class EarlyStoppingSuite(unittest.TestCase):
|
|||
TrialScheduler.STOP)
|
||||
|
||||
def testMedianStoppingGracePeriod(self):
|
||||
rule = MedianStoppingRule(grace_period=2.5, min_samples_required=1)
|
||||
rule = MedianStoppingRule(
|
||||
metric="episode_reward_mean",
|
||||
mode="max",
|
||||
grace_period=2.5,
|
||||
min_samples_required=1)
|
||||
t1, t2 = self.basicSetup(rule)
|
||||
runner = mock_trial_runner()
|
||||
rule.on_trial_complete(runner, t1, result(10, 1000))
|
||||
|
@ -104,7 +116,11 @@ class EarlyStoppingSuite(unittest.TestCase):
|
|||
TrialScheduler.STOP)
|
||||
|
||||
def testMedianStoppingMinSamples(self):
|
||||
rule = MedianStoppingRule(grace_period=0, min_samples_required=2)
|
||||
rule = MedianStoppingRule(
|
||||
metric="episode_reward_mean",
|
||||
mode="max",
|
||||
grace_period=0,
|
||||
min_samples_required=2)
|
||||
t1, t2 = self.basicSetup(rule)
|
||||
runner = mock_trial_runner()
|
||||
rule.on_trial_complete(runner, t1, result(10, 1000))
|
||||
|
@ -120,7 +136,11 @@ class EarlyStoppingSuite(unittest.TestCase):
|
|||
TrialScheduler.STOP)
|
||||
|
||||
def testMedianStoppingUsesMedian(self):
|
||||
rule = MedianStoppingRule(grace_period=0, min_samples_required=1)
|
||||
rule = MedianStoppingRule(
|
||||
metric="episode_reward_mean",
|
||||
mode="max",
|
||||
grace_period=0,
|
||||
min_samples_required=1)
|
||||
t1, t2 = self.basicSetup(rule)
|
||||
runner = mock_trial_runner()
|
||||
rule.on_trial_complete(runner, t1, result(10, 1000))
|
||||
|
@ -135,7 +155,11 @@ class EarlyStoppingSuite(unittest.TestCase):
|
|||
|
||||
def testMedianStoppingSoftStop(self):
|
||||
rule = MedianStoppingRule(
|
||||
grace_period=0, min_samples_required=1, hard_stop=False)
|
||||
metric="episode_reward_mean",
|
||||
mode="max",
|
||||
grace_period=0,
|
||||
min_samples_required=1,
|
||||
hard_stop=False)
|
||||
t1, t2 = self.basicSetup(rule)
|
||||
runner = mock_trial_runner()
|
||||
rule.on_trial_complete(runner, t1, result(10, 1000))
|
||||
|
@ -265,7 +289,8 @@ class HyperbandSuite(unittest.TestCase):
|
|||
(15, 9) -> (5, 27) -> (2, 45);
|
||||
(34, 3) -> (12, 9) -> (4, 27) -> (2, 42);
|
||||
(81, 1) -> (27, 3) -> (9, 9) -> (3, 27) -> (1, 41);"""
|
||||
sched = HyperBandScheduler(max_t=max_t)
|
||||
sched = HyperBandScheduler(
|
||||
metric="episode_reward_mean", mode="max", max_t=max_t)
|
||||
for i in range(num_trials):
|
||||
t = Trial("__fake")
|
||||
sched.on_trial_add(None, t)
|
||||
|
@ -321,7 +346,7 @@ class HyperbandSuite(unittest.TestCase):
|
|||
return sched
|
||||
|
||||
def testConfigSameEta(self):
|
||||
sched = HyperBandScheduler()
|
||||
sched = HyperBandScheduler(metric="episode_reward_mean", mode="max")
|
||||
i = 0
|
||||
while not sched._cur_band_filled():
|
||||
t = Trial("__fake")
|
||||
|
@ -335,7 +360,10 @@ class HyperbandSuite(unittest.TestCase):
|
|||
|
||||
reduction_factor = 10
|
||||
sched = HyperBandScheduler(
|
||||
max_t=1000, reduction_factor=reduction_factor)
|
||||
metric="episode_reward_mean",
|
||||
mode="max",
|
||||
max_t=1000,
|
||||
reduction_factor=reduction_factor)
|
||||
i = 0
|
||||
while not sched._cur_band_filled():
|
||||
t = Trial("__fake")
|
||||
|
@ -348,7 +376,8 @@ class HyperbandSuite(unittest.TestCase):
|
|||
self.assertEqual(sched._hyperbands[0][-1]._r, 1)
|
||||
|
||||
def testConfigSameEtaSmall(self):
|
||||
sched = HyperBandScheduler(max_t=1)
|
||||
sched = HyperBandScheduler(
|
||||
metric="episode_reward_mean", mode="max", max_t=1)
|
||||
i = 0
|
||||
while len(sched._hyperbands) < 2:
|
||||
t = Trial("__fake")
|
||||
|
@ -627,7 +656,11 @@ class BOHBSuite(unittest.TestCase):
|
|||
_register_all() # re-register the evicted objects
|
||||
|
||||
def testLargestBracketFirst(self):
|
||||
sched = HyperBandForBOHB(max_t=3, reduction_factor=3)
|
||||
sched = HyperBandForBOHB(
|
||||
metric="episode_reward_mean",
|
||||
mode="max",
|
||||
max_t=3,
|
||||
reduction_factor=3)
|
||||
runner = _MockTrialRunner(sched)
|
||||
for i in range(3):
|
||||
t = Trial("__fake")
|
||||
|
@ -642,7 +675,11 @@ class BOHBSuite(unittest.TestCase):
|
|||
def result(score, ts):
|
||||
return {"episode_reward_mean": score, TRAINING_ITERATION: ts}
|
||||
|
||||
sched = HyperBandForBOHB(max_t=3, reduction_factor=3)
|
||||
sched = HyperBandForBOHB(
|
||||
metric="episode_reward_mean",
|
||||
mode="max",
|
||||
max_t=3,
|
||||
reduction_factor=3)
|
||||
runner = _MockTrialRunner(sched)
|
||||
runner._search_alg = MagicMock()
|
||||
runner._search_alg.searcher = MagicMock()
|
||||
|
@ -668,7 +705,11 @@ class BOHBSuite(unittest.TestCase):
|
|||
def result(score, ts):
|
||||
return {"episode_reward_mean": score, TRAINING_ITERATION: ts}
|
||||
|
||||
sched = HyperBandForBOHB(max_t=3, reduction_factor=3, mode="min")
|
||||
sched = HyperBandForBOHB(
|
||||
metric="episode_reward_mean",
|
||||
mode="min",
|
||||
max_t=3,
|
||||
reduction_factor=3)
|
||||
runner = _MockTrialRunner(sched)
|
||||
runner._search_alg = MagicMock()
|
||||
runner._search_alg.searcher = MagicMock()
|
||||
|
@ -693,7 +734,11 @@ class BOHBSuite(unittest.TestCase):
|
|||
def result(score, ts):
|
||||
return {"episode_reward_mean": score, TRAINING_ITERATION: ts}
|
||||
|
||||
sched = HyperBandForBOHB(max_t=10, reduction_factor=3, mode="min")
|
||||
sched = HyperBandForBOHB(
|
||||
metric="episode_reward_mean",
|
||||
mode="min",
|
||||
max_t=10,
|
||||
reduction_factor=3)
|
||||
runner = _MockTrialRunner(sched)
|
||||
runner._search_alg = MagicMock()
|
||||
runner._search_alg.searcher = MagicMock()
|
||||
|
@ -761,6 +806,8 @@ class PopulationBasedTestingSuite(unittest.TestCase):
|
|||
}
|
||||
pbt = PopulationBasedTraining(
|
||||
time_attr="training_iteration",
|
||||
metric="episode_reward_mean",
|
||||
mode="max",
|
||||
perturbation_interval=perturbation_interval,
|
||||
resample_probability=resample_prob,
|
||||
quantile_fraction=0.25,
|
||||
|
@ -1675,6 +1722,7 @@ class E2EPopulationBasedTestingSuite(unittest.TestCase):
|
|||
}
|
||||
pbt = PopulationBasedTraining(
|
||||
metric="mean_accuracy",
|
||||
mode="max",
|
||||
time_attr="training_iteration",
|
||||
perturbation_interval=perturbation_interval,
|
||||
resample_probability=resample_prob,
|
||||
|
@ -1791,7 +1839,8 @@ class AsyncHyperBandSuite(unittest.TestCase):
|
|||
return t1, t2
|
||||
|
||||
def testAsyncHBOnComplete(self):
|
||||
scheduler = AsyncHyperBandScheduler(max_t=10, brackets=1)
|
||||
scheduler = AsyncHyperBandScheduler(
|
||||
metric="episode_reward_mean", mode="max", max_t=10, brackets=1)
|
||||
t1, t2 = self.basicSetup(scheduler)
|
||||
t3 = Trial("PPO")
|
||||
scheduler.on_trial_add(None, t3)
|
||||
|
@ -1802,7 +1851,11 @@ class AsyncHyperBandSuite(unittest.TestCase):
|
|||
|
||||
def testAsyncHBGracePeriod(self):
|
||||
scheduler = AsyncHyperBandScheduler(
|
||||
grace_period=2.5, reduction_factor=3, brackets=1)
|
||||
metric="episode_reward_mean",
|
||||
mode="max",
|
||||
grace_period=2.5,
|
||||
reduction_factor=3,
|
||||
brackets=1)
|
||||
t1, t2 = self.basicSetup(scheduler)
|
||||
scheduler.on_trial_complete(None, t1, result(10, 1000))
|
||||
scheduler.on_trial_complete(None, t2, result(10, 1000))
|
||||
|
@ -1819,7 +1872,8 @@ class AsyncHyperBandSuite(unittest.TestCase):
|
|||
TrialScheduler.STOP)
|
||||
|
||||
def testAsyncHBAllCompletes(self):
|
||||
scheduler = AsyncHyperBandScheduler(max_t=10, brackets=10)
|
||||
scheduler = AsyncHyperBandScheduler(
|
||||
metric="episode_reward_mean", mode="max", max_t=10, brackets=10)
|
||||
trials = [Trial("PPO") for i in range(10)]
|
||||
for t in trials:
|
||||
scheduler.on_trial_add(None, t)
|
||||
|
@ -1831,7 +1885,12 @@ class AsyncHyperBandSuite(unittest.TestCase):
|
|||
|
||||
def testAsyncHBUsesPercentile(self):
|
||||
scheduler = AsyncHyperBandScheduler(
|
||||
grace_period=1, max_t=10, reduction_factor=2, brackets=1)
|
||||
metric="episode_reward_mean",
|
||||
mode="max",
|
||||
grace_period=1,
|
||||
max_t=10,
|
||||
reduction_factor=2,
|
||||
brackets=1)
|
||||
t1, t2 = self.basicSetup(scheduler)
|
||||
scheduler.on_trial_complete(None, t1, result(10, 1000))
|
||||
scheduler.on_trial_complete(None, t2, result(10, 1000))
|
||||
|
@ -1846,7 +1905,12 @@ class AsyncHyperBandSuite(unittest.TestCase):
|
|||
|
||||
def testAsyncHBNanPercentile(self):
|
||||
scheduler = AsyncHyperBandScheduler(
|
||||
grace_period=1, max_t=10, reduction_factor=2, brackets=1)
|
||||
metric="episode_reward_mean",
|
||||
mode="max",
|
||||
grace_period=1,
|
||||
max_t=10,
|
||||
reduction_factor=2,
|
||||
brackets=1)
|
||||
t1, t2 = self.nanSetup(scheduler)
|
||||
scheduler.on_trial_complete(None, t1, result(10, 450))
|
||||
scheduler.on_trial_complete(None, t2, result(10, np.nan))
|
||||
|
|
|
@ -68,6 +68,8 @@ def _report_progress(runner, reporter, done=False):
|
|||
def run(
|
||||
run_or_experiment,
|
||||
name=None,
|
||||
metric=None,
|
||||
mode=None,
|
||||
stop=None,
|
||||
time_budget_s=None,
|
||||
config=None,
|
||||
|
@ -147,6 +149,12 @@ def run(
|
|||
will need to first register the function:
|
||||
``tune.register_trainable("lambda_id", lambda x: ...)``. You can
|
||||
then use ``tune.run("lambda_id")``.
|
||||
metric (str): Metric to optimize. This metric should be reported
|
||||
with `tune.report()`. If set, will be passed to the search
|
||||
algorithm and scheduler.
|
||||
mode (str): Must be one of [min, max]. Determines whether objective is
|
||||
minimizing or maximizing the metric attribute. If set, will be
|
||||
passed to the search algorithm and scheduler.
|
||||
name (str): Name of experiment.
|
||||
stop (dict | callable | :class:`Stopper`): Stopping criteria. If dict,
|
||||
the keys may be any field in the return result of 'train()',
|
||||
|
@ -276,6 +284,11 @@ def run(
|
|||
"sync_config=SyncConfig(...)`. See `ray.tune.SyncConfig` for "
|
||||
"more details.")
|
||||
|
||||
if mode and mode not in ["min", "max"]:
|
||||
raise ValueError(
|
||||
"The `mode` parameter passed to `tune.run()` has to be one of "
|
||||
"['min', 'max']")
|
||||
|
||||
config = config or {}
|
||||
sync_config = sync_config or SyncConfig()
|
||||
set_sync_periods(sync_config)
|
||||
|
@ -329,8 +342,7 @@ def run(
|
|||
if not search_alg:
|
||||
search_alg = BasicVariantGenerator()
|
||||
|
||||
# TODO (krfricke): Introduce metric/mode as top level API
|
||||
if config and not search_alg.set_search_properties(None, None, config):
|
||||
if config and not search_alg.set_search_properties(metric, mode, config):
|
||||
if has_unresolved_values(config):
|
||||
raise ValueError(
|
||||
"You passed a `config` parameter to `tune.run()` with "
|
||||
|
@ -339,9 +351,17 @@ def run(
|
|||
"does not contain any more parameter definitions - include "
|
||||
"them in the search algorithm's search space if necessary.")
|
||||
|
||||
scheduler = scheduler or FIFOScheduler()
|
||||
if not scheduler.set_search_properties(metric, mode):
|
||||
raise ValueError(
|
||||
"You passed a `metric` or `mode` argument to `tune.run()`, but "
|
||||
"the scheduler you are using was already instantiated with their "
|
||||
"own `metric` and `mode` parameters. Either remove the arguments "
|
||||
"from your scheduler or from your call to `tune.run()`")
|
||||
|
||||
runner = TrialRunner(
|
||||
search_alg=search_alg,
|
||||
scheduler=scheduler or FIFOScheduler(),
|
||||
scheduler=scheduler,
|
||||
local_checkpoint_dir=experiments[0].checkpoint_dir,
|
||||
remote_checkpoint_dir=experiments[0].remote_checkpoint_dir,
|
||||
sync_to_cloud=sync_config.sync_to_cloud,
|
||||
|
@ -413,8 +433,8 @@ def run(
|
|||
return ExperimentAnalysis(
|
||||
runner.checkpoint_file,
|
||||
trials=trials,
|
||||
default_metric=None,
|
||||
default_mode=None)
|
||||
default_metric=metric,
|
||||
default_mode=mode)
|
||||
|
||||
|
||||
def run_experiments(experiments,
|
||||
|
|
|
@ -26,7 +26,7 @@ timm
|
|||
torch>=1.5.0
|
||||
torchvision>=0.6.0
|
||||
transformers
|
||||
tune-sklearn==0.0.5
|
||||
git+git://github.com/ray-project/tune-sklearn@master#tune-sklearn
|
||||
wandb
|
||||
xgboost
|
||||
zoopt>=0.4.0
|
||||
|
|
Loading…
Add table
Reference in a new issue