[tune] Single wait refactor. (#21852)

This is a down scoped change. For the full overview picture of Tune control loop, see [`Tune control loop refactoring`](https://docs.google.com/document/d/1RDsW7SVzwMPZfA0WLOPA4YTqbRyXIHGYmBenJk33HaE/edit#heading=h.2za3bbxbs5gn)

1. Previously there are separate waits on pg ready and other events. As a result, there are quite a few timing tweaks that are inefficient, hard to understand and unit test. This PR consolidates into a single wait that is handled by TrialRunner in each step.
- A few event types are introduced, and their mapping into scenarios
  * PG_READY --> Should place a trial onto it. If somehow there is no trial to be placed there, the pg will be put in _ready momentarily. This is due to historically resources is conceptualized as a pull based model. 
  * NO_RUNNING_TRIALS_TIME_OUT --> possibly not sufficient resources case
  * TRAINING_RESULT
  * SAVING_RESULT
  * RESTORING_RESULT
  * YIELD --> This just means that simply taking very long to train. We need to punt back to the main loop to print out status info etc.

2. Previously TrialCleanup is not very efficient and can be racing between Trainable.stop() and `return_placement_group`. This PR streamlines the Trial cleanup process by explicitly let Trainable.stop() to finish followed by `return_placement_group(pg)`. Note, graceful shutdown is needed in cases like `pause_trial` where checkpointing to memory needs to be given the time to happen before the actor is gone. 

3. There are quite some env variables removed (timing tweaks), that I consider OK to proceed without deprecation cycle.
This commit is contained in:
xwjiang2010 2022-02-09 07:31:17 -08:00 committed by GitHub
parent dea3574050
commit 323511b716
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
20 changed files with 825 additions and 802 deletions

View file

@ -36,6 +36,8 @@ These are the environment variables Ray Tune currently considers:
letting them finish the current training step and any user-defined cleanup. letting them finish the current training step and any user-defined cleanup.
Setting this variable to a non-zero, positive integer will cause trials to be forcefully Setting this variable to a non-zero, positive integer will cause trials to be forcefully
terminated after a grace period of that many seconds. Defaults to ``0``. terminated after a grace period of that many seconds. Defaults to ``0``.
* **TUNE_GET_EXECUTOR_EVENT_WAIT_S**: The time that TrialRunner waits for the
next ExecutorEvent in a blocking fashion. Defaults to ``5``.
* **TUNE_FUNCTION_THREAD_TIMEOUT_S**: Time in seconds the function API waits * **TUNE_FUNCTION_THREAD_TIMEOUT_S**: Time in seconds the function API waits
for threads to finish after instructing them to complete. Defaults to ``2``. for threads to finish after instructing them to complete. Defaults to ``2``.
* **TUNE_GLOBAL_CHECKPOINT_S**: Time in seconds that limits how often Tune's * **TUNE_GLOBAL_CHECKPOINT_S**: Time in seconds that limits how often Tune's
@ -57,10 +59,6 @@ These are the environment variables Ray Tune currently considers:
In normal circumstances these shouldn't differ anyway, but reconcilation makes sure to capture cases when In normal circumstances these shouldn't differ anyway, but reconcilation makes sure to capture cases when
placement groups are manually destroyed. Reconcilation doesn't take much time, but it can add up when placement groups are manually destroyed. Reconcilation doesn't take much time, but it can add up when
running a large number of short trials. Defaults to every ``5`` (seconds). running a large number of short trials. Defaults to every ``5`` (seconds).
* **TUNE_PLACEMENT_GROUP_WAIT_S**: Default time the trial executor waits for placement
groups to be placed before continuing the tuning loop. Setting this to a float
will block for that many seconds. This is mostly used for testing purposes. Defaults
to -1, which disables blocking.
* **TUNE_RESULT_DIR**: Directory where Ray Tune trial results are stored. If this * **TUNE_RESULT_DIR**: Directory where Ray Tune trial results are stored. If this
is not set, ``~/ray_results`` will be used. is not set, ``~/ray_results`` will be used.
* **TUNE_RESULT_BUFFER_LENGTH**: Ray Tune can buffer results from trainables before they are passed * **TUNE_RESULT_BUFFER_LENGTH**: Ray Tune can buffer results from trainables before they are passed
@ -74,11 +72,6 @@ These are the environment variables Ray Tune currently considers:
but never longer than this value. Defaults to 100 (seconds). but never longer than this value. Defaults to 100 (seconds).
* **TUNE_RESULT_BUFFER_MIN_TIME_S**: Additionally, you can specify a minimum time to buffer results. Defaults to 0. * **TUNE_RESULT_BUFFER_MIN_TIME_S**: Additionally, you can specify a minimum time to buffer results. Defaults to 0.
* **TUNE_SYNCER_VERBOSITY**: Amount of command output when using Tune with Docker Syncer. Defaults to 0. * **TUNE_SYNCER_VERBOSITY**: Amount of command output when using Tune with Docker Syncer. Defaults to 0.
* **TUNE_TRIAL_RESULT_WAIT_TIME_S**: Amount of time Ray Tune will block until a result from a running trial is received.
Defaults to 1 (second).
* **TUNE_TRIAL_STARTUP_GRACE_PERIOD**: Amount of time after starting a trial that Ray Tune checks for successful
trial startups. After the grace period, Tune will block for up to ``TUNE_TRIAL_RESULT_WAIT_TIME_S`` seconds
until a result from a running trial is received. Can be disabled by setting this to lower or equal to 0.
* **TUNE_WARN_THRESHOLD_S**: Threshold for logging if an Tune event loop operation takes too long. Defaults to 0.5 (seconds). * **TUNE_WARN_THRESHOLD_S**: Threshold for logging if an Tune event loop operation takes too long. Defaults to 0.5 (seconds).
* **TUNE_WARN_INSUFFICENT_RESOURCE_THRESHOLD_S**: Threshold for throwing a warning if no active trials are in ``RUNNING`` state * **TUNE_WARN_INSUFFICENT_RESOURCE_THRESHOLD_S**: Threshold for throwing a warning if no active trials are in ``RUNNING`` state
for this amount of seconds. If the Ray Tune job is stuck in this state (most likely due to insufficient resources), for this amount of seconds. If the Ray Tune job is stuck in this state (most likely due to insufficient resources),

View file

@ -1,11 +1,12 @@
# coding: utf-8 # coding: utf-8
import copy import copy
import inspect import inspect
import random
from collections import deque from collections import deque
from enum import Enum
from functools import partial from functools import partial
import logging import logging
import os import os
import random
import time import time
import traceback import traceback
from contextlib import contextmanager from contextlib import contextmanager
@ -15,14 +16,15 @@ from typing import (
Iterable, Iterable,
List, List,
Optional, Optional,
Union,
Set,
) )
import ray import ray
from ray.actor import ActorHandle
from ray.exceptions import GetTimeoutError from ray.exceptions import GetTimeoutError
from ray import ray_constants from ray import ray_constants
from ray._private.resource_spec import NODE_ID_PREFIX from ray._private.resource_spec import NODE_ID_PREFIX
from ray.tune.error import AbortTrialExecution from ray.tune.error import AbortTrialExecution, TuneError
from ray.tune.logger import NoopLogger from ray.tune.logger import NoopLogger
from ray.tune.result import TRIAL_INFO, STDOUT_FILE, STDERR_FILE from ray.tune.result import TRIAL_INFO, STDOUT_FILE, STDERR_FILE
from ray.tune.resources import Resources from ray.tune.resources import Resources
@ -33,14 +35,12 @@ from ray.tune.trial_executor import TrialExecutor
from ray.tune.utils import warn_if_slow from ray.tune.utils import warn_if_slow
from ray.util import log_once from ray.util import log_once
from ray.util.annotations import DeveloperAPI from ray.util.annotations import DeveloperAPI
from ray.util.placement_group import remove_placement_group, PlacementGroup
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
TUNE_STATE_REFRESH_PERIOD = 10 # Refresh resources every 10 s TUNE_STATE_REFRESH_PERIOD = 10 # Refresh resources every 10 s
BOTTLENECK_WARN_PERIOD_S = 60
NONTRIVIAL_WAIT_TIME_THRESHOLD_S = 1e-3
DEFAULT_GET_TIMEOUT = 60.0 # seconds DEFAULT_GET_TIMEOUT = 60.0 # seconds
TRIAL_CLEANUP_THRESHOLD = 100
class _ActorClassCache: class _ActorClassCache:
@ -86,75 +86,56 @@ class _LocalWrapper:
return self._result return self._result
def post_stop_cleanup(future, pg):
"""Things to be done after a trial is stopped."""
assert isinstance(pg, PlacementGroup)
try:
# This should not be blocking as
# we are only here when triggered.
ray.get(future, timeout=0)
except GetTimeoutError:
if log_once("tune_trial_cleanup_timeout"):
logger.error(
"Timed out when trying to stop the Ray actor gracefully. "
"Consider making `stop` a faster operation."
)
except Exception:
if log_once("tune_trial_cleanup_exception"):
logger.error(
f"An exception occurred when trying to stop the Ray actor:"
f"{traceback.format_exc()}"
)
finally:
remove_placement_group(pg)
class _TrialCleanup: class _TrialCleanup:
"""Mechanism for ensuring trial stop futures are cleaned up. """Responsible for triggering force cleanup of remote actors,
without waiting for `Trainable.stop()` to finish.
Args: Only instantiated when `TUNE_FORCE_TRIAL_CLEANUP_S` is set up.
threshold (int): Number of futures to hold at once. If the threshold
is passed, cleanup will kick in and remove futures.
force_cleanup (int): Grace periods for forceful actor termination.
If 0, actors will not be forcefully terminated.
""" """
def __init__( def __init__(self, force_cleanup):
self, threshold: int = TRIAL_CLEANUP_THRESHOLD, force_cleanup: int = 0 assert force_cleanup
):
self.threshold = threshold
self._cleanup_map = {}
if force_cleanup < 0:
force_cleanup = 0
self._force_cleanup = force_cleanup self._force_cleanup = force_cleanup
self._future_to_insert_time = deque()
def add(self, trial: Trial, actor: ActorHandle): def add(self, future):
"""Adds a trial actor to be stopped. self._future_to_insert_time.append((future, time.time()))
If the number of futures exceeds the threshold, the cleanup mechanism def get_next(self):
will kick in. """Get the next future that is eligible to be cleaned up forcibly."""
if (
Args: len(self._future_to_insert_time) > 0
trial (Trial): The trial corresponding to the future. and self._future_to_insert_time[0][1] + self._force_cleanup < time.time()
actor (ActorHandle): Handle to the trainable to be stopped. ):
""" return self._future_to_insert_time.popleft()
future = actor.stop.remote()
del actor
self._cleanup_map[future] = trial
if len(self._cleanup_map) > self.threshold:
self.cleanup(partial=True)
def cleanup(self, partial: bool = True):
"""Waits for cleanup to finish.
If partial=False, all futures are expected to return. If a future
does not return within the timeout period, the cleanup terminates.
"""
# At this point, self._cleanup_map holds the last references
# to actors. Removing those references either one-by-one
# (graceful termination case) or all at once, by reinstantiating
# self._cleanup_map (forceful termination case) will cause Ray
# to kill the actors during garbage collection.
logger.debug("Cleaning up futures")
num_to_keep = int(self.threshold) / 2 if partial else 0
while len(self._cleanup_map) > num_to_keep:
dones, _ = ray.wait(
list(self._cleanup_map),
timeout=DEFAULT_GET_TIMEOUT
if not self._force_cleanup
else self._force_cleanup,
)
if not dones:
logger.warning(
"Skipping cleanup - trainable.stop did not return in "
"time. Consider making `stop` a faster operation."
)
if not partial and self._force_cleanup:
logger.warning("Forcing trainable cleanup by terminating actors.")
self._cleanup_map = {}
return
else: else:
done = dones[0] return None
del self._cleanup_map[done]
def is_empty(self):
return len(self._future_to_insert_time) == 0
def noop_logger_creator(config, logdir): def noop_logger_creator(config, logdir):
@ -165,6 +146,39 @@ def noop_logger_creator(config, logdir):
return NoopLogger(config, logdir) return NoopLogger(config, logdir)
class ExecutorEventType(Enum):
"""The executor event type.
Some of the events are internal events to executor while others
are handled by runner."""
NO_RUNNING_TRIAL_TIMEOUT = 1
PG_READY = 2
TRAINING_RESULT = 3
SAVING_RESULT = 4
RESTORING_RESULT = 5
STOP_RESULT = 6 # Internally to executor only.
ERROR = 7 # This is to signal to TrialRunner that there is an error.
YIELD = 8 # Yielding back to TrialRunner's main event loop.
class ExecutorEvent:
"""A struct that describes the event to be processed by TrialRunner."""
def __init__(
self,
event_type: ExecutorEventType,
trial: Optional[Trial] = None,
result: Optional[Union[str, Dict]] = None,
):
self.type = event_type
self.trial = trial
self.result = result
def __repr__(self):
return f"[{self.type}] for {self.trial}"
@DeveloperAPI @DeveloperAPI
class RayTrialExecutor(TrialExecutor): class RayTrialExecutor(TrialExecutor):
"""An implementation of TrialExecutor based on Ray.""" """An implementation of TrialExecutor based on Ray."""
@ -177,10 +191,17 @@ class RayTrialExecutor(TrialExecutor):
wait_for_placement_group: Optional[float] = None, wait_for_placement_group: Optional[float] = None,
): ):
super(RayTrialExecutor, self).__init__() super(RayTrialExecutor, self).__init__()
self._running = {} # future --> (type, trial/pg)
self._futures = {}
force_trial_cleanup = int(os.environ.get("TUNE_FORCE_TRIAL_CLEANUP_S", "0")) force_trial_cleanup = int(os.environ.get("TUNE_FORCE_TRIAL_CLEANUP_S", "0"))
self._trial_cleanup = _TrialCleanup(force_cleanup=force_trial_cleanup) self._get_next_event_wait = int(
os.environ.get("TUNE_GET_EXECUTOR_EVENT_WAIT_S", "5")
)
if force_trial_cleanup:
self._trial_cleanup = _TrialCleanup(force_trial_cleanup)
else:
self._trial_cleanup = None
self._has_cleaned_up_pgs = False self._has_cleaned_up_pgs = False
self._reuse_actors = reuse_actors self._reuse_actors = reuse_actors
# The maxlen will be updated when `set_max_pending_trials()` is called # The maxlen will be updated when `set_max_pending_trials()` is called
@ -189,7 +210,6 @@ class RayTrialExecutor(TrialExecutor):
self._avail_resources = Resources(cpu=0, gpu=0) self._avail_resources = Resources(cpu=0, gpu=0)
self._pg_manager = PlacementGroupManager(prefix=get_tune_pg_prefix()) self._pg_manager = PlacementGroupManager(prefix=get_tune_pg_prefix())
self._staged_trials = set() self._staged_trials = set()
self._just_staged_trials = set()
self._trial_just_finished = False self._trial_just_finished = False
self._trial_just_finished_before = False self._trial_just_finished_before = False
@ -201,12 +221,6 @@ class RayTrialExecutor(TrialExecutor):
) )
self._refresh_period = refresh_period self._refresh_period = refresh_period
self._wait_for_pg = wait_for_placement_group or float(
os.environ.get("TUNE_PLACEMENT_GROUP_WAIT_S", "-1")
)
if self._wait_for_pg < 0:
self._wait_for_pg = None
self.last_pg_recon = 0 self.last_pg_recon = 0
self.pg_recon_interval = float( self.pg_recon_interval = float(
os.environ.get("TUNE_PLACEMENT_GROUP_RECON_INTERVAL", "5") os.environ.get("TUNE_PLACEMENT_GROUP_RECON_INTERVAL", "5")
@ -229,10 +243,6 @@ class RayTrialExecutor(TrialExecutor):
if ray.is_initialized(): if ray.is_initialized():
self._update_avail_resources() self._update_avail_resources()
def in_staging_grace_period(self) -> bool:
"""Returns True if trials have recently been staged."""
return self._pg_manager.in_staging_grace_period()
def set_max_pending_trials(self, max_pending: int) -> None: def set_max_pending_trials(self, max_pending: int) -> None:
if len(self._cached_actor_pg) > 0: if len(self._cached_actor_pg) > 0:
logger.warning( logger.warning(
@ -243,7 +253,7 @@ class RayTrialExecutor(TrialExecutor):
self._cached_actor_pg = deque(maxlen=max_pending) self._cached_actor_pg = deque(maxlen=max_pending)
self._pg_manager.set_max_staging(max_pending) self._pg_manager.set_max_staging(max_pending)
def stage_and_update_status(self, trials: Iterable[Trial]): def _stage_and_update_status(self, trials: Iterable[Trial]):
"""Check and update statuses of scheduled placement groups. """Check and update statuses of scheduled placement groups.
Stages placement groups of all trials. Stages placement groups of all trials.
@ -255,7 +265,7 @@ class RayTrialExecutor(TrialExecutor):
self._has_cleaned_up_pgs = True self._has_cleaned_up_pgs = True
for trial in trials: for trial in trials:
if trial.status != Trial.PENDING: if trial.status not in (Trial.PENDING, Trial.PAUSED):
continue continue
if trial in self._staged_trials: if trial in self._staged_trials:
continue continue
@ -266,7 +276,6 @@ class RayTrialExecutor(TrialExecutor):
# Break if we reached the limit of pending placement groups. # Break if we reached the limit of pending placement groups.
break break
self._staged_trials.add(trial) self._staged_trials.add(trial)
self._just_staged_trials.add(trial)
self._pg_manager.update_status() self._pg_manager.update_status()
@ -279,6 +288,7 @@ class RayTrialExecutor(TrialExecutor):
Trial object or None. Trial object or None.
""" """
# TODO(xwjiang): This method should consider `self._cached_actor_pg`.
for trial in self._staged_trials: for trial in self._staged_trials:
if self._pg_manager.has_ready(trial): if self._pg_manager.has_ready(trial):
return trial return trial
@ -317,43 +327,7 @@ class RayTrialExecutor(TrialExecutor):
) )
_actor_cls = _class_cache.get(trainable_cls) _actor_cls = _class_cache.get(trainable_cls)
if not self._pg_manager.has_ready(trial, update=True):
if trial not in self._staged_trials:
if self._pg_manager.stage_trial_pg(trial):
self._staged_trials.add(trial)
self._just_staged_trials.add(trial)
just_staged = trial in self._just_staged_trials
# This part of the code is mostly here for testing
# purposes. If self._wait_for_pg is set, we will wait here
# for that many seconds until the placement group is ready.
# This ensures that the trial can be started right away and
# not just in the next step() of the trial runner.
# We only do this if we have reason to believe that resources
# will be ready, soon, i.e. when a) we just staged the PG,
# b) another trial just exited, freeing resources, or c)
# when there are no currently running trials.
if self._wait_for_pg is not None and (
just_staged
or self._trial_just_finished_before
or not self.get_running_trials()
):
logger.debug(
f"Waiting up to {self._wait_for_pg} seconds for "
f"placement group of trial {trial} to become ready."
)
wait_end = time.monotonic() + self._wait_for_pg
while time.monotonic() < wait_end:
self._pg_manager.update_status()
if self._pg_manager.has_ready(trial):
break
time.sleep(0.1)
else:
return None
if not self._pg_manager.has_ready(trial): if not self._pg_manager.has_ready(trial):
# PG may have become ready during waiting period
return None return None
full_actor_class = self._pg_manager.get_full_actor_cls(trial, _actor_cls) full_actor_class = self._pg_manager.get_full_actor_cls(trial, _actor_cls)
@ -403,7 +377,7 @@ class RayTrialExecutor(TrialExecutor):
def _train(self, trial): def _train(self, trial):
"""Start one iteration of training and save remote id.""" """Start one iteration of training and save remote id."""
if self._find_item(self._running, trial): if self._find_future(trial):
logging.debug( logging.debug(
"Trial {} already has a queued future. Skipping this " "Trial {} already has a queued future. Skipping this "
"`train` call. This may occur if a trial has " "`train` call. This may occur if a trial has "
@ -414,7 +388,7 @@ class RayTrialExecutor(TrialExecutor):
assert trial.status == Trial.RUNNING, trial.status assert trial.status == Trial.RUNNING, trial.status
buffer_time_s = max( buffer_time_s = max(
self._buffer_min_time_s, self._buffer_min_time_s,
min(self._buffer_max_time_s, len(self._running) // 10), min(self._buffer_max_time_s, len(self._futures) // 10),
) )
with self._change_working_directory(trial): with self._change_working_directory(trial):
buffer_length = self._buffer_length buffer_length = self._buffer_length
@ -441,8 +415,8 @@ class RayTrialExecutor(TrialExecutor):
if isinstance(remote, dict): if isinstance(remote, dict):
remote = _LocalWrapper(remote) remote = _LocalWrapper(remote)
self._running[remote] = trial self._futures[remote] = (ExecutorEventType.TRAINING_RESULT, trial)
trial_item = self._find_item(self._running, trial) trial_item = self._find_future(trial)
assert len(trial_item) < 2, trial_item assert len(trial_item) < 2, trial_item
def _start_trial(self, trial) -> bool: def _start_trial(self, trial) -> bool:
@ -540,11 +514,13 @@ class RayTrialExecutor(TrialExecutor):
if should_destroy_actor: if should_destroy_actor:
logger.debug("Trial %s: Destroying actor.", trial) logger.debug("Trial %s: Destroying actor.", trial)
# Try to return the placement group for other trials to use
self._pg_manager.return_pg(trial)
with self._change_working_directory(trial): with self._change_working_directory(trial):
self._trial_cleanup.add(trial, actor=trial.runner) future = trial.runner.stop.remote()
pg = self._pg_manager.remove_from_in_use(trial)
self._futures[future] = (ExecutorEventType.STOP_RESULT, pg)
if self._trial_cleanup: # force trial cleanup within a deadline
self._trial_cleanup.add(future)
if trial in self._staged_trials: if trial in self._staged_trials:
self._staged_trials.remove(trial) self._staged_trials.remove(trial)
@ -584,8 +560,8 @@ class RayTrialExecutor(TrialExecutor):
# have been lost. TODO(ujvl): is this the right thing to do? # have been lost. TODO(ujvl): is this the right thing to do?
return False return False
def _find_item(self, dictionary, item): def _find_future(self, trial):
out = [rid for rid, t in dictionary.items() if t is item] out = [rid for rid, t in self._futures.items() if t[1] is trial]
assert ( assert (
len(out) <= 1 len(out) <= 1
), "Expecting one future for any given trial at any given time." ), "Expecting one future for any given trial at any given time."
@ -598,9 +574,9 @@ class RayTrialExecutor(TrialExecutor):
self._stop_trial(trial, error=error, error_msg=error_msg) self._stop_trial(trial, error=error, error_msg=error_msg)
if prior_status == Trial.RUNNING: if prior_status == Trial.RUNNING:
logger.debug("Trial %s: Returning resources.", trial) logger.debug("Trial %s: Returning resources.", trial)
out = self._find_item(self._running, trial) out = self._find_future(trial)
for result_id in out: for result_id in out:
self._running.pop(result_id) self._futures.pop(result_id)
def continue_training(self, trial: Trial) -> None: def continue_training(self, trial: Trial) -> None:
"""Continues the training of this trial.""" """Continues the training of this trial."""
@ -649,63 +625,6 @@ class RayTrialExecutor(TrialExecutor):
return False return False
return reset_val return reset_val
def get_running_trials(self) -> List[Trial]:
"""Returns the running trials."""
return list(self._running.values())
def get_next_available_trial(
self, timeout: Optional[float] = None
) -> Optional[Trial]:
if not self._running:
return None
shuffled_results = list(self._running.keys())
random.shuffle(shuffled_results)
# Note: We shuffle the results because `ray.wait` by default returns
# the first available result, and we want to guarantee that slower
# trials (i.e. trials that run remotely) also get fairly reported.
# See https://github.com/ray-project/ray/issues/4211 for details.
start = time.time()
ready, _ = ray.wait(shuffled_results, timeout=timeout)
if not ready:
return None
result_id = ready[0]
wait_time = time.time() - start
if wait_time > NONTRIVIAL_WAIT_TIME_THRESHOLD_S:
self._last_nontrivial_wait = time.time()
if time.time() - self._last_nontrivial_wait > BOTTLENECK_WARN_PERIOD_S:
logger.warning(
"Over the last {} seconds, the Tune event loop has been "
"backlogged processing new results. Consider increasing your "
"period of result reporting to improve performance.".format(
BOTTLENECK_WARN_PERIOD_S
)
)
self._last_nontrivial_wait = time.time()
return self._running[result_id]
def fetch_result(self, trial) -> List[Dict]:
"""Fetches result list of the running trials.
Returns:
Result of the most recent trial training run.
"""
trial_future = self._find_item(self._running, trial)
if not trial_future:
raise ValueError("Trial was not running.")
self._running.pop(trial_future[0])
with warn_if_slow("fetch_result"):
result = ray.get(trial_future[0], timeout=DEFAULT_GET_TIMEOUT)
# For local mode
if isinstance(result, _LocalWrapper):
result = result.unwrap()
if not isinstance(result, list):
return [result]
return result
def _update_avail_resources(self, num_retries=5): def _update_avail_resources(self, num_retries=5):
if time.time() - self._last_resource_refresh < self._refresh_period: if time.time() - self._last_resource_refresh < self._refresh_period:
return return
@ -814,8 +733,7 @@ class RayTrialExecutor(TrialExecutor):
self._trial_just_finished = False self._trial_just_finished = False
def on_step_end(self, trials: List[Trial]) -> None: def on_step_end(self, trials: List[Trial]) -> None:
self._just_staged_trials.clear() self._do_force_trial_cleanup()
if time.time() > self.last_pg_recon + self.pg_recon_interval: if time.time() > self.last_pg_recon + self.pg_recon_interval:
# Only do this every now and then - usually the placement groups # Only do this every now and then - usually the placement groups
# should not get out of sync, and calling this often is inefficient # should not get out of sync, and calling this often is inefficient
@ -824,6 +742,20 @@ class RayTrialExecutor(TrialExecutor):
self._pg_manager.cleanup() self._pg_manager.cleanup()
def _do_force_trial_cleanup(self) -> None:
if self._trial_cleanup:
while True:
next_future_to_clean = self._trial_cleanup.get_next()
if not next_future_to_clean:
break
if next_future_to_clean in self._futures.keys():
_, pg = self._futures.pop(next_future_to_clean)
post_stop_cleanup(next_future_to_clean, pg)
else:
# This just means that before the deadline reaches,
# the future is already cleaned up.
pass
def force_reconcilation_on_next_step_end(self) -> None: def force_reconcilation_on_next_step_end(self) -> None:
self.last_pg_recon = -float("inf") self.last_pg_recon = -float("inf")
@ -842,6 +774,7 @@ class RayTrialExecutor(TrialExecutor):
Returns: Returns:
Checkpoint object, or None if an Exception occurs. Checkpoint object, or None if an Exception occurs.
""" """
logger.info(f"saving trial {trial}")
result = result or trial.last_result result = result or trial.last_result
with self._change_working_directory(trial): with self._change_working_directory(trial):
if storage == Checkpoint.MEMORY: if storage == Checkpoint.MEMORY:
@ -852,7 +785,7 @@ class RayTrialExecutor(TrialExecutor):
value = trial.runner.save.remote() value = trial.runner.save.remote()
checkpoint = Checkpoint(storage, value, result) checkpoint = Checkpoint(storage, value, result)
trial.saving_to = checkpoint trial.saving_to = checkpoint
self._running[value] = trial self._futures[value] = (ExecutorEventType.SAVING_RESULT, trial)
return checkpoint return checkpoint
def restore(self, trial) -> None: def restore(self, trial) -> None:
@ -899,7 +832,7 @@ class RayTrialExecutor(TrialExecutor):
"storage-based restoration" "storage-based restoration"
) )
self._running[remote] = trial self._futures[remote] = (ExecutorEventType.RESTORING_RESULT, trial)
trial.restoring_from = checkpoint trial.restoring_from = checkpoint
def export_trial_if_needed(self, trial: Trial) -> Dict: def export_trial_if_needed(self, trial: Trial) -> Dict:
@ -922,7 +855,19 @@ class RayTrialExecutor(TrialExecutor):
return self._avail_resources.gpu > 0 return self._avail_resources.gpu > 0
def cleanup(self, trials: List[Trial]) -> None: def cleanup(self, trials: List[Trial]) -> None:
self._trial_cleanup.cleanup(partial=False) while True:
if self._trial_cleanup and self._trial_cleanup.is_empty():
break
elif not self._trial_cleanup and len(self._futures) == 0:
break
self._do_force_trial_cleanup()
ready, _ = ray.wait(list(self._futures.keys()), timeout=0)
if not ready:
continue
event_type, trial_or_pg = self._futures.pop(ready[0])
if event_type == ExecutorEventType.STOP_RESULT:
post_stop_cleanup(ready[0], trial_or_pg)
self._pg_manager.reconcile_placement_groups(trials) self._pg_manager.reconcile_placement_groups(trials)
self._pg_manager.cleanup(force=True) self._pg_manager.cleanup(force=True)
self._pg_manager.cleanup_existing_pg(block=True) self._pg_manager.cleanup_existing_pg(block=True)
@ -944,6 +889,150 @@ class RayTrialExecutor(TrialExecutor):
else: else:
yield yield
def get_next_executor_event(
self, live_trials: Set[Trial], next_trial_exists: bool
) -> ExecutorEvent:
"""Get the next executor event to be processed in TrialRunner.
In case there are multiple events available for handling, the next
event is determined by the following priority:
1. if there is `next_trial_exists`, and if there is cached resources
to use, PG_READY is emitted.
2. if there is `next_trial_exists` and there is no cached resources
to use, wait on pg future and randomized other futures. If multiple
futures are ready, pg future will take priority to be handled first.
3. if there is no `next_trial_exists`, wait on just randomized other
futures.
An example of #3 would be synchronous hyperband. Although there are pgs
ready, the scheduler is holding back scheduling new trials since the
whole band of trials is waiting for the slowest trial to finish. In
this case, we prioritize handling training result to avoid deadlock
situation.
This is a blocking wait with a timeout (specified with env var).
The reason for the timeout is
we still want to print status info periodically in TrialRunner for
better user experience.
The handle of `ExecutorEvent.STOP_RESULT` is purely internal to
RayTrialExecutor itself. All the other future results are handled by
TrialRunner.
In the future we may want to do most of the handle of
`ExecutorEvent.RESTORE_RESULT` and `SAVING_RESULT` in
RayTrialExecutor itself and only notify TrialRunner to invoke
corresponding callbacks. This view is more consistent with our goal
of TrialRunner responsible for external facing Trial state transition,
while RayTrialExecutor responsible for internal facing transitions,
namely, `is_saving`, `is_restoring` etc.
Also you may notice that the boundary between RayTrialExecutor and
PlacementGroupManager right now is really blurry. This will be
improved once we move to an ActorPool abstraction.
`next_trial_exists` means that there is a trial to run - prioritize
returning PG_READY in this case.
"""
# First update status of staged placement groups
self._stage_and_update_status(live_trials)
while True:
###################################################################
# when next_trial_exists and there are cached resources
###################################################################
# There could be existing PGs from either `self._cached_actor_pg`
# or from `self._pg_manager._ready`. If so and if there is indeed
# a next trial to run, we return `PG_READY` future for trial
# runner. The next trial can then be scheduled on this PG.
if next_trial_exists:
if len(self._cached_actor_pg) > 0:
return ExecutorEvent(ExecutorEventType.PG_READY)
# TODO(xwjiang): Expose proper API when we decide to do
# ActorPool abstraction.
if any(len(r) > 0 for r in self._pg_manager._ready.values()):
return ExecutorEvent(ExecutorEventType.PG_READY)
###################################################################
# Prepare for futures to wait
###################################################################
futures_to_wait = list(self._futures.keys())
random.shuffle(futures_to_wait)
if next_trial_exists:
# Only wait for pg explicitly if there is next trial to run.
# In which case, handling PG_READY triumphs handling other events.
# Since we want to place pending trial ASAP.
futures_to_wait = (
self._pg_manager.get_staging_future_list() + futures_to_wait
)
logger.debug(
f"get_next_executor_event before wait with futures "
f"{futures_to_wait} and "
f"next_trial_exists={next_trial_exists}"
)
ready_futures, _ = ray.wait(
futures_to_wait, num_returns=1, timeout=self._get_next_event_wait
)
###################################################################
# Dealing with no future returned case.
###################################################################
if len(ready_futures) == 0:
if len(self._futures) == 0:
# No running trial and timing out with wait, could be we may
# have insufficient cluster resources that makes tune run
# infeasible.
# TODO: Move InsufficientResourceManager's logic
# to TrialExecutor. It is not Runner's responsibility!
return ExecutorEvent(ExecutorEventType.NO_RUNNING_TRIAL_TIMEOUT)
else:
# Training simply takes long time, yield the control back to main
# event loop to print progress info etc.
return ExecutorEvent(ExecutorEventType.YIELD)
###################################################################
# If there is future returned.
###################################################################
assert len(ready_futures) == 1
ready_future = ready_futures[0]
###################################################################
# If it is a PG_READY event.
###################################################################
if ready_future not in self._futures.keys():
# This is a ready future.
self._pg_manager.handle_ready_future(ready_future)
return ExecutorEvent(ExecutorEventType.PG_READY)
###################################################################
# non PG_READY event
###################################################################
result_type, trial_or_pg = self._futures.pop(ready_future)
if result_type == ExecutorEventType.STOP_RESULT:
pg = trial_or_pg
post_stop_cleanup(ready_future, pg)
else:
trial = trial_or_pg
assert isinstance(trial, Trial)
try:
future_result = ray.get(ready_future)
# For local mode
if isinstance(future_result, _LocalWrapper):
future_result = future_result.unwrap()
if result_type in (
ExecutorEventType.TRAINING_RESULT,
ExecutorEventType.SAVING_RESULT,
ExecutorEventType.RESTORING_RESULT,
):
logger.debug(f"Returning [{result_type}] for trial {trial}")
return ExecutorEvent(result_type, trial, result=future_result)
else:
raise TuneError(f"Unexpected future type - [{result_type}]")
except Exception:
return ExecutorEvent(
ExecutorEventType.ERROR, trial, traceback.format_exc()
)
def _to_gb(n_bytes): def _to_gb(n_bytes):
return round(n_bytes / (1024 ** 3), 2) return round(n_bytes / (1024 ** 3), 2)

View file

@ -780,7 +780,7 @@ class TrainableFunctionApiTest(unittest.TestCase):
return [0, 1, True, {}] return [0, 1, True, {}]
class FailureInjectionCallback(Callback): class FailureInjectionCallback(Callback):
def on_trial_start(self, trials, **info): def on_step_end(self, **info):
raise RuntimeError raise RuntimeError
with self.assertRaises(Exception): with self.assertRaises(Exception):
@ -870,7 +870,6 @@ class TrainableFunctionApiTest(unittest.TestCase):
trial.last_result.get("trial_resources"), trial.placement_group_factory trial.last_result.get("trial_resources"), trial.placement_group_factory
) )
@patch("ray.tune.ray_trial_executor.TRIAL_CLEANUP_THRESHOLD", 3)
def testLotsOfStops(self): def testLotsOfStops(self):
class TestTrainable(Trainable): class TestTrainable(Trainable):
def step(self): def step(self):

View file

@ -30,12 +30,6 @@ from ray.tune.utils.mock import (
MOCK_REMOTE_DIR, MOCK_REMOTE_DIR,
) )
# Wait up to five seconds for placement groups when starting a trial
os.environ["TUNE_PLACEMENT_GROUP_WAIT_S"] = "5"
# Block for results even when placement groups are pending
os.environ["TUNE_TRIAL_STARTUP_GRACE_PERIOD"] = "0"
os.environ["TUNE_TRIAL_RESULT_WAIT_TIME_S"] = "9999"
def _check_trial_running(trial): def _check_trial_running(trial):
if trial.runner: if trial.runner:
@ -203,17 +197,10 @@ def test_remove_node_before_result(start_connected_emptyhead_cluster):
assert trial.last_result.get("training_iteration") == 1 assert trial.last_result.get("training_iteration") == 1
# Process result: discover failure, recover, _train (from scratch) # Process result: discover failure, recover, _train (from scratch)
while trial.status != Trial.TERMINATED:
runner.step() runner.step()
runner.step() # Process result, invoke _train assert trial.last_result.get("training_iteration") > 1
assert trial.last_result.get("training_iteration") == 1
runner.step() # Process result, invoke _save
assert trial.last_result.get("training_iteration") == 2
# process save, invoke _train
runner.step()
# process result
runner.step()
assert trial.status == Trial.TERMINATED
with pytest.raises(TuneError): with pytest.raises(TuneError):
runner.step() runner.step()
@ -262,7 +249,7 @@ def test_trial_migration(start_connected_emptyhead_cluster, trainable_id):
# assert t.last_result is None, "Trial result not restored correctly." # assert t.last_result is None, "Trial result not restored correctly."
# Process result (x2), process save, process result (x2), process save # Process result (x2), process save, process result (x2), process save
for _ in range(6): while not runner.is_finished():
runner.step() runner.step()
assert t.status == Trial.TERMINATED, runner.debug_string() assert t.status == Trial.TERMINATED, runner.debug_string()
@ -271,19 +258,13 @@ def test_trial_migration(start_connected_emptyhead_cluster, trainable_id):
t2 = Trial(trainable_id, **kwargs) t2 = Trial(trainable_id, **kwargs)
runner.add_trial(t2) runner.add_trial(t2)
# Start trial, process result (x2), process save # Start trial, process result (x2), process save
for _ in range(4): while not t2.has_checkpoint():
runner.step() runner.step()
assert t2.has_checkpoint()
node3 = cluster.add_node(num_cpus=1) node3 = cluster.add_node(num_cpus=1)
cluster.remove_node(node2) cluster.remove_node(node2)
cluster.wait_for_nodes() cluster.wait_for_nodes()
runner.step() # Process result 3 + start and fail 4 result while not runner.is_finished():
runner.step() # Dispatch restore runner.step()
runner.step() # Process restore
runner.step() # Process result 5
if t2.status != Trial.TERMINATED:
runner.step() # Process result 6, dispatch save
runner.step() # Process save
assert t2.status == Trial.TERMINATED, runner.debug_string() assert t2.status == Trial.TERMINATED, runner.debug_string()
# Test recovery of trial that won't be checkpointed # Test recovery of trial that won't be checkpointed
@ -301,8 +282,7 @@ def test_trial_migration(start_connected_emptyhead_cluster, trainable_id):
cluster.add_node(num_cpus=1) cluster.add_node(num_cpus=1)
cluster.remove_node(node3) cluster.remove_node(node3)
cluster.wait_for_nodes() cluster.wait_for_nodes()
runner.step() # Error handling step while not runner.is_finished():
if t3.status != Trial.ERROR:
runner.step() runner.step()
assert t3.status == Trial.ERROR, runner.debug_string() assert t3.status == Trial.ERROR, runner.debug_string()
@ -406,19 +386,14 @@ def test_migration_checkpoint_removal(start_connected_emptyhead_cluster, trainab
runner.add_trial(t1) runner.add_trial(t1)
# Start trial, process result (x2), process save # Start trial, process result (x2), process save
for _ in range(4): while not t1.has_checkpoint():
runner.step() runner.step()
assert t1.has_checkpoint()
cluster.add_node(num_cpus=1) cluster.add_node(num_cpus=1)
cluster.remove_node(node) cluster.remove_node(node)
cluster.wait_for_nodes() cluster.wait_for_nodes()
shutil.rmtree(os.path.dirname(t1.checkpoint.value)) shutil.rmtree(os.path.dirname(t1.checkpoint.value))
runner.step() # Collect result 3, kick off + fail result 4 while not runner.is_finished():
runner.step() # Dispatch restore
runner.step() # Process restore + step 4
for _ in range(3):
if t1.status != Trial.TERMINATED:
runner.step() runner.step()
assert t1.status == Trial.TERMINATED, runner.debug_string() assert t1.status == Trial.TERMINATED, runner.debug_string()

View file

@ -295,11 +295,6 @@ with patch("ray.tune.progress_reporter._get_trial_location",
class ProgressReporterTest(unittest.TestCase): class ProgressReporterTest(unittest.TestCase):
def setUp(self) -> None: def setUp(self) -> None:
# Wait up to five seconds for placement groups when starting a trial
os.environ["TUNE_PLACEMENT_GROUP_WAIT_S"] = "5"
# Block for results even when placement groups are pending
os.environ["TUNE_TRIAL_STARTUP_GRACE_PERIOD"] = "0"
os.environ["TUNE_TRIAL_RESULT_WAIT_TIME_S"] = "99999"
os.environ["TUNE_MAX_PENDING_TRIALS_PG"] = "auto" os.environ["TUNE_MAX_PENDING_TRIALS_PG"] = "auto"
def mock_trial(self, status, i): def mock_trial(self, status, i):

View file

@ -10,7 +10,7 @@ from ray import tune
from ray.rllib import _register_all from ray.rllib import _register_all
from ray.tune import Trainable from ray.tune import Trainable
from ray.tune.callback import Callback from ray.tune.callback import Callback
from ray.tune.ray_trial_executor import RayTrialExecutor from ray.tune.ray_trial_executor import RayTrialExecutor, ExecutorEventType
from ray.tune.registry import _global_registry, TRAINABLE_CLASS from ray.tune.registry import _global_registry, TRAINABLE_CLASS
from ray.tune.result import PID, TRAINING_ITERATION, TRIAL_ID from ray.tune.result import PID, TRAINING_ITERATION, TRIAL_ID
from ray.tune.suggest import BasicVariantGenerator from ray.tune.suggest import BasicVariantGenerator
@ -83,12 +83,6 @@ class TrialExecutorInsufficientResourcesTest(unittest.TestCase):
class RayTrialExecutorTest(unittest.TestCase): class RayTrialExecutorTest(unittest.TestCase):
def setUp(self): def setUp(self):
# Wait up to five seconds for placement groups when starting a trial
os.environ["TUNE_PLACEMENT_GROUP_WAIT_S"] = "5"
# Block for results even when placement groups are pending
os.environ["TUNE_TRIAL_STARTUP_GRACE_PERIOD"] = "0"
os.environ["TUNE_TRIAL_RESULT_WAIT_TIME_S"] = "99999"
self.trial_executor = RayTrialExecutor() self.trial_executor = RayTrialExecutor()
ray.init(num_cpus=2, ignore_reinit_error=True) ray.init(num_cpus=2, ignore_reinit_error=True)
_register_all() # Needed for flaky tests _register_all() # Needed for flaky tests
@ -97,34 +91,63 @@ class RayTrialExecutorTest(unittest.TestCase):
ray.shutdown() ray.shutdown()
_register_all() # re-register the evicted objects _register_all() # re-register the evicted objects
def _simulate_starting_trial(self, trial):
future_result = self.trial_executor.get_next_executor_event(
live_trials={trial}, next_trial_exists=True
)
assert future_result.type == ExecutorEventType.PG_READY
self.assertTrue(self.trial_executor.start_trial(trial))
self.assertEqual(Trial.RUNNING, trial.status)
def _simulate_getting_result(self, trial):
while True:
future_result = self.trial_executor.get_next_executor_event(
live_trials={trial}, next_trial_exists=False
)
if future_result.type == ExecutorEventType.TRAINING_RESULT:
break
if isinstance(future_result.result, list):
for r in future_result.result:
trial.update_last_result(r)
else:
trial.update_last_result(future_result.result)
def _simulate_saving(self, trial):
checkpoint = self.trial_executor.save(trial, Checkpoint.PERSISTENT)
self.assertEqual(checkpoint, trial.saving_to)
self.assertEqual(trial.checkpoint.value, None)
future_result = self.trial_executor.get_next_executor_event(
live_trials={trial}, next_trial_exists=False
)
assert future_result.type == ExecutorEventType.SAVING_RESULT
self.process_trial_save(trial, future_result.result)
self.assertEqual(checkpoint, trial.checkpoint)
def testStartStop(self): def testStartStop(self):
trial = Trial("__fake") trial = Trial("__fake")
self.trial_executor.start_trial(trial) self._simulate_starting_trial(trial)
running = self.trial_executor.get_running_trials()
self.assertEqual(1, len(running))
self.trial_executor.stop_trial(trial) self.trial_executor.stop_trial(trial)
def testAsyncSave(self): def testAsyncSave(self):
"""Tests that saved checkpoint value not immediately set.""" """Tests that saved checkpoint value not immediately set."""
trial = Trial("__fake") trial = Trial("__fake")
self.trial_executor.start_trial(trial) self._simulate_starting_trial(trial)
self.assertEqual(Trial.RUNNING, trial.status)
trial.last_result = self.trial_executor.fetch_result(trial)[-1] self._simulate_getting_result(trial)
checkpoint = self.trial_executor.save(trial, Checkpoint.PERSISTENT)
self.assertEqual(checkpoint, trial.saving_to) self._simulate_saving(trial)
self.assertEqual(trial.checkpoint.value, None)
self.process_trial_save(trial)
self.assertEqual(checkpoint, trial.checkpoint)
self.trial_executor.stop_trial(trial) self.trial_executor.stop_trial(trial)
self.assertEqual(Trial.TERMINATED, trial.status) self.assertEqual(Trial.TERMINATED, trial.status)
def testSaveRestore(self): def testSaveRestore(self):
trial = Trial("__fake") trial = Trial("__fake")
self.trial_executor.start_trial(trial) self._simulate_starting_trial(trial)
self.assertEqual(Trial.RUNNING, trial.status)
trial.last_result = self.trial_executor.fetch_result(trial)[-1] self._simulate_getting_result(trial)
self.trial_executor.save(trial, Checkpoint.PERSISTENT)
self.process_trial_save(trial) self._simulate_saving(trial)
self.trial_executor.restore(trial) self.trial_executor.restore(trial)
self.trial_executor.stop_trial(trial) self.trial_executor.stop_trial(trial)
self.assertEqual(Trial.TERMINATED, trial.status) self.assertEqual(Trial.TERMINATED, trial.status)
@ -132,40 +155,44 @@ class RayTrialExecutorTest(unittest.TestCase):
def testPauseResume(self): def testPauseResume(self):
"""Tests that pausing works for trials in flight.""" """Tests that pausing works for trials in flight."""
trial = Trial("__fake") trial = Trial("__fake")
self.trial_executor.start_trial(trial) self._simulate_starting_trial(trial)
self.assertEqual(Trial.RUNNING, trial.status)
self.trial_executor.pause_trial(trial) self.trial_executor.pause_trial(trial)
self.assertEqual(Trial.PAUSED, trial.status) self.assertEqual(Trial.PAUSED, trial.status)
self.trial_executor.start_trial(trial)
self.assertEqual(Trial.RUNNING, trial.status) self._simulate_starting_trial(trial)
self.trial_executor.stop_trial(trial) self.trial_executor.stop_trial(trial)
self.assertEqual(Trial.TERMINATED, trial.status) self.assertEqual(Trial.TERMINATED, trial.status)
def testSavePauseResumeErrorRestore(self): def testSavePauseResumeErrorRestore(self):
"""Tests that pause checkpoint does not replace restore checkpoint.""" """Tests that pause checkpoint does not replace restore checkpoint."""
trial = Trial("__fake") trial = Trial("__fake")
self.trial_executor.start_trial(trial) self._simulate_starting_trial(trial)
trial.last_result = self.trial_executor.fetch_result(trial)[-1]
self._simulate_getting_result(trial)
# Save # Save
checkpoint = self.trial_executor.save(trial, Checkpoint.PERSISTENT) self._simulate_saving(trial)
self.assertEqual(Trial.RUNNING, trial.status)
self.assertEqual(checkpoint.storage, Checkpoint.PERSISTENT)
# Process save result (simulates trial runner)
self.process_trial_save(trial)
# Train # Train
self.trial_executor.continue_training(trial) self.trial_executor.continue_training(trial)
trial.last_result = self.trial_executor.fetch_result(trial)[-1] self._simulate_getting_result(trial)
# Pause # Pause
self.trial_executor.pause_trial(trial) self.trial_executor.pause_trial(trial)
self.assertEqual(Trial.PAUSED, trial.status) self.assertEqual(Trial.PAUSED, trial.status)
self.assertEqual(trial.checkpoint.storage, Checkpoint.MEMORY) self.assertEqual(trial.checkpoint.storage, Checkpoint.MEMORY)
# Resume # Resume
self.trial_executor.start_trial(trial) self._simulate_starting_trial(trial)
self.assertEqual(Trial.RUNNING, trial.status)
# Error # Error
trial.set_status(Trial.ERROR) trial.set_status(Trial.ERROR)
# Restore # Restore
self.trial_executor.restore(trial) self.trial_executor.restore(trial)
self.trial_executor.stop_trial(trial) self.trial_executor.stop_trial(trial)
self.assertEqual(Trial.TERMINATED, trial.status) self.assertEqual(Trial.TERMINATED, trial.status)
@ -178,13 +205,14 @@ class RayTrialExecutorTest(unittest.TestCase):
def testPauseResume2(self): def testPauseResume2(self):
"""Tests that pausing works for trials being processed.""" """Tests that pausing works for trials being processed."""
trial = Trial("__fake") trial = Trial("__fake")
self.trial_executor.start_trial(trial) self._simulate_starting_trial(trial)
self.assertEqual(Trial.RUNNING, trial.status)
self.trial_executor.fetch_result(trial) self._simulate_getting_result(trial)
self.trial_executor.pause_trial(trial) self.trial_executor.pause_trial(trial)
self.assertEqual(Trial.PAUSED, trial.status) self.assertEqual(Trial.PAUSED, trial.status)
self.trial_executor.start_trial(trial)
self.assertEqual(Trial.RUNNING, trial.status) self._simulate_starting_trial(trial)
self.trial_executor.stop_trial(trial) self.trial_executor.stop_trial(trial)
self.assertEqual(Trial.TERMINATED, trial.status) self.assertEqual(Trial.TERMINATED, trial.status)
@ -199,15 +227,17 @@ class RayTrialExecutorTest(unittest.TestCase):
base = max(result_buffer_length, 1) base = max(result_buffer_length, 1)
trial = Trial("__fake") trial = Trial("__fake")
self.trial_executor.start_trial(trial) self._simulate_starting_trial(trial)
self.assertEqual(Trial.RUNNING, trial.status)
trial.last_result = self.trial_executor.fetch_result(trial)[-1] self._simulate_getting_result(trial)
self.assertEqual(trial.last_result.get(TRAINING_ITERATION), base) self.assertEqual(trial.last_result.get(TRAINING_ITERATION), base)
self.trial_executor.pause_trial(trial) self.trial_executor.pause_trial(trial)
self.assertEqual(Trial.PAUSED, trial.status) self.assertEqual(Trial.PAUSED, trial.status)
self.trial_executor.start_trial(trial)
self.assertEqual(Trial.RUNNING, trial.status) self._simulate_starting_trial(trial)
trial.last_result = self.trial_executor.fetch_result(trial)[-1]
self._simulate_getting_result(trial)
self.assertEqual(trial.last_result.get(TRAINING_ITERATION), base * 2) self.assertEqual(trial.last_result.get(TRAINING_ITERATION), base * 2)
self.trial_executor.stop_trial(trial) self.trial_executor.stop_trial(trial)
self.assertEqual(Trial.TERMINATED, trial.status) self.assertEqual(Trial.TERMINATED, trial.status)
@ -224,7 +254,7 @@ class RayTrialExecutorTest(unittest.TestCase):
def testNoResetTrial(self): def testNoResetTrial(self):
"""Tests that reset handles NotImplemented properly.""" """Tests that reset handles NotImplemented properly."""
trial = Trial("__fake") trial = Trial("__fake")
self.trial_executor.start_trial(trial) self._simulate_starting_trial(trial)
exists = self.trial_executor.reset_trial(trial, {}, "modified_mock") exists = self.trial_executor.reset_trial(trial, {}, "modified_mock")
self.assertEqual(exists, False) self.assertEqual(exists, False)
self.assertEqual(Trial.RUNNING, trial.status) self.assertEqual(Trial.RUNNING, trial.status)
@ -248,18 +278,18 @@ class RayTrialExecutorTest(unittest.TestCase):
"grid_search", "grid_search",
) )
trial = trials[0] trial = trials[0]
self.trial_executor.start_trial(trial) self._simulate_starting_trial(trial)
exists = self.trial_executor.reset_trial(trial, {"hi": 1}, "modified_mock") exists = self.trial_executor.reset_trial(trial, {"hi": 1}, "modified_mock")
self.assertEqual(exists, True) self.assertEqual(exists, True)
self.assertEqual(trial.config.get("hi"), 1) self.assertEqual(trial.config.get("hi"), 1)
self.assertEqual(trial.experiment_tag, "modified_mock") self.assertEqual(trial.experiment_tag, "modified_mock")
self.assertEqual(Trial.RUNNING, trial.status) self.assertEqual(Trial.RUNNING, trial.status)
def testForceTrialCleanup(self): def testTrialCleanup(self):
class B(Trainable): class B(Trainable):
def step(self): def step(self):
print("Step start") print("Step start")
time.sleep(10) time.sleep(4)
print("Step done") print("Step done")
return dict(my_metric=1, timesteps_this_iter=1, done=True) return dict(my_metric=1, timesteps_this_iter=1, done=True)
@ -269,7 +299,7 @@ class RayTrialExecutorTest(unittest.TestCase):
def cleanup(self): def cleanup(self):
print("Cleanup start") print("Cleanup start")
time.sleep(10) time.sleep(4)
print("Cleanup done") print("Cleanup done")
# First check if the trials terminate gracefully by default # First check if the trials terminate gracefully by default
@ -281,15 +311,15 @@ class RayTrialExecutorTest(unittest.TestCase):
"grid_search", "grid_search",
) )
trial = trials[0] trial = trials[0]
self.trial_executor.start_trial(trial) self._simulate_starting_trial(trial)
self.assertEqual(Trial.RUNNING, trial.status) time.sleep(1)
time.sleep(5)
print("Stop trial") print("Stop trial")
self.trial_executor.stop_trial(trial) self.trial_executor.stop_trial(trial)
print("Start trial cleanup") print("Start trial cleanup")
start = time.time() start = time.time()
self.trial_executor.cleanup([trial]) self.trial_executor.cleanup([trial])
self.assertGreaterEqual(time.time() - start, 12.0) # 4 - 1 + 4.
self.assertGreaterEqual(time.time() - start, 6)
# Check forceful termination. It should run for much less than the # Check forceful termination. It should run for much less than the
# sleep periods in the Trainable # sleep periods in the Trainable
@ -304,15 +334,16 @@ class RayTrialExecutorTest(unittest.TestCase):
os.environ["TUNE_FORCE_TRIAL_CLEANUP_S"] = "1" os.environ["TUNE_FORCE_TRIAL_CLEANUP_S"] = "1"
self.trial_executor = RayTrialExecutor() self.trial_executor = RayTrialExecutor()
os.environ["TUNE_FORCE_TRIAL_CLEANUP_S"] = "0" os.environ["TUNE_FORCE_TRIAL_CLEANUP_S"] = "0"
self.trial_executor.start_trial(trial) self._simulate_starting_trial(trial)
self.assertEqual(Trial.RUNNING, trial.status) self.assertEqual(Trial.RUNNING, trial.status)
time.sleep(5) time.sleep(1)
print("Stop trial") print("Stop trial")
self.trial_executor.stop_trial(trial) self.trial_executor.stop_trial(trial)
print("Start trial cleanup") print("Start trial cleanup")
start = time.time() start = time.time()
self.trial_executor.cleanup([trial]) self.trial_executor.cleanup([trial])
self.assertLess(time.time() - start, 5.0) # less than 1 with some margin.
self.assertLess(time.time() - start, 2.0)
# also check if auto-filled metrics were returned # also check if auto-filled metrics were returned
self.assertIn(PID, trial.last_result) self.assertIn(PID, trial.last_result)
@ -332,10 +363,9 @@ class RayTrialExecutorTest(unittest.TestCase):
break break
return trials return trials
def process_trial_save(self, trial): def process_trial_save(self, trial, checkpoint_value):
"""Simulates trial runner save.""" """Simulates trial runner save."""
checkpoint = trial.saving_to checkpoint = trial.saving_to
checkpoint_value = self.trial_executor.fetch_result(trial)[-1]
checkpoint.value = checkpoint_value checkpoint.value = checkpoint_value
trial.on_checkpoint(checkpoint) trial.on_checkpoint(checkpoint)
@ -460,10 +490,8 @@ class LocalModeExecutorTest(RayTrialExecutorTest):
ray.shutdown() ray.shutdown()
_register_all() # re-register the evicted objects _register_all() # re-register the evicted objects
def testForceTrialCleanup(self): def testTrialCleanup(self):
self.skipTest( self.skipTest("Skipping as trial cleanup is not applicable" " for local mode.")
"Skipping as force trial cleanup is not applicable" " for local mode."
)
if __name__ == "__main__": if __name__ == "__main__":

