diff --git a/src/ray/gcs/redis_module/ray_redis_module.cc b/src/ray/gcs/redis_module/ray_redis_module.cc index 8c1c6bd2a..75cc4500b 100644 --- a/src/ray/gcs/redis_module/ray_redis_module.cc +++ b/src/ray/gcs/redis_module/ray_redis_module.cc @@ -39,6 +39,9 @@ extern RedisChainModule module; } \ } +/// Map from pub sub channel to clients that are waiting on that channel. +std::unordered_map> notification_map; + /// Parse a Redis string into a TablePubsub channel. Status ParseTablePubsub(TablePubsub *out, const RedisModuleString *pubsub_channel_str) { long long pubsub_channel_long; @@ -138,14 +141,12 @@ Status OpenPrefixedKey(RedisModuleKey **out, RedisModuleCtx *ctx, /// Open the key used to store the channels that should be published to when an /// update happens at the given keyname. -Status OpenBroadcastKey(RedisModuleKey **out, RedisModuleCtx *ctx, - RedisModuleString *pubsub_channel_str, RedisModuleString *keyname, - int mode) { +Status GetBroadcastKey(RedisModuleCtx *ctx, RedisModuleString *pubsub_channel_str, + RedisModuleString *keyname, std::string *out) { RedisModuleString *channel; RAY_RETURN_NOT_OK(FormatPubsubChannel(&channel, ctx, pubsub_channel_str, keyname)); RedisModuleString *prefixed_keyname = RedisString_Format(ctx, "BCAST:%S", channel); - *out = reinterpret_cast( - RedisModule_OpenKey(ctx, prefixed_keyname, mode)); + *out = RedisString_ToString(prefixed_keyname); return Status::OK(); } @@ -192,22 +193,20 @@ int PublishTableAdd(RedisModuleCtx *ctx, RedisModuleString *pubsub_channel_str, return RedisModule_ReplyWithError(ctx, "error during PUBLISH"); } + std::string notification_key; + REPLY_AND_RETURN_IF_NOT_OK( + GetBroadcastKey(ctx, pubsub_channel_str, id, ¬ification_key)); // Publish the data to any clients who requested notifications on this key. - RedisModuleKey *notification_key; - REPLY_AND_RETURN_IF_NOT_OK(OpenBroadcastKey(¬ification_key, ctx, pubsub_channel_str, - id, REDISMODULE_READ | REDISMODULE_WRITE)); - if (RedisModule_KeyType(notification_key) != REDISMODULE_KEYTYPE_EMPTY) { - // NOTE(swang): Sets are not implemented yet, so we use ZSETs instead. - REPLY_AND_RETURN_IF_FALSE(RedisModule_ZsetFirstInScoreRange( - notification_key, REDISMODULE_NEGATIVE_INFINITE, - REDISMODULE_POSITIVE_INFINITE, 1, 1) == REDISMODULE_OK, - "Unable to initialize zset iterator"); - for (; !RedisModule_ZsetRangeEndReached(notification_key); - RedisModule_ZsetRangeNext(notification_key)) { - RedisModuleString *client_channel = - RedisModule_ZsetRangeCurrentElement(notification_key, NULL); + auto it = notification_map.find(notification_key); + if (it != notification_map.end()) { + for (const std::string &client_channel : it->second) { + // RedisModule_Call seems to be broken and cannot accept "bb", + // therefore we construct a temporary redis string here, which + // will be garbage collected by redis. + auto channel = + RedisModule_CreateString(ctx, client_channel.data(), client_channel.size()); RedisModuleCallReply *reply = RedisModule_Call( - ctx, "PUBLISH", "sb", client_channel, fbb.GetBufferPointer(), fbb.GetSize()); + ctx, "PUBLISH", "sb", channel, fbb.GetBufferPointer(), fbb.GetSize()); if (reply == NULL) { return RedisModule_ReplyWithError(ctx, "error during PUBLISH"); } @@ -532,12 +531,10 @@ int TableRequestNotifications_RedisCommand(RedisModuleCtx *ctx, RedisModuleStrin // Add this client to the set of clients that should be notified when there // are changes to the key. - RedisModuleKey *notification_key; - REPLY_AND_RETURN_IF_NOT_OK(OpenBroadcastKey(¬ification_key, ctx, pubsub_channel_str, - id, REDISMODULE_READ | REDISMODULE_WRITE)); - REPLY_AND_RETURN_IF_FALSE( - RedisModule_ZsetAdd(notification_key, 0.0, client_channel, NULL) == REDISMODULE_OK, - "ZsetAdd failed."); + std::string notification_key; + REPLY_AND_RETURN_IF_NOT_OK( + GetBroadcastKey(ctx, pubsub_channel_str, id, ¬ification_key)); + notification_map[notification_key].push_back(RedisString_ToString(client_channel)); // Lookup the current value at the key. RedisModuleKey *table_key; @@ -587,17 +584,16 @@ int TableCancelNotifications_RedisCommand(RedisModuleCtx *ctx, RedisModuleString // Remove this client from the set of clients that should be notified when // there are changes to the key. - RedisModuleKey *notification_key; - REPLY_AND_RETURN_IF_NOT_OK(OpenBroadcastKey(¬ification_key, ctx, pubsub_channel_str, - id, REDISMODULE_READ | REDISMODULE_WRITE)); - if (RedisModule_KeyType(notification_key) != REDISMODULE_KEYTYPE_EMPTY) { - REPLY_AND_RETURN_IF_FALSE( - RedisModule_ZsetRem(notification_key, client_channel, NULL) == REDISMODULE_OK, - "not opened for writing or wrong type."); - size_t size = RedisModule_ValueLength(notification_key); - if (size == 0) { - REPLY_AND_RETURN_IF_FALSE(RedisModule_DeleteKey(notification_key) == REDISMODULE_OK, - "Unable to delete zset key."); + std::string notification_key; + REPLY_AND_RETURN_IF_NOT_OK( + GetBroadcastKey(ctx, pubsub_channel_str, id, ¬ification_key)); + auto it = notification_map.find(notification_key); + if (it != notification_map.end()) { + it->second.erase(std::remove(it->second.begin(), it->second.end(), + RedisString_ToString(client_channel)), + it->second.end()); + if (it->second.size() == 0) { + notification_map.erase(it); } } @@ -669,6 +665,25 @@ int TableTestAndUpdate_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **arg return result; } +std::string DebugString() { + std::stringstream result; + result << "RedisModule:"; + result << "\n- NotificationMap.size = " << notification_map.size(); + result << std::endl; + return result.str(); +} + +int DebugString_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) { + REDISMODULE_NOT_USED(argv); + RedisModule_AutoMemory(ctx); + + if (argc != 1) { + return RedisModule_WrongArity(ctx); + } + std::string debug_string = DebugString(); + return RedisModule_ReplyWithStringBuffer(ctx, debug_string.data(), debug_string.size()); +} + extern "C" { /* This function must be present on each Redis module. It is used in order to @@ -714,6 +729,11 @@ int RedisModule_OnLoad(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) return REDISMODULE_ERR; } + if (RedisModule_CreateCommand(ctx, "ray.debug_string", DebugString_RedisCommand, + "readonly", 0, 0, 0) == REDISMODULE_ERR) { + return REDISMODULE_ERR; + } + #if RAY_USE_NEW_GCS // Chain-enabled commands that depend on ray-project/credis. if (RedisModule_CreateCommand(ctx, "ray.chain.table_add", ChainTableAdd_RedisCommand, diff --git a/src/ray/gcs/redis_module/redis_string.h b/src/ray/gcs/redis_module/redis_string.h index 3755b3f3a..aaddc2007 100644 --- a/src/ray/gcs/redis_module/redis_string.h +++ b/src/ray/gcs/redis_module/redis_string.h @@ -70,4 +70,10 @@ RedisModuleString *RedisString_Format(RedisModuleCtx *ctx, const char *fmt, ...) return result; } +std::string RedisString_ToString(RedisModuleString *string) { + size_t size; + const char *data = RedisModule_StringPtrLen(string, &size); + return std::string(data, size); +} + #endif // RAY_REDIS_STRING_H_ diff --git a/test/jenkins_tests/multi_node_tests/test_wait_hanging.py b/test/jenkins_tests/multi_node_tests/test_wait_hanging.py new file mode 100644 index 000000000..911c5979f --- /dev/null +++ b/test/jenkins_tests/multi_node_tests/test_wait_hanging.py @@ -0,0 +1,31 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import ray + + +@ray.remote +def f(): + return 0 + + +@ray.remote +def g(): + import time + start = time.time() + while time.time() < start + 1: + ray.get([f.remote() for _ in range(10)]) + + +# 10MB -> hangs after ~5 iterations +# 20MB -> hangs after ~20 iterations +# 50MB -> hangs after ~50 iterations +ray.init(redis_max_memory=1024 * 1024 * 50) + +i = 0 +for i in range(100): + i += 1 + a = g.remote() + [ok], _ = ray.wait([a]) + print("iter", i)