mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[tune] Single wait refactor. (#21852)
This is a down scoped change. For the full overview picture of Tune control loop, see [`Tune control loop refactoring`](https://docs.google.com/document/d/1RDsW7SVzwMPZfA0WLOPA4YTqbRyXIHGYmBenJk33HaE/edit#heading=h.2za3bbxbs5gn) 1. Previously there are separate waits on pg ready and other events. As a result, there are quite a few timing tweaks that are inefficient, hard to understand and unit test. This PR consolidates into a single wait that is handled by TrialRunner in each step. - A few event types are introduced, and their mapping into scenarios * PG_READY --> Should place a trial onto it. If somehow there is no trial to be placed there, the pg will be put in _ready momentarily. This is due to historically resources is conceptualized as a pull based model. * NO_RUNNING_TRIALS_TIME_OUT --> possibly not sufficient resources case * TRAINING_RESULT * SAVING_RESULT * RESTORING_RESULT * YIELD --> This just means that simply taking very long to train. We need to punt back to the main loop to print out status info etc. 2. Previously TrialCleanup is not very efficient and can be racing between Trainable.stop() and `return_placement_group`. This PR streamlines the Trial cleanup process by explicitly let Trainable.stop() to finish followed by `return_placement_group(pg)`. Note, graceful shutdown is needed in cases like `pause_trial` where checkpointing to memory needs to be given the time to happen before the actor is gone. 3. There are quite some env variables removed (timing tweaks), that I consider OK to proceed without deprecation cycle.
This commit is contained in:
parent
dea3574050
commit
323511b716
20 changed files with 825 additions and 802 deletions
|
@ -36,6 +36,8 @@ These are the environment variables Ray Tune currently considers:
|
||||||
letting them finish the current training step and any user-defined cleanup.
|
letting them finish the current training step and any user-defined cleanup.
|
||||||
Setting this variable to a non-zero, positive integer will cause trials to be forcefully
|
Setting this variable to a non-zero, positive integer will cause trials to be forcefully
|
||||||
terminated after a grace period of that many seconds. Defaults to ``0``.
|
terminated after a grace period of that many seconds. Defaults to ``0``.
|
||||||
|
* **TUNE_GET_EXECUTOR_EVENT_WAIT_S**: The time that TrialRunner waits for the
|
||||||
|
next ExecutorEvent in a blocking fashion. Defaults to ``5``.
|
||||||
* **TUNE_FUNCTION_THREAD_TIMEOUT_S**: Time in seconds the function API waits
|
* **TUNE_FUNCTION_THREAD_TIMEOUT_S**: Time in seconds the function API waits
|
||||||
for threads to finish after instructing them to complete. Defaults to ``2``.
|
for threads to finish after instructing them to complete. Defaults to ``2``.
|
||||||
* **TUNE_GLOBAL_CHECKPOINT_S**: Time in seconds that limits how often Tune's
|
* **TUNE_GLOBAL_CHECKPOINT_S**: Time in seconds that limits how often Tune's
|
||||||
|
@ -57,10 +59,6 @@ These are the environment variables Ray Tune currently considers:
|
||||||
In normal circumstances these shouldn't differ anyway, but reconcilation makes sure to capture cases when
|
In normal circumstances these shouldn't differ anyway, but reconcilation makes sure to capture cases when
|
||||||
placement groups are manually destroyed. Reconcilation doesn't take much time, but it can add up when
|
placement groups are manually destroyed. Reconcilation doesn't take much time, but it can add up when
|
||||||
running a large number of short trials. Defaults to every ``5`` (seconds).
|
running a large number of short trials. Defaults to every ``5`` (seconds).
|
||||||
* **TUNE_PLACEMENT_GROUP_WAIT_S**: Default time the trial executor waits for placement
|
|
||||||
groups to be placed before continuing the tuning loop. Setting this to a float
|
|
||||||
will block for that many seconds. This is mostly used for testing purposes. Defaults
|
|
||||||
to -1, which disables blocking.
|
|
||||||
* **TUNE_RESULT_DIR**: Directory where Ray Tune trial results are stored. If this
|
* **TUNE_RESULT_DIR**: Directory where Ray Tune trial results are stored. If this
|
||||||
is not set, ``~/ray_results`` will be used.
|
is not set, ``~/ray_results`` will be used.
|
||||||
* **TUNE_RESULT_BUFFER_LENGTH**: Ray Tune can buffer results from trainables before they are passed
|
* **TUNE_RESULT_BUFFER_LENGTH**: Ray Tune can buffer results from trainables before they are passed
|
||||||
|
@ -74,11 +72,6 @@ These are the environment variables Ray Tune currently considers:
|
||||||
but never longer than this value. Defaults to 100 (seconds).
|
but never longer than this value. Defaults to 100 (seconds).
|
||||||
* **TUNE_RESULT_BUFFER_MIN_TIME_S**: Additionally, you can specify a minimum time to buffer results. Defaults to 0.
|
* **TUNE_RESULT_BUFFER_MIN_TIME_S**: Additionally, you can specify a minimum time to buffer results. Defaults to 0.
|
||||||
* **TUNE_SYNCER_VERBOSITY**: Amount of command output when using Tune with Docker Syncer. Defaults to 0.
|
* **TUNE_SYNCER_VERBOSITY**: Amount of command output when using Tune with Docker Syncer. Defaults to 0.
|
||||||
* **TUNE_TRIAL_RESULT_WAIT_TIME_S**: Amount of time Ray Tune will block until a result from a running trial is received.
|
|
||||||
Defaults to 1 (second).
|
|
||||||
* **TUNE_TRIAL_STARTUP_GRACE_PERIOD**: Amount of time after starting a trial that Ray Tune checks for successful
|
|
||||||
trial startups. After the grace period, Tune will block for up to ``TUNE_TRIAL_RESULT_WAIT_TIME_S`` seconds
|
|
||||||
until a result from a running trial is received. Can be disabled by setting this to lower or equal to 0.
|
|
||||||
* **TUNE_WARN_THRESHOLD_S**: Threshold for logging if an Tune event loop operation takes too long. Defaults to 0.5 (seconds).
|
* **TUNE_WARN_THRESHOLD_S**: Threshold for logging if an Tune event loop operation takes too long. Defaults to 0.5 (seconds).
|
||||||
* **TUNE_WARN_INSUFFICENT_RESOURCE_THRESHOLD_S**: Threshold for throwing a warning if no active trials are in ``RUNNING`` state
|
* **TUNE_WARN_INSUFFICENT_RESOURCE_THRESHOLD_S**: Threshold for throwing a warning if no active trials are in ``RUNNING`` state
|
||||||
for this amount of seconds. If the Ray Tune job is stuck in this state (most likely due to insufficient resources),
|
for this amount of seconds. If the Ray Tune job is stuck in this state (most likely due to insufficient resources),
|
||||||
|
|
|
@ -1,11 +1,12 @@
|
||||||
# coding: utf-8
|
# coding: utf-8
|
||||||
import copy
|
import copy
|
||||||
import inspect
|
import inspect
|
||||||
|
import random
|
||||||
from collections import deque
|
from collections import deque
|
||||||
|
from enum import Enum
|
||||||
from functools import partial
|
from functools import partial
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import random
|
|
||||||
import time
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
|
@ -15,14 +16,15 @@ from typing import (
|
||||||
Iterable,
|
Iterable,
|
||||||
List,
|
List,
|
||||||
Optional,
|
Optional,
|
||||||
|
Union,
|
||||||
|
Set,
|
||||||
)
|
)
|
||||||
|
|
||||||
import ray
|
import ray
|
||||||
from ray.actor import ActorHandle
|
|
||||||
from ray.exceptions import GetTimeoutError
|
from ray.exceptions import GetTimeoutError
|
||||||
from ray import ray_constants
|
from ray import ray_constants
|
||||||
from ray._private.resource_spec import NODE_ID_PREFIX
|
from ray._private.resource_spec import NODE_ID_PREFIX
|
||||||
from ray.tune.error import AbortTrialExecution
|
from ray.tune.error import AbortTrialExecution, TuneError
|
||||||
from ray.tune.logger import NoopLogger
|
from ray.tune.logger import NoopLogger
|
||||||
from ray.tune.result import TRIAL_INFO, STDOUT_FILE, STDERR_FILE
|
from ray.tune.result import TRIAL_INFO, STDOUT_FILE, STDERR_FILE
|
||||||
from ray.tune.resources import Resources
|
from ray.tune.resources import Resources
|
||||||
|
@ -33,14 +35,12 @@ from ray.tune.trial_executor import TrialExecutor
|
||||||
from ray.tune.utils import warn_if_slow
|
from ray.tune.utils import warn_if_slow
|
||||||
from ray.util import log_once
|
from ray.util import log_once
|
||||||
from ray.util.annotations import DeveloperAPI
|
from ray.util.annotations import DeveloperAPI
|
||||||
|
from ray.util.placement_group import remove_placement_group, PlacementGroup
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
TUNE_STATE_REFRESH_PERIOD = 10 # Refresh resources every 10 s
|
TUNE_STATE_REFRESH_PERIOD = 10 # Refresh resources every 10 s
|
||||||
BOTTLENECK_WARN_PERIOD_S = 60
|
|
||||||
NONTRIVIAL_WAIT_TIME_THRESHOLD_S = 1e-3
|
|
||||||
DEFAULT_GET_TIMEOUT = 60.0 # seconds
|
DEFAULT_GET_TIMEOUT = 60.0 # seconds
|
||||||
TRIAL_CLEANUP_THRESHOLD = 100
|
|
||||||
|
|
||||||
|
|
||||||
class _ActorClassCache:
|
class _ActorClassCache:
|
||||||
|
@ -86,75 +86,56 @@ class _LocalWrapper:
|
||||||
return self._result
|
return self._result
|
||||||
|
|
||||||
|
|
||||||
|
def post_stop_cleanup(future, pg):
|
||||||
|
"""Things to be done after a trial is stopped."""
|
||||||
|
assert isinstance(pg, PlacementGroup)
|
||||||
|
try:
|
||||||
|
# This should not be blocking as
|
||||||
|
# we are only here when triggered.
|
||||||
|
ray.get(future, timeout=0)
|
||||||
|
except GetTimeoutError:
|
||||||
|
if log_once("tune_trial_cleanup_timeout"):
|
||||||
|
logger.error(
|
||||||
|
"Timed out when trying to stop the Ray actor gracefully. "
|
||||||
|
"Consider making `stop` a faster operation."
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
if log_once("tune_trial_cleanup_exception"):
|
||||||
|
logger.error(
|
||||||
|
f"An exception occurred when trying to stop the Ray actor:"
|
||||||
|
f"{traceback.format_exc()}"
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
remove_placement_group(pg)
|
||||||
|
|
||||||
|
|
||||||
class _TrialCleanup:
|
class _TrialCleanup:
|
||||||
"""Mechanism for ensuring trial stop futures are cleaned up.
|
"""Responsible for triggering force cleanup of remote actors,
|
||||||
|
without waiting for `Trainable.stop()` to finish.
|
||||||
|
|
||||||
Args:
|
Only instantiated when `TUNE_FORCE_TRIAL_CLEANUP_S` is set up.
|
||||||
threshold (int): Number of futures to hold at once. If the threshold
|
|
||||||
is passed, cleanup will kick in and remove futures.
|
|
||||||
force_cleanup (int): Grace periods for forceful actor termination.
|
|
||||||
If 0, actors will not be forcefully terminated.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(self, force_cleanup):
|
||||||
self, threshold: int = TRIAL_CLEANUP_THRESHOLD, force_cleanup: int = 0
|
assert force_cleanup
|
||||||
):
|
|
||||||
self.threshold = threshold
|
|
||||||
self._cleanup_map = {}
|
|
||||||
if force_cleanup < 0:
|
|
||||||
force_cleanup = 0
|
|
||||||
self._force_cleanup = force_cleanup
|
self._force_cleanup = force_cleanup
|
||||||
|
self._future_to_insert_time = deque()
|
||||||
|
|
||||||
def add(self, trial: Trial, actor: ActorHandle):
|
def add(self, future):
|
||||||
"""Adds a trial actor to be stopped.
|
self._future_to_insert_time.append((future, time.time()))
|
||||||
|
|
||||||
If the number of futures exceeds the threshold, the cleanup mechanism
|
def get_next(self):
|
||||||
will kick in.
|
"""Get the next future that is eligible to be cleaned up forcibly."""
|
||||||
|
if (
|
||||||
Args:
|
len(self._future_to_insert_time) > 0
|
||||||
trial (Trial): The trial corresponding to the future.
|
and self._future_to_insert_time[0][1] + self._force_cleanup < time.time()
|
||||||
actor (ActorHandle): Handle to the trainable to be stopped.
|
):
|
||||||
"""
|
return self._future_to_insert_time.popleft()
|
||||||
future = actor.stop.remote()
|
|
||||||
|
|
||||||
del actor
|
|
||||||
|
|
||||||
self._cleanup_map[future] = trial
|
|
||||||
if len(self._cleanup_map) > self.threshold:
|
|
||||||
self.cleanup(partial=True)
|
|
||||||
|
|
||||||
def cleanup(self, partial: bool = True):
|
|
||||||
"""Waits for cleanup to finish.
|
|
||||||
|
|
||||||
If partial=False, all futures are expected to return. If a future
|
|
||||||
does not return within the timeout period, the cleanup terminates.
|
|
||||||
"""
|
|
||||||
# At this point, self._cleanup_map holds the last references
|
|
||||||
# to actors. Removing those references either one-by-one
|
|
||||||
# (graceful termination case) or all at once, by reinstantiating
|
|
||||||
# self._cleanup_map (forceful termination case) will cause Ray
|
|
||||||
# to kill the actors during garbage collection.
|
|
||||||
logger.debug("Cleaning up futures")
|
|
||||||
num_to_keep = int(self.threshold) / 2 if partial else 0
|
|
||||||
while len(self._cleanup_map) > num_to_keep:
|
|
||||||
dones, _ = ray.wait(
|
|
||||||
list(self._cleanup_map),
|
|
||||||
timeout=DEFAULT_GET_TIMEOUT
|
|
||||||
if not self._force_cleanup
|
|
||||||
else self._force_cleanup,
|
|
||||||
)
|
|
||||||
if not dones:
|
|
||||||
logger.warning(
|
|
||||||
"Skipping cleanup - trainable.stop did not return in "
|
|
||||||
"time. Consider making `stop` a faster operation."
|
|
||||||
)
|
|
||||||
if not partial and self._force_cleanup:
|
|
||||||
logger.warning("Forcing trainable cleanup by terminating actors.")
|
|
||||||
self._cleanup_map = {}
|
|
||||||
return
|
|
||||||
else:
|
else:
|
||||||
done = dones[0]
|
return None
|
||||||
del self._cleanup_map[done]
|
|
||||||
|
def is_empty(self):
|
||||||
|
return len(self._future_to_insert_time) == 0
|
||||||
|
|
||||||
|
|
||||||
def noop_logger_creator(config, logdir):
|
def noop_logger_creator(config, logdir):
|
||||||
|
@ -165,6 +146,39 @@ def noop_logger_creator(config, logdir):
|
||||||
return NoopLogger(config, logdir)
|
return NoopLogger(config, logdir)
|
||||||
|
|
||||||
|
|
||||||
|
class ExecutorEventType(Enum):
|
||||||
|
"""The executor event type.
|
||||||
|
|
||||||
|
Some of the events are internal events to executor while others
|
||||||
|
are handled by runner."""
|
||||||
|
|
||||||
|
NO_RUNNING_TRIAL_TIMEOUT = 1
|
||||||
|
PG_READY = 2
|
||||||
|
TRAINING_RESULT = 3
|
||||||
|
SAVING_RESULT = 4
|
||||||
|
RESTORING_RESULT = 5
|
||||||
|
STOP_RESULT = 6 # Internally to executor only.
|
||||||
|
ERROR = 7 # This is to signal to TrialRunner that there is an error.
|
||||||
|
YIELD = 8 # Yielding back to TrialRunner's main event loop.
|
||||||
|
|
||||||
|
|
||||||
|
class ExecutorEvent:
|
||||||
|
"""A struct that describes the event to be processed by TrialRunner."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
event_type: ExecutorEventType,
|
||||||
|
trial: Optional[Trial] = None,
|
||||||
|
result: Optional[Union[str, Dict]] = None,
|
||||||
|
):
|
||||||
|
self.type = event_type
|
||||||
|
self.trial = trial
|
||||||
|
self.result = result
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f"[{self.type}] for {self.trial}"
|
||||||
|
|
||||||
|
|
||||||
@DeveloperAPI
|
@DeveloperAPI
|
||||||
class RayTrialExecutor(TrialExecutor):
|
class RayTrialExecutor(TrialExecutor):
|
||||||
"""An implementation of TrialExecutor based on Ray."""
|
"""An implementation of TrialExecutor based on Ray."""
|
||||||
|
@ -177,10 +191,17 @@ class RayTrialExecutor(TrialExecutor):
|
||||||
wait_for_placement_group: Optional[float] = None,
|
wait_for_placement_group: Optional[float] = None,
|
||||||
):
|
):
|
||||||
super(RayTrialExecutor, self).__init__()
|
super(RayTrialExecutor, self).__init__()
|
||||||
self._running = {}
|
# future --> (type, trial/pg)
|
||||||
|
self._futures = {}
|
||||||
|
|
||||||
force_trial_cleanup = int(os.environ.get("TUNE_FORCE_TRIAL_CLEANUP_S", "0"))
|
force_trial_cleanup = int(os.environ.get("TUNE_FORCE_TRIAL_CLEANUP_S", "0"))
|
||||||
self._trial_cleanup = _TrialCleanup(force_cleanup=force_trial_cleanup)
|
self._get_next_event_wait = int(
|
||||||
|
os.environ.get("TUNE_GET_EXECUTOR_EVENT_WAIT_S", "5")
|
||||||
|
)
|
||||||
|
if force_trial_cleanup:
|
||||||
|
self._trial_cleanup = _TrialCleanup(force_trial_cleanup)
|
||||||
|
else:
|
||||||
|
self._trial_cleanup = None
|
||||||
self._has_cleaned_up_pgs = False
|
self._has_cleaned_up_pgs = False
|
||||||
self._reuse_actors = reuse_actors
|
self._reuse_actors = reuse_actors
|
||||||
# The maxlen will be updated when `set_max_pending_trials()` is called
|
# The maxlen will be updated when `set_max_pending_trials()` is called
|
||||||
|
@ -189,7 +210,6 @@ class RayTrialExecutor(TrialExecutor):
|
||||||
self._avail_resources = Resources(cpu=0, gpu=0)
|
self._avail_resources = Resources(cpu=0, gpu=0)
|
||||||
self._pg_manager = PlacementGroupManager(prefix=get_tune_pg_prefix())
|
self._pg_manager = PlacementGroupManager(prefix=get_tune_pg_prefix())
|
||||||
self._staged_trials = set()
|
self._staged_trials = set()
|
||||||
self._just_staged_trials = set()
|
|
||||||
self._trial_just_finished = False
|
self._trial_just_finished = False
|
||||||
self._trial_just_finished_before = False
|
self._trial_just_finished_before = False
|
||||||
|
|
||||||
|
@ -201,12 +221,6 @@ class RayTrialExecutor(TrialExecutor):
|
||||||
)
|
)
|
||||||
self._refresh_period = refresh_period
|
self._refresh_period = refresh_period
|
||||||
|
|
||||||
self._wait_for_pg = wait_for_placement_group or float(
|
|
||||||
os.environ.get("TUNE_PLACEMENT_GROUP_WAIT_S", "-1")
|
|
||||||
)
|
|
||||||
if self._wait_for_pg < 0:
|
|
||||||
self._wait_for_pg = None
|
|
||||||
|
|
||||||
self.last_pg_recon = 0
|
self.last_pg_recon = 0
|
||||||
self.pg_recon_interval = float(
|
self.pg_recon_interval = float(
|
||||||
os.environ.get("TUNE_PLACEMENT_GROUP_RECON_INTERVAL", "5")
|
os.environ.get("TUNE_PLACEMENT_GROUP_RECON_INTERVAL", "5")
|
||||||
|
@ -229,10 +243,6 @@ class RayTrialExecutor(TrialExecutor):
|
||||||
if ray.is_initialized():
|
if ray.is_initialized():
|
||||||
self._update_avail_resources()
|
self._update_avail_resources()
|
||||||
|
|
||||||
def in_staging_grace_period(self) -> bool:
|
|
||||||
"""Returns True if trials have recently been staged."""
|
|
||||||
return self._pg_manager.in_staging_grace_period()
|
|
||||||
|
|
||||||
def set_max_pending_trials(self, max_pending: int) -> None:
|
def set_max_pending_trials(self, max_pending: int) -> None:
|
||||||
if len(self._cached_actor_pg) > 0:
|
if len(self._cached_actor_pg) > 0:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
|
@ -243,7 +253,7 @@ class RayTrialExecutor(TrialExecutor):
|
||||||
self._cached_actor_pg = deque(maxlen=max_pending)
|
self._cached_actor_pg = deque(maxlen=max_pending)
|
||||||
self._pg_manager.set_max_staging(max_pending)
|
self._pg_manager.set_max_staging(max_pending)
|
||||||
|
|
||||||
def stage_and_update_status(self, trials: Iterable[Trial]):
|
def _stage_and_update_status(self, trials: Iterable[Trial]):
|
||||||
"""Check and update statuses of scheduled placement groups.
|
"""Check and update statuses of scheduled placement groups.
|
||||||
|
|
||||||
Stages placement groups of all trials.
|
Stages placement groups of all trials.
|
||||||
|
@ -255,7 +265,7 @@ class RayTrialExecutor(TrialExecutor):
|
||||||
self._has_cleaned_up_pgs = True
|
self._has_cleaned_up_pgs = True
|
||||||
|
|
||||||
for trial in trials:
|
for trial in trials:
|
||||||
if trial.status != Trial.PENDING:
|
if trial.status not in (Trial.PENDING, Trial.PAUSED):
|
||||||
continue
|
continue
|
||||||
if trial in self._staged_trials:
|
if trial in self._staged_trials:
|
||||||
continue
|
continue
|
||||||
|
@ -266,7 +276,6 @@ class RayTrialExecutor(TrialExecutor):
|
||||||
# Break if we reached the limit of pending placement groups.
|
# Break if we reached the limit of pending placement groups.
|
||||||
break
|
break
|
||||||
self._staged_trials.add(trial)
|
self._staged_trials.add(trial)
|
||||||
self._just_staged_trials.add(trial)
|
|
||||||
|
|
||||||
self._pg_manager.update_status()
|
self._pg_manager.update_status()
|
||||||
|
|
||||||
|
@ -279,6 +288,7 @@ class RayTrialExecutor(TrialExecutor):
|
||||||
Trial object or None.
|
Trial object or None.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
# TODO(xwjiang): This method should consider `self._cached_actor_pg`.
|
||||||
for trial in self._staged_trials:
|
for trial in self._staged_trials:
|
||||||
if self._pg_manager.has_ready(trial):
|
if self._pg_manager.has_ready(trial):
|
||||||
return trial
|
return trial
|
||||||
|
@ -317,43 +327,7 @@ class RayTrialExecutor(TrialExecutor):
|
||||||
)
|
)
|
||||||
_actor_cls = _class_cache.get(trainable_cls)
|
_actor_cls = _class_cache.get(trainable_cls)
|
||||||
|
|
||||||
if not self._pg_manager.has_ready(trial, update=True):
|
|
||||||
if trial not in self._staged_trials:
|
|
||||||
if self._pg_manager.stage_trial_pg(trial):
|
|
||||||
self._staged_trials.add(trial)
|
|
||||||
self._just_staged_trials.add(trial)
|
|
||||||
|
|
||||||
just_staged = trial in self._just_staged_trials
|
|
||||||
|
|
||||||
# This part of the code is mostly here for testing
|
|
||||||
# purposes. If self._wait_for_pg is set, we will wait here
|
|
||||||
# for that many seconds until the placement group is ready.
|
|
||||||
# This ensures that the trial can be started right away and
|
|
||||||
# not just in the next step() of the trial runner.
|
|
||||||
# We only do this if we have reason to believe that resources
|
|
||||||
# will be ready, soon, i.e. when a) we just staged the PG,
|
|
||||||
# b) another trial just exited, freeing resources, or c)
|
|
||||||
# when there are no currently running trials.
|
|
||||||
if self._wait_for_pg is not None and (
|
|
||||||
just_staged
|
|
||||||
or self._trial_just_finished_before
|
|
||||||
or not self.get_running_trials()
|
|
||||||
):
|
|
||||||
logger.debug(
|
|
||||||
f"Waiting up to {self._wait_for_pg} seconds for "
|
|
||||||
f"placement group of trial {trial} to become ready."
|
|
||||||
)
|
|
||||||
wait_end = time.monotonic() + self._wait_for_pg
|
|
||||||
while time.monotonic() < wait_end:
|
|
||||||
self._pg_manager.update_status()
|
|
||||||
if self._pg_manager.has_ready(trial):
|
|
||||||
break
|
|
||||||
time.sleep(0.1)
|
|
||||||
else:
|
|
||||||
return None
|
|
||||||
|
|
||||||
if not self._pg_manager.has_ready(trial):
|
if not self._pg_manager.has_ready(trial):
|
||||||
# PG may have become ready during waiting period
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
full_actor_class = self._pg_manager.get_full_actor_cls(trial, _actor_cls)
|
full_actor_class = self._pg_manager.get_full_actor_cls(trial, _actor_cls)
|
||||||
|
@ -403,7 +377,7 @@ class RayTrialExecutor(TrialExecutor):
|
||||||
def _train(self, trial):
|
def _train(self, trial):
|
||||||
"""Start one iteration of training and save remote id."""
|
"""Start one iteration of training and save remote id."""
|
||||||
|
|
||||||
if self._find_item(self._running, trial):
|
if self._find_future(trial):
|
||||||
logging.debug(
|
logging.debug(
|
||||||
"Trial {} already has a queued future. Skipping this "
|
"Trial {} already has a queued future. Skipping this "
|
||||||
"`train` call. This may occur if a trial has "
|
"`train` call. This may occur if a trial has "
|
||||||
|
@ -414,7 +388,7 @@ class RayTrialExecutor(TrialExecutor):
|
||||||
assert trial.status == Trial.RUNNING, trial.status
|
assert trial.status == Trial.RUNNING, trial.status
|
||||||
buffer_time_s = max(
|
buffer_time_s = max(
|
||||||
self._buffer_min_time_s,
|
self._buffer_min_time_s,
|
||||||
min(self._buffer_max_time_s, len(self._running) // 10),
|
min(self._buffer_max_time_s, len(self._futures) // 10),
|
||||||
)
|
)
|
||||||
with self._change_working_directory(trial):
|
with self._change_working_directory(trial):
|
||||||
buffer_length = self._buffer_length
|
buffer_length = self._buffer_length
|
||||||
|
@ -441,8 +415,8 @@ class RayTrialExecutor(TrialExecutor):
|
||||||
if isinstance(remote, dict):
|
if isinstance(remote, dict):
|
||||||
remote = _LocalWrapper(remote)
|
remote = _LocalWrapper(remote)
|
||||||
|
|
||||||
self._running[remote] = trial
|
self._futures[remote] = (ExecutorEventType.TRAINING_RESULT, trial)
|
||||||
trial_item = self._find_item(self._running, trial)
|
trial_item = self._find_future(trial)
|
||||||
assert len(trial_item) < 2, trial_item
|
assert len(trial_item) < 2, trial_item
|
||||||
|
|
||||||
def _start_trial(self, trial) -> bool:
|
def _start_trial(self, trial) -> bool:
|
||||||
|
@ -540,11 +514,13 @@ class RayTrialExecutor(TrialExecutor):
|
||||||
if should_destroy_actor:
|
if should_destroy_actor:
|
||||||
logger.debug("Trial %s: Destroying actor.", trial)
|
logger.debug("Trial %s: Destroying actor.", trial)
|
||||||
|
|
||||||
# Try to return the placement group for other trials to use
|
|
||||||
self._pg_manager.return_pg(trial)
|
|
||||||
|
|
||||||
with self._change_working_directory(trial):
|
with self._change_working_directory(trial):
|
||||||
self._trial_cleanup.add(trial, actor=trial.runner)
|
future = trial.runner.stop.remote()
|
||||||
|
|
||||||
|
pg = self._pg_manager.remove_from_in_use(trial)
|
||||||
|
self._futures[future] = (ExecutorEventType.STOP_RESULT, pg)
|
||||||
|
if self._trial_cleanup: # force trial cleanup within a deadline
|
||||||
|
self._trial_cleanup.add(future)
|
||||||
|
|
||||||
if trial in self._staged_trials:
|
if trial in self._staged_trials:
|
||||||
self._staged_trials.remove(trial)
|
self._staged_trials.remove(trial)
|
||||||
|
@ -584,8 +560,8 @@ class RayTrialExecutor(TrialExecutor):
|
||||||
# have been lost. TODO(ujvl): is this the right thing to do?
|
# have been lost. TODO(ujvl): is this the right thing to do?
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def _find_item(self, dictionary, item):
|
def _find_future(self, trial):
|
||||||
out = [rid for rid, t in dictionary.items() if t is item]
|
out = [rid for rid, t in self._futures.items() if t[1] is trial]
|
||||||
assert (
|
assert (
|
||||||
len(out) <= 1
|
len(out) <= 1
|
||||||
), "Expecting one future for any given trial at any given time."
|
), "Expecting one future for any given trial at any given time."
|
||||||
|
@ -598,9 +574,9 @@ class RayTrialExecutor(TrialExecutor):
|
||||||
self._stop_trial(trial, error=error, error_msg=error_msg)
|
self._stop_trial(trial, error=error, error_msg=error_msg)
|
||||||
if prior_status == Trial.RUNNING:
|
if prior_status == Trial.RUNNING:
|
||||||
logger.debug("Trial %s: Returning resources.", trial)
|
logger.debug("Trial %s: Returning resources.", trial)
|
||||||
out = self._find_item(self._running, trial)
|
out = self._find_future(trial)
|
||||||
for result_id in out:
|
for result_id in out:
|
||||||
self._running.pop(result_id)
|
self._futures.pop(result_id)
|
||||||
|
|
||||||
def continue_training(self, trial: Trial) -> None:
|
def continue_training(self, trial: Trial) -> None:
|
||||||
"""Continues the training of this trial."""
|
"""Continues the training of this trial."""
|
||||||
|
@ -649,63 +625,6 @@ class RayTrialExecutor(TrialExecutor):
|
||||||
return False
|
return False
|
||||||
return reset_val
|
return reset_val
|
||||||
|
|
||||||
def get_running_trials(self) -> List[Trial]:
|
|
||||||
"""Returns the running trials."""
|
|
||||||
return list(self._running.values())
|
|
||||||
|
|
||||||
def get_next_available_trial(
|
|
||||||
self, timeout: Optional[float] = None
|
|
||||||
) -> Optional[Trial]:
|
|
||||||
if not self._running:
|
|
||||||
return None
|
|
||||||
shuffled_results = list(self._running.keys())
|
|
||||||
random.shuffle(shuffled_results)
|
|
||||||
|
|
||||||
# Note: We shuffle the results because `ray.wait` by default returns
|
|
||||||
# the first available result, and we want to guarantee that slower
|
|
||||||
# trials (i.e. trials that run remotely) also get fairly reported.
|
|
||||||
# See https://github.com/ray-project/ray/issues/4211 for details.
|
|
||||||
start = time.time()
|
|
||||||
ready, _ = ray.wait(shuffled_results, timeout=timeout)
|
|
||||||
if not ready:
|
|
||||||
return None
|
|
||||||
result_id = ready[0]
|
|
||||||
wait_time = time.time() - start
|
|
||||||
if wait_time > NONTRIVIAL_WAIT_TIME_THRESHOLD_S:
|
|
||||||
self._last_nontrivial_wait = time.time()
|
|
||||||
if time.time() - self._last_nontrivial_wait > BOTTLENECK_WARN_PERIOD_S:
|
|
||||||
logger.warning(
|
|
||||||
"Over the last {} seconds, the Tune event loop has been "
|
|
||||||
"backlogged processing new results. Consider increasing your "
|
|
||||||
"period of result reporting to improve performance.".format(
|
|
||||||
BOTTLENECK_WARN_PERIOD_S
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
self._last_nontrivial_wait = time.time()
|
|
||||||
return self._running[result_id]
|
|
||||||
|
|
||||||
def fetch_result(self, trial) -> List[Dict]:
|
|
||||||
"""Fetches result list of the running trials.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Result of the most recent trial training run.
|
|
||||||
"""
|
|
||||||
trial_future = self._find_item(self._running, trial)
|
|
||||||
if not trial_future:
|
|
||||||
raise ValueError("Trial was not running.")
|
|
||||||
self._running.pop(trial_future[0])
|
|
||||||
with warn_if_slow("fetch_result"):
|
|
||||||
result = ray.get(trial_future[0], timeout=DEFAULT_GET_TIMEOUT)
|
|
||||||
|
|
||||||
# For local mode
|
|
||||||
if isinstance(result, _LocalWrapper):
|
|
||||||
result = result.unwrap()
|
|
||||||
|
|
||||||
if not isinstance(result, list):
|
|
||||||
return [result]
|
|
||||||
return result
|
|
||||||
|
|
||||||
def _update_avail_resources(self, num_retries=5):
|
def _update_avail_resources(self, num_retries=5):
|
||||||
if time.time() - self._last_resource_refresh < self._refresh_period:
|
if time.time() - self._last_resource_refresh < self._refresh_period:
|
||||||
return
|
return
|
||||||
|
@ -814,8 +733,7 @@ class RayTrialExecutor(TrialExecutor):
|
||||||
self._trial_just_finished = False
|
self._trial_just_finished = False
|
||||||
|
|
||||||
def on_step_end(self, trials: List[Trial]) -> None:
|
def on_step_end(self, trials: List[Trial]) -> None:
|
||||||
self._just_staged_trials.clear()
|
self._do_force_trial_cleanup()
|
||||||
|
|
||||||
if time.time() > self.last_pg_recon + self.pg_recon_interval:
|
if time.time() > self.last_pg_recon + self.pg_recon_interval:
|
||||||
# Only do this every now and then - usually the placement groups
|
# Only do this every now and then - usually the placement groups
|
||||||
# should not get out of sync, and calling this often is inefficient
|
# should not get out of sync, and calling this often is inefficient
|
||||||
|
@ -824,6 +742,20 @@ class RayTrialExecutor(TrialExecutor):
|
||||||
|
|
||||||
self._pg_manager.cleanup()
|
self._pg_manager.cleanup()
|
||||||
|
|
||||||
|
def _do_force_trial_cleanup(self) -> None:
|
||||||
|
if self._trial_cleanup:
|
||||||
|
while True:
|
||||||
|
next_future_to_clean = self._trial_cleanup.get_next()
|
||||||
|
if not next_future_to_clean:
|
||||||
|
break
|
||||||
|
if next_future_to_clean in self._futures.keys():
|
||||||
|
_, pg = self._futures.pop(next_future_to_clean)
|
||||||
|
post_stop_cleanup(next_future_to_clean, pg)
|
||||||
|
else:
|
||||||
|
# This just means that before the deadline reaches,
|
||||||
|
# the future is already cleaned up.
|
||||||
|
pass
|
||||||
|
|
||||||
def force_reconcilation_on_next_step_end(self) -> None:
|
def force_reconcilation_on_next_step_end(self) -> None:
|
||||||
self.last_pg_recon = -float("inf")
|
self.last_pg_recon = -float("inf")
|
||||||
|
|
||||||
|
@ -842,6 +774,7 @@ class RayTrialExecutor(TrialExecutor):
|
||||||
Returns:
|
Returns:
|
||||||
Checkpoint object, or None if an Exception occurs.
|
Checkpoint object, or None if an Exception occurs.
|
||||||
"""
|
"""
|
||||||
|
logger.info(f"saving trial {trial}")
|
||||||
result = result or trial.last_result
|
result = result or trial.last_result
|
||||||
with self._change_working_directory(trial):
|
with self._change_working_directory(trial):
|
||||||
if storage == Checkpoint.MEMORY:
|
if storage == Checkpoint.MEMORY:
|
||||||
|
@ -852,7 +785,7 @@ class RayTrialExecutor(TrialExecutor):
|
||||||
value = trial.runner.save.remote()
|
value = trial.runner.save.remote()
|
||||||
checkpoint = Checkpoint(storage, value, result)
|
checkpoint = Checkpoint(storage, value, result)
|
||||||
trial.saving_to = checkpoint
|
trial.saving_to = checkpoint
|
||||||
self._running[value] = trial
|
self._futures[value] = (ExecutorEventType.SAVING_RESULT, trial)
|
||||||
return checkpoint
|
return checkpoint
|
||||||
|
|
||||||
def restore(self, trial) -> None:
|
def restore(self, trial) -> None:
|
||||||
|
@ -899,7 +832,7 @@ class RayTrialExecutor(TrialExecutor):
|
||||||
"storage-based restoration"
|
"storage-based restoration"
|
||||||
)
|
)
|
||||||
|
|
||||||
self._running[remote] = trial
|
self._futures[remote] = (ExecutorEventType.RESTORING_RESULT, trial)
|
||||||
trial.restoring_from = checkpoint
|
trial.restoring_from = checkpoint
|
||||||
|
|
||||||
def export_trial_if_needed(self, trial: Trial) -> Dict:
|
def export_trial_if_needed(self, trial: Trial) -> Dict:
|
||||||
|
@ -922,7 +855,19 @@ class RayTrialExecutor(TrialExecutor):
|
||||||
return self._avail_resources.gpu > 0
|
return self._avail_resources.gpu > 0
|
||||||
|
|
||||||
def cleanup(self, trials: List[Trial]) -> None:
|
def cleanup(self, trials: List[Trial]) -> None:
|
||||||
self._trial_cleanup.cleanup(partial=False)
|
while True:
|
||||||
|
if self._trial_cleanup and self._trial_cleanup.is_empty():
|
||||||
|
break
|
||||||
|
elif not self._trial_cleanup and len(self._futures) == 0:
|
||||||
|
break
|
||||||
|
self._do_force_trial_cleanup()
|
||||||
|
ready, _ = ray.wait(list(self._futures.keys()), timeout=0)
|
||||||
|
if not ready:
|
||||||
|
continue
|
||||||
|
event_type, trial_or_pg = self._futures.pop(ready[0])
|
||||||
|
if event_type == ExecutorEventType.STOP_RESULT:
|
||||||
|
post_stop_cleanup(ready[0], trial_or_pg)
|
||||||
|
|
||||||
self._pg_manager.reconcile_placement_groups(trials)
|
self._pg_manager.reconcile_placement_groups(trials)
|
||||||
self._pg_manager.cleanup(force=True)
|
self._pg_manager.cleanup(force=True)
|
||||||
self._pg_manager.cleanup_existing_pg(block=True)
|
self._pg_manager.cleanup_existing_pg(block=True)
|
||||||
|
@ -944,6 +889,150 @@ class RayTrialExecutor(TrialExecutor):
|
||||||
else:
|
else:
|
||||||
yield
|
yield
|
||||||
|
|
||||||
|
def get_next_executor_event(
|
||||||
|
self, live_trials: Set[Trial], next_trial_exists: bool
|
||||||
|
) -> ExecutorEvent:
|
||||||
|
"""Get the next executor event to be processed in TrialRunner.
|
||||||
|
|
||||||
|
In case there are multiple events available for handling, the next
|
||||||
|
event is determined by the following priority:
|
||||||
|
1. if there is `next_trial_exists`, and if there is cached resources
|
||||||
|
to use, PG_READY is emitted.
|
||||||
|
2. if there is `next_trial_exists` and there is no cached resources
|
||||||
|
to use, wait on pg future and randomized other futures. If multiple
|
||||||
|
futures are ready, pg future will take priority to be handled first.
|
||||||
|
3. if there is no `next_trial_exists`, wait on just randomized other
|
||||||
|
futures.
|
||||||
|
|
||||||
|
An example of #3 would be synchronous hyperband. Although there are pgs
|
||||||
|
ready, the scheduler is holding back scheduling new trials since the
|
||||||
|
whole band of trials is waiting for the slowest trial to finish. In
|
||||||
|
this case, we prioritize handling training result to avoid deadlock
|
||||||
|
situation.
|
||||||
|
|
||||||
|
This is a blocking wait with a timeout (specified with env var).
|
||||||
|
The reason for the timeout is
|
||||||
|
we still want to print status info periodically in TrialRunner for
|
||||||
|
better user experience.
|
||||||
|
|
||||||
|
The handle of `ExecutorEvent.STOP_RESULT` is purely internal to
|
||||||
|
RayTrialExecutor itself. All the other future results are handled by
|
||||||
|
TrialRunner.
|
||||||
|
|
||||||
|
In the future we may want to do most of the handle of
|
||||||
|
`ExecutorEvent.RESTORE_RESULT` and `SAVING_RESULT` in
|
||||||
|
RayTrialExecutor itself and only notify TrialRunner to invoke
|
||||||
|
corresponding callbacks. This view is more consistent with our goal
|
||||||
|
of TrialRunner responsible for external facing Trial state transition,
|
||||||
|
while RayTrialExecutor responsible for internal facing transitions,
|
||||||
|
namely, `is_saving`, `is_restoring` etc.
|
||||||
|
|
||||||
|
Also you may notice that the boundary between RayTrialExecutor and
|
||||||
|
PlacementGroupManager right now is really blurry. This will be
|
||||||
|
improved once we move to an ActorPool abstraction.
|
||||||
|
|
||||||
|
`next_trial_exists` means that there is a trial to run - prioritize
|
||||||
|
returning PG_READY in this case.
|
||||||
|
"""
|
||||||
|
# First update status of staged placement groups
|
||||||
|
self._stage_and_update_status(live_trials)
|
||||||
|
while True:
|
||||||
|
###################################################################
|
||||||
|
# when next_trial_exists and there are cached resources
|
||||||
|
###################################################################
|
||||||
|
# There could be existing PGs from either `self._cached_actor_pg`
|
||||||
|
# or from `self._pg_manager._ready`. If so and if there is indeed
|
||||||
|
# a next trial to run, we return `PG_READY` future for trial
|
||||||
|
# runner. The next trial can then be scheduled on this PG.
|
||||||
|
if next_trial_exists:
|
||||||
|
if len(self._cached_actor_pg) > 0:
|
||||||
|
return ExecutorEvent(ExecutorEventType.PG_READY)
|
||||||
|
# TODO(xwjiang): Expose proper API when we decide to do
|
||||||
|
# ActorPool abstraction.
|
||||||
|
if any(len(r) > 0 for r in self._pg_manager._ready.values()):
|
||||||
|
return ExecutorEvent(ExecutorEventType.PG_READY)
|
||||||
|
|
||||||
|
###################################################################
|
||||||
|
# Prepare for futures to wait
|
||||||
|
###################################################################
|
||||||
|
futures_to_wait = list(self._futures.keys())
|
||||||
|
random.shuffle(futures_to_wait)
|
||||||
|
if next_trial_exists:
|
||||||
|
# Only wait for pg explicitly if there is next trial to run.
|
||||||
|
# In which case, handling PG_READY triumphs handling other events.
|
||||||
|
# Since we want to place pending trial ASAP.
|
||||||
|
futures_to_wait = (
|
||||||
|
self._pg_manager.get_staging_future_list() + futures_to_wait
|
||||||
|
)
|
||||||
|
logger.debug(
|
||||||
|
f"get_next_executor_event before wait with futures "
|
||||||
|
f"{futures_to_wait} and "
|
||||||
|
f"next_trial_exists={next_trial_exists}"
|
||||||
|
)
|
||||||
|
|
||||||
|
ready_futures, _ = ray.wait(
|
||||||
|
futures_to_wait, num_returns=1, timeout=self._get_next_event_wait
|
||||||
|
)
|
||||||
|
|
||||||
|
###################################################################
|
||||||
|
# Dealing with no future returned case.
|
||||||
|
###################################################################
|
||||||
|
if len(ready_futures) == 0:
|
||||||
|
if len(self._futures) == 0:
|
||||||
|
# No running trial and timing out with wait, could be we may
|
||||||
|
# have insufficient cluster resources that makes tune run
|
||||||
|
# infeasible.
|
||||||
|
# TODO: Move InsufficientResourceManager's logic
|
||||||
|
# to TrialExecutor. It is not Runner's responsibility!
|
||||||
|
return ExecutorEvent(ExecutorEventType.NO_RUNNING_TRIAL_TIMEOUT)
|
||||||
|
else:
|
||||||
|
# Training simply takes long time, yield the control back to main
|
||||||
|
# event loop to print progress info etc.
|
||||||
|
return ExecutorEvent(ExecutorEventType.YIELD)
|
||||||
|
|
||||||
|
###################################################################
|
||||||
|
# If there is future returned.
|
||||||
|
###################################################################
|
||||||
|
assert len(ready_futures) == 1
|
||||||
|
ready_future = ready_futures[0]
|
||||||
|
|
||||||
|
###################################################################
|
||||||
|
# If it is a PG_READY event.
|
||||||
|
###################################################################
|
||||||
|
if ready_future not in self._futures.keys():
|
||||||
|
# This is a ready future.
|
||||||
|
self._pg_manager.handle_ready_future(ready_future)
|
||||||
|
return ExecutorEvent(ExecutorEventType.PG_READY)
|
||||||
|
|
||||||
|
###################################################################
|
||||||
|
# non PG_READY event
|
||||||
|
###################################################################
|
||||||
|
result_type, trial_or_pg = self._futures.pop(ready_future)
|
||||||
|
if result_type == ExecutorEventType.STOP_RESULT:
|
||||||
|
pg = trial_or_pg
|
||||||
|
post_stop_cleanup(ready_future, pg)
|
||||||
|
else:
|
||||||
|
trial = trial_or_pg
|
||||||
|
assert isinstance(trial, Trial)
|
||||||
|
try:
|
||||||
|
future_result = ray.get(ready_future)
|
||||||
|
# For local mode
|
||||||
|
if isinstance(future_result, _LocalWrapper):
|
||||||
|
future_result = future_result.unwrap()
|
||||||
|
if result_type in (
|
||||||
|
ExecutorEventType.TRAINING_RESULT,
|
||||||
|
ExecutorEventType.SAVING_RESULT,
|
||||||
|
ExecutorEventType.RESTORING_RESULT,
|
||||||
|
):
|
||||||
|
logger.debug(f"Returning [{result_type}] for trial {trial}")
|
||||||
|
return ExecutorEvent(result_type, trial, result=future_result)
|
||||||
|
else:
|
||||||
|
raise TuneError(f"Unexpected future type - [{result_type}]")
|
||||||
|
except Exception:
|
||||||
|
return ExecutorEvent(
|
||||||
|
ExecutorEventType.ERROR, trial, traceback.format_exc()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _to_gb(n_bytes):
|
def _to_gb(n_bytes):
|
||||||
return round(n_bytes / (1024 ** 3), 2)
|
return round(n_bytes / (1024 ** 3), 2)
|
||||||
|
|
|
@ -780,7 +780,7 @@ class TrainableFunctionApiTest(unittest.TestCase):
|
||||||
return [0, 1, True, {}]
|
return [0, 1, True, {}]
|
||||||
|
|
||||||
class FailureInjectionCallback(Callback):
|
class FailureInjectionCallback(Callback):
|
||||||
def on_trial_start(self, trials, **info):
|
def on_step_end(self, **info):
|
||||||
raise RuntimeError
|
raise RuntimeError
|
||||||
|
|
||||||
with self.assertRaises(Exception):
|
with self.assertRaises(Exception):
|
||||||
|
@ -870,7 +870,6 @@ class TrainableFunctionApiTest(unittest.TestCase):
|
||||||
trial.last_result.get("trial_resources"), trial.placement_group_factory
|
trial.last_result.get("trial_resources"), trial.placement_group_factory
|
||||||
)
|
)
|
||||||
|
|
||||||
@patch("ray.tune.ray_trial_executor.TRIAL_CLEANUP_THRESHOLD", 3)
|
|
||||||
def testLotsOfStops(self):
|
def testLotsOfStops(self):
|
||||||
class TestTrainable(Trainable):
|
class TestTrainable(Trainable):
|
||||||
def step(self):
|
def step(self):
|
||||||
|
|
|
@ -30,12 +30,6 @@ from ray.tune.utils.mock import (
|
||||||
MOCK_REMOTE_DIR,
|
MOCK_REMOTE_DIR,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Wait up to five seconds for placement groups when starting a trial
|
|
||||||
os.environ["TUNE_PLACEMENT_GROUP_WAIT_S"] = "5"
|
|
||||||
# Block for results even when placement groups are pending
|
|
||||||
os.environ["TUNE_TRIAL_STARTUP_GRACE_PERIOD"] = "0"
|
|
||||||
os.environ["TUNE_TRIAL_RESULT_WAIT_TIME_S"] = "9999"
|
|
||||||
|
|
||||||
|
|
||||||
def _check_trial_running(trial):
|
def _check_trial_running(trial):
|
||||||
if trial.runner:
|
if trial.runner:
|
||||||
|
@ -203,17 +197,10 @@ def test_remove_node_before_result(start_connected_emptyhead_cluster):
|
||||||
assert trial.last_result.get("training_iteration") == 1
|
assert trial.last_result.get("training_iteration") == 1
|
||||||
|
|
||||||
# Process result: discover failure, recover, _train (from scratch)
|
# Process result: discover failure, recover, _train (from scratch)
|
||||||
|
while trial.status != Trial.TERMINATED:
|
||||||
runner.step()
|
runner.step()
|
||||||
|
|
||||||
runner.step() # Process result, invoke _train
|
assert trial.last_result.get("training_iteration") > 1
|
||||||
assert trial.last_result.get("training_iteration") == 1
|
|
||||||
runner.step() # Process result, invoke _save
|
|
||||||
assert trial.last_result.get("training_iteration") == 2
|
|
||||||
# process save, invoke _train
|
|
||||||
runner.step()
|
|
||||||
# process result
|
|
||||||
runner.step()
|
|
||||||
assert trial.status == Trial.TERMINATED
|
|
||||||
|
|
||||||
with pytest.raises(TuneError):
|
with pytest.raises(TuneError):
|
||||||
runner.step()
|
runner.step()
|
||||||
|
@ -262,7 +249,7 @@ def test_trial_migration(start_connected_emptyhead_cluster, trainable_id):
|
||||||
# assert t.last_result is None, "Trial result not restored correctly."
|
# assert t.last_result is None, "Trial result not restored correctly."
|
||||||
|
|
||||||
# Process result (x2), process save, process result (x2), process save
|
# Process result (x2), process save, process result (x2), process save
|
||||||
for _ in range(6):
|
while not runner.is_finished():
|
||||||
runner.step()
|
runner.step()
|
||||||
|
|
||||||
assert t.status == Trial.TERMINATED, runner.debug_string()
|
assert t.status == Trial.TERMINATED, runner.debug_string()
|
||||||
|
@ -271,19 +258,13 @@ def test_trial_migration(start_connected_emptyhead_cluster, trainable_id):
|
||||||
t2 = Trial(trainable_id, **kwargs)
|
t2 = Trial(trainable_id, **kwargs)
|
||||||
runner.add_trial(t2)
|
runner.add_trial(t2)
|
||||||
# Start trial, process result (x2), process save
|
# Start trial, process result (x2), process save
|
||||||
for _ in range(4):
|
while not t2.has_checkpoint():
|
||||||
runner.step()
|
runner.step()
|
||||||
assert t2.has_checkpoint()
|
|
||||||
node3 = cluster.add_node(num_cpus=1)
|
node3 = cluster.add_node(num_cpus=1)
|
||||||
cluster.remove_node(node2)
|
cluster.remove_node(node2)
|
||||||
cluster.wait_for_nodes()
|
cluster.wait_for_nodes()
|
||||||
runner.step() # Process result 3 + start and fail 4 result
|
while not runner.is_finished():
|
||||||
runner.step() # Dispatch restore
|
runner.step()
|
||||||
runner.step() # Process restore
|
|
||||||
runner.step() # Process result 5
|
|
||||||
if t2.status != Trial.TERMINATED:
|
|
||||||
runner.step() # Process result 6, dispatch save
|
|
||||||
runner.step() # Process save
|
|
||||||
assert t2.status == Trial.TERMINATED, runner.debug_string()
|
assert t2.status == Trial.TERMINATED, runner.debug_string()
|
||||||
|
|
||||||
# Test recovery of trial that won't be checkpointed
|
# Test recovery of trial that won't be checkpointed
|
||||||
|
@ -301,8 +282,7 @@ def test_trial_migration(start_connected_emptyhead_cluster, trainable_id):
|
||||||
cluster.add_node(num_cpus=1)
|
cluster.add_node(num_cpus=1)
|
||||||
cluster.remove_node(node3)
|
cluster.remove_node(node3)
|
||||||
cluster.wait_for_nodes()
|
cluster.wait_for_nodes()
|
||||||
runner.step() # Error handling step
|
while not runner.is_finished():
|
||||||
if t3.status != Trial.ERROR:
|
|
||||||
runner.step()
|
runner.step()
|
||||||
assert t3.status == Trial.ERROR, runner.debug_string()
|
assert t3.status == Trial.ERROR, runner.debug_string()
|
||||||
|
|
||||||
|
@ -406,19 +386,14 @@ def test_migration_checkpoint_removal(start_connected_emptyhead_cluster, trainab
|
||||||
runner.add_trial(t1)
|
runner.add_trial(t1)
|
||||||
|
|
||||||
# Start trial, process result (x2), process save
|
# Start trial, process result (x2), process save
|
||||||
for _ in range(4):
|
while not t1.has_checkpoint():
|
||||||
runner.step()
|
runner.step()
|
||||||
assert t1.has_checkpoint()
|
|
||||||
|
|
||||||
cluster.add_node(num_cpus=1)
|
cluster.add_node(num_cpus=1)
|
||||||
cluster.remove_node(node)
|
cluster.remove_node(node)
|
||||||
cluster.wait_for_nodes()
|
cluster.wait_for_nodes()
|
||||||
shutil.rmtree(os.path.dirname(t1.checkpoint.value))
|
shutil.rmtree(os.path.dirname(t1.checkpoint.value))
|
||||||
runner.step() # Collect result 3, kick off + fail result 4
|
while not runner.is_finished():
|
||||||
runner.step() # Dispatch restore
|
|
||||||
runner.step() # Process restore + step 4
|
|
||||||
for _ in range(3):
|
|
||||||
if t1.status != Trial.TERMINATED:
|
|
||||||
runner.step()
|
runner.step()
|
||||||
assert t1.status == Trial.TERMINATED, runner.debug_string()
|
assert t1.status == Trial.TERMINATED, runner.debug_string()
|
||||||
|
|
||||||
|
|
|
@ -295,11 +295,6 @@ with patch("ray.tune.progress_reporter._get_trial_location",
|
||||||
|
|
||||||
class ProgressReporterTest(unittest.TestCase):
|
class ProgressReporterTest(unittest.TestCase):
|
||||||
def setUp(self) -> None:
|
def setUp(self) -> None:
|
||||||
# Wait up to five seconds for placement groups when starting a trial
|
|
||||||
os.environ["TUNE_PLACEMENT_GROUP_WAIT_S"] = "5"
|
|
||||||
# Block for results even when placement groups are pending
|
|
||||||
os.environ["TUNE_TRIAL_STARTUP_GRACE_PERIOD"] = "0"
|
|
||||||
os.environ["TUNE_TRIAL_RESULT_WAIT_TIME_S"] = "99999"
|
|
||||||
os.environ["TUNE_MAX_PENDING_TRIALS_PG"] = "auto"
|
os.environ["TUNE_MAX_PENDING_TRIALS_PG"] = "auto"
|
||||||
|
|
||||||
def mock_trial(self, status, i):
|
def mock_trial(self, status, i):
|
||||||
|
|
|
@ -10,7 +10,7 @@ from ray import tune
|
||||||
from ray.rllib import _register_all
|
from ray.rllib import _register_all
|
||||||
from ray.tune import Trainable
|
from ray.tune import Trainable
|
||||||
from ray.tune.callback import Callback
|
from ray.tune.callback import Callback
|
||||||
from ray.tune.ray_trial_executor import RayTrialExecutor
|
from ray.tune.ray_trial_executor import RayTrialExecutor, ExecutorEventType
|
||||||
from ray.tune.registry import _global_registry, TRAINABLE_CLASS
|
from ray.tune.registry import _global_registry, TRAINABLE_CLASS
|
||||||
from ray.tune.result import PID, TRAINING_ITERATION, TRIAL_ID
|
from ray.tune.result import PID, TRAINING_ITERATION, TRIAL_ID
|
||||||
from ray.tune.suggest import BasicVariantGenerator
|
from ray.tune.suggest import BasicVariantGenerator
|
||||||
|
@ -83,12 +83,6 @@ class TrialExecutorInsufficientResourcesTest(unittest.TestCase):
|
||||||
|
|
||||||
class RayTrialExecutorTest(unittest.TestCase):
|
class RayTrialExecutorTest(unittest.TestCase):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
# Wait up to five seconds for placement groups when starting a trial
|
|
||||||
os.environ["TUNE_PLACEMENT_GROUP_WAIT_S"] = "5"
|
|
||||||
# Block for results even when placement groups are pending
|
|
||||||
os.environ["TUNE_TRIAL_STARTUP_GRACE_PERIOD"] = "0"
|
|
||||||
os.environ["TUNE_TRIAL_RESULT_WAIT_TIME_S"] = "99999"
|
|
||||||
|
|
||||||
self.trial_executor = RayTrialExecutor()
|
self.trial_executor = RayTrialExecutor()
|
||||||
ray.init(num_cpus=2, ignore_reinit_error=True)
|
ray.init(num_cpus=2, ignore_reinit_error=True)
|
||||||
_register_all() # Needed for flaky tests
|
_register_all() # Needed for flaky tests
|
||||||
|
@ -97,34 +91,63 @@ class RayTrialExecutorTest(unittest.TestCase):
|
||||||
ray.shutdown()
|
ray.shutdown()
|
||||||
_register_all() # re-register the evicted objects
|
_register_all() # re-register the evicted objects
|
||||||
|
|
||||||
|
def _simulate_starting_trial(self, trial):
|
||||||
|
future_result = self.trial_executor.get_next_executor_event(
|
||||||
|
live_trials={trial}, next_trial_exists=True
|
||||||
|
)
|
||||||
|
assert future_result.type == ExecutorEventType.PG_READY
|
||||||
|
self.assertTrue(self.trial_executor.start_trial(trial))
|
||||||
|
self.assertEqual(Trial.RUNNING, trial.status)
|
||||||
|
|
||||||
|
def _simulate_getting_result(self, trial):
|
||||||
|
while True:
|
||||||
|
future_result = self.trial_executor.get_next_executor_event(
|
||||||
|
live_trials={trial}, next_trial_exists=False
|
||||||
|
)
|
||||||
|
if future_result.type == ExecutorEventType.TRAINING_RESULT:
|
||||||
|
break
|
||||||
|
if isinstance(future_result.result, list):
|
||||||
|
for r in future_result.result:
|
||||||
|
trial.update_last_result(r)
|
||||||
|
else:
|
||||||
|
trial.update_last_result(future_result.result)
|
||||||
|
|
||||||
|
def _simulate_saving(self, trial):
|
||||||
|
checkpoint = self.trial_executor.save(trial, Checkpoint.PERSISTENT)
|
||||||
|
self.assertEqual(checkpoint, trial.saving_to)
|
||||||
|
self.assertEqual(trial.checkpoint.value, None)
|
||||||
|
future_result = self.trial_executor.get_next_executor_event(
|
||||||
|
live_trials={trial}, next_trial_exists=False
|
||||||
|
)
|
||||||
|
assert future_result.type == ExecutorEventType.SAVING_RESULT
|
||||||
|
self.process_trial_save(trial, future_result.result)
|
||||||
|
self.assertEqual(checkpoint, trial.checkpoint)
|
||||||
|
|
||||||
def testStartStop(self):
|
def testStartStop(self):
|
||||||
trial = Trial("__fake")
|
trial = Trial("__fake")
|
||||||
self.trial_executor.start_trial(trial)
|
self._simulate_starting_trial(trial)
|
||||||
running = self.trial_executor.get_running_trials()
|
|
||||||
self.assertEqual(1, len(running))
|
|
||||||
self.trial_executor.stop_trial(trial)
|
self.trial_executor.stop_trial(trial)
|
||||||
|
|
||||||
def testAsyncSave(self):
|
def testAsyncSave(self):
|
||||||
"""Tests that saved checkpoint value not immediately set."""
|
"""Tests that saved checkpoint value not immediately set."""
|
||||||
trial = Trial("__fake")
|
trial = Trial("__fake")
|
||||||
self.trial_executor.start_trial(trial)
|
self._simulate_starting_trial(trial)
|
||||||
self.assertEqual(Trial.RUNNING, trial.status)
|
|
||||||
trial.last_result = self.trial_executor.fetch_result(trial)[-1]
|
self._simulate_getting_result(trial)
|
||||||
checkpoint = self.trial_executor.save(trial, Checkpoint.PERSISTENT)
|
|
||||||
self.assertEqual(checkpoint, trial.saving_to)
|
self._simulate_saving(trial)
|
||||||
self.assertEqual(trial.checkpoint.value, None)
|
|
||||||
self.process_trial_save(trial)
|
|
||||||
self.assertEqual(checkpoint, trial.checkpoint)
|
|
||||||
self.trial_executor.stop_trial(trial)
|
self.trial_executor.stop_trial(trial)
|
||||||
self.assertEqual(Trial.TERMINATED, trial.status)
|
self.assertEqual(Trial.TERMINATED, trial.status)
|
||||||
|
|
||||||
def testSaveRestore(self):
|
def testSaveRestore(self):
|
||||||
trial = Trial("__fake")
|
trial = Trial("__fake")
|
||||||
self.trial_executor.start_trial(trial)
|
self._simulate_starting_trial(trial)
|
||||||
self.assertEqual(Trial.RUNNING, trial.status)
|
|
||||||
trial.last_result = self.trial_executor.fetch_result(trial)[-1]
|
self._simulate_getting_result(trial)
|
||||||
self.trial_executor.save(trial, Checkpoint.PERSISTENT)
|
|
||||||
self.process_trial_save(trial)
|
self._simulate_saving(trial)
|
||||||
|
|
||||||
self.trial_executor.restore(trial)
|
self.trial_executor.restore(trial)
|
||||||
self.trial_executor.stop_trial(trial)
|
self.trial_executor.stop_trial(trial)
|
||||||
self.assertEqual(Trial.TERMINATED, trial.status)
|
self.assertEqual(Trial.TERMINATED, trial.status)
|
||||||
|
@ -132,40 +155,44 @@ class RayTrialExecutorTest(unittest.TestCase):
|
||||||
def testPauseResume(self):
|
def testPauseResume(self):
|
||||||
"""Tests that pausing works for trials in flight."""
|
"""Tests that pausing works for trials in flight."""
|
||||||
trial = Trial("__fake")
|
trial = Trial("__fake")
|
||||||
self.trial_executor.start_trial(trial)
|
self._simulate_starting_trial(trial)
|
||||||
self.assertEqual(Trial.RUNNING, trial.status)
|
|
||||||
self.trial_executor.pause_trial(trial)
|
self.trial_executor.pause_trial(trial)
|
||||||
self.assertEqual(Trial.PAUSED, trial.status)
|
self.assertEqual(Trial.PAUSED, trial.status)
|
||||||
self.trial_executor.start_trial(trial)
|
|
||||||
self.assertEqual(Trial.RUNNING, trial.status)
|
self._simulate_starting_trial(trial)
|
||||||
|
|
||||||
self.trial_executor.stop_trial(trial)
|
self.trial_executor.stop_trial(trial)
|
||||||
self.assertEqual(Trial.TERMINATED, trial.status)
|
self.assertEqual(Trial.TERMINATED, trial.status)
|
||||||
|
|
||||||
def testSavePauseResumeErrorRestore(self):
|
def testSavePauseResumeErrorRestore(self):
|
||||||
"""Tests that pause checkpoint does not replace restore checkpoint."""
|
"""Tests that pause checkpoint does not replace restore checkpoint."""
|
||||||
trial = Trial("__fake")
|
trial = Trial("__fake")
|
||||||
self.trial_executor.start_trial(trial)
|
self._simulate_starting_trial(trial)
|
||||||
trial.last_result = self.trial_executor.fetch_result(trial)[-1]
|
|
||||||
|
self._simulate_getting_result(trial)
|
||||||
|
|
||||||
# Save
|
# Save
|
||||||
checkpoint = self.trial_executor.save(trial, Checkpoint.PERSISTENT)
|
self._simulate_saving(trial)
|
||||||
self.assertEqual(Trial.RUNNING, trial.status)
|
|
||||||
self.assertEqual(checkpoint.storage, Checkpoint.PERSISTENT)
|
|
||||||
# Process save result (simulates trial runner)
|
|
||||||
self.process_trial_save(trial)
|
|
||||||
# Train
|
# Train
|
||||||
self.trial_executor.continue_training(trial)
|
self.trial_executor.continue_training(trial)
|
||||||
trial.last_result = self.trial_executor.fetch_result(trial)[-1]
|
self._simulate_getting_result(trial)
|
||||||
|
|
||||||
# Pause
|
# Pause
|
||||||
self.trial_executor.pause_trial(trial)
|
self.trial_executor.pause_trial(trial)
|
||||||
self.assertEqual(Trial.PAUSED, trial.status)
|
self.assertEqual(Trial.PAUSED, trial.status)
|
||||||
self.assertEqual(trial.checkpoint.storage, Checkpoint.MEMORY)
|
self.assertEqual(trial.checkpoint.storage, Checkpoint.MEMORY)
|
||||||
|
|
||||||
# Resume
|
# Resume
|
||||||
self.trial_executor.start_trial(trial)
|
self._simulate_starting_trial(trial)
|
||||||
self.assertEqual(Trial.RUNNING, trial.status)
|
|
||||||
# Error
|
# Error
|
||||||
trial.set_status(Trial.ERROR)
|
trial.set_status(Trial.ERROR)
|
||||||
|
|
||||||
# Restore
|
# Restore
|
||||||
self.trial_executor.restore(trial)
|
self.trial_executor.restore(trial)
|
||||||
|
|
||||||
self.trial_executor.stop_trial(trial)
|
self.trial_executor.stop_trial(trial)
|
||||||
self.assertEqual(Trial.TERMINATED, trial.status)
|
self.assertEqual(Trial.TERMINATED, trial.status)
|
||||||
|
|
||||||
|
@ -178,13 +205,14 @@ class RayTrialExecutorTest(unittest.TestCase):
|
||||||
def testPauseResume2(self):
|
def testPauseResume2(self):
|
||||||
"""Tests that pausing works for trials being processed."""
|
"""Tests that pausing works for trials being processed."""
|
||||||
trial = Trial("__fake")
|
trial = Trial("__fake")
|
||||||
self.trial_executor.start_trial(trial)
|
self._simulate_starting_trial(trial)
|
||||||
self.assertEqual(Trial.RUNNING, trial.status)
|
|
||||||
self.trial_executor.fetch_result(trial)
|
self._simulate_getting_result(trial)
|
||||||
|
|
||||||
self.trial_executor.pause_trial(trial)
|
self.trial_executor.pause_trial(trial)
|
||||||
self.assertEqual(Trial.PAUSED, trial.status)
|
self.assertEqual(Trial.PAUSED, trial.status)
|
||||||
self.trial_executor.start_trial(trial)
|
|
||||||
self.assertEqual(Trial.RUNNING, trial.status)
|
self._simulate_starting_trial(trial)
|
||||||
self.trial_executor.stop_trial(trial)
|
self.trial_executor.stop_trial(trial)
|
||||||
self.assertEqual(Trial.TERMINATED, trial.status)
|
self.assertEqual(Trial.TERMINATED, trial.status)
|
||||||
|
|
||||||
|
@ -199,15 +227,17 @@ class RayTrialExecutorTest(unittest.TestCase):
|
||||||
base = max(result_buffer_length, 1)
|
base = max(result_buffer_length, 1)
|
||||||
|
|
||||||
trial = Trial("__fake")
|
trial = Trial("__fake")
|
||||||
self.trial_executor.start_trial(trial)
|
self._simulate_starting_trial(trial)
|
||||||
self.assertEqual(Trial.RUNNING, trial.status)
|
|
||||||
trial.last_result = self.trial_executor.fetch_result(trial)[-1]
|
self._simulate_getting_result(trial)
|
||||||
self.assertEqual(trial.last_result.get(TRAINING_ITERATION), base)
|
self.assertEqual(trial.last_result.get(TRAINING_ITERATION), base)
|
||||||
|
|
||||||
self.trial_executor.pause_trial(trial)
|
self.trial_executor.pause_trial(trial)
|
||||||
self.assertEqual(Trial.PAUSED, trial.status)
|
self.assertEqual(Trial.PAUSED, trial.status)
|
||||||
self.trial_executor.start_trial(trial)
|
|
||||||
self.assertEqual(Trial.RUNNING, trial.status)
|
self._simulate_starting_trial(trial)
|
||||||
trial.last_result = self.trial_executor.fetch_result(trial)[-1]
|
|
||||||
|
self._simulate_getting_result(trial)
|
||||||
self.assertEqual(trial.last_result.get(TRAINING_ITERATION), base * 2)
|
self.assertEqual(trial.last_result.get(TRAINING_ITERATION), base * 2)
|
||||||
self.trial_executor.stop_trial(trial)
|
self.trial_executor.stop_trial(trial)
|
||||||
self.assertEqual(Trial.TERMINATED, trial.status)
|
self.assertEqual(Trial.TERMINATED, trial.status)
|
||||||
|
@ -224,7 +254,7 @@ class RayTrialExecutorTest(unittest.TestCase):
|
||||||
def testNoResetTrial(self):
|
def testNoResetTrial(self):
|
||||||
"""Tests that reset handles NotImplemented properly."""
|
"""Tests that reset handles NotImplemented properly."""
|
||||||
trial = Trial("__fake")
|
trial = Trial("__fake")
|
||||||
self.trial_executor.start_trial(trial)
|
self._simulate_starting_trial(trial)
|
||||||
exists = self.trial_executor.reset_trial(trial, {}, "modified_mock")
|
exists = self.trial_executor.reset_trial(trial, {}, "modified_mock")
|
||||||
self.assertEqual(exists, False)
|
self.assertEqual(exists, False)
|
||||||
self.assertEqual(Trial.RUNNING, trial.status)
|
self.assertEqual(Trial.RUNNING, trial.status)
|
||||||
|
@ -248,18 +278,18 @@ class RayTrialExecutorTest(unittest.TestCase):
|
||||||
"grid_search",
|
"grid_search",
|
||||||
)
|
)
|
||||||
trial = trials[0]
|
trial = trials[0]
|
||||||
self.trial_executor.start_trial(trial)
|
self._simulate_starting_trial(trial)
|
||||||
exists = self.trial_executor.reset_trial(trial, {"hi": 1}, "modified_mock")
|
exists = self.trial_executor.reset_trial(trial, {"hi": 1}, "modified_mock")
|
||||||
self.assertEqual(exists, True)
|
self.assertEqual(exists, True)
|
||||||
self.assertEqual(trial.config.get("hi"), 1)
|
self.assertEqual(trial.config.get("hi"), 1)
|
||||||
self.assertEqual(trial.experiment_tag, "modified_mock")
|
self.assertEqual(trial.experiment_tag, "modified_mock")
|
||||||
self.assertEqual(Trial.RUNNING, trial.status)
|
self.assertEqual(Trial.RUNNING, trial.status)
|
||||||
|
|
||||||
def testForceTrialCleanup(self):
|
def testTrialCleanup(self):
|
||||||
class B(Trainable):
|
class B(Trainable):
|
||||||
def step(self):
|
def step(self):
|
||||||
print("Step start")
|
print("Step start")
|
||||||
time.sleep(10)
|
time.sleep(4)
|
||||||
print("Step done")
|
print("Step done")
|
||||||
return dict(my_metric=1, timesteps_this_iter=1, done=True)
|
return dict(my_metric=1, timesteps_this_iter=1, done=True)
|
||||||
|
|
||||||
|
@ -269,7 +299,7 @@ class RayTrialExecutorTest(unittest.TestCase):
|
||||||
|
|
||||||
def cleanup(self):
|
def cleanup(self):
|
||||||
print("Cleanup start")
|
print("Cleanup start")
|
||||||
time.sleep(10)
|
time.sleep(4)
|
||||||
print("Cleanup done")
|
print("Cleanup done")
|
||||||
|
|
||||||
# First check if the trials terminate gracefully by default
|
# First check if the trials terminate gracefully by default
|
||||||
|
@ -281,15 +311,15 @@ class RayTrialExecutorTest(unittest.TestCase):
|
||||||
"grid_search",
|
"grid_search",
|
||||||
)
|
)
|
||||||
trial = trials[0]
|
trial = trials[0]
|
||||||
self.trial_executor.start_trial(trial)
|
self._simulate_starting_trial(trial)
|
||||||
self.assertEqual(Trial.RUNNING, trial.status)
|
time.sleep(1)
|
||||||
time.sleep(5)
|
|
||||||
print("Stop trial")
|
print("Stop trial")
|
||||||
self.trial_executor.stop_trial(trial)
|
self.trial_executor.stop_trial(trial)
|
||||||
print("Start trial cleanup")
|
print("Start trial cleanup")
|
||||||
start = time.time()
|
start = time.time()
|
||||||
self.trial_executor.cleanup([trial])
|
self.trial_executor.cleanup([trial])
|
||||||
self.assertGreaterEqual(time.time() - start, 12.0)
|
# 4 - 1 + 4.
|
||||||
|
self.assertGreaterEqual(time.time() - start, 6)
|
||||||
|
|
||||||
# Check forceful termination. It should run for much less than the
|
# Check forceful termination. It should run for much less than the
|
||||||
# sleep periods in the Trainable
|
# sleep periods in the Trainable
|
||||||
|
@ -304,15 +334,16 @@ class RayTrialExecutorTest(unittest.TestCase):
|
||||||
os.environ["TUNE_FORCE_TRIAL_CLEANUP_S"] = "1"
|
os.environ["TUNE_FORCE_TRIAL_CLEANUP_S"] = "1"
|
||||||
self.trial_executor = RayTrialExecutor()
|
self.trial_executor = RayTrialExecutor()
|
||||||
os.environ["TUNE_FORCE_TRIAL_CLEANUP_S"] = "0"
|
os.environ["TUNE_FORCE_TRIAL_CLEANUP_S"] = "0"
|
||||||
self.trial_executor.start_trial(trial)
|
self._simulate_starting_trial(trial)
|
||||||
self.assertEqual(Trial.RUNNING, trial.status)
|
self.assertEqual(Trial.RUNNING, trial.status)
|
||||||
time.sleep(5)
|
time.sleep(1)
|
||||||
print("Stop trial")
|
print("Stop trial")
|
||||||
self.trial_executor.stop_trial(trial)
|
self.trial_executor.stop_trial(trial)
|
||||||
print("Start trial cleanup")
|
print("Start trial cleanup")
|
||||||
start = time.time()
|
start = time.time()
|
||||||
self.trial_executor.cleanup([trial])
|
self.trial_executor.cleanup([trial])
|
||||||
self.assertLess(time.time() - start, 5.0)
|
# less than 1 with some margin.
|
||||||
|
self.assertLess(time.time() - start, 2.0)
|
||||||
|
|
||||||
# also check if auto-filled metrics were returned
|
# also check if auto-filled metrics were returned
|
||||||
self.assertIn(PID, trial.last_result)
|
self.assertIn(PID, trial.last_result)
|
||||||
|
@ -332,10 +363,9 @@ class RayTrialExecutorTest(unittest.TestCase):
|
||||||
break
|
break
|
||||||
return trials
|
return trials
|
||||||
|
|
||||||
def process_trial_save(self, trial):
|
def process_trial_save(self, trial, checkpoint_value):
|
||||||
"""Simulates trial runner save."""
|
"""Simulates trial runner save."""
|
||||||
checkpoint = trial.saving_to
|
checkpoint = trial.saving_to
|
||||||
checkpoint_value = self.trial_executor.fetch_result(trial)[-1]
|
|
||||||
checkpoint.value = checkpoint_value
|
checkpoint.value = checkpoint_value
|
||||||
trial.on_checkpoint(checkpoint)
|
trial.on_checkpoint(checkpoint)
|
||||||
|
|
||||||
|
@ -460,10 +490,8 @@ class LocalModeExecutorTest(RayTrialExecutorTest):
|
||||||
ray.shutdown()
|
ray.shutdown()
|
||||||
_register_all() # re-register the evicted objects
|
_register_all() # re-register the evicted objects
|
||||||
|
|
||||||
def testForceTrialCleanup(self):
|
def testTrialCleanup(self):
|
||||||
self.skipTest(
|
self.skipTest("Skipping as trial cleanup is not applicable" " for local mode.")
|
||||||
"Skipping as force trial cleanup is not applicable" " for local mode."
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
@ -15,6 +15,7 @@ import ray
|
||||||
from ray.rllib import _register_all
|
from ray.rllib import _register_all
|
||||||
|
|
||||||
from ray import tune
|
from ray import tune
|
||||||
|
from ray.tune import TuneError
|
||||||
from ray.tune.integration.docker import DockerSyncer
|
from ray.tune.integration.docker import DockerSyncer
|
||||||
from ray.tune.integration.kubernetes import KubernetesSyncer
|
from ray.tune.integration.kubernetes import KubernetesSyncer
|
||||||
from ray.tune.sync_client import NOOP
|
from ray.tune.sync_client import NOOP
|
||||||
|
@ -29,12 +30,6 @@ from ray.tune.utils.callback import create_default_callbacks
|
||||||
|
|
||||||
class TestSyncFunctionality(unittest.TestCase):
|
class TestSyncFunctionality(unittest.TestCase):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
# Wait up to 1.5 seconds for placement groups when starting a trial
|
|
||||||
os.environ["TUNE_PLACEMENT_GROUP_WAIT_S"] = "1.5"
|
|
||||||
# Block for results even when placement groups are pending
|
|
||||||
os.environ["TUNE_TRIAL_STARTUP_GRACE_PERIOD"] = "0"
|
|
||||||
os.environ["TUNE_TRIAL_RESULT_WAIT_TIME_S"] = "99999"
|
|
||||||
|
|
||||||
ray.init(num_cpus=2)
|
ray.init(num_cpus=2)
|
||||||
|
|
||||||
def tearDown(self):
|
def tearDown(self):
|
||||||
|
@ -120,7 +115,7 @@ class TestSyncFunctionality(unittest.TestCase):
|
||||||
|
|
||||||
def testClusterProperString(self):
|
def testClusterProperString(self):
|
||||||
"""Tests that invalid commands throw.."""
|
"""Tests that invalid commands throw.."""
|
||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(TuneError):
|
||||||
# This raises ValueError because logger is init in safe zone.
|
# This raises ValueError because logger is init in safe zone.
|
||||||
sync_config = tune.SyncConfig(syncer="ls {target}")
|
sync_config = tune.SyncConfig(syncer="ls {target}")
|
||||||
[trial] = tune.run(
|
[trial] = tune.run(
|
||||||
|
@ -131,7 +126,7 @@ class TestSyncFunctionality(unittest.TestCase):
|
||||||
sync_config=sync_config,
|
sync_config=sync_config,
|
||||||
).trials
|
).trials
|
||||||
|
|
||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(TuneError):
|
||||||
# This raises ValueError because logger is init in safe zone.
|
# This raises ValueError because logger is init in safe zone.
|
||||||
sync_config = tune.SyncConfig(syncer="ls {source}")
|
sync_config = tune.SyncConfig(syncer="ls {source}")
|
||||||
[trial] = tune.run(
|
[trial] = tune.run(
|
||||||
|
|
|
@ -19,29 +19,11 @@ from ray.tune.utils.placement_groups import PlacementGroupFactory
|
||||||
|
|
||||||
class TrialRunnerTest(unittest.TestCase):
|
class TrialRunnerTest(unittest.TestCase):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
# Wait up to five seconds for placement groups when starting a trial
|
|
||||||
os.environ["TUNE_PLACEMENT_GROUP_WAIT_S"] = "5"
|
|
||||||
# Block for results even when placement groups are pending
|
|
||||||
os.environ["TUNE_TRIAL_STARTUP_GRACE_PERIOD"] = "0"
|
|
||||||
os.environ["TUNE_TRIAL_RESULT_WAIT_TIME_S"] = "99999"
|
|
||||||
_register_all() # re-register the evicted objects
|
_register_all() # re-register the evicted objects
|
||||||
|
|
||||||
def tearDown(self):
|
def tearDown(self):
|
||||||
ray.shutdown()
|
ray.shutdown()
|
||||||
|
|
||||||
def testTrialStatus(self):
|
|
||||||
ray.init(num_cpus=2)
|
|
||||||
|
|
||||||
trial = Trial("__fake")
|
|
||||||
trial_executor = RayTrialExecutor()
|
|
||||||
self.assertEqual(trial.status, Trial.PENDING)
|
|
||||||
trial_executor.start_trial(trial)
|
|
||||||
self.assertEqual(trial.status, Trial.RUNNING)
|
|
||||||
trial_executor.stop_trial(trial)
|
|
||||||
self.assertEqual(trial.status, Trial.TERMINATED)
|
|
||||||
trial_executor.stop_trial(trial, error=True)
|
|
||||||
self.assertEqual(trial.status, Trial.ERROR)
|
|
||||||
|
|
||||||
def testExperimentTagTruncation(self):
|
def testExperimentTagTruncation(self):
|
||||||
ray.init(num_cpus=2)
|
ray.init(num_cpus=2)
|
||||||
|
|
||||||
|
@ -74,7 +56,8 @@ class TrialRunnerTest(unittest.TestCase):
|
||||||
|
|
||||||
def testExtraResources(self):
|
def testExtraResources(self):
|
||||||
ray.init(num_cpus=4, num_gpus=2)
|
ray.init(num_cpus=4, num_gpus=2)
|
||||||
runner = TrialRunner()
|
snapshot = TrialStatusSnapshot()
|
||||||
|
runner = TrialRunner(callbacks=[TrialStatusSnapshotTaker(snapshot)])
|
||||||
kwargs = {
|
kwargs = {
|
||||||
"stopping_criterion": {"training_iteration": 1},
|
"stopping_criterion": {"training_iteration": 1},
|
||||||
"placement_group_factory": PlacementGroupFactory(
|
"placement_group_factory": PlacementGroupFactory(
|
||||||
|
@ -85,17 +68,18 @@ class TrialRunnerTest(unittest.TestCase):
|
||||||
for t in trials:
|
for t in trials:
|
||||||
runner.add_trial(t)
|
runner.add_trial(t)
|
||||||
|
|
||||||
|
while not runner.is_finished():
|
||||||
runner.step()
|
runner.step()
|
||||||
self.assertEqual(trials[0].status, Trial.RUNNING)
|
|
||||||
self.assertEqual(trials[1].status, Trial.PENDING)
|
|
||||||
|
|
||||||
runner.step()
|
self.assertLess(snapshot.max_running_trials(), 2)
|
||||||
self.assertEqual(trials[0].status, Trial.TERMINATED)
|
self.assertTrue(snapshot.all_trials_are_terminated())
|
||||||
self.assertEqual(trials[1].status, Trial.PENDING)
|
|
||||||
|
|
||||||
def testCustomResources(self):
|
def testCustomResources(self):
|
||||||
ray.init(num_cpus=4, num_gpus=2, resources={"a": 2})
|
ray.init(num_cpus=4, num_gpus=2, resources={"a": 2})
|
||||||
runner = TrialRunner()
|
# Since each trial will occupy the full custom resources,
|
||||||
|
# there are at most 1 trial running at any given moment.
|
||||||
|
snapshot = TrialStatusSnapshot()
|
||||||
|
runner = TrialRunner(callbacks=[TrialStatusSnapshotTaker(snapshot)])
|
||||||
kwargs = {
|
kwargs = {
|
||||||
"stopping_criterion": {"training_iteration": 1},
|
"stopping_criterion": {"training_iteration": 1},
|
||||||
"placement_group_factory": PlacementGroupFactory([{"CPU": 1, "a": 2}]),
|
"placement_group_factory": PlacementGroupFactory([{"CPU": 1, "a": 2}]),
|
||||||
|
@ -104,16 +88,18 @@ class TrialRunnerTest(unittest.TestCase):
|
||||||
for t in trials:
|
for t in trials:
|
||||||
runner.add_trial(t)
|
runner.add_trial(t)
|
||||||
|
|
||||||
|
while not runner.is_finished():
|
||||||
runner.step()
|
runner.step()
|
||||||
self.assertEqual(trials[0].status, Trial.RUNNING)
|
|
||||||
self.assertEqual(trials[1].status, Trial.PENDING)
|
self.assertLess(snapshot.max_running_trials(), 2)
|
||||||
runner.step()
|
self.assertTrue(snapshot.all_trials_are_terminated())
|
||||||
self.assertEqual(trials[0].status, Trial.TERMINATED)
|
|
||||||
self.assertEqual(trials[1].status, Trial.PENDING)
|
|
||||||
|
|
||||||
def testExtraCustomResources(self):
|
def testExtraCustomResources(self):
|
||||||
ray.init(num_cpus=4, num_gpus=2, resources={"a": 2})
|
ray.init(num_cpus=4, num_gpus=2, resources={"a": 2})
|
||||||
runner = TrialRunner()
|
# Since each trial will occupy the full custom resources,
|
||||||
|
# there are at most 1 trial running at any given moment.
|
||||||
|
snapshot = TrialStatusSnapshot()
|
||||||
|
runner = TrialRunner(callbacks=[TrialStatusSnapshotTaker(snapshot)])
|
||||||
kwargs = {
|
kwargs = {
|
||||||
"stopping_criterion": {"training_iteration": 1},
|
"stopping_criterion": {"training_iteration": 1},
|
||||||
"placement_group_factory": PlacementGroupFactory([{"CPU": 1}, {"a": 2}]),
|
"placement_group_factory": PlacementGroupFactory([{"CPU": 1}, {"a": 2}]),
|
||||||
|
@ -122,14 +108,11 @@ class TrialRunnerTest(unittest.TestCase):
|
||||||
for t in trials:
|
for t in trials:
|
||||||
runner.add_trial(t)
|
runner.add_trial(t)
|
||||||
|
|
||||||
|
while not runner.is_finished():
|
||||||
runner.step()
|
runner.step()
|
||||||
self.assertEqual(trials[0].status, Trial.RUNNING)
|
|
||||||
self.assertEqual(trials[1].status, Trial.PENDING)
|
|
||||||
|
|
||||||
runner.step()
|
self.assertLess(snapshot.max_running_trials(), 2)
|
||||||
self.assertTrue(sum(t.status == Trial.RUNNING for t in trials) < 2)
|
self.assertTrue(snapshot.all_trials_are_terminated())
|
||||||
self.assertEqual(trials[0].status, Trial.TERMINATED)
|
|
||||||
self.assertEqual(trials[1].status, Trial.PENDING)
|
|
||||||
|
|
||||||
def testFractionalGpus(self):
|
def testFractionalGpus(self):
|
||||||
ray.init(num_cpus=4, num_gpus=1)
|
ray.init(num_cpus=4, num_gpus=1)
|
||||||
|
@ -209,12 +192,7 @@ class TrialRunnerTest(unittest.TestCase):
|
||||||
for t in trials:
|
for t in trials:
|
||||||
runner.add_trial(t)
|
runner.add_trial(t)
|
||||||
|
|
||||||
runner.step()
|
while not runner.is_finished():
|
||||||
self.assertEqual(trials[0].status, Trial.RUNNING)
|
|
||||||
|
|
||||||
runner.step()
|
|
||||||
self.assertEqual(trials[0].status, Trial.RUNNING)
|
|
||||||
|
|
||||||
runner.step()
|
runner.step()
|
||||||
self.assertEqual(trials[0].status, Trial.TERMINATED)
|
self.assertEqual(trials[0].status, Trial.TERMINATED)
|
||||||
self.assertRaises(TuneError, runner.step)
|
self.assertRaises(TuneError, runner.step)
|
||||||
|
@ -224,15 +202,23 @@ class TrialRunnerTest(unittest.TestCase):
|
||||||
ray.init(num_cpus=2)
|
ray.init(num_cpus=2)
|
||||||
|
|
||||||
class ChangingScheduler(FIFOScheduler):
|
class ChangingScheduler(FIFOScheduler):
|
||||||
|
def __init__(self):
|
||||||
|
self._has_received_one_trial_result = False
|
||||||
|
|
||||||
|
# For figuring out how many runner.step there are.
|
||||||
|
def has_received_one_trial_result(self):
|
||||||
|
return self._has_received_one_trial_result
|
||||||
|
|
||||||
def on_trial_result(self, trial_runner, trial, result):
|
def on_trial_result(self, trial_runner, trial, result):
|
||||||
if result["training_iteration"] == 1:
|
if result["training_iteration"] == 1:
|
||||||
|
self._has_received_one_trial_result = True
|
||||||
executor = trial_runner.trial_executor
|
executor = trial_runner.trial_executor
|
||||||
executor.stop_trial(trial)
|
executor.pause_trial(trial)
|
||||||
trial.update_resources(dict(cpu=2, gpu=0))
|
trial.update_resources(dict(cpu=2, gpu=0))
|
||||||
executor.start_trial(trial)
|
return TrialScheduler.NOOP
|
||||||
return TrialScheduler.CONTINUE
|
|
||||||
|
|
||||||
runner = TrialRunner(scheduler=ChangingScheduler())
|
scheduler = ChangingScheduler()
|
||||||
|
runner = TrialRunner(scheduler=scheduler)
|
||||||
kwargs = {
|
kwargs = {
|
||||||
"stopping_criterion": {"training_iteration": 2},
|
"stopping_criterion": {"training_iteration": 2},
|
||||||
"resources": Resources(cpu=1, gpu=0),
|
"resources": Resources(cpu=1, gpu=0),
|
||||||
|
@ -250,8 +236,11 @@ class TrialRunnerTest(unittest.TestCase):
|
||||||
ValueError, lambda: trials[0].update_resources(dict(cpu=2, gpu=0))
|
ValueError, lambda: trials[0].update_resources(dict(cpu=2, gpu=0))
|
||||||
)
|
)
|
||||||
|
|
||||||
|
while not scheduler.has_received_one_trial_result():
|
||||||
|
runner.step()
|
||||||
|
self.assertEqual(trials[0].status, Trial.PAUSED)
|
||||||
|
# extra step for tune loop to stage the resource requests.
|
||||||
runner.step()
|
runner.step()
|
||||||
self.assertEqual(trials[0].status, Trial.RUNNING)
|
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
runner.trial_executor._pg_manager.occupied_resources().get("CPU"), 2
|
runner.trial_executor._pg_manager.occupied_resources().get("CPU"), 2
|
||||||
)
|
)
|
||||||
|
|
|
@ -13,6 +13,7 @@ from ray.tune.trial import Trial
|
||||||
from ray.tune.trial_runner import TrialRunner
|
from ray.tune.trial_runner import TrialRunner
|
||||||
from ray.tune.resources import Resources
|
from ray.tune.resources import Resources
|
||||||
from ray.tune.suggest import BasicVariantGenerator
|
from ray.tune.suggest import BasicVariantGenerator
|
||||||
|
from ray.tune.tests.test_trial_runner_utils import TrialResultObserver
|
||||||
|
|
||||||
|
|
||||||
def create_mock_components():
|
def create_mock_components():
|
||||||
|
@ -37,11 +38,6 @@ def create_mock_components():
|
||||||
class TrialRunnerTest2(unittest.TestCase):
|
class TrialRunnerTest2(unittest.TestCase):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
os.environ["TUNE_STATE_REFRESH_PERIOD"] = "0.1"
|
os.environ["TUNE_STATE_REFRESH_PERIOD"] = "0.1"
|
||||||
# Wait up to five seconds for placement groups when starting a trial
|
|
||||||
os.environ["TUNE_PLACEMENT_GROUP_WAIT_S"] = "5"
|
|
||||||
# Block for results even when placement groups are pending
|
|
||||||
os.environ["TUNE_TRIAL_STARTUP_GRACE_PERIOD"] = "0"
|
|
||||||
os.environ["TUNE_TRIAL_RESULT_WAIT_TIME_S"] = "99999"
|
|
||||||
|
|
||||||
def tearDown(self):
|
def tearDown(self):
|
||||||
ray.shutdown()
|
ray.shutdown()
|
||||||
|
@ -89,12 +85,9 @@ class TrialRunnerTest2(unittest.TestCase):
|
||||||
runner.add_trial(Trial("__fake", **kwargs))
|
runner.add_trial(Trial("__fake", **kwargs))
|
||||||
trials = runner.get_trials()
|
trials = runner.get_trials()
|
||||||
|
|
||||||
runner.step() # Start trial
|
while not runner.is_finished():
|
||||||
self.assertEqual(trials[0].status, Trial.RUNNING)
|
runner.step()
|
||||||
runner.step() # Process result, dispatch save
|
|
||||||
self.assertEqual(trials[0].status, Trial.RUNNING)
|
|
||||||
runner.step() # Process save
|
|
||||||
runner.step() # Error
|
|
||||||
self.assertEqual(trials[0].status, Trial.ERROR)
|
self.assertEqual(trials[0].status, Trial.ERROR)
|
||||||
self.assertEqual(trials[0].num_failures, 1)
|
self.assertEqual(trials[0].num_failures, 1)
|
||||||
self.assertEqual(len(searchalg.errored_trials), 1)
|
self.assertEqual(len(searchalg.errored_trials), 1)
|
||||||
|
@ -107,6 +100,7 @@ class TrialRunnerTest2(unittest.TestCase):
|
||||||
runner = TrialRunner(searchalg, scheduler=scheduler)
|
runner = TrialRunner(searchalg, scheduler=scheduler)
|
||||||
|
|
||||||
kwargs = {
|
kwargs = {
|
||||||
|
"stopping_criterion": {"training_iteration": 2},
|
||||||
"resources": Resources(cpu=1, gpu=1),
|
"resources": Resources(cpu=1, gpu=1),
|
||||||
"checkpoint_freq": 1,
|
"checkpoint_freq": 1,
|
||||||
"max_failures": 1,
|
"max_failures": 1,
|
||||||
|
@ -117,18 +111,15 @@ class TrialRunnerTest2(unittest.TestCase):
|
||||||
runner.add_trial(Trial("__fake", **kwargs))
|
runner.add_trial(Trial("__fake", **kwargs))
|
||||||
trials = runner.get_trials()
|
trials = runner.get_trials()
|
||||||
|
|
||||||
runner.step() # Start trial
|
while not runner.is_finished():
|
||||||
self.assertEqual(trials[0].status, Trial.RUNNING)
|
runner.step()
|
||||||
runner.step() # Process result, dispatch save
|
self.assertEqual(trials[0].status, Trial.TERMINATED)
|
||||||
self.assertEqual(trials[0].status, Trial.RUNNING)
|
|
||||||
runner.step() # Process save
|
|
||||||
runner.step() # Error (transient), dispatch restore
|
|
||||||
self.assertEqual(trials[0].status, Trial.RUNNING)
|
|
||||||
self.assertEqual(trials[0].num_failures, 1)
|
self.assertEqual(trials[0].num_failures, 1)
|
||||||
runner.step() # Process restore
|
|
||||||
self.assertEqual(trials[0].status, Trial.RUNNING)
|
|
||||||
self.assertEqual(len(searchalg.errored_trials), 0)
|
self.assertEqual(len(searchalg.errored_trials), 0)
|
||||||
self.assertEqual(len(scheduler.errored_trials), 0)
|
# Notice this is 1 since during recovery, the previously errored trial
|
||||||
|
# is "requeued". This will call scheduler.on_trial_error.
|
||||||
|
# Searcher.on_trial_error is, however, not called in this process.
|
||||||
|
self.assertEqual(len(scheduler.errored_trials), 1)
|
||||||
|
|
||||||
def testFailureRecoveryMaxFailures(self):
|
def testFailureRecoveryMaxFailures(self):
|
||||||
ray.init(num_cpus=1, num_gpus=1)
|
ray.init(num_cpus=1, num_gpus=1)
|
||||||
|
@ -145,20 +136,8 @@ class TrialRunnerTest2(unittest.TestCase):
|
||||||
runner.add_trial(Trial("__fake", **kwargs))
|
runner.add_trial(Trial("__fake", **kwargs))
|
||||||
trials = runner.get_trials()
|
trials = runner.get_trials()
|
||||||
|
|
||||||
runner.step() # Start trial
|
while not runner.is_finished():
|
||||||
self.assertEqual(trials[0].status, Trial.RUNNING)
|
runner.step()
|
||||||
runner.step() # Process result, dispatch save
|
|
||||||
self.assertEqual(trials[0].status, Trial.RUNNING)
|
|
||||||
runner.step() # Process save
|
|
||||||
runner.step() # Error (transient), dispatch restore
|
|
||||||
self.assertEqual(trials[0].status, Trial.RUNNING)
|
|
||||||
self.assertEqual(trials[0].num_failures, 1)
|
|
||||||
runner.step() # Process restore
|
|
||||||
runner.step() # Error (transient), dispatch restore
|
|
||||||
self.assertEqual(trials[0].status, Trial.RUNNING)
|
|
||||||
self.assertEqual(trials[0].num_failures, 2)
|
|
||||||
runner.step() # Process restore
|
|
||||||
runner.step() # Error (terminal)
|
|
||||||
self.assertEqual(trials[0].status, Trial.ERROR)
|
self.assertEqual(trials[0].status, Trial.ERROR)
|
||||||
self.assertEqual(trials[0].num_failures, 3)
|
self.assertEqual(trials[0].num_failures, 3)
|
||||||
|
|
||||||
|
@ -178,13 +157,12 @@ class TrialRunnerTest2(unittest.TestCase):
|
||||||
runner.add_trial(Trial("__fake", **kwargs))
|
runner.add_trial(Trial("__fake", **kwargs))
|
||||||
trials = runner.get_trials()
|
trials = runner.get_trials()
|
||||||
|
|
||||||
runner.step() # Start trial
|
while not runner.is_finished():
|
||||||
self.assertEqual(trials[0].status, Trial.RUNNING)
|
runner.step()
|
||||||
runner.step() # Process result, dispatch save
|
|
||||||
self.assertEqual(trials[0].status, Trial.RUNNING)
|
|
||||||
runner.step() # Process save
|
|
||||||
runner.step() # Error
|
|
||||||
self.assertEqual(trials[0].status, Trial.ERROR)
|
self.assertEqual(trials[0].status, Trial.ERROR)
|
||||||
|
# Somehow with `fail_fast=True`, if one errors out, the others are
|
||||||
|
# then stopped with `TERMINATED` status.
|
||||||
|
self.assertEqual(trials[1].status, Trial.TERMINATED)
|
||||||
self.assertRaises(TuneError, lambda: runner.step())
|
self.assertRaises(TuneError, lambda: runner.step())
|
||||||
|
|
||||||
def testFailFastRaise(self):
|
def testFailFastRaise(self):
|
||||||
|
@ -203,13 +181,14 @@ class TrialRunnerTest2(unittest.TestCase):
|
||||||
runner.add_trial(Trial("__fake", **kwargs))
|
runner.add_trial(Trial("__fake", **kwargs))
|
||||||
trials = runner.get_trials()
|
trials = runner.get_trials()
|
||||||
|
|
||||||
runner.step() # Start trial
|
|
||||||
self.assertEqual(trials[0].status, Trial.RUNNING)
|
|
||||||
runner.step() # Process result, dispatch save
|
|
||||||
self.assertEqual(trials[0].status, Trial.RUNNING)
|
|
||||||
runner.step() # Process save
|
|
||||||
with self.assertRaises(Exception):
|
with self.assertRaises(Exception):
|
||||||
runner.step() # Error
|
while not runner.is_finished():
|
||||||
|
runner.step()
|
||||||
|
|
||||||
|
# Not critical checks. Only to showcase the difference
|
||||||
|
# with none raise type FailFast.
|
||||||
|
self.assertEqual(trials[0].status, Trial.RUNNING)
|
||||||
|
self.assertEqual(trials[1].status, Trial.PENDING)
|
||||||
|
|
||||||
def testCheckpointing(self):
|
def testCheckpointing(self):
|
||||||
ray.init(num_cpus=1, num_gpus=1)
|
ray.init(num_cpus=1, num_gpus=1)
|
||||||
|
@ -244,35 +223,38 @@ class TrialRunnerTest2(unittest.TestCase):
|
||||||
|
|
||||||
def testRestoreMetricsAfterCheckpointing(self):
|
def testRestoreMetricsAfterCheckpointing(self):
|
||||||
ray.init(num_cpus=1, num_gpus=1)
|
ray.init(num_cpus=1, num_gpus=1)
|
||||||
runner = TrialRunner()
|
|
||||||
|
observer = TrialResultObserver()
|
||||||
|
runner = TrialRunner(callbacks=[observer])
|
||||||
kwargs = {
|
kwargs = {
|
||||||
|
"stopping_criterion": {"training_iteration": 2},
|
||||||
"resources": Resources(cpu=1, gpu=1),
|
"resources": Resources(cpu=1, gpu=1),
|
||||||
"checkpoint_freq": 1,
|
"checkpoint_freq": 1,
|
||||||
}
|
}
|
||||||
runner.add_trial(Trial("__fake", **kwargs))
|
runner.add_trial(Trial("__fake", **kwargs))
|
||||||
trials = runner.get_trials()
|
trials = runner.get_trials()
|
||||||
|
|
||||||
runner.step() # Start trial
|
while not runner.is_finished():
|
||||||
self.assertEqual(trials[0].status, Trial.RUNNING)
|
runner.step()
|
||||||
self.assertEqual(ray.get(trials[0].runner.set_info.remote(1)), 1)
|
|
||||||
runner.step() # Process result, dispatch save
|
|
||||||
runner.step() # Process save
|
|
||||||
runner.trial_executor.stop_trial(trials[0])
|
|
||||||
kwargs["restore_path"] = trials[0].checkpoint.value
|
|
||||||
|
|
||||||
|
self.assertEqual(trials[0].status, Trial.TERMINATED)
|
||||||
|
|
||||||
|
kwargs["restore_path"] = trials[0].checkpoint.value
|
||||||
|
kwargs.pop("stopping_criterion")
|
||||||
kwargs.pop("checkpoint_freq") # No checkpointing for next trial
|
kwargs.pop("checkpoint_freq") # No checkpointing for next trial
|
||||||
runner.add_trial(Trial("__fake", **kwargs))
|
runner.add_trial(Trial("__fake", **kwargs))
|
||||||
trials = runner.get_trials()
|
trials = runner.get_trials()
|
||||||
|
|
||||||
runner.step() # Start trial, dispatch restore
|
observer.reset()
|
||||||
self.assertEqual(trials[0].status, Trial.TERMINATED)
|
while not observer.just_received_a_result():
|
||||||
self.assertEqual(trials[1].status, Trial.RUNNING)
|
runner.step()
|
||||||
runner.step() # Process restore
|
|
||||||
runner.step() # Process result
|
|
||||||
self.assertEqual(trials[1].last_result["timesteps_since_restore"], 10)
|
self.assertEqual(trials[1].last_result["timesteps_since_restore"], 10)
|
||||||
self.assertEqual(trials[1].last_result["iterations_since_restore"], 1)
|
self.assertEqual(trials[1].last_result["iterations_since_restore"], 1)
|
||||||
self.assertGreater(trials[1].last_result["time_since_restore"], 0)
|
self.assertGreater(trials[1].last_result["time_since_restore"], 0)
|
||||||
runner.step() # Process restore
|
|
||||||
|
while not observer.just_received_a_result():
|
||||||
|
runner.step()
|
||||||
|
|
||||||
self.assertEqual(trials[1].last_result["timesteps_since_restore"], 20)
|
self.assertEqual(trials[1].last_result["timesteps_since_restore"], 20)
|
||||||
self.assertEqual(trials[1].last_result["iterations_since_restore"], 2)
|
self.assertEqual(trials[1].last_result["iterations_since_restore"], 2)
|
||||||
self.assertGreater(trials[1].last_result["time_since_restore"], 0)
|
self.assertGreater(trials[1].last_result["time_since_restore"], 0)
|
||||||
|
@ -289,12 +271,9 @@ class TrialRunnerTest2(unittest.TestCase):
|
||||||
runner.add_trial(Trial("__fake", **kwargs))
|
runner.add_trial(Trial("__fake", **kwargs))
|
||||||
trials = runner.get_trials()
|
trials = runner.get_trials()
|
||||||
|
|
||||||
runner.step() # Start trial
|
while not runner.is_finished():
|
||||||
self.assertEqual(trials[0].status, Trial.RUNNING)
|
runner.step()
|
||||||
runner.step() # Process result
|
|
||||||
runner.step() # Process result, dispatch save
|
|
||||||
self.assertEqual(trials[0].last_result[DONE], True)
|
self.assertEqual(trials[0].last_result[DONE], True)
|
||||||
runner.step() # Process save
|
|
||||||
self.assertEqual(trials[0].has_checkpoint(), True)
|
self.assertEqual(trials[0].has_checkpoint(), True)
|
||||||
|
|
||||||
def testResultDone(self):
|
def testResultDone(self):
|
||||||
|
@ -308,10 +287,7 @@ class TrialRunnerTest2(unittest.TestCase):
|
||||||
runner.add_trial(Trial("__fake", **kwargs))
|
runner.add_trial(Trial("__fake", **kwargs))
|
||||||
trials = runner.get_trials()
|
trials = runner.get_trials()
|
||||||
|
|
||||||
runner.step()
|
while not runner.is_finished():
|
||||||
self.assertEqual(trials[0].status, Trial.RUNNING)
|
|
||||||
runner.step()
|
|
||||||
self.assertNotEqual(trials[0].last_result[DONE], True)
|
|
||||||
runner.step()
|
runner.step()
|
||||||
self.assertEqual(trials[0].last_result[DONE], True)
|
self.assertEqual(trials[0].last_result[DONE], True)
|
||||||
|
|
||||||
|
|
|
@ -25,16 +25,11 @@ from ray.tune.suggest._mock import _MockSuggestionAlgorithm
|
||||||
from ray.tune.suggest.suggestion import Searcher, ConcurrencyLimiter
|
from ray.tune.suggest.suggestion import Searcher, ConcurrencyLimiter
|
||||||
from ray.tune.suggest.search_generator import SearchGenerator
|
from ray.tune.suggest.search_generator import SearchGenerator
|
||||||
from ray.tune.syncer import SyncConfig
|
from ray.tune.syncer import SyncConfig
|
||||||
|
from ray.tune.tests.test_trial_runner_utils import TrialResultObserver
|
||||||
|
|
||||||
|
|
||||||
class TrialRunnerTest3(unittest.TestCase):
|
class TrialRunnerTest3(unittest.TestCase):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
# Wait up to five seconds for placement groups when starting a trial
|
|
||||||
os.environ["TUNE_PLACEMENT_GROUP_WAIT_S"] = "5"
|
|
||||||
# Block for results even when placement groups are pending
|
|
||||||
os.environ["TUNE_TRIAL_STARTUP_GRACE_PERIOD"] = "0"
|
|
||||||
os.environ["TUNE_TRIAL_RESULT_WAIT_TIME_S"] = "99999"
|
|
||||||
|
|
||||||
os.environ["TUNE_MAX_PENDING_TRIALS_PG"] = "auto" # Reset default
|
os.environ["TUNE_MAX_PENDING_TRIALS_PG"] = "auto" # Reset default
|
||||||
|
|
||||||
self.tmpdir = tempfile.mkdtemp()
|
self.tmpdir = tempfile.mkdtemp()
|
||||||
|
@ -131,15 +126,9 @@ class TrialRunnerTest3(unittest.TestCase):
|
||||||
searcher = search_alg.searcher
|
searcher = search_alg.searcher
|
||||||
search_alg.add_configurations(experiments)
|
search_alg.add_configurations(experiments)
|
||||||
runner = TrialRunner(search_alg=search_alg)
|
runner = TrialRunner(search_alg=search_alg)
|
||||||
runner.step()
|
|
||||||
trials = runner.get_trials()
|
|
||||||
self.assertEqual(trials[0].status, Trial.RUNNING)
|
|
||||||
|
|
||||||
|
while not runner.is_finished():
|
||||||
runner.step()
|
runner.step()
|
||||||
self.assertEqual(trials[0].status, Trial.RUNNING)
|
|
||||||
|
|
||||||
runner.step()
|
|
||||||
self.assertEqual(trials[0].status, Trial.TERMINATED)
|
|
||||||
|
|
||||||
self.assertEqual(searcher.counter["result"], 1)
|
self.assertEqual(searcher.counter["result"], 1)
|
||||||
self.assertEqual(searcher.counter["complete"], 1)
|
self.assertEqual(searcher.counter["complete"], 1)
|
||||||
|
@ -204,18 +193,17 @@ class TrialRunnerTest3(unittest.TestCase):
|
||||||
runner = TrialRunner(search_alg=search_alg)
|
runner = TrialRunner(search_alg=search_alg)
|
||||||
runner.step()
|
runner.step()
|
||||||
trials = runner.get_trials()
|
trials = runner.get_trials()
|
||||||
self.assertEqual(trials[0].status, Trial.RUNNING)
|
while trials[0].status != Trial.TERMINATED:
|
||||||
|
runner.step()
|
||||||
|
|
||||||
runner.step()
|
runner.step()
|
||||||
self.assertEqual(trials[0].status, Trial.TERMINATED)
|
|
||||||
|
|
||||||
trials = runner.get_trials()
|
trials = runner.get_trials()
|
||||||
runner.step()
|
|
||||||
self.assertEqual(trials[1].status, Trial.RUNNING)
|
self.assertEqual(trials[1].status, Trial.RUNNING)
|
||||||
self.assertEqual(len(searcher.live_trials), 1)
|
self.assertEqual(len(searcher.live_trials), 1)
|
||||||
|
|
||||||
searcher.stall = True
|
searcher.stall = True
|
||||||
|
|
||||||
|
while trials[1].status != Trial.TERMINATED:
|
||||||
runner.step()
|
runner.step()
|
||||||
self.assertEqual(trials[1].status, Trial.TERMINATED)
|
self.assertEqual(trials[1].status, Trial.TERMINATED)
|
||||||
self.assertEqual(len(searcher.live_trials), 0)
|
self.assertEqual(len(searcher.live_trials), 0)
|
||||||
|
@ -231,8 +219,9 @@ class TrialRunnerTest3(unittest.TestCase):
|
||||||
self.assertEqual(trials[2].status, Trial.RUNNING)
|
self.assertEqual(trials[2].status, Trial.RUNNING)
|
||||||
self.assertEqual(len(searcher.live_trials), 1)
|
self.assertEqual(len(searcher.live_trials), 1)
|
||||||
|
|
||||||
|
while trials[2].status != Trial.TERMINATED:
|
||||||
runner.step()
|
runner.step()
|
||||||
self.assertEqual(trials[2].status, Trial.TERMINATED)
|
|
||||||
self.assertEqual(len(searcher.live_trials), 0)
|
self.assertEqual(len(searcher.live_trials), 0)
|
||||||
self.assertTrue(search_alg.is_finished())
|
self.assertTrue(search_alg.is_finished())
|
||||||
self.assertTrue(runner.is_finished())
|
self.assertTrue(runner.is_finished())
|
||||||
|
@ -445,9 +434,9 @@ class TrialRunnerTest3(unittest.TestCase):
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
runner.add_trial(trials[0])
|
runner.add_trial(trials[0])
|
||||||
runner.step() # Start trial
|
while not runner.is_finished():
|
||||||
runner.step() # Process result, dispatch save
|
# Start trial, process result, dispatch save and process save.
|
||||||
runner.step() # Process save
|
runner.step()
|
||||||
self.assertEqual(trials[0].status, Trial.TERMINATED)
|
self.assertEqual(trials[0].status, Trial.TERMINATED)
|
||||||
|
|
||||||
trials += [
|
trials += [
|
||||||
|
@ -460,10 +449,13 @@ class TrialRunnerTest3(unittest.TestCase):
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
runner.add_trial(trials[1])
|
runner.add_trial(trials[1])
|
||||||
runner.step() # Start trial
|
while not runner.is_finished():
|
||||||
runner.step() # Process result, dispatch save
|
# Start trial,
|
||||||
runner.step() # Process save
|
# Process result,
|
||||||
runner.step() # Error
|
# Dispatch save,
|
||||||
|
# Process save and
|
||||||
|
# Error.
|
||||||
|
runner.step()
|
||||||
self.assertEqual(trials[1].status, Trial.ERROR)
|
self.assertEqual(trials[1].status, Trial.ERROR)
|
||||||
|
|
||||||
trials += [
|
trials += [
|
||||||
|
@ -488,12 +480,14 @@ class TrialRunnerTest3(unittest.TestCase):
|
||||||
restored_trial = runner2.get_trial("trial_succ")
|
restored_trial = runner2.get_trial("trial_succ")
|
||||||
self.assertEqual(Trial.PENDING, restored_trial.status)
|
self.assertEqual(Trial.PENDING, restored_trial.status)
|
||||||
|
|
||||||
runner2.step() # Start trial
|
while not runner2.is_finished():
|
||||||
runner2.step() # Process result, dispatch save
|
# Start trial,
|
||||||
runner2.step() # Process save
|
# Process result, dispatch save
|
||||||
runner2.step() # Process result, dispatch save
|
# Process save
|
||||||
runner2.step() # Process save
|
# Process result, dispatch save
|
||||||
self.assertRaises(TuneError, runner2.step)
|
# Process save.
|
||||||
|
runner2.step()
|
||||||
|
self.assertEqual(restored_trial.status, Trial.TERMINATED)
|
||||||
|
|
||||||
def testTrialNoCheckpointSave(self):
|
def testTrialNoCheckpointSave(self):
|
||||||
"""Check that non-checkpointing trials *are* saved."""
|
"""Check that non-checkpointing trials *are* saved."""
|
||||||
|
@ -533,7 +527,8 @@ class TrialRunnerTest3(unittest.TestCase):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
runner.step()
|
old_trials = runner.get_trials()
|
||||||
|
while not old_trials[2].has_reported_at_least_once:
|
||||||
runner.step()
|
runner.step()
|
||||||
|
|
||||||
runner2 = TrialRunner(resume="LOCAL", local_checkpoint_dir=self.tmpdir)
|
runner2 = TrialRunner(resume="LOCAL", local_checkpoint_dir=self.tmpdir)
|
||||||
|
@ -643,29 +638,42 @@ class TrialRunnerTest3(unittest.TestCase):
|
||||||
checkpoint_at_end=True,
|
checkpoint_at_end=True,
|
||||||
stopping_criterion={"training_iteration": 4},
|
stopping_criterion={"training_iteration": 4},
|
||||||
)
|
)
|
||||||
|
observer = TrialResultObserver()
|
||||||
runner = TrialRunner(
|
runner = TrialRunner(
|
||||||
local_checkpoint_dir=self.tmpdir,
|
local_checkpoint_dir=self.tmpdir,
|
||||||
checkpoint_period=0,
|
checkpoint_period=0,
|
||||||
trial_executor=RayTrialExecutor(result_buffer_length=7),
|
trial_executor=RayTrialExecutor(result_buffer_length=7),
|
||||||
|
callbacks=[observer],
|
||||||
)
|
)
|
||||||
runner.add_trial(trial)
|
runner.add_trial(trial)
|
||||||
|
|
||||||
runner.step() # start trial
|
while not observer.just_received_a_result():
|
||||||
|
runner.step()
|
||||||
runner.step() # run iteration 1
|
|
||||||
self.assertEqual(trial.last_result[TRAINING_ITERATION], 1)
|
self.assertEqual(trial.last_result[TRAINING_ITERATION], 1)
|
||||||
self.assertEqual(num_checkpoints(trial), 0)
|
self.assertEqual(num_checkpoints(trial), 0)
|
||||||
|
|
||||||
runner.step() # run iteration 2
|
while True:
|
||||||
|
runner.step()
|
||||||
|
if observer.just_received_a_result():
|
||||||
|
break
|
||||||
self.assertEqual(trial.last_result[TRAINING_ITERATION], 2)
|
self.assertEqual(trial.last_result[TRAINING_ITERATION], 2)
|
||||||
self.assertEqual(num_checkpoints(trial), 0)
|
self.assertEqual(num_checkpoints(trial), 0)
|
||||||
|
|
||||||
runner.step() # run iteration 3
|
while True:
|
||||||
|
runner.step()
|
||||||
|
if observer.just_received_a_result():
|
||||||
|
break
|
||||||
self.assertEqual(trial.last_result[TRAINING_ITERATION], 3)
|
self.assertEqual(trial.last_result[TRAINING_ITERATION], 3)
|
||||||
self.assertEqual(num_checkpoints(trial), 0)
|
self.assertEqual(num_checkpoints(trial), 0)
|
||||||
|
|
||||||
runner.step() # run iteration 4
|
while True:
|
||||||
|
runner.step()
|
||||||
|
if observer.just_received_a_result():
|
||||||
|
break
|
||||||
self.assertEqual(trial.last_result[TRAINING_ITERATION], 4)
|
self.assertEqual(trial.last_result[TRAINING_ITERATION], 4)
|
||||||
|
|
||||||
|
while not runner.is_finished():
|
||||||
|
runner.step()
|
||||||
self.assertEqual(num_checkpoints(trial), 1)
|
self.assertEqual(num_checkpoints(trial), 1)
|
||||||
|
|
||||||
def testUserCheckpoint(self):
|
def testUserCheckpoint(self):
|
||||||
|
|
|
@ -9,11 +9,14 @@ from collections import OrderedDict
|
||||||
|
|
||||||
import ray
|
import ray
|
||||||
from ray import tune
|
from ray import tune
|
||||||
from ray.exceptions import RayActorError
|
|
||||||
from ray.rllib import _register_all
|
from ray.rllib import _register_all
|
||||||
from ray.tune.checkpoint_manager import Checkpoint
|
from ray.tune.checkpoint_manager import Checkpoint
|
||||||
from ray.tune.logger import DEFAULT_LOGGERS, LoggerCallback, LegacyLoggerCallback
|
from ray.tune.logger import DEFAULT_LOGGERS, LoggerCallback, LegacyLoggerCallback
|
||||||
from ray.tune.ray_trial_executor import RayTrialExecutor
|
from ray.tune.ray_trial_executor import (
|
||||||
|
RayTrialExecutor,
|
||||||
|
ExecutorEvent,
|
||||||
|
ExecutorEventType,
|
||||||
|
)
|
||||||
from ray.tune.result import TRAINING_ITERATION
|
from ray.tune.result import TRAINING_ITERATION
|
||||||
from ray.tune.syncer import SyncConfig, SyncerCallback
|
from ray.tune.syncer import SyncConfig, SyncerCallback
|
||||||
|
|
||||||
|
@ -65,28 +68,25 @@ class TestCallback(Callback):
|
||||||
self.state["experiment_end"] = info
|
self.state["experiment_end"] = info
|
||||||
|
|
||||||
|
|
||||||
|
# TODO(xwjiang): Move this to a testing util.
|
||||||
class _MockTrialExecutor(RayTrialExecutor):
|
class _MockTrialExecutor(RayTrialExecutor):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.next_trial = None
|
self.next_future_result = None
|
||||||
self.results = {}
|
|
||||||
self.should_fail_in_fetch_result = False
|
|
||||||
|
|
||||||
def fetch_result(self, trial):
|
def start_trial(self, trial: Trial):
|
||||||
if self.should_fail_in_fetch_result:
|
trial.status = Trial.RUNNING
|
||||||
raise RayActorError(
|
return True
|
||||||
"The actor died unexpectedly before finishing this task."
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
return [self.results.get(trial, {})]
|
|
||||||
|
|
||||||
def get_next_available_trial(self, timeout=None):
|
def continue_training(self, trial: Trial):
|
||||||
return self.next_trial or super().get_next_available_trial()
|
pass
|
||||||
|
|
||||||
|
def get_next_executor_event(self, live_trials, next_trial_exists):
|
||||||
|
return self.next_future_result
|
||||||
|
|
||||||
|
|
||||||
class TrialRunnerCallbacks(unittest.TestCase):
|
class TrialRunnerCallbacks(unittest.TestCase):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
os.environ["TUNE_PLACEMENT_GROUP_WAIT_S"] = "1"
|
|
||||||
|
|
||||||
ray.init()
|
ray.init()
|
||||||
self.tmpdir = tempfile.mkdtemp()
|
self.tmpdir = tempfile.mkdtemp()
|
||||||
|
@ -110,7 +110,9 @@ class TrialRunnerCallbacks(unittest.TestCase):
|
||||||
for t in trials:
|
for t in trials:
|
||||||
self.trial_runner.add_trial(t)
|
self.trial_runner.add_trial(t)
|
||||||
|
|
||||||
self.executor.next_trial = trials[0]
|
self.executor.next_future_result = ExecutorEvent(
|
||||||
|
event_type=ExecutorEventType.PG_READY
|
||||||
|
)
|
||||||
self.trial_runner.step()
|
self.trial_runner.step()
|
||||||
|
|
||||||
# Trial 1 has been started
|
# Trial 1 has been started
|
||||||
|
@ -132,7 +134,9 @@ class TrialRunnerCallbacks(unittest.TestCase):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
self.executor.next_trial = trials[1]
|
self.executor.next_future_result = ExecutorEvent(
|
||||||
|
event_type=ExecutorEventType.PG_READY
|
||||||
|
)
|
||||||
self.trial_runner.step()
|
self.trial_runner.step()
|
||||||
|
|
||||||
# Iteration not increased yet
|
# Iteration not increased yet
|
||||||
|
@ -148,7 +152,9 @@ class TrialRunnerCallbacks(unittest.TestCase):
|
||||||
cp = Checkpoint(Checkpoint.PERSISTENT, "__checkpoint", {TRAINING_ITERATION: 0})
|
cp = Checkpoint(Checkpoint.PERSISTENT, "__checkpoint", {TRAINING_ITERATION: 0})
|
||||||
|
|
||||||
# Let the first trial save a checkpoint
|
# Let the first trial save a checkpoint
|
||||||
self.executor.next_trial = trials[0]
|
self.executor.next_future_result = ExecutorEvent(
|
||||||
|
event_type=ExecutorEventType.SAVING_RESULT, trial=trials[0]
|
||||||
|
)
|
||||||
trials[0].saving_to = cp
|
trials[0].saving_to = cp
|
||||||
self.trial_runner.step()
|
self.trial_runner.step()
|
||||||
self.assertEqual(self.callback.state["trial_save"]["iteration"], 2)
|
self.assertEqual(self.callback.state["trial_save"]["iteration"], 2)
|
||||||
|
@ -156,8 +162,9 @@ class TrialRunnerCallbacks(unittest.TestCase):
|
||||||
|
|
||||||
# Let the second trial send a result
|
# Let the second trial send a result
|
||||||
result = {TRAINING_ITERATION: 1, "metric": 800, "done": False}
|
result = {TRAINING_ITERATION: 1, "metric": 800, "done": False}
|
||||||
self.executor.results[trials[1]] = result
|
self.executor.next_future_result = ExecutorEvent(
|
||||||
self.executor.next_trial = trials[1]
|
event_type=ExecutorEventType.TRAINING_RESULT, trial=trials[1], result=result
|
||||||
|
)
|
||||||
self.assertTrue(not trials[1].has_reported_at_least_once)
|
self.assertTrue(not trials[1].has_reported_at_least_once)
|
||||||
self.trial_runner.step()
|
self.trial_runner.step()
|
||||||
self.assertEqual(self.callback.state["trial_result"]["iteration"], 3)
|
self.assertEqual(self.callback.state["trial_result"]["iteration"], 3)
|
||||||
|
@ -167,25 +174,28 @@ class TrialRunnerCallbacks(unittest.TestCase):
|
||||||
|
|
||||||
# Let the second trial restore from a checkpoint
|
# Let the second trial restore from a checkpoint
|
||||||
trials[1].restoring_from = cp
|
trials[1].restoring_from = cp
|
||||||
self.executor.results[trials[1]] = trials[1].last_result
|
self.executor.next_future_result = ExecutorEvent(
|
||||||
|
event_type=ExecutorEventType.RESTORING_RESULT, trial=trials[1]
|
||||||
|
)
|
||||||
self.trial_runner.step()
|
self.trial_runner.step()
|
||||||
self.assertEqual(self.callback.state["trial_restore"]["iteration"], 4)
|
self.assertEqual(self.callback.state["trial_restore"]["iteration"], 4)
|
||||||
self.assertEqual(self.callback.state["trial_restore"]["trial"].trial_id, "two")
|
self.assertEqual(self.callback.state["trial_restore"]["trial"].trial_id, "two")
|
||||||
|
|
||||||
# Let the second trial finish
|
# Let the second trial finish
|
||||||
trials[1].restoring_from = None
|
trials[1].restoring_from = None
|
||||||
self.executor.results[trials[1]] = {
|
self.executor.next_future_result = ExecutorEvent(
|
||||||
TRAINING_ITERATION: 2,
|
event_type=ExecutorEventType.TRAINING_RESULT,
|
||||||
"metric": 900,
|
trial=trials[1],
|
||||||
"done": True,
|
result={TRAINING_ITERATION: 2, "metric": 900, "done": True},
|
||||||
}
|
)
|
||||||
self.trial_runner.step()
|
self.trial_runner.step()
|
||||||
self.assertEqual(self.callback.state["trial_complete"]["iteration"], 5)
|
self.assertEqual(self.callback.state["trial_complete"]["iteration"], 5)
|
||||||
self.assertEqual(self.callback.state["trial_complete"]["trial"].trial_id, "two")
|
self.assertEqual(self.callback.state["trial_complete"]["trial"].trial_id, "two")
|
||||||
|
|
||||||
# Let the first trial error
|
# Let the first trial error
|
||||||
self.executor.next_trial = trials[0]
|
self.executor.next_future_result = ExecutorEvent(
|
||||||
self.executor.should_fail_in_fetch_result = True
|
event_type=ExecutorEventType.ERROR, trial=trials[0]
|
||||||
|
)
|
||||||
self.trial_runner.step()
|
self.trial_runner.step()
|
||||||
self.assertEqual(self.callback.state["trial_fail"]["iteration"], 6)
|
self.assertEqual(self.callback.state["trial_fail"]["iteration"], 6)
|
||||||
self.assertEqual(self.callback.state["trial_fail"]["trial"].trial_id, "one")
|
self.assertEqual(self.callback.state["trial_fail"]["trial"].trial_id, "one")
|
||||||
|
|
|
@ -53,7 +53,6 @@ class TrialRunnerPlacementGroupTest(unittest.TestCase):
|
||||||
self.assertFalse(pg_manager._staging[pgf])
|
self.assertFalse(pg_manager._staging[pgf])
|
||||||
for pgf in pg_manager._ready:
|
for pgf in pg_manager._ready:
|
||||||
self.assertFalse(pg_manager._ready[pgf])
|
self.assertFalse(pg_manager._ready[pgf])
|
||||||
self.assertTrue(pg_manager._latest_staging_start_time)
|
|
||||||
|
|
||||||
num_non_removed_pgs = len(
|
num_non_removed_pgs = len(
|
||||||
[p for pid, p in placement_group_table().items() if p["state"] != "REMOVED"]
|
[p for pid, p in placement_group_table().items() if p["state"] != "REMOVED"]
|
||||||
|
@ -109,17 +108,6 @@ class TrialRunnerPlacementGroupTest(unittest.TestCase):
|
||||||
|
|
||||||
total_num_tracked = num_staging + num_ready + num_in_use + num_cached
|
total_num_tracked = num_staging + num_ready + num_in_use + num_cached
|
||||||
|
|
||||||
num_non_removed_pgs = len(
|
|
||||||
[
|
|
||||||
p
|
|
||||||
for pid, p in placement_group_table().items()
|
|
||||||
if p["state"] != "REMOVED"
|
|
||||||
]
|
|
||||||
)
|
|
||||||
num_removal_scheduled_pgs = len(
|
|
||||||
trial_executor._pg_manager._pgs_for_removal
|
|
||||||
)
|
|
||||||
|
|
||||||
# All trials should be scheduled
|
# All trials should be scheduled
|
||||||
this.assertEqual(
|
this.assertEqual(
|
||||||
scheduled,
|
scheduled,
|
||||||
|
@ -141,13 +129,6 @@ class TrialRunnerPlacementGroupTest(unittest.TestCase):
|
||||||
msg=f"Num tracked iter {iteration}",
|
msg=f"Num tracked iter {iteration}",
|
||||||
)
|
)
|
||||||
|
|
||||||
# The number of actual placement groups should match this
|
|
||||||
this.assertGreaterEqual(
|
|
||||||
max(scheduled, len(trials)) - num_finished + num_parallel_reuse,
|
|
||||||
num_non_removed_pgs - num_removal_scheduled_pgs,
|
|
||||||
msg=f"Num actual iter {iteration}",
|
|
||||||
)
|
|
||||||
|
|
||||||
start = time.time()
|
start = time.time()
|
||||||
out = tune.run(
|
out = tune.run(
|
||||||
train,
|
train,
|
||||||
|
|
22
python/ray/tune/tests/test_trial_runner_utils.py
Normal file
22
python/ray/tune/tests/test_trial_runner_utils.py
Normal file
|
@ -0,0 +1,22 @@
|
||||||
|
from ray.tune import Callback
|
||||||
|
|
||||||
|
|
||||||
|
class TrialResultObserver(Callback):
|
||||||
|
"""Helper class to control runner.step() count."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._counter = 0
|
||||||
|
self._last_counter = 0
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self._last_counter = self._counter
|
||||||
|
|
||||||
|
def just_received_a_result(self):
|
||||||
|
if self._last_counter == self._counter:
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
self._last_counter = self._counter
|
||||||
|
return True
|
||||||
|
|
||||||
|
def on_trial_result(self, **kwargs):
|
||||||
|
self._counter += 1
|
|
@ -173,7 +173,6 @@ class PopulationBasedTrainingFileDescriptorTest(unittest.TestCase):
|
||||||
|
|
||||||
class PopulationBasedTrainingSynchTest(unittest.TestCase):
|
class PopulationBasedTrainingSynchTest(unittest.TestCase):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
os.environ["TUNE_TRIAL_STARTUP_GRACE_PERIOD"] = "0"
|
|
||||||
ray.init(num_cpus=2)
|
ray.init(num_cpus=2)
|
||||||
|
|
||||||
def MockTrainingFuncSync(config, checkpoint_dir=None):
|
def MockTrainingFuncSync(config, checkpoint_dir=None):
|
||||||
|
|
|
@ -105,13 +105,6 @@ def _run(local_dir, driver_semaphore, trainer_semaphore):
|
||||||
|
|
||||||
|
|
||||||
class TuneInterruptionTest(unittest.TestCase):
|
class TuneInterruptionTest(unittest.TestCase):
|
||||||
def setUp(self) -> None:
|
|
||||||
# Wait up to five seconds for placement groups when starting a trial
|
|
||||||
os.environ["TUNE_PLACEMENT_GROUP_WAIT_S"] = "5"
|
|
||||||
# Block for results even when placement groups are pending
|
|
||||||
os.environ["TUNE_TRIAL_STARTUP_GRACE_PERIOD"] = "0"
|
|
||||||
os.environ["TUNE_TRIAL_RESULT_WAIT_TIME_S"] = "99999"
|
|
||||||
|
|
||||||
def testExperimentInterrupted(self):
|
def testExperimentInterrupted(self):
|
||||||
local_dir = tempfile.mkdtemp()
|
local_dir = tempfile.mkdtemp()
|
||||||
# Unix platforms may default to "fork", which is problematic with
|
# Unix platforms may default to "fork", which is problematic with
|
||||||
|
@ -214,11 +207,6 @@ class TuneFailResumeGridTest(unittest.TestCase):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.logdir = tempfile.mkdtemp()
|
self.logdir = tempfile.mkdtemp()
|
||||||
os.environ["TUNE_GLOBAL_CHECKPOINT_S"] = "0"
|
os.environ["TUNE_GLOBAL_CHECKPOINT_S"] = "0"
|
||||||
# Wait up to 1.5 seconds for placement groups when starting a trial
|
|
||||||
os.environ["TUNE_PLACEMENT_GROUP_WAIT_S"] = "1.5"
|
|
||||||
# Block for results even when placement groups are pending
|
|
||||||
os.environ["TUNE_TRIAL_STARTUP_GRACE_PERIOD"] = "0"
|
|
||||||
os.environ["TUNE_TRIAL_RESULT_WAIT_TIME_S"] = "99999"
|
|
||||||
|
|
||||||
# Change back to local_mode=True after this is resolved:
|
# Change back to local_mode=True after this is resolved:
|
||||||
# https://github.com/ray-project/ray/issues/13932
|
# https://github.com/ray-project/ray/issues/13932
|
||||||
|
|
|
@ -1,4 +1,3 @@
|
||||||
import os
|
|
||||||
import requests
|
import requests
|
||||||
import socket
|
import socket
|
||||||
import subprocess
|
import subprocess
|
||||||
|
@ -28,11 +27,6 @@ def get_valid_port():
|
||||||
|
|
||||||
class TuneServerSuite(unittest.TestCase):
|
class TuneServerSuite(unittest.TestCase):
|
||||||
def basicSetup(self):
|
def basicSetup(self):
|
||||||
# Wait up to five seconds for placement groups when starting a trial
|
|
||||||
os.environ["TUNE_PLACEMENT_GROUP_WAIT_S"] = "5"
|
|
||||||
# Block for results even when placement groups are pending
|
|
||||||
os.environ["TUNE_TRIAL_STARTUP_GRACE_PERIOD"] = "0"
|
|
||||||
os.environ["TUNE_TRIAL_RESULT_WAIT_TIME_S"] = "99999"
|
|
||||||
|
|
||||||
ray.init(num_cpus=4, num_gpus=1)
|
ray.init(num_cpus=4, num_gpus=1)
|
||||||
port = get_valid_port()
|
port = get_valid_port()
|
||||||
|
|
|
@ -4,7 +4,6 @@ import logging
|
||||||
from typing import Dict, List, Optional
|
from typing import Dict, List, Optional
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
from ray.tune.resources import Resources
|
|
||||||
from ray.util.annotations import DeveloperAPI
|
from ray.util.annotations import DeveloperAPI
|
||||||
from ray.tune.trial import Trial, Checkpoint
|
from ray.tune.trial import Trial, Checkpoint
|
||||||
|
|
||||||
|
@ -79,11 +78,6 @@ class TrialExecutor(metaclass=_WarnOnDirectInheritanceMeta):
|
||||||
self._trials_to_cache.clear()
|
self._trials_to_cache.clear()
|
||||||
return self._cached_trial_state
|
return self._cached_trial_state
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def has_resources(self, resources: Resources) -> bool:
|
|
||||||
"""Returns whether this runner has at least the specified resources."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def start_trial(self, trial: Trial) -> bool:
|
def start_trial(self, trial: Trial) -> bool:
|
||||||
"""Starts the trial restoring from checkpoint if checkpoint is provided.
|
"""Starts the trial restoring from checkpoint if checkpoint is provided.
|
||||||
|
@ -150,11 +144,6 @@ class TrialExecutor(metaclass=_WarnOnDirectInheritanceMeta):
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get_running_trials(self) -> List[Trial]:
|
|
||||||
"""Returns all running trials."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
def on_step_begin(self, trials: List[Trial]) -> None:
|
def on_step_begin(self, trials: List[Trial]) -> None:
|
||||||
"""A hook called before running one step of the trial event loop.
|
"""A hook called before running one step of the trial event loop.
|
||||||
|
|
||||||
|
@ -176,26 +165,6 @@ class TrialExecutor(metaclass=_WarnOnDirectInheritanceMeta):
|
||||||
def force_reconcilation_on_next_step_end(self) -> None:
|
def force_reconcilation_on_next_step_end(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get_next_available_trial(self) -> Optional[Trial]:
|
|
||||||
"""Blocking call that waits until one result is ready.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Trial object that is ready for intermediate processing.
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def fetch_result(self, trial: Trial) -> List[Trial]:
|
|
||||||
"""Fetches one result for the trial.
|
|
||||||
|
|
||||||
Assumes the trial is running.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Result object for the trial.
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def debug_string(self) -> str:
|
def debug_string(self) -> str:
|
||||||
"""Returns a human readable message for printing to the console."""
|
"""Returns a human readable message for printing to the console."""
|
||||||
|
@ -260,10 +229,6 @@ class TrialExecutor(metaclass=_WarnOnDirectInheritanceMeta):
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def in_staging_grace_period(self) -> bool:
|
|
||||||
"""Returns True if trials have recently been staged."""
|
|
||||||
return False
|
|
||||||
|
|
||||||
def set_max_pending_trials(self, max_pending: int) -> None:
|
def set_max_pending_trials(self, max_pending: int) -> None:
|
||||||
"""Set the maximum number of allowed pending trials."""
|
"""Set the maximum number of allowed pending trials."""
|
||||||
pass
|
pass
|
||||||
|
|
|
@ -15,7 +15,7 @@ from ray.tune import TuneError
|
||||||
from ray.tune.callback import CallbackList
|
from ray.tune.callback import CallbackList
|
||||||
from ray.tune.experiment import Experiment
|
from ray.tune.experiment import Experiment
|
||||||
from ray.tune.insufficient_resources_manager import InsufficientResourcesManager
|
from ray.tune.insufficient_resources_manager import InsufficientResourcesManager
|
||||||
from ray.tune.ray_trial_executor import RayTrialExecutor
|
from ray.tune.ray_trial_executor import RayTrialExecutor, ExecutorEventType
|
||||||
from ray.tune.result import (
|
from ray.tune.result import (
|
||||||
DEBUG_METRICS,
|
DEBUG_METRICS,
|
||||||
DEFAULT_METRIC,
|
DEFAULT_METRIC,
|
||||||
|
@ -338,7 +338,6 @@ class TrialRunner:
|
||||||
self._cached_trial_decisions = {}
|
self._cached_trial_decisions = {}
|
||||||
self._queued_trial_decisions = {}
|
self._queued_trial_decisions = {}
|
||||||
self._updated_queue = False
|
self._updated_queue = False
|
||||||
self._result_wait_time = int(os.getenv("TUNE_TRIAL_RESULT_WAIT_TIME_S", "1"))
|
|
||||||
|
|
||||||
self._stop_queue = []
|
self._stop_queue = []
|
||||||
self._should_stop_experiment = False # used by TuneServer
|
self._should_stop_experiment = False # used by TuneServer
|
||||||
|
@ -685,23 +684,14 @@ class TrialRunner:
|
||||||
) and all(trial.is_finished() for trial in self._trials)
|
) and all(trial.is_finished() for trial in self._trials)
|
||||||
return trials_done and self._search_alg.is_finished()
|
return trials_done and self._search_alg.is_finished()
|
||||||
|
|
||||||
def step(self):
|
def _update_trial_queue_and_get_next_trial(self) -> Optional[Trial]:
|
||||||
"""Runs one step of the trial event loop.
|
"""Adding suggested trials to the live queue of trials (they start as PENDING trials).
|
||||||
|
|
||||||
Callers should typically run this method repeatedly in a loop. They
|
Returns:
|
||||||
may inspect or modify the runner's state in between calls to step().
|
next_trial: Trial
|
||||||
"""
|
"""
|
||||||
self._updated_queue = False
|
self._updated_queue = False
|
||||||
|
|
||||||
if self.is_finished():
|
|
||||||
raise TuneError("Called step when all trials finished?")
|
|
||||||
with warn_if_slow("on_step_begin"):
|
|
||||||
self.trial_executor.on_step_begin(self.get_trials())
|
|
||||||
with warn_if_slow("callbacks.on_step_begin"):
|
|
||||||
self._callbacks.on_step_begin(
|
|
||||||
iteration=self._iteration, trials=self._trials
|
|
||||||
)
|
|
||||||
|
|
||||||
# This will contain the next trial to start
|
# This will contain the next trial to start
|
||||||
next_trial = self._get_next_trial() # blocking
|
next_trial = self._get_next_trial() # blocking
|
||||||
# Create pending trials. If the queue was updated before, only
|
# Create pending trials. If the queue was updated before, only
|
||||||
|
@ -715,46 +705,64 @@ class TrialRunner:
|
||||||
break
|
break
|
||||||
num_pending_trials += 1
|
num_pending_trials += 1
|
||||||
|
|
||||||
# Update status of staged placement groups
|
return next_trial
|
||||||
self.trial_executor.stage_and_update_status(self._live_trials)
|
|
||||||
|
|
||||||
def _start_trial(trial: Trial) -> bool:
|
def _wait_and_handle_event(self, next_trial: Optional[Trial]):
|
||||||
"""Helper function to start trial and call callbacks"""
|
try:
|
||||||
with warn_if_slow("start_trial"):
|
# Single wait of entire tune loop.
|
||||||
if self.trial_executor.start_trial(trial):
|
future_result = self.trial_executor.get_next_executor_event(
|
||||||
self._callbacks.on_trial_start(
|
self._live_trials, next_trial is not None
|
||||||
iteration=self._iteration, trials=self._trials, trial=trial
|
|
||||||
)
|
)
|
||||||
return True
|
if future_result.type == ExecutorEventType.PG_READY:
|
||||||
return False
|
self._on_pg_ready(next_trial)
|
||||||
|
elif future_result.type == ExecutorEventType.NO_RUNNING_TRIAL_TIMEOUT:
|
||||||
may_handle_events = True
|
|
||||||
if next_trial is not None:
|
|
||||||
if _start_trial(next_trial):
|
|
||||||
may_handle_events = False
|
|
||||||
elif next_trial.status != Trial.ERROR:
|
|
||||||
# Only try to start another trial if previous trial startup
|
|
||||||
# did not error (e.g. it just didn't start because its
|
|
||||||
# placement group is not ready, yet).
|
|
||||||
# Without this clause, this test fails:
|
|
||||||
# test_trial_runner_pg.py::
|
|
||||||
# TrialRunnerPlacementGroupHeterogeneousTest::
|
|
||||||
# testResourceDeadlock
|
|
||||||
next_trial = self.trial_executor.get_staged_trial()
|
|
||||||
if next_trial is not None:
|
|
||||||
if _start_trial(next_trial):
|
|
||||||
may_handle_events = False
|
|
||||||
|
|
||||||
if may_handle_events:
|
|
||||||
if self.trial_executor.get_running_trials():
|
|
||||||
timeout = self._result_wait_time
|
|
||||||
if self.trial_executor.in_staging_grace_period():
|
|
||||||
timeout = 0.1
|
|
||||||
self._process_events(timeout=timeout)
|
|
||||||
else:
|
|
||||||
self._insufficient_resources_manager.on_no_available_trials(
|
self._insufficient_resources_manager.on_no_available_trials(
|
||||||
self.get_trials()
|
self.get_trials()
|
||||||
)
|
)
|
||||||
|
elif future_result.type == ExecutorEventType.YIELD:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
trial = future_result.trial
|
||||||
|
result = future_result.result
|
||||||
|
if future_result.type == ExecutorEventType.ERROR:
|
||||||
|
self._on_executor_error(trial, result)
|
||||||
|
elif future_result.type == ExecutorEventType.RESTORING_RESULT:
|
||||||
|
self._on_restoring_result(trial)
|
||||||
|
else:
|
||||||
|
assert future_result.type in (
|
||||||
|
ExecutorEventType.SAVING_RESULT,
|
||||||
|
ExecutorEventType.TRAINING_RESULT,
|
||||||
|
), f"Unexpected future type - {future_result.type}"
|
||||||
|
if future_result.type == ExecutorEventType.TRAINING_RESULT:
|
||||||
|
self._on_training_result(trial, result)
|
||||||
|
else:
|
||||||
|
self._on_saving_result(trial, result)
|
||||||
|
self._post_process_on_training_saving_result(trial)
|
||||||
|
except Exception as e:
|
||||||
|
if e is TuneError:
|
||||||
|
raise e
|
||||||
|
else:
|
||||||
|
raise TuneError(traceback.format_exc())
|
||||||
|
|
||||||
|
def step(self):
|
||||||
|
"""Runs one step of the trial event loop.
|
||||||
|
|
||||||
|
Callers should typically run this method repeatedly in a loop. They
|
||||||
|
may inspect or modify the runner's state in between calls to step().
|
||||||
|
"""
|
||||||
|
|
||||||
|
if self.is_finished():
|
||||||
|
raise TuneError("Called step when all trials finished?")
|
||||||
|
with warn_if_slow("on_step_begin"):
|
||||||
|
self.trial_executor.on_step_begin(self.get_trials())
|
||||||
|
with warn_if_slow("callbacks.on_step_begin"):
|
||||||
|
self._callbacks.on_step_begin(
|
||||||
|
iteration=self._iteration, trials=self._trials
|
||||||
|
)
|
||||||
|
|
||||||
|
next_trial = self._update_trial_queue_and_get_next_trial()
|
||||||
|
|
||||||
|
self._wait_and_handle_event(next_trial)
|
||||||
|
|
||||||
self._stop_experiment_if_needed()
|
self._stop_experiment_if_needed()
|
||||||
|
|
||||||
|
@ -770,12 +778,100 @@ class TrialRunner:
|
||||||
|
|
||||||
if self.is_finished():
|
if self.is_finished():
|
||||||
self._server.shutdown()
|
self._server.shutdown()
|
||||||
|
|
||||||
|
self._reconcile_live_trials()
|
||||||
|
|
||||||
with warn_if_slow("on_step_end"):
|
with warn_if_slow("on_step_end"):
|
||||||
self.trial_executor.on_step_end(self.get_trials())
|
self.trial_executor.on_step_end(self.get_trials())
|
||||||
with warn_if_slow("callbacks.on_step_end"):
|
with warn_if_slow("callbacks.on_step_end"):
|
||||||
self._callbacks.on_step_end(iteration=self._iteration, trials=self._trials)
|
self._callbacks.on_step_end(iteration=self._iteration, trials=self._trials)
|
||||||
|
|
||||||
self._reconcile_live_trials()
|
def _on_pg_ready(self, next_trial: Optional[Trial]):
|
||||||
|
def _start_trial(trial: Trial) -> bool:
|
||||||
|
"""Helper function to start trial and call callbacks"""
|
||||||
|
with warn_if_slow("start_trial"):
|
||||||
|
if self.trial_executor.start_trial(trial):
|
||||||
|
self._callbacks.on_trial_start(
|
||||||
|
iteration=self._iteration, trials=self._trials, trial=trial
|
||||||
|
)
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
assert next_trial is not None
|
||||||
|
logger.info(f"starting {next_trial}")
|
||||||
|
if not _start_trial(next_trial) and next_trial.status != Trial.ERROR:
|
||||||
|
# Only try to start another trial if previous trial startup
|
||||||
|
# did not error (e.g. it just didn't start because its
|
||||||
|
# placement group is not ready, yet).
|
||||||
|
# Without this clause, this test fails:
|
||||||
|
# test_trial_runner_pg.py::
|
||||||
|
# TrialRunnerPlacementGroupHeterogeneousTest::
|
||||||
|
# testResourceDeadlock
|
||||||
|
next_trial = self.trial_executor.get_staged_trial()
|
||||||
|
if next_trial is not None:
|
||||||
|
# Must be able to start.
|
||||||
|
assert _start_trial(next_trial)
|
||||||
|
else:
|
||||||
|
logger.info(f"reconciling {self.get_trials()}")
|
||||||
|
self.trial_executor._pg_manager.reconcile_placement_groups(
|
||||||
|
self.get_trials()
|
||||||
|
)
|
||||||
|
|
||||||
|
def _on_saving_result(self, trial, result):
|
||||||
|
with warn_if_slow("process_trial_save") as _profile:
|
||||||
|
self._process_trial_save(trial, result)
|
||||||
|
with warn_if_slow("callbacks.on_trial_save"):
|
||||||
|
self._callbacks.on_trial_save(
|
||||||
|
iteration=self._iteration, trials=self._trials, trial=trial
|
||||||
|
)
|
||||||
|
if _profile.too_slow and trial.sync_on_checkpoint:
|
||||||
|
# TODO(ujvl): Suggest using cloud checkpointing once
|
||||||
|
# API has converged.
|
||||||
|
|
||||||
|
msg = (
|
||||||
|
"Consider turning off forced head-worker trial "
|
||||||
|
"checkpoint syncs by setting sync_on_checkpoint=False"
|
||||||
|
". Note that this may result in faulty trial "
|
||||||
|
"restoration if a failure occurs while the checkpoint "
|
||||||
|
"is being synced from the worker to the head node."
|
||||||
|
)
|
||||||
|
|
||||||
|
if trial.location.hostname and (
|
||||||
|
trial.location.hostname != get_node_ip_address()
|
||||||
|
):
|
||||||
|
if log_once("tune_head_worker_checkpoint"):
|
||||||
|
logger.warning(msg)
|
||||||
|
|
||||||
|
def _on_restoring_result(self, trial):
|
||||||
|
with warn_if_slow("process_trial_restore"):
|
||||||
|
self._process_trial_restore(trial)
|
||||||
|
with warn_if_slow("callbacks.on_trial_restore"):
|
||||||
|
self._callbacks.on_trial_restore(
|
||||||
|
iteration=self._iteration, trials=self._trials, trial=trial
|
||||||
|
)
|
||||||
|
|
||||||
|
def _on_training_result(self, trial, result):
|
||||||
|
if not isinstance(result, list):
|
||||||
|
result = [result]
|
||||||
|
with warn_if_slow("process_trial_result"):
|
||||||
|
self._process_trial_results(trial, result)
|
||||||
|
|
||||||
|
def _post_process_on_training_saving_result(self, trial):
|
||||||
|
# `self._queued_trial_decisions` now contains a final decision
|
||||||
|
# based on all results
|
||||||
|
if trial not in self._cached_trial_decisions:
|
||||||
|
final_decision = self._queued_trial_decisions.pop(trial.trial_id, None)
|
||||||
|
if final_decision:
|
||||||
|
self._execute_action(trial, final_decision)
|
||||||
|
|
||||||
|
def _on_executor_error(self, trial, result):
|
||||||
|
error_msg = f"Trial {trial}: Error processing event."
|
||||||
|
if self._fail_fast == TrialRunner.RAISE:
|
||||||
|
logger.error(error_msg)
|
||||||
|
raise
|
||||||
|
else:
|
||||||
|
logger.exception(error_msg)
|
||||||
|
self._process_trial_failure(trial, result)
|
||||||
|
|
||||||
def get_trial(self, tid):
|
def get_trial(self, tid):
|
||||||
trial = [t for t in self._trials if t.trial_id == tid]
|
trial = [t for t in self._trials if t.trial_id == tid]
|
||||||
|
@ -857,75 +953,8 @@ class TrialRunner:
|
||||||
logger.debug("Running trial {}".format(trial))
|
logger.debug("Running trial {}".format(trial))
|
||||||
return trial
|
return trial
|
||||||
|
|
||||||
def _process_events(self, timeout: Optional[float] = None):
|
def _process_trial_results(self, trial, results):
|
||||||
# TODO(ujvl): Consider combining get_next_available_trial and
|
logger.debug(f"process_trial_results {results}")
|
||||||
# fetch_result functionality so that we don't timeout on fetch.
|
|
||||||
trial = self.trial_executor.get_next_available_trial(
|
|
||||||
timeout=timeout
|
|
||||||
) # blocking
|
|
||||||
if not trial:
|
|
||||||
return
|
|
||||||
if trial.is_restoring:
|
|
||||||
with warn_if_slow("process_trial_restore"):
|
|
||||||
self._process_trial_restore(trial)
|
|
||||||
with warn_if_slow("callbacks.on_trial_restore"):
|
|
||||||
self._callbacks.on_trial_restore(
|
|
||||||
iteration=self._iteration, trials=self._trials, trial=trial
|
|
||||||
)
|
|
||||||
elif trial.is_saving:
|
|
||||||
with warn_if_slow("process_trial_save") as _profile:
|
|
||||||
self._process_trial_save(trial)
|
|
||||||
with warn_if_slow("callbacks.on_trial_save"):
|
|
||||||
self._callbacks.on_trial_save(
|
|
||||||
iteration=self._iteration, trials=self._trials, trial=trial
|
|
||||||
)
|
|
||||||
if _profile.too_slow and trial.sync_on_checkpoint:
|
|
||||||
# TODO(ujvl): Suggest using cloud checkpointing once
|
|
||||||
# API has converged.
|
|
||||||
|
|
||||||
msg = (
|
|
||||||
"Consider turning off forced head-worker trial "
|
|
||||||
"checkpoint syncs by setting sync_on_checkpoint=False"
|
|
||||||
". Note that this may result in faulty trial "
|
|
||||||
"restoration if a failure occurs while the checkpoint "
|
|
||||||
"is being synced from the worker to the head node."
|
|
||||||
)
|
|
||||||
|
|
||||||
if trial.location.hostname and (
|
|
||||||
trial.location.hostname != get_node_ip_address()
|
|
||||||
):
|
|
||||||
if log_once("tune_head_worker_checkpoint"):
|
|
||||||
logger.warning(msg)
|
|
||||||
|
|
||||||
else:
|
|
||||||
with warn_if_slow("process_trial"):
|
|
||||||
self._process_trial(trial)
|
|
||||||
|
|
||||||
# `self._queued_trial_decisions` now contains a final decision
|
|
||||||
# based on all results
|
|
||||||
if trial not in self._cached_trial_decisions:
|
|
||||||
final_decision = self._queued_trial_decisions.pop(trial.trial_id, None)
|
|
||||||
if final_decision:
|
|
||||||
self._execute_action(trial, final_decision)
|
|
||||||
|
|
||||||
def _process_trial(self, trial):
|
|
||||||
"""Processes a trial result.
|
|
||||||
|
|
||||||
Fetches the trial's latest result and makes a scheduling decision
|
|
||||||
regarding its next action. If a checkpoint is taken, the decided
|
|
||||||
action is cached and acted on only after the checkpoint is later
|
|
||||||
processed (see `_process_trial_save`). Otherwise the decision is
|
|
||||||
acted on immediately.
|
|
||||||
|
|
||||||
If multiple results are received (e.g. because of buffering), all
|
|
||||||
results are processed and the final action is determined. STOP
|
|
||||||
takes precedence over PAUSE, which takes precedence over CONTINUE.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
trial (Trial): Trial with a result ready to be processed.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
results = self.trial_executor.fetch_result(trial)
|
|
||||||
with warn_if_slow(
|
with warn_if_slow(
|
||||||
"process_trial_results",
|
"process_trial_results",
|
||||||
message="Processing trial results took {duration:.3f} s, "
|
message="Processing trial results took {duration:.3f} s, "
|
||||||
|
@ -955,14 +984,6 @@ class TrialRunner:
|
||||||
# If the decision is to stop the trial,
|
# If the decision is to stop the trial,
|
||||||
# ignore all results that came after that.
|
# ignore all results that came after that.
|
||||||
break
|
break
|
||||||
except Exception:
|
|
||||||
error_msg = "Trial %s: Error processing event." % trial
|
|
||||||
if self._fail_fast == TrialRunner.RAISE:
|
|
||||||
logger.error(error_msg)
|
|
||||||
raise
|
|
||||||
else:
|
|
||||||
logger.exception(error_msg)
|
|
||||||
self._process_trial_failure(trial, traceback.format_exc())
|
|
||||||
|
|
||||||
def _process_trial_result(self, trial, result):
|
def _process_trial_result(self, trial, result):
|
||||||
result.update(trial_id=trial.trial_id)
|
result.update(trial_id=trial.trial_id)
|
||||||
|
@ -1014,6 +1035,7 @@ class TrialRunner:
|
||||||
self._checkpoint_trial_if_needed(trial, force=force_checkpoint)
|
self._checkpoint_trial_if_needed(trial, force=force_checkpoint)
|
||||||
|
|
||||||
if trial.is_saving:
|
if trial.is_saving:
|
||||||
|
logger.info(f"caching trial decision {trial}")
|
||||||
# Cache decision to execute on after the save is processed.
|
# Cache decision to execute on after the save is processed.
|
||||||
# This prevents changing the trial's state or kicking off
|
# This prevents changing the trial's state or kicking off
|
||||||
# another training step prematurely.
|
# another training step prematurely.
|
||||||
|
@ -1085,7 +1107,7 @@ class TrialRunner:
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
def _process_trial_save(self, trial):
|
def _process_trial_save(self, trial, result):
|
||||||
"""Processes a trial save.
|
"""Processes a trial save.
|
||||||
|
|
||||||
Acts on the decision cached during the last `_process_trial` call.
|
Acts on the decision cached during the last `_process_trial` call.
|
||||||
|
@ -1094,19 +1116,9 @@ class TrialRunner:
|
||||||
trial (Trial): Trial being saved.
|
trial (Trial): Trial being saved.
|
||||||
"""
|
"""
|
||||||
logger.debug("Trial %s: Processing trial save.", trial)
|
logger.debug("Trial %s: Processing trial save.", trial)
|
||||||
checkpoint_value = None
|
|
||||||
try:
|
|
||||||
results = self.trial_executor.fetch_result(trial)
|
|
||||||
checkpoint_value = results[-1]
|
|
||||||
except Exception:
|
|
||||||
logger.exception("Trial %s: Error processing result.", trial)
|
|
||||||
if self._fail_fast == TrialRunner.RAISE:
|
|
||||||
raise
|
|
||||||
self._process_trial_failure(trial, traceback.format_exc())
|
|
||||||
|
|
||||||
if checkpoint_value:
|
|
||||||
try:
|
try:
|
||||||
trial.saving_to.value = checkpoint_value
|
trial.saving_to.value = result
|
||||||
self._callbacks.on_checkpoint(
|
self._callbacks.on_checkpoint(
|
||||||
iteration=self._iteration,
|
iteration=self._iteration,
|
||||||
trials=self._trials,
|
trials=self._trials,
|
||||||
|
@ -1117,15 +1129,13 @@ class TrialRunner:
|
||||||
if trial.checkpoint.storage != Checkpoint.MEMORY:
|
if trial.checkpoint.storage != Checkpoint.MEMORY:
|
||||||
self.trial_executor.mark_trial_to_checkpoint(trial)
|
self.trial_executor.mark_trial_to_checkpoint(trial)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception(
|
logger.exception("Trial %s: Error handling checkpoint %s", trial, result)
|
||||||
"Trial %s: Error handling checkpoint %s", trial, checkpoint_value
|
|
||||||
)
|
|
||||||
if self._fail_fast == TrialRunner.RAISE:
|
if self._fail_fast == TrialRunner.RAISE:
|
||||||
raise
|
raise
|
||||||
|
|
||||||
trial.saving_to = None
|
trial.saving_to = None
|
||||||
decision = self._cached_trial_decisions.pop(trial.trial_id, None)
|
decision = self._cached_trial_decisions.pop(trial.trial_id, None)
|
||||||
if decision and checkpoint_value:
|
if decision and result:
|
||||||
self._queue_decision(trial, decision)
|
self._queue_decision(trial, decision)
|
||||||
|
|
||||||
def _process_trial_restore(self, trial):
|
def _process_trial_restore(self, trial):
|
||||||
|
@ -1135,18 +1145,11 @@ class TrialRunner:
|
||||||
trial (Trial): Trial being restored.
|
trial (Trial): Trial being restored.
|
||||||
"""
|
"""
|
||||||
logger.debug("Trial %s: Processing trial restore.", trial)
|
logger.debug("Trial %s: Processing trial restore.", trial)
|
||||||
try:
|
|
||||||
self.trial_executor.fetch_result(trial)
|
|
||||||
trial.on_restore()
|
trial.on_restore()
|
||||||
logger.debug("Trial %s: Restore processed successfully", trial)
|
logger.debug("Trial %s: Restore processed successfully", trial)
|
||||||
self.trial_executor.set_status(trial, Trial.RUNNING)
|
self.trial_executor.set_status(trial, Trial.RUNNING)
|
||||||
self.trial_executor.continue_training(trial)
|
self.trial_executor.continue_training(trial)
|
||||||
self._live_trials.add(trial)
|
self._live_trials.add(trial)
|
||||||
except Exception:
|
|
||||||
logger.exception("Trial %s: Error processing restore.", trial)
|
|
||||||
if self._fail_fast == TrialRunner.RAISE:
|
|
||||||
raise
|
|
||||||
self._process_trial_failure(trial, traceback.format_exc())
|
|
||||||
|
|
||||||
def _process_trial_failure(self, trial, error_msg):
|
def _process_trial_failure(self, trial, error_msg):
|
||||||
"""Handle trial failure.
|
"""Handle trial failure.
|
||||||
|
@ -1217,6 +1220,10 @@ class TrialRunner:
|
||||||
trial (Trial): Trial to recover.
|
trial (Trial): Trial to recover.
|
||||||
error_msg (str): Error message from prior to invoking this method.
|
error_msg (str): Error message from prior to invoking this method.
|
||||||
"""
|
"""
|
||||||
|
self._cached_trial_decisions.pop(trial.trial_id, None)
|
||||||
|
# Resetting this, in case that the trial is in saving status when it crashes.
|
||||||
|
if trial.is_saving:
|
||||||
|
trial.saving_to = None
|
||||||
if trial.is_restoring:
|
if trial.is_restoring:
|
||||||
# Restore was unsuccessful, try again without checkpoint.
|
# Restore was unsuccessful, try again without checkpoint.
|
||||||
trial.clear_checkpoint()
|
trial.clear_checkpoint()
|
||||||
|
@ -1229,6 +1236,8 @@ class TrialRunner:
|
||||||
"Trial %s: Attempting to restore " "trial state from last checkpoint.",
|
"Trial %s: Attempting to restore " "trial state from last checkpoint.",
|
||||||
trial,
|
trial,
|
||||||
)
|
)
|
||||||
|
# TODO(xwjiang): For better consistency, consider not starting
|
||||||
|
# trials here. Instead rely on requeuing the trial.
|
||||||
started = self.trial_executor.start_trial(trial)
|
started = self.trial_executor.start_trial(trial)
|
||||||
if not started:
|
if not started:
|
||||||
requeue_trial = True
|
requeue_trial = True
|
||||||
|
|
|
@ -306,6 +306,16 @@ def run(
|
||||||
"removing this argument from your call to `tune.run()`"
|
"removing this argument from your call to `tune.run()`"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Starting deprecation in ray 1.10.
|
||||||
|
if os.environ.get("TUNE_TRIAL_RESULT_WAIT_TIME_S") is not None:
|
||||||
|
warnings.warn("`TUNE_TRIAL_RESULT_WAIT_TIME_S` is deprecated.")
|
||||||
|
|
||||||
|
if os.environ.get("TUNE_TRIAL_STARTUP_GRACE_PERIOD") is not None:
|
||||||
|
warnings.warn("`TUNE_TRIAL_STARTUP_GRACE_PERIOD` is deprecated.")
|
||||||
|
|
||||||
|
if os.environ.get("TUNE_PLACEMENT_GROUP_WAIT_S") is not None:
|
||||||
|
warnings.warn("`TUNE_PLACEMENT_GROUP_WAIT_S` is deprecated.")
|
||||||
|
|
||||||
# NO CODE IS TO BE ADDED ABOVE THIS COMMENT
|
# NO CODE IS TO BE ADDED ABOVE THIS COMMENT
|
||||||
# remote_run_kwargs must be defined before any other
|
# remote_run_kwargs must be defined before any other
|
||||||
# code is ran to ensure that at this point,
|
# code is ran to ensure that at this point,
|
||||||
|
|
|
@ -316,16 +316,16 @@ class PlacementGroupManager:
|
||||||
self._cached_pgs: Dict[PlacementGroup, PlacementGroupFactory] = {}
|
self._cached_pgs: Dict[PlacementGroup, PlacementGroupFactory] = {}
|
||||||
|
|
||||||
# Placement groups scheduled for delayed removal.
|
# Placement groups scheduled for delayed removal.
|
||||||
|
# This is used as a damper to filter out some high frequency change
|
||||||
|
# in resources request.
|
||||||
|
# Only PGs that have never been used go here.
|
||||||
|
# TODO(xwjiang): `self._pgs_for_removal` and `self._unstaged_xxx`
|
||||||
|
# are really the same now. We should consolidate to using one.
|
||||||
|
# Also `remove_placement_group` method should just be combined with
|
||||||
|
# `unstage_unused_xxx`.
|
||||||
self._pgs_for_removal: Dict[PlacementGroup, float] = {}
|
self._pgs_for_removal: Dict[PlacementGroup, float] = {}
|
||||||
self._removal_delay = TUNE_PLACEMENT_GROUP_REMOVAL_DELAY
|
self._removal_delay = TUNE_PLACEMENT_GROUP_REMOVAL_DELAY
|
||||||
|
|
||||||
# Latest PG staging time to check if still in grace period.
|
|
||||||
self._latest_staging_start_time = time.time()
|
|
||||||
|
|
||||||
# Seconds we wait for a trial to come up before we make blocking calls
|
|
||||||
# to process events
|
|
||||||
self._grace_period = float(os.getenv("TUNE_TRIAL_STARTUP_GRACE_PERIOD", 10.0))
|
|
||||||
|
|
||||||
self._max_staging = max_staging
|
self._max_staging = max_staging
|
||||||
|
|
||||||
def set_max_staging(self, max_staging: int):
|
def set_max_staging(self, max_staging: int):
|
||||||
|
@ -440,7 +440,6 @@ class PlacementGroupManager:
|
||||||
|
|
||||||
self._staging[pgf].add(pg)
|
self._staging[pgf].add(pg)
|
||||||
self._staging_futures[pg.ready()] = (pgf, pg)
|
self._staging_futures[pg.ready()] = (pgf, pg)
|
||||||
self._latest_staging_start_time = time.time()
|
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
@ -461,11 +460,17 @@ class PlacementGroupManager:
|
||||||
ready, _ = ray.wait(list(self._staging_futures.keys()), timeout=0)
|
ready, _ = ray.wait(list(self._staging_futures.keys()), timeout=0)
|
||||||
|
|
||||||
for ready_fut in ready:
|
for ready_fut in ready:
|
||||||
|
self.handle_ready_future(ready_fut)
|
||||||
|
|
||||||
|
def handle_ready_future(self, ready_fut):
|
||||||
ready_pgf, ready_pg = self._staging_futures.pop(ready_fut)
|
ready_pgf, ready_pg = self._staging_futures.pop(ready_fut)
|
||||||
|
|
||||||
self._staging[ready_pgf].remove(ready_pg)
|
self._staging[ready_pgf].remove(ready_pg)
|
||||||
self._ready[ready_pgf].add(ready_pg)
|
self._ready[ready_pgf].add(ready_pg)
|
||||||
|
|
||||||
|
def get_staging_future_list(self):
|
||||||
|
return list(self._staging_futures.keys())
|
||||||
|
|
||||||
def get_full_actor_cls(
|
def get_full_actor_cls(
|
||||||
self, trial: "Trial", actor_cls: ActorClass
|
self, trial: "Trial", actor_cls: ActorClass
|
||||||
) -> Optional[ActorClass]:
|
) -> Optional[ActorClass]:
|
||||||
|
@ -602,7 +607,7 @@ class PlacementGroupManager:
|
||||||
def clean_cached_pg(self, pg: PlacementGroup):
|
def clean_cached_pg(self, pg: PlacementGroup):
|
||||||
self._cached_pgs.pop(pg)
|
self._cached_pgs.pop(pg)
|
||||||
|
|
||||||
def return_pg(self, trial: "Trial"):
|
def remove_from_in_use(self, trial: "Trial") -> PlacementGroup:
|
||||||
"""Return pg back to Core scheduling.
|
"""Return pg back to Core scheduling.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -612,7 +617,7 @@ class PlacementGroupManager:
|
||||||
pg = self._in_use_trials.pop(trial)
|
pg = self._in_use_trials.pop(trial)
|
||||||
self._in_use_pgs.pop(pg)
|
self._in_use_pgs.pop(pg)
|
||||||
|
|
||||||
self.remove_pg(pg)
|
return pg
|
||||||
|
|
||||||
def _unstage_unused_pg(
|
def _unstage_unused_pg(
|
||||||
self, pgf: PlacementGroupFactory
|
self, pgf: PlacementGroupFactory
|
||||||
|
@ -664,13 +669,6 @@ class PlacementGroupManager:
|
||||||
|
|
||||||
return trial_pg
|
return trial_pg
|
||||||
|
|
||||||
def in_staging_grace_period(self):
|
|
||||||
return (
|
|
||||||
self._staging_futures
|
|
||||||
and self._grace_period
|
|
||||||
and time.time() <= self._latest_staging_start_time + self._grace_period
|
|
||||||
)
|
|
||||||
|
|
||||||
def reconcile_placement_groups(self, trials: List["Trial"]):
|
def reconcile_placement_groups(self, trials: List["Trial"]):
|
||||||
"""Reconcile placement groups to match requirements.
|
"""Reconcile placement groups to match requirements.
|
||||||
|
|
||||||
|
|
Loading…
Add table
Reference in a new issue