[gcs] Make gcs client in python able to auto reconnect (#20299)

## Why are these changes needed?
Since we are using gcs client as kv backend, we need to make it auto-reconnect in case of a failure. This PR adds this feature.

This PR adds auto_reconnect decorator to gcs-utils and in case of a failure it'll try to reconnect to gcs until it succeeds.

This feature right now support redis which should be deleted later once we finished bootstrap since kv will always go to gcs.

## Related issue number
This commit is contained in:
Yi Cheng 2021-11-14 11:27:49 -08:00 committed by GitHub
parent 722f935f9a
commit 87fa56def4
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 106 additions and 16 deletions

View file

@ -160,7 +160,7 @@ class DashboardAgent(object):
# TODO: redis-removal bootstrap
gcs_address = await self.aioredis_client.get(
dashboard_consts.REDIS_KEY_GCS_SERVER_ADDRESS)
self.gcs_client = GcsClient(gcs_address.decode())
self.gcs_client = GcsClient(address=gcs_address.decode())
modules = self._load_modules()
# Http server should be initialized after all modules loaded.

View file

@ -75,7 +75,7 @@ class _SubscriberBase:
class GcsPublisher:
"""Publisher to GCS."""
def __init__(self, address: str = None, channel: grpc.Channel = None):
def __init__(self, *, address: str = None, channel: grpc.Channel = None):
if address:
assert channel is None, \
"address and channel cannot both be specified"

View file

@ -1,6 +1,8 @@
import enum
import logging
from typing import List
from typing import List, Optional
from functools import wraps
import time
import grpc
@ -119,6 +121,59 @@ def create_gcs_channel(address: str, aio=False):
return init_grpc_channel(address, options=_GRPC_OPTIONS, asynchronous=aio)
def _auto_reconnect(f):
@wraps(f)
def wrapper(self, *args, **kwargs):
remaining_retry = self._nums_reconnect_retry
while True:
try:
return f(self, *args, **kwargs)
except grpc.RpcError as e:
if remaining_retry <= 0:
raise
if e.code() == grpc.StatusCode.UNAVAILABLE:
logger.error(
"Failed to send request to gcs, reconnecting. "
f"Error {e}")
try:
self._connect()
except Exception:
logger.error(f"Connecting to gcs failed. Error {e}")
time.sleep(1)
remaining_retry -= 1
continue
raise
return wrapper
class GcsChannel:
def __init__(self,
redis_client=None,
gcs_address: Optional[str] = None,
aio: bool = False):
if redis_client is None and gcs_address is None:
raise ValueError(
"One of `redis_client` or `gcs_address` has to be set")
if redis_client is not None and gcs_address is not None:
raise ValueError(
"Only one of `redis_client` or `gcs_address` can be set")
self._redis_client = redis_client
self._gcs_address = gcs_address
self._aio = aio
def connect(self):
if self._redis_client is not None:
gcs_address = get_gcs_address_from_redis(self._redis_client)
else:
gcs_address = self._gcs_address
self._channel = create_gcs_channel(gcs_address, self._aio)
def channel(self):
return self._channel
class GcsCode(enum.IntEnum):
# corresponding to ray/src/ray/common/status.h
OK = 0
@ -128,13 +183,24 @@ class GcsCode(enum.IntEnum):
class GcsClient:
"""Client to GCS using GRPC"""
def __init__(self, address: str = None, channel: grpc.Channel = None):
if address:
assert channel is None, \
"Only one of address and channel can be specified"
channel = create_gcs_channel(address)
self._kv_stub = gcs_service_pb2_grpc.InternalKVGcsServiceStub(channel)
def __init__(self,
channel: Optional[GcsChannel] = None,
address: Optional[str] = None,
nums_reconnect_retry: int = 5):
if channel is None:
assert isinstance(address, str)
channel = GcsChannel(gcs_address=address)
assert isinstance(channel, GcsChannel)
self._channel = channel
self._connect()
self._nums_reconnect_retry = nums_reconnect_retry
def _connect(self):
self._channel.connect()
self._kv_stub = gcs_service_pb2_grpc.InternalKVGcsServiceStub(
self._channel.channel())
@_auto_reconnect
def internal_kv_get(self, key: bytes) -> bytes:
logger.debug(f"internal_kv_get {key}")
req = gcs_service_pb2.InternalKVGetRequest(key=key)
@ -147,6 +213,7 @@ class GcsClient:
raise RuntimeError(f"Failed to get value for key {key} "
f"due to error {reply.status.message}")
@_auto_reconnect
def internal_kv_put(self, key: bytes, value: bytes,
overwrite: bool) -> int:
logger.debug(f"internal_kv_put {key} {value} {overwrite}")
@ -159,6 +226,7 @@ class GcsClient:
raise RuntimeError(f"Failed to put value {value} to key {key} "
f"due to error {reply.status.message}")
@_auto_reconnect
def internal_kv_del(self, key: bytes) -> int:
logger.debug(f"internal_kv_del {key}")
req = gcs_service_pb2.InternalKVDelRequest(key=key)
@ -169,6 +237,7 @@ class GcsClient:
raise RuntimeError(f"Failed to delete key {key} "
f"due to error {reply.status.message}")
@_auto_reconnect
def internal_kv_exists(self, key: bytes) -> bool:
logger.debug(f"internal_kv_exists {key}")
req = gcs_service_pb2.InternalKVExistsRequest(key=key)
@ -179,6 +248,7 @@ class GcsClient:
raise RuntimeError(f"Failed to check existence of key {key} "
f"due to error {reply.status.message}")
@_auto_reconnect
def internal_kv_keys(self, prefix: bytes) -> List[bytes]:
logger.debug(f"internal_kv_keys {prefix}")
req = gcs_service_pb2.InternalKVKeysRequest(prefix=prefix)
@ -191,7 +261,7 @@ class GcsClient:
@staticmethod
def create_from_redis(redis_cli):
return GcsClient(get_gcs_address_from_redis(redis_cli))
return GcsClient(GcsChannel(redis_client=redis_cli))
@staticmethod
def connect_to_gcs_by_redis_address(redis_address, redis_password):

View file

@ -503,7 +503,8 @@ def get_non_head_nodes(cluster):
def init_error_pubsub():
"""Initialize redis error info pub/sub"""
if gcs_pubsub_enabled():
s = GcsSubscriber(channel=ray.worker.global_worker.gcs_channel)
s = GcsSubscriber(
channel=ray.worker.global_worker.gcs_channel.channel())
s.subscribe_error()
else:
s = ray.worker.global_worker.redis_client.pubsub(

View file

@ -165,6 +165,25 @@ def test_del_actor_after_gcs_server_restart(ray_start_regular):
ray.get_actor("abc")
@pytest.mark.parametrize("auto_reconnect", [True, False])
def test_gcs_client_reconnect(ray_start_regular, auto_reconnect):
redis_client = ray.worker.global_worker.redis_client
channel = gcs_utils.GcsChannel(redis_client=redis_client)
gcs_client = gcs_utils.GcsClient(channel) if auto_reconnect \
else gcs_utils.GcsClient(channel, nums_reconnect_retry=0)
gcs_client.internal_kv_put(b"a", b"b", overwrite=True)
gcs_client.internal_kv_get(b"a") == b"b"
ray.worker._global_node.kill_gcs_server()
ray.worker._global_node.start_gcs_server()
if auto_reconnect is False:
with pytest.raises(Exception):
gcs_client.internal_kv_get(b"a")
else:
assert gcs_client.internal_kv_get(b"a") == b"b"
if __name__ == "__main__":
import pytest
sys.exit(pytest.main(["-v", __file__]))

View file

@ -1259,7 +1259,7 @@ def listen_error_messages_from_gcs(worker, threads_stopped):
threads_stopped (threading.Event): A threading event used to signal to
the thread that it should exit.
"""
worker.gcs_subscriber = GcsSubscriber(channel=worker.gcs_channel)
worker.gcs_subscriber = GcsSubscriber(channel=worker.gcs_channel.channel())
# Exports that are published after the call to
# gcs_subscriber.subscribe_error() and before the call to
# gcs_subscriber.poll_error() will still be processed in the loop.
@ -1360,16 +1360,16 @@ def connect(node,
# that is not true of Redis pubsub clients. See the documentation at
# https://github.com/andymccurdy/redis-py#thread-safety.
worker.redis_client = node.create_redis_client()
worker.gcs_channel = gcs_utils.create_gcs_channel(
gcs_utils.get_gcs_address_from_redis(worker.redis_client))
worker.gcs_client = gcs_utils.GcsClient(channel=worker.gcs_channel)
worker.gcs_channel = gcs_utils.GcsChannel(redis_client=worker.redis_client)
worker.gcs_client = gcs_utils.GcsClient(worker.gcs_channel)
_initialize_internal_kv(worker.gcs_client)
ray.state.state._initialize_global_state(
node.redis_address, redis_password=node.redis_password)
worker.gcs_pubsub_enabled = gcs_pubsub_enabled()
worker.gcs_publisher = None
if worker.gcs_pubsub_enabled:
worker.gcs_publisher = GcsPublisher(channel=worker.gcs_channel)
worker.gcs_publisher = GcsPublisher(
channel=worker.gcs_channel.channel())
# Initialize some fields.
if mode in (WORKER_MODE, RESTORE_WORKER_MODE, SPILL_WORKER_MODE):