diff --git a/doc/source/tune.rst b/doc/source/tune.rst index b0cdceb86..72f8a1c72 100644 --- a/doc/source/tune.rst +++ b/doc/source/tune.rst @@ -62,7 +62,7 @@ Features Ray Tune has the following features: -- Scalable implementations of search algorithms such as `Population Based Training (PBT) `__, `Median Stopping Rule `__, and `HyperBand `__. +- Scalable implementations of search algorithms such as `Population Based Training (PBT) `__, `Median Stopping Rule `__, Model-Based Optimization (HyperOpt), and `HyperBand `__. - Integration with visualization tools such as `TensorBoard `__, `rllab's VisKit `__, and a `parallel coordinates visualization `__. @@ -94,12 +94,28 @@ You can find the code for Ray Tune `here on GitHub `__, `Median Stopping Rule `__, and `HyperBand `__. +By default, Ray Tune schedules trials in serial order with the ``FIFOScheduler`` class. However, you can also specify a custom scheduling algorithm that can early stop trials, perturb parameters, or incorporate suggestions from an external service. Currently implemented trial schedulers include `Population Based Training (PBT) `__, `Median Stopping Rule `__, Model-Based Optimization (HyperOpt), and `HyperBand `__. .. code-block:: python run_experiments({...}, scheduler=AsyncHyperBandScheduler()) + +HyperOpt Integration +-------------------- + +The``HyperOptScheduler`` is a Trial Scheduler that is backed by HyperOpt to perform sequential model-based hyperparameter optimization. +In order to use this scheduler, you will need to install HyperOpt via the following command: + +.. code-block:: bash + + $ pip install --upgrade git+git://github.com/hyperopt/hyperopt.git + +An example of this can be found in `hyperopt_example.py `__. + +.. autoclass:: ray.tune.hpo_scheduler.HyperOptScheduler + + Visualizing Results ------------------- diff --git a/docker/examples/Dockerfile b/docker/examples/Dockerfile index c9a964d8a..72a4314da 100644 --- a/docker/examples/Dockerfile +++ b/docker/examples/Dockerfile @@ -4,4 +4,5 @@ FROM ray-project/deploy RUN conda install -y -c conda-forge tensorflow RUN apt-get install -y zlib1g-dev RUN pip install gym[atari] opencv-python==3.2.0.8 smart_open +RUN pip install --upgrade git+git://github.com/hyperopt/hyperopt.git # RUN conda install -y -q pytorch torchvision -c soumith diff --git a/python/ray/tune/config_parser.py b/python/ray/tune/config_parser.py index d4f851b63..8733547fa 100644 --- a/python/ray/tune/config_parser.py +++ b/python/ray/tune/config_parser.py @@ -95,7 +95,8 @@ def make_parser(**kwargs): "many times. Only applies if checkpointing is enabled.") parser.add_argument( "--scheduler", default="FIFO", type=str, - help="FIFO (default), MedianStopping, AsyncHyperBand, or HyperBand.") + help="FIFO (default), MedianStopping, AsyncHyperBand," + "HyperBand, or HyperOpt.") parser.add_argument( "--scheduler-config", default="{}", type=json.loads, help="Config options to pass to the scheduler.") diff --git a/python/ray/tune/examples/hyperopt_example.py b/python/ray/tune/examples/hyperopt_example.py new file mode 100644 index 000000000..42091d2d0 --- /dev/null +++ b/python/ray/tune/examples/hyperopt_example.py @@ -0,0 +1,45 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import ray +from ray.tune import run_experiments, register_trainable +from ray.tune.hpo_scheduler import HyperOptScheduler + + +def easy_objective(config, reporter): + import time + time.sleep(0.2) + reporter( + timesteps_total=1, + episode_reward_mean=-((config["height"]-14) ** 2 + + abs(config["width"]-3))) + time.sleep(0.2) + + +if __name__ == '__main__': + import argparse + from hyperopt import hp + + parser = argparse.ArgumentParser() + parser.add_argument( + "--smoke-test", action="store_true", help="Finish quickly for testing") + args, _ = parser.parse_known_args() + ray.init(redirect_output=True) + + register_trainable("exp", easy_objective) + + space = { + 'width': hp.uniform('width', 0, 20), + 'height': hp.uniform('height', -100, 100), + } + + config = {"my_exp": { + "run": "exp", + "repeat": 5 if args.smoke_test else 1000, + "stop": {"training_iteration": 1}, + "config": { + "space": space}}} + hpo_sched = HyperOptScheduler() + + run_experiments(config, verbose=False, scheduler=hpo_sched) diff --git a/python/ray/tune/experiment.py b/python/ray/tune/experiment.py index 134e87207..bfaf72be3 100644 --- a/python/ray/tune/experiment.py +++ b/python/ray/tune/experiment.py @@ -2,8 +2,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from ray.tune.variant_generator import generate_trials from ray.tune.result import DEFAULT_RESULTS_DIR +from ray.tune.error import TuneError class Experiment(object): @@ -49,8 +49,21 @@ class Experiment(object): "checkpoint_freq": checkpoint_freq, "max_failures": max_failures } - self._trials = generate_trials(spec, name) - def trials(self): - for trial in self._trials: - yield trial + self.name = name + self.spec = spec + + @classmethod + def from_json(cls, name, spec): + """Generates an Experiment object from JSON. + + Args: + name (str): Name of Experiment. + spec (dict): JSON configuration of experiment. + """ + if "run" not in spec: + raise TuneError("No trainable specified!") + exp = cls(name, spec["run"]) + exp.name = name + exp.spec = spec + return exp diff --git a/python/ray/tune/hpo_scheduler.py b/python/ray/tune/hpo_scheduler.py new file mode 100644 index 000000000..1b61ffbd3 --- /dev/null +++ b/python/ray/tune/hpo_scheduler.py @@ -0,0 +1,201 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import copy +import numpy as np +try: + import hyperopt as hpo +except Exception as e: + hpo = None + +from ray.tune.trial import Trial +from ray.tune.error import TuneError +from ray.tune.trial_scheduler import TrialScheduler, FIFOScheduler +from ray.tune.config_parser import make_parser +from ray.tune.variant_generator import to_argv + + +class HyperOptScheduler(FIFOScheduler): + """FIFOScheduler that uses HyperOpt to provide trial suggestions. + + Requires HyperOpt to be installed via source. + Uses the Tree-structured Parzen Estimators algorithm. Externally added + trials will not be tracked by HyperOpt. Also, + variant generation will be limited, as the hyperparameter configuration + must be specified using HyperOpt primitives. + + Parameters: + max_concurrent (int | None): Number of maximum concurrent trials. + If None, then trials will be queued only if resources + are available. + reward_attr (str): The TrainingResult objective value attribute. + This refers to an increasing value, which is internally negated + when interacting with HyperOpt. Suggestion procedures + will use this attribute. + + Examples: + >>> space = {'param': hp.uniform('param', 0, 20)} + >>> config = {"my_exp": { + "run": "exp", + "repeat": 5, + "config": {"space": space}}} + >>> run_experiments(config, scheduler=HyperOptScheduler()) + """ + + def __init__(self, max_concurrent=None, reward_attr="episode_reward_mean"): + assert hpo is not None, "HyperOpt must be installed!" + assert type(max_concurrent) in [type(None), int] + if type(max_concurrent) is int: + assert max_concurrent > 0 + self._max_concurrent = max_concurrent # NOTE: this is modified later + self._reward_attr = reward_attr + self._experiment = None + + def add_experiment(self, experiment, trial_runner): + """Tracks one experiment. + + Will error if one tries to track multiple experiments. + """ + assert self._experiment is None, "HyperOpt only tracks one experiment!" + self._experiment = experiment + + self._output_path = experiment.name + spec = copy.deepcopy(experiment.spec) + + # Set Scheduler field, as Tune Parser will default to FIFO + assert spec.get("scheduler") in [None, "HyperOpt"], "Incorrectly " \ + "specified scheduler!" + spec["scheduler"] = "HyperOpt" + + if "env" in spec: + spec["config"] = spec.get("config", {}) + spec["config"]["env"] = spec["env"] + del spec["env"] + + space = spec["config"]["space"] + del spec["config"]["space"] + + self.parser = make_parser() + self.args = self.parser.parse_args(to_argv(spec)) + self.args.scheduler = "HyperOpt" + self.default_config = copy.deepcopy(spec["config"]) + + self.algo = hpo.tpe.suggest + self.domain = hpo.Domain(lambda spc: spc, space) + self._hpopt_trials = hpo.Trials() + self._tune_to_hp = {} + self._num_trials_left = self.args.repeat + + if type(self._max_concurrent) is int: + self._max_concurrent = min(self._max_concurrent, self.args.repeat) + + self.rstate = np.random.RandomState() + self.trial_generator = self._trial_generator() + self._add_new_trials_if_needed(trial_runner) + + def _trial_generator(self): + while self._num_trials_left > 0: + new_cfg = copy.deepcopy(self.default_config) + new_ids = self._hpopt_trials.new_trial_ids(1) + self._hpopt_trials.refresh() + + # Get new suggestion from + new_trials = self.algo( + new_ids, self.domain, self._hpopt_trials, + self.rstate.randint(2 ** 31 - 1)) + self._hpopt_trials.insert_trial_docs(new_trials) + self._hpopt_trials.refresh() + new_trial = new_trials[0] + new_trial_id = new_trial["tid"] + suggested_config = hpo.base.spec_from_misc(new_trial["misc"]) + new_cfg.update(suggested_config) + + kv_str = "_".join(["{}={}".format(k, str(v)[:5]) + for k, v in sorted(suggested_config.items())]) + experiment_tag = "{}_{}".format(new_trial_id, kv_str) + + # Keep this consistent with tune.variant_generator + trial = Trial( + trainable_name=self.args.run, + config=new_cfg, + local_dir=os.path.join(self.args.local_dir, self._output_path), + experiment_tag=experiment_tag, + resources=self.args.trial_resources, + stopping_criterion=self.args.stop, + checkpoint_freq=self.args.checkpoint_freq, + restore_path=self.args.restore, + upload_dir=self.args.upload_dir, + max_failures=self.args.max_failures) + + self._tune_to_hp[trial] = new_trial_id + self._num_trials_left -= 1 + yield trial + + def on_trial_result(self, trial_runner, trial, result): + ho_trial = self._get_hyperopt_trial(self._tune_to_hp[trial]) + now = hpo.utils.coarse_utcnow() + ho_trial['book_time'] = now + ho_trial['refresh_time'] = now + return TrialScheduler.CONTINUE + + def on_trial_error(self, trial_runner, trial): + ho_trial = self._get_hyperopt_trial(self._tune_to_hp[trial]) + ho_trial['refresh_time'] = hpo.utils.coarse_utcnow() + ho_trial['state'] = hpo.base.JOB_STATE_ERROR + ho_trial['misc']['error'] = (str(TuneError), "Tune Error") + self._hpopt_trials.refresh() + del self._tune_to_hp[trial] + + def on_trial_remove(self, trial_runner, trial): + ho_trial = self._get_hyperopt_trial(self._tune_to_hp[trial]) + ho_trial['refresh_time'] = hpo.utils.coarse_utcnow() + ho_trial['state'] = hpo.base.JOB_STATE_ERROR + ho_trial['misc']['error'] = (str(TuneError), "Tune Removed") + self._hpopt_trials.refresh() + del self._tune_to_hp[trial] + + def on_trial_complete(self, trial_runner, trial, result): + ho_trial = self._get_hyperopt_trial(self._tune_to_hp[trial]) + ho_trial['refresh_time'] = hpo.utils.coarse_utcnow() + ho_trial['state'] = hpo.base.JOB_STATE_DONE + hp_result = self._to_hyperopt_result(result) + ho_trial['result'] = hp_result + self._hpopt_trials.refresh() + del self._tune_to_hp[trial] + + def _to_hyperopt_result(self, result): + return {"loss": -getattr(result, self._reward_attr), + "status": "ok"} + + def _get_hyperopt_trial(self, tid): + return [t for t in self._hpopt_trials.trials if t["tid"] == tid][0] + + def choose_trial_to_run(self, trial_runner): + self._add_new_trials_if_needed(trial_runner) + return FIFOScheduler.choose_trial_to_run(self, trial_runner) + + def _add_new_trials_if_needed(self, trial_runner): + """Checks if there is a next trial ready to be queued. + + This is determined by tracking the number of concurrent + experiments and trials left to run. If self._max_concurrent is None, + scheduler will add new trial if there is none that are pending. + """ + pending = [t for t in trial_runner.get_trials() + if t.status == Trial.PENDING] + if self._num_trials_left <= 0: + return + if self._max_concurrent is None: + if not pending: + trial_runner.add_trial(next(self.trial_generator)) + else: + while self._num_live_trials() < self._max_concurrent: + try: + trial_runner.add_trial(next(self.trial_generator)) + except StopIteration: + break + + def _num_live_trials(self): + return len(self._tune_to_hp) diff --git a/python/ray/tune/test/trial_runner_test.py b/python/ray/tune/test/trial_runner_test.py index bea3d98bd..a45be3104 100644 --- a/python/ray/tune/test/trial_runner_test.py +++ b/python/ray/tune/test/trial_runner_test.py @@ -100,6 +100,7 @@ class TrainableFunctionApiTest(unittest.TestCase): def testBadParams2(self): def f(): run_experiments({"foo": { + "run": "asdf", "bah": "this param is not allowed", }}) self.assertRaises(TuneError, f) diff --git a/python/ray/tune/test/trial_scheduler_test.py b/python/ray/tune/test/trial_scheduler_test.py index 5dc6d6af7..fbc461bee 100644 --- a/python/ray/tune/test/trial_scheduler_test.py +++ b/python/ray/tune/test/trial_scheduler_test.py @@ -1,4 +1,3 @@ - from __future__ import absolute_import from __future__ import division from __future__ import print_function diff --git a/python/ray/tune/trial.py b/python/ray/tune/trial.py index 908a23d72..63085112e 100644 --- a/python/ray/tune/trial.py +++ b/python/ray/tune/trial.py @@ -335,8 +335,8 @@ class Trial(object): def update_last_result(self, result, terminate=False): if terminate: result = result._replace(done=True) - if terminate or ( - self.verbose and + if self.verbose and ( + terminate or time.time() - self.last_debug > DEBUG_PRINT_INTERVAL): print("TrainingResult for {}:".format(self)) print(" {}".format(pretty_print(result).replace("\n", "\n "))) diff --git a/python/ray/tune/trial_runner.py b/python/ray/tune/trial_runner.py index 5b4c77789..48b0f6c2e 100644 --- a/python/ray/tune/trial_runner.py +++ b/python/ray/tune/trial_runner.py @@ -40,13 +40,16 @@ class TrialRunner(object): """ def __init__(self, scheduler=None, launch_web_server=False, - server_port=TuneServer.DEFAULT_PORT): + server_port=TuneServer.DEFAULT_PORT, verbose=True): """Initializes a new TrialRunner. Args: scheduler (TrialScheduler): Defaults to FIFOScheduler. launch_web_server (bool): Flag for starting TuneServer - server_port (int): Port number for launching TuneServer""" + server_port (int): Port number for launching TuneServer + verbose (bool): Flag for verbosity. If False, trial results + will not be output. + """ self._scheduler_alg = scheduler or FIFOScheduler() self._trials = [] @@ -64,6 +67,7 @@ class TrialRunner(object): if launch_web_server: self._server = TuneServer(self, server_port) self._stop_queue = [] + self._verbose = verbose def is_finished(self): """Returns whether all trials have finished running.""" @@ -85,8 +89,9 @@ class TrialRunner(object): Callers should typically run this method repeatedly in a loop. They may inspect or modify the runner's state in between calls to step(). """ - if self._can_launch_more(): - self._launch_trial() + next_trial = self._get_next_trial() + if next_trial is not None: + self._launch_trial(next_trial) elif self._running: self._process_events() else: @@ -127,7 +132,11 @@ class TrialRunner(object): """Adds a new trial to this TrialRunner. Trials may be added at any time. + + Args: + trial (Trial): Trial to queue. """ + trial.set_verbose(self._verbose) self._scheduler_alg.on_trial_add(self, trial) self._trials.append(trial) @@ -185,13 +194,12 @@ class TrialRunner(object): resources.cpu_total() <= cpu_avail and resources.gpu_total() <= gpu_avail) - def _can_launch_more(self): + def _get_next_trial(self): self._update_avail_resources() - trial = self._get_runnable() - return trial is not None + trial = self._scheduler_alg.choose_trial_to_run(self) + return trial - def _launch_trial(self, custom_trial=None): - trial = custom_trial or self._get_runnable() + def _launch_trial(self, trial): self._commit_resources(trial.resources) try: trial.start() @@ -219,6 +227,7 @@ class TrialRunner(object): self._total_time += result.time_this_iter_s if trial.should_stop(result): + # Hook into scheduler self._scheduler_alg.on_trial_complete(self, trial, result) decision = TrialScheduler.STOP else: @@ -262,9 +271,6 @@ class TrialRunner(object): print("Error recovering trial from checkpoint, abort:", error_msg) self._stop_trial(trial, error=True, error_msg=error_msg) - def _get_runnable(self): - return self._scheduler_alg.choose_trial_to_run(self) - def _commit_resources(self, resources): self._committed_resources = Resources( self._committed_resources.cpu + resources.cpu_total(), diff --git a/python/ray/tune/trial_scheduler.py b/python/ray/tune/trial_scheduler.py index 5aa5238fc..cbb8bb564 100644 --- a/python/ray/tune/trial_scheduler.py +++ b/python/ray/tune/trial_scheduler.py @@ -3,6 +3,7 @@ from __future__ import division from __future__ import print_function from ray.tune.trial import Trial +from ray.tune.variant_generator import generate_trials class TrialScheduler(object): @@ -47,11 +48,28 @@ class TrialScheduler(object): raise NotImplementedError - def choose_trial_to_run(self, trial_runner, trials): + def add_experiment(self, experiment, trial_runner): + """Adds an experiment to the scheduler. + + The scheduler is responsible for adding the trials of the experiment + to the runner, which can be done immediately (if there are a finite + set of trials), or over time (if there is an infinite stream of trials + or if the scheduler is iterative in nature). + """ + generator = generate_trials(experiment.spec, experiment.name) + while True: + try: + trial_runner.add_trial(next(generator)) + except StopIteration: + break + + def choose_trial_to_run(self, trial_runner): """Called to choose a new trial to run. This should return one of the trials in trial_runner that is in - the PENDING or PAUSED state.""" + the PENDING or PAUSED state. This function must be idempotent. + + If no trial is ready, return None.""" raise NotImplementedError diff --git a/python/ray/tune/tune.py b/python/ray/tune/tune.py index 61e8e0722..cc71188f0 100644 --- a/python/ray/tune/tune.py +++ b/python/ray/tune/tune.py @@ -4,16 +4,16 @@ from __future__ import print_function import time -from ray.tune import TuneError +from ray.tune.error import TuneError from ray.tune.hyperband import HyperBandScheduler from ray.tune.async_hyperband import AsyncHyperBandScheduler from ray.tune.median_stopping_rule import MedianStoppingRule +from ray.tune.hpo_scheduler import HyperOptScheduler from ray.tune.trial import Trial, DEBUG_PRINT_INTERVAL from ray.tune.log_sync import wait_for_log_sync from ray.tune.trial_runner import TrialRunner from ray.tune.trial_scheduler import FIFOScheduler from ray.tune.web_server import TuneServer -from ray.tune.variant_generator import generate_trials from ray.tune.experiment import Experiment @@ -22,6 +22,7 @@ _SCHEDULERS = { "MedianStopping": MedianStoppingRule, "HyperBand": HyperBandScheduler, "AsyncHyperBand": AsyncHyperBandScheduler, + "HyperOpt": HyperOptScheduler, } @@ -42,7 +43,7 @@ def run_experiments(experiments, scheduler=None, with_server=False, experiments (Experiment | list | dict): Experiments to run. scheduler (TrialScheduler): Scheduler for executing the experiment. Choose among FIFO (default), MedianStopping, - AsyncHyperBand, or HyperBand. + AsyncHyperBand, HyperBand, or HyperOpt. with_server (bool): Starts a background Tune server. Needed for using the Client API. server_port (int): Port number for launching TuneServer. @@ -53,23 +54,21 @@ def run_experiments(experiments, scheduler=None, with_server=False, scheduler = FIFOScheduler() runner = TrialRunner( - scheduler, launch_web_server=with_server, server_port=server_port) + scheduler, launch_web_server=with_server, server_port=server_port, + verbose=verbose) + exp_list = experiments + if isinstance(experiments, Experiment): + exp_list = [experiments] + elif type(experiments) is dict: + exp_list = [Experiment.from_json(name, spec) + for name, spec in experiments.items()] - if type(experiments) is dict: - for name, spec in experiments.items(): - for trial in generate_trials(spec, name): - trial.set_verbose(verbose) - runner.add_trial(trial) - elif (type(experiments) is list and - all(isinstance(exp, Experiment) for exp in experiments)): - for experiment in experiments: - for trial in experiment.trials(): - trial.set_verbose(verbose) - runner.add_trial(trial) - elif isinstance(experiments, Experiment): - for trial in experiments.trials(): - trial.set_verbose(verbose) - runner.add_trial(trial) + if (type(exp_list) is list and + all(isinstance(exp, Experiment) for exp in exp_list)): + for experiment in exp_list: + scheduler.add_experiment(experiment, runner) + else: + raise TuneError("Invalid argument: {}".format(experiments)) print(runner.debug_string(max_debug=99999)) diff --git a/test/jenkins_tests/run_multi_node_tests.sh b/test/jenkins_tests/run_multi_node_tests.sh index c309915a0..70ee6cf55 100755 --- a/test/jenkins_tests/run_multi_node_tests.sh +++ b/test/jenkins_tests/run_multi_node_tests.sh @@ -214,6 +214,10 @@ docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \ python /ray/python/ray/tune/examples/tune_mnist_async_hyperband.py \ --smoke-test +docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \ + python /ray/python/ray/tune/examples/hyperopt_example.py \ + --smoke-test + docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \ python /ray/python/ray/rllib/examples/multiagent_mountaincar.py