[tune] Retry cloud sync up/down/delete on fail (#22029)

This commit is contained in:
Kai Fricke 2022-02-15 12:27:29 +00:00 committed by GitHub
parent b729a9390f
commit c866131cc0
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 88 additions and 13 deletions

View file

@ -5,6 +5,7 @@ import logging
import pathlib
import subprocess
import tempfile
import time
import types
import warnings
@ -30,7 +31,7 @@ def noop(*args):
return
def get_sync_client(sync_function, delete_function=None):
def get_sync_client(sync_function, delete_function=None) -> Optional["SyncClient"]:
"""Returns a sync client.
Args:
@ -58,7 +59,7 @@ def get_sync_client(sync_function, delete_function=None):
return client_cls(sync_function, sync_function, delete_function)
def get_cloud_sync_client(remote_path):
def get_cloud_sync_client(remote_path) -> "CommandBasedClient":
"""Returns a CommandBasedClient that can sync to/from remote storage.
Args:
@ -158,6 +159,10 @@ class SyncClient:
"""Waits for current sync to complete, if asynchronously started."""
pass
def wait_or_retry(self, max_retries: int = 3, backoff_s: int = 5):
"""Wait for current sync to complete or retries on error."""
pass
def reset(self):
"""Resets state."""
pass
@ -251,6 +256,8 @@ class CommandBasedClient(SyncClient):
self.logfile = None
self._closed = False
self.cmd_process = None
# Keep track of last command for retry
self._last_cmd = None
def set_logdir(self, logdir):
"""Sets the directory to log sync execution output in.
@ -273,6 +280,11 @@ class CommandBasedClient(SyncClient):
else:
return self.logfile
def _start_process(self, cmd: str) -> subprocess.Popen:
return subprocess.Popen(
cmd, shell=True, stderr=subprocess.PIPE, stdout=self._get_logfile()
)
def sync_up(self, source, target, exclude: Optional[List] = None):
return self._execute(self.sync_up_template, source, target, exclude)
@ -284,13 +296,15 @@ class CommandBasedClient(SyncClient):
def delete(self, target):
if self.is_running:
logger.warning("Last sync client cmd still in progress, skipping.")
logger.warning(
f"Last sync client cmd still in progress, "
f"skipping deletion of {target}"
)
return False
final_cmd = self.delete_template.format(target=quote(target), options="")
logger.debug("Running delete: {}".format(final_cmd))
self.cmd_process = subprocess.Popen(
final_cmd, shell=True, stderr=subprocess.PIPE, stdout=self._get_logfile()
)
self._last_cmd = final_cmd
self.cmd_process = self._start_process(final_cmd)
return True
def wait(self):
@ -306,10 +320,28 @@ class CommandBasedClient(SyncClient):
"Error message ({}): {}".format(args, code, error_msg)
)
def wait_or_retry(self, max_retries: int = 3, backoff_s: int = 5):
assert max_retries > 0
for i in range(max_retries - 1):
try:
self.wait()
except TuneError as e:
logger.error(
f"Caught sync error: {e}. "
f"Retrying after sleeping for {backoff_s} seconds..."
)
time.sleep(backoff_s)
self.cmd_process = self._start_process(self._last_cmd)
continue
return
self.cmd_process = None
raise TuneError(f"Failed sync even after {max_retries} retries.")
def reset(self):
if self.is_running:
logger.warning("Sync process still running but resetting anyways.")
self.cmd_process = None
self._last_cmd = None
def close(self):
if self.logfile:
@ -329,7 +361,10 @@ class CommandBasedClient(SyncClient):
def _execute(self, sync_template, source, target, exclude: Optional[List] = None):
"""Executes sync_template on source and target."""
if self.is_running:
logger.warning("Last sync client cmd still in progress, skipping.")
logger.warning(
f"Last sync client cmd still in progress, "
f"skipping sync from {source} to {target}."
)
return False
if exclude and self.exclude_template:
@ -355,9 +390,8 @@ class CommandBasedClient(SyncClient):
source=quote(source), target=quote(target), options=option_str
)
logger.debug("Running sync: {}".format(final_cmd))
self.cmd_process = subprocess.Popen(
final_cmd, shell=True, stderr=subprocess.PIPE, stdout=self._get_logfile()
)
self._last_cmd = final_cmd
self.cmd_process = self._start_process(final_cmd)
return True
@staticmethod

View file

@ -415,6 +415,41 @@ class TestSyncFunctionality(unittest.TestCase):
trial_syncer = syncer_callback._get_trial_syncer(trial)
self.assertEqual(trial_syncer.sync_client, NOOP)
def testSyncWaitRetry(self):
class CountingClient(CommandBasedClient):
def __init__(self, *args, **kwargs):
self._sync_ups = 0
self._sync_downs = 0
super(CountingClient, self).__init__(*args, **kwargs)
def _start_process(self, cmd):
if "UPLOAD" in cmd:
self._sync_ups += 1
elif "DOWNLOAD" in cmd:
self._sync_downs += 1
if self._sync_downs == 1:
self._last_cmd = "echo DOWNLOAD && true"
return super(CountingClient, self)._start_process(cmd)
client = CountingClient(
"echo UPLOAD {source} {target} && false",
"echo DOWNLOAD {source} {target} && false",
"echo DELETE {target}",
)
# Fail always
with self.assertRaisesRegex(TuneError, "Failed sync even after"):
client.sync_up("test_source", "test_target")
client.wait_or_retry(max_retries=3, backoff_s=0)
self.assertEquals(client._sync_ups, 3)
# Succeed after second try
client.sync_down("test_source", "test_target")
client.wait_or_retry(max_retries=3, backoff_s=0)
self.assertEquals(client._sync_downs, 2)
if __name__ == "__main__":
import pytest

View file

@ -441,7 +441,7 @@ class Trainable:
self.storage_client.sync_up(
checkpoint_dir, self._storage_path(checkpoint_dir)
)
self.storage_client.wait()
self.storage_client.wait_or_retry()
def save_to_object(self):
"""Saves the current model state to a Python object.
@ -488,7 +488,7 @@ class Trainable:
os.path.join(self.remote_checkpoint_dir, rel_checkpoint_dir),
os.path.join(self.logdir, rel_checkpoint_dir),
)
self.storage_client.wait()
self.storage_client.wait_or_retry()
# Ensure TrialCheckpoints are converted
if isinstance(checkpoint_path, TrialCheckpoint):
@ -557,6 +557,7 @@ class Trainable:
else:
if self.uses_cloud_checkpointing:
self.storage_client.delete(self._storage_path(checkpoint_dir))
self.storage_client.wait_or_retry()
if os.path.exists(checkpoint_dir):
shutil.rmtree(checkpoint_dir)

View file

@ -266,7 +266,12 @@ def send_signal_after_wait(process: subprocess.Popen, signal: int, wait: int = 3
time.sleep(wait)
if process.poll() is not None:
raise RuntimeError(f"Process {process.pid} already terminated.")
raise RuntimeError(
f"Process {process.pid} already terminated. This usually means "
f"that some of the trials ERRORed (e.g. because they couldn't be "
f"restored. Try re-running this test to see if this fixes the "
f"issue."
)
print(f"Sending signal {signal} to process {process.pid}")
process.send_signal(signal)