ray/dashboard/head.py
2020-08-14 14:06:57 -07:00

170 lines
6.6 KiB
Python

import sys
import asyncio
import logging
import aiohttp
import aioredis
from grpc.experimental import aio as aiogrpc
import ray.services
import ray.new_dashboard.consts as dashboard_consts
import ray.new_dashboard.utils as dashboard_utils
from ray.core.generated import gcs_service_pb2
from ray.core.generated import gcs_service_pb2_grpc
from ray.new_dashboard.datacenter import DataSource, DataOrganizer
logger = logging.getLogger(__name__)
routes = dashboard_utils.ClassMethodRouteTable
aiogrpc.init_grpc_aio()
def gcs_node_info_to_dict(message):
return dashboard_utils.message_to_dict(
message, {"nodeId"}, including_default_value_fields=True)
class DashboardHead:
def __init__(self, redis_address, redis_password):
# Scan and import head modules for collecting http routes.
self._head_cls_list = dashboard_utils.get_all_modules(
dashboard_utils.DashboardHeadModule)
ip, port = redis_address.split(":")
# NodeInfoGcsService
self._gcs_node_info_stub = None
self._gcs_rpc_error_counter = 0
# Public attributes are accessible for all head modules.
self.redis_address = (ip, int(port))
self.redis_password = redis_password
self.aioredis_client = None
self.aiogrpc_gcs_channel = None
self.http_session = aiohttp.ClientSession(
loop=asyncio.get_event_loop())
self.ip = ray.services.get_node_ip_address()
async def _get_nodes(self):
"""Read the client table.
Returns:
A list of information about the nodes in the cluster.
"""
request = gcs_service_pb2.GetAllNodeInfoRequest()
reply = await self._gcs_node_info_stub.GetAllNodeInfo(
request, timeout=2)
if reply.status.code == 0:
results = []
node_id_set = set()
for node_info in reply.node_info_list:
if node_info.node_id in node_id_set:
continue
node_id_set.add(node_info.node_id)
node_info_dict = gcs_node_info_to_dict(node_info)
results.append(node_info_dict)
return results
else:
logger.error("Failed to GetAllNodeInfo: %s", reply.status.message)
async def _update_nodes(self):
while True:
try:
nodes = await self._get_nodes()
self._gcs_rpc_error_counter = 0
node_ips = [node["nodeManagerAddress"] for node in nodes]
node_hostnames = [
node["nodeManagerHostname"] for node in nodes
]
agents = dict(DataSource.agents)
for node in nodes:
node_ip = node["nodeManagerAddress"]
if node_ip not in agents:
key = "{}{}".format(
dashboard_consts.DASHBOARD_AGENT_PORT_PREFIX,
node_ip)
agent_port = await self.aioredis_client.get(key)
if agent_port:
agents[node_ip] = agent_port
for ip in agents.keys() - set(node_ips):
agents.pop(ip, None)
DataSource.agents.reset(agents)
DataSource.nodes.reset(dict(zip(node_ips, nodes)))
DataSource.hostname_to_ip.reset(
dict(zip(node_hostnames, node_ips)))
DataSource.ip_to_hostname.reset(
dict(zip(node_ips, node_hostnames)))
except aiogrpc.AioRpcError as ex:
logger.exception(ex)
self._gcs_rpc_error_counter += 1
if self._gcs_rpc_error_counter > \
dashboard_consts.MAX_COUNT_OF_GCS_RPC_ERROR:
logger.error(
"Dashboard suicide, the GCS RPC error count %s > %s",
self._gcs_rpc_error_counter,
dashboard_consts.MAX_COUNT_OF_GCS_RPC_ERROR)
sys.exit(-1)
except Exception as ex:
logger.exception(ex)
finally:
await asyncio.sleep(
dashboard_consts.UPDATE_NODES_INTERVAL_SECONDS)
def _load_modules(self):
"""Load dashboard head modules."""
modules = []
for cls in self._head_cls_list:
logger.info("Load %s: %s",
dashboard_utils.DashboardHeadModule.__name__, cls)
c = cls(self)
dashboard_utils.ClassMethodRouteTable.bind(c)
modules.append(c)
return modules
async def run(self):
# Create an aioredis client for all modules.
self.aioredis_client = await aioredis.create_redis_pool(
address=self.redis_address, password=self.redis_password)
# Waiting for GCS is ready.
while True:
try:
gcs_address = await self.aioredis_client.get(
dashboard_consts.REDIS_KEY_GCS_SERVER_ADDRESS)
if not gcs_address:
raise Exception("GCS address not found.")
logger.info("Connect to GCS at %s", gcs_address)
channel = aiogrpc.insecure_channel(gcs_address)
except Exception as ex:
logger.error("Connect to GCS failed: %s, retry...", ex)
await asyncio.sleep(
dashboard_consts.CONNECT_GCS_INTERVAL_SECONDS)
else:
self.aiogrpc_gcs_channel = channel
break
# Create a NodeInfoGcsServiceStub.
self._gcs_node_info_stub = gcs_service_pb2_grpc.NodeInfoGcsServiceStub(
self.aiogrpc_gcs_channel)
async def _async_notify():
"""Notify signals from queue."""
while True:
co = await dashboard_utils.NotifyQueue.get()
try:
await co
except Exception as e:
logger.exception(e)
async def _purge_data():
"""Purge data in datacenter."""
while True:
await asyncio.sleep(
dashboard_consts.PURGE_DATA_INTERVAL_SECONDS)
try:
await DataOrganizer.purge()
except Exception as e:
logger.exception(e)
modules = self._load_modules()
# Freeze signal after all modules loaded.
dashboard_utils.SignalManager.freeze()
await asyncio.gather(self._update_nodes(), _async_notify(),
_purge_data(), *(m.run() for m in modules))