[tune] Add timeout ro retry_fn to catch hanging syncs (#28155)

Syncing sometimes hangs in pyarrow for unknown reasons. We should introduce a timeout for these syncing operations.

Signed-off-by: Kai Fricke <kai@anyscale.com>
This commit is contained in:
Kai Fricke 2022-09-02 12:52:26 +01:00 committed by GitHub
parent 8692eb6208
commit 3590a86db0
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 208 additions and 54 deletions

View file

@ -217,7 +217,7 @@ class RayTrialExecutor:
self._has_cleaned_up_pgs = False
self._reuse_actors = reuse_actors
# The maxlen will be updated when `set_max_pending_trials()` is called
# The maxlen will be updated when `setup(max_pending_trials)` is called
self._cached_actor_pg = deque(maxlen=1)
self._pg_manager = _PlacementGroupManager(prefix=_get_tune_pg_prefix())
self._staged_trials = set()
@ -235,16 +235,20 @@ class RayTrialExecutor:
self._buffer_max_time_s = float(
os.getenv("TUNE_RESULT_BUFFER_MAX_TIME_S", 100.0)
)
self._trainable_kwargs = {}
def set_max_pending_trials(self, max_pending: int) -> None:
def setup(
self, max_pending_trials: int, trainable_kwargs: Optional[Dict] = None
) -> None:
if len(self._cached_actor_pg) > 0:
logger.warning(
"Cannot update maximum number of queued actors for reuse "
"during a run."
)
else:
self._cached_actor_pg = deque(maxlen=max_pending)
self._pg_manager.set_max_staging(max_pending)
self._cached_actor_pg = deque(maxlen=max_pending_trials)
self._pg_manager.set_max_staging(max_pending_trials)
self._trainable_kwargs = trainable_kwargs or {}
def set_status(self, trial: Trial, status: str) -> None:
"""Sets status and checkpoints metadata if needed.
@ -377,6 +381,9 @@ class RayTrialExecutor:
kwargs["remote_checkpoint_dir"] = trial.remote_checkpoint_dir
kwargs["custom_syncer"] = trial.custom_syncer
if self._trainable_kwargs:
kwargs.update(self._trainable_kwargs)
# Throw a meaningful error if trainable does not use the
# new API
sig = inspect.signature(trial.get_trainable_cls())

View file

@ -198,6 +198,8 @@ class _ExperimentCheckpointManager:
exclude = ["*/checkpoint_*"]
if self._syncer:
# Todo: Implement sync_timeout for experiment-level syncing
# (it is currently only used for trainable-to-cloud syncing)
if force:
# Wait until previous sync command finished
self._syncer.wait()
@ -341,7 +343,13 @@ class TrialRunner:
else:
# Manual override
self._max_pending_trials = int(max_pending_trials)
self.trial_executor.set_max_pending_trials(self._max_pending_trials)
sync_config = sync_config or SyncConfig()
self.trial_executor.setup(
max_pending_trials=self._max_pending_trials,
trainable_kwargs={"sync_timeout": sync_config.sync_timeout},
)
self._metric = metric
@ -385,7 +393,6 @@ class TrialRunner:
if self._local_checkpoint_dir:
os.makedirs(self._local_checkpoint_dir, exist_ok=True)
sync_config = sync_config or SyncConfig()
self._remote_checkpoint_dir = remote_checkpoint_dir
self._syncer = get_node_to_storage_syncer(sync_config)

View file

@ -40,6 +40,9 @@ logger = logging.getLogger(__name__)
# Syncing period for syncing checkpoints between nodes or to cloud.
DEFAULT_SYNC_PERIOD = 300
# Default sync timeout after which syncing processes are aborted
DEFAULT_SYNC_TIMEOUT = 1800
_EXCLUDE_FROM_SYNC = [
"./checkpoint_-00001",
"./checkpoint_tmp*",
@ -85,6 +88,8 @@ class SyncConfig:
is asynchronous and best-effort. This does not affect persistent
storage syncing. Defaults to True.
sync_period: Syncing period for syncing between nodes.
sync_timeout: Timeout after which running sync processes are aborted.
Currently only affects trial-to-cloud syncing.
"""
@ -93,6 +98,7 @@ class SyncConfig:
sync_on_checkpoint: bool = True
sync_period: int = DEFAULT_SYNC_PERIOD
sync_timeout: int = DEFAULT_SYNC_TIMEOUT
def _repr_html_(self) -> str:
"""Generate an HTML representation of the SyncConfig.

View file

@ -499,7 +499,7 @@ class RayExecutorPlacementGroupTest(unittest.TestCase):
executor = RayTrialExecutor(reuse_actors=True)
executor._pg_manager = pgm
executor.set_max_pending_trials(1)
executor.setup(max_pending_trials=1)
def train(config):
yield 1

View file

@ -1,14 +1,16 @@
import json
import os
import tempfile
import time
from typing import Dict, Union
from unittest.mock import patch
import pytest
import ray
from ray import tune
from ray.air import session, Checkpoint
from ray.air._internal.remote_storage import download_from_uri
from ray.air._internal.remote_storage import download_from_uri, upload_to_uri
from ray.tune.trainable import wrap_function
@ -188,6 +190,42 @@ def test_checkpoint_object_no_sync(tmpdir):
trainable.restore_from_object(obj)
@pytest.mark.parametrize("hanging", [True, False])
def test_sync_timeout(tmpdir, hanging):
orig_upload_fn = upload_to_uri
def _hanging_upload(*args, **kwargs):
time.sleep(200 if hanging else 0)
orig_upload_fn(*args, **kwargs)
trainable = SavingTrainable(
"object",
remote_checkpoint_dir=f"memory:///test/location_hanging_{hanging}",
sync_timeout=0.5,
)
with patch("ray.air.checkpoint.upload_to_uri", _hanging_upload):
trainable.save()
check_dir = tmpdir / "check_save_obj"
try:
download_from_uri(
uri=f"memory:///test/location_hanging_{hanging}", local_path=str(check_dir)
)
except FileNotFoundError:
hung = True
else:
hung = False
assert hung == hanging
if hanging:
assert not check_dir.exists()
else:
assert check_dir.listdir()
if __name__ == "__main__":
import sys

View file

@ -1,45 +1,94 @@
import unittest
import time
import pytest
from ray.tune.search.variant_generator import format_vars
from ray.tune.utils.util import retry_fn
class TuneUtilsTest(unittest.TestCase):
def testFormatVars(self):
# Format brackets correctly
self.assertTrue(
format_vars(
{
("a", "b", "c"): 8.1234567,
("a", "b", "d"): [7, 8],
("a", "b", "e"): [[[3, 4]]],
}
),
"c=8.12345,d=7_8,e=3_4",
def test_format_vars():
# Format brackets correctly
assert (
format_vars(
{
("a", "b", "c"): 8.1234567,
("a", "b", "d"): [7, 8],
("a", "b", "e"): [[[3, 4]]],
}
)
# Sorted by full keys, but only last key is reported
self.assertTrue(
format_vars(
{
("a", "c", "x"): [7, 8],
("a", "b", "x"): 8.1234567,
}
),
"x=8.12345,x=7_8",
== "c=8.1235,d=7_8,e=3_4"
)
# Sorted by full keys, but only last key is reported
assert (
format_vars(
{
("a", "c", "x"): [7, 8],
("a", "b", "x"): 8.1234567,
}
)
# Filter out invalid chars. It's ok to have empty keys or values.
self.assertTrue(
format_vars(
{
("a c?x"): " <;%$ok ",
("some"): " ",
}
),
"a_c_x=ok,some=",
== "x=8.1235,x=7_8"
)
# Filter out invalid chars. It's ok to have empty keys or values.
assert (
format_vars(
{
("a c?x",): " <;%$ok ",
("some",): " ",
}
)
== "a_c_x=ok,some="
)
def test_retry_fn_repeat(tmpdir):
success = tmpdir / "success"
marker = tmpdir / "marker"
def _fail_once():
if marker.exists():
success.write_text(".", encoding="utf-8")
return
marker.write_text(".", encoding="utf-8")
raise RuntimeError("Failing")
assert not success.exists()
assert not marker.exists()
assert retry_fn(
fn=_fail_once,
exception_type=RuntimeError,
sleep_time=0,
)
assert success.exists()
assert marker.exists()
def test_retry_fn_timeout(tmpdir):
success = tmpdir / "success"
marker = tmpdir / "marker"
def _fail_once():
if not marker.exists():
marker.write_text(".", encoding="utf-8")
raise RuntimeError("Failing")
time.sleep(5)
success.write_text(".", encoding="utf-8")
return
assert not success.exists()
assert not marker.exists()
assert not retry_fn(
fn=_fail_once, exception_type=RuntimeError, sleep_time=0, timeout=0.1
)
assert not success.exists()
assert marker.exists()
if __name__ == "__main__":
import pytest
import sys
sys.exit(pytest.main(["-v", __file__]))

View file

@ -101,8 +101,9 @@ class Trainable:
logger_creator: Callable[[Dict[str, Any]], "Logger"] = None,
remote_checkpoint_dir: Optional[str] = None,
custom_syncer: Optional[Syncer] = None,
sync_timeout: Optional[int] = None,
):
"""Initialize an Trainable.
"""Initialize a Trainable.
Sets up logging and points ``self.logdir`` to a directory in which
training outputs should be placed.
@ -120,6 +121,7 @@ class Trainable:
which is different from **per checkpoint** directory.
custom_syncer: Syncer used for synchronizing data from Ray nodes
to external storage.
sync_timeout: Timeout after which sync processes are aborted.
"""
self._experiment_id = uuid.uuid4().hex
@ -171,6 +173,7 @@ class Trainable:
self.remote_checkpoint_dir = remote_checkpoint_dir
self.custom_syncer = custom_syncer
self.sync_timeout = sync_timeout
@property
def uses_cloud_checkpointing(self):
@ -512,12 +515,22 @@ class Trainable:
return True
checkpoint = Checkpoint.from_directory(checkpoint_dir)
retry_fn(
lambda: checkpoint.to_uri(self._storage_path(checkpoint_dir)),
checkpoint_uri = self._storage_path(checkpoint_dir)
if not retry_fn(
lambda: checkpoint.to_uri(checkpoint_uri),
subprocess.CalledProcessError,
num_retries=3,
sleep_time=1,
)
timeout=self.sync_timeout,
):
logger.error(
f"Could not upload checkpoint even after 3 retries."
f"Please check if the credentials expired and that the remote "
f"filesystem is supported.. For large checkpoints, consider "
f"increasing `SyncConfig(sync_timeout)` "
f"(current value: {self.sync_timeout} seconds). Checkpoint URI: "
f"{checkpoint_uri}"
)
return True
def _maybe_load_from_cloud(self, checkpoint_path: str) -> bool:
@ -546,12 +559,17 @@ class Trainable:
return True
checkpoint = Checkpoint.from_uri(external_uri)
retry_fn(
if not retry_fn(
lambda: checkpoint.to_directory(local_dir),
subprocess.CalledProcessError,
num_retries=3,
sleep_time=1,
)
timeout=self.sync_timeout,
):
logger.error(
f"Could not download checkpoint even after 3 retries: "
f"{external_uri}"
)
return True
@ -719,12 +737,17 @@ class Trainable:
self.custom_syncer.wait_or_retry()
else:
checkpoint_uri = self._storage_path(checkpoint_dir)
retry_fn(
if not retry_fn(
lambda: _delete_external_checkpoint(checkpoint_uri),
subprocess.CalledProcessError,
num_retries=3,
sleep_time=1,
)
timeout=self.sync_timeout,
):
logger.error(
f"Could not delete checkpoint even after 3 retries: "
f"{checkpoint_uri}"
)
if os.path.exists(checkpoint_dir):
shutil.rmtree(checkpoint_dir)

View file

@ -7,6 +7,7 @@ import threading
import time
from collections import defaultdict
from datetime import datetime
from numbers import Number
from threading import Thread
from typing import Dict, List, Union, Type, Callable, Any, Optional
@ -124,18 +125,41 @@ class UtilMonitor(Thread):
@DeveloperAPI
def retry_fn(
fn: Callable[[], Any],
exception_type: Type[Exception],
exception_type: Type[Exception] = Exception,
num_retries: int = 3,
sleep_time: int = 1,
):
for i in range(num_retries):
timeout: Optional[Number] = None,
) -> bool:
errored = threading.Event()
def _try_fn():
try:
fn()
except exception_type as e:
logger.warning(e)
time.sleep(sleep_time)
else:
break
errored.set()
for i in range(num_retries):
errored.clear()
proc = threading.Thread(target=_try_fn)
proc.daemon = True
proc.start()
proc.join(timeout=timeout)
if proc.is_alive():
logger.debug(
f"Process timed out (try {i+1}/{num_retries}): "
f"{getattr(fn, '__name__', None)}"
)
elif not errored.is_set():
return True
# Timed out, sleep and try again
time.sleep(sleep_time)
# Timed out, so return False
return False
@ray.remote