[tune] add mode/metric parameters to tune.run (#10627)

Co-authored-by: Richard Liaw <rliaw@berkeley.edu>
This commit is contained in:
Kai Fricke 2020-09-09 01:06:21 +01:00 committed by GitHub
parent edce7a05e6
commit 756a9ea641
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
26 changed files with 322 additions and 88 deletions

View file

@ -29,5 +29,5 @@ git+https://github.com/executablebooks/sphinx-book-theme.git@0a87d26e214c419d2e6
tabulate
uvicorn
werkzeug
tune-sklearn==0.0.5
git+git://github.com/ray-project/tune-sklearn@master#tune-sklearn
scikit-optimize

View file

@ -127,7 +127,7 @@ tune_search = TuneSearchCV(
clf,
parameter_grid,
search_optimization="bayesian",
n_iter=3,
n_trials=3,
early_stopping=True,
max_iters=10,
)

View file

@ -149,7 +149,7 @@ py_test(
py_test(
name = "test_sample",
size = "medium",
size = "small",
srcs = ["tests/test_sample.py"],
deps = [":tune_lib"],
tags = ["exclusive"],

View file

@ -121,7 +121,7 @@ if __name__ == "__main__":
else:
ray.init(num_cpus=2 if args.smoke_test else None)
sched = AsyncHyperBandScheduler(
time_attr="training_iteration", metric="mean_accuracy")
time_attr="training_iteration", metric="mean_accuracy", mode="max")
analysis = tune.run(
train_mnist,
name="exp",

View file

@ -65,9 +65,11 @@ class TrainMNIST(tune.Trainable):
if __name__ == "__main__":
args = parser.parse_args()
ray.init(address=args.ray_address, num_cpus=6 if args.smoke_test else None)
sched = ASHAScheduler(metric="mean_accuracy")
sched = ASHAScheduler()
analysis = tune.run(
TrainMNIST,
metric="mean_accuracy",
mode="max",
scheduler=sched,
stop={
"mean_accuracy": 0.95,

View file

@ -10,8 +10,8 @@ from ray.tune.schedulers.pbt import (PopulationBasedTraining,
def create_scheduler(
scheduler,
metric="episode_reward_mean",
mode="max",
metric=None,
mode=None,
**kwargs,
):
"""Instantiate a scheduler based on the given string.

View file

@ -38,8 +38,8 @@ class AsyncHyperBandScheduler(FIFOScheduler):
def __init__(self,
time_attr="training_iteration",
reward_attr=None,
metric="episode_reward_mean",
mode="max",
metric=None,
mode=None,
max_t=100,
grace_period=1,
reduction_factor=4,
@ -49,7 +49,8 @@ class AsyncHyperBandScheduler(FIFOScheduler):
assert grace_period > 0, "grace_period must be positive!"
assert reduction_factor > 1, "Reduction Factor not valid!"
assert brackets > 0, "brackets must be positive!"
assert mode in ["min", "max"], "`mode` must be 'min' or 'max'!"
if mode:
assert mode in ["min", "max"], "`mode` must be 'min' or 'max'!"
if reward_attr is not None:
mode = "max"
@ -73,13 +74,41 @@ class AsyncHyperBandScheduler(FIFOScheduler):
self._counter = 0 # for
self._num_stopped = 0
self._metric = metric
if mode == "max":
self._mode = mode
self._metric_op = None
if self._mode == "max":
self._metric_op = 1.
elif mode == "min":
elif self._mode == "min":
self._metric_op = -1.
self._time_attr = time_attr
def set_search_properties(self, metric, mode):
if self._metric and metric:
return False
if self._mode and mode:
return False
if metric:
self._metric = metric
if mode:
self._mode = mode
if self._mode == "max":
self._metric_op = 1.
elif self._mode == "min":
self._metric_op = -1.
return True
def on_trial_add(self, trial_runner, trial):
if not self._metric or not self._metric_op:
raise ValueError(
"{} has been instantiated without a valid `metric` ({}) or "
"`mode` ({}) parameter. Either pass these parameters when "
"instantiating the scheduler, or pass them as parameters "
"to `tune.run()`".format(self.__class__.__name__, self._metric,
self._mode))
sizes = np.array([len(b._rungs) for b in self._brackets])
probs = np.e**(sizes - sizes.max())
normalized = probs / probs.sum()

View file

@ -30,6 +30,13 @@ class HyperBandForBOHB(HyperBandScheduler):
to current bracket. Else, create new iteration, create new bracket,
add to bracket.
"""
if not self._metric or not self._metric_op:
raise ValueError(
"{} has been instantiated without a valid `metric` ({}) or "
"`mode` ({}) parameter. Either pass these parameters when "
"instantiating the scheduler, or pass them as parameters "
"to `tune.run()`".format(self.__class__.__name__, self._metric,
self._mode))
cur_bracket = self._state["bracket"]
cur_band = self._hyperbands[self._state["band_idx"]]

View file

@ -76,12 +76,13 @@ class HyperBandScheduler(FIFOScheduler):
def __init__(self,
time_attr="training_iteration",
reward_attr=None,
metric="episode_reward_mean",
mode="max",
metric=None,
mode=None,
max_t=81,
reduction_factor=3):
assert max_t > 0, "Max (time_attr) not valid!"
assert mode in ["min", "max"], "`mode` must be 'min' or 'max'!"
if mode:
assert mode in ["min", "max"], "`mode` must be 'min' or 'max'!"
if reward_attr is not None:
mode = "max"
@ -108,12 +109,33 @@ class HyperBandScheduler(FIFOScheduler):
self._state = {"bracket": None, "band_idx": 0}
self._num_stopped = 0
self._metric = metric
if mode == "max":
self._mode = mode
self._metric_op = None
if self._mode == "max":
self._metric_op = 1.
elif mode == "min":
elif self._mode == "min":
self._metric_op = -1.
self._time_attr = time_attr
def set_search_properties(self, metric, mode):
if self._metric and metric:
return False
if self._mode and mode:
return False
if metric:
self._metric = metric
if mode:
self._mode = mode
if self._mode == "max":
self._metric_op = 1.
elif self._mode == "min":
self._metric_op = -1.
return True
def on_trial_add(self, trial_runner, trial):
"""Adds new trial.
@ -121,6 +143,13 @@ class HyperBandScheduler(FIFOScheduler):
add to current bracket. Else, if current band is not filled,
create new bracket, add to current bracket.
Else, create new iteration, create new bracket, add to bracket."""
if not self._metric or not self._metric_op:
raise ValueError(
"{} has been instantiated without a valid `metric` ({}) or "
"`mode` ({}) parameter. Either pass these parameters when "
"instantiating the scheduler, or pass them as parameters "
"to `tune.run()`".format(self.__class__.__name__, self._metric,
self._mode))
cur_bracket = self._state["bracket"]
cur_band = self._hyperbands[self._state["band_idx"]]

View file

@ -40,13 +40,12 @@ class MedianStoppingRule(FIFOScheduler):
def __init__(self,
time_attr="time_total_s",
reward_attr=None,
metric="episode_reward_mean",
mode="max",
metric=None,
mode=None,
grace_period=60.0,
min_samples_required=3,
min_time_slice=0,
hard_stop=True):
assert mode in ["min", "max"], "`mode` must be 'min' or 'max'!"
if reward_attr is not None:
mode = "max"
metric = reward_attr
@ -60,15 +59,49 @@ class MedianStoppingRule(FIFOScheduler):
self._min_samples_required = min_samples_required
self._min_time_slice = min_time_slice
self._metric = metric
assert mode in {"min", "max"}, "`mode` must be 'min' or 'max'."
self._worst = float("-inf") if mode == "max" else float("inf")
self._compare_op = max if mode == "max" else min
self._worst = None
self._compare_op = None
self._mode = mode
if mode:
assert mode in ["min", "max"], "`mode` must be 'min' or 'max'."
self._worst = float("-inf") if self._mode == "max" else float(
"inf")
self._compare_op = max if self._mode == "max" else min
self._time_attr = time_attr
self._hard_stop = hard_stop
self._trial_state = {}
self._last_pause = collections.defaultdict(lambda: float("-inf"))
self._results = collections.defaultdict(list)
def set_search_properties(self, metric, mode):
if self._metric and metric:
return False
if self._mode and mode:
return False
if metric:
self._metric = metric
if mode:
self._mode = mode
self._worst = float("-inf") if self._mode == "max" else float("inf")
self._compare_op = max if self._mode == "max" else min
return True
def on_trial_add(self, trial_runner, trial):
if not self._metric or not self._worst or not self._compare_op:
raise ValueError(
"{} has been instantiated without a valid `metric` ({}) or "
"`mode` ({}) parameter. Either pass these parameters when "
"instantiating the scheduler, or pass them as parameters "
"to `tune.run()`".format(self.__class__.__name__, self._metric,
self._mode))
super(MedianStoppingRule, self).on_trial_add(trial_runner, trial)
def on_trial_result(self, trial_runner, trial, result):
"""Callback for early stopping.

View file

@ -216,8 +216,8 @@ class PopulationBasedTraining(FIFOScheduler):
def __init__(self,
time_attr="time_total_s",
reward_attr=None,
metric="episode_reward_mean",
mode="max",
metric=None,
mode=None,
perturbation_interval=60.0,
hyperparam_mutations={},
quantile_fraction=0.25,
@ -253,7 +253,8 @@ class PopulationBasedTraining(FIFOScheduler):
"perturbation_interval must be a positive number greater "
"than 0. Current value: '{}'".format(perturbation_interval))
assert mode in ["min", "max"], "`mode` must be 'min' or 'max'!"
if mode:
assert mode in ["min", "max"], "`mode` must be 'min' or 'max'."
if reward_attr is not None:
mode = "max"
@ -265,9 +266,11 @@ class PopulationBasedTraining(FIFOScheduler):
FIFOScheduler.__init__(self)
self._metric = metric
if mode == "max":
self._mode = mode
self._metric_op = None
if self._mode == "max":
self._metric_op = 1.
elif mode == "min":
elif self._mode == "min":
self._metric_op = -1.
self._time_attr = time_attr
self._perturbation_interval = perturbation_interval
@ -285,7 +288,33 @@ class PopulationBasedTraining(FIFOScheduler):
self._num_checkpoints = 0
self._num_perturbations = 0
def set_search_properties(self, metric, mode):
if self._metric and metric:
return False
if self._mode and mode:
return False
if metric:
self._metric = metric
if mode:
self._mode = mode
if self._mode == "max":
self._metric_op = 1.
elif self._mode == "min":
self._metric_op = -1.
return True
def on_trial_add(self, trial_runner, trial):
if not self._metric or not self._metric_op:
raise ValueError(
"{} has been instantiated without a valid `metric` ({}) or "
"`mode` ({}) parameter. Either pass these parameters when "
"instantiating the scheduler, or pass them as parameters "
"to `tune.run()`".format(self.__class__.__name__, self._metric,
self._mode))
self._trial_state[trial] = PBTTrialState(trial)
for attr in self._hyperparam_mutations.keys():

View file

@ -8,6 +8,18 @@ class TrialScheduler:
PAUSE = "PAUSE" #: Status for pausing trial execution
STOP = "STOP" #: Status for stopping trial execution
def set_search_properties(self, metric, mode):
"""Pass search properties to scheduler.
This method acts as an alternative to instantiating schedulers
that react to metrics with their own `metric` and `mode` parameters.
Args:
metric (str): Metric to optimize
mode (str): One of ["min", "max"]. Direction to optimize.
"""
return True
def on_trial_add(self, trial_runner, trial):
"""Called when a new trial is added to the trial runner."""

View file

@ -8,8 +8,8 @@ from ray.tune.suggest.repeater import Repeater
def create_searcher(
search_alg,
metric="episode_reward_mean",
mode="max",
metric=None,
mode=None,
**kwargs,
):
"""Instantiate a search algorithm based on the given string.

View file

@ -104,15 +104,16 @@ class AxSearch(Searcher):
def __init__(self,
space=None,
metric="episode_reward_mean",
mode="max",
metric=None,
mode=None,
parameter_constraints=None,
outcome_constraints=None,
ax_client=None,
use_early_stopped_trials=None,
max_concurrent=None):
assert ax is not None, "Ax must be installed!"
assert mode in ["min", "max"], "`mode` must be one of ['min', 'max']"
if mode:
assert mode in ["min", "max"], "`mode` must be 'min' or 'max'."
super(AxSearch, self).__init__(
metric=metric,

View file

@ -101,8 +101,8 @@ class BayesOptSearch(Searcher):
def __init__(self,
space=None,
metric="episode_reward_mean",
mode="max",
metric=None,
mode=None,
utility_kwargs=None,
random_state=42,
random_search_steps=10,
@ -144,7 +144,8 @@ class BayesOptSearch(Searcher):
assert byo is not None, (
"BayesOpt must be installed!. You can install BayesOpt with"
" the command: `pip install bayesian-optimization`.")
assert mode in ["min", "max"], "`mode` must be 'min' or 'max'!"
if mode:
assert mode in ["min", "max"], "`mode` must be 'min' or 'max'."
self.max_concurrent = max_concurrent
self._config_counter = defaultdict(int)
self._patience = patience

View file

@ -95,11 +95,12 @@ class TuneBOHB(Searcher):
space=None,
bohb_config=None,
max_concurrent=10,
metric="neg_mean_loss",
mode="max"):
metric=None,
mode=None):
from hpbandster.optimizers.config_generators.bohb import BOHB
assert BOHB is not None, "HpBandSter must be installed!"
assert mode in ["min", "max"], "`mode` must be in [min, max]!"
if mode:
assert mode in ["min", "max"], "`mode` must be 'min' or 'max'."
self._max_concurrent = max_concurrent
self.trial_to_params = {}
self.running = set()

View file

@ -130,15 +130,16 @@ class DragonflySearch(Searcher):
optimizer=None,
domain=None,
space=None,
metric="episode_reward_mean",
mode="max",
metric=None,
mode=None,
points_to_evaluate=None,
evaluated_rewards=None,
**kwargs):
assert dragonfly is not None, """dragonfly must be installed!
You can install Dragonfly with the command:
`pip install dragonfly-opt`."""
assert mode in ["min", "max"], "`mode` must be 'min' or 'max'!"
if mode:
assert mode in ["min", "max"], "`mode` must be 'min' or 'max'."
super(DragonflySearch, self).__init__(
metric=metric, mode=mode, **kwargs)

View file

@ -118,8 +118,8 @@ class HyperOptSearch(Searcher):
def __init__(
self,
space=None,
metric="episode_reward_mean",
mode="max",
metric=None,
mode=None,
points_to_evaluate=None,
n_initial_points=20,
random_state_seed=None,
@ -129,6 +129,8 @@ class HyperOptSearch(Searcher):
):
assert hpo is not None, (
"HyperOpt must be installed! Run `pip install hyperopt`.")
if mode:
assert mode in ["min", "max"], "`mode` must be 'min' or 'max'."
from hyperopt.fmin import generate_trials_to_calculate
super(HyperOptSearch, self).__init__(
metric=metric,

View file

@ -87,12 +87,13 @@ class NevergradSearch(Searcher):
def __init__(self,
optimizer=None,
space=None,
metric="episode_reward_mean",
mode="max",
metric=None,
mode=None,
max_concurrent=None,
**kwargs):
assert ng is not None, "Nevergrad must be installed!"
assert mode in ["min", "max"], "`mode` must be 'min' or 'max'!"
if mode:
assert mode in ["min", "max"], "`mode` must be 'min' or 'max'."
super(NevergradSearch, self).__init__(
metric=metric, mode=mode, max_concurrent=max_concurrent, **kwargs)

View file

@ -100,11 +100,7 @@ class OptunaSearch(Searcher):
"""
def __init__(self,
space=None,
metric="episode_reward_mean",
mode="max",
sampler=None):
def __init__(self, space=None, metric=None, mode=None, sampler=None):
assert ot is not None, (
"Optuna must be installed! Run `pip install optuna`.")
super(OptunaSearch, self).__init__(

View file

@ -127,8 +127,8 @@ class SkOptSearch(Searcher):
def __init__(self,
optimizer=None,
space=None,
metric="episode_reward_mean",
mode="max",
metric=None,
mode=None,
points_to_evaluate=None,
evaluated_rewards=None,
max_concurrent=None,
@ -137,7 +137,8 @@ class SkOptSearch(Searcher):
You can install Skopt with the command:
`pip install scikit-optimize`."""
assert mode in ["min", "max"], "`mode` must be 'min' or 'max'!"
if mode:
assert mode in ["min", "max"], "`mode` must be 'min' or 'max'."
self.max_concurrent = max_concurrent
super(SkOptSearch, self).__init__(
metric=metric,

View file

@ -56,8 +56,8 @@ class Searcher:
CKPT_FILE_TMPL = "searcher-state-{}.pkl"
def __init__(self,
metric="episode_reward_mean",
mode="max",
metric=None,
mode=None,
max_concurrent=None,
use_early_stopped_trials=None):
if use_early_stopped_trials is False:
@ -70,6 +70,13 @@ class Searcher:
"search algorithm. Use tune.suggest.ConcurrencyLimiter() "
"instead. This will raise an error in future versions of Ray.")
self._metric = metric
self._mode = mode
if not mode or not metric:
# Early return to avoid assertions
return
assert isinstance(
metric, type(mode)), "metric and mode must be of the same type"
if isinstance(mode, str):
@ -83,9 +90,6 @@ class Searcher:
else:
raise ValueError("Mode most either be a list or string")
self._metric = metric
self._mode = mode
def set_search_properties(self, metric, mode, config):
"""Pass search properties to searcher.

View file

@ -109,12 +109,13 @@ class ZOOptSearch(Searcher):
algo="asracos",
budget=None,
dim_dict=None,
metric="episode_reward_mean",
mode="min",
metric=None,
mode=None,
**kwargs):
assert zoopt is not None, "Zoopt not found - please install zoopt."
assert budget is not None, "`budget` should not be None!"
assert mode in ["min", "max"], "`mode` must be 'min' or 'max'!"
if mode:
assert mode in ["min", "max"], "`mode` must be 'min' or 'max'."
_algo = algo.lower()
assert _algo in ["asracos", "sracos"
], "`algo` must be in ['asracos', 'sracos'] currently"

View file

@ -60,7 +60,11 @@ class EarlyStoppingSuite(unittest.TestCase):
return t1, t2
def testMedianStoppingConstantPerf(self):
rule = MedianStoppingRule(grace_period=0, min_samples_required=1)
rule = MedianStoppingRule(
metric="episode_reward_mean",
mode="max",
grace_period=0,
min_samples_required=1)
t1, t2 = self.basicSetup(rule)
runner = mock_trial_runner()
rule.on_trial_complete(runner, t1, result(10, 1000))
@ -75,7 +79,11 @@ class EarlyStoppingSuite(unittest.TestCase):
TrialScheduler.STOP)
def testMedianStoppingOnCompleteOnly(self):
rule = MedianStoppingRule(grace_period=0, min_samples_required=1)
rule = MedianStoppingRule(
metric="episode_reward_mean",
mode="max",
grace_period=0,
min_samples_required=1)
t1, t2 = self.basicSetup(rule)
runner = mock_trial_runner()
self.assertEqual(
@ -87,7 +95,11 @@ class EarlyStoppingSuite(unittest.TestCase):
TrialScheduler.STOP)
def testMedianStoppingGracePeriod(self):
rule = MedianStoppingRule(grace_period=2.5, min_samples_required=1)
rule = MedianStoppingRule(
metric="episode_reward_mean",
mode="max",
grace_period=2.5,
min_samples_required=1)
t1, t2 = self.basicSetup(rule)
runner = mock_trial_runner()
rule.on_trial_complete(runner, t1, result(10, 1000))
@ -104,7 +116,11 @@ class EarlyStoppingSuite(unittest.TestCase):
TrialScheduler.STOP)
def testMedianStoppingMinSamples(self):
rule = MedianStoppingRule(grace_period=0, min_samples_required=2)
rule = MedianStoppingRule(
metric="episode_reward_mean",
mode="max",
grace_period=0,
min_samples_required=2)
t1, t2 = self.basicSetup(rule)
runner = mock_trial_runner()
rule.on_trial_complete(runner, t1, result(10, 1000))
@ -120,7 +136,11 @@ class EarlyStoppingSuite(unittest.TestCase):
TrialScheduler.STOP)
def testMedianStoppingUsesMedian(self):
rule = MedianStoppingRule(grace_period=0, min_samples_required=1)
rule = MedianStoppingRule(
metric="episode_reward_mean",
mode="max",
grace_period=0,
min_samples_required=1)
t1, t2 = self.basicSetup(rule)
runner = mock_trial_runner()
rule.on_trial_complete(runner, t1, result(10, 1000))
@ -135,7 +155,11 @@ class EarlyStoppingSuite(unittest.TestCase):
def testMedianStoppingSoftStop(self):
rule = MedianStoppingRule(
grace_period=0, min_samples_required=1, hard_stop=False)
metric="episode_reward_mean",
mode="max",
grace_period=0,
min_samples_required=1,
hard_stop=False)
t1, t2 = self.basicSetup(rule)
runner = mock_trial_runner()
rule.on_trial_complete(runner, t1, result(10, 1000))
@ -265,7 +289,8 @@ class HyperbandSuite(unittest.TestCase):
(15, 9) -> (5, 27) -> (2, 45);
(34, 3) -> (12, 9) -> (4, 27) -> (2, 42);
(81, 1) -> (27, 3) -> (9, 9) -> (3, 27) -> (1, 41);"""
sched = HyperBandScheduler(max_t=max_t)
sched = HyperBandScheduler(
metric="episode_reward_mean", mode="max", max_t=max_t)
for i in range(num_trials):
t = Trial("__fake")
sched.on_trial_add(None, t)
@ -321,7 +346,7 @@ class HyperbandSuite(unittest.TestCase):
return sched
def testConfigSameEta(self):
sched = HyperBandScheduler()
sched = HyperBandScheduler(metric="episode_reward_mean", mode="max")
i = 0
while not sched._cur_band_filled():
t = Trial("__fake")
@ -335,7 +360,10 @@ class HyperbandSuite(unittest.TestCase):
reduction_factor = 10
sched = HyperBandScheduler(
max_t=1000, reduction_factor=reduction_factor)
metric="episode_reward_mean",
mode="max",
max_t=1000,
reduction_factor=reduction_factor)
i = 0
while not sched._cur_band_filled():
t = Trial("__fake")
@ -348,7 +376,8 @@ class HyperbandSuite(unittest.TestCase):
self.assertEqual(sched._hyperbands[0][-1]._r, 1)
def testConfigSameEtaSmall(self):
sched = HyperBandScheduler(max_t=1)
sched = HyperBandScheduler(
metric="episode_reward_mean", mode="max", max_t=1)
i = 0
while len(sched._hyperbands) < 2:
t = Trial("__fake")
@ -627,7 +656,11 @@ class BOHBSuite(unittest.TestCase):
_register_all() # re-register the evicted objects
def testLargestBracketFirst(self):
sched = HyperBandForBOHB(max_t=3, reduction_factor=3)
sched = HyperBandForBOHB(
metric="episode_reward_mean",
mode="max",
max_t=3,
reduction_factor=3)
runner = _MockTrialRunner(sched)
for i in range(3):
t = Trial("__fake")
@ -642,7 +675,11 @@ class BOHBSuite(unittest.TestCase):
def result(score, ts):
return {"episode_reward_mean": score, TRAINING_ITERATION: ts}
sched = HyperBandForBOHB(max_t=3, reduction_factor=3)
sched = HyperBandForBOHB(
metric="episode_reward_mean",
mode="max",
max_t=3,
reduction_factor=3)
runner = _MockTrialRunner(sched)
runner._search_alg = MagicMock()
runner._search_alg.searcher = MagicMock()
@ -668,7 +705,11 @@ class BOHBSuite(unittest.TestCase):
def result(score, ts):
return {"episode_reward_mean": score, TRAINING_ITERATION: ts}
sched = HyperBandForBOHB(max_t=3, reduction_factor=3, mode="min")
sched = HyperBandForBOHB(
metric="episode_reward_mean",
mode="min",
max_t=3,
reduction_factor=3)
runner = _MockTrialRunner(sched)
runner._search_alg = MagicMock()
runner._search_alg.searcher = MagicMock()
@ -693,7 +734,11 @@ class BOHBSuite(unittest.TestCase):
def result(score, ts):
return {"episode_reward_mean": score, TRAINING_ITERATION: ts}
sched = HyperBandForBOHB(max_t=10, reduction_factor=3, mode="min")
sched = HyperBandForBOHB(
metric="episode_reward_mean",
mode="min",
max_t=10,
reduction_factor=3)
runner = _MockTrialRunner(sched)
runner._search_alg = MagicMock()
runner._search_alg.searcher = MagicMock()
@ -761,6 +806,8 @@ class PopulationBasedTestingSuite(unittest.TestCase):
}
pbt = PopulationBasedTraining(
time_attr="training_iteration",
metric="episode_reward_mean",
mode="max",
perturbation_interval=perturbation_interval,
resample_probability=resample_prob,
quantile_fraction=0.25,
@ -1675,6 +1722,7 @@ class E2EPopulationBasedTestingSuite(unittest.TestCase):
}
pbt = PopulationBasedTraining(
metric="mean_accuracy",
mode="max",
time_attr="training_iteration",
perturbation_interval=perturbation_interval,
resample_probability=resample_prob,
@ -1791,7 +1839,8 @@ class AsyncHyperBandSuite(unittest.TestCase):
return t1, t2
def testAsyncHBOnComplete(self):
scheduler = AsyncHyperBandScheduler(max_t=10, brackets=1)
scheduler = AsyncHyperBandScheduler(
metric="episode_reward_mean", mode="max", max_t=10, brackets=1)
t1, t2 = self.basicSetup(scheduler)
t3 = Trial("PPO")
scheduler.on_trial_add(None, t3)
@ -1802,7 +1851,11 @@ class AsyncHyperBandSuite(unittest.TestCase):
def testAsyncHBGracePeriod(self):
scheduler = AsyncHyperBandScheduler(
grace_period=2.5, reduction_factor=3, brackets=1)
metric="episode_reward_mean",
mode="max",
grace_period=2.5,
reduction_factor=3,
brackets=1)
t1, t2 = self.basicSetup(scheduler)
scheduler.on_trial_complete(None, t1, result(10, 1000))
scheduler.on_trial_complete(None, t2, result(10, 1000))
@ -1819,7 +1872,8 @@ class AsyncHyperBandSuite(unittest.TestCase):
TrialScheduler.STOP)
def testAsyncHBAllCompletes(self):
scheduler = AsyncHyperBandScheduler(max_t=10, brackets=10)
scheduler = AsyncHyperBandScheduler(
metric="episode_reward_mean", mode="max", max_t=10, brackets=10)
trials = [Trial("PPO") for i in range(10)]
for t in trials:
scheduler.on_trial_add(None, t)
@ -1831,7 +1885,12 @@ class AsyncHyperBandSuite(unittest.TestCase):
def testAsyncHBUsesPercentile(self):
scheduler = AsyncHyperBandScheduler(
grace_period=1, max_t=10, reduction_factor=2, brackets=1)
metric="episode_reward_mean",
mode="max",
grace_period=1,
max_t=10,
reduction_factor=2,
brackets=1)
t1, t2 = self.basicSetup(scheduler)
scheduler.on_trial_complete(None, t1, result(10, 1000))
scheduler.on_trial_complete(None, t2, result(10, 1000))
@ -1846,7 +1905,12 @@ class AsyncHyperBandSuite(unittest.TestCase):
def testAsyncHBNanPercentile(self):
scheduler = AsyncHyperBandScheduler(
grace_period=1, max_t=10, reduction_factor=2, brackets=1)
metric="episode_reward_mean",
mode="max",
grace_period=1,
max_t=10,
reduction_factor=2,
brackets=1)
t1, t2 = self.nanSetup(scheduler)
scheduler.on_trial_complete(None, t1, result(10, 450))
scheduler.on_trial_complete(None, t2, result(10, np.nan))

View file

@ -68,6 +68,8 @@ def _report_progress(runner, reporter, done=False):
def run(
run_or_experiment,
name=None,
metric=None,
mode=None,
stop=None,
time_budget_s=None,
config=None,
@ -147,6 +149,12 @@ def run(
will need to first register the function:
``tune.register_trainable("lambda_id", lambda x: ...)``. You can
then use ``tune.run("lambda_id")``.
metric (str): Metric to optimize. This metric should be reported
with `tune.report()`. If set, will be passed to the search
algorithm and scheduler.
mode (str): Must be one of [min, max]. Determines whether objective is
minimizing or maximizing the metric attribute. If set, will be
passed to the search algorithm and scheduler.
name (str): Name of experiment.
stop (dict | callable | :class:`Stopper`): Stopping criteria. If dict,
the keys may be any field in the return result of 'train()',
@ -276,6 +284,11 @@ def run(
"sync_config=SyncConfig(...)`. See `ray.tune.SyncConfig` for "
"more details.")
if mode and mode not in ["min", "max"]:
raise ValueError(
"The `mode` parameter passed to `tune.run()` has to be one of "
"['min', 'max']")
config = config or {}
sync_config = sync_config or SyncConfig()
set_sync_periods(sync_config)
@ -329,8 +342,7 @@ def run(
if not search_alg:
search_alg = BasicVariantGenerator()
# TODO (krfricke): Introduce metric/mode as top level API
if config and not search_alg.set_search_properties(None, None, config):
if config and not search_alg.set_search_properties(metric, mode, config):
if has_unresolved_values(config):
raise ValueError(
"You passed a `config` parameter to `tune.run()` with "
@ -339,9 +351,17 @@ def run(
"does not contain any more parameter definitions - include "
"them in the search algorithm's search space if necessary.")
scheduler = scheduler or FIFOScheduler()
if not scheduler.set_search_properties(metric, mode):
raise ValueError(
"You passed a `metric` or `mode` argument to `tune.run()`, but "
"the scheduler you are using was already instantiated with their "
"own `metric` and `mode` parameters. Either remove the arguments "
"from your scheduler or from your call to `tune.run()`")
runner = TrialRunner(
search_alg=search_alg,
scheduler=scheduler or FIFOScheduler(),
scheduler=scheduler,
local_checkpoint_dir=experiments[0].checkpoint_dir,
remote_checkpoint_dir=experiments[0].remote_checkpoint_dir,
sync_to_cloud=sync_config.sync_to_cloud,
@ -413,8 +433,8 @@ def run(
return ExperimentAnalysis(
runner.checkpoint_file,
trials=trials,
default_metric=None,
default_mode=None)
default_metric=metric,
default_mode=mode)
def run_experiments(experiments,

View file

@ -26,7 +26,7 @@ timm
torch>=1.5.0
torchvision>=0.6.0
transformers
tune-sklearn==0.0.5
git+git://github.com/ray-project/tune-sklearn@master#tune-sklearn
wandb
xgboost
zoopt>=0.4.0