[Core][Pubsub][Logging 2/n] Use GCS pubsub for logs (#20492)

Using Ray pubsub for publishing and subscribing logs via GCS, from Python worker, log importer, dashboard and unit tests.

This change is guarded behind the RAY_gcs_grpc_based_pubsub feature flag.
This commit is contained in:
mwtian 2021-11-30 12:00:44 -08:00 committed by GitHub
parent bec719d823
commit 95c26eec26
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 91 additions and 42 deletions

View file

@ -208,6 +208,7 @@ class DashboardHead:
self.gcs_subscriber = GcsAioSubscriber(
channel=self.aiogrpc_gcs_channel)
await self.gcs_subscriber.subscribe_error()
await self.gcs_subscriber.subscribe_logs()
self.health_check_thread = GCSHealthCheckThread(gcs_address)
self.health_check_thread.start()

View file

@ -259,20 +259,29 @@ class NodeHead(dashboard_utils.DashboardHeadModule):
DataSource.ip_and_pid_to_logs[ip] = logs_for_ip
logger.info(f"Received a log for {ip} and {pid}")
aioredis_client = self._dashboard_head.aioredis_client
receiver = Receiver()
if self._dashboard_head.gcs_subscriber:
while True:
log_batch = await \
self._dashboard_head.gcs_subscriber.poll_logs()
try:
process_log_batch(log_batch)
except Exception:
logger.exception("Error receiving log from GCS.")
else:
aioredis_client = self._dashboard_head.aioredis_client
receiver = Receiver()
channel = receiver.channel(gcs_utils.LOG_FILE_CHANNEL)
await aioredis_client.subscribe(channel)
logger.info("Subscribed to %s", channel)
channel = receiver.channel(gcs_utils.LOG_FILE_CHANNEL)
await aioredis_client.subscribe(channel)
logger.info("Subscribed to %s", channel)
async for sender, msg in receiver.iter():
try:
data = json.loads(ray._private.utils.decode(msg))
data["pid"] = str(data["pid"])
process_log_batch(data)
except Exception:
logger.exception("Error receiving log from Redis.")
async for sender, msg in receiver.iter():
try:
data = json.loads(ray._private.utils.decode(msg))
data["pid"] = str(data["pid"])
process_log_batch(data)
except Exception:
logger.exception("Error receiving log from Redis.")
async def _update_error_info(self):
def process_error(error_data):

View file

@ -12,6 +12,7 @@ import time
import traceback
import ray.ray_constants as ray_constants
import ray._private.gcs_pubsub as gcs_pubsub
import ray._private.gcs_utils as gcs_utils
import ray._private.services as services
import ray._private.utils
@ -98,6 +99,10 @@ class LogMonitor:
self.logs_dir = logs_dir
self.redis_client = ray._private.services.create_redis_client(
redis_address, password=redis_password)
self.publisher = None
if gcs_pubsub.gcs_pubsub_enabled():
gcs_addr = gcs_utils.get_gcs_address_from_redis(self.redis_client)
self.publisher = gcs_pubsub.GcsPublisher(address=gcs_addr)
self.log_filenames = set()
self.open_file_infos = []
self.closed_file_infos = []
@ -270,8 +275,11 @@ class LogMonitor:
"actor_name": file_info.actor_name,
"task_name": file_info.task_name,
}
self.redis_client.publish(gcs_utils.LOG_FILE_CHANNEL,
json.dumps(data))
if self.publisher:
self.publisher.publish_logs(data)
else:
self.redis_client.publish(gcs_utils.LOG_FILE_CHANNEL,
json.dumps(data))
anything_published = True
lines_to_publish = []

View file

