[tune] Support true pooling and batched concurrency (#10352)

This commit is contained in:
Richard Liaw 2020-09-01 10:33:49 -07:00 committed by GitHub
parent e5d089384b
commit 09d4a3241f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 98 additions and 38 deletions

View file

@ -1,5 +1,6 @@
# coding: utf-8
import copy
from functools import partial
import logging
import os
import random
@ -14,7 +15,7 @@ from ray.resource_spec import ResourceSpec
from ray.tune.durable_trainable import DurableTrainable
from ray.tune.error import AbortTrialExecution, TuneError
from ray.tune.logger import NoopLogger
from ray.tune.result import TRIAL_INFO, LOGDIR_PATH, STDOUT_FILE, STDERR_FILE
from ray.tune.result import TRIAL_INFO, STDOUT_FILE, STDERR_FILE
from ray.tune.resources import Resources
from ray.tune.trainable import TrainableUtil
from ray.tune.trial import Trial, Checkpoint, Location, TrialInfo
@ -122,6 +123,14 @@ class _TrialCleanup:
del self._cleanup_map[done]
def noop_logger_creator(config, logdir):
# Set the working dir in the remote process, for user file writes
os.makedirs(logdir, exist_ok=True)
if not ray.worker._mode() == ray.worker.LOCAL_MODE:
os.chdir(logdir)
return NoopLogger(config, logdir)
class RayTrialExecutor(TrialExecutor):
"""An implementation of TrialExecutor based on Ray."""
@ -163,7 +172,7 @@ class RayTrialExecutor(TrialExecutor):
trial.init_logger()
# We checkpoint metadata here to try mitigating logdir duplication
self.try_checkpoint_metadata(trial)
remote_logdir = trial.logdir
logger_creator = partial(noop_logger_creator, logdir=trial.logdir)
if (self._reuse_actors and reuse_allowed
and self._cached_actor is not None):
@ -172,7 +181,8 @@ class RayTrialExecutor(TrialExecutor):
existing_runner = self._cached_actor
self._cached_actor = None
trial.set_runner(existing_runner)
if not self.reset_trial(trial, trial.config, trial.experiment_tag):
if not self.reset_trial(trial, trial.config, trial.experiment_tag,
logger_creator):
raise AbortTrialExecution(
"Trainable runner reuse requires reset_config() to be "
"implemented and return True.")
@ -192,15 +202,6 @@ class RayTrialExecutor(TrialExecutor):
memory=trial.resources.memory or None,
object_store_memory=trial.resources.object_store_memory or None,
resources=trial.resources.custom_resources)
def logger_creator(config):
# Set the working dir in the remote process, for user file writes
logdir = config.pop(LOGDIR_PATH, remote_logdir)
os.makedirs(logdir, exist_ok=True)
if not ray.worker._mode() == ray.worker.LOCAL_MODE:
os.chdir(logdir)
return NoopLogger(config, logdir)
# Clear the Trial's location (to be updated later on result)
# since we don't know where the remote runner is placed.
trial.set_location(Location())
@ -268,6 +269,12 @@ class RayTrialExecutor(TrialExecutor):
"""
prior_status = trial.status
if runner is None:
# TODO: Right now, we only support reuse if there has been
# previously instantiated state on the worker. However,
# we should consider the case where function evaluations
# can be very fast - thereby extending the need to support
# reuse to cases where there has not been previously
# instantiated state before.
reuse_allowed = checkpoint is not None or trial.has_checkpoint()
runner = self._setup_remote_runner(trial, reuse_allowed)
trial.set_runner(runner)
@ -377,13 +384,19 @@ class RayTrialExecutor(TrialExecutor):
self._paused[trial_future[0]] = trial
super(RayTrialExecutor, self).pause_trial(trial)
def reset_trial(self, trial, new_config, new_experiment_tag):
def reset_trial(self,
trial,
new_config,
new_experiment_tag,
logger_creator=None):
"""Tries to invoke `Trainable.reset()` to reset trial.
Args:
trial (Trial): Trial to be reset.
new_config (dict): New configuration for Trial trainable.
new_experiment_tag (str): New experiment name for trial.
logger_creator (Callable[[Dict], Logger]): A function that
instantiates a logger on the actor process.
Returns:
True if `reset_config` is successful else False.
@ -395,7 +408,7 @@ class RayTrialExecutor(TrialExecutor):
with warn_if_slow("reset"):
try:
reset_val = ray.get(
trainable.reset.remote(new_config, trial.logdir),
trainable.reset.remote(new_config, logger_creator),
timeout=DEFAULT_GET_TIMEOUT)
except GetTimeoutError:
logger.exception("Trial %s: reset timed out.", trial)

