mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[Core][Pubsub][Logging 1/n] add logging support to GCS pubsub in Python (#20604)
This PR adds support for publishing and subscribing to logs in Python via GCS pubsub. It also refactors the Python threaded subscriber to support subscribing and calling `close()` from multiple threads. We can also move tests and logging support to another PR, but it will make the purpose of the refactoring seems less obvious.
This commit is contained in:
parent
aabe9229df
commit
a4d3898159
10 changed files with 502 additions and 147 deletions
|
@ -64,6 +64,7 @@ MOCK_MODULES = [
|
|||
"ray.core.generated.common_pb2",
|
||||
"ray.core.generated.runtime_env_common_pb2",
|
||||
"ray.core.generated.gcs_pb2",
|
||||
"ray.core.generated.logging_pb2",
|
||||
"ray.core.generated.ray.protocol.Task",
|
||||
"ray.serve.generated",
|
||||
"ray.serve.generated.serve_pb2",
|
||||
|
|
|
@ -12,10 +12,10 @@ except ImportError:
|
|||
from grpc.experimental import aio as aiogrpc
|
||||
|
||||
import ray._private.gcs_utils as gcs_utils
|
||||
import ray._private.logging_utils as logging_utils
|
||||
from ray.core.generated.gcs_pb2 import ErrorTableData
|
||||
from ray.core.generated import gcs_service_pb2_grpc
|
||||
from ray.core.generated import gcs_service_pb2
|
||||
from ray.core.generated.gcs_pb2 import (
|
||||
ErrorTableData, )
|
||||
from ray.core.generated import pubsub_pb2
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -23,7 +23,7 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
def gcs_pubsub_enabled():
|
||||
"""Checks whether GCS pubsub feature flag is enabled."""
|
||||
return os.environ.get("RAY_gcs_grpc_based_pubsub") not in\
|
||||
return os.environ.get("RAY_gcs_grpc_based_pubsub") not in \
|
||||
[None, "0", "false"]
|
||||
|
||||
|
||||
|
@ -48,11 +48,28 @@ def construct_error_message(job_id, error_type, message, timestamp):
|
|||
return data
|
||||
|
||||
|
||||
class _PublisherBase:
|
||||
@staticmethod
|
||||
def _create_log_request(log_json: dict):
|
||||
job_id = log_json.get("job")
|
||||
return gcs_service_pb2.GcsPublishRequest(pub_messages=[
|
||||
pubsub_pb2.PubMessage(
|
||||
channel_type=pubsub_pb2.RAY_LOG_CHANNEL,
|
||||
key_id=job_id.encode() if job_id else None,
|
||||
log_batch_message=logging_utils.log_batch_dict_to_proto(
|
||||
log_json))
|
||||
])
|
||||
|
||||
|
||||
class _SubscriberBase:
|
||||
def _subscribe_error_request(self):
|
||||
cmd = pubsub_pb2.Command(
|
||||
channel_type=pubsub_pb2.RAY_ERROR_INFO_CHANNEL,
|
||||
subscribe_message={})
|
||||
def __init__(self):
|
||||
# self._subscriber_id needs to match the binary format of a random
|
||||
# SubscriberID / UniqueID, which is 28 (kUniqueIDSize) random bytes.
|
||||
self._subscriber_id = bytes(
|
||||
bytearray(random.getrandbits(8) for _ in range(28)))
|
||||
|
||||
def _subscribe_request(self, channel):
|
||||
cmd = pubsub_pb2.Command(channel_type=channel, subscribe_message={})
|
||||
req = gcs_service_pb2.GcsSubscriberCommandBatchRequest(
|
||||
subscriber_id=self._subscriber_id, commands=[cmd])
|
||||
return req
|
||||
|
@ -61,18 +78,31 @@ class _SubscriberBase:
|
|||
return gcs_service_pb2.GcsSubscriberPollRequest(
|
||||
subscriber_id=self._subscriber_id)
|
||||
|
||||
def _unsubscribe_request(self):
|
||||
def _unsubscribe_request(self, channels):
|
||||
req = gcs_service_pb2.GcsSubscriberCommandBatchRequest(
|
||||
subscriber_id=self._subscriber_id, commands=[])
|
||||
if self._subscribed_error:
|
||||
cmd = pubsub_pb2.Command(
|
||||
channel_type=pubsub_pb2.RAY_ERROR_INFO_CHANNEL,
|
||||
unsubscribe_message={})
|
||||
req.commands.append(cmd)
|
||||
for channel in channels:
|
||||
req.commands.append(
|
||||
pubsub_pb2.Command(
|
||||
channel_type=channel, unsubscribe_message={}))
|
||||
return req
|
||||
|
||||
@staticmethod
|
||||
def _pop_error_info(queue):
|
||||
if len(queue) == 0:
|
||||
return None, None
|
||||
msg = queue.popleft()
|
||||
return msg.key_id, msg.error_info_message
|
||||
|
||||
class GcsPublisher:
|
||||
@staticmethod
|
||||
def _pop_log_batch(queue):
|
||||
if len(queue) == 0:
|
||||
return None
|
||||
msg = queue.popleft()
|
||||
return logging_utils.log_batch_proto_to_dict(msg.log_batch_message)
|
||||
|
||||
|
||||
class GcsPublisher(_PublisherBase):
|
||||
"""Publisher to GCS."""
|
||||
|
||||
def __init__(self, *, address: str = None, channel: grpc.Channel = None):
|
||||
|
@ -94,16 +124,119 @@ class GcsPublisher:
|
|||
req = gcs_service_pb2.GcsPublishRequest(pub_messages=[msg])
|
||||
self._stub.GcsPublish(req)
|
||||
|
||||
def publish_logs(self, log_batch: dict) -> None:
|
||||
"""Publishes logs to GCS."""
|
||||
req = self._create_log_request(log_batch)
|
||||
self._stub.GcsPublish(req)
|
||||
|
||||
class GcsSubscriber(_SubscriberBase):
|
||||
"""Subscriber to GCS. Thread safe.
|
||||
|
||||
class _SyncSubscriber(_SubscriberBase):
|
||||
def __init__(
|
||||
self,
|
||||
pubsub_channel_type,
|
||||
address: str = None,
|
||||
channel: grpc.Channel = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if address:
|
||||
assert channel is None, \
|
||||
"address and channel cannot both be specified"
|
||||
channel = gcs_utils.create_gcs_channel(address)
|
||||
else:
|
||||
assert channel is not None, \
|
||||
"One of address and channel must be specified"
|
||||
# GRPC stub to GCS pubsub.
|
||||
self._stub = gcs_service_pb2_grpc.InternalPubSubGcsServiceStub(channel)
|
||||
|
||||
# Type of the channel.
|
||||
self._channel = pubsub_channel_type
|
||||
# Protects multi-threaded read and write of self._queue.
|
||||
self._lock = threading.Lock()
|
||||
# A queue of received PubMessage.
|
||||
self._queue = deque()
|
||||
# Indicates whether the subscriber has closed.
|
||||
self._close = threading.Event()
|
||||
|
||||
def subscribe(self) -> None:
|
||||
"""Registers a subscription for the subscriber's channel type.
|
||||
|
||||
Before the registration, published messages in the channel will not be
|
||||
saved for the subscriber.
|
||||
"""
|
||||
with self._lock:
|
||||
if self._close.is_set():
|
||||
return
|
||||
req = self._subscribe_request(self._channel)
|
||||
self._stub.GcsSubscriberCommandBatch(req, timeout=30)
|
||||
|
||||
def _poll_locked(self, timeout=None) -> None:
|
||||
assert self._lock.locked()
|
||||
|
||||
# Poll until data becomes available.
|
||||
while len(self._queue) == 0:
|
||||
if self._close.is_set():
|
||||
return
|
||||
|
||||
fut = self._stub.GcsSubscriberPoll.future(
|
||||
self._poll_request(), timeout=timeout)
|
||||
# Wait for result to become available, or cancel if the
|
||||
# subscriber has closed.
|
||||
while True:
|
||||
try:
|
||||
# Use 1s timeout to check for subscriber closing
|
||||
# periodically.
|
||||
fut.result(timeout=1)
|
||||
break
|
||||
except grpc.FutureTimeoutError:
|
||||
# Subscriber has closed. Cancel inflight request and
|
||||
# return from polling.
|
||||
if self._close.is_set():
|
||||
fut.cancel()
|
||||
return
|
||||
# GRPC has not replied, continue waiting.
|
||||
continue
|
||||
except grpc.RpcError as e:
|
||||
# Choose to not raise deadline exceeded errors to the
|
||||
# caller. Instead return None. This can be revisited later.
|
||||
if e.code() == grpc.StatusCode.DEADLINE_EXCEEDED:
|
||||
return
|
||||
raise
|
||||
|
||||
if fut.done():
|
||||
for msg in fut.result().pub_messages:
|
||||
if msg.channel_type != self._channel:
|
||||
logger.warn(
|
||||
f"Ignoring message from unsubscribed channel {msg}"
|
||||
)
|
||||
continue
|
||||
self._queue.append(msg)
|
||||
|
||||
def close(self) -> None:
|
||||
"""Closes the subscriber and its active subscription."""
|
||||
|
||||
# Mark close to terminate inflight polling and prevent future requests.
|
||||
self._close.set()
|
||||
req = self._unsubscribe_request(channels=[self._channel])
|
||||
try:
|
||||
self._stub.GcsSubscriberCommandBatch(req, timeout=5)
|
||||
except Exception:
|
||||
pass
|
||||
self._stub = None
|
||||
|
||||
|
||||
class GcsErrorSubscriber(_SyncSubscriber):
|
||||
"""Subscriber to error info. Thread safe.
|
||||
|
||||
Usage example:
|
||||
subscriber = GcsSubscriber()
|
||||
subscriber.subscribe_error()
|
||||
subscriber = GcsErrorSubscriber()
|
||||
# Subscribe to the error channel.
|
||||
subscriber.subscribe()
|
||||
...
|
||||
while running:
|
||||
error_id, error_data = subscriber.poll_error()
|
||||
error_id, error_data = subscriber.poll()
|
||||
......
|
||||
# Unsubscribe from the error channels.
|
||||
subscriber.close()
|
||||
"""
|
||||
|
||||
|
@ -112,90 +245,55 @@ class GcsSubscriber(_SubscriberBase):
|
|||
address: str = None,
|
||||
channel: grpc.Channel = None,
|
||||
):
|
||||
if address:
|
||||
assert channel is None, \
|
||||
"address and channel cannot both be specified"
|
||||
channel = gcs_utils.create_gcs_channel(address)
|
||||
else:
|
||||
assert channel is not None, \
|
||||
"One of address and channel must be specified"
|
||||
self._lock = threading.RLock()
|
||||
self._stub = gcs_service_pb2_grpc.InternalPubSubGcsServiceStub(channel)
|
||||
self._subscriber_id = bytes(
|
||||
bytearray(random.getrandbits(8) for _ in range(28)))
|
||||
# Whether error info has been subscribed.
|
||||
self._subscribed_error = False
|
||||
# Buffer for holding error info.
|
||||
self._errors = deque()
|
||||
# Future for indicating whether the subscriber has closed.
|
||||
self._close = threading.Event()
|
||||
super().__init__(pubsub_pb2.RAY_ERROR_INFO_CHANNEL, address, channel)
|
||||
|
||||
def subscribe_error(self) -> None:
|
||||
"""Registers a subscription for error info.
|
||||
def poll(self, timeout=None) -> Tuple[bytes, ErrorTableData]:
|
||||
"""Polls for new error messages.
|
||||
|
||||
Before the registration, published errors will not be saved for the
|
||||
subscriber.
|
||||
Returns:
|
||||
A tuple of error message ID and ErrorTableData proto message,
|
||||
or None, None if polling times out or subscriber closed.
|
||||
"""
|
||||
with self._lock:
|
||||
if self._close.is_set():
|
||||
return
|
||||
if not self._subscribed_error:
|
||||
req = self._subscribe_error_request()
|
||||
self._stub.GcsSubscriberCommandBatch(req, timeout=30)
|
||||
self._subscribed_error = True
|
||||
self._poll_locked(timeout=timeout)
|
||||
return self._pop_error_info(self._queue)
|
||||
|
||||
def poll_error(self, timeout=None) -> Tuple[bytes, ErrorTableData]:
|
||||
"""Polls for new error messages."""
|
||||
|
||||
class GcsLogSubscriber(_SyncSubscriber):
|
||||
"""Subscriber to logs. Thread safe.
|
||||
|
||||
Usage example:
|
||||
subscriber = GcsLogSubscriber()
|
||||
# Subscribe to the log channel.
|
||||
subscriber.subscribe()
|
||||
...
|
||||
while running:
|
||||
log = subscriber.poll()
|
||||
......
|
||||
# Unsubscribe from the log channel.
|
||||
subscriber.close()
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
address: str = None,
|
||||
channel: grpc.Channel = None,
|
||||
):
|
||||
super().__init__(pubsub_pb2.RAY_LOG_CHANNEL, address, channel)
|
||||
|
||||
def poll(self, timeout=None) -> Tuple[bytes, ErrorTableData]:
|
||||
"""Polls for new log messages.
|
||||
|
||||
Returns:
|
||||
A dict containing a batch of log lines and their metadata,
|
||||
or None if polling times out or subscriber closed.
|
||||
"""
|
||||
with self._lock:
|
||||
if self._close.is_set():
|
||||
return
|
||||
|
||||
if len(self._errors) == 0:
|
||||
req = self._poll_request()
|
||||
fut = self._stub.GcsSubscriberPoll.future(req, timeout=timeout)
|
||||
# Wait for result to become available, or cancel if the
|
||||
# subscriber has closed.
|
||||
while True:
|
||||
try:
|
||||
fut.result(timeout=1)
|
||||
break
|
||||
except grpc.FutureTimeoutError:
|
||||
# Subscriber has closed. Cancel inflight the request
|
||||
# and return from polling.
|
||||
if self._close.is_set():
|
||||
fut.cancel()
|
||||
return None, None
|
||||
# GRPC has not replied, continue waiting.
|
||||
continue
|
||||
except Exception:
|
||||
# GRPC error, including deadline exceeded.
|
||||
raise
|
||||
if fut.done():
|
||||
for msg in fut.result().pub_messages:
|
||||
self._errors.append((msg.key_id,
|
||||
msg.error_info_message))
|
||||
|
||||
if len(self._errors) == 0:
|
||||
return None, None
|
||||
return self._errors.popleft()
|
||||
|
||||
def close(self) -> None:
|
||||
"""Closes the subscriber and its active subscriptions."""
|
||||
# Mark close to terminate inflight polling and prevent future requests.
|
||||
self._close.set()
|
||||
with self._lock:
|
||||
if not self._stub:
|
||||
# Subscriber already closed.
|
||||
return
|
||||
req = self._unsubscribe_request()
|
||||
try:
|
||||
self._stub.GcsSubscriberCommandBatch(req, timeout=30)
|
||||
except Exception:
|
||||
pass
|
||||
self._stub = None
|
||||
self._poll_locked(timeout=timeout)
|
||||
return self._pop_log_batch(self._queue)
|
||||
|
||||
|
||||
class GcsAioPublisher:
|
||||
class GcsAioPublisher(_PublisherBase):
|
||||
"""Publisher to GCS. Uses async io."""
|
||||
|
||||
def __init__(self, address: str = None, channel: aiogrpc.Channel = None):
|
||||
|
@ -218,6 +316,11 @@ class GcsAioPublisher:
|
|||
req = gcs_service_pb2.GcsPublishRequest(pub_messages=[msg])
|
||||
await self._stub.GcsPublish(req)
|
||||
|
||||
async def publish_logs(self, log_batch: dict) -> None:
|
||||
"""Publishes logs to GCS."""
|
||||
req = self._create_log_request(log_batch)
|
||||
await self._stub.GcsPublish(req)
|
||||
|
||||
|
||||
class GcsAioSubscriber(_SubscriberBase):
|
||||
"""Async io subscriber to GCS.
|
||||
|
@ -232,6 +335,8 @@ class GcsAioSubscriber(_SubscriberBase):
|
|||
"""
|
||||
|
||||
def __init__(self, address: str = None, channel: aiogrpc.Channel = None):
|
||||
super().__init__()
|
||||
|
||||
if address:
|
||||
assert channel is None, \
|
||||
"address and channel cannot both be specified"
|
||||
|
@ -239,13 +344,9 @@ class GcsAioSubscriber(_SubscriberBase):
|
|||
else:
|
||||
assert channel is not None, \
|
||||
"One of address and channel must be specified"
|
||||
# Message queue for each channel.
|
||||
self._messages = {}
|
||||
self._stub = gcs_service_pb2_grpc.InternalPubSubGcsServiceStub(channel)
|
||||
self._subscriber_id = bytes(
|
||||
bytearray(random.getrandbits(8) for _ in range(28)))
|
||||
# Whether error info has been subscribed.
|
||||
self._subscribed_error = False
|
||||
# Buffer for holding error info.
|
||||
self._errors = deque()
|
||||
|
||||
async def subscribe_error(self) -> None:
|
||||
"""Registers a subscription for error info.
|
||||
|
@ -253,28 +354,55 @@ class GcsAioSubscriber(_SubscriberBase):
|
|||
Before the registration, published errors will not be saved for the
|
||||
subscriber.
|
||||
"""
|
||||
if not self._subscribed_error:
|
||||
req = self._subscribe_error_request()
|
||||
if pubsub_pb2.RAY_ERROR_INFO_CHANNEL not in self._messages:
|
||||
self._messages[pubsub_pb2.RAY_ERROR_INFO_CHANNEL] = deque()
|
||||
req = self._subscribe_request(pubsub_pb2.RAY_ERROR_INFO_CHANNEL)
|
||||
await self._stub.GcsSubscriberCommandBatch(req, timeout=30)
|
||||
self._subscribed_error = True
|
||||
|
||||
async def subscribe_logs(self) -> None:
|
||||
"""Registers a subscription for logs.
|
||||
|
||||
Before the registration, published logs will not be saved for the
|
||||
subscriber.
|
||||
"""
|
||||
if pubsub_pb2.RAY_LOG_CHANNEL not in self._messages:
|
||||
self._messages[pubsub_pb2.RAY_LOG_CHANNEL] = deque()
|
||||
req = self._subscribe_request(pubsub_pb2.RAY_LOG_CHANNEL)
|
||||
await self._stub.GcsSubscriberCommandBatch(req, timeout=30)
|
||||
|
||||
def _enqueue_poll_response(self, resp):
|
||||
for msg in resp.pub_messages:
|
||||
queue = self._messages.get(msg.channel_type)
|
||||
if queue is not None:
|
||||
queue.append(msg)
|
||||
else:
|
||||
logger.warn(
|
||||
f"Ignoring message from unsubscribed channel {msg}")
|
||||
|
||||
async def poll_error(self, timeout=None) -> Tuple[bytes, ErrorTableData]:
|
||||
"""Polls for new error messages."""
|
||||
if len(self._errors) == 0:
|
||||
queue = self._messages.get(pubsub_pb2.RAY_ERROR_INFO_CHANNEL)
|
||||
while len(queue) == 0:
|
||||
req = self._poll_request()
|
||||
reply = await self._stub.GcsSubscriberPoll(req, timeout=timeout)
|
||||
for msg in reply.pub_messages:
|
||||
self._errors.append((msg.key_id, msg.error_info_message))
|
||||
self._enqueue_poll_response(reply)
|
||||
|
||||
if len(self._errors) == 0:
|
||||
return None, None
|
||||
return self._errors.popleft()
|
||||
return self._pop_error_info(queue)
|
||||
|
||||
async def poll_logs(self, timeout=None) -> dict:
|
||||
"""Polls for new error messages."""
|
||||
queue = self._messages.get(pubsub_pb2.RAY_LOG_CHANNEL)
|
||||
while len(queue) == 0:
|
||||
req = self._poll_request()
|
||||
reply = await self._stub.GcsSubscriberPoll(req, timeout=timeout)
|
||||
self._enqueue_poll_response(reply)
|
||||
|
||||
return self._pop_log_batch(queue)
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Closes the subscriber and its active subscriptions."""
|
||||
req = self._unsubscribe_request()
|
||||
req = self._unsubscribe_request(self._messages.keys())
|
||||
try:
|
||||
await self._stub.GcsSubscriberCommandBatch(req, timeout=30)
|
||||
await self._stub.GcsSubscriberCommandBatch(req, timeout=5)
|
||||
except Exception:
|
||||
pass
|
||||
self._subscribed_error = False
|
||||
|
|
29
python/ray/_private/logging_utils.py
Normal file
29
python/ray/_private/logging_utils.py
Normal file
|
@ -0,0 +1,29 @@
|
|||
from ray.core.generated.logging_pb2 import LogBatch
|
||||
|
||||
|
||||
def log_batch_dict_to_proto(log_json: dict) -> LogBatch:
|
||||
"""Converts a dict containing a batch of logs to a LogBatch proto."""
|
||||
return LogBatch(
|
||||
ip=log_json.get("ip"),
|
||||
# Cast to support string pid like "gcs".
|
||||
pid=str(log_json.get("pid")) if log_json.get("pid") else None,
|
||||
# Job ID as a hex string.
|
||||
job_id=log_json.get("job"),
|
||||
is_error=bool(log_json.get("is_err")),
|
||||
lines=log_json.get("lines"),
|
||||
actor_name=log_json.get("actor_name"),
|
||||
task_name=log_json.get("task_name"),
|
||||
)
|
||||
|
||||
|
||||
def log_batch_proto_to_dict(log_batch: LogBatch) -> dict:
|
||||
"""Converts a LogBatch proto to a dict containing a batch of logs."""
|
||||
return {
|
||||
"ip": log_batch.ip,
|
||||
"pid": log_batch.pid,
|
||||
"job": log_batch.job_id,
|
||||
"is_err": log_batch.is_error,
|
||||
"lines": log_batch.lines,
|
||||
"actor_name": log_batch.actor_name,
|
||||
"task_name": log_batch.task_name,
|
||||
}
|
|
@ -23,10 +23,10 @@ import ray._private.services
|
|||
import ray._private.utils
|
||||
import ray._private.gcs_utils as gcs_utils
|
||||
import ray._private.memory_monitor as memory_monitor
|
||||
from ray.core.generated import gcs_pb2
|
||||
from ray.core.generated import node_manager_pb2
|
||||
from ray.core.generated import node_manager_pb2_grpc
|
||||
from ray.core.generated import gcs_pb2
|
||||
from ray._private.gcs_pubsub import gcs_pubsub_enabled, GcsSubscriber
|
||||
from ray._private.gcs_pubsub import gcs_pubsub_enabled, GcsErrorSubscriber
|
||||
from ray._private.tls_utils import generate_self_signed_tls_certs
|
||||
from ray.util.queue import Queue, _QueueActor, Empty
|
||||
from ray.scripts.scripts import main as ray_main
|
||||
|
@ -513,9 +513,9 @@ def get_non_head_nodes(cluster):
|
|||
def init_error_pubsub():
|
||||
"""Initialize redis error info pub/sub"""
|
||||
if gcs_pubsub_enabled():
|
||||
s = GcsSubscriber(
|
||||
s = GcsErrorSubscriber(
|
||||
channel=ray.worker.global_worker.gcs_channel.channel())
|
||||
s.subscribe_error()
|
||||
s.subscribe()
|
||||
else:
|
||||
s = ray.worker.global_worker.redis_client.pubsub(
|
||||
ignore_subscribe_messages=True)
|
||||
|
@ -532,17 +532,11 @@ def get_error_message(subscriber, num=1e6, error_type=None, timeout=20):
|
|||
deadline = time.time() + timeout
|
||||
msgs = []
|
||||
while time.time() < deadline and len(msgs) < num:
|
||||
if isinstance(subscriber, GcsSubscriber):
|
||||
try:
|
||||
_, error_data = subscriber.poll_error(timeout=deadline -
|
||||
time.time())
|
||||
except grpc.RpcError as e:
|
||||
# Failed to match error message before timeout.
|
||||
if e.code() == grpc.StatusCode.DEADLINE_EXCEEDED:
|
||||
logging.warning("get_error_message() timed out")
|
||||
return []
|
||||
# Otherwise, the error is unexpected.
|
||||
raise
|
||||
if isinstance(subscriber, GcsErrorSubscriber):
|
||||
_, error_data = subscriber.poll(timeout=deadline - time.time())
|
||||
if not error_data:
|
||||
# Timed out before any data is received.
|
||||
break
|
||||
else:
|
||||
msg = subscriber.get_message()
|
||||
if msg is None:
|
||||
|
|
|
@ -1,9 +1,10 @@
|
|||
import sys
|
||||
import threading
|
||||
|
||||
import ray
|
||||
import ray._private.gcs_utils as gcs_utils
|
||||
from ray._private.gcs_pubsub import GcsPublisher, GcsSubscriber, \
|
||||
GcsAioPublisher, GcsAioSubscriber
|
||||
from ray._private.gcs_pubsub import GcsPublisher, GcsErrorSubscriber, \
|
||||
GcsLogSubscriber, GcsAioPublisher, GcsAioSubscriber
|
||||
from ray.core.generated.gcs_pb2 import ErrorTableData
|
||||
import pytest
|
||||
|
||||
|
@ -23,8 +24,8 @@ def test_publish_and_subscribe_error_info(ray_start_regular):
|
|||
|
||||
gcs_server_addr = gcs_utils.get_gcs_address_from_redis(redis)
|
||||
|
||||
subscriber = GcsSubscriber(address=gcs_server_addr)
|
||||
subscriber.subscribe_error()
|
||||
subscriber = GcsErrorSubscriber(address=gcs_server_addr)
|
||||
subscriber.subscribe()
|
||||
|
||||
publisher = GcsPublisher(address=gcs_server_addr)
|
||||
err1 = ErrorTableData(error_message="test error message 1")
|
||||
|
@ -32,8 +33,8 @@ def test_publish_and_subscribe_error_info(ray_start_regular):
|
|||
publisher.publish_error(b"aaa_id", err1)
|
||||
publisher.publish_error(b"bbb_id", err2)
|
||||
|
||||
assert subscriber.poll_error() == (b"aaa_id", err1)
|
||||
assert subscriber.poll_error() == (b"bbb_id", err2)
|
||||
assert subscriber.poll() == (b"aaa_id", err1)
|
||||
assert subscriber.poll() == (b"bbb_id", err2)
|
||||
|
||||
subscriber.close()
|
||||
|
||||
|
@ -69,5 +70,148 @@ async def test_aio_publish_and_subscribe_error_info(ray_start_regular):
|
|||
await subscriber.close()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"ray_start_regular", [{
|
||||
"_system_config": {
|
||||
"gcs_grpc_based_pubsub": True
|
||||
}
|
||||
}],
|
||||
indirect=True)
|
||||
def test_publish_and_subscribe_logs(ray_start_regular):
|
||||
address_info = ray_start_regular
|
||||
redis = ray._private.services.create_redis_client(
|
||||
address_info["redis_address"],
|
||||
password=ray.ray_constants.REDIS_DEFAULT_PASSWORD)
|
||||
|
||||
gcs_server_addr = gcs_utils.get_gcs_address_from_redis(redis)
|
||||
|
||||
subscriber = GcsLogSubscriber(address=gcs_server_addr)
|
||||
subscriber.subscribe()
|
||||
|
||||
publisher = GcsPublisher(address=gcs_server_addr)
|
||||
log_batch = {
|
||||
"ip": "127.0.0.1",
|
||||
"pid": 1234,
|
||||
"job": "0001",
|
||||
"is_err": False,
|
||||
"lines": ["line 1", "line 2"],
|
||||
"actor_name": "test actor",
|
||||
"task_name": "test task",
|
||||
}
|
||||
publisher.publish_logs(log_batch)
|
||||
|
||||
# PID is treated as string.
|
||||
log_batch["pid"] = "1234"
|
||||
assert subscriber.poll() == log_batch
|
||||
|
||||
subscriber.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"ray_start_regular", [{
|
||||
"_system_config": {
|
||||
"gcs_grpc_based_pubsub": True
|
||||
}
|
||||
}],
|
||||
indirect=True)
|
||||
async def test_aio_publish_and_subscribe_logs(ray_start_regular):
|
||||
address_info = ray_start_regular
|
||||
redis = ray._private.services.create_redis_client(
|
||||
address_info["redis_address"],
|
||||
password=ray.ray_constants.REDIS_DEFAULT_PASSWORD)
|
||||
|
||||
gcs_server_addr = gcs_utils.get_gcs_address_from_redis(redis)
|
||||
|
||||
subscriber = GcsAioSubscriber(address=gcs_server_addr)
|
||||
await subscriber.subscribe_logs()
|
||||
|
||||
publisher = GcsAioPublisher(address=gcs_server_addr)
|
||||
log_batch = {
|
||||
"ip": "127.0.0.1",
|
||||
"pid": "gcs",
|
||||
"job": "0001",
|
||||
"is_err": False,
|
||||
"lines": ["line 1", "line 2"],
|
||||
"actor_name": "test actor",
|
||||
"task_name": "test task",
|
||||
}
|
||||
await publisher.publish_logs(log_batch)
|
||||
|
||||
assert await subscriber.poll_logs() == log_batch
|
||||
|
||||
await subscriber.close()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"ray_start_regular", [{
|
||||
"_system_config": {
|
||||
"gcs_grpc_based_pubsub": True
|
||||
}
|
||||
}],
|
||||
indirect=True)
|
||||
def test_subscribe_two_channels(ray_start_regular):
|
||||
"""Tests concurrently subscribing to two channels work."""
|
||||
|
||||
address_info = ray_start_regular
|
||||
redis = ray._private.services.create_redis_client(
|
||||
address_info["redis_address"],
|
||||
password=ray.ray_constants.REDIS_DEFAULT_PASSWORD)
|
||||
|
||||
gcs_server_addr = gcs_utils.get_gcs_address_from_redis(redis)
|
||||
|
||||
num_messages = 100
|
||||
|
||||
errors = []
|
||||
|
||||
def receive_errors():
|
||||
subscriber = GcsErrorSubscriber(address=gcs_server_addr)
|
||||
subscriber.subscribe()
|
||||
while len(errors) < num_messages:
|
||||
_, msg = subscriber.poll()
|
||||
errors.append(msg)
|
||||
|
||||
logs = []
|
||||
|
||||
def receive_logs():
|
||||
subscriber = GcsLogSubscriber(address=gcs_server_addr)
|
||||
subscriber.subscribe()
|
||||
while len(logs) < num_messages:
|
||||
log_batch = subscriber.poll()
|
||||
logs.append(log_batch)
|
||||
|
||||
t1 = threading.Thread(target=receive_errors)
|
||||
t1.start()
|
||||
|
||||
t2 = threading.Thread(target=receive_logs)
|
||||
t2.start()
|
||||
|
||||
publisher = GcsPublisher(address=gcs_server_addr)
|
||||
for i in range(0, num_messages):
|
||||
publisher.publish_error(
|
||||
b"msg_id", ErrorTableData(error_message=f"error {i}"))
|
||||
publisher.publish_logs({
|
||||
"ip": "127.0.0.1",
|
||||
"pid": "gcs",
|
||||
"job": "0001",
|
||||
"is_err": False,
|
||||
"lines": [f"line {i}"],
|
||||
"actor_name": "test actor",
|
||||
"task_name": "test task",
|
||||
})
|
||||
|
||||
t1.join(timeout=1)
|
||||
assert not t1.is_alive()
|
||||
assert len(errors) == num_messages
|
||||
|
||||
t2.join(timeout=1)
|
||||
assert not t2.is_alive()
|
||||
assert len(logs) == num_messages
|
||||
|
||||
for i in range(0, num_messages):
|
||||
assert errors[i].error_message == f"error {i}"
|
||||
assert logs[i]["lines"][0] == f"line {i}"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(pytest.main(["-v", __file__]))
|
||||
|
|
|
@ -28,7 +28,7 @@ import ray.serialization as serialization
|
|||
import ray._private.gcs_utils as gcs_utils
|
||||
import ray._private.services as services
|
||||
from ray._private.gcs_pubsub import gcs_pubsub_enabled, GcsPublisher, \
|
||||
GcsSubscriber
|
||||
GcsErrorSubscriber
|
||||
from ray._private.runtime_env.py_modules import upload_py_modules_if_needed
|
||||
from ray._private.runtime_env.working_dir import upload_working_dir_if_needed
|
||||
from ray._private.runtime_env.constants import RAY_JOB_CONFIG_JSON_ENV_VAR
|
||||
|
@ -1266,7 +1266,7 @@ def listen_error_messages_from_gcs(worker, threads_stopped):
|
|||
# gcs_subscriber.poll_error() will still be processed in the loop.
|
||||
|
||||
# TODO: we should just subscribe to the errors for this specific job.
|
||||
worker.gcs_subscriber.subscribe_error()
|
||||
worker.gcs_error_subscriber.subscribe()
|
||||
|
||||
try:
|
||||
if _internal_kv_initialized():
|
||||
|
@ -1281,7 +1281,7 @@ def listen_error_messages_from_gcs(worker, threads_stopped):
|
|||
if threads_stopped.is_set():
|
||||
return
|
||||
|
||||
_, error_data = worker.gcs_subscriber.poll_error()
|
||||
_, error_data = worker.gcs_error_subscriber.poll()
|
||||
if error_data is None:
|
||||
continue
|
||||
if error_data.job_id not in [
|
||||
|
@ -1371,7 +1371,7 @@ def connect(node,
|
|||
if worker.gcs_pubsub_enabled:
|
||||
worker.gcs_publisher = GcsPublisher(
|
||||
channel=worker.gcs_channel.channel())
|
||||
worker.gcs_subscriber = GcsSubscriber(
|
||||
worker.gcs_error_subscriber = GcsErrorSubscriber(
|
||||
channel=worker.gcs_channel.channel())
|
||||
|
||||
# Initialize some fields.
|
||||
|
@ -1575,8 +1575,8 @@ def disconnect(exiting_interpreter=False):
|
|||
# should be handled cleanly in the worker object's destructor and not
|
||||
# in this disconnect method.
|
||||
worker.threads_stopped.set()
|
||||
if hasattr(worker, "gcs_subscriber"):
|
||||
worker.gcs_subscriber.close()
|
||||
if hasattr(worker, "gcs_error_subscriber"):
|
||||
worker.gcs_error_subscriber.close()
|
||||
if hasattr(worker, "import_thread"):
|
||||
worker.import_thread.join_import_thread()
|
||||
if hasattr(worker, "listener_thread"):
|
||||
|
|
|
@ -70,13 +70,16 @@ void GcsServer::Start() {
|
|||
// Init grpc based pubsub on GCS.
|
||||
// TODO: Move this into GcsPublisher.
|
||||
inner_publisher = std::make_unique<pubsub::Publisher>(
|
||||
/*channels=*/std::vector<
|
||||
rpc::ChannelType>{rpc::ChannelType::GCS_ACTOR_CHANNEL,
|
||||
rpc::ChannelType::GCS_JOB_CHANNEL,
|
||||
rpc::ChannelType::GCS_NODE_INFO_CHANNEL,
|
||||
rpc::ChannelType::GCS_NODE_RESOURCE_CHANNEL,
|
||||
rpc::ChannelType::GCS_WORKER_DELTA_CHANNEL,
|
||||
rpc::ChannelType::RAY_ERROR_INFO_CHANNEL},
|
||||
/*channels=*/
|
||||
std::vector<rpc::ChannelType>{
|
||||
rpc::ChannelType::GCS_ACTOR_CHANNEL,
|
||||
rpc::ChannelType::GCS_JOB_CHANNEL,
|
||||
rpc::ChannelType::GCS_NODE_INFO_CHANNEL,
|
||||
rpc::ChannelType::GCS_NODE_RESOURCE_CHANNEL,
|
||||
rpc::ChannelType::GCS_WORKER_DELTA_CHANNEL,
|
||||
rpc::ChannelType::RAY_ERROR_INFO_CHANNEL,
|
||||
rpc::ChannelType::RAY_LOG_CHANNEL,
|
||||
},
|
||||
/*periodical_runner=*/&pubsub_periodical_runner_,
|
||||
/*get_time_ms=*/[]() { return absl::GetCurrentTimeNanos() / 1e6; },
|
||||
/*subscriber_timeout_ms=*/RayConfig::instance().subscriber_timeout_ms(),
|
||||
|
|
|
@ -57,6 +57,21 @@ python_grpc_compile(
|
|||
deps = [":gcs_proto"],
|
||||
)
|
||||
|
||||
proto_library(
|
||||
name = "logging_proto",
|
||||
srcs = ["logging.proto"],
|
||||
)
|
||||
|
||||
cc_proto_library(
|
||||
name = "loggings_cc_proto",
|
||||
deps = [":logging_proto"],
|
||||
)
|
||||
|
||||
python_grpc_compile(
|
||||
name = "logging_py_proto",
|
||||
deps = [":logging_proto"],
|
||||
)
|
||||
|
||||
proto_library(
|
||||
name = "node_manager_proto",
|
||||
srcs = ["node_manager.proto"],
|
||||
|
@ -238,6 +253,7 @@ proto_library(
|
|||
deps = [
|
||||
":common_proto",
|
||||
":gcs_proto",
|
||||
":logging_proto",
|
||||
],
|
||||
)
|
||||
|
||||
|
|
36
src/ray/protobuf/logging.proto
Normal file
36
src/ray/protobuf/logging.proto
Normal file
|
@ -0,0 +1,36 @@
|
|||
// Copyright 2021 The Ray Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
syntax = "proto3";
|
||||
option cc_enable_arenas = true;
|
||||
|
||||
package ray.rpc;
|
||||
|
||||
// A batch of logs with metadata and multiple log lines.
|
||||
message LogBatch {
|
||||
// IP of the log publisher.
|
||||
string ip = 1;
|
||||
// Ray uses string for pid sometimes, e.g. autoscaler, gcs.
|
||||
string pid = 2;
|
||||
// Job ID in hex.
|
||||
string job_id = 3;
|
||||
// Whether this is an error output.
|
||||
bool is_error = 4;
|
||||
// Multiple lines of logs.
|
||||
repeated string lines = 5;
|
||||
// Name of the actor.
|
||||
string actor_name = 6;
|
||||
// Name of the task.
|
||||
string task_name = 7;
|
||||
}
|
|
@ -19,6 +19,7 @@ package ray.rpc;
|
|||
|
||||
import "src/ray/protobuf/common.proto";
|
||||
import "src/ray/protobuf/gcs.proto";
|
||||
import "src/ray/protobuf/logging.proto";
|
||||
|
||||
/// Each channel is prefixed by the name of its components.
|
||||
/// For example, for pubsub channels that are used by core workers,
|
||||
|
@ -42,6 +43,8 @@ enum ChannelType {
|
|||
GCS_WORKER_DELTA_CHANNEL = 7;
|
||||
/// A channel for errors from various Ray components.
|
||||
RAY_ERROR_INFO_CHANNEL = 8;
|
||||
/// A channel for logs from various Ray components.
|
||||
RAY_LOG_CHANNEL = 9;
|
||||
}
|
||||
|
||||
///
|
||||
|
@ -64,6 +67,7 @@ message PubMessage {
|
|||
NodeResourceChange node_resource_message = 10;
|
||||
WorkerDeltaData worker_delta_message = 11;
|
||||
ErrorTableData error_info_message = 12;
|
||||
LogBatch log_batch_message = 13;
|
||||
|
||||
// The message that indicates the given key id is not available anymore.
|
||||
FailureMessage failure_message = 6;
|
||||
|
|
Loading…
Add table
Reference in a new issue