mirror of
https://github.com/vale981/ray
synced 2025-03-05 18:11:42 -05:00
[tune] HyperOpt Support (v2) (#1763)
This commit is contained in:
parent
5a9e83761d
commit
888e70f1be
13 changed files with 348 additions and 44 deletions
|
@ -62,7 +62,7 @@ Features
|
||||||
|
|
||||||
Ray Tune has the following features:
|
Ray Tune has the following features:
|
||||||
|
|
||||||
- Scalable implementations of search algorithms such as `Population Based Training (PBT) <pbt.html>`__, `Median Stopping Rule <hyperband.html#median-stopping-rule>`__, and `HyperBand <hyperband.html>`__.
|
- Scalable implementations of search algorithms such as `Population Based Training (PBT) <pbt.html>`__, `Median Stopping Rule <hyperband.html#median-stopping-rule>`__, Model-Based Optimization (HyperOpt), and `HyperBand <hyperband.html>`__.
|
||||||
|
|
||||||
- Integration with visualization tools such as `TensorBoard <https://www.tensorflow.org/get_started/summaries_and_tensorboard>`__, `rllab's VisKit <https://media.readthedocs.org/pdf/rllab/latest/rllab.pdf>`__, and a `parallel coordinates visualization <https://en.wikipedia.org/wiki/Parallel_coordinates>`__.
|
- Integration with visualization tools such as `TensorBoard <https://www.tensorflow.org/get_started/summaries_and_tensorboard>`__, `rllab's VisKit <https://media.readthedocs.org/pdf/rllab/latest/rllab.pdf>`__, and a `parallel coordinates visualization <https://en.wikipedia.org/wiki/Parallel_coordinates>`__.
|
||||||
|
|
||||||
|
@ -94,12 +94,28 @@ You can find the code for Ray Tune `here on GitHub <https://github.com/ray-proje
|
||||||
Trial Schedulers
|
Trial Schedulers
|
||||||
----------------
|
----------------
|
||||||
|
|
||||||
By default, Ray Tune schedules trials in serial order with the ``FIFOScheduler`` class. However, you can also specify a custom scheduling algorithm that can early stop trials, perturb parameters, or incorporate suggestions from an external service. Currently implemented trial schedulers include `Population Based Training (PBT) <pbt.html>`__, `Median Stopping Rule <hyperband.html#median-stopping-rule>`__, and `HyperBand <hyperband.html>`__.
|
By default, Ray Tune schedules trials in serial order with the ``FIFOScheduler`` class. However, you can also specify a custom scheduling algorithm that can early stop trials, perturb parameters, or incorporate suggestions from an external service. Currently implemented trial schedulers include `Population Based Training (PBT) <pbt.html>`__, `Median Stopping Rule <hyperband.html#median-stopping-rule>`__, Model-Based Optimization (HyperOpt), and `HyperBand <hyperband.html>`__.
|
||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
run_experiments({...}, scheduler=AsyncHyperBandScheduler())
|
run_experiments({...}, scheduler=AsyncHyperBandScheduler())
|
||||||
|
|
||||||
|
|
||||||
|
HyperOpt Integration
|
||||||
|
--------------------
|
||||||
|
|
||||||
|
The``HyperOptScheduler`` is a Trial Scheduler that is backed by HyperOpt to perform sequential model-based hyperparameter optimization.
|
||||||
|
In order to use this scheduler, you will need to install HyperOpt via the following command:
|
||||||
|
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
|
$ pip install --upgrade git+git://github.com/hyperopt/hyperopt.git
|
||||||
|
|
||||||
|
An example of this can be found in `hyperopt_example.py <https://github.com/ray-project/ray/blob/master/python/ray/tune/examples/hyperopt_example.py>`__.
|
||||||
|
|
||||||
|
.. autoclass:: ray.tune.hpo_scheduler.HyperOptScheduler
|
||||||
|
|
||||||
|
|
||||||
Visualizing Results
|
Visualizing Results
|
||||||
-------------------
|
-------------------
|
||||||
|
|
||||||
|
|
|
@ -4,4 +4,5 @@ FROM ray-project/deploy
|
||||||
RUN conda install -y -c conda-forge tensorflow
|
RUN conda install -y -c conda-forge tensorflow
|
||||||
RUN apt-get install -y zlib1g-dev
|
RUN apt-get install -y zlib1g-dev
|
||||||
RUN pip install gym[atari] opencv-python==3.2.0.8 smart_open
|
RUN pip install gym[atari] opencv-python==3.2.0.8 smart_open
|
||||||
|
RUN pip install --upgrade git+git://github.com/hyperopt/hyperopt.git
|
||||||
# RUN conda install -y -q pytorch torchvision -c soumith
|
# RUN conda install -y -q pytorch torchvision -c soumith
|
||||||
|
|
|
@ -95,7 +95,8 @@ def make_parser(**kwargs):
|
||||||
"many times. Only applies if checkpointing is enabled.")
|
"many times. Only applies if checkpointing is enabled.")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--scheduler", default="FIFO", type=str,
|
"--scheduler", default="FIFO", type=str,
|
||||||
help="FIFO (default), MedianStopping, AsyncHyperBand, or HyperBand.")
|
help="FIFO (default), MedianStopping, AsyncHyperBand,"
|
||||||
|
"HyperBand, or HyperOpt.")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--scheduler-config", default="{}", type=json.loads,
|
"--scheduler-config", default="{}", type=json.loads,
|
||||||
help="Config options to pass to the scheduler.")
|
help="Config options to pass to the scheduler.")
|
||||||
|
|
45
python/ray/tune/examples/hyperopt_example.py
Normal file
45
python/ray/tune/examples/hyperopt_example.py
Normal file
|
@ -0,0 +1,45 @@
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import ray
|
||||||
|
from ray.tune import run_experiments, register_trainable
|
||||||
|
from ray.tune.hpo_scheduler import HyperOptScheduler
|
||||||
|
|
||||||
|
|
||||||
|
def easy_objective(config, reporter):
|
||||||
|
import time
|
||||||
|
time.sleep(0.2)
|
||||||
|
reporter(
|
||||||
|
timesteps_total=1,
|
||||||
|
episode_reward_mean=-((config["height"]-14) ** 2
|
||||||
|
+ abs(config["width"]-3)))
|
||||||
|
time.sleep(0.2)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
import argparse
|
||||||
|
from hyperopt import hp
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument(
|
||||||
|
"--smoke-test", action="store_true", help="Finish quickly for testing")
|
||||||
|
args, _ = parser.parse_known_args()
|
||||||
|
ray.init(redirect_output=True)
|
||||||
|
|
||||||
|
register_trainable("exp", easy_objective)
|
||||||
|
|
||||||
|
space = {
|
||||||
|
'width': hp.uniform('width', 0, 20),
|
||||||
|
'height': hp.uniform('height', -100, 100),
|
||||||
|
}
|
||||||
|
|
||||||
|
config = {"my_exp": {
|
||||||
|
"run": "exp",
|
||||||
|
"repeat": 5 if args.smoke_test else 1000,
|
||||||
|
"stop": {"training_iteration": 1},
|
||||||
|
"config": {
|
||||||
|
"space": space}}}
|
||||||
|
hpo_sched = HyperOptScheduler()
|
||||||
|
|
||||||
|
run_experiments(config, verbose=False, scheduler=hpo_sched)
|
|
@ -2,8 +2,8 @@ from __future__ import absolute_import
|
||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
from ray.tune.variant_generator import generate_trials
|
|
||||||
from ray.tune.result import DEFAULT_RESULTS_DIR
|
from ray.tune.result import DEFAULT_RESULTS_DIR
|
||||||
|
from ray.tune.error import TuneError
|
||||||
|
|
||||||
|
|
||||||
class Experiment(object):
|
class Experiment(object):
|
||||||
|
@ -49,8 +49,21 @@ class Experiment(object):
|
||||||
"checkpoint_freq": checkpoint_freq,
|
"checkpoint_freq": checkpoint_freq,
|
||||||
"max_failures": max_failures
|
"max_failures": max_failures
|
||||||
}
|
}
|
||||||
self._trials = generate_trials(spec, name)
|
|
||||||
|
|
||||||
def trials(self):
|
self.name = name
|
||||||
for trial in self._trials:
|
self.spec = spec
|
||||||
yield trial
|
|
||||||
|
@classmethod
|
||||||
|
def from_json(cls, name, spec):
|
||||||
|
"""Generates an Experiment object from JSON.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name (str): Name of Experiment.
|
||||||
|
spec (dict): JSON configuration of experiment.
|
||||||
|
"""
|
||||||
|
if "run" not in spec:
|
||||||
|
raise TuneError("No trainable specified!")
|
||||||
|
exp = cls(name, spec["run"])
|
||||||
|
exp.name = name
|
||||||
|
exp.spec = spec
|
||||||
|
return exp
|
||||||
|
|
201
python/ray/tune/hpo_scheduler.py
Normal file
201
python/ray/tune/hpo_scheduler.py
Normal file
|
@ -0,0 +1,201 @@
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import os
|
||||||
|
import copy
|
||||||
|
import numpy as np
|
||||||
|
try:
|
||||||
|
import hyperopt as hpo
|
||||||
|
except Exception as e:
|
||||||
|
hpo = None
|
||||||
|
|
||||||
|
from ray.tune.trial import Trial
|
||||||
|
from ray.tune.error import TuneError
|
||||||
|
from ray.tune.trial_scheduler import TrialScheduler, FIFOScheduler
|
||||||
|
from ray.tune.config_parser import make_parser
|
||||||
|
from ray.tune.variant_generator import to_argv
|
||||||
|
|
||||||
|
|
||||||
|
class HyperOptScheduler(FIFOScheduler):
|
||||||
|
"""FIFOScheduler that uses HyperOpt to provide trial suggestions.
|
||||||
|
|
||||||
|
Requires HyperOpt to be installed via source.
|
||||||
|
Uses the Tree-structured Parzen Estimators algorithm. Externally added
|
||||||
|
trials will not be tracked by HyperOpt. Also,
|
||||||
|
variant generation will be limited, as the hyperparameter configuration
|
||||||
|
must be specified using HyperOpt primitives.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
max_concurrent (int | None): Number of maximum concurrent trials.
|
||||||
|
If None, then trials will be queued only if resources
|
||||||
|
are available.
|
||||||
|
reward_attr (str): The TrainingResult objective value attribute.
|
||||||
|
This refers to an increasing value, which is internally negated
|
||||||
|
when interacting with HyperOpt. Suggestion procedures
|
||||||
|
will use this attribute.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> space = {'param': hp.uniform('param', 0, 20)}
|
||||||
|
>>> config = {"my_exp": {
|
||||||
|
"run": "exp",
|
||||||
|
"repeat": 5,
|
||||||
|
"config": {"space": space}}}
|
||||||
|
>>> run_experiments(config, scheduler=HyperOptScheduler())
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, max_concurrent=None, reward_attr="episode_reward_mean"):
|
||||||
|
assert hpo is not None, "HyperOpt must be installed!"
|
||||||
|
assert type(max_concurrent) in [type(None), int]
|
||||||
|
if type(max_concurrent) is int:
|
||||||
|
assert max_concurrent > 0
|
||||||
|
self._max_concurrent = max_concurrent # NOTE: this is modified later
|
||||||
|
self._reward_attr = reward_attr
|
||||||
|
self._experiment = None
|
||||||
|
|
||||||
|
def add_experiment(self, experiment, trial_runner):
|
||||||
|
"""Tracks one experiment.
|
||||||
|
|
||||||
|
Will error if one tries to track multiple experiments.
|
||||||
|
"""
|
||||||
|
assert self._experiment is None, "HyperOpt only tracks one experiment!"
|
||||||
|
self._experiment = experiment
|
||||||
|
|
||||||
|
self._output_path = experiment.name
|
||||||
|
spec = copy.deepcopy(experiment.spec)
|
||||||
|
|
||||||
|
# Set Scheduler field, as Tune Parser will default to FIFO
|
||||||
|
assert spec.get("scheduler") in [None, "HyperOpt"], "Incorrectly " \
|
||||||
|
"specified scheduler!"
|
||||||
|
spec["scheduler"] = "HyperOpt"
|
||||||
|
|
||||||
|
if "env" in spec:
|
||||||
|
spec["config"] = spec.get("config", {})
|
||||||
|
spec["config"]["env"] = spec["env"]
|
||||||
|
del spec["env"]
|
||||||
|
|
||||||
|
space = spec["config"]["space"]
|
||||||
|
del spec["config"]["space"]
|
||||||
|
|
||||||
|
self.parser = make_parser()
|
||||||
|
self.args = self.parser.parse_args(to_argv(spec))
|
||||||
|
self.args.scheduler = "HyperOpt"
|
||||||
|
self.default_config = copy.deepcopy(spec["config"])
|
||||||
|
|
||||||
|
self.algo = hpo.tpe.suggest
|
||||||
|
self.domain = hpo.Domain(lambda spc: spc, space)
|
||||||
|
self._hpopt_trials = hpo.Trials()
|
||||||
|
self._tune_to_hp = {}
|
||||||
|
self._num_trials_left = self.args.repeat
|
||||||
|
|
||||||
|
if type(self._max_concurrent) is int:
|
||||||
|
self._max_concurrent = min(self._max_concurrent, self.args.repeat)
|
||||||
|
|
||||||
|
self.rstate = np.random.RandomState()
|
||||||
|
self.trial_generator = self._trial_generator()
|
||||||
|
self._add_new_trials_if_needed(trial_runner)
|
||||||
|
|
||||||
|
def _trial_generator(self):
|
||||||
|
while self._num_trials_left > 0:
|
||||||
|
new_cfg = copy.deepcopy(self.default_config)
|
||||||
|
new_ids = self._hpopt_trials.new_trial_ids(1)
|
||||||
|
self._hpopt_trials.refresh()
|
||||||
|
|
||||||
|
# Get new suggestion from
|
||||||
|
new_trials = self.algo(
|
||||||
|
new_ids, self.domain, self._hpopt_trials,
|
||||||
|
self.rstate.randint(2 ** 31 - 1))
|
||||||
|
self._hpopt_trials.insert_trial_docs(new_trials)
|
||||||
|
self._hpopt_trials.refresh()
|
||||||
|
new_trial = new_trials[0]
|
||||||
|
new_trial_id = new_trial["tid"]
|
||||||
|
suggested_config = hpo.base.spec_from_misc(new_trial["misc"])
|
||||||
|
new_cfg.update(suggested_config)
|
||||||
|
|
||||||
|
kv_str = "_".join(["{}={}".format(k, str(v)[:5])
|
||||||
|
for k, v in sorted(suggested_config.items())])
|
||||||
|
experiment_tag = "{}_{}".format(new_trial_id, kv_str)
|
||||||
|
|
||||||
|
# Keep this consistent with tune.variant_generator
|
||||||
|
trial = Trial(
|
||||||
|
trainable_name=self.args.run,
|
||||||
|
config=new_cfg,
|
||||||
|
local_dir=os.path.join(self.args.local_dir, self._output_path),
|
||||||
|
experiment_tag=experiment_tag,
|
||||||
|
resources=self.args.trial_resources,
|
||||||
|
stopping_criterion=self.args.stop,
|
||||||
|
checkpoint_freq=self.args.checkpoint_freq,
|
||||||
|
restore_path=self.args.restore,
|
||||||
|
upload_dir=self.args.upload_dir,
|
||||||
|
max_failures=self.args.max_failures)
|
||||||
|
|
||||||
|
self._tune_to_hp[trial] = new_trial_id
|
||||||
|
self._num_trials_left -= 1
|
||||||
|
yield trial
|
||||||
|
|
||||||
|
def on_trial_result(self, trial_runner, trial, result):
|
||||||
|
ho_trial = self._get_hyperopt_trial(self._tune_to_hp[trial])
|
||||||
|
now = hpo.utils.coarse_utcnow()
|
||||||
|
ho_trial['book_time'] = now
|
||||||
|
ho_trial['refresh_time'] = now
|
||||||
|
return TrialScheduler.CONTINUE
|
||||||
|
|
||||||
|
def on_trial_error(self, trial_runner, trial):
|
||||||
|
ho_trial = self._get_hyperopt_trial(self._tune_to_hp[trial])
|
||||||
|
ho_trial['refresh_time'] = hpo.utils.coarse_utcnow()
|
||||||
|
ho_trial['state'] = hpo.base.JOB_STATE_ERROR
|
||||||
|
ho_trial['misc']['error'] = (str(TuneError), "Tune Error")
|
||||||
|
self._hpopt_trials.refresh()
|
||||||
|
del self._tune_to_hp[trial]
|
||||||
|
|
||||||
|
def on_trial_remove(self, trial_runner, trial):
|
||||||
|
ho_trial = self._get_hyperopt_trial(self._tune_to_hp[trial])
|
||||||
|
ho_trial['refresh_time'] = hpo.utils.coarse_utcnow()
|
||||||
|
ho_trial['state'] = hpo.base.JOB_STATE_ERROR
|
||||||
|
ho_trial['misc']['error'] = (str(TuneError), "Tune Removed")
|
||||||
|
self._hpopt_trials.refresh()
|
||||||
|
del self._tune_to_hp[trial]
|
||||||
|
|
||||||
|
def on_trial_complete(self, trial_runner, trial, result):
|
||||||
|
ho_trial = self._get_hyperopt_trial(self._tune_to_hp[trial])
|
||||||
|
ho_trial['refresh_time'] = hpo.utils.coarse_utcnow()
|
||||||
|
ho_trial['state'] = hpo.base.JOB_STATE_DONE
|
||||||
|
hp_result = self._to_hyperopt_result(result)
|
||||||
|
ho_trial['result'] = hp_result
|
||||||
|
self._hpopt_trials.refresh()
|
||||||
|
del self._tune_to_hp[trial]
|
||||||
|
|
||||||
|
def _to_hyperopt_result(self, result):
|
||||||
|
return {"loss": -getattr(result, self._reward_attr),
|
||||||
|
"status": "ok"}
|
||||||
|
|
||||||
|
def _get_hyperopt_trial(self, tid):
|
||||||
|
return [t for t in self._hpopt_trials.trials if t["tid"] == tid][0]
|
||||||
|
|
||||||
|
def choose_trial_to_run(self, trial_runner):
|
||||||
|
self._add_new_trials_if_needed(trial_runner)
|
||||||
|
return FIFOScheduler.choose_trial_to_run(self, trial_runner)
|
||||||
|
|
||||||
|
def _add_new_trials_if_needed(self, trial_runner):
|
||||||
|
"""Checks if there is a next trial ready to be queued.
|
||||||
|
|
||||||
|
This is determined by tracking the number of concurrent
|
||||||
|
experiments and trials left to run. If self._max_concurrent is None,
|
||||||
|
scheduler will add new trial if there is none that are pending.
|
||||||
|
"""
|
||||||
|
pending = [t for t in trial_runner.get_trials()
|
||||||
|
if t.status == Trial.PENDING]
|
||||||
|
if self._num_trials_left <= 0:
|
||||||
|
return
|
||||||
|
if self._max_concurrent is None:
|
||||||
|
if not pending:
|
||||||
|
trial_runner.add_trial(next(self.trial_generator))
|
||||||
|
else:
|
||||||
|
while self._num_live_trials() < self._max_concurrent:
|
||||||
|
try:
|
||||||
|
trial_runner.add_trial(next(self.trial_generator))
|
||||||
|
except StopIteration:
|
||||||
|
break
|
||||||
|
|
||||||
|
def _num_live_trials(self):
|
||||||
|
return len(self._tune_to_hp)
|
|
@ -100,6 +100,7 @@ class TrainableFunctionApiTest(unittest.TestCase):
|
||||||
def testBadParams2(self):
|
def testBadParams2(self):
|
||||||
def f():
|
def f():
|
||||||
run_experiments({"foo": {
|
run_experiments({"foo": {
|
||||||
|
"run": "asdf",
|
||||||
"bah": "this param is not allowed",
|
"bah": "this param is not allowed",
|
||||||
}})
|
}})
|
||||||
self.assertRaises(TuneError, f)
|
self.assertRaises(TuneError, f)
|
||||||
|
|
|
@ -1,4 +1,3 @@
|
||||||
|
|
||||||
from __future__ import absolute_import
|
from __future__ import absolute_import
|
||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
|
@ -335,8 +335,8 @@ class Trial(object):
|
||||||
def update_last_result(self, result, terminate=False):
|
def update_last_result(self, result, terminate=False):
|
||||||
if terminate:
|
if terminate:
|
||||||
result = result._replace(done=True)
|
result = result._replace(done=True)
|
||||||
if terminate or (
|
if self.verbose and (
|
||||||
self.verbose and
|
terminate or
|
||||||
time.time() - self.last_debug > DEBUG_PRINT_INTERVAL):
|
time.time() - self.last_debug > DEBUG_PRINT_INTERVAL):
|
||||||
print("TrainingResult for {}:".format(self))
|
print("TrainingResult for {}:".format(self))
|
||||||
print(" {}".format(pretty_print(result).replace("\n", "\n ")))
|
print(" {}".format(pretty_print(result).replace("\n", "\n ")))
|
||||||
|
|
|
@ -40,13 +40,16 @@ class TrialRunner(object):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, scheduler=None, launch_web_server=False,
|
def __init__(self, scheduler=None, launch_web_server=False,
|
||||||
server_port=TuneServer.DEFAULT_PORT):
|
server_port=TuneServer.DEFAULT_PORT, verbose=True):
|
||||||
"""Initializes a new TrialRunner.
|
"""Initializes a new TrialRunner.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
scheduler (TrialScheduler): Defaults to FIFOScheduler.
|
scheduler (TrialScheduler): Defaults to FIFOScheduler.
|
||||||
launch_web_server (bool): Flag for starting TuneServer
|
launch_web_server (bool): Flag for starting TuneServer
|
||||||
server_port (int): Port number for launching TuneServer"""
|
server_port (int): Port number for launching TuneServer
|
||||||
|
verbose (bool): Flag for verbosity. If False, trial results
|
||||||
|
will not be output.
|
||||||
|
"""
|
||||||
|
|
||||||
self._scheduler_alg = scheduler or FIFOScheduler()
|
self._scheduler_alg = scheduler or FIFOScheduler()
|
||||||
self._trials = []
|
self._trials = []
|
||||||
|
@ -64,6 +67,7 @@ class TrialRunner(object):
|
||||||
if launch_web_server:
|
if launch_web_server:
|
||||||
self._server = TuneServer(self, server_port)
|
self._server = TuneServer(self, server_port)
|
||||||
self._stop_queue = []
|
self._stop_queue = []
|
||||||
|
self._verbose = verbose
|
||||||
|
|
||||||
def is_finished(self):
|
def is_finished(self):
|
||||||
"""Returns whether all trials have finished running."""
|
"""Returns whether all trials have finished running."""
|
||||||
|
@ -85,8 +89,9 @@ class TrialRunner(object):
|
||||||
Callers should typically run this method repeatedly in a loop. They
|
Callers should typically run this method repeatedly in a loop. They
|
||||||
may inspect or modify the runner's state in between calls to step().
|
may inspect or modify the runner's state in between calls to step().
|
||||||
"""
|
"""
|
||||||
if self._can_launch_more():
|
next_trial = self._get_next_trial()
|
||||||
self._launch_trial()
|
if next_trial is not None:
|
||||||
|
self._launch_trial(next_trial)
|
||||||
elif self._running:
|
elif self._running:
|
||||||
self._process_events()
|
self._process_events()
|
||||||
else:
|
else:
|
||||||
|
@ -127,7 +132,11 @@ class TrialRunner(object):
|
||||||
"""Adds a new trial to this TrialRunner.
|
"""Adds a new trial to this TrialRunner.
|
||||||
|
|
||||||
Trials may be added at any time.
|
Trials may be added at any time.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
trial (Trial): Trial to queue.
|
||||||
"""
|
"""
|
||||||
|
trial.set_verbose(self._verbose)
|
||||||
self._scheduler_alg.on_trial_add(self, trial)
|
self._scheduler_alg.on_trial_add(self, trial)
|
||||||
self._trials.append(trial)
|
self._trials.append(trial)
|
||||||
|
|
||||||
|
@ -185,13 +194,12 @@ class TrialRunner(object):
|
||||||
resources.cpu_total() <= cpu_avail and
|
resources.cpu_total() <= cpu_avail and
|
||||||
resources.gpu_total() <= gpu_avail)
|
resources.gpu_total() <= gpu_avail)
|
||||||
|
|
||||||
def _can_launch_more(self):
|
def _get_next_trial(self):
|
||||||
self._update_avail_resources()
|
self._update_avail_resources()
|
||||||
trial = self._get_runnable()
|
trial = self._scheduler_alg.choose_trial_to_run(self)
|
||||||
return trial is not None
|
return trial
|
||||||
|
|
||||||
def _launch_trial(self, custom_trial=None):
|
def _launch_trial(self, trial):
|
||||||
trial = custom_trial or self._get_runnable()
|
|
||||||
self._commit_resources(trial.resources)
|
self._commit_resources(trial.resources)
|
||||||
try:
|
try:
|
||||||
trial.start()
|
trial.start()
|
||||||
|
@ -219,6 +227,7 @@ class TrialRunner(object):
|
||||||
self._total_time += result.time_this_iter_s
|
self._total_time += result.time_this_iter_s
|
||||||
|
|
||||||
if trial.should_stop(result):
|
if trial.should_stop(result):
|
||||||
|
# Hook into scheduler
|
||||||
self._scheduler_alg.on_trial_complete(self, trial, result)
|
self._scheduler_alg.on_trial_complete(self, trial, result)
|
||||||
decision = TrialScheduler.STOP
|
decision = TrialScheduler.STOP
|
||||||
else:
|
else:
|
||||||
|
@ -262,9 +271,6 @@ class TrialRunner(object):
|
||||||
print("Error recovering trial from checkpoint, abort:", error_msg)
|
print("Error recovering trial from checkpoint, abort:", error_msg)
|
||||||
self._stop_trial(trial, error=True, error_msg=error_msg)
|
self._stop_trial(trial, error=True, error_msg=error_msg)
|
||||||
|
|
||||||
def _get_runnable(self):
|
|
||||||
return self._scheduler_alg.choose_trial_to_run(self)
|
|
||||||
|
|
||||||
def _commit_resources(self, resources):
|
def _commit_resources(self, resources):
|
||||||
self._committed_resources = Resources(
|
self._committed_resources = Resources(
|
||||||
self._committed_resources.cpu + resources.cpu_total(),
|
self._committed_resources.cpu + resources.cpu_total(),
|
||||||
|
|
|
@ -3,6 +3,7 @@ from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
from ray.tune.trial import Trial
|
from ray.tune.trial import Trial
|
||||||
|
from ray.tune.variant_generator import generate_trials
|
||||||
|
|
||||||
|
|
||||||
class TrialScheduler(object):
|
class TrialScheduler(object):
|
||||||
|
@ -47,11 +48,28 @@ class TrialScheduler(object):
|
||||||
|
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def choose_trial_to_run(self, trial_runner, trials):
|
def add_experiment(self, experiment, trial_runner):
|
||||||
|
"""Adds an experiment to the scheduler.
|
||||||
|
|
||||||
|
The scheduler is responsible for adding the trials of the experiment
|
||||||
|
to the runner, which can be done immediately (if there are a finite
|
||||||
|
set of trials), or over time (if there is an infinite stream of trials
|
||||||
|
or if the scheduler is iterative in nature).
|
||||||
|
"""
|
||||||
|
generator = generate_trials(experiment.spec, experiment.name)
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
trial_runner.add_trial(next(generator))
|
||||||
|
except StopIteration:
|
||||||
|
break
|
||||||
|
|
||||||
|
def choose_trial_to_run(self, trial_runner):
|
||||||
"""Called to choose a new trial to run.
|
"""Called to choose a new trial to run.
|
||||||
|
|
||||||
This should return one of the trials in trial_runner that is in
|
This should return one of the trials in trial_runner that is in
|
||||||
the PENDING or PAUSED state."""
|
the PENDING or PAUSED state. This function must be idempotent.
|
||||||
|
|
||||||
|
If no trial is ready, return None."""
|
||||||
|
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
|
@ -4,16 +4,16 @@ from __future__ import print_function
|
||||||
|
|
||||||
import time
|
import time
|
||||||
|
|
||||||
from ray.tune import TuneError
|
from ray.tune.error import TuneError
|
||||||
from ray.tune.hyperband import HyperBandScheduler
|
from ray.tune.hyperband import HyperBandScheduler
|
||||||
from ray.tune.async_hyperband import AsyncHyperBandScheduler
|
from ray.tune.async_hyperband import AsyncHyperBandScheduler
|
||||||
from ray.tune.median_stopping_rule import MedianStoppingRule
|
from ray.tune.median_stopping_rule import MedianStoppingRule
|
||||||
|
from ray.tune.hpo_scheduler import HyperOptScheduler
|
||||||
from ray.tune.trial import Trial, DEBUG_PRINT_INTERVAL
|
from ray.tune.trial import Trial, DEBUG_PRINT_INTERVAL
|
||||||
from ray.tune.log_sync import wait_for_log_sync
|
from ray.tune.log_sync import wait_for_log_sync
|
||||||
from ray.tune.trial_runner import TrialRunner
|
from ray.tune.trial_runner import TrialRunner
|
||||||
from ray.tune.trial_scheduler import FIFOScheduler
|
from ray.tune.trial_scheduler import FIFOScheduler
|
||||||
from ray.tune.web_server import TuneServer
|
from ray.tune.web_server import TuneServer
|
||||||
from ray.tune.variant_generator import generate_trials
|
|
||||||
from ray.tune.experiment import Experiment
|
from ray.tune.experiment import Experiment
|
||||||
|
|
||||||
|
|
||||||
|
@ -22,6 +22,7 @@ _SCHEDULERS = {
|
||||||
"MedianStopping": MedianStoppingRule,
|
"MedianStopping": MedianStoppingRule,
|
||||||
"HyperBand": HyperBandScheduler,
|
"HyperBand": HyperBandScheduler,
|
||||||
"AsyncHyperBand": AsyncHyperBandScheduler,
|
"AsyncHyperBand": AsyncHyperBandScheduler,
|
||||||
|
"HyperOpt": HyperOptScheduler,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -42,7 +43,7 @@ def run_experiments(experiments, scheduler=None, with_server=False,
|
||||||
experiments (Experiment | list | dict): Experiments to run.
|
experiments (Experiment | list | dict): Experiments to run.
|
||||||
scheduler (TrialScheduler): Scheduler for executing
|
scheduler (TrialScheduler): Scheduler for executing
|
||||||
the experiment. Choose among FIFO (default), MedianStopping,
|
the experiment. Choose among FIFO (default), MedianStopping,
|
||||||
AsyncHyperBand, or HyperBand.
|
AsyncHyperBand, HyperBand, or HyperOpt.
|
||||||
with_server (bool): Starts a background Tune server. Needed for
|
with_server (bool): Starts a background Tune server. Needed for
|
||||||
using the Client API.
|
using the Client API.
|
||||||
server_port (int): Port number for launching TuneServer.
|
server_port (int): Port number for launching TuneServer.
|
||||||
|
@ -53,23 +54,21 @@ def run_experiments(experiments, scheduler=None, with_server=False,
|
||||||
scheduler = FIFOScheduler()
|
scheduler = FIFOScheduler()
|
||||||
|
|
||||||
runner = TrialRunner(
|
runner = TrialRunner(
|
||||||
scheduler, launch_web_server=with_server, server_port=server_port)
|
scheduler, launch_web_server=with_server, server_port=server_port,
|
||||||
|
verbose=verbose)
|
||||||
|
exp_list = experiments
|
||||||
|
if isinstance(experiments, Experiment):
|
||||||
|
exp_list = [experiments]
|
||||||
|
elif type(experiments) is dict:
|
||||||
|
exp_list = [Experiment.from_json(name, spec)
|
||||||
|
for name, spec in experiments.items()]
|
||||||
|
|
||||||
if type(experiments) is dict:
|
if (type(exp_list) is list and
|
||||||
for name, spec in experiments.items():
|
all(isinstance(exp, Experiment) for exp in exp_list)):
|
||||||
for trial in generate_trials(spec, name):
|
for experiment in exp_list:
|
||||||
trial.set_verbose(verbose)
|
scheduler.add_experiment(experiment, runner)
|
||||||
runner.add_trial(trial)
|
else:
|
||||||
elif (type(experiments) is list and
|
raise TuneError("Invalid argument: {}".format(experiments))
|
||||||
all(isinstance(exp, Experiment) for exp in experiments)):
|
|
||||||
for experiment in experiments:
|
|
||||||
for trial in experiment.trials():
|
|
||||||
trial.set_verbose(verbose)
|
|
||||||
runner.add_trial(trial)
|
|
||||||
elif isinstance(experiments, Experiment):
|
|
||||||
for trial in experiments.trials():
|
|
||||||
trial.set_verbose(verbose)
|
|
||||||
runner.add_trial(trial)
|
|
||||||
|
|
||||||
print(runner.debug_string(max_debug=99999))
|
print(runner.debug_string(max_debug=99999))
|
||||||
|
|
||||||
|
|
|
@ -214,6 +214,10 @@ docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \
|
||||||
python /ray/python/ray/tune/examples/tune_mnist_async_hyperband.py \
|
python /ray/python/ray/tune/examples/tune_mnist_async_hyperband.py \
|
||||||
--smoke-test
|
--smoke-test
|
||||||
|
|
||||||
|
docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \
|
||||||
|
python /ray/python/ray/tune/examples/hyperopt_example.py \
|
||||||
|
--smoke-test
|
||||||
|
|
||||||
docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \
|
docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \
|
||||||
python /ray/python/ray/rllib/examples/multiagent_mountaincar.py
|
python /ray/python/ray/rllib/examples/multiagent_mountaincar.py
|
||||||
|
|
||||||
|
|
Loading…
Add table
Reference in a new issue