[tune] Allow actor reuse for new trials (#13549)

* Allow actor reuse for new trials

* Fix tests and update conf when starting new trial

* Move magic config to `reset_trial`
This commit is contained in:
Kai Fricke 2021-01-20 11:25:33 +01:00 committed by GitHub
parent 800304acfb
commit 6c23bef2a7
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 22 additions and 16 deletions

View file

@ -16,7 +16,6 @@ from ray import ray_constants
from ray.resource_spec import ResourceSpec
from ray.tune.durable_trainable import DurableTrainable
from ray.tune.error import AbortTrialExecution, TuneError
from ray.tune.function_runner import FunctionRunner
from ray.tune.logger import NoopLogger
from ray.tune.result import TRIAL_INFO, STDOUT_FILE, STDERR_FILE
from ray.tune.resources import Resources
@ -246,19 +245,19 @@ class RayTrialExecutor(TrialExecutor):
return None
def _setup_remote_runner(self, trial, reuse_allowed):
def _setup_remote_runner(self, trial):
trial.init_logdir()
# We checkpoint metadata here to try mitigating logdir duplication
self.try_checkpoint_metadata(trial)
logger_creator = partial(noop_logger_creator, logdir=trial.logdir)
if (self._reuse_actors and reuse_allowed
and self._cached_actor is not None):
if (self._reuse_actors and self._cached_actor is not None):
logger.debug("Trial %s: Reusing cached runner %s", trial,
self._cached_actor)
existing_runner = self._cached_actor
self._cached_actor = None
trial.set_runner(existing_runner)
if not self.reset_trial(trial, trial.config, trial.experiment_tag,
logger_creator):
raise AbortTrialExecution(
@ -378,14 +377,7 @@ class RayTrialExecutor(TrialExecutor):
"""
prior_status = trial.status
if runner is None:
# We reuse actors when there is previously instantiated state on
# the actor. Function API calls are also supported when there is
# no checkpoint to continue from.
# TODO: Check preconditions - why is previous state needed?
reuse_allowed = checkpoint is not None or trial.has_checkpoint() \
or issubclass(trial.get_trainable_cls(),
FunctionRunner)
runner = self._setup_remote_runner(trial, reuse_allowed)
runner = self._setup_remote_runner(trial)
if not runner:
return False
trial.set_runner(runner)
@ -520,11 +512,20 @@ class RayTrialExecutor(TrialExecutor):
trial.set_experiment_tag(new_experiment_tag)
trial.set_config(new_config)
trainable = trial.runner
# Pass magic variables
extra_config = copy.deepcopy(new_config)
extra_config[TRIAL_INFO] = TrialInfo(trial)
stdout_file, stderr_file = trial.log_to_file
extra_config[STDOUT_FILE] = stdout_file
extra_config[STDERR_FILE] = stderr_file
with self._change_working_directory(trial):
with warn_if_slow("reset"):
try:
reset_val = ray.get(
trainable.reset.remote(new_config, logger_creator),
trainable.reset.remote(extra_config, logger_creator),
timeout=DEFAULT_GET_TIMEOUT)
except GetTimeoutError:
logger.exception("Trial %s: reset timed out.", trial)

View file

@ -49,6 +49,7 @@ def create_resettable_class():
if "fake_reset_not_supported" in self.config:
return False
self.num_resets += 1
self.iter = 0
self.msg = new_config.get("message", "No message")
return True
@ -131,7 +132,7 @@ class ActorReuseTest(unittest.TestCase):
self.assertEqual([t.last_result["id"] for t in trials], [0, 1, 2, 3])
self.assertEqual([t.last_result["iter"] for t in trials], [2, 2, 2, 2])
self.assertEqual([t.last_result["num_resets"] for t in trials],
[1, 2, 3, 4])
[4, 5, 6, 7])
def testTrialReuseEnabledFunction(self):
num_resets = defaultdict(lambda: 0)
@ -176,7 +177,7 @@ class ActorReuseTest(unittest.TestCase):
reuse_actors=True).trials
# Check trial 1
self.assertEqual(trial1.last_result["num_resets"], 1)
self.assertEqual(trial1.last_result["num_resets"], 2)
self.assertTrue(os.path.exists(os.path.join(trial1.logdir, "stdout")))
self.assertTrue(os.path.exists(os.path.join(trial1.logdir, "stderr")))
with open(os.path.join(trial1.logdir, "stdout"), "rt") as fp:
@ -191,7 +192,7 @@ class ActorReuseTest(unittest.TestCase):
self.assertNotIn("LOG_STDERR: Second", content)
# Check trial 2
self.assertEqual(trial2.last_result["num_resets"], 2)
self.assertEqual(trial2.last_result["num_resets"], 3)
self.assertTrue(os.path.exists(os.path.join(trial2.logdir, "stdout")))
self.assertTrue(os.path.exists(os.path.join(trial2.logdir, "stderr")))
with open(os.path.join(trial2.logdir, "stdout"), "rt") as fp:

View file

@ -431,6 +431,10 @@ class Trainable:
reset actor behavior for the new config."""
self.config = new_config
trial_info = new_config.pop(TRIAL_INFO, None)
if trial_info:
self._trial_info = trial_info
self._result_logger.flush()
self._result_logger.close()