[tune] reuse actors for function API (#11230)

Co-authored-by: Kristian Hartikainen <kristian.hartikainen@gmail.com>
This commit is contained in:
Kai Fricke 2020-10-09 00:15:02 +01:00 committed by GitHub
parent 587319debc
commit b450cb030a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 154 additions and 28 deletions

View file

@ -608,6 +608,8 @@ These are the environment variables Ray Tune currently considers:
or a search algorithm, Tune will error
if the metric was not reported in the result. Setting this environment variable
to ``1`` will disable this check.
* **TUNE_FUNCTION_THREAD_TIMEOUT_S**: Time in seconds the function API waits
for threads to finish after instructing them to complete. Defaults to ``2``.
* **TUNE_GLOBAL_CHECKPOINT_S**: Time in seconds that limits how often Tune's
experiment state is checkpointed. If not set this will default to ``10``.
* **TUNE_MAX_LEN_IDENTIFIER**: Maximum length of trial subdirectory names (those

View file

@ -1,5 +1,6 @@
import logging
import os
import sys
import time
import inspect
import shutil
@ -120,12 +121,21 @@ class StatusReporter:
def __init__(self,
result_queue,
continue_semaphore,
end_event,
trial_name=None,
trial_id=None,
logdir=None):
self._queue = result_queue
self._last_report_time = None
self._continue_semaphore = continue_semaphore
self._end_event = end_event
self._trial_name = trial_name
self._trial_id = trial_id
self._logdir = logdir
self._last_checkpoint = None
self._fresh_checkpoint = False
def reset(self, trial_name=None, trial_id=None, logdir=None):
self._trial_name = trial_name
self._trial_id = trial_id
self._logdir = logdir
@ -171,6 +181,11 @@ class StatusReporter:
# resume training.
self._continue_semaphore.acquire()
# If the trial should be terminated, exit gracefully.
if self._end_event.is_set():
self._end_event.clear()
sys.exit(0)
def make_checkpoint_dir(self, step):
checkpoint_dir = TrainableUtil.make_checkpoint_dir(
self.logdir, index=step)
@ -264,6 +279,10 @@ class FunctionRunner(Trainable):
# and to generate the next result.
self._continue_semaphore = threading.Semaphore(0)
# Event for notifying the reporter to exit gracefully, terminating
# the thread.
self._end_event = threading.Event()
# Queue for passing results between threads
self._results_queue = queue.Queue(1)
@ -275,6 +294,7 @@ class FunctionRunner(Trainable):
self._status_reporter = StatusReporter(
self._results_queue,
self._continue_semaphore,
self._end_event,
trial_name=self.trial_name,
trial_id=self.trial_id,
logdir=self.logdir)
@ -363,7 +383,7 @@ class FunctionRunner(Trainable):
# This keyword appears if the train_func using the Function API
# finishes without "done=True". This duplicates the last result, but
# the TrialRunner will not log this result again.
if "__duplicate__" in result:
if RESULT_DUPLICATE in result:
new_result = self._last_result.copy()
new_result.update(result)
result = new_result
@ -441,6 +461,11 @@ class FunctionRunner(Trainable):
self.restore(checkpoint_path)
def cleanup(self):
# Trigger thread termination
self._end_event.set()
self._continue_semaphore.release()
# Do not wait for thread termination here.
# If everything stayed in synch properly, this should never happen.
if not self._results_queue.empty():
logger.warning(
@ -457,6 +482,29 @@ class FunctionRunner(Trainable):
logger.debug("Clearing temporary checkpoint: %s",
self.temp_checkpoint_dir)
def reset_config(self, new_config):
if self._runner and self._runner.is_alive():
self._end_event.set()
self._continue_semaphore.release()
# Wait for thread termination so it is save to re-use the same
# actor.
thread_timeout = int(
os.environ.get("TUNE_FUNCTION_THREAD_TIMEOUT_S", 2))
self._runner.join(timeout=thread_timeout)
if self._runner.is_alive():
# Did not finish within timeout, reset unsuccessful.
return False
self._runner = None
self._last_result = {}
self._status_reporter.reset(
trial_name=self.trial_name,
trial_id=self.trial_id,
logdir=self.logdir)
return True
def _report_thread_runner_error(self, block=False):
try:
err_tb_str = self._error_queue.get(

View file

@ -14,6 +14,7 @@ from ray import ray_constants
from ray.resource_spec import ResourceSpec
from ray.tune.durable_trainable import DurableTrainable
from ray.tune.error import AbortTrialExecution, TuneError
from ray.tune.function_runner import FunctionRunner
from ray.tune.logger import NoopLogger
from ray.tune.result import TRIAL_INFO, STDOUT_FILE, STDERR_FILE
from ray.tune.resources import Resources
@ -276,13 +277,13 @@ class RayTrialExecutor(TrialExecutor):
"""
prior_status = trial.status
if runner is None:
# TODO: Right now, we only support reuse if there has been
# previously instantiated state on the worker. However,
# we should consider the case where function evaluations
# can be very fast - thereby extending the need to support
# reuse to cases where there has not been previously
# instantiated state before.
reuse_allowed = checkpoint is not None or trial.has_checkpoint()
# We reuse actors when there is previously instantiated state on
# the actor. Function API calls are also supported when there is
# no checkpoint to continue from.
# TODO: Check preconditions - why is previous state needed?
reuse_allowed = checkpoint is not None or trial.has_checkpoint() \
or issubclass(trial.get_trainable_cls(),
FunctionRunner)
runner = self._setup_remote_runner(trial, reuse_allowed)
trial.set_runner(runner)
self.restore(trial, checkpoint)

View file

@ -1,11 +1,14 @@
import os
import pickle
import unittest
import sys
from collections import defaultdict
import ray
from ray import tune, logger
from ray.tune import Trainable, run_experiments, register_trainable
from ray.tune.error import TuneError
from ray.tune.function_runner import wrap_function
from ray.tune.schedulers.trial_scheduler import FIFOScheduler, TrialScheduler
@ -30,6 +33,7 @@ def create_resettable_class():
logger.info("LOG_STDERR: {}".format(self.msg))
return {
"id": self.config["id"],
"num_resets": self.num_resets,
"done": self.iter > 1,
"iter": self.iter
@ -51,6 +55,35 @@ def create_resettable_class():
return MyResettableClass
def create_resettable_function(num_resets: defaultdict):
def trainable(config, checkpoint_dir=None):
if checkpoint_dir:
with open(os.path.join(checkpoint_dir, "chkpt"), "rb") as fp:
step = pickle.load(fp)
else:
step = 0
while step < 2:
step += 1
with tune.checkpoint_dir(step) as checkpoint_dir:
with open(os.path.join(checkpoint_dir, "chkpt"), "wb") as fp:
pickle.dump(step, fp)
tune.report(**{
"done": step >= 2,
"iter": step,
"id": config["id"]
})
trainable = wrap_function(trainable)
class ResetCountTrainable(trainable):
def reset_config(self, new_config):
num_resets[self.trial_id] += 1
return super().reset_config(new_config)
return ResetCountTrainable
class ActorReuseTest(unittest.TestCase):
def setUp(self):
ray.init(num_cpus=1, num_gpus=0)
@ -58,38 +91,56 @@ class ActorReuseTest(unittest.TestCase):
def tearDown(self):
ray.shutdown()
def testTrialReuseDisabled(self):
def _run_trials_with_frequent_pauses(self, trainable, reuse=False):
trials = run_experiments(
{
"foo": {
"run": create_resettable_class(),
"num_samples": 4,
"config": {},
"run": trainable,
"num_samples": 1,
"config": {
"id": tune.grid_search([0, 1, 2, 3])
},
}
},
reuse_actors=False,
reuse_actors=reuse,
scheduler=FrequentPausesScheduler(),
verbose=0)
return trials
def testTrialReuseDisabled(self):
trials = self._run_trials_with_frequent_pauses(
create_resettable_class(), reuse=False)
self.assertEqual([t.last_result["id"] for t in trials], [0, 1, 2, 3])
self.assertEqual([t.last_result["iter"] for t in trials], [2, 2, 2, 2])
self.assertEqual([t.last_result["num_resets"] for t in trials],
[0, 0, 0, 0])
def testTrialReuseDisabledFunction(self):
num_resets = defaultdict(lambda: 0)
trials = self._run_trials_with_frequent_pauses(
create_resettable_function(num_resets), reuse=False)
self.assertEqual([t.last_result["id"] for t in trials], [0, 1, 2, 3])
self.assertEqual([t.last_result["iter"] for t in trials], [2, 2, 2, 2])
self.assertEqual([num_resets[t.trial_id] for t in trials],
[0, 0, 0, 0])
def testTrialReuseEnabled(self):
trials = run_experiments(
{
"foo": {
"run": create_resettable_class(),
"num_samples": 4,
"config": {},
}
},
reuse_actors=True,
scheduler=FrequentPausesScheduler(),
verbose=0)
trials = self._run_trials_with_frequent_pauses(
create_resettable_class(), reuse=True)
self.assertEqual([t.last_result["id"] for t in trials], [0, 1, 2, 3])
self.assertEqual([t.last_result["iter"] for t in trials], [2, 2, 2, 2])
self.assertEqual([t.last_result["num_resets"] for t in trials],
[1, 2, 3, 4])
def testTrialReuseEnabledFunction(self):
num_resets = defaultdict(lambda: 0)
trials = self._run_trials_with_frequent_pauses(
create_resettable_function(num_resets), reuse=True)
self.assertEqual([t.last_result["id"] for t in trials], [0, 1, 2, 3])
self.assertEqual([t.last_result["iter"] for t in trials], [2, 2, 2, 2])
self.assertEqual([num_resets[t.trial_id] for t in trials],
[0, 0, 0, 0])
def testReuseEnabledError(self):
def run():
run_experiments(
@ -97,8 +148,9 @@ class ActorReuseTest(unittest.TestCase):
"foo": {
"run": create_resettable_class(),
"max_failures": 1,
"num_samples": 4,
"num_samples": 1,
"config": {
"id": tune.grid_search([0, 1, 2, 3]),
"fake_reset_not_supported": True
},
}
@ -115,7 +167,8 @@ class ActorReuseTest(unittest.TestCase):
[trial1, trial2] = tune.run(
"foo2",
config={
"message": tune.grid_search(["First", "Second"])
"message": tune.grid_search(["First", "Second"]),
"id": -1
},
log_to_file=True,
scheduler=FrequentPausesScheduler(),

View file

@ -552,7 +552,22 @@ class Trainable:
self._close_logfiles()
self._open_logfiles(stdout_file, stderr_file)
return self.reset_config(new_config)
success = self.reset_config(new_config)
if not success:
return False
# Reset attributes. Will be overwritten by `restore` if a checkpoint
# is provided.
self._iteration = 0
self._time_total = 0.0
self._timesteps_total = None
self._episodes_total = None
self._time_since_restore = 0.0
self._timesteps_since_restore = 0
self._iterations_since_restore = 0
self._restored = False
return True
def reset_config(self, new_config):
"""Resets configuration without restarting the trial.

View file

@ -724,7 +724,6 @@ class TrialRunner:
"""
try:
result = self.trial_executor.fetch_result(trial)
is_duplicate = RESULT_DUPLICATE in result
force_checkpoint = result.get(SHOULD_CHECKPOINT, False)
# TrialScheduler and SearchAlgorithm still receive a

View file

@ -1,5 +1,6 @@
import logging
import sys
import time
from ray.tune.error import TuneError
from ray.tune.experiment import convert_to_experiment_list, Experiment
@ -255,6 +256,7 @@ def run(
Raises:
TuneError: Any trials failed and `raise_on_failed_trial` is True.
"""
all_start = time.time()
if global_checkpoint_period:
raise ValueError("global_checkpoint_period is deprecated. Set env var "
"'TUNE_GLOBAL_CHECKPOINT_S' instead.")
@ -404,10 +406,12 @@ def run(
"`Trainable.default_resource_request` if using the "
"Trainable API.")
tune_start = time.time()
while not runner.is_finished():
runner.step()
if verbose:
_report_progress(runner, progress_reporter)
tune_taken = time.time() - tune_start
try:
runner.checkpoint(force=True)
@ -431,6 +435,10 @@ def run(
else:
logger.error("Trials did not complete: %s", incomplete_trials)
all_taken = time.time() - all_start
logger.info(f"Total run time: {all_taken:.2f} seconds "
f"({tune_taken:.2f} seconds for the tuning loop).")
trials = runner.get_trials()
return ExperimentAnalysis(
runner.checkpoint_file,