mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
170 lines
6.6 KiB
Python
170 lines
6.6 KiB
Python
import sys
|
|
import asyncio
|
|
import logging
|
|
|
|
import aiohttp
|
|
import aioredis
|
|
from grpc.experimental import aio as aiogrpc
|
|
|
|
import ray.services
|
|
import ray.new_dashboard.consts as dashboard_consts
|
|
import ray.new_dashboard.utils as dashboard_utils
|
|
from ray.core.generated import gcs_service_pb2
|
|
from ray.core.generated import gcs_service_pb2_grpc
|
|
from ray.new_dashboard.datacenter import DataSource, DataOrganizer
|
|
|
|
logger = logging.getLogger(__name__)
|
|
routes = dashboard_utils.ClassMethodRouteTable
|
|
|
|
aiogrpc.init_grpc_aio()
|
|
|
|
|
|
def gcs_node_info_to_dict(message):
|
|
return dashboard_utils.message_to_dict(
|
|
message, {"nodeId"}, including_default_value_fields=True)
|
|
|
|
|
|
class DashboardHead:
|
|
def __init__(self, redis_address, redis_password):
|
|
# Scan and import head modules for collecting http routes.
|
|
self._head_cls_list = dashboard_utils.get_all_modules(
|
|
dashboard_utils.DashboardHeadModule)
|
|
ip, port = redis_address.split(":")
|
|
# NodeInfoGcsService
|
|
self._gcs_node_info_stub = None
|
|
self._gcs_rpc_error_counter = 0
|
|
# Public attributes are accessible for all head modules.
|
|
self.redis_address = (ip, int(port))
|
|
self.redis_password = redis_password
|
|
self.aioredis_client = None
|
|
self.aiogrpc_gcs_channel = None
|
|
self.http_session = aiohttp.ClientSession(
|
|
loop=asyncio.get_event_loop())
|
|
self.ip = ray.services.get_node_ip_address()
|
|
|
|
async def _get_nodes(self):
|
|
"""Read the client table.
|
|
|
|
Returns:
|
|
A list of information about the nodes in the cluster.
|
|
"""
|
|
request = gcs_service_pb2.GetAllNodeInfoRequest()
|
|
reply = await self._gcs_node_info_stub.GetAllNodeInfo(
|
|
request, timeout=2)
|
|
if reply.status.code == 0:
|
|
results = []
|
|
node_id_set = set()
|
|
for node_info in reply.node_info_list:
|
|
if node_info.node_id in node_id_set:
|
|
continue
|
|
node_id_set.add(node_info.node_id)
|
|
node_info_dict = gcs_node_info_to_dict(node_info)
|
|
results.append(node_info_dict)
|
|
return results
|
|
else:
|
|
logger.error("Failed to GetAllNodeInfo: %s", reply.status.message)
|
|
|
|
async def _update_nodes(self):
|
|
while True:
|
|
try:
|
|
nodes = await self._get_nodes()
|
|
self._gcs_rpc_error_counter = 0
|
|
node_ips = [node["nodeManagerAddress"] for node in nodes]
|
|
node_hostnames = [
|
|
node["nodeManagerHostname"] for node in nodes
|
|
]
|
|
|
|
agents = dict(DataSource.agents)
|
|
for node in nodes:
|
|
node_ip = node["nodeManagerAddress"]
|
|
if node_ip not in agents:
|
|
key = "{}{}".format(
|
|
dashboard_consts.DASHBOARD_AGENT_PORT_PREFIX,
|
|
node_ip)
|
|
agent_port = await self.aioredis_client.get(key)
|
|
if agent_port:
|
|
agents[node_ip] = agent_port
|
|
for ip in agents.keys() - set(node_ips):
|
|
agents.pop(ip, None)
|
|
|
|
DataSource.agents.reset(agents)
|
|
DataSource.nodes.reset(dict(zip(node_ips, nodes)))
|
|
DataSource.hostname_to_ip.reset(
|
|
dict(zip(node_hostnames, node_ips)))
|
|
DataSource.ip_to_hostname.reset(
|
|
dict(zip(node_ips, node_hostnames)))
|
|
except aiogrpc.AioRpcError as ex:
|
|
logger.exception(ex)
|
|
self._gcs_rpc_error_counter += 1
|
|
if self._gcs_rpc_error_counter > \
|
|
dashboard_consts.MAX_COUNT_OF_GCS_RPC_ERROR:
|
|
logger.error(
|
|
"Dashboard suicide, the GCS RPC error count %s > %s",
|
|
self._gcs_rpc_error_counter,
|
|
dashboard_consts.MAX_COUNT_OF_GCS_RPC_ERROR)
|
|
sys.exit(-1)
|
|
except Exception as ex:
|
|
logger.exception(ex)
|
|
finally:
|
|
await asyncio.sleep(
|
|
dashboard_consts.UPDATE_NODES_INTERVAL_SECONDS)
|
|
|
|
def _load_modules(self):
|
|
"""Load dashboard head modules."""
|
|
modules = []
|
|
for cls in self._head_cls_list:
|
|
logger.info("Load %s: %s",
|
|
dashboard_utils.DashboardHeadModule.__name__, cls)
|
|
c = cls(self)
|
|
dashboard_utils.ClassMethodRouteTable.bind(c)
|
|
modules.append(c)
|
|
return modules
|
|
|
|
async def run(self):
|
|
# Create an aioredis client for all modules.
|
|
self.aioredis_client = await aioredis.create_redis_pool(
|
|
address=self.redis_address, password=self.redis_password)
|
|
# Waiting for GCS is ready.
|
|
while True:
|
|
try:
|
|
gcs_address = await self.aioredis_client.get(
|
|
dashboard_consts.REDIS_KEY_GCS_SERVER_ADDRESS)
|
|
if not gcs_address:
|
|
raise Exception("GCS address not found.")
|
|
logger.info("Connect to GCS at %s", gcs_address)
|
|
channel = aiogrpc.insecure_channel(gcs_address)
|
|
except Exception as ex:
|
|
logger.error("Connect to GCS failed: %s, retry...", ex)
|
|
await asyncio.sleep(
|
|
dashboard_consts.CONNECT_GCS_INTERVAL_SECONDS)
|
|
else:
|
|
self.aiogrpc_gcs_channel = channel
|
|
break
|
|
# Create a NodeInfoGcsServiceStub.
|
|
self._gcs_node_info_stub = gcs_service_pb2_grpc.NodeInfoGcsServiceStub(
|
|
self.aiogrpc_gcs_channel)
|
|
|
|
async def _async_notify():
|
|
"""Notify signals from queue."""
|
|
while True:
|
|
co = await dashboard_utils.NotifyQueue.get()
|
|
try:
|
|
await co
|
|
except Exception as e:
|
|
logger.exception(e)
|
|
|
|
async def _purge_data():
|
|
"""Purge data in datacenter."""
|
|
while True:
|
|
await asyncio.sleep(
|
|
dashboard_consts.PURGE_DATA_INTERVAL_SECONDS)
|
|
try:
|
|
await DataOrganizer.purge()
|
|
except Exception as e:
|
|
logger.exception(e)
|
|
|
|
modules = self._load_modules()
|
|
# Freeze signal after all modules loaded.
|
|
dashboard_utils.SignalManager.freeze()
|
|
await asyncio.gather(self._update_nodes(), _async_notify(),
|
|
_purge_data(), *(m.run() for m in modules))
|