[Core][Pubsub][Logging 1/n] add logging support to GCS pubsub in Python (#20604)

This PR adds support for publishing and subscribing to logs in Python via GCS pubsub. It also refactors the Python threaded subscriber to support subscribing and calling `close()` from multiple threads.

We can also move tests and logging support to another PR, but it will make the purpose of the refactoring seems less obvious.
This commit is contained in:
mwtian 2021-11-29 11:26:01 -08:00 committed by GitHub
parent aabe9229df
commit a4d3898159
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 502 additions and 147 deletions

View file

@ -64,6 +64,7 @@ MOCK_MODULES = [
"ray.core.generated.common_pb2",
"ray.core.generated.runtime_env_common_pb2",
"ray.core.generated.gcs_pb2",
"ray.core.generated.logging_pb2",
"ray.core.generated.ray.protocol.Task",
"ray.serve.generated",
"ray.serve.generated.serve_pb2",

View file

@ -12,10 +12,10 @@ except ImportError:
from grpc.experimental import aio as aiogrpc
import ray._private.gcs_utils as gcs_utils
import ray._private.logging_utils as logging_utils
from ray.core.generated.gcs_pb2 import ErrorTableData
from ray.core.generated import gcs_service_pb2_grpc
from ray.core.generated import gcs_service_pb2
from ray.core.generated.gcs_pb2 import (
ErrorTableData, )
from ray.core.generated import pubsub_pb2
logger = logging.getLogger(__name__)
@ -23,7 +23,7 @@ logger = logging.getLogger(__name__)
def gcs_pubsub_enabled():
"""Checks whether GCS pubsub feature flag is enabled."""
return os.environ.get("RAY_gcs_grpc_based_pubsub") not in\
return os.environ.get("RAY_gcs_grpc_based_pubsub") not in \
[None, "0", "false"]
@ -48,11 +48,28 @@ def construct_error_message(job_id, error_type, message, timestamp):
return data
class _PublisherBase:
@staticmethod
def _create_log_request(log_json: dict):
job_id = log_json.get("job")
return gcs_service_pb2.GcsPublishRequest(pub_messages=[
pubsub_pb2.PubMessage(
channel_type=pubsub_pb2.RAY_LOG_CHANNEL,
key_id=job_id.encode() if job_id else None,
log_batch_message=logging_utils.log_batch_dict_to_proto(
log_json))
])
class _SubscriberBase:
def _subscribe_error_request(self):
cmd = pubsub_pb2.Command(
channel_type=pubsub_pb2.RAY_ERROR_INFO_CHANNEL,
subscribe_message={})
def __init__(self):
# self._subscriber_id needs to match the binary format of a random
# SubscriberID / UniqueID, which is 28 (kUniqueIDSize) random bytes.
self._subscriber_id = bytes(
bytearray(random.getrandbits(8) for _ in range(28)))
def _subscribe_request(self, channel):
cmd = pubsub_pb2.Command(channel_type=channel, subscribe_message={})
req = gcs_service_pb2.GcsSubscriberCommandBatchRequest(
subscriber_id=self._subscriber_id, commands=[cmd])
return req
@ -61,18 +78,31 @@ class _SubscriberBase:
return gcs_service_pb2.GcsSubscriberPollRequest(
subscriber_id=self._subscriber_id)
def _unsubscribe_request(self):
def _unsubscribe_request(self, channels):
req = gcs_service_pb2.GcsSubscriberCommandBatchRequest(
subscriber_id=self._subscriber_id, commands=[])
if self._subscribed_error:
cmd = pubsub_pb2.Command(
channel_type=pubsub_pb2.RAY_ERROR_INFO_CHANNEL,
unsubscribe_message={})
req.commands.append(cmd)
for channel in channels:
req.commands.append(
pubsub_pb2.Command(
channel_type=channel, unsubscribe_message={}))
return req
@staticmethod
def _pop_error_info(queue):
if len(queue) == 0:
return None, None
msg = queue.popleft()
return msg.key_id, msg.error_info_message
class GcsPublisher:
@staticmethod
def _pop_log_batch(queue):
if len(queue) == 0:
return None
msg = queue.popleft()
return logging_utils.log_batch_proto_to_dict(msg.log_batch_message)
class GcsPublisher(_PublisherBase):
"""Publisher to GCS."""
def __init__(self, *, address: str = None, channel: grpc.Channel = None):
@ -94,16 +124,119 @@ class GcsPublisher:
req = gcs_service_pb2.GcsPublishRequest(pub_messages=[msg])
self._stub.GcsPublish(req)
def publish_logs(self, log_batch: dict) -> None:
"""Publishes logs to GCS."""
req = self._create_log_request(log_batch)
self._stub.GcsPublish(req)
class GcsSubscriber(_SubscriberBase):
"""Subscriber to GCS. Thread safe.
class _SyncSubscriber(_SubscriberBase):
def __init__(
self,
pubsub_channel_type,
address: str = None,
channel: grpc.Channel = None,
):
super().__init__()
if address:
assert channel is None, \
"address and channel cannot both be specified"
channel = gcs_utils.create_gcs_channel(address)
else:
assert channel is not None, \
"One of address and channel must be specified"
# GRPC stub to GCS pubsub.
self._stub = gcs_service_pb2_grpc.InternalPubSubGcsServiceStub(channel)
# Type of the channel.
self._channel = pubsub_channel_type
# Protects multi-threaded read and write of self._queue.
self._lock = threading.Lock()
# A queue of received PubMessage.
self._queue = deque()
# Indicates whether the subscriber has closed.
self._close = threading.Event()
def subscribe(self) -> None:
"""Registers a subscription for the subscriber's channel type.
Before the registration, published messages in the channel will not be
saved for the subscriber.
"""
with self._lock:
if self._close.is_set():
return
req = self._subscribe_request(self._channel)
self._stub.GcsSubscriberCommandBatch(req, timeout=30)
def _poll_locked(self, timeout=None) -> None:
assert self._lock.locked()
# Poll until data becomes available.
while len(self._queue) == 0:
if self._close.is_set():
return
fut = self._stub.GcsSubscriberPoll.future(
self._poll_request(), timeout=timeout)
# Wait for result to become available, or cancel if the
# subscriber has closed.
while True:
try:
# Use 1s timeout to check for subscriber closing
# periodically.
fut.result(timeout=1)
break
except grpc.FutureTimeoutError:
# Subscriber has closed. Cancel inflight request and
# return from polling.
if self._close.is_set():
fut.cancel()
return
# GRPC has not replied, continue waiting.
continue
except grpc.RpcError as e:
# Choose to not raise deadline exceeded errors to the
# caller. Instead return None. This can be revisited later.
if e.code() == grpc.StatusCode.DEADLINE_EXCEEDED:
return
raise
if fut.done():
for msg in fut.result().pub_messages:
if msg.channel_type != self._channel:
logger.warn(
f"Ignoring message from unsubscribed channel {msg}"
)
continue
self._queue.append(msg)
def close(self) -> None:
"""Closes the subscriber and its active subscription."""
# Mark close to terminate inflight polling and prevent future requests.
self._close.set()
req = self._unsubscribe_request(channels=[self._channel])
try:
self._stub.GcsSubscriberCommandBatch(req, timeout=5)
except Exception:
pass
self._stub = None
class GcsErrorSubscriber(_SyncSubscriber):
"""Subscriber to error info. Thread safe.
Usage example:
subscriber = GcsSubscriber()
subscriber.subscribe_error()
subscriber = GcsErrorSubscriber()
# Subscribe to the error channel.
subscriber.subscribe()
...
while running:
error_id, error_data = subscriber.poll_error()
error_id, error_data = subscriber.poll()
......
# Unsubscribe from the error channels.
subscriber.close()
"""
@ -112,90 +245,55 @@ class GcsSubscriber(_SubscriberBase):
address: str = None,
channel: grpc.Channel = None,
):
if address:
assert channel is None, \
"address and channel cannot both be specified"
channel = gcs_utils.create_gcs_channel(address)
else:
assert channel is not None, \
"One of address and channel must be specified"
self._lock = threading.RLock()
self._stub = gcs_service_pb2_grpc.InternalPubSubGcsServiceStub(channel)
self._subscriber_id = bytes(
bytearray(random.getrandbits(8) for _ in range(28)))
# Whether error info has been subscribed.
self._subscribed_error = False
# Buffer for holding error info.
self._errors = deque()
# Future for indicating whether the subscriber has closed.
self._close = threading.Event()
super().__init__(pubsub_pb2.RAY_ERROR_INFO_CHANNEL, address, channel)
def subscribe_error(self) -> None:
"""Registers a subscription for error info.
def poll(self, timeout=None) -> Tuple[bytes, ErrorTableData]:
"""Polls for new error messages.
Before the registration, published errors will not be saved for the
subscriber.
Returns:
A tuple of error message ID and ErrorTableData proto message,
or None, None if polling times out or subscriber closed.
"""
with self._lock:
if self._close.is_set():
return
if not self._subscribed_error:
req = self._subscribe_error_request()
self._stub.GcsSubscriberCommandBatch(req, timeout=30)
self._subscribed_error = True
self._poll_locked(timeout=timeout)
return self._pop_error_info(self._queue)
def poll_error(self, timeout=None) -> Tuple[bytes, ErrorTableData]:
"""Polls for new error messages."""
class GcsLogSubscriber(_SyncSubscriber):
"""Subscriber to logs. Thread safe.
Usage example:
subscriber = GcsLogSubscriber()
# Subscribe to the log channel.
subscriber.subscribe()
...
while running:
log = subscriber.poll()
......
# Unsubscribe from the log channel.
subscriber.close()
"""
def __init__(
self,
address: str = None,
channel: grpc.Channel = None,
):
super().__init__(pubsub_pb2.RAY_LOG_CHANNEL, address, channel)
def poll(self, timeout=None) -> Tuple[bytes, ErrorTableData]:
"""Polls for new log messages.
Returns:
A dict containing a batch of log lines and their metadata,
or None if polling times out or subscriber closed.
"""
with self._lock:
if self._close.is_set():
return
if len(self._errors) == 0:
req = self._poll_request()
fut = self._stub.GcsSubscriberPoll.future(req, timeout=timeout)
# Wait for result to become available, or cancel if the
# subscriber has closed.
while True:
try:
fut.result(timeout=1)
break
except grpc.FutureTimeoutError:
# Subscriber has closed. Cancel inflight the request
# and return from polling.
if self._close.is_set():
fut.cancel()
return None, None
# GRPC has not replied, continue waiting.
continue
except Exception:
# GRPC error, including deadline exceeded.
raise
if fut.done():
for msg in fut.result().pub_messages:
self._errors.append((msg.key_id,
msg.error_info_message))
if len(self._errors) == 0:
return None, None
return self._errors.popleft()
def close(self) -> None:
"""Closes the subscriber and its active subscriptions."""
# Mark close to terminate inflight polling and prevent future requests.
self._close.set()
with self._lock:
if not self._stub:
# Subscriber already closed.
return
req = self._unsubscribe_request()
try:
self._stub.GcsSubscriberCommandBatch(req, timeout=30)
except Exception:
pass
self._stub = None
self._poll_locked(timeout=timeout)
return self._pop_log_batch(self._queue)
class GcsAioPublisher:
class GcsAioPublisher(_PublisherBase):
"""Publisher to GCS. Uses async io."""
def __init__(self, address: str = None, channel: aiogrpc.Channel = None):
@ -218,6 +316,11 @@ class GcsAioPublisher:
req = gcs_service_pb2.GcsPublishRequest(pub_messages=[msg])
await self._stub.GcsPublish(req)
async def publish_logs(self, log_batch: dict) -> None:
"""Publishes logs to GCS."""
req = self._create_log_request(log_batch)
await self._stub.GcsPublish(req)
class GcsAioSubscriber(_SubscriberBase):
"""Async io subscriber to GCS.
@ -232,6 +335,8 @@ class GcsAioSubscriber(_SubscriberBase):
"""
def __init__(self, address: str = None, channel: aiogrpc.Channel = None):
super().__init__()
if address:
assert channel is None, \
"address and channel cannot both be specified"
@ -239,13 +344,9 @@ class GcsAioSubscriber(_SubscriberBase):
else:
assert channel is not None, \
"One of address and channel must be specified"
# Message queue for each channel.
self._messages = {}
self._stub = gcs_service_pb2_grpc.InternalPubSubGcsServiceStub(channel)
self._subscriber_id = bytes(
bytearray(random.getrandbits(8) for _ in range(28)))
# Whether error info has been subscribed.
self._subscribed_error = False
# Buffer for holding error info.
self._errors = deque()
async def subscribe_error(self) -> None:
"""Registers a subscription for error info.
@ -253,28 +354,55 @@ class GcsAioSubscriber(_SubscriberBase):
Before the registration, published errors will not be saved for the
subscriber.
"""
if not self._subscribed_error:
req = self._subscribe_error_request()
if pubsub_pb2.RAY_ERROR_INFO_CHANNEL not in self._messages:
self._messages[pubsub_pb2.RAY_ERROR_INFO_CHANNEL] = deque()
req = self._subscribe_request(pubsub_pb2.RAY_ERROR_INFO_CHANNEL)
await self._stub.GcsSubscriberCommandBatch(req, timeout=30)
self._subscribed_error = True
async def subscribe_logs(self) -> None:
"""Registers a subscription for logs.
Before the registration, published logs will not be saved for the
subscriber.
"""
if pubsub_pb2.RAY_LOG_CHANNEL not in self._messages:
self._messages[pubsub_pb2.RAY_LOG_CHANNEL] = deque()
req = self._subscribe_request(pubsub_pb2.RAY_LOG_CHANNEL)
await self._stub.GcsSubscriberCommandBatch(req, timeout=30)
def _enqueue_poll_response(self, resp):
for msg in resp.pub_messages:
queue = self._messages.get(msg.channel_type)
if queue is not None:
queue.append(msg)
else:
logger.warn(
f"Ignoring message from unsubscribed channel {msg}")
async def poll_error(self, timeout=None) -> Tuple[bytes, ErrorTableData]:
"""Polls for new error messages."""
if len(self._errors) == 0:
queue = self._messages.get(pubsub_pb2.RAY_ERROR_INFO_CHANNEL)
while len(queue) == 0:
req = self._poll_request()
reply = await self._stub.GcsSubscriberPoll(req, timeout=timeout)
for msg in reply.pub_messages:
self._errors.append((msg.key_id, msg.error_info_message))
self._enqueue_poll_response(reply)
if len(self._errors) == 0:
return None, None
return self._errors.popleft()
return self._pop_error_info(queue)
async def poll_logs(self, timeout=None) -> dict:
"""Polls for new error messages."""
queue = self._messages.get(pubsub_pb2.RAY_LOG_CHANNEL)
while len(queue) == 0:
req = self._poll_request()
reply = await self._stub.GcsSubscriberPoll(req, timeout=timeout)
self._enqueue_poll_response(reply)
return self._pop_log_batch(queue)
async def close(self) -> None:
"""Closes the subscriber and its active subscriptions."""
req = self._unsubscribe_request()
req = self._unsubscribe_request(self._messages.keys())
try:
await self._stub.GcsSubscriberCommandBatch(req, timeout=30)
await self._stub.GcsSubscriberCommandBatch(req, timeout=5)
except Exception:
pass
self._subscribed_error = False

View file

@ -0,0 +1,29 @@
from ray.core.generated.logging_pb2 import LogBatch
def log_batch_dict_to_proto(log_json: dict) -> LogBatch:
"""Converts a dict containing a batch of logs to a LogBatch proto."""
return LogBatch(
ip=log_json.get("ip"),
# Cast to support string pid like "gcs".
pid=str(log_json.get("pid")) if log_json.get("pid") else None,
# Job ID as a hex string.
job_id=log_json.get("job"),
is_error=bool(log_json.get("is_err")),
lines=log_json.get("lines"),
actor_name=log_json.get("actor_name"),
task_name=log_json.get("task_name"),
)
def log_batch_proto_to_dict(log_batch: LogBatch) -> dict:
"""Converts a LogBatch proto to a dict containing a batch of logs."""
return {
"ip": log_batch.ip,
"pid": log_batch.pid,
"job": log_batch.job_id,
"is_err": log_batch.is_error,
"lines": log_batch.lines,
"actor_name": log_batch.actor_name,
"task_name": log_batch.task_name,
}

View file

@ -23,10 +23,10 @@ import ray._private.services
import ray._private.utils
import ray._private.gcs_utils as gcs_utils
import ray._private.memory_monitor as memory_monitor
from ray.core.generated import gcs_pb2
from ray.core.generated import node_manager_pb2
from ray.core.generated import node_manager_pb2_grpc
from ray.core.generated import gcs_pb2
from ray._private.gcs_pubsub import gcs_pubsub_enabled, GcsSubscriber
from ray._private.gcs_pubsub import gcs_pubsub_enabled, GcsErrorSubscriber
from ray._private.tls_utils import generate_self_signed_tls_certs
from ray.util.queue import Queue, _QueueActor, Empty
from ray.scripts.scripts import main as ray_main
@ -513,9 +513,9 @@ def get_non_head_nodes(cluster):
def init_error_pubsub():
"""Initialize redis error info pub/sub"""
if gcs_pubsub_enabled():
s = GcsSubscriber(
s = GcsErrorSubscriber(
channel=ray.worker.global_worker.gcs_channel.channel())
s.subscribe_error()
s.subscribe()
else:
s = ray.worker.global_worker.redis_client.pubsub(
ignore_subscribe_messages=True)
@ -532,17 +532,11 @@ def get_error_message(subscriber, num=1e6, error_type=None, timeout=20):
deadline = time.time() + timeout
msgs = []
while time.time() < deadline and len(msgs) < num:
if isinstance(subscriber, GcsSubscriber):
try:
_, error_data = subscriber.poll_error(timeout=deadline -
time.time())
except grpc.RpcError as e:
# Failed to match error message before timeout.
if e.code() == grpc.StatusCode.DEADLINE_EXCEEDED:
logging.warning("get_error_message() timed out")
return []
# Otherwise, the error is unexpected.
raise
if isinstance(subscriber, GcsErrorSubscriber):
_, error_data = subscriber.poll(timeout=deadline - time.time())
if not error_data:
# Timed out before any data is received.
break
else:
msg = subscriber.get_message()
if msg is None:

View file

@ -1,9 +1,10 @@
import sys
import threading
import ray
import ray._private.gcs_utils as gcs_utils
from ray._private.gcs_pubsub import GcsPublisher, GcsSubscriber, \
GcsAioPublisher, GcsAioSubscriber
from ray._private.gcs_pubsub import GcsPublisher, GcsErrorSubscriber, \
GcsLogSubscriber, GcsAioPublisher, GcsAioSubscriber
from ray.core.generated.gcs_pb2 import ErrorTableData
import pytest
@ -23,8 +24,8 @@ def test_publish_and_subscribe_error_info(ray_start_regular):
gcs_server_addr = gcs_utils.get_gcs_address_from_redis(redis)
subscriber = GcsSubscriber(address=gcs_server_addr)
subscriber.subscribe_error()
subscriber = GcsErrorSubscriber(address=gcs_server_addr)
subscriber.subscribe()
publisher = GcsPublisher(address=gcs_server_addr)
err1 = ErrorTableData(error_message="test error message 1")
@ -32,8 +33,8 @@ def test_publish_and_subscribe_error_info(ray_start_regular):
publisher.publish_error(b"aaa_id", err1)
publisher.publish_error(b"bbb_id", err2)
assert subscriber.poll_error() == (b"aaa_id", err1)
assert subscriber.poll_error() == (b"bbb_id", err2)
assert subscriber.poll() == (b"aaa_id", err1)
assert subscriber.poll() == (b"bbb_id", err2)
subscriber.close()
@ -69,5 +70,148 @@ async def test_aio_publish_and_subscribe_error_info(ray_start_regular):
await subscriber.close()
@pytest.mark.parametrize(
"ray_start_regular", [{
"_system_config": {
"gcs_grpc_based_pubsub": True
}
}],
indirect=True)
def test_publish_and_subscribe_logs(ray_start_regular):
address_info = ray_start_regular
redis = ray._private.services.create_redis_client(
address_info["redis_address"],
password=ray.ray_constants.REDIS_DEFAULT_PASSWORD)
gcs_server_addr = gcs_utils.get_gcs_address_from_redis(redis)
subscriber = GcsLogSubscriber(address=gcs_server_addr)
subscriber.subscribe()
publisher = GcsPublisher(address=gcs_server_addr)
log_batch = {
"ip": "127.0.0.1",
"pid": 1234,
"job": "0001",
"is_err": False,
"lines": ["line 1", "line 2"],
"actor_name": "test actor",
"task_name": "test task",
}
publisher.publish_logs(log_batch)
# PID is treated as string.
log_batch["pid"] = "1234"
assert subscriber.poll() == log_batch
subscriber.close()
@pytest.mark.asyncio
@pytest.mark.parametrize(
"ray_start_regular", [{
"_system_config": {
"gcs_grpc_based_pubsub": True
}
}],
indirect=True)
async def test_aio_publish_and_subscribe_logs(ray_start_regular):
address_info = ray_start_regular
redis = ray._private.services.create_redis_client(
address_info["redis_address"],
password=ray.ray_constants.REDIS_DEFAULT_PASSWORD)
gcs_server_addr = gcs_utils.get_gcs_address_from_redis(redis)
subscriber = GcsAioSubscriber(address=gcs_server_addr)
await subscriber.subscribe_logs()
publisher = GcsAioPublisher(address=gcs_server_addr)
log_batch = {
"ip": "127.0.0.1",
"pid": "gcs",
"job": "0001",
"is_err": False,
"lines": ["line 1", "line 2"],
"actor_name": "test actor",
"task_name": "test task",
}
await publisher.publish_logs(log_batch)
assert await subscriber.poll_logs() == log_batch
await subscriber.close()
@pytest.mark.parametrize(
"ray_start_regular", [{
"_system_config": {
"gcs_grpc_based_pubsub": True
}
}],
indirect=True)
def test_subscribe_two_channels(ray_start_regular):
"""Tests concurrently subscribing to two channels work."""
address_info = ray_start_regular
redis = ray._private.services.create_redis_client(
address_info["redis_address"],
password=ray.ray_constants.REDIS_DEFAULT_PASSWORD)
gcs_server_addr = gcs_utils.get_gcs_address_from_redis(redis)
num_messages = 100
errors = []
def receive_errors():
subscriber = GcsErrorSubscriber(address=gcs_server_addr)
subscriber.subscribe()
while len(errors) < num_messages:
_, msg = subscriber.poll()
errors.append(msg)
logs = []
def receive_logs():
subscriber = GcsLogSubscriber(address=gcs_server_addr)
subscriber.subscribe()
while len(logs) < num_messages:
log_batch = subscriber.poll()
logs.append(log_batch)
t1 = threading.Thread(target=receive_errors)
t1.start()
t2 = threading.Thread(target=receive_logs)
t2.start()
publisher = GcsPublisher(address=gcs_server_addr)
for i in range(0, num_messages):
publisher.publish_error(
b"msg_id", ErrorTableData(error_message=f"error {i}"))
publisher.publish_logs({
"ip": "127.0.0.1",
"pid": "gcs",
"job": "0001",
"is_err": False,
"lines": [f"line {i}"],
"actor_name": "test actor",
"task_name": "test task",
})
t1.join(timeout=1)
assert not t1.is_alive()
assert len(errors) == num_messages
t2.join(timeout=1)
assert not t2.is_alive()
assert len(logs) == num_messages
for i in range(0, num_messages):
assert errors[i].error_message == f"error {i}"
assert logs[i]["lines"][0] == f"line {i}"
if __name__ == "__main__":
sys.exit(pytest.main(["-v", __file__]))

View file

@ -28,7 +28,7 @@ import ray.serialization as serialization
import ray._private.gcs_utils as gcs_utils
import ray._private.services as services
from ray._private.gcs_pubsub import gcs_pubsub_enabled, GcsPublisher, \
GcsSubscriber
GcsErrorSubscriber
from ray._private.runtime_env.py_modules import upload_py_modules_if_needed
from ray._private.runtime_env.working_dir import upload_working_dir_if_needed
from ray._private.runtime_env.constants import RAY_JOB_CONFIG_JSON_ENV_VAR
@ -1266,7 +1266,7 @@ def listen_error_messages_from_gcs(worker, threads_stopped):
# gcs_subscriber.poll_error() will still be processed in the loop.
# TODO: we should just subscribe to the errors for this specific job.
worker.gcs_subscriber.subscribe_error()
worker.gcs_error_subscriber.subscribe()
try:
if _internal_kv_initialized():
@ -1281,7 +1281,7 @@ def listen_error_messages_from_gcs(worker, threads_stopped):
if threads_stopped.is_set():
return
_, error_data = worker.gcs_subscriber.poll_error()
_, error_data = worker.gcs_error_subscriber.poll()
if error_data is None:
continue
if error_data.job_id not in [
@ -1371,7 +1371,7 @@ def connect(node,
if worker.gcs_pubsub_enabled:
worker.gcs_publisher = GcsPublisher(
channel=worker.gcs_channel.channel())
worker.gcs_subscriber = GcsSubscriber(
worker.gcs_error_subscriber = GcsErrorSubscriber(
channel=worker.gcs_channel.channel())
# Initialize some fields.
@ -1575,8 +1575,8 @@ def disconnect(exiting_interpreter=False):
# should be handled cleanly in the worker object's destructor and not
# in this disconnect method.
worker.threads_stopped.set()
if hasattr(worker, "gcs_subscriber"):
worker.gcs_subscriber.close()
if hasattr(worker, "gcs_error_subscriber"):
worker.gcs_error_subscriber.close()
if hasattr(worker, "import_thread"):
worker.import_thread.join_import_thread()
if hasattr(worker, "listener_thread"):

View file

@ -70,13 +70,16 @@ void GcsServer::Start() {
// Init grpc based pubsub on GCS.
// TODO: Move this into GcsPublisher.
inner_publisher = std::make_unique<pubsub::Publisher>(
/*channels=*/std::vector<
rpc::ChannelType>{rpc::ChannelType::GCS_ACTOR_CHANNEL,
rpc::ChannelType::GCS_JOB_CHANNEL,
rpc::ChannelType::GCS_NODE_INFO_CHANNEL,
rpc::ChannelType::GCS_NODE_RESOURCE_CHANNEL,
rpc::ChannelType::GCS_WORKER_DELTA_CHANNEL,
rpc::ChannelType::RAY_ERROR_INFO_CHANNEL},
/*channels=*/
std::vector<rpc::ChannelType>{
rpc::ChannelType::GCS_ACTOR_CHANNEL,
rpc::ChannelType::GCS_JOB_CHANNEL,
rpc::ChannelType::GCS_NODE_INFO_CHANNEL,
rpc::ChannelType::GCS_NODE_RESOURCE_CHANNEL,
rpc::ChannelType::GCS_WORKER_DELTA_CHANNEL,
rpc::ChannelType::RAY_ERROR_INFO_CHANNEL,
rpc::ChannelType::RAY_LOG_CHANNEL,
},
/*periodical_runner=*/&pubsub_periodical_runner_,
/*get_time_ms=*/[]() { return absl::GetCurrentTimeNanos() / 1e6; },
/*subscriber_timeout_ms=*/RayConfig::instance().subscriber_timeout_ms(),

View file

@ -57,6 +57,21 @@ python_grpc_compile(
deps = [":gcs_proto"],
)
proto_library(
name = "logging_proto",
srcs = ["logging.proto"],
)
cc_proto_library(
name = "loggings_cc_proto",
deps = [":logging_proto"],
)
python_grpc_compile(
name = "logging_py_proto",
deps = [":logging_proto"],
)
proto_library(
name = "node_manager_proto",
srcs = ["node_manager.proto"],
@ -238,6 +253,7 @@ proto_library(
deps = [
":common_proto",
":gcs_proto",
":logging_proto",
],
)

View file

@ -0,0 +1,36 @@
// Copyright 2021 The Ray Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto3";
option cc_enable_arenas = true;
package ray.rpc;
// A batch of logs with metadata and multiple log lines.
message LogBatch {
// IP of the log publisher.
string ip = 1;
// Ray uses string for pid sometimes, e.g. autoscaler, gcs.
string pid = 2;
// Job ID in hex.
string job_id = 3;
// Whether this is an error output.
bool is_error = 4;
// Multiple lines of logs.
repeated string lines = 5;
// Name of the actor.
string actor_name = 6;
// Name of the task.
string task_name = 7;
}

View file

@ -19,6 +19,7 @@ package ray.rpc;
import "src/ray/protobuf/common.proto";
import "src/ray/protobuf/gcs.proto";
import "src/ray/protobuf/logging.proto";
/// Each channel is prefixed by the name of its components.
/// For example, for pubsub channels that are used by core workers,
@ -42,6 +43,8 @@ enum ChannelType {
GCS_WORKER_DELTA_CHANNEL = 7;
/// A channel for errors from various Ray components.
RAY_ERROR_INFO_CHANNEL = 8;
/// A channel for logs from various Ray components.
RAY_LOG_CHANNEL = 9;
}
///
@ -64,6 +67,7 @@ message PubMessage {
NodeResourceChange node_resource_message = 10;
WorkerDeltaData worker_delta_message = 11;
ErrorTableData error_info_message = 12;
LogBatch log_batch_message = 13;
// The message that indicates the given key id is not available anymore.
FailureMessage failure_message = 6;