mirror of
https://github.com/vale981/ray
synced 2025-03-05 10:01:43 -05:00
[tune] HyperOpt Support (v2) (#1763)
This commit is contained in:
parent
5a9e83761d
commit
888e70f1be
13 changed files with 348 additions and 44 deletions
|
@ -62,7 +62,7 @@ Features
|
|||
|
||||
Ray Tune has the following features:
|
||||
|
||||
- Scalable implementations of search algorithms such as `Population Based Training (PBT) <pbt.html>`__, `Median Stopping Rule <hyperband.html#median-stopping-rule>`__, and `HyperBand <hyperband.html>`__.
|
||||
- Scalable implementations of search algorithms such as `Population Based Training (PBT) <pbt.html>`__, `Median Stopping Rule <hyperband.html#median-stopping-rule>`__, Model-Based Optimization (HyperOpt), and `HyperBand <hyperband.html>`__.
|
||||
|
||||
- Integration with visualization tools such as `TensorBoard <https://www.tensorflow.org/get_started/summaries_and_tensorboard>`__, `rllab's VisKit <https://media.readthedocs.org/pdf/rllab/latest/rllab.pdf>`__, and a `parallel coordinates visualization <https://en.wikipedia.org/wiki/Parallel_coordinates>`__.
|
||||
|
||||
|
@ -94,12 +94,28 @@ You can find the code for Ray Tune `here on GitHub <https://github.com/ray-proje
|
|||
Trial Schedulers
|
||||
----------------
|
||||
|
||||
By default, Ray Tune schedules trials in serial order with the ``FIFOScheduler`` class. However, you can also specify a custom scheduling algorithm that can early stop trials, perturb parameters, or incorporate suggestions from an external service. Currently implemented trial schedulers include `Population Based Training (PBT) <pbt.html>`__, `Median Stopping Rule <hyperband.html#median-stopping-rule>`__, and `HyperBand <hyperband.html>`__.
|
||||
By default, Ray Tune schedules trials in serial order with the ``FIFOScheduler`` class. However, you can also specify a custom scheduling algorithm that can early stop trials, perturb parameters, or incorporate suggestions from an external service. Currently implemented trial schedulers include `Population Based Training (PBT) <pbt.html>`__, `Median Stopping Rule <hyperband.html#median-stopping-rule>`__, Model-Based Optimization (HyperOpt), and `HyperBand <hyperband.html>`__.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
run_experiments({...}, scheduler=AsyncHyperBandScheduler())
|
||||
|
||||
|
||||
HyperOpt Integration
|
||||
--------------------
|
||||
|
||||
The``HyperOptScheduler`` is a Trial Scheduler that is backed by HyperOpt to perform sequential model-based hyperparameter optimization.
|
||||
In order to use this scheduler, you will need to install HyperOpt via the following command:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
$ pip install --upgrade git+git://github.com/hyperopt/hyperopt.git
|
||||
|
||||
An example of this can be found in `hyperopt_example.py <https://github.com/ray-project/ray/blob/master/python/ray/tune/examples/hyperopt_example.py>`__.
|
||||
|
||||
.. autoclass:: ray.tune.hpo_scheduler.HyperOptScheduler
|
||||
|
||||
|
||||
Visualizing Results
|
||||
-------------------
|
||||
|
||||
|
|
|
@ -4,4 +4,5 @@ FROM ray-project/deploy
|
|||
RUN conda install -y -c conda-forge tensorflow
|
||||
RUN apt-get install -y zlib1g-dev
|
||||
RUN pip install gym[atari] opencv-python==3.2.0.8 smart_open
|
||||
RUN pip install --upgrade git+git://github.com/hyperopt/hyperopt.git
|
||||
# RUN conda install -y -q pytorch torchvision -c soumith
|
||||
|
|
|
@ -95,7 +95,8 @@ def make_parser(**kwargs):
|
|||
"many times. Only applies if checkpointing is enabled.")
|
||||
parser.add_argument(
|
||||
"--scheduler", default="FIFO", type=str,
|
||||
help="FIFO (default), MedianStopping, AsyncHyperBand, or HyperBand.")
|
||||
help="FIFO (default), MedianStopping, AsyncHyperBand,"
|
||||
"HyperBand, or HyperOpt.")
|
||||
parser.add_argument(
|
||||
"--scheduler-config", default="{}", type=json.loads,
|
||||
help="Config options to pass to the scheduler.")
|
||||
|
|
45
python/ray/tune/examples/hyperopt_example.py
Normal file
45
python/ray/tune/examples/hyperopt_example.py
Normal file
|
@ -0,0 +1,45 @@
|
|||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import ray
|
||||
from ray.tune import run_experiments, register_trainable
|
||||
from ray.tune.hpo_scheduler import HyperOptScheduler
|
||||
|
||||
|
||||
def easy_objective(config, reporter):
|
||||
import time
|
||||
time.sleep(0.2)
|
||||
reporter(
|
||||
timesteps_total=1,
|
||||
episode_reward_mean=-((config["height"]-14) ** 2
|
||||
+ abs(config["width"]-3)))
|
||||
time.sleep(0.2)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
import argparse
|
||||
from hyperopt import hp
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--smoke-test", action="store_true", help="Finish quickly for testing")
|
||||
args, _ = parser.parse_known_args()
|
||||
ray.init(redirect_output=True)
|
||||
|
||||
register_trainable("exp", easy_objective)
|
||||
|
||||
space = {
|
||||
'width': hp.uniform('width', 0, 20),
|
||||
'height': hp.uniform('height', -100, 100),
|
||||
}
|
||||
|
||||
config = {"my_exp": {
|
||||
"run": "exp",
|
||||
"repeat": 5 if args.smoke_test else 1000,
|
||||
"stop": {"training_iteration": 1},
|
||||
"config": {
|
||||
"space": space}}}
|
||||
hpo_sched = HyperOptScheduler()
|
||||
|
||||
run_experiments(config, verbose=False, scheduler=hpo_sched)
|
|
@ -2,8 +2,8 @@ from __future__ import absolute_import
|
|||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from ray.tune.variant_generator import generate_trials
|
||||
from ray.tune.result import DEFAULT_RESULTS_DIR
|
||||
from ray.tune.error import TuneError
|
||||
|
||||
|
||||
class Experiment(object):
|
||||
|
@ -49,8 +49,21 @@ class Experiment(object):
|
|||
"checkpoint_freq": checkpoint_freq,
|
||||
"max_failures": max_failures
|
||||
}
|
||||
self._trials = generate_trials(spec, name)
|
||||
|
||||
def trials(self):
|
||||
for trial in self._trials:
|
||||
yield trial
|
||||
self.name = name
|
||||
self.spec = spec
|
||||
|
||||
@classmethod
|
||||
def from_json(cls, name, spec):
|
||||
"""Generates an Experiment object from JSON.
|
||||
|
||||
Args:
|
||||
name (str): Name of Experiment.
|
||||
spec (dict): JSON configuration of experiment.
|
||||
"""
|
||||
if "run" not in spec:
|
||||
raise TuneError("No trainable specified!")
|
||||
exp = cls(name, spec["run"])
|
||||
exp.name = name
|
||||
exp.spec = spec
|
||||
return exp
|
||||
|
|
201
python/ray/tune/hpo_scheduler.py
Normal file
201
python/ray/tune/hpo_scheduler.py
Normal file
|
@ -0,0 +1,201 @@
|
|||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import copy
|
||||
import numpy as np
|
||||
try:
|
||||
import hyperopt as hpo
|
||||
except Exception as e:
|
||||
hpo = None
|
||||
|
||||
from ray.tune.trial import Trial
|
||||
from ray.tune.error import TuneError
|
||||
from ray.tune.trial_scheduler import TrialScheduler, FIFOScheduler
|
||||
from ray.tune.config_parser import make_parser
|
||||
from ray.tune.variant_generator import to_argv
|
||||
|
||||
|
||||
class HyperOptScheduler(FIFOScheduler):
|
||||
"""FIFOScheduler that uses HyperOpt to provide trial suggestions.
|
||||
|
||||
Requires HyperOpt to be installed via source.
|
||||
Uses the Tree-structured Parzen Estimators algorithm. Externally added
|
||||
trials will not be tracked by HyperOpt. Also,
|
||||
variant generation will be limited, as the hyperparameter configuration
|
||||
must be specified using HyperOpt primitives.
|
||||
|
||||
Parameters:
|
||||
max_concurrent (int | None): Number of maximum concurrent trials.
|
||||
If None, then trials will be queued only if resources
|
||||
are available.
|
||||
reward_attr (str): The TrainingResult objective value attribute.
|
||||
This refers to an increasing value, which is internally negated
|
||||
when interacting with HyperOpt. Suggestion procedures
|
||||
will use this attribute.
|
||||
|
||||
Examples:
|
||||
>>> space = {'param': hp.uniform('param', 0, 20)}
|
||||
>>> config = {"my_exp": {
|
||||
"run": "exp",
|
||||
"repeat": 5,
|
||||
"config": {"space": space}}}
|
||||
>>> run_experiments(config, scheduler=HyperOptScheduler())
|
||||
"""
|
||||
|
||||
def __init__(self, max_concurrent=None, reward_attr="episode_reward_mean"):
|
||||
assert hpo is not None, "HyperOpt must be installed!"
|
||||
assert type(max_concurrent) in [type(None), int]
|
||||
if type(max_concurrent) is int:
|
||||
assert max_concurrent > 0
|
||||
self._max_concurrent = max_concurrent # NOTE: this is modified later
|
||||
self._reward_attr = reward_attr
|
||||
self._experiment = None
|
||||
|
||||
def add_experiment(self, experiment, trial_runner):
|
||||
"""Tracks one experiment.
|
||||
|
||||
Will error if one tries to track multiple experiments.
|
||||
"""
|
||||
assert self._experiment is None, "HyperOpt only tracks one experiment!"
|
||||
self._experiment = experiment
|
||||
|
||||
self._output_path = experiment.name
|
||||
spec = copy.deepcopy(experiment.spec)
|
||||
|
||||
# Set Scheduler field, as Tune Parser will default to FIFO
|
||||
assert spec.get("scheduler") in [None, "HyperOpt"], "Incorrectly " \
|
||||
"specified scheduler!"
|
||||
spec["scheduler"] = "HyperOpt"
|
||||
|
||||
if "env" in spec:
|
||||
spec["config"] = spec.get("config", {})
|
||||
spec["config"]["env"] = spec["env"]
|
||||
del spec["env"]
|
||||
|
||||
space = spec["config"]["space"]
|
||||
del spec["config"]["space"]
|
||||
|
||||
self.parser = make_parser()
|
||||
self.args = self.parser.parse_args(to_argv(spec))
|
||||
self.args.scheduler = "HyperOpt"
|
||||
self.default_config = copy.deepcopy(spec["config"])
|
||||
|
||||
self.algo = hpo.tpe.suggest
|
||||
self.domain = hpo.Domain(lambda spc: spc, space)
|
||||
self._hpopt_trials = hpo.Trials()
|
||||
self._tune_to_hp = {}
|
||||
self._num_trials_left = self.args.repeat
|
||||
|
||||
if type(self._max_concurrent) is int:
|
||||
self._max_concurrent = min(self._max_concurrent, self.args.repeat)
|
||||
|
||||
self.rstate = np.random.RandomState()
|
||||
self.trial_generator = self._trial_generator()
|
||||
self._add_new_trials_if_needed(trial_runner)
|
||||
|
||||
def _trial_generator(self):
|
||||
while self._num_trials_left > 0:
|
||||
new_cfg = copy.deepcopy(self.default_config)
|
||||
new_ids = self._hpopt_trials.new_trial_ids(1)
|
||||
self._hpopt_trials.refresh()
|
||||
|
||||
# Get new suggestion from
|
||||
new_trials = self.algo(
|
||||
new_ids, self.domain, self._hpopt_trials,
|
||||
self.rstate.randint(2 ** 31 - 1))
|
||||
self._hpopt_trials.insert_trial_docs(new_trials)
|
||||
self._hpopt_trials.refresh()
|
||||
new_trial = new_trials[0]
|
||||
new_trial_id = new_trial["tid"]
|
||||
suggested_config = hpo.base.spec_from_misc(new_trial["misc"])
|
||||
new_cfg.update(suggested_config)
|
||||
|
||||
kv_str = "_".join(["{}={}".format(k, str(v)[:5])
|
||||
for k, v in sorted(suggested_config.items())])
|
||||
experiment_tag = "{}_{}".format(new_trial_id, kv_str)
|
||||
|
||||
# Keep this consistent with tune.variant_generator
|
||||
trial = Trial(
|
||||
trainable_name=self.args.run,
|
||||
config=new_cfg,
|
||||
local_dir=os.path.join(self.args.local_dir, self._output_path),
|
||||
experiment_tag=experiment_tag,
|
||||
resources=self.args.trial_resources,
|
||||
stopping_criterion=self.args.stop,
|
||||
checkpoint_freq=self.args.checkpoint_freq,
|
||||
restore_path=self.args.restore,
|
||||
upload_dir=self.args.upload_dir,
|
||||
max_failures=self.args.max_failures)
|
||||
|
||||
self._tune_to_hp[trial] = new_trial_id
|
||||
self._num_trials_left -= 1
|
||||
yield trial
|
||||
|
||||
def on_trial_result(self, trial_runner, trial, result):
|
||||
ho_trial = self._get_hyperopt_trial(self._tune_to_hp[trial])
|
||||
now = hpo.utils.coarse_utcnow()
|
||||
ho_trial['book_time'] = now
|
||||
ho_trial['refresh_time'] = now
|
||||
return TrialScheduler.CONTINUE
|
||||
|
||||
def on_trial_error(self, trial_runner, trial):
|
||||
ho_trial = self._get_hyperopt_trial(self._tune_to_hp[trial])
|
||||
ho_trial['refresh_time'] = hpo.utils.coarse_utcnow()
|
||||
ho_trial['state'] = hpo.base.JOB_STATE_ERROR
|
||||
ho_trial['misc']['error'] = (str(TuneError), "Tune Error")
|
||||
self._hpopt_trials.refresh()
|
||||
del self._tune_to_hp[trial]
|
||||
|
||||
def on_trial_remove(self, trial_runner, trial):
|
||||
ho_trial = self._get_hyperopt_trial(self._tune_to_hp[trial])
|
||||
ho_trial['refresh_time'] = hpo.utils.coarse_utcnow()
|
||||
ho_trial['state'] = hpo.base.JOB_STATE_ERROR
|
||||
ho_trial['misc']['error'] = (str(TuneError), "Tune Removed")
|
||||
self._hpopt_trials.refresh()
|
||||
del self._tune_to_hp[trial]
|
||||
|
||||
def on_trial_complete(self, trial_runner, trial, result):
|
||||
ho_trial = self._get_hyperopt_trial(self._tune_to_hp[trial])
|
||||
ho_trial['refresh_time'] = hpo.utils.coarse_utcnow()
|
||||
ho_trial['state'] = hpo.base.JOB_STATE_DONE
|
||||
hp_result = self._to_hyperopt_result(result)
|
||||
ho_trial['result'] = hp_result
|
||||
self._hpopt_trials.refresh()
|
||||
del self._tune_to_hp[trial]
|
||||
|
||||
def _to_hyperopt_result(self, result):
|
||||
return {"loss": -getattr(result, self._reward_attr),
|
||||
"status": "ok"}
|
||||
|
||||
def _get_hyperopt_trial(self, tid):
|
||||
return [t for t in self._hpopt_trials.trials if t["tid"] == tid][0]
|
||||
|
||||
def choose_trial_to_run(self, trial_runner):
|
||||
self._add_new_trials_if_needed(trial_runner)
|
||||
return FIFOScheduler.choose_trial_to_run(self, trial_runner)
|
||||
|
||||
def _add_new_trials_if_needed(self, trial_runner):
|
||||
"""Checks if there is a next trial ready to be queued.
|
||||
|
||||
This is determined by tracking the number of concurrent
|
||||
experiments and trials left to run. If self._max_concurrent is None,
|
||||
scheduler will add new trial if there is none that are pending.
|
||||
"""
|
||||
pending = [t for t in trial_runner.get_trials()
|
||||
if t.status == Trial.PENDING]
|
||||
if self._num_trials_left <= 0:
|
||||
return
|
||||
if self._max_concurrent is None:
|
||||
if not pending:
|
||||
trial_runner.add_trial(next(self.trial_generator))
|
||||
else:
|
||||
while self._num_live_trials() < self._max_concurrent:
|
||||
try:
|
||||
trial_runner.add_trial(next(self.trial_generator))
|
||||
except StopIteration:
|
||||
break
|
||||
|
||||
def _num_live_trials(self):
|
||||
return len(self._tune_to_hp)
|
|
@ -100,6 +100,7 @@ class TrainableFunctionApiTest(unittest.TestCase):
|
|||
def testBadParams2(self):
|
||||
def f():
|
||||
run_experiments({"foo": {
|
||||
"run": "asdf",
|
||||
"bah": "this param is not allowed",
|
||||
}})
|
||||
self.assertRaises(TuneError, f)
|
||||
|
|
|
@ -1,4 +1,3 @@
|
|||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
|
|
@ -335,8 +335,8 @@ class Trial(object):
|
|||
def update_last_result(self, result, terminate=False):
|
||||
if terminate:
|
||||
result = result._replace(done=True)
|
||||
if terminate or (
|
||||
self.verbose and
|
||||
if self.verbose and (
|
||||
terminate or
|
||||
time.time() - self.last_debug > DEBUG_PRINT_INTERVAL):
|
||||
print("TrainingResult for {}:".format(self))
|
||||
print(" {}".format(pretty_print(result).replace("\n", "\n ")))
|
||||
|
|
|
@ -40,13 +40,16 @@ class TrialRunner(object):
|
|||
"""
|
||||
|
||||
def __init__(self, scheduler=None, launch_web_server=False,
|
||||
server_port=TuneServer.DEFAULT_PORT):
|
||||
server_port=TuneServer.DEFAULT_PORT, verbose=True):
|
||||
"""Initializes a new TrialRunner.
|
||||
|
||||
Args:
|
||||
scheduler (TrialScheduler): Defaults to FIFOScheduler.
|
||||
launch_web_server (bool): Flag for starting TuneServer
|
||||
server_port (int): Port number for launching TuneServer"""
|
||||
server_port (int): Port number for launching TuneServer
|
||||
verbose (bool): Flag for verbosity. If False, trial results
|
||||
will not be output.
|
||||
"""
|
||||
|
||||
self._scheduler_alg = scheduler or FIFOScheduler()
|
||||
self._trials = []
|
||||
|
@ -64,6 +67,7 @@ class TrialRunner(object):
|
|||
if launch_web_server:
|
||||
self._server = TuneServer(self, server_port)
|
||||
self._stop_queue = []
|
||||
self._verbose = verbose
|
||||
|
||||
def is_finished(self):
|
||||
"""Returns whether all trials have finished running."""
|
||||
|
@ -85,8 +89,9 @@ class TrialRunner(object):
|
|||
Callers should typically run this method repeatedly in a loop. They
|
||||
may inspect or modify the runner's state in between calls to step().
|
||||
"""
|
||||
if self._can_launch_more():
|
||||
self._launch_trial()
|
||||
next_trial = self._get_next_trial()
|
||||
if next_trial is not None:
|
||||
self._launch_trial(next_trial)
|
||||
elif self._running:
|
||||
self._process_events()
|
||||
else:
|
||||
|
@ -127,7 +132,11 @@ class TrialRunner(object):
|
|||
"""Adds a new trial to this TrialRunner.
|
||||
|
||||
Trials may be added at any time.
|
||||
|
||||
Args:
|
||||
trial (Trial): Trial to queue.
|
||||
"""
|
||||
trial.set_verbose(self._verbose)
|
||||
self._scheduler_alg.on_trial_add(self, trial)
|
||||
self._trials.append(trial)
|
||||
|
||||
|
@ -185,13 +194,12 @@ class TrialRunner(object):
|
|||
resources.cpu_total() <= cpu_avail and
|
||||
resources.gpu_total() <= gpu_avail)
|
||||
|
||||
def _can_launch_more(self):
|
||||
def _get_next_trial(self):
|
||||
self._update_avail_resources()
|
||||
trial = self._get_runnable()
|
||||
return trial is not None
|
||||
trial = self._scheduler_alg.choose_trial_to_run(self)
|
||||
return trial
|
||||
|
||||
def _launch_trial(self, custom_trial=None):
|
||||
trial = custom_trial or self._get_runnable()
|
||||
def _launch_trial(self, trial):
|
||||
self._commit_resources(trial.resources)
|
||||
try:
|
||||
trial.start()
|
||||
|
@ -219,6 +227,7 @@ class TrialRunner(object):
|
|||
self._total_time += result.time_this_iter_s
|
||||
|
||||
if trial.should_stop(result):
|
||||
# Hook into scheduler
|
||||
self._scheduler_alg.on_trial_complete(self, trial, result)
|
||||
decision = TrialScheduler.STOP
|
||||
else:
|
||||
|
@ -262,9 +271,6 @@ class TrialRunner(object):
|
|||
print("Error recovering trial from checkpoint, abort:", error_msg)
|
||||
self._stop_trial(trial, error=True, error_msg=error_msg)
|
||||
|
||||
def _get_runnable(self):
|
||||
return self._scheduler_alg.choose_trial_to_run(self)
|
||||
|
||||
def _commit_resources(self, resources):
|
||||
self._committed_resources = Resources(
|
||||
self._committed_resources.cpu + resources.cpu_total(),
|
||||
|
|
|
@ -3,6 +3,7 @@ from __future__ import division
|
|||
from __future__ import print_function
|
||||
|
||||
from ray.tune.trial import Trial
|
||||
from ray.tune.variant_generator import generate_trials
|
||||
|
||||
|
||||
class TrialScheduler(object):
|
||||
|
@ -47,11 +48,28 @@ class TrialScheduler(object):
|
|||
|
||||
raise NotImplementedError
|
||||
|
||||
def choose_trial_to_run(self, trial_runner, trials):
|
||||
def add_experiment(self, experiment, trial_runner):
|
||||
"""Adds an experiment to the scheduler.
|
||||
|
||||
The scheduler is responsible for adding the trials of the experiment
|
||||
to the runner, which can be done immediately (if there are a finite
|
||||
set of trials), or over time (if there is an infinite stream of trials
|
||||
or if the scheduler is iterative in nature).
|
||||
"""
|
||||
generator = generate_trials(experiment.spec, experiment.name)
|
||||
while True:
|
||||
try:
|
||||
trial_runner.add_trial(next(generator))
|
||||
except StopIteration:
|
||||
break
|
||||
|
||||
def choose_trial_to_run(self, trial_runner):
|
||||
"""Called to choose a new trial to run.
|
||||
|
||||
This should return one of the trials in trial_runner that is in
|
||||
the PENDING or PAUSED state."""
|
||||
the PENDING or PAUSED state. This function must be idempotent.
|
||||
|
||||
If no trial is ready, return None."""
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
|
|
|
@ -4,16 +4,16 @@ from __future__ import print_function
|
|||
|
||||
import time
|
||||
|
||||
from ray.tune import TuneError
|
||||
from ray.tune.error import TuneError
|
||||
from ray.tune.hyperband import HyperBandScheduler
|
||||
from ray.tune.async_hyperband import AsyncHyperBandScheduler
|
||||
from ray.tune.median_stopping_rule import MedianStoppingRule
|
||||
from ray.tune.hpo_scheduler import HyperOptScheduler
|
||||
from ray.tune.trial import Trial, DEBUG_PRINT_INTERVAL
|
||||
from ray.tune.log_sync import wait_for_log_sync
|
||||
from ray.tune.trial_runner import TrialRunner
|
||||
from ray.tune.trial_scheduler import FIFOScheduler
|
||||
from ray.tune.web_server import TuneServer
|
||||
from ray.tune.variant_generator import generate_trials
|
||||
from ray.tune.experiment import Experiment
|
||||
|
||||
|
||||
|
@ -22,6 +22,7 @@ _SCHEDULERS = {
|
|||
"MedianStopping": MedianStoppingRule,
|
||||
"HyperBand": HyperBandScheduler,
|
||||
"AsyncHyperBand": AsyncHyperBandScheduler,
|
||||
"HyperOpt": HyperOptScheduler,
|
||||
}
|
||||
|
||||
|
||||
|
@ -42,7 +43,7 @@ def run_experiments(experiments, scheduler=None, with_server=False,
|
|||
experiments (Experiment | list | dict): Experiments to run.
|
||||
scheduler (TrialScheduler): Scheduler for executing
|
||||
the experiment. Choose among FIFO (default), MedianStopping,
|
||||
AsyncHyperBand, or HyperBand.
|
||||
AsyncHyperBand, HyperBand, or HyperOpt.
|
||||
with_server (bool): Starts a background Tune server. Needed for
|
||||
using the Client API.
|
||||
server_port (int): Port number for launching TuneServer.
|
||||
|
@ -53,23 +54,21 @@ def run_experiments(experiments, scheduler=None, with_server=False,
|
|||
scheduler = FIFOScheduler()
|
||||
|
||||
runner = TrialRunner(
|
||||
scheduler, launch_web_server=with_server, server_port=server_port)
|
||||
scheduler, launch_web_server=with_server, server_port=server_port,
|
||||
verbose=verbose)
|
||||
exp_list = experiments
|
||||
if isinstance(experiments, Experiment):
|
||||
exp_list = [experiments]
|
||||
elif type(experiments) is dict:
|
||||
exp_list = [Experiment.from_json(name, spec)
|
||||
for name, spec in experiments.items()]
|
||||
|
||||
if type(experiments) is dict:
|
||||
for name, spec in experiments.items():
|
||||
for trial in generate_trials(spec, name):
|
||||
trial.set_verbose(verbose)
|
||||
runner.add_trial(trial)
|
||||
elif (type(experiments) is list and
|
||||
all(isinstance(exp, Experiment) for exp in experiments)):
|
||||
for experiment in experiments:
|
||||
for trial in experiment.trials():
|
||||
trial.set_verbose(verbose)
|
||||
runner.add_trial(trial)
|
||||
elif isinstance(experiments, Experiment):
|
||||
for trial in experiments.trials():
|
||||
trial.set_verbose(verbose)
|
||||
runner.add_trial(trial)
|
||||
if (type(exp_list) is list and
|
||||
all(isinstance(exp, Experiment) for exp in exp_list)):
|
||||
for experiment in exp_list:
|
||||
scheduler.add_experiment(experiment, runner)
|
||||
else:
|
||||
raise TuneError("Invalid argument: {}".format(experiments))
|
||||
|
||||
print(runner.debug_string(max_debug=99999))
|
||||
|
||||
|
|
|
@ -214,6 +214,10 @@ docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \
|
|||
python /ray/python/ray/tune/examples/tune_mnist_async_hyperband.py \
|
||||
--smoke-test
|
||||
|
||||
docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \
|
||||
python /ray/python/ray/tune/examples/hyperopt_example.py \
|
||||
--smoke-test
|
||||
|
||||
docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \
|
||||
python /ray/python/ray/rllib/examples/multiagent_mountaincar.py
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue