mirror of
https://github.com/vale981/ray
synced 2025-03-07 02:51:39 -05:00
[tune] Add points_to_evaluate
to BasicVariantGenerator (#12916)
Co-authored-by: Richard Liaw <rliaw@berkeley.edu>
This commit is contained in:
parent
124c8318a8
commit
3d72000826
11 changed files with 396 additions and 34 deletions
|
@ -263,10 +263,7 @@ Grid Search API
|
||||||
|
|
||||||
.. autofunction:: ray.tune.grid_search
|
.. autofunction:: ray.tune.grid_search
|
||||||
|
|
||||||
Internals
|
References
|
||||||
---------
|
----------
|
||||||
|
|
||||||
BasicVariantGenerator
|
See also :ref:`tune-basicvariant`.
|
||||||
~~~~~~~~~~~~~~~~~~~~~
|
|
||||||
|
|
||||||
.. autoclass:: ray.tune.suggest.BasicVariantGenerator
|
|
|
@ -22,6 +22,10 @@ Summary
|
||||||
- Summary
|
- Summary
|
||||||
- Website
|
- Website
|
||||||
- Code Example
|
- Code Example
|
||||||
|
* - :ref:`Random search/grid search <tune-basicvariant>`
|
||||||
|
- Random search/grid search
|
||||||
|
-
|
||||||
|
- :doc:`/tune/examples/tune_basic_example`
|
||||||
* - :ref:`AxSearch <tune-ax>`
|
* - :ref:`AxSearch <tune-ax>`
|
||||||
- Bayesian/Bandit Optimization
|
- Bayesian/Bandit Optimization
|
||||||
- [`Ax <https://ax.dev/>`__]
|
- [`Ax <https://ax.dev/>`__]
|
||||||
|
@ -123,6 +127,21 @@ identifier.
|
||||||
|
|
||||||
.. note:: This is currently not implemented for: AxSearch, TuneBOHB, SigOptSearch, and DragonflySearch.
|
.. note:: This is currently not implemented for: AxSearch, TuneBOHB, SigOptSearch, and DragonflySearch.
|
||||||
|
|
||||||
|
.. _tune-basicvariant:
|
||||||
|
|
||||||
|
Random search and grid search (tune.suggest.basic_variant.BasicVariantGenerator)
|
||||||
|
--------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
The default and most basic way to do hyperparameter search is via random and grid search.
|
||||||
|
Ray Tune does this through the :class:`BasicVariantGenerator <ray.tune.suggest.basic_variant.BasicVariantGenerator>`
|
||||||
|
class that generates trial variants given a search space definition.
|
||||||
|
|
||||||
|
The :class:`BasicVariantGenerator <ray.tune.suggest.basic_variant.BasicVariantGenerator>` is used per
|
||||||
|
default if no search algorithm is passed to
|
||||||
|
:func:`tune.run() <ray.tune.run>`.
|
||||||
|
|
||||||
|
.. autoclass:: ray.tune.suggest.basic_variant.BasicVariantGenerator
|
||||||
|
|
||||||
.. _tune-ax:
|
.. _tune-ax:
|
||||||
|
|
||||||
Ax (tune.suggest.ax.AxSearch)
|
Ax (tune.suggest.ax.AxSearch)
|
||||||
|
|
|
@ -13,7 +13,7 @@ If any example is broken, or if you'd like to add an example to this page, feel
|
||||||
General Examples
|
General Examples
|
||||||
----------------
|
----------------
|
||||||
|
|
||||||
|
- :doc:`/tune/examples/tune_basic_example`: Simple example for doing a basic random and grid search.
|
||||||
- :doc:`/tune/examples/async_hyperband_example`: Example of using a simple tuning function with AsyncHyperBandScheduler.
|
- :doc:`/tune/examples/async_hyperband_example`: Example of using a simple tuning function with AsyncHyperBandScheduler.
|
||||||
- :doc:`/tune/examples/hyperband_function_example`: Example of using a Trainable function with HyperBandScheduler. Also uses the AsyncHyperBandScheduler.
|
- :doc:`/tune/examples/hyperband_function_example`: Example of using a Trainable function with HyperBandScheduler. Also uses the AsyncHyperBandScheduler.
|
||||||
- :doc:`/tune/examples/pbt_function`: Example of using the function API with a PopulationBasedTraining scheduler.
|
- :doc:`/tune/examples/pbt_function`: Example of using the function API with a PopulationBasedTraining scheduler.
|
||||||
|
|
6
doc/source/tune/examples/tune_basic_example.rst
Normal file
6
doc/source/tune/examples/tune_basic_example.rst
Normal file
|
@ -0,0 +1,6 @@
|
||||||
|
:orphan:
|
||||||
|
|
||||||
|
tune_basic_example
|
||||||
|
~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
.. literalinclude:: /../../python/ray/tune/examples/tune_basic_example.py
|
|
@ -157,7 +157,7 @@ py_test(
|
||||||
|
|
||||||
py_test(
|
py_test(
|
||||||
name = "test_sample",
|
name = "test_sample",
|
||||||
size = "small",
|
size = "medium",
|
||||||
srcs = ["tests/test_sample.py"],
|
srcs = ["tests/test_sample.py"],
|
||||||
deps = [":tune_lib"],
|
deps = [":tune_lib"],
|
||||||
tags = ["exclusive"],
|
tags = ["exclusive"],
|
||||||
|
@ -696,6 +696,15 @@ py_test(
|
||||||
args = ["--smoke-test"]
|
args = ["--smoke-test"]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
py_test(
|
||||||
|
name = "tune_basic_example",
|
||||||
|
size = "small",
|
||||||
|
srcs = ["examples/tune_basic_example.py"],
|
||||||
|
deps = [":tune_lib"],
|
||||||
|
tags = ["exclusive", "example"],
|
||||||
|
args = ["--smoke-test"]
|
||||||
|
)
|
||||||
|
|
||||||
# Downloads too much data.
|
# Downloads too much data.
|
||||||
# py_test(
|
# py_test(
|
||||||
# name = "tune_cifar10_gluon",
|
# name = "tune_cifar10_gluon",
|
||||||
|
|
51
python/ray/tune/examples/tune_basic_example.py
Normal file
51
python/ray/tune/examples/tune_basic_example.py
Normal file
|
@ -0,0 +1,51 @@
|
||||||
|
"""This example demonstrates basic Ray Tune random search and grid search."""
|
||||||
|
import time
|
||||||
|
|
||||||
|
import ray
|
||||||
|
from ray import tune
|
||||||
|
|
||||||
|
|
||||||
|
def evaluation_fn(step, width, height):
|
||||||
|
time.sleep(0.1)
|
||||||
|
return (0.1 + width * step / 100)**(-1) + height * 0.1
|
||||||
|
|
||||||
|
|
||||||
|
def easy_objective(config):
|
||||||
|
# Hyperparameters
|
||||||
|
width, height = config["width"], config["height"]
|
||||||
|
|
||||||
|
for step in range(config["steps"]):
|
||||||
|
# Iterative training function - can be any arbitrary training procedure
|
||||||
|
intermediate_score = evaluation_fn(step, width, height)
|
||||||
|
# Feed the score back back to Tune.
|
||||||
|
tune.report(iterations=step, mean_loss=intermediate_score)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument(
|
||||||
|
"--smoke-test", action="store_true", help="Finish quickly for testing")
|
||||||
|
args, _ = parser.parse_known_args()
|
||||||
|
ray.init(configure_logging=False)
|
||||||
|
|
||||||
|
# This will do a grid search over the `activation` parameter. This means
|
||||||
|
# that each of the two values (`relu` and `tanh`) will be sampled once
|
||||||
|
# for each sample (`num_samples`). We end up with 2 * 50 = 100 samples.
|
||||||
|
# The `width` and `height` parameters are sampled randomly.
|
||||||
|
# `steps` is a constant parameter.
|
||||||
|
|
||||||
|
analysis = tune.run(
|
||||||
|
easy_objective,
|
||||||
|
metric="mean_loss",
|
||||||
|
mode="min",
|
||||||
|
num_samples=5 if args.smoke_test else 50,
|
||||||
|
config={
|
||||||
|
"steps": 5 if args.smoke_test else 100,
|
||||||
|
"width": tune.uniform(0, 20),
|
||||||
|
"height": tune.uniform(-100, 100),
|
||||||
|
"activation": tune.grid_search(["relu", "tanh"])
|
||||||
|
})
|
||||||
|
|
||||||
|
print("Best hyperparameters found were: ", analysis.best_config)
|
|
@ -53,6 +53,14 @@ class Domain:
|
||||||
def is_function(self):
|
def is_function(self):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
def is_valid(self, value: Any):
|
||||||
|
"""Returns True if `value` is a valid value in this domain."""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@property
|
||||||
|
def domain_str(self):
|
||||||
|
return "(unknown)"
|
||||||
|
|
||||||
|
|
||||||
class Sampler:
|
class Sampler:
|
||||||
def sample(self,
|
def sample(self,
|
||||||
|
@ -203,6 +211,13 @@ class Float(Domain):
|
||||||
new.set_sampler(Quantized(new.get_sampler(), q), allow_override=True)
|
new.set_sampler(Quantized(new.get_sampler(), q), allow_override=True)
|
||||||
return new
|
return new
|
||||||
|
|
||||||
|
def is_valid(self, value: float):
|
||||||
|
return self.lower <= value <= self.upper
|
||||||
|
|
||||||
|
@property
|
||||||
|
def domain_str(self):
|
||||||
|
return f"({self.lower}, {self.upper})"
|
||||||
|
|
||||||
|
|
||||||
class Integer(Domain):
|
class Integer(Domain):
|
||||||
class _Uniform(Uniform):
|
class _Uniform(Uniform):
|
||||||
|
@ -232,6 +247,13 @@ class Integer(Domain):
|
||||||
new.set_sampler(self._Uniform())
|
new.set_sampler(self._Uniform())
|
||||||
return new
|
return new
|
||||||
|
|
||||||
|
def is_valid(self, value: int):
|
||||||
|
return self.lower <= value <= self.upper
|
||||||
|
|
||||||
|
@property
|
||||||
|
def domain_str(self):
|
||||||
|
return f"({self.lower}, {self.upper})"
|
||||||
|
|
||||||
|
|
||||||
class Categorical(Domain):
|
class Categorical(Domain):
|
||||||
class _Uniform(Uniform):
|
class _Uniform(Uniform):
|
||||||
|
@ -264,6 +286,13 @@ class Categorical(Domain):
|
||||||
def __getitem__(self, item):
|
def __getitem__(self, item):
|
||||||
return self.categories[item]
|
return self.categories[item]
|
||||||
|
|
||||||
|
def is_valid(self, value: Any):
|
||||||
|
return value in self.categories
|
||||||
|
|
||||||
|
@property
|
||||||
|
def domain_str(self):
|
||||||
|
return f"{self.categories}"
|
||||||
|
|
||||||
|
|
||||||
class Function(Domain):
|
class Function(Domain):
|
||||||
class _CallSampler(BaseSampler):
|
class _CallSampler(BaseSampler):
|
||||||
|
@ -295,6 +324,13 @@ class Function(Domain):
|
||||||
def is_function(self):
|
def is_function(self):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
def is_valid(self, value: Any):
|
||||||
|
return True # This is user-defined, so lets not assume anything
|
||||||
|
|
||||||
|
@property
|
||||||
|
def domain_str(self):
|
||||||
|
return f"{self.func}()"
|
||||||
|
|
||||||
|
|
||||||
class Quantized(Sampler):
|
class Quantized(Sampler):
|
||||||
def __init__(self, sampler: Sampler, q: Union[float, int]):
|
def __init__(self, sampler: Sampler, q: Union[float, int]):
|
||||||
|
|
|
@ -1,44 +1,95 @@
|
||||||
|
import copy
|
||||||
import itertools
|
import itertools
|
||||||
import os
|
import os
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Dict, List, Union
|
from typing import Dict, List, Optional, Union
|
||||||
|
|
||||||
from ray.tune.error import TuneError
|
from ray.tune.error import TuneError
|
||||||
from ray.tune.experiment import Experiment, convert_to_experiment_list
|
from ray.tune.experiment import Experiment, convert_to_experiment_list
|
||||||
from ray.tune.config_parser import make_parser, create_trial_from_spec
|
from ray.tune.config_parser import make_parser, create_trial_from_spec
|
||||||
from ray.tune.suggest.variant_generator import (
|
from ray.tune.suggest.variant_generator import (
|
||||||
count_variants, generate_variants, format_vars, flatten_resolved_vars)
|
count_variants, generate_variants, format_vars, flatten_resolved_vars,
|
||||||
|
get_preset_variants)
|
||||||
from ray.tune.suggest.search import SearchAlgorithm
|
from ray.tune.suggest.search import SearchAlgorithm
|
||||||
|
|
||||||
|
|
||||||
class BasicVariantGenerator(SearchAlgorithm):
|
class BasicVariantGenerator(SearchAlgorithm):
|
||||||
"""Uses Tune's variant generation for resolving variables.
|
"""Uses Tune's variant generation for resolving variables.
|
||||||
|
|
||||||
See also: `ray.tune.suggest.variant_generator`.
|
This is the default search algorithm used if no other search algorithm
|
||||||
|
is specified.
|
||||||
|
|
||||||
User API:
|
|
||||||
|
Args:
|
||||||
|
points_to_evaluate (list): Initial parameter suggestions to be run
|
||||||
|
first. This is for when you already have some good parameters
|
||||||
|
you want to run first to help the algorithm make better suggestions
|
||||||
|
for future parameters. Needs to be a list of dicts containing the
|
||||||
|
configurations.
|
||||||
|
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
from ray import tune
|
from ray import tune
|
||||||
from ray.tune.suggest import BasicVariantGenerator
|
|
||||||
|
|
||||||
searcher = BasicVariantGenerator()
|
# This will automatically use the `BasicVariantGenerator`
|
||||||
tune.run(my_trainable_func, algo=searcher)
|
tune.run(
|
||||||
|
lambda config: config["a"] + config["b"],
|
||||||
|
config={
|
||||||
|
"a": tune.grid_search([1, 2]),
|
||||||
|
"b": tune.randint(0, 3)
|
||||||
|
},
|
||||||
|
num_samples=4)
|
||||||
|
|
||||||
Internal API:
|
In the example above, 8 trials will be generated: For each sample
|
||||||
|
(``4``), each of the grid search variants for ``a`` will be sampled
|
||||||
|
once. The ``b`` parameter will be sampled randomly.
|
||||||
|
|
||||||
|
The generator accepts a pre-set list of points that should be evaluated.
|
||||||
|
The points will replace the first samples of each experiment passed to
|
||||||
|
the ``BasicVariantGenerator``.
|
||||||
|
|
||||||
|
Each point will replace one sample of the specified ``num_samples``. If
|
||||||
|
grid search variables are overwritten with the values specified in the
|
||||||
|
presets, the number of samples will thus be reduced.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
from ray.tune.suggest import BasicVariantGenerator
|
from ray import tune
|
||||||
|
from ray.tune.suggest.basic_variant import BasicVariantGenerator
|
||||||
|
|
||||||
|
|
||||||
|
tune.run(
|
||||||
|
lambda config: config["a"] + config["b"],
|
||||||
|
config={
|
||||||
|
"a": tune.grid_search([1, 2]),
|
||||||
|
"b": tune.randint(0, 3)
|
||||||
|
},
|
||||||
|
search_alg=BasicVariantGenerator(points_to_evaluate=[
|
||||||
|
{"a": 2, "b": 2},
|
||||||
|
{"a": 1},
|
||||||
|
{"b": 2}
|
||||||
|
]),
|
||||||
|
num_samples=4)
|
||||||
|
|
||||||
|
The example above will produce six trials via four samples:
|
||||||
|
|
||||||
|
- The first sample will produce one trial with ``a=2`` and ``b=2``.
|
||||||
|
- The second sample will produce one trial with ``a=1`` and ``b`` sampled
|
||||||
|
randomly
|
||||||
|
- The third sample will produce two trials, one for each grid search
|
||||||
|
value of ``a``. It will be ``b=2`` for both of these trials.
|
||||||
|
- The fourth sample will produce two trials, one for each grid search
|
||||||
|
value of ``a``. ``b`` will be sampled randomly and independently for
|
||||||
|
both of these trials.
|
||||||
|
|
||||||
searcher = BasicVariantGenerator()
|
|
||||||
searcher.add_configurations({"experiment": { ... }})
|
|
||||||
trial = searcher.next_trial()
|
|
||||||
searcher.is_finished == True
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self, points_to_evaluate: Optional[List[Dict]] = None):
|
||||||
"""Initializes the Variant Generator.
|
"""Initializes the Variant Generator.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
@ -48,6 +99,8 @@ class BasicVariantGenerator(SearchAlgorithm):
|
||||||
self._counter = 0
|
self._counter = 0
|
||||||
self._finished = False
|
self._finished = False
|
||||||
|
|
||||||
|
self._points_to_evaluate = points_to_evaluate or []
|
||||||
|
|
||||||
# Unique prefix for all trials generated, e.g., trial ids start as
|
# Unique prefix for all trials generated, e.g., trial ids start as
|
||||||
# 2f1e_00001, 2f1ef_00002, 2f1ef_0003, etc. Overridable for testing.
|
# 2f1e_00001, 2f1ef_00002, 2f1ef_0003, etc. Overridable for testing.
|
||||||
force_test_uuid = os.environ.get("_TEST_TUNE_TRIAL_UUID")
|
force_test_uuid = os.environ.get("_TEST_TUNE_TRIAL_UUID")
|
||||||
|
@ -72,12 +125,14 @@ class BasicVariantGenerator(SearchAlgorithm):
|
||||||
"""
|
"""
|
||||||
experiment_list = convert_to_experiment_list(experiments)
|
experiment_list = convert_to_experiment_list(experiments)
|
||||||
for experiment in experiment_list:
|
for experiment in experiment_list:
|
||||||
self._total_samples += count_variants(experiment.spec)
|
points_to_evaluate = copy.deepcopy(self._points_to_evaluate)
|
||||||
|
self._total_samples += count_variants(experiment.spec,
|
||||||
|
points_to_evaluate)
|
||||||
self._trial_generator = itertools.chain(
|
self._trial_generator = itertools.chain(
|
||||||
self._trial_generator,
|
self._trial_generator,
|
||||||
self._generate_trials(
|
self._generate_trials(
|
||||||
experiment.spec.get("num_samples", 1), experiment.spec,
|
experiment.spec.get("num_samples", 1), experiment.spec,
|
||||||
experiment.dir_name))
|
experiment.dir_name, points_to_evaluate))
|
||||||
|
|
||||||
def next_trial(self):
|
def next_trial(self):
|
||||||
"""Provides one Trial object to be queued into the TrialRunner.
|
"""Provides one Trial object to be queued into the TrialRunner.
|
||||||
|
@ -95,7 +150,11 @@ class BasicVariantGenerator(SearchAlgorithm):
|
||||||
self.set_finished()
|
self.set_finished()
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def _generate_trials(self, num_samples, unresolved_spec, output_path=""):
|
def _generate_trials(self,
|
||||||
|
num_samples,
|
||||||
|
unresolved_spec,
|
||||||
|
output_path="",
|
||||||
|
points_to_evaluate=None):
|
||||||
"""Generates Trial objects with the variant generation process.
|
"""Generates Trial objects with the variant generation process.
|
||||||
|
|
||||||
Uses a fixed point iteration to resolve variants. All trials
|
Uses a fixed point iteration to resolve variants. All trials
|
||||||
|
@ -109,6 +168,28 @@ class BasicVariantGenerator(SearchAlgorithm):
|
||||||
|
|
||||||
if "run" not in unresolved_spec:
|
if "run" not in unresolved_spec:
|
||||||
raise TuneError("Must specify `run` in {}".format(unresolved_spec))
|
raise TuneError("Must specify `run` in {}".format(unresolved_spec))
|
||||||
|
|
||||||
|
points_to_evaluate = points_to_evaluate or []
|
||||||
|
|
||||||
|
while points_to_evaluate:
|
||||||
|
config = points_to_evaluate.pop(0)
|
||||||
|
for resolved_vars, spec in get_preset_variants(
|
||||||
|
unresolved_spec, config):
|
||||||
|
trial_id = self._uuid_prefix + ("%05d" % self._counter)
|
||||||
|
experiment_tag = str(self._counter)
|
||||||
|
self._counter += 1
|
||||||
|
yield create_trial_from_spec(
|
||||||
|
spec,
|
||||||
|
output_path,
|
||||||
|
self._parser,
|
||||||
|
evaluated_params=flatten_resolved_vars(resolved_vars),
|
||||||
|
trial_id=trial_id,
|
||||||
|
experiment_tag=experiment_tag)
|
||||||
|
num_samples -= 1
|
||||||
|
|
||||||
|
if num_samples <= 0:
|
||||||
|
return
|
||||||
|
|
||||||
for _ in range(num_samples):
|
for _ in range(num_samples):
|
||||||
for resolved_vars, spec in generate_variants(unresolved_spec):
|
for resolved_vars, spec in generate_variants(unresolved_spec):
|
||||||
trial_id = self._uuid_prefix + ("%05d" % self._counter)
|
trial_id = self._uuid_prefix + ("%05d" % self._counter)
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
import copy
|
import copy
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, Dict, Generator, List, Tuple
|
from collections.abc import Mapping
|
||||||
|
from typing import Any, Dict, Generator, List, Optional, Tuple
|
||||||
|
|
||||||
import numpy
|
import numpy
|
||||||
import random
|
import random
|
||||||
|
@ -138,13 +139,38 @@ def parse_spec_vars(spec: Dict) -> Tuple[List[Tuple[Tuple, Any]], List[Tuple[
|
||||||
return resolved_vars, domain_vars, grid_vars
|
return resolved_vars, domain_vars, grid_vars
|
||||||
|
|
||||||
|
|
||||||
def count_variants(spec: Dict) -> int:
|
def count_variants(spec: Dict, presets: Optional[List[Dict]] = None) -> int:
|
||||||
spec = copy.deepcopy(spec)
|
# Helper function: Deep update dictionary
|
||||||
_, domain_vars, grid_vars = parse_spec_vars(spec)
|
def deep_update(d, u):
|
||||||
grid_count = 1
|
for k, v in u.items():
|
||||||
for path, domain in grid_vars:
|
if isinstance(v, Mapping):
|
||||||
grid_count *= len(domain.categories)
|
d[k] = deep_update(d.get(k, {}), v)
|
||||||
return spec.get("num_samples", 1) * grid_count
|
else:
|
||||||
|
d[k] = v
|
||||||
|
return d
|
||||||
|
|
||||||
|
# Count samples for a specific spec
|
||||||
|
def spec_samples(spec, num_samples=1):
|
||||||
|
_, domain_vars, grid_vars = parse_spec_vars(spec)
|
||||||
|
grid_count = 1
|
||||||
|
for path, domain in grid_vars:
|
||||||
|
grid_count *= len(domain.categories)
|
||||||
|
return num_samples * grid_count
|
||||||
|
|
||||||
|
total_samples = 0
|
||||||
|
total_num_samples = spec.get("num_samples", 1)
|
||||||
|
# For each preset, overwrite the spec and count the samples generated
|
||||||
|
# for this preset
|
||||||
|
for preset in presets:
|
||||||
|
preset_spec = copy.deepcopy(spec)
|
||||||
|
deep_update(preset_spec["config"], preset)
|
||||||
|
total_samples += spec_samples(preset_spec, 1)
|
||||||
|
total_num_samples -= 1
|
||||||
|
|
||||||
|
# Add the remaining samples
|
||||||
|
if total_num_samples > 0:
|
||||||
|
total_samples += spec_samples(spec, total_num_samples)
|
||||||
|
return total_samples
|
||||||
|
|
||||||
|
|
||||||
def _generate_variants(spec: Dict) -> Tuple[Dict, Dict]:
|
def _generate_variants(spec: Dict) -> Tuple[Dict, Dict]:
|
||||||
|
@ -172,6 +198,45 @@ def _generate_variants(spec: Dict) -> Tuple[Dict, Dict]:
|
||||||
yield resolved_vars, spec
|
yield resolved_vars, spec
|
||||||
|
|
||||||
|
|
||||||
|
def get_preset_variants(spec: Dict, config: Dict):
|
||||||
|
"""Get variants according to a spec, initialized with a config.
|
||||||
|
|
||||||
|
Variables from the spec are overwritten by the variables in the config.
|
||||||
|
Thus, we may end up with less sampled parameters.
|
||||||
|
|
||||||
|
This function also checks if values used to overwrite search space
|
||||||
|
parameters are valid, and logs a warning if not.
|
||||||
|
"""
|
||||||
|
spec = copy.deepcopy(spec)
|
||||||
|
|
||||||
|
resolved, _, _ = parse_spec_vars(config)
|
||||||
|
|
||||||
|
for path, val in resolved:
|
||||||
|
try:
|
||||||
|
domain = _get_value(spec["config"], path)
|
||||||
|
if isinstance(domain, dict):
|
||||||
|
if "grid_search" in domain:
|
||||||
|
domain = Categorical(domain["grid_search"])
|
||||||
|
else:
|
||||||
|
# If users want to overwrite an entire subdict,
|
||||||
|
# let them do it.
|
||||||
|
domain = None
|
||||||
|
except IndexError as exc:
|
||||||
|
raise ValueError(
|
||||||
|
f"Pre-set config key `{'/'.join(path)}` does not correspond "
|
||||||
|
f"to a valid key in the search space definition. Please add "
|
||||||
|
f"this path to the `config` variable passed to `tune.run()`."
|
||||||
|
) from exc
|
||||||
|
|
||||||
|
if domain and not domain.is_valid(val):
|
||||||
|
logger.warning(
|
||||||
|
f"Pre-set value `{val}` is not within valid values of "
|
||||||
|
f"parameter `{'/'.join(path)}`: {domain.domain_str}")
|
||||||
|
assign_value(spec["config"], path, val)
|
||||||
|
|
||||||
|
return _generate_variants(spec)
|
||||||
|
|
||||||
|
|
||||||
def assign_value(spec: Dict, path: Tuple, value: Any):
|
def assign_value(spec: Dict, path: Tuple, value: Any):
|
||||||
for k in path[:-1]:
|
for k in path[:-1]:
|
||||||
spec = spec[k]
|
spec = spec[k]
|
||||||
|
|
|
@ -2,6 +2,7 @@ import numpy as np
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from ray import tune
|
from ray import tune
|
||||||
|
from ray.tune import Experiment
|
||||||
from ray.tune.suggest.variant_generator import generate_variants
|
from ray.tune.suggest.variant_generator import generate_variants
|
||||||
|
|
||||||
|
|
||||||
|
@ -871,6 +872,102 @@ class SearchSpaceTest(unittest.TestCase):
|
||||||
return self._testPointsToEvaluate(
|
return self._testPointsToEvaluate(
|
||||||
ZOOptSearch, config, budget=10, parallel_num=8)
|
ZOOptSearch, config, budget=10, parallel_num=8)
|
||||||
|
|
||||||
|
def testPointsToEvaluateBasicVariant(self):
|
||||||
|
config = {
|
||||||
|
"metric": tune.sample.Categorical([1, 2, 3, 4]).uniform(),
|
||||||
|
"a": tune.sample.Categorical(["t1", "t2", "t3", "t4"]).uniform(),
|
||||||
|
"b": tune.sample.Integer(0, 5),
|
||||||
|
"c": tune.sample.Float(1e-4, 1e-1).loguniform()
|
||||||
|
}
|
||||||
|
|
||||||
|
from ray.tune.suggest.basic_variant import BasicVariantGenerator
|
||||||
|
return self._testPointsToEvaluate(BasicVariantGenerator, config)
|
||||||
|
|
||||||
|
def testPointsToEvaluateBasicVariantAdvanced(self):
|
||||||
|
config = {
|
||||||
|
"grid_1": tune.grid_search(["a", "b", "c", "d"]),
|
||||||
|
"grid_2": tune.grid_search(["x", "y", "z"]),
|
||||||
|
"nested": {
|
||||||
|
"random": tune.uniform(2., 10.),
|
||||||
|
"dependent": tune.sample_from(
|
||||||
|
lambda spec: -1. * spec.config.nested.random)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
points = [
|
||||||
|
{
|
||||||
|
"grid_1": "b"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"grid_2": "z"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"grid_1": "a",
|
||||||
|
"grid_2": "y"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"nested": {
|
||||||
|
"random": 8.0
|
||||||
|
}
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
from ray.tune.suggest.basic_variant import BasicVariantGenerator
|
||||||
|
|
||||||
|
# grid_1 * grid_2 are 3 * 4 = 12 variants per complete grid search
|
||||||
|
# However if one grid var is set by preset variables, that run
|
||||||
|
# is excluded from grid search.
|
||||||
|
|
||||||
|
# Point 1 overwrites grid_1, so the first trial only grid searches
|
||||||
|
# over grid_2 (3 trials).
|
||||||
|
# The remaining 5 trials search over the whole space (5 * 12 trials)
|
||||||
|
searcher = BasicVariantGenerator(points_to_evaluate=[points[0]])
|
||||||
|
exp = Experiment(
|
||||||
|
run=_mock_objective, name="test", config=config, num_samples=6)
|
||||||
|
searcher.add_configurations(exp)
|
||||||
|
self.assertEqual(searcher.total_samples, 1 * 3 + 5 * 12)
|
||||||
|
|
||||||
|
# Point 2 overwrites grid_2, so the first trial only grid searches
|
||||||
|
# over grid_1 (4 trials).
|
||||||
|
# The remaining 5 trials search over the whole space (5 * 12 trials)
|
||||||
|
searcher = BasicVariantGenerator(points_to_evaluate=[points[1]])
|
||||||
|
exp = Experiment(
|
||||||
|
run=_mock_objective, name="test", config=config, num_samples=6)
|
||||||
|
searcher.add_configurations(exp)
|
||||||
|
self.assertEqual(searcher.total_samples, 1 * 4 + 5 * 12)
|
||||||
|
|
||||||
|
# Point 3 overwrites grid_1 and grid_2, so the first trial does not
|
||||||
|
# grid search.
|
||||||
|
# The remaining 5 trials search over the whole space (5 * 12 trials)
|
||||||
|
searcher = BasicVariantGenerator(points_to_evaluate=[points[2]])
|
||||||
|
exp = Experiment(
|
||||||
|
run=_mock_objective, name="test", config=config, num_samples=6)
|
||||||
|
searcher.add_configurations(exp)
|
||||||
|
self.assertEqual(searcher.total_samples, 1 + 5 * 12)
|
||||||
|
|
||||||
|
# When initialized with all points, the first three trials are
|
||||||
|
# defined by the logic above. Only 3 trials are grid searched
|
||||||
|
# compeletely.
|
||||||
|
searcher = BasicVariantGenerator(points_to_evaluate=points)
|
||||||
|
exp = Experiment(
|
||||||
|
run=_mock_objective, name="test", config=config, num_samples=6)
|
||||||
|
searcher.add_configurations(exp)
|
||||||
|
self.assertEqual(searcher.total_samples, 1 * 3 + 1 * 4 + 1 + 3 * 12)
|
||||||
|
|
||||||
|
# Run this and confirm results
|
||||||
|
analysis = tune.run(exp, search_alg=searcher)
|
||||||
|
configs = [trial.config for trial in analysis.trials]
|
||||||
|
|
||||||
|
self.assertEqual(len(configs), searcher.total_samples)
|
||||||
|
self.assertTrue(
|
||||||
|
all(config["grid_1"] == "b" for config in configs[0:3]))
|
||||||
|
self.assertTrue(
|
||||||
|
all(config["grid_2"] == "z" for config in configs[3:7]))
|
||||||
|
self.assertTrue(configs[7]["grid_1"] == "a"
|
||||||
|
and configs[7]["grid_2"] == "y")
|
||||||
|
self.assertTrue(configs[8]["nested"]["random"] == 8.0)
|
||||||
|
self.assertTrue(configs[8]["nested"]["dependent"] == -8.0)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
import pytest
|
import pytest
|
||||||
|
|
|
@ -171,7 +171,8 @@ def run(
|
||||||
samples are generated until a stopping condition is met.
|
samples are generated until a stopping condition is met.
|
||||||
local_dir (str): Local dir to save training results to.
|
local_dir (str): Local dir to save training results to.
|
||||||
Defaults to ``~/ray_results``.
|
Defaults to ``~/ray_results``.
|
||||||
search_alg (Searcher): Search algorithm for optimization.
|
search_alg (Searcher|SearchAlgorithm): Search algorithm for
|
||||||
|
optimization.
|
||||||
scheduler (TrialScheduler): Scheduler for executing
|
scheduler (TrialScheduler): Scheduler for executing
|
||||||
the experiment. Choose among FIFO (default), MedianStopping,
|
the experiment. Choose among FIFO (default), MedianStopping,
|
||||||
AsyncHyperBand, HyperBand and PopulationBasedTraining. Refer to
|
AsyncHyperBand, HyperBand and PopulationBasedTraining. Refer to
|
||||||
|
|
Loading…
Add table
Reference in a new issue