[core] Resubscribe GCS in python when GCS restarts. (#24887)

This is a follow-up PRs of https://github.com/ray-project/ray/pull/24813 and https://github.com/ray-project/ray/pull/24628

Unlike the change in cpp layer, where the resubscription is done by GCS broadcast a request to raylet/core_worker and the client-side do the resubscription, in the python layer, we detect the failure in the client-side.

In case of a failure, the protocol is:

1. call subscribe
2. if timeout when doing resubscribe, throw an exception and this will crash the system. This is ok because when GCS has been down for a time longer than expected, we expect the ray cluster to be down.
3. continue to poll once subscribe ok.

However, there is an extreme case where things might be broken: the client might miss detecting a failure.

This could happen if the long-polling has been returned and the python layer is doing its own work. And before it sends another long-polling, GCS restarts and recovered. 

Here we are not going to take care of this case because:
1. usually GCS is going to take several seconds to be up and the python layer's work is simply pushing data into a queue (sync version). For the async version, it's only used in Dashboard which is not a critical component.
2. pubsub in python layer is not doing critical work: it handles logs/errors for ray job;
3. for the dashboard, it can just restart to fix the issue.


A known issue here is that we might miss logs in case of GCS failure due to the following reasons:

- py's pubsub is only doing best effort publishing. If it failed too many times, it'll skip publishing the message (lose messages from producer side)
- if message is pushed to GCS, but the worker hasn't done resubscription yet, the pushed message will be lost (lose messages from consumer side)

We think it's reasonable and valid behavior given that the logs are not defined to be a critical component and we'd like to simplify the design of pubsub in GCS.

Another things is `run_functions_on_all_workers`. We'll plan to stop using it within ray core and deprecate it in the longer term. But it won't cause a problem for the current cases because:

1. It's only set in driver and we don't support creating a new driver when GCS is down.
2. When GCS is down, we don't support starting new ray workers.

And `run_functions_on_all_workers` is only used when we initialize driver/workers.
This commit is contained in:
Yi Cheng 2022-05-23 13:06:33 -07:00 committed by GitHub
parent 36b1b4ce0c
commit 7cf4233858
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 289 additions and 55 deletions

View file

@ -697,32 +697,6 @@ def test_dashboard_port_conflict(ray_start_with_dashboard):
raise Exception("Timed out while testing.")
@pytest.mark.skipif(
os.environ.get("RAY_MINIMAL") == "1",
reason="This test is not supposed to work for minimal installation.",
)
def test_gcs_check_alive(fast_gcs_failure_detection, ray_start_with_dashboard):
assert wait_until_server_available(ray_start_with_dashboard["webui_url"]) is True
all_processes = ray.worker._global_node.all_processes
dashboard_info = all_processes[ray_constants.PROCESS_TYPE_DASHBOARD][0]
dashboard_proc = psutil.Process(dashboard_info.process.pid)
gcs_server_info = all_processes[ray_constants.PROCESS_TYPE_GCS_SERVER][0]
gcs_server_proc = psutil.Process(gcs_server_info.process.pid)
assert dashboard_proc.status() in [
psutil.STATUS_RUNNING,
psutil.STATUS_SLEEPING,
psutil.STATUS_DISK_SLEEP,
]
gcs_server_proc.kill()
gcs_server_proc.wait()
# The dashboard exits by os._exit(-1)
assert dashboard_proc.wait(10) == 255
@pytest.mark.skipif(
os.environ.get("RAY_DEFAULT") != "1",
reason="This test only works for default installation.",

View file

@ -104,16 +104,28 @@ class _SubscriberBase:
)
return req
@staticmethod
def _should_terminate_polling(e: grpc.RpcError) -> None:
# Caller only expects polling to be terminated after deadline exceeded.
if e.code() == grpc.StatusCode.DEADLINE_EXCEEDED:
def _handle_polling_failure(self, e: grpc.RpcError) -> bool:
if self._close.is_set():
return False
if e.code() in (
grpc.StatusCode.UNAVAILABLE,
grpc.StatusCode.UNKNOWN,
grpc.StatusCode.DEADLINE_EXCEEDED,
):
return True
# Could be a temporary connection issue. Suppress error.
# TODO: reconnect GRPC channel?
if e.code() == grpc.StatusCode.UNAVAILABLE:
return True
return False
raise e
def _handle_subscribe_failure(self, e: grpc.RpcError):
if e.code() in (
grpc.StatusCode.UNAVAILABLE,
grpc.StatusCode.UNKNOWN,
grpc.StatusCode.DEADLINE_EXCEEDED,
):
time.sleep(1)
else:
raise e
@staticmethod
def _pop_error_info(queue):
@ -211,7 +223,7 @@ class _SyncSubscriber(_SubscriberBase):
# Type of the channel.
self._channel = pubsub_channel_type
# Protects multi-threaded read and write of self._queue.
self._lock = threading.Lock()
self._lock = threading.RLock()
# A queue of received PubMessage.
self._queue = deque()
# Indicates whether the subscriber has closed.
@ -224,13 +236,24 @@ class _SyncSubscriber(_SubscriberBase):
saved for the subscriber.
"""
with self._lock:
if self._close.is_set():
return
req = self._subscribe_request(self._channel)
self._stub.GcsSubscriberCommandBatch(req, timeout=30)
start = time.time()
from ray._raylet import Config
while True:
try:
if self._close.is_set():
return
return self._stub.GcsSubscriberCommandBatch(req, timeout=30)
except grpc.RpcError as e:
self._handle_subscribe_failure(e)
if (
time.time() - start
> Config.gcs_rpc_server_reconnect_timeout_s()
):
raise e
def _poll_locked(self, timeout=None) -> None:
assert self._lock.locked()
# Poll until data becomes available.
while len(self._queue) == 0:
@ -257,9 +280,18 @@ class _SyncSubscriber(_SubscriberBase):
# GRPC has not replied, continue waiting.
continue
except grpc.RpcError as e:
if self._should_terminate_polling(e):
if (
e.code() == grpc.StatusCode.DEADLINE_EXCEEDED
and timeout is not None
):
return
if self._handle_polling_failure(e) is True:
self.subscribe()
fut = self._stub.GcsSubscriberPoll.future(
self._poll_request(), timeout=timeout
)
else:
return
raise
if fut.done():
self._last_batch_size = len(fut.result().pub_messages)
@ -277,11 +309,13 @@ class _SyncSubscriber(_SubscriberBase):
return
self._close.set()
req = self._unsubscribe_request(channels=[self._channel])
try:
self._stub.GcsSubscriberCommandBatch(req, timeout=5)
except Exception:
pass
self._stub = None
with self._lock:
self._stub = None
class GcsErrorSubscriber(_SyncSubscriber):
@ -496,18 +530,39 @@ class _AioSubscriber(_SubscriberBase):
if self._close.is_set():
return
req = self._subscribe_request(self._channel)
await self._stub.GcsSubscriberCommandBatch(req, timeout=30)
start = time.time()
from ray._raylet import Config
while True:
try:
return await self._stub.GcsSubscriberCommandBatch(req, timeout=30)
except grpc.RpcError as e:
self._handle_subscribe_failure(e)
if time.time() - start > Config.gcs_rpc_server_reconnect_timeout_s():
raise
async def _poll_call(self, req, timeout=None):
# Wrap GRPC _AioCall as a coroutine.
return await self._stub.GcsSubscriberPoll(req, timeout=timeout)
while True:
try:
return await self._stub.GcsSubscriberPoll(req, timeout=timeout)
except grpc.RpcError as e:
if (
e.code() == grpc.StatusCode.DEADLINE_EXCEEDED
and timeout is not None
):
return
if self._handle_polling_failure(e) is True:
await self.subscribe()
else:
return
async def _poll(self, timeout=None) -> None:
req = self._poll_request()
while len(self._queue) == 0:
# TODO: use asyncio.create_task() after Python 3.6 is no longer
# supported.
poll = asyncio.ensure_future(self._poll_call(req, timeout=timeout))
poll = asyncio.ensure_future(self._poll_call(req))
close = asyncio.ensure_future(self._close.wait())
done, _ = await asyncio.wait(
[poll, close], timeout=timeout, return_when=asyncio.FIRST_COMPLETED
@ -515,14 +570,9 @@ class _AioSubscriber(_SubscriberBase):
if poll not in done or close in done:
# Request timed out or subscriber closed.
break
try:
self._last_batch_size = len(poll.result().pub_messages)
for msg in poll.result().pub_messages:
self._queue.append(msg)
except grpc.RpcError as e:
if self._should_terminate_polling(e):
return
raise
self._last_batch_size = len(poll.result().pub_messages)
for msg in poll.result().pub_messages:
self._queue.append(msg)
async def close(self) -> None:
"""Closes the subscriber and its active subscription."""

View file

@ -1,5 +1,5 @@
from libcpp cimport bool as c_bool
from libc.stdint cimport int64_t, uint64_t, uint32_t
from libc.stdint cimport int64_t, uint64_t, uint32_t, int32_t
from libcpp.string cimport string as c_string
from libcpp.unordered_map cimport unordered_map
@ -68,3 +68,5 @@ cdef extern from "ray/common/ray_config.h" nogil:
c_bool start_python_importer_thread() const
c_bool use_ray_syncer() const
int32_t gcs_rpc_server_reconnect_timeout_s() const

View file

@ -84,6 +84,10 @@ cdef class Config:
def object_manager_default_chunk_size():
return RayConfig.instance().object_manager_default_chunk_size()
@staticmethod
def gcs_rpc_server_reconnect_timeout_s():
return RayConfig.instance().gcs_rpc_server_reconnect_timeout_s()
@staticmethod
def maximum_gcs_deletion_batch_size():
return RayConfig.instance().maximum_gcs_deletion_batch_size()

View file

@ -9,11 +9,16 @@ from time import sleep
from ray._private.test_utils import (
generate_system_config_map,
run_string_as_driver_nonblocking,
wait_for_condition,
wait_for_pid_to_exit,
convert_actor_state,
)
import logging
logger = logging.getLogger(__name__)
@ray.remote
class Increase:
@ -338,6 +343,83 @@ def test_core_worker_resubscription(tmp_path, ray_start_regular_with_external_re
ray.get(r, timeout=5)
@pytest.mark.parametrize(
"ray_start_regular_with_external_redis",
[
generate_system_config_map(
num_heartbeats_timeout=20, gcs_rpc_server_reconnect_timeout_s=60
)
],
indirect=True,
)
def test_py_resubscription(tmp_path, ray_start_regular_with_external_redis):
# This test is to ensure python pubsub works
from filelock import FileLock
lock_file1 = str(tmp_path / "lock1")
lock1 = FileLock(lock_file1)
lock1.acquire()
lock_file2 = str(tmp_path / "lock2")
lock2 = FileLock(lock_file2)
script = f"""
from filelock import FileLock
import ray
@ray.remote
def f():
print("OK1", flush=True)
# wait until log_monitor push this
from time import sleep
sleep(2)
lock1 = FileLock(r"{lock_file1}")
lock2 = FileLock(r"{lock_file2}")
lock2.acquire()
lock1.acquire()
# wait until log_monitor push this
from time import sleep
sleep(2)
print("OK2", flush=True)
ray.init(address='auto')
ray.get(f.remote())
ray.shutdown()
"""
proc = run_string_as_driver_nonblocking(script)
def condition():
import filelock
try:
lock2.acquire(timeout=1)
except filelock.Timeout:
return True
lock2.release()
return False
# make sure the script has printed "OK1"
wait_for_condition(condition, timeout=10)
ray.worker._global_node.kill_gcs_server()
import time
time.sleep(2)
ray.worker._global_node.start_gcs_server()
lock1.release()
proc.wait()
output = proc.stdout.read()
# Print logs which are useful for debugging in CI
print("=================== OUTPUTS ============")
print(output.decode())
assert b"OK1" in output
assert b"OK2" in output
@pytest.mark.parametrize("auto_reconnect", [True, False])
def test_gcs_client_reconnect(ray_start_regular_with_external_redis, auto_reconnect):
gcs_address = ray.worker.global_worker.gcs_client.address

View file

@ -1,6 +1,7 @@
import sys
import threading
import ray
from ray._private.gcs_pubsub import (
GcsPublisher,
GcsErrorSubscriber,
@ -34,6 +35,72 @@ def test_publish_and_subscribe_error_info(ray_start_regular):
subscriber.close()
def test_publish_and_subscribe_error_info_ft(ray_start_regular_with_external_redis):
address_info = ray_start_regular_with_external_redis
gcs_server_addr = address_info["gcs_address"]
from threading import Barrier, Thread
subscriber = GcsErrorSubscriber(address=gcs_server_addr)
subscriber.subscribe()
publisher = GcsPublisher(address=gcs_server_addr)
err1 = ErrorTableData(error_message="test error message 1")
err2 = ErrorTableData(error_message="test error message 2")
err3 = ErrorTableData(error_message="test error message 3")
err4 = ErrorTableData(error_message="test error message 4")
b = Barrier(3)
def publisher_func():
print("Publisher HERE")
publisher.publish_error(b"aaa_id", err1)
publisher.publish_error(b"bbb_id", err2)
b.wait()
print("Publisher HERE")
# Wait fo subscriber to subscribe first.
# It's ok to loose log messages.
from time import sleep
sleep(5)
publisher.publish_error(b"aaa_id", err3)
print("pub err1")
publisher.publish_error(b"bbb_id", err4)
print("pub err2")
print("DONE")
def subscriber_func():
print("Subscriber HERE")
assert subscriber.poll() == (b"aaa_id", err1)
assert subscriber.poll() == (b"bbb_id", err2)
b.wait()
assert subscriber.poll() == (b"aaa_id", err3)
print("sub err1")
assert subscriber.poll() == (b"bbb_id", err4)
print("sub err2")
subscriber.close()
print("DONE")
t1 = Thread(target=publisher_func)
t2 = Thread(target=subscriber_func)
t1.start()
t2.start()
b.wait()
ray.worker._global_node.kill_gcs_server()
from time import sleep
sleep(1)
ray.worker._global_node.start_gcs_server()
sleep(1)
t1.join()
t2.join()
@pytest.mark.asyncio
async def test_aio_publish_and_subscribe_error_info(ray_start_regular):
address_info = ray_start_regular
@ -54,6 +121,61 @@ async def test_aio_publish_and_subscribe_error_info(ray_start_regular):
await subscriber.close()
@pytest.mark.asyncio
async def test_aio_publish_and_subscribe_error_info_ft(
ray_start_regular_with_external_redis,
):
address_info = ray_start_regular_with_external_redis
gcs_server_addr = address_info["gcs_address"]
subscriber = GcsAioErrorSubscriber(address=gcs_server_addr)
await subscriber.subscribe()
err1 = ErrorTableData(error_message="test error message 1")
err2 = ErrorTableData(error_message="test error message 2")
err3 = ErrorTableData(error_message="test error message 3")
err4 = ErrorTableData(error_message="test error message 4")
def restart_gcs_server():
import asyncio
asyncio.set_event_loop(asyncio.new_event_loop())
from time import sleep
publisher = GcsAioPublisher(address=gcs_server_addr)
asyncio.get_event_loop().run_until_complete(
publisher.publish_error(b"aaa_id", err1)
)
asyncio.get_event_loop().run_until_complete(
publisher.publish_error(b"bbb_id", err2)
)
# wait until subscribe consume everything
sleep(5)
ray.worker._global_node.kill_gcs_server()
sleep(1)
ray.worker._global_node.start_gcs_server()
# wait until subscriber resubscribed
sleep(5)
asyncio.get_event_loop().run_until_complete(
publisher.publish_error(b"aaa_id", err3)
)
asyncio.get_event_loop().run_until_complete(
publisher.publish_error(b"bbb_id", err4)
)
t1 = threading.Thread(target=restart_gcs_server)
t1.start()
assert await subscriber.poll() == (b"aaa_id", err1)
assert await subscriber.poll() == (b"bbb_id", err2)
assert await subscriber.poll() == (b"aaa_id", err3)
assert await subscriber.poll() == (b"bbb_id", err4)
await subscriber.close()
t1.join()
def test_publish_and_subscribe_logs(ray_start_regular):
address_info = ray_start_regular
gcs_server_addr = address_info["gcs_address"]