Improve Object Transfer Performance (#6067)

This commit is contained in:
Danyang Zhuo 2019-11-18 14:40:34 -08:00 committed by Philipp Moritz
parent d3ff2252c4
commit 4f583ec784
7 changed files with 102 additions and 48 deletions

View file

@ -18,9 +18,10 @@ ObjectManager::ObjectManager(asio::io_service &main_service,
buffer_pool_(config_.store_socket_name, config_.object_chunk_size),
rpc_work_(rpc_service_),
gen_(std::chrono::high_resolution_clock::now().time_since_epoch().count()),
object_manager_server_("ObjectManager", config_.object_manager_port),
object_manager_server_("ObjectManager", config_.object_manager_port,
config_.rpc_service_threads_number),
object_manager_service_(rpc_service_, *this),
client_call_manager_(main_service) {
client_call_manager_(main_service, config_.rpc_service_threads_number) {
RAY_CHECK(config_.rpc_service_threads_number > 0);
client_id_ = object_directory_->GetLocalClientID();
main_service_ = &main_service;
@ -443,10 +444,7 @@ ray::Status ObjectManager::SendObjectChunk(
RAY_RETURN_NOT_OK(status);
}
std::string buffer;
buffer.resize(chunk_info.buffer_length);
buffer.assign(chunk_info.data, chunk_info.data + chunk_info.buffer_length);
push_request.set_data(std::move(buffer));
push_request.set_data(chunk_info.data, chunk_info.buffer_length);
// record the time cost between send chunk and receive reply
rpc::ClientCallback<rpc::PushReply> callback = [this, start_time, object_id, client_id,

View file

@ -140,7 +140,8 @@ int main(int argc, char *argv[]) {
RayConfig::instance().object_manager_push_timeout_ms();
int num_cpus = static_cast<int>(static_resource_conf["CPU"]);
object_manager_config.rpc_service_threads_number = std::max(2, num_cpus / 2);
object_manager_config.rpc_service_threads_number =
std::min(std::max(2, num_cpus / 4), 8);
object_manager_config.object_chunk_size =
RayConfig::instance().object_manager_default_chunk_size();

View file

@ -123,16 +123,25 @@ class ClientCallManager {
///
/// \param[in] main_service The main event loop, to which the callback functions will be
/// posted.
explicit ClientCallManager(boost::asio::io_service &main_service)
: main_service_(main_service) {
// Start the polling thread.
polling_thread_ =
std::thread(&ClientCallManager::PollEventsFromCompletionQueue, this);
explicit ClientCallManager(boost::asio::io_service &main_service, int num_threads = 1)
: main_service_(main_service), num_threads_(num_threads) {
rr_index_ = rand() % num_threads_;
// Start the polling threads.
cqs_.reserve(num_threads_);
for (int i = 0; i < num_threads_; i++) {
cqs_.emplace_back();
polling_threads_.emplace_back(&ClientCallManager::PollEventsFromCompletionQueue,
this, i);
}
}
~ClientCallManager() {
cq_.Shutdown();
polling_thread_.join();
for (auto &cq : cqs_) {
cq.Shutdown();
}
for (auto &polling_thread : polling_threads_) {
polling_thread.join();
}
}
/// Create a new `ClientCall` and send request.
@ -155,8 +164,9 @@ class ClientCallManager {
const Request &request, const ClientCallback<Reply> &callback) {
auto call = std::make_shared<ClientCallImpl<Reply>>(callback);
// Send request.
call->response_reader_ =
(stub.*prepare_async_function)(&call->context_, request, &cq_);
// Find the next completion queue to wait for response.
call->response_reader_ = (stub.*prepare_async_function)(
&call->context_, request, &cqs_[rr_index_++ % num_threads_]);
call->response_reader_->StartCall();
// Create a new tag object. This object will eventually be deleted in the
// `ClientCallManager::PollEventsFromCompletionQueue` when reply is received.
@ -174,7 +184,7 @@ class ClientCallManager {
/// This function runs in a background thread. It keeps polling events from the
/// `CompletionQueue`, and dispatches the event to the callbacks via the `ClientCall`
/// objects.
void PollEventsFromCompletionQueue() {
void PollEventsFromCompletionQueue(int index) {
void *got_tag;
bool ok = false;
auto deadline = gpr_inf_future(GPR_CLOCK_REALTIME);
@ -183,7 +193,7 @@ class ClientCallManager {
// synchronous cq_.Next blocks indefinitely in the case that the process
// received a SIGTERM.
while (true) {
auto status = cq_.AsyncNext(&got_tag, &ok, deadline);
auto status = cqs_[index].AsyncNext(&got_tag, &ok, deadline);
if (status == grpc::CompletionQueue::SHUTDOWN) {
break;
}
@ -206,11 +216,17 @@ class ClientCallManager {
/// The main event loop, to which the callback functions will be posted.
boost::asio::io_service &main_service_;
/// The gRPC `CompletionQueue` object used to poll events.
grpc::CompletionQueue cq_;
/// The number of polling threads.
int num_threads_;
/// Polling thread to check the completion queue.
std::thread polling_thread_;
/// The index to send RPCs in a round-robin fashion
std::atomic<unsigned int> rr_index_;
/// The gRPC `CompletionQueue` object used to poll events.
std::vector<grpc::CompletionQueue> cqs_;
/// Polling threads to check the completion queue.
std::vector<std::thread> polling_threads_;
};
} // namespace rpc

View file

@ -23,6 +23,11 @@ bool PortNotInUse(int port) {
namespace ray {
namespace rpc {
GrpcServer::GrpcServer(std::string name, const uint32_t port, int num_threads)
: name_(std::move(name)), port_(port), is_closed_(true), num_threads_(num_threads) {
cqs_.reserve(num_threads_);
}
void GrpcServer::Run() {
std::string server_address("0.0.0.0:" + std::to_string(port_));
// Unfortunately, grpc will not return an error if the specified port is in
@ -47,7 +52,9 @@ void GrpcServer::Run() {
}
// Get hold of the completion queue used for the asynchronous communication
// with the gRPC runtime.
cq_ = builder.AddCompletionQueue();
for (int i = 0; i < num_threads_; i++) {
cqs_.push_back(builder.AddCompletionQueue());
}
// Build and start server.
server_ = builder.BuildAndStart();
RAY_LOG(INFO) << name_ << " server started, listening on port " << port_ << ".";
@ -59,22 +66,28 @@ void GrpcServer::Run() {
entry.first->CreateCall();
}
}
// Start a thread that polls incoming requests.
polling_thread_ = std::thread(&GrpcServer::PollEventsFromCompletionQueue, this);
// Start threads that polls incoming requests.
for (int i = 0; i < num_threads_; i++) {
polling_threads_.emplace_back(&GrpcServer::PollEventsFromCompletionQueue, this, i);
}
// Set the server as running.
is_closed_ = false;
}
void GrpcServer::RegisterService(GrpcService &service) {
services_.emplace_back(service.GetGrpcService());
service.InitServerCallFactories(cq_, &server_call_factories_and_concurrencies_);
for (int i = 0; i < num_threads_; i++) {
service.InitServerCallFactories(cqs_[i], &server_call_factories_and_concurrencies_);
}
}
void GrpcServer::PollEventsFromCompletionQueue() {
void GrpcServer::PollEventsFromCompletionQueue(int index) {
void *tag;
bool ok;
// Keep reading events from the `CompletionQueue` until it's shutdown.
while (cq_->Next(&tag, &ok)) {
while (cqs_[index]->Next(&tag, &ok)) {
auto *server_call = static_cast<ServerCall *>(tag);
bool delete_call = false;
if (ok) {

View file

@ -31,8 +31,7 @@ class GrpcServer {
/// \param[in] name Name of this server, used for logging and debugging purpose.
/// \param[in] port The port to bind this server to. If it's 0, a random available port
/// will be chosen.
GrpcServer(std::string name, const uint32_t port)
: name_(std::move(name)), port_(port), is_closed_(true) {}
GrpcServer(std::string name, const uint32_t port, int num_threads = 1);
/// Destruct this gRPC server.
~GrpcServer() { Shutdown(); }
@ -44,8 +43,12 @@ class GrpcServer {
void Shutdown() {
if (!is_closed_) {
server_->Shutdown();
cq_->Shutdown();
polling_thread_.join();
for (const auto &cq : cqs_) {
cq->Shutdown();
}
for (auto &polling_thread : polling_threads_) {
polling_thread.join();
}
is_closed_ = true;
RAY_LOG(DEBUG) << "gRPC server of " << name_ << " shutdown.";
}
@ -65,7 +68,7 @@ class GrpcServer {
/// This function runs in a background thread. It keeps polling events from the
/// `ServerCompletionQueue`, and dispaches the event to the `ServiceHandler` instances
/// via the `ServerCall` objects.
void PollEventsFromCompletionQueue();
void PollEventsFromCompletionQueue(int index);
/// Name of this server, used for logging and debugging purpose.
const std::string name_;
@ -79,12 +82,14 @@ class GrpcServer {
/// this gRPC server can handle.
std::vector<std::pair<std::unique_ptr<ServerCallFactory>, int>>
server_call_factories_and_concurrencies_;
/// The number of completion queues the server is polling from.
int num_threads_;
/// The `ServerCompletionQueue` object used for polling events.
std::unique_ptr<grpc::ServerCompletionQueue> cq_;
std::vector<std::unique_ptr<grpc::ServerCompletionQueue>> cqs_;
/// The `Server` object.
std::unique_ptr<grpc::Server> server_;
/// The polling thread used to check the completion queue.
std::thread polling_thread_;
/// The polling threads used to check the completion queues.
std::vector<std::thread> polling_threads_;
};
/// Base class that represents an abstract gRPC service.

View file

@ -4,6 +4,8 @@
#include <thread>
#include <grpcpp/grpcpp.h>
#include <grpcpp/resource_quota.h>
#include <grpcpp/support/channel_arguments.h>
#include "ray/common/status.h"
#include "ray/util/logging.h"
@ -23,11 +25,22 @@ class ObjectManagerClient {
/// \param[in] port Port of the node manager server.
/// \param[in] client_call_manager The `ClientCallManager` used for managing requests.
ObjectManagerClient(const std::string &address, const int port,
ClientCallManager &client_call_manager)
: client_call_manager_(client_call_manager) {
std::shared_ptr<grpc::Channel> channel = grpc::CreateChannel(
address + ":" + std::to_string(port), grpc::InsecureChannelCredentials());
stub_ = ObjectManagerService::NewStub(channel);
ClientCallManager &client_call_manager, int num_connections = 4)
: client_call_manager_(client_call_manager), num_connections_(num_connections) {
push_rr_index_ = rand() % num_connections_;
pull_rr_index_ = rand() % num_connections_;
freeobjects_rr_index_ = rand() % num_connections_;
stubs_.reserve(num_connections_);
for (int i = 0; i < num_connections_; i++) {
grpc::ResourceQuota quota;
quota.SetMaxThreads(num_connections_);
grpc::ChannelArguments argument;
argument.SetResourceQuota(quota);
std::shared_ptr<grpc::Channel> channel =
grpc::CreateCustomChannel(address + ":" + std::to_string(port),
grpc::InsecureChannelCredentials(), argument);
stubs_.push_back(ObjectManagerService::NewStub(channel));
}
};
/// Push object to remote object manager
@ -36,7 +49,8 @@ class ObjectManagerClient {
/// \param callback The callback function that handles reply from server
void Push(const PushRequest &request, const ClientCallback<PushReply> &callback) {
client_call_manager_.CreateCall<ObjectManagerService, PushRequest, PushReply>(
*stub_, &ObjectManagerService::Stub::PrepareAsyncPush, request, callback);
*stubs_[push_rr_index_++ % num_connections_],
&ObjectManagerService::Stub::PrepareAsyncPush, request, callback);
}
/// Pull object from remote object manager
@ -45,7 +59,8 @@ class ObjectManagerClient {
/// \param callback The callback function that handles reply from server
void Pull(const PullRequest &request, const ClientCallback<PullReply> &callback) {
client_call_manager_.CreateCall<ObjectManagerService, PullRequest, PullReply>(
*stub_, &ObjectManagerService::Stub::PrepareAsyncPull, request, callback);
*stubs_[pull_rr_index_++ % num_connections_],
&ObjectManagerService::Stub::PrepareAsyncPull, request, callback);
}
/// Tell remote object manager to free objects
@ -56,13 +71,19 @@ class ObjectManagerClient {
const ClientCallback<FreeObjectsReply> &callback) {
client_call_manager_
.CreateCall<ObjectManagerService, FreeObjectsRequest, FreeObjectsReply>(
*stub_, &ObjectManagerService::Stub::PrepareAsyncFreeObjects, request,
callback);
*stubs_[freeobjects_rr_index_++ % num_connections_],
&ObjectManagerService::Stub::PrepareAsyncFreeObjects, request, callback);
}
private:
int num_connections_;
std::atomic<unsigned int> push_rr_index_;
std::atomic<unsigned int> pull_rr_index_;
std::atomic<unsigned int> freeobjects_rr_index_;
/// The gRPC-generated stub.
std::unique_ptr<ObjectManagerService::Stub> stub_;
std::vector<std::unique_ptr<ObjectManagerService::Stub>> stubs_;
/// The `ClientCallManager` used for managing requests.
ClientCallManager &client_call_manager_;

View file

@ -57,7 +57,7 @@ class ObjectManagerGrpcService : public GrpcService {
service_, &ObjectManagerService::AsyncService::RequestPush, service_handler_,
&ObjectManagerServiceHandler::HandlePushRequest, cq, main_service_));
server_call_factories_and_concurrencies->emplace_back(std::move(push_call_factory),
50);
5);
// Initialize the factory for `Pull` requests.
std::unique_ptr<ServerCallFactory> pull_call_factory(
@ -66,7 +66,7 @@ class ObjectManagerGrpcService : public GrpcService {
service_, &ObjectManagerService::AsyncService::RequestPull, service_handler_,
&ObjectManagerServiceHandler::HandlePullRequest, cq, main_service_));
server_call_factories_and_concurrencies->emplace_back(std::move(pull_call_factory),
50);
5);
// Initialize the factory for `FreeObjects` requests.
std::unique_ptr<ServerCallFactory> free_objects_call_factory(