mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[Core][Pubsub] Implement Python GCS publisher and subscriber (#20111)
## Why are these changes needed? This change adds Python publisher and subscriber in `gcs_utils.py`, and GRPC handler on GCS for publishing iva GCS. Error info is migrated to use the GCS-based pubsub, if feature flag `RAY_gcs_grpc_based_pubsub=true`. Also, add a `--gcs-address` flag to some Python processes. It is not set anywhere yet, but will be set aftering Redis-less bootstrapping work. Unit tests are added for the Python publisher and subscriber. Migrated error info publishers and subscribers are tested with existing unit tests, e.g. tests calling `ray._private.test_utils.get_error_message()` to ensure error info is published. GCS based pubsub has gaps in handling deadline, cancelled requests and GCS restarts. So 3 more unit tests are disabled in the `HA GCS` mode. They will be addressed in a separate change. ## Related issue number
This commit is contained in:
parent
fca851eef5
commit
0330852baf
31 changed files with 763 additions and 188 deletions
|
@ -343,6 +343,7 @@
|
|||
--test_env=RAY_gcs_grpc_based_pubsub=true
|
||||
-- //python/ray/tests/...
|
||||
-//python/ray/tests:test_failure_2
|
||||
-//python/ray/tests:test_job
|
||||
- bazel test --config=ci $(./scripts/bazel_export_options)
|
||||
--test_tag_filters=-kubernetes,client_tests,-flaky
|
||||
--test_env=RAY_CLIENT_MODE=1 --test_env=RAY_PROFILING=1
|
||||
|
@ -358,6 +359,7 @@
|
|||
-- //python/ray/tests/...
|
||||
-//python/ray/tests:test_client_multi -//python/ray/tests:test_component_failures_3
|
||||
-//python/ray/tests:test_healthcheck -//python/ray/tests:test_gcs_fault_tolerance
|
||||
-//python/ray/tests:test_client
|
||||
- label: ":redis: HA GCS (Medium K-Z)"
|
||||
conditions: ["RAY_CI_PYTHON_AFFECTED"]
|
||||
commands:
|
||||
|
@ -368,6 +370,7 @@
|
|||
-- //python/ray/tests/...
|
||||
-//python/ray/tests:test_multinode_failures_2 -//python/ray/tests:test_ray_debugger
|
||||
-//python/ray/tests:test_placement_group_2 -//python/ray/tests:test_placement_group_3
|
||||
-//python/ray/tests:test_multi_node
|
||||
|
||||
- label: ":brain: RLlib: Learning discr. actions TF2-static-graph (from rllib/tuned_examples/*.yaml)"
|
||||
conditions: ["RAY_CI_RLLIB_AFFECTED"]
|
||||
|
|
|
@ -19,7 +19,9 @@ import ray.dashboard.utils as dashboard_utils
|
|||
import ray.ray_constants as ray_constants
|
||||
import ray._private.services
|
||||
import ray._private.utils
|
||||
from ray._private.gcs_utils import GcsClient
|
||||
from ray._private.gcs_pubsub import gcs_pubsub_enabled, GcsPublisher
|
||||
from ray._private.gcs_utils import GcsClient, \
|
||||
get_gcs_address_from_redis
|
||||
from ray.core.generated import agent_manager_pb2
|
||||
from ray.core.generated import agent_manager_pb2_grpc
|
||||
from ray._private.ray_logging import setup_component_logger
|
||||
|
@ -228,6 +230,11 @@ if __name__ == "__main__":
|
|||
required=True,
|
||||
type=str,
|
||||
help="the IP address of this node.")
|
||||
parser.add_argument(
|
||||
"--gcs-address",
|
||||
required=False,
|
||||
type=str,
|
||||
help="The address (ip:port) of GCS.")
|
||||
parser.add_argument(
|
||||
"--redis-address",
|
||||
required=True,
|
||||
|
@ -377,6 +384,12 @@ if __name__ == "__main__":
|
|||
# impact of the issue.
|
||||
redis_client = ray._private.services.create_redis_client(
|
||||
args.redis_address, password=args.redis_password)
|
||||
gcs_publisher = None
|
||||
if args.gcs_address:
|
||||
gcs_publisher = GcsPublisher(args.gcs_address)
|
||||
elif gcs_pubsub_enabled():
|
||||
gcs_publisher = GcsPublisher(
|
||||
address=get_gcs_address_from_redis(redis_client))
|
||||
traceback_str = ray._private.utils.format_error_message(
|
||||
traceback.format_exc())
|
||||
message = (
|
||||
|
@ -390,9 +403,11 @@ if __name__ == "__main__":
|
|||
"\n 3. runtime_env APIs won't work."
|
||||
"\nCheck out the `dashboard_agent.log` to see the "
|
||||
"detailed failure messages.")
|
||||
ray._private.utils.push_error_to_driver_through_redis(
|
||||
redis_client, ray_constants.DASHBOARD_AGENT_DIED_ERROR,
|
||||
message)
|
||||
ray._private.utils.publish_error_to_driver(
|
||||
ray_constants.DASHBOARD_AGENT_DIED_ERROR,
|
||||
message,
|
||||
redis_client=redis_client,
|
||||
gcs_publisher=gcs_publisher)
|
||||
logger.error(message)
|
||||
logger.exception(e)
|
||||
exit(1)
|
||||
|
|
|
@ -13,8 +13,10 @@ import ray.dashboard.consts as dashboard_consts
|
|||
import ray.dashboard.head as dashboard_head
|
||||
import ray.dashboard.utils as dashboard_utils
|
||||
import ray.ray_constants as ray_constants
|
||||
import ray._private.gcs_utils as gcs_utils
|
||||
import ray._private.services
|
||||
import ray._private.utils
|
||||
from ray._private.gcs_pubsub import gcs_pubsub_enabled, GcsPublisher
|
||||
from ray._private.ray_logging import setup_component_logger
|
||||
from ray._private.metrics_agent import PrometheusServiceDiscoveryWriter
|
||||
|
||||
|
@ -134,6 +136,11 @@ if __name__ == "__main__":
|
|||
type=int,
|
||||
default=0,
|
||||
help="The retry times to select a valid port.")
|
||||
parser.add_argument(
|
||||
"--gcs-address",
|
||||
required=False,
|
||||
type=str,
|
||||
help="The address (ip:port) of GCS.")
|
||||
parser.add_argument(
|
||||
"--redis-address",
|
||||
required=True,
|
||||
|
@ -223,18 +230,29 @@ if __name__ == "__main__":
|
|||
loop = asyncio.get_event_loop()
|
||||
loop.run_until_complete(dashboard.run())
|
||||
except Exception as e:
|
||||
# Something went wrong, so push an error to all drivers.
|
||||
redis_client = ray._private.services.create_redis_client(
|
||||
args.redis_address, password=args.redis_password)
|
||||
traceback_str = ray._private.utils.format_error_message(
|
||||
traceback.format_exc())
|
||||
message = f"The dashboard on node {platform.uname()[1]} " \
|
||||
f"failed with the following " \
|
||||
f"error:\n{traceback_str}"
|
||||
ray._private.utils.push_error_to_driver_through_redis(
|
||||
redis_client, ray_constants.DASHBOARD_DIED_ERROR, message)
|
||||
if isinstance(e, FrontendNotFoundError):
|
||||
logger.warning(message)
|
||||
else:
|
||||
logger.error(message)
|
||||
raise e
|
||||
|
||||
# Something went wrong, so push an error to all drivers.
|
||||
redis_client = ray._private.services.create_redis_client(
|
||||
args.redis_address, password=args.redis_password)
|
||||
gcs_publisher = None
|
||||
if args.gcs_address:
|
||||
gcs_publisher = GcsPublisher(address=args.gcs_address)
|
||||
elif gcs_pubsub_enabled():
|
||||
gcs_publisher = GcsPublisher(
|
||||
address=gcs_utils.get_gcs_address_from_redis(redis_client))
|
||||
ray._private.utils.publish_error_to_driver(
|
||||
redis_client,
|
||||
ray_constants.DASHBOARD_DIED_ERROR,
|
||||
message,
|
||||
redis_client=redis_client,
|
||||
gcs_publisher=gcs_publisher)
|
||||
|
|
|
@ -19,6 +19,7 @@ import ray._private.services
|
|||
import ray.dashboard.consts as dashboard_consts
|
||||
import ray.dashboard.utils as dashboard_utils
|
||||
from ray import ray_constants
|
||||
from ray._private.gcs_pubsub import gcs_pubsub_enabled, GcsAioSubscriber
|
||||
from ray.core.generated import gcs_service_pb2
|
||||
from ray.core.generated import gcs_service_pb2_grpc
|
||||
from ray.dashboard.datacenter import DataOrganizer
|
||||
|
@ -44,8 +45,8 @@ GRPC_CHANNEL_OPTIONS = (
|
|||
async def get_gcs_address_with_retry(redis_client) -> str:
|
||||
while True:
|
||||
try:
|
||||
gcs_address = await redis_client.get(
|
||||
dashboard_consts.REDIS_KEY_GCS_SERVER_ADDRESS)
|
||||
gcs_address = (await redis_client.get(
|
||||
dashboard_consts.REDIS_KEY_GCS_SERVER_ADDRESS)).decode()
|
||||
if not gcs_address:
|
||||
raise Exception("GCS address not found.")
|
||||
logger.info("Connect to GCS at %s", gcs_address)
|
||||
|
@ -113,6 +114,7 @@ class DashboardHead:
|
|||
self.log_dir = log_dir
|
||||
self.aioredis_client = None
|
||||
self.aiogrpc_gcs_channel = None
|
||||
self.gcs_subscriber = None
|
||||
self.http_session = None
|
||||
self.ip = ray.util.get_node_ip_address()
|
||||
ip, port = redis_address.split(":")
|
||||
|
@ -192,11 +194,15 @@ class DashboardHead:
|
|||
# Waiting for GCS is ready.
|
||||
# TODO: redis-removal bootstrap
|
||||
gcs_address = await get_gcs_address_with_retry(self.aioredis_client)
|
||||
self.gcs_client = GcsClient(gcs_address)
|
||||
self.gcs_client = GcsClient(address=gcs_address)
|
||||
internal_kv._initialize_internal_kv(self.gcs_client)
|
||||
self.aiogrpc_gcs_channel = ray._private.utils.init_grpc_channel(
|
||||
gcs_address, GRPC_CHANNEL_OPTIONS, asynchronous=True)
|
||||
gcs_client = GcsClient(gcs_address)
|
||||
internal_kv._initialize_internal_kv(gcs_client)
|
||||
self.gcs_subscriber = None
|
||||
if gcs_pubsub_enabled():
|
||||
self.gcs_subscriber = GcsAioSubscriber(
|
||||
channel=self.aiogrpc_gcs_channel)
|
||||
await self.gcs_subscriber.subscribe_error()
|
||||
|
||||
self.health_check_thread = GCSHealthCheckThread(gcs_address)
|
||||
self.health_check_thread.start()
|
||||
|
|
|
@ -1,5 +1,4 @@
|
|||
NODE_STATS_UPDATE_INTERVAL_SECONDS = 1
|
||||
ERROR_INFO_UPDATE_INTERVAL_SECONDS = 5
|
||||
LOG_INFO_UPDATE_INTERVAL_SECONDS = 5
|
||||
UPDATE_NODES_INTERVAL_SECONDS = 5
|
||||
MAX_COUNT_OF_GCS_RPC_ERROR = 10
|
||||
|
|
|
@ -272,28 +272,15 @@ class NodeHead(dashboard_utils.DashboardHeadModule):
|
|||
logger.exception("Error receiving log info.")
|
||||
|
||||
async def _update_error_info(self):
|
||||
aioredis_client = self._dashboard_head.aioredis_client
|
||||
receiver = Receiver()
|
||||
|
||||
key = gcs_utils.RAY_ERROR_PUBSUB_PATTERN
|
||||
pattern = receiver.pattern(key)
|
||||
await aioredis_client.psubscribe(pattern)
|
||||
logger.info("Subscribed to %s", key)
|
||||
|
||||
async for sender, msg in receiver.iter():
|
||||
try:
|
||||
_, data = msg
|
||||
pubsub_msg = gcs_utils.PubSubMessage.FromString(data)
|
||||
error_data = gcs_utils.ErrorTableData.FromString(
|
||||
pubsub_msg.data)
|
||||
def process_error(error_data):
|
||||
error_data = gcs_utils.ErrorTableData.FromString(pubsub_msg.data)
|
||||
message = error_data.error_message
|
||||
message = re.sub(r"\x1b\[\d+m", "", message)
|
||||
match = re.search(r"\(pid=(\d+), ip=(.*?)\)", message)
|
||||
if match:
|
||||
pid = match.group(1)
|
||||
ip = match.group(2)
|
||||
errs_for_ip = dict(
|
||||
DataSource.ip_and_pid_to_errors.get(ip, {}))
|
||||
errs_for_ip = dict(DataSource.ip_and_pid_to_errors.get(ip, {}))
|
||||
pid_errors = list(errs_for_ip.get(pid, []))
|
||||
pid_errors.append({
|
||||
"message": message,
|
||||
|
@ -303,8 +290,33 @@ class NodeHead(dashboard_utils.DashboardHeadModule):
|
|||
errs_for_ip[pid] = pid_errors
|
||||
DataSource.ip_and_pid_to_errors[ip] = errs_for_ip
|
||||
logger.info(f"Received error entry for {ip} {pid}")
|
||||
|
||||
if self._dashboard_head.gcs_subscriber:
|
||||
while True:
|
||||
_, error_data = await \
|
||||
self._dashboard_head.gcs_subscriber.poll_error()
|
||||
try:
|
||||
process_error(error_data)
|
||||
except Exception:
|
||||
logger.exception("Error receiving error info.")
|
||||
logger.exception("Error receiving error info from GCS.")
|
||||
else:
|
||||
aioredis_client = self._dashboard_head.aioredis_client
|
||||
receiver = Receiver()
|
||||
|
||||
key = gcs_utils.RAY_ERROR_PUBSUB_PATTERN
|
||||
pattern = receiver.pattern(key)
|
||||
await aioredis_client.psubscribe(pattern)
|
||||
logger.info("Subscribed to %s", key)
|
||||
|
||||
async for _, msg in receiver.iter():
|
||||
try:
|
||||
_, data = msg
|
||||
pubsub_msg = gcs_utils.PubSubMessage.FromString(data)
|
||||
error_data = gcs_utils.ErrorTableData.FromString(
|
||||
pubsub_msg.data)
|
||||
process_error(error_data)
|
||||
except Exception:
|
||||
logger.exception("Error receiving error info from Redis.")
|
||||
|
||||
async def run(self, server):
|
||||
gcs_channel = self._dashboard_head.aiogrpc_gcs_channel
|
||||
|
|
276
python/ray/_private/gcs_pubsub.py
Normal file
276
python/ray/_private/gcs_pubsub.py
Normal file
|
@ -0,0 +1,276 @@
|
|||
import os
|
||||
from collections import deque
|
||||
import logging
|
||||
import random
|
||||
import threading
|
||||
from typing import Tuple
|
||||
|
||||
import grpc
|
||||
|
||||
import ray._private.gcs_utils as gcs_utils
|
||||
from ray.core.generated import gcs_service_pb2_grpc
|
||||
from ray.core.generated import gcs_service_pb2
|
||||
from ray.core.generated.gcs_pb2 import (
|
||||
ErrorTableData, )
|
||||
from ray.core.generated import pubsub_pb2
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def gcs_pubsub_enabled():
|
||||
"""Checks whether GCS pubsub feature flag is enabled."""
|
||||
return os.environ.get("RAY_gcs_grpc_based_pubsub") not in\
|
||||
[None, "0", "false"]
|
||||
|
||||
|
||||
def construct_error_message(job_id, error_type, message, timestamp):
|
||||
"""Construct an ErrorTableData object.
|
||||
|
||||
Args:
|
||||
job_id: The ID of the job that the error should go to. If this is
|
||||
nil, then the error will go to all drivers.
|
||||
error_type: The type of the error.
|
||||
message: The error message.
|
||||
timestamp: The time of the error.
|
||||
|
||||
Returns:
|
||||
The ErrorTableData object.
|
||||
"""
|
||||
data = ErrorTableData()
|
||||
data.job_id = job_id.binary()
|
||||
data.type = error_type
|
||||
data.error_message = message
|
||||
data.timestamp = timestamp
|
||||
return data
|
||||
|
||||
|
||||
class _SubscriberBase:
|
||||
def _subscribe_error_request(self):
|
||||
cmd = pubsub_pb2.Command(
|
||||
channel_type=pubsub_pb2.RAY_ERROR_INFO_CHANNEL,
|
||||
subscribe_message={})
|
||||
req = gcs_service_pb2.GcsSubscriberCommandBatchRequest(
|
||||
subscriber_id=self._subscriber_id, commands=[cmd])
|
||||
return req
|
||||
|
||||
def _poll_request(self):
|
||||
return gcs_service_pb2.GcsSubscriberPollRequest(
|
||||
subscriber_id=self._subscriber_id)
|
||||
|
||||
def _unsubscribe_request(self):
|
||||
req = gcs_service_pb2.GcsSubscriberCommandBatchRequest(
|
||||
subscriber_id=self._subscriber_id, commands=[])
|
||||
if self._subscribed_error:
|
||||
cmd = pubsub_pb2.Command(
|
||||
channel_type=pubsub_pb2.RAY_ERROR_INFO_CHANNEL,
|
||||
unsubscribe_message={})
|
||||
req.commands.append(cmd)
|
||||
return req
|
||||
|
||||
|
||||
class GcsPublisher:
|
||||
"""Publisher to GCS."""
|
||||
|
||||
def __init__(self, address: str = None, channel: grpc.Channel = None):
|
||||
if address:
|
||||
assert channel is None, \
|
||||
"address and channel cannot both be specified"
|
||||
channel = gcs_utils.create_gcs_channel(address)
|
||||
else:
|
||||
assert channel is not None, \
|
||||
"One of address and channel must be specified"
|
||||
self._stub = gcs_service_pb2_grpc.InternalPubSubGcsServiceStub(channel)
|
||||
|
||||
def publish_error(self, key_id: bytes, error_info: ErrorTableData) -> None:
|
||||
"""Publishes error info to GCS."""
|
||||
msg = pubsub_pb2.PubMessage(
|
||||
channel_type=pubsub_pb2.RAY_ERROR_INFO_CHANNEL,
|
||||
key_id=key_id,
|
||||
error_info_message=error_info)
|
||||
req = gcs_service_pb2.GcsPublishRequest(pub_messages=[msg])
|
||||
self._stub.GcsPublish(req)
|
||||
|
||||
|
||||
class GcsSubscriber(_SubscriberBase):
|
||||
"""Subscriber to GCS. Thread safe.
|
||||
|
||||
Usage example:
|
||||
subscriber = GcsSubscriber()
|
||||
subscriber.subscribe_error()
|
||||
while running:
|
||||
error_id, error_data = subscriber.poll_error()
|
||||
......
|
||||
subscriber.close()
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
address: str = None,
|
||||
channel: grpc.Channel = None,
|
||||
):
|
||||
if address:
|
||||
assert channel is None, \
|
||||
"address and channel cannot both be specified"
|
||||
channel = gcs_utils.create_gcs_channel(address)
|
||||
else:
|
||||
assert channel is not None, \
|
||||
"One of address and channel must be specified"
|
||||
self._lock = threading.RLock()
|
||||
self._stub = gcs_service_pb2_grpc.InternalPubSubGcsServiceStub(channel)
|
||||
self._subscriber_id = bytes(
|
||||
bytearray(random.getrandbits(8) for _ in range(28)))
|
||||
# Whether error info has been subscribed.
|
||||
self._subscribed_error = False
|
||||
# Buffer for holding error info.
|
||||
self._errors = deque()
|
||||
# Future for indicating whether the subscriber has closed.
|
||||
self._close = threading.Event()
|
||||
|
||||
def subscribe_error(self) -> None:
|
||||
"""Registers a subscription for error info.
|
||||
|
||||
Before the registration, published errors will not be saved for the
|
||||
subscriber.
|
||||
"""
|
||||
with self._lock:
|
||||
if self._close.is_set():
|
||||
return
|
||||
if not self._subscribed_error:
|
||||
req = self._subscribe_error_request()
|
||||
self._stub.GcsSubscriberCommandBatch(req, timeout=30)
|
||||
self._subscribed_error = True
|
||||
|
||||
def poll_error(self, timeout=None) -> Tuple[bytes, ErrorTableData]:
|
||||
"""Polls for new error messages."""
|
||||
with self._lock:
|
||||
if self._close.is_set():
|
||||
return
|
||||
|
||||
if len(self._errors) == 0:
|
||||
req = self._poll_request()
|
||||
fut = self._stub.GcsSubscriberPoll.future(req, timeout=timeout)
|
||||
# Wait for result to become available, or cancel if the
|
||||
# subscriber has closed.
|
||||
while True:
|
||||
try:
|
||||
fut.result(timeout=1)
|
||||
break
|
||||
except grpc.FutureTimeoutError:
|
||||
# Subscriber has closed. Cancel inflight the request
|
||||
# and return from polling.
|
||||
if self._close.is_set():
|
||||
fut.cancel()
|
||||
return None, None
|
||||
# GRPC has not replied, continue waiting.
|
||||
continue
|
||||
except Exception:
|
||||
# GRPC error, including deadline exceeded.
|
||||
raise
|
||||
if fut.done():
|
||||
for msg in fut.result().pub_messages:
|
||||
self._errors.append((msg.key_id,
|
||||
msg.error_info_message))
|
||||
|
||||
if len(self._errors) == 0:
|
||||
return None, None
|
||||
return self._errors.popleft()
|
||||
|
||||
def close(self) -> None:
|
||||
"""Closes the subscriber and its active subscriptions."""
|
||||
# Mark close to terminate inflight polling and prevent future requests.
|
||||
self._close.set()
|
||||
with self._lock:
|
||||
if not self._stub:
|
||||
# Subscriber already closed.
|
||||
return
|
||||
req = self._unsubscribe_request()
|
||||
try:
|
||||
self._stub.GcsSubscriberCommandBatch(req, timeout=30)
|
||||
except Exception:
|
||||
pass
|
||||
self._stub = None
|
||||
|
||||
|
||||
class GcsAioPublisher:
|
||||
"""Publisher to GCS. Uses async io."""
|
||||
|
||||
def __init__(self, address: str = None, channel: grpc.aio.Channel = None):
|
||||
if address:
|
||||
assert channel is None, \
|
||||
"address and channel cannot both be specified"
|
||||
channel = gcs_utils.create_gcs_channel(address, aio=True)
|
||||
else:
|
||||
assert channel is not None, \
|
||||
"One of address and channel must be specified"
|
||||
self._stub = gcs_service_pb2_grpc.InternalPubSubGcsServiceStub(channel)
|
||||
|
||||
async def publish_error(self, key_id: bytes,
|
||||
error_info: ErrorTableData) -> None:
|
||||
"""Publishes error info to GCS."""
|
||||
msg = pubsub_pb2.PubMessage(
|
||||
channel_type=pubsub_pb2.RAY_ERROR_INFO_CHANNEL,
|
||||
key_id=key_id,
|
||||
error_info_message=error_info)
|
||||
req = gcs_service_pb2.GcsPublishRequest(pub_messages=[msg])
|
||||
await self._stub.GcsPublish(req)
|
||||
|
||||
|
||||
class GcsAioSubscriber(_SubscriberBase):
|
||||
"""Async io subscriber to GCS.
|
||||
|
||||
Usage example:
|
||||
subscriber = GcsAioSubscriber()
|
||||
await subscriber.subscribe_error()
|
||||
while running:
|
||||
error_id, error_data = await subscriber.poll_error()
|
||||
......
|
||||
await subscriber.close()
|
||||
"""
|
||||
|
||||
def __init__(self, address: str = None, channel: grpc.aio.Channel = None):
|
||||
if address:
|
||||
assert channel is None, \
|
||||
"address and channel cannot both be specified"
|
||||
channel = gcs_utils.create_gcs_channel(address, aio=True)
|
||||
else:
|
||||
assert channel is not None, \
|
||||
"One of address and channel must be specified"
|
||||
self._stub = gcs_service_pb2_grpc.InternalPubSubGcsServiceStub(channel)
|
||||
self._subscriber_id = bytes(
|
||||
bytearray(random.getrandbits(8) for _ in range(28)))
|
||||
# Whether error info has been subscribed.
|
||||
self._subscribed_error = False
|
||||
# Buffer for holding error info.
|
||||
self._errors = deque()
|
||||
|
||||
async def subscribe_error(self) -> None:
|
||||
"""Registers a subscription for error info.
|
||||
|
||||
Before the registration, published errors will not be saved for the
|
||||
subscriber.
|
||||
"""
|
||||
if not self._subscribed_error:
|
||||
req = self._subscribe_error_request()
|
||||
await self._stub.GcsSubscriberCommandBatch(req, timeout=30)
|
||||
self._subscribed_error = True
|
||||
|
||||
async def poll_error(self, timeout=None) -> Tuple[bytes, ErrorTableData]:
|
||||
"""Polls for new error messages."""
|
||||
if len(self._errors) == 0:
|
||||
req = self._poll_request()
|
||||
reply = await self._stub.GcsSubscriberPoll(req, timeout=timeout)
|
||||
for msg in reply.pub_messages:
|
||||
self._errors.append((msg.key_id, msg.error_info_message))
|
||||
|
||||
if len(self._errors) == 0:
|
||||
return None, None
|
||||
return self._errors.popleft()
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Closes the subscriber and its active subscriptions."""
|
||||
req = self._unsubscribe_request()
|
||||
try:
|
||||
await self._stub.GcsSubscriberCommandBatch(req, timeout=30)
|
||||
except Exception:
|
||||
pass
|
||||
self._subscribed_error = False
|
|
@ -1,7 +1,10 @@
|
|||
from ray.core.generated.common_pb2 import ErrorType
|
||||
import enum
|
||||
import logging
|
||||
from typing import List
|
||||
|
||||
import grpc
|
||||
|
||||
from ray.core.generated.common_pb2 import ErrorType
|
||||
from ray.core.generated import gcs_service_pb2_grpc
|
||||
from ray.core.generated import gcs_service_pb2
|
||||
from ray.core.generated.gcs_pb2 import (
|
||||
|
@ -51,64 +54,69 @@ __all__ = [
|
|||
"ResourceLoad",
|
||||
"ResourceMap",
|
||||
"ResourceTableData",
|
||||
"construct_error_message",
|
||||
"ObjectLocationInfo",
|
||||
"PubSubMessage",
|
||||
"WorkerTableData",
|
||||
"PlacementGroupTableData",
|
||||
]
|
||||
|
||||
FUNCTION_PREFIX = "RemoteFunction:"
|
||||
LOG_FILE_CHANNEL = "RAY_LOG_CHANNEL"
|
||||
REPORTER_CHANNEL = "RAY_REPORTER"
|
||||
|
||||
# xray resource usages
|
||||
XRAY_RESOURCES_BATCH_PATTERN = "RESOURCES_BATCH:".encode("ascii")
|
||||
|
||||
# xray job updates
|
||||
XRAY_JOB_PATTERN = "JOB:*".encode("ascii")
|
||||
|
||||
# Actor pub/sub updates
|
||||
RAY_ACTOR_PUBSUB_PATTERN = "ACTOR:*".encode("ascii")
|
||||
|
||||
# Reporter pub/sub updates
|
||||
RAY_REPORTER_PUBSUB_PATTERN = "RAY_REPORTER.*".encode("ascii")
|
||||
|
||||
RAY_ERROR_PUBSUB_PATTERN = "ERROR_INFO:*".encode("ascii")
|
||||
|
||||
# These prefixes must be kept up-to-date with the TablePrefix enum in
|
||||
# gcs.proto.
|
||||
# TODO(rkn): We should use scoped enums, in which case we should be able to
|
||||
# just access the flatbuffer generated values.
|
||||
TablePrefix_RAYLET_TASK_string = "RAYLET_TASK"
|
||||
TablePrefix_OBJECT_string = "OBJECT"
|
||||
TablePrefix_PROFILE_string = "PROFILE"
|
||||
TablePrefix_JOB_string = "JOB"
|
||||
TablePrefix_ACTOR_string = "ACTOR"
|
||||
|
||||
WORKER = 0
|
||||
DRIVER = 1
|
||||
|
||||
# Cap messages at 512MB
|
||||
_MAX_MESSAGE_LENGTH = 512 * 1024 * 1024
|
||||
# Send keepalive every 60s
|
||||
_GRPC_KEEPALIVE_TIME_MS = 60 * 1000
|
||||
# Keepalive should be replied < 60s
|
||||
_GRPC_KEEPALIVE_TIMEOUT_MS = 60 * 1000
|
||||
|
||||
def construct_error_message(job_id, error_type, message, timestamp):
|
||||
"""Construct a serialized ErrorTableData object.
|
||||
# Also relying on these defaults:
|
||||
# grpc.keepalive_permit_without_calls=0: No keepalive without inflight calls.
|
||||
# grpc.use_local_subchannel_pool=0: Subchannels are shared.
|
||||
_GRPC_OPTIONS = [("grpc.enable_http_proxy",
|
||||
0), ("grpc.max_send_message_length", _MAX_MESSAGE_LENGTH),
|
||||
("grpc.max_receive_message_length", _MAX_MESSAGE_LENGTH),
|
||||
("grpc.keepalive_time_ms",
|
||||
_GRPC_KEEPALIVE_TIME_MS), ("grpc.keepalive_timeout_ms",
|
||||
_GRPC_KEEPALIVE_TIMEOUT_MS)]
|
||||
|
||||
|
||||
def get_gcs_address_from_redis(redis) -> str:
|
||||
"""Reads GCS address from redis.
|
||||
|
||||
Args:
|
||||
job_id: The ID of the job that the error should go to. If this is
|
||||
nil, then the error will go to all drivers.
|
||||
error_type: The type of the error.
|
||||
message: The error message.
|
||||
timestamp: The time of the error.
|
||||
|
||||
redis: Redis client to fetch GCS address.
|
||||
Returns:
|
||||
The serialized object.
|
||||
GCS address string.
|
||||
"""
|
||||
data = ErrorTableData()
|
||||
data.job_id = job_id.binary()
|
||||
data.type = error_type
|
||||
data.error_message = message
|
||||
data.timestamp = timestamp
|
||||
return data.SerializeToString()
|
||||
gcs_address = redis.get("GcsServerAddress")
|
||||
if gcs_address is None:
|
||||
raise RuntimeError("Failed to look up gcs address through redis")
|
||||
return gcs_address.decode()
|
||||
|
||||
|
||||
def create_gcs_channel(address: str, aio=False):
|
||||
"""Returns a GRPC channel to GCS.
|
||||
|
||||
Args:
|
||||
address: GCS address string, e.g. ip:port
|
||||
aio: Whether using grpc.aio
|
||||
Returns:
|
||||
grpc.Channel or grpc.aio.Channel to GCS
|
||||
"""
|
||||
from ray._private.utils import init_grpc_channel
|
||||
return init_grpc_channel(address, options=_GRPC_OPTIONS, asynchronous=aio)
|
||||
|
||||
|
||||
class GcsCode(enum.IntEnum):
|
||||
|
@ -118,17 +126,13 @@ class GcsCode(enum.IntEnum):
|
|||
|
||||
|
||||
class GcsClient:
|
||||
MAX_MESSAGE_LENGTH = 512 * 1024 * 1024 # 512MB
|
||||
"""Client to GCS using GRPC"""
|
||||
|
||||
def __init__(self, address):
|
||||
from ray._private.utils import init_grpc_channel
|
||||
logger.debug(f"Connecting to gcs address: {address}")
|
||||
options = [("grpc.enable_http_proxy",
|
||||
0), ("grpc.max_send_message_length",
|
||||
GcsClient.MAX_MESSAGE_LENGTH),
|
||||
("grpc.max_receive_message_length",
|
||||
GcsClient.MAX_MESSAGE_LENGTH)]
|
||||
channel = init_grpc_channel(address, options=options)
|
||||
def __init__(self, address: str = None, channel: grpc.Channel = None):
|
||||
if address:
|
||||
assert channel is None, \
|
||||
"Only one of address and channel can be specified"
|
||||
channel = create_gcs_channel(address)
|
||||
self._kv_stub = gcs_service_pb2_grpc.InternalKVGcsServiceStub(channel)
|
||||
|
||||
def internal_kv_get(self, key: bytes) -> bytes:
|
||||
|
@ -187,10 +191,7 @@ class GcsClient:
|
|||
|
||||
@staticmethod
|
||||
def create_from_redis(redis_cli):
|
||||
gcs_address = redis_cli.get("GcsServerAddress")
|
||||
if gcs_address is None:
|
||||
raise RuntimeError("Failed to look up gcs address through redis")
|
||||
return GcsClient(gcs_address.decode())
|
||||
return GcsClient(get_gcs_address_from_redis(redis_cli))
|
||||
|
||||
@staticmethod
|
||||
def connect_to_gcs_by_redis_address(redis_address, redis_password):
|
||||
|
|
|
@ -15,6 +15,7 @@ import ray.ray_constants as ray_constants
|
|||
import ray._private.gcs_utils as gcs_utils
|
||||
import ray._private.services as services
|
||||
import ray._private.utils
|
||||
from ray._private.gcs_pubsub import gcs_pubsub_enabled, GcsPublisher
|
||||
from ray._private.ray_logging import setup_component_logger
|
||||
|
||||
# Logger for this module. It should be configured at the entry point
|
||||
|
@ -365,6 +366,11 @@ if __name__ == "__main__":
|
|||
description=("Parse Redis server for the "
|
||||
"log monitor to connect "
|
||||
"to."))
|
||||
parser.add_argument(
|
||||
"--gcs-address",
|
||||
required=False,
|
||||
type=str,
|
||||
help="The address (ip:port) of GCS.")
|
||||
parser.add_argument(
|
||||
"--redis-address",
|
||||
required=True,
|
||||
|
@ -436,11 +442,20 @@ if __name__ == "__main__":
|
|||
# Something went wrong, so push an error to all drivers.
|
||||
redis_client = ray._private.services.create_redis_client(
|
||||
args.redis_address, password=args.redis_password)
|
||||
gcs_publisher = None
|
||||
if args.gcs_address:
|
||||
gcs_publisher = GcsPublisher(address=args.gcs_address)
|
||||
elif gcs_pubsub_enabled():
|
||||
gcs_publisher = GcsPublisher(
|
||||
address=gcs_utils.get_gcs_address_from_redis(redis_client))
|
||||
traceback_str = ray._private.utils.format_error_message(
|
||||
traceback.format_exc())
|
||||
message = (f"The log monitor on node {platform.node()} "
|
||||
f"failed with the following error:\n{traceback_str}")
|
||||
ray._private.utils.push_error_to_driver_through_redis(
|
||||
redis_client, ray_constants.LOG_MONITOR_DIED_ERROR, message)
|
||||
ray._private.utils.publish_error_to_driver(
|
||||
ray_constants.LOG_MONITOR_DIED_ERROR,
|
||||
message,
|
||||
redis_client=redis_client,
|
||||
gcs_publisher=gcs_publisher)
|
||||
logger.error(message)
|
||||
raise e
|
||||
|
|
|
@ -1926,6 +1926,7 @@ def start_monitor(redis_address,
|
|||
|
||||
Args:
|
||||
redis_address (str): The address that the Redis server is listening on.
|
||||
gcs_address (str): The address of GCS server.
|
||||
logs_dir(str): The path to the log directory.
|
||||
stdout_file: A file handle opened for writing to redirect stdout to. If
|
||||
no redirection should happen, then this should be None.
|
||||
|
|
|
@ -26,6 +26,7 @@ import ray._private.gcs_utils as gcs_utils
|
|||
import ray._private.memory_monitor as memory_monitor
|
||||
from ray.core.generated import node_manager_pb2
|
||||
from ray.core.generated import node_manager_pb2_grpc
|
||||
from ray._private.gcs_pubsub import gcs_pubsub_enabled, GcsSubscriber
|
||||
from ray._private.tls_utils import generate_self_signed_tls_certs
|
||||
from ray.util.queue import Queue, _QueueActor, Empty
|
||||
from ray.scripts.scripts import main as ray_main
|
||||
|
@ -502,19 +503,34 @@ def get_non_head_nodes(cluster):
|
|||
|
||||
def init_error_pubsub():
|
||||
"""Initialize redis error info pub/sub"""
|
||||
p = ray.worker.global_worker.redis_client.pubsub(
|
||||
if gcs_pubsub_enabled():
|
||||
s = GcsSubscriber(channel=ray.worker.global_worker.gcs_channel)
|
||||
s.subscribe_error()
|
||||
else:
|
||||
s = ray.worker.global_worker.redis_client.pubsub(
|
||||
ignore_subscribe_messages=True)
|
||||
error_pubsub_channel = gcs_utils.RAY_ERROR_PUBSUB_PATTERN
|
||||
p.psubscribe(error_pubsub_channel)
|
||||
return p
|
||||
s.psubscribe(gcs_utils.RAY_ERROR_PUBSUB_PATTERN)
|
||||
return s
|
||||
|
||||
|
||||
def get_error_message(pub_sub, num, error_type=None, timeout=20):
|
||||
"""Get errors through pub/sub."""
|
||||
start_time = time.time()
|
||||
def get_error_message(subscriber, num, error_type=None, timeout=20):
|
||||
"""Get errors through subscriber."""
|
||||
deadline = time.time() + timeout
|
||||
msgs = []
|
||||
while time.time() - start_time < timeout and len(msgs) < num:
|
||||
msg = pub_sub.get_message()
|
||||
while time.time() < deadline and len(msgs) < num:
|
||||
if isinstance(subscriber, GcsSubscriber):
|
||||
try:
|
||||
_, error_data = subscriber.poll_error(timeout=deadline -
|
||||
time.time())
|
||||
except grpc.RpcError as e:
|
||||
# Failed to match error message before timeout.
|
||||
if e.code() == grpc.StatusCode.DEADLINE_EXCEEDED:
|
||||
logging.warning("get_error_message() timed out")
|
||||
return []
|
||||
# Otherwise, the error is unexpected.
|
||||
raise
|
||||
else:
|
||||
msg = subscriber.get_message()
|
||||
if msg is None:
|
||||
time.sleep(0.01)
|
||||
continue
|
||||
|
|
|
@ -27,6 +27,7 @@ import numpy as np
|
|||
import ray
|
||||
import ray._private.gcs_utils as gcs_utils
|
||||
import ray.ray_constants as ray_constants
|
||||
from ray._private.gcs_pubsub import construct_error_message
|
||||
from ray._private.tls_utils import load_certs_from_env
|
||||
|
||||
# Import psutil after ray so the packaged version is used.
|
||||
|
@ -113,10 +114,11 @@ def push_error_to_driver(worker, error_type, message, job_id=None):
|
|||
worker.core_worker.push_error(job_id, error_type, message, time.time())
|
||||
|
||||
|
||||
def push_error_to_driver_through_redis(redis_client,
|
||||
error_type,
|
||||
def publish_error_to_driver(error_type,
|
||||
message,
|
||||
job_id=None):
|
||||
job_id=None,
|
||||
redis_client=None,
|
||||
gcs_publisher=None):
|
||||
"""Push an error message to the driver to be printed in the background.
|
||||
|
||||
Normally the push_error_to_driver function should be used. However, in some
|
||||
|
@ -125,25 +127,31 @@ def push_error_to_driver_through_redis(redis_client,
|
|||
backend processes.
|
||||
|
||||
Args:
|
||||
redis_client: The redis client to use.
|
||||
error_type (str): The type of the error.
|
||||
message (str): The message that will be printed in the background
|
||||
on the driver.
|
||||
job_id: The ID of the driver to push the error message to. If this
|
||||
is None, then the message will be pushed to all drivers.
|
||||
redis_client: The redis client to use.
|
||||
gcs_publisher: The GCS publisher to use. If specified, ignores
|
||||
redis_client.
|
||||
"""
|
||||
if job_id is None:
|
||||
job_id = ray.JobID.nil()
|
||||
assert isinstance(job_id, ray.JobID)
|
||||
# Do everything in Python and through the Python Redis client instead
|
||||
# of through the raylet.
|
||||
error_data = gcs_utils.construct_error_message(job_id, error_type, message,
|
||||
error_data = construct_error_message(job_id, error_type, message,
|
||||
time.time())
|
||||
if gcs_publisher:
|
||||
gcs_publisher.publish_error(job_id.hex().encode(), error_data)
|
||||
elif redis_client:
|
||||
pubsub_msg = gcs_utils.PubSubMessage()
|
||||
pubsub_msg.id = job_id.binary()
|
||||
pubsub_msg.data = error_data
|
||||
pubsub_msg.data = error_data.SerializeToString()
|
||||
redis_client.publish("ERROR_INFO:" + job_id.hex(),
|
||||
pubsub_msg.SerializeToString())
|
||||
else:
|
||||
raise ValueError(
|
||||
"One of redis_client and gcs_publisher needs to be specified!")
|
||||
|
||||
|
||||
def random_string():
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
"""Autoscaler monitoring loop daemon."""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import logging.handlers
|
||||
import os
|
||||
import sys
|
||||
|
@ -35,7 +34,8 @@ from ray.autoscaler._private.util import DEBUG_AUTOSCALING_STATUS, \
|
|||
from ray.core.generated import gcs_service_pb2, gcs_service_pb2_grpc
|
||||
import ray.ray_constants as ray_constants
|
||||
from ray._private.ray_logging import setup_component_logger
|
||||
from ray._private.gcs_utils import GcsClient
|
||||
from ray._private.gcs_pubsub import gcs_pubsub_enabled, GcsPublisher
|
||||
from ray._private.gcs_utils import GcsClient, get_gcs_address_from_redis
|
||||
from ray.experimental.internal_kv import _initialize_internal_kv, \
|
||||
_internal_kv_put, _internal_kv_initialized, _internal_kv_get, \
|
||||
_internal_kv_del
|
||||
|
@ -409,9 +409,18 @@ class Monitor:
|
|||
_internal_kv_put(DEBUG_AUTOSCALING_ERROR, message, overwrite=True)
|
||||
redis_client = ray._private.services.create_redis_client(
|
||||
self.redis_address, password=self.redis_password)
|
||||
from ray._private.utils import push_error_to_driver_through_redis
|
||||
push_error_to_driver_through_redis(
|
||||
redis_client, ray_constants.MONITOR_DIED_ERROR, message)
|
||||
gcs_publisher = None
|
||||
if args.gcs_address:
|
||||
gcs_publisher = GcsPublisher(address=args.gcs_address)
|
||||
elif gcs_pubsub_enabled():
|
||||
gcs_publisher = GcsPublisher(
|
||||
address=get_gcs_address_from_redis(redis_client))
|
||||
from ray._private.utils import publish_error_to_driver
|
||||
publish_error_to_driver(
|
||||
ray_constants.MONITOR_DIED_ERROR,
|
||||
message,
|
||||
redis_client=redis_client,
|
||||
gcs_publisher=gcs_publisher)
|
||||
|
||||
def _signal_handler(self, sig, frame):
|
||||
self._handle_failure(f"Terminated with signal {sig}\n" +
|
||||
|
@ -437,6 +446,11 @@ if __name__ == "__main__":
|
|||
parser = argparse.ArgumentParser(
|
||||
description=("Parse Redis server for the "
|
||||
"monitor to connect to."))
|
||||
parser.add_argument(
|
||||
"--gcs-address",
|
||||
required=False,
|
||||
type=str,
|
||||
help="The address (ip:port) of GCS.")
|
||||
parser.add_argument(
|
||||
"--redis-address",
|
||||
required=True,
|
||||
|
|
|
@ -28,6 +28,7 @@ py_test_module_list(
|
|||
"test_component_failures_2.py",
|
||||
"test_component_failures_3.py",
|
||||
"test_error_ray_not_initialized.py",
|
||||
"test_gcs_pubsub.py",
|
||||
"test_global_gc.py",
|
||||
"test_grpc_client_credentials.py",
|
||||
"test_iter.py",
|
||||
|
|
|
@ -11,6 +11,7 @@ import ray._private.utils
|
|||
import ray._private.gcs_utils as gcs_utils
|
||||
import ray.ray_constants as ray_constants
|
||||
from ray.exceptions import RayTaskError, RayActorError, GetTimeoutError
|
||||
from ray._private.gcs_pubsub import gcs_pubsub_enabled, GcsPublisher
|
||||
from ray._private.test_utils import (wait_for_condition, SignalActor,
|
||||
init_error_pubsub, get_error_message)
|
||||
|
||||
|
@ -61,14 +62,21 @@ def test_unhandled_errors(ray_start_regular):
|
|||
del os.environ["RAY_IGNORE_UNHANDLED_ERRORS"]
|
||||
|
||||
|
||||
def test_push_error_to_driver_through_redis(ray_start_regular, error_pubsub):
|
||||
def test_publish_error_to_driver(ray_start_regular, error_pubsub):
|
||||
address_info = ray_start_regular
|
||||
address = address_info["redis_address"]
|
||||
redis_client = ray._private.services.create_redis_client(
|
||||
address, password=ray.ray_constants.REDIS_DEFAULT_PASSWORD)
|
||||
gcs_publisher = None
|
||||
if gcs_pubsub_enabled():
|
||||
gcs_publisher = GcsPublisher(
|
||||
address=gcs_utils.get_gcs_address_from_redis(redis_client))
|
||||
error_message = "Test error message"
|
||||
ray._private.utils.push_error_to_driver_through_redis(
|
||||
redis_client, ray_constants.DASHBOARD_AGENT_DIED_ERROR, error_message)
|
||||
ray._private.utils.publish_error_to_driver(
|
||||
ray_constants.DASHBOARD_AGENT_DIED_ERROR,
|
||||
error_message,
|
||||
redis_client=redis_client,
|
||||
gcs_publisher=gcs_publisher)
|
||||
errors = get_error_message(error_pubsub, 1,
|
||||
ray_constants.DASHBOARD_AGENT_DIED_ERROR)
|
||||
assert errors[0].type == ray_constants.DASHBOARD_AGENT_DIED_ERROR
|
||||
|
|
|
@ -104,9 +104,8 @@ def test_connect_with_disconnected_node(shutdown_only):
|
|||
ray.init(address=cluster.address)
|
||||
p = init_error_pubsub()
|
||||
errors = get_error_message(p, 1, timeout=5)
|
||||
print(errors)
|
||||
assert len(errors) == 0
|
||||
# This node is killed by SIGKILL, ray_monitor will mark it to dead.
|
||||
# This node will be killed by SIGKILL, ray_monitor will mark it to dead.
|
||||
dead_node = cluster.add_node(num_cpus=0)
|
||||
cluster.remove_node(dead_node, allow_graceful=False)
|
||||
errors = get_error_message(p, 1, ray_constants.REMOVED_NODE_ERROR)
|
||||
|
|
73
python/ray/tests/test_gcs_pubsub.py
Normal file
73
python/ray/tests/test_gcs_pubsub.py
Normal file
|
@ -0,0 +1,73 @@
|
|||
import sys
|
||||
|
||||
import ray
|
||||
import ray._private.gcs_utils as gcs_utils
|
||||
from ray._private.gcs_pubsub import GcsPublisher, GcsSubscriber, \
|
||||
GcsAioPublisher, GcsAioSubscriber
|
||||
from ray.core.generated.gcs_pb2 import ErrorTableData
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"ray_start_regular", [{
|
||||
"_system_config": {
|
||||
"gcs_grpc_based_pubsub": True
|
||||
}
|
||||
}],
|
||||
indirect=True)
|
||||
def test_publish_and_subscribe_error_info(ray_start_regular):
|
||||
address_info = ray_start_regular
|
||||
redis = ray._private.services.create_redis_client(
|
||||
address_info["redis_address"],
|
||||
password=ray.ray_constants.REDIS_DEFAULT_PASSWORD)
|
||||
|
||||
gcs_server_addr = gcs_utils.get_gcs_address_from_redis(redis)
|
||||
|
||||
subscriber = GcsSubscriber(address=gcs_server_addr)
|
||||
subscriber.subscribe_error()
|
||||
|
||||
publisher = GcsPublisher(address=gcs_server_addr)
|
||||
err1 = ErrorTableData(error_message="test error message 1")
|
||||
err2 = ErrorTableData(error_message="test error message 2")
|
||||
publisher.publish_error(b"aaa_id", err1)
|
||||
publisher.publish_error(b"bbb_id", err2)
|
||||
|
||||
assert subscriber.poll_error() == (b"aaa_id", err1)
|
||||
assert subscriber.poll_error() == (b"bbb_id", err2)
|
||||
|
||||
subscriber.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"ray_start_regular", [{
|
||||
"_system_config": {
|
||||
"gcs_grpc_based_pubsub": True
|
||||
}
|
||||
}],
|
||||
indirect=True)
|
||||
async def test_aio_publish_and_subscribe_error_info(ray_start_regular):
|
||||
address_info = ray_start_regular
|
||||
redis = ray._private.services.create_redis_client(
|
||||
address_info["redis_address"],
|
||||
password=ray.ray_constants.REDIS_DEFAULT_PASSWORD)
|
||||
|
||||
gcs_server_addr = gcs_utils.get_gcs_address_from_redis(redis)
|
||||
|
||||
subscriber = GcsAioSubscriber(address=gcs_server_addr)
|
||||
await subscriber.subscribe_error()
|
||||
|
||||
publisher = GcsAioPublisher(address=gcs_server_addr)
|
||||
err1 = ErrorTableData(error_message="test error message 1")
|
||||
err2 = ErrorTableData(error_message="test error message 2")
|
||||
await publisher.publish_error(b"aaa_id", err1)
|
||||
await publisher.publish_error(b"bbb_id", err2)
|
||||
|
||||
assert await subscriber.poll_error() == (b"aaa_id", err1)
|
||||
assert await subscriber.poll_error() == (b"bbb_id", err2)
|
||||
|
||||
await subscriber.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(pytest.main(["-v", __file__]))
|
|
@ -27,6 +27,8 @@ import ray.remote_function
|
|||
import ray.serialization as serialization
|
||||
import ray._private.gcs_utils as gcs_utils
|
||||
import ray._private.services as services
|
||||
from ray._private.gcs_pubsub import gcs_pubsub_enabled, GcsPublisher, \
|
||||
GcsSubscriber
|
||||
from ray._private.runtime_env.py_modules import upload_py_modules_if_needed
|
||||
from ray._private.runtime_env.working_dir import upload_working_dir_if_needed
|
||||
from ray._private.runtime_env.constants import RAY_JOB_CONFIG_JSON_ENV_VAR
|
||||
|
@ -1247,6 +1249,58 @@ def listen_error_messages_raylet(worker, threads_stopped):
|
|||
worker.error_message_pubsub_client.close()
|
||||
|
||||
|
||||
def listen_error_messages_from_gcs(worker, threads_stopped):
|
||||
"""Listen to error messages in the background on the driver.
|
||||
|
||||
This runs in a separate thread on the driver and pushes (error, time)
|
||||
tuples to be published.
|
||||
|
||||
Args:
|
||||
worker: The worker class that this thread belongs to.
|
||||
threads_stopped (threading.Event): A threading event used to signal to
|
||||
the thread that it should exit.
|
||||
"""
|
||||
worker.gcs_subscriber = GcsSubscriber(channel=worker.gcs_channel)
|
||||
# Exports that are published after the call to
|
||||
# gcs_subscriber.subscribe_error() and before the call to
|
||||
# gcs_subscriber.poll_error() will still be processed in the loop.
|
||||
|
||||
# TODO: we should just subscribe to the errors for this specific job.
|
||||
worker.gcs_subscriber.subscribe_error()
|
||||
|
||||
try:
|
||||
if _internal_kv_initialized():
|
||||
# Get any autoscaler errors that occurred before the call to
|
||||
# subscribe.
|
||||
error_message = _internal_kv_get(DEBUG_AUTOSCALING_ERROR)
|
||||
if error_message is not None:
|
||||
logger.warning(error_message.decode())
|
||||
|
||||
while True:
|
||||
# Exit if received a signal that the thread should stop.
|
||||
if threads_stopped.is_set():
|
||||
return
|
||||
|
||||
_, error_data = worker.gcs_subscriber.poll_error()
|
||||
if error_data is None:
|
||||
continue
|
||||
if error_data.job_id not in [
|
||||
worker.current_job_id.binary(),
|
||||
JobID.nil().binary(),
|
||||
]:
|
||||
continue
|
||||
|
||||
error_message = error_data.error_message
|
||||
if error_data.type == ray_constants.TASK_PUSH_ERROR:
|
||||
# TODO(ekl) remove task push errors entirely now that we have
|
||||
# the separate unhandled exception handler.
|
||||
pass
|
||||
else:
|
||||
logger.warning(error_message)
|
||||
except (OSError, ConnectionError) as e:
|
||||
logger.error(f"listen_error_messages_from_gcs: {e}")
|
||||
|
||||
|
||||
@PublicAPI
|
||||
@client_mode_hook(auto_init=False)
|
||||
def is_initialized() -> bool:
|
||||
|
@ -1307,11 +1361,16 @@ def connect(node,
|
|||
# that is not true of Redis pubsub clients. See the documentation at
|
||||
# https://github.com/andymccurdy/redis-py#thread-safety.
|
||||
worker.redis_client = node.create_redis_client()
|
||||
worker.gcs_client = gcs_utils.GcsClient.create_from_redis(
|
||||
worker.redis_client)
|
||||
worker.gcs_channel = gcs_utils.create_gcs_channel(
|
||||
gcs_utils.get_gcs_address_from_redis(worker.redis_client))
|
||||
worker.gcs_client = gcs_utils.GcsClient(channel=worker.gcs_channel)
|
||||
_initialize_internal_kv(worker.gcs_client)
|
||||
ray.state.state._initialize_global_state(
|
||||
node.redis_address, redis_password=node.redis_password)
|
||||
worker.gcs_pubsub_enabled = gcs_pubsub_enabled()
|
||||
worker.gcs_publisher = None
|
||||
if worker.gcs_pubsub_enabled:
|
||||
worker.gcs_publisher = GcsPublisher(channel=worker.gcs_channel)
|
||||
|
||||
# Initialize some fields.
|
||||
if mode in (WORKER_MODE, RESTORE_WORKER_MODE, SPILL_WORKER_MODE):
|
||||
|
@ -1350,11 +1409,12 @@ def connect(node,
|
|||
raise e
|
||||
elif mode == WORKER_MODE:
|
||||
traceback_str = traceback.format_exc()
|
||||
ray._private.utils.push_error_to_driver_through_redis(
|
||||
worker.redis_client,
|
||||
ray._private.utils.publish_error_to_driver(
|
||||
ray_constants.VERSION_MISMATCH_PUSH_ERROR,
|
||||
traceback_str,
|
||||
job_id=None)
|
||||
job_id=None,
|
||||
redis_client=worker.redis_client,
|
||||
gcs_publisher=worker.gcs_publisher)
|
||||
|
||||
worker.lock = threading.RLock()
|
||||
|
||||
|
@ -1445,7 +1505,8 @@ def connect(node,
|
|||
# scheduler for new error messages.
|
||||
if mode == SCRIPT_MODE:
|
||||
worker.listener_thread = threading.Thread(
|
||||
target=listen_error_messages_raylet,
|
||||
target=listen_error_messages_from_gcs
|
||||
if worker.gcs_pubsub_enabled else listen_error_messages_raylet,
|
||||
name="ray_listen_error_messages",
|
||||
args=(worker, worker.threads_stopped))
|
||||
worker.listener_thread.daemon = True
|
||||
|
@ -1519,6 +1580,8 @@ def disconnect(exiting_interpreter=False):
|
|||
if hasattr(worker, "import_thread"):
|
||||
worker.import_thread.join_import_thread()
|
||||
if hasattr(worker, "listener_thread"):
|
||||
if hasattr(worker, "gcs_subscriber"):
|
||||
worker.gcs_subscriber.close()
|
||||
worker.listener_thread.join()
|
||||
if hasattr(worker, "logger_thread"):
|
||||
worker.logger_thread.join()
|
||||
|
|
|
@ -116,17 +116,6 @@ class MockGcsTaskReconstructionTable : public GcsTaskReconstructionTable {
|
|||
namespace ray {
|
||||
namespace gcs {
|
||||
|
||||
class MockGcsObjectTable : public GcsObjectTable {
|
||||
public:
|
||||
MOCK_METHOD(JobID, GetJobIdFromKey, (const ObjectID &key), (override));
|
||||
};
|
||||
|
||||
} // namespace gcs
|
||||
} // namespace ray
|
||||
|
||||
namespace ray {
|
||||
namespace gcs {
|
||||
|
||||
class MockGcsNodeTable : public GcsNodeTable {
|
||||
public:
|
||||
};
|
||||
|
|
|
@ -72,7 +72,8 @@ void GcsServer::Start() {
|
|||
rpc::ChannelType::GCS_JOB_CHANNEL,
|
||||
rpc::ChannelType::GCS_NODE_INFO_CHANNEL,
|
||||
rpc::ChannelType::GCS_NODE_RESOURCE_CHANNEL,
|
||||
rpc::ChannelType::GCS_WORKER_DELTA_CHANNEL},
|
||||
rpc::ChannelType::GCS_WORKER_DELTA_CHANNEL,
|
||||
rpc::ChannelType::RAY_ERROR_INFO_CHANNEL},
|
||||
/*periodical_runner=*/&pubsub_periodical_runner_,
|
||||
/*get_time_ms=*/[]() { return absl::GetCurrentTimeNanos() / 1e6; },
|
||||
/*subscriber_timeout_ms=*/RayConfig::instance().subscriber_timeout_ms(),
|
||||
|
|
|
@ -213,17 +213,6 @@ class GcsTaskReconstructionTable
|
|||
JobID GetJobIdFromKey(const TaskID &key) override { return key.ActorId().JobId(); }
|
||||
};
|
||||
|
||||
class GcsObjectTable : public GcsTableWithJobId<ObjectID, ObjectLocationInfo> {
|
||||
public:
|
||||
explicit GcsObjectTable(std::shared_ptr<StoreClient> store_client)
|
||||
: GcsTableWithJobId(std::move(store_client)) {
|
||||
table_name_ = TablePrefix_Name(TablePrefix::OBJECT);
|
||||
}
|
||||
|
||||
private:
|
||||
JobID GetJobIdFromKey(const ObjectID &key) override { return key.TaskId().JobId(); }
|
||||
};
|
||||
|
||||
class GcsNodeTable : public GcsTable<NodeID, GcsNodeInfo> {
|
||||
public:
|
||||
explicit GcsNodeTable(std::shared_ptr<StoreClient> store_client)
|
||||
|
|
|
@ -17,12 +17,35 @@
|
|||
namespace ray {
|
||||
namespace gcs {
|
||||
|
||||
void InternalPubSubHandler::HandleGcsPublish(const rpc::GcsPublishRequest &request,
|
||||
rpc::GcsPublishReply *reply,
|
||||
rpc::SendReplyCallback send_reply_callback) {
|
||||
if (gcs_publisher_ == nullptr) {
|
||||
send_reply_callback(
|
||||
Status::NotImplemented("GCS pubsub is not yet enabled. Please enable it with "
|
||||
"system config `gcs_grpc_based_pubsub=True`"),
|
||||
nullptr, nullptr);
|
||||
return;
|
||||
}
|
||||
for (const auto &msg : request.pub_messages()) {
|
||||
gcs_publisher_->GetPublisher()->Publish(msg);
|
||||
}
|
||||
send_reply_callback(Status::OK(), nullptr, nullptr);
|
||||
}
|
||||
|
||||
// Needs to use rpc::GcsSubscriberPollRequest and rpc::GcsSubscriberPollReply here,
|
||||
// and convert the reply to rpc::PubsubLongPollingReply because GCS RPC services are
|
||||
// required to have the `status` field in replies.
|
||||
void InternalPubSubHandler::HandleGcsSubscriberPoll(
|
||||
const rpc::GcsSubscriberPollRequest &request, rpc::GcsSubscriberPollReply *reply,
|
||||
rpc::SendReplyCallback send_reply_callback) {
|
||||
if (gcs_publisher_ == nullptr) {
|
||||
send_reply_callback(
|
||||
Status::NotImplemented("GCS pubsub is not yet enabled. Please enable it with "
|
||||
"system config `gcs_grpc_based_pubsub=True`"),
|
||||
nullptr, nullptr);
|
||||
return;
|
||||
}
|
||||
const auto subscriber_id = UniqueID::FromBinary(request.subscriber_id());
|
||||
auto pubsub_reply = std::make_shared<rpc::PubsubLongPollingReply>();
|
||||
auto pubsub_reply_ptr = pubsub_reply.get();
|
||||
|
@ -44,6 +67,13 @@ void InternalPubSubHandler::HandleGcsSubscriberCommandBatch(
|
|||
const rpc::GcsSubscriberCommandBatchRequest &request,
|
||||
rpc::GcsSubscriberCommandBatchReply *reply,
|
||||
rpc::SendReplyCallback send_reply_callback) {
|
||||
if (gcs_publisher_ == nullptr) {
|
||||
send_reply_callback(
|
||||
Status::NotImplemented("GCS pubsub is not yet enabled. Please enable it with "
|
||||
"system config `gcs_grpc_based_pubsub=True`"),
|
||||
nullptr, nullptr);
|
||||
return;
|
||||
}
|
||||
const auto subscriber_id = UniqueID::FromBinary(request.subscriber_id());
|
||||
for (const auto &command : request.commands()) {
|
||||
if (command.has_unsubscribe_message()) {
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
|
||||
#include "ray/gcs/pubsub/gcs_pub_sub.h"
|
||||
#include "ray/rpc/gcs_server/gcs_rpc_server.h"
|
||||
#include "src/ray/protobuf/gcs_service.grpc.pb.h"
|
||||
|
||||
namespace ray {
|
||||
namespace gcs {
|
||||
|
@ -28,6 +29,10 @@ class InternalPubSubHandler : public rpc::InternalPubSubHandler {
|
|||
explicit InternalPubSubHandler(const std::shared_ptr<gcs::GcsPublisher> &gcs_publisher)
|
||||
: gcs_publisher_(gcs_publisher) {}
|
||||
|
||||
void HandleGcsPublish(const rpc::GcsPublishRequest &request,
|
||||
rpc::GcsPublishReply *reply,
|
||||
rpc::SendReplyCallback send_reply_callback) final;
|
||||
|
||||
void HandleGcsSubscriberPoll(const rpc::GcsSubscriberPollRequest &request,
|
||||
rpc::GcsSubscriberPollReply *reply,
|
||||
rpc::SendReplyCallback send_reply_callback) final;
|
||||
|
|
|
@ -302,6 +302,17 @@ Status GcsPublisher::PublishTaskLease(const TaskID &id, const rpc::TaskLeaseData
|
|||
Status GcsPublisher::PublishError(const std::string &id,
|
||||
const rpc::ErrorTableData &message,
|
||||
const StatusCallback &done) {
|
||||
if (publisher_ != nullptr) {
|
||||
rpc::PubMessage msg;
|
||||
msg.set_channel_type(rpc::ChannelType::RAY_ERROR_INFO_CHANNEL);
|
||||
msg.set_key_id(id);
|
||||
*msg.mutable_error_info_message() = message;
|
||||
publisher_->Publish(msg);
|
||||
if (done != nullptr) {
|
||||
done(Status::OK());
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
return pubsub_->Publish(ERROR_INFO_CHANNEL, id, message.SerializeAsString(), done);
|
||||
}
|
||||
|
||||
|
|
|
@ -563,6 +563,16 @@ service InternalKVGcsService {
|
|||
rpc InternalKVKeys(InternalKVKeysRequest) returns (InternalKVKeysReply);
|
||||
}
|
||||
|
||||
message GcsPublishRequest {
|
||||
/// The messages that are published.
|
||||
repeated PubMessage pub_messages = 1;
|
||||
}
|
||||
|
||||
message GcsPublishReply {
|
||||
// Not populated.
|
||||
GcsStatus status = 100;
|
||||
}
|
||||
|
||||
message GcsSubscriberPollRequest {
|
||||
/// The id of the subscriber.
|
||||
bytes subscriber_id = 1;
|
||||
|
@ -590,6 +600,9 @@ message GcsSubscriberCommandBatchReply {
|
|||
/// This supports subscribing updates from GCS with long poll, and registering /
|
||||
/// de-registering subscribers.
|
||||
service InternalPubSubGcsService {
|
||||
/// The request to sent to GCS to publish messages.
|
||||
/// Currently only supporting error info, logs and Python function messages.
|
||||
rpc GcsPublish(GcsPublishRequest) returns (GcsPublishReply);
|
||||
/// The long polling request sent to GCS for pubsub operations.
|
||||
/// The long poll request will be replied once there are a batch of messages that
|
||||
/// need to be published to the caller (subscriber).
|
||||
|
|
|
@ -40,6 +40,8 @@ enum ChannelType {
|
|||
GCS_NODE_RESOURCE_CHANNEL = 6;
|
||||
/// A channel for worker changes, currently only for worker failures.
|
||||
GCS_WORKER_DELTA_CHANNEL = 7;
|
||||
/// A channel for errors from various Ray components.
|
||||
RAY_ERROR_INFO_CHANNEL = 8;
|
||||
}
|
||||
|
||||
///
|
||||
|
@ -61,6 +63,7 @@ message PubMessage {
|
|||
GcsNodeInfo node_info_message = 9;
|
||||
NodeResourceChange node_resource_message = 10;
|
||||
WorkerDeltaData worker_delta_message = 11;
|
||||
ErrorTableData error_info_message = 12;
|
||||
|
||||
// The message that indicates the given key id is not available anymore.
|
||||
FailureMessage failure_message = 6;
|
||||
|
|
|
@ -145,6 +145,10 @@ bool SubscriptionIndex::CheckNoLeaks() const {
|
|||
|
||||
bool Subscriber::ConnectToSubscriber(rpc::PubsubLongPollingReply *reply,
|
||||
rpc::SendReplyCallback send_reply_callback) {
|
||||
if (long_polling_connection_) {
|
||||
// Flush the current subscriber poll with an empty reply.
|
||||
PublishIfPossible(/*force_noop=*/true);
|
||||
}
|
||||
if (!long_polling_connection_) {
|
||||
RAY_CHECK(reply != nullptr);
|
||||
RAY_CHECK(send_reply_callback != nullptr);
|
||||
|
@ -171,30 +175,25 @@ void Subscriber::QueueMessage(const rpc::PubMessage &pub_message, bool try_publi
|
|||
}
|
||||
}
|
||||
|
||||
bool Subscriber::PublishIfPossible(bool force) {
|
||||
bool Subscriber::PublishIfPossible(bool force_noop) {
|
||||
if (!long_polling_connection_) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (force || mailbox_.size() > 0) {
|
||||
// If force publish is invoked, mailbox could be empty. We should always add a reply
|
||||
// here because otherwise, there could be memory leak due to our grpc layer
|
||||
// implementation.
|
||||
if (mailbox_.empty()) {
|
||||
mailbox_.push(absl::make_unique<rpc::PubsubLongPollingReply>());
|
||||
if (!force_noop && mailbox_.empty()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!force_noop) {
|
||||
// Reply to the long polling subscriber. Swap the reply here to avoid extra copy.
|
||||
long_polling_connection_->reply->Swap(mailbox_.front().get());
|
||||
mailbox_.pop();
|
||||
}
|
||||
long_polling_connection_->send_reply_callback(Status::OK(), nullptr, nullptr);
|
||||
|
||||
// Clean up & update metadata.
|
||||
long_polling_connection_.reset();
|
||||
mailbox_.pop();
|
||||
last_connection_update_time_ms_ = get_time_ms_();
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool Subscriber::CheckNoLeaks() const {
|
||||
|
@ -311,7 +310,7 @@ int Publisher::UnregisterSubscriberInternal(const SubscriberID &subscriber_id) {
|
|||
}
|
||||
auto &subscriber = it->second;
|
||||
// Remove the long polling connection because otherwise, there's memory leak.
|
||||
subscriber->PublishIfPossible(/*force=*/true);
|
||||
subscriber->PublishIfPossible(/*force_noop=*/true);
|
||||
subscribers_.erase(it);
|
||||
return erased;
|
||||
}
|
||||
|
@ -330,8 +329,8 @@ void Publisher::CheckDeadSubscribers() {
|
|||
if (disconnected) {
|
||||
dead_subscribers.push_back(it.first);
|
||||
} else if (active_connection_timed_out) {
|
||||
// Refresh the long polling connection. The subscriber will send it again.
|
||||
subscriber->PublishIfPossible(/*force*/ true);
|
||||
// Refresh the long polling connection. The subscriber will poll again.
|
||||
subscriber->PublishIfPossible(/*force_noop*/ true);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -123,10 +123,11 @@ class Subscriber {
|
|||
|
||||
/// Publish all queued messages if possible.
|
||||
///
|
||||
/// \param force If true, we publish to the subscriber although there's no queued
|
||||
/// message.
|
||||
/// \param force_noop If true, reply to the subscriber with an empty message, regardless
|
||||
/// of whethere there is any queued message. This is for cases where the current poll
|
||||
/// might have been cancelled, or the subscriber might be dead.
|
||||
/// \return True if it publishes. False otherwise.
|
||||
bool PublishIfPossible(bool force = false);
|
||||
bool PublishIfPossible(bool force_noop = false);
|
||||
|
||||
/// Testing only. Return true if there's no metadata remained in the private attribute.
|
||||
bool CheckNoLeaks() const;
|
||||
|
|
|
@ -285,8 +285,8 @@ TEST_F(PublisherTest, TestSubscriber) {
|
|||
ASSERT_FALSE(subscriber->PublishIfPossible());
|
||||
// Try connecting it. Should return true.
|
||||
ASSERT_TRUE(subscriber->ConnectToSubscriber(&reply, send_reply_callback));
|
||||
// If connecting it again, it should fail the request.
|
||||
ASSERT_FALSE(subscriber->ConnectToSubscriber(&reply, send_reply_callback));
|
||||
// Polling when there is already an inflight polling request should still work.
|
||||
ASSERT_TRUE(subscriber->ConnectToSubscriber(&reply, send_reply_callback));
|
||||
// Since there's no published objects, it should return false.
|
||||
ASSERT_FALSE(subscriber->PublishIfPossible());
|
||||
|
||||
|
|
|
@ -277,6 +277,8 @@ class GcsRpcClient {
|
|||
internal_kv_grpc_client_, )
|
||||
|
||||
/// Operations for pubsub
|
||||
VOID_GCS_RPC_CLIENT_METHOD(InternalPubSubGcsService, GcsPublish,
|
||||
internal_pubsub_grpc_client_, )
|
||||
VOID_GCS_RPC_CLIENT_METHOD(InternalPubSubGcsService, GcsSubscriberPoll,
|
||||
internal_pubsub_grpc_client_, )
|
||||
VOID_GCS_RPC_CLIENT_METHOD(InternalPubSubGcsService, GcsSubscriberCommandBatch,
|
||||
|
|
|
@ -606,6 +606,9 @@ class InternalPubSubGcsServiceHandler {
|
|||
public:
|
||||
virtual ~InternalPubSubGcsServiceHandler() = default;
|
||||
|
||||
virtual void HandleGcsPublish(const GcsPublishRequest &request, GcsPublishReply *reply,
|
||||
SendReplyCallback send_reply_callback) = 0;
|
||||
|
||||
virtual void HandleGcsSubscriberPoll(const GcsSubscriberPollRequest &request,
|
||||
GcsSubscriberPollReply *reply,
|
||||
SendReplyCallback send_reply_callback) = 0;
|
||||
|
@ -626,6 +629,7 @@ class InternalPubSubGrpcService : public GrpcService {
|
|||
void InitServerCallFactories(
|
||||
const std::unique_ptr<grpc::ServerCompletionQueue> &cq,
|
||||
std::vector<std::unique_ptr<ServerCallFactory>> *server_call_factories) override {
|
||||
INTERNAL_PUBSUB_SERVICE_RPC_HANDLER(GcsPublish);
|
||||
INTERNAL_PUBSUB_SERVICE_RPC_HANDLER(GcsSubscriberPoll);
|
||||
INTERNAL_PUBSUB_SERVICE_RPC_HANDLER(GcsSubscriberCommandBatch);
|
||||
}
|
||||
|
|
Loading…
Add table
Reference in a new issue