ray/python/ray/tune/function_runner.py
Eric Liang 173f1d629a
[tune] Ray Tune API cleanup (#1454)
Remove rllib dep: trainable is now a standalone abstract class that can be easily subclassed.

Clean up hyperband: fix debug string and add an example.

Remove YAML api / ScriptRunner: this was never really used.

Move ray.init() out of run_experiments(): This provides greater flexibility and should be less confusing since there isn't an implicit init() done there. Note that this is a breaking API change for tune.
2018-01-24 16:55:17 -08:00

125 lines
3.9 KiB
Python

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import time
import threading
import traceback
from ray.tune import TuneError
from ray.tune.trainable import Trainable
from ray.tune.result import TrainingResult
class StatusReporter(object):
"""Object passed into your main() that you can report status through."""
def __init__(self):
self._latest_result = None
self._last_result = None
self._lock = threading.Lock()
self._error = None
self._done = False
def __call__(self, **kwargs):
"""Report updated training status.
Args:
kwargs (TrainingResult): Latest training result status. You must
at least define `timesteps_total`, but probably want to report
some of the other metrics as well.
"""
with self._lock:
self._latest_result = self._last_result = TrainingResult(**kwargs)
def _get_and_clear_status(self):
if self._error:
raise TuneError("Error running trial: " + str(self._error))
if self._done and not self._latest_result:
if not self._last_result:
raise TuneError("Trial finished without reporting result!")
return self._last_result._replace(done=True)
with self._lock:
res = self._latest_result
self._latest_result = None
return res
def _stop(self):
self._error = "Agent stopped"
DEFAULT_CONFIG = {
# batch results to at least this granularity
"script_min_iter_time_s": 1,
}
class _RunnerThread(threading.Thread):
"""Supervisor thread that runs your script."""
def __init__(self, entrypoint, config, status_reporter):
self._entrypoint = entrypoint
self._entrypoint_args = [config, status_reporter]
self._status_reporter = status_reporter
threading.Thread.__init__(self)
self.daemon = True
def run(self):
try:
self._entrypoint(*self._entrypoint_args)
except Exception as e:
self._status_reporter._error = e
print("Runner thread raised: {}".format(traceback.format_exc()))
raise e
finally:
self._status_reporter._done = True
class FunctionRunner(Trainable):
"""Trainable that runs a user function returning training results.
This mode of execution does not support checkpoint/restore."""
_name = "func"
_default_config = DEFAULT_CONFIG
def _setup(self):
entrypoint = self._trainable_func()
self._status_reporter = StatusReporter()
scrubbed_config = self.config.copy()
for k in self._default_config:
if k in scrubbed_config:
del scrubbed_config[k]
self._runner = _RunnerThread(
entrypoint, scrubbed_config, self._status_reporter)
self._start_time = time.time()
self._last_reported_timestep = 0
self._runner.start()
def _trainable_func(self):
"""Subclasses can override this to set the trainable func."""
raise NotImplementedError
def _train(self):
time.sleep(
self.config.get(
"script_min_iter_time_s",
self._default_config["script_min_iter_time_s"]))
result = self._status_reporter._get_and_clear_status()
while result is None:
time.sleep(1)
result = self._status_reporter._get_and_clear_status()
if result.timesteps_total is None:
raise TuneError("Must specify timesteps_total in result", result)
result = result._replace(
timesteps_this_iter=(
result.timesteps_total - self._last_reported_timestep))
self._last_reported_timestep = result.timesteps_total
return result
def _stop(self):
self._status_reporter._stop()