Fix race condition in redis_async_context.cc (#6231)

* dispatch callback to backend thread

* tmp: test in loop

* compiling

* Works using shared_ptrs

* Revert "tmp: test in loop"

This reverts commit faf1f8f74b34a99396906f56827d2691472ae7d4.

* Copy into CallbackReply

* fix comment

* warning

* add nil case
This commit is contained in:
Edward Oakes 2019-11-22 15:51:40 -08:00 committed by GitHub
parent f53f576120
commit ae5abc48a9
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 156 additions and 129 deletions

View file

@ -31,7 +31,12 @@ void RedisAsyncContext::ResetRawRedisAsyncContext() {
}
void RedisAsyncContext::RedisAsyncHandleRead() {
// `redisAsyncHandleRead` is already thread-safe, so no lock here.
// `redisAsyncHandleRead` will mutate `redis_async_context_`, use a lock to protect
// it.
// This function will execute the callbacks which are registered by
// `redisvAsyncCommand`, `redisAsyncCommandArgv` and so on.
std::lock_guard<std::mutex> lock(mutex_);
redisAsyncHandleRead(redis_async_context_);
}

View file

@ -21,20 +21,20 @@ namespace {
/// A helper function to call the callback and delete it from the callback
/// manager if necessary.
void ProcessCallback(int64_t callback_index,
const ray::gcs::CallbackReply &callback_reply) {
std::shared_ptr<ray::gcs::CallbackReply> callback_reply) {
RAY_CHECK(callback_index >= 0) << "The callback index must be greater than 0, "
<< "but it actually is " << callback_index;
auto callback_item = ray::gcs::RedisCallbackManager::instance().get(callback_index);
if (!callback_item.is_subscription) {
if (!callback_item->is_subscription_) {
// Record the redis latency for non-subscription redis operations.
auto end_time = absl::GetCurrentTimeNanos() / 1000;
ray::stats::RedisLatency().Record(end_time - callback_item.start_time);
ray::stats::RedisLatency().Record(end_time - callback_item->start_time_);
}
// Invoke the callback.
if (callback_item.callback != nullptr) {
callback_item.callback(callback_reply);
}
if (!callback_item.is_subscription) {
// Dispatch the callback.
callback_item->Dispatch(callback_reply);
if (!callback_item->is_subscription_) {
// Delete the callback if it's not a subscription callback.
ray::gcs::RedisCallbackManager::instance().remove(callback_index);
}
@ -46,79 +46,77 @@ namespace ray {
namespace gcs {
CallbackReply::CallbackReply(redisReply *redis_reply) {
CallbackReply::CallbackReply(redisReply *redis_reply) : reply_type_(redis_reply->type) {
RAY_CHECK(nullptr != redis_reply);
RAY_CHECK(redis_reply->type != REDIS_REPLY_ERROR)
<< "Got an error in redis reply: " << redis_reply->str;
this->redis_reply_ = redis_reply;
switch (reply_type_) {
case REDIS_REPLY_NIL: {
break;
}
case REDIS_REPLY_ERROR: {
RAY_CHECK(false) << "Got an error in redis reply: " << redis_reply->str;
break;
}
case REDIS_REPLY_INTEGER: {
int_reply_ = static_cast<int64_t>(redis_reply->integer);
break;
}
case REDIS_REPLY_STATUS: {
const std::string status_str(redis_reply->str, redis_reply->len);
if (status_str == "OK") {
status_reply_ = Status::OK();
} else {
status_reply_ = Status::RedisError(status_str);
}
break;
}
case REDIS_REPLY_STRING: {
string_reply_ = std::string(redis_reply->str, redis_reply->len);
break;
}
case REDIS_REPLY_ARRAY: {
// Array replies are only used for pub-sub messages. Parse the published message.
redisReply *message_type = redis_reply->element[0];
if (strcmp(message_type->str, "subscribe") == 0) {
// If the message is for the initial subscription call, return the empty
// string as a response to signify that subscription was successful.
} else if (strcmp(message_type->str, "message") == 0) {
// If the message is from a PUBLISH, make sure the data is nonempty.
redisReply *message = redis_reply->element[redis_reply->elements - 1];
// data is a notification message.
string_reply_ = std::string(message->str, message->len);
RAY_CHECK(!string_reply_.empty()) << "Empty message received on subscribe channel.";
} else {
RAY_LOG(FATAL) << "This is not a pubsub reply: data=" << message_type->str;
}
break;
}
default: {
RAY_LOG(WARNING) << "Encountered unexpected redis reply type: " << reply_type_;
}
}
}
bool CallbackReply::IsNil() const { return REDIS_REPLY_NIL == redis_reply_->type; }
bool CallbackReply::IsNil() const { return REDIS_REPLY_NIL == reply_type_; }
int64_t CallbackReply::ReadAsInteger() const {
RAY_CHECK(REDIS_REPLY_INTEGER == redis_reply_->type)
<< "Unexpected type: " << redis_reply_->type;
return static_cast<int64_t>(redis_reply_->integer);
}
std::string CallbackReply::ReadAsString() const {
RAY_CHECK(REDIS_REPLY_STRING == redis_reply_->type)
<< "Unexpected type: " << redis_reply_->type;
return std::string(redis_reply_->str, redis_reply_->len);
RAY_CHECK(reply_type_ == REDIS_REPLY_INTEGER) << "Unexpected type: " << reply_type_;
return int_reply_;
}
Status CallbackReply::ReadAsStatus() const {
RAY_CHECK(REDIS_REPLY_STATUS == redis_reply_->type)
<< "Unexpected type: " << redis_reply_->type;
const std::string status_str(redis_reply_->str, redis_reply_->len);
if ("OK" == status_str) {
return Status::OK();
}
RAY_CHECK(reply_type_ == REDIS_REPLY_STATUS) << "Unexpected type: " << reply_type_;
return status_reply_;
}
return Status::RedisError(status_str);
std::string CallbackReply::ReadAsString() const {
RAY_CHECK(reply_type_ == REDIS_REPLY_STRING) << "Unexpected type: " << reply_type_;
return string_reply_;
}
std::string CallbackReply::ReadAsPubsubData() const {
RAY_CHECK(REDIS_REPLY_ARRAY == redis_reply_->type)
<< "Unexpected type: " << redis_reply_->type;
std::string data = "";
// Parse the published message.
redisReply *message_type = redis_reply_->element[0];
if (strcmp(message_type->str, "subscribe") == 0) {
// If the message is for the initial subscription call, return the empty
// string as a response to signify that subscription was successful.
} else if (strcmp(message_type->str, "message") == 0) {
// If the message is from a PUBLISH, make sure the data is nonempty.
redisReply *message = redis_reply_->element[redis_reply_->elements - 1];
// data is a notification message.
data = std::string(message->str, message->len);
RAY_CHECK(!data.empty()) << "Empty message received on subscribe channel.";
} else {
RAY_LOG(FATAL) << "This is not a pubsub reply: data=" << message_type->str;
}
return data;
}
void CallbackReply::ReadAsStringArray(std::vector<std::string> *array) const {
RAY_CHECK(nullptr != array) << "Argument `array` must not be nullptr.";
RAY_CHECK(REDIS_REPLY_ARRAY == redis_reply_->type);
const auto array_size = static_cast<size_t>(redis_reply_->elements);
if (array_size > 0) {
auto *entry = redis_reply_->element[0];
const bool is_pubsub_reply =
strcmp(entry->str, "subscribe") == 0 || strcmp(entry->str, "message") == 0;
RAY_CHECK(!is_pubsub_reply) << "Subpub reply cannot be read as a string array.";
}
array->reserve(array_size);
for (size_t i = 0; i < array_size; ++i) {
auto *entry = redis_reply_->element[i];
RAY_CHECK(REDIS_REPLY_STRING == entry->type) << "Unexcepted type: " << entry->type;
array->push_back(std::string(entry->str, entry->len));
}
RAY_CHECK(reply_type_ == REDIS_REPLY_ARRAY) << "Unexpected type: " << reply_type_;
return string_reply_;
}
// This is a global redis callback which will be registered for every
@ -130,19 +128,22 @@ void GlobalRedisCallback(void *c, void *r, void *privdata) {
}
int64_t callback_index = reinterpret_cast<int64_t>(privdata);
redisReply *reply = reinterpret_cast<redisReply *>(r);
ProcessCallback(callback_index, CallbackReply(reply));
ProcessCallback(callback_index, std::make_shared<CallbackReply>(reply));
}
int64_t RedisCallbackManager::add(const RedisCallback &function, bool is_subscription) {
int64_t RedisCallbackManager::add(const RedisCallback &function, bool is_subscription,
boost::asio::io_service &io_service) {
auto start_time = absl::GetCurrentTimeNanos() / 1000;
std::lock_guard<std::mutex> lock(mutex_);
callback_items_.emplace(num_callbacks_,
CallbackItem(function, is_subscription, start_time));
callback_items_.emplace(
num_callbacks_,
std::make_shared<CallbackItem>(function, is_subscription, start_time, io_service));
return num_callbacks_++;
}
RedisCallbackManager::CallbackItem &RedisCallbackManager::get(int64_t callback_index) {
std::shared_ptr<RedisCallbackManager::CallbackItem> RedisCallbackManager::get(
int64_t callback_index) {
std::lock_guard<std::mutex> lock(mutex_);
RAY_CHECK(callback_items_.find(callback_index) != callback_items_.end());
return callback_items_[callback_index];
@ -280,7 +281,8 @@ Status RedisContext::SubscribeAsync(const ClientID &client_id,
<< "Client requested subscribe on a table that does not support pubsub";
RAY_CHECK(async_redis_subscribe_context_);
int64_t callback_index = RedisCallbackManager::instance().add(redisCallback, true);
int64_t callback_index =
RedisCallbackManager::instance().add(redisCallback, true, io_service_);
RAY_CHECK(out_callback_index != nullptr);
*out_callback_index = callback_index;
Status status = Status::OK();

View file

@ -1,6 +1,8 @@
#ifndef RAY_GCS_REDIS_CONTEXT_H
#define RAY_GCS_REDIS_CONTEXT_H
#include <boost/asio.hpp>
#include <boost/bind.hpp>
#include <functional>
#include <memory>
#include <mutex>
@ -41,31 +43,36 @@ class CallbackReply {
/// Read this reply data as an integer.
int64_t ReadAsInteger() const;
/// Read this reply data as a status.
Status ReadAsStatus() const;
/// Read this reply data as a string.
///
/// Note that this will return an empty string if
/// the type of this reply is `nil` or `status`.
std::string ReadAsString() const;
/// Read this reply data as a status.
Status ReadAsStatus() const;
/// Read this reply data as a pub-sub data.
/// Read this reply data as pub-sub data.
std::string ReadAsPubsubData() const;
/// Read this reply data as a string array.
///
/// \param array Since the return-value may be large,
/// make it as an output parameter.
void ReadAsStringArray(std::vector<std::string> *array) const;
private:
redisReply *redis_reply_;
/// Flag indicating the type of reply this represents.
int reply_type_;
/// Reply data if reply_type_ is REDIS_REPLY_INTEGER.
int64_t int_reply_;
/// Reply data if reply_type_ is REDIS_REPLY_STATUS.
Status status_reply_;
/// Reply data if reply_type_ is REDIS_REPLY_STRING or REDIS_REPLY_ARRAY.
/// Note that REDIS_REPLY_ARRAY is only used for pub-sub data.
std::string string_reply_;
};
/// Every callback should take in a vector of the results from the Redis
/// operation.
using RedisCallback = std::function<void(const CallbackReply &)>;
using RedisCallback = std::function<void(std::shared_ptr<CallbackReply>)>;
void GlobalRedisCallback(void *c, void *r, void *privdata);
@ -76,24 +83,33 @@ class RedisCallbackManager {
return instance;
}
struct CallbackItem {
struct CallbackItem : public std::enable_shared_from_this<CallbackItem> {
CallbackItem() = default;
CallbackItem(const RedisCallback &callback, bool is_subscription,
int64_t start_time) {
this->callback = callback;
this->is_subscription = is_subscription;
this->start_time = start_time;
CallbackItem(const RedisCallback &callback, bool is_subscription, int64_t start_time,
boost::asio::io_service &io_service)
: callback_(callback),
is_subscription_(is_subscription),
start_time_(start_time),
io_service_(io_service) {}
void Dispatch(std::shared_ptr<CallbackReply> &reply) {
std::shared_ptr<CallbackItem> self = shared_from_this();
if (callback_ != nullptr) {
io_service_.post([self, reply]() { self->callback_(std::move(reply)); });
}
}
RedisCallback callback;
bool is_subscription;
int64_t start_time;
RedisCallback callback_;
bool is_subscription_;
int64_t start_time_;
boost::asio::io_service &io_service_;
};
int64_t add(const RedisCallback &function, bool is_subscription);
int64_t add(const RedisCallback &function, bool is_subscription,
boost::asio::io_service &io_service);
CallbackItem &get(int64_t callback_index);
std::shared_ptr<CallbackItem> get(int64_t callback_index);
/// Remove a callback.
void remove(int64_t callback_index);
@ -106,12 +122,13 @@ class RedisCallbackManager {
std::mutex mutex_;
int64_t num_callbacks_ = 0;
std::unordered_map<int64_t, CallbackItem> callback_items_;
std::unordered_map<int64_t, std::shared_ptr<CallbackItem>> callback_items_;
};
class RedisContext {
public:
RedisContext() : context_(nullptr) {}
RedisContext(boost::asio::io_service &io_service)
: io_service_(io_service), context_(nullptr) {}
~RedisContext();
@ -170,6 +187,7 @@ class RedisContext {
}
private:
boost::asio::io_service &io_service_;
redisContext *context_;
std::unique_ptr<RedisAsyncContext> redis_async_context_;
std::unique_ptr<RedisAsyncContext> async_redis_subscribe_context_;
@ -181,7 +199,8 @@ Status RedisContext::RunAsync(const std::string &command, const ID &id, const vo
const TablePubsub pubsub_channel,
RedisCallback redisCallback, int log_length) {
RAY_CHECK(redis_async_context_);
int64_t callback_index = RedisCallbackManager::instance().add(redisCallback, false);
int64_t callback_index =
RedisCallbackManager::instance().add(redisCallback, false, io_service_);
Status status = Status::OK();
if (length > 0) {
if (log_length >= 0) {

View file

@ -83,7 +83,7 @@ Status RedisGcsClient::Connect(boost::asio::io_service &io_service) {
return Status::Invalid("gcs service address is invalid!");
}
primary_context_ = std::make_shared<RedisContext>();
primary_context_ = std::make_shared<RedisContext>(io_service);
RAY_CHECK_OK(primary_context_->Connect(options_.server_ip_, options_.server_port_,
/*sharding=*/true,
@ -103,12 +103,12 @@ Status RedisGcsClient::Connect(boost::asio::io_service &io_service) {
for (size_t i = 0; i < addresses.size(); ++i) {
// Populate shard_contexts.
shard_contexts_.push_back(std::make_shared<RedisContext>());
shard_contexts_.push_back(std::make_shared<RedisContext>(io_service));
RAY_CHECK_OK(shard_contexts_[i]->Connect(addresses[i], ports[i], /*sharding=*/true,
/*password=*/options_.password_));
}
} else {
shard_contexts_.push_back(std::make_shared<RedisContext>());
shard_contexts_.push_back(std::make_shared<RedisContext>(io_service));
RAY_CHECK_OK(shard_contexts_[0]->Connect(options_.server_ip_, options_.server_port_,
/*sharding=*/true,
/*password=*/options_.password_));

View file

@ -44,8 +44,8 @@ Status Log<ID, Data>::Append(const JobID &job_id, const ID &id,
const std::shared_ptr<Data> &data,
const WriteCallback &done) {
num_appends_++;
auto callback = [this, id, data, done](const CallbackReply &reply) {
const auto status = reply.ReadAsStatus();
auto callback = [this, id, data, done](std::shared_ptr<CallbackReply> reply) {
const auto status = reply->ReadAsStatus();
// Failed to append the entry.
RAY_CHECK(status.ok()) << "Failed to execute command TABLE_APPEND:"
<< status.ToString();
@ -65,8 +65,8 @@ Status Log<ID, Data>::AppendAt(const JobID &job_id, const ID &id,
const WriteCallback &done, const WriteCallback &failure,
int log_length) {
num_appends_++;
auto callback = [this, id, data, done, failure](const CallbackReply &reply) {
const auto status = reply.ReadAsStatus();
auto callback = [this, id, data, done, failure](std::shared_ptr<CallbackReply> reply) {
const auto status = reply->ReadAsStatus();
if (status.ok()) {
if (done != nullptr) {
(done)(client_, id, *data);
@ -86,12 +86,12 @@ Status Log<ID, Data>::AppendAt(const JobID &job_id, const ID &id,
template <typename ID, typename Data>
Status Log<ID, Data>::Lookup(const JobID &job_id, const ID &id, const Callback &lookup) {
num_lookups_++;
auto callback = [this, id, lookup](const CallbackReply &reply) {
auto callback = [this, id, lookup](std::shared_ptr<CallbackReply> reply) {
if (lookup != nullptr) {
std::vector<Data> results;
if (!reply.IsNil()) {
if (!reply->IsNil()) {
GcsEntry gcs_entry;
gcs_entry.ParseFromString(reply.ReadAsString());
gcs_entry.ParseFromString(reply->ReadAsString());
RAY_CHECK(ID::FromBinary(gcs_entry.id()) == id);
for (int64_t i = 0; i < gcs_entry.entries_size(); i++) {
Data data;
@ -126,8 +126,8 @@ Status Log<ID, Data>::Subscribe(const JobID &job_id, const ClientID &client_id,
const SubscriptionCallback &done) {
RAY_CHECK(subscribe_callback_index_ == -1)
<< "Client called Subscribe twice on the same table";
auto callback = [this, subscribe, done](const CallbackReply &reply) {
const auto data = reply.ReadAsPubsubData();
auto callback = [this, subscribe, done](std::shared_ptr<CallbackReply> reply) {
const auto data = reply->ReadAsPubsubData();
if (data.empty()) {
// No notification data is provided. This is the callback for the
@ -170,8 +170,8 @@ Status Log<ID, Data>::RequestNotifications(const JobID &job_id, const ID &id,
RedisCallback callback = nullptr;
if (done != nullptr) {
callback = [done](const CallbackReply &reply) {
const auto status = reply.IsNil()
callback = [done](std::shared_ptr<CallbackReply> reply) {
const auto status = reply->IsNil()
? Status::OK()
: Status::RedisError("request notifications failed.");
done(status);
@ -192,8 +192,8 @@ Status Log<ID, Data>::CancelNotifications(const JobID &job_id, const ID &id,
RedisCallback callback = nullptr;
if (done != nullptr) {
callback = [done](const CallbackReply &reply) {
const auto status = reply.ReadAsStatus();
callback = [done](std::shared_ptr<CallbackReply> reply) {
const auto status = reply->ReadAsStatus();
done(status);
};
}
@ -254,7 +254,7 @@ Status Table<ID, Data>::Add(const JobID &job_id, const ID &id,
const std::shared_ptr<Data> &data,
const WriteCallback &done) {
num_adds_++;
auto callback = [this, id, data, done](const CallbackReply &reply) {
auto callback = [this, id, data, done](std::shared_ptr<CallbackReply> reply) {
if (done != nullptr) {
(done)(client_, id, *data);
}
@ -317,7 +317,7 @@ template <typename ID, typename Data>
Status Set<ID, Data>::Add(const JobID &job_id, const ID &id,
const std::shared_ptr<Data> &data, const WriteCallback &done) {
num_adds_++;
auto callback = [this, id, data, done](const CallbackReply &reply) {
auto callback = [this, id, data, done](std::shared_ptr<CallbackReply> reply) {
if (done != nullptr) {
(done)(client_, id, *data);
}
@ -332,7 +332,7 @@ Status Set<ID, Data>::Remove(const JobID &job_id, const ID &id,
const std::shared_ptr<Data> &data,
const WriteCallback &done) {
num_removes_++;
auto callback = [this, id, data, done](const CallbackReply &reply) {
auto callback = [this, id, data, done](std::shared_ptr<CallbackReply> reply) {
if (done != nullptr) {
(done)(client_, id, *data);
}
@ -354,7 +354,7 @@ template <typename ID, typename Data>
Status Hash<ID, Data>::Update(const JobID &job_id, const ID &id, const DataMap &data_map,
const HashCallback &done) {
num_adds_++;
auto callback = [this, id, data_map, done](const CallbackReply &reply) {
auto callback = [this, id, data_map, done](std::shared_ptr<CallbackReply> reply) {
if (done != nullptr) {
(done)(client_, id, data_map);
}
@ -376,7 +376,8 @@ Status Hash<ID, Data>::RemoveEntries(const JobID &job_id, const ID &id,
const std::vector<std::string> &keys,
const HashRemoveCallback &remove_callback) {
num_removes_++;
auto callback = [this, id, keys, remove_callback](const CallbackReply &reply) {
auto callback = [this, id, keys,
remove_callback](std::shared_ptr<CallbackReply> reply) {
if (remove_callback != nullptr) {
(remove_callback)(client_, id, keys);
}
@ -404,13 +405,13 @@ template <typename ID, typename Data>
Status Hash<ID, Data>::Lookup(const JobID &job_id, const ID &id,
const HashCallback &lookup) {
num_lookups_++;
auto callback = [this, id, lookup](const CallbackReply &reply) {
auto callback = [this, id, lookup](std::shared_ptr<CallbackReply> reply) {
if (lookup != nullptr) {
DataMap results;
if (!reply.IsNil()) {
const auto data = reply.ReadAsString();
if (!reply->IsNil()) {
const auto data = reply->ReadAsString();
GcsEntry gcs_entry;
gcs_entry.ParseFromString(reply.ReadAsString());
gcs_entry.ParseFromString(reply->ReadAsString());
RAY_CHECK(ID::FromBinary(gcs_entry.id()) == id);
RAY_CHECK(gcs_entry.entries_size() % 2 == 0);
for (int i = 0; i < gcs_entry.entries_size(); i += 2) {
@ -434,8 +435,8 @@ Status Hash<ID, Data>::Subscribe(const JobID &job_id, const ClientID &client_id,
const SubscriptionCallback &done) {
RAY_CHECK(subscribe_callback_index_ == -1)
<< "Client called Subscribe twice on the same table";
auto callback = [this, subscribe, done](const CallbackReply &reply) {
const auto data = reply.ReadAsPubsubData();
auto callback = [this, subscribe, done](std::shared_ptr<CallbackReply> reply) {
const auto data = reply->ReadAsPubsubData();
if (data.empty()) {
// No notification data is provided. This is the callback for the
// initial subscription request.