diff --git a/python/ray/runtime_context.py b/python/ray/runtime_context.py index d57b67113..e56473d1d 100644 --- a/python/ray/runtime_context.py +++ b/python/ray/runtime_context.py @@ -191,6 +191,15 @@ class RuntimeContext(object): worker.check_connected() return worker.core_worker.get_actor_handle(self.actor_id) + @property + def gcs_address(self): + """Get the GCS address of the ray cluster. + Returns: + The GCS address of the cluster. + """ + self.worker.check_connected() + return self.worker.gcs_client.address + def _get_actor_call_stats(self): """Get the current worker's task counters. diff --git a/python/ray/serve/constants.py b/python/ray/serve/constants.py index ed017afaf..2f7979679 100644 --- a/python/ray/serve/constants.py +++ b/python/ray/serve/constants.py @@ -1,3 +1,5 @@ +import os + from enum import Enum import re @@ -99,6 +101,9 @@ ANONYMOUS_NAMESPACE_PATTERN = re.compile( # Handle metric push interval. (This interval will affect the cold start time period) HANDLE_METRIC_PUSH_INTERVAL_S = 10 +# Timeout for GCS internal KV service +RAY_SERVE_KV_TIMEOUT_S = float(os.environ.get("RAY_SERVE_KV_TIMEOUT_S", "0")) or None + class ServeHandleType(str, Enum): SYNC = "SYNC" diff --git a/python/ray/serve/storage/kv_store.py b/python/ray/serve/storage/kv_store.py index 8c79636f0..d74891ff4 100644 --- a/python/ray/serve/storage/kv_store.py +++ b/python/ray/serve/storage/kv_store.py @@ -9,10 +9,11 @@ try: except ImportError: boto3 = None +import ray from ray import ray_constants -import ray.experimental.internal_kv as ray_kv +from ray._private.gcs_utils import GcsClient -from ray.serve.constants import SERVE_LOGGER_NAME +from ray.serve.constants import SERVE_LOGGER_NAME, RAY_SERVE_KV_TIMEOUT_S from ray.serve.storage.kv_store_base import KVStoreBase logger = logging.getLogger(SERVE_LOGGER_NAME) @@ -29,11 +30,16 @@ class RayInternalKVStore(KVStoreBase): Supports string keys and bytes values, caller must handle serialization. """ - def __init__(self, namespace: str = None): - assert ray_kv._internal_kv_initialized() + def __init__( + self, + namespace: str = None, + ): if namespace is not None and not isinstance(namespace, str): raise TypeError("namespace must a string, got: {}.".format(type(namespace))) + self.gcs_client = GcsClient(address=ray.get_runtime_context().gcs_address) + + self.timeout = RAY_SERVE_KV_TIMEOUT_S self.namespace = namespace or "" def get_storage_key(self, key: str) -> str: @@ -51,11 +57,12 @@ class RayInternalKVStore(KVStoreBase): if not isinstance(val, bytes): raise TypeError("val must be bytes, got: {}.".format(type(val))) - ray_kv._internal_kv_put( - self.get_storage_key(key), + return self.gcs_client.internal_kv_put( + self.get_storage_key(key).encode(), val, overwrite=True, namespace=ray_constants.KV_NAMESPACE_SERVE, + timeout=self.timeout, ) def get(self, key: str) -> Optional[bytes]: @@ -70,8 +77,10 @@ class RayInternalKVStore(KVStoreBase): if not isinstance(key, str): raise TypeError("key must be a string, got: {}.".format(type(key))) - return ray_kv._internal_kv_get( - self.get_storage_key(key), namespace=ray_constants.KV_NAMESPACE_SERVE + return self.gcs_client.internal_kv_get( + self.get_storage_key(key).encode(), + namespace=ray_constants.KV_NAMESPACE_SERVE, + timeout=self.timeout, ) def delete(self, key: str): @@ -83,8 +92,12 @@ class RayInternalKVStore(KVStoreBase): if not isinstance(key, str): raise TypeError("key must be a string, got: {}.".format(type(key))) - return ray_kv._internal_kv_del( - self.get_storage_key(key), namespace=ray_constants.KV_NAMESPACE_SERVE + + return self.gcs_client.internal_kv_del( + self.get_storage_key(key).encode(), + False, + namespace=ray_constants.KV_NAMESPACE_SERVE, + timeout=self.timeout, ) @@ -202,6 +215,7 @@ class RayS3KVStore(KVStoreBase): aws_access_key_id=None, aws_secret_access_key=None, aws_session_token=None, + endpoint_url=None, ): self._namespace = namespace self._bucket = bucket @@ -217,6 +231,7 @@ class RayS3KVStore(KVStoreBase): aws_access_key_id=aws_access_key_id, aws_secret_access_key=aws_secret_access_key, aws_session_token=aws_session_token, + endpoint_url=endpoint_url, ) def get_storage_key(self, key: str) -> str: diff --git a/python/ray/serve/tests/storage_tests/test_kv_store.py b/python/ray/serve/tests/storage_tests/test_kv_store.py index 45fe43052..8921836d5 100644 --- a/python/ray/serve/tests/storage_tests/test_kv_store.py +++ b/python/ray/serve/tests/storage_tests/test_kv_store.py @@ -4,6 +4,7 @@ from typing import Optional import pytest from ray.serve.constants import DEFAULT_CHECKPOINT_PATH +from ray._private.test_utils import simulate_storage from ray.serve.storage.checkpoint_path import make_kv_store from ray.serve.storage.kv_store import RayInternalKVStore, RayLocalKVStore, RayS3KVStore from ray.serve.storage.kv_store_base import KVStoreBase @@ -80,18 +81,36 @@ def test_external_kv_local_disk(): _test_operations(kv_store) -@pytest.mark.skip(reason="Need to figure out credentials for testing") def test_external_kv_aws_s3(): - kv_store = RayS3KVStore( - "namespace", - bucket="jiao-test", - s3_path="/checkpoint", - aws_access_key_id=os.environ.get("AWS_ACCESS_KEY_ID", None), - aws_secret_access_key=os.environ.get("AWS_SECRET_ACCESS_KEY", None), - aws_session_token=os.environ.get("AWS_SESSION_TOKEN", None), - ) + with simulate_storage("s3", "serve-test") as uri: + from urllib.parse import urlparse, parse_qs - _test_operations(kv_store) + o = urlparse(uri) + qs = parse_qs(o.query) + region_name = qs["region"][0] + endpoint_url = qs["endpoint_override"][0] + + import boto3 + + s3 = boto3.client( + "s3", + region_name=region_name, + endpoint_url=endpoint_url, + ) + s3.create_bucket( + Bucket="serve-test", + CreateBucketConfiguration={"LocationConstraint": "us-west-2"}, + ) + + kv_store = RayS3KVStore( + "namespace", + bucket="serve-test", + prefix="checkpoint", + region_name=region_name, + endpoint_url=endpoint_url, + ) + + _test_operations(kv_store) @pytest.mark.skip(reason="Need to figure out credentials for testing")