[Pubsub] Batch messages (#15084)

* batch pubsub 1

* Logic done. Tests left.

* done.
This commit is contained in:
SangBin Cho 2021-04-02 16:42:18 -07:00 committed by GitHub
parent aea28c53ce
commit cef6286f63
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 111 additions and 45 deletions

View file

@ -403,3 +403,6 @@ RAY_CONFIG(int64_t, asio_stats_print_interval_ms, -1)
/// Maximum amount of memory that will be used by running tasks' args.
RAY_CONFIG(float, max_task_args_memory_fraction, 0.7)
/// The maximum number of objects to publish for each publish calls.
RAY_CONFIG(uint64_t, publish_batch_size, 5000)

View file

@ -563,7 +563,8 @@ CoreWorker::CoreWorker(const CoreWorkerOptions &options, const WorkerID &worker_
task_manager_));
object_status_publisher_ = std::make_shared<Publisher>(
/*is_node_dead=*/[this](const NodeID &node_id) {
/*is_node_dead=*/
[this](const NodeID &node_id) {
if (auto node_info =
gcs_client_->Nodes().Get(node_id, /*filter_dead_nodes=*/false)) {
return node_info->state() ==
@ -572,7 +573,8 @@ CoreWorker::CoreWorker(const CoreWorkerOptions &options, const WorkerID &worker_
// Node information is probably not
// subscribed yet, so report that the node is alive.
return true;
});
},
/*publish_batch_size_=*/RayConfig::instance().publish_batch_size());
auto node_addr_factory = [this](const NodeID &node_id) {
absl::optional<rpc::Address> addr;
@ -2409,8 +2411,8 @@ void CoreWorker::HandlePubsubLongPolling(const rpc::PubsubLongPollingRequest &re
send_reply_callback(Status::OK(), nullptr, nullptr);
};
RAY_LOG(DEBUG) << "Got long polling request from node " << subscriber_id;
object_status_publisher_->Connect(subscriber_id,
std::move(long_polling_reply_callback));
object_status_publisher_->ConnectToSubscriber(subscriber_id,
std::move(long_polling_reply_callback));
}
void CoreWorker::HandleAddObjectLocationOwner(

View file

@ -101,7 +101,8 @@ bool SubscriptionIndex::AssertNoLeak() const {
return objects_to_subscribers_.size() == 0 && subscribers_to_objects_.size() == 0;
}
bool Subscriber::Connect(LongPollConnectCallback long_polling_reply_callback) {
bool Subscriber::ConnectToSubscriber(
LongPollConnectCallback long_polling_reply_callback) {
if (long_polling_reply_callback_ == nullptr) {
long_polling_reply_callback_ = long_polling_reply_callback;
return true;
@ -120,10 +121,16 @@ bool Subscriber::PublishIfPossible(bool force) {
if (long_polling_reply_callback_ == nullptr) {
return false;
}
if (force || mailbox_.size() > 0) {
long_polling_reply_callback_(mailbox_);
std::vector<ObjectID> mails_to_post;
while (mailbox_.size() > 0 && mails_to_post.size() < publish_batch_size_) {
// It ensures that mails are posted in the FIFO order.
mails_to_post.push_back(mailbox_.front());
mailbox_.pop_front();
}
long_polling_reply_callback_(mails_to_post);
long_polling_reply_callback_ = nullptr;
mailbox_.clear();
return true;
}
return false;
@ -133,8 +140,8 @@ bool Subscriber::AssertNoLeak() const {
return long_polling_reply_callback_ == nullptr && mailbox_.size() == 0;
}
void Publisher::Connect(const NodeID &subscriber_node_id,
LongPollConnectCallback long_poll_connect_callback) {
void Publisher::ConnectToSubscriber(const NodeID &subscriber_node_id,
LongPollConnectCallback long_poll_connect_callback) {
RAY_LOG(DEBUG) << "Long polling connection initiated by " << subscriber_node_id;
RAY_CHECK(long_poll_connect_callback != nullptr);
@ -148,13 +155,16 @@ void Publisher::Connect(const NodeID &subscriber_node_id,
absl::MutexLock lock(&mutex_);
auto it = subscribers_.find(subscriber_node_id);
if (it == subscribers_.end()) {
it = subscribers_.emplace(subscriber_node_id, std::make_shared<Subscriber>()).first;
it = subscribers_
.emplace(subscriber_node_id,
std::make_shared<Subscriber>(publish_batch_size_))
.first;
}
auto &subscriber = it->second;
// Since the long polling connection is synchronous between the client and coordinator,
// when it connects, the connection shouldn't have existed.
RAY_CHECK(subscriber->Connect(std::move(long_poll_connect_callback)));
RAY_CHECK(subscriber->ConnectToSubscriber(std::move(long_poll_connect_callback)));
subscriber->PublishIfPossible();
}
@ -168,7 +178,8 @@ void Publisher::RegisterSubscription(const NodeID &subscriber_node_id,
absl::MutexLock lock(&mutex_);
if (subscribers_.count(subscriber_node_id) == 0) {
subscribers_.emplace(subscriber_node_id, std::make_shared<Subscriber>());
subscribers_.emplace(subscriber_node_id,
std::make_shared<Subscriber>(publish_batch_size_));
}
subscription_index_.AddEntry(object_id, subscriber_node_id);
}

View file

@ -70,7 +70,9 @@ class SubscriptionIndex {
/// Abstraction to each subscriber for the coordinator.
class Subscriber {
public:
explicit Subscriber() {}
explicit Subscriber(const uint64_t publish_batch_size)
: publish_batch_size_(publish_batch_size) {}
~Subscriber() = default;
/// Connect to the subscriber. Currently, it means we cache the long polling request to
@ -78,7 +80,7 @@ class Subscriber {
///
/// \param long_polling_reply_callback reply callback to the long polling request.
/// \return True if connection is new. False if there were already connections cached.
bool Connect(LongPollConnectCallback long_polling_reply_callback);
bool ConnectToSubscriber(LongPollConnectCallback long_polling_reply_callback);
/// Queue the object id to publish to the subscriber.
///
@ -101,7 +103,9 @@ class Subscriber {
/// Cached long polling reply callback.
LongPollConnectCallback long_polling_reply_callback_ = nullptr;
/// Queued messages to publish.
std::vector<ObjectID> mailbox_;
std::list<ObjectID> mailbox_;
/// The maximum number of objects to publish for each publish calls.
const uint64_t publish_batch_size_;
};
/// Pubsub server.
@ -110,8 +114,9 @@ class Publisher {
/// Pubsub coordinator constructor.
///
/// \param is_node_dead A callback that returns true if the given node id is dead.
explicit Publisher(std::function<bool(const NodeID &)> is_node_dead)
: is_node_dead_(is_node_dead) {}
explicit Publisher(std::function<bool(const NodeID &)> is_node_dead,
const uint64_t publish_batch_size)
: is_node_dead_(is_node_dead), publish_batch_size_(publish_batch_size) {}
~Publisher() = default;
@ -119,8 +124,8 @@ class Publisher {
// TODO(sang): Currently, we need to pass the callback for connection because we are
// using long polling internally. This should be changed once the bidirectional grpc
// streaming is supported.
void Connect(const NodeID &subscriber_node_id,
LongPollConnectCallback long_poll_connect_callback);
void ConnectToSubscriber(const NodeID &subscriber_node_id,
LongPollConnectCallback long_poll_connect_callback);
/// Register the subscription.
///
@ -167,6 +172,9 @@ class Publisher {
/// Index that stores the mapping of objects <-> subscribers.
SubscriptionIndex subscription_index_ GUARDED_BY(mutex_);
/// The maximum number of objects to publish for each publish calls.
const uint64_t publish_batch_size_;
};
} // namespace ray

View file

@ -30,7 +30,8 @@ class PublisherTest : public ::testing::Test {
void SetUp() {
dead_nodes_.clear();
object_status_publisher_ = std::shared_ptr<Publisher>(new Publisher(
[this](const NodeID &node_id) { return dead_nodes_.count(node_id) == 1; }));
[this](const NodeID &node_id) { return dead_nodes_.count(node_id) == 1; },
/*batch_size*/ 100));
}
void TearDown() { subscribers_map_.clear(); }
@ -195,40 +196,80 @@ TEST_F(PublisherTest, TestSubscriber) {
}
};
Subscriber subscriber;
std::shared_ptr<Subscriber> subscriber = std::make_shared<Subscriber>(10);
// If there's no connection, it will return false.
ASSERT_FALSE(subscriber.PublishIfPossible());
ASSERT_FALSE(subscriber->PublishIfPossible());
// Try connecting it. Should return true.
ASSERT_TRUE(subscriber.Connect(reply));
ASSERT_TRUE(subscriber->ConnectToSubscriber(reply));
// If connecting it again, it should fail the request.
ASSERT_FALSE(subscriber.Connect(reply));
ASSERT_FALSE(subscriber->ConnectToSubscriber(reply));
// Since there's no published objects, it should return false.
ASSERT_FALSE(subscriber.PublishIfPossible());
ASSERT_FALSE(subscriber->PublishIfPossible());
std::unordered_set<ObjectID> published_objects;
// Make sure publishing one object works as expected.
auto oid = ObjectID::FromRandom();
subscriber.QueueMessage(oid, /*try_publish=*/false);
subscriber->QueueMessage(oid, /*try_publish=*/false);
published_objects.emplace(oid);
ASSERT_TRUE(subscriber.PublishIfPossible());
ASSERT_TRUE(subscriber->PublishIfPossible());
ASSERT_TRUE(object_ids_published.count(oid) > 0);
// Since the object is published, and there's no connection, it should return false.
ASSERT_FALSE(subscriber.PublishIfPossible());
ASSERT_FALSE(subscriber->PublishIfPossible());
// Add 3 oids and see if it works properly.
for (int i = 0; i < 3; i++) {
oid = ObjectID::FromRandom();
subscriber.QueueMessage(oid, /*try_publish=*/false);
subscriber->QueueMessage(oid, /*try_publish=*/false);
published_objects.emplace(oid);
}
// Since there's no connection, objects won't be published.
ASSERT_FALSE(subscriber.PublishIfPossible());
ASSERT_TRUE(subscriber.Connect(reply));
ASSERT_TRUE(subscriber.PublishIfPossible());
ASSERT_FALSE(subscriber->PublishIfPossible());
ASSERT_TRUE(subscriber->ConnectToSubscriber(reply));
ASSERT_TRUE(subscriber->PublishIfPossible());
for (auto oid : published_objects) {
ASSERT_TRUE(object_ids_published.count(oid) > 0);
}
ASSERT_TRUE(subscriber.AssertNoLeak());
ASSERT_TRUE(subscriber->AssertNoLeak());
}
TEST_F(PublisherTest, TestSubscriberBatchSize) {
std::unordered_set<ObjectID> object_ids_published;
LongPollConnectCallback reply =
[&object_ids_published](const std::vector<ObjectID> &object_ids) {
for (auto &oid : object_ids) {
object_ids_published.emplace(oid);
}
};
auto max_publish_size = 5;
std::shared_ptr<Subscriber> subscriber = std::make_shared<Subscriber>(max_publish_size);
ASSERT_TRUE(subscriber->ConnectToSubscriber(reply));
std::unordered_set<ObjectID> published_objects;
std::vector<ObjectID> oids;
for (int i = 0; i < 10; i++) {
auto oid = ObjectID::FromRandom();
oids.push_back(oid);
subscriber->QueueMessage(oid, /*try_publish=*/false);
published_objects.emplace(oid);
}
// Make sure only up to batch size is published.
ASSERT_TRUE(subscriber->PublishIfPossible());
for (int i = 0; i < max_publish_size; i++) {
ASSERT_TRUE(object_ids_published.count(oids[i]) > 0);
}
for (int i = max_publish_size; i < 10; i++) {
ASSERT_FALSE(object_ids_published.count(oids[i]) > 0);
}
// Remainings are published.
ASSERT_TRUE(subscriber->ConnectToSubscriber(reply));
ASSERT_TRUE(subscriber->PublishIfPossible());
for (int i = 0; i < 10; i++) {
ASSERT_TRUE(object_ids_published.count(oids[i]) > 0);
}
}
TEST_F(PublisherTest, TestBasicSingleSubscriber) {
@ -241,7 +282,7 @@ TEST_F(PublisherTest, TestBasicSingleSubscriber) {
const auto subscriber_node_id = NodeID::FromRandom();
const auto oid = ObjectID::FromRandom();
object_status_publisher_->Connect(subscriber_node_id, long_polling_connect);
object_status_publisher_->ConnectToSubscriber(subscriber_node_id, long_polling_connect);
object_status_publisher_->RegisterSubscription(subscriber_node_id, oid);
object_status_publisher_->Publish(oid);
ASSERT_EQ(batched_ids[0], oid);
@ -261,7 +302,7 @@ TEST_F(PublisherTest, TestNoConnectionWhenRegistered) {
object_status_publisher_->Publish(oid);
// Nothing has been published because there's no connection.
ASSERT_EQ(batched_ids.size(), 0);
object_status_publisher_->Connect(subscriber_node_id, long_polling_connect);
object_status_publisher_->ConnectToSubscriber(subscriber_node_id, long_polling_connect);
// When the connection is coming, it should be published.
ASSERT_EQ(batched_ids[0], oid);
}
@ -285,7 +326,7 @@ TEST_F(PublisherTest, TestMultiObjectsFromSingleNode) {
ASSERT_EQ(batched_ids.size(), 0);
// Now connection is initiated, and all oids are published.
object_status_publisher_->Connect(subscriber_node_id, long_polling_connect);
object_status_publisher_->ConnectToSubscriber(subscriber_node_id, long_polling_connect);
for (int i = 0; i < num_oids; i++) {
const auto oid_test = oids[i];
const auto published_oid = batched_ids[i];
@ -320,7 +361,8 @@ TEST_F(PublisherTest, TestMultiObjectsFromMultiNodes) {
// Check all of nodes are publishing objects properly.
for (int i = 0; i < num_nodes; i++) {
const auto subscriber_node_id = subscribers[i];
object_status_publisher_->Connect(subscriber_node_id, long_polling_connect);
object_status_publisher_->ConnectToSubscriber(subscriber_node_id,
long_polling_connect);
const auto oid_test = oids[i];
const auto published_oid = batched_ids[i];
ASSERT_EQ(oid_test, published_oid);
@ -347,7 +389,7 @@ TEST_F(PublisherTest, TestBatch) {
ASSERT_EQ(batched_ids.size(), 0);
// Now connection is initiated, and all oids are published.
object_status_publisher_->Connect(subscriber_node_id, long_polling_connect);
object_status_publisher_->ConnectToSubscriber(subscriber_node_id, long_polling_connect);
for (int i = 0; i < num_oids; i++) {
const auto oid_test = oids[i];
const auto published_oid = batched_ids[i];
@ -363,7 +405,7 @@ TEST_F(PublisherTest, TestBatch) {
object_status_publisher_->RegisterSubscription(subscriber_node_id, oid);
object_status_publisher_->Publish(oid);
}
object_status_publisher_->Connect(subscriber_node_id, long_polling_connect);
object_status_publisher_->ConnectToSubscriber(subscriber_node_id, long_polling_connect);
for (int i = 0; i < num_oids; i++) {
const auto oid_test = oids[i];
const auto published_oid = batched_ids[i];
@ -380,7 +422,7 @@ TEST_F(PublisherTest, TestNodeFailureWhenConnectionExisted) {
const auto subscriber_node_id = NodeID::FromRandom();
const auto oid = ObjectID::FromRandom();
object_status_publisher_->Connect(subscriber_node_id, long_polling_connect);
object_status_publisher_->ConnectToSubscriber(subscriber_node_id, long_polling_connect);
dead_nodes_.emplace(subscriber_node_id);
// All these ops should be no-op.
object_status_publisher_->RegisterSubscription(subscriber_node_id, oid);
@ -414,7 +456,7 @@ TEST_F(PublisherTest, TestNodeFailureWhenConnectionDoesntExist) {
ASSERT_EQ(long_polling_connection_replied, false);
// Connect should reply right away to avoid memory leak.
object_status_publisher_->Connect(subscriber_node_id, long_polling_connect);
object_status_publisher_->ConnectToSubscriber(subscriber_node_id, long_polling_connect);
ASSERT_EQ(long_polling_connection_replied, true);
long_polling_connection_replied = false;
@ -429,7 +471,7 @@ TEST_F(PublisherTest, TestNodeFailureWhenConnectionDoesntExist) {
///
subscriber_node_id = NodeID::FromRandom();
oid = ObjectID::FromRandom();
object_status_publisher_->Connect(subscriber_node_id, long_polling_connect);
object_status_publisher_->ConnectToSubscriber(subscriber_node_id, long_polling_connect);
dead_nodes_.emplace(subscriber_node_id);
erased = object_status_publisher_->UnregisterSubscriber(subscriber_node_id);
ASSERT_EQ(long_polling_connection_replied, true);
@ -447,7 +489,7 @@ TEST_F(PublisherTest, TestUnregisterSubscription) {
const auto subscriber_node_id = NodeID::FromRandom();
const auto oid = ObjectID::FromRandom();
object_status_publisher_->Connect(subscriber_node_id, long_polling_connect);
object_status_publisher_->ConnectToSubscriber(subscriber_node_id, long_polling_connect);
object_status_publisher_->RegisterSubscription(subscriber_node_id, oid);
ASSERT_EQ(long_polling_connection_replied, false);
@ -482,7 +524,7 @@ TEST_F(PublisherTest, TestUnregisterSubscriber) {
// Test basic.
const auto subscriber_node_id = NodeID::FromRandom();
const auto oid = ObjectID::FromRandom();
object_status_publisher_->Connect(subscriber_node_id, long_polling_connect);
object_status_publisher_->ConnectToSubscriber(subscriber_node_id, long_polling_connect);
object_status_publisher_->RegisterSubscription(subscriber_node_id, oid);
ASSERT_EQ(long_polling_connection_replied, false);
int erased = object_status_publisher_->UnregisterSubscriber(subscriber_node_id);
@ -492,7 +534,7 @@ TEST_F(PublisherTest, TestUnregisterSubscriber) {
// Test when registration wasn't done.
long_polling_connection_replied = false;
object_status_publisher_->Connect(subscriber_node_id, long_polling_connect);
object_status_publisher_->ConnectToSubscriber(subscriber_node_id, long_polling_connect);
erased = object_status_publisher_->UnregisterSubscriber(subscriber_node_id);
ASSERT_FALSE(erased);
ASSERT_EQ(long_polling_connection_replied, true);