mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[tune] Refactor Syncer / deprecate Sync client (#25655)
This PR includes / depends on #25709 The two concepts of Syncer and SyncClient are confusing, as is the current API for passing custom sync functions. This PR refactors Tune's syncing behavior. The Sync client concept is hard deprecated. Instead, we offer a well defined Syncer API that can be extended to provide own syncing functionality. However, the default will be to use Ray AIRs file transfer utilities. New API: - Users can pass `syncer=CustomSyncer` which implements the `Syncer` API - Otherwise our off-the-shelf syncing is used - As before, syncing to cloud disables syncing to driver Changes: - Sync client is removed - Syncer interface introduced - _DefaultSyncer is a wrapper around the URI upload/download API from Ray AIR - SyncerCallback only uses remote tasks to synchronize data - Rsync syncing is fully depracated and removed - Docker and kubernetes-specific syncing is fully deprecated and removed - Testing is improved to use `file://` URIs instead of mock sync clients
This commit is contained in:
parent
f597e21ac8
commit
6313ddc47c
36 changed files with 1787 additions and 2884 deletions
|
@ -72,7 +72,6 @@ These are the environment variables Ray Tune currently considers:
|
|||
* **TUNE_RESULT_BUFFER_MAX_TIME_S**: Similarly, Ray Tune buffers results up to ``number_of_trial/10`` seconds,
|
||||
but never longer than this value. Defaults to 100 (seconds).
|
||||
* **TUNE_RESULT_BUFFER_MIN_TIME_S**: Additionally, you can specify a minimum time to buffer results. Defaults to 0.
|
||||
* **TUNE_SYNCER_VERBOSITY**: Amount of command output when using Tune with Docker Syncer. Defaults to 0.
|
||||
* **TUNE_WARN_THRESHOLD_S**: Threshold for logging if an Tune event loop operation takes too long. Defaults to 0.5 (seconds).
|
||||
* **TUNE_WARN_INSUFFICENT_RESOURCE_THRESHOLD_S**: Threshold for throwing a warning if no active trials are in ``RUNNING`` state
|
||||
for this amount of seconds. If the Ray Tune job is stuck in this state (most likely due to insufficient resources),
|
||||
|
|
|
@ -8,14 +8,6 @@ External library integrations (tune.integration)
|
|||
:depth: 1
|
||||
|
||||
|
||||
.. _tune-integration-docker:
|
||||
|
||||
Docker (tune.integration.docker)
|
||||
--------------------------------
|
||||
|
||||
.. autofunction:: ray.tune.integration.docker.DockerSyncer
|
||||
|
||||
|
||||
.. _tune-integration-keras:
|
||||
|
||||
Keras (tune.integration.keras)
|
||||
|
@ -25,12 +17,6 @@ Keras (tune.integration.keras)
|
|||
|
||||
.. autoclass:: ray.tune.integration.keras.TuneReportCheckpointCallback
|
||||
|
||||
.. _tune-integration-kubernetes:
|
||||
|
||||
Kubernetes (tune.integration.kubernetes)
|
||||
----------------------------------------
|
||||
|
||||
.. autofunction:: ray.tune.integration.kubernetes.NamespacedKubernetesSyncer
|
||||
|
||||
.. _tune-integration-mlflow:
|
||||
|
||||
|
|
|
@ -197,7 +197,6 @@ tune.run(tune.with_parameters(f, data=data))
|
|||
# __large_data_end__
|
||||
|
||||
MyTrainableClass = None
|
||||
custom_sync_str_or_func = ""
|
||||
|
||||
if not MOCK:
|
||||
# __log_1_start__
|
||||
|
@ -209,37 +208,31 @@ if not MOCK:
|
|||
# __log_1_end__
|
||||
|
||||
# __log_2_start__
|
||||
from ray.tune.syncer import Syncer
|
||||
|
||||
class CustomSyncer(Syncer):
|
||||
def sync_up(
|
||||
self, local_dir: str, remote_dir: str, exclude: list = None
|
||||
) -> bool:
|
||||
pass # sync up
|
||||
|
||||
def sync_down(
|
||||
self, remote_dir: str, local_dir: str, exclude: list = None
|
||||
) -> bool:
|
||||
pass # sync down
|
||||
|
||||
def delete(self, remote_dir: str) -> bool:
|
||||
pass # delete
|
||||
|
||||
tune.run(
|
||||
MyTrainableClass,
|
||||
sync_config=tune.SyncConfig(
|
||||
upload_dir="s3://my-log-dir", syncer=custom_sync_str_or_func
|
||||
upload_dir="s3://my-log-dir", syncer=CustomSyncer()
|
||||
),
|
||||
)
|
||||
# __log_2_end__
|
||||
|
||||
# __sync_start__
|
||||
import subprocess
|
||||
|
||||
|
||||
def custom_sync_func(source, target):
|
||||
# run other workload here
|
||||
sync_cmd = "s3 {source} {target}".format(source=source, target=target)
|
||||
sync_process = subprocess.Popen(sync_cmd, shell=True)
|
||||
sync_process.wait()
|
||||
|
||||
|
||||
# __sync_end__
|
||||
|
||||
if not MOCK:
|
||||
# __docker_start__
|
||||
from ray import tune
|
||||
from ray.tune.integration.docker import DockerSyncer
|
||||
|
||||
sync_config = tune.SyncConfig(syncer=DockerSyncer)
|
||||
|
||||
tune.run(train, sync_config=sync_config)
|
||||
# __docker_end__
|
||||
|
||||
# __s3_start__
|
||||
from ray import tune
|
||||
|
||||
|
@ -264,13 +257,6 @@ if not MOCK:
|
|||
)
|
||||
# __sync_config_end__
|
||||
|
||||
# __k8s_start__
|
||||
from ray.tune.integration.kubernetes import NamespacedKubernetesSyncer
|
||||
|
||||
sync_config = tune.SyncConfig(syncer=NamespacedKubernetesSyncer("ray"))
|
||||
|
||||
tune.run(train, sync_config=sync_config)
|
||||
# __k8s_end__
|
||||
|
||||
import ray
|
||||
|
||||
|
|
|
@ -577,7 +577,9 @@ How can I upload my Tune results to cloud storage?
|
|||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
If an upload directory is provided, Tune will automatically sync results from the ``local_dir`` to the given directory,
|
||||
natively supporting standard URIs for systems like S3, gsutil or HDFS.
|
||||
natively supporting standard URIs for systems like S3, gsutil or HDFS. You can add more filesystems by installing
|
||||
`fs-spec <https://filesystem-spec.readthedocs.io/en/latest/>`_-compatible filesystems e.g. using pip.
|
||||
|
||||
Here is an example of uploading to S3, using a bucket called ``my-log-dir``:
|
||||
|
||||
.. literalinclude:: doc_code/faq.py
|
||||
|
@ -586,8 +588,7 @@ Here is an example of uploading to S3, using a bucket called ``my-log-dir``:
|
|||
:start-after: __log_1_start__
|
||||
:end-before: __log_1_end__
|
||||
|
||||
You can customize this to specify arbitrary storages with the ``syncer`` argument in ``tune.SyncConfig``.
|
||||
This argument supports either strings with the same replacement fields OR arbitrary functions.
|
||||
You can customize synchronization behavior by implementing your own Syncer:
|
||||
|
||||
.. literalinclude:: doc_code/faq.py
|
||||
:dedent:
|
||||
|
@ -595,14 +596,6 @@ This argument supports either strings with the same replacement fields OR arbitr
|
|||
:start-after: __log_2_start__
|
||||
:end-before: __log_2_end__
|
||||
|
||||
If a string is provided, then it must include replacement fields ``{source}`` and ``{target}``, like
|
||||
``s3 sync {source} {target}``. Alternatively, a function can be provided with the following signature:
|
||||
|
||||
.. literalinclude:: doc_code/faq.py
|
||||
:language: python
|
||||
:start-after: __sync_start__
|
||||
:end-before: __sync_end__
|
||||
|
||||
By default, syncing occurs every 300 seconds.
|
||||
To change the frequency of syncing, set the ``sync_period`` attribute of the sync config to the desired syncing period.
|
||||
|
||||
|
@ -623,23 +616,13 @@ How can I use Tune with Docker?
|
|||
Tune automatically syncs files and checkpoints between different remote
|
||||
containers as needed.
|
||||
|
||||
To make this work in your Docker cluster, e.g. when you are using the Ray autoscaler
|
||||
with docker containers, you will need to pass a
|
||||
``DockerSyncer`` to the ``syncer`` argument of ``tune.SyncConfig``.
|
||||
|
||||
.. literalinclude:: doc_code/faq.py
|
||||
:dedent:
|
||||
:language: python
|
||||
:start-after: __docker_start__
|
||||
:end-before: __docker_end__
|
||||
|
||||
.. _tune-kubernetes:
|
||||
|
||||
How can I use Tune with Kubernetes?
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
Ray Tune automatically synchronizes files and checkpoints between different remote nodes as needed.
|
||||
This usually happens via SSH, but this can be a :ref:`performance bottleneck <tune-bottlenecks>`,
|
||||
This usually happens via the Ray object store, but this can be a :ref:`performance bottleneck <tune-bottlenecks>`,
|
||||
especially when running many trials in parallel.
|
||||
|
||||
Instead you should use shared storage for checkpoints so that no additional synchronization across nodes
|
||||
|
@ -662,19 +645,9 @@ Second, you can set up a shared file system like NFS. If you do this, disable au
|
|||
:start-after: __sync_config_start__
|
||||
:end-before: __sync_config_end__
|
||||
|
||||
Lastly, if you still want to use SSH for trial synchronization, but are not running
|
||||
on the Ray cluster launcher, you might need to pass a
|
||||
``KubernetesSyncer`` to the ``syncer`` argument of ``tune.SyncConfig``.
|
||||
You have to specify your Kubernetes namespace explicitly:
|
||||
|
||||
.. literalinclude:: doc_code/faq.py
|
||||
:dedent:
|
||||
:language: python
|
||||
:start-after: __k8s_start__
|
||||
:end-before: __k8s_end__
|
||||
|
||||
Please note that we strongly encourage you to use one of the other two options instead, as they will
|
||||
result in less overhead and don't require pods to SSH into each other.
|
||||
Please note that we strongly encourage you to use one of these two options, as they will
|
||||
result in less overhead and provide naturally durable checkpoint storage.
|
||||
|
||||
.. _tune-default-search-space:
|
||||
|
||||
|
|
|
@ -34,13 +34,6 @@ except (ImportError, ModuleNotFoundError):
|
|||
|
||||
from ray import logger
|
||||
|
||||
# We keep these constants for legacy compatibility with Tune's sync client
|
||||
# After Tune fully moved to using pyarrow.fs we can remove these.
|
||||
S3_PREFIX = "s3://"
|
||||
GS_PREFIX = "gs://"
|
||||
HDFS_PREFIX = "hdfs://"
|
||||
ALLOWED_REMOTE_PREFIXES = (S3_PREFIX, GS_PREFIX, HDFS_PREFIX)
|
||||
|
||||
|
||||
def _assert_pyarrow_installed():
|
||||
if pyarrow is None:
|
||||
|
@ -77,11 +70,7 @@ def is_non_local_path_uri(uri: str) -> bool:
|
|||
|
||||
if bool(get_fs_and_path(uri)[0]):
|
||||
return True
|
||||
# Keep manual check for prefixes for backwards compatibility with the
|
||||
# TrialCheckpoint class. Remove once fully deprecated.
|
||||
# Deprecated: Remove in Ray > 1.13
|
||||
if any(uri.startswith(p) for p in ALLOWED_REMOTE_PREFIXES):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
|
@ -207,8 +196,8 @@ def _upload_to_uri_with_exclude(
|
|||
if _should_exclude(candidate):
|
||||
continue
|
||||
|
||||
full_source_path = os.path.join(local_path, candidate)
|
||||
full_target_path = os.path.join(bucket_path, candidate)
|
||||
full_source_path = os.path.normpath(os.path.join(local_path, candidate))
|
||||
full_target_path = os.path.normpath(os.path.join(bucket_path, candidate))
|
||||
|
||||
pyarrow.fs.copy_files(
|
||||
full_source_path, full_target_path, destination_filesystem=fs
|
||||
|
|
|
@ -17,6 +17,7 @@ from ray.tune import Trainable, PlacementGroupFactory
|
|||
from ray.tune.logger import Logger
|
||||
from ray.tune.registry import get_trainable_cls
|
||||
from ray.tune.resources import Resources
|
||||
from ray.tune.syncer import Syncer
|
||||
from ray.util.annotations import PublicAPI
|
||||
from ray.util.ml_utils.dict import merge_dicts
|
||||
|
||||
|
@ -204,7 +205,7 @@ class RLTrainer(BaseTrainer):
|
|||
env: Optional[Union[str, EnvType]] = None,
|
||||
logger_creator: Optional[Callable[[], Logger]] = None,
|
||||
remote_checkpoint_dir: Optional[str] = None,
|
||||
sync_function_tpl: Optional[str] = None,
|
||||
custom_syncer: Optional[Syncer] = None,
|
||||
):
|
||||
resolved_config = merge_dicts(base_config, config or {})
|
||||
param_dict["config"] = resolved_config
|
||||
|
@ -217,7 +218,7 @@ class RLTrainer(BaseTrainer):
|
|||
env=env,
|
||||
logger_creator=logger_creator,
|
||||
remote_checkpoint_dir=remote_checkpoint_dir,
|
||||
sync_function_tpl=sync_function_tpl,
|
||||
custom_syncer=custom_syncer,
|
||||
)
|
||||
|
||||
def save_checkpoint(self, checkpoint_dir: str):
|
||||
|
|
|
@ -136,22 +136,6 @@ py_test(
|
|||
tags = ["team:ml", "exclusive", "tests_dir_I"],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "test_integration_docker",
|
||||
size = "small",
|
||||
srcs = ["tests/test_integration_docker.py"],
|
||||
deps = [":tune_lib"],
|
||||
tags = ["team:ml", "exclusive", "tests_dir_I"],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "test_integration_kubernetes",
|
||||
size = "small",
|
||||
srcs = ["tests/test_integration_kubernetes.py"],
|
||||
deps = [":tune_lib"],
|
||||
tags = ["team:ml", "exclusive", "tests_dir_I"],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "test_integration_pytorch_lightning",
|
||||
size = "small",
|
||||
|
@ -281,9 +265,25 @@ py_test(
|
|||
)
|
||||
|
||||
py_test(
|
||||
name = "test_sync",
|
||||
name = "test_util_file_transfer",
|
||||
size = "medium",
|
||||
srcs = ["tests/test_sync.py"],
|
||||
srcs = ["tests/test_util_file_transfer.py"],
|
||||
deps = [":tune_lib"],
|
||||
tags = ["team:ml", "exclusive", "tests_dir_S"],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "test_syncer",
|
||||
size = "medium",
|
||||
srcs = ["tests/test_syncer.py"],
|
||||
deps = [":tune_lib"],
|
||||
tags = ["team:ml", "exclusive", "tests_dir_S"],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "test_syncer_callback",
|
||||
size = "medium",
|
||||
srcs = ["tests/test_syncer_callback.py"],
|
||||
deps = [":tune_lib"],
|
||||
tags = ["team:ml", "exclusive", "tests_dir_S"],
|
||||
)
|
||||
|
|
|
@ -9,7 +9,7 @@ from six import string_types
|
|||
from ray.tune import TuneError
|
||||
from ray.tune.trial import Trial
|
||||
from ray.tune.resources import json_to_resources
|
||||
from ray.tune.syncer import SyncConfig
|
||||
from ray.tune.syncer import SyncConfig, Syncer
|
||||
from ray.tune.utils.placement_groups import PlacementGroupFactory
|
||||
from ray.tune.utils.util import SafeFallbackEncoder
|
||||
|
||||
|
@ -211,16 +211,20 @@ def create_trial_from_spec(
|
|||
remote_checkpoint_dir = spec.get("remote_checkpoint_dir")
|
||||
|
||||
sync_config = spec.get("sync_config", SyncConfig())
|
||||
if sync_config.syncer is None or isinstance(sync_config.syncer, str):
|
||||
sync_function_tpl = sync_config.syncer
|
||||
elif not isinstance(sync_config.syncer, str):
|
||||
# If a syncer was specified, but not a template, it is a function.
|
||||
# Functions cannot be used for trial checkpointing on remote nodes,
|
||||
# so we set the remote checkpoint dir to None to disable this.
|
||||
sync_function_tpl = None
|
||||
remote_checkpoint_dir = None
|
||||
if (
|
||||
sync_config.syncer is None
|
||||
or sync_config.syncer == "auto"
|
||||
or isinstance(sync_config.syncer, Syncer)
|
||||
):
|
||||
custom_syncer = sync_config.syncer
|
||||
else:
|
||||
sync_function_tpl = None # Auto-detect
|
||||
raise ValueError(
|
||||
f"Unknown syncer type passed in SyncConfig: {type(sync_config.syncer)}. "
|
||||
f"Note that custom sync functions and templates have been deprecated. "
|
||||
f"Instead you can implement you own `Syncer` class. "
|
||||
f"Please leave a comment on GitHub if you run into any issues with this: "
|
||||
f"https://github.com/ray-project/ray/issues"
|
||||
)
|
||||
|
||||
return Trial(
|
||||
# Submitting trial via server in py2.7 creates Unicode, which does not
|
||||
|
@ -232,7 +236,7 @@ def create_trial_from_spec(
|
|||
# json.load leads to str -> unicode in py2.7
|
||||
stopping_criterion=spec.get("stop", {}),
|
||||
remote_checkpoint_dir=remote_checkpoint_dir,
|
||||
sync_function_tpl=sync_function_tpl,
|
||||
custom_syncer=custom_syncer,
|
||||
checkpoint_freq=args.checkpoint_freq,
|
||||
checkpoint_at_end=args.checkpoint_at_end,
|
||||
sync_on_checkpoint=sync_config.sync_on_checkpoint,
|
||||
|
@ -246,5 +250,5 @@ def create_trial_from_spec(
|
|||
log_to_file=spec.get("log_to_file"),
|
||||
# str(None) doesn't create None
|
||||
max_failures=args.max_failures,
|
||||
**trial_kwargs
|
||||
**trial_kwargs,
|
||||
)
|
||||
|
|
|
@ -1,154 +1,21 @@
|
|||
import logging
|
||||
import os
|
||||
from typing import Optional, Tuple, List
|
||||
|
||||
from ray.autoscaler.sdk import rsync, configure_logging
|
||||
from ray.util import get_node_ip_address
|
||||
from ray.util.debug import log_once
|
||||
from ray.tune.syncer import NodeSyncer
|
||||
from ray.tune.sync_client import SyncClient
|
||||
from ray.ray_constants import env_integer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
from ray.util.annotations import Deprecated
|
||||
|
||||
|
||||
class DockerSyncer(NodeSyncer):
|
||||
"""DockerSyncer used for synchronization between Docker containers.
|
||||
This syncer extends the node syncer, but is usually instantiated
|
||||
without a custom sync client. The sync client defaults to
|
||||
``DockerSyncClient`` instead.
|
||||
|
||||
Set the env var `TUNE_SYNCER_VERBOSITY` to increase verbosity
|
||||
of syncing operations (0, 1, 2, 3). Defaults to 0.
|
||||
|
||||
.. note::
|
||||
This syncer only works with the Ray cluster launcher.
|
||||
If you use your own Docker setup, make sure the nodes can connect
|
||||
to each other via SSH, and try the regular SSH-based syncer instead.
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from ray.tune.integration.docker import DockerSyncer
|
||||
tune.run(train,
|
||||
sync_config=tune.SyncConfig(
|
||||
syncer=DockerSyncer))
|
||||
|
||||
"""
|
||||
|
||||
_cluster_config_file = os.path.expanduser("~/ray_bootstrap_config.yaml")
|
||||
|
||||
def __init__(
|
||||
self, local_dir: str, remote_dir: str, sync_client: Optional[SyncClient] = None
|
||||
):
|
||||
configure_logging(
|
||||
log_style="record", verbosity=env_integer("TUNE_SYNCER_VERBOSITY", 0)
|
||||
@Deprecated
|
||||
class DockerSyncer:
|
||||
def __init__(self, *args, **kwargs):
|
||||
raise DeprecationWarning(
|
||||
"DockerSyncer has been fully deprecated. There is no need to "
|
||||
"use this syncer anymore - data syncing will happen automatically "
|
||||
"using the Ray object store. You can just remove passing this class."
|
||||
)
|
||||
self.local_ip = get_node_ip_address()
|
||||
self.worker_ip = None
|
||||
|
||||
sync_client = sync_client or DockerSyncClient()
|
||||
sync_client.configure(self._cluster_config_file)
|
||||
|
||||
super(NodeSyncer, self).__init__(local_dir, remote_dir, sync_client)
|
||||
|
||||
def set_worker_ip(self, worker_ip: str):
|
||||
self.worker_ip = worker_ip
|
||||
|
||||
@property
|
||||
def _remote_path(self) -> Tuple[str, str]:
|
||||
return (self.worker_ip, self._remote_dir)
|
||||
|
||||
|
||||
class DockerSyncClient(SyncClient):
|
||||
"""DockerSyncClient to be used by DockerSyncer.
|
||||
|
||||
This client takes care of executing the synchronization
|
||||
commands for Docker nodes. In its ``sync_down`` and
|
||||
``sync_up`` commands, it expects tuples for the source
|
||||
and target, respectively, for compatibility with docker.
|
||||
|
||||
Args:
|
||||
should_bootstrap: Whether to bootstrap the autoscaler
|
||||
cofiguration. This may be useful when you are
|
||||
running into authentication problems; i.e.:
|
||||
https://github.com/ray-project/ray/issues/17756.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, should_bootstrap: bool = True):
|
||||
self._command_runners = {}
|
||||
self._cluster_config = None
|
||||
if os.environ.get("TUNE_SYNC_DISABLE_BOOTSTRAP") == "1":
|
||||
should_bootstrap = False
|
||||
logger.debug("Skipping bootstrap for docker sync client.")
|
||||
self._should_bootstrap = should_bootstrap
|
||||
|
||||
def configure(self, cluster_config_file: str):
|
||||
self._cluster_config_file = cluster_config_file
|
||||
|
||||
def sync_up(
|
||||
self, source: str, target: Tuple[str, str], exclude: Optional[List] = None
|
||||
) -> bool:
|
||||
"""Here target is a tuple (target_node, target_dir)"""
|
||||
target_node, target_dir = target
|
||||
|
||||
# Add trailing slashes for rsync
|
||||
source = os.path.join(source, "")
|
||||
target_dir = os.path.join(target_dir, "")
|
||||
import click
|
||||
|
||||
try:
|
||||
rsync(
|
||||
cluster_config=self._cluster_config_file,
|
||||
source=source,
|
||||
target=target_dir,
|
||||
down=False,
|
||||
ip_address=target_node,
|
||||
should_bootstrap=self._should_bootstrap,
|
||||
use_internal_ip=True,
|
||||
@Deprecated
|
||||
class DockerSyncClient:
|
||||
def __init__(self, *args, **kwargs):
|
||||
raise DeprecationWarning(
|
||||
"DockerSyncClient has been fully deprecated. There is no need to "
|
||||
"use this syncer anymore - data syncing will happen automatically "
|
||||
"using the Ray object store. You can just remove passing this class."
|
||||
)
|
||||
except click.ClickException:
|
||||
if log_once("docker_rsync_up_fail"):
|
||||
logger.warning(
|
||||
"Rsync-up failed. Consider using a durable trainable "
|
||||
"or setting the `TUNE_SYNC_DISABLE_BOOTSTRAP=1` env var."
|
||||
)
|
||||
raise
|
||||
|
||||
return True
|
||||
|
||||
def sync_down(
|
||||
self, source: Tuple[str, str], target: str, exclude: Optional[List] = None
|
||||
) -> bool:
|
||||
"""Here source is a tuple (source_node, source_dir)"""
|
||||
source_node, source_dir = source
|
||||
|
||||
# Add trailing slashes for rsync
|
||||
source_dir = os.path.join(source_dir, "")
|
||||
target = os.path.join(target, "")
|
||||
import click
|
||||
|
||||
try:
|
||||
rsync(
|
||||
cluster_config=self._cluster_config_file,
|
||||
source=source_dir,
|
||||
target=target,
|
||||
down=True,
|
||||
ip_address=source_node,
|
||||
should_bootstrap=self._should_bootstrap,
|
||||
use_internal_ip=True,
|
||||
)
|
||||
except click.ClickException:
|
||||
if log_once("docker_rsync_down_fail"):
|
||||
logger.warning(
|
||||
"Rsync-down failed. Consider using a durable trainable "
|
||||
"or setting the `TUNE_SYNC_DISABLE_BOOTSTRAP=1` env var."
|
||||
)
|
||||
raise
|
||||
|
||||
return True
|
||||
|
||||
def delete(self, target: str) -> bool:
|
||||
raise NotImplementedError
|
||||
|
|
|
@ -1,176 +1,30 @@
|
|||
import os
|
||||
from typing import Any, Optional, Tuple, List
|
||||
import subprocess
|
||||
|
||||
from ray import logger
|
||||
from ray.autoscaler._private.command_runner import KubernetesCommandRunner
|
||||
from ray.tune.syncer import NodeSyncer
|
||||
from ray.tune.sync_client import SyncClient
|
||||
from ray.util import get_node_ip_address
|
||||
from ray.util.annotations import Deprecated
|
||||
|
||||
|
||||
def try_import_kubernetes():
|
||||
try:
|
||||
import kubernetes
|
||||
except ImportError:
|
||||
kubernetes = None
|
||||
return kubernetes
|
||||
@Deprecated
|
||||
class KubernetesSyncer:
|
||||
def __init__(self, *args, **kwargs):
|
||||
|
||||
raise DeprecationWarning(
|
||||
"KubernetesSyncer has been fully deprecated. There is no need to "
|
||||
"use this syncer anymore - data syncing will happen automatically "
|
||||
"using the Ray object store. You can just remove passing this class."
|
||||
)
|
||||
|
||||
|
||||
kubernetes = try_import_kubernetes()
|
||||
@Deprecated
|
||||
class KubernetesSyncClient:
|
||||
def __init__(self, *args, **kwargs):
|
||||
raise DeprecationWarning(
|
||||
"KubernetesSyncClient has been fully deprecated. There is no need to "
|
||||
"use this syncer anymore - data syncing will happen automatically "
|
||||
"using the Ray object store. You can just remove passing this class."
|
||||
)
|
||||
|
||||
|
||||
def NamespacedKubernetesSyncer(namespace: str):
|
||||
"""Wrapper to return a ``KubernetesSyncer`` for a Kubernetes namespace.
|
||||
|
||||
Args:
|
||||
namespace: Kubernetes namespace.
|
||||
|
||||
Returns:
|
||||
A ``KubernetesSyncer`` class to be passed to ``tune.run()``.
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from ray.tune.integration.kubernetes import NamespacedKubernetesSyncer
|
||||
tune.run(train,
|
||||
sync_config=tune.SyncConfig(
|
||||
syncer=NamespacedKubernetesSyncer("ray")))
|
||||
|
||||
"""
|
||||
|
||||
class _NamespacedKubernetesSyncer(KubernetesSyncer):
|
||||
_namespace = namespace
|
||||
|
||||
return _NamespacedKubernetesSyncer
|
||||
|
||||
|
||||
class KubernetesSyncer(NodeSyncer):
|
||||
"""KubernetesSyncer used for synchronization between Kubernetes pods.
|
||||
|
||||
This syncer extends the node syncer, but is usually instantiated
|
||||
without a custom sync client. The sync client defaults to
|
||||
``KubernetesSyncClient`` instead.
|
||||
|
||||
KubernetesSyncer uses the default namespace ``ray``. You should
|
||||
probably use ``NamespacedKubernetesSyncer`` to return a class
|
||||
with a custom namespace instead.
|
||||
"""
|
||||
|
||||
_namespace = "ray"
|
||||
|
||||
def __init__(
|
||||
self, local_dir: str, remote_dir: str, sync_client: Optional[SyncClient] = None
|
||||
):
|
||||
if not kubernetes:
|
||||
raise ImportError(
|
||||
"kubernetes is not installed on this machine/container. "
|
||||
"Try: pip install kubernetes"
|
||||
raise DeprecationWarning(
|
||||
"NamespacedKubernetesSyncer has been fully deprecated. There is no need to "
|
||||
"use this syncer anymore - data syncing will happen automatically "
|
||||
"using the Ray object store. You can just remove passing this class."
|
||||
)
|
||||
self.local_ip = get_node_ip_address()
|
||||
self.local_node = self._get_kubernetes_node_by_ip(self.local_ip)
|
||||
self.worker_ip = None
|
||||
self.worker_node = None
|
||||
|
||||
sync_client = sync_client or KubernetesSyncClient(
|
||||
namespace=self.__class__._namespace
|
||||
)
|
||||
|
||||
super(NodeSyncer, self).__init__(local_dir, remote_dir, sync_client)
|
||||
|
||||
def set_worker_ip(self, worker_ip: str):
|
||||
self.worker_ip = worker_ip
|
||||
self.worker_node = self._get_kubernetes_node_by_ip(worker_ip)
|
||||
|
||||
def _get_kubernetes_node_by_ip(self, node_ip: str) -> Optional[str]:
|
||||
"""Return node name by internal or external IP"""
|
||||
kubernetes.config.load_incluster_config()
|
||||
api = kubernetes.client.CoreV1Api()
|
||||
pods = api.list_namespaced_pod(self._namespace)
|
||||
for pod in pods.items:
|
||||
if pod.status.host_ip == node_ip or pod.status.pod_ip == node_ip:
|
||||
return pod.metadata.name
|
||||
|
||||
logger.error("Could not find Kubernetes pod name for IP {}".format(node_ip))
|
||||
return None
|
||||
|
||||
@property
|
||||
def _remote_path(self) -> Tuple[str, str]:
|
||||
return self.worker_node, self._remote_dir
|
||||
|
||||
|
||||
class KubernetesSyncClient(SyncClient):
|
||||
"""KubernetesSyncClient to be used by KubernetesSyncer.
|
||||
|
||||
This client takes care of executing the synchronization
|
||||
commands for Kubernetes clients. In its ``sync_down`` and
|
||||
``sync_up`` commands, it expects tuples for the source
|
||||
and target, respectively, for compatibility with the
|
||||
KubernetesCommandRunner.
|
||||
|
||||
Args:
|
||||
namespace: Namespace in which the pods live.
|
||||
process_runner: How commands should be called.
|
||||
Defaults to ``subprocess``.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, namespace: str, process_runner: Any = subprocess):
|
||||
self.namespace = namespace
|
||||
self._process_runner = process_runner
|
||||
self._command_runners = {}
|
||||
|
||||
def _create_command_runner(self, node_id: str) -> KubernetesCommandRunner:
|
||||
"""Create a command runner for one Kubernetes node"""
|
||||
return KubernetesCommandRunner(
|
||||
log_prefix="KubernetesSyncClient: {}:".format(node_id),
|
||||
namespace=self.namespace,
|
||||
node_id=node_id,
|
||||
auth_config=None,
|
||||
process_runner=self._process_runner,
|
||||
)
|
||||
|
||||
def _get_command_runner(self, node_id: str) -> KubernetesCommandRunner:
|
||||
"""Create command runner if it doesn't exist"""
|
||||
# Todo(krfricke): These cached runners are currently
|
||||
# never cleaned up. They are cheap so this shouldn't
|
||||
# cause much problems, but should be addressed if
|
||||
# the SyncClient is used more extensively in the future.
|
||||
if node_id not in self._command_runners:
|
||||
command_runner = self._create_command_runner(node_id)
|
||||
self._command_runners[node_id] = command_runner
|
||||
return self._command_runners[node_id]
|
||||
|
||||
def sync_up(
|
||||
self, source: str, target: Tuple[str, str], exclude: Optional[List] = None
|
||||
) -> bool:
|
||||
"""Here target is a tuple (target_node, target_dir)"""
|
||||
target_node, target_dir = target
|
||||
|
||||
# Add trailing slashes for rsync
|
||||
source = os.path.join(source, "")
|
||||
target_dir = os.path.join(target_dir, "")
|
||||
|
||||
command_runner = self._get_command_runner(target_node)
|
||||
command_runner.run_rsync_up(source, target_dir)
|
||||
return True
|
||||
|
||||
def sync_down(
|
||||
self, source: Tuple[str, str], target: str, exclude: Optional[List] = None
|
||||
) -> bool:
|
||||
"""Here source is a tuple (source_node, source_dir)"""
|
||||
source_node, source_dir = source
|
||||
|
||||
# Add trailing slashes for rsync
|
||||
source_dir = os.path.join(source_dir, "")
|
||||
target = os.path.join(target, "")
|
||||
|
||||
command_runner = self._get_command_runner(source_node)
|
||||
command_runner.run_rsync_down(source_dir, target)
|
||||
return True
|
||||
|
||||
def delete(self, target: str) -> bool:
|
||||
"""No delete function because it is only used by
|
||||
the KubernetesSyncer, which doesn't call delete."""
|
||||
return True
|
||||
|
|
|
@ -379,7 +379,7 @@ class RayTrialExecutor:
|
|||
# We keep these kwargs separate for backwards compatibility
|
||||
# with trainables that don't provide these keyword arguments
|
||||
kwargs["remote_checkpoint_dir"] = trial.remote_checkpoint_dir
|
||||
kwargs["sync_function_tpl"] = trial.sync_function_tpl
|
||||
kwargs["custom_syncer"] = trial.custom_syncer
|
||||
|
||||
# Throw a meaningful error if trainable does not use the
|
||||
# new API
|
||||
|
@ -389,7 +389,7 @@ class RayTrialExecutor:
|
|||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
"Your trainable class does not accept a "
|
||||
"`remote_checkpoint_dir` or `sync_function_tpl` argument "
|
||||
"`remote_checkpoint_dir` or `custom_syncer` argument "
|
||||
"in its constructor, but you've passed a "
|
||||
"`upload_dir` to your SyncConfig. Without accepting "
|
||||
"these parameters and passing them to the base trainable "
|
||||
|
@ -775,27 +775,31 @@ class RayTrialExecutor:
|
|||
raise RuntimeError(
|
||||
"Trial {}: Unable to restore - no runner found.".format(trial)
|
||||
)
|
||||
value = checkpoint.dir_or_data
|
||||
checkpoint_dir = checkpoint.dir_or_data
|
||||
node_ip = checkpoint.node_ip
|
||||
if checkpoint.storage_mode == CheckpointStorage.MEMORY:
|
||||
logger.debug("Trial %s: Attempting restore from object", trial)
|
||||
# Note that we don't store the remote since in-memory checkpoints
|
||||
# don't guarantee fault tolerance and don't need to be waited on.
|
||||
with self._change_working_directory(trial):
|
||||
trial.runner.restore_from_object.remote(value)
|
||||
trial.runner.restore_from_object.remote(checkpoint_dir)
|
||||
else:
|
||||
logger.debug("Trial %s: Attempting restore from %s", trial, value)
|
||||
if trial.uses_cloud_checkpointing or not trial.sync_on_checkpoint:
|
||||
logger.debug("Trial %s: Attempting restore from %s", trial, checkpoint_dir)
|
||||
if (
|
||||
trial.uses_cloud_checkpointing
|
||||
or not trial.sync_on_checkpoint
|
||||
or not os.path.exists(checkpoint_dir)
|
||||
):
|
||||
# If using cloud checkpointing, trial will get cp from cloud.
|
||||
# If not syncing to driver, assume it has access to the cp
|
||||
# on the local fs.
|
||||
with self._change_working_directory(trial):
|
||||
remote = trial.runner.restore.remote(value, node_ip)
|
||||
remote = trial.runner.restore.remote(checkpoint_dir, node_ip)
|
||||
elif trial.sync_on_checkpoint:
|
||||
# This provides FT backwards compatibility in the
|
||||
# case where no cloud checkpoints are provided.
|
||||
logger.debug("Trial %s: Reading checkpoint into memory", trial)
|
||||
obj = TrainableUtil.checkpoint_to_object(value)
|
||||
obj = TrainableUtil.checkpoint_to_object(checkpoint_dir)
|
||||
with self._change_working_directory(trial):
|
||||
remote = trial.runner.restore_from_object.remote(obj)
|
||||
else:
|
||||
|
|
|
@ -1,606 +1,44 @@
|
|||
import abc
|
||||
import distutils
|
||||
import distutils.spawn
|
||||
import inspect
|
||||
import logging
|
||||
import pathlib
|
||||
import subprocess
|
||||
import tempfile
|
||||
import time
|
||||
import types
|
||||
|
||||
from typing import Optional, List, Callable, Union, Tuple
|
||||
|
||||
from shlex import quote
|
||||
|
||||
import ray
|
||||
from ray.tune.error import TuneError
|
||||
from ray.tune.utils.file_transfer import sync_dir_between_nodes, delete_on_node
|
||||
from ray.util.annotations import PublicAPI, DeveloperAPI
|
||||
from ray.air._internal.remote_storage import (
|
||||
S3_PREFIX,
|
||||
GS_PREFIX,
|
||||
HDFS_PREFIX,
|
||||
ALLOWED_REMOTE_PREFIXES,
|
||||
)
|
||||
from ray.util.annotations import Deprecated
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
noop_template = ": {target}" # noop in bash
|
||||
|
||||
|
||||
def noop(*args):
|
||||
return
|
||||
|
||||
|
||||
def get_sync_client(
|
||||
sync_function: Optional[Union[str, Callable]],
|
||||
delete_function: Optional[Union[str, Callable]] = None,
|
||||
) -> Optional["SyncClient"]:
|
||||
"""Returns a sync client.
|
||||
|
||||
Args:
|
||||
sync_function: Sync function.
|
||||
delete_function: Delete function. Must be
|
||||
the same type as sync_function if it is provided.
|
||||
|
||||
Raises:
|
||||
ValueError if sync_function or delete_function are malformed.
|
||||
"""
|
||||
if sync_function is None:
|
||||
return None
|
||||
if delete_function and type(sync_function) != type(delete_function):
|
||||
raise ValueError("Sync and delete functions must be of same type.")
|
||||
if isinstance(sync_function, types.FunctionType):
|
||||
delete_function = delete_function or noop
|
||||
client_cls = FunctionBasedClient
|
||||
elif isinstance(sync_function, str):
|
||||
delete_function = delete_function or noop_template
|
||||
client_cls = CommandBasedClient
|
||||
else:
|
||||
raise ValueError(
|
||||
"Sync function {} must be string or function".format(sync_function)
|
||||
)
|
||||
return client_cls(sync_function, sync_function, delete_function)
|
||||
|
||||
|
||||
def get_cloud_sync_client(remote_path: str) -> "CommandBasedClient":
|
||||
"""Returns a CommandBasedClient that can sync to/from remote storage.
|
||||
|
||||
Args:
|
||||
remote_path: Path to remote storage (S3, GS or HDFS).
|
||||
|
||||
Raises:
|
||||
ValueError if malformed remote_dir.
|
||||
"""
|
||||
if remote_path.startswith(S3_PREFIX):
|
||||
if not distutils.spawn.find_executable("aws"):
|
||||
raise ValueError(
|
||||
"Upload uri starting with '{}' requires awscli tool"
|
||||
" to be installed".format(S3_PREFIX)
|
||||
)
|
||||
sync_up_template = (
|
||||
"aws s3 sync {source} {target} "
|
||||
"--exact-timestamps --only-show-errors {options}"
|
||||
)
|
||||
sync_down_template = sync_up_template
|
||||
delete_template = "aws s3 rm {target} --recursive --only-show-errors {options}"
|
||||
exclude_template = "--exclude '{pattern}'"
|
||||
elif remote_path.startswith(GS_PREFIX):
|
||||
if not distutils.spawn.find_executable("gsutil"):
|
||||
raise ValueError(
|
||||
"Upload uri starting with '{}' requires gsutil tool"
|
||||
" to be installed".format(GS_PREFIX)
|
||||
)
|
||||
sync_up_template = "gsutil rsync -r {options} {source} {target}"
|
||||
sync_down_template = sync_up_template
|
||||
delete_template = "gsutil rm -r {options} {target}"
|
||||
exclude_template = "-x '{regex_pattern}'"
|
||||
elif remote_path.startswith(HDFS_PREFIX):
|
||||
if not distutils.spawn.find_executable("hdfs"):
|
||||
raise ValueError(
|
||||
"Upload uri starting with '{}' requires hdfs tool"
|
||||
" to be installed".format(HDFS_PREFIX)
|
||||
)
|
||||
sync_up_template = "hdfs dfs -put -f {source} {target}"
|
||||
sync_down_template = "hdfs dfs -get -f {source} {target}"
|
||||
delete_template = "hdfs dfs -rm -r {target}"
|
||||
exclude_template = None
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Upload uri must start with one of: {ALLOWED_REMOTE_PREFIXES} "
|
||||
f"(is: `{remote_path}`)"
|
||||
)
|
||||
return CommandBasedClient(
|
||||
sync_up_template, sync_down_template, delete_template, exclude_template
|
||||
)
|
||||
|
||||
|
||||
@PublicAPI(stability="beta")
|
||||
@Deprecated
|
||||
class SyncClient(abc.ABC):
|
||||
"""Client interface for interacting with remote storage options."""
|
||||
|
||||
def sync_up(self, source: str, target: str, exclude: Optional[List] = None):
|
||||
"""Syncs up from source to target.
|
||||
|
||||
Args:
|
||||
source: Source path.
|
||||
target: Target path.
|
||||
exclude: Pattern of files to exclude, e.g.
|
||||
``["*/checkpoint_*]`` to exclude trial checkpoints.
|
||||
|
||||
Returns:
|
||||
True if sync initiation successful, False otherwise.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def sync_down(self, source: str, target: str, exclude: Optional[List] = None):
|
||||
"""Syncs down from source to target.
|
||||
|
||||
Args:
|
||||
source: Source path.
|
||||
target: Target path.
|
||||
exclude: Pattern of files to exclude, e.g.
|
||||
``["*/checkpoint_*]`` to exclude trial checkpoints.
|
||||
|
||||
Returns:
|
||||
True if sync initiation successful, False otherwise.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def delete(self, target: str):
|
||||
"""Deletes target.
|
||||
|
||||
Args:
|
||||
target: Target path.
|
||||
|
||||
Returns:
|
||||
True if delete initiation successful, False otherwise.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def wait(self):
|
||||
"""Waits for current sync to complete, if asynchronously started."""
|
||||
pass
|
||||
|
||||
def wait_or_retry(self, max_retries: int = 3, backoff_s: int = 5):
|
||||
"""Wait for current sync to complete or retries on error."""
|
||||
pass
|
||||
|
||||
def reset(self):
|
||||
"""Resets state."""
|
||||
pass
|
||||
|
||||
def close(self):
|
||||
"""Clean up hook."""
|
||||
pass
|
||||
|
||||
|
||||
def _is_legacy_sync_fn(func) -> bool:
|
||||
sig = inspect.signature(func)
|
||||
try:
|
||||
sig.bind_partial(None, None, None)
|
||||
return False
|
||||
except TypeError:
|
||||
return True
|
||||
|
||||
|
||||
@DeveloperAPI
|
||||
class FunctionBasedClient(SyncClient):
|
||||
def __init__(self, sync_up_func, sync_down_func, delete_func=None):
|
||||
self.sync_up_func = sync_up_func
|
||||
self._sync_up_legacy = _is_legacy_sync_fn(sync_up_func)
|
||||
|
||||
self.sync_down_func = sync_down_func
|
||||
self._sync_down_legacy = _is_legacy_sync_fn(sync_up_func)
|
||||
|
||||
if self._sync_up_legacy or self._sync_down_legacy:
|
||||
def __init__(self, *args, **kwargs):
|
||||
raise DeprecationWarning(
|
||||
"Your sync functions currently only accepts two params "
|
||||
"(a `source` and a `target`). In the future, we will "
|
||||
"pass an additional `exclude` parameter. Please adjust "
|
||||
"your sync function accordingly."
|
||||
"SyncClient has been deprecated. Please implement a "
|
||||
"`ray.tune.syncer.Syncer` instead."
|
||||
)
|
||||
|
||||
self.delete_func = delete_func or noop
|
||||
|
||||
def sync_up(self, source, target, exclude: Optional[List] = None):
|
||||
if self._sync_up_legacy:
|
||||
self.sync_up_func(source, target)
|
||||
else:
|
||||
self.sync_up_func(source, target, exclude)
|
||||
return True
|
||||
|
||||
def sync_down(self, source, target, exclude: Optional[List] = None):
|
||||
if self._sync_down_legacy:
|
||||
self.sync_down_func(source, target)
|
||||
else:
|
||||
self.sync_down_func(source, target, exclude)
|
||||
return True
|
||||
|
||||
def delete(self, target):
|
||||
self.delete_func(target)
|
||||
return True
|
||||
@Deprecated
|
||||
class FunctionBasedClient(SyncClient):
|
||||
def __init__(self, *args, **kwargs):
|
||||
raise DeprecationWarning(
|
||||
"FunctionBasedClient has been deprecated. Please implement a "
|
||||
"`ray.tune.syncer.Syncer` instead."
|
||||
)
|
||||
|
||||
|
||||
NOOP = FunctionBasedClient(noop, noop)
|
||||
|
||||
|
||||
@DeveloperAPI
|
||||
@Deprecated
|
||||
class CommandBasedClient(SyncClient):
|
||||
"""Syncs between two directories with the given command.
|
||||
|
||||
If a sync is already in-flight when calling ``sync_down`` or
|
||||
``sync_up``, a warning will be printed and the new sync command is
|
||||
ignored. To force a new sync, either use ``wait()``
|
||||
(or ``wait_or_retry()``) to wait until the previous sync has finished,
|
||||
or call ``reset()`` to detach from the previous sync. Note that this
|
||||
will not kill the previous sync command, so it may still be executed.
|
||||
|
||||
Arguments:
|
||||
sync_up_template: A runnable string template; needs to
|
||||
include replacement fields ``{source}``, ``{target}``, and
|
||||
``{options}``.
|
||||
sync_down_template: A runnable string template; needs to
|
||||
include replacement fields ``{source}``, ``{target}``, and
|
||||
``{options}``.
|
||||
delete_template: A runnable string template; needs
|
||||
to include replacement field ``{target}``. Noop by default.
|
||||
exclude_template: A pattern with possible
|
||||
replacement fields ``{pattern}`` and ``{regex_pattern}``.
|
||||
Will replace ``{options}}`` in the sync up/down templates
|
||||
if files/directories to exclude are passed.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
sync_up_template: str,
|
||||
sync_down_template: str,
|
||||
delete_template: Optional[str] = noop_template,
|
||||
exclude_template: Optional[str] = None,
|
||||
):
|
||||
self._validate_sync_string(sync_up_template)
|
||||
self._validate_sync_string(sync_down_template)
|
||||
self._validate_exclude_template(exclude_template)
|
||||
self.sync_up_template = sync_up_template
|
||||
self.sync_down_template = sync_down_template
|
||||
self.delete_template = delete_template
|
||||
self.exclude_template = exclude_template
|
||||
self.logfile = None
|
||||
self._closed = False
|
||||
self.cmd_process = None
|
||||
# Keep track of last command for retry
|
||||
self._last_cmd = None
|
||||
|
||||
def set_logdir(self, logdir: str):
|
||||
"""Sets the directory to log sync execution output in.
|
||||
|
||||
Args:
|
||||
logdir: Log directory.
|
||||
"""
|
||||
self.logfile = tempfile.NamedTemporaryFile(
|
||||
prefix="log_sync_out", dir=logdir, suffix=".log", delete=False
|
||||
)
|
||||
self._closed = False
|
||||
|
||||
def _get_logfile(self):
|
||||
if self._closed:
|
||||
raise RuntimeError(
|
||||
"[internalerror] The client has been closed. "
|
||||
"Please report this stacktrace + your cluster configuration "
|
||||
"on Github!"
|
||||
)
|
||||
else:
|
||||
return self.logfile
|
||||
|
||||
def _start_process(self, cmd: str) -> subprocess.Popen:
|
||||
return subprocess.Popen(
|
||||
cmd, shell=True, stderr=subprocess.PIPE, stdout=self._get_logfile()
|
||||
)
|
||||
|
||||
def sync_up(self, source, target, exclude: Optional[List] = None):
|
||||
return self._execute(self.sync_up_template, source, target, exclude)
|
||||
|
||||
def sync_down(self, source, target, exclude: Optional[List] = None):
|
||||
# Just in case some command line sync client expects that local
|
||||
# directory exists.
|
||||
pathlib.Path(target).mkdir(parents=True, exist_ok=True)
|
||||
return self._execute(self.sync_down_template, source, target, exclude)
|
||||
|
||||
def delete(self, target):
|
||||
if self.is_running:
|
||||
logger.warning(
|
||||
f"Last sync client cmd still in progress, "
|
||||
f"skipping deletion of {target}"
|
||||
)
|
||||
return False
|
||||
final_cmd = self.delete_template.format(target=quote(target), options="")
|
||||
logger.debug("Running delete: {}".format(final_cmd))
|
||||
self._last_cmd = final_cmd
|
||||
self.cmd_process = self._start_process(final_cmd)
|
||||
return True
|
||||
|
||||
def wait(self):
|
||||
if self.cmd_process:
|
||||
_, error_msg = self.cmd_process.communicate()
|
||||
error_msg = error_msg.decode("ascii")
|
||||
code = self.cmd_process.returncode
|
||||
args = self.cmd_process.args
|
||||
self.cmd_process = None
|
||||
if code != 0:
|
||||
raise TuneError(
|
||||
"Sync error. Ran command: {}\n"
|
||||
"Error message ({}): {}".format(args, code, error_msg)
|
||||
)
|
||||
|
||||
def wait_or_retry(self, max_retries: int = 3, backoff_s: int = 5):
|
||||
assert max_retries > 0
|
||||
for _ in range(max_retries - 1):
|
||||
try:
|
||||
self.wait()
|
||||
except TuneError as e:
|
||||
logger.error(
|
||||
f"Caught sync error: {e}. "
|
||||
f"Retrying after sleeping for {backoff_s} seconds..."
|
||||
)
|
||||
time.sleep(backoff_s)
|
||||
self.cmd_process = self._start_process(self._last_cmd)
|
||||
continue
|
||||
return
|
||||
self.cmd_process = None
|
||||
raise TuneError(f"Failed sync even after {max_retries} retries.")
|
||||
|
||||
def reset(self):
|
||||
if self.is_running:
|
||||
logger.warning("Sync process still running but resetting anyways.")
|
||||
self.cmd_process = None
|
||||
self._last_cmd = None
|
||||
|
||||
def close(self):
|
||||
if self.logfile:
|
||||
logger.debug(f"Closing the logfile: {str(self.logfile)}")
|
||||
self.logfile.close()
|
||||
self.logfile = None
|
||||
self._closed = True
|
||||
|
||||
@property
|
||||
def is_running(self):
|
||||
"""Returns whether a sync or delete process is running."""
|
||||
if self.cmd_process:
|
||||
self.cmd_process.poll()
|
||||
return self.cmd_process.returncode is None
|
||||
return False
|
||||
|
||||
def _execute(self, sync_template, source, target, exclude: Optional[List] = None):
|
||||
"""Executes sync_template on source and target."""
|
||||
if self.is_running:
|
||||
logger.warning(
|
||||
f"Last sync client cmd still in progress, "
|
||||
f"skipping sync from {source} to {target}."
|
||||
)
|
||||
return False
|
||||
|
||||
if exclude and self.exclude_template:
|
||||
options = []
|
||||
if "{pattern}" in self.exclude_template:
|
||||
for excl in exclude:
|
||||
options.append(self.exclude_template.format(pattern=excl))
|
||||
elif "{regex_pattern}" in self.exclude_template:
|
||||
# This is obviously not a great way to convert to regex,
|
||||
# but it will do for the moment. Todo: Improve.
|
||||
def _to_regex(pattern: str) -> str:
|
||||
return f"({pattern.replace('*', '.*')})"
|
||||
|
||||
regex_pattern = "|".join(_to_regex(excl) for excl in exclude)
|
||||
options.append(
|
||||
self.exclude_template.format(regex_pattern=regex_pattern)
|
||||
)
|
||||
option_str = " ".join(options)
|
||||
else:
|
||||
option_str = ""
|
||||
|
||||
final_cmd = sync_template.format(
|
||||
source=quote(source), target=quote(target), options=option_str
|
||||
)
|
||||
logger.debug("Running sync: {}".format(final_cmd))
|
||||
self._last_cmd = final_cmd
|
||||
self.cmd_process = self._start_process(final_cmd)
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def _validate_sync_string(sync_string):
|
||||
if not isinstance(sync_string, str):
|
||||
raise ValueError("{} is not a string.".format(sync_string))
|
||||
if "{source}" not in sync_string:
|
||||
raise ValueError("Sync template missing `{source}`: " f"{sync_string}.")
|
||||
if "{target}" not in sync_string:
|
||||
raise ValueError("Sync template missing `{target}`: " f"{sync_string}.")
|
||||
|
||||
@staticmethod
|
||||
def _validate_exclude_template(exclude_template):
|
||||
if exclude_template:
|
||||
if (
|
||||
"{pattern}" not in exclude_template
|
||||
and "{regex_pattern}" not in exclude_template
|
||||
):
|
||||
raise ValueError(
|
||||
"Neither `{pattern}` nor `{regex_pattern}` found in "
|
||||
f"exclude string `{exclude_template}`"
|
||||
def __init__(self, *args, **kwargs):
|
||||
raise DeprecationWarning(
|
||||
"CommandBasedClient has been deprecated. Please implement a "
|
||||
"`ray.tune.syncer.Syncer` instead."
|
||||
)
|
||||
|
||||
|
||||
@DeveloperAPI
|
||||
@Deprecated
|
||||
class RemoteTaskClient(SyncClient):
|
||||
"""Sync client that uses remote tasks to synchronize two directories.
|
||||
|
||||
This client expects tuples of (ip, path) for remote sources/targets
|
||||
in sync_down/sync_up.
|
||||
|
||||
To avoid unnecessary syncing, the sync client will collect the existing
|
||||
files with their respective mtimes and sizes on the (possibly remote)
|
||||
target directory. Only files that are not in the target directory or
|
||||
differ to those in the target directory by size or mtime will be
|
||||
transferred. This is similar to most cloud
|
||||
synchronization implementations (e.g. aws s3 sync).
|
||||
|
||||
If a sync is already in-flight when calling ``sync_down`` or
|
||||
``sync_up``, a warning will be printed and the new sync command is
|
||||
ignored. To force a new sync, either use ``wait()``
|
||||
(or ``wait_or_retry()``) to wait until the previous sync has finished,
|
||||
or call ``reset()`` to detach from the previous sync. Note that this
|
||||
will not kill the previous sync command, so it may still be executed.
|
||||
"""
|
||||
|
||||
def __init__(self, _store_remotes: bool = False):
|
||||
# Used for testing
|
||||
self._store_remotes = _store_remotes
|
||||
self._stored_pack_actor_ref = None
|
||||
self._stored_files_stats_future = None
|
||||
|
||||
self._sync_future = None
|
||||
|
||||
self._last_source_tuple = None
|
||||
self._last_target_tuple = None
|
||||
|
||||
self._max_size_bytes = None # No file size limit
|
||||
|
||||
def _sync_still_running(self) -> bool:
|
||||
if not self._sync_future:
|
||||
return False
|
||||
|
||||
ready, not_ready = ray.wait([self._sync_future], timeout=0.0)
|
||||
if self._sync_future in ready:
|
||||
self.wait()
|
||||
return False
|
||||
return True
|
||||
|
||||
def sync_down(
|
||||
self, source: Tuple[str, str], target: str, exclude: Optional[List] = None
|
||||
) -> bool:
|
||||
if self._sync_still_running():
|
||||
logger.warning(
|
||||
f"Last remote task sync still in progress, "
|
||||
f"skipping sync from {source} to {target}."
|
||||
def __init__(self, *args, **kwargs):
|
||||
raise DeprecationWarning(
|
||||
"RemoteTaskClient has been deprecated. Please implement a "
|
||||
"`ray.tune.syncer.Syncer` instead."
|
||||
)
|
||||
return False
|
||||
|
||||
source_ip, source_path = source
|
||||
target_ip = ray.util.get_node_ip_address()
|
||||
|
||||
self._last_source_tuple = source_ip, source_path
|
||||
self._last_target_tuple = target_ip, target
|
||||
|
||||
return self._execute_sync(self._last_source_tuple, self._last_target_tuple)
|
||||
|
||||
def sync_up(
|
||||
self, source: str, target: Tuple[str, str], exclude: Optional[List] = None
|
||||
) -> bool:
|
||||
if self._sync_still_running():
|
||||
logger.warning(
|
||||
f"Last remote task sync still in progress, "
|
||||
f"skipping sync from {source} to {target}."
|
||||
)
|
||||
return False
|
||||
|
||||
source_ip = ray.util.get_node_ip_address()
|
||||
target_ip, target_path = target
|
||||
|
||||
self._last_source_tuple = source_ip, source
|
||||
self._last_target_tuple = target_ip, target_path
|
||||
|
||||
return self._execute_sync(self._last_source_tuple, self._last_target_tuple)
|
||||
|
||||
def _sync_function(self, *args, **kwargs):
|
||||
return sync_dir_between_nodes(*args, **kwargs)
|
||||
|
||||
def _execute_sync(
|
||||
self,
|
||||
source_tuple: Tuple[str, str],
|
||||
target_tuple: Tuple[str, str],
|
||||
) -> bool:
|
||||
source_ip, source_path = source_tuple
|
||||
target_ip, target_path = target_tuple
|
||||
|
||||
self._sync_future, pack_actor, files_stats = self._sync_function(
|
||||
source_ip=source_ip,
|
||||
source_path=source_path,
|
||||
target_ip=target_ip,
|
||||
target_path=target_path,
|
||||
return_futures=True,
|
||||
max_size_bytes=self._max_size_bytes,
|
||||
)
|
||||
|
||||
if self._store_remotes:
|
||||
self._stored_pack_actor_ref = pack_actor
|
||||
self._stored_files_stats = files_stats
|
||||
|
||||
return True
|
||||
|
||||
def delete(self, target: str):
|
||||
if not self._last_target_tuple:
|
||||
logger.warning(
|
||||
f"Could not delete path {target} as the target node is not known."
|
||||
)
|
||||
return
|
||||
|
||||
node_ip = self._last_target_tuple[0]
|
||||
|
||||
try:
|
||||
delete_on_node(node_ip=node_ip, path=target)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Could not delete path {target} on remote node {node_ip}: {e}"
|
||||
)
|
||||
|
||||
def wait(self):
|
||||
if self._sync_future:
|
||||
try:
|
||||
ray.get(self._sync_future)
|
||||
except Exception as e:
|
||||
raise TuneError(
|
||||
f"Remote task sync failed from "
|
||||
f"{self._last_source_tuple} to "
|
||||
f"{self._last_target_tuple}: {e}"
|
||||
) from e
|
||||
self._sync_future = None
|
||||
self._stored_pack_actor_ref = None
|
||||
self._stored_files_stats_future = None
|
||||
|
||||
def wait_or_retry(self, max_retries: int = 3, backoff_s: int = 5):
|
||||
assert max_retries > 0
|
||||
|
||||
for _ in range(max_retries - 1):
|
||||
try:
|
||||
self.wait()
|
||||
except TuneError as e:
|
||||
logger.error(
|
||||
f"Caught sync error: {e}. "
|
||||
f"Retrying after sleeping for {backoff_s} seconds..."
|
||||
)
|
||||
time.sleep(backoff_s)
|
||||
|
||||
self._execute_sync(
|
||||
self._last_source_tuple,
|
||||
self._last_target_tuple,
|
||||
)
|
||||
continue
|
||||
return
|
||||
self._sync_future = None
|
||||
self._stored_pack_actor_ref = None
|
||||
self._stored_files_stats_future = None
|
||||
raise TuneError(f"Failed sync even after {max_retries} retries.")
|
||||
|
||||
def reset(self):
|
||||
if self._sync_future:
|
||||
logger.warning("Sync process still running but resetting anyways.")
|
||||
self._sync_future = None
|
||||
self._last_source_tuple = None
|
||||
self._last_target_tuple = None
|
||||
self._stored_pack_actor_ref = None
|
||||
self._stored_files_stats_future = None
|
||||
|
||||
def close(self):
|
||||
self._sync_future = None # Avoid warning
|
||||
self.reset()
|
||||
|
|
|
@ -1,40 +1,31 @@
|
|||
import abc
|
||||
import threading
|
||||
from typing import (
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
TYPE_CHECKING,
|
||||
Type,
|
||||
Union,
|
||||
Optional,
|
||||
Tuple,
|
||||
)
|
||||
|
||||
import distutils
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
|
||||
from inspect import isclass
|
||||
from shlex import quote
|
||||
|
||||
import ray
|
||||
import yaml
|
||||
from ray.air._internal.remote_storage import get_fs_and_path, fs_hint
|
||||
from ray.air._internal.remote_storage import (
|
||||
fs_hint,
|
||||
upload_to_uri,
|
||||
download_from_uri,
|
||||
delete_at_uri,
|
||||
is_non_local_path_uri,
|
||||
)
|
||||
from ray.tune import TuneError
|
||||
from ray.tune.callback import Callback
|
||||
from ray.tune.result import NODE_IP
|
||||
from ray.util import get_node_ip_address
|
||||
from ray.util.debug import log_once
|
||||
from ray.tune.cluster_info import get_ssh_key, get_ssh_user
|
||||
from ray.tune.sync_client import (
|
||||
CommandBasedClient,
|
||||
get_sync_client,
|
||||
get_cloud_sync_client,
|
||||
NOOP,
|
||||
SyncClient,
|
||||
RemoteTaskClient,
|
||||
)
|
||||
from ray.tune.utils.file_transfer import sync_dir_between_nodes
|
||||
from ray.util.annotations import PublicAPI, DeveloperAPI
|
||||
from ray.util.ml_utils.checkpoint_manager import CheckpointStorage, _TrackedCheckpoint
|
||||
|
||||
|
@ -44,71 +35,22 @@ if TYPE_CHECKING:
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Syncing period for syncing checkpoints between nodes or to cloud.
|
||||
SYNC_PERIOD = 300
|
||||
|
||||
CLOUD_CHECKPOINTING_URL = (
|
||||
"https://docs.ray.io/en/master/tune/user-guide.html#using-cloud-storage"
|
||||
)
|
||||
_log_sync_warned = False
|
||||
_syncers = {}
|
||||
DEFAULT_SYNC_PERIOD = 300
|
||||
|
||||
|
||||
def wait_for_sync():
|
||||
for syncer in _syncers.values():
|
||||
syncer.wait()
|
||||
def _validate_upload_dir(sync_config: "SyncConfig") -> bool:
|
||||
if not sync_config.upload_dir:
|
||||
return True
|
||||
|
||||
if sync_config.upload_dir.startswith("file://"):
|
||||
return True
|
||||
|
||||
def validate_upload_dir(sync_config: "SyncConfig"):
|
||||
if sync_config.upload_dir:
|
||||
exc = None
|
||||
try:
|
||||
fs, _ = get_fs_and_path(sync_config.upload_dir)
|
||||
except ImportError as e:
|
||||
fs = None
|
||||
exc = e
|
||||
if not fs:
|
||||
if not is_non_local_path_uri(sync_config.upload_dir):
|
||||
raise ValueError(
|
||||
f"Could not identify external storage filesystem for "
|
||||
f"upload dir `{sync_config.upload_dir}`. "
|
||||
f"Hint: {fs_hint(sync_config.upload_dir)}"
|
||||
) from exc
|
||||
|
||||
|
||||
def set_sync_periods(sync_config: "SyncConfig"):
|
||||
"""Sets sync period from config."""
|
||||
global SYNC_PERIOD
|
||||
SYNC_PERIOD = int(sync_config.sync_period)
|
||||
|
||||
|
||||
def get_rsync_template_if_available(options: str = ""):
|
||||
"""Template enabling syncs between driver and worker when possible.
|
||||
Requires ray cluster to be started with the autoscaler. Also requires
|
||||
rsync to be installed.
|
||||
|
||||
Args:
|
||||
options: Additional rsync options.
|
||||
|
||||
Returns:
|
||||
Sync template with source and target parameters. None if rsync
|
||||
unavailable.
|
||||
"""
|
||||
if not distutils.spawn.find_executable("rsync"):
|
||||
if log_once("tune:rsync"):
|
||||
logger.error("Log sync requires rsync to be installed.")
|
||||
return None
|
||||
global _log_sync_warned
|
||||
ssh_key = get_ssh_key()
|
||||
if ssh_key is None:
|
||||
if not _log_sync_warned:
|
||||
logger.debug("Log sync requires cluster to be setup with `ray up`.")
|
||||
_log_sync_warned = True
|
||||
return None
|
||||
|
||||
rsh = "ssh -i {ssh_key} -o ConnectTimeout=120s -o StrictHostKeyChecking=no"
|
||||
rsh = rsh.format(ssh_key=quote(ssh_key))
|
||||
options += " --exclude='checkpoint_tmp*'"
|
||||
template = "rsync {options} -savz -e {rsh} {{source}} {{target}}"
|
||||
return template.format(options=options, rsh=quote(rsh))
|
||||
)
|
||||
|
||||
|
||||
@PublicAPI
|
||||
|
@ -124,13 +66,8 @@ class SyncConfig:
|
|||
upload_dir: Optional URI to sync training results and checkpoints
|
||||
to (e.g. ``s3://bucket``, ``gs://bucket`` or ``hdfs://path``).
|
||||
Specifying this will enable cloud-based checkpointing.
|
||||
syncer: Function for syncing the local_dir to and
|
||||
from remote storage. If string, then it must be a string template
|
||||
that includes ``{source}`` and ``{target}`` for the syncer to run.
|
||||
If not provided, it defaults to rsync for non cloud-based storage,
|
||||
and to standard S3, gsutil or HDFS sync commands for cloud-based
|
||||
storage.
|
||||
If set to ``None``, no syncing will take place.
|
||||
syncer: Syncer class to use for synchronizing checkpoints to/from
|
||||
cloud storage. If set to ``None``, no syncing will take place.
|
||||
Defaults to ``"auto"`` (auto detect).
|
||||
sync_on_checkpoint: Force sync-down of trial checkpoint to
|
||||
driver (only non cloud-storage).
|
||||
|
@ -142,395 +79,411 @@ class SyncConfig:
|
|||
"""
|
||||
|
||||
upload_dir: Optional[str] = None
|
||||
syncer: Optional[str] = "auto"
|
||||
syncer: Optional[Union[str, "Syncer"]] = "auto"
|
||||
|
||||
sync_on_checkpoint: bool = True
|
||||
sync_period: int = 300
|
||||
sync_period: int = DEFAULT_SYNC_PERIOD
|
||||
|
||||
|
||||
class _BackgroundProcess:
|
||||
def __init__(self, fn: Callable):
|
||||
self._fn = fn
|
||||
self._process = None
|
||||
self._result = {}
|
||||
|
||||
@property
|
||||
def is_running(self):
|
||||
return self._process and self._process.is_alive()
|
||||
|
||||
def start(self, *args, **kwargs):
|
||||
if self.is_running:
|
||||
return False
|
||||
|
||||
self._result = {}
|
||||
|
||||
def entrypoint():
|
||||
try:
|
||||
result = self._fn(*args, **kwargs)
|
||||
except Exception as e:
|
||||
self._result["exception"] = e
|
||||
return
|
||||
|
||||
self._result["result"] = result
|
||||
|
||||
self._process = threading.Thread(target=entrypoint)
|
||||
self._process.start()
|
||||
|
||||
def wait(self):
|
||||
if not self._process:
|
||||
return
|
||||
|
||||
self._process.join()
|
||||
self._process = None
|
||||
|
||||
exception = self._result.get("exception")
|
||||
if exception:
|
||||
raise exception
|
||||
|
||||
result = self._result.get("result")
|
||||
|
||||
self._result = {}
|
||||
return result
|
||||
|
||||
|
||||
@DeveloperAPI
|
||||
class Syncer:
|
||||
def __init__(self, local_dir: str, remote_dir: str, sync_client: SyncClient = NOOP):
|
||||
"""Syncs between two directories with the sync_function.
|
||||
class Syncer(abc.ABC):
|
||||
"""Syncer class for synchronizing data between Ray nodes and external storage.
|
||||
|
||||
This class handles data transfer for two cases:
|
||||
|
||||
1. Synchronizing data from the driver to external storage. This affects
|
||||
experiment-level checkpoints and trial-level checkpoints if no cloud storage
|
||||
is used.
|
||||
2. Synchronizing data from remote trainables to external storage.
|
||||
|
||||
Synchronizing tasks are usually asynchronous and can be awaited using ``wait()``.
|
||||
The base class implements a ``wait_or_retry()`` API that will retry a failed
|
||||
sync command.
|
||||
|
||||
The base class also exposes an API to only kick off syncs every ``sync_period``
|
||||
seconds.
|
||||
|
||||
Arguments:
|
||||
local_dir: Directory to sync. Uniquely identifies the syncer.
|
||||
remote_dir: Remote directory to sync with.
|
||||
sync_client: Client for syncing between local_dir and
|
||||
remote_dir. Defaults to a Noop.
|
||||
"""
|
||||
self._local_dir = os.path.join(local_dir, "") if local_dir else local_dir
|
||||
self._remote_dir = remote_dir
|
||||
|
||||
def __init__(self, sync_period: float = 300.0):
|
||||
self.sync_period = sync_period
|
||||
self.last_sync_up_time = float("-inf")
|
||||
self.last_sync_down_time = float("-inf")
|
||||
self.sync_client = sync_client
|
||||
|
||||
@property
|
||||
def _pass_ip_path_tuples(self) -> False:
|
||||
"""Return True if the sync client expects (ip, path) tuples instead
|
||||
of rsync strings (user@ip:/path/)."""
|
||||
return isinstance(self.sync_client, RemoteTaskClient)
|
||||
@abc.abstractmethod
|
||||
def sync_up(
|
||||
self, local_dir: str, remote_dir: str, exclude: Optional[List] = None
|
||||
) -> bool:
|
||||
"""Synchronize local directory to remote directory.
|
||||
|
||||
def sync_up_if_needed(self, sync_period: int, exclude: Optional[List] = None):
|
||||
This function can spawn an asynchronous process that can be awaited in
|
||||
``wait()``.
|
||||
|
||||
Args:
|
||||
local_dir: Local directory to sync from.
|
||||
remote_dir: Remote directory to sync up to. This is an URI
|
||||
(``protocol://remote/path``).
|
||||
exclude: Pattern of files to exclude, e.g.
|
||||
``["*/checkpoint_*]`` to exclude trial checkpoints.
|
||||
|
||||
Returns:
|
||||
True if sync process has been spawned, False otherwise.
|
||||
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def sync_down(
|
||||
self, remote_dir: str, local_dir: str, exclude: Optional[List] = None
|
||||
) -> bool:
|
||||
"""Synchronize remote directory to local directory.
|
||||
|
||||
This function can spawn an asynchronous process that can be awaited in
|
||||
``wait()``.
|
||||
|
||||
Args:
|
||||
remote_dir: Remote directory to sync down from. This is an URI
|
||||
(``protocol://remote/path``).
|
||||
local_dir: Local directory to sync to.
|
||||
exclude: Pattern of files to exclude, e.g.
|
||||
``["*/checkpoint_*]`` to exclude trial checkpoints.
|
||||
|
||||
Returns:
|
||||
True if sync process has been spawned, False otherwise.
|
||||
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def delete(self, remote_dir: str) -> bool:
|
||||
"""Delete directory on remote storage.
|
||||
|
||||
This function can spawn an asynchronous process that can be awaited in
|
||||
``wait()``.
|
||||
|
||||
Args:
|
||||
remote_dir: Remote directory to delete. This is an URI
|
||||
(``protocol://remote/path``).
|
||||
|
||||
Returns:
|
||||
True if sync process has been spawned, False otherwise.
|
||||
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def retry(self):
|
||||
"""Retry the last sync up, sync down, or delete command.
|
||||
|
||||
You should implement this method if you spawn asynchronous syncing
|
||||
processes.
|
||||
"""
|
||||
pass
|
||||
|
||||
def wait(self):
|
||||
"""Wait for asynchronous sync command to finish.
|
||||
|
||||
You should implement this method if you spawn asynchronous syncing
|
||||
processes.
|
||||
"""
|
||||
pass
|
||||
|
||||
def sync_up_if_needed(
|
||||
self, local_dir: str, remote_dir: str, exclude: Optional[List] = None
|
||||
) -> bool:
|
||||
"""Syncs up if time since last sync up is greater than sync_period.
|
||||
|
||||
Args:
|
||||
sync_period: Time period between subsequent syncs.
|
||||
local_dir: Local directory to sync from.
|
||||
remote_dir: Remote directory to sync up to. This is an URI
|
||||
(``protocol://remote/path``).
|
||||
exclude: Pattern of files to exclude, e.g.
|
||||
``["*/checkpoint_*]`` to exclude trial checkpoints.
|
||||
"""
|
||||
now = time.time()
|
||||
if now - self.last_sync_up_time >= self.sync_period:
|
||||
result = self.sync_up(
|
||||
local_dir=local_dir, remote_dir=remote_dir, exclude=exclude
|
||||
)
|
||||
self.last_sync_up_time = now
|
||||
return result
|
||||
|
||||
if time.time() - self.last_sync_up_time > sync_period:
|
||||
self.sync_up(exclude)
|
||||
|
||||
def sync_down_if_needed(self, sync_period: int, exclude: Optional[List] = None):
|
||||
def sync_down_if_needed(
|
||||
self, remote_dir: str, local_dir: str, exclude: Optional[List] = None
|
||||
):
|
||||
"""Syncs down if time since last sync down is greater than sync_period.
|
||||
|
||||
Args:
|
||||
sync_period: Time period between subsequent syncs.
|
||||
remote_dir: Remote directory to sync down from. This is an URI
|
||||
(``protocol://remote/path``).
|
||||
local_dir: Local directory to sync to.
|
||||
exclude: Pattern of files to exclude, e.g.
|
||||
``["*/checkpoint_*]`` to exclude trial checkpoints.
|
||||
"""
|
||||
if time.time() - self.last_sync_down_time > sync_period:
|
||||
self.sync_down(exclude)
|
||||
|
||||
def sync_up(self, exclude: Optional[List] = None):
|
||||
"""Attempts to start the sync-up to the remote path.
|
||||
|
||||
Args:
|
||||
exclude: Pattern of files to exclude, e.g.
|
||||
``["*/checkpoint_*]`` to exclude trial checkpoints.
|
||||
|
||||
Returns:
|
||||
Whether the sync (if feasible) was successfully started.
|
||||
"""
|
||||
result = False
|
||||
if self.validate_hosts(self._local_dir, self._remote_path):
|
||||
try:
|
||||
result = self.sync_client.sync_up(
|
||||
self._local_dir, self._remote_path, exclude=exclude
|
||||
now = time.time()
|
||||
if now - self.last_sync_down_time >= self.sync_period:
|
||||
result = self.sync_down(
|
||||
remote_dir=remote_dir, local_dir=local_dir, exclude=exclude
|
||||
)
|
||||
self.last_sync_up_time = time.time()
|
||||
except Exception:
|
||||
logger.exception("Sync execution failed.")
|
||||
self.last_sync_down_time = now
|
||||
return result
|
||||
|
||||
def sync_down(self, exclude: Optional[List] = None):
|
||||
"""Attempts to start the sync-down from the remote path.
|
||||
|
||||
Args:
|
||||
exclude: Pattern of files to exclude, e.g.
|
||||
``["*/checkpoint_*]`` to exclude trial checkpoints.
|
||||
|
||||
Returns:
|
||||
Whether the sync (if feasible) was successfully started.
|
||||
"""
|
||||
result = False
|
||||
if self.validate_hosts(self._local_dir, self._remote_path):
|
||||
def wait_or_retry(self, max_retries: int = 3, backoff_s: int = 5):
|
||||
assert max_retries > 0
|
||||
last_error = None
|
||||
for _ in range(max_retries - 1):
|
||||
try:
|
||||
result = self.sync_client.sync_down(
|
||||
self._remote_path, self._local_dir, exclude=exclude
|
||||
self.wait()
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Caught sync error: {e}. "
|
||||
f"Retrying after sleeping for {backoff_s} seconds..."
|
||||
)
|
||||
self.last_sync_down_time = time.time()
|
||||
except Exception:
|
||||
logger.exception("Sync execution failed.")
|
||||
return result
|
||||
|
||||
def validate_hosts(self, source, target):
|
||||
if not (source and target):
|
||||
logger.debug(
|
||||
"Source or target is empty, skipping log sync for "
|
||||
"{}".format(self._local_dir)
|
||||
)
|
||||
return False
|
||||
return True
|
||||
|
||||
def wait(self):
|
||||
"""Waits for the sync client to complete the current sync."""
|
||||
self.sync_client.wait()
|
||||
last_error = e
|
||||
time.sleep(backoff_s)
|
||||
self.retry()
|
||||
continue
|
||||
return
|
||||
raise TuneError(
|
||||
f"Failed sync even after {max_retries} retries."
|
||||
) from last_error
|
||||
|
||||
def reset(self):
|
||||
self.last_sync_up_time = float("-inf")
|
||||
self.last_sync_down_time = float("-inf")
|
||||
self.sync_client.reset()
|
||||
|
||||
def close(self):
|
||||
self.sync_client.close()
|
||||
|
||||
@property
|
||||
def _remote_path(self) -> Optional[Union[str, Tuple[str, str]]]:
|
||||
return self._remote_dir
|
||||
pass
|
||||
|
||||
|
||||
@DeveloperAPI
|
||||
class CloudSyncer(Syncer):
|
||||
"""Syncer for syncing files to/from the cloud."""
|
||||
class _DefaultSyncer(Syncer):
|
||||
"""Default syncer between local storage and remote URI."""
|
||||
|
||||
def __init__(self, local_dir, remote_dir, sync_client):
|
||||
super(CloudSyncer, self).__init__(local_dir, remote_dir, sync_client)
|
||||
def __init__(self, sync_period: float = 300.0):
|
||||
super(_DefaultSyncer, self).__init__(sync_period=sync_period)
|
||||
self._sync_process = None
|
||||
self._current_cmd = None
|
||||
|
||||
def sync_up_if_needed(self, exclude: Optional[List] = None):
|
||||
return super(CloudSyncer, self).sync_up_if_needed(SYNC_PERIOD, exclude=exclude)
|
||||
|
||||
def sync_down_if_needed(self, exclude: Optional[List] = None):
|
||||
return super(CloudSyncer, self).sync_down_if_needed(
|
||||
SYNC_PERIOD, exclude=exclude
|
||||
)
|
||||
|
||||
|
||||
@DeveloperAPI
|
||||
class NodeSyncer(Syncer):
|
||||
"""Syncer for syncing files to/from a remote dir to a local dir."""
|
||||
|
||||
def __init__(self, local_dir, remote_dir, sync_client):
|
||||
self.local_ip = get_node_ip_address()
|
||||
self.worker_ip = None
|
||||
super(NodeSyncer, self).__init__(local_dir, remote_dir, sync_client)
|
||||
|
||||
def set_worker_ip(self, worker_ip):
|
||||
"""Sets the worker IP to sync logs from."""
|
||||
self.worker_ip = worker_ip
|
||||
|
||||
def has_remote_target(self):
|
||||
"""Returns whether the Syncer has a remote target."""
|
||||
if not self.worker_ip:
|
||||
logger.debug("Worker IP unknown, skipping sync for %s", self._local_dir)
|
||||
return False
|
||||
if self.worker_ip == self.local_ip:
|
||||
logger.debug("Worker IP is local IP, skipping sync for %s", self._local_dir)
|
||||
return False
|
||||
return True
|
||||
|
||||
def sync_up_if_needed(self, exclude: Optional[List] = None):
|
||||
if not self.has_remote_target():
|
||||
return True
|
||||
return super(NodeSyncer, self).sync_up_if_needed(SYNC_PERIOD, exclude=exclude)
|
||||
|
||||
def sync_down_if_needed(self, exclude: Optional[List] = None):
|
||||
if not self.has_remote_target():
|
||||
return True
|
||||
return super(NodeSyncer, self).sync_down_if_needed(SYNC_PERIOD, exclude=exclude)
|
||||
|
||||
def sync_up_to_new_location(self, worker_ip):
|
||||
if worker_ip != self.worker_ip:
|
||||
logger.debug("Setting new worker IP to %s", worker_ip)
|
||||
self.set_worker_ip(worker_ip)
|
||||
self.reset()
|
||||
if not self.sync_up():
|
||||
def sync_up(
|
||||
self, local_dir: str, remote_dir: str, exclude: Optional[List] = None
|
||||
) -> bool:
|
||||
if self._sync_process and self._sync_process.is_running:
|
||||
logger.warning(
|
||||
"Sync up to new location skipped. This should not occur."
|
||||
f"Last sync still in progress, "
|
||||
f"skipping sync up of {local_dir} to {remote_dir}"
|
||||
)
|
||||
else:
|
||||
logger.warning("Sync attempted to same IP %s.", worker_ip)
|
||||
return False
|
||||
elif self._sync_process:
|
||||
try:
|
||||
self._sync_process.wait()
|
||||
except Exception as e:
|
||||
logger.warning(f"Last sync command failed: {e}")
|
||||
|
||||
self._current_cmd = (
|
||||
upload_to_uri,
|
||||
dict(local_path=local_dir, uri=remote_dir, exclude=exclude),
|
||||
)
|
||||
self.retry()
|
||||
|
||||
def sync_up(self, exclude: Optional[List] = None):
|
||||
if not self.has_remote_target():
|
||||
return True
|
||||
return super(NodeSyncer, self).sync_up(exclude=exclude)
|
||||
|
||||
def sync_down(self, exclude: Optional[List] = None):
|
||||
if not self.has_remote_target():
|
||||
def sync_down(
|
||||
self, remote_dir: str, local_dir: str, exclude: Optional[List] = None
|
||||
) -> bool:
|
||||
if self._sync_process and self._sync_process.is_running:
|
||||
logger.warning(
|
||||
f"Last sync still in progress, "
|
||||
f"skipping sync down of {remote_dir} to {local_dir}"
|
||||
)
|
||||
return False
|
||||
elif self._sync_process:
|
||||
try:
|
||||
self._sync_process.wait()
|
||||
except Exception as e:
|
||||
logger.warning(f"Last sync command failed: {e}")
|
||||
|
||||
self._current_cmd = (
|
||||
download_from_uri,
|
||||
dict(uri=remote_dir, local_path=local_dir),
|
||||
)
|
||||
self.retry()
|
||||
|
||||
return True
|
||||
logger.debug("Syncing from %s to %s", self._remote_path, self._local_dir)
|
||||
return super(NodeSyncer, self).sync_down(exclude=exclude)
|
||||
|
||||
@property
|
||||
def _remote_path(self) -> Optional[Union[str, Tuple[str, str]]]:
|
||||
ssh_user = get_ssh_user()
|
||||
global _log_sync_warned
|
||||
if not self.has_remote_target():
|
||||
return None
|
||||
if ssh_user is None:
|
||||
if not _log_sync_warned:
|
||||
logger.error("Syncer requires cluster to be setup with `ray up`.")
|
||||
_log_sync_warned = True
|
||||
return None
|
||||
if self._pass_ip_path_tuples:
|
||||
return self.worker_ip, self._remote_dir
|
||||
return "{}@{}:{}/".format(ssh_user, self.worker_ip, self._remote_dir)
|
||||
def delete(self, remote_dir: str) -> bool:
|
||||
if self._sync_process and self._sync_process.is_running:
|
||||
logger.warning(
|
||||
f"Last sync still in progress, skipping deletion of {remote_dir}"
|
||||
)
|
||||
return False
|
||||
|
||||
self._current_cmd = (delete_at_uri, dict(uri=remote_dir))
|
||||
self.retry()
|
||||
|
||||
return True
|
||||
|
||||
def wait(self):
|
||||
if self._sync_process:
|
||||
try:
|
||||
self._sync_process.wait()
|
||||
except Exception as e:
|
||||
raise TuneError(f"Sync process failed: {e}") from e
|
||||
finally:
|
||||
self._sync_process = None
|
||||
|
||||
def retry(self):
|
||||
if not self._current_cmd:
|
||||
raise TuneError("No sync command set, cannot retry.")
|
||||
cmd, kwargs = self._current_cmd
|
||||
self._sync_process = _BackgroundProcess(cmd)
|
||||
self._sync_process.start(**kwargs)
|
||||
|
||||
|
||||
@DeveloperAPI
|
||||
def get_cloud_syncer(
|
||||
local_dir: str,
|
||||
remote_dir: Optional[str] = None,
|
||||
sync_function: Optional[Union[Callable, str]] = None,
|
||||
) -> CloudSyncer:
|
||||
"""Returns a Syncer.
|
||||
def get_node_to_storage_syncer(sync_config: SyncConfig) -> Optional[Syncer]:
|
||||
""""""
|
||||
if sync_config.syncer is None:
|
||||
return None
|
||||
|
||||
This syncer is in charge of syncing the local_dir with upload_dir.
|
||||
if not sync_config.upload_dir:
|
||||
return None
|
||||
|
||||
If no ``remote_dir`` is provided, it will return a no-op syncer.
|
||||
if sync_config.syncer == "auto":
|
||||
return _DefaultSyncer(sync_period=sync_config.sync_period)
|
||||
|
||||
If a ``sync_function`` is provided, it will return a CloudSyncer using
|
||||
a custom SyncClient initialized by the sync function. Otherwise it will
|
||||
return a CloudSyncer with default templates for s3/gs/hdfs.
|
||||
if isinstance(sync_config.syncer, Syncer):
|
||||
return sync_config.syncer
|
||||
|
||||
Args:
|
||||
local_dir: Source directory for syncing.
|
||||
remote_dir: Target directory for syncing. If not provided, a
|
||||
no-op Syncer is returned.
|
||||
sync_function: Function for syncing the local_dir to
|
||||
remote_dir. If string, then it must be a string template for
|
||||
syncer to run. If not provided, it defaults
|
||||
to standard S3, gsutil or HDFS sync commands.
|
||||
|
||||
Raises:
|
||||
ValueError if malformed remote_dir.
|
||||
"""
|
||||
key = (local_dir, remote_dir)
|
||||
|
||||
if key in _syncers:
|
||||
return _syncers[key]
|
||||
|
||||
if not remote_dir:
|
||||
_syncers[key] = CloudSyncer(local_dir, remote_dir, NOOP)
|
||||
return _syncers[key]
|
||||
|
||||
if sync_function == "auto":
|
||||
sync_function = None # Auto-detect
|
||||
|
||||
# Maybe get user-provided sync client here
|
||||
client = get_sync_client(sync_function)
|
||||
|
||||
if client:
|
||||
# If the user provided a sync template or function
|
||||
_syncers[key] = CloudSyncer(local_dir, remote_dir, client)
|
||||
else:
|
||||
# Else, get default cloud sync client (e.g. S3 syncer)
|
||||
sync_client = get_cloud_sync_client(remote_dir)
|
||||
_syncers[key] = CloudSyncer(local_dir, remote_dir, sync_client)
|
||||
|
||||
return _syncers[key]
|
||||
|
||||
|
||||
@DeveloperAPI
|
||||
def get_node_syncer(
|
||||
local_dir: str,
|
||||
remote_dir: Optional[str] = None,
|
||||
sync_function: Optional[Union[Callable, str, bool, Type[Syncer]]] = None,
|
||||
):
|
||||
"""Returns a NodeSyncer.
|
||||
|
||||
Args:
|
||||
local_dir: Source directory for syncing.
|
||||
remote_dir: Target directory for syncing. If not provided, a
|
||||
noop Syncer is returned.
|
||||
sync_function: Function for syncing the local_dir to
|
||||
remote_dir. If string, then it must be a string template for
|
||||
syncer to run. If True or not provided, it defaults rsync
|
||||
(if available) or otherwise remote-task based syncing. If
|
||||
False, a noop Syncer is returned.
|
||||
"""
|
||||
if sync_function == "auto":
|
||||
sync_function = None # Auto-detect
|
||||
|
||||
key = (local_dir, remote_dir)
|
||||
if key in _syncers:
|
||||
# Get cached syncer
|
||||
return _syncers[key]
|
||||
elif isclass(sync_function) and issubclass(sync_function, Syncer):
|
||||
# Type[Syncer]
|
||||
_syncers[key] = sync_function(local_dir, remote_dir, None)
|
||||
return _syncers[key]
|
||||
elif not remote_dir or sync_function is False:
|
||||
# Do not sync trials if no remote dir specified or syncer=False
|
||||
sync_client = NOOP
|
||||
elif sync_function and sync_function is not True:
|
||||
# String or callable (for function syncers)
|
||||
sync_client = get_sync_client(sync_function)
|
||||
else:
|
||||
# sync_function=True or sync_function=None --> default
|
||||
rsync_function_str = get_rsync_template_if_available()
|
||||
if rsync_function_str:
|
||||
sync_client = CommandBasedClient(rsync_function_str, rsync_function_str)
|
||||
sync_client.set_logdir(local_dir)
|
||||
else:
|
||||
sync_client = RemoteTaskClient()
|
||||
|
||||
_syncers[key] = NodeSyncer(local_dir, remote_dir, sync_client)
|
||||
return _syncers[key]
|
||||
raise ValueError(
|
||||
f"Unknown syncer type passed in SyncConfig: {type(sync_config.syncer)}. "
|
||||
f"Note that custom sync functions and templates have been deprecated. "
|
||||
f"Instead you can implement you own `Syncer` class. "
|
||||
f"Please leave a comment on GitHub if you run into any issues with this: "
|
||||
f"https://github.com/ray-project/ray/issues"
|
||||
)
|
||||
|
||||
|
||||
@DeveloperAPI
|
||||
class SyncerCallback(Callback):
|
||||
def __init__(self, sync_function: Optional[Union[bool, Callable]]):
|
||||
self._sync_function = sync_function
|
||||
self._syncers: Dict["Trial", NodeSyncer] = {}
|
||||
"""Callback to synchronize trial directories on a worker node with the driver."""
|
||||
|
||||
def _get_trial_syncer(self, trial: "Trial"):
|
||||
if trial not in self._syncers:
|
||||
self._syncers[trial] = self._create_trial_syncer(trial)
|
||||
return self._syncers[trial]
|
||||
def __init__(self, enabled: bool = True, sync_period: float = DEFAULT_SYNC_PERIOD):
|
||||
self._enabled = enabled
|
||||
self._sync_processes: Dict[str, _BackgroundProcess] = {}
|
||||
self._sync_times: Dict[str, float] = {}
|
||||
self._sync_period = sync_period
|
||||
|
||||
def _create_trial_syncer(self, trial: "Trial"):
|
||||
return get_node_syncer(
|
||||
trial.logdir, remote_dir=trial.logdir, sync_function=self._sync_function
|
||||
def _get_trial_sync_process(self, trial: "Trial"):
|
||||
return self._sync_processes.setdefault(
|
||||
trial.trial_id, _BackgroundProcess(sync_dir_between_nodes)
|
||||
)
|
||||
|
||||
def _remove_trial_syncer(self, trial: "Trial"):
|
||||
self._syncers.pop(trial, None)
|
||||
def _remove_trial_sync_process(self, trial: "Trial"):
|
||||
self._sync_processes.pop(trial.trial_id, None)
|
||||
|
||||
def _sync_trial_checkpoint(self, trial: "Trial", checkpoint: _TrackedCheckpoint):
|
||||
if checkpoint.storage_mode == CheckpointStorage.MEMORY:
|
||||
return
|
||||
def _should_sync(self, trial: "Trial"):
|
||||
last_sync_time = self._sync_times.setdefault(trial.trial_id, float("-inf"))
|
||||
return time.time() - last_sync_time >= self._sync_period
|
||||
|
||||
def _mark_as_synced(self, trial: "Trial"):
|
||||
self._sync_times[trial.trial_id] = time.time()
|
||||
|
||||
def _local_trial_logdir(self, trial: "Trial"):
|
||||
return trial.logdir
|
||||
|
||||
def _remote_trial_logdir(self, trial: "Trial"):
|
||||
return trial.logdir
|
||||
|
||||
def _sync_trial_dir(
|
||||
self, trial: "Trial", force: bool = False, wait: bool = True
|
||||
) -> bool:
|
||||
if not self._enabled or trial.uses_cloud_checkpointing:
|
||||
return False
|
||||
|
||||
sync_process = self._get_trial_sync_process(trial)
|
||||
|
||||
# Always run if force=True
|
||||
# Otherwise, only run if we should sync (considering sync period)
|
||||
# or if there is no sync currently still running.
|
||||
if not force and (not self._should_sync(trial) or sync_process.is_running):
|
||||
return False
|
||||
|
||||
if NODE_IP in trial.last_result:
|
||||
source_ip = trial.last_result[NODE_IP]
|
||||
else:
|
||||
source_ip = ray.get(trial.runner.get_current_ip.remote())
|
||||
|
||||
trial_syncer = self._get_trial_syncer(trial)
|
||||
# If the sync_function is False, syncing to driver is disabled.
|
||||
# In every other case (valid values include None, True Callable,
|
||||
# NodeSyncer) syncing to driver is enabled.
|
||||
if trial.sync_on_checkpoint and self._sync_function is not False:
|
||||
try:
|
||||
# Wait for any other syncs to finish. We need to sync again
|
||||
# after this to handle checkpoints taken mid-sync.
|
||||
trial_syncer.wait()
|
||||
sync_process.wait()
|
||||
except TuneError as e:
|
||||
# Errors occurring during this wait are not fatal for this
|
||||
# checkpoint, so it should just be logged.
|
||||
logger.error(
|
||||
f"Trial {trial}: An error occurred during the "
|
||||
f"checkpoint pre-sync wait: {e}"
|
||||
f"checkpoint syncing of the previous checkpoint: {e}"
|
||||
)
|
||||
# Force sync down and wait before tracking the new checkpoint.
|
||||
sync_process.start(
|
||||
source_ip=source_ip,
|
||||
source_path=self._remote_trial_logdir(trial),
|
||||
target_ip=ray.util.get_node_ip_address(),
|
||||
target_path=self._local_trial_logdir(trial),
|
||||
)
|
||||
self._sync_times[trial.trial_id] = time.time()
|
||||
if wait:
|
||||
try:
|
||||
if trial_syncer.sync_down():
|
||||
trial_syncer.wait()
|
||||
else:
|
||||
logger.error(
|
||||
f"Trial {trial}: Checkpoint sync skipped. "
|
||||
f"This should not happen."
|
||||
)
|
||||
sync_process.wait()
|
||||
except TuneError as e:
|
||||
if trial.uses_cloud_checkpointing:
|
||||
# Even though rsync failed the trainable can restore
|
||||
# from remote durable storage.
|
||||
logger.error(f"Trial {trial}: Sync error: {e}")
|
||||
else:
|
||||
# If the trainable didn't have remote storage to upload
|
||||
# to then this checkpoint may have been lost, so we
|
||||
# shouldn't track it with the checkpoint_manager.
|
||||
raise e
|
||||
if not trial.uses_cloud_checkpointing:
|
||||
if not os.path.exists(checkpoint.dir_or_data):
|
||||
raise TuneError(
|
||||
"Trial {}: Checkpoint path {} not "
|
||||
"found after successful sync down. "
|
||||
"Are you running on a Kubernetes or "
|
||||
"managed cluster? rsync will not function "
|
||||
"due to a lack of SSH functionality. "
|
||||
"You'll need to use cloud-checkpointing "
|
||||
"if that's the case, see instructions "
|
||||
"here: {} .".format(
|
||||
trial,
|
||||
checkpoint.dir_or_data,
|
||||
CLOUD_CHECKPOINTING_URL,
|
||||
# Errors occurring during this wait are not fatal for this
|
||||
# checkpoint, so it should just be logged.
|
||||
logger.error(
|
||||
f"Trial {trial}: An error occurred during the "
|
||||
f"checkpoint syncing of the current checkpoint: {e}"
|
||||
)
|
||||
)
|
||||
|
||||
def on_trial_start(
|
||||
self, iteration: int, trials: List["Trial"], trial: "Trial", **info
|
||||
):
|
||||
self._get_trial_syncer(trial)
|
||||
return True
|
||||
|
||||
def on_trial_result(
|
||||
self,
|
||||
|
@ -540,23 +493,13 @@ class SyncerCallback(Callback):
|
|||
result: Dict,
|
||||
**info,
|
||||
):
|
||||
trial_syncer = self._get_trial_syncer(trial)
|
||||
trial_syncer.set_worker_ip(result.get(NODE_IP))
|
||||
trial_syncer.sync_down_if_needed()
|
||||
self._sync_trial_dir(trial, force=False, wait=False)
|
||||
|
||||
def on_trial_complete(
|
||||
self, iteration: int, trials: List["Trial"], trial: "Trial", **info
|
||||
):
|
||||
trial_syncer = self._get_trial_syncer(trial)
|
||||
if NODE_IP in trial.last_result:
|
||||
trainable_ip = trial.last_result[NODE_IP]
|
||||
else:
|
||||
trainable_ip = ray.get(trial.runner.get_current_ip.remote())
|
||||
trial_syncer.set_worker_ip(trainable_ip)
|
||||
# Always sync down when trial completed
|
||||
trial_syncer.sync_down()
|
||||
trial_syncer.close()
|
||||
self._remove_trial_syncer(trial)
|
||||
self._sync_trial_dir(trial, force=True, wait=True)
|
||||
self._remove_trial_sync_process(trial)
|
||||
|
||||
def on_checkpoint(
|
||||
self,
|
||||
|
@ -566,80 +509,30 @@ class SyncerCallback(Callback):
|
|||
checkpoint: _TrackedCheckpoint,
|
||||
**info,
|
||||
):
|
||||
self._sync_trial_checkpoint(trial, checkpoint)
|
||||
if checkpoint.storage_mode == CheckpointStorage.MEMORY:
|
||||
return
|
||||
|
||||
|
||||
def detect_cluster_syncer(
|
||||
sync_config: Optional[SyncConfig],
|
||||
cluster_config_file: str = "~/ray_bootstrap_config.yaml",
|
||||
) -> Union[bool, Type, NodeSyncer]:
|
||||
"""Detect cluster Syncer given SyncConfig.
|
||||
|
||||
Returns False if cloud checkpointing is enabled (when upload dir is
|
||||
set).
|
||||
|
||||
Else, returns sync config syncer if manually specified.
|
||||
|
||||
Else, detects cluster environment (e.g. Docker, Kubernetes) and returns
|
||||
syncer accordingly.
|
||||
|
||||
"""
|
||||
from ray.tune.integration.docker import DockerSyncer
|
||||
|
||||
sync_config = sync_config or SyncConfig()
|
||||
|
||||
if bool(sync_config.upload_dir) or sync_config.syncer is None:
|
||||
# No sync to driver for cloud checkpointing or if manually disabled
|
||||
return False
|
||||
|
||||
_syncer = sync_config.syncer
|
||||
|
||||
if _syncer == "auto":
|
||||
_syncer = None
|
||||
|
||||
if isinstance(_syncer, Type):
|
||||
return _syncer
|
||||
|
||||
# Else: True or None. Auto-detect.
|
||||
cluster_config_file = os.path.expanduser(cluster_config_file)
|
||||
if not os.path.exists(cluster_config_file):
|
||||
return _syncer
|
||||
|
||||
with open(cluster_config_file, "rt") as fp:
|
||||
config = yaml.safe_load(fp.read())
|
||||
|
||||
if config.get("docker"):
|
||||
logger.debug(
|
||||
"Detected docker autoscaling environment. Using `DockerSyncer` "
|
||||
"as sync client. If this is not correct or leads to errors, "
|
||||
"please pass a `syncer` parameter in the `SyncConfig` to "
|
||||
"`tune.run().` to manually configure syncing behavior."
|
||||
)
|
||||
return DockerSyncer
|
||||
|
||||
if config.get("provider", {}).get("type", "") == "kubernetes":
|
||||
from ray.tune.integration.kubernetes import (
|
||||
NamespacedKubernetesSyncer,
|
||||
try_import_kubernetes,
|
||||
if self._sync_trial_dir(
|
||||
trial, force=trial.sync_on_checkpoint, wait=True
|
||||
) and not os.path.exists(checkpoint.dir_or_data):
|
||||
raise TuneError(
|
||||
f"Trial {trial}: Checkpoint path {checkpoint.dir_or_data} not "
|
||||
"found after successful sync down."
|
||||
)
|
||||
|
||||
if not try_import_kubernetes():
|
||||
logger.warning(
|
||||
"Detected Ray autoscaling environment on Kubernetes, "
|
||||
"but Kubernetes Python CLI is not installed. "
|
||||
"Checkpoint syncing may not work properly across "
|
||||
"multiple pods. Be sure to install 'kubernetes' on "
|
||||
"each container."
|
||||
)
|
||||
def wait_for_all(self):
|
||||
failed_syncs = {}
|
||||
for trial, sync_process in self._sync_processes.items():
|
||||
try:
|
||||
sync_process.wait()
|
||||
except Exception as e:
|
||||
failed_syncs[trial] = e
|
||||
|
||||
namespace = config["provider"].get("namespace", "ray")
|
||||
logger.debug(
|
||||
f"Detected Ray autoscaling environment on Kubernetes. Using "
|
||||
f"`NamespacedKubernetesSyncer` with namespace `{namespace}` "
|
||||
f"as sync client. If this is not correct or leads to errors, "
|
||||
f"please pass a `syncer` parameter in the `SyncConfig` "
|
||||
f"to `tune.run()` to manually configure syncing behavior.."
|
||||
if failed_syncs:
|
||||
sync_str = "\n".join(
|
||||
[f" {trial}: {e}" for trial, e in failed_syncs.items()]
|
||||
)
|
||||
raise TuneError(
|
||||
f"At least one trial failed to sync down when waiting for all "
|
||||
f"trials to sync: \n{sync_str}"
|
||||
)
|
||||
return NamespacedKubernetesSyncer(namespace)
|
||||
|
||||
return _syncer
|
||||
|
|
0
python/ray/tune/tests/__init__.py
Normal file
0
python/ray/tune/tests/__init__.py
Normal file
|
@ -61,7 +61,7 @@ from ray.tune.suggest._mock import _MockSuggestionAlgorithm
|
|||
from ray.tune.suggest.ax import AxSearch
|
||||
from ray.tune.suggest.hyperopt import HyperOptSearch
|
||||
from ray.tune.suggest.suggestion import ConcurrencyLimiter
|
||||
from ray.tune.sync_client import CommandBasedClient
|
||||
from ray.tune.syncer import Syncer
|
||||
from ray.tune.trial import Trial
|
||||
from ray.tune.trial_runner import TrialRunner
|
||||
from ray.tune.utils import flatten_dict
|
||||
|
@ -1037,24 +1037,23 @@ class TrainableFunctionApiTest(unittest.TestCase):
|
|||
def testDurableTrainableSyncFunction(self):
|
||||
"""Check custom sync functions in durable trainables"""
|
||||
|
||||
class CustomSyncer(Syncer):
|
||||
def sync_up(
|
||||
self, local_dir: str, remote_dir: str, exclude: list = None
|
||||
) -> bool:
|
||||
pass # sync up
|
||||
|
||||
def sync_down(
|
||||
self, remote_dir: str, local_dir: str, exclude: list = None
|
||||
) -> bool:
|
||||
pass # sync down
|
||||
|
||||
def delete(self, remote_dir: str) -> bool:
|
||||
pass # delete
|
||||
|
||||
class TestDurable(Trainable):
|
||||
def __init__(self, *args, **kwargs):
|
||||
# Mock distutils.spawn.find_executable
|
||||
# so `aws` command is found
|
||||
import distutils.spawn
|
||||
|
||||
distutils.spawn.find_executable = lambda *_, **__: True
|
||||
super(TestDurable, self).__init__(*args, **kwargs)
|
||||
|
||||
def check(self):
|
||||
return (
|
||||
bool(self.sync_function_tpl)
|
||||
and isinstance(self.storage_client, CommandBasedClient)
|
||||
and "aws" not in self.storage_client.sync_up_template
|
||||
)
|
||||
|
||||
class TestTplDurable(TestDurable):
|
||||
_sync_function_tpl = "echo static sync {source} {target}"
|
||||
def has_custom_syncer(self):
|
||||
return bool(self.custom_syncer)
|
||||
|
||||
upload_dir = "s3://test-bucket/path"
|
||||
|
||||
|
@ -1074,21 +1073,17 @@ class TrainableFunctionApiTest(unittest.TestCase):
|
|||
cls = trial.get_trainable_cls()
|
||||
actor = ray.remote(cls).remote(
|
||||
remote_checkpoint_dir=upload_dir,
|
||||
sync_function_tpl=trial.sync_function_tpl,
|
||||
custom_syncer=trial.custom_syncer,
|
||||
)
|
||||
return actor
|
||||
|
||||
# This actor should create a default aws syncer, so check should fail
|
||||
actor1 = _create_remote_actor(TestDurable, None)
|
||||
self.assertFalse(ray.get(actor1.check.remote()))
|
||||
self.assertFalse(ray.get(actor1.has_custom_syncer.remote()))
|
||||
|
||||
# This actor should create a custom syncer, so check should pass
|
||||
actor2 = _create_remote_actor(TestDurable, "echo test sync {source} {target}")
|
||||
self.assertTrue(ray.get(actor2.check.remote()))
|
||||
|
||||
# This actor should create a custom syncer, so check should pass
|
||||
actor3 = _create_remote_actor(TestTplDurable, None)
|
||||
self.assertTrue(ray.get(actor3.check.remote()))
|
||||
actor2 = _create_remote_actor(TestDurable, CustomSyncer())
|
||||
self.assertTrue(ray.get(actor2.has_custom_syncer.remote()))
|
||||
|
||||
def testCheckpointDict(self):
|
||||
class TestTrain(Trainable):
|
||||
|
|
|
@ -5,30 +5,18 @@ import os
|
|||
import pytest
|
||||
import shutil
|
||||
import sys
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from typing import Callable, Union, Optional
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import ray
|
||||
from ray import tune
|
||||
from ray.rllib import _register_all
|
||||
from ray.cluster_utils import Cluster
|
||||
from ray._private.test_utils import run_string_as_driver_nonblocking
|
||||
from ray.tune import register_trainable
|
||||
from ray.tune.experiment import Experiment
|
||||
from ray.tune.error import TuneError
|
||||
from ray.tune.suggest import BasicVariantGenerator
|
||||
from ray.tune.syncer import CloudSyncer, SyncerCallback, get_node_syncer
|
||||
from ray.tune.utils.trainable import TrainableUtil
|
||||
from ray.tune.syncer import SyncerCallback
|
||||
from ray.tune.trial import Trial
|
||||
from ray.tune.trial_runner import TrialRunner
|
||||
from ray.tune.utils.mock import (
|
||||
MockDurableTrainer,
|
||||
MockRemoteTrainer,
|
||||
MockNodeSyncer,
|
||||
mock_storage_client,
|
||||
MOCK_REMOTE_DIR,
|
||||
)
|
||||
|
||||
|
||||
def _check_trial_running(trial):
|
||||
|
@ -51,27 +39,9 @@ def _start_new_cluster():
|
|||
"_system_config": {"num_heartbeats_timeout": 10},
|
||||
},
|
||||
)
|
||||
# Pytest doesn't play nicely with imports
|
||||
register_trainable("__fake_remote", MockRemoteTrainer)
|
||||
register_trainable("__fake_durable", MockDurableTrainer)
|
||||
_register_all()
|
||||
return cluster
|
||||
|
||||
|
||||
class _PerTrialSyncerCallback(SyncerCallback):
|
||||
def __init__(
|
||||
self, get_sync_fn: Callable[["Trial"], Optional[Union[bool, Callable]]]
|
||||
):
|
||||
self._get_sync_fn = get_sync_fn
|
||||
super(_PerTrialSyncerCallback, self).__init__(None)
|
||||
|
||||
def _create_trial_syncer(self, trial: "Trial"):
|
||||
sync_fn = self._get_sync_fn(trial)
|
||||
return get_node_syncer(
|
||||
trial.logdir, remote_dir=trial.logdir, sync_function=sync_fn
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def start_connected_cluster():
|
||||
# Start the Ray processes.
|
||||
|
@ -94,10 +64,6 @@ def start_connected_emptyhead_cluster():
|
|||
"_system_config": {"num_heartbeats_timeout": 10},
|
||||
},
|
||||
)
|
||||
# Pytest doesn't play nicely with imports
|
||||
_register_all()
|
||||
register_trainable("__fake_remote", MockRemoteTrainer)
|
||||
register_trainable("__fake_durable", MockDurableTrainer)
|
||||
os.environ["TUNE_STATE_REFRESH_PERIOD"] = "0.1"
|
||||
yield cluster
|
||||
# The code after the yield will run as teardown code.
|
||||
|
@ -208,8 +174,16 @@ def test_remove_node_before_result(start_connected_emptyhead_cluster):
|
|||
runner.step()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("trainable_id", ["__fake", "__fake_durable"])
|
||||
def test_trial_migration(start_connected_emptyhead_cluster, trainable_id):
|
||||
def custom_driver_logdir_callback(tempdir: str):
|
||||
class SeparateDriverSyncerCallback(SyncerCallback):
|
||||
def _local_trial_logdir(self, trial):
|
||||
return os.path.join(tempdir, trial.relative_logdir)
|
||||
|
||||
return SeparateDriverSyncerCallback()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("durable", [False, True])
|
||||
def test_trial_migration(start_connected_emptyhead_cluster, tmpdir, durable):
|
||||
"""Removing a node while cluster has space should migrate trial.
|
||||
|
||||
The trial state should also be consistent with the checkpoint.
|
||||
|
@ -218,21 +192,23 @@ def test_trial_migration(start_connected_emptyhead_cluster, trainable_id):
|
|||
node = cluster.add_node(num_cpus=1)
|
||||
cluster.wait_for_nodes()
|
||||
|
||||
syncer_callback = _PerTrialSyncerCallback(
|
||||
lambda trial: trial.trainable_name == "__fake"
|
||||
)
|
||||
if durable:
|
||||
upload_dir = "file://" + str(tmpdir)
|
||||
syncer_callback = SyncerCallback()
|
||||
else:
|
||||
upload_dir = None
|
||||
syncer_callback = custom_driver_logdir_callback(str(tmpdir))
|
||||
|
||||
runner = TrialRunner(BasicVariantGenerator(), callbacks=[syncer_callback])
|
||||
kwargs = {
|
||||
"stopping_criterion": {"training_iteration": 4},
|
||||
"checkpoint_freq": 2,
|
||||
"max_failures": 2,
|
||||
"remote_checkpoint_dir": upload_dir,
|
||||
}
|
||||
|
||||
if trainable_id == "__fake_durable":
|
||||
kwargs["remote_checkpoint_dir"] = MOCK_REMOTE_DIR
|
||||
|
||||
# Test recovery of trial that hasn't been checkpointed
|
||||
t = Trial(trainable_id, **kwargs)
|
||||
t = Trial("__fake", **kwargs)
|
||||
runner.add_trial(t)
|
||||
runner.step() # Start trial
|
||||
runner.step() # Process result
|
||||
|
@ -257,7 +233,7 @@ def test_trial_migration(start_connected_emptyhead_cluster, trainable_id):
|
|||
assert t.status == Trial.TERMINATED, runner.debug_string()
|
||||
|
||||
# Test recovery of trial that has been checkpointed
|
||||
t2 = Trial(trainable_id, **kwargs)
|
||||
t2 = Trial("__fake", **kwargs)
|
||||
runner.add_trial(t2)
|
||||
# Start trial, process result (x2), process save
|
||||
while not t2.has_checkpoint():
|
||||
|
@ -272,12 +248,10 @@ def test_trial_migration(start_connected_emptyhead_cluster, trainable_id):
|
|||
# Test recovery of trial that won't be checkpointed
|
||||
kwargs = {
|
||||
"stopping_criterion": {"training_iteration": 3},
|
||||
"remote_checkpoint_dir": upload_dir,
|
||||
}
|
||||
|
||||
if trainable_id == "__fake_durable":
|
||||
kwargs["remote_checkpoint_dir"] = MOCK_REMOTE_DIR
|
||||
|
||||
t3 = Trial(trainable_id, **kwargs)
|
||||
t3 = Trial("__fake", **kwargs)
|
||||
runner.add_trial(t3)
|
||||
runner.step() # Start trial
|
||||
runner.step() # Process result 1
|
||||
|
@ -292,8 +266,8 @@ def test_trial_migration(start_connected_emptyhead_cluster, trainable_id):
|
|||
runner.step()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("trainable_id", ["__fake", "__fake_durable"])
|
||||
def test_trial_requeue(start_connected_emptyhead_cluster, trainable_id):
|
||||
@pytest.mark.parametrize("durable", [False, True])
|
||||
def test_trial_requeue(start_connected_emptyhead_cluster, tmpdir, durable):
|
||||
"""Removing a node in full cluster causes Trial to be requeued."""
|
||||
os.environ["TUNE_MAX_PENDING_TRIALS_PG"] = "1"
|
||||
|
||||
|
@ -301,20 +275,22 @@ def test_trial_requeue(start_connected_emptyhead_cluster, trainable_id):
|
|||
node = cluster.add_node(num_cpus=1)
|
||||
cluster.wait_for_nodes()
|
||||
|
||||
syncer_callback = _PerTrialSyncerCallback(
|
||||
lambda trial: trial.trainable_name == "__fake"
|
||||
)
|
||||
if durable:
|
||||
upload_dir = "file://" + str(tmpdir)
|
||||
syncer_callback = SyncerCallback()
|
||||
else:
|
||||
upload_dir = None
|
||||
syncer_callback = custom_driver_logdir_callback(str(tmpdir))
|
||||
|
||||
runner = TrialRunner(BasicVariantGenerator(), callbacks=[syncer_callback]) # noqa
|
||||
kwargs = {
|
||||
"stopping_criterion": {"training_iteration": 5},
|
||||
"checkpoint_freq": 1,
|
||||
"max_failures": 1,
|
||||
"remote_checkpoint_dir": upload_dir,
|
||||
}
|
||||
|
||||
if trainable_id == "__fake_durable":
|
||||
kwargs["remote_checkpoint_dir"] = MOCK_REMOTE_DIR
|
||||
|
||||
trials = [Trial(trainable_id, **kwargs), Trial(trainable_id, **kwargs)]
|
||||
trials = [Trial("__fake", **kwargs), Trial("__fake", **kwargs)]
|
||||
for t in trials:
|
||||
runner.add_trial(t)
|
||||
|
||||
|
@ -333,58 +309,32 @@ def test_trial_requeue(start_connected_emptyhead_cluster, trainable_id):
|
|||
assert all(t.status == Trial.PENDING for t in trials), runner.debug_string()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("trainable_id", ["__fake_remote", "__fake_durable"])
|
||||
def test_migration_checkpoint_removal(start_connected_emptyhead_cluster, trainable_id):
|
||||
@pytest.mark.parametrize("durable", [False, True])
|
||||
def test_migration_checkpoint_removal(
|
||||
start_connected_emptyhead_cluster, tmpdir, durable
|
||||
):
|
||||
"""Test checks that trial restarts if checkpoint is lost w/ node fail."""
|
||||
cluster = start_connected_emptyhead_cluster
|
||||
node = cluster.add_node(num_cpus=1)
|
||||
cluster.wait_for_nodes()
|
||||
|
||||
# Only added for fake_remote case.
|
||||
# For durable case, we don't do sync to head.
|
||||
class _SyncerCallback(SyncerCallback):
|
||||
def _create_trial_syncer(self, trial: "Trial"):
|
||||
client = mock_storage_client()
|
||||
return MockNodeSyncer(trial.logdir, trial.logdir, client)
|
||||
if durable:
|
||||
upload_dir = "file://" + str(tmpdir)
|
||||
syncer_callback = SyncerCallback()
|
||||
else:
|
||||
upload_dir = None
|
||||
syncer_callback = custom_driver_logdir_callback(str(tmpdir))
|
||||
|
||||
syncer_callback = (
|
||||
[_SyncerCallback(None)] if trainable_id == "__fake_remote" else None
|
||||
)
|
||||
runner = TrialRunner(BasicVariantGenerator(), callbacks=syncer_callback)
|
||||
runner = TrialRunner(BasicVariantGenerator(), callbacks=[syncer_callback])
|
||||
kwargs = {
|
||||
"stopping_criterion": {"training_iteration": 4},
|
||||
"checkpoint_freq": 2,
|
||||
"max_failures": 2,
|
||||
"remote_checkpoint_dir": upload_dir,
|
||||
}
|
||||
|
||||
if trainable_id == "__fake_durable":
|
||||
kwargs["remote_checkpoint_dir"] = MOCK_REMOTE_DIR
|
||||
|
||||
# The following patches only affect __fake_remote.
|
||||
def hide_remote_path(path_function):
|
||||
def hidden_path_func(checkpoint_path):
|
||||
"""Converts back to local path first."""
|
||||
if MOCK_REMOTE_DIR in checkpoint_path:
|
||||
checkpoint_path = checkpoint_path[len(MOCK_REMOTE_DIR) :]
|
||||
checkpoint_path = os.path.join("/", checkpoint_path)
|
||||
return path_function(checkpoint_path)
|
||||
|
||||
return hidden_path_func
|
||||
|
||||
trainable_util = "ray.tune.ray_trial_executor.TrainableUtil"
|
||||
_find_ckpt = trainable_util + ".find_checkpoint_dir"
|
||||
find_func = TrainableUtil.find_checkpoint_dir
|
||||
_pickle_ckpt = trainable_util + ".pickle_checkpoint"
|
||||
pickle_func = TrainableUtil.pickle_checkpoint
|
||||
|
||||
with patch(_find_ckpt) as mock_find, patch(_pickle_ckpt) as mock_pkl_ckpt:
|
||||
# __fake_remote trainables save to a separate "remote" directory.
|
||||
# TrainableUtil will not check this path unless we mock it.
|
||||
mock_find.side_effect = hide_remote_path(find_func)
|
||||
mock_pkl_ckpt.side_effect = hide_remote_path(pickle_func)
|
||||
|
||||
# Test recovery of trial that has been checkpointed
|
||||
t1 = Trial(trainable_id, **kwargs)
|
||||
t1 = Trial("__fake", **kwargs)
|
||||
runner.add_trial(t1)
|
||||
|
||||
# Start trial, process result (x2), process save
|
||||
|
@ -394,84 +344,45 @@ def test_migration_checkpoint_removal(start_connected_emptyhead_cluster, trainab
|
|||
cluster.add_node(num_cpus=1)
|
||||
cluster.remove_node(node)
|
||||
cluster.wait_for_nodes()
|
||||
|
||||
# Remove checkpoint on "remote" node
|
||||
shutil.rmtree(os.path.dirname(t1.checkpoint.dir_or_data))
|
||||
|
||||
if not durable:
|
||||
# Recover from driver file
|
||||
t1.checkpoint.dir_or_data = os.path.join(
|
||||
tmpdir,
|
||||
t1.relative_logdir,
|
||||
os.path.relpath(t1.checkpoint.dir_or_data, t1.logdir),
|
||||
)
|
||||
|
||||
while not runner.is_finished():
|
||||
runner.step()
|
||||
assert t1.status == Trial.TERMINATED, runner.debug_string()
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Not very consistent.")
|
||||
@pytest.mark.parametrize("trainable_id", ["__fake", "__fake_durable"])
|
||||
def test_cluster_down_simple(start_connected_cluster, tmpdir, trainable_id):
|
||||
"""Tests that TrialRunner save/restore works on cluster shutdown."""
|
||||
cluster = start_connected_cluster
|
||||
cluster.add_node(num_cpus=1)
|
||||
cluster.wait_for_nodes()
|
||||
|
||||
dirpath = str(tmpdir)
|
||||
syncer_callback = _PerTrialSyncerCallback(
|
||||
lambda trial: trial.trainable_name == "__fake"
|
||||
)
|
||||
runner = TrialRunner(
|
||||
local_checkpoint_dir=dirpath, checkpoint_period=0, callbacks=[syncer_callback]
|
||||
)
|
||||
kwargs = {
|
||||
"stopping_criterion": {"training_iteration": 2},
|
||||
"checkpoint_freq": 1,
|
||||
"max_failures": 1,
|
||||
}
|
||||
|
||||
if trainable_id == "__fake_durable":
|
||||
kwargs["remote_checkpoint_dir"] = MOCK_REMOTE_DIR
|
||||
|
||||
trials = [Trial(trainable_id, **kwargs), Trial(trainable_id, **kwargs)]
|
||||
for t in trials:
|
||||
runner.add_trial(t)
|
||||
|
||||
# Start trial (x2), process result, process save
|
||||
for _ in range(4):
|
||||
runner.step()
|
||||
assert all(t.status == Trial.RUNNING for t in runner.get_trials())
|
||||
runner.checkpoint()
|
||||
|
||||
ray.shutdown()
|
||||
cluster.shutdown()
|
||||
|
||||
cluster = _start_new_cluster()
|
||||
runner = TrialRunner(resume="LOCAL", local_checkpoint_dir=dirpath)
|
||||
# Start trial, process restore, process result, process save
|
||||
for _ in range(4):
|
||||
runner.step()
|
||||
|
||||
# Start trial 2, process result, process save, process result, process save
|
||||
for i in range(5):
|
||||
runner.step()
|
||||
|
||||
with pytest.raises(TuneError):
|
||||
runner.step()
|
||||
|
||||
assert all(t.status == Trial.TERMINATED for t in runner.get_trials())
|
||||
ray.shutdown()
|
||||
cluster.shutdown()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("trainable_id", ["__fake", "__fake_durable"])
|
||||
def test_cluster_down_full(start_connected_cluster, tmpdir, trainable_id):
|
||||
@pytest.mark.parametrize("durable", [False, True])
|
||||
def test_cluster_down_full(start_connected_cluster, tmpdir, durable):
|
||||
"""Tests that run_experiment restoring works on cluster shutdown."""
|
||||
cluster = start_connected_cluster
|
||||
dirpath = str(tmpdir)
|
||||
|
||||
use_default_sync = trainable_id == "__fake"
|
||||
if durable:
|
||||
upload_dir = "file://" + str(tmpdir)
|
||||
syncer_callback = SyncerCallback()
|
||||
else:
|
||||
upload_dir = None
|
||||
syncer_callback = custom_driver_logdir_callback(str(tmpdir))
|
||||
|
||||
from ray.tune.result import DEFAULT_RESULTS_DIR
|
||||
|
||||
local_dir = DEFAULT_RESULTS_DIR
|
||||
upload_dir = None if use_default_sync else MOCK_REMOTE_DIR
|
||||
|
||||
base_dict = dict(
|
||||
run=trainable_id,
|
||||
run="__fake",
|
||||
stop=dict(training_iteration=3),
|
||||
local_dir=local_dir,
|
||||
sync_config=dict(upload_dir=upload_dir, syncer=use_default_sync),
|
||||
sync_config=dict(upload_dir=upload_dir),
|
||||
)
|
||||
|
||||
exp1_args = base_dict
|
||||
|
@ -486,19 +397,18 @@ def test_cluster_down_full(start_connected_cluster, tmpdir, trainable_id):
|
|||
"exp4": exp4_args,
|
||||
}
|
||||
|
||||
mock_get_client = "ray.tune.trial_runner.get_cloud_syncer"
|
||||
with patch(mock_get_client) as mock_get_cloud_syncer:
|
||||
mock_syncer = CloudSyncer(local_dir, upload_dir, mock_storage_client())
|
||||
mock_get_cloud_syncer.return_value = mock_syncer
|
||||
|
||||
tune.run_experiments(all_experiments, raise_on_failed_trial=False)
|
||||
tune.run_experiments(
|
||||
all_experiments, callbacks=[syncer_callback], raise_on_failed_trial=False
|
||||
)
|
||||
|
||||
ray.shutdown()
|
||||
cluster.shutdown()
|
||||
cluster = _start_new_cluster()
|
||||
|
||||
trials = tune.run_experiments(
|
||||
all_experiments, resume=True, raise_on_failed_trial=False
|
||||
all_experiments,
|
||||
resume=True,
|
||||
raise_on_failed_trial=False,
|
||||
)
|
||||
|
||||
assert len(trials) == 4
|
||||
|
|
|
@ -6,12 +6,9 @@ import subprocess
|
|||
import sys
|
||||
|
||||
import ray
|
||||
from ray.rllib import _register_all
|
||||
from ray.cluster_utils import Cluster
|
||||
from ray.tune import register_trainable
|
||||
from ray.tune.trial import Trial
|
||||
from ray.tune.trial_runner import TrialRunner
|
||||
from ray.tune.utils.mock import MockDurableTrainer, MockRemoteTrainer
|
||||
from ray.tune.utils.mock_trainable import MyTrainableClass
|
||||
|
||||
|
||||
|
@ -24,10 +21,6 @@ def _start_new_cluster():
|
|||
"_system_config": {"num_heartbeats_timeout": 10},
|
||||
},
|
||||
)
|
||||
# Pytest doesn't play nicely with imports
|
||||
register_trainable("__fake_remote", MockRemoteTrainer)
|
||||
register_trainable("__fake_durable", MockDurableTrainer)
|
||||
_register_all()
|
||||
return cluster
|
||||
|
||||
|
||||
|
|
|
@ -1,108 +0,0 @@
|
|||
import unittest
|
||||
from typing import Optional
|
||||
|
||||
from ray.tune.integration.docker import DockerSyncer, DockerSyncClient
|
||||
from ray.tune.sync_client import SyncClient
|
||||
from ray.tune.syncer import NodeSyncer
|
||||
|
||||
|
||||
class _MockRsync:
|
||||
def __init__(self):
|
||||
self.history = []
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
self.history.append(kwargs)
|
||||
|
||||
|
||||
class _MockLookup:
|
||||
def __init__(self, node_ips):
|
||||
self.node_to_ip = {}
|
||||
self.ip_to_node = {}
|
||||
for node, ip in node_ips.items():
|
||||
self.node_to_ip[node] = ip
|
||||
self.ip_to_node[ip] = node
|
||||
|
||||
def get_ip(self, node):
|
||||
return self.node_to_ip[node]
|
||||
|
||||
def get_node(self, ip):
|
||||
return self.ip_to_node[ip]
|
||||
|
||||
|
||||
def _create_mock_syncer(local_ip, local_dir, remote_dir):
|
||||
class _MockSyncer(DockerSyncer):
|
||||
def __init__(
|
||||
self,
|
||||
local_dir: str,
|
||||
remote_dir: str,
|
||||
sync_client: Optional[SyncClient] = None,
|
||||
):
|
||||
self.local_ip = local_ip
|
||||
self.worker_ip = None
|
||||
|
||||
sync_client = sync_client or DockerSyncClient()
|
||||
sync_client.configure("__nofile__")
|
||||
|
||||
super(NodeSyncer, self).__init__(local_dir, remote_dir, sync_client)
|
||||
|
||||
return _MockSyncer(local_dir, remote_dir)
|
||||
|
||||
|
||||
class DockerIntegrationTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.lookup = _MockLookup({"head": "1.0.0.0", "w1": "1.0.0.1", "w2": "1.0.0.2"})
|
||||
self.local_dir = "/tmp/local"
|
||||
self.remote_dir = "/tmp/remote"
|
||||
|
||||
self.mock_command = _MockRsync()
|
||||
|
||||
from ray.tune.integration import docker
|
||||
|
||||
docker.rsync = self.mock_command
|
||||
|
||||
def tearDown(self):
|
||||
pass
|
||||
|
||||
def testDockerRsyncUpDown(self):
|
||||
syncer = _create_mock_syncer(
|
||||
self.lookup.get_ip("head"), self.local_dir, self.remote_dir
|
||||
)
|
||||
|
||||
syncer.set_worker_ip(self.lookup.get_ip("w1"))
|
||||
|
||||
# Test sync up. Should add / to the dirs and call rsync
|
||||
syncer.sync_up()
|
||||
print(self.mock_command.history[-1])
|
||||
self.assertEqual(self.mock_command.history[-1]["source"], self.local_dir + "/")
|
||||
self.assertEqual(self.mock_command.history[-1]["target"], self.remote_dir + "/")
|
||||
self.assertEqual(self.mock_command.history[-1]["down"], False)
|
||||
self.assertEqual(
|
||||
self.mock_command.history[-1]["ip_address"], self.lookup.get_ip("w1")
|
||||
)
|
||||
|
||||
# Test sync down.
|
||||
syncer.sync_down()
|
||||
print(self.mock_command.history[-1])
|
||||
|
||||
self.assertEqual(self.mock_command.history[-1]["target"], self.local_dir + "/")
|
||||
self.assertEqual(self.mock_command.history[-1]["source"], self.remote_dir + "/")
|
||||
self.assertEqual(self.mock_command.history[-1]["down"], True)
|
||||
self.assertEqual(
|
||||
self.mock_command.history[-1]["ip_address"], self.lookup.get_ip("w1")
|
||||
)
|
||||
|
||||
# Sync to same node should be ignored
|
||||
prev = len(self.mock_command.history)
|
||||
syncer.set_worker_ip(self.lookup.get_ip("head"))
|
||||
syncer.sync_up()
|
||||
self.assertEqual(len(self.mock_command.history), prev)
|
||||
|
||||
syncer.sync_down()
|
||||
self.assertEqual(len(self.mock_command.history), prev)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import pytest
|
||||
import sys
|
||||
|
||||
sys.exit(pytest.main(["-v", __file__]))
|
|
@ -1,114 +0,0 @@
|
|||
import unittest
|
||||
|
||||
from ray.autoscaler._private.command_runner import KUBECTL_RSYNC
|
||||
from ray.tune.integration.kubernetes import KubernetesSyncer, KubernetesSyncClient
|
||||
from ray.tune.syncer import NodeSyncer
|
||||
|
||||
|
||||
class _MockProcessRunner:
|
||||
def __init__(self):
|
||||
self.history = []
|
||||
|
||||
def check_call(self, command):
|
||||
self.history.append(command)
|
||||
return True
|
||||
|
||||
|
||||
class _MockLookup:
|
||||
def __init__(self, node_ips):
|
||||
self.node_to_ip = {}
|
||||
self.ip_to_node = {}
|
||||
for node, ip in node_ips.items():
|
||||
self.node_to_ip[node] = ip
|
||||
self.ip_to_node[ip] = node
|
||||
|
||||
def get_ip(self, node):
|
||||
return self.node_to_ip[node]
|
||||
|
||||
def get_node(self, ip):
|
||||
return self.ip_to_node[ip]
|
||||
|
||||
def __call__(self, ip):
|
||||
return self.ip_to_node[ip]
|
||||
|
||||
|
||||
def _create_mock_syncer(
|
||||
namespace, lookup, process_runner, local_ip, local_dir, remote_dir
|
||||
):
|
||||
class _MockSyncer(KubernetesSyncer):
|
||||
_namespace = namespace
|
||||
_get_kubernetes_node_by_ip = lookup
|
||||
|
||||
def __init__(self, local_dir, remote_dir, sync_client):
|
||||
self.local_ip = local_ip
|
||||
self.local_node = self._get_kubernetes_node_by_ip(self.local_ip)
|
||||
self.worker_ip = None
|
||||
self.worker_node = None
|
||||
|
||||
sync_client = sync_client
|
||||
super(NodeSyncer, self).__init__(local_dir, remote_dir, sync_client)
|
||||
|
||||
return _MockSyncer(
|
||||
local_dir,
|
||||
remote_dir,
|
||||
sync_client=KubernetesSyncClient(
|
||||
namespace=namespace, process_runner=process_runner
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class KubernetesIntegrationTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.namespace = "test_ray"
|
||||
self.lookup = _MockLookup({"head": "1.0.0.0", "w1": "1.0.0.1", "w2": "1.0.0.2"})
|
||||
self.process_runner = _MockProcessRunner()
|
||||
self.local_dir = "/tmp/local"
|
||||
self.remote_dir = "/tmp/remote"
|
||||
|
||||
def tearDown(self):
|
||||
pass
|
||||
|
||||
def testKubernetesRsyncUpDown(self):
|
||||
syncer = _create_mock_syncer(
|
||||
self.namespace,
|
||||
self.lookup,
|
||||
self.process_runner,
|
||||
self.lookup.get_ip("head"),
|
||||
self.local_dir,
|
||||
self.remote_dir,
|
||||
)
|
||||
|
||||
syncer.set_worker_ip(self.lookup.get_ip("w1"))
|
||||
|
||||
# Test sync up. Should add / to the dirs and call rsync
|
||||
syncer.sync_up()
|
||||
self.assertEqual(self.process_runner.history[-1][0], KUBECTL_RSYNC)
|
||||
self.assertEqual(self.process_runner.history[-1][-2], self.local_dir + "/")
|
||||
self.assertEqual(
|
||||
self.process_runner.history[-1][-1],
|
||||
"{}@{}:{}".format("w1", self.namespace, self.remote_dir + "/"),
|
||||
)
|
||||
|
||||
# Test sync down.
|
||||
syncer.sync_down()
|
||||
self.assertEqual(self.process_runner.history[-1][0], KUBECTL_RSYNC)
|
||||
self.assertEqual(
|
||||
self.process_runner.history[-1][-2],
|
||||
"{}@{}:{}".format("w1", self.namespace, self.remote_dir + "/"),
|
||||
)
|
||||
self.assertEqual(self.process_runner.history[-1][-1], self.local_dir + "/")
|
||||
|
||||
# Sync to same node should be ignored
|
||||
syncer.set_worker_ip(self.lookup.get_ip("head"))
|
||||
syncer.sync_up()
|
||||
self.assertTrue(len(self.process_runner.history) == 2)
|
||||
|
||||
syncer.sync_down()
|
||||
self.assertTrue(len(self.process_runner.history) == 2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import pytest
|
||||
import sys
|
||||
|
||||
sys.exit(pytest.main(["-v", __file__]))
|
|
@ -1,717 +0,0 @@
|
|||
import glob
|
||||
import io
|
||||
import os
|
||||
import pickle
|
||||
import shutil
|
||||
import sys
|
||||
import tarfile
|
||||
import tempfile
|
||||
import time
|
||||
import unittest
|
||||
from unittest.mock import patch
|
||||
import yaml
|
||||
|
||||
from collections import deque
|
||||
|
||||
import ray
|
||||
from ray.exceptions import RayTaskError
|
||||
from ray.rllib import _register_all
|
||||
|
||||
from ray import tune
|
||||
from ray.tune import TuneError
|
||||
from ray.tune.integration.docker import DockerSyncer
|
||||
from ray.tune.integration.kubernetes import KubernetesSyncer
|
||||
from ray.tune.sync_client import NOOP, RemoteTaskClient
|
||||
from ray.tune.syncer import (
|
||||
CommandBasedClient,
|
||||
detect_cluster_syncer,
|
||||
get_cloud_sync_client,
|
||||
SyncerCallback,
|
||||
)
|
||||
from ray.tune.utils.callback import create_default_callbacks
|
||||
from ray.tune.utils.file_transfer import (
|
||||
delete_on_node,
|
||||
_sync_dir_on_same_node,
|
||||
_sync_dir_between_different_nodes,
|
||||
)
|
||||
|
||||
|
||||
# Default RemoteTaskClient will use _sync_dir_on_same_node in this test,
|
||||
# as the IPs are the same
|
||||
class RemoteTaskClientWithSyncDirBetweenDifferentNodes(RemoteTaskClient):
|
||||
def _sync_function(self, *args, **kwargs):
|
||||
return _sync_dir_between_different_nodes(*args, **kwargs)
|
||||
|
||||
|
||||
class TestSyncFunctionality(unittest.TestCase):
|
||||
def setUp(self):
|
||||
ray.init(num_cpus=2)
|
||||
|
||||
def tearDown(self):
|
||||
ray.shutdown()
|
||||
_register_all() # re-register the evicted objects
|
||||
|
||||
@patch("ray.tune.sync_client.S3_PREFIX", "test")
|
||||
def testCloudProperString(self):
|
||||
with self.assertRaises(ValueError):
|
||||
[trial] = tune.run(
|
||||
"__fake",
|
||||
name="foo",
|
||||
max_failures=0,
|
||||
stop={"training_iteration": 1},
|
||||
sync_config=tune.SyncConfig(
|
||||
**{"upload_dir": "test", "syncer": "ls {target}"}
|
||||
),
|
||||
).trials
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
[trial] = tune.run(
|
||||
"__fake",
|
||||
name="foo",
|
||||
max_failures=0,
|
||||
stop={"training_iteration": 1},
|
||||
sync_config=tune.SyncConfig(
|
||||
**{"upload_dir": "test", "syncer": "ls {source}"}
|
||||
),
|
||||
).trials
|
||||
|
||||
tmpdir = tempfile.mkdtemp()
|
||||
logfile = os.path.join(tmpdir, "test.log")
|
||||
|
||||
[trial] = tune.run(
|
||||
"__fake",
|
||||
name="foo",
|
||||
max_failures=0,
|
||||
stop={"training_iteration": 1},
|
||||
sync_config=tune.SyncConfig(
|
||||
**{
|
||||
"upload_dir": "test",
|
||||
"syncer": "echo {source} {target} > " + logfile,
|
||||
}
|
||||
),
|
||||
).trials
|
||||
with open(logfile) as f:
|
||||
lines = f.read()
|
||||
self.assertTrue("test" in lines)
|
||||
shutil.rmtree(tmpdir)
|
||||
|
||||
def testClusterProperString(self):
|
||||
"""Tests that invalid commands throw.."""
|
||||
with self.assertRaises(TuneError):
|
||||
# This raises ValueError because logger is init in safe zone.
|
||||
sync_config = tune.SyncConfig(syncer="ls {target}")
|
||||
[trial] = tune.run(
|
||||
"__fake",
|
||||
name="foo",
|
||||
max_failures=0,
|
||||
stop={"training_iteration": 1},
|
||||
sync_config=sync_config,
|
||||
).trials
|
||||
|
||||
with self.assertRaises(TuneError):
|
||||
# This raises ValueError because logger is init in safe zone.
|
||||
sync_config = tune.SyncConfig(syncer="ls {source}")
|
||||
[trial] = tune.run(
|
||||
"__fake",
|
||||
name="foo",
|
||||
max_failures=0,
|
||||
sync_config=sync_config,
|
||||
stop={"training_iteration": 1},
|
||||
).trials
|
||||
|
||||
with patch.object(CommandBasedClient, "_execute") as mock_fn:
|
||||
with patch("ray.tune.syncer.get_node_ip_address") as mock_sync:
|
||||
sync_config = tune.SyncConfig(syncer="echo {source} {target}")
|
||||
mock_sync.return_value = "0.0.0.0"
|
||||
[trial] = tune.run(
|
||||
"__fake",
|
||||
name="foo",
|
||||
max_failures=0,
|
||||
sync_config=sync_config,
|
||||
stop={"training_iteration": 1},
|
||||
).trials
|
||||
self.assertGreater(mock_fn.call_count, 0)
|
||||
|
||||
def testCloudFunctions(self):
|
||||
tmpdir = tempfile.mkdtemp()
|
||||
tmpdir2 = tempfile.mkdtemp()
|
||||
os.mkdir(os.path.join(tmpdir2, "foo"))
|
||||
|
||||
def sync_func(local, remote, exclude=None):
|
||||
for filename in glob.glob(os.path.join(local, "*.json")):
|
||||
shutil.copy(filename, remote)
|
||||
|
||||
sync_config = tune.SyncConfig(upload_dir=tmpdir2, syncer=sync_func)
|
||||
[trial] = tune.run(
|
||||
"__fake",
|
||||
name="foo",
|
||||
max_failures=0,
|
||||
local_dir=tmpdir,
|
||||
stop={"training_iteration": 1},
|
||||
sync_config=sync_config,
|
||||
).trials
|
||||
test_file_path = glob.glob(os.path.join(tmpdir2, "foo", "*.json"))
|
||||
self.assertTrue(test_file_path)
|
||||
shutil.rmtree(tmpdir)
|
||||
shutil.rmtree(tmpdir2)
|
||||
|
||||
@patch("ray.tune.sync_client.S3_PREFIX", "test")
|
||||
def testCloudSyncPeriod(self):
|
||||
"""Tests that changing SYNC_PERIOD affects syncing frequency."""
|
||||
tmpdir = tempfile.mkdtemp()
|
||||
|
||||
def trainable(config):
|
||||
for i in range(10):
|
||||
time.sleep(1)
|
||||
tune.report(score=i)
|
||||
|
||||
def counter(local, remote, exclude=None):
|
||||
count_file = os.path.join(tmpdir, "count.txt")
|
||||
if not os.path.exists(count_file):
|
||||
count = 0
|
||||
else:
|
||||
with open(count_file, "rb") as fp:
|
||||
count = pickle.load(fp)
|
||||
count += 1
|
||||
with open(count_file, "wb") as fp:
|
||||
pickle.dump(count, fp)
|
||||
|
||||
sync_config = tune.SyncConfig(upload_dir="test", syncer=counter, sync_period=1)
|
||||
# This was originally set to 0.5
|
||||
os.environ["TUNE_GLOBAL_CHECKPOINT_S"] = "0"
|
||||
self.addCleanup(lambda: os.environ.pop("TUNE_GLOBAL_CHECKPOINT_S", None))
|
||||
[trial] = tune.run(
|
||||
trainable,
|
||||
name="foo",
|
||||
max_failures=0,
|
||||
local_dir=tmpdir,
|
||||
stop={"training_iteration": 10},
|
||||
sync_config=sync_config,
|
||||
).trials
|
||||
|
||||
count_file = os.path.join(tmpdir, "count.txt")
|
||||
with open(count_file, "rb") as fp:
|
||||
count = pickle.load(fp)
|
||||
|
||||
self.assertEqual(count, 12)
|
||||
shutil.rmtree(tmpdir)
|
||||
|
||||
def testClusterSyncFunction(self):
|
||||
def sync_func_driver(source, target, exclude=None):
|
||||
assert ":" in source, "Source {} not a remote path.".format(source)
|
||||
assert ":" not in target, "Target is supposed to be local."
|
||||
with open(os.path.join(target, "test.log2"), "w") as f:
|
||||
print("writing to", f.name)
|
||||
f.write(source)
|
||||
|
||||
sync_config = tune.SyncConfig(syncer=sync_func_driver, sync_period=5)
|
||||
|
||||
[trial] = tune.run(
|
||||
"__fake",
|
||||
name="foo",
|
||||
max_failures=0,
|
||||
stop={"training_iteration": 1},
|
||||
sync_config=sync_config,
|
||||
).trials
|
||||
test_file_path = os.path.join(trial.logdir, "test.log2")
|
||||
self.assertFalse(os.path.exists(test_file_path))
|
||||
|
||||
with patch("ray.tune.syncer.get_node_ip_address") as mock_sync:
|
||||
mock_sync.return_value = "0.0.0.0"
|
||||
sync_config = tune.SyncConfig(syncer=sync_func_driver)
|
||||
[trial] = tune.run(
|
||||
"__fake",
|
||||
name="foo",
|
||||
max_failures=0,
|
||||
stop={"training_iteration": 1},
|
||||
sync_config=sync_config,
|
||||
).trials
|
||||
test_file_path = os.path.join(trial.logdir, "test.log2")
|
||||
self.assertTrue(os.path.exists(test_file_path))
|
||||
os.remove(test_file_path)
|
||||
|
||||
def testNoSync(self):
|
||||
"""Sync should not run on a single node."""
|
||||
|
||||
def sync_func(source, target, exclude=None):
|
||||
pass
|
||||
|
||||
sync_config = tune.SyncConfig(syncer=sync_func)
|
||||
|
||||
with patch.object(CommandBasedClient, "_execute") as mock_sync:
|
||||
[trial] = tune.run(
|
||||
"__fake",
|
||||
name="foo",
|
||||
max_failures=0,
|
||||
stop={"training_iteration": 1},
|
||||
sync_config=sync_config,
|
||||
).trials
|
||||
self.assertEqual(mock_sync.call_count, 0)
|
||||
|
||||
def testCloudSyncExclude(self):
|
||||
captured = deque(maxlen=1)
|
||||
captured.append("")
|
||||
|
||||
def always_true(*args, **kwargs):
|
||||
return True
|
||||
|
||||
def capture_popen(command, *args, **kwargs):
|
||||
captured.append(command)
|
||||
|
||||
with patch("subprocess.Popen", capture_popen), patch(
|
||||
"distutils.spawn.find_executable", always_true
|
||||
):
|
||||
# S3
|
||||
s3_client = get_cloud_sync_client("s3://test-bucket/test-dir")
|
||||
s3_client.sync_down(
|
||||
"s3://test-bucket/test-dir/remote_source", "local_target"
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
captured[0].strip(),
|
||||
"aws s3 sync s3://test-bucket/test-dir/remote_source "
|
||||
"local_target --exact-timestamps --only-show-errors",
|
||||
)
|
||||
|
||||
s3_client.sync_down(
|
||||
"s3://test-bucket/test-dir/remote_source",
|
||||
"local_target",
|
||||
exclude=["*/checkpoint_*"],
|
||||
)
|
||||
self.assertEqual(
|
||||
captured[0].strip(),
|
||||
"aws s3 sync s3://test-bucket/test-dir/remote_source "
|
||||
"local_target --exact-timestamps --only-show-errors "
|
||||
"--exclude '*/checkpoint_*'",
|
||||
)
|
||||
|
||||
s3_client.sync_down(
|
||||
"s3://test-bucket/test-dir/remote_source",
|
||||
"local_target",
|
||||
exclude=["*/checkpoint_*", "*.big"],
|
||||
)
|
||||
self.assertEqual(
|
||||
captured[0].strip(),
|
||||
"aws s3 sync s3://test-bucket/test-dir/remote_source "
|
||||
"local_target --exact-timestamps --only-show-errors "
|
||||
"--exclude '*/checkpoint_*' --exclude '*.big'",
|
||||
)
|
||||
|
||||
# GS
|
||||
gs_client = get_cloud_sync_client("gs://test-bucket/test-dir")
|
||||
gs_client.sync_down(
|
||||
"gs://test-bucket/test-dir/remote_source", "local_target"
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
captured[0].strip(),
|
||||
"gsutil rsync -r "
|
||||
"gs://test-bucket/test-dir/remote_source "
|
||||
"local_target",
|
||||
)
|
||||
|
||||
gs_client.sync_down(
|
||||
"gs://test-bucket/test-dir/remote_source",
|
||||
"local_target",
|
||||
exclude=["*/checkpoint_*"],
|
||||
)
|
||||
self.assertEqual(
|
||||
captured[0].strip(),
|
||||
"gsutil rsync -r "
|
||||
"-x '(.*/checkpoint_.*)' "
|
||||
"gs://test-bucket/test-dir/remote_source "
|
||||
"local_target",
|
||||
)
|
||||
|
||||
gs_client.sync_down(
|
||||
"gs://test-bucket/test-dir/remote_source",
|
||||
"local_target",
|
||||
exclude=["*/checkpoint_*", "*.big"],
|
||||
)
|
||||
self.assertEqual(
|
||||
captured[0].strip(),
|
||||
"gsutil rsync -r "
|
||||
"-x '(.*/checkpoint_.*)|(.*.big)' "
|
||||
"gs://test-bucket/test-dir/remote_source "
|
||||
"local_target",
|
||||
)
|
||||
|
||||
def testSyncDetection(self):
|
||||
kubernetes_conf = {"provider": {"type": "kubernetes", "namespace": "test_ray"}}
|
||||
docker_conf = {"docker": {"image": "bogus"}, "provider": {"type": "aws"}}
|
||||
aws_conf = {"provider": {"type": "aws"}}
|
||||
|
||||
with tempfile.TemporaryDirectory() as dir:
|
||||
kubernetes_file = os.path.join(dir, "kubernetes.yaml")
|
||||
with open(kubernetes_file, "wt") as fp:
|
||||
yaml.safe_dump(kubernetes_conf, fp)
|
||||
|
||||
docker_file = os.path.join(dir, "docker.yaml")
|
||||
with open(docker_file, "wt") as fp:
|
||||
yaml.safe_dump(docker_conf, fp)
|
||||
|
||||
aws_file = os.path.join(dir, "aws.yaml")
|
||||
with open(aws_file, "wt") as fp:
|
||||
yaml.safe_dump(aws_conf, fp)
|
||||
|
||||
kubernetes_syncer = detect_cluster_syncer(None, kubernetes_file)
|
||||
self.assertTrue(issubclass(kubernetes_syncer, KubernetesSyncer))
|
||||
self.assertEqual(kubernetes_syncer._namespace, "test_ray")
|
||||
|
||||
docker_syncer = detect_cluster_syncer(None, docker_file)
|
||||
self.assertTrue(issubclass(docker_syncer, DockerSyncer))
|
||||
|
||||
aws_syncer = detect_cluster_syncer(None, aws_file)
|
||||
self.assertEqual(aws_syncer, None)
|
||||
|
||||
# Should still return DockerSyncer, since it was passed explicitly
|
||||
syncer = detect_cluster_syncer(
|
||||
tune.SyncConfig(syncer=DockerSyncer), kubernetes_file
|
||||
)
|
||||
self.assertTrue(issubclass(syncer, DockerSyncer))
|
||||
|
||||
@patch(
|
||||
"ray.tune.syncer.get_rsync_template_if_available",
|
||||
lambda: "rsync {source} {target}",
|
||||
)
|
||||
def testNoSyncToDriver(self):
|
||||
"""Test that sync to driver is disabled"""
|
||||
|
||||
class _Trial:
|
||||
def __init__(self, id, logdir):
|
||||
self.id = (id,)
|
||||
self.logdir = logdir
|
||||
|
||||
trial = _Trial("0", "some_dir")
|
||||
|
||||
sync_config = tune.SyncConfig(syncer=None)
|
||||
|
||||
# Create syncer callbacks
|
||||
callbacks = create_default_callbacks([], sync_config)
|
||||
syncer_callback = callbacks[-1]
|
||||
|
||||
# Sanity check that we got the syncer callback
|
||||
self.assertTrue(isinstance(syncer_callback, SyncerCallback))
|
||||
|
||||
# Sync function should be false (no sync to driver)
|
||||
self.assertEqual(syncer_callback._sync_function, False)
|
||||
|
||||
# Sync to driver is disabled, so this should be no-op
|
||||
trial_syncer = syncer_callback._get_trial_syncer(trial)
|
||||
self.assertEqual(trial_syncer.sync_client, NOOP)
|
||||
|
||||
def testSyncWaitRetry(self):
|
||||
class CountingClient(CommandBasedClient):
|
||||
def __init__(self, *args, **kwargs):
|
||||
self._sync_ups = 0
|
||||
self._sync_downs = 0
|
||||
super(CountingClient, self).__init__(*args, **kwargs)
|
||||
|
||||
def _start_process(self, cmd):
|
||||
if "UPLOAD" in cmd:
|
||||
self._sync_ups += 1
|
||||
elif "DOWNLOAD" in cmd:
|
||||
self._sync_downs += 1
|
||||
if self._sync_downs == 1:
|
||||
self._last_cmd = "echo DOWNLOAD && true"
|
||||
return super(CountingClient, self)._start_process(cmd)
|
||||
|
||||
client = CountingClient(
|
||||
"echo UPLOAD {source} {target} && false",
|
||||
"echo DOWNLOAD {source} {target} && false",
|
||||
"echo DELETE {target}",
|
||||
)
|
||||
|
||||
# Fail always
|
||||
with self.assertRaisesRegex(TuneError, "Failed sync even after"):
|
||||
client.sync_up("test_source", "test_target")
|
||||
client.wait_or_retry(max_retries=3, backoff_s=0)
|
||||
|
||||
self.assertEquals(client._sync_ups, 3)
|
||||
|
||||
# Succeed after second try
|
||||
client.sync_down("test_source", "test_target")
|
||||
client.wait_or_retry(max_retries=3, backoff_s=0)
|
||||
|
||||
self.assertEquals(client._sync_downs, 2)
|
||||
|
||||
def _check_dir_contents(self, path: str):
|
||||
assert os.path.exists(os.path.join(path, "dir_level0"))
|
||||
assert os.path.exists(os.path.join(path, "dir_level0", "dir_level1"))
|
||||
assert os.path.exists(os.path.join(path, "dir_level0", "file_level1.txt"))
|
||||
with open(os.path.join(path, "dir_level0", "file_level1.txt"), "r") as f:
|
||||
assert f.read() == "Data\n"
|
||||
|
||||
def _testSyncInNodeAndDelete(self, num_workers: int = 1):
|
||||
temp_source = tempfile.mkdtemp()
|
||||
temp_up_target = tempfile.mkdtemp()
|
||||
temp_down_target = tempfile.mkdtemp()
|
||||
self.addCleanup(shutil.rmtree, temp_source)
|
||||
self.addCleanup(shutil.rmtree, temp_up_target, ignore_errors=True)
|
||||
self.addCleanup(shutil.rmtree, temp_down_target)
|
||||
|
||||
os.makedirs(os.path.join(temp_source, "dir_level0", "dir_level1"))
|
||||
with open(os.path.join(temp_source, "dir_level0", "file_level1.txt"), "w") as f:
|
||||
f.write("Data\n")
|
||||
|
||||
# Sanity check
|
||||
self._check_dir_contents(temp_source)
|
||||
node_ip = ray.util.get_node_ip_address()
|
||||
|
||||
futures = [
|
||||
_sync_dir_on_same_node(
|
||||
ip=node_ip,
|
||||
source_path=temp_source,
|
||||
target_path=temp_up_target,
|
||||
return_futures=True,
|
||||
)
|
||||
for i in range(num_workers)
|
||||
]
|
||||
ray.get(futures)
|
||||
|
||||
# Check sync up
|
||||
self._check_dir_contents(temp_up_target)
|
||||
|
||||
assert not os.listdir(temp_down_target)
|
||||
|
||||
futures = [
|
||||
_sync_dir_on_same_node(
|
||||
ip=node_ip,
|
||||
source_path=temp_up_target,
|
||||
target_path=temp_down_target,
|
||||
return_futures=True,
|
||||
)
|
||||
for i in range(num_workers)
|
||||
]
|
||||
ray.get(futures)
|
||||
|
||||
# Check sync down
|
||||
self._check_dir_contents(temp_down_target)
|
||||
|
||||
# Delete in some dir
|
||||
delete_on_node(node_ip=node_ip, path=temp_up_target)
|
||||
|
||||
assert not os.path.exists(temp_up_target)
|
||||
|
||||
def testSyncInNodeAndDelete(self):
|
||||
self._testSyncInNodeAndDelete(num_workers=1)
|
||||
|
||||
def testSyncInNodeAndDeleteMultipleWorkers(self):
|
||||
self._testSyncInNodeAndDelete(num_workers=8)
|
||||
|
||||
def _testSyncBetweenNodesAndDelete(self, num_workers: int = 1):
|
||||
temp_source = tempfile.mkdtemp()
|
||||
temp_up_target = tempfile.mkdtemp()
|
||||
temp_down_target = tempfile.mkdtemp()
|
||||
self.addCleanup(shutil.rmtree, temp_source)
|
||||
self.addCleanup(shutil.rmtree, temp_up_target, ignore_errors=True)
|
||||
self.addCleanup(shutil.rmtree, temp_down_target)
|
||||
|
||||
os.makedirs(os.path.join(temp_source, "dir_level0", "dir_level1"))
|
||||
with open(os.path.join(temp_source, "dir_level0", "file_level1.txt"), "w") as f:
|
||||
f.write("Data\n")
|
||||
|
||||
# Sanity check
|
||||
self._check_dir_contents(temp_source)
|
||||
node_ip = ray.util.get_node_ip_address()
|
||||
|
||||
futures = [
|
||||
_sync_dir_between_different_nodes(
|
||||
source_ip=node_ip,
|
||||
source_path=temp_source,
|
||||
target_ip=node_ip,
|
||||
target_path=temp_up_target,
|
||||
return_futures=True,
|
||||
)[0]
|
||||
for i in range(num_workers)
|
||||
]
|
||||
ray.get(futures)
|
||||
|
||||
# Check sync up
|
||||
self._check_dir_contents(temp_up_target)
|
||||
|
||||
# Max size exceeded
|
||||
with self.assertRaises(RayTaskError):
|
||||
_sync_dir_between_different_nodes(
|
||||
source_ip=node_ip,
|
||||
source_path=temp_up_target,
|
||||
target_ip=node_ip,
|
||||
target_path=temp_down_target,
|
||||
max_size_bytes=2,
|
||||
)
|
||||
|
||||
assert not os.listdir(temp_down_target)
|
||||
|
||||
futures = [
|
||||
_sync_dir_between_different_nodes(
|
||||
source_ip=node_ip,
|
||||
source_path=temp_up_target,
|
||||
target_ip=node_ip,
|
||||
target_path=temp_down_target,
|
||||
return_futures=True,
|
||||
)[0]
|
||||
for i in range(num_workers)
|
||||
]
|
||||
ray.get(futures)
|
||||
|
||||
# Check sync down
|
||||
self._check_dir_contents(temp_down_target)
|
||||
|
||||
# Delete in some dir
|
||||
delete_on_node(node_ip=node_ip, path=temp_up_target)
|
||||
|
||||
assert not os.path.exists(temp_up_target)
|
||||
|
||||
def testSyncBetweenNodesAndDelete(self):
|
||||
self._testSyncBetweenNodesAndDelete(num_workers=1)
|
||||
|
||||
def testSyncBetweenNodesAndDeleteMultipleWorkers(self):
|
||||
self._testSyncBetweenNodesAndDelete(num_workers=8)
|
||||
|
||||
def _prepareDirForTestSyncRemoteTask(self):
|
||||
temp_source = tempfile.mkdtemp()
|
||||
temp_up_target = tempfile.mkdtemp()
|
||||
temp_down_target = tempfile.mkdtemp()
|
||||
self.addCleanup(shutil.rmtree, temp_source)
|
||||
self.addCleanup(shutil.rmtree, temp_up_target)
|
||||
self.addCleanup(shutil.rmtree, temp_down_target)
|
||||
|
||||
os.makedirs(os.path.join(temp_source, "A", "a1"))
|
||||
os.makedirs(os.path.join(temp_source, "A", "a2"))
|
||||
os.makedirs(os.path.join(temp_source, "B", "b1"))
|
||||
with open(os.path.join(temp_source, "level_0.txt"), "wt") as fp:
|
||||
fp.write("Level 0\n")
|
||||
with open(os.path.join(temp_source, "A", "level_a1.txt"), "wt") as fp:
|
||||
fp.write("Level A1\n")
|
||||
with open(os.path.join(temp_source, "A", "a1", "level_a2.txt"), "wt") as fp:
|
||||
fp.write("Level A2\n")
|
||||
with open(os.path.join(temp_source, "B", "level_b1.txt"), "wt") as fp:
|
||||
fp.write("Level B1\n")
|
||||
return temp_source, temp_up_target, temp_down_target
|
||||
|
||||
def testSyncRemoteTaskOnlyDifferencesOnDifferentNodes(self):
|
||||
"""Tests the RemoteTaskClient sync client with different node logic.
|
||||
|
||||
In this test we generate a directory with multiple files.
|
||||
We then use both ``sync_down`` and ``sync_up`` to synchronize
|
||||
these to different directories (on the same node). We then assert
|
||||
that the files have been transferred correctly.
|
||||
|
||||
We then edit one of the files and add another one. We then sync
|
||||
up/down again. In this sync, we assert that only modified and new
|
||||
files are transferred.
|
||||
"""
|
||||
(
|
||||
temp_source,
|
||||
temp_up_target,
|
||||
temp_down_target,
|
||||
) = self._prepareDirForTestSyncRemoteTask()
|
||||
this_node_ip = ray.util.get_node_ip_address()
|
||||
|
||||
# Sync everything up
|
||||
client = RemoteTaskClientWithSyncDirBetweenDifferentNodes(_store_remotes=True)
|
||||
client.sync_up(source=temp_source, target=(this_node_ip, temp_up_target))
|
||||
client.wait()
|
||||
|
||||
# Assume that we synced everything up to second level
|
||||
self.assertTrue(
|
||||
os.path.exists(os.path.join(temp_up_target, "A", "a1", "level_a2.txt")),
|
||||
msg=f"Contents: {os.listdir(temp_up_target)}",
|
||||
)
|
||||
with open(os.path.join(temp_up_target, "A", "a1", "level_a2.txt"), "rt") as fp:
|
||||
self.assertEqual(fp.read(), "Level A2\n")
|
||||
|
||||
# Sync everything down
|
||||
client.sync_down(source=(this_node_ip, temp_source), target=temp_down_target)
|
||||
client.wait()
|
||||
|
||||
# Assume that we synced everything up to second level
|
||||
self.assertTrue(
|
||||
os.path.exists(os.path.join(temp_down_target, "A", "a1", "level_a2.txt")),
|
||||
msg=f"Contents: {os.listdir(temp_down_target)}",
|
||||
)
|
||||
with open(
|
||||
os.path.join(temp_down_target, "A", "a1", "level_a2.txt"), "rt"
|
||||
) as fp:
|
||||
self.assertEqual(fp.read(), "Level A2\n")
|
||||
|
||||
# Now, edit some stuff in our source. Then confirm only these
|
||||
# edited files are synced
|
||||
with open(os.path.join(temp_source, "A", "a1", "level_a2.txt"), "wt") as fp:
|
||||
fp.write("Level X2\n") # Same length
|
||||
with open(os.path.join(temp_source, "A", "level_a1x.txt"), "wt") as fp:
|
||||
fp.write("Level A1X\n") # New file
|
||||
|
||||
# Sync up
|
||||
client.sync_up(source=temp_source, target=(this_node_ip, temp_up_target))
|
||||
|
||||
# Hi-jack futures
|
||||
files_stats = ray.get(client._stored_files_stats)
|
||||
tarball = ray.get(client._stored_pack_actor_ref.get_full_data.remote())
|
||||
client.wait()
|
||||
|
||||
# Existing file should have new content
|
||||
with open(os.path.join(temp_up_target, "A", "a1", "level_a2.txt"), "rt") as fp:
|
||||
self.assertEqual(fp.read(), "Level X2\n")
|
||||
|
||||
# New file should be there
|
||||
with open(os.path.join(temp_up_target, "A", "level_a1x.txt"), "rt") as fp:
|
||||
self.assertEqual(fp.read(), "Level A1X\n")
|
||||
|
||||
# Old file should be there
|
||||
with open(os.path.join(temp_up_target, "B", "level_b1.txt"), "rt") as fp:
|
||||
self.assertEqual(fp.read(), "Level B1\n")
|
||||
|
||||
# In the target dir, level_a1x was not contained
|
||||
self.assertIn(os.path.join("A", "a1", "level_a2.txt"), files_stats)
|
||||
self.assertNotIn(os.path.join("A", "level_a1x.txt"), files_stats)
|
||||
|
||||
# Inspect tarball
|
||||
with tarfile.open(fileobj=io.BytesIO(tarball)) as tar:
|
||||
files_in_tar = tar.getnames()
|
||||
self.assertIn(os.path.join("A", "a1", "level_a2.txt"), files_in_tar)
|
||||
self.assertIn(os.path.join("A", "level_a1x.txt"), files_in_tar)
|
||||
self.assertNotIn(os.path.join("A", "level_a1.txt"), files_in_tar)
|
||||
# 6 directories (including root) + 2 files
|
||||
self.assertEqual(len(files_in_tar), 8, msg=str(files_in_tar))
|
||||
|
||||
# Sync down
|
||||
client.sync_down(source=(this_node_ip, temp_source), target=temp_down_target)
|
||||
|
||||
# Hi-jack futures
|
||||
files_stats = ray.get(client._stored_files_stats)
|
||||
tarball = ray.get(client._stored_pack_actor_ref.get_full_data.remote())
|
||||
client.wait()
|
||||
|
||||
# Existing file should have new content
|
||||
with open(
|
||||
os.path.join(temp_down_target, "A", "a1", "level_a2.txt"), "rt"
|
||||
) as fp:
|
||||
self.assertEqual(fp.read(), "Level X2\n")
|
||||
|
||||
# New file should be there
|
||||
with open(os.path.join(temp_down_target, "A", "level_a1x.txt"), "rt") as fp:
|
||||
self.assertEqual(fp.read(), "Level A1X\n")
|
||||
|
||||
# Old file should be there
|
||||
with open(os.path.join(temp_down_target, "B", "level_b1.txt"), "rt") as fp:
|
||||
self.assertEqual(fp.read(), "Level B1\n")
|
||||
|
||||
# In the target dir, level_a1x was not contained
|
||||
self.assertIn(os.path.join("A", "a1", "level_a2.txt"), files_stats)
|
||||
self.assertNotIn(os.path.join("A", "level_a1x.txt"), files_stats)
|
||||
|
||||
# Inspect tarball
|
||||
with tarfile.open(fileobj=io.BytesIO(tarball)) as tar:
|
||||
files_in_tar = tar.getnames()
|
||||
self.assertIn(os.path.join("A", "a1", "level_a2.txt"), files_in_tar)
|
||||
self.assertIn(os.path.join("A", "level_a1x.txt"), files_in_tar)
|
||||
self.assertNotIn(os.path.join("A", "level_a1.txt"), files_in_tar)
|
||||
# 6 directories (including root) + 2 files
|
||||
self.assertEqual(len(files_in_tar), 8, msg=str(files_in_tar))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import pytest
|
||||
|
||||
sys.exit(pytest.main(["-v", __file__]))
|
369
python/ray/tune/tests/test_syncer.py
Normal file
369
python/ray/tune/tests/test_syncer.py
Normal file
|
@ -0,0 +1,369 @@
|
|||
import logging
|
||||
from typing import List, Optional
|
||||
|
||||
import os
|
||||
|
||||
import pytest
|
||||
import shutil
|
||||
import tempfile
|
||||
|
||||
from freezegun import freeze_time
|
||||
|
||||
import ray
|
||||
|
||||
from ray import tune
|
||||
from ray.tune import TuneError
|
||||
from ray.tune.syncer import _DefaultSyncer, Syncer, _validate_upload_dir
|
||||
from ray.tune.utils.file_transfer import _pack_dir, _unpack_dir
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def ray_start_2_cpus():
|
||||
address_info = ray.init(num_cpus=2, configure_logging=False)
|
||||
yield address_info
|
||||
# The code after the yield will run as teardown code.
|
||||
ray.shutdown()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_data_dirs():
|
||||
tmp_source = os.path.realpath(tempfile.mkdtemp())
|
||||
tmp_target = os.path.realpath(tempfile.mkdtemp())
|
||||
|
||||
os.makedirs(os.path.join(tmp_source, "subdir", "nested"))
|
||||
os.makedirs(os.path.join(tmp_source, "subdir_exclude", "something"))
|
||||
|
||||
files = [
|
||||
"level0.txt",
|
||||
"level0_exclude.txt",
|
||||
"subdir/level1.txt",
|
||||
"subdir/level1_exclude.txt",
|
||||
"subdir/nested/level2.txt",
|
||||
"subdir_nested_level2_exclude.txt",
|
||||
"subdir_exclude/something/somewhere.txt",
|
||||
]
|
||||
|
||||
for file in files:
|
||||
with open(os.path.join(tmp_source, file), "w") as f:
|
||||
f.write("Data")
|
||||
|
||||
yield tmp_source, tmp_target
|
||||
|
||||
shutil.rmtree(tmp_source)
|
||||
shutil.rmtree(tmp_target)
|
||||
|
||||
|
||||
def assert_file(exists: bool, root: str, path: str):
|
||||
full_path = os.path.join(root, path)
|
||||
|
||||
if exists:
|
||||
assert os.path.exists(full_path)
|
||||
else:
|
||||
assert not os.path.exists(full_path)
|
||||
|
||||
|
||||
class TestTrainable(tune.Trainable):
|
||||
def save_checkpoint(self, checkpoint_dir: str):
|
||||
with open(os.path.join(checkpoint_dir, "checkpoint.data"), "w") as f:
|
||||
f.write("Data")
|
||||
return checkpoint_dir
|
||||
|
||||
|
||||
class CustomSyncer(Syncer):
|
||||
def __init__(self, sync_period: float = 300.0):
|
||||
super(CustomSyncer, self).__init__(sync_period=sync_period)
|
||||
self._sync_status = {}
|
||||
|
||||
def sync_up(
|
||||
self, local_dir: str, remote_dir: str, exclude: Optional[List] = None
|
||||
) -> bool:
|
||||
with open(os.path.join(local_dir, "custom_syncer.txt"), "w") as f:
|
||||
f.write("Data\n")
|
||||
self._sync_status[remote_dir] = _pack_dir(local_dir)
|
||||
return True
|
||||
|
||||
def sync_down(
|
||||
self, remote_dir: str, local_dir: str, exclude: Optional[List] = None
|
||||
) -> bool:
|
||||
if remote_dir not in self._sync_status:
|
||||
return False
|
||||
_unpack_dir(self._sync_status[remote_dir], local_dir)
|
||||
return True
|
||||
|
||||
def delete(self, remote_dir: str) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
def retry(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def wait(self):
|
||||
pass
|
||||
|
||||
|
||||
def test_sync_string_invalid_uri():
|
||||
with pytest.raises(ValueError):
|
||||
_validate_upload_dir(tune.SyncConfig(upload_dir="invalid://some/url"))
|
||||
|
||||
|
||||
def test_sync_string_invalid_local():
|
||||
with pytest.raises(ValueError):
|
||||
_validate_upload_dir(tune.SyncConfig(upload_dir="/invalid/dir"))
|
||||
|
||||
|
||||
def test_sync_string_valid_local():
|
||||
_validate_upload_dir(tune.SyncConfig(upload_dir="file:///valid/dir"))
|
||||
|
||||
|
||||
def test_sync_string_valid_s3():
|
||||
_validate_upload_dir(tune.SyncConfig(upload_dir="s3://valid/bucket"))
|
||||
|
||||
|
||||
def test_syncer_sync_up_down(temp_data_dirs):
|
||||
"""Check that syncing up and down works"""
|
||||
tmp_source, tmp_target = temp_data_dirs
|
||||
|
||||
syncer = _DefaultSyncer()
|
||||
|
||||
syncer.sync_up(
|
||||
local_dir=tmp_source, remote_dir="memory:///test/test_syncer_sync_up_down"
|
||||
)
|
||||
syncer.wait()
|
||||
|
||||
syncer.sync_down(
|
||||
remote_dir="memory:///test/test_syncer_sync_up_down", local_dir=tmp_target
|
||||
)
|
||||
syncer.wait()
|
||||
|
||||
# Target dir should have all files
|
||||
assert_file(True, tmp_target, "level0.txt")
|
||||
assert_file(True, tmp_target, "level0_exclude.txt")
|
||||
assert_file(True, tmp_target, "subdir/level1.txt")
|
||||
assert_file(True, tmp_target, "subdir/level1_exclude.txt")
|
||||
assert_file(True, tmp_target, "subdir/nested/level2.txt")
|
||||
assert_file(True, tmp_target, "subdir_nested_level2_exclude.txt")
|
||||
assert_file(True, tmp_target, "subdir_exclude/something/somewhere.txt")
|
||||
|
||||
|
||||
def test_syncer_sync_exclude(temp_data_dirs):
|
||||
"""Check that the exclude parameter works"""
|
||||
tmp_source, tmp_target = temp_data_dirs
|
||||
|
||||
syncer = _DefaultSyncer()
|
||||
|
||||
syncer.sync_up(
|
||||
local_dir=tmp_source,
|
||||
remote_dir="memory:///test/test_syncer_sync_exclude",
|
||||
exclude=["*_exclude*"],
|
||||
)
|
||||
syncer.wait()
|
||||
|
||||
syncer.sync_down(
|
||||
remote_dir="memory:///test/test_syncer_sync_exclude", local_dir=tmp_target
|
||||
)
|
||||
syncer.wait()
|
||||
|
||||
# Excluded files should not be found in target
|
||||
assert_file(True, tmp_target, "level0.txt")
|
||||
assert_file(False, tmp_target, "level0_exclude.txt")
|
||||
assert_file(True, tmp_target, "subdir/level1.txt")
|
||||
assert_file(False, tmp_target, "subdir/level1_exclude.txt")
|
||||
assert_file(True, tmp_target, "subdir/nested/level2.txt")
|
||||
assert_file(False, tmp_target, "subdir_nested_level2_exclude.txt")
|
||||
assert_file(False, tmp_target, "subdir_exclude/something/somewhere.txt")
|
||||
|
||||
|
||||
def test_sync_up_if_needed(temp_data_dirs):
|
||||
"""Check that we only sync up again after sync period"""
|
||||
tmp_source, tmp_target = temp_data_dirs
|
||||
|
||||
with freeze_time() as frozen:
|
||||
syncer = _DefaultSyncer(sync_period=60)
|
||||
assert syncer.sync_up_if_needed(
|
||||
local_dir=tmp_source, remote_dir="memory:///test/test_sync_up_not_needed"
|
||||
)
|
||||
syncer.wait()
|
||||
|
||||
frozen.tick(30)
|
||||
|
||||
# Sync period not over, yet
|
||||
assert not syncer.sync_up_if_needed(
|
||||
local_dir=tmp_source, remote_dir="memory:///test/test_sync_up_not_needed"
|
||||
)
|
||||
|
||||
frozen.tick(30)
|
||||
|
||||
# Sync period over, sync again
|
||||
assert syncer.sync_up_if_needed(
|
||||
local_dir=tmp_source, remote_dir="memory:///test/test_sync_up_not_needed"
|
||||
)
|
||||
|
||||
|
||||
def test_sync_down_if_needed(temp_data_dirs):
|
||||
"""Check that we only sync down again after sync period"""
|
||||
tmp_source, tmp_target = temp_data_dirs
|
||||
|
||||
with freeze_time() as frozen:
|
||||
syncer = _DefaultSyncer(sync_period=60)
|
||||
|
||||
# Populate remote directory
|
||||
syncer.sync_up(
|
||||
local_dir=tmp_source, remote_dir="memory:///test/test_sync_down_if_needed"
|
||||
)
|
||||
syncer.wait()
|
||||
|
||||
assert syncer.sync_down_if_needed(
|
||||
remote_dir="memory:///test/test_sync_down_if_needed", local_dir=tmp_target
|
||||
)
|
||||
syncer.wait()
|
||||
|
||||
frozen.tick(30)
|
||||
|
||||
# Sync period not over, yet
|
||||
assert not syncer.sync_down_if_needed(
|
||||
remote_dir="memory:///test/test_sync_down_if_needed", local_dir=tmp_target
|
||||
)
|
||||
|
||||
frozen.tick(30)
|
||||
|
||||
# Sync period over, sync again
|
||||
assert syncer.sync_down_if_needed(
|
||||
remote_dir="memory:///test/test_sync_down_if_needed", local_dir=tmp_target
|
||||
)
|
||||
|
||||
|
||||
def test_syncer_still_running_no_sync(temp_data_dirs):
|
||||
"""Check that no new sync is issued if old sync is still running"""
|
||||
tmp_source, tmp_target = temp_data_dirs
|
||||
|
||||
class FakeSyncProcess:
|
||||
@property
|
||||
def is_running(self):
|
||||
return True
|
||||
|
||||
syncer = _DefaultSyncer(sync_period=60)
|
||||
syncer._sync_process = FakeSyncProcess()
|
||||
assert not syncer.sync_up_if_needed(
|
||||
local_dir=tmp_source,
|
||||
remote_dir="memory:///test/test_syncer_still_running_no_sync",
|
||||
)
|
||||
|
||||
|
||||
def test_syncer_not_running_sync(temp_data_dirs):
|
||||
"""Check that new sync is issued if old sync completed"""
|
||||
tmp_source, tmp_target = temp_data_dirs
|
||||
|
||||
class FakeSyncProcess:
|
||||
@property
|
||||
def is_running(self):
|
||||
return False
|
||||
|
||||
def wait(self):
|
||||
return True
|
||||
|
||||
syncer = _DefaultSyncer(sync_period=60)
|
||||
syncer._sync_process = FakeSyncProcess()
|
||||
assert syncer.sync_up_if_needed(
|
||||
local_dir=tmp_source,
|
||||
remote_dir="memory:///test/test_syncer_not_running_sync",
|
||||
)
|
||||
|
||||
|
||||
def test_syncer_not_running_sync_last_failed(caplog, temp_data_dirs):
|
||||
"""Check that new sync is issued if old sync completed"""
|
||||
caplog.set_level(logging.WARNING)
|
||||
|
||||
tmp_source, tmp_target = temp_data_dirs
|
||||
|
||||
class FakeSyncProcess:
|
||||
@property
|
||||
def is_running(self):
|
||||
return False
|
||||
|
||||
def wait(self):
|
||||
raise RuntimeError("Sync failed")
|
||||
|
||||
syncer = _DefaultSyncer(sync_period=60)
|
||||
syncer._sync_process = FakeSyncProcess()
|
||||
assert syncer.sync_up_if_needed(
|
||||
local_dir=tmp_source,
|
||||
remote_dir="memory:///test/test_syncer_not_running_sync",
|
||||
)
|
||||
assert "Last sync command failed" in caplog.text
|
||||
|
||||
|
||||
def test_syncer_delete(temp_data_dirs):
|
||||
"""Check that deletion on remote storage works"""
|
||||
tmp_source, tmp_target = temp_data_dirs
|
||||
|
||||
syncer = _DefaultSyncer(sync_period=60)
|
||||
|
||||
# Populate remote directory
|
||||
syncer.sync_up(local_dir=tmp_source, remote_dir="memory:///test/test_syncer_delete")
|
||||
syncer.wait()
|
||||
|
||||
syncer.delete(remote_dir="memory:///test/test_syncer_delete")
|
||||
|
||||
syncer.sync_down(
|
||||
remote_dir="memory:///test/test_syncer_delete", local_dir=tmp_target
|
||||
)
|
||||
with pytest.raises(TuneError):
|
||||
syncer.wait()
|
||||
|
||||
# Remote storage was deleted, so target should be empty
|
||||
assert_file(False, tmp_target, "level0.txt")
|
||||
assert_file(False, tmp_target, "level0_exclude.txt")
|
||||
assert_file(False, tmp_target, "subdir/level1.txt")
|
||||
assert_file(False, tmp_target, "subdir/level1_exclude.txt")
|
||||
assert_file(False, tmp_target, "subdir/nested/level2.txt")
|
||||
assert_file(False, tmp_target, "subdir_nested_level2_exclude.txt")
|
||||
assert_file(False, tmp_target, "subdir_exclude/something/somewhere.txt")
|
||||
|
||||
|
||||
def test_syncer_wait_or_retry(temp_data_dirs):
|
||||
"""Check that the wait or retry API works"""
|
||||
tmp_source, tmp_target = temp_data_dirs
|
||||
|
||||
syncer = _DefaultSyncer(sync_period=60)
|
||||
|
||||
# Will fail as dir does not exist
|
||||
syncer.sync_down(
|
||||
remote_dir="memory:///test/test_syncer_wait_or_retry", local_dir=tmp_target
|
||||
)
|
||||
with pytest.raises(TuneError) as e:
|
||||
syncer.wait_or_retry(max_retries=3, backoff_s=0)
|
||||
assert "Failed sync even after 3 retries." in str(e)
|
||||
|
||||
|
||||
def test_trainable_syncer_default(ray_start_2_cpus, temp_data_dirs):
|
||||
"""Check that Trainable.save() triggers syncing using default syncing"""
|
||||
tmp_source, tmp_target = temp_data_dirs
|
||||
|
||||
trainable = ray.remote(TestTrainable).remote(
|
||||
remote_checkpoint_dir=f"file://{tmp_target}"
|
||||
)
|
||||
|
||||
checkpoint_dir = ray.get(trainable.save.remote())
|
||||
|
||||
assert_file(True, tmp_target, os.path.join(checkpoint_dir, "checkpoint.data"))
|
||||
assert_file(False, tmp_target, os.path.join(checkpoint_dir, "custom_syncer.txt"))
|
||||
|
||||
|
||||
def test_trainable_syncer_custom(ray_start_2_cpus, temp_data_dirs):
|
||||
"""Check that Trainable.save() triggers syncing using custom syncer"""
|
||||
tmp_source, tmp_target = temp_data_dirs
|
||||
|
||||
trainable = ray.remote(TestTrainable).remote(
|
||||
remote_checkpoint_dir=f"file://{tmp_target}",
|
||||
custom_syncer=CustomSyncer(),
|
||||
)
|
||||
|
||||
checkpoint_dir = ray.get(trainable.save.remote())
|
||||
|
||||
assert_file(True, tmp_target, os.path.join(checkpoint_dir, "checkpoint.data"))
|
||||
assert_file(True, tmp_target, os.path.join(checkpoint_dir, "custom_syncer.txt"))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
|
||||
sys.exit(pytest.main(["-v", __file__]))
|
392
python/ray/tune/tests/test_syncer_callback.py
Normal file
392
python/ray/tune/tests/test_syncer_callback.py
Normal file
|
@ -0,0 +1,392 @@
|
|||
import logging
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
from freezegun import freeze_time
|
||||
|
||||
import ray.util
|
||||
from ray.tune import TuneError
|
||||
from ray.tune.result import NODE_IP
|
||||
from ray.tune.syncer import (
|
||||
DEFAULT_SYNC_PERIOD,
|
||||
SyncConfig,
|
||||
SyncerCallback,
|
||||
_BackgroundProcess,
|
||||
)
|
||||
from ray.tune.utils.callback import create_default_callbacks
|
||||
from ray.tune.utils.file_transfer import sync_dir_between_nodes
|
||||
from ray.util.ml_utils.checkpoint_manager import CheckpointStorage, _TrackedCheckpoint
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def ray_start_2_cpus():
|
||||
address_info = ray.init(num_cpus=2, configure_logging=False)
|
||||
yield address_info
|
||||
# The code after the yield will run as teardown code.
|
||||
ray.shutdown()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_data_dirs():
|
||||
tmp_source = os.path.realpath(tempfile.mkdtemp())
|
||||
tmp_target = os.path.realpath(tempfile.mkdtemp())
|
||||
|
||||
os.makedirs(os.path.join(tmp_source, "subdir", "nested"))
|
||||
os.makedirs(os.path.join(tmp_source, "subdir_exclude", "something"))
|
||||
|
||||
files = [
|
||||
"level0.txt",
|
||||
"level0_exclude.txt",
|
||||
"subdir/level1.txt",
|
||||
"subdir/level1_exclude.txt",
|
||||
"subdir/nested/level2.txt",
|
||||
"subdir_nested_level2_exclude.txt",
|
||||
"subdir_exclude/something/somewhere.txt",
|
||||
]
|
||||
|
||||
for file in files:
|
||||
with open(os.path.join(tmp_source, file), "w") as f:
|
||||
f.write("Data")
|
||||
|
||||
yield tmp_source, tmp_target
|
||||
|
||||
shutil.rmtree(tmp_source)
|
||||
shutil.rmtree(tmp_target)
|
||||
|
||||
|
||||
def assert_file(exists: bool, root: str, path: str):
|
||||
full_path = os.path.join(root, path)
|
||||
|
||||
if exists:
|
||||
assert os.path.exists(full_path)
|
||||
else:
|
||||
assert not os.path.exists(full_path)
|
||||
|
||||
|
||||
class MockTrial:
|
||||
def __init__(self, trial_id: str, logdir: str):
|
||||
self.trial_id = trial_id
|
||||
self.last_result = {NODE_IP: ray.util.get_node_ip_address()}
|
||||
self.uses_cloud_checkpointing = False
|
||||
self.sync_on_checkpoint = True
|
||||
|
||||
self.logdir = logdir
|
||||
|
||||
|
||||
class TestSyncerCallback(SyncerCallback):
|
||||
def __init__(
|
||||
self,
|
||||
enabled: bool = True,
|
||||
sync_period: float = DEFAULT_SYNC_PERIOD,
|
||||
local_logdir_override: Optional[str] = None,
|
||||
remote_logdir_override: Optional[str] = None,
|
||||
):
|
||||
super(TestSyncerCallback, self).__init__(
|
||||
enabled=enabled, sync_period=sync_period
|
||||
)
|
||||
self.local_logdir_override = local_logdir_override
|
||||
self.remote_logdir_override = remote_logdir_override
|
||||
|
||||
def _local_trial_logdir(self, trial):
|
||||
if self.local_logdir_override:
|
||||
return self.local_logdir_override
|
||||
return super(TestSyncerCallback, self)._local_trial_logdir(trial)
|
||||
|
||||
def _remote_trial_logdir(self, trial):
|
||||
if self.remote_logdir_override:
|
||||
return self.remote_logdir_override
|
||||
return super(TestSyncerCallback, self)._remote_trial_logdir(trial)
|
||||
|
||||
def _get_trial_sync_process(self, trial):
|
||||
return self._sync_processes.setdefault(
|
||||
trial.trial_id, MaybeFailingProcess(sync_dir_between_nodes)
|
||||
)
|
||||
|
||||
|
||||
class MaybeFailingProcess(_BackgroundProcess):
|
||||
should_fail = False
|
||||
|
||||
def wait(self):
|
||||
result = super(MaybeFailingProcess, self).wait()
|
||||
if self.should_fail:
|
||||
raise TuneError("Syncing failed.")
|
||||
return result
|
||||
|
||||
|
||||
def test_syncer_callback_disabled():
|
||||
"""Check that syncer=None disables callback"""
|
||||
callbacks = create_default_callbacks(
|
||||
callbacks=[], sync_config=SyncConfig(syncer=None)
|
||||
)
|
||||
syncer_callback = None
|
||||
for cb in callbacks:
|
||||
if isinstance(cb, SyncerCallback):
|
||||
syncer_callback = cb
|
||||
|
||||
trial1 = MockTrial(trial_id="a", logdir=None)
|
||||
trial1.uses_cloud_checkpointing = False
|
||||
|
||||
assert syncer_callback
|
||||
assert not syncer_callback._enabled
|
||||
# Syncer disabled, so no-op
|
||||
assert not syncer_callback._sync_trial_dir(trial1)
|
||||
|
||||
# This should not raise any error for not existing directory
|
||||
syncer_callback.on_checkpoint(
|
||||
iteration=1,
|
||||
trials=[],
|
||||
trial=trial1,
|
||||
checkpoint=_TrackedCheckpoint(
|
||||
dir_or_data="/does/not/exist", storage_mode=CheckpointStorage.PERSISTENT
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def test_syncer_callback_noop_on_trial_cloud_checkpointing():
|
||||
"""Check that trial using cloud checkpointing disables sync to driver"""
|
||||
callbacks = create_default_callbacks(callbacks=[], sync_config=SyncConfig())
|
||||
syncer_callback = None
|
||||
for cb in callbacks:
|
||||
if isinstance(cb, SyncerCallback):
|
||||
syncer_callback = cb
|
||||
|
||||
trial1 = MockTrial(trial_id="a", logdir=None)
|
||||
trial1.uses_cloud_checkpointing = True
|
||||
|
||||
assert syncer_callback
|
||||
assert syncer_callback._enabled
|
||||
# Cloud checkpointing set, so no-op
|
||||
assert not syncer_callback._sync_trial_dir(trial1)
|
||||
|
||||
# This should not raise any error for not existing directory
|
||||
syncer_callback.on_checkpoint(
|
||||
iteration=1,
|
||||
trials=[],
|
||||
trial=trial1,
|
||||
checkpoint=_TrackedCheckpoint(
|
||||
dir_or_data="/does/not/exist", storage_mode=CheckpointStorage.PERSISTENT
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def test_syncer_callback_op_on_no_cloud_checkpointing():
|
||||
"""Check that without cloud checkpointing sync to driver is enabled"""
|
||||
callbacks = create_default_callbacks(callbacks=[], sync_config=SyncConfig())
|
||||
syncer_callback = None
|
||||
for cb in callbacks:
|
||||
if isinstance(cb, SyncerCallback):
|
||||
syncer_callback = cb
|
||||
|
||||
trial1 = MockTrial(trial_id="a", logdir=None)
|
||||
trial1.uses_cloud_checkpointing = False
|
||||
|
||||
assert syncer_callback
|
||||
assert syncer_callback._enabled
|
||||
assert syncer_callback._sync_trial_dir(trial1)
|
||||
|
||||
|
||||
def test_syncer_callback_sync(ray_start_2_cpus, temp_data_dirs):
|
||||
"""Check that on_trial_result triggers syncing"""
|
||||
tmp_source, tmp_target = temp_data_dirs
|
||||
|
||||
syncer_callback = TestSyncerCallback(local_logdir_override=tmp_target)
|
||||
|
||||
trial1 = MockTrial(trial_id="a", logdir=tmp_source)
|
||||
|
||||
syncer_callback.on_trial_result(iteration=1, trials=[], trial=trial1, result={})
|
||||
syncer_callback.wait_for_all()
|
||||
|
||||
assert_file(True, tmp_target, "level0.txt")
|
||||
assert_file(True, tmp_target, "level0_exclude.txt")
|
||||
assert_file(True, tmp_target, "subdir/level1.txt")
|
||||
assert_file(True, tmp_target, "subdir/level1_exclude.txt")
|
||||
assert_file(True, tmp_target, "subdir/nested/level2.txt")
|
||||
assert_file(True, tmp_target, "subdir_nested_level2_exclude.txt")
|
||||
assert_file(True, tmp_target, "subdir_exclude/something/somewhere.txt")
|
||||
|
||||
|
||||
def test_syncer_callback_sync_period(ray_start_2_cpus, temp_data_dirs):
|
||||
"""Check that on_trial_result triggers syncing, obeying sync period"""
|
||||
tmp_source, tmp_target = temp_data_dirs
|
||||
|
||||
with freeze_time() as frozen:
|
||||
syncer_callback = TestSyncerCallback(
|
||||
sync_period=60, local_logdir_override=tmp_target
|
||||
)
|
||||
|
||||
trial1 = MockTrial(trial_id="a", logdir=tmp_source)
|
||||
|
||||
syncer_callback.on_trial_result(iteration=1, trials=[], trial=trial1, result={})
|
||||
syncer_callback.wait_for_all()
|
||||
|
||||
assert_file(True, tmp_target, "level0.txt")
|
||||
assert_file(False, tmp_target, "level0_new.txt")
|
||||
|
||||
# Add new file to source directory
|
||||
with open(os.path.join(tmp_source, "level0_new.txt"), "w") as f:
|
||||
f.write("Data\n")
|
||||
|
||||
frozen.tick(30)
|
||||
|
||||
# Should not sync after 30 seconds
|
||||
syncer_callback.on_trial_result(iteration=2, trials=[], trial=trial1, result={})
|
||||
syncer_callback.wait_for_all()
|
||||
|
||||
assert_file(True, tmp_target, "level0.txt")
|
||||
assert_file(False, tmp_target, "level0_new.txt")
|
||||
|
||||
frozen.tick(30)
|
||||
|
||||
# Should sync after 60 seconds
|
||||
syncer_callback.on_trial_result(iteration=3, trials=[], trial=trial1, result={})
|
||||
syncer_callback.wait_for_all()
|
||||
|
||||
assert_file(True, tmp_target, "level0.txt")
|
||||
assert_file(True, tmp_target, "level0_new.txt")
|
||||
|
||||
|
||||
def test_syncer_callback_force_on_checkpoint(ray_start_2_cpus, temp_data_dirs):
|
||||
"""Check that on_checkpoint forces syncing"""
|
||||
tmp_source, tmp_target = temp_data_dirs
|
||||
|
||||
with freeze_time() as frozen:
|
||||
syncer_callback = TestSyncerCallback(
|
||||
sync_period=60, local_logdir_override=tmp_target
|
||||
)
|
||||
|
||||
trial1 = MockTrial(trial_id="a", logdir=tmp_source)
|
||||
|
||||
syncer_callback.on_trial_result(iteration=1, trials=[], trial=trial1, result={})
|
||||
syncer_callback.wait_for_all()
|
||||
|
||||
assert_file(True, tmp_target, "level0.txt")
|
||||
assert_file(False, tmp_target, "level0_new.txt")
|
||||
|
||||
# Add new file to source directory
|
||||
with open(os.path.join(tmp_source, "level0_new.txt"), "w") as f:
|
||||
f.write("Data\n")
|
||||
|
||||
assert_file(False, tmp_target, "level0_new.txt")
|
||||
|
||||
frozen.tick(30)
|
||||
|
||||
# Should sync as checkpoint observed
|
||||
syncer_callback.on_checkpoint(
|
||||
iteration=2,
|
||||
trials=[],
|
||||
trial=trial1,
|
||||
checkpoint=_TrackedCheckpoint(
|
||||
dir_or_data=tmp_target, storage_mode=CheckpointStorage.PERSISTENT
|
||||
),
|
||||
)
|
||||
syncer_callback.wait_for_all()
|
||||
|
||||
assert_file(True, tmp_target, "level0.txt")
|
||||
assert_file(True, tmp_target, "level0_new.txt")
|
||||
|
||||
|
||||
def test_syncer_callback_force_on_complete(ray_start_2_cpus, temp_data_dirs):
|
||||
"""Check that on_trial_complete forces syncing"""
|
||||
tmp_source, tmp_target = temp_data_dirs
|
||||
|
||||
with freeze_time() as frozen:
|
||||
syncer_callback = TestSyncerCallback(
|
||||
sync_period=60, local_logdir_override=tmp_target
|
||||
)
|
||||
|
||||
trial1 = MockTrial(trial_id="a", logdir=tmp_source)
|
||||
|
||||
syncer_callback.on_trial_result(iteration=1, trials=[], trial=trial1, result={})
|
||||
syncer_callback.wait_for_all()
|
||||
|
||||
assert_file(True, tmp_target, "level0.txt")
|
||||
assert_file(False, tmp_target, "level0_new.txt")
|
||||
|
||||
# Add new file to source directory
|
||||
with open(os.path.join(tmp_source, "level0_new.txt"), "w") as f:
|
||||
f.write("Data\n")
|
||||
|
||||
assert_file(False, tmp_target, "level0_new.txt")
|
||||
|
||||
frozen.tick(30)
|
||||
|
||||
# Should sync as checkpoint observed
|
||||
syncer_callback.on_trial_complete(
|
||||
iteration=2,
|
||||
trials=[],
|
||||
trial=trial1,
|
||||
)
|
||||
syncer_callback.wait_for_all()
|
||||
|
||||
assert_file(True, tmp_target, "level0.txt")
|
||||
assert_file(True, tmp_target, "level0_new.txt")
|
||||
|
||||
|
||||
def test_syncer_callback_wait_for_all_error(ray_start_2_cpus, temp_data_dirs):
|
||||
"""Check that syncer errors are caught correctly in wait_for_all()"""
|
||||
tmp_source, tmp_target = temp_data_dirs
|
||||
|
||||
syncer_callback = TestSyncerCallback(
|
||||
sync_period=0,
|
||||
local_logdir_override=tmp_target,
|
||||
)
|
||||
|
||||
trial1 = MockTrial(trial_id="a", logdir=tmp_source)
|
||||
|
||||
# Inject FailingProcess into callback
|
||||
sync_process = syncer_callback._get_trial_sync_process(trial1)
|
||||
sync_process.should_fail = True
|
||||
|
||||
# This sync will fail because the remote location does not exist
|
||||
syncer_callback.on_trial_result(iteration=1, trials=[], trial=trial1, result={})
|
||||
|
||||
with pytest.raises(TuneError) as e:
|
||||
syncer_callback.wait_for_all()
|
||||
assert "At least one" in e
|
||||
|
||||
|
||||
def test_syncer_callback_log_error(caplog, ray_start_2_cpus, temp_data_dirs):
|
||||
"""Check that errors in a previous sync are logged correctly"""
|
||||
caplog.set_level(logging.ERROR, logger="ray.tune.syncer")
|
||||
|
||||
tmp_source, tmp_target = temp_data_dirs
|
||||
|
||||
syncer_callback = TestSyncerCallback(
|
||||
sync_period=0,
|
||||
local_logdir_override=tmp_target,
|
||||
)
|
||||
|
||||
trial1 = MockTrial(trial_id="a", logdir=tmp_source)
|
||||
|
||||
# Inject FailingProcess into callback
|
||||
sync_process = syncer_callback._get_trial_sync_process(trial1)
|
||||
|
||||
syncer_callback.on_trial_result(iteration=1, trials=[], trial=trial1, result={})
|
||||
|
||||
# So far we haven't wait()ed, so no error, yet
|
||||
assert not caplog.text
|
||||
assert_file(False, tmp_target, "level0.txt")
|
||||
|
||||
sync_process.should_fail = True
|
||||
|
||||
# When the previous sync processes fails, an error is logged but sync is restarted
|
||||
syncer_callback.on_trial_complete(iteration=2, trials=[], trial=trial1)
|
||||
|
||||
assert (
|
||||
"An error occurred during the checkpoint syncing of the previous checkpoint"
|
||||
in caplog.text
|
||||
)
|
||||
|
||||
sync_process.should_fail = False
|
||||
|
||||
syncer_callback.wait_for_all()
|
||||
assert_file(True, tmp_target, "level0.txt")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
|
||||
sys.exit(pytest.main(["-v", __file__]))
|
|
@ -27,7 +27,7 @@ from ray.tune.utils.trainable import TrainableUtil
|
|||
@pytest.mark.parametrize("logdir", ["~/tmp/exp/trial", "~/tmp/exp/trial/"])
|
||||
def test_find_rel_checkpoint_dir(checkpoint_path, logdir):
|
||||
assert (
|
||||
TrainableUtil.find_rel_checkpoint_dir(logdir, checkpoint_path) == "checkpoint0/"
|
||||
TrainableUtil.find_rel_checkpoint_dir(logdir, checkpoint_path) == "checkpoint0"
|
||||
)
|
||||
|
||||
|
||||
|
|
|
@ -6,7 +6,6 @@ import shutil
|
|||
import sys
|
||||
import tempfile
|
||||
import unittest
|
||||
from unittest.mock import patch
|
||||
|
||||
import ray
|
||||
from ray.rllib import _register_all
|
||||
|
@ -24,7 +23,7 @@ from ray.tune.suggest.repeater import Repeater
|
|||
from ray.tune.suggest._mock import _MockSuggestionAlgorithm
|
||||
from ray.tune.suggest.suggestion import Searcher, ConcurrencyLimiter
|
||||
from ray.tune.suggest.search_generator import SearchGenerator
|
||||
from ray.tune.syncer import SyncConfig
|
||||
from ray.tune.syncer import SyncConfig, Syncer
|
||||
from ray.tune.tests.utils_for_test_trial_runner import TrialResultObserver
|
||||
|
||||
|
||||
|
@ -748,19 +747,37 @@ class TrialRunnerTest3(unittest.TestCase):
|
|||
self.assertTrue(trials[0].has_checkpoint())
|
||||
self.assertEqual(num_checkpoints(trials[0]), 2)
|
||||
|
||||
@patch("ray.tune.syncer.SYNC_PERIOD", 0)
|
||||
def testCheckpointAutoPeriod(self):
|
||||
ray.init(num_cpus=3)
|
||||
|
||||
# This makes checkpointing take 2 seconds.
|
||||
def sync_up(source, target, exclude=None):
|
||||
|
||||
class CustomSyncer(Syncer):
|
||||
def __init__(self, sync_period: float = 300.0):
|
||||
super(CustomSyncer, self).__init__(sync_period=sync_period)
|
||||
self._sync_status = {}
|
||||
|
||||
def sync_up(
|
||||
self, local_dir: str, remote_dir: str, exclude: list = None
|
||||
) -> bool:
|
||||
time.sleep(2)
|
||||
return True
|
||||
|
||||
def sync_down(
|
||||
self, remote_dir: str, local_dir: str, exclude: list = None
|
||||
) -> bool:
|
||||
time.sleep(2)
|
||||
return True
|
||||
|
||||
def delete(self, remote_dir: str) -> bool:
|
||||
pass
|
||||
|
||||
runner = TrialRunner(
|
||||
local_checkpoint_dir=self.tmpdir,
|
||||
checkpoint_period="auto",
|
||||
sync_config=SyncConfig(upload_dir="fake", syncer=sync_up),
|
||||
sync_config=SyncConfig(
|
||||
upload_dir="fake", syncer=CustomSyncer(), sync_period=0
|
||||
),
|
||||
remote_checkpoint_dir="fake",
|
||||
)
|
||||
runner.add_trial(Trial("__fake", config={"user_checkpoint_freq": 1}))
|
||||
|
|
|
@ -290,12 +290,6 @@ class TrialRunnerCallbacks(unittest.TestCase):
|
|||
first_logger_pos, last_logger_pos, syncer_pos = get_positions(callbacks)
|
||||
self.assertLess(last_logger_pos, syncer_pos)
|
||||
|
||||
# This should throw an error as the syncer comes before the logger
|
||||
with self.assertRaises(ValueError):
|
||||
callbacks = create_default_callbacks(
|
||||
[SyncerCallback(None), LoggerCallback()], SyncConfig(), None
|
||||
)
|
||||
|
||||
# This should be reordered but preserve the regular callback order
|
||||
[mc1, mc2, mc3] = [Callback(), Callback(), Callback()]
|
||||
# Has to be legacy logger to avoid logger callback creation
|
||||
|
|
227
python/ray/tune/tests/test_util_file_transfer.py
Normal file
227
python/ray/tune/tests/test_util_file_transfer.py
Normal file
|
@ -0,0 +1,227 @@
|
|||
import io
|
||||
import tarfile
|
||||
import os
|
||||
|
||||
import pytest
|
||||
import shutil
|
||||
import tempfile
|
||||
|
||||
from ray.exceptions import RayTaskError
|
||||
|
||||
from ray.tune.utils.file_transfer import (
|
||||
_sync_dir_between_different_nodes,
|
||||
delete_on_node,
|
||||
_sync_dir_on_same_node,
|
||||
)
|
||||
import ray.util
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def ray_start_2_cpus():
|
||||
address_info = ray.init(num_cpus=2, configure_logging=False)
|
||||
yield address_info
|
||||
# The code after the yield will run as teardown code.
|
||||
ray.shutdown()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_data_dirs():
|
||||
tmp_source = os.path.realpath(tempfile.mkdtemp())
|
||||
tmp_target = os.path.realpath(tempfile.mkdtemp())
|
||||
|
||||
os.makedirs(os.path.join(tmp_source, "subdir", "nested"))
|
||||
os.makedirs(os.path.join(tmp_source, "subdir_exclude", "something"))
|
||||
|
||||
files = [
|
||||
"level0.txt",
|
||||
"level0_exclude.txt",
|
||||
"subdir/level1.txt",
|
||||
"subdir/level1_exclude.txt",
|
||||
"subdir/nested/level2.txt",
|
||||
"subdir_nested_level2_exclude.txt",
|
||||
"subdir_exclude/something/somewhere.txt",
|
||||
]
|
||||
|
||||
for file in files:
|
||||
with open(os.path.join(tmp_source, file), "w") as f:
|
||||
f.write("Data")
|
||||
|
||||
yield tmp_source, tmp_target
|
||||
|
||||
shutil.rmtree(tmp_source)
|
||||
shutil.rmtree(tmp_target)
|
||||
|
||||
|
||||
def assert_file(exists: bool, root: str, path: str):
|
||||
full_path = os.path.join(root, path)
|
||||
|
||||
if exists:
|
||||
assert os.path.exists(full_path)
|
||||
else:
|
||||
assert not os.path.exists(full_path)
|
||||
|
||||
|
||||
def test_sync_nodes(ray_start_2_cpus, temp_data_dirs):
|
||||
"""Check that syncing between nodes works (data is found in target directory)"""
|
||||
tmp_source, tmp_target = temp_data_dirs
|
||||
|
||||
assert_file(True, tmp_source, "level0.txt")
|
||||
assert_file(True, tmp_source, "subdir/level1.txt")
|
||||
assert_file(False, tmp_target, "level0.txt")
|
||||
assert_file(False, tmp_target, "subdir/level1.txt")
|
||||
|
||||
node_ip = ray.util.get_node_ip_address()
|
||||
_sync_dir_between_different_nodes(
|
||||
source_ip=node_ip,
|
||||
source_path=tmp_source,
|
||||
target_ip=node_ip,
|
||||
target_path=tmp_target,
|
||||
)
|
||||
|
||||
assert_file(True, tmp_target, "level0.txt")
|
||||
assert_file(True, tmp_target, "subdir/level1.txt")
|
||||
|
||||
|
||||
def test_sync_nodes_only_diff(ray_start_2_cpus, temp_data_dirs):
|
||||
"""Check that only differing files are synced between nodes"""
|
||||
tmp_source, tmp_target = temp_data_dirs
|
||||
|
||||
# Sanity check
|
||||
assert_file(True, tmp_source, "level0.txt")
|
||||
assert_file(True, tmp_source, "subdir/level1.txt")
|
||||
assert_file(False, tmp_target, "level0.txt")
|
||||
assert_file(False, tmp_target, "level0_new.txt")
|
||||
|
||||
node_ip = ray.util.get_node_ip_address()
|
||||
_sync_dir_between_different_nodes(
|
||||
source_ip=node_ip,
|
||||
source_path=tmp_source,
|
||||
target_ip=node_ip,
|
||||
target_path=tmp_target,
|
||||
)
|
||||
|
||||
assert_file(True, tmp_source, "level0.txt")
|
||||
assert_file(True, tmp_target, "level0.txt")
|
||||
assert_file(True, tmp_target, "subdir/level1.txt")
|
||||
assert_file(False, tmp_target, "level0_new.txt")
|
||||
|
||||
# Add new file
|
||||
with open(os.path.join(tmp_source, "level0_new.txt"), "w") as f:
|
||||
f.write("Data\n")
|
||||
|
||||
# Modify existing file
|
||||
with open(os.path.join(tmp_source, "subdir", "level1.txt"), "w") as f:
|
||||
f.write("New data\n")
|
||||
|
||||
unpack, pack_actor, files_stats = _sync_dir_between_different_nodes(
|
||||
source_ip=node_ip,
|
||||
source_path=tmp_source,
|
||||
target_ip=node_ip,
|
||||
target_path=tmp_target,
|
||||
return_futures=True,
|
||||
)
|
||||
|
||||
files_stats = ray.get(files_stats)
|
||||
tarball = ray.get(pack_actor.get_full_data.remote())
|
||||
|
||||
assert "./level0.txt" in files_stats
|
||||
assert "./level0_new.txt" not in files_stats # Was not in target dir
|
||||
assert "subdir/level1.txt" in files_stats
|
||||
|
||||
with tarfile.open(fileobj=io.BytesIO(tarball)) as tar:
|
||||
files_in_tar = tar.getnames()
|
||||
assert "./level0.txt" not in files_in_tar
|
||||
assert "./level0_new.txt" in files_in_tar
|
||||
assert "subdir/level1.txt" in files_in_tar
|
||||
assert len(files_in_tar) == 7 # 3 files, 4 dirs (including root)
|
||||
|
||||
ray.get(unpack) # Wait until finished for teardown
|
||||
|
||||
|
||||
def test_delete_on_node(ray_start_2_cpus, temp_data_dirs):
|
||||
"""Check that delete on node works."""
|
||||
tmp_source, tmp_target = temp_data_dirs
|
||||
|
||||
assert_file(True, tmp_source, "level0.txt")
|
||||
assert_file(True, tmp_source, "subdir/level1.txt")
|
||||
|
||||
node_ip = ray.util.get_node_ip_address()
|
||||
delete_on_node(
|
||||
node_ip=node_ip,
|
||||
path=tmp_source,
|
||||
)
|
||||
|
||||
assert_file(False, tmp_source, "level0.txt")
|
||||
assert_file(False, tmp_source, "subdir/level1.txt")
|
||||
|
||||
# Re-create dir for teardown
|
||||
os.makedirs(tmp_source, exist_ok=True)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_workers", [1, 8])
|
||||
def test_multi_sync_same_node(ray_start_2_cpus, temp_data_dirs, num_workers):
|
||||
"""Check that multiple competing syncs to the same node+dir don't interfere"""
|
||||
tmp_source, tmp_target = temp_data_dirs
|
||||
|
||||
assert_file(True, tmp_source, "level0.txt")
|
||||
assert_file(True, tmp_source, "subdir/level1.txt")
|
||||
|
||||
node_ip = ray.util.get_node_ip_address()
|
||||
futures = [
|
||||
_sync_dir_on_same_node(
|
||||
ip=node_ip,
|
||||
source_path=tmp_source,
|
||||
target_path=tmp_target,
|
||||
return_futures=True,
|
||||
)
|
||||
for _ in range(num_workers)
|
||||
]
|
||||
ray.get(futures)
|
||||
|
||||
assert_file(True, tmp_target, "level0.txt")
|
||||
assert_file(True, tmp_target, "subdir/level1.txt")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_workers", [1, 8])
|
||||
def test_multi_sync_different_node(ray_start_2_cpus, temp_data_dirs, num_workers):
|
||||
"""Check that multiple competing syncs to the same dir don't interfere"""
|
||||
tmp_source, tmp_target = temp_data_dirs
|
||||
|
||||
assert_file(True, tmp_source, "level0.txt")
|
||||
assert_file(True, tmp_source, "subdir/level1.txt")
|
||||
|
||||
node_ip = ray.util.get_node_ip_address()
|
||||
futures = [
|
||||
_sync_dir_between_different_nodes(
|
||||
source_ip=node_ip,
|
||||
source_path=tmp_source,
|
||||
target_ip=node_ip,
|
||||
target_path=tmp_target,
|
||||
return_futures=True,
|
||||
)[0]
|
||||
for _ in range(num_workers)
|
||||
]
|
||||
ray.get(futures)
|
||||
|
||||
assert_file(True, tmp_target, "level0.txt")
|
||||
assert_file(True, tmp_target, "subdir/level1.txt")
|
||||
|
||||
|
||||
def test_max_size_exceeded(ray_start_2_cpus, temp_data_dirs):
|
||||
tmp_source, tmp_target = temp_data_dirs
|
||||
|
||||
node_ip = ray.util.get_node_ip_address()
|
||||
with pytest.raises(RayTaskError):
|
||||
_sync_dir_between_different_nodes(
|
||||
source_ip=node_ip,
|
||||
source_path=tmp_source,
|
||||
target_ip=node_ip,
|
||||
target_path=tmp_target,
|
||||
max_size_bytes=2,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
|
||||
sys.exit(pytest.main(["-v", __file__]))
|
|
@ -39,7 +39,7 @@ from ray.tune.result import (
|
|||
STDOUT_FILE,
|
||||
STDERR_FILE,
|
||||
)
|
||||
from ray.tune.sync_client import get_sync_client, get_cloud_sync_client
|
||||
from ray.tune.syncer import Syncer
|
||||
from ray.tune.utils import UtilMonitor
|
||||
from ray.tune.utils.placement_groups import PlacementGroupFactory
|
||||
from ray.tune.utils.trainable import TrainableUtil
|
||||
|
@ -91,14 +91,12 @@ class Trainable:
|
|||
|
||||
"""
|
||||
|
||||
_sync_function_tpl = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: Dict[str, Any] = None,
|
||||
logger_creator: Callable[[Dict[str, Any]], Logger] = None,
|
||||
remote_checkpoint_dir: Optional[str] = None,
|
||||
sync_function_tpl: Optional[str] = None,
|
||||
custom_syncer: Optional[Syncer] = None,
|
||||
):
|
||||
"""Initialize an Trainable.
|
||||
|
||||
|
@ -116,8 +114,8 @@ class Trainable:
|
|||
remote_checkpoint_dir: Upload directory (S3 or GS path).
|
||||
This is **per trial** directory,
|
||||
which is different from **per checkpoint** directory.
|
||||
sync_function_tpl: Sync function template to use. Defaults
|
||||
to `cls._sync_function` (which defaults to `None`).
|
||||
custom_syncer: Syncer used for synchronizing data from Ray nodes
|
||||
to external storage.
|
||||
"""
|
||||
|
||||
self._experiment_id = uuid.uuid4().hex
|
||||
|
@ -168,26 +166,13 @@ class Trainable:
|
|||
self._monitor = UtilMonitor(start=log_sys_usage)
|
||||
|
||||
self.remote_checkpoint_dir = remote_checkpoint_dir
|
||||
self.sync_function_tpl = sync_function_tpl or self._sync_function_tpl
|
||||
self.custom_syncer = custom_syncer
|
||||
self.storage_client = None
|
||||
|
||||
if self.uses_cloud_checkpointing and self.sync_function_tpl:
|
||||
# Keep this only for custom sync functions and
|
||||
# backwards compatibility.
|
||||
# Todo (krfricke): We should find a way to register custom
|
||||
# syncers in Checkpoints rather than passing storage clients
|
||||
self.storage_client = self._create_storage_client()
|
||||
|
||||
@property
|
||||
def uses_cloud_checkpointing(self):
|
||||
return bool(self.remote_checkpoint_dir)
|
||||
|
||||
def _create_storage_client(self):
|
||||
"""Returns a storage client."""
|
||||
return get_sync_client(self.sync_function_tpl) or get_cloud_sync_client(
|
||||
self.remote_checkpoint_dir
|
||||
)
|
||||
|
||||
def _storage_path(self, local_path):
|
||||
"""Converts a `local_path` to be based off of
|
||||
`self.remote_checkpoint_dir`."""
|
||||
|
@ -463,16 +448,17 @@ class Trainable:
|
|||
|
||||
return checkpoint_path
|
||||
|
||||
def _maybe_save_to_cloud(self, checkpoint_dir: str):
|
||||
def _maybe_save_to_cloud(self, checkpoint_dir: str) -> bool:
|
||||
# Derived classes like the FunctionRunner might call this
|
||||
if self.uses_cloud_checkpointing:
|
||||
if self.storage_client:
|
||||
# Keep for backwards compatibility, remove after deprecation
|
||||
self.storage_client.sync_up(
|
||||
if not self.uses_cloud_checkpointing:
|
||||
return False
|
||||
|
||||
if self.custom_syncer:
|
||||
self.custom_syncer.sync_up(
|
||||
checkpoint_dir, self._storage_path(checkpoint_dir)
|
||||
)
|
||||
self.storage_client.wait_or_retry()
|
||||
return
|
||||
self.custom_syncer.wait_or_retry()
|
||||
return True
|
||||
|
||||
checkpoint = Checkpoint.from_directory(checkpoint_dir)
|
||||
retry_fn(
|
||||
|
@ -481,6 +467,33 @@ class Trainable:
|
|||
num_retries=3,
|
||||
sleep_time=1,
|
||||
)
|
||||
return True
|
||||
|
||||
def _maybe_load_from_cloud(self, checkpoint_path: str) -> bool:
|
||||
if not self.uses_cloud_checkpointing:
|
||||
return False
|
||||
|
||||
rel_checkpoint_dir = TrainableUtil.find_rel_checkpoint_dir(
|
||||
self.logdir, checkpoint_path
|
||||
)
|
||||
external_uri = os.path.join(self.remote_checkpoint_dir, rel_checkpoint_dir)
|
||||
local_dir = os.path.join(self.logdir, rel_checkpoint_dir)
|
||||
|
||||
if self.custom_syncer:
|
||||
# Only keep for backwards compatibility
|
||||
self.custom_syncer.sync_down(remote_dir=external_uri, local_dir=local_dir)
|
||||
self.custom_syncer.wait_or_retry()
|
||||
return True
|
||||
|
||||
checkpoint = Checkpoint.from_uri(external_uri)
|
||||
retry_fn(
|
||||
lambda: checkpoint.to_directory(local_dir),
|
||||
subprocess.CalledProcessError,
|
||||
num_retries=3,
|
||||
sleep_time=1,
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
def save_to_object(self):
|
||||
"""Saves the current model state to a Python object.
|
||||
|
@ -533,26 +546,7 @@ class Trainable:
|
|||
if isinstance(checkpoint_path, TrialCheckpoint):
|
||||
checkpoint_path = checkpoint_path.local_path
|
||||
|
||||
if self.uses_cloud_checkpointing:
|
||||
rel_checkpoint_dir = TrainableUtil.find_rel_checkpoint_dir(
|
||||
self.logdir, checkpoint_path
|
||||
)
|
||||
external_uri = os.path.join(self.remote_checkpoint_dir, rel_checkpoint_dir)
|
||||
local_dir = os.path.join(self.logdir, rel_checkpoint_dir)
|
||||
|
||||
if self.storage_client:
|
||||
# Only keep for backwards compatibility
|
||||
self.storage_client.sync_down(external_uri, local_dir)
|
||||
self.storage_client.wait_or_retry()
|
||||
else:
|
||||
checkpoint = Checkpoint.from_uri(external_uri)
|
||||
retry_fn(
|
||||
lambda: checkpoint.to_directory(local_dir),
|
||||
subprocess.CalledProcessError,
|
||||
num_retries=3,
|
||||
sleep_time=1,
|
||||
)
|
||||
elif (
|
||||
if not self._maybe_load_from_cloud(checkpoint_path) and (
|
||||
# If a checkpoint source IP is given
|
||||
checkpoint_node_ip
|
||||
# And the checkpoint does not currently exist on the local node
|
||||
|
@ -566,6 +560,13 @@ class Trainable:
|
|||
if checkpoint:
|
||||
checkpoint.to_directory(checkpoint_path)
|
||||
|
||||
if not os.path.exists(checkpoint_path):
|
||||
raise ValueError(
|
||||
f"Could not recover from checkpoint as it does not exist on local "
|
||||
f"disk and was not available on cloud storage or another Ray node. "
|
||||
f"Got checkpoint path: {checkpoint_path} and IP {checkpoint_node_ip}"
|
||||
)
|
||||
|
||||
with open(checkpoint_path + ".tune_metadata", "rb") as f:
|
||||
metadata = pickle.load(f)
|
||||
self._experiment_id = metadata["experiment_id"]
|
||||
|
|
|
@ -32,6 +32,7 @@ from ray.tune.result import (
|
|||
DEBUG_METRICS,
|
||||
)
|
||||
from ray.tune.resources import Resources
|
||||
from ray.tune.syncer import Syncer
|
||||
from ray.tune.utils.placement_groups import (
|
||||
PlacementGroupFactory,
|
||||
resource_dict_to_pg_factory,
|
||||
|
@ -251,7 +252,7 @@ class Trial:
|
|||
placement_group_factory: Optional[PlacementGroupFactory] = None,
|
||||
stopping_criterion: Optional[Dict[str, float]] = None,
|
||||
remote_checkpoint_dir: Optional[str] = None,
|
||||
sync_function_tpl: Optional[str] = None,
|
||||
custom_syncer: Optional[Syncer] = None,
|
||||
checkpoint_freq: int = 0,
|
||||
checkpoint_at_end: bool = False,
|
||||
sync_on_checkpoint: bool = True,
|
||||
|
@ -367,9 +368,9 @@ class Trial:
|
|||
else:
|
||||
self.remote_checkpoint_dir_prefix = None
|
||||
|
||||
if sync_function_tpl == "auto" or not isinstance(sync_function_tpl, str):
|
||||
sync_function_tpl = None
|
||||
self.sync_function_tpl = sync_function_tpl
|
||||
if custom_syncer == "auto" or not isinstance(custom_syncer, Syncer):
|
||||
custom_syncer = None
|
||||
self.custom_syncer = custom_syncer
|
||||
|
||||
self.checkpoint_freq = checkpoint_freq
|
||||
self.checkpoint_at_end = checkpoint_at_end
|
||||
|
|
|
@ -34,7 +34,7 @@ from ray.tune.result import (
|
|||
from ray.tune.schedulers import FIFOScheduler, TrialScheduler
|
||||
from ray.tune.stopper import NoopStopper, Stopper
|
||||
from ray.tune.suggest import BasicVariantGenerator, SearchAlgorithm
|
||||
from ray.tune.syncer import CloudSyncer, get_cloud_syncer, SyncConfig
|
||||
from ray.tune.syncer import SyncConfig, get_node_to_storage_syncer, Syncer
|
||||
from ray.tune.trial import Trial
|
||||
from ray.tune.utils import warn_if_slow, flatten_dict
|
||||
from ray.tune.utils.log import Verbosity, has_verbosity
|
||||
|
@ -110,8 +110,10 @@ class _ExperimentCheckpointManager:
|
|||
checkpoint_period: Union[int, float, str],
|
||||
start_time: float,
|
||||
session_str: str,
|
||||
syncer: CloudSyncer,
|
||||
sync_trial_checkpoints: bool = True,
|
||||
syncer: Syncer,
|
||||
sync_trial_checkpoints: bool,
|
||||
local_dir: str,
|
||||
remote_dir: str,
|
||||
):
|
||||
self._checkpoint_dir = checkpoint_dir
|
||||
self._auto_checkpoint_enabled = checkpoint_period == "auto"
|
||||
|
@ -125,6 +127,8 @@ class _ExperimentCheckpointManager:
|
|||
|
||||
self._syncer = syncer
|
||||
self._sync_trial_checkpoints = sync_trial_checkpoints
|
||||
self._local_dir = local_dir
|
||||
self._remote_dir = remote_dir
|
||||
|
||||
self._last_checkpoint_time = 0.0
|
||||
|
||||
|
@ -183,12 +187,22 @@ class _ExperimentCheckpointManager:
|
|||
else:
|
||||
exclude = ["*/checkpoint_*"]
|
||||
|
||||
if self._syncer:
|
||||
if force:
|
||||
# Wait until previous sync command finished
|
||||
self._syncer.wait()
|
||||
self._syncer.sync_up(exclude=exclude)
|
||||
self._syncer.sync_up(
|
||||
local_dir=self._local_dir,
|
||||
remote_dir=self._remote_dir,
|
||||
exclude=exclude,
|
||||
)
|
||||
else:
|
||||
self._syncer.sync_up_if_needed(exclude=exclude)
|
||||
self._syncer.sync_up_if_needed(
|
||||
local_dir=self._local_dir,
|
||||
remote_dir=self._remote_dir,
|
||||
exclude=exclude,
|
||||
)
|
||||
|
||||
checkpoint_time_taken = time.monotonic() - checkpoint_time_start
|
||||
|
||||
if self._auto_checkpoint_enabled:
|
||||
|
@ -359,9 +373,8 @@ class TrialRunner:
|
|||
|
||||
sync_config = sync_config or SyncConfig()
|
||||
self._remote_checkpoint_dir = remote_checkpoint_dir
|
||||
self._syncer = get_cloud_syncer(
|
||||
local_checkpoint_dir, remote_checkpoint_dir, sync_config.syncer
|
||||
)
|
||||
|
||||
self._syncer = get_node_to_storage_syncer(sync_config)
|
||||
self._stopper = stopper or NoopStopper()
|
||||
self._resumed = False
|
||||
|
||||
|
@ -438,6 +451,8 @@ class TrialRunner:
|
|||
session_str=self._session_str,
|
||||
syncer=self._syncer,
|
||||
sync_trial_checkpoints=sync_trial_checkpoints,
|
||||
local_dir=self._local_checkpoint_dir,
|
||||
remote_dir=self._remote_checkpoint_dir,
|
||||
)
|
||||
|
||||
@property
|
||||
|
@ -471,10 +486,12 @@ class TrialRunner:
|
|||
)
|
||||
# Not clear if we need this assertion, since we should always have a
|
||||
# local checkpoint dir.
|
||||
assert self._local_checkpoint_dir or self._remote_checkpoint_dir
|
||||
assert self._local_checkpoint_dir or (
|
||||
self._remote_checkpoint_dir and self._syncer
|
||||
)
|
||||
|
||||
if resume_type == "AUTO":
|
||||
if self._remote_checkpoint_dir:
|
||||
if self._remote_checkpoint_dir and self._syncer:
|
||||
logger.info(
|
||||
f"Trying to find and download experiment checkpoint at "
|
||||
f"{self._remote_checkpoint_dir}"
|
||||
|
@ -482,7 +499,10 @@ class TrialRunner:
|
|||
# Todo: This syncs the entire experiment including trial
|
||||
# checkpoints. We should exclude these in the future.
|
||||
try:
|
||||
self._syncer.sync_down_if_needed()
|
||||
self._syncer.sync_down_if_needed(
|
||||
remote_dir=self._remote_checkpoint_dir,
|
||||
local_dir=self._local_checkpoint_dir,
|
||||
)
|
||||
self._syncer.wait()
|
||||
except TuneError as e:
|
||||
logger.warning(
|
||||
|
@ -548,9 +568,10 @@ class TrialRunner:
|
|||
f"({self._remote_checkpoint_dir})"
|
||||
):
|
||||
return False
|
||||
if not self._remote_checkpoint_dir:
|
||||
if not self._remote_checkpoint_dir or not self._syncer:
|
||||
raise ValueError(
|
||||
"Called resume from remote without remote directory. "
|
||||
"Called resume from remote without remote directory or "
|
||||
"without valid syncer. "
|
||||
"Fix this by passing a `SyncConfig` object with "
|
||||
"`upload_dir` set to `tune.run(sync_config=...)`."
|
||||
)
|
||||
|
@ -566,7 +587,11 @@ class TrialRunner:
|
|||
exclude = ["*/checkpoint_*"]
|
||||
|
||||
try:
|
||||
self._syncer.sync_down_if_needed(exclude=exclude)
|
||||
self._syncer.sync_down_if_needed(
|
||||
remote_dir=self._remote_checkpoint_dir,
|
||||
local_dir=self._local_checkpoint_dir,
|
||||
exclude=exclude,
|
||||
)
|
||||
self._syncer.wait()
|
||||
except TuneError as e:
|
||||
raise RuntimeError(
|
||||
|
|
|
@ -49,9 +49,8 @@ from ray.tune.schedulers.util import (
|
|||
from ray.tune.suggest.variant_generator import has_unresolved_values
|
||||
from ray.tune.syncer import (
|
||||
SyncConfig,
|
||||
set_sync_periods,
|
||||
wait_for_sync,
|
||||
validate_upload_dir,
|
||||
_validate_upload_dir,
|
||||
SyncerCallback,
|
||||
)
|
||||
from ray.tune.trainable import Trainable
|
||||
from ray.tune.trial import Trial
|
||||
|
@ -428,8 +427,7 @@ def run(
|
|||
|
||||
config = config or {}
|
||||
sync_config = sync_config or SyncConfig()
|
||||
validate_upload_dir(sync_config)
|
||||
set_sync_periods(sync_config)
|
||||
_validate_upload_dir(sync_config)
|
||||
|
||||
if num_samples == -1:
|
||||
num_samples = sys.maxsize
|
||||
|
@ -722,7 +720,14 @@ def run(
|
|||
if has_verbosity(Verbosity.V1_EXPERIMENT):
|
||||
_report_progress(runner, progress_reporter, done=True)
|
||||
|
||||
wait_for_sync()
|
||||
# Wait for syncing to finish
|
||||
for callback in callbacks:
|
||||
if isinstance(callback, SyncerCallback):
|
||||
try:
|
||||
callback.wait_for_all()
|
||||
except TuneError as e:
|
||||
logger.error(e)
|
||||
|
||||
runner.cleanup()
|
||||
|
||||
incomplete_trials = []
|
||||
|
|
|
@ -5,13 +5,14 @@ import os
|
|||
|
||||
from ray.tune.callback import Callback
|
||||
from ray.tune.progress_reporter import TrialProgressCallback
|
||||
from ray.tune.syncer import SyncConfig, detect_cluster_syncer
|
||||
from ray.tune.syncer import SyncConfig
|
||||
from ray.tune.logger import (
|
||||
CSVLoggerCallback,
|
||||
CSVLogger,
|
||||
JsonLoggerCallback,
|
||||
JsonLogger,
|
||||
LegacyLoggerCallback,
|
||||
LoggerCallback,
|
||||
TBXLoggerCallback,
|
||||
TBXLogger,
|
||||
)
|
||||
|
@ -69,7 +70,6 @@ def create_default_callbacks(
|
|||
# Check if we have a CSV, JSON and TensorboardX logger
|
||||
for i, callback in enumerate(callbacks):
|
||||
if isinstance(callback, LegacyLoggerCallback):
|
||||
last_logger_index = i
|
||||
if CSVLogger in callback.logger_classes:
|
||||
has_csv_logger = True
|
||||
if JsonLogger in callback.logger_classes:
|
||||
|
@ -78,17 +78,17 @@ def create_default_callbacks(
|
|||
has_tbx_logger = True
|
||||
elif isinstance(callback, CSVLoggerCallback):
|
||||
has_csv_logger = True
|
||||
last_logger_index = i
|
||||
elif isinstance(callback, JsonLoggerCallback):
|
||||
has_json_logger = True
|
||||
last_logger_index = i
|
||||
elif isinstance(callback, TBXLoggerCallback):
|
||||
has_tbx_logger = True
|
||||
last_logger_index = i
|
||||
elif isinstance(callback, SyncerCallback):
|
||||
syncer_index = i
|
||||
has_syncer_callback = True
|
||||
|
||||
if isinstance(callback, LoggerCallback):
|
||||
last_logger_index = i
|
||||
|
||||
# If CSV, JSON or TensorboardX loggers are missing, add
|
||||
if os.environ.get("TUNE_DISABLE_AUTO_CALLBACK_LOGGERS", "0") != "1":
|
||||
if not has_csv_logger:
|
||||
|
@ -114,11 +114,9 @@ def create_default_callbacks(
|
|||
not has_syncer_callback
|
||||
and os.environ.get("TUNE_DISABLE_AUTO_CALLBACK_SYNCER", "0") != "1"
|
||||
):
|
||||
|
||||
# Detect Docker and Kubernetes environments
|
||||
_cluster_syncer = detect_cluster_syncer(sync_config)
|
||||
|
||||
syncer_callback = SyncerCallback(sync_function=_cluster_syncer)
|
||||
syncer_callback = SyncerCallback(
|
||||
enabled=bool(sync_config.syncer), sync_period=sync_config.sync_period
|
||||
)
|
||||
callbacks.append(syncer_callback)
|
||||
syncer_index = len(callbacks) - 1
|
||||
|
||||
|
@ -127,16 +125,7 @@ def create_default_callbacks(
|
|||
and last_logger_index is not None
|
||||
and syncer_index < last_logger_index
|
||||
):
|
||||
if not has_csv_logger or not has_json_logger or not has_tbx_logger:
|
||||
raise ValueError(
|
||||
"The `SyncerCallback` you passed to `tune.run()` came before "
|
||||
"at least one `LoggerCallback`. Syncing should be done "
|
||||
"after writing logs. Please re-order the callbacks so that "
|
||||
"the `SyncerCallback` comes after any `LoggerCallback`."
|
||||
)
|
||||
else:
|
||||
# If these loggers were automatically created. just re-order
|
||||
# the callbacks
|
||||
# Re-order callbacks
|
||||
syncer_obj = callbacks[syncer_index]
|
||||
callbacks.pop(syncer_index)
|
||||
callbacks.insert(last_logger_index, syncer_obj)
|
||||
|
|
|
@ -4,85 +4,13 @@ import os
|
|||
import random
|
||||
import time
|
||||
from typing import Dict
|
||||
import uuid
|
||||
|
||||
import ray._private.utils
|
||||
|
||||
from ray.rllib.algorithms.mock import _MockTrainer
|
||||
from ray.tune import Trainable
|
||||
from ray.tune.callback import Callback
|
||||
from ray.tune.sync_client import get_sync_client
|
||||
from ray.tune.syncer import NodeSyncer
|
||||
from ray.tune.trial import Trial
|
||||
|
||||
MOCK_REMOTE_DIR = (
|
||||
os.path.join(ray._private.utils.get_user_temp_dir(), "mock-tune-remote") + os.sep
|
||||
)
|
||||
# Sync and delete templates that operate on local directories.
|
||||
LOCAL_SYNC_TEMPLATE = "mkdir -p {target} && rsync -avz {source}/ {target}/"
|
||||
LOCAL_DELETE_TEMPLATE = "rm -rf {target}"
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def mock_storage_client():
|
||||
"""Mocks storage client that treats a local dir as durable storage."""
|
||||
client = get_sync_client(LOCAL_SYNC_TEMPLATE, LOCAL_DELETE_TEMPLATE)
|
||||
path = os.path.join(
|
||||
ray._private.utils.get_user_temp_dir(), f"mock-client-{uuid.uuid4().hex[:4]}"
|
||||
)
|
||||
os.makedirs(path, exist_ok=True)
|
||||
client.set_logdir(path)
|
||||
return client
|
||||
|
||||
|
||||
class MockNodeSyncer(NodeSyncer):
|
||||
"""Mock NodeSyncer that syncs to and from /tmp"""
|
||||
|
||||
def has_remote_target(self):
|
||||
return True
|
||||
|
||||
@property
|
||||
def _remote_path(self):
|
||||
if self._remote_dir.startswith("/"):
|
||||
self._remote_dir = self._remote_dir[1:]
|
||||
return os.path.join(MOCK_REMOTE_DIR, self._remote_dir)
|
||||
|
||||
|
||||
class MockRemoteTrainer(_MockTrainer):
|
||||
"""Mock Trainable that saves at tmp for simulated clusters."""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
# Tests in test_cluster.py supply a remote checkpoint dir
|
||||
# We should ignore this here as this is specifically a
|
||||
# non-durable trainer
|
||||
kwargs.pop("remote_checkpoint_dir", None)
|
||||
|
||||
super(MockRemoteTrainer, self).__init__(*args, **kwargs)
|
||||
if self._logdir.startswith("/"):
|
||||
self._logdir = self._logdir[1:]
|
||||
self._logdir = os.path.join(MOCK_REMOTE_DIR, self._logdir)
|
||||
if not os.path.exists(self._logdir):
|
||||
os.makedirs(self._logdir)
|
||||
|
||||
|
||||
class MockDurableTrainer(_MockTrainer):
|
||||
"""Mock DurableTrainable that saves at tmp for simulated clusters."""
|
||||
|
||||
# Evaluate to true to use legacy storage client
|
||||
_sync_function_tpl = True
|
||||
|
||||
def __init__(
|
||||
self, remote_checkpoint_dir=None, sync_function_tpl=None, *args, **kwargs
|
||||
):
|
||||
_MockTrainer.__init__(self, *args, **kwargs)
|
||||
kwargs["remote_checkpoint_dir"] = remote_checkpoint_dir
|
||||
Trainable.__init__(self, *args, **kwargs)
|
||||
|
||||
def _create_storage_client(self):
|
||||
return mock_storage_client()
|
||||
|
||||
|
||||
class FailureInjectorCallback(Callback):
|
||||
"""Adds random failure injection to the TrialExecutor."""
|
||||
|
||||
|
|
|
@ -138,14 +138,14 @@ class TrainableUtil:
|
|||
|
||||
Note, the assumption here is `logdir` should be the prefix of
|
||||
`checkpoint_path`.
|
||||
For example, returns `checkpoint00000/`.
|
||||
For example, returns `checkpoint00000`.
|
||||
"""
|
||||
assert checkpoint_path.startswith(
|
||||
logdir
|
||||
), "expecting `logdir` to be a prefix of `checkpoint_path`"
|
||||
rel_path = os.path.relpath(checkpoint_path, logdir)
|
||||
tokens = rel_path.split(os.sep)
|
||||
return os.path.join(tokens[0], "")
|
||||
return os.path.join(tokens[0])
|
||||
|
||||
@staticmethod
|
||||
def make_checkpoint_dir(
|
||||
|
|
|
@ -146,7 +146,9 @@ def _serialize_checkpoint(checkpoint_path) -> bytes:
|
|||
def get_checkpoint_from_remote_node(
|
||||
checkpoint_path: str, node_ip: str, timeout: float = 300.0
|
||||
) -> Optional[Checkpoint]:
|
||||
if not any(node["NodeManagerAddress"] == node_ip for node in ray.nodes()):
|
||||
if not any(
|
||||
node["NodeManagerAddress"] == node_ip and node["Alive"] for node in ray.nodes()
|
||||
):
|
||||
logger.warning(
|
||||
f"Could not fetch checkpoint with path {checkpoint_path} from "
|
||||
f"node with IP {node_ip} because the node is not available "
|
||||
|
|
|
@ -197,8 +197,7 @@ class Algorithm(Trainable):
|
|||
config: Optional[Union[PartialAlgorithmConfigDict, AlgorithmConfig]] = None,
|
||||
env: Optional[Union[str, EnvType]] = None,
|
||||
logger_creator: Optional[Callable[[], Logger]] = None,
|
||||
remote_checkpoint_dir: Optional[str] = None,
|
||||
sync_function_tpl: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initializes an Algorithm instance.
|
||||
|
||||
|
@ -211,6 +210,8 @@ class Algorithm(Trainable):
|
|||
the "env" key in `config`.
|
||||
logger_creator: Callable that creates a ray.tune.Logger
|
||||
object. If unspecified, a default logger is created.
|
||||
**kwargs: Arguments passed to the Trainable base class.
|
||||
|
||||
"""
|
||||
|
||||
# User provided (partial) config (this may be w/o the default
|
||||
|
@ -288,9 +289,7 @@ class Algorithm(Trainable):
|
|||
}
|
||||
}
|
||||
|
||||
super().__init__(
|
||||
config, logger_creator, remote_checkpoint_dir, sync_function_tpl
|
||||
)
|
||||
super().__init__(config=config, logger_creator=logger_creator, **kwargs)
|
||||
|
||||
# Check, whether `training_iteration` is still a tune.Trainable property
|
||||
# and has not been overridden by the user in the attempt to implement the
|
||||
|
|
|
@ -163,8 +163,7 @@ class DDPPO(PPO):
|
|||
config: Optional[PartialAlgorithmConfigDict] = None,
|
||||
env: Optional[Union[str, EnvType]] = None,
|
||||
logger_creator: Optional[Callable[[], Logger]] = None,
|
||||
remote_checkpoint_dir: Optional[str] = None,
|
||||
sync_function_tpl: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initializes a DDPPO instance.
|
||||
|
||||
|
@ -177,9 +176,11 @@ class DDPPO(PPO):
|
|||
the "env" key in `config`.
|
||||
logger_creator: Callable that creates a ray.tune.Logger
|
||||
object. If unspecified, a default logger is created.
|
||||
**kwargs: Arguments passed to the Trainable base class
|
||||
|
||||
"""
|
||||
super().__init__(
|
||||
config, env, logger_creator, remote_checkpoint_dir, sync_function_tpl
|
||||
config=config, env=env, logger_creator=logger_creator, **kwargs
|
||||
)
|
||||
|
||||
if "train_batch_size" in config.keys() and config["train_batch_size"] != -1:
|
||||
|
|
Loading…
Add table
Reference in a new issue