[Remove Redis Pubsub 1/n] Remove enable_gcs_pubsub() (#23189)

GCS pubsub has been the default for awhile. There is little chance that we would need to revert back to Redis pubsub in future. This is the step in removing Redis pubsub, by first removing the `enable_gcs_pubsub()` feature guard.
This commit is contained in:
mwtian 2022-03-15 23:56:15 -07:00 committed by GitHub
parent 678d23fe42
commit 72ef9f91aa
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
15 changed files with 84 additions and 374 deletions

View file

@ -7,7 +7,6 @@ import traceback
import ray
import pytest
import ray.dashboard.utils as dashboard_utils
import ray._private.gcs_utils as gcs_utils
import ray._private.gcs_pubsub as gcs_pubsub
from ray.dashboard.tests.conftest import * # noqa
from ray.dashboard.modules.actor import actor_consts
@ -228,15 +227,7 @@ def test_actor_pubsub(disable_aiohttp_cache, ray_start_with_dashboard):
def handle_pub_messages(msgs, timeout, expect_num):
start_time = time.time()
while time.time() - start_time < timeout and len(msgs) < expect_num:
if gcs_pubsub.gcs_pubsub_enabled():
_, actor_data = sub.poll(timeout=timeout)
else:
msg = sub.get_message()
if msg is None:
time.sleep(0.01)
continue
pubsub_msg = gcs_utils.PubSubMessage.FromString(msg["data"])
actor_data = gcs_utils.ActorTableData.FromString(pubsub_msg.data)
_, actor_data = sub.poll(timeout=timeout)
if actor_data is None:
continue
msgs.append(actor_data)

View file

@ -5,7 +5,6 @@ import os
import aiohttp.web
import ray
import ray.dashboard.modules.reporter.reporter_consts as reporter_consts
import ray.dashboard.utils as dashboard_utils
import ray.dashboard.optional_utils as dashboard_optional_utils
import ray.experimental.internal_kv as internal_kv
@ -18,7 +17,7 @@ from ray.ray_constants import (
)
from ray.core.generated import reporter_pb2
from ray.core.generated import reporter_pb2_grpc
from ray._private.gcs_pubsub import gcs_pubsub_enabled, GcsAioResourceUsageSubscriber
from ray._private.gcs_pubsub import GcsAioResourceUsageSubscriber
from ray._private.metrics_agent import PrometheusServiceDiscoveryWriter
from ray.dashboard.datacenter import DataSource
@ -148,45 +147,24 @@ class ReportHead(dashboard_utils.DashboardHeadModule):
# Need daemon True to avoid dashboard hangs at exit.
self.service_discovery.daemon = True
self.service_discovery.start()
if gcs_pubsub_enabled():
gcs_addr = self._dashboard_head.gcs_address
subscriber = GcsAioResourceUsageSubscriber(gcs_addr)
await subscriber.subscribe()
gcs_addr = self._dashboard_head.gcs_address
subscriber = GcsAioResourceUsageSubscriber(gcs_addr)
await subscriber.subscribe()
while True:
try:
# The key is b'RAY_REPORTER:{node id hex}',
# e.g. b'RAY_REPORTER:2b4fbd...'
key, data = await subscriber.poll()
if key is None:
continue
data = json.loads(data)
node_id = key.split(":")[-1]
DataSource.node_physical_stats[node_id] = data
except Exception:
logger.exception(
"Error receiving node physical stats " "from reporter agent."
)
else:
from aioredis.pubsub import Receiver
receiver = Receiver()
aioredis_client = self._dashboard_head.aioredis_client
reporter_key = "{}*".format(reporter_consts.REPORTER_PREFIX)
await aioredis_client.psubscribe(receiver.pattern(reporter_key))
logger.info(f"Subscribed to {reporter_key}")
async for sender, msg in receiver.iter():
try:
key, data = msg
data = json.loads(ray._private.utils.decode(data))
key = key.decode("utf-8")
node_id = key.split(":")[-1]
DataSource.node_physical_stats[node_id] = data
except Exception:
logger.exception(
"Error receiving node physical stats " "from reporter agent."
)
while True:
try:
# The key is b'RAY_REPORTER:{node id hex}',
# e.g. b'RAY_REPORTER:2b4fbd...'
key, data = await subscriber.poll()
if key is None:
continue
data = json.loads(data)
node_id = key.split(":")[-1]
DataSource.node_physical_stats[node_id] = data
except Exception:
logger.exception(
"Error receiving node physical stats from reporter agent."
)
@staticmethod
def is_minimal_module():

View file

@ -23,7 +23,6 @@ from ray._private.test_utils import (
run_string_as_driver,
wait_until_succeeded_without_exception,
)
from ray._private.gcs_pubsub import gcs_pubsub_enabled
from ray.ray_constants import DEBUG_AUTOSCALING_STATUS_LEGACY, DEBUG_AUTOSCALING_ERROR
from ray.dashboard import dashboard
import ray.dashboard.consts as dashboard_consts
@ -701,13 +700,9 @@ def test_gcs_check_alive(fast_gcs_failure_detection, ray_start_with_dashboard):
gcs_server_proc.kill()
gcs_server_proc.wait()
if gcs_pubsub_enabled():
# When pubsub enabled, the exits comes from pubsub errored.
# TODO: Fix this exits logic for pubsub
assert dashboard_proc.wait(10) != 0
else:
# The dashboard exits by os._exit(-1)
assert dashboard_proc.wait(10) == 255
# The dashboard exits by os._exit(-1)
assert dashboard_proc.wait(10) == 255
if __name__ == "__main__":

View file

@ -178,13 +178,8 @@ class FunctionActorManager:
break
# Notify all subscribers that there is a new function exported. Note
# that the notification doesn't include any actual data.
if self._worker.gcs_pubsub_enabled:
# TODO(mwtian) implement per-job notification here.
self._worker.gcs_publisher.publish_function_key(key)
else:
self._worker.redis_client.lpush(
make_exports_prefix(self._worker.current_job_id), "a"
)
# TODO(mwtian) implement per-job notification here.
self._worker.gcs_publisher.publish_function_key(key)
def export(self, remote_function):
"""Pickle a remote function and export it to redis.

View file

@ -16,7 +16,6 @@ except ImportError:
import ray._private.gcs_utils as gcs_utils
import ray._private.logging_utils as logging_utils
from ray._raylet import Config
from ray.core.generated.gcs_pb2 import ErrorTableData
from ray.core.generated import dependency_pb2
from ray.core.generated import gcs_service_pb2_grpc
@ -30,11 +29,6 @@ logger = logging.getLogger(__name__)
MAX_GCS_PUBLISH_RETRIES = 60
def gcs_pubsub_enabled():
"""Checks whether GCS pubsub feature flag is enabled."""
return Config.gcs_grpc_based_pubsub()
def construct_error_message(job_id, error_type, message, timestamp):
"""Construct an ErrorTableData object.

View file

@ -94,13 +94,9 @@ _GRPC_OPTIONS = [
def use_gcs_for_bootstrap():
from ray._private.gcs_pubsub import gcs_pubsub_enabled
from ray._raylet import Config
ret = Config.bootstrap_with_gcs()
if ret:
assert gcs_pubsub_enabled()
return ret
return Config.bootstrap_with_gcs()
def get_gcs_address_from_redis(redis) -> str:

View file

@ -33,21 +33,9 @@ class ImportThread:
self.worker = worker
self.mode = mode
self.gcs_client = worker.gcs_client
if worker.gcs_pubsub_enabled:
self.subscriber = worker.gcs_function_key_subscriber
self.subscriber.subscribe()
self.exception_type = grpc.RpcError
else:
import redis
self.subscriber = worker.redis_client.pubsub()
self.subscriber.subscribe(
b"__keyspace@0__:"
+ ray._private.function_manager.make_exports_prefix(
self.worker.current_job_id
)
)
self.exception_type = redis.exceptions.ConnectionError
self.subscriber = worker.gcs_function_key_subscriber
self.subscriber.subscribe()
self.exception_type = grpc.RpcError
self.threads_stopped = threads_stopped
self.imported_collision_identifiers = defaultdict(int)
# Keep track of the number of imports that we've imported.
@ -72,20 +60,10 @@ class ImportThread:
# Exit if we received a signal that we should stop.
if self.threads_stopped.is_set():
return
if self.worker.gcs_pubsub_enabled:
key = self.subscriber.poll()
if key is None:
# subscriber has closed.
break
else:
msg = self.subscriber.get_message()
if msg is None:
self.threads_stopped.wait(timeout=0.01)
continue
if msg["type"] == "subscribe":
continue
key = self.subscriber.poll()
if key is None:
# subscriber has closed.
break
self._do_importing()
except (OSError, self.exception_type) as e:
logger.error(f"ImportThread: {e}")

View file

@ -1,7 +1,6 @@
import argparse
import errno
import glob
import json
import logging
import logging.handlers
import os
@ -13,10 +12,9 @@ import traceback
import ray.ray_constants as ray_constants
import ray._private.gcs_pubsub as gcs_pubsub
import ray._private.gcs_utils as gcs_utils
import ray._private.services as services
import ray._private.utils
from ray._private.gcs_pubsub import gcs_pubsub_enabled, GcsPublisher
from ray._private.gcs_pubsub import GcsPublisher
from ray._private.ray_logging import setup_component_logger
# Logger for this module. It should be configured at the entry point
@ -91,7 +89,6 @@ class LogMonitor:
host (str): The hostname of this machine. Used to improve the log
messages published to Redis.
logs_dir (str): The directory that the log files are in.
redis_client: A client used to communicate with the Redis server.
log_filenames (set): This is the set of filenames of all files in
open_file_infos and closed_file_infos.
open_file_infos (list[LogFileInfo]): Info for all of the open files.
@ -105,10 +102,7 @@ class LogMonitor:
"""Initialize the log monitor object."""
self.ip = services.get_node_ip_address()
self.logs_dir = logs_dir
self.redis_client = None
self.publisher = None
if gcs_pubsub.gcs_pubsub_enabled():
self.publisher = gcs_pubsub.GcsPublisher(address=gcs_address)
self.publisher = gcs_pubsub.GcsPublisher(address=gcs_address)
self.log_filenames = set()
self.open_file_infos = []
self.closed_file_infos = []
@ -293,12 +287,7 @@ class LogMonitor:
"actor_name": file_info.actor_name,
"task_name": file_info.task_name,
}
if self.publisher:
self.publisher.publish_logs(data)
else:
self.redis_client.publish(
gcs_utils.LOG_FILE_CHANNEL, json.dumps(data)
)
self.publisher.publish_logs(data)
anything_published = True
lines_to_publish = []
@ -477,12 +466,7 @@ if __name__ == "__main__":
log_monitor.run()
except Exception as e:
# Something went wrong, so push an error to all drivers.
redis_client = ray._private.services.create_redis_client(
args.redis_address, password=args.redis_password
)
gcs_publisher = None
if gcs_pubsub_enabled():
gcs_publisher = GcsPublisher(address=args.gcs_address)
gcs_publisher = GcsPublisher(address=args.gcs_address)
traceback_str = ray._private.utils.format_error_message(traceback.format_exc())
message = (
f"The log monitor on node {platform.node()} "
@ -491,7 +475,6 @@ if __name__ == "__main__":
ray._private.utils.publish_error_to_driver(
ray_constants.LOG_MONITOR_DIED_ERROR,
message,
redis_client=redis_client,
gcs_publisher=gcs_publisher,
)
logger.error(message)

View file

@ -31,7 +31,6 @@ from ray.core.generated import gcs_pb2
from ray.core.generated import node_manager_pb2
from ray.core.generated import node_manager_pb2_grpc
from ray._private.gcs_pubsub import (
gcs_pubsub_enabled,
GcsErrorSubscriber,
GcsLogSubscriber,
)
@ -218,7 +217,6 @@ def run_string_as_driver(driver_script: str, env: Dict = None, encode: str = "ut
Returns:
The script's output.
"""
proc = subprocess.Popen(
[sys.executable, "-"],
stdin=subprocess.PIPE,
@ -581,12 +579,8 @@ def get_non_head_nodes(cluster):
def init_error_pubsub():
"""Initialize redis error info pub/sub"""
if gcs_pubsub_enabled():
s = GcsErrorSubscriber(address=ray.worker.global_worker.gcs_client.address)
s.subscribe()
else:
s = ray.worker.global_worker.redis_client.pubsub(ignore_subscribe_messages=True)
s.psubscribe(gcs_utils.RAY_ERROR_PUBSUB_PATTERN)
s = GcsErrorSubscriber(address=ray.worker.global_worker.gcs_client.address)
s.subscribe()
return s
@ -621,12 +615,8 @@ def get_error_message(subscriber, num=1e6, error_type=None, timeout=20):
def init_log_pubsub():
"""Initialize redis error info pub/sub"""
if gcs_pubsub_enabled():
s = GcsLogSubscriber(address=ray.worker.global_worker.gcs_client.address)
s.subscribe()
else:
s = ray.worker.global_worker.redis_client.pubsub(ignore_subscribe_messages=True)
s.psubscribe(gcs_utils.LOG_FILE_CHANNEL)
s = GcsLogSubscriber(address=ray.worker.global_worker.gcs_client.address)
s.subscribe()
return s

View file

@ -8,7 +8,7 @@ import numpy as np
import pytest
import ray.cluster_utils
from ray._private.gcs_pubsub import gcs_pubsub_enabled, GcsFunctionKeySubscriber
from ray._private.gcs_pubsub import GcsFunctionKeySubscriber
from ray._private.test_utils import wait_for_condition
from ray.autoscaler._private.constants import RAY_PROCESSES
from pathlib import Path
@ -108,33 +108,27 @@ def test_function_unique_export(ray_start_regular):
def g():
ray.get(f.remote())
if gcs_pubsub_enabled():
subscriber = GcsFunctionKeySubscriber(
address=ray.worker.global_worker.gcs_client.address
)
subscriber.subscribe()
subscriber = GcsFunctionKeySubscriber(
address=ray.worker.global_worker.gcs_client.address
)
subscriber.subscribe()
ray.get(g.remote())
# Poll pubsub channel for messages generated from running task g().
num_exports = 0
while True:
key = subscriber.poll(timeout=1)
if key is None:
break
else:
num_exports += 1
print(f"num_exports after running g(): {num_exports}")
ray.get([g.remote() for _ in range(5)])
ray.get(g.remote())
# Poll pubsub channel for messages generated from running task g().
num_exports = 0
while True:
key = subscriber.poll(timeout=1)
assert key is None, f"Unexpected function key export: {key}"
else:
ray.get(g.remote())
num_exports = ray.worker.global_worker.redis_client.llen("Exports")
ray.get([g.remote() for _ in range(5)])
assert ray.worker.global_worker.redis_client.llen("Exports") == num_exports
if key is None:
break
else:
num_exports += 1
print(f"num_exports after running g(): {num_exports}")
ray.get([g.remote() for _ in range(5)])
key = subscriber.poll(timeout=1)
assert key is None, f"Unexpected function key export: {key}"
@pytest.mark.skipif(

View file

@ -11,7 +11,7 @@ import ray._private.utils
import ray._private.gcs_utils as gcs_utils
import ray.ray_constants as ray_constants
from ray.exceptions import RayTaskError, RayActorError, GetTimeoutError
from ray._private.gcs_pubsub import gcs_pubsub_enabled, GcsPublisher
from ray._private.gcs_pubsub import GcsPublisher
from ray._private.test_utils import (
wait_for_condition,
SignalActor,
@ -69,21 +69,12 @@ def test_unhandled_errors(ray_start_regular):
def test_publish_error_to_driver(ray_start_regular, error_pubsub):
address_info = ray_start_regular
redis_client = None
gcs_publisher = None
if gcs_pubsub_enabled():
gcs_publisher = GcsPublisher(address=address_info["gcs_address"])
else:
redis_client = ray._private.services.create_redis_client(
address_info["redis_address"],
password=ray.ray_constants.REDIS_DEFAULT_PASSWORD,
)
gcs_publisher = GcsPublisher(address=address_info["gcs_address"])
error_message = "Test error message"
ray._private.utils.publish_error_to_driver(
ray_constants.DASHBOARD_AGENT_DIED_ERROR,
error_message,
redis_client=redis_client,
gcs_publisher=gcs_publisher,
)
errors = get_error_message(

View file

@ -14,7 +14,6 @@ from ray.ray_constants import DEBUG_AUTOSCALING_ERROR
import ray._private.utils
import ray.ray_constants as ray_constants
from ray.cluster_utils import cluster_not_supported
import ray._private.gcs_pubsub as gcs_pubsub
from ray._private.test_utils import (
init_error_pubsub,
get_error_message,
@ -425,37 +424,6 @@ def test_fate_sharing(ray_start_cluster, use_actors, node_failure):
test_process_failure(use_actors)
@pytest.mark.parametrize(
"ray_start_regular",
[{"_system_config": {"gcs_rpc_server_reconnect_timeout_s": 100}}],
indirect=True,
)
@pytest.mark.skipif(
gcs_pubsub.gcs_pubsub_enabled(),
reason="Logs are streamed via GCS pubsub when it is enabled, so logs "
"cannot be delivered after GCS is killed.",
)
def test_gcs_server_failiure_report(ray_start_regular, log_pubsub):
# Get gcs server pid to send a signal.
all_processes = ray.worker._global_node.all_processes
gcs_server_process = all_processes["gcs_server"][0].process
gcs_server_pid = gcs_server_process.pid
# TODO(mwtian): make sure logs are delivered after GCS is restarted.
if sys.platform == "win32":
sig = 9
else:
sig = signal.SIGBUS
os.kill(gcs_server_pid, sig)
# wait for 30 seconds, for the 1st batch of logs.
batches = get_log_batch(log_pubsub, 1, timeout=30)
assert gcs_server_process.poll() is not None
if sys.platform != "win32":
# Windows signal handler does not run when process is terminated
assert len(batches) == 1
assert batches[0]["pid"] == "gcs_server", batches
def test_list_named_actors_timeout(monkeypatch, shutdown_only):
with monkeypatch.context() as m:
# defer for 3s

View file

@ -1,7 +1,6 @@
import sys
import ray
import ray._private.gcs_pubsub as gcs_pubsub
import ray._private.gcs_utils as gcs_utils
import pytest
from ray._private.test_utils import (
@ -63,11 +62,11 @@ def test_gcs_server_restart(ray_start_regular_with_external_redis):
],
indirect=True,
)
@pytest.mark.skipif(
gcs_pubsub.gcs_pubsub_enabled(),
@pytest.mark.skip(
reason="GCS pubsub may lose messages after GCS restarts. Need to "
"implement re-fetching state in GCS client.",
)
# TODO(mwtian): re-enable after fixing https://github.com/ray-project/ray/issues/22340
def test_gcs_server_restart_during_actor_creation(
ray_start_regular_with_external_redis,
):

View file

@ -15,11 +15,6 @@ from ray.core.generated.gcs_pb2 import ErrorTableData
import pytest
@pytest.mark.parametrize(
"ray_start_regular",
[{"_system_config": {"gcs_grpc_based_pubsub": True}}],
indirect=True,
)
def test_publish_and_subscribe_error_info(ray_start_regular):
address_info = ray_start_regular
gcs_server_addr = address_info["gcs_address"]
@ -40,11 +35,6 @@ def test_publish_and_subscribe_error_info(ray_start_regular):
@pytest.mark.asyncio
@pytest.mark.parametrize(
"ray_start_regular",
[{"_system_config": {"gcs_grpc_based_pubsub": True}}],
indirect=True,
)
async def test_aio_publish_and_subscribe_error_info(ray_start_regular):
address_info = ray_start_regular
gcs_server_addr = address_info["gcs_address"]
@ -64,11 +54,6 @@ async def test_aio_publish_and_subscribe_error_info(ray_start_regular):
await subscriber.close()
@pytest.mark.parametrize(
"ray_start_regular",
[{"_system_config": {"gcs_grpc_based_pubsub": True}}],
indirect=True,
)
def test_publish_and_subscribe_logs(ray_start_regular):
address_info = ray_start_regular
gcs_server_addr = address_info["gcs_address"]
@ -96,11 +81,6 @@ def test_publish_and_subscribe_logs(ray_start_regular):
@pytest.mark.asyncio
@pytest.mark.parametrize(
"ray_start_regular",
[{"_system_config": {"gcs_grpc_based_pubsub": True}}],
indirect=True,
)
async def test_aio_publish_and_subscribe_logs(ray_start_regular):
address_info = ray_start_regular
gcs_server_addr = address_info["gcs_address"]
@ -125,11 +105,6 @@ async def test_aio_publish_and_subscribe_logs(ray_start_regular):
await subscriber.close()
@pytest.mark.parametrize(
"ray_start_regular",
[{"_system_config": {"gcs_grpc_based_pubsub": True}}],
indirect=True,
)
def test_publish_and_subscribe_function_keys(ray_start_regular):
address_info = ray_start_regular
gcs_server_addr = address_info["gcs_address"]
@ -148,11 +123,6 @@ def test_publish_and_subscribe_function_keys(ray_start_regular):
@pytest.mark.asyncio
@pytest.mark.parametrize(
"ray_start_regular",
[{"_system_config": {"gcs_grpc_based_pubsub": True}}],
indirect=True,
)
async def test_aio_publish_and_subscribe_resource_usage(ray_start_regular):
address_info = ray_start_regular
gcs_server_addr = address_info["gcs_address"]
@ -170,11 +140,6 @@ async def test_aio_publish_and_subscribe_resource_usage(ray_start_regular):
await subscriber.close()
@pytest.mark.parametrize(
"ray_start_regular",
[{"_system_config": {"gcs_grpc_based_pubsub": True}}],
indirect=True,
)
def test_two_subscribers(ray_start_regular):
"""Tests concurrently subscribing to two channels work."""

View file

@ -30,7 +30,6 @@ import ray._private.gcs_utils as gcs_utils
import ray._private.services as services
from ray.util.scheduling_strategies import SchedulingStrategyT
from ray._private.gcs_pubsub import (
gcs_pubsub_enabled,
GcsPublisher,
GcsErrorSubscriber,
GcsLogSubscriber,
@ -453,26 +452,13 @@ class Worker:
def print_logs(self):
"""Prints log messages from workers on all nodes in the same job."""
if self.gcs_pubsub_enabled:
import grpc
import grpc
subscriber = self.gcs_log_subscriber
subscriber.subscribe()
exception_type = grpc.RpcError
else:
import redis
subscriber = self.redis_client.pubsub(ignore_subscribe_messages=True)
subscriber.subscribe(gcs_utils.LOG_FILE_CHANNEL)
exception_type = redis.exceptions.ConnectionError
subscriber = self.gcs_log_subscriber
subscriber.subscribe()
exception_type = grpc.RpcError
localhost = services.get_node_ip_address()
try:
# Keep track of the number of consecutive log messages that have
# been received with no break in between. If this number grows
# continually, then the worker is probably not able to process the
# log messages as rapidly as they are coming in.
# This is meaningful only for Redis subscriber.
num_consecutive_messages_received = 0
# Number of messages received from the last polling. When the batch
# size exceeds 100 and keeps increasing, the worker and the user
# probably will not be able to consume the log messages as rapidly
@ -485,43 +471,21 @@ class Worker:
if self.threads_stopped.is_set():
return
if self.gcs_pubsub_enabled:
msg = subscriber.poll()
else:
msg = subscriber.get_message()
data = subscriber.poll()
# GCS subscriber only returns None on unavailability.
# Redis subscriber returns None when there is no new message.
if msg is None:
num_consecutive_messages_received = 0
if data is None:
last_polling_batch_size = 0
self.threads_stopped.wait(timeout=0.01)
continue
if self.gcs_pubsub_enabled:
data = msg
else:
data = json.loads(ray._private.utils.decode(msg["data"]))
# Don't show logs from other drivers.
if data["job"] and data["job"] != job_id_hex:
num_consecutive_messages_received = 0
last_polling_batch_size = 0
continue
data["localhost"] = localhost
global_worker_stdstream_dispatcher.emit(data)
if self.gcs_pubsub_enabled:
lagging = (
100 <= last_polling_batch_size < subscriber.last_batch_size
)
last_polling_batch_size = subscriber.last_batch_size
else:
num_consecutive_messages_received += 1
lagging = (
num_consecutive_messages_received % 100 == 0
and num_consecutive_messages_received > 0
)
lagging = 100 <= last_polling_batch_size < subscriber.last_batch_size
if lagging:
logger.warning(
"The driver may not be able to keep up with the "
@ -530,6 +494,8 @@ class Worker:
"'ray.init(log_to_driver=False)'."
)
last_polling_batch_size = subscriber.last_batch_size
except (OSError, exception_type) as e:
logger.error(f"print_logs: {e}")
finally:
@ -1366,72 +1332,7 @@ def print_worker_logs(data: Dict[str, str], print_file: Any):
)
def listen_error_messages_raylet(worker, threads_stopped):
"""Listen to error messages in the background on the driver.
This runs in a separate thread on the driver and pushes (error, time)
tuples to the output queue.
Args:
worker: The worker class that this thread belongs to.
threads_stopped (threading.Event): A threading event used to signal to
the thread that it should exit.
"""
import redis
worker.error_message_pubsub_client = worker.redis_client.pubsub(
ignore_subscribe_messages=True
)
# Exports that are published after the call to
# error_message_pubsub_client.subscribe and before the call to
# error_message_pubsub_client.listen will still be processed in the loop.
# Really we should just subscribe to the errors for this specific job.
# However, currently all errors seem to be published on the same channel.
error_pubsub_channel = gcs_utils.RAY_ERROR_PUBSUB_PATTERN
worker.error_message_pubsub_client.psubscribe(error_pubsub_channel)
try:
if _internal_kv_initialized():
# Get any autoscaler errors that occurred before the call to
# subscribe.
error_message = _internal_kv_get(ray_constants.DEBUG_AUTOSCALING_ERROR)
if error_message is not None:
logger.warning(error_message.decode())
while True:
# Exit if we received a signal that we should stop.
if threads_stopped.is_set():
return
msg = worker.error_message_pubsub_client.get_message()
if msg is None:
threads_stopped.wait(timeout=0.01)
continue
pubsub_msg = gcs_utils.PubSubMessage.FromString(msg["data"])
error_data = gcs_utils.ErrorTableData.FromString(pubsub_msg.data)
job_id = error_data.job_id
if job_id not in [
worker.current_job_id.binary(),
JobID.nil().binary(),
]:
continue
error_message = error_data.error_message
if error_data.type == ray_constants.TASK_PUSH_ERROR:
# TODO(ekl) remove task push errors entirely now that we have
# the separate unhandled exception handler.
pass
else:
logger.warning(error_message)
except (OSError, redis.exceptions.ConnectionError) as e:
logger.error(f"listen_error_messages_raylet: {e}")
finally:
# Close the pubsub client to avoid leaking file descriptors.
worker.error_message_pubsub_client.close()
def listen_error_messages_from_gcs(worker, threads_stopped):
def listen_error_messages(worker, threads_stopped):
"""Listen to error messages in the background on the driver.
This runs in a separate thread on the driver and pushes (error, time)
@ -1476,7 +1377,7 @@ def listen_error_messages_from_gcs(worker, threads_stopped):
else:
logger.warning(error_message)
except (OSError, ConnectionError) as e:
logger.error(f"listen_error_messages_from_gcs: {e}")
logger.error(f"listen_error_messages: {e}")
@PublicAPI
@ -1543,17 +1444,12 @@ def connect(
ray.state.state._initialize_global_state(
ray._raylet.GcsClientOptions.from_gcs_address(node.gcs_address)
)
worker.gcs_pubsub_enabled = gcs_pubsub_enabled()
worker.gcs_publisher = None
if worker.gcs_pubsub_enabled:
worker.gcs_publisher = GcsPublisher(address=worker.gcs_client.address)
worker.gcs_error_subscriber = GcsErrorSubscriber(
address=worker.gcs_client.address
)
worker.gcs_log_subscriber = GcsLogSubscriber(address=worker.gcs_client.address)
worker.gcs_function_key_subscriber = GcsFunctionKeySubscriber(
address=worker.gcs_client.address
)
worker.gcs_publisher = GcsPublisher(address=worker.gcs_client.address)
worker.gcs_error_subscriber = GcsErrorSubscriber(address=worker.gcs_client.address)
worker.gcs_log_subscriber = GcsLogSubscriber(address=worker.gcs_client.address)
worker.gcs_function_key_subscriber = GcsFunctionKeySubscriber(
address=worker.gcs_client.address
)
# Initialize some fields.
if mode in (WORKER_MODE, RESTORE_WORKER_MODE, SPILL_WORKER_MODE):
@ -1701,9 +1597,7 @@ def connect(
# scheduler for new error messages.
if mode == SCRIPT_MODE:
worker.listener_thread = threading.Thread(
target=listen_error_messages_from_gcs
if worker.gcs_pubsub_enabled
else listen_error_messages_raylet,
target=listen_error_messages,
name="ray_listen_error_messages",
args=(worker, worker.threads_stopped),
)
@ -1775,10 +1669,9 @@ def disconnect(exiting_interpreter=False):
# should be handled cleanly in the worker object's destructor and not
# in this disconnect method.
worker.threads_stopped.set()
if worker.gcs_pubsub_enabled:
worker.gcs_function_key_subscriber.close()
worker.gcs_error_subscriber.close()
worker.gcs_log_subscriber.close()
worker.gcs_function_key_subscriber.close()
worker.gcs_error_subscriber.close()
worker.gcs_log_subscriber.close()
if hasattr(worker, "import_thread"):
worker.import_thread.join_import_thread()
if hasattr(worker, "listener_thread"):