From 7cf42338589d57bafea9933055781fe78e066a72 Mon Sep 17 00:00:00 2001 From: Yi Cheng <74173148+iycheng@users.noreply.github.com> Date: Mon, 23 May 2022 13:06:33 -0700 Subject: [PATCH] [core] Resubscribe GCS in python when GCS restarts. (#24887) This is a follow-up PRs of https://github.com/ray-project/ray/pull/24813 and https://github.com/ray-project/ray/pull/24628 Unlike the change in cpp layer, where the resubscription is done by GCS broadcast a request to raylet/core_worker and the client-side do the resubscription, in the python layer, we detect the failure in the client-side. In case of a failure, the protocol is: 1. call subscribe 2. if timeout when doing resubscribe, throw an exception and this will crash the system. This is ok because when GCS has been down for a time longer than expected, we expect the ray cluster to be down. 3. continue to poll once subscribe ok. However, there is an extreme case where things might be broken: the client might miss detecting a failure. This could happen if the long-polling has been returned and the python layer is doing its own work. And before it sends another long-polling, GCS restarts and recovered. Here we are not going to take care of this case because: 1. usually GCS is going to take several seconds to be up and the python layer's work is simply pushing data into a queue (sync version). For the async version, it's only used in Dashboard which is not a critical component. 2. pubsub in python layer is not doing critical work: it handles logs/errors for ray job; 3. for the dashboard, it can just restart to fix the issue. A known issue here is that we might miss logs in case of GCS failure due to the following reasons: - py's pubsub is only doing best effort publishing. If it failed too many times, it'll skip publishing the message (lose messages from producer side) - if message is pushed to GCS, but the worker hasn't done resubscription yet, the pushed message will be lost (lose messages from consumer side) We think it's reasonable and valid behavior given that the logs are not defined to be a critical component and we'd like to simplify the design of pubsub in GCS. Another things is `run_functions_on_all_workers`. We'll plan to stop using it within ray core and deprecate it in the longer term. But it won't cause a problem for the current cases because: 1. It's only set in driver and we don't support creating a new driver when GCS is down. 2. When GCS is down, we don't support starting new ray workers. And `run_functions_on_all_workers` is only used when we initialize driver/workers. --- dashboard/tests/test_dashboard.py | 26 ---- python/ray/_private/gcs_pubsub.py | 106 +++++++++++----- python/ray/includes/ray_config.pxd | 4 +- python/ray/includes/ray_config.pxi | 4 + python/ray/tests/test_gcs_fault_tolerance.py | 82 +++++++++++++ python/ray/tests/test_gcs_pubsub.py | 122 +++++++++++++++++++ 6 files changed, 289 insertions(+), 55 deletions(-) diff --git a/dashboard/tests/test_dashboard.py b/dashboard/tests/test_dashboard.py index 4064d6da4..6735d4ae4 100644 --- a/dashboard/tests/test_dashboard.py +++ b/dashboard/tests/test_dashboard.py @@ -697,32 +697,6 @@ def test_dashboard_port_conflict(ray_start_with_dashboard): raise Exception("Timed out while testing.") -@pytest.mark.skipif( - os.environ.get("RAY_MINIMAL") == "1", - reason="This test is not supposed to work for minimal installation.", -) -def test_gcs_check_alive(fast_gcs_failure_detection, ray_start_with_dashboard): - assert wait_until_server_available(ray_start_with_dashboard["webui_url"]) is True - - all_processes = ray.worker._global_node.all_processes - dashboard_info = all_processes[ray_constants.PROCESS_TYPE_DASHBOARD][0] - dashboard_proc = psutil.Process(dashboard_info.process.pid) - gcs_server_info = all_processes[ray_constants.PROCESS_TYPE_GCS_SERVER][0] - gcs_server_proc = psutil.Process(gcs_server_info.process.pid) - - assert dashboard_proc.status() in [ - psutil.STATUS_RUNNING, - psutil.STATUS_SLEEPING, - psutil.STATUS_DISK_SLEEP, - ] - - gcs_server_proc.kill() - gcs_server_proc.wait() - - # The dashboard exits by os._exit(-1) - assert dashboard_proc.wait(10) == 255 - - @pytest.mark.skipif( os.environ.get("RAY_DEFAULT") != "1", reason="This test only works for default installation.", diff --git a/python/ray/_private/gcs_pubsub.py b/python/ray/_private/gcs_pubsub.py index 669e7568e..775ec333c 100644 --- a/python/ray/_private/gcs_pubsub.py +++ b/python/ray/_private/gcs_pubsub.py @@ -104,16 +104,28 @@ class _SubscriberBase: ) return req - @staticmethod - def _should_terminate_polling(e: grpc.RpcError) -> None: - # Caller only expects polling to be terminated after deadline exceeded. - if e.code() == grpc.StatusCode.DEADLINE_EXCEEDED: + def _handle_polling_failure(self, e: grpc.RpcError) -> bool: + if self._close.is_set(): + return False + + if e.code() in ( + grpc.StatusCode.UNAVAILABLE, + grpc.StatusCode.UNKNOWN, + grpc.StatusCode.DEADLINE_EXCEEDED, + ): return True - # Could be a temporary connection issue. Suppress error. - # TODO: reconnect GRPC channel? - if e.code() == grpc.StatusCode.UNAVAILABLE: - return True - return False + + raise e + + def _handle_subscribe_failure(self, e: grpc.RpcError): + if e.code() in ( + grpc.StatusCode.UNAVAILABLE, + grpc.StatusCode.UNKNOWN, + grpc.StatusCode.DEADLINE_EXCEEDED, + ): + time.sleep(1) + else: + raise e @staticmethod def _pop_error_info(queue): @@ -211,7 +223,7 @@ class _SyncSubscriber(_SubscriberBase): # Type of the channel. self._channel = pubsub_channel_type # Protects multi-threaded read and write of self._queue. - self._lock = threading.Lock() + self._lock = threading.RLock() # A queue of received PubMessage. self._queue = deque() # Indicates whether the subscriber has closed. @@ -224,13 +236,24 @@ class _SyncSubscriber(_SubscriberBase): saved for the subscriber. """ with self._lock: - if self._close.is_set(): - return req = self._subscribe_request(self._channel) - self._stub.GcsSubscriberCommandBatch(req, timeout=30) + start = time.time() + from ray._raylet import Config + + while True: + try: + if self._close.is_set(): + return + return self._stub.GcsSubscriberCommandBatch(req, timeout=30) + except grpc.RpcError as e: + self._handle_subscribe_failure(e) + if ( + time.time() - start + > Config.gcs_rpc_server_reconnect_timeout_s() + ): + raise e def _poll_locked(self, timeout=None) -> None: - assert self._lock.locked() # Poll until data becomes available. while len(self._queue) == 0: @@ -257,9 +280,18 @@ class _SyncSubscriber(_SubscriberBase): # GRPC has not replied, continue waiting. continue except grpc.RpcError as e: - if self._should_terminate_polling(e): + if ( + e.code() == grpc.StatusCode.DEADLINE_EXCEEDED + and timeout is not None + ): + return + if self._handle_polling_failure(e) is True: + self.subscribe() + fut = self._stub.GcsSubscriberPoll.future( + self._poll_request(), timeout=timeout + ) + else: return - raise if fut.done(): self._last_batch_size = len(fut.result().pub_messages) @@ -277,11 +309,13 @@ class _SyncSubscriber(_SubscriberBase): return self._close.set() req = self._unsubscribe_request(channels=[self._channel]) + try: self._stub.GcsSubscriberCommandBatch(req, timeout=5) except Exception: pass - self._stub = None + with self._lock: + self._stub = None class GcsErrorSubscriber(_SyncSubscriber): @@ -496,18 +530,39 @@ class _AioSubscriber(_SubscriberBase): if self._close.is_set(): return req = self._subscribe_request(self._channel) - await self._stub.GcsSubscriberCommandBatch(req, timeout=30) + start = time.time() + from ray._raylet import Config + + while True: + try: + return await self._stub.GcsSubscriberCommandBatch(req, timeout=30) + except grpc.RpcError as e: + self._handle_subscribe_failure(e) + if time.time() - start > Config.gcs_rpc_server_reconnect_timeout_s(): + raise async def _poll_call(self, req, timeout=None): # Wrap GRPC _AioCall as a coroutine. - return await self._stub.GcsSubscriberPoll(req, timeout=timeout) + while True: + try: + return await self._stub.GcsSubscriberPoll(req, timeout=timeout) + except grpc.RpcError as e: + if ( + e.code() == grpc.StatusCode.DEADLINE_EXCEEDED + and timeout is not None + ): + return + if self._handle_polling_failure(e) is True: + await self.subscribe() + else: + return async def _poll(self, timeout=None) -> None: req = self._poll_request() while len(self._queue) == 0: # TODO: use asyncio.create_task() after Python 3.6 is no longer # supported. - poll = asyncio.ensure_future(self._poll_call(req, timeout=timeout)) + poll = asyncio.ensure_future(self._poll_call(req)) close = asyncio.ensure_future(self._close.wait()) done, _ = await asyncio.wait( [poll, close], timeout=timeout, return_when=asyncio.FIRST_COMPLETED @@ -515,14 +570,9 @@ class _AioSubscriber(_SubscriberBase): if poll not in done or close in done: # Request timed out or subscriber closed. break - try: - self._last_batch_size = len(poll.result().pub_messages) - for msg in poll.result().pub_messages: - self._queue.append(msg) - except grpc.RpcError as e: - if self._should_terminate_polling(e): - return - raise + self._last_batch_size = len(poll.result().pub_messages) + for msg in poll.result().pub_messages: + self._queue.append(msg) async def close(self) -> None: """Closes the subscriber and its active subscription.""" diff --git a/python/ray/includes/ray_config.pxd b/python/ray/includes/ray_config.pxd index e9ac6d2cc..7b715fe41 100644 --- a/python/ray/includes/ray_config.pxd +++ b/python/ray/includes/ray_config.pxd @@ -1,5 +1,5 @@ from libcpp cimport bool as c_bool -from libc.stdint cimport int64_t, uint64_t, uint32_t +from libc.stdint cimport int64_t, uint64_t, uint32_t, int32_t from libcpp.string cimport string as c_string from libcpp.unordered_map cimport unordered_map @@ -68,3 +68,5 @@ cdef extern from "ray/common/ray_config.h" nogil: c_bool start_python_importer_thread() const c_bool use_ray_syncer() const + + int32_t gcs_rpc_server_reconnect_timeout_s() const diff --git a/python/ray/includes/ray_config.pxi b/python/ray/includes/ray_config.pxi index c65bfbc29..499500953 100644 --- a/python/ray/includes/ray_config.pxi +++ b/python/ray/includes/ray_config.pxi @@ -84,6 +84,10 @@ cdef class Config: def object_manager_default_chunk_size(): return RayConfig.instance().object_manager_default_chunk_size() + @staticmethod + def gcs_rpc_server_reconnect_timeout_s(): + return RayConfig.instance().gcs_rpc_server_reconnect_timeout_s() + @staticmethod def maximum_gcs_deletion_batch_size(): return RayConfig.instance().maximum_gcs_deletion_batch_size() diff --git a/python/ray/tests/test_gcs_fault_tolerance.py b/python/ray/tests/test_gcs_fault_tolerance.py index bcc1b7700..52a3c3c3d 100644 --- a/python/ray/tests/test_gcs_fault_tolerance.py +++ b/python/ray/tests/test_gcs_fault_tolerance.py @@ -9,11 +9,16 @@ from time import sleep from ray._private.test_utils import ( generate_system_config_map, + run_string_as_driver_nonblocking, wait_for_condition, wait_for_pid_to_exit, convert_actor_state, ) +import logging + +logger = logging.getLogger(__name__) + @ray.remote class Increase: @@ -338,6 +343,83 @@ def test_core_worker_resubscription(tmp_path, ray_start_regular_with_external_re ray.get(r, timeout=5) +@pytest.mark.parametrize( + "ray_start_regular_with_external_redis", + [ + generate_system_config_map( + num_heartbeats_timeout=20, gcs_rpc_server_reconnect_timeout_s=60 + ) + ], + indirect=True, +) +def test_py_resubscription(tmp_path, ray_start_regular_with_external_redis): + # This test is to ensure python pubsub works + from filelock import FileLock + + lock_file1 = str(tmp_path / "lock1") + lock1 = FileLock(lock_file1) + lock1.acquire() + + lock_file2 = str(tmp_path / "lock2") + lock2 = FileLock(lock_file2) + + script = f""" +from filelock import FileLock +import ray + +@ray.remote +def f(): + print("OK1", flush=True) + # wait until log_monitor push this + from time import sleep + sleep(2) + lock1 = FileLock(r"{lock_file1}") + lock2 = FileLock(r"{lock_file2}") + + lock2.acquire() + lock1.acquire() + + # wait until log_monitor push this + from time import sleep + sleep(2) + print("OK2", flush=True) + +ray.init(address='auto') +ray.get(f.remote()) +ray.shutdown() +""" + proc = run_string_as_driver_nonblocking(script) + + def condition(): + import filelock + + try: + lock2.acquire(timeout=1) + except filelock.Timeout: + return True + + lock2.release() + return False + + # make sure the script has printed "OK1" + wait_for_condition(condition, timeout=10) + + ray.worker._global_node.kill_gcs_server() + import time + + time.sleep(2) + ray.worker._global_node.start_gcs_server() + + lock1.release() + proc.wait() + output = proc.stdout.read() + # Print logs which are useful for debugging in CI + print("=================== OUTPUTS ============") + print(output.decode()) + assert b"OK1" in output + assert b"OK2" in output + + @pytest.mark.parametrize("auto_reconnect", [True, False]) def test_gcs_client_reconnect(ray_start_regular_with_external_redis, auto_reconnect): gcs_address = ray.worker.global_worker.gcs_client.address diff --git a/python/ray/tests/test_gcs_pubsub.py b/python/ray/tests/test_gcs_pubsub.py index 8d0942abd..b2d47bad7 100644 --- a/python/ray/tests/test_gcs_pubsub.py +++ b/python/ray/tests/test_gcs_pubsub.py @@ -1,6 +1,7 @@ import sys import threading +import ray from ray._private.gcs_pubsub import ( GcsPublisher, GcsErrorSubscriber, @@ -34,6 +35,72 @@ def test_publish_and_subscribe_error_info(ray_start_regular): subscriber.close() +def test_publish_and_subscribe_error_info_ft(ray_start_regular_with_external_redis): + address_info = ray_start_regular_with_external_redis + gcs_server_addr = address_info["gcs_address"] + from threading import Barrier, Thread + + subscriber = GcsErrorSubscriber(address=gcs_server_addr) + subscriber.subscribe() + + publisher = GcsPublisher(address=gcs_server_addr) + + err1 = ErrorTableData(error_message="test error message 1") + err2 = ErrorTableData(error_message="test error message 2") + err3 = ErrorTableData(error_message="test error message 3") + err4 = ErrorTableData(error_message="test error message 4") + b = Barrier(3) + + def publisher_func(): + print("Publisher HERE") + publisher.publish_error(b"aaa_id", err1) + publisher.publish_error(b"bbb_id", err2) + + b.wait() + + print("Publisher HERE") + # Wait fo subscriber to subscribe first. + # It's ok to loose log messages. + from time import sleep + + sleep(5) + publisher.publish_error(b"aaa_id", err3) + print("pub err1") + publisher.publish_error(b"bbb_id", err4) + print("pub err2") + print("DONE") + + def subscriber_func(): + print("Subscriber HERE") + assert subscriber.poll() == (b"aaa_id", err1) + assert subscriber.poll() == (b"bbb_id", err2) + + b.wait() + assert subscriber.poll() == (b"aaa_id", err3) + print("sub err1") + assert subscriber.poll() == (b"bbb_id", err4) + print("sub err2") + + subscriber.close() + print("DONE") + + t1 = Thread(target=publisher_func) + t2 = Thread(target=subscriber_func) + t1.start() + t2.start() + b.wait() + + ray.worker._global_node.kill_gcs_server() + from time import sleep + + sleep(1) + ray.worker._global_node.start_gcs_server() + sleep(1) + + t1.join() + t2.join() + + @pytest.mark.asyncio async def test_aio_publish_and_subscribe_error_info(ray_start_regular): address_info = ray_start_regular @@ -54,6 +121,61 @@ async def test_aio_publish_and_subscribe_error_info(ray_start_regular): await subscriber.close() +@pytest.mark.asyncio +async def test_aio_publish_and_subscribe_error_info_ft( + ray_start_regular_with_external_redis, +): + address_info = ray_start_regular_with_external_redis + gcs_server_addr = address_info["gcs_address"] + + subscriber = GcsAioErrorSubscriber(address=gcs_server_addr) + await subscriber.subscribe() + + err1 = ErrorTableData(error_message="test error message 1") + err2 = ErrorTableData(error_message="test error message 2") + err3 = ErrorTableData(error_message="test error message 3") + err4 = ErrorTableData(error_message="test error message 4") + + def restart_gcs_server(): + import asyncio + + asyncio.set_event_loop(asyncio.new_event_loop()) + from time import sleep + + publisher = GcsAioPublisher(address=gcs_server_addr) + asyncio.get_event_loop().run_until_complete( + publisher.publish_error(b"aaa_id", err1) + ) + asyncio.get_event_loop().run_until_complete( + publisher.publish_error(b"bbb_id", err2) + ) + + # wait until subscribe consume everything + sleep(5) + ray.worker._global_node.kill_gcs_server() + sleep(1) + ray.worker._global_node.start_gcs_server() + # wait until subscriber resubscribed + sleep(5) + + asyncio.get_event_loop().run_until_complete( + publisher.publish_error(b"aaa_id", err3) + ) + asyncio.get_event_loop().run_until_complete( + publisher.publish_error(b"bbb_id", err4) + ) + + t1 = threading.Thread(target=restart_gcs_server) + t1.start() + assert await subscriber.poll() == (b"aaa_id", err1) + assert await subscriber.poll() == (b"bbb_id", err2) + assert await subscriber.poll() == (b"aaa_id", err3) + assert await subscriber.poll() == (b"bbb_id", err4) + + await subscriber.close() + t1.join() + + def test_publish_and_subscribe_logs(ray_start_regular): address_info = ray_start_regular gcs_server_addr = address_info["gcs_address"]