[tune] Refactor Syncer / deprecate Sync client (#25655)

This PR includes / depends on #25709

The two concepts of Syncer and SyncClient are confusing, as is the current API for passing custom sync functions.

This PR refactors Tune's syncing behavior. The Sync client concept is hard deprecated. Instead, we offer a well defined Syncer API that can be extended to provide own syncing functionality. However, the default will be to use Ray AIRs file transfer utilities.

New API:
- Users can pass `syncer=CustomSyncer` which implements the `Syncer` API
- Otherwise our off-the-shelf syncing is used
- As before, syncing to cloud disables syncing to driver

Changes:
- Sync client is removed
- Syncer interface introduced
- _DefaultSyncer is a wrapper around the URI upload/download API from Ray AIR
- SyncerCallback only uses remote tasks to synchronize data
- Rsync syncing is fully depracated and removed
- Docker and kubernetes-specific syncing is fully deprecated and removed
- Testing is improved to use `file://` URIs instead of mock sync clients
This commit is contained in:
Kai Fricke 2022-06-14 14:46:30 +02:00 committed by GitHub
parent f597e21ac8
commit 6313ddc47c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
36 changed files with 1787 additions and 2884 deletions

View file

@ -72,7 +72,6 @@ These are the environment variables Ray Tune currently considers:
* **TUNE_RESULT_BUFFER_MAX_TIME_S**: Similarly, Ray Tune buffers results up to ``number_of_trial/10`` seconds,
but never longer than this value. Defaults to 100 (seconds).
* **TUNE_RESULT_BUFFER_MIN_TIME_S**: Additionally, you can specify a minimum time to buffer results. Defaults to 0.
* **TUNE_SYNCER_VERBOSITY**: Amount of command output when using Tune with Docker Syncer. Defaults to 0.
* **TUNE_WARN_THRESHOLD_S**: Threshold for logging if an Tune event loop operation takes too long. Defaults to 0.5 (seconds).
* **TUNE_WARN_INSUFFICENT_RESOURCE_THRESHOLD_S**: Threshold for throwing a warning if no active trials are in ``RUNNING`` state
for this amount of seconds. If the Ray Tune job is stuck in this state (most likely due to insufficient resources),

View file

@ -8,14 +8,6 @@ External library integrations (tune.integration)
:depth: 1
.. _tune-integration-docker:
Docker (tune.integration.docker)
--------------------------------
.. autofunction:: ray.tune.integration.docker.DockerSyncer
.. _tune-integration-keras:
Keras (tune.integration.keras)
@ -25,12 +17,6 @@ Keras (tune.integration.keras)
.. autoclass:: ray.tune.integration.keras.TuneReportCheckpointCallback
.. _tune-integration-kubernetes:
Kubernetes (tune.integration.kubernetes)
----------------------------------------
.. autofunction:: ray.tune.integration.kubernetes.NamespacedKubernetesSyncer
.. _tune-integration-mlflow:

View file

@ -197,7 +197,6 @@ tune.run(tune.with_parameters(f, data=data))
# __large_data_end__
MyTrainableClass = None
custom_sync_str_or_func = ""
if not MOCK:
# __log_1_start__
@ -209,37 +208,31 @@ if not MOCK:
# __log_1_end__
# __log_2_start__
from ray.tune.syncer import Syncer
class CustomSyncer(Syncer):
def sync_up(
self, local_dir: str, remote_dir: str, exclude: list = None
) -> bool:
pass # sync up
def sync_down(
self, remote_dir: str, local_dir: str, exclude: list = None
) -> bool:
pass # sync down
def delete(self, remote_dir: str) -> bool:
pass # delete
tune.run(
MyTrainableClass,
sync_config=tune.SyncConfig(
upload_dir="s3://my-log-dir", syncer=custom_sync_str_or_func
upload_dir="s3://my-log-dir", syncer=CustomSyncer()
),
)
# __log_2_end__
# __sync_start__
import subprocess
def custom_sync_func(source, target):
# run other workload here
sync_cmd = "s3 {source} {target}".format(source=source, target=target)
sync_process = subprocess.Popen(sync_cmd, shell=True)
sync_process.wait()
# __sync_end__
if not MOCK:
# __docker_start__
from ray import tune
from ray.tune.integration.docker import DockerSyncer
sync_config = tune.SyncConfig(syncer=DockerSyncer)
tune.run(train, sync_config=sync_config)
# __docker_end__
# __s3_start__
from ray import tune
@ -264,13 +257,6 @@ if not MOCK:
)
# __sync_config_end__
# __k8s_start__
from ray.tune.integration.kubernetes import NamespacedKubernetesSyncer
sync_config = tune.SyncConfig(syncer=NamespacedKubernetesSyncer("ray"))
tune.run(train, sync_config=sync_config)
# __k8s_end__
import ray

View file

@ -577,7 +577,9 @@ How can I upload my Tune results to cloud storage?
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
If an upload directory is provided, Tune will automatically sync results from the ``local_dir`` to the given directory,
natively supporting standard URIs for systems like S3, gsutil or HDFS.
natively supporting standard URIs for systems like S3, gsutil or HDFS. You can add more filesystems by installing
`fs-spec <https://filesystem-spec.readthedocs.io/en/latest/>`_-compatible filesystems e.g. using pip.
Here is an example of uploading to S3, using a bucket called ``my-log-dir``:
.. literalinclude:: doc_code/faq.py
@ -586,8 +588,7 @@ Here is an example of uploading to S3, using a bucket called ``my-log-dir``:
:start-after: __log_1_start__
:end-before: __log_1_end__
You can customize this to specify arbitrary storages with the ``syncer`` argument in ``tune.SyncConfig``.
This argument supports either strings with the same replacement fields OR arbitrary functions.
You can customize synchronization behavior by implementing your own Syncer:
.. literalinclude:: doc_code/faq.py
:dedent:
@ -595,14 +596,6 @@ This argument supports either strings with the same replacement fields OR arbitr
:start-after: __log_2_start__
:end-before: __log_2_end__
If a string is provided, then it must include replacement fields ``{source}`` and ``{target}``, like
``s3 sync {source} {target}``. Alternatively, a function can be provided with the following signature:
.. literalinclude:: doc_code/faq.py
:language: python
:start-after: __sync_start__
:end-before: __sync_end__
By default, syncing occurs every 300 seconds.
To change the frequency of syncing, set the ``sync_period`` attribute of the sync config to the desired syncing period.
@ -623,23 +616,13 @@ How can I use Tune with Docker?
Tune automatically syncs files and checkpoints between different remote
containers as needed.
To make this work in your Docker cluster, e.g. when you are using the Ray autoscaler
with docker containers, you will need to pass a
``DockerSyncer`` to the ``syncer`` argument of ``tune.SyncConfig``.
.. literalinclude:: doc_code/faq.py
:dedent:
:language: python
:start-after: __docker_start__
:end-before: __docker_end__
.. _tune-kubernetes:
How can I use Tune with Kubernetes?
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Ray Tune automatically synchronizes files and checkpoints between different remote nodes as needed.
This usually happens via SSH, but this can be a :ref:`performance bottleneck <tune-bottlenecks>`,
This usually happens via the Ray object store, but this can be a :ref:`performance bottleneck <tune-bottlenecks>`,
especially when running many trials in parallel.
Instead you should use shared storage for checkpoints so that no additional synchronization across nodes
@ -662,19 +645,9 @@ Second, you can set up a shared file system like NFS. If you do this, disable au
:start-after: __sync_config_start__
:end-before: __sync_config_end__
Lastly, if you still want to use SSH for trial synchronization, but are not running
on the Ray cluster launcher, you might need to pass a
``KubernetesSyncer`` to the ``syncer`` argument of ``tune.SyncConfig``.
You have to specify your Kubernetes namespace explicitly:
.. literalinclude:: doc_code/faq.py
:dedent:
:language: python
:start-after: __k8s_start__
:end-before: __k8s_end__
Please note that we strongly encourage you to use one of the other two options instead, as they will
result in less overhead and don't require pods to SSH into each other.
Please note that we strongly encourage you to use one of these two options, as they will
result in less overhead and provide naturally durable checkpoint storage.
.. _tune-default-search-space:

View file

@ -34,13 +34,6 @@ except (ImportError, ModuleNotFoundError):
from ray import logger
# We keep these constants for legacy compatibility with Tune's sync client
# After Tune fully moved to using pyarrow.fs we can remove these.
S3_PREFIX = "s3://"
GS_PREFIX = "gs://"
HDFS_PREFIX = "hdfs://"
ALLOWED_REMOTE_PREFIXES = (S3_PREFIX, GS_PREFIX, HDFS_PREFIX)
def _assert_pyarrow_installed():
if pyarrow is None:
@ -77,11 +70,7 @@ def is_non_local_path_uri(uri: str) -> bool:
if bool(get_fs_and_path(uri)[0]):
return True
# Keep manual check for prefixes for backwards compatibility with the
# TrialCheckpoint class. Remove once fully deprecated.
# Deprecated: Remove in Ray > 1.13
if any(uri.startswith(p) for p in ALLOWED_REMOTE_PREFIXES):
return True
return False
@ -207,8 +196,8 @@ def _upload_to_uri_with_exclude(
if _should_exclude(candidate):
continue
full_source_path = os.path.join(local_path, candidate)
full_target_path = os.path.join(bucket_path, candidate)
full_source_path = os.path.normpath(os.path.join(local_path, candidate))
full_target_path = os.path.normpath(os.path.join(bucket_path, candidate))
pyarrow.fs.copy_files(
full_source_path, full_target_path, destination_filesystem=fs

View file

@ -17,6 +17,7 @@ from ray.tune import Trainable, PlacementGroupFactory
from ray.tune.logger import Logger
from ray.tune.registry import get_trainable_cls
from ray.tune.resources import Resources
from ray.tune.syncer import Syncer
from ray.util.annotations import PublicAPI
from ray.util.ml_utils.dict import merge_dicts
@ -204,7 +205,7 @@ class RLTrainer(BaseTrainer):
env: Optional[Union[str, EnvType]] = None,
logger_creator: Optional[Callable[[], Logger]] = None,
remote_checkpoint_dir: Optional[str] = None,
sync_function_tpl: Optional[str] = None,
custom_syncer: Optional[Syncer] = None,
):
resolved_config = merge_dicts(base_config, config or {})
param_dict["config"] = resolved_config
@ -217,7 +218,7 @@ class RLTrainer(BaseTrainer):
env=env,
logger_creator=logger_creator,
remote_checkpoint_dir=remote_checkpoint_dir,
sync_function_tpl=sync_function_tpl,
custom_syncer=custom_syncer,
)
def save_checkpoint(self, checkpoint_dir: str):

View file

