[Core][Pubsub] Implement Python GCS publisher and subscriber (#20111)

## Why are these changes needed?
This change adds Python publisher and subscriber in `gcs_utils.py`, and GRPC handler on GCS for publishing iva GCS. Error info is migrated to use the GCS-based pubsub, if feature flag `RAY_gcs_grpc_based_pubsub=true`.

Also, add a `--gcs-address` flag to some Python processes. It is not set anywhere yet, but will be set aftering Redis-less bootstrapping work.

Unit tests are added for the Python publisher and subscriber. Migrated error info publishers and subscribers are tested with existing unit tests, e.g. tests calling `ray._private.test_utils.get_error_message()` to ensure error info is published.

GCS based pubsub has gaps in handling deadline, cancelled requests and GCS restarts. So 3 more unit tests are disabled in the `HA GCS` mode. They will be addressed in a separate change.

## Related issue number
This commit is contained in:
mwtian 2021-11-11 14:59:57 -08:00 committed by GitHub
parent fca851eef5
commit 0330852baf
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
31 changed files with 763 additions and 188 deletions

View file

@ -343,6 +343,7 @@
--test_env=RAY_gcs_grpc_based_pubsub=true --test_env=RAY_gcs_grpc_based_pubsub=true
-- //python/ray/tests/... -- //python/ray/tests/...
-//python/ray/tests:test_failure_2 -//python/ray/tests:test_failure_2
-//python/ray/tests:test_job
- bazel test --config=ci $(./scripts/bazel_export_options) - bazel test --config=ci $(./scripts/bazel_export_options)
--test_tag_filters=-kubernetes,client_tests,-flaky --test_tag_filters=-kubernetes,client_tests,-flaky
--test_env=RAY_CLIENT_MODE=1 --test_env=RAY_PROFILING=1 --test_env=RAY_CLIENT_MODE=1 --test_env=RAY_PROFILING=1
@ -358,6 +359,7 @@
-- //python/ray/tests/... -- //python/ray/tests/...
-//python/ray/tests:test_client_multi -//python/ray/tests:test_component_failures_3 -//python/ray/tests:test_client_multi -//python/ray/tests:test_component_failures_3
-//python/ray/tests:test_healthcheck -//python/ray/tests:test_gcs_fault_tolerance -//python/ray/tests:test_healthcheck -//python/ray/tests:test_gcs_fault_tolerance
-//python/ray/tests:test_client
- label: ":redis: HA GCS (Medium K-Z)" - label: ":redis: HA GCS (Medium K-Z)"
conditions: ["RAY_CI_PYTHON_AFFECTED"] conditions: ["RAY_CI_PYTHON_AFFECTED"]
commands: commands:
@ -368,6 +370,7 @@
-- //python/ray/tests/... -- //python/ray/tests/...
-//python/ray/tests:test_multinode_failures_2 -//python/ray/tests:test_ray_debugger -//python/ray/tests:test_multinode_failures_2 -//python/ray/tests:test_ray_debugger
-//python/ray/tests:test_placement_group_2 -//python/ray/tests:test_placement_group_3 -//python/ray/tests:test_placement_group_2 -//python/ray/tests:test_placement_group_3
-//python/ray/tests:test_multi_node
- label: ":brain: RLlib: Learning discr. actions TF2-static-graph (from rllib/tuned_examples/*.yaml)" - label: ":brain: RLlib: Learning discr. actions TF2-static-graph (from rllib/tuned_examples/*.yaml)"
conditions: ["RAY_CI_RLLIB_AFFECTED"] conditions: ["RAY_CI_RLLIB_AFFECTED"]

View file

@ -19,7 +19,9 @@ import ray.dashboard.utils as dashboard_utils
import ray.ray_constants as ray_constants import ray.ray_constants as ray_constants
import ray._private.services import ray._private.services
import ray._private.utils import ray._private.utils
from ray._private.gcs_utils import GcsClient from ray._private.gcs_pubsub import gcs_pubsub_enabled, GcsPublisher
from ray._private.gcs_utils import GcsClient, \
get_gcs_address_from_redis
from ray.core.generated import agent_manager_pb2 from ray.core.generated import agent_manager_pb2
from ray.core.generated import agent_manager_pb2_grpc from ray.core.generated import agent_manager_pb2_grpc
from ray._private.ray_logging import setup_component_logger from ray._private.ray_logging import setup_component_logger
@ -228,6 +230,11 @@ if __name__ == "__main__":
required=True, required=True,
type=str, type=str,
help="the IP address of this node.") help="the IP address of this node.")
parser.add_argument(
"--gcs-address",
required=False,
type=str,
help="The address (ip:port) of GCS.")
parser.add_argument( parser.add_argument(
"--redis-address", "--redis-address",
required=True, required=True,
@ -377,6 +384,12 @@ if __name__ == "__main__":
# impact of the issue. # impact of the issue.
redis_client = ray._private.services.create_redis_client( redis_client = ray._private.services.create_redis_client(
args.redis_address, password=args.redis_password) args.redis_address, password=args.redis_password)
gcs_publisher = None
if args.gcs_address:
gcs_publisher = GcsPublisher(args.gcs_address)
elif gcs_pubsub_enabled():
gcs_publisher = GcsPublisher(
address=get_gcs_address_from_redis(redis_client))
traceback_str = ray._private.utils.format_error_message( traceback_str = ray._private.utils.format_error_message(
traceback.format_exc()) traceback.format_exc())
message = ( message = (
@ -390,9 +403,11 @@ if __name__ == "__main__":
"\n 3. runtime_env APIs won't work." "\n 3. runtime_env APIs won't work."
"\nCheck out the `dashboard_agent.log` to see the " "\nCheck out the `dashboard_agent.log` to see the "
"detailed failure messages.") "detailed failure messages.")
ray._private.utils.push_error_to_driver_through_redis( ray._private.utils.publish_error_to_driver(
redis_client, ray_constants.DASHBOARD_AGENT_DIED_ERROR, ray_constants.DASHBOARD_AGENT_DIED_ERROR,
message) message,
redis_client=redis_client,
gcs_publisher=gcs_publisher)
logger.error(message) logger.error(message)
logger.exception(e) logger.exception(e)
exit(1) exit(1)

View file

@ -13,8 +13,10 @@ import ray.dashboard.consts as dashboard_consts
import ray.dashboard.head as dashboard_head import ray.dashboard.head as dashboard_head
import ray.dashboard.utils as dashboard_utils import ray.dashboard.utils as dashboard_utils
import ray.ray_constants as ray_constants import ray.ray_constants as ray_constants
import ray._private.gcs_utils as gcs_utils
import ray._private.services import ray._private.services
import ray._private.utils import ray._private.utils
from ray._private.gcs_pubsub import gcs_pubsub_enabled, GcsPublisher
from ray._private.ray_logging import setup_component_logger from ray._private.ray_logging import setup_component_logger
from ray._private.metrics_agent import PrometheusServiceDiscoveryWriter from ray._private.metrics_agent import PrometheusServiceDiscoveryWriter
@ -134,6 +136,11 @@ if __name__ == "__main__":
type=int, type=int,
default=0, default=0,
help="The retry times to select a valid port.") help="The retry times to select a valid port.")
parser.add_argument(
"--gcs-address",
required=False,
type=str,
help="The address (ip:port) of GCS.")
parser.add_argument( parser.add_argument(
"--redis-address", "--redis-address",
required=True, required=True,
@ -223,18 +230,29 @@ if __name__ == "__main__":
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
loop.run_until_complete(dashboard.run()) loop.run_until_complete(dashboard.run())
except Exception as e: except Exception as e:
# Something went wrong, so push an error to all drivers.
redis_client = ray._private.services.create_redis_client(
args.redis_address, password=args.redis_password)
traceback_str = ray._private.utils.format_error_message( traceback_str = ray._private.utils.format_error_message(
traceback.format_exc()) traceback.format_exc())
message = f"The dashboard on node {platform.uname()[1]} " \ message = f"The dashboard on node {platform.uname()[1]} " \
f"failed with the following " \ f"failed with the following " \
f"error:\n{traceback_str}" f"error:\n{traceback_str}"
ray._private.utils.push_error_to_driver_through_redis(
redis_client, ray_constants.DASHBOARD_DIED_ERROR, message)
if isinstance(e, FrontendNotFoundError): if isinstance(e, FrontendNotFoundError):
logger.warning(message) logger.warning(message)
else: else:
logger.error(message) logger.error(message)
raise e raise e
# Something went wrong, so push an error to all drivers.
redis_client = ray._private.services.create_redis_client(
args.redis_address, password=args.redis_password)
gcs_publisher = None
if args.gcs_address:
gcs_publisher = GcsPublisher(address=args.gcs_address)
elif gcs_pubsub_enabled():
gcs_publisher = GcsPublisher(
address=gcs_utils.get_gcs_address_from_redis(redis_client))
ray._private.utils.publish_error_to_driver(
redis_client,
ray_constants.DASHBOARD_DIED_ERROR,
message,
redis_client=redis_client,
gcs_publisher=gcs_publisher)

View file

@ -19,6 +19,7 @@ import ray._private.services
import ray.dashboard.consts as dashboard_consts import ray.dashboard.consts as dashboard_consts
import ray.dashboard.utils as dashboard_utils import ray.dashboard.utils as dashboard_utils
from ray import ray_constants from ray import ray_constants
from ray._private.gcs_pubsub import gcs_pubsub_enabled, GcsAioSubscriber
from ray.core.generated import gcs_service_pb2 from ray.core.generated import gcs_service_pb2
from ray.core.generated import gcs_service_pb2_grpc from ray.core.generated import gcs_service_pb2_grpc
from ray.dashboard.datacenter import DataOrganizer from ray.dashboard.datacenter import DataOrganizer
@ -44,8 +45,8 @@ GRPC_CHANNEL_OPTIONS = (
async def get_gcs_address_with_retry(redis_client) -> str: async def get_gcs_address_with_retry(redis_client) -> str:
while True: while True:
try: try:
gcs_address = await redis_client.get( gcs_address = (await redis_client.get(
dashboard_consts.REDIS_KEY_GCS_SERVER_ADDRESS) dashboard_consts.REDIS_KEY_GCS_SERVER_ADDRESS)).decode()
if not gcs_address: if not gcs_address:
raise Exception("GCS address not found.") raise Exception("GCS address not found.")
logger.info("Connect to GCS at %s", gcs_address) logger.info("Connect to GCS at %s", gcs_address)
@ -113,6 +114,7 @@ class DashboardHead:
self.log_dir = log_dir self.log_dir = log_dir
self.aioredis_client = None self.aioredis_client = None
self.aiogrpc_gcs_channel = None self.aiogrpc_gcs_channel = None
self.gcs_subscriber = None
self.http_session = None self.http_session = None
self.ip = ray.util.get_node_ip_address() self.ip = ray.util.get_node_ip_address()
ip, port = redis_address.split(":") ip, port = redis_address.split(":")
@ -192,11 +194,15 @@ class DashboardHead:
# Waiting for GCS is ready. # Waiting for GCS is ready.
# TODO: redis-removal bootstrap # TODO: redis-removal bootstrap
gcs_address = await get_gcs_address_with_retry(self.aioredis_client) gcs_address = await get_gcs_address_with_retry(self.aioredis_client)
self.gcs_client = GcsClient(gcs_address) self.gcs_client = GcsClient(address=gcs_address)
internal_kv._initialize_internal_kv(self.gcs_client)
self.aiogrpc_gcs_channel = ray._private.utils.init_grpc_channel( self.aiogrpc_gcs_channel = ray._private.utils.init_grpc_channel(
gcs_address, GRPC_CHANNEL_OPTIONS, asynchronous=True) gcs_address, GRPC_CHANNEL_OPTIONS, asynchronous=True)
gcs_client = GcsClient(gcs_address) self.gcs_subscriber = None
internal_kv._initialize_internal_kv(gcs_client) if gcs_pubsub_enabled():
self.gcs_subscriber = GcsAioSubscriber(
channel=self.aiogrpc_gcs_channel)
await self.gcs_subscriber.subscribe_error()
self.health_check_thread = GCSHealthCheckThread(gcs_address) self.health_check_thread = GCSHealthCheckThread(gcs_address)
self.health_check_thread.start() self.health_check_thread.start()

View file

@ -1,5 +1,4 @@
NODE_STATS_UPDATE_INTERVAL_SECONDS = 1 NODE_STATS_UPDATE_INTERVAL_SECONDS = 1
ERROR_INFO_UPDATE_INTERVAL_SECONDS = 5
LOG_INFO_UPDATE_INTERVAL_SECONDS = 5 LOG_INFO_UPDATE_INTERVAL_SECONDS = 5
UPDATE_NODES_INTERVAL_SECONDS = 5 UPDATE_NODES_INTERVAL_SECONDS = 5
MAX_COUNT_OF_GCS_RPC_ERROR = 10 MAX_COUNT_OF_GCS_RPC_ERROR = 10

View file

@ -272,39 +272,51 @@ class NodeHead(dashboard_utils.DashboardHeadModule):
logger.exception("Error receiving log info.") logger.exception("Error receiving log info.")
async def _update_error_info(self): async def _update_error_info(self):
aioredis_client = self._dashboard_head.aioredis_client def process_error(error_data):
receiver = Receiver() error_data = gcs_utils.ErrorTableData.FromString(pubsub_msg.data)
message = error_data.error_message
message = re.sub(r"\x1b\[\d+m", "", message)
match = re.search(r"\(pid=(\d+), ip=(.*?)\)", message)
if match:
pid = match.group(1)
ip = match.group(2)
errs_for_ip = dict(DataSource.ip_and_pid_to_errors.get(ip, {}))
pid_errors = list(errs_for_ip.get(pid, []))
pid_errors.append({
"message": message,
"timestamp": error_data.timestamp,
"type": error_data.type
})
errs_for_ip[pid] = pid_errors
DataSource.ip_and_pid_to_errors[ip] = errs_for_ip
logger.info(f"Received error entry for {ip} {pid}")
key = gcs_utils.RAY_ERROR_PUBSUB_PATTERN if self._dashboard_head.gcs_subscriber:
pattern = receiver.pattern(key) while True:
await aioredis_client.psubscribe(pattern) _, error_data = await \
logger.info("Subscribed to %s", key) self._dashboard_head.gcs_subscriber.poll_error()
try:
process_error(error_data)
except Exception:
logger.exception("Error receiving error info from GCS.")
else:
aioredis_client = self._dashboard_head.aioredis_client
receiver = Receiver()
async for sender, msg in receiver.iter(): key = gcs_utils.RAY_ERROR_PUBSUB_PATTERN
try: pattern = receiver.pattern(key)
_, data = msg await aioredis_client.psubscribe(pattern)
pubsub_msg = gcs_utils.PubSubMessage.FromString(data) logger.info("Subscribed to %s", key)
error_data = gcs_utils.ErrorTableData.FromString(
pubsub_msg.data) async for _, msg in receiver.iter():
message = error_data.error_message try:
message = re.sub(r"\x1b\[\d+m", "", message) _, data = msg
match = re.search(r"\(pid=(\d+), ip=(.*?)\)", message) pubsub_msg = gcs_utils.PubSubMessage.FromString(data)
if match: error_data = gcs_utils.ErrorTableData.FromString(
pid = match.group(1) pubsub_msg.data)
ip = match.group(2) process_error(error_data)
errs_for_ip = dict( except Exception:
DataSource.ip_and_pid_to_errors.get(ip, {})) logger.exception("Error receiving error info from Redis.")
pid_errors = list(errs_for_ip.get(pid, []))
pid_errors.append({
"message": message,
"timestamp": error_data.timestamp,
"type": error_data.type
})
errs_for_ip[pid] = pid_errors
DataSource.ip_and_pid_to_errors[ip] = errs_for_ip
logger.info(f"Received error entry for {ip} {pid}")
except Exception:
logger.exception("Error receiving error info.")
async def run(self, server): async def run(self, server):
gcs_channel = self._dashboard_head.aiogrpc_gcs_channel gcs_channel = self._dashboard_head.aiogrpc_gcs_channel

View file

@ -0,0 +1,276 @@
import os
from collections import deque
import logging
import random
import threading
from typing import Tuple
import grpc
import ray._private.gcs_utils as gcs_utils
from ray.core.generated import gcs_service_pb2_grpc
from ray.core.generated import gcs_service_pb2
from ray.core.generated.gcs_pb2 import (
ErrorTableData, )
from ray.core.generated import pubsub_pb2
logger = logging.getLogger(__name__)
def gcs_pubsub_enabled():
"""Checks whether GCS pubsub feature flag is enabled."""
return os.environ.get("RAY_gcs_grpc_based_pubsub") not in\
[None, "0", "false"]
def construct_error_message(job_id, error_type, message, timestamp):
"""Construct an ErrorTableData object.
Args:
job_id: The ID of the job that the error should go to. If this is
nil, then the error will go to all drivers.
error_type: The type of the error.
message: The error message.
timestamp: The time of the error.
Returns:
The ErrorTableData object.
"""
data = ErrorTableData()
data.job_id = job_id.binary()
data.type = error_type
data.error_message = message
data.timestamp = timestamp
return data
class _SubscriberBase:
def _subscribe_error_request(self):
cmd = pubsub_pb2.Command(
channel_type=pubsub_pb2.RAY_ERROR_INFO_CHANNEL,
subscribe_message={})
req = gcs_service_pb2.GcsSubscriberCommandBatchRequest(
subscriber_id=self._subscriber_id, commands=[cmd])
return req
def _poll_request(self):
return gcs_service_pb2.GcsSubscriberPollRequest(
subscriber_id=self._subscriber_id)
def _unsubscribe_request(self):
req = gcs_service_pb2.GcsSubscriberCommandBatchRequest(
subscriber_id=self._subscriber_id, commands=[])
if self._subscribed_error:
cmd = pubsub_pb2.Command(
channel_type=pubsub_pb2.RAY_ERROR_INFO_CHANNEL,
unsubscribe_message={})
req.commands.append(cmd)
return req
class GcsPublisher:
"""Publisher to GCS."""
def __init__(self, address: str = None, channel: grpc.Channel = None):
if address:
assert channel is None, \
"address and channel cannot both be specified"
channel = gcs_utils.create_gcs_channel(address)
else:
assert channel is not None, \
"One of address and channel must be specified"
self._stub = gcs_service_pb2_grpc.InternalPubSubGcsServiceStub(channel)
def publish_error(self, key_id: bytes, error_info: ErrorTableData) -> None:
"""Publishes error info to GCS."""
msg = pubsub_pb2.PubMessage(
channel_type=pubsub_pb2.RAY_ERROR_INFO_CHANNEL,
key_id=key_id,
error_info_message=error_info)
req = gcs_service_pb2.GcsPublishRequest(pub_messages=[msg])
self._stub.GcsPublish(req)
class GcsSubscriber(_SubscriberBase):
"""Subscriber to GCS. Thread safe.
Usage example:
subscriber = GcsSubscriber()
subscriber.subscribe_error()
while running:
error_id, error_data = subscriber.poll_error()
......
subscriber.close()
"""
def __init__(
self,
address: str = None,
channel: grpc.Channel = None,
):
if address:
assert channel is None, \
"address and channel cannot both be specified"
channel = gcs_utils.create_gcs_channel(address)
else:
assert channel is not None, \
"One of address and channel must be specified"
self._lock = threading.RLock()
self._stub = gcs_service_pb2_grpc.InternalPubSubGcsServiceStub(channel)
self._subscriber_id = bytes(
bytearray(random.getrandbits(8) for _ in range(28)))
# Whether error info has been subscribed.
self._subscribed_error = False
# Buffer for holding error info.
self._errors = deque()
# Future for indicating whether the subscriber has closed.
self._close = threading.Event()
def subscribe_error(self) -> None:
"""Registers a subscription for error info.
Before the registration, published errors will not be saved for the
subscriber.
"""
with self._lock:
if self._close.is_set():
return
if not self._subscribed_error:
req = self._subscribe_error_request()
self._stub.GcsSubscriberCommandBatch(req, timeout=30)
self._subscribed_error = True
def poll_error(self, timeout=None) -> Tuple[bytes, ErrorTableData]:
"""Polls for new error messages."""
with self._lock:
if self._close.is_set():
return
if len(self._errors) == 0:
req = self._poll_request()
fut = self._stub.GcsSubscriberPoll.future(req, timeout=timeout)
# Wait for result to become available, or cancel if the
# subscriber has closed.
while True:
try:
fut.result(timeout=1)
break
except grpc.FutureTimeoutError:
# Subscriber has closed. Cancel inflight the request
# and return from polling.
if self._close.is_set():
fut.cancel()
return None, None
# GRPC has not replied, continue waiting.
continue
except Exception:
# GRPC error, including deadline exceeded.
raise
if fut.done():
for msg in fut.result().pub_messages:
self._errors.append((msg.key_id,
msg.error_info_message))
if len(self._errors) == 0:
return None, None
return self._errors.popleft()
def close(self) -> None:
"""Closes the subscriber and its active subscriptions."""
# Mark close to terminate inflight polling and prevent future requests.
self._close.set()
with self._lock:
if not self._stub:
# Subscriber already closed.
return
req = self._unsubscribe_request()
try:
self._stub.GcsSubscriberCommandBatch(req, timeout=30)
except Exception:
pass
self._stub = None
class GcsAioPublisher:
"""Publisher to GCS. Uses async io."""
def __init__(self, address: str = None, channel: grpc.aio.Channel = None):
if address:
assert channel is None, \
"address and channel cannot both be specified"
channel = gcs_utils.create_gcs_channel(address, aio=True)
else:
assert channel is not None, \
"One of address and channel must be specified"
self._stub = gcs_service_pb2_grpc.InternalPubSubGcsServiceStub(channel)
async def publish_error(self, key_id: bytes,
error_info: ErrorTableData) -> None:
"""Publishes error info to GCS."""
msg = pubsub_pb2.PubMessage(
channel_type=pubsub_pb2.RAY_ERROR_INFO_CHANNEL,
key_id=key_id,
error_info_message=error_info)
req = gcs_service_pb2.GcsPublishRequest(pub_messages=[msg])
await self._stub.GcsPublish(req)
class GcsAioSubscriber(_SubscriberBase):
"""Async io subscriber to GCS.
Usage example:
subscriber = GcsAioSubscriber()
await subscriber.subscribe_error()
while running:
error_id, error_data = await subscriber.poll_error()
......
await subscriber.close()
"""
def __init__(self, address: str = None, channel: grpc.aio.Channel = None):
if address:
assert channel is None, \
"address and channel cannot both be specified"
channel = gcs_utils.create_gcs_channel(address, aio=True)
else:
assert channel is not None, \
"One of address and channel must be specified"
self._stub = gcs_service_pb2_grpc.InternalPubSubGcsServiceStub(channel)
self._subscriber_id = bytes(
bytearray(random.getrandbits(8) for _ in range(28)))
# Whether error info has been subscribed.
self._subscribed_error = False
# Buffer for holding error info.
self._errors = deque()
async def subscribe_error(self) -> None:
"""Registers a subscription for error info.
Before the registration, published errors will not be saved for the
subscriber.
"""
if not self._subscribed_error:
req = self._subscribe_error_request()
await self._stub.GcsSubscriberCommandBatch(req, timeout=30)
self._subscribed_error = True
async def poll_error(self, timeout=None) -> Tuple[bytes, ErrorTableData]:
"""Polls for new error messages."""
if len(self._errors) == 0:
req = self._poll_request()
reply = await self._stub.GcsSubscriberPoll(req, timeout=timeout)
for msg in reply.pub_messages:
self._errors.append((msg.key_id, msg.error_info_message))
if len(self._errors) == 0:
return None, None
return self._errors.popleft()
async def close(self) -> None:
"""Closes the subscriber and its active subscriptions."""
req = self._unsubscribe_request()
try:
await self._stub.GcsSubscriberCommandBatch(req, timeout=30)
except Exception:
pass
self._subscribed_error = False

View file

@ -1,7 +1,10 @@
from ray.core.generated.common_pb2 import ErrorType
import enum import enum
import logging import logging
from typing import List from typing import List
import grpc
from ray.core.generated.common_pb2 import ErrorType
from ray.core.generated import gcs_service_pb2_grpc from ray.core.generated import gcs_service_pb2_grpc
from ray.core.generated import gcs_service_pb2 from ray.core.generated import gcs_service_pb2
from ray.core.generated.gcs_pb2 import ( from ray.core.generated.gcs_pb2 import (
@ -51,64 +54,69 @@ __all__ = [
"ResourceLoad", "ResourceLoad",
"ResourceMap", "ResourceMap",
"ResourceTableData", "ResourceTableData",
"construct_error_message",
"ObjectLocationInfo", "ObjectLocationInfo",
"PubSubMessage", "PubSubMessage",
"WorkerTableData", "WorkerTableData",
"PlacementGroupTableData", "PlacementGroupTableData",
] ]
FUNCTION_PREFIX = "RemoteFunction:"
LOG_FILE_CHANNEL = "RAY_LOG_CHANNEL" LOG_FILE_CHANNEL = "RAY_LOG_CHANNEL"
REPORTER_CHANNEL = "RAY_REPORTER"
# xray resource usages
XRAY_RESOURCES_BATCH_PATTERN = "RESOURCES_BATCH:".encode("ascii")
# xray job updates
XRAY_JOB_PATTERN = "JOB:*".encode("ascii")
# Actor pub/sub updates # Actor pub/sub updates
RAY_ACTOR_PUBSUB_PATTERN = "ACTOR:*".encode("ascii") RAY_ACTOR_PUBSUB_PATTERN = "ACTOR:*".encode("ascii")
# Reporter pub/sub updates
RAY_REPORTER_PUBSUB_PATTERN = "RAY_REPORTER.*".encode("ascii")
RAY_ERROR_PUBSUB_PATTERN = "ERROR_INFO:*".encode("ascii") RAY_ERROR_PUBSUB_PATTERN = "ERROR_INFO:*".encode("ascii")
# These prefixes must be kept up-to-date with the TablePrefix enum in # These prefixes must be kept up-to-date with the TablePrefix enum in
# gcs.proto. # gcs.proto.
# TODO(rkn): We should use scoped enums, in which case we should be able to
# just access the flatbuffer generated values.
TablePrefix_RAYLET_TASK_string = "RAYLET_TASK"
TablePrefix_OBJECT_string = "OBJECT"
TablePrefix_PROFILE_string = "PROFILE"
TablePrefix_JOB_string = "JOB"
TablePrefix_ACTOR_string = "ACTOR" TablePrefix_ACTOR_string = "ACTOR"
WORKER = 0 WORKER = 0
DRIVER = 1 DRIVER = 1
# Cap messages at 512MB
_MAX_MESSAGE_LENGTH = 512 * 1024 * 1024
# Send keepalive every 60s
_GRPC_KEEPALIVE_TIME_MS = 60 * 1000
# Keepalive should be replied < 60s
_GRPC_KEEPALIVE_TIMEOUT_MS = 60 * 1000
def construct_error_message(job_id, error_type, message, timestamp): # Also relying on these defaults:
"""Construct a serialized ErrorTableData object. # grpc.keepalive_permit_without_calls=0: No keepalive without inflight calls.
# grpc.use_local_subchannel_pool=0: Subchannels are shared.
_GRPC_OPTIONS = [("grpc.enable_http_proxy",
0), ("grpc.max_send_message_length", _MAX_MESSAGE_LENGTH),
("grpc.max_receive_message_length", _MAX_MESSAGE_LENGTH),
("grpc.keepalive_time_ms",
_GRPC_KEEPALIVE_TIME_MS), ("grpc.keepalive_timeout_ms",
_GRPC_KEEPALIVE_TIMEOUT_MS)]
def get_gcs_address_from_redis(redis) -> str:
"""Reads GCS address from redis.
Args: Args:
job_id: The ID of the job that the error should go to. If this is redis: Redis client to fetch GCS address.
nil, then the error will go to all drivers.
error_type: The type of the error.
message: The error message.
timestamp: The time of the error.
Returns: Returns:
The serialized object. GCS address string.
""" """
data = ErrorTableData() gcs_address = redis.get("GcsServerAddress")
data.job_id = job_id.binary() if gcs_address is None:
data.type = error_type raise RuntimeError("Failed to look up gcs address through redis")
data.error_message = message return gcs_address.decode()
data.timestamp = timestamp
return data.SerializeToString()
def create_gcs_channel(address: str, aio=False):
"""Returns a GRPC channel to GCS.
Args:
address: GCS address string, e.g. ip:port
aio: Whether using grpc.aio
Returns:
grpc.Channel or grpc.aio.Channel to GCS
"""
from ray._private.utils import init_grpc_channel
return init_grpc_channel(address, options=_GRPC_OPTIONS, asynchronous=aio)
class GcsCode(enum.IntEnum): class GcsCode(enum.IntEnum):
@ -118,17 +126,13 @@ class GcsCode(enum.IntEnum):
class GcsClient: class GcsClient:
MAX_MESSAGE_LENGTH = 512 * 1024 * 1024 # 512MB """Client to GCS using GRPC"""
def __init__(self, address): def __init__(self, address: str = None, channel: grpc.Channel = None):
from ray._private.utils import init_grpc_channel if address:
logger.debug(f"Connecting to gcs address: {address}") assert channel is None, \
options = [("grpc.enable_http_proxy", "Only one of address and channel can be specified"
0), ("grpc.max_send_message_length", channel = create_gcs_channel(address)
GcsClient.MAX_MESSAGE_LENGTH),
("grpc.max_receive_message_length",
GcsClient.MAX_MESSAGE_LENGTH)]
channel = init_grpc_channel(address, options=options)
self._kv_stub = gcs_service_pb2_grpc.InternalKVGcsServiceStub(channel) self._kv_stub = gcs_service_pb2_grpc.InternalKVGcsServiceStub(channel)
def internal_kv_get(self, key: bytes) -> bytes: def internal_kv_get(self, key: bytes) -> bytes:
@ -187,10 +191,7 @@ class GcsClient:
@staticmethod @staticmethod
def create_from_redis(redis_cli): def create_from_redis(redis_cli):
gcs_address = redis_cli.get("GcsServerAddress") return GcsClient(get_gcs_address_from_redis(redis_cli))
if gcs_address is None:
raise RuntimeError("Failed to look up gcs address through redis")
return GcsClient(gcs_address.decode())
@staticmethod @staticmethod
def connect_to_gcs_by_redis_address(redis_address, redis_password): def connect_to_gcs_by_redis_address(redis_address, redis_password):

View file

@ -15,6 +15,7 @@ import ray.ray_constants as ray_constants
import ray._private.gcs_utils as gcs_utils import ray._private.gcs_utils as gcs_utils
import ray._private.services as services import ray._private.services as services
import ray._private.utils import ray._private.utils
from ray._private.gcs_pubsub import gcs_pubsub_enabled, GcsPublisher
from ray._private.ray_logging import setup_component_logger from ray._private.ray_logging import setup_component_logger
# Logger for this module. It should be configured at the entry point # Logger for this module. It should be configured at the entry point
@ -365,6 +366,11 @@ if __name__ == "__main__":
description=("Parse Redis server for the " description=("Parse Redis server for the "
"log monitor to connect " "log monitor to connect "
"to.")) "to."))
parser.add_argument(
"--gcs-address",
required=False,
type=str,
help="The address (ip:port) of GCS.")
parser.add_argument( parser.add_argument(
"--redis-address", "--redis-address",
required=True, required=True,
@ -436,11 +442,20 @@ if __name__ == "__main__":
# Something went wrong, so push an error to all drivers. # Something went wrong, so push an error to all drivers.
redis_client = ray._private.services.create_redis_client( redis_client = ray._private.services.create_redis_client(
args.redis_address, password=args.redis_password) args.redis_address, password=args.redis_password)
gcs_publisher = None
if args.gcs_address:
gcs_publisher = GcsPublisher(address=args.gcs_address)
elif gcs_pubsub_enabled():
gcs_publisher = GcsPublisher(
address=gcs_utils.get_gcs_address_from_redis(redis_client))
traceback_str = ray._private.utils.format_error_message( traceback_str = ray._private.utils.format_error_message(
traceback.format_exc()) traceback.format_exc())
message = (f"The log monitor on node {platform.node()} " message = (f"The log monitor on node {platform.node()} "
f"failed with the following error:\n{traceback_str}") f"failed with the following error:\n{traceback_str}")
ray._private.utils.push_error_to_driver_through_redis( ray._private.utils.publish_error_to_driver(
redis_client, ray_constants.LOG_MONITOR_DIED_ERROR, message) ray_constants.LOG_MONITOR_DIED_ERROR,
message,
redis_client=redis_client,
gcs_publisher=gcs_publisher)
logger.error(message) logger.error(message)
raise e raise e

View file

@ -1926,6 +1926,7 @@ def start_monitor(redis_address,
Args: Args:
redis_address (str): The address that the Redis server is listening on. redis_address (str): The address that the Redis server is listening on.
gcs_address (str): The address of GCS server.
logs_dir(str): The path to the log directory. logs_dir(str): The path to the log directory.
stdout_file: A file handle opened for writing to redirect stdout to. If stdout_file: A file handle opened for writing to redirect stdout to. If
no redirection should happen, then this should be None. no redirection should happen, then this should be None.

View file

@ -26,6 +26,7 @@ import ray._private.gcs_utils as gcs_utils
import ray._private.memory_monitor as memory_monitor import ray._private.memory_monitor as memory_monitor
from ray.core.generated import node_manager_pb2 from ray.core.generated import node_manager_pb2
from ray.core.generated import node_manager_pb2_grpc from ray.core.generated import node_manager_pb2_grpc
from ray._private.gcs_pubsub import gcs_pubsub_enabled, GcsSubscriber
from ray._private.tls_utils import generate_self_signed_tls_certs from ray._private.tls_utils import generate_self_signed_tls_certs
from ray.util.queue import Queue, _QueueActor, Empty from ray.util.queue import Queue, _QueueActor, Empty
from ray.scripts.scripts import main as ray_main from ray.scripts.scripts import main as ray_main
@ -502,24 +503,39 @@ def get_non_head_nodes(cluster):
def init_error_pubsub(): def init_error_pubsub():
"""Initialize redis error info pub/sub""" """Initialize redis error info pub/sub"""
p = ray.worker.global_worker.redis_client.pubsub( if gcs_pubsub_enabled():
ignore_subscribe_messages=True) s = GcsSubscriber(channel=ray.worker.global_worker.gcs_channel)
error_pubsub_channel = gcs_utils.RAY_ERROR_PUBSUB_PATTERN s.subscribe_error()
p.psubscribe(error_pubsub_channel) else:
return p s = ray.worker.global_worker.redis_client.pubsub(
ignore_subscribe_messages=True)
s.psubscribe(gcs_utils.RAY_ERROR_PUBSUB_PATTERN)
return s
def get_error_message(pub_sub, num, error_type=None, timeout=20): def get_error_message(subscriber, num, error_type=None, timeout=20):
"""Get errors through pub/sub.""" """Get errors through subscriber."""
start_time = time.time() deadline = time.time() + timeout
msgs = [] msgs = []
while time.time() - start_time < timeout and len(msgs) < num: while time.time() < deadline and len(msgs) < num:
msg = pub_sub.get_message() if isinstance(subscriber, GcsSubscriber):
if msg is None: try:
time.sleep(0.01) _, error_data = subscriber.poll_error(timeout=deadline -
continue time.time())
pubsub_msg = gcs_utils.PubSubMessage.FromString(msg["data"]) except grpc.RpcError as e:
error_data = gcs_utils.ErrorTableData.FromString(pubsub_msg.data) # Failed to match error message before timeout.
if e.code() == grpc.StatusCode.DEADLINE_EXCEEDED:
logging.warning("get_error_message() timed out")
return []
# Otherwise, the error is unexpected.
raise
else:
msg = subscriber.get_message()
if msg is None:
time.sleep(0.01)
continue
pubsub_msg = gcs_utils.PubSubMessage.FromString(msg["data"])
error_data = gcs_utils.ErrorTableData.FromString(pubsub_msg.data)
if error_type is None or error_type == error_data.type: if error_type is None or error_type == error_data.type:
msgs.append(error_data) msgs.append(error_data)
else: else:

View file

@ -27,6 +27,7 @@ import numpy as np
import ray import ray
import ray._private.gcs_utils as gcs_utils import ray._private.gcs_utils as gcs_utils
import ray.ray_constants as ray_constants import ray.ray_constants as ray_constants
from ray._private.gcs_pubsub import construct_error_message
from ray._private.tls_utils import load_certs_from_env from ray._private.tls_utils import load_certs_from_env
# Import psutil after ray so the packaged version is used. # Import psutil after ray so the packaged version is used.
@ -113,10 +114,11 @@ def push_error_to_driver(worker, error_type, message, job_id=None):
worker.core_worker.push_error(job_id, error_type, message, time.time()) worker.core_worker.push_error(job_id, error_type, message, time.time())
def push_error_to_driver_through_redis(redis_client, def publish_error_to_driver(error_type,
error_type, message,
message, job_id=None,
job_id=None): redis_client=None,
gcs_publisher=None):
"""Push an error message to the driver to be printed in the background. """Push an error message to the driver to be printed in the background.
Normally the push_error_to_driver function should be used. However, in some Normally the push_error_to_driver function should be used. However, in some
@ -125,25 +127,31 @@ def push_error_to_driver_through_redis(redis_client,
backend processes. backend processes.
Args: Args:
redis_client: The redis client to use.
error_type (str): The type of the error. error_type (str): The type of the error.
message (str): The message that will be printed in the background message (str): The message that will be printed in the background
on the driver. on the driver.
job_id: The ID of the driver to push the error message to. If this job_id: The ID of the driver to push the error message to. If this
is None, then the message will be pushed to all drivers. is None, then the message will be pushed to all drivers.
redis_client: The redis client to use.
gcs_publisher: The GCS publisher to use. If specified, ignores
redis_client.
""" """
if job_id is None: if job_id is None:
job_id = ray.JobID.nil() job_id = ray.JobID.nil()
assert isinstance(job_id, ray.JobID) assert isinstance(job_id, ray.JobID)
# Do everything in Python and through the Python Redis client instead error_data = construct_error_message(job_id, error_type, message,
# of through the raylet. time.time())
error_data = gcs_utils.construct_error_message(job_id, error_type, message, if gcs_publisher:
time.time()) gcs_publisher.publish_error(job_id.hex().encode(), error_data)
pubsub_msg = gcs_utils.PubSubMessage() elif redis_client:
pubsub_msg.id = job_id.binary() pubsub_msg = gcs_utils.PubSubMessage()
pubsub_msg.data = error_data pubsub_msg.id = job_id.binary()
redis_client.publish("ERROR_INFO:" + job_id.hex(), pubsub_msg.data = error_data.SerializeToString()
pubsub_msg.SerializeToString()) redis_client.publish("ERROR_INFO:" + job_id.hex(),
pubsub_msg.SerializeToString())
else:
raise ValueError(
"One of redis_client and gcs_publisher needs to be specified!")
def random_string(): def random_string():

View file

@ -1,7 +1,6 @@
"""Autoscaler monitoring loop daemon.""" """Autoscaler monitoring loop daemon."""
import argparse import argparse
import logging
import logging.handlers import logging.handlers
import os import os
import sys import sys
@ -35,7 +34,8 @@ from ray.autoscaler._private.util import DEBUG_AUTOSCALING_STATUS, \
from ray.core.generated import gcs_service_pb2, gcs_service_pb2_grpc from ray.core.generated import gcs_service_pb2, gcs_service_pb2_grpc
import ray.ray_constants as ray_constants import ray.ray_constants as ray_constants
from ray._private.ray_logging import setup_component_logger from ray._private.ray_logging import setup_component_logger
from ray._private.gcs_utils import GcsClient from ray._private.gcs_pubsub import gcs_pubsub_enabled, GcsPublisher
from ray._private.gcs_utils import GcsClient, get_gcs_address_from_redis
from ray.experimental.internal_kv import _initialize_internal_kv, \ from ray.experimental.internal_kv import _initialize_internal_kv, \
_internal_kv_put, _internal_kv_initialized, _internal_kv_get, \ _internal_kv_put, _internal_kv_initialized, _internal_kv_get, \
_internal_kv_del _internal_kv_del
@ -409,9 +409,18 @@ class Monitor:
_internal_kv_put(DEBUG_AUTOSCALING_ERROR, message, overwrite=True) _internal_kv_put(DEBUG_AUTOSCALING_ERROR, message, overwrite=True)
redis_client = ray._private.services.create_redis_client( redis_client = ray._private.services.create_redis_client(
self.redis_address, password=self.redis_password) self.redis_address, password=self.redis_password)
from ray._private.utils import push_error_to_driver_through_redis gcs_publisher = None
push_error_to_driver_through_redis( if args.gcs_address:
redis_client, ray_constants.MONITOR_DIED_ERROR, message) gcs_publisher = GcsPublisher(address=args.gcs_address)
elif gcs_pubsub_enabled():
gcs_publisher = GcsPublisher(
address=get_gcs_address_from_redis(redis_client))
from ray._private.utils import publish_error_to_driver
publish_error_to_driver(
ray_constants.MONITOR_DIED_ERROR,
message,
redis_client=redis_client,
gcs_publisher=gcs_publisher)
def _signal_handler(self, sig, frame): def _signal_handler(self, sig, frame):
self._handle_failure(f"Terminated with signal {sig}\n" + self._handle_failure(f"Terminated with signal {sig}\n" +
@ -437,6 +446,11 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description=("Parse Redis server for the " description=("Parse Redis server for the "
"monitor to connect to.")) "monitor to connect to."))
parser.add_argument(
"--gcs-address",
required=False,
type=str,
help="The address (ip:port) of GCS.")
parser.add_argument( parser.add_argument(
"--redis-address", "--redis-address",
required=True, required=True,

View file

@ -28,6 +28,7 @@ py_test_module_list(
"test_component_failures_2.py", "test_component_failures_2.py",
"test_component_failures_3.py", "test_component_failures_3.py",
"test_error_ray_not_initialized.py", "test_error_ray_not_initialized.py",
"test_gcs_pubsub.py",
"test_global_gc.py", "test_global_gc.py",
"test_grpc_client_credentials.py", "test_grpc_client_credentials.py",
"test_iter.py", "test_iter.py",

View file

@ -11,6 +11,7 @@ import ray._private.utils
import ray._private.gcs_utils as gcs_utils import ray._private.gcs_utils as gcs_utils
import ray.ray_constants as ray_constants import ray.ray_constants as ray_constants
from ray.exceptions import RayTaskError, RayActorError, GetTimeoutError from ray.exceptions import RayTaskError, RayActorError, GetTimeoutError
from ray._private.gcs_pubsub import gcs_pubsub_enabled, GcsPublisher
from ray._private.test_utils import (wait_for_condition, SignalActor, from ray._private.test_utils import (wait_for_condition, SignalActor,
init_error_pubsub, get_error_message) init_error_pubsub, get_error_message)
@ -61,14 +62,21 @@ def test_unhandled_errors(ray_start_regular):
del os.environ["RAY_IGNORE_UNHANDLED_ERRORS"] del os.environ["RAY_IGNORE_UNHANDLED_ERRORS"]
def test_push_error_to_driver_through_redis(ray_start_regular, error_pubsub): def test_publish_error_to_driver(ray_start_regular, error_pubsub):
address_info = ray_start_regular address_info = ray_start_regular
address = address_info["redis_address"] address = address_info["redis_address"]
redis_client = ray._private.services.create_redis_client( redis_client = ray._private.services.create_redis_client(
address, password=ray.ray_constants.REDIS_DEFAULT_PASSWORD) address, password=ray.ray_constants.REDIS_DEFAULT_PASSWORD)
gcs_publisher = None
if gcs_pubsub_enabled():
gcs_publisher = GcsPublisher(
address=gcs_utils.get_gcs_address_from_redis(redis_client))
error_message = "Test error message" error_message = "Test error message"
ray._private.utils.push_error_to_driver_through_redis( ray._private.utils.publish_error_to_driver(
redis_client, ray_constants.DASHBOARD_AGENT_DIED_ERROR, error_message) ray_constants.DASHBOARD_AGENT_DIED_ERROR,
error_message,
redis_client=redis_client,
gcs_publisher=gcs_publisher)
errors = get_error_message(error_pubsub, 1, errors = get_error_message(error_pubsub, 1,
ray_constants.DASHBOARD_AGENT_DIED_ERROR) ray_constants.DASHBOARD_AGENT_DIED_ERROR)
assert errors[0].type == ray_constants.DASHBOARD_AGENT_DIED_ERROR assert errors[0].type == ray_constants.DASHBOARD_AGENT_DIED_ERROR

View file

@ -104,9 +104,8 @@ def test_connect_with_disconnected_node(shutdown_only):
ray.init(address=cluster.address) ray.init(address=cluster.address)
p = init_error_pubsub() p = init_error_pubsub()
errors = get_error_message(p, 1, timeout=5) errors = get_error_message(p, 1, timeout=5)
print(errors)
assert len(errors) == 0 assert len(errors) == 0
# This node is killed by SIGKILL, ray_monitor will mark it to dead. # This node will be killed by SIGKILL, ray_monitor will mark it to dead.
dead_node = cluster.add_node(num_cpus=0) dead_node = cluster.add_node(num_cpus=0)
cluster.remove_node(dead_node, allow_graceful=False) cluster.remove_node(dead_node, allow_graceful=False)
errors = get_error_message(p, 1, ray_constants.REMOVED_NODE_ERROR) errors = get_error_message(p, 1, ray_constants.REMOVED_NODE_ERROR)

View file

@ -0,0 +1,73 @@
import sys
import ray
import ray._private.gcs_utils as gcs_utils
from ray._private.gcs_pubsub import GcsPublisher, GcsSubscriber, \
GcsAioPublisher, GcsAioSubscriber
from ray.core.generated.gcs_pb2 import ErrorTableData
import pytest
@pytest.mark.parametrize(
"ray_start_regular", [{
"_system_config": {
"gcs_grpc_based_pubsub": True
}
}],
indirect=True)
def test_publish_and_subscribe_error_info(ray_start_regular):
address_info = ray_start_regular
redis = ray._private.services.create_redis_client(
address_info["redis_address"],
password=ray.ray_constants.REDIS_DEFAULT_PASSWORD)
gcs_server_addr = gcs_utils.get_gcs_address_from_redis(redis)
subscriber = GcsSubscriber(address=gcs_server_addr)
subscriber.subscribe_error()
publisher = GcsPublisher(address=gcs_server_addr)
err1 = ErrorTableData(error_message="test error message 1")
err2 = ErrorTableData(error_message="test error message 2")
publisher.publish_error(b"aaa_id", err1)
publisher.publish_error(b"bbb_id", err2)
assert subscriber.poll_error() == (b"aaa_id", err1)
assert subscriber.poll_error() == (b"bbb_id", err2)
subscriber.close()
@pytest.mark.asyncio
@pytest.mark.parametrize(
"ray_start_regular", [{
"_system_config": {
"gcs_grpc_based_pubsub": True
}
}],
indirect=True)
async def test_aio_publish_and_subscribe_error_info(ray_start_regular):
address_info = ray_start_regular
redis = ray._private.services.create_redis_client(
address_info["redis_address"],
password=ray.ray_constants.REDIS_DEFAULT_PASSWORD)
gcs_server_addr = gcs_utils.get_gcs_address_from_redis(redis)
subscriber = GcsAioSubscriber(address=gcs_server_addr)
await subscriber.subscribe_error()
publisher = GcsAioPublisher(address=gcs_server_addr)
err1 = ErrorTableData(error_message="test error message 1")
err2 = ErrorTableData(error_message="test error message 2")
await publisher.publish_error(b"aaa_id", err1)
await publisher.publish_error(b"bbb_id", err2)
assert await subscriber.poll_error() == (b"aaa_id", err1)
assert await subscriber.poll_error() == (b"bbb_id", err2)
await subscriber.close()
if __name__ == "__main__":
sys.exit(pytest.main(["-v", __file__]))

View file

@ -27,6 +27,8 @@ import ray.remote_function
import ray.serialization as serialization import ray.serialization as serialization
import ray._private.gcs_utils as gcs_utils import ray._private.gcs_utils as gcs_utils
import ray._private.services as services import ray._private.services as services
from ray._private.gcs_pubsub import gcs_pubsub_enabled, GcsPublisher, \
GcsSubscriber
from ray._private.runtime_env.py_modules import upload_py_modules_if_needed from ray._private.runtime_env.py_modules import upload_py_modules_if_needed
from ray._private.runtime_env.working_dir import upload_working_dir_if_needed from ray._private.runtime_env.working_dir import upload_working_dir_if_needed
from ray._private.runtime_env.constants import RAY_JOB_CONFIG_JSON_ENV_VAR from ray._private.runtime_env.constants import RAY_JOB_CONFIG_JSON_ENV_VAR
@ -1247,6 +1249,58 @@ def listen_error_messages_raylet(worker, threads_stopped):
worker.error_message_pubsub_client.close() worker.error_message_pubsub_client.close()
def listen_error_messages_from_gcs(worker, threads_stopped):
"""Listen to error messages in the background on the driver.
This runs in a separate thread on the driver and pushes (error, time)
tuples to be published.
Args:
worker: The worker class that this thread belongs to.
threads_stopped (threading.Event): A threading event used to signal to
the thread that it should exit.
"""
worker.gcs_subscriber = GcsSubscriber(channel=worker.gcs_channel)
# Exports that are published after the call to
# gcs_subscriber.subscribe_error() and before the call to
# gcs_subscriber.poll_error() will still be processed in the loop.
# TODO: we should just subscribe to the errors for this specific job.
worker.gcs_subscriber.subscribe_error()
try:
if _internal_kv_initialized():
# Get any autoscaler errors that occurred before the call to
# subscribe.
error_message = _internal_kv_get(DEBUG_AUTOSCALING_ERROR)
if error_message is not None:
logger.warning(error_message.decode())
while True:
# Exit if received a signal that the thread should stop.
if threads_stopped.is_set():
return
_, error_data = worker.gcs_subscriber.poll_error()
if error_data is None:
continue
if error_data.job_id not in [
worker.current_job_id.binary(),
JobID.nil().binary(),
]:
continue
error_message = error_data.error_message
if error_data.type == ray_constants.TASK_PUSH_ERROR:
# TODO(ekl) remove task push errors entirely now that we have
# the separate unhandled exception handler.
pass
else:
logger.warning(error_message)
except (OSError, ConnectionError) as e:
logger.error(f"listen_error_messages_from_gcs: {e}")
@PublicAPI @PublicAPI
@client_mode_hook(auto_init=False) @client_mode_hook(auto_init=False)
def is_initialized() -> bool: def is_initialized() -> bool:
@ -1307,11 +1361,16 @@ def connect(node,
# that is not true of Redis pubsub clients. See the documentation at # that is not true of Redis pubsub clients. See the documentation at
# https://github.com/andymccurdy/redis-py#thread-safety. # https://github.com/andymccurdy/redis-py#thread-safety.
worker.redis_client = node.create_redis_client() worker.redis_client = node.create_redis_client()
worker.gcs_client = gcs_utils.GcsClient.create_from_redis( worker.gcs_channel = gcs_utils.create_gcs_channel(
worker.redis_client) gcs_utils.get_gcs_address_from_redis(worker.redis_client))
worker.gcs_client = gcs_utils.GcsClient(channel=worker.gcs_channel)
_initialize_internal_kv(worker.gcs_client) _initialize_internal_kv(worker.gcs_client)
ray.state.state._initialize_global_state( ray.state.state._initialize_global_state(
node.redis_address, redis_password=node.redis_password) node.redis_address, redis_password=node.redis_password)
worker.gcs_pubsub_enabled = gcs_pubsub_enabled()
worker.gcs_publisher = None
if worker.gcs_pubsub_enabled:
worker.gcs_publisher = GcsPublisher(channel=worker.gcs_channel)
# Initialize some fields. # Initialize some fields.
if mode in (WORKER_MODE, RESTORE_WORKER_MODE, SPILL_WORKER_MODE): if mode in (WORKER_MODE, RESTORE_WORKER_MODE, SPILL_WORKER_MODE):
@ -1350,11 +1409,12 @@ def connect(node,
raise e raise e
elif mode == WORKER_MODE: elif mode == WORKER_MODE:
traceback_str = traceback.format_exc() traceback_str = traceback.format_exc()
ray._private.utils.push_error_to_driver_through_redis( ray._private.utils.publish_error_to_driver(
worker.redis_client,
ray_constants.VERSION_MISMATCH_PUSH_ERROR, ray_constants.VERSION_MISMATCH_PUSH_ERROR,
traceback_str, traceback_str,
job_id=None) job_id=None,
redis_client=worker.redis_client,
gcs_publisher=worker.gcs_publisher)
worker.lock = threading.RLock() worker.lock = threading.RLock()
@ -1445,7 +1505,8 @@ def connect(node,
# scheduler for new error messages. # scheduler for new error messages.
if mode == SCRIPT_MODE: if mode == SCRIPT_MODE:
worker.listener_thread = threading.Thread( worker.listener_thread = threading.Thread(
target=listen_error_messages_raylet, target=listen_error_messages_from_gcs
if worker.gcs_pubsub_enabled else listen_error_messages_raylet,
name="ray_listen_error_messages", name="ray_listen_error_messages",
args=(worker, worker.threads_stopped)) args=(worker, worker.threads_stopped))
worker.listener_thread.daemon = True worker.listener_thread.daemon = True
@ -1519,6 +1580,8 @@ def disconnect(exiting_interpreter=False):
if hasattr(worker, "import_thread"): if hasattr(worker, "import_thread"):
worker.import_thread.join_import_thread() worker.import_thread.join_import_thread()
if hasattr(worker, "listener_thread"): if hasattr(worker, "listener_thread"):
if hasattr(worker, "gcs_subscriber"):
worker.gcs_subscriber.close()
worker.listener_thread.join() worker.listener_thread.join()
if hasattr(worker, "logger_thread"): if hasattr(worker, "logger_thread"):
worker.logger_thread.join() worker.logger_thread.join()

View file

@ -116,17 +116,6 @@ class MockGcsTaskReconstructionTable : public GcsTaskReconstructionTable {
namespace ray { namespace ray {
namespace gcs { namespace gcs {
class MockGcsObjectTable : public GcsObjectTable {
public:
MOCK_METHOD(JobID, GetJobIdFromKey, (const ObjectID &key), (override));
};
} // namespace gcs
} // namespace ray
namespace ray {
namespace gcs {
class MockGcsNodeTable : public GcsNodeTable { class MockGcsNodeTable : public GcsNodeTable {
public: public:
}; };

View file

@ -72,7 +72,8 @@ void GcsServer::Start() {
rpc::ChannelType::GCS_JOB_CHANNEL, rpc::ChannelType::GCS_JOB_CHANNEL,
rpc::ChannelType::GCS_NODE_INFO_CHANNEL, rpc::ChannelType::GCS_NODE_INFO_CHANNEL,
rpc::ChannelType::GCS_NODE_RESOURCE_CHANNEL, rpc::ChannelType::GCS_NODE_RESOURCE_CHANNEL,
rpc::ChannelType::GCS_WORKER_DELTA_CHANNEL}, rpc::ChannelType::GCS_WORKER_DELTA_CHANNEL,
rpc::ChannelType::RAY_ERROR_INFO_CHANNEL},
/*periodical_runner=*/&pubsub_periodical_runner_, /*periodical_runner=*/&pubsub_periodical_runner_,
/*get_time_ms=*/[]() { return absl::GetCurrentTimeNanos() / 1e6; }, /*get_time_ms=*/[]() { return absl::GetCurrentTimeNanos() / 1e6; },
/*subscriber_timeout_ms=*/RayConfig::instance().subscriber_timeout_ms(), /*subscriber_timeout_ms=*/RayConfig::instance().subscriber_timeout_ms(),

View file

@ -213,17 +213,6 @@ class GcsTaskReconstructionTable
JobID GetJobIdFromKey(const TaskID &key) override { return key.ActorId().JobId(); } JobID GetJobIdFromKey(const TaskID &key) override { return key.ActorId().JobId(); }
}; };
class GcsObjectTable : public GcsTableWithJobId<ObjectID, ObjectLocationInfo> {
public:
explicit GcsObjectTable(std::shared_ptr<StoreClient> store_client)
: GcsTableWithJobId(std::move(store_client)) {
table_name_ = TablePrefix_Name(TablePrefix::OBJECT);
}
private:
JobID GetJobIdFromKey(const ObjectID &key) override { return key.TaskId().JobId(); }
};
class GcsNodeTable : public GcsTable<NodeID, GcsNodeInfo> { class GcsNodeTable : public GcsTable<NodeID, GcsNodeInfo> {
public: public:
explicit GcsNodeTable(std::shared_ptr<StoreClient> store_client) explicit GcsNodeTable(std::shared_ptr<StoreClient> store_client)

View file

@ -17,12 +17,35 @@
namespace ray { namespace ray {
namespace gcs { namespace gcs {
void InternalPubSubHandler::HandleGcsPublish(const rpc::GcsPublishRequest &request,
rpc::GcsPublishReply *reply,
rpc::SendReplyCallback send_reply_callback) {
if (gcs_publisher_ == nullptr) {
send_reply_callback(
Status::NotImplemented("GCS pubsub is not yet enabled. Please enable it with "
"system config `gcs_grpc_based_pubsub=True`"),
nullptr, nullptr);
return;
}
for (const auto &msg : request.pub_messages()) {
gcs_publisher_->GetPublisher()->Publish(msg);
}
send_reply_callback(Status::OK(), nullptr, nullptr);
}
// Needs to use rpc::GcsSubscriberPollRequest and rpc::GcsSubscriberPollReply here, // Needs to use rpc::GcsSubscriberPollRequest and rpc::GcsSubscriberPollReply here,
// and convert the reply to rpc::PubsubLongPollingReply because GCS RPC services are // and convert the reply to rpc::PubsubLongPollingReply because GCS RPC services are
// required to have the `status` field in replies. // required to have the `status` field in replies.
void InternalPubSubHandler::HandleGcsSubscriberPoll( void InternalPubSubHandler::HandleGcsSubscriberPoll(
const rpc::GcsSubscriberPollRequest &request, rpc::GcsSubscriberPollReply *reply, const rpc::GcsSubscriberPollRequest &request, rpc::GcsSubscriberPollReply *reply,
rpc::SendReplyCallback send_reply_callback) { rpc::SendReplyCallback send_reply_callback) {
if (gcs_publisher_ == nullptr) {
send_reply_callback(
Status::NotImplemented("GCS pubsub is not yet enabled. Please enable it with "
"system config `gcs_grpc_based_pubsub=True`"),
nullptr, nullptr);
return;
}
const auto subscriber_id = UniqueID::FromBinary(request.subscriber_id()); const auto subscriber_id = UniqueID::FromBinary(request.subscriber_id());
auto pubsub_reply = std::make_shared<rpc::PubsubLongPollingReply>(); auto pubsub_reply = std::make_shared<rpc::PubsubLongPollingReply>();
auto pubsub_reply_ptr = pubsub_reply.get(); auto pubsub_reply_ptr = pubsub_reply.get();
@ -44,6 +67,13 @@ void InternalPubSubHandler::HandleGcsSubscriberCommandBatch(
const rpc::GcsSubscriberCommandBatchRequest &request, const rpc::GcsSubscriberCommandBatchRequest &request,
rpc::GcsSubscriberCommandBatchReply *reply, rpc::GcsSubscriberCommandBatchReply *reply,
rpc::SendReplyCallback send_reply_callback) { rpc::SendReplyCallback send_reply_callback) {
if (gcs_publisher_ == nullptr) {
send_reply_callback(
Status::NotImplemented("GCS pubsub is not yet enabled. Please enable it with "
"system config `gcs_grpc_based_pubsub=True`"),
nullptr, nullptr);
return;
}
const auto subscriber_id = UniqueID::FromBinary(request.subscriber_id()); const auto subscriber_id = UniqueID::FromBinary(request.subscriber_id());
for (const auto &command : request.commands()) { for (const auto &command : request.commands()) {
if (command.has_unsubscribe_message()) { if (command.has_unsubscribe_message()) {

View file

@ -16,6 +16,7 @@
#include "ray/gcs/pubsub/gcs_pub_sub.h" #include "ray/gcs/pubsub/gcs_pub_sub.h"
#include "ray/rpc/gcs_server/gcs_rpc_server.h" #include "ray/rpc/gcs_server/gcs_rpc_server.h"
#include "src/ray/protobuf/gcs_service.grpc.pb.h"
namespace ray { namespace ray {
namespace gcs { namespace gcs {
@ -28,6 +29,10 @@ class InternalPubSubHandler : public rpc::InternalPubSubHandler {
explicit InternalPubSubHandler(const std::shared_ptr<gcs::GcsPublisher> &gcs_publisher) explicit InternalPubSubHandler(const std::shared_ptr<gcs::GcsPublisher> &gcs_publisher)
: gcs_publisher_(gcs_publisher) {} : gcs_publisher_(gcs_publisher) {}
void HandleGcsPublish(const rpc::GcsPublishRequest &request,
rpc::GcsPublishReply *reply,
rpc::SendReplyCallback send_reply_callback) final;
void HandleGcsSubscriberPoll(const rpc::GcsSubscriberPollRequest &request, void HandleGcsSubscriberPoll(const rpc::GcsSubscriberPollRequest &request,
rpc::GcsSubscriberPollReply *reply, rpc::GcsSubscriberPollReply *reply,
rpc::SendReplyCallback send_reply_callback) final; rpc::SendReplyCallback send_reply_callback) final;

View file

@ -302,6 +302,17 @@ Status GcsPublisher::PublishTaskLease(const TaskID &id, const rpc::TaskLeaseData
Status GcsPublisher::PublishError(const std::string &id, Status GcsPublisher::PublishError(const std::string &id,
const rpc::ErrorTableData &message, const rpc::ErrorTableData &message,
const StatusCallback &done) { const StatusCallback &done) {
if (publisher_ != nullptr) {
rpc::PubMessage msg;
msg.set_channel_type(rpc::ChannelType::RAY_ERROR_INFO_CHANNEL);
msg.set_key_id(id);
*msg.mutable_error_info_message() = message;
publisher_->Publish(msg);
if (done != nullptr) {
done(Status::OK());
}
return Status::OK();
}
return pubsub_->Publish(ERROR_INFO_CHANNEL, id, message.SerializeAsString(), done); return pubsub_->Publish(ERROR_INFO_CHANNEL, id, message.SerializeAsString(), done);
} }

View file

@ -563,6 +563,16 @@ service InternalKVGcsService {
rpc InternalKVKeys(InternalKVKeysRequest) returns (InternalKVKeysReply); rpc InternalKVKeys(InternalKVKeysRequest) returns (InternalKVKeysReply);
} }
message GcsPublishRequest {
/// The messages that are published.
repeated PubMessage pub_messages = 1;
}
message GcsPublishReply {
// Not populated.
GcsStatus status = 100;
}
message GcsSubscriberPollRequest { message GcsSubscriberPollRequest {
/// The id of the subscriber. /// The id of the subscriber.
bytes subscriber_id = 1; bytes subscriber_id = 1;
@ -590,6 +600,9 @@ message GcsSubscriberCommandBatchReply {
/// This supports subscribing updates from GCS with long poll, and registering / /// This supports subscribing updates from GCS with long poll, and registering /
/// de-registering subscribers. /// de-registering subscribers.
service InternalPubSubGcsService { service InternalPubSubGcsService {
/// The request to sent to GCS to publish messages.
/// Currently only supporting error info, logs and Python function messages.
rpc GcsPublish(GcsPublishRequest) returns (GcsPublishReply);
/// The long polling request sent to GCS for pubsub operations. /// The long polling request sent to GCS for pubsub operations.
/// The long poll request will be replied once there are a batch of messages that /// The long poll request will be replied once there are a batch of messages that
/// need to be published to the caller (subscriber). /// need to be published to the caller (subscriber).

View file

@ -40,6 +40,8 @@ enum ChannelType {
GCS_NODE_RESOURCE_CHANNEL = 6; GCS_NODE_RESOURCE_CHANNEL = 6;
/// A channel for worker changes, currently only for worker failures. /// A channel for worker changes, currently only for worker failures.
GCS_WORKER_DELTA_CHANNEL = 7; GCS_WORKER_DELTA_CHANNEL = 7;
/// A channel for errors from various Ray components.
RAY_ERROR_INFO_CHANNEL = 8;
} }
/// ///
@ -61,6 +63,7 @@ message PubMessage {
GcsNodeInfo node_info_message = 9; GcsNodeInfo node_info_message = 9;
NodeResourceChange node_resource_message = 10; NodeResourceChange node_resource_message = 10;
WorkerDeltaData worker_delta_message = 11; WorkerDeltaData worker_delta_message = 11;
ErrorTableData error_info_message = 12;
// The message that indicates the given key id is not available anymore. // The message that indicates the given key id is not available anymore.
FailureMessage failure_message = 6; FailureMessage failure_message = 6;

View file

@ -145,6 +145,10 @@ bool SubscriptionIndex::CheckNoLeaks() const {
bool Subscriber::ConnectToSubscriber(rpc::PubsubLongPollingReply *reply, bool Subscriber::ConnectToSubscriber(rpc::PubsubLongPollingReply *reply,
rpc::SendReplyCallback send_reply_callback) { rpc::SendReplyCallback send_reply_callback) {
if (long_polling_connection_) {
// Flush the current subscriber poll with an empty reply.
PublishIfPossible(/*force_noop=*/true);
}
if (!long_polling_connection_) { if (!long_polling_connection_) {
RAY_CHECK(reply != nullptr); RAY_CHECK(reply != nullptr);
RAY_CHECK(send_reply_callback != nullptr); RAY_CHECK(send_reply_callback != nullptr);
@ -171,30 +175,25 @@ void Subscriber::QueueMessage(const rpc::PubMessage &pub_message, bool try_publi
} }
} }
bool Subscriber::PublishIfPossible(bool force) { bool Subscriber::PublishIfPossible(bool force_noop) {
if (!long_polling_connection_) { if (!long_polling_connection_) {
return false; return false;
} }
if (!force_noop && mailbox_.empty()) {
return false;
}
if (force || mailbox_.size() > 0) { if (!force_noop) {
// If force publish is invoked, mailbox could be empty. We should always add a reply
// here because otherwise, there could be memory leak due to our grpc layer
// implementation.
if (mailbox_.empty()) {
mailbox_.push(absl::make_unique<rpc::PubsubLongPollingReply>());
}
// Reply to the long polling subscriber. Swap the reply here to avoid extra copy. // Reply to the long polling subscriber. Swap the reply here to avoid extra copy.
long_polling_connection_->reply->Swap(mailbox_.front().get()); long_polling_connection_->reply->Swap(mailbox_.front().get());
long_polling_connection_->send_reply_callback(Status::OK(), nullptr, nullptr);
// Clean up & update metadata.
long_polling_connection_.reset();
mailbox_.pop(); mailbox_.pop();
last_connection_update_time_ms_ = get_time_ms_();
return true;
} }
return false; long_polling_connection_->send_reply_callback(Status::OK(), nullptr, nullptr);
// Clean up & update metadata.
long_polling_connection_.reset();
last_connection_update_time_ms_ = get_time_ms_();
return true;
} }
bool Subscriber::CheckNoLeaks() const { bool Subscriber::CheckNoLeaks() const {
@ -311,7 +310,7 @@ int Publisher::UnregisterSubscriberInternal(const SubscriberID &subscriber_id) {
} }
auto &subscriber = it->second; auto &subscriber = it->second;
// Remove the long polling connection because otherwise, there's memory leak. // Remove the long polling connection because otherwise, there's memory leak.
subscriber->PublishIfPossible(/*force=*/true); subscriber->PublishIfPossible(/*force_noop=*/true);
subscribers_.erase(it); subscribers_.erase(it);
return erased; return erased;
} }
@ -330,8 +329,8 @@ void Publisher::CheckDeadSubscribers() {
if (disconnected) { if (disconnected) {
dead_subscribers.push_back(it.first); dead_subscribers.push_back(it.first);
} else if (active_connection_timed_out) { } else if (active_connection_timed_out) {
// Refresh the long polling connection. The subscriber will send it again. // Refresh the long polling connection. The subscriber will poll again.
subscriber->PublishIfPossible(/*force*/ true); subscriber->PublishIfPossible(/*force_noop*/ true);
} }
} }

View file

@ -123,10 +123,11 @@ class Subscriber {
/// Publish all queued messages if possible. /// Publish all queued messages if possible.
/// ///
/// \param force If true, we publish to the subscriber although there's no queued /// \param force_noop If true, reply to the subscriber with an empty message, regardless
/// message. /// of whethere there is any queued message. This is for cases where the current poll
/// might have been cancelled, or the subscriber might be dead.
/// \return True if it publishes. False otherwise. /// \return True if it publishes. False otherwise.
bool PublishIfPossible(bool force = false); bool PublishIfPossible(bool force_noop = false);
/// Testing only. Return true if there's no metadata remained in the private attribute. /// Testing only. Return true if there's no metadata remained in the private attribute.
bool CheckNoLeaks() const; bool CheckNoLeaks() const;

View file

@ -285,8 +285,8 @@ TEST_F(PublisherTest, TestSubscriber) {
ASSERT_FALSE(subscriber->PublishIfPossible()); ASSERT_FALSE(subscriber->PublishIfPossible());
// Try connecting it. Should return true. // Try connecting it. Should return true.
ASSERT_TRUE(subscriber->ConnectToSubscriber(&reply, send_reply_callback)); ASSERT_TRUE(subscriber->ConnectToSubscriber(&reply, send_reply_callback));
// If connecting it again, it should fail the request. // Polling when there is already an inflight polling request should still work.
ASSERT_FALSE(subscriber->ConnectToSubscriber(&reply, send_reply_callback)); ASSERT_TRUE(subscriber->ConnectToSubscriber(&reply, send_reply_callback));
// Since there's no published objects, it should return false. // Since there's no published objects, it should return false.
ASSERT_FALSE(subscriber->PublishIfPossible()); ASSERT_FALSE(subscriber->PublishIfPossible());

View file

@ -277,6 +277,8 @@ class GcsRpcClient {
internal_kv_grpc_client_, ) internal_kv_grpc_client_, )
/// Operations for pubsub /// Operations for pubsub
VOID_GCS_RPC_CLIENT_METHOD(InternalPubSubGcsService, GcsPublish,
internal_pubsub_grpc_client_, )
VOID_GCS_RPC_CLIENT_METHOD(InternalPubSubGcsService, GcsSubscriberPoll, VOID_GCS_RPC_CLIENT_METHOD(InternalPubSubGcsService, GcsSubscriberPoll,
internal_pubsub_grpc_client_, ) internal_pubsub_grpc_client_, )
VOID_GCS_RPC_CLIENT_METHOD(InternalPubSubGcsService, GcsSubscriberCommandBatch, VOID_GCS_RPC_CLIENT_METHOD(InternalPubSubGcsService, GcsSubscriberCommandBatch,

View file

@ -606,6 +606,9 @@ class InternalPubSubGcsServiceHandler {
public: public:
virtual ~InternalPubSubGcsServiceHandler() = default; virtual ~InternalPubSubGcsServiceHandler() = default;
virtual void HandleGcsPublish(const GcsPublishRequest &request, GcsPublishReply *reply,
SendReplyCallback send_reply_callback) = 0;
virtual void HandleGcsSubscriberPoll(const GcsSubscriberPollRequest &request, virtual void HandleGcsSubscriberPoll(const GcsSubscriberPollRequest &request,
GcsSubscriberPollReply *reply, GcsSubscriberPollReply *reply,
SendReplyCallback send_reply_callback) = 0; SendReplyCallback send_reply_callback) = 0;
@ -626,6 +629,7 @@ class InternalPubSubGrpcService : public GrpcService {
void InitServerCallFactories( void InitServerCallFactories(
const std::unique_ptr<grpc::ServerCompletionQueue> &cq, const std::unique_ptr<grpc::ServerCompletionQueue> &cq,
std::vector<std::unique_ptr<ServerCallFactory>> *server_call_factories) override { std::vector<std::unique_ptr<ServerCallFactory>> *server_call_factories) override {
INTERNAL_PUBSUB_SERVICE_RPC_HANDLER(GcsPublish);
INTERNAL_PUBSUB_SERVICE_RPC_HANDLER(GcsSubscriberPoll); INTERNAL_PUBSUB_SERVICE_RPC_HANDLER(GcsSubscriberPoll);
INTERNAL_PUBSUB_SERVICE_RPC_HANDLER(GcsSubscriberCommandBatch); INTERNAL_PUBSUB_SERVICE_RPC_HANDLER(GcsSubscriberCommandBatch);
} }