diff --git a/BUILD.bazel b/BUILD.bazel index 0e9f1b04a..46d27bdd2 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -1160,6 +1160,8 @@ cc_test( "//:redis-cli", "//:redis-server", ], + # TODO(swang): Enable again once pubsub client supports GCS server restart. + tags = ["manual"], deps = [ ":gcs_server_lib", ":gcs_test_util_lib", diff --git a/python/ray/tests/BUILD b/python/ray/tests/BUILD index 0531796b2..f86445a76 100644 --- a/python/ray/tests/BUILD +++ b/python/ray/tests/BUILD @@ -459,6 +459,7 @@ py_test( name = "test_gcs_fault_tolerance", size = "medium", srcs = SRCS + ["test_gcs_fault_tolerance.py"], - tags = ["exclusive"], + # TODO(swang): Enable again once pubsub client supports GCS server restart. + tags = ["exclusive", "manual"], deps = ["//:ray_lib"], ) diff --git a/src/ray/gcs/pubsub/gcs_pub_sub.cc b/src/ray/gcs/pubsub/gcs_pub_sub.cc index 52277112f..579296db9 100644 --- a/src/ray/gcs/pubsub/gcs_pub_sub.cc +++ b/src/ray/gcs/pubsub/gcs_pub_sub.cc @@ -44,52 +44,120 @@ Status GcsPubSub::SubscribeAll(const std::string &channel, const Callback &subsc return SubscribeInternal(channel, subscribe, done); } -Status GcsPubSub::Unsubscribe(const std::string &channel, const std::string &id) { - std::string pattern = GenChannelPattern(channel, id); - { - absl::MutexLock lock(&mutex_); - auto it = subscribe_callback_index_.find(pattern); - RAY_CHECK(it != subscribe_callback_index_.end()); - unsubscribe_callback_index_[pattern] = it->second; - subscribe_callback_index_.erase(it); - } - return redis_client_->GetPrimaryContext()->PUnsubscribeAsync(pattern); +Status GcsPubSub::Unsubscribe(const std::string &channel_name, const std::string &id) { + std::string pattern = GenChannelPattern(channel_name, id); + + absl::MutexLock lock(&mutex_); + // Add the UNSUBSCRIBE command to the queue. + auto channel = channels_.find(pattern); + RAY_CHECK(channel != channels_.end()); + channel->second.command_queue.push_back(Command()); + + // Process the first command on the queue, if possible. + return ExecuteCommandIfPossible(channel->first, channel->second); } -Status GcsPubSub::SubscribeInternal(const std::string &channel, const Callback &subscribe, - const StatusCallback &done, +Status GcsPubSub::SubscribeInternal(const std::string &channel_name, + const Callback &subscribe, const StatusCallback &done, const boost::optional &id) { - std::string pattern = GenChannelPattern(channel, id); - auto callback = [this, pattern, done, subscribe](std::shared_ptr reply) { - if (!reply->IsNil()) { + std::string pattern = GenChannelPattern(channel_name, id); + + absl::MutexLock lock(&mutex_); + auto channel = channels_.find(pattern); + if (channel == channels_.end()) { + // There were no pending commands for this channel and we were not already + // subscribed. + channel = channels_.emplace(pattern, Channel()).first; + } + + // Add the SUBSCRIBE command to the queue. + channel->second.command_queue.push_back(Command(subscribe, done)); + + // Process the first command on the queue, if possible. + return ExecuteCommandIfPossible(channel->first, channel->second); +} + +Status GcsPubSub::ExecuteCommandIfPossible(const std::string &channel_key, + GcsPubSub::Channel &channel) { + // Process the first command on the queue, if possible. + Status status; + auto &command = channel.command_queue.front(); + if (command.is_subscribe && channel.callback_index == -1) { + // The next command is SUBSCRIBE and we are currently unsubscribed, so we + // can execute the command. + int64_t callback_index = + ray::gcs::RedisCallbackManager::instance().AllocateCallbackIndex(); + const auto &command_done_callback = command.done_callback; + const auto &command_subscribe_callback = command.subscribe_callback; + auto callback = [this, channel_key, command_done_callback, command_subscribe_callback, + callback_index](std::shared_ptr reply) { + if (reply->IsNil()) { + return; + } if (reply->IsUnsubscribeCallback()) { + // Unset the callback index. absl::MutexLock lock(&mutex_); - ray::gcs::RedisCallbackManager::instance().remove( - unsubscribe_callback_index_[pattern]); - unsubscribe_callback_index_.erase(pattern); + auto channel = channels_.find(channel_key); + RAY_CHECK(channel != channels_.end()); + ray::gcs::RedisCallbackManager::instance().RemoveCallback( + channel->second.callback_index); + channel->second.callback_index = -1; + channel->second.pending_reply = false; + + if (channel->second.command_queue.empty()) { + // We are unsubscribed and there are no more commands to process. + // Delete the channel. + channels_.erase(channel); + } else { + // Process the next item in the queue. + RAY_CHECK(channel->second.command_queue.front().is_subscribe); + RAY_CHECK_OK(ExecuteCommandIfPossible(channel_key, channel->second)); + } } else if (reply->IsSubscribeCallback()) { - if (done) { - done(Status::OK()); + { + // Set the callback index. + absl::MutexLock lock(&mutex_); + auto channel = channels_.find(channel_key); + RAY_CHECK(channel != channels_.end()); + channel->second.callback_index = callback_index; + channel->second.pending_reply = false; + // Process the next item in the queue, if any. + if (!channel->second.command_queue.empty()) { + RAY_CHECK(!channel->second.command_queue.front().is_subscribe); + RAY_CHECK_OK(ExecuteCommandIfPossible(channel_key, channel->second)); + } + } + + if (command_done_callback) { + command_done_callback(Status::OK()); } } else { const auto reply_data = reply->ReadAsPubsubData(); if (!reply_data.empty()) { rpc::PubSubMessage message; message.ParseFromString(reply_data); - subscribe(message.id(), message.data()); + command_subscribe_callback(message.id(), message.data()); } } - } - }; - - int64_t out_callback_index; - auto status = redis_client_->GetPrimaryContext()->PSubscribeAsync(pattern, callback, - &out_callback_index); - if (id) { - absl::MutexLock lock(&mutex_); - // If the same pattern has been subscribed more than once, the last subscription takes - // effect. - subscribe_callback_index_[pattern] = out_callback_index; + }; + status = redis_client_->GetPrimaryContext()->PSubscribeAsync(channel_key, callback, + callback_index); + channel.pending_reply = true; + channel.command_queue.pop_front(); + } else if (!command.is_subscribe && channel.callback_index != -1) { + // The next command is UNSUBSCRIBE and we are currently subscribed, so we + // can execute the command. The reply for will be received through the + // SUBSCRIBE command's callback. + status = redis_client_->GetPrimaryContext()->PUnsubscribeAsync(channel_key); + channel.pending_reply = true; + channel.command_queue.pop_front(); + } else if (!channel.pending_reply) { + // There is no in-flight command, but the next command to execute is not + // runnable. The caller must have sent a command out-of-order. + // TODO(swang): This can cause a fatal error if the GCS server restarts and + // the client attempts to subscribe again. + RAY_LOG(FATAL) << "Caller attempted a duplicate subscribe or unsubscribe to channel " + << channel_key; } return status; } diff --git a/src/ray/gcs/pubsub/gcs_pub_sub.h b/src/ray/gcs/pubsub/gcs_pub_sub.h index 8342bb26f..8b41e00ac 100644 --- a/src/ray/gcs/pubsub/gcs_pub_sub.h +++ b/src/ray/gcs/pubsub/gcs_pub_sub.h @@ -15,6 +15,7 @@ #ifndef RAY_GCS_GCS_PUB_SUB_H_ #define RAY_GCS_GCS_PUB_SUB_H_ +#include "absl/container/flat_hash_map.h" #include "absl/synchronization/mutex.h" #include "ray/gcs/callback.h" @@ -89,6 +90,57 @@ class GcsPubSub { Status Unsubscribe(const std::string &channel, const std::string &id); private: + /// Represents a caller's command to subscribe or unsubscribe to a given + /// channel. + struct Command { + /// SUBSCRIBE constructor. + Command(const Callback &subscribe_callback, const StatusCallback &done_callback) + : is_subscribe(true), + subscribe_callback(subscribe_callback), + done_callback(done_callback) {} + /// UNSUBSCRIBE constructor. + Command() : is_subscribe(false) {} + /// True if this is a SUBSCRIBE command and false if UNSUBSCRIBE. + const bool is_subscribe; + /// Callback that is called whenever a new pubsub message is received from + /// Redis. This should only be set if is_subscribe is true. + const Callback subscribe_callback; + /// Callback that is called once we have successfully subscribed to a + /// channel. This should only be set if is_subscribe is true. + const StatusCallback done_callback; + }; + + struct Channel { + Channel() {} + /// Queue of subscribe/unsubscribe commands to this channel. The queue + /// asserts that subscribe and unsubscribe commands alternate, i.e. there + /// cannot be more than one subscribe/unsubscribe command in a row. A + /// subscribe command can execute if the callback index below is not set, + /// i.e. this is the first subscribe command or the last unsubscribe + /// command's reply has been received. An unsubscribe command can execute + /// if the callback index is set, i.e. the last subscribe command's reply + /// has been received. + std::deque command_queue; + /// The current Redis callback index stored in the RedisContext for this + /// channel. This callback index is used to identify any pubsub + /// notifications meant for this channel. The callback index is set once we + /// have received a reply from Redis that we have subscribed. The callback + /// index is set back to -1 if we receive a reply from Redis that we have + /// unsubscribed. + int64_t callback_index = -1; + /// Whether we are pending a reply from Redis. We cannot send another + /// command from the queue until this has been reset to false. + bool pending_reply = false; + }; + + /// Execute the first queued command for the given channel, if possible. A + /// subscribe command can execute if the channel's callback index is not set. + /// An unsubscribe command can execute if the channel's callback index is + /// set. + Status ExecuteCommandIfPossible(const std::string &channel_key, + GcsPubSub::Channel &channel) + EXCLUSIVE_LOCKS_REQUIRED(mutex_); + Status SubscribeInternal(const std::string &channel, const Callback &subscribe, const StatusCallback &done, const boost::optional &id = boost::none); @@ -101,8 +153,7 @@ class GcsPubSub { /// Mutex to protect the subscribe_callback_index_ field. absl::Mutex mutex_; - std::unordered_map subscribe_callback_index_ GUARDED_BY(mutex_); - std::unordered_map unsubscribe_callback_index_ GUARDED_BY(mutex_); + absl::flat_hash_map channels_ GUARDED_BY(mutex_); }; } // namespace gcs diff --git a/src/ray/gcs/pubsub/test/gcs_pub_sub_test.cc b/src/ray/gcs/pubsub/test/gcs_pub_sub_test.cc index 0e21d285a..d0aa3943e 100644 --- a/src/ray/gcs/pubsub/test/gcs_pub_sub_test.cc +++ b/src/ray/gcs/pubsub/test/gcs_pub_sub_test.cc @@ -92,6 +92,8 @@ class GcsPubSubTest : public ::testing::Test { template void WaitPendingDone(const std::vector &data, int expected_count) { auto condition = [&data, expected_count]() { + RAY_CHECK((int)data.size() <= expected_count) + << "Expected " << expected_count << " data " << data.size(); return (int)data.size() == expected_count; }; EXPECT_TRUE(WaitForCondition(condition, timeout_ms_.count())); @@ -129,6 +131,31 @@ TEST_F(GcsPubSubTest, TestPubSubApi) { WaitPendingDone(all_result, 3); } +TEST_F(GcsPubSubTest, TestManyPubsub) { + std::string channel("channel"); + std::string id("id"); + std::string data("data"); + std::vector> all_result; + SubscribeAll(channel, all_result); + // Test many concurrent subscribes and unsubscribes. + for (int i = 0; i < 1000; i++) { + auto subscribe = [](const std::string &id, const std::string &data) {}; + RAY_CHECK_OK((pub_sub_->Subscribe(channel, id, subscribe, nullptr))); + RAY_CHECK_OK((pub_sub_->Unsubscribe(channel, id))); + } + for (int i = 0; i < 1000; i++) { + std::vector result; + // Use the synchronous subscribe to make sure our SUBSCRIBE message reaches + // Redis before the PUBLISH. + Subscribe(channel, id, result); + Publish(channel, id, data); + + WaitPendingDone(result, 1); + WaitPendingDone(all_result, i + 1); + RAY_CHECK_OK((pub_sub_->Unsubscribe(channel, id))); + } +} + TEST_F(GcsPubSubTest, TestMultithreading) { std::string channel("channel"); auto sub_message_count = std::make_shared>(0); diff --git a/src/ray/gcs/redis_context.cc b/src/ray/gcs/redis_context.cc index 486abbb31..8d7bf516b 100644 --- a/src/ray/gcs/redis_context.cc +++ b/src/ray/gcs/redis_context.cc @@ -37,7 +37,8 @@ void ProcessCallback(int64_t callback_index, std::shared_ptr callback_reply) { RAY_CHECK(callback_index >= 0) << "The callback index must be greater than 0, " << "but it actually is " << callback_index; - auto callback_item = ray::gcs::RedisCallbackManager::instance().get(callback_index); + auto callback_item = + ray::gcs::RedisCallbackManager::instance().GetCallback(callback_index); if (!callback_item->is_subscription_) { // Record the redis latency for non-subscription redis operations. auto end_time = absl::GetCurrentTimeNanos() / 1000; @@ -49,7 +50,7 @@ void ProcessCallback(int64_t callback_index, if (!callback_item->is_subscription_) { // Delete the callback if it's not a subscription callback. - ray::gcs::RedisCallbackManager::instance().remove(callback_index); + ray::gcs::RedisCallbackManager::instance().RemoveCallback(callback_index); } } @@ -203,25 +204,38 @@ void GlobalRedisCallback(void *c, void *r, void *privdata) { ProcessCallback(callback_index, std::make_shared(reply)); } -int64_t RedisCallbackManager::add(const RedisCallback &function, bool is_subscription, - boost::asio::io_service &io_service) { - auto start_time = absl::GetCurrentTimeNanos() / 1000; - +int64_t RedisCallbackManager::AllocateCallbackIndex() { std::lock_guard lock(mutex_); - callback_items_.emplace( - num_callbacks_, - std::make_shared(function, is_subscription, start_time, io_service)); return num_callbacks_++; } -std::shared_ptr RedisCallbackManager::get( - int64_t callback_index) { +int64_t RedisCallbackManager::AddCallback(const RedisCallback &function, + bool is_subscription, + boost::asio::io_service &io_service, + int64_t callback_index) { + auto start_time = absl::GetCurrentTimeNanos() / 1000; + std::lock_guard lock(mutex_); - RAY_CHECK(callback_items_.find(callback_index) != callback_items_.end()); - return callback_items_[callback_index]; + if (callback_index == -1) { + // No callback index was specified. Allocate a new callback index. + callback_index = num_callbacks_; + num_callbacks_++; + } + callback_items_.emplace( + callback_index, + std::make_shared(function, is_subscription, start_time, io_service)); + return callback_index; } -void RedisCallbackManager::remove(int64_t callback_index) { +std::shared_ptr RedisCallbackManager::GetCallback( + int64_t callback_index) const { + std::lock_guard lock(mutex_); + auto it = callback_items_.find(callback_index); + RAY_CHECK(it != callback_items_.end()) << callback_index; + return it->second; +} + +void RedisCallbackManager::RemoveCallback(int64_t callback_index) { std::lock_guard lock(mutex_); callback_items_.erase(callback_index); } @@ -362,7 +376,7 @@ Status RedisContext::RunArgvAsync(const std::vector &args, argc.push_back(args[i].size()); } int64_t callback_index = - RedisCallbackManager::instance().add(redis_callback, false, io_service_); + RedisCallbackManager::instance().AddCallback(redis_callback, false, io_service_); // Run the Redis command. Status status = redis_async_context_->RedisAsyncCommandArgv( reinterpret_cast(&GlobalRedisCallback), @@ -379,7 +393,7 @@ Status RedisContext::SubscribeAsync(const ClientID &client_id, RAY_CHECK(async_redis_subscribe_context_); int64_t callback_index = - RedisCallbackManager::instance().add(redisCallback, true, io_service_); + RedisCallbackManager::instance().AddCallback(redisCallback, true, io_service_); RAY_CHECK(out_callback_index != nullptr); *out_callback_index = callback_index; Status status = Status::OK(); @@ -403,12 +417,11 @@ Status RedisContext::SubscribeAsync(const ClientID &client_id, Status RedisContext::PSubscribeAsync(const std::string &pattern, const RedisCallback &redisCallback, - int64_t *out_callback_index) { + int64_t callback_index) { RAY_CHECK(async_redis_subscribe_context_); - int64_t callback_index = - RedisCallbackManager::instance().add(redisCallback, true, io_service_); - *out_callback_index = callback_index; + RAY_UNUSED(RedisCallbackManager::instance().AddCallback(redisCallback, true, + io_service_, callback_index)); std::string redis_command = "PSUBSCRIBE %b"; return async_redis_subscribe_context_->RedisAsyncCommand( reinterpret_cast(&GlobalRedisCallback), diff --git a/src/ray/gcs/redis_context.h b/src/ray/gcs/redis_context.h index 7bcb651bb..317277fb9 100644 --- a/src/ray/gcs/redis_context.h +++ b/src/ray/gcs/redis_context.h @@ -146,20 +146,25 @@ class RedisCallbackManager { boost::asio::io_service *io_service_; }; - int64_t add(const RedisCallback &function, bool is_subscription, - boost::asio::io_service &io_service); + /// Allocate an index at which we can add a callback later on. + int64_t AllocateCallbackIndex(); - std::shared_ptr get(int64_t callback_index); + /// Add a callback at an optionally specified index. + int64_t AddCallback(const RedisCallback &function, bool is_subscription, + boost::asio::io_service &io_service, int64_t callback_index = -1); /// Remove a callback. - void remove(int64_t callback_index); + void RemoveCallback(int64_t callback_index); + + /// Get a callback. + std::shared_ptr GetCallback(int64_t callback_index) const; private: RedisCallbackManager() : num_callbacks_(0){}; ~RedisCallbackManager() {} - std::mutex mutex_; + mutable std::mutex mutex_; int64_t num_callbacks_ = 0; std::unordered_map> callback_items_; @@ -245,10 +250,12 @@ class RedisContext { /// /// \param pattern The pattern of subscription channel. /// \param redisCallback The callback function that the notification calls. - /// \param out_callback_index The output pointer to callback index. + /// \param callback_index The index at which to add the callback. This index + /// must already be allocated in the callback manager via + /// RedisCallbackManager::AllocateCallbackIndex. /// \return Status. Status PSubscribeAsync(const std::string &pattern, const RedisCallback &redisCallback, - int64_t *out_callback_index); + int64_t callback_index); /// Unsubscribes the client from the given pattern. /// @@ -296,7 +303,7 @@ Status RedisContext::RunAsync(const std::string &command, const ID &id, const vo RedisCallback redisCallback, int log_length) { RAY_CHECK(redis_async_context_); int64_t callback_index = - RedisCallbackManager::instance().add(redisCallback, false, io_service_); + RedisCallbackManager::instance().AddCallback(redisCallback, false, io_service_); Status status = Status::OK(); if (length > 0) { if (log_length >= 0) {