mirror of
https://github.com/vale981/ray
synced 2025-03-05 10:01:43 -05:00
[1/2][serve] Use GcsClient to replace the kv client to use timeout. (#25633)
Timeout is only introduced in GcsClient due to the reason that ray client is not defining the timeout well for their API and it's a lot of effort to make it work e2e. For built-in component, we should use GcsClient directly. This PR use GcsClient to replace the old one to integrate GCS HA with Ray Serve.
This commit is contained in:
parent
d36fd77548
commit
0c527b4502
4 changed files with 68 additions and 20 deletions
|
@ -191,6 +191,15 @@ class RuntimeContext(object):
|
|||
worker.check_connected()
|
||||
return worker.core_worker.get_actor_handle(self.actor_id)
|
||||
|
||||
@property
|
||||
def gcs_address(self):
|
||||
"""Get the GCS address of the ray cluster.
|
||||
Returns:
|
||||
The GCS address of the cluster.
|
||||
"""
|
||||
self.worker.check_connected()
|
||||
return self.worker.gcs_client.address
|
||||
|
||||
def _get_actor_call_stats(self):
|
||||
"""Get the current worker's task counters.
|
||||
|
||||
|
|
|
@ -1,3 +1,5 @@
|
|||
import os
|
||||
|
||||
from enum import Enum
|
||||
import re
|
||||
|
||||
|
@ -99,6 +101,9 @@ ANONYMOUS_NAMESPACE_PATTERN = re.compile(
|
|||
# Handle metric push interval. (This interval will affect the cold start time period)
|
||||
HANDLE_METRIC_PUSH_INTERVAL_S = 10
|
||||
|
||||
# Timeout for GCS internal KV service
|
||||
RAY_SERVE_KV_TIMEOUT_S = float(os.environ.get("RAY_SERVE_KV_TIMEOUT_S", "0")) or None
|
||||
|
||||
|
||||
class ServeHandleType(str, Enum):
|
||||
SYNC = "SYNC"
|
||||
|
|
|
@ -9,10 +9,11 @@ try:
|
|||
except ImportError:
|
||||
boto3 = None
|
||||
|
||||
import ray
|
||||
from ray import ray_constants
|
||||
import ray.experimental.internal_kv as ray_kv
|
||||
from ray._private.gcs_utils import GcsClient
|
||||
|
||||
from ray.serve.constants import SERVE_LOGGER_NAME
|
||||
from ray.serve.constants import SERVE_LOGGER_NAME, RAY_SERVE_KV_TIMEOUT_S
|
||||
from ray.serve.storage.kv_store_base import KVStoreBase
|
||||
|
||||
logger = logging.getLogger(SERVE_LOGGER_NAME)
|
||||
|
@ -29,11 +30,16 @@ class RayInternalKVStore(KVStoreBase):
|
|||
Supports string keys and bytes values, caller must handle serialization.
|
||||
"""
|
||||
|
||||
def __init__(self, namespace: str = None):
|
||||
assert ray_kv._internal_kv_initialized()
|
||||
def __init__(
|
||||
self,
|
||||
namespace: str = None,
|
||||
):
|
||||
if namespace is not None and not isinstance(namespace, str):
|
||||
raise TypeError("namespace must a string, got: {}.".format(type(namespace)))
|
||||
|
||||
self.gcs_client = GcsClient(address=ray.get_runtime_context().gcs_address)
|
||||
|
||||
self.timeout = RAY_SERVE_KV_TIMEOUT_S
|
||||
self.namespace = namespace or ""
|
||||
|
||||
def get_storage_key(self, key: str) -> str:
|
||||
|
@ -51,11 +57,12 @@ class RayInternalKVStore(KVStoreBase):
|
|||
if not isinstance(val, bytes):
|
||||
raise TypeError("val must be bytes, got: {}.".format(type(val)))
|
||||
|
||||
ray_kv._internal_kv_put(
|
||||
self.get_storage_key(key),
|
||||
return self.gcs_client.internal_kv_put(
|
||||
self.get_storage_key(key).encode(),
|
||||
val,
|
||||
overwrite=True,
|
||||
namespace=ray_constants.KV_NAMESPACE_SERVE,
|
||||
timeout=self.timeout,
|
||||
)
|
||||
|
||||
def get(self, key: str) -> Optional[bytes]:
|
||||
|
@ -70,8 +77,10 @@ class RayInternalKVStore(KVStoreBase):
|
|||
if not isinstance(key, str):
|
||||
raise TypeError("key must be a string, got: {}.".format(type(key)))
|
||||
|
||||
return ray_kv._internal_kv_get(
|
||||
self.get_storage_key(key), namespace=ray_constants.KV_NAMESPACE_SERVE
|
||||
return self.gcs_client.internal_kv_get(
|
||||
self.get_storage_key(key).encode(),
|
||||
namespace=ray_constants.KV_NAMESPACE_SERVE,
|
||||
timeout=self.timeout,
|
||||
)
|
||||
|
||||
def delete(self, key: str):
|
||||
|
@ -83,8 +92,12 @@ class RayInternalKVStore(KVStoreBase):
|
|||
|
||||
if not isinstance(key, str):
|
||||
raise TypeError("key must be a string, got: {}.".format(type(key)))
|
||||
return ray_kv._internal_kv_del(
|
||||
self.get_storage_key(key), namespace=ray_constants.KV_NAMESPACE_SERVE
|
||||
|
||||
return self.gcs_client.internal_kv_del(
|
||||
self.get_storage_key(key).encode(),
|
||||
False,
|
||||
namespace=ray_constants.KV_NAMESPACE_SERVE,
|
||||
timeout=self.timeout,
|
||||
)
|
||||
|
||||
|
||||
|
@ -202,6 +215,7 @@ class RayS3KVStore(KVStoreBase):
|
|||
aws_access_key_id=None,
|
||||
aws_secret_access_key=None,
|
||||
aws_session_token=None,
|
||||
endpoint_url=None,
|
||||
):
|
||||
self._namespace = namespace
|
||||
self._bucket = bucket
|
||||
|
@ -217,6 +231,7 @@ class RayS3KVStore(KVStoreBase):
|
|||
aws_access_key_id=aws_access_key_id,
|
||||
aws_secret_access_key=aws_secret_access_key,
|
||||
aws_session_token=aws_session_token,
|
||||
endpoint_url=endpoint_url,
|
||||
)
|
||||
|
||||
def get_storage_key(self, key: str) -> str:
|
||||
|
|
|
@ -4,6 +4,7 @@ from typing import Optional
|
|||
|
||||
import pytest
|
||||
from ray.serve.constants import DEFAULT_CHECKPOINT_PATH
|
||||
from ray._private.test_utils import simulate_storage
|
||||
from ray.serve.storage.checkpoint_path import make_kv_store
|
||||
from ray.serve.storage.kv_store import RayInternalKVStore, RayLocalKVStore, RayS3KVStore
|
||||
from ray.serve.storage.kv_store_base import KVStoreBase
|
||||
|
@ -80,18 +81,36 @@ def test_external_kv_local_disk():
|
|||
_test_operations(kv_store)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Need to figure out credentials for testing")
|
||||
def test_external_kv_aws_s3():
|
||||
kv_store = RayS3KVStore(
|
||||
"namespace",
|
||||
bucket="jiao-test",
|
||||
s3_path="/checkpoint",
|
||||
aws_access_key_id=os.environ.get("AWS_ACCESS_KEY_ID", None),
|
||||
aws_secret_access_key=os.environ.get("AWS_SECRET_ACCESS_KEY", None),
|
||||
aws_session_token=os.environ.get("AWS_SESSION_TOKEN", None),
|
||||
)
|
||||
with simulate_storage("s3", "serve-test") as uri:
|
||||
from urllib.parse import urlparse, parse_qs
|
||||
|
||||
_test_operations(kv_store)
|
||||
o = urlparse(uri)
|
||||
qs = parse_qs(o.query)
|
||||
region_name = qs["region"][0]
|
||||
endpoint_url = qs["endpoint_override"][0]
|
||||
|
||||
import boto3
|
||||
|
||||
s3 = boto3.client(
|
||||
"s3",
|
||||
region_name=region_name,
|
||||
endpoint_url=endpoint_url,
|
||||
)
|
||||
s3.create_bucket(
|
||||
Bucket="serve-test",
|
||||
CreateBucketConfiguration={"LocationConstraint": "us-west-2"},
|
||||
)
|
||||
|
||||
kv_store = RayS3KVStore(
|
||||
"namespace",
|
||||
bucket="serve-test",
|
||||
prefix="checkpoint",
|
||||
region_name=region_name,
|
||||
endpoint_url=endpoint_url,
|
||||
)
|
||||
|
||||
_test_operations(kv_store)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Need to figure out credentials for testing")
|
||||
|
|
Loading…
Add table
Reference in a new issue