[tune] Experiment Management API (#1328)

* init for exposing external interface

* revisions

* http server

* small

* simplify

* ui

* fixes

* test

* nit

* nit

* merge

* untested

* nits

* nit

* init tests

* tests

* more tests

* nit

* fix hyperband

* cleanup

* nits

* good stuff

* cleanup

* comments and need to test

* nit

* notebook

* testing

* test and expose server

* server_tests

* docs

* periods

* fix tests

* committing test

* fi
This commit is contained in:
Richard Liaw 2018-01-24 13:45:10 -08:00 committed by Eric Liang
parent 1d2a28ab07
commit a7d544424c
15 changed files with 599 additions and 88 deletions

View file

@ -123,6 +123,7 @@ script:
- python test/monitor_test.py
- python test/trial_runner_test.py
- python test/trial_scheduler_test.py
- python test/tune_server_test.py
- python test/cython_test.py
- python -m pytest python/ray/dataframe/test/test_dataframe.py
- python -m pytest python/ray/dataframe/test/test_series.py

View file

@ -159,3 +159,26 @@ Running in a large cluster
--------------------------
The ``run_experiments`` also takes any arguments that ``ray.init()`` does. This can be used to pass in the redis address of a multi-node Ray cluster. For more details, check out the `tune.py script <https://github.com/ray-project/ray/blob/master/python/ray/tune/tune.py>`__.
Client API
----------
You can modify an ongoing experiment by adding or deleting trials using the Tune Client API. To do this, start your experiment with a flag, either from the command-line, e.g.:
::
cd ray/python/tune
./tune.py -f examples/tune_mnist_ray.yaml --server=True --server-port=4321
Or within the Python API, e.g.:
::
run_experiments({...}, with_server=True, server_port=4321)
Then, on the client side, you can use the following class. The server address defaults to ``localhost:4321``. If on a cluster, you may want to forward this port so that you can use the Client on your local machine.
.. autoclass:: ray.tune.web_server.TuneClient
:members:
For an example notebook for using the Client API, see the `Client API Example <https://github.com/ray-project/ray/tree/master/python/ray/tune/TuneClient.ipynb>`__.

View file

@ -0,0 +1,88 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from ray.tune.web_server import TuneClient\n",
"\n",
"manager = TuneClient(tune_address=\"localhost:4321\")\n",
"\n",
"x = manager.get_all_trials()\n",
"\n",
"[((y[\"id\"]), y[\"status\"]) for y in x[\"trials\"]]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true,
"scrolled": false
},
"outputs": [],
"source": [
"for y in x[\"trials\"][-10:]:\n",
" manager.stop_trial(y[\"id\"])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"from ray.tune.variant_generator import generate_trials\n",
"import yaml\n",
"\n",
"with open(\"../rllib/tuned_examples/hyperband-cartpole.yaml\") as f:\n",
" d = yaml.load(f)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"name, spec = [x for x in d.items()][0]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"manager.add_trial(name, spec)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.2"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

View file

@ -6,3 +6,8 @@ from __future__ import print_function
class TuneError(Exception):
"""General error class raised by ray.tune."""
pass
class TuneManagerError(TuneError):
"""Error raised in operating the Tune Manager."""
pass

View file

@ -87,7 +87,9 @@ class HyperBandScheduler(FIFOScheduler):
self._time_attr = time_attr
def on_trial_add(self, trial_runner, trial):
"""On a new trial add, if current bracket is not filled,
"""Adds new trial.
On a new trial add, if current bracket is not filled,
add to current bracket. Else, if current band is not filled,
create new bracket, add to current bracket.
Else, create new iteration, create new bracket, add to bracket."""
@ -136,9 +138,8 @@ class HyperBandScheduler(FIFOScheduler):
This scheduler will not start trials but will stop trials.
The current running trial will not be handled,
as the trialrunner will be given control to handle it.
as the trialrunner will be given control to handle it."""
# TODO(rliaw) should be only called if trial has not errored"""
bracket, _ = self._trial_info[trial]
bracket.update_trial_stats(trial, result)
@ -160,16 +161,17 @@ class HyperBandScheduler(FIFOScheduler):
action = TrialScheduler.PAUSE
if bracket.cur_iter_done():
if bracket.finished():
self._cleanup_bracket(trial_runner, bracket)
bracket.cleanup_full(trial_runner)
return TrialScheduler.CONTINUE
good, bad = bracket.successive_halving(self._reward_attr)
# kill bad trials
self._num_stopped += len(bad)
for t in bad:
if t.status == Trial.PAUSED:
self._cleanup_trial(trial_runner, t, bracket, hard=True)
trial_runner.stop_trial(t)
elif t.status == Trial.RUNNING:
self._cleanup_trial(trial_runner, t, bracket, hard=False)
bracket.cleanup_trial(t)
action = TrialScheduler.STOP
else:
raise Exception("Trial with unexpected status encountered")
@ -185,47 +187,30 @@ class HyperBandScheduler(FIFOScheduler):
action = TrialScheduler.CONTINUE
return action
def _cleanup_trial(self, trial_runner, t, bracket, hard=False):
"""Bookkeeping for trials finished. If `hard=True`, then
this scheduler will force the trial_runner to release resources.
def on_trial_remove(self, trial_runner, trial):
"""Notification when trial terminates.
Otherwise, only clean up trial information locally."""
self._num_stopped += 1
if hard:
trial_runner._stop_trial(t)
bracket.cleanup_trial(t)
def _cleanup_bracket(self, trial_runner, bracket):
"""Cleans up bracket after bracket is completely finished.
Lets the last trial continue to run until termination condition
kicks in."""
for trial in bracket.current_trials():
if (trial.status == Trial.PAUSED):
self._cleanup_trial(
trial_runner, trial, bracket,
hard=True)
Trial info is removed from bracket. Triggers halving if bracket is
not finished."""
bracket, _ = self._trial_info[trial]
bracket.cleanup_trial(trial)
if not bracket.finished():
self._process_bracket(trial_runner, bracket, trial)
def on_trial_complete(self, trial_runner, trial, result):
"""Cleans up trial info from bracket if trial completed early."""
bracket, _ = self._trial_info[trial]
self._cleanup_trial(trial_runner, trial, bracket, hard=False)
self._process_bracket(trial_runner, bracket, trial)
self.on_trial_remove(trial_runner, trial)
def on_trial_error(self, trial_runner, trial):
"""Cleans up trial info from bracket if trial errored early."""
bracket, _ = self._trial_info[trial]
self._cleanup_trial(trial_runner, trial, bracket, hard=False)
self._process_bracket(trial_runner, bracket, trial)
self.on_trial_remove(trial_runner, trial)
def choose_trial_to_run(self, trial_runner, *args):
"""Fair scheduling within iteration by completion percentage.
List of trials not used since all trials are tracked as state
of scheduler.
If iteration is occupied (ie, no trials to run), then look into
next iteration."""
List of trials not used since all trials are tracked as state
of scheduler. If iteration is occupied (ie, no trials to run),
then look into next iteration."""
for hyperband in self._hyperbands:
for bracket in sorted(hyperband,
@ -237,6 +222,7 @@ class HyperBandScheduler(FIFOScheduler):
return None
def debug_string(self):
# TODO(rliaw): This debug string needs work
brackets = [
"({0}/{1})".format(
len(bracket._live_trials), len(bracket._all_trials))
@ -301,8 +287,11 @@ class Bracket():
return False
def filled(self):
"""We will only let new trials be added at current level,
minimizing the need to backtrack and bookkeep previous medians"""
"""Checks if bracket is filled.
Only let new trials be added at current level minimizing the need
to backtrack and bookkeep previous medians."""
return len(self._live_trials) == self._n
def successive_halving(self, reward_attr):
@ -346,6 +335,15 @@ class Bracket():
assert trial in self._live_trials
del self._live_trials[trial]
def cleanup_full(self, trial_runner):
"""Cleans up bracket after bracket is completely finished.
Lets the last trial continue to run until termination condition
kicks in."""
for trial in self.current_trials():
if (trial.status == Trial.PAUSED):
trial_runner.stop_trial(trial)
def completion_percentage(self):
"""Returns a progress metric.
@ -374,5 +372,8 @@ class Bracket():
"r={}".format(self._r),
"progress={}".format(self.completion_percentage())
])
return "Bracket({})".format(status)
def debug_string(self):
trials = ", ".join([t.status for t in self._live_trials])
return "Bracket({})[{}]".format(status, trials)
return "{}[{}]".format(self, trials)

