[1/2][serve] Use GcsClient to replace the kv client to use timeout. (#25633)

Timeout is only introduced in GcsClient due to the reason that ray client is not defining the timeout well for their API and it's a lot of effort to make it work e2e. For built-in component, we should use GcsClient directly.

This PR use GcsClient to replace the old one to integrate GCS HA with Ray Serve.
This commit is contained in:
Yi Cheng 2022-06-11 06:41:49 +00:00 committed by GitHub
parent d36fd77548
commit 0c527b4502
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 68 additions and 20 deletions

View file

@ -191,6 +191,15 @@ class RuntimeContext(object):
worker.check_connected() worker.check_connected()
return worker.core_worker.get_actor_handle(self.actor_id) return worker.core_worker.get_actor_handle(self.actor_id)
@property
def gcs_address(self):
"""Get the GCS address of the ray cluster.
Returns:
The GCS address of the cluster.
"""
self.worker.check_connected()
return self.worker.gcs_client.address
def _get_actor_call_stats(self): def _get_actor_call_stats(self):
"""Get the current worker's task counters. """Get the current worker's task counters.

View file

@ -1,3 +1,5 @@
import os
from enum import Enum from enum import Enum
import re import re
@ -99,6 +101,9 @@ ANONYMOUS_NAMESPACE_PATTERN = re.compile(
# Handle metric push interval. (This interval will affect the cold start time period) # Handle metric push interval. (This interval will affect the cold start time period)
HANDLE_METRIC_PUSH_INTERVAL_S = 10 HANDLE_METRIC_PUSH_INTERVAL_S = 10
# Timeout for GCS internal KV service
RAY_SERVE_KV_TIMEOUT_S = float(os.environ.get("RAY_SERVE_KV_TIMEOUT_S", "0")) or None
class ServeHandleType(str, Enum): class ServeHandleType(str, Enum):
SYNC = "SYNC" SYNC = "SYNC"

View file

@ -9,10 +9,11 @@ try:
except ImportError: except ImportError:
boto3 = None boto3 = None
import ray
from ray import ray_constants from ray import ray_constants
import ray.experimental.internal_kv as ray_kv from ray._private.gcs_utils import GcsClient
from ray.serve.constants import SERVE_LOGGER_NAME from ray.serve.constants import SERVE_LOGGER_NAME, RAY_SERVE_KV_TIMEOUT_S
from ray.serve.storage.kv_store_base import KVStoreBase from ray.serve.storage.kv_store_base import KVStoreBase
logger = logging.getLogger(SERVE_LOGGER_NAME) logger = logging.getLogger(SERVE_LOGGER_NAME)
@ -29,11 +30,16 @@ class RayInternalKVStore(KVStoreBase):
Supports string keys and bytes values, caller must handle serialization. Supports string keys and bytes values, caller must handle serialization.
""" """
def __init__(self, namespace: str = None): def __init__(
assert ray_kv._internal_kv_initialized() self,
namespace: str = None,
):
if namespace is not None and not isinstance(namespace, str): if namespace is not None and not isinstance(namespace, str):
raise TypeError("namespace must a string, got: {}.".format(type(namespace))) raise TypeError("namespace must a string, got: {}.".format(type(namespace)))
self.gcs_client = GcsClient(address=ray.get_runtime_context().gcs_address)
self.timeout = RAY_SERVE_KV_TIMEOUT_S
self.namespace = namespace or "" self.namespace = namespace or ""
def get_storage_key(self, key: str) -> str: def get_storage_key(self, key: str) -> str:
@ -51,11 +57,12 @@ class RayInternalKVStore(KVStoreBase):
if not isinstance(val, bytes): if not isinstance(val, bytes):
raise TypeError("val must be bytes, got: {}.".format(type(val))) raise TypeError("val must be bytes, got: {}.".format(type(val)))
ray_kv._internal_kv_put( return self.gcs_client.internal_kv_put(
self.get_storage_key(key), self.get_storage_key(key).encode(),
val, val,
overwrite=True, overwrite=True,
namespace=ray_constants.KV_NAMESPACE_SERVE, namespace=ray_constants.KV_NAMESPACE_SERVE,
timeout=self.timeout,
) )
def get(self, key: str) -> Optional[bytes]: def get(self, key: str) -> Optional[bytes]:
@ -70,8 +77,10 @@ class RayInternalKVStore(KVStoreBase):
if not isinstance(key, str): if not isinstance(key, str):
raise TypeError("key must be a string, got: {}.".format(type(key))) raise TypeError("key must be a string, got: {}.".format(type(key)))
return ray_kv._internal_kv_get( return self.gcs_client.internal_kv_get(
self.get_storage_key(key), namespace=ray_constants.KV_NAMESPACE_SERVE self.get_storage_key(key).encode(),
namespace=ray_constants.KV_NAMESPACE_SERVE,
timeout=self.timeout,
) )
def delete(self, key: str): def delete(self, key: str):
@ -83,8 +92,12 @@ class RayInternalKVStore(KVStoreBase):
if not isinstance(key, str): if not isinstance(key, str):
raise TypeError("key must be a string, got: {}.".format(type(key))) raise TypeError("key must be a string, got: {}.".format(type(key)))
return ray_kv._internal_kv_del(
self.get_storage_key(key), namespace=ray_constants.KV_NAMESPACE_SERVE return self.gcs_client.internal_kv_del(
self.get_storage_key(key).encode(),
False,
namespace=ray_constants.KV_NAMESPACE_SERVE,
timeout=self.timeout,
) )
@ -202,6 +215,7 @@ class RayS3KVStore(KVStoreBase):
aws_access_key_id=None, aws_access_key_id=None,
aws_secret_access_key=None, aws_secret_access_key=None,
aws_session_token=None, aws_session_token=None,
endpoint_url=None,
): ):
self._namespace = namespace self._namespace = namespace
self._bucket = bucket self._bucket = bucket
@ -217,6 +231,7 @@ class RayS3KVStore(KVStoreBase):
aws_access_key_id=aws_access_key_id, aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key, aws_secret_access_key=aws_secret_access_key,
aws_session_token=aws_session_token, aws_session_token=aws_session_token,
endpoint_url=endpoint_url,
) )
def get_storage_key(self, key: str) -> str: def get_storage_key(self, key: str) -> str:

