diff --git a/dashboard/tests/test_dashboard.py b/dashboard/tests/test_dashboard.py index 4064d6da4..6735d4ae4 100644 --- a/dashboard/tests/test_dashboard.py +++ b/dashboard/tests/test_dashboard.py @@ -697,32 +697,6 @@ def test_dashboard_port_conflict(ray_start_with_dashboard): raise Exception("Timed out while testing.") -@pytest.mark.skipif( - os.environ.get("RAY_MINIMAL") == "1", - reason="This test is not supposed to work for minimal installation.", -) -def test_gcs_check_alive(fast_gcs_failure_detection, ray_start_with_dashboard): - assert wait_until_server_available(ray_start_with_dashboard["webui_url"]) is True - - all_processes = ray.worker._global_node.all_processes - dashboard_info = all_processes[ray_constants.PROCESS_TYPE_DASHBOARD][0] - dashboard_proc = psutil.Process(dashboard_info.process.pid) - gcs_server_info = all_processes[ray_constants.PROCESS_TYPE_GCS_SERVER][0] - gcs_server_proc = psutil.Process(gcs_server_info.process.pid) - - assert dashboard_proc.status() in [ - psutil.STATUS_RUNNING, - psutil.STATUS_SLEEPING, - psutil.STATUS_DISK_SLEEP, - ] - - gcs_server_proc.kill() - gcs_server_proc.wait() - - # The dashboard exits by os._exit(-1) - assert dashboard_proc.wait(10) == 255 - - @pytest.mark.skipif( os.environ.get("RAY_DEFAULT") != "1", reason="This test only works for default installation.", diff --git a/python/ray/_private/gcs_pubsub.py b/python/ray/_private/gcs_pubsub.py index 669e7568e..775ec333c 100644 --- a/python/ray/_private/gcs_pubsub.py +++ b/python/ray/_private/gcs_pubsub.py @@ -104,16 +104,28 @@ class _SubscriberBase: ) return req - @staticmethod - def _should_terminate_polling(e: grpc.RpcError) -> None: - # Caller only expects polling to be terminated after deadline exceeded. - if e.code() == grpc.StatusCode.DEADLINE_EXCEEDED: + def _handle_polling_failure(self, e: grpc.RpcError) -> bool: + if self._close.is_set(): + return False + + if e.code() in ( + grpc.StatusCode.UNAVAILABLE, + grpc.StatusCode.UNKNOWN, + grpc.StatusCode.DEADLINE_EXCEEDED, + ): return True - # Could be a temporary connection issue. Suppress error. - # TODO: reconnect GRPC channel? - if e.code() == grpc.StatusCode.UNAVAILABLE: - return True - return False + + raise e + + def _handle_subscribe_failure(self, e: grpc.RpcError): + if e.code() in ( + grpc.StatusCode.UNAVAILABLE, + grpc.StatusCode.UNKNOWN, + grpc.StatusCode.DEADLINE_EXCEEDED, + ): + time.sleep(1) + else: + raise e @staticmethod def _pop_error_info(queue): @@ -211,7 +223,7 @@ class _SyncSubscriber(_SubscriberBase): # Type of the channel. self._channel = pubsub_channel_type # Protects multi-threaded read and write of self._queue. - self._lock = threading.Lock() + self._lock = threading.RLock() # A queue of received PubMessage. self._queue = deque() # Indicates whether the subscriber has closed. @@ -224,13 +236,24 @@ class _SyncSubscriber(_SubscriberBase): saved for the subscriber. """ with self._lock: - if self._close.is_set(): - return req = self._subscribe_request(self._channel) - self._stub.GcsSubscriberCommandBatch(req, timeout=30) + start = time.time() + from ray._raylet import Config + + while True: + try: + if self._close.is_set(): + return + return self._stub.GcsSubscriberCommandBatch(req, timeout=30) + except grpc.RpcError as e: + self._handle_subscribe_failure(e) + if ( + time.time() - start + > Config.gcs_rpc_server_reconnect_timeout_s() + ): + raise e def _poll_locked(self, timeout=None) -> None: - assert self._lock.locked() # Poll until data becomes available. while len(self._queue) == 0: @@ -257,9 +280,18 @@ class _SyncSubscriber(_SubscriberBase): # GRPC has not replied, continue waiting. continue except grpc.RpcError as e: - if self._should_terminate_polling(e): + if ( + e.code() == grpc.StatusCode.DEADLINE_EXCEEDED + and timeout is not None + ): + return + if self._handle_polling_failure(e) is True: + self.subscribe() + fut = self._stub.GcsSubscriberPoll.future( + self._poll_request(), timeout=timeout + ) + else: return - raise if fut.done(): self._last_batch_size = len(fut.result().pub_messages) @@ -277,11 +309,13 @@ class _SyncSubscriber(_SubscriberBase): return self._close.set() req = self._unsubscribe_request(channels=[self._channel]) + try: self._stub.GcsSubscriberCommandBatch(req, timeout=5) except Exception: pass - self._stub = None + with self._lock: + self._stub = None class GcsErrorSubscriber(_SyncSubscriber): @@ -496,18 +530,39 @@ class _AioSubscriber(_SubscriberBase): if self._close.is_set(): return req = self._subscribe_request(self._channel) - await self._stub.GcsSubscriberCommandBatch(req, timeout=30) + start = time.time() + from ray._raylet import Config + + while True: + try: + return await self._stub.GcsSubscriberCommandBatch(req, timeout=30) + except grpc.RpcError as e: + self._handle_subscribe_failure(e) + if time.time() - start > Config.gcs_rpc_server_reconnect_timeout_s(): + raise async def _poll_call(self, req, timeout=None): # Wrap GRPC _AioCall as a coroutine. - return await self._stub.GcsSubscriberPoll(req, timeout=timeout) + while True: + try: + return await self._stub.GcsSubscriberPoll(req, timeout=timeout) + except grpc.RpcError as e: + if ( + e.code() == grpc.StatusCode.DEADLINE_EXCEEDED + and timeout is not None + ): + return + if self._handle_polling_failure(e) is True: + await self.subscribe() + else: + return async def _poll(self, timeout=None) -> None: req = self._poll_request() while len(self._queue) == 0: # TODO: use asyncio.create_task() after Python 3.6 is no longer # supported. - poll = asyncio.ensure_future(self._poll_call(req, timeout=timeout)) + poll = asyncio.ensure_future(self._poll_call(req)) close = asyncio.ensure_future(self._close.wait()) done, _ = await asyncio.wait( [poll, close], timeout=timeout, return_when=asyncio.FIRST_COMPLETED @@ -515,14 +570,9 @@ class _AioSubscriber(_SubscriberBase): if poll not in done or close in done: # Request timed out or subscriber closed. break - try: - self._last_batch_size = len(poll.result().pub_messages) - for msg in poll.result().pub_messages: - self._queue.append(msg) - except grpc.RpcError as e: - if self._should_terminate_polling(e): - return - raise + self._last_batch_size = len(poll.result().pub_messages) + for msg in poll.result().pub_messages: + self._queue.append(msg) async def close(self) -> None: """Closes the subscriber and its active subscription.""" diff --git a/python/ray/includes/ray_config.pxd b/python/ray/includes/ray_config.pxd index e9ac6d2cc..7b715fe41 100644 --- a/python/ray/includes/ray_config.pxd +++ b/python/ray/includes/ray_config.pxd @@ -1,5 +1,5 @@ from libcpp cimport bool as c_bool -from libc.stdint cimport int64_t, uint64_t, uint32_t +from libc.stdint cimport int64_t, uint64_t, uint32_t, int32_t from libcpp.string cimport string as c_string from libcpp.unordered_map cimport unordered_map @@ -68,3 +68,5 @@ cdef extern from "ray/common/ray_config.h" nogil: c_bool start_python_importer_thread() const c_bool use_ray_syncer() const + + int32_t gcs_rpc_server_reconnect_timeout_s() const diff --git a/python/ray/includes/ray_config.pxi b/python/ray/includes/ray_config.pxi index c65bfbc29..499500953 100644 --- a/python/ray/includes/ray_config.pxi +++ b/python/ray/includes/ray_config.pxi @@ -84,6 +84,10 @@ cdef class Config: def object_manager_default_chunk_size(): return RayConfig.instance().object_manager_default_chunk_size() + @staticmethod + def gcs_rpc_server_reconnect_timeout_s(): + return RayConfig.instance().gcs_rpc_server_reconnect_timeout_s() + @staticmethod def maximum_gcs_deletion_batch_size(): return RayConfig.instance().maximum_gcs_deletion_batch_size() diff --git a/python/ray/tests/test_gcs_fault_tolerance.py b/python/ray/tests/test_gcs_fault_tolerance.py index bcc1b7700..52a3c3c3d 100644 --- a/python/ray/tests/test_gcs_fault_tolerance.py +++ b/python/ray/tests/test_gcs_fault_tolerance.py @@ -9,11 +9,16 @@ from time import sleep from ray._private.test_utils import ( generate_system_config_map, + run_string_as_driver_nonblocking, wait_for_condition, wait_for_pid_to_exit, convert_actor_state, ) +import logging + +logger = logging.getLogger(__name__) + @ray.remote class Increase: @@ -338,6 +343,83 @@ def test_core_worker_resubscription(tmp_path, ray_start_regular_with_external_re ray.get(r, timeout=5) +@pytest.mark.parametrize( + "ray_start_regular_with_external_redis", + [ + generate_system_config_map( + num_heartbeats_timeout=20, gcs_rpc_server_reconnect_timeout_s=60 + ) + ], + indirect=True, +) +def test_py_resubscription(tmp_path, ray_start_regular_with_external_redis): + # This test is to ensure python pubsub works + from filelock import FileLock + + lock_file1 = str(tmp_path / "lock1") + lock1 = FileLock(lock_file1) + lock1.acquire() + + lock_file2 = str(tmp_path / "lock2") + lock2 = FileLock(lock_file2) + + script = f""" +from filelock import FileLock +import ray + +@ray.remote +def f(): + print("OK1", flush=True) + # wait until log_monitor push this + from time import sleep + sleep(2) + lock1 = FileLock(r"{lock_file1}") + lock2 = FileLock(r"{lock_file2}") + + lock2.acquire() + lock1.acquire() + + # wait until log_monitor push this + from time import sleep + sleep(2) + print("OK2", flush=True) + +ray.init(address='auto') +ray.get(f.remote()) +ray.shutdown() +""" + proc = run_string_as_driver_nonblocking(script) + + def condition(): + import filelock + + try: + lock2.acquire(timeout=1) + except filelock.Timeout: + return True + + lock2.release() + return False + + # make sure the script has printed "OK1" + wait_for_condition(condition, timeout=10) + + ray.worker._global_node.kill_gcs_server() + import time + + time.sleep(2) + ray.worker._global_node.start_gcs_server() + + lock1.release() + proc.wait() + output = proc.stdout.read() + # Print logs which are useful for debugging in CI + print("=================== OUTPUTS ============") + print(output.decode()) + assert b"OK1" in output + assert b"OK2" in output + + @pytest.mark.parametrize("auto_reconnect", [True, False]) def test_gcs_client_reconnect(ray_start_regular_with_external_redis, auto_reconnect): gcs_address = ray.worker.global_worker.gcs_client.address diff --git a/python/ray/tests/test_gcs_pubsub.py b/python/ray/tests/test_gcs_pubsub.py index 8d0942abd..b2d47bad7 100644 --- a/python/ray/tests/test_gcs_pubsub.py +++ b/python/ray/tests/test_gcs_pubsub.py @@ -1,6 +1,7 @@ import sys import threading +import ray from ray._private.gcs_pubsub import ( GcsPublisher, GcsErrorSubscriber, @@ -34,6 +35,72 @@ def test_publish_and_subscribe_error_info(ray_start_regular): subscriber.close() +def test_publish_and_subscribe_error_info_ft(ray_start_regular_with_external_redis): + address_info = ray_start_regular_with_external_redis + gcs_server_addr = address_info["gcs_address"] + from threading import Barrier, Thread + + subscriber = GcsErrorSubscriber(address=gcs_server_addr) + subscriber.subscribe() + + publisher = GcsPublisher(address=gcs_server_addr) + + err1 = ErrorTableData(error_message="test error message 1") + err2 = ErrorTableData(error_message="test error message 2") + err3 = ErrorTableData(error_message="test error message 3") + err4 = ErrorTableData(error_message="test error message 4") + b = Barrier(3) + + def publisher_func(): + print("Publisher HERE") + publisher.publish_error(b"aaa_id", err1) + publisher.publish_error(b"bbb_id", err2) + + b.wait() + + print("Publisher HERE") + # Wait fo subscriber to subscribe first. + # It's ok to loose log messages. + from time import sleep + + sleep(5) + publisher.publish_error(b"aaa_id", err3) + print("pub err1") + publisher.publish_error(b"bbb_id", err4) + print("pub err2") + print("DONE") + + def subscriber_func(): + print("Subscriber HERE") + assert subscriber.poll() == (b"aaa_id", err1) + assert subscriber.poll() == (b"bbb_id", err2) + + b.wait() + assert subscriber.poll() == (b"aaa_id", err3) + print("sub err1") + assert subscriber.poll() == (b"bbb_id", err4) + print("sub err2") + + subscriber.close() + print("DONE") + + t1 = Thread(target=publisher_func) + t2 = Thread(target=subscriber_func) + t1.start() + t2.start() + b.wait() + + ray.worker._global_node.kill_gcs_server() + from time import sleep + + sleep(1) + ray.worker._global_node.start_gcs_server() + sleep(1) + + t1.join() + t2.join() + + @pytest.mark.asyncio async def test_aio_publish_and_subscribe_error_info(ray_start_regular): address_info = ray_start_regular @@ -54,6 +121,61 @@ async def test_aio_publish_and_subscribe_error_info(ray_start_regular): await subscriber.close() +@pytest.mark.asyncio +async def test_aio_publish_and_subscribe_error_info_ft( + ray_start_regular_with_external_redis, +): + address_info = ray_start_regular_with_external_redis + gcs_server_addr = address_info["gcs_address"] + + subscriber = GcsAioErrorSubscriber(address=gcs_server_addr) + await subscriber.subscribe() + + err1 = ErrorTableData(error_message="test error message 1") + err2 = ErrorTableData(error_message="test error message 2") + err3 = ErrorTableData(error_message="test error message 3") + err4 = ErrorTableData(error_message="test error message 4") + + def restart_gcs_server(): + import asyncio + + asyncio.set_event_loop(asyncio.new_event_loop()) + from time import sleep + + publisher = GcsAioPublisher(address=gcs_server_addr) + asyncio.get_event_loop().run_until_complete( + publisher.publish_error(b"aaa_id", err1) + ) + asyncio.get_event_loop().run_until_complete( + publisher.publish_error(b"bbb_id", err2) + ) + + # wait until subscribe consume everything + sleep(5) + ray.worker._global_node.kill_gcs_server() + sleep(1) + ray.worker._global_node.start_gcs_server() + # wait until subscriber resubscribed + sleep(5) + + asyncio.get_event_loop().run_until_complete( + publisher.publish_error(b"aaa_id", err3) + ) + asyncio.get_event_loop().run_until_complete( + publisher.publish_error(b"bbb_id", err4) + ) + + t1 = threading.Thread(target=restart_gcs_server) + t1.start() + assert await subscriber.poll() == (b"aaa_id", err1) + assert await subscriber.poll() == (b"bbb_id", err2) + assert await subscriber.poll() == (b"aaa_id", err3) + assert await subscriber.poll() == (b"bbb_id", err4) + + await subscriber.close() + t1.join() + + def test_publish_and_subscribe_logs(ray_start_regular): address_info = ray_start_regular gcs_server_addr = address_info["gcs_address"]