mirror of
https://github.com/vale981/ray
synced 2025-03-06 18:41:40 -05:00
[Pubsub] Batch messages (#15084)
* batch pubsub 1 * Logic done. Tests left. * done.
This commit is contained in:
parent
aea28c53ce
commit
cef6286f63
5 changed files with 111 additions and 45 deletions
|
@ -403,3 +403,6 @@ RAY_CONFIG(int64_t, asio_stats_print_interval_ms, -1)
|
|||
|
||||
/// Maximum amount of memory that will be used by running tasks' args.
|
||||
RAY_CONFIG(float, max_task_args_memory_fraction, 0.7)
|
||||
|
||||
/// The maximum number of objects to publish for each publish calls.
|
||||
RAY_CONFIG(uint64_t, publish_batch_size, 5000)
|
||||
|
|
|
@ -563,7 +563,8 @@ CoreWorker::CoreWorker(const CoreWorkerOptions &options, const WorkerID &worker_
|
|||
task_manager_));
|
||||
|
||||
object_status_publisher_ = std::make_shared<Publisher>(
|
||||
/*is_node_dead=*/[this](const NodeID &node_id) {
|
||||
/*is_node_dead=*/
|
||||
[this](const NodeID &node_id) {
|
||||
if (auto node_info =
|
||||
gcs_client_->Nodes().Get(node_id, /*filter_dead_nodes=*/false)) {
|
||||
return node_info->state() ==
|
||||
|
@ -572,7 +573,8 @@ CoreWorker::CoreWorker(const CoreWorkerOptions &options, const WorkerID &worker_
|
|||
// Node information is probably not
|
||||
// subscribed yet, so report that the node is alive.
|
||||
return true;
|
||||
});
|
||||
},
|
||||
/*publish_batch_size_=*/RayConfig::instance().publish_batch_size());
|
||||
|
||||
auto node_addr_factory = [this](const NodeID &node_id) {
|
||||
absl::optional<rpc::Address> addr;
|
||||
|
@ -2409,8 +2411,8 @@ void CoreWorker::HandlePubsubLongPolling(const rpc::PubsubLongPollingRequest &re
|
|||
send_reply_callback(Status::OK(), nullptr, nullptr);
|
||||
};
|
||||
RAY_LOG(DEBUG) << "Got long polling request from node " << subscriber_id;
|
||||
object_status_publisher_->Connect(subscriber_id,
|
||||
std::move(long_polling_reply_callback));
|
||||
object_status_publisher_->ConnectToSubscriber(subscriber_id,
|
||||
std::move(long_polling_reply_callback));
|
||||
}
|
||||
|
||||
void CoreWorker::HandleAddObjectLocationOwner(
|
||||
|
|
|
@ -101,7 +101,8 @@ bool SubscriptionIndex::AssertNoLeak() const {
|
|||
return objects_to_subscribers_.size() == 0 && subscribers_to_objects_.size() == 0;
|
||||
}
|
||||
|
||||
bool Subscriber::Connect(LongPollConnectCallback long_polling_reply_callback) {
|
||||
bool Subscriber::ConnectToSubscriber(
|
||||
LongPollConnectCallback long_polling_reply_callback) {
|
||||
if (long_polling_reply_callback_ == nullptr) {
|
||||
long_polling_reply_callback_ = long_polling_reply_callback;
|
||||
return true;
|
||||
|
@ -120,10 +121,16 @@ bool Subscriber::PublishIfPossible(bool force) {
|
|||
if (long_polling_reply_callback_ == nullptr) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (force || mailbox_.size() > 0) {
|
||||
long_polling_reply_callback_(mailbox_);
|
||||
std::vector<ObjectID> mails_to_post;
|
||||
while (mailbox_.size() > 0 && mails_to_post.size() < publish_batch_size_) {
|
||||
// It ensures that mails are posted in the FIFO order.
|
||||
mails_to_post.push_back(mailbox_.front());
|
||||
mailbox_.pop_front();
|
||||
}
|
||||
long_polling_reply_callback_(mails_to_post);
|
||||
long_polling_reply_callback_ = nullptr;
|
||||
mailbox_.clear();
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
|
@ -133,8 +140,8 @@ bool Subscriber::AssertNoLeak() const {
|
|||
return long_polling_reply_callback_ == nullptr && mailbox_.size() == 0;
|
||||
}
|
||||
|
||||
void Publisher::Connect(const NodeID &subscriber_node_id,
|
||||
LongPollConnectCallback long_poll_connect_callback) {
|
||||
void Publisher::ConnectToSubscriber(const NodeID &subscriber_node_id,
|
||||
LongPollConnectCallback long_poll_connect_callback) {
|
||||
RAY_LOG(DEBUG) << "Long polling connection initiated by " << subscriber_node_id;
|
||||
RAY_CHECK(long_poll_connect_callback != nullptr);
|
||||
|
||||
|
@ -148,13 +155,16 @@ void Publisher::Connect(const NodeID &subscriber_node_id,
|
|||
absl::MutexLock lock(&mutex_);
|
||||
auto it = subscribers_.find(subscriber_node_id);
|
||||
if (it == subscribers_.end()) {
|
||||
it = subscribers_.emplace(subscriber_node_id, std::make_shared<Subscriber>()).first;
|
||||
it = subscribers_
|
||||
.emplace(subscriber_node_id,
|
||||
std::make_shared<Subscriber>(publish_batch_size_))
|
||||
.first;
|
||||
}
|
||||
auto &subscriber = it->second;
|
||||
|
||||
// Since the long polling connection is synchronous between the client and coordinator,
|
||||
// when it connects, the connection shouldn't have existed.
|
||||
RAY_CHECK(subscriber->Connect(std::move(long_poll_connect_callback)));
|
||||
RAY_CHECK(subscriber->ConnectToSubscriber(std::move(long_poll_connect_callback)));
|
||||
subscriber->PublishIfPossible();
|
||||
}
|
||||
|
||||
|
@ -168,7 +178,8 @@ void Publisher::RegisterSubscription(const NodeID &subscriber_node_id,
|
|||
|
||||
absl::MutexLock lock(&mutex_);
|
||||
if (subscribers_.count(subscriber_node_id) == 0) {
|
||||
subscribers_.emplace(subscriber_node_id, std::make_shared<Subscriber>());
|
||||
subscribers_.emplace(subscriber_node_id,
|
||||
std::make_shared<Subscriber>(publish_batch_size_));
|
||||
}
|
||||
subscription_index_.AddEntry(object_id, subscriber_node_id);
|
||||
}
|
||||
|
|
|
@ -70,7 +70,9 @@ class SubscriptionIndex {
|
|||
/// Abstraction to each subscriber for the coordinator.
|
||||
class Subscriber {
|
||||
public:
|
||||
explicit Subscriber() {}
|
||||
explicit Subscriber(const uint64_t publish_batch_size)
|
||||
: publish_batch_size_(publish_batch_size) {}
|
||||
|
||||
~Subscriber() = default;
|
||||
|
||||
/// Connect to the subscriber. Currently, it means we cache the long polling request to
|
||||
|
@ -78,7 +80,7 @@ class Subscriber {
|
|||
///
|
||||
/// \param long_polling_reply_callback reply callback to the long polling request.
|
||||
/// \return True if connection is new. False if there were already connections cached.
|
||||
bool Connect(LongPollConnectCallback long_polling_reply_callback);
|
||||
bool ConnectToSubscriber(LongPollConnectCallback long_polling_reply_callback);
|
||||
|
||||
/// Queue the object id to publish to the subscriber.
|
||||
///
|
||||
|
@ -101,7 +103,9 @@ class Subscriber {
|
|||
/// Cached long polling reply callback.
|
||||
LongPollConnectCallback long_polling_reply_callback_ = nullptr;
|
||||
/// Queued messages to publish.
|
||||
std::vector<ObjectID> mailbox_;
|
||||
std::list<ObjectID> mailbox_;
|
||||
/// The maximum number of objects to publish for each publish calls.
|
||||
const uint64_t publish_batch_size_;
|
||||
};
|
||||
|
||||
/// Pubsub server.
|
||||
|
@ -110,8 +114,9 @@ class Publisher {
|
|||
/// Pubsub coordinator constructor.
|
||||
///
|
||||
/// \param is_node_dead A callback that returns true if the given node id is dead.
|
||||
explicit Publisher(std::function<bool(const NodeID &)> is_node_dead)
|
||||
: is_node_dead_(is_node_dead) {}
|
||||
explicit Publisher(std::function<bool(const NodeID &)> is_node_dead,
|
||||
const uint64_t publish_batch_size)
|
||||
: is_node_dead_(is_node_dead), publish_batch_size_(publish_batch_size) {}
|
||||
|
||||
~Publisher() = default;
|
||||
|
||||
|
@ -119,8 +124,8 @@ class Publisher {
|
|||
// TODO(sang): Currently, we need to pass the callback for connection because we are
|
||||
// using long polling internally. This should be changed once the bidirectional grpc
|
||||
// streaming is supported.
|
||||
void Connect(const NodeID &subscriber_node_id,
|
||||
LongPollConnectCallback long_poll_connect_callback);
|
||||
void ConnectToSubscriber(const NodeID &subscriber_node_id,
|
||||
LongPollConnectCallback long_poll_connect_callback);
|
||||
|
||||
/// Register the subscription.
|
||||
///
|
||||
|
@ -167,6 +172,9 @@ class Publisher {
|
|||
|
||||
/// Index that stores the mapping of objects <-> subscribers.
|
||||
SubscriptionIndex subscription_index_ GUARDED_BY(mutex_);
|
||||
|
||||
/// The maximum number of objects to publish for each publish calls.
|
||||
const uint64_t publish_batch_size_;
|
||||
};
|
||||
|
||||
} // namespace ray
|
||||
|
|
|
@ -30,7 +30,8 @@ class PublisherTest : public ::testing::Test {
|
|||
void SetUp() {
|
||||
dead_nodes_.clear();
|
||||
object_status_publisher_ = std::shared_ptr<Publisher>(new Publisher(
|
||||
[this](const NodeID &node_id) { return dead_nodes_.count(node_id) == 1; }));
|
||||
[this](const NodeID &node_id) { return dead_nodes_.count(node_id) == 1; },
|
||||
/*batch_size*/ 100));
|
||||
}
|
||||
|
||||
void TearDown() { subscribers_map_.clear(); }
|
||||
|
@ -195,40 +196,80 @@ TEST_F(PublisherTest, TestSubscriber) {
|
|||
}
|
||||
};
|
||||
|
||||
Subscriber subscriber;
|
||||
std::shared_ptr<Subscriber> subscriber = std::make_shared<Subscriber>(10);
|
||||
// If there's no connection, it will return false.
|
||||
ASSERT_FALSE(subscriber.PublishIfPossible());
|
||||
ASSERT_FALSE(subscriber->PublishIfPossible());
|
||||
// Try connecting it. Should return true.
|
||||
ASSERT_TRUE(subscriber.Connect(reply));
|
||||
ASSERT_TRUE(subscriber->ConnectToSubscriber(reply));
|
||||
// If connecting it again, it should fail the request.
|
||||
ASSERT_FALSE(subscriber.Connect(reply));
|
||||
ASSERT_FALSE(subscriber->ConnectToSubscriber(reply));
|
||||
// Since there's no published objects, it should return false.
|
||||
ASSERT_FALSE(subscriber.PublishIfPossible());
|
||||
ASSERT_FALSE(subscriber->PublishIfPossible());
|
||||
|
||||
std::unordered_set<ObjectID> published_objects;
|
||||
// Make sure publishing one object works as expected.
|
||||
auto oid = ObjectID::FromRandom();
|
||||
subscriber.QueueMessage(oid, /*try_publish=*/false);
|
||||
subscriber->QueueMessage(oid, /*try_publish=*/false);
|
||||
published_objects.emplace(oid);
|
||||
ASSERT_TRUE(subscriber.PublishIfPossible());
|
||||
ASSERT_TRUE(subscriber->PublishIfPossible());
|
||||
ASSERT_TRUE(object_ids_published.count(oid) > 0);
|
||||
// Since the object is published, and there's no connection, it should return false.
|
||||
ASSERT_FALSE(subscriber.PublishIfPossible());
|
||||
ASSERT_FALSE(subscriber->PublishIfPossible());
|
||||
|
||||
// Add 3 oids and see if it works properly.
|
||||
for (int i = 0; i < 3; i++) {
|
||||
oid = ObjectID::FromRandom();
|
||||
subscriber.QueueMessage(oid, /*try_publish=*/false);
|
||||
subscriber->QueueMessage(oid, /*try_publish=*/false);
|
||||
published_objects.emplace(oid);
|
||||
}
|
||||
// Since there's no connection, objects won't be published.
|
||||
ASSERT_FALSE(subscriber.PublishIfPossible());
|
||||
ASSERT_TRUE(subscriber.Connect(reply));
|
||||
ASSERT_TRUE(subscriber.PublishIfPossible());
|
||||
ASSERT_FALSE(subscriber->PublishIfPossible());
|
||||
ASSERT_TRUE(subscriber->ConnectToSubscriber(reply));
|
||||
ASSERT_TRUE(subscriber->PublishIfPossible());
|
||||
for (auto oid : published_objects) {
|
||||
ASSERT_TRUE(object_ids_published.count(oid) > 0);
|
||||
}
|
||||
ASSERT_TRUE(subscriber.AssertNoLeak());
|
||||
ASSERT_TRUE(subscriber->AssertNoLeak());
|
||||
}
|
||||
|
||||
TEST_F(PublisherTest, TestSubscriberBatchSize) {
|
||||
std::unordered_set<ObjectID> object_ids_published;
|
||||
LongPollConnectCallback reply =
|
||||
[&object_ids_published](const std::vector<ObjectID> &object_ids) {
|
||||
for (auto &oid : object_ids) {
|
||||
object_ids_published.emplace(oid);
|
||||
}
|
||||
};
|
||||
|
||||
auto max_publish_size = 5;
|
||||
std::shared_ptr<Subscriber> subscriber = std::make_shared<Subscriber>(max_publish_size);
|
||||
ASSERT_TRUE(subscriber->ConnectToSubscriber(reply));
|
||||
|
||||
std::unordered_set<ObjectID> published_objects;
|
||||
std::vector<ObjectID> oids;
|
||||
for (int i = 0; i < 10; i++) {
|
||||
auto oid = ObjectID::FromRandom();
|
||||
oids.push_back(oid);
|
||||
subscriber->QueueMessage(oid, /*try_publish=*/false);
|
||||
published_objects.emplace(oid);
|
||||
}
|
||||
|
||||
// Make sure only up to batch size is published.
|
||||
ASSERT_TRUE(subscriber->PublishIfPossible());
|
||||
|
||||
for (int i = 0; i < max_publish_size; i++) {
|
||||
ASSERT_TRUE(object_ids_published.count(oids[i]) > 0);
|
||||
}
|
||||
for (int i = max_publish_size; i < 10; i++) {
|
||||
ASSERT_FALSE(object_ids_published.count(oids[i]) > 0);
|
||||
}
|
||||
|
||||
// Remainings are published.
|
||||
ASSERT_TRUE(subscriber->ConnectToSubscriber(reply));
|
||||
ASSERT_TRUE(subscriber->PublishIfPossible());
|
||||
for (int i = 0; i < 10; i++) {
|
||||
ASSERT_TRUE(object_ids_published.count(oids[i]) > 0);
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(PublisherTest, TestBasicSingleSubscriber) {
|
||||
|
@ -241,7 +282,7 @@ TEST_F(PublisherTest, TestBasicSingleSubscriber) {
|
|||
const auto subscriber_node_id = NodeID::FromRandom();
|
||||
const auto oid = ObjectID::FromRandom();
|
||||
|
||||
object_status_publisher_->Connect(subscriber_node_id, long_polling_connect);
|
||||
object_status_publisher_->ConnectToSubscriber(subscriber_node_id, long_polling_connect);
|
||||
object_status_publisher_->RegisterSubscription(subscriber_node_id, oid);
|
||||
object_status_publisher_->Publish(oid);
|
||||
ASSERT_EQ(batched_ids[0], oid);
|
||||
|
@ -261,7 +302,7 @@ TEST_F(PublisherTest, TestNoConnectionWhenRegistered) {
|
|||
object_status_publisher_->Publish(oid);
|
||||
// Nothing has been published because there's no connection.
|
||||
ASSERT_EQ(batched_ids.size(), 0);
|
||||
object_status_publisher_->Connect(subscriber_node_id, long_polling_connect);
|
||||
object_status_publisher_->ConnectToSubscriber(subscriber_node_id, long_polling_connect);
|
||||
// When the connection is coming, it should be published.
|
||||
ASSERT_EQ(batched_ids[0], oid);
|
||||
}
|
||||
|
@ -285,7 +326,7 @@ TEST_F(PublisherTest, TestMultiObjectsFromSingleNode) {
|
|||
ASSERT_EQ(batched_ids.size(), 0);
|
||||
|
||||
// Now connection is initiated, and all oids are published.
|
||||
object_status_publisher_->Connect(subscriber_node_id, long_polling_connect);
|
||||
object_status_publisher_->ConnectToSubscriber(subscriber_node_id, long_polling_connect);
|
||||
for (int i = 0; i < num_oids; i++) {
|
||||
const auto oid_test = oids[i];
|
||||
const auto published_oid = batched_ids[i];
|
||||
|
@ -320,7 +361,8 @@ TEST_F(PublisherTest, TestMultiObjectsFromMultiNodes) {
|
|||
// Check all of nodes are publishing objects properly.
|
||||
for (int i = 0; i < num_nodes; i++) {
|
||||
const auto subscriber_node_id = subscribers[i];
|
||||
object_status_publisher_->Connect(subscriber_node_id, long_polling_connect);
|
||||
object_status_publisher_->ConnectToSubscriber(subscriber_node_id,
|
||||
long_polling_connect);
|
||||
const auto oid_test = oids[i];
|
||||
const auto published_oid = batched_ids[i];
|
||||
ASSERT_EQ(oid_test, published_oid);
|
||||
|
@ -347,7 +389,7 @@ TEST_F(PublisherTest, TestBatch) {
|
|||
ASSERT_EQ(batched_ids.size(), 0);
|
||||
|
||||
// Now connection is initiated, and all oids are published.
|
||||
object_status_publisher_->Connect(subscriber_node_id, long_polling_connect);
|
||||
object_status_publisher_->ConnectToSubscriber(subscriber_node_id, long_polling_connect);
|
||||
for (int i = 0; i < num_oids; i++) {
|
||||
const auto oid_test = oids[i];
|
||||
const auto published_oid = batched_ids[i];
|
||||
|
@ -363,7 +405,7 @@ TEST_F(PublisherTest, TestBatch) {
|
|||
object_status_publisher_->RegisterSubscription(subscriber_node_id, oid);
|
||||
object_status_publisher_->Publish(oid);
|
||||
}
|
||||
object_status_publisher_->Connect(subscriber_node_id, long_polling_connect);
|
||||
object_status_publisher_->ConnectToSubscriber(subscriber_node_id, long_polling_connect);
|
||||
for (int i = 0; i < num_oids; i++) {
|
||||
const auto oid_test = oids[i];
|
||||
const auto published_oid = batched_ids[i];
|
||||
|
@ -380,7 +422,7 @@ TEST_F(PublisherTest, TestNodeFailureWhenConnectionExisted) {
|
|||
|
||||
const auto subscriber_node_id = NodeID::FromRandom();
|
||||
const auto oid = ObjectID::FromRandom();
|
||||
object_status_publisher_->Connect(subscriber_node_id, long_polling_connect);
|
||||
object_status_publisher_->ConnectToSubscriber(subscriber_node_id, long_polling_connect);
|
||||
dead_nodes_.emplace(subscriber_node_id);
|
||||
// All these ops should be no-op.
|
||||
object_status_publisher_->RegisterSubscription(subscriber_node_id, oid);
|
||||
|
@ -414,7 +456,7 @@ TEST_F(PublisherTest, TestNodeFailureWhenConnectionDoesntExist) {
|
|||
ASSERT_EQ(long_polling_connection_replied, false);
|
||||
|
||||
// Connect should reply right away to avoid memory leak.
|
||||
object_status_publisher_->Connect(subscriber_node_id, long_polling_connect);
|
||||
object_status_publisher_->ConnectToSubscriber(subscriber_node_id, long_polling_connect);
|
||||
ASSERT_EQ(long_polling_connection_replied, true);
|
||||
long_polling_connection_replied = false;
|
||||
|
||||
|
@ -429,7 +471,7 @@ TEST_F(PublisherTest, TestNodeFailureWhenConnectionDoesntExist) {
|
|||
///
|
||||
subscriber_node_id = NodeID::FromRandom();
|
||||
oid = ObjectID::FromRandom();
|
||||
object_status_publisher_->Connect(subscriber_node_id, long_polling_connect);
|
||||
object_status_publisher_->ConnectToSubscriber(subscriber_node_id, long_polling_connect);
|
||||
dead_nodes_.emplace(subscriber_node_id);
|
||||
erased = object_status_publisher_->UnregisterSubscriber(subscriber_node_id);
|
||||
ASSERT_EQ(long_polling_connection_replied, true);
|
||||
|
@ -447,7 +489,7 @@ TEST_F(PublisherTest, TestUnregisterSubscription) {
|
|||
|
||||
const auto subscriber_node_id = NodeID::FromRandom();
|
||||
const auto oid = ObjectID::FromRandom();
|
||||
object_status_publisher_->Connect(subscriber_node_id, long_polling_connect);
|
||||
object_status_publisher_->ConnectToSubscriber(subscriber_node_id, long_polling_connect);
|
||||
object_status_publisher_->RegisterSubscription(subscriber_node_id, oid);
|
||||
ASSERT_EQ(long_polling_connection_replied, false);
|
||||
|
||||
|
@ -482,7 +524,7 @@ TEST_F(PublisherTest, TestUnregisterSubscriber) {
|
|||
// Test basic.
|
||||
const auto subscriber_node_id = NodeID::FromRandom();
|
||||
const auto oid = ObjectID::FromRandom();
|
||||
object_status_publisher_->Connect(subscriber_node_id, long_polling_connect);
|
||||
object_status_publisher_->ConnectToSubscriber(subscriber_node_id, long_polling_connect);
|
||||
object_status_publisher_->RegisterSubscription(subscriber_node_id, oid);
|
||||
ASSERT_EQ(long_polling_connection_replied, false);
|
||||
int erased = object_status_publisher_->UnregisterSubscriber(subscriber_node_id);
|
||||
|
@ -492,7 +534,7 @@ TEST_F(PublisherTest, TestUnregisterSubscriber) {
|
|||
|
||||
// Test when registration wasn't done.
|
||||
long_polling_connection_replied = false;
|
||||
object_status_publisher_->Connect(subscriber_node_id, long_polling_connect);
|
||||
object_status_publisher_->ConnectToSubscriber(subscriber_node_id, long_polling_connect);
|
||||
erased = object_status_publisher_->UnregisterSubscriber(subscriber_node_id);
|
||||
ASSERT_FALSE(erased);
|
||||
ASSERT_EQ(long_polling_connection_replied, true);
|
||||
|
|
Loading…
Add table
Reference in a new issue