From d4cae5f63298b71ec21d3b4bcc370e94d1cff48c Mon Sep 17 00:00:00 2001 From: Clark Zinzow Date: Thu, 16 Apr 2020 09:32:24 -0600 Subject: [PATCH] [Core] Added ability to specify different IP addresses for a core worker and its raylet. (#7985) --- python/ray/_raylet.pyx | 5 +- python/ray/includes/libcoreworker.pxd | 1 + python/ray/node.py | 37 +++++++-- python/ray/parameter.py | 4 + python/ray/services.py | 83 +++++++++++++------ python/ray/worker.py | 15 ++-- python/ray/workers/default_worker.py | 11 +++ src/ray/core_worker/core_worker.cc | 4 +- src/ray/core_worker/core_worker.h | 2 + .../java/io_ray_runtime_RayNativeRuntime.cc | 7 +- src/ray/core_worker/test/core_worker_test.cc | 1 + src/ray/core_worker/test/mock_worker.cc | 1 + src/ray/gcs/pb_util.h | 7 +- src/ray/gcs/redis_accessor.cc | 4 +- src/ray/gcs/test/gcs_test_util.h | 8 +- .../gcs/test/redis_job_info_accessor_test.cc | 2 +- src/ray/protobuf/gcs.proto | 4 +- src/ray/raylet/format/node_manager.fbs | 2 + src/ray/raylet/node_manager.cc | 17 ++-- src/ray/raylet/raylet_client.cc | 5 +- src/ray/raylet/raylet_client.h | 3 +- src/ray/raylet/worker.cc | 8 +- src/ray/raylet/worker.h | 6 +- src/ray/raylet/worker_pool_test.cc | 2 +- streaming/src/test/mock_actor.cc | 1 + streaming/src/test/queue_tests_base.h | 1 + 26 files changed, 170 insertions(+), 71 deletions(-) diff --git a/python/ray/_raylet.pyx b/python/ray/_raylet.pyx index 31e74d820..e537df4da 100644 --- a/python/ray/_raylet.pyx +++ b/python/ray/_raylet.pyx @@ -630,8 +630,8 @@ cdef class CoreWorker: def __cinit__(self, is_driver, store_socket, raylet_socket, JobID job_id, GcsClientOptions gcs_options, log_dir, - node_ip_address, node_manager_port, local_mode, - driver_name, stdout_file, stderr_file): + node_ip_address, node_manager_port, raylet_ip_address, + local_mode, driver_name, stdout_file, stderr_file): self.is_driver = is_driver self.is_local_mode = local_mode @@ -647,6 +647,7 @@ cdef class CoreWorker: options.install_failure_signal_handler = True options.node_ip_address = node_ip_address.encode("utf-8") options.node_manager_port = node_manager_port + options.raylet_ip_address = raylet_ip_address.encode("utf-8") options.driver_name = driver_name options.stdout_file = stdout_file options.stderr_file = stderr_file diff --git a/python/ray/includes/libcoreworker.pxd b/python/ray/includes/libcoreworker.pxd index d54e7a015..477678997 100644 --- a/python/ray/includes/libcoreworker.pxd +++ b/python/ray/includes/libcoreworker.pxd @@ -195,6 +195,7 @@ cdef extern from "ray/core_worker/core_worker.h" nogil: c_bool install_failure_signal_handler c_string node_ip_address int node_manager_port + c_string raylet_ip_address c_string driver_name c_string stdout_file c_string stderr_file diff --git a/python/ray/node.py b/python/ray/node.py index 884c3f220..12fbd5f1e 100644 --- a/python/ray/node.py +++ b/python/ray/node.py @@ -80,6 +80,19 @@ class Node: node_ip_address = ray.services.get_node_ip_address() self._node_ip_address = node_ip_address + if ray_params.raylet_ip_address: + raylet_ip_address = ray_params.raylet_ip_address + else: + raylet_ip_address = node_ip_address + + if raylet_ip_address != node_ip_address and (not connect_only or head): + raise ValueError( + "The raylet IP address should only be different than the node " + "IP address when connecting to an existing raylet; i.e., when " + "head=False and connect_only=True.") + + self._raylet_ip_address = raylet_ip_address + ray_params.update_if_absent( include_log_monitor=True, resources={}, @@ -122,7 +135,7 @@ class Node: # from Redis. address_info = ray.services.get_address_info_from_redis( self.redis_address, - self._node_ip_address, + self._raylet_ip_address, redis_password=self.redis_password) self._plasma_store_socket_name = address_info[ "object_store_address"] @@ -229,9 +242,14 @@ class Node: @property def node_ip_address(self): - """Get the cluster Redis address.""" + """Get the IP address of this node.""" return self._node_ip_address + @property + def raylet_ip_address(self): + """Get the IP address of the raylet that this node connects to.""" + return self._raylet_ip_address + @property def address(self): """Get the cluster address.""" @@ -287,6 +305,7 @@ class Node: """Get a dictionary of addresses.""" return { "node_ip_address": self._node_ip_address, + "raylet_ip_address": self._raylet_ip_address, "redis_address": self._redis_address, "object_store_address": self._plasma_store_socket_name, "raylet_socket_name": self._raylet_socket_name, @@ -429,7 +448,7 @@ class Node: assert ray_constants.PROCESS_TYPE_REAPER not in self.all_processes if process_info is not None: self.all_processes[ray_constants.PROCESS_TYPE_REAPER] = [ - process_info + process_info, ] def start_redis(self): @@ -469,7 +488,7 @@ class Node: fate_share=self.kernel_fate_share) assert ray_constants.PROCESS_TYPE_LOG_MONITOR not in self.all_processes self.all_processes[ray_constants.PROCESS_TYPE_LOG_MONITOR] = [ - process_info + process_info, ] def start_reporter(self): @@ -484,7 +503,7 @@ class Node: assert ray_constants.PROCESS_TYPE_REPORTER not in self.all_processes if process_info is not None: self.all_processes[ray_constants.PROCESS_TYPE_REPORTER] = [ - process_info + process_info, ] def start_dashboard(self, require_webui): @@ -508,7 +527,7 @@ class Node: assert ray_constants.PROCESS_TYPE_DASHBOARD not in self.all_processes if process_info is not None: self.all_processes[ray_constants.PROCESS_TYPE_DASHBOARD] = [ - process_info + process_info, ] redis_client = self.create_redis_client() redis_client.hmset("webui", {"url": self._webui_url}) @@ -527,7 +546,7 @@ class Node: assert ( ray_constants.PROCESS_TYPE_PLASMA_STORE not in self.all_processes) self.all_processes[ray_constants.PROCESS_TYPE_PLASMA_STORE] = [ - process_info + process_info, ] def start_gcs_server(self): @@ -544,7 +563,7 @@ class Node: assert ( ray_constants.PROCESS_TYPE_GCS_SERVER not in self.all_processes) self.all_processes[ray_constants.PROCESS_TYPE_GCS_SERVER] = [ - process_info + process_info, ] def start_raylet(self, use_valgrind=False, use_profiler=False): @@ -617,7 +636,7 @@ class Node: assert (ray_constants.PROCESS_TYPE_RAYLET_MONITOR not in self.all_processes) self.all_processes[ray_constants.PROCESS_TYPE_RAYLET_MONITOR] = [ - process_info + process_info, ] def start_head_processes(self): diff --git a/python/ray/parameter.py b/python/ray/parameter.py index 6cd4f6d4d..031b337bd 100644 --- a/python/ray/parameter.py +++ b/python/ray/parameter.py @@ -33,6 +33,8 @@ class RayParams: object_manager_port int: The port to use for the object manager. node_manager_port: The port to use for the node manager. node_ip_address (str): The IP address of the node that we are on. + raylet_ip_address (str): The IP address of the raylet that this node + connects to. object_id_seed (int): Used to seed the deterministic generation of object IDs. The same value can be used across multiple runs of the same job in order to generate the object IDs in a consistent @@ -95,6 +97,7 @@ class RayParams: object_manager_port=None, node_manager_port=None, node_ip_address=None, + raylet_ip_address=None, object_id_seed=None, driver_mode=None, redirect_worker_output=None, @@ -131,6 +134,7 @@ class RayParams: self.object_manager_port = object_manager_port self.node_manager_port = node_manager_port self.node_ip_address = node_ip_address + self.raylet_ip_address = raylet_ip_address self.driver_mode = driver_mode self.redirect_worker_output = redirect_worker_output self.redirect_output = redirect_output diff --git a/python/ray/services.py b/python/ray/services.py index 18c788269..64f2b7504 100644 --- a/python/ray/services.py +++ b/python/ray/services.py @@ -73,8 +73,14 @@ DEFAULT_JAVA_WORKER_CLASSPATH = [ logger = logging.getLogger(__name__) ProcessInfo = collections.namedtuple("ProcessInfo", [ - "process", "stdout_file", "stderr_file", "use_valgrind", "use_gdb", - "use_valgrind_profiler", "use_perftools_profiler", "use_tmux" + "process", + "stdout_file", + "stderr_file", + "use_valgrind", + "use_gdb", + "use_valgrind_profiler", + "use_perftools_profiler", + "use_tmux", ]) @@ -189,7 +195,7 @@ def get_address_info_from_redis_helper(redis_address, return { "object_store_address": relevant_client["ObjectStoreSocketName"], "raylet_socket_name": relevant_client["RayletSocketName"], - "node_manager_port": relevant_client["NodeManagerPort"] + "node_manager_port": relevant_client["NodeManagerPort"], } @@ -430,9 +436,12 @@ def start_ray_process(command, logger.info("Detected environment variable '%s'.", gdb_env_var) use_gdb = True - if sum( - [use_gdb, use_valgrind, use_valgrind_profiler, use_perftools_profiler - ]) > 1: + if sum([ + use_gdb, + use_valgrind, + use_valgrind_profiler, + use_perftools_profiler, + ]) > 1: raise ValueError( "At most one of the 'use_gdb', 'use_valgrind', " "'use_valgrind_profiler', and 'use_perftools_profiler' flags can " @@ -463,9 +472,12 @@ def start_ray_process(command, if use_valgrind: command = [ - "valgrind", "--track-origins=yes", "--leak-check=full", - "--show-leak-kinds=all", "--leak-check-heuristics=stdstring", - "--error-exitcode=1" + "valgrind", + "--track-origins=yes", + "--leak-check=full", + "--show-leak-kinds=all", + "--leak-check-heuristics=stdstring", + "--error-exitcode=1", ] + command if use_valgrind_profiler: @@ -1023,9 +1035,11 @@ def start_log_monitor(redis_address, log_monitor_filepath = os.path.join( os.path.dirname(os.path.abspath(__file__)), "log_monitor.py") command = [ - sys.executable, "-u", log_monitor_filepath, + sys.executable, + "-u", + log_monitor_filepath, "--redis-address={}".format(redis_address), - "--logs-dir={}".format(logs_dir) + "--logs-dir={}".format(logs_dir), ] if redis_password: command += ["--redis-password", redis_password] @@ -1059,8 +1073,10 @@ def start_reporter(redis_address, reporter_filepath = os.path.join( os.path.dirname(os.path.abspath(__file__)), "reporter.py") command = [ - sys.executable, "-u", reporter_filepath, - "--redis-address={}".format(redis_address) + sys.executable, + "-u", + reporter_filepath, + "--redis-address={}".format(redis_address), ] if redis_password: command += ["--redis-password", redis_password] @@ -1114,9 +1130,13 @@ def start_dashboard(require_webui, dashboard_filepath = os.path.join( os.path.dirname(os.path.abspath(__file__)), "dashboard/dashboard.py") command = [ - sys.executable, "-u", dashboard_filepath, "--host={}".format(host), - "--port={}".format(port), "--redis-address={}".format(redis_address), - "--temp-dir={}".format(temp_dir) + sys.executable, + "-u", + dashboard_filepath, + "--host={}".format(host), + "--port={}".format(port), + "--redis-address={}".format(redis_address), + "--temp-dir={}".format(temp_dir), ] if redis_password: command += ["--redis-password", redis_password] @@ -1290,13 +1310,15 @@ def start_raylet(redis_address, # Create the command that the Raylet will use to start workers. start_worker_command = [ - sys.executable, worker_path, + sys.executable, + worker_path, "--node-ip-address={}".format(node_ip_address), "--node-manager-port={}".format(node_manager_port), "--object-store-name={}".format(plasma_store_name), "--raylet-name={}".format(raylet_name), "--redis-address={}".format(redis_address), - "--config-list={}".format(config_str), "--temp-dir={}".format(temp_dir) + "--config-list={}".format(config_str), + "--temp-dir={}".format(temp_dir), ] if redis_password: start_worker_command += ["--redis-password={}".format(redis_password)] @@ -1540,8 +1562,11 @@ def _start_plasma_store(plasma_store_memory, plasma_store_memory = int(plasma_store_memory) command = [ - PLASMA_STORE_EXECUTABLE, "-s", socket_name, "-m", - str(plasma_store_memory) + PLASMA_STORE_EXECUTABLE, + "-s", + socket_name, + "-m", + str(plasma_store_memory), ] if plasma_directory is not None: command += ["-d", plasma_directory] @@ -1617,6 +1642,7 @@ def start_worker(node_ip_address, redis_address, worker_path, temp_dir, + raylet_ip_address=None, stdout_file=None, stderr_file=None, fate_share=None): @@ -1631,6 +1657,8 @@ def start_worker(node_ip_address, worker_path (str): The path of the source code which the worker process will run. temp_dir (str): The path of the temp dir. + raylet_ip_address (str): The IP address of the worker's raylet. If not + provided, it defaults to the node_ip_address. stdout_file: A file handle opened for writing to redirect stdout to. If no redirection should happen, then this should be None. stderr_file: A file handle opened for writing to redirect stderr to. If @@ -1640,12 +1668,17 @@ def start_worker(node_ip_address, ProcessInfo for the process that was started. """ command = [ - sys.executable, "-u", worker_path, + sys.executable, + "-u", + worker_path, "--node-ip-address=" + node_ip_address, "--object-store-name=" + object_store_name, "--raylet-name=" + raylet_name, - "--redis-address=" + str(redis_address), "--temp-dir=" + temp_dir + "--redis-address=" + str(redis_address), + "--temp-dir=" + temp_dir, ] + if raylet_ip_address is not None: + command.append("--raylet-ip-address=" + raylet_ip_address) process_info = start_ray_process( command, ray_constants.PROCESS_TYPE_WORKER, @@ -1678,8 +1711,10 @@ def start_monitor(redis_address, monitor_path = os.path.join( os.path.dirname(os.path.abspath(__file__)), "monitor.py") command = [ - sys.executable, "-u", monitor_path, - "--redis-address=" + str(redis_address) + sys.executable, + "-u", + monitor_path, + "--redis-address=" + str(redis_address), ] if autoscaling_config: command.append("--autoscaling-config=" + str(autoscaling_config)) diff --git a/python/ray/worker.py b/python/ray/worker.py index e699b7230..2209761d4 100644 --- a/python/ray/worker.py +++ b/python/ray/worker.py @@ -355,7 +355,7 @@ class Worker: "job_id": self.current_job_id.binary(), "function_id": function_to_run_id, "function": pickled_function, - "run_on_other_drivers": str(run_on_other_drivers) + "run_on_other_drivers": str(run_on_other_drivers), }) self.redis_client.rpush("Exports", key) # TODO(rkn): If the worker fails after it calls setnx and before it @@ -689,6 +689,8 @@ def init(address=None, if node_ip_address is not None: node_ip_address = services.address_to_ip(node_ip_address) + raylet_ip_address = node_ip_address + _internal_config = (json.loads(_internal_config) if _internal_config else {}) # Set the internal config options for LRU eviction. @@ -708,6 +710,7 @@ def init(address=None, redis_address=redis_address, redis_port=redis_port, node_ip_address=node_ip_address, + raylet_ip_address=raylet_ip_address, object_id_seed=object_id_seed, driver_mode=driver_mode, redirect_worker_output=redirect_worker_output, @@ -788,6 +791,7 @@ def init(address=None, # In this case, we only need to connect the node. ray_params = ray.parameter.RayParams( node_ip_address=node_ip_address, + raylet_ip_address=raylet_ip_address, redis_address=redis_address, redis_password=redis_password, object_id_seed=object_id_seed, @@ -1053,7 +1057,7 @@ def listen_error_messages_raylet(worker, task_error_queue, threads_stopped): job_id = error_data.job_id if job_id not in [ worker.current_job_id.binary(), - JobID.nil().binary() + JobID.nil().binary(), ]: continue @@ -1226,6 +1230,7 @@ def connect(node, int(redis_port), node.redis_password, ) + worker.core_worker = ray._raylet.CoreWorker( (mode == SCRIPT_MODE or mode == LOCAL_MODE), node.plasma_store_socket_name, @@ -1235,6 +1240,7 @@ def connect(node, node.get_logs_dir_path(), node.node_ip_address, node.node_manager_port, + node.raylet_ip_address, (mode == LOCAL_MODE), driver_name, log_stdout_file_name, @@ -1575,9 +1581,8 @@ def wait(object_ids, num_returns=1, timeout=None): blocking_wait_inside_async_warned = True if isinstance(object_ids, ObjectID): - raise TypeError( - "wait() expected a list of ray.ObjectID, got a single ray.ObjectID" - ) + raise TypeError("wait() expected a list of ray.ObjectID, got a single " + "ray.ObjectID") if not isinstance(object_ids, list): raise TypeError( diff --git a/python/ray/workers/default_worker.py b/python/ray/workers/default_worker.py index 5f0285e27..8587c62ae 100644 --- a/python/ray/workers/default_worker.py +++ b/python/ray/workers/default_worker.py @@ -21,6 +21,12 @@ parser.add_argument( required=True, type=int, help="the port of the worker's node") +parser.add_argument( + "--raylet-ip-address", + required=False, + type=str, + default=None, + help="the ip address of the worker's raylet") parser.add_argument( "--redis-address", required=True, @@ -89,8 +95,13 @@ if __name__ == "__main__": internal_config[config_list[i]] = config_list[i + 1] i += 2 + raylet_ip_address = args.raylet_ip_address + if raylet_ip_address is None: + raylet_ip_address = args.node_ip_address + ray_params = RayParams( node_ip_address=args.node_ip_address, + raylet_ip_address=raylet_ip_address, node_manager_port=args.node_manager_port, redis_address=args.redis_address, redis_password=args.redis_password, diff --git a/src/ray/core_worker/core_worker.cc b/src/ray/core_worker/core_worker.cc index 3d54c0f4e..81840996b 100644 --- a/src/ray/core_worker/core_worker.cc +++ b/src/ray/core_worker/core_worker.cc @@ -313,13 +313,13 @@ CoreWorker::CoreWorker(const CoreWorkerOptions &options, const WorkerID &worker_ // so that the worker (java/python .etc) can retrieve and handle the error // instead of crashing. auto grpc_client = rpc::NodeManagerWorkerClient::make( - options_.node_ip_address, options_.node_manager_port, *client_call_manager_); + options_.raylet_ip_address, options_.node_manager_port, *client_call_manager_); ClientID local_raylet_id; local_raylet_client_ = std::shared_ptr(new raylet::RayletClient( io_service_, std::move(grpc_client), options_.raylet_socket, GetWorkerID(), (options_.worker_type == ray::WorkerType::WORKER), worker_context_.GetCurrentJobID(), options_.language, &local_raylet_id, - core_worker_server_.GetPort())); + options_.node_ip_address, core_worker_server_.GetPort())); connected_ = true; // Set our own address. diff --git a/src/ray/core_worker/core_worker.h b/src/ray/core_worker/core_worker.h index 1a803cc63..733f91484 100644 --- a/src/ray/core_worker/core_worker.h +++ b/src/ray/core_worker/core_worker.h @@ -84,6 +84,8 @@ struct CoreWorkerOptions { std::string node_ip_address; /// Port of the local raylet. int node_manager_port; + /// IP address of the raylet. + std::string raylet_ip_address; /// The name of the driver. std::string driver_name; /// The stdout file of this process. diff --git a/src/ray/core_worker/lib/java/io_ray_runtime_RayNativeRuntime.cc b/src/ray/core_worker/lib/java/io_ray_runtime_RayNativeRuntime.cc index 94acbad06..e441f0d18 100644 --- a/src/ray/core_worker/lib/java/io_ray_runtime_RayNativeRuntime.cc +++ b/src/ray/core_worker/lib/java/io_ray_runtime_RayNativeRuntime.cc @@ -13,8 +13,11 @@ // limitations under the License. #include "ray/core_worker/lib/java/io_ray_runtime_RayNativeRuntime.h" + #include + #include + #include "ray/common/id.h" #include "ray/core_worker/core_worker.h" #include "ray/core_worker/lib/java/jni_utils.h" @@ -37,7 +40,6 @@ inline ray::gcs::GcsClientOptions ToGcsClientOptions(JNIEnv *env, extern "C" { #endif - JNIEXPORT void JNICALL Java_io_ray_runtime_RayNativeRuntime_nativeInitialize( JNIEnv *env, jclass, jint workerMode, jstring nodeIpAddress, jint nodeManagerPort, jstring driverName, jstring storeSocket, jstring rayletSocket, jbyteArray jobId, @@ -112,6 +114,7 @@ JNIEXPORT void JNICALL Java_io_ray_runtime_RayNativeRuntime_nativeInitialize( false, // install_failure_signal_handler JavaStringToNativeString(env, nodeIpAddress), // node_ip_address static_cast(nodeManagerPort), // node_manager_port + JavaStringToNativeString(env, nodeIpAddress), // raylet_ip_address JavaStringToNativeString(env, driverName), // driver_name "", // stdout_file "", // stderr_file @@ -135,7 +138,7 @@ JNIEXPORT void JNICALL Java_io_ray_runtime_RayNativeRuntime_nativeRunTaskExecuto } JNIEXPORT void JNICALL Java_io_ray_runtime_RayNativeRuntime_nativeShutdown(JNIEnv *env, - jclass o) { + jclass o) { ray::CoreWorkerProcess::Shutdown(); } diff --git a/src/ray/core_worker/test/core_worker_test.cc b/src/ray/core_worker/test/core_worker_test.cc index 3aba27809..7f5369710 100644 --- a/src/ray/core_worker/test/core_worker_test.cc +++ b/src/ray/core_worker/test/core_worker_test.cc @@ -263,6 +263,7 @@ class CoreWorkerTest : public ::testing::Test { true, // install_failure_signal_handler "127.0.0.1", // node_ip_address node_manager_port, // node_manager_port + "127.0.0.1", // raylet_ip_address "core_worker_test", // driver_name "", // stdout_file "", // stderr_file diff --git a/src/ray/core_worker/test/mock_worker.cc b/src/ray/core_worker/test/mock_worker.cc index 8c28b87ed..cdb82187f 100644 --- a/src/ray/core_worker/test/mock_worker.cc +++ b/src/ray/core_worker/test/mock_worker.cc @@ -45,6 +45,7 @@ class MockWorker { true, // install_failure_signal_handler "127.0.0.1", // node_ip_address node_manager_port, // node_manager_port + "127.0.0.1", // raylet_ip_address "", // driver_name "", // stdout_file "", // stderr_file diff --git a/src/ray/gcs/pb_util.h b/src/ray/gcs/pb_util.h index 94b416967..66959b721 100644 --- a/src/ray/gcs/pb_util.h +++ b/src/ray/gcs/pb_util.h @@ -16,6 +16,7 @@ #define RAY_GCS_PB_UTIL_H #include + #include "ray/common/id.h" #include "ray/common/task/task_spec.h" #include "ray/protobuf/gcs.pb.h" @@ -29,17 +30,17 @@ namespace gcs { /// \param job_id The ID of job that need to be registered or updated. /// \param is_dead Whether the driver of this job is dead. /// \param timestamp The UNIX timestamp of corresponding to this event. -/// \param node_manager_address Address of the node this job was started on. +/// \param driver_ip_address IP address of the driver that started this job. /// \param driver_pid Process ID of the driver running this job. /// \return The job table data created by this method. inline std::shared_ptr CreateJobTableData( const ray::JobID &job_id, bool is_dead, int64_t timestamp, - const std::string &node_manager_address, int64_t driver_pid) { + const std::string &driver_ip_address, int64_t driver_pid) { auto job_info_ptr = std::make_shared(); job_info_ptr->set_job_id(job_id.Binary()); job_info_ptr->set_is_dead(is_dead); job_info_ptr->set_timestamp(timestamp); - job_info_ptr->set_node_manager_address(node_manager_address); + job_info_ptr->set_driver_ip_address(driver_ip_address); job_info_ptr->set_driver_pid(driver_pid); return job_info_ptr; } diff --git a/src/ray/gcs/redis_accessor.cc b/src/ray/gcs/redis_accessor.cc index 1eeb191ef..6eb7c8df2 100644 --- a/src/ray/gcs/redis_accessor.cc +++ b/src/ray/gcs/redis_accessor.cc @@ -13,7 +13,9 @@ // limitations under the License. #include "ray/gcs/redis_accessor.h" + #include + #include "ray/gcs/pb_util.h" #include "ray/gcs/redis_gcs_client.h" #include "ray/util/logging.h" @@ -304,7 +306,7 @@ Status RedisJobInfoAccessor::AsyncMarkFinished(const JobID &job_id, const StatusCallback &callback) { std::shared_ptr data_ptr = CreateJobTableData(job_id, /*is_dead*/ true, /*time_stamp*/ std::time(nullptr), - /*node_manager_address*/ "", /*driver_pid*/ -1); + /*driver_ip_address*/ "", /*driver_pid*/ -1); return DoAsyncAppend(data_ptr, callback); } diff --git a/src/ray/gcs/test/gcs_test_util.h b/src/ray/gcs/test/gcs_test_util.h index 9a6a837bf..c7a310835 100644 --- a/src/ray/gcs/test/gcs_test_util.h +++ b/src/ray/gcs/test/gcs_test_util.h @@ -15,6 +15,9 @@ #ifndef RAY_GCS_TEST_UTIL_H #define RAY_GCS_TEST_UTIL_H +#include +#include + #include "src/ray/common/task/task.h" #include "src/ray/common/task/task_util.h" #include "src/ray/common/test_util.h" @@ -23,9 +26,6 @@ #include "src/ray/gcs/gcs_server/gcs_node_manager.h" #include "src/ray/util/asio_util.h" -#include -#include - namespace ray { struct Mocker { @@ -64,7 +64,7 @@ struct Mocker { job_table_data->set_job_id(job_id.Binary()); job_table_data->set_is_dead(false); job_table_data->set_timestamp(std::time(nullptr)); - job_table_data->set_node_manager_address("127.0.0.1"); + job_table_data->set_driver_ip_address("127.0.0.1"); job_table_data->set_driver_pid(5667L); return job_table_data; } diff --git a/src/ray/gcs/test/redis_job_info_accessor_test.cc b/src/ray/gcs/test/redis_job_info_accessor_test.cc index 1c55adeab..f1752f224 100644 --- a/src/ray/gcs/test/redis_job_info_accessor_test.cc +++ b/src/ray/gcs/test/redis_job_info_accessor_test.cc @@ -31,7 +31,7 @@ class RedisJobInfoAccessorTest : public AccessorTestBase { JobID job_id = JobID::FromInt(i); std::shared_ptr job_data_ptr = CreateJobTableData(job_id, /*is_dead*/ false, /*timestamp*/ 1, - /*node_manager_address*/ "", /*driver_pid*/ i); + /*driver_ip_address*/ "", /*driver_pid*/ i); id_to_data_[job_id] = job_data_ptr; } } diff --git a/src/ray/protobuf/gcs.proto b/src/ray/protobuf/gcs.proto index 47184d9f1..3829a74e9 100644 --- a/src/ray/protobuf/gcs.proto +++ b/src/ray/protobuf/gcs.proto @@ -256,8 +256,8 @@ message JobTableData { bool is_dead = 2; // The UNIX timestamp corresponding to this event (job added or removed). int64 timestamp = 3; - // IP of the node this job was started on. - string node_manager_address = 4; + // IP address of the driver that started this job. + string driver_ip_address = 4; // Process ID of the driver running this job. int64 driver_pid = 5; } diff --git a/src/ray/raylet/format/node_manager.fbs b/src/ray/raylet/format/node_manager.fbs index 65b02e630..3942cae5b 100644 --- a/src/ray/raylet/format/node_manager.fbs +++ b/src/ray/raylet/format/node_manager.fbs @@ -151,6 +151,8 @@ table RegisterClientRequest { // Language of this worker. // TODO(hchen): Use `Language` in `common.proto`. language: int; + // IP address of this worker. + ip_address: string; // Port that this worker is listening on. port: int; } diff --git a/src/ray/raylet/node_manager.cc b/src/ray/raylet/node_manager.cc index e83c867d2..0c3b0d1f5 100644 --- a/src/ray/raylet/node_manager.cc +++ b/src/ray/raylet/node_manager.cc @@ -1098,8 +1098,9 @@ void NodeManager::ProcessRegisterClientRequestMessage( Language language = static_cast(message->language()); WorkerID worker_id = from_flatbuf(*message->worker_id()); pid_t pid = message->worker_pid(); - auto worker = std::make_shared(worker_id, language, message->port(), client, - client_call_manager_); + std::string worker_ip_address = string_from_flatbuf(*message->ip_address()); + auto worker = std::make_shared(worker_id, language, worker_ip_address, + message->port(), client, client_call_manager_); if (message->is_worker()) { // Register the new worker. if (worker_pool_.RegisterWorker(worker, pid).ok()) { @@ -1117,9 +1118,8 @@ void NodeManager::ProcessRegisterClientRequestMessage( Status status = worker_pool_.RegisterDriver(worker); if (status.ok()) { local_queues_.AddDriverTaskId(driver_task_id); - auto job_data_ptr = - gcs::CreateJobTableData(job_id, /*is_dead*/ false, std::time(nullptr), - initial_config_.node_manager_address, pid); + auto job_data_ptr = gcs::CreateJobTableData( + job_id, /*is_dead*/ false, std::time(nullptr), worker_ip_address, pid); RAY_CHECK_OK(gcs_client_->Jobs().AsyncAdd(job_data_ptr, nullptr)); } } @@ -1260,8 +1260,8 @@ void NodeManager::ProcessDisconnectClientMessage( // Publish the worker failure. auto worker_failure_data_ptr = gcs::CreateWorkerFailureData( - self_node_id_, worker->WorkerId(), initial_config_.node_manager_address, - worker->Port(), time(nullptr), intentional_disconnect); + self_node_id_, worker->WorkerId(), worker->IpAddress(), worker->Port(), + time(nullptr), intentional_disconnect); RAY_CHECK_OK(gcs_client_->Workers().AsyncReportWorkerFailure(worker_failure_data_ptr, nullptr)); } @@ -1687,8 +1687,7 @@ void NodeManager::HandleRequestWorkerLease(const rpc::RequestWorkerLeaseRequest ClientID spillback_to, std::string address, int port) { if (worker != nullptr) { - reply->mutable_worker_address()->set_ip_address( - initial_config_.node_manager_address); + reply->mutable_worker_address()->set_ip_address(worker->IpAddress()); reply->mutable_worker_address()->set_port(worker->Port()); reply->mutable_worker_address()->set_worker_id(worker->WorkerId().Binary()); reply->mutable_worker_address()->set_raylet_id(self_node_id_.Binary()); diff --git a/src/ray/raylet/raylet_client.cc b/src/ray/raylet/raylet_client.cc index a68b9c35e..80c9be933 100644 --- a/src/ray/raylet/raylet_client.cc +++ b/src/ray/raylet/raylet_client.cc @@ -165,7 +165,8 @@ raylet::RayletClient::RayletClient( boost::asio::io_service &io_service, std::shared_ptr grpc_client, const std::string &raylet_socket, const WorkerID &worker_id, bool is_worker, - const JobID &job_id, const Language &language, ClientID *raylet_id, int port) + const JobID &job_id, const Language &language, ClientID *raylet_id, + const std::string &ip_address, int port) : grpc_client_(std::move(grpc_client)), worker_id_(worker_id), job_id_(job_id) { // For C++14, we could use std::make_unique conn_ = std::unique_ptr( @@ -174,7 +175,7 @@ raylet::RayletClient::RayletClient( flatbuffers::FlatBufferBuilder fbb; auto message = protocol::CreateRegisterClientRequest( fbb, is_worker, to_flatbuf(fbb, worker_id), getpid(), to_flatbuf(fbb, job_id), - language, port); + language, fbb.CreateString(ip_address), port); fbb.Finish(message); // Register the process ID with the raylet. // NOTE(swang): If raylet exits and we are registered as a worker, we will get killed. diff --git a/src/ray/raylet/raylet_client.h b/src/ray/raylet/raylet_client.h index e82d6cc4d..c0c7a2026 100644 --- a/src/ray/raylet/raylet_client.h +++ b/src/ray/raylet/raylet_client.h @@ -154,13 +154,14 @@ class RayletClient : public PinObjectsInterface, /// \param job_id The ID of the driver. This is non-nil if the client is a driver. /// \param language Language of the worker. /// \param raylet_id This will be populated with the local raylet's ClientID. + /// \param ip_address The IP address of the worker. /// \param port The port that the worker will listen on for gRPC requests, if /// any. RayletClient(boost::asio::io_service &io_service, std::shared_ptr grpc_client, const std::string &raylet_socket, const WorkerID &worker_id, bool is_worker, const JobID &job_id, const Language &language, - ClientID *raylet_id, int port = -1); + ClientID *raylet_id, const std::string &ip_address, int port = -1); /// Connect to the raylet via grpc only. /// diff --git a/src/ray/raylet/worker.cc b/src/ray/raylet/worker.cc index 6d92c31c7..92fb2e3b8 100644 --- a/src/ray/raylet/worker.cc +++ b/src/ray/raylet/worker.cc @@ -26,11 +26,13 @@ namespace ray { namespace raylet { /// A constructor responsible for initializing the state of a worker. -Worker::Worker(const WorkerID &worker_id, const Language &language, int port, +Worker::Worker(const WorkerID &worker_id, const Language &language, + const std::string &ip_address, int port, std::shared_ptr connection, rpc::ClientCallManager &client_call_manager) : worker_id_(worker_id), language_(language), + ip_address_(ip_address), port_(port), connection_(connection), dead_(false), @@ -39,7 +41,7 @@ Worker::Worker(const WorkerID &worker_id, const Language &language, int port, is_detached_actor_(false) { if (port_ > 0) { rpc::Address addr; - addr.set_ip_address("127.0.0.1"); + addr.set_ip_address(ip_address_); addr.set_port(port_); rpc_client_ = std::unique_ptr( new rpc::CoreWorkerClient(addr, client_call_manager_)); @@ -67,6 +69,8 @@ void Worker::SetProcess(Process proc) { Language Worker::GetLanguage() const { return language_; } +const std::string Worker::IpAddress() const { return ip_address_; } + int Worker::Port() const { return port_; } void Worker::AssignTaskId(const TaskID &task_id) { assigned_task_id_ = task_id; } diff --git a/src/ray/raylet/worker.h b/src/ray/raylet/worker.h index 7b6b4e084..02c42eb14 100644 --- a/src/ray/raylet/worker.h +++ b/src/ray/raylet/worker.h @@ -38,7 +38,8 @@ class Worker { public: /// A constructor that initializes a worker object. /// NOTE: You MUST manually set the worker process. - Worker(const WorkerID &worker_id, const Language &language, int port, + Worker(const WorkerID &worker_id, const Language &language, + const std::string &ip_address, int port, std::shared_ptr connection, rpc::ClientCallManager &client_call_manager); /// A destructor responsible for freeing all worker state. @@ -54,6 +55,7 @@ class Worker { Process GetProcess() const; void SetProcess(Process proc); Language GetLanguage() const; + const std::string IpAddress() const; int Port() const; void AssignTaskId(const TaskID &task_id); const TaskID &GetAssignedTaskId() const; @@ -131,6 +133,8 @@ class Worker { Process proc_; /// The language type of this worker. Language language_; + /// IP address of this worker. + std::string ip_address_; /// Port that this worker listens on. /// If port <= 0, this indicates that the worker will not listen to a port. int port_; diff --git a/src/ray/raylet/worker_pool_test.cc b/src/ray/raylet/worker_pool_test.cc index e48987b61..3c1d9e4fa 100644 --- a/src/ray/raylet/worker_pool_test.cc +++ b/src/ray/raylet/worker_pool_test.cc @@ -115,7 +115,7 @@ class WorkerPoolTest : public ::testing::Test { ClientConnection::Create(client_handler, message_handler, std::move(socket), "worker", {}, error_message_type_); std::shared_ptr worker = std::make_shared( - WorkerID::FromRandom(), language, -1, client, client_call_manager_); + WorkerID::FromRandom(), language, "127.0.0.1", -1, client, client_call_manager_); if (!proc.IsNull()) { worker->SetProcess(proc); } diff --git a/streaming/src/test/mock_actor.cc b/streaming/src/test/mock_actor.cc index f5997d8d6..266a4897b 100644 --- a/streaming/src/test/mock_actor.cc +++ b/streaming/src/test/mock_actor.cc @@ -304,6 +304,7 @@ class StreamingWorker { true, // install_failure_signal_handler "127.0.0.1", // node_ip_address node_manager_port, // node_manager_port + "127.0.0.1", // raylet_ip_address "", // driver_name "", // stdout_file "", // stderr_file diff --git a/streaming/src/test/queue_tests_base.h b/streaming/src/test/queue_tests_base.h index f1b2029c6..8e7ef168e 100644 --- a/streaming/src/test/queue_tests_base.h +++ b/streaming/src/test/queue_tests_base.h @@ -318,6 +318,7 @@ class StreamingQueueTestBase : public ::testing::TestWithParam { true, // install_failure_signal_handler "127.0.0.1", // node_ip_address node_manager_port_, // node_manager_port + "127.0.0.1", // raylet_ip_address "queue_tests", // driver_name "", // stdout_file "", // stderr_file