mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[Core][Pubsub][Logging 2/n] Use GCS pubsub for logs (#20492)
Using Ray pubsub for publishing and subscribing logs via GCS, from Python worker, log importer, dashboard and unit tests. This change is guarded behind the RAY_gcs_grpc_based_pubsub feature flag.
This commit is contained in:
parent
bec719d823
commit
95c26eec26
6 changed files with 91 additions and 42 deletions
|
@ -208,6 +208,7 @@ class DashboardHead:
|
|||
self.gcs_subscriber = GcsAioSubscriber(
|
||||
channel=self.aiogrpc_gcs_channel)
|
||||
await self.gcs_subscriber.subscribe_error()
|
||||
await self.gcs_subscriber.subscribe_logs()
|
||||
|
||||
self.health_check_thread = GCSHealthCheckThread(gcs_address)
|
||||
self.health_check_thread.start()
|
||||
|
|
|
@ -259,20 +259,29 @@ class NodeHead(dashboard_utils.DashboardHeadModule):
|
|||
DataSource.ip_and_pid_to_logs[ip] = logs_for_ip
|
||||
logger.info(f"Received a log for {ip} and {pid}")
|
||||
|
||||
aioredis_client = self._dashboard_head.aioredis_client
|
||||
receiver = Receiver()
|
||||
if self._dashboard_head.gcs_subscriber:
|
||||
while True:
|
||||
log_batch = await \
|
||||
self._dashboard_head.gcs_subscriber.poll_logs()
|
||||
try:
|
||||
process_log_batch(log_batch)
|
||||
except Exception:
|
||||
logger.exception("Error receiving log from GCS.")
|
||||
else:
|
||||
aioredis_client = self._dashboard_head.aioredis_client
|
||||
receiver = Receiver()
|
||||
|
||||
channel = receiver.channel(gcs_utils.LOG_FILE_CHANNEL)
|
||||
await aioredis_client.subscribe(channel)
|
||||
logger.info("Subscribed to %s", channel)
|
||||
channel = receiver.channel(gcs_utils.LOG_FILE_CHANNEL)
|
||||
await aioredis_client.subscribe(channel)
|
||||
logger.info("Subscribed to %s", channel)
|
||||
|
||||
async for sender, msg in receiver.iter():
|
||||
try:
|
||||
data = json.loads(ray._private.utils.decode(msg))
|
||||
data["pid"] = str(data["pid"])
|
||||
process_log_batch(data)
|
||||
except Exception:
|
||||
logger.exception("Error receiving log from Redis.")
|
||||
async for sender, msg in receiver.iter():
|
||||
try:
|
||||
data = json.loads(ray._private.utils.decode(msg))
|
||||
data["pid"] = str(data["pid"])
|
||||
process_log_batch(data)
|
||||
except Exception:
|
||||
logger.exception("Error receiving log from Redis.")
|
||||
|
||||
async def _update_error_info(self):
|
||||
def process_error(error_data):
|
||||
|
|
|
@ -12,6 +12,7 @@ import time
|
|||
import traceback
|
||||
|
||||
import ray.ray_constants as ray_constants
|
||||
import ray._private.gcs_pubsub as gcs_pubsub
|
||||
import ray._private.gcs_utils as gcs_utils
|
||||
import ray._private.services as services
|
||||
import ray._private.utils
|
||||
|
@ -98,6 +99,10 @@ class LogMonitor:
|
|||
self.logs_dir = logs_dir
|
||||
self.redis_client = ray._private.services.create_redis_client(
|
||||
redis_address, password=redis_password)
|
||||
self.publisher = None
|
||||
if gcs_pubsub.gcs_pubsub_enabled():
|
||||
gcs_addr = gcs_utils.get_gcs_address_from_redis(self.redis_client)
|
||||
self.publisher = gcs_pubsub.GcsPublisher(address=gcs_addr)
|
||||
self.log_filenames = set()
|
||||
self.open_file_infos = []
|
||||
self.closed_file_infos = []
|
||||
|
@ -270,8 +275,11 @@ class LogMonitor:
|
|||
"actor_name": file_info.actor_name,
|
||||
"task_name": file_info.task_name,
|
||||
}
|
||||
self.redis_client.publish(gcs_utils.LOG_FILE_CHANNEL,
|
||||
json.dumps(data))
|
||||
if self.publisher:
|
||||
self.publisher.publish_logs(data)
|
||||
else:
|
||||
self.redis_client.publish(gcs_utils.LOG_FILE_CHANNEL,
|
||||
json.dumps(data))
|
||||
anything_published = True
|
||||
lines_to_publish = []
|
||||
|
||||
|
|
|
@ -26,7 +26,8 @@ import ray._private.memory_monitor as memory_monitor
|
|||
from ray.core.generated import gcs_pb2
|
||||
from ray.core.generated import node_manager_pb2
|
||||
from ray.core.generated import node_manager_pb2_grpc
|
||||
from ray._private.gcs_pubsub import gcs_pubsub_enabled, GcsErrorSubscriber
|
||||
from ray._private.gcs_pubsub import gcs_pubsub_enabled, GcsErrorSubscriber, \
|
||||
GcsLogSubscriber
|
||||
from ray._private.tls_utils import generate_self_signed_tls_certs
|
||||
from ray.util.queue import Queue, _QueueActor, Empty
|
||||
from ray.scripts.scripts import main as ray_main
|
||||
|
@ -554,11 +555,15 @@ def get_error_message(subscriber, num=1e6, error_type=None, timeout=20):
|
|||
|
||||
def init_log_pubsub():
|
||||
"""Initialize redis error info pub/sub"""
|
||||
p = ray.worker.global_worker.redis_client.pubsub(
|
||||
ignore_subscribe_messages=True)
|
||||
log_pubsub_channel = gcs_utils.LOG_FILE_CHANNEL
|
||||
p.psubscribe(log_pubsub_channel)
|
||||
return p
|
||||
if gcs_pubsub_enabled():
|
||||
s = GcsLogSubscriber(
|
||||
channel=ray.worker.global_worker.gcs_channel.channel())
|
||||
s.subscribe()
|
||||
else:
|
||||
s = ray.worker.global_worker.redis_client.pubsub(
|
||||
ignore_subscribe_messages=True)
|
||||
s.psubscribe(gcs_utils.LOG_FILE_CHANNEL)
|
||||
return s
|
||||
|
||||
|
||||
def get_log_message(subscriber,
|
||||
|
@ -576,11 +581,17 @@ def get_log_message(subscriber,
|
|||
deadline = time.time() + timeout
|
||||
msgs = []
|
||||
while time.time() < deadline and len(msgs) < num:
|
||||
msg = subscriber.get_message()
|
||||
if msg is None:
|
||||
time.sleep(0.01)
|
||||
continue
|
||||
logs_data = json.loads(ray._private.utils.decode(msg["data"]))
|
||||
if isinstance(subscriber, GcsLogSubscriber):
|
||||
logs_data = subscriber.poll(timeout=deadline - time.time())
|
||||
if not logs_data:
|
||||
# Timed out before any data is received.
|
||||
break
|
||||
else:
|
||||
msg = subscriber.get_message()
|
||||
if msg is None:
|
||||
time.sleep(0.01)
|
||||
continue
|
||||
logs_data = json.loads(ray._private.utils.decode(msg["data"]))
|
||||
|
||||
if job_id and job_id != logs_data["job"]:
|
||||
continue
|
||||
|
@ -607,11 +618,17 @@ def get_log_batch(subscriber,
|
|||
deadline = time.time() + timeout
|
||||
batches = []
|
||||
while time.time() < deadline and len(batches) < num:
|
||||
msg = subscriber.get_message()
|
||||
if msg is None:
|
||||
time.sleep(0.01)
|
||||
continue
|
||||
logs_data = json.loads(ray._private.utils.decode(msg["data"]))
|
||||
if isinstance(subscriber, GcsLogSubscriber):
|
||||
logs_data = subscriber.poll(timeout=deadline - time.time())
|
||||
if not logs_data:
|
||||
# Timed out before any data is received.
|
||||
break
|
||||
else:
|
||||
msg = subscriber.get_message()
|
||||
if msg is None:
|
||||
time.sleep(0.01)
|
||||
continue
|
||||
logs_data = json.loads(ray._private.utils.decode(msg["data"]))
|
||||
|
||||
if job_id and job_id != logs_data["job"]:
|
||||
continue
|
||||
|
|
|
@ -200,13 +200,13 @@ def test_subscribe_two_channels(ray_start_regular):
|
|||
"task_name": "test task",
|
||||
})
|
||||
|
||||
t1.join(timeout=1)
|
||||
assert not t1.is_alive()
|
||||
assert len(errors) == num_messages
|
||||
t1.join(timeout=10)
|
||||
assert not t1.is_alive(), len(errors)
|
||||
assert len(errors) == num_messages, len(errors)
|
||||
|
||||
t2.join(timeout=1)
|
||||
assert not t2.is_alive()
|
||||
assert len(logs) == num_messages
|
||||
t2.join(timeout=10)
|
||||
assert not t2.is_alive(), len(logs)
|
||||
assert len(logs) == num_messages, len(logs)
|
||||
|
||||
for i in range(0, num_messages):
|
||||
assert errors[i].error_message == f"error {i}"
|
||||
|
|
|
@ -28,7 +28,7 @@ import ray.serialization as serialization
|
|||
import ray._private.gcs_utils as gcs_utils
|
||||
import ray._private.services as services
|
||||
from ray._private.gcs_pubsub import gcs_pubsub_enabled, GcsPublisher, \
|
||||
GcsErrorSubscriber
|
||||
GcsErrorSubscriber, GcsLogSubscriber
|
||||
from ray._private.runtime_env.py_modules import upload_py_modules_if_needed
|
||||
from ray._private.runtime_env.working_dir import upload_working_dir_if_needed
|
||||
from ray._private.runtime_env.constants import RAY_JOB_CONFIG_JSON_ENV_VAR
|
||||
|
@ -434,8 +434,13 @@ class Worker:
|
|||
def print_logs(self):
|
||||
"""Prints log messages from workers on all nodes in the same job.
|
||||
"""
|
||||
subscriber = self.redis_client.pubsub(ignore_subscribe_messages=True)
|
||||
subscriber.subscribe(gcs_utils.LOG_FILE_CHANNEL)
|
||||
if self.gcs_pubsub_enabled:
|
||||
subscriber = self.gcs_log_subscriber
|
||||
subscriber.subscribe()
|
||||
else:
|
||||
subscriber = self.redis_client.pubsub(
|
||||
ignore_subscribe_messages=True)
|
||||
subscriber.subscribe(gcs_utils.LOG_FILE_CHANNEL)
|
||||
localhost = services.get_node_ip_address()
|
||||
try:
|
||||
# Keep track of the number of consecutive log messages that have
|
||||
|
@ -449,7 +454,10 @@ class Worker:
|
|||
if self.threads_stopped.is_set():
|
||||
return
|
||||
|
||||
msg = subscriber.get_message()
|
||||
if self.gcs_pubsub_enabled:
|
||||
msg = subscriber.poll()
|
||||
else:
|
||||
msg = subscriber.get_message()
|
||||
if msg is None:
|
||||
num_consecutive_messages_received = 0
|
||||
self.threads_stopped.wait(timeout=0.01)
|
||||
|
@ -463,7 +471,10 @@ class Worker:
|
|||
"logs to the driver, use "
|
||||
"'ray.init(log_to_driver=False)'.")
|
||||
|
||||
data = json.loads(ray._private.utils.decode(msg["data"]))
|
||||
if self.gcs_pubsub_enabled:
|
||||
data = msg
|
||||
else:
|
||||
data = json.loads(ray._private.utils.decode(msg["data"]))
|
||||
|
||||
# Don't show logs from other drivers.
|
||||
if (self.filter_logs_by_job and data["job"]
|
||||
|
@ -1373,6 +1384,8 @@ def connect(node,
|
|||
channel=worker.gcs_channel.channel())
|
||||
worker.gcs_error_subscriber = GcsErrorSubscriber(
|
||||
channel=worker.gcs_channel.channel())
|
||||
worker.gcs_log_subscriber = GcsLogSubscriber(
|
||||
channel=worker.gcs_channel.channel())
|
||||
|
||||
# Initialize some fields.
|
||||
if mode in (WORKER_MODE, RESTORE_WORKER_MODE, SPILL_WORKER_MODE):
|
||||
|
@ -1575,8 +1588,9 @@ def disconnect(exiting_interpreter=False):
|
|||
# should be handled cleanly in the worker object's destructor and not
|
||||
# in this disconnect method.
|
||||
worker.threads_stopped.set()
|
||||
if hasattr(worker, "gcs_error_subscriber"):
|
||||
if worker.gcs_pubsub_enabled:
|
||||
worker.gcs_error_subscriber.close()
|
||||
worker.gcs_log_subscriber.close()
|
||||
if hasattr(worker, "import_thread"):
|
||||
worker.import_thread.join_import_thread()
|
||||
if hasattr(worker, "listener_thread"):
|
||||
|
|
Loading…
Add table
Reference in a new issue