@ -26,7 +26,8 @@ import ray._private.memory_monitor as memory_monitor
from ray.core.generated import gcs_pb2
from ray.core.generated import node_manager_pb2
from ray.core.generated import node_manager_pb2_grpc
from ray._private.gcs_pubsub import gcs_pubsub_enabled, GcsErrorSubscriber
from ray._private.gcs_pubsub import gcs_pubsub_enabled, GcsErrorSubscriber, \
GcsLogSubscriber
from ray._private.tls_utils import generate_self_signed_tls_certs
from ray.util.queue import Queue, _QueueActor, Empty
from ray.scripts.scripts import main as ray_main
@ -554,11 +555,15 @@ def get_error_message(subscriber, num=1e6, error_type=None, timeout=20):
def init_log_pubsub():
"""Initialize redis error info pub/sub"""
p = ray.worker.global_worker.redis_client.pubsub(
ignore_subscribe_messages=True)
log_pubsub_channel = gcs_utils.LOG_FILE_CHANNEL
p.psubscribe(log_pubsub_channel)
return p
if gcs_pubsub_enabled():
s = GcsLogSubscriber(
channel=ray.worker.global_worker.gcs_channel.channel())
s.subscribe()
else:
s = ray.worker.global_worker.redis_client.pubsub(
ignore_subscribe_messages=True)
s.psubscribe(gcs_utils.LOG_FILE_CHANNEL)
return s
def get_log_message(subscriber,
@ -576,11 +581,17 @@ def get_log_message(subscriber,
deadline = time.time() + timeout
msgs = []
while time.time() < deadline and len(msgs) < num:
msg = subscriber.get_message()
if msg is None:
time.sleep(0.01)
continue
logs_data = json.loads(ray._private.utils.decode(msg["data"]))
if isinstance(subscriber, GcsLogSubscriber):
logs_data = subscriber.poll(timeout=deadline - time.time())
if not logs_data:
# Timed out before any data is received.
break
else:
msg = subscriber.get_message()
if msg is None:
time.sleep(0.01)
continue
logs_data = json.loads(ray._private.utils.decode(msg["data"]))
if job_id and job_id != logs_data["job"]:
continue
@ -607,11 +618,17 @@ def get_log_batch(subscriber,
deadline = time.time() + timeout
batches = []
while time.time() < deadline and len(batches) < num:
msg = subscriber.get_message()
if msg is None:
time.sleep(0.01)
continue
logs_data = json.loads(ray._private.utils.decode(msg["data"]))
if isinstance(subscriber, GcsLogSubscriber):
logs_data = subscriber.poll(timeout=deadline - time.time())
if not logs_data:
# Timed out before any data is received.
break
else:
msg = subscriber.get_message()
if msg is None:
time.sleep(0.01)
continue
logs_data = json.loads(ray._private.utils.decode(msg["data"]))
if job_id and job_id != logs_data["job"]:
continue

View file

@ -200,13 +200,13 @@ def test_subscribe_two_channels(ray_start_regular):
"task_name": "test task",
})
t1.join(timeout=1)
assert not t1.is_alive()
assert len(errors) == num_messages
t1.join(timeout=10)
assert not t1.is_alive(), len(errors)
assert len(errors) == num_messages, len(errors)
t2.join(timeout=1)
assert not t2.is_alive()
assert len(logs) == num_messages
t2.join(timeout=10)
assert not t2.is_alive(), len(logs)
assert len(logs) == num_messages, len(logs)
for i in range(0, num_messages):
assert errors[i].error_message == f"error {i}"

View file

@ -28,7 +28,7 @@ import ray.serialization as serialization
import ray._private.gcs_utils as gcs_utils
import ray._private.services as services
from ray._private.gcs_pubsub import gcs_pubsub_enabled, GcsPublisher, \
GcsErrorSubscriber
GcsErrorSubscriber, GcsLogSubscriber
from ray._private.runtime_env.py_modules import upload_py_modules_if_needed
from ray._private.runtime_env.working_dir import upload_working_dir_if_needed
from ray._private.runtime_env.constants import RAY_JOB_CONFIG_JSON_ENV_VAR
@ -434,8 +434,13 @@ class Worker:
def print_logs(self):
"""Prints log messages from workers on all nodes in the same job.
"""
subscriber = self.redis_client.pubsub(ignore_subscribe_messages=True)
subscriber.subscribe(gcs_utils.LOG_FILE_CHANNEL)
if self.gcs_pubsub_enabled:
subscriber = self.gcs_log_subscriber
subscriber.subscribe()
else:
subscriber = self.redis_client.pubsub(
ignore_subscribe_messages=True)
subscriber.subscribe(gcs_utils.LOG_FILE_CHANNEL)
localhost = services.get_node_ip_address()
try:
# Keep track of the number of consecutive log messages that have
@ -449,7 +454,10 @@ class Worker:
if self.threads_stopped.is_set():
return
msg = subscriber.get_message()
if self.gcs_pubsub_enabled:
msg = subscriber.poll()
else:
msg = subscriber.get_message()
if msg is None:
num_consecutive_messages_received = 0
self.threads_stopped.wait(timeout=0.01)
@ -463,7 +471,10 @@ class Worker:
"logs to the driver, use "
"'ray.init(log_to_driver=False)'.")
data = json.loads(ray._private.utils.decode(msg["data"]))
if self.gcs_pubsub_enabled:
data = msg
else:
data = json.loads(ray._private.utils.decode(msg["data"]))
# Don't show logs from other drivers.
if (self.filter_logs_by_job and data["job"]
@ -1373,6 +1384,8 @@ def connect(node,
channel=worker.gcs_channel.channel())
worker.gcs_error_subscriber = GcsErrorSubscriber(
channel=worker.gcs_channel.channel())
worker.gcs_log_subscriber = GcsLogSubscriber(
channel=worker.gcs_channel.channel())
# Initialize some fields.
if mode in (WORKER_MODE, RESTORE_WORKER_MODE, SPILL_WORKER_MODE):
@ -1575,8 +1588,9 @@ def disconnect(exiting_interpreter=False):
# should be handled cleanly in the worker object's destructor and not
# in this disconnect method.
worker.threads_stopped.set()
if hasattr(worker, "gcs_error_subscriber"):
if worker.gcs_pubsub_enabled:
worker.gcs_error_subscriber.close()
worker.gcs_log_subscriber.close()
if hasattr(worker, "import_thread"):
worker.import_thread.join_import_thread()
if hasattr(worker, "listener_thread"):