mirror of
https://github.com/vale981/ray
synced 2025-03-05 10:01:43 -05:00
[core] Resubscribe GCS in python when GCS restarts. (#24887)
This is a follow-up PRs of https://github.com/ray-project/ray/pull/24813 and https://github.com/ray-project/ray/pull/24628 Unlike the change in cpp layer, where the resubscription is done by GCS broadcast a request to raylet/core_worker and the client-side do the resubscription, in the python layer, we detect the failure in the client-side. In case of a failure, the protocol is: 1. call subscribe 2. if timeout when doing resubscribe, throw an exception and this will crash the system. This is ok because when GCS has been down for a time longer than expected, we expect the ray cluster to be down. 3. continue to poll once subscribe ok. However, there is an extreme case where things might be broken: the client might miss detecting a failure. This could happen if the long-polling has been returned and the python layer is doing its own work. And before it sends another long-polling, GCS restarts and recovered. Here we are not going to take care of this case because: 1. usually GCS is going to take several seconds to be up and the python layer's work is simply pushing data into a queue (sync version). For the async version, it's only used in Dashboard which is not a critical component. 2. pubsub in python layer is not doing critical work: it handles logs/errors for ray job; 3. for the dashboard, it can just restart to fix the issue. A known issue here is that we might miss logs in case of GCS failure due to the following reasons: - py's pubsub is only doing best effort publishing. If it failed too many times, it'll skip publishing the message (lose messages from producer side) - if message is pushed to GCS, but the worker hasn't done resubscription yet, the pushed message will be lost (lose messages from consumer side) We think it's reasonable and valid behavior given that the logs are not defined to be a critical component and we'd like to simplify the design of pubsub in GCS. Another things is `run_functions_on_all_workers`. We'll plan to stop using it within ray core and deprecate it in the longer term. But it won't cause a problem for the current cases because: 1. It's only set in driver and we don't support creating a new driver when GCS is down. 2. When GCS is down, we don't support starting new ray workers. And `run_functions_on_all_workers` is only used when we initialize driver/workers.
This commit is contained in:
parent
36b1b4ce0c
commit
7cf4233858
6 changed files with 289 additions and 55 deletions
|
@ -697,32 +697,6 @@ def test_dashboard_port_conflict(ray_start_with_dashboard):
|
|||
raise Exception("Timed out while testing.")
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("RAY_MINIMAL") == "1",
|
||||
reason="This test is not supposed to work for minimal installation.",
|
||||
)
|
||||
def test_gcs_check_alive(fast_gcs_failure_detection, ray_start_with_dashboard):
|
||||
assert wait_until_server_available(ray_start_with_dashboard["webui_url"]) is True
|
||||
|
||||
all_processes = ray.worker._global_node.all_processes
|
||||
dashboard_info = all_processes[ray_constants.PROCESS_TYPE_DASHBOARD][0]
|
||||
dashboard_proc = psutil.Process(dashboard_info.process.pid)
|
||||
gcs_server_info = all_processes[ray_constants.PROCESS_TYPE_GCS_SERVER][0]
|
||||
gcs_server_proc = psutil.Process(gcs_server_info.process.pid)
|
||||
|
||||
assert dashboard_proc.status() in [
|
||||
psutil.STATUS_RUNNING,
|
||||
psutil.STATUS_SLEEPING,
|
||||
psutil.STATUS_DISK_SLEEP,
|
||||
]
|
||||
|
||||
gcs_server_proc.kill()
|
||||
gcs_server_proc.wait()
|
||||
|
||||
# The dashboard exits by os._exit(-1)
|
||||
assert dashboard_proc.wait(10) == 255
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("RAY_DEFAULT") != "1",
|
||||
reason="This test only works for default installation.",
|
||||
|
|
|
@ -104,16 +104,28 @@ class _SubscriberBase:
|
|||
)
|
||||
return req
|
||||
|
||||
@staticmethod
|
||||
def _should_terminate_polling(e: grpc.RpcError) -> None:
|
||||
# Caller only expects polling to be terminated after deadline exceeded.
|
||||
if e.code() == grpc.StatusCode.DEADLINE_EXCEEDED:
|
||||
def _handle_polling_failure(self, e: grpc.RpcError) -> bool:
|
||||
if self._close.is_set():
|
||||
return False
|
||||
|
||||
if e.code() in (
|
||||
grpc.StatusCode.UNAVAILABLE,
|
||||
grpc.StatusCode.UNKNOWN,
|
||||
grpc.StatusCode.DEADLINE_EXCEEDED,
|
||||
):
|
||||
return True
|
||||
# Could be a temporary connection issue. Suppress error.
|
||||
# TODO: reconnect GRPC channel?
|
||||
if e.code() == grpc.StatusCode.UNAVAILABLE:
|
||||
return True
|
||||
return False
|
||||
|
||||
raise e
|
||||
|
||||
def _handle_subscribe_failure(self, e: grpc.RpcError):
|
||||
if e.code() in (
|
||||
grpc.StatusCode.UNAVAILABLE,
|
||||
grpc.StatusCode.UNKNOWN,
|
||||
grpc.StatusCode.DEADLINE_EXCEEDED,
|
||||
):
|
||||
time.sleep(1)
|
||||
else:
|
||||
raise e
|
||||
|
||||
@staticmethod
|
||||
def _pop_error_info(queue):
|
||||
|
@ -211,7 +223,7 @@ class _SyncSubscriber(_SubscriberBase):
|
|||
# Type of the channel.
|
||||
self._channel = pubsub_channel_type
|
||||
# Protects multi-threaded read and write of self._queue.
|
||||
self._lock = threading.Lock()
|
||||
self._lock = threading.RLock()
|
||||
# A queue of received PubMessage.
|
||||
self._queue = deque()
|
||||
# Indicates whether the subscriber has closed.
|
||||
|
@ -224,13 +236,24 @@ class _SyncSubscriber(_SubscriberBase):
|
|||
saved for the subscriber.
|
||||
"""
|
||||
with self._lock:
|
||||
if self._close.is_set():
|
||||
return
|
||||
req = self._subscribe_request(self._channel)
|
||||
self._stub.GcsSubscriberCommandBatch(req, timeout=30)
|
||||
start = time.time()
|
||||
from ray._raylet import Config
|
||||
|
||||
while True:
|
||||
try:
|
||||
if self._close.is_set():
|
||||
return
|
||||
return self._stub.GcsSubscriberCommandBatch(req, timeout=30)
|
||||
except grpc.RpcError as e:
|
||||
self._handle_subscribe_failure(e)
|
||||
if (
|
||||
time.time() - start
|
||||
> Config.gcs_rpc_server_reconnect_timeout_s()
|
||||
):
|
||||
raise e
|
||||
|
||||
def _poll_locked(self, timeout=None) -> None:
|
||||
assert self._lock.locked()
|
||||
|
||||
# Poll until data becomes available.
|
||||
while len(self._queue) == 0:
|
||||
|
@ -257,9 +280,18 @@ class _SyncSubscriber(_SubscriberBase):
|
|||
# GRPC has not replied, continue waiting.
|
||||
continue
|
||||
except grpc.RpcError as e:
|
||||
if self._should_terminate_polling(e):
|
||||
if (
|
||||
e.code() == grpc.StatusCode.DEADLINE_EXCEEDED
|
||||
and timeout is not None
|
||||
):
|
||||
return
|
||||
if self._handle_polling_failure(e) is True:
|
||||
self.subscribe()
|
||||
fut = self._stub.GcsSubscriberPoll.future(
|
||||
self._poll_request(), timeout=timeout
|
||||
)
|
||||
else:
|
||||
return
|
||||
raise
|
||||
|
||||
if fut.done():
|
||||
self._last_batch_size = len(fut.result().pub_messages)
|
||||
|
@ -277,11 +309,13 @@ class _SyncSubscriber(_SubscriberBase):
|
|||
return
|
||||
self._close.set()
|
||||
req = self._unsubscribe_request(channels=[self._channel])
|
||||
|
||||
try:
|
||||
self._stub.GcsSubscriberCommandBatch(req, timeout=5)
|
||||
except Exception:
|
||||
pass
|
||||
self._stub = None
|
||||
with self._lock:
|
||||
self._stub = None
|
||||
|
||||
|
||||
class GcsErrorSubscriber(_SyncSubscriber):
|
||||
|
@ -496,18 +530,39 @@ class _AioSubscriber(_SubscriberBase):
|
|||
if self._close.is_set():
|
||||
return
|
||||
req = self._subscribe_request(self._channel)
|
||||
await self._stub.GcsSubscriberCommandBatch(req, timeout=30)
|
||||
start = time.time()
|
||||
from ray._raylet import Config
|
||||
|
||||
while True:
|
||||
try:
|
||||
return await self._stub.GcsSubscriberCommandBatch(req, timeout=30)
|
||||
except grpc.RpcError as e:
|
||||
self._handle_subscribe_failure(e)
|
||||
if time.time() - start > Config.gcs_rpc_server_reconnect_timeout_s():
|
||||
raise
|
||||
|
||||
async def _poll_call(self, req, timeout=None):
|
||||
# Wrap GRPC _AioCall as a coroutine.
|
||||
return await self._stub.GcsSubscriberPoll(req, timeout=timeout)
|
||||
while True:
|
||||
try:
|
||||
return await self._stub.GcsSubscriberPoll(req, timeout=timeout)
|
||||
except grpc.RpcError as e:
|
||||
if (
|
||||
e.code() == grpc.StatusCode.DEADLINE_EXCEEDED
|
||||
and timeout is not None
|
||||
):
|
||||
return
|
||||
if self._handle_polling_failure(e) is True:
|
||||
await self.subscribe()
|
||||
else:
|
||||
return
|
||||
|
||||
async def _poll(self, timeout=None) -> None:
|
||||
req = self._poll_request()
|
||||
while len(self._queue) == 0:
|
||||
# TODO: use asyncio.create_task() after Python 3.6 is no longer
|
||||
# supported.
|
||||
poll = asyncio.ensure_future(self._poll_call(req, timeout=timeout))
|
||||
poll = asyncio.ensure_future(self._poll_call(req))
|
||||
close = asyncio.ensure_future(self._close.wait())
|
||||
done, _ = await asyncio.wait(
|
||||
[poll, close], timeout=timeout, return_when=asyncio.FIRST_COMPLETED
|
||||
|
@ -515,14 +570,9 @@ class _AioSubscriber(_SubscriberBase):
|
|||
if poll not in done or close in done:
|
||||
# Request timed out or subscriber closed.
|
||||
break
|
||||
try:
|
||||
self._last_batch_size = len(poll.result().pub_messages)
|
||||
for msg in poll.result().pub_messages:
|
||||
self._queue.append(msg)
|
||||
except grpc.RpcError as e:
|
||||
if self._should_terminate_polling(e):
|
||||
return
|
||||
raise
|
||||
self._last_batch_size = len(poll.result().pub_messages)
|
||||
for msg in poll.result().pub_messages:
|
||||
self._queue.append(msg)
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Closes the subscriber and its active subscription."""
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
from libcpp cimport bool as c_bool
|
||||
from libc.stdint cimport int64_t, uint64_t, uint32_t
|
||||
from libc.stdint cimport int64_t, uint64_t, uint32_t, int32_t
|
||||
from libcpp.string cimport string as c_string
|
||||
from libcpp.unordered_map cimport unordered_map
|
||||
|
||||
|
@ -68,3 +68,5 @@ cdef extern from "ray/common/ray_config.h" nogil:
|
|||
c_bool start_python_importer_thread() const
|
||||
|
||||
c_bool use_ray_syncer() const
|
||||
|
||||
int32_t gcs_rpc_server_reconnect_timeout_s() const
|
||||
|
|
|
@ -84,6 +84,10 @@ cdef class Config:
|
|||
def object_manager_default_chunk_size():
|
||||
return RayConfig.instance().object_manager_default_chunk_size()
|
||||
|
||||
@staticmethod
|
||||
def gcs_rpc_server_reconnect_timeout_s():
|
||||
return RayConfig.instance().gcs_rpc_server_reconnect_timeout_s()
|
||||
|
||||
@staticmethod
|
||||
def maximum_gcs_deletion_batch_size():
|
||||
return RayConfig.instance().maximum_gcs_deletion_batch_size()
|
||||
|
|
|
@ -9,11 +9,16 @@ from time import sleep
|
|||
|
||||
from ray._private.test_utils import (
|
||||
generate_system_config_map,
|
||||
run_string_as_driver_nonblocking,
|
||||
wait_for_condition,
|
||||
wait_for_pid_to_exit,
|
||||
convert_actor_state,
|
||||
)
|
||||
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@ray.remote
|
||||
class Increase:
|
||||
|
@ -338,6 +343,83 @@ def test_core_worker_resubscription(tmp_path, ray_start_regular_with_external_re
|
|||
ray.get(r, timeout=5)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"ray_start_regular_with_external_redis",
|
||||
[
|
||||
generate_system_config_map(
|
||||
num_heartbeats_timeout=20, gcs_rpc_server_reconnect_timeout_s=60
|
||||
)
|
||||
],
|
||||
indirect=True,
|
||||
)
|
||||
def test_py_resubscription(tmp_path, ray_start_regular_with_external_redis):
|
||||
# This test is to ensure python pubsub works
|
||||
from filelock import FileLock
|
||||
|
||||
lock_file1 = str(tmp_path / "lock1")
|
||||
lock1 = FileLock(lock_file1)
|
||||
lock1.acquire()
|
||||
|
||||
lock_file2 = str(tmp_path / "lock2")
|
||||
lock2 = FileLock(lock_file2)
|
||||
|
||||
script = f"""
|
||||
from filelock import FileLock
|
||||
import ray
|
||||
|
||||
@ray.remote
|
||||
def f():
|
||||
print("OK1", flush=True)
|
||||
# wait until log_monitor push this
|
||||
from time import sleep
|
||||
sleep(2)
|
||||
lock1 = FileLock(r"{lock_file1}")
|
||||
lock2 = FileLock(r"{lock_file2}")
|
||||
|
||||
lock2.acquire()
|
||||
lock1.acquire()
|
||||
|
||||
# wait until log_monitor push this
|
||||
from time import sleep
|
||||
sleep(2)
|
||||
print("OK2", flush=True)
|
||||
|
||||
ray.init(address='auto')
|
||||
ray.get(f.remote())
|
||||
ray.shutdown()
|
||||
"""
|
||||
proc = run_string_as_driver_nonblocking(script)
|
||||
|
||||
def condition():
|
||||
import filelock
|
||||
|
||||
try:
|
||||
lock2.acquire(timeout=1)
|
||||
except filelock.Timeout:
|
||||
return True
|
||||
|
||||
lock2.release()
|
||||
return False
|
||||
|
||||
# make sure the script has printed "OK1"
|
||||
wait_for_condition(condition, timeout=10)
|
||||
|
||||
ray.worker._global_node.kill_gcs_server()
|
||||
import time
|
||||
|
||||
time.sleep(2)
|
||||
ray.worker._global_node.start_gcs_server()
|
||||
|
||||
lock1.release()
|
||||
proc.wait()
|
||||
output = proc.stdout.read()
|
||||
# Print logs which are useful for debugging in CI
|
||||
print("=================== OUTPUTS ============")
|
||||
print(output.decode())
|
||||
assert b"OK1" in output
|
||||
assert b"OK2" in output
|
||||
|
||||
|
||||
@pytest.mark.parametrize("auto_reconnect", [True, False])
|
||||
def test_gcs_client_reconnect(ray_start_regular_with_external_redis, auto_reconnect):
|
||||
gcs_address = ray.worker.global_worker.gcs_client.address
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
import sys
|
||||
import threading
|
||||
|
||||
import ray
|
||||
from ray._private.gcs_pubsub import (
|
||||
GcsPublisher,
|
||||
GcsErrorSubscriber,
|
||||
|
@ -34,6 +35,72 @@ def test_publish_and_subscribe_error_info(ray_start_regular):
|
|||
subscriber.close()
|
||||
|
||||
|
||||
def test_publish_and_subscribe_error_info_ft(ray_start_regular_with_external_redis):
|
||||
address_info = ray_start_regular_with_external_redis
|
||||
gcs_server_addr = address_info["gcs_address"]
|
||||
from threading import Barrier, Thread
|
||||
|
||||
subscriber = GcsErrorSubscriber(address=gcs_server_addr)
|
||||
subscriber.subscribe()
|
||||
|
||||
publisher = GcsPublisher(address=gcs_server_addr)
|
||||
|
||||
err1 = ErrorTableData(error_message="test error message 1")
|
||||
err2 = ErrorTableData(error_message="test error message 2")
|
||||
err3 = ErrorTableData(error_message="test error message 3")
|
||||
err4 = ErrorTableData(error_message="test error message 4")
|
||||
b = Barrier(3)
|
||||
|
||||
def publisher_func():
|
||||
print("Publisher HERE")
|
||||
publisher.publish_error(b"aaa_id", err1)
|
||||
publisher.publish_error(b"bbb_id", err2)
|
||||
|
||||
b.wait()
|
||||
|
||||
print("Publisher HERE")
|
||||
# Wait fo subscriber to subscribe first.
|
||||
# It's ok to loose log messages.
|
||||
from time import sleep
|
||||
|
||||
sleep(5)
|
||||
publisher.publish_error(b"aaa_id", err3)
|
||||
print("pub err1")
|
||||
publisher.publish_error(b"bbb_id", err4)
|
||||
print("pub err2")
|
||||
print("DONE")
|
||||
|
||||
def subscriber_func():
|
||||
print("Subscriber HERE")
|
||||
assert subscriber.poll() == (b"aaa_id", err1)
|
||||
assert subscriber.poll() == (b"bbb_id", err2)
|
||||
|
||||
b.wait()
|
||||
assert subscriber.poll() == (b"aaa_id", err3)
|
||||
print("sub err1")
|
||||
assert subscriber.poll() == (b"bbb_id", err4)
|
||||
print("sub err2")
|
||||
|
||||
subscriber.close()
|
||||
print("DONE")
|
||||
|
||||
t1 = Thread(target=publisher_func)
|
||||
t2 = Thread(target=subscriber_func)
|
||||
t1.start()
|
||||
t2.start()
|
||||
b.wait()
|
||||
|
||||
ray.worker._global_node.kill_gcs_server()
|
||||
from time import sleep
|
||||
|
||||
sleep(1)
|
||||
ray.worker._global_node.start_gcs_server()
|
||||
sleep(1)
|
||||
|
||||
t1.join()
|
||||
t2.join()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aio_publish_and_subscribe_error_info(ray_start_regular):
|
||||
address_info = ray_start_regular
|
||||
|
@ -54,6 +121,61 @@ async def test_aio_publish_and_subscribe_error_info(ray_start_regular):
|
|||
await subscriber.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aio_publish_and_subscribe_error_info_ft(
|
||||
ray_start_regular_with_external_redis,
|
||||
):
|
||||
address_info = ray_start_regular_with_external_redis
|
||||
gcs_server_addr = address_info["gcs_address"]
|
||||
|
||||
subscriber = GcsAioErrorSubscriber(address=gcs_server_addr)
|
||||
await subscriber.subscribe()
|
||||
|
||||
err1 = ErrorTableData(error_message="test error message 1")
|
||||
err2 = ErrorTableData(error_message="test error message 2")
|
||||
err3 = ErrorTableData(error_message="test error message 3")
|
||||
err4 = ErrorTableData(error_message="test error message 4")
|
||||
|
||||
def restart_gcs_server():
|
||||
import asyncio
|
||||
|
||||
asyncio.set_event_loop(asyncio.new_event_loop())
|
||||
from time import sleep
|
||||
|
||||
publisher = GcsAioPublisher(address=gcs_server_addr)
|
||||
asyncio.get_event_loop().run_until_complete(
|
||||
publisher.publish_error(b"aaa_id", err1)
|
||||
)
|
||||
asyncio.get_event_loop().run_until_complete(
|
||||
publisher.publish_error(b"bbb_id", err2)
|
||||
)
|
||||
|
||||
# wait until subscribe consume everything
|
||||
sleep(5)
|
||||
ray.worker._global_node.kill_gcs_server()
|
||||
sleep(1)
|
||||
ray.worker._global_node.start_gcs_server()
|
||||
# wait until subscriber resubscribed
|
||||
sleep(5)
|
||||
|
||||
asyncio.get_event_loop().run_until_complete(
|
||||
publisher.publish_error(b"aaa_id", err3)
|
||||
)
|
||||
asyncio.get_event_loop().run_until_complete(
|
||||
publisher.publish_error(b"bbb_id", err4)
|
||||
)
|
||||
|
||||
t1 = threading.Thread(target=restart_gcs_server)
|
||||
t1.start()
|
||||
assert await subscriber.poll() == (b"aaa_id", err1)
|
||||
assert await subscriber.poll() == (b"bbb_id", err2)
|
||||
assert await subscriber.poll() == (b"aaa_id", err3)
|
||||
assert await subscriber.poll() == (b"bbb_id", err4)
|
||||
|
||||
await subscriber.close()
|
||||
t1.join()
|
||||
|
||||
|
||||
def test_publish_and_subscribe_logs(ray_start_regular):
|
||||
address_info = ray_start_regular
|
||||
gcs_server_addr = address_info["gcs_address"]
|
||||
|
|
Loading…
Add table
Reference in a new issue