View file

@ -74,10 +74,6 @@ TRIAL_INFO = "__trial_info__"
STDOUT_FILE = "__stdout_file__"
STDERR_FILE = "__stderr_file__"
# __logdir_path__ is a magic keyword used internally to pass a new
# logdir to existing loggers.
LOGDIR_PATH = "__logdir_path__"
# Where Tune writes result files by default
DEFAULT_RESULTS_DIR = (os.environ.get("TEST_TMPDIR")
or os.environ.get("TUNE_RESULT_DIR")

View file

@ -256,6 +256,10 @@ class ConcurrencyLimiter(Searcher):
Args:
searcher (Searcher): Searcher object that the
ConcurrencyLimiter will manage.
max_concurrent (int): Maximum concurrent samples from the underlying
searcher.
batch (bool): Whether to wait for all concurrent samples
to finish before updating the underlying searcher.
Example:
@ -267,11 +271,13 @@ class ConcurrencyLimiter(Searcher):
tune.run(trainable, search_alg=search_alg)
"""
def __init__(self, searcher, max_concurrent):
def __init__(self, searcher, max_concurrent, batch=False):
assert type(max_concurrent) is int and max_concurrent > 0
self.searcher = searcher
self.max_concurrent = max_concurrent
self.batch = batch
self.live_trials = set()
self.cached_results = {}
super(ConcurrencyLimiter, self).__init__(
metric=self.searcher.metric, mode=self.searcher.mode)
@ -284,6 +290,7 @@ class ConcurrencyLimiter(Searcher):
"concurrency limit: %s/%s.", len(self.live_trials),
self.max_concurrent)
return
suggestion = self.searcher.suggest(trial_id)
if suggestion not in (None, Searcher.FINISHED):
self.live_trials.add(trial_id)
@ -292,6 +299,18 @@ class ConcurrencyLimiter(Searcher):
def on_trial_complete(self, trial_id, result=None, error=False):
if trial_id not in self.live_trials:
return
elif self.batch:
self.cached_results[trial_id] = (result, error)
if len(self.cached_results) == self.max_concurrent:
# Update the underlying searcher once the
# full batch is completed.
for trial_id, (result, error) in self.cached_results.items():
self.searcher.on_trial_complete(
trial_id, result=result, error=error)
self.live_trials.remove(trial_id)
self.cached_results = {}
else:
return
else:
self.searcher.on_trial_complete(
trial_id, result=result, error=error)

View file

@ -29,7 +29,11 @@ def create_resettable_class():
print("PRINT_STDERR: {}".format(self.msg), file=sys.stderr)
logger.info("LOG_STDERR: {}".format(self.msg))
return {"num_resets": self.num_resets, "done": self.iter > 1}
return {
"num_resets": self.num_resets,
"done": self.iter > 1,
"iter": self.iter
}
def save_checkpoint(self, chkpt_dir):
return {"iter": self.iter}
@ -64,7 +68,9 @@ class ActorReuseTest(unittest.TestCase):
}
},
reuse_actors=False,
scheduler=FrequentPausesScheduler())
scheduler=FrequentPausesScheduler(),
verbose=0)
self.assertEqual([t.last_result["iter"] for t in trials], [2, 2, 2, 2])
self.assertEqual([t.last_result["num_resets"] for t in trials],
[0, 0, 0, 0])
@ -78,11 +84,13 @@ class ActorReuseTest(unittest.TestCase):
}
},
reuse_actors=True,
scheduler=FrequentPausesScheduler())
scheduler=FrequentPausesScheduler(),
verbose=0)
self.assertEqual([t.last_result["iter"] for t in trials], [2, 2, 2, 2])
self.assertEqual([t.last_result["num_resets"] for t in trials],
[1, 2, 3, 4])
def testTrialReuseEnabledError(self):
def testReuseEnabledError(self):
def run():
run_experiments(
{

View file

@ -70,7 +70,7 @@ tune.run_experiments({
"c": tune.grid_search(list(range(10))),
},
},
}, reuse_actors=True, verbose=1)"""
}, verbose=1)"""
EXPECTED_END_TO_END_START = """Number of trials: 30 (29 PENDING, 1 RUNNING)
+---------------+----------+-------+-----+-----+

