[Core] Added ability to specify different IP addresses for a core worker and its raylet. (#7985)

This commit is contained in:
Clark Zinzow 2020-04-16 09:32:24 -06:00 committed by GitHub
parent d0fab84e4d
commit d4cae5f632
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
26 changed files with 170 additions and 71 deletions

View file

@ -630,8 +630,8 @@ cdef class CoreWorker:
def __cinit__(self, is_driver, store_socket, raylet_socket,
JobID job_id, GcsClientOptions gcs_options, log_dir,
node_ip_address, node_manager_port, local_mode,
driver_name, stdout_file, stderr_file):
node_ip_address, node_manager_port, raylet_ip_address,
local_mode, driver_name, stdout_file, stderr_file):
self.is_driver = is_driver
self.is_local_mode = local_mode
@ -647,6 +647,7 @@ cdef class CoreWorker:
options.install_failure_signal_handler = True
options.node_ip_address = node_ip_address.encode("utf-8")
options.node_manager_port = node_manager_port
options.raylet_ip_address = raylet_ip_address.encode("utf-8")
options.driver_name = driver_name
options.stdout_file = stdout_file
options.stderr_file = stderr_file

View file

@ -195,6 +195,7 @@ cdef extern from "ray/core_worker/core_worker.h" nogil:
c_bool install_failure_signal_handler
c_string node_ip_address
int node_manager_port
c_string raylet_ip_address
c_string driver_name
c_string stdout_file
c_string stderr_file

View file

