Revert "Revert "[Job Submission][refactor 1/N] Add AgentInfo to GCSNodeInfo (…" (#27308)

This commit is contained in:
Jialing He 2022-08-05 16:32:48 +08:00 committed by GitHub
parent b11d3061d8
commit ccf411604e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
23 changed files with 226 additions and 115 deletions

View file

@ -1,7 +1,6 @@
import argparse import argparse
import asyncio import asyncio
import io import io
import json
import logging import logging
import logging.handlers import logging.handlers
import os import os
@ -13,11 +12,10 @@ import ray._private.services
import ray._private.utils import ray._private.utils
import ray.dashboard.consts as dashboard_consts import ray.dashboard.consts as dashboard_consts
import ray.dashboard.utils as dashboard_utils import ray.dashboard.utils as dashboard_utils
import ray.experimental.internal_kv as internal_kv
from ray._private.gcs_pubsub import GcsAioPublisher, GcsPublisher from ray._private.gcs_pubsub import GcsAioPublisher, GcsPublisher
from ray._private.gcs_utils import GcsAioClient, GcsClient from ray._private.gcs_utils import GcsAioClient, GcsClient
from ray._private.ray_logging import setup_component_logger from ray._private.ray_logging import setup_component_logger
from ray.core.generated import agent_manager_pb2, agent_manager_pb2_grpc from ray.core.generated import agent_manager_pb2, agent_manager_pb2_grpc, common_pb2
from ray.experimental.internal_kv import ( from ray.experimental.internal_kv import (
_initialize_internal_kv, _initialize_internal_kv,
_internal_kv_initialized, _internal_kv_initialized,
@ -262,22 +260,20 @@ class DashboardAgent:
# TODO: Use async version if performance is an issue # TODO: Use async version if performance is an issue
# -1 should indicate that http server is not started. # -1 should indicate that http server is not started.
http_port = -1 if not self.http_server else self.http_server.http_port http_port = -1 if not self.http_server else self.http_server.http_port
internal_kv._internal_kv_put(
f"{dashboard_consts.DASHBOARD_AGENT_PORT_PREFIX}{self.node_id}",
json.dumps([http_port, self.grpc_port]),
namespace=ray_constants.KV_NAMESPACE_DASHBOARD,
)
# Register agent to agent manager. # Register agent to agent manager.
raylet_stub = agent_manager_pb2_grpc.AgentManagerServiceStub( raylet_stub = agent_manager_pb2_grpc.AgentManagerServiceStub(
self.aiogrpc_raylet_channel self.aiogrpc_raylet_channel
) )
await raylet_stub.RegisterAgent( await raylet_stub.RegisterAgent(
agent_manager_pb2.RegisterAgentRequest( agent_manager_pb2.RegisterAgentRequest(
agent_id=self.agent_id, agent_info=common_pb2.AgentInfo(
agent_port=self.grpc_port, id=self.agent_id,
agent_ip_address=self.ip, pid=os.getpid(),
grpc_port=self.grpc_port,
http_port=http_port,
ip_address=self.ip,
)
) )
) )

View file

@ -1,7 +1,6 @@
from ray._private.ray_constants import env_integer from ray._private.ray_constants import env_integer
DASHBOARD_LOG_FILENAME = "dashboard.log" DASHBOARD_LOG_FILENAME = "dashboard.log"
DASHBOARD_AGENT_PORT_PREFIX = "DASHBOARD_AGENT_PORT_PREFIX:"
DASHBOARD_AGENT_LOG_FILENAME = "dashboard_agent.log" DASHBOARD_AGENT_LOG_FILENAME = "dashboard_agent.log"
DASHBOARD_AGENT_CHECK_PARENT_INTERVAL_SECONDS = 2 DASHBOARD_AGENT_CHECK_PARENT_INTERVAL_SECONDS = 2
RAY_STATE_SERVER_MAX_HTTP_REQUEST_ENV_NAME = "RAY_STATE_SERVER_MAX_HTTP_REQUEST" RAY_STATE_SERVER_MAX_HTTP_REQUEST_ENV_NAME = "RAY_STATE_SERVER_MAX_HTTP_REQUEST"

View file

@ -1,5 +1,4 @@
import asyncio import asyncio
import json
import logging import logging
import re import re
import time import time
@ -7,7 +6,6 @@ import time
import aiohttp.web import aiohttp.web
import ray._private.utils import ray._private.utils
import ray.dashboard.consts as dashboard_consts
import ray.dashboard.optional_utils as dashboard_optional_utils import ray.dashboard.optional_utils as dashboard_optional_utils
import ray.dashboard.utils as dashboard_utils import ray.dashboard.utils as dashboard_utils
from ray._private import ray_constants from ray._private import ray_constants
@ -131,10 +129,9 @@ class NodeHead(dashboard_utils.DashboardHeadModule):
try: try:
nodes = await self._get_nodes() nodes = await self._get_nodes()
alive_node_ids = []
alive_node_infos = []
node_id_to_ip = {} node_id_to_ip = {}
node_id_to_hostname = {} node_id_to_hostname = {}
agents = dict(DataSource.agents)
for node in nodes.values(): for node in nodes.values():
node_id = node["nodeId"] node_id = node["nodeId"]
ip = node["nodeManagerAddress"] ip = node["nodeManagerAddress"]
@ -150,20 +147,10 @@ class NodeHead(dashboard_utils.DashboardHeadModule):
node_id_to_hostname[node_id] = hostname node_id_to_hostname[node_id] = hostname
assert node["state"] in ["ALIVE", "DEAD"] assert node["state"] in ["ALIVE", "DEAD"]
if node["state"] == "ALIVE": if node["state"] == "ALIVE":
alive_node_ids.append(node_id) agents[node_id] = [
alive_node_infos.append(node) node["agentInfo"]["httpPort"],
node["agentInfo"]["grpcPort"],
agents = dict(DataSource.agents) ]
for node_id in alive_node_ids:
key = f"{dashboard_consts.DASHBOARD_AGENT_PORT_PREFIX}" f"{node_id}"
# TODO: Use async version if performance is an issue
agent_port = ray.experimental.internal_kv._internal_kv_get(
key, namespace=ray_constants.KV_NAMESPACE_DASHBOARD
)
if agent_port:
agents[node_id] = json.loads(agent_port)
for node_id in agents.keys() - set(alive_node_ids):
agents.pop(node_id, None)
DataSource.node_id_to_ip.reset(node_id_to_ip) DataSource.node_id_to_ip.reset(node_id_to_ip)
DataSource.node_id_to_hostname.reset(node_id_to_hostname) DataSource.node_id_to_hostname.reset(node_id_to_hostname)

View file

@ -6,6 +6,7 @@ import time
import traceback import traceback
import random import random
import pytest import pytest
import psutil
import ray import ray
import threading import threading
from datetime import datetime, timedelta from datetime import datetime, timedelta
@ -18,6 +19,7 @@ from ray._private.test_utils import (
wait_for_condition, wait_for_condition,
wait_until_succeeded_without_exception, wait_until_succeeded_without_exception,
) )
from ray._private.state import state
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -348,5 +350,33 @@ def test_frequent_node_update(
wait_for_condition(verify, timeout=15) wait_for_condition(verify, timeout=15)
# See detail: https://github.com/ray-project/ray/issues/24361
@pytest.mark.skipif(sys.platform == "win32", reason="Flaky on Windows.")
def test_node_register_with_agent(ray_start_cluster_head):
def test_agent_port(pid, port):
p = psutil.Process(pid)
assert p.cmdline()[2].endswith("dashboard/agent.py")
for c in p.connections():
if c.status == psutil.CONN_LISTEN and c.laddr.port == port:
return
assert False
def test_agent_process(pid):
p = psutil.Process(pid)
assert p.cmdline()[2].endswith("dashboard/agent.py")
for node_info in state.node_table():
agent_info = node_info["AgentInfo"]
assert agent_info["IpAddress"] == node_info["NodeManagerAddress"]
test_agent_port(agent_info["Pid"], agent_info["GrpcPort"])
if agent_info["HttpPort"] >= 0:
test_agent_port(agent_info["Pid"], agent_info["HttpPort"])
else:
# Port conflicts may be caused that the previous
# test did not kill the agent cleanly
assert agent_info["HttpPort"] == -1
if __name__ == "__main__": if __name__ == "__main__":
sys.exit(pytest.main(["-v", __file__])) sys.exit(pytest.main(["-v", __file__]))

View file

@ -110,7 +110,6 @@ def test_basic(ray_start_with_dashboard):
"""Dashboard test that starts a Ray cluster with a dashboard server running, """Dashboard test that starts a Ray cluster with a dashboard server running,
then hits the dashboard API and asserts that it receives sensible data.""" then hits the dashboard API and asserts that it receives sensible data."""
address_info = ray_start_with_dashboard address_info = ray_start_with_dashboard
node_id = address_info["node_id"]
gcs_client = make_gcs_client(address_info) gcs_client = make_gcs_client(address_info)
ray.experimental.internal_kv._initialize_internal_kv(gcs_client) ray.experimental.internal_kv._initialize_internal_kv(gcs_client)
@ -146,11 +145,6 @@ def test_basic(ray_start_with_dashboard):
namespace=ray_constants.KV_NAMESPACE_DASHBOARD, namespace=ray_constants.KV_NAMESPACE_DASHBOARD,
) )
assert dashboard_rpc_address is not None assert dashboard_rpc_address is not None
key = f"{dashboard_consts.DASHBOARD_AGENT_PORT_PREFIX}{node_id}"
agent_ports = ray.experimental.internal_kv._internal_kv_get(
key, namespace=ray_constants.KV_NAMESPACE_DASHBOARD
)
assert agent_ports is not None
def test_raylet_and_agent_share_fate(shutdown_only): def test_raylet_and_agent_share_fate(shutdown_only):
@ -795,7 +789,6 @@ def test_dashboard_port_conflict(ray_start_with_dashboard):
) )
def test_gcs_check_alive(fast_gcs_failure_detection, ray_start_with_dashboard): def test_gcs_check_alive(fast_gcs_failure_detection, ray_start_with_dashboard):
assert wait_until_server_available(ray_start_with_dashboard["webui_url"]) is True assert wait_until_server_available(ray_start_with_dashboard["webui_url"]) is True
all_processes = ray._private.worker._global_node.all_processes all_processes = ray._private.worker._global_node.all_processes
dashboard_info = all_processes[ray_constants.PROCESS_TYPE_DASHBOARD][0] dashboard_info = all_processes[ray_constants.PROCESS_TYPE_DASHBOARD][0]
dashboard_proc = psutil.Process(dashboard_info.process.pid) dashboard_proc = psutil.Process(dashboard_info.process.pid)

