mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
[tune] reuse actors for function API (#11230)
Co-authored-by: Kristian Hartikainen <kristian.hartikainen@gmail.com>
This commit is contained in:
parent
587319debc
commit
b450cb030a
7 changed files with 154 additions and 28 deletions
|
@ -608,6 +608,8 @@ These are the environment variables Ray Tune currently considers:
|
||||||
or a search algorithm, Tune will error
|
or a search algorithm, Tune will error
|
||||||
if the metric was not reported in the result. Setting this environment variable
|
if the metric was not reported in the result. Setting this environment variable
|
||||||
to ``1`` will disable this check.
|
to ``1`` will disable this check.
|
||||||
|
* **TUNE_FUNCTION_THREAD_TIMEOUT_S**: Time in seconds the function API waits
|
||||||
|
for threads to finish after instructing them to complete. Defaults to ``2``.
|
||||||
* **TUNE_GLOBAL_CHECKPOINT_S**: Time in seconds that limits how often Tune's
|
* **TUNE_GLOBAL_CHECKPOINT_S**: Time in seconds that limits how often Tune's
|
||||||
experiment state is checkpointed. If not set this will default to ``10``.
|
experiment state is checkpointed. If not set this will default to ``10``.
|
||||||
* **TUNE_MAX_LEN_IDENTIFIER**: Maximum length of trial subdirectory names (those
|
* **TUNE_MAX_LEN_IDENTIFIER**: Maximum length of trial subdirectory names (those
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
import sys
|
||||||
import time
|
import time
|
||||||
import inspect
|
import inspect
|
||||||
import shutil
|
import shutil
|
||||||
|
@ -120,12 +121,21 @@ class StatusReporter:
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
result_queue,
|
result_queue,
|
||||||
continue_semaphore,
|
continue_semaphore,
|
||||||
|
end_event,
|
||||||
trial_name=None,
|
trial_name=None,
|
||||||
trial_id=None,
|
trial_id=None,
|
||||||
logdir=None):
|
logdir=None):
|
||||||
self._queue = result_queue
|
self._queue = result_queue
|
||||||
self._last_report_time = None
|
self._last_report_time = None
|
||||||
self._continue_semaphore = continue_semaphore
|
self._continue_semaphore = continue_semaphore
|
||||||
|
self._end_event = end_event
|
||||||
|
self._trial_name = trial_name
|
||||||
|
self._trial_id = trial_id
|
||||||
|
self._logdir = logdir
|
||||||
|
self._last_checkpoint = None
|
||||||
|
self._fresh_checkpoint = False
|
||||||
|
|
||||||
|
def reset(self, trial_name=None, trial_id=None, logdir=None):
|
||||||
self._trial_name = trial_name
|
self._trial_name = trial_name
|
||||||
self._trial_id = trial_id
|
self._trial_id = trial_id
|
||||||
self._logdir = logdir
|
self._logdir = logdir
|
||||||
|
@ -171,6 +181,11 @@ class StatusReporter:
|
||||||
# resume training.
|
# resume training.
|
||||||
self._continue_semaphore.acquire()
|
self._continue_semaphore.acquire()
|
||||||
|
|
||||||
|
# If the trial should be terminated, exit gracefully.
|
||||||
|
if self._end_event.is_set():
|
||||||
|
self._end_event.clear()
|
||||||
|
sys.exit(0)
|
||||||
|
|
||||||
def make_checkpoint_dir(self, step):
|
def make_checkpoint_dir(self, step):
|
||||||
checkpoint_dir = TrainableUtil.make_checkpoint_dir(
|
checkpoint_dir = TrainableUtil.make_checkpoint_dir(
|
||||||
self.logdir, index=step)
|
self.logdir, index=step)
|
||||||
|
@ -264,6 +279,10 @@ class FunctionRunner(Trainable):
|
||||||
# and to generate the next result.
|
# and to generate the next result.
|
||||||
self._continue_semaphore = threading.Semaphore(0)
|
self._continue_semaphore = threading.Semaphore(0)
|
||||||
|
|
||||||
|
# Event for notifying the reporter to exit gracefully, terminating
|
||||||
|
# the thread.
|
||||||
|
self._end_event = threading.Event()
|
||||||
|
|
||||||
# Queue for passing results between threads
|
# Queue for passing results between threads
|
||||||
self._results_queue = queue.Queue(1)
|
self._results_queue = queue.Queue(1)
|
||||||
|
|
||||||
|
@ -275,6 +294,7 @@ class FunctionRunner(Trainable):
|
||||||
self._status_reporter = StatusReporter(
|
self._status_reporter = StatusReporter(
|
||||||
self._results_queue,
|
self._results_queue,
|
||||||
self._continue_semaphore,
|
self._continue_semaphore,
|
||||||
|
self._end_event,
|
||||||
trial_name=self.trial_name,
|
trial_name=self.trial_name,
|
||||||
trial_id=self.trial_id,
|
trial_id=self.trial_id,
|
||||||
logdir=self.logdir)
|
logdir=self.logdir)
|
||||||
|
@ -363,7 +383,7 @@ class FunctionRunner(Trainable):
|
||||||
# This keyword appears if the train_func using the Function API
|
# This keyword appears if the train_func using the Function API
|
||||||
# finishes without "done=True". This duplicates the last result, but
|
# finishes without "done=True". This duplicates the last result, but
|
||||||
# the TrialRunner will not log this result again.
|
# the TrialRunner will not log this result again.
|
||||||
if "__duplicate__" in result:
|
if RESULT_DUPLICATE in result:
|
||||||
new_result = self._last_result.copy()
|
new_result = self._last_result.copy()
|
||||||
new_result.update(result)
|
new_result.update(result)
|
||||||
result = new_result
|
result = new_result
|
||||||
|
@ -441,6 +461,11 @@ class FunctionRunner(Trainable):
|
||||||
self.restore(checkpoint_path)
|
self.restore(checkpoint_path)
|
||||||
|
|
||||||
def cleanup(self):
|
def cleanup(self):
|
||||||
|
# Trigger thread termination
|
||||||
|
self._end_event.set()
|
||||||
|
self._continue_semaphore.release()
|
||||||
|
# Do not wait for thread termination here.
|
||||||
|
|
||||||
# If everything stayed in synch properly, this should never happen.
|
# If everything stayed in synch properly, this should never happen.
|
||||||
if not self._results_queue.empty():
|
if not self._results_queue.empty():
|
||||||
logger.warning(
|
logger.warning(
|
||||||
|
@ -457,6 +482,29 @@ class FunctionRunner(Trainable):
|
||||||
logger.debug("Clearing temporary checkpoint: %s",
|
logger.debug("Clearing temporary checkpoint: %s",
|
||||||
self.temp_checkpoint_dir)
|
self.temp_checkpoint_dir)
|
||||||
|
|
||||||
|
def reset_config(self, new_config):
|
||||||
|
if self._runner and self._runner.is_alive():
|
||||||
|
self._end_event.set()
|
||||||
|
self._continue_semaphore.release()
|
||||||
|
# Wait for thread termination so it is save to re-use the same
|
||||||
|
# actor.
|
||||||
|
thread_timeout = int(
|
||||||
|
os.environ.get("TUNE_FUNCTION_THREAD_TIMEOUT_S", 2))
|
||||||
|
self._runner.join(timeout=thread_timeout)
|
||||||
|
if self._runner.is_alive():
|
||||||
|
# Did not finish within timeout, reset unsuccessful.
|
||||||
|
return False
|
||||||
|
|
||||||
|
self._runner = None
|
||||||
|
self._last_result = {}
|
||||||
|
|
||||||
|
self._status_reporter.reset(
|
||||||
|
trial_name=self.trial_name,
|
||||||
|
trial_id=self.trial_id,
|
||||||
|
logdir=self.logdir)
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
def _report_thread_runner_error(self, block=False):
|
def _report_thread_runner_error(self, block=False):
|
||||||
try:
|
try:
|
||||||
err_tb_str = self._error_queue.get(
|
err_tb_str = self._error_queue.get(
|
||||||
|
|
|
@ -14,6 +14,7 @@ from ray import ray_constants
|
||||||
from ray.resource_spec import ResourceSpec
|
from ray.resource_spec import ResourceSpec
|
||||||
from ray.tune.durable_trainable import DurableTrainable
|
from ray.tune.durable_trainable import DurableTrainable
|
||||||
from ray.tune.error import AbortTrialExecution, TuneError
|
from ray.tune.error import AbortTrialExecution, TuneError
|
||||||
|
from ray.tune.function_runner import FunctionRunner
|
||||||
from ray.tune.logger import NoopLogger
|
from ray.tune.logger import NoopLogger
|
||||||
from ray.tune.result import TRIAL_INFO, STDOUT_FILE, STDERR_FILE
|
from ray.tune.result import TRIAL_INFO, STDOUT_FILE, STDERR_FILE
|
||||||
from ray.tune.resources import Resources
|
from ray.tune.resources import Resources
|
||||||
|
@ -276,13 +277,13 @@ class RayTrialExecutor(TrialExecutor):
|
||||||
"""
|
"""
|
||||||
prior_status = trial.status
|
prior_status = trial.status
|
||||||
if runner is None:
|
if runner is None:
|
||||||
# TODO: Right now, we only support reuse if there has been
|
# We reuse actors when there is previously instantiated state on
|
||||||
# previously instantiated state on the worker. However,
|
# the actor. Function API calls are also supported when there is
|
||||||
# we should consider the case where function evaluations
|
# no checkpoint to continue from.
|
||||||
# can be very fast - thereby extending the need to support
|
# TODO: Check preconditions - why is previous state needed?
|
||||||
# reuse to cases where there has not been previously
|
reuse_allowed = checkpoint is not None or trial.has_checkpoint() \
|
||||||
# instantiated state before.
|
or issubclass(trial.get_trainable_cls(),
|
||||||
reuse_allowed = checkpoint is not None or trial.has_checkpoint()
|
FunctionRunner)
|
||||||
runner = self._setup_remote_runner(trial, reuse_allowed)
|
runner = self._setup_remote_runner(trial, reuse_allowed)
|
||||||
trial.set_runner(runner)
|
trial.set_runner(runner)
|
||||||
self.restore(trial, checkpoint)
|
self.restore(trial, checkpoint)
|
||||||
|
|
|
@ -1,11 +1,14 @@
|
||||||
import os
|
import os
|
||||||
|
import pickle
|
||||||
import unittest
|
import unittest
|
||||||
import sys
|
import sys
|
||||||
|
from collections import defaultdict
|
||||||
|
|
||||||
import ray
|
import ray
|
||||||
from ray import tune, logger
|
from ray import tune, logger
|
||||||
from ray.tune import Trainable, run_experiments, register_trainable
|
from ray.tune import Trainable, run_experiments, register_trainable
|
||||||
from ray.tune.error import TuneError
|
from ray.tune.error import TuneError
|
||||||
|
from ray.tune.function_runner import wrap_function
|
||||||
from ray.tune.schedulers.trial_scheduler import FIFOScheduler, TrialScheduler
|
from ray.tune.schedulers.trial_scheduler import FIFOScheduler, TrialScheduler
|
||||||
|
|
||||||
|
|
||||||
|
@ -30,6 +33,7 @@ def create_resettable_class():
|
||||||
logger.info("LOG_STDERR: {}".format(self.msg))
|
logger.info("LOG_STDERR: {}".format(self.msg))
|
||||||
|
|
||||||
return {
|
return {
|
||||||
|
"id": self.config["id"],
|
||||||
"num_resets": self.num_resets,
|
"num_resets": self.num_resets,
|
||||||
"done": self.iter > 1,
|
"done": self.iter > 1,
|
||||||
"iter": self.iter
|
"iter": self.iter
|
||||||
|
@ -51,6 +55,35 @@ def create_resettable_class():
|
||||||
return MyResettableClass
|
return MyResettableClass
|
||||||
|
|
||||||
|
|
||||||
|
def create_resettable_function(num_resets: defaultdict):
|
||||||
|
def trainable(config, checkpoint_dir=None):
|
||||||
|
if checkpoint_dir:
|
||||||
|
with open(os.path.join(checkpoint_dir, "chkpt"), "rb") as fp:
|
||||||
|
step = pickle.load(fp)
|
||||||
|
else:
|
||||||
|
step = 0
|
||||||
|
|
||||||
|
while step < 2:
|
||||||
|
step += 1
|
||||||
|
with tune.checkpoint_dir(step) as checkpoint_dir:
|
||||||
|
with open(os.path.join(checkpoint_dir, "chkpt"), "wb") as fp:
|
||||||
|
pickle.dump(step, fp)
|
||||||
|
tune.report(**{
|
||||||
|
"done": step >= 2,
|
||||||
|
"iter": step,
|
||||||
|
"id": config["id"]
|
||||||
|
})
|
||||||
|
|
||||||
|
trainable = wrap_function(trainable)
|
||||||
|
|
||||||
|
class ResetCountTrainable(trainable):
|
||||||
|
def reset_config(self, new_config):
|
||||||
|
num_resets[self.trial_id] += 1
|
||||||
|
return super().reset_config(new_config)
|
||||||
|
|
||||||
|
return ResetCountTrainable
|
||||||
|
|
||||||
|
|
||||||
class ActorReuseTest(unittest.TestCase):
|
class ActorReuseTest(unittest.TestCase):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
ray.init(num_cpus=1, num_gpus=0)
|
ray.init(num_cpus=1, num_gpus=0)
|
||||||
|
@ -58,38 +91,56 @@ class ActorReuseTest(unittest.TestCase):
|
||||||
def tearDown(self):
|
def tearDown(self):
|
||||||
ray.shutdown()
|
ray.shutdown()
|
||||||
|
|
||||||
def testTrialReuseDisabled(self):
|
def _run_trials_with_frequent_pauses(self, trainable, reuse=False):
|
||||||
trials = run_experiments(
|
trials = run_experiments(
|
||||||
{
|
{
|
||||||
"foo": {
|
"foo": {
|
||||||
"run": create_resettable_class(),
|
"run": trainable,
|
||||||
"num_samples": 4,
|
"num_samples": 1,
|
||||||
"config": {},
|
"config": {
|
||||||
|
"id": tune.grid_search([0, 1, 2, 3])
|
||||||
|
},
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
reuse_actors=False,
|
reuse_actors=reuse,
|
||||||
scheduler=FrequentPausesScheduler(),
|
scheduler=FrequentPausesScheduler(),
|
||||||
verbose=0)
|
verbose=0)
|
||||||
|
return trials
|
||||||
|
|
||||||
|
def testTrialReuseDisabled(self):
|
||||||
|
trials = self._run_trials_with_frequent_pauses(
|
||||||
|
create_resettable_class(), reuse=False)
|
||||||
|
self.assertEqual([t.last_result["id"] for t in trials], [0, 1, 2, 3])
|
||||||
self.assertEqual([t.last_result["iter"] for t in trials], [2, 2, 2, 2])
|
self.assertEqual([t.last_result["iter"] for t in trials], [2, 2, 2, 2])
|
||||||
self.assertEqual([t.last_result["num_resets"] for t in trials],
|
self.assertEqual([t.last_result["num_resets"] for t in trials],
|
||||||
[0, 0, 0, 0])
|
[0, 0, 0, 0])
|
||||||
|
|
||||||
|
def testTrialReuseDisabledFunction(self):
|
||||||
|
num_resets = defaultdict(lambda: 0)
|
||||||
|
trials = self._run_trials_with_frequent_pauses(
|
||||||
|
create_resettable_function(num_resets), reuse=False)
|
||||||
|
self.assertEqual([t.last_result["id"] for t in trials], [0, 1, 2, 3])
|
||||||
|
self.assertEqual([t.last_result["iter"] for t in trials], [2, 2, 2, 2])
|
||||||
|
self.assertEqual([num_resets[t.trial_id] for t in trials],
|
||||||
|
[0, 0, 0, 0])
|
||||||
|
|
||||||
def testTrialReuseEnabled(self):
|
def testTrialReuseEnabled(self):
|
||||||
trials = run_experiments(
|
trials = self._run_trials_with_frequent_pauses(
|
||||||
{
|
create_resettable_class(), reuse=True)
|
||||||
"foo": {
|
self.assertEqual([t.last_result["id"] for t in trials], [0, 1, 2, 3])
|
||||||
"run": create_resettable_class(),
|
|
||||||
"num_samples": 4,
|
|
||||||
"config": {},
|
|
||||||
}
|
|
||||||
},
|
|
||||||
reuse_actors=True,
|
|
||||||
scheduler=FrequentPausesScheduler(),
|
|
||||||
verbose=0)
|
|
||||||
self.assertEqual([t.last_result["iter"] for t in trials], [2, 2, 2, 2])
|
self.assertEqual([t.last_result["iter"] for t in trials], [2, 2, 2, 2])
|
||||||
self.assertEqual([t.last_result["num_resets"] for t in trials],
|
self.assertEqual([t.last_result["num_resets"] for t in trials],
|
||||||
[1, 2, 3, 4])
|
[1, 2, 3, 4])
|
||||||
|
|
||||||
|
def testTrialReuseEnabledFunction(self):
|
||||||
|
num_resets = defaultdict(lambda: 0)
|
||||||
|
trials = self._run_trials_with_frequent_pauses(
|
||||||
|
create_resettable_function(num_resets), reuse=True)
|
||||||
|
self.assertEqual([t.last_result["id"] for t in trials], [0, 1, 2, 3])
|
||||||
|
self.assertEqual([t.last_result["iter"] for t in trials], [2, 2, 2, 2])
|
||||||
|
self.assertEqual([num_resets[t.trial_id] for t in trials],
|
||||||
|
[0, 0, 0, 0])
|
||||||
|
|
||||||
def testReuseEnabledError(self):
|
def testReuseEnabledError(self):
|
||||||
def run():
|
def run():
|
||||||
run_experiments(
|
run_experiments(
|
||||||
|
@ -97,8 +148,9 @@ class ActorReuseTest(unittest.TestCase):
|
||||||
"foo": {
|
"foo": {
|
||||||
"run": create_resettable_class(),
|
"run": create_resettable_class(),
|
||||||
"max_failures": 1,
|
"max_failures": 1,
|
||||||
"num_samples": 4,
|
"num_samples": 1,
|
||||||
"config": {
|
"config": {
|
||||||
|
"id": tune.grid_search([0, 1, 2, 3]),
|
||||||
"fake_reset_not_supported": True
|
"fake_reset_not_supported": True
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
@ -115,7 +167,8 @@ class ActorReuseTest(unittest.TestCase):
|
||||||
[trial1, trial2] = tune.run(
|
[trial1, trial2] = tune.run(
|
||||||
"foo2",
|
"foo2",
|
||||||
config={
|
config={
|
||||||
"message": tune.grid_search(["First", "Second"])
|
"message": tune.grid_search(["First", "Second"]),
|
||||||
|
"id": -1
|
||||||
},
|
},
|
||||||
log_to_file=True,
|
log_to_file=True,
|
||||||
scheduler=FrequentPausesScheduler(),
|
scheduler=FrequentPausesScheduler(),
|
||||||
|
|
|
@ -552,7 +552,22 @@ class Trainable:
|
||||||
self._close_logfiles()
|
self._close_logfiles()
|
||||||
self._open_logfiles(stdout_file, stderr_file)
|
self._open_logfiles(stdout_file, stderr_file)
|
||||||
|
|
||||||
return self.reset_config(new_config)
|
success = self.reset_config(new_config)
|
||||||
|
if not success:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Reset attributes. Will be overwritten by `restore` if a checkpoint
|
||||||
|
# is provided.
|
||||||
|
self._iteration = 0
|
||||||
|
self._time_total = 0.0
|
||||||
|
self._timesteps_total = None
|
||||||
|
self._episodes_total = None
|
||||||
|
self._time_since_restore = 0.0
|
||||||
|
self._timesteps_since_restore = 0
|
||||||
|
self._iterations_since_restore = 0
|
||||||
|
self._restored = False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
def reset_config(self, new_config):
|
def reset_config(self, new_config):
|
||||||
"""Resets configuration without restarting the trial.
|
"""Resets configuration without restarting the trial.
|
||||||
|
|
|
@ -724,7 +724,6 @@ class TrialRunner:
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
result = self.trial_executor.fetch_result(trial)
|
result = self.trial_executor.fetch_result(trial)
|
||||||
|
|
||||||
is_duplicate = RESULT_DUPLICATE in result
|
is_duplicate = RESULT_DUPLICATE in result
|
||||||
force_checkpoint = result.get(SHOULD_CHECKPOINT, False)
|
force_checkpoint = result.get(SHOULD_CHECKPOINT, False)
|
||||||
# TrialScheduler and SearchAlgorithm still receive a
|
# TrialScheduler and SearchAlgorithm still receive a
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
import logging
|
import logging
|
||||||
import sys
|
import sys
|
||||||
|
import time
|
||||||
|
|
||||||
from ray.tune.error import TuneError
|
from ray.tune.error import TuneError
|
||||||
from ray.tune.experiment import convert_to_experiment_list, Experiment
|
from ray.tune.experiment import convert_to_experiment_list, Experiment
|
||||||
|
@ -255,6 +256,7 @@ def run(
|
||||||
Raises:
|
Raises:
|
||||||
TuneError: Any trials failed and `raise_on_failed_trial` is True.
|
TuneError: Any trials failed and `raise_on_failed_trial` is True.
|
||||||
"""
|
"""
|
||||||
|
all_start = time.time()
|
||||||
if global_checkpoint_period:
|
if global_checkpoint_period:
|
||||||
raise ValueError("global_checkpoint_period is deprecated. Set env var "
|
raise ValueError("global_checkpoint_period is deprecated. Set env var "
|
||||||
"'TUNE_GLOBAL_CHECKPOINT_S' instead.")
|
"'TUNE_GLOBAL_CHECKPOINT_S' instead.")
|
||||||
|
@ -404,10 +406,12 @@ def run(
|
||||||
"`Trainable.default_resource_request` if using the "
|
"`Trainable.default_resource_request` if using the "
|
||||||
"Trainable API.")
|
"Trainable API.")
|
||||||
|
|
||||||
|
tune_start = time.time()
|
||||||
while not runner.is_finished():
|
while not runner.is_finished():
|
||||||
runner.step()
|
runner.step()
|
||||||
if verbose:
|
if verbose:
|
||||||
_report_progress(runner, progress_reporter)
|
_report_progress(runner, progress_reporter)
|
||||||
|
tune_taken = time.time() - tune_start
|
||||||
|
|
||||||
try:
|
try:
|
||||||
runner.checkpoint(force=True)
|
runner.checkpoint(force=True)
|
||||||
|
@ -431,6 +435,10 @@ def run(
|
||||||
else:
|
else:
|
||||||
logger.error("Trials did not complete: %s", incomplete_trials)
|
logger.error("Trials did not complete: %s", incomplete_trials)
|
||||||
|
|
||||||
|
all_taken = time.time() - all_start
|
||||||
|
logger.info(f"Total run time: {all_taken:.2f} seconds "
|
||||||
|
f"({tune_taken:.2f} seconds for the tuning loop).")
|
||||||
|
|
||||||
trials = runner.get_trials()
|
trials = runner.get_trials()
|
||||||
return ExperimentAnalysis(
|
return ExperimentAnalysis(
|
||||||
runner.checkpoint_file,
|
runner.checkpoint_file,
|
||||||
|
|
Loading…
Add table
Reference in a new issue