mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
[tune] Add points_to_evaluate
to BasicVariantGenerator (#12916)
Co-authored-by: Richard Liaw <rliaw@berkeley.edu>
This commit is contained in:
parent
124c8318a8
commit
3d72000826
11 changed files with 396 additions and 34 deletions
|
@ -263,10 +263,7 @@ Grid Search API
|
|||
|
||||
.. autofunction:: ray.tune.grid_search
|
||||
|
||||
Internals
|
||||
---------
|
||||
References
|
||||
----------
|
||||
|
||||
BasicVariantGenerator
|
||||
~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: ray.tune.suggest.BasicVariantGenerator
|
||||
See also :ref:`tune-basicvariant`.
|
|
@ -22,6 +22,10 @@ Summary
|
|||
- Summary
|
||||
- Website
|
||||
- Code Example
|
||||
* - :ref:`Random search/grid search <tune-basicvariant>`
|
||||
- Random search/grid search
|
||||
-
|
||||
- :doc:`/tune/examples/tune_basic_example`
|
||||
* - :ref:`AxSearch <tune-ax>`
|
||||
- Bayesian/Bandit Optimization
|
||||
- [`Ax <https://ax.dev/>`__]
|
||||
|
@ -123,6 +127,21 @@ identifier.
|
|||
|
||||
.. note:: This is currently not implemented for: AxSearch, TuneBOHB, SigOptSearch, and DragonflySearch.
|
||||
|
||||
.. _tune-basicvariant:
|
||||
|
||||
Random search and grid search (tune.suggest.basic_variant.BasicVariantGenerator)
|
||||
--------------------------------------------------------------------------------
|
||||
|
||||
The default and most basic way to do hyperparameter search is via random and grid search.
|
||||
Ray Tune does this through the :class:`BasicVariantGenerator <ray.tune.suggest.basic_variant.BasicVariantGenerator>`
|
||||
class that generates trial variants given a search space definition.
|
||||
|
||||
The :class:`BasicVariantGenerator <ray.tune.suggest.basic_variant.BasicVariantGenerator>` is used per
|
||||
default if no search algorithm is passed to
|
||||
:func:`tune.run() <ray.tune.run>`.
|
||||
|
||||
.. autoclass:: ray.tune.suggest.basic_variant.BasicVariantGenerator
|
||||
|
||||
.. _tune-ax:
|
||||
|
||||
Ax (tune.suggest.ax.AxSearch)
|
||||
|
|
|
@ -13,7 +13,7 @@ If any example is broken, or if you'd like to add an example to this page, feel
|
|||
General Examples
|
||||
----------------
|
||||
|
||||
|
||||
- :doc:`/tune/examples/tune_basic_example`: Simple example for doing a basic random and grid search.
|
||||
- :doc:`/tune/examples/async_hyperband_example`: Example of using a simple tuning function with AsyncHyperBandScheduler.
|
||||
- :doc:`/tune/examples/hyperband_function_example`: Example of using a Trainable function with HyperBandScheduler. Also uses the AsyncHyperBandScheduler.
|
||||
- :doc:`/tune/examples/pbt_function`: Example of using the function API with a PopulationBasedTraining scheduler.
|
||||
|
|
6
doc/source/tune/examples/tune_basic_example.rst
Normal file
6
doc/source/tune/examples/tune_basic_example.rst
Normal file
|
@ -0,0 +1,6 @@
|
|||
:orphan:
|
||||
|
||||
tune_basic_example
|
||||
~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. literalinclude:: /../../python/ray/tune/examples/tune_basic_example.py
|
|
@ -157,7 +157,7 @@ py_test(
|
|||
|
||||
py_test(
|
||||
name = "test_sample",
|
||||
size = "small",
|
||||
size = "medium",
|
||||
srcs = ["tests/test_sample.py"],
|
||||
deps = [":tune_lib"],
|
||||
tags = ["exclusive"],
|
||||
|
@ -696,6 +696,15 @@ py_test(
|
|||
args = ["--smoke-test"]
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "tune_basic_example",
|
||||
size = "small",
|
||||
srcs = ["examples/tune_basic_example.py"],
|
||||
deps = [":tune_lib"],
|
||||
tags = ["exclusive", "example"],
|
||||
args = ["--smoke-test"]
|
||||
)
|
||||
|
||||
# Downloads too much data.
|
||||
# py_test(
|
||||
# name = "tune_cifar10_gluon",
|
||||
|
|
51
python/ray/tune/examples/tune_basic_example.py
Normal file
51
python/ray/tune/examples/tune_basic_example.py
Normal file
|
@ -0,0 +1,51 @@
|
|||
"""This example demonstrates basic Ray Tune random search and grid search."""
|
||||
import time
|
||||
|
||||
import ray
|
||||
from ray import tune
|
||||
|
||||
|
||||
def evaluation_fn(step, width, height):
|
||||
time.sleep(0.1)
|
||||
return (0.1 + width * step / 100)**(-1) + height * 0.1
|
||||
|
||||
|
||||
def easy_objective(config):
|
||||
# Hyperparameters
|
||||
width, height = config["width"], config["height"]
|
||||
|
||||
for step in range(config["steps"]):
|
||||
# Iterative training function - can be any arbitrary training procedure
|
||||
intermediate_score = evaluation_fn(step, width, height)
|
||||
# Feed the score back back to Tune.
|
||||
tune.report(iterations=step, mean_loss=intermediate_score)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--smoke-test", action="store_true", help="Finish quickly for testing")
|
||||
args, _ = parser.parse_known_args()
|
||||
ray.init(configure_logging=False)
|
||||
|
||||
# This will do a grid search over the `activation` parameter. This means
|
||||
# that each of the two values (`relu` and `tanh`) will be sampled once
|
||||
# for each sample (`num_samples`). We end up with 2 * 50 = 100 samples.
|
||||
# The `width` and `height` parameters are sampled randomly.
|
||||
# `steps` is a constant parameter.
|
||||
|
||||
analysis = tune.run(
|
||||
easy_objective,
|
||||
metric="mean_loss",
|
||||
mode="min",
|
||||
num_samples=5 if args.smoke_test else 50,
|
||||
config={
|
||||
"steps": 5 if args.smoke_test else 100,
|
||||
"width": tune.uniform(0, 20),
|
||||
"height": tune.uniform(-100, 100),
|
||||
"activation": tune.grid_search(["relu", "tanh"])
|
||||
})
|
||||
|
||||
print("Best hyperparameters found were: ", analysis.best_config)
|
|
@ -53,6 +53,14 @@ class Domain:
|
|||
def is_function(self):
|
||||
return False
|
||||
|
||||
def is_valid(self, value: Any):
|
||||
"""Returns True if `value` is a valid value in this domain."""
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def domain_str(self):
|
||||
return "(unknown)"
|
||||
|
||||
|
||||
class Sampler:
|
||||
def sample(self,
|
||||
|
@ -203,6 +211,13 @@ class Float(Domain):
|
|||
new.set_sampler(Quantized(new.get_sampler(), q), allow_override=True)
|
||||
return new
|
||||
|
||||
def is_valid(self, value: float):
|
||||
return self.lower <= value <= self.upper
|
||||
|
||||
@property
|
||||
def domain_str(self):
|
||||
return f"({self.lower}, {self.upper})"
|
||||
|
||||
|
||||
class Integer(Domain):
|
||||
class _Uniform(Uniform):
|
||||
|
@ -232,6 +247,13 @@ class Integer(Domain):
|
|||
new.set_sampler(self._Uniform())
|
||||
return new
|
||||
|
||||
def is_valid(self, value: int):
|
||||
return self.lower <= value <= self.upper
|
||||
|
||||
@property
|
||||
def domain_str(self):
|
||||
return f"({self.lower}, {self.upper})"
|
||||
|
||||
|
||||
class Categorical(Domain):
|
||||
class _Uniform(Uniform):
|
||||
|
@ -264,6 +286,13 @@ class Categorical(Domain):
|
|||
def __getitem__(self, item):
|
||||
return self.categories[item]
|
||||
|
||||
def is_valid(self, value: Any):
|
||||
return value in self.categories
|
||||
|
||||
@property
|
||||
def domain_str(self):
|
||||
return f"{self.categories}"
|
||||
|
||||
|
||||
class Function(Domain):
|
||||
class _CallSampler(BaseSampler):
|
||||
|
@ -295,6 +324,13 @@ class Function(Domain):
|
|||
def is_function(self):
|
||||
return True
|
||||
|
||||
def is_valid(self, value: Any):
|
||||
return True # This is user-defined, so lets not assume anything
|
||||
|
||||
@property
|
||||
def domain_str(self):
|
||||
return f"{self.func}()"
|
||||
|
||||
|
||||
class Quantized(Sampler):
|
||||
def __init__(self, sampler: Sampler, q: Union[float, int]):
|
||||
|
|
|
@ -1,44 +1,95 @@
|
|||
import copy
|
||||
import itertools
|
||||
import os
|
||||
import uuid
|
||||
from typing import Dict, List, Union
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
from ray.tune.error import TuneError
|
||||
from ray.tune.experiment import Experiment, convert_to_experiment_list
|
||||
from ray.tune.config_parser import make_parser, create_trial_from_spec
|
||||
from ray.tune.suggest.variant_generator import (
|
||||
count_variants, generate_variants, format_vars, flatten_resolved_vars)
|
||||
count_variants, generate_variants, format_vars, flatten_resolved_vars,
|
||||
get_preset_variants)
|
||||
from ray.tune.suggest.search import SearchAlgorithm
|
||||
|
||||
|
||||
class BasicVariantGenerator(SearchAlgorithm):
|
||||
"""Uses Tune's variant generation for resolving variables.
|
||||
|
||||
See also: `ray.tune.suggest.variant_generator`.
|
||||
This is the default search algorithm used if no other search algorithm
|
||||
is specified.
|
||||
|
||||
User API:
|
||||
|
||||
Args:
|
||||
points_to_evaluate (list): Initial parameter suggestions to be run
|
||||
first. This is for when you already have some good parameters
|
||||
you want to run first to help the algorithm make better suggestions
|
||||
for future parameters. Needs to be a list of dicts containing the
|
||||
configurations.
|
||||
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from ray import tune
|
||||
from ray.tune.suggest import BasicVariantGenerator
|
||||
|
||||
searcher = BasicVariantGenerator()
|
||||
tune.run(my_trainable_func, algo=searcher)
|
||||
# This will automatically use the `BasicVariantGenerator`
|
||||
tune.run(
|
||||
lambda config: config["a"] + config["b"],
|
||||
config={
|
||||
"a": tune.grid_search([1, 2]),
|
||||
"b": tune.randint(0, 3)
|
||||
},
|
||||
num_samples=4)
|
||||
|
||||
Internal API:
|
||||
In the example above, 8 trials will be generated: For each sample
|
||||
(``4``), each of the grid search variants for ``a`` will be sampled
|
||||
once. The ``b`` parameter will be sampled randomly.
|
||||
|
||||
The generator accepts a pre-set list of points that should be evaluated.
|
||||
The points will replace the first samples of each experiment passed to
|
||||
the ``BasicVariantGenerator``.
|
||||
|
||||
Each point will replace one sample of the specified ``num_samples``. If
|
||||
grid search variables are overwritten with the values specified in the
|
||||
presets, the number of samples will thus be reduced.
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from ray.tune.suggest import BasicVariantGenerator
|
||||
from ray import tune
|
||||
from ray.tune.suggest.basic_variant import BasicVariantGenerator
|
||||
|
||||
|
||||
tune.run(
|
||||
lambda config: config["a"] + config["b"],
|
||||
config={
|
||||
"a": tune.grid_search([1, 2]),
|
||||
"b": tune.randint(0, 3)
|
||||
},
|
||||
search_alg=BasicVariantGenerator(points_to_evaluate=[
|
||||
{"a": 2, "b": 2},
|
||||
{"a": 1},
|
||||
{"b": 2}
|
||||
]),
|
||||
num_samples=4)
|
||||
|
||||
The example above will produce six trials via four samples:
|
||||
|
||||
- The first sample will produce one trial with ``a=2`` and ``b=2``.
|
||||
- The second sample will produce one trial with ``a=1`` and ``b`` sampled
|
||||
randomly
|
||||
- The third sample will produce two trials, one for each grid search
|
||||
value of ``a``. It will be ``b=2`` for both of these trials.
|
||||
- The fourth sample will produce two trials, one for each grid search
|
||||
value of ``a``. ``b`` will be sampled randomly and independently for
|
||||
both of these trials.
|
||||
|
||||
searcher = BasicVariantGenerator()
|
||||
searcher.add_configurations({"experiment": { ... }})
|
||||
trial = searcher.next_trial()
|
||||
searcher.is_finished == True
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self, points_to_evaluate: Optional[List[Dict]] = None):
|
||||
"""Initializes the Variant Generator.
|
||||
|
||||
"""
|
||||
|
@ -48,6 +99,8 @@ class BasicVariantGenerator(SearchAlgorithm):
|
|||
self._counter = 0
|
||||
self._finished = False
|
||||
|
||||
self._points_to_evaluate = points_to_evaluate or []
|
||||
|
||||
# Unique prefix for all trials generated, e.g., trial ids start as
|
||||
# 2f1e_00001, 2f1ef_00002, 2f1ef_0003, etc. Overridable for testing.
|
||||
force_test_uuid = os.environ.get("_TEST_TUNE_TRIAL_UUID")
|
||||
|
@ -72,12 +125,14 @@ class BasicVariantGenerator(SearchAlgorithm):
|
|||
"""
|
||||
experiment_list = convert_to_experiment_list(experiments)
|
||||
for experiment in experiment_list:
|
||||
self._total_samples += count_variants(experiment.spec)
|
||||
points_to_evaluate = copy.deepcopy(self._points_to_evaluate)
|
||||
self._total_samples += count_variants(experiment.spec,
|
||||
points_to_evaluate)
|
||||
self._trial_generator = itertools.chain(
|
||||
self._trial_generator,
|
||||
self._generate_trials(
|
||||
experiment.spec.get("num_samples", 1), experiment.spec,
|
||||
experiment.dir_name))
|
||||
experiment.dir_name, points_to_evaluate))
|
||||
|
||||
def next_trial(self):
|
||||
"""Provides one Trial object to be queued into the TrialRunner.
|
||||
|
@ -95,7 +150,11 @@ class BasicVariantGenerator(SearchAlgorithm):
|
|||
self.set_finished()
|
||||
return None
|
||||
|
||||
def _generate_trials(self, num_samples, unresolved_spec, output_path=""):
|
||||
def _generate_trials(self,
|
||||
num_samples,
|
||||
unresolved_spec,
|
||||
output_path="",
|
||||
points_to_evaluate=None):
|
||||
"""Generates Trial objects with the variant generation process.
|
||||
|
||||
Uses a fixed point iteration to resolve variants. All trials
|
||||
|
@ -109,6 +168,28 @@ class BasicVariantGenerator(SearchAlgorithm):
|
|||
|
||||
if "run" not in unresolved_spec:
|
||||
raise TuneError("Must specify `run` in {}".format(unresolved_spec))
|
||||
|
||||
points_to_evaluate = points_to_evaluate or []
|
||||
|
||||
while points_to_evaluate:
|
||||
config = points_to_evaluate.pop(0)
|
||||
for resolved_vars, spec in get_preset_variants(
|
||||
unresolved_spec, config):
|
||||
trial_id = self._uuid_prefix + ("%05d" % self._counter)
|
||||
experiment_tag = str(self._counter)
|
||||
self._counter += 1
|
||||
yield create_trial_from_spec(
|
||||
spec,
|
||||
output_path,
|
||||
self._parser,
|
||||
evaluated_params=flatten_resolved_vars(resolved_vars),
|
||||
trial_id=trial_id,
|
||||
experiment_tag=experiment_tag)
|
||||
num_samples -= 1
|
||||
|
||||
if num_samples <= 0:
|
||||
return
|
||||
|
||||
for _ in range(num_samples):
|
||||
for resolved_vars, spec in generate_variants(unresolved_spec):
|
||||
trial_id = self._uuid_prefix + ("%05d" % self._counter)
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
import copy
|
||||
import logging
|
||||
from typing import Any, Dict, Generator, List, Tuple
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Dict, Generator, List, Optional, Tuple
|
||||
|
||||
import numpy
|
||||
import random
|
||||
|
@ -138,13 +139,38 @@ def parse_spec_vars(spec: Dict) -> Tuple[List[Tuple[Tuple, Any]], List[Tuple[
|
|||
return resolved_vars, domain_vars, grid_vars
|
||||
|
||||
|
||||
def count_variants(spec: Dict) -> int:
|
||||
spec = copy.deepcopy(spec)
|
||||
_, domain_vars, grid_vars = parse_spec_vars(spec)
|
||||
grid_count = 1
|
||||
for path, domain in grid_vars:
|
||||
grid_count *= len(domain.categories)
|
||||
return spec.get("num_samples", 1) * grid_count
|
||||
def count_variants(spec: Dict, presets: Optional[List[Dict]] = None) -> int:
|
||||
# Helper function: Deep update dictionary
|
||||
def deep_update(d, u):
|
||||
for k, v in u.items():
|
||||
if isinstance(v, Mapping):
|
||||
d[k] = deep_update(d.get(k, {}), v)
|
||||
else:
|
||||
d[k] = v
|
||||
return d
|
||||
|
||||
# Count samples for a specific spec
|
||||
def spec_samples(spec, num_samples=1):
|
||||
_, domain_vars, grid_vars = parse_spec_vars(spec)
|
||||
grid_count = 1
|
||||
for path, domain in grid_vars:
|
||||
grid_count *= len(domain.categories)
|
||||
return num_samples * grid_count
|
||||
|
||||
total_samples = 0
|
||||
total_num_samples = spec.get("num_samples", 1)
|
||||
# For each preset, overwrite the spec and count the samples generated
|
||||
# for this preset
|
||||
for preset in presets:
|
||||
preset_spec = copy.deepcopy(spec)
|
||||
deep_update(preset_spec["config"], preset)
|
||||
total_samples += spec_samples(preset_spec, 1)
|
||||
total_num_samples -= 1
|
||||
|
||||
# Add the remaining samples
|
||||
if total_num_samples > 0:
|
||||
total_samples += spec_samples(spec, total_num_samples)
|
||||
return total_samples
|
||||
|
||||
|
||||
def _generate_variants(spec: Dict) -> Tuple[Dict, Dict]:
|
||||
|
@ -172,6 +198,45 @@ def _generate_variants(spec: Dict) -> Tuple[Dict, Dict]:
|
|||
yield resolved_vars, spec
|
||||
|
||||
|
||||
def get_preset_variants(spec: Dict, config: Dict):
|
||||
"""Get variants according to a spec, initialized with a config.
|
||||
|
||||
Variables from the spec are overwritten by the variables in the config.
|
||||
Thus, we may end up with less sampled parameters.
|
||||
|
||||
This function also checks if values used to overwrite search space
|
||||
parameters are valid, and logs a warning if not.
|
||||
"""
|
||||
spec = copy.deepcopy(spec)
|
||||
|
||||
resolved, _, _ = parse_spec_vars(config)
|
||||
|
||||
for path, val in resolved:
|
||||
try:
|
||||
domain = _get_value(spec["config"], path)
|
||||
if isinstance(domain, dict):
|
||||
if "grid_search" in domain:
|
||||
domain = Categorical(domain["grid_search"])
|
||||
else:
|
||||
# If users want to overwrite an entire subdict,
|
||||
# let them do it.
|
||||
domain = None
|
||||
except IndexError as exc:
|
||||
raise ValueError(
|
||||
f"Pre-set config key `{'/'.join(path)}` does not correspond "
|
||||
f"to a valid key in the search space definition. Please add "
|
||||
f"this path to the `config` variable passed to `tune.run()`."
|
||||
) from exc
|
||||
|
||||
if domain and not domain.is_valid(val):
|
||||
logger.warning(
|
||||
f"Pre-set value `{val}` is not within valid values of "
|
||||
f"parameter `{'/'.join(path)}`: {domain.domain_str}")
|
||||
assign_value(spec["config"], path, val)
|
||||
|
||||
return _generate_variants(spec)
|
||||
|
||||
|
||||
def assign_value(spec: Dict, path: Tuple, value: Any):
|
||||
for k in path[:-1]:
|
||||
spec = spec[k]
|
||||
|
|
|
@ -2,6 +2,7 @@ import numpy as np
|
|||
import unittest
|
||||
|
||||
from ray import tune
|
||||
from ray.tune import Experiment
|
||||
from ray.tune.suggest.variant_generator import generate_variants
|
||||
|
||||
|
||||
|
@ -871,6 +872,102 @@ class SearchSpaceTest(unittest.TestCase):
|
|||
return self._testPointsToEvaluate(
|
||||
ZOOptSearch, config, budget=10, parallel_num=8)
|
||||
|
||||
def testPointsToEvaluateBasicVariant(self):
|
||||
config = {
|
||||
"metric": tune.sample.Categorical([1, 2, 3, 4]).uniform(),
|
||||
"a": tune.sample.Categorical(["t1", "t2", "t3", "t4"]).uniform(),
|
||||
"b": tune.sample.Integer(0, 5),
|
||||
"c": tune.sample.Float(1e-4, 1e-1).loguniform()
|
||||
}
|
||||
|
||||
from ray.tune.suggest.basic_variant import BasicVariantGenerator
|
||||
return self._testPointsToEvaluate(BasicVariantGenerator, config)
|
||||
|
||||
def testPointsToEvaluateBasicVariantAdvanced(self):
|
||||
config = {
|
||||
"grid_1": tune.grid_search(["a", "b", "c", "d"]),
|
||||
"grid_2": tune.grid_search(["x", "y", "z"]),
|
||||
"nested": {
|
||||
"random": tune.uniform(2., 10.),
|
||||
"dependent": tune.sample_from(
|
||||
lambda spec: -1. * spec.config.nested.random)
|
||||
}
|
||||
}
|
||||
|
||||
points = [
|
||||
{
|
||||
"grid_1": "b"
|
||||
},
|
||||
{
|
||||
"grid_2": "z"
|
||||
},
|
||||
{
|
||||
"grid_1": "a",
|
||||
"grid_2": "y"
|
||||
},
|
||||
{
|
||||
"nested": {
|
||||
"random": 8.0
|
||||
}
|
||||
},
|
||||
]
|
||||
|
||||
from ray.tune.suggest.basic_variant import BasicVariantGenerator
|
||||
|
||||
# grid_1 * grid_2 are 3 * 4 = 12 variants per complete grid search
|
||||
# However if one grid var is set by preset variables, that run
|
||||
# is excluded from grid search.
|
||||
|
||||
# Point 1 overwrites grid_1, so the first trial only grid searches
|
||||
# over grid_2 (3 trials).
|
||||
# The remaining 5 trials search over the whole space (5 * 12 trials)
|
||||
searcher = BasicVariantGenerator(points_to_evaluate=[points[0]])
|
||||
exp = Experiment(
|
||||
run=_mock_objective, name="test", config=config, num_samples=6)
|
||||
searcher.add_configurations(exp)
|
||||
self.assertEqual(searcher.total_samples, 1 * 3 + 5 * 12)
|
||||
|
||||
# Point 2 overwrites grid_2, so the first trial only grid searches
|
||||
# over grid_1 (4 trials).
|
||||
# The remaining 5 trials search over the whole space (5 * 12 trials)
|
||||
searcher = BasicVariantGenerator(points_to_evaluate=[points[1]])
|
||||
exp = Experiment(
|
||||
run=_mock_objective, name="test", config=config, num_samples=6)
|
||||
searcher.add_configurations(exp)
|
||||
self.assertEqual(searcher.total_samples, 1 * 4 + 5 * 12)
|
||||
|
||||
# Point 3 overwrites grid_1 and grid_2, so the first trial does not
|
||||
# grid search.
|
||||
# The remaining 5 trials search over the whole space (5 * 12 trials)
|
||||
searcher = BasicVariantGenerator(points_to_evaluate=[points[2]])
|
||||
exp = Experiment(
|
||||
run=_mock_objective, name="test", config=config, num_samples=6)
|
||||
searcher.add_configurations(exp)
|
||||
self.assertEqual(searcher.total_samples, 1 + 5 * 12)
|
||||
|
||||
# When initialized with all points, the first three trials are
|
||||
# defined by the logic above. Only 3 trials are grid searched
|
||||
# compeletely.
|
||||
searcher = BasicVariantGenerator(points_to_evaluate=points)
|
||||
exp = Experiment(
|
||||
run=_mock_objective, name="test", config=config, num_samples=6)
|
||||
searcher.add_configurations(exp)
|
||||
self.assertEqual(searcher.total_samples, 1 * 3 + 1 * 4 + 1 + 3 * 12)
|
||||
|
||||
# Run this and confirm results
|
||||
analysis = tune.run(exp, search_alg=searcher)
|
||||
configs = [trial.config for trial in analysis.trials]
|
||||
|
||||
self.assertEqual(len(configs), searcher.total_samples)
|
||||
self.assertTrue(
|
||||
all(config["grid_1"] == "b" for config in configs[0:3]))
|
||||
self.assertTrue(
|
||||
all(config["grid_2"] == "z" for config in configs[3:7]))
|
||||
self.assertTrue(configs[7]["grid_1"] == "a"
|
||||
and configs[7]["grid_2"] == "y")
|
||||
self.assertTrue(configs[8]["nested"]["random"] == 8.0)
|
||||
self.assertTrue(configs[8]["nested"]["dependent"] == -8.0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import pytest
|
||||
|
|
|
@ -171,7 +171,8 @@ def run(
|
|||
samples are generated until a stopping condition is met.
|
||||
local_dir (str): Local dir to save training results to.
|
||||
Defaults to ``~/ray_results``.
|
||||
search_alg (Searcher): Search algorithm for optimization.
|
||||
search_alg (Searcher|SearchAlgorithm): Search algorithm for
|
||||
optimization.
|
||||
scheduler (TrialScheduler): Scheduler for executing
|
||||
the experiment. Choose among FIFO (default), MedianStopping,
|
||||
AsyncHyperBand, HyperBand and PopulationBasedTraining. Refer to
|
||||
|
|
Loading…
Add table
Reference in a new issue