View file

@ -164,6 +164,12 @@ class GlobalState:
"RayletSocketName": item.raylet_socket_name, "RayletSocketName": item.raylet_socket_name,
"MetricsExportPort": item.metrics_export_port, "MetricsExportPort": item.metrics_export_port,
"NodeName": item.node_name, "NodeName": item.node_name,
"AgentInfo": {
"IpAddress": item.agent_info.ip_address,
"GrpcPort": item.agent_info.grpc_port,
"HttpPort": item.agent_info.http_port,
"Pid": item.agent_info.pid,
},
} }
node_info["alive"] = node_info["Alive"] node_info["alive"] = node_info["Alive"]
node_info["Resources"] = ( node_info["Resources"] = (

View file

@ -1,6 +1,7 @@
import atexit import atexit
import faulthandler import faulthandler
import functools import functools
import grpc
import hashlib import hashlib
import inspect import inspect
import io import io
@ -756,7 +757,6 @@ class Worker:
def print_logs(self): def print_logs(self):
"""Prints log messages from workers on all nodes in the same job.""" """Prints log messages from workers on all nodes in the same job."""
import grpc
subscriber = self.gcs_log_subscriber subscriber = self.gcs_log_subscriber
subscriber.subscribe() subscriber.subscribe()
@ -1902,6 +1902,11 @@ def connect(
if mode == SCRIPT_MODE: if mode == SCRIPT_MODE:
raise e raise e
elif mode == WORKER_MODE: elif mode == WORKER_MODE:
if isinstance(e, grpc.RpcError) and e.code() in (
grpc.StatusCode.UNAVAILABLE,
grpc.StatusCode.UNKNOWN,
):
raise e
traceback_str = traceback.format_exc() traceback_str = traceback.format_exc()
ray._private.utils.publish_error_to_driver( ray._private.utils.publish_error_to_driver(
ray_constants.VERSION_MISMATCH_PUSH_ERROR, ray_constants.VERSION_MISMATCH_PUSH_ERROR,

View file

@ -68,7 +68,8 @@ py_test_module_list(
"test_healthcheck.py", "test_healthcheck.py",
"test_kill_raylet_signal_log.py", "test_kill_raylet_signal_log.py",
"test_memstat.py", "test_memstat.py",
"test_protobuf_compatibility.py" "test_protobuf_compatibility.py",
"test_scheduling_performance.py"
], ],
size = "medium", size = "medium",
tags = ["exclusive", "medium_size_python_tests_a_to_j", "team:core"], tags = ["exclusive", "medium_size_python_tests_a_to_j", "team:core"],
@ -120,10 +121,8 @@ py_test_module_list(
"test_multi_node_2.py", "test_multi_node_2.py",
"test_multinode_failures.py", "test_multinode_failures.py",
"test_multinode_failures_2.py", "test_multinode_failures_2.py",
"test_multiprocessing.py",
"test_object_assign_owner.py", "test_object_assign_owner.py",
"test_placement_group.py", "test_placement_group.py",
"test_placement_group_2.py",
"test_placement_group_3.py", "test_placement_group_3.py",
"test_placement_group_4.py", "test_placement_group_4.py",
"test_placement_group_5.py", "test_placement_group_5.py",
@ -184,7 +183,6 @@ py_test_module_list(
"test_cross_language.py", "test_cross_language.py",
"test_environ.py", "test_environ.py",
"test_raylet_output.py", "test_raylet_output.py",
"test_scheduling_performance.py",
"test_get_or_create_actor.py", "test_get_or_create_actor.py",
], ],
size = "small", size = "small",
@ -295,6 +293,7 @@ py_test_module_list(
"test_placement_group_mini_integration.py", "test_placement_group_mini_integration.py",
"test_scheduling_2.py", "test_scheduling_2.py",
"test_multi_node_3.py", "test_multi_node_3.py",
"test_placement_group_2.py",
], ],
size = "large", size = "large",
tags = ["exclusive", "large_size_python_tests_shard_1", "team:core"], tags = ["exclusive", "large_size_python_tests_shard_1", "team:core"],

View file

@ -15,6 +15,7 @@ from pathlib import Path
from tempfile import gettempdir from tempfile import gettempdir
from typing import List, Tuple from typing import List, Tuple
from unittest import mock from unittest import mock
import signal
import pytest import pytest
@ -204,10 +205,19 @@ def _ray_start(**kwargs):
init_kwargs.update(kwargs) init_kwargs.update(kwargs)
# Start the Ray processes. # Start the Ray processes.
address_info = ray.init("local", **init_kwargs) address_info = ray.init("local", **init_kwargs)
agent_pids = []
for node in ray.nodes():
agent_pids.append(int(node["AgentInfo"]["Pid"]))
yield address_info yield address_info
# The code after the yield will run as teardown code. # The code after the yield will run as teardown code.
ray.shutdown() ray.shutdown()
# Make sure the agent process is dead.
for pid in agent_pids:
try:
os.kill(pid, signal.SIGKILL)
except Exception:
pass
# Delete the cluster address just in case. # Delete the cluster address just in case.
ray._private.utils.reset_ray_address() ray._private.utils.reset_ray_address()

View file

@ -834,8 +834,9 @@ def test_ray_status(shutdown_only, monkeypatch):
@pytest.mark.xfail(cluster_not_supported, reason="cluster not supported on Windows") @pytest.mark.xfail(cluster_not_supported, reason="cluster not supported on Windows")
def test_ray_status_multinode(ray_start_cluster): def test_ray_status_multinode(ray_start_cluster):
NODE_NUMBER = 4
cluster = ray_start_cluster cluster = ray_start_cluster
for _ in range(4): for _ in range(NODE_NUMBER):
cluster.add_node(num_cpus=2) cluster.add_node(num_cpus=2)
runner = CliRunner() runner = CliRunner()
@ -850,8 +851,12 @@ def test_ray_status_multinode(ray_start_cluster):
wait_for_condition(output_ready) wait_for_condition(output_ready)
result = runner.invoke(scripts.status, []) def check_result():
_check_output_via_pattern("test_ray_status_multinode.txt", result) result = runner.invoke(scripts.status, [])
_check_output_via_pattern("test_ray_status_multinode.txt", result)
return True
wait_for_condition(check_result)
@pytest.mark.skipif( @pytest.mark.skipif(

View file

@ -11,17 +11,29 @@ from ray._private.test_utils import wait_for_condition, run_string_as_driver_non
def get_all_ray_worker_processes(): def get_all_ray_worker_processes():
processes = [ processes = psutil.process_iter(attrs=["pid", "name", "cmdline"])
p.info["cmdline"] for p in psutil.process_iter(attrs=["pid", "name", "cmdline"])
]
result = [] result = []
for p in processes: for p in processes:
if p is not None and len(p) > 0 and "ray::" in p[0]: cmd_line = p.info["cmdline"]
if cmd_line is not None and len(cmd_line) > 0 and "ray::" in cmd_line[0]:
result.append(p) result.append(p)
return result return result
@pytest.fixture
def kill_all_ray_worker_process():
# Avoiding the previous test doesn't kill the relevant process,
# thus making the current test fail.
ray_process = get_all_ray_worker_processes()
for p in ray_process:
try:
p.kill()
except Exception:
pass
yield
@pytest.fixture @pytest.fixture
def short_gcs_publish_timeout(monkeypatch): def short_gcs_publish_timeout(monkeypatch):
monkeypatch.setenv("RAY_MAX_GCS_PUBLISH_RETRIES", "3") monkeypatch.setenv("RAY_MAX_GCS_PUBLISH_RETRIES", "3")
@ -29,8 +41,11 @@ def short_gcs_publish_timeout(monkeypatch):
@pytest.mark.skipif(platform.system() == "Windows", reason="Hang on Windows.") @pytest.mark.skipif(platform.system() == "Windows", reason="Hang on Windows.")
def test_ray_shutdown(short_gcs_publish_timeout, shutdown_only): def test_ray_shutdown(
kill_all_ray_worker_process, short_gcs_publish_timeout, shutdown_only
):
"""Make sure all ray workers are shutdown when driver is done.""" """Make sure all ray workers are shutdown when driver is done."""
ray.init() ray.init()
@ray.remote @ray.remote
@ -45,12 +60,15 @@ def test_ray_shutdown(short_gcs_publish_timeout, shutdown_only):
ray.shutdown() ray.shutdown()
wait_for_condition(lambda: len(get_all_ray_worker_processes()) == 0) wait_for_condition(lambda: len(get_all_ray_worker_processes()) == 0, timeout=20)
@pytest.mark.skipif(platform.system() == "Windows", reason="Hang on Windows.") @pytest.mark.skipif(platform.system() == "Windows", reason="Hang on Windows.")
def test_driver_dead(short_gcs_publish_timeout, shutdown_only): def test_driver_dead(
kill_all_ray_worker_process, short_gcs_publish_timeout, shutdown_only
):
"""Make sure all ray workers are shutdown when driver is killed.""" """Make sure all ray workers are shutdown when driver is killed."""
driver = """ driver = """
import ray import ray
ray.init(_system_config={"gcs_rpc_server_reconnect_timeout_s": 1}) ray.init(_system_config={"gcs_rpc_server_reconnect_timeout_s": 1})
@ -74,12 +92,15 @@ tasks = [f.remote() for _ in range(num_cpus)]
p.wait() p.wait()
time.sleep(0.1) time.sleep(0.1)
wait_for_condition(lambda: len(get_all_ray_worker_processes()) == 0) wait_for_condition(lambda: len(get_all_ray_worker_processes()) == 0, timeout=20)
@pytest.mark.skipif(platform.system() == "Windows", reason="Hang on Windows.") @pytest.mark.skipif(platform.system() == "Windows", reason="Hang on Windows.")
def test_node_killed(short_gcs_publish_timeout, ray_start_cluster): def test_node_killed(
kill_all_ray_worker_process, short_gcs_publish_timeout, ray_start_cluster
):
"""Make sure all ray workers when nodes are dead.""" """Make sure all ray workers when nodes are dead."""
cluster = ray_start_cluster cluster = ray_start_cluster
# head node. # head node.
cluster.add_node( cluster.add_node(
@ -106,12 +127,15 @@ def test_node_killed(short_gcs_publish_timeout, ray_start_cluster):
for worker in workers: for worker in workers:
cluster.remove_node(worker) cluster.remove_node(worker)
wait_for_condition(lambda: len(get_all_ray_worker_processes()) == 0) wait_for_condition(lambda: len(get_all_ray_worker_processes()) == 0, timeout=20)
@pytest.mark.skipif(platform.system() == "Windows", reason="Hang on Windows.") @pytest.mark.skipif(platform.system() == "Windows", reason="Hang on Windows.")
def test_head_node_down(short_gcs_publish_timeout, ray_start_cluster): def test_head_node_down(
kill_all_ray_worker_process, short_gcs_publish_timeout, ray_start_cluster
):
"""Make sure all ray workers when head node is dead.""" """Make sure all ray workers when head node is dead."""
cluster = ray_start_cluster cluster = ray_start_cluster
# head node. # head node.
head = cluster.add_node( head = cluster.add_node(
@ -149,7 +173,7 @@ time.sleep(100)
cluster.remove_node(head) cluster.remove_node(head)
wait_for_condition(lambda: len(get_all_ray_worker_processes()) == 0) wait_for_condition(lambda: len(get_all_ray_worker_processes()) == 0, timeout=20)
if __name__ == "__main__": if __name__ == "__main__":

View file

@ -10,7 +10,6 @@ import yaml
from click.testing import CliRunner from click.testing import CliRunner
import ray import ray
import ray.dashboard.consts as dashboard_consts
import ray._private.state as global_state import ray._private.state as global_state
import ray._private.ray_constants as ray_constants import ray._private.ray_constants as ray_constants
from ray._private.test_utils import ( from ray._private.test_utils import (
@ -1180,16 +1179,8 @@ async def test_state_data_source_client(ray_start_cluster):
wait_for_condition(lambda: len(ray.nodes()) == 2) wait_for_condition(lambda: len(ray.nodes()) == 2)
for node in ray.nodes(): for node in ray.nodes():
node_id = node["NodeID"] node_id = node["NodeID"]
key = f"{dashboard_consts.DASHBOARD_AGENT_PORT_PREFIX}{node_id}"
def get_port():
return ray.experimental.internal_kv._internal_kv_get(
key, namespace=ray_constants.KV_NAMESPACE_DASHBOARD
)
wait_for_condition(lambda: get_port() is not None)
# The second index is the gRPC port # The second index is the gRPC port
port = json.loads(get_port())[1] port = node["AgentInfo"]["GrpcPort"]
ip = node["NodeManagerAddress"] ip = node["NodeManagerAddress"]
client.register_agent_client(node_id, ip, port) client.register_agent_client(node_id, ip, port)
result = await client.get_runtime_envs_info(node_id) result = await client.get_runtime_envs_info(node_id)
@ -1391,16 +1382,8 @@ async def test_state_data_source_client_limit_distributed_sources(ray_start_clus
""" """
for node in ray.nodes(): for node in ray.nodes():
node_id = node["NodeID"] node_id = node["NodeID"]
key = f"{dashboard_consts.DASHBOARD_AGENT_PORT_PREFIX}{node_id}"
def get_port():
return ray.experimental.internal_kv._internal_kv_get(
key, namespace=ray_constants.KV_NAMESPACE_DASHBOARD
)
wait_for_condition(lambda: get_port() is not None)
# The second index is the gRPC port # The second index is the gRPC port
port = json.loads(get_port())[1] port = node["AgentInfo"]["GrpcPort"]
ip = node["NodeManagerAddress"] ip = node["NodeManagerAddress"]
client.register_agent_client(node_id, ip, port) client.register_agent_client(node_id, ip, port)
@ -1513,8 +1496,13 @@ def test_cli_apis_sanity_check(ray_start_cluster):
) )
) )
# Test get workers by id # Test get workers by id
# Still need a `wait_for_condition`,
# because the worker obtained through the api server will not filter the driver,
# but `global_state.workers` will filter the driver.
wait_for_condition(lambda: len(global_state.workers()) > 0)
workers = global_state.workers() workers = global_state.workers()
assert len(workers) > 0
worker_id = list(workers.keys())[0] worker_id = list(workers.keys())[0]
wait_for_condition( wait_for_condition(
lambda: verify_output(ray_get, ["workers", worker_id], ["worker_id", worker_id]) lambda: verify_output(ray_get, ["workers", worker_id], ["worker_id", worker_id])

View file

@ -243,7 +243,7 @@ python_grpc_compile(
proto_library( proto_library(
name = "agent_manager_proto", name = "agent_manager_proto",
srcs = ["agent_manager.proto"], srcs = ["agent_manager.proto"],
deps = [], deps = [":common_proto"],
) )
python_grpc_compile( python_grpc_compile(

View file

@ -17,6 +17,8 @@ option cc_enable_arenas = true;
package ray.rpc; package ray.rpc;
import "src/ray/protobuf/common.proto";
enum AgentRpcStatus { enum AgentRpcStatus {
// OK. // OK.
AGENT_RPC_STATUS_OK = 0; AGENT_RPC_STATUS_OK = 0;
@ -25,9 +27,7 @@ enum AgentRpcStatus {
} }
message RegisterAgentRequest { message RegisterAgentRequest {
int32 agent_id = 1; AgentInfo agent_info = 1;
int32 agent_port = 2;
string agent_ip_address = 3;
} }
message RegisterAgentReply { message RegisterAgentReply {

View file

@ -677,3 +677,17 @@ message NamedActorInfo {
string ray_namespace = 1; string ray_namespace = 1;
string name = 2; string name = 2;
} }
// Info about a agent process.
message AgentInfo {
// The agent id.
int32 id = 1;
// The agent process pid.
int64 pid = 2;
// IP address of the agent process.
string ip_address = 3;
// The GRPC port number of the agent process.
int32 grpc_port = 4;
// The http port number of the agent process.
int32 http_port = 5;
}

View file

@ -247,6 +247,8 @@ message GcsNodeInfo {
// The user-provided identifier or name for this node. // The user-provided identifier or name for this node.
string node_name = 12; string node_name = 12;
// The information of the agent process.
AgentInfo agent_info = 13;
} }
message HeartbeatTableData { message HeartbeatTableData {

View file

@ -29,19 +29,20 @@ namespace raylet {
void AgentManager::HandleRegisterAgent(const rpc::RegisterAgentRequest &request, void AgentManager::HandleRegisterAgent(const rpc::RegisterAgentRequest &request,
rpc::RegisterAgentReply *reply, rpc::RegisterAgentReply *reply,
rpc::SendReplyCallback send_reply_callback) { rpc::SendReplyCallback send_reply_callback) {
reported_agent_ip_address_ = request.agent_ip_address(); reported_agent_info_.CopyFrom(request.agent_info());
reported_agent_port_ = request.agent_port();
reported_agent_id_ = request.agent_id();
// TODO(SongGuyang): We should remove this after we find better port resolution. // TODO(SongGuyang): We should remove this after we find better port resolution.
// Note: `agent_port_` should be 0 if the grpc port of agent is in conflict. // Note: `reported_agent_info_.grpc_port()` should be 0 if the grpc port of agent is in
if (reported_agent_port_ != 0) { // conflict.
if (reported_agent_info_.grpc_port() != 0) {
runtime_env_agent_client_ = runtime_env_agent_client_factory_( runtime_env_agent_client_ = runtime_env_agent_client_factory_(
reported_agent_ip_address_, reported_agent_port_); reported_agent_info_.ip_address(), reported_agent_info_.grpc_port());
RAY_LOG(INFO) << "HandleRegisterAgent, ip: " << reported_agent_ip_address_ RAY_LOG(INFO) << "HandleRegisterAgent, ip: " << reported_agent_info_.ip_address()
<< ", port: " << reported_agent_port_ << ", id: " << reported_agent_id_; << ", port: " << reported_agent_info_.grpc_port()
<< ", id: " << reported_agent_info_.id();
} else { } else {
RAY_LOG(WARNING) << "The GRPC port of the Ray agent is invalid (0), ip: " RAY_LOG(WARNING) << "The GRPC port of the Ray agent is invalid (0), ip: "
<< reported_agent_ip_address_ << ", id: " << reported_agent_id_ << reported_agent_info_.ip_address()
<< ", id: " << reported_agent_info_.id()
<< ". The agent client in the raylet has been disabled."; << ". The agent client in the raylet has been disabled.";
disable_agent_client_ = true; disable_agent_client_ = true;
} }
@ -56,16 +57,19 @@ void AgentManager::StartAgent() {
return; return;
} }
// Create a non-zero random agent_id to pass to the child process // Create a non-zero random agent_id_ to pass to the child process
// We cannot use pid an id because os.getpid() from the python process is not // We cannot use pid an id because os.getpid() from the python process is not
// reliable when using a launcher. // reliable when using a launcher.
// See https://github.com/ray-project/ray/issues/24361 and Python issue // See https://github.com/ray-project/ray/issues/24361 and Python issue
// https://github.com/python/cpython/issues/83086 // https://github.com/python/cpython/issues/83086
int agent_id = 0; agent_id_ = 0;
while (agent_id == 0) { while (agent_id_ == 0) {
agent_id = rand(); agent_id_ = rand();
} }
const std::string agent_id_str = std::to_string(agent_id); // Make sure reported_agent_info_.id() not equal
// `agent_id_` before the agent finished register.
reported_agent_info_.set_id(0);
const std::string agent_id_str = std::to_string(agent_id_);
std::vector<const char *> argv; std::vector<const char *> argv;
for (const std::string &arg : options_.agent_commands) { for (const std::string &arg : options_.agent_commands) {
argv.push_back(arg.c_str()); argv.push_back(arg.c_str());
@ -104,21 +108,22 @@ void AgentManager::StartAgent() {
<< ec.message(); << ec.message();
} }
std::thread monitor_thread([this, child, agent_id]() mutable { std::thread monitor_thread([this, child]() mutable {
SetThreadName("agent.monitor"); SetThreadName("agent.monitor");
RAY_LOG(INFO) << "Monitor agent process with id " << agent_id << ", register timeout " RAY_LOG(INFO) << "Monitor agent process with id " << child.GetId()
<< ", register timeout "
<< RayConfig::instance().agent_register_timeout_ms() << "ms."; << RayConfig::instance().agent_register_timeout_ms() << "ms.";
auto timer = delay_executor_( auto timer = delay_executor_(
[this, child, agent_id]() mutable { [this, child]() mutable {
if (reported_agent_id_ != agent_id) { if (!IsAgentRegistered()) {
if (reported_agent_id_ == 0) { if (reported_agent_info_.id() == 0) {
RAY_LOG(WARNING) << "Agent process expected id " << agent_id RAY_LOG(WARNING) << "Agent process expected id " << agent_id_
<< " timed out before registering. ip " << " timed out before registering. ip "
<< reported_agent_ip_address_ << ", id " << reported_agent_info_.ip_address() << ", id "
<< reported_agent_id_; << reported_agent_info_.id();
} else { } else {
RAY_LOG(WARNING) << "Agent process expected id " << agent_id RAY_LOG(WARNING) << "Agent process expected id " << agent_id_
<< " but got id " << reported_agent_id_ << " but got id " << reported_agent_info_.id()
<< ", this is a fatal error"; << ", this is a fatal error";
} }
child.Kill(); child.Kill();
@ -128,9 +133,9 @@ void AgentManager::StartAgent() {
int exit_code = child.Wait(); int exit_code = child.Wait();
timer->cancel(); timer->cancel();
RAY_LOG(WARNING) << "Agent process with id " << agent_id << " exited, return value " RAY_LOG(WARNING) << "Agent process with id " << agent_id_ << " exited, return value "
<< exit_code << ". ip " << reported_agent_ip_address_ << ". id " << exit_code << ". ip " << reported_agent_info_.ip_address()
<< reported_agent_id_; << ". id " << reported_agent_info_.id();
RAY_LOG(ERROR) RAY_LOG(ERROR)
<< "The raylet exited immediately because the Ray agent failed. " << "The raylet exited immediately because the Ray agent failed. "
"The raylet fate shares with the agent. This can happen because the " "The raylet fate shares with the agent. This can happen because the "
@ -303,5 +308,15 @@ void AgentManager::DeleteRuntimeEnvIfPossible(
}); });
} }
const ray::Status AgentManager::TryToGetAgentInfo(rpc::AgentInfo *agent_info) const {
if (IsAgentRegistered()) {
*agent_info = reported_agent_info_;
return ray::Status::OK();
} else {
std::string err_msg = "The agent has not finished register yet.";
return ray::Status::Invalid(err_msg);
}
}
} // namespace raylet } // namespace raylet
} // namespace ray } // namespace ray

View file

@ -23,6 +23,7 @@
#include "ray/rpc/agent_manager/agent_manager_server.h" #include "ray/rpc/agent_manager/agent_manager_server.h"
#include "ray/rpc/runtime_env/runtime_env_client.h" #include "ray/rpc/runtime_env/runtime_env_client.h"
#include "ray/util/process.h" #include "ray/util/process.h"
#include "src/ray/protobuf/gcs.pb.h"
namespace ray { namespace ray {
namespace raylet { namespace raylet {
@ -88,17 +89,28 @@ class AgentManager : public rpc::AgentManagerServiceHandler {
virtual void DeleteRuntimeEnvIfPossible(const std::string &serialized_runtime_env, virtual void DeleteRuntimeEnvIfPossible(const std::string &serialized_runtime_env,
DeleteRuntimeEnvIfPossibleCallback callback); DeleteRuntimeEnvIfPossibleCallback callback);
/// Try to Get the information about the agent process.
///
/// \param[out] agent_info The information of the agent process.
/// \return Status, if successful will return `ray::Status::OK`,
/// otherwise will return `ray::Status::Invalid`.
const ray::Status TryToGetAgentInfo(rpc::AgentInfo *agent_info) const;
private: private:
void StartAgent(); void StartAgent();
const bool IsAgentRegistered() const { return reported_agent_info_.id() == agent_id_; }
private: private:
Options options_; Options options_;
pid_t reported_agent_id_ = 0; /// we need to make sure `agent_id_` and `reported_agent_info_.id()` are not equal
int reported_agent_port_ = 0; /// until the agent process is finished registering, the initial value of
/// `reported_agent_info_.id()` is 0, so I set the initial value of `agent_id_` is -1
int agent_id_ = -1;
rpc::AgentInfo reported_agent_info_;
/// Whether or not we intend to start the agent. This is false if we /// Whether or not we intend to start the agent. This is false if we
/// are missing Ray Dashboard dependencies, for example. /// are missing Ray Dashboard dependencies, for example.
bool should_start_agent_ = true; bool should_start_agent_ = true;
std::string reported_agent_ip_address_;
DelayExecutorFn delay_executor_; DelayExecutorFn delay_executor_;
RuntimeEnvAgentClientFactoryFn runtime_env_agent_client_factory_; RuntimeEnvAgentClientFactoryFn runtime_env_agent_client_factory_;
std::shared_ptr<rpc::RuntimeEnvAgentClientInterface> runtime_env_agent_client_; std::shared_ptr<rpc::RuntimeEnvAgentClientInterface> runtime_env_agent_client_;

View file

@ -2860,6 +2860,10 @@ void NodeManager::PublishInfeasibleTaskError(const RayTask &task) const {
} }
} }
const ray::Status NodeManager::TryToGetAgentInfo(rpc::AgentInfo *agent_info) const {
return agent_manager_->TryToGetAgentInfo(agent_info);
}
} // namespace raylet } // namespace raylet
} // namespace ray } // namespace ray

View file

@ -237,6 +237,13 @@ class NodeManager : public rpc::NodeManagerServiceHandler,
int64_t limit, int64_t limit,
const std::function<void()> &on_all_replied); const std::function<void()> &on_all_replied);
/// Try to Get the information about the agent process.
///
/// \param[out] agent_info The information of the agent process.
/// \return Status, if successful will return `ray::Status::OK`,
/// otherwise will return `ray::Status::Invalid`.
const ray::Status TryToGetAgentInfo(rpc::AgentInfo *agent_info) const;
private: private:
/// Methods for handling nodes. /// Methods for handling nodes.

View file

@ -97,7 +97,6 @@ Raylet::~Raylet() {}
void Raylet::Start() { void Raylet::Start() {
RAY_CHECK_OK(RegisterGcs()); RAY_CHECK_OK(RegisterGcs());
// Start listening for clients. // Start listening for clients.
DoAccept(); DoAccept();
} }
@ -109,6 +108,21 @@ void Raylet::Stop() {
} }
ray::Status Raylet::RegisterGcs() { ray::Status Raylet::RegisterGcs() {
rpc::AgentInfo agent_info;
auto status = node_manager_.TryToGetAgentInfo(&agent_info);
if (status.ok()) {
self_node_info_.mutable_agent_info()->CopyFrom(agent_info);
} else {
// Because current function and `AgentManager::HandleRegisterAgent`
// will be invoke in same thread, so we need post current function
// into main_service_ after interval milliseconds.
std::this_thread::sleep_for(std::chrono::milliseconds(
RayConfig::instance().raylet_get_agent_info_interval_ms()));
main_service_.post([this]() { RAY_CHECK_OK(RegisterGcs()); },
"Raylet.TryToGetAgentInfoAndRegisterGcs");
return Status::OK();
}
auto register_callback = [this](const Status &status) { auto register_callback = [this](const Status &status) {
RAY_CHECK_OK(status); RAY_CHECK_OK(status);
RAY_LOG(INFO) << "Raylet of id, " << self_node_id_ RAY_LOG(INFO) << "Raylet of id, " << self_node_id_

View file

@ -68,7 +68,7 @@ class Raylet {
NodeID GetNodeId() const { return self_node_id_; } NodeID GetNodeId() const { return self_node_id_; }
private: private:
/// Register GCS client. /// Try to get agent info, after its success, register the current node to GCS.
ray::Status RegisterGcs(); ray::Status RegisterGcs();
/// Accept a client connection. /// Accept a client connection.

View file

@ -503,7 +503,8 @@ class WorkerPoolTest : public ::testing::Test {
false); false);
rpc::RegisterAgentRequest request; rpc::RegisterAgentRequest request;
// Set agent port to a nonzero value to avoid invalid agent client. // Set agent port to a nonzero value to avoid invalid agent client.
request.set_agent_port(12345); request.mutable_agent_info()->set_grpc_port(12345);
request.mutable_agent_info()->set_http_port(54321);
rpc::RegisterAgentReply reply; rpc::RegisterAgentReply reply;
auto send_reply_callback = auto send_reply_callback =
[](ray::Status status, std::function<void()> f1, std::function<void()> f2) {}; [](ray::Status status, std::function<void()> f1, std::function<void()> f2) {};