@ -136,22 +136,6 @@ py_test(
tags = ["team:ml", "exclusive", "tests_dir_I"],
)
py_test(
name = "test_integration_docker",
size = "small",
srcs = ["tests/test_integration_docker.py"],
deps = [":tune_lib"],
tags = ["team:ml", "exclusive", "tests_dir_I"],
)
py_test(
name = "test_integration_kubernetes",
size = "small",
srcs = ["tests/test_integration_kubernetes.py"],
deps = [":tune_lib"],
tags = ["team:ml", "exclusive", "tests_dir_I"],
)
py_test(
name = "test_integration_pytorch_lightning",
size = "small",
@ -281,9 +265,25 @@ py_test(
)
py_test(
name = "test_sync",
name = "test_util_file_transfer",
size = "medium",
srcs = ["tests/test_sync.py"],
srcs = ["tests/test_util_file_transfer.py"],
deps = [":tune_lib"],
tags = ["team:ml", "exclusive", "tests_dir_S"],
)
py_test(
name = "test_syncer",
size = "medium",
srcs = ["tests/test_syncer.py"],
deps = [":tune_lib"],
tags = ["team:ml", "exclusive", "tests_dir_S"],
)
py_test(
name = "test_syncer_callback",
size = "medium",
srcs = ["tests/test_syncer_callback.py"],
deps = [":tune_lib"],
tags = ["team:ml", "exclusive", "tests_dir_S"],
)

View file

@ -9,7 +9,7 @@ from six import string_types
from ray.tune import TuneError
from ray.tune.trial import Trial
from ray.tune.resources import json_to_resources
from ray.tune.syncer import SyncConfig
from ray.tune.syncer import SyncConfig, Syncer
from ray.tune.utils.placement_groups import PlacementGroupFactory
from ray.tune.utils.util import SafeFallbackEncoder
@ -211,16 +211,20 @@ def create_trial_from_spec(
remote_checkpoint_dir = spec.get("remote_checkpoint_dir")
sync_config = spec.get("sync_config", SyncConfig())
if sync_config.syncer is None or isinstance(sync_config.syncer, str):
sync_function_tpl = sync_config.syncer
elif not isinstance(sync_config.syncer, str):
# If a syncer was specified, but not a template, it is a function.
# Functions cannot be used for trial checkpointing on remote nodes,
# so we set the remote checkpoint dir to None to disable this.
sync_function_tpl = None
remote_checkpoint_dir = None
if (
sync_config.syncer is None
or sync_config.syncer == "auto"
or isinstance(sync_config.syncer, Syncer)
):
custom_syncer = sync_config.syncer
else:
sync_function_tpl = None # Auto-detect
raise ValueError(
f"Unknown syncer type passed in SyncConfig: {type(sync_config.syncer)}. "
f"Note that custom sync functions and templates have been deprecated. "
f"Instead you can implement you own `Syncer` class. "
f"Please leave a comment on GitHub if you run into any issues with this: "
f"https://github.com/ray-project/ray/issues"
)
return Trial(
# Submitting trial via server in py2.7 creates Unicode, which does not
@ -232,7 +236,7 @@ def create_trial_from_spec(
# json.load leads to str -> unicode in py2.7
stopping_criterion=spec.get("stop", {}),
remote_checkpoint_dir=remote_checkpoint_dir,
sync_function_tpl=sync_function_tpl,
custom_syncer=custom_syncer,
checkpoint_freq=args.checkpoint_freq,
checkpoint_at_end=args.checkpoint_at_end,
sync_on_checkpoint=sync_config.sync_on_checkpoint,
@ -246,5 +250,5 @@ def create_trial_from_spec(
log_to_file=spec.get("log_to_file"),
# str(None) doesn't create None
max_failures=args.max_failures,
**trial_kwargs
**trial_kwargs,
)

View file

@ -1,154 +1,21 @@
import logging
import os
from typing import Optional, Tuple, List
from ray.autoscaler.sdk import rsync, configure_logging
from ray.util import get_node_ip_address
from ray.util.debug import log_once
from ray.tune.syncer import NodeSyncer
from ray.tune.sync_client import SyncClient
from ray.ray_constants import env_integer
logger = logging.getLogger(__name__)
from ray.util.annotations import Deprecated
class DockerSyncer(NodeSyncer):
"""DockerSyncer used for synchronization between Docker containers.
This syncer extends the node syncer, but is usually instantiated
without a custom sync client. The sync client defaults to
``DockerSyncClient`` instead.
Set the env var `TUNE_SYNCER_VERBOSITY` to increase verbosity
of syncing operations (0, 1, 2, 3). Defaults to 0.
.. note::
This syncer only works with the Ray cluster launcher.
If you use your own Docker setup, make sure the nodes can connect
to each other via SSH, and try the regular SSH-based syncer instead.
Example:
.. code-block:: python
from ray.tune.integration.docker import DockerSyncer
tune.run(train,
sync_config=tune.SyncConfig(
syncer=DockerSyncer))
"""
_cluster_config_file = os.path.expanduser("~/ray_bootstrap_config.yaml")
def __init__(
self, local_dir: str, remote_dir: str, sync_client: Optional[SyncClient] = None
):
configure_logging(
log_style="record", verbosity=env_integer("TUNE_SYNCER_VERBOSITY", 0)
@Deprecated
class DockerSyncer:
def __init__(self, *args, **kwargs):
raise DeprecationWarning(
"DockerSyncer has been fully deprecated. There is no need to "
"use this syncer anymore - data syncing will happen automatically "
"using the Ray object store. You can just remove passing this class."
)
self.local_ip = get_node_ip_address()
self.worker_ip = None
sync_client = sync_client or DockerSyncClient()
sync_client.configure(self._cluster_config_file)
super(NodeSyncer, self).__init__(local_dir, remote_dir, sync_client)
def set_worker_ip(self, worker_ip: str):
self.worker_ip = worker_ip
@property
def _remote_path(self) -> Tuple[str, str]:
return (self.worker_ip, self._remote_dir)
class DockerSyncClient(SyncClient):
"""DockerSyncClient to be used by DockerSyncer.
This client takes care of executing the synchronization
commands for Docker nodes. In its ``sync_down`` and
``sync_up`` commands, it expects tuples for the source
and target, respectively, for compatibility with docker.
Args:
should_bootstrap: Whether to bootstrap the autoscaler
cofiguration. This may be useful when you are
running into authentication problems; i.e.:
https://github.com/ray-project/ray/issues/17756.
"""
def __init__(self, should_bootstrap: bool = True):
self._command_runners = {}
self._cluster_config = None
if os.environ.get("TUNE_SYNC_DISABLE_BOOTSTRAP") == "1":
should_bootstrap = False
logger.debug("Skipping bootstrap for docker sync client.")
self._should_bootstrap = should_bootstrap
def configure(self, cluster_config_file: str):
self._cluster_config_file = cluster_config_file
def sync_up(
self, source: str, target: Tuple[str, str], exclude: Optional[List] = None
) -> bool:
"""Here target is a tuple (target_node, target_dir)"""
target_node, target_dir = target
# Add trailing slashes for rsync
source = os.path.join(source, "")
target_dir = os.path.join(target_dir, "")
import click
try:
rsync(
cluster_config=self._cluster_config_file,
source=source,
target=target_dir,
down=False,
ip_address=target_node,
should_bootstrap=self._should_bootstrap,
use_internal_ip=True,
@Deprecated
class DockerSyncClient:
def __init__(self, *args, **kwargs):
raise DeprecationWarning(
"DockerSyncClient has been fully deprecated. There is no need to "
"use this syncer anymore - data syncing will happen automatically "
"using the Ray object store. You can just remove passing this class."
)
except click.ClickException:
if log_once("docker_rsync_up_fail"):
logger.warning(
"Rsync-up failed. Consider using a durable trainable "
"or setting the `TUNE_SYNC_DISABLE_BOOTSTRAP=1` env var."
)
raise
return True
def sync_down(
self, source: Tuple[str, str], target: str, exclude: Optional[List] = None
) -> bool:
"""Here source is a tuple (source_node, source_dir)"""
source_node, source_dir = source
# Add trailing slashes for rsync
source_dir = os.path.join(source_dir, "")
target = os.path.join(target, "")
import click
try:
rsync(
cluster_config=self._cluster_config_file,
source=source_dir,
target=target,
down=True,
ip_address=source_node,
should_bootstrap=self._should_bootstrap,
use_internal_ip=True,
)
except click.ClickException:
if log_once("docker_rsync_down_fail"):
logger.warning(
"Rsync-down failed. Consider using a durable trainable "
"or setting the `TUNE_SYNC_DISABLE_BOOTSTRAP=1` env var."
)
raise
return True
def delete(self, target: str) -> bool:
raise NotImplementedError

View file

@ -1,176 +1,30 @@
import os
from typing import Any, Optional, Tuple, List
import subprocess
from ray import logger
from ray.autoscaler._private.command_runner import KubernetesCommandRunner
from ray.tune.syncer import NodeSyncer
from ray.tune.sync_client import SyncClient
from ray.util import get_node_ip_address
from ray.util.annotations import Deprecated
def try_import_kubernetes():
try:
import kubernetes
except ImportError:
kubernetes = None
return kubernetes
@Deprecated
class KubernetesSyncer:
def __init__(self, *args, **kwargs):
raise DeprecationWarning(
"KubernetesSyncer has been fully deprecated. There is no need to "
"use this syncer anymore - data syncing will happen automatically "
"using the Ray object store. You can just remove passing this class."
)
kubernetes = try_import_kubernetes()
@Deprecated
class KubernetesSyncClient:
def __init__(self, *args, **kwargs):
raise DeprecationWarning(
"KubernetesSyncClient has been fully deprecated. There is no need to "
"use this syncer anymore - data syncing will happen automatically "
"using the Ray object store. You can just remove passing this class."
)
def NamespacedKubernetesSyncer(namespace: str):
"""Wrapper to return a ``KubernetesSyncer`` for a Kubernetes namespace.
Args:
namespace: Kubernetes namespace.
Returns:
A ``KubernetesSyncer`` class to be passed to ``tune.run()``.
Example:
.. code-block:: python
from ray.tune.integration.kubernetes import NamespacedKubernetesSyncer
tune.run(train,
sync_config=tune.SyncConfig(
syncer=NamespacedKubernetesSyncer("ray")))
"""
class _NamespacedKubernetesSyncer(KubernetesSyncer):
_namespace = namespace
return _NamespacedKubernetesSyncer
class KubernetesSyncer(NodeSyncer):
"""KubernetesSyncer used for synchronization between Kubernetes pods.
This syncer extends the node syncer, but is usually instantiated
without a custom sync client. The sync client defaults to
``KubernetesSyncClient`` instead.
KubernetesSyncer uses the default namespace ``ray``. You should
probably use ``NamespacedKubernetesSyncer`` to return a class
with a custom namespace instead.
"""
_namespace = "ray"
def __init__(
self, local_dir: str, remote_dir: str, sync_client: Optional[SyncClient] = None
):
if not kubernetes:
raise ImportError(
"kubernetes is not installed on this machine/container. "
"Try: pip install kubernetes"
raise DeprecationWarning(
"NamespacedKubernetesSyncer has been fully deprecated. There is no need to "
"use this syncer anymore - data syncing will happen automatically "
"using the Ray object store. You can just remove passing this class."
)
self.local_ip = get_node_ip_address()
self.local_node = self._get_kubernetes_node_by_ip(self.local_ip)
self.worker_ip = None
self.worker_node = None
sync_client = sync_client or KubernetesSyncClient(
namespace=self.__class__._namespace
)
super(NodeSyncer, self).__init__(local_dir, remote_dir, sync_client)
def set_worker_ip(self, worker_ip: str):
self.worker_ip = worker_ip
self.worker_node = self._get_kubernetes_node_by_ip(worker_ip)
def _get_kubernetes_node_by_ip(self, node_ip: str) -> Optional[str]:
"""Return node name by internal or external IP"""
kubernetes.config.load_incluster_config()
api = kubernetes.client.CoreV1Api()
pods = api.list_namespaced_pod(self._namespace)
for pod in pods.items:
if pod.status.host_ip == node_ip or pod.status.pod_ip == node_ip:
return pod.metadata.name
logger.error("Could not find Kubernetes pod name for IP {}".format(node_ip))
return None
@property
def _remote_path(self) -> Tuple[str, str]:
return self.worker_node, self._remote_dir
class KubernetesSyncClient(SyncClient):
"""KubernetesSyncClient to be used by KubernetesSyncer.
This client takes care of executing the synchronization
commands for Kubernetes clients. In its ``sync_down`` and
``sync_up`` commands, it expects tuples for the source
and target, respectively, for compatibility with the
KubernetesCommandRunner.
Args:
namespace: Namespace in which the pods live.
process_runner: How commands should be called.
Defaults to ``subprocess``.
"""
def __init__(self, namespace: str, process_runner: Any = subprocess):
self.namespace = namespace
self._process_runner = process_runner
self._command_runners = {}
def _create_command_runner(self, node_id: str) -> KubernetesCommandRunner:
"""Create a command runner for one Kubernetes node"""
return KubernetesCommandRunner(
log_prefix="KubernetesSyncClient: {}:".format(node_id),
namespace=self.namespace,
node_id=node_id,
auth_config=None,
process_runner=self._process_runner,
)
def _get_command_runner(self, node_id: str) -> KubernetesCommandRunner:
"""Create command runner if it doesn't exist"""
# Todo(krfricke): These cached runners are currently
# never cleaned up. They are cheap so this shouldn't
# cause much problems, but should be addressed if
# the SyncClient is used more extensively in the future.
if node_id not in self._command_runners:
command_runner = self._create_command_runner(node_id)
self._command_runners[node_id] = command_runner
return self._command_runners[node_id]
def sync_up(
self, source: str, target: Tuple[str, str], exclude: Optional[List] = None
) -> bool:
"""Here target is a tuple (target_node, target_dir)"""
target_node, target_dir = target
# Add trailing slashes for rsync
source = os.path.join(source, "")
target_dir = os.path.join(target_dir, "")
command_runner = self._get_command_runner(target_node)
command_runner.run_rsync_up(source, target_dir)
return True
def sync_down(
self, source: Tuple[str, str], target: str, exclude: Optional[List] = None
) -> bool:
"""Here source is a tuple (source_node, source_dir)"""
source_node, source_dir = source
# Add trailing slashes for rsync
source_dir = os.path.join(source_dir, "")
target = os.path.join(target, "")
command_runner = self._get_command_runner(source_node)
command_runner.run_rsync_down(source_dir, target)
return True
def delete(self, target: str) -> bool:
"""No delete function because it is only used by
the KubernetesSyncer, which doesn't call delete."""
return True

View file

@ -379,7 +379,7 @@ class RayTrialExecutor:
# We keep these kwargs separate for backwards compatibility
# with trainables that don't provide these keyword arguments
kwargs["remote_checkpoint_dir"] = trial.remote_checkpoint_dir
kwargs["sync_function_tpl"] = trial.sync_function_tpl
kwargs["custom_syncer"] = trial.custom_syncer
# Throw a meaningful error if trainable does not use the
# new API
@ -389,7 +389,7 @@ class RayTrialExecutor:
except Exception as e:
raise RuntimeError(
"Your trainable class does not accept a "
"`remote_checkpoint_dir` or `sync_function_tpl` argument "
"`remote_checkpoint_dir` or `custom_syncer` argument "
"in its constructor, but you've passed a "
"`upload_dir` to your SyncConfig. Without accepting "
"these parameters and passing them to the base trainable "
@ -775,27 +775,31 @@ class RayTrialExecutor:
raise RuntimeError(
"Trial {}: Unable to restore - no runner found.".format(trial)
)
value = checkpoint.dir_or_data
checkpoint_dir = checkpoint.dir_or_data
node_ip = checkpoint.node_ip
if checkpoint.storage_mode == CheckpointStorage.MEMORY:
logger.debug("Trial %s: Attempting restore from object", trial)
# Note that we don't store the remote since in-memory checkpoints
# don't guarantee fault tolerance and don't need to be waited on.
with self._change_working_directory(trial):
trial.runner.restore_from_object.remote(value)
trial.runner.restore_from_object.remote(checkpoint_dir)
else:
logger.debug("Trial %s: Attempting restore from %s", trial, value)
if trial.uses_cloud_checkpointing or not trial.sync_on_checkpoint:
logger.debug("Trial %s: Attempting restore from %s", trial, checkpoint_dir)
if (
trial.uses_cloud_checkpointing
or not trial.sync_on_checkpoint
or not os.path.exists(checkpoint_dir)
):
# If using cloud checkpointing, trial will get cp from cloud.
# If not syncing to driver, assume it has access to the cp
# on the local fs.
with self._change_working_directory(trial):
remote = trial.runner.restore.remote(value, node_ip)
remote = trial.runner.restore.remote(checkpoint_dir, node_ip)
elif trial.sync_on_checkpoint:
# This provides FT backwards compatibility in the
# case where no cloud checkpoints are provided.
logger.debug("Trial %s: Reading checkpoint into memory", trial)
obj = TrainableUtil.checkpoint_to_object(value)
obj = TrainableUtil.checkpoint_to_object(checkpoint_dir)
with self._change_working_directory(trial):
remote = trial.runner.restore_from_object.remote(obj)
else:

View file

@ -1,606 +1,44 @@
import abc
import distutils
import distutils.spawn
import inspect
import logging
import pathlib
import subprocess
import tempfile
import time
import types
from typing import Optional, List, Callable, Union, Tuple
from shlex import quote
import ray
from ray.tune.error import TuneError
from ray.tune.utils.file_transfer import sync_dir_between_nodes, delete_on_node
from ray.util.annotations import PublicAPI, DeveloperAPI
from ray.air._internal.remote_storage import (
S3_PREFIX,
GS_PREFIX,
HDFS_PREFIX,
ALLOWED_REMOTE_PREFIXES,
)
from ray.util.annotations import Deprecated
logger = logging.getLogger(__name__)
noop_template = ": {target}" # noop in bash
def noop(*args):
return
def get_sync_client(
sync_function: Optional[Union[str, Callable]],
delete_function: Optional[Union[str, Callable]] = None,
) -> Optional["SyncClient"]:
"""Returns a sync client.
Args:
sync_function: Sync function.
delete_function: Delete function. Must be
the same type as sync_function if it is provided.
Raises:
ValueError if sync_function or delete_function are malformed.
"""
if sync_function is None:
return None
if delete_function and type(sync_function) != type(delete_function):
raise ValueError("Sync and delete functions must be of same type.")
if isinstance(sync_function, types.FunctionType):
delete_function = delete_function or noop
client_cls = FunctionBasedClient
elif isinstance(sync_function, str):
delete_function = delete_function or noop_template
client_cls = CommandBasedClient
else:
raise ValueError(
"Sync function {} must be string or function".format(sync_function)
)
return client_cls(sync_function, sync_function, delete_function)
def get_cloud_sync_client(remote_path: str) -> "CommandBasedClient":
"""Returns a CommandBasedClient that can sync to/from remote storage.
Args:
remote_path: Path to remote storage (S3, GS or HDFS).
Raises:
ValueError if malformed remote_dir.
"""
if remote_path.startswith(S3_PREFIX):
if not distutils.spawn.find_executable("aws"):
raise ValueError(
"Upload uri starting with '{}' requires awscli tool"
" to be installed".format(S3_PREFIX)
)
sync_up_template = (
"aws s3 sync {source} {target} "
"--exact-timestamps --only-show-errors {options}"
)
sync_down_template = sync_up_template
delete_template = "aws s3 rm {target} --recursive --only-show-errors {options}"
exclude_template = "--exclude '{pattern}'"
elif remote_path.startswith(GS_PREFIX):
if not distutils.spawn.find_executable("gsutil"):
raise ValueError(
"Upload uri starting with '{}' requires gsutil tool"
" to be installed".format(GS_PREFIX)
)
sync_up_template = "gsutil rsync -r {options} {source} {target}"
sync_down_template = sync_up_template
delete_template = "gsutil rm -r {options} {target}"
exclude_template = "-x '{regex_pattern}'"
elif remote_path.startswith(HDFS_PREFIX):
if not distutils.spawn.find_executable("hdfs"):
raise ValueError(
"Upload uri starting with '{}' requires hdfs tool"
" to be installed".format(HDFS_PREFIX)
)
sync_up_template = "hdfs dfs -put -f {source} {target}"
sync_down_template = "hdfs dfs -get -f {source} {target}"
delete_template = "hdfs dfs -rm -r {target}"
exclude_template = None
else:
raise ValueError(
f"Upload uri must start with one of: {ALLOWED_REMOTE_PREFIXES} "
f"(is: `{remote_path}`)"
)
return CommandBasedClient(
sync_up_template, sync_down_template, delete_template, exclude_template
)
@PublicAPI(stability="beta")
@Deprecated
class SyncClient(abc.ABC):
"""Client interface for interacting with remote storage options."""
def sync_up(self, source: str, target: str, exclude: Optional[List] = None):
"""Syncs up from source to target.
Args:
source: Source path.
target: Target path.
exclude: Pattern of files to exclude, e.g.
``["*/checkpoint_*]`` to exclude trial checkpoints.
Returns:
True if sync initiation successful, False otherwise.
"""
raise NotImplementedError
def sync_down(self, source: str, target: str, exclude: Optional[List] = None):
"""Syncs down from source to target.
Args:
source: Source path.
target: Target path.
exclude: Pattern of files to exclude, e.g.
``["*/checkpoint_*]`` to exclude trial checkpoints.
Returns:
True if sync initiation successful, False otherwise.
"""
raise NotImplementedError
def delete(self, target: str):
"""Deletes target.
Args:
target: Target path.
Returns:
True if delete initiation successful, False otherwise.
"""
raise NotImplementedError
def wait(self):
"""Waits for current sync to complete, if asynchronously started."""
pass
def wait_or_retry(self, max_retries: int = 3, backoff_s: int = 5):
"""Wait for current sync to complete or retries on error."""
pass
def reset(self):
"""Resets state."""
pass
def close(self):
"""Clean up hook."""
pass
def _is_legacy_sync_fn(func) -> bool:
sig = inspect.signature(func)
try:
sig.bind_partial(None, None, None)
return False
except TypeError:
return True
@DeveloperAPI
class FunctionBasedClient(SyncClient):
def __init__(self, sync_up_func, sync_down_func, delete_func=None):
self.sync_up_func = sync_up_func
self._sync_up_legacy = _is_legacy_sync_fn(sync_up_func)
self.sync_down_func = sync_down_func
self._sync_down_legacy = _is_legacy_sync_fn(sync_up_func)
if self._sync_up_legacy or self._sync_down_legacy:
def __init__(self, *args, **kwargs):
raise DeprecationWarning(
"Your sync functions currently only accepts two params "
"(a `source` and a `target`). In the future, we will "
"pass an additional `exclude` parameter. Please adjust "
"your sync function accordingly."
"SyncClient has been deprecated. Please implement a "
"`ray.tune.syncer.Syncer` instead."
)
self.delete_func = delete_func or noop
def sync_up(self, source, target, exclude: Optional[List] = None):
if self._sync_up_legacy:
self.sync_up_func(source, target)
else:
self.sync_up_func(source, target, exclude)
return True
def sync_down(self, source, target, exclude: Optional[List] = None):
if self._sync_down_legacy:
self.sync_down_func(source, target)
else:
self.sync_down_func(source, target, exclude)
return True
def delete(self, target):
self.delete_func(target)
return True
@Deprecated
class FunctionBasedClient(SyncClient):
def __init__(self, *args, **kwargs):
raise DeprecationWarning(
"FunctionBasedClient has been deprecated. Please implement a "
"`ray.tune.syncer.Syncer` instead."
)
NOOP = FunctionBasedClient(noop, noop)
@DeveloperAPI
@Deprecated
class CommandBasedClient(SyncClient):
"""Syncs between two directories with the given command.
If a sync is already in-flight when calling ``sync_down`` or
``sync_up``, a warning will be printed and the new sync command is
ignored. To force a new sync, either use ``wait()``
(or ``wait_or_retry()``) to wait until the previous sync has finished,
or call ``reset()`` to detach from the previous sync. Note that this
will not kill the previous sync command, so it may still be executed.
Arguments:
sync_up_template: A runnable string template; needs to
include replacement fields ``{source}``, ``{target}``, and
``{options}``.
sync_down_template: A runnable string template; needs to
include replacement fields ``{source}``, ``{target}``, and
``{options}``.
delete_template: A runnable string template; needs
to include replacement field ``{target}``. Noop by default.
exclude_template: A pattern with possible
replacement fields ``{pattern}`` and ``{regex_pattern}``.
Will replace ``{options}}`` in the sync up/down templates
if files/directories to exclude are passed.
"""
def __init__(
self,
sync_up_template: str,
sync_down_template: str,
delete_template: Optional[str] = noop_template,
exclude_template: Optional[str] = None,
):
self._validate_sync_string(sync_up_template)
self._validate_sync_string(sync_down_template)
self._validate_exclude_template(exclude_template)
self.sync_up_template = sync_up_template
self.sync_down_template = sync_down_template
self.delete_template = delete_template
self.exclude_template = exclude_template
self.logfile = None
self._closed = False
self.cmd_process = None
# Keep track of last command for retry
self._last_cmd = None
def set_logdir(self, logdir: str):
"""Sets the directory to log sync execution output in.
Args:
logdir: Log directory.
"""
self.logfile = tempfile.NamedTemporaryFile(
prefix="log_sync_out", dir=logdir, suffix=".log", delete=False
)
self._closed = False
def _get_logfile(self):
if self._closed:
raise RuntimeError(
"[internalerror] The client has been closed. "
"Please report this stacktrace + your cluster configuration "
"on Github!"
)
else:
return self.logfile
def _start_process(self, cmd: str) -> subprocess.Popen:
return subprocess.Popen(
cmd, shell=True, stderr=subprocess.PIPE, stdout=self._get_logfile()
)
def sync_up(self, source, target, exclude: Optional[List] = None):
return self._execute(self.sync_up_template, source, target, exclude)
def sync_down(self, source, target, exclude: Optional[List] = None):
# Just in case some command line sync client expects that local
# directory exists.
pathlib.Path(target).mkdir(parents=True, exist_ok=True)
return self._execute(self.sync_down_template, source, target, exclude)
def delete(self, target):
if self.is_running:
logger.warning(
f"Last sync client cmd still in progress, "
f"skipping deletion of {target}"
)
return False
final_cmd = self.delete_template.format(target=quote(target), options="")
logger.debug("Running delete: {}".format(final_cmd))
self._last_cmd = final_cmd
self.cmd_process = self._start_process(final_cmd)
return True
def wait(self):
if self.cmd_process:
_, error_msg = self.cmd_process.communicate()
error_msg = error_msg.decode("ascii")
code = self.cmd_process.returncode
args = self.cmd_process.args
self.cmd_process = None
if code != 0:
raise TuneError(
"Sync error. Ran command: {}\n"
"Error message ({}): {}".format(args, code, error_msg)
)
def wait_or_retry(self, max_retries: int = 3, backoff_s: int = 5):
assert max_retries > 0
for _ in range(max_retries - 1):
try:
self.wait()
except TuneError as e:
logger.error(
f"Caught sync error: {e}. "
f"Retrying after sleeping for {backoff_s} seconds..."
)
time.sleep(backoff_s)
self.cmd_process = self._start_process(self._last_cmd)
continue
return
self.cmd_process = None
raise TuneError(f"Failed sync even after {max_retries} retries.")
def reset(self):
if self.is_running:
logger.warning("Sync process still running but resetting anyways.")
self.cmd_process = None
self._last_cmd = None
def close(self):
if self.logfile:
logger.debug(f"Closing the logfile: {str(self.logfile)}")
self.logfile.close()
self.logfile = None
self._closed = True
@property
def is_running(self):
"""Returns whether a sync or delete process is running."""
if self.cmd_process:
self.cmd_process.poll()
return self.cmd_process.returncode is None
return False
def _execute(self, sync_template, source, target, exclude: Optional[List] = None):
"""Executes sync_template on source and target."""
if self.is_running:
logger.warning(
f"Last sync client cmd still in progress, "
f"skipping sync from {source} to {target}."
)
return False
if exclude and self.exclude_template:
options = []
if "{pattern}" in self.exclude_template:
for excl in exclude:
options.append(self.exclude_template.format(pattern=excl))
elif "{regex_pattern}" in self.exclude_template:
# This is obviously not a great way to convert to regex,
# but it will do for the moment. Todo: Improve.
def _to_regex(pattern: str) -> str:
return f"({pattern.replace('*', '.*')})"
regex_pattern = "|".join(_to_regex(excl) for excl in exclude)
options.append(
self.exclude_template.format(regex_pattern=regex_pattern)
)
option_str = " ".join(options)
else:
option_str = ""
final_cmd = sync_template.format(
source=quote(source), target=quote(target), options=option_str
)
logger.debug("Running sync: {}".format(final_cmd))
self._last_cmd = final_cmd
self.cmd_process = self._start_process(final_cmd)
return True
@staticmethod
def _validate_sync_string(sync_string):
if not isinstance(sync_string, str):
raise ValueError("{} is not a string.".format(sync_string))
if "{source}" not in sync_string:
raise ValueError("Sync template missing `{source}`: " f"{sync_string}.")
if "{target}" not in sync_string:
raise ValueError("Sync template missing `{target}`: " f"{sync_string}.")
@staticmethod
def _validate_exclude_template(exclude_template):
if exclude_template:
if (
"{pattern}" not in exclude_template
and "{regex_pattern}" not in exclude_template
):
raise ValueError(
"Neither `{pattern}` nor `{regex_pattern}` found in "
f"exclude string `{exclude_template}`"
def __init__(self, *args, **kwargs):
raise DeprecationWarning(
"CommandBasedClient has been deprecated. Please implement a "
"`ray.tune.syncer.Syncer` instead."
)
@DeveloperAPI
@Deprecated
class RemoteTaskClient(SyncClient):
"""Sync client that uses remote tasks to synchronize two directories.
This client expects tuples of (ip, path) for remote sources/targets
in sync_down/sync_up.
To avoid unnecessary syncing, the sync client will collect the existing
files with their respective mtimes and sizes on the (possibly remote)
target directory. Only files that are not in the target directory or
differ to those in the target directory by size or mtime will be
transferred. This is similar to most cloud
synchronization implementations (e.g. aws s3 sync).
If a sync is already in-flight when calling ``sync_down`` or
``sync_up``, a warning will be printed and the new sync command is
ignored. To force a new sync, either use ``wait()``
(or ``wait_or_retry()``) to wait until the previous sync has finished,
or call ``reset()`` to detach from the previous sync. Note that this
will not kill the previous sync command, so it may still be executed.
"""
def __init__(self, _store_remotes: bool = False):
# Used for testing
self._store_remotes = _store_remotes
self._stored_pack_actor_ref = None
self._stored_files_stats_future = None
self._sync_future = None
self._last_source_tuple = None
self._last_target_tuple = None
self._max_size_bytes = None # No file size limit
def _sync_still_running(self) -> bool:
if not self._sync_future:
return False
ready, not_ready = ray.wait([self._sync_future], timeout=0.0)
if self._sync_future in ready:
self.wait()
return False
return True
def sync_down(
self, source: Tuple[str, str], target: str, exclude: Optional[List] = None
) -> bool:
if self._sync_still_running():
logger.warning(
f"Last remote task sync still in progress, "
f"skipping sync from {source} to {target}."
def __init__(self, *args, **kwargs):
raise DeprecationWarning(
"RemoteTaskClient has been deprecated. Please implement a "
"`ray.tune.syncer.Syncer` instead."
)
return False
source_ip, source_path = source
target_ip = ray.util.get_node_ip_address()
self._last_source_tuple = source_ip, source_path
self._last_target_tuple = target_ip, target
return self._execute_sync(self._last_source_tuple, self._last_target_tuple)
def sync_up(
self, source: str, target: Tuple[str, str], exclude: Optional[List] = None
) -> bool:
if self._sync_still_running():
logger.warning(
f"Last remote task sync still in progress, "
f"skipping sync from {source} to {target}."
)
return False
source_ip = ray.util.get_node_ip_address()
target_ip, target_path = target
self._last_source_tuple = source_ip, source
self._last_target_tuple = target_ip, target_path
return self._execute_sync(self._last_source_tuple, self._last_target_tuple)
def _sync_function(self, *args, **kwargs):
return sync_dir_between_nodes(*args, **kwargs)
def _execute_sync(
self,
source_tuple: Tuple[str, str],
target_tuple: Tuple[str, str],
) -> bool:
source_ip, source_path = source_tuple
target_ip, target_path = target_tuple
self._sync_future, pack_actor, files_stats = self._sync_function(
source_ip=source_ip,
source_path=source_path,
target_ip=target_ip,
target_path=target_path,
return_futures=True,
max_size_bytes=self._max_size_bytes,
)
if self._store_remotes:
self._stored_pack_actor_ref = pack_actor
self._stored_files_stats = files_stats
return True
def delete(self, target: str):
if not self._last_target_tuple:
logger.warning(
f"Could not delete path {target} as the target node is not known."
)
return
node_ip = self._last_target_tuple[0]
try:
delete_on_node(node_ip=node_ip, path=target)
except Exception as e:
logger.warning(
f"Could not delete path {target} on remote node {node_ip}: {e}"
)
def wait(self):
if self._sync_future:
try:
ray.get(self._sync_future)
except Exception as e:
raise TuneError(
f"Remote task sync failed from "
f"{self._last_source_tuple} to "
f"{self._last_target_tuple}: {e}"
) from e
self._sync_future = None
self._stored_pack_actor_ref = None
self._stored_files_stats_future = None
def wait_or_retry(self, max_retries: int = 3, backoff_s: int = 5):
assert max_retries > 0
for _ in range(max_retries - 1):
try:
self.wait()
except TuneError as e:
logger.error(
f"Caught sync error: {e}. "
f"Retrying after sleeping for {backoff_s} seconds..."
)
time.sleep(backoff_s)
self._execute_sync(
self._last_source_tuple,
self._last_target_tuple,
)
continue
return
self._sync_future = None
self._stored_pack_actor_ref = None
self._stored_files_stats_future = None
raise TuneError(f"Failed sync even after {max_retries} retries.")
def reset(self):
if self._sync_future:
logger.warning("Sync process still running but resetting anyways.")
self._sync_future = None
self._last_source_tuple = None
self._last_target_tuple = None
self._stored_pack_actor_ref = None
self._stored_files_stats_future = None
def close(self):
self._sync_future = None # Avoid warning
self.reset()

View file

@ -1,40 +1,31 @@
import abc
import threading
from typing import (
Callable,
Dict,
List,
TYPE_CHECKING,
Type,
Union,
Optional,
Tuple,
)
import distutils
import logging
import os
import time
from dataclasses import dataclass
from inspect import isclass
from shlex import quote
import ray
import yaml
from ray.air._internal.remote_storage import get_fs_and_path, fs_hint
from ray.air._internal.remote_storage import (
fs_hint,
upload_to_uri,
download_from_uri,
delete_at_uri,
is_non_local_path_uri,
)
from ray.tune import TuneError
from ray.tune.callback import Callback
from ray.tune.result import NODE_IP
from ray.util import get_node_ip_address
from ray.util.debug import log_once
from ray.tune.cluster_info import get_ssh_key, get_ssh_user
from ray.tune.sync_client import (
CommandBasedClient,
get_sync_client,
get_cloud_sync_client,
NOOP,
SyncClient,
RemoteTaskClient,
)
from ray.tune.utils.file_transfer import sync_dir_between_nodes
from ray.util.annotations import PublicAPI, DeveloperAPI
from ray.util.ml_utils.checkpoint_manager import CheckpointStorage, _TrackedCheckpoint
@ -44,71 +35,22 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
# Syncing period for syncing checkpoints between nodes or to cloud.
SYNC_PERIOD = 300
CLOUD_CHECKPOINTING_URL = (
"https://docs.ray.io/en/master/tune/user-guide.html#using-cloud-storage"
)
_log_sync_warned = False
_syncers = {}
DEFAULT_SYNC_PERIOD = 300
def wait_for_sync():
for syncer in _syncers.values():
syncer.wait()
def _validate_upload_dir(sync_config: "SyncConfig") -> bool:
if not sync_config.upload_dir:
return True
if sync_config.upload_dir.startswith("file://"):
return True
def validate_upload_dir(sync_config: "SyncConfig"):
if sync_config.upload_dir:
exc = None
try:
fs, _ = get_fs_and_path(sync_config.upload_dir)
except ImportError as e:
fs = None
exc = e
if not fs:
if not is_non_local_path_uri(sync_config.upload_dir):
raise ValueError(
f"Could not identify external storage filesystem for "
f"upload dir `{sync_config.upload_dir}`. "
f"Hint: {fs_hint(sync_config.upload_dir)}"
) from exc
def set_sync_periods(sync_config: "SyncConfig"):
"""Sets sync period from config."""
global SYNC_PERIOD
SYNC_PERIOD = int(sync_config.sync_period)
def get_rsync_template_if_available(options: str = ""):
"""Template enabling syncs between driver and worker when possible.
Requires ray cluster to be started with the autoscaler. Also requires
rsync to be installed.
Args:
options: Additional rsync options.
Returns:
Sync template with source and target parameters. None if rsync
unavailable.
"""
if not distutils.spawn.find_executable("rsync"):
if log_once("tune:rsync"):
logger.error("Log sync requires rsync to be installed.")
return None
global _log_sync_warned
ssh_key = get_ssh_key()
if ssh_key is None:
if not _log_sync_warned:
logger.debug("Log sync requires cluster to be setup with `ray up`.")
_log_sync_warned = True
return None
rsh = "ssh -i {ssh_key} -o ConnectTimeout=120s -o StrictHostKeyChecking=no"
rsh = rsh.format(ssh_key=quote(ssh_key))
options += " --exclude='checkpoint_tmp*'"
template = "rsync {options} -savz -e {rsh} {{source}} {{target}}"
return template.format(options=options, rsh=quote(rsh))
)
@PublicAPI
@ -124,13 +66,8 @@ class SyncConfig:
upload_dir: Optional URI to sync training results and checkpoints
to (e.g. ``s3://bucket``, ``gs://bucket`` or ``hdfs://path``).
Specifying this will enable cloud-based checkpointing.
syncer: Function for syncing the local_dir to and
from remote storage. If string, then it must be a string template
that includes ``{source}`` and ``{target}`` for the syncer to run.
If not provided, it defaults to rsync for non cloud-based storage,
and to standard S3, gsutil or HDFS sync commands for cloud-based
storage.
If set to ``None``, no syncing will take place.
syncer: Syncer class to use for synchronizing checkpoints to/from
cloud storage. If set to ``None``, no syncing will take place.
Defaults to ``"auto"`` (auto detect).
sync_on_checkpoint: Force sync-down of trial checkpoint to
driver (only non cloud-storage).
@ -142,395 +79,411 @@ class SyncConfig:
"""
upload_dir: Optional[str] = None
syncer: Optional[str] = "auto"
syncer: Optional[Union[str, "Syncer"]] = "auto"
sync_on_checkpoint: bool = True
sync_period: int = 300
sync_period: int = DEFAULT_SYNC_PERIOD
class _BackgroundProcess:
def __init__(self, fn: Callable):
self._fn = fn
self._process = None
self._result = {}
@property
def is_running(self):
return self._process and self._process.is_alive()
def start(self, *args, **kwargs):
if self.is_running:
return False
self._result = {}
def entrypoint():
try:
result = self._fn(*args, **kwargs)
except Exception as e:
self._result["exception"] = e
return
self._result["result"] = result
self._process = threading.Thread(target=entrypoint)
self._process.start()
def wait(self):
if not self._process:
return
self._process.join()
self._process = None
exception = self._result.get("exception")
if exception:
raise exception
result = self._result.get("result")
self._result = {}
return result
@DeveloperAPI
class Syncer:
def __init__(self, local_dir: str, remote_dir: str, sync_client: SyncClient = NOOP):
"""Syncs between two directories with the sync_function.
class Syncer(abc.ABC):
"""Syncer class for synchronizing data between Ray nodes and external storage.
This class handles data transfer for two cases:
1. Synchronizing data from the driver to external storage. This affects
experiment-level checkpoints and trial-level checkpoints if no cloud storage
is used.
2. Synchronizing data from remote trainables to external storage.
Synchronizing tasks are usually asynchronous and can be awaited using ``wait()``.
The base class implements a ``wait_or_retry()`` API that will retry a failed
sync command.
The base class also exposes an API to only kick off syncs every ``sync_period``
seconds.
Arguments:
local_dir: Directory to sync. Uniquely identifies the syncer.
remote_dir: Remote directory to sync with.
sync_client: Client for syncing between local_dir and
remote_dir. Defaults to a Noop.
"""
self._local_dir = os.path.join(local_dir, "") if local_dir else local_dir
self._remote_dir = remote_dir
def __init__(self, sync_period: float = 300.0):
self.sync_period = sync_period
self.last_sync_up_time = float("-inf")
self.last_sync_down_time = float("-inf")
self.sync_client = sync_client
@property
def _pass_ip_path_tuples(self) -> False:
"""Return True if the sync client expects (ip, path) tuples instead
of rsync strings (user@ip:/path/)."""
return isinstance(self.sync_client, RemoteTaskClient)
@abc.abstractmethod
def sync_up(
self, local_dir: str, remote_dir: str, exclude: Optional[List] = None
) -> bool:
"""Synchronize local directory to remote directory.
def sync_up_if_needed(self, sync_period: int, exclude: Optional[List] = None):
This function can spawn an asynchronous process that can be awaited in
``wait()``.
Args:
local_dir: Local directory to sync from.
remote_dir: Remote directory to sync up to. This is an URI
(``protocol://remote/path``).
exclude: Pattern of files to exclude, e.g.
``["*/checkpoint_*]`` to exclude trial checkpoints.
Returns:
True if sync process has been spawned, False otherwise.
"""
raise NotImplementedError
@abc.abstractmethod
def sync_down(
self, remote_dir: str, local_dir: str, exclude: Optional[List] = None
) -> bool:
"""Synchronize remote directory to local directory.
This function can spawn an asynchronous process that can be awaited in
``wait()``.
Args:
remote_dir: Remote directory to sync down from. This is an URI
(``protocol://remote/path``).
local_dir: Local directory to sync to.
exclude: Pattern of files to exclude, e.g.
``["*/checkpoint_*]`` to exclude trial checkpoints.
Returns:
True if sync process has been spawned, False otherwise.
"""
raise NotImplementedError
@abc.abstractmethod
def delete(self, remote_dir: str) -> bool:
"""Delete directory on remote storage.
This function can spawn an asynchronous process that can be awaited in
``wait()``.
Args:
remote_dir: Remote directory to delete. This is an URI
(``protocol://remote/path``).
Returns:
True if sync process has been spawned, False otherwise.
"""
raise NotImplementedError
def retry(self):
"""Retry the last sync up, sync down, or delete command.
You should implement this method if you spawn asynchronous syncing
processes.
"""
pass
def wait(self):
"""Wait for asynchronous sync command to finish.
You should implement this method if you spawn asynchronous syncing
processes.
"""
pass
def sync_up_if_needed(
self, local_dir: str, remote_dir: str, exclude: Optional[List] = None
) -> bool:
"""Syncs up if time since last sync up is greater than sync_period.
Args:
sync_period: Time period between subsequent syncs.
local_dir: Local directory to sync from.
remote_dir: Remote directory to sync up to. This is an URI
(``protocol://remote/path``).
exclude: Pattern of files to exclude, e.g.
``["*/checkpoint_*]`` to exclude trial checkpoints.
"""
now = time.time()
if now - self.last_sync_up_time >= self.sync_period:
result = self.sync_up(
local_dir=local_dir, remote_dir=remote_dir, exclude=exclude
)
self.last_sync_up_time = now
return result
if time.time() - self.last_sync_up_time > sync_period:
self.sync_up(exclude)
def sync_down_if_needed(self, sync_period: int, exclude: Optional[List] = None):
def sync_down_if_needed(
self, remote_dir: str, local_dir: str, exclude: Optional[List] = None
):
"""Syncs down if time since last sync down is greater than sync_period.
Args:
sync_period: Time period between subsequent syncs.
remote_dir: Remote directory to sync down from. This is an URI
(``protocol://remote/path``).
local_dir: Local directory to sync to.
exclude: Pattern of files to exclude, e.g.
``["*/checkpoint_*]`` to exclude trial checkpoints.
"""
if time.time() - self.last_sync_down_time > sync_period:
self.sync_down(exclude)
def sync_up(self, exclude: Optional[List] = None):
"""Attempts to start the sync-up to the remote path.
Args:
exclude: Pattern of files to exclude, e.g.
``["*/checkpoint_*]`` to exclude trial checkpoints.
Returns:
Whether the sync (if feasible) was successfully started.
"""
result = False
if self.validate_hosts(self._local_dir, self._remote_path):
try:
result = self.sync_client.sync_up(
self._local_dir, self._remote_path, exclude=exclude
now = time.time()
if now - self.last_sync_down_time >= self.sync_period:
result = self.sync_down(
remote_dir=remote_dir, local_dir=local_dir, exclude=exclude
)
self.last_sync_up_time = time.time()
except Exception:
logger.exception("Sync execution failed.")
self.last_sync_down_time = now
return result
def sync_down(self, exclude: Optional[List] = None):
"""Attempts to start the sync-down from the remote path.
Args:
exclude: Pattern of files to exclude, e.g.
``["*/checkpoint_*]`` to exclude trial checkpoints.
Returns:
Whether the sync (if feasible) was successfully started.
"""
result = False
if self.validate_hosts(self._local_dir, self._remote_path):
def wait_or_retry(self, max_retries: int = 3, backoff_s: int = 5):
assert max_retries > 0
last_error = None
for _ in range(max_retries - 1):
try:
result = self.sync_client.sync_down(
self._remote_path, self._local_dir, exclude=exclude
self.wait()
except Exception as e:
logger.error(
f"Caught sync error: {e}. "
f"Retrying after sleeping for {backoff_s} seconds..."
)
self.last_sync_down_time = time.time()
except Exception:
logger.exception("Sync execution failed.")
return result
def validate_hosts(self, source, target):
if not (source and target):
logger.debug(
"Source or target is empty, skipping log sync for "
"{}".format(self._local_dir)
)
return False
return True
def wait(self):
"""Waits for the sync client to complete the current sync."""
self.sync_client.wait()
last_error = e
time.sleep(backoff_s)
self.retry()
continue
return
raise TuneError(
f"Failed sync even after {max_retries} retries."
) from last_error
def reset(self):
self.last_sync_up_time = float("-inf")
self.last_sync_down_time = float("-inf")
self.sync_client.reset()
def close(self):
self.sync_client.close()
@property
def _remote_path(self) -> Optional[Union[str, Tuple[str, str]]]:
return self._remote_dir
pass
@DeveloperAPI
class CloudSyncer(Syncer):
"""Syncer for syncing files to/from the cloud."""
class _DefaultSyncer(Syncer):
"""Default syncer between local storage and remote URI."""
def __init__(self, local_dir, remote_dir, sync_client):
super(CloudSyncer, self).__init__(local_dir, remote_dir, sync_client)
def __init__(self, sync_period: float = 300.0):
super(_DefaultSyncer, self).__init__(sync_period=sync_period)
self._sync_process = None
self._current_cmd = None
def sync_up_if_needed(self, exclude: Optional[List] = None):
return super(CloudSyncer, self).sync_up_if_needed(SYNC_PERIOD, exclude=exclude)
def sync_down_if_needed(self, exclude: Optional[List] = None):
return super(CloudSyncer, self).sync_down_if_needed(
SYNC_PERIOD, exclude=exclude
)
@DeveloperAPI
class NodeSyncer(Syncer):
"""Syncer for syncing files to/from a remote dir to a local dir."""
def __init__(self, local_dir, remote_dir, sync_client):
self.local_ip = get_node_ip_address()
self.worker_ip = None
super(NodeSyncer, self).__init__(local_dir, remote_dir, sync_client)
def set_worker_ip(self, worker_ip):
"""Sets the worker IP to sync logs from."""
self.worker_ip = worker_ip
def has_remote_target(self):
"""Returns whether the Syncer has a remote target."""
if not self.worker_ip:
logger.debug("Worker IP unknown, skipping sync for %s", self._local_dir)
return False
if self.worker_ip == self.local_ip:
logger.debug("Worker IP is local IP, skipping sync for %s", self._local_dir)
return False
return True
def sync_up_if_needed(self, exclude: Optional[List] = None):
if not self.has_remote_target():
return True
return super(NodeSyncer, self).sync_up_if_needed(SYNC_PERIOD, exclude=exclude)
def sync_down_if_needed(self, exclude: Optional[List] = None):
if not self.has_remote_target():
return True
return super(NodeSyncer, self).sync_down_if_needed(SYNC_PERIOD, exclude=exclude)
def sync_up_to_new_location(self, worker_ip):
if worker_ip != self.worker_ip:
logger.debug("Setting new worker IP to %s", worker_ip)
self.set_worker_ip(worker_ip)
self.reset()
if not self.sync_up():
def sync_up(
self, local_dir: str, remote_dir: str, exclude: Optional[List] = None
) -> bool:
if self._sync_process and self._sync_process.is_running:
logger.warning(
"Sync up to new location skipped. This should not occur."
f"Last sync still in progress, "
f"skipping sync up of {local_dir} to {remote_dir}"
)
else:
logger.warning("Sync attempted to same IP %s.", worker_ip)
return False
elif self._sync_process:
try:
self._sync_process.wait()
except Exception as e:
logger.warning(f"Last sync command failed: {e}")
self._current_cmd = (
upload_to_uri,
dict(local_path=local_dir, uri=remote_dir, exclude=exclude),
)
self.retry()
def sync_up(self, exclude: Optional[List] = None):
if not self.has_remote_target():
return True
return super(NodeSyncer, self).sync_up(exclude=exclude)
def sync_down(self, exclude: Optional[List] = None):
if not self.has_remote_target():
def sync_down(
self, remote_dir: str, local_dir: str, exclude: Optional[List] = None
) -> bool:
if self._sync_process and self._sync_process.is_running:
logger.warning(
f"Last sync still in progress, "
f"skipping sync down of {remote_dir} to {local_dir}"
)
return False
elif self._sync_process:
try:
self._sync_process.wait()
except Exception as e:
logger.warning(f"Last sync command failed: {e}")
self._current_cmd = (
download_from_uri,
dict(uri=remote_dir, local_path=local_dir),
)
self.retry()
return True
logger.debug("Syncing from %s to %s", self._remote_path, self._local_dir)
return super(NodeSyncer, self).sync_down(exclude=exclude)
@property
def _remote_path(self) -> Optional[Union[str, Tuple[str, str]]]:
ssh_user = get_ssh_user()
global _log_sync_warned
if not self.has_remote_target():
return None
if ssh_user is None:
if not _log_sync_warned:
logger.error("Syncer requires cluster to be setup with `ray up`.")
_log_sync_warned = True
return None
if self._pass_ip_path_tuples:
return self.worker_ip, self._remote_dir
return "{}@{}:{}/".format(ssh_user, self.worker_ip, self._remote_dir)
def delete(self, remote_dir: str) -> bool:
if self._sync_process and self._sync_process.is_running:
logger.warning(
f"Last sync still in progress, skipping deletion of {remote_dir}"
)
return False
self._current_cmd = (delete_at_uri, dict(uri=remote_dir))
self.retry()
return True
def wait(self):
if self._sync_process:
try:
self._sync_process.wait()
except Exception as e:
raise TuneError(f"Sync process failed: {e}") from e
finally:
self._sync_process = None
def retry(self):
if not self._current_cmd:
raise TuneError("No sync command set, cannot retry.")
cmd, kwargs = self._current_cmd
self._sync_process = _BackgroundProcess(cmd)
self._sync_process.start(**kwargs)
@DeveloperAPI
def get_cloud_syncer(
local_dir: str,
remote_dir: Optional[str] = None,
sync_function: Optional[Union[Callable, str]] = None,
) -> CloudSyncer:
"""Returns a Syncer.
def get_node_to_storage_syncer(sync_config: SyncConfig) -> Optional[Syncer]:
""""""
if sync_config.syncer is None:
return None
This syncer is in charge of syncing the local_dir with upload_dir.
if not sync_config.upload_dir:
return None
If no ``remote_dir`` is provided, it will return a no-op syncer.
if sync_config.syncer == "auto":
return _DefaultSyncer(sync_period=sync_config.sync_period)
If a ``sync_function`` is provided, it will return a CloudSyncer using
a custom SyncClient initialized by the sync function. Otherwise it will
return a CloudSyncer with default templates for s3/gs/hdfs.
if isinstance(sync_config.syncer, Syncer):
return sync_config.syncer
Args:
local_dir: Source directory for syncing.
remote_dir: Target directory for syncing. If not provided, a
no-op Syncer is returned.
sync_function: Function for syncing the local_dir to
remote_dir. If string, then it must be a string template for
syncer to run. If not provided, it defaults
to standard S3, gsutil or HDFS sync commands.
Raises:
ValueError if malformed remote_dir.
"""
key = (local_dir, remote_dir)
if key in _syncers:
return _syncers[key]
if not remote_dir:
_syncers[key] = CloudSyncer(local_dir, remote_dir, NOOP)
return _syncers[key]
if sync_function == "auto":
sync_function = None # Auto-detect
# Maybe get user-provided sync client here
client = get_sync_client(sync_function)
if client:
# If the user provided a sync template or function
_syncers[key] = CloudSyncer(local_dir, remote_dir, client)
else:
# Else, get default cloud sync client (e.g. S3 syncer)
sync_client = get_cloud_sync_client(remote_dir)
_syncers[key] = CloudSyncer(local_dir, remote_dir, sync_client)
return _syncers[key]
@DeveloperAPI
def get_node_syncer(
local_dir: str,
remote_dir: Optional[str] = None,
sync_function: Optional[Union[Callable, str, bool, Type[Syncer]]] = None,
):
"""Returns a NodeSyncer.
Args:
local_dir: Source directory for syncing.
remote_dir: Target directory for syncing. If not provided, a
noop Syncer is returned.
sync_function: Function for syncing the local_dir to
remote_dir. If string, then it must be a string template for
syncer to run. If True or not provided, it defaults rsync
(if available) or otherwise remote-task based syncing. If
False, a noop Syncer is returned.
"""
if sync_function == "auto":
sync_function = None # Auto-detect
key = (local_dir, remote_dir)
if key in _syncers:
# Get cached syncer
return _syncers[key]
elif isclass(sync_function) and issubclass(sync_function, Syncer):
# Type[Syncer]
_syncers[key] = sync_function(local_dir, remote_dir, None)
return _syncers[key]
elif not remote_dir or sync_function is False:
# Do not sync trials if no remote dir specified or syncer=False
sync_client = NOOP
elif sync_function and sync_function is not True:
# String or callable (for function syncers)
sync_client = get_sync_client(sync_function)
else:
# sync_function=True or sync_function=None --> default
rsync_function_str = get_rsync_template_if_available()
if rsync_function_str:
sync_client = CommandBasedClient(rsync_function_str, rsync_function_str)
sync_client.set_logdir(local_dir)
else:
sync_client = RemoteTaskClient()
_syncers[key] = NodeSyncer(local_dir, remote_dir, sync_client)
return _syncers[key]
raise ValueError(
f"Unknown syncer type passed in SyncConfig: {type(sync_config.syncer)}. "
f"Note that custom sync functions and templates have been deprecated. "
f"Instead you can implement you own `Syncer` class. "
f"Please leave a comment on GitHub if you run into any issues with this: "
f"https://github.com/ray-project/ray/issues"
)
@DeveloperAPI
class SyncerCallback(Callback):
def __init__(self, sync_function: Optional[Union[bool, Callable]]):
self._sync_function = sync_function
self._syncers: Dict["Trial", NodeSyncer] = {}
"""Callback to synchronize trial directories on a worker node with the driver."""
def _get_trial_syncer(self, trial: "Trial"):
if trial not in self._syncers:
self._syncers[trial] = self._create_trial_syncer(trial)
return self._syncers[trial]
def __init__(self, enabled: bool = True, sync_period: float = DEFAULT_SYNC_PERIOD):
self._enabled = enabled
self._sync_processes: Dict[str, _BackgroundProcess] = {}
self._sync_times: Dict[str, float] = {}
self._sync_period = sync_period
def _create_trial_syncer(self, trial: "Trial"):
return get_node_syncer(
trial.logdir, remote_dir=trial.logdir, sync_function=self._sync_function
def _get_trial_sync_process(self, trial: "Trial"):
return self._sync_processes.setdefault(
trial.trial_id, _BackgroundProcess(sync_dir_between_nodes)
)
def _remove_trial_syncer(self, trial: "Trial"):
self._syncers.pop(trial, None)
def _remove_trial_sync_process(self, trial: "Trial"):
self._sync_processes.pop(trial.trial_id, None)
def _sync_trial_checkpoint(self, trial: "Trial", checkpoint: _TrackedCheckpoint):
if checkpoint.storage_mode == CheckpointStorage.MEMORY:
return
def _should_sync(self, trial: "Trial"):
last_sync_time = self._sync_times.setdefault(trial.trial_id, float("-inf"))
return time.time() - last_sync_time >= self._sync_period
def _mark_as_synced(self, trial: "Trial"):
self._sync_times[trial.trial_id] = time.time()
def _local_trial_logdir(self, trial: "Trial"):
return trial.logdir
def _remote_trial_logdir(self, trial: "Trial"):
return trial.logdir
def _sync_trial_dir(
self, trial: "Trial", force: bool = False, wait: bool = True
) -> bool:
if not self._enabled or trial.uses_cloud_checkpointing:
return False
sync_process = self._get_trial_sync_process(trial)
# Always run if force=True
# Otherwise, only run if we should sync (considering sync period)
# or if there is no sync currently still running.
if not force and (not self._should_sync(trial) or sync_process.is_running):
return False
if NODE_IP in trial.last_result:
source_ip = trial.last_result[NODE_IP]
else:
source_ip = ray.get(trial.runner.get_current_ip.remote())
trial_syncer = self._get_trial_syncer(trial)
# If the sync_function is False, syncing to driver is disabled.
# In every other case (valid values include None, True Callable,
# NodeSyncer) syncing to driver is enabled.
if trial.sync_on_checkpoint and self._sync_function is not False:
try:
# Wait for any other syncs to finish. We need to sync again
# after this to handle checkpoints taken mid-sync.
trial_syncer.wait()
sync_process.wait()
except TuneError as e:
# Errors occurring during this wait are not fatal for this
# checkpoint, so it should just be logged.
logger.error(
f"Trial {trial}: An error occurred during the "
f"checkpoint pre-sync wait: {e}"
f"checkpoint syncing of the previous checkpoint: {e}"
)
# Force sync down and wait before tracking the new checkpoint.
sync_process.start(
source_ip=source_ip,
source_path=self._remote_trial_logdir(trial),
target_ip=ray.util.get_node_ip_address(),
target_path=self._local_trial_logdir(trial),
)
self._sync_times[trial.trial_id] = time.time()
if wait:
try:
if trial_syncer.sync_down():
trial_syncer.wait()
else:
logger.error(
f"Trial {trial}: Checkpoint sync skipped. "
f"This should not happen."
)
sync_process.wait()
except TuneError as e:
if trial.uses_cloud_checkpointing:
# Even though rsync failed the trainable can restore
# from remote durable storage.
logger.error(f"Trial {trial}: Sync error: {e}")
else:
# If the trainable didn't have remote storage to upload
# to then this checkpoint may have been lost, so we
# shouldn't track it with the checkpoint_manager.
raise e
if not trial.uses_cloud_checkpointing:
if not os.path.exists(checkpoint.dir_or_data):
raise TuneError(
"Trial {}: Checkpoint path {} not "
"found after successful sync down. "
"Are you running on a Kubernetes or "
"managed cluster? rsync will not function "
"due to a lack of SSH functionality. "
"You'll need to use cloud-checkpointing "
"if that's the case, see instructions "
"here: {} .".format(
trial,
checkpoint.dir_or_data,
CLOUD_CHECKPOINTING_URL,
# Errors occurring during this wait are not fatal for this
# checkpoint, so it should just be logged.
logger.error(
f"Trial {trial}: An error occurred during the "
f"checkpoint syncing of the current checkpoint: {e}"
)
)
def on_trial_start(
self, iteration: int, trials: List["Trial"], trial: "Trial", **info
):
self._get_trial_syncer(trial)
return True
def on_trial_result(
self,
@ -540,23 +493,13 @@ class SyncerCallback(Callback):
result: Dict,
**info,
):
trial_syncer = self._get_trial_syncer(trial)
trial_syncer.set_worker_ip(result.get(NODE_IP))
trial_syncer.sync_down_if_needed()
self._sync_trial_dir(trial, force=False, wait=False)
def on_trial_complete(
self, iteration: int, trials: List["Trial"], trial: "Trial", **info
):
trial_syncer = self._get_trial_syncer(trial)
if NODE_IP in trial.last_result:
trainable_ip = trial.last_result[NODE_IP]
else:
trainable_ip = ray.get(trial.runner.get_current_ip.remote())
trial_syncer.set_worker_ip(trainable_ip)
# Always sync down when trial completed
trial_syncer.sync_down()
trial_syncer.close()
self._remove_trial_syncer(trial)
self._sync_trial_dir(trial, force=True, wait=True)
self._remove_trial_sync_process(trial)
def on_checkpoint(
self,
@ -566,80 +509,30 @@ class SyncerCallback(Callback):
checkpoint: _TrackedCheckpoint,
**info,
):
self._sync_trial_checkpoint(trial, checkpoint)
if checkpoint.storage_mode == CheckpointStorage.MEMORY:
return
def detect_cluster_syncer(
sync_config: Optional[SyncConfig],
cluster_config_file: str = "~/ray_bootstrap_config.yaml",
) -> Union[bool, Type, NodeSyncer]:
"""Detect cluster Syncer given SyncConfig.
Returns False if cloud checkpointing is enabled (when upload dir is
set).
Else, returns sync config syncer if manually specified.
Else, detects cluster environment (e.g. Docker, Kubernetes) and returns
syncer accordingly.
"""
from ray.tune.integration.docker import DockerSyncer
sync_config = sync_config or SyncConfig()
if bool(sync_config.upload_dir) or sync_config.syncer is None:
# No sync to driver for cloud checkpointing or if manually disabled
return False
_syncer = sync_config.syncer
if _syncer == "auto":
_syncer = None
if isinstance(_syncer, Type):
return _syncer
# Else: True or None. Auto-detect.
cluster_config_file = os.path.expanduser(cluster_config_file)
if not os.path.exists(cluster_config_file):
return _syncer
with open(cluster_config_file, "rt") as fp:
config = yaml.safe_load(fp.read())
if config.get("docker"):
logger.debug(
"Detected docker autoscaling environment. Using `DockerSyncer` "
"as sync client. If this is not correct or leads to errors, "
"please pass a `syncer` parameter in the `SyncConfig` to "
"`tune.run().` to manually configure syncing behavior."
)
return DockerSyncer
if config.get("provider", {}).get("type", "") == "kubernetes":
from ray.tune.integration.kubernetes import (
NamespacedKubernetesSyncer,
try_import_kubernetes,
if self._sync_trial_dir(
trial, force=trial.sync_on_checkpoint, wait=True
) and not os.path.exists(checkpoint.dir_or_data):
raise TuneError(
f"Trial {trial}: Checkpoint path {checkpoint.dir_or_data} not "
"found after successful sync down."
)
if not try_import_kubernetes():
logger.warning(
"Detected Ray autoscaling environment on Kubernetes, "
"but Kubernetes Python CLI is not installed. "
"Checkpoint syncing may not work properly across "
"multiple pods. Be sure to install 'kubernetes' on "
"each container."
)
def wait_for_all(self):
failed_syncs = {}
for trial, sync_process in self._sync_processes.items():
try:
sync_process.wait()
except Exception as e:
failed_syncs[trial] = e
namespace = config["provider"].get("namespace", "ray")
logger.debug(
f"Detected Ray autoscaling environment on Kubernetes. Using "
f"`NamespacedKubernetesSyncer` with namespace `{namespace}` "
f"as sync client. If this is not correct or leads to errors, "
f"please pass a `syncer` parameter in the `SyncConfig` "
f"to `tune.run()` to manually configure syncing behavior.."
if failed_syncs:
sync_str = "\n".join(
[f" {trial}: {e}" for trial, e in failed_syncs.items()]
)
raise TuneError(
f"At least one trial failed to sync down when waiting for all "
f"trials to sync: \n{sync_str}"
)
return NamespacedKubernetesSyncer(namespace)
return _syncer

View file

View file

@ -61,7 +61,7 @@ from ray.tune.suggest._mock import _MockSuggestionAlgorithm
from ray.tune.suggest.ax import AxSearch
from ray.tune.suggest.hyperopt import HyperOptSearch
from ray.tune.suggest.suggestion import ConcurrencyLimiter
from ray.tune.sync_client import CommandBasedClient
from ray.tune.syncer import Syncer
from ray.tune.trial import Trial
from ray.tune.trial_runner import TrialRunner
from ray.tune.utils import flatten_dict
@ -1037,24 +1037,23 @@ class TrainableFunctionApiTest(unittest.TestCase):
def testDurableTrainableSyncFunction(self):
"""Check custom sync functions in durable trainables"""
class CustomSyncer(Syncer):
def sync_up(
self, local_dir: str, remote_dir: str, exclude: list = None
) -> bool:
pass # sync up
def sync_down(
self, remote_dir: str, local_dir: str, exclude: list = None
) -> bool:
pass # sync down
def delete(self, remote_dir: str) -> bool:
pass # delete
class TestDurable(Trainable):
def __init__(self, *args, **kwargs):
# Mock distutils.spawn.find_executable
# so `aws` command is found
import distutils.spawn
distutils.spawn.find_executable = lambda *_, **__: True
super(TestDurable, self).__init__(*args, **kwargs)
def check(self):
return (
bool(self.sync_function_tpl)
and isinstance(self.storage_client, CommandBasedClient)
and "aws" not in self.storage_client.sync_up_template
)
class TestTplDurable(TestDurable):
_sync_function_tpl = "echo static sync {source} {target}"
def has_custom_syncer(self):
return bool(self.custom_syncer)
upload_dir = "s3://test-bucket/path"
@ -1074,21 +1073,17 @@ class TrainableFunctionApiTest(unittest.TestCase):
cls = trial.get_trainable_cls()
actor = ray.remote(cls).remote(
remote_checkpoint_dir=upload_dir,
sync_function_tpl=trial.sync_function_tpl,
custom_syncer=trial.custom_syncer,
)
return actor
# This actor should create a default aws syncer, so check should fail
actor1 = _create_remote_actor(TestDurable, None)
self.assertFalse(ray.get(actor1.check.remote()))
self.assertFalse(ray.get(actor1.has_custom_syncer.remote()))
# This actor should create a custom syncer, so check should pass
actor2 = _create_remote_actor(TestDurable, "echo test sync {source} {target}")
self.assertTrue(ray.get(actor2.check.remote()))
# This actor should create a custom syncer, so check should pass
actor3 = _create_remote_actor(TestTplDurable, None)
self.assertTrue(ray.get(actor3.check.remote()))
actor2 = _create_remote_actor(TestDurable, CustomSyncer())
self.assertTrue(ray.get(actor2.has_custom_syncer.remote()))
def testCheckpointDict(self):
class TestTrain(Trainable):

View file

@ -5,30 +5,18 @@ import os
import pytest
import shutil
import sys
from unittest.mock import MagicMock, patch
from typing import Callable, Union, Optional
from unittest.mock import MagicMock
import ray
from ray import tune
from ray.rllib import _register_all
from ray.cluster_utils import Cluster
from ray._private.test_utils import run_string_as_driver_nonblocking
from ray.tune import register_trainable
from ray.tune.experiment import Experiment
from ray.tune.error import TuneError
from ray.tune.suggest import BasicVariantGenerator
from ray.tune.syncer import CloudSyncer, SyncerCallback, get_node_syncer
from ray.tune.utils.trainable import TrainableUtil
from ray.tune.syncer import SyncerCallback
from ray.tune.trial import Trial
from ray.tune.trial_runner import TrialRunner
from ray.tune.utils.mock import (
MockDurableTrainer,
MockRemoteTrainer,
MockNodeSyncer,
mock_storage_client,
MOCK_REMOTE_DIR,
)
def _check_trial_running(trial):
@ -51,27 +39,9 @@ def _start_new_cluster():
"_system_config": {"num_heartbeats_timeout": 10},
},
)
# Pytest doesn't play nicely with imports
register_trainable("__fake_remote", MockRemoteTrainer)
register_trainable("__fake_durable", MockDurableTrainer)
_register_all()
return cluster
class _PerTrialSyncerCallback(SyncerCallback):
def __init__(
self, get_sync_fn: Callable[["Trial"], Optional[Union[bool, Callable]]]
):
self._get_sync_fn = get_sync_fn
super(_PerTrialSyncerCallback, self).__init__(None)
def _create_trial_syncer(self, trial: "Trial"):
sync_fn = self._get_sync_fn(trial)
return get_node_syncer(
trial.logdir, remote_dir=trial.logdir, sync_function=sync_fn
)
@pytest.fixture
def start_connected_cluster():
# Start the Ray processes.
@ -94,10 +64,6 @@ def start_connected_emptyhead_cluster():
"_system_config": {"num_heartbeats_timeout": 10},
},
)
# Pytest doesn't play nicely with imports
_register_all()
register_trainable("__fake_remote", MockRemoteTrainer)
register_trainable("__fake_durable", MockDurableTrainer)
os.environ["TUNE_STATE_REFRESH_PERIOD"] = "0.1"
yield cluster
# The code after the yield will run as teardown code.
@ -208,8 +174,16 @@ def test_remove_node_before_result(start_connected_emptyhead_cluster):
runner.step()
@pytest.mark.parametrize("trainable_id", ["__fake", "__fake_durable"])
def test_trial_migration(start_connected_emptyhead_cluster, trainable_id):
def custom_driver_logdir_callback(tempdir: str):
class SeparateDriverSyncerCallback(SyncerCallback):
def _local_trial_logdir(self, trial):
return os.path.join(tempdir, trial.relative_logdir)
return SeparateDriverSyncerCallback()
@pytest.mark.parametrize("durable", [False, True])
def test_trial_migration(start_connected_emptyhead_cluster, tmpdir, durable):
"""Removing a node while cluster has space should migrate trial.
The trial state should also be consistent with the checkpoint.
@ -218,21 +192,23 @@ def test_trial_migration(start_connected_emptyhead_cluster, trainable_id):
node = cluster.add_node(num_cpus=1)
cluster.wait_for_nodes()
syncer_callback = _PerTrialSyncerCallback(
lambda trial: trial.trainable_name == "__fake"
)
if durable:
upload_dir = "file://" + str(tmpdir)
syncer_callback = SyncerCallback()
else:
upload_dir = None
syncer_callback = custom_driver_logdir_callback(str(tmpdir))
runner = TrialRunner(BasicVariantGenerator(), callbacks=[syncer_callback])
kwargs = {
"stopping_criterion": {"training_iteration": 4},
"checkpoint_freq": 2,
"max_failures": 2,
"remote_checkpoint_dir": upload_dir,
}
if trainable_id == "__fake_durable":
kwargs["remote_checkpoint_dir"] = MOCK_REMOTE_DIR
# Test recovery of trial that hasn't been checkpointed
t = Trial(trainable_id, **kwargs)
t = Trial("__fake", **kwargs)
runner.add_trial(t)
runner.step() # Start trial
runner.step() # Process result
@ -257,7 +233,7 @@ def test_trial_migration(start_connected_emptyhead_cluster, trainable_id):
assert t.status == Trial.TERMINATED, runner.debug_string()
# Test recovery of trial that has been checkpointed
t2 = Trial(trainable_id, **kwargs)
t2 = Trial("__fake", **kwargs)
runner.add_trial(t2)
# Start trial, process result (x2), process save
while not t2.has_checkpoint():
@ -272,12 +248,10 @@ def test_trial_migration(start_connected_emptyhead_cluster, trainable_id):
# Test recovery of trial that won't be checkpointed
kwargs = {
"stopping_criterion": {"training_iteration": 3},
"remote_checkpoint_dir": upload_dir,
}
if trainable_id == "__fake_durable":
kwargs["remote_checkpoint_dir"] = MOCK_REMOTE_DIR
t3 = Trial(trainable_id, **kwargs)
t3 = Trial("__fake", **kwargs)
runner.add_trial(t3)
runner.step() # Start trial
runner.step() # Process result 1
@ -292,8 +266,8 @@ def test_trial_migration(start_connected_emptyhead_cluster, trainable_id):
runner.step()
@pytest.mark.parametrize("trainable_id", ["__fake", "__fake_durable"])
def test_trial_requeue(start_connected_emptyhead_cluster, trainable_id):
@pytest.mark.parametrize("durable", [False, True])
def test_trial_requeue(start_connected_emptyhead_cluster, tmpdir, durable):
"""Removing a node in full cluster causes Trial to be requeued."""
os.environ["TUNE_MAX_PENDING_TRIALS_PG"] = "1"
@ -301,20 +275,22 @@ def test_trial_requeue(start_connected_emptyhead_cluster, trainable_id):
node = cluster.add_node(num_cpus=1)
cluster.wait_for_nodes()
syncer_callback = _PerTrialSyncerCallback(
lambda trial: trial.trainable_name == "__fake"
)
if durable:
upload_dir = "file://" + str(tmpdir)
syncer_callback = SyncerCallback()
else:
upload_dir = None
syncer_callback = custom_driver_logdir_callback(str(tmpdir))
runner = TrialRunner(BasicVariantGenerator(), callbacks=[syncer_callback]) # noqa
kwargs = {
"stopping_criterion": {"training_iteration": 5},
"checkpoint_freq": 1,
"max_failures": 1,
"remote_checkpoint_dir": upload_dir,
}
if trainable_id == "__fake_durable":
kwargs["remote_checkpoint_dir"] = MOCK_REMOTE_DIR
trials = [Trial(trainable_id, **kwargs), Trial(trainable_id, **kwargs)]
trials = [Trial("__fake", **kwargs), Trial("__fake", **kwargs)]
for t in trials:
runner.add_trial(t)
@ -333,58 +309,32 @@ def test_trial_requeue(start_connected_emptyhead_cluster, trainable_id):
assert all(t.status == Trial.PENDING for t in trials), runner.debug_string()
@pytest.mark.parametrize("trainable_id", ["__fake_remote", "__fake_durable"])
def test_migration_checkpoint_removal(start_connected_emptyhead_cluster, trainable_id):
@pytest.mark.parametrize("durable", [False, True])
def test_migration_checkpoint_removal(
start_connected_emptyhead_cluster, tmpdir, durable
):
"""Test checks that trial restarts if checkpoint is lost w/ node fail."""
cluster = start_connected_emptyhead_cluster
node = cluster.add_node(num_cpus=1)
cluster.wait_for_nodes()
# Only added for fake_remote case.
# For durable case, we don't do sync to head.
class _SyncerCallback(SyncerCallback):
def _create_trial_syncer(self, trial: "Trial"):
client = mock_storage_client()
return MockNodeSyncer(trial.logdir, trial.logdir, client)
if durable:
upload_dir = "file://" + str(tmpdir)
syncer_callback = SyncerCallback()
else:
upload_dir = None
syncer_callback = custom_driver_logdir_callback(str(tmpdir))
syncer_callback = (
[_SyncerCallback(None)] if trainable_id == "__fake_remote" else None
)
runner = TrialRunner(BasicVariantGenerator(), callbacks=syncer_callback)
runner = TrialRunner(BasicVariantGenerator(), callbacks=[syncer_callback])
kwargs = {
"stopping_criterion": {"training_iteration": 4},
"checkpoint_freq": 2,
"max_failures": 2,
"remote_checkpoint_dir": upload_dir,
}
if trainable_id == "__fake_durable":
kwargs["remote_checkpoint_dir"] = MOCK_REMOTE_DIR
# The following patches only affect __fake_remote.
def hide_remote_path(path_function):
def hidden_path_func(checkpoint_path):
"""Converts back to local path first."""
if MOCK_REMOTE_DIR in checkpoint_path:
checkpoint_path = checkpoint_path[len(MOCK_REMOTE_DIR) :]
checkpoint_path = os.path.join("/", checkpoint_path)
return path_function(checkpoint_path)
return hidden_path_func
trainable_util = "ray.tune.ray_trial_executor.TrainableUtil"
_find_ckpt = trainable_util + ".find_checkpoint_dir"
find_func = TrainableUtil.find_checkpoint_dir
_pickle_ckpt = trainable_util + ".pickle_checkpoint"
pickle_func = TrainableUtil.pickle_checkpoint
with patch(_find_ckpt) as mock_find, patch(_pickle_ckpt) as mock_pkl_ckpt:
# __fake_remote trainables save to a separate "remote" directory.
# TrainableUtil will not check this path unless we mock it.
mock_find.side_effect = hide_remote_path(find_func)
mock_pkl_ckpt.side_effect = hide_remote_path(pickle_func)
# Test recovery of trial that has been checkpointed
t1 = Trial(trainable_id, **kwargs)
t1 = Trial("__fake", **kwargs)
runner.add_trial(t1)
# Start trial, process result (x2), process save
@ -394,84 +344,45 @@ def test_migration_checkpoint_removal(start_connected_emptyhead_cluster, trainab
cluster.add_node(num_cpus=1)
cluster.remove_node(node)
cluster.wait_for_nodes()
# Remove checkpoint on "remote" node
shutil.rmtree(os.path.dirname(t1.checkpoint.dir_or_data))
if not durable:
# Recover from driver file
t1.checkpoint.dir_or_data = os.path.join(
tmpdir,
t1.relative_logdir,
os.path.relpath(t1.checkpoint.dir_or_data, t1.logdir),
)
while not runner.is_finished():
runner.step()
assert t1.status == Trial.TERMINATED, runner.debug_string()
@pytest.mark.skip(reason="Not very consistent.")
@pytest.mark.parametrize("trainable_id", ["__fake", "__fake_durable"])
def test_cluster_down_simple(start_connected_cluster, tmpdir, trainable_id):
"""Tests that TrialRunner save/restore works on cluster shutdown."""
cluster = start_connected_cluster
cluster.add_node(num_cpus=1)
cluster.wait_for_nodes()
dirpath = str(tmpdir)
syncer_callback = _PerTrialSyncerCallback(
lambda trial: trial.trainable_name == "__fake"
)
runner = TrialRunner(
local_checkpoint_dir=dirpath, checkpoint_period=0, callbacks=[syncer_callback]
)
kwargs = {
"stopping_criterion": {"training_iteration": 2},
"checkpoint_freq": 1,
"max_failures": 1,
}
if trainable_id == "__fake_durable":
kwargs["remote_checkpoint_dir"] = MOCK_REMOTE_DIR
trials = [Trial(trainable_id, **kwargs), Trial(trainable_id, **kwargs)]
for t in trials:
runner.add_trial(t)
# Start trial (x2), process result, process save
for _ in range(4):
runner.step()
assert all(t.status == Trial.RUNNING for t in runner.get_trials())
runner.checkpoint()
ray.shutdown()
cluster.shutdown()
cluster = _start_new_cluster()
runner = TrialRunner(resume="LOCAL", local_checkpoint_dir=dirpath)
# Start trial, process restore, process result, process save
for _ in range(4):
runner.step()
# Start trial 2, process result, process save, process result, process save
for i in range(5):
runner.step()
with pytest.raises(TuneError):
runner.step()
assert all(t.status == Trial.TERMINATED for t in runner.get_trials())
ray.shutdown()
cluster.shutdown()
@pytest.mark.parametrize("trainable_id", ["__fake", "__fake_durable"])
def test_cluster_down_full(start_connected_cluster, tmpdir, trainable_id):
@pytest.mark.parametrize("durable", [False, True])
def test_cluster_down_full(start_connected_cluster, tmpdir, durable):
"""Tests that run_experiment restoring works on cluster shutdown."""
cluster = start_connected_cluster
dirpath = str(tmpdir)
use_default_sync = trainable_id == "__fake"
if durable:
upload_dir = "file://" + str(tmpdir)
syncer_callback = SyncerCallback()
else:
upload_dir = None
syncer_callback = custom_driver_logdir_callback(str(tmpdir))
from ray.tune.result import DEFAULT_RESULTS_DIR
local_dir = DEFAULT_RESULTS_DIR
upload_dir = None if use_default_sync else MOCK_REMOTE_DIR
base_dict = dict(
run=trainable_id,
run="__fake",
stop=dict(training_iteration=3),
local_dir=local_dir,
sync_config=dict(upload_dir=upload_dir, syncer=use_default_sync),
sync_config=dict(upload_dir=upload_dir),
)
exp1_args = base_dict
@ -486,19 +397,18 @@ def test_cluster_down_full(start_connected_cluster, tmpdir, trainable_id):
"exp4": exp4_args,
}
mock_get_client = "ray.tune.trial_runner.get_cloud_syncer"
with patch(mock_get_client) as mock_get_cloud_syncer:
mock_syncer = CloudSyncer(local_dir, upload_dir, mock_storage_client())
mock_get_cloud_syncer.return_value = mock_syncer
tune.run_experiments(all_experiments, raise_on_failed_trial=False)
tune.run_experiments(
all_experiments, callbacks=[syncer_callback], raise_on_failed_trial=False
)
ray.shutdown()
cluster.shutdown()
cluster = _start_new_cluster()
trials = tune.run_experiments(
all_experiments, resume=True, raise_on_failed_trial=False
all_experiments,
resume=True,
raise_on_failed_trial=False,
)
assert len(trials) == 4

View file

@ -6,12 +6,9 @@ import subprocess
import sys
import ray
from ray.rllib import _register_all
from ray.cluster_utils import Cluster
from ray.tune import register_trainable
from ray.tune.trial import Trial
from ray.tune.trial_runner import TrialRunner
from ray.tune.utils.mock import MockDurableTrainer, MockRemoteTrainer
from ray.tune.utils.mock_trainable import MyTrainableClass
@ -24,10 +21,6 @@ def _start_new_cluster():
"_system_config": {"num_heartbeats_timeout": 10},
},
)
# Pytest doesn't play nicely with imports
register_trainable("__fake_remote", MockRemoteTrainer)
register_trainable("__fake_durable", MockDurableTrainer)
_register_all()
return cluster

View file

@ -1,108 +0,0 @@
import unittest
from typing import Optional
from ray.tune.integration.docker import DockerSyncer, DockerSyncClient
from ray.tune.sync_client import SyncClient
from ray.tune.syncer import NodeSyncer
class _MockRsync:
def __init__(self):
self.history = []
def __call__(self, *args, **kwargs):
self.history.append(kwargs)
class _MockLookup:
def __init__(self, node_ips):
self.node_to_ip = {}
self.ip_to_node = {}
for node, ip in node_ips.items():
self.node_to_ip[node] = ip
self.ip_to_node[ip] = node
def get_ip(self, node):
return self.node_to_ip[node]
def get_node(self, ip):
return self.ip_to_node[ip]
def _create_mock_syncer(local_ip, local_dir, remote_dir):
class _MockSyncer(DockerSyncer):
def __init__(
self,
local_dir: str,
remote_dir: str,
sync_client: Optional[SyncClient] = None,
):
self.local_ip = local_ip
self.worker_ip = None
sync_client = sync_client or DockerSyncClient()
sync_client.configure("__nofile__")
super(NodeSyncer, self).__init__(local_dir, remote_dir, sync_client)
return _MockSyncer(local_dir, remote_dir)
class DockerIntegrationTest(unittest.TestCase):
def setUp(self):
self.lookup = _MockLookup({"head": "1.0.0.0", "w1": "1.0.0.1", "w2": "1.0.0.2"})
self.local_dir = "/tmp/local"
self.remote_dir = "/tmp/remote"
self.mock_command = _MockRsync()
from ray.tune.integration import docker
docker.rsync = self.mock_command
def tearDown(self):
pass
def testDockerRsyncUpDown(self):
syncer = _create_mock_syncer(
self.lookup.get_ip("head"), self.local_dir, self.remote_dir
)
syncer.set_worker_ip(self.lookup.get_ip("w1"))
# Test sync up. Should add / to the dirs and call rsync
syncer.sync_up()
print(self.mock_command.history[-1])
self.assertEqual(self.mock_command.history[-1]["source"], self.local_dir + "/")
self.assertEqual(self.mock_command.history[-1]["target"], self.remote_dir + "/")
self.assertEqual(self.mock_command.history[-1]["down"], False)
self.assertEqual(
self.mock_command.history[-1]["ip_address"], self.lookup.get_ip("w1")
)
# Test sync down.
syncer.sync_down()
print(self.mock_command.history[-1])
self.assertEqual(self.mock_command.history[-1]["target"], self.local_dir + "/")
self.assertEqual(self.mock_command.history[-1]["source"], self.remote_dir + "/")
self.assertEqual(self.mock_command.history[-1]["down"], True)
self.assertEqual(
self.mock_command.history[-1]["ip_address"], self.lookup.get_ip("w1")
)
# Sync to same node should be ignored
prev = len(self.mock_command.history)
syncer.set_worker_ip(self.lookup.get_ip("head"))
syncer.sync_up()
self.assertEqual(len(self.mock_command.history), prev)
syncer.sync_down()
self.assertEqual(len(self.mock_command.history), prev)
if __name__ == "__main__":
import pytest
import sys
sys.exit(pytest.main(["-v", __file__]))

View file

@ -1,114 +0,0 @@
import unittest
from ray.autoscaler._private.command_runner import KUBECTL_RSYNC
from ray.tune.integration.kubernetes import KubernetesSyncer, KubernetesSyncClient
from ray.tune.syncer import NodeSyncer
class _MockProcessRunner:
def __init__(self):
self.history = []
def check_call(self, command):
self.history.append(command)
return True
class _MockLookup:
def __init__(self, node_ips):
self.node_to_ip = {}
self.ip_to_node = {}
for node, ip in node_ips.items():
self.node_to_ip[node] = ip
self.ip_to_node[ip] = node
def get_ip(self, node):
return self.node_to_ip[node]
def get_node(self, ip):
return self.ip_to_node[ip]
def __call__(self, ip):
return self.ip_to_node[ip]
def _create_mock_syncer(
namespace, lookup, process_runner, local_ip, local_dir, remote_dir
):
class _MockSyncer(KubernetesSyncer):
_namespace = namespace
_get_kubernetes_node_by_ip = lookup
def __init__(self, local_dir, remote_dir, sync_client):
self.local_ip = local_ip
self.local_node = self._get_kubernetes_node_by_ip(self.local_ip)
self.worker_ip = None
self.worker_node = None
sync_client = sync_client
super(NodeSyncer, self).__init__(local_dir, remote_dir, sync_client)
return _MockSyncer(
local_dir,
remote_dir,
sync_client=KubernetesSyncClient(
namespace=namespace, process_runner=process_runner
),
)
class KubernetesIntegrationTest(unittest.TestCase):
def setUp(self):
self.namespace = "test_ray"
self.lookup = _MockLookup({"head": "1.0.0.0", "w1": "1.0.0.1", "w2": "1.0.0.2"})
self.process_runner = _MockProcessRunner()
self.local_dir = "/tmp/local"
self.remote_dir = "/tmp/remote"
def tearDown(self):
pass
def testKubernetesRsyncUpDown(self):
syncer = _create_mock_syncer(
self.namespace,
self.lookup,
self.process_runner,
self.lookup.get_ip("head"),
self.local_dir,
self.remote_dir,
)
syncer.set_worker_ip(self.lookup.get_ip("w1"))
# Test sync up. Should add / to the dirs and call rsync
syncer.sync_up()
self.assertEqual(self.process_runner.history[-1][0], KUBECTL_RSYNC)
self.assertEqual(self.process_runner.history[-1][-2], self.local_dir + "/")
self.assertEqual(
self.process_runner.history[-1][-1],
"{}@{}:{}".format("w1", self.namespace, self.remote_dir + "/"),
)
# Test sync down.
syncer.sync_down()
self.assertEqual(self.process_runner.history[-1][0], KUBECTL_RSYNC)
self.assertEqual(
self.process_runner.history[-1][-2],
"{}@{}:{}".format("w1", self.namespace, self.remote_dir + "/"),
)
self.assertEqual(self.process_runner.history[-1][-1], self.local_dir + "/")
# Sync to same node should be ignored
syncer.set_worker_ip(self.lookup.get_ip("head"))
syncer.sync_up()
self.assertTrue(len(self.process_runner.history) == 2)
syncer.sync_down()
self.assertTrue(len(self.process_runner.history) == 2)
if __name__ == "__main__":
import pytest
import sys
sys.exit(pytest.main(["-v", __file__]))

View file

@ -1,717 +0,0 @@
import glob
import io
import os
import pickle
import shutil
import sys
import tarfile
import tempfile
import time
import unittest
from unittest.mock import patch
import yaml
from collections import deque
import ray
from ray.exceptions import RayTaskError
from ray.rllib import _register_all
from ray import tune
from ray.tune import TuneError
from ray.tune.integration.docker import DockerSyncer
from ray.tune.integration.kubernetes import KubernetesSyncer
from ray.tune.sync_client import NOOP, RemoteTaskClient
from ray.tune.syncer import (
CommandBasedClient,
detect_cluster_syncer,
get_cloud_sync_client,
SyncerCallback,
)
from ray.tune.utils.callback import create_default_callbacks
from ray.tune.utils.file_transfer import (
delete_on_node,
_sync_dir_on_same_node,
_sync_dir_between_different_nodes,
)
# Default RemoteTaskClient will use _sync_dir_on_same_node in this test,
# as the IPs are the same
class RemoteTaskClientWithSyncDirBetweenDifferentNodes(RemoteTaskClient):
def _sync_function(self, *args, **kwargs):
return _sync_dir_between_different_nodes(*args, **kwargs)
class TestSyncFunctionality(unittest.TestCase):
def setUp(self):
ray.init(num_cpus=2)
def tearDown(self):
ray.shutdown()
_register_all() # re-register the evicted objects
@patch("ray.tune.sync_client.S3_PREFIX", "test")
def testCloudProperString(self):
with self.assertRaises(ValueError):
[trial] = tune.run(
"__fake",
name="foo",
max_failures=0,
stop={"training_iteration": 1},
sync_config=tune.SyncConfig(
**{"upload_dir": "test", "syncer": "ls {target}"}
),
).trials
with self.assertRaises(ValueError):
[trial] = tune.run(
"__fake",
name="foo",
max_failures=0,
stop={"training_iteration": 1},
sync_config=tune.SyncConfig(
**{"upload_dir": "test", "syncer": "ls {source}"}
),
).trials
tmpdir = tempfile.mkdtemp()
logfile = os.path.join(tmpdir, "test.log")
[trial] = tune.run(
"__fake",
name="foo",
max_failures=0,
stop={"training_iteration": 1},
sync_config=tune.SyncConfig(
**{
"upload_dir": "test",
"syncer": "echo {source} {target} > " + logfile,
}
),
).trials
with open(logfile) as f:
lines = f.read()
self.assertTrue("test" in lines)
shutil.rmtree(tmpdir)
def testClusterProperString(self):
"""Tests that invalid commands throw.."""
with self.assertRaises(TuneError):
# This raises ValueError because logger is init in safe zone.
sync_config = tune.SyncConfig(syncer="ls {target}")
[trial] = tune.run(
"__fake",
name="foo",
max_failures=0,
stop={"training_iteration": 1},
sync_config=sync_config,
).trials
with self.assertRaises(TuneError):
# This raises ValueError because logger is init in safe zone.
sync_config = tune.SyncConfig(syncer="ls {source}")
[trial] = tune.run(
"__fake",
name="foo",
max_failures=0,
sync_config=sync_config,
stop={"training_iteration": 1},
).trials
with patch.object(CommandBasedClient, "_execute") as mock_fn:
with patch("ray.tune.syncer.get_node_ip_address") as mock_sync:
sync_config = tune.SyncConfig(syncer="echo {source} {target}")
mock_sync.return_value = "0.0.0.0"
[trial] = tune.run(
"__fake",
name="foo",
max_failures=0,
sync_config=sync_config,
stop={"training_iteration": 1},
).trials
self.assertGreater(mock_fn.call_count, 0)
def testCloudFunctions(self):
tmpdir = tempfile.mkdtemp()
tmpdir2 = tempfile.mkdtemp()
os.mkdir(os.path.join(tmpdir2, "foo"))
def sync_func(local, remote, exclude=None):
for filename in glob.glob(os.path.join(local, "*.json")):
shutil.copy(filename, remote)
sync_config = tune.SyncConfig(upload_dir=tmpdir2, syncer=sync_func)
[trial] = tune.run(
"__fake",
name="foo",
max_failures=0,
local_dir=tmpdir,
stop={"training_iteration": 1},
sync_config=sync_config,
).trials
test_file_path = glob.glob(os.path.join(tmpdir2, "foo", "*.json"))
self.assertTrue(test_file_path)
shutil.rmtree(tmpdir)
shutil.rmtree(tmpdir2)
@patch("ray.tune.sync_client.S3_PREFIX", "test")
def testCloudSyncPeriod(self):
"""Tests that changing SYNC_PERIOD affects syncing frequency."""
tmpdir = tempfile.mkdtemp()
def trainable(config):
for i in range(10):
time.sleep(1)
tune.report(score=i)
def counter(local, remote, exclude=None):
count_file = os.path.join(tmpdir, "count.txt")
if not os.path.exists(count_file):
count = 0
else:
with open(count_file, "rb") as fp:
count = pickle.load(fp)
count += 1
with open(count_file, "wb") as fp:
pickle.dump(count, fp)
sync_config = tune.SyncConfig(upload_dir="test", syncer=counter, sync_period=1)
# This was originally set to 0.5
os.environ["TUNE_GLOBAL_CHECKPOINT_S"] = "0"
self.addCleanup(lambda: os.environ.pop("TUNE_GLOBAL_CHECKPOINT_S", None))
[trial] = tune.run(
trainable,
name="foo",
max_failures=0,
local_dir=tmpdir,
stop={"training_iteration": 10},
sync_config=sync_config,
).trials
count_file = os.path.join(tmpdir, "count.txt")
with open(count_file, "rb") as fp:
count = pickle.load(fp)
self.assertEqual(count, 12)
shutil.rmtree(tmpdir)
def testClusterSyncFunction(self):
def sync_func_driver(source, target, exclude=None):
assert ":" in source, "Source {} not a remote path.".format(source)
assert ":" not in target, "Target is supposed to be local."
with open(os.path.join(target, "test.log2"), "w") as f:
print("writing to", f.name)
f.write(source)
sync_config = tune.SyncConfig(syncer=sync_func_driver, sync_period=5)
[trial] = tune.run(
"__fake",
name="foo",
max_failures=0,
stop={"training_iteration": 1},
sync_config=sync_config,
).trials
test_file_path = os.path.join(trial.logdir, "test.log2")
self.assertFalse(os.path.exists(test_file_path))
with patch("ray.tune.syncer.get_node_ip_address") as mock_sync:
mock_sync.return_value = "0.0.0.0"
sync_config = tune.SyncConfig(syncer=sync_func_driver)
[trial] = tune.run(
"__fake",
name="foo",
max_failures=0,
stop={"training_iteration": 1},
sync_config=sync_config,
).trials
test_file_path = os.path.join(trial.logdir, "test.log2")
self.assertTrue(os.path.exists(test_file_path))
os.remove(test_file_path)
def testNoSync(self):
"""Sync should not run on a single node."""
def sync_func(source, target, exclude=None):
pass
sync_config = tune.SyncConfig(syncer=sync_func)
with patch.object(CommandBasedClient, "_execute") as mock_sync:
[trial] = tune.run(
"__fake",
name="foo",
max_failures=0,
stop={"training_iteration": 1},
sync_config=sync_config,
).trials
self.assertEqual(mock_sync.call_count, 0)
def testCloudSyncExclude(self):
captured = deque(maxlen=1)
captured.append("")
def always_true(*args, **kwargs):
return True
def capture_popen(command, *args, **kwargs):
captured.append(command)
with patch("subprocess.Popen", capture_popen), patch(
"distutils.spawn.find_executable", always_true
):
# S3
s3_client = get_cloud_sync_client("s3://test-bucket/test-dir")
s3_client.sync_down(
"s3://test-bucket/test-dir/remote_source", "local_target"
)
self.assertEqual(
captured[0].strip(),
"aws s3 sync s3://test-bucket/test-dir/remote_source "
"local_target --exact-timestamps --only-show-errors",
)
s3_client.sync_down(
"s3://test-bucket/test-dir/remote_source",
"local_target",
exclude=["*/checkpoint_*"],
)
self.assertEqual(
captured[0].strip(),
"aws s3 sync s3://test-bucket/test-dir/remote_source "
"local_target --exact-timestamps --only-show-errors "
"--exclude '*/checkpoint_*'",
)
s3_client.sync_down(
"s3://test-bucket/test-dir/remote_source",
"local_target",
exclude=["*/checkpoint_*", "*.big"],
)
self.assertEqual(
captured[0].strip(),
"aws s3 sync s3://test-bucket/test-dir/remote_source "
"local_target --exact-timestamps --only-show-errors "
"--exclude '*/checkpoint_*' --exclude '*.big'",
)
# GS
gs_client = get_cloud_sync_client("gs://test-bucket/test-dir")
gs_client.sync_down(
"gs://test-bucket/test-dir/remote_source", "local_target"
)
self.assertEqual(
captured[0].strip(),
"gsutil rsync -r "
"gs://test-bucket/test-dir/remote_source "
"local_target",
)
gs_client.sync_down(
"gs://test-bucket/test-dir/remote_source",
"local_target",
exclude=["*/checkpoint_*"],
)
self.assertEqual(
captured[0].strip(),
"gsutil rsync -r "
"-x '(.*/checkpoint_.*)' "
"gs://test-bucket/test-dir/remote_source "
"local_target",
)
gs_client.sync_down(
"gs://test-bucket/test-dir/remote_source",
"local_target",
exclude=["*/checkpoint_*", "*.big"],
)
self.assertEqual(
captured[0].strip(),
"gsutil rsync -r "
"-x '(.*/checkpoint_.*)|(.*.big)' "
"gs://test-bucket/test-dir/remote_source "
"local_target",
)
def testSyncDetection(self):
kubernetes_conf = {"provider": {"type": "kubernetes", "namespace": "test_ray"}}
docker_conf = {"docker": {"image": "bogus"}, "provider": {"type": "aws"}}
aws_conf = {"provider": {"type": "aws"}}
with tempfile.TemporaryDirectory() as dir:
kubernetes_file = os.path.join(dir, "kubernetes.yaml")
with open(kubernetes_file, "wt") as fp:
yaml.safe_dump(kubernetes_conf, fp)
docker_file = os.path.join(dir, "docker.yaml")
with open(docker_file, "wt") as fp:
yaml.safe_dump(docker_conf, fp)
aws_file = os.path.join(dir, "aws.yaml")
with open(aws_file, "wt") as fp:
yaml.safe_dump(aws_conf, fp)
kubernetes_syncer = detect_cluster_syncer(None, kubernetes_file)
self.assertTrue(issubclass(kubernetes_syncer, KubernetesSyncer))
self.assertEqual(kubernetes_syncer._namespace, "test_ray")
docker_syncer = detect_cluster_syncer(None, docker_file)
self.assertTrue(issubclass(docker_syncer, DockerSyncer))
aws_syncer = detect_cluster_syncer(None, aws_file)
self.assertEqual(aws_syncer, None)
# Should still return DockerSyncer, since it was passed explicitly
syncer = detect_cluster_syncer(
tune.SyncConfig(syncer=DockerSyncer), kubernetes_file
)
self.assertTrue(issubclass(syncer, DockerSyncer))
@patch(
"ray.tune.syncer.get_rsync_template_if_available",
lambda: "rsync {source} {target}",
)
def testNoSyncToDriver(self):
"""Test that sync to driver is disabled"""
class _Trial:
def __init__(self, id, logdir):
self.id = (id,)
self.logdir = logdir
trial = _Trial("0", "some_dir")
sync_config = tune.SyncConfig(syncer=None)
# Create syncer callbacks
callbacks = create_default_callbacks([], sync_config)
syncer_callback = callbacks[-1]
# Sanity check that we got the syncer callback
self.assertTrue(isinstance(syncer_callback, SyncerCallback))
# Sync function should be false (no sync to driver)
self.assertEqual(syncer_callback._sync_function, False)
# Sync to driver is disabled, so this should be no-op
trial_syncer = syncer_callback._get_trial_syncer(trial)
self.assertEqual(trial_syncer.sync_client, NOOP)
def testSyncWaitRetry(self):
class CountingClient(CommandBasedClient):
def __init__(self, *args, **kwargs):
self._sync_ups = 0
self._sync_downs = 0
super(CountingClient, self).__init__(*args, **kwargs)
def _start_process(self, cmd):
if "UPLOAD" in cmd:
self._sync_ups += 1
elif "DOWNLOAD" in cmd:
self._sync_downs += 1
if self._sync_downs == 1:
self._last_cmd = "echo DOWNLOAD && true"
return super(CountingClient, self)._start_process(cmd)
client = CountingClient(
"echo UPLOAD {source} {target} && false",
"echo DOWNLOAD {source} {target} && false",
"echo DELETE {target}",
)
# Fail always
with self.assertRaisesRegex(TuneError, "Failed sync even after"):
client.sync_up("test_source", "test_target")
client.wait_or_retry(max_retries=3, backoff_s=0)
self.assertEquals(client._sync_ups, 3)
# Succeed after second try
client.sync_down("test_source", "test_target")
client.wait_or_retry(max_retries=3, backoff_s=0)
self.assertEquals(client._sync_downs, 2)
def _check_dir_contents(self, path: str):
assert os.path.exists(os.path.join(path, "dir_level0"))
assert os.path.exists(os.path.join(path, "dir_level0", "dir_level1"))
assert os.path.exists(os.path.join(path, "dir_level0", "file_level1.txt"))
with open(os.path.join(path, "dir_level0", "file_level1.txt"), "r") as f:
assert f.read() == "Data\n"
def _testSyncInNodeAndDelete(self, num_workers: int = 1):
temp_source = tempfile.mkdtemp()
temp_up_target = tempfile.mkdtemp()
temp_down_target = tempfile.mkdtemp()
self.addCleanup(shutil.rmtree, temp_source)
self.addCleanup(shutil.rmtree, temp_up_target, ignore_errors=True)
self.addCleanup(shutil.rmtree, temp_down_target)
os.makedirs(os.path.join(temp_source, "dir_level0", "dir_level1"))
with open(os.path.join(temp_source, "dir_level0", "file_level1.txt"), "w") as f:
f.write("Data\n")
# Sanity check
self._check_dir_contents(temp_source)
node_ip = ray.util.get_node_ip_address()
futures = [
_sync_dir_on_same_node(
ip=node_ip,
source_path=temp_source,
target_path=temp_up_target,
return_futures=True,
)
for i in range(num_workers)
]
ray.get(futures)
# Check sync up
self._check_dir_contents(temp_up_target)
assert not os.listdir(temp_down_target)
futures = [
_sync_dir_on_same_node(
ip=node_ip,
source_path=temp_up_target,
target_path=temp_down_target,
return_futures=True,
)
for i in range(num_workers)
]
ray.get(futures)
# Check sync down
self._check_dir_contents(temp_down_target)
# Delete in some dir
delete_on_node(node_ip=node_ip, path=temp_up_target)
assert not os.path.exists(temp_up_target)
def testSyncInNodeAndDelete(self):
self._testSyncInNodeAndDelete(num_workers=1)
def testSyncInNodeAndDeleteMultipleWorkers(self):
self._testSyncInNodeAndDelete(num_workers=8)
def _testSyncBetweenNodesAndDelete(self, num_workers: int = 1):
temp_source = tempfile.mkdtemp()
temp_up_target = tempfile.mkdtemp()
temp_down_target = tempfile.mkdtemp()
self.addCleanup(shutil.rmtree, temp_source)
self.addCleanup(shutil.rmtree, temp_up_target, ignore_errors=True)
self.addCleanup(shutil.rmtree, temp_down_target)
os.makedirs(os.path.join(temp_source, "dir_level0", "dir_level1"))
with open(os.path.join(temp_source, "dir_level0", "file_level1.txt"), "w") as f:
f.write("Data\n")
# Sanity check
self._check_dir_contents(temp_source)
node_ip = ray.util.get_node_ip_address()
futures = [
_sync_dir_between_different_nodes(
source_ip=node_ip,
source_path=temp_source,
target_ip=node_ip,
target_path=temp_up_target,
return_futures=True,
)[0]
for i in range(num_workers)
]
ray.get(futures)
# Check sync up
self._check_dir_contents(temp_up_target)
# Max size exceeded
with self.assertRaises(RayTaskError):
_sync_dir_between_different_nodes(
source_ip=node_ip,
source_path=temp_up_target,
target_ip=node_ip,
target_path=temp_down_target,
max_size_bytes=2,
)
assert not os.listdir(temp_down_target)
futures = [
_sync_dir_between_different_nodes(
source_ip=node_ip,
source_path=temp_up_target,
target_ip=node_ip,
target_path=temp_down_target,
return_futures=True,
)[0]
for i in range(num_workers)
]
ray.get(futures)
# Check sync down
self._check_dir_contents(temp_down_target)
# Delete in some dir
delete_on_node(node_ip=node_ip, path=temp_up_target)
assert not os.path.exists(temp_up_target)
def testSyncBetweenNodesAndDelete(self):
self._testSyncBetweenNodesAndDelete(num_workers=1)
def testSyncBetweenNodesAndDeleteMultipleWorkers(self):
self._testSyncBetweenNodesAndDelete(num_workers=8)
def _prepareDirForTestSyncRemoteTask(self):
temp_source = tempfile.mkdtemp()
temp_up_target = tempfile.mkdtemp()
temp_down_target = tempfile.mkdtemp()
self.addCleanup(shutil.rmtree, temp_source)
self.addCleanup(shutil.rmtree, temp_up_target)
self.addCleanup(shutil.rmtree, temp_down_target)
os.makedirs(os.path.join(temp_source, "A", "a1"))
os.makedirs(os.path.join(temp_source, "A", "a2"))
os.makedirs(os.path.join(temp_source, "B", "b1"))
with open(os.path.join(temp_source, "level_0.txt"), "wt") as fp:
fp.write("Level 0\n")
with open(os.path.join(temp_source, "A", "level_a1.txt"), "wt") as fp:
fp.write("Level A1\n")
with open(os.path.join(temp_source, "A", "a1", "level_a2.txt"), "wt") as fp:
fp.write("Level A2\n")
with open(os.path.join(temp_source, "B", "level_b1.txt"), "wt") as fp:
fp.write("Level B1\n")
return temp_source, temp_up_target, temp_down_target
def testSyncRemoteTaskOnlyDifferencesOnDifferentNodes(self):
"""Tests the RemoteTaskClient sync client with different node logic.
In this test we generate a directory with multiple files.
We then use both ``sync_down`` and ``sync_up`` to synchronize
these to different directories (on the same node). We then assert
that the files have been transferred correctly.
We then edit one of the files and add another one. We then sync
up/down again. In this sync, we assert that only modified and new
files are transferred.
"""
(
temp_source,
temp_up_target,
temp_down_target,
) = self._prepareDirForTestSyncRemoteTask()
this_node_ip = ray.util.get_node_ip_address()
# Sync everything up
client = RemoteTaskClientWithSyncDirBetweenDifferentNodes(_store_remotes=True)
client.sync_up(source=temp_source, target=(this_node_ip, temp_up_target))
client.wait()
# Assume that we synced everything up to second level
self.assertTrue(
os.path.exists(os.path.join(temp_up_target, "A", "a1", "level_a2.txt")),
msg=f"Contents: {os.listdir(temp_up_target)}",
)
with open(os.path.join(temp_up_target, "A", "a1", "level_a2.txt"), "rt") as fp:
self.assertEqual(fp.read(), "Level A2\n")
# Sync everything down
client.sync_down(source=(this_node_ip, temp_source), target=temp_down_target)
client.wait()
# Assume that we synced everything up to second level
self.assertTrue(
os.path.exists(os.path.join(temp_down_target, "A", "a1", "level_a2.txt")),
msg=f"Contents: {os.listdir(temp_down_target)}",
)
with open(
os.path.join(temp_down_target, "A", "a1", "level_a2.txt"), "rt"
) as fp:
self.assertEqual(fp.read(), "Level A2\n")
# Now, edit some stuff in our source. Then confirm only these
# edited files are synced
with open(os.path.join(temp_source, "A", "a1", "level_a2.txt"), "wt") as fp:
fp.write("Level X2\n") # Same length
with open(os.path.join(temp_source, "A", "level_a1x.txt"), "wt") as fp:
fp.write("Level A1X\n") # New file
# Sync up
client.sync_up(source=temp_source, target=(this_node_ip, temp_up_target))
# Hi-jack futures
files_stats = ray.get(client._stored_files_stats)
tarball = ray.get(client._stored_pack_actor_ref.get_full_data.remote())
client.wait()
# Existing file should have new content
with open(os.path.join(temp_up_target, "A", "a1", "level_a2.txt"), "rt") as fp:
self.assertEqual(fp.read(), "Level X2\n")
# New file should be there
with open(os.path.join(temp_up_target, "A", "level_a1x.txt"), "rt") as fp:
self.assertEqual(fp.read(), "Level A1X\n")
# Old file should be there
with open(os.path.join(temp_up_target, "B", "level_b1.txt"), "rt") as fp:
self.assertEqual(fp.read(), "Level B1\n")
# In the target dir, level_a1x was not contained
self.assertIn(os.path.join("A", "a1", "level_a2.txt"), files_stats)
self.assertNotIn(os.path.join("A", "level_a1x.txt"), files_stats)
# Inspect tarball
with tarfile.open(fileobj=io.BytesIO(tarball)) as tar:
files_in_tar = tar.getnames()
self.assertIn(os.path.join("A", "a1", "level_a2.txt"), files_in_tar)
self.assertIn(os.path.join("A", "level_a1x.txt"), files_in_tar)
self.assertNotIn(os.path.join("A", "level_a1.txt"), files_in_tar)
# 6 directories (including root) + 2 files
self.assertEqual(len(files_in_tar), 8, msg=str(files_in_tar))
# Sync down
client.sync_down(source=(this_node_ip, temp_source), target=temp_down_target)
# Hi-jack futures
files_stats = ray.get(client._stored_files_stats)
tarball = ray.get(client._stored_pack_actor_ref.get_full_data.remote())
client.wait()
# Existing file should have new content
with open(
os.path.join(temp_down_target, "A", "a1", "level_a2.txt"), "rt"
) as fp:
self.assertEqual(fp.read(), "Level X2\n")
# New file should be there
with open(os.path.join(temp_down_target, "A", "level_a1x.txt"), "rt") as fp:
self.assertEqual(fp.read(), "Level A1X\n")
# Old file should be there
with open(os.path.join(temp_down_target, "B", "level_b1.txt"), "rt") as fp:
self.assertEqual(fp.read(), "Level B1\n")
# In the target dir, level_a1x was not contained
self.assertIn(os.path.join("A", "a1", "level_a2.txt"), files_stats)
self.assertNotIn(os.path.join("A", "level_a1x.txt"), files_stats)
# Inspect tarball
with tarfile.open(fileobj=io.BytesIO(tarball)) as tar:
files_in_tar = tar.getnames()
self.assertIn(os.path.join("A", "a1", "level_a2.txt"), files_in_tar)
self.assertIn(os.path.join("A", "level_a1x.txt"), files_in_tar)
self.assertNotIn(os.path.join("A", "level_a1.txt"), files_in_tar)
# 6 directories (including root) + 2 files
self.assertEqual(len(files_in_tar), 8, msg=str(files_in_tar))
if __name__ == "__main__":
import pytest
sys.exit(pytest.main(["-v", __file__]))

View file

@ -0,0 +1,369 @@
import logging
from typing import List, Optional
import os
import pytest
import shutil
import tempfile
from freezegun import freeze_time
import ray
from ray import tune
from ray.tune import TuneError
from ray.tune.syncer import _DefaultSyncer, Syncer, _validate_upload_dir
from ray.tune.utils.file_transfer import _pack_dir, _unpack_dir
@pytest.fixture
def ray_start_2_cpus():
address_info = ray.init(num_cpus=2, configure_logging=False)
yield address_info
# The code after the yield will run as teardown code.
ray.shutdown()
@pytest.fixture
def temp_data_dirs():
tmp_source = os.path.realpath(tempfile.mkdtemp())
tmp_target = os.path.realpath(tempfile.mkdtemp())
os.makedirs(os.path.join(tmp_source, "subdir", "nested"))
os.makedirs(os.path.join(tmp_source, "subdir_exclude", "something"))
files = [
"level0.txt",
"level0_exclude.txt",
"subdir/level1.txt",
"subdir/level1_exclude.txt",
"subdir/nested/level2.txt",
"subdir_nested_level2_exclude.txt",
"subdir_exclude/something/somewhere.txt",
]
for file in files:
with open(os.path.join(tmp_source, file), "w") as f:
f.write("Data")
yield tmp_source, tmp_target
shutil.rmtree(tmp_source)
shutil.rmtree(tmp_target)
def assert_file(exists: bool, root: str, path: str):
full_path = os.path.join(root, path)
if exists:
assert os.path.exists(full_path)
else:
assert not os.path.exists(full_path)
class TestTrainable(tune.Trainable):
def save_checkpoint(self, checkpoint_dir: str):
with open(os.path.join(checkpoint_dir, "checkpoint.data"), "w") as f:
f.write("Data")
return checkpoint_dir
class CustomSyncer(Syncer):
def __init__(self, sync_period: float = 300.0):
super(CustomSyncer, self).__init__(sync_period=sync_period)
self._sync_status = {}
def sync_up(
self, local_dir: str, remote_dir: str, exclude: Optional[List] = None
) -> bool:
with open(os.path.join(local_dir, "custom_syncer.txt"), "w") as f:
f.write("Data\n")
self._sync_status[remote_dir] = _pack_dir(local_dir)
return True
def sync_down(
self, remote_dir: str, local_dir: str, exclude: Optional[List] = None
) -> bool:
if remote_dir not in self._sync_status:
return False
_unpack_dir(self._sync_status[remote_dir], local_dir)
return True
def delete(self, remote_dir: str) -> bool:
raise NotImplementedError
def retry(self):
raise NotImplementedError
def wait(self):
pass
def test_sync_string_invalid_uri():
with pytest.raises(ValueError):
_validate_upload_dir(tune.SyncConfig(upload_dir="invalid://some/url"))
def test_sync_string_invalid_local():
with pytest.raises(ValueError):
_validate_upload_dir(tune.SyncConfig(upload_dir="/invalid/dir"))
def test_sync_string_valid_local():
_validate_upload_dir(tune.SyncConfig(upload_dir="file:///valid/dir"))
def test_sync_string_valid_s3():
_validate_upload_dir(tune.SyncConfig(upload_dir="s3://valid/bucket"))
def test_syncer_sync_up_down(temp_data_dirs):
"""Check that syncing up and down works"""
tmp_source, tmp_target = temp_data_dirs
syncer = _DefaultSyncer()
syncer.sync_up(
local_dir=tmp_source, remote_dir="memory:///test/test_syncer_sync_up_down"
)
syncer.wait()
syncer.sync_down(
remote_dir="memory:///test/test_syncer_sync_up_down", local_dir=tmp_target
)
syncer.wait()
# Target dir should have all files
assert_file(True, tmp_target, "level0.txt")
assert_file(True, tmp_target, "level0_exclude.txt")
assert_file(True, tmp_target, "subdir/level1.txt")
assert_file(True, tmp_target, "subdir/level1_exclude.txt")
assert_file(True, tmp_target, "subdir/nested/level2.txt")
assert_file(True, tmp_target, "subdir_nested_level2_exclude.txt")
assert_file(True, tmp_target, "subdir_exclude/something/somewhere.txt")
def test_syncer_sync_exclude(temp_data_dirs):
"""Check that the exclude parameter works"""
tmp_source, tmp_target = temp_data_dirs
syncer = _DefaultSyncer()
syncer.sync_up(
local_dir=tmp_source,
remote_dir="memory:///test/test_syncer_sync_exclude",
exclude=["*_exclude*"],
)
syncer.wait()
syncer.sync_down(
remote_dir="memory:///test/test_syncer_sync_exclude", local_dir=tmp_target
)
syncer.wait()
# Excluded files should not be found in target
assert_file(True, tmp_target, "level0.txt")
assert_file(False, tmp_target, "level0_exclude.txt")
assert_file(True, tmp_target, "subdir/level1.txt")
assert_file(False, tmp_target, "subdir/level1_exclude.txt")
assert_file(True, tmp_target, "subdir/nested/level2.txt")
assert_file(False, tmp_target, "subdir_nested_level2_exclude.txt")
assert_file(False, tmp_target, "subdir_exclude/something/somewhere.txt")
def test_sync_up_if_needed(temp_data_dirs):
"""Check that we only sync up again after sync period"""
tmp_source, tmp_target = temp_data_dirs
with freeze_time() as frozen:
syncer = _DefaultSyncer(sync_period=60)
assert syncer.sync_up_if_needed(
local_dir=tmp_source, remote_dir="memory:///test/test_sync_up_not_needed"
)
syncer.wait()
frozen.tick(30)
# Sync period not over, yet
assert not syncer.sync_up_if_needed(
local_dir=tmp_source, remote_dir="memory:///test/test_sync_up_not_needed"
)
frozen.tick(30)
# Sync period over, sync again
assert syncer.sync_up_if_needed(
local_dir=tmp_source, remote_dir="memory:///test/test_sync_up_not_needed"
)
def test_sync_down_if_needed(temp_data_dirs):
"""Check that we only sync down again after sync period"""
tmp_source, tmp_target = temp_data_dirs
with freeze_time() as frozen:
syncer = _DefaultSyncer(sync_period=60)
# Populate remote directory
syncer.sync_up(
local_dir=tmp_source, remote_dir="memory:///test/test_sync_down_if_needed"
)
syncer.wait()
assert syncer.sync_down_if_needed(
remote_dir="memory:///test/test_sync_down_if_needed", local_dir=tmp_target
)
syncer.wait()
frozen.tick(30)
# Sync period not over, yet
assert not syncer.sync_down_if_needed(
remote_dir="memory:///test/test_sync_down_if_needed", local_dir=tmp_target
)
frozen.tick(30)
# Sync period over, sync again
assert syncer.sync_down_if_needed(
remote_dir="memory:///test/test_sync_down_if_needed", local_dir=tmp_target
)
def test_syncer_still_running_no_sync(temp_data_dirs):
"""Check that no new sync is issued if old sync is still running"""
tmp_source, tmp_target = temp_data_dirs
class FakeSyncProcess:
@property
def is_running(self):
return True
syncer = _DefaultSyncer(sync_period=60)
syncer._sync_process = FakeSyncProcess()
assert not syncer.sync_up_if_needed(
local_dir=tmp_source,
remote_dir="memory:///test/test_syncer_still_running_no_sync",
)
def test_syncer_not_running_sync(temp_data_dirs):
"""Check that new sync is issued if old sync completed"""
tmp_source, tmp_target = temp_data_dirs
class FakeSyncProcess:
@property
def is_running(self):
return False
def wait(self):
return True
syncer = _DefaultSyncer(sync_period=60)
syncer._sync_process = FakeSyncProcess()
assert syncer.sync_up_if_needed(
local_dir=tmp_source,
remote_dir="memory:///test/test_syncer_not_running_sync",
)
def test_syncer_not_running_sync_last_failed(caplog, temp_data_dirs):
"""Check that new sync is issued if old sync completed"""
caplog.set_level(logging.WARNING)
tmp_source, tmp_target = temp_data_dirs
class FakeSyncProcess:
@property
def is_running(self):
return False
def wait(self):
raise RuntimeError("Sync failed")
syncer = _DefaultSyncer(sync_period=60)
syncer._sync_process = FakeSyncProcess()
assert syncer.sync_up_if_needed(
local_dir=tmp_source,
remote_dir="memory:///test/test_syncer_not_running_sync",
)
assert "Last sync command failed" in caplog.text
def test_syncer_delete(temp_data_dirs):
"""Check that deletion on remote storage works"""
tmp_source, tmp_target = temp_data_dirs
syncer = _DefaultSyncer(sync_period=60)
# Populate remote directory
syncer.sync_up(local_dir=tmp_source, remote_dir="memory:///test/test_syncer_delete")
syncer.wait()
syncer.delete(remote_dir="memory:///test/test_syncer_delete")
syncer.sync_down(
remote_dir="memory:///test/test_syncer_delete", local_dir=tmp_target
)
with pytest.raises(TuneError):
syncer.wait()
# Remote storage was deleted, so target should be empty
assert_file(False, tmp_target, "level0.txt")
assert_file(False, tmp_target, "level0_exclude.txt")
assert_file(False, tmp_target, "subdir/level1.txt")
assert_file(False, tmp_target, "subdir/level1_exclude.txt")
assert_file(False, tmp_target, "subdir/nested/level2.txt")
assert_file(False, tmp_target, "subdir_nested_level2_exclude.txt")
assert_file(False, tmp_target, "subdir_exclude/something/somewhere.txt")
def test_syncer_wait_or_retry(temp_data_dirs):
"""Check that the wait or retry API works"""
tmp_source, tmp_target = temp_data_dirs
syncer = _DefaultSyncer(sync_period=60)
# Will fail as dir does not exist
syncer.sync_down(
remote_dir="memory:///test/test_syncer_wait_or_retry", local_dir=tmp_target
)
with pytest.raises(TuneError) as e:
syncer.wait_or_retry(max_retries=3, backoff_s=0)
assert "Failed sync even after 3 retries." in str(e)
def test_trainable_syncer_default(ray_start_2_cpus, temp_data_dirs):
"""Check that Trainable.save() triggers syncing using default syncing"""
tmp_source, tmp_target = temp_data_dirs
trainable = ray.remote(TestTrainable).remote(
remote_checkpoint_dir=f"file://{tmp_target}"
)
checkpoint_dir = ray.get(trainable.save.remote())
assert_file(True, tmp_target, os.path.join(checkpoint_dir, "checkpoint.data"))
assert_file(False, tmp_target, os.path.join(checkpoint_dir, "custom_syncer.txt"))
def test_trainable_syncer_custom(ray_start_2_cpus, temp_data_dirs):
"""Check that Trainable.save() triggers syncing using custom syncer"""
tmp_source, tmp_target = temp_data_dirs
trainable = ray.remote(TestTrainable).remote(
remote_checkpoint_dir=f"file://{tmp_target}",
custom_syncer=CustomSyncer(),
)
checkpoint_dir = ray.get(trainable.save.remote())
assert_file(True, tmp_target, os.path.join(checkpoint_dir, "checkpoint.data"))
assert_file(True, tmp_target, os.path.join(checkpoint_dir, "custom_syncer.txt"))
if __name__ == "__main__":
import sys
sys.exit(pytest.main(["-v", __file__]))

View file

@ -0,0 +1,392 @@
import logging
import os
import shutil
import tempfile
from typing import Optional
import pytest
from freezegun import freeze_time
import ray.util
from ray.tune import TuneError
from ray.tune.result import NODE_IP
from ray.tune.syncer import (
DEFAULT_SYNC_PERIOD,
SyncConfig,
SyncerCallback,
_BackgroundProcess,
)
from ray.tune.utils.callback import create_default_callbacks
from ray.tune.utils.file_transfer import sync_dir_between_nodes
from ray.util.ml_utils.checkpoint_manager import CheckpointStorage, _TrackedCheckpoint
@pytest.fixture
def ray_start_2_cpus():
address_info = ray.init(num_cpus=2, configure_logging=False)
yield address_info
# The code after the yield will run as teardown code.
ray.shutdown()
@pytest.fixture
def temp_data_dirs():
tmp_source = os.path.realpath(tempfile.mkdtemp())
tmp_target = os.path.realpath(tempfile.mkdtemp())
os.makedirs(os.path.join(tmp_source, "subdir", "nested"))
os.makedirs(os.path.join(tmp_source, "subdir_exclude", "something"))
files = [
"level0.txt",
"level0_exclude.txt",
"subdir/level1.txt",
"subdir/level1_exclude.txt",
"subdir/nested/level2.txt",
"subdir_nested_level2_exclude.txt",
"subdir_exclude/something/somewhere.txt",
]
for file in files:
with open(os.path.join(tmp_source, file), "w") as f:
f.write("Data")
yield tmp_source, tmp_target
shutil.rmtree(tmp_source)
shutil.rmtree(tmp_target)
def assert_file(exists: bool, root: str, path: str):
full_path = os.path.join(root, path)
if exists:
assert os.path.exists(full_path)
else:
assert not os.path.exists(full_path)
class MockTrial:
def __init__(self, trial_id: str, logdir: str):
self.trial_id = trial_id
self.last_result = {NODE_IP: ray.util.get_node_ip_address()}
self.uses_cloud_checkpointing = False
self.sync_on_checkpoint = True
self.logdir = logdir
class TestSyncerCallback(SyncerCallback):
def __init__(
self,
enabled: bool = True,
sync_period: float = DEFAULT_SYNC_PERIOD,
local_logdir_override: Optional[str] = None,
remote_logdir_override: Optional[str] = None,
):
super(TestSyncerCallback, self).__init__(
enabled=enabled, sync_period=sync_period
)
self.local_logdir_override = local_logdir_override
self.remote_logdir_override = remote_logdir_override
def _local_trial_logdir(self, trial):
if self.local_logdir_override:
return self.local_logdir_override
return super(TestSyncerCallback, self)._local_trial_logdir(trial)
def _remote_trial_logdir(self, trial):
if self.remote_logdir_override:
return self.remote_logdir_override
return super(TestSyncerCallback, self)._remote_trial_logdir(trial)
def _get_trial_sync_process(self, trial):
return self._sync_processes.setdefault(
trial.trial_id, MaybeFailingProcess(sync_dir_between_nodes)
)
class MaybeFailingProcess(_BackgroundProcess):
should_fail = False
def wait(self):
result = super(MaybeFailingProcess, self).wait()
if self.should_fail:
raise TuneError("Syncing failed.")
return result
def test_syncer_callback_disabled():
"""Check that syncer=None disables callback"""
callbacks = create_default_callbacks(
callbacks=[], sync_config=SyncConfig(syncer=None)
)
syncer_callback = None
for cb in callbacks:
if isinstance(cb, SyncerCallback):
syncer_callback = cb
trial1 = MockTrial(trial_id="a", logdir=None)
trial1.uses_cloud_checkpointing = False
assert syncer_callback
assert not syncer_callback._enabled
# Syncer disabled, so no-op
assert not syncer_callback._sync_trial_dir(trial1)
# This should not raise any error for not existing directory
syncer_callback.on_checkpoint(
iteration=1,
trials=[],
trial=trial1,
checkpoint=_TrackedCheckpoint(
dir_or_data="/does/not/exist", storage_mode=CheckpointStorage.PERSISTENT
),
)
def test_syncer_callback_noop_on_trial_cloud_checkpointing():
"""Check that trial using cloud checkpointing disables sync to driver"""
callbacks = create_default_callbacks(callbacks=[], sync_config=SyncConfig())
syncer_callback = None
for cb in callbacks:
if isinstance(cb, SyncerCallback):
syncer_callback = cb
trial1 = MockTrial(trial_id="a", logdir=None)
trial1.uses_cloud_checkpointing = True
assert syncer_callback
assert syncer_callback._enabled
# Cloud checkpointing set, so no-op
assert not syncer_callback._sync_trial_dir(trial1)
# This should not raise any error for not existing directory
syncer_callback.on_checkpoint(
iteration=1,
trials=[],
trial=trial1,
checkpoint=_TrackedCheckpoint(
dir_or_data="/does/not/exist", storage_mode=CheckpointStorage.PERSISTENT
),
)
def test_syncer_callback_op_on_no_cloud_checkpointing():
"""Check that without cloud checkpointing sync to driver is enabled"""
callbacks = create_default_callbacks(callbacks=[], sync_config=SyncConfig())
syncer_callback = None
for cb in callbacks:
if isinstance(cb, SyncerCallback):
syncer_callback = cb
trial1 = MockTrial(trial_id="a", logdir=None)
trial1.uses_cloud_checkpointing = False
assert syncer_callback
assert syncer_callback._enabled
assert syncer_callback._sync_trial_dir(trial1)
def test_syncer_callback_sync(ray_start_2_cpus, temp_data_dirs):
"""Check that on_trial_result triggers syncing"""
tmp_source, tmp_target = temp_data_dirs
syncer_callback = TestSyncerCallback(local_logdir_override=tmp_target)
trial1 = MockTrial(trial_id="a", logdir=tmp_source)
syncer_callback.on_trial_result(iteration=1, trials=[], trial=trial1, result={})
syncer_callback.wait_for_all()
assert_file(True, tmp_target, "level0.txt")
assert_file(True, tmp_target, "level0_exclude.txt")
assert_file(True, tmp_target, "subdir/level1.txt")
assert_file(True, tmp_target, "subdir/level1_exclude.txt")
assert_file(True, tmp_target, "subdir/nested/level2.txt")
assert_file(True, tmp_target, "subdir_nested_level2_exclude.txt")
assert_file(True, tmp_target, "subdir_exclude/something/somewhere.txt")
def test_syncer_callback_sync_period(ray_start_2_cpus, temp_data_dirs):
"""Check that on_trial_result triggers syncing, obeying sync period"""
tmp_source, tmp_target = temp_data_dirs
with freeze_time() as frozen:
syncer_callback = TestSyncerCallback(
sync_period=60, local_logdir_override=tmp_target
)
trial1 = MockTrial(trial_id="a", logdir=tmp_source)
syncer_callback.on_trial_result(iteration=1, trials=[], trial=trial1, result={})
syncer_callback.wait_for_all()
assert_file(True, tmp_target, "level0.txt")
assert_file(False, tmp_target, "level0_new.txt")
# Add new file to source directory
with open(os.path.join(tmp_source, "level0_new.txt"), "w") as f:
f.write("Data\n")
frozen.tick(30)
# Should not sync after 30 seconds
syncer_callback.on_trial_result(iteration=2, trials=[], trial=trial1, result={})
syncer_callback.wait_for_all()
assert_file(True, tmp_target, "level0.txt")
assert_file(False, tmp_target, "level0_new.txt")
frozen.tick(30)
# Should sync after 60 seconds
syncer_callback.on_trial_result(iteration=3, trials=[], trial=trial1, result={})
syncer_callback.wait_for_all()
assert_file(True, tmp_target, "level0.txt")
assert_file(True, tmp_target, "level0_new.txt")
def test_syncer_callback_force_on_checkpoint(ray_start_2_cpus, temp_data_dirs):
"""Check that on_checkpoint forces syncing"""
tmp_source, tmp_target = temp_data_dirs
with freeze_time() as frozen:
syncer_callback = TestSyncerCallback(
sync_period=60, local_logdir_override=tmp_target
)
trial1 = MockTrial(trial_id="a", logdir=tmp_source)
syncer_callback.on_trial_result(iteration=1, trials=[], trial=trial1, result={})
syncer_callback.wait_for_all()
assert_file(True, tmp_target, "level0.txt")
assert_file(False, tmp_target, "level0_new.txt")
# Add new file to source directory
with open(os.path.join(tmp_source, "level0_new.txt"), "w") as f:
f.write("Data\n")
assert_file(False, tmp_target, "level0_new.txt")
frozen.tick(30)
# Should sync as checkpoint observed
syncer_callback.on_checkpoint(
iteration=2,
trials=[],
trial=trial1,
checkpoint=_TrackedCheckpoint(
dir_or_data=tmp_target, storage_mode=CheckpointStorage.PERSISTENT
),
)
syncer_callback.wait_for_all()
assert_file(True, tmp_target, "level0.txt")
assert_file(True, tmp_target, "level0_new.txt")
def test_syncer_callback_force_on_complete(ray_start_2_cpus, temp_data_dirs):
"""Check that on_trial_complete forces syncing"""
tmp_source, tmp_target = temp_data_dirs
with freeze_time() as frozen:
syncer_callback = TestSyncerCallback(
sync_period=60, local_logdir_override=tmp_target
)
trial1 = MockTrial(trial_id="a", logdir=tmp_source)
syncer_callback.on_trial_result(iteration=1, trials=[], trial=trial1, result={})
syncer_callback.wait_for_all()
assert_file(True, tmp_target, "level0.txt")
assert_file(False, tmp_target, "level0_new.txt")
# Add new file to source directory
with open(os.path.join(tmp_source, "level0_new.txt"), "w") as f:
f.write("Data\n")
assert_file(False, tmp_target, "level0_new.txt")
frozen.tick(30)
# Should sync as checkpoint observed
syncer_callback.on_trial_complete(
iteration=2,
trials=[],
trial=trial1,
)
syncer_callback.wait_for_all()
assert_file(True, tmp_target, "level0.txt")
assert_file(True, tmp_target, "level0_new.txt")
def test_syncer_callback_wait_for_all_error(ray_start_2_cpus, temp_data_dirs):
"""Check that syncer errors are caught correctly in wait_for_all()"""
tmp_source, tmp_target = temp_data_dirs
syncer_callback = TestSyncerCallback(
sync_period=0,
local_logdir_override=tmp_target,
)
trial1 = MockTrial(trial_id="a", logdir=tmp_source)
# Inject FailingProcess into callback
sync_process = syncer_callback._get_trial_sync_process(trial1)
sync_process.should_fail = True
# This sync will fail because the remote location does not exist
syncer_callback.on_trial_result(iteration=1, trials=[], trial=trial1, result={})
with pytest.raises(TuneError) as e:
syncer_callback.wait_for_all()
assert "At least one" in e
def test_syncer_callback_log_error(caplog, ray_start_2_cpus, temp_data_dirs):
"""Check that errors in a previous sync are logged correctly"""
caplog.set_level(logging.ERROR, logger="ray.tune.syncer")
tmp_source, tmp_target = temp_data_dirs
syncer_callback = TestSyncerCallback(
sync_period=0,
local_logdir_override=tmp_target,
)
trial1 = MockTrial(trial_id="a", logdir=tmp_source)
# Inject FailingProcess into callback
sync_process = syncer_callback._get_trial_sync_process(trial1)
syncer_callback.on_trial_result(iteration=1, trials=[], trial=trial1, result={})
# So far we haven't wait()ed, so no error, yet
assert not caplog.text
assert_file(False, tmp_target, "level0.txt")
sync_process.should_fail = True
# When the previous sync processes fails, an error is logged but sync is restarted
syncer_callback.on_trial_complete(iteration=2, trials=[], trial=trial1)
assert (
"An error occurred during the checkpoint syncing of the previous checkpoint"
in caplog.text
)
sync_process.should_fail = False
syncer_callback.wait_for_all()
assert_file(True, tmp_target, "level0.txt")
if __name__ == "__main__":
import sys
sys.exit(pytest.main(["-v", __file__]))

View file

@ -27,7 +27,7 @@ from ray.tune.utils.trainable import TrainableUtil
@pytest.mark.parametrize("logdir", ["~/tmp/exp/trial", "~/tmp/exp/trial/"])
def test_find_rel_checkpoint_dir(checkpoint_path, logdir):
assert (
TrainableUtil.find_rel_checkpoint_dir(logdir, checkpoint_path) == "checkpoint0/"
TrainableUtil.find_rel_checkpoint_dir(logdir, checkpoint_path) == "checkpoint0"
)

View file

@ -6,7 +6,6 @@ import shutil
import sys
import tempfile
import unittest
from unittest.mock import patch
import ray
from ray.rllib import _register_all
@ -24,7 +23,7 @@ from ray.tune.suggest.repeater import Repeater
from ray.tune.suggest._mock import _MockSuggestionAlgorithm
from ray.tune.suggest.suggestion import Searcher, ConcurrencyLimiter
from ray.tune.suggest.search_generator import SearchGenerator
from ray.tune.syncer import SyncConfig
from ray.tune.syncer import SyncConfig, Syncer
from ray.tune.tests.utils_for_test_trial_runner import TrialResultObserver
@ -748,19 +747,37 @@ class TrialRunnerTest3(unittest.TestCase):
self.assertTrue(trials[0].has_checkpoint())
self.assertEqual(num_checkpoints(trials[0]), 2)
@patch("ray.tune.syncer.SYNC_PERIOD", 0)
def testCheckpointAutoPeriod(self):
ray.init(num_cpus=3)
# This makes checkpointing take 2 seconds.
def sync_up(source, target, exclude=None):
class CustomSyncer(Syncer):
def __init__(self, sync_period: float = 300.0):
super(CustomSyncer, self).__init__(sync_period=sync_period)
self._sync_status = {}
def sync_up(
self, local_dir: str, remote_dir: str, exclude: list = None
) -> bool:
time.sleep(2)
return True
def sync_down(
self, remote_dir: str, local_dir: str, exclude: list = None
) -> bool:
time.sleep(2)
return True
def delete(self, remote_dir: str) -> bool:
pass
runner = TrialRunner(
local_checkpoint_dir=self.tmpdir,
checkpoint_period="auto",
sync_config=SyncConfig(upload_dir="fake", syncer=sync_up),
sync_config=SyncConfig(
upload_dir="fake", syncer=CustomSyncer(), sync_period=0
),
remote_checkpoint_dir="fake",
)
runner.add_trial(Trial("__fake", config={"user_checkpoint_freq": 1}))

View file

@ -290,12 +290,6 @@ class TrialRunnerCallbacks(unittest.TestCase):
first_logger_pos, last_logger_pos, syncer_pos = get_positions(callbacks)
self.assertLess(last_logger_pos, syncer_pos)
# This should throw an error as the syncer comes before the logger
with self.assertRaises(ValueError):
callbacks = create_default_callbacks(
[SyncerCallback(None), LoggerCallback()], SyncConfig(), None
)
# This should be reordered but preserve the regular callback order
[mc1, mc2, mc3] = [Callback(), Callback(), Callback()]
# Has to be legacy logger to avoid logger callback creation

View file

@ -0,0 +1,227 @@
import io
import tarfile
import os
import pytest
import shutil
import tempfile
from ray.exceptions import RayTaskError
from ray.tune.utils.file_transfer import (
_sync_dir_between_different_nodes,
delete_on_node,
_sync_dir_on_same_node,
)
import ray.util
@pytest.fixture
def ray_start_2_cpus():
address_info = ray.init(num_cpus=2, configure_logging=False)
yield address_info
# The code after the yield will run as teardown code.
ray.shutdown()
@pytest.fixture
def temp_data_dirs():
tmp_source = os.path.realpath(tempfile.mkdtemp())
tmp_target = os.path.realpath(tempfile.mkdtemp())
os.makedirs(os.path.join(tmp_source, "subdir", "nested"))
os.makedirs(os.path.join(tmp_source, "subdir_exclude", "something"))
files = [
"level0.txt",
"level0_exclude.txt",
"subdir/level1.txt",
"subdir/level1_exclude.txt",
"subdir/nested/level2.txt",
"subdir_nested_level2_exclude.txt",
"subdir_exclude/something/somewhere.txt",
]
for file in files:
with open(os.path.join(tmp_source, file), "w") as f:
f.write("Data")
yield tmp_source, tmp_target
shutil.rmtree(tmp_source)
shutil.rmtree(tmp_target)
def assert_file(exists: bool, root: str, path: str):
full_path = os.path.join(root, path)
if exists:
assert os.path.exists(full_path)
else:
assert not os.path.exists(full_path)
def test_sync_nodes(ray_start_2_cpus, temp_data_dirs):
"""Check that syncing between nodes works (data is found in target directory)"""
tmp_source, tmp_target = temp_data_dirs
assert_file(True, tmp_source, "level0.txt")
assert_file(True, tmp_source, "subdir/level1.txt")
assert_file(False, tmp_target, "level0.txt")
assert_file(False, tmp_target, "subdir/level1.txt")
node_ip = ray.util.get_node_ip_address()
_sync_dir_between_different_nodes(
source_ip=node_ip,
source_path=tmp_source,
target_ip=node_ip,
target_path=tmp_target,
)
assert_file(True, tmp_target, "level0.txt")
assert_file(True, tmp_target, "subdir/level1.txt")
def test_sync_nodes_only_diff(ray_start_2_cpus, temp_data_dirs):
"""Check that only differing files are synced between nodes"""
tmp_source, tmp_target = temp_data_dirs
# Sanity check
assert_file(True, tmp_source, "level0.txt")
assert_file(True, tmp_source, "subdir/level1.txt")
assert_file(False, tmp_target, "level0.txt")
assert_file(False, tmp_target, "level0_new.txt")
node_ip = ray.util.get_node_ip_address()
_sync_dir_between_different_nodes(
source_ip=node_ip,
source_path=tmp_source,
target_ip=node_ip,
target_path=tmp_target,
)
assert_file(True, tmp_source, "level0.txt")
assert_file(True, tmp_target, "level0.txt")
assert_file(True, tmp_target, "subdir/level1.txt")
assert_file(False, tmp_target, "level0_new.txt")
# Add new file
with open(os.path.join(tmp_source, "level0_new.txt"), "w") as f:
f.write("Data\n")
# Modify existing file
with open(os.path.join(tmp_source, "subdir", "level1.txt"), "w") as f:
f.write("New data\n")
unpack, pack_actor, files_stats = _sync_dir_between_different_nodes(
source_ip=node_ip,
source_path=tmp_source,
target_ip=node_ip,
target_path=tmp_target,
return_futures=True,
)
files_stats = ray.get(files_stats)
tarball = ray.get(pack_actor.get_full_data.remote())
assert "./level0.txt" in files_stats
assert "./level0_new.txt" not in files_stats # Was not in target dir
assert "subdir/level1.txt" in files_stats
with tarfile.open(fileobj=io.BytesIO(tarball)) as tar:
files_in_tar = tar.getnames()
assert "./level0.txt" not in files_in_tar
assert "./level0_new.txt" in files_in_tar
assert "subdir/level1.txt" in files_in_tar
assert len(files_in_tar) == 7 # 3 files, 4 dirs (including root)
ray.get(unpack) # Wait until finished for teardown
def test_delete_on_node(ray_start_2_cpus, temp_data_dirs):
"""Check that delete on node works."""
tmp_source, tmp_target = temp_data_dirs
assert_file(True, tmp_source, "level0.txt")
assert_file(True, tmp_source, "subdir/level1.txt")
node_ip = ray.util.get_node_ip_address()
delete_on_node(
node_ip=node_ip,
path=tmp_source,
)
assert_file(False, tmp_source, "level0.txt")
assert_file(False, tmp_source, "subdir/level1.txt")
# Re-create dir for teardown
os.makedirs(tmp_source, exist_ok=True)
@pytest.mark.parametrize("num_workers", [1, 8])
def test_multi_sync_same_node(ray_start_2_cpus, temp_data_dirs, num_workers):
"""Check that multiple competing syncs to the same node+dir don't interfere"""
tmp_source, tmp_target = temp_data_dirs
assert_file(True, tmp_source, "level0.txt")
assert_file(True, tmp_source, "subdir/level1.txt")
node_ip = ray.util.get_node_ip_address()
futures = [
_sync_dir_on_same_node(
ip=node_ip,
source_path=tmp_source,
target_path=tmp_target,
return_futures=True,
)
for _ in range(num_workers)
]
ray.get(futures)
assert_file(True, tmp_target, "level0.txt")
assert_file(True, tmp_target, "subdir/level1.txt")
@pytest.mark.parametrize("num_workers", [1, 8])
def test_multi_sync_different_node(ray_start_2_cpus, temp_data_dirs, num_workers):
"""Check that multiple competing syncs to the same dir don't interfere"""
tmp_source, tmp_target = temp_data_dirs
assert_file(True, tmp_source, "level0.txt")
assert_file(True, tmp_source, "subdir/level1.txt")
node_ip = ray.util.get_node_ip_address()
futures = [
_sync_dir_between_different_nodes(
source_ip=node_ip,
source_path=tmp_source,
target_ip=node_ip,
target_path=tmp_target,
return_futures=True,
)[0]
for _ in range(num_workers)
]
ray.get(futures)
assert_file(True, tmp_target, "level0.txt")
assert_file(True, tmp_target, "subdir/level1.txt")
def test_max_size_exceeded(ray_start_2_cpus, temp_data_dirs):
tmp_source, tmp_target = temp_data_dirs
node_ip = ray.util.get_node_ip_address()
with pytest.raises(RayTaskError):
_sync_dir_between_different_nodes(
source_ip=node_ip,
source_path=tmp_source,
target_ip=node_ip,
target_path=tmp_target,
max_size_bytes=2,
)
if __name__ == "__main__":
import sys
sys.exit(pytest.main(["-v", __file__]))

View file

@ -39,7 +39,7 @@ from ray.tune.result import (
STDOUT_FILE,
STDERR_FILE,
)
from ray.tune.sync_client import get_sync_client, get_cloud_sync_client
from ray.tune.syncer import Syncer
from ray.tune.utils import UtilMonitor
from ray.tune.utils.placement_groups import PlacementGroupFactory
from ray.tune.utils.trainable import TrainableUtil
@ -91,14 +91,12 @@ class Trainable:
"""
_sync_function_tpl = None
def __init__(
self,
config: Dict[str, Any] = None,
logger_creator: Callable[[Dict[str, Any]], Logger] = None,
remote_checkpoint_dir: Optional[str] = None,
sync_function_tpl: Optional[str] = None,
custom_syncer: Optional[Syncer] = None,
):
"""Initialize an Trainable.
@ -116,8 +114,8 @@ class Trainable:
remote_checkpoint_dir: Upload directory (S3 or GS path).
This is **per trial** directory,
which is different from **per checkpoint** directory.
sync_function_tpl: Sync function template to use. Defaults
to `cls._sync_function` (which defaults to `None`).
custom_syncer: Syncer used for synchronizing data from Ray nodes
to external storage.
"""
self._experiment_id = uuid.uuid4().hex
@ -168,26 +166,13 @@ class Trainable:
self._monitor = UtilMonitor(start=log_sys_usage)
self.remote_checkpoint_dir = remote_checkpoint_dir
self.sync_function_tpl = sync_function_tpl or self._sync_function_tpl
self.custom_syncer = custom_syncer
self.storage_client = None
if self.uses_cloud_checkpointing and self.sync_function_tpl:
# Keep this only for custom sync functions and
# backwards compatibility.
# Todo (krfricke): We should find a way to register custom
# syncers in Checkpoints rather than passing storage clients
self.storage_client = self._create_storage_client()
@property
def uses_cloud_checkpointing(self):
return bool(self.remote_checkpoint_dir)
def _create_storage_client(self):
"""Returns a storage client."""
return get_sync_client(self.sync_function_tpl) or get_cloud_sync_client(
self.remote_checkpoint_dir
)
def _storage_path(self, local_path):
"""Converts a `local_path` to be based off of
`self.remote_checkpoint_dir`."""
@ -463,16 +448,17 @@ class Trainable:
return checkpoint_path
def _maybe_save_to_cloud(self, checkpoint_dir: str):
def _maybe_save_to_cloud(self, checkpoint_dir: str) -> bool:
# Derived classes like the FunctionRunner might call this
if self.uses_cloud_checkpointing:
if self.storage_client:
# Keep for backwards compatibility, remove after deprecation
self.storage_client.sync_up(
if not self.uses_cloud_checkpointing:
return False
if self.custom_syncer:
self.custom_syncer.sync_up(
checkpoint_dir, self._storage_path(checkpoint_dir)
)
self.storage_client.wait_or_retry()
return
self.custom_syncer.wait_or_retry()
return True
checkpoint = Checkpoint.from_directory(checkpoint_dir)
retry_fn(
@ -481,6 +467,33 @@ class Trainable:
num_retries=3,
sleep_time=1,
)
return True
def _maybe_load_from_cloud(self, checkpoint_path: str) -> bool:
if not self.uses_cloud_checkpointing:
return False
rel_checkpoint_dir = TrainableUtil.find_rel_checkpoint_dir(
self.logdir, checkpoint_path
)
external_uri = os.path.join(self.remote_checkpoint_dir, rel_checkpoint_dir)
local_dir = os.path.join(self.logdir, rel_checkpoint_dir)
if self.custom_syncer:
# Only keep for backwards compatibility
self.custom_syncer.sync_down(remote_dir=external_uri, local_dir=local_dir)
self.custom_syncer.wait_or_retry()
return True
checkpoint = Checkpoint.from_uri(external_uri)
retry_fn(
lambda: checkpoint.to_directory(local_dir),
subprocess.CalledProcessError,
num_retries=3,
sleep_time=1,
)
return True
def save_to_object(self):
"""Saves the current model state to a Python object.
@ -533,26 +546,7 @@ class Trainable:
if isinstance(checkpoint_path, TrialCheckpoint):
checkpoint_path = checkpoint_path.local_path
if self.uses_cloud_checkpointing:
rel_checkpoint_dir = TrainableUtil.find_rel_checkpoint_dir(
self.logdir, checkpoint_path
)
external_uri = os.path.join(self.remote_checkpoint_dir, rel_checkpoint_dir)
local_dir = os.path.join(self.logdir, rel_checkpoint_dir)
if self.storage_client:
# Only keep for backwards compatibility
self.storage_client.sync_down(external_uri, local_dir)
self.storage_client.wait_or_retry()
else:
checkpoint = Checkpoint.from_uri(external_uri)
retry_fn(
lambda: checkpoint.to_directory(local_dir),
subprocess.CalledProcessError,
num_retries=3,
sleep_time=1,
)
elif (
if not self._maybe_load_from_cloud(checkpoint_path) and (
# If a checkpoint source IP is given
checkpoint_node_ip
# And the checkpoint does not currently exist on the local node
@ -566,6 +560,13 @@ class Trainable:
if checkpoint:
checkpoint.to_directory(checkpoint_path)
if not os.path.exists(checkpoint_path):
raise ValueError(
f"Could not recover from checkpoint as it does not exist on local "
f"disk and was not available on cloud storage or another Ray node. "
f"Got checkpoint path: {checkpoint_path} and IP {checkpoint_node_ip}"
)
with open(checkpoint_path + ".tune_metadata", "rb") as f:
metadata = pickle.load(f)
self._experiment_id = metadata["experiment_id"]

View file

@ -32,6 +32,7 @@ from ray.tune.result import (
DEBUG_METRICS,
)
from ray.tune.resources import Resources
from ray.tune.syncer import Syncer
from ray.tune.utils.placement_groups import (
PlacementGroupFactory,
resource_dict_to_pg_factory,
@ -251,7 +252,7 @@ class Trial:
placement_group_factory: Optional[PlacementGroupFactory] = None,
stopping_criterion: Optional[Dict[str, float]] = None,
remote_checkpoint_dir: Optional[str] = None,
sync_function_tpl: Optional[str] = None,
custom_syncer: Optional[Syncer] = None,
checkpoint_freq: int = 0,
checkpoint_at_end: bool = False,
sync_on_checkpoint: bool = True,
@ -367,9 +368,9 @@ class Trial:
else:
self.remote_checkpoint_dir_prefix = None
if sync_function_tpl == "auto" or not isinstance(sync_function_tpl, str):
sync_function_tpl = None
self.sync_function_tpl = sync_function_tpl
if custom_syncer == "auto" or not isinstance(custom_syncer, Syncer):
custom_syncer = None
self.custom_syncer = custom_syncer
self.checkpoint_freq = checkpoint_freq
self.checkpoint_at_end = checkpoint_at_end

View file

@ -34,7 +34,7 @@ from ray.tune.result import (
from ray.tune.schedulers import FIFOScheduler, TrialScheduler
from ray.tune.stopper import NoopStopper, Stopper
from ray.tune.suggest import BasicVariantGenerator, SearchAlgorithm
from ray.tune.syncer import CloudSyncer, get_cloud_syncer, SyncConfig
from ray.tune.syncer import SyncConfig, get_node_to_storage_syncer, Syncer
from ray.tune.trial import Trial
from ray.tune.utils import warn_if_slow, flatten_dict
from ray.tune.utils.log import Verbosity, has_verbosity
@ -110,8 +110,10 @@ class _ExperimentCheckpointManager:
checkpoint_period: Union[int, float, str],
start_time: float,
session_str: str,
syncer: CloudSyncer,
sync_trial_checkpoints: bool = True,
syncer: Syncer,
sync_trial_checkpoints: bool,
local_dir: str,
remote_dir: str,
):
self._checkpoint_dir = checkpoint_dir
self._auto_checkpoint_enabled = checkpoint_period == "auto"
@ -125,6 +127,8 @@ class _ExperimentCheckpointManager:
self._syncer = syncer
self._sync_trial_checkpoints = sync_trial_checkpoints
self._local_dir = local_dir
self._remote_dir = remote_dir
self._last_checkpoint_time = 0.0
@ -183,12 +187,22 @@ class _ExperimentCheckpointManager:
else:
exclude = ["*/checkpoint_*"]
if self._syncer:
if force:
# Wait until previous sync command finished
self._syncer.wait()
self._syncer.sync_up(exclude=exclude)
self._syncer.sync_up(
local_dir=self._local_dir,
remote_dir=self._remote_dir,
exclude=exclude,
)
else:
self._syncer.sync_up_if_needed(exclude=exclude)
self._syncer.sync_up_if_needed(
local_dir=self._local_dir,
remote_dir=self._remote_dir,
exclude=exclude,
)
checkpoint_time_taken = time.monotonic() - checkpoint_time_start
if self._auto_checkpoint_enabled:
@ -359,9 +373,8 @@ class TrialRunner:
sync_config = sync_config or SyncConfig()
self._remote_checkpoint_dir = remote_checkpoint_dir
self._syncer = get_cloud_syncer(
local_checkpoint_dir, remote_checkpoint_dir, sync_config.syncer
)
self._syncer = get_node_to_storage_syncer(sync_config)
self._stopper = stopper or NoopStopper()
self._resumed = False
@ -438,6 +451,8 @@ class TrialRunner:
session_str=self._session_str,
syncer=self._syncer,
sync_trial_checkpoints=sync_trial_checkpoints,
local_dir=self._local_checkpoint_dir,
remote_dir=self._remote_checkpoint_dir,
)
@property
@ -471,10 +486,12 @@ class TrialRunner:
)
# Not clear if we need this assertion, since we should always have a
# local checkpoint dir.
assert self._local_checkpoint_dir or self._remote_checkpoint_dir
assert self._local_checkpoint_dir or (
self._remote_checkpoint_dir and self._syncer
)
if resume_type == "AUTO":
if self._remote_checkpoint_dir:
if self._remote_checkpoint_dir and self._syncer:
logger.info(
f"Trying to find and download experiment checkpoint at "
f"{self._remote_checkpoint_dir}"
@ -482,7 +499,10 @@ class TrialRunner:
# Todo: This syncs the entire experiment including trial
# checkpoints. We should exclude these in the future.
try:
self._syncer.sync_down_if_needed()
self._syncer.sync_down_if_needed(
remote_dir=self._remote_checkpoint_dir,
local_dir=self._local_checkpoint_dir,
)
self._syncer.wait()
except TuneError as e:
logger.warning(
@ -548,9 +568,10 @@ class TrialRunner:
f"({self._remote_checkpoint_dir})"
):
return False
if not self._remote_checkpoint_dir:
if not self._remote_checkpoint_dir or not self._syncer:
raise ValueError(
"Called resume from remote without remote directory. "
"Called resume from remote without remote directory or "
"without valid syncer. "
"Fix this by passing a `SyncConfig` object with "
"`upload_dir` set to `tune.run(sync_config=...)`."
)
@ -566,7 +587,11 @@ class TrialRunner:
exclude = ["*/checkpoint_*"]
try:
self._syncer.sync_down_if_needed(exclude=exclude)
self._syncer.sync_down_if_needed(
remote_dir=self._remote_checkpoint_dir,
local_dir=self._local_checkpoint_dir,
exclude=exclude,
)
self._syncer.wait()
except TuneError as e:
raise RuntimeError(

View file

@ -49,9 +49,8 @@ from ray.tune.schedulers.util import (
from ray.tune.suggest.variant_generator import has_unresolved_values
from ray.tune.syncer import (
SyncConfig,
set_sync_periods,
wait_for_sync,
validate_upload_dir,
_validate_upload_dir,
SyncerCallback,
)
from ray.tune.trainable import Trainable
from ray.tune.trial import Trial
@ -428,8 +427,7 @@ def run(
config = config or {}
sync_config = sync_config or SyncConfig()
validate_upload_dir(sync_config)
set_sync_periods(sync_config)
_validate_upload_dir(sync_config)
if num_samples == -1:
num_samples = sys.maxsize
@ -722,7 +720,14 @@ def run(
if has_verbosity(Verbosity.V1_EXPERIMENT):
_report_progress(runner, progress_reporter, done=True)
wait_for_sync()
# Wait for syncing to finish
for callback in callbacks:
if isinstance(callback, SyncerCallback):
try:
callback.wait_for_all()
except TuneError as e:
logger.error(e)
runner.cleanup()
incomplete_trials = []

View file

@ -5,13 +5,14 @@ import os
from ray.tune.callback import Callback
from ray.tune.progress_reporter import TrialProgressCallback
from ray.tune.syncer import SyncConfig, detect_cluster_syncer
from ray.tune.syncer import SyncConfig
from ray.tune.logger import (
CSVLoggerCallback,
CSVLogger,
JsonLoggerCallback,
JsonLogger,
LegacyLoggerCallback,
LoggerCallback,
TBXLoggerCallback,
TBXLogger,
)
@ -69,7 +70,6 @@ def create_default_callbacks(
# Check if we have a CSV, JSON and TensorboardX logger
for i, callback in enumerate(callbacks):
if isinstance(callback, LegacyLoggerCallback):
last_logger_index = i
if CSVLogger in callback.logger_classes:
has_csv_logger = True
if JsonLogger in callback.logger_classes:
@ -78,17 +78,17 @@ def create_default_callbacks(
has_tbx_logger = True
elif isinstance(callback, CSVLoggerCallback):
has_csv_logger = True
last_logger_index = i
elif isinstance(callback, JsonLoggerCallback):
has_json_logger = True
last_logger_index = i
elif isinstance(callback, TBXLoggerCallback):
has_tbx_logger = True
last_logger_index = i
elif isinstance(callback, SyncerCallback):
syncer_index = i
has_syncer_callback = True
if isinstance(callback, LoggerCallback):
last_logger_index = i
# If CSV, JSON or TensorboardX loggers are missing, add
if os.environ.get("TUNE_DISABLE_AUTO_CALLBACK_LOGGERS", "0") != "1":
if not has_csv_logger:
@ -114,11 +114,9 @@ def create_default_callbacks(
not has_syncer_callback
and os.environ.get("TUNE_DISABLE_AUTO_CALLBACK_SYNCER", "0") != "1"
):
# Detect Docker and Kubernetes environments
_cluster_syncer = detect_cluster_syncer(sync_config)
syncer_callback = SyncerCallback(sync_function=_cluster_syncer)
syncer_callback = SyncerCallback(
enabled=bool(sync_config.syncer), sync_period=sync_config.sync_period
)
callbacks.append(syncer_callback)
syncer_index = len(callbacks) - 1
@ -127,16 +125,7 @@ def create_default_callbacks(
and last_logger_index is not None
and syncer_index < last_logger_index
):
if not has_csv_logger or not has_json_logger or not has_tbx_logger:
raise ValueError(
"The `SyncerCallback` you passed to `tune.run()` came before "
"at least one `LoggerCallback`. Syncing should be done "
"after writing logs. Please re-order the callbacks so that "
"the `SyncerCallback` comes after any `LoggerCallback`."
)
else:
# If these loggers were automatically created. just re-order
# the callbacks
# Re-order callbacks
syncer_obj = callbacks[syncer_index]
callbacks.pop(syncer_index)
callbacks.insert(last_logger_index, syncer_obj)

View file

@ -4,85 +4,13 @@ import os
import random
import time
from typing import Dict
import uuid
import ray._private.utils
from ray.rllib.algorithms.mock import _MockTrainer
from ray.tune import Trainable
from ray.tune.callback import Callback
from ray.tune.sync_client import get_sync_client
from ray.tune.syncer import NodeSyncer
from ray.tune.trial import Trial
MOCK_REMOTE_DIR = (
os.path.join(ray._private.utils.get_user_temp_dir(), "mock-tune-remote") + os.sep
)
# Sync and delete templates that operate on local directories.
LOCAL_SYNC_TEMPLATE = "mkdir -p {target} && rsync -avz {source}/ {target}/"
LOCAL_DELETE_TEMPLATE = "rm -rf {target}"
logger = logging.getLogger(__name__)
def mock_storage_client():
"""Mocks storage client that treats a local dir as durable storage."""
client = get_sync_client(LOCAL_SYNC_TEMPLATE, LOCAL_DELETE_TEMPLATE)
path = os.path.join(
ray._private.utils.get_user_temp_dir(), f"mock-client-{uuid.uuid4().hex[:4]}"
)
os.makedirs(path, exist_ok=True)
client.set_logdir(path)
return client
class MockNodeSyncer(NodeSyncer):
"""Mock NodeSyncer that syncs to and from /tmp"""
def has_remote_target(self):
return True
@property
def _remote_path(self):
if self._remote_dir.startswith("/"):
self._remote_dir = self._remote_dir[1:]
return os.path.join(MOCK_REMOTE_DIR, self._remote_dir)
class MockRemoteTrainer(_MockTrainer):
"""Mock Trainable that saves at tmp for simulated clusters."""
def __init__(self, *args, **kwargs):
# Tests in test_cluster.py supply a remote checkpoint dir
# We should ignore this here as this is specifically a
# non-durable trainer
kwargs.pop("remote_checkpoint_dir", None)
super(MockRemoteTrainer, self).__init__(*args, **kwargs)
if self._logdir.startswith("/"):
self._logdir = self._logdir[1:]
self._logdir = os.path.join(MOCK_REMOTE_DIR, self._logdir)
if not os.path.exists(self._logdir):
os.makedirs(self._logdir)
class MockDurableTrainer(_MockTrainer):
"""Mock DurableTrainable that saves at tmp for simulated clusters."""
# Evaluate to true to use legacy storage client
_sync_function_tpl = True
def __init__(
self, remote_checkpoint_dir=None, sync_function_tpl=None, *args, **kwargs
):
_MockTrainer.__init__(self, *args, **kwargs)
kwargs["remote_checkpoint_dir"] = remote_checkpoint_dir
Trainable.__init__(self, *args, **kwargs)
def _create_storage_client(self):
return mock_storage_client()
class FailureInjectorCallback(Callback):
"""Adds random failure injection to the TrialExecutor."""

View file

@ -138,14 +138,14 @@ class TrainableUtil:
Note, the assumption here is `logdir` should be the prefix of
`checkpoint_path`.
For example, returns `checkpoint00000/`.
For example, returns `checkpoint00000`.
"""
assert checkpoint_path.startswith(
logdir
), "expecting `logdir` to be a prefix of `checkpoint_path`"
rel_path = os.path.relpath(checkpoint_path, logdir)
tokens = rel_path.split(os.sep)
return os.path.join(tokens[0], "")
return os.path.join(tokens[0])
@staticmethod
def make_checkpoint_dir(

View file

@ -146,7 +146,9 @@ def _serialize_checkpoint(checkpoint_path) -> bytes:
def get_checkpoint_from_remote_node(
checkpoint_path: str, node_ip: str, timeout: float = 300.0
) -> Optional[Checkpoint]:
if not any(node["NodeManagerAddress"] == node_ip for node in ray.nodes()):
if not any(
node["NodeManagerAddress"] == node_ip and node["Alive"] for node in ray.nodes()
):
logger.warning(
f"Could not fetch checkpoint with path {checkpoint_path} from "
f"node with IP {node_ip} because the node is not available "

View file

@ -197,8 +197,7 @@ class Algorithm(Trainable):
config: Optional[Union[PartialAlgorithmConfigDict, AlgorithmConfig]] = None,
env: Optional[Union[str, EnvType]] = None,
logger_creator: Optional[Callable[[], Logger]] = None,
remote_checkpoint_dir: Optional[str] = None,
sync_function_tpl: Optional[str] = None,
**kwargs,
):
"""Initializes an Algorithm instance.
@ -211,6 +210,8 @@ class Algorithm(Trainable):
the "env" key in `config`.
logger_creator: Callable that creates a ray.tune.Logger
object. If unspecified, a default logger is created.
**kwargs: Arguments passed to the Trainable base class.
"""
# User provided (partial) config (this may be w/o the default
@ -288,9 +289,7 @@ class Algorithm(Trainable):
}
}
super().__init__(
config, logger_creator, remote_checkpoint_dir, sync_function_tpl
)
super().__init__(config=config, logger_creator=logger_creator, **kwargs)
# Check, whether `training_iteration` is still a tune.Trainable property
# and has not been overridden by the user in the attempt to implement the

View file

@ -163,8 +163,7 @@ class DDPPO(PPO):
config: Optional[PartialAlgorithmConfigDict] = None,
env: Optional[Union[str, EnvType]] = None,
logger_creator: Optional[Callable[[], Logger]] = None,
remote_checkpoint_dir: Optional[str] = None,
sync_function_tpl: Optional[str] = None,
**kwargs,
):
"""Initializes a DDPPO instance.
@ -177,9 +176,11 @@ class DDPPO(PPO):
the "env" key in `config`.
logger_creator: Callable that creates a ray.tune.Logger
object. If unspecified, a default logger is created.
**kwargs: Arguments passed to the Trainable base class
"""
super().__init__(
config, env, logger_creator, remote_checkpoint_dir, sync_function_tpl
config=config, env=env, logger_creator=logger_creator, **kwargs
)
if "train_batch_size" in config.keys() and config["train_batch_size"] != -1: