Revert "[core] Resubscribe GCS in python when GCS restarts. (#24887)" (#25168)

This reverts commit 7cf4233858.
This commit is contained in:
mwtian 2022-05-24 18:13:40 -07:00 committed by GitHub
parent f7692e4602
commit fa32cb7c40
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 55 additions and 289 deletions

View file

@ -697,6 +697,32 @@ def test_dashboard_port_conflict(ray_start_with_dashboard):
raise Exception("Timed out while testing.")
@pytest.mark.skipif(
os.environ.get("RAY_MINIMAL") == "1",
reason="This test is not supposed to work for minimal installation.",
)
def test_gcs_check_alive(fast_gcs_failure_detection, ray_start_with_dashboard):
assert wait_until_server_available(ray_start_with_dashboard["webui_url"]) is True
all_processes = ray.worker._global_node.all_processes
dashboard_info = all_processes[ray_constants.PROCESS_TYPE_DASHBOARD][0]
dashboard_proc = psutil.Process(dashboard_info.process.pid)
gcs_server_info = all_processes[ray_constants.PROCESS_TYPE_GCS_SERVER][0]
gcs_server_proc = psutil.Process(gcs_server_info.process.pid)
assert dashboard_proc.status() in [
psutil.STATUS_RUNNING,
psutil.STATUS_SLEEPING,
psutil.STATUS_DISK_SLEEP,
]
gcs_server_proc.kill()
gcs_server_proc.wait()
# The dashboard exits by os._exit(-1)
assert dashboard_proc.wait(10) == 255
@pytest.mark.skipif(
os.environ.get("RAY_DEFAULT") != "1",
reason="This test only works for default installation.",

View file

@ -104,28 +104,16 @@ class _SubscriberBase:
)
return req
def _handle_polling_failure(self, e: grpc.RpcError) -> bool:
if self._close.is_set():
return False
if e.code() in (
grpc.StatusCode.UNAVAILABLE,
grpc.StatusCode.UNKNOWN,
grpc.StatusCode.DEADLINE_EXCEEDED,
):
@staticmethod
def _should_terminate_polling(e: grpc.RpcError) -> None:
# Caller only expects polling to be terminated after deadline exceeded.
if e.code() == grpc.StatusCode.DEADLINE_EXCEEDED:
return True
raise e
def _handle_subscribe_failure(self, e: grpc.RpcError):
if e.code() in (
grpc.StatusCode.UNAVAILABLE,
grpc.StatusCode.UNKNOWN,
grpc.StatusCode.DEADLINE_EXCEEDED,
):
time.sleep(1)
else:
raise e
# Could be a temporary connection issue. Suppress error.
# TODO: reconnect GRPC channel?
if e.code() == grpc.StatusCode.UNAVAILABLE:
return True
return False
@staticmethod
def _pop_error_info(queue):
@ -223,7 +211,7 @@ class _SyncSubscriber(_SubscriberBase):
# Type of the channel.
self._channel = pubsub_channel_type
# Protects multi-threaded read and write of self._queue.
self._lock = threading.RLock()
self._lock = threading.Lock()
# A queue of received PubMessage.
self._queue = deque()
# Indicates whether the subscriber has closed.
@ -236,24 +224,13 @@ class _SyncSubscriber(_SubscriberBase):
saved for the subscriber.
"""
with self._lock:
if self._close.is_set():
return
req = self._subscribe_request(self._channel)
start = time.time()
from ray._raylet import Config
while True:
try:
if self._close.is_set():
return
return self._stub.GcsSubscriberCommandBatch(req, timeout=30)
except grpc.RpcError as e:
self._handle_subscribe_failure(e)
if (
time.time() - start
> Config.gcs_rpc_server_reconnect_timeout_s()
):
raise e
self._stub.GcsSubscriberCommandBatch(req, timeout=30)
def _poll_locked(self, timeout=None) -> None:
assert self._lock.locked()
# Poll until data becomes available.
while len(self._queue) == 0:
@ -280,18 +257,9 @@ class _SyncSubscriber(_SubscriberBase):
# GRPC has not replied, continue waiting.
continue
except grpc.RpcError as e:
if (
e.code() == grpc.StatusCode.DEADLINE_EXCEEDED
and timeout is not None
):
return
if self._handle_polling_failure(e) is True:
self.subscribe()
fut = self._stub.GcsSubscriberPoll.future(
self._poll_request(), timeout=timeout
)
else:
if self._should_terminate_polling(e):
return
raise
if fut.done():
self._last_batch_size = len(fut.result().pub_messages)
@ -309,13 +277,11 @@ class _SyncSubscriber(_SubscriberBase):
return
self._close.set()
req = self._unsubscribe_request(channels=[self._channel])
try:
self._stub.GcsSubscriberCommandBatch(req, timeout=5)
except Exception:
pass
with self._lock:
self._stub = None
self._stub = None
class GcsErrorSubscriber(_SyncSubscriber):
@ -530,39 +496,18 @@ class _AioSubscriber(_SubscriberBase):
if self._close.is_set():
return
req = self._subscribe_request(self._channel)
start = time.time()
from ray._raylet import Config
while True:
try:
return await self._stub.GcsSubscriberCommandBatch(req, timeout=30)
except grpc.RpcError as e:
self._handle_subscribe_failure(e)
if time.time() - start > Config.gcs_rpc_server_reconnect_timeout_s():
raise
await self._stub.GcsSubscriberCommandBatch(req, timeout=30)
async def _poll_call(self, req, timeout=None):
# Wrap GRPC _AioCall as a coroutine.
while True:
try:
return await self._stub.GcsSubscriberPoll(req, timeout=timeout)
except grpc.RpcError as e:
if (
e.code() == grpc.StatusCode.DEADLINE_EXCEEDED
and timeout is not None
):
return
if self._handle_polling_failure(e) is True:
await self.subscribe()
else:
return
return await self._stub.GcsSubscriberPoll(req, timeout=timeout)
async def _poll(self, timeout=None) -> None:
req = self._poll_request()
while len(self._queue) == 0:
# TODO: use asyncio.create_task() after Python 3.6 is no longer
# supported.
poll = asyncio.ensure_future(self._poll_call(req))
poll = asyncio.ensure_future(self._poll_call(req, timeout=timeout))
close = asyncio.ensure_future(self._close.wait())
done, _ = await asyncio.wait(
[poll, close], timeout=timeout, return_when=asyncio.FIRST_COMPLETED
@ -570,9 +515,14 @@ class _AioSubscriber(_SubscriberBase):
if poll not in done or close in done:
# Request timed out or subscriber closed.
break
self._last_batch_size = len(poll.result().pub_messages)
for msg in poll.result().pub_messages:
self._queue.append(msg)
try:
self._last_batch_size = len(poll.result().pub_messages)
for msg in poll.result().pub_messages:
self._queue.append(msg)
except grpc.RpcError as e:
if self._should_terminate_polling(e):
return
raise
async def close(self) -> None:
"""Closes the subscriber and its active subscription."""

View file

@ -1,5 +1,5 @@
from libcpp cimport bool as c_bool
from libc.stdint cimport int64_t, uint64_t, uint32_t, int32_t
from libc.stdint cimport int64_t, uint64_t, uint32_t
from libcpp.string cimport string as c_string
from libcpp.unordered_map cimport unordered_map
@ -68,5 +68,3 @@ cdef extern from "ray/common/ray_config.h" nogil:
c_bool start_python_importer_thread() const
c_bool use_ray_syncer() const
int32_t gcs_rpc_server_reconnect_timeout_s() const

View file

@ -84,10 +84,6 @@ cdef class Config:
def object_manager_default_chunk_size():
return RayConfig.instance().object_manager_default_chunk_size()
@staticmethod
def gcs_rpc_server_reconnect_timeout_s():
return RayConfig.instance().gcs_rpc_server_reconnect_timeout_s()
@staticmethod
def maximum_gcs_deletion_batch_size():
return RayConfig.instance().maximum_gcs_deletion_batch_size()

View file

@ -9,16 +9,11 @@ from time import sleep
from ray._private.test_utils import (
generate_system_config_map,
run_string_as_driver_nonblocking,
wait_for_condition,
wait_for_pid_to_exit,
convert_actor_state,
)
import logging
logger = logging.getLogger(__name__)
@ray.remote
class Increase:
@ -343,83 +338,6 @@ def test_core_worker_resubscription(tmp_path, ray_start_regular_with_external_re
ray.get(r, timeout=5)
@pytest.mark.parametrize(
"ray_start_regular_with_external_redis",
[
generate_system_config_map(
num_heartbeats_timeout=20, gcs_rpc_server_reconnect_timeout_s=60
)
],
indirect=True,
)
def test_py_resubscription(tmp_path, ray_start_regular_with_external_redis):
# This test is to ensure python pubsub works
from filelock import FileLock
lock_file1 = str(tmp_path / "lock1")
lock1 = FileLock(lock_file1)
lock1.acquire()
lock_file2 = str(tmp_path / "lock2")
lock2 = FileLock(lock_file2)
script = f"""
from filelock import FileLock
import ray
@ray.remote
def f():
print("OK1", flush=True)
# wait until log_monitor push this
from time import sleep
sleep(2)
lock1 = FileLock(r"{lock_file1}")
lock2 = FileLock(r"{lock_file2}")
lock2.acquire()
lock1.acquire()
# wait until log_monitor push this
from time import sleep
sleep(2)
print("OK2", flush=True)
ray.init(address='auto')
ray.get(f.remote())
ray.shutdown()
"""
proc = run_string_as_driver_nonblocking(script)
def condition():
import filelock
try:
lock2.acquire(timeout=1)
except filelock.Timeout:
return True
lock2.release()
return False
# make sure the script has printed "OK1"
wait_for_condition(condition, timeout=10)
ray.worker._global_node.kill_gcs_server()
import time
time.sleep(2)
ray.worker._global_node.start_gcs_server()
lock1.release()
proc.wait()
output = proc.stdout.read()
# Print logs which are useful for debugging in CI
print("=================== OUTPUTS ============")
print(output.decode())
assert b"OK1" in output
assert b"OK2" in output
@pytest.mark.parametrize("auto_reconnect", [True, False])
def test_gcs_client_reconnect(ray_start_regular_with_external_redis, auto_reconnect):
gcs_address = ray.worker.global_worker.gcs_client.address

View file

@ -1,7 +1,6 @@
import sys
import threading
import ray
from ray._private.gcs_pubsub import (
GcsPublisher,
GcsErrorSubscriber,
@ -35,72 +34,6 @@ def test_publish_and_subscribe_error_info(ray_start_regular):
subscriber.close()
def test_publish_and_subscribe_error_info_ft(ray_start_regular_with_external_redis):
address_info = ray_start_regular_with_external_redis
gcs_server_addr = address_info["gcs_address"]
from threading import Barrier, Thread
subscriber = GcsErrorSubscriber(address=gcs_server_addr)
subscriber.subscribe()
publisher = GcsPublisher(address=gcs_server_addr)
err1 = ErrorTableData(error_message="test error message 1")
err2 = ErrorTableData(error_message="test error message 2")
err3 = ErrorTableData(error_message="test error message 3")
err4 = ErrorTableData(error_message="test error message 4")
b = Barrier(3)
def publisher_func():
print("Publisher HERE")
publisher.publish_error(b"aaa_id", err1)
publisher.publish_error(b"bbb_id", err2)
b.wait()
print("Publisher HERE")
# Wait fo subscriber to subscribe first.
# It's ok to loose log messages.
from time import sleep
sleep(5)
publisher.publish_error(b"aaa_id", err3)
print("pub err1")
publisher.publish_error(b"bbb_id", err4)
print("pub err2")
print("DONE")
def subscriber_func():
print("Subscriber HERE")
assert subscriber.poll() == (b"aaa_id", err1)
assert subscriber.poll() == (b"bbb_id", err2)
b.wait()
assert subscriber.poll() == (b"aaa_id", err3)
print("sub err1")
assert subscriber.poll() == (b"bbb_id", err4)
print("sub err2")
subscriber.close()
print("DONE")
t1 = Thread(target=publisher_func)
t2 = Thread(target=subscriber_func)
t1.start()
t2.start()
b.wait()
ray.worker._global_node.kill_gcs_server()
from time import sleep
sleep(1)
ray.worker._global_node.start_gcs_server()
sleep(1)
t1.join()
t2.join()
@pytest.mark.asyncio
async def test_aio_publish_and_subscribe_error_info(ray_start_regular):
address_info = ray_start_regular
@ -121,61 +54,6 @@ async def test_aio_publish_and_subscribe_error_info(ray_start_regular):
await subscriber.close()
@pytest.mark.asyncio
async def test_aio_publish_and_subscribe_error_info_ft(
ray_start_regular_with_external_redis,
):
address_info = ray_start_regular_with_external_redis
gcs_server_addr = address_info["gcs_address"]
subscriber = GcsAioErrorSubscriber(address=gcs_server_addr)
await subscriber.subscribe()
err1 = ErrorTableData(error_message="test error message 1")
err2 = ErrorTableData(error_message="test error message 2")
err3 = ErrorTableData(error_message="test error message 3")
err4 = ErrorTableData(error_message="test error message 4")
def restart_gcs_server():
import asyncio
asyncio.set_event_loop(asyncio.new_event_loop())
from time import sleep
publisher = GcsAioPublisher(address=gcs_server_addr)
asyncio.get_event_loop().run_until_complete(
publisher.publish_error(b"aaa_id", err1)
)
asyncio.get_event_loop().run_until_complete(
publisher.publish_error(b"bbb_id", err2)
)
# wait until subscribe consume everything
sleep(5)
ray.worker._global_node.kill_gcs_server()
sleep(1)
ray.worker._global_node.start_gcs_server()
# wait until subscriber resubscribed
sleep(5)
asyncio.get_event_loop().run_until_complete(
publisher.publish_error(b"aaa_id", err3)
)
asyncio.get_event_loop().run_until_complete(
publisher.publish_error(b"bbb_id", err4)
)
t1 = threading.Thread(target=restart_gcs_server)
t1.start()
assert await subscriber.poll() == (b"aaa_id", err1)
assert await subscriber.poll() == (b"bbb_id", err2)
assert await subscriber.poll() == (b"aaa_id", err3)
assert await subscriber.poll() == (b"bbb_id", err4)
await subscriber.close()
t1.join()
def test_publish_and_subscribe_logs(ray_start_regular):
address_info = ray_start_regular
gcs_server_addr = address_info["gcs_address"]