[Core][Pubsub] Implement Python GCS publisher and subscriber (#20111)

## Why are these changes needed?
This change adds Python publisher and subscriber in `gcs_utils.py`, and GRPC handler on GCS for publishing iva GCS. Error info is migrated to use the GCS-based pubsub, if feature flag `RAY_gcs_grpc_based_pubsub=true`.

Also, add a `--gcs-address` flag to some Python processes. It is not set anywhere yet, but will be set aftering Redis-less bootstrapping work.

Unit tests are added for the Python publisher and subscriber. Migrated error info publishers and subscribers are tested with existing unit tests, e.g. tests calling `ray._private.test_utils.get_error_message()` to ensure error info is published.

GCS based pubsub has gaps in handling deadline, cancelled requests and GCS restarts. So 3 more unit tests are disabled in the `HA GCS` mode. They will be addressed in a separate change.

## Related issue number
This commit is contained in:
mwtian 2021-11-11 14:59:57 -08:00 committed by GitHub
parent fca851eef5
commit 0330852baf
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
31 changed files with 763 additions and 188 deletions

View file

@ -343,6 +343,7 @@
--test_env=RAY_gcs_grpc_based_pubsub=true
-- //python/ray/tests/...
-//python/ray/tests:test_failure_2
-//python/ray/tests:test_job
- bazel test --config=ci $(./scripts/bazel_export_options)
--test_tag_filters=-kubernetes,client_tests,-flaky
--test_env=RAY_CLIENT_MODE=1 --test_env=RAY_PROFILING=1
@ -358,6 +359,7 @@
-- //python/ray/tests/...
-//python/ray/tests:test_client_multi -//python/ray/tests:test_component_failures_3
-//python/ray/tests:test_healthcheck -//python/ray/tests:test_gcs_fault_tolerance
-//python/ray/tests:test_client
- label: ":redis: HA GCS (Medium K-Z)"
conditions: ["RAY_CI_PYTHON_AFFECTED"]
commands:
@ -368,6 +370,7 @@
-- //python/ray/tests/...
-//python/ray/tests:test_multinode_failures_2 -//python/ray/tests:test_ray_debugger
-//python/ray/tests:test_placement_group_2 -//python/ray/tests:test_placement_group_3
-//python/ray/tests:test_multi_node
- label: ":brain: RLlib: Learning discr. actions TF2-static-graph (from rllib/tuned_examples/*.yaml)"
conditions: ["RAY_CI_RLLIB_AFFECTED"]

View file

@ -19,7 +19,9 @@ import ray.dashboard.utils as dashboard_utils
import ray.ray_constants as ray_constants
import ray._private.services
import ray._private.utils
from ray._private.gcs_utils import GcsClient
from ray._private.gcs_pubsub import gcs_pubsub_enabled, GcsPublisher
from ray._private.gcs_utils import GcsClient, \
get_gcs_address_from_redis
from ray.core.generated import agent_manager_pb2
from ray.core.generated import agent_manager_pb2_grpc
from ray._private.ray_logging import setup_component_logger
@ -228,6 +230,11 @@ if __name__ == "__main__":
required=True,
type=str,
help="the IP address of this node.")
parser.add_argument(
"--gcs-address",
required=False,
type=str,
help="The address (ip:port) of GCS.")
parser.add_argument(
"--redis-address",
required=True,
@ -377,6 +384,12 @@ if __name__ == "__main__":
# impact of the issue.
redis_client = ray._private.services.create_redis_client(
args.redis_address, password=args.redis_password)
gcs_publisher = None
if args.gcs_address:
gcs_publisher = GcsPublisher(args.gcs_address)
elif gcs_pubsub_enabled():
gcs_publisher = GcsPublisher(
address=get_gcs_address_from_redis(redis_client))
traceback_str = ray._private.utils.format_error_message(
traceback.format_exc())
message = (
@ -390,9 +403,11 @@ if __name__ == "__main__":
"\n 3. runtime_env APIs won't work."
"\nCheck out the `dashboard_agent.log` to see the "
"detailed failure messages.")
ray._private.utils.push_error_to_driver_through_redis(
redis_client, ray_constants.DASHBOARD_AGENT_DIED_ERROR,
message)
ray._private.utils.publish_error_to_driver(
ray_constants.DASHBOARD_AGENT_DIED_ERROR,
message,
redis_client=redis_client,
gcs_publisher=gcs_publisher)
logger.error(message)
logger.exception(e)
exit(1)

View file

@ -13,8 +13,10 @@ import ray.dashboard.consts as dashboard_consts
import ray.dashboard.head as dashboard_head
import ray.dashboard.utils as dashboard_utils
import ray.ray_constants as ray_constants
import ray._private.gcs_utils as gcs_utils
import ray._private.services
import ray._private.utils
from ray._private.gcs_pubsub import gcs_pubsub_enabled, GcsPublisher
from ray._private.ray_logging import setup_component_logger
from ray._private.metrics_agent import PrometheusServiceDiscoveryWriter
@ -134,6 +136,11 @@ if __name__ == "__main__":
type=int,
default=0,
help="The retry times to select a valid port.")
parser.add_argument(
"--gcs-address",
required=False,
type=str,
help="The address (ip:port) of GCS.")
parser.add_argument(
"--redis-address",
required=True,
@ -223,18 +230,29 @@ if __name__ == "__main__":
loop = asyncio.get_event_loop()
loop.run_until_complete(dashboard.run())
except Exception as e:
# Something went wrong, so push an error to all drivers.
redis_client = ray._private.services.create_redis_client(
args.redis_address, password=args.redis_password)
traceback_str = ray._private.utils.format_error_message(
traceback.format_exc())
message = f"The dashboard on node {platform.uname()[1]} " \
f"failed with the following " \
f"error:\n{traceback_str}"
ray._private.utils.push_error_to_driver_through_redis(
redis_client, ray_constants.DASHBOARD_DIED_ERROR, message)
if isinstance(e, FrontendNotFoundError):
logger.warning(message)
else:
logger.error(message)
raise e
# Something went wrong, so push an error to all drivers.
redis_client = ray._private.services.create_redis_client(
args.redis_address, password=args.redis_password)
gcs_publisher = None
if args.gcs_address:
gcs_publisher = GcsPublisher(address=args.gcs_address)
elif gcs_pubsub_enabled():
gcs_publisher = GcsPublisher(
address=gcs_utils.get_gcs_address_from_redis(redis_client))
ray._private.utils.publish_error_to_driver(
redis_client,
ray_constants.DASHBOARD_DIED_ERROR,
message,
redis_client=redis_client,
gcs_publisher=gcs_publisher)

View file

@ -19,6 +19,7 @@ import ray._private.services
import ray.dashboard.consts as dashboard_consts
import ray.dashboard.utils as dashboard_utils
from ray import ray_constants
from ray._private.gcs_pubsub import gcs_pubsub_enabled, GcsAioSubscriber
from ray.core.generated import gcs_service_pb2
from ray.core.generated import gcs_service_pb2_grpc
from ray.dashboard.datacenter import DataOrganizer
@ -44,8 +45,8 @@ GRPC_CHANNEL_OPTIONS = (
async def get_gcs_address_with_retry(redis_client) -> str:
while True:
try:
gcs_address = await redis_client.get(
dashboard_consts.REDIS_KEY_GCS_SERVER_ADDRESS)
gcs_address = (await redis_client.get(
dashboard_consts.REDIS_KEY_GCS_SERVER_ADDRESS)).decode()
if not gcs_address:
raise Exception("GCS address not found.")
logger.info("Connect to GCS at %s", gcs_address)
@ -113,6 +114,7 @@ class DashboardHead:
self.log_dir = log_dir
self.aioredis_client = None
self.aiogrpc_gcs_channel = None
self.gcs_subscriber = None
self.http_session = None
self.ip = ray.util.get_node_ip_address()
ip, port = redis_address.split(":")
@ -192,11 +194,15 @@ class DashboardHead:
# Waiting for GCS is ready.
# TODO: redis-removal bootstrap
gcs_address = await get_gcs_address_with_retry(self.aioredis_client)
self.gcs_client = GcsClient(gcs_address)
self.gcs_client = GcsClient(address=gcs_address)
internal_kv._initialize_internal_kv(self.gcs_client)
self.aiogrpc_gcs_channel = ray._private.utils.init_grpc_channel(
gcs_address, GRPC_CHANNEL_OPTIONS, asynchronous=True)
gcs_client = GcsClient(gcs_address)
internal_kv._initialize_internal_kv(gcs_client)
self.gcs_subscriber = None
if gcs_pubsub_enabled():
self.gcs_subscriber = GcsAioSubscriber(
channel=self.aiogrpc_gcs_channel)
await self.gcs_subscriber.subscribe_error()
self.health_check_thread = GCSHealthCheckThread(gcs_address)
self.health_check_thread.start()

View file

@ -1,5 +1,4 @@
NODE_STATS_UPDATE_INTERVAL_SECONDS = 1
ERROR_INFO_UPDATE_INTERVAL_SECONDS = 5
LOG_INFO_UPDATE_INTERVAL_SECONDS = 5
UPDATE_NODES_INTERVAL_SECONDS = 5
MAX_COUNT_OF_GCS_RPC_ERROR = 10

View file

@ -272,28 +272,15 @@ class NodeHead(dashboard_utils.DashboardHeadModule):
logger.exception("Error receiving log info.")
async def _update_error_info(self):
aioredis_client = self._dashboard_head.aioredis_client
receiver = Receiver()
key = gcs_utils.RAY_ERROR_PUBSUB_PATTERN
pattern = receiver.pattern(key)
await aioredis_client.psubscribe(pattern)
logger.info("Subscribed to %s", key)
async for sender, msg in receiver.iter():
try:
_, data = msg
pubsub_msg = gcs_utils.PubSubMessage.FromString(data)
error_data = gcs_utils.ErrorTableData.FromString(
pubsub_msg.data)
def process_error(error_data):
error_data = gcs_utils.ErrorTableData.FromString(pubsub_msg.data)
message = error_data.error_message
message = re.sub(r"\x1b\[\d+m", "", message)
match = re.search(r"\(pid=(\d+), ip=(.*?)\)", message)
if match:
pid = match.group(1)
ip = match.group(2)
errs_for_ip = dict(
DataSource.ip_and_pid_to_errors.get(ip, {}))
errs_for_ip = dict(DataSource.ip_and_pid_to_errors.get(ip, {}))
pid_errors = list(errs_for_ip.get(pid, []))
pid_errors.append({
"message": message,
@ -303,8 +290,33 @@ class NodeHead(dashboard_utils.DashboardHeadModule):
errs_for_ip[pid] = pid_errors
DataSource.ip_and_pid_to_errors[ip] = errs_for_ip
logger.info(f"Received error entry for {ip} {pid}")
if self._dashboard_head.gcs_subscriber:
while True:
_, error_data = await \
self._dashboard_head.gcs_subscriber.poll_error()
try:
process_error(error_data)
except Exception:
logger.exception("Error receiving error info.")
logger.exception("Error receiving error info from GCS.")
else:
aioredis_client = self._dashboard_head.aioredis_client
receiver = Receiver()
key = gcs_utils.RAY_ERROR_PUBSUB_PATTERN
pattern = receiver.pattern(key)
await aioredis_client.psubscribe(pattern)
logger.info("Subscribed to %s", key)
async for _, msg in receiver.iter():
try:
_, data = msg
pubsub_msg = gcs_utils.PubSubMessage.FromString(data)
error_data = gcs_utils.ErrorTableData.FromString(
pubsub_msg.data)
process_error(error_data)
except Exception:
logger.exception("Error receiving error info from Redis.")
async def run(self, server):
gcs_channel = self._dashboard_head.aiogrpc_gcs_channel

View file

@ -0,0 +1,276 @@
import os
from collections import deque
import logging
import random
import threading
from typing import Tuple
import grpc
import ray._private.gcs_utils as gcs_utils
from ray.core.generated import gcs_service_pb2_grpc
from ray.core.generated import gcs_service_pb2
from ray.core.generated.gcs_pb2 import (
ErrorTableData, )
from ray.core.generated import pubsub_pb2
logger = logging.getLogger(__name__)
def gcs_pubsub_enabled():
"""Checks whether GCS pubsub feature flag is enabled."""
return os.environ.get("RAY_gcs_grpc_based_pubsub") not in\
[None, "0", "false"]
def construct_error_message(job_id, error_type, message, timestamp):
"""Construct an ErrorTableData object.
Args:
job_id: The ID of the job that the error should go to. If this is
nil, then the error will go to all drivers.
error_type: The type of the error.
message: The error message.
timestamp: The time of the error.
Returns:
The ErrorTableData object.
"""
data = ErrorTableData()
data.job_id = job_id.binary()
data.type = error_type
data.error_message = message
data.timestamp = timestamp
return data
class _SubscriberBase:
def _subscribe_error_request(self):
cmd = pubsub_pb2.Command(
channel_type=pubsub_pb2.RAY_ERROR_INFO_CHANNEL,
subscribe_message={})
req = gcs_service_pb2.GcsSubscriberCommandBatchRequest(
subscriber_id=self._subscriber_id, commands=[cmd])
return req
def _poll_request(self):
return gcs_service_pb2.GcsSubscriberPollRequest(
subscriber_id=self._subscriber_id)
def _unsubscribe_request(self):
req = gcs_service_pb2.GcsSubscriberCommandBatchRequest(
subscriber_id=self._subscriber_id, commands=[])
if self._subscribed_error:
cmd = pubsub_pb2.Command(
channel_type=pubsub_pb2.RAY_ERROR_INFO_CHANNEL,
unsubscribe_message={})
req.commands.append(cmd)
return req
class GcsPublisher:
"""Publisher to GCS."""
def __init__(self, address: str = None, channel: grpc.Channel = None):
if address:
assert channel is None, \
"address and channel cannot both be specified"
channel = gcs_utils.create_gcs_channel(address)
else:
assert channel is not None, \
"One of address and channel must be specified"
self._stub = gcs_service_pb2_grpc.InternalPubSubGcsServiceStub(channel)
def publish_error(self, key_id: bytes, error_info: ErrorTableData) -> None:
"""Publishes error info to GCS."""
msg = pubsub_pb2.PubMessage(
channel_type=pubsub_pb2.RAY_ERROR_INFO_CHANNEL,
key_id=key_id,
error_info_message=error_info)
req = gcs_service_pb2.GcsPublishRequest(pub_messages=[msg])
self._stub.GcsPublish(req)
class GcsSubscriber(_SubscriberBase):
"""Subscriber to GCS. Thread safe.
Usage example:
subscriber = GcsSubscriber()
subscriber.subscribe_error()
while running:
error_id, error_data = subscriber.poll_error()
......
subscriber.close()
"""
def __init__(
self,
address: str = None,
channel: grpc.Channel = None,
):
if address:
assert channel is None, \
"address and channel cannot both be specified"
channel = gcs_utils.create_gcs_channel(address)
else:
assert channel is not None, \
"One of address and channel must be specified"
self._lock = threading.RLock()
self._stub = gcs_service_pb2_grpc.InternalPubSubGcsServiceStub(channel)
self._subscriber_id = bytes(
bytearray(random.getrandbits(8) for _ in range(28)))
# Whether error info has been subscribed.
self._subscribed_error = False
# Buffer for holding error info.
self._errors = deque()
# Future for indicating whether the subscriber has closed.
self._close = threading.Event()
def subscribe_error(self) -> None:
"""Registers a subscription for error info.
Before the registration, published errors will not be saved for the
subscriber.
"""
with self._lock:
if self._close.is_set():
return
if not self._subscribed_error:
req = self._subscribe_error_request()
self._stub.GcsSubscriberCommandBatch(req, timeout=30)
self._subscribed_error = True
def poll_error(self, timeout=None) -> Tuple[bytes, ErrorTableData]:
"""Polls for new error messages."""
with self._lock:
if self._close.is_set():
return
if len(self._errors) == 0:
req = self._poll_request()
fut = self._stub.GcsSubscriberPoll.future(req, timeout=timeout)
# Wait for result to become available, or cancel if the
# subscriber has closed.
while True:
try:
fut.result(timeout=1)
break
except grpc.FutureTimeoutError:
# Subscriber has closed. Cancel inflight the request
# and return from polling.
if self._close.is_set():
fut.cancel()
return None, None
# GRPC has not replied, continue waiting.
continue
except Exception:
# GRPC error, including deadline exceeded.
raise
if fut.done():
for msg in fut.result().pub_messages:
self._errors.append((msg.key_id,
msg.error_info_message))
if len(self._errors) == 0:
return None, None
return self._errors.popleft()
def close(self) -> None:
"""Closes the subscriber and its active subscriptions."""
# Mark close to terminate inflight polling and prevent future requests.
self._close.set()
with self._lock:
if not self._stub:
# Subscriber already closed.
return
req = self._unsubscribe_request()
try:
self._stub.GcsSubscriberCommandBatch(req, timeout=30)
except Exception:
pass
self._stub = None
class GcsAioPublisher:
"""Publisher to GCS. Uses async io."""
def __init__(self, address: str = None, channel: grpc.aio.Channel = None):
if address:
assert channel is None, \
"address and channel cannot both be specified"
channel = gcs_utils.create_gcs_channel(address, aio=True)
else:
assert channel is not None, \
"One of address and channel must be specified"
self._stub = gcs_service_pb2_grpc.InternalPubSubGcsServiceStub(channel)
async def publish_error(self, key_id: bytes,
error_info: ErrorTableData) -> None:
"""Publishes error info to GCS."""
msg = pubsub_pb2.PubMessage(
channel_type=pubsub_pb2.RAY_ERROR_INFO_CHANNEL,
key_id=key_id,
error_info_message=error_info)
req = gcs_service_pb2.GcsPublishRequest(pub_messages=[msg])
await self._stub.GcsPublish(req)
class GcsAioSubscriber(_SubscriberBase):
"""Async io subscriber to GCS.
Usage example:
subscriber = GcsAioSubscriber()
await subscriber.subscribe_error()
while running:
error_id, error_data = await subscriber.poll_error()
......
await subscriber.close()
"""
def __init__(self, address: str = None, channel: grpc.aio.Channel = None):
if address:
assert channel is None, \
"address and channel cannot both be specified"
channel = gcs_utils.create_gcs_channel(address, aio=True)
else:
assert channel is not None, \
"One of address and channel must be specified"
self._stub = gcs_service_pb2_grpc.InternalPubSubGcsServiceStub(channel)
self._subscriber_id = bytes(
bytearray(random.getrandbits(8) for _ in range(28)))
# Whether error info has been subscribed.
self._subscribed_error = False
# Buffer for holding error info.
self._errors = deque()
async def subscribe_error(self) -> None:
"""Registers a subscription for error info.
Before the registration, published errors will not be saved for the
subscriber.
"""
if not self._subscribed_error:
req = self._subscribe_error_request()
await self._stub.GcsSubscriberCommandBatch(req, timeout=30)
self._subscribed_error = True
async def poll_error(self, timeout=None) -> Tuple[bytes, ErrorTableData]:
"""Polls for new error messages."""
if len(self._errors) == 0:
req = self._poll_request()
reply = await self._stub.GcsSubscriberPoll(req, timeout=timeout)
for msg in reply.pub_messages:
self._errors.append((msg.key_id, msg.error_info_message))
if len(self._errors) == 0:
return None, None
return self._errors.popleft()
async def close(self) -> None:
"""Closes the subscriber and its active subscriptions."""
req = self._unsubscribe_request()
try:
await self._stub.GcsSubscriberCommandBatch(req, timeout=30)
except Exception:
pass
self._subscribed_error = False

View file

@ -1,7 +1,10 @@
from ray.core.generated.common_pb2 import ErrorType
import enum
import logging
from typing import List
import grpc
from ray.core.generated.common_pb2 import ErrorType
from ray.core.generated import gcs_service_pb2_grpc
from ray.core.generated import gcs_service_pb2
from ray.core.generated.gcs_pb2 import (
@ -51,64 +54,69 @@ __all__ = [
"ResourceLoad",
"ResourceMap",
"ResourceTableData",
"construct_error_message",
"ObjectLocationInfo",
"PubSubMessage",
"WorkerTableData",
"PlacementGroupTableData",
]
FUNCTION_PREFIX = "RemoteFunction:"
LOG_FILE_CHANNEL = "RAY_LOG_CHANNEL"
REPORTER_CHANNEL = "RAY_REPORTER"
# xray resource usages
XRAY_RESOURCES_BATCH_PATTERN = "RESOURCES_BATCH:".encode("ascii")
# xray job updates
XRAY_JOB_PATTERN = "JOB:*".encode("ascii")
# Actor pub/sub updates
RAY_ACTOR_PUBSUB_PATTERN = "ACTOR:*".encode("ascii")
# Reporter pub/sub updates
RAY_REPORTER_PUBSUB_PATTERN = "RAY_REPORTER.*".encode("ascii")
RAY_ERROR_PUBSUB_PATTERN = "ERROR_INFO:*".encode("ascii")
# These prefixes must be kept up-to-date with the TablePrefix enum in
# gcs.proto.
# TODO(rkn): We should use scoped enums, in which case we should be able to
# just access the flatbuffer generated values.
TablePrefix_RAYLET_TASK_string = "RAYLET_TASK"
TablePrefix_OBJECT_string = "OBJECT"
TablePrefix_PROFILE_string = "PROFILE"
TablePrefix_JOB_string = "JOB"
TablePrefix_ACTOR_string = "ACTOR"
WORKER = 0
DRIVER = 1
# Cap messages at 512MB
_MAX_MESSAGE_LENGTH = 512 * 1024 * 1024
# Send keepalive every 60s
_GRPC_KEEPALIVE_TIME_MS = 60 * 1000
# Keepalive should be replied < 60s
_GRPC_KEEPALIVE_TIMEOUT_MS = 60 * 1000
def construct_error_message(job_id, error_type, message, timestamp):
"""Construct a serialized ErrorTableData object.
# Also relying on these defaults:
# grpc.keepalive_permit_without_calls=0: No keepalive without inflight calls.
# grpc.use_local_subchannel_pool=0: Subchannels are shared.
_GRPC_OPTIONS = [("grpc.enable_http_proxy",
0), ("grpc.max_send_message_length", _MAX_MESSAGE_LENGTH),
("grpc.max_receive_message_length", _MAX_MESSAGE_LENGTH),
("grpc.keepalive_time_ms",
_GRPC_KEEPALIVE_TIME_MS), ("grpc.keepalive_timeout_ms",
_GRPC_KEEPALIVE_TIMEOUT_MS)]
def get_gcs_address_from_redis(redis) -> str:
"""Reads GCS address from redis.
Args:
job_id: The ID of the job that the error should go to. If this is
nil, then the error will go to all drivers.
error_type: The type of the error.
message: The error message.
timestamp: The time of the error.
redis: Redis client to fetch GCS address.
Returns:
The serialized object.
GCS address string.
"""
data = ErrorTableData()
data.job_id = job_id.binary()
data.type = error_type
data.error_message = message
data.timestamp = timestamp
return data.SerializeToString()
gcs_address = redis.get("GcsServerAddress")
if gcs_address is None:
raise RuntimeError("Failed to look up gcs address through redis")
return gcs_address.decode()
def create_gcs_channel(address: str, aio=False):
"""Returns a GRPC channel to GCS.
Args:
address: GCS address string, e.g. ip:port
aio: Whether using grpc.aio
Returns:
grpc.Channel or grpc.aio.Channel to GCS
"""
from ray._private.utils import init_grpc_channel
return init_grpc_channel(address, options=_GRPC_OPTIONS, asynchronous=aio)
class GcsCode(enum.IntEnum):
@ -118,17 +126,13 @@ class GcsCode(enum.IntEnum):
class GcsClient:
MAX_MESSAGE_LENGTH = 512 * 1024 * 1024 # 512MB
"""Client to GCS using GRPC"""
def __init__(self, address):
from ray._private.utils import init_grpc_channel
logger.debug(f"Connecting to gcs address: {address}")
options = [("grpc.enable_http_proxy",
0), ("grpc.max_send_message_length",
GcsClient.MAX_MESSAGE_LENGTH),
("grpc.max_receive_message_length",
GcsClient.MAX_MESSAGE_LENGTH)]
channel = init_grpc_channel(address, options=options)
def __init__(self, address: str = None, channel: grpc.Channel = None):
if address:
assert channel is None, \
"Only one of address and channel can be specified"
channel = create_gcs_channel(address)
self._kv_stub = gcs_service_pb2_grpc.InternalKVGcsServiceStub(channel)
def internal_kv_get(self, key: bytes) -> bytes:
@ -187,10 +191,7 @@ class GcsClient:
@staticmethod
def create_from_redis(redis_cli):
gcs_address = redis_cli.get("GcsServerAddress")
if gcs_address is None:
raise RuntimeError("Failed to look up gcs address through redis")
return GcsClient(gcs_address.decode())
return GcsClient(get_gcs_address_from_redis(redis_cli))
@staticmethod
def connect_to_gcs_by_redis_address(redis_address, redis_password):

View file

@ -15,6 +15,7 @@ import ray.ray_constants as ray_constants
import ray._private.gcs_utils as gcs_utils
import ray._private.services as services
import ray._private.utils
from ray._private.gcs_pubsub import gcs_pubsub_enabled, GcsPublisher
from ray._private.ray_logging import setup_component_logger
# Logger for this module. It should be configured at the entry point
@ -365,6 +366,11 @@ if __name__ == "__main__":
description=("Parse Redis server for the "
"log monitor to connect "
"to."))
parser.add_argument(
"--gcs-address",
required=False,
type=str,
help="The address (ip:port) of GCS.")
parser.add_argument(
"--redis-address",
required=True,
@ -436,11 +442,20 @@ if __name__ == "__main__":
# Something went wrong, so push an error to all drivers.
redis_client = ray._private.services.create_redis_client(
args.redis_address, password=args.redis_password)
gcs_publisher = None
if args.gcs_address:
gcs_publisher = GcsPublisher(address=args.gcs_address)
elif gcs_pubsub_enabled():
gcs_publisher = GcsPublisher(
address=gcs_utils.get_gcs_address_from_redis(redis_client))
traceback_str = ray._private.utils.format_error_message(
traceback.format_exc())
message = (f"The log monitor on node {platform.node()} "
f"failed with the following error:\n{traceback_str}")
ray._private.utils.push_error_to_driver_through_redis(
redis_client, ray_constants.LOG_MONITOR_DIED_ERROR, message)
ray._private.utils.publish_error_to_driver(
ray_constants.LOG_MONITOR_DIED_ERROR,
message,
redis_client=redis_client,
gcs_publisher=gcs_publisher)
logger.error(message)
raise e

View file

@ -1926,6 +1926,7 @@ def start_monitor(redis_address,
Args:
redis_address (str): The address that the Redis server is listening on.
gcs_address (str): The address of GCS server.
logs_dir(str): The path to the log directory.
stdout_file: A file handle opened for writing to redirect stdout to. If
no redirection should happen, then this should be None.

View file

@ -26,6 +26,7 @@ import ray._private.gcs_utils as gcs_utils
import ray._private.memory_monitor as memory_monitor
from ray.core.generated import node_manager_pb2
from ray.core.generated import node_manager_pb2_grpc
from ray._private.gcs_pubsub import gcs_pubsub_enabled, GcsSubscriber
from ray._private.tls_utils import generate_self_signed_tls_certs
from ray.util.queue import Queue, _QueueActor, Empty
from ray.scripts.scripts import main as ray_main
@ -502,19 +503,34 @@ def get_non_head_nodes(cluster):
def init_error_pubsub():
"""Initialize redis error info pub/sub"""
p = ray.worker.global_worker.redis_client.pubsub(
if gcs_pubsub_enabled():
s = GcsSubscriber(channel=ray.worker.global_worker.gcs_channel)
s.subscribe_error()
else:
s = ray.worker.global_worker.redis_client.pubsub(
ignore_subscribe_messages=True)
error_pubsub_channel = gcs_utils.RAY_ERROR_PUBSUB_PATTERN
p.psubscribe(error_pubsub_channel)
return p
s.psubscribe(gcs_utils.RAY_ERROR_PUBSUB_PATTERN)
return s
def get_error_message(pub_sub, num, error_type=None, timeout=20):
"""Get errors through pub/sub."""
start_time = time.time()
def get_error_message(subscriber, num, error_type=None, timeout=20):
"""Get errors through subscriber."""
deadline = time.time() + timeout
msgs = []
while time.time() - start_time < timeout and len(msgs) < num:
msg = pub_sub.get_message()
while time.time() < deadline and len(msgs) < num:
if isinstance(subscriber, GcsSubscriber):
try:
_, error_data = subscriber.poll_error(timeout=deadline -
time.time())
except grpc.RpcError as e:
# Failed to match error message before timeout.
if e.code() == grpc.StatusCode.DEADLINE_EXCEEDED:
logging.warning("get_error_message() timed out")
return []
# Otherwise, the error is unexpected.
raise
else:
msg = subscriber.get_message()
if msg is None:
time.sleep(0.01)
continue

View file

@ -27,6 +27,7 @@ import numpy as np
import ray
import ray._private.gcs_utils as gcs_utils
import ray.ray_constants as ray_constants
from ray._private.gcs_pubsub import construct_error_message
from ray._private.tls_utils import load_certs_from_env
# Import psutil after ray so the packaged version is used.
@ -113,10 +114,11 @@ def push_error_to_driver(worker, error_type, message, job_id=None):
worker.core_worker.push_error(job_id, error_type, message, time.time())
def push_error_to_driver_through_redis(redis_client,
error_type,
def publish_error_to_driver(error_type,
message,
job_id=None):
job_id=None,
redis_client=None,
gcs_publisher=None):
"""Push an error message to the driver to be printed in the background.
Normally the push_error_to_driver function should be used. However, in some
@ -125,25 +127,31 @@ def push_error_to_driver_through_redis(redis_client,
backend processes.
Args:
redis_client: The redis client to use.
error_type (str): The type of the error.
message (str): The message that will be printed in the background
on the driver.
job_id: The ID of the driver to push the error message to. If this
is None, then the message will be pushed to all drivers.
redis_client: The redis client to use.
gcs_publisher: The GCS publisher to use. If specified, ignores
redis_client.
"""
if job_id is None:
job_id = ray.JobID.nil()
assert isinstance(job_id, ray.JobID)
# Do everything in Python and through the Python Redis client instead
# of through the raylet.
error_data = gcs_utils.construct_error_message(job_id, error_type, message,
error_data = construct_error_message(job_id, error_type, message,
time.time())
if gcs_publisher:
gcs_publisher.publish_error(job_id.hex().encode(), error_data)
elif redis_client:
pubsub_msg = gcs_utils.PubSubMessage()
pubsub_msg.id = job_id.binary()
pubsub_msg.data = error_data
pubsub_msg.data = error_data.SerializeToString()
redis_client.publish("ERROR_INFO:" + job_id.hex(),
pubsub_msg.SerializeToString())
else:
raise ValueError(
"One of redis_client and gcs_publisher needs to be specified!")
def random_string():

View file

@ -1,7 +1,6 @@
"""Autoscaler monitoring loop daemon."""
import argparse
import logging
import logging.handlers
import os
import sys
@ -35,7 +34,8 @@ from ray.autoscaler._private.util import DEBUG_AUTOSCALING_STATUS, \
from ray.core.generated import gcs_service_pb2, gcs_service_pb2_grpc
import ray.ray_constants as ray_constants
from ray._private.ray_logging import setup_component_logger
from ray._private.gcs_utils import GcsClient
from ray._private.gcs_pubsub import gcs_pubsub_enabled, GcsPublisher
from ray._private.gcs_utils import GcsClient, get_gcs_address_from_redis
from ray.experimental.internal_kv import _initialize_internal_kv, \
_internal_kv_put, _internal_kv_initialized, _internal_kv_get, \
_internal_kv_del
@ -409,9 +409,18 @@ class Monitor:
_internal_kv_put(DEBUG_AUTOSCALING_ERROR, message, overwrite=True)
redis_client = ray._private.services.create_redis_client(
self.redis_address, password=self.redis_password)
from ray._private.utils import push_error_to_driver_through_redis
push_error_to_driver_through_redis(
redis_client, ray_constants.MONITOR_DIED_ERROR, message)
gcs_publisher = None
if args.gcs_address:
gcs_publisher = GcsPublisher(address=args.gcs_address)
elif gcs_pubsub_enabled():
gcs_publisher = GcsPublisher(
address=get_gcs_address_from_redis(redis_client))
from ray._private.utils import publish_error_to_driver
publish_error_to_driver(
ray_constants.MONITOR_DIED_ERROR,
message,
redis_client=redis_client,
gcs_publisher=gcs_publisher)
def _signal_handler(self, sig, frame):
self._handle_failure(f"Terminated with signal {sig}\n" +
@ -437,6 +446,11 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser(
description=("Parse Redis server for the "
"monitor to connect to."))
parser.add_argument(
"--gcs-address",
required=False,
type=str,
help="The address (ip:port) of GCS.")
parser.add_argument(
"--redis-address",
required=True,

View file

@ -28,6 +28,7 @@ py_test_module_list(
"test_component_failures_2.py",
"test_component_failures_3.py",
"test_error_ray_not_initialized.py",
"test_gcs_pubsub.py",
"test_global_gc.py",
"test_grpc_client_credentials.py",
"test_iter.py",

View file

@ -11,6 +11,7 @@ import ray._private.utils
import ray._private.gcs_utils as gcs_utils
import ray.ray_constants as ray_constants
from ray.exceptions import RayTaskError, RayActorError, GetTimeoutError
from ray._private.gcs_pubsub import gcs_pubsub_enabled, GcsPublisher
from ray._private.test_utils import (wait_for_condition, SignalActor,
init_error_pubsub, get_error_message)
@ -61,14 +62,21 @@ def test_unhandled_errors(ray_start_regular):
del os.environ["RAY_IGNORE_UNHANDLED_ERRORS"]
def test_push_error_to_driver_through_redis(ray_start_regular, error_pubsub):
def test_publish_error_to_driver(ray_start_regular, error_pubsub):
address_info = ray_start_regular
address = address_info["redis_address"]
redis_client = ray._private.services.create_redis_client(
address, password=ray.ray_constants.REDIS_DEFAULT_PASSWORD)
gcs_publisher = None
if gcs_pubsub_enabled():
gcs_publisher = GcsPublisher(
address=gcs_utils.get_gcs_address_from_redis(redis_client))
error_message = "Test error message"
ray._private.utils.push_error_to_driver_through_redis(
redis_client, ray_constants.DASHBOARD_AGENT_DIED_ERROR, error_message)
ray._private.utils.publish_error_to_driver(
ray_constants.DASHBOARD_AGENT_DIED_ERROR,
error_message,
redis_client=redis_client,
gcs_publisher=gcs_publisher)
errors = get_error_message(error_pubsub, 1,
ray_constants.DASHBOARD_AGENT_DIED_ERROR)
assert errors[0].type == ray_constants.DASHBOARD_AGENT_DIED_ERROR

View file

@ -104,9 +104,8 @@ def test_connect_with_disconnected_node(shutdown_only):
ray.init(address=cluster.address)
p = init_error_pubsub()
errors = get_error_message(p, 1, timeout=5)
print(errors)
assert len(errors) == 0
# This node is killed by SIGKILL, ray_monitor will mark it to dead.
# This node will be killed by SIGKILL, ray_monitor will mark it to dead.
dead_node = cluster.add_node(num_cpus=0)
cluster.remove_node(dead_node, allow_graceful=False)
errors = get_error_message(p, 1, ray_constants.REMOVED_NODE_ERROR)

View file

@ -0,0 +1,73 @@
import sys
import ray
import ray._private.gcs_utils as gcs_utils
from ray._private.gcs_pubsub import GcsPublisher, GcsSubscriber, \
GcsAioPublisher, GcsAioSubscriber
from ray.core.generated.gcs_pb2 import ErrorTableData
import pytest
@pytest.mark.parametrize(
"ray_start_regular", [{
"_system_config": {
"gcs_grpc_based_pubsub": True
}
}],
indirect=True)
def test_publish_and_subscribe_error_info(ray_start_regular):
address_info = ray_start_regular
redis = ray._private.services.create_redis_client(
address_info["redis_address"],
password=ray.ray_constants.REDIS_DEFAULT_PASSWORD)
gcs_server_addr = gcs_utils.get_gcs_address_from_redis(redis)
subscriber = GcsSubscriber(address=gcs_server_addr)
subscriber.subscribe_error()
publisher = GcsPublisher(address=gcs_server_addr)
err1 = ErrorTableData(error_message="test error message 1")
err2 = ErrorTableData(error_message="test error message 2")
publisher.publish_error(b"aaa_id", err1)
publisher.publish_error(b"bbb_id", err2)
assert subscriber.poll_error() == (b"aaa_id", err1)
assert subscriber.poll_error() == (b"bbb_id", err2)
subscriber.close()
@pytest.mark.asyncio
@pytest.mark.parametrize(
"ray_start_regular", [{
"_system_config": {
"gcs_grpc_based_pubsub": True
}
}],
indirect=True)
async def test_aio_publish_and_subscribe_error_info(ray_start_regular):
address_info = ray_start_regular
redis = ray._private.services.create_redis_client(
address_info["redis_address"],
password=ray.ray_constants.REDIS_DEFAULT_PASSWORD)
gcs_server_addr = gcs_utils.get_gcs_address_from_redis(redis)
subscriber = GcsAioSubscriber(address=gcs_server_addr)
await subscriber.subscribe_error()
publisher = GcsAioPublisher(address=gcs_server_addr)
err1 = ErrorTableData(error_message="test error message 1")
err2 = ErrorTableData(error_message="test error message 2")
await publisher.publish_error(b"aaa_id", err1)
await publisher.publish_error(b"bbb_id", err2)
assert await subscriber.poll_error() == (b"aaa_id", err1)
assert await subscriber.poll_error() == (b"bbb_id", err2)
await subscriber.close()
if __name__ == "__main__":
sys.exit(pytest.main(["-v", __file__]))

View file

@ -27,6 +27,8 @@ import ray.remote_function
import ray.serialization as serialization
import ray._private.gcs_utils as gcs_utils
import ray._private.services as services
from ray._private.gcs_pubsub import gcs_pubsub_enabled, GcsPublisher, \
GcsSubscriber
from ray._private.runtime_env.py_modules import upload_py_modules_if_needed
from ray._private.runtime_env.working_dir import upload_working_dir_if_needed
from ray._private.runtime_env.constants import RAY_JOB_CONFIG_JSON_ENV_VAR
@ -1247,6 +1249,58 @@ def listen_error_messages_raylet(worker, threads_stopped):
worker.error_message_pubsub_client.close()
def listen_error_messages_from_gcs(worker, threads_stopped):
"""Listen to error messages in the background on the driver.
This runs in a separate thread on the driver and pushes (error, time)
tuples to be published.
Args:
worker: The worker class that this thread belongs to.
threads_stopped (threading.Event): A threading event used to signal to
the thread that it should exit.
"""
worker.gcs_subscriber = GcsSubscriber(channel=worker.gcs_channel)
# Exports that are published after the call to
# gcs_subscriber.subscribe_error() and before the call to
# gcs_subscriber.poll_error() will still be processed in the loop.
# TODO: we should just subscribe to the errors for this specific job.
worker.gcs_subscriber.subscribe_error()
try:
if _internal_kv_initialized():
# Get any autoscaler errors that occurred before the call to
# subscribe.
error_message = _internal_kv_get(DEBUG_AUTOSCALING_ERROR)
if error_message is not None:
logger.warning(error_message.decode())
while True:
# Exit if received a signal that the thread should stop.
if threads_stopped.is_set():
return
_, error_data = worker.gcs_subscriber.poll_error()
if error_data is None:
continue
if error_data.job_id not in [
worker.current_job_id.binary(),
JobID.nil().binary(),
]:
continue
error_message = error_data.error_message
if error_data.type == ray_constants.TASK_PUSH_ERROR:
# TODO(ekl) remove task push errors entirely now that we have
# the separate unhandled exception handler.
pass
else:
logger.warning(error_message)
except (OSError, ConnectionError) as e:
logger.error(f"listen_error_messages_from_gcs: {e}")
@PublicAPI
@client_mode_hook(auto_init=False)
def is_initialized() -> bool:
@ -1307,11 +1361,16 @@ def connect(node,
# that is not true of Redis pubsub clients. See the documentation at
# https://github.com/andymccurdy/redis-py#thread-safety.
worker.redis_client = node.create_redis_client()
worker.gcs_client = gcs_utils.GcsClient.create_from_redis(
worker.redis_client)
worker.gcs_channel = gcs_utils.create_gcs_channel(
gcs_utils.get_gcs_address_from_redis(worker.redis_client))
worker.gcs_client = gcs_utils.GcsClient(channel=worker.gcs_channel)
_initialize_internal_kv(worker.gcs_client)
ray.state.state._initialize_global_state(
node.redis_address, redis_password=node.redis_password)
worker.gcs_pubsub_enabled = gcs_pubsub_enabled()
worker.gcs_publisher = None
if worker.gcs_pubsub_enabled:
worker.gcs_publisher = GcsPublisher(channel=worker.gcs_channel)
# Initialize some fields.
if mode in (WORKER_MODE, RESTORE_WORKER_MODE, SPILL_WORKER_MODE):
@ -1350,11 +1409,12 @@ def connect(node,
raise e
elif mode == WORKER_MODE:
traceback_str = traceback.format_exc()
ray._private.utils.push_error_to_driver_through_redis(
worker.redis_client,
ray._private.utils.publish_error_to_driver(
ray_constants.VERSION_MISMATCH_PUSH_ERROR,
traceback_str,
job_id=None)
job_id=None,
redis_client=worker.redis_client,
gcs_publisher=worker.gcs_publisher)
worker.lock = threading.RLock()
@ -1445,7 +1505,8 @@ def connect(node,
# scheduler for new error messages.
if mode == SCRIPT_MODE:
worker.listener_thread = threading.Thread(
target=listen_error_messages_raylet,
target=listen_error_messages_from_gcs
if worker.gcs_pubsub_enabled else listen_error_messages_raylet,
name="ray_listen_error_messages",
args=(worker, worker.threads_stopped))
worker.listener_thread.daemon = True
@ -1519,6 +1580,8 @@ def disconnect(exiting_interpreter=False):
if hasattr(worker, "import_thread"):
worker.import_thread.join_import_thread()
if hasattr(worker, "listener_thread"):
if hasattr(worker, "gcs_subscriber"):
worker.gcs_subscriber.close()
worker.listener_thread.join()
if hasattr(worker, "logger_thread"):
worker.logger_thread.join()

View file

@ -116,17 +116,6 @@ class MockGcsTaskReconstructionTable : public GcsTaskReconstructionTable {
namespace ray {
namespace gcs {
class MockGcsObjectTable : public GcsObjectTable {
public:
MOCK_METHOD(JobID, GetJobIdFromKey, (const ObjectID &key), (override));
};
} // namespace gcs
} // namespace ray
namespace ray {
namespace gcs {
class MockGcsNodeTable : public GcsNodeTable {
public:
};

View file

@ -72,7 +72,8 @@ void GcsServer::Start() {
rpc::ChannelType::GCS_JOB_CHANNEL,
rpc::ChannelType::GCS_NODE_INFO_CHANNEL,
rpc::ChannelType::GCS_NODE_RESOURCE_CHANNEL,
rpc::ChannelType::GCS_WORKER_DELTA_CHANNEL},
rpc::ChannelType::GCS_WORKER_DELTA_CHANNEL,
rpc::ChannelType::RAY_ERROR_INFO_CHANNEL},
/*periodical_runner=*/&pubsub_periodical_runner_,
/*get_time_ms=*/[]() { return absl::GetCurrentTimeNanos() / 1e6; },
/*subscriber_timeout_ms=*/RayConfig::instance().subscriber_timeout_ms(),

View file

@ -213,17 +213,6 @@ class GcsTaskReconstructionTable
JobID GetJobIdFromKey(const TaskID &key) override { return key.ActorId().JobId(); }
};
class GcsObjectTable : public GcsTableWithJobId<ObjectID, ObjectLocationInfo> {
public:
explicit GcsObjectTable(std::shared_ptr<StoreClient> store_client)
: GcsTableWithJobId(std::move(store_client)) {
table_name_ = TablePrefix_Name(TablePrefix::OBJECT);
}
private:
JobID GetJobIdFromKey(const ObjectID &key) override { return key.TaskId().JobId(); }
};
class GcsNodeTable : public GcsTable<NodeID, GcsNodeInfo> {
public:
explicit GcsNodeTable(std::shared_ptr<StoreClient> store_client)

View file

@ -17,12 +17,35 @@
namespace ray {
namespace gcs {
void InternalPubSubHandler::HandleGcsPublish(const rpc::GcsPublishRequest &request,
rpc::GcsPublishReply *reply,
rpc::SendReplyCallback send_reply_callback) {
if (gcs_publisher_ == nullptr) {
send_reply_callback(
Status::NotImplemented("GCS pubsub is not yet enabled. Please enable it with "
"system config `gcs_grpc_based_pubsub=True`"),
nullptr, nullptr);
return;
}
for (const auto &msg : request.pub_messages()) {
gcs_publisher_->GetPublisher()->Publish(msg);
}
send_reply_callback(Status::OK(), nullptr, nullptr);
}
// Needs to use rpc::GcsSubscriberPollRequest and rpc::GcsSubscriberPollReply here,
// and convert the reply to rpc::PubsubLongPollingReply because GCS RPC services are
// required to have the `status` field in replies.
void InternalPubSubHandler::HandleGcsSubscriberPoll(
const rpc::GcsSubscriberPollRequest &request, rpc::GcsSubscriberPollReply *reply,
rpc::SendReplyCallback send_reply_callback) {
if (gcs_publisher_ == nullptr) {
send_reply_callback(
Status::NotImplemented("GCS pubsub is not yet enabled. Please enable it with "
"system config `gcs_grpc_based_pubsub=True`"),
nullptr, nullptr);
return;
}
const auto subscriber_id = UniqueID::FromBinary(request.subscriber_id());
auto pubsub_reply = std::make_shared<rpc::PubsubLongPollingReply>();
auto pubsub_reply_ptr = pubsub_reply.get();
@ -44,6 +67,13 @@ void InternalPubSubHandler::HandleGcsSubscriberCommandBatch(
const rpc::GcsSubscriberCommandBatchRequest &request,
rpc::GcsSubscriberCommandBatchReply *reply,
rpc::SendReplyCallback send_reply_callback) {
if (gcs_publisher_ == nullptr) {
send_reply_callback(
Status::NotImplemented("GCS pubsub is not yet enabled. Please enable it with "
"system config `gcs_grpc_based_pubsub=True`"),
nullptr, nullptr);
return;
}
const auto subscriber_id = UniqueID::FromBinary(request.subscriber_id());
for (const auto &command : request.commands()) {
if (command.has_unsubscribe_message()) {

View file

@ -16,6 +16,7 @@
#include "ray/gcs/pubsub/gcs_pub_sub.h"
#include "ray/rpc/gcs_server/gcs_rpc_server.h"
#include "src/ray/protobuf/gcs_service.grpc.pb.h"
namespace ray {
namespace gcs {
@ -28,6 +29,10 @@ class InternalPubSubHandler : public rpc::InternalPubSubHandler {
explicit InternalPubSubHandler(const std::shared_ptr<gcs::GcsPublisher> &gcs_publisher)
: gcs_publisher_(gcs_publisher) {}
void HandleGcsPublish(const rpc::GcsPublishRequest &request,
rpc::GcsPublishReply *reply,
rpc::SendReplyCallback send_reply_callback) final;
void HandleGcsSubscriberPoll(const rpc::GcsSubscriberPollRequest &request,
rpc::GcsSubscriberPollReply *reply,
rpc::SendReplyCallback send_reply_callback) final;

View file

@ -302,6 +302,17 @@ Status GcsPublisher::PublishTaskLease(const TaskID &id, const rpc::TaskLeaseData
Status GcsPublisher::PublishError(const std::string &id,
const rpc::ErrorTableData &message,
const StatusCallback &done) {
if (publisher_ != nullptr) {
rpc::PubMessage msg;
msg.set_channel_type(rpc::ChannelType::RAY_ERROR_INFO_CHANNEL);
msg.set_key_id(id);
*msg.mutable_error_info_message() = message;
publisher_->Publish(msg);
if (done != nullptr) {
done(Status::OK());
}
return Status::OK();
}
return pubsub_->Publish(ERROR_INFO_CHANNEL, id, message.SerializeAsString(), done);
}

View file

@ -563,6 +563,16 @@ service InternalKVGcsService {
rpc InternalKVKeys(InternalKVKeysRequest) returns (InternalKVKeysReply);
}
message GcsPublishRequest {
/// The messages that are published.
repeated PubMessage pub_messages = 1;
}
message GcsPublishReply {
// Not populated.
GcsStatus status = 100;
}
message GcsSubscriberPollRequest {
/// The id of the subscriber.
bytes subscriber_id = 1;
@ -590,6 +600,9 @@ message GcsSubscriberCommandBatchReply {
/// This supports subscribing updates from GCS with long poll, and registering /
/// de-registering subscribers.
service InternalPubSubGcsService {
/// The request to sent to GCS to publish messages.
/// Currently only supporting error info, logs and Python function messages.
rpc GcsPublish(GcsPublishRequest) returns (GcsPublishReply);
/// The long polling request sent to GCS for pubsub operations.
/// The long poll request will be replied once there are a batch of messages that
/// need to be published to the caller (subscriber).

View file

@ -40,6 +40,8 @@ enum ChannelType {
GCS_NODE_RESOURCE_CHANNEL = 6;
/// A channel for worker changes, currently only for worker failures.
GCS_WORKER_DELTA_CHANNEL = 7;
/// A channel for errors from various Ray components.
RAY_ERROR_INFO_CHANNEL = 8;
}
///
@ -61,6 +63,7 @@ message PubMessage {
GcsNodeInfo node_info_message = 9;
NodeResourceChange node_resource_message = 10;
WorkerDeltaData worker_delta_message = 11;
ErrorTableData error_info_message = 12;
// The message that indicates the given key id is not available anymore.
FailureMessage failure_message = 6;

View file

@ -145,6 +145,10 @@ bool SubscriptionIndex::CheckNoLeaks() const {
bool Subscriber::ConnectToSubscriber(rpc::PubsubLongPollingReply *reply,
rpc::SendReplyCallback send_reply_callback) {
if (long_polling_connection_) {
// Flush the current subscriber poll with an empty reply.
PublishIfPossible(/*force_noop=*/true);
}
if (!long_polling_connection_) {
RAY_CHECK(reply != nullptr);
RAY_CHECK(send_reply_callback != nullptr);
@ -171,30 +175,25 @@ void Subscriber::QueueMessage(const rpc::PubMessage &pub_message, bool try_publi
}
}
bool Subscriber::PublishIfPossible(bool force) {
bool Subscriber::PublishIfPossible(bool force_noop) {
if (!long_polling_connection_) {
return false;
}
if (force || mailbox_.size() > 0) {
// If force publish is invoked, mailbox could be empty. We should always add a reply
// here because otherwise, there could be memory leak due to our grpc layer
// implementation.
if (mailbox_.empty()) {
mailbox_.push(absl::make_unique<rpc::PubsubLongPollingReply>());
if (!force_noop && mailbox_.empty()) {
return false;
}
if (!force_noop) {
// Reply to the long polling subscriber. Swap the reply here to avoid extra copy.
long_polling_connection_->reply->Swap(mailbox_.front().get());
mailbox_.pop();
}
long_polling_connection_->send_reply_callback(Status::OK(), nullptr, nullptr);
// Clean up & update metadata.
long_polling_connection_.reset();
mailbox_.pop();
last_connection_update_time_ms_ = get_time_ms_();
return true;
}
return false;
}
bool Subscriber::CheckNoLeaks() const {
@ -311,7 +310,7 @@ int Publisher::UnregisterSubscriberInternal(const SubscriberID &subscriber_id) {
}
auto &subscriber = it->second;
// Remove the long polling connection because otherwise, there's memory leak.
subscriber->PublishIfPossible(/*force=*/true);
subscriber->PublishIfPossible(/*force_noop=*/true);
subscribers_.erase(it);
return erased;
}
@ -330,8 +329,8 @@ void Publisher::CheckDeadSubscribers() {
if (disconnected) {
dead_subscribers.push_back(it.first);
} else if (active_connection_timed_out) {
// Refresh the long polling connection. The subscriber will send it again.
subscriber->PublishIfPossible(/*force*/ true);
// Refresh the long polling connection. The subscriber will poll again.
subscriber->PublishIfPossible(/*force_noop*/ true);
}
}

View file

@ -123,10 +123,11 @@ class Subscriber {
/// Publish all queued messages if possible.
///
/// \param force If true, we publish to the subscriber although there's no queued
/// message.
/// \param force_noop If true, reply to the subscriber with an empty message, regardless
/// of whethere there is any queued message. This is for cases where the current poll
/// might have been cancelled, or the subscriber might be dead.
/// \return True if it publishes. False otherwise.
bool PublishIfPossible(bool force = false);
bool PublishIfPossible(bool force_noop = false);
/// Testing only. Return true if there's no metadata remained in the private attribute.
bool CheckNoLeaks() const;

View file

@ -285,8 +285,8 @@ TEST_F(PublisherTest, TestSubscriber) {
ASSERT_FALSE(subscriber->PublishIfPossible());
// Try connecting it. Should return true.
ASSERT_TRUE(subscriber->ConnectToSubscriber(&reply, send_reply_callback));
// If connecting it again, it should fail the request.
ASSERT_FALSE(subscriber->ConnectToSubscriber(&reply, send_reply_callback));
// Polling when there is already an inflight polling request should still work.
ASSERT_TRUE(subscriber->ConnectToSubscriber(&reply, send_reply_callback));
// Since there's no published objects, it should return false.
ASSERT_FALSE(subscriber->PublishIfPossible());

View file

@ -277,6 +277,8 @@ class GcsRpcClient {
internal_kv_grpc_client_, )
/// Operations for pubsub
VOID_GCS_RPC_CLIENT_METHOD(InternalPubSubGcsService, GcsPublish,
internal_pubsub_grpc_client_, )
VOID_GCS_RPC_CLIENT_METHOD(InternalPubSubGcsService, GcsSubscriberPoll,
internal_pubsub_grpc_client_, )
VOID_GCS_RPC_CLIENT_METHOD(InternalPubSubGcsService, GcsSubscriberCommandBatch,

View file

@ -606,6 +606,9 @@ class InternalPubSubGcsServiceHandler {
public:
virtual ~InternalPubSubGcsServiceHandler() = default;
virtual void HandleGcsPublish(const GcsPublishRequest &request, GcsPublishReply *reply,
SendReplyCallback send_reply_callback) = 0;
virtual void HandleGcsSubscriberPoll(const GcsSubscriberPollRequest &request,
GcsSubscriberPollReply *reply,
SendReplyCallback send_reply_callback) = 0;
@ -626,6 +629,7 @@ class InternalPubSubGrpcService : public GrpcService {
void InitServerCallFactories(
const std::unique_ptr<grpc::ServerCompletionQueue> &cq,
std::vector<std::unique_ptr<ServerCallFactory>> *server_call_factories) override {
INTERNAL_PUBSUB_SERVICE_RPC_HANDLER(GcsPublish);
INTERNAL_PUBSUB_SERVICE_RPC_HANDLER(GcsSubscriberPoll);
INTERNAL_PUBSUB_SERVICE_RPC_HANDLER(GcsSubscriberCommandBatch);
}