mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
Revert "Revert "[Job Submission][refactor 1/N] Add AgentInfo to GCSNodeInfo (…" (#27308)
This commit is contained in:
parent
b11d3061d8
commit
ccf411604e
23 changed files with 226 additions and 115 deletions
|
@ -1,7 +1,6 @@
|
||||||
import argparse
|
import argparse
|
||||||
import asyncio
|
import asyncio
|
||||||
import io
|
import io
|
||||||
import json
|
|
||||||
import logging
|
import logging
|
||||||
import logging.handlers
|
import logging.handlers
|
||||||
import os
|
import os
|
||||||
|
@ -13,11 +12,10 @@ import ray._private.services
|
||||||
import ray._private.utils
|
import ray._private.utils
|
||||||
import ray.dashboard.consts as dashboard_consts
|
import ray.dashboard.consts as dashboard_consts
|
||||||
import ray.dashboard.utils as dashboard_utils
|
import ray.dashboard.utils as dashboard_utils
|
||||||
import ray.experimental.internal_kv as internal_kv
|
|
||||||
from ray._private.gcs_pubsub import GcsAioPublisher, GcsPublisher
|
from ray._private.gcs_pubsub import GcsAioPublisher, GcsPublisher
|
||||||
from ray._private.gcs_utils import GcsAioClient, GcsClient
|
from ray._private.gcs_utils import GcsAioClient, GcsClient
|
||||||
from ray._private.ray_logging import setup_component_logger
|
from ray._private.ray_logging import setup_component_logger
|
||||||
from ray.core.generated import agent_manager_pb2, agent_manager_pb2_grpc
|
from ray.core.generated import agent_manager_pb2, agent_manager_pb2_grpc, common_pb2
|
||||||
from ray.experimental.internal_kv import (
|
from ray.experimental.internal_kv import (
|
||||||
_initialize_internal_kv,
|
_initialize_internal_kv,
|
||||||
_internal_kv_initialized,
|
_internal_kv_initialized,
|
||||||
|
@ -262,22 +260,20 @@ class DashboardAgent:
|
||||||
# TODO: Use async version if performance is an issue
|
# TODO: Use async version if performance is an issue
|
||||||
# -1 should indicate that http server is not started.
|
# -1 should indicate that http server is not started.
|
||||||
http_port = -1 if not self.http_server else self.http_server.http_port
|
http_port = -1 if not self.http_server else self.http_server.http_port
|
||||||
internal_kv._internal_kv_put(
|
|
||||||
f"{dashboard_consts.DASHBOARD_AGENT_PORT_PREFIX}{self.node_id}",
|
|
||||||
json.dumps([http_port, self.grpc_port]),
|
|
||||||
namespace=ray_constants.KV_NAMESPACE_DASHBOARD,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Register agent to agent manager.
|
# Register agent to agent manager.
|
||||||
raylet_stub = agent_manager_pb2_grpc.AgentManagerServiceStub(
|
raylet_stub = agent_manager_pb2_grpc.AgentManagerServiceStub(
|
||||||
self.aiogrpc_raylet_channel
|
self.aiogrpc_raylet_channel
|
||||||
)
|
)
|
||||||
|
|
||||||
await raylet_stub.RegisterAgent(
|
await raylet_stub.RegisterAgent(
|
||||||
agent_manager_pb2.RegisterAgentRequest(
|
agent_manager_pb2.RegisterAgentRequest(
|
||||||
agent_id=self.agent_id,
|
agent_info=common_pb2.AgentInfo(
|
||||||
agent_port=self.grpc_port,
|
id=self.agent_id,
|
||||||
agent_ip_address=self.ip,
|
pid=os.getpid(),
|
||||||
|
grpc_port=self.grpc_port,
|
||||||
|
http_port=http_port,
|
||||||
|
ip_address=self.ip,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -1,7 +1,6 @@
|
||||||
from ray._private.ray_constants import env_integer
|
from ray._private.ray_constants import env_integer
|
||||||
|
|
||||||
DASHBOARD_LOG_FILENAME = "dashboard.log"
|
DASHBOARD_LOG_FILENAME = "dashboard.log"
|
||||||
DASHBOARD_AGENT_PORT_PREFIX = "DASHBOARD_AGENT_PORT_PREFIX:"
|
|
||||||
DASHBOARD_AGENT_LOG_FILENAME = "dashboard_agent.log"
|
DASHBOARD_AGENT_LOG_FILENAME = "dashboard_agent.log"
|
||||||
DASHBOARD_AGENT_CHECK_PARENT_INTERVAL_SECONDS = 2
|
DASHBOARD_AGENT_CHECK_PARENT_INTERVAL_SECONDS = 2
|
||||||
RAY_STATE_SERVER_MAX_HTTP_REQUEST_ENV_NAME = "RAY_STATE_SERVER_MAX_HTTP_REQUEST"
|
RAY_STATE_SERVER_MAX_HTTP_REQUEST_ENV_NAME = "RAY_STATE_SERVER_MAX_HTTP_REQUEST"
|
||||||
|
|
|
@ -1,5 +1,4 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
import json
|
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
import time
|
import time
|
||||||
|
@ -7,7 +6,6 @@ import time
|
||||||
import aiohttp.web
|
import aiohttp.web
|
||||||
|
|
||||||
import ray._private.utils
|
import ray._private.utils
|
||||||
import ray.dashboard.consts as dashboard_consts
|
|
||||||
import ray.dashboard.optional_utils as dashboard_optional_utils
|
import ray.dashboard.optional_utils as dashboard_optional_utils
|
||||||
import ray.dashboard.utils as dashboard_utils
|
import ray.dashboard.utils as dashboard_utils
|
||||||
from ray._private import ray_constants
|
from ray._private import ray_constants
|
||||||
|
@ -131,10 +129,9 @@ class NodeHead(dashboard_utils.DashboardHeadModule):
|
||||||
try:
|
try:
|
||||||
nodes = await self._get_nodes()
|
nodes = await self._get_nodes()
|
||||||
|
|
||||||
alive_node_ids = []
|
|
||||||
alive_node_infos = []
|
|
||||||
node_id_to_ip = {}
|
node_id_to_ip = {}
|
||||||
node_id_to_hostname = {}
|
node_id_to_hostname = {}
|
||||||
|
agents = dict(DataSource.agents)
|
||||||
for node in nodes.values():
|
for node in nodes.values():
|
||||||
node_id = node["nodeId"]
|
node_id = node["nodeId"]
|
||||||
ip = node["nodeManagerAddress"]
|
ip = node["nodeManagerAddress"]
|
||||||
|
@ -150,20 +147,10 @@ class NodeHead(dashboard_utils.DashboardHeadModule):
|
||||||
node_id_to_hostname[node_id] = hostname
|
node_id_to_hostname[node_id] = hostname
|
||||||
assert node["state"] in ["ALIVE", "DEAD"]
|
assert node["state"] in ["ALIVE", "DEAD"]
|
||||||
if node["state"] == "ALIVE":
|
if node["state"] == "ALIVE":
|
||||||
alive_node_ids.append(node_id)
|
agents[node_id] = [
|
||||||
alive_node_infos.append(node)
|
node["agentInfo"]["httpPort"],
|
||||||
|
node["agentInfo"]["grpcPort"],
|
||||||
agents = dict(DataSource.agents)
|
]
|
||||||
for node_id in alive_node_ids:
|
|
||||||
key = f"{dashboard_consts.DASHBOARD_AGENT_PORT_PREFIX}" f"{node_id}"
|
|
||||||
# TODO: Use async version if performance is an issue
|
|
||||||
agent_port = ray.experimental.internal_kv._internal_kv_get(
|
|
||||||
key, namespace=ray_constants.KV_NAMESPACE_DASHBOARD
|
|
||||||
)
|
|
||||||
if agent_port:
|
|
||||||
agents[node_id] = json.loads(agent_port)
|
|
||||||
for node_id in agents.keys() - set(alive_node_ids):
|
|
||||||
agents.pop(node_id, None)
|
|
||||||
|
|
||||||
DataSource.node_id_to_ip.reset(node_id_to_ip)
|
DataSource.node_id_to_ip.reset(node_id_to_ip)
|
||||||
DataSource.node_id_to_hostname.reset(node_id_to_hostname)
|
DataSource.node_id_to_hostname.reset(node_id_to_hostname)
|
||||||
|
|
|
@ -6,6 +6,7 @@ import time
|
||||||
import traceback
|
import traceback
|
||||||
import random
|
import random
|
||||||
import pytest
|
import pytest
|
||||||
|
import psutil
|
||||||
import ray
|
import ray
|
||||||
import threading
|
import threading
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
|
@ -18,6 +19,7 @@ from ray._private.test_utils import (
|
||||||
wait_for_condition,
|
wait_for_condition,
|
||||||
wait_until_succeeded_without_exception,
|
wait_until_succeeded_without_exception,
|
||||||
)
|
)
|
||||||
|
from ray._private.state import state
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
@ -348,5 +350,33 @@ def test_frequent_node_update(
|
||||||
wait_for_condition(verify, timeout=15)
|
wait_for_condition(verify, timeout=15)
|
||||||
|
|
||||||
|
|
||||||
|
# See detail: https://github.com/ray-project/ray/issues/24361
|
||||||
|
@pytest.mark.skipif(sys.platform == "win32", reason="Flaky on Windows.")
|
||||||
|
def test_node_register_with_agent(ray_start_cluster_head):
|
||||||
|
def test_agent_port(pid, port):
|
||||||
|
p = psutil.Process(pid)
|
||||||
|
assert p.cmdline()[2].endswith("dashboard/agent.py")
|
||||||
|
|
||||||
|
for c in p.connections():
|
||||||
|
if c.status == psutil.CONN_LISTEN and c.laddr.port == port:
|
||||||
|
return
|
||||||
|
assert False
|
||||||
|
|
||||||
|
def test_agent_process(pid):
|
||||||
|
p = psutil.Process(pid)
|
||||||
|
assert p.cmdline()[2].endswith("dashboard/agent.py")
|
||||||
|
|
||||||
|
for node_info in state.node_table():
|
||||||
|
agent_info = node_info["AgentInfo"]
|
||||||
|
assert agent_info["IpAddress"] == node_info["NodeManagerAddress"]
|
||||||
|
test_agent_port(agent_info["Pid"], agent_info["GrpcPort"])
|
||||||
|
if agent_info["HttpPort"] >= 0:
|
||||||
|
test_agent_port(agent_info["Pid"], agent_info["HttpPort"])
|
||||||
|
else:
|
||||||
|
# Port conflicts may be caused that the previous
|
||||||
|
# test did not kill the agent cleanly
|
||||||
|
assert agent_info["HttpPort"] == -1
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
sys.exit(pytest.main(["-v", __file__]))
|
sys.exit(pytest.main(["-v", __file__]))
|
||||||
|
|
|
@ -110,7 +110,6 @@ def test_basic(ray_start_with_dashboard):
|
||||||
"""Dashboard test that starts a Ray cluster with a dashboard server running,
|
"""Dashboard test that starts a Ray cluster with a dashboard server running,
|
||||||
then hits the dashboard API and asserts that it receives sensible data."""
|
then hits the dashboard API and asserts that it receives sensible data."""
|
||||||
address_info = ray_start_with_dashboard
|
address_info = ray_start_with_dashboard
|
||||||
node_id = address_info["node_id"]
|
|
||||||
gcs_client = make_gcs_client(address_info)
|
gcs_client = make_gcs_client(address_info)
|
||||||
ray.experimental.internal_kv._initialize_internal_kv(gcs_client)
|
ray.experimental.internal_kv._initialize_internal_kv(gcs_client)
|
||||||
|
|
||||||
|
@ -146,11 +145,6 @@ def test_basic(ray_start_with_dashboard):
|
||||||
namespace=ray_constants.KV_NAMESPACE_DASHBOARD,
|
namespace=ray_constants.KV_NAMESPACE_DASHBOARD,
|
||||||
)
|
)
|
||||||
assert dashboard_rpc_address is not None
|
assert dashboard_rpc_address is not None
|
||||||
key = f"{dashboard_consts.DASHBOARD_AGENT_PORT_PREFIX}{node_id}"
|
|
||||||
agent_ports = ray.experimental.internal_kv._internal_kv_get(
|
|
||||||
key, namespace=ray_constants.KV_NAMESPACE_DASHBOARD
|
|
||||||
)
|
|
||||||
assert agent_ports is not None
|
|
||||||
|
|
||||||
|
|
||||||
def test_raylet_and_agent_share_fate(shutdown_only):
|
def test_raylet_and_agent_share_fate(shutdown_only):
|
||||||
|
@ -795,7 +789,6 @@ def test_dashboard_port_conflict(ray_start_with_dashboard):
|
||||||
)
|
)
|
||||||
def test_gcs_check_alive(fast_gcs_failure_detection, ray_start_with_dashboard):
|
def test_gcs_check_alive(fast_gcs_failure_detection, ray_start_with_dashboard):
|
||||||
assert wait_until_server_available(ray_start_with_dashboard["webui_url"]) is True
|
assert wait_until_server_available(ray_start_with_dashboard["webui_url"]) is True
|
||||||
|
|
||||||
all_processes = ray._private.worker._global_node.all_processes
|
all_processes = ray._private.worker._global_node.all_processes
|
||||||
dashboard_info = all_processes[ray_constants.PROCESS_TYPE_DASHBOARD][0]
|
dashboard_info = all_processes[ray_constants.PROCESS_TYPE_DASHBOARD][0]
|
||||||
dashboard_proc = psutil.Process(dashboard_info.process.pid)
|
dashboard_proc = psutil.Process(dashboard_info.process.pid)
|
||||||
|
|
|
@ -164,6 +164,12 @@ class GlobalState:
|
||||||
"RayletSocketName": item.raylet_socket_name,
|
"RayletSocketName": item.raylet_socket_name,
|
||||||
"MetricsExportPort": item.metrics_export_port,
|
"MetricsExportPort": item.metrics_export_port,
|
||||||
"NodeName": item.node_name,
|
"NodeName": item.node_name,
|
||||||
|
"AgentInfo": {
|
||||||
|
"IpAddress": item.agent_info.ip_address,
|
||||||
|
"GrpcPort": item.agent_info.grpc_port,
|
||||||
|
"HttpPort": item.agent_info.http_port,
|
||||||
|
"Pid": item.agent_info.pid,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
node_info["alive"] = node_info["Alive"]
|
node_info["alive"] = node_info["Alive"]
|
||||||
node_info["Resources"] = (
|
node_info["Resources"] = (
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
import atexit
|
import atexit
|
||||||
import faulthandler
|
import faulthandler
|
||||||
import functools
|
import functools
|
||||||
|
import grpc
|
||||||
import hashlib
|
import hashlib
|
||||||
import inspect
|
import inspect
|
||||||
import io
|
import io
|
||||||
|
@ -756,7 +757,6 @@ class Worker:
|
||||||
|
|
||||||
def print_logs(self):
|
def print_logs(self):
|
||||||
"""Prints log messages from workers on all nodes in the same job."""
|
"""Prints log messages from workers on all nodes in the same job."""
|
||||||
import grpc
|
|
||||||
|
|
||||||
subscriber = self.gcs_log_subscriber
|
subscriber = self.gcs_log_subscriber
|
||||||
subscriber.subscribe()
|
subscriber.subscribe()
|
||||||
|
@ -1902,6 +1902,11 @@ def connect(
|
||||||
if mode == SCRIPT_MODE:
|
if mode == SCRIPT_MODE:
|
||||||
raise e
|
raise e
|
||||||
elif mode == WORKER_MODE:
|
elif mode == WORKER_MODE:
|
||||||
|
if isinstance(e, grpc.RpcError) and e.code() in (
|
||||||
|
grpc.StatusCode.UNAVAILABLE,
|
||||||
|
grpc.StatusCode.UNKNOWN,
|
||||||
|
):
|
||||||
|
raise e
|
||||||
traceback_str = traceback.format_exc()
|
traceback_str = traceback.format_exc()
|
||||||
ray._private.utils.publish_error_to_driver(
|
ray._private.utils.publish_error_to_driver(
|
||||||
ray_constants.VERSION_MISMATCH_PUSH_ERROR,
|
ray_constants.VERSION_MISMATCH_PUSH_ERROR,
|
||||||
|
|
|
@ -68,7 +68,8 @@ py_test_module_list(
|
||||||
"test_healthcheck.py",
|
"test_healthcheck.py",
|
||||||
"test_kill_raylet_signal_log.py",
|
"test_kill_raylet_signal_log.py",
|
||||||
"test_memstat.py",
|
"test_memstat.py",
|
||||||
"test_protobuf_compatibility.py"
|
"test_protobuf_compatibility.py",
|
||||||
|
"test_scheduling_performance.py"
|
||||||
],
|
],
|
||||||
size = "medium",
|
size = "medium",
|
||||||
tags = ["exclusive", "medium_size_python_tests_a_to_j", "team:core"],
|
tags = ["exclusive", "medium_size_python_tests_a_to_j", "team:core"],
|
||||||
|
@ -120,10 +121,8 @@ py_test_module_list(
|
||||||
"test_multi_node_2.py",
|
"test_multi_node_2.py",
|
||||||
"test_multinode_failures.py",
|
"test_multinode_failures.py",
|
||||||
"test_multinode_failures_2.py",
|
"test_multinode_failures_2.py",
|
||||||
"test_multiprocessing.py",
|
|
||||||
"test_object_assign_owner.py",
|
"test_object_assign_owner.py",
|
||||||
"test_placement_group.py",
|
"test_placement_group.py",
|
||||||
"test_placement_group_2.py",
|
|
||||||
"test_placement_group_3.py",
|
"test_placement_group_3.py",
|
||||||
"test_placement_group_4.py",
|
"test_placement_group_4.py",
|
||||||
"test_placement_group_5.py",
|
"test_placement_group_5.py",
|
||||||
|
@ -184,7 +183,6 @@ py_test_module_list(
|
||||||
"test_cross_language.py",
|
"test_cross_language.py",
|
||||||
"test_environ.py",
|
"test_environ.py",
|
||||||
"test_raylet_output.py",
|
"test_raylet_output.py",
|
||||||
"test_scheduling_performance.py",
|
|
||||||
"test_get_or_create_actor.py",
|
"test_get_or_create_actor.py",
|
||||||
],
|
],
|
||||||
size = "small",
|
size = "small",
|
||||||
|
@ -295,6 +293,7 @@ py_test_module_list(
|
||||||
"test_placement_group_mini_integration.py",
|
"test_placement_group_mini_integration.py",
|
||||||
"test_scheduling_2.py",
|
"test_scheduling_2.py",
|
||||||
"test_multi_node_3.py",
|
"test_multi_node_3.py",
|
||||||
|
"test_placement_group_2.py",
|
||||||
],
|
],
|
||||||
size = "large",
|
size = "large",
|
||||||
tags = ["exclusive", "large_size_python_tests_shard_1", "team:core"],
|
tags = ["exclusive", "large_size_python_tests_shard_1", "team:core"],
|
||||||
|
|
|
@ -15,6 +15,7 @@ from pathlib import Path
|
||||||
from tempfile import gettempdir
|
from tempfile import gettempdir
|
||||||
from typing import List, Tuple
|
from typing import List, Tuple
|
||||||
from unittest import mock
|
from unittest import mock
|
||||||
|
import signal
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
@ -204,10 +205,19 @@ def _ray_start(**kwargs):
|
||||||
init_kwargs.update(kwargs)
|
init_kwargs.update(kwargs)
|
||||||
# Start the Ray processes.
|
# Start the Ray processes.
|
||||||
address_info = ray.init("local", **init_kwargs)
|
address_info = ray.init("local", **init_kwargs)
|
||||||
|
agent_pids = []
|
||||||
|
for node in ray.nodes():
|
||||||
|
agent_pids.append(int(node["AgentInfo"]["Pid"]))
|
||||||
|
|
||||||
yield address_info
|
yield address_info
|
||||||
# The code after the yield will run as teardown code.
|
# The code after the yield will run as teardown code.
|
||||||
ray.shutdown()
|
ray.shutdown()
|
||||||
|
# Make sure the agent process is dead.
|
||||||
|
for pid in agent_pids:
|
||||||
|
try:
|
||||||
|
os.kill(pid, signal.SIGKILL)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
# Delete the cluster address just in case.
|
# Delete the cluster address just in case.
|
||||||
ray._private.utils.reset_ray_address()
|
ray._private.utils.reset_ray_address()
|
||||||
|
|
||||||
|
|
|
@ -834,8 +834,9 @@ def test_ray_status(shutdown_only, monkeypatch):
|
||||||
|
|
||||||
@pytest.mark.xfail(cluster_not_supported, reason="cluster not supported on Windows")
|
@pytest.mark.xfail(cluster_not_supported, reason="cluster not supported on Windows")
|
||||||
def test_ray_status_multinode(ray_start_cluster):
|
def test_ray_status_multinode(ray_start_cluster):
|
||||||
|
NODE_NUMBER = 4
|
||||||
cluster = ray_start_cluster
|
cluster = ray_start_cluster
|
||||||
for _ in range(4):
|
for _ in range(NODE_NUMBER):
|
||||||
cluster.add_node(num_cpus=2)
|
cluster.add_node(num_cpus=2)
|
||||||
runner = CliRunner()
|
runner = CliRunner()
|
||||||
|
|
||||||
|
@ -850,8 +851,12 @@ def test_ray_status_multinode(ray_start_cluster):
|
||||||
|
|
||||||
wait_for_condition(output_ready)
|
wait_for_condition(output_ready)
|
||||||
|
|
||||||
result = runner.invoke(scripts.status, [])
|
def check_result():
|
||||||
_check_output_via_pattern("test_ray_status_multinode.txt", result)
|
result = runner.invoke(scripts.status, [])
|
||||||
|
_check_output_via_pattern("test_ray_status_multinode.txt", result)
|
||||||
|
return True
|
||||||
|
|
||||||
|
wait_for_condition(check_result)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(
|
@pytest.mark.skipif(
|
||||||
|
|
|
@ -11,17 +11,29 @@ from ray._private.test_utils import wait_for_condition, run_string_as_driver_non
|
||||||
|
|
||||||
|
|
||||||
def get_all_ray_worker_processes():
|
def get_all_ray_worker_processes():
|
||||||
processes = [
|
processes = psutil.process_iter(attrs=["pid", "name", "cmdline"])
|
||||||
p.info["cmdline"] for p in psutil.process_iter(attrs=["pid", "name", "cmdline"])
|
|
||||||
]
|
|
||||||
|
|
||||||
result = []
|
result = []
|
||||||
for p in processes:
|
for p in processes:
|
||||||
if p is not None and len(p) > 0 and "ray::" in p[0]:
|
cmd_line = p.info["cmdline"]
|
||||||
|
if cmd_line is not None and len(cmd_line) > 0 and "ray::" in cmd_line[0]:
|
||||||
result.append(p)
|
result.append(p)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def kill_all_ray_worker_process():
|
||||||
|
# Avoiding the previous test doesn't kill the relevant process,
|
||||||
|
# thus making the current test fail.
|
||||||
|
ray_process = get_all_ray_worker_processes()
|
||||||
|
for p in ray_process:
|
||||||
|
try:
|
||||||
|
p.kill()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
yield
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def short_gcs_publish_timeout(monkeypatch):
|
def short_gcs_publish_timeout(monkeypatch):
|
||||||
monkeypatch.setenv("RAY_MAX_GCS_PUBLISH_RETRIES", "3")
|
monkeypatch.setenv("RAY_MAX_GCS_PUBLISH_RETRIES", "3")
|
||||||
|
@ -29,8 +41,11 @@ def short_gcs_publish_timeout(monkeypatch):
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(platform.system() == "Windows", reason="Hang on Windows.")
|
@pytest.mark.skipif(platform.system() == "Windows", reason="Hang on Windows.")
|
||||||
def test_ray_shutdown(short_gcs_publish_timeout, shutdown_only):
|
def test_ray_shutdown(
|
||||||
|
kill_all_ray_worker_process, short_gcs_publish_timeout, shutdown_only
|
||||||
|
):
|
||||||
"""Make sure all ray workers are shutdown when driver is done."""
|
"""Make sure all ray workers are shutdown when driver is done."""
|
||||||
|
|
||||||
ray.init()
|
ray.init()
|
||||||
|
|
||||||
@ray.remote
|
@ray.remote
|
||||||
|
@ -45,12 +60,15 @@ def test_ray_shutdown(short_gcs_publish_timeout, shutdown_only):
|
||||||
|
|
||||||
ray.shutdown()
|
ray.shutdown()
|
||||||
|
|
||||||
wait_for_condition(lambda: len(get_all_ray_worker_processes()) == 0)
|
wait_for_condition(lambda: len(get_all_ray_worker_processes()) == 0, timeout=20)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(platform.system() == "Windows", reason="Hang on Windows.")
|
@pytest.mark.skipif(platform.system() == "Windows", reason="Hang on Windows.")
|
||||||
def test_driver_dead(short_gcs_publish_timeout, shutdown_only):
|
def test_driver_dead(
|
||||||
|
kill_all_ray_worker_process, short_gcs_publish_timeout, shutdown_only
|
||||||
|
):
|
||||||
"""Make sure all ray workers are shutdown when driver is killed."""
|
"""Make sure all ray workers are shutdown when driver is killed."""
|
||||||
|
|
||||||
driver = """
|
driver = """
|
||||||
import ray
|
import ray
|
||||||
ray.init(_system_config={"gcs_rpc_server_reconnect_timeout_s": 1})
|
ray.init(_system_config={"gcs_rpc_server_reconnect_timeout_s": 1})
|
||||||
|
@ -74,12 +92,15 @@ tasks = [f.remote() for _ in range(num_cpus)]
|
||||||
p.wait()
|
p.wait()
|
||||||
time.sleep(0.1)
|
time.sleep(0.1)
|
||||||
|
|
||||||
wait_for_condition(lambda: len(get_all_ray_worker_processes()) == 0)
|
wait_for_condition(lambda: len(get_all_ray_worker_processes()) == 0, timeout=20)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(platform.system() == "Windows", reason="Hang on Windows.")
|
@pytest.mark.skipif(platform.system() == "Windows", reason="Hang on Windows.")
|
||||||
def test_node_killed(short_gcs_publish_timeout, ray_start_cluster):
|
def test_node_killed(
|
||||||
|
kill_all_ray_worker_process, short_gcs_publish_timeout, ray_start_cluster
|
||||||
|
):
|
||||||
"""Make sure all ray workers when nodes are dead."""
|
"""Make sure all ray workers when nodes are dead."""
|
||||||
|
|
||||||
cluster = ray_start_cluster
|
cluster = ray_start_cluster
|
||||||
# head node.
|
# head node.
|
||||||
cluster.add_node(
|
cluster.add_node(
|
||||||
|
@ -106,12 +127,15 @@ def test_node_killed(short_gcs_publish_timeout, ray_start_cluster):
|
||||||
for worker in workers:
|
for worker in workers:
|
||||||
cluster.remove_node(worker)
|
cluster.remove_node(worker)
|
||||||
|
|
||||||
wait_for_condition(lambda: len(get_all_ray_worker_processes()) == 0)
|
wait_for_condition(lambda: len(get_all_ray_worker_processes()) == 0, timeout=20)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(platform.system() == "Windows", reason="Hang on Windows.")
|
@pytest.mark.skipif(platform.system() == "Windows", reason="Hang on Windows.")
|
||||||
def test_head_node_down(short_gcs_publish_timeout, ray_start_cluster):
|
def test_head_node_down(
|
||||||
|
kill_all_ray_worker_process, short_gcs_publish_timeout, ray_start_cluster
|
||||||
|
):
|
||||||
"""Make sure all ray workers when head node is dead."""
|
"""Make sure all ray workers when head node is dead."""
|
||||||
|
|
||||||
cluster = ray_start_cluster
|
cluster = ray_start_cluster
|
||||||
# head node.
|
# head node.
|
||||||
head = cluster.add_node(
|
head = cluster.add_node(
|
||||||
|
@ -149,7 +173,7 @@ time.sleep(100)
|
||||||
|
|
||||||
cluster.remove_node(head)
|
cluster.remove_node(head)
|
||||||
|
|
||||||
wait_for_condition(lambda: len(get_all_ray_worker_processes()) == 0)
|
wait_for_condition(lambda: len(get_all_ray_worker_processes()) == 0, timeout=20)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
@ -10,7 +10,6 @@ import yaml
|
||||||
from click.testing import CliRunner
|
from click.testing import CliRunner
|
||||||
|
|
||||||
import ray
|
import ray
|
||||||
import ray.dashboard.consts as dashboard_consts
|
|
||||||
import ray._private.state as global_state
|
import ray._private.state as global_state
|
||||||
import ray._private.ray_constants as ray_constants
|
import ray._private.ray_constants as ray_constants
|
||||||
from ray._private.test_utils import (
|
from ray._private.test_utils import (
|
||||||
|
@ -1180,16 +1179,8 @@ async def test_state_data_source_client(ray_start_cluster):
|
||||||
wait_for_condition(lambda: len(ray.nodes()) == 2)
|
wait_for_condition(lambda: len(ray.nodes()) == 2)
|
||||||
for node in ray.nodes():
|
for node in ray.nodes():
|
||||||
node_id = node["NodeID"]
|
node_id = node["NodeID"]
|
||||||
key = f"{dashboard_consts.DASHBOARD_AGENT_PORT_PREFIX}{node_id}"
|
|
||||||
|
|
||||||
def get_port():
|
|
||||||
return ray.experimental.internal_kv._internal_kv_get(
|
|
||||||
key, namespace=ray_constants.KV_NAMESPACE_DASHBOARD
|
|
||||||
)
|
|
||||||
|
|
||||||
wait_for_condition(lambda: get_port() is not None)
|
|
||||||
# The second index is the gRPC port
|
# The second index is the gRPC port
|
||||||
port = json.loads(get_port())[1]
|
port = node["AgentInfo"]["GrpcPort"]
|
||||||
ip = node["NodeManagerAddress"]
|
ip = node["NodeManagerAddress"]
|
||||||
client.register_agent_client(node_id, ip, port)
|
client.register_agent_client(node_id, ip, port)
|
||||||
result = await client.get_runtime_envs_info(node_id)
|
result = await client.get_runtime_envs_info(node_id)
|
||||||
|
@ -1391,16 +1382,8 @@ async def test_state_data_source_client_limit_distributed_sources(ray_start_clus
|
||||||
"""
|
"""
|
||||||
for node in ray.nodes():
|
for node in ray.nodes():
|
||||||
node_id = node["NodeID"]
|
node_id = node["NodeID"]
|
||||||
key = f"{dashboard_consts.DASHBOARD_AGENT_PORT_PREFIX}{node_id}"
|
|
||||||
|
|
||||||
def get_port():
|
|
||||||
return ray.experimental.internal_kv._internal_kv_get(
|
|
||||||
key, namespace=ray_constants.KV_NAMESPACE_DASHBOARD
|
|
||||||
)
|
|
||||||
|
|
||||||
wait_for_condition(lambda: get_port() is not None)
|
|
||||||
# The second index is the gRPC port
|
# The second index is the gRPC port
|
||||||
port = json.loads(get_port())[1]
|
port = node["AgentInfo"]["GrpcPort"]
|
||||||
ip = node["NodeManagerAddress"]
|
ip = node["NodeManagerAddress"]
|
||||||
client.register_agent_client(node_id, ip, port)
|
client.register_agent_client(node_id, ip, port)
|
||||||
|
|
||||||
|
@ -1513,8 +1496,13 @@ def test_cli_apis_sanity_check(ray_start_cluster):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
# Test get workers by id
|
# Test get workers by id
|
||||||
|
|
||||||
|
# Still need a `wait_for_condition`,
|
||||||
|
# because the worker obtained through the api server will not filter the driver,
|
||||||
|
# but `global_state.workers` will filter the driver.
|
||||||
|
wait_for_condition(lambda: len(global_state.workers()) > 0)
|
||||||
workers = global_state.workers()
|
workers = global_state.workers()
|
||||||
assert len(workers) > 0
|
|
||||||
worker_id = list(workers.keys())[0]
|
worker_id = list(workers.keys())[0]
|
||||||
wait_for_condition(
|
wait_for_condition(
|
||||||
lambda: verify_output(ray_get, ["workers", worker_id], ["worker_id", worker_id])
|
lambda: verify_output(ray_get, ["workers", worker_id], ["worker_id", worker_id])
|
||||||
|
|
|
@ -243,7 +243,7 @@ python_grpc_compile(
|
||||||
proto_library(
|
proto_library(
|
||||||
name = "agent_manager_proto",
|
name = "agent_manager_proto",
|
||||||
srcs = ["agent_manager.proto"],
|
srcs = ["agent_manager.proto"],
|
||||||
deps = [],
|
deps = [":common_proto"],
|
||||||
)
|
)
|
||||||
|
|
||||||
python_grpc_compile(
|
python_grpc_compile(
|
||||||
|
|
|
@ -17,6 +17,8 @@ option cc_enable_arenas = true;
|
||||||
|
|
||||||
package ray.rpc;
|
package ray.rpc;
|
||||||
|
|
||||||
|
import "src/ray/protobuf/common.proto";
|
||||||
|
|
||||||
enum AgentRpcStatus {
|
enum AgentRpcStatus {
|
||||||
// OK.
|
// OK.
|
||||||
AGENT_RPC_STATUS_OK = 0;
|
AGENT_RPC_STATUS_OK = 0;
|
||||||
|
@ -25,9 +27,7 @@ enum AgentRpcStatus {
|
||||||
}
|
}
|
||||||
|
|
||||||
message RegisterAgentRequest {
|
message RegisterAgentRequest {
|
||||||
int32 agent_id = 1;
|
AgentInfo agent_info = 1;
|
||||||
int32 agent_port = 2;
|
|
||||||
string agent_ip_address = 3;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
message RegisterAgentReply {
|
message RegisterAgentReply {
|
||||||
|
|
|
@ -677,3 +677,17 @@ message NamedActorInfo {
|
||||||
string ray_namespace = 1;
|
string ray_namespace = 1;
|
||||||
string name = 2;
|
string name = 2;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Info about a agent process.
|
||||||
|
message AgentInfo {
|
||||||
|
// The agent id.
|
||||||
|
int32 id = 1;
|
||||||
|
// The agent process pid.
|
||||||
|
int64 pid = 2;
|
||||||
|
// IP address of the agent process.
|
||||||
|
string ip_address = 3;
|
||||||
|
// The GRPC port number of the agent process.
|
||||||
|
int32 grpc_port = 4;
|
||||||
|
// The http port number of the agent process.
|
||||||
|
int32 http_port = 5;
|
||||||
|
}
|
||||||
|
|
|
@ -247,6 +247,8 @@ message GcsNodeInfo {
|
||||||
|
|
||||||
// The user-provided identifier or name for this node.
|
// The user-provided identifier or name for this node.
|
||||||
string node_name = 12;
|
string node_name = 12;
|
||||||
|
// The information of the agent process.
|
||||||
|
AgentInfo agent_info = 13;
|
||||||
}
|
}
|
||||||
|
|
||||||
message HeartbeatTableData {
|
message HeartbeatTableData {
|
||||||
|
|
|
@ -29,19 +29,20 @@ namespace raylet {
|
||||||
void AgentManager::HandleRegisterAgent(const rpc::RegisterAgentRequest &request,
|
void AgentManager::HandleRegisterAgent(const rpc::RegisterAgentRequest &request,
|
||||||
rpc::RegisterAgentReply *reply,
|
rpc::RegisterAgentReply *reply,
|
||||||
rpc::SendReplyCallback send_reply_callback) {
|
rpc::SendReplyCallback send_reply_callback) {
|
||||||
reported_agent_ip_address_ = request.agent_ip_address();
|
reported_agent_info_.CopyFrom(request.agent_info());
|
||||||
reported_agent_port_ = request.agent_port();
|
|
||||||
reported_agent_id_ = request.agent_id();
|
|
||||||
// TODO(SongGuyang): We should remove this after we find better port resolution.
|
// TODO(SongGuyang): We should remove this after we find better port resolution.
|
||||||
// Note: `agent_port_` should be 0 if the grpc port of agent is in conflict.
|
// Note: `reported_agent_info_.grpc_port()` should be 0 if the grpc port of agent is in
|
||||||
if (reported_agent_port_ != 0) {
|
// conflict.
|
||||||
|
if (reported_agent_info_.grpc_port() != 0) {
|
||||||
runtime_env_agent_client_ = runtime_env_agent_client_factory_(
|
runtime_env_agent_client_ = runtime_env_agent_client_factory_(
|
||||||
reported_agent_ip_address_, reported_agent_port_);
|
reported_agent_info_.ip_address(), reported_agent_info_.grpc_port());
|
||||||
RAY_LOG(INFO) << "HandleRegisterAgent, ip: " << reported_agent_ip_address_
|
RAY_LOG(INFO) << "HandleRegisterAgent, ip: " << reported_agent_info_.ip_address()
|
||||||
<< ", port: " << reported_agent_port_ << ", id: " << reported_agent_id_;
|
<< ", port: " << reported_agent_info_.grpc_port()
|
||||||
|
<< ", id: " << reported_agent_info_.id();
|
||||||
} else {
|
} else {
|
||||||
RAY_LOG(WARNING) << "The GRPC port of the Ray agent is invalid (0), ip: "
|
RAY_LOG(WARNING) << "The GRPC port of the Ray agent is invalid (0), ip: "
|
||||||
<< reported_agent_ip_address_ << ", id: " << reported_agent_id_
|
<< reported_agent_info_.ip_address()
|
||||||
|
<< ", id: " << reported_agent_info_.id()
|
||||||
<< ". The agent client in the raylet has been disabled.";
|
<< ". The agent client in the raylet has been disabled.";
|
||||||
disable_agent_client_ = true;
|
disable_agent_client_ = true;
|
||||||
}
|
}
|
||||||
|
@ -56,16 +57,19 @@ void AgentManager::StartAgent() {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create a non-zero random agent_id to pass to the child process
|
// Create a non-zero random agent_id_ to pass to the child process
|
||||||
// We cannot use pid an id because os.getpid() from the python process is not
|
// We cannot use pid an id because os.getpid() from the python process is not
|
||||||
// reliable when using a launcher.
|
// reliable when using a launcher.
|
||||||
// See https://github.com/ray-project/ray/issues/24361 and Python issue
|
// See https://github.com/ray-project/ray/issues/24361 and Python issue
|
||||||
// https://github.com/python/cpython/issues/83086
|
// https://github.com/python/cpython/issues/83086
|
||||||
int agent_id = 0;
|
agent_id_ = 0;
|
||||||
while (agent_id == 0) {
|
while (agent_id_ == 0) {
|
||||||
agent_id = rand();
|
agent_id_ = rand();
|
||||||
}
|
}
|
||||||
const std::string agent_id_str = std::to_string(agent_id);
|
// Make sure reported_agent_info_.id() not equal
|
||||||
|
// `agent_id_` before the agent finished register.
|
||||||
|
reported_agent_info_.set_id(0);
|
||||||
|
const std::string agent_id_str = std::to_string(agent_id_);
|
||||||
std::vector<const char *> argv;
|
std::vector<const char *> argv;
|
||||||
for (const std::string &arg : options_.agent_commands) {
|
for (const std::string &arg : options_.agent_commands) {
|
||||||
argv.push_back(arg.c_str());
|
argv.push_back(arg.c_str());
|
||||||
|
@ -104,21 +108,22 @@ void AgentManager::StartAgent() {
|
||||||
<< ec.message();
|
<< ec.message();
|
||||||
}
|
}
|
||||||
|
|
||||||
std::thread monitor_thread([this, child, agent_id]() mutable {
|
std::thread monitor_thread([this, child]() mutable {
|
||||||
SetThreadName("agent.monitor");
|
SetThreadName("agent.monitor");
|
||||||
RAY_LOG(INFO) << "Monitor agent process with id " << agent_id << ", register timeout "
|
RAY_LOG(INFO) << "Monitor agent process with id " << child.GetId()
|
||||||
|
<< ", register timeout "
|
||||||
<< RayConfig::instance().agent_register_timeout_ms() << "ms.";
|
<< RayConfig::instance().agent_register_timeout_ms() << "ms.";
|
||||||
auto timer = delay_executor_(
|
auto timer = delay_executor_(
|
||||||
[this, child, agent_id]() mutable {
|
[this, child]() mutable {
|
||||||
if (reported_agent_id_ != agent_id) {
|
if (!IsAgentRegistered()) {
|
||||||
if (reported_agent_id_ == 0) {
|
if (reported_agent_info_.id() == 0) {
|
||||||
RAY_LOG(WARNING) << "Agent process expected id " << agent_id
|
RAY_LOG(WARNING) << "Agent process expected id " << agent_id_
|
||||||
<< " timed out before registering. ip "
|
<< " timed out before registering. ip "
|
||||||
<< reported_agent_ip_address_ << ", id "
|
<< reported_agent_info_.ip_address() << ", id "
|
||||||
<< reported_agent_id_;
|
<< reported_agent_info_.id();
|
||||||
} else {
|
} else {
|
||||||
RAY_LOG(WARNING) << "Agent process expected id " << agent_id
|
RAY_LOG(WARNING) << "Agent process expected id " << agent_id_
|
||||||
<< " but got id " << reported_agent_id_
|
<< " but got id " << reported_agent_info_.id()
|
||||||
<< ", this is a fatal error";
|
<< ", this is a fatal error";
|
||||||
}
|
}
|
||||||
child.Kill();
|
child.Kill();
|
||||||
|
@ -128,9 +133,9 @@ void AgentManager::StartAgent() {
|
||||||
|
|
||||||
int exit_code = child.Wait();
|
int exit_code = child.Wait();
|
||||||
timer->cancel();
|
timer->cancel();
|
||||||
RAY_LOG(WARNING) << "Agent process with id " << agent_id << " exited, return value "
|
RAY_LOG(WARNING) << "Agent process with id " << agent_id_ << " exited, return value "
|
||||||
<< exit_code << ". ip " << reported_agent_ip_address_ << ". id "
|
<< exit_code << ". ip " << reported_agent_info_.ip_address()
|
||||||
<< reported_agent_id_;
|
<< ". id " << reported_agent_info_.id();
|
||||||
RAY_LOG(ERROR)
|
RAY_LOG(ERROR)
|
||||||
<< "The raylet exited immediately because the Ray agent failed. "
|
<< "The raylet exited immediately because the Ray agent failed. "
|
||||||
"The raylet fate shares with the agent. This can happen because the "
|
"The raylet fate shares with the agent. This can happen because the "
|
||||||
|
@ -303,5 +308,15 @@ void AgentManager::DeleteRuntimeEnvIfPossible(
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const ray::Status AgentManager::TryToGetAgentInfo(rpc::AgentInfo *agent_info) const {
|
||||||
|
if (IsAgentRegistered()) {
|
||||||
|
*agent_info = reported_agent_info_;
|
||||||
|
return ray::Status::OK();
|
||||||
|
} else {
|
||||||
|
std::string err_msg = "The agent has not finished register yet.";
|
||||||
|
return ray::Status::Invalid(err_msg);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace raylet
|
} // namespace raylet
|
||||||
} // namespace ray
|
} // namespace ray
|
||||||
|
|
|
@ -23,6 +23,7 @@
|
||||||
#include "ray/rpc/agent_manager/agent_manager_server.h"
|
#include "ray/rpc/agent_manager/agent_manager_server.h"
|
||||||
#include "ray/rpc/runtime_env/runtime_env_client.h"
|
#include "ray/rpc/runtime_env/runtime_env_client.h"
|
||||||
#include "ray/util/process.h"
|
#include "ray/util/process.h"
|
||||||
|
#include "src/ray/protobuf/gcs.pb.h"
|
||||||
|
|
||||||
namespace ray {
|
namespace ray {
|
||||||
namespace raylet {
|
namespace raylet {
|
||||||
|
@ -88,17 +89,28 @@ class AgentManager : public rpc::AgentManagerServiceHandler {
|
||||||
virtual void DeleteRuntimeEnvIfPossible(const std::string &serialized_runtime_env,
|
virtual void DeleteRuntimeEnvIfPossible(const std::string &serialized_runtime_env,
|
||||||
DeleteRuntimeEnvIfPossibleCallback callback);
|
DeleteRuntimeEnvIfPossibleCallback callback);
|
||||||
|
|
||||||
|
/// Try to Get the information about the agent process.
|
||||||
|
///
|
||||||
|
/// \param[out] agent_info The information of the agent process.
|
||||||
|
/// \return Status, if successful will return `ray::Status::OK`,
|
||||||
|
/// otherwise will return `ray::Status::Invalid`.
|
||||||
|
const ray::Status TryToGetAgentInfo(rpc::AgentInfo *agent_info) const;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void StartAgent();
|
void StartAgent();
|
||||||
|
|
||||||
|
const bool IsAgentRegistered() const { return reported_agent_info_.id() == agent_id_; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
Options options_;
|
Options options_;
|
||||||
pid_t reported_agent_id_ = 0;
|
/// we need to make sure `agent_id_` and `reported_agent_info_.id()` are not equal
|
||||||
int reported_agent_port_ = 0;
|
/// until the agent process is finished registering, the initial value of
|
||||||
|
/// `reported_agent_info_.id()` is 0, so I set the initial value of `agent_id_` is -1
|
||||||
|
int agent_id_ = -1;
|
||||||
|
rpc::AgentInfo reported_agent_info_;
|
||||||
/// Whether or not we intend to start the agent. This is false if we
|
/// Whether or not we intend to start the agent. This is false if we
|
||||||
/// are missing Ray Dashboard dependencies, for example.
|
/// are missing Ray Dashboard dependencies, for example.
|
||||||
bool should_start_agent_ = true;
|
bool should_start_agent_ = true;
|
||||||
std::string reported_agent_ip_address_;
|
|
||||||
DelayExecutorFn delay_executor_;
|
DelayExecutorFn delay_executor_;
|
||||||
RuntimeEnvAgentClientFactoryFn runtime_env_agent_client_factory_;
|
RuntimeEnvAgentClientFactoryFn runtime_env_agent_client_factory_;
|
||||||
std::shared_ptr<rpc::RuntimeEnvAgentClientInterface> runtime_env_agent_client_;
|
std::shared_ptr<rpc::RuntimeEnvAgentClientInterface> runtime_env_agent_client_;
|
||||||
|
|
|
@ -2860,6 +2860,10 @@ void NodeManager::PublishInfeasibleTaskError(const RayTask &task) const {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const ray::Status NodeManager::TryToGetAgentInfo(rpc::AgentInfo *agent_info) const {
|
||||||
|
return agent_manager_->TryToGetAgentInfo(agent_info);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace raylet
|
} // namespace raylet
|
||||||
|
|
||||||
} // namespace ray
|
} // namespace ray
|
||||||
|
|
|
@ -237,6 +237,13 @@ class NodeManager : public rpc::NodeManagerServiceHandler,
|
||||||
int64_t limit,
|
int64_t limit,
|
||||||
const std::function<void()> &on_all_replied);
|
const std::function<void()> &on_all_replied);
|
||||||
|
|
||||||
|
/// Try to Get the information about the agent process.
|
||||||
|
///
|
||||||
|
/// \param[out] agent_info The information of the agent process.
|
||||||
|
/// \return Status, if successful will return `ray::Status::OK`,
|
||||||
|
/// otherwise will return `ray::Status::Invalid`.
|
||||||
|
const ray::Status TryToGetAgentInfo(rpc::AgentInfo *agent_info) const;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
/// Methods for handling nodes.
|
/// Methods for handling nodes.
|
||||||
|
|
||||||
|
|
|
@ -97,7 +97,6 @@ Raylet::~Raylet() {}
|
||||||
|
|
||||||
void Raylet::Start() {
|
void Raylet::Start() {
|
||||||
RAY_CHECK_OK(RegisterGcs());
|
RAY_CHECK_OK(RegisterGcs());
|
||||||
|
|
||||||
// Start listening for clients.
|
// Start listening for clients.
|
||||||
DoAccept();
|
DoAccept();
|
||||||
}
|
}
|
||||||
|
@ -109,6 +108,21 @@ void Raylet::Stop() {
|
||||||
}
|
}
|
||||||
|
|
||||||
ray::Status Raylet::RegisterGcs() {
|
ray::Status Raylet::RegisterGcs() {
|
||||||
|
rpc::AgentInfo agent_info;
|
||||||
|
auto status = node_manager_.TryToGetAgentInfo(&agent_info);
|
||||||
|
if (status.ok()) {
|
||||||
|
self_node_info_.mutable_agent_info()->CopyFrom(agent_info);
|
||||||
|
} else {
|
||||||
|
// Because current function and `AgentManager::HandleRegisterAgent`
|
||||||
|
// will be invoke in same thread, so we need post current function
|
||||||
|
// into main_service_ after interval milliseconds.
|
||||||
|
std::this_thread::sleep_for(std::chrono::milliseconds(
|
||||||
|
RayConfig::instance().raylet_get_agent_info_interval_ms()));
|
||||||
|
main_service_.post([this]() { RAY_CHECK_OK(RegisterGcs()); },
|
||||||
|
"Raylet.TryToGetAgentInfoAndRegisterGcs");
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
auto register_callback = [this](const Status &status) {
|
auto register_callback = [this](const Status &status) {
|
||||||
RAY_CHECK_OK(status);
|
RAY_CHECK_OK(status);
|
||||||
RAY_LOG(INFO) << "Raylet of id, " << self_node_id_
|
RAY_LOG(INFO) << "Raylet of id, " << self_node_id_
|
||||||
|
|
|
@ -68,7 +68,7 @@ class Raylet {
|
||||||
NodeID GetNodeId() const { return self_node_id_; }
|
NodeID GetNodeId() const { return self_node_id_; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
/// Register GCS client.
|
/// Try to get agent info, after its success, register the current node to GCS.
|
||||||
ray::Status RegisterGcs();
|
ray::Status RegisterGcs();
|
||||||
|
|
||||||
/// Accept a client connection.
|
/// Accept a client connection.
|
||||||
|
|
|
@ -503,7 +503,8 @@ class WorkerPoolTest : public ::testing::Test {
|
||||||
false);
|
false);
|
||||||
rpc::RegisterAgentRequest request;
|
rpc::RegisterAgentRequest request;
|
||||||
// Set agent port to a nonzero value to avoid invalid agent client.
|
// Set agent port to a nonzero value to avoid invalid agent client.
|
||||||
request.set_agent_port(12345);
|
request.mutable_agent_info()->set_grpc_port(12345);
|
||||||
|
request.mutable_agent_info()->set_http_port(54321);
|
||||||
rpc::RegisterAgentReply reply;
|
rpc::RegisterAgentReply reply;
|
||||||
auto send_reply_callback =
|
auto send_reply_callback =
|
||||||
[](ray::Status status, std::function<void()> f1, std::function<void()> f2) {};
|
[](ray::Status status, std::function<void()> f1, std::function<void()> f2) {};
|
||||||
|
|
Loading…
Add table
Reference in a new issue