View file

@ -15,6 +15,7 @@ import ray
from ray.rllib import _register_all from ray.rllib import _register_all
from ray import tune from ray import tune
from ray.tune import TuneError
from ray.tune.integration.docker import DockerSyncer from ray.tune.integration.docker import DockerSyncer
from ray.tune.integration.kubernetes import KubernetesSyncer from ray.tune.integration.kubernetes import KubernetesSyncer
from ray.tune.sync_client import NOOP from ray.tune.sync_client import NOOP
@ -29,12 +30,6 @@ from ray.tune.utils.callback import create_default_callbacks
class TestSyncFunctionality(unittest.TestCase): class TestSyncFunctionality(unittest.TestCase):
def setUp(self): def setUp(self):
# Wait up to 1.5 seconds for placement groups when starting a trial
os.environ["TUNE_PLACEMENT_GROUP_WAIT_S"] = "1.5"
# Block for results even when placement groups are pending
os.environ["TUNE_TRIAL_STARTUP_GRACE_PERIOD"] = "0"
os.environ["TUNE_TRIAL_RESULT_WAIT_TIME_S"] = "99999"
ray.init(num_cpus=2) ray.init(num_cpus=2)
def tearDown(self): def tearDown(self):
@ -120,7 +115,7 @@ class TestSyncFunctionality(unittest.TestCase):
def testClusterProperString(self): def testClusterProperString(self):
"""Tests that invalid commands throw..""" """Tests that invalid commands throw.."""
with self.assertRaises(ValueError): with self.assertRaises(TuneError):
# This raises ValueError because logger is init in safe zone. # This raises ValueError because logger is init in safe zone.
sync_config = tune.SyncConfig(syncer="ls {target}") sync_config = tune.SyncConfig(syncer="ls {target}")
[trial] = tune.run( [trial] = tune.run(
@ -131,7 +126,7 @@ class TestSyncFunctionality(unittest.TestCase):
sync_config=sync_config, sync_config=sync_config,
).trials ).trials
with self.assertRaises(ValueError): with self.assertRaises(TuneError):
# This raises ValueError because logger is init in safe zone. # This raises ValueError because logger is init in safe zone.
sync_config = tune.SyncConfig(syncer="ls {source}") sync_config = tune.SyncConfig(syncer="ls {source}")
[trial] = tune.run( [trial] = tune.run(

View file

@ -19,29 +19,11 @@ from ray.tune.utils.placement_groups import PlacementGroupFactory
class TrialRunnerTest(unittest.TestCase): class TrialRunnerTest(unittest.TestCase):
def setUp(self): def setUp(self):
# Wait up to five seconds for placement groups when starting a trial
os.environ["TUNE_PLACEMENT_GROUP_WAIT_S"] = "5"
# Block for results even when placement groups are pending
os.environ["TUNE_TRIAL_STARTUP_GRACE_PERIOD"] = "0"
os.environ["TUNE_TRIAL_RESULT_WAIT_TIME_S"] = "99999"
_register_all() # re-register the evicted objects _register_all() # re-register the evicted objects
def tearDown(self): def tearDown(self):
ray.shutdown() ray.shutdown()
def testTrialStatus(self):
ray.init(num_cpus=2)
trial = Trial("__fake")
trial_executor = RayTrialExecutor()
self.assertEqual(trial.status, Trial.PENDING)
trial_executor.start_trial(trial)
self.assertEqual(trial.status, Trial.RUNNING)
trial_executor.stop_trial(trial)
self.assertEqual(trial.status, Trial.TERMINATED)
trial_executor.stop_trial(trial, error=True)
self.assertEqual(trial.status, Trial.ERROR)
def testExperimentTagTruncation(self): def testExperimentTagTruncation(self):
ray.init(num_cpus=2) ray.init(num_cpus=2)
@ -74,7 +56,8 @@ class TrialRunnerTest(unittest.TestCase):
def testExtraResources(self): def testExtraResources(self):
ray.init(num_cpus=4, num_gpus=2) ray.init(num_cpus=4, num_gpus=2)
runner = TrialRunner() snapshot = TrialStatusSnapshot()
runner = TrialRunner(callbacks=[TrialStatusSnapshotTaker(snapshot)])
kwargs = { kwargs = {
"stopping_criterion": {"training_iteration": 1}, "stopping_criterion": {"training_iteration": 1},
"placement_group_factory": PlacementGroupFactory( "placement_group_factory": PlacementGroupFactory(
@ -85,17 +68,18 @@ class TrialRunnerTest(unittest.TestCase):
for t in trials: for t in trials:
runner.add_trial(t) runner.add_trial(t)
while not runner.is_finished():
runner.step() runner.step()
self.assertEqual(trials[0].status, Trial.RUNNING)
self.assertEqual(trials[1].status, Trial.PENDING)
runner.step() self.assertLess(snapshot.max_running_trials(), 2)
self.assertEqual(trials[0].status, Trial.TERMINATED) self.assertTrue(snapshot.all_trials_are_terminated())
self.assertEqual(trials[1].status, Trial.PENDING)
def testCustomResources(self): def testCustomResources(self):
ray.init(num_cpus=4, num_gpus=2, resources={"a": 2}) ray.init(num_cpus=4, num_gpus=2, resources={"a": 2})
runner = TrialRunner() # Since each trial will occupy the full custom resources,
# there are at most 1 trial running at any given moment.
snapshot = TrialStatusSnapshot()
runner = TrialRunner(callbacks=[TrialStatusSnapshotTaker(snapshot)])
kwargs = { kwargs = {
"stopping_criterion": {"training_iteration": 1}, "stopping_criterion": {"training_iteration": 1},
"placement_group_factory": PlacementGroupFactory([{"CPU": 1, "a": 2}]), "placement_group_factory": PlacementGroupFactory([{"CPU": 1, "a": 2}]),
@ -104,16 +88,18 @@ class TrialRunnerTest(unittest.TestCase):
for t in trials: for t in trials:
runner.add_trial(t) runner.add_trial(t)
while not runner.is_finished():
runner.step() runner.step()
self.assertEqual(trials[0].status, Trial.RUNNING)
self.assertEqual(trials[1].status, Trial.PENDING) self.assertLess(snapshot.max_running_trials(), 2)
runner.step() self.assertTrue(snapshot.all_trials_are_terminated())
self.assertEqual(trials[0].status, Trial.TERMINATED)
self.assertEqual(trials[1].status, Trial.PENDING)
def testExtraCustomResources(self): def testExtraCustomResources(self):
ray.init(num_cpus=4, num_gpus=2, resources={"a": 2}) ray.init(num_cpus=4, num_gpus=2, resources={"a": 2})
runner = TrialRunner() # Since each trial will occupy the full custom resources,
# there are at most 1 trial running at any given moment.
snapshot = TrialStatusSnapshot()
runner = TrialRunner(callbacks=[TrialStatusSnapshotTaker(snapshot)])
kwargs = { kwargs = {
"stopping_criterion": {"training_iteration": 1}, "stopping_criterion": {"training_iteration": 1},
"placement_group_factory": PlacementGroupFactory([{"CPU": 1}, {"a": 2}]), "placement_group_factory": PlacementGroupFactory([{"CPU": 1}, {"a": 2}]),
@ -122,14 +108,11 @@ class TrialRunnerTest(unittest.TestCase):
for t in trials: for t in trials:
runner.add_trial(t) runner.add_trial(t)
while not runner.is_finished():
runner.step() runner.step()
self.assertEqual(trials[0].status, Trial.RUNNING)
self.assertEqual(trials[1].status, Trial.PENDING)
runner.step() self.assertLess(snapshot.max_running_trials(), 2)
self.assertTrue(sum(t.status == Trial.RUNNING for t in trials) < 2) self.assertTrue(snapshot.all_trials_are_terminated())
self.assertEqual(trials[0].status, Trial.TERMINATED)
self.assertEqual(trials[1].status, Trial.PENDING)
def testFractionalGpus(self): def testFractionalGpus(self):
ray.init(num_cpus=4, num_gpus=1) ray.init(num_cpus=4, num_gpus=1)
@ -209,12 +192,7 @@ class TrialRunnerTest(unittest.TestCase):
for t in trials: for t in trials:
runner.add_trial(t) runner.add_trial(t)
runner.step() while not runner.is_finished():
self.assertEqual(trials[0].status, Trial.RUNNING)
runner.step()
self.assertEqual(trials[0].status, Trial.RUNNING)
runner.step() runner.step()
self.assertEqual(trials[0].status, Trial.TERMINATED) self.assertEqual(trials[0].status, Trial.TERMINATED)
self.assertRaises(TuneError, runner.step) self.assertRaises(TuneError, runner.step)
@ -224,15 +202,23 @@ class TrialRunnerTest(unittest.TestCase):
ray.init(num_cpus=2) ray.init(num_cpus=2)
class ChangingScheduler(FIFOScheduler): class ChangingScheduler(FIFOScheduler):
def __init__(self):
self._has_received_one_trial_result = False
# For figuring out how many runner.step there are.
def has_received_one_trial_result(self):
return self._has_received_one_trial_result
def on_trial_result(self, trial_runner, trial, result): def on_trial_result(self, trial_runner, trial, result):
if result["training_iteration"] == 1: if result["training_iteration"] == 1:
self._has_received_one_trial_result = True
executor = trial_runner.trial_executor executor = trial_runner.trial_executor
executor.stop_trial(trial) executor.pause_trial(trial)
trial.update_resources(dict(cpu=2, gpu=0)) trial.update_resources(dict(cpu=2, gpu=0))
executor.start_trial(trial) return TrialScheduler.NOOP
return TrialScheduler.CONTINUE
runner = TrialRunner(scheduler=ChangingScheduler()) scheduler = ChangingScheduler()
runner = TrialRunner(scheduler=scheduler)
kwargs = { kwargs = {
"stopping_criterion": {"training_iteration": 2}, "stopping_criterion": {"training_iteration": 2},
"resources": Resources(cpu=1, gpu=0), "resources": Resources(cpu=1, gpu=0),
@ -250,8 +236,11 @@ class TrialRunnerTest(unittest.TestCase):
ValueError, lambda: trials[0].update_resources(dict(cpu=2, gpu=0)) ValueError, lambda: trials[0].update_resources(dict(cpu=2, gpu=0))
) )
while not scheduler.has_received_one_trial_result():
runner.step()
self.assertEqual(trials[0].status, Trial.PAUSED)
# extra step for tune loop to stage the resource requests.
runner.step() runner.step()
self.assertEqual(trials[0].status, Trial.RUNNING)
self.assertEqual( self.assertEqual(
runner.trial_executor._pg_manager.occupied_resources().get("CPU"), 2 runner.trial_executor._pg_manager.occupied_resources().get("CPU"), 2
) )

View file

@ -13,6 +13,7 @@ from ray.tune.trial import Trial
from ray.tune.trial_runner import TrialRunner from ray.tune.trial_runner import TrialRunner
from ray.tune.resources import Resources from ray.tune.resources import Resources
from ray.tune.suggest import BasicVariantGenerator from ray.tune.suggest import BasicVariantGenerator
from ray.tune.tests.test_trial_runner_utils import TrialResultObserver
def create_mock_components(): def create_mock_components():
@ -37,11 +38,6 @@ def create_mock_components():
class TrialRunnerTest2(unittest.TestCase): class TrialRunnerTest2(unittest.TestCase):
def setUp(self): def setUp(self):
os.environ["TUNE_STATE_REFRESH_PERIOD"] = "0.1" os.environ["TUNE_STATE_REFRESH_PERIOD"] = "0.1"
# Wait up to five seconds for placement groups when starting a trial
os.environ["TUNE_PLACEMENT_GROUP_WAIT_S"] = "5"
# Block for results even when placement groups are pending
os.environ["TUNE_TRIAL_STARTUP_GRACE_PERIOD"] = "0"
os.environ["TUNE_TRIAL_RESULT_WAIT_TIME_S"] = "99999"
def tearDown(self): def tearDown(self):
ray.shutdown() ray.shutdown()
@ -89,12 +85,9 @@ class TrialRunnerTest2(unittest.TestCase):
runner.add_trial(Trial("__fake", **kwargs)) runner.add_trial(Trial("__fake", **kwargs))
trials = runner.get_trials() trials = runner.get_trials()
runner.step() # Start trial while not runner.is_finished():
self.assertEqual(trials[0].status, Trial.RUNNING) runner.step()
runner.step() # Process result, dispatch save
self.assertEqual(trials[0].status, Trial.RUNNING)
runner.step() # Process save
runner.step() # Error
self.assertEqual(trials[0].status, Trial.ERROR) self.assertEqual(trials[0].status, Trial.ERROR)
self.assertEqual(trials[0].num_failures, 1) self.assertEqual(trials[0].num_failures, 1)
self.assertEqual(len(searchalg.errored_trials), 1) self.assertEqual(len(searchalg.errored_trials), 1)
@ -107,6 +100,7 @@ class TrialRunnerTest2(unittest.TestCase):
runner = TrialRunner(searchalg, scheduler=scheduler) runner = TrialRunner(searchalg, scheduler=scheduler)
kwargs = { kwargs = {
"stopping_criterion": {"training_iteration": 2},
"resources": Resources(cpu=1, gpu=1), "resources": Resources(cpu=1, gpu=1),
"checkpoint_freq": 1, "checkpoint_freq": 1,
"max_failures": 1, "max_failures": 1,
@ -117,18 +111,15 @@ class TrialRunnerTest2(unittest.TestCase):
runner.add_trial(Trial("__fake", **kwargs)) runner.add_trial(Trial("__fake", **kwargs))
trials = runner.get_trials() trials = runner.get_trials()
runner.step() # Start trial while not runner.is_finished():
self.assertEqual(trials[0].status, Trial.RUNNING) runner.step()
runner.step() # Process result, dispatch save self.assertEqual(trials[0].status, Trial.TERMINATED)
self.assertEqual(trials[0].status, Trial.RUNNING)
runner.step() # Process save
runner.step() # Error (transient), dispatch restore
self.assertEqual(trials[0].status, Trial.RUNNING)
self.assertEqual(trials[0].num_failures, 1) self.assertEqual(trials[0].num_failures, 1)
runner.step() # Process restore
self.assertEqual(trials[0].status, Trial.RUNNING)
self.assertEqual(len(searchalg.errored_trials), 0) self.assertEqual(len(searchalg.errored_trials), 0)
self.assertEqual(len(scheduler.errored_trials), 0) # Notice this is 1 since during recovery, the previously errored trial
# is "requeued". This will call scheduler.on_trial_error.
# Searcher.on_trial_error is, however, not called in this process.
self.assertEqual(len(scheduler.errored_trials), 1)
def testFailureRecoveryMaxFailures(self): def testFailureRecoveryMaxFailures(self):
ray.init(num_cpus=1, num_gpus=1) ray.init(num_cpus=1, num_gpus=1)
@ -145,20 +136,8 @@ class TrialRunnerTest2(unittest.TestCase):
runner.add_trial(Trial("__fake", **kwargs)) runner.add_trial(Trial("__fake", **kwargs))
trials = runner.get_trials() trials = runner.get_trials()
runner.step() # Start trial while not runner.is_finished():
self.assertEqual(trials[0].status, Trial.RUNNING) runner.step()
runner.step() # Process result, dispatch save
self.assertEqual(trials[0].status, Trial.RUNNING)
runner.step() # Process save
runner.step() # Error (transient), dispatch restore
self.assertEqual(trials[0].status, Trial.RUNNING)
self.assertEqual(trials[0].num_failures, 1)
runner.step() # Process restore
runner.step() # Error (transient), dispatch restore
self.assertEqual(trials[0].status, Trial.RUNNING)
self.assertEqual(trials[0].num_failures, 2)
runner.step() # Process restore
runner.step() # Error (terminal)
self.assertEqual(trials[0].status, Trial.ERROR) self.assertEqual(trials[0].status, Trial.ERROR)
self.assertEqual(trials[0].num_failures, 3) self.assertEqual(trials[0].num_failures, 3)
@ -178,13 +157,12 @@ class TrialRunnerTest2(unittest.TestCase):
runner.add_trial(Trial("__fake", **kwargs)) runner.add_trial(Trial("__fake", **kwargs))
trials = runner.get_trials() trials = runner.get_trials()
runner.step() # Start trial while not runner.is_finished():
self.assertEqual(trials[0].status, Trial.RUNNING) runner.step()
runner.step() # Process result, dispatch save
self.assertEqual(trials[0].status, Trial.RUNNING)
runner.step() # Process save
runner.step() # Error
self.assertEqual(trials[0].status, Trial.ERROR) self.assertEqual(trials[0].status, Trial.ERROR)
# Somehow with `fail_fast=True`, if one errors out, the others are
# then stopped with `TERMINATED` status.
self.assertEqual(trials[1].status, Trial.TERMINATED)
self.assertRaises(TuneError, lambda: runner.step()) self.assertRaises(TuneError, lambda: runner.step())
def testFailFastRaise(self): def testFailFastRaise(self):
@ -203,13 +181,14 @@ class TrialRunnerTest2(unittest.TestCase):
runner.add_trial(Trial("__fake", **kwargs)) runner.add_trial(Trial("__fake", **kwargs))
trials = runner.get_trials() trials = runner.get_trials()
runner.step() # Start trial
self.assertEqual(trials[0].status, Trial.RUNNING)
runner.step() # Process result, dispatch save
self.assertEqual(trials[0].status, Trial.RUNNING)
runner.step() # Process save
with self.assertRaises(Exception): with self.assertRaises(Exception):
runner.step() # Error while not runner.is_finished():
runner.step()
# Not critical checks. Only to showcase the difference
# with none raise type FailFast.
self.assertEqual(trials[0].status, Trial.RUNNING)
self.assertEqual(trials[1].status, Trial.PENDING)
def testCheckpointing(self): def testCheckpointing(self):
ray.init(num_cpus=1, num_gpus=1) ray.init(num_cpus=1, num_gpus=1)
@ -244,35 +223,38 @@ class TrialRunnerTest2(unittest.TestCase):
def testRestoreMetricsAfterCheckpointing(self): def testRestoreMetricsAfterCheckpointing(self):
ray.init(num_cpus=1, num_gpus=1) ray.init(num_cpus=1, num_gpus=1)
runner = TrialRunner()
observer = TrialResultObserver()
runner = TrialRunner(callbacks=[observer])
kwargs = { kwargs = {
"stopping_criterion": {"training_iteration": 2},
"resources": Resources(cpu=1, gpu=1), "resources": Resources(cpu=1, gpu=1),
"checkpoint_freq": 1, "checkpoint_freq": 1,
} }
runner.add_trial(Trial("__fake", **kwargs)) runner.add_trial(Trial("__fake", **kwargs))
trials = runner.get_trials() trials = runner.get_trials()
runner.step() # Start trial while not runner.is_finished():
self.assertEqual(trials[0].status, Trial.RUNNING) runner.step()
self.assertEqual(ray.get(trials[0].runner.set_info.remote(1)), 1)
runner.step() # Process result, dispatch save
runner.step() # Process save
runner.trial_executor.stop_trial(trials[0])
kwargs["restore_path"] = trials[0].checkpoint.value
self.assertEqual(trials[0].status, Trial.TERMINATED)
kwargs["restore_path"] = trials[0].checkpoint.value
kwargs.pop("stopping_criterion")
kwargs.pop("checkpoint_freq") # No checkpointing for next trial kwargs.pop("checkpoint_freq") # No checkpointing for next trial
runner.add_trial(Trial("__fake", **kwargs)) runner.add_trial(Trial("__fake", **kwargs))
trials = runner.get_trials() trials = runner.get_trials()
runner.step() # Start trial, dispatch restore observer.reset()
self.assertEqual(trials[0].status, Trial.TERMINATED) while not observer.just_received_a_result():
self.assertEqual(trials[1].status, Trial.RUNNING) runner.step()
runner.step() # Process restore
runner.step() # Process result
self.assertEqual(trials[1].last_result["timesteps_since_restore"], 10) self.assertEqual(trials[1].last_result["timesteps_since_restore"], 10)
self.assertEqual(trials[1].last_result["iterations_since_restore"], 1) self.assertEqual(trials[1].last_result["iterations_since_restore"], 1)
self.assertGreater(trials[1].last_result["time_since_restore"], 0) self.assertGreater(trials[1].last_result["time_since_restore"], 0)
runner.step() # Process restore
while not observer.just_received_a_result():
runner.step()
self.assertEqual(trials[1].last_result["timesteps_since_restore"], 20) self.assertEqual(trials[1].last_result["timesteps_since_restore"], 20)
self.assertEqual(trials[1].last_result["iterations_since_restore"], 2) self.assertEqual(trials[1].last_result["iterations_since_restore"], 2)
self.assertGreater(trials[1].last_result["time_since_restore"], 0) self.assertGreater(trials[1].last_result["time_since_restore"], 0)
@ -289,12 +271,9 @@ class TrialRunnerTest2(unittest.TestCase):
runner.add_trial(Trial("__fake", **kwargs)) runner.add_trial(Trial("__fake", **kwargs))
trials = runner.get_trials() trials = runner.get_trials()
runner.step() # Start trial while not runner.is_finished():
self.assertEqual(trials[0].status, Trial.RUNNING) runner.step()
runner.step() # Process result
runner.step() # Process result, dispatch save
self.assertEqual(trials[0].last_result[DONE], True) self.assertEqual(trials[0].last_result[DONE], True)
runner.step() # Process save
self.assertEqual(trials[0].has_checkpoint(), True) self.assertEqual(trials[0].has_checkpoint(), True)
def testResultDone(self): def testResultDone(self):
@ -308,10 +287,7 @@ class TrialRunnerTest2(unittest.TestCase):
runner.add_trial(Trial("__fake", **kwargs)) runner.add_trial(Trial("__fake", **kwargs))
trials = runner.get_trials() trials = runner.get_trials()
runner.step() while not runner.is_finished():
self.assertEqual(trials[0].status, Trial.RUNNING)
runner.step()
self.assertNotEqual(trials[0].last_result[DONE], True)
runner.step() runner.step()
self.assertEqual(trials[0].last_result[DONE], True) self.assertEqual(trials[0].last_result[DONE], True)

View file

@ -25,16 +25,11 @@ from ray.tune.suggest._mock import _MockSuggestionAlgorithm
from ray.tune.suggest.suggestion import Searcher, ConcurrencyLimiter from ray.tune.suggest.suggestion import Searcher, ConcurrencyLimiter
from ray.tune.suggest.search_generator import SearchGenerator from ray.tune.suggest.search_generator import SearchGenerator
from ray.tune.syncer import SyncConfig from ray.tune.syncer import SyncConfig
from ray.tune.tests.test_trial_runner_utils import TrialResultObserver
class TrialRunnerTest3(unittest.TestCase): class TrialRunnerTest3(unittest.TestCase):
def setUp(self): def setUp(self):
# Wait up to five seconds for placement groups when starting a trial
os.environ["TUNE_PLACEMENT_GROUP_WAIT_S"] = "5"
# Block for results even when placement groups are pending
os.environ["TUNE_TRIAL_STARTUP_GRACE_PERIOD"] = "0"
os.environ["TUNE_TRIAL_RESULT_WAIT_TIME_S"] = "99999"
os.environ["TUNE_MAX_PENDING_TRIALS_PG"] = "auto" # Reset default os.environ["TUNE_MAX_PENDING_TRIALS_PG"] = "auto" # Reset default
self.tmpdir = tempfile.mkdtemp() self.tmpdir = tempfile.mkdtemp()
@ -131,15 +126,9 @@ class TrialRunnerTest3(unittest.TestCase):
searcher = search_alg.searcher searcher = search_alg.searcher
search_alg.add_configurations(experiments) search_alg.add_configurations(experiments)
runner = TrialRunner(search_alg=search_alg) runner = TrialRunner(search_alg=search_alg)
runner.step()
trials = runner.get_trials()
self.assertEqual(trials[0].status, Trial.RUNNING)
while not runner.is_finished():
runner.step() runner.step()
self.assertEqual(trials[0].status, Trial.RUNNING)
runner.step()
self.assertEqual(trials[0].status, Trial.TERMINATED)
self.assertEqual(searcher.counter["result"], 1) self.assertEqual(searcher.counter["result"], 1)
self.assertEqual(searcher.counter["complete"], 1) self.assertEqual(searcher.counter["complete"], 1)
@ -204,18 +193,17 @@ class TrialRunnerTest3(unittest.TestCase):
runner = TrialRunner(search_alg=search_alg) runner = TrialRunner(search_alg=search_alg)
runner.step() runner.step()
trials = runner.get_trials() trials = runner.get_trials()
self.assertEqual(trials[0].status, Trial.RUNNING) while trials[0].status != Trial.TERMINATED:
runner.step()
runner.step() runner.step()
self.assertEqual(trials[0].status, Trial.TERMINATED)
trials = runner.get_trials() trials = runner.get_trials()
runner.step()
self.assertEqual(trials[1].status, Trial.RUNNING) self.assertEqual(trials[1].status, Trial.RUNNING)
self.assertEqual(len(searcher.live_trials), 1) self.assertEqual(len(searcher.live_trials), 1)
searcher.stall = True searcher.stall = True
while trials[1].status != Trial.TERMINATED:
runner.step() runner.step()
self.assertEqual(trials[1].status, Trial.TERMINATED) self.assertEqual(trials[1].status, Trial.TERMINATED)
self.assertEqual(len(searcher.live_trials), 0) self.assertEqual(len(searcher.live_trials), 0)
@ -231,8 +219,9 @@ class TrialRunnerTest3(unittest.TestCase):
self.assertEqual(trials[2].status, Trial.RUNNING) self.assertEqual(trials[2].status, Trial.RUNNING)
self.assertEqual(len(searcher.live_trials), 1) self.assertEqual(len(searcher.live_trials), 1)
while trials[2].status != Trial.TERMINATED:
runner.step() runner.step()
self.assertEqual(trials[2].status, Trial.TERMINATED)
self.assertEqual(len(searcher.live_trials), 0) self.assertEqual(len(searcher.live_trials), 0)
self.assertTrue(search_alg.is_finished()) self.assertTrue(search_alg.is_finished())
self.assertTrue(runner.is_finished()) self.assertTrue(runner.is_finished())
@ -445,9 +434,9 @@ class TrialRunnerTest3(unittest.TestCase):
) )
] ]
runner.add_trial(trials[0]) runner.add_trial(trials[0])
runner.step() # Start trial while not runner.is_finished():
runner.step() # Process result, dispatch save # Start trial, process result, dispatch save and process save.
runner.step() # Process save runner.step()
self.assertEqual(trials[0].status, Trial.TERMINATED) self.assertEqual(trials[0].status, Trial.TERMINATED)
trials += [ trials += [
@ -460,10 +449,13 @@ class TrialRunnerTest3(unittest.TestCase):
) )
] ]
runner.add_trial(trials[1]) runner.add_trial(trials[1])
runner.step() # Start trial while not runner.is_finished():
runner.step() # Process result, dispatch save # Start trial,
runner.step() # Process save # Process result,
runner.step() # Error # Dispatch save,
# Process save and
# Error.
runner.step()
self.assertEqual(trials[1].status, Trial.ERROR) self.assertEqual(trials[1].status, Trial.ERROR)
trials += [ trials += [
@ -488,12 +480,14 @@ class TrialRunnerTest3(unittest.TestCase):
restored_trial = runner2.get_trial("trial_succ") restored_trial = runner2.get_trial("trial_succ")
self.assertEqual(Trial.PENDING, restored_trial.status) self.assertEqual(Trial.PENDING, restored_trial.status)
runner2.step() # Start trial while not runner2.is_finished():
runner2.step() # Process result, dispatch save # Start trial,
runner2.step() # Process save # Process result, dispatch save
runner2.step() # Process result, dispatch save # Process save
runner2.step() # Process save # Process result, dispatch save
self.assertRaises(TuneError, runner2.step) # Process save.
runner2.step()
self.assertEqual(restored_trial.status, Trial.TERMINATED)
def testTrialNoCheckpointSave(self): def testTrialNoCheckpointSave(self):
"""Check that non-checkpointing trials *are* saved.""" """Check that non-checkpointing trials *are* saved."""
@ -533,7 +527,8 @@ class TrialRunnerTest3(unittest.TestCase):
) )
) )
runner.step() old_trials = runner.get_trials()
while not old_trials[2].has_reported_at_least_once:
runner.step() runner.step()
runner2 = TrialRunner(resume="LOCAL", local_checkpoint_dir=self.tmpdir) runner2 = TrialRunner(resume="LOCAL", local_checkpoint_dir=self.tmpdir)
@ -643,29 +638,42 @@ class TrialRunnerTest3(unittest.TestCase):
checkpoint_at_end=True, checkpoint_at_end=True,
stopping_criterion={"training_iteration": 4}, stopping_criterion={"training_iteration": 4},
) )
observer = TrialResultObserver()
runner = TrialRunner( runner = TrialRunner(
local_checkpoint_dir=self.tmpdir, local_checkpoint_dir=self.tmpdir,
checkpoint_period=0, checkpoint_period=0,
trial_executor=RayTrialExecutor(result_buffer_length=7), trial_executor=RayTrialExecutor(result_buffer_length=7),
callbacks=[observer],
) )
runner.add_trial(trial) runner.add_trial(trial)
runner.step() # start trial while not observer.just_received_a_result():
runner.step()
runner.step() # run iteration 1
self.assertEqual(trial.last_result[TRAINING_ITERATION], 1) self.assertEqual(trial.last_result[TRAINING_ITERATION], 1)
self.assertEqual(num_checkpoints(trial), 0) self.assertEqual(num_checkpoints(trial), 0)
runner.step() # run iteration 2 while True:
runner.step()
if observer.just_received_a_result():
break
self.assertEqual(trial.last_result[TRAINING_ITERATION], 2) self.assertEqual(trial.last_result[TRAINING_ITERATION], 2)
self.assertEqual(num_checkpoints(trial), 0) self.assertEqual(num_checkpoints(trial), 0)
runner.step() # run iteration 3 while True:
runner.step()
if observer.just_received_a_result():
break
self.assertEqual(trial.last_result[TRAINING_ITERATION], 3) self.assertEqual(trial.last_result[TRAINING_ITERATION], 3)
self.assertEqual(num_checkpoints(trial), 0) self.assertEqual(num_checkpoints(trial), 0)
runner.step() # run iteration 4 while True:
runner.step()
if observer.just_received_a_result():
break
self.assertEqual(trial.last_result[TRAINING_ITERATION], 4) self.assertEqual(trial.last_result[TRAINING_ITERATION], 4)
while not runner.is_finished():
runner.step()
self.assertEqual(num_checkpoints(trial), 1) self.assertEqual(num_checkpoints(trial), 1)
def testUserCheckpoint(self): def testUserCheckpoint(self):

View file

@ -9,11 +9,14 @@ from collections import OrderedDict
import ray import ray
from ray import tune from ray import tune
from ray.exceptions import RayActorError
from ray.rllib import _register_all from ray.rllib import _register_all
from ray.tune.checkpoint_manager import Checkpoint from ray.tune.checkpoint_manager import Checkpoint
from ray.tune.logger import DEFAULT_LOGGERS, LoggerCallback, LegacyLoggerCallback from ray.tune.logger import DEFAULT_LOGGERS, LoggerCallback, LegacyLoggerCallback
from ray.tune.ray_trial_executor import RayTrialExecutor from ray.tune.ray_trial_executor import (
RayTrialExecutor,
ExecutorEvent,
ExecutorEventType,
)
from ray.tune.result import TRAINING_ITERATION from ray.tune.result import TRAINING_ITERATION
from ray.tune.syncer import SyncConfig, SyncerCallback from ray.tune.syncer import SyncConfig, SyncerCallback
@ -65,28 +68,25 @@ class TestCallback(Callback):
self.state["experiment_end"] = info self.state["experiment_end"] = info
# TODO(xwjiang): Move this to a testing util.
class _MockTrialExecutor(RayTrialExecutor): class _MockTrialExecutor(RayTrialExecutor):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.next_trial = None self.next_future_result = None
self.results = {}
self.should_fail_in_fetch_result = False
def fetch_result(self, trial): def start_trial(self, trial: Trial):
if self.should_fail_in_fetch_result: trial.status = Trial.RUNNING
raise RayActorError( return True
"The actor died unexpectedly before finishing this task."
)
else:
return [self.results.get(trial, {})]
def get_next_available_trial(self, timeout=None): def continue_training(self, trial: Trial):
return self.next_trial or super().get_next_available_trial() pass
def get_next_executor_event(self, live_trials, next_trial_exists):
return self.next_future_result
class TrialRunnerCallbacks(unittest.TestCase): class TrialRunnerCallbacks(unittest.TestCase):
def setUp(self): def setUp(self):
os.environ["TUNE_PLACEMENT_GROUP_WAIT_S"] = "1"
ray.init() ray.init()
self.tmpdir = tempfile.mkdtemp() self.tmpdir = tempfile.mkdtemp()
@ -110,7 +110,9 @@ class TrialRunnerCallbacks(unittest.TestCase):
for t in trials: for t in trials:
self.trial_runner.add_trial(t) self.trial_runner.add_trial(t)
self.executor.next_trial = trials[0] self.executor.next_future_result = ExecutorEvent(
event_type=ExecutorEventType.PG_READY
)
self.trial_runner.step() self.trial_runner.step()
# Trial 1 has been started # Trial 1 has been started
@ -132,7 +134,9 @@ class TrialRunnerCallbacks(unittest.TestCase):
) )
) )
self.executor.next_trial = trials[1] self.executor.next_future_result = ExecutorEvent(
event_type=ExecutorEventType.PG_READY
)
self.trial_runner.step() self.trial_runner.step()
# Iteration not increased yet # Iteration not increased yet
@ -148,7 +152,9 @@ class TrialRunnerCallbacks(unittest.TestCase):
cp = Checkpoint(Checkpoint.PERSISTENT, "__checkpoint", {TRAINING_ITERATION: 0}) cp = Checkpoint(Checkpoint.PERSISTENT, "__checkpoint", {TRAINING_ITERATION: 0})
# Let the first trial save a checkpoint # Let the first trial save a checkpoint
self.executor.next_trial = trials[0] self.executor.next_future_result = ExecutorEvent(
event_type=ExecutorEventType.SAVING_RESULT, trial=trials[0]
)
trials[0].saving_to = cp trials[0].saving_to = cp
self.trial_runner.step() self.trial_runner.step()
self.assertEqual(self.callback.state["trial_save"]["iteration"], 2) self.assertEqual(self.callback.state["trial_save"]["iteration"], 2)
@ -156,8 +162,9 @@ class TrialRunnerCallbacks(unittest.TestCase):
# Let the second trial send a result # Let the second trial send a result
result = {TRAINING_ITERATION: 1, "metric": 800, "done": False} result = {TRAINING_ITERATION: 1, "metric": 800, "done": False}
self.executor.results[trials[1]] = result self.executor.next_future_result = ExecutorEvent(
self.executor.next_trial = trials[1] event_type=ExecutorEventType.TRAINING_RESULT, trial=trials[1], result=result
)
self.assertTrue(not trials[1].has_reported_at_least_once) self.assertTrue(not trials[1].has_reported_at_least_once)
self.trial_runner.step() self.trial_runner.step()
self.assertEqual(self.callback.state["trial_result"]["iteration"], 3) self.assertEqual(self.callback.state["trial_result"]["iteration"], 3)
@ -167,25 +174,28 @@ class TrialRunnerCallbacks(unittest.TestCase):
# Let the second trial restore from a checkpoint # Let the second trial restore from a checkpoint
trials[1].restoring_from = cp trials[1].restoring_from = cp
self.executor.results[trials[1]] = trials[1].last_result self.executor.next_future_result = ExecutorEvent(
event_type=ExecutorEventType.RESTORING_RESULT, trial=trials[1]
)
self.trial_runner.step() self.trial_runner.step()
self.assertEqual(self.callback.state["trial_restore"]["iteration"], 4) self.assertEqual(self.callback.state["trial_restore"]["iteration"], 4)
self.assertEqual(self.callback.state["trial_restore"]["trial"].trial_id, "two") self.assertEqual(self.callback.state["trial_restore"]["trial"].trial_id, "two")
# Let the second trial finish # Let the second trial finish
trials[1].restoring_from = None trials[1].restoring_from = None
self.executor.results[trials[1]] = { self.executor.next_future_result = ExecutorEvent(
TRAINING_ITERATION: 2, event_type=ExecutorEventType.TRAINING_RESULT,
"metric": 900, trial=trials[1],
"done": True, result={TRAINING_ITERATION: 2, "metric": 900, "done": True},
} )
self.trial_runner.step() self.trial_runner.step()
self.assertEqual(self.callback.state["trial_complete"]["iteration"], 5) self.assertEqual(self.callback.state["trial_complete"]["iteration"], 5)
self.assertEqual(self.callback.state["trial_complete"]["trial"].trial_id, "two") self.assertEqual(self.callback.state["trial_complete"]["trial"].trial_id, "two")
# Let the first trial error # Let the first trial error
self.executor.next_trial = trials[0] self.executor.next_future_result = ExecutorEvent(
self.executor.should_fail_in_fetch_result = True event_type=ExecutorEventType.ERROR, trial=trials[0]
)
self.trial_runner.step() self.trial_runner.step()
self.assertEqual(self.callback.state["trial_fail"]["iteration"], 6) self.assertEqual(self.callback.state["trial_fail"]["iteration"], 6)
self.assertEqual(self.callback.state["trial_fail"]["trial"].trial_id, "one") self.assertEqual(self.callback.state["trial_fail"]["trial"].trial_id, "one")

View file

@ -53,7 +53,6 @@ class TrialRunnerPlacementGroupTest(unittest.TestCase):
self.assertFalse(pg_manager._staging[pgf]) self.assertFalse(pg_manager._staging[pgf])
for pgf in pg_manager._ready: for pgf in pg_manager._ready:
self.assertFalse(pg_manager._ready[pgf]) self.assertFalse(pg_manager._ready[pgf])
self.assertTrue(pg_manager._latest_staging_start_time)
num_non_removed_pgs = len( num_non_removed_pgs = len(
[p for pid, p in placement_group_table().items() if p["state"] != "REMOVED"] [p for pid, p in placement_group_table().items() if p["state"] != "REMOVED"]
@ -109,17 +108,6 @@ class TrialRunnerPlacementGroupTest(unittest.TestCase):
total_num_tracked = num_staging + num_ready + num_in_use + num_cached total_num_tracked = num_staging + num_ready + num_in_use + num_cached
num_non_removed_pgs = len(
[
p
for pid, p in placement_group_table().items()
if p["state"] != "REMOVED"
]
)
num_removal_scheduled_pgs = len(
trial_executor._pg_manager._pgs_for_removal
)
# All trials should be scheduled # All trials should be scheduled
this.assertEqual( this.assertEqual(
scheduled, scheduled,
@ -141,13 +129,6 @@ class TrialRunnerPlacementGroupTest(unittest.TestCase):
msg=f"Num tracked iter {iteration}", msg=f"Num tracked iter {iteration}",
) )
# The number of actual placement groups should match this
this.assertGreaterEqual(
max(scheduled, len(trials)) - num_finished + num_parallel_reuse,
num_non_removed_pgs - num_removal_scheduled_pgs,
msg=f"Num actual iter {iteration}",
)
start = time.time() start = time.time()
out = tune.run( out = tune.run(
train, train,

View file

@ -0,0 +1,22 @@
from ray.tune import Callback
class TrialResultObserver(Callback):
"""Helper class to control runner.step() count."""
def __init__(self):
self._counter = 0
self._last_counter = 0
def reset(self):
self._last_counter = self._counter
def just_received_a_result(self):
if self._last_counter == self._counter:
return False
else:
self._last_counter = self._counter
return True
def on_trial_result(self, **kwargs):
self._counter += 1

View file

@ -173,7 +173,6 @@ class PopulationBasedTrainingFileDescriptorTest(unittest.TestCase):
class PopulationBasedTrainingSynchTest(unittest.TestCase): class PopulationBasedTrainingSynchTest(unittest.TestCase):
def setUp(self): def setUp(self):
os.environ["TUNE_TRIAL_STARTUP_GRACE_PERIOD"] = "0"
ray.init(num_cpus=2) ray.init(num_cpus=2)
def MockTrainingFuncSync(config, checkpoint_dir=None): def MockTrainingFuncSync(config, checkpoint_dir=None):

View file

@ -105,13 +105,6 @@ def _run(local_dir, driver_semaphore, trainer_semaphore):
class TuneInterruptionTest(unittest.TestCase): class TuneInterruptionTest(unittest.TestCase):
def setUp(self) -> None:
# Wait up to five seconds for placement groups when starting a trial
os.environ["TUNE_PLACEMENT_GROUP_WAIT_S"] = "5"
# Block for results even when placement groups are pending
os.environ["TUNE_TRIAL_STARTUP_GRACE_PERIOD"] = "0"
os.environ["TUNE_TRIAL_RESULT_WAIT_TIME_S"] = "99999"
def testExperimentInterrupted(self): def testExperimentInterrupted(self):
local_dir = tempfile.mkdtemp() local_dir = tempfile.mkdtemp()
# Unix platforms may default to "fork", which is problematic with # Unix platforms may default to "fork", which is problematic with
@ -214,11 +207,6 @@ class TuneFailResumeGridTest(unittest.TestCase):
def setUp(self): def setUp(self):
self.logdir = tempfile.mkdtemp() self.logdir = tempfile.mkdtemp()
os.environ["TUNE_GLOBAL_CHECKPOINT_S"] = "0" os.environ["TUNE_GLOBAL_CHECKPOINT_S"] = "0"
# Wait up to 1.5 seconds for placement groups when starting a trial
os.environ["TUNE_PLACEMENT_GROUP_WAIT_S"] = "1.5"
# Block for results even when placement groups are pending
os.environ["TUNE_TRIAL_STARTUP_GRACE_PERIOD"] = "0"
os.environ["TUNE_TRIAL_RESULT_WAIT_TIME_S"] = "99999"
# Change back to local_mode=True after this is resolved: # Change back to local_mode=True after this is resolved:
# https://github.com/ray-project/ray/issues/13932 # https://github.com/ray-project/ray/issues/13932

View file

@ -1,4 +1,3 @@
import os
import requests import requests
import socket import socket
import subprocess import subprocess
@ -28,11 +27,6 @@ def get_valid_port():
class TuneServerSuite(unittest.TestCase): class TuneServerSuite(unittest.TestCase):
def basicSetup(self): def basicSetup(self):
# Wait up to five seconds for placement groups when starting a trial
os.environ["TUNE_PLACEMENT_GROUP_WAIT_S"] = "5"
# Block for results even when placement groups are pending
os.environ["TUNE_TRIAL_STARTUP_GRACE_PERIOD"] = "0"
os.environ["TUNE_TRIAL_RESULT_WAIT_TIME_S"] = "99999"
ray.init(num_cpus=4, num_gpus=1) ray.init(num_cpus=4, num_gpus=1)
port = get_valid_port() port = get_valid_port()

View file

@ -4,7 +4,6 @@ import logging
from typing import Dict, List, Optional from typing import Dict, List, Optional
import warnings import warnings
from ray.tune.resources import Resources
from ray.util.annotations import DeveloperAPI from ray.util.annotations import DeveloperAPI
from ray.tune.trial import Trial, Checkpoint from ray.tune.trial import Trial, Checkpoint
@ -79,11 +78,6 @@ class TrialExecutor(metaclass=_WarnOnDirectInheritanceMeta):
self._trials_to_cache.clear() self._trials_to_cache.clear()
return self._cached_trial_state return self._cached_trial_state
@abstractmethod
def has_resources(self, resources: Resources) -> bool:
"""Returns whether this runner has at least the specified resources."""
pass
@abstractmethod @abstractmethod
def start_trial(self, trial: Trial) -> bool: def start_trial(self, trial: Trial) -> bool:
"""Starts the trial restoring from checkpoint if checkpoint is provided. """Starts the trial restoring from checkpoint if checkpoint is provided.
@ -150,11 +144,6 @@ class TrialExecutor(metaclass=_WarnOnDirectInheritanceMeta):
""" """
pass pass
@abstractmethod
def get_running_trials(self) -> List[Trial]:
"""Returns all running trials."""
pass
def on_step_begin(self, trials: List[Trial]) -> None: def on_step_begin(self, trials: List[Trial]) -> None:
"""A hook called before running one step of the trial event loop. """A hook called before running one step of the trial event loop.
@ -176,26 +165,6 @@ class TrialExecutor(metaclass=_WarnOnDirectInheritanceMeta):
def force_reconcilation_on_next_step_end(self) -> None: def force_reconcilation_on_next_step_end(self) -> None:
pass pass
@abstractmethod
def get_next_available_trial(self) -> Optional[Trial]:
"""Blocking call that waits until one result is ready.
Returns:
Trial object that is ready for intermediate processing.
"""
pass
@abstractmethod
def fetch_result(self, trial: Trial) -> List[Trial]:
"""Fetches one result for the trial.
Assumes the trial is running.
Returns:
Result object for the trial.
"""
pass
@abstractmethod @abstractmethod
def debug_string(self) -> str: def debug_string(self) -> str:
"""Returns a human readable message for printing to the console.""" """Returns a human readable message for printing to the console."""
@ -260,10 +229,6 @@ class TrialExecutor(metaclass=_WarnOnDirectInheritanceMeta):
""" """
pass pass
def in_staging_grace_period(self) -> bool:
"""Returns True if trials have recently been staged."""
return False
def set_max_pending_trials(self, max_pending: int) -> None: def set_max_pending_trials(self, max_pending: int) -> None:
"""Set the maximum number of allowed pending trials.""" """Set the maximum number of allowed pending trials."""
pass pass

View file

@ -15,7 +15,7 @@ from ray.tune import TuneError
from ray.tune.callback import CallbackList from ray.tune.callback import CallbackList
from ray.tune.experiment import Experiment from ray.tune.experiment import Experiment
from ray.tune.insufficient_resources_manager import InsufficientResourcesManager from ray.tune.insufficient_resources_manager import InsufficientResourcesManager
from ray.tune.ray_trial_executor import RayTrialExecutor from ray.tune.ray_trial_executor import RayTrialExecutor, ExecutorEventType
from ray.tune.result import ( from ray.tune.result import (
DEBUG_METRICS, DEBUG_METRICS,
DEFAULT_METRIC, DEFAULT_METRIC,
@ -338,7 +338,6 @@ class TrialRunner:
self._cached_trial_decisions = {} self._cached_trial_decisions = {}
self._queued_trial_decisions = {} self._queued_trial_decisions = {}
self._updated_queue = False self._updated_queue = False
self._result_wait_time = int(os.getenv("TUNE_TRIAL_RESULT_WAIT_TIME_S", "1"))
self._stop_queue = [] self._stop_queue = []
self._should_stop_experiment = False # used by TuneServer self._should_stop_experiment = False # used by TuneServer
@ -685,23 +684,14 @@ class TrialRunner:
) and all(trial.is_finished() for trial in self._trials) ) and all(trial.is_finished() for trial in self._trials)
return trials_done and self._search_alg.is_finished() return trials_done and self._search_alg.is_finished()
def step(self): def _update_trial_queue_and_get_next_trial(self) -> Optional[Trial]:
"""Runs one step of the trial event loop. """Adding suggested trials to the live queue of trials (they start as PENDING trials).
Callers should typically run this method repeatedly in a loop. They Returns:
may inspect or modify the runner's state in between calls to step(). next_trial: Trial
""" """
self._updated_queue = False self._updated_queue = False
if self.is_finished():
raise TuneError("Called step when all trials finished?")
with warn_if_slow("on_step_begin"):
self.trial_executor.on_step_begin(self.get_trials())
with warn_if_slow("callbacks.on_step_begin"):
self._callbacks.on_step_begin(
iteration=self._iteration, trials=self._trials
)
# This will contain the next trial to start # This will contain the next trial to start
next_trial = self._get_next_trial() # blocking next_trial = self._get_next_trial() # blocking
# Create pending trials. If the queue was updated before, only # Create pending trials. If the queue was updated before, only
@ -715,46 +705,64 @@ class TrialRunner:
break break
num_pending_trials += 1 num_pending_trials += 1
# Update status of staged placement groups return next_trial
self.trial_executor.stage_and_update_status(self._live_trials)
def _start_trial(trial: Trial) -> bool: def _wait_and_handle_event(self, next_trial: Optional[Trial]):
"""Helper function to start trial and call callbacks""" try:
with warn_if_slow("start_trial"): # Single wait of entire tune loop.
if self.trial_executor.start_trial(trial): future_result = self.trial_executor.get_next_executor_event(
self._callbacks.on_trial_start( self._live_trials, next_trial is not None
iteration=self._iteration, trials=self._trials, trial=trial
) )
return True if future_result.type == ExecutorEventType.PG_READY:
return False self._on_pg_ready(next_trial)
elif future_result.type == ExecutorEventType.NO_RUNNING_TRIAL_TIMEOUT:
may_handle_events = True
if next_trial is not None:
if _start_trial(next_trial):
may_handle_events = False
elif next_trial.status != Trial.ERROR:
# Only try to start another trial if previous trial startup
# did not error (e.g. it just didn't start because its
# placement group is not ready, yet).
# Without this clause, this test fails:
# test_trial_runner_pg.py::
# TrialRunnerPlacementGroupHeterogeneousTest::
# testResourceDeadlock
next_trial = self.trial_executor.get_staged_trial()
if next_trial is not None:
if _start_trial(next_trial):
may_handle_events = False
if may_handle_events:
if self.trial_executor.get_running_trials():
timeout = self._result_wait_time
if self.trial_executor.in_staging_grace_period():
timeout = 0.1
self._process_events(timeout=timeout)
else:
self._insufficient_resources_manager.on_no_available_trials( self._insufficient_resources_manager.on_no_available_trials(
self.get_trials() self.get_trials()
) )
elif future_result.type == ExecutorEventType.YIELD:
pass
else:
trial = future_result.trial
result = future_result.result
if future_result.type == ExecutorEventType.ERROR:
self._on_executor_error(trial, result)
elif future_result.type == ExecutorEventType.RESTORING_RESULT:
self._on_restoring_result(trial)
else:
assert future_result.type in (
ExecutorEventType.SAVING_RESULT,
ExecutorEventType.TRAINING_RESULT,
), f"Unexpected future type - {future_result.type}"
if future_result.type == ExecutorEventType.TRAINING_RESULT:
self._on_training_result(trial, result)
else:
self._on_saving_result(trial, result)
self._post_process_on_training_saving_result(trial)
except Exception as e:
if e is TuneError:
raise e
else:
raise TuneError(traceback.format_exc())
def step(self):
"""Runs one step of the trial event loop.
Callers should typically run this method repeatedly in a loop. They
may inspect or modify the runner's state in between calls to step().
"""
if self.is_finished():
raise TuneError("Called step when all trials finished?")
with warn_if_slow("on_step_begin"):
self.trial_executor.on_step_begin(self.get_trials())
with warn_if_slow("callbacks.on_step_begin"):
self._callbacks.on_step_begin(
iteration=self._iteration, trials=self._trials
)
next_trial = self._update_trial_queue_and_get_next_trial()
self._wait_and_handle_event(next_trial)
self._stop_experiment_if_needed() self._stop_experiment_if_needed()
@ -770,12 +778,100 @@ class TrialRunner:
if self.is_finished(): if self.is_finished():
self._server.shutdown() self._server.shutdown()
self._reconcile_live_trials()
with warn_if_slow("on_step_end"): with warn_if_slow("on_step_end"):
self.trial_executor.on_step_end(self.get_trials()) self.trial_executor.on_step_end(self.get_trials())
with warn_if_slow("callbacks.on_step_end"): with warn_if_slow("callbacks.on_step_end"):
self._callbacks.on_step_end(iteration=self._iteration, trials=self._trials) self._callbacks.on_step_end(iteration=self._iteration, trials=self._trials)
self._reconcile_live_trials() def _on_pg_ready(self, next_trial: Optional[Trial]):
def _start_trial(trial: Trial) -> bool:
"""Helper function to start trial and call callbacks"""
with warn_if_slow("start_trial"):
if self.trial_executor.start_trial(trial):
self._callbacks.on_trial_start(
iteration=self._iteration, trials=self._trials, trial=trial
)
return True
return False
assert next_trial is not None
logger.info(f"starting {next_trial}")
if not _start_trial(next_trial) and next_trial.status != Trial.ERROR:
# Only try to start another trial if previous trial startup
# did not error (e.g. it just didn't start because its
# placement group is not ready, yet).
# Without this clause, this test fails:
# test_trial_runner_pg.py::
# TrialRunnerPlacementGroupHeterogeneousTest::
# testResourceDeadlock
next_trial = self.trial_executor.get_staged_trial()
if next_trial is not None:
# Must be able to start.
assert _start_trial(next_trial)
else:
logger.info(f"reconciling {self.get_trials()}")
self.trial_executor._pg_manager.reconcile_placement_groups(
self.get_trials()
)
def _on_saving_result(self, trial, result):
with warn_if_slow("process_trial_save") as _profile:
self._process_trial_save(trial, result)
with warn_if_slow("callbacks.on_trial_save"):
self._callbacks.on_trial_save(
iteration=self._iteration, trials=self._trials, trial=trial
)
if _profile.too_slow and trial.sync_on_checkpoint:
# TODO(ujvl): Suggest using cloud checkpointing once
# API has converged.
msg = (
"Consider turning off forced head-worker trial "
"checkpoint syncs by setting sync_on_checkpoint=False"
". Note that this may result in faulty trial "
"restoration if a failure occurs while the checkpoint "
"is being synced from the worker to the head node."
)
if trial.location.hostname and (
trial.location.hostname != get_node_ip_address()
):
if log_once("tune_head_worker_checkpoint"):
logger.warning(msg)
def _on_restoring_result(self, trial):
with warn_if_slow("process_trial_restore"):
self._process_trial_restore(trial)
with warn_if_slow("callbacks.on_trial_restore"):
self._callbacks.on_trial_restore(
iteration=self._iteration, trials=self._trials, trial=trial
)
def _on_training_result(self, trial, result):
if not isinstance(result, list):
result = [result]
with warn_if_slow("process_trial_result"):
self._process_trial_results(trial, result)
def _post_process_on_training_saving_result(self, trial):
# `self._queued_trial_decisions` now contains a final decision
# based on all results
if trial not in self._cached_trial_decisions:
final_decision = self._queued_trial_decisions.pop(trial.trial_id, None)
if final_decision:
self._execute_action(trial, final_decision)
def _on_executor_error(self, trial, result):
error_msg = f"Trial {trial}: Error processing event."
if self._fail_fast == TrialRunner.RAISE:
logger.error(error_msg)
raise
else:
logger.exception(error_msg)
self._process_trial_failure(trial, result)
def get_trial(self, tid): def get_trial(self, tid):
trial = [t for t in self._trials if t.trial_id == tid] trial = [t for t in self._trials if t.trial_id == tid]
@ -857,75 +953,8 @@ class TrialRunner:
logger.debug("Running trial {}".format(trial)) logger.debug("Running trial {}".format(trial))
return trial return trial
def _process_events(self, timeout: Optional[float] = None): def _process_trial_results(self, trial, results):
# TODO(ujvl): Consider combining get_next_available_trial and logger.debug(f"process_trial_results {results}")
# fetch_result functionality so that we don't timeout on fetch.
trial = self.trial_executor.get_next_available_trial(
timeout=timeout
) # blocking
if not trial:
return
if trial.is_restoring:
with warn_if_slow("process_trial_restore"):
self._process_trial_restore(trial)
with warn_if_slow("callbacks.on_trial_restore"):
self._callbacks.on_trial_restore(
iteration=self._iteration, trials=self._trials, trial=trial
)
elif trial.is_saving:
with warn_if_slow("process_trial_save") as _profile:
self._process_trial_save(trial)
with warn_if_slow("callbacks.on_trial_save"):
self._callbacks.on_trial_save(
iteration=self._iteration, trials=self._trials, trial=trial
)
if _profile.too_slow and trial.sync_on_checkpoint:
# TODO(ujvl): Suggest using cloud checkpointing once
# API has converged.
msg = (
"Consider turning off forced head-worker trial "
"checkpoint syncs by setting sync_on_checkpoint=False"
". Note that this may result in faulty trial "
"restoration if a failure occurs while the checkpoint "
"is being synced from the worker to the head node."
)
if trial.location.hostname and (
trial.location.hostname != get_node_ip_address()
):
if log_once("tune_head_worker_checkpoint"):
logger.warning(msg)
else:
with warn_if_slow("process_trial"):
self._process_trial(trial)
# `self._queued_trial_decisions` now contains a final decision
# based on all results
if trial not in self._cached_trial_decisions:
final_decision = self._queued_trial_decisions.pop(trial.trial_id, None)
if final_decision:
self._execute_action(trial, final_decision)
def _process_trial(self, trial):
"""Processes a trial result.
Fetches the trial's latest result and makes a scheduling decision
regarding its next action. If a checkpoint is taken, the decided
action is cached and acted on only after the checkpoint is later
processed (see `_process_trial_save`). Otherwise the decision is
acted on immediately.
If multiple results are received (e.g. because of buffering), all
results are processed and the final action is determined. STOP
takes precedence over PAUSE, which takes precedence over CONTINUE.
Args:
trial (Trial): Trial with a result ready to be processed.
"""
try:
results = self.trial_executor.fetch_result(trial)
with warn_if_slow( with warn_if_slow(
"process_trial_results", "process_trial_results",
message="Processing trial results took {duration:.3f} s, " message="Processing trial results took {duration:.3f} s, "
@ -955,14 +984,6 @@ class TrialRunner:
# If the decision is to stop the trial, # If the decision is to stop the trial,
# ignore all results that came after that. # ignore all results that came after that.
break break
except Exception:
error_msg = "Trial %s: Error processing event." % trial
if self._fail_fast == TrialRunner.RAISE:
logger.error(error_msg)
raise
else:
logger.exception(error_msg)
self._process_trial_failure(trial, traceback.format_exc())
def _process_trial_result(self, trial, result): def _process_trial_result(self, trial, result):
result.update(trial_id=trial.trial_id) result.update(trial_id=trial.trial_id)
@ -1014,6 +1035,7 @@ class TrialRunner:
self._checkpoint_trial_if_needed(trial, force=force_checkpoint) self._checkpoint_trial_if_needed(trial, force=force_checkpoint)
if trial.is_saving: if trial.is_saving:
logger.info(f"caching trial decision {trial}")
# Cache decision to execute on after the save is processed. # Cache decision to execute on after the save is processed.
# This prevents changing the trial's state or kicking off # This prevents changing the trial's state or kicking off
# another training step prematurely. # another training step prematurely.
@ -1085,7 +1107,7 @@ class TrialRunner:
) )
) )
def _process_trial_save(self, trial): def _process_trial_save(self, trial, result):
"""Processes a trial save. """Processes a trial save.
Acts on the decision cached during the last `_process_trial` call. Acts on the decision cached during the last `_process_trial` call.
@ -1094,19 +1116,9 @@ class TrialRunner:
trial (Trial): Trial being saved. trial (Trial): Trial being saved.
""" """
logger.debug("Trial %s: Processing trial save.", trial) logger.debug("Trial %s: Processing trial save.", trial)
checkpoint_value = None
try:
results = self.trial_executor.fetch_result(trial)
checkpoint_value = results[-1]
except Exception:
logger.exception("Trial %s: Error processing result.", trial)
if self._fail_fast == TrialRunner.RAISE:
raise
self._process_trial_failure(trial, traceback.format_exc())
if checkpoint_value:
try: try:
trial.saving_to.value = checkpoint_value trial.saving_to.value = result
self._callbacks.on_checkpoint( self._callbacks.on_checkpoint(
iteration=self._iteration, iteration=self._iteration,
trials=self._trials, trials=self._trials,
@ -1117,15 +1129,13 @@ class TrialRunner:
if trial.checkpoint.storage != Checkpoint.MEMORY: if trial.checkpoint.storage != Checkpoint.MEMORY:
self.trial_executor.mark_trial_to_checkpoint(trial) self.trial_executor.mark_trial_to_checkpoint(trial)
except Exception: except Exception:
logger.exception( logger.exception("Trial %s: Error handling checkpoint %s", trial, result)
"Trial %s: Error handling checkpoint %s", trial, checkpoint_value
)
if self._fail_fast == TrialRunner.RAISE: if self._fail_fast == TrialRunner.RAISE:
raise raise
trial.saving_to = None trial.saving_to = None
decision = self._cached_trial_decisions.pop(trial.trial_id, None) decision = self._cached_trial_decisions.pop(trial.trial_id, None)
if decision and checkpoint_value: if decision and result:
self._queue_decision(trial, decision) self._queue_decision(trial, decision)
def _process_trial_restore(self, trial): def _process_trial_restore(self, trial):
@ -1135,18 +1145,11 @@ class TrialRunner:
trial (Trial): Trial being restored. trial (Trial): Trial being restored.
""" """
logger.debug("Trial %s: Processing trial restore.", trial) logger.debug("Trial %s: Processing trial restore.", trial)
try:
self.trial_executor.fetch_result(trial)
trial.on_restore() trial.on_restore()
logger.debug("Trial %s: Restore processed successfully", trial) logger.debug("Trial %s: Restore processed successfully", trial)
self.trial_executor.set_status(trial, Trial.RUNNING) self.trial_executor.set_status(trial, Trial.RUNNING)
self.trial_executor.continue_training(trial) self.trial_executor.continue_training(trial)
self._live_trials.add(trial) self._live_trials.add(trial)
except Exception:
logger.exception("Trial %s: Error processing restore.", trial)
if self._fail_fast == TrialRunner.RAISE:
raise
self._process_trial_failure(trial, traceback.format_exc())
def _process_trial_failure(self, trial, error_msg): def _process_trial_failure(self, trial, error_msg):
"""Handle trial failure. """Handle trial failure.
@ -1217,6 +1220,10 @@ class TrialRunner:
trial (Trial): Trial to recover. trial (Trial): Trial to recover.
error_msg (str): Error message from prior to invoking this method. error_msg (str): Error message from prior to invoking this method.
""" """
self._cached_trial_decisions.pop(trial.trial_id, None)
# Resetting this, in case that the trial is in saving status when it crashes.
if trial.is_saving:
trial.saving_to = None
if trial.is_restoring: if trial.is_restoring:
# Restore was unsuccessful, try again without checkpoint. # Restore was unsuccessful, try again without checkpoint.
trial.clear_checkpoint() trial.clear_checkpoint()
@ -1229,6 +1236,8 @@ class TrialRunner:
"Trial %s: Attempting to restore " "trial state from last checkpoint.", "Trial %s: Attempting to restore " "trial state from last checkpoint.",
trial, trial,
) )
# TODO(xwjiang): For better consistency, consider not starting
# trials here. Instead rely on requeuing the trial.
started = self.trial_executor.start_trial(trial) started = self.trial_executor.start_trial(trial)
if not started: if not started:
requeue_trial = True requeue_trial = True

View file

@ -306,6 +306,16 @@ def run(
"removing this argument from your call to `tune.run()`" "removing this argument from your call to `tune.run()`"
) )
# Starting deprecation in ray 1.10.
if os.environ.get("TUNE_TRIAL_RESULT_WAIT_TIME_S") is not None:
warnings.warn("`TUNE_TRIAL_RESULT_WAIT_TIME_S` is deprecated.")
if os.environ.get("TUNE_TRIAL_STARTUP_GRACE_PERIOD") is not None:
warnings.warn("`TUNE_TRIAL_STARTUP_GRACE_PERIOD` is deprecated.")
if os.environ.get("TUNE_PLACEMENT_GROUP_WAIT_S") is not None:
warnings.warn("`TUNE_PLACEMENT_GROUP_WAIT_S` is deprecated.")
# NO CODE IS TO BE ADDED ABOVE THIS COMMENT # NO CODE IS TO BE ADDED ABOVE THIS COMMENT
# remote_run_kwargs must be defined before any other # remote_run_kwargs must be defined before any other
# code is ran to ensure that at this point, # code is ran to ensure that at this point,

View file

@ -316,16 +316,16 @@ class PlacementGroupManager:
self._cached_pgs: Dict[PlacementGroup, PlacementGroupFactory] = {} self._cached_pgs: Dict[PlacementGroup, PlacementGroupFactory] = {}
# Placement groups scheduled for delayed removal. # Placement groups scheduled for delayed removal.
# This is used as a damper to filter out some high frequency change
# in resources request.
# Only PGs that have never been used go here.
# TODO(xwjiang): `self._pgs_for_removal` and `self._unstaged_xxx`
# are really the same now. We should consolidate to using one.
# Also `remove_placement_group` method should just be combined with
# `unstage_unused_xxx`.
self._pgs_for_removal: Dict[PlacementGroup, float] = {} self._pgs_for_removal: Dict[PlacementGroup, float] = {}
self._removal_delay = TUNE_PLACEMENT_GROUP_REMOVAL_DELAY self._removal_delay = TUNE_PLACEMENT_GROUP_REMOVAL_DELAY
# Latest PG staging time to check if still in grace period.
self._latest_staging_start_time = time.time()
# Seconds we wait for a trial to come up before we make blocking calls
# to process events
self._grace_period = float(os.getenv("TUNE_TRIAL_STARTUP_GRACE_PERIOD", 10.0))
self._max_staging = max_staging self._max_staging = max_staging
def set_max_staging(self, max_staging: int): def set_max_staging(self, max_staging: int):
@ -440,7 +440,6 @@ class PlacementGroupManager:
self._staging[pgf].add(pg) self._staging[pgf].add(pg)
self._staging_futures[pg.ready()] = (pgf, pg) self._staging_futures[pg.ready()] = (pgf, pg)
self._latest_staging_start_time = time.time()
return True return True
@ -461,11 +460,17 @@ class PlacementGroupManager:
ready, _ = ray.wait(list(self._staging_futures.keys()), timeout=0) ready, _ = ray.wait(list(self._staging_futures.keys()), timeout=0)
for ready_fut in ready: for ready_fut in ready:
self.handle_ready_future(ready_fut)
def handle_ready_future(self, ready_fut):
ready_pgf, ready_pg = self._staging_futures.pop(ready_fut) ready_pgf, ready_pg = self._staging_futures.pop(ready_fut)
self._staging[ready_pgf].remove(ready_pg) self._staging[ready_pgf].remove(ready_pg)
self._ready[ready_pgf].add(ready_pg) self._ready[ready_pgf].add(ready_pg)
def get_staging_future_list(self):
return list(self._staging_futures.keys())
def get_full_actor_cls( def get_full_actor_cls(
self, trial: "Trial", actor_cls: ActorClass self, trial: "Trial", actor_cls: ActorClass
) -> Optional[ActorClass]: ) -> Optional[ActorClass]:
@ -602,7 +607,7 @@ class PlacementGroupManager:
def clean_cached_pg(self, pg: PlacementGroup): def clean_cached_pg(self, pg: PlacementGroup):
self._cached_pgs.pop(pg) self._cached_pgs.pop(pg)
def return_pg(self, trial: "Trial"): def remove_from_in_use(self, trial: "Trial") -> PlacementGroup:
"""Return pg back to Core scheduling. """Return pg back to Core scheduling.
Args: Args:
@ -612,7 +617,7 @@ class PlacementGroupManager:
pg = self._in_use_trials.pop(trial) pg = self._in_use_trials.pop(trial)
self._in_use_pgs.pop(pg) self._in_use_pgs.pop(pg)
self.remove_pg(pg) return pg
def _unstage_unused_pg( def _unstage_unused_pg(
self, pgf: PlacementGroupFactory self, pgf: PlacementGroupFactory
@ -664,13 +669,6 @@ class PlacementGroupManager:
return trial_pg return trial_pg
def in_staging_grace_period(self):
return (
self._staging_futures
and self._grace_period
and time.time() <= self._latest_staging_start_time + self._grace_period
)
def reconcile_placement_groups(self, trials: List["Trial"]): def reconcile_placement_groups(self, trials: List["Trial"]):
"""Reconcile placement groups to match requirements. """Reconcile placement groups to match requirements.