ray/dashboard/head.py
Jiajun Yao 2b37c32d43
Auto reconnect for gcs aio client (#26673)
#20299 adds auto reconnect for sync gcs client and this PR does the same thing for async gcs client.
2022-07-19 13:11:09 -07:00

227 lines
8.2 KiB
Python

import asyncio
import logging
import os
import threading
from concurrent.futures import Future
from queue import Queue
import ray._private.services
import ray._private.utils
import ray.dashboard.consts as dashboard_consts
import ray.dashboard.utils as dashboard_utils
import ray.experimental.internal_kv as internal_kv
from ray._private import ray_constants
from ray._private.gcs_pubsub import GcsAioErrorSubscriber, GcsAioLogSubscriber
from ray._private.gcs_utils import GcsClient, GcsAioClient, check_health
from ray.dashboard.datacenter import DataOrganizer
from ray.dashboard.utils import async_loop_forever
try:
from grpc import aio as aiogrpc
except ImportError:
from grpc.experimental import aio as aiogrpc
logger = logging.getLogger(__name__)
aiogrpc.init_grpc_aio()
GRPC_CHANNEL_OPTIONS = (
*ray_constants.GLOBAL_GRPC_OPTIONS,
("grpc.max_send_message_length", ray_constants.GRPC_CPP_MAX_MESSAGE_SIZE),
("grpc.max_receive_message_length", ray_constants.GRPC_CPP_MAX_MESSAGE_SIZE),
)
class GCSHealthCheckThread(threading.Thread):
def __init__(self, gcs_address: str):
self.gcs_address = gcs_address
self.work_queue = Queue()
super().__init__(daemon=True)
def run(self) -> None:
while True:
future = self.work_queue.get()
check_result = check_health(self.gcs_address)
future.set_result(check_result)
async def check_once(self) -> bool:
"""Ask the thread to perform a health check."""
assert (
threading.current_thread != self
), "caller shouldn't be from the same thread as GCSHealthCheckThread."
future = Future()
self.work_queue.put(future)
return await asyncio.wrap_future(future)
class DashboardHead:
def __init__(
self,
http_host,
http_port,
http_port_retries,
gcs_address,
log_dir,
temp_dir,
session_dir,
minimal,
):
self.minimal = minimal
self.health_check_thread: GCSHealthCheckThread = None
self._gcs_rpc_error_counter = 0
# Public attributes are accessible for all head modules.
# Walkaround for issue: https://github.com/ray-project/ray/issues/7084
self.http_host = "127.0.0.1" if http_host == "localhost" else http_host
self.http_port = http_port
self.http_port_retries = http_port_retries
self.gcs_address = None
assert gcs_address is not None
self.gcs_address = gcs_address
self.log_dir = log_dir
self.temp_dir = temp_dir
self.session_dir = session_dir
self.aiogrpc_gcs_channel = None
self.gcs_error_subscriber = None
self.gcs_log_subscriber = None
self.ip = ray.util.get_node_ip_address()
ip, port = gcs_address.split(":")
self.server = aiogrpc.server(options=(("grpc.so_reuseport", 0),))
grpc_ip = "127.0.0.1" if self.ip == "127.0.0.1" else "0.0.0.0"
self.grpc_port = ray._private.tls_utils.add_port_to_grpc_server(
self.server, f"{grpc_ip}:0"
)
logger.info("Dashboard head grpc address: %s:%s", grpc_ip, self.grpc_port)
# If the dashboard is started as non-minimal version, http server should
# be configured to expose APIs.
self.http_server = None
async def _configure_http_server(self, modules):
from ray.dashboard.http_server_head import HttpServerDashboardHead
http_server = HttpServerDashboardHead(
self.ip, self.http_host, self.http_port, self.http_port_retries
)
await http_server.run(modules)
return http_server
@property
def http_session(self):
assert self.http_server, "Accessing unsupported API in a minimal ray."
return self.http_server.http_session
@async_loop_forever(dashboard_consts.GCS_CHECK_ALIVE_INTERVAL_SECONDS)
async def _gcs_check_alive(self):
check_future = self.health_check_thread.check_once()
# NOTE(simon): making sure the check procedure doesn't timeout itself.
# Otherwise, the dashboard will always think that gcs is alive.
try:
is_alive = await asyncio.wait_for(
check_future, dashboard_consts.GCS_CHECK_ALIVE_RPC_TIMEOUT + 1
)
except asyncio.TimeoutError:
logger.error("Failed to check gcs health, client timed out.")
is_alive = False
if is_alive:
self._gcs_rpc_error_counter = 0
else:
self._gcs_rpc_error_counter += 1
if (
self._gcs_rpc_error_counter
> dashboard_consts.GCS_CHECK_ALIVE_MAX_COUNT_OF_RPC_ERROR
):
logger.error(
"Dashboard exiting because it received too many GCS RPC "
"errors count: %s, threshold is %s.",
self._gcs_rpc_error_counter,
dashboard_consts.GCS_CHECK_ALIVE_MAX_COUNT_OF_RPC_ERROR,
)
# TODO(fyrestone): Do not use ray.state in
# PrometheusServiceDiscoveryWriter.
# Currently, we use os._exit() here to avoid hanging at the ray
# shutdown(). Please refer to:
# https://github.com/ray-project/ray/issues/16328
os._exit(-1)
def _load_modules(self):
"""Load dashboard head modules."""
modules = []
head_cls_list = dashboard_utils.get_all_modules(
dashboard_utils.DashboardHeadModule
)
for cls in head_cls_list:
logger.info(
"Loading %s: %s", dashboard_utils.DashboardHeadModule.__name__, cls
)
c = cls(self)
modules.append(c)
logger.info("Loaded %d modules.", len(modules))
return modules
async def run(self):
gcs_address = self.gcs_address
# Dashboard will handle connection failure automatically
self.gcs_client = GcsClient(address=gcs_address, nums_reconnect_retry=0)
internal_kv._initialize_internal_kv(self.gcs_client)
self.gcs_aio_client = GcsAioClient(address=gcs_address, nums_reconnect_retry=0)
self.aiogrpc_gcs_channel = self.gcs_aio_client.channel.channel()
self.gcs_error_subscriber = GcsAioErrorSubscriber(address=gcs_address)
self.gcs_log_subscriber = GcsAioLogSubscriber(address=gcs_address)
await self.gcs_error_subscriber.subscribe()
await self.gcs_log_subscriber.subscribe()
self.health_check_thread = GCSHealthCheckThread(gcs_address)
self.health_check_thread.start()
# Start a grpc asyncio server.
await self.server.start()
async def _async_notify():
"""Notify signals from queue."""
while True:
co = await dashboard_utils.NotifyQueue.get()
try:
await co
except Exception:
logger.exception(f"Error notifying coroutine {co}")
modules = self._load_modules()
http_host, http_port = self.http_host, self.http_port
if not self.minimal:
self.http_server = await self._configure_http_server(modules)
http_host, http_port = self.http_server.get_address()
internal_kv._internal_kv_put(
ray_constants.DASHBOARD_ADDRESS,
f"{http_host}:{http_port}",
namespace=ray_constants.KV_NAMESPACE_DASHBOARD,
)
# TODO: Use async version if performance is an issue
# Write the dashboard head port to gcs kv.
internal_kv._internal_kv_put(
dashboard_consts.DASHBOARD_RPC_ADDRESS,
f"{self.ip}:{self.grpc_port}",
namespace=ray_constants.KV_NAMESPACE_DASHBOARD,
)
# Freeze signal after all modules loaded.
dashboard_utils.SignalManager.freeze()
concurrent_tasks = [
self._gcs_check_alive(),
_async_notify(),
DataOrganizer.purge(),
DataOrganizer.organize(),
]
await asyncio.gather(*concurrent_tasks, *(m.run(self.server) for m in modules))
await self.server.wait_for_termination()
if self.http_server:
await self.http_server.cleanup()