[Core] Add node_name field to GcsNodeInfo (#23543)

Make it easier to identify nodes by a string identifier separate from their IP address.
This commit is contained in:
jon-chuang 2022-04-19 08:03:12 -04:00 committed by GitHub
parent 082baa2342
commit e0c0ea2e59
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
18 changed files with 141 additions and 9 deletions

View file

@ -135,6 +135,7 @@ class RayParams:
node_manager_port=0,
gcs_server_port=None,
node_ip_address=None,
node_name=None,
raylet_ip_address=None,
min_worker_port=None,
max_worker_port=None,
@ -186,6 +187,7 @@ class RayParams:
self.node_manager_port = node_manager_port
self.gcs_server_port = gcs_server_port
self.node_ip_address = node_ip_address
self.node_name = node_name
self.raylet_ip_address = raylet_ip_address
self.min_worker_port = min_worker_port
self.max_worker_port = max_worker_port

View file

@ -1575,6 +1575,7 @@ def start_raylet(
backup_count=0,
ray_debugger_external=False,
env_updates=None,
node_name=None,
):
"""Start a raylet, which is a combined local scheduler and object manager.
@ -1799,6 +1800,10 @@ def start_raylet(
command.append("--huge_pages")
if socket_to_use:
socket_to_use.close()
if node_name is not None:
command.append(
f"--node-name={node_name}",
)
process_info = start_ray_process(
command,
ray_constants.PROCESS_TYPE_RAYLET,

View file

@ -55,8 +55,10 @@ class RayTestTimeoutException(Exception):
pass
def make_global_state_accessor(address_info):
gcs_options = GcsClientOptions.from_gcs_address(address_info["gcs_address"])
def make_global_state_accessor(ray_context):
gcs_options = GcsClientOptions.from_gcs_address(
ray_context.address_info["gcs_address"]
)
global_state_accessor = GlobalStateAccessor(gcs_options)
global_state_accessor.connect()
return global_state_accessor

View file

@ -994,6 +994,7 @@ class Node:
start_initial_python_workers_for_first_job=self._ray_params.start_initial_python_workers_for_first_job, # noqa: E501
ray_debugger_external=self._ray_params.ray_debugger_external,
env_updates=self._ray_params.env_vars,
node_name=self._ray_params.node_name,
)
assert ray_constants.PROCESS_TYPE_RAYLET not in self.all_processes
self.all_processes[ray_constants.PROCESS_TYPE_RAYLET] = [process_info]

View file

@ -271,6 +271,14 @@ def debug(address):
f"{ray_constants.DEFAULT_PORT}; if port is set to 0, we will"
f" allocate an available port.",
)
@click.option(
"--node-name",
required=False,
hidden=True,
type=str,
help="the user-provided identifier or name for this node. "
"Defaults to the node's ip_address",
)
@click.option(
"--redis-password",
required=False,
@ -511,6 +519,7 @@ def start(
node_ip_address,
address,
port,
node_name,
redis_password,
redis_shard_ports,
object_manager_port,
@ -582,6 +591,7 @@ def start(
redirect_output = None if not no_redirect_output else True
ray_params = ray._private.parameter.RayParams(
node_ip_address=node_ip_address,
node_name=node_name if node_name else node_ip_address,
min_worker_port=min_worker_port,
max_worker_port=max_worker_port,
worker_port_list=worker_port_list,

View file

@ -186,6 +186,7 @@ class GlobalState:
"ObjectStoreSocketName": item.object_store_socket_name,
"RayletSocketName": item.raylet_socket_name,
"MetricsExportPort": item.metrics_export_port,
"NodeName": item.node_name,
}
node_info["alive"] = node_info["Alive"]
node_info["Resources"] = (

View file

@ -16,8 +16,9 @@ from ray._private.test_utils import (
make_global_state_accessor,
)
# TODO(rliaw): The proper way to do this is to have the pytest config setup.
@pytest.mark.skipif(
pytest_timeout is None,
reason="Timeout package not installed; skipping test that may hang.",
@ -154,6 +155,52 @@ def test_global_state_actor_entry(ray_start_regular):
)
def test_node_name_cluster(ray_start_cluster):
cluster = ray_start_cluster
cluster.add_node(node_name="head_node", include_dashboard=False)
head_context = ray.init(address=cluster.address, include_dashboard=False)
cluster.add_node(node_name="worker_node", include_dashboard=False)
cluster.wait_for_nodes()
global_state_accessor = make_global_state_accessor(head_context)
node_table = global_state_accessor.get_node_table()
assert len(node_table) == 2
for node_data in node_table:
node = gcs_utils.GcsNodeInfo.FromString(node_data)
if (
ray._private.utils.binary_to_hex(node.node_id)
== head_context.address_info["node_id"]
):
assert node.node_name == "head_node"
else:
assert node.node_name == "worker_node"
global_state_accessor.disconnect()
ray.shutdown()
cluster.shutdown()
def test_node_name_init():
# Test ray.init with _node_name directly
new_head_context = ray.init(_node_name="new_head_node", include_dashboard=False)
global_state_accessor = make_global_state_accessor(new_head_context)
node_data = global_state_accessor.get_node_table()[0]
node = gcs_utils.GcsNodeInfo.FromString(node_data)
assert node.node_name == "new_head_node"
ray.shutdown()
def test_no_node_name():
# Test that starting ray with no node name will result in a node_name=ip_address
new_head_context = ray.init(include_dashboard=False)
global_state_accessor = make_global_state_accessor(new_head_context)
node_data = global_state_accessor.get_node_table()[0]
node = gcs_utils.GcsNodeInfo.FromString(node_data)
assert node.node_name == ray.util.get_node_ip_address()
ray.shutdown()
@pytest.mark.parametrize("max_shapes", [0, 2, -1])
def test_load_report(shutdown_only, max_shapes):
resource1 = "A"

View file

@ -527,6 +527,30 @@ time.sleep(5)
assert actor_repr not in out
def test_node_name_in_raylet_death():
NODE_NAME = "RAY_TEST_RAYLET_DEATH_NODE_NAME"
script = f"""
import ray
import time
import os
NUM_HEARTBEATS=10
HEARTBEAT_PERIOD=500
WAIT_BUFFER_SECONDS=5
os.environ["RAY_num_heartbeats_timeout"]=str(NUM_HEARTBEATS)
os.environ["RAY_raylet_heartbeat_period_milliseconds"]=str(HEARTBEAT_PERIOD)
ray.init(_node_name=\"{NODE_NAME}\")
# This will kill raylet without letting it exit gracefully.
ray.worker._global_node.kill_raylet()
time.sleep(NUM_HEARTBEATS * HEARTBEAT_PERIOD / 1000 + WAIT_BUFFER_SECONDS)
ray.shutdown()
"""
out = run_string_as_driver(script)
assert out.count(f"node name: {NODE_NAME} has been marked dead") == 1
if __name__ == "__main__":
if len(sys.argv) > 1 and sys.argv[1] == "_ray_instance":
# Set object store memory very low so that it won't complain

View file

@ -714,6 +714,7 @@ def init(
_metrics_export_port: Optional[int] = None,
_system_config: Optional[Dict[str, str]] = None,
_tracing_startup_hook: Optional[Callable] = None,
_node_name: str = None,
**kwargs,
) -> BaseContext:
"""
@ -835,6 +836,8 @@ def init(
(optional) additional instruments. See more at
docs.ray.io/tracing.html. It is currently under active development,
and the API is subject to change.
_node_name (str): User-provided node name or identifier. Defaults to
the node IP address.
Returns:
If the provided address includes a protocol, for example by prepending
@ -1015,6 +1018,7 @@ def init(
enable_object_reconstruction=_enable_object_reconstruction,
metrics_export_port=_metrics_export_port,
tracing_startup_hook=_tracing_startup_hook,
node_name=_node_name,
)
# Start the Ray processes. We set shutdown_at_exit=False because we
# shutdown the node in the ray.shutdown call that happens in the atexit
@ -1050,6 +1054,11 @@ def init(
"When connecting to an existing cluster, "
"_enable_object_reconstruction must not be provided."
)
if _node_name is not None:
raise ValueError(
"_node_name cannot be configured when connecting to "
"an existing cluster."
)
# In this case, we only need to connect the node.
ray_params = ray._private.parameter.RayParams(

View file

@ -131,7 +131,9 @@ TEST_P(GlobalStateAccessorTest, TestNodeTable) {
// It's useful to check if index value will be marked as address suffix.
for (int index = 0; index < node_count; ++index) {
auto node_table_data =
Mocker::GenNodeInfo(index, std::string("127.0.0.") + std::to_string(index));
Mocker::GenNodeInfo(index,
std::string("127.0.0.") + std::to_string(index),
"Mocker_node_" + std::to_string(index * 10));
std::promise<bool> promise;
RAY_CHECK_OK(gcs_client_->Nodes().AsyncRegister(
*node_table_data, [&promise](Status status) { promise.set_value(status.ok()); }));
@ -144,6 +146,9 @@ TEST_P(GlobalStateAccessorTest, TestNodeTable) {
node_data.ParseFromString(node_table[index]);
ASSERT_EQ(node_data.node_manager_address(),
std::string("127.0.0.") + std::to_string(node_data.node_manager_port()));
ASSERT_EQ(
node_data.node_name(),
std::string("Mocker_node_") + std::to_string(node_data.node_manager_port() * 10));
}
}

View file

@ -40,12 +40,14 @@ void GcsNodeManager::HandleRegisterNode(const rpc::RegisterNodeRequest &request,
rpc::SendReplyCallback send_reply_callback) {
NodeID node_id = NodeID::FromBinary(request.node_info().node_id());
RAY_LOG(INFO) << "Registering node info, node id = " << node_id
<< ", address = " << request.node_info().node_manager_address();
<< ", address = " << request.node_info().node_manager_address()
<< ", node name = " << request.node_info().node_name();
auto on_done = [this, node_id, request, reply, send_reply_callback](
const Status &status) {
RAY_CHECK_OK(status);
RAY_LOG(INFO) << "Finished registering node info, node id = " << node_id
<< ", address = " << request.node_info().node_manager_address();
<< ", address = " << request.node_info().node_manager_address()
<< ", node name = " << request.node_info().node_name();
RAY_CHECK_OK(gcs_publisher_->PublishNodeInfo(node_id, request.node_info(), nullptr));
AddNode(std::make_shared<rpc::GcsNodeInfo>(request.node_info()));
GCS_RPC_SEND_REPLY(send_reply_callback, reply, status);
@ -190,11 +192,12 @@ void GcsNodeManager::AddNode(std::shared_ptr<rpc::GcsNodeInfo> node) {
std::shared_ptr<rpc::GcsNodeInfo> GcsNodeManager::RemoveNode(
const ray::NodeID &node_id, bool is_intended /*= false*/) {
RAY_LOG(INFO) << "Removing node, node id = " << node_id;
std::shared_ptr<rpc::GcsNodeInfo> removed_node;
auto iter = alive_nodes_.find(node_id);
if (iter != alive_nodes_.end()) {
removed_node = std::move(iter->second);
RAY_LOG(INFO) << "Removing node, node id = " << node_id
<< ", node name = " << removed_node->node_name();
// Record stats that there's a new removed node.
stats::NodeFailureTotal.Record(1);
// Remove from alive nodes.
@ -206,7 +209,8 @@ std::shared_ptr<rpc::GcsNodeInfo> GcsNodeManager::RemoveNode(
std::string type = "node_removed";
std::ostringstream error_message;
error_message << "The node with node id: " << node_id
<< " and ip: " << removed_node->node_manager_address()
<< " and address: " << removed_node->node_manager_address()
<< " and node name: " << removed_node->node_name()
<< " has been marked dead because the detector"
<< " has missed too many heartbeats from it. This can happen when a "
"raylet crashes unexpectedly or has lagging heartbeats.";
@ -214,6 +218,7 @@ std::shared_ptr<rpc::GcsNodeInfo> GcsNodeManager::RemoveNode(
.WithField("node_id", node_id.Hex())
.WithField("ip", removed_node->node_manager_address())
<< error_message.str();
RAY_LOG(WARNING) << error_message.str();
auto error_data_ptr =
gcs::CreateErrorTableData(type, error_message.str(), current_time_ms());
RAY_CHECK_OK(gcs_publisher_->PublishError(node_id.Hex(), *error_data_ptr, nullptr));

View file

@ -187,11 +187,14 @@ struct Mocker {
return request;
}
static std::shared_ptr<rpc::GcsNodeInfo> GenNodeInfo(
uint16_t port = 0, const std::string address = "127.0.0.1") {
uint16_t port = 0,
const std::string address = "127.0.0.1",
const std::string node_name = "Mocker_node") {
auto node = std::make_shared<rpc::GcsNodeInfo>();
node->set_node_id(NodeID::FromRandom().Binary());
node->set_node_manager_port(port);
node->set_node_manager_address(address);
node->set_node_name(node_name);
return node;
}

View file

@ -245,6 +245,9 @@ message GcsNodeInfo {
// The total resources of this node.
map<string, double> resources_total = 11;
// The user-provided identifier or name for this node.
string node_name = 12;
}
message HeartbeatTableData {

View file

@ -63,6 +63,7 @@ DEFINE_string(resource_dir, "", "The path of this ray resource directory.");
DEFINE_int32(ray_debugger_external, 0, "Make Ray debugger externally accessible.");
// store options
DEFINE_int64(object_store_memory, -1, "The initial memory of the object store.");
DEFINE_string(node_name, "", "The user-provided identifier or name for this node.");
#ifdef __linux__
DEFINE_string(plasma_directory,
"/dev/shm",
@ -86,6 +87,8 @@ int main(int argc, char *argv[]) {
gflags::ParseCommandLineFlags(&argc, &argv, true);
const std::string raylet_socket_name = FLAGS_raylet_socket_name;
const std::string store_socket_name = FLAGS_store_socket_name;
const std::string node_name =
(FLAGS_node_name == "") ? FLAGS_node_ip_address : FLAGS_node_name;
const int object_manager_port = static_cast<int>(FLAGS_object_manager_port);
const int node_manager_port = static_cast<int>(FLAGS_node_manager_port);
const int metrics_agent_port = static_cast<int>(FLAGS_metrics_agent_port);
@ -262,6 +265,7 @@ int main(int argc, char *argv[]) {
raylet = std::make_unique<ray::raylet::Raylet>(main_service,
raylet_socket_name,
node_ip_address,
node_name,
node_manager_config,
object_manager_config,
gcs_client,

View file

@ -174,10 +174,12 @@ void HeartbeatSender::Heartbeat() {
NodeManager::NodeManager(instrumented_io_context &io_service,
const NodeID &self_node_id,
const std::string &self_node_name,
const NodeManagerConfig &config,
const ObjectManagerConfig &object_manager_config,
std::shared_ptr<gcs::GcsClient> gcs_client)
: self_node_id_(self_node_id),
self_node_name_(self_node_name),
io_service_(io_service),
gcs_client_(gcs_client),
worker_pool_(
@ -2157,6 +2159,8 @@ std::string NodeManager::DebugString() const {
std::stringstream result;
uint64_t now_ms = current_time_ms();
result << "NodeManager:";
result << "\nNode ID: " << self_node_id_;
result << "\nNode name: " << self_node_name_;
result << "\nInitialConfigResources: " << initial_config_.resource_config.ToString();
if (cluster_task_manager_ != nullptr) {
result << "\nClusterTaskManager:\n";

View file

@ -146,6 +146,7 @@ class NodeManager : public rpc::NodeManagerServiceHandler {
/// \param object_manager A reference to the local object manager.
NodeManager(instrumented_io_context &io_service,
const NodeID &self_node_id,
const std::string &self_node_name,
const NodeManagerConfig &config,
const ObjectManagerConfig &object_manager_config,
std::shared_ptr<gcs::GcsClient> gcs_client);
@ -617,6 +618,8 @@ class NodeManager : public rpc::NodeManagerServiceHandler {
/// ID of this node.
NodeID self_node_id_;
/// The user-given identifier or name of this node.
std::string self_node_name_;
instrumented_io_context &io_service_;
/// A client connection to the GCS.
std::shared_ptr<gcs::GcsClient> gcs_client_;

View file

@ -58,6 +58,7 @@ namespace raylet {
Raylet::Raylet(instrumented_io_context &main_service,
const std::string &socket_name,
const std::string &node_ip_address,
const std::string &node_name,
const NodeManagerConfig &node_manager_config,
const ObjectManagerConfig &object_manager_config,
std::shared_ptr<gcs::GcsClient> gcs_client,
@ -70,6 +71,7 @@ Raylet::Raylet(instrumented_io_context &main_service,
gcs_client_(gcs_client),
node_manager_(main_service,
self_node_id_,
node_name,
node_manager_config,
object_manager_config,
gcs_client_),
@ -79,6 +81,7 @@ Raylet::Raylet(instrumented_io_context &main_service,
self_node_info_.set_node_id(self_node_id_.Binary());
self_node_info_.set_state(GcsNodeInfo::ALIVE);
self_node_info_.set_node_manager_address(node_ip_address);
self_node_info_.set_node_name(node_name);
self_node_info_.set_raylet_socket_name(socket_name);
self_node_info_.set_object_store_socket_name(object_manager_config.store_socket_name);
self_node_info_.set_object_manager_port(node_manager_.GetObjectManagerPort());

View file

@ -50,6 +50,7 @@ class Raylet {
Raylet(instrumented_io_context &main_service,
const std::string &socket_name,
const std::string &node_ip_address,
const std::string &node_name,
const NodeManagerConfig &node_manager_config,
const ObjectManagerConfig &object_manager_config,
std::shared_ptr<gcs::GcsClient> gcs_client,