mirror of
https://github.com/vale981/ray
synced 2025-03-04 17:41:43 -05:00
[tune] Add timeout ro retry_fn to catch hanging syncs (#28155)
Syncing sometimes hangs in pyarrow for unknown reasons. We should introduce a timeout for these syncing operations. Signed-off-by: Kai Fricke <kai@anyscale.com>
This commit is contained in:
parent
8692eb6208
commit
3590a86db0
8 changed files with 208 additions and 54 deletions
|
@ -217,7 +217,7 @@ class RayTrialExecutor:
|
|||
|
||||
self._has_cleaned_up_pgs = False
|
||||
self._reuse_actors = reuse_actors
|
||||
# The maxlen will be updated when `set_max_pending_trials()` is called
|
||||
# The maxlen will be updated when `setup(max_pending_trials)` is called
|
||||
self._cached_actor_pg = deque(maxlen=1)
|
||||
self._pg_manager = _PlacementGroupManager(prefix=_get_tune_pg_prefix())
|
||||
self._staged_trials = set()
|
||||
|
@ -235,16 +235,20 @@ class RayTrialExecutor:
|
|||
self._buffer_max_time_s = float(
|
||||
os.getenv("TUNE_RESULT_BUFFER_MAX_TIME_S", 100.0)
|
||||
)
|
||||
self._trainable_kwargs = {}
|
||||
|
||||
def set_max_pending_trials(self, max_pending: int) -> None:
|
||||
def setup(
|
||||
self, max_pending_trials: int, trainable_kwargs: Optional[Dict] = None
|
||||
) -> None:
|
||||
if len(self._cached_actor_pg) > 0:
|
||||
logger.warning(
|
||||
"Cannot update maximum number of queued actors for reuse "
|
||||
"during a run."
|
||||
)
|
||||
else:
|
||||
self._cached_actor_pg = deque(maxlen=max_pending)
|
||||
self._pg_manager.set_max_staging(max_pending)
|
||||
self._cached_actor_pg = deque(maxlen=max_pending_trials)
|
||||
self._pg_manager.set_max_staging(max_pending_trials)
|
||||
self._trainable_kwargs = trainable_kwargs or {}
|
||||
|
||||
def set_status(self, trial: Trial, status: str) -> None:
|
||||
"""Sets status and checkpoints metadata if needed.
|
||||
|
@ -377,6 +381,9 @@ class RayTrialExecutor:
|
|||
kwargs["remote_checkpoint_dir"] = trial.remote_checkpoint_dir
|
||||
kwargs["custom_syncer"] = trial.custom_syncer
|
||||
|
||||
if self._trainable_kwargs:
|
||||
kwargs.update(self._trainable_kwargs)
|
||||
|
||||
# Throw a meaningful error if trainable does not use the
|
||||
# new API
|
||||
sig = inspect.signature(trial.get_trainable_cls())
|
||||
|
|
|
@ -198,6 +198,8 @@ class _ExperimentCheckpointManager:
|
|||
exclude = ["*/checkpoint_*"]
|
||||
|
||||
if self._syncer:
|
||||
# Todo: Implement sync_timeout for experiment-level syncing
|
||||
# (it is currently only used for trainable-to-cloud syncing)
|
||||
if force:
|
||||
# Wait until previous sync command finished
|
||||
self._syncer.wait()
|
||||
|
@ -341,7 +343,13 @@ class TrialRunner:
|
|||
else:
|
||||
# Manual override
|
||||
self._max_pending_trials = int(max_pending_trials)
|
||||
self.trial_executor.set_max_pending_trials(self._max_pending_trials)
|
||||
|
||||
sync_config = sync_config or SyncConfig()
|
||||
|
||||
self.trial_executor.setup(
|
||||
max_pending_trials=self._max_pending_trials,
|
||||
trainable_kwargs={"sync_timeout": sync_config.sync_timeout},
|
||||
)
|
||||
|
||||
self._metric = metric
|
||||
|
||||
|
@ -385,7 +393,6 @@ class TrialRunner:
|
|||
if self._local_checkpoint_dir:
|
||||
os.makedirs(self._local_checkpoint_dir, exist_ok=True)
|
||||
|
||||
sync_config = sync_config or SyncConfig()
|
||||
self._remote_checkpoint_dir = remote_checkpoint_dir
|
||||
|
||||
self._syncer = get_node_to_storage_syncer(sync_config)
|
||||
|
|
|
@ -40,6 +40,9 @@ logger = logging.getLogger(__name__)
|
|||
# Syncing period for syncing checkpoints between nodes or to cloud.
|
||||
DEFAULT_SYNC_PERIOD = 300
|
||||
|
||||
# Default sync timeout after which syncing processes are aborted
|
||||
DEFAULT_SYNC_TIMEOUT = 1800
|
||||
|
||||
_EXCLUDE_FROM_SYNC = [
|
||||
"./checkpoint_-00001",
|
||||
"./checkpoint_tmp*",
|
||||
|
@ -85,6 +88,8 @@ class SyncConfig:
|
|||
is asynchronous and best-effort. This does not affect persistent
|
||||
storage syncing. Defaults to True.
|
||||
sync_period: Syncing period for syncing between nodes.
|
||||
sync_timeout: Timeout after which running sync processes are aborted.
|
||||
Currently only affects trial-to-cloud syncing.
|
||||
|
||||
"""
|
||||
|
||||
|
@ -93,6 +98,7 @@ class SyncConfig:
|
|||
|
||||
sync_on_checkpoint: bool = True
|
||||
sync_period: int = DEFAULT_SYNC_PERIOD
|
||||
sync_timeout: int = DEFAULT_SYNC_TIMEOUT
|
||||
|
||||
def _repr_html_(self) -> str:
|
||||
"""Generate an HTML representation of the SyncConfig.
|
||||
|
|
|
@ -499,7 +499,7 @@ class RayExecutorPlacementGroupTest(unittest.TestCase):
|
|||
|
||||
executor = RayTrialExecutor(reuse_actors=True)
|
||||
executor._pg_manager = pgm
|
||||
executor.set_max_pending_trials(1)
|
||||
executor.setup(max_pending_trials=1)
|
||||
|
||||
def train(config):
|
||||
yield 1
|
||||
|
|
|
@ -1,14 +1,16 @@
|
|||
import json
|
||||
import os
|
||||
import tempfile
|
||||
import time
|
||||
from typing import Dict, Union
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
import ray
|
||||
from ray import tune
|
||||
from ray.air import session, Checkpoint
|
||||
from ray.air._internal.remote_storage import download_from_uri
|
||||
from ray.air._internal.remote_storage import download_from_uri, upload_to_uri
|
||||
from ray.tune.trainable import wrap_function
|
||||
|
||||
|
||||
|
@ -188,6 +190,42 @@ def test_checkpoint_object_no_sync(tmpdir):
|
|||
trainable.restore_from_object(obj)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("hanging", [True, False])
|
||||
def test_sync_timeout(tmpdir, hanging):
|
||||
orig_upload_fn = upload_to_uri
|
||||
|
||||
def _hanging_upload(*args, **kwargs):
|
||||
time.sleep(200 if hanging else 0)
|
||||
orig_upload_fn(*args, **kwargs)
|
||||
|
||||
trainable = SavingTrainable(
|
||||
"object",
|
||||
remote_checkpoint_dir=f"memory:///test/location_hanging_{hanging}",
|
||||
sync_timeout=0.5,
|
||||
)
|
||||
|
||||
with patch("ray.air.checkpoint.upload_to_uri", _hanging_upload):
|
||||
trainable.save()
|
||||
|
||||
check_dir = tmpdir / "check_save_obj"
|
||||
|
||||
try:
|
||||
download_from_uri(
|
||||
uri=f"memory:///test/location_hanging_{hanging}", local_path=str(check_dir)
|
||||
)
|
||||
except FileNotFoundError:
|
||||
hung = True
|
||||
else:
|
||||
hung = False
|
||||
|
||||
assert hung == hanging
|
||||
|
||||
if hanging:
|
||||
assert not check_dir.exists()
|
||||
else:
|
||||
assert check_dir.listdir()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
|
||||
|
|
|
@ -1,45 +1,94 @@
|
|||
import unittest
|
||||
import time
|
||||
|
||||
import pytest
|
||||
|
||||
from ray.tune.search.variant_generator import format_vars
|
||||
from ray.tune.utils.util import retry_fn
|
||||
|
||||
|
||||
class TuneUtilsTest(unittest.TestCase):
|
||||
def testFormatVars(self):
|
||||
# Format brackets correctly
|
||||
self.assertTrue(
|
||||
format_vars(
|
||||
{
|
||||
("a", "b", "c"): 8.1234567,
|
||||
("a", "b", "d"): [7, 8],
|
||||
("a", "b", "e"): [[[3, 4]]],
|
||||
}
|
||||
),
|
||||
"c=8.12345,d=7_8,e=3_4",
|
||||
def test_format_vars():
|
||||
|
||||
# Format brackets correctly
|
||||
assert (
|
||||
format_vars(
|
||||
{
|
||||
("a", "b", "c"): 8.1234567,
|
||||
("a", "b", "d"): [7, 8],
|
||||
("a", "b", "e"): [[[3, 4]]],
|
||||
}
|
||||
)
|
||||
# Sorted by full keys, but only last key is reported
|
||||
self.assertTrue(
|
||||
format_vars(
|
||||
{
|
||||
("a", "c", "x"): [7, 8],
|
||||
("a", "b", "x"): 8.1234567,
|
||||
}
|
||||
),
|
||||
"x=8.12345,x=7_8",
|
||||
== "c=8.1235,d=7_8,e=3_4"
|
||||
)
|
||||
# Sorted by full keys, but only last key is reported
|
||||
assert (
|
||||
format_vars(
|
||||
{
|
||||
("a", "c", "x"): [7, 8],
|
||||
("a", "b", "x"): 8.1234567,
|
||||
}
|
||||
)
|
||||
# Filter out invalid chars. It's ok to have empty keys or values.
|
||||
self.assertTrue(
|
||||
format_vars(
|
||||
{
|
||||
("a c?x"): " <;%$ok ",
|
||||
("some"): " ",
|
||||
}
|
||||
),
|
||||
"a_c_x=ok,some=",
|
||||
== "x=8.1235,x=7_8"
|
||||
)
|
||||
# Filter out invalid chars. It's ok to have empty keys or values.
|
||||
assert (
|
||||
format_vars(
|
||||
{
|
||||
("a c?x",): " <;%$ok ",
|
||||
("some",): " ",
|
||||
}
|
||||
)
|
||||
== "a_c_x=ok,some="
|
||||
)
|
||||
|
||||
|
||||
def test_retry_fn_repeat(tmpdir):
|
||||
success = tmpdir / "success"
|
||||
marker = tmpdir / "marker"
|
||||
|
||||
def _fail_once():
|
||||
if marker.exists():
|
||||
success.write_text(".", encoding="utf-8")
|
||||
return
|
||||
marker.write_text(".", encoding="utf-8")
|
||||
raise RuntimeError("Failing")
|
||||
|
||||
assert not success.exists()
|
||||
assert not marker.exists()
|
||||
|
||||
assert retry_fn(
|
||||
fn=_fail_once,
|
||||
exception_type=RuntimeError,
|
||||
sleep_time=0,
|
||||
)
|
||||
|
||||
assert success.exists()
|
||||
assert marker.exists()
|
||||
|
||||
|
||||
def test_retry_fn_timeout(tmpdir):
|
||||
success = tmpdir / "success"
|
||||
marker = tmpdir / "marker"
|
||||
|
||||
def _fail_once():
|
||||
if not marker.exists():
|
||||
marker.write_text(".", encoding="utf-8")
|
||||
raise RuntimeError("Failing")
|
||||
time.sleep(5)
|
||||
success.write_text(".", encoding="utf-8")
|
||||
return
|
||||
|
||||
assert not success.exists()
|
||||
assert not marker.exists()
|
||||
|
||||
assert not retry_fn(
|
||||
fn=_fail_once, exception_type=RuntimeError, sleep_time=0, timeout=0.1
|
||||
)
|
||||
|
||||
assert not success.exists()
|
||||
assert marker.exists()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import pytest
|
||||
import sys
|
||||
|
||||
sys.exit(pytest.main(["-v", __file__]))
|
||||
|
|
|
@ -101,8 +101,9 @@ class Trainable:
|
|||
logger_creator: Callable[[Dict[str, Any]], "Logger"] = None,
|
||||
remote_checkpoint_dir: Optional[str] = None,
|
||||
custom_syncer: Optional[Syncer] = None,
|
||||
sync_timeout: Optional[int] = None,
|
||||
):
|
||||
"""Initialize an Trainable.
|
||||
"""Initialize a Trainable.
|
||||
|
||||
Sets up logging and points ``self.logdir`` to a directory in which
|
||||
training outputs should be placed.
|
||||
|
@ -120,6 +121,7 @@ class Trainable:
|
|||
which is different from **per checkpoint** directory.
|
||||
custom_syncer: Syncer used for synchronizing data from Ray nodes
|
||||
to external storage.
|
||||
sync_timeout: Timeout after which sync processes are aborted.
|
||||
"""
|
||||
|
||||
self._experiment_id = uuid.uuid4().hex
|
||||
|
@ -171,6 +173,7 @@ class Trainable:
|
|||
|
||||
self.remote_checkpoint_dir = remote_checkpoint_dir
|
||||
self.custom_syncer = custom_syncer
|
||||
self.sync_timeout = sync_timeout
|
||||
|
||||
@property
|
||||
def uses_cloud_checkpointing(self):
|
||||
|
@ -512,12 +515,22 @@ class Trainable:
|
|||
return True
|
||||
|
||||
checkpoint = Checkpoint.from_directory(checkpoint_dir)
|
||||
retry_fn(
|
||||
lambda: checkpoint.to_uri(self._storage_path(checkpoint_dir)),
|
||||
checkpoint_uri = self._storage_path(checkpoint_dir)
|
||||
if not retry_fn(
|
||||
lambda: checkpoint.to_uri(checkpoint_uri),
|
||||
subprocess.CalledProcessError,
|
||||
num_retries=3,
|
||||
sleep_time=1,
|
||||
)
|
||||
timeout=self.sync_timeout,
|
||||
):
|
||||
logger.error(
|
||||
f"Could not upload checkpoint even after 3 retries."
|
||||
f"Please check if the credentials expired and that the remote "
|
||||
f"filesystem is supported.. For large checkpoints, consider "
|
||||
f"increasing `SyncConfig(sync_timeout)` "
|
||||
f"(current value: {self.sync_timeout} seconds). Checkpoint URI: "
|
||||
f"{checkpoint_uri}"
|
||||
)
|
||||
return True
|
||||
|
||||
def _maybe_load_from_cloud(self, checkpoint_path: str) -> bool:
|
||||
|
@ -546,12 +559,17 @@ class Trainable:
|
|||
return True
|
||||
|
||||
checkpoint = Checkpoint.from_uri(external_uri)
|
||||
retry_fn(
|
||||
if not retry_fn(
|
||||
lambda: checkpoint.to_directory(local_dir),
|
||||
subprocess.CalledProcessError,
|
||||
num_retries=3,
|
||||
sleep_time=1,
|
||||
)
|
||||
timeout=self.sync_timeout,
|
||||
):
|
||||
logger.error(
|
||||
f"Could not download checkpoint even after 3 retries: "
|
||||
f"{external_uri}"
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
|
@ -719,12 +737,17 @@ class Trainable:
|
|||
self.custom_syncer.wait_or_retry()
|
||||
else:
|
||||
checkpoint_uri = self._storage_path(checkpoint_dir)
|
||||
retry_fn(
|
||||
if not retry_fn(
|
||||
lambda: _delete_external_checkpoint(checkpoint_uri),
|
||||
subprocess.CalledProcessError,
|
||||
num_retries=3,
|
||||
sleep_time=1,
|
||||
)
|
||||
timeout=self.sync_timeout,
|
||||
):
|
||||
logger.error(
|
||||
f"Could not delete checkpoint even after 3 retries: "
|
||||
f"{checkpoint_uri}"
|
||||
)
|
||||
|
||||
if os.path.exists(checkpoint_dir):
|
||||
shutil.rmtree(checkpoint_dir)
|
||||
|
|
|
@ -7,6 +7,7 @@ import threading
|
|||
import time
|
||||
from collections import defaultdict
|
||||
from datetime import datetime
|
||||
from numbers import Number
|
||||
from threading import Thread
|
||||
from typing import Dict, List, Union, Type, Callable, Any, Optional
|
||||
|
||||
|
@ -124,18 +125,41 @@ class UtilMonitor(Thread):
|
|||
@DeveloperAPI
|
||||
def retry_fn(
|
||||
fn: Callable[[], Any],
|
||||
exception_type: Type[Exception],
|
||||
exception_type: Type[Exception] = Exception,
|
||||
num_retries: int = 3,
|
||||
sleep_time: int = 1,
|
||||
):
|
||||
for i in range(num_retries):
|
||||
timeout: Optional[Number] = None,
|
||||
) -> bool:
|
||||
errored = threading.Event()
|
||||
|
||||
def _try_fn():
|
||||
try:
|
||||
fn()
|
||||
except exception_type as e:
|
||||
logger.warning(e)
|
||||
time.sleep(sleep_time)
|
||||
else:
|
||||
break
|
||||
errored.set()
|
||||
|
||||
for i in range(num_retries):
|
||||
errored.clear()
|
||||
|
||||
proc = threading.Thread(target=_try_fn)
|
||||
proc.daemon = True
|
||||
proc.start()
|
||||
proc.join(timeout=timeout)
|
||||
|
||||
if proc.is_alive():
|
||||
logger.debug(
|
||||
f"Process timed out (try {i+1}/{num_retries}): "
|
||||
f"{getattr(fn, '__name__', None)}"
|
||||
)
|
||||
elif not errored.is_set():
|
||||
return True
|
||||
|
||||
# Timed out, sleep and try again
|
||||
time.sleep(sleep_time)
|
||||
|
||||
# Timed out, so return False
|
||||
return False
|
||||
|
||||
|
||||
@ray.remote
|
||||
|
|
Loading…
Add table
Reference in a new issue