mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[tune] Retry cloud sync up/down/delete on fail (#22029)
This commit is contained in:
parent
b729a9390f
commit
c866131cc0
4 changed files with 88 additions and 13 deletions
|
@ -5,6 +5,7 @@ import logging
|
|||
import pathlib
|
||||
import subprocess
|
||||
import tempfile
|
||||
import time
|
||||
import types
|
||||
import warnings
|
||||
|
||||
|
@ -30,7 +31,7 @@ def noop(*args):
|
|||
return
|
||||
|
||||
|
||||
def get_sync_client(sync_function, delete_function=None):
|
||||
def get_sync_client(sync_function, delete_function=None) -> Optional["SyncClient"]:
|
||||
"""Returns a sync client.
|
||||
|
||||
Args:
|
||||
|
@ -58,7 +59,7 @@ def get_sync_client(sync_function, delete_function=None):
|
|||
return client_cls(sync_function, sync_function, delete_function)
|
||||
|
||||
|
||||
def get_cloud_sync_client(remote_path):
|
||||
def get_cloud_sync_client(remote_path) -> "CommandBasedClient":
|
||||
"""Returns a CommandBasedClient that can sync to/from remote storage.
|
||||
|
||||
Args:
|
||||
|
@ -158,6 +159,10 @@ class SyncClient:
|
|||
"""Waits for current sync to complete, if asynchronously started."""
|
||||
pass
|
||||
|
||||
def wait_or_retry(self, max_retries: int = 3, backoff_s: int = 5):
|
||||
"""Wait for current sync to complete or retries on error."""
|
||||
pass
|
||||
|
||||
def reset(self):
|
||||
"""Resets state."""
|
||||
pass
|
||||
|
@ -251,6 +256,8 @@ class CommandBasedClient(SyncClient):
|
|||
self.logfile = None
|
||||
self._closed = False
|
||||
self.cmd_process = None
|
||||
# Keep track of last command for retry
|
||||
self._last_cmd = None
|
||||
|
||||
def set_logdir(self, logdir):
|
||||
"""Sets the directory to log sync execution output in.
|
||||
|
@ -273,6 +280,11 @@ class CommandBasedClient(SyncClient):
|
|||
else:
|
||||
return self.logfile
|
||||
|
||||
def _start_process(self, cmd: str) -> subprocess.Popen:
|
||||
return subprocess.Popen(
|
||||
cmd, shell=True, stderr=subprocess.PIPE, stdout=self._get_logfile()
|
||||
)
|
||||
|
||||
def sync_up(self, source, target, exclude: Optional[List] = None):
|
||||
return self._execute(self.sync_up_template, source, target, exclude)
|
||||
|
||||
|
@ -284,13 +296,15 @@ class CommandBasedClient(SyncClient):
|
|||
|
||||
def delete(self, target):
|
||||
if self.is_running:
|
||||
logger.warning("Last sync client cmd still in progress, skipping.")
|
||||
logger.warning(
|
||||
f"Last sync client cmd still in progress, "
|
||||
f"skipping deletion of {target}"
|
||||
)
|
||||
return False
|
||||
final_cmd = self.delete_template.format(target=quote(target), options="")
|
||||
logger.debug("Running delete: {}".format(final_cmd))
|
||||
self.cmd_process = subprocess.Popen(
|
||||
final_cmd, shell=True, stderr=subprocess.PIPE, stdout=self._get_logfile()
|
||||
)
|
||||
self._last_cmd = final_cmd
|
||||
self.cmd_process = self._start_process(final_cmd)
|
||||
return True
|
||||
|
||||
def wait(self):
|
||||
|
@ -306,10 +320,28 @@ class CommandBasedClient(SyncClient):
|
|||
"Error message ({}): {}".format(args, code, error_msg)
|
||||
)
|
||||
|
||||
def wait_or_retry(self, max_retries: int = 3, backoff_s: int = 5):
|
||||
assert max_retries > 0
|
||||
for i in range(max_retries - 1):
|
||||
try:
|
||||
self.wait()
|
||||
except TuneError as e:
|
||||
logger.error(
|
||||
f"Caught sync error: {e}. "
|
||||
f"Retrying after sleeping for {backoff_s} seconds..."
|
||||
)
|
||||
time.sleep(backoff_s)
|
||||
self.cmd_process = self._start_process(self._last_cmd)
|
||||
continue
|
||||
return
|
||||
self.cmd_process = None
|
||||
raise TuneError(f"Failed sync even after {max_retries} retries.")
|
||||
|
||||
def reset(self):
|
||||
if self.is_running:
|
||||
logger.warning("Sync process still running but resetting anyways.")
|
||||
self.cmd_process = None
|
||||
self._last_cmd = None
|
||||
|
||||
def close(self):
|
||||
if self.logfile:
|
||||
|
@ -329,7 +361,10 @@ class CommandBasedClient(SyncClient):
|
|||
def _execute(self, sync_template, source, target, exclude: Optional[List] = None):
|
||||
"""Executes sync_template on source and target."""
|
||||
if self.is_running:
|
||||
logger.warning("Last sync client cmd still in progress, skipping.")
|
||||
logger.warning(
|
||||
f"Last sync client cmd still in progress, "
|
||||
f"skipping sync from {source} to {target}."
|
||||
)
|
||||
return False
|
||||
|
||||
if exclude and self.exclude_template:
|
||||
|
@ -355,9 +390,8 @@ class CommandBasedClient(SyncClient):
|
|||
source=quote(source), target=quote(target), options=option_str
|
||||
)
|
||||
logger.debug("Running sync: {}".format(final_cmd))
|
||||
self.cmd_process = subprocess.Popen(
|
||||
final_cmd, shell=True, stderr=subprocess.PIPE, stdout=self._get_logfile()
|
||||
)
|
||||
self._last_cmd = final_cmd
|
||||
self.cmd_process = self._start_process(final_cmd)
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
|
|
|
@ -415,6 +415,41 @@ class TestSyncFunctionality(unittest.TestCase):
|
|||
trial_syncer = syncer_callback._get_trial_syncer(trial)
|
||||
self.assertEqual(trial_syncer.sync_client, NOOP)
|
||||
|
||||
def testSyncWaitRetry(self):
|
||||
class CountingClient(CommandBasedClient):
|
||||
def __init__(self, *args, **kwargs):
|
||||
self._sync_ups = 0
|
||||
self._sync_downs = 0
|
||||
super(CountingClient, self).__init__(*args, **kwargs)
|
||||
|
||||
def _start_process(self, cmd):
|
||||
if "UPLOAD" in cmd:
|
||||
self._sync_ups += 1
|
||||
elif "DOWNLOAD" in cmd:
|
||||
self._sync_downs += 1
|
||||
if self._sync_downs == 1:
|
||||
self._last_cmd = "echo DOWNLOAD && true"
|
||||
return super(CountingClient, self)._start_process(cmd)
|
||||
|
||||
client = CountingClient(
|
||||
"echo UPLOAD {source} {target} && false",
|
||||
"echo DOWNLOAD {source} {target} && false",
|
||||
"echo DELETE {target}",
|
||||
)
|
||||
|
||||
# Fail always
|
||||
with self.assertRaisesRegex(TuneError, "Failed sync even after"):
|
||||
client.sync_up("test_source", "test_target")
|
||||
client.wait_or_retry(max_retries=3, backoff_s=0)
|
||||
|
||||
self.assertEquals(client._sync_ups, 3)
|
||||
|
||||
# Succeed after second try
|
||||
client.sync_down("test_source", "test_target")
|
||||
client.wait_or_retry(max_retries=3, backoff_s=0)
|
||||
|
||||
self.assertEquals(client._sync_downs, 2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import pytest
|
||||
|
|
|
@ -441,7 +441,7 @@ class Trainable:
|
|||
self.storage_client.sync_up(
|
||||
checkpoint_dir, self._storage_path(checkpoint_dir)
|
||||
)
|
||||
self.storage_client.wait()
|
||||
self.storage_client.wait_or_retry()
|
||||
|
||||
def save_to_object(self):
|
||||
"""Saves the current model state to a Python object.
|
||||
|
@ -488,7 +488,7 @@ class Trainable:
|
|||
os.path.join(self.remote_checkpoint_dir, rel_checkpoint_dir),
|
||||
os.path.join(self.logdir, rel_checkpoint_dir),
|
||||
)
|
||||
self.storage_client.wait()
|
||||
self.storage_client.wait_or_retry()
|
||||
|
||||
# Ensure TrialCheckpoints are converted
|
||||
if isinstance(checkpoint_path, TrialCheckpoint):
|
||||
|
@ -557,6 +557,7 @@ class Trainable:
|
|||
else:
|
||||
if self.uses_cloud_checkpointing:
|
||||
self.storage_client.delete(self._storage_path(checkpoint_dir))
|
||||
self.storage_client.wait_or_retry()
|
||||
|
||||
if os.path.exists(checkpoint_dir):
|
||||
shutil.rmtree(checkpoint_dir)
|
||||
|
|
|
@ -266,7 +266,12 @@ def send_signal_after_wait(process: subprocess.Popen, signal: int, wait: int = 3
|
|||
time.sleep(wait)
|
||||
|
||||
if process.poll() is not None:
|
||||
raise RuntimeError(f"Process {process.pid} already terminated.")
|
||||
raise RuntimeError(
|
||||
f"Process {process.pid} already terminated. This usually means "
|
||||
f"that some of the trials ERRORed (e.g. because they couldn't be "
|
||||
f"restored. Try re-running this test to see if this fixes the "
|
||||
f"issue."
|
||||
)
|
||||
|
||||
print(f"Sending signal {signal} to process {process.pid}")
|
||||
process.send_signal(signal)
|
||||
|
|
Loading…
Add table
Reference in a new issue