mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[gcs] Make gcs client in python able to auto reconnect (#20299)
## Why are these changes needed? Since we are using gcs client as kv backend, we need to make it auto-reconnect in case of a failure. This PR adds this feature. This PR adds auto_reconnect decorator to gcs-utils and in case of a failure it'll try to reconnect to gcs until it succeeds. This feature right now support redis which should be deleted later once we finished bootstrap since kv will always go to gcs. ## Related issue number
This commit is contained in:
parent
722f935f9a
commit
87fa56def4
6 changed files with 106 additions and 16 deletions
|
@ -160,7 +160,7 @@ class DashboardAgent(object):
|
|||
# TODO: redis-removal bootstrap
|
||||
gcs_address = await self.aioredis_client.get(
|
||||
dashboard_consts.REDIS_KEY_GCS_SERVER_ADDRESS)
|
||||
self.gcs_client = GcsClient(gcs_address.decode())
|
||||
self.gcs_client = GcsClient(address=gcs_address.decode())
|
||||
modules = self._load_modules()
|
||||
|
||||
# Http server should be initialized after all modules loaded.
|
||||
|
|
|
@ -75,7 +75,7 @@ class _SubscriberBase:
|
|||
class GcsPublisher:
|
||||
"""Publisher to GCS."""
|
||||
|
||||
def __init__(self, address: str = None, channel: grpc.Channel = None):
|
||||
def __init__(self, *, address: str = None, channel: grpc.Channel = None):
|
||||
if address:
|
||||
assert channel is None, \
|
||||
"address and channel cannot both be specified"
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
import enum
|
||||
import logging
|
||||
from typing import List
|
||||
from typing import List, Optional
|
||||
from functools import wraps
|
||||
import time
|
||||
|
||||
import grpc
|
||||
|
||||
|
@ -119,6 +121,59 @@ def create_gcs_channel(address: str, aio=False):
|
|||
return init_grpc_channel(address, options=_GRPC_OPTIONS, asynchronous=aio)
|
||||
|
||||
|
||||
def _auto_reconnect(f):
|
||||
@wraps(f)
|
||||
def wrapper(self, *args, **kwargs):
|
||||
remaining_retry = self._nums_reconnect_retry
|
||||
while True:
|
||||
try:
|
||||
return f(self, *args, **kwargs)
|
||||
except grpc.RpcError as e:
|
||||
if remaining_retry <= 0:
|
||||
raise
|
||||
if e.code() == grpc.StatusCode.UNAVAILABLE:
|
||||
logger.error(
|
||||
"Failed to send request to gcs, reconnecting. "
|
||||
f"Error {e}")
|
||||
try:
|
||||
self._connect()
|
||||
except Exception:
|
||||
logger.error(f"Connecting to gcs failed. Error {e}")
|
||||
time.sleep(1)
|
||||
remaining_retry -= 1
|
||||
continue
|
||||
raise
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class GcsChannel:
|
||||
def __init__(self,
|
||||
redis_client=None,
|
||||
gcs_address: Optional[str] = None,
|
||||
aio: bool = False):
|
||||
if redis_client is None and gcs_address is None:
|
||||
raise ValueError(
|
||||
"One of `redis_client` or `gcs_address` has to be set")
|
||||
if redis_client is not None and gcs_address is not None:
|
||||
raise ValueError(
|
||||
"Only one of `redis_client` or `gcs_address` can be set")
|
||||
self._redis_client = redis_client
|
||||
self._gcs_address = gcs_address
|
||||
self._aio = aio
|
||||
|
||||
def connect(self):
|
||||
if self._redis_client is not None:
|
||||
gcs_address = get_gcs_address_from_redis(self._redis_client)
|
||||
else:
|
||||
gcs_address = self._gcs_address
|
||||
|
||||
self._channel = create_gcs_channel(gcs_address, self._aio)
|
||||
|
||||
def channel(self):
|
||||
return self._channel
|
||||
|
||||
|
||||
class GcsCode(enum.IntEnum):
|
||||
# corresponding to ray/src/ray/common/status.h
|
||||
OK = 0
|
||||
|
@ -128,13 +183,24 @@ class GcsCode(enum.IntEnum):
|
|||
class GcsClient:
|
||||
"""Client to GCS using GRPC"""
|
||||
|
||||
def __init__(self, address: str = None, channel: grpc.Channel = None):
|
||||
if address:
|
||||
assert channel is None, \
|
||||
"Only one of address and channel can be specified"
|
||||
channel = create_gcs_channel(address)
|
||||
self._kv_stub = gcs_service_pb2_grpc.InternalKVGcsServiceStub(channel)
|
||||
def __init__(self,
|
||||
channel: Optional[GcsChannel] = None,
|
||||
address: Optional[str] = None,
|
||||
nums_reconnect_retry: int = 5):
|
||||
if channel is None:
|
||||
assert isinstance(address, str)
|
||||
channel = GcsChannel(gcs_address=address)
|
||||
assert isinstance(channel, GcsChannel)
|
||||
self._channel = channel
|
||||
self._connect()
|
||||
self._nums_reconnect_retry = nums_reconnect_retry
|
||||
|
||||
def _connect(self):
|
||||
self._channel.connect()
|
||||
self._kv_stub = gcs_service_pb2_grpc.InternalKVGcsServiceStub(
|
||||
self._channel.channel())
|
||||
|
||||
@_auto_reconnect
|
||||
def internal_kv_get(self, key: bytes) -> bytes:
|
||||
logger.debug(f"internal_kv_get {key}")
|
||||
req = gcs_service_pb2.InternalKVGetRequest(key=key)
|
||||
|
@ -147,6 +213,7 @@ class GcsClient:
|
|||
raise RuntimeError(f"Failed to get value for key {key} "
|
||||
f"due to error {reply.status.message}")
|
||||
|
||||
@_auto_reconnect
|
||||
def internal_kv_put(self, key: bytes, value: bytes,
|
||||
overwrite: bool) -> int:
|
||||
logger.debug(f"internal_kv_put {key} {value} {overwrite}")
|
||||
|
@ -159,6 +226,7 @@ class GcsClient:
|
|||
raise RuntimeError(f"Failed to put value {value} to key {key} "
|
||||
f"due to error {reply.status.message}")
|
||||
|
||||
@_auto_reconnect
|
||||
def internal_kv_del(self, key: bytes) -> int:
|
||||
logger.debug(f"internal_kv_del {key}")
|
||||
req = gcs_service_pb2.InternalKVDelRequest(key=key)
|
||||
|
@ -169,6 +237,7 @@ class GcsClient:
|
|||
raise RuntimeError(f"Failed to delete key {key} "
|
||||
f"due to error {reply.status.message}")
|
||||
|
||||
@_auto_reconnect
|
||||
def internal_kv_exists(self, key: bytes) -> bool:
|
||||
logger.debug(f"internal_kv_exists {key}")
|
||||
req = gcs_service_pb2.InternalKVExistsRequest(key=key)
|
||||
|
@ -179,6 +248,7 @@ class GcsClient:
|
|||
raise RuntimeError(f"Failed to check existence of key {key} "
|
||||
f"due to error {reply.status.message}")
|
||||
|
||||
@_auto_reconnect
|
||||
def internal_kv_keys(self, prefix: bytes) -> List[bytes]:
|
||||
logger.debug(f"internal_kv_keys {prefix}")
|
||||
req = gcs_service_pb2.InternalKVKeysRequest(prefix=prefix)
|
||||
|
@ -191,7 +261,7 @@ class GcsClient:
|
|||
|
||||
@staticmethod
|
||||
def create_from_redis(redis_cli):
|
||||
return GcsClient(get_gcs_address_from_redis(redis_cli))
|
||||
return GcsClient(GcsChannel(redis_client=redis_cli))
|
||||
|
||||
@staticmethod
|
||||
def connect_to_gcs_by_redis_address(redis_address, redis_password):
|
||||
|
|
|
@ -503,7 +503,8 @@ def get_non_head_nodes(cluster):
|
|||
def init_error_pubsub():
|
||||
"""Initialize redis error info pub/sub"""
|
||||
if gcs_pubsub_enabled():
|
||||
s = GcsSubscriber(channel=ray.worker.global_worker.gcs_channel)
|
||||
s = GcsSubscriber(
|
||||
channel=ray.worker.global_worker.gcs_channel.channel())
|
||||
s.subscribe_error()
|
||||
else:
|
||||
s = ray.worker.global_worker.redis_client.pubsub(
|
||||
|
|
|
@ -165,6 +165,25 @@ def test_del_actor_after_gcs_server_restart(ray_start_regular):
|
|||
ray.get_actor("abc")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("auto_reconnect", [True, False])
|
||||
def test_gcs_client_reconnect(ray_start_regular, auto_reconnect):
|
||||
redis_client = ray.worker.global_worker.redis_client
|
||||
channel = gcs_utils.GcsChannel(redis_client=redis_client)
|
||||
gcs_client = gcs_utils.GcsClient(channel) if auto_reconnect \
|
||||
else gcs_utils.GcsClient(channel, nums_reconnect_retry=0)
|
||||
|
||||
gcs_client.internal_kv_put(b"a", b"b", overwrite=True)
|
||||
gcs_client.internal_kv_get(b"a") == b"b"
|
||||
|
||||
ray.worker._global_node.kill_gcs_server()
|
||||
ray.worker._global_node.start_gcs_server()
|
||||
if auto_reconnect is False:
|
||||
with pytest.raises(Exception):
|
||||
gcs_client.internal_kv_get(b"a")
|
||||
else:
|
||||
assert gcs_client.internal_kv_get(b"a") == b"b"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import pytest
|
||||
sys.exit(pytest.main(["-v", __file__]))
|
||||
|
|
|
@ -1259,7 +1259,7 @@ def listen_error_messages_from_gcs(worker, threads_stopped):
|
|||
threads_stopped (threading.Event): A threading event used to signal to
|
||||
the thread that it should exit.
|
||||
"""
|
||||
worker.gcs_subscriber = GcsSubscriber(channel=worker.gcs_channel)
|
||||
worker.gcs_subscriber = GcsSubscriber(channel=worker.gcs_channel.channel())
|
||||
# Exports that are published after the call to
|
||||
# gcs_subscriber.subscribe_error() and before the call to
|
||||
# gcs_subscriber.poll_error() will still be processed in the loop.
|
||||
|
@ -1360,16 +1360,16 @@ def connect(node,
|
|||
# that is not true of Redis pubsub clients. See the documentation at
|
||||
# https://github.com/andymccurdy/redis-py#thread-safety.
|
||||
worker.redis_client = node.create_redis_client()
|
||||
worker.gcs_channel = gcs_utils.create_gcs_channel(
|
||||
gcs_utils.get_gcs_address_from_redis(worker.redis_client))
|
||||
worker.gcs_client = gcs_utils.GcsClient(channel=worker.gcs_channel)
|
||||
worker.gcs_channel = gcs_utils.GcsChannel(redis_client=worker.redis_client)
|
||||
worker.gcs_client = gcs_utils.GcsClient(worker.gcs_channel)
|
||||
_initialize_internal_kv(worker.gcs_client)
|
||||
ray.state.state._initialize_global_state(
|
||||
node.redis_address, redis_password=node.redis_password)
|
||||
worker.gcs_pubsub_enabled = gcs_pubsub_enabled()
|
||||
worker.gcs_publisher = None
|
||||
if worker.gcs_pubsub_enabled:
|
||||
worker.gcs_publisher = GcsPublisher(channel=worker.gcs_channel)
|
||||
worker.gcs_publisher = GcsPublisher(
|
||||
channel=worker.gcs_channel.channel())
|
||||
|
||||
# Initialize some fields.
|
||||
if mode in (WORKER_MODE, RESTORE_WORKER_MODE, SPILL_WORKER_MODE):
|
||||
|
|
Loading…
Add table
Reference in a new issue