From 4f583ec784ec163bf82cf7e3b9a5ac2fe6113b0a Mon Sep 17 00:00:00 2001 From: Danyang Zhuo Date: Mon, 18 Nov 2019 14:40:34 -0800 Subject: [PATCH] Improve Object Transfer Performance (#6067) --- src/ray/object_manager/object_manager.cc | 10 ++-- src/ray/raylet/main.cc | 3 +- src/ray/rpc/client_call.h | 46 +++++++++++++------ src/ray/rpc/grpc_server.cc | 25 +++++++--- src/ray/rpc/grpc_server.h | 21 +++++---- .../object_manager/object_manager_client.h | 41 +++++++++++++---- .../object_manager/object_manager_server.h | 4 +- 7 files changed, 102 insertions(+), 48 deletions(-) diff --git a/src/ray/object_manager/object_manager.cc b/src/ray/object_manager/object_manager.cc index 0b3243c1b..ae8714232 100644 --- a/src/ray/object_manager/object_manager.cc +++ b/src/ray/object_manager/object_manager.cc @@ -18,9 +18,10 @@ ObjectManager::ObjectManager(asio::io_service &main_service, buffer_pool_(config_.store_socket_name, config_.object_chunk_size), rpc_work_(rpc_service_), gen_(std::chrono::high_resolution_clock::now().time_since_epoch().count()), - object_manager_server_("ObjectManager", config_.object_manager_port), + object_manager_server_("ObjectManager", config_.object_manager_port, + config_.rpc_service_threads_number), object_manager_service_(rpc_service_, *this), - client_call_manager_(main_service) { + client_call_manager_(main_service, config_.rpc_service_threads_number) { RAY_CHECK(config_.rpc_service_threads_number > 0); client_id_ = object_directory_->GetLocalClientID(); main_service_ = &main_service; @@ -443,10 +444,7 @@ ray::Status ObjectManager::SendObjectChunk( RAY_RETURN_NOT_OK(status); } - std::string buffer; - buffer.resize(chunk_info.buffer_length); - buffer.assign(chunk_info.data, chunk_info.data + chunk_info.buffer_length); - push_request.set_data(std::move(buffer)); + push_request.set_data(chunk_info.data, chunk_info.buffer_length); // record the time cost between send chunk and receive reply rpc::ClientCallback callback = [this, start_time, object_id, client_id, diff --git a/src/ray/raylet/main.cc b/src/ray/raylet/main.cc index e2261a0c8..13191b97b 100644 --- a/src/ray/raylet/main.cc +++ b/src/ray/raylet/main.cc @@ -140,7 +140,8 @@ int main(int argc, char *argv[]) { RayConfig::instance().object_manager_push_timeout_ms(); int num_cpus = static_cast(static_resource_conf["CPU"]); - object_manager_config.rpc_service_threads_number = std::max(2, num_cpus / 2); + object_manager_config.rpc_service_threads_number = + std::min(std::max(2, num_cpus / 4), 8); object_manager_config.object_chunk_size = RayConfig::instance().object_manager_default_chunk_size(); diff --git a/src/ray/rpc/client_call.h b/src/ray/rpc/client_call.h index 72384ca3f..0763d2b0c 100644 --- a/src/ray/rpc/client_call.h +++ b/src/ray/rpc/client_call.h @@ -123,16 +123,25 @@ class ClientCallManager { /// /// \param[in] main_service The main event loop, to which the callback functions will be /// posted. - explicit ClientCallManager(boost::asio::io_service &main_service) - : main_service_(main_service) { - // Start the polling thread. - polling_thread_ = - std::thread(&ClientCallManager::PollEventsFromCompletionQueue, this); + explicit ClientCallManager(boost::asio::io_service &main_service, int num_threads = 1) + : main_service_(main_service), num_threads_(num_threads) { + rr_index_ = rand() % num_threads_; + // Start the polling threads. + cqs_.reserve(num_threads_); + for (int i = 0; i < num_threads_; i++) { + cqs_.emplace_back(); + polling_threads_.emplace_back(&ClientCallManager::PollEventsFromCompletionQueue, + this, i); + } } ~ClientCallManager() { - cq_.Shutdown(); - polling_thread_.join(); + for (auto &cq : cqs_) { + cq.Shutdown(); + } + for (auto &polling_thread : polling_threads_) { + polling_thread.join(); + } } /// Create a new `ClientCall` and send request. @@ -155,8 +164,9 @@ class ClientCallManager { const Request &request, const ClientCallback &callback) { auto call = std::make_shared>(callback); // Send request. - call->response_reader_ = - (stub.*prepare_async_function)(&call->context_, request, &cq_); + // Find the next completion queue to wait for response. + call->response_reader_ = (stub.*prepare_async_function)( + &call->context_, request, &cqs_[rr_index_++ % num_threads_]); call->response_reader_->StartCall(); // Create a new tag object. This object will eventually be deleted in the // `ClientCallManager::PollEventsFromCompletionQueue` when reply is received. @@ -174,7 +184,7 @@ class ClientCallManager { /// This function runs in a background thread. It keeps polling events from the /// `CompletionQueue`, and dispatches the event to the callbacks via the `ClientCall` /// objects. - void PollEventsFromCompletionQueue() { + void PollEventsFromCompletionQueue(int index) { void *got_tag; bool ok = false; auto deadline = gpr_inf_future(GPR_CLOCK_REALTIME); @@ -183,7 +193,7 @@ class ClientCallManager { // synchronous cq_.Next blocks indefinitely in the case that the process // received a SIGTERM. while (true) { - auto status = cq_.AsyncNext(&got_tag, &ok, deadline); + auto status = cqs_[index].AsyncNext(&got_tag, &ok, deadline); if (status == grpc::CompletionQueue::SHUTDOWN) { break; } @@ -206,11 +216,17 @@ class ClientCallManager { /// The main event loop, to which the callback functions will be posted. boost::asio::io_service &main_service_; - /// The gRPC `CompletionQueue` object used to poll events. - grpc::CompletionQueue cq_; + /// The number of polling threads. + int num_threads_; - /// Polling thread to check the completion queue. - std::thread polling_thread_; + /// The index to send RPCs in a round-robin fashion + std::atomic rr_index_; + + /// The gRPC `CompletionQueue` object used to poll events. + std::vector cqs_; + + /// Polling threads to check the completion queue. + std::vector polling_threads_; }; } // namespace rpc diff --git a/src/ray/rpc/grpc_server.cc b/src/ray/rpc/grpc_server.cc index ef6fea9c5..4061e3598 100644 --- a/src/ray/rpc/grpc_server.cc +++ b/src/ray/rpc/grpc_server.cc @@ -23,6 +23,11 @@ bool PortNotInUse(int port) { namespace ray { namespace rpc { +GrpcServer::GrpcServer(std::string name, const uint32_t port, int num_threads) + : name_(std::move(name)), port_(port), is_closed_(true), num_threads_(num_threads) { + cqs_.reserve(num_threads_); +} + void GrpcServer::Run() { std::string server_address("0.0.0.0:" + std::to_string(port_)); // Unfortunately, grpc will not return an error if the specified port is in @@ -47,7 +52,9 @@ void GrpcServer::Run() { } // Get hold of the completion queue used for the asynchronous communication // with the gRPC runtime. - cq_ = builder.AddCompletionQueue(); + for (int i = 0; i < num_threads_; i++) { + cqs_.push_back(builder.AddCompletionQueue()); + } // Build and start server. server_ = builder.BuildAndStart(); RAY_LOG(INFO) << name_ << " server started, listening on port " << port_ << "."; @@ -59,22 +66,28 @@ void GrpcServer::Run() { entry.first->CreateCall(); } } - // Start a thread that polls incoming requests. - polling_thread_ = std::thread(&GrpcServer::PollEventsFromCompletionQueue, this); + // Start threads that polls incoming requests. + for (int i = 0; i < num_threads_; i++) { + polling_threads_.emplace_back(&GrpcServer::PollEventsFromCompletionQueue, this, i); + } // Set the server as running. is_closed_ = false; } void GrpcServer::RegisterService(GrpcService &service) { services_.emplace_back(service.GetGrpcService()); - service.InitServerCallFactories(cq_, &server_call_factories_and_concurrencies_); + + for (int i = 0; i < num_threads_; i++) { + service.InitServerCallFactories(cqs_[i], &server_call_factories_and_concurrencies_); + } } -void GrpcServer::PollEventsFromCompletionQueue() { +void GrpcServer::PollEventsFromCompletionQueue(int index) { void *tag; bool ok; + // Keep reading events from the `CompletionQueue` until it's shutdown. - while (cq_->Next(&tag, &ok)) { + while (cqs_[index]->Next(&tag, &ok)) { auto *server_call = static_cast(tag); bool delete_call = false; if (ok) { diff --git a/src/ray/rpc/grpc_server.h b/src/ray/rpc/grpc_server.h index b2e884445..bd28ae3c1 100644 --- a/src/ray/rpc/grpc_server.h +++ b/src/ray/rpc/grpc_server.h @@ -31,8 +31,7 @@ class GrpcServer { /// \param[in] name Name of this server, used for logging and debugging purpose. /// \param[in] port The port to bind this server to. If it's 0, a random available port /// will be chosen. - GrpcServer(std::string name, const uint32_t port) - : name_(std::move(name)), port_(port), is_closed_(true) {} + GrpcServer(std::string name, const uint32_t port, int num_threads = 1); /// Destruct this gRPC server. ~GrpcServer() { Shutdown(); } @@ -44,8 +43,12 @@ class GrpcServer { void Shutdown() { if (!is_closed_) { server_->Shutdown(); - cq_->Shutdown(); - polling_thread_.join(); + for (const auto &cq : cqs_) { + cq->Shutdown(); + } + for (auto &polling_thread : polling_threads_) { + polling_thread.join(); + } is_closed_ = true; RAY_LOG(DEBUG) << "gRPC server of " << name_ << " shutdown."; } @@ -65,7 +68,7 @@ class GrpcServer { /// This function runs in a background thread. It keeps polling events from the /// `ServerCompletionQueue`, and dispaches the event to the `ServiceHandler` instances /// via the `ServerCall` objects. - void PollEventsFromCompletionQueue(); + void PollEventsFromCompletionQueue(int index); /// Name of this server, used for logging and debugging purpose. const std::string name_; @@ -79,12 +82,14 @@ class GrpcServer { /// this gRPC server can handle. std::vector, int>> server_call_factories_and_concurrencies_; + /// The number of completion queues the server is polling from. + int num_threads_; /// The `ServerCompletionQueue` object used for polling events. - std::unique_ptr cq_; + std::vector> cqs_; /// The `Server` object. std::unique_ptr server_; - /// The polling thread used to check the completion queue. - std::thread polling_thread_; + /// The polling threads used to check the completion queues. + std::vector polling_threads_; }; /// Base class that represents an abstract gRPC service. diff --git a/src/ray/rpc/object_manager/object_manager_client.h b/src/ray/rpc/object_manager/object_manager_client.h index f37a081e6..b7e711f8d 100644 --- a/src/ray/rpc/object_manager/object_manager_client.h +++ b/src/ray/rpc/object_manager/object_manager_client.h @@ -4,6 +4,8 @@ #include #include +#include +#include #include "ray/common/status.h" #include "ray/util/logging.h" @@ -23,11 +25,22 @@ class ObjectManagerClient { /// \param[in] port Port of the node manager server. /// \param[in] client_call_manager The `ClientCallManager` used for managing requests. ObjectManagerClient(const std::string &address, const int port, - ClientCallManager &client_call_manager) - : client_call_manager_(client_call_manager) { - std::shared_ptr channel = grpc::CreateChannel( - address + ":" + std::to_string(port), grpc::InsecureChannelCredentials()); - stub_ = ObjectManagerService::NewStub(channel); + ClientCallManager &client_call_manager, int num_connections = 4) + : client_call_manager_(client_call_manager), num_connections_(num_connections) { + push_rr_index_ = rand() % num_connections_; + pull_rr_index_ = rand() % num_connections_; + freeobjects_rr_index_ = rand() % num_connections_; + stubs_.reserve(num_connections_); + for (int i = 0; i < num_connections_; i++) { + grpc::ResourceQuota quota; + quota.SetMaxThreads(num_connections_); + grpc::ChannelArguments argument; + argument.SetResourceQuota(quota); + std::shared_ptr channel = + grpc::CreateCustomChannel(address + ":" + std::to_string(port), + grpc::InsecureChannelCredentials(), argument); + stubs_.push_back(ObjectManagerService::NewStub(channel)); + } }; /// Push object to remote object manager @@ -36,7 +49,8 @@ class ObjectManagerClient { /// \param callback The callback function that handles reply from server void Push(const PushRequest &request, const ClientCallback &callback) { client_call_manager_.CreateCall( - *stub_, &ObjectManagerService::Stub::PrepareAsyncPush, request, callback); + *stubs_[push_rr_index_++ % num_connections_], + &ObjectManagerService::Stub::PrepareAsyncPush, request, callback); } /// Pull object from remote object manager @@ -45,7 +59,8 @@ class ObjectManagerClient { /// \param callback The callback function that handles reply from server void Pull(const PullRequest &request, const ClientCallback &callback) { client_call_manager_.CreateCall( - *stub_, &ObjectManagerService::Stub::PrepareAsyncPull, request, callback); + *stubs_[pull_rr_index_++ % num_connections_], + &ObjectManagerService::Stub::PrepareAsyncPull, request, callback); } /// Tell remote object manager to free objects @@ -56,13 +71,19 @@ class ObjectManagerClient { const ClientCallback &callback) { client_call_manager_ .CreateCall( - *stub_, &ObjectManagerService::Stub::PrepareAsyncFreeObjects, request, - callback); + *stubs_[freeobjects_rr_index_++ % num_connections_], + &ObjectManagerService::Stub::PrepareAsyncFreeObjects, request, callback); } private: + int num_connections_; + + std::atomic push_rr_index_; + std::atomic pull_rr_index_; + std::atomic freeobjects_rr_index_; + /// The gRPC-generated stub. - std::unique_ptr stub_; + std::vector> stubs_; /// The `ClientCallManager` used for managing requests. ClientCallManager &client_call_manager_; diff --git a/src/ray/rpc/object_manager/object_manager_server.h b/src/ray/rpc/object_manager/object_manager_server.h index c0af15ffb..496d22591 100644 --- a/src/ray/rpc/object_manager/object_manager_server.h +++ b/src/ray/rpc/object_manager/object_manager_server.h @@ -57,7 +57,7 @@ class ObjectManagerGrpcService : public GrpcService { service_, &ObjectManagerService::AsyncService::RequestPush, service_handler_, &ObjectManagerServiceHandler::HandlePushRequest, cq, main_service_)); server_call_factories_and_concurrencies->emplace_back(std::move(push_call_factory), - 50); + 5); // Initialize the factory for `Pull` requests. std::unique_ptr pull_call_factory( @@ -66,7 +66,7 @@ class ObjectManagerGrpcService : public GrpcService { service_, &ObjectManagerService::AsyncService::RequestPull, service_handler_, &ObjectManagerServiceHandler::HandlePullRequest, cq, main_service_)); server_call_factories_and_concurrencies->emplace_back(std::move(pull_call_factory), - 50); + 5); // Initialize the factory for `FreeObjects` requests. std::unique_ptr free_objects_call_factory(