mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
[tune] Retry cloud sync up/down/delete on fail (#22029)
This commit is contained in:
parent
b729a9390f
commit
c866131cc0
4 changed files with 88 additions and 13 deletions
|
@ -5,6 +5,7 @@ import logging
|
||||||
import pathlib
|
import pathlib
|
||||||
import subprocess
|
import subprocess
|
||||||
import tempfile
|
import tempfile
|
||||||
|
import time
|
||||||
import types
|
import types
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
|
@ -30,7 +31,7 @@ def noop(*args):
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
||||||
def get_sync_client(sync_function, delete_function=None):
|
def get_sync_client(sync_function, delete_function=None) -> Optional["SyncClient"]:
|
||||||
"""Returns a sync client.
|
"""Returns a sync client.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -58,7 +59,7 @@ def get_sync_client(sync_function, delete_function=None):
|
||||||
return client_cls(sync_function, sync_function, delete_function)
|
return client_cls(sync_function, sync_function, delete_function)
|
||||||
|
|
||||||
|
|
||||||
def get_cloud_sync_client(remote_path):
|
def get_cloud_sync_client(remote_path) -> "CommandBasedClient":
|
||||||
"""Returns a CommandBasedClient that can sync to/from remote storage.
|
"""Returns a CommandBasedClient that can sync to/from remote storage.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -158,6 +159,10 @@ class SyncClient:
|
||||||
"""Waits for current sync to complete, if asynchronously started."""
|
"""Waits for current sync to complete, if asynchronously started."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def wait_or_retry(self, max_retries: int = 3, backoff_s: int = 5):
|
||||||
|
"""Wait for current sync to complete or retries on error."""
|
||||||
|
pass
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
"""Resets state."""
|
"""Resets state."""
|
||||||
pass
|
pass
|
||||||
|
@ -251,6 +256,8 @@ class CommandBasedClient(SyncClient):
|
||||||
self.logfile = None
|
self.logfile = None
|
||||||
self._closed = False
|
self._closed = False
|
||||||
self.cmd_process = None
|
self.cmd_process = None
|
||||||
|
# Keep track of last command for retry
|
||||||
|
self._last_cmd = None
|
||||||
|
|
||||||
def set_logdir(self, logdir):
|
def set_logdir(self, logdir):
|
||||||
"""Sets the directory to log sync execution output in.
|
"""Sets the directory to log sync execution output in.
|
||||||
|
@ -273,6 +280,11 @@ class CommandBasedClient(SyncClient):
|
||||||
else:
|
else:
|
||||||
return self.logfile
|
return self.logfile
|
||||||
|
|
||||||
|
def _start_process(self, cmd: str) -> subprocess.Popen:
|
||||||
|
return subprocess.Popen(
|
||||||
|
cmd, shell=True, stderr=subprocess.PIPE, stdout=self._get_logfile()
|
||||||
|
)
|
||||||
|
|
||||||
def sync_up(self, source, target, exclude: Optional[List] = None):
|
def sync_up(self, source, target, exclude: Optional[List] = None):
|
||||||
return self._execute(self.sync_up_template, source, target, exclude)
|
return self._execute(self.sync_up_template, source, target, exclude)
|
||||||
|
|
||||||
|
@ -284,13 +296,15 @@ class CommandBasedClient(SyncClient):
|
||||||
|
|
||||||
def delete(self, target):
|
def delete(self, target):
|
||||||
if self.is_running:
|
if self.is_running:
|
||||||
logger.warning("Last sync client cmd still in progress, skipping.")
|
logger.warning(
|
||||||
|
f"Last sync client cmd still in progress, "
|
||||||
|
f"skipping deletion of {target}"
|
||||||
|
)
|
||||||
return False
|
return False
|
||||||
final_cmd = self.delete_template.format(target=quote(target), options="")
|
final_cmd = self.delete_template.format(target=quote(target), options="")
|
||||||
logger.debug("Running delete: {}".format(final_cmd))
|
logger.debug("Running delete: {}".format(final_cmd))
|
||||||
self.cmd_process = subprocess.Popen(
|
self._last_cmd = final_cmd
|
||||||
final_cmd, shell=True, stderr=subprocess.PIPE, stdout=self._get_logfile()
|
self.cmd_process = self._start_process(final_cmd)
|
||||||
)
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def wait(self):
|
def wait(self):
|
||||||
|
@ -306,10 +320,28 @@ class CommandBasedClient(SyncClient):
|
||||||
"Error message ({}): {}".format(args, code, error_msg)
|
"Error message ({}): {}".format(args, code, error_msg)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def wait_or_retry(self, max_retries: int = 3, backoff_s: int = 5):
|
||||||
|
assert max_retries > 0
|
||||||
|
for i in range(max_retries - 1):
|
||||||
|
try:
|
||||||
|
self.wait()
|
||||||
|
except TuneError as e:
|
||||||
|
logger.error(
|
||||||
|
f"Caught sync error: {e}. "
|
||||||
|
f"Retrying after sleeping for {backoff_s} seconds..."
|
||||||
|
)
|
||||||
|
time.sleep(backoff_s)
|
||||||
|
self.cmd_process = self._start_process(self._last_cmd)
|
||||||
|
continue
|
||||||
|
return
|
||||||
|
self.cmd_process = None
|
||||||
|
raise TuneError(f"Failed sync even after {max_retries} retries.")
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
if self.is_running:
|
if self.is_running:
|
||||||
logger.warning("Sync process still running but resetting anyways.")
|
logger.warning("Sync process still running but resetting anyways.")
|
||||||
self.cmd_process = None
|
self.cmd_process = None
|
||||||
|
self._last_cmd = None
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
if self.logfile:
|
if self.logfile:
|
||||||
|
@ -329,7 +361,10 @@ class CommandBasedClient(SyncClient):
|
||||||
def _execute(self, sync_template, source, target, exclude: Optional[List] = None):
|
def _execute(self, sync_template, source, target, exclude: Optional[List] = None):
|
||||||
"""Executes sync_template on source and target."""
|
"""Executes sync_template on source and target."""
|
||||||
if self.is_running:
|
if self.is_running:
|
||||||
logger.warning("Last sync client cmd still in progress, skipping.")
|
logger.warning(
|
||||||
|
f"Last sync client cmd still in progress, "
|
||||||
|
f"skipping sync from {source} to {target}."
|
||||||
|
)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
if exclude and self.exclude_template:
|
if exclude and self.exclude_template:
|
||||||
|
@ -355,9 +390,8 @@ class CommandBasedClient(SyncClient):
|
||||||
source=quote(source), target=quote(target), options=option_str
|
source=quote(source), target=quote(target), options=option_str
|
||||||
)
|
)
|
||||||
logger.debug("Running sync: {}".format(final_cmd))
|
logger.debug("Running sync: {}".format(final_cmd))
|
||||||
self.cmd_process = subprocess.Popen(
|
self._last_cmd = final_cmd
|
||||||
final_cmd, shell=True, stderr=subprocess.PIPE, stdout=self._get_logfile()
|
self.cmd_process = self._start_process(final_cmd)
|
||||||
)
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|
|
@ -415,6 +415,41 @@ class TestSyncFunctionality(unittest.TestCase):
|
||||||
trial_syncer = syncer_callback._get_trial_syncer(trial)
|
trial_syncer = syncer_callback._get_trial_syncer(trial)
|
||||||
self.assertEqual(trial_syncer.sync_client, NOOP)
|
self.assertEqual(trial_syncer.sync_client, NOOP)
|
||||||
|
|
||||||
|
def testSyncWaitRetry(self):
|
||||||
|
class CountingClient(CommandBasedClient):
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
self._sync_ups = 0
|
||||||
|
self._sync_downs = 0
|
||||||
|
super(CountingClient, self).__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
def _start_process(self, cmd):
|
||||||
|
if "UPLOAD" in cmd:
|
||||||
|
self._sync_ups += 1
|
||||||
|
elif "DOWNLOAD" in cmd:
|
||||||
|
self._sync_downs += 1
|
||||||
|
if self._sync_downs == 1:
|
||||||
|
self._last_cmd = "echo DOWNLOAD && true"
|
||||||
|
return super(CountingClient, self)._start_process(cmd)
|
||||||
|
|
||||||
|
client = CountingClient(
|
||||||
|
"echo UPLOAD {source} {target} && false",
|
||||||
|
"echo DOWNLOAD {source} {target} && false",
|
||||||
|
"echo DELETE {target}",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Fail always
|
||||||
|
with self.assertRaisesRegex(TuneError, "Failed sync even after"):
|
||||||
|
client.sync_up("test_source", "test_target")
|
||||||
|
client.wait_or_retry(max_retries=3, backoff_s=0)
|
||||||
|
|
||||||
|
self.assertEquals(client._sync_ups, 3)
|
||||||
|
|
||||||
|
# Succeed after second try
|
||||||
|
client.sync_down("test_source", "test_target")
|
||||||
|
client.wait_or_retry(max_retries=3, backoff_s=0)
|
||||||
|
|
||||||
|
self.assertEquals(client._sync_downs, 2)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
import pytest
|
import pytest
|
||||||
|
|
|
@ -441,7 +441,7 @@ class Trainable:
|
||||||
self.storage_client.sync_up(
|
self.storage_client.sync_up(
|
||||||
checkpoint_dir, self._storage_path(checkpoint_dir)
|
checkpoint_dir, self._storage_path(checkpoint_dir)
|
||||||
)
|
)
|
||||||
self.storage_client.wait()
|
self.storage_client.wait_or_retry()
|
||||||
|
|
||||||
def save_to_object(self):
|
def save_to_object(self):
|
||||||
"""Saves the current model state to a Python object.
|
"""Saves the current model state to a Python object.
|
||||||
|
@ -488,7 +488,7 @@ class Trainable:
|
||||||
os.path.join(self.remote_checkpoint_dir, rel_checkpoint_dir),
|
os.path.join(self.remote_checkpoint_dir, rel_checkpoint_dir),
|
||||||
os.path.join(self.logdir, rel_checkpoint_dir),
|
os.path.join(self.logdir, rel_checkpoint_dir),
|
||||||
)
|
)
|
||||||
self.storage_client.wait()
|
self.storage_client.wait_or_retry()
|
||||||
|
|
||||||
# Ensure TrialCheckpoints are converted
|
# Ensure TrialCheckpoints are converted
|
||||||
if isinstance(checkpoint_path, TrialCheckpoint):
|
if isinstance(checkpoint_path, TrialCheckpoint):
|
||||||
|
@ -557,6 +557,7 @@ class Trainable:
|
||||||
else:
|
else:
|
||||||
if self.uses_cloud_checkpointing:
|
if self.uses_cloud_checkpointing:
|
||||||
self.storage_client.delete(self._storage_path(checkpoint_dir))
|
self.storage_client.delete(self._storage_path(checkpoint_dir))
|
||||||
|
self.storage_client.wait_or_retry()
|
||||||
|
|
||||||
if os.path.exists(checkpoint_dir):
|
if os.path.exists(checkpoint_dir):
|
||||||
shutil.rmtree(checkpoint_dir)
|
shutil.rmtree(checkpoint_dir)
|
||||||
|
|
|
@ -266,7 +266,12 @@ def send_signal_after_wait(process: subprocess.Popen, signal: int, wait: int = 3
|
||||||
time.sleep(wait)
|
time.sleep(wait)
|
||||||
|
|
||||||
if process.poll() is not None:
|
if process.poll() is not None:
|
||||||
raise RuntimeError(f"Process {process.pid} already terminated.")
|
raise RuntimeError(
|
||||||
|
f"Process {process.pid} already terminated. This usually means "
|
||||||
|
f"that some of the trials ERRORed (e.g. because they couldn't be "
|
||||||
|
f"restored. Try re-running this test to see if this fixes the "
|
||||||
|
f"issue."
|
||||||
|
)
|
||||||
|
|
||||||
print(f"Sending signal {signal} to process {process.pid}")
|
print(f"Sending signal {signal} to process {process.pid}")
|
||||||
process.send_signal(signal)
|
process.send_signal(signal)
|
||||||
|
|
Loading…
Add table
Reference in a new issue