mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[tune] Track live trials in a set in the TrialRunner to reduce linear scans (#15811)
This commit is contained in:
parent
85bc1b2979
commit
e547a27944
4 changed files with 38 additions and 12 deletions
|
@ -7,7 +7,7 @@ import random
|
|||
import time
|
||||
import traceback
|
||||
from contextlib import contextmanager
|
||||
from typing import List, Optional
|
||||
from typing import Iterable, Optional
|
||||
|
||||
import ray
|
||||
from ray.actor import ActorHandle
|
||||
|
@ -207,7 +207,7 @@ class RayTrialExecutor(TrialExecutor):
|
|||
def set_max_pending_trials(self, max_pending: int):
|
||||
self._pg_manager.set_max_staging(max_pending)
|
||||
|
||||
def stage_and_update_status(self, trials: List[Trial]):
|
||||
def stage_and_update_status(self, trials: Iterable[Trial]):
|
||||
"""Check and update statuses of scheduled placement groups.
|
||||
|
||||
Stages placement groups of all trials.
|
||||
|
|
|
@ -179,7 +179,7 @@ class MedianStoppingRule(FIFOScheduler):
|
|||
trial: Trial, time: float) -> str:
|
||||
pause = time - self._last_pause[trial] > self._min_time_slice
|
||||
pause = pause and [
|
||||
t for t in trial_runner.get_trials()
|
||||
t for t in trial_runner.get_live_trials()
|
||||
if t.status in (Trial.PENDING, Trial.PAUSED)
|
||||
]
|
||||
return TrialScheduler.PAUSE if pause else TrialScheduler.CONTINUE
|
||||
|
|
|
@ -42,7 +42,7 @@ def _check_trial_running(trial):
|
|||
|
||||
|
||||
def _get_running_trials(runner):
|
||||
return [t for t in runner.get_trials() if t.status == Trial.RUNNING]
|
||||
return [t for t in runner.get_live_trials() if t.status == Trial.RUNNING]
|
||||
|
||||
|
||||
def _start_new_cluster():
|
||||
|
|
|
@ -274,6 +274,7 @@ class TrialRunner:
|
|||
self._server = TuneServer(self, self._server_port)
|
||||
|
||||
self._trials = []
|
||||
self._live_trials = set() # Set of non-terminated trials
|
||||
self._cached_trial_decisions = {}
|
||||
self._queued_trial_decisions = {}
|
||||
self._updated_queue = False
|
||||
|
@ -474,7 +475,14 @@ class TrialRunner:
|
|||
|
||||
def is_finished(self):
|
||||
"""Returns whether all trials have finished running."""
|
||||
trials_done = all(trial.is_finished() for trial in self._trials)
|
||||
# The checks here are partly redundant but optimized for quick
|
||||
# evaluation. Specifically, if there are live trials, we check
|
||||
# these live trials first. Only if none of the live trials is
|
||||
# live anymore do we loop over all trials for a final check.
|
||||
trials_done = (len(self._live_trials) == 0 or all(
|
||||
trial.is_finished()
|
||||
for trial in self._live_trials)) and all(trial.is_finished()
|
||||
for trial in self._trials)
|
||||
return trials_done and self._search_alg.is_finished()
|
||||
|
||||
def step(self):
|
||||
|
@ -500,14 +508,14 @@ class TrialRunner:
|
|||
# continue updating if this was successful (next_trial is not None)
|
||||
if not self._updated_queue or (self._updated_queue and next_trial):
|
||||
num_pending_trials = len(
|
||||
[t for t in self._trials if t.status == Trial.PENDING])
|
||||
[t for t in self._live_trials if t.status == Trial.PENDING])
|
||||
while num_pending_trials < self._max_pending_trials:
|
||||
if not self._update_trial_queue(blocking=False):
|
||||
break
|
||||
num_pending_trials += 1
|
||||
|
||||
# Update status of staged placement groups
|
||||
self.trial_executor.stage_and_update_status(self._trials)
|
||||
self.trial_executor.stage_and_update_status(self._live_trials)
|
||||
|
||||
def _start_trial(trial: Trial) -> bool:
|
||||
"""Helper function to start trial and call callbacks"""
|
||||
|
@ -562,6 +570,8 @@ class TrialRunner:
|
|||
self._callbacks.on_step_end(
|
||||
iteration=self._iteration, trials=self._trials)
|
||||
|
||||
self._reconcile_live_trials()
|
||||
|
||||
def get_trial(self, tid):
|
||||
trial = [t for t in self._trials if t.trial_id == tid]
|
||||
return trial[0] if trial else None
|
||||
|
@ -573,6 +583,10 @@ class TrialRunner:
|
|||
"""
|
||||
return self._trials
|
||||
|
||||
def get_live_trials(self):
|
||||
"""Returns the set of trials that are not in Trial.TERMINATED state."""
|
||||
return self._live_trials
|
||||
|
||||
def add_trial(self, trial):
|
||||
"""Adds a new trial to this TrialRunner.
|
||||
|
||||
|
@ -582,6 +596,8 @@ class TrialRunner:
|
|||
trial (Trial): Trial to queue.
|
||||
"""
|
||||
self._trials.append(trial)
|
||||
if trial.status != Trial.TERMINATED:
|
||||
self._live_trials.add(trial)
|
||||
with warn_if_slow("scheduler.on_trial_add"):
|
||||
self._scheduler_alg.on_trial_add(self, trial)
|
||||
self.trial_executor.try_checkpoint_metadata(trial)
|
||||
|
@ -625,11 +641,11 @@ class TrialRunner:
|
|||
Blocks if all trials queued have finished, but search algorithm is
|
||||
still not finished.
|
||||
"""
|
||||
trials_done = all(trial.is_finished() for trial in self._trials)
|
||||
trials_done = all(trial.is_finished() for trial in self._live_trials)
|
||||
wait_for_trial = trials_done and not self._search_alg.is_finished()
|
||||
# Only fetch a new trial if we have no pending trial
|
||||
if not any(trial.status == Trial.PENDING for trial in self._trials) \
|
||||
or wait_for_trial:
|
||||
if not any(trial.status == Trial.PENDING
|
||||
for trial in self._live_trials) or wait_for_trial:
|
||||
self._update_trial_queue(blocking=wait_for_trial)
|
||||
with warn_if_slow("choose_trial_to_run"):
|
||||
trial = self._scheduler_alg.choose_trial_to_run(self)
|
||||
|
@ -935,6 +951,7 @@ class TrialRunner:
|
|||
logger.debug("Trial %s: Restore processed successfully", trial)
|
||||
self.trial_executor.set_status(trial, Trial.RUNNING)
|
||||
self.trial_executor.continue_training(trial)
|
||||
self._live_trials.add(trial)
|
||||
except Exception:
|
||||
logger.exception("Trial %s: Error processing restore.", trial)
|
||||
if self._fail_fast == TrialRunner.RAISE:
|
||||
|
@ -1067,6 +1084,7 @@ class TrialRunner:
|
|||
# See https://github.com/ray-project/ray/issues/5168
|
||||
self._trials.pop(self._trials.index(trial))
|
||||
self._trials.append(trial)
|
||||
self._live_trials.add(trial)
|
||||
|
||||
with warn_if_slow("scheduler.on_trial_add"):
|
||||
self._scheduler_alg.on_trial_add(self, trial)
|
||||
|
@ -1158,10 +1176,18 @@ class TrialRunner:
|
|||
trial=trial)
|
||||
error = True
|
||||
self.trial_executor.stop_trial(trial, error=error, error_msg=error_msg)
|
||||
self._live_trials.discard(trial)
|
||||
|
||||
def cleanup_trials(self):
|
||||
self.trial_executor.cleanup(self)
|
||||
|
||||
def _reconcile_live_trials(self):
|
||||
"""Loop through live trials and remove if terminated"""
|
||||
for trial in list(self._live_trials):
|
||||
# Only for TERMINATED trials. ERRORed trials might be retried.
|
||||
if trial.status == Trial.TERMINATED:
|
||||
self._live_trials.remove(trial)
|
||||
|
||||
def __getstate__(self):
|
||||
"""Gets state for trial.
|
||||
|
||||
|
@ -1170,8 +1196,8 @@ class TrialRunner:
|
|||
"""
|
||||
state = self.__dict__.copy()
|
||||
for k in [
|
||||
"_trials", "_stop_queue", "_server", "_search_alg",
|
||||
"_scheduler_alg", "_pending_trial_queue_times",
|
||||
"_trials", "_live_trials", "_stop_queue", "_server",
|
||||
"_search_alg", "_scheduler_alg", "_pending_trial_queue_times",
|
||||
"trial_executor", "_syncer", "_callbacks",
|
||||
"_checkpoint_manager"
|
||||
]:
|
||||
|
|
Loading…
Add table
Reference in a new issue