[scheduling][5] Refactor resource syncer. (#23270)

## Why are these changes needed?

This PR refactor the resource syncer to decouple it from GCS and raylet. GCS and raylet will use the same module to sync data. The integration will happen in the next PR.

There are several new introduced components:

* RaySyncer: the place where remote and local information sits. It's a coordinator layer.
* NodeState: keeps track of the local status, similar to NodeSyncConnection.
* NodeSyncConnection: keeps track of the sending and receiving information and make sure not sending the information the remote node knows.

The core protocol is that each node will send {what it has} - {what the target has} to the target.
For example, think about node A <-> B. A will send all A has exclude what B has to B.

Whenever when there is new information (from NodeState or NodeSyncConnection), it'll be passed to RaySyncer broadcast message to broadcast. 

NodeSyncConnection is for the communication layer. It has two implementations Client and Server:

* Server => Client: client will send a long-polling request and server will response every 100ms if there is data to be sent.
* Client => Server: client will check every 100ms to see whether there is new data to be sent. If there is, just use RPC call to send the data.

Here is one example:

```mermaid
flowchart LR;
    A-->B;
    B-->C;
    B-->D;
```

It means A initialize the connection to B and B initialize the connections to C and D

Now C generate a message M:

1. [C] RaySyncer check whether there is new message generated in C and get M
2. [C] RaySyncer will push M to NodeSyncConnection in local component (B)
3. [C] ServerSyncConnection will wait until B send a long polling and send the data to B
4. [B] B received the message from C and push it to local sync connection (C, A, D)
5. [B] ClientSyncConnection of C will not push it to its local queue since it's received by this channel.
6. [B] ClientSyncConnection of D will send this message to D
7. [B] ServerSyncConnection of A will be used to send this message to A (long-polling here)
8. [B] B will update NodeState (local component) with this message M
9. [D] D's pipelines is similar to 5) (with ServerSyncConnection) and 8)
10. [A] A's pipeline is similar to 5) and 8)
This commit is contained in:
Yi Cheng 2022-03-29 23:52:39 -07:00 committed by GitHub
parent 5aead0bb91
commit 781c46ae44
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 1957 additions and 0 deletions

View file

@ -404,6 +404,13 @@ cc_library(
visibility = ["//visibility:public"],
)
cc_grpc_library(
name = "ray_syncer_cc_grpc",
srcs = ["//src/ray/protobuf:ray_syncer_proto"],
grpc_only = True,
deps = ["//src/ray/protobuf:ray_syncer_cc_proto"],
)
cc_library(
name = "ray_common",
srcs = glob(
@ -433,6 +440,7 @@ cc_library(
visibility = ["//visibility:public"],
deps = [
":node_manager_fbs",
":ray_syncer_cc_grpc",
":ray_util",
":stats_metric",
"//src/ray/protobuf:common_cc_proto",
@ -2380,6 +2388,27 @@ cc_library(
],
)
cc_binary(
name = "syncer_service_e2e_test",
srcs = ["src/ray/common/test/syncer_service_e2e_test.cc"],
copts = COPTS,
deps = [
":ray_common",
],
)
cc_test(
name = "ray_syncer_test",
srcs = ["src/ray/common/test/ray_syncer_test.cc"],
copts = COPTS,
tags = ["team:core"],
deps = [
":ray_common",
":ray_mock",
"@com_google_googletest//:gtest",
],
)
cc_test(
name = "callback_reply_test",
size = "small",

View file

@ -0,0 +1,50 @@
// Copyright The Ray Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
namespace ray {
namespace syncer {
class MockReporterInterface : public ReporterInterface {
public:
MOCK_METHOD(std::optional<RaySyncMessage>,
Snapshot,
(int64_t current_version, RayComponentId component_id),
(const, override));
};
} // namespace syncer
} // namespace ray
namespace ray {
namespace syncer {
class MockReceiverInterface : public ReceiverInterface {
public:
MOCK_METHOD(void, Update, (std::shared_ptr<const RaySyncMessage> message), (override));
};
} // namespace syncer
} // namespace ray
namespace ray {
namespace syncer {
class MockNodeSyncConnection : public NodeSyncConnection {
public:
using NodeSyncConnection::NodeSyncConnection;
MOCK_METHOD(void, DoSend, (), (override));
};
} // namespace syncer
} // namespace ray

View file

@ -502,6 +502,9 @@ RAY_CONFIG(std::string, custom_unit_instance_resources, "")
// Maximum size of the batches when broadcasting resources to raylet.
RAY_CONFIG(uint64_t, resource_broadcast_batch_size, 512);
// Maximum ray sync message batch size in bytes (1MB by default) between nodes.
RAY_CONFIG(uint64_t, max_sync_message_batch_bytes, 1 * 1024 * 1024);
// If enabled and worker stated in container, the container will add
// resource limit.
RAY_CONFIG(bool, worker_resource_limits_enabled, false)

View file

@ -0,0 +1,197 @@
// Copyright 2022 The Ray Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
namespace ray {
namespace syncer {
/// NodeState keeps track of the modules in the local nodes.
/// It contains the local components for receiving and reporting.
/// It also keeps the raw messages receivers got.
class NodeState {
public:
/// Constructor of NodeState.
NodeState();
/// Set the local component.
///
/// \param cid The component id.
/// \param reporter The reporter is defined to be the local module which wants to
/// broadcast its internal status to the whole clsuter. When it's null, it means there
/// is no reporter in the local node for this component. This is the place there
/// messages are
/// generated.
/// \param receiver The receiver is defined to be the module which eventually
/// will have the view of of the cluster for this component. It's the place where
/// received messages are consumed.
///
/// \return true if set successfully.
bool SetComponent(RayComponentId cid,
const ReporterInterface *reporter,
ReceiverInterface *receiver);
/// Get the snapshot of a component for a newer version.
///
/// \param cid The component id to take the snapshot.
///
/// \return If a snapshot is taken, return the message, otherwise std::nullopt.
std::optional<RaySyncMessage> GetSnapshot(RayComponentId cid);
/// Consume a message. Receiver will consume this message if it doesn't have
/// this message.
///
/// \param message The message received.
///
/// \return true if the local node doesn't have message with newer version.
bool ConsumeMessage(std::shared_ptr<const RaySyncMessage> message);
/// Return the cluster view of this local node.
const absl::flat_hash_map<
std::string,
std::array<std::shared_ptr<const RaySyncMessage>, kComponentArraySize>>
&GetClusterView() const {
return cluster_view_;
}
private:
/// For local nodes
std::array<const ReporterInterface *, kComponentArraySize> reporters_ = {nullptr};
std::array<ReceiverInterface *, kComponentArraySize> receivers_ = {nullptr};
/// This field records the version of the snapshot that has been taken.
std::array<int64_t, kComponentArraySize> snapshots_versions_taken_;
/// Keep track of the latest messages received.
/// Use shared pointer for easier liveness management since these messages might be
/// sending via rpc.
absl::flat_hash_map<
std::string,
std::array<std::shared_ptr<const RaySyncMessage>, kComponentArraySize>>
cluster_view_;
};
class NodeSyncConnection {
public:
NodeSyncConnection(
instrumented_io_context &io_context,
std::string remote_node_id,
std::function<void(std::shared_ptr<RaySyncMessage>)> message_processor);
/// Push a message to the sending queue to be sent later. Some message
/// might be dropped if the module think the target node has already got the
/// information. Usually it'll happen when the message has the source node id
/// as the target or the message is sent from the remote node.
///
/// \param message The message to be sent.
///
/// \return true if push to queue successfully.
bool PushToSendingQueue(std::shared_ptr<const RaySyncMessage> message);
/// Send the message queued.
virtual void DoSend() = 0;
virtual ~NodeSyncConnection() {}
/// Return the remote node id of this connection.
const std::string &GetRemoteNodeID() const { return remote_node_id_; }
/// Handle the udpates sent from the remote node.
///
/// \param messages The message received.
void ReceiveUpdate(RaySyncMessages messages);
protected:
// For testing
FRIEND_TEST(RaySyncerTest, NodeSyncConnection);
friend struct SyncerServerTest;
std::array<int64_t, kComponentArraySize> &GetNodeComponentVersions(
const std::string &node_id);
/// The io context
instrumented_io_context &io_context_;
/// The remote node id.
std::string remote_node_id_;
/// Handler of a message update.
std::function<void(std::shared_ptr<RaySyncMessage>)> message_processor_;
/// Buffering all the updates. Sending will be done in an async way.
absl::flat_hash_map<std::pair<std::string, RayComponentId>,
std::shared_ptr<const RaySyncMessage>>
sending_buffer_;
/// Keep track of the versions of components in the remote node.
/// This field will be udpated when messages are received or sent.
/// We'll filter the received or sent messages when the message is stale.
absl::flat_hash_map<std::string, std::array<int64_t, kComponentArraySize>>
node_versions_;
};
/// SyncConnection for gRPC server side. It has customized logic for sending.
class ServerSyncConnection : public NodeSyncConnection {
public:
ServerSyncConnection(
instrumented_io_context &io_context,
const std::string &remote_node_id,
std::function<void(std::shared_ptr<RaySyncMessage>)> message_processor);
~ServerSyncConnection() override;
void HandleLongPollingRequest(grpc::ServerUnaryReactor *reactor,
RaySyncMessages *response);
protected:
/// Send the message from the pending queue to the target node.
/// It'll send nothing unless there is a long-polling request.
/// TODO (iycheng): Unify the sending algorithm when we migrate to gRPC streaming
void DoSend() override;
/// These two fields are RPC related. When the server got long-polling requests,
/// these two fields will be set so that it can be used to send message.
/// After the message being sent, these two fields will be set to be empty again.
/// When the periodical timer wake up, it'll check whether these two fields are set
/// and it'll only send data when these are set.
RaySyncMessages *response_ = nullptr;
grpc::ServerUnaryReactor *unary_reactor_ = nullptr;
};
/// SyncConnection for gRPC client side. It has customized logic for sending.
class ClientSyncConnection : public NodeSyncConnection {
public:
ClientSyncConnection(
instrumented_io_context &io_context,
const std::string &node_id,
std::function<void(std::shared_ptr<RaySyncMessage>)> message_processor,
std::shared_ptr<grpc::Channel> channel);
protected:
/// Send the message from the pending queue to the target node.
/// It'll use gRPC to send the message directly.
void DoSend() override;
/// Start to send long-polling request to remote nodes.
void StartLongPolling();
/// Stub for this connection.
std::unique_ptr<ray::rpc::syncer::RaySyncer::Stub> stub_;
/// Where the received message is stored.
ray::rpc::syncer::RaySyncMessages in_message_;
/// Dummy request for long-polling.
DummyRequest dummy_;
};
} // namespace syncer
} // namespace ray

View file

@ -0,0 +1,448 @@
// Copyright 2022 The Ray Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "ray/common/ray_syncer/ray_syncer.h"
#include <functional>
#include "ray/common/ray_config.h"
namespace ray {
namespace syncer {
NodeState::NodeState() { snapshots_versions_taken_.fill(-1); }
bool NodeState::SetComponent(RayComponentId cid,
const ReporterInterface *reporter,
ReceiverInterface *receiver) {
if (cid < static_cast<RayComponentId>(kComponentArraySize) &&
reporters_[cid] == nullptr && receivers_[cid] == nullptr) {
reporters_[cid] = reporter;
receivers_[cid] = receiver;
return true;
} else {
RAY_LOG(FATAL) << "Fail to set components, component_id:" << cid
<< ", reporter:" << reporter << ", receiver:" << receiver;
return false;
}
}
std::optional<RaySyncMessage> NodeState::GetSnapshot(RayComponentId cid) {
if (reporters_[cid] == nullptr) {
return std::nullopt;
}
auto message = reporters_[cid]->Snapshot(snapshots_versions_taken_[cid], cid);
if (message != std::nullopt) {
snapshots_versions_taken_[cid] = message->version();
RAY_LOG(DEBUG) << "Snapshot taken: cid:" << cid << ", version:" << message->version()
<< ", node:" << NodeID::FromBinary(message->node_id());
}
return message;
}
bool NodeState::ConsumeMessage(std::shared_ptr<const RaySyncMessage> message) {
auto &current = cluster_view_[message->node_id()][message->component_id()];
RAY_LOG(DEBUG) << "ConsumeMessage: " << (current ? current->version() : -1)
<< " message_version: " << message->version()
<< ", message_from: " << NodeID::FromBinary(message->node_id());
// Check whether newer version of this message has been received.
if (current && current->version() >= message->version()) {
return false;
}
current = message;
auto receiver = receivers_[message->component_id()];
if (receiver != nullptr) {
receiver->Update(message);
}
return true;
}
NodeSyncConnection::NodeSyncConnection(
instrumented_io_context &io_context,
std::string remote_node_id,
std::function<void(std::shared_ptr<RaySyncMessage>)> message_processor)
: io_context_(io_context),
remote_node_id_(std::move(remote_node_id)),
message_processor_(std::move(message_processor)) {}
void NodeSyncConnection::ReceiveUpdate(RaySyncMessages messages) {
for (auto &message : *messages.mutable_sync_messages()) {
auto &node_versions = GetNodeComponentVersions(message.node_id());
RAY_LOG(DEBUG) << "Receive update: "
<< " component_id=" << message.component_id()
<< ", message_version=" << message.version()
<< ", local_message_version=" << node_versions[message.component_id()];
if (node_versions[message.component_id()] < message.version()) {
node_versions[message.component_id()] = message.version();
message_processor_(std::make_shared<RaySyncMessage>(std::move(message)));
}
}
}
bool NodeSyncConnection::PushToSendingQueue(
std::shared_ptr<const RaySyncMessage> message) {
// Try to filter out the messages the target node already has.
// Usually it'll be the case when the message is generated from the
// target node or it's sent from the target node.
if (message->node_id() == GetRemoteNodeID()) {
// Skip the message when it's about the node of this connection.
return false;
}
auto &node_versions = GetNodeComponentVersions(message->node_id());
if (node_versions[message->component_id()] < message->version()) {
node_versions[message->component_id()] = message->version();
sending_buffer_[std::make_pair(message->node_id(), message->component_id())] =
message;
return true;
}
return false;
}
std::array<int64_t, kComponentArraySize> &NodeSyncConnection::GetNodeComponentVersions(
const std::string &node_id) {
auto iter = node_versions_.find(node_id);
if (iter == node_versions_.end()) {
iter =
node_versions_.emplace(node_id, std::array<int64_t, kComponentArraySize>()).first;
iter->second.fill(-1);
}
return iter->second;
}
ClientSyncConnection::ClientSyncConnection(
instrumented_io_context &io_context,
const std::string &node_id,
std::function<void(std::shared_ptr<RaySyncMessage>)> message_processor,
std::shared_ptr<grpc::Channel> channel)
: NodeSyncConnection(io_context, node_id, std::move(message_processor)),
stub_(ray::rpc::syncer::RaySyncer::NewStub(channel)) {
StartLongPolling();
}
void ClientSyncConnection::StartLongPolling() {
// This will be a long-polling request. The node will only reply if
// 1. there is a new version of message
// 2. and it has passed X ms since last update.
auto client_context = std::make_shared<grpc::ClientContext>();
stub_->async()->LongPolling(
client_context.get(),
&dummy_,
&in_message_,
[this, client_context](grpc::Status status) {
if (status.ok()) {
RAY_CHECK(in_message_.GetArena() == nullptr);
io_context_.post(
[this, messages = std::move(in_message_)]() mutable {
ReceiveUpdate(std::move(messages));
},
"LongPollingCallback");
in_message_.Clear();
// Start the next polling.
StartLongPolling();
}
});
}
void ClientSyncConnection::DoSend() {
if (sending_buffer_.empty()) {
return;
}
auto client_context = std::make_shared<grpc::ClientContext>();
auto arena = std::make_shared<google::protobuf::Arena>();
auto request = google::protobuf::Arena::CreateMessage<RaySyncMessages>(arena.get());
auto response = google::protobuf::Arena::CreateMessage<DummyResponse>(arena.get());
std::vector<std::shared_ptr<const RaySyncMessage>> holder;
size_t message_bytes = 0;
auto iter = sending_buffer_.begin();
while (message_bytes < RayConfig::instance().max_sync_message_batch_bytes() &&
iter != sending_buffer_.end()) {
message_bytes += iter->second->sync_message().size();
// TODO (iycheng): Use arena allocator for optimization
request->mutable_sync_messages()->UnsafeArenaAddAllocated(
const_cast<RaySyncMessage *>(iter->second.get()));
holder.push_back(iter->second);
sending_buffer_.erase(iter++);
}
if (request->sync_messages_size() != 0) {
stub_->async()->Update(
client_context.get(),
request,
response,
[arena, client_context, holder = std::move(holder)](grpc::Status status) {
if (!status.ok()) {
RAY_LOG(ERROR) << "Sending request failed because of "
<< status.error_message();
}
});
}
}
ServerSyncConnection::ServerSyncConnection(
instrumented_io_context &io_context,
const std::string &remote_node_id,
std::function<void(std::shared_ptr<RaySyncMessage>)> message_processor)
: NodeSyncConnection(io_context, remote_node_id, std::move(message_processor)) {}
ServerSyncConnection::~ServerSyncConnection() {
// If there is a pending request, we need to cancel it. Otherwise, rpc will
// hang there forever.
if (unary_reactor_ != nullptr) {
unary_reactor_->Finish(grpc::Status::CANCELLED);
}
}
void ServerSyncConnection::HandleLongPollingRequest(grpc::ServerUnaryReactor *reactor,
RaySyncMessages *response) {
RAY_CHECK(response_ == nullptr);
RAY_CHECK(unary_reactor_ == nullptr);
unary_reactor_ = reactor;
response_ = response;
}
void ServerSyncConnection::DoSend() {
// There is no receive request
if (unary_reactor_ == nullptr || sending_buffer_.empty()) {
return;
}
RAY_CHECK(unary_reactor_ != nullptr && response_ != nullptr);
size_t message_bytes = 0;
auto iter = sending_buffer_.begin();
while (message_bytes < RayConfig::instance().max_sync_message_batch_bytes() &&
iter != sending_buffer_.end()) {
message_bytes += iter->second->sync_message().size();
// TODO (iycheng): Use arena allocator for optimization
response_->add_sync_messages()->CopyFrom(*iter->second);
sending_buffer_.erase(iter++);
}
if (response_->sync_messages_size() != 0) {
unary_reactor_->Finish(grpc::Status::OK);
unary_reactor_ = nullptr;
response_ = nullptr;
}
}
RaySyncer::RaySyncer(instrumented_io_context &io_context,
const std::string &local_node_id)
: io_context_(io_context),
local_node_id_(local_node_id),
node_state_(std::make_unique<NodeState>()),
timer_(io_context) {
stopped_ = std::make_shared<bool>(false);
upward_only_.fill(false);
timer_.RunFnPeriodically(
[this]() {
for (auto &[_, sync_connection] : sync_connections_) {
sync_connection->DoSend();
}
},
RayConfig::instance().raylet_report_resources_period_milliseconds());
}
RaySyncer::~RaySyncer() { *stopped_ = true; }
void RaySyncer::Connect(std::shared_ptr<grpc::Channel> channel) {
auto stub = ray::rpc::syncer::RaySyncer::NewStub(channel);
auto request = std::make_shared<StartSyncRequest>();
request->set_node_id(local_node_id_);
auto response = std::make_shared<StartSyncResponse>();
auto client_context = std::make_shared<grpc::ClientContext>();
stub->async()->StartSync(
client_context.get(),
request.get(),
response.get(),
[this, channel, request, response, client_context, stopped = this->stopped_](
grpc::Status status) {
if (*stopped) {
return;
}
if (status.ok()) {
io_context_.post(
[this, channel, response]() {
auto connection = std::make_unique<ClientSyncConnection>(
io_context_,
response->node_id(),
[this](auto msg) { BroadcastMessage(msg); },
channel);
Connect(std::move(connection));
},
"StartSyncCallback");
}
});
}
void RaySyncer::Connect(std::unique_ptr<NodeSyncConnection> connection) {
// Somehow connection=std::move(connection) won't be compiled here.
// Potentially it might have a leak here if the function is not executed.
io_context_.dispatch(
[this, connection = connection.release()]() mutable {
RAY_CHECK(connection != nullptr);
RAY_CHECK(sync_connections_[connection->GetRemoteNodeID()] == nullptr);
auto &conn = *connection;
bool is_upward_conn = false;
if (dynamic_cast<ClientSyncConnection *>(connection) == nullptr) {
upward_connections_.insert(connection);
is_upward_conn = true;
}
sync_connections_[connection->GetRemoteNodeID()].reset(connection);
for (const auto &[_, messages] : node_state_->GetClusterView()) {
for (auto &message : messages) {
if (!message) {
continue;
}
if (upward_only_[message->component_id()] && !is_upward_conn) {
continue;
}
RAY_CHECK(conn.PushToSendingQueue(message));
}
}
},
"RaySyncer::Connect");
}
void RaySyncer::Disconnect(const std::string &node_id) {
io_context_.post([this, node_id]() { sync_connections_.erase(node_id); },
"RaySyncerDisconnect");
}
bool RaySyncer::Register(RayComponentId component_id,
const ReporterInterface *reporter,
ReceiverInterface *receiver,
bool upward_only,
int64_t pull_from_reporter_interval_ms) {
if (!node_state_->SetComponent(component_id, reporter, receiver)) {
return false;
}
upward_only_[component_id] = upward_only;
// Set job to pull from reporter periodically
if (reporter != nullptr) {
RAY_CHECK(pull_from_reporter_interval_ms > 0);
timer_.RunFnPeriodically(
[this, component_id]() {
auto snapshot = node_state_->GetSnapshot(component_id);
if (snapshot) {
RAY_CHECK(snapshot->node_id() == GetLocalNodeID());
BroadcastMessage(std::make_shared<RaySyncMessage>(std::move(*snapshot)));
}
},
pull_from_reporter_interval_ms);
}
RAY_LOG(DEBUG) << "Registered components: "
<< "component_id:" << component_id << ", reporter:" << reporter
<< ", receiver:" << receiver
<< ", pull_from_reporter_interval_ms:" << pull_from_reporter_interval_ms
<< ", upward_only:" << upward_only_[component_id];
return true;
}
void RaySyncer::BroadcastMessage(std::shared_ptr<const RaySyncMessage> message) {
// The message is stale. Just skip this one.
if (!node_state_->ConsumeMessage(message)) {
return;
}
if (upward_only_[message->component_id()]) {
for (auto &connection : upward_connections_) {
connection->PushToSendingQueue(message);
}
} else {
for (auto &connection : sync_connections_) {
connection.second->PushToSendingQueue(message);
}
}
}
grpc::ServerUnaryReactor *RaySyncerService::StartSync(
grpc::CallbackServerContext *context,
const StartSyncRequest *request,
StartSyncResponse *response) {
auto *reactor = context->DefaultReactor();
// Make sure server only have one client
RAY_CHECK(remote_node_id_.empty());
remote_node_id_ = request->node_id();
RAY_LOG(DEBUG) << "Get connect from: " << NodeID::FromBinary(remote_node_id_);
syncer_.GetIOContext().post(
[this, response, reactor, context]() {
if (context->IsCancelled()) {
reactor->Finish(grpc::Status::CANCELLED);
return;
}
syncer_.Connect(std::make_unique<ServerSyncConnection>(
syncer_.GetIOContext(), remote_node_id_, [this](auto msg) {
syncer_.BroadcastMessage(msg);
}));
response->set_node_id(syncer_.GetLocalNodeID());
reactor->Finish(grpc::Status::OK);
},
"RaySyncer::StartSync");
return reactor;
}
grpc::ServerUnaryReactor *RaySyncerService::Update(grpc::CallbackServerContext *context,
const RaySyncMessages *request,
DummyResponse *) {
auto *reactor = context->DefaultReactor();
// Make sure request is allocated from heap so that it can be moved safely.
RAY_CHECK(request->GetArena() == nullptr);
syncer_.GetIOContext().post(
[this, request = std::move(*const_cast<RaySyncMessages *>(request))]() mutable {
auto *sync_connection = dynamic_cast<ServerSyncConnection *>(
syncer_.GetSyncConnection(remote_node_id_));
if (sync_connection != nullptr) {
sync_connection->ReceiveUpdate(std::move(request));
} else {
RAY_LOG(FATAL) << "Fail to get the sync context";
}
},
"SyncerUpdate");
reactor->Finish(grpc::Status::OK);
return reactor;
}
grpc::ServerUnaryReactor *RaySyncerService::LongPolling(
grpc::CallbackServerContext *context,
const DummyRequest *,
RaySyncMessages *response) {
auto *reactor = context->DefaultReactor();
syncer_.GetIOContext().post(
[this, reactor, response]() mutable {
auto *sync_connection = dynamic_cast<ServerSyncConnection *>(
syncer_.GetSyncConnection(remote_node_id_));
if (sync_connection != nullptr) {
sync_connection->HandleLongPollingRequest(reactor, response);
} else {
RAY_LOG(ERROR) << "Fail to setup long-polling";
reactor->Finish(grpc::Status::CANCELLED);
}
},
"SyncLongPolling");
return reactor;
}
RaySyncerService::~RaySyncerService() { syncer_.Disconnect(remote_node_id_); }
} // namespace syncer
} // namespace ray

View file

@ -0,0 +1,238 @@
// Copyright 2022 The Ray Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <grpcpp/server.h>
#include <gtest/gtest_prod.h>
#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "boost/functional/hash.hpp"
#include "ray/common/asio/instrumented_io_context.h"
#include "ray/common/asio/periodical_runner.h"
#include "ray/common/id.h"
#include "src/ray/protobuf/ray_syncer.grpc.pb.h"
namespace ray {
namespace syncer {
using ray::rpc::syncer::DummyRequest;
using ray::rpc::syncer::DummyResponse;
using ray::rpc::syncer::RayComponentId;
using ray::rpc::syncer::RaySyncMessage;
using ray::rpc::syncer::RaySyncMessages;
using ray::rpc::syncer::StartSyncRequest;
using ray::rpc::syncer::StartSyncResponse;
static constexpr size_t kComponentArraySize =
static_cast<size_t>(ray::rpc::syncer::RayComponentId_ARRAYSIZE);
/// The interface for a reporter. Reporter is defined to be a local module which would
/// like to let the other nodes know its state. For example, local cluster resource
/// manager.
struct ReporterInterface {
/// Interface to get the snapshot of the component. It asks the module to take a
/// snapshot of the current state. Each snapshot is versioned and it should return
/// std::nullopt if it doesn't have qualified version.
///
/// \param version_after Request snapshot with version after `version_after`. If the
/// reporter doesn't have the qualified version, just return std::nullopt
/// \param component_id The component id asked for.
///
/// \return std::nullopt if the reporter doesn't have such component or the current
/// snapshot of the component is not newer the asked one. Otherwise, return the
/// actual message.
virtual std::optional<RaySyncMessage> Snapshot(int64_t version_after,
RayComponentId component_id) const = 0;
virtual ~ReporterInterface() {}
};
/// The interface for a receiver. Receiver is defined to be a module which would like
/// to get the state of other nodes. For example, cluster resource manager.
struct ReceiverInterface {
/// Interface to update a module. The module should read the `sync_message` fields and
/// deserialize it to update its internal state.
///
/// \param message The message received from remote node.
virtual void Update(std::shared_ptr<const RaySyncMessage> message) = 0;
virtual ~ReceiverInterface() {}
};
// Forward declaration of internal structures
class NodeState;
class NodeSyncConnection;
/// RaySyncer is an embedding service for component synchronization.
/// All operations in this class needs to be finished GetIOContext()
/// for thread-safety.
/// RaySyncer is the control plane to make sure all connections eventually
/// have the latest view of the cluster components registered.
/// RaySyncer has two components:
/// 1. NodeSyncConnection: keeps track of the sending and receiving information
/// and make sure not sending the information the remote node knows.
/// 2. NodeState: keeps track of the local status, similar to NodeSyncConnection,
// but it's for local node.
class RaySyncer {
public:
/// Constructor of RaySyncer
///
/// \param io_context The io context for this component.
/// \param node_id The id of current node.
RaySyncer(instrumented_io_context &io_context, const std::string &node_id);
~RaySyncer();
/// Connect to a node.
/// TODO (iycheng): Introduce grpc channel pool and use node_id
/// for the connection.
///
/// \param connection The connection to the remote node.
void Connect(std::unique_ptr<NodeSyncConnection> connection);
/// Connect to a node.
/// TODO (iycheng): Introduce grpc channel pool and use node_id
/// for the connection.
///
/// \param connection The connection to the remote node.
void Connect(std::shared_ptr<grpc::Channel> channel);
void Disconnect(const std::string &node_id);
/// Register the components to the syncer module. Syncer will make sure eventually
/// it'll have a global view of the cluster.
///
/// Right now there are two types of components. One type of components will
/// try to broadcast the messages to make sure eventually the cluster will reach
/// an agreement (upward_only=false). The other type of components will only
/// send the message to upward (upward_only=true). Right now, upward is defined
/// to be the place which received the connection. In Ray, one type of this message
/// is resource load which only GCS needs.
/// TODO (iycheng): 1) Revisit this and come with a better solution; or 2) implement
/// resource loads in another way to avoid this feature; or 3) broadcast resource
/// loads so the scheduler can also use this.
///
/// \param component_id The component to sync.
/// \param reporter The local component to be broadcasted.
/// \param receiver The snapshot of the component in the cluster.
/// \param upward_only Only send the message to the upward of this node.
/// component.
/// \param pull_from_reporter_interval_ms The frequence to pull a message
/// from reporter and push it to sending queue.
bool Register(RayComponentId component_id,
const ReporterInterface *reporter,
ReceiverInterface *receiver,
bool upward_only = false,
int64_t pull_from_reporter_interval_ms = 100);
/// Function to broadcast the messages to other nodes.
/// A message will be sent to a node if that node doesn't have this message.
/// The message can be generated by local reporter or received by the other node.
///
/// \param message The message to be broadcasted.
void BroadcastMessage(std::shared_ptr<const RaySyncMessage> message);
/// Get the current node id.
const std::string &GetLocalNodeID() const { return local_node_id_; }
/// Get the io_context used by RaySyncer.
instrumented_io_context &GetIOContext() { return io_context_; }
/// Get the SyncConnection of a node.
///
/// \param node_id The node id to lookup.
///
/// \return nullptr if it doesn't exist, otherwise, the connection associated with the
/// node.
NodeSyncConnection *GetSyncConnection(const std::string &node_id) const {
auto iter = sync_connections_.find(node_id);
if (iter == sync_connections_.end()) {
return nullptr;
}
return iter->second.get();
}
private:
/// io_context for this thread
instrumented_io_context &io_context_;
/// The current node id.
const std::string local_node_id_;
/// Manage connections. Here the key is the NodeID in binary form.
absl::flat_hash_map<std::string, std::unique_ptr<NodeSyncConnection>> sync_connections_;
/// Upward connections. These are connections initialized not by the local node.
absl::flat_hash_set<NodeSyncConnection *> upward_connections_;
/// The local node state
std::unique_ptr<NodeState> node_state_;
/// Each component will define a flag to indicate whether the message should be sent
/// to ClientSyncConnection only.
std::array<bool, kComponentArraySize> upward_only_;
/// Timer is used to do broadcasting.
ray::PeriodicalRunner timer_;
std::shared_ptr<bool> stopped_;
/// Test purpose
friend struct SyncerServerTest;
FRIEND_TEST(SyncerTest, Broadcast);
FRIEND_TEST(SyncerTest, Test1To1);
FRIEND_TEST(SyncerTest, Test1ToN);
FRIEND_TEST(SyncerTest, TestMToN);
};
class ClientSyncConnection;
class ServerSyncConnection;
/// RaySyncerService is a service to take care of resource synchronization
/// related operations.
/// Right now only raylet needs to setup this service. But in the future,
/// we can use this to construct more complicated resource reporting algorithm,
/// like tree-based one.
class RaySyncerService : public ray::rpc::syncer::RaySyncer::CallbackService {
public:
RaySyncerService(RaySyncer &syncer) : syncer_(syncer) {}
~RaySyncerService();
grpc::ServerUnaryReactor *StartSync(grpc::CallbackServerContext *context,
const StartSyncRequest *request,
StartSyncResponse *response) override;
grpc::ServerUnaryReactor *Update(grpc::CallbackServerContext *context,
const RaySyncMessages *request,
DummyResponse *) override;
grpc::ServerUnaryReactor *LongPolling(grpc::CallbackServerContext *context,
const DummyRequest *,
RaySyncMessages *response) override;
private:
// This will be created after connection is established.
// Ideally this should be owned by RaySyncer, but since we are doing
// long-polling right now, we have to put it here so that when
// long-polling request comes, we can set it up.
std::string remote_node_id_;
// The ray syncer this RPC wrappers of.
RaySyncer &syncer_;
};
} // namespace syncer
} // namespace ray
#include "ray/common/ray_syncer/ray_syncer-inl.h"

View file

@ -0,0 +1,779 @@
// Copyright 2022 The Ray Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// clang-format off
#include "gmock/gmock.h"
#include "gtest/gtest.h"
#include <chrono>
#include <sstream>
#include <grpc/grpc.h>
#include <grpcpp/create_channel.h>
#include <google/protobuf/util/message_differencer.h>
#include <google/protobuf/util/json_util.h>
#include <grpcpp/security/credentials.h>
#include <grpcpp/security/server_credentials.h>
#include <grpcpp/server.h>
#include <grpcpp/server_builder.h>
#include "ray/common/ray_syncer/ray_syncer.h"
#include "mock/ray/common/ray_syncer/ray_syncer.h"
// clang-format on
using namespace std::chrono;
using namespace ray::syncer;
using ray::NodeID;
using ::testing::_;
using ::testing::Eq;
using ::testing::Invoke;
using ::testing::Return;
using ::testing::WithArg;
namespace ray {
namespace syncer {
RaySyncMessage MakeMessage(RayComponentId cid, int64_t version, const NodeID &id) {
auto msg = RaySyncMessage();
msg.set_version(version);
msg.set_component_id(cid);
msg.set_node_id(id.Binary());
return msg;
}
class RaySyncerTest : public ::testing::Test {
protected:
void SetUp() override {
local_versions_.fill(0);
for (size_t cid = 0; cid < reporters_.size(); ++cid) {
receivers_[cid] = std::make_unique<MockReceiverInterface>();
auto &reporter = reporters_[cid];
reporter = std::make_unique<MockReporterInterface>();
auto take_snapshot =
[this, cid](int64_t curr_version) mutable -> std::optional<RaySyncMessage> {
if (curr_version >= local_versions_[cid]) {
return std::nullopt;
} else {
auto msg = RaySyncMessage();
msg.set_component_id(static_cast<RayComponentId>(cid));
msg.set_version(++local_versions_[cid]);
return std::make_optional(std::move(msg));
}
};
ON_CALL(*reporter, Snapshot(_, _)).WillByDefault(WithArg<0>(Invoke(take_snapshot)));
}
thread_ = std::make_unique<std::thread>([this]() {
boost::asio::io_context::work work(io_context_);
io_context_.run();
});
local_id_ = NodeID::FromRandom();
syncer_ = std::make_unique<RaySyncer>(io_context_, local_id_.Binary());
}
MockReporterInterface *GetReporter(RayComponentId cid) {
return reporters_[static_cast<size_t>(cid)].get();
}
MockReceiverInterface *GetReceiver(RayComponentId cid) {
return receivers_[static_cast<size_t>(cid)].get();
}
int64_t &LocalVersion(RayComponentId cid) {
return local_versions_[static_cast<size_t>(cid)];
}
void TearDown() override {
io_context_.stop();
thread_->join();
}
std::array<int64_t, kComponentArraySize> local_versions_;
std::array<std::unique_ptr<MockReporterInterface>, kComponentArraySize> reporters_ = {
nullptr};
std::array<std::unique_ptr<MockReceiverInterface>, kComponentArraySize> receivers_ = {
nullptr};
instrumented_io_context io_context_;
std::unique_ptr<std::thread> thread_;
std::unique_ptr<RaySyncer> syncer_;
NodeID local_id_;
};
TEST_F(RaySyncerTest, NodeStateGetSnapshot) {
auto node_status = std::make_unique<NodeState>();
node_status->SetComponent(RayComponentId::RESOURCE_MANAGER, nullptr, nullptr);
ASSERT_EQ(std::nullopt, node_status->GetSnapshot(RayComponentId::RESOURCE_MANAGER));
ASSERT_EQ(std::nullopt, node_status->GetSnapshot(RayComponentId::SCHEDULER));
auto reporter = std::make_unique<MockReporterInterface>();
ASSERT_TRUE(node_status->SetComponent(RayComponentId::RESOURCE_MANAGER,
GetReporter(RayComponentId::RESOURCE_MANAGER),
nullptr));
// Take a snapshot
ASSERT_EQ(std::nullopt, node_status->GetSnapshot(RayComponentId::SCHEDULER));
auto msg = node_status->GetSnapshot(RayComponentId::RESOURCE_MANAGER);
ASSERT_EQ(LocalVersion(RayComponentId::RESOURCE_MANAGER), msg->version());
// Revert one version back.
LocalVersion(RayComponentId::RESOURCE_MANAGER) -= 1;
msg = node_status->GetSnapshot(RayComponentId::RESOURCE_MANAGER);
ASSERT_EQ(std::nullopt, msg);
}
TEST_F(RaySyncerTest, NodeStateConsume) {
auto node_status = std::make_unique<NodeState>();
node_status->SetComponent(RayComponentId::RESOURCE_MANAGER,
nullptr,
GetReceiver(RayComponentId::RESOURCE_MANAGER));
auto from_node_id = NodeID::FromRandom();
// The first time receiver the message
auto msg = MakeMessage(RayComponentId::RESOURCE_MANAGER, 0, from_node_id);
ASSERT_TRUE(node_status->ConsumeMessage(std::make_shared<RaySyncMessage>(msg)));
ASSERT_FALSE(node_status->ConsumeMessage(std::make_shared<RaySyncMessage>(msg)));
msg.set_version(1);
ASSERT_TRUE(node_status->ConsumeMessage(std::make_shared<RaySyncMessage>(msg)));
ASSERT_FALSE(node_status->ConsumeMessage(std::make_shared<RaySyncMessage>(msg)));
}
TEST_F(RaySyncerTest, NodeSyncConnection) {
auto node_id = NodeID::FromRandom();
MockNodeSyncConnection sync_connection(
io_context_,
node_id.Binary(),
[](std::shared_ptr<ray::rpc::syncer::RaySyncMessage>) {});
auto from_node_id = NodeID::FromRandom();
auto msg = MakeMessage(RayComponentId::RESOURCE_MANAGER, 0, from_node_id);
// First push will succeed and the second one will be deduplicated.
ASSERT_TRUE(sync_connection.PushToSendingQueue(std::make_shared<RaySyncMessage>(msg)));
ASSERT_FALSE(sync_connection.PushToSendingQueue(std::make_shared<RaySyncMessage>(msg)));
ASSERT_EQ(1, sync_connection.sending_buffer_.size());
ASSERT_EQ(0, sync_connection.sending_buffer_.begin()->second->version());
ASSERT_EQ(1, sync_connection.node_versions_.size());
ASSERT_EQ(0,
sync_connection
.node_versions_[from_node_id.Binary()][RayComponentId::RESOURCE_MANAGER]);
msg.set_version(2);
ASSERT_TRUE(sync_connection.PushToSendingQueue(std::make_shared<RaySyncMessage>(msg)));
ASSERT_FALSE(sync_connection.PushToSendingQueue(std::make_shared<RaySyncMessage>(msg)));
// The previous message is deleted.
ASSERT_EQ(1, sync_connection.sending_buffer_.size());
ASSERT_EQ(1, sync_connection.node_versions_.size());
ASSERT_EQ(2, sync_connection.sending_buffer_.begin()->second->version());
ASSERT_EQ(2,
sync_connection
.node_versions_[from_node_id.Binary()][RayComponentId::RESOURCE_MANAGER]);
}
struct SyncerServerTest {
SyncerServerTest(std::string port, bool has_scheduler_reporter = true) {
this->server_port = port;
bool has_scheduler_receiver = !has_scheduler_reporter;
// Setup io context
auto node_id = NodeID::FromRandom();
for (auto &v : local_versions) {
v = 0;
}
// Setup syncer and grpc server
syncer = std::make_unique<RaySyncer>(io_context, node_id.Binary());
auto server_address = std::string("0.0.0.0:") + port;
grpc::ServerBuilder builder;
service = std::make_unique<RaySyncerService>(*syncer);
builder.AddListeningPort(server_address, grpc::InsecureServerCredentials());
builder.RegisterService(service.get());
server = builder.BuildAndStart();
for (size_t cid = 0; cid < reporters.size(); ++cid) {
auto snapshot_received = [this](std::shared_ptr<const RaySyncMessage> message) {
auto iter = received_versions.find(message->node_id());
if (iter == received_versions.end()) {
for (auto &v : received_versions[message->node_id()]) {
v = 0;
}
iter = received_versions.find(message->node_id());
}
received_versions[message->node_id()][message->component_id()] =
message->version();
message_consumed[message->node_id()]++;
};
if (has_scheduler_receiver ||
static_cast<RayComponentId>(cid) != RayComponentId::SCHEDULER) {
receivers[cid] = std::make_unique<MockReceiverInterface>();
EXPECT_CALL(*receivers[cid], Update(_))
.WillRepeatedly(WithArg<0>(Invoke(snapshot_received)));
}
auto &reporter = reporters[cid];
auto take_snapshot =
[this, cid](int64_t version_after) mutable -> std::optional<RaySyncMessage> {
if (local_versions[cid] <= version_after) {
return std::nullopt;
} else {
auto msg = RaySyncMessage();
msg.set_component_id(static_cast<RayComponentId>(cid));
msg.set_version(local_versions[cid]);
msg.set_node_id(syncer->GetLocalNodeID());
std::string dbg_message;
google::protobuf::util::MessageToJsonString(msg, &dbg_message);
snapshot_taken++;
return std::make_optional(std::move(msg));
}
};
if (has_scheduler_reporter ||
static_cast<RayComponentId>(cid) != RayComponentId::SCHEDULER) {
reporter = std::make_unique<MockReporterInterface>();
EXPECT_CALL(*reporter, Snapshot(_, Eq(cid)))
.WillRepeatedly(WithArg<0>(Invoke(take_snapshot)));
}
syncer->Register(static_cast<RayComponentId>(cid),
reporter.get(),
receivers[cid].get(),
static_cast<RayComponentId>(cid) == RayComponentId::SCHEDULER);
}
thread = std::make_unique<std::thread>([this] {
boost::asio::io_context::work work(io_context);
io_context.run();
});
}
void WaitSendingFlush() {
while (true) {
std::promise<bool> p;
auto f = p.get_future();
io_context.post(
[&p, this]() mutable {
for (const auto &[node_id, conn] : syncer->sync_connections_) {
if (!conn->sending_buffer_.empty()) {
p.set_value(false);
RAY_LOG(INFO) << NodeID::FromBinary(syncer->GetLocalNodeID()) << ": "
<< "Waiting for message on " << NodeID::FromBinary(node_id)
<< " to be sent."
<< " Remainings " << conn->sending_buffer_.size();
return;
}
}
p.set_value(true);
},
"TEST");
if (f.get()) {
return;
} else {
std::this_thread::sleep_for(1s);
}
}
}
bool WaitUntil(std::function<bool()> predicate, int64_t time_s) {
auto start = steady_clock::now();
while (duration_cast<seconds>(steady_clock::now() - start).count() <= time_s) {
std::promise<bool> p;
auto f = p.get_future();
io_context.post([&p, predicate]() mutable { p.set_value(predicate()); }, "TEST");
if (f.get()) {
return true;
} else {
std::this_thread::sleep_for(1s);
}
}
return false;
}
~SyncerServerTest() {
service.reset();
server.reset();
io_context.stop();
thread->join();
syncer.reset();
}
int64_t GetNumConsumedMessages(const std::string &node_id) const {
auto iter = message_consumed.find(node_id);
if (iter == message_consumed.end()) {
return 0;
} else {
return iter->second;
}
}
std::array<std::atomic<int64_t>, kComponentArraySize> _v;
const std::array<std::atomic<int64_t>, kComponentArraySize> &GetReceivedVersions(
const std::string &node_id) {
auto iter = received_versions.find(node_id);
if (iter == received_versions.end()) {
for (auto &v : _v) {
v.store(-1);
}
return _v;
}
return iter->second;
}
std::unique_ptr<RaySyncerService> service;
std::unique_ptr<RaySyncer> syncer;
std::unique_ptr<grpc::Server> server;
std::unique_ptr<std::thread> thread;
instrumented_io_context io_context;
std::string server_port;
std::array<std::atomic<int64_t>, kComponentArraySize> local_versions;
std::array<std::unique_ptr<MockReporterInterface>, kComponentArraySize> reporters = {
nullptr};
int64_t snapshot_taken = 0;
std::unordered_map<std::string, std::array<std::atomic<int64_t>, kComponentArraySize>>
received_versions;
std::unordered_map<std::string, std::atomic<int64_t>> message_consumed;
std::array<std::unique_ptr<MockReceiverInterface>, kComponentArraySize> receivers = {
nullptr};
};
// Useful for debugging
// std::ostream &operator<<(std::ostream &os, const SyncerServerTest &server) {
// auto dump_array = [&os](const std::array<int64_t, kComponentArraySize> &v,
// std::string label,
// int indent) mutable -> std::ostream & {
// os << std::string('\t', indent);
// os << label << ": ";
// for (size_t i = 0; i < v.size(); ++i) {
// os << v[i];
// if (i + 1 != v.size()) {
// os << ", ";
// }
// }
// return os;
// };
// os << "NodeID: " << NodeID::FromBinary(server.syncer->GetLocalNodeID()) << std::endl;
// dump_array(server.local_versions, "LocalVersions:", 1) << std::endl;
// for (auto [node_id, versions] : server.received_versions) {
// os << "\tFromNodeID: " << NodeID::FromBinary(node_id) << std::endl;
// dump_array(versions, "RemoteVersions:", 2) << std::endl;
// }
// return os;
// }
std::shared_ptr<grpc::Channel> MakeChannel(std::string port) {
grpc::ChannelArguments argument;
// Disable http proxy since it disrupts local connections. TODO(ekl) we should make
// this configurable, or selectively set it for known local connections only.
argument.SetInt(GRPC_ARG_ENABLE_HTTP_PROXY, 0);
argument.SetMaxSendMessageSize(::RayConfig::instance().max_grpc_message_size());
argument.SetMaxReceiveMessageSize(::RayConfig::instance().max_grpc_message_size());
return grpc::CreateCustomChannel(
"localhost:" + port, grpc::InsecureChannelCredentials(), argument);
}
using TClusterView = absl::flat_hash_map<
std::string,
std::array<std::shared_ptr<const RaySyncMessage>, kComponentArraySize>>;
TEST(SyncerTest, Test1To1) {
// s1: reporter: RayComponentId::RESOURCE_MANAGER
// s1: receiver: RayComponentId::SCHEDULER, RayComponentId::RESOURCE_MANAGER
auto s1 = SyncerServerTest("19990", false);
// s2: reporter: RayComponentId::RESOURCE_MANAGER, RayComponentId::SCHEDULER
// s2: receiver: RayComponentId::RESOURCE_MANAGER
auto s2 = SyncerServerTest("19991", true);
// Make sure the setup is correct
ASSERT_NE(nullptr, s1.receivers[RayComponentId::SCHEDULER]);
ASSERT_EQ(nullptr, s2.receivers[RayComponentId::SCHEDULER]);
ASSERT_EQ(nullptr, s1.reporters[RayComponentId::SCHEDULER]);
ASSERT_NE(nullptr, s2.reporters[RayComponentId::SCHEDULER]);
ASSERT_NE(nullptr, s1.receivers[RayComponentId::RESOURCE_MANAGER]);
ASSERT_NE(nullptr, s2.receivers[RayComponentId::RESOURCE_MANAGER]);
ASSERT_NE(nullptr, s1.reporters[RayComponentId::RESOURCE_MANAGER]);
ASSERT_NE(nullptr, s2.reporters[RayComponentId::RESOURCE_MANAGER]);
auto channel_to_s2 = MakeChannel("19991");
s1.syncer->Connect(channel_to_s2);
// Make sure s2 adds s1n
ASSERT_TRUE(s2.WaitUntil(
[&s2]() {
return s2.syncer->sync_connections_.size() == 1 && s2.snapshot_taken == 2;
},
5));
// Make sure s1 adds s2
ASSERT_TRUE(s1.WaitUntil(
[&s1]() {
return s1.syncer->sync_connections_.size() == 1 && s1.snapshot_taken == 1;
},
5));
// s1 will only send 1 message to s2 because it only has one reporter
ASSERT_TRUE(s2.WaitUntil(
[&s2, node_id = s1.syncer->GetLocalNodeID()]() {
return s2.GetNumConsumedMessages(node_id) == 1;
},
5));
// s2 will send 2 messages to s1 because it has two reporters.
ASSERT_TRUE(s1.WaitUntil(
[&s1, node_id = s2.syncer->GetLocalNodeID()]() {
return s1.GetNumConsumedMessages(node_id) == 2;
},
5));
// s2 local module version advance
s2.local_versions[0] = 1;
ASSERT_TRUE(s2.WaitUntil([&s2]() { return s2.snapshot_taken == 3; }, 2));
// Make sure s2 send the new message to s1.
ASSERT_TRUE(s1.WaitUntil(
[&s1, node_id = s2.syncer->GetLocalNodeID()]() {
return s1.GetReceivedVersions(node_id)[RayComponentId::RESOURCE_MANAGER] == 1 &&
s1.GetNumConsumedMessages(node_id) == 3;
},
5));
// Make sure no new messages are sent
s2.local_versions[0] = 0;
std::this_thread::sleep_for(1s);
ASSERT_TRUE(s1.GetNumConsumedMessages(s2.syncer->GetLocalNodeID()) == 3);
ASSERT_TRUE(s2.GetNumConsumedMessages(s1.syncer->GetLocalNodeID()) == 1);
// Change it back
s2.local_versions[0] = 1;
// Make some random messages
std::random_device rd;
std::mt19937 gen(rd());
std::uniform_int_distribution<> rand_sleep(0, 10000);
std::uniform_int_distribution<> choose_component(0, 1);
size_t s1_updated = 0;
size_t s2_updated = 0;
auto start = steady_clock::now();
for (int i = 0; i < 10000; ++i) {
if (choose_component(gen) == 0) {
s1.local_versions[0]++;
++s1_updated;
} else {
s2.local_versions[choose_component(gen)]++;
++s2_updated;
}
if (rand_sleep(gen) < 5) {
std::this_thread::sleep_for(1s);
}
}
auto end = steady_clock::now();
// Max messages can be send during this period of time.
// +1 is for corner cases.
auto max_sends =
duration_cast<milliseconds>(end - start).count() /
RayConfig::instance().raylet_report_resources_period_milliseconds() +
1;
ASSERT_TRUE(s1.WaitUntil(
[&s1, &s2]() {
return s1.GetReceivedVersions(s2.syncer->GetLocalNodeID()) == s2.local_versions &&
s2.GetReceivedVersions(s1.syncer->GetLocalNodeID())[0] ==
s1.local_versions[0];
},
5));
// s2 has two reporters + 3 for the ones send before the measure
ASSERT_LT(s1.GetNumConsumedMessages(s2.syncer->GetLocalNodeID()), max_sends * 2 + 3);
// s1 has one reporter + 1 for the one send before the measure
ASSERT_LT(s2.GetNumConsumedMessages(s1.syncer->GetLocalNodeID()), max_sends + 3);
}
TEST(SyncerTest, Broadcast) {
// This test covers the broadcast feature of ray syncer.
auto s1 = SyncerServerTest("19990", false);
auto s2 = SyncerServerTest("19991", true);
auto s3 = SyncerServerTest("19992", true);
// We need to make sure s1 is sending data to s3 for s2
s1.syncer->Connect(MakeChannel("19991"));
s1.syncer->Connect(MakeChannel("19992"));
// Make sure the setup is correct
ASSERT_TRUE(s1.WaitUntil(
[&s1]() {
return s1.syncer->sync_connections_.size() == 2 && s1.snapshot_taken == 1;
},
5));
ASSERT_TRUE(s1.WaitUntil(
[&s2]() {
return s2.syncer->sync_connections_.size() == 1 && s2.snapshot_taken == 2;
},
5));
ASSERT_TRUE(s1.WaitUntil(
[&s3]() {
return s3.syncer->sync_connections_.size() == 1 && s3.snapshot_taken == 2;
},
5));
// Change the resource in s2 and make sure s1 && s3 are correct
s2.local_versions[0] = 1;
s2.local_versions[1] = 1;
ASSERT_TRUE(s1.WaitUntil(
[&s1, node_id = s2.syncer->GetLocalNodeID()]() mutable {
return s1.received_versions[node_id][0] == 1 &&
s1.received_versions[node_id][1] == 1;
},
5));
ASSERT_TRUE(s1.WaitUntil(
[&s3, node_id = s2.syncer->GetLocalNodeID()]() mutable {
return s3.received_versions[node_id][0] == 1 &&
// Make sure SCHEDULE information is not sent to s3
s3.received_versions[node_id][1] == 0;
},
5));
}
bool CompareViews(const std::vector<std::unique_ptr<SyncerServerTest>> &servers,
const std::vector<TClusterView> &views,
const std::vector<std::set<size_t>> &g) {
// Check broadcasting is working
// component id = 0
// simply compare everything with server 0
for (size_t i = 1; i < views.size(); ++i) {
if (views[i].size() != views[0].size()) {
RAY_LOG(ERROR) << "View size wrong: (" << i << ") :" << views[i].size() << " vs "
<< views[0].size();
return false;
}
for (const auto &[k, v] : views[0]) {
auto iter = views[i].find(k);
if (iter == views[i].end()) {
return false;
}
const auto &vv = iter->second;
if (!google::protobuf::util::MessageDifferencer::Equals(*v[0], *vv[0])) {
RAY_LOG(ERROR) << i << ": FAIL RESOURCE: " << v[0] << ", " << vv[0] << ", "
<< v[1] << ", " << vv[1];
std::string dbg_message;
google::protobuf::util::MessageToJsonString(*v[0], &dbg_message);
RAY_LOG(ERROR) << "server[0] >> "
<< NodeID::FromBinary(servers[0]->syncer->GetLocalNodeID()) << ": "
<< dbg_message << " - " << NodeID::FromBinary(v[0]->node_id());
dbg_message.clear();
google::protobuf::util::MessageToJsonString(*vv[0], &dbg_message);
RAY_LOG(ERROR) << "server[i] << "
<< NodeID::FromBinary(servers[i]->syncer->GetLocalNodeID()) << ": "
<< dbg_message << " - " << NodeID::FromBinary(vv[0]->node_id());
return false;
}
}
}
std::map<std::string, size_t> node_id_to_idx;
for (size_t i = 0; i < servers.size(); ++i) {
node_id_to_idx[servers[i]->syncer->GetLocalNodeID()] = i;
}
// Check whether j is reachable from i
auto reachable = [&g](size_t i, size_t j) {
if (i == j) {
return true;
}
std::deque<size_t> q;
q.push_back(i);
while (!q.empty()) {
auto f = q.front();
q.pop_front();
for (auto m : g[f]) {
if (m == j) {
return true;
}
q.push_back(m);
}
}
return false;
};
// Check scheduler which is aggregating only
for (size_t i = 0; i < servers.size(); ++i) {
const auto &view = views[i];
// view: node_id -> msg
for (auto [node_id, msgs] : view) {
if (node_id_to_idx[node_id] == i) {
continue;
}
auto msg = msgs[1];
auto is_reachable = reachable(i, node_id_to_idx[node_id]);
if (msg == nullptr) {
if (is_reachable) {
RAY_LOG(ERROR) << i << " is null, but it can reach " << node_id_to_idx[node_id];
return false;
}
} else {
if (!is_reachable) {
RAY_LOG(ERROR) << i << " is not null, but it can't reachable "
<< node_id_to_idx[node_id];
return false;
}
auto iter = views[node_id_to_idx[node_id]].find(node_id);
if (iter == views[node_id_to_idx[node_id]].end()) {
return false;
}
auto msg2 = iter->second[1];
if (msg2 == nullptr) {
return false;
}
if (!google::protobuf::util::MessageDifferencer::Equals(*msg, *msg2)) {
std::string dbg_message;
google::protobuf::util::MessageToJsonString(*msg, &dbg_message);
RAY_LOG(ERROR) << "server[" << i << "] >> "
<< NodeID::FromBinary(servers[i]->syncer->GetLocalNodeID())
<< ": " << dbg_message;
dbg_message.clear();
google::protobuf::util::MessageToJsonString(*msg2, &dbg_message);
RAY_LOG(ERROR) << "server[" << node_id_to_idx[node_id] << "] << "
<< NodeID::FromBinary(servers[node_id_to_idx[node_id]]
->syncer->GetLocalNodeID())
<< ": " << dbg_message;
return false;
}
}
}
}
return true;
}
bool TestCorrectness(std::function<TClusterView(RaySyncer &syncer)> get_cluster_view,
std::vector<std::unique_ptr<SyncerServerTest>> &servers,
const std::vector<std::set<size_t>> &g) {
auto check = [&servers, get_cluster_view, &g]() {
std::vector<TClusterView> views;
for (auto &s : servers) {
views.push_back(get_cluster_view(*(s->syncer)));
}
return CompareViews(servers, views, g);
};
for (auto &server : servers) {
server->WaitSendingFlush();
}
for (size_t i = 0; i < 10; ++i) {
if (!check()) {
std::this_thread::sleep_for(1s);
} else {
break;
}
}
if (!check()) {
RAY_LOG(ERROR) << "Initial check failed";
return false;
}
std::random_device rd;
std::mt19937 gen(rd());
std::uniform_int_distribution<> rand_sleep(0, 1000000);
std::uniform_int_distribution<> choose_component(0, 1);
std::uniform_int_distribution<> choose_server(0, servers.size() - 1);
for (size_t i = 0; i < 1000000; ++i) {
auto server_idx = choose_server(gen);
auto component_id = choose_component(gen);
if (server_idx == 0) {
component_id = 0;
}
servers[server_idx]->local_versions[component_id]++;
// expect to sleep for 100 times for the whole loop.
if (rand_sleep(gen) < 100) {
std::this_thread::sleep_for(100ms);
}
}
for (auto &server : servers) {
server->WaitSendingFlush();
}
// Make sure everything is synced.
for (size_t i = 0; i < 10; ++i) {
if (!check()) {
std::this_thread::sleep_for(1s);
} else {
break;
}
}
return check();
}
TEST(SyncerTest, Test1ToN) {
size_t base_port = 18990;
std::vector<std::unique_ptr<SyncerServerTest>> servers;
for (int i = 0; i < 20; ++i) {
servers.push_back(
std::make_unique<SyncerServerTest>(std::to_string(i + base_port), i != 0));
}
std::vector<std::set<size_t>> g(servers.size());
for (size_t i = 1; i < servers.size(); ++i) {
servers[0]->syncer->Connect(MakeChannel(servers[i]->server_port));
g[0].insert(i);
}
auto get_cluster_view = [](RaySyncer &syncer) {
std::promise<TClusterView> p;
auto f = p.get_future();
syncer.GetIOContext().post(
[&p, &syncer]() mutable { p.set_value(syncer.node_state_->GetClusterView()); },
"TEST");
return f.get();
};
ASSERT_TRUE(TestCorrectness(get_cluster_view, servers, g));
}
TEST(SyncerTest, TestMToN) {
size_t base_port = 18990;
std::vector<std::unique_ptr<SyncerServerTest>> servers;
for (int i = 0; i < 20; ++i) {
servers.push_back(
std::make_unique<SyncerServerTest>(std::to_string(i + base_port), i != 0));
}
std::vector<std::set<size_t>> g(servers.size());
// Try to construct a tree based structure
size_t i = 1;
size_t curr = 0;
while (i < servers.size()) {
// try to connect to 2 servers per node.
for (int k = 0; k < 2 && i < servers.size(); ++k, ++i) {
servers[curr]->syncer->Connect(MakeChannel(servers[i]->server_port));
g[curr].insert(i);
}
++curr;
}
auto get_cluster_view = [](RaySyncer &syncer) {
std::promise<TClusterView> p;
auto f = p.get_future();
syncer.GetIOContext().post(
[&p, &syncer]() mutable { p.set_value(syncer.node_state_->GetClusterView()); },
"TEST");
return f.get();
};
ASSERT_TRUE(TestCorrectness(get_cluster_view, servers, g));
}
} // namespace syncer
} // namespace ray

View file

@ -0,0 +1,135 @@
// Copyright 2022 The Ray Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <grpc/grpc.h>
#include <grpcpp/create_channel.h>
#include <grpcpp/security/credentials.h>
#include <grpcpp/security/server_credentials.h>
#include <grpcpp/server.h>
#include <grpcpp/server_builder.h>
#include <cstdlib>
#include <ctime>
#include <iostream>
#include "ray/common/asio/periodical_runner.h"
#include "ray/common/id.h"
#include "ray/common/ray_syncer/ray_syncer.h"
using namespace std;
using namespace ray::syncer;
class LocalNode : public ReporterInterface {
public:
LocalNode(instrumented_io_context &io_context, ray::NodeID node_id)
: node_id_(node_id), timer_(io_context) {
timer_.RunFnPeriodically(
[this]() {
auto v = static_cast<double>(std::rand()) / RAND_MAX;
if (v < 0.3) {
int old_state = state_;
state_ += std::rand() % 10;
++version_;
RAY_LOG(INFO) << node_id_ << " change from (" << old_state
<< ", v:" << (version_ - 1) << ") to (" << state_
<< ", v:" << version_ << ")";
}
},
1000);
}
std::optional<RaySyncMessage> Snapshot(int64_t current_version,
RayComponentId) const override {
if (current_version > version_) {
return std::nullopt;
}
ray::rpc::syncer::RaySyncMessage msg;
msg.set_component_id(ray::rpc::syncer::RayComponentId::RESOURCE_MANAGER);
msg.set_version(version_);
msg.set_sync_message(
std::string(reinterpret_cast<const char *>(&state_), sizeof(state_)));
msg.set_node_id(node_id_.Binary());
return msg;
}
private:
int64_t version_ = 0;
int state_ = 0;
ray::NodeID node_id_;
ray::PeriodicalRunner timer_;
};
class RemoteNodes : public ReceiverInterface {
public:
RemoteNodes() {}
bool NeedBroadcast() const override { return true; }
void Update(std::shared_ptr<const ray::rpc::syncer::RaySyncMessage> msg) override {
auto version = msg->version();
int state = *reinterpret_cast<const int *>(msg->sync_message().data());
auto iter = infos_.find(msg->node_id());
if (iter == infos_.end() || iter->second.second < version) {
RAY_LOG(INFO) << "Update node " << ray::NodeID::FromBinary(msg->node_id()).Hex()
<< " to (" << state << ", v:" << version << ")";
infos_[msg->node_id()] = std::make_pair(state, version);
}
}
private:
absl::flat_hash_map<std::string, std::pair<int, int>> infos_;
};
int main(int argc, char *argv[]) {
std::srand(std::time(nullptr));
instrumented_io_context io_context;
RAY_CHECK(argc == 3) << "./test_syncer_service server_port leader_port";
auto node_id = ray::NodeID::FromRandom();
auto server_port = std::string(argv[1]);
auto leader_port = std::string(argv[2]);
auto local_node = std::make_unique<LocalNode>(io_context, node_id);
auto remote_node = std::make_unique<RemoteNodes>();
RaySyncer syncer(io_context, node_id.Binary());
// RPC related field
grpc::ServerBuilder builder;
std::unique_ptr<RaySyncerService> service;
std::unique_ptr<grpc::Server> server;
std::shared_ptr<grpc::Channel> channel;
syncer.Register(ray::rpc::syncer::RayComponentId::RESOURCE_MANAGER,
local_node.get(),
remote_node.get());
if (server_port != ".") {
RAY_LOG(INFO) << "Start server on port " << server_port;
auto server_address = "0.0.0.0:" + server_port;
service = std::make_unique<RaySyncerService>(syncer);
builder.AddListeningPort(server_address, grpc::InsecureServerCredentials());
builder.RegisterService(service.get());
server = builder.BuildAndStart();
}
if (leader_port != ".") {
grpc::ChannelArguments argument;
// Disable http proxy since it disrupts local connections. TODO(ekl) we should make
// this configurable, or selectively set it for known local connections only.
argument.SetInt(GRPC_ARG_ENABLE_HTTP_PROXY, 0);
argument.SetMaxSendMessageSize(::RayConfig::instance().max_grpc_message_size());
argument.SetMaxReceiveMessageSize(::RayConfig::instance().max_grpc_message_size());
channel = grpc::CreateCustomChannel(
"localhost:" + leader_port, grpc::InsecureChannelCredentials(), argument);
syncer.Connect(channel);
}
boost::asio::io_context::work work(io_context);
io_context.run();
return 0;
}

View file

@ -37,6 +37,16 @@ proto_library(
visibility = ["//java:__subpackages__"],
)
proto_library(
name = "ray_syncer_proto",
srcs = ["ray_syncer.proto"],
)
cc_proto_library(
name = "ray_syncer_cc_proto",
deps = [":ray_syncer_proto"],
)
cc_proto_library(
name = "runtime_env_common_cc_proto",
deps = [":runtime_env_common_proto"],

View file

@ -0,0 +1,68 @@
// Copyright 2022 The Ray Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto3";
option cc_enable_arenas = true;
package ray.rpc.syncer;
enum RayComponentId {
RESOURCE_MANAGER = 0;
SCHEDULER = 1;
}
message RaySyncMessage {
// The version of the message. -1 means the version is not set.
int64 version = 1;
// The component this message is for.
RayComponentId component_id = 2;
// The actual payload.
bytes sync_message = 3;
// The node id which initially sent this message.
bytes node_id = 4;
}
message RaySyncMessages {
// The bached messages.
repeated RaySyncMessage sync_messages = 1;
}
message StartSyncRequest {
bytes node_id = 1;
}
message StartSyncResponse {
bytes node_id = 1;
}
message DummyRequest {}
message DummyResponse {}
service RaySyncer {
// Ideally these should be a streaming API like this
// rpc StartSync(stream RaySyncMessages) returns (stream RaySyncMessages);
// But to make sure it's the same as the current protocol, we still use
// unary rpc.
// TODO (iycheng): Using grpc streaming for the protocol.
// This is the first message that should be sent. It will initialize
// some structure between nodes.
rpc StartSync(StartSyncRequest) returns (StartSyncResponse);
// These two RPCs are for messages reporting and broadcasting.
// Update is used by the client to send update request to the server.
rpc Update(RaySyncMessages) returns (DummyResponse);
// LongPolling is used by the server to send request to the client.
rpc LongPolling(DummyRequest) returns (RaySyncMessages);
}