Use grpc for communication from worker to local raylet (task submission and direct actor args only) (#6118)

* Skeleton for SubmitTask proto

* Pass through node manager port, connect in raylet client

* Switch submit task to grpc

* Check port in use

* doc

* Remove default port, set port randomly from driver

* update

* Fix test

* Fix object manager test
This commit is contained in:
Stephanie Wang 2019-11-11 21:17:25 -08:00 committed by GitHub
parent f48293f96d
commit 35d177f459
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
23 changed files with 257 additions and 93 deletions

View file

@ -687,7 +687,7 @@ cdef class CoreWorker:
def __cinit__(self, is_driver, store_socket, raylet_socket,
JobID job_id, GcsClientOptions gcs_options, log_dir,
node_ip_address):
node_ip_address, node_manager_port):
assert pyarrow is not None, ("Expected pyarrow to be imported from "
"outside _raylet. See __init__.py for "
"details.")
@ -697,8 +697,8 @@ cdef class CoreWorker:
LANGUAGE_PYTHON, store_socket.encode("ascii"),
raylet_socket.encode("ascii"), job_id.native(),
gcs_options.native()[0], log_dir.encode("utf-8"),
node_ip_address.encode("utf-8"), task_execution_handler,
check_signals, exit_handler))
node_ip_address.encode("utf-8"), node_manager_port,
task_execution_handler, check_signals, exit_handler))
def disconnect(self):
with nogil:

View file

