Request and cancel notifications in the new GCS API (#1758)

* Add TableRequestNotifications and TableCancelNotifications to Redis modules

* Add RequestNotifications and CancelNotifications to generic GCS Table

* Add tests for subscribing to specific keys

* Remove TODO!

* Return the current value at the key directly from RequestNotifications instead of through publish

* Add unit test for Lookup failure callback

* Modify tests to account for empty subscription response

* Remove ObjectTable notification methods

* Clean up message parsing and doc in redis context

* Use vectors of DataT in all GCS callbacks

* Clean up SubscriptionCallback

* Move Table definitions into tables.cc

* Refactor and document redis modules

* doc

* Fix new GCS build

* Cleanups

* Revert "Fix new GCS build"

This reverts commit 6e3e69090c67ef60aaf22a9cf62be0290d989e96.

* Use vectors for internal callback interface, user-facing interface takes a reference to a single item

* Fix new GCS build

* Add unit test for Lookup failure callback

* Fix compiler errors

* Cleanup

* Publish the entry ID with the notification

* Check that the ID for a notification matches in client tests
This commit is contained in:
Stephanie Wang 2018-03-22 10:31:07 -07:00 committed by GitHub
parent 0c835a379f
commit 8704c8618c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
9 changed files with 822 additions and 297 deletions

View file

@ -53,6 +53,32 @@ static const char *table_prefixes[] = {
NULL, "TASK:", "CLIENT:", "OBJECT:", "FUNCTION:",
};
/// Parse a Redis string into a TablePubsub channel.
TablePubsub ParseTablePubsub(const RedisModuleString *pubsub_channel_str) {
long long pubsub_channel_long;
RAY_CHECK(RedisModule_StringToLongLong(
pubsub_channel_str, &pubsub_channel_long) == REDISMODULE_OK)
<< "Pubsub channel must be a valid TablePubsub";
auto pubsub_channel = static_cast<TablePubsub>(pubsub_channel_long);
RAY_CHECK(pubsub_channel >= TablePubsub_MIN &&
pubsub_channel <= TablePubsub_MAX)
<< "Pubsub channel must be a valid TablePubsub";
return pubsub_channel;
}
/// Format a pubsub channel for a specific key. pubsub_channel_str should
/// contain a valid TablePubsub.
RedisModuleString *FormatPubsubChannel(
RedisModuleCtx *ctx,
const RedisModuleString *pubsub_channel_str,
const RedisModuleString *id) {
// Format the pubsub channel enum to a string. TablePubsub_MAX should be more
// than enough digits, but add 1 just in case for the null terminator.
char pubsub_channel[TablePubsub_MAX + 1];
sprintf(pubsub_channel, "%d", ParseTablePubsub(pubsub_channel_str));
return RedisString_Format(ctx, "%s:%S", pubsub_channel, id);
}
// TODO(swang): This helper function should be deprecated by the version below,
// which uses enums for table prefixes.
RedisModuleKey *OpenPrefixedKey(RedisModuleCtx *ctx,
@ -83,6 +109,23 @@ RedisModuleKey *OpenPrefixedKey(RedisModuleCtx *ctx,
return OpenPrefixedKey(ctx, table_prefixes[prefix], keyname, mode);
}
/// Open the key used to store the channels that should be published to when an
/// update happens at the given keyname.
RedisModuleKey *OpenBroadcastKey(RedisModuleCtx *ctx,
RedisModuleString *pubsub_channel_str,
RedisModuleString *keyname,
int mode) {
RedisModuleString *channel =
FormatPubsubChannel(ctx, pubsub_channel_str, keyname);
RedisModuleString *prefixed_keyname =
RedisString_Format(ctx, "BCAST:%S", channel);
RedisModuleKey *key =
(RedisModuleKey *) RedisModule_OpenKey(ctx, prefixed_keyname, mode);
RedisModule_FreeString(ctx, prefixed_keyname);
RedisModule_FreeString(ctx, channel);
return key;
}
/**
* This is a helper method to convert a redis module string to a flatbuffer
* string.
@ -411,8 +454,181 @@ bool PublishObjectNotification(RedisModuleCtx *ctx,
return true;
}
// This is a temporary redis command that will be removed once
// NOTE(pcmoritz): This is a temporary redis command that will be removed once
// the GCS uses https://github.com/pcmoritz/credis.
int TaskTableAdd(RedisModuleCtx *ctx,
RedisModuleString *id,
RedisModuleString *data) {
const char *buf = RedisModule_StringPtrLen(data, NULL);
auto message = flatbuffers::GetRoot<TaskTableData>(buf);
if (message->scheduling_state() == SchedulingState_WAITING ||
message->scheduling_state() == SchedulingState_SCHEDULED) {
/* Build the PUBLISH topic and message for task table subscribers. The
* topic
* is a string in the format "TASK_PREFIX:<local scheduler ID>:<state>".
* The
* message is a serialized SubscribeToTasksReply flatbuffer object. */
std::string state = std::to_string(message->scheduling_state());
RedisModuleString *publish_topic = RedisString_Format(
ctx, "%s%b:%s", TASK_PREFIX, message->scheduler_id()->str().data(),
sizeof(DBClientID), state.c_str());
/* Construct the flatbuffers object for the payload. */
flatbuffers::FlatBufferBuilder fbb;
/* Create the flatbuffers message. */
auto msg = CreateTaskReply(
fbb, RedisStringToFlatbuf(fbb, id), message->scheduling_state(),
fbb.CreateString(message->scheduler_id()),
fbb.CreateString(message->execution_dependencies()),
fbb.CreateString(message->task_info()), message->spillback_count(),
true /* not used */);
fbb.Finish(msg);
RedisModuleString *publish_message = RedisModule_CreateString(
ctx, (const char *) fbb.GetBufferPointer(), fbb.GetSize());
RedisModuleCallReply *reply =
RedisModule_Call(ctx, "PUBLISH", "ss", publish_topic, publish_message);
/* See how many clients received this publish. */
long long num_clients = RedisModule_CallReplyInteger(reply);
RAY_CHECK(num_clients <= 1) << "Published to " << num_clients
<< " clients.";
RedisModule_FreeString(ctx, publish_message);
RedisModule_FreeString(ctx, publish_topic);
}
return RedisModule_ReplyWithSimpleString(ctx, "OK");
}
// TODO(swang): Implement the client table as an append-only log so that we
// don't need this special case for client table publication.
int ClientTableAdd(RedisModuleCtx *ctx,
RedisModuleString *pubsub_channel_str,
RedisModuleString *data) {
const char *buf = RedisModule_StringPtrLen(data, NULL);
auto client_data = flatbuffers::GetRoot<ClientTableData>(buf);
RedisModuleKey *clients_key = (RedisModuleKey *) RedisModule_OpenKey(
ctx, pubsub_channel_str, REDISMODULE_READ | REDISMODULE_WRITE);
// If this is a client addition, send all previous notifications, in order.
// NOTE(swang): This will go to all clients, so some clients will get
// duplicate notifications.
if (client_data->is_insertion() &&
RedisModule_KeyType(clients_key) != REDISMODULE_KEYTYPE_EMPTY) {
// NOTE(swang): Sets are not implemented yet, so we use ZSETs instead.
CHECK_ERROR(RedisModule_ZsetFirstInScoreRange(
clients_key, REDISMODULE_NEGATIVE_INFINITE,
REDISMODULE_POSITIVE_INFINITE, 1, 1),
"Unable to initialize zset iterator");
do {
RedisModuleString *message =
RedisModule_ZsetRangeCurrentElement(clients_key, NULL);
RedisModuleCallReply *reply =
RedisModule_Call(ctx, "PUBLISH", "ss", pubsub_channel_str, message);
if (reply == NULL) {
RedisModule_CloseKey(clients_key);
return RedisModule_ReplyWithError(ctx, "error during PUBLISH");
}
} while (RedisModule_ZsetRangeNext(clients_key));
}
// Append this notification to the past notifications so that it will get
// sent to new clients in the future.
size_t index = RedisModule_ValueLength(clients_key);
// Serialize the notification to send.
flatbuffers::FlatBufferBuilder fbb;
auto message = CreateGcsNotification(fbb, fbb.CreateString(""),
RedisStringToFlatbuf(fbb, data));
fbb.Finish(message);
auto notification = RedisModule_CreateString(
ctx, reinterpret_cast<const char *>(fbb.GetBufferPointer()),
fbb.GetSize());
RedisModule_ZsetAdd(clients_key, index, notification, NULL);
// Publish the notification about this client.
RedisModuleCallReply *reply =
RedisModule_Call(ctx, "PUBLISH", "ss", pubsub_channel_str, notification);
RedisModule_FreeString(ctx, notification);
if (reply == NULL) {
RedisModule_CloseKey(clients_key);
return RedisModule_ReplyWithError(ctx, "error during PUBLISH");
}
RedisModule_CloseKey(clients_key);
return RedisModule_ReplyWithSimpleString(ctx, "OK");
}
/// Publish a notification for a new entry at a key. This publishes a
/// notification to all subscribers of the table, as well as every client that
/// has requested notifications for this key.
///
/// \param pubsub_channel_str The pubsub channel name that notifications for
/// this key should be published to. When publishing to a specific
/// client, the channel name should be <pubsub_channel>:<client_id>.
/// \param id The ID of the key that the notification is about.
/// \param data The data to publish.
/// \return OK if there is no error during a publish.
int PublishTableAdd(RedisModuleCtx *ctx,
RedisModuleString *pubsub_channel_str,
RedisModuleString *id,
RedisModuleString *data) {
// Serialize the notification to send.
flatbuffers::FlatBufferBuilder fbb;
auto message = CreateGcsNotification(fbb, RedisStringToFlatbuf(fbb, id),
RedisStringToFlatbuf(fbb, data));
fbb.Finish(message);
// Write the data back to any subscribers that are listening to all table
// notifications.
RedisModuleCallReply *reply =
RedisModule_Call(ctx, "PUBLISH", "sb", pubsub_channel_str,
fbb.GetBufferPointer(), fbb.GetSize());
if (reply == NULL) {
return RedisModule_ReplyWithError(ctx, "error during PUBLISH");
}
// Publish the data to any clients who requested notifications on this key.
RedisModuleKey *notification_key = OpenBroadcastKey(
ctx, pubsub_channel_str, id, REDISMODULE_READ | REDISMODULE_WRITE);
if (RedisModule_KeyType(notification_key) != REDISMODULE_KEYTYPE_EMPTY) {
// NOTE(swang): Sets are not implemented yet, so we use ZSETs instead.
CHECK_ERROR(RedisModule_ZsetFirstInScoreRange(
notification_key, REDISMODULE_NEGATIVE_INFINITE,
REDISMODULE_POSITIVE_INFINITE, 1, 1),
"Unable to initialize zset iterator");
for (; !RedisModule_ZsetRangeEndReached(notification_key);
RedisModule_ZsetRangeNext(notification_key)) {
RedisModuleString *client_channel =
RedisModule_ZsetRangeCurrentElement(notification_key, NULL);
RedisModuleCallReply *reply =
RedisModule_Call(ctx, "PUBLISH", "sb", client_channel,
fbb.GetBufferPointer(), fbb.GetSize());
if (reply == NULL) {
RedisModule_CloseKey(notification_key);
return RedisModule_ReplyWithError(ctx, "error during PUBLISH");
}
}
}
RedisModule_CloseKey(notification_key);
return RedisModule_ReplyWithSimpleString(ctx, "OK");
}
/// Add an entry at a key. This overwrites any existing data at the key.
/// Publishes a notification about the update to all subscribers, if a pubsub
/// channel is provided.
///
/// This is called from a client with the command:
//
/// RAY.TABLE_ADD <table_prefix> <pubsub_channel> <id> <data>
///
/// \param table_prefix The prefix string for keys in this table.
/// \param pubsub_channel The pubsub channel name that notifications for
/// this key should be published to. When publishing to a specific
/// client, the channel name should be <pubsub_channel>:<client_id>.
/// \param id The ID of the key to set.
/// \param data The data to insert at the key.
/// \return The current value at the key, or OK if there is no value.
int TableAdd_RedisCommand(RedisModuleCtx *ctx,
RedisModuleString **argv,
int argc) {
@ -431,108 +647,22 @@ int TableAdd_RedisCommand(RedisModuleCtx *ctx,
RedisModule_StringSet(key, data);
RedisModule_CloseKey(key);
// Get the requested pubsub channel.
long long pubsub_channel_long;
RAY_CHECK(RedisModule_StringToLongLong(
pubsub_channel_str, &pubsub_channel_long) == REDISMODULE_OK)
<< "Pubsub channel must be a valid TablePubsub";
auto pubsub_channel = static_cast<TablePubsub>(pubsub_channel_long);
RAY_CHECK(pubsub_channel >= TablePubsub_MIN &&
pubsub_channel <= TablePubsub_MAX)
<< "Pubsub channel must be a valid TablePubsub";
// Publish a message on the requested pubsub channel if necessary.
TablePubsub pubsub_channel = ParseTablePubsub(pubsub_channel_str);
if (pubsub_channel == TablePubsub_TASK) {
const char *buf = RedisModule_StringPtrLen(data, NULL);
auto message = flatbuffers::GetRoot<TaskTableData>(buf);
if (message->scheduling_state() == SchedulingState_WAITING ||
message->scheduling_state() == SchedulingState_SCHEDULED) {
/* Build the PUBLISH topic and message for task table subscribers. The
* topic
* is a string in the format "TASK_PREFIX:<local scheduler ID>:<state>".
* The
* message is a serialized SubscribeToTasksReply flatbuffer object. */
std::string state = std::to_string(message->scheduling_state());
RedisModuleString *publish_topic = RedisString_Format(
ctx, "%s%b:%s", TASK_PREFIX, message->scheduler_id()->str().data(),
sizeof(DBClientID), state.c_str());
/* Construct the flatbuffers object for the payload. */
flatbuffers::FlatBufferBuilder fbb;
/* Create the flatbuffers message. */
auto msg = CreateTaskReply(
fbb, RedisStringToFlatbuf(fbb, id), message->scheduling_state(),
fbb.CreateString(message->scheduler_id()),
fbb.CreateString(message->execution_dependencies()),
fbb.CreateString(message->task_info()), message->spillback_count(),
true /* not used */);
fbb.Finish(msg);
RedisModuleString *publish_message = RedisModule_CreateString(
ctx, (const char *) fbb.GetBufferPointer(), fbb.GetSize());
RedisModuleCallReply *reply = RedisModule_Call(
ctx, "PUBLISH", "ss", publish_topic, publish_message);
/* See how many clients received this publish. */
long long num_clients = RedisModule_CallReplyInteger(reply);
RAY_CHECK(num_clients <= 1) << "Published to " << num_clients
<< " clients.";
RedisModule_FreeString(ctx, publish_message);
RedisModule_FreeString(ctx, publish_topic);
}
// Publish the task to its subscribers.
// TODO(swang): This is only necessary for legacy Ray and should be removed
// once we switch to using the new GCS API for the task table.
return TaskTableAdd(ctx, id, data);
} else if (pubsub_channel == TablePubsub_CLIENT) {
const char *buf = RedisModule_StringPtrLen(data, NULL);
auto client_data = flatbuffers::GetRoot<ClientTableData>(buf);
RedisModuleKey *clients_key = (RedisModuleKey *) RedisModule_OpenKey(
ctx, pubsub_channel_str, REDISMODULE_READ | REDISMODULE_WRITE);
// If this is a client addition, send all previous notifications, in order.
// NOTE(swang): This will go to all clients, so some clients will get
// duplicate notifications.
if (client_data->is_insertion() &&
RedisModule_KeyType(clients_key) != REDISMODULE_KEYTYPE_EMPTY) {
// NOTE(swang): Sets are not implemented yet, so we use ZSETs instead.
CHECK_ERROR(RedisModule_ZsetFirstInScoreRange(
clients_key, REDISMODULE_NEGATIVE_INFINITE,
REDISMODULE_POSITIVE_INFINITE, 1, 1),
"Unable to initialize zset iterator");
do {
RedisModuleString *message =
RedisModule_ZsetRangeCurrentElement(clients_key, NULL);
RedisModuleCallReply *reply =
RedisModule_Call(ctx, "PUBLISH", "ss", pubsub_channel_str, message);
if (reply == NULL) {
RedisModule_CloseKey(clients_key);
RedisModule_ReplyWithError(ctx, "error during PUBLISH");
}
} while (RedisModule_ZsetRangeNext(clients_key));
}
// Append this notification to the past notifications so that it will get
// sent to new clients in the future.
size_t index = RedisModule_ValueLength(key);
RedisModule_ZsetAdd(clients_key, index, data, NULL);
// Publish the notification about this client.
RedisModuleCallReply *reply =
RedisModule_Call(ctx, "PUBLISH", "ss", pubsub_channel_str, data);
if (reply == NULL) {
RedisModule_ReplyWithError(ctx, "error during PUBLISH");
}
RedisModule_CloseKey(clients_key);
// Publish all previous client table additions to the new client.
return ClientTableAdd(ctx, pubsub_channel_str, data);
} else if (pubsub_channel != TablePubsub_NO_PUBLISH) {
// All other pubsub channels write the data back directly onto the channel.
RedisModuleCallReply *reply =
RedisModule_Call(ctx, "PUBLISH", "ss", pubsub_channel_str, data);
if (reply == NULL) {
RedisModule_ReplyWithError(ctx, "error during PUBLISH");
}
return PublishTableAdd(ctx, pubsub_channel_str, id, data);
} else {
return RedisModule_ReplyWithSimpleString(ctx, "OK");
}
return RedisModule_ReplyWithSimpleString(ctx, "OK");
}
// This is a temporary redis command that will be removed once
@ -561,6 +691,114 @@ int TableLookup_RedisCommand(RedisModuleCtx *ctx,
return REDISMODULE_OK;
}
/// Request notifications for changes to a key. Returns the current value or
/// values at the key. Notifications will be sent to the requesting client for
/// every subsequent TABLE_ADD to the key.
///
/// This is called from a client with the command:
//
/// RAY.TABLE_REQUEST_NOTIFICATIONS <table_prefix> <pubsub_channel> <id>
/// <client_id>
///
/// \param table_prefix The prefix string for keys in this table.
/// \param pubsub_channel The pubsub channel name that notifications for
/// this key should be published to. When publishing to a specific
/// client, the channel name should be <pubsub_channel>:<client_id>.
/// \param id The ID of the key to publish notifications for.
/// \param client_id The ID of the client that is being notified.
/// \return The current value at the key, or OK if there is no value.
int TableRequestNotifications_RedisCommand(RedisModuleCtx *ctx,
RedisModuleString **argv,
int argc) {
if (argc != 5) {
return RedisModule_WrongArity(ctx);
}
RedisModuleString *prefix_str = argv[1];
RedisModuleString *pubsub_channel_str = argv[2];
RedisModuleString *id = argv[3];
RedisModuleString *client_id = argv[4];
RedisModuleString *client_channel =
FormatPubsubChannel(ctx, pubsub_channel_str, client_id);
// Add this client to the set of clients that should be notified when there
// are changes to the key.
RedisModuleKey *notification_key = OpenBroadcastKey(
ctx, pubsub_channel_str, id, REDISMODULE_READ | REDISMODULE_WRITE);
CHECK_ERROR(RedisModule_ZsetAdd(notification_key, 0.0, client_channel, NULL),
"ZsetAdd failed.");
RedisModule_CloseKey(notification_key);
RedisModule_FreeString(ctx, client_channel);
// Return the current value at the key, if any, to the client that requested
// a notification.
RedisModuleKey *table_key =
OpenPrefixedKey(ctx, prefix_str, id, REDISMODULE_READ);
if (table_key != nullptr) {
// Serialize the notification to send.
size_t data_len = 0;
char *data_buf =
RedisModule_StringDMA(table_key, &data_len, REDISMODULE_READ);
flatbuffers::FlatBufferBuilder fbb;
auto message = CreateGcsNotification(fbb, RedisStringToFlatbuf(fbb, id),
fbb.CreateString(data_buf, data_len));
fbb.Finish(message);
int result = RedisModule_ReplyWithStringBuffer(
ctx, reinterpret_cast<const char *>(fbb.GetBufferPointer()),
fbb.GetSize());
RedisModule_CloseKey(table_key);
return result;
} else {
RedisModule_CloseKey(table_key);
RedisModule_ReplyWithSimpleString(ctx, "OK");
return REDISMODULE_OK;
}
}
/// Cancel notifications for changes to a key. The client will no longer
/// receive notifications for this key.
///
/// This is called from a client with the command:
//
/// RAY.TABLE_CANCEL_NOTIFICATIONS <table_prefix> <pubsub_channel> <id>
/// <client_id>
///
/// \param table_prefix The prefix string for keys in this table.
/// \param pubsub_channel The pubsub channel name that notifications for
/// this key should be published to. If publishing to a specific client,
/// then the channel name should be <pubsub_channel>:<client_id>.
/// \param id The ID of the key to publish notifications for.
/// \param client_id The ID of the client that is being notified.
/// \return OK if the requesting client was removed, or an error if the client
/// was not found.
int TableCancelNotifications_RedisCommand(RedisModuleCtx *ctx,
RedisModuleString **argv,
int argc) {
if (argc < 5) {
return RedisModule_WrongArity(ctx);
}
RedisModuleString *pubsub_channel_str = argv[2];
RedisModuleString *id = argv[3];
RedisModuleString *client_id = argv[4];
RedisModuleString *client_channel =
FormatPubsubChannel(ctx, pubsub_channel_str, client_id);
// Remove this client from the set of clients that should be notified when
// there are changes to the key.
RedisModuleKey *notification_key = OpenBroadcastKey(
ctx, pubsub_channel_str, id, REDISMODULE_READ | REDISMODULE_WRITE);
RAY_CHECK(RedisModule_KeyType(notification_key) != REDISMODULE_KEYTYPE_EMPTY);
int deleted;
RedisModule_ZsetRem(notification_key, client_channel, &deleted);
RAY_CHECK(deleted);
RedisModule_CloseKey(notification_key);
RedisModule_ReplyWithSimpleString(ctx, "OK");
return REDISMODULE_OK;
}
bool is_nil(const std::string &data) {
RAY_CHECK(data.size() == kUniqueIDSize);
const uint8_t *d = reinterpret_cast<const uint8_t *>(data.data());
@ -1429,6 +1667,18 @@ int RedisModule_OnLoad(RedisModuleCtx *ctx,
return REDISMODULE_ERR;
}
if (RedisModule_CreateCommand(ctx, "ray.table_request_notifications",
TableRequestNotifications_RedisCommand,
"write pubsub", 0, 0, 0) == REDISMODULE_ERR) {
return REDISMODULE_ERR;
}
if (RedisModule_CreateCommand(ctx, "ray.table_cancel_notifications",
TableCancelNotifications_RedisCommand,
"write pubsub", 0, 0, 0) == REDISMODULE_ERR) {
return REDISMODULE_ERR;
}
if (RedisModule_CreateCommand(ctx, "ray.table_test_and_update",
TableTestAndUpdate_RedisCommand, "write", 0, 0,
0) == REDISMODULE_ERR) {

View file

@ -1332,10 +1332,10 @@ void log_object_hash_mismatch_error_result_callback(ObjectID object_id,
RAY_CHECK_OK(state->gcs_client.task_table().Lookup(
ray::JobID::nil(), task_id,
[user_context](gcs::AsyncGcsClient *, const TaskID &,
std::shared_ptr<TaskTableDataT> t) {
const TaskTableDataT &t) {
Task *task = Task_alloc(
t->task_info.data(), t->task_info.size(), t->scheduling_state,
DBClientID::from_binary(t->scheduler_id), std::vector<ObjectID>());
t.task_info.data(), t.task_info.size(), t.scheduling_state,
DBClientID::from_binary(t.scheduler_id), std::vector<ObjectID>());
log_object_hash_mismatch_error_task_callback(task, user_context);
Task_free(task);
},

View file

@ -21,7 +21,7 @@ static inline void flushall_redis(void) {
class TestGcs : public ::testing::Test {
public:
TestGcs() {
TestGcs() : num_callbacks_(0) {
client_ = std::make_shared<gcs::AsyncGcsClient>();
ClientTableDataT client_info;
client_info.client_id = ClientID::from_random().binary();
@ -42,7 +42,12 @@ class TestGcs : public ::testing::Test {
virtual void Stop() = 0;
int64_t NumCallbacks() const { return num_callbacks_; }
void IncrementNumCallbacks() { num_callbacks_++; }
protected:
int64_t num_callbacks_;
std::shared_ptr<gcs::AsyncGcsClient> client_;
JobID job_id_;
};
@ -87,14 +92,14 @@ class TestGcsWithAsio : public TestGcs {
};
void ObjectAdded(gcs::AsyncGcsClient *client, const UniqueID &id,
std::shared_ptr<ObjectTableDataT> data) {
ASSERT_EQ(data->managers, std::vector<std::string>({"A", "B"}));
const ObjectTableDataT &data) {
ASSERT_EQ(data.managers, std::vector<std::string>({"A", "B"}));
}
void Lookup(gcs::AsyncGcsClient *client, const UniqueID &id,
std::shared_ptr<ObjectTableDataT> data) {
const ObjectTableDataT &data) {
// Check that the object entry was added.
ASSERT_EQ(data->managers, std::vector<std::string>({"A", "B"}));
ASSERT_EQ(data.managers, std::vector<std::string>({"A", "B"}));
test->Stop();
}
@ -126,14 +131,37 @@ TEST_F(TestGcsWithAsio, TestObjectTable) {
TestObjectTable(job_id_, client_);
}
void TestLookupFailure(const JobID &job_id, std::shared_ptr<gcs::AsyncGcsClient> client) {
auto object_id = ObjectID::from_random();
// Looking up an empty object ID should call the failure callback.
auto failure_callback = [](gcs::AsyncGcsClient *client, const UniqueID &id) {
test->Stop();
};
RAY_CHECK_OK(
client->object_table().Lookup(job_id, object_id, nullptr, failure_callback));
// Run the event loop. The loop will only stop if the failure callback is
// called.
test->Start();
}
TEST_F(TestGcsWithAe, TestLookupFailure) {
test = this;
TestLookupFailure(job_id_, client_);
}
TEST_F(TestGcsWithAsio, TestLookupFailure) {
test = this;
TestLookupFailure(job_id_, client_);
}
void TaskAdded(gcs::AsyncGcsClient *client, const TaskID &id,
std::shared_ptr<TaskTableDataT> data) {
ASSERT_EQ(data->scheduling_state, SchedulingState_SCHEDULED);
const TaskTableDataT &data) {
ASSERT_EQ(data.scheduling_state, SchedulingState_SCHEDULED);
}
void TaskLookup(gcs::AsyncGcsClient *client, const TaskID &id,
std::shared_ptr<TaskTableDataT> data) {
ASSERT_EQ(data->scheduling_state, SchedulingState_SCHEDULED);
const TaskTableDataT &data) {
ASSERT_EQ(data.scheduling_state, SchedulingState_SCHEDULED);
}
void TaskLookupFailure(gcs::AsyncGcsClient *client, const TaskID &id) {
@ -141,8 +169,8 @@ void TaskLookupFailure(gcs::AsyncGcsClient *client, const TaskID &id) {
}
void TaskLookupAfterUpdate(gcs::AsyncGcsClient *client, const TaskID &id,
std::shared_ptr<TaskTableDataT> data) {
ASSERT_EQ(data->scheduling_state, SchedulingState_LOST);
const TaskTableDataT &data) {
ASSERT_EQ(data.scheduling_state, SchedulingState_LOST);
test->Stop();
}
@ -153,8 +181,8 @@ void TaskLookupAfterUpdateFailure(gcs::AsyncGcsClient *client, const TaskID &id)
void TaskUpdateCallback(gcs::AsyncGcsClient *client, const TaskID &task_id,
const TaskTableDataT &task, bool updated) {
RAY_CHECK_OK(client->task_table().Lookup(
DriverID::nil(), task_id, &TaskLookupAfterUpdate, &TaskLookupAfterUpdateFailure));
RAY_CHECK_OK(client->task_table().Lookup(DriverID::nil(), task_id,
&TaskLookupAfterUpdate, &TaskLookupFailure));
}
void TestTaskTable(const JobID &job_id, std::shared_ptr<gcs::AsyncGcsClient> client) {
@ -189,28 +217,40 @@ TEST_F(TestGcsWithAsio, TestTaskTable) {
TestTaskTable(job_id_, client_);
}
void ObjectTableSubscribed(gcs::AsyncGcsClient *client, const UniqueID &id,
std::shared_ptr<ObjectTableDataT> data) {
test->Stop();
}
void TestSubscribeAll(const JobID &job_id, std::shared_ptr<gcs::AsyncGcsClient> client) {
// Subscribe to all object table notifications. The registered callback for
// notifications will check whether the object below is added.
RAY_CHECK_OK(client->object_table().Subscribe(job_id, ClientID::nil(), &Lookup,
&ObjectTableSubscribed));
// Run the event loop. The loop will only stop if the subscription succeeds.
test->Start();
// We have subscribed. Add an object table entry.
auto data = std::make_shared<ObjectTableDataT>();
data->managers.push_back("A");
data->managers.push_back("B");
ObjectID object_id = ObjectID::from_random();
RAY_CHECK_OK(client->object_table().Add(job_id, object_id, data, &ObjectAdded));
// Callback for a notification.
auto notification_callback = [object_id](
gcs::AsyncGcsClient *client, const UniqueID &id, const ObjectTableDataT &data) {
ASSERT_EQ(id, object_id);
// Check that the object entry was added.
ASSERT_EQ(data.managers, std::vector<std::string>({"A", "B"}));
test->IncrementNumCallbacks();
test->Stop();
};
// Callback for subscription success. This should only be called once.
auto subscribe_callback = [job_id, object_id](gcs::AsyncGcsClient *client) {
test->IncrementNumCallbacks();
// We have subscribed. Add an object table entry.
auto data = std::make_shared<ObjectTableDataT>();
data->managers.push_back("A");
data->managers.push_back("B");
RAY_CHECK_OK(client->object_table().Add(job_id, object_id, data, &ObjectAdded));
};
// Subscribe to all object table notifications. Once we have successfully
// subscribed, we will add an object and check that we get notified of the
// operation.
RAY_CHECK_OK(client->object_table().Subscribe(
job_id, ClientID::nil(), notification_callback, subscribe_callback));
// Run the event loop. The loop will only stop if the registered subscription
// callback is called (or an assertion failure).
test->Start();
// Check that we received one callback for subscription success and one for
// the Add notification.
ASSERT_EQ(test->NumCallbacks(), 2);
}
TEST_F(TestGcsWithAe, TestSubscribeAll) {
@ -223,11 +263,152 @@ TEST_F(TestGcsWithAsio, TestSubscribeAll) {
TestSubscribeAll(job_id_, client_);
}
void TestSubscribeId(const JobID &job_id, std::shared_ptr<gcs::AsyncGcsClient> client) {
// Add an object table entry.
ObjectID object_id1 = ObjectID::from_random();
auto data1 = std::make_shared<ObjectTableDataT>();
data1->managers.push_back("A");
data1->managers.push_back("B");
RAY_CHECK_OK(client->object_table().Add(job_id, object_id1, data1, nullptr));
// Add a second object table entry.
ObjectID object_id2 = ObjectID::from_random();
auto data2 = std::make_shared<ObjectTableDataT>();
data2->managers.push_back("C");
RAY_CHECK_OK(client->object_table().Add(job_id, object_id2, data2, nullptr));
// The callback for subscription success. Once we've subscribed, request
// notifications for the second object that was added.
auto subscribe_callback = [job_id, object_id2](gcs::AsyncGcsClient *client) {
test->IncrementNumCallbacks();
// Request notifications for the second object. Since we already added the
// entry to the table, we should receive an initial notification for its
// current value.
RAY_CHECK_OK(client->object_table().RequestNotifications(
job_id, object_id2, client->client_table().GetLocalClientId()));
// Overwrite the entry for the object. We should receive a second
// notification for its new value.
auto data = std::make_shared<ObjectTableDataT>();
data->managers.push_back("C");
data->managers.push_back("D");
RAY_CHECK_OK(client->object_table().Add(job_id, object_id2, data, nullptr));
};
// The callback for a notification from the object table. This should only be
// received for the object that we requested notifications for.
auto notification_callback = [data2, object_id2](
gcs::AsyncGcsClient *client, const UniqueID &id, const ObjectTableDataT &data) {
ASSERT_EQ(id, object_id2);
// Check that we got a notification for the correct object.
ASSERT_EQ(data.managers.front(), "C");
test->IncrementNumCallbacks();
// Stop the loop once we've received notifications for both writes to the
// object key.
if (test->NumCallbacks() == 3) {
test->Stop();
}
};
RAY_CHECK_OK(
client->object_table().Subscribe(job_id, client->client_table().GetLocalClientId(),
notification_callback, subscribe_callback));
// Run the event loop. The loop will only stop if the registered subscription
// callback is called for both writes to the object key.
test->Start();
// Check that we received one callback for subscription success and two
// callbacks for the Add notifications.
ASSERT_EQ(test->NumCallbacks(), 3);
}
TEST_F(TestGcsWithAe, TestSubscribeId) {
test = this;
TestSubscribeId(job_id_, client_);
}
TEST_F(TestGcsWithAsio, TestSubscribeId) {
test = this;
TestSubscribeId(job_id_, client_);
}
void TestSubscribeCancel(const JobID &job_id,
std::shared_ptr<gcs::AsyncGcsClient> client) {
// Write the object table once.
ObjectID object_id = ObjectID::from_random();
auto data = std::make_shared<ObjectTableDataT>();
data->managers.push_back("A");
RAY_CHECK_OK(client->object_table().Add(job_id, object_id, data, nullptr));
// The callback for subscription success. Once we've subscribed, request
// notifications for the second object that was added.
auto subscribe_callback = [job_id, object_id](gcs::AsyncGcsClient *client) {
test->IncrementNumCallbacks();
// Request notifications for the object. We should receive a notification
// for the current value at the key.
RAY_CHECK_OK(client->object_table().RequestNotifications(
job_id, object_id, client->client_table().GetLocalClientId()));
// Cancel notifications.
RAY_CHECK_OK(client->object_table().CancelNotifications(
job_id, object_id, client->client_table().GetLocalClientId()));
// Write the object table entry twice. Since we canceled notifications, we
// should not get notifications for either of these writes.
auto data = std::make_shared<ObjectTableDataT>();
data->managers.push_back("B");
RAY_CHECK_OK(client->object_table().Add(job_id, object_id, data, nullptr));
data = std::make_shared<ObjectTableDataT>();
data->managers.push_back("C");
RAY_CHECK_OK(client->object_table().Add(job_id, object_id, data, nullptr));
// Request notifications for the object again. We should only receive a
// notification for the current value at the key.
RAY_CHECK_OK(client->object_table().RequestNotifications(
job_id, object_id, client->client_table().GetLocalClientId()));
};
// The callback for a notification from the object table.
auto notification_callback = [object_id](
gcs::AsyncGcsClient *client, const UniqueID &id, const ObjectTableDataT &data) {
ASSERT_EQ(id, object_id);
// Check that we only receive notifications for the key when we have
// requested notifications for it. We should not get a notification for the
// entry that began with "B" since we canceled notifications then.
if (test->NumCallbacks() == 1) {
ASSERT_EQ(data.managers.front(), "A");
} else {
ASSERT_EQ(data.managers.front(), "C");
}
test->IncrementNumCallbacks();
if (test->NumCallbacks() == 3) {
test->Stop();
}
};
RAY_CHECK_OK(
client->object_table().Subscribe(job_id, client->client_table().GetLocalClientId(),
notification_callback, subscribe_callback));
// Run the event loop. The loop will only stop if the registered subscription
// callback is called (or an assertion failure).
test->Start();
// Check that we received one callback for subscription success and two
// callbacks for the Add notifications.
ASSERT_EQ(test->NumCallbacks(), 3);
}
TEST_F(TestGcsWithAe, TestSubscribeCancel) {
test = this;
TestSubscribeCancel(job_id_, client_);
}
TEST_F(TestGcsWithAsio, TestSubscribeCancel) {
test = this;
TestSubscribeCancel(job_id_, client_);
}
void ClientTableNotification(gcs::AsyncGcsClient *client, const UniqueID &id,
std::shared_ptr<ClientTableDataT> data, bool is_insertion) {
const ClientTableDataT &data, bool is_insertion) {
ClientID added_id = client->client_table().GetLocalClientId();
ASSERT_EQ(ClientID::from_binary(data->client_id), added_id);
ASSERT_EQ(data->is_insertion, is_insertion);
ASSERT_EQ(ClientID::from_binary(data.client_id), added_id);
ASSERT_EQ(data.is_insertion, is_insertion);
auto cached_client = client->client_table().GetClient(added_id);
ASSERT_EQ(ClientID::from_binary(cached_client.client_id), added_id);
@ -239,8 +420,7 @@ void TestClientTableConnect(const JobID &job_id,
// Register callbacks for when a client gets added and removed. The latter
// event will stop the event loop.
client->client_table().RegisterClientAddedCallback(
[](gcs::AsyncGcsClient *client, const UniqueID &id,
std::shared_ptr<ClientTableDataT> data) {
[](gcs::AsyncGcsClient *client, const UniqueID &id, const ClientTableDataT &data) {
ClientTableNotification(client, id, data, true);
test->Stop();
});
@ -260,13 +440,11 @@ void TestClientTableDisconnect(const JobID &job_id,
// Register callbacks for when a client gets added and removed. The latter
// event will stop the event loop.
client->client_table().RegisterClientAddedCallback(
[](gcs::AsyncGcsClient *client, const UniqueID &id,
std::shared_ptr<ClientTableDataT> data) {
[](gcs::AsyncGcsClient *client, const UniqueID &id, const ClientTableDataT &data) {
ClientTableNotification(client, id, data, true);
});
client->client_table().RegisterClientRemovedCallback(
[](gcs::AsyncGcsClient *client, const UniqueID &id,
std::shared_ptr<ClientTableDataT> data) {
[](gcs::AsyncGcsClient *client, const UniqueID &id, const ClientTableDataT &data) {
ClientTableNotification(client, id, data, false);
test->Stop();
});

View file

@ -21,6 +21,11 @@ enum TablePubsub:int {
ACTOR
}
table GcsNotification {
id: string;
data: string;
}
table FunctionTableData {
language: Language;
name: string;

View file

@ -11,6 +11,22 @@ extern "C" {
// TODO(pcm): Integrate into the C++ tree.
#include "state/ray_config.h"
namespace {
/// A helper function to call the callback and delete it from the callback
/// manager if necessary.
void ProcessCallback(int64_t callback_index, const std::vector<std::string> &data) {
if (callback_index >= 0) {
bool delete_callback =
ray::gcs::RedisCallbackManager::instance().get(callback_index)(data);
// Delete the callback if necessary.
if (delete_callback) {
ray::gcs::RedisCallbackManager::instance().remove(callback_index);
}
}
}
}
namespace ray {
namespace gcs {
@ -24,24 +40,25 @@ void GlobalRedisCallback(void *c, void *r, void *privdata) {
}
int64_t callback_index = reinterpret_cast<int64_t>(privdata);
redisReply *reply = reinterpret_cast<redisReply *>(r);
std::string data = "";
if (reply->type == REDIS_REPLY_NIL) {
// Respond with blank string, which triggers a failure callback for lookups.
} else if (reply->type == REDIS_REPLY_STRING) {
data = std::string(reply->str, reply->len);
} else if (reply->type == REDIS_REPLY_ARRAY) {
reply = reply->element[reply->elements - 1];
data = std::string(reply->str, reply->len);
} else if (reply->type == REDIS_REPLY_STATUS) {
} else if (reply->type == REDIS_REPLY_ERROR) {
std::vector<std::string> data;
// Parse the response.
switch (reply->type) {
case (REDIS_REPLY_NIL): {
// Do not add any data for a nil response.
} break;
case (REDIS_REPLY_STRING): {
data.push_back(std::string(reply->str, reply->len));
} break;
case (REDIS_REPLY_STATUS): {
} break;
case (REDIS_REPLY_ERROR): {
RAY_LOG(ERROR) << "Redis error " << reply->str;
} else {
} break;
default:
RAY_LOG(FATAL) << "Fatal redis error of type " << reply->type
<< " and with string " << reply->str;
}
RedisCallbackManager::instance().get(callback_index)(data);
// Delete the callback.
RedisCallbackManager::instance().remove(callback_index);
ProcessCallback(callback_index, data);
}
void SubscribeRedisCallback(void *c, void *r, void *privdata) {
@ -50,31 +67,35 @@ void SubscribeRedisCallback(void *c, void *r, void *privdata) {
}
int64_t callback_index = reinterpret_cast<int64_t>(privdata);
redisReply *reply = reinterpret_cast<redisReply *>(r);
std::string data = "";
if (reply->type == REDIS_REPLY_ARRAY) {
// Parse the message.
std::vector<std::string> data;
// Parse the response.
switch (reply->type) {
case (REDIS_REPLY_ARRAY): {
// Parse the published message.
redisReply *message_type = reply->element[0];
if (strcmp(message_type->str, "subscribe") == 0) {
// If the message is for the initial subscription call, do not fill in
// data.
// If the message is for the initial subscription call, return the empty
// string as a response to signify that subscription was successful.
data.push_back("");
} else if (strcmp(message_type->str, "message") == 0) {
// If the message is from a PUBLISH, make sure the data is nonempty.
redisReply *message = reply->element[reply->elements - 1];
data = std::string(message->str, message->len);
RAY_CHECK(!data.empty()) << "Empty message received on subscribe channel";
auto notification = std::string(message->str, message->len);
RAY_CHECK(!notification.empty()) << "Empty message received on subscribe channel";
data.push_back(notification);
} else {
RAY_LOG(FATAL) << "Fatal redis error during subscribe" << message_type->str;
}
// NOTE(swang): We do not delete the callback after calling it since there
// may be more subscription messages.
RedisCallbackManager::instance().get(callback_index)(data);
} else if (reply->type == REDIS_REPLY_ERROR) {
} break;
case (REDIS_REPLY_ERROR): {
RAY_LOG(ERROR) << "Redis error " << reply->str;
} else {
} break;
default:
RAY_LOG(FATAL) << "Fatal redis error of type " << reply->type << " and with string "
<< reply->str;
}
ProcessCallback(callback_index, data);
}
int64_t RedisCallbackManager::add(const RedisCallback &function) {
@ -161,8 +182,9 @@ Status RedisContext::AttachToEventLoop(aeEventLoop *loop) {
}
Status RedisContext::RunAsync(const std::string &command, const UniqueID &id,
uint8_t *data, int64_t length, const TablePrefix prefix,
const TablePubsub pubsub_channel, int64_t callback_index) {
const uint8_t *data, int64_t length,
const TablePrefix prefix, const TablePubsub pubsub_channel,
int64_t callback_index) {
if (length > 0) {
std::string redis_command = command + " %d %d %b %b";
int status = redisAsyncCommand(
@ -200,7 +222,6 @@ Status RedisContext::SubscribeAsync(const ClientID &client_id,
reinterpret_cast<void *>(callback_index), redis_command.c_str(), pubsub_channel);
} else {
// Subscribe only to messages sent to this client.
// TODO(swang): Nobody sends on this channel yet.
std::string redis_command = "SUBSCRIBE %d:%b";
status = redisAsyncCommand(
subscribe_context_, reinterpret_cast<redisCallbackFn *>(&SubscribeRedisCallback),

View file

@ -21,7 +21,10 @@ namespace gcs {
class RedisCallbackManager {
public:
using RedisCallback = std::function<void(const std::string &)>;
/// Every callback should take in a vector of the results from the Redis
/// operation and return a bool indicating whether the callback should be
/// deleted once called.
using RedisCallback = std::function<bool(const std::vector<std::string> &)>;
static RedisCallbackManager &instance() {
static RedisCallbackManager instance;
@ -50,7 +53,7 @@ class RedisContext {
~RedisContext();
Status Connect(const std::string &address, int port);
Status AttachToEventLoop(aeEventLoop *loop);
Status RunAsync(const std::string &command, const UniqueID &id, uint8_t *data,
Status RunAsync(const std::string &command, const UniqueID &id, const uint8_t *data,
int64_t length, const TablePrefix prefix,
const TablePubsub pubsub_channel, int64_t callback_index);
Status SubscribeAsync(const ClientID &client_id, const TablePubsub pubsub_channel,

View file

@ -1,36 +1,141 @@
#include "ray/gcs/tables.h"
#include "common_protocol.h"
#include "ray/gcs/client.h"
namespace ray {
namespace gcs {
void ClientTable::RegisterClientAddedCallback(const Callback &callback) {
template <typename ID, typename Data>
Status Table<ID, Data>::Add(const JobID &job_id, const ID &id,
std::shared_ptr<DataT> data, const Callback &done) {
auto d = std::shared_ptr<CallbackData>(
new CallbackData({id, data, done, nullptr, nullptr, this, client_}));
int64_t callback_index =
RedisCallbackManager::instance().add([d](const std::vector<std::string> &data) {
if (d->callback != nullptr) {
(d->callback)(d->client, d->id, *d->data);
}
return true;
});
flatbuffers::FlatBufferBuilder fbb;
fbb.ForceDefaults(true);
fbb.Finish(Data::Pack(fbb, data.get()));
return context_->RunAsync("RAY.TABLE_ADD", id, fbb.GetBufferPointer(), fbb.GetSize(),
prefix_, pubsub_channel_, callback_index);
}
template <typename ID, typename Data>
Status Table<ID, Data>::Lookup(const JobID &job_id, const ID &id, const Callback &lookup,
const FailureCallback &failure) {
auto d = std::shared_ptr<CallbackData>(
new CallbackData({id, nullptr, lookup, failure, nullptr, this, client_}));
int64_t callback_index =
RedisCallbackManager::instance().add([d](const std::vector<std::string> &data) {
if (data.empty()) {
if (d->failure != nullptr) {
(d->failure)(d->client, d->id);
}
} else {
RAY_CHECK(data.size() == 1);
if (d->callback != nullptr) {
DataT result;
auto root = flatbuffers::GetRoot<Data>(data[0].data());
root->UnPackTo(&result);
(d->callback)(d->client, d->id, result);
}
}
return true;
});
std::vector<uint8_t> nil;
return context_->RunAsync("RAY.TABLE_LOOKUP", id, nil.data(), nil.size(), prefix_,
pubsub_channel_, callback_index);
}
template <typename ID, typename Data>
Status Table<ID, Data>::Subscribe(const JobID &job_id, const ClientID &client_id,
const Callback &subscribe,
const SubscriptionCallback &done) {
RAY_CHECK(subscribe_callback_index_ == -1)
<< "Client called Subscribe twice on the same table";
auto d = std::shared_ptr<CallbackData>(
new CallbackData({client_id, nullptr, subscribe, nullptr, done, this, client_}));
int64_t callback_index = RedisCallbackManager::instance().add(
[this, d](const std::vector<std::string> &data) {
if (data.size() == 1 && data[0] == "") {
// No notification data is provided. This is the callback for the
// initial subscription request.
if (d->subscription_callback != nullptr) {
(d->subscription_callback)(d->client);
}
} else {
// Data is provided. This is the callback for a message.
RAY_CHECK(data.size() == 1);
if (d->callback != nullptr) {
// Parse the notification.
auto notification = flatbuffers::GetRoot<GcsNotification>(data[0].data());
ID id = UniqueID::nil();
if (notification->id()->size() > 0) {
id = from_flatbuf(*notification->id());
}
DataT result;
auto root = flatbuffers::GetRoot<Data>(notification->data()->data());
root->UnPackTo(&result);
(d->callback)(d->client, id, result);
}
}
// We do not delete the callback after calling it since there may be
// more subscription messages.
return false;
});
subscribe_callback_index_ = callback_index;
return context_->SubscribeAsync(client_id, pubsub_channel_, callback_index);
}
template <typename ID, typename Data>
Status Table<ID, Data>::RequestNotifications(const JobID &job_id, const ID &id,
const ClientID &client_id) {
RAY_CHECK(subscribe_callback_index_ >= 0)
<< "Client requested notifications on a key before Subscribe completed";
return context_->RunAsync("RAY.TABLE_REQUEST_NOTIFICATIONS", id, client_id.data(),
client_id.size(), prefix_, pubsub_channel_,
subscribe_callback_index_);
}
template <typename ID, typename Data>
Status Table<ID, Data>::CancelNotifications(const JobID &job_id, const ID &id,
const ClientID &client_id) {
RAY_CHECK(subscribe_callback_index_ >= 0)
<< "Client canceled notifications on a key before Subscribe completed";
return context_->RunAsync("RAY.TABLE_CANCEL_NOTIFICATIONS", id, client_id.data(),
client_id.size(), prefix_, pubsub_channel_,
/*callback_index=*/-1);
}
void ClientTable::RegisterClientAddedCallback(const ClientTableCallback &callback) {
client_added_callback_ = callback;
// Call the callback for any added clients that are cached.
for (const auto &entry : client_cache_) {
if (!entry.first.is_nil() && entry.second.is_insertion) {
auto data = std::make_shared<ClientTableDataT>(entry.second);
client_added_callback_(client_, entry.first, data);
client_added_callback_(client_, ClientID::nil(), entry.second);
}
}
}
void ClientTable::RegisterClientRemovedCallback(const Callback &callback) {
void ClientTable::RegisterClientRemovedCallback(const ClientTableCallback &callback) {
client_removed_callback_ = callback;
// Call the callback for any removed clients that are cached.
for (const auto &entry : client_cache_) {
if (!entry.first.is_nil() && !entry.second.is_insertion) {
auto data = std::make_shared<ClientTableDataT>(entry.second);
client_removed_callback_(client_, entry.first, data);
client_removed_callback_(client_, ClientID::nil(), entry.second);
}
}
}
void ClientTable::HandleNotification(AsyncGcsClient *client, const ClientID &channel_id,
std::shared_ptr<ClientTableDataT> data) {
ClientID client_id = ClientID::from_binary(data->client_id);
const ClientTableDataT &data) {
ClientID client_id = ClientID::from_binary(data.client_id);
// It's possible to get duplicate notifications from the client table, so
// check whether this notification is new.
auto entry = client_cache_.find(client_id);
@ -42,24 +147,24 @@ void ClientTable::HandleNotification(AsyncGcsClient *client, const ClientID &cha
// If the entry is in the cache, then the notification is new if the client
// was alive and is now dead.
bool was_inserted = entry->second.is_insertion;
bool is_deleted = !data->is_insertion;
bool is_deleted = !data.is_insertion;
is_new = (was_inserted && is_deleted);
// Once a client with a given ID has been removed, it should never be added
// again. If the entry was in the cache and the client was deleted, check
// that this new notification is not an insertion.
if (!entry->second.is_insertion) {
RAY_CHECK(!data->is_insertion)
RAY_CHECK(!data.is_insertion)
<< "Notification for addition of a client that was already removed:"
<< client_id.hex();
}
}
// Add the notification to our cache. Notifications are idempotent.
client_cache_[client_id] = *data;
client_cache_[client_id] = data;
// If the notification is new, call any registered callbacks.
if (is_new) {
if (data->is_insertion) {
if (data.is_insertion) {
if (client_added_callback_ != nullptr) {
client_added_callback_(client, client_id, data);
}
@ -72,7 +177,7 @@ void ClientTable::HandleNotification(AsyncGcsClient *client, const ClientID &cha
}
void ClientTable::HandleConnected(AsyncGcsClient *client, const ClientID &client_id,
std::shared_ptr<ClientTableDataT> data) {
const ClientTableDataT &data) {
RAY_CHECK(client_id == client_id_) << client_id.hex() << " " << client_id_.hex();
}
@ -87,18 +192,17 @@ Status ClientTable::Connect() {
data->is_insertion = true;
// Callback for a notification from the client table.
auto notification_callback = [this](AsyncGcsClient *client, const ClientID &channel_id,
std::shared_ptr<ClientTableDataT> data) {
const ClientTableDataT &data) {
return HandleNotification(client, channel_id, data);
};
// Callback to handle our own successful connection once we've added
// ourselves.
auto add_callback = [this](AsyncGcsClient *client, const ClientID &id,
std::shared_ptr<ClientTableDataT> data) {
const ClientTableDataT &data) {
HandleConnected(client, id, data);
};
// Callback to add ourselves once we've successfully subscribed.
auto subscription_callback = [this, data, add_callback](
AsyncGcsClient *c, const ClientID &id, std::shared_ptr<ClientTableDataT> d) {
auto subscription_callback = [this, data, add_callback](AsyncGcsClient *c) {
// Mark ourselves as deleted if we called Disconnect() since the last
// Connect() call.
if (disconnected_) {
@ -114,7 +218,7 @@ Status ClientTable::Disconnect() {
auto data = std::make_shared<ClientTableDataT>(local_client_);
data->is_insertion = true;
auto add_callback = [this](AsyncGcsClient *client, const ClientID &id,
std::shared_ptr<ClientTableDataT> data) {
const ClientTableDataT &data) {
HandleConnected(client, id, data);
};
RAY_RETURN_NOT_OK(Add(JobID::nil(), client_id_, data, add_callback));
@ -135,6 +239,9 @@ const ClientTableDataT &ClientTable::GetClient(const ClientID &client_id) {
}
}
template class Table<TaskID, TaskTableData>;
template class Table<ObjectID, ObjectTableData>;
} // namespace gcs
} // namespace ray

View file

@ -30,9 +30,14 @@ template <typename ID, typename Data>
class Table {
public:
using DataT = typename Data::NativeTableType;
using Callback = std::function<void(AsyncGcsClient *client, const ID &id,
std::shared_ptr<DataT> data)>;
using Callback =
std::function<void(AsyncGcsClient *client, const ID &id, const DataT &data)>;
/// The callback to call when a lookup fails because there is no entry at the
/// key.
using FailureCallback = std::function<void(AsyncGcsClient *client, const ID &id)>;
/// The callback to call when a SUBSCRIBE call completes and we are ready to
/// request and receive notifications.
using SubscriptionCallback = std::function<void(AsyncGcsClient *client)>;
struct CallbackData {
ID id;
@ -41,7 +46,7 @@ class Table {
FailureCallback failure;
// An optional callback to call for subscription operations, where the
// first message is a notification of subscription success.
Callback subscription_callback;
SubscriptionCallback subscription_callback;
Table<ID, Data> *table;
AsyncGcsClient *client;
};
@ -50,7 +55,8 @@ class Table {
: context_(context),
client_(client),
pubsub_channel_(TablePubsub_NO_PUBLISH),
prefix_(TablePrefix_UNUSED){};
prefix_(TablePrefix_UNUSED),
subscribe_callback_index_(-1){};
/// Add an entry to the table.
///
@ -61,91 +67,63 @@ class Table {
/// GCS.
/// \return Status
Status Add(const JobID &job_id, const ID &id, std::shared_ptr<DataT> data,
const Callback &done) {
auto d = std::shared_ptr<CallbackData>(
new CallbackData({id, data, done, nullptr, nullptr, this, client_}));
int64_t callback_index =
RedisCallbackManager::instance().add([d](const std::string &data) {
if (d->callback != nullptr) {
(d->callback)(d->client, d->id, d->data);
}
});
flatbuffers::FlatBufferBuilder fbb;
fbb.ForceDefaults(true);
fbb.Finish(Data::Pack(fbb, data.get()));
RAY_RETURN_NOT_OK(context_->RunAsync("RAY.TABLE_ADD", id, fbb.GetBufferPointer(),
fbb.GetSize(), prefix_, pubsub_channel_,
callback_index));
return Status::OK();
}
const Callback &done);
/// Lookup an entry asynchronously.
///
/// \param job_id The ID of the job (= driver).
/// \param id The ID of the data that is looked up in the GCS.
/// \param lookup Callback that is called after lookup.
/// \param lookup Callback that is called after lookup. If the callback is
/// called with an empty vector, then there was no data at the key.
/// \return Status
Status Lookup(const JobID &job_id, const ID &id, const Callback &lookup,
const FailureCallback &failure) {
auto d = std::shared_ptr<CallbackData>(
new CallbackData({id, nullptr, lookup, failure, nullptr, this, client_}));
int64_t callback_index =
RedisCallbackManager::instance().add([d](const std::string &data) {
if (data.empty()) {
if (d->failure != nullptr) {
(d->failure)(d->client, d->id);
}
} else {
auto result = std::make_shared<DataT>();
auto root = flatbuffers::GetRoot<Data>(data.data());
root->UnPackTo(result.get());
if (d->callback != nullptr) {
(d->callback)(d->client, d->id, result);
}
}
});
std::vector<uint8_t> nil;
RAY_RETURN_NOT_OK(context_->RunAsync("RAY.TABLE_LOOKUP", id, nil.data(), nil.size(),
prefix_, pubsub_channel_, callback_index));
return Status::OK();
}
const FailureCallback &failure);
/// Subscribe to updates of this table
/// Subscribe to any Add operations to this table. The caller may choose to
/// subscribe to all Adds, or to subscribe only to keys that it requests
/// notifications for. This may only be called once per Table instance.
///
/// \param job_id The ID of the job (= driver).
/// \param client_id The type of update to listen to. If this is nil, then a
/// message for each Add to the table will be received. Else, only
/// messages for the given client will be received.
/// \param subscribe Callback that is called on each received message.
/// messages for the given client will be received. In the latter
/// case, the client may request notifications on specific keys in the
/// table via `RequestNotifications`.
/// \param subscribe Callback that is called on each received message. If the
/// callback is called with an empty vector, then there was no data at
/// the key.
/// \param done Callback that is called when subscription is complete and we
/// are ready to receive messages..
/// are ready to receive messages.
/// \return Status
Status Subscribe(const JobID &job_id, const ClientID &client_id,
const Callback &subscribe, const Callback &done) {
auto d = std::shared_ptr<CallbackData>(
new CallbackData({client_id, nullptr, subscribe, nullptr, done, this, client_}));
int64_t callback_index =
RedisCallbackManager::instance().add([d](const std::string &data) {
if (data.empty()) {
// No data is provided. This is the callback for the initial
// subscription request.
if (d->subscription_callback != nullptr) {
(d->subscription_callback)(d->client, d->id, nullptr);
}
} else {
// Data is provided. This is the callback for a message.
auto result = std::make_shared<DataT>();
auto root = flatbuffers::GetRoot<Data>(data.data());
root->UnPackTo(result.get());
(d->callback)(d->client, d->id, result);
}
});
std::vector<uint8_t> nil;
return context_->SubscribeAsync(client_id, pubsub_channel_, callback_index);
}
const Callback &subscribe, const SubscriptionCallback &done);
/// Remove and entry from the table
Status Remove(const JobID &job_id, const ID &id, const Callback &done);
/// Request notifications about a key in this table.
///
/// The notifications will be returned via the subscribe callback that was
/// registered by `Subscribe`. An initial notification will be returned for
/// the current value(s) at the key, if any, and a subsequent notification
/// will be published for every following `Add` to the key. Before
/// notifications can be requested, the caller must first call `Subscribe`,
/// with the same `client_id`.
///
/// \param job_id The ID of the job (= driver).
/// \param id The ID of the key to request notifications for.
/// \param client_id The client who is requesting notifications. Before
/// notifications can be requested, a call to `Subscribe` to this
/// table with the same `client_id` must complete successfully.
/// \return Status
Status RequestNotifications(const JobID &job_id, const ID &id,
const ClientID &client_id);
/// Cancel notifications about a key in this table.
///
/// \param job_id The ID of the job (= driver).
/// \param id The ID of the key to request notifications for.
/// \param client_id The client who originally requested notifications.
/// \return Status
Status CancelNotifications(const JobID &job_id, const ID &id,
const ClientID &client_id);
protected:
/// The connection to the GCS.
@ -158,6 +136,10 @@ class Table {
TablePubsub pubsub_channel_;
/// The prefix to use for keys in this table.
TablePrefix prefix_;
/// The index in the RedisCallbackManager for the callback that is called
/// when we receive notifications. This is >= 0 iff we have subscribed to the
/// table, otherwise -1.
int64_t subscribe_callback_index_;
};
class ObjectTable : public Table<ObjectID, ObjectTableData> {
@ -167,31 +149,6 @@ class ObjectTable : public Table<ObjectID, ObjectTableData> {
pubsub_channel_ = TablePubsub_OBJECT;
prefix_ = TablePrefix_OBJECT;
};
/// Set up a client-specific channel for receiving notifications about
/// available
/// objects from the object table. The callback will be called once per
/// notification received on this channel.
///
/// \param subscribe_all
/// \param object_available_callback Callback to be called when new object
/// becomes available.
/// \param done_callback Callback to be called when subscription is installed.
/// This is only used for the tests.
/// \return Status
Status SubscribeToNotifications(const JobID &job_id, bool subscribe_all,
const Callback &object_available, const Callback &done);
/// Request notifications about the availability of some objects from the
/// object
/// table. The notifications will be published to this client's object
/// notification channel, which was set up by the method
/// ObjectTableSubscribeToNotifications.
///
/// \param object_ids The object IDs to receive notifications about.
/// \return Status
Status RequestNotifications(const JobID &job_id,
const std::vector<ObjectID> &object_ids);
};
class FunctionTable : public Table<ObjectID, FunctionTableData> {
@ -240,11 +197,13 @@ class TaskTable : public Table<TaskID, TaskTableData> {
std::shared_ptr<TaskTableTestAndUpdateT> data,
const TestAndUpdateCallback &callback) {
int64_t callback_index = RedisCallbackManager::instance().add(
[this, callback, id](const std::string &data) {
[this, callback, id](const std::vector<std::string> &data) {
RAY_CHECK(data.size() == 1);
auto result = std::make_shared<TaskTableDataT>();
auto root = flatbuffers::GetRoot<TaskTableData>(data.data());
auto root = flatbuffers::GetRoot<TaskTableData>(data[0].data());
root->UnPackTo(result.get());
callback(client_, id, *result, root->updated());
return true;
});
flatbuffers::FlatBufferBuilder fbb;
fbb.Finish(TaskTableTestAndUpdate::Pack(fbb, data.get()));
@ -293,6 +252,8 @@ Status TaskTableTestAndUpdate(AsyncGcsClient *gcs_client, const TaskID &task_id,
class ClientTable : private Table<ClientID, ClientTableData> {
public:
using ClientTableCallback = std::function<void(
AsyncGcsClient *client, const ClientID &id, const ClientTableDataT &data)>;
ClientTable(const std::shared_ptr<RedisContext> &context, AsyncGcsClient *client,
const ClientTableDataT &local_client)
: Table(context, client),
@ -324,12 +285,12 @@ class ClientTable : private Table<ClientID, ClientTableData> {
/// Register a callback to call when a new client is added.
///
/// \param callback The callback to register.
void RegisterClientAddedCallback(const Callback &callback);
void RegisterClientAddedCallback(const ClientTableCallback &callback);
/// Register a callback to call when a client is removed.
///
/// \param callback The callback to register.
void RegisterClientRemovedCallback(const Callback &callback);
void RegisterClientRemovedCallback(const ClientTableCallback &callback);
/// Get a client's information from the cache. The cache only contains
/// information for clients that we've heard a notification for.
@ -352,10 +313,10 @@ class ClientTable : private Table<ClientID, ClientTableData> {
private:
/// Handle a client table notification.
void HandleNotification(AsyncGcsClient *client, const ClientID &channel_id,
std::shared_ptr<ClientTableDataT>);
const ClientTableDataT &notifications);
/// Handle this client's successful connection to the GCS.
void HandleConnected(AsyncGcsClient *client, const ClientID &client_id,
std::shared_ptr<ClientTableDataT>);
const ClientTableDataT &notifications);
/// Whether this client has called Disconnect().
bool disconnected_;
@ -364,9 +325,9 @@ class ClientTable : private Table<ClientID, ClientTableData> {
/// Information about this client.
ClientTableDataT local_client_;
/// The callback to call when a new client is added.
Callback client_added_callback_;
ClientTableCallback client_added_callback_;
/// The callback to call when a client is removed.
Callback client_removed_callback_;
ClientTableCallback client_removed_callback_;
/// A cache for information about all clients.
std::unordered_map<ClientID, ClientTableDataT, UniqueIDHasher> client_cache_;
};

View file

@ -44,9 +44,9 @@ Status TaskTableAdd(AsyncGcsClient *gcs_client, Task *task) {
TaskSpec *spec = execution_spec.Spec();
auto data = MakeTaskTableData(execution_spec, Task_local_scheduler(task),
static_cast<SchedulingState>(Task_state(task)));
return gcs_client->task_table().Add(ray::JobID::nil(), TaskSpec_task_id(spec), data,
[](gcs::AsyncGcsClient *client, const TaskID &id,
std::shared_ptr<TaskTableDataT> data) {});
return gcs_client->task_table().Add(
ray::JobID::nil(), TaskSpec_task_id(spec), data,
[](gcs::AsyncGcsClient *client, const TaskID &id, const TaskTableDataT &data) {});
}
// TODO(pcm): This is a helper method that should go away once we get rid of