[tune] reuse actors for function API (#11230)

Co-authored-by: Kristian Hartikainen <kristian.hartikainen@gmail.com>
This commit is contained in:
Kai Fricke 2020-10-09 00:15:02 +01:00 committed by GitHub
parent 587319debc
commit b450cb030a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 154 additions and 28 deletions

View file

@ -608,6 +608,8 @@ These are the environment variables Ray Tune currently considers:
or a search algorithm, Tune will error or a search algorithm, Tune will error
if the metric was not reported in the result. Setting this environment variable if the metric was not reported in the result. Setting this environment variable
to ``1`` will disable this check. to ``1`` will disable this check.
* **TUNE_FUNCTION_THREAD_TIMEOUT_S**: Time in seconds the function API waits
for threads to finish after instructing them to complete. Defaults to ``2``.
* **TUNE_GLOBAL_CHECKPOINT_S**: Time in seconds that limits how often Tune's * **TUNE_GLOBAL_CHECKPOINT_S**: Time in seconds that limits how often Tune's
experiment state is checkpointed. If not set this will default to ``10``. experiment state is checkpointed. If not set this will default to ``10``.
* **TUNE_MAX_LEN_IDENTIFIER**: Maximum length of trial subdirectory names (those * **TUNE_MAX_LEN_IDENTIFIER**: Maximum length of trial subdirectory names (those

View file

@ -1,5 +1,6 @@
import logging import logging
import os import os
import sys
import time import time
import inspect import inspect
import shutil import shutil
@ -120,12 +121,21 @@ class StatusReporter:
def __init__(self, def __init__(self,
result_queue, result_queue,
continue_semaphore, continue_semaphore,
end_event,
trial_name=None, trial_name=None,
trial_id=None, trial_id=None,
logdir=None): logdir=None):
self._queue = result_queue self._queue = result_queue
self._last_report_time = None self._last_report_time = None
self._continue_semaphore = continue_semaphore self._continue_semaphore = continue_semaphore
self._end_event = end_event
self._trial_name = trial_name
self._trial_id = trial_id
self._logdir = logdir
self._last_checkpoint = None
self._fresh_checkpoint = False
def reset(self, trial_name=None, trial_id=None, logdir=None):
self._trial_name = trial_name self._trial_name = trial_name
self._trial_id = trial_id self._trial_id = trial_id
self._logdir = logdir self._logdir = logdir
@ -171,6 +181,11 @@ class StatusReporter:
# resume training. # resume training.
self._continue_semaphore.acquire() self._continue_semaphore.acquire()
# If the trial should be terminated, exit gracefully.
if self._end_event.is_set():
self._end_event.clear()
sys.exit(0)
def make_checkpoint_dir(self, step): def make_checkpoint_dir(self, step):
checkpoint_dir = TrainableUtil.make_checkpoint_dir( checkpoint_dir = TrainableUtil.make_checkpoint_dir(
self.logdir, index=step) self.logdir, index=step)
@ -264,6 +279,10 @@ class FunctionRunner(Trainable):
# and to generate the next result. # and to generate the next result.
self._continue_semaphore = threading.Semaphore(0) self._continue_semaphore = threading.Semaphore(0)
# Event for notifying the reporter to exit gracefully, terminating
# the thread.
self._end_event = threading.Event()
# Queue for passing results between threads # Queue for passing results between threads
self._results_queue = queue.Queue(1) self._results_queue = queue.Queue(1)
@ -275,6 +294,7 @@ class FunctionRunner(Trainable):
self._status_reporter = StatusReporter( self._status_reporter = StatusReporter(
self._results_queue, self._results_queue,
self._continue_semaphore, self._continue_semaphore,
self._end_event,
trial_name=self.trial_name, trial_name=self.trial_name,
trial_id=self.trial_id, trial_id=self.trial_id,
logdir=self.logdir) logdir=self.logdir)
@ -363,7 +383,7 @@ class FunctionRunner(Trainable):
# This keyword appears if the train_func using the Function API # This keyword appears if the train_func using the Function API
# finishes without "done=True". This duplicates the last result, but # finishes without "done=True". This duplicates the last result, but
# the TrialRunner will not log this result again. # the TrialRunner will not log this result again.
if "__duplicate__" in result: if RESULT_DUPLICATE in result:
new_result = self._last_result.copy() new_result = self._last_result.copy()
new_result.update(result) new_result.update(result)
result = new_result result = new_result
@ -441,6 +461,11 @@ class FunctionRunner(Trainable):
self.restore(checkpoint_path) self.restore(checkpoint_path)
def cleanup(self): def cleanup(self):
# Trigger thread termination
self._end_event.set()
self._continue_semaphore.release()
# Do not wait for thread termination here.
# If everything stayed in synch properly, this should never happen. # If everything stayed in synch properly, this should never happen.
if not self._results_queue.empty(): if not self._results_queue.empty():
logger.warning( logger.warning(
@ -457,6 +482,29 @@ class FunctionRunner(Trainable):
logger.debug("Clearing temporary checkpoint: %s", logger.debug("Clearing temporary checkpoint: %s",
self.temp_checkpoint_dir) self.temp_checkpoint_dir)
def reset_config(self, new_config):
if self._runner and self._runner.is_alive():
self._end_event.set()
self._continue_semaphore.release()
# Wait for thread termination so it is save to re-use the same
# actor.
thread_timeout = int(
os.environ.get("TUNE_FUNCTION_THREAD_TIMEOUT_S", 2))
self._runner.join(timeout=thread_timeout)
if self._runner.is_alive():
# Did not finish within timeout, reset unsuccessful.
return False
self._runner = None
self._last_result = {}
self._status_reporter.reset(
trial_name=self.trial_name,
trial_id=self.trial_id,
logdir=self.logdir)
return True
def _report_thread_runner_error(self, block=False): def _report_thread_runner_error(self, block=False):
try: try:
err_tb_str = self._error_queue.get( err_tb_str = self._error_queue.get(

View file

@ -14,6 +14,7 @@ from ray import ray_constants
from ray.resource_spec import ResourceSpec from ray.resource_spec import ResourceSpec
from ray.tune.durable_trainable import DurableTrainable from ray.tune.durable_trainable import DurableTrainable
from ray.tune.error import AbortTrialExecution, TuneError from ray.tune.error import AbortTrialExecution, TuneError
from ray.tune.function_runner import FunctionRunner
from ray.tune.logger import NoopLogger from ray.tune.logger import NoopLogger
from ray.tune.result import TRIAL_INFO, STDOUT_FILE, STDERR_FILE from ray.tune.result import TRIAL_INFO, STDOUT_FILE, STDERR_FILE
from ray.tune.resources import Resources from ray.tune.resources import Resources
@ -276,13 +277,13 @@ class RayTrialExecutor(TrialExecutor):
""" """
prior_status = trial.status prior_status = trial.status
if runner is None: if runner is None:
# TODO: Right now, we only support reuse if there has been # We reuse actors when there is previously instantiated state on
# previously instantiated state on the worker. However, # the actor. Function API calls are also supported when there is
# we should consider the case where function evaluations # no checkpoint to continue from.
# can be very fast - thereby extending the need to support # TODO: Check preconditions - why is previous state needed?
# reuse to cases where there has not been previously reuse_allowed = checkpoint is not None or trial.has_checkpoint() \
# instantiated state before. or issubclass(trial.get_trainable_cls(),
reuse_allowed = checkpoint is not None or trial.has_checkpoint() FunctionRunner)
runner = self._setup_remote_runner(trial, reuse_allowed) runner = self._setup_remote_runner(trial, reuse_allowed)
trial.set_runner(runner) trial.set_runner(runner)
self.restore(trial, checkpoint) self.restore(trial, checkpoint)

View file

@ -1,11 +1,14 @@
import os import os
import pickle
import unittest import unittest
import sys import sys
from collections import defaultdict
import ray import ray
from ray import tune, logger from ray import tune, logger
from ray.tune import Trainable, run_experiments, register_trainable from ray.tune import Trainable, run_experiments, register_trainable
from ray.tune.error import TuneError from ray.tune.error import TuneError
from ray.tune.function_runner import wrap_function
from ray.tune.schedulers.trial_scheduler import FIFOScheduler, TrialScheduler from ray.tune.schedulers.trial_scheduler import FIFOScheduler, TrialScheduler
@ -30,6 +33,7 @@ def create_resettable_class():
logger.info("LOG_STDERR: {}".format(self.msg)) logger.info("LOG_STDERR: {}".format(self.msg))
return { return {
"id": self.config["id"],
"num_resets": self.num_resets, "num_resets": self.num_resets,
"done": self.iter > 1, "done": self.iter > 1,
"iter": self.iter "iter": self.iter
@ -51,6 +55,35 @@ def create_resettable_class():
return MyResettableClass return MyResettableClass
def create_resettable_function(num_resets: defaultdict):
def trainable(config, checkpoint_dir=None):
if checkpoint_dir:
with open(os.path.join(checkpoint_dir, "chkpt"), "rb") as fp:
step = pickle.load(fp)
else:
step = 0
while step < 2:
step += 1
with tune.checkpoint_dir(step) as checkpoint_dir:
with open(os.path.join(checkpoint_dir, "chkpt"), "wb") as fp:
pickle.dump(step, fp)
tune.report(**{
"done": step >= 2,
"iter": step,
"id": config["id"]
})
trainable = wrap_function(trainable)
class ResetCountTrainable(trainable):
def reset_config(self, new_config):
num_resets[self.trial_id] += 1
return super().reset_config(new_config)
return ResetCountTrainable
class ActorReuseTest(unittest.TestCase): class ActorReuseTest(unittest.TestCase):
def setUp(self): def setUp(self):
ray.init(num_cpus=1, num_gpus=0) ray.init(num_cpus=1, num_gpus=0)
@ -58,38 +91,56 @@ class ActorReuseTest(unittest.TestCase):
def tearDown(self): def tearDown(self):
ray.shutdown() ray.shutdown()
def testTrialReuseDisabled(self): def _run_trials_with_frequent_pauses(self, trainable, reuse=False):
trials = run_experiments( trials = run_experiments(
{ {
"foo": { "foo": {
"run": create_resettable_class(), "run": trainable,
"num_samples": 4, "num_samples": 1,
"config": {}, "config": {
"id": tune.grid_search([0, 1, 2, 3])
},
} }
}, },
reuse_actors=False, reuse_actors=reuse,
scheduler=FrequentPausesScheduler(), scheduler=FrequentPausesScheduler(),
verbose=0) verbose=0)
return trials
def testTrialReuseDisabled(self):
trials = self._run_trials_with_frequent_pauses(
create_resettable_class(), reuse=False)
self.assertEqual([t.last_result["id"] for t in trials], [0, 1, 2, 3])
self.assertEqual([t.last_result["iter"] for t in trials], [2, 2, 2, 2]) self.assertEqual([t.last_result["iter"] for t in trials], [2, 2, 2, 2])
self.assertEqual([t.last_result["num_resets"] for t in trials], self.assertEqual([t.last_result["num_resets"] for t in trials],
[0, 0, 0, 0]) [0, 0, 0, 0])
def testTrialReuseDisabledFunction(self):
num_resets = defaultdict(lambda: 0)
trials = self._run_trials_with_frequent_pauses(
create_resettable_function(num_resets), reuse=False)
self.assertEqual([t.last_result["id"] for t in trials], [0, 1, 2, 3])
self.assertEqual([t.last_result["iter"] for t in trials], [2, 2, 2, 2])
self.assertEqual([num_resets[t.trial_id] for t in trials],
[0, 0, 0, 0])
def testTrialReuseEnabled(self): def testTrialReuseEnabled(self):
trials = run_experiments( trials = self._run_trials_with_frequent_pauses(
{ create_resettable_class(), reuse=True)
"foo": { self.assertEqual([t.last_result["id"] for t in trials], [0, 1, 2, 3])
"run": create_resettable_class(),
"num_samples": 4,
"config": {},
}
},
reuse_actors=True,
scheduler=FrequentPausesScheduler(),
verbose=0)
self.assertEqual([t.last_result["iter"] for t in trials], [2, 2, 2, 2]) self.assertEqual([t.last_result["iter"] for t in trials], [2, 2, 2, 2])
self.assertEqual([t.last_result["num_resets"] for t in trials], self.assertEqual([t.last_result["num_resets"] for t in trials],
[1, 2, 3, 4]) [1, 2, 3, 4])
def testTrialReuseEnabledFunction(self):
num_resets = defaultdict(lambda: 0)
trials = self._run_trials_with_frequent_pauses(
create_resettable_function(num_resets), reuse=True)
self.assertEqual([t.last_result["id"] for t in trials], [0, 1, 2, 3])
self.assertEqual([t.last_result["iter"] for t in trials], [2, 2, 2, 2])
self.assertEqual([num_resets[t.trial_id] for t in trials],
[0, 0, 0, 0])
def testReuseEnabledError(self): def testReuseEnabledError(self):
def run(): def run():
run_experiments( run_experiments(
@ -97,8 +148,9 @@ class ActorReuseTest(unittest.TestCase):
"foo": { "foo": {
"run": create_resettable_class(), "run": create_resettable_class(),
"max_failures": 1, "max_failures": 1,
"num_samples": 4, "num_samples": 1,
"config": { "config": {
"id": tune.grid_search([0, 1, 2, 3]),
"fake_reset_not_supported": True "fake_reset_not_supported": True
}, },
} }
@ -115,7 +167,8 @@ class ActorReuseTest(unittest.TestCase):
[trial1, trial2] = tune.run( [trial1, trial2] = tune.run(
"foo2", "foo2",
config={ config={
"message": tune.grid_search(["First", "Second"]) "message": tune.grid_search(["First", "Second"]),
"id": -1
}, },
log_to_file=True, log_to_file=True,
scheduler=FrequentPausesScheduler(), scheduler=FrequentPausesScheduler(),

View file

@ -552,7 +552,22 @@ class Trainable:
self._close_logfiles() self._close_logfiles()
self._open_logfiles(stdout_file, stderr_file) self._open_logfiles(stdout_file, stderr_file)
return self.reset_config(new_config) success = self.reset_config(new_config)
if not success:
return False
# Reset attributes. Will be overwritten by `restore` if a checkpoint
# is provided.
self._iteration = 0
self._time_total = 0.0
self._timesteps_total = None
self._episodes_total = None
self._time_since_restore = 0.0
self._timesteps_since_restore = 0
self._iterations_since_restore = 0
self._restored = False
return True
def reset_config(self, new_config): def reset_config(self, new_config):
"""Resets configuration without restarting the trial. """Resets configuration without restarting the trial.

View file

@ -724,7 +724,6 @@ class TrialRunner:
""" """
try: try:
result = self.trial_executor.fetch_result(trial) result = self.trial_executor.fetch_result(trial)
is_duplicate = RESULT_DUPLICATE in result is_duplicate = RESULT_DUPLICATE in result
force_checkpoint = result.get(SHOULD_CHECKPOINT, False) force_checkpoint = result.get(SHOULD_CHECKPOINT, False)
# TrialScheduler and SearchAlgorithm still receive a # TrialScheduler and SearchAlgorithm still receive a

View file

@ -1,5 +1,6 @@
import logging import logging
import sys import sys
import time
from ray.tune.error import TuneError from ray.tune.error import TuneError
from ray.tune.experiment import convert_to_experiment_list, Experiment from ray.tune.experiment import convert_to_experiment_list, Experiment
@ -255,6 +256,7 @@ def run(
Raises: Raises:
TuneError: Any trials failed and `raise_on_failed_trial` is True. TuneError: Any trials failed and `raise_on_failed_trial` is True.
""" """
all_start = time.time()
if global_checkpoint_period: if global_checkpoint_period:
raise ValueError("global_checkpoint_period is deprecated. Set env var " raise ValueError("global_checkpoint_period is deprecated. Set env var "
"'TUNE_GLOBAL_CHECKPOINT_S' instead.") "'TUNE_GLOBAL_CHECKPOINT_S' instead.")
@ -404,10 +406,12 @@ def run(
"`Trainable.default_resource_request` if using the " "`Trainable.default_resource_request` if using the "
"Trainable API.") "Trainable API.")
tune_start = time.time()
while not runner.is_finished(): while not runner.is_finished():
runner.step() runner.step()
if verbose: if verbose:
_report_progress(runner, progress_reporter) _report_progress(runner, progress_reporter)
tune_taken = time.time() - tune_start
try: try:
runner.checkpoint(force=True) runner.checkpoint(force=True)
@ -431,6 +435,10 @@ def run(
else: else:
logger.error("Trials did not complete: %s", incomplete_trials) logger.error("Trials did not complete: %s", incomplete_trials)
all_taken = time.time() - all_start
logger.info(f"Total run time: {all_taken:.2f} seconds "
f"({tune_taken:.2f} seconds for the tuning loop).")
trials = runner.get_trials() trials = runner.get_trials()
return ExperimentAnalysis( return ExperimentAnalysis(
runner.checkpoint_file, runner.checkpoint_file,