ray/dashboard/modules/node/node_head.py

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

341 lines
13 KiB
Python
Raw Normal View History

import asyncio
import re
import logging
import json
import aiohttp.web
import ray._private.utils
from ray import ray_constants
from ray.dashboard.modules.node import node_consts
from ray.dashboard.modules.node.node_consts import (
MAX_LOGS_TO_CACHE,
LOG_PRUNE_THREASHOLD,
)
import ray.dashboard.utils as dashboard_utils
import ray.dashboard.optional_utils as dashboard_optional_utils
import ray.dashboard.consts as dashboard_consts
from ray.dashboard.utils import async_loop_forever
from ray.dashboard.memory_utils import GroupByType, SortingType
from ray.core.generated import node_manager_pb2
from ray.core.generated import node_manager_pb2_grpc
from ray.core.generated import gcs_service_pb2
from ray.core.generated import gcs_service_pb2_grpc
from ray.dashboard.datacenter import DataSource, DataOrganizer
logger = logging.getLogger(__name__)
routes = dashboard_optional_utils.ClassMethodRouteTable
def gcs_node_info_to_dict(message):
return dashboard_utils.message_to_dict(
message, {"nodeId"}, including_default_value_fields=True
)
def node_stats_to_dict(message):
decode_keys = {
"actorId",
"jobId",
"taskId",
"parentTaskId",
"sourceActorId",
"callerId",
"rayletId",
"workerId",
"placementGroupId",
}
core_workers_stats = message.core_workers_stats
message.ClearField("core_workers_stats")
try:
result = dashboard_utils.message_to_dict(message, decode_keys)
result["coreWorkersStats"] = [
dashboard_utils.message_to_dict(
m, decode_keys, including_default_value_fields=True
)
for m in core_workers_stats
]
return result
finally:
message.core_workers_stats.extend(core_workers_stats)
class NodeHead(dashboard_utils.DashboardHeadModule):
def __init__(self, dashboard_head):
super().__init__(dashboard_head)
self._stubs = {}
# NodeInfoGcsService
self._gcs_node_info_stub = None
self._collect_memory_info = False
DataSource.nodes.signal.append(self._update_stubs)
async def _update_stubs(self, change):
if change.old:
node_id, node_info = change.old
self._stubs.pop(node_id)
if change.new:
# TODO(fyrestone): Handle exceptions.
node_id, node_info = change.new
address = "{}:{}".format(
node_info["nodeManagerAddress"], int(node_info["nodeManagerPort"])
)
options = ray_constants.GLOBAL_GRPC_OPTIONS
channel = ray._private.utils.init_grpc_channel(
address, options, asynchronous=True
)
stub = node_manager_pb2_grpc.NodeManagerServiceStub(channel)
self._stubs[node_id] = stub
async def _get_nodes(self):
"""Read the client table.
Returns:
A dict of information about the nodes in the cluster.
"""
request = gcs_service_pb2.GetAllNodeInfoRequest()
reply = await self._gcs_node_info_stub.GetAllNodeInfo(request, timeout=2)
if reply.status.code == 0:
result = {}
for node_info in reply.node_info_list:
node_info_dict = gcs_node_info_to_dict(node_info)
result[node_info_dict["nodeId"]] = node_info_dict
return result
else:
logger.error("Failed to GetAllNodeInfo: %s", reply.status.message)
async def _update_nodes(self):
# TODO(fyrestone): Refactor code for updating actor / node / job.
# Subscribe actor channel.
while True:
try:
nodes = await self._get_nodes()
alive_node_ids = []
alive_node_infos = []
node_id_to_ip = {}
node_id_to_hostname = {}
for node in nodes.values():
node_id = node["nodeId"]
ip = node["nodeManagerAddress"]
hostname = node["nodeManagerHostname"]
node_id_to_ip[node_id] = ip
node_id_to_hostname[node_id] = hostname
assert node["state"] in ["ALIVE", "DEAD"]
if node["state"] == "ALIVE":
alive_node_ids.append(node_id)
alive_node_infos.append(node)
agents = dict(DataSource.agents)
for node_id in alive_node_ids:
key = f"{dashboard_consts.DASHBOARD_AGENT_PORT_PREFIX}" f"{node_id}"
# TODO: Use async version if performance is an issue
agent_port = ray.experimental.internal_kv._internal_kv_get(
key, namespace=ray_constants.KV_NAMESPACE_DASHBOARD
)
if agent_port:
agents[node_id] = json.loads(agent_port)
for node_id in agents.keys() - set(alive_node_ids):
agents.pop(node_id, None)
DataSource.node_id_to_ip.reset(node_id_to_ip)
DataSource.node_id_to_hostname.reset(node_id_to_hostname)
DataSource.agents.reset(agents)
DataSource.nodes.reset(nodes)
except Exception:
logger.exception("Error updating nodes.")
finally:
await asyncio.sleep(node_consts.UPDATE_NODES_INTERVAL_SECONDS)
@routes.get("/nodes")
@dashboard_optional_utils.aiohttp_cache
async def get_all_nodes(self, req) -> aiohttp.web.Response:
view = req.query.get("view")
if view == "summary":
all_node_summary = await DataOrganizer.get_all_node_summary()
return dashboard_optional_utils.rest_response(
success=True, message="Node summary fetched.", summary=all_node_summary
)
elif view == "details":
all_node_details = await DataOrganizer.get_all_node_details()
return dashboard_optional_utils.rest_response(
success=True,
message="All node details fetched",
clients=all_node_details,
)
elif view is not None and view.lower() == "hostNameList".lower():
alive_hostnames = set()
for node in DataSource.nodes.values():
if node["state"] == "ALIVE":
alive_hostnames.add(node["nodeManagerHostname"])
return dashboard_optional_utils.rest_response(
success=True,
message="Node hostname list fetched.",
host_name_list=list(alive_hostnames),
)
else:
return dashboard_optional_utils.rest_response(
success=False, message=f"Unknown view {view}"
)
@routes.get("/nodes/{node_id}")
@dashboard_optional_utils.aiohttp_cache
async def get_node(self, req) -> aiohttp.web.Response:
node_id = req.match_info.get("node_id")
node_info = await DataOrganizer.get_node_info(node_id)
return dashboard_optional_utils.rest_response(
success=True, message="Node details fetched.", detail=node_info
)
@routes.get("/memory/memory_table")
async def get_memory_table(self, req) -> aiohttp.web.Response:
group_by = req.query.get("group_by")
sort_by = req.query.get("sort_by")
kwargs = {}
if group_by:
kwargs["group_by"] = GroupByType(group_by)
if sort_by:
kwargs["sort_by"] = SortingType(sort_by)
memory_table = await DataOrganizer.get_memory_table(**kwargs)
return dashboard_optional_utils.rest_response(
success=True,
message="Fetched memory table",
memory_table=memory_table.as_dict(),
)
@routes.get("/memory/set_fetch")
async def set_fetch_memory_info(self, req) -> aiohttp.web.Response:
should_fetch = req.query["shouldFetch"]
if should_fetch == "true":
self._collect_memory_info = True
elif should_fetch == "false":
self._collect_memory_info = False
else:
return dashboard_optional_utils.rest_response(
success=False, message=f"Unknown argument to set_fetch {should_fetch}"
)
return dashboard_optional_utils.rest_response(
success=True, message=f"Successfully set fetching to {should_fetch}"
)
@routes.get("/node_logs")
async def get_logs(self, req) -> aiohttp.web.Response:
ip = req.query["ip"]
pid = str(req.query.get("pid", ""))
node_logs = DataSource.ip_and_pid_to_logs.get(ip, {})
if pid:
node_logs = {str(pid): node_logs.get(pid, [])}
return dashboard_optional_utils.rest_response(
success=True, message="Fetched logs.", logs=node_logs
)
@routes.get("/node_errors")
async def get_errors(self, req) -> aiohttp.web.Response:
ip = req.query["ip"]
pid = str(req.query.get("pid", ""))
node_errors = DataSource.ip_and_pid_to_errors.get(ip, {})
if pid:
node_errors = {str(pid): node_errors.get(pid, [])}
return dashboard_optional_utils.rest_response(
success=True, message="Fetched errors.", errors=node_errors
)
@async_loop_forever(node_consts.NODE_STATS_UPDATE_INTERVAL_SECONDS)
async def _update_node_stats(self):
# Copy self._stubs to avoid `dictionary changed size during iteration`.
for node_id, stub in list(self._stubs.items()):
node_info = DataSource.nodes.get(node_id)
if node_info["state"] != "ALIVE":
continue
try:
reply = await stub.GetNodeStats(
node_manager_pb2.GetNodeStatsRequest(
include_memory_info=self._collect_memory_info
),
timeout=2,
)
reply_dict = node_stats_to_dict(reply)
DataSource.node_stats[node_id] = reply_dict
except Exception:
logger.exception(f"Error updating node stats of {node_id}.")
async def _update_log_info(self):
if ray_constants.DISABLE_DASHBOARD_LOG_INFO:
return
def process_log_batch(log_batch):
ip = log_batch["ip"]
pid = str(log_batch["pid"])
if pid != "autoscaler":
logs_for_ip = dict(DataSource.ip_and_pid_to_logs.get(ip, {}))
logs_for_pid = list(logs_for_ip.get(pid, []))
logs_for_pid.extend(log_batch["lines"])
# Only cache upto MAX_LOGS_TO_CACHE
logs_length = len(logs_for_pid)
if logs_length > MAX_LOGS_TO_CACHE * LOG_PRUNE_THREASHOLD:
offset = logs_length - MAX_LOGS_TO_CACHE
del logs_for_pid[:offset]
logs_for_ip[pid] = logs_for_pid
DataSource.ip_and_pid_to_logs[ip] = logs_for_ip
logger.debug(f"Received a log for {ip} and {pid}")
while True:
try:
log_batch = await self._dashboard_head.gcs_log_subscriber.poll()
if log_batch is None:
continue
process_log_batch(log_batch)
except Exception:
logger.exception("Error receiving log from GCS.")
async def _update_error_info(self):
def process_error(error_data):
message = error_data.error_message
message = re.sub(r"\x1b\[\d+m", "", message)
match = re.search(r"\(pid=(\d+), ip=(.*?)\)", message)
if match:
pid = match.group(1)
ip = match.group(2)
errs_for_ip = dict(DataSource.ip_and_pid_to_errors.get(ip, {}))
pid_errors = list(errs_for_ip.get(pid, []))
pid_errors.append(
{
"message": message,
"timestamp": error_data.timestamp,
"type": error_data.type,
}
)
errs_for_ip[pid] = pid_errors
DataSource.ip_and_pid_to_errors[ip] = errs_for_ip
logger.info(f"Received error entry for {ip} {pid}")
while True:
try:
(
_,
error_data,
) = await self._dashboard_head.gcs_error_subscriber.poll()
if error_data is None:
continue
process_error(error_data)
except Exception:
logger.exception("Error receiving error info from GCS.")
async def run(self, server):
gcs_channel = self._dashboard_head.aiogrpc_gcs_channel
self._gcs_node_info_stub = gcs_service_pb2_grpc.NodeInfoGcsServiceStub(
gcs_channel
)
await asyncio.gather(
self._update_nodes(),
self._update_node_stats(),
self._update_log_info(),
self._update_error_info(),
)
@staticmethod
def is_minimal_module():
return False