mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[Core][Pubsub] Implement Python GCS publisher and subscriber (#20111)
## Why are these changes needed? This change adds Python publisher and subscriber in `gcs_utils.py`, and GRPC handler on GCS for publishing iva GCS. Error info is migrated to use the GCS-based pubsub, if feature flag `RAY_gcs_grpc_based_pubsub=true`. Also, add a `--gcs-address` flag to some Python processes. It is not set anywhere yet, but will be set aftering Redis-less bootstrapping work. Unit tests are added for the Python publisher and subscriber. Migrated error info publishers and subscribers are tested with existing unit tests, e.g. tests calling `ray._private.test_utils.get_error_message()` to ensure error info is published. GCS based pubsub has gaps in handling deadline, cancelled requests and GCS restarts. So 3 more unit tests are disabled in the `HA GCS` mode. They will be addressed in a separate change. ## Related issue number
This commit is contained in:
parent
fca851eef5
commit
0330852baf
31 changed files with 763 additions and 188 deletions
|
@ -343,6 +343,7 @@
|
||||||
--test_env=RAY_gcs_grpc_based_pubsub=true
|
--test_env=RAY_gcs_grpc_based_pubsub=true
|
||||||
-- //python/ray/tests/...
|
-- //python/ray/tests/...
|
||||||
-//python/ray/tests:test_failure_2
|
-//python/ray/tests:test_failure_2
|
||||||
|
-//python/ray/tests:test_job
|
||||||
- bazel test --config=ci $(./scripts/bazel_export_options)
|
- bazel test --config=ci $(./scripts/bazel_export_options)
|
||||||
--test_tag_filters=-kubernetes,client_tests,-flaky
|
--test_tag_filters=-kubernetes,client_tests,-flaky
|
||||||
--test_env=RAY_CLIENT_MODE=1 --test_env=RAY_PROFILING=1
|
--test_env=RAY_CLIENT_MODE=1 --test_env=RAY_PROFILING=1
|
||||||
|
@ -358,6 +359,7 @@
|
||||||
-- //python/ray/tests/...
|
-- //python/ray/tests/...
|
||||||
-//python/ray/tests:test_client_multi -//python/ray/tests:test_component_failures_3
|
-//python/ray/tests:test_client_multi -//python/ray/tests:test_component_failures_3
|
||||||
-//python/ray/tests:test_healthcheck -//python/ray/tests:test_gcs_fault_tolerance
|
-//python/ray/tests:test_healthcheck -//python/ray/tests:test_gcs_fault_tolerance
|
||||||
|
-//python/ray/tests:test_client
|
||||||
- label: ":redis: HA GCS (Medium K-Z)"
|
- label: ":redis: HA GCS (Medium K-Z)"
|
||||||
conditions: ["RAY_CI_PYTHON_AFFECTED"]
|
conditions: ["RAY_CI_PYTHON_AFFECTED"]
|
||||||
commands:
|
commands:
|
||||||
|
@ -368,6 +370,7 @@
|
||||||
-- //python/ray/tests/...
|
-- //python/ray/tests/...
|
||||||
-//python/ray/tests:test_multinode_failures_2 -//python/ray/tests:test_ray_debugger
|
-//python/ray/tests:test_multinode_failures_2 -//python/ray/tests:test_ray_debugger
|
||||||
-//python/ray/tests:test_placement_group_2 -//python/ray/tests:test_placement_group_3
|
-//python/ray/tests:test_placement_group_2 -//python/ray/tests:test_placement_group_3
|
||||||
|
-//python/ray/tests:test_multi_node
|
||||||
|
|
||||||
- label: ":brain: RLlib: Learning discr. actions TF2-static-graph (from rllib/tuned_examples/*.yaml)"
|
- label: ":brain: RLlib: Learning discr. actions TF2-static-graph (from rllib/tuned_examples/*.yaml)"
|
||||||
conditions: ["RAY_CI_RLLIB_AFFECTED"]
|
conditions: ["RAY_CI_RLLIB_AFFECTED"]
|
||||||
|
|
|
@ -19,7 +19,9 @@ import ray.dashboard.utils as dashboard_utils
|
||||||
import ray.ray_constants as ray_constants
|
import ray.ray_constants as ray_constants
|
||||||
import ray._private.services
|
import ray._private.services
|
||||||
import ray._private.utils
|
import ray._private.utils
|
||||||
from ray._private.gcs_utils import GcsClient
|
from ray._private.gcs_pubsub import gcs_pubsub_enabled, GcsPublisher
|
||||||
|
from ray._private.gcs_utils import GcsClient, \
|
||||||
|
get_gcs_address_from_redis
|
||||||
from ray.core.generated import agent_manager_pb2
|
from ray.core.generated import agent_manager_pb2
|
||||||
from ray.core.generated import agent_manager_pb2_grpc
|
from ray.core.generated import agent_manager_pb2_grpc
|
||||||
from ray._private.ray_logging import setup_component_logger
|
from ray._private.ray_logging import setup_component_logger
|
||||||
|
@ -228,6 +230,11 @@ if __name__ == "__main__":
|
||||||
required=True,
|
required=True,
|
||||||
type=str,
|
type=str,
|
||||||
help="the IP address of this node.")
|
help="the IP address of this node.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--gcs-address",
|
||||||
|
required=False,
|
||||||
|
type=str,
|
||||||
|
help="The address (ip:port) of GCS.")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--redis-address",
|
"--redis-address",
|
||||||
required=True,
|
required=True,
|
||||||
|
@ -377,6 +384,12 @@ if __name__ == "__main__":
|
||||||
# impact of the issue.
|
# impact of the issue.
|
||||||
redis_client = ray._private.services.create_redis_client(
|
redis_client = ray._private.services.create_redis_client(
|
||||||
args.redis_address, password=args.redis_password)
|
args.redis_address, password=args.redis_password)
|
||||||
|
gcs_publisher = None
|
||||||
|
if args.gcs_address:
|
||||||
|
gcs_publisher = GcsPublisher(args.gcs_address)
|
||||||
|
elif gcs_pubsub_enabled():
|
||||||
|
gcs_publisher = GcsPublisher(
|
||||||
|
address=get_gcs_address_from_redis(redis_client))
|
||||||
traceback_str = ray._private.utils.format_error_message(
|
traceback_str = ray._private.utils.format_error_message(
|
||||||
traceback.format_exc())
|
traceback.format_exc())
|
||||||
message = (
|
message = (
|
||||||
|
@ -390,9 +403,11 @@ if __name__ == "__main__":
|
||||||
"\n 3. runtime_env APIs won't work."
|
"\n 3. runtime_env APIs won't work."
|
||||||
"\nCheck out the `dashboard_agent.log` to see the "
|
"\nCheck out the `dashboard_agent.log` to see the "
|
||||||
"detailed failure messages.")
|
"detailed failure messages.")
|
||||||
ray._private.utils.push_error_to_driver_through_redis(
|
ray._private.utils.publish_error_to_driver(
|
||||||
redis_client, ray_constants.DASHBOARD_AGENT_DIED_ERROR,
|
ray_constants.DASHBOARD_AGENT_DIED_ERROR,
|
||||||
message)
|
message,
|
||||||
|
redis_client=redis_client,
|
||||||
|
gcs_publisher=gcs_publisher)
|
||||||
logger.error(message)
|
logger.error(message)
|
||||||
logger.exception(e)
|
logger.exception(e)
|
||||||
exit(1)
|
exit(1)
|
||||||
|
|
|
@ -13,8 +13,10 @@ import ray.dashboard.consts as dashboard_consts
|
||||||
import ray.dashboard.head as dashboard_head
|
import ray.dashboard.head as dashboard_head
|
||||||
import ray.dashboard.utils as dashboard_utils
|
import ray.dashboard.utils as dashboard_utils
|
||||||
import ray.ray_constants as ray_constants
|
import ray.ray_constants as ray_constants
|
||||||
|
import ray._private.gcs_utils as gcs_utils
|
||||||
import ray._private.services
|
import ray._private.services
|
||||||
import ray._private.utils
|
import ray._private.utils
|
||||||
|
from ray._private.gcs_pubsub import gcs_pubsub_enabled, GcsPublisher
|
||||||
from ray._private.ray_logging import setup_component_logger
|
from ray._private.ray_logging import setup_component_logger
|
||||||
from ray._private.metrics_agent import PrometheusServiceDiscoveryWriter
|
from ray._private.metrics_agent import PrometheusServiceDiscoveryWriter
|
||||||
|
|
||||||
|
@ -134,6 +136,11 @@ if __name__ == "__main__":
|
||||||
type=int,
|
type=int,
|
||||||
default=0,
|
default=0,
|
||||||
help="The retry times to select a valid port.")
|
help="The retry times to select a valid port.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--gcs-address",
|
||||||
|
required=False,
|
||||||
|
type=str,
|
||||||
|
help="The address (ip:port) of GCS.")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--redis-address",
|
"--redis-address",
|
||||||
required=True,
|
required=True,
|
||||||
|
@ -223,18 +230,29 @@ if __name__ == "__main__":
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
loop.run_until_complete(dashboard.run())
|
loop.run_until_complete(dashboard.run())
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# Something went wrong, so push an error to all drivers.
|
|
||||||
redis_client = ray._private.services.create_redis_client(
|
|
||||||
args.redis_address, password=args.redis_password)
|
|
||||||
traceback_str = ray._private.utils.format_error_message(
|
traceback_str = ray._private.utils.format_error_message(
|
||||||
traceback.format_exc())
|
traceback.format_exc())
|
||||||
message = f"The dashboard on node {platform.uname()[1]} " \
|
message = f"The dashboard on node {platform.uname()[1]} " \
|
||||||
f"failed with the following " \
|
f"failed with the following " \
|
||||||
f"error:\n{traceback_str}"
|
f"error:\n{traceback_str}"
|
||||||
ray._private.utils.push_error_to_driver_through_redis(
|
|
||||||
redis_client, ray_constants.DASHBOARD_DIED_ERROR, message)
|
|
||||||
if isinstance(e, FrontendNotFoundError):
|
if isinstance(e, FrontendNotFoundError):
|
||||||
logger.warning(message)
|
logger.warning(message)
|
||||||
else:
|
else:
|
||||||
logger.error(message)
|
logger.error(message)
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
# Something went wrong, so push an error to all drivers.
|
||||||
|
redis_client = ray._private.services.create_redis_client(
|
||||||
|
args.redis_address, password=args.redis_password)
|
||||||
|
gcs_publisher = None
|
||||||
|
if args.gcs_address:
|
||||||
|
gcs_publisher = GcsPublisher(address=args.gcs_address)
|
||||||
|
elif gcs_pubsub_enabled():
|
||||||
|
gcs_publisher = GcsPublisher(
|
||||||
|
address=gcs_utils.get_gcs_address_from_redis(redis_client))
|
||||||
|
ray._private.utils.publish_error_to_driver(
|
||||||
|
redis_client,
|
||||||
|
ray_constants.DASHBOARD_DIED_ERROR,
|
||||||
|
message,
|
||||||
|
redis_client=redis_client,
|
||||||
|
gcs_publisher=gcs_publisher)
|
||||||
|
|
|
@ -19,6 +19,7 @@ import ray._private.services
|
||||||
import ray.dashboard.consts as dashboard_consts
|
import ray.dashboard.consts as dashboard_consts
|
||||||
import ray.dashboard.utils as dashboard_utils
|
import ray.dashboard.utils as dashboard_utils
|
||||||
from ray import ray_constants
|
from ray import ray_constants
|
||||||
|
from ray._private.gcs_pubsub import gcs_pubsub_enabled, GcsAioSubscriber
|
||||||
from ray.core.generated import gcs_service_pb2
|
from ray.core.generated import gcs_service_pb2
|
||||||
from ray.core.generated import gcs_service_pb2_grpc
|
from ray.core.generated import gcs_service_pb2_grpc
|
||||||
from ray.dashboard.datacenter import DataOrganizer
|
from ray.dashboard.datacenter import DataOrganizer
|
||||||
|
@ -44,8 +45,8 @@ GRPC_CHANNEL_OPTIONS = (
|
||||||
async def get_gcs_address_with_retry(redis_client) -> str:
|
async def get_gcs_address_with_retry(redis_client) -> str:
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
gcs_address = await redis_client.get(
|
gcs_address = (await redis_client.get(
|
||||||
dashboard_consts.REDIS_KEY_GCS_SERVER_ADDRESS)
|
dashboard_consts.REDIS_KEY_GCS_SERVER_ADDRESS)).decode()
|
||||||
if not gcs_address:
|
if not gcs_address:
|
||||||
raise Exception("GCS address not found.")
|
raise Exception("GCS address not found.")
|
||||||
logger.info("Connect to GCS at %s", gcs_address)
|
logger.info("Connect to GCS at %s", gcs_address)
|
||||||
|
@ -113,6 +114,7 @@ class DashboardHead:
|
||||||
self.log_dir = log_dir
|
self.log_dir = log_dir
|
||||||
self.aioredis_client = None
|
self.aioredis_client = None
|
||||||
self.aiogrpc_gcs_channel = None
|
self.aiogrpc_gcs_channel = None
|
||||||
|
self.gcs_subscriber = None
|
||||||
self.http_session = None
|
self.http_session = None
|
||||||
self.ip = ray.util.get_node_ip_address()
|
self.ip = ray.util.get_node_ip_address()
|
||||||
ip, port = redis_address.split(":")
|
ip, port = redis_address.split(":")
|
||||||
|
@ -192,11 +194,15 @@ class DashboardHead:
|
||||||
# Waiting for GCS is ready.
|
# Waiting for GCS is ready.
|
||||||
# TODO: redis-removal bootstrap
|
# TODO: redis-removal bootstrap
|
||||||
gcs_address = await get_gcs_address_with_retry(self.aioredis_client)
|
gcs_address = await get_gcs_address_with_retry(self.aioredis_client)
|
||||||
self.gcs_client = GcsClient(gcs_address)
|
self.gcs_client = GcsClient(address=gcs_address)
|
||||||
|
internal_kv._initialize_internal_kv(self.gcs_client)
|
||||||
self.aiogrpc_gcs_channel = ray._private.utils.init_grpc_channel(
|
self.aiogrpc_gcs_channel = ray._private.utils.init_grpc_channel(
|
||||||
gcs_address, GRPC_CHANNEL_OPTIONS, asynchronous=True)
|
gcs_address, GRPC_CHANNEL_OPTIONS, asynchronous=True)
|
||||||
gcs_client = GcsClient(gcs_address)
|
self.gcs_subscriber = None
|
||||||
internal_kv._initialize_internal_kv(gcs_client)
|
if gcs_pubsub_enabled():
|
||||||
|
self.gcs_subscriber = GcsAioSubscriber(
|
||||||
|
channel=self.aiogrpc_gcs_channel)
|
||||||
|
await self.gcs_subscriber.subscribe_error()
|
||||||
|
|
||||||
self.health_check_thread = GCSHealthCheckThread(gcs_address)
|
self.health_check_thread = GCSHealthCheckThread(gcs_address)
|
||||||
self.health_check_thread.start()
|
self.health_check_thread.start()
|
||||||
|
|
|
@ -1,5 +1,4 @@
|
||||||
NODE_STATS_UPDATE_INTERVAL_SECONDS = 1
|
NODE_STATS_UPDATE_INTERVAL_SECONDS = 1
|
||||||
ERROR_INFO_UPDATE_INTERVAL_SECONDS = 5
|
|
||||||
LOG_INFO_UPDATE_INTERVAL_SECONDS = 5
|
LOG_INFO_UPDATE_INTERVAL_SECONDS = 5
|
||||||
UPDATE_NODES_INTERVAL_SECONDS = 5
|
UPDATE_NODES_INTERVAL_SECONDS = 5
|
||||||
MAX_COUNT_OF_GCS_RPC_ERROR = 10
|
MAX_COUNT_OF_GCS_RPC_ERROR = 10
|
||||||
|
|
|
@ -272,39 +272,51 @@ class NodeHead(dashboard_utils.DashboardHeadModule):
|
||||||
logger.exception("Error receiving log info.")
|
logger.exception("Error receiving log info.")
|
||||||
|
|
||||||
async def _update_error_info(self):
|
async def _update_error_info(self):
|
||||||
aioredis_client = self._dashboard_head.aioredis_client
|
def process_error(error_data):
|
||||||
receiver = Receiver()
|
error_data = gcs_utils.ErrorTableData.FromString(pubsub_msg.data)
|
||||||
|
message = error_data.error_message
|
||||||
|
message = re.sub(r"\x1b\[\d+m", "", message)
|
||||||
|
match = re.search(r"\(pid=(\d+), ip=(.*?)\)", message)
|
||||||
|
if match:
|
||||||
|
pid = match.group(1)
|
||||||
|
ip = match.group(2)
|
||||||
|
errs_for_ip = dict(DataSource.ip_and_pid_to_errors.get(ip, {}))
|
||||||
|
pid_errors = list(errs_for_ip.get(pid, []))
|
||||||
|
pid_errors.append({
|
||||||
|
"message": message,
|
||||||
|
"timestamp": error_data.timestamp,
|
||||||
|
"type": error_data.type
|
||||||
|
})
|
||||||
|
errs_for_ip[pid] = pid_errors
|
||||||
|
DataSource.ip_and_pid_to_errors[ip] = errs_for_ip
|
||||||
|
logger.info(f"Received error entry for {ip} {pid}")
|
||||||
|
|
||||||
key = gcs_utils.RAY_ERROR_PUBSUB_PATTERN
|
if self._dashboard_head.gcs_subscriber:
|
||||||
pattern = receiver.pattern(key)
|
while True:
|
||||||
await aioredis_client.psubscribe(pattern)
|
_, error_data = await \
|
||||||
logger.info("Subscribed to %s", key)
|
self._dashboard_head.gcs_subscriber.poll_error()
|
||||||
|
try:
|
||||||
|
process_error(error_data)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Error receiving error info from GCS.")
|
||||||
|
else:
|
||||||
|
aioredis_client = self._dashboard_head.aioredis_client
|
||||||
|
receiver = Receiver()
|
||||||
|
|
||||||
async for sender, msg in receiver.iter():
|
key = gcs_utils.RAY_ERROR_PUBSUB_PATTERN
|
||||||
try:
|
pattern = receiver.pattern(key)
|
||||||
_, data = msg
|
await aioredis_client.psubscribe(pattern)
|
||||||
pubsub_msg = gcs_utils.PubSubMessage.FromString(data)
|
logger.info("Subscribed to %s", key)
|
||||||
error_data = gcs_utils.ErrorTableData.FromString(
|
|
||||||
pubsub_msg.data)
|
async for _, msg in receiver.iter():
|
||||||
message = error_data.error_message
|
try:
|
||||||
message = re.sub(r"\x1b\[\d+m", "", message)
|
_, data = msg
|
||||||
match = re.search(r"\(pid=(\d+), ip=(.*?)\)", message)
|
pubsub_msg = gcs_utils.PubSubMessage.FromString(data)
|
||||||
if match:
|
error_data = gcs_utils.ErrorTableData.FromString(
|
||||||
pid = match.group(1)
|
pubsub_msg.data)
|
||||||
ip = match.group(2)
|
process_error(error_data)
|
||||||
errs_for_ip = dict(
|
except Exception:
|
||||||
DataSource.ip_and_pid_to_errors.get(ip, {}))
|
logger.exception("Error receiving error info from Redis.")
|
||||||
pid_errors = list(errs_for_ip.get(pid, []))
|
|
||||||
pid_errors.append({
|
|
||||||
"message": message,
|
|
||||||
"timestamp": error_data.timestamp,
|
|
||||||
"type": error_data.type
|
|
||||||
})
|
|
||||||
errs_for_ip[pid] = pid_errors
|
|
||||||
DataSource.ip_and_pid_to_errors[ip] = errs_for_ip
|
|
||||||
logger.info(f"Received error entry for {ip} {pid}")
|
|
||||||
except Exception:
|
|
||||||
logger.exception("Error receiving error info.")
|
|
||||||
|
|
||||||
async def run(self, server):
|
async def run(self, server):
|
||||||
gcs_channel = self._dashboard_head.aiogrpc_gcs_channel
|
gcs_channel = self._dashboard_head.aiogrpc_gcs_channel
|
||||||
|
|
276
python/ray/_private/gcs_pubsub.py
Normal file
276
python/ray/_private/gcs_pubsub.py
Normal file
|
@ -0,0 +1,276 @@
|
||||||
|
import os
|
||||||
|
from collections import deque
|
||||||
|
import logging
|
||||||
|
import random
|
||||||
|
import threading
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
import grpc
|
||||||
|
|
||||||
|
import ray._private.gcs_utils as gcs_utils
|
||||||
|
from ray.core.generated import gcs_service_pb2_grpc
|
||||||
|
from ray.core.generated import gcs_service_pb2
|
||||||
|
from ray.core.generated.gcs_pb2 import (
|
||||||
|
ErrorTableData, )
|
||||||
|
from ray.core.generated import pubsub_pb2
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def gcs_pubsub_enabled():
|
||||||
|
"""Checks whether GCS pubsub feature flag is enabled."""
|
||||||
|
return os.environ.get("RAY_gcs_grpc_based_pubsub") not in\
|
||||||
|
[None, "0", "false"]
|
||||||
|
|
||||||
|
|
||||||
|
def construct_error_message(job_id, error_type, message, timestamp):
|
||||||
|
"""Construct an ErrorTableData object.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
job_id: The ID of the job that the error should go to. If this is
|
||||||
|
nil, then the error will go to all drivers.
|
||||||
|
error_type: The type of the error.
|
||||||
|
message: The error message.
|
||||||
|
timestamp: The time of the error.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The ErrorTableData object.
|
||||||
|
"""
|
||||||
|
data = ErrorTableData()
|
||||||
|
data.job_id = job_id.binary()
|
||||||
|
data.type = error_type
|
||||||
|
data.error_message = message
|
||||||
|
data.timestamp = timestamp
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
class _SubscriberBase:
|
||||||
|
def _subscribe_error_request(self):
|
||||||
|
cmd = pubsub_pb2.Command(
|
||||||
|
channel_type=pubsub_pb2.RAY_ERROR_INFO_CHANNEL,
|
||||||
|
subscribe_message={})
|
||||||
|
req = gcs_service_pb2.GcsSubscriberCommandBatchRequest(
|
||||||
|
subscriber_id=self._subscriber_id, commands=[cmd])
|
||||||
|
return req
|
||||||
|
|
||||||
|
def _poll_request(self):
|
||||||
|
return gcs_service_pb2.GcsSubscriberPollRequest(
|
||||||
|
subscriber_id=self._subscriber_id)
|
||||||
|
|
||||||
|
def _unsubscribe_request(self):
|
||||||
|
req = gcs_service_pb2.GcsSubscriberCommandBatchRequest(
|
||||||
|
subscriber_id=self._subscriber_id, commands=[])
|
||||||
|
if self._subscribed_error:
|
||||||
|
cmd = pubsub_pb2.Command(
|
||||||
|
channel_type=pubsub_pb2.RAY_ERROR_INFO_CHANNEL,
|
||||||
|
unsubscribe_message={})
|
||||||
|
req.commands.append(cmd)
|
||||||
|
return req
|
||||||
|
|
||||||
|
|
||||||
|
class GcsPublisher:
|
||||||
|
"""Publisher to GCS."""
|
||||||
|
|
||||||
|
def __init__(self, address: str = None, channel: grpc.Channel = None):
|
||||||
|
if address:
|
||||||
|
assert channel is None, \
|
||||||
|
"address and channel cannot both be specified"
|
||||||
|
channel = gcs_utils.create_gcs_channel(address)
|
||||||
|
else:
|
||||||
|
assert channel is not None, \
|
||||||
|
"One of address and channel must be specified"
|
||||||
|
self._stub = gcs_service_pb2_grpc.InternalPubSubGcsServiceStub(channel)
|
||||||
|
|
||||||
|
def publish_error(self, key_id: bytes, error_info: ErrorTableData) -> None:
|
||||||
|
"""Publishes error info to GCS."""
|
||||||
|
msg = pubsub_pb2.PubMessage(
|
||||||
|
channel_type=pubsub_pb2.RAY_ERROR_INFO_CHANNEL,
|
||||||
|
key_id=key_id,
|
||||||
|
error_info_message=error_info)
|
||||||
|
req = gcs_service_pb2.GcsPublishRequest(pub_messages=[msg])
|
||||||
|
self._stub.GcsPublish(req)
|
||||||
|
|
||||||
|
|
||||||
|
class GcsSubscriber(_SubscriberBase):
|
||||||
|
"""Subscriber to GCS. Thread safe.
|
||||||
|
|
||||||
|
Usage example:
|
||||||
|
subscriber = GcsSubscriber()
|
||||||
|
subscriber.subscribe_error()
|
||||||
|
while running:
|
||||||
|
error_id, error_data = subscriber.poll_error()
|
||||||
|
......
|
||||||
|
subscriber.close()
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
address: str = None,
|
||||||
|
channel: grpc.Channel = None,
|
||||||
|
):
|
||||||
|
if address:
|
||||||
|
assert channel is None, \
|
||||||
|
"address and channel cannot both be specified"
|
||||||
|
channel = gcs_utils.create_gcs_channel(address)
|
||||||
|
else:
|
||||||
|
assert channel is not None, \
|
||||||
|
"One of address and channel must be specified"
|
||||||
|
self._lock = threading.RLock()
|
||||||
|
self._stub = gcs_service_pb2_grpc.InternalPubSubGcsServiceStub(channel)
|
||||||
|
self._subscriber_id = bytes(
|
||||||
|
bytearray(random.getrandbits(8) for _ in range(28)))
|
||||||
|
# Whether error info has been subscribed.
|
||||||
|
self._subscribed_error = False
|
||||||
|
# Buffer for holding error info.
|
||||||
|
self._errors = deque()
|
||||||
|
# Future for indicating whether the subscriber has closed.
|
||||||
|
self._close = threading.Event()
|
||||||
|
|
||||||
|
def subscribe_error(self) -> None:
|
||||||
|
"""Registers a subscription for error info.
|
||||||
|
|
||||||
|
Before the registration, published errors will not be saved for the
|
||||||
|
subscriber.
|
||||||
|
"""
|
||||||
|
with self._lock:
|
||||||
|
if self._close.is_set():
|
||||||
|
return
|
||||||
|
if not self._subscribed_error:
|
||||||
|
req = self._subscribe_error_request()
|
||||||
|
self._stub.GcsSubscriberCommandBatch(req, timeout=30)
|
||||||
|
self._subscribed_error = True
|
||||||
|
|
||||||
|
def poll_error(self, timeout=None) -> Tuple[bytes, ErrorTableData]:
|
||||||
|
"""Polls for new error messages."""
|
||||||
|
with self._lock:
|
||||||
|
if self._close.is_set():
|
||||||
|
return
|
||||||
|
|
||||||
|
if len(self._errors) == 0:
|
||||||
|
req = self._poll_request()
|
||||||
|
fut = self._stub.GcsSubscriberPoll.future(req, timeout=timeout)
|
||||||
|
# Wait for result to become available, or cancel if the
|
||||||
|
# subscriber has closed.
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
fut.result(timeout=1)
|
||||||
|
break
|
||||||
|
except grpc.FutureTimeoutError:
|
||||||
|
# Subscriber has closed. Cancel inflight the request
|
||||||
|
# and return from polling.
|
||||||
|
if self._close.is_set():
|
||||||
|
fut.cancel()
|
||||||
|
return None, None
|
||||||
|
# GRPC has not replied, continue waiting.
|
||||||
|
continue
|
||||||
|
except Exception:
|
||||||
|
# GRPC error, including deadline exceeded.
|
||||||
|
raise
|
||||||
|
if fut.done():
|
||||||
|
for msg in fut.result().pub_messages:
|
||||||
|
self._errors.append((msg.key_id,
|
||||||
|
msg.error_info_message))
|
||||||
|
|
||||||
|
if len(self._errors) == 0:
|
||||||
|
return None, None
|
||||||
|
return self._errors.popleft()
|
||||||
|
|
||||||
|
def close(self) -> None:
|
||||||
|
"""Closes the subscriber and its active subscriptions."""
|
||||||
|
# Mark close to terminate inflight polling and prevent future requests.
|
||||||
|
self._close.set()
|
||||||
|
with self._lock:
|
||||||
|
if not self._stub:
|
||||||
|
# Subscriber already closed.
|
||||||
|
return
|
||||||
|
req = self._unsubscribe_request()
|
||||||
|
try:
|
||||||
|
self._stub.GcsSubscriberCommandBatch(req, timeout=30)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
self._stub = None
|
||||||
|
|
||||||
|
|
||||||
|
class GcsAioPublisher:
|
||||||
|
"""Publisher to GCS. Uses async io."""
|
||||||
|
|
||||||
|
def __init__(self, address: str = None, channel: grpc.aio.Channel = None):
|
||||||
|
if address:
|
||||||
|
assert channel is None, \
|
||||||
|
"address and channel cannot both be specified"
|
||||||
|
channel = gcs_utils.create_gcs_channel(address, aio=True)
|
||||||
|
else:
|
||||||
|
assert channel is not None, \
|
||||||
|
"One of address and channel must be specified"
|
||||||
|
self._stub = gcs_service_pb2_grpc.InternalPubSubGcsServiceStub(channel)
|
||||||
|
|
||||||
|
async def publish_error(self, key_id: bytes,
|
||||||
|
error_info: ErrorTableData) -> None:
|
||||||
|
"""Publishes error info to GCS."""
|
||||||
|
msg = pubsub_pb2.PubMessage(
|
||||||
|
channel_type=pubsub_pb2.RAY_ERROR_INFO_CHANNEL,
|
||||||
|
key_id=key_id,
|
||||||
|
error_info_message=error_info)
|
||||||
|
req = gcs_service_pb2.GcsPublishRequest(pub_messages=[msg])
|
||||||
|
await self._stub.GcsPublish(req)
|
||||||
|
|
||||||
|
|
||||||
|
class GcsAioSubscriber(_SubscriberBase):
|
||||||
|
"""Async io subscriber to GCS.
|
||||||
|
|
||||||
|
Usage example:
|
||||||
|
subscriber = GcsAioSubscriber()
|
||||||
|
await subscriber.subscribe_error()
|
||||||
|
while running:
|
||||||
|
error_id, error_data = await subscriber.poll_error()
|
||||||
|
......
|
||||||
|
await subscriber.close()
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, address: str = None, channel: grpc.aio.Channel = None):
|
||||||
|
if address:
|
||||||
|
assert channel is None, \
|
||||||
|
"address and channel cannot both be specified"
|
||||||
|
channel = gcs_utils.create_gcs_channel(address, aio=True)
|
||||||
|
else:
|
||||||
|
assert channel is not None, \
|
||||||
|
"One of address and channel must be specified"
|
||||||
|
self._stub = gcs_service_pb2_grpc.InternalPubSubGcsServiceStub(channel)
|
||||||
|
self._subscriber_id = bytes(
|
||||||
|
bytearray(random.getrandbits(8) for _ in range(28)))
|
||||||
|
# Whether error info has been subscribed.
|
||||||
|
self._subscribed_error = False
|
||||||
|
# Buffer for holding error info.
|
||||||
|
self._errors = deque()
|
||||||
|
|
||||||
|
async def subscribe_error(self) -> None:
|
||||||
|
"""Registers a subscription for error info.
|
||||||
|
|
||||||
|
Before the registration, published errors will not be saved for the
|
||||||
|
subscriber.
|
||||||
|
"""
|
||||||
|
if not self._subscribed_error:
|
||||||
|
req = self._subscribe_error_request()
|
||||||
|
await self._stub.GcsSubscriberCommandBatch(req, timeout=30)
|
||||||
|
self._subscribed_error = True
|
||||||
|
|
||||||
|
async def poll_error(self, timeout=None) -> Tuple[bytes, ErrorTableData]:
|
||||||
|
"""Polls for new error messages."""
|
||||||
|
if len(self._errors) == 0:
|
||||||
|
req = self._poll_request()
|
||||||
|
reply = await self._stub.GcsSubscriberPoll(req, timeout=timeout)
|
||||||
|
for msg in reply.pub_messages:
|
||||||
|
self._errors.append((msg.key_id, msg.error_info_message))
|
||||||
|
|
||||||
|
if len(self._errors) == 0:
|
||||||
|
return None, None
|
||||||
|
return self._errors.popleft()
|
||||||
|
|
||||||
|
async def close(self) -> None:
|
||||||
|
"""Closes the subscriber and its active subscriptions."""
|
||||||
|
req = self._unsubscribe_request()
|
||||||
|
try:
|
||||||
|
await self._stub.GcsSubscriberCommandBatch(req, timeout=30)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
self._subscribed_error = False
|
|
@ -1,7 +1,10 @@
|
||||||
from ray.core.generated.common_pb2 import ErrorType
|
|
||||||
import enum
|
import enum
|
||||||
import logging
|
import logging
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
|
import grpc
|
||||||
|
|
||||||
|
from ray.core.generated.common_pb2 import ErrorType
|
||||||
from ray.core.generated import gcs_service_pb2_grpc
|
from ray.core.generated import gcs_service_pb2_grpc
|
||||||
from ray.core.generated import gcs_service_pb2
|
from ray.core.generated import gcs_service_pb2
|
||||||
from ray.core.generated.gcs_pb2 import (
|
from ray.core.generated.gcs_pb2 import (
|
||||||
|
@ -51,64 +54,69 @@ __all__ = [
|
||||||
"ResourceLoad",
|
"ResourceLoad",
|
||||||
"ResourceMap",
|
"ResourceMap",
|
||||||
"ResourceTableData",
|
"ResourceTableData",
|
||||||
"construct_error_message",
|
|
||||||
"ObjectLocationInfo",
|
"ObjectLocationInfo",
|
||||||
"PubSubMessage",
|
"PubSubMessage",
|
||||||
"WorkerTableData",
|
"WorkerTableData",
|
||||||
"PlacementGroupTableData",
|
"PlacementGroupTableData",
|
||||||
]
|
]
|
||||||
|
|
||||||
FUNCTION_PREFIX = "RemoteFunction:"
|
|
||||||
LOG_FILE_CHANNEL = "RAY_LOG_CHANNEL"
|
LOG_FILE_CHANNEL = "RAY_LOG_CHANNEL"
|
||||||
REPORTER_CHANNEL = "RAY_REPORTER"
|
|
||||||
|
|
||||||
# xray resource usages
|
|
||||||
XRAY_RESOURCES_BATCH_PATTERN = "RESOURCES_BATCH:".encode("ascii")
|
|
||||||
|
|
||||||
# xray job updates
|
|
||||||
XRAY_JOB_PATTERN = "JOB:*".encode("ascii")
|
|
||||||
|
|
||||||
# Actor pub/sub updates
|
# Actor pub/sub updates
|
||||||
RAY_ACTOR_PUBSUB_PATTERN = "ACTOR:*".encode("ascii")
|
RAY_ACTOR_PUBSUB_PATTERN = "ACTOR:*".encode("ascii")
|
||||||
|
|
||||||
# Reporter pub/sub updates
|
|
||||||
RAY_REPORTER_PUBSUB_PATTERN = "RAY_REPORTER.*".encode("ascii")
|
|
||||||
|
|
||||||
RAY_ERROR_PUBSUB_PATTERN = "ERROR_INFO:*".encode("ascii")
|
RAY_ERROR_PUBSUB_PATTERN = "ERROR_INFO:*".encode("ascii")
|
||||||
|
|
||||||
# These prefixes must be kept up-to-date with the TablePrefix enum in
|
# These prefixes must be kept up-to-date with the TablePrefix enum in
|
||||||
# gcs.proto.
|
# gcs.proto.
|
||||||
# TODO(rkn): We should use scoped enums, in which case we should be able to
|
|
||||||
# just access the flatbuffer generated values.
|
|
||||||
TablePrefix_RAYLET_TASK_string = "RAYLET_TASK"
|
|
||||||
TablePrefix_OBJECT_string = "OBJECT"
|
|
||||||
TablePrefix_PROFILE_string = "PROFILE"
|
|
||||||
TablePrefix_JOB_string = "JOB"
|
|
||||||
TablePrefix_ACTOR_string = "ACTOR"
|
TablePrefix_ACTOR_string = "ACTOR"
|
||||||
|
|
||||||
WORKER = 0
|
WORKER = 0
|
||||||
DRIVER = 1
|
DRIVER = 1
|
||||||
|
|
||||||
|
# Cap messages at 512MB
|
||||||
|
_MAX_MESSAGE_LENGTH = 512 * 1024 * 1024
|
||||||
|
# Send keepalive every 60s
|
||||||
|
_GRPC_KEEPALIVE_TIME_MS = 60 * 1000
|
||||||
|
# Keepalive should be replied < 60s
|
||||||
|
_GRPC_KEEPALIVE_TIMEOUT_MS = 60 * 1000
|
||||||
|
|
||||||
def construct_error_message(job_id, error_type, message, timestamp):
|
# Also relying on these defaults:
|
||||||
"""Construct a serialized ErrorTableData object.
|
# grpc.keepalive_permit_without_calls=0: No keepalive without inflight calls.
|
||||||
|
# grpc.use_local_subchannel_pool=0: Subchannels are shared.
|
||||||
|
_GRPC_OPTIONS = [("grpc.enable_http_proxy",
|
||||||
|
0), ("grpc.max_send_message_length", _MAX_MESSAGE_LENGTH),
|
||||||
|
("grpc.max_receive_message_length", _MAX_MESSAGE_LENGTH),
|
||||||
|
("grpc.keepalive_time_ms",
|
||||||
|
_GRPC_KEEPALIVE_TIME_MS), ("grpc.keepalive_timeout_ms",
|
||||||
|
_GRPC_KEEPALIVE_TIMEOUT_MS)]
|
||||||
|
|
||||||
|
|
||||||
|
def get_gcs_address_from_redis(redis) -> str:
|
||||||
|
"""Reads GCS address from redis.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
job_id: The ID of the job that the error should go to. If this is
|
redis: Redis client to fetch GCS address.
|
||||||
nil, then the error will go to all drivers.
|
|
||||||
error_type: The type of the error.
|
|
||||||
message: The error message.
|
|
||||||
timestamp: The time of the error.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The serialized object.
|
GCS address string.
|
||||||
"""
|
"""
|
||||||
data = ErrorTableData()
|
gcs_address = redis.get("GcsServerAddress")
|
||||||
data.job_id = job_id.binary()
|
if gcs_address is None:
|
||||||
data.type = error_type
|
raise RuntimeError("Failed to look up gcs address through redis")
|
||||||
data.error_message = message
|
return gcs_address.decode()
|
||||||
data.timestamp = timestamp
|
|
||||||
return data.SerializeToString()
|
|
||||||
|
def create_gcs_channel(address: str, aio=False):
|
||||||
|
"""Returns a GRPC channel to GCS.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
address: GCS address string, e.g. ip:port
|
||||||
|
aio: Whether using grpc.aio
|
||||||
|
Returns:
|
||||||
|
grpc.Channel or grpc.aio.Channel to GCS
|
||||||
|
"""
|
||||||
|
from ray._private.utils import init_grpc_channel
|
||||||
|
return init_grpc_channel(address, options=_GRPC_OPTIONS, asynchronous=aio)
|
||||||
|
|
||||||
|
|
||||||
class GcsCode(enum.IntEnum):
|
class GcsCode(enum.IntEnum):
|
||||||
|
@ -118,17 +126,13 @@ class GcsCode(enum.IntEnum):
|
||||||
|
|
||||||
|
|
||||||
class GcsClient:
|
class GcsClient:
|
||||||
MAX_MESSAGE_LENGTH = 512 * 1024 * 1024 # 512MB
|
"""Client to GCS using GRPC"""
|
||||||
|
|
||||||
def __init__(self, address):
|
def __init__(self, address: str = None, channel: grpc.Channel = None):
|
||||||
from ray._private.utils import init_grpc_channel
|
if address:
|
||||||
logger.debug(f"Connecting to gcs address: {address}")
|
assert channel is None, \
|
||||||
options = [("grpc.enable_http_proxy",
|
"Only one of address and channel can be specified"
|
||||||
0), ("grpc.max_send_message_length",
|
channel = create_gcs_channel(address)
|
||||||
GcsClient.MAX_MESSAGE_LENGTH),
|
|
||||||
("grpc.max_receive_message_length",
|
|
||||||
GcsClient.MAX_MESSAGE_LENGTH)]
|
|
||||||
channel = init_grpc_channel(address, options=options)
|
|
||||||
self._kv_stub = gcs_service_pb2_grpc.InternalKVGcsServiceStub(channel)
|
self._kv_stub = gcs_service_pb2_grpc.InternalKVGcsServiceStub(channel)
|
||||||
|
|
||||||
def internal_kv_get(self, key: bytes) -> bytes:
|
def internal_kv_get(self, key: bytes) -> bytes:
|
||||||
|
@ -187,10 +191,7 @@ class GcsClient:
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def create_from_redis(redis_cli):
|
def create_from_redis(redis_cli):
|
||||||
gcs_address = redis_cli.get("GcsServerAddress")
|
return GcsClient(get_gcs_address_from_redis(redis_cli))
|
||||||
if gcs_address is None:
|
|
||||||
raise RuntimeError("Failed to look up gcs address through redis")
|
|
||||||
return GcsClient(gcs_address.decode())
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def connect_to_gcs_by_redis_address(redis_address, redis_password):
|
def connect_to_gcs_by_redis_address(redis_address, redis_password):
|
||||||
|
|
|
@ -15,6 +15,7 @@ import ray.ray_constants as ray_constants
|
||||||
import ray._private.gcs_utils as gcs_utils
|
import ray._private.gcs_utils as gcs_utils
|
||||||
import ray._private.services as services
|
import ray._private.services as services
|
||||||
import ray._private.utils
|
import ray._private.utils
|
||||||
|
from ray._private.gcs_pubsub import gcs_pubsub_enabled, GcsPublisher
|
||||||
from ray._private.ray_logging import setup_component_logger
|
from ray._private.ray_logging import setup_component_logger
|
||||||
|
|
||||||
# Logger for this module. It should be configured at the entry point
|
# Logger for this module. It should be configured at the entry point
|
||||||
|
@ -365,6 +366,11 @@ if __name__ == "__main__":
|
||||||
description=("Parse Redis server for the "
|
description=("Parse Redis server for the "
|
||||||
"log monitor to connect "
|
"log monitor to connect "
|
||||||
"to."))
|
"to."))
|
||||||
|
parser.add_argument(
|
||||||
|
"--gcs-address",
|
||||||
|
required=False,
|
||||||
|
type=str,
|
||||||
|
help="The address (ip:port) of GCS.")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--redis-address",
|
"--redis-address",
|
||||||
required=True,
|
required=True,
|
||||||
|
@ -436,11 +442,20 @@ if __name__ == "__main__":
|
||||||
# Something went wrong, so push an error to all drivers.
|
# Something went wrong, so push an error to all drivers.
|
||||||
redis_client = ray._private.services.create_redis_client(
|
redis_client = ray._private.services.create_redis_client(
|
||||||
args.redis_address, password=args.redis_password)
|
args.redis_address, password=args.redis_password)
|
||||||
|
gcs_publisher = None
|
||||||
|
if args.gcs_address:
|
||||||
|
gcs_publisher = GcsPublisher(address=args.gcs_address)
|
||||||
|
elif gcs_pubsub_enabled():
|
||||||
|
gcs_publisher = GcsPublisher(
|
||||||
|
address=gcs_utils.get_gcs_address_from_redis(redis_client))
|
||||||
traceback_str = ray._private.utils.format_error_message(
|
traceback_str = ray._private.utils.format_error_message(
|
||||||
traceback.format_exc())
|
traceback.format_exc())
|
||||||
message = (f"The log monitor on node {platform.node()} "
|
message = (f"The log monitor on node {platform.node()} "
|
||||||
f"failed with the following error:\n{traceback_str}")
|
f"failed with the following error:\n{traceback_str}")
|
||||||
ray._private.utils.push_error_to_driver_through_redis(
|
ray._private.utils.publish_error_to_driver(
|
||||||
redis_client, ray_constants.LOG_MONITOR_DIED_ERROR, message)
|
ray_constants.LOG_MONITOR_DIED_ERROR,
|
||||||
|
message,
|
||||||
|
redis_client=redis_client,
|
||||||
|
gcs_publisher=gcs_publisher)
|
||||||
logger.error(message)
|
logger.error(message)
|
||||||
raise e
|
raise e
|
||||||
|
|
|
@ -1926,6 +1926,7 @@ def start_monitor(redis_address,
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
redis_address (str): The address that the Redis server is listening on.
|
redis_address (str): The address that the Redis server is listening on.
|
||||||
|
gcs_address (str): The address of GCS server.
|
||||||
logs_dir(str): The path to the log directory.
|
logs_dir(str): The path to the log directory.
|
||||||
stdout_file: A file handle opened for writing to redirect stdout to. If
|
stdout_file: A file handle opened for writing to redirect stdout to. If
|
||||||
no redirection should happen, then this should be None.
|
no redirection should happen, then this should be None.
|
||||||
|
|
|
@ -26,6 +26,7 @@ import ray._private.gcs_utils as gcs_utils
|
||||||
import ray._private.memory_monitor as memory_monitor
|
import ray._private.memory_monitor as memory_monitor
|
||||||
from ray.core.generated import node_manager_pb2
|
from ray.core.generated import node_manager_pb2
|
||||||
from ray.core.generated import node_manager_pb2_grpc
|
from ray.core.generated import node_manager_pb2_grpc
|
||||||
|
from ray._private.gcs_pubsub import gcs_pubsub_enabled, GcsSubscriber
|
||||||
from ray._private.tls_utils import generate_self_signed_tls_certs
|
from ray._private.tls_utils import generate_self_signed_tls_certs
|
||||||
from ray.util.queue import Queue, _QueueActor, Empty
|
from ray.util.queue import Queue, _QueueActor, Empty
|
||||||
from ray.scripts.scripts import main as ray_main
|
from ray.scripts.scripts import main as ray_main
|
||||||
|
@ -502,24 +503,39 @@ def get_non_head_nodes(cluster):
|
||||||
|
|
||||||
def init_error_pubsub():
|
def init_error_pubsub():
|
||||||
"""Initialize redis error info pub/sub"""
|
"""Initialize redis error info pub/sub"""
|
||||||
p = ray.worker.global_worker.redis_client.pubsub(
|
if gcs_pubsub_enabled():
|
||||||
ignore_subscribe_messages=True)
|
s = GcsSubscriber(channel=ray.worker.global_worker.gcs_channel)
|
||||||
error_pubsub_channel = gcs_utils.RAY_ERROR_PUBSUB_PATTERN
|
s.subscribe_error()
|
||||||
p.psubscribe(error_pubsub_channel)
|
else:
|
||||||
return p
|
s = ray.worker.global_worker.redis_client.pubsub(
|
||||||
|
ignore_subscribe_messages=True)
|
||||||
|
s.psubscribe(gcs_utils.RAY_ERROR_PUBSUB_PATTERN)
|
||||||
|
return s
|
||||||
|
|
||||||
|
|
||||||
def get_error_message(pub_sub, num, error_type=None, timeout=20):
|
def get_error_message(subscriber, num, error_type=None, timeout=20):
|
||||||
"""Get errors through pub/sub."""
|
"""Get errors through subscriber."""
|
||||||
start_time = time.time()
|
deadline = time.time() + timeout
|
||||||
msgs = []
|
msgs = []
|
||||||
while time.time() - start_time < timeout and len(msgs) < num:
|
while time.time() < deadline and len(msgs) < num:
|
||||||
msg = pub_sub.get_message()
|
if isinstance(subscriber, GcsSubscriber):
|
||||||
if msg is None:
|
try:
|
||||||
time.sleep(0.01)
|
_, error_data = subscriber.poll_error(timeout=deadline -
|
||||||
continue
|
time.time())
|
||||||
pubsub_msg = gcs_utils.PubSubMessage.FromString(msg["data"])
|
except grpc.RpcError as e:
|
||||||
error_data = gcs_utils.ErrorTableData.FromString(pubsub_msg.data)
|
# Failed to match error message before timeout.
|
||||||
|
if e.code() == grpc.StatusCode.DEADLINE_EXCEEDED:
|
||||||
|
logging.warning("get_error_message() timed out")
|
||||||
|
return []
|
||||||
|
# Otherwise, the error is unexpected.
|
||||||
|
raise
|
||||||
|
else:
|
||||||
|
msg = subscriber.get_message()
|
||||||
|
if msg is None:
|
||||||
|
time.sleep(0.01)
|
||||||
|
continue
|
||||||
|
pubsub_msg = gcs_utils.PubSubMessage.FromString(msg["data"])
|
||||||
|
error_data = gcs_utils.ErrorTableData.FromString(pubsub_msg.data)
|
||||||
if error_type is None or error_type == error_data.type:
|
if error_type is None or error_type == error_data.type:
|
||||||
msgs.append(error_data)
|
msgs.append(error_data)
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -27,6 +27,7 @@ import numpy as np
|
||||||
import ray
|
import ray
|
||||||
import ray._private.gcs_utils as gcs_utils
|
import ray._private.gcs_utils as gcs_utils
|
||||||
import ray.ray_constants as ray_constants
|
import ray.ray_constants as ray_constants
|
||||||
|
from ray._private.gcs_pubsub import construct_error_message
|
||||||
from ray._private.tls_utils import load_certs_from_env
|
from ray._private.tls_utils import load_certs_from_env
|
||||||
|
|
||||||
# Import psutil after ray so the packaged version is used.
|
# Import psutil after ray so the packaged version is used.
|
||||||
|
@ -113,10 +114,11 @@ def push_error_to_driver(worker, error_type, message, job_id=None):
|
||||||
worker.core_worker.push_error(job_id, error_type, message, time.time())
|
worker.core_worker.push_error(job_id, error_type, message, time.time())
|
||||||
|
|
||||||
|
|
||||||
def push_error_to_driver_through_redis(redis_client,
|
def publish_error_to_driver(error_type,
|
||||||
error_type,
|
message,
|
||||||
message,
|
job_id=None,
|
||||||
job_id=None):
|
redis_client=None,
|
||||||
|
gcs_publisher=None):
|
||||||
"""Push an error message to the driver to be printed in the background.
|
"""Push an error message to the driver to be printed in the background.
|
||||||
|
|
||||||
Normally the push_error_to_driver function should be used. However, in some
|
Normally the push_error_to_driver function should be used. However, in some
|
||||||
|
@ -125,25 +127,31 @@ def push_error_to_driver_through_redis(redis_client,
|
||||||
backend processes.
|
backend processes.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
redis_client: The redis client to use.
|
|
||||||
error_type (str): The type of the error.
|
error_type (str): The type of the error.
|
||||||
message (str): The message that will be printed in the background
|
message (str): The message that will be printed in the background
|
||||||
on the driver.
|
on the driver.
|
||||||
job_id: The ID of the driver to push the error message to. If this
|
job_id: The ID of the driver to push the error message to. If this
|
||||||
is None, then the message will be pushed to all drivers.
|
is None, then the message will be pushed to all drivers.
|
||||||
|
redis_client: The redis client to use.
|
||||||
|
gcs_publisher: The GCS publisher to use. If specified, ignores
|
||||||
|
redis_client.
|
||||||
"""
|
"""
|
||||||
if job_id is None:
|
if job_id is None:
|
||||||
job_id = ray.JobID.nil()
|
job_id = ray.JobID.nil()
|
||||||
assert isinstance(job_id, ray.JobID)
|
assert isinstance(job_id, ray.JobID)
|
||||||
# Do everything in Python and through the Python Redis client instead
|
error_data = construct_error_message(job_id, error_type, message,
|
||||||
# of through the raylet.
|
time.time())
|
||||||
error_data = gcs_utils.construct_error_message(job_id, error_type, message,
|
if gcs_publisher:
|
||||||
time.time())
|
gcs_publisher.publish_error(job_id.hex().encode(), error_data)
|
||||||
pubsub_msg = gcs_utils.PubSubMessage()
|
elif redis_client:
|
||||||
pubsub_msg.id = job_id.binary()
|
pubsub_msg = gcs_utils.PubSubMessage()
|
||||||
pubsub_msg.data = error_data
|
pubsub_msg.id = job_id.binary()
|
||||||
redis_client.publish("ERROR_INFO:" + job_id.hex(),
|
pubsub_msg.data = error_data.SerializeToString()
|
||||||
pubsub_msg.SerializeToString())
|
redis_client.publish("ERROR_INFO:" + job_id.hex(),
|
||||||
|
pubsub_msg.SerializeToString())
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"One of redis_client and gcs_publisher needs to be specified!")
|
||||||
|
|
||||||
|
|
||||||
def random_string():
|
def random_string():
|
||||||
|
|
|
@ -1,7 +1,6 @@
|
||||||
"""Autoscaler monitoring loop daemon."""
|
"""Autoscaler monitoring loop daemon."""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import logging
|
|
||||||
import logging.handlers
|
import logging.handlers
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
@ -35,7 +34,8 @@ from ray.autoscaler._private.util import DEBUG_AUTOSCALING_STATUS, \
|
||||||
from ray.core.generated import gcs_service_pb2, gcs_service_pb2_grpc
|
from ray.core.generated import gcs_service_pb2, gcs_service_pb2_grpc
|
||||||
import ray.ray_constants as ray_constants
|
import ray.ray_constants as ray_constants
|
||||||
from ray._private.ray_logging import setup_component_logger
|
from ray._private.ray_logging import setup_component_logger
|
||||||
from ray._private.gcs_utils import GcsClient
|
from ray._private.gcs_pubsub import gcs_pubsub_enabled, GcsPublisher
|
||||||
|
from ray._private.gcs_utils import GcsClient, get_gcs_address_from_redis
|
||||||
from ray.experimental.internal_kv import _initialize_internal_kv, \
|
from ray.experimental.internal_kv import _initialize_internal_kv, \
|
||||||
_internal_kv_put, _internal_kv_initialized, _internal_kv_get, \
|
_internal_kv_put, _internal_kv_initialized, _internal_kv_get, \
|
||||||
_internal_kv_del
|
_internal_kv_del
|
||||||
|
@ -409,9 +409,18 @@ class Monitor:
|
||||||
_internal_kv_put(DEBUG_AUTOSCALING_ERROR, message, overwrite=True)
|
_internal_kv_put(DEBUG_AUTOSCALING_ERROR, message, overwrite=True)
|
||||||
redis_client = ray._private.services.create_redis_client(
|
redis_client = ray._private.services.create_redis_client(
|
||||||
self.redis_address, password=self.redis_password)
|
self.redis_address, password=self.redis_password)
|
||||||
from ray._private.utils import push_error_to_driver_through_redis
|
gcs_publisher = None
|
||||||
push_error_to_driver_through_redis(
|
if args.gcs_address:
|
||||||
redis_client, ray_constants.MONITOR_DIED_ERROR, message)
|
gcs_publisher = GcsPublisher(address=args.gcs_address)
|
||||||
|
elif gcs_pubsub_enabled():
|
||||||
|
gcs_publisher = GcsPublisher(
|
||||||
|
address=get_gcs_address_from_redis(redis_client))
|
||||||
|
from ray._private.utils import publish_error_to_driver
|
||||||
|
publish_error_to_driver(
|
||||||
|
ray_constants.MONITOR_DIED_ERROR,
|
||||||
|
message,
|
||||||
|
redis_client=redis_client,
|
||||||
|
gcs_publisher=gcs_publisher)
|
||||||
|
|
||||||
def _signal_handler(self, sig, frame):
|
def _signal_handler(self, sig, frame):
|
||||||
self._handle_failure(f"Terminated with signal {sig}\n" +
|
self._handle_failure(f"Terminated with signal {sig}\n" +
|
||||||
|
@ -437,6 +446,11 @@ if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
description=("Parse Redis server for the "
|
description=("Parse Redis server for the "
|
||||||
"monitor to connect to."))
|
"monitor to connect to."))
|
||||||
|
parser.add_argument(
|
||||||
|
"--gcs-address",
|
||||||
|
required=False,
|
||||||
|
type=str,
|
||||||
|
help="The address (ip:port) of GCS.")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--redis-address",
|
"--redis-address",
|
||||||
required=True,
|
required=True,
|
||||||
|
|
|
@ -28,6 +28,7 @@ py_test_module_list(
|
||||||
"test_component_failures_2.py",
|
"test_component_failures_2.py",
|
||||||
"test_component_failures_3.py",
|
"test_component_failures_3.py",
|
||||||
"test_error_ray_not_initialized.py",
|
"test_error_ray_not_initialized.py",
|
||||||
|
"test_gcs_pubsub.py",
|
||||||
"test_global_gc.py",
|
"test_global_gc.py",
|
||||||
"test_grpc_client_credentials.py",
|
"test_grpc_client_credentials.py",
|
||||||
"test_iter.py",
|
"test_iter.py",
|
||||||
|
|
|
@ -11,6 +11,7 @@ import ray._private.utils
|
||||||
import ray._private.gcs_utils as gcs_utils
|
import ray._private.gcs_utils as gcs_utils
|
||||||
import ray.ray_constants as ray_constants
|
import ray.ray_constants as ray_constants
|
||||||
from ray.exceptions import RayTaskError, RayActorError, GetTimeoutError
|
from ray.exceptions import RayTaskError, RayActorError, GetTimeoutError
|
||||||
|
from ray._private.gcs_pubsub import gcs_pubsub_enabled, GcsPublisher
|
||||||
from ray._private.test_utils import (wait_for_condition, SignalActor,
|
from ray._private.test_utils import (wait_for_condition, SignalActor,
|
||||||
init_error_pubsub, get_error_message)
|
init_error_pubsub, get_error_message)
|
||||||
|
|
||||||
|
@ -61,14 +62,21 @@ def test_unhandled_errors(ray_start_regular):
|
||||||
del os.environ["RAY_IGNORE_UNHANDLED_ERRORS"]
|
del os.environ["RAY_IGNORE_UNHANDLED_ERRORS"]
|
||||||
|
|
||||||
|
|
||||||
def test_push_error_to_driver_through_redis(ray_start_regular, error_pubsub):
|
def test_publish_error_to_driver(ray_start_regular, error_pubsub):
|
||||||
address_info = ray_start_regular
|
address_info = ray_start_regular
|
||||||
address = address_info["redis_address"]
|
address = address_info["redis_address"]
|
||||||
redis_client = ray._private.services.create_redis_client(
|
redis_client = ray._private.services.create_redis_client(
|
||||||
address, password=ray.ray_constants.REDIS_DEFAULT_PASSWORD)
|
address, password=ray.ray_constants.REDIS_DEFAULT_PASSWORD)
|
||||||
|
gcs_publisher = None
|
||||||
|
if gcs_pubsub_enabled():
|
||||||
|
gcs_publisher = GcsPublisher(
|
||||||
|
address=gcs_utils.get_gcs_address_from_redis(redis_client))
|
||||||
error_message = "Test error message"
|
error_message = "Test error message"
|
||||||
ray._private.utils.push_error_to_driver_through_redis(
|
ray._private.utils.publish_error_to_driver(
|
||||||
redis_client, ray_constants.DASHBOARD_AGENT_DIED_ERROR, error_message)
|
ray_constants.DASHBOARD_AGENT_DIED_ERROR,
|
||||||
|
error_message,
|
||||||
|
redis_client=redis_client,
|
||||||
|
gcs_publisher=gcs_publisher)
|
||||||
errors = get_error_message(error_pubsub, 1,
|
errors = get_error_message(error_pubsub, 1,
|
||||||
ray_constants.DASHBOARD_AGENT_DIED_ERROR)
|
ray_constants.DASHBOARD_AGENT_DIED_ERROR)
|
||||||
assert errors[0].type == ray_constants.DASHBOARD_AGENT_DIED_ERROR
|
assert errors[0].type == ray_constants.DASHBOARD_AGENT_DIED_ERROR
|
||||||
|
|
|
@ -104,9 +104,8 @@ def test_connect_with_disconnected_node(shutdown_only):
|
||||||
ray.init(address=cluster.address)
|
ray.init(address=cluster.address)
|
||||||
p = init_error_pubsub()
|
p = init_error_pubsub()
|
||||||
errors = get_error_message(p, 1, timeout=5)
|
errors = get_error_message(p, 1, timeout=5)
|
||||||
print(errors)
|
|
||||||
assert len(errors) == 0
|
assert len(errors) == 0
|
||||||
# This node is killed by SIGKILL, ray_monitor will mark it to dead.
|
# This node will be killed by SIGKILL, ray_monitor will mark it to dead.
|
||||||
dead_node = cluster.add_node(num_cpus=0)
|
dead_node = cluster.add_node(num_cpus=0)
|
||||||
cluster.remove_node(dead_node, allow_graceful=False)
|
cluster.remove_node(dead_node, allow_graceful=False)
|
||||||
errors = get_error_message(p, 1, ray_constants.REMOVED_NODE_ERROR)
|
errors = get_error_message(p, 1, ray_constants.REMOVED_NODE_ERROR)
|
||||||
|
|
73
python/ray/tests/test_gcs_pubsub.py
Normal file
73
python/ray/tests/test_gcs_pubsub.py
Normal file
|
@ -0,0 +1,73 @@
|
||||||
|
import sys
|
||||||
|
|
||||||
|
import ray
|
||||||
|
import ray._private.gcs_utils as gcs_utils
|
||||||
|
from ray._private.gcs_pubsub import GcsPublisher, GcsSubscriber, \
|
||||||
|
GcsAioPublisher, GcsAioSubscriber
|
||||||
|
from ray.core.generated.gcs_pb2 import ErrorTableData
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"ray_start_regular", [{
|
||||||
|
"_system_config": {
|
||||||
|
"gcs_grpc_based_pubsub": True
|
||||||
|
}
|
||||||
|
}],
|
||||||
|
indirect=True)
|
||||||
|
def test_publish_and_subscribe_error_info(ray_start_regular):
|
||||||
|
address_info = ray_start_regular
|
||||||
|
redis = ray._private.services.create_redis_client(
|
||||||
|
address_info["redis_address"],
|
||||||
|
password=ray.ray_constants.REDIS_DEFAULT_PASSWORD)
|
||||||
|
|
||||||
|
gcs_server_addr = gcs_utils.get_gcs_address_from_redis(redis)
|
||||||
|
|
||||||
|
subscriber = GcsSubscriber(address=gcs_server_addr)
|
||||||
|
subscriber.subscribe_error()
|
||||||
|
|
||||||
|
publisher = GcsPublisher(address=gcs_server_addr)
|
||||||
|
err1 = ErrorTableData(error_message="test error message 1")
|
||||||
|
err2 = ErrorTableData(error_message="test error message 2")
|
||||||
|
publisher.publish_error(b"aaa_id", err1)
|
||||||
|
publisher.publish_error(b"bbb_id", err2)
|
||||||
|
|
||||||
|
assert subscriber.poll_error() == (b"aaa_id", err1)
|
||||||
|
assert subscriber.poll_error() == (b"bbb_id", err2)
|
||||||
|
|
||||||
|
subscriber.close()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"ray_start_regular", [{
|
||||||
|
"_system_config": {
|
||||||
|
"gcs_grpc_based_pubsub": True
|
||||||
|
}
|
||||||
|
}],
|
||||||
|
indirect=True)
|
||||||
|
async def test_aio_publish_and_subscribe_error_info(ray_start_regular):
|
||||||
|
address_info = ray_start_regular
|
||||||
|
redis = ray._private.services.create_redis_client(
|
||||||
|
address_info["redis_address"],
|
||||||
|
password=ray.ray_constants.REDIS_DEFAULT_PASSWORD)
|
||||||
|
|
||||||
|
gcs_server_addr = gcs_utils.get_gcs_address_from_redis(redis)
|
||||||
|
|
||||||
|
subscriber = GcsAioSubscriber(address=gcs_server_addr)
|
||||||
|
await subscriber.subscribe_error()
|
||||||
|
|
||||||
|
publisher = GcsAioPublisher(address=gcs_server_addr)
|
||||||
|
err1 = ErrorTableData(error_message="test error message 1")
|
||||||
|
err2 = ErrorTableData(error_message="test error message 2")
|
||||||
|
await publisher.publish_error(b"aaa_id", err1)
|
||||||
|
await publisher.publish_error(b"bbb_id", err2)
|
||||||
|
|
||||||
|
assert await subscriber.poll_error() == (b"aaa_id", err1)
|
||||||
|
assert await subscriber.poll_error() == (b"bbb_id", err2)
|
||||||
|
|
||||||
|
await subscriber.close()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
sys.exit(pytest.main(["-v", __file__]))
|
|
@ -27,6 +27,8 @@ import ray.remote_function
|
||||||
import ray.serialization as serialization
|
import ray.serialization as serialization
|
||||||
import ray._private.gcs_utils as gcs_utils
|
import ray._private.gcs_utils as gcs_utils
|
||||||
import ray._private.services as services
|
import ray._private.services as services
|
||||||
|
from ray._private.gcs_pubsub import gcs_pubsub_enabled, GcsPublisher, \
|
||||||
|
GcsSubscriber
|
||||||
from ray._private.runtime_env.py_modules import upload_py_modules_if_needed
|
from ray._private.runtime_env.py_modules import upload_py_modules_if_needed
|
||||||
from ray._private.runtime_env.working_dir import upload_working_dir_if_needed
|
from ray._private.runtime_env.working_dir import upload_working_dir_if_needed
|
||||||
from ray._private.runtime_env.constants import RAY_JOB_CONFIG_JSON_ENV_VAR
|
from ray._private.runtime_env.constants import RAY_JOB_CONFIG_JSON_ENV_VAR
|
||||||
|
@ -1247,6 +1249,58 @@ def listen_error_messages_raylet(worker, threads_stopped):
|
||||||
worker.error_message_pubsub_client.close()
|
worker.error_message_pubsub_client.close()
|
||||||
|
|
||||||
|
|
||||||
|
def listen_error_messages_from_gcs(worker, threads_stopped):
|
||||||
|
"""Listen to error messages in the background on the driver.
|
||||||
|
|
||||||
|
This runs in a separate thread on the driver and pushes (error, time)
|
||||||
|
tuples to be published.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
worker: The worker class that this thread belongs to.
|
||||||
|
threads_stopped (threading.Event): A threading event used to signal to
|
||||||
|
the thread that it should exit.
|
||||||
|
"""
|
||||||
|
worker.gcs_subscriber = GcsSubscriber(channel=worker.gcs_channel)
|
||||||
|
# Exports that are published after the call to
|
||||||
|
# gcs_subscriber.subscribe_error() and before the call to
|
||||||
|
# gcs_subscriber.poll_error() will still be processed in the loop.
|
||||||
|
|
||||||
|
# TODO: we should just subscribe to the errors for this specific job.
|
||||||
|
worker.gcs_subscriber.subscribe_error()
|
||||||
|
|
||||||
|
try:
|
||||||
|
if _internal_kv_initialized():
|
||||||
|
# Get any autoscaler errors that occurred before the call to
|
||||||
|
# subscribe.
|
||||||
|
error_message = _internal_kv_get(DEBUG_AUTOSCALING_ERROR)
|
||||||
|
if error_message is not None:
|
||||||
|
logger.warning(error_message.decode())
|
||||||
|
|
||||||
|
while True:
|
||||||
|
# Exit if received a signal that the thread should stop.
|
||||||
|
if threads_stopped.is_set():
|
||||||
|
return
|
||||||
|
|
||||||
|
_, error_data = worker.gcs_subscriber.poll_error()
|
||||||
|
if error_data is None:
|
||||||
|
continue
|
||||||
|
if error_data.job_id not in [
|
||||||
|
worker.current_job_id.binary(),
|
||||||
|
JobID.nil().binary(),
|
||||||
|
]:
|
||||||
|
continue
|
||||||
|
|
||||||
|
error_message = error_data.error_message
|
||||||
|
if error_data.type == ray_constants.TASK_PUSH_ERROR:
|
||||||
|
# TODO(ekl) remove task push errors entirely now that we have
|
||||||
|
# the separate unhandled exception handler.
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
logger.warning(error_message)
|
||||||
|
except (OSError, ConnectionError) as e:
|
||||||
|
logger.error(f"listen_error_messages_from_gcs: {e}")
|
||||||
|
|
||||||
|
|
||||||
@PublicAPI
|
@PublicAPI
|
||||||
@client_mode_hook(auto_init=False)
|
@client_mode_hook(auto_init=False)
|
||||||
def is_initialized() -> bool:
|
def is_initialized() -> bool:
|
||||||
|
@ -1307,11 +1361,16 @@ def connect(node,
|
||||||
# that is not true of Redis pubsub clients. See the documentation at
|
# that is not true of Redis pubsub clients. See the documentation at
|
||||||
# https://github.com/andymccurdy/redis-py#thread-safety.
|
# https://github.com/andymccurdy/redis-py#thread-safety.
|
||||||
worker.redis_client = node.create_redis_client()
|
worker.redis_client = node.create_redis_client()
|
||||||
worker.gcs_client = gcs_utils.GcsClient.create_from_redis(
|
worker.gcs_channel = gcs_utils.create_gcs_channel(
|
||||||
worker.redis_client)
|
gcs_utils.get_gcs_address_from_redis(worker.redis_client))
|
||||||
|
worker.gcs_client = gcs_utils.GcsClient(channel=worker.gcs_channel)
|
||||||
_initialize_internal_kv(worker.gcs_client)
|
_initialize_internal_kv(worker.gcs_client)
|
||||||
ray.state.state._initialize_global_state(
|
ray.state.state._initialize_global_state(
|
||||||
node.redis_address, redis_password=node.redis_password)
|
node.redis_address, redis_password=node.redis_password)
|
||||||
|
worker.gcs_pubsub_enabled = gcs_pubsub_enabled()
|
||||||
|
worker.gcs_publisher = None
|
||||||
|
if worker.gcs_pubsub_enabled:
|
||||||
|
worker.gcs_publisher = GcsPublisher(channel=worker.gcs_channel)
|
||||||
|
|
||||||
# Initialize some fields.
|
# Initialize some fields.
|
||||||
if mode in (WORKER_MODE, RESTORE_WORKER_MODE, SPILL_WORKER_MODE):
|
if mode in (WORKER_MODE, RESTORE_WORKER_MODE, SPILL_WORKER_MODE):
|
||||||
|
@ -1350,11 +1409,12 @@ def connect(node,
|
||||||
raise e
|
raise e
|
||||||
elif mode == WORKER_MODE:
|
elif mode == WORKER_MODE:
|
||||||
traceback_str = traceback.format_exc()
|
traceback_str = traceback.format_exc()
|
||||||
ray._private.utils.push_error_to_driver_through_redis(
|
ray._private.utils.publish_error_to_driver(
|
||||||
worker.redis_client,
|
|
||||||
ray_constants.VERSION_MISMATCH_PUSH_ERROR,
|
ray_constants.VERSION_MISMATCH_PUSH_ERROR,
|
||||||
traceback_str,
|
traceback_str,
|
||||||
job_id=None)
|
job_id=None,
|
||||||
|
redis_client=worker.redis_client,
|
||||||
|
gcs_publisher=worker.gcs_publisher)
|
||||||
|
|
||||||
worker.lock = threading.RLock()
|
worker.lock = threading.RLock()
|
||||||
|
|
||||||
|
@ -1445,7 +1505,8 @@ def connect(node,
|
||||||
# scheduler for new error messages.
|
# scheduler for new error messages.
|
||||||
if mode == SCRIPT_MODE:
|
if mode == SCRIPT_MODE:
|
||||||
worker.listener_thread = threading.Thread(
|
worker.listener_thread = threading.Thread(
|
||||||
target=listen_error_messages_raylet,
|
target=listen_error_messages_from_gcs
|
||||||
|
if worker.gcs_pubsub_enabled else listen_error_messages_raylet,
|
||||||
name="ray_listen_error_messages",
|
name="ray_listen_error_messages",
|
||||||
args=(worker, worker.threads_stopped))
|
args=(worker, worker.threads_stopped))
|
||||||
worker.listener_thread.daemon = True
|
worker.listener_thread.daemon = True
|
||||||
|
@ -1519,6 +1580,8 @@ def disconnect(exiting_interpreter=False):
|
||||||
if hasattr(worker, "import_thread"):
|
if hasattr(worker, "import_thread"):
|
||||||
worker.import_thread.join_import_thread()
|
worker.import_thread.join_import_thread()
|
||||||
if hasattr(worker, "listener_thread"):
|
if hasattr(worker, "listener_thread"):
|
||||||
|
if hasattr(worker, "gcs_subscriber"):
|
||||||
|
worker.gcs_subscriber.close()
|
||||||
worker.listener_thread.join()
|
worker.listener_thread.join()
|
||||||
if hasattr(worker, "logger_thread"):
|
if hasattr(worker, "logger_thread"):
|
||||||
worker.logger_thread.join()
|
worker.logger_thread.join()
|
||||||
|
|
|
@ -116,17 +116,6 @@ class MockGcsTaskReconstructionTable : public GcsTaskReconstructionTable {
|
||||||
namespace ray {
|
namespace ray {
|
||||||
namespace gcs {
|
namespace gcs {
|
||||||
|
|
||||||
class MockGcsObjectTable : public GcsObjectTable {
|
|
||||||
public:
|
|
||||||
MOCK_METHOD(JobID, GetJobIdFromKey, (const ObjectID &key), (override));
|
|
||||||
};
|
|
||||||
|
|
||||||
} // namespace gcs
|
|
||||||
} // namespace ray
|
|
||||||
|
|
||||||
namespace ray {
|
|
||||||
namespace gcs {
|
|
||||||
|
|
||||||
class MockGcsNodeTable : public GcsNodeTable {
|
class MockGcsNodeTable : public GcsNodeTable {
|
||||||
public:
|
public:
|
||||||
};
|
};
|
||||||
|
|
|
@ -72,7 +72,8 @@ void GcsServer::Start() {
|
||||||
rpc::ChannelType::GCS_JOB_CHANNEL,
|
rpc::ChannelType::GCS_JOB_CHANNEL,
|
||||||
rpc::ChannelType::GCS_NODE_INFO_CHANNEL,
|
rpc::ChannelType::GCS_NODE_INFO_CHANNEL,
|
||||||
rpc::ChannelType::GCS_NODE_RESOURCE_CHANNEL,
|
rpc::ChannelType::GCS_NODE_RESOURCE_CHANNEL,
|
||||||
rpc::ChannelType::GCS_WORKER_DELTA_CHANNEL},
|
rpc::ChannelType::GCS_WORKER_DELTA_CHANNEL,
|
||||||
|
rpc::ChannelType::RAY_ERROR_INFO_CHANNEL},
|
||||||
/*periodical_runner=*/&pubsub_periodical_runner_,
|
/*periodical_runner=*/&pubsub_periodical_runner_,
|
||||||
/*get_time_ms=*/[]() { return absl::GetCurrentTimeNanos() / 1e6; },
|
/*get_time_ms=*/[]() { return absl::GetCurrentTimeNanos() / 1e6; },
|
||||||
/*subscriber_timeout_ms=*/RayConfig::instance().subscriber_timeout_ms(),
|
/*subscriber_timeout_ms=*/RayConfig::instance().subscriber_timeout_ms(),
|
||||||
|
|
|
@ -213,17 +213,6 @@ class GcsTaskReconstructionTable
|
||||||
JobID GetJobIdFromKey(const TaskID &key) override { return key.ActorId().JobId(); }
|
JobID GetJobIdFromKey(const TaskID &key) override { return key.ActorId().JobId(); }
|
||||||
};
|
};
|
||||||
|
|
||||||
class GcsObjectTable : public GcsTableWithJobId<ObjectID, ObjectLocationInfo> {
|
|
||||||
public:
|
|
||||||
explicit GcsObjectTable(std::shared_ptr<StoreClient> store_client)
|
|
||||||
: GcsTableWithJobId(std::move(store_client)) {
|
|
||||||
table_name_ = TablePrefix_Name(TablePrefix::OBJECT);
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
JobID GetJobIdFromKey(const ObjectID &key) override { return key.TaskId().JobId(); }
|
|
||||||
};
|
|
||||||
|
|
||||||
class GcsNodeTable : public GcsTable<NodeID, GcsNodeInfo> {
|
class GcsNodeTable : public GcsTable<NodeID, GcsNodeInfo> {
|
||||||
public:
|
public:
|
||||||
explicit GcsNodeTable(std::shared_ptr<StoreClient> store_client)
|
explicit GcsNodeTable(std::shared_ptr<StoreClient> store_client)
|
||||||
|
|
|
@ -17,12 +17,35 @@
|
||||||
namespace ray {
|
namespace ray {
|
||||||
namespace gcs {
|
namespace gcs {
|
||||||
|
|
||||||
|
void InternalPubSubHandler::HandleGcsPublish(const rpc::GcsPublishRequest &request,
|
||||||
|
rpc::GcsPublishReply *reply,
|
||||||
|
rpc::SendReplyCallback send_reply_callback) {
|
||||||
|
if (gcs_publisher_ == nullptr) {
|
||||||
|
send_reply_callback(
|
||||||
|
Status::NotImplemented("GCS pubsub is not yet enabled. Please enable it with "
|
||||||
|
"system config `gcs_grpc_based_pubsub=True`"),
|
||||||
|
nullptr, nullptr);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
for (const auto &msg : request.pub_messages()) {
|
||||||
|
gcs_publisher_->GetPublisher()->Publish(msg);
|
||||||
|
}
|
||||||
|
send_reply_callback(Status::OK(), nullptr, nullptr);
|
||||||
|
}
|
||||||
|
|
||||||
// Needs to use rpc::GcsSubscriberPollRequest and rpc::GcsSubscriberPollReply here,
|
// Needs to use rpc::GcsSubscriberPollRequest and rpc::GcsSubscriberPollReply here,
|
||||||
// and convert the reply to rpc::PubsubLongPollingReply because GCS RPC services are
|
// and convert the reply to rpc::PubsubLongPollingReply because GCS RPC services are
|
||||||
// required to have the `status` field in replies.
|
// required to have the `status` field in replies.
|
||||||
void InternalPubSubHandler::HandleGcsSubscriberPoll(
|
void InternalPubSubHandler::HandleGcsSubscriberPoll(
|
||||||
const rpc::GcsSubscriberPollRequest &request, rpc::GcsSubscriberPollReply *reply,
|
const rpc::GcsSubscriberPollRequest &request, rpc::GcsSubscriberPollReply *reply,
|
||||||
rpc::SendReplyCallback send_reply_callback) {
|
rpc::SendReplyCallback send_reply_callback) {
|
||||||
|
if (gcs_publisher_ == nullptr) {
|
||||||
|
send_reply_callback(
|
||||||
|
Status::NotImplemented("GCS pubsub is not yet enabled. Please enable it with "
|
||||||
|
"system config `gcs_grpc_based_pubsub=True`"),
|
||||||
|
nullptr, nullptr);
|
||||||
|
return;
|
||||||
|
}
|
||||||
const auto subscriber_id = UniqueID::FromBinary(request.subscriber_id());
|
const auto subscriber_id = UniqueID::FromBinary(request.subscriber_id());
|
||||||
auto pubsub_reply = std::make_shared<rpc::PubsubLongPollingReply>();
|
auto pubsub_reply = std::make_shared<rpc::PubsubLongPollingReply>();
|
||||||
auto pubsub_reply_ptr = pubsub_reply.get();
|
auto pubsub_reply_ptr = pubsub_reply.get();
|
||||||
|
@ -44,6 +67,13 @@ void InternalPubSubHandler::HandleGcsSubscriberCommandBatch(
|
||||||
const rpc::GcsSubscriberCommandBatchRequest &request,
|
const rpc::GcsSubscriberCommandBatchRequest &request,
|
||||||
rpc::GcsSubscriberCommandBatchReply *reply,
|
rpc::GcsSubscriberCommandBatchReply *reply,
|
||||||
rpc::SendReplyCallback send_reply_callback) {
|
rpc::SendReplyCallback send_reply_callback) {
|
||||||
|
if (gcs_publisher_ == nullptr) {
|
||||||
|
send_reply_callback(
|
||||||
|
Status::NotImplemented("GCS pubsub is not yet enabled. Please enable it with "
|
||||||
|
"system config `gcs_grpc_based_pubsub=True`"),
|
||||||
|
nullptr, nullptr);
|
||||||
|
return;
|
||||||
|
}
|
||||||
const auto subscriber_id = UniqueID::FromBinary(request.subscriber_id());
|
const auto subscriber_id = UniqueID::FromBinary(request.subscriber_id());
|
||||||
for (const auto &command : request.commands()) {
|
for (const auto &command : request.commands()) {
|
||||||
if (command.has_unsubscribe_message()) {
|
if (command.has_unsubscribe_message()) {
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
|
|
||||||
#include "ray/gcs/pubsub/gcs_pub_sub.h"
|
#include "ray/gcs/pubsub/gcs_pub_sub.h"
|
||||||
#include "ray/rpc/gcs_server/gcs_rpc_server.h"
|
#include "ray/rpc/gcs_server/gcs_rpc_server.h"
|
||||||
|
#include "src/ray/protobuf/gcs_service.grpc.pb.h"
|
||||||
|
|
||||||
namespace ray {
|
namespace ray {
|
||||||
namespace gcs {
|
namespace gcs {
|
||||||
|
@ -28,6 +29,10 @@ class InternalPubSubHandler : public rpc::InternalPubSubHandler {
|
||||||
explicit InternalPubSubHandler(const std::shared_ptr<gcs::GcsPublisher> &gcs_publisher)
|
explicit InternalPubSubHandler(const std::shared_ptr<gcs::GcsPublisher> &gcs_publisher)
|
||||||
: gcs_publisher_(gcs_publisher) {}
|
: gcs_publisher_(gcs_publisher) {}
|
||||||
|
|
||||||
|
void HandleGcsPublish(const rpc::GcsPublishRequest &request,
|
||||||
|
rpc::GcsPublishReply *reply,
|
||||||
|
rpc::SendReplyCallback send_reply_callback) final;
|
||||||
|
|
||||||
void HandleGcsSubscriberPoll(const rpc::GcsSubscriberPollRequest &request,
|
void HandleGcsSubscriberPoll(const rpc::GcsSubscriberPollRequest &request,
|
||||||
rpc::GcsSubscriberPollReply *reply,
|
rpc::GcsSubscriberPollReply *reply,
|
||||||
rpc::SendReplyCallback send_reply_callback) final;
|
rpc::SendReplyCallback send_reply_callback) final;
|
||||||
|
|
|
@ -302,6 +302,17 @@ Status GcsPublisher::PublishTaskLease(const TaskID &id, const rpc::TaskLeaseData
|
||||||
Status GcsPublisher::PublishError(const std::string &id,
|
Status GcsPublisher::PublishError(const std::string &id,
|
||||||
const rpc::ErrorTableData &message,
|
const rpc::ErrorTableData &message,
|
||||||
const StatusCallback &done) {
|
const StatusCallback &done) {
|
||||||
|
if (publisher_ != nullptr) {
|
||||||
|
rpc::PubMessage msg;
|
||||||
|
msg.set_channel_type(rpc::ChannelType::RAY_ERROR_INFO_CHANNEL);
|
||||||
|
msg.set_key_id(id);
|
||||||
|
*msg.mutable_error_info_message() = message;
|
||||||
|
publisher_->Publish(msg);
|
||||||
|
if (done != nullptr) {
|
||||||
|
done(Status::OK());
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
return pubsub_->Publish(ERROR_INFO_CHANNEL, id, message.SerializeAsString(), done);
|
return pubsub_->Publish(ERROR_INFO_CHANNEL, id, message.SerializeAsString(), done);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -563,6 +563,16 @@ service InternalKVGcsService {
|
||||||
rpc InternalKVKeys(InternalKVKeysRequest) returns (InternalKVKeysReply);
|
rpc InternalKVKeys(InternalKVKeysRequest) returns (InternalKVKeysReply);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
message GcsPublishRequest {
|
||||||
|
/// The messages that are published.
|
||||||
|
repeated PubMessage pub_messages = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
message GcsPublishReply {
|
||||||
|
// Not populated.
|
||||||
|
GcsStatus status = 100;
|
||||||
|
}
|
||||||
|
|
||||||
message GcsSubscriberPollRequest {
|
message GcsSubscriberPollRequest {
|
||||||
/// The id of the subscriber.
|
/// The id of the subscriber.
|
||||||
bytes subscriber_id = 1;
|
bytes subscriber_id = 1;
|
||||||
|
@ -590,6 +600,9 @@ message GcsSubscriberCommandBatchReply {
|
||||||
/// This supports subscribing updates from GCS with long poll, and registering /
|
/// This supports subscribing updates from GCS with long poll, and registering /
|
||||||
/// de-registering subscribers.
|
/// de-registering subscribers.
|
||||||
service InternalPubSubGcsService {
|
service InternalPubSubGcsService {
|
||||||
|
/// The request to sent to GCS to publish messages.
|
||||||
|
/// Currently only supporting error info, logs and Python function messages.
|
||||||
|
rpc GcsPublish(GcsPublishRequest) returns (GcsPublishReply);
|
||||||
/// The long polling request sent to GCS for pubsub operations.
|
/// The long polling request sent to GCS for pubsub operations.
|
||||||
/// The long poll request will be replied once there are a batch of messages that
|
/// The long poll request will be replied once there are a batch of messages that
|
||||||
/// need to be published to the caller (subscriber).
|
/// need to be published to the caller (subscriber).
|
||||||
|
|
|
@ -40,6 +40,8 @@ enum ChannelType {
|
||||||
GCS_NODE_RESOURCE_CHANNEL = 6;
|
GCS_NODE_RESOURCE_CHANNEL = 6;
|
||||||
/// A channel for worker changes, currently only for worker failures.
|
/// A channel for worker changes, currently only for worker failures.
|
||||||
GCS_WORKER_DELTA_CHANNEL = 7;
|
GCS_WORKER_DELTA_CHANNEL = 7;
|
||||||
|
/// A channel for errors from various Ray components.
|
||||||
|
RAY_ERROR_INFO_CHANNEL = 8;
|
||||||
}
|
}
|
||||||
|
|
||||||
///
|
///
|
||||||
|
@ -61,6 +63,7 @@ message PubMessage {
|
||||||
GcsNodeInfo node_info_message = 9;
|
GcsNodeInfo node_info_message = 9;
|
||||||
NodeResourceChange node_resource_message = 10;
|
NodeResourceChange node_resource_message = 10;
|
||||||
WorkerDeltaData worker_delta_message = 11;
|
WorkerDeltaData worker_delta_message = 11;
|
||||||
|
ErrorTableData error_info_message = 12;
|
||||||
|
|
||||||
// The message that indicates the given key id is not available anymore.
|
// The message that indicates the given key id is not available anymore.
|
||||||
FailureMessage failure_message = 6;
|
FailureMessage failure_message = 6;
|
||||||
|
|
|
@ -145,6 +145,10 @@ bool SubscriptionIndex::CheckNoLeaks() const {
|
||||||
|
|
||||||
bool Subscriber::ConnectToSubscriber(rpc::PubsubLongPollingReply *reply,
|
bool Subscriber::ConnectToSubscriber(rpc::PubsubLongPollingReply *reply,
|
||||||
rpc::SendReplyCallback send_reply_callback) {
|
rpc::SendReplyCallback send_reply_callback) {
|
||||||
|
if (long_polling_connection_) {
|
||||||
|
// Flush the current subscriber poll with an empty reply.
|
||||||
|
PublishIfPossible(/*force_noop=*/true);
|
||||||
|
}
|
||||||
if (!long_polling_connection_) {
|
if (!long_polling_connection_) {
|
||||||
RAY_CHECK(reply != nullptr);
|
RAY_CHECK(reply != nullptr);
|
||||||
RAY_CHECK(send_reply_callback != nullptr);
|
RAY_CHECK(send_reply_callback != nullptr);
|
||||||
|
@ -171,30 +175,25 @@ void Subscriber::QueueMessage(const rpc::PubMessage &pub_message, bool try_publi
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
bool Subscriber::PublishIfPossible(bool force) {
|
bool Subscriber::PublishIfPossible(bool force_noop) {
|
||||||
if (!long_polling_connection_) {
|
if (!long_polling_connection_) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
if (!force_noop && mailbox_.empty()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
if (force || mailbox_.size() > 0) {
|
if (!force_noop) {
|
||||||
// If force publish is invoked, mailbox could be empty. We should always add a reply
|
|
||||||
// here because otherwise, there could be memory leak due to our grpc layer
|
|
||||||
// implementation.
|
|
||||||
if (mailbox_.empty()) {
|
|
||||||
mailbox_.push(absl::make_unique<rpc::PubsubLongPollingReply>());
|
|
||||||
}
|
|
||||||
|
|
||||||
// Reply to the long polling subscriber. Swap the reply here to avoid extra copy.
|
// Reply to the long polling subscriber. Swap the reply here to avoid extra copy.
|
||||||
long_polling_connection_->reply->Swap(mailbox_.front().get());
|
long_polling_connection_->reply->Swap(mailbox_.front().get());
|
||||||
long_polling_connection_->send_reply_callback(Status::OK(), nullptr, nullptr);
|
|
||||||
|
|
||||||
// Clean up & update metadata.
|
|
||||||
long_polling_connection_.reset();
|
|
||||||
mailbox_.pop();
|
mailbox_.pop();
|
||||||
last_connection_update_time_ms_ = get_time_ms_();
|
|
||||||
return true;
|
|
||||||
}
|
}
|
||||||
return false;
|
long_polling_connection_->send_reply_callback(Status::OK(), nullptr, nullptr);
|
||||||
|
|
||||||
|
// Clean up & update metadata.
|
||||||
|
long_polling_connection_.reset();
|
||||||
|
last_connection_update_time_ms_ = get_time_ms_();
|
||||||
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool Subscriber::CheckNoLeaks() const {
|
bool Subscriber::CheckNoLeaks() const {
|
||||||
|
@ -311,7 +310,7 @@ int Publisher::UnregisterSubscriberInternal(const SubscriberID &subscriber_id) {
|
||||||
}
|
}
|
||||||
auto &subscriber = it->second;
|
auto &subscriber = it->second;
|
||||||
// Remove the long polling connection because otherwise, there's memory leak.
|
// Remove the long polling connection because otherwise, there's memory leak.
|
||||||
subscriber->PublishIfPossible(/*force=*/true);
|
subscriber->PublishIfPossible(/*force_noop=*/true);
|
||||||
subscribers_.erase(it);
|
subscribers_.erase(it);
|
||||||
return erased;
|
return erased;
|
||||||
}
|
}
|
||||||
|
@ -330,8 +329,8 @@ void Publisher::CheckDeadSubscribers() {
|
||||||
if (disconnected) {
|
if (disconnected) {
|
||||||
dead_subscribers.push_back(it.first);
|
dead_subscribers.push_back(it.first);
|
||||||
} else if (active_connection_timed_out) {
|
} else if (active_connection_timed_out) {
|
||||||
// Refresh the long polling connection. The subscriber will send it again.
|
// Refresh the long polling connection. The subscriber will poll again.
|
||||||
subscriber->PublishIfPossible(/*force*/ true);
|
subscriber->PublishIfPossible(/*force_noop*/ true);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -123,10 +123,11 @@ class Subscriber {
|
||||||
|
|
||||||
/// Publish all queued messages if possible.
|
/// Publish all queued messages if possible.
|
||||||
///
|
///
|
||||||
/// \param force If true, we publish to the subscriber although there's no queued
|
/// \param force_noop If true, reply to the subscriber with an empty message, regardless
|
||||||
/// message.
|
/// of whethere there is any queued message. This is for cases where the current poll
|
||||||
|
/// might have been cancelled, or the subscriber might be dead.
|
||||||
/// \return True if it publishes. False otherwise.
|
/// \return True if it publishes. False otherwise.
|
||||||
bool PublishIfPossible(bool force = false);
|
bool PublishIfPossible(bool force_noop = false);
|
||||||
|
|
||||||
/// Testing only. Return true if there's no metadata remained in the private attribute.
|
/// Testing only. Return true if there's no metadata remained in the private attribute.
|
||||||
bool CheckNoLeaks() const;
|
bool CheckNoLeaks() const;
|
||||||
|
|
|
@ -285,8 +285,8 @@ TEST_F(PublisherTest, TestSubscriber) {
|
||||||
ASSERT_FALSE(subscriber->PublishIfPossible());
|
ASSERT_FALSE(subscriber->PublishIfPossible());
|
||||||
// Try connecting it. Should return true.
|
// Try connecting it. Should return true.
|
||||||
ASSERT_TRUE(subscriber->ConnectToSubscriber(&reply, send_reply_callback));
|
ASSERT_TRUE(subscriber->ConnectToSubscriber(&reply, send_reply_callback));
|
||||||
// If connecting it again, it should fail the request.
|
// Polling when there is already an inflight polling request should still work.
|
||||||
ASSERT_FALSE(subscriber->ConnectToSubscriber(&reply, send_reply_callback));
|
ASSERT_TRUE(subscriber->ConnectToSubscriber(&reply, send_reply_callback));
|
||||||
// Since there's no published objects, it should return false.
|
// Since there's no published objects, it should return false.
|
||||||
ASSERT_FALSE(subscriber->PublishIfPossible());
|
ASSERT_FALSE(subscriber->PublishIfPossible());
|
||||||
|
|
||||||
|
|
|
@ -277,6 +277,8 @@ class GcsRpcClient {
|
||||||
internal_kv_grpc_client_, )
|
internal_kv_grpc_client_, )
|
||||||
|
|
||||||
/// Operations for pubsub
|
/// Operations for pubsub
|
||||||
|
VOID_GCS_RPC_CLIENT_METHOD(InternalPubSubGcsService, GcsPublish,
|
||||||
|
internal_pubsub_grpc_client_, )
|
||||||
VOID_GCS_RPC_CLIENT_METHOD(InternalPubSubGcsService, GcsSubscriberPoll,
|
VOID_GCS_RPC_CLIENT_METHOD(InternalPubSubGcsService, GcsSubscriberPoll,
|
||||||
internal_pubsub_grpc_client_, )
|
internal_pubsub_grpc_client_, )
|
||||||
VOID_GCS_RPC_CLIENT_METHOD(InternalPubSubGcsService, GcsSubscriberCommandBatch,
|
VOID_GCS_RPC_CLIENT_METHOD(InternalPubSubGcsService, GcsSubscriberCommandBatch,
|
||||||
|
|
|
@ -606,6 +606,9 @@ class InternalPubSubGcsServiceHandler {
|
||||||
public:
|
public:
|
||||||
virtual ~InternalPubSubGcsServiceHandler() = default;
|
virtual ~InternalPubSubGcsServiceHandler() = default;
|
||||||
|
|
||||||
|
virtual void HandleGcsPublish(const GcsPublishRequest &request, GcsPublishReply *reply,
|
||||||
|
SendReplyCallback send_reply_callback) = 0;
|
||||||
|
|
||||||
virtual void HandleGcsSubscriberPoll(const GcsSubscriberPollRequest &request,
|
virtual void HandleGcsSubscriberPoll(const GcsSubscriberPollRequest &request,
|
||||||
GcsSubscriberPollReply *reply,
|
GcsSubscriberPollReply *reply,
|
||||||
SendReplyCallback send_reply_callback) = 0;
|
SendReplyCallback send_reply_callback) = 0;
|
||||||
|
@ -626,6 +629,7 @@ class InternalPubSubGrpcService : public GrpcService {
|
||||||
void InitServerCallFactories(
|
void InitServerCallFactories(
|
||||||
const std::unique_ptr<grpc::ServerCompletionQueue> &cq,
|
const std::unique_ptr<grpc::ServerCompletionQueue> &cq,
|
||||||
std::vector<std::unique_ptr<ServerCallFactory>> *server_call_factories) override {
|
std::vector<std::unique_ptr<ServerCallFactory>> *server_call_factories) override {
|
||||||
|
INTERNAL_PUBSUB_SERVICE_RPC_HANDLER(GcsPublish);
|
||||||
INTERNAL_PUBSUB_SERVICE_RPC_HANDLER(GcsSubscriberPoll);
|
INTERNAL_PUBSUB_SERVICE_RPC_HANDLER(GcsSubscriberPoll);
|
||||||
INTERNAL_PUBSUB_SERVICE_RPC_HANDLER(GcsSubscriberCommandBatch);
|
INTERNAL_PUBSUB_SERVICE_RPC_HANDLER(GcsSubscriberCommandBatch);
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Reference in a new issue