@ -55,6 +55,7 @@ cdef extern from "ray/core_worker/core_worker.h" nogil:
const c_string &raylet_socket, const CJobID &job_id,
const CGcsClientOptions &gcs_options,
const c_string &log_dir, const c_string &node_ip_address,
int node_manager_port,
CRayStatus (
CTaskType task_type,
const CRayFunction &ray_function,

View file

@ -10,6 +10,7 @@ import json
import os
import logging
import signal
import socket
import sys
import tempfile
import threading
@ -117,7 +118,8 @@ class Node(object):
# If user does not provide the socket name, get it from Redis.
if (self._plasma_store_socket_name is None
or self._raylet_socket_name is None):
or self._raylet_socket_name is None
or self._ray_params.node_manager_port is None):
# Get the address info of the processes to connect to
# from Redis.
address_info = ray.services.get_address_info_from_redis(
@ -127,6 +129,8 @@ class Node(object):
self._plasma_store_socket_name = address_info[
"object_store_address"]
self._raylet_socket_name = address_info["raylet_socket_name"]
self._ray_params.node_manager_port = address_info[
"node_manager_port"]
else:
# If the user specified a socket name, use it.
self._plasma_store_socket_name = self._prepare_socket_file(
@ -144,6 +148,16 @@ class Node(object):
ray_params.include_java = (
ray.services.include_java_from_redis(redis_client))
if head or not connect_only:
# We need to start a local raylet.
if (self._ray_params.node_manager_port is None
or self._ray_params.node_manager_port == 0):
# No port specified. Pick a random port for the raylet to use.
# NOTE: There is a possible but unlikely race condition where
# the port is bound by another process between now and when the
# raylet starts.
self._ray_params.node_manager_port = self._get_unused_port()
# Start processes.
if head:
self.start_head_processes()
@ -294,6 +308,11 @@ class Node(object):
"""Get the node's raylet socket name."""
return self._raylet_socket_name
@property
def node_manager_port(self):
"""Get the node manager's port."""
return self._ray_params.node_manager_port
@property
def address_info(self):
"""Get a dictionary of addresses."""
@ -390,6 +409,13 @@ class Node(object):
log_stderr_file = open(log_stderr, "a", buffering=1)
return log_stdout_file, log_stderr_file
def _get_unused_port(self):
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
s.bind(("", 0))
port = s.getsockname()[1]
s.close()
return port
def _prepare_socket_file(self, socket_path, default_prefix):
"""Prepare the socket file for raylet and plasma.
@ -508,6 +534,7 @@ class Node(object):
process_info = ray.services.start_raylet(
self._redis_address,
self._node_ip_address,
self._ray_params.node_manager_port,
self._raylet_socket_name,
self._plasma_store_socket_name,
self._ray_params.worker_path,
@ -515,7 +542,6 @@ class Node(object):
self._session_dir,
self.get_resource_spec(),
self._ray_params.object_manager_port,
self._ray_params.node_manager_port,
self._ray_params.redis_password,
use_valgrind=use_valgrind,
use_profiler=use_profiler,

View file

@ -153,6 +153,7 @@ def get_address_info_from_redis_helper(redis_address,
return {
"object_store_address": relevant_client["ObjectStoreSocketName"],
"raylet_socket_name": relevant_client["RayletSocketName"],
"node_manager_port": relevant_client["NodeManagerPort"]
}
@ -1045,6 +1046,7 @@ def start_dashboard(host,
def start_raylet(redis_address,
node_ip_address,
node_manager_port,
raylet_name,
plasma_store_name,
worker_path,
@ -1052,7 +1054,6 @@ def start_raylet(redis_address,
session_dir,
resource_spec,
object_manager_port=None,
node_manager_port=None,
redis_password=None,
use_valgrind=False,
use_profiler=False,
@ -1068,6 +1069,8 @@ def start_raylet(redis_address,
Args:
redis_address (str): The address of the primary Redis server.
node_ip_address (str): The IP address of this node.
node_manager_port(int): The port to use for the node manager. This must
not be 0.
raylet_name (str): The name of the raylet socket to create.
plasma_store_name (str): The name of the plasma store socket to connect
to.
@ -1078,8 +1081,6 @@ def start_raylet(redis_address,
resource_spec (ResourceSpec): Resources for this raylet.
object_manager_port: The port to use for the object manager. If this is
None, then the object manager will choose its own port.
node_manager_port: The port to use for the node manager. If this is
None, then the node manager will choose its own port.
redis_password: The password to use when connecting to Redis.
use_valgrind (bool): True if the raylet should be started inside
of valgrind. If this is True, use_profiler must be False.
@ -1098,6 +1099,9 @@ def start_raylet(redis_address,
Returns:
ProcessInfo for the process that was started.
"""
# The caller must provide a node manager port so that we can correctly
# populate the command to start a worker.
assert node_manager_port is not None and node_manager_port != 0
config = config or {}
config_str = ",".join(["{},{}".format(*kv) for kv in config.items()])
@ -1137,13 +1141,14 @@ def start_raylet(redis_address,
# Create the command that the Raylet will use to start workers.
start_worker_command = ("{} {} "
"--node-ip-address={} "
"--node-manager-port={} "
"--object-store-name={} "
"--raylet-name={} "
"--redis-address={} "
"--temp-dir={}".format(
sys.executable, worker_path, node_ip_address,
plasma_store_name, raylet_name, redis_address,
temp_dir))
node_manager_port, plasma_store_name,
raylet_name, redis_address, temp_dir))
if redis_password:
start_worker_command += " --redis-password {}".format(redis_password)
@ -1151,10 +1156,6 @@ def start_raylet(redis_address,
# manager to choose its own port.
if object_manager_port is None:
object_manager_port = 0
# If the node manager port is None, then use 0 to cause the node manager
# to choose its own port.
if node_manager_port is None:
node_manager_port = 0
if load_code_from_local:
start_worker_command += " --load-code-from-local "

View file

@ -92,6 +92,8 @@ class Cluster(object):
self.webui_url = self.head_node.webui_url
else:
ray_params.update_if_absent(redis_address=self.redis_address)
# Let grpc pick a port.
ray_params.update(node_manager_port=0)
node = ray.node.Node(
ray_params,
head=False,

View file

@ -1215,6 +1215,7 @@ def connect(node,
gcs_options,
node.get_logs_dir_path(),
node.node_ip_address,
node.node_manager_port,
)
worker.raylet_client = ray._raylet.RayletClient(worker.core_worker)

View file

@ -19,6 +19,11 @@ parser.add_argument(
required=True,
type=str,
help="the ip address of the worker's node")
parser.add_argument(
"--node-manager-port",
required=True,
type=int,
help="the port of the worker's node")
parser.add_argument(
"--redis-address",
required=True,
@ -74,6 +79,7 @@ if __name__ == "__main__":
ray_params = RayParams(
node_ip_address=args.node_ip_address,
node_manager_port=args.node_manager_port,
redis_address=args.redis_address,
redis_password=args.redis_password,
plasma_store_socket_name=args.object_store_name,

View file

@ -60,6 +60,7 @@ CoreWorker::CoreWorker(const WorkerType worker_type, const Language language,
const std::string &store_socket, const std::string &raylet_socket,
const JobID &job_id, const gcs::GcsClientOptions &gcs_options,
const std::string &log_dir, const std::string &node_ip_address,
int node_manager_port,
const TaskExecutionCallback &task_execution_callback,
std::function<Status()> check_signals,
const std::function<void()> exit_handler)
@ -72,6 +73,7 @@ CoreWorker::CoreWorker(const WorkerType worker_type, const Language language,
heartbeat_timer_(io_service_),
worker_server_(WorkerTypeString(worker_type), 0 /* let grpc choose a port */),
gcs_client_(gcs_options),
client_call_manager_(io_service_),
memory_store_(std::make_shared<CoreWorkerMemoryStore>()),
task_execution_service_work_(task_execution_service_),
task_execution_callback_(task_execution_callback),
@ -117,8 +119,11 @@ CoreWorker::CoreWorker(const WorkerType worker_type, const Language language,
// connect to Raylet after a number of retries, this can be changed later
// so that the worker (java/python .etc) can retrieve and handle the error
// instead of crashing.
auto grpc_client = rpc::NodeManagerWorkerClient::make(
node_ip_address, node_manager_port, client_call_manager_);
raylet_client_ = std::unique_ptr<RayletClient>(new RayletClient(
raylet_socket, WorkerID::FromBinary(worker_context_.GetWorkerID().Binary()),
std::move(grpc_client), raylet_socket,
WorkerID::FromBinary(worker_context_.GetWorkerID().Binary()),
(worker_type_ == ray::WorkerType::WORKER), worker_context_.GetCurrentJobID(),
language_, worker_server_.GetPort()));
// Unfortunately the raylet client has to be constructed after the receivers.
@ -489,7 +494,8 @@ Status CoreWorker::SubmitTaskToRaylet(const TaskSpecification &task_spec) {
if (task_deps->size() > 0) {
for (size_t i = 0; i < num_returns; i++) {
reference_counter_.SetDependencies(task_spec.ReturnId(i, TaskTransportType::RAYLET), task_deps);
reference_counter_.SetDependencies(task_spec.ReturnId(i, TaskTransportType::RAYLET),
task_deps);
}
}

View file

@ -17,6 +17,7 @@
#include "ray/core_worker/transport/raylet_transport.h"
#include "ray/gcs/redis_gcs_client.h"
#include "ray/raylet/raylet_client.h"
#include "ray/rpc/node_manager/node_manager_client.h"
#include "ray/rpc/worker/worker_client.h"
#include "ray/rpc/worker/worker_server.h"
@ -58,6 +59,7 @@ class CoreWorker {
/// \param[in] log_dir Directory to write logs to. If this is empty, logs
/// won't be written to a file.
/// \param[in] node_ip_address IP address of the node.
/// \param[in] node_manager_port Port of the local raylet.
/// \param[in] task_execution_callback Language worker callback to execute tasks.
/// \parma[in] check_signals Language worker function to check for signals and handle
/// them. If the function returns anything but StatusOK, any long-running
@ -70,7 +72,7 @@ class CoreWorker {
const std::string &store_socket, const std::string &raylet_socket,
const JobID &job_id, const gcs::GcsClientOptions &gcs_options,
const std::string &log_dir, const std::string &node_ip_address,
const TaskExecutionCallback &task_execution_callback,
int node_manager_port, const TaskExecutionCallback &task_execution_callback,
std::function<Status()> check_signals = nullptr,
std::function<void()> exit_handler = nullptr);
@ -454,6 +456,9 @@ class CoreWorker {
// Client to the GCS shared by core worker interfaces.
gcs::RedisGcsClient gcs_client_;
/// The `ClientCallManager` object that is shared by all `NodeManagerClient`s.
rpc::ClientCallManager client_call_manager_;
// Client to the raylet shared by core worker interfaces.
std::unique_ptr<RayletClient> raylet_client_;

View file

@ -25,12 +25,17 @@
#include "ray/thirdparty/hiredis/hiredis.h"
#include "ray/util/test_util.h"
namespace ray {
namespace {
std::string store_executable;
std::string raylet_executable;
int node_manager_port = 0;
std::string mock_worker_executable;
} // namespace
namespace ray {
static void flushall_redis(void) {
redisContext *context = redisConnect("127.0.0.1", 6379);
freeReplyObject(redisCommand(context, "FLUSHALL"));
@ -92,8 +97,8 @@ class CoreWorkerTest : public ::testing::Test {
// a task can be scheduled to the desired node.
for (int i = 0; i < num_nodes; i++) {
raylet_socket_names_[i] =
StartRaylet(raylet_store_socket_names_[i], "127.0.0.1", "127.0.0.1",
"\"CPU,4.0,resource" + std::to_string(i) + ",10\"");
StartRaylet(raylet_store_socket_names_[i], "127.0.0.1", node_manager_port + i,
"127.0.0.1", "\"CPU,4.0,resource" + std::to_string(i) + ",10\"");
}
}
@ -134,12 +139,12 @@ class CoreWorkerTest : public ::testing::Test {
}
std::string StartRaylet(std::string store_socket_name, std::string node_ip_address,
std::string redis_address, std::string resource) {
int port, std::string redis_address, std::string resource) {
std::string raylet_socket_name = "/tmp/raylet" + ObjectID::FromRandom().Hex();
std::string ray_start_cmd = raylet_executable;
ray_start_cmd.append(" --raylet_socket_name=" + raylet_socket_name)
.append(" --store_socket_name=" + store_socket_name)
.append(" --object_manager_port=0 --node_manager_port=0")
.append(" --object_manager_port=0 --node_manager_port=" + std::to_string(port))
.append(" --node_ip_address=" + node_ip_address)
.append(" --redis_address=" + redis_address)
.append(" --redis_port=6379")
@ -147,7 +152,8 @@ class CoreWorkerTest : public ::testing::Test {
.append(" --maximum_startup_concurrency=10")
.append(" --static_resource_list=" + resource)
.append(" --python_worker_command=\"" + mock_worker_executable + " " +
store_socket_name + " " + raylet_socket_name + "\"")
store_socket_name + " " + raylet_socket_name + " " +
std::to_string(port) + "\"")
.append(" --config_list=initial_reconstruction_timeout_milliseconds,2000")
.append(" & echo $! > " + raylet_socket_name + ".pid");
@ -212,7 +218,7 @@ bool CoreWorkerTest::WaitForDirectCallActorState(CoreWorker &worker,
void CoreWorkerTest::TestNormalTask(std::unordered_map<std::string, double> &resources) {
CoreWorker driver(WorkerType::DRIVER, Language::PYTHON, raylet_store_socket_names_[0],
raylet_socket_names_[0], NextJobId(), gcs_options_, "", "127.0.0.1",
nullptr);
node_manager_port, nullptr);
// Test for tasks with by-value and by-ref args.
{
@ -255,7 +261,7 @@ void CoreWorkerTest::TestActorTask(std::unordered_map<std::string, double> &reso
bool is_direct_call) {
CoreWorker driver(WorkerType::DRIVER, Language::PYTHON, raylet_store_socket_names_[0],
raylet_socket_names_[0], NextJobId(), gcs_options_, "", "127.0.0.1",
nullptr);
node_manager_port, nullptr);
auto actor_id = CreateActorHelper(driver, resources, is_direct_call, 1000);
@ -338,7 +344,7 @@ void CoreWorkerTest::TestActorReconstruction(
std::unordered_map<std::string, double> &resources, bool is_direct_call) {
CoreWorker driver(WorkerType::DRIVER, Language::PYTHON, raylet_store_socket_names_[0],
raylet_socket_names_[0], NextJobId(), gcs_options_, "", "127.0.0.1",
nullptr);
node_manager_port, nullptr);
// creating actor.
auto actor_id = CreateActorHelper(driver, resources, is_direct_call, 1000);
@ -394,7 +400,7 @@ void CoreWorkerTest::TestActorFailure(std::unordered_map<std::string, double> &r
bool is_direct_call) {
CoreWorker driver(WorkerType::DRIVER, Language::PYTHON, raylet_store_socket_names_[0],
raylet_socket_names_[0], NextJobId(), gcs_options_, "", "127.0.0.1",
nullptr);
node_manager_port, nullptr);
// creating actor.
auto actor_id =
@ -539,7 +545,7 @@ TEST_F(ZeroNodeTest, TestTaskSpecPerf) {
TEST_F(SingleNodeTest, TestDirectActorTaskSubmissionPerf) {
CoreWorker driver(WorkerType::DRIVER, Language::PYTHON, raylet_store_socket_names_[0],
raylet_socket_names_[0], JobID::FromInt(1), gcs_options_, "",
"127.0.0.1", nullptr);
"127.0.0.1", node_manager_port, nullptr);
std::vector<ObjectID> object_ids;
// Create an actor.
std::unordered_map<std::string, double> resources;
@ -753,7 +759,8 @@ TEST_F(SingleNodeTest, TestMemoryStoreProvider) {
TEST_F(SingleNodeTest, TestObjectInterface) {
CoreWorker core_worker(WorkerType::DRIVER, Language::PYTHON,
raylet_store_socket_names_[0], raylet_socket_names_[0],
JobID::FromInt(1), gcs_options_, "", "127.0.0.1", nullptr);
JobID::FromInt(1), gcs_options_, "", "127.0.0.1",
node_manager_port, nullptr);
uint8_t array1[] = {1, 2, 3, 4, 5, 6, 7, 8};
uint8_t array2[] = {10, 11, 12, 13, 14, 15};
@ -824,11 +831,11 @@ TEST_F(SingleNodeTest, TestObjectInterface) {
TEST_F(TwoNodeTest, TestObjectInterfaceCrossNodes) {
CoreWorker worker1(WorkerType::DRIVER, Language::PYTHON, raylet_store_socket_names_[0],
raylet_socket_names_[0], NextJobId(), gcs_options_, "", "127.0.0.1",
nullptr);
node_manager_port, nullptr);
CoreWorker worker2(WorkerType::DRIVER, Language::PYTHON, raylet_store_socket_names_[1],
raylet_socket_names_[1], NextJobId(), gcs_options_, "", "127.0.0.1",
nullptr);
node_manager_port, nullptr);
uint8_t array1[] = {1, 2, 3, 4, 5, 6, 7, 8};
uint8_t array2[] = {10, 11, 12, 13, 14, 15};
@ -946,9 +953,10 @@ TEST_F(TwoNodeTest, TestDirectActorTaskCrossNodesFailure) {
int main(int argc, char **argv) {
::testing::InitGoogleTest(&argc, argv);
RAY_CHECK(argc == 4);
ray::store_executable = std::string(argv[1]);
ray::raylet_executable = std::string(argv[2]);
ray::mock_worker_executable = std::string(argv[3]);
RAY_CHECK(argc == 5);
store_executable = std::string(argv[1]);
raylet_executable = std::string(argv[2]);
node_manager_port = std::stoi(std::string(argv[3]));
mock_worker_executable = std::string(argv[4]);
return RUN_ALL_TESTS();
}

View file

@ -19,10 +19,10 @@ namespace ray {
class MockWorker {
public:
MockWorker(const std::string &store_socket, const std::string &raylet_socket,
const gcs::GcsClientOptions &gcs_options)
int node_manager_port, const gcs::GcsClientOptions &gcs_options)
: worker_(WorkerType::WORKER, Language::PYTHON, store_socket, raylet_socket,
JobID::FromInt(1), gcs_options, /*log_dir=*/"",
/*node_id_address=*/"127.0.0.1",
/*node_id_address=*/"127.0.0.1", node_manager_port,
std::bind(&MockWorker::ExecuteTask, this, _1, _2, _3, _4, _5, _6, _7)) {}
void StartExecutingTasks() { worker_.StartExecutingTasks(); }
@ -71,12 +71,13 @@ class MockWorker {
} // namespace ray
int main(int argc, char **argv) {
RAY_CHECK(argc == 3);
RAY_CHECK(argc == 4);
auto store_socket = std::string(argv[1]);
auto raylet_socket = std::string(argv[2]);
auto node_manager_port = std::stoi(std::string(argv[3]));
ray::gcs::GcsClientOptions gcs_options("127.0.0.1", 6379, "");
ray::MockWorker worker(store_socket, raylet_socket, gcs_options);
ray::MockWorker worker(store_socket, raylet_socket, node_manager_port, gcs_options);
worker.StartExecutingTasks();
return 0;
}

View file

@ -44,7 +44,7 @@ class MockServer {
private:
ray::Status RegisterGcs(boost::asio::io_service &io_service) {
auto object_manager_port = config_.object_manager_port;
auto object_manager_port = object_manager_.GetServerPort();
GcsNodeInfo node_info = gcs_client_->client_table().GetLocalClient();
node_info.set_node_manager_address("127.0.0.1");
node_info.set_node_manager_port(object_manager_port);
@ -110,7 +110,7 @@ class TestObjectManagerBase : public ::testing::Test {
om_config_1.pull_timeout_ms = pull_timeout_ms;
om_config_1.object_chunk_size = object_chunk_size;
om_config_1.push_timeout_ms = push_timeout_ms;
om_config_1.object_manager_port = 12345;
om_config_1.object_manager_port = 0;
om_config_1.rpc_service_threads_number = 3;
server1.reset(new MockServer(main_service, om_config_1, gcs_client_1));
@ -123,7 +123,7 @@ class TestObjectManagerBase : public ::testing::Test {
om_config_2.pull_timeout_ms = pull_timeout_ms;
om_config_2.object_chunk_size = object_chunk_size;
om_config_2.push_timeout_ms = push_timeout_ms;
om_config_2.object_manager_port = 23456;
om_config_2.object_manager_port = 0;
om_config_2.rpc_service_threads_number = 3;
server2.reset(new MockServer(main_service, om_config_2, gcs_client_2));

View file

@ -38,7 +38,7 @@ class MockServer {
private:
ray::Status RegisterGcs(boost::asio::io_service &io_service) {
auto object_manager_port = config_.object_manager_port;
auto object_manager_port = object_manager_.GetServerPort();
GcsNodeInfo node_info = gcs_client_->client_table().GetLocalClient();
node_info.set_node_manager_address("127.0.0.1");
node_info.set_node_manager_port(object_manager_port);
@ -102,7 +102,7 @@ class TestObjectManagerBase : public ::testing::Test {
om_config_1.pull_timeout_ms = pull_timeout_ms;
om_config_1.object_chunk_size = object_chunk_size;
om_config_1.push_timeout_ms = push_timeout_ms;
om_config_1.object_manager_port = 12345;
om_config_1.object_manager_port = 0;
om_config_1.rpc_service_threads_number = 3;
server1.reset(new MockServer(main_service, om_config_1, gcs_client_1));
@ -115,7 +115,7 @@ class TestObjectManagerBase : public ::testing::Test {
om_config_2.pull_timeout_ms = pull_timeout_ms;
om_config_2.object_chunk_size = object_chunk_size;
om_config_2.push_timeout_ms = push_timeout_ms;
om_config_2.object_manager_port = 23456;
om_config_2.object_manager_port = 0;
om_config_2.rpc_service_threads_number = 3;
server2.reset(new MockServer(main_service, om_config_2, gcs_client_2));

View file

@ -4,6 +4,14 @@ package ray.rpc;
import "src/ray/protobuf/common.proto";
// Submit a task for execution.
message SubmitTaskRequest {
TaskSpec task_spec = 1;
}
message SubmitTaskReply {
}
message ForwardTaskRequest {
// The ID of the task to be forwarded.
bytes task_id = 1;
@ -56,6 +64,8 @@ message NodeStatsReply {
// Service for inter-node-manager communication.
service NodeManagerService {
// Submit a task (from a local or remote worker) to the node manager.
rpc SubmitTask(SubmitTaskRequest) returns (SubmitTaskReply);
// Forward a task and its uncommitted lineage to the remote node manager.
rpc ForwardTask(ForwardTaskRequest) returns (ForwardTaskReply);
// Get the current node stats.

View file

@ -5,12 +5,9 @@
namespace ray.protocol;
enum MessageType:int {
// Task is submitted to the raylet. This is sent from a worker to a
// raylet.
SubmitTask = 1,
// Notify the raylet that a task has finished. This is sent from a
// worker to a raylet.
TaskDone,
TaskDone = 1,
// Log a message to the event table. This is sent from a worker to a raylet.
EventLogMessage,
// Send an initial connection message to the raylet. This is sent
@ -94,10 +91,6 @@ table Task {
task_execution_spec: TaskExecutionSpecification;
}
table SubmitTaskRequest {
task_spec: string;
}
// This message describes a given resource that is reserved for a worker.
table ResourceIdSetInfo {
// The name of the resource.

View file

@ -895,9 +895,6 @@ void NodeManager::ProcessClientMessage(
// because it's already disconnected.
return;
} break;
case protocol::MessageType::SubmitTask: {
ProcessSubmitTaskMessage(message_data);
} break;
case protocol::MessageType::SetResourceRequest: {
ProcessSetResourceRequest(client, message_data);
} break;
@ -1175,18 +1172,6 @@ void NodeManager::ProcessDisconnectClientMessage(
// these can be leaked.
}
void NodeManager::ProcessSubmitTaskMessage(const uint8_t *message_data) {
// Read the task submitted by the client.
auto fbs_message = flatbuffers::GetRoot<protocol::SubmitTaskRequest>(message_data);
rpc::Task task_message;
RAY_CHECK(task_message.mutable_task_spec()->ParseFromArray(
fbs_message->task_spec()->data(), fbs_message->task_spec()->size()));
// Submit the task to the raylet. Since the task was submitted
// locally, there is no uncommitted lineage.
SubmitTask(Task(task_message), Lineage());
}
void NodeManager::ProcessFetchOrReconstructMessage(
const std::shared_ptr<LocalClientConnection> &client, const uint8_t *message_data) {
auto message = flatbuffers::GetRoot<protocol::FetchOrReconstruct>(message_data);
@ -1390,6 +1375,18 @@ void NodeManager::ProcessReportActiveObjectIDs(
unordered_set_from_flatbuf<ObjectID>(*message->object_ids()));
}
void NodeManager::HandleSubmitTask(const rpc::SubmitTaskRequest &request,
rpc::SubmitTaskReply *reply,
rpc::SendReplyCallback send_reply_callback) {
rpc::Task task;
task.mutable_task_spec()->CopyFrom(request.task_spec());
// Submit the task to the raylet. Since the task was submitted
// locally, there is no uncommitted lineage.
SubmitTask(Task(task), Lineage());
send_reply_callback(Status::OK(), nullptr, nullptr);
}
void NodeManager::HandleForwardTask(const rpc::ForwardTaskRequest &request,
rpc::ForwardTaskReply *reply,
rpc::SendReplyCallback send_reply_callback) {

View file

@ -406,12 +406,6 @@ class NodeManager : public rpc::NodeManagerServiceHandler {
const std::shared_ptr<LocalClientConnection> &client,
bool intentional_disconnect = false);
/// Process client message of SubmitTask
///
/// \param message_data A pointer to the message data.
/// \return Void.
void ProcessSubmitTaskMessage(const uint8_t *message_data);
/// Process client message of FetchOrReconstruct
///
/// \param client The client that sent the message.
@ -495,6 +489,11 @@ class NodeManager : public rpc::NodeManagerServiceHandler {
/// \return void.
void FinishAssignTask(const TaskID &task_id, Worker &worker, bool success);
/// Handle a `SubmitTask` request.
void HandleSubmitTask(const rpc::SubmitTaskRequest &request,
rpc::SubmitTaskReply *reply,
rpc::SendReplyCallback send_reply_callback) override;
/// Handle a `ForwardTask` request.
void HandleForwardTask(const rpc::ForwardTaskRequest &request,
rpc::ForwardTaskReply *reply,

View file

@ -201,10 +201,15 @@ ray::Status RayletConnection::AtomicRequestReply(
return ReadMessage(reply_type, reply_message);
}
RayletClient::RayletClient(const std::string &raylet_socket, const WorkerID &worker_id,
RayletClient::RayletClient(std::shared_ptr<ray::rpc::NodeManagerWorkerClient> grpc_client,
const std::string &raylet_socket, const WorkerID &worker_id,
bool is_worker, const JobID &job_id, const Language &language,
int port)
: worker_id_(worker_id), is_worker_(is_worker), job_id_(job_id), language_(language) {
: grpc_client_(std::move(grpc_client)),
worker_id_(worker_id),
is_worker_(is_worker),
job_id_(job_id),
language_(language) {
// For C++14, we could use std::make_unique
conn_ = std::unique_ptr<RayletConnection>(new RayletConnection(raylet_socket, -1, -1));
@ -220,11 +225,9 @@ RayletClient::RayletClient(const std::string &raylet_socket, const WorkerID &wor
}
ray::Status RayletClient::SubmitTask(const ray::TaskSpecification &task_spec) {
flatbuffers::FlatBufferBuilder fbb;
auto message = ray::protocol::CreateSubmitTaskRequest(
fbb, fbb.CreateString(task_spec.Serialize()));
fbb.Finish(message);
return conn_->WriteMessage(MessageType::SubmitTask, &fbb);
ray::rpc::SubmitTaskRequest request;
request.mutable_task_spec()->CopyFrom(task_spec.GetMessage());
return grpc_client_->SubmitTask(request, /*callback=*/nullptr);
}
ray::Status RayletClient::TaskDone() {

View file

@ -9,6 +9,7 @@
#include "ray/common/status.h"
#include "ray/common/task/task_spec.h"
#include "ray/rpc/node_manager/node_manager_client.h"
using ray::ActorCheckpointID;
using ray::ActorID;
@ -66,13 +67,15 @@ class RayletClient {
public:
/// Connect to the raylet.
///
/// \param grpc_client gRPC client to the raylet.
/// \param raylet_socket The name of the socket to use to connect to the raylet.
/// \param worker_id A unique ID to represent the worker.
/// \param is_worker Whether this client is a worker. If it is a worker, an
/// additional message will be sent to register as one.
/// \param job_id The ID of the driver. This is non-nil if the client is a driver.
/// \return The connection information.
RayletClient(const std::string &raylet_socket, const WorkerID &worker_id,
RayletClient(std::shared_ptr<ray::rpc::NodeManagerWorkerClient> grpc_client,
const std::string &raylet_socket, const WorkerID &worker_id,
bool is_worker, const JobID &job_id, const Language &language,
int port = -1);
@ -193,6 +196,9 @@ class RayletClient {
const ResourceMappingType &GetResourceIDs() const { return resource_ids_; }
private:
/// gRPC client to the raylet. Right now, this is only used for a couple
/// request types.
std::shared_ptr<ray::rpc::NodeManagerWorkerClient> grpc_client_;
const WorkerID worker_id_;
const bool is_worker_;
const JobID job_id_;

View file

@ -2,11 +2,38 @@
#include "src/ray/rpc/grpc_server.h"
#include <grpcpp/impl/service_type.h>
namespace {
bool PortNotInUse(int port) {
int fd = socket(AF_INET, SOCK_STREAM, 0);
if (fd == -1) {
return false;
}
struct sockaddr_in server_addr = {0};
server_addr.sin_family = AF_INET;
server_addr.sin_addr.s_addr = htonl(INADDR_ANY);
server_addr.sin_port = htons(port);
int err = bind(fd, (struct sockaddr *)&server_addr, sizeof(server_addr));
close(fd);
return err == 0;
}
} // namespace
namespace ray {
namespace rpc {
void GrpcServer::Run() {
std::string server_address("0.0.0.0:" + std::to_string(port_));
// Unfortunately, grpc will not return an error if the specified port is in
// use. There is a race condition here where two servers could check the same
// port, but only one would succeed in binding.
if (port_ > 0) {
RAY_CHECK(PortNotInUse(port_))
<< "Port " << port_
<< " specified by caller already in use. Try passing node_manager_port=... into "
"ray.init() to pick a specific port";
}
grpc::ServerBuilder builder;
// TODO(hchen): Add options for authentication.

View file

@ -57,6 +57,53 @@ class NodeManagerClient {
ClientCallManager &client_call_manager_;
};
/// Client used by workers for communicating with a node manager server.
class NodeManagerWorkerClient
: public std::enable_shared_from_this<NodeManagerWorkerClient> {
public:
/// Constructor.
///
/// \param[in] address Address of the node manager server.
/// \param[in] port Port of the node manager server.
/// \param[in] client_call_manager The `ClientCallManager` used for managing requests.
static std::shared_ptr<NodeManagerWorkerClient> make(
const std::string &address, const int port,
ClientCallManager &client_call_manager) {
auto instance = new NodeManagerWorkerClient(address, port, client_call_manager);
return std::shared_ptr<NodeManagerWorkerClient>(instance);
}
/// Submit a task.
ray::Status SubmitTask(const SubmitTaskRequest &request,
const ClientCallback<SubmitTaskReply> &callback) {
auto call = client_call_manager_
.CreateCall<NodeManagerService, SubmitTaskRequest, SubmitTaskReply>(
*stub_, &NodeManagerService::Stub::PrepareAsyncSubmitTask,
request, callback);
return call->GetStatus();
}
private:
/// Constructor.
///
/// \param[in] address Address of the node manager server.
/// \param[in] port Port of the node manager server.
/// \param[in] client_call_manager The `ClientCallManager` used for managing requests.
NodeManagerWorkerClient(const std::string &address, const int port,
ClientCallManager &client_call_manager)
: client_call_manager_(client_call_manager) {
std::shared_ptr<grpc::Channel> channel = grpc::CreateChannel(
address + ":" + std::to_string(port), grpc::InsecureChannelCredentials());
stub_ = NodeManagerService::NewStub(channel);
};
/// The gRPC-generated stub.
std::unique_ptr<NodeManagerService::Stub> stub_;
/// The `ClientCallManager` used for managing requests.
ClientCallManager &client_call_manager_;
};
} // namespace rpc
} // namespace ray

View file

@ -13,24 +13,24 @@ namespace rpc {
/// Interface of the `NodeManagerService`, see `src/ray/protobuf/node_manager.proto`.
class NodeManagerServiceHandler {
public:
/// Handle a `ForwardTask` request.
/// The implementation can handle this request asynchronously. When handling is done,
/// the `send_reply_callback` should be called.
/// Handlers. For all of the following handlers, the implementations can
/// handle the request asynchronously. When handling is done, the
/// `send_reply_callback` should be called. See
/// src/ray/rpc/node_manager/node_manager_client.h and
/// src/ray/protobuf/node_manager.proto for a description of the
/// functionality of each handler.
///
/// \param[in] request The request message.
/// \param[out] reply The reply message.
/// \param[in] send_reply_callback The callback to be called when the request is done.
virtual void HandleSubmitTask(const SubmitTaskRequest &request, SubmitTaskReply *reply,
SendReplyCallback send_reply_callback) = 0;
virtual void HandleForwardTask(const ForwardTaskRequest &request,
ForwardTaskReply *reply,
SendReplyCallback send_reply_callback) = 0;
/// Handle a `GetNodeStats` request.
/// The implementation can handle this request asynchronously. When handling is done,
/// the `send_reply_callback` should be called.
///
/// \param[in] request The request message.
/// \param[out] reply The reply message.
/// \param[in] send_reply_callback The callback to be called when the request is done.
virtual void HandleNodeStatsRequest(const NodeStatsRequest &request,
NodeStatsReply *reply,
SendReplyCallback send_reply_callback) = 0;
@ -55,6 +55,13 @@ class NodeManagerGrpcService : public GrpcService {
std::vector<std::pair<std::unique_ptr<ServerCallFactory>, int>>
*server_call_factories_and_concurrencies) override {
// Initialize the factory for requests.
std::unique_ptr<ServerCallFactory> submit_task_call_factory(
new ServerCallFactoryImpl<NodeManagerService, NodeManagerServiceHandler,
SubmitTaskRequest, SubmitTaskReply>(
service_, &NodeManagerService::AsyncService::RequestSubmitTask,
service_handler_, &NodeManagerServiceHandler::HandleSubmitTask, cq,
main_service_));
std::unique_ptr<ServerCallFactory> forward_task_call_factory(
new ServerCallFactoryImpl<NodeManagerService, NodeManagerServiceHandler,
ForwardTaskRequest, ForwardTaskReply>(
@ -70,6 +77,8 @@ class NodeManagerGrpcService : public GrpcService {
main_service_));
// Set accept concurrency.
server_call_factories_and_concurrencies->emplace_back(
std::move(submit_task_call_factory), 100);
server_call_factories_and_concurrencies->emplace_back(
std::move(forward_task_call_factory), 100);
server_call_factories_and_concurrencies->emplace_back(

View file

@ -2,6 +2,22 @@
# This needs to be run in the root directory.
# Try to find an unused port for raylet to use.
PORTS="2000 2001 2002 2003 2004 2005 2006 2007 2008 2009"
RAYLET_PORT=0
for port in $PORTS; do
nc -z localhost $port
if [[ $? != 0 ]]; then
RAYLET_PORT=$port
break
fi
done
if [[ $RAYLET_PORT == 0 ]]; then
echo "WARNING: Could not find unused port for raylet to use. Exiting without running tests."
exit
fi
# Cause the script to exit if a single command fails.
set -e
set -x
@ -38,7 +54,7 @@ sleep 2s
bazel run //:redis-server -- --loglevel warning ${LOAD_MODULE_ARGS} --port 6380 &
sleep 2s
# Run tests.
./bazel-bin/core_worker_test $STORE_EXEC $RAYLET_EXEC $MOCK_WORKER_EXEC
./bazel-bin/core_worker_test $STORE_EXEC $RAYLET_EXEC $RAYLET_PORT $MOCK_WORKER_EXEC
sleep 1s
bazel run //:redis-cli -- -p 6379 shutdown
bazel run //:redis-cli -- -p 6380 shutdown