mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
Request and cancel notifications in the new GCS API (#1758)
* Add TableRequestNotifications and TableCancelNotifications to Redis modules * Add RequestNotifications and CancelNotifications to generic GCS Table * Add tests for subscribing to specific keys * Remove TODO! * Return the current value at the key directly from RequestNotifications instead of through publish * Add unit test for Lookup failure callback * Modify tests to account for empty subscription response * Remove ObjectTable notification methods * Clean up message parsing and doc in redis context * Use vectors of DataT in all GCS callbacks * Clean up SubscriptionCallback * Move Table definitions into tables.cc * Refactor and document redis modules * doc * Fix new GCS build * Cleanups * Revert "Fix new GCS build" This reverts commit 6e3e69090c67ef60aaf22a9cf62be0290d989e96. * Use vectors for internal callback interface, user-facing interface takes a reference to a single item * Fix new GCS build * Add unit test for Lookup failure callback * Fix compiler errors * Cleanup * Publish the entry ID with the notification * Check that the ID for a notification matches in client tests
This commit is contained in:
parent
0c835a379f
commit
8704c8618c
9 changed files with 822 additions and 297 deletions
|
@ -53,6 +53,32 @@ static const char *table_prefixes[] = {
|
|||
NULL, "TASK:", "CLIENT:", "OBJECT:", "FUNCTION:",
|
||||
};
|
||||
|
||||
/// Parse a Redis string into a TablePubsub channel.
|
||||
TablePubsub ParseTablePubsub(const RedisModuleString *pubsub_channel_str) {
|
||||
long long pubsub_channel_long;
|
||||
RAY_CHECK(RedisModule_StringToLongLong(
|
||||
pubsub_channel_str, &pubsub_channel_long) == REDISMODULE_OK)
|
||||
<< "Pubsub channel must be a valid TablePubsub";
|
||||
auto pubsub_channel = static_cast<TablePubsub>(pubsub_channel_long);
|
||||
RAY_CHECK(pubsub_channel >= TablePubsub_MIN &&
|
||||
pubsub_channel <= TablePubsub_MAX)
|
||||
<< "Pubsub channel must be a valid TablePubsub";
|
||||
return pubsub_channel;
|
||||
}
|
||||
|
||||
/// Format a pubsub channel for a specific key. pubsub_channel_str should
|
||||
/// contain a valid TablePubsub.
|
||||
RedisModuleString *FormatPubsubChannel(
|
||||
RedisModuleCtx *ctx,
|
||||
const RedisModuleString *pubsub_channel_str,
|
||||
const RedisModuleString *id) {
|
||||
// Format the pubsub channel enum to a string. TablePubsub_MAX should be more
|
||||
// than enough digits, but add 1 just in case for the null terminator.
|
||||
char pubsub_channel[TablePubsub_MAX + 1];
|
||||
sprintf(pubsub_channel, "%d", ParseTablePubsub(pubsub_channel_str));
|
||||
return RedisString_Format(ctx, "%s:%S", pubsub_channel, id);
|
||||
}
|
||||
|
||||
// TODO(swang): This helper function should be deprecated by the version below,
|
||||
// which uses enums for table prefixes.
|
||||
RedisModuleKey *OpenPrefixedKey(RedisModuleCtx *ctx,
|
||||
|
@ -83,6 +109,23 @@ RedisModuleKey *OpenPrefixedKey(RedisModuleCtx *ctx,
|
|||
return OpenPrefixedKey(ctx, table_prefixes[prefix], keyname, mode);
|
||||
}
|
||||
|
||||
/// Open the key used to store the channels that should be published to when an
|
||||
/// update happens at the given keyname.
|
||||
RedisModuleKey *OpenBroadcastKey(RedisModuleCtx *ctx,
|
||||
RedisModuleString *pubsub_channel_str,
|
||||
RedisModuleString *keyname,
|
||||
int mode) {
|
||||
RedisModuleString *channel =
|
||||
FormatPubsubChannel(ctx, pubsub_channel_str, keyname);
|
||||
RedisModuleString *prefixed_keyname =
|
||||
RedisString_Format(ctx, "BCAST:%S", channel);
|
||||
RedisModuleKey *key =
|
||||
(RedisModuleKey *) RedisModule_OpenKey(ctx, prefixed_keyname, mode);
|
||||
RedisModule_FreeString(ctx, prefixed_keyname);
|
||||
RedisModule_FreeString(ctx, channel);
|
||||
return key;
|
||||
}
|
||||
|
||||
/**
|
||||
* This is a helper method to convert a redis module string to a flatbuffer
|
||||
* string.
|
||||
|
@ -411,8 +454,181 @@ bool PublishObjectNotification(RedisModuleCtx *ctx,
|
|||
return true;
|
||||
}
|
||||
|
||||
// This is a temporary redis command that will be removed once
|
||||
// NOTE(pcmoritz): This is a temporary redis command that will be removed once
|
||||
// the GCS uses https://github.com/pcmoritz/credis.
|
||||
int TaskTableAdd(RedisModuleCtx *ctx,
|
||||
RedisModuleString *id,
|
||||
RedisModuleString *data) {
|
||||
const char *buf = RedisModule_StringPtrLen(data, NULL);
|
||||
auto message = flatbuffers::GetRoot<TaskTableData>(buf);
|
||||
|
||||
if (message->scheduling_state() == SchedulingState_WAITING ||
|
||||
message->scheduling_state() == SchedulingState_SCHEDULED) {
|
||||
/* Build the PUBLISH topic and message for task table subscribers. The
|
||||
* topic
|
||||
* is a string in the format "TASK_PREFIX:<local scheduler ID>:<state>".
|
||||
* The
|
||||
* message is a serialized SubscribeToTasksReply flatbuffer object. */
|
||||
std::string state = std::to_string(message->scheduling_state());
|
||||
RedisModuleString *publish_topic = RedisString_Format(
|
||||
ctx, "%s%b:%s", TASK_PREFIX, message->scheduler_id()->str().data(),
|
||||
sizeof(DBClientID), state.c_str());
|
||||
|
||||
/* Construct the flatbuffers object for the payload. */
|
||||
flatbuffers::FlatBufferBuilder fbb;
|
||||
/* Create the flatbuffers message. */
|
||||
auto msg = CreateTaskReply(
|
||||
fbb, RedisStringToFlatbuf(fbb, id), message->scheduling_state(),
|
||||
fbb.CreateString(message->scheduler_id()),
|
||||
fbb.CreateString(message->execution_dependencies()),
|
||||
fbb.CreateString(message->task_info()), message->spillback_count(),
|
||||
true /* not used */);
|
||||
fbb.Finish(msg);
|
||||
|
||||
RedisModuleString *publish_message = RedisModule_CreateString(
|
||||
ctx, (const char *) fbb.GetBufferPointer(), fbb.GetSize());
|
||||
|
||||
RedisModuleCallReply *reply =
|
||||
RedisModule_Call(ctx, "PUBLISH", "ss", publish_topic, publish_message);
|
||||
|
||||
/* See how many clients received this publish. */
|
||||
long long num_clients = RedisModule_CallReplyInteger(reply);
|
||||
RAY_CHECK(num_clients <= 1) << "Published to " << num_clients
|
||||
<< " clients.";
|
||||
|
||||
RedisModule_FreeString(ctx, publish_message);
|
||||
RedisModule_FreeString(ctx, publish_topic);
|
||||
}
|
||||
return RedisModule_ReplyWithSimpleString(ctx, "OK");
|
||||
}
|
||||
|
||||
// TODO(swang): Implement the client table as an append-only log so that we
|
||||
// don't need this special case for client table publication.
|
||||
int ClientTableAdd(RedisModuleCtx *ctx,
|
||||
RedisModuleString *pubsub_channel_str,
|
||||
RedisModuleString *data) {
|
||||
const char *buf = RedisModule_StringPtrLen(data, NULL);
|
||||
auto client_data = flatbuffers::GetRoot<ClientTableData>(buf);
|
||||
|
||||
RedisModuleKey *clients_key = (RedisModuleKey *) RedisModule_OpenKey(
|
||||
ctx, pubsub_channel_str, REDISMODULE_READ | REDISMODULE_WRITE);
|
||||
// If this is a client addition, send all previous notifications, in order.
|
||||
// NOTE(swang): This will go to all clients, so some clients will get
|
||||
// duplicate notifications.
|
||||
if (client_data->is_insertion() &&
|
||||
RedisModule_KeyType(clients_key) != REDISMODULE_KEYTYPE_EMPTY) {
|
||||
// NOTE(swang): Sets are not implemented yet, so we use ZSETs instead.
|
||||
CHECK_ERROR(RedisModule_ZsetFirstInScoreRange(
|
||||
clients_key, REDISMODULE_NEGATIVE_INFINITE,
|
||||
REDISMODULE_POSITIVE_INFINITE, 1, 1),
|
||||
"Unable to initialize zset iterator");
|
||||
do {
|
||||
RedisModuleString *message =
|
||||
RedisModule_ZsetRangeCurrentElement(clients_key, NULL);
|
||||
RedisModuleCallReply *reply =
|
||||
RedisModule_Call(ctx, "PUBLISH", "ss", pubsub_channel_str, message);
|
||||
if (reply == NULL) {
|
||||
RedisModule_CloseKey(clients_key);
|
||||
return RedisModule_ReplyWithError(ctx, "error during PUBLISH");
|
||||
}
|
||||
} while (RedisModule_ZsetRangeNext(clients_key));
|
||||
}
|
||||
|
||||
// Append this notification to the past notifications so that it will get
|
||||
// sent to new clients in the future.
|
||||
size_t index = RedisModule_ValueLength(clients_key);
|
||||
// Serialize the notification to send.
|
||||
flatbuffers::FlatBufferBuilder fbb;
|
||||
auto message = CreateGcsNotification(fbb, fbb.CreateString(""),
|
||||
RedisStringToFlatbuf(fbb, data));
|
||||
fbb.Finish(message);
|
||||
auto notification = RedisModule_CreateString(
|
||||
ctx, reinterpret_cast<const char *>(fbb.GetBufferPointer()),
|
||||
fbb.GetSize());
|
||||
RedisModule_ZsetAdd(clients_key, index, notification, NULL);
|
||||
// Publish the notification about this client.
|
||||
RedisModuleCallReply *reply =
|
||||
RedisModule_Call(ctx, "PUBLISH", "ss", pubsub_channel_str, notification);
|
||||
RedisModule_FreeString(ctx, notification);
|
||||
if (reply == NULL) {
|
||||
RedisModule_CloseKey(clients_key);
|
||||
return RedisModule_ReplyWithError(ctx, "error during PUBLISH");
|
||||
}
|
||||
|
||||
RedisModule_CloseKey(clients_key);
|
||||
return RedisModule_ReplyWithSimpleString(ctx, "OK");
|
||||
}
|
||||
|
||||
/// Publish a notification for a new entry at a key. This publishes a
|
||||
/// notification to all subscribers of the table, as well as every client that
|
||||
/// has requested notifications for this key.
|
||||
///
|
||||
/// \param pubsub_channel_str The pubsub channel name that notifications for
|
||||
/// this key should be published to. When publishing to a specific
|
||||
/// client, the channel name should be <pubsub_channel>:<client_id>.
|
||||
/// \param id The ID of the key that the notification is about.
|
||||
/// \param data The data to publish.
|
||||
/// \return OK if there is no error during a publish.
|
||||
int PublishTableAdd(RedisModuleCtx *ctx,
|
||||
RedisModuleString *pubsub_channel_str,
|
||||
RedisModuleString *id,
|
||||
RedisModuleString *data) {
|
||||
// Serialize the notification to send.
|
||||
flatbuffers::FlatBufferBuilder fbb;
|
||||
auto message = CreateGcsNotification(fbb, RedisStringToFlatbuf(fbb, id),
|
||||
RedisStringToFlatbuf(fbb, data));
|
||||
fbb.Finish(message);
|
||||
|
||||
// Write the data back to any subscribers that are listening to all table
|
||||
// notifications.
|
||||
RedisModuleCallReply *reply =
|
||||
RedisModule_Call(ctx, "PUBLISH", "sb", pubsub_channel_str,
|
||||
fbb.GetBufferPointer(), fbb.GetSize());
|
||||
if (reply == NULL) {
|
||||
return RedisModule_ReplyWithError(ctx, "error during PUBLISH");
|
||||
}
|
||||
|
||||
// Publish the data to any clients who requested notifications on this key.
|
||||
RedisModuleKey *notification_key = OpenBroadcastKey(
|
||||
ctx, pubsub_channel_str, id, REDISMODULE_READ | REDISMODULE_WRITE);
|
||||
if (RedisModule_KeyType(notification_key) != REDISMODULE_KEYTYPE_EMPTY) {
|
||||
// NOTE(swang): Sets are not implemented yet, so we use ZSETs instead.
|
||||
CHECK_ERROR(RedisModule_ZsetFirstInScoreRange(
|
||||
notification_key, REDISMODULE_NEGATIVE_INFINITE,
|
||||
REDISMODULE_POSITIVE_INFINITE, 1, 1),
|
||||
"Unable to initialize zset iterator");
|
||||
for (; !RedisModule_ZsetRangeEndReached(notification_key);
|
||||
RedisModule_ZsetRangeNext(notification_key)) {
|
||||
RedisModuleString *client_channel =
|
||||
RedisModule_ZsetRangeCurrentElement(notification_key, NULL);
|
||||
RedisModuleCallReply *reply =
|
||||
RedisModule_Call(ctx, "PUBLISH", "sb", client_channel,
|
||||
fbb.GetBufferPointer(), fbb.GetSize());
|
||||
if (reply == NULL) {
|
||||
RedisModule_CloseKey(notification_key);
|
||||
return RedisModule_ReplyWithError(ctx, "error during PUBLISH");
|
||||
}
|
||||
}
|
||||
}
|
||||
RedisModule_CloseKey(notification_key);
|
||||
return RedisModule_ReplyWithSimpleString(ctx, "OK");
|
||||
}
|
||||
|
||||
/// Add an entry at a key. This overwrites any existing data at the key.
|
||||
/// Publishes a notification about the update to all subscribers, if a pubsub
|
||||
/// channel is provided.
|
||||
///
|
||||
/// This is called from a client with the command:
|
||||
//
|
||||
/// RAY.TABLE_ADD <table_prefix> <pubsub_channel> <id> <data>
|
||||
///
|
||||
/// \param table_prefix The prefix string for keys in this table.
|
||||
/// \param pubsub_channel The pubsub channel name that notifications for
|
||||
/// this key should be published to. When publishing to a specific
|
||||
/// client, the channel name should be <pubsub_channel>:<client_id>.
|
||||
/// \param id The ID of the key to set.
|
||||
/// \param data The data to insert at the key.
|
||||
/// \return The current value at the key, or OK if there is no value.
|
||||
int TableAdd_RedisCommand(RedisModuleCtx *ctx,
|
||||
RedisModuleString **argv,
|
||||
int argc) {
|
||||
|
@ -431,108 +647,22 @@ int TableAdd_RedisCommand(RedisModuleCtx *ctx,
|
|||
RedisModule_StringSet(key, data);
|
||||
RedisModule_CloseKey(key);
|
||||
|
||||
// Get the requested pubsub channel.
|
||||
long long pubsub_channel_long;
|
||||
RAY_CHECK(RedisModule_StringToLongLong(
|
||||
pubsub_channel_str, &pubsub_channel_long) == REDISMODULE_OK)
|
||||
<< "Pubsub channel must be a valid TablePubsub";
|
||||
auto pubsub_channel = static_cast<TablePubsub>(pubsub_channel_long);
|
||||
RAY_CHECK(pubsub_channel >= TablePubsub_MIN &&
|
||||
pubsub_channel <= TablePubsub_MAX)
|
||||
<< "Pubsub channel must be a valid TablePubsub";
|
||||
|
||||
// Publish a message on the requested pubsub channel if necessary.
|
||||
TablePubsub pubsub_channel = ParseTablePubsub(pubsub_channel_str);
|
||||
if (pubsub_channel == TablePubsub_TASK) {
|
||||
const char *buf = RedisModule_StringPtrLen(data, NULL);
|
||||
auto message = flatbuffers::GetRoot<TaskTableData>(buf);
|
||||
|
||||
if (message->scheduling_state() == SchedulingState_WAITING ||
|
||||
message->scheduling_state() == SchedulingState_SCHEDULED) {
|
||||
/* Build the PUBLISH topic and message for task table subscribers. The
|
||||
* topic
|
||||
* is a string in the format "TASK_PREFIX:<local scheduler ID>:<state>".
|
||||
* The
|
||||
* message is a serialized SubscribeToTasksReply flatbuffer object. */
|
||||
std::string state = std::to_string(message->scheduling_state());
|
||||
RedisModuleString *publish_topic = RedisString_Format(
|
||||
ctx, "%s%b:%s", TASK_PREFIX, message->scheduler_id()->str().data(),
|
||||
sizeof(DBClientID), state.c_str());
|
||||
|
||||
/* Construct the flatbuffers object for the payload. */
|
||||
flatbuffers::FlatBufferBuilder fbb;
|
||||
/* Create the flatbuffers message. */
|
||||
auto msg = CreateTaskReply(
|
||||
fbb, RedisStringToFlatbuf(fbb, id), message->scheduling_state(),
|
||||
fbb.CreateString(message->scheduler_id()),
|
||||
fbb.CreateString(message->execution_dependencies()),
|
||||
fbb.CreateString(message->task_info()), message->spillback_count(),
|
||||
true /* not used */);
|
||||
fbb.Finish(msg);
|
||||
|
||||
RedisModuleString *publish_message = RedisModule_CreateString(
|
||||
ctx, (const char *) fbb.GetBufferPointer(), fbb.GetSize());
|
||||
|
||||
RedisModuleCallReply *reply = RedisModule_Call(
|
||||
ctx, "PUBLISH", "ss", publish_topic, publish_message);
|
||||
|
||||
/* See how many clients received this publish. */
|
||||
long long num_clients = RedisModule_CallReplyInteger(reply);
|
||||
RAY_CHECK(num_clients <= 1) << "Published to " << num_clients
|
||||
<< " clients.";
|
||||
|
||||
RedisModule_FreeString(ctx, publish_message);
|
||||
RedisModule_FreeString(ctx, publish_topic);
|
||||
}
|
||||
// Publish the task to its subscribers.
|
||||
// TODO(swang): This is only necessary for legacy Ray and should be removed
|
||||
// once we switch to using the new GCS API for the task table.
|
||||
return TaskTableAdd(ctx, id, data);
|
||||
} else if (pubsub_channel == TablePubsub_CLIENT) {
|
||||
const char *buf = RedisModule_StringPtrLen(data, NULL);
|
||||
auto client_data = flatbuffers::GetRoot<ClientTableData>(buf);
|
||||
|
||||
RedisModuleKey *clients_key = (RedisModuleKey *) RedisModule_OpenKey(
|
||||
ctx, pubsub_channel_str, REDISMODULE_READ | REDISMODULE_WRITE);
|
||||
// If this is a client addition, send all previous notifications, in order.
|
||||
// NOTE(swang): This will go to all clients, so some clients will get
|
||||
// duplicate notifications.
|
||||
if (client_data->is_insertion() &&
|
||||
RedisModule_KeyType(clients_key) != REDISMODULE_KEYTYPE_EMPTY) {
|
||||
// NOTE(swang): Sets are not implemented yet, so we use ZSETs instead.
|
||||
CHECK_ERROR(RedisModule_ZsetFirstInScoreRange(
|
||||
clients_key, REDISMODULE_NEGATIVE_INFINITE,
|
||||
REDISMODULE_POSITIVE_INFINITE, 1, 1),
|
||||
"Unable to initialize zset iterator");
|
||||
do {
|
||||
RedisModuleString *message =
|
||||
RedisModule_ZsetRangeCurrentElement(clients_key, NULL);
|
||||
RedisModuleCallReply *reply =
|
||||
RedisModule_Call(ctx, "PUBLISH", "ss", pubsub_channel_str, message);
|
||||
if (reply == NULL) {
|
||||
RedisModule_CloseKey(clients_key);
|
||||
RedisModule_ReplyWithError(ctx, "error during PUBLISH");
|
||||
}
|
||||
} while (RedisModule_ZsetRangeNext(clients_key));
|
||||
}
|
||||
|
||||
// Append this notification to the past notifications so that it will get
|
||||
// sent to new clients in the future.
|
||||
size_t index = RedisModule_ValueLength(key);
|
||||
RedisModule_ZsetAdd(clients_key, index, data, NULL);
|
||||
// Publish the notification about this client.
|
||||
RedisModuleCallReply *reply =
|
||||
RedisModule_Call(ctx, "PUBLISH", "ss", pubsub_channel_str, data);
|
||||
if (reply == NULL) {
|
||||
RedisModule_ReplyWithError(ctx, "error during PUBLISH");
|
||||
}
|
||||
|
||||
RedisModule_CloseKey(clients_key);
|
||||
// Publish all previous client table additions to the new client.
|
||||
return ClientTableAdd(ctx, pubsub_channel_str, data);
|
||||
} else if (pubsub_channel != TablePubsub_NO_PUBLISH) {
|
||||
// All other pubsub channels write the data back directly onto the channel.
|
||||
RedisModuleCallReply *reply =
|
||||
RedisModule_Call(ctx, "PUBLISH", "ss", pubsub_channel_str, data);
|
||||
if (reply == NULL) {
|
||||
RedisModule_ReplyWithError(ctx, "error during PUBLISH");
|
||||
}
|
||||
return PublishTableAdd(ctx, pubsub_channel_str, id, data);
|
||||
} else {
|
||||
return RedisModule_ReplyWithSimpleString(ctx, "OK");
|
||||
}
|
||||
|
||||
return RedisModule_ReplyWithSimpleString(ctx, "OK");
|
||||
}
|
||||
|
||||
// This is a temporary redis command that will be removed once
|
||||
|
@ -561,6 +691,114 @@ int TableLookup_RedisCommand(RedisModuleCtx *ctx,
|
|||
return REDISMODULE_OK;
|
||||
}
|
||||
|
||||
/// Request notifications for changes to a key. Returns the current value or
|
||||
/// values at the key. Notifications will be sent to the requesting client for
|
||||
/// every subsequent TABLE_ADD to the key.
|
||||
///
|
||||
/// This is called from a client with the command:
|
||||
//
|
||||
/// RAY.TABLE_REQUEST_NOTIFICATIONS <table_prefix> <pubsub_channel> <id>
|
||||
/// <client_id>
|
||||
///
|
||||
/// \param table_prefix The prefix string for keys in this table.
|
||||
/// \param pubsub_channel The pubsub channel name that notifications for
|
||||
/// this key should be published to. When publishing to a specific
|
||||
/// client, the channel name should be <pubsub_channel>:<client_id>.
|
||||
/// \param id The ID of the key to publish notifications for.
|
||||
/// \param client_id The ID of the client that is being notified.
|
||||
/// \return The current value at the key, or OK if there is no value.
|
||||
int TableRequestNotifications_RedisCommand(RedisModuleCtx *ctx,
|
||||
RedisModuleString **argv,
|
||||
int argc) {
|
||||
if (argc != 5) {
|
||||
return RedisModule_WrongArity(ctx);
|
||||
}
|
||||
|
||||
RedisModuleString *prefix_str = argv[1];
|
||||
RedisModuleString *pubsub_channel_str = argv[2];
|
||||
RedisModuleString *id = argv[3];
|
||||
RedisModuleString *client_id = argv[4];
|
||||
RedisModuleString *client_channel =
|
||||
FormatPubsubChannel(ctx, pubsub_channel_str, client_id);
|
||||
|
||||
// Add this client to the set of clients that should be notified when there
|
||||
// are changes to the key.
|
||||
RedisModuleKey *notification_key = OpenBroadcastKey(
|
||||
ctx, pubsub_channel_str, id, REDISMODULE_READ | REDISMODULE_WRITE);
|
||||
CHECK_ERROR(RedisModule_ZsetAdd(notification_key, 0.0, client_channel, NULL),
|
||||
"ZsetAdd failed.");
|
||||
RedisModule_CloseKey(notification_key);
|
||||
RedisModule_FreeString(ctx, client_channel);
|
||||
|
||||
// Return the current value at the key, if any, to the client that requested
|
||||
// a notification.
|
||||
RedisModuleKey *table_key =
|
||||
OpenPrefixedKey(ctx, prefix_str, id, REDISMODULE_READ);
|
||||
if (table_key != nullptr) {
|
||||
// Serialize the notification to send.
|
||||
size_t data_len = 0;
|
||||
char *data_buf =
|
||||
RedisModule_StringDMA(table_key, &data_len, REDISMODULE_READ);
|
||||
flatbuffers::FlatBufferBuilder fbb;
|
||||
auto message = CreateGcsNotification(fbb, RedisStringToFlatbuf(fbb, id),
|
||||
fbb.CreateString(data_buf, data_len));
|
||||
fbb.Finish(message);
|
||||
|
||||
int result = RedisModule_ReplyWithStringBuffer(
|
||||
ctx, reinterpret_cast<const char *>(fbb.GetBufferPointer()),
|
||||
fbb.GetSize());
|
||||
RedisModule_CloseKey(table_key);
|
||||
return result;
|
||||
} else {
|
||||
RedisModule_CloseKey(table_key);
|
||||
RedisModule_ReplyWithSimpleString(ctx, "OK");
|
||||
return REDISMODULE_OK;
|
||||
}
|
||||
}
|
||||
|
||||
/// Cancel notifications for changes to a key. The client will no longer
|
||||
/// receive notifications for this key.
|
||||
///
|
||||
/// This is called from a client with the command:
|
||||
//
|
||||
/// RAY.TABLE_CANCEL_NOTIFICATIONS <table_prefix> <pubsub_channel> <id>
|
||||
/// <client_id>
|
||||
///
|
||||
/// \param table_prefix The prefix string for keys in this table.
|
||||
/// \param pubsub_channel The pubsub channel name that notifications for
|
||||
/// this key should be published to. If publishing to a specific client,
|
||||
/// then the channel name should be <pubsub_channel>:<client_id>.
|
||||
/// \param id The ID of the key to publish notifications for.
|
||||
/// \param client_id The ID of the client that is being notified.
|
||||
/// \return OK if the requesting client was removed, or an error if the client
|
||||
/// was not found.
|
||||
int TableCancelNotifications_RedisCommand(RedisModuleCtx *ctx,
|
||||
RedisModuleString **argv,
|
||||
int argc) {
|
||||
if (argc < 5) {
|
||||
return RedisModule_WrongArity(ctx);
|
||||
}
|
||||
|
||||
RedisModuleString *pubsub_channel_str = argv[2];
|
||||
RedisModuleString *id = argv[3];
|
||||
RedisModuleString *client_id = argv[4];
|
||||
RedisModuleString *client_channel =
|
||||
FormatPubsubChannel(ctx, pubsub_channel_str, client_id);
|
||||
|
||||
// Remove this client from the set of clients that should be notified when
|
||||
// there are changes to the key.
|
||||
RedisModuleKey *notification_key = OpenBroadcastKey(
|
||||
ctx, pubsub_channel_str, id, REDISMODULE_READ | REDISMODULE_WRITE);
|
||||
RAY_CHECK(RedisModule_KeyType(notification_key) != REDISMODULE_KEYTYPE_EMPTY);
|
||||
int deleted;
|
||||
RedisModule_ZsetRem(notification_key, client_channel, &deleted);
|
||||
RAY_CHECK(deleted);
|
||||
RedisModule_CloseKey(notification_key);
|
||||
|
||||
RedisModule_ReplyWithSimpleString(ctx, "OK");
|
||||
return REDISMODULE_OK;
|
||||
}
|
||||
|
||||
bool is_nil(const std::string &data) {
|
||||
RAY_CHECK(data.size() == kUniqueIDSize);
|
||||
const uint8_t *d = reinterpret_cast<const uint8_t *>(data.data());
|
||||
|
@ -1429,6 +1667,18 @@ int RedisModule_OnLoad(RedisModuleCtx *ctx,
|
|||
return REDISMODULE_ERR;
|
||||
}
|
||||
|
||||
if (RedisModule_CreateCommand(ctx, "ray.table_request_notifications",
|
||||
TableRequestNotifications_RedisCommand,
|
||||
"write pubsub", 0, 0, 0) == REDISMODULE_ERR) {
|
||||
return REDISMODULE_ERR;
|
||||
}
|
||||
|
||||
if (RedisModule_CreateCommand(ctx, "ray.table_cancel_notifications",
|
||||
TableCancelNotifications_RedisCommand,
|
||||
"write pubsub", 0, 0, 0) == REDISMODULE_ERR) {
|
||||
return REDISMODULE_ERR;
|
||||
}
|
||||
|
||||
if (RedisModule_CreateCommand(ctx, "ray.table_test_and_update",
|
||||
TableTestAndUpdate_RedisCommand, "write", 0, 0,
|
||||
0) == REDISMODULE_ERR) {
|
||||
|
|
|
@ -1332,10 +1332,10 @@ void log_object_hash_mismatch_error_result_callback(ObjectID object_id,
|
|||
RAY_CHECK_OK(state->gcs_client.task_table().Lookup(
|
||||
ray::JobID::nil(), task_id,
|
||||
[user_context](gcs::AsyncGcsClient *, const TaskID &,
|
||||
std::shared_ptr<TaskTableDataT> t) {
|
||||
const TaskTableDataT &t) {
|
||||
Task *task = Task_alloc(
|
||||
t->task_info.data(), t->task_info.size(), t->scheduling_state,
|
||||
DBClientID::from_binary(t->scheduler_id), std::vector<ObjectID>());
|
||||
t.task_info.data(), t.task_info.size(), t.scheduling_state,
|
||||
DBClientID::from_binary(t.scheduler_id), std::vector<ObjectID>());
|
||||
log_object_hash_mismatch_error_task_callback(task, user_context);
|
||||
Task_free(task);
|
||||
},
|
||||
|
|
|
@ -21,7 +21,7 @@ static inline void flushall_redis(void) {
|
|||
|
||||
class TestGcs : public ::testing::Test {
|
||||
public:
|
||||
TestGcs() {
|
||||
TestGcs() : num_callbacks_(0) {
|
||||
client_ = std::make_shared<gcs::AsyncGcsClient>();
|
||||
ClientTableDataT client_info;
|
||||
client_info.client_id = ClientID::from_random().binary();
|
||||
|
@ -42,7 +42,12 @@ class TestGcs : public ::testing::Test {
|
|||
|
||||
virtual void Stop() = 0;
|
||||
|
||||
int64_t NumCallbacks() const { return num_callbacks_; }
|
||||
|
||||
void IncrementNumCallbacks() { num_callbacks_++; }
|
||||
|
||||
protected:
|
||||
int64_t num_callbacks_;
|
||||
std::shared_ptr<gcs::AsyncGcsClient> client_;
|
||||
JobID job_id_;
|
||||
};
|
||||
|
@ -87,14 +92,14 @@ class TestGcsWithAsio : public TestGcs {
|
|||
};
|
||||
|
||||
void ObjectAdded(gcs::AsyncGcsClient *client, const UniqueID &id,
|
||||
std::shared_ptr<ObjectTableDataT> data) {
|
||||
ASSERT_EQ(data->managers, std::vector<std::string>({"A", "B"}));
|
||||
const ObjectTableDataT &data) {
|
||||
ASSERT_EQ(data.managers, std::vector<std::string>({"A", "B"}));
|
||||
}
|
||||
|
||||
void Lookup(gcs::AsyncGcsClient *client, const UniqueID &id,
|
||||
std::shared_ptr<ObjectTableDataT> data) {
|
||||
const ObjectTableDataT &data) {
|
||||
// Check that the object entry was added.
|
||||
ASSERT_EQ(data->managers, std::vector<std::string>({"A", "B"}));
|
||||
ASSERT_EQ(data.managers, std::vector<std::string>({"A", "B"}));
|
||||
test->Stop();
|
||||
}
|
||||
|
||||
|
@ -126,14 +131,37 @@ TEST_F(TestGcsWithAsio, TestObjectTable) {
|
|||
TestObjectTable(job_id_, client_);
|
||||
}
|
||||
|
||||
void TestLookupFailure(const JobID &job_id, std::shared_ptr<gcs::AsyncGcsClient> client) {
|
||||
auto object_id = ObjectID::from_random();
|
||||
// Looking up an empty object ID should call the failure callback.
|
||||
auto failure_callback = [](gcs::AsyncGcsClient *client, const UniqueID &id) {
|
||||
test->Stop();
|
||||
};
|
||||
RAY_CHECK_OK(
|
||||
client->object_table().Lookup(job_id, object_id, nullptr, failure_callback));
|
||||
// Run the event loop. The loop will only stop if the failure callback is
|
||||
// called.
|
||||
test->Start();
|
||||
}
|
||||
|
||||
TEST_F(TestGcsWithAe, TestLookupFailure) {
|
||||
test = this;
|
||||
TestLookupFailure(job_id_, client_);
|
||||
}
|
||||
|
||||
TEST_F(TestGcsWithAsio, TestLookupFailure) {
|
||||
test = this;
|
||||
TestLookupFailure(job_id_, client_);
|
||||
}
|
||||
|
||||
void TaskAdded(gcs::AsyncGcsClient *client, const TaskID &id,
|
||||
std::shared_ptr<TaskTableDataT> data) {
|
||||
ASSERT_EQ(data->scheduling_state, SchedulingState_SCHEDULED);
|
||||
const TaskTableDataT &data) {
|
||||
ASSERT_EQ(data.scheduling_state, SchedulingState_SCHEDULED);
|
||||
}
|
||||
|
||||
void TaskLookup(gcs::AsyncGcsClient *client, const TaskID &id,
|
||||
std::shared_ptr<TaskTableDataT> data) {
|
||||
ASSERT_EQ(data->scheduling_state, SchedulingState_SCHEDULED);
|
||||
const TaskTableDataT &data) {
|
||||
ASSERT_EQ(data.scheduling_state, SchedulingState_SCHEDULED);
|
||||
}
|
||||
|
||||
void TaskLookupFailure(gcs::AsyncGcsClient *client, const TaskID &id) {
|
||||
|
@ -141,8 +169,8 @@ void TaskLookupFailure(gcs::AsyncGcsClient *client, const TaskID &id) {
|
|||
}
|
||||
|
||||
void TaskLookupAfterUpdate(gcs::AsyncGcsClient *client, const TaskID &id,
|
||||
std::shared_ptr<TaskTableDataT> data) {
|
||||
ASSERT_EQ(data->scheduling_state, SchedulingState_LOST);
|
||||
const TaskTableDataT &data) {
|
||||
ASSERT_EQ(data.scheduling_state, SchedulingState_LOST);
|
||||
test->Stop();
|
||||
}
|
||||
|
||||
|
@ -153,8 +181,8 @@ void TaskLookupAfterUpdateFailure(gcs::AsyncGcsClient *client, const TaskID &id)
|
|||
|
||||
void TaskUpdateCallback(gcs::AsyncGcsClient *client, const TaskID &task_id,
|
||||
const TaskTableDataT &task, bool updated) {
|
||||
RAY_CHECK_OK(client->task_table().Lookup(
|
||||
DriverID::nil(), task_id, &TaskLookupAfterUpdate, &TaskLookupAfterUpdateFailure));
|
||||
RAY_CHECK_OK(client->task_table().Lookup(DriverID::nil(), task_id,
|
||||
&TaskLookupAfterUpdate, &TaskLookupFailure));
|
||||
}
|
||||
|
||||
void TestTaskTable(const JobID &job_id, std::shared_ptr<gcs::AsyncGcsClient> client) {
|
||||
|
@ -189,28 +217,40 @@ TEST_F(TestGcsWithAsio, TestTaskTable) {
|
|||
TestTaskTable(job_id_, client_);
|
||||
}
|
||||
|
||||
void ObjectTableSubscribed(gcs::AsyncGcsClient *client, const UniqueID &id,
|
||||
std::shared_ptr<ObjectTableDataT> data) {
|
||||
test->Stop();
|
||||
}
|
||||
|
||||
void TestSubscribeAll(const JobID &job_id, std::shared_ptr<gcs::AsyncGcsClient> client) {
|
||||
// Subscribe to all object table notifications. The registered callback for
|
||||
// notifications will check whether the object below is added.
|
||||
RAY_CHECK_OK(client->object_table().Subscribe(job_id, ClientID::nil(), &Lookup,
|
||||
&ObjectTableSubscribed));
|
||||
// Run the event loop. The loop will only stop if the subscription succeeds.
|
||||
test->Start();
|
||||
|
||||
// We have subscribed. Add an object table entry.
|
||||
auto data = std::make_shared<ObjectTableDataT>();
|
||||
data->managers.push_back("A");
|
||||
data->managers.push_back("B");
|
||||
ObjectID object_id = ObjectID::from_random();
|
||||
RAY_CHECK_OK(client->object_table().Add(job_id, object_id, data, &ObjectAdded));
|
||||
// Callback for a notification.
|
||||
auto notification_callback = [object_id](
|
||||
gcs::AsyncGcsClient *client, const UniqueID &id, const ObjectTableDataT &data) {
|
||||
ASSERT_EQ(id, object_id);
|
||||
// Check that the object entry was added.
|
||||
ASSERT_EQ(data.managers, std::vector<std::string>({"A", "B"}));
|
||||
test->IncrementNumCallbacks();
|
||||
test->Stop();
|
||||
};
|
||||
|
||||
// Callback for subscription success. This should only be called once.
|
||||
auto subscribe_callback = [job_id, object_id](gcs::AsyncGcsClient *client) {
|
||||
test->IncrementNumCallbacks();
|
||||
// We have subscribed. Add an object table entry.
|
||||
auto data = std::make_shared<ObjectTableDataT>();
|
||||
data->managers.push_back("A");
|
||||
data->managers.push_back("B");
|
||||
RAY_CHECK_OK(client->object_table().Add(job_id, object_id, data, &ObjectAdded));
|
||||
};
|
||||
|
||||
// Subscribe to all object table notifications. Once we have successfully
|
||||
// subscribed, we will add an object and check that we get notified of the
|
||||
// operation.
|
||||
RAY_CHECK_OK(client->object_table().Subscribe(
|
||||
job_id, ClientID::nil(), notification_callback, subscribe_callback));
|
||||
|
||||
// Run the event loop. The loop will only stop if the registered subscription
|
||||
// callback is called (or an assertion failure).
|
||||
test->Start();
|
||||
// Check that we received one callback for subscription success and one for
|
||||
// the Add notification.
|
||||
ASSERT_EQ(test->NumCallbacks(), 2);
|
||||
}
|
||||
|
||||
TEST_F(TestGcsWithAe, TestSubscribeAll) {
|
||||
|
@ -223,11 +263,152 @@ TEST_F(TestGcsWithAsio, TestSubscribeAll) {
|
|||
TestSubscribeAll(job_id_, client_);
|
||||
}
|
||||
|
||||
void TestSubscribeId(const JobID &job_id, std::shared_ptr<gcs::AsyncGcsClient> client) {
|
||||
// Add an object table entry.
|
||||
ObjectID object_id1 = ObjectID::from_random();
|
||||
auto data1 = std::make_shared<ObjectTableDataT>();
|
||||
data1->managers.push_back("A");
|
||||
data1->managers.push_back("B");
|
||||
RAY_CHECK_OK(client->object_table().Add(job_id, object_id1, data1, nullptr));
|
||||
|
||||
// Add a second object table entry.
|
||||
ObjectID object_id2 = ObjectID::from_random();
|
||||
auto data2 = std::make_shared<ObjectTableDataT>();
|
||||
data2->managers.push_back("C");
|
||||
RAY_CHECK_OK(client->object_table().Add(job_id, object_id2, data2, nullptr));
|
||||
|
||||
// The callback for subscription success. Once we've subscribed, request
|
||||
// notifications for the second object that was added.
|
||||
auto subscribe_callback = [job_id, object_id2](gcs::AsyncGcsClient *client) {
|
||||
test->IncrementNumCallbacks();
|
||||
// Request notifications for the second object. Since we already added the
|
||||
// entry to the table, we should receive an initial notification for its
|
||||
// current value.
|
||||
RAY_CHECK_OK(client->object_table().RequestNotifications(
|
||||
job_id, object_id2, client->client_table().GetLocalClientId()));
|
||||
// Overwrite the entry for the object. We should receive a second
|
||||
// notification for its new value.
|
||||
auto data = std::make_shared<ObjectTableDataT>();
|
||||
data->managers.push_back("C");
|
||||
data->managers.push_back("D");
|
||||
RAY_CHECK_OK(client->object_table().Add(job_id, object_id2, data, nullptr));
|
||||
};
|
||||
|
||||
// The callback for a notification from the object table. This should only be
|
||||
// received for the object that we requested notifications for.
|
||||
auto notification_callback = [data2, object_id2](
|
||||
gcs::AsyncGcsClient *client, const UniqueID &id, const ObjectTableDataT &data) {
|
||||
ASSERT_EQ(id, object_id2);
|
||||
// Check that we got a notification for the correct object.
|
||||
ASSERT_EQ(data.managers.front(), "C");
|
||||
test->IncrementNumCallbacks();
|
||||
// Stop the loop once we've received notifications for both writes to the
|
||||
// object key.
|
||||
if (test->NumCallbacks() == 3) {
|
||||
test->Stop();
|
||||
}
|
||||
};
|
||||
|
||||
RAY_CHECK_OK(
|
||||
client->object_table().Subscribe(job_id, client->client_table().GetLocalClientId(),
|
||||
notification_callback, subscribe_callback));
|
||||
|
||||
// Run the event loop. The loop will only stop if the registered subscription
|
||||
// callback is called for both writes to the object key.
|
||||
test->Start();
|
||||
// Check that we received one callback for subscription success and two
|
||||
// callbacks for the Add notifications.
|
||||
ASSERT_EQ(test->NumCallbacks(), 3);
|
||||
}
|
||||
|
||||
TEST_F(TestGcsWithAe, TestSubscribeId) {
|
||||
test = this;
|
||||
TestSubscribeId(job_id_, client_);
|
||||
}
|
||||
|
||||
TEST_F(TestGcsWithAsio, TestSubscribeId) {
|
||||
test = this;
|
||||
TestSubscribeId(job_id_, client_);
|
||||
}
|
||||
|
||||
void TestSubscribeCancel(const JobID &job_id,
|
||||
std::shared_ptr<gcs::AsyncGcsClient> client) {
|
||||
// Write the object table once.
|
||||
ObjectID object_id = ObjectID::from_random();
|
||||
auto data = std::make_shared<ObjectTableDataT>();
|
||||
data->managers.push_back("A");
|
||||
RAY_CHECK_OK(client->object_table().Add(job_id, object_id, data, nullptr));
|
||||
|
||||
// The callback for subscription success. Once we've subscribed, request
|
||||
// notifications for the second object that was added.
|
||||
auto subscribe_callback = [job_id, object_id](gcs::AsyncGcsClient *client) {
|
||||
test->IncrementNumCallbacks();
|
||||
// Request notifications for the object. We should receive a notification
|
||||
// for the current value at the key.
|
||||
RAY_CHECK_OK(client->object_table().RequestNotifications(
|
||||
job_id, object_id, client->client_table().GetLocalClientId()));
|
||||
// Cancel notifications.
|
||||
RAY_CHECK_OK(client->object_table().CancelNotifications(
|
||||
job_id, object_id, client->client_table().GetLocalClientId()));
|
||||
// Write the object table entry twice. Since we canceled notifications, we
|
||||
// should not get notifications for either of these writes.
|
||||
auto data = std::make_shared<ObjectTableDataT>();
|
||||
data->managers.push_back("B");
|
||||
RAY_CHECK_OK(client->object_table().Add(job_id, object_id, data, nullptr));
|
||||
data = std::make_shared<ObjectTableDataT>();
|
||||
data->managers.push_back("C");
|
||||
RAY_CHECK_OK(client->object_table().Add(job_id, object_id, data, nullptr));
|
||||
// Request notifications for the object again. We should only receive a
|
||||
// notification for the current value at the key.
|
||||
RAY_CHECK_OK(client->object_table().RequestNotifications(
|
||||
job_id, object_id, client->client_table().GetLocalClientId()));
|
||||
};
|
||||
|
||||
// The callback for a notification from the object table.
|
||||
auto notification_callback = [object_id](
|
||||
gcs::AsyncGcsClient *client, const UniqueID &id, const ObjectTableDataT &data) {
|
||||
ASSERT_EQ(id, object_id);
|
||||
// Check that we only receive notifications for the key when we have
|
||||
// requested notifications for it. We should not get a notification for the
|
||||
// entry that began with "B" since we canceled notifications then.
|
||||
if (test->NumCallbacks() == 1) {
|
||||
ASSERT_EQ(data.managers.front(), "A");
|
||||
} else {
|
||||
ASSERT_EQ(data.managers.front(), "C");
|
||||
}
|
||||
test->IncrementNumCallbacks();
|
||||
if (test->NumCallbacks() == 3) {
|
||||
test->Stop();
|
||||
}
|
||||
};
|
||||
|
||||
RAY_CHECK_OK(
|
||||
client->object_table().Subscribe(job_id, client->client_table().GetLocalClientId(),
|
||||
notification_callback, subscribe_callback));
|
||||
|
||||
// Run the event loop. The loop will only stop if the registered subscription
|
||||
// callback is called (or an assertion failure).
|
||||
test->Start();
|
||||
// Check that we received one callback for subscription success and two
|
||||
// callbacks for the Add notifications.
|
||||
ASSERT_EQ(test->NumCallbacks(), 3);
|
||||
}
|
||||
|
||||
TEST_F(TestGcsWithAe, TestSubscribeCancel) {
|
||||
test = this;
|
||||
TestSubscribeCancel(job_id_, client_);
|
||||
}
|
||||
|
||||
TEST_F(TestGcsWithAsio, TestSubscribeCancel) {
|
||||
test = this;
|
||||
TestSubscribeCancel(job_id_, client_);
|
||||
}
|
||||
|
||||
void ClientTableNotification(gcs::AsyncGcsClient *client, const UniqueID &id,
|
||||
std::shared_ptr<ClientTableDataT> data, bool is_insertion) {
|
||||
const ClientTableDataT &data, bool is_insertion) {
|
||||
ClientID added_id = client->client_table().GetLocalClientId();
|
||||
ASSERT_EQ(ClientID::from_binary(data->client_id), added_id);
|
||||
ASSERT_EQ(data->is_insertion, is_insertion);
|
||||
ASSERT_EQ(ClientID::from_binary(data.client_id), added_id);
|
||||
ASSERT_EQ(data.is_insertion, is_insertion);
|
||||
|
||||
auto cached_client = client->client_table().GetClient(added_id);
|
||||
ASSERT_EQ(ClientID::from_binary(cached_client.client_id), added_id);
|
||||
|
@ -239,8 +420,7 @@ void TestClientTableConnect(const JobID &job_id,
|
|||
// Register callbacks for when a client gets added and removed. The latter
|
||||
// event will stop the event loop.
|
||||
client->client_table().RegisterClientAddedCallback(
|
||||
[](gcs::AsyncGcsClient *client, const UniqueID &id,
|
||||
std::shared_ptr<ClientTableDataT> data) {
|
||||
[](gcs::AsyncGcsClient *client, const UniqueID &id, const ClientTableDataT &data) {
|
||||
ClientTableNotification(client, id, data, true);
|
||||
test->Stop();
|
||||
});
|
||||
|
@ -260,13 +440,11 @@ void TestClientTableDisconnect(const JobID &job_id,
|
|||
// Register callbacks for when a client gets added and removed. The latter
|
||||
// event will stop the event loop.
|
||||
client->client_table().RegisterClientAddedCallback(
|
||||
[](gcs::AsyncGcsClient *client, const UniqueID &id,
|
||||
std::shared_ptr<ClientTableDataT> data) {
|
||||
[](gcs::AsyncGcsClient *client, const UniqueID &id, const ClientTableDataT &data) {
|
||||
ClientTableNotification(client, id, data, true);
|
||||
});
|
||||
client->client_table().RegisterClientRemovedCallback(
|
||||
[](gcs::AsyncGcsClient *client, const UniqueID &id,
|
||||
std::shared_ptr<ClientTableDataT> data) {
|
||||
[](gcs::AsyncGcsClient *client, const UniqueID &id, const ClientTableDataT &data) {
|
||||
ClientTableNotification(client, id, data, false);
|
||||
test->Stop();
|
||||
});
|
||||
|
|
|
@ -21,6 +21,11 @@ enum TablePubsub:int {
|
|||
ACTOR
|
||||
}
|
||||
|
||||
table GcsNotification {
|
||||
id: string;
|
||||
data: string;
|
||||
}
|
||||
|
||||
table FunctionTableData {
|
||||
language: Language;
|
||||
name: string;
|
||||
|
|
|
@ -11,6 +11,22 @@ extern "C" {
|
|||
// TODO(pcm): Integrate into the C++ tree.
|
||||
#include "state/ray_config.h"
|
||||
|
||||
namespace {
|
||||
|
||||
/// A helper function to call the callback and delete it from the callback
|
||||
/// manager if necessary.
|
||||
void ProcessCallback(int64_t callback_index, const std::vector<std::string> &data) {
|
||||
if (callback_index >= 0) {
|
||||
bool delete_callback =
|
||||
ray::gcs::RedisCallbackManager::instance().get(callback_index)(data);
|
||||
// Delete the callback if necessary.
|
||||
if (delete_callback) {
|
||||
ray::gcs::RedisCallbackManager::instance().remove(callback_index);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
namespace ray {
|
||||
|
||||
namespace gcs {
|
||||
|
@ -24,24 +40,25 @@ void GlobalRedisCallback(void *c, void *r, void *privdata) {
|
|||
}
|
||||
int64_t callback_index = reinterpret_cast<int64_t>(privdata);
|
||||
redisReply *reply = reinterpret_cast<redisReply *>(r);
|
||||
std::string data = "";
|
||||
if (reply->type == REDIS_REPLY_NIL) {
|
||||
// Respond with blank string, which triggers a failure callback for lookups.
|
||||
} else if (reply->type == REDIS_REPLY_STRING) {
|
||||
data = std::string(reply->str, reply->len);
|
||||
} else if (reply->type == REDIS_REPLY_ARRAY) {
|
||||
reply = reply->element[reply->elements - 1];
|
||||
data = std::string(reply->str, reply->len);
|
||||
} else if (reply->type == REDIS_REPLY_STATUS) {
|
||||
} else if (reply->type == REDIS_REPLY_ERROR) {
|
||||
std::vector<std::string> data;
|
||||
// Parse the response.
|
||||
switch (reply->type) {
|
||||
case (REDIS_REPLY_NIL): {
|
||||
// Do not add any data for a nil response.
|
||||
} break;
|
||||
case (REDIS_REPLY_STRING): {
|
||||
data.push_back(std::string(reply->str, reply->len));
|
||||
} break;
|
||||
case (REDIS_REPLY_STATUS): {
|
||||
} break;
|
||||
case (REDIS_REPLY_ERROR): {
|
||||
RAY_LOG(ERROR) << "Redis error " << reply->str;
|
||||
} else {
|
||||
} break;
|
||||
default:
|
||||
RAY_LOG(FATAL) << "Fatal redis error of type " << reply->type
|
||||
<< " and with string " << reply->str;
|
||||
}
|
||||
RedisCallbackManager::instance().get(callback_index)(data);
|
||||
// Delete the callback.
|
||||
RedisCallbackManager::instance().remove(callback_index);
|
||||
ProcessCallback(callback_index, data);
|
||||
}
|
||||
|
||||
void SubscribeRedisCallback(void *c, void *r, void *privdata) {
|
||||
|
@ -50,31 +67,35 @@ void SubscribeRedisCallback(void *c, void *r, void *privdata) {
|
|||
}
|
||||
int64_t callback_index = reinterpret_cast<int64_t>(privdata);
|
||||
redisReply *reply = reinterpret_cast<redisReply *>(r);
|
||||
std::string data = "";
|
||||
if (reply->type == REDIS_REPLY_ARRAY) {
|
||||
// Parse the message.
|
||||
std::vector<std::string> data;
|
||||
// Parse the response.
|
||||
switch (reply->type) {
|
||||
case (REDIS_REPLY_ARRAY): {
|
||||
// Parse the published message.
|
||||
redisReply *message_type = reply->element[0];
|
||||
if (strcmp(message_type->str, "subscribe") == 0) {
|
||||
// If the message is for the initial subscription call, do not fill in
|
||||
// data.
|
||||
// If the message is for the initial subscription call, return the empty
|
||||
// string as a response to signify that subscription was successful.
|
||||
data.push_back("");
|
||||
} else if (strcmp(message_type->str, "message") == 0) {
|
||||
// If the message is from a PUBLISH, make sure the data is nonempty.
|
||||
redisReply *message = reply->element[reply->elements - 1];
|
||||
data = std::string(message->str, message->len);
|
||||
RAY_CHECK(!data.empty()) << "Empty message received on subscribe channel";
|
||||
auto notification = std::string(message->str, message->len);
|
||||
RAY_CHECK(!notification.empty()) << "Empty message received on subscribe channel";
|
||||
data.push_back(notification);
|
||||
} else {
|
||||
RAY_LOG(FATAL) << "Fatal redis error during subscribe" << message_type->str;
|
||||
}
|
||||
|
||||
// NOTE(swang): We do not delete the callback after calling it since there
|
||||
// may be more subscription messages.
|
||||
RedisCallbackManager::instance().get(callback_index)(data);
|
||||
} else if (reply->type == REDIS_REPLY_ERROR) {
|
||||
} break;
|
||||
case (REDIS_REPLY_ERROR): {
|
||||
RAY_LOG(ERROR) << "Redis error " << reply->str;
|
||||
} else {
|
||||
} break;
|
||||
default:
|
||||
RAY_LOG(FATAL) << "Fatal redis error of type " << reply->type << " and with string "
|
||||
<< reply->str;
|
||||
}
|
||||
ProcessCallback(callback_index, data);
|
||||
}
|
||||
|
||||
int64_t RedisCallbackManager::add(const RedisCallback &function) {
|
||||
|
@ -161,8 +182,9 @@ Status RedisContext::AttachToEventLoop(aeEventLoop *loop) {
|
|||
}
|
||||
|
||||
Status RedisContext::RunAsync(const std::string &command, const UniqueID &id,
|
||||
uint8_t *data, int64_t length, const TablePrefix prefix,
|
||||
const TablePubsub pubsub_channel, int64_t callback_index) {
|
||||
const uint8_t *data, int64_t length,
|
||||
const TablePrefix prefix, const TablePubsub pubsub_channel,
|
||||
int64_t callback_index) {
|
||||
if (length > 0) {
|
||||
std::string redis_command = command + " %d %d %b %b";
|
||||
int status = redisAsyncCommand(
|
||||
|
@ -200,7 +222,6 @@ Status RedisContext::SubscribeAsync(const ClientID &client_id,
|
|||
reinterpret_cast<void *>(callback_index), redis_command.c_str(), pubsub_channel);
|
||||
} else {
|
||||
// Subscribe only to messages sent to this client.
|
||||
// TODO(swang): Nobody sends on this channel yet.
|
||||
std::string redis_command = "SUBSCRIBE %d:%b";
|
||||
status = redisAsyncCommand(
|
||||
subscribe_context_, reinterpret_cast<redisCallbackFn *>(&SubscribeRedisCallback),
|
||||
|
|
|
@ -21,7 +21,10 @@ namespace gcs {
|
|||
|
||||
class RedisCallbackManager {
|
||||
public:
|
||||
using RedisCallback = std::function<void(const std::string &)>;
|
||||
/// Every callback should take in a vector of the results from the Redis
|
||||
/// operation and return a bool indicating whether the callback should be
|
||||
/// deleted once called.
|
||||
using RedisCallback = std::function<bool(const std::vector<std::string> &)>;
|
||||
|
||||
static RedisCallbackManager &instance() {
|
||||
static RedisCallbackManager instance;
|
||||
|
@ -50,7 +53,7 @@ class RedisContext {
|
|||
~RedisContext();
|
||||
Status Connect(const std::string &address, int port);
|
||||
Status AttachToEventLoop(aeEventLoop *loop);
|
||||
Status RunAsync(const std::string &command, const UniqueID &id, uint8_t *data,
|
||||
Status RunAsync(const std::string &command, const UniqueID &id, const uint8_t *data,
|
||||
int64_t length, const TablePrefix prefix,
|
||||
const TablePubsub pubsub_channel, int64_t callback_index);
|
||||
Status SubscribeAsync(const ClientID &client_id, const TablePubsub pubsub_channel,
|
||||
|
|
|
@ -1,36 +1,141 @@
|
|||
#include "ray/gcs/tables.h"
|
||||
|
||||
#include "common_protocol.h"
|
||||
#include "ray/gcs/client.h"
|
||||
|
||||
namespace ray {
|
||||
|
||||
namespace gcs {
|
||||
|
||||
void ClientTable::RegisterClientAddedCallback(const Callback &callback) {
|
||||
template <typename ID, typename Data>
|
||||
Status Table<ID, Data>::Add(const JobID &job_id, const ID &id,
|
||||
std::shared_ptr<DataT> data, const Callback &done) {
|
||||
auto d = std::shared_ptr<CallbackData>(
|
||||
new CallbackData({id, data, done, nullptr, nullptr, this, client_}));
|
||||
int64_t callback_index =
|
||||
RedisCallbackManager::instance().add([d](const std::vector<std::string> &data) {
|
||||
if (d->callback != nullptr) {
|
||||
(d->callback)(d->client, d->id, *d->data);
|
||||
}
|
||||
return true;
|
||||
});
|
||||
flatbuffers::FlatBufferBuilder fbb;
|
||||
fbb.ForceDefaults(true);
|
||||
fbb.Finish(Data::Pack(fbb, data.get()));
|
||||
return context_->RunAsync("RAY.TABLE_ADD", id, fbb.GetBufferPointer(), fbb.GetSize(),
|
||||
prefix_, pubsub_channel_, callback_index);
|
||||
}
|
||||
|
||||
template <typename ID, typename Data>
|
||||
Status Table<ID, Data>::Lookup(const JobID &job_id, const ID &id, const Callback &lookup,
|
||||
const FailureCallback &failure) {
|
||||
auto d = std::shared_ptr<CallbackData>(
|
||||
new CallbackData({id, nullptr, lookup, failure, nullptr, this, client_}));
|
||||
int64_t callback_index =
|
||||
RedisCallbackManager::instance().add([d](const std::vector<std::string> &data) {
|
||||
if (data.empty()) {
|
||||
if (d->failure != nullptr) {
|
||||
(d->failure)(d->client, d->id);
|
||||
}
|
||||
} else {
|
||||
RAY_CHECK(data.size() == 1);
|
||||
if (d->callback != nullptr) {
|
||||
DataT result;
|
||||
auto root = flatbuffers::GetRoot<Data>(data[0].data());
|
||||
root->UnPackTo(&result);
|
||||
(d->callback)(d->client, d->id, result);
|
||||
}
|
||||
}
|
||||
return true;
|
||||
});
|
||||
std::vector<uint8_t> nil;
|
||||
return context_->RunAsync("RAY.TABLE_LOOKUP", id, nil.data(), nil.size(), prefix_,
|
||||
pubsub_channel_, callback_index);
|
||||
}
|
||||
|
||||
template <typename ID, typename Data>
|
||||
Status Table<ID, Data>::Subscribe(const JobID &job_id, const ClientID &client_id,
|
||||
const Callback &subscribe,
|
||||
const SubscriptionCallback &done) {
|
||||
RAY_CHECK(subscribe_callback_index_ == -1)
|
||||
<< "Client called Subscribe twice on the same table";
|
||||
auto d = std::shared_ptr<CallbackData>(
|
||||
new CallbackData({client_id, nullptr, subscribe, nullptr, done, this, client_}));
|
||||
int64_t callback_index = RedisCallbackManager::instance().add(
|
||||
[this, d](const std::vector<std::string> &data) {
|
||||
if (data.size() == 1 && data[0] == "") {
|
||||
// No notification data is provided. This is the callback for the
|
||||
// initial subscription request.
|
||||
if (d->subscription_callback != nullptr) {
|
||||
(d->subscription_callback)(d->client);
|
||||
}
|
||||
} else {
|
||||
// Data is provided. This is the callback for a message.
|
||||
RAY_CHECK(data.size() == 1);
|
||||
if (d->callback != nullptr) {
|
||||
// Parse the notification.
|
||||
auto notification = flatbuffers::GetRoot<GcsNotification>(data[0].data());
|
||||
ID id = UniqueID::nil();
|
||||
if (notification->id()->size() > 0) {
|
||||
id = from_flatbuf(*notification->id());
|
||||
}
|
||||
DataT result;
|
||||
auto root = flatbuffers::GetRoot<Data>(notification->data()->data());
|
||||
root->UnPackTo(&result);
|
||||
(d->callback)(d->client, id, result);
|
||||
}
|
||||
}
|
||||
// We do not delete the callback after calling it since there may be
|
||||
// more subscription messages.
|
||||
return false;
|
||||
});
|
||||
subscribe_callback_index_ = callback_index;
|
||||
return context_->SubscribeAsync(client_id, pubsub_channel_, callback_index);
|
||||
}
|
||||
|
||||
template <typename ID, typename Data>
|
||||
Status Table<ID, Data>::RequestNotifications(const JobID &job_id, const ID &id,
|
||||
const ClientID &client_id) {
|
||||
RAY_CHECK(subscribe_callback_index_ >= 0)
|
||||
<< "Client requested notifications on a key before Subscribe completed";
|
||||
return context_->RunAsync("RAY.TABLE_REQUEST_NOTIFICATIONS", id, client_id.data(),
|
||||
client_id.size(), prefix_, pubsub_channel_,
|
||||
subscribe_callback_index_);
|
||||
}
|
||||
|
||||
template <typename ID, typename Data>
|
||||
Status Table<ID, Data>::CancelNotifications(const JobID &job_id, const ID &id,
|
||||
const ClientID &client_id) {
|
||||
RAY_CHECK(subscribe_callback_index_ >= 0)
|
||||
<< "Client canceled notifications on a key before Subscribe completed";
|
||||
return context_->RunAsync("RAY.TABLE_CANCEL_NOTIFICATIONS", id, client_id.data(),
|
||||
client_id.size(), prefix_, pubsub_channel_,
|
||||
/*callback_index=*/-1);
|
||||
}
|
||||
|
||||
void ClientTable::RegisterClientAddedCallback(const ClientTableCallback &callback) {
|
||||
client_added_callback_ = callback;
|
||||
// Call the callback for any added clients that are cached.
|
||||
for (const auto &entry : client_cache_) {
|
||||
if (!entry.first.is_nil() && entry.second.is_insertion) {
|
||||
auto data = std::make_shared<ClientTableDataT>(entry.second);
|
||||
client_added_callback_(client_, entry.first, data);
|
||||
client_added_callback_(client_, ClientID::nil(), entry.second);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void ClientTable::RegisterClientRemovedCallback(const Callback &callback) {
|
||||
void ClientTable::RegisterClientRemovedCallback(const ClientTableCallback &callback) {
|
||||
client_removed_callback_ = callback;
|
||||
// Call the callback for any removed clients that are cached.
|
||||
for (const auto &entry : client_cache_) {
|
||||
if (!entry.first.is_nil() && !entry.second.is_insertion) {
|
||||
auto data = std::make_shared<ClientTableDataT>(entry.second);
|
||||
client_removed_callback_(client_, entry.first, data);
|
||||
client_removed_callback_(client_, ClientID::nil(), entry.second);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void ClientTable::HandleNotification(AsyncGcsClient *client, const ClientID &channel_id,
|
||||
std::shared_ptr<ClientTableDataT> data) {
|
||||
ClientID client_id = ClientID::from_binary(data->client_id);
|
||||
const ClientTableDataT &data) {
|
||||
ClientID client_id = ClientID::from_binary(data.client_id);
|
||||
// It's possible to get duplicate notifications from the client table, so
|
||||
// check whether this notification is new.
|
||||
auto entry = client_cache_.find(client_id);
|
||||
|
@ -42,24 +147,24 @@ void ClientTable::HandleNotification(AsyncGcsClient *client, const ClientID &cha
|
|||
// If the entry is in the cache, then the notification is new if the client
|
||||
// was alive and is now dead.
|
||||
bool was_inserted = entry->second.is_insertion;
|
||||
bool is_deleted = !data->is_insertion;
|
||||
bool is_deleted = !data.is_insertion;
|
||||
is_new = (was_inserted && is_deleted);
|
||||
// Once a client with a given ID has been removed, it should never be added
|
||||
// again. If the entry was in the cache and the client was deleted, check
|
||||
// that this new notification is not an insertion.
|
||||
if (!entry->second.is_insertion) {
|
||||
RAY_CHECK(!data->is_insertion)
|
||||
RAY_CHECK(!data.is_insertion)
|
||||
<< "Notification for addition of a client that was already removed:"
|
||||
<< client_id.hex();
|
||||
}
|
||||
}
|
||||
|
||||
// Add the notification to our cache. Notifications are idempotent.
|
||||
client_cache_[client_id] = *data;
|
||||
client_cache_[client_id] = data;
|
||||
|
||||
// If the notification is new, call any registered callbacks.
|
||||
if (is_new) {
|
||||
if (data->is_insertion) {
|
||||
if (data.is_insertion) {
|
||||
if (client_added_callback_ != nullptr) {
|
||||
client_added_callback_(client, client_id, data);
|
||||
}
|
||||
|
@ -72,7 +177,7 @@ void ClientTable::HandleNotification(AsyncGcsClient *client, const ClientID &cha
|
|||
}
|
||||
|
||||
void ClientTable::HandleConnected(AsyncGcsClient *client, const ClientID &client_id,
|
||||
std::shared_ptr<ClientTableDataT> data) {
|
||||
const ClientTableDataT &data) {
|
||||
RAY_CHECK(client_id == client_id_) << client_id.hex() << " " << client_id_.hex();
|
||||
}
|
||||
|
||||
|
@ -87,18 +192,17 @@ Status ClientTable::Connect() {
|
|||
data->is_insertion = true;
|
||||
// Callback for a notification from the client table.
|
||||
auto notification_callback = [this](AsyncGcsClient *client, const ClientID &channel_id,
|
||||
std::shared_ptr<ClientTableDataT> data) {
|
||||
const ClientTableDataT &data) {
|
||||
return HandleNotification(client, channel_id, data);
|
||||
};
|
||||
// Callback to handle our own successful connection once we've added
|
||||
// ourselves.
|
||||
auto add_callback = [this](AsyncGcsClient *client, const ClientID &id,
|
||||
std::shared_ptr<ClientTableDataT> data) {
|
||||
const ClientTableDataT &data) {
|
||||
HandleConnected(client, id, data);
|
||||
};
|
||||
// Callback to add ourselves once we've successfully subscribed.
|
||||
auto subscription_callback = [this, data, add_callback](
|
||||
AsyncGcsClient *c, const ClientID &id, std::shared_ptr<ClientTableDataT> d) {
|
||||
auto subscription_callback = [this, data, add_callback](AsyncGcsClient *c) {
|
||||
// Mark ourselves as deleted if we called Disconnect() since the last
|
||||
// Connect() call.
|
||||
if (disconnected_) {
|
||||
|
@ -114,7 +218,7 @@ Status ClientTable::Disconnect() {
|
|||
auto data = std::make_shared<ClientTableDataT>(local_client_);
|
||||
data->is_insertion = true;
|
||||
auto add_callback = [this](AsyncGcsClient *client, const ClientID &id,
|
||||
std::shared_ptr<ClientTableDataT> data) {
|
||||
const ClientTableDataT &data) {
|
||||
HandleConnected(client, id, data);
|
||||
};
|
||||
RAY_RETURN_NOT_OK(Add(JobID::nil(), client_id_, data, add_callback));
|
||||
|
@ -135,6 +239,9 @@ const ClientTableDataT &ClientTable::GetClient(const ClientID &client_id) {
|
|||
}
|
||||
}
|
||||
|
||||
template class Table<TaskID, TaskTableData>;
|
||||
template class Table<ObjectID, ObjectTableData>;
|
||||
|
||||
} // namespace gcs
|
||||
|
||||
} // namespace ray
|
||||
|
|
|
@ -30,9 +30,14 @@ template <typename ID, typename Data>
|
|||
class Table {
|
||||
public:
|
||||
using DataT = typename Data::NativeTableType;
|
||||
using Callback = std::function<void(AsyncGcsClient *client, const ID &id,
|
||||
std::shared_ptr<DataT> data)>;
|
||||
using Callback =
|
||||
std::function<void(AsyncGcsClient *client, const ID &id, const DataT &data)>;
|
||||
/// The callback to call when a lookup fails because there is no entry at the
|
||||
/// key.
|
||||
using FailureCallback = std::function<void(AsyncGcsClient *client, const ID &id)>;
|
||||
/// The callback to call when a SUBSCRIBE call completes and we are ready to
|
||||
/// request and receive notifications.
|
||||
using SubscriptionCallback = std::function<void(AsyncGcsClient *client)>;
|
||||
|
||||
struct CallbackData {
|
||||
ID id;
|
||||
|
@ -41,7 +46,7 @@ class Table {
|
|||
FailureCallback failure;
|
||||
// An optional callback to call for subscription operations, where the
|
||||
// first message is a notification of subscription success.
|
||||
Callback subscription_callback;
|
||||
SubscriptionCallback subscription_callback;
|
||||
Table<ID, Data> *table;
|
||||
AsyncGcsClient *client;
|
||||
};
|
||||
|
@ -50,7 +55,8 @@ class Table {
|
|||
: context_(context),
|
||||
client_(client),
|
||||
pubsub_channel_(TablePubsub_NO_PUBLISH),
|
||||
prefix_(TablePrefix_UNUSED){};
|
||||
prefix_(TablePrefix_UNUSED),
|
||||
subscribe_callback_index_(-1){};
|
||||
|
||||
/// Add an entry to the table.
|
||||
///
|
||||
|
@ -61,91 +67,63 @@ class Table {
|
|||
/// GCS.
|
||||
/// \return Status
|
||||
Status Add(const JobID &job_id, const ID &id, std::shared_ptr<DataT> data,
|
||||
const Callback &done) {
|
||||
auto d = std::shared_ptr<CallbackData>(
|
||||
new CallbackData({id, data, done, nullptr, nullptr, this, client_}));
|
||||
int64_t callback_index =
|
||||
RedisCallbackManager::instance().add([d](const std::string &data) {
|
||||
if (d->callback != nullptr) {
|
||||
(d->callback)(d->client, d->id, d->data);
|
||||
}
|
||||
});
|
||||
flatbuffers::FlatBufferBuilder fbb;
|
||||
fbb.ForceDefaults(true);
|
||||
fbb.Finish(Data::Pack(fbb, data.get()));
|
||||
RAY_RETURN_NOT_OK(context_->RunAsync("RAY.TABLE_ADD", id, fbb.GetBufferPointer(),
|
||||
fbb.GetSize(), prefix_, pubsub_channel_,
|
||||
callback_index));
|
||||
return Status::OK();
|
||||
}
|
||||
const Callback &done);
|
||||
|
||||
/// Lookup an entry asynchronously.
|
||||
///
|
||||
/// \param job_id The ID of the job (= driver).
|
||||
/// \param id The ID of the data that is looked up in the GCS.
|
||||
/// \param lookup Callback that is called after lookup.
|
||||
/// \param lookup Callback that is called after lookup. If the callback is
|
||||
/// called with an empty vector, then there was no data at the key.
|
||||
/// \return Status
|
||||
Status Lookup(const JobID &job_id, const ID &id, const Callback &lookup,
|
||||
const FailureCallback &failure) {
|
||||
auto d = std::shared_ptr<CallbackData>(
|
||||
new CallbackData({id, nullptr, lookup, failure, nullptr, this, client_}));
|
||||
int64_t callback_index =
|
||||
RedisCallbackManager::instance().add([d](const std::string &data) {
|
||||
if (data.empty()) {
|
||||
if (d->failure != nullptr) {
|
||||
(d->failure)(d->client, d->id);
|
||||
}
|
||||
} else {
|
||||
auto result = std::make_shared<DataT>();
|
||||
auto root = flatbuffers::GetRoot<Data>(data.data());
|
||||
root->UnPackTo(result.get());
|
||||
if (d->callback != nullptr) {
|
||||
(d->callback)(d->client, d->id, result);
|
||||
}
|
||||
}
|
||||
});
|
||||
std::vector<uint8_t> nil;
|
||||
RAY_RETURN_NOT_OK(context_->RunAsync("RAY.TABLE_LOOKUP", id, nil.data(), nil.size(),
|
||||
prefix_, pubsub_channel_, callback_index));
|
||||
return Status::OK();
|
||||
}
|
||||
const FailureCallback &failure);
|
||||
|
||||
/// Subscribe to updates of this table
|
||||
/// Subscribe to any Add operations to this table. The caller may choose to
|
||||
/// subscribe to all Adds, or to subscribe only to keys that it requests
|
||||
/// notifications for. This may only be called once per Table instance.
|
||||
///
|
||||
/// \param job_id The ID of the job (= driver).
|
||||
/// \param client_id The type of update to listen to. If this is nil, then a
|
||||
/// message for each Add to the table will be received. Else, only
|
||||
/// messages for the given client will be received.
|
||||
/// \param subscribe Callback that is called on each received message.
|
||||
/// messages for the given client will be received. In the latter
|
||||
/// case, the client may request notifications on specific keys in the
|
||||
/// table via `RequestNotifications`.
|
||||
/// \param subscribe Callback that is called on each received message. If the
|
||||
/// callback is called with an empty vector, then there was no data at
|
||||
/// the key.
|
||||
/// \param done Callback that is called when subscription is complete and we
|
||||
/// are ready to receive messages..
|
||||
/// are ready to receive messages.
|
||||
/// \return Status
|
||||
Status Subscribe(const JobID &job_id, const ClientID &client_id,
|
||||
const Callback &subscribe, const Callback &done) {
|
||||
auto d = std::shared_ptr<CallbackData>(
|
||||
new CallbackData({client_id, nullptr, subscribe, nullptr, done, this, client_}));
|
||||
int64_t callback_index =
|
||||
RedisCallbackManager::instance().add([d](const std::string &data) {
|
||||
if (data.empty()) {
|
||||
// No data is provided. This is the callback for the initial
|
||||
// subscription request.
|
||||
if (d->subscription_callback != nullptr) {
|
||||
(d->subscription_callback)(d->client, d->id, nullptr);
|
||||
}
|
||||
} else {
|
||||
// Data is provided. This is the callback for a message.
|
||||
auto result = std::make_shared<DataT>();
|
||||
auto root = flatbuffers::GetRoot<Data>(data.data());
|
||||
root->UnPackTo(result.get());
|
||||
(d->callback)(d->client, d->id, result);
|
||||
}
|
||||
});
|
||||
std::vector<uint8_t> nil;
|
||||
return context_->SubscribeAsync(client_id, pubsub_channel_, callback_index);
|
||||
}
|
||||
const Callback &subscribe, const SubscriptionCallback &done);
|
||||
|
||||
/// Remove and entry from the table
|
||||
Status Remove(const JobID &job_id, const ID &id, const Callback &done);
|
||||
/// Request notifications about a key in this table.
|
||||
///
|
||||
/// The notifications will be returned via the subscribe callback that was
|
||||
/// registered by `Subscribe`. An initial notification will be returned for
|
||||
/// the current value(s) at the key, if any, and a subsequent notification
|
||||
/// will be published for every following `Add` to the key. Before
|
||||
/// notifications can be requested, the caller must first call `Subscribe`,
|
||||
/// with the same `client_id`.
|
||||
///
|
||||
/// \param job_id The ID of the job (= driver).
|
||||
/// \param id The ID of the key to request notifications for.
|
||||
/// \param client_id The client who is requesting notifications. Before
|
||||
/// notifications can be requested, a call to `Subscribe` to this
|
||||
/// table with the same `client_id` must complete successfully.
|
||||
/// \return Status
|
||||
Status RequestNotifications(const JobID &job_id, const ID &id,
|
||||
const ClientID &client_id);
|
||||
|
||||
/// Cancel notifications about a key in this table.
|
||||
///
|
||||
/// \param job_id The ID of the job (= driver).
|
||||
/// \param id The ID of the key to request notifications for.
|
||||
/// \param client_id The client who originally requested notifications.
|
||||
/// \return Status
|
||||
Status CancelNotifications(const JobID &job_id, const ID &id,
|
||||
const ClientID &client_id);
|
||||
|
||||
protected:
|
||||
/// The connection to the GCS.
|
||||
|
@ -158,6 +136,10 @@ class Table {
|
|||
TablePubsub pubsub_channel_;
|
||||
/// The prefix to use for keys in this table.
|
||||
TablePrefix prefix_;
|
||||
/// The index in the RedisCallbackManager for the callback that is called
|
||||
/// when we receive notifications. This is >= 0 iff we have subscribed to the
|
||||
/// table, otherwise -1.
|
||||
int64_t subscribe_callback_index_;
|
||||
};
|
||||
|
||||
class ObjectTable : public Table<ObjectID, ObjectTableData> {
|
||||
|
@ -167,31 +149,6 @@ class ObjectTable : public Table<ObjectID, ObjectTableData> {
|
|||
pubsub_channel_ = TablePubsub_OBJECT;
|
||||
prefix_ = TablePrefix_OBJECT;
|
||||
};
|
||||
|
||||
/// Set up a client-specific channel for receiving notifications about
|
||||
/// available
|
||||
/// objects from the object table. The callback will be called once per
|
||||
/// notification received on this channel.
|
||||
///
|
||||
/// \param subscribe_all
|
||||
/// \param object_available_callback Callback to be called when new object
|
||||
/// becomes available.
|
||||
/// \param done_callback Callback to be called when subscription is installed.
|
||||
/// This is only used for the tests.
|
||||
/// \return Status
|
||||
Status SubscribeToNotifications(const JobID &job_id, bool subscribe_all,
|
||||
const Callback &object_available, const Callback &done);
|
||||
|
||||
/// Request notifications about the availability of some objects from the
|
||||
/// object
|
||||
/// table. The notifications will be published to this client's object
|
||||
/// notification channel, which was set up by the method
|
||||
/// ObjectTableSubscribeToNotifications.
|
||||
///
|
||||
/// \param object_ids The object IDs to receive notifications about.
|
||||
/// \return Status
|
||||
Status RequestNotifications(const JobID &job_id,
|
||||
const std::vector<ObjectID> &object_ids);
|
||||
};
|
||||
|
||||
class FunctionTable : public Table<ObjectID, FunctionTableData> {
|
||||
|
@ -240,11 +197,13 @@ class TaskTable : public Table<TaskID, TaskTableData> {
|
|||
std::shared_ptr<TaskTableTestAndUpdateT> data,
|
||||
const TestAndUpdateCallback &callback) {
|
||||
int64_t callback_index = RedisCallbackManager::instance().add(
|
||||
[this, callback, id](const std::string &data) {
|
||||
[this, callback, id](const std::vector<std::string> &data) {
|
||||
RAY_CHECK(data.size() == 1);
|
||||
auto result = std::make_shared<TaskTableDataT>();
|
||||
auto root = flatbuffers::GetRoot<TaskTableData>(data.data());
|
||||
auto root = flatbuffers::GetRoot<TaskTableData>(data[0].data());
|
||||
root->UnPackTo(result.get());
|
||||
callback(client_, id, *result, root->updated());
|
||||
return true;
|
||||
});
|
||||
flatbuffers::FlatBufferBuilder fbb;
|
||||
fbb.Finish(TaskTableTestAndUpdate::Pack(fbb, data.get()));
|
||||
|
@ -293,6 +252,8 @@ Status TaskTableTestAndUpdate(AsyncGcsClient *gcs_client, const TaskID &task_id,
|
|||
|
||||
class ClientTable : private Table<ClientID, ClientTableData> {
|
||||
public:
|
||||
using ClientTableCallback = std::function<void(
|
||||
AsyncGcsClient *client, const ClientID &id, const ClientTableDataT &data)>;
|
||||
ClientTable(const std::shared_ptr<RedisContext> &context, AsyncGcsClient *client,
|
||||
const ClientTableDataT &local_client)
|
||||
: Table(context, client),
|
||||
|
@ -324,12 +285,12 @@ class ClientTable : private Table<ClientID, ClientTableData> {
|
|||
/// Register a callback to call when a new client is added.
|
||||
///
|
||||
/// \param callback The callback to register.
|
||||
void RegisterClientAddedCallback(const Callback &callback);
|
||||
void RegisterClientAddedCallback(const ClientTableCallback &callback);
|
||||
|
||||
/// Register a callback to call when a client is removed.
|
||||
///
|
||||
/// \param callback The callback to register.
|
||||
void RegisterClientRemovedCallback(const Callback &callback);
|
||||
void RegisterClientRemovedCallback(const ClientTableCallback &callback);
|
||||
|
||||
/// Get a client's information from the cache. The cache only contains
|
||||
/// information for clients that we've heard a notification for.
|
||||
|
@ -352,10 +313,10 @@ class ClientTable : private Table<ClientID, ClientTableData> {
|
|||
private:
|
||||
/// Handle a client table notification.
|
||||
void HandleNotification(AsyncGcsClient *client, const ClientID &channel_id,
|
||||
std::shared_ptr<ClientTableDataT>);
|
||||
const ClientTableDataT ¬ifications);
|
||||
/// Handle this client's successful connection to the GCS.
|
||||
void HandleConnected(AsyncGcsClient *client, const ClientID &client_id,
|
||||
std::shared_ptr<ClientTableDataT>);
|
||||
const ClientTableDataT ¬ifications);
|
||||
|
||||
/// Whether this client has called Disconnect().
|
||||
bool disconnected_;
|
||||
|
@ -364,9 +325,9 @@ class ClientTable : private Table<ClientID, ClientTableData> {
|
|||
/// Information about this client.
|
||||
ClientTableDataT local_client_;
|
||||
/// The callback to call when a new client is added.
|
||||
Callback client_added_callback_;
|
||||
ClientTableCallback client_added_callback_;
|
||||
/// The callback to call when a client is removed.
|
||||
Callback client_removed_callback_;
|
||||
ClientTableCallback client_removed_callback_;
|
||||
/// A cache for information about all clients.
|
||||
std::unordered_map<ClientID, ClientTableDataT, UniqueIDHasher> client_cache_;
|
||||
};
|
||||
|
|
|
@ -44,9 +44,9 @@ Status TaskTableAdd(AsyncGcsClient *gcs_client, Task *task) {
|
|||
TaskSpec *spec = execution_spec.Spec();
|
||||
auto data = MakeTaskTableData(execution_spec, Task_local_scheduler(task),
|
||||
static_cast<SchedulingState>(Task_state(task)));
|
||||
return gcs_client->task_table().Add(ray::JobID::nil(), TaskSpec_task_id(spec), data,
|
||||
[](gcs::AsyncGcsClient *client, const TaskID &id,
|
||||
std::shared_ptr<TaskTableDataT> data) {});
|
||||
return gcs_client->task_table().Add(
|
||||
ray::JobID::nil(), TaskSpec_task_id(spec), data,
|
||||
[](gcs::AsyncGcsClient *client, const TaskID &id, const TaskTableDataT &data) {});
|
||||
}
|
||||
|
||||
// TODO(pcm): This is a helper method that should go away once we get rid of
|
||||
|
|
Loading…
Add table
Reference in a new issue