mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
[tune] reuse actors for function API (#11230)
Co-authored-by: Kristian Hartikainen <kristian.hartikainen@gmail.com>
This commit is contained in:
parent
587319debc
commit
b450cb030a
7 changed files with 154 additions and 28 deletions
|
@ -608,6 +608,8 @@ These are the environment variables Ray Tune currently considers:
|
|||
or a search algorithm, Tune will error
|
||||
if the metric was not reported in the result. Setting this environment variable
|
||||
to ``1`` will disable this check.
|
||||
* **TUNE_FUNCTION_THREAD_TIMEOUT_S**: Time in seconds the function API waits
|
||||
for threads to finish after instructing them to complete. Defaults to ``2``.
|
||||
* **TUNE_GLOBAL_CHECKPOINT_S**: Time in seconds that limits how often Tune's
|
||||
experiment state is checkpointed. If not set this will default to ``10``.
|
||||
* **TUNE_MAX_LEN_IDENTIFIER**: Maximum length of trial subdirectory names (those
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
import logging
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import inspect
|
||||
import shutil
|
||||
|
@ -120,12 +121,21 @@ class StatusReporter:
|
|||
def __init__(self,
|
||||
result_queue,
|
||||
continue_semaphore,
|
||||
end_event,
|
||||
trial_name=None,
|
||||
trial_id=None,
|
||||
logdir=None):
|
||||
self._queue = result_queue
|
||||
self._last_report_time = None
|
||||
self._continue_semaphore = continue_semaphore
|
||||
self._end_event = end_event
|
||||
self._trial_name = trial_name
|
||||
self._trial_id = trial_id
|
||||
self._logdir = logdir
|
||||
self._last_checkpoint = None
|
||||
self._fresh_checkpoint = False
|
||||
|
||||
def reset(self, trial_name=None, trial_id=None, logdir=None):
|
||||
self._trial_name = trial_name
|
||||
self._trial_id = trial_id
|
||||
self._logdir = logdir
|
||||
|
@ -171,6 +181,11 @@ class StatusReporter:
|
|||
# resume training.
|
||||
self._continue_semaphore.acquire()
|
||||
|
||||
# If the trial should be terminated, exit gracefully.
|
||||
if self._end_event.is_set():
|
||||
self._end_event.clear()
|
||||
sys.exit(0)
|
||||
|
||||
def make_checkpoint_dir(self, step):
|
||||
checkpoint_dir = TrainableUtil.make_checkpoint_dir(
|
||||
self.logdir, index=step)
|
||||
|
@ -264,6 +279,10 @@ class FunctionRunner(Trainable):
|
|||
# and to generate the next result.
|
||||
self._continue_semaphore = threading.Semaphore(0)
|
||||
|
||||
# Event for notifying the reporter to exit gracefully, terminating
|
||||
# the thread.
|
||||
self._end_event = threading.Event()
|
||||
|
||||
# Queue for passing results between threads
|
||||
self._results_queue = queue.Queue(1)
|
||||
|
||||
|
@ -275,6 +294,7 @@ class FunctionRunner(Trainable):
|
|||
self._status_reporter = StatusReporter(
|
||||
self._results_queue,
|
||||
self._continue_semaphore,
|
||||
self._end_event,
|
||||
trial_name=self.trial_name,
|
||||
trial_id=self.trial_id,
|
||||
logdir=self.logdir)
|
||||
|
@ -363,7 +383,7 @@ class FunctionRunner(Trainable):
|
|||
# This keyword appears if the train_func using the Function API
|
||||
# finishes without "done=True". This duplicates the last result, but
|
||||
# the TrialRunner will not log this result again.
|
||||
if "__duplicate__" in result:
|
||||
if RESULT_DUPLICATE in result:
|
||||
new_result = self._last_result.copy()
|
||||
new_result.update(result)
|
||||
result = new_result
|
||||
|
@ -441,6 +461,11 @@ class FunctionRunner(Trainable):
|
|||
self.restore(checkpoint_path)
|
||||
|
||||
def cleanup(self):
|
||||
# Trigger thread termination
|
||||
self._end_event.set()
|
||||
self._continue_semaphore.release()
|
||||
# Do not wait for thread termination here.
|
||||
|
||||
# If everything stayed in synch properly, this should never happen.
|
||||
if not self._results_queue.empty():
|
||||
logger.warning(
|
||||
|
@ -457,6 +482,29 @@ class FunctionRunner(Trainable):
|
|||
logger.debug("Clearing temporary checkpoint: %s",
|
||||
self.temp_checkpoint_dir)
|
||||
|
||||
def reset_config(self, new_config):
|
||||
if self._runner and self._runner.is_alive():
|
||||
self._end_event.set()
|
||||
self._continue_semaphore.release()
|
||||
# Wait for thread termination so it is save to re-use the same
|
||||
# actor.
|
||||
thread_timeout = int(
|
||||
os.environ.get("TUNE_FUNCTION_THREAD_TIMEOUT_S", 2))
|
||||
self._runner.join(timeout=thread_timeout)
|
||||
if self._runner.is_alive():
|
||||
# Did not finish within timeout, reset unsuccessful.
|
||||
return False
|
||||
|
||||
self._runner = None
|
||||
self._last_result = {}
|
||||
|
||||
self._status_reporter.reset(
|
||||
trial_name=self.trial_name,
|
||||
trial_id=self.trial_id,
|
||||
logdir=self.logdir)
|
||||
|
||||
return True
|
||||
|
||||
def _report_thread_runner_error(self, block=False):
|
||||
try:
|
||||
err_tb_str = self._error_queue.get(
|
||||
|
|
|
@ -14,6 +14,7 @@ from ray import ray_constants
|
|||
from ray.resource_spec import ResourceSpec
|
||||
from ray.tune.durable_trainable import DurableTrainable
|
||||
from ray.tune.error import AbortTrialExecution, TuneError
|
||||
from ray.tune.function_runner import FunctionRunner
|
||||
from ray.tune.logger import NoopLogger
|
||||
from ray.tune.result import TRIAL_INFO, STDOUT_FILE, STDERR_FILE
|
||||
from ray.tune.resources import Resources
|
||||
|
@ -276,13 +277,13 @@ class RayTrialExecutor(TrialExecutor):
|
|||
"""
|
||||
prior_status = trial.status
|
||||
if runner is None:
|
||||
# TODO: Right now, we only support reuse if there has been
|
||||
# previously instantiated state on the worker. However,
|
||||
# we should consider the case where function evaluations
|
||||
# can be very fast - thereby extending the need to support
|
||||
# reuse to cases where there has not been previously
|
||||
# instantiated state before.
|
||||
reuse_allowed = checkpoint is not None or trial.has_checkpoint()
|
||||
# We reuse actors when there is previously instantiated state on
|
||||
# the actor. Function API calls are also supported when there is
|
||||
# no checkpoint to continue from.
|
||||
# TODO: Check preconditions - why is previous state needed?
|
||||
reuse_allowed = checkpoint is not None or trial.has_checkpoint() \
|
||||
or issubclass(trial.get_trainable_cls(),
|
||||
FunctionRunner)
|
||||
runner = self._setup_remote_runner(trial, reuse_allowed)
|
||||
trial.set_runner(runner)
|
||||
self.restore(trial, checkpoint)
|
||||
|
|
|
@ -1,11 +1,14 @@
|
|||
import os
|
||||
import pickle
|
||||
import unittest
|
||||
import sys
|
||||
from collections import defaultdict
|
||||
|
||||
import ray
|
||||
from ray import tune, logger
|
||||
from ray.tune import Trainable, run_experiments, register_trainable
|
||||
from ray.tune.error import TuneError
|
||||
from ray.tune.function_runner import wrap_function
|
||||
from ray.tune.schedulers.trial_scheduler import FIFOScheduler, TrialScheduler
|
||||
|
||||
|
||||
|
@ -30,6 +33,7 @@ def create_resettable_class():
|
|||
logger.info("LOG_STDERR: {}".format(self.msg))
|
||||
|
||||
return {
|
||||
"id": self.config["id"],
|
||||
"num_resets": self.num_resets,
|
||||
"done": self.iter > 1,
|
||||
"iter": self.iter
|
||||
|
@ -51,6 +55,35 @@ def create_resettable_class():
|
|||
return MyResettableClass
|
||||
|
||||
|
||||
def create_resettable_function(num_resets: defaultdict):
|
||||
def trainable(config, checkpoint_dir=None):
|
||||
if checkpoint_dir:
|
||||
with open(os.path.join(checkpoint_dir, "chkpt"), "rb") as fp:
|
||||
step = pickle.load(fp)
|
||||
else:
|
||||
step = 0
|
||||
|
||||
while step < 2:
|
||||
step += 1
|
||||
with tune.checkpoint_dir(step) as checkpoint_dir:
|
||||
with open(os.path.join(checkpoint_dir, "chkpt"), "wb") as fp:
|
||||
pickle.dump(step, fp)
|
||||
tune.report(**{
|
||||
"done": step >= 2,
|
||||
"iter": step,
|
||||
"id": config["id"]
|
||||
})
|
||||
|
||||
trainable = wrap_function(trainable)
|
||||
|
||||
class ResetCountTrainable(trainable):
|
||||
def reset_config(self, new_config):
|
||||
num_resets[self.trial_id] += 1
|
||||
return super().reset_config(new_config)
|
||||
|
||||
return ResetCountTrainable
|
||||
|
||||
|
||||
class ActorReuseTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
ray.init(num_cpus=1, num_gpus=0)
|
||||
|
@ -58,38 +91,56 @@ class ActorReuseTest(unittest.TestCase):
|
|||
def tearDown(self):
|
||||
ray.shutdown()
|
||||
|
||||
def testTrialReuseDisabled(self):
|
||||
def _run_trials_with_frequent_pauses(self, trainable, reuse=False):
|
||||
trials = run_experiments(
|
||||
{
|
||||
"foo": {
|
||||
"run": create_resettable_class(),
|
||||
"num_samples": 4,
|
||||
"config": {},
|
||||
"run": trainable,
|
||||
"num_samples": 1,
|
||||
"config": {
|
||||
"id": tune.grid_search([0, 1, 2, 3])
|
||||
},
|
||||
}
|
||||
},
|
||||
reuse_actors=False,
|
||||
reuse_actors=reuse,
|
||||
scheduler=FrequentPausesScheduler(),
|
||||
verbose=0)
|
||||
return trials
|
||||
|
||||
def testTrialReuseDisabled(self):
|
||||
trials = self._run_trials_with_frequent_pauses(
|
||||
create_resettable_class(), reuse=False)
|
||||
self.assertEqual([t.last_result["id"] for t in trials], [0, 1, 2, 3])
|
||||
self.assertEqual([t.last_result["iter"] for t in trials], [2, 2, 2, 2])
|
||||
self.assertEqual([t.last_result["num_resets"] for t in trials],
|
||||
[0, 0, 0, 0])
|
||||
|
||||
def testTrialReuseDisabledFunction(self):
|
||||
num_resets = defaultdict(lambda: 0)
|
||||
trials = self._run_trials_with_frequent_pauses(
|
||||
create_resettable_function(num_resets), reuse=False)
|
||||
self.assertEqual([t.last_result["id"] for t in trials], [0, 1, 2, 3])
|
||||
self.assertEqual([t.last_result["iter"] for t in trials], [2, 2, 2, 2])
|
||||
self.assertEqual([num_resets[t.trial_id] for t in trials],
|
||||
[0, 0, 0, 0])
|
||||
|
||||
def testTrialReuseEnabled(self):
|
||||
trials = run_experiments(
|
||||
{
|
||||
"foo": {
|
||||
"run": create_resettable_class(),
|
||||
"num_samples": 4,
|
||||
"config": {},
|
||||
}
|
||||
},
|
||||
reuse_actors=True,
|
||||
scheduler=FrequentPausesScheduler(),
|
||||
verbose=0)
|
||||
trials = self._run_trials_with_frequent_pauses(
|
||||
create_resettable_class(), reuse=True)
|
||||
self.assertEqual([t.last_result["id"] for t in trials], [0, 1, 2, 3])
|
||||
self.assertEqual([t.last_result["iter"] for t in trials], [2, 2, 2, 2])
|
||||
self.assertEqual([t.last_result["num_resets"] for t in trials],
|
||||
[1, 2, 3, 4])
|
||||
|
||||
def testTrialReuseEnabledFunction(self):
|
||||
num_resets = defaultdict(lambda: 0)
|
||||
trials = self._run_trials_with_frequent_pauses(
|
||||
create_resettable_function(num_resets), reuse=True)
|
||||
self.assertEqual([t.last_result["id"] for t in trials], [0, 1, 2, 3])
|
||||
self.assertEqual([t.last_result["iter"] for t in trials], [2, 2, 2, 2])
|
||||
self.assertEqual([num_resets[t.trial_id] for t in trials],
|
||||
[0, 0, 0, 0])
|
||||
|
||||
def testReuseEnabledError(self):
|
||||
def run():
|
||||
run_experiments(
|
||||
|
@ -97,8 +148,9 @@ class ActorReuseTest(unittest.TestCase):
|
|||
"foo": {
|
||||
"run": create_resettable_class(),
|
||||
"max_failures": 1,
|
||||
"num_samples": 4,
|
||||
"num_samples": 1,
|
||||
"config": {
|
||||
"id": tune.grid_search([0, 1, 2, 3]),
|
||||
"fake_reset_not_supported": True
|
||||
},
|
||||
}
|
||||
|
@ -115,7 +167,8 @@ class ActorReuseTest(unittest.TestCase):
|
|||
[trial1, trial2] = tune.run(
|
||||
"foo2",
|
||||
config={
|
||||
"message": tune.grid_search(["First", "Second"])
|
||||
"message": tune.grid_search(["First", "Second"]),
|
||||
"id": -1
|
||||
},
|
||||
log_to_file=True,
|
||||
scheduler=FrequentPausesScheduler(),
|
||||
|
|
|
@ -552,7 +552,22 @@ class Trainable:
|
|||
self._close_logfiles()
|
||||
self._open_logfiles(stdout_file, stderr_file)
|
||||
|
||||
return self.reset_config(new_config)
|
||||
success = self.reset_config(new_config)
|
||||
if not success:
|
||||
return False
|
||||
|
||||
# Reset attributes. Will be overwritten by `restore` if a checkpoint
|
||||
# is provided.
|
||||
self._iteration = 0
|
||||
self._time_total = 0.0
|
||||
self._timesteps_total = None
|
||||
self._episodes_total = None
|
||||
self._time_since_restore = 0.0
|
||||
self._timesteps_since_restore = 0
|
||||
self._iterations_since_restore = 0
|
||||
self._restored = False
|
||||
|
||||
return True
|
||||
|
||||
def reset_config(self, new_config):
|
||||
"""Resets configuration without restarting the trial.
|
||||
|
|
|
@ -724,7 +724,6 @@ class TrialRunner:
|
|||
"""
|
||||
try:
|
||||
result = self.trial_executor.fetch_result(trial)
|
||||
|
||||
is_duplicate = RESULT_DUPLICATE in result
|
||||
force_checkpoint = result.get(SHOULD_CHECKPOINT, False)
|
||||
# TrialScheduler and SearchAlgorithm still receive a
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
import logging
|
||||
import sys
|
||||
import time
|
||||
|
||||
from ray.tune.error import TuneError
|
||||
from ray.tune.experiment import convert_to_experiment_list, Experiment
|
||||
|
@ -255,6 +256,7 @@ def run(
|
|||
Raises:
|
||||
TuneError: Any trials failed and `raise_on_failed_trial` is True.
|
||||
"""
|
||||
all_start = time.time()
|
||||
if global_checkpoint_period:
|
||||
raise ValueError("global_checkpoint_period is deprecated. Set env var "
|
||||
"'TUNE_GLOBAL_CHECKPOINT_S' instead.")
|
||||
|
@ -404,10 +406,12 @@ def run(
|
|||
"`Trainable.default_resource_request` if using the "
|
||||
"Trainable API.")
|
||||
|
||||
tune_start = time.time()
|
||||
while not runner.is_finished():
|
||||
runner.step()
|
||||
if verbose:
|
||||
_report_progress(runner, progress_reporter)
|
||||
tune_taken = time.time() - tune_start
|
||||
|
||||
try:
|
||||
runner.checkpoint(force=True)
|
||||
|
@ -431,6 +435,10 @@ def run(
|
|||
else:
|
||||
logger.error("Trials did not complete: %s", incomplete_trials)
|
||||
|
||||
all_taken = time.time() - all_start
|
||||
logger.info(f"Total run time: {all_taken:.2f} seconds "
|
||||
f"({tune_taken:.2f} seconds for the tuning loop).")
|
||||
|
||||
trials = runner.get_trials()
|
||||
return ExperimentAnalysis(
|
||||
runner.checkpoint_file,
|
||||
|
|
Loading…
Add table
Reference in a new issue