mirror of
https://github.com/vale981/ray
synced 2025-04-23 06:25:52 -04:00
[tune] Experiment Management API (#1328)
* init for exposing external interface * revisions * http server * small * simplify * ui * fixes * test * nit * nit * merge * untested * nits * nit * init tests * tests * more tests * nit * fix hyperband * cleanup * nits * good stuff * cleanup * comments and need to test * nit * notebook * testing * test and expose server * server_tests * docs * periods * fix tests * committing test * fi
This commit is contained in:
parent
1d2a28ab07
commit
a7d544424c
15 changed files with 599 additions and 88 deletions
|
@ -123,6 +123,7 @@ script:
|
|||
- python test/monitor_test.py
|
||||
- python test/trial_runner_test.py
|
||||
- python test/trial_scheduler_test.py
|
||||
- python test/tune_server_test.py
|
||||
- python test/cython_test.py
|
||||
- python -m pytest python/ray/dataframe/test/test_dataframe.py
|
||||
- python -m pytest python/ray/dataframe/test/test_series.py
|
||||
|
|
|
@ -159,3 +159,26 @@ Running in a large cluster
|
|||
--------------------------
|
||||
|
||||
The ``run_experiments`` also takes any arguments that ``ray.init()`` does. This can be used to pass in the redis address of a multi-node Ray cluster. For more details, check out the `tune.py script <https://github.com/ray-project/ray/blob/master/python/ray/tune/tune.py>`__.
|
||||
|
||||
Client API
|
||||
----------
|
||||
|
||||
You can modify an ongoing experiment by adding or deleting trials using the Tune Client API. To do this, start your experiment with a flag, either from the command-line, e.g.:
|
||||
|
||||
::
|
||||
|
||||
cd ray/python/tune
|
||||
./tune.py -f examples/tune_mnist_ray.yaml --server=True --server-port=4321
|
||||
|
||||
Or within the Python API, e.g.:
|
||||
::
|
||||
|
||||
run_experiments({...}, with_server=True, server_port=4321)
|
||||
|
||||
Then, on the client side, you can use the following class. The server address defaults to ``localhost:4321``. If on a cluster, you may want to forward this port so that you can use the Client on your local machine.
|
||||
|
||||
.. autoclass:: ray.tune.web_server.TuneClient
|
||||
:members:
|
||||
|
||||
|
||||
For an example notebook for using the Client API, see the `Client API Example <https://github.com/ray-project/ray/tree/master/python/ray/tune/TuneClient.ipynb>`__.
|
||||
|
|
88
python/ray/tune/TuneClient.ipynb
Normal file
88
python/ray/tune/TuneClient.ipynb
Normal file
|
@ -0,0 +1,88 @@
|
|||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from ray.tune.web_server import TuneClient\n",
|
||||
"\n",
|
||||
"manager = TuneClient(tune_address=\"localhost:4321\")\n",
|
||||
"\n",
|
||||
"x = manager.get_all_trials()\n",
|
||||
"\n",
|
||||
"[((y[\"id\"]), y[\"status\"]) for y in x[\"trials\"]]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"collapsed": true,
|
||||
"scrolled": false
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"for y in x[\"trials\"][-10:]:\n",
|
||||
" manager.stop_trial(y[\"id\"])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"collapsed": true
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from ray.tune.variant_generator import generate_trials\n",
|
||||
"import yaml\n",
|
||||
"\n",
|
||||
"with open(\"../rllib/tuned_examples/hyperband-cartpole.yaml\") as f:\n",
|
||||
" d = yaml.load(f)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"collapsed": true
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"name, spec = [x for x in d.items()][0]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"manager.add_trial(name, spec)"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.6.2"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
|
@ -6,3 +6,8 @@ from __future__ import print_function
|
|||
class TuneError(Exception):
|
||||
"""General error class raised by ray.tune."""
|
||||
pass
|
||||
|
||||
|
||||
class TuneManagerError(TuneError):
|
||||
"""Error raised in operating the Tune Manager."""
|
||||
pass
|
||||
|
|
|
@ -87,7 +87,9 @@ class HyperBandScheduler(FIFOScheduler):
|
|||
self._time_attr = time_attr
|
||||
|
||||
def on_trial_add(self, trial_runner, trial):
|
||||
"""On a new trial add, if current bracket is not filled,
|
||||
"""Adds new trial.
|
||||
|
||||
On a new trial add, if current bracket is not filled,
|
||||
add to current bracket. Else, if current band is not filled,
|
||||
create new bracket, add to current bracket.
|
||||
Else, create new iteration, create new bracket, add to bracket."""
|
||||
|
@ -136,9 +138,8 @@ class HyperBandScheduler(FIFOScheduler):
|
|||
|
||||
This scheduler will not start trials but will stop trials.
|
||||
The current running trial will not be handled,
|
||||
as the trialrunner will be given control to handle it.
|
||||
as the trialrunner will be given control to handle it."""
|
||||
|
||||
# TODO(rliaw) should be only called if trial has not errored"""
|
||||
bracket, _ = self._trial_info[trial]
|
||||
bracket.update_trial_stats(trial, result)
|
||||
|
||||
|
@ -160,16 +161,17 @@ class HyperBandScheduler(FIFOScheduler):
|
|||
action = TrialScheduler.PAUSE
|
||||
if bracket.cur_iter_done():
|
||||
if bracket.finished():
|
||||
self._cleanup_bracket(trial_runner, bracket)
|
||||
bracket.cleanup_full(trial_runner)
|
||||
return TrialScheduler.CONTINUE
|
||||
|
||||
good, bad = bracket.successive_halving(self._reward_attr)
|
||||
# kill bad trials
|
||||
self._num_stopped += len(bad)
|
||||
for t in bad:
|
||||
if t.status == Trial.PAUSED:
|
||||
self._cleanup_trial(trial_runner, t, bracket, hard=True)
|
||||
trial_runner.stop_trial(t)
|
||||
elif t.status == Trial.RUNNING:
|
||||
self._cleanup_trial(trial_runner, t, bracket, hard=False)
|
||||
bracket.cleanup_trial(t)
|
||||
action = TrialScheduler.STOP
|
||||
else:
|
||||
raise Exception("Trial with unexpected status encountered")
|
||||
|
@ -185,47 +187,30 @@ class HyperBandScheduler(FIFOScheduler):
|
|||
action = TrialScheduler.CONTINUE
|
||||
return action
|
||||
|
||||
def _cleanup_trial(self, trial_runner, t, bracket, hard=False):
|
||||
"""Bookkeeping for trials finished. If `hard=True`, then
|
||||
this scheduler will force the trial_runner to release resources.
|
||||
def on_trial_remove(self, trial_runner, trial):
|
||||
"""Notification when trial terminates.
|
||||
|
||||
Otherwise, only clean up trial information locally."""
|
||||
self._num_stopped += 1
|
||||
if hard:
|
||||
trial_runner._stop_trial(t)
|
||||
bracket.cleanup_trial(t)
|
||||
|
||||
def _cleanup_bracket(self, trial_runner, bracket):
|
||||
"""Cleans up bracket after bracket is completely finished.
|
||||
Lets the last trial continue to run until termination condition
|
||||
kicks in."""
|
||||
for trial in bracket.current_trials():
|
||||
if (trial.status == Trial.PAUSED):
|
||||
self._cleanup_trial(
|
||||
trial_runner, trial, bracket,
|
||||
hard=True)
|
||||
Trial info is removed from bracket. Triggers halving if bracket is
|
||||
not finished."""
|
||||
bracket, _ = self._trial_info[trial]
|
||||
bracket.cleanup_trial(trial)
|
||||
if not bracket.finished():
|
||||
self._process_bracket(trial_runner, bracket, trial)
|
||||
|
||||
def on_trial_complete(self, trial_runner, trial, result):
|
||||
"""Cleans up trial info from bracket if trial completed early."""
|
||||
|
||||
bracket, _ = self._trial_info[trial]
|
||||
self._cleanup_trial(trial_runner, trial, bracket, hard=False)
|
||||
self._process_bracket(trial_runner, bracket, trial)
|
||||
self.on_trial_remove(trial_runner, trial)
|
||||
|
||||
def on_trial_error(self, trial_runner, trial):
|
||||
"""Cleans up trial info from bracket if trial errored early."""
|
||||
|
||||
bracket, _ = self._trial_info[trial]
|
||||
self._cleanup_trial(trial_runner, trial, bracket, hard=False)
|
||||
self._process_bracket(trial_runner, bracket, trial)
|
||||
self.on_trial_remove(trial_runner, trial)
|
||||
|
||||
def choose_trial_to_run(self, trial_runner, *args):
|
||||
"""Fair scheduling within iteration by completion percentage.
|
||||
List of trials not used since all trials are tracked as state
|
||||
of scheduler.
|
||||
|
||||
If iteration is occupied (ie, no trials to run), then look into
|
||||
next iteration."""
|
||||
List of trials not used since all trials are tracked as state
|
||||
of scheduler. If iteration is occupied (ie, no trials to run),
|
||||
then look into next iteration."""
|
||||
|
||||
for hyperband in self._hyperbands:
|
||||
for bracket in sorted(hyperband,
|
||||
|
@ -237,6 +222,7 @@ class HyperBandScheduler(FIFOScheduler):
|
|||
return None
|
||||
|
||||
def debug_string(self):
|
||||
# TODO(rliaw): This debug string needs work
|
||||
brackets = [
|
||||
"({0}/{1})".format(
|
||||
len(bracket._live_trials), len(bracket._all_trials))
|
||||
|
@ -301,8 +287,11 @@ class Bracket():
|
|||
return False
|
||||
|
||||
def filled(self):
|
||||
"""We will only let new trials be added at current level,
|
||||
minimizing the need to backtrack and bookkeep previous medians"""
|
||||
"""Checks if bracket is filled.
|
||||
|
||||
Only let new trials be added at current level minimizing the need
|
||||
to backtrack and bookkeep previous medians."""
|
||||
|
||||
return len(self._live_trials) == self._n
|
||||
|
||||
def successive_halving(self, reward_attr):
|
||||
|
@ -346,6 +335,15 @@ class Bracket():
|
|||
assert trial in self._live_trials
|
||||
del self._live_trials[trial]
|
||||
|
||||
def cleanup_full(self, trial_runner):
|
||||
"""Cleans up bracket after bracket is completely finished.
|
||||
|
||||
Lets the last trial continue to run until termination condition
|
||||
kicks in."""
|
||||
for trial in self.current_trials():
|
||||
if (trial.status == Trial.PAUSED):
|
||||
trial_runner.stop_trial(trial)
|
||||
|
||||
def completion_percentage(self):
|
||||
"""Returns a progress metric.
|
||||
|
||||
|
@ -374,5 +372,8 @@ class Bracket():
|
|||
"r={}".format(self._r),
|
||||
"progress={}".format(self.completion_percentage())
|
||||
])
|
||||
return "Bracket({})".format(status)
|
||||
|
||||
def debug_string(self):
|
||||
trials = ", ".join([t.status for t in self._live_trials])
|
||||
return "Bracket({})[{}]".format(status, trials)
|
||||
return "{}[{}]".format(self, trials)
|
||||
|
|
|
@ -5,6 +5,7 @@ from __future__ import print_function
|
|||
import collections
|
||||
import numpy as np
|
||||
|
||||
from ray.tune.trial import Trial
|
||||
from ray.tune.trial_scheduler import FIFOScheduler, TrialScheduler
|
||||
|
||||
|
||||
|
@ -74,6 +75,11 @@ class MedianStoppingRule(FIFOScheduler):
|
|||
self._results[trial].append(result)
|
||||
self._completed_trials.add(trial)
|
||||
|
||||
def on_trial_remove(self, trial_runner, trial):
|
||||
"""Marks trial as completed if it is paused and has previously ran."""
|
||||
if trial.status is Trial.PAUSED and trial in self._results:
|
||||
self._completed_trials.add(trial)
|
||||
|
||||
def debug_string(self):
|
||||
return "Using MedianStoppingRule: num_stopped={}.".format(
|
||||
len(self._stopped_trials))
|
||||
|
|
|
@ -29,7 +29,7 @@ TrainingResult = namedtuple("TrainingResult", [
|
|||
# (Required) Accumulated timesteps for this entire experiment.
|
||||
"timesteps_total",
|
||||
|
||||
# (Optional) If training is finished.
|
||||
# (Optional) If training is terminated.
|
||||
"done",
|
||||
|
||||
# (Optional) Custom metadata to report for this iteration.
|
||||
|
|
|
@ -9,6 +9,7 @@ import ray
|
|||
import os
|
||||
|
||||
from collections import namedtuple
|
||||
from ray.utils import random_string, binary_to_hex
|
||||
from ray.tune import TuneError
|
||||
from ray.tune.logger import NoopLogger, UnifiedLogger
|
||||
from ray.tune.result import TrainingResult, DEFAULT_RESULTS_DIR, pretty_print
|
||||
|
@ -105,6 +106,7 @@ class Trial(object):
|
|||
self.location = None
|
||||
self.logdir = None
|
||||
self.result_logger = None
|
||||
self.trial_id = binary_to_hex(random_string())[:8]
|
||||
|
||||
def start(self):
|
||||
"""Starts this trial.
|
||||
|
|
|
@ -8,6 +8,7 @@ import time
|
|||
import traceback
|
||||
|
||||
from ray.tune import TuneError
|
||||
from ray.tune.web_server import TuneServer
|
||||
from ray.tune.trial import Trial, Resources
|
||||
from ray.tune.trial_scheduler import FIFOScheduler, TrialScheduler
|
||||
|
||||
|
@ -34,8 +35,14 @@ class TrialRunner(object):
|
|||
misleading benchmark results.
|
||||
"""
|
||||
|
||||
def __init__(self, scheduler=None):
|
||||
"""Initializes a new TrialRunner."""
|
||||
def __init__(self, scheduler=None, launch_web_server=False,
|
||||
server_port=TuneServer.DEFAULT_PORT):
|
||||
"""Initializes a new TrialRunner.
|
||||
|
||||
Args:
|
||||
scheduler (TrialScheduler): Defaults to FIFOScheduler.
|
||||
launch_web_server (bool): Flag for starting TuneServer
|
||||
server_port (int): Port number for launching TuneServer"""
|
||||
|
||||
self._scheduler_alg = scheduler or FIFOScheduler()
|
||||
self._trials = []
|
||||
|
@ -49,6 +56,10 @@ class TrialRunner(object):
|
|||
self._global_time_limit = float(
|
||||
os.environ.get("TRIALRUNNER_WALLTIME_LIMIT", float('inf')))
|
||||
self._total_time = 0
|
||||
self._server = None
|
||||
if launch_web_server:
|
||||
self._server = TuneServer(self, server_port)
|
||||
self._stop_queue = []
|
||||
|
||||
def is_finished(self):
|
||||
"""Returns whether all trials have finished running."""
|
||||
|
@ -70,7 +81,6 @@ class TrialRunner(object):
|
|||
Callers should typically run this method repeatedly in a loop. They
|
||||
may inspect or modify the runner's state in between calls to step().
|
||||
"""
|
||||
|
||||
if self._can_launch_more():
|
||||
self._launch_trial()
|
||||
elif self._running:
|
||||
|
@ -91,6 +101,16 @@ class TrialRunner(object):
|
|||
"trials with sufficient resources.")
|
||||
raise TuneError("Called step when all trials finished?")
|
||||
|
||||
if self._server:
|
||||
self._process_requests()
|
||||
|
||||
if self.is_finished():
|
||||
self._server.shutdown()
|
||||
|
||||
def get_trial(self, tid):
|
||||
trial = [t for t in self._trials if t.trial_id == tid]
|
||||
return trial[0] if trial else None
|
||||
|
||||
def get_trials(self):
|
||||
"""Returns the list of trials managed by this TrialRunner.
|
||||
|
||||
|
@ -207,6 +227,43 @@ class TrialRunner(object):
|
|||
assert self._committed_resources.cpu >= 0
|
||||
assert self._committed_resources.gpu >= 0
|
||||
|
||||
def request_stop_trial(self, trial):
|
||||
self._stop_queue.append(trial)
|
||||
|
||||
def _process_requests(self):
|
||||
while self._stop_queue:
|
||||
t = self._stop_queue.pop()
|
||||
self.stop_trial(t)
|
||||
|
||||
def stop_trial(self, trial):
|
||||
"""Stops trial.
|
||||
|
||||
Trials may be stopped at any time. If trial is in state PENDING
|
||||
or PAUSED, calls `scheduler.on_trial_remove`. Otherwise waits for
|
||||
result for the trial and calls `scheduler.on_trial_complete`
|
||||
if RUNNING."""
|
||||
error = False
|
||||
|
||||
if trial.status in [Trial.ERROR, Trial.TERMINATED]:
|
||||
return
|
||||
elif trial.status in [Trial.PENDING, Trial.PAUSED]:
|
||||
self._scheduler_alg.on_trial_remove(self, trial)
|
||||
elif trial.status is Trial.RUNNING:
|
||||
# NOTE: There should only be one...
|
||||
result_id = [rid for rid, t in self._running.items()
|
||||
if t is trial][0]
|
||||
self._running.pop(result_id)
|
||||
try:
|
||||
result = ray.get(result_id)
|
||||
trial.update_last_result(result, terminate=True)
|
||||
self._scheduler_alg.on_trial_complete(self, trial, result)
|
||||
except Exception:
|
||||
print("Error processing event:", traceback.format_exc())
|
||||
self._scheduler_alg.on_trial_error(self, trial)
|
||||
error = True
|
||||
|
||||
self._stop_trial(trial, error=error)
|
||||
|
||||
def _stop_trial(self, trial, error=False):
|
||||
"""Only returns resources if resources allocated."""
|
||||
prior_status = trial.status
|
||||
|
|
|
@ -26,14 +26,24 @@ class TrialScheduler(object):
|
|||
"""Called on each intermediate result returned by a trial.
|
||||
|
||||
At this point, the trial scheduler can make a decision by returning
|
||||
one of CONTINUE, PAUSE, and STOP."""
|
||||
one of CONTINUE, PAUSE, and STOP. This will only be called when the
|
||||
trial is in the RUNNING state."""
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
def on_trial_complete(self, trial_runner, trial, result):
|
||||
"""Notification for the completion of trial.
|
||||
|
||||
This will only be called when the trial completes naturally."""
|
||||
This will only be called when the trial is in the RUNNING state and
|
||||
either completes naturally or by manual termination."""
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
def on_trial_remove(self, trial_runner, trial):
|
||||
"""Called to remove trial.
|
||||
|
||||
This is called when the trial is in PAUSED or PENDING state. Otherwise,
|
||||
call `on_trial_complete`."""
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
|
@ -66,6 +76,9 @@ class FIFOScheduler(TrialScheduler):
|
|||
def on_trial_complete(self, trial_runner, trial, result):
|
||||
pass
|
||||
|
||||
def on_trial_remove(self, trial_runner, trial):
|
||||
pass
|
||||
|
||||
def choose_trial_to_run(self, trial_runner):
|
||||
for trial in trial_runner.get_trials():
|
||||
if (trial.status == Trial.PENDING and
|
||||
|
|
|
@ -16,6 +16,7 @@ from ray.tune.median_stopping_rule import MedianStoppingRule
|
|||
from ray.tune.trial import Trial
|
||||
from ray.tune.trial_runner import TrialRunner
|
||||
from ray.tune.trial_scheduler import FIFOScheduler
|
||||
from ray.tune.web_server import TuneServer
|
||||
from ray.tune.variant_generator import generate_trials
|
||||
|
||||
|
||||
|
@ -41,6 +42,10 @@ parser.add_argument("--scheduler", default="FIFO", type=str,
|
|||
help="FIFO, MedianStopping, or HyperBand")
|
||||
parser.add_argument("--scheduler-config", default="{}", type=json.loads,
|
||||
help="Config options to pass to the scheduler.")
|
||||
parser.add_argument("--server", default=False, type=bool,
|
||||
help="Option to launch Tune Server")
|
||||
parser.add_argument("--server-port", default=TuneServer.DEFAULT_PORT,
|
||||
type=int, help="Option to launch Tune Server")
|
||||
parser.add_argument("-f", "--config-file", required=True, type=str,
|
||||
help="Read experiment options from this JSON/YAML file.")
|
||||
|
||||
|
@ -61,10 +66,13 @@ def _make_scheduler(args):
|
|||
args.scheduler, _SCHEDULERS.keys()))
|
||||
|
||||
|
||||
def run_experiments(experiments, scheduler=None, **ray_args):
|
||||
def run_experiments(experiments, scheduler=None, with_server=False,
|
||||
server_port=TuneServer.DEFAULT_PORT, **ray_args):
|
||||
if scheduler is None:
|
||||
scheduler = FIFOScheduler()
|
||||
runner = TrialRunner(scheduler)
|
||||
|
||||
runner = TrialRunner(
|
||||
scheduler, launch_web_server=with_server, server_port=server_port)
|
||||
|
||||
for name, spec in experiments.items():
|
||||
for trial in generate_trials(spec, name):
|
||||
|
@ -78,6 +86,7 @@ def run_experiments(experiments, scheduler=None, **ray_args):
|
|||
print(runner.debug_string())
|
||||
|
||||
for trial in runner.get_trials():
|
||||
# TODO(rliaw): What about errored?
|
||||
if trial.status != Trial.TERMINATED:
|
||||
raise TuneError("Trial did not complete", trial)
|
||||
|
||||
|
@ -90,5 +99,6 @@ if __name__ == "__main__":
|
|||
with open(args.config_file) as f:
|
||||
experiments = yaml.load(f)
|
||||
run_experiments(
|
||||
experiments, _make_scheduler(args), redis_address=args.redis_address,
|
||||
experiments, _make_scheduler(args), with_server=args.server,
|
||||
server_port=args.server_port, redis_address=args.redis_address,
|
||||
num_cpus=args.num_cpus, num_gpus=args.num_gpus)
|
||||
|
|
146
python/ray/tune/web_server.py
Normal file
146
python/ray/tune/web_server.py
Normal file
|
@ -0,0 +1,146 @@
|
|||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import requests
|
||||
import json
|
||||
import threading
|
||||
from http.server import HTTPServer, BaseHTTPRequestHandler
|
||||
|
||||
from ray.tune.error import TuneError, TuneManagerError
|
||||
from ray.tune.variant_generator import generate_trials
|
||||
|
||||
|
||||
class TuneClient(object):
|
||||
"""Client to interact with ongoing Tune experiment.
|
||||
|
||||
Requires server to have started running."""
|
||||
STOP = "STOP"
|
||||
ADD = "ADD"
|
||||
GET_LIST = "GET_LIST"
|
||||
GET_TRIAL = "GET_TRIAL"
|
||||
|
||||
def __init__(self, tune_address):
|
||||
# TODO(rliaw): Better to specify address and port forward
|
||||
self._tune_address = tune_address
|
||||
self._path = "http://{}".format(tune_address)
|
||||
|
||||
def get_all_trials(self):
|
||||
"""Returns a list of all trials (trial_id, config, status)."""
|
||||
return self._get_response(
|
||||
{"command": TuneClient.GET_LIST})
|
||||
|
||||
def get_trial(self, trial_id):
|
||||
"""Returns the last result for queried trial."""
|
||||
return self._get_response(
|
||||
{"command": TuneClient.GET_TRIAL,
|
||||
"trial_id": trial_id})
|
||||
|
||||
def add_trial(self, name, trial_spec):
|
||||
"""Adds a trial of `name` with configurations."""
|
||||
# TODO(rliaw): have better way of specifying a new trial
|
||||
return self._get_response(
|
||||
{"command": TuneClient.ADD,
|
||||
"name": name,
|
||||
"spec": trial_spec})
|
||||
|
||||
def stop_trial(self, trial_id):
|
||||
"""Requests to stop trial."""
|
||||
return self._get_response(
|
||||
{"command": TuneClient.STOP,
|
||||
"trial_id": trial_id})
|
||||
|
||||
def _get_response(self, data):
|
||||
payload = json.dumps(data).encode()
|
||||
response = requests.get(self._path, data=payload)
|
||||
parsed = response.json()
|
||||
return parsed
|
||||
|
||||
|
||||
def RunnerHandler(runner):
|
||||
class Handler(BaseHTTPRequestHandler):
|
||||
|
||||
def do_GET(self):
|
||||
content_len = int(self.headers.get('Content-Length'), 0)
|
||||
raw_body = self.rfile.read(content_len)
|
||||
parsed_input = json.loads(raw_body.decode())
|
||||
status, response = self.execute_command(parsed_input)
|
||||
if status:
|
||||
self.send_response(200)
|
||||
else:
|
||||
self.send_response(400)
|
||||
self.end_headers()
|
||||
self.wfile.write(json.dumps(
|
||||
response).encode())
|
||||
|
||||
def trial_info(self, trial):
|
||||
if trial.last_result:
|
||||
result = trial.last_result._asdict()
|
||||
else:
|
||||
result = None
|
||||
info_dict = {
|
||||
"id": trial.trial_id,
|
||||
"trainable_name": trial.trainable_name,
|
||||
"config": trial.config,
|
||||
"status": trial.status,
|
||||
"result": result
|
||||
}
|
||||
return info_dict
|
||||
|
||||
def execute_command(self, args):
|
||||
def get_trial():
|
||||
trial = runner.get_trial(args["trial_id"])
|
||||
if trial is None:
|
||||
error = "Trial ({}) not found.".format(args["trial_id"])
|
||||
raise TuneManagerError(error)
|
||||
else:
|
||||
return trial
|
||||
|
||||
command = args["command"]
|
||||
response = {}
|
||||
try:
|
||||
if command == TuneClient.GET_LIST:
|
||||
response["trials"] = [self.trial_info(t)
|
||||
for t in runner.get_trials()]
|
||||
elif command == TuneClient.GET_TRIAL:
|
||||
trial = get_trial()
|
||||
response["trial_info"] = self.trial_info(trial)
|
||||
elif command == TuneClient.STOP:
|
||||
trial = get_trial()
|
||||
runner.request_stop_trial(trial)
|
||||
elif command == TuneClient.ADD:
|
||||
name = args["name"]
|
||||
spec = args["spec"]
|
||||
for trial in generate_trials(spec, name):
|
||||
runner.add_trial(trial)
|
||||
else:
|
||||
raise TuneManagerError("Unknown command.")
|
||||
status = True
|
||||
except TuneError as e:
|
||||
status = False
|
||||
response["message"] = str(e)
|
||||
|
||||
return status, response
|
||||
|
||||
return Handler
|
||||
|
||||
|
||||
class TuneServer(threading.Thread):
|
||||
|
||||
DEFAULT_PORT = 4321
|
||||
|
||||
def __init__(self, runner, port=None):
|
||||
|
||||
threading.Thread.__init__(self)
|
||||
self._port = port if port else self.DEFAULT_PORT
|
||||
address = ('localhost', self._port)
|
||||
print("Starting Tune Server...")
|
||||
self._server = HTTPServer(
|
||||
address, RunnerHandler(runner))
|
||||
self.start()
|
||||
|
||||
def run(self):
|
||||
self._server.serve_forever()
|
||||
|
||||
def shutdown(self):
|
||||
self._server.shutdown()
|
|
@ -497,6 +497,46 @@ class TrialRunnerTest(unittest.TestCase):
|
|||
runner.step()
|
||||
self.assertEqual(trials[0].status, Trial.TERMINATED)
|
||||
|
||||
def testStopTrial(self):
|
||||
ray.init(num_cpus=4, num_gpus=2)
|
||||
runner = TrialRunner()
|
||||
kwargs = {
|
||||
"stopping_criterion": {"training_iteration": 5},
|
||||
"resources": Resources(cpu=1, gpu=1),
|
||||
}
|
||||
trials = [
|
||||
Trial("__fake", **kwargs),
|
||||
Trial("__fake", **kwargs),
|
||||
Trial("__fake", **kwargs),
|
||||
Trial("__fake", **kwargs)]
|
||||
for t in trials:
|
||||
runner.add_trial(t)
|
||||
runner.step()
|
||||
self.assertEqual(trials[0].status, Trial.RUNNING)
|
||||
self.assertEqual(trials[1].status, Trial.PENDING)
|
||||
|
||||
# Stop trial while running
|
||||
runner.stop_trial(trials[0])
|
||||
self.assertEqual(trials[0].status, Trial.TERMINATED)
|
||||
self.assertEqual(trials[1].status, Trial.PENDING)
|
||||
|
||||
runner.step()
|
||||
self.assertEqual(trials[0].status, Trial.TERMINATED)
|
||||
self.assertEqual(trials[1].status, Trial.RUNNING)
|
||||
self.assertEqual(trials[-1].status, Trial.PENDING)
|
||||
|
||||
# Stop trial while pending
|
||||
runner.stop_trial(trials[-1])
|
||||
self.assertEqual(trials[0].status, Trial.TERMINATED)
|
||||
self.assertEqual(trials[1].status, Trial.RUNNING)
|
||||
self.assertEqual(trials[-1].status, Trial.TERMINATED)
|
||||
|
||||
runner.step()
|
||||
self.assertEqual(trials[0].status, Trial.TERMINATED)
|
||||
self.assertEqual(trials[1].status, Trial.RUNNING)
|
||||
self.assertEqual(trials[2].status, Trial.RUNNING)
|
||||
self.assertEqual(trials[-1].status, Trial.TERMINATED)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main(verbosity=2)
|
||||
|
|
|
@ -140,8 +140,25 @@ class EarlyStoppingSuite(unittest.TestCase):
|
|||
|
||||
|
||||
class _MockTrialRunner():
|
||||
def _stop_trial(self, trial):
|
||||
trial.stop()
|
||||
def __init__(self, scheduler):
|
||||
self._scheduler_alg = scheduler
|
||||
|
||||
def process_action(self, trial, action):
|
||||
if action == TrialScheduler.CONTINUE:
|
||||
pass
|
||||
elif action == TrialScheduler.PAUSE:
|
||||
self._pause_trial(trial)
|
||||
elif action == TrialScheduler.STOP:
|
||||
trial.stop()
|
||||
|
||||
def stop_trial(self, trial):
|
||||
if trial.status in [Trial.ERROR, Trial.TERMINATED]:
|
||||
return
|
||||
elif trial.status in [Trial.PENDING, Trial.PAUSED]:
|
||||
self._scheduler_alg.on_trial_remove(self, trial)
|
||||
else:
|
||||
|
||||
self._scheduler_alg.on_trial_complete(self, trial, result(100, 10))
|
||||
|
||||
def has_resources(self, resources):
|
||||
return True
|
||||
|
@ -168,7 +185,7 @@ class HyperbandSuite(unittest.TestCase):
|
|||
for i in range(num_trials):
|
||||
t = Trial("__fake")
|
||||
sched.on_trial_add(None, t)
|
||||
runner = _MockTrialRunner()
|
||||
runner = _MockTrialRunner(sched)
|
||||
return sched, runner
|
||||
|
||||
def default_statistics(self):
|
||||
|
@ -186,14 +203,6 @@ class HyperbandSuite(unittest.TestCase):
|
|||
def downscale(self, n, sched):
|
||||
return int(np.ceil(n / sched._eta))
|
||||
|
||||
def process(self, trl, mock_runner, action):
|
||||
if action == TrialScheduler.CONTINUE:
|
||||
pass
|
||||
elif action == TrialScheduler.PAUSE:
|
||||
mock_runner._pause_trial(trl)
|
||||
elif action == TrialScheduler.STOP:
|
||||
self.stopTrial(trl, mock_runner)
|
||||
|
||||
def basicSetup(self):
|
||||
"""Setup and verify full band.
|
||||
"""
|
||||
|
@ -224,10 +233,6 @@ class HyperbandSuite(unittest.TestCase):
|
|||
|
||||
return sched
|
||||
|
||||
def stopTrial(self, trial, mock_runner):
|
||||
self.assertNotEqual(trial.status, Trial.TERMINATED)
|
||||
mock_runner._stop_trial(trial)
|
||||
|
||||
def testConfigSameEta(self):
|
||||
sched = HyperBandScheduler()
|
||||
i = 0
|
||||
|
@ -283,7 +288,7 @@ class HyperbandSuite(unittest.TestCase):
|
|||
mock_runner, trl, result(cur_units, i))
|
||||
if i < current_length - 1:
|
||||
self.assertEqual(action, TrialScheduler.PAUSE)
|
||||
self.process(trl, mock_runner, action)
|
||||
mock_runner.process_action(trl, action)
|
||||
|
||||
self.assertEqual(action, TrialScheduler.CONTINUE)
|
||||
new_length = len(big_bracket.current_trials())
|
||||
|
@ -304,7 +309,7 @@ class HyperbandSuite(unittest.TestCase):
|
|||
for i, trl in reversed(list(enumerate(big_bracket.current_trials()))):
|
||||
action = sched.on_trial_result(
|
||||
mock_runner, trl, result(cur_units, i))
|
||||
self.process(trl, mock_runner, action)
|
||||
mock_runner.process_action(trl, action)
|
||||
|
||||
self.assertEqual(action, TrialScheduler.STOP)
|
||||
|
||||
|
@ -321,7 +326,7 @@ class HyperbandSuite(unittest.TestCase):
|
|||
for i, trl in enumerate(big_bracket.current_trials()):
|
||||
action = sched.on_trial_result(
|
||||
mock_runner, trl, result(cur_units, i))
|
||||
self.process(trl, mock_runner, action)
|
||||
mock_runner.process_action(trl, action)
|
||||
|
||||
self.assertEqual(action, TrialScheduler.CONTINUE)
|
||||
|
||||
|
@ -412,9 +417,9 @@ class HyperbandSuite(unittest.TestCase):
|
|||
mock_runner._launch_trial(t)
|
||||
|
||||
for i, t in enumerate(bracket_trials):
|
||||
status = sched.on_trial_result(
|
||||
action = sched.on_trial_result(
|
||||
mock_runner, t, result(init_units, i))
|
||||
self.assertEqual(status, TrialScheduler.CONTINUE)
|
||||
self.assertEqual(action, TrialScheduler.CONTINUE)
|
||||
t = Trial("__fake")
|
||||
sched.on_trial_add(None, t)
|
||||
mock_runner._launch_trial(t)
|
||||
|
@ -442,7 +447,7 @@ class HyperbandSuite(unittest.TestCase):
|
|||
for i in range(stats["max_trials"]):
|
||||
t = Trial("__fake")
|
||||
sched.on_trial_add(None, t)
|
||||
runner = _MockTrialRunner()
|
||||
runner = _MockTrialRunner(sched)
|
||||
|
||||
big_bracket = sched._hyperbands[0][-1]
|
||||
|
||||
|
@ -452,17 +457,11 @@ class HyperbandSuite(unittest.TestCase):
|
|||
|
||||
# Provides results from 0 to 8 in order, keeping the last one running
|
||||
for i, trl in enumerate(big_bracket.current_trials()):
|
||||
status = sched.on_trial_result(runner, trl, result2(1, i))
|
||||
if status == TrialScheduler.CONTINUE:
|
||||
continue
|
||||
elif status == TrialScheduler.PAUSE:
|
||||
runner._pause_trial(trl)
|
||||
elif status == TrialScheduler.STOP:
|
||||
self.assertNotEqual(trl.status, Trial.TERMINATED)
|
||||
self.stopTrial(trl, runner)
|
||||
action = sched.on_trial_result(runner, trl, result2(1, i))
|
||||
runner.process_action(trl, action)
|
||||
|
||||
new_length = len(big_bracket.current_trials())
|
||||
self.assertEqual(status, TrialScheduler.CONTINUE)
|
||||
self.assertEqual(action, TrialScheduler.CONTINUE)
|
||||
self.assertEqual(new_length, self.downscale(current_length, sched))
|
||||
|
||||
def testJumpingTime(self):
|
||||
|
@ -476,21 +475,38 @@ class HyperbandSuite(unittest.TestCase):
|
|||
main_trials = big_bracket.current_trials()[:-1]
|
||||
jump = big_bracket.current_trials()[-1]
|
||||
for i, trl in enumerate(main_trials):
|
||||
status = sched.on_trial_result(mock_runner, trl, result(1, i))
|
||||
if status == TrialScheduler.CONTINUE:
|
||||
continue
|
||||
elif status == TrialScheduler.PAUSE:
|
||||
mock_runner._pause_trial(trl)
|
||||
elif status == TrialScheduler.STOP:
|
||||
self.assertNotEqual(trl.status, Trial.TERMINATED)
|
||||
self.stopTrial(trl, mock_runner)
|
||||
action = sched.on_trial_result(mock_runner, trl, result(1, i))
|
||||
mock_runner.process_action(trl, action)
|
||||
|
||||
status = sched.on_trial_result(mock_runner, jump, result(4, i))
|
||||
self.assertEqual(status, TrialScheduler.PAUSE)
|
||||
action = sched.on_trial_result(mock_runner, jump, result(4, i))
|
||||
self.assertEqual(action, TrialScheduler.PAUSE)
|
||||
|
||||
current_length = len(big_bracket.current_trials())
|
||||
self.assertLess(current_length, 27)
|
||||
|
||||
def testRemove(self):
|
||||
"""Test with 4: start 1, remove 1 pending, add 2, remove 1 pending"""
|
||||
sched, runner = self.schedulerSetup(4)
|
||||
trials = sorted(list(sched._trial_info), key=lambda t: t.trial_id)
|
||||
runner._launch_trial(trials[0])
|
||||
sched.on_trial_result(runner, trials[0], result(1, 5))
|
||||
self.assertEqual(trials[0].status, Trial.RUNNING)
|
||||
self.assertEqual(trials[1].status, Trial.PENDING)
|
||||
|
||||
bracket, _ = sched._trial_info[trials[1]]
|
||||
self.assertTrue(trials[1] in bracket._live_trials)
|
||||
sched.on_trial_remove(runner, trials[1])
|
||||
self.assertFalse(trials[1] in bracket._live_trials)
|
||||
|
||||
for i in range(2):
|
||||
trial = Trial("__fake")
|
||||
sched.on_trial_add(None, trial)
|
||||
|
||||
bracket, _ = sched._trial_info[trial]
|
||||
self.assertTrue(trial in bracket._live_trials)
|
||||
sched.on_trial_remove(runner, trial) # where trial is not running
|
||||
self.assertFalse(trial in bracket._live_trials)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main(verbosity=2)
|
||||
|
|
103
test/tune_server_test.py
Normal file
103
test/tune_server_test.py
Normal file
|
@ -0,0 +1,103 @@
|
|||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import unittest
|
||||
import socket
|
||||
|
||||
import ray
|
||||
from ray.rllib import _register_all
|
||||
from ray.tune.trial import Trial, Resources
|
||||
from ray.tune.web_server import TuneClient
|
||||
from ray.tune.trial_runner import TrialRunner
|
||||
|
||||
|
||||
def get_valid_port():
|
||||
port = 4321
|
||||
while True:
|
||||
try:
|
||||
print("Trying port", port)
|
||||
port_test_socket = socket.socket()
|
||||
port_test_socket.bind(("127.0.0.1", port))
|
||||
port_test_socket.close()
|
||||
break
|
||||
except socket.error:
|
||||
port += 1
|
||||
return port
|
||||
|
||||
|
||||
class TuneServerSuite(unittest.TestCase):
|
||||
def basicSetup(self):
|
||||
ray.init(num_cpus=4, num_gpus=1)
|
||||
port = get_valid_port()
|
||||
self.runner = TrialRunner(
|
||||
launch_web_server=True, server_port=port)
|
||||
runner = self.runner
|
||||
kwargs = {
|
||||
"stopping_criterion": {"training_iteration": 3},
|
||||
"resources": Resources(cpu=1, gpu=1),
|
||||
}
|
||||
trials = [
|
||||
Trial("__fake", **kwargs),
|
||||
Trial("__fake", **kwargs)]
|
||||
for t in trials:
|
||||
runner.add_trial(t)
|
||||
client = TuneClient("localhost:{}".format(port))
|
||||
return runner, client
|
||||
|
||||
def tearDown(self):
|
||||
print("Tearing down....")
|
||||
try:
|
||||
self.runner._server.shutdown()
|
||||
self.runner = None
|
||||
except Exception as e:
|
||||
print(e)
|
||||
ray.worker.cleanup()
|
||||
_register_all()
|
||||
|
||||
def testAddTrial(self):
|
||||
runner, client = self.basicSetup()
|
||||
for i in range(3):
|
||||
runner.step()
|
||||
spec = {
|
||||
"run": "__fake",
|
||||
"stop": {"training_iteration": 3},
|
||||
"resources": dict(cpu=1, gpu=1),
|
||||
}
|
||||
client.add_trial("test", spec)
|
||||
runner.step()
|
||||
all_trials = client.get_all_trials()["trials"]
|
||||
runner.step()
|
||||
self.assertEqual(len(all_trials), 3)
|
||||
|
||||
def testGetTrials(self):
|
||||
runner, client = self.basicSetup()
|
||||
for i in range(3):
|
||||
runner.step()
|
||||
all_trials = client.get_all_trials()["trials"]
|
||||
self.assertEqual(len(all_trials), 2)
|
||||
tid = all_trials[0]["id"]
|
||||
client.get_trial(tid)
|
||||
runner.step()
|
||||
self.assertEqual(len(all_trials), 2)
|
||||
|
||||
def testStopTrial(self):
|
||||
"""Check if Stop Trial works"""
|
||||
runner, client = self.basicSetup()
|
||||
for i in range(2):
|
||||
runner.step()
|
||||
all_trials = client.get_all_trials()["trials"]
|
||||
self.assertEqual(
|
||||
len([t for t in all_trials if t["status"] == Trial.RUNNING]), 1)
|
||||
|
||||
tid = [t for t in all_trials if t["status"] == Trial.RUNNING][0]["id"]
|
||||
client.stop_trial(tid)
|
||||
runner.step()
|
||||
|
||||
all_trials = client.get_all_trials()["trials"]
|
||||
self.assertEqual(
|
||||
len([t for t in all_trials if t["status"] == Trial.RUNNING]), 0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main(verbosity=2)
|
Loading…
Add table
Reference in a new issue