View file

@ -4,6 +4,7 @@ from typing import Optional
import pytest import pytest
from ray.serve.constants import DEFAULT_CHECKPOINT_PATH from ray.serve.constants import DEFAULT_CHECKPOINT_PATH
from ray._private.test_utils import simulate_storage
from ray.serve.storage.checkpoint_path import make_kv_store from ray.serve.storage.checkpoint_path import make_kv_store
from ray.serve.storage.kv_store import RayInternalKVStore, RayLocalKVStore, RayS3KVStore from ray.serve.storage.kv_store import RayInternalKVStore, RayLocalKVStore, RayS3KVStore
from ray.serve.storage.kv_store_base import KVStoreBase from ray.serve.storage.kv_store_base import KVStoreBase
@ -80,18 +81,36 @@ def test_external_kv_local_disk():
_test_operations(kv_store) _test_operations(kv_store)
@pytest.mark.skip(reason="Need to figure out credentials for testing")
def test_external_kv_aws_s3(): def test_external_kv_aws_s3():
kv_store = RayS3KVStore( with simulate_storage("s3", "serve-test") as uri:
"namespace", from urllib.parse import urlparse, parse_qs
bucket="jiao-test",
s3_path="/checkpoint",
aws_access_key_id=os.environ.get("AWS_ACCESS_KEY_ID", None),
aws_secret_access_key=os.environ.get("AWS_SECRET_ACCESS_KEY", None),
aws_session_token=os.environ.get("AWS_SESSION_TOKEN", None),
)
_test_operations(kv_store) o = urlparse(uri)
qs = parse_qs(o.query)
region_name = qs["region"][0]
endpoint_url = qs["endpoint_override"][0]
import boto3
s3 = boto3.client(
"s3",
region_name=region_name,
endpoint_url=endpoint_url,
)
s3.create_bucket(
Bucket="serve-test",
CreateBucketConfiguration={"LocationConstraint": "us-west-2"},
)
kv_store = RayS3KVStore(
"namespace",
bucket="serve-test",
prefix="checkpoint",
region_name=region_name,
endpoint_url=endpoint_url,
)
_test_operations(kv_store)
@pytest.mark.skip(reason="Need to figure out credentials for testing") @pytest.mark.skip(reason="Need to figure out credentials for testing")