mirror of
https://github.com/vale981/ray
synced 2025-03-04 17:41:43 -05:00
rpc: Core Worker client pool (#9934)
This commit is contained in:
parent
dee3322ab0
commit
1d01c668f0
12 changed files with 172 additions and 84 deletions
|
@ -1,6 +1,8 @@
|
|||
# Bazel build
|
||||
# C/C++ documentation: https://docs.bazel.build/versions/master/be/c-cpp.html
|
||||
|
||||
load("@rules_python//python:defs.bzl", "py_library")
|
||||
load("@rules_cc//cc:defs.bzl", "cc_binary", "cc_library", "cc_test")
|
||||
load("@com_github_grpc_grpc//bazel:cc_grpc_library.bzl", "cc_grpc_library")
|
||||
load("@com_github_grpc_grpc//bazel:cython_library.bzl", "pyx_library")
|
||||
load("@com_github_google_flatbuffers//:build_defs.bzl", "flatbuffer_cc_library")
|
||||
|
@ -129,6 +131,9 @@ cc_grpc_library(
|
|||
# worker server and client.
|
||||
cc_library(
|
||||
name = "worker_rpc",
|
||||
srcs = glob([
|
||||
"src/ray/rpc/worker/*.cc",
|
||||
]),
|
||||
hdrs = glob([
|
||||
"src/ray/rpc/worker/*.h",
|
||||
]),
|
||||
|
|
|
@ -18,24 +18,17 @@ namespace ray {
|
|||
|
||||
void FutureResolver::ResolveFutureAsync(const ObjectID &object_id,
|
||||
const rpc::Address &owner_address) {
|
||||
absl::MutexLock lock(&mu_);
|
||||
const auto owner_worker_id = WorkerID::FromBinary(owner_address.worker_id());
|
||||
if (rpc_address_.worker_id() == owner_address.worker_id()) {
|
||||
// We do not need to resolve objects that we own. This can happen if a task
|
||||
// with a borrowed reference executes on the object's owning worker.
|
||||
return;
|
||||
}
|
||||
auto it = owner_clients_.find(owner_worker_id);
|
||||
if (it == owner_clients_.end()) {
|
||||
auto client =
|
||||
std::shared_ptr<rpc::CoreWorkerClientInterface>(client_factory_(owner_address));
|
||||
it = owner_clients_.emplace(owner_worker_id, std::move(client)).first;
|
||||
}
|
||||
auto conn = owner_clients_.GetOrConnect(owner_address);
|
||||
|
||||
rpc::GetObjectStatusRequest request;
|
||||
request.set_object_id(object_id.Binary());
|
||||
request.set_owner_worker_id(owner_worker_id.Binary());
|
||||
it->second->GetObjectStatus(
|
||||
request.set_owner_worker_id(owner_address.worker_id());
|
||||
conn->GetObjectStatus(
|
||||
request,
|
||||
[this, object_id](const Status &status, const rpc::GetObjectStatusReply &reply) {
|
||||
if (!status.ok()) {
|
||||
|
|
|
@ -19,6 +19,7 @@
|
|||
#include "ray/common/id.h"
|
||||
#include "ray/core_worker/store_provider/memory_store/memory_store.h"
|
||||
#include "ray/rpc/worker/core_worker_client.h"
|
||||
#include "ray/rpc/worker/core_worker_client_pool.h"
|
||||
#include "src/ray/protobuf/core_worker.pb.h"
|
||||
|
||||
namespace ray {
|
||||
|
@ -30,7 +31,7 @@ class FutureResolver {
|
|||
FutureResolver(std::shared_ptr<CoreWorkerMemoryStore> store,
|
||||
rpc::ClientFactoryFn client_factory, const rpc::Address &rpc_address)
|
||||
: in_memory_store_(store),
|
||||
client_factory_(client_factory),
|
||||
owner_clients_(client_factory),
|
||||
rpc_address_(rpc_address) {}
|
||||
|
||||
/// Resolve the value for a future. This will periodically contact the given
|
||||
|
@ -47,20 +48,12 @@ class FutureResolver {
|
|||
/// Used to store values of resolved futures.
|
||||
std::shared_ptr<CoreWorkerMemoryStore> in_memory_store_;
|
||||
|
||||
/// Factory for producing new core worker clients.
|
||||
const rpc::ClientFactoryFn client_factory_;
|
||||
rpc::CoreWorkerClientPool owner_clients_;
|
||||
|
||||
/// Address of our RPC server. Used to notify borrowed objects' owners of our
|
||||
/// address, so the owner can contact us to ask when our reference to the
|
||||
/// object has gone out of scope.
|
||||
const rpc::Address rpc_address_;
|
||||
|
||||
/// Protects against concurrent access to internal state.
|
||||
absl::Mutex mu_;
|
||||
|
||||
/// Cache of gRPC clients to the objects' owners.
|
||||
absl::flat_hash_map<WorkerID, std::shared_ptr<rpc::CoreWorkerClientInterface>>
|
||||
owner_clients_ GUARDED_BY(mu_);
|
||||
};
|
||||
|
||||
} // namespace ray
|
||||
|
|
|
@ -732,17 +732,13 @@ void ReferenceCounter::WaitForRefRemoved(const ReferenceTable::iterator &ref_it,
|
|||
request.set_contained_in_id(contained_in_id.Binary());
|
||||
request.set_intended_worker_id(addr.worker_id.Binary());
|
||||
|
||||
auto it = borrower_cache_.find(addr);
|
||||
if (it == borrower_cache_.end()) {
|
||||
RAY_CHECK(client_factory_ != nullptr);
|
||||
it = borrower_cache_.emplace(addr, client_factory_(addr.ToProto())).first;
|
||||
}
|
||||
auto conn = borrower_pool_.GetOrConnect(addr.ToProto());
|
||||
|
||||
RAY_LOG(DEBUG) << "Sending WaitForRefRemoved to borrower " << addr.ip_address << ":"
|
||||
<< addr.port << " for object " << object_id;
|
||||
// Send the borrower a message about this object. The borrower responds once
|
||||
// it is no longer using the object ID.
|
||||
it->second->WaitForRefRemoved(
|
||||
conn->WaitForRefRemoved(
|
||||
request, [this, object_id, addr](const Status &status,
|
||||
const rpc::WaitForRefRemovedReply &reply) {
|
||||
RAY_LOG(DEBUG) << "Received reply from borrower " << addr.ip_address << ":"
|
||||
|
|
|
@ -23,6 +23,7 @@
|
|||
#include "ray/common/id.h"
|
||||
#include "ray/rpc/grpc_server.h"
|
||||
#include "ray/rpc/worker/core_worker_client.h"
|
||||
#include "ray/rpc/worker/core_worker_client_pool.h"
|
||||
#include "ray/util/logging.h"
|
||||
#include "src/ray/protobuf/common.pb.h"
|
||||
|
||||
|
@ -66,7 +67,7 @@ class ReferenceCounter : public ReferenceCounterInterface {
|
|||
: rpc_address_(rpc_address),
|
||||
distributed_ref_counting_enabled_(distributed_ref_counting_enabled),
|
||||
lineage_pinning_enabled_(lineage_pinning_enabled),
|
||||
client_factory_(client_factory) {}
|
||||
borrower_pool_(client_factory) {}
|
||||
|
||||
~ReferenceCounter() {}
|
||||
|
||||
|
@ -647,11 +648,10 @@ class ReferenceCounter : public ReferenceCounterInterface {
|
|||
/// Factory for producing new core worker clients.
|
||||
rpc::ClientFactoryFn client_factory_;
|
||||
|
||||
/// Map from worker address to core worker client. The owner of an object
|
||||
/// Pool from worker address to core worker client. The owner of an object
|
||||
/// uses this client to request a notification from borrowers once the
|
||||
/// borrower's ref count for the ID goes to 0.
|
||||
absl::flat_hash_map<rpc::WorkerAddress, std::shared_ptr<rpc::CoreWorkerClientInterface>>
|
||||
borrower_cache_ GUARDED_BY(mutex_);
|
||||
rpc::CoreWorkerClientPool borrower_pool_;
|
||||
|
||||
/// Protects access to the reference counting state.
|
||||
mutable absl::Mutex mutex_;
|
||||
|
|
|
@ -94,12 +94,7 @@ Status CoreWorkerDirectTaskSubmitter::SubmitTask(TaskSpecification task_spec) {
|
|||
|
||||
void CoreWorkerDirectTaskSubmitter::AddWorkerLeaseClient(
|
||||
const rpc::WorkerAddress &addr, std::shared_ptr<WorkerLeaseInterface> lease_client) {
|
||||
auto it = client_cache_.find(addr);
|
||||
if (it == client_cache_.end()) {
|
||||
client_cache_[addr] =
|
||||
std::shared_ptr<rpc::CoreWorkerClientInterface>(client_factory_(addr.ToProto()));
|
||||
RAY_LOG(INFO) << "Connected to " << addr.ip_address << ":" << addr.port;
|
||||
}
|
||||
client_cache_.GetOrConnect(addr.ToProto());
|
||||
int64_t expiration = current_time_ms() + lease_timeout_ms_;
|
||||
LeaseEntry new_lease_entry = LeaseEntry(std::move(lease_client), expiration, 0);
|
||||
worker_to_lease_entry_.emplace(addr, new_lease_entry);
|
||||
|
@ -131,7 +126,7 @@ void CoreWorkerDirectTaskSubmitter::OnWorkerIdle(
|
|||
}
|
||||
|
||||
} else {
|
||||
auto &client = *client_cache_[addr];
|
||||
auto &client = *client_cache_.GetOrConnect(addr.ToProto());
|
||||
|
||||
while (!queue_entry->second.empty() &&
|
||||
lease_entry.tasks_in_flight_ < max_tasks_in_flight_per_worker_) {
|
||||
|
@ -368,17 +363,23 @@ Status CoreWorkerDirectTaskSubmitter::CancelTask(TaskSpecification task_spec,
|
|||
// or when all dependencies are resolved.
|
||||
RAY_CHECK(cancelled_tasks_.emplace(task_spec.TaskId()).second);
|
||||
auto rpc_client = executing_tasks_.find(task_spec.TaskId());
|
||||
// Looks for an RPC handle for the worker executing the task.
|
||||
if (rpc_client != executing_tasks_.end() &&
|
||||
client_cache_.find(rpc_client->second) != client_cache_.end()) {
|
||||
client = client_cache_.find(rpc_client->second)->second;
|
||||
|
||||
if (rpc_client == executing_tasks_.end()) {
|
||||
// This case is reached for tasks that have unresolved dependencies.
|
||||
// No executing tasks, so cancelling is a noop.
|
||||
return Status::OK();
|
||||
}
|
||||
// Looks for an RPC handle for the worker executing the task.
|
||||
auto maybe_client = client_cache_.GetByID(rpc_client->second.worker_id);
|
||||
if (!maybe_client.has_value()) {
|
||||
// If we don't have a connection to that worker, we can't cancel it.
|
||||
// This case is reached for tasks that have unresolved dependencies.
|
||||
return Status::OK();
|
||||
}
|
||||
client = maybe_client.value();
|
||||
}
|
||||
|
||||
// This case is reached for tasks that have unresolved dependencies.
|
||||
if (client == nullptr) {
|
||||
return Status::OK();
|
||||
}
|
||||
RAY_CHECK(client != nullptr);
|
||||
|
||||
auto request = rpc::CancelTaskRequest();
|
||||
request.set_intended_task_id(task_spec.TaskId().Binary());
|
||||
|
@ -408,15 +409,16 @@ Status CoreWorkerDirectTaskSubmitter::CancelTask(TaskSpecification task_spec,
|
|||
Status CoreWorkerDirectTaskSubmitter::CancelRemoteTask(const ObjectID &object_id,
|
||||
const rpc::Address &worker_addr,
|
||||
bool force_kill) {
|
||||
absl::MutexLock lock(&mu_);
|
||||
auto client = client_cache_.find(rpc::WorkerAddress(worker_addr));
|
||||
if (client == client_cache_.end()) {
|
||||
auto maybe_client = client_cache_.GetByID(rpc::WorkerAddress(worker_addr).worker_id);
|
||||
|
||||
if (!maybe_client.has_value()) {
|
||||
return Status::Invalid("No remote worker found");
|
||||
}
|
||||
auto client = maybe_client.value();
|
||||
auto request = rpc::RemoteCancelTaskRequest();
|
||||
request.set_force_kill(force_kill);
|
||||
request.set_remote_object_id(object_id.Binary());
|
||||
client->second->RemoteCancelTask(request, nullptr);
|
||||
client->RemoteCancelTask(request, nullptr);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
|
|
@ -28,6 +28,7 @@
|
|||
#include "ray/core_worker/transport/direct_actor_transport.h"
|
||||
#include "ray/raylet_client/raylet_client.h"
|
||||
#include "ray/rpc/worker/core_worker_client.h"
|
||||
#include "ray/rpc/worker/core_worker_client_pool.h"
|
||||
|
||||
namespace ray {
|
||||
|
||||
|
@ -60,13 +61,13 @@ class CoreWorkerDirectTaskSubmitter {
|
|||
absl::optional<boost::asio::steady_timer> cancel_timer = absl::nullopt)
|
||||
: rpc_address_(rpc_address),
|
||||
local_lease_client_(lease_client),
|
||||
client_factory_(client_factory),
|
||||
lease_client_factory_(lease_client_factory),
|
||||
resolver_(store, task_finisher),
|
||||
task_finisher_(task_finisher),
|
||||
lease_timeout_ms_(lease_timeout_ms),
|
||||
local_raylet_id_(local_raylet_id),
|
||||
actor_creator_(std::move(actor_creator)),
|
||||
client_cache_(client_factory),
|
||||
max_tasks_in_flight_per_worker_(max_tasks_in_flight_per_worker),
|
||||
cancel_retry_timer_(std::move(cancel_timer)) {}
|
||||
|
||||
|
@ -142,9 +143,6 @@ class CoreWorkerDirectTaskSubmitter {
|
|||
absl::flat_hash_map<ClientID, std::shared_ptr<WorkerLeaseInterface>>
|
||||
remote_lease_clients_ GUARDED_BY(mu_);
|
||||
|
||||
/// Factory for producing new core worker clients.
|
||||
rpc::ClientFactoryFn client_factory_;
|
||||
|
||||
/// Factory for producing new clients to request leases from remote nodes.
|
||||
LeaseClientFactoryFn lease_client_factory_;
|
||||
|
||||
|
@ -169,8 +167,7 @@ class CoreWorkerDirectTaskSubmitter {
|
|||
absl::Mutex mu_;
|
||||
|
||||
/// Cache of gRPC clients to other workers.
|
||||
absl::flat_hash_map<rpc::WorkerAddress, std::shared_ptr<rpc::CoreWorkerClientInterface>>
|
||||
client_cache_ GUARDED_BY(mu_);
|
||||
rpc::CoreWorkerClientPool client_cache_;
|
||||
|
||||
// max_tasks_in_flight_per_worker_ limits the number of tasks that can be pipelined to a
|
||||
// worker using a single lease.
|
||||
|
|
|
@ -36,7 +36,7 @@ GcsActorScheduler::GcsActorScheduler(
|
|||
schedule_failure_handler_(std::move(schedule_failure_handler)),
|
||||
schedule_success_handler_(std::move(schedule_success_handler)),
|
||||
lease_client_factory_(std::move(lease_client_factory)),
|
||||
client_factory_(std::move(client_factory)) {
|
||||
core_worker_clients_(client_factory) {
|
||||
RAY_CHECK(schedule_failure_handler_ != nullptr && schedule_success_handler_ != nullptr);
|
||||
}
|
||||
|
||||
|
@ -110,7 +110,7 @@ std::vector<ActorID> GcsActorScheduler::CancelOnNode(const ClientID &node_id) {
|
|||
for (auto &entry : iter->second) {
|
||||
actor_ids.emplace_back(entry.second->GetAssignedActorID());
|
||||
// Remove core worker client.
|
||||
RAY_CHECK(core_worker_clients_.erase(entry.first) != 0);
|
||||
core_worker_clients_.Disconnect(entry.first);
|
||||
}
|
||||
node_to_workers_when_creating_.erase(iter);
|
||||
}
|
||||
|
@ -145,7 +145,7 @@ ActorID GcsActorScheduler::CancelOnWorker(const ClientID &node_id,
|
|||
if (actor_iter != iter->second.end()) {
|
||||
assigned_actor_id = actor_iter->second->GetAssignedActorID();
|
||||
// Remove core worker client.
|
||||
RAY_CHECK(core_worker_clients_.erase(worker_id) != 0);
|
||||
core_worker_clients_.Disconnect(worker_id);
|
||||
iter->second.erase(actor_iter);
|
||||
if (iter->second.empty()) {
|
||||
node_to_workers_when_creating_.erase(iter);
|
||||
|
@ -307,7 +307,7 @@ void GcsActorScheduler::HandleWorkerLeasedReply(
|
|||
// Make sure to connect to the client before persisting actor info to GCS.
|
||||
// Without this, there could be a possible race condition. Related issues:
|
||||
// https://github.com/ray-project/ray/pull/9215/files#r449469320
|
||||
GetOrConnectCoreWorkerClient(leased_worker->GetAddress());
|
||||
core_worker_clients_.GetOrConnect(leased_worker->GetAddress());
|
||||
RAY_CHECK_OK(gcs_actor_table_.Put(actor->GetActorID(), actor->GetActorTableData(),
|
||||
[this, actor, leased_worker](Status status) {
|
||||
RAY_CHECK_OK(status);
|
||||
|
@ -332,7 +332,7 @@ void GcsActorScheduler::CreateActorOnWorker(std::shared_ptr<GcsActor> actor,
|
|||
}
|
||||
request->mutable_resource_mapping()->CopyFrom(resources);
|
||||
|
||||
auto client = GetOrConnectCoreWorkerClient(worker->GetAddress());
|
||||
auto client = core_worker_clients_.GetOrConnect(worker->GetAddress());
|
||||
client->PushNormalTask(
|
||||
std::move(request),
|
||||
[this, actor, worker](Status status, const rpc::PushTaskReply &reply) {
|
||||
|
@ -350,7 +350,7 @@ void GcsActorScheduler::CreateActorOnWorker(std::shared_ptr<GcsActor> actor,
|
|||
// The worker is still in the creating map.
|
||||
if (status.ok()) {
|
||||
// Remove related core worker client.
|
||||
RAY_CHECK(core_worker_clients_.erase(actor->GetWorkerID()) != 0);
|
||||
core_worker_clients_.Disconnect(actor->GetWorkerID());
|
||||
// Remove related worker in phase of creating.
|
||||
iter->second.erase(worker_iter);
|
||||
if (iter->second.empty()) {
|
||||
|
@ -419,15 +419,5 @@ std::shared_ptr<WorkerLeaseInterface> GcsActorScheduler::GetOrConnectLeaseClient
|
|||
return iter->second;
|
||||
}
|
||||
|
||||
std::shared_ptr<rpc::CoreWorkerClientInterface>
|
||||
GcsActorScheduler::GetOrConnectCoreWorkerClient(const rpc::Address &worker_address) {
|
||||
auto worker_id = WorkerID::FromBinary(worker_address.worker_id());
|
||||
auto iter = core_worker_clients_.find(worker_id);
|
||||
if (iter == core_worker_clients_.end()) {
|
||||
iter = core_worker_clients_.emplace(worker_id, client_factory_(worker_address)).first;
|
||||
}
|
||||
return iter->second;
|
||||
}
|
||||
|
||||
} // namespace gcs
|
||||
} // namespace ray
|
||||
|
|
|
@ -27,6 +27,7 @@
|
|||
#include "ray/raylet_client/raylet_client.h"
|
||||
#include "ray/rpc/node_manager/node_manager_client.h"
|
||||
#include "ray/rpc/worker/core_worker_client.h"
|
||||
#include "ray/rpc/worker/core_worker_client_pool.h"
|
||||
#include "src/ray/protobuf/gcs_service.pb.h"
|
||||
|
||||
namespace ray {
|
||||
|
@ -259,10 +260,6 @@ class GcsActorScheduler : public GcsActorSchedulerInterface {
|
|||
std::shared_ptr<WorkerLeaseInterface> GetOrConnectLeaseClient(
|
||||
const rpc::Address &raylet_address);
|
||||
|
||||
/// Get or create CoreWorkerClient to communicate with the remote leased worker.
|
||||
std::shared_ptr<rpc::CoreWorkerClientInterface> GetOrConnectCoreWorkerClient(
|
||||
const rpc::Address &worker_address);
|
||||
|
||||
protected:
|
||||
/// The io loop that is used to delay execution of tasks (e.g.,
|
||||
/// execute_after).
|
||||
|
@ -282,9 +279,6 @@ class GcsActorScheduler : public GcsActorSchedulerInterface {
|
|||
/// The cached node clients which are used to communicate with raylet to lease workers.
|
||||
absl::flat_hash_map<ClientID, std::shared_ptr<WorkerLeaseInterface>>
|
||||
remote_lease_clients_;
|
||||
/// The cached core worker clients which are used to communicate with leased worker.
|
||||
absl::flat_hash_map<WorkerID, std::shared_ptr<rpc::CoreWorkerClientInterface>>
|
||||
core_worker_clients_;
|
||||
/// Reference of GcsNodeManager.
|
||||
const GcsNodeManager &gcs_node_manager_;
|
||||
/// A publisher for publishing gcs messages.
|
||||
|
@ -295,10 +289,10 @@ class GcsActorScheduler : public GcsActorSchedulerInterface {
|
|||
std::function<void(std::shared_ptr<GcsActor>)> schedule_success_handler_;
|
||||
/// Factory for producing new clients to request leases from remote nodes.
|
||||
LeaseClientFactoryFn lease_client_factory_;
|
||||
/// Factory for producing new core worker clients.
|
||||
rpc::ClientFactoryFn client_factory_;
|
||||
/// The nodes which are releasing unused workers.
|
||||
absl::flat_hash_set<ClientID> nodes_of_releasing_unused_workers_;
|
||||
/// The cached core worker clients which are used to communicate with leased worker.
|
||||
rpc::CoreWorkerClientPool core_worker_clients_;
|
||||
};
|
||||
|
||||
} // namespace gcs
|
||||
|
|
|
@ -205,10 +205,6 @@ struct GcsServerMocker {
|
|||
lease_client_factory_ = std::move(lease_client_factory);
|
||||
}
|
||||
|
||||
void ResetClientFactory(rpc::ClientFactoryFn client_factory) {
|
||||
client_factory_ = std::move(client_factory);
|
||||
}
|
||||
|
||||
void TryLeaseWorkerFromNodeAgain(std::shared_ptr<gcs::GcsActor> actor,
|
||||
std::shared_ptr<rpc::GcsNodeInfo> node) {
|
||||
DoRetryLeasingWorkerFromNode(std::move(actor), std::move(node));
|
||||
|
|
42
src/ray/rpc/worker/core_worker_client_pool.cc
Normal file
42
src/ray/rpc/worker/core_worker_client_pool.cc
Normal file
|
@ -0,0 +1,42 @@
|
|||
#include "ray/rpc/worker/core_worker_client_pool.h"
|
||||
|
||||
namespace ray {
|
||||
namespace rpc {
|
||||
|
||||
optional<shared_ptr<CoreWorkerClientInterface>> CoreWorkerClientPool::GetByID(
|
||||
ray::WorkerID id) {
|
||||
absl::MutexLock lock(&mu_);
|
||||
auto it = client_map_.find(id);
|
||||
if (it == client_map_.end()) {
|
||||
return {};
|
||||
}
|
||||
return it->second;
|
||||
}
|
||||
|
||||
shared_ptr<CoreWorkerClientInterface> CoreWorkerClientPool::GetOrConnect(
|
||||
const Address &addr_proto) {
|
||||
RAY_CHECK(addr_proto.worker_id() != "");
|
||||
absl::MutexLock lock(&mu_);
|
||||
auto id = WorkerID::FromBinary(addr_proto.worker_id());
|
||||
auto it = client_map_.find(id);
|
||||
if (it != client_map_.end()) {
|
||||
return it->second;
|
||||
}
|
||||
auto connection = client_factory_(addr_proto);
|
||||
client_map_[id] = connection;
|
||||
|
||||
RAY_LOG(INFO) << "Connected to " << addr_proto.ip_address() << ":" << addr_proto.port();
|
||||
return connection;
|
||||
}
|
||||
|
||||
void CoreWorkerClientPool::Disconnect(ray::WorkerID id) {
|
||||
absl::MutexLock lock(&mu_);
|
||||
auto it = client_map_.find(id);
|
||||
if (it == client_map_.end()) {
|
||||
return;
|
||||
}
|
||||
client_map_.erase(it);
|
||||
}
|
||||
|
||||
} // namespace rpc
|
||||
} // namespace ray
|
80
src/ray/rpc/worker/core_worker_client_pool.h
Normal file
80
src/ray/rpc/worker/core_worker_client_pool.h
Normal file
|
@ -0,0 +1,80 @@
|
|||
// Copyright 2020 The Ray Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "absl/base/thread_annotations.h"
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/synchronization/mutex.h"
|
||||
#include "ray/common/id.h"
|
||||
#include "ray/rpc/worker/core_worker_client.h"
|
||||
|
||||
using absl::optional;
|
||||
using std::shared_ptr;
|
||||
|
||||
namespace ray {
|
||||
namespace rpc {
|
||||
|
||||
class CoreWorkerClientPool {
|
||||
public:
|
||||
CoreWorkerClientPool() = delete;
|
||||
|
||||
/// Creates a CoreWorkerClientPool based on the low-level ClientCallManager.
|
||||
CoreWorkerClientPool(rpc::ClientCallManager &ccm)
|
||||
: client_factory_(defaultClientFactory(ccm)){};
|
||||
|
||||
/// Creates a CoreWorkerClientPool by a given connection function.
|
||||
CoreWorkerClientPool(ClientFactoryFn client_factory)
|
||||
: client_factory_(client_factory){};
|
||||
|
||||
/// Returns an existing Interface if one exists, or an empty optional
|
||||
/// otherwise.
|
||||
/// Any returned pointer is borrowed, and expected to be used briefly.
|
||||
optional<shared_ptr<CoreWorkerClientInterface>> GetByID(ray::WorkerID id);
|
||||
|
||||
/// Returns an open CoreWorkerClientInterface if one exists, and connect to one
|
||||
/// if it does not. The returned pointer is borrowed, and expected to be used
|
||||
/// briefly.
|
||||
shared_ptr<CoreWorkerClientInterface> GetOrConnect(const Address &addr_proto);
|
||||
|
||||
/// Removes a connection to the worker from the pool, if one exists. Since the
|
||||
/// shared pointer will no longer be retained in the pool, the connection will
|
||||
/// be open until it's no longer used, at which time it will disconnect.
|
||||
void Disconnect(ray::WorkerID id);
|
||||
|
||||
private:
|
||||
/// Provides the default client factory function. Providing this function to the
|
||||
/// construtor aids migration but is ultimately a thing that should be
|
||||
/// deprecated and brought internal to the pool, so this is our bridge.
|
||||
ClientFactoryFn defaultClientFactory(rpc::ClientCallManager &ccm) const {
|
||||
return [&](const rpc::Address &addr) {
|
||||
return std::shared_ptr<rpc::CoreWorkerClient>(new rpc::CoreWorkerClient(addr, ccm));
|
||||
};
|
||||
};
|
||||
|
||||
/// This factory function does the connection to CoreWorkerClient, and is
|
||||
/// provided by the constructor (either the default implementation, above, or a
|
||||
/// provided one)
|
||||
ClientFactoryFn client_factory_;
|
||||
|
||||
absl::Mutex mu_;
|
||||
|
||||
/// A pool of open connections by WorkerID. Clients can reuse the connection
|
||||
/// objects in this pool by requesting them.
|
||||
absl::flat_hash_map<ray::WorkerID, shared_ptr<CoreWorkerClientInterface>> client_map_
|
||||
GUARDED_BY(mu_);
|
||||
};
|
||||
|
||||
} // namespace rpc
|
||||
} // namespace ray
|
Loading…
Add table
Reference in a new issue