View file

@ -5,6 +5,7 @@ from __future__ import print_function
import collections
import numpy as np
from ray.tune.trial import Trial
from ray.tune.trial_scheduler import FIFOScheduler, TrialScheduler
@ -74,6 +75,11 @@ class MedianStoppingRule(FIFOScheduler):
self._results[trial].append(result)
self._completed_trials.add(trial)
def on_trial_remove(self, trial_runner, trial):
"""Marks trial as completed if it is paused and has previously ran."""
if trial.status is Trial.PAUSED and trial in self._results:
self._completed_trials.add(trial)
def debug_string(self):
return "Using MedianStoppingRule: num_stopped={}.".format(
len(self._stopped_trials))

View file

@ -29,7 +29,7 @@ TrainingResult = namedtuple("TrainingResult", [
# (Required) Accumulated timesteps for this entire experiment.
"timesteps_total",
# (Optional) If training is finished.
# (Optional) If training is terminated.
"done",
# (Optional) Custom metadata to report for this iteration.

View file

@ -9,6 +9,7 @@ import ray
import os
from collections import namedtuple
from ray.utils import random_string, binary_to_hex
from ray.tune import TuneError
from ray.tune.logger import NoopLogger, UnifiedLogger
from ray.tune.result import TrainingResult, DEFAULT_RESULTS_DIR, pretty_print
@ -105,6 +106,7 @@ class Trial(object):
self.location = None
self.logdir = None
self.result_logger = None
self.trial_id = binary_to_hex(random_string())[:8]
def start(self):
"""Starts this trial.

View file

@ -8,6 +8,7 @@ import time
import traceback
from ray.tune import TuneError
from ray.tune.web_server import TuneServer
from ray.tune.trial import Trial, Resources
from ray.tune.trial_scheduler import FIFOScheduler, TrialScheduler
@ -34,8 +35,14 @@ class TrialRunner(object):
misleading benchmark results.
"""
def __init__(self, scheduler=None):
"""Initializes a new TrialRunner."""
def __init__(self, scheduler=None, launch_web_server=False,
server_port=TuneServer.DEFAULT_PORT):
"""Initializes a new TrialRunner.
Args:
scheduler (TrialScheduler): Defaults to FIFOScheduler.
launch_web_server (bool): Flag for starting TuneServer
server_port (int): Port number for launching TuneServer"""
self._scheduler_alg = scheduler or FIFOScheduler()
self._trials = []
@ -49,6 +56,10 @@ class TrialRunner(object):
self._global_time_limit = float(
os.environ.get("TRIALRUNNER_WALLTIME_LIMIT", float('inf')))
self._total_time = 0
self._server = None
if launch_web_server:
self._server = TuneServer(self, server_port)
self._stop_queue = []
def is_finished(self):
"""Returns whether all trials have finished running."""
@ -70,7 +81,6 @@ class TrialRunner(object):
Callers should typically run this method repeatedly in a loop. They
may inspect or modify the runner's state in between calls to step().
"""
if self._can_launch_more():
self._launch_trial()
elif self._running:
@ -91,6 +101,16 @@ class TrialRunner(object):
"trials with sufficient resources.")
raise TuneError("Called step when all trials finished?")
if self._server:
self._process_requests()
if self.is_finished():
self._server.shutdown()
def get_trial(self, tid):
trial = [t for t in self._trials if t.trial_id == tid]
return trial[0] if trial else None
def get_trials(self):
"""Returns the list of trials managed by this TrialRunner.
@ -207,6 +227,43 @@ class TrialRunner(object):
assert self._committed_resources.cpu >= 0
assert self._committed_resources.gpu >= 0
def request_stop_trial(self, trial):
self._stop_queue.append(trial)
def _process_requests(self):
while self._stop_queue:
t = self._stop_queue.pop()
self.stop_trial(t)
def stop_trial(self, trial):
"""Stops trial.
Trials may be stopped at any time. If trial is in state PENDING
or PAUSED, calls `scheduler.on_trial_remove`. Otherwise waits for
result for the trial and calls `scheduler.on_trial_complete`
if RUNNING."""
error = False
if trial.status in [Trial.ERROR, Trial.TERMINATED]:
return
elif trial.status in [Trial.PENDING, Trial.PAUSED]:
self._scheduler_alg.on_trial_remove(self, trial)
elif trial.status is Trial.RUNNING:
# NOTE: There should only be one...
result_id = [rid for rid, t in self._running.items()
if t is trial][0]
self._running.pop(result_id)
try:
result = ray.get(result_id)
trial.update_last_result(result, terminate=True)
self._scheduler_alg.on_trial_complete(self, trial, result)
except Exception:
print("Error processing event:", traceback.format_exc())
self._scheduler_alg.on_trial_error(self, trial)
error = True
self._stop_trial(trial, error=error)
def _stop_trial(self, trial, error=False):
"""Only returns resources if resources allocated."""
prior_status = trial.status

View file

@ -26,14 +26,24 @@ class TrialScheduler(object):
"""Called on each intermediate result returned by a trial.
At this point, the trial scheduler can make a decision by returning
one of CONTINUE, PAUSE, and STOP."""
one of CONTINUE, PAUSE, and STOP. This will only be called when the
trial is in the RUNNING state."""
raise NotImplementedError
def on_trial_complete(self, trial_runner, trial, result):
"""Notification for the completion of trial.
This will only be called when the trial completes naturally."""
This will only be called when the trial is in the RUNNING state and
either completes naturally or by manual termination."""
raise NotImplementedError
def on_trial_remove(self, trial_runner, trial):
"""Called to remove trial.
This is called when the trial is in PAUSED or PENDING state. Otherwise,
call `on_trial_complete`."""
raise NotImplementedError
@ -66,6 +76,9 @@ class FIFOScheduler(TrialScheduler):
def on_trial_complete(self, trial_runner, trial, result):
pass
def on_trial_remove(self, trial_runner, trial):
pass
def choose_trial_to_run(self, trial_runner):
for trial in trial_runner.get_trials():
if (trial.status == Trial.PENDING and

View file

@ -16,6 +16,7 @@ from ray.tune.median_stopping_rule import MedianStoppingRule
from ray.tune.trial import Trial
from ray.tune.trial_runner import TrialRunner
from ray.tune.trial_scheduler import FIFOScheduler
from ray.tune.web_server import TuneServer
from ray.tune.variant_generator import generate_trials
@ -41,6 +42,10 @@ parser.add_argument("--scheduler", default="FIFO", type=str,
help="FIFO, MedianStopping, or HyperBand")
parser.add_argument("--scheduler-config", default="{}", type=json.loads,
help="Config options to pass to the scheduler.")
parser.add_argument("--server", default=False, type=bool,
help="Option to launch Tune Server")
parser.add_argument("--server-port", default=TuneServer.DEFAULT_PORT,
type=int, help="Option to launch Tune Server")
parser.add_argument("-f", "--config-file", required=True, type=str,
help="Read experiment options from this JSON/YAML file.")
@ -61,10 +66,13 @@ def _make_scheduler(args):
args.scheduler, _SCHEDULERS.keys()))
def run_experiments(experiments, scheduler=None, **ray_args):
def run_experiments(experiments, scheduler=None, with_server=False,
server_port=TuneServer.DEFAULT_PORT, **ray_args):
if scheduler is None:
scheduler = FIFOScheduler()
runner = TrialRunner(scheduler)
runner = TrialRunner(
scheduler, launch_web_server=with_server, server_port=server_port)
for name, spec in experiments.items():
for trial in generate_trials(spec, name):
@ -78,6 +86,7 @@ def run_experiments(experiments, scheduler=None, **ray_args):
print(runner.debug_string())
for trial in runner.get_trials():
# TODO(rliaw): What about errored?
if trial.status != Trial.TERMINATED:
raise TuneError("Trial did not complete", trial)
@ -90,5 +99,6 @@ if __name__ == "__main__":
with open(args.config_file) as f:
experiments = yaml.load(f)
run_experiments(
experiments, _make_scheduler(args), redis_address=args.redis_address,
experiments, _make_scheduler(args), with_server=args.server,
server_port=args.server_port, redis_address=args.redis_address,
num_cpus=args.num_cpus, num_gpus=args.num_gpus)

View file

@ -0,0 +1,146 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import requests
import json
import threading
from http.server import HTTPServer, BaseHTTPRequestHandler
from ray.tune.error import TuneError, TuneManagerError
from ray.tune.variant_generator import generate_trials
class TuneClient(object):
"""Client to interact with ongoing Tune experiment.
Requires server to have started running."""
STOP = "STOP"
ADD = "ADD"
GET_LIST = "GET_LIST"
GET_TRIAL = "GET_TRIAL"
def __init__(self, tune_address):
# TODO(rliaw): Better to specify address and port forward
self._tune_address = tune_address
self._path = "http://{}".format(tune_address)
def get_all_trials(self):
"""Returns a list of all trials (trial_id, config, status)."""
return self._get_response(
{"command": TuneClient.GET_LIST})
def get_trial(self, trial_id):
"""Returns the last result for queried trial."""
return self._get_response(
{"command": TuneClient.GET_TRIAL,
"trial_id": trial_id})
def add_trial(self, name, trial_spec):
"""Adds a trial of `name` with configurations."""
# TODO(rliaw): have better way of specifying a new trial
return self._get_response(
{"command": TuneClient.ADD,
"name": name,
"spec": trial_spec})
def stop_trial(self, trial_id):
"""Requests to stop trial."""
return self._get_response(
{"command": TuneClient.STOP,
"trial_id": trial_id})
def _get_response(self, data):
payload = json.dumps(data).encode()
response = requests.get(self._path, data=payload)
parsed = response.json()
return parsed
def RunnerHandler(runner):
class Handler(BaseHTTPRequestHandler):
def do_GET(self):
content_len = int(self.headers.get('Content-Length'), 0)
raw_body = self.rfile.read(content_len)
parsed_input = json.loads(raw_body.decode())
status, response = self.execute_command(parsed_input)
if status:
self.send_response(200)
else:
self.send_response(400)
self.end_headers()
self.wfile.write(json.dumps(
response).encode())
def trial_info(self, trial):
if trial.last_result:
result = trial.last_result._asdict()
else:
result = None
info_dict = {
"id": trial.trial_id,
"trainable_name": trial.trainable_name,
"config": trial.config,
"status": trial.status,
"result": result
}
return info_dict
def execute_command(self, args):
def get_trial():
trial = runner.get_trial(args["trial_id"])
if trial is None:
error = "Trial ({}) not found.".format(args["trial_id"])
raise TuneManagerError(error)
else:
return trial
command = args["command"]
response = {}
try:
if command == TuneClient.GET_LIST:
response["trials"] = [self.trial_info(t)
for t in runner.get_trials()]
elif command == TuneClient.GET_TRIAL:
trial = get_trial()
response["trial_info"] = self.trial_info(trial)
elif command == TuneClient.STOP:
trial = get_trial()
runner.request_stop_trial(trial)
elif command == TuneClient.ADD:
name = args["name"]
spec = args["spec"]
for trial in generate_trials(spec, name):
runner.add_trial(trial)
else:
raise TuneManagerError("Unknown command.")
status = True
except TuneError as e:
status = False
response["message"] = str(e)
return status, response
return Handler
class TuneServer(threading.Thread):
DEFAULT_PORT = 4321
def __init__(self, runner, port=None):
threading.Thread.__init__(self)
self._port = port if port else self.DEFAULT_PORT
address = ('localhost', self._port)
print("Starting Tune Server...")
self._server = HTTPServer(
address, RunnerHandler(runner))
self.start()
def run(self):
self._server.serve_forever()
def shutdown(self):
self._server.shutdown()

View file

@ -497,6 +497,46 @@ class TrialRunnerTest(unittest.TestCase):
runner.step()
self.assertEqual(trials[0].status, Trial.TERMINATED)
def testStopTrial(self):
ray.init(num_cpus=4, num_gpus=2)
runner = TrialRunner()
kwargs = {
"stopping_criterion": {"training_iteration": 5},
"resources": Resources(cpu=1, gpu=1),
}
trials = [
Trial("__fake", **kwargs),
Trial("__fake", **kwargs),
Trial("__fake", **kwargs),
Trial("__fake", **kwargs)]
for t in trials:
runner.add_trial(t)
runner.step()
self.assertEqual(trials[0].status, Trial.RUNNING)
self.assertEqual(trials[1].status, Trial.PENDING)
# Stop trial while running
runner.stop_trial(trials[0])
self.assertEqual(trials[0].status, Trial.TERMINATED)
self.assertEqual(trials[1].status, Trial.PENDING)
runner.step()
self.assertEqual(trials[0].status, Trial.TERMINATED)
self.assertEqual(trials[1].status, Trial.RUNNING)
self.assertEqual(trials[-1].status, Trial.PENDING)
# Stop trial while pending
runner.stop_trial(trials[-1])
self.assertEqual(trials[0].status, Trial.TERMINATED)
self.assertEqual(trials[1].status, Trial.RUNNING)
self.assertEqual(trials[-1].status, Trial.TERMINATED)
runner.step()
self.assertEqual(trials[0].status, Trial.TERMINATED)
self.assertEqual(trials[1].status, Trial.RUNNING)
self.assertEqual(trials[2].status, Trial.RUNNING)
self.assertEqual(trials[-1].status, Trial.TERMINATED)
if __name__ == "__main__":
unittest.main(verbosity=2)

View file

@ -140,8 +140,25 @@ class EarlyStoppingSuite(unittest.TestCase):
class _MockTrialRunner():
def _stop_trial(self, trial):
trial.stop()
def __init__(self, scheduler):
self._scheduler_alg = scheduler
def process_action(self, trial, action):
if action == TrialScheduler.CONTINUE:
pass
elif action == TrialScheduler.PAUSE:
self._pause_trial(trial)
elif action == TrialScheduler.STOP:
trial.stop()
def stop_trial(self, trial):
if trial.status in [Trial.ERROR, Trial.TERMINATED]:
return
elif trial.status in [Trial.PENDING, Trial.PAUSED]:
self._scheduler_alg.on_trial_remove(self, trial)
else:
self._scheduler_alg.on_trial_complete(self, trial, result(100, 10))
def has_resources(self, resources):
return True
@ -168,7 +185,7 @@ class HyperbandSuite(unittest.TestCase):
for i in range(num_trials):
t = Trial("__fake")
sched.on_trial_add(None, t)
runner = _MockTrialRunner()
runner = _MockTrialRunner(sched)
return sched, runner
def default_statistics(self):
@ -186,14 +203,6 @@ class HyperbandSuite(unittest.TestCase):
def downscale(self, n, sched):
return int(np.ceil(n / sched._eta))
def process(self, trl, mock_runner, action):
if action == TrialScheduler.CONTINUE:
pass
elif action == TrialScheduler.PAUSE:
mock_runner._pause_trial(trl)
elif action == TrialScheduler.STOP:
self.stopTrial(trl, mock_runner)
def basicSetup(self):
"""Setup and verify full band.
"""
@ -224,10 +233,6 @@ class HyperbandSuite(unittest.TestCase):
return sched
def stopTrial(self, trial, mock_runner):
self.assertNotEqual(trial.status, Trial.TERMINATED)
mock_runner._stop_trial(trial)
def testConfigSameEta(self):
sched = HyperBandScheduler()
i = 0
@ -283,7 +288,7 @@ class HyperbandSuite(unittest.TestCase):
mock_runner, trl, result(cur_units, i))
if i < current_length - 1:
self.assertEqual(action, TrialScheduler.PAUSE)
self.process(trl, mock_runner, action)
mock_runner.process_action(trl, action)
self.assertEqual(action, TrialScheduler.CONTINUE)
new_length = len(big_bracket.current_trials())
@ -304,7 +309,7 @@ class HyperbandSuite(unittest.TestCase):
for i, trl in reversed(list(enumerate(big_bracket.current_trials()))):
action = sched.on_trial_result(
mock_runner, trl, result(cur_units, i))
self.process(trl, mock_runner, action)
mock_runner.process_action(trl, action)
self.assertEqual(action, TrialScheduler.STOP)
@ -321,7 +326,7 @@ class HyperbandSuite(unittest.TestCase):
for i, trl in enumerate(big_bracket.current_trials()):
action = sched.on_trial_result(
mock_runner, trl, result(cur_units, i))
self.process(trl, mock_runner, action)
mock_runner.process_action(trl, action)
self.assertEqual(action, TrialScheduler.CONTINUE)
@ -412,9 +417,9 @@ class HyperbandSuite(unittest.TestCase):
mock_runner._launch_trial(t)
for i, t in enumerate(bracket_trials):
status = sched.on_trial_result(
action = sched.on_trial_result(
mock_runner, t, result(init_units, i))
self.assertEqual(status, TrialScheduler.CONTINUE)
self.assertEqual(action, TrialScheduler.CONTINUE)
t = Trial("__fake")
sched.on_trial_add(None, t)
mock_runner._launch_trial(t)
@ -442,7 +447,7 @@ class HyperbandSuite(unittest.TestCase):
for i in range(stats["max_trials"]):
t = Trial("__fake")
sched.on_trial_add(None, t)
runner = _MockTrialRunner()
runner = _MockTrialRunner(sched)
big_bracket = sched._hyperbands[0][-1]
@ -452,17 +457,11 @@ class HyperbandSuite(unittest.TestCase):
# Provides results from 0 to 8 in order, keeping the last one running
for i, trl in enumerate(big_bracket.current_trials()):
status = sched.on_trial_result(runner, trl, result2(1, i))
if status == TrialScheduler.CONTINUE:
continue
elif status == TrialScheduler.PAUSE:
runner._pause_trial(trl)
elif status == TrialScheduler.STOP:
self.assertNotEqual(trl.status, Trial.TERMINATED)
self.stopTrial(trl, runner)
action = sched.on_trial_result(runner, trl, result2(1, i))
runner.process_action(trl, action)
new_length = len(big_bracket.current_trials())
self.assertEqual(status, TrialScheduler.CONTINUE)
self.assertEqual(action, TrialScheduler.CONTINUE)
self.assertEqual(new_length, self.downscale(current_length, sched))
def testJumpingTime(self):
@ -476,21 +475,38 @@ class HyperbandSuite(unittest.TestCase):
main_trials = big_bracket.current_trials()[:-1]
jump = big_bracket.current_trials()[-1]
for i, trl in enumerate(main_trials):
status = sched.on_trial_result(mock_runner, trl, result(1, i))
if status == TrialScheduler.CONTINUE:
continue
elif status == TrialScheduler.PAUSE:
mock_runner._pause_trial(trl)
elif status == TrialScheduler.STOP:
self.assertNotEqual(trl.status, Trial.TERMINATED)
self.stopTrial(trl, mock_runner)
action = sched.on_trial_result(mock_runner, trl, result(1, i))
mock_runner.process_action(trl, action)
status = sched.on_trial_result(mock_runner, jump, result(4, i))
self.assertEqual(status, TrialScheduler.PAUSE)
action = sched.on_trial_result(mock_runner, jump, result(4, i))
self.assertEqual(action, TrialScheduler.PAUSE)
current_length = len(big_bracket.current_trials())
self.assertLess(current_length, 27)
def testRemove(self):
"""Test with 4: start 1, remove 1 pending, add 2, remove 1 pending"""
sched, runner = self.schedulerSetup(4)
trials = sorted(list(sched._trial_info), key=lambda t: t.trial_id)
runner._launch_trial(trials[0])
sched.on_trial_result(runner, trials[0], result(1, 5))
self.assertEqual(trials[0].status, Trial.RUNNING)
self.assertEqual(trials[1].status, Trial.PENDING)
bracket, _ = sched._trial_info[trials[1]]
self.assertTrue(trials[1] in bracket._live_trials)
sched.on_trial_remove(runner, trials[1])
self.assertFalse(trials[1] in bracket._live_trials)
for i in range(2):
trial = Trial("__fake")
sched.on_trial_add(None, trial)
bracket, _ = sched._trial_info[trial]
self.assertTrue(trial in bracket._live_trials)
sched.on_trial_remove(runner, trial) # where trial is not running
self.assertFalse(trial in bracket._live_trials)
if __name__ == "__main__":
unittest.main(verbosity=2)

103
test/tune_server_test.py Normal file
View file

@ -0,0 +1,103 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import unittest
import socket
import ray
from ray.rllib import _register_all
from ray.tune.trial import Trial, Resources
from ray.tune.web_server import TuneClient
from ray.tune.trial_runner import TrialRunner
def get_valid_port():
port = 4321
while True:
try:
print("Trying port", port)
port_test_socket = socket.socket()
port_test_socket.bind(("127.0.0.1", port))
port_test_socket.close()
break
except socket.error:
port += 1
return port
class TuneServerSuite(unittest.TestCase):
def basicSetup(self):
ray.init(num_cpus=4, num_gpus=1)
port = get_valid_port()
self.runner = TrialRunner(
launch_web_server=True, server_port=port)
runner = self.runner
kwargs = {
"stopping_criterion": {"training_iteration": 3},
"resources": Resources(cpu=1, gpu=1),
}
trials = [
Trial("__fake", **kwargs),
Trial("__fake", **kwargs)]
for t in trials:
runner.add_trial(t)
client = TuneClient("localhost:{}".format(port))
return runner, client
def tearDown(self):
print("Tearing down....")
try:
self.runner._server.shutdown()
self.runner = None
except Exception as e:
print(e)
ray.worker.cleanup()
_register_all()
def testAddTrial(self):
runner, client = self.basicSetup()
for i in range(3):
runner.step()
spec = {
"run": "__fake",
"stop": {"training_iteration": 3},
"resources": dict(cpu=1, gpu=1),
}
client.add_trial("test", spec)
runner.step()
all_trials = client.get_all_trials()["trials"]
runner.step()
self.assertEqual(len(all_trials), 3)
def testGetTrials(self):
runner, client = self.basicSetup()
for i in range(3):
runner.step()
all_trials = client.get_all_trials()["trials"]
self.assertEqual(len(all_trials), 2)
tid = all_trials[0]["id"]
client.get_trial(tid)
runner.step()
self.assertEqual(len(all_trials), 2)
def testStopTrial(self):
"""Check if Stop Trial works"""
runner, client = self.basicSetup()
for i in range(2):
runner.step()
all_trials = client.get_all_trials()["trials"]
self.assertEqual(
len([t for t in all_trials if t["status"] == Trial.RUNNING]), 1)
tid = [t for t in all_trials if t["status"] == Trial.RUNNING][0]["id"]
client.stop_trial(tid)
runner.step()
all_trials = client.get_all_trials()["trials"]
self.assertEqual(
len([t for t in all_trials if t["status"] == Trial.RUNNING]), 0)
if __name__ == "__main__":
unittest.main(verbosity=2)