[tune] Track live trials in a set in the TrialRunner to reduce linear scans (#15811)

This commit is contained in:
Kai Fricke 2021-06-17 09:36:07 +01:00 committed by GitHub
parent 85bc1b2979
commit e547a27944
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 38 additions and 12 deletions

View file

@ -7,7 +7,7 @@ import random
import time
import traceback
from contextlib import contextmanager
from typing import List, Optional
from typing import Iterable, Optional
import ray
from ray.actor import ActorHandle
@ -207,7 +207,7 @@ class RayTrialExecutor(TrialExecutor):
def set_max_pending_trials(self, max_pending: int):
self._pg_manager.set_max_staging(max_pending)
def stage_and_update_status(self, trials: List[Trial]):
def stage_and_update_status(self, trials: Iterable[Trial]):
"""Check and update statuses of scheduled placement groups.
Stages placement groups of all trials.

View file

@ -179,7 +179,7 @@ class MedianStoppingRule(FIFOScheduler):
trial: Trial, time: float) -> str:
pause = time - self._last_pause[trial] > self._min_time_slice
pause = pause and [
t for t in trial_runner.get_trials()
t for t in trial_runner.get_live_trials()
if t.status in (Trial.PENDING, Trial.PAUSED)
]
return TrialScheduler.PAUSE if pause else TrialScheduler.CONTINUE

View file

@ -42,7 +42,7 @@ def _check_trial_running(trial):
def _get_running_trials(runner):
return [t for t in runner.get_trials() if t.status == Trial.RUNNING]
return [t for t in runner.get_live_trials() if t.status == Trial.RUNNING]
def _start_new_cluster():

View file

@ -274,6 +274,7 @@ class TrialRunner:
self._server = TuneServer(self, self._server_port)
self._trials = []
self._live_trials = set() # Set of non-terminated trials
self._cached_trial_decisions = {}
self._queued_trial_decisions = {}
self._updated_queue = False
@ -474,7 +475,14 @@ class TrialRunner:
def is_finished(self):
"""Returns whether all trials have finished running."""
trials_done = all(trial.is_finished() for trial in self._trials)
# The checks here are partly redundant but optimized for quick
# evaluation. Specifically, if there are live trials, we check
# these live trials first. Only if none of the live trials is
# live anymore do we loop over all trials for a final check.
trials_done = (len(self._live_trials) == 0 or all(
trial.is_finished()
for trial in self._live_trials)) and all(trial.is_finished()
for trial in self._trials)
return trials_done and self._search_alg.is_finished()
def step(self):
@ -500,14 +508,14 @@ class TrialRunner:
# continue updating if this was successful (next_trial is not None)
if not self._updated_queue or (self._updated_queue and next_trial):
num_pending_trials = len(
[t for t in self._trials if t.status == Trial.PENDING])
[t for t in self._live_trials if t.status == Trial.PENDING])
while num_pending_trials < self._max_pending_trials:
if not self._update_trial_queue(blocking=False):
break
num_pending_trials += 1
# Update status of staged placement groups
self.trial_executor.stage_and_update_status(self._trials)
self.trial_executor.stage_and_update_status(self._live_trials)
def _start_trial(trial: Trial) -> bool:
"""Helper function to start trial and call callbacks"""
@ -562,6 +570,8 @@ class TrialRunner:
self._callbacks.on_step_end(
iteration=self._iteration, trials=self._trials)
self._reconcile_live_trials()
def get_trial(self, tid):
trial = [t for t in self._trials if t.trial_id == tid]
return trial[0] if trial else None
@ -573,6 +583,10 @@ class TrialRunner:
"""
return self._trials
def get_live_trials(self):
"""Returns the set of trials that are not in Trial.TERMINATED state."""
return self._live_trials
def add_trial(self, trial):
"""Adds a new trial to this TrialRunner.
@ -582,6 +596,8 @@ class TrialRunner:
trial (Trial): Trial to queue.
"""
self._trials.append(trial)
if trial.status != Trial.TERMINATED:
self._live_trials.add(trial)
with warn_if_slow("scheduler.on_trial_add"):
self._scheduler_alg.on_trial_add(self, trial)
self.trial_executor.try_checkpoint_metadata(trial)
@ -625,11 +641,11 @@ class TrialRunner:
Blocks if all trials queued have finished, but search algorithm is
still not finished.
"""
trials_done = all(trial.is_finished() for trial in self._trials)
trials_done = all(trial.is_finished() for trial in self._live_trials)
wait_for_trial = trials_done and not self._search_alg.is_finished()
# Only fetch a new trial if we have no pending trial
if not any(trial.status == Trial.PENDING for trial in self._trials) \
or wait_for_trial:
if not any(trial.status == Trial.PENDING
for trial in self._live_trials) or wait_for_trial:
self._update_trial_queue(blocking=wait_for_trial)
with warn_if_slow("choose_trial_to_run"):
trial = self._scheduler_alg.choose_trial_to_run(self)
@ -935,6 +951,7 @@ class TrialRunner:
logger.debug("Trial %s: Restore processed successfully", trial)
self.trial_executor.set_status(trial, Trial.RUNNING)
self.trial_executor.continue_training(trial)
self._live_trials.add(trial)
except Exception:
logger.exception("Trial %s: Error processing restore.", trial)
if self._fail_fast == TrialRunner.RAISE:
@ -1067,6 +1084,7 @@ class TrialRunner:
# See https://github.com/ray-project/ray/issues/5168
self._trials.pop(self._trials.index(trial))
self._trials.append(trial)
self._live_trials.add(trial)
with warn_if_slow("scheduler.on_trial_add"):
self._scheduler_alg.on_trial_add(self, trial)
@ -1158,10 +1176,18 @@ class TrialRunner:
trial=trial)
error = True
self.trial_executor.stop_trial(trial, error=error, error_msg=error_msg)
self._live_trials.discard(trial)
def cleanup_trials(self):
self.trial_executor.cleanup(self)
def _reconcile_live_trials(self):
"""Loop through live trials and remove if terminated"""
for trial in list(self._live_trials):
# Only for TERMINATED trials. ERRORed trials might be retried.
if trial.status == Trial.TERMINATED:
self._live_trials.remove(trial)
def __getstate__(self):
"""Gets state for trial.
@ -1170,8 +1196,8 @@ class TrialRunner:
"""
state = self.__dict__.copy()
for k in [
"_trials", "_stop_queue", "_server", "_search_alg",
"_scheduler_alg", "_pending_trial_queue_times",
"_trials", "_live_trials", "_stop_queue", "_server",
"_search_alg", "_scheduler_alg", "_pending_trial_queue_times",
"trial_executor", "_syncer", "_callbacks",
"_checkpoint_manager"
]: