mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
[tune] Support true pooling and batched concurrency (#10352)
This commit is contained in:
parent
e5d089384b
commit
09d4a3241f
7 changed files with 98 additions and 38 deletions
|
@ -1,5 +1,6 @@
|
|||
# coding: utf-8
|
||||
import copy
|
||||
from functools import partial
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
|
@ -14,7 +15,7 @@ from ray.resource_spec import ResourceSpec
|
|||
from ray.tune.durable_trainable import DurableTrainable
|
||||
from ray.tune.error import AbortTrialExecution, TuneError
|
||||
from ray.tune.logger import NoopLogger
|
||||
from ray.tune.result import TRIAL_INFO, LOGDIR_PATH, STDOUT_FILE, STDERR_FILE
|
||||
from ray.tune.result import TRIAL_INFO, STDOUT_FILE, STDERR_FILE
|
||||
from ray.tune.resources import Resources
|
||||
from ray.tune.trainable import TrainableUtil
|
||||
from ray.tune.trial import Trial, Checkpoint, Location, TrialInfo
|
||||
|
@ -122,6 +123,14 @@ class _TrialCleanup:
|
|||
del self._cleanup_map[done]
|
||||
|
||||
|
||||
def noop_logger_creator(config, logdir):
|
||||
# Set the working dir in the remote process, for user file writes
|
||||
os.makedirs(logdir, exist_ok=True)
|
||||
if not ray.worker._mode() == ray.worker.LOCAL_MODE:
|
||||
os.chdir(logdir)
|
||||
return NoopLogger(config, logdir)
|
||||
|
||||
|
||||
class RayTrialExecutor(TrialExecutor):
|
||||
"""An implementation of TrialExecutor based on Ray."""
|
||||
|
||||
|
@ -163,7 +172,7 @@ class RayTrialExecutor(TrialExecutor):
|
|||
trial.init_logger()
|
||||
# We checkpoint metadata here to try mitigating logdir duplication
|
||||
self.try_checkpoint_metadata(trial)
|
||||
remote_logdir = trial.logdir
|
||||
logger_creator = partial(noop_logger_creator, logdir=trial.logdir)
|
||||
|
||||
if (self._reuse_actors and reuse_allowed
|
||||
and self._cached_actor is not None):
|
||||
|
@ -172,7 +181,8 @@ class RayTrialExecutor(TrialExecutor):
|
|||
existing_runner = self._cached_actor
|
||||
self._cached_actor = None
|
||||
trial.set_runner(existing_runner)
|
||||
if not self.reset_trial(trial, trial.config, trial.experiment_tag):
|
||||
if not self.reset_trial(trial, trial.config, trial.experiment_tag,
|
||||
logger_creator):
|
||||
raise AbortTrialExecution(
|
||||
"Trainable runner reuse requires reset_config() to be "
|
||||
"implemented and return True.")
|
||||
|
@ -192,15 +202,6 @@ class RayTrialExecutor(TrialExecutor):
|
|||
memory=trial.resources.memory or None,
|
||||
object_store_memory=trial.resources.object_store_memory or None,
|
||||
resources=trial.resources.custom_resources)
|
||||
|
||||
def logger_creator(config):
|
||||
# Set the working dir in the remote process, for user file writes
|
||||
logdir = config.pop(LOGDIR_PATH, remote_logdir)
|
||||
os.makedirs(logdir, exist_ok=True)
|
||||
if not ray.worker._mode() == ray.worker.LOCAL_MODE:
|
||||
os.chdir(logdir)
|
||||
return NoopLogger(config, logdir)
|
||||
|
||||
# Clear the Trial's location (to be updated later on result)
|
||||
# since we don't know where the remote runner is placed.
|
||||
trial.set_location(Location())
|
||||
|
@ -268,6 +269,12 @@ class RayTrialExecutor(TrialExecutor):
|
|||
"""
|
||||
prior_status = trial.status
|
||||
if runner is None:
|
||||
# TODO: Right now, we only support reuse if there has been
|
||||
# previously instantiated state on the worker. However,
|
||||
# we should consider the case where function evaluations
|
||||
# can be very fast - thereby extending the need to support
|
||||
# reuse to cases where there has not been previously
|
||||
# instantiated state before.
|
||||
reuse_allowed = checkpoint is not None or trial.has_checkpoint()
|
||||
runner = self._setup_remote_runner(trial, reuse_allowed)
|
||||
trial.set_runner(runner)
|
||||
|
@ -377,13 +384,19 @@ class RayTrialExecutor(TrialExecutor):
|
|||
self._paused[trial_future[0]] = trial
|
||||
super(RayTrialExecutor, self).pause_trial(trial)
|
||||
|
||||
def reset_trial(self, trial, new_config, new_experiment_tag):
|
||||
def reset_trial(self,
|
||||
trial,
|
||||
new_config,
|
||||
new_experiment_tag,
|
||||
logger_creator=None):
|
||||
"""Tries to invoke `Trainable.reset()` to reset trial.
|
||||
|
||||
Args:
|
||||
trial (Trial): Trial to be reset.
|
||||
new_config (dict): New configuration for Trial trainable.
|
||||
new_experiment_tag (str): New experiment name for trial.
|
||||
logger_creator (Callable[[Dict], Logger]): A function that
|
||||
instantiates a logger on the actor process.
|
||||
|
||||
Returns:
|
||||
True if `reset_config` is successful else False.
|
||||
|
@ -395,7 +408,7 @@ class RayTrialExecutor(TrialExecutor):
|
|||
with warn_if_slow("reset"):
|
||||
try:
|
||||
reset_val = ray.get(
|
||||
trainable.reset.remote(new_config, trial.logdir),
|
||||
trainable.reset.remote(new_config, logger_creator),
|
||||
timeout=DEFAULT_GET_TIMEOUT)
|
||||
except GetTimeoutError:
|
||||
logger.exception("Trial %s: reset timed out.", trial)
|
||||
|
|
|
@ -74,10 +74,6 @@ TRIAL_INFO = "__trial_info__"
|
|||
STDOUT_FILE = "__stdout_file__"
|
||||
STDERR_FILE = "__stderr_file__"
|
||||
|
||||
# __logdir_path__ is a magic keyword used internally to pass a new
|
||||
# logdir to existing loggers.
|
||||
LOGDIR_PATH = "__logdir_path__"
|
||||
|
||||
# Where Tune writes result files by default
|
||||
DEFAULT_RESULTS_DIR = (os.environ.get("TEST_TMPDIR")
|
||||
or os.environ.get("TUNE_RESULT_DIR")
|
||||
|
|
|
@ -256,6 +256,10 @@ class ConcurrencyLimiter(Searcher):
|
|||
Args:
|
||||
searcher (Searcher): Searcher object that the
|
||||
ConcurrencyLimiter will manage.
|
||||
max_concurrent (int): Maximum concurrent samples from the underlying
|
||||
searcher.
|
||||
batch (bool): Whether to wait for all concurrent samples
|
||||
to finish before updating the underlying searcher.
|
||||
|
||||
Example:
|
||||
|
||||
|
@ -267,11 +271,13 @@ class ConcurrencyLimiter(Searcher):
|
|||
tune.run(trainable, search_alg=search_alg)
|
||||
"""
|
||||
|
||||
def __init__(self, searcher, max_concurrent):
|
||||
def __init__(self, searcher, max_concurrent, batch=False):
|
||||
assert type(max_concurrent) is int and max_concurrent > 0
|
||||
self.searcher = searcher
|
||||
self.max_concurrent = max_concurrent
|
||||
self.batch = batch
|
||||
self.live_trials = set()
|
||||
self.cached_results = {}
|
||||
super(ConcurrencyLimiter, self).__init__(
|
||||
metric=self.searcher.metric, mode=self.searcher.mode)
|
||||
|
||||
|
@ -284,6 +290,7 @@ class ConcurrencyLimiter(Searcher):
|
|||
"concurrency limit: %s/%s.", len(self.live_trials),
|
||||
self.max_concurrent)
|
||||
return
|
||||
|
||||
suggestion = self.searcher.suggest(trial_id)
|
||||
if suggestion not in (None, Searcher.FINISHED):
|
||||
self.live_trials.add(trial_id)
|
||||
|
@ -292,6 +299,18 @@ class ConcurrencyLimiter(Searcher):
|
|||
def on_trial_complete(self, trial_id, result=None, error=False):
|
||||
if trial_id not in self.live_trials:
|
||||
return
|
||||
elif self.batch:
|
||||
self.cached_results[trial_id] = (result, error)
|
||||
if len(self.cached_results) == self.max_concurrent:
|
||||
# Update the underlying searcher once the
|
||||
# full batch is completed.
|
||||
for trial_id, (result, error) in self.cached_results.items():
|
||||
self.searcher.on_trial_complete(
|
||||
trial_id, result=result, error=error)
|
||||
self.live_trials.remove(trial_id)
|
||||
self.cached_results = {}
|
||||
else:
|
||||
return
|
||||
else:
|
||||
self.searcher.on_trial_complete(
|
||||
trial_id, result=result, error=error)
|
||||
|
|
|
@ -29,7 +29,11 @@ def create_resettable_class():
|
|||
print("PRINT_STDERR: {}".format(self.msg), file=sys.stderr)
|
||||
logger.info("LOG_STDERR: {}".format(self.msg))
|
||||
|
||||
return {"num_resets": self.num_resets, "done": self.iter > 1}
|
||||
return {
|
||||
"num_resets": self.num_resets,
|
||||
"done": self.iter > 1,
|
||||
"iter": self.iter
|
||||
}
|
||||
|
||||
def save_checkpoint(self, chkpt_dir):
|
||||
return {"iter": self.iter}
|
||||
|
@ -64,7 +68,9 @@ class ActorReuseTest(unittest.TestCase):
|
|||
}
|
||||
},
|
||||
reuse_actors=False,
|
||||
scheduler=FrequentPausesScheduler())
|
||||
scheduler=FrequentPausesScheduler(),
|
||||
verbose=0)
|
||||
self.assertEqual([t.last_result["iter"] for t in trials], [2, 2, 2, 2])
|
||||
self.assertEqual([t.last_result["num_resets"] for t in trials],
|
||||
[0, 0, 0, 0])
|
||||
|
||||
|
@ -78,11 +84,13 @@ class ActorReuseTest(unittest.TestCase):
|
|||
}
|
||||
},
|
||||
reuse_actors=True,
|
||||
scheduler=FrequentPausesScheduler())
|
||||
scheduler=FrequentPausesScheduler(),
|
||||
verbose=0)
|
||||
self.assertEqual([t.last_result["iter"] for t in trials], [2, 2, 2, 2])
|
||||
self.assertEqual([t.last_result["num_resets"] for t in trials],
|
||||
[1, 2, 3, 4])
|
||||
|
||||
def testTrialReuseEnabledError(self):
|
||||
def testReuseEnabledError(self):
|
||||
def run():
|
||||
run_experiments(
|
||||
{
|
||||
|
|
|
@ -70,7 +70,7 @@ tune.run_experiments({
|
|||
"c": tune.grid_search(list(range(10))),
|
||||
},
|
||||
},
|
||||
}, reuse_actors=True, verbose=1)"""
|
||||
}, verbose=1)"""
|
||||
|
||||
EXPECTED_END_TO_END_START = """Number of trials: 30 (29 PENDING, 1 RUNNING)
|
||||
+---------------+----------+-------+-----+-----+
|
||||
|
|
|
@ -764,6 +764,33 @@ class SearchAlgorithmTest(unittest.TestCase):
|
|||
limiter2.on_trial_complete("test_2", {"result": 3})
|
||||
assert limiter2.suggest("test_3")["score"] == 3
|
||||
|
||||
def testBatchLimiter(self):
|
||||
ray.init(num_cpus=4)
|
||||
|
||||
class TestSuggestion(Searcher):
|
||||
def __init__(self, index):
|
||||
self.index = index
|
||||
self.returned_result = []
|
||||
super().__init__(metric="result", mode="max")
|
||||
|
||||
def suggest(self, trial_id):
|
||||
self.index += 1
|
||||
return {"score": self.index}
|
||||
|
||||
def on_trial_complete(self, trial_id, result=None, **kwargs):
|
||||
self.returned_result.append(result)
|
||||
|
||||
searcher = TestSuggestion(0)
|
||||
limiter = ConcurrencyLimiter(searcher, max_concurrent=2, batch=True)
|
||||
assert limiter.suggest("test_1")["score"] == 1
|
||||
assert limiter.suggest("test_2")["score"] == 2
|
||||
assert limiter.suggest("test_3") is None
|
||||
|
||||
limiter.on_trial_complete("test_1", {"result": 3})
|
||||
assert limiter.suggest("test_3") is None
|
||||
limiter.on_trial_complete("test_2", {"result": 3})
|
||||
assert limiter.suggest("test_3") is not None
|
||||
|
||||
|
||||
class ResourcesTest(unittest.TestCase):
|
||||
def testSubtraction(self):
|
||||
|
|
|
@ -24,7 +24,7 @@ from ray.tune.logger import UnifiedLogger
|
|||
from ray.tune.result import (
|
||||
DEFAULT_RESULTS_DIR, TIME_THIS_ITER_S, TIMESTEPS_THIS_ITER, DONE,
|
||||
TIMESTEPS_TOTAL, EPISODES_THIS_ITER, EPISODES_TOTAL, TRAINING_ITERATION,
|
||||
RESULT_DUPLICATE, TRIAL_INFO, STDOUT_FILE, STDERR_FILE, LOGDIR_PATH)
|
||||
RESULT_DUPLICATE, TRIAL_INFO, STDOUT_FILE, STDERR_FILE)
|
||||
from ray.tune.utils import UtilMonitor
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -224,9 +224,8 @@ class Trainable:
|
|||
self.config = config or {}
|
||||
trial_info = self.config.pop(TRIAL_INFO, None)
|
||||
|
||||
self._logger_creator = logger_creator
|
||||
self._result_logger = self._logdir = None
|
||||
self._create_logger(self.config)
|
||||
self._create_logger(self.config, logger_creator)
|
||||
|
||||
self._stdout_context = self._stdout_fp = self._stdout_stream = None
|
||||
self._stderr_context = self._stderr_fp = self._stderr_stream = None
|
||||
|
@ -535,22 +534,17 @@ class Trainable:
|
|||
export_dir = export_dir or self.logdir
|
||||
return self._export_model(export_formats, export_dir)
|
||||
|
||||
def reset(self, new_config, new_logdir):
|
||||
def reset(self, new_config, logger_creator=None):
|
||||
"""Resets trial for use with new config.
|
||||
|
||||
Subclasses should override reset_config() to actually
|
||||
reset actor behavior for the new config."""
|
||||
self.config = new_config
|
||||
|
||||
logger_config = new_config.copy()
|
||||
logger_config[LOGDIR_PATH] = new_logdir
|
||||
|
||||
self._logdir = new_logdir
|
||||
|
||||
self._result_logger.flush()
|
||||
self._result_logger.close()
|
||||
|
||||
self._create_logger(logger_config)
|
||||
self._create_logger(new_config.copy(), logger_creator)
|
||||
|
||||
stdout_file = new_config.pop(STDOUT_FILE, None)
|
||||
stderr_file = new_config.pop(STDERR_FILE, None)
|
||||
|
@ -576,10 +570,13 @@ class Trainable:
|
|||
"""
|
||||
return False
|
||||
|
||||
def _create_logger(self, config):
|
||||
"""Create logger from logger creator"""
|
||||
if self._logger_creator:
|
||||
self._result_logger = self._logger_creator(config)
|
||||
def _create_logger(self, config, logger_creator=None):
|
||||
"""Create logger from logger creator.
|
||||
|
||||
Sets _logdir and _result_logger.
|
||||
"""
|
||||
if logger_creator:
|
||||
self._result_logger = logger_creator(config)
|
||||
self._logdir = self._result_logger.logdir
|
||||
else:
|
||||
logdir_prefix = datetime.today().strftime("%Y-%m-%d_%H-%M-%S")
|
||||
|
|
Loading…
Add table
Reference in a new issue