mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
[tune] Allow actor reuse for new trials (#13549)
* Allow actor reuse for new trials * Fix tests and update conf when starting new trial * Move magic config to `reset_trial`
This commit is contained in:
parent
800304acfb
commit
6c23bef2a7
3 changed files with 22 additions and 16 deletions
|
@ -16,7 +16,6 @@ from ray import ray_constants
|
|||
from ray.resource_spec import ResourceSpec
|
||||
from ray.tune.durable_trainable import DurableTrainable
|
||||
from ray.tune.error import AbortTrialExecution, TuneError
|
||||
from ray.tune.function_runner import FunctionRunner
|
||||
from ray.tune.logger import NoopLogger
|
||||
from ray.tune.result import TRIAL_INFO, STDOUT_FILE, STDERR_FILE
|
||||
from ray.tune.resources import Resources
|
||||
|
@ -246,19 +245,19 @@ class RayTrialExecutor(TrialExecutor):
|
|||
|
||||
return None
|
||||
|
||||
def _setup_remote_runner(self, trial, reuse_allowed):
|
||||
def _setup_remote_runner(self, trial):
|
||||
trial.init_logdir()
|
||||
# We checkpoint metadata here to try mitigating logdir duplication
|
||||
self.try_checkpoint_metadata(trial)
|
||||
logger_creator = partial(noop_logger_creator, logdir=trial.logdir)
|
||||
|
||||
if (self._reuse_actors and reuse_allowed
|
||||
and self._cached_actor is not None):
|
||||
if (self._reuse_actors and self._cached_actor is not None):
|
||||
logger.debug("Trial %s: Reusing cached runner %s", trial,
|
||||
self._cached_actor)
|
||||
existing_runner = self._cached_actor
|
||||
self._cached_actor = None
|
||||
trial.set_runner(existing_runner)
|
||||
|
||||
if not self.reset_trial(trial, trial.config, trial.experiment_tag,
|
||||
logger_creator):
|
||||
raise AbortTrialExecution(
|
||||
|
@ -378,14 +377,7 @@ class RayTrialExecutor(TrialExecutor):
|
|||
"""
|
||||
prior_status = trial.status
|
||||
if runner is None:
|
||||
# We reuse actors when there is previously instantiated state on
|
||||
# the actor. Function API calls are also supported when there is
|
||||
# no checkpoint to continue from.
|
||||
# TODO: Check preconditions - why is previous state needed?
|
||||
reuse_allowed = checkpoint is not None or trial.has_checkpoint() \
|
||||
or issubclass(trial.get_trainable_cls(),
|
||||
FunctionRunner)
|
||||
runner = self._setup_remote_runner(trial, reuse_allowed)
|
||||
runner = self._setup_remote_runner(trial)
|
||||
if not runner:
|
||||
return False
|
||||
trial.set_runner(runner)
|
||||
|
@ -520,11 +512,20 @@ class RayTrialExecutor(TrialExecutor):
|
|||
trial.set_experiment_tag(new_experiment_tag)
|
||||
trial.set_config(new_config)
|
||||
trainable = trial.runner
|
||||
|
||||
# Pass magic variables
|
||||
extra_config = copy.deepcopy(new_config)
|
||||
extra_config[TRIAL_INFO] = TrialInfo(trial)
|
||||
|
||||
stdout_file, stderr_file = trial.log_to_file
|
||||
extra_config[STDOUT_FILE] = stdout_file
|
||||
extra_config[STDERR_FILE] = stderr_file
|
||||
|
||||
with self._change_working_directory(trial):
|
||||
with warn_if_slow("reset"):
|
||||
try:
|
||||
reset_val = ray.get(
|
||||
trainable.reset.remote(new_config, logger_creator),
|
||||
trainable.reset.remote(extra_config, logger_creator),
|
||||
timeout=DEFAULT_GET_TIMEOUT)
|
||||
except GetTimeoutError:
|
||||
logger.exception("Trial %s: reset timed out.", trial)
|
||||
|
|
|
@ -49,6 +49,7 @@ def create_resettable_class():
|
|||
if "fake_reset_not_supported" in self.config:
|
||||
return False
|
||||
self.num_resets += 1
|
||||
self.iter = 0
|
||||
self.msg = new_config.get("message", "No message")
|
||||
return True
|
||||
|
||||
|
@ -131,7 +132,7 @@ class ActorReuseTest(unittest.TestCase):
|
|||
self.assertEqual([t.last_result["id"] for t in trials], [0, 1, 2, 3])
|
||||
self.assertEqual([t.last_result["iter"] for t in trials], [2, 2, 2, 2])
|
||||
self.assertEqual([t.last_result["num_resets"] for t in trials],
|
||||
[1, 2, 3, 4])
|
||||
[4, 5, 6, 7])
|
||||
|
||||
def testTrialReuseEnabledFunction(self):
|
||||
num_resets = defaultdict(lambda: 0)
|
||||
|
@ -176,7 +177,7 @@ class ActorReuseTest(unittest.TestCase):
|
|||
reuse_actors=True).trials
|
||||
|
||||
# Check trial 1
|
||||
self.assertEqual(trial1.last_result["num_resets"], 1)
|
||||
self.assertEqual(trial1.last_result["num_resets"], 2)
|
||||
self.assertTrue(os.path.exists(os.path.join(trial1.logdir, "stdout")))
|
||||
self.assertTrue(os.path.exists(os.path.join(trial1.logdir, "stderr")))
|
||||
with open(os.path.join(trial1.logdir, "stdout"), "rt") as fp:
|
||||
|
@ -191,7 +192,7 @@ class ActorReuseTest(unittest.TestCase):
|
|||
self.assertNotIn("LOG_STDERR: Second", content)
|
||||
|
||||
# Check trial 2
|
||||
self.assertEqual(trial2.last_result["num_resets"], 2)
|
||||
self.assertEqual(trial2.last_result["num_resets"], 3)
|
||||
self.assertTrue(os.path.exists(os.path.join(trial2.logdir, "stdout")))
|
||||
self.assertTrue(os.path.exists(os.path.join(trial2.logdir, "stderr")))
|
||||
with open(os.path.join(trial2.logdir, "stdout"), "rt") as fp:
|
||||
|
|
|
@ -431,6 +431,10 @@ class Trainable:
|
|||
reset actor behavior for the new config."""
|
||||
self.config = new_config
|
||||
|
||||
trial_info = new_config.pop(TRIAL_INFO, None)
|
||||
if trial_info:
|
||||
self._trial_info = trial_info
|
||||
|
||||
self._result_logger.flush()
|
||||
self._result_logger.close()
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue