rpc: Core Worker client pool (#9934)

This commit is contained in:
Barak Michener 2020-08-07 16:34:29 -07:00 committed by GitHub
parent dee3322ab0
commit 1d01c668f0
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
12 changed files with 172 additions and 84 deletions

View file

@ -1,6 +1,8 @@
# Bazel build
# C/C++ documentation: https://docs.bazel.build/versions/master/be/c-cpp.html
load("@rules_python//python:defs.bzl", "py_library")
load("@rules_cc//cc:defs.bzl", "cc_binary", "cc_library", "cc_test")
load("@com_github_grpc_grpc//bazel:cc_grpc_library.bzl", "cc_grpc_library")
load("@com_github_grpc_grpc//bazel:cython_library.bzl", "pyx_library")
load("@com_github_google_flatbuffers//:build_defs.bzl", "flatbuffer_cc_library")
@ -129,6 +131,9 @@ cc_grpc_library(
# worker server and client.
cc_library(
name = "worker_rpc",
srcs = glob([
"src/ray/rpc/worker/*.cc",
]),
hdrs = glob([
"src/ray/rpc/worker/*.h",
]),

View file

@ -18,24 +18,17 @@ namespace ray {
void FutureResolver::ResolveFutureAsync(const ObjectID &object_id,
const rpc::Address &owner_address) {
absl::MutexLock lock(&mu_);
const auto owner_worker_id = WorkerID::FromBinary(owner_address.worker_id());
if (rpc_address_.worker_id() == owner_address.worker_id()) {
// We do not need to resolve objects that we own. This can happen if a task
// with a borrowed reference executes on the object's owning worker.
return;
}
auto it = owner_clients_.find(owner_worker_id);
if (it == owner_clients_.end()) {
auto client =
std::shared_ptr<rpc::CoreWorkerClientInterface>(client_factory_(owner_address));
it = owner_clients_.emplace(owner_worker_id, std::move(client)).first;
}
auto conn = owner_clients_.GetOrConnect(owner_address);
rpc::GetObjectStatusRequest request;
request.set_object_id(object_id.Binary());
request.set_owner_worker_id(owner_worker_id.Binary());
it->second->GetObjectStatus(
request.set_owner_worker_id(owner_address.worker_id());
conn->GetObjectStatus(
request,
[this, object_id](const Status &status, const rpc::GetObjectStatusReply &reply) {
if (!status.ok()) {

View file

@ -19,6 +19,7 @@
#include "ray/common/id.h"
#include "ray/core_worker/store_provider/memory_store/memory_store.h"
#include "ray/rpc/worker/core_worker_client.h"
#include "ray/rpc/worker/core_worker_client_pool.h"
#include "src/ray/protobuf/core_worker.pb.h"
namespace ray {
@ -30,7 +31,7 @@ class FutureResolver {
FutureResolver(std::shared_ptr<CoreWorkerMemoryStore> store,
rpc::ClientFactoryFn client_factory, const rpc::Address &rpc_address)
: in_memory_store_(store),
client_factory_(client_factory),
owner_clients_(client_factory),
rpc_address_(rpc_address) {}
/// Resolve the value for a future. This will periodically contact the given
@ -47,20 +48,12 @@ class FutureResolver {
/// Used to store values of resolved futures.
std::shared_ptr<CoreWorkerMemoryStore> in_memory_store_;
/// Factory for producing new core worker clients.
const rpc::ClientFactoryFn client_factory_;
rpc::CoreWorkerClientPool owner_clients_;
/// Address of our RPC server. Used to notify borrowed objects' owners of our
/// address, so the owner can contact us to ask when our reference to the
/// object has gone out of scope.
const rpc::Address rpc_address_;
/// Protects against concurrent access to internal state.
absl::Mutex mu_;
/// Cache of gRPC clients to the objects' owners.
absl::flat_hash_map<WorkerID, std::shared_ptr<rpc::CoreWorkerClientInterface>>
owner_clients_ GUARDED_BY(mu_);
};
} // namespace ray

View file

@ -732,17 +732,13 @@ void ReferenceCounter::WaitForRefRemoved(const ReferenceTable::iterator &ref_it,
request.set_contained_in_id(contained_in_id.Binary());
request.set_intended_worker_id(addr.worker_id.Binary());
auto it = borrower_cache_.find(addr);
if (it == borrower_cache_.end()) {
RAY_CHECK(client_factory_ != nullptr);
it = borrower_cache_.emplace(addr, client_factory_(addr.ToProto())).first;
}
auto conn = borrower_pool_.GetOrConnect(addr.ToProto());
RAY_LOG(DEBUG) << "Sending WaitForRefRemoved to borrower " << addr.ip_address << ":"
<< addr.port << " for object " << object_id;
// Send the borrower a message about this object. The borrower responds once
// it is no longer using the object ID.
it->second->WaitForRefRemoved(
conn->WaitForRefRemoved(
request, [this, object_id, addr](const Status &status,
const rpc::WaitForRefRemovedReply &reply) {
RAY_LOG(DEBUG) << "Received reply from borrower " << addr.ip_address << ":"

View file

@ -23,6 +23,7 @@
#include "ray/common/id.h"
#include "ray/rpc/grpc_server.h"
#include "ray/rpc/worker/core_worker_client.h"
#include "ray/rpc/worker/core_worker_client_pool.h"
#include "ray/util/logging.h"
#include "src/ray/protobuf/common.pb.h"
@ -66,7 +67,7 @@ class ReferenceCounter : public ReferenceCounterInterface {
: rpc_address_(rpc_address),
distributed_ref_counting_enabled_(distributed_ref_counting_enabled),
lineage_pinning_enabled_(lineage_pinning_enabled),
client_factory_(client_factory) {}
borrower_pool_(client_factory) {}
~ReferenceCounter() {}
@ -647,11 +648,10 @@ class ReferenceCounter : public ReferenceCounterInterface {
/// Factory for producing new core worker clients.
rpc::ClientFactoryFn client_factory_;
/// Map from worker address to core worker client. The owner of an object
/// Pool from worker address to core worker client. The owner of an object
/// uses this client to request a notification from borrowers once the
/// borrower's ref count for the ID goes to 0.
absl::flat_hash_map<rpc::WorkerAddress, std::shared_ptr<rpc::CoreWorkerClientInterface>>
borrower_cache_ GUARDED_BY(mutex_);
rpc::CoreWorkerClientPool borrower_pool_;
/// Protects access to the reference counting state.
mutable absl::Mutex mutex_;

View file

@ -94,12 +94,7 @@ Status CoreWorkerDirectTaskSubmitter::SubmitTask(TaskSpecification task_spec) {
void CoreWorkerDirectTaskSubmitter::AddWorkerLeaseClient(
const rpc::WorkerAddress &addr, std::shared_ptr<WorkerLeaseInterface> lease_client) {
auto it = client_cache_.find(addr);
if (it == client_cache_.end()) {
client_cache_[addr] =
std::shared_ptr<rpc::CoreWorkerClientInterface>(client_factory_(addr.ToProto()));
RAY_LOG(INFO) << "Connected to " << addr.ip_address << ":" << addr.port;
}
client_cache_.GetOrConnect(addr.ToProto());
int64_t expiration = current_time_ms() + lease_timeout_ms_;
LeaseEntry new_lease_entry = LeaseEntry(std::move(lease_client), expiration, 0);
worker_to_lease_entry_.emplace(addr, new_lease_entry);
@ -131,7 +126,7 @@ void CoreWorkerDirectTaskSubmitter::OnWorkerIdle(
}
} else {
auto &client = *client_cache_[addr];
auto &client = *client_cache_.GetOrConnect(addr.ToProto());
while (!queue_entry->second.empty() &&
lease_entry.tasks_in_flight_ < max_tasks_in_flight_per_worker_) {
@ -368,17 +363,23 @@ Status CoreWorkerDirectTaskSubmitter::CancelTask(TaskSpecification task_spec,
// or when all dependencies are resolved.
RAY_CHECK(cancelled_tasks_.emplace(task_spec.TaskId()).second);
auto rpc_client = executing_tasks_.find(task_spec.TaskId());
// Looks for an RPC handle for the worker executing the task.
if (rpc_client != executing_tasks_.end() &&
client_cache_.find(rpc_client->second) != client_cache_.end()) {
client = client_cache_.find(rpc_client->second)->second;
if (rpc_client == executing_tasks_.end()) {
// This case is reached for tasks that have unresolved dependencies.
// No executing tasks, so cancelling is a noop.
return Status::OK();
}
// Looks for an RPC handle for the worker executing the task.
auto maybe_client = client_cache_.GetByID(rpc_client->second.worker_id);
if (!maybe_client.has_value()) {
// If we don't have a connection to that worker, we can't cancel it.
// This case is reached for tasks that have unresolved dependencies.
return Status::OK();
}
client = maybe_client.value();
}
// This case is reached for tasks that have unresolved dependencies.
if (client == nullptr) {
return Status::OK();
}
RAY_CHECK(client != nullptr);
auto request = rpc::CancelTaskRequest();
request.set_intended_task_id(task_spec.TaskId().Binary());
@ -408,15 +409,16 @@ Status CoreWorkerDirectTaskSubmitter::CancelTask(TaskSpecification task_spec,
Status CoreWorkerDirectTaskSubmitter::CancelRemoteTask(const ObjectID &object_id,
const rpc::Address &worker_addr,
bool force_kill) {
absl::MutexLock lock(&mu_);
auto client = client_cache_.find(rpc::WorkerAddress(worker_addr));
if (client == client_cache_.end()) {
auto maybe_client = client_cache_.GetByID(rpc::WorkerAddress(worker_addr).worker_id);
if (!maybe_client.has_value()) {
return Status::Invalid("No remote worker found");
}
auto client = maybe_client.value();
auto request = rpc::RemoteCancelTaskRequest();
request.set_force_kill(force_kill);
request.set_remote_object_id(object_id.Binary());
client->second->RemoteCancelTask(request, nullptr);
client->RemoteCancelTask(request, nullptr);
return Status::OK();
}

View file

@ -28,6 +28,7 @@
#include "ray/core_worker/transport/direct_actor_transport.h"
#include "ray/raylet_client/raylet_client.h"
#include "ray/rpc/worker/core_worker_client.h"
#include "ray/rpc/worker/core_worker_client_pool.h"
namespace ray {
@ -60,13 +61,13 @@ class CoreWorkerDirectTaskSubmitter {
absl::optional<boost::asio::steady_timer> cancel_timer = absl::nullopt)
: rpc_address_(rpc_address),
local_lease_client_(lease_client),
client_factory_(client_factory),
lease_client_factory_(lease_client_factory),
resolver_(store, task_finisher),
task_finisher_(task_finisher),
lease_timeout_ms_(lease_timeout_ms),
local_raylet_id_(local_raylet_id),
actor_creator_(std::move(actor_creator)),
client_cache_(client_factory),
max_tasks_in_flight_per_worker_(max_tasks_in_flight_per_worker),
cancel_retry_timer_(std::move(cancel_timer)) {}
@ -142,9 +143,6 @@ class CoreWorkerDirectTaskSubmitter {
absl::flat_hash_map<ClientID, std::shared_ptr<WorkerLeaseInterface>>
remote_lease_clients_ GUARDED_BY(mu_);
/// Factory for producing new core worker clients.
rpc::ClientFactoryFn client_factory_;
/// Factory for producing new clients to request leases from remote nodes.
LeaseClientFactoryFn lease_client_factory_;
@ -169,8 +167,7 @@ class CoreWorkerDirectTaskSubmitter {
absl::Mutex mu_;
/// Cache of gRPC clients to other workers.
absl::flat_hash_map<rpc::WorkerAddress, std::shared_ptr<rpc::CoreWorkerClientInterface>>
client_cache_ GUARDED_BY(mu_);
rpc::CoreWorkerClientPool client_cache_;
// max_tasks_in_flight_per_worker_ limits the number of tasks that can be pipelined to a
// worker using a single lease.

View file

@ -36,7 +36,7 @@ GcsActorScheduler::GcsActorScheduler(
schedule_failure_handler_(std::move(schedule_failure_handler)),
schedule_success_handler_(std::move(schedule_success_handler)),
lease_client_factory_(std::move(lease_client_factory)),
client_factory_(std::move(client_factory)) {
core_worker_clients_(client_factory) {
RAY_CHECK(schedule_failure_handler_ != nullptr && schedule_success_handler_ != nullptr);
}
@ -110,7 +110,7 @@ std::vector<ActorID> GcsActorScheduler::CancelOnNode(const ClientID &node_id) {
for (auto &entry : iter->second) {
actor_ids.emplace_back(entry.second->GetAssignedActorID());
// Remove core worker client.
RAY_CHECK(core_worker_clients_.erase(entry.first) != 0);
core_worker_clients_.Disconnect(entry.first);
}
node_to_workers_when_creating_.erase(iter);
}
@ -145,7 +145,7 @@ ActorID GcsActorScheduler::CancelOnWorker(const ClientID &node_id,
if (actor_iter != iter->second.end()) {
assigned_actor_id = actor_iter->second->GetAssignedActorID();
// Remove core worker client.
RAY_CHECK(core_worker_clients_.erase(worker_id) != 0);
core_worker_clients_.Disconnect(worker_id);
iter->second.erase(actor_iter);
if (iter->second.empty()) {
node_to_workers_when_creating_.erase(iter);
@ -307,7 +307,7 @@ void GcsActorScheduler::HandleWorkerLeasedReply(
// Make sure to connect to the client before persisting actor info to GCS.
// Without this, there could be a possible race condition. Related issues:
// https://github.com/ray-project/ray/pull/9215/files#r449469320
GetOrConnectCoreWorkerClient(leased_worker->GetAddress());
core_worker_clients_.GetOrConnect(leased_worker->GetAddress());
RAY_CHECK_OK(gcs_actor_table_.Put(actor->GetActorID(), actor->GetActorTableData(),
[this, actor, leased_worker](Status status) {
RAY_CHECK_OK(status);
@ -332,7 +332,7 @@ void GcsActorScheduler::CreateActorOnWorker(std::shared_ptr<GcsActor> actor,
}
request->mutable_resource_mapping()->CopyFrom(resources);
auto client = GetOrConnectCoreWorkerClient(worker->GetAddress());
auto client = core_worker_clients_.GetOrConnect(worker->GetAddress());
client->PushNormalTask(
std::move(request),
[this, actor, worker](Status status, const rpc::PushTaskReply &reply) {
@ -350,7 +350,7 @@ void GcsActorScheduler::CreateActorOnWorker(std::shared_ptr<GcsActor> actor,
// The worker is still in the creating map.
if (status.ok()) {
// Remove related core worker client.
RAY_CHECK(core_worker_clients_.erase(actor->GetWorkerID()) != 0);
core_worker_clients_.Disconnect(actor->GetWorkerID());
// Remove related worker in phase of creating.
iter->second.erase(worker_iter);
if (iter->second.empty()) {
@ -419,15 +419,5 @@ std::shared_ptr<WorkerLeaseInterface> GcsActorScheduler::GetOrConnectLeaseClient
return iter->second;
}
std::shared_ptr<rpc::CoreWorkerClientInterface>
GcsActorScheduler::GetOrConnectCoreWorkerClient(const rpc::Address &worker_address) {
auto worker_id = WorkerID::FromBinary(worker_address.worker_id());
auto iter = core_worker_clients_.find(worker_id);
if (iter == core_worker_clients_.end()) {
iter = core_worker_clients_.emplace(worker_id, client_factory_(worker_address)).first;
}
return iter->second;
}
} // namespace gcs
} // namespace ray

View file

@ -27,6 +27,7 @@
#include "ray/raylet_client/raylet_client.h"
#include "ray/rpc/node_manager/node_manager_client.h"
#include "ray/rpc/worker/core_worker_client.h"
#include "ray/rpc/worker/core_worker_client_pool.h"
#include "src/ray/protobuf/gcs_service.pb.h"
namespace ray {
@ -259,10 +260,6 @@ class GcsActorScheduler : public GcsActorSchedulerInterface {
std::shared_ptr<WorkerLeaseInterface> GetOrConnectLeaseClient(
const rpc::Address &raylet_address);
/// Get or create CoreWorkerClient to communicate with the remote leased worker.
std::shared_ptr<rpc::CoreWorkerClientInterface> GetOrConnectCoreWorkerClient(
const rpc::Address &worker_address);
protected:
/// The io loop that is used to delay execution of tasks (e.g.,
/// execute_after).
@ -282,9 +279,6 @@ class GcsActorScheduler : public GcsActorSchedulerInterface {
/// The cached node clients which are used to communicate with raylet to lease workers.
absl::flat_hash_map<ClientID, std::shared_ptr<WorkerLeaseInterface>>
remote_lease_clients_;
/// The cached core worker clients which are used to communicate with leased worker.
absl::flat_hash_map<WorkerID, std::shared_ptr<rpc::CoreWorkerClientInterface>>
core_worker_clients_;
/// Reference of GcsNodeManager.
const GcsNodeManager &gcs_node_manager_;
/// A publisher for publishing gcs messages.
@ -295,10 +289,10 @@ class GcsActorScheduler : public GcsActorSchedulerInterface {
std::function<void(std::shared_ptr<GcsActor>)> schedule_success_handler_;
/// Factory for producing new clients to request leases from remote nodes.
LeaseClientFactoryFn lease_client_factory_;
/// Factory for producing new core worker clients.
rpc::ClientFactoryFn client_factory_;
/// The nodes which are releasing unused workers.
absl::flat_hash_set<ClientID> nodes_of_releasing_unused_workers_;
/// The cached core worker clients which are used to communicate with leased worker.
rpc::CoreWorkerClientPool core_worker_clients_;
};
} // namespace gcs

View file

@ -205,10 +205,6 @@ struct GcsServerMocker {
lease_client_factory_ = std::move(lease_client_factory);
}
void ResetClientFactory(rpc::ClientFactoryFn client_factory) {
client_factory_ = std::move(client_factory);
}
void TryLeaseWorkerFromNodeAgain(std::shared_ptr<gcs::GcsActor> actor,
std::shared_ptr<rpc::GcsNodeInfo> node) {
DoRetryLeasingWorkerFromNode(std::move(actor), std::move(node));

View file

@ -0,0 +1,42 @@
#include "ray/rpc/worker/core_worker_client_pool.h"
namespace ray {
namespace rpc {
optional<shared_ptr<CoreWorkerClientInterface>> CoreWorkerClientPool::GetByID(
ray::WorkerID id) {
absl::MutexLock lock(&mu_);
auto it = client_map_.find(id);
if (it == client_map_.end()) {
return {};
}
return it->second;
}
shared_ptr<CoreWorkerClientInterface> CoreWorkerClientPool::GetOrConnect(
const Address &addr_proto) {
RAY_CHECK(addr_proto.worker_id() != "");
absl::MutexLock lock(&mu_);
auto id = WorkerID::FromBinary(addr_proto.worker_id());
auto it = client_map_.find(id);
if (it != client_map_.end()) {
return it->second;
}
auto connection = client_factory_(addr_proto);
client_map_[id] = connection;
RAY_LOG(INFO) << "Connected to " << addr_proto.ip_address() << ":" << addr_proto.port();
return connection;
}
void CoreWorkerClientPool::Disconnect(ray::WorkerID id) {
absl::MutexLock lock(&mu_);
auto it = client_map_.find(id);
if (it == client_map_.end()) {
return;
}
client_map_.erase(it);
}
} // namespace rpc
} // namespace ray

View file

@ -0,0 +1,80 @@
// Copyright 2020 The Ray Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "absl/base/thread_annotations.h"
#include "absl/container/flat_hash_map.h"
#include "absl/synchronization/mutex.h"
#include "ray/common/id.h"
#include "ray/rpc/worker/core_worker_client.h"
using absl::optional;
using std::shared_ptr;
namespace ray {
namespace rpc {
class CoreWorkerClientPool {
public:
CoreWorkerClientPool() = delete;
/// Creates a CoreWorkerClientPool based on the low-level ClientCallManager.
CoreWorkerClientPool(rpc::ClientCallManager &ccm)
: client_factory_(defaultClientFactory(ccm)){};
/// Creates a CoreWorkerClientPool by a given connection function.
CoreWorkerClientPool(ClientFactoryFn client_factory)
: client_factory_(client_factory){};
/// Returns an existing Interface if one exists, or an empty optional
/// otherwise.
/// Any returned pointer is borrowed, and expected to be used briefly.
optional<shared_ptr<CoreWorkerClientInterface>> GetByID(ray::WorkerID id);
/// Returns an open CoreWorkerClientInterface if one exists, and connect to one
/// if it does not. The returned pointer is borrowed, and expected to be used
/// briefly.
shared_ptr<CoreWorkerClientInterface> GetOrConnect(const Address &addr_proto);
/// Removes a connection to the worker from the pool, if one exists. Since the
/// shared pointer will no longer be retained in the pool, the connection will
/// be open until it's no longer used, at which time it will disconnect.
void Disconnect(ray::WorkerID id);
private:
/// Provides the default client factory function. Providing this function to the
/// construtor aids migration but is ultimately a thing that should be
/// deprecated and brought internal to the pool, so this is our bridge.
ClientFactoryFn defaultClientFactory(rpc::ClientCallManager &ccm) const {
return [&](const rpc::Address &addr) {
return std::shared_ptr<rpc::CoreWorkerClient>(new rpc::CoreWorkerClient(addr, ccm));
};
};
/// This factory function does the connection to CoreWorkerClient, and is
/// provided by the constructor (either the default implementation, above, or a
/// provided one)
ClientFactoryFn client_factory_;
absl::Mutex mu_;
/// A pool of open connections by WorkerID. Clients can reuse the connection
/// objects in this pool by requesting them.
absl::flat_hash_map<ray::WorkerID, shared_ptr<CoreWorkerClientInterface>> client_map_
GUARDED_BY(mu_);
};
} // namespace rpc
} // namespace ray