From 50784e249660f84011aee464d163784741337d28 Mon Sep 17 00:00:00 2001 From: fyrestone Date: Thu, 17 Sep 2020 01:17:29 +0800 Subject: [PATCH] [Dashboard] Dashboard node grouping (#10528) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Add RAY_NODE_ID environment var to agent * Node ralated data use node id as key * ray.init() return node id; Pass test_reporter.py * Fix lint & CI * Fix comments * Minor fixes * Fix CI * Add const to ClientID in AgentManager::Options * Use fstring * Add comments * Fix lint * Add test_multi_nodes_info Co-authored-by: 刘宝 --- dashboard/agent.py | 6 +- dashboard/datacenter.py | 55 ++++++++------- dashboard/head.py | 67 ++++++++----------- dashboard/modules/log/log_head.py | 16 +++-- dashboard/modules/log/test_log.py | 5 +- dashboard/modules/reporter/reporter_agent.py | 21 +++--- dashboard/modules/reporter/reporter_head.py | 14 ++-- dashboard/modules/reporter/test_reporter.py | 2 +- .../stats_collector/stats_collector_head.py | 31 +++++---- .../stats_collector/test_stats_collector.py | 44 +++++++++++- dashboard/tests/test_dashboard.py | 13 ++-- python/ray/_raylet.pyx | 4 ++ python/ray/includes/libcoreworker.pxd | 1 + python/ray/worker.py | 3 +- src/ray/core_worker/core_worker.h | 4 ++ src/ray/raylet/agent_manager.cc | 5 +- src/ray/raylet/agent_manager.h | 2 + src/ray/raylet/node_manager.cc | 4 +- 18 files changed, 182 insertions(+), 115 deletions(-) diff --git a/dashboard/agent.py b/dashboard/agent.py index b7c379da3..0c69c0da5 100644 --- a/dashboard/agent.py +++ b/dashboard/agent.py @@ -55,6 +55,8 @@ class DashboardAgent(object): self.node_manager_port = node_manager_port self.object_store_name = object_store_name self.raylet_name = raylet_name + self.node_id = os.environ["RAY_NODE_ID"] + assert self.node_id, "Empty node id (RAY_NODE_ID)." self.ip = ray.services.get_node_ip_address() self.server = aiogrpc.server(options=(("grpc.so_reuseport", 0), )) self.grpc_port = self.server.add_insecure_port("[::]:0") @@ -152,8 +154,8 @@ class DashboardAgent(object): # Write the dashboard agent port to redis. await self.aioredis_client.set( - "{}{}".format(dashboard_consts.DASHBOARD_AGENT_PORT_PREFIX, - self.ip), json.dumps([http_port, self.grpc_port])) + f"{dashboard_consts.DASHBOARD_AGENT_PORT_PREFIX}{self.node_id}", + json.dumps([http_port, self.grpc_port])) # Register agent to agent manager. raylet_stub = agent_manager_pb2_grpc.AgentManagerServiceStub( diff --git a/dashboard/datacenter.py b/dashboard/datacenter.py index 1ee454917..c7a0fdc67 100644 --- a/dashboard/datacenter.py +++ b/dashboard/datacenter.py @@ -11,22 +11,22 @@ class GlobalSignals: class DataSource: - # {ip address(str): node stats(dict of GetNodeStatsReply + # {node id hex(str): node stats(dict of GetNodeStatsReply # in node_manager.proto)} node_stats = Dict() - # {ip address(str): node physical stats(dict from reporter_agent.py)} + # {node id hex(str): node physical stats(dict from reporter_agent.py)} node_physical_stats = Dict() # {actor id hex(str): actor table data(dict of ActorTableData # in gcs.proto)} actors = Dict() - # {ip address(str): dashboard agent [http port(int), grpc port(int)]} + # {node id hex(str): dashboard agent [http port(int), grpc port(int)]} agents = Dict() - # {ip address(str): gcs node info(dict of GcsNodeInfo in gcs.proto)} + # {node id hex(str): gcs node info(dict of GcsNodeInfo in gcs.proto)} nodes = Dict() - # {hostname(str): ip address(str)} - hostname_to_ip = Dict() - # {ip address(str): hostname(str)} - ip_to_hostname = Dict() + # {node id hex(str): ip address(str)} + node_id_to_ip = Dict() + # {node id hex(str): hostname(str)} + node_id_to_hostname = Dict() class DataOrganizer: @@ -37,20 +37,23 @@ class DataOrganizer: # we do not needs to purge them: # * agents # * nodes - # * hostname_to_ip - # * ip_to_hostname + # * node_id_to_ip + # * node_id_to_hostname logger.info("Purge data.") - valid_keys = DataSource.ip_to_hostname.keys() - for key in DataSource.node_stats.keys() - valid_keys: + alive_nodes = { + node_id + for node_id, node_info in DataSource.nodes.items() + if node_info["state"] == "ALIVE" + } + for key in DataSource.node_stats.keys() - alive_nodes: DataSource.node_stats.pop(key) - for key in DataSource.node_physical_stats.keys() - valid_keys: + for key in DataSource.node_physical_stats.keys() - alive_nodes: DataSource.node_physical_stats.pop(key) @classmethod - async def get_node_actors(cls, hostname): - ip = DataSource.hostname_to_ip[hostname] - node_stats = DataSource.node_stats.get(ip, {}) + async def get_node_actors(cls, node_id): + node_stats = DataSource.node_stats.get(node_id, {}) node_worker_id_set = set() for worker_stats in node_stats.get("workersStats", []): node_worker_id_set.add(worker_stats["workerId"]) @@ -61,10 +64,10 @@ class DataOrganizer: return node_actors @classmethod - async def get_node_info(cls, hostname): - ip = DataSource.hostname_to_ip[hostname] - node_physical_stats = DataSource.node_physical_stats.get(ip, {}) - node_stats = DataSource.node_stats.get(ip, {}) + async def get_node_info(cls, node_id): + node_physical_stats = DataSource.node_physical_stats.get(node_id, {}) + node_stats = DataSource.node_stats.get(node_id, {}) + node = DataSource.nodes.get(node_id, {}) # Merge coreWorkerStats (node stats) to workers (node physical stats) workers_stats = node_stats.pop("workersStats", {}) @@ -86,11 +89,13 @@ class DataOrganizer: worker["language"] = pid_to_language.get(worker["pid"], "") worker["jobId"] = pid_to_job_id.get(worker["pid"], "ffff") - # Merge node stats to node physical stats node_info = node_physical_stats + # Merge node stats to node physical stats node_info["raylet"] = node_stats - node_info["actors"] = await cls.get_node_actors(hostname) - node_info["state"] = DataSource.nodes.get(ip, {}).get("state", "DEAD") + # Merge GcsNodeInfo to node physical stats + node_info["raylet"].update(node) + # Merge actors to node physical stats + node_info["actors"] = await cls.get_node_actors(node_id) await GlobalSignals.node_info_fetched.send(node_info) @@ -99,8 +104,8 @@ class DataOrganizer: @classmethod async def get_all_node_summary(cls): all_nodes_summary = [] - for hostname in DataSource.hostname_to_ip.keys(): - node_info = await cls.get_node_info(hostname) + for node_id in DataSource.nodes.keys(): + node_info = await cls.get_node_info(node_id) node_info.pop("workers", None) node_info.pop("actors", None) node_info["raylet"].pop("workersStats", None) diff --git a/dashboard/head.py b/dashboard/head.py index 9a8f27379..5f23b2654 100644 --- a/dashboard/head.py +++ b/dashboard/head.py @@ -54,21 +54,17 @@ class DashboardHead: """Read the client table. Returns: - A list of information about the nodes in the cluster. + A dict of information about the nodes in the cluster. """ request = gcs_service_pb2.GetAllNodeInfoRequest() reply = await self._gcs_node_info_stub.GetAllNodeInfo( request, timeout=2) if reply.status.code == 0: - results = [] - node_id_set = set() + result = {} for node_info in reply.node_info_list: - if node_info.node_id in node_id_set: - continue - node_id_set.add(node_info.node_id) node_info_dict = gcs_node_info_to_dict(node_info) - results.append(node_info_dict) - return results + result[node_info_dict["nodeId"]] = node_info_dict + return result else: logger.error("Failed to GetAllNodeInfo: %s", reply.status.message) @@ -77,44 +73,37 @@ class DashboardHead: try: nodes = await self._get_nodes() - # Get correct node info by state, - # 1. The node is ALIVE if any ALIVE node info - # of the hostname exists. - # 2. The node is DEAD if all node info of the - # hostname are DEAD. - hostname_to_node_info = {} - for node in nodes: - hostname = node["nodeManagerAddress"] + alive_node_ids = [] + alive_node_infos = [] + node_id_to_ip = {} + node_id_to_hostname = {} + for node in nodes.values(): + node_id = node["nodeId"] + ip = node["nodeManagerAddress"] + hostname = node["nodeManagerHostname"] + node_id_to_ip[node_id] = ip + node_id_to_hostname[node_id] = hostname assert node["state"] in ["ALIVE", "DEAD"] - choose = hostname_to_node_info.get(hostname) - if choose is not None and choose["state"] == "ALIVE": - continue - hostname_to_node_info[hostname] = node - nodes = hostname_to_node_info.values() - - self._gcs_rpc_error_counter = 0 - node_ips = [node["nodeManagerAddress"] for node in nodes] - node_hostnames = [ - node["nodeManagerHostname"] for node in nodes - ] + if node["state"] == "ALIVE": + alive_node_ids.append(node_id) + alive_node_infos.append(node) agents = dict(DataSource.agents) - for node in nodes: - node_ip = node["nodeManagerAddress"] - key = "{}{}".format( - dashboard_consts.DASHBOARD_AGENT_PORT_PREFIX, node_ip) + for node_id in alive_node_ids: + key = f"{dashboard_consts.DASHBOARD_AGENT_PORT_PREFIX}" \ + f"{node_id}" agent_port = await self.aioredis_client.get(key) if agent_port: - agents[node_ip] = json.loads(agent_port) - for ip in agents.keys() - set(node_ips): - agents.pop(ip, None) + agents[node_id] = json.loads(agent_port) + for node_id in agents.keys() - set(alive_node_ids): + agents.pop(node_id, None) + DataSource.node_id_to_ip.reset(node_id_to_ip) + DataSource.node_id_to_hostname.reset(node_id_to_hostname) DataSource.agents.reset(agents) - DataSource.nodes.reset(dict(zip(node_ips, nodes))) - DataSource.hostname_to_ip.reset( - dict(zip(node_hostnames, node_ips))) - DataSource.ip_to_hostname.reset( - dict(zip(node_ips, node_hostnames))) + DataSource.nodes.reset(nodes) + + self._gcs_rpc_error_counter = 0 except aiogrpc.AioRpcError: logger.exception("Got AioRpcError when updating nodes.") self._gcs_rpc_error_counter += 1 diff --git a/dashboard/modules/log/log_head.py b/dashboard/modules/log/log_head.py index 5f1a1f6dc..abf4f0405 100644 --- a/dashboard/modules/log/log_head.py +++ b/dashboard/modules/log/log_head.py @@ -22,23 +22,27 @@ class LogHead(dashboard_utils.DashboardHeadModule): self.insert_log_url_to_node_info) async def insert_log_url_to_node_info(self, node_info): - ip = node_info.get("ip") - if ip is None: + node_id = node_info.get("raylet", {}).get("nodeId") + if node_id is None: return - agent_port = DataSource.agents.get(ip) + agent_port = DataSource.agents.get(node_id) if agent_port is None: return agent_http_port, _ = agent_port - log_url = self.LOG_URL_TEMPLATE.format(ip=ip, port=agent_http_port) + log_url = self.LOG_URL_TEMPLATE.format( + ip=node_info.get("ip"), port=agent_http_port) node_info["logUrl"] = log_url @routes.get("/log_index") async def get_log_index(self, req) -> aiohttp.web.Response: url_list = [] - for ip, ports in DataSource.agents.items(): + agent_ips = [] + for node_id, ports in DataSource.agents.items(): + ip = DataSource.node_id_to_ip[node_id] + agent_ips.append(ip) url_list.append( self.LOG_URL_TEMPLATE.format(ip=ip, port=str(ports[0]))) - if self._dashboard_head.ip not in DataSource.agents: + if self._dashboard_head.ip not in agent_ips: url_list.append( self.LOG_URL_TEMPLATE.format( ip=self._dashboard_head.ip, diff --git a/dashboard/modules/log/test_log.py b/dashboard/modules/log/test_log.py index a40643f90..f73e9bc0c 100644 --- a/dashboard/modules/log/test_log.py +++ b/dashboard/modules/log/test_log.py @@ -2,7 +2,6 @@ import os import sys import logging import requests -import socket import time import traceback import html.parser @@ -48,6 +47,7 @@ def test_log(ray_start_with_dashboard): is True) webui_url = ray_start_with_dashboard["webui_url"] webui_url = format_web_url(webui_url) + node_id = ray_start_with_dashboard["node_id"] timeout_seconds = 10 start_time = time.time() @@ -91,8 +91,7 @@ def test_log(ray_start_with_dashboard): assert response.text == "Dashboard" # Test logUrl in node info. - response = requests.get(webui_url + - f"/nodes/{socket.gethostname()}") + response = requests.get(webui_url + f"/nodes/{node_id}") response.raise_for_status() node_info = response.json() assert node_info["result"] is True diff --git a/dashboard/modules/reporter/reporter_agent.py b/dashboard/modules/reporter/reporter_agent.py index f129a702e..7b7e6ef0a 100644 --- a/dashboard/modules/reporter/reporter_agent.py +++ b/dashboard/modules/reporter/reporter_agent.py @@ -70,6 +70,8 @@ class ReporterAgent(dashboard_utils.DashboardAgentModule, self._workers = set() self._network_stats_hist = [(0, (0.0, 0.0))] # time, (sent, recv) self._metrics_agent = MetricsAgent(dashboard_agent.metrics_export_port) + self._key = f"{reporter_consts.REPORTER_PREFIX}" \ + f"{self._dashboard_agent.node_id}" async def GetProfilingStats(self, request, context): pid = request.pid @@ -186,12 +188,15 @@ class ReporterAgent(dashboard_utils.DashboardAgentModule, @staticmethod def _get_raylet_cmdline(): - curr_proc = psutil.Process() - parent = curr_proc.parent() - if parent.pid == 1: - return "" - else: - return parent.cmdline() + try: + curr_proc = psutil.Process() + parent = curr_proc.parent() + if parent.pid == 1: + return [] + else: + return parent.cmdline() + except (psutil.AccessDenied, ProcessLookupError): + return [] def _get_load_avg(self): if sys.platform == "win32": @@ -237,9 +242,7 @@ class ReporterAgent(dashboard_utils.DashboardAgentModule, while True: try: stats = self._get_all_stats() - await aioredis_client.publish( - "{}{}".format(reporter_consts.REPORTER_PREFIX, - self._hostname), jsonify_asdict(stats)) + await aioredis_client.publish(self._key, jsonify_asdict(stats)) except Exception: logger.exception("Error publishing node physical stats.") await asyncio.sleep( diff --git a/dashboard/modules/reporter/reporter_head.py b/dashboard/modules/reporter/reporter_head.py index 6046fd325..fdac34c84 100644 --- a/dashboard/modules/reporter/reporter_head.py +++ b/dashboard/modules/reporter/reporter_head.py @@ -27,10 +27,12 @@ class ReportHead(dashboard_utils.DashboardHeadModule): async def _update_stubs(self, change): if change.old: - ip, port = change.old + node_id, port = change.old + ip = DataSource.node_id_to_ip[node_id] self._stubs.pop(ip) if change.new: - ip, ports = change.new + node_id, ports = change.new + ip = DataSource.node_id_to_ip[node_id] channel = aiogrpc.insecure_channel(f"{ip}:{ports[1]}") stub = reporter_pb2_grpc.ReporterServiceStub(channel) self._stubs[ip] = stub @@ -60,9 +62,13 @@ class ReportHead(dashboard_utils.DashboardHeadModule): async for sender, msg in receiver.iter(): try: - _, data = msg + # The key is b'RAY_REPORTER:{node id hex}', + # e.g. b'RAY_REPORTER:2b4fbd406898cc86fb88fb0acfd5456b0afd87cf' + key, data = msg data = json.loads(ray.utils.decode(data)) - DataSource.node_physical_stats[data["ip"]] = data + key = key.decode("utf-8") + node_id = key.split(":")[-1] + DataSource.node_physical_stats[node_id] = data except Exception: logger.exception( "Error receiving node physical stats from reporter agent.") diff --git a/dashboard/modules/reporter/test_reporter.py b/dashboard/modules/reporter/test_reporter.py index 0097bc465..099d0fdea 100644 --- a/dashboard/modules/reporter/test_reporter.py +++ b/dashboard/modules/reporter/test_reporter.py @@ -81,7 +81,7 @@ def test_node_physical_stats(enable_test_module, shutdown_only): assert result["result"] is True node_physical_stats = result["data"]["nodePhysicalStats"] assert len(node_physical_stats) == 1 - current_stats = node_physical_stats[addresses["raylet_ip_address"]] + current_stats = node_physical_stats[addresses["node_id"]] # Check Actor workers current_actor_pids = set() for worker in current_stats["workers"]: diff --git a/dashboard/modules/stats_collector/stats_collector_head.py b/dashboard/modules/stats_collector/stats_collector_head.py index 5ff210706..c70f057a0 100644 --- a/dashboard/modules/stats_collector/stats_collector_head.py +++ b/dashboard/modules/stats_collector/stats_collector_head.py @@ -50,14 +50,15 @@ class StatsCollector(dashboard_utils.DashboardHeadModule): async def _update_stubs(self, change): if change.old: - ip, port = change.old - self._stubs.pop(ip) + node_id, node_info = change.old + self._stubs.pop(node_id) if change.new: - ip, node_info = change.new - address = "{}:{}".format(ip, int(node_info["nodeManagerPort"])) + node_id, node_info = change.new + address = "{}:{}".format(node_info["nodeManagerAddress"], + int(node_info["nodeManagerPort"])) channel = aiogrpc.insecure_channel(address) stub = node_manager_pb2_grpc.NodeManagerServiceStub(channel) - self._stubs[ip] = stub + self._stubs[node_id] = stub @routes.get("/nodes") async def get_all_nodes(self, req) -> aiohttp.web.Response: @@ -69,18 +70,22 @@ class StatsCollector(dashboard_utils.DashboardHeadModule): message="Node summary fetched.", summary=all_node_summary) elif view is not None and view.lower() == "hostNameList".lower(): + alive_hostnames = set() + for node in DataSource.nodes.values(): + if node["state"] == "ALIVE": + alive_hostnames.add(node["nodeManagerHostname"]) return await dashboard_utils.rest_response( success=True, message="Node hostname list fetched.", - host_name_list=list(DataSource.hostname_to_ip.keys())) + host_name_list=list(alive_hostnames)) else: return await dashboard_utils.rest_response( success=False, message=f"Unknown view {view}") - @routes.get("/nodes/{hostname}") + @routes.get("/nodes/{node_id}") async def get_node(self, req) -> aiohttp.web.Response: - hostname = req.match_info.get("hostname") - node_info = await DataOrganizer.get_node_info(hostname) + node_id = req.match_info.get("node_id") + node_info = await DataOrganizer.get_node_info(node_id) return await dashboard_utils.rest_response( success=True, message="Node detail fetched.", detail=node_info) @@ -133,17 +138,17 @@ class StatsCollector(dashboard_utils.DashboardHeadModule): @async_loop_forever( stats_collector_consts.NODE_STATS_UPDATE_INTERVAL_SECONDS) async def _update_node_stats(self): - for ip, stub in self._stubs.items(): - node_info = DataSource.nodes.get(ip) + for node_id, stub in self._stubs.items(): + node_info = DataSource.nodes.get(node_id) if node_info["state"] != "ALIVE": continue try: reply = await stub.GetNodeStats( node_manager_pb2.GetNodeStatsRequest(), timeout=2) reply_dict = node_stats_to_dict(reply) - DataSource.node_stats[ip] = reply_dict + DataSource.node_stats[node_id] = reply_dict except Exception: - logger.exception(f"Error updating node stats of {ip}.") + logger.exception(f"Error updating node stats of {node_id}.") async def run(self, server): gcs_channel = self._dashboard_head.aiogrpc_gcs_channel diff --git a/dashboard/modules/stats_collector/test_stats_collector.py b/dashboard/modules/stats_collector/test_stats_collector.py index 72e614437..0c8305c83 100644 --- a/dashboard/modules/stats_collector/test_stats_collector.py +++ b/dashboard/modules/stats_collector/test_stats_collector.py @@ -11,6 +11,7 @@ from ray.new_dashboard.tests.conftest import * # noqa from ray.test_utils import ( format_web_url, wait_until_server_available, + wait_for_condition, ) os.environ["RAY_USE_NEW_DASHBOARD"] = "1" @@ -32,6 +33,7 @@ def test_node_info(ray_start_with_dashboard): is True) webui_url = ray_start_with_dashboard["webui_url"] webui_url = format_web_url(webui_url) + node_id = ray_start_with_dashboard["node_id"] timeout_seconds = 10 start_time = time.time() @@ -47,13 +49,13 @@ def test_node_info(ray_start_with_dashboard): assert len(hostname_list) == 1 hostname = hostname_list[0] - response = requests.get(webui_url + f"/nodes/{hostname}") + response = requests.get(webui_url + f"/nodes/{node_id}") response.raise_for_status() detail = response.json() assert detail["result"] is True, detail["msg"] detail = detail["data"]["detail"] assert detail["hostname"] == hostname - assert detail["state"] == "ALIVE" + assert detail["raylet"]["state"] == "ALIVE" assert "raylet" in detail["cmdline"][0] assert len(detail["workers"]) >= 2 assert len(detail["actors"]) == 2, detail["actors"] @@ -72,7 +74,7 @@ def test_node_info(ray_start_with_dashboard): assert len(summary["data"]["summary"]) == 1 summary = summary["data"]["summary"][0] assert summary["hostname"] == hostname - assert summary["state"] == "ALIVE" + assert summary["raylet"]["state"] == "ALIVE" assert "raylet" in summary["cmdline"][0] assert "workers" not in summary assert "actors" not in summary @@ -89,5 +91,41 @@ def test_node_info(ray_start_with_dashboard): raise Exception(f"Timed out while testing, {ex_stack}") +@pytest.mark.parametrize( + "ray_start_cluster_head", [{ + "include_dashboard": True + }], indirect=True) +def test_multi_nodes_info(enable_test_module, ray_start_cluster_head): + cluster = ray_start_cluster_head + assert (wait_until_server_available(cluster.webui_url) is True) + webui_url = cluster.webui_url + webui_url = format_web_url(webui_url) + cluster.add_node() + cluster.add_node() + + def _check_nodes(): + try: + response = requests.get(webui_url + "/nodes?view=summary") + response.raise_for_status() + summary = response.json() + assert summary["result"] is True, summary["msg"] + summary = summary["data"]["summary"] + assert len(summary) == 3 + for node_info in summary: + node_id = node_info["raylet"]["nodeId"] + response = requests.get(webui_url + f"/nodes/{node_id}") + response.raise_for_status() + detail = response.json() + assert detail["result"] is True, detail["msg"] + detail = detail["data"]["detail"] + assert detail["raylet"]["state"] == "ALIVE" + return True + except Exception as ex: + logger.info(ex) + return False + + wait_for_condition(_check_nodes, timeout=10) + + if __name__ == "__main__": sys.exit(pytest.main(["-v", __file__])) diff --git a/dashboard/tests/test_dashboard.py b/dashboard/tests/test_dashboard.py index faa3c1335..a88061edf 100644 --- a/dashboard/tests/test_dashboard.py +++ b/dashboard/tests/test_dashboard.py @@ -63,6 +63,7 @@ def test_basic(ray_start_with_dashboard): assert (wait_until_server_available(ray_start_with_dashboard["webui_url"]) is True) address_info = ray_start_with_dashboard + node_id = address_info["node_id"] address = address_info["redis_address"] address = address.split(":") assert len(address) == 2 @@ -139,8 +140,7 @@ def test_basic(ray_start_with_dashboard): dashboard_rpc_address = client.get( dashboard_consts.REDIS_KEY_DASHBOARD_RPC) assert dashboard_rpc_address is not None - key = "{}{}".format(dashboard_consts.DASHBOARD_AGENT_PORT_PREFIX, - address[0]) + key = f"{dashboard_consts.DASHBOARD_AGENT_PORT_PREFIX}{node_id}" agent_ports = client.get(key) assert agent_ports is not None @@ -167,10 +167,10 @@ def test_nodes_update(enable_test_module, ray_start_with_dashboard): dump_data = dump_info["data"] assert len(dump_data["nodes"]) == 1 assert len(dump_data["agents"]) == 1 - assert len(dump_data["hostnameToIp"]) == 1 - assert len(dump_data["ipToHostname"]) == 1 + assert len(dump_data["nodeIdToIp"]) == 1 + assert len(dump_data["nodeIdToHostname"]) == 1 assert dump_data["nodes"].keys() == dump_data[ - "ipToHostname"].keys() + "nodeIdToHostname"].keys() response = requests.get(webui_url + "/test/notified_agents") response.raise_for_status() @@ -215,7 +215,8 @@ def test_http_get(enable_test_module, ray_start_with_dashboard): assert dump_info["result"] is True dump_data = dump_info["data"] assert len(dump_data["agents"]) == 1 - ip, ports = next(iter(dump_data["agents"].items())) + node_id, ports = next(iter(dump_data["agents"].items())) + ip = ray_start_with_dashboard["node_ip_address"] http_port, grpc_port = ports response = requests.get( diff --git a/python/ray/_raylet.pyx b/python/ray/_raylet.pyx index 110e6f881..211563072 100644 --- a/python/ray/_raylet.pyx +++ b/python/ray/_raylet.pyx @@ -783,6 +783,10 @@ cdef class CoreWorker: return JobID( CCoreWorkerProcess.GetCoreWorker().GetCurrentJobId().Binary()) + def get_current_node_id(self): + return ClientID( + CCoreWorkerProcess.GetCoreWorker().GetCurrentNodeId().Binary()) + def get_actor_id(self): return ActorID( CCoreWorkerProcess.GetCoreWorker().GetActorId().Binary()) diff --git a/python/ray/includes/libcoreworker.pxd b/python/ray/includes/libcoreworker.pxd index f7a0d14c5..3acc77501 100644 --- a/python/ray/includes/libcoreworker.pxd +++ b/python/ray/includes/libcoreworker.pxd @@ -121,6 +121,7 @@ cdef extern from "ray/core_worker/core_worker.h" nogil: CJobID GetCurrentJobId() CTaskID GetCurrentTaskId() + CClientID GetCurrentNodeId() const CActorID &GetActorId() void SetActorTitle(const c_string &title) void SetWebuiDisplay(const c_string &key, const c_string &message) diff --git a/python/ray/worker.py b/python/ray/worker.py index 8061c5522..5a90073c6 100644 --- a/python/ray/worker.py +++ b/python/ray/worker.py @@ -753,7 +753,8 @@ def init( for hook in _post_init_hooks: hook() - return _global_node.address_info + node_id = global_worker.core_worker.get_current_node_id() + return dict(_global_node.address_info, node_id=node_id.hex()) # Functions to run as callback after a successful ray init. diff --git a/src/ray/core_worker/core_worker.h b/src/ray/core_worker/core_worker.h index 900c4c482..c1e42d59b 100644 --- a/src/ray/core_worker/core_worker.h +++ b/src/ray/core_worker/core_worker.h @@ -350,6 +350,10 @@ class CoreWorker : public rpc::CoreWorkerServiceHandler { const JobID &GetCurrentJobId() const { return worker_context_.GetCurrentJobID(); } + ClientID GetCurrentNodeId() const { + return ClientID::FromBinary(rpc_address_.raylet_id()); + } + void SetWebuiDisplay(const std::string &key, const std::string &message); void SetActorTitle(const std::string &title); diff --git a/src/ray/raylet/agent_manager.cc b/src/ray/raylet/agent_manager.cc index af4c482d3..23b8769c8 100644 --- a/src/ray/raylet/agent_manager.cc +++ b/src/ray/raylet/agent_manager.cc @@ -57,7 +57,10 @@ void AgentManager::StartAgent() { argv.push_back(arg.c_str()); } argv.push_back(NULL); - Process child(argv.data(), nullptr, ec); + // Set node id to agent. + ProcessEnvironment env; + env.insert({"RAY_NODE_ID", options_.node_id.Hex()}); + Process child(argv.data(), nullptr, ec, false, env); if (!child.IsValid() || ec) { // The worker failed to start. This is a fatal error. RAY_LOG(FATAL) << "Failed to start agent with return value " << ec << ": " diff --git a/src/ray/raylet/agent_manager.h b/src/ray/raylet/agent_manager.h index 29f19d4c6..3c79b31ba 100644 --- a/src/ray/raylet/agent_manager.h +++ b/src/ray/raylet/agent_manager.h @@ -18,6 +18,7 @@ #include #include +#include "ray/common/id.h" #include "ray/rpc/agent_manager/agent_manager_client.h" #include "ray/rpc/agent_manager/agent_manager_server.h" #include "ray/util/process.h" @@ -32,6 +33,7 @@ typedef std::function(std::function class AgentManager : public rpc::AgentManagerServiceHandler { public: struct Options { + const ClientID node_id; std::vector agent_commands; }; diff --git a/src/ray/raylet/node_manager.cc b/src/ray/raylet/node_manager.cc index 6eab5a345..159cf3e34 100644 --- a/src/ray/raylet/node_manager.cc +++ b/src/ray/raylet/node_manager.cc @@ -208,8 +208,8 @@ NodeManager::NodeManager(boost::asio::io_service &io_service, node_manager_server_.RegisterService(agent_manager_service_); node_manager_server_.Run(); - AgentManager::Options options; - options.agent_commands = ParseCommandLine(config.agent_command); + auto options = + AgentManager::Options({self_node_id, ParseCommandLine(config.agent_command)}); agent_manager_.reset( new AgentManager(std::move(options), /*delay_executor=*/