[core] Queue subscription/unsubscription commands in the GCS (#8756)

* Only remove callback index if in map

* test

* Queue subscription commands

* lint

* Check status

* update

* update

* update

* Disable GCS restart tests

* lint
This commit is contained in:
Stephanie Wang 2020-06-05 19:49:19 -07:00 committed by GitHub
parent 54189bca5a
commit b160b83d3e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 232 additions and 63 deletions

View file

@ -1160,6 +1160,8 @@ cc_test(
"//:redis-cli",
"//:redis-server",
],
# TODO(swang): Enable again once pubsub client supports GCS server restart.
tags = ["manual"],
deps = [
":gcs_server_lib",
":gcs_test_util_lib",

View file

@ -459,6 +459,7 @@ py_test(
name = "test_gcs_fault_tolerance",
size = "medium",
srcs = SRCS + ["test_gcs_fault_tolerance.py"],
tags = ["exclusive"],
# TODO(swang): Enable again once pubsub client supports GCS server restart.
tags = ["exclusive", "manual"],
deps = ["//:ray_lib"],
)

View file

@ -44,52 +44,120 @@ Status GcsPubSub::SubscribeAll(const std::string &channel, const Callback &subsc
return SubscribeInternal(channel, subscribe, done);
}
Status GcsPubSub::Unsubscribe(const std::string &channel, const std::string &id) {
std::string pattern = GenChannelPattern(channel, id);
{
absl::MutexLock lock(&mutex_);
auto it = subscribe_callback_index_.find(pattern);
RAY_CHECK(it != subscribe_callback_index_.end());
unsubscribe_callback_index_[pattern] = it->second;
subscribe_callback_index_.erase(it);
}
return redis_client_->GetPrimaryContext()->PUnsubscribeAsync(pattern);
Status GcsPubSub::Unsubscribe(const std::string &channel_name, const std::string &id) {
std::string pattern = GenChannelPattern(channel_name, id);
absl::MutexLock lock(&mutex_);
// Add the UNSUBSCRIBE command to the queue.
auto channel = channels_.find(pattern);
RAY_CHECK(channel != channels_.end());
channel->second.command_queue.push_back(Command());
// Process the first command on the queue, if possible.
return ExecuteCommandIfPossible(channel->first, channel->second);
}
Status GcsPubSub::SubscribeInternal(const std::string &channel, const Callback &subscribe,
const StatusCallback &done,
Status GcsPubSub::SubscribeInternal(const std::string &channel_name,
const Callback &subscribe, const StatusCallback &done,
const boost::optional<std::string> &id) {
std::string pattern = GenChannelPattern(channel, id);
auto callback = [this, pattern, done, subscribe](std::shared_ptr<CallbackReply> reply) {
if (!reply->IsNil()) {
std::string pattern = GenChannelPattern(channel_name, id);
absl::MutexLock lock(&mutex_);
auto channel = channels_.find(pattern);
if (channel == channels_.end()) {
// There were no pending commands for this channel and we were not already
// subscribed.
channel = channels_.emplace(pattern, Channel()).first;
}
// Add the SUBSCRIBE command to the queue.
channel->second.command_queue.push_back(Command(subscribe, done));
// Process the first command on the queue, if possible.
return ExecuteCommandIfPossible(channel->first, channel->second);
}
Status GcsPubSub::ExecuteCommandIfPossible(const std::string &channel_key,
GcsPubSub::Channel &channel) {
// Process the first command on the queue, if possible.
Status status;
auto &command = channel.command_queue.front();
if (command.is_subscribe && channel.callback_index == -1) {
// The next command is SUBSCRIBE and we are currently unsubscribed, so we
// can execute the command.
int64_t callback_index =
ray::gcs::RedisCallbackManager::instance().AllocateCallbackIndex();
const auto &command_done_callback = command.done_callback;
const auto &command_subscribe_callback = command.subscribe_callback;
auto callback = [this, channel_key, command_done_callback, command_subscribe_callback,
callback_index](std::shared_ptr<CallbackReply> reply) {
if (reply->IsNil()) {
return;
}
if (reply->IsUnsubscribeCallback()) {
// Unset the callback index.
absl::MutexLock lock(&mutex_);
ray::gcs::RedisCallbackManager::instance().remove(
unsubscribe_callback_index_[pattern]);
unsubscribe_callback_index_.erase(pattern);
auto channel = channels_.find(channel_key);
RAY_CHECK(channel != channels_.end());
ray::gcs::RedisCallbackManager::instance().RemoveCallback(
channel->second.callback_index);
channel->second.callback_index = -1;
channel->second.pending_reply = false;
if (channel->second.command_queue.empty()) {
// We are unsubscribed and there are no more commands to process.
// Delete the channel.
channels_.erase(channel);
} else {
// Process the next item in the queue.
RAY_CHECK(channel->second.command_queue.front().is_subscribe);
RAY_CHECK_OK(ExecuteCommandIfPossible(channel_key, channel->second));
}
} else if (reply->IsSubscribeCallback()) {
if (done) {
done(Status::OK());
{
// Set the callback index.
absl::MutexLock lock(&mutex_);
auto channel = channels_.find(channel_key);
RAY_CHECK(channel != channels_.end());
channel->second.callback_index = callback_index;
channel->second.pending_reply = false;
// Process the next item in the queue, if any.
if (!channel->second.command_queue.empty()) {
RAY_CHECK(!channel->second.command_queue.front().is_subscribe);
RAY_CHECK_OK(ExecuteCommandIfPossible(channel_key, channel->second));
}
}
if (command_done_callback) {
command_done_callback(Status::OK());
}
} else {
const auto reply_data = reply->ReadAsPubsubData();
if (!reply_data.empty()) {
rpc::PubSubMessage message;
message.ParseFromString(reply_data);
subscribe(message.id(), message.data());
command_subscribe_callback(message.id(), message.data());
}
}
}
};
int64_t out_callback_index;
auto status = redis_client_->GetPrimaryContext()->PSubscribeAsync(pattern, callback,
&out_callback_index);
if (id) {
absl::MutexLock lock(&mutex_);
// If the same pattern has been subscribed more than once, the last subscription takes
// effect.
subscribe_callback_index_[pattern] = out_callback_index;
};
status = redis_client_->GetPrimaryContext()->PSubscribeAsync(channel_key, callback,
callback_index);
channel.pending_reply = true;
channel.command_queue.pop_front();
} else if (!command.is_subscribe && channel.callback_index != -1) {
// The next command is UNSUBSCRIBE and we are currently subscribed, so we
// can execute the command. The reply for will be received through the
// SUBSCRIBE command's callback.
status = redis_client_->GetPrimaryContext()->PUnsubscribeAsync(channel_key);
channel.pending_reply = true;
channel.command_queue.pop_front();
} else if (!channel.pending_reply) {
// There is no in-flight command, but the next command to execute is not
// runnable. The caller must have sent a command out-of-order.
// TODO(swang): This can cause a fatal error if the GCS server restarts and
// the client attempts to subscribe again.
RAY_LOG(FATAL) << "Caller attempted a duplicate subscribe or unsubscribe to channel "
<< channel_key;
}
return status;
}

View file

@ -15,6 +15,7 @@
#ifndef RAY_GCS_GCS_PUB_SUB_H_
#define RAY_GCS_GCS_PUB_SUB_H_
#include "absl/container/flat_hash_map.h"
#include "absl/synchronization/mutex.h"
#include "ray/gcs/callback.h"
@ -89,6 +90,57 @@ class GcsPubSub {
Status Unsubscribe(const std::string &channel, const std::string &id);
private:
/// Represents a caller's command to subscribe or unsubscribe to a given
/// channel.
struct Command {
/// SUBSCRIBE constructor.
Command(const Callback &subscribe_callback, const StatusCallback &done_callback)
: is_subscribe(true),
subscribe_callback(subscribe_callback),
done_callback(done_callback) {}
/// UNSUBSCRIBE constructor.
Command() : is_subscribe(false) {}
/// True if this is a SUBSCRIBE command and false if UNSUBSCRIBE.
const bool is_subscribe;
/// Callback that is called whenever a new pubsub message is received from
/// Redis. This should only be set if is_subscribe is true.
const Callback subscribe_callback;
/// Callback that is called once we have successfully subscribed to a
/// channel. This should only be set if is_subscribe is true.
const StatusCallback done_callback;
};
struct Channel {
Channel() {}
/// Queue of subscribe/unsubscribe commands to this channel. The queue
/// asserts that subscribe and unsubscribe commands alternate, i.e. there
/// cannot be more than one subscribe/unsubscribe command in a row. A
/// subscribe command can execute if the callback index below is not set,
/// i.e. this is the first subscribe command or the last unsubscribe
/// command's reply has been received. An unsubscribe command can execute
/// if the callback index is set, i.e. the last subscribe command's reply
/// has been received.
std::deque<Command> command_queue;
/// The current Redis callback index stored in the RedisContext for this
/// channel. This callback index is used to identify any pubsub
/// notifications meant for this channel. The callback index is set once we
/// have received a reply from Redis that we have subscribed. The callback
/// index is set back to -1 if we receive a reply from Redis that we have
/// unsubscribed.
int64_t callback_index = -1;
/// Whether we are pending a reply from Redis. We cannot send another
/// command from the queue until this has been reset to false.
bool pending_reply = false;
};
/// Execute the first queued command for the given channel, if possible. A
/// subscribe command can execute if the channel's callback index is not set.
/// An unsubscribe command can execute if the channel's callback index is
/// set.
Status ExecuteCommandIfPossible(const std::string &channel_key,
GcsPubSub::Channel &channel)
EXCLUSIVE_LOCKS_REQUIRED(mutex_);
Status SubscribeInternal(const std::string &channel, const Callback &subscribe,
const StatusCallback &done,
const boost::optional<std::string> &id = boost::none);
@ -101,8 +153,7 @@ class GcsPubSub {
/// Mutex to protect the subscribe_callback_index_ field.
absl::Mutex mutex_;
std::unordered_map<std::string, int64_t> subscribe_callback_index_ GUARDED_BY(mutex_);
std::unordered_map<std::string, int64_t> unsubscribe_callback_index_ GUARDED_BY(mutex_);
absl::flat_hash_map<std::string, Channel> channels_ GUARDED_BY(mutex_);
};
} // namespace gcs

View file

@ -92,6 +92,8 @@ class GcsPubSubTest : public ::testing::Test {
template <typename Data>
void WaitPendingDone(const std::vector<Data> &data, int expected_count) {
auto condition = [&data, expected_count]() {
RAY_CHECK((int)data.size() <= expected_count)
<< "Expected " << expected_count << " data " << data.size();
return (int)data.size() == expected_count;
};
EXPECT_TRUE(WaitForCondition(condition, timeout_ms_.count()));
@ -129,6 +131,31 @@ TEST_F(GcsPubSubTest, TestPubSubApi) {
WaitPendingDone(all_result, 3);
}
TEST_F(GcsPubSubTest, TestManyPubsub) {
std::string channel("channel");
std::string id("id");
std::string data("data");
std::vector<std::pair<std::string, std::string>> all_result;
SubscribeAll(channel, all_result);
// Test many concurrent subscribes and unsubscribes.
for (int i = 0; i < 1000; i++) {
auto subscribe = [](const std::string &id, const std::string &data) {};
RAY_CHECK_OK((pub_sub_->Subscribe(channel, id, subscribe, nullptr)));
RAY_CHECK_OK((pub_sub_->Unsubscribe(channel, id)));
}
for (int i = 0; i < 1000; i++) {
std::vector<std::string> result;
// Use the synchronous subscribe to make sure our SUBSCRIBE message reaches
// Redis before the PUBLISH.
Subscribe(channel, id, result);
Publish(channel, id, data);
WaitPendingDone(result, 1);
WaitPendingDone(all_result, i + 1);
RAY_CHECK_OK((pub_sub_->Unsubscribe(channel, id)));
}
}
TEST_F(GcsPubSubTest, TestMultithreading) {
std::string channel("channel");
auto sub_message_count = std::make_shared<std::atomic<int>>(0);

View file

@ -37,7 +37,8 @@ void ProcessCallback(int64_t callback_index,
std::shared_ptr<ray::gcs::CallbackReply> callback_reply) {
RAY_CHECK(callback_index >= 0) << "The callback index must be greater than 0, "
<< "but it actually is " << callback_index;
auto callback_item = ray::gcs::RedisCallbackManager::instance().get(callback_index);
auto callback_item =
ray::gcs::RedisCallbackManager::instance().GetCallback(callback_index);
if (!callback_item->is_subscription_) {
// Record the redis latency for non-subscription redis operations.
auto end_time = absl::GetCurrentTimeNanos() / 1000;
@ -49,7 +50,7 @@ void ProcessCallback(int64_t callback_index,
if (!callback_item->is_subscription_) {
// Delete the callback if it's not a subscription callback.
ray::gcs::RedisCallbackManager::instance().remove(callback_index);
ray::gcs::RedisCallbackManager::instance().RemoveCallback(callback_index);
}
}
@ -203,25 +204,38 @@ void GlobalRedisCallback(void *c, void *r, void *privdata) {
ProcessCallback(callback_index, std::make_shared<CallbackReply>(reply));
}
int64_t RedisCallbackManager::add(const RedisCallback &function, bool is_subscription,
boost::asio::io_service &io_service) {
auto start_time = absl::GetCurrentTimeNanos() / 1000;
int64_t RedisCallbackManager::AllocateCallbackIndex() {
std::lock_guard<std::mutex> lock(mutex_);
callback_items_.emplace(
num_callbacks_,
std::make_shared<CallbackItem>(function, is_subscription, start_time, io_service));
return num_callbacks_++;
}
std::shared_ptr<RedisCallbackManager::CallbackItem> RedisCallbackManager::get(
int64_t callback_index) {
int64_t RedisCallbackManager::AddCallback(const RedisCallback &function,
bool is_subscription,
boost::asio::io_service &io_service,
int64_t callback_index) {
auto start_time = absl::GetCurrentTimeNanos() / 1000;
std::lock_guard<std::mutex> lock(mutex_);
RAY_CHECK(callback_items_.find(callback_index) != callback_items_.end());
return callback_items_[callback_index];
if (callback_index == -1) {
// No callback index was specified. Allocate a new callback index.
callback_index = num_callbacks_;
num_callbacks_++;
}
callback_items_.emplace(
callback_index,
std::make_shared<CallbackItem>(function, is_subscription, start_time, io_service));
return callback_index;
}
void RedisCallbackManager::remove(int64_t callback_index) {
std::shared_ptr<RedisCallbackManager::CallbackItem> RedisCallbackManager::GetCallback(
int64_t callback_index) const {
std::lock_guard<std::mutex> lock(mutex_);
auto it = callback_items_.find(callback_index);
RAY_CHECK(it != callback_items_.end()) << callback_index;
return it->second;
}
void RedisCallbackManager::RemoveCallback(int64_t callback_index) {
std::lock_guard<std::mutex> lock(mutex_);
callback_items_.erase(callback_index);
}
@ -362,7 +376,7 @@ Status RedisContext::RunArgvAsync(const std::vector<std::string> &args,
argc.push_back(args[i].size());
}
int64_t callback_index =
RedisCallbackManager::instance().add(redis_callback, false, io_service_);
RedisCallbackManager::instance().AddCallback(redis_callback, false, io_service_);
// Run the Redis command.
Status status = redis_async_context_->RedisAsyncCommandArgv(
reinterpret_cast<redisCallbackFn *>(&GlobalRedisCallback),
@ -379,7 +393,7 @@ Status RedisContext::SubscribeAsync(const ClientID &client_id,
RAY_CHECK(async_redis_subscribe_context_);
int64_t callback_index =
RedisCallbackManager::instance().add(redisCallback, true, io_service_);
RedisCallbackManager::instance().AddCallback(redisCallback, true, io_service_);
RAY_CHECK(out_callback_index != nullptr);
*out_callback_index = callback_index;
Status status = Status::OK();
@ -403,12 +417,11 @@ Status RedisContext::SubscribeAsync(const ClientID &client_id,
Status RedisContext::PSubscribeAsync(const std::string &pattern,
const RedisCallback &redisCallback,
int64_t *out_callback_index) {
int64_t callback_index) {
RAY_CHECK(async_redis_subscribe_context_);
int64_t callback_index =
RedisCallbackManager::instance().add(redisCallback, true, io_service_);
*out_callback_index = callback_index;
RAY_UNUSED(RedisCallbackManager::instance().AddCallback(redisCallback, true,
io_service_, callback_index));
std::string redis_command = "PSUBSCRIBE %b";
return async_redis_subscribe_context_->RedisAsyncCommand(
reinterpret_cast<redisCallbackFn *>(&GlobalRedisCallback),

View file

@ -146,20 +146,25 @@ class RedisCallbackManager {
boost::asio::io_service *io_service_;
};
int64_t add(const RedisCallback &function, bool is_subscription,
boost::asio::io_service &io_service);
/// Allocate an index at which we can add a callback later on.
int64_t AllocateCallbackIndex();
std::shared_ptr<CallbackItem> get(int64_t callback_index);
/// Add a callback at an optionally specified index.
int64_t AddCallback(const RedisCallback &function, bool is_subscription,
boost::asio::io_service &io_service, int64_t callback_index = -1);
/// Remove a callback.
void remove(int64_t callback_index);
void RemoveCallback(int64_t callback_index);
/// Get a callback.
std::shared_ptr<CallbackItem> GetCallback(int64_t callback_index) const;
private:
RedisCallbackManager() : num_callbacks_(0){};
~RedisCallbackManager() {}
std::mutex mutex_;
mutable std::mutex mutex_;
int64_t num_callbacks_ = 0;
std::unordered_map<int64_t, std::shared_ptr<CallbackItem>> callback_items_;
@ -245,10 +250,12 @@ class RedisContext {
///
/// \param pattern The pattern of subscription channel.
/// \param redisCallback The callback function that the notification calls.
/// \param out_callback_index The output pointer to callback index.
/// \param callback_index The index at which to add the callback. This index
/// must already be allocated in the callback manager via
/// RedisCallbackManager::AllocateCallbackIndex.
/// \return Status.
Status PSubscribeAsync(const std::string &pattern, const RedisCallback &redisCallback,
int64_t *out_callback_index);
int64_t callback_index);
/// Unsubscribes the client from the given pattern.
///
@ -296,7 +303,7 @@ Status RedisContext::RunAsync(const std::string &command, const ID &id, const vo
RedisCallback redisCallback, int log_length) {
RAY_CHECK(redis_async_context_);
int64_t callback_index =
RedisCallbackManager::instance().add(redisCallback, false, io_service_);
RedisCallbackManager::instance().AddCallback(redisCallback, false, io_service_);
Status status = Status::OK();
if (length > 0) {
if (log_length >= 0) {