mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
Improve Object Transfer Performance (#6067)
This commit is contained in:
parent
d3ff2252c4
commit
4f583ec784
7 changed files with 102 additions and 48 deletions
|
@ -18,9 +18,10 @@ ObjectManager::ObjectManager(asio::io_service &main_service,
|
|||
buffer_pool_(config_.store_socket_name, config_.object_chunk_size),
|
||||
rpc_work_(rpc_service_),
|
||||
gen_(std::chrono::high_resolution_clock::now().time_since_epoch().count()),
|
||||
object_manager_server_("ObjectManager", config_.object_manager_port),
|
||||
object_manager_server_("ObjectManager", config_.object_manager_port,
|
||||
config_.rpc_service_threads_number),
|
||||
object_manager_service_(rpc_service_, *this),
|
||||
client_call_manager_(main_service) {
|
||||
client_call_manager_(main_service, config_.rpc_service_threads_number) {
|
||||
RAY_CHECK(config_.rpc_service_threads_number > 0);
|
||||
client_id_ = object_directory_->GetLocalClientID();
|
||||
main_service_ = &main_service;
|
||||
|
@ -443,10 +444,7 @@ ray::Status ObjectManager::SendObjectChunk(
|
|||
RAY_RETURN_NOT_OK(status);
|
||||
}
|
||||
|
||||
std::string buffer;
|
||||
buffer.resize(chunk_info.buffer_length);
|
||||
buffer.assign(chunk_info.data, chunk_info.data + chunk_info.buffer_length);
|
||||
push_request.set_data(std::move(buffer));
|
||||
push_request.set_data(chunk_info.data, chunk_info.buffer_length);
|
||||
|
||||
// record the time cost between send chunk and receive reply
|
||||
rpc::ClientCallback<rpc::PushReply> callback = [this, start_time, object_id, client_id,
|
||||
|
|
|
@ -140,7 +140,8 @@ int main(int argc, char *argv[]) {
|
|||
RayConfig::instance().object_manager_push_timeout_ms();
|
||||
|
||||
int num_cpus = static_cast<int>(static_resource_conf["CPU"]);
|
||||
object_manager_config.rpc_service_threads_number = std::max(2, num_cpus / 2);
|
||||
object_manager_config.rpc_service_threads_number =
|
||||
std::min(std::max(2, num_cpus / 4), 8);
|
||||
object_manager_config.object_chunk_size =
|
||||
RayConfig::instance().object_manager_default_chunk_size();
|
||||
|
||||
|
|
|
@ -123,16 +123,25 @@ class ClientCallManager {
|
|||
///
|
||||
/// \param[in] main_service The main event loop, to which the callback functions will be
|
||||
/// posted.
|
||||
explicit ClientCallManager(boost::asio::io_service &main_service)
|
||||
: main_service_(main_service) {
|
||||
// Start the polling thread.
|
||||
polling_thread_ =
|
||||
std::thread(&ClientCallManager::PollEventsFromCompletionQueue, this);
|
||||
explicit ClientCallManager(boost::asio::io_service &main_service, int num_threads = 1)
|
||||
: main_service_(main_service), num_threads_(num_threads) {
|
||||
rr_index_ = rand() % num_threads_;
|
||||
// Start the polling threads.
|
||||
cqs_.reserve(num_threads_);
|
||||
for (int i = 0; i < num_threads_; i++) {
|
||||
cqs_.emplace_back();
|
||||
polling_threads_.emplace_back(&ClientCallManager::PollEventsFromCompletionQueue,
|
||||
this, i);
|
||||
}
|
||||
}
|
||||
|
||||
~ClientCallManager() {
|
||||
cq_.Shutdown();
|
||||
polling_thread_.join();
|
||||
for (auto &cq : cqs_) {
|
||||
cq.Shutdown();
|
||||
}
|
||||
for (auto &polling_thread : polling_threads_) {
|
||||
polling_thread.join();
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new `ClientCall` and send request.
|
||||
|
@ -155,8 +164,9 @@ class ClientCallManager {
|
|||
const Request &request, const ClientCallback<Reply> &callback) {
|
||||
auto call = std::make_shared<ClientCallImpl<Reply>>(callback);
|
||||
// Send request.
|
||||
call->response_reader_ =
|
||||
(stub.*prepare_async_function)(&call->context_, request, &cq_);
|
||||
// Find the next completion queue to wait for response.
|
||||
call->response_reader_ = (stub.*prepare_async_function)(
|
||||
&call->context_, request, &cqs_[rr_index_++ % num_threads_]);
|
||||
call->response_reader_->StartCall();
|
||||
// Create a new tag object. This object will eventually be deleted in the
|
||||
// `ClientCallManager::PollEventsFromCompletionQueue` when reply is received.
|
||||
|
@ -174,7 +184,7 @@ class ClientCallManager {
|
|||
/// This function runs in a background thread. It keeps polling events from the
|
||||
/// `CompletionQueue`, and dispatches the event to the callbacks via the `ClientCall`
|
||||
/// objects.
|
||||
void PollEventsFromCompletionQueue() {
|
||||
void PollEventsFromCompletionQueue(int index) {
|
||||
void *got_tag;
|
||||
bool ok = false;
|
||||
auto deadline = gpr_inf_future(GPR_CLOCK_REALTIME);
|
||||
|
@ -183,7 +193,7 @@ class ClientCallManager {
|
|||
// synchronous cq_.Next blocks indefinitely in the case that the process
|
||||
// received a SIGTERM.
|
||||
while (true) {
|
||||
auto status = cq_.AsyncNext(&got_tag, &ok, deadline);
|
||||
auto status = cqs_[index].AsyncNext(&got_tag, &ok, deadline);
|
||||
if (status == grpc::CompletionQueue::SHUTDOWN) {
|
||||
break;
|
||||
}
|
||||
|
@ -206,11 +216,17 @@ class ClientCallManager {
|
|||
/// The main event loop, to which the callback functions will be posted.
|
||||
boost::asio::io_service &main_service_;
|
||||
|
||||
/// The gRPC `CompletionQueue` object used to poll events.
|
||||
grpc::CompletionQueue cq_;
|
||||
/// The number of polling threads.
|
||||
int num_threads_;
|
||||
|
||||
/// Polling thread to check the completion queue.
|
||||
std::thread polling_thread_;
|
||||
/// The index to send RPCs in a round-robin fashion
|
||||
std::atomic<unsigned int> rr_index_;
|
||||
|
||||
/// The gRPC `CompletionQueue` object used to poll events.
|
||||
std::vector<grpc::CompletionQueue> cqs_;
|
||||
|
||||
/// Polling threads to check the completion queue.
|
||||
std::vector<std::thread> polling_threads_;
|
||||
};
|
||||
|
||||
} // namespace rpc
|
||||
|
|
|
@ -23,6 +23,11 @@ bool PortNotInUse(int port) {
|
|||
namespace ray {
|
||||
namespace rpc {
|
||||
|
||||
GrpcServer::GrpcServer(std::string name, const uint32_t port, int num_threads)
|
||||
: name_(std::move(name)), port_(port), is_closed_(true), num_threads_(num_threads) {
|
||||
cqs_.reserve(num_threads_);
|
||||
}
|
||||
|
||||
void GrpcServer::Run() {
|
||||
std::string server_address("0.0.0.0:" + std::to_string(port_));
|
||||
// Unfortunately, grpc will not return an error if the specified port is in
|
||||
|
@ -47,7 +52,9 @@ void GrpcServer::Run() {
|
|||
}
|
||||
// Get hold of the completion queue used for the asynchronous communication
|
||||
// with the gRPC runtime.
|
||||
cq_ = builder.AddCompletionQueue();
|
||||
for (int i = 0; i < num_threads_; i++) {
|
||||
cqs_.push_back(builder.AddCompletionQueue());
|
||||
}
|
||||
// Build and start server.
|
||||
server_ = builder.BuildAndStart();
|
||||
RAY_LOG(INFO) << name_ << " server started, listening on port " << port_ << ".";
|
||||
|
@ -59,22 +66,28 @@ void GrpcServer::Run() {
|
|||
entry.first->CreateCall();
|
||||
}
|
||||
}
|
||||
// Start a thread that polls incoming requests.
|
||||
polling_thread_ = std::thread(&GrpcServer::PollEventsFromCompletionQueue, this);
|
||||
// Start threads that polls incoming requests.
|
||||
for (int i = 0; i < num_threads_; i++) {
|
||||
polling_threads_.emplace_back(&GrpcServer::PollEventsFromCompletionQueue, this, i);
|
||||
}
|
||||
// Set the server as running.
|
||||
is_closed_ = false;
|
||||
}
|
||||
|
||||
void GrpcServer::RegisterService(GrpcService &service) {
|
||||
services_.emplace_back(service.GetGrpcService());
|
||||
service.InitServerCallFactories(cq_, &server_call_factories_and_concurrencies_);
|
||||
|
||||
for (int i = 0; i < num_threads_; i++) {
|
||||
service.InitServerCallFactories(cqs_[i], &server_call_factories_and_concurrencies_);
|
||||
}
|
||||
}
|
||||
|
||||
void GrpcServer::PollEventsFromCompletionQueue() {
|
||||
void GrpcServer::PollEventsFromCompletionQueue(int index) {
|
||||
void *tag;
|
||||
bool ok;
|
||||
|
||||
// Keep reading events from the `CompletionQueue` until it's shutdown.
|
||||
while (cq_->Next(&tag, &ok)) {
|
||||
while (cqs_[index]->Next(&tag, &ok)) {
|
||||
auto *server_call = static_cast<ServerCall *>(tag);
|
||||
bool delete_call = false;
|
||||
if (ok) {
|
||||
|
|
|
@ -31,8 +31,7 @@ class GrpcServer {
|
|||
/// \param[in] name Name of this server, used for logging and debugging purpose.
|
||||
/// \param[in] port The port to bind this server to. If it's 0, a random available port
|
||||
/// will be chosen.
|
||||
GrpcServer(std::string name, const uint32_t port)
|
||||
: name_(std::move(name)), port_(port), is_closed_(true) {}
|
||||
GrpcServer(std::string name, const uint32_t port, int num_threads = 1);
|
||||
|
||||
/// Destruct this gRPC server.
|
||||
~GrpcServer() { Shutdown(); }
|
||||
|
@ -44,8 +43,12 @@ class GrpcServer {
|
|||
void Shutdown() {
|
||||
if (!is_closed_) {
|
||||
server_->Shutdown();
|
||||
cq_->Shutdown();
|
||||
polling_thread_.join();
|
||||
for (const auto &cq : cqs_) {
|
||||
cq->Shutdown();
|
||||
}
|
||||
for (auto &polling_thread : polling_threads_) {
|
||||
polling_thread.join();
|
||||
}
|
||||
is_closed_ = true;
|
||||
RAY_LOG(DEBUG) << "gRPC server of " << name_ << " shutdown.";
|
||||
}
|
||||
|
@ -65,7 +68,7 @@ class GrpcServer {
|
|||
/// This function runs in a background thread. It keeps polling events from the
|
||||
/// `ServerCompletionQueue`, and dispaches the event to the `ServiceHandler` instances
|
||||
/// via the `ServerCall` objects.
|
||||
void PollEventsFromCompletionQueue();
|
||||
void PollEventsFromCompletionQueue(int index);
|
||||
|
||||
/// Name of this server, used for logging and debugging purpose.
|
||||
const std::string name_;
|
||||
|
@ -79,12 +82,14 @@ class GrpcServer {
|
|||
/// this gRPC server can handle.
|
||||
std::vector<std::pair<std::unique_ptr<ServerCallFactory>, int>>
|
||||
server_call_factories_and_concurrencies_;
|
||||
/// The number of completion queues the server is polling from.
|
||||
int num_threads_;
|
||||
/// The `ServerCompletionQueue` object used for polling events.
|
||||
std::unique_ptr<grpc::ServerCompletionQueue> cq_;
|
||||
std::vector<std::unique_ptr<grpc::ServerCompletionQueue>> cqs_;
|
||||
/// The `Server` object.
|
||||
std::unique_ptr<grpc::Server> server_;
|
||||
/// The polling thread used to check the completion queue.
|
||||
std::thread polling_thread_;
|
||||
/// The polling threads used to check the completion queues.
|
||||
std::vector<std::thread> polling_threads_;
|
||||
};
|
||||
|
||||
/// Base class that represents an abstract gRPC service.
|
||||
|
|
|
@ -4,6 +4,8 @@
|
|||
#include <thread>
|
||||
|
||||
#include <grpcpp/grpcpp.h>
|
||||
#include <grpcpp/resource_quota.h>
|
||||
#include <grpcpp/support/channel_arguments.h>
|
||||
|
||||
#include "ray/common/status.h"
|
||||
#include "ray/util/logging.h"
|
||||
|
@ -23,11 +25,22 @@ class ObjectManagerClient {
|
|||
/// \param[in] port Port of the node manager server.
|
||||
/// \param[in] client_call_manager The `ClientCallManager` used for managing requests.
|
||||
ObjectManagerClient(const std::string &address, const int port,
|
||||
ClientCallManager &client_call_manager)
|
||||
: client_call_manager_(client_call_manager) {
|
||||
std::shared_ptr<grpc::Channel> channel = grpc::CreateChannel(
|
||||
address + ":" + std::to_string(port), grpc::InsecureChannelCredentials());
|
||||
stub_ = ObjectManagerService::NewStub(channel);
|
||||
ClientCallManager &client_call_manager, int num_connections = 4)
|
||||
: client_call_manager_(client_call_manager), num_connections_(num_connections) {
|
||||
push_rr_index_ = rand() % num_connections_;
|
||||
pull_rr_index_ = rand() % num_connections_;
|
||||
freeobjects_rr_index_ = rand() % num_connections_;
|
||||
stubs_.reserve(num_connections_);
|
||||
for (int i = 0; i < num_connections_; i++) {
|
||||
grpc::ResourceQuota quota;
|
||||
quota.SetMaxThreads(num_connections_);
|
||||
grpc::ChannelArguments argument;
|
||||
argument.SetResourceQuota(quota);
|
||||
std::shared_ptr<grpc::Channel> channel =
|
||||
grpc::CreateCustomChannel(address + ":" + std::to_string(port),
|
||||
grpc::InsecureChannelCredentials(), argument);
|
||||
stubs_.push_back(ObjectManagerService::NewStub(channel));
|
||||
}
|
||||
};
|
||||
|
||||
/// Push object to remote object manager
|
||||
|
@ -36,7 +49,8 @@ class ObjectManagerClient {
|
|||
/// \param callback The callback function that handles reply from server
|
||||
void Push(const PushRequest &request, const ClientCallback<PushReply> &callback) {
|
||||
client_call_manager_.CreateCall<ObjectManagerService, PushRequest, PushReply>(
|
||||
*stub_, &ObjectManagerService::Stub::PrepareAsyncPush, request, callback);
|
||||
*stubs_[push_rr_index_++ % num_connections_],
|
||||
&ObjectManagerService::Stub::PrepareAsyncPush, request, callback);
|
||||
}
|
||||
|
||||
/// Pull object from remote object manager
|
||||
|
@ -45,7 +59,8 @@ class ObjectManagerClient {
|
|||
/// \param callback The callback function that handles reply from server
|
||||
void Pull(const PullRequest &request, const ClientCallback<PullReply> &callback) {
|
||||
client_call_manager_.CreateCall<ObjectManagerService, PullRequest, PullReply>(
|
||||
*stub_, &ObjectManagerService::Stub::PrepareAsyncPull, request, callback);
|
||||
*stubs_[pull_rr_index_++ % num_connections_],
|
||||
&ObjectManagerService::Stub::PrepareAsyncPull, request, callback);
|
||||
}
|
||||
|
||||
/// Tell remote object manager to free objects
|
||||
|
@ -56,13 +71,19 @@ class ObjectManagerClient {
|
|||
const ClientCallback<FreeObjectsReply> &callback) {
|
||||
client_call_manager_
|
||||
.CreateCall<ObjectManagerService, FreeObjectsRequest, FreeObjectsReply>(
|
||||
*stub_, &ObjectManagerService::Stub::PrepareAsyncFreeObjects, request,
|
||||
callback);
|
||||
*stubs_[freeobjects_rr_index_++ % num_connections_],
|
||||
&ObjectManagerService::Stub::PrepareAsyncFreeObjects, request, callback);
|
||||
}
|
||||
|
||||
private:
|
||||
int num_connections_;
|
||||
|
||||
std::atomic<unsigned int> push_rr_index_;
|
||||
std::atomic<unsigned int> pull_rr_index_;
|
||||
std::atomic<unsigned int> freeobjects_rr_index_;
|
||||
|
||||
/// The gRPC-generated stub.
|
||||
std::unique_ptr<ObjectManagerService::Stub> stub_;
|
||||
std::vector<std::unique_ptr<ObjectManagerService::Stub>> stubs_;
|
||||
|
||||
/// The `ClientCallManager` used for managing requests.
|
||||
ClientCallManager &client_call_manager_;
|
||||
|
|
|
@ -57,7 +57,7 @@ class ObjectManagerGrpcService : public GrpcService {
|
|||
service_, &ObjectManagerService::AsyncService::RequestPush, service_handler_,
|
||||
&ObjectManagerServiceHandler::HandlePushRequest, cq, main_service_));
|
||||
server_call_factories_and_concurrencies->emplace_back(std::move(push_call_factory),
|
||||
50);
|
||||
5);
|
||||
|
||||
// Initialize the factory for `Pull` requests.
|
||||
std::unique_ptr<ServerCallFactory> pull_call_factory(
|
||||
|
@ -66,7 +66,7 @@ class ObjectManagerGrpcService : public GrpcService {
|
|||
service_, &ObjectManagerService::AsyncService::RequestPull, service_handler_,
|
||||
&ObjectManagerServiceHandler::HandlePullRequest, cq, main_service_));
|
||||
server_call_factories_and_concurrencies->emplace_back(std::move(pull_call_factory),
|
||||
50);
|
||||
5);
|
||||
|
||||
// Initialize the factory for `FreeObjects` requests.
|
||||
std::unique_ptr<ServerCallFactory> free_objects_call_factory(
|
||||
|
|
Loading…
Add table
Reference in a new issue