diff --git a/doc/source/tune/api_docs/search_space.rst b/doc/source/tune/api_docs/search_space.rst index 005942fe9..3c069760f 100644 --- a/doc/source/tune/api_docs/search_space.rst +++ b/doc/source/tune/api_docs/search_space.rst @@ -263,10 +263,7 @@ Grid Search API .. autofunction:: ray.tune.grid_search -Internals ---------- +References +---------- -BasicVariantGenerator -~~~~~~~~~~~~~~~~~~~~~ - -.. autoclass:: ray.tune.suggest.BasicVariantGenerator +See also :ref:`tune-basicvariant`. \ No newline at end of file diff --git a/doc/source/tune/api_docs/suggestion.rst b/doc/source/tune/api_docs/suggestion.rst index 9675f8537..05c3d466a 100644 --- a/doc/source/tune/api_docs/suggestion.rst +++ b/doc/source/tune/api_docs/suggestion.rst @@ -22,6 +22,10 @@ Summary - Summary - Website - Code Example + * - :ref:`Random search/grid search ` + - Random search/grid search + - + - :doc:`/tune/examples/tune_basic_example` * - :ref:`AxSearch ` - Bayesian/Bandit Optimization - [`Ax `__] @@ -123,6 +127,21 @@ identifier. .. note:: This is currently not implemented for: AxSearch, TuneBOHB, SigOptSearch, and DragonflySearch. +.. _tune-basicvariant: + +Random search and grid search (tune.suggest.basic_variant.BasicVariantGenerator) +-------------------------------------------------------------------------------- + +The default and most basic way to do hyperparameter search is via random and grid search. +Ray Tune does this through the :class:`BasicVariantGenerator ` +class that generates trial variants given a search space definition. + +The :class:`BasicVariantGenerator ` is used per +default if no search algorithm is passed to +:func:`tune.run() `. + +.. autoclass:: ray.tune.suggest.basic_variant.BasicVariantGenerator + .. _tune-ax: Ax (tune.suggest.ax.AxSearch) diff --git a/doc/source/tune/examples/index.rst b/doc/source/tune/examples/index.rst index 89af9deb9..54852d550 100644 --- a/doc/source/tune/examples/index.rst +++ b/doc/source/tune/examples/index.rst @@ -13,7 +13,7 @@ If any example is broken, or if you'd like to add an example to this page, feel General Examples ---------------- - +- :doc:`/tune/examples/tune_basic_example`: Simple example for doing a basic random and grid search. - :doc:`/tune/examples/async_hyperband_example`: Example of using a simple tuning function with AsyncHyperBandScheduler. - :doc:`/tune/examples/hyperband_function_example`: Example of using a Trainable function with HyperBandScheduler. Also uses the AsyncHyperBandScheduler. - :doc:`/tune/examples/pbt_function`: Example of using the function API with a PopulationBasedTraining scheduler. diff --git a/doc/source/tune/examples/tune_basic_example.rst b/doc/source/tune/examples/tune_basic_example.rst new file mode 100644 index 000000000..1be5ab3f1 --- /dev/null +++ b/doc/source/tune/examples/tune_basic_example.rst @@ -0,0 +1,6 @@ +:orphan: + +tune_basic_example +~~~~~~~~~~~~~~~~~~ + +.. literalinclude:: /../../python/ray/tune/examples/tune_basic_example.py diff --git a/python/ray/tune/BUILD b/python/ray/tune/BUILD index f10df3ec9..8b3439853 100644 --- a/python/ray/tune/BUILD +++ b/python/ray/tune/BUILD @@ -157,7 +157,7 @@ py_test( py_test( name = "test_sample", - size = "small", + size = "medium", srcs = ["tests/test_sample.py"], deps = [":tune_lib"], tags = ["exclusive"], @@ -696,6 +696,15 @@ py_test( args = ["--smoke-test"] ) +py_test( + name = "tune_basic_example", + size = "small", + srcs = ["examples/tune_basic_example.py"], + deps = [":tune_lib"], + tags = ["exclusive", "example"], + args = ["--smoke-test"] +) + # Downloads too much data. # py_test( # name = "tune_cifar10_gluon", diff --git a/python/ray/tune/examples/tune_basic_example.py b/python/ray/tune/examples/tune_basic_example.py new file mode 100644 index 000000000..30677bc0c --- /dev/null +++ b/python/ray/tune/examples/tune_basic_example.py @@ -0,0 +1,51 @@ +"""This example demonstrates basic Ray Tune random search and grid search.""" +import time + +import ray +from ray import tune + + +def evaluation_fn(step, width, height): + time.sleep(0.1) + return (0.1 + width * step / 100)**(-1) + height * 0.1 + + +def easy_objective(config): + # Hyperparameters + width, height = config["width"], config["height"] + + for step in range(config["steps"]): + # Iterative training function - can be any arbitrary training procedure + intermediate_score = evaluation_fn(step, width, height) + # Feed the score back back to Tune. + tune.report(iterations=step, mean_loss=intermediate_score) + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument( + "--smoke-test", action="store_true", help="Finish quickly for testing") + args, _ = parser.parse_known_args() + ray.init(configure_logging=False) + + # This will do a grid search over the `activation` parameter. This means + # that each of the two values (`relu` and `tanh`) will be sampled once + # for each sample (`num_samples`). We end up with 2 * 50 = 100 samples. + # The `width` and `height` parameters are sampled randomly. + # `steps` is a constant parameter. + + analysis = tune.run( + easy_objective, + metric="mean_loss", + mode="min", + num_samples=5 if args.smoke_test else 50, + config={ + "steps": 5 if args.smoke_test else 100, + "width": tune.uniform(0, 20), + "height": tune.uniform(-100, 100), + "activation": tune.grid_search(["relu", "tanh"]) + }) + + print("Best hyperparameters found were: ", analysis.best_config) diff --git a/python/ray/tune/sample.py b/python/ray/tune/sample.py index a9d82331a..7190c69d2 100644 --- a/python/ray/tune/sample.py +++ b/python/ray/tune/sample.py @@ -53,6 +53,14 @@ class Domain: def is_function(self): return False + def is_valid(self, value: Any): + """Returns True if `value` is a valid value in this domain.""" + raise NotImplementedError + + @property + def domain_str(self): + return "(unknown)" + class Sampler: def sample(self, @@ -203,6 +211,13 @@ class Float(Domain): new.set_sampler(Quantized(new.get_sampler(), q), allow_override=True) return new + def is_valid(self, value: float): + return self.lower <= value <= self.upper + + @property + def domain_str(self): + return f"({self.lower}, {self.upper})" + class Integer(Domain): class _Uniform(Uniform): @@ -232,6 +247,13 @@ class Integer(Domain): new.set_sampler(self._Uniform()) return new + def is_valid(self, value: int): + return self.lower <= value <= self.upper + + @property + def domain_str(self): + return f"({self.lower}, {self.upper})" + class Categorical(Domain): class _Uniform(Uniform): @@ -264,6 +286,13 @@ class Categorical(Domain): def __getitem__(self, item): return self.categories[item] + def is_valid(self, value: Any): + return value in self.categories + + @property + def domain_str(self): + return f"{self.categories}" + class Function(Domain): class _CallSampler(BaseSampler): @@ -295,6 +324,13 @@ class Function(Domain): def is_function(self): return True + def is_valid(self, value: Any): + return True # This is user-defined, so lets not assume anything + + @property + def domain_str(self): + return f"{self.func}()" + class Quantized(Sampler): def __init__(self, sampler: Sampler, q: Union[float, int]): diff --git a/python/ray/tune/suggest/basic_variant.py b/python/ray/tune/suggest/basic_variant.py index 435e6dd01..46f54888b 100644 --- a/python/ray/tune/suggest/basic_variant.py +++ b/python/ray/tune/suggest/basic_variant.py @@ -1,44 +1,95 @@ +import copy import itertools import os import uuid -from typing import Dict, List, Union +from typing import Dict, List, Optional, Union from ray.tune.error import TuneError from ray.tune.experiment import Experiment, convert_to_experiment_list from ray.tune.config_parser import make_parser, create_trial_from_spec from ray.tune.suggest.variant_generator import ( - count_variants, generate_variants, format_vars, flatten_resolved_vars) + count_variants, generate_variants, format_vars, flatten_resolved_vars, + get_preset_variants) from ray.tune.suggest.search import SearchAlgorithm class BasicVariantGenerator(SearchAlgorithm): """Uses Tune's variant generation for resolving variables. - See also: `ray.tune.suggest.variant_generator`. + This is the default search algorithm used if no other search algorithm + is specified. - User API: + + Args: + points_to_evaluate (list): Initial parameter suggestions to be run + first. This is for when you already have some good parameters + you want to run first to help the algorithm make better suggestions + for future parameters. Needs to be a list of dicts containing the + configurations. + + + Example: .. code-block:: python from ray import tune - from ray.tune.suggest import BasicVariantGenerator - searcher = BasicVariantGenerator() - tune.run(my_trainable_func, algo=searcher) + # This will automatically use the `BasicVariantGenerator` + tune.run( + lambda config: config["a"] + config["b"], + config={ + "a": tune.grid_search([1, 2]), + "b": tune.randint(0, 3) + }, + num_samples=4) - Internal API: + In the example above, 8 trials will be generated: For each sample + (``4``), each of the grid search variants for ``a`` will be sampled + once. The ``b`` parameter will be sampled randomly. + + The generator accepts a pre-set list of points that should be evaluated. + The points will replace the first samples of each experiment passed to + the ``BasicVariantGenerator``. + + Each point will replace one sample of the specified ``num_samples``. If + grid search variables are overwritten with the values specified in the + presets, the number of samples will thus be reduced. + + Example: .. code-block:: python - from ray.tune.suggest import BasicVariantGenerator + from ray import tune + from ray.tune.suggest.basic_variant import BasicVariantGenerator + + + tune.run( + lambda config: config["a"] + config["b"], + config={ + "a": tune.grid_search([1, 2]), + "b": tune.randint(0, 3) + }, + search_alg=BasicVariantGenerator(points_to_evaluate=[ + {"a": 2, "b": 2}, + {"a": 1}, + {"b": 2} + ]), + num_samples=4) + + The example above will produce six trials via four samples: + + - The first sample will produce one trial with ``a=2`` and ``b=2``. + - The second sample will produce one trial with ``a=1`` and ``b`` sampled + randomly + - The third sample will produce two trials, one for each grid search + value of ``a``. It will be ``b=2`` for both of these trials. + - The fourth sample will produce two trials, one for each grid search + value of ``a``. ``b`` will be sampled randomly and independently for + both of these trials. - searcher = BasicVariantGenerator() - searcher.add_configurations({"experiment": { ... }}) - trial = searcher.next_trial() - searcher.is_finished == True """ - def __init__(self): + def __init__(self, points_to_evaluate: Optional[List[Dict]] = None): """Initializes the Variant Generator. """ @@ -48,6 +99,8 @@ class BasicVariantGenerator(SearchAlgorithm): self._counter = 0 self._finished = False + self._points_to_evaluate = points_to_evaluate or [] + # Unique prefix for all trials generated, e.g., trial ids start as # 2f1e_00001, 2f1ef_00002, 2f1ef_0003, etc. Overridable for testing. force_test_uuid = os.environ.get("_TEST_TUNE_TRIAL_UUID") @@ -72,12 +125,14 @@ class BasicVariantGenerator(SearchAlgorithm): """ experiment_list = convert_to_experiment_list(experiments) for experiment in experiment_list: - self._total_samples += count_variants(experiment.spec) + points_to_evaluate = copy.deepcopy(self._points_to_evaluate) + self._total_samples += count_variants(experiment.spec, + points_to_evaluate) self._trial_generator = itertools.chain( self._trial_generator, self._generate_trials( experiment.spec.get("num_samples", 1), experiment.spec, - experiment.dir_name)) + experiment.dir_name, points_to_evaluate)) def next_trial(self): """Provides one Trial object to be queued into the TrialRunner. @@ -95,7 +150,11 @@ class BasicVariantGenerator(SearchAlgorithm): self.set_finished() return None - def _generate_trials(self, num_samples, unresolved_spec, output_path=""): + def _generate_trials(self, + num_samples, + unresolved_spec, + output_path="", + points_to_evaluate=None): """Generates Trial objects with the variant generation process. Uses a fixed point iteration to resolve variants. All trials @@ -109,6 +168,28 @@ class BasicVariantGenerator(SearchAlgorithm): if "run" not in unresolved_spec: raise TuneError("Must specify `run` in {}".format(unresolved_spec)) + + points_to_evaluate = points_to_evaluate or [] + + while points_to_evaluate: + config = points_to_evaluate.pop(0) + for resolved_vars, spec in get_preset_variants( + unresolved_spec, config): + trial_id = self._uuid_prefix + ("%05d" % self._counter) + experiment_tag = str(self._counter) + self._counter += 1 + yield create_trial_from_spec( + spec, + output_path, + self._parser, + evaluated_params=flatten_resolved_vars(resolved_vars), + trial_id=trial_id, + experiment_tag=experiment_tag) + num_samples -= 1 + + if num_samples <= 0: + return + for _ in range(num_samples): for resolved_vars, spec in generate_variants(unresolved_spec): trial_id = self._uuid_prefix + ("%05d" % self._counter) diff --git a/python/ray/tune/suggest/variant_generator.py b/python/ray/tune/suggest/variant_generator.py index 7048cd804..849b3b012 100644 --- a/python/ray/tune/suggest/variant_generator.py +++ b/python/ray/tune/suggest/variant_generator.py @@ -1,6 +1,7 @@ import copy import logging -from typing import Any, Dict, Generator, List, Tuple +from collections.abc import Mapping +from typing import Any, Dict, Generator, List, Optional, Tuple import numpy import random @@ -138,13 +139,38 @@ def parse_spec_vars(spec: Dict) -> Tuple[List[Tuple[Tuple, Any]], List[Tuple[ return resolved_vars, domain_vars, grid_vars -def count_variants(spec: Dict) -> int: - spec = copy.deepcopy(spec) - _, domain_vars, grid_vars = parse_spec_vars(spec) - grid_count = 1 - for path, domain in grid_vars: - grid_count *= len(domain.categories) - return spec.get("num_samples", 1) * grid_count +def count_variants(spec: Dict, presets: Optional[List[Dict]] = None) -> int: + # Helper function: Deep update dictionary + def deep_update(d, u): + for k, v in u.items(): + if isinstance(v, Mapping): + d[k] = deep_update(d.get(k, {}), v) + else: + d[k] = v + return d + + # Count samples for a specific spec + def spec_samples(spec, num_samples=1): + _, domain_vars, grid_vars = parse_spec_vars(spec) + grid_count = 1 + for path, domain in grid_vars: + grid_count *= len(domain.categories) + return num_samples * grid_count + + total_samples = 0 + total_num_samples = spec.get("num_samples", 1) + # For each preset, overwrite the spec and count the samples generated + # for this preset + for preset in presets: + preset_spec = copy.deepcopy(spec) + deep_update(preset_spec["config"], preset) + total_samples += spec_samples(preset_spec, 1) + total_num_samples -= 1 + + # Add the remaining samples + if total_num_samples > 0: + total_samples += spec_samples(spec, total_num_samples) + return total_samples def _generate_variants(spec: Dict) -> Tuple[Dict, Dict]: @@ -172,6 +198,45 @@ def _generate_variants(spec: Dict) -> Tuple[Dict, Dict]: yield resolved_vars, spec +def get_preset_variants(spec: Dict, config: Dict): + """Get variants according to a spec, initialized with a config. + + Variables from the spec are overwritten by the variables in the config. + Thus, we may end up with less sampled parameters. + + This function also checks if values used to overwrite search space + parameters are valid, and logs a warning if not. + """ + spec = copy.deepcopy(spec) + + resolved, _, _ = parse_spec_vars(config) + + for path, val in resolved: + try: + domain = _get_value(spec["config"], path) + if isinstance(domain, dict): + if "grid_search" in domain: + domain = Categorical(domain["grid_search"]) + else: + # If users want to overwrite an entire subdict, + # let them do it. + domain = None + except IndexError as exc: + raise ValueError( + f"Pre-set config key `{'/'.join(path)}` does not correspond " + f"to a valid key in the search space definition. Please add " + f"this path to the `config` variable passed to `tune.run()`." + ) from exc + + if domain and not domain.is_valid(val): + logger.warning( + f"Pre-set value `{val}` is not within valid values of " + f"parameter `{'/'.join(path)}`: {domain.domain_str}") + assign_value(spec["config"], path, val) + + return _generate_variants(spec) + + def assign_value(spec: Dict, path: Tuple, value: Any): for k in path[:-1]: spec = spec[k] diff --git a/python/ray/tune/tests/test_sample.py b/python/ray/tune/tests/test_sample.py index d40900a6c..921e0c9ca 100644 --- a/python/ray/tune/tests/test_sample.py +++ b/python/ray/tune/tests/test_sample.py @@ -2,6 +2,7 @@ import numpy as np import unittest from ray import tune +from ray.tune import Experiment from ray.tune.suggest.variant_generator import generate_variants @@ -871,6 +872,102 @@ class SearchSpaceTest(unittest.TestCase): return self._testPointsToEvaluate( ZOOptSearch, config, budget=10, parallel_num=8) + def testPointsToEvaluateBasicVariant(self): + config = { + "metric": tune.sample.Categorical([1, 2, 3, 4]).uniform(), + "a": tune.sample.Categorical(["t1", "t2", "t3", "t4"]).uniform(), + "b": tune.sample.Integer(0, 5), + "c": tune.sample.Float(1e-4, 1e-1).loguniform() + } + + from ray.tune.suggest.basic_variant import BasicVariantGenerator + return self._testPointsToEvaluate(BasicVariantGenerator, config) + + def testPointsToEvaluateBasicVariantAdvanced(self): + config = { + "grid_1": tune.grid_search(["a", "b", "c", "d"]), + "grid_2": tune.grid_search(["x", "y", "z"]), + "nested": { + "random": tune.uniform(2., 10.), + "dependent": tune.sample_from( + lambda spec: -1. * spec.config.nested.random) + } + } + + points = [ + { + "grid_1": "b" + }, + { + "grid_2": "z" + }, + { + "grid_1": "a", + "grid_2": "y" + }, + { + "nested": { + "random": 8.0 + } + }, + ] + + from ray.tune.suggest.basic_variant import BasicVariantGenerator + + # grid_1 * grid_2 are 3 * 4 = 12 variants per complete grid search + # However if one grid var is set by preset variables, that run + # is excluded from grid search. + + # Point 1 overwrites grid_1, so the first trial only grid searches + # over grid_2 (3 trials). + # The remaining 5 trials search over the whole space (5 * 12 trials) + searcher = BasicVariantGenerator(points_to_evaluate=[points[0]]) + exp = Experiment( + run=_mock_objective, name="test", config=config, num_samples=6) + searcher.add_configurations(exp) + self.assertEqual(searcher.total_samples, 1 * 3 + 5 * 12) + + # Point 2 overwrites grid_2, so the first trial only grid searches + # over grid_1 (4 trials). + # The remaining 5 trials search over the whole space (5 * 12 trials) + searcher = BasicVariantGenerator(points_to_evaluate=[points[1]]) + exp = Experiment( + run=_mock_objective, name="test", config=config, num_samples=6) + searcher.add_configurations(exp) + self.assertEqual(searcher.total_samples, 1 * 4 + 5 * 12) + + # Point 3 overwrites grid_1 and grid_2, so the first trial does not + # grid search. + # The remaining 5 trials search over the whole space (5 * 12 trials) + searcher = BasicVariantGenerator(points_to_evaluate=[points[2]]) + exp = Experiment( + run=_mock_objective, name="test", config=config, num_samples=6) + searcher.add_configurations(exp) + self.assertEqual(searcher.total_samples, 1 + 5 * 12) + + # When initialized with all points, the first three trials are + # defined by the logic above. Only 3 trials are grid searched + # compeletely. + searcher = BasicVariantGenerator(points_to_evaluate=points) + exp = Experiment( + run=_mock_objective, name="test", config=config, num_samples=6) + searcher.add_configurations(exp) + self.assertEqual(searcher.total_samples, 1 * 3 + 1 * 4 + 1 + 3 * 12) + + # Run this and confirm results + analysis = tune.run(exp, search_alg=searcher) + configs = [trial.config for trial in analysis.trials] + + self.assertEqual(len(configs), searcher.total_samples) + self.assertTrue( + all(config["grid_1"] == "b" for config in configs[0:3])) + self.assertTrue( + all(config["grid_2"] == "z" for config in configs[3:7])) + self.assertTrue(configs[7]["grid_1"] == "a" + and configs[7]["grid_2"] == "y") + self.assertTrue(configs[8]["nested"]["random"] == 8.0) + self.assertTrue(configs[8]["nested"]["dependent"] == -8.0) + if __name__ == "__main__": import pytest diff --git a/python/ray/tune/tune.py b/python/ray/tune/tune.py index fe26e12e5..ab3df8ba8 100644 --- a/python/ray/tune/tune.py +++ b/python/ray/tune/tune.py @@ -171,7 +171,8 @@ def run( samples are generated until a stopping condition is met. local_dir (str): Local dir to save training results to. Defaults to ``~/ray_results``. - search_alg (Searcher): Search algorithm for optimization. + search_alg (Searcher|SearchAlgorithm): Search algorithm for + optimization. scheduler (TrialScheduler): Scheduler for executing the experiment. Choose among FIFO (default), MedianStopping, AsyncHyperBand, HyperBand and PopulationBasedTraining. Refer to