View file

@ -764,6 +764,33 @@ class SearchAlgorithmTest(unittest.TestCase):
limiter2.on_trial_complete("test_2", {"result": 3})
assert limiter2.suggest("test_3")["score"] == 3
def testBatchLimiter(self):
ray.init(num_cpus=4)
class TestSuggestion(Searcher):
def __init__(self, index):
self.index = index
self.returned_result = []
super().__init__(metric="result", mode="max")
def suggest(self, trial_id):
self.index += 1
return {"score": self.index}
def on_trial_complete(self, trial_id, result=None, **kwargs):
self.returned_result.append(result)
searcher = TestSuggestion(0)
limiter = ConcurrencyLimiter(searcher, max_concurrent=2, batch=True)
assert limiter.suggest("test_1")["score"] == 1
assert limiter.suggest("test_2")["score"] == 2
assert limiter.suggest("test_3") is None
limiter.on_trial_complete("test_1", {"result": 3})
assert limiter.suggest("test_3") is None
limiter.on_trial_complete("test_2", {"result": 3})
assert limiter.suggest("test_3") is not None
class ResourcesTest(unittest.TestCase):
def testSubtraction(self):

View file

@ -24,7 +24,7 @@ from ray.tune.logger import UnifiedLogger
from ray.tune.result import (
DEFAULT_RESULTS_DIR, TIME_THIS_ITER_S, TIMESTEPS_THIS_ITER, DONE,
TIMESTEPS_TOTAL, EPISODES_THIS_ITER, EPISODES_TOTAL, TRAINING_ITERATION,
RESULT_DUPLICATE, TRIAL_INFO, STDOUT_FILE, STDERR_FILE, LOGDIR_PATH)
RESULT_DUPLICATE, TRIAL_INFO, STDOUT_FILE, STDERR_FILE)
from ray.tune.utils import UtilMonitor
logger = logging.getLogger(__name__)
@ -224,9 +224,8 @@ class Trainable:
self.config = config or {}
trial_info = self.config.pop(TRIAL_INFO, None)
self._logger_creator = logger_creator
self._result_logger = self._logdir = None
self._create_logger(self.config)
self._create_logger(self.config, logger_creator)
self._stdout_context = self._stdout_fp = self._stdout_stream = None
self._stderr_context = self._stderr_fp = self._stderr_stream = None
@ -535,22 +534,17 @@ class Trainable:
export_dir = export_dir or self.logdir
return self._export_model(export_formats, export_dir)
def reset(self, new_config, new_logdir):
def reset(self, new_config, logger_creator=None):
"""Resets trial for use with new config.
Subclasses should override reset_config() to actually
reset actor behavior for the new config."""
self.config = new_config
logger_config = new_config.copy()
logger_config[LOGDIR_PATH] = new_logdir
self._logdir = new_logdir
self._result_logger.flush()
self._result_logger.close()
self._create_logger(logger_config)
self._create_logger(new_config.copy(), logger_creator)
stdout_file = new_config.pop(STDOUT_FILE, None)
stderr_file = new_config.pop(STDERR_FILE, None)
@ -576,10 +570,13 @@ class Trainable:
"""
return False
def _create_logger(self, config):
"""Create logger from logger creator"""
if self._logger_creator:
self._result_logger = self._logger_creator(config)
def _create_logger(self, config, logger_creator=None):
"""Create logger from logger creator.
Sets _logdir and _result_logger.
"""
if logger_creator:
self._result_logger = logger_creator(config)
self._logdir = self._result_logger.logdir
else:
logdir_prefix = datetime.today().strftime("%Y-%m-%d_%H-%M-%S")