@ -80,6 +80,19 @@ class Node:
node_ip_address = ray.services.get_node_ip_address()
self._node_ip_address = node_ip_address
if ray_params.raylet_ip_address:
raylet_ip_address = ray_params.raylet_ip_address
else:
raylet_ip_address = node_ip_address
if raylet_ip_address != node_ip_address and (not connect_only or head):
raise ValueError(
"The raylet IP address should only be different than the node "
"IP address when connecting to an existing raylet; i.e., when "
"head=False and connect_only=True.")
self._raylet_ip_address = raylet_ip_address
ray_params.update_if_absent(
include_log_monitor=True,
resources={},
@ -122,7 +135,7 @@ class Node:
# from Redis.
address_info = ray.services.get_address_info_from_redis(
self.redis_address,
self._node_ip_address,
self._raylet_ip_address,
redis_password=self.redis_password)
self._plasma_store_socket_name = address_info[
"object_store_address"]
@ -229,9 +242,14 @@ class Node:
@property
def node_ip_address(self):
"""Get the cluster Redis address."""
"""Get the IP address of this node."""
return self._node_ip_address
@property
def raylet_ip_address(self):
"""Get the IP address of the raylet that this node connects to."""
return self._raylet_ip_address
@property
def address(self):
"""Get the cluster address."""
@ -287,6 +305,7 @@ class Node:
"""Get a dictionary of addresses."""
return {
"node_ip_address": self._node_ip_address,
"raylet_ip_address": self._raylet_ip_address,
"redis_address": self._redis_address,
"object_store_address": self._plasma_store_socket_name,
"raylet_socket_name": self._raylet_socket_name,
@ -429,7 +448,7 @@ class Node:
assert ray_constants.PROCESS_TYPE_REAPER not in self.all_processes
if process_info is not None:
self.all_processes[ray_constants.PROCESS_TYPE_REAPER] = [
process_info
process_info,
]
def start_redis(self):
@ -469,7 +488,7 @@ class Node:
fate_share=self.kernel_fate_share)
assert ray_constants.PROCESS_TYPE_LOG_MONITOR not in self.all_processes
self.all_processes[ray_constants.PROCESS_TYPE_LOG_MONITOR] = [
process_info
process_info,
]
def start_reporter(self):
@ -484,7 +503,7 @@ class Node:
assert ray_constants.PROCESS_TYPE_REPORTER not in self.all_processes
if process_info is not None:
self.all_processes[ray_constants.PROCESS_TYPE_REPORTER] = [
process_info
process_info,
]
def start_dashboard(self, require_webui):
@ -508,7 +527,7 @@ class Node:
assert ray_constants.PROCESS_TYPE_DASHBOARD not in self.all_processes
if process_info is not None:
self.all_processes[ray_constants.PROCESS_TYPE_DASHBOARD] = [
process_info
process_info,
]
redis_client = self.create_redis_client()
redis_client.hmset("webui", {"url": self._webui_url})
@ -527,7 +546,7 @@ class Node:
assert (
ray_constants.PROCESS_TYPE_PLASMA_STORE not in self.all_processes)
self.all_processes[ray_constants.PROCESS_TYPE_PLASMA_STORE] = [
process_info
process_info,
]
def start_gcs_server(self):
@ -544,7 +563,7 @@ class Node:
assert (
ray_constants.PROCESS_TYPE_GCS_SERVER not in self.all_processes)
self.all_processes[ray_constants.PROCESS_TYPE_GCS_SERVER] = [
process_info
process_info,
]
def start_raylet(self, use_valgrind=False, use_profiler=False):
@ -617,7 +636,7 @@ class Node:
assert (ray_constants.PROCESS_TYPE_RAYLET_MONITOR not in
self.all_processes)
self.all_processes[ray_constants.PROCESS_TYPE_RAYLET_MONITOR] = [
process_info
process_info,
]
def start_head_processes(self):

View file

@ -33,6 +33,8 @@ class RayParams:
object_manager_port int: The port to use for the object manager.
node_manager_port: The port to use for the node manager.
node_ip_address (str): The IP address of the node that we are on.
raylet_ip_address (str): The IP address of the raylet that this node
connects to.
object_id_seed (int): Used to seed the deterministic generation of
object IDs. The same value can be used across multiple runs of the
same job in order to generate the object IDs in a consistent
@ -95,6 +97,7 @@ class RayParams:
object_manager_port=None,
node_manager_port=None,
node_ip_address=None,
raylet_ip_address=None,
object_id_seed=None,
driver_mode=None,
redirect_worker_output=None,
@ -131,6 +134,7 @@ class RayParams:
self.object_manager_port = object_manager_port
self.node_manager_port = node_manager_port
self.node_ip_address = node_ip_address
self.raylet_ip_address = raylet_ip_address
self.driver_mode = driver_mode
self.redirect_worker_output = redirect_worker_output
self.redirect_output = redirect_output

View file

@ -73,8 +73,14 @@ DEFAULT_JAVA_WORKER_CLASSPATH = [
logger = logging.getLogger(__name__)
ProcessInfo = collections.namedtuple("ProcessInfo", [
"process", "stdout_file", "stderr_file", "use_valgrind", "use_gdb",
"use_valgrind_profiler", "use_perftools_profiler", "use_tmux"
"process",
"stdout_file",
"stderr_file",
"use_valgrind",
"use_gdb",
"use_valgrind_profiler",
"use_perftools_profiler",
"use_tmux",
])
@ -189,7 +195,7 @@ def get_address_info_from_redis_helper(redis_address,
return {
"object_store_address": relevant_client["ObjectStoreSocketName"],
"raylet_socket_name": relevant_client["RayletSocketName"],
"node_manager_port": relevant_client["NodeManagerPort"]
"node_manager_port": relevant_client["NodeManagerPort"],
}
@ -430,9 +436,12 @@ def start_ray_process(command,
logger.info("Detected environment variable '%s'.", gdb_env_var)
use_gdb = True
if sum(
[use_gdb, use_valgrind, use_valgrind_profiler, use_perftools_profiler
]) > 1:
if sum([
use_gdb,
use_valgrind,
use_valgrind_profiler,
use_perftools_profiler,
]) > 1:
raise ValueError(
"At most one of the 'use_gdb', 'use_valgrind', "
"'use_valgrind_profiler', and 'use_perftools_profiler' flags can "
@ -463,9 +472,12 @@ def start_ray_process(command,
if use_valgrind:
command = [
"valgrind", "--track-origins=yes", "--leak-check=full",
"--show-leak-kinds=all", "--leak-check-heuristics=stdstring",
"--error-exitcode=1"
"valgrind",
"--track-origins=yes",
"--leak-check=full",
"--show-leak-kinds=all",
"--leak-check-heuristics=stdstring",
"--error-exitcode=1",
] + command
if use_valgrind_profiler:
@ -1023,9 +1035,11 @@ def start_log_monitor(redis_address,
log_monitor_filepath = os.path.join(
os.path.dirname(os.path.abspath(__file__)), "log_monitor.py")
command = [
sys.executable, "-u", log_monitor_filepath,
sys.executable,
"-u",
log_monitor_filepath,
"--redis-address={}".format(redis_address),
"--logs-dir={}".format(logs_dir)
"--logs-dir={}".format(logs_dir),
]
if redis_password:
command += ["--redis-password", redis_password]
@ -1059,8 +1073,10 @@ def start_reporter(redis_address,
reporter_filepath = os.path.join(
os.path.dirname(os.path.abspath(__file__)), "reporter.py")
command = [
sys.executable, "-u", reporter_filepath,
"--redis-address={}".format(redis_address)
sys.executable,
"-u",
reporter_filepath,
"--redis-address={}".format(redis_address),
]
if redis_password:
command += ["--redis-password", redis_password]
@ -1114,9 +1130,13 @@ def start_dashboard(require_webui,
dashboard_filepath = os.path.join(
os.path.dirname(os.path.abspath(__file__)), "dashboard/dashboard.py")
command = [
sys.executable, "-u", dashboard_filepath, "--host={}".format(host),
"--port={}".format(port), "--redis-address={}".format(redis_address),
"--temp-dir={}".format(temp_dir)
sys.executable,
"-u",
dashboard_filepath,
"--host={}".format(host),
"--port={}".format(port),
"--redis-address={}".format(redis_address),
"--temp-dir={}".format(temp_dir),
]
if redis_password:
command += ["--redis-password", redis_password]
@ -1290,13 +1310,15 @@ def start_raylet(redis_address,
# Create the command that the Raylet will use to start workers.
start_worker_command = [
sys.executable, worker_path,
sys.executable,
worker_path,
"--node-ip-address={}".format(node_ip_address),
"--node-manager-port={}".format(node_manager_port),
"--object-store-name={}".format(plasma_store_name),
"--raylet-name={}".format(raylet_name),
"--redis-address={}".format(redis_address),
"--config-list={}".format(config_str), "--temp-dir={}".format(temp_dir)
"--config-list={}".format(config_str),
"--temp-dir={}".format(temp_dir),
]
if redis_password:
start_worker_command += ["--redis-password={}".format(redis_password)]
@ -1540,8 +1562,11 @@ def _start_plasma_store(plasma_store_memory,
plasma_store_memory = int(plasma_store_memory)
command = [
PLASMA_STORE_EXECUTABLE, "-s", socket_name, "-m",
str(plasma_store_memory)
PLASMA_STORE_EXECUTABLE,
"-s",
socket_name,
"-m",
str(plasma_store_memory),
]
if plasma_directory is not None:
command += ["-d", plasma_directory]
@ -1617,6 +1642,7 @@ def start_worker(node_ip_address,
redis_address,
worker_path,
temp_dir,
raylet_ip_address=None,
stdout_file=None,
stderr_file=None,
fate_share=None):
@ -1631,6 +1657,8 @@ def start_worker(node_ip_address,
worker_path (str): The path of the source code which the worker process
will run.
temp_dir (str): The path of the temp dir.
raylet_ip_address (str): The IP address of the worker's raylet. If not
provided, it defaults to the node_ip_address.
stdout_file: A file handle opened for writing to redirect stdout to. If
no redirection should happen, then this should be None.
stderr_file: A file handle opened for writing to redirect stderr to. If
@ -1640,12 +1668,17 @@ def start_worker(node_ip_address,
ProcessInfo for the process that was started.
"""
command = [
sys.executable, "-u", worker_path,
sys.executable,
"-u",
worker_path,
"--node-ip-address=" + node_ip_address,
"--object-store-name=" + object_store_name,
"--raylet-name=" + raylet_name,
"--redis-address=" + str(redis_address), "--temp-dir=" + temp_dir
"--redis-address=" + str(redis_address),
"--temp-dir=" + temp_dir,
]
if raylet_ip_address is not None:
command.append("--raylet-ip-address=" + raylet_ip_address)
process_info = start_ray_process(
command,
ray_constants.PROCESS_TYPE_WORKER,
@ -1678,8 +1711,10 @@ def start_monitor(redis_address,
monitor_path = os.path.join(
os.path.dirname(os.path.abspath(__file__)), "monitor.py")
command = [
sys.executable, "-u", monitor_path,
"--redis-address=" + str(redis_address)
sys.executable,
"-u",
monitor_path,
"--redis-address=" + str(redis_address),
]
if autoscaling_config:
command.append("--autoscaling-config=" + str(autoscaling_config))

View file

@ -355,7 +355,7 @@ class Worker:
"job_id": self.current_job_id.binary(),
"function_id": function_to_run_id,
"function": pickled_function,
"run_on_other_drivers": str(run_on_other_drivers)
"run_on_other_drivers": str(run_on_other_drivers),
})
self.redis_client.rpush("Exports", key)
# TODO(rkn): If the worker fails after it calls setnx and before it
@ -689,6 +689,8 @@ def init(address=None,
if node_ip_address is not None:
node_ip_address = services.address_to_ip(node_ip_address)
raylet_ip_address = node_ip_address
_internal_config = (json.loads(_internal_config)
if _internal_config else {})
# Set the internal config options for LRU eviction.
@ -708,6 +710,7 @@ def init(address=None,
redis_address=redis_address,
redis_port=redis_port,
node_ip_address=node_ip_address,
raylet_ip_address=raylet_ip_address,
object_id_seed=object_id_seed,
driver_mode=driver_mode,
redirect_worker_output=redirect_worker_output,
@ -788,6 +791,7 @@ def init(address=None,
# In this case, we only need to connect the node.
ray_params = ray.parameter.RayParams(
node_ip_address=node_ip_address,
raylet_ip_address=raylet_ip_address,
redis_address=redis_address,
redis_password=redis_password,
object_id_seed=object_id_seed,
@ -1053,7 +1057,7 @@ def listen_error_messages_raylet(worker, task_error_queue, threads_stopped):
job_id = error_data.job_id
if job_id not in [
worker.current_job_id.binary(),
JobID.nil().binary()
JobID.nil().binary(),
]:
continue
@ -1226,6 +1230,7 @@ def connect(node,
int(redis_port),
node.redis_password,
)
worker.core_worker = ray._raylet.CoreWorker(
(mode == SCRIPT_MODE or mode == LOCAL_MODE),
node.plasma_store_socket_name,
@ -1235,6 +1240,7 @@ def connect(node,
node.get_logs_dir_path(),
node.node_ip_address,
node.node_manager_port,
node.raylet_ip_address,
(mode == LOCAL_MODE),
driver_name,
log_stdout_file_name,
@ -1575,9 +1581,8 @@ def wait(object_ids, num_returns=1, timeout=None):
blocking_wait_inside_async_warned = True
if isinstance(object_ids, ObjectID):
raise TypeError(
"wait() expected a list of ray.ObjectID, got a single ray.ObjectID"
)
raise TypeError("wait() expected a list of ray.ObjectID, got a single "
"ray.ObjectID")
if not isinstance(object_ids, list):
raise TypeError(

View file

@ -21,6 +21,12 @@ parser.add_argument(
required=True,
type=int,
help="the port of the worker's node")
parser.add_argument(
"--raylet-ip-address",
required=False,
type=str,
default=None,
help="the ip address of the worker's raylet")
parser.add_argument(
"--redis-address",
required=True,
@ -89,8 +95,13 @@ if __name__ == "__main__":
internal_config[config_list[i]] = config_list[i + 1]
i += 2
raylet_ip_address = args.raylet_ip_address
if raylet_ip_address is None:
raylet_ip_address = args.node_ip_address
ray_params = RayParams(
node_ip_address=args.node_ip_address,
raylet_ip_address=raylet_ip_address,
node_manager_port=args.node_manager_port,
redis_address=args.redis_address,
redis_password=args.redis_password,

View file

@ -313,13 +313,13 @@ CoreWorker::CoreWorker(const CoreWorkerOptions &options, const WorkerID &worker_
// so that the worker (java/python .etc) can retrieve and handle the error
// instead of crashing.
auto grpc_client = rpc::NodeManagerWorkerClient::make(
options_.node_ip_address, options_.node_manager_port, *client_call_manager_);
options_.raylet_ip_address, options_.node_manager_port, *client_call_manager_);
ClientID local_raylet_id;
local_raylet_client_ = std::shared_ptr<raylet::RayletClient>(new raylet::RayletClient(
io_service_, std::move(grpc_client), options_.raylet_socket, GetWorkerID(),
(options_.worker_type == ray::WorkerType::WORKER),
worker_context_.GetCurrentJobID(), options_.language, &local_raylet_id,
core_worker_server_.GetPort()));
options_.node_ip_address, core_worker_server_.GetPort()));
connected_ = true;
// Set our own address.

View file

@ -84,6 +84,8 @@ struct CoreWorkerOptions {
std::string node_ip_address;
/// Port of the local raylet.
int node_manager_port;
/// IP address of the raylet.
std::string raylet_ip_address;
/// The name of the driver.
std::string driver_name;
/// The stdout file of this process.

View file

@ -13,8 +13,11 @@
// limitations under the License.
#include "ray/core_worker/lib/java/io_ray_runtime_RayNativeRuntime.h"
#include <jni.h>
#include <sstream>
#include "ray/common/id.h"
#include "ray/core_worker/core_worker.h"
#include "ray/core_worker/lib/java/jni_utils.h"
@ -37,7 +40,6 @@ inline ray::gcs::GcsClientOptions ToGcsClientOptions(JNIEnv *env,
extern "C" {
#endif
JNIEXPORT void JNICALL Java_io_ray_runtime_RayNativeRuntime_nativeInitialize(
JNIEnv *env, jclass, jint workerMode, jstring nodeIpAddress, jint nodeManagerPort,
jstring driverName, jstring storeSocket, jstring rayletSocket, jbyteArray jobId,
@ -112,6 +114,7 @@ JNIEXPORT void JNICALL Java_io_ray_runtime_RayNativeRuntime_nativeInitialize(
false, // install_failure_signal_handler
JavaStringToNativeString(env, nodeIpAddress), // node_ip_address
static_cast<int>(nodeManagerPort), // node_manager_port
JavaStringToNativeString(env, nodeIpAddress), // raylet_ip_address
JavaStringToNativeString(env, driverName), // driver_name
"", // stdout_file
"", // stderr_file
@ -135,7 +138,7 @@ JNIEXPORT void JNICALL Java_io_ray_runtime_RayNativeRuntime_nativeRunTaskExecuto
}
JNIEXPORT void JNICALL Java_io_ray_runtime_RayNativeRuntime_nativeShutdown(JNIEnv *env,
jclass o) {
jclass o) {
ray::CoreWorkerProcess::Shutdown();
}

View file

@ -263,6 +263,7 @@ class CoreWorkerTest : public ::testing::Test {
true, // install_failure_signal_handler
"127.0.0.1", // node_ip_address
node_manager_port, // node_manager_port
"127.0.0.1", // raylet_ip_address
"core_worker_test", // driver_name
"", // stdout_file
"", // stderr_file

View file

@ -45,6 +45,7 @@ class MockWorker {
true, // install_failure_signal_handler
"127.0.0.1", // node_ip_address
node_manager_port, // node_manager_port
"127.0.0.1", // raylet_ip_address
"", // driver_name
"", // stdout_file
"", // stderr_file

View file

@ -16,6 +16,7 @@
#define RAY_GCS_PB_UTIL_H
#include <memory>
#include "ray/common/id.h"
#include "ray/common/task/task_spec.h"
#include "ray/protobuf/gcs.pb.h"
@ -29,17 +30,17 @@ namespace gcs {
/// \param job_id The ID of job that need to be registered or updated.
/// \param is_dead Whether the driver of this job is dead.
/// \param timestamp The UNIX timestamp of corresponding to this event.
/// \param node_manager_address Address of the node this job was started on.
/// \param driver_ip_address IP address of the driver that started this job.
/// \param driver_pid Process ID of the driver running this job.
/// \return The job table data created by this method.
inline std::shared_ptr<ray::rpc::JobTableData> CreateJobTableData(
const ray::JobID &job_id, bool is_dead, int64_t timestamp,
const std::string &node_manager_address, int64_t driver_pid) {
const std::string &driver_ip_address, int64_t driver_pid) {
auto job_info_ptr = std::make_shared<ray::rpc::JobTableData>();
job_info_ptr->set_job_id(job_id.Binary());
job_info_ptr->set_is_dead(is_dead);
job_info_ptr->set_timestamp(timestamp);
job_info_ptr->set_node_manager_address(node_manager_address);
job_info_ptr->set_driver_ip_address(driver_ip_address);
job_info_ptr->set_driver_pid(driver_pid);
return job_info_ptr;
}

View file

@ -13,7 +13,9 @@
// limitations under the License.
#include "ray/gcs/redis_accessor.h"
#include <boost/none.hpp>
#include "ray/gcs/pb_util.h"
#include "ray/gcs/redis_gcs_client.h"
#include "ray/util/logging.h"
@ -304,7 +306,7 @@ Status RedisJobInfoAccessor::AsyncMarkFinished(const JobID &job_id,
const StatusCallback &callback) {
std::shared_ptr<JobTableData> data_ptr =
CreateJobTableData(job_id, /*is_dead*/ true, /*time_stamp*/ std::time(nullptr),
/*node_manager_address*/ "", /*driver_pid*/ -1);
/*driver_ip_address*/ "", /*driver_pid*/ -1);
return DoAsyncAppend(data_ptr, callback);
}

View file

@ -15,6 +15,9 @@
#ifndef RAY_GCS_TEST_UTIL_H
#define RAY_GCS_TEST_UTIL_H
#include <memory>
#include <utility>
#include "src/ray/common/task/task.h"
#include "src/ray/common/task/task_util.h"
#include "src/ray/common/test_util.h"
@ -23,9 +26,6 @@
#include "src/ray/gcs/gcs_server/gcs_node_manager.h"
#include "src/ray/util/asio_util.h"
#include <memory>
#include <utility>
namespace ray {
struct Mocker {
@ -64,7 +64,7 @@ struct Mocker {
job_table_data->set_job_id(job_id.Binary());
job_table_data->set_is_dead(false);
job_table_data->set_timestamp(std::time(nullptr));
job_table_data->set_node_manager_address("127.0.0.1");
job_table_data->set_driver_ip_address("127.0.0.1");
job_table_data->set_driver_pid(5667L);
return job_table_data;
}

View file

@ -31,7 +31,7 @@ class RedisJobInfoAccessorTest : public AccessorTestBase<JobID, JobTableData> {
JobID job_id = JobID::FromInt(i);
std::shared_ptr<JobTableData> job_data_ptr =
CreateJobTableData(job_id, /*is_dead*/ false, /*timestamp*/ 1,
/*node_manager_address*/ "", /*driver_pid*/ i);
/*driver_ip_address*/ "", /*driver_pid*/ i);
id_to_data_[job_id] = job_data_ptr;
}
}

View file

@ -256,8 +256,8 @@ message JobTableData {
bool is_dead = 2;
// The UNIX timestamp corresponding to this event (job added or removed).
int64 timestamp = 3;
// IP of the node this job was started on.
string node_manager_address = 4;
// IP address of the driver that started this job.
string driver_ip_address = 4;
// Process ID of the driver running this job.
int64 driver_pid = 5;
}

View file

@ -151,6 +151,8 @@ table RegisterClientRequest {
// Language of this worker.
// TODO(hchen): Use `Language` in `common.proto`.
language: int;
// IP address of this worker.
ip_address: string;
// Port that this worker is listening on.
port: int;
}

View file

@ -1098,8 +1098,9 @@ void NodeManager::ProcessRegisterClientRequestMessage(
Language language = static_cast<Language>(message->language());
WorkerID worker_id = from_flatbuf<WorkerID>(*message->worker_id());
pid_t pid = message->worker_pid();
auto worker = std::make_shared<Worker>(worker_id, language, message->port(), client,
client_call_manager_);
std::string worker_ip_address = string_from_flatbuf(*message->ip_address());
auto worker = std::make_shared<Worker>(worker_id, language, worker_ip_address,
message->port(), client, client_call_manager_);
if (message->is_worker()) {
// Register the new worker.
if (worker_pool_.RegisterWorker(worker, pid).ok()) {
@ -1117,9 +1118,8 @@ void NodeManager::ProcessRegisterClientRequestMessage(
Status status = worker_pool_.RegisterDriver(worker);
if (status.ok()) {
local_queues_.AddDriverTaskId(driver_task_id);
auto job_data_ptr =
gcs::CreateJobTableData(job_id, /*is_dead*/ false, std::time(nullptr),
initial_config_.node_manager_address, pid);
auto job_data_ptr = gcs::CreateJobTableData(
job_id, /*is_dead*/ false, std::time(nullptr), worker_ip_address, pid);
RAY_CHECK_OK(gcs_client_->Jobs().AsyncAdd(job_data_ptr, nullptr));
}
}
@ -1260,8 +1260,8 @@ void NodeManager::ProcessDisconnectClientMessage(
// Publish the worker failure.
auto worker_failure_data_ptr = gcs::CreateWorkerFailureData(
self_node_id_, worker->WorkerId(), initial_config_.node_manager_address,
worker->Port(), time(nullptr), intentional_disconnect);
self_node_id_, worker->WorkerId(), worker->IpAddress(), worker->Port(),
time(nullptr), intentional_disconnect);
RAY_CHECK_OK(gcs_client_->Workers().AsyncReportWorkerFailure(worker_failure_data_ptr,
nullptr));
}
@ -1687,8 +1687,7 @@ void NodeManager::HandleRequestWorkerLease(const rpc::RequestWorkerLeaseRequest
ClientID spillback_to,
std::string address, int port) {
if (worker != nullptr) {
reply->mutable_worker_address()->set_ip_address(
initial_config_.node_manager_address);
reply->mutable_worker_address()->set_ip_address(worker->IpAddress());
reply->mutable_worker_address()->set_port(worker->Port());
reply->mutable_worker_address()->set_worker_id(worker->WorkerId().Binary());
reply->mutable_worker_address()->set_raylet_id(self_node_id_.Binary());

View file

@ -165,7 +165,8 @@ raylet::RayletClient::RayletClient(
boost::asio::io_service &io_service,
std::shared_ptr<rpc::NodeManagerWorkerClient> grpc_client,
const std::string &raylet_socket, const WorkerID &worker_id, bool is_worker,
const JobID &job_id, const Language &language, ClientID *raylet_id, int port)
const JobID &job_id, const Language &language, ClientID *raylet_id,
const std::string &ip_address, int port)
: grpc_client_(std::move(grpc_client)), worker_id_(worker_id), job_id_(job_id) {
// For C++14, we could use std::make_unique
conn_ = std::unique_ptr<raylet::RayletConnection>(
@ -174,7 +175,7 @@ raylet::RayletClient::RayletClient(
flatbuffers::FlatBufferBuilder fbb;
auto message = protocol::CreateRegisterClientRequest(
fbb, is_worker, to_flatbuf(fbb, worker_id), getpid(), to_flatbuf(fbb, job_id),
language, port);
language, fbb.CreateString(ip_address), port);
fbb.Finish(message);
// Register the process ID with the raylet.
// NOTE(swang): If raylet exits and we are registered as a worker, we will get killed.

View file

@ -154,13 +154,14 @@ class RayletClient : public PinObjectsInterface,
/// \param job_id The ID of the driver. This is non-nil if the client is a driver.
/// \param language Language of the worker.
/// \param raylet_id This will be populated with the local raylet's ClientID.
/// \param ip_address The IP address of the worker.
/// \param port The port that the worker will listen on for gRPC requests, if
/// any.
RayletClient(boost::asio::io_service &io_service,
std::shared_ptr<ray::rpc::NodeManagerWorkerClient> grpc_client,
const std::string &raylet_socket, const WorkerID &worker_id,
bool is_worker, const JobID &job_id, const Language &language,
ClientID *raylet_id, int port = -1);
ClientID *raylet_id, const std::string &ip_address, int port = -1);
/// Connect to the raylet via grpc only.
///

View file

@ -26,11 +26,13 @@ namespace ray {
namespace raylet {
/// A constructor responsible for initializing the state of a worker.
Worker::Worker(const WorkerID &worker_id, const Language &language, int port,
Worker::Worker(const WorkerID &worker_id, const Language &language,
const std::string &ip_address, int port,
std::shared_ptr<ClientConnection> connection,
rpc::ClientCallManager &client_call_manager)
: worker_id_(worker_id),
language_(language),
ip_address_(ip_address),
port_(port),
connection_(connection),
dead_(false),
@ -39,7 +41,7 @@ Worker::Worker(const WorkerID &worker_id, const Language &language, int port,
is_detached_actor_(false) {
if (port_ > 0) {
rpc::Address addr;
addr.set_ip_address("127.0.0.1");
addr.set_ip_address(ip_address_);
addr.set_port(port_);
rpc_client_ = std::unique_ptr<rpc::CoreWorkerClient>(
new rpc::CoreWorkerClient(addr, client_call_manager_));
@ -67,6 +69,8 @@ void Worker::SetProcess(Process proc) {
Language Worker::GetLanguage() const { return language_; }
const std::string Worker::IpAddress() const { return ip_address_; }
int Worker::Port() const { return port_; }
void Worker::AssignTaskId(const TaskID &task_id) { assigned_task_id_ = task_id; }

View file

@ -38,7 +38,8 @@ class Worker {
public:
/// A constructor that initializes a worker object.
/// NOTE: You MUST manually set the worker process.
Worker(const WorkerID &worker_id, const Language &language, int port,
Worker(const WorkerID &worker_id, const Language &language,
const std::string &ip_address, int port,
std::shared_ptr<ClientConnection> connection,
rpc::ClientCallManager &client_call_manager);
/// A destructor responsible for freeing all worker state.
@ -54,6 +55,7 @@ class Worker {
Process GetProcess() const;
void SetProcess(Process proc);
Language GetLanguage() const;
const std::string IpAddress() const;
int Port() const;
void AssignTaskId(const TaskID &task_id);
const TaskID &GetAssignedTaskId() const;
@ -131,6 +133,8 @@ class Worker {
Process proc_;
/// The language type of this worker.
Language language_;
/// IP address of this worker.
std::string ip_address_;
/// Port that this worker listens on.
/// If port <= 0, this indicates that the worker will not listen to a port.
int port_;

View file

@ -115,7 +115,7 @@ class WorkerPoolTest : public ::testing::Test {
ClientConnection::Create(client_handler, message_handler, std::move(socket),
"worker", {}, error_message_type_);
std::shared_ptr<Worker> worker = std::make_shared<Worker>(
WorkerID::FromRandom(), language, -1, client, client_call_manager_);
WorkerID::FromRandom(), language, "127.0.0.1", -1, client, client_call_manager_);
if (!proc.IsNull()) {
worker->SetProcess(proc);
}

View file

@ -304,6 +304,7 @@ class StreamingWorker {
true, // install_failure_signal_handler
"127.0.0.1", // node_ip_address
node_manager_port, // node_manager_port
"127.0.0.1", // raylet_ip_address
"", // driver_name
"", // stdout_file
"", // stderr_file

View file

@ -318,6 +318,7 @@ class StreamingQueueTestBase : public ::testing::TestWithParam<uint64_t> {
true, // install_failure_signal_handler
"127.0.0.1", // node_ip_address
node_manager_port_, // node_manager_port
"127.0.0.1", // raylet_ip_address
"queue_tests", // driver_name
"", // stdout_file
"", // stderr_file