[tune] Single wait refactor. (#21852)

This is a down scoped change. For the full overview picture of Tune control loop, see [`Tune control loop refactoring`](https://docs.google.com/document/d/1RDsW7SVzwMPZfA0WLOPA4YTqbRyXIHGYmBenJk33HaE/edit#heading=h.2za3bbxbs5gn)

1. Previously there are separate waits on pg ready and other events. As a result, there are quite a few timing tweaks that are inefficient, hard to understand and unit test. This PR consolidates into a single wait that is handled by TrialRunner in each step.
- A few event types are introduced, and their mapping into scenarios
  * PG_READY --> Should place a trial onto it. If somehow there is no trial to be placed there, the pg will be put in _ready momentarily. This is due to historically resources is conceptualized as a pull based model. 
  * NO_RUNNING_TRIALS_TIME_OUT --> possibly not sufficient resources case
  * TRAINING_RESULT
  * SAVING_RESULT
  * RESTORING_RESULT
  * YIELD --> This just means that simply taking very long to train. We need to punt back to the main loop to print out status info etc.

2. Previously TrialCleanup is not very efficient and can be racing between Trainable.stop() and `return_placement_group`. This PR streamlines the Trial cleanup process by explicitly let Trainable.stop() to finish followed by `return_placement_group(pg)`. Note, graceful shutdown is needed in cases like `pause_trial` where checkpointing to memory needs to be given the time to happen before the actor is gone. 

3. There are quite some env variables removed (timing tweaks), that I consider OK to proceed without deprecation cycle.
This commit is contained in:
xwjiang2010 2022-02-09 07:31:17 -08:00 committed by GitHub
parent dea3574050
commit 323511b716
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
20 changed files with 825 additions and 802 deletions

View file

@ -36,6 +36,8 @@ These are the environment variables Ray Tune currently considers:
letting them finish the current training step and any user-defined cleanup.
Setting this variable to a non-zero, positive integer will cause trials to be forcefully
terminated after a grace period of that many seconds. Defaults to ``0``.
* **TUNE_GET_EXECUTOR_EVENT_WAIT_S**: The time that TrialRunner waits for the
next ExecutorEvent in a blocking fashion. Defaults to ``5``.
* **TUNE_FUNCTION_THREAD_TIMEOUT_S**: Time in seconds the function API waits
for threads to finish after instructing them to complete. Defaults to ``2``.
* **TUNE_GLOBAL_CHECKPOINT_S**: Time in seconds that limits how often Tune's
@ -57,10 +59,6 @@ These are the environment variables Ray Tune currently considers:
In normal circumstances these shouldn't differ anyway, but reconcilation makes sure to capture cases when
placement groups are manually destroyed. Reconcilation doesn't take much time, but it can add up when
running a large number of short trials. Defaults to every ``5`` (seconds).
* **TUNE_PLACEMENT_GROUP_WAIT_S**: Default time the trial executor waits for placement
groups to be placed before continuing the tuning loop. Setting this to a float
will block for that many seconds. This is mostly used for testing purposes. Defaults
to -1, which disables blocking.
* **TUNE_RESULT_DIR**: Directory where Ray Tune trial results are stored. If this
is not set, ``~/ray_results`` will be used.
* **TUNE_RESULT_BUFFER_LENGTH**: Ray Tune can buffer results from trainables before they are passed
@ -74,11 +72,6 @@ These are the environment variables Ray Tune currently considers:
but never longer than this value. Defaults to 100 (seconds).
* **TUNE_RESULT_BUFFER_MIN_TIME_S**: Additionally, you can specify a minimum time to buffer results. Defaults to 0.
* **TUNE_SYNCER_VERBOSITY**: Amount of command output when using Tune with Docker Syncer. Defaults to 0.
* **TUNE_TRIAL_RESULT_WAIT_TIME_S**: Amount of time Ray Tune will block until a result from a running trial is received.
Defaults to 1 (second).
* **TUNE_TRIAL_STARTUP_GRACE_PERIOD**: Amount of time after starting a trial that Ray Tune checks for successful
trial startups. After the grace period, Tune will block for up to ``TUNE_TRIAL_RESULT_WAIT_TIME_S`` seconds
until a result from a running trial is received. Can be disabled by setting this to lower or equal to 0.
* **TUNE_WARN_THRESHOLD_S**: Threshold for logging if an Tune event loop operation takes too long. Defaults to 0.5 (seconds).
* **TUNE_WARN_INSUFFICENT_RESOURCE_THRESHOLD_S**: Threshold for throwing a warning if no active trials are in ``RUNNING`` state
for this amount of seconds. If the Ray Tune job is stuck in this state (most likely due to insufficient resources),

View file

@ -1,11 +1,12 @@
# coding: utf-8
import copy
import inspect
import random
from collections import deque
from enum import Enum
from functools import partial
import logging
import os
import random
import time
import traceback
from contextlib import contextmanager
@ -15,14 +16,15 @@ from typing import (
Iterable,
List,
Optional,
Union,
Set,
)
import ray
from ray.actor import ActorHandle
from ray.exceptions import GetTimeoutError
from ray import ray_constants
from ray._private.resource_spec import NODE_ID_PREFIX
from ray.tune.error import AbortTrialExecution
from ray.tune.error import AbortTrialExecution, TuneError
from ray.tune.logger import NoopLogger
from ray.tune.result import TRIAL_INFO, STDOUT_FILE, STDERR_FILE
from ray.tune.resources import Resources
@ -33,14 +35,12 @@ from ray.tune.trial_executor import TrialExecutor
from ray.tune.utils import warn_if_slow
from ray.util import log_once
from ray.util.annotations import DeveloperAPI
from ray.util.placement_group import remove_placement_group, PlacementGroup
logger = logging.getLogger(__name__)
TUNE_STATE_REFRESH_PERIOD = 10 # Refresh resources every 10 s
BOTTLENECK_WARN_PERIOD_S = 60
NONTRIVIAL_WAIT_TIME_THRESHOLD_S = 1e-3
DEFAULT_GET_TIMEOUT = 60.0 # seconds
TRIAL_CLEANUP_THRESHOLD = 100
class _ActorClassCache:
@ -86,75 +86,56 @@ class _LocalWrapper:
return self._result
def post_stop_cleanup(future, pg):
"""Things to be done after a trial is stopped."""
assert isinstance(pg, PlacementGroup)
try:
# This should not be blocking as
# we are only here when triggered.
ray.get(future, timeout=0)
except GetTimeoutError:
if log_once("tune_trial_cleanup_timeout"):
logger.error(
"Timed out when trying to stop the Ray actor gracefully. "
"Consider making `stop` a faster operation."
)
except Exception:
if log_once("tune_trial_cleanup_exception"):
logger.error(
f"An exception occurred when trying to stop the Ray actor:"
f"{traceback.format_exc()}"
)
finally:
remove_placement_group(pg)
class _TrialCleanup:
"""Mechanism for ensuring trial stop futures are cleaned up.
"""Responsible for triggering force cleanup of remote actors,
without waiting for `Trainable.stop()` to finish.
Args:
threshold (int): Number of futures to hold at once. If the threshold
is passed, cleanup will kick in and remove futures.
force_cleanup (int): Grace periods for forceful actor termination.
If 0, actors will not be forcefully terminated.
Only instantiated when `TUNE_FORCE_TRIAL_CLEANUP_S` is set up.
"""
def __init__(
self, threshold: int = TRIAL_CLEANUP_THRESHOLD, force_cleanup: int = 0
):
self.threshold = threshold
self._cleanup_map = {}
if force_cleanup < 0:
force_cleanup = 0
def __init__(self, force_cleanup):
assert force_cleanup
self._force_cleanup = force_cleanup
self._future_to_insert_time = deque()
def add(self, trial: Trial, actor: ActorHandle):
"""Adds a trial actor to be stopped.
def add(self, future):
self._future_to_insert_time.append((future, time.time()))
If the number of futures exceeds the threshold, the cleanup mechanism
will kick in.
Args:
trial (Trial): The trial corresponding to the future.
actor (ActorHandle): Handle to the trainable to be stopped.
"""
future = actor.stop.remote()
del actor
self._cleanup_map[future] = trial
if len(self._cleanup_map) > self.threshold:
self.cleanup(partial=True)
def cleanup(self, partial: bool = True):
"""Waits for cleanup to finish.
If partial=False, all futures are expected to return. If a future
does not return within the timeout period, the cleanup terminates.
"""
# At this point, self._cleanup_map holds the last references
# to actors. Removing those references either one-by-one
# (graceful termination case) or all at once, by reinstantiating
# self._cleanup_map (forceful termination case) will cause Ray
# to kill the actors during garbage collection.
logger.debug("Cleaning up futures")
num_to_keep = int(self.threshold) / 2 if partial else 0
while len(self._cleanup_map) > num_to_keep:
dones, _ = ray.wait(
list(self._cleanup_map),
timeout=DEFAULT_GET_TIMEOUT
if not self._force_cleanup
else self._force_cleanup,
)
if not dones:
logger.warning(
"Skipping cleanup - trainable.stop did not return in "
"time. Consider making `stop` a faster operation."
)
if not partial and self._force_cleanup:
logger.warning("Forcing trainable cleanup by terminating actors.")
self._cleanup_map = {}
return
def get_next(self):
"""Get the next future that is eligible to be cleaned up forcibly."""
if (
len(self._future_to_insert_time) > 0
and self._future_to_insert_time[0][1] + self._force_cleanup < time.time()
):
return self._future_to_insert_time.popleft()
else:
done = dones[0]
del self._cleanup_map[done]
return None
def is_empty(self):
return len(self._future_to_insert_time) == 0
def noop_logger_creator(config, logdir):
@ -165,6 +146,39 @@ def noop_logger_creator(config, logdir):
return NoopLogger(config, logdir)
class ExecutorEventType(Enum):
"""The executor event type.
Some of the events are internal events to executor while others
are handled by runner."""
NO_RUNNING_TRIAL_TIMEOUT = 1
PG_READY = 2
TRAINING_RESULT = 3
SAVING_RESULT = 4
RESTORING_RESULT = 5
STOP_RESULT = 6 # Internally to executor only.
ERROR = 7 # This is to signal to TrialRunner that there is an error.
YIELD = 8 # Yielding back to TrialRunner's main event loop.
class ExecutorEvent:
"""A struct that describes the event to be processed by TrialRunner."""
def __init__(
self,
event_type: ExecutorEventType,
trial: Optional[Trial] = None,
result: Optional[Union[str, Dict]] = None,
):
self.type = event_type
self.trial = trial
self.result = result
def __repr__(self):
return f"[{self.type}] for {self.trial}"
@DeveloperAPI
class RayTrialExecutor(TrialExecutor):
"""An implementation of TrialExecutor based on Ray."""
@ -177,10 +191,17 @@ class RayTrialExecutor(TrialExecutor):
wait_for_placement_group: Optional[float] = None,
):
super(RayTrialExecutor, self).__init__()
self._running = {}
# future --> (type, trial/pg)
self._futures = {}
force_trial_cleanup = int(os.environ.get("TUNE_FORCE_TRIAL_CLEANUP_S", "0"))
self._trial_cleanup = _TrialCleanup(force_cleanup=force_trial_cleanup)
self._get_next_event_wait = int(
os.environ.get("TUNE_GET_EXECUTOR_EVENT_WAIT_S", "5")
)
if force_trial_cleanup:
self._trial_cleanup = _TrialCleanup(force_trial_cleanup)
else:
self._trial_cleanup = None
self._has_cleaned_up_pgs = False
self._reuse_actors = reuse_actors
# The maxlen will be updated when `set_max_pending_trials()` is called
@ -189,7 +210,6 @@ class RayTrialExecutor(TrialExecutor):
self._avail_resources = Resources(cpu=0, gpu=0)
self._pg_manager = PlacementGroupManager(prefix=get_tune_pg_prefix())
self._staged_trials = set()
self._just_staged_trials = set()
self._trial_just_finished = False
self._trial_just_finished_before = False
@ -201,12 +221,6 @@ class RayTrialExecutor(TrialExecutor):
)
self._refresh_period = refresh_period
self._wait_for_pg = wait_for_placement_group or float(
os.environ.get("TUNE_PLACEMENT_GROUP_WAIT_S", "-1")
)
if self._wait_for_pg < 0:
self._wait_for_pg = None
self.last_pg_recon = 0
self.pg_recon_interval = float(
os.environ.get("TUNE_PLACEMENT_GROUP_RECON_INTERVAL", "5")
@ -229,10 +243,6 @@ class RayTrialExecutor(TrialExecutor):
if ray.is_initialized():
self._update_avail_resources()
def in_staging_grace_period(self) -> bool:
"""Returns True if trials have recently been staged."""
return self._pg_manager.in_staging_grace_period()
def set_max_pending_trials(self, max_pending: int) -> None:
if len(self._cached_actor_pg) > 0:
logger.warning(
@ -243,7 +253,7 @@ class RayTrialExecutor(TrialExecutor):
self._cached_actor_pg = deque(maxlen=max_pending)
self._pg_manager.set_max_staging(max_pending)
def stage_and_update_status(self, trials: Iterable[Trial]):
def _stage_and_update_status(self, trials: Iterable[Trial]):
"""Check and update statuses of scheduled placement groups.
Stages placement groups of all trials.
@ -255,7 +265,7 @@ class RayTrialExecutor(TrialExecutor):
self._has_cleaned_up_pgs = True
for trial in trials:
if trial.status != Trial.PENDING:
if trial.status not in (Trial.PENDING, Trial.PAUSED):
continue
if trial in self._staged_trials:
continue
@ -266,7 +276,6 @@ class RayTrialExecutor(TrialExecutor):
# Break if we reached the limit of pending placement groups.
break
self._staged_trials.add(trial)
self._just_staged_trials.add(trial)
self._pg_manager.update_status()
@ -279,6 +288,7 @@ class RayTrialExecutor(TrialExecutor):
Trial object or None.
"""
# TODO(xwjiang): This method should consider `self._cached_actor_pg`.
for trial in self._staged_trials:
if self._pg_manager.has_ready(trial):
return trial
@ -317,43 +327,7 @@ class RayTrialExecutor(TrialExecutor):
)
_actor_cls = _class_cache.get(trainable_cls)
if not self._pg_manager.has_ready(trial, update=True):
if trial not in self._staged_trials:
if self._pg_manager.stage_trial_pg(trial):
self._staged_trials.add(trial)
self._just_staged_trials.add(trial)
just_staged = trial in self._just_staged_trials
# This part of the code is mostly here for testing
# purposes. If self._wait_for_pg is set, we will wait here
# for that many seconds until the placement group is ready.
# This ensures that the trial can be started right away and
# not just in the next step() of the trial runner.
# We only do this if we have reason to believe that resources
# will be ready, soon, i.e. when a) we just staged the PG,
# b) another trial just exited, freeing resources, or c)
# when there are no currently running trials.
if self._wait_for_pg is not None and (
just_staged
or self._trial_just_finished_before
or not self.get_running_trials()
):
logger.debug(
f"Waiting up to {self._wait_for_pg} seconds for "
f"placement group of trial {trial} to become ready."
)
wait_end = time.monotonic() + self._wait_for_pg
while time.monotonic() < wait_end:
self._pg_manager.update_status()
if self._pg_manager.has_ready(trial):
break
time.sleep(0.1)
else:
return None
if not self._pg_manager.has_ready(trial):
# PG may have become ready during waiting period
return None
full_actor_class = self._pg_manager.get_full_actor_cls(trial, _actor_cls)
@ -403,7 +377,7 @@ class RayTrialExecutor(TrialExecutor):
def _train(self, trial):
"""Start one iteration of training and save remote id."""
if self._find_item(self._running, trial):
if self._find_future(trial):
logging.debug(
"Trial {} already has a queued future. Skipping this "
"`train` call. This may occur if a trial has "
@ -414,7 +388,7 @@ class RayTrialExecutor(TrialExecutor):
assert trial.status == Trial.RUNNING, trial.status
buffer_time_s = max(
self._buffer_min_time_s,
min(self._buffer_max_time_s, len(self._running) // 10),
min(self._buffer_max_time_s, len(self._futures) // 10),
)
with self._change_working_directory(trial):
buffer_length = self._buffer_length
@ -441,8 +415,8 @@ class RayTrialExecutor(TrialExecutor):
if isinstance(remote, dict):
remote = _LocalWrapper(remote)
self._running[remote] = trial
trial_item = self._find_item(self._running, trial)
self._futures[remote] = (ExecutorEventType.TRAINING_RESULT, trial)
trial_item = self._find_future(trial)
assert len(trial_item) < 2, trial_item
def _start_trial(self, trial) -> bool:
@ -540,11 +514,13 @@ class RayTrialExecutor(TrialExecutor):
if should_destroy_actor:
logger.debug("Trial %s: Destroying actor.", trial)
# Try to return the placement group for other trials to use
self._pg_manager.return_pg(trial)
with self._change_working_directory(trial):
self._trial_cleanup.add(trial, actor=trial.runner)
future = trial.runner.stop.remote()
pg = self._pg_manager.remove_from_in_use(trial)
self._futures[future] = (ExecutorEventType.STOP_RESULT, pg)
if self._trial_cleanup: # force trial cleanup within a deadline
self._trial_cleanup.add(future)
if trial in self._staged_trials:
self._staged_trials.remove(trial)
@ -584,8 +560,8 @@ class RayTrialExecutor(TrialExecutor):
# have been lost. TODO(ujvl): is this the right thing to do?
return False
def _find_item(self, dictionary, item):
out = [rid for rid, t in dictionary.items() if t is item]
def _find_future(self, trial):
out = [rid for rid, t in self._futures.items() if t[1] is trial]
assert (
len(out) <= 1
), "Expecting one future for any given trial at any given time."
@ -598,9 +574,9 @@ class RayTrialExecutor(TrialExecutor):
self._stop_trial(trial, error=error, error_msg=error_msg)
if prior_status == Trial.RUNNING:
logger.debug("Trial %s: Returning resources.", trial)
out = self._find_item(self._running, trial)
out = self._find_future(trial)
for result_id in out:
self._running.pop(result_id)
self._futures.pop(result_id)
def continue_training(self, trial: Trial) -> None:
"""Continues the training of this trial."""
@ -649,63 +625,6 @@ class RayTrialExecutor(TrialExecutor):
return False
return reset_val
def get_running_trials(self) -> List[Trial]:
"""Returns the running trials."""
return list(self._running.values())
def get_next_available_trial(
self, timeout: Optional[float] = None
) -> Optional[Trial]:
if not self._running:
return None
shuffled_results = list(self._running.keys())
random.shuffle(shuffled_results)
# Note: We shuffle the results because `ray.wait` by default returns
# the first available result, and we want to guarantee that slower
# trials (i.e. trials that run remotely) also get fairly reported.
# See https://github.com/ray-project/ray/issues/4211 for details.
start = time.time()
ready, _ = ray.wait(shuffled_results, timeout=timeout)
if not ready:
return None
result_id = ready[0]
wait_time = time.time() - start
if wait_time > NONTRIVIAL_WAIT_TIME_THRESHOLD_S:
self._last_nontrivial_wait = time.time()
if time.time() - self._last_nontrivial_wait > BOTTLENECK_WARN_PERIOD_S:
logger.warning(
"Over the last {} seconds, the Tune event loop has been "
"backlogged processing new results. Consider increasing your "
"period of result reporting to improve performance.".format(
BOTTLENECK_WARN_PERIOD_S
)
)
self._last_nontrivial_wait = time.time()
return self._running[result_id]
def fetch_result(self, trial) -> List[Dict]:
"""Fetches result list of the running trials.
Returns:
Result of the most recent trial training run.
"""
trial_future = self._find_item(self._running, trial)
if not trial_future:
raise ValueError("Trial was not running.")
self._running.pop(trial_future[0])
with warn_if_slow("fetch_result"):
result = ray.get(trial_future[0], timeout=DEFAULT_GET_TIMEOUT)
# For local mode
if isinstance(result, _LocalWrapper):
result = result.unwrap()
if not isinstance(result, list):
return [result]
return result
def _update_avail_resources(self, num_retries=5):
if time.time() - self._last_resource_refresh < self._refresh_period:
return
@ -814,8 +733,7 @@ class RayTrialExecutor(TrialExecutor):
self._trial_just_finished = False
def on_step_end(self, trials: List[Trial]) -> None:
self._just_staged_trials.clear()
self._do_force_trial_cleanup()
if time.time() > self.last_pg_recon + self.pg_recon_interval:
# Only do this every now and then - usually the placement groups
# should not get out of sync, and calling this often is inefficient
@ -824,6 +742,20 @@ class RayTrialExecutor(TrialExecutor):
self._pg_manager.cleanup()
def _do_force_trial_cleanup(self) -> None:
if self._trial_cleanup:
while True:
next_future_to_clean = self._trial_cleanup.get_next()
if not next_future_to_clean:
break
if next_future_to_clean in self._futures.keys():
_, pg = self._futures.pop(next_future_to_clean)
post_stop_cleanup(next_future_to_clean, pg)
else:
# This just means that before the deadline reaches,
# the future is already cleaned up.
pass
def force_reconcilation_on_next_step_end(self) -> None:
self.last_pg_recon = -float("inf")
@ -842,6 +774,7 @@ class RayTrialExecutor(TrialExecutor):
Returns:
Checkpoint object, or None if an Exception occurs.
"""
logger.info(f"saving trial {trial}")
result = result or trial.last_result
with self._change_working_directory(trial):
if storage == Checkpoint.MEMORY:
@ -852,7 +785,7 @@ class RayTrialExecutor(TrialExecutor):
value = trial.runner.save.remote()
checkpoint = Checkpoint(storage, value, result)
trial.saving_to = checkpoint
self._running[value] = trial
self._futures[value] = (ExecutorEventType.SAVING_RESULT, trial)
return checkpoint
def restore(self, trial) -> None:
@ -899,7 +832,7 @@ class RayTrialExecutor(TrialExecutor):
"storage-based restoration"
)
self._running[remote] = trial
self._futures[remote] = (ExecutorEventType.RESTORING_RESULT, trial)
trial.restoring_from = checkpoint
def export_trial_if_needed(self, trial: Trial) -> Dict:
@ -922,7 +855,19 @@ class RayTrialExecutor(TrialExecutor):
return self._avail_resources.gpu > 0
def cleanup(self, trials: List[Trial]) -> None:
self._trial_cleanup.cleanup(partial=False)
while True:
if self._trial_cleanup and self._trial_cleanup.is_empty():
break
elif not self._trial_cleanup and len(self._futures) == 0:
break
self._do_force_trial_cleanup()
ready, _ = ray.wait(list(self._futures.keys()), timeout=0)
if not ready:
continue
event_type, trial_or_pg = self._futures.pop(ready[0])
if event_type == ExecutorEventType.STOP_RESULT:
post_stop_cleanup(ready[0], trial_or_pg)
self._pg_manager.reconcile_placement_groups(trials)
self._pg_manager.cleanup(force=True)
self._pg_manager.cleanup_existing_pg(block=True)
@ -944,6 +889,150 @@ class RayTrialExecutor(TrialExecutor):
else:
yield
def get_next_executor_event(
self, live_trials: Set[Trial], next_trial_exists: bool
) -> ExecutorEvent:
"""Get the next executor event to be processed in TrialRunner.
In case there are multiple events available for handling, the next
event is determined by the following priority:
1. if there is `next_trial_exists`, and if there is cached resources
to use, PG_READY is emitted.
2. if there is `next_trial_exists` and there is no cached resources
to use, wait on pg future and randomized other futures. If multiple
futures are ready, pg future will take priority to be handled first.
3. if there is no `next_trial_exists`, wait on just randomized other
futures.
An example of #3 would be synchronous hyperband. Although there are pgs
ready, the scheduler is holding back scheduling new trials since the
whole band of trials is waiting for the slowest trial to finish. In
this case, we prioritize handling training result to avoid deadlock
situation.
This is a blocking wait with a timeout (specified with env var).
The reason for the timeout is
we still want to print status info periodically in TrialRunner for
better user experience.
The handle of `ExecutorEvent.STOP_RESULT` is purely internal to
RayTrialExecutor itself. All the other future results are handled by
TrialRunner.
In the future we may want to do most of the handle of
`ExecutorEvent.RESTORE_RESULT` and `SAVING_RESULT` in
RayTrialExecutor itself and only notify TrialRunner to invoke
corresponding callbacks. This view is more consistent with our goal
of TrialRunner responsible for external facing Trial state transition,
while RayTrialExecutor responsible for internal facing transitions,
namely, `is_saving`, `is_restoring` etc.
Also you may notice that the boundary between RayTrialExecutor and
PlacementGroupManager right now is really blurry. This will be
improved once we move to an ActorPool abstraction.
`next_trial_exists` means that there is a trial to run - prioritize
returning PG_READY in this case.
"""
# First update status of staged placement groups
self._stage_and_update_status(live_trials)
while True:
###################################################################
# when next_trial_exists and there are cached resources
###################################################################
# There could be existing PGs from either `self._cached_actor_pg`
# or from `self._pg_manager._ready`. If so and if there is indeed
# a next trial to run, we return `PG_READY` future for trial
# runner. The next trial can then be scheduled on this PG.
if next_trial_exists:
if len(self._cached_actor_pg) > 0:
return ExecutorEvent(ExecutorEventType.PG_READY)
# TODO(xwjiang): Expose proper API when we decide to do
# ActorPool abstraction.
if any(len(r) > 0 for r in self._pg_manager._ready.values()):
return ExecutorEvent(ExecutorEventType.PG_READY)
###################################################################
# Prepare for futures to wait
###################################################################
futures_to_wait = list(self._futures.keys())
random.shuffle(futures_to_wait)
if next_trial_exists:
# Only wait for pg explicitly if there is next trial to run.
# In which case, handling PG_READY triumphs handling other events.
# Since we want to place pending trial ASAP.
futures_to_wait = (
self._pg_manager.get_staging_future_list() + futures_to_wait
)
logger.debug(
f"get_next_executor_event before wait with futures "
f"{futures_to_wait} and "
f"next_trial_exists={next_trial_exists}"
)
ready_futures, _ = ray.wait(
futures_to_wait, num_returns=1, timeout=self._get_next_event_wait
)
###################################################################
# Dealing with no future returned case.
###################################################################
if len(ready_futures) == 0:
if len(self._futures) == 0:
# No running trial and timing out with wait, could be we may
# have insufficient cluster resources that makes tune run
# infeasible.
# TODO: Move InsufficientResourceManager's logic
# to TrialExecutor. It is not Runner's responsibility!
return ExecutorEvent(ExecutorEventType.NO_RUNNING_TRIAL_TIMEOUT)
else:
# Training simply takes long time, yield the control back to main
# event loop to print progress info etc.
return ExecutorEvent(ExecutorEventType.YIELD)
###################################################################
# If there is future returned.
###################################################################
assert len(ready_futures) == 1
ready_future = ready_futures[0]
###################################################################
# If it is a PG_READY event.
###################################################################
if ready_future not in self._futures.keys():
# This is a ready future.
self._pg_manager.handle_ready_future(ready_future)
return ExecutorEvent(ExecutorEventType.PG_READY)
###################################################################
# non PG_READY event
###################################################################
result_type, trial_or_pg = self._futures.pop(ready_future)
if result_type == ExecutorEventType.STOP_RESULT:
pg = trial_or_pg
post_stop_cleanup(ready_future, pg)
else:
trial = trial_or_pg
assert isinstance(trial, Trial)
try:
future_result = ray.get(ready_future)
# For local mode
if isinstance(future_result, _LocalWrapper):
future_result = future_result.unwrap()
if result_type in (
ExecutorEventType.TRAINING_RESULT,
ExecutorEventType.SAVING_RESULT,
ExecutorEventType.RESTORING_RESULT,
):
logger.debug(f"Returning [{result_type}] for trial {trial}")
return ExecutorEvent(result_type, trial, result=future_result)
else:
raise TuneError(f"Unexpected future type - [{result_type}]")
except Exception:
return ExecutorEvent(
ExecutorEventType.ERROR, trial, traceback.format_exc()
)
def _to_gb(n_bytes):
return round(n_bytes / (1024 ** 3), 2)

View file

@ -780,7 +780,7 @@ class TrainableFunctionApiTest(unittest.TestCase):
return [0, 1, True, {}]
class FailureInjectionCallback(Callback):
def on_trial_start(self, trials, **info):
def on_step_end(self, **info):
raise RuntimeError
with self.assertRaises(Exception):
@ -870,7 +870,6 @@ class TrainableFunctionApiTest(unittest.TestCase):
trial.last_result.get("trial_resources"), trial.placement_group_factory
)
@patch("ray.tune.ray_trial_executor.TRIAL_CLEANUP_THRESHOLD", 3)
def testLotsOfStops(self):
class TestTrainable(Trainable):
def step(self):

View file

@ -30,12 +30,6 @@ from ray.tune.utils.mock import (
MOCK_REMOTE_DIR,
)
# Wait up to five seconds for placement groups when starting a trial
os.environ["TUNE_PLACEMENT_GROUP_WAIT_S"] = "5"
# Block for results even when placement groups are pending
os.environ["TUNE_TRIAL_STARTUP_GRACE_PERIOD"] = "0"
os.environ["TUNE_TRIAL_RESULT_WAIT_TIME_S"] = "9999"
def _check_trial_running(trial):
if trial.runner:
@ -203,17 +197,10 @@ def test_remove_node_before_result(start_connected_emptyhead_cluster):
assert trial.last_result.get("training_iteration") == 1
# Process result: discover failure, recover, _train (from scratch)
while trial.status != Trial.TERMINATED:
runner.step()
runner.step() # Process result, invoke _train
assert trial.last_result.get("training_iteration") == 1
runner.step() # Process result, invoke _save
assert trial.last_result.get("training_iteration") == 2
# process save, invoke _train
runner.step()
# process result
runner.step()
assert trial.status == Trial.TERMINATED
assert trial.last_result.get("training_iteration") > 1
with pytest.raises(TuneError):
runner.step()
@ -262,7 +249,7 @@ def test_trial_migration(start_connected_emptyhead_cluster, trainable_id):
# assert t.last_result is None, "Trial result not restored correctly."
# Process result (x2), process save, process result (x2), process save
for _ in range(6):
while not runner.is_finished():
runner.step()
assert t.status == Trial.TERMINATED, runner.debug_string()
@ -271,19 +258,13 @@ def test_trial_migration(start_connected_emptyhead_cluster, trainable_id):
t2 = Trial(trainable_id, **kwargs)
runner.add_trial(t2)
# Start trial, process result (x2), process save
for _ in range(4):
while not t2.has_checkpoint():
runner.step()
assert t2.has_checkpoint()
node3 = cluster.add_node(num_cpus=1)
cluster.remove_node(node2)
cluster.wait_for_nodes()
runner.step() # Process result 3 + start and fail 4 result
runner.step() # Dispatch restore
runner.step() # Process restore
runner.step() # Process result 5
if t2.status != Trial.TERMINATED:
runner.step() # Process result 6, dispatch save
runner.step() # Process save
while not runner.is_finished():
runner.step()
assert t2.status == Trial.TERMINATED, runner.debug_string()
# Test recovery of trial that won't be checkpointed
@ -301,8 +282,7 @@ def test_trial_migration(start_connected_emptyhead_cluster, trainable_id):
cluster.add_node(num_cpus=1)
cluster.remove_node(node3)
cluster.wait_for_nodes()
runner.step() # Error handling step
if t3.status != Trial.ERROR:
while not runner.is_finished():
runner.step()
assert t3.status == Trial.ERROR, runner.debug_string()
@ -406,19 +386,14 @@ def test_migration_checkpoint_removal(start_connected_emptyhead_cluster, trainab
runner.add_trial(t1)
# Start trial, process result (x2), process save
for _ in range(4):
while not t1.has_checkpoint():
runner.step()
assert t1.has_checkpoint()
cluster.add_node(num_cpus=1)
cluster.remove_node(node)
cluster.wait_for_nodes()
shutil.rmtree(os.path.dirname(t1.checkpoint.value))
runner.step() # Collect result 3, kick off + fail result 4
runner.step() # Dispatch restore
runner.step() # Process restore + step 4
for _ in range(3):
if t1.status != Trial.TERMINATED:
while not runner.is_finished():
runner.step()
assert t1.status == Trial.TERMINATED, runner.debug_string()

View file

@ -295,11 +295,6 @@ with patch("ray.tune.progress_reporter._get_trial_location",
class ProgressReporterTest(unittest.TestCase):
def setUp(self) -> None:
# Wait up to five seconds for placement groups when starting a trial
os.environ["TUNE_PLACEMENT_GROUP_WAIT_S"] = "5"
# Block for results even when placement groups are pending
os.environ["TUNE_TRIAL_STARTUP_GRACE_PERIOD"] = "0"
os.environ["TUNE_TRIAL_RESULT_WAIT_TIME_S"] = "99999"
os.environ["TUNE_MAX_PENDING_TRIALS_PG"] = "auto"
def mock_trial(self, status, i):

View file

@ -10,7 +10,7 @@ from ray import tune
from ray.rllib import _register_all
from ray.tune import Trainable
from ray.tune.callback import Callback
from ray.tune.ray_trial_executor import RayTrialExecutor
from ray.tune.ray_trial_executor import RayTrialExecutor, ExecutorEventType
from ray.tune.registry import _global_registry, TRAINABLE_CLASS
from ray.tune.result import PID, TRAINING_ITERATION, TRIAL_ID
from ray.tune.suggest import BasicVariantGenerator
@ -83,12 +83,6 @@ class TrialExecutorInsufficientResourcesTest(unittest.TestCase):
class RayTrialExecutorTest(unittest.TestCase):
def setUp(self):
# Wait up to five seconds for placement groups when starting a trial
os.environ["TUNE_PLACEMENT_GROUP_WAIT_S"] = "5"
# Block for results even when placement groups are pending
os.environ["TUNE_TRIAL_STARTUP_GRACE_PERIOD"] = "0"
os.environ["TUNE_TRIAL_RESULT_WAIT_TIME_S"] = "99999"
self.trial_executor = RayTrialExecutor()
ray.init(num_cpus=2, ignore_reinit_error=True)
_register_all() # Needed for flaky tests
@ -97,34 +91,63 @@ class RayTrialExecutorTest(unittest.TestCase):
ray.shutdown()
_register_all() # re-register the evicted objects
def _simulate_starting_trial(self, trial):
future_result = self.trial_executor.get_next_executor_event(
live_trials={trial}, next_trial_exists=True
)
assert future_result.type == ExecutorEventType.PG_READY
self.assertTrue(self.trial_executor.start_trial(trial))
self.assertEqual(Trial.RUNNING, trial.status)
def _simulate_getting_result(self, trial):
while True:
future_result = self.trial_executor.get_next_executor_event(
live_trials={trial}, next_trial_exists=False
)
if future_result.type == ExecutorEventType.TRAINING_RESULT:
break
if isinstance(future_result.result, list):
for r in future_result.result:
trial.update_last_result(r)
else:
trial.update_last_result(future_result.result)
def _simulate_saving(self, trial):
checkpoint = self.trial_executor.save(trial, Checkpoint.PERSISTENT)
self.assertEqual(checkpoint, trial.saving_to)
self.assertEqual(trial.checkpoint.value, None)
future_result = self.trial_executor.get_next_executor_event(
live_trials={trial}, next_trial_exists=False
)
assert future_result.type == ExecutorEventType.SAVING_RESULT
self.process_trial_save(trial, future_result.result)
self.assertEqual(checkpoint, trial.checkpoint)
def testStartStop(self):
trial = Trial("__fake")
self.trial_executor.start_trial(trial)
running = self.trial_executor.get_running_trials()
self.assertEqual(1, len(running))
self._simulate_starting_trial(trial)
self.trial_executor.stop_trial(trial)
def testAsyncSave(self):
"""Tests that saved checkpoint value not immediately set."""
trial = Trial("__fake")
self.trial_executor.start_trial(trial)
self.assertEqual(Trial.RUNNING, trial.status)
trial.last_result = self.trial_executor.fetch_result(trial)[-1]
checkpoint = self.trial_executor.save(trial, Checkpoint.PERSISTENT)
self.assertEqual(checkpoint, trial.saving_to)
self.assertEqual(trial.checkpoint.value, None)
self.process_trial_save(trial)
self.assertEqual(checkpoint, trial.checkpoint)
self._simulate_starting_trial(trial)
self._simulate_getting_result(trial)
self._simulate_saving(trial)
self.trial_executor.stop_trial(trial)
self.assertEqual(Trial.TERMINATED, trial.status)
def testSaveRestore(self):
trial = Trial("__fake")
self.trial_executor.start_trial(trial)
self.assertEqual(Trial.RUNNING, trial.status)
trial.last_result = self.trial_executor.fetch_result(trial)[-1]
self.trial_executor.save(trial, Checkpoint.PERSISTENT)
self.process_trial_save(trial)
self._simulate_starting_trial(trial)
self._simulate_getting_result(trial)
self._simulate_saving(trial)
self.trial_executor.restore(trial)
self.trial_executor.stop_trial(trial)
self.assertEqual(Trial.TERMINATED, trial.status)
@ -132,40 +155,44 @@ class RayTrialExecutorTest(unittest.TestCase):
def testPauseResume(self):
"""Tests that pausing works for trials in flight."""
trial = Trial("__fake")
self.trial_executor.start_trial(trial)
self.assertEqual(Trial.RUNNING, trial.status)
self._simulate_starting_trial(trial)
self.trial_executor.pause_trial(trial)
self.assertEqual(Trial.PAUSED, trial.status)
self.trial_executor.start_trial(trial)
self.assertEqual(Trial.RUNNING, trial.status)
self._simulate_starting_trial(trial)
self.trial_executor.stop_trial(trial)
self.assertEqual(Trial.TERMINATED, trial.status)
def testSavePauseResumeErrorRestore(self):
"""Tests that pause checkpoint does not replace restore checkpoint."""
trial = Trial("__fake")
self.trial_executor.start_trial(trial)
trial.last_result = self.trial_executor.fetch_result(trial)[-1]
self._simulate_starting_trial(trial)
self._simulate_getting_result(trial)
# Save
checkpoint = self.trial_executor.save(trial, Checkpoint.PERSISTENT)
self.assertEqual(Trial.RUNNING, trial.status)
self.assertEqual(checkpoint.storage, Checkpoint.PERSISTENT)
# Process save result (simulates trial runner)
self.process_trial_save(trial)
self._simulate_saving(trial)
# Train
self.trial_executor.continue_training(trial)
trial.last_result = self.trial_executor.fetch_result(trial)[-1]
self._simulate_getting_result(trial)
# Pause
self.trial_executor.pause_trial(trial)
self.assertEqual(Trial.PAUSED, trial.status)
self.assertEqual(trial.checkpoint.storage, Checkpoint.MEMORY)
# Resume
self.trial_executor.start_trial(trial)
self.assertEqual(Trial.RUNNING, trial.status)
self._simulate_starting_trial(trial)
# Error
trial.set_status(Trial.ERROR)
# Restore
self.trial_executor.restore(trial)
self.trial_executor.stop_trial(trial)
self.assertEqual(Trial.TERMINATED, trial.status)
@ -178,13 +205,14 @@ class RayTrialExecutorTest(unittest.TestCase):
def testPauseResume2(self):
"""Tests that pausing works for trials being processed."""
trial = Trial("__fake")
self.trial_executor.start_trial(trial)
self.assertEqual(Trial.RUNNING, trial.status)
self.trial_executor.fetch_result(trial)
self._simulate_starting_trial(trial)
self._simulate_getting_result(trial)
self.trial_executor.pause_trial(trial)
self.assertEqual(Trial.PAUSED, trial.status)
self.trial_executor.start_trial(trial)
self.assertEqual(Trial.RUNNING, trial.status)
self._simulate_starting_trial(trial)
self.trial_executor.stop_trial(trial)
self.assertEqual(Trial.TERMINATED, trial.status)
@ -199,15 +227,17 @@ class RayTrialExecutorTest(unittest.TestCase):
base = max(result_buffer_length, 1)
trial = Trial("__fake")
self.trial_executor.start_trial(trial)
self.assertEqual(Trial.RUNNING, trial.status)
trial.last_result = self.trial_executor.fetch_result(trial)[-1]
self._simulate_starting_trial(trial)
self._simulate_getting_result(trial)
self.assertEqual(trial.last_result.get(TRAINING_ITERATION), base)
self.trial_executor.pause_trial(trial)
self.assertEqual(Trial.PAUSED, trial.status)
self.trial_executor.start_trial(trial)
self.assertEqual(Trial.RUNNING, trial.status)
trial.last_result = self.trial_executor.fetch_result(trial)[-1]
self._simulate_starting_trial(trial)
self._simulate_getting_result(trial)
self.assertEqual(trial.last_result.get(TRAINING_ITERATION), base * 2)
self.trial_executor.stop_trial(trial)
self.assertEqual(Trial.TERMINATED, trial.status)
@ -224,7 +254,7 @@ class RayTrialExecutorTest(unittest.TestCase):
def testNoResetTrial(self):
"""Tests that reset handles NotImplemented properly."""
trial = Trial("__fake")
self.trial_executor.start_trial(trial)
self._simulate_starting_trial(trial)
exists = self.trial_executor.reset_trial(trial, {}, "modified_mock")
self.assertEqual(exists, False)
self.assertEqual(Trial.RUNNING, trial.status)
@ -248,18 +278,18 @@ class RayTrialExecutorTest(unittest.TestCase):
"grid_search",
)
trial = trials[0]
self.trial_executor.start_trial(trial)
self._simulate_starting_trial(trial)
exists = self.trial_executor.reset_trial(trial, {"hi": 1}, "modified_mock")
self.assertEqual(exists, True)
self.assertEqual(trial.config.get("hi"), 1)
self.assertEqual(trial.experiment_tag, "modified_mock")
self.assertEqual(Trial.RUNNING, trial.status)
def testForceTrialCleanup(self):
def testTrialCleanup(self):
class B(Trainable):
def step(self):
print("Step start")
time.sleep(10)
time.sleep(4)
print("Step done")
return dict(my_metric=1, timesteps_this_iter=1, done=True)
@ -269,7 +299,7 @@ class RayTrialExecutorTest(unittest.TestCase):
def cleanup(self):
print("Cleanup start")
time.sleep(10)
time.sleep(4)
print("Cleanup done")
# First check if the trials terminate gracefully by default
@ -281,15 +311,15 @@ class RayTrialExecutorTest(unittest.TestCase):
"grid_search",
)
trial = trials[0]
self.trial_executor.start_trial(trial)
self.assertEqual(Trial.RUNNING, trial.status)
time.sleep(5)
self._simulate_starting_trial(trial)
time.sleep(1)
print("Stop trial")
self.trial_executor.stop_trial(trial)
print("Start trial cleanup")
start = time.time()
self.trial_executor.cleanup([trial])
self.assertGreaterEqual(time.time() - start, 12.0)
# 4 - 1 + 4.
self.assertGreaterEqual(time.time() - start, 6)
# Check forceful termination. It should run for much less than the
# sleep periods in the Trainable
@ -304,15 +334,16 @@ class RayTrialExecutorTest(unittest.TestCase):
os.environ["TUNE_FORCE_TRIAL_CLEANUP_S"] = "1"
self.trial_executor = RayTrialExecutor()
os.environ["TUNE_FORCE_TRIAL_CLEANUP_S"] = "0"
self.trial_executor.start_trial(trial)
self._simulate_starting_trial(trial)
self.assertEqual(Trial.RUNNING, trial.status)
time.sleep(5)
time.sleep(1)
print("Stop trial")
self.trial_executor.stop_trial(trial)
print("Start trial cleanup")
start = time.time()
self.trial_executor.cleanup([trial])
self.assertLess(time.time() - start, 5.0)
# less than 1 with some margin.
self.assertLess(time.time() - start, 2.0)
# also check if auto-filled metrics were returned
self.assertIn(PID, trial.last_result)
@ -332,10 +363,9 @@ class RayTrialExecutorTest(unittest.TestCase):
break
return trials
def process_trial_save(self, trial):
def process_trial_save(self, trial, checkpoint_value):
"""Simulates trial runner save."""
checkpoint = trial.saving_to
checkpoint_value = self.trial_executor.fetch_result(trial)[-1]
checkpoint.value = checkpoint_value
trial.on_checkpoint(checkpoint)
@ -460,10 +490,8 @@ class LocalModeExecutorTest(RayTrialExecutorTest):
ray.shutdown()
_register_all() # re-register the evicted objects
def testForceTrialCleanup(self):
self.skipTest(
"Skipping as force trial cleanup is not applicable" " for local mode."
)
def testTrialCleanup(self):
self.skipTest("Skipping as trial cleanup is not applicable" " for local mode.")
if __name__ == "__main__":

View file

@ -15,6 +15,7 @@ import ray
from ray.rllib import _register_all
from ray import tune
from ray.tune import TuneError
from ray.tune.integration.docker import DockerSyncer
from ray.tune.integration.kubernetes import KubernetesSyncer
from ray.tune.sync_client import NOOP
@ -29,12 +30,6 @@ from ray.tune.utils.callback import create_default_callbacks
class TestSyncFunctionality(unittest.TestCase):
def setUp(self):
# Wait up to 1.5 seconds for placement groups when starting a trial
os.environ["TUNE_PLACEMENT_GROUP_WAIT_S"] = "1.5"
# Block for results even when placement groups are pending
os.environ["TUNE_TRIAL_STARTUP_GRACE_PERIOD"] = "0"
os.environ["TUNE_TRIAL_RESULT_WAIT_TIME_S"] = "99999"
ray.init(num_cpus=2)
def tearDown(self):
@ -120,7 +115,7 @@ class TestSyncFunctionality(unittest.TestCase):
def testClusterProperString(self):
"""Tests that invalid commands throw.."""
with self.assertRaises(ValueError):
with self.assertRaises(TuneError):
# This raises ValueError because logger is init in safe zone.
sync_config = tune.SyncConfig(syncer="ls {target}")
[trial] = tune.run(
@ -131,7 +126,7 @@ class TestSyncFunctionality(unittest.TestCase):
sync_config=sync_config,
).trials
with self.assertRaises(ValueError):
with self.assertRaises(TuneError):
# This raises ValueError because logger is init in safe zone.
sync_config = tune.SyncConfig(syncer="ls {source}")
[trial] = tune.run(

View file

@ -19,29 +19,11 @@ from ray.tune.utils.placement_groups import PlacementGroupFactory
class TrialRunnerTest(unittest.TestCase):
def setUp(self):
# Wait up to five seconds for placement groups when starting a trial
os.environ["TUNE_PLACEMENT_GROUP_WAIT_S"] = "5"
# Block for results even when placement groups are pending
os.environ["TUNE_TRIAL_STARTUP_GRACE_PERIOD"] = "0"
os.environ["TUNE_TRIAL_RESULT_WAIT_TIME_S"] = "99999"
_register_all() # re-register the evicted objects
def tearDown(self):
ray.shutdown()
def testTrialStatus(self):
ray.init(num_cpus=2)
trial = Trial("__fake")
trial_executor = RayTrialExecutor()
self.assertEqual(trial.status, Trial.PENDING)
trial_executor.start_trial(trial)
self.assertEqual(trial.status, Trial.RUNNING)
trial_executor.stop_trial(trial)
self.assertEqual(trial.status, Trial.TERMINATED)
trial_executor.stop_trial(trial, error=True)
self.assertEqual(trial.status, Trial.ERROR)
def testExperimentTagTruncation(self):
ray.init(num_cpus=2)
@ -74,7 +56,8 @@ class TrialRunnerTest(unittest.TestCase):
def testExtraResources(self):
ray.init(num_cpus=4, num_gpus=2)
runner = TrialRunner()
snapshot = TrialStatusSnapshot()
runner = TrialRunner(callbacks=[TrialStatusSnapshotTaker(snapshot)])
kwargs = {
"stopping_criterion": {"training_iteration": 1},
"placement_group_factory": PlacementGroupFactory(
@ -85,17 +68,18 @@ class TrialRunnerTest(unittest.TestCase):
for t in trials:
runner.add_trial(t)
while not runner.is_finished():
runner.step()
self.assertEqual(trials[0].status, Trial.RUNNING)
self.assertEqual(trials[1].status, Trial.PENDING)
runner.step()
self.assertEqual(trials[0].status, Trial.TERMINATED)
self.assertEqual(trials[1].status, Trial.PENDING)
self.assertLess(snapshot.max_running_trials(), 2)
self.assertTrue(snapshot.all_trials_are_terminated())
def testCustomResources(self):
ray.init(num_cpus=4, num_gpus=2, resources={"a": 2})
runner = TrialRunner()
# Since each trial will occupy the full custom resources,
# there are at most 1 trial running at any given moment.
snapshot = TrialStatusSnapshot()
runner = TrialRunner(callbacks=[TrialStatusSnapshotTaker(snapshot)])
kwargs = {
"stopping_criterion": {"training_iteration": 1},
"placement_group_factory": PlacementGroupFactory([{"CPU": 1, "a": 2}]),
@ -104,16 +88,18 @@ class TrialRunnerTest(unittest.TestCase):
for t in trials:
runner.add_trial(t)
while not runner.is_finished():
runner.step()
self.assertEqual(trials[0].status, Trial.RUNNING)
self.assertEqual(trials[1].status, Trial.PENDING)
runner.step()
self.assertEqual(trials[0].status, Trial.TERMINATED)
self.assertEqual(trials[1].status, Trial.PENDING)
self.assertLess(snapshot.max_running_trials(), 2)
self.assertTrue(snapshot.all_trials_are_terminated())
def testExtraCustomResources(self):
ray.init(num_cpus=4, num_gpus=2, resources={"a": 2})
runner = TrialRunner()
# Since each trial will occupy the full custom resources,
# there are at most 1 trial running at any given moment.
snapshot = TrialStatusSnapshot()
runner = TrialRunner(callbacks=[TrialStatusSnapshotTaker(snapshot)])
kwargs = {
"stopping_criterion": {"training_iteration": 1},
"placement_group_factory": PlacementGroupFactory([{"CPU": 1}, {"a": 2}]),
@ -122,14 +108,11 @@ class TrialRunnerTest(unittest.TestCase):
for t in trials:
runner.add_trial(t)
while not runner.is_finished():
runner.step()
self.assertEqual(trials[0].status, Trial.RUNNING)
self.assertEqual(trials[1].status, Trial.PENDING)
runner.step()
self.assertTrue(sum(t.status == Trial.RUNNING for t in trials) < 2)
self.assertEqual(trials[0].status, Trial.TERMINATED)
self.assertEqual(trials[1].status, Trial.PENDING)
self.assertLess(snapshot.max_running_trials(), 2)
self.assertTrue(snapshot.all_trials_are_terminated())
def testFractionalGpus(self):
ray.init(num_cpus=4, num_gpus=1)
@ -209,12 +192,7 @@ class TrialRunnerTest(unittest.TestCase):
for t in trials:
runner.add_trial(t)
runner.step()
self.assertEqual(trials[0].status, Trial.RUNNING)
runner.step()
self.assertEqual(trials[0].status, Trial.RUNNING)
while not runner.is_finished():
runner.step()
self.assertEqual(trials[0].status, Trial.TERMINATED)
self.assertRaises(TuneError, runner.step)
@ -224,15 +202,23 @@ class TrialRunnerTest(unittest.TestCase):
ray.init(num_cpus=2)
class ChangingScheduler(FIFOScheduler):
def __init__(self):
self._has_received_one_trial_result = False
# For figuring out how many runner.step there are.
def has_received_one_trial_result(self):
return self._has_received_one_trial_result
def on_trial_result(self, trial_runner, trial, result):
if result["training_iteration"] == 1:
self._has_received_one_trial_result = True
executor = trial_runner.trial_executor
executor.stop_trial(trial)
executor.pause_trial(trial)
trial.update_resources(dict(cpu=2, gpu=0))
executor.start_trial(trial)
return TrialScheduler.CONTINUE
return TrialScheduler.NOOP
runner = TrialRunner(scheduler=ChangingScheduler())
scheduler = ChangingScheduler()
runner = TrialRunner(scheduler=scheduler)
kwargs = {
"stopping_criterion": {"training_iteration": 2},
"resources": Resources(cpu=1, gpu=0),
@ -250,8 +236,11 @@ class TrialRunnerTest(unittest.TestCase):
ValueError, lambda: trials[0].update_resources(dict(cpu=2, gpu=0))
)
while not scheduler.has_received_one_trial_result():
runner.step()
self.assertEqual(trials[0].status, Trial.PAUSED)
# extra step for tune loop to stage the resource requests.
runner.step()
self.assertEqual(trials[0].status, Trial.RUNNING)
self.assertEqual(
runner.trial_executor._pg_manager.occupied_resources().get("CPU"), 2
)

View file

@ -13,6 +13,7 @@ from ray.tune.trial import Trial
from ray.tune.trial_runner import TrialRunner
from ray.tune.resources import Resources
from ray.tune.suggest import BasicVariantGenerator
from ray.tune.tests.test_trial_runner_utils import TrialResultObserver
def create_mock_components():
@ -37,11 +38,6 @@ def create_mock_components():
class TrialRunnerTest2(unittest.TestCase):
def setUp(self):
os.environ["TUNE_STATE_REFRESH_PERIOD"] = "0.1"
# Wait up to five seconds for placement groups when starting a trial
os.environ["TUNE_PLACEMENT_GROUP_WAIT_S"] = "5"
# Block for results even when placement groups are pending
os.environ["TUNE_TRIAL_STARTUP_GRACE_PERIOD"] = "0"
os.environ["TUNE_TRIAL_RESULT_WAIT_TIME_S"] = "99999"
def tearDown(self):
ray.shutdown()
@ -89,12 +85,9 @@ class TrialRunnerTest2(unittest.TestCase):
runner.add_trial(Trial("__fake", **kwargs))
trials = runner.get_trials()
runner.step() # Start trial
self.assertEqual(trials[0].status, Trial.RUNNING)
runner.step() # Process result, dispatch save
self.assertEqual(trials[0].status, Trial.RUNNING)
runner.step() # Process save
runner.step() # Error
while not runner.is_finished():
runner.step()
self.assertEqual(trials[0].status, Trial.ERROR)
self.assertEqual(trials[0].num_failures, 1)
self.assertEqual(len(searchalg.errored_trials), 1)
@ -107,6 +100,7 @@ class TrialRunnerTest2(unittest.TestCase):
runner = TrialRunner(searchalg, scheduler=scheduler)
kwargs = {
"stopping_criterion": {"training_iteration": 2},
"resources": Resources(cpu=1, gpu=1),
"checkpoint_freq": 1,
"max_failures": 1,
@ -117,18 +111,15 @@ class TrialRunnerTest2(unittest.TestCase):
runner.add_trial(Trial("__fake", **kwargs))
trials = runner.get_trials()
runner.step() # Start trial
self.assertEqual(trials[0].status, Trial.RUNNING)
runner.step() # Process result, dispatch save
self.assertEqual(trials[0].status, Trial.RUNNING)
runner.step() # Process save
runner.step() # Error (transient), dispatch restore
self.assertEqual(trials[0].status, Trial.RUNNING)
while not runner.is_finished():
runner.step()
self.assertEqual(trials[0].status, Trial.TERMINATED)
self.assertEqual(trials[0].num_failures, 1)
runner.step() # Process restore
self.assertEqual(trials[0].status, Trial.RUNNING)
self.assertEqual(len(searchalg.errored_trials), 0)
self.assertEqual(len(scheduler.errored_trials), 0)
# Notice this is 1 since during recovery, the previously errored trial
# is "requeued". This will call scheduler.on_trial_error.
# Searcher.on_trial_error is, however, not called in this process.
self.assertEqual(len(scheduler.errored_trials), 1)
def testFailureRecoveryMaxFailures(self):
ray.init(num_cpus=1, num_gpus=1)
@ -145,20 +136,8 @@ class TrialRunnerTest2(unittest.TestCase):
runner.add_trial(Trial("__fake", **kwargs))
trials = runner.get_trials()
runner.step() # Start trial
self.assertEqual(trials[0].status, Trial.RUNNING)
runner.step() # Process result, dispatch save
self.assertEqual(trials[0].status, Trial.RUNNING)
runner.step() # Process save
runner.step() # Error (transient), dispatch restore
self.assertEqual(trials[0].status, Trial.RUNNING)
self.assertEqual(trials[0].num_failures, 1)
runner.step() # Process restore
runner.step() # Error (transient), dispatch restore
self.assertEqual(trials[0].status, Trial.RUNNING)
self.assertEqual(trials[0].num_failures, 2)
runner.step() # Process restore
runner.step() # Error (terminal)
while not runner.is_finished():
runner.step()
self.assertEqual(trials[0].status, Trial.ERROR)
self.assertEqual(trials[0].num_failures, 3)
@ -178,13 +157,12 @@ class TrialRunnerTest2(unittest.TestCase):
runner.add_trial(Trial("__fake", **kwargs))
trials = runner.get_trials()
runner.step() # Start trial
self.assertEqual(trials[0].status, Trial.RUNNING)
runner.step() # Process result, dispatch save
self.assertEqual(trials[0].status, Trial.RUNNING)
runner.step() # Process save
runner.step() # Error
while not runner.is_finished():
runner.step()
self.assertEqual(trials[0].status, Trial.ERROR)
# Somehow with `fail_fast=True`, if one errors out, the others are
# then stopped with `TERMINATED` status.
self.assertEqual(trials[1].status, Trial.TERMINATED)
self.assertRaises(TuneError, lambda: runner.step())
def testFailFastRaise(self):
@ -203,13 +181,14 @@ class TrialRunnerTest2(unittest.TestCase):
runner.add_trial(Trial("__fake", **kwargs))
trials = runner.get_trials()
runner.step() # Start trial
self.assertEqual(trials[0].status, Trial.RUNNING)
runner.step() # Process result, dispatch save
self.assertEqual(trials[0].status, Trial.RUNNING)
runner.step() # Process save
with self.assertRaises(Exception):
runner.step() # Error
while not runner.is_finished():
runner.step()
# Not critical checks. Only to showcase the difference
# with none raise type FailFast.
self.assertEqual(trials[0].status, Trial.RUNNING)
self.assertEqual(trials[1].status, Trial.PENDING)
def testCheckpointing(self):
ray.init(num_cpus=1, num_gpus=1)
@ -244,35 +223,38 @@ class TrialRunnerTest2(unittest.TestCase):
def testRestoreMetricsAfterCheckpointing(self):
ray.init(num_cpus=1, num_gpus=1)
runner = TrialRunner()
observer = TrialResultObserver()
runner = TrialRunner(callbacks=[observer])
kwargs = {
"stopping_criterion": {"training_iteration": 2},
"resources": Resources(cpu=1, gpu=1),
"checkpoint_freq": 1,
}
runner.add_trial(Trial("__fake", **kwargs))
trials = runner.get_trials()
runner.step() # Start trial
self.assertEqual(trials[0].status, Trial.RUNNING)
self.assertEqual(ray.get(trials[0].runner.set_info.remote(1)), 1)
runner.step() # Process result, dispatch save
runner.step() # Process save
runner.trial_executor.stop_trial(trials[0])
kwargs["restore_path"] = trials[0].checkpoint.value
while not runner.is_finished():
runner.step()
self.assertEqual(trials[0].status, Trial.TERMINATED)
kwargs["restore_path"] = trials[0].checkpoint.value
kwargs.pop("stopping_criterion")
kwargs.pop("checkpoint_freq") # No checkpointing for next trial
runner.add_trial(Trial("__fake", **kwargs))
trials = runner.get_trials()
runner.step() # Start trial, dispatch restore
self.assertEqual(trials[0].status, Trial.TERMINATED)
self.assertEqual(trials[1].status, Trial.RUNNING)
runner.step() # Process restore
runner.step() # Process result
observer.reset()
while not observer.just_received_a_result():
runner.step()
self.assertEqual(trials[1].last_result["timesteps_since_restore"], 10)
self.assertEqual(trials[1].last_result["iterations_since_restore"], 1)
self.assertGreater(trials[1].last_result["time_since_restore"], 0)
runner.step() # Process restore
while not observer.just_received_a_result():
runner.step()
self.assertEqual(trials[1].last_result["timesteps_since_restore"], 20)
self.assertEqual(trials[1].last_result["iterations_since_restore"], 2)
self.assertGreater(trials[1].last_result["time_since_restore"], 0)
@ -289,12 +271,9 @@ class TrialRunnerTest2(unittest.TestCase):
runner.add_trial(Trial("__fake", **kwargs))
trials = runner.get_trials()
runner.step() # Start trial
self.assertEqual(trials[0].status, Trial.RUNNING)
runner.step() # Process result
runner.step() # Process result, dispatch save
while not runner.is_finished():
runner.step()
self.assertEqual(trials[0].last_result[DONE], True)
runner.step() # Process save
self.assertEqual(trials[0].has_checkpoint(), True)
def testResultDone(self):
@ -308,10 +287,7 @@ class TrialRunnerTest2(unittest.TestCase):
runner.add_trial(Trial("__fake", **kwargs))
trials = runner.get_trials()
runner.step()
self.assertEqual(trials[0].status, Trial.RUNNING)
runner.step()
self.assertNotEqual(trials[0].last_result[DONE], True)
while not runner.is_finished():
runner.step()
self.assertEqual(trials[0].last_result[DONE], True)

View file

@ -25,16 +25,11 @@ from ray.tune.suggest._mock import _MockSuggestionAlgorithm
from ray.tune.suggest.suggestion import Searcher, ConcurrencyLimiter
from ray.tune.suggest.search_generator import SearchGenerator
from ray.tune.syncer import SyncConfig
from ray.tune.tests.test_trial_runner_utils import TrialResultObserver
class TrialRunnerTest3(unittest.TestCase):
def setUp(self):
# Wait up to five seconds for placement groups when starting a trial
os.environ["TUNE_PLACEMENT_GROUP_WAIT_S"] = "5"
# Block for results even when placement groups are pending
os.environ["TUNE_TRIAL_STARTUP_GRACE_PERIOD"] = "0"
os.environ["TUNE_TRIAL_RESULT_WAIT_TIME_S"] = "99999"
os.environ["TUNE_MAX_PENDING_TRIALS_PG"] = "auto" # Reset default
self.tmpdir = tempfile.mkdtemp()
@ -131,15 +126,9 @@ class TrialRunnerTest3(unittest.TestCase):
searcher = search_alg.searcher
search_alg.add_configurations(experiments)
runner = TrialRunner(search_alg=search_alg)
runner.step()
trials = runner.get_trials()
self.assertEqual(trials[0].status, Trial.RUNNING)
while not runner.is_finished():
runner.step()
self.assertEqual(trials[0].status, Trial.RUNNING)
runner.step()
self.assertEqual(trials[0].status, Trial.TERMINATED)
self.assertEqual(searcher.counter["result"], 1)
self.assertEqual(searcher.counter["complete"], 1)
@ -204,18 +193,17 @@ class TrialRunnerTest3(unittest.TestCase):
runner = TrialRunner(search_alg=search_alg)
runner.step()
trials = runner.get_trials()
self.assertEqual(trials[0].status, Trial.RUNNING)
while trials[0].status != Trial.TERMINATED:
runner.step()
runner.step()
self.assertEqual(trials[0].status, Trial.TERMINATED)
trials = runner.get_trials()
runner.step()
self.assertEqual(trials[1].status, Trial.RUNNING)
self.assertEqual(len(searcher.live_trials), 1)
searcher.stall = True
while trials[1].status != Trial.TERMINATED:
runner.step()
self.assertEqual(trials[1].status, Trial.TERMINATED)
self.assertEqual(len(searcher.live_trials), 0)
@ -231,8 +219,9 @@ class TrialRunnerTest3(unittest.TestCase):
self.assertEqual(trials[2].status, Trial.RUNNING)
self.assertEqual(len(searcher.live_trials), 1)
while trials[2].status != Trial.TERMINATED:
runner.step()
self.assertEqual(trials[2].status, Trial.TERMINATED)
self.assertEqual(len(searcher.live_trials), 0)
self.assertTrue(search_alg.is_finished())
self.assertTrue(runner.is_finished())
@ -445,9 +434,9 @@ class TrialRunnerTest3(unittest.TestCase):
)
]
runner.add_trial(trials[0])
runner.step() # Start trial
runner.step() # Process result, dispatch save
runner.step() # Process save
while not runner.is_finished():
# Start trial, process result, dispatch save and process save.
runner.step()
self.assertEqual(trials[0].status, Trial.TERMINATED)
trials += [
@ -460,10 +449,13 @@ class TrialRunnerTest3(unittest.TestCase):
)
]
runner.add_trial(trials[1])
runner.step() # Start trial
runner.step() # Process result, dispatch save
runner.step() # Process save
runner.step() # Error
while not runner.is_finished():
# Start trial,
# Process result,
# Dispatch save,
# Process save and
# Error.
runner.step()
self.assertEqual(trials[1].status, Trial.ERROR)
trials += [
@ -488,12 +480,14 @@ class TrialRunnerTest3(unittest.TestCase):
restored_trial = runner2.get_trial("trial_succ")
self.assertEqual(Trial.PENDING, restored_trial.status)
runner2.step() # Start trial
runner2.step() # Process result, dispatch save
runner2.step() # Process save
runner2.step() # Process result, dispatch save
runner2.step() # Process save
self.assertRaises(TuneError, runner2.step)
while not runner2.is_finished():
# Start trial,
# Process result, dispatch save
# Process save
# Process result, dispatch save
# Process save.
runner2.step()
self.assertEqual(restored_trial.status, Trial.TERMINATED)
def testTrialNoCheckpointSave(self):
"""Check that non-checkpointing trials *are* saved."""
@ -533,7 +527,8 @@ class TrialRunnerTest3(unittest.TestCase):
)
)
runner.step()
old_trials = runner.get_trials()
while not old_trials[2].has_reported_at_least_once:
runner.step()
runner2 = TrialRunner(resume="LOCAL", local_checkpoint_dir=self.tmpdir)
@ -643,29 +638,42 @@ class TrialRunnerTest3(unittest.TestCase):
checkpoint_at_end=True,
stopping_criterion={"training_iteration": 4},
)
observer = TrialResultObserver()
runner = TrialRunner(
local_checkpoint_dir=self.tmpdir,
checkpoint_period=0,
trial_executor=RayTrialExecutor(result_buffer_length=7),
callbacks=[observer],
)
runner.add_trial(trial)
runner.step() # start trial
runner.step() # run iteration 1
while not observer.just_received_a_result():
runner.step()
self.assertEqual(trial.last_result[TRAINING_ITERATION], 1)
self.assertEqual(num_checkpoints(trial), 0)
runner.step() # run iteration 2
while True:
runner.step()
if observer.just_received_a_result():
break
self.assertEqual(trial.last_result[TRAINING_ITERATION], 2)
self.assertEqual(num_checkpoints(trial), 0)
runner.step() # run iteration 3
while True:
runner.step()
if observer.just_received_a_result():
break
self.assertEqual(trial.last_result[TRAINING_ITERATION], 3)
self.assertEqual(num_checkpoints(trial), 0)
runner.step() # run iteration 4
while True:
runner.step()
if observer.just_received_a_result():
break
self.assertEqual(trial.last_result[TRAINING_ITERATION], 4)
while not runner.is_finished():
runner.step()
self.assertEqual(num_checkpoints(trial), 1)
def testUserCheckpoint(self):

View file

@ -9,11 +9,14 @@ from collections import OrderedDict
import ray
from ray import tune
from ray.exceptions import RayActorError
from ray.rllib import _register_all
from ray.tune.checkpoint_manager import Checkpoint
from ray.tune.logger import DEFAULT_LOGGERS, LoggerCallback, LegacyLoggerCallback
from ray.tune.ray_trial_executor import RayTrialExecutor
from ray.tune.ray_trial_executor import (
RayTrialExecutor,
ExecutorEvent,
ExecutorEventType,
)
from ray.tune.result import TRAINING_ITERATION
from ray.tune.syncer import SyncConfig, SyncerCallback
@ -65,28 +68,25 @@ class TestCallback(Callback):
self.state["experiment_end"] = info
# TODO(xwjiang): Move this to a testing util.
class _MockTrialExecutor(RayTrialExecutor):
def __init__(self):
super().__init__()
self.next_trial = None
self.results = {}
self.should_fail_in_fetch_result = False
self.next_future_result = None
def fetch_result(self, trial):
if self.should_fail_in_fetch_result:
raise RayActorError(
"The actor died unexpectedly before finishing this task."
)
else:
return [self.results.get(trial, {})]
def start_trial(self, trial: Trial):
trial.status = Trial.RUNNING
return True
def get_next_available_trial(self, timeout=None):
return self.next_trial or super().get_next_available_trial()
def continue_training(self, trial: Trial):
pass
def get_next_executor_event(self, live_trials, next_trial_exists):
return self.next_future_result
class TrialRunnerCallbacks(unittest.TestCase):
def setUp(self):
os.environ["TUNE_PLACEMENT_GROUP_WAIT_S"] = "1"
ray.init()
self.tmpdir = tempfile.mkdtemp()
@ -110,7 +110,9 @@ class TrialRunnerCallbacks(unittest.TestCase):
for t in trials:
self.trial_runner.add_trial(t)
self.executor.next_trial = trials[0]
self.executor.next_future_result = ExecutorEvent(
event_type=ExecutorEventType.PG_READY
)
self.trial_runner.step()
# Trial 1 has been started
@ -132,7 +134,9 @@ class TrialRunnerCallbacks(unittest.TestCase):
)
)
self.executor.next_trial = trials[1]
self.executor.next_future_result = ExecutorEvent(
event_type=ExecutorEventType.PG_READY
)
self.trial_runner.step()
# Iteration not increased yet
@ -148,7 +152,9 @@ class TrialRunnerCallbacks(unittest.TestCase):
cp = Checkpoint(Checkpoint.PERSISTENT, "__checkpoint", {TRAINING_ITERATION: 0})
# Let the first trial save a checkpoint
self.executor.next_trial = trials[0]
self.executor.next_future_result = ExecutorEvent(
event_type=ExecutorEventType.SAVING_RESULT, trial=trials[0]
)
trials[0].saving_to = cp
self.trial_runner.step()
self.assertEqual(self.callback.state["trial_save"]["iteration"], 2)
@ -156,8 +162,9 @@ class TrialRunnerCallbacks(unittest.TestCase):
# Let the second trial send a result
result = {TRAINING_ITERATION: 1, "metric": 800, "done": False}
self.executor.results[trials[1]] = result
self.executor.next_trial = trials[1]
self.executor.next_future_result = ExecutorEvent(
event_type=ExecutorEventType.TRAINING_RESULT, trial=trials[1], result=result
)
self.assertTrue(not trials[1].has_reported_at_least_once)
self.trial_runner.step()
self.assertEqual(self.callback.state["trial_result"]["iteration"], 3)
@ -167,25 +174,28 @@ class TrialRunnerCallbacks(unittest.TestCase):
# Let the second trial restore from a checkpoint
trials[1].restoring_from = cp
self.executor.results[trials[1]] = trials[1].last_result
self.executor.next_future_result = ExecutorEvent(
event_type=ExecutorEventType.RESTORING_RESULT, trial=trials[1]
)
self.trial_runner.step()
self.assertEqual(self.callback.state["trial_restore"]["iteration"], 4)
self.assertEqual(self.callback.state["trial_restore"]["trial"].trial_id, "two")
# Let the second trial finish
trials[1].restoring_from = None
self.executor.results[trials[1]] = {
TRAINING_ITERATION: 2,
"metric": 900,
"done": True,
}
self.executor.next_future_result = ExecutorEvent(
event_type=ExecutorEventType.TRAINING_RESULT,
trial=trials[1],
result={TRAINING_ITERATION: 2, "metric": 900, "done": True},
)
self.trial_runner.step()
self.assertEqual(self.callback.state["trial_complete"]["iteration"], 5)
self.assertEqual(self.callback.state["trial_complete"]["trial"].trial_id, "two")
# Let the first trial error
self.executor.next_trial = trials[0]
self.executor.should_fail_in_fetch_result = True
self.executor.next_future_result = ExecutorEvent(
event_type=ExecutorEventType.ERROR, trial=trials[0]
)
self.trial_runner.step()
self.assertEqual(self.callback.state["trial_fail"]["iteration"], 6)
self.assertEqual(self.callback.state["trial_fail"]["trial"].trial_id, "one")

View file

@ -53,7 +53,6 @@ class TrialRunnerPlacementGroupTest(unittest.TestCase):
self.assertFalse(pg_manager._staging[pgf])
for pgf in pg_manager._ready:
self.assertFalse(pg_manager._ready[pgf])
self.assertTrue(pg_manager._latest_staging_start_time)
num_non_removed_pgs = len(
[p for pid, p in placement_group_table().items() if p["state"] != "REMOVED"]
@ -109,17 +108,6 @@ class TrialRunnerPlacementGroupTest(unittest.TestCase):
total_num_tracked = num_staging + num_ready + num_in_use + num_cached
num_non_removed_pgs = len(
[
p
for pid, p in placement_group_table().items()
if p["state"] != "REMOVED"
]
)
num_removal_scheduled_pgs = len(
trial_executor._pg_manager._pgs_for_removal
)
# All trials should be scheduled
this.assertEqual(
scheduled,
@ -141,13 +129,6 @@ class TrialRunnerPlacementGroupTest(unittest.TestCase):
msg=f"Num tracked iter {iteration}",
)
# The number of actual placement groups should match this
this.assertGreaterEqual(
max(scheduled, len(trials)) - num_finished + num_parallel_reuse,
num_non_removed_pgs - num_removal_scheduled_pgs,
msg=f"Num actual iter {iteration}",
)
start = time.time()
out = tune.run(
train,

View file

@ -0,0 +1,22 @@
from ray.tune import Callback
class TrialResultObserver(Callback):
"""Helper class to control runner.step() count."""
def __init__(self):
self._counter = 0
self._last_counter = 0
def reset(self):
self._last_counter = self._counter
def just_received_a_result(self):
if self._last_counter == self._counter:
return False
else:
self._last_counter = self._counter
return True
def on_trial_result(self, **kwargs):
self._counter += 1

View file

@ -173,7 +173,6 @@ class PopulationBasedTrainingFileDescriptorTest(unittest.TestCase):
class PopulationBasedTrainingSynchTest(unittest.TestCase):
def setUp(self):
os.environ["TUNE_TRIAL_STARTUP_GRACE_PERIOD"] = "0"
ray.init(num_cpus=2)
def MockTrainingFuncSync(config, checkpoint_dir=None):

View file

@ -105,13 +105,6 @@ def _run(local_dir, driver_semaphore, trainer_semaphore):
class TuneInterruptionTest(unittest.TestCase):
def setUp(self) -> None:
# Wait up to five seconds for placement groups when starting a trial
os.environ["TUNE_PLACEMENT_GROUP_WAIT_S"] = "5"
# Block for results even when placement groups are pending
os.environ["TUNE_TRIAL_STARTUP_GRACE_PERIOD"] = "0"
os.environ["TUNE_TRIAL_RESULT_WAIT_TIME_S"] = "99999"
def testExperimentInterrupted(self):
local_dir = tempfile.mkdtemp()
# Unix platforms may default to "fork", which is problematic with
@ -214,11 +207,6 @@ class TuneFailResumeGridTest(unittest.TestCase):
def setUp(self):
self.logdir = tempfile.mkdtemp()
os.environ["TUNE_GLOBAL_CHECKPOINT_S"] = "0"
# Wait up to 1.5 seconds for placement groups when starting a trial
os.environ["TUNE_PLACEMENT_GROUP_WAIT_S"] = "1.5"
# Block for results even when placement groups are pending
os.environ["TUNE_TRIAL_STARTUP_GRACE_PERIOD"] = "0"
os.environ["TUNE_TRIAL_RESULT_WAIT_TIME_S"] = "99999"
# Change back to local_mode=True after this is resolved:
# https://github.com/ray-project/ray/issues/13932

View file

@ -1,4 +1,3 @@
import os
import requests
import socket
import subprocess
@ -28,11 +27,6 @@ def get_valid_port():
class TuneServerSuite(unittest.TestCase):
def basicSetup(self):
# Wait up to five seconds for placement groups when starting a trial
os.environ["TUNE_PLACEMENT_GROUP_WAIT_S"] = "5"
# Block for results even when placement groups are pending
os.environ["TUNE_TRIAL_STARTUP_GRACE_PERIOD"] = "0"
os.environ["TUNE_TRIAL_RESULT_WAIT_TIME_S"] = "99999"
ray.init(num_cpus=4, num_gpus=1)
port = get_valid_port()

View file

@ -4,7 +4,6 @@ import logging
from typing import Dict, List, Optional
import warnings
from ray.tune.resources import Resources
from ray.util.annotations import DeveloperAPI
from ray.tune.trial import Trial, Checkpoint
@ -79,11 +78,6 @@ class TrialExecutor(metaclass=_WarnOnDirectInheritanceMeta):
self._trials_to_cache.clear()
return self._cached_trial_state
@abstractmethod
def has_resources(self, resources: Resources) -> bool:
"""Returns whether this runner has at least the specified resources."""
pass
@abstractmethod
def start_trial(self, trial: Trial) -> bool:
"""Starts the trial restoring from checkpoint if checkpoint is provided.
@ -150,11 +144,6 @@ class TrialExecutor(metaclass=_WarnOnDirectInheritanceMeta):
"""
pass
@abstractmethod
def get_running_trials(self) -> List[Trial]:
"""Returns all running trials."""
pass
def on_step_begin(self, trials: List[Trial]) -> None:
"""A hook called before running one step of the trial event loop.
@ -176,26 +165,6 @@ class TrialExecutor(metaclass=_WarnOnDirectInheritanceMeta):
def force_reconcilation_on_next_step_end(self) -> None:
pass
@abstractmethod
def get_next_available_trial(self) -> Optional[Trial]:
"""Blocking call that waits until one result is ready.
Returns:
Trial object that is ready for intermediate processing.
"""
pass
@abstractmethod
def fetch_result(self, trial: Trial) -> List[Trial]:
"""Fetches one result for the trial.
Assumes the trial is running.
Returns:
Result object for the trial.
"""
pass
@abstractmethod
def debug_string(self) -> str:
"""Returns a human readable message for printing to the console."""
@ -260,10 +229,6 @@ class TrialExecutor(metaclass=_WarnOnDirectInheritanceMeta):
"""
pass
def in_staging_grace_period(self) -> bool:
"""Returns True if trials have recently been staged."""
return False
def set_max_pending_trials(self, max_pending: int) -> None:
"""Set the maximum number of allowed pending trials."""
pass

View file

@ -15,7 +15,7 @@ from ray.tune import TuneError
from ray.tune.callback import CallbackList
from ray.tune.experiment import Experiment
from ray.tune.insufficient_resources_manager import InsufficientResourcesManager
from ray.tune.ray_trial_executor import RayTrialExecutor
from ray.tune.ray_trial_executor import RayTrialExecutor, ExecutorEventType
from ray.tune.result import (
DEBUG_METRICS,
DEFAULT_METRIC,
@ -338,7 +338,6 @@ class TrialRunner:
self._cached_trial_decisions = {}
self._queued_trial_decisions = {}
self._updated_queue = False
self._result_wait_time = int(os.getenv("TUNE_TRIAL_RESULT_WAIT_TIME_S", "1"))
self._stop_queue = []
self._should_stop_experiment = False # used by TuneServer
@ -685,23 +684,14 @@ class TrialRunner:
) and all(trial.is_finished() for trial in self._trials)
return trials_done and self._search_alg.is_finished()
def step(self):
"""Runs one step of the trial event loop.
def _update_trial_queue_and_get_next_trial(self) -> Optional[Trial]:
"""Adding suggested trials to the live queue of trials (they start as PENDING trials).
Callers should typically run this method repeatedly in a loop. They
may inspect or modify the runner's state in between calls to step().
Returns:
next_trial: Trial
"""
self._updated_queue = False
if self.is_finished():
raise TuneError("Called step when all trials finished?")
with warn_if_slow("on_step_begin"):
self.trial_executor.on_step_begin(self.get_trials())
with warn_if_slow("callbacks.on_step_begin"):
self._callbacks.on_step_begin(
iteration=self._iteration, trials=self._trials
)
# This will contain the next trial to start
next_trial = self._get_next_trial() # blocking
# Create pending trials. If the queue was updated before, only
@ -715,46 +705,64 @@ class TrialRunner:
break
num_pending_trials += 1
# Update status of staged placement groups
self.trial_executor.stage_and_update_status(self._live_trials)
return next_trial
def _start_trial(trial: Trial) -> bool:
"""Helper function to start trial and call callbacks"""
with warn_if_slow("start_trial"):
if self.trial_executor.start_trial(trial):
self._callbacks.on_trial_start(
iteration=self._iteration, trials=self._trials, trial=trial
def _wait_and_handle_event(self, next_trial: Optional[Trial]):
try:
# Single wait of entire tune loop.
future_result = self.trial_executor.get_next_executor_event(
self._live_trials, next_trial is not None
)
return True
return False
may_handle_events = True
if next_trial is not None:
if _start_trial(next_trial):
may_handle_events = False
elif next_trial.status != Trial.ERROR:
# Only try to start another trial if previous trial startup
# did not error (e.g. it just didn't start because its
# placement group is not ready, yet).
# Without this clause, this test fails:
# test_trial_runner_pg.py::
# TrialRunnerPlacementGroupHeterogeneousTest::
# testResourceDeadlock
next_trial = self.trial_executor.get_staged_trial()
if next_trial is not None:
if _start_trial(next_trial):
may_handle_events = False
if may_handle_events:
if self.trial_executor.get_running_trials():
timeout = self._result_wait_time
if self.trial_executor.in_staging_grace_period():
timeout = 0.1
self._process_events(timeout=timeout)
else:
if future_result.type == ExecutorEventType.PG_READY:
self._on_pg_ready(next_trial)
elif future_result.type == ExecutorEventType.NO_RUNNING_TRIAL_TIMEOUT:
self._insufficient_resources_manager.on_no_available_trials(
self.get_trials()
)
elif future_result.type == ExecutorEventType.YIELD:
pass
else:
trial = future_result.trial
result = future_result.result
if future_result.type == ExecutorEventType.ERROR:
self._on_executor_error(trial, result)
elif future_result.type == ExecutorEventType.RESTORING_RESULT:
self._on_restoring_result(trial)
else:
assert future_result.type in (
ExecutorEventType.SAVING_RESULT,
ExecutorEventType.TRAINING_RESULT,
), f"Unexpected future type - {future_result.type}"
if future_result.type == ExecutorEventType.TRAINING_RESULT:
self._on_training_result(trial, result)
else:
self._on_saving_result(trial, result)
self._post_process_on_training_saving_result(trial)
except Exception as e:
if e is TuneError:
raise e
else:
raise TuneError(traceback.format_exc())
def step(self):
"""Runs one step of the trial event loop.
Callers should typically run this method repeatedly in a loop. They
may inspect or modify the runner's state in between calls to step().
"""
if self.is_finished():
raise TuneError("Called step when all trials finished?")
with warn_if_slow("on_step_begin"):
self.trial_executor.on_step_begin(self.get_trials())
with warn_if_slow("callbacks.on_step_begin"):
self._callbacks.on_step_begin(
iteration=self._iteration, trials=self._trials
)
next_trial = self._update_trial_queue_and_get_next_trial()
self._wait_and_handle_event(next_trial)
self._stop_experiment_if_needed()
@ -770,12 +778,100 @@ class TrialRunner:
if self.is_finished():
self._server.shutdown()
self._reconcile_live_trials()
with warn_if_slow("on_step_end"):
self.trial_executor.on_step_end(self.get_trials())
with warn_if_slow("callbacks.on_step_end"):
self._callbacks.on_step_end(iteration=self._iteration, trials=self._trials)
self._reconcile_live_trials()
def _on_pg_ready(self, next_trial: Optional[Trial]):
def _start_trial(trial: Trial) -> bool:
"""Helper function to start trial and call callbacks"""
with warn_if_slow("start_trial"):
if self.trial_executor.start_trial(trial):
self._callbacks.on_trial_start(
iteration=self._iteration, trials=self._trials, trial=trial
)
return True
return False
assert next_trial is not None
logger.info(f"starting {next_trial}")
if not _start_trial(next_trial) and next_trial.status != Trial.ERROR:
# Only try to start another trial if previous trial startup
# did not error (e.g. it just didn't start because its
# placement group is not ready, yet).
# Without this clause, this test fails:
# test_trial_runner_pg.py::
# TrialRunnerPlacementGroupHeterogeneousTest::
# testResourceDeadlock
next_trial = self.trial_executor.get_staged_trial()
if next_trial is not None:
# Must be able to start.
assert _start_trial(next_trial)
else:
logger.info(f"reconciling {self.get_trials()}")
self.trial_executor._pg_manager.reconcile_placement_groups(
self.get_trials()
)
def _on_saving_result(self, trial, result):
with warn_if_slow("process_trial_save") as _profile:
self._process_trial_save(trial, result)
with warn_if_slow("callbacks.on_trial_save"):
self._callbacks.on_trial_save(
iteration=self._iteration, trials=self._trials, trial=trial
)
if _profile.too_slow and trial.sync_on_checkpoint:
# TODO(ujvl): Suggest using cloud checkpointing once
# API has converged.
msg = (
"Consider turning off forced head-worker trial "
"checkpoint syncs by setting sync_on_checkpoint=False"
". Note that this may result in faulty trial "
"restoration if a failure occurs while the checkpoint "
"is being synced from the worker to the head node."
)
if trial.location.hostname and (
trial.location.hostname != get_node_ip_address()
):
if log_once("tune_head_worker_checkpoint"):
logger.warning(msg)
def _on_restoring_result(self, trial):
with warn_if_slow("process_trial_restore"):
self._process_trial_restore(trial)
with warn_if_slow("callbacks.on_trial_restore"):
self._callbacks.on_trial_restore(
iteration=self._iteration, trials=self._trials, trial=trial
)
def _on_training_result(self, trial, result):
if not isinstance(result, list):
result = [result]
with warn_if_slow("process_trial_result"):
self._process_trial_results(trial, result)
def _post_process_on_training_saving_result(self, trial):
# `self._queued_trial_decisions` now contains a final decision
# based on all results
if trial not in self._cached_trial_decisions:
final_decision = self._queued_trial_decisions.pop(trial.trial_id, None)
if final_decision:
self._execute_action(trial, final_decision)
def _on_executor_error(self, trial, result):
error_msg = f"Trial {trial}: Error processing event."
if self._fail_fast == TrialRunner.RAISE:
logger.error(error_msg)
raise
else:
logger.exception(error_msg)
self._process_trial_failure(trial, result)
def get_trial(self, tid):
trial = [t for t in self._trials if t.trial_id == tid]
@ -857,75 +953,8 @@ class TrialRunner:
logger.debug("Running trial {}".format(trial))
return trial
def _process_events(self, timeout: Optional[float] = None):
# TODO(ujvl): Consider combining get_next_available_trial and
# fetch_result functionality so that we don't timeout on fetch.
trial = self.trial_executor.get_next_available_trial(
timeout=timeout
) # blocking
if not trial:
return
if trial.is_restoring:
with warn_if_slow("process_trial_restore"):
self._process_trial_restore(trial)
with warn_if_slow("callbacks.on_trial_restore"):
self._callbacks.on_trial_restore(
iteration=self._iteration, trials=self._trials, trial=trial
)
elif trial.is_saving:
with warn_if_slow("process_trial_save") as _profile:
self._process_trial_save(trial)
with warn_if_slow("callbacks.on_trial_save"):
self._callbacks.on_trial_save(
iteration=self._iteration, trials=self._trials, trial=trial
)
if _profile.too_slow and trial.sync_on_checkpoint:
# TODO(ujvl): Suggest using cloud checkpointing once
# API has converged.
msg = (
"Consider turning off forced head-worker trial "
"checkpoint syncs by setting sync_on_checkpoint=False"
". Note that this may result in faulty trial "
"restoration if a failure occurs while the checkpoint "
"is being synced from the worker to the head node."
)
if trial.location.hostname and (
trial.location.hostname != get_node_ip_address()
):
if log_once("tune_head_worker_checkpoint"):
logger.warning(msg)
else:
with warn_if_slow("process_trial"):
self._process_trial(trial)
# `self._queued_trial_decisions` now contains a final decision
# based on all results
if trial not in self._cached_trial_decisions:
final_decision = self._queued_trial_decisions.pop(trial.trial_id, None)
if final_decision:
self._execute_action(trial, final_decision)
def _process_trial(self, trial):
"""Processes a trial result.
Fetches the trial's latest result and makes a scheduling decision
regarding its next action. If a checkpoint is taken, the decided
action is cached and acted on only after the checkpoint is later
processed (see `_process_trial_save`). Otherwise the decision is
acted on immediately.
If multiple results are received (e.g. because of buffering), all
results are processed and the final action is determined. STOP
takes precedence over PAUSE, which takes precedence over CONTINUE.
Args:
trial (Trial): Trial with a result ready to be processed.
"""
try:
results = self.trial_executor.fetch_result(trial)
def _process_trial_results(self, trial, results):
logger.debug(f"process_trial_results {results}")
with warn_if_slow(
"process_trial_results",
message="Processing trial results took {duration:.3f} s, "
@ -955,14 +984,6 @@ class TrialRunner:
# If the decision is to stop the trial,
# ignore all results that came after that.
break
except Exception:
error_msg = "Trial %s: Error processing event." % trial
if self._fail_fast == TrialRunner.RAISE:
logger.error(error_msg)
raise
else:
logger.exception(error_msg)
self._process_trial_failure(trial, traceback.format_exc())
def _process_trial_result(self, trial, result):
result.update(trial_id=trial.trial_id)
@ -1014,6 +1035,7 @@ class TrialRunner:
self._checkpoint_trial_if_needed(trial, force=force_checkpoint)
if trial.is_saving:
logger.info(f"caching trial decision {trial}")
# Cache decision to execute on after the save is processed.
# This prevents changing the trial's state or kicking off
# another training step prematurely.
@ -1085,7 +1107,7 @@ class TrialRunner:
)
)
def _process_trial_save(self, trial):
def _process_trial_save(self, trial, result):
"""Processes a trial save.
Acts on the decision cached during the last `_process_trial` call.
@ -1094,19 +1116,9 @@ class TrialRunner:
trial (Trial): Trial being saved.
"""
logger.debug("Trial %s: Processing trial save.", trial)
checkpoint_value = None
try:
results = self.trial_executor.fetch_result(trial)
checkpoint_value = results[-1]
except Exception:
logger.exception("Trial %s: Error processing result.", trial)
if self._fail_fast == TrialRunner.RAISE:
raise
self._process_trial_failure(trial, traceback.format_exc())
if checkpoint_value:
try:
trial.saving_to.value = checkpoint_value
trial.saving_to.value = result
self._callbacks.on_checkpoint(
iteration=self._iteration,
trials=self._trials,
@ -1117,15 +1129,13 @@ class TrialRunner:
if trial.checkpoint.storage != Checkpoint.MEMORY:
self.trial_executor.mark_trial_to_checkpoint(trial)
except Exception:
logger.exception(
"Trial %s: Error handling checkpoint %s", trial, checkpoint_value
)
logger.exception("Trial %s: Error handling checkpoint %s", trial, result)
if self._fail_fast == TrialRunner.RAISE:
raise
trial.saving_to = None
decision = self._cached_trial_decisions.pop(trial.trial_id, None)
if decision and checkpoint_value:
if decision and result:
self._queue_decision(trial, decision)
def _process_trial_restore(self, trial):
@ -1135,18 +1145,11 @@ class TrialRunner:
trial (Trial): Trial being restored.
"""
logger.debug("Trial %s: Processing trial restore.", trial)
try:
self.trial_executor.fetch_result(trial)
trial.on_restore()
logger.debug("Trial %s: Restore processed successfully", trial)
self.trial_executor.set_status(trial, Trial.RUNNING)
self.trial_executor.continue_training(trial)
self._live_trials.add(trial)
except Exception:
logger.exception("Trial %s: Error processing restore.", trial)
if self._fail_fast == TrialRunner.RAISE:
raise
self._process_trial_failure(trial, traceback.format_exc())
def _process_trial_failure(self, trial, error_msg):
"""Handle trial failure.
@ -1217,6 +1220,10 @@ class TrialRunner:
trial (Trial): Trial to recover.
error_msg (str): Error message from prior to invoking this method.
"""
self._cached_trial_decisions.pop(trial.trial_id, None)
# Resetting this, in case that the trial is in saving status when it crashes.
if trial.is_saving:
trial.saving_to = None
if trial.is_restoring:
# Restore was unsuccessful, try again without checkpoint.
trial.clear_checkpoint()
@ -1229,6 +1236,8 @@ class TrialRunner:
"Trial %s: Attempting to restore " "trial state from last checkpoint.",
trial,
)
# TODO(xwjiang): For better consistency, consider not starting
# trials here. Instead rely on requeuing the trial.
started = self.trial_executor.start_trial(trial)
if not started:
requeue_trial = True

View file

@ -306,6 +306,16 @@ def run(
"removing this argument from your call to `tune.run()`"
)
# Starting deprecation in ray 1.10.
if os.environ.get("TUNE_TRIAL_RESULT_WAIT_TIME_S") is not None:
warnings.warn("`TUNE_TRIAL_RESULT_WAIT_TIME_S` is deprecated.")
if os.environ.get("TUNE_TRIAL_STARTUP_GRACE_PERIOD") is not None:
warnings.warn("`TUNE_TRIAL_STARTUP_GRACE_PERIOD` is deprecated.")
if os.environ.get("TUNE_PLACEMENT_GROUP_WAIT_S") is not None:
warnings.warn("`TUNE_PLACEMENT_GROUP_WAIT_S` is deprecated.")
# NO CODE IS TO BE ADDED ABOVE THIS COMMENT
# remote_run_kwargs must be defined before any other
# code is ran to ensure that at this point,

View file

@ -316,16 +316,16 @@ class PlacementGroupManager:
self._cached_pgs: Dict[PlacementGroup, PlacementGroupFactory] = {}
# Placement groups scheduled for delayed removal.
# This is used as a damper to filter out some high frequency change
# in resources request.
# Only PGs that have never been used go here.
# TODO(xwjiang): `self._pgs_for_removal` and `self._unstaged_xxx`
# are really the same now. We should consolidate to using one.
# Also `remove_placement_group` method should just be combined with
# `unstage_unused_xxx`.
self._pgs_for_removal: Dict[PlacementGroup, float] = {}
self._removal_delay = TUNE_PLACEMENT_GROUP_REMOVAL_DELAY
# Latest PG staging time to check if still in grace period.
self._latest_staging_start_time = time.time()
# Seconds we wait for a trial to come up before we make blocking calls
# to process events
self._grace_period = float(os.getenv("TUNE_TRIAL_STARTUP_GRACE_PERIOD", 10.0))
self._max_staging = max_staging
def set_max_staging(self, max_staging: int):
@ -440,7 +440,6 @@ class PlacementGroupManager:
self._staging[pgf].add(pg)
self._staging_futures[pg.ready()] = (pgf, pg)
self._latest_staging_start_time = time.time()
return True
@ -461,11 +460,17 @@ class PlacementGroupManager:
ready, _ = ray.wait(list(self._staging_futures.keys()), timeout=0)
for ready_fut in ready:
self.handle_ready_future(ready_fut)
def handle_ready_future(self, ready_fut):
ready_pgf, ready_pg = self._staging_futures.pop(ready_fut)
self._staging[ready_pgf].remove(ready_pg)
self._ready[ready_pgf].add(ready_pg)
def get_staging_future_list(self):
return list(self._staging_futures.keys())
def get_full_actor_cls(
self, trial: "Trial", actor_cls: ActorClass
) -> Optional[ActorClass]:
@ -602,7 +607,7 @@ class PlacementGroupManager:
def clean_cached_pg(self, pg: PlacementGroup):
self._cached_pgs.pop(pg)
def return_pg(self, trial: "Trial"):
def remove_from_in_use(self, trial: "Trial") -> PlacementGroup:
"""Return pg back to Core scheduling.
Args:
@ -612,7 +617,7 @@ class PlacementGroupManager:
pg = self._in_use_trials.pop(trial)
self._in_use_pgs.pop(pg)
self.remove_pg(pg)
return pg
def _unstage_unused_pg(
self, pgf: PlacementGroupFactory
@ -664,13 +669,6 @@ class PlacementGroupManager:
return trial_pg
def in_staging_grace_period(self):
return (
self._staging_futures
and self._grace_period
and time.time() <= self._latest_staging_start_time + self._grace_period
)
def reconcile_placement_groups(self, trials: List["Trial"]):
"""Reconcile placement groups to match requirements.