diff --git a/python/ray/experimental/state.py b/python/ray/experimental/state.py index e3fd4a4fa..108ee0a68 100644 --- a/python/ray/experimental/state.py +++ b/python/ray/experimental/state.py @@ -353,9 +353,9 @@ class GlobalState(object): "Deleted": bool(int(decode(client_info[b"deleted"]))), "DBClientID": binary_to_hex(client_info[b"ray_client_id"]) } - if b"aux_address" in client_info: + if b"manager_address" in client_info: client_info_parsed["AuxAddress"] = decode( - client_info[b"aux_address"]) + client_info[b"manager_address"]) if b"num_cpus" in client_info: client_info_parsed["NumCPUs"] = float( decode(client_info[b"num_cpus"])) diff --git a/python/ray/worker.py b/python/ray/worker.py index 04833d03d..436289701 100644 --- a/python/ray/worker.py +++ b/python/ray/worker.py @@ -1096,7 +1096,7 @@ def get_address_info_from_redis_helper(redis_address, node_ip_address): # Build the address information. object_store_addresses = [] for manager in plasma_managers: - address = manager[b"address"].decode("ascii") + address = manager[b"manager_address"].decode("ascii") port = services.get_port(address) object_store_addresses.append( services.ObjectStoreAddress( diff --git a/src/common/common.cc b/src/common/common.cc index 82e861444..3588ab7d3 100644 --- a/src/common/common.cc +++ b/src/common/common.cc @@ -44,6 +44,10 @@ bool DBClientID_equal(DBClientID first_id, DBClientID second_id) { return UNIQUE_ID_EQ(first_id, second_id); } +bool DBClientID_is_nil(DBClientID id) { + return IS_NIL_ID(id); +} + bool WorkerID_equal(WorkerID first_id, WorkerID second_id) { return UNIQUE_ID_EQ(first_id, second_id); } diff --git a/src/common/common.h b/src/common/common.h index 5f5668f90..1b7b71b50 100644 --- a/src/common/common.h +++ b/src/common/common.h @@ -218,6 +218,14 @@ typedef UniqueID DBClientID; */ bool DBClientID_equal(DBClientID first_id, DBClientID second_id); +/** + * Compare a db client ID to the nil ID. + * + * @param id The db client ID to compare to nil. + * @return True if the db client ID is equal to nil. + */ +bool DBClientID_is_nil(ObjectID id); + #define MAX(x, y) ((x) >= (y) ? (x) : (y)) #define MIN(x, y) ((x) <= (y) ? (x) : (y)) diff --git a/src/common/format/common.fbs b/src/common/format/common.fbs index 7d74c9ab7..58ed7c356 100644 --- a/src/common/format/common.fbs +++ b/src/common/format/common.fbs @@ -111,7 +111,7 @@ table SubscribeToDBClientTableReply { client_type: string; // If the client is a local scheduler, this is the address of the plasma // manager that the local scheduler is connected to. Otherwise, it is empty. - aux_address: string; + manager_address: string; // True if the message is about the addition of a client and false if it is // about the deletion of a client. is_insertion: bool; diff --git a/src/common/redis_module/ray_redis_module.cc b/src/common/redis_module/ray_redis_module.cc index 75c0eee2e..ab14fd06f 100644 --- a/src/common/redis_module/ray_redis_module.cc +++ b/src/common/redis_module/ray_redis_module.cc @@ -83,16 +83,16 @@ flatbuffers::Offset RedisStringToFlatbuf( * * TODO(swang): Use flatbuffers for the notification message. * The format for the published notification is: - * : - * If no auxiliary address is provided, aux_address will be set to ":". If + * : + * If no manager address is provided, manager_address will be set to ":". If * is_insertion is true, then the last field will be "1", else "0". * * @param ctx The Redis context. * @param ray_client_id The ID of the database client that was inserted or * deleted. * @param client_type The type of client that was inserted or deleted. - * @param aux_address An optional secondary address associated with the - * database client. + * @param manager_address An optional secondary address for the object manager + * associated with this database client. * @param is_insertion A boolean that's true if the update was an insertion and * false if deletion. * @return True if the publish was successful and false otherwise. @@ -100,7 +100,7 @@ flatbuffers::Offset RedisStringToFlatbuf( bool PublishDBClientNotification(RedisModuleCtx *ctx, RedisModuleString *ray_client_id, RedisModuleString *client_type, - RedisModuleString *aux_address, + RedisModuleString *manager_address, bool is_insertion) { /* Construct strings to publish on the db client channel. */ RedisModuleString *channel_name = @@ -108,16 +108,17 @@ bool PublishDBClientNotification(RedisModuleCtx *ctx, /* Construct the flatbuffers object to publish over the channel. */ flatbuffers::FlatBufferBuilder fbb; /* Use an empty aux address if one is not passed in. */ - flatbuffers::Offset aux_address_str; - if (aux_address != NULL) { - aux_address_str = RedisStringToFlatbuf(fbb, aux_address); + flatbuffers::Offset manager_address_str; + if (manager_address != NULL) { + manager_address_str = RedisStringToFlatbuf(fbb, manager_address); } else { - aux_address_str = fbb.CreateString("", strlen("")); + manager_address_str = fbb.CreateString("", strlen("")); } /* Create the flatbuffers message. */ auto message = CreateSubscribeToDBClientTableReply( fbb, RedisStringToFlatbuf(fbb, ray_client_id), - RedisStringToFlatbuf(fbb, client_type), aux_address_str, is_insertion); + RedisStringToFlatbuf(fbb, client_type), manager_address_str, + is_insertion); fbb.Finish(message); /* Create a Redis string to publish by serializing the flatbuffers object. */ RedisModuleString *client_info = RedisModule_CreateString( @@ -141,14 +142,10 @@ bool PublishDBClientNotification(RedisModuleCtx *ctx, * and these will be stored in a hashmap associated with this client. Several * fields are singled out for special treatment: * - * address: This is provided by plasma managers and it should be an address - * like "127.0.0.1:1234". It is returned by RAY.GET_CLIENT_ADDRESS so - * that other plasma managers know how to fetch objects. - * aux_address: This is provided by local schedulers and should be the - * address of the plasma manager that the local scheduler is connected - * to. This is published to the "db_clients" channel by the RAY.CONNECT - * command and is used by the global scheduler to determine which plasma - * managers and local schedulers are connected. + * manager_address: This is provided by local schedulers and plasma + * managers and should be the address of the plasma manager that the + * client is associated with. This is published to the "db_clients" + * channel by the RAY.CONNECT command. * * @param ray_client_id The db client ID of the client. * @param node_ip_address The IP address of the node the client is on. @@ -178,9 +175,9 @@ int Connect_RedisCommand(RedisModuleCtx *ctx, } /* This will be used to construct a publish message. */ - RedisModuleString *aux_address = NULL; - RedisModuleString *aux_address_key = - RedisModule_CreateString(ctx, "aux_address", strlen("aux_address")); + RedisModuleString *manager_address = NULL; + RedisModuleString *manager_address_key = RedisModule_CreateString( + ctx, "manager_address", strlen("manager_address")); RedisModuleString *deleted = RedisModule_CreateString(ctx, "0", strlen("0")); RedisModule_HashSet(db_client_table_key, REDISMODULE_HASH_CFIELDS, @@ -193,16 +190,16 @@ int Connect_RedisCommand(RedisModuleCtx *ctx, RedisModuleString *value = argv[i + 1]; RedisModule_HashSet(db_client_table_key, REDISMODULE_HASH_NONE, key, value, NULL); - if (RedisModule_StringCompare(key, aux_address_key) == 0) { - aux_address = value; + if (RedisModule_StringCompare(key, manager_address_key) == 0) { + manager_address = value; } } /* Clean up. */ RedisModule_FreeString(ctx, deleted); - RedisModule_FreeString(ctx, aux_address_key); + RedisModule_FreeString(ctx, manager_address_key); RedisModule_CloseKey(db_client_table_key); - if (!PublishDBClientNotification(ctx, ray_client_id, client_type, aux_address, - true)) { + if (!PublishDBClientNotification(ctx, ray_client_id, client_type, + manager_address, true)) { return RedisModule_ReplyWithError(ctx, "PUBLISH unsuccessful"); } @@ -256,16 +253,16 @@ int Disconnect_RedisCommand(RedisModuleCtx *ctx, RedisModule_FreeString(ctx, deleted); RedisModuleString *client_type; - RedisModuleString *aux_address; + RedisModuleString *manager_address; RedisModule_HashGet(db_client_table_key, REDISMODULE_HASH_CFIELDS, - "client_type", &client_type, "aux_address", - &aux_address, NULL); + "client_type", &client_type, "manager_address", + &manager_address, NULL); /* Publish the deletion notification on the db client channel. */ published = PublishDBClientNotification(ctx, ray_client_id, client_type, - aux_address, false); - if (aux_address != NULL) { - RedisModule_FreeString(ctx, aux_address); + manager_address, false); + if (manager_address != NULL) { + RedisModule_FreeString(ctx, manager_address); } RedisModule_FreeString(ctx, client_type); } @@ -282,50 +279,6 @@ int Disconnect_RedisCommand(RedisModuleCtx *ctx, return REDISMODULE_OK; } -/** - * Get the address of a client from its db client ID. This is called from a - * client with the command: - * - * RAY.GET_CLIENT_ADDRESS - * - * @param ray_client_id The db client ID of the client. - * @return The address of the client if the operation was successful. - */ -int GetClientAddress_RedisCommand(RedisModuleCtx *ctx, - RedisModuleString **argv, - int argc) { - if (argc != 2) { - return RedisModule_WrongArity(ctx); - } - - RedisModuleString *ray_client_id = argv[1]; - /* Get the request client address from the db client table. */ - RedisModuleKey *db_client_table_key = - OpenPrefixedKey(ctx, DB_CLIENT_PREFIX, ray_client_id, REDISMODULE_READ); - if (db_client_table_key == NULL) { - /* There is no client with this ID. */ - RedisModule_CloseKey(db_client_table_key); - return RedisModule_ReplyWithError(ctx, "invalid client ID"); - } - RedisModuleString *address; - RedisModule_HashGet(db_client_table_key, REDISMODULE_HASH_CFIELDS, "address", - &address, NULL); - if (address == NULL) { - /* The key did not exist. This should not happen. */ - RedisModule_CloseKey(db_client_table_key); - return RedisModule_ReplyWithError( - ctx, "Client does not have an address field. This shouldn't happen."); - } - - RedisModule_ReplyWithString(ctx, address); - - /* Cleanup. */ - RedisModule_CloseKey(db_client_table_key); - RedisModule_FreeString(ctx, address); - - return REDISMODULE_OK; -} - /** * Lookup an entry in the object table. * @@ -1054,7 +1007,7 @@ int TaskTableUpdate_RedisCommand(RedisModuleCtx *ctx, * This is called from a client with the command: * * RAY.TASK_TABLE_TEST_AND_UPDATE - * + * * * @param task_id A string that is the ID of the task. * @param test_state_bitmask A string that is the test bitmask for the @@ -1064,19 +1017,28 @@ int TaskTableUpdate_RedisCommand(RedisModuleCtx *ctx, * instance) to update the task entry with. * @param ray_client_id A string that is the ray client ID of the associated * local scheduler, if any, to update the task entry with. + * @param test_local_scheduler_id A string to test the local scheduler ID. If + * provided, and if the current local scheduler ID does not match it, + * then the update does not happen. * @return Returns the task entry as a TaskReply. The reply will reflect the * update, if it happened. */ int TaskTableTestAndUpdate_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) { - if (argc != 5) { + if (argc < 5 || argc > 6) { return RedisModule_WrongArity(ctx); } + /* If a sixth argument was provided, then we should also test the current + * local scheduler ID. */ + bool test_local_scheduler = (argc == 6); - RedisModuleString *state = argv[3]; + RedisModuleString *task_id = argv[1]; + RedisModuleString *test_state = argv[2]; + RedisModuleString *update_state = argv[3]; + RedisModuleString *local_scheduler_id = argv[4]; - RedisModuleKey *key = OpenPrefixedKey(ctx, TASK_PREFIX, argv[1], + RedisModuleKey *key = OpenPrefixedKey(ctx, TASK_PREFIX, task_id, REDISMODULE_READ | REDISMODULE_WRITE); if (RedisModule_KeyType(key) == REDISMODULE_KEYTYPE_EMPTY) { RedisModule_CloseKey(key); @@ -1085,8 +1047,10 @@ int TaskTableTestAndUpdate_RedisCommand(RedisModuleCtx *ctx, /* If the key exists, look up the fields and return them in an array. */ RedisModuleString *current_state = NULL; + RedisModuleString *current_local_scheduler_id = NULL; RedisModule_HashGet(key, REDISMODULE_HASH_CFIELDS, "state", ¤t_state, - NULL); + "local_scheduler_id", ¤t_local_scheduler_id, NULL); + long long current_state_integer; if (RedisModule_StringToLongLong(current_state, ¤t_state_integer) != REDISMODULE_OK) { @@ -1098,25 +1062,42 @@ int TaskTableTestAndUpdate_RedisCommand(RedisModuleCtx *ctx, return RedisModule_ReplyWithError(ctx, "Found invalid scheduling state."); } long long test_state_bitmask; - int status = RedisModule_StringToLongLong(argv[2], &test_state_bitmask); + int status = RedisModule_StringToLongLong(test_state, &test_state_bitmask); if (status != REDISMODULE_OK) { RedisModule_CloseKey(key); return RedisModule_ReplyWithError( ctx, "Invalid test value for scheduling state"); } - bool updated = false; + bool update = false; if (current_state_integer & test_state_bitmask) { - /* The test passed, so perform the update. */ - RedisModule_HashSet(key, REDISMODULE_HASH_CFIELDS, "state", state, - "local_scheduler_id", argv[4], NULL); - updated = true; + if (test_local_scheduler) { + /* A test local scheduler ID was provided. Test whether it is equal to + * the current local scheduler ID before performing the update. */ + RedisModuleString *test_local_scheduler_id = argv[5]; + if (RedisModule_StringCompare(current_local_scheduler_id, + test_local_scheduler_id) == 0) { + /* If the current local scheduler ID does matches the test ID, then + * perform the update. */ + update = true; + } + } else { + /* No test local scheduler ID was provided. Perform the update. */ + update = true; + } + } + + /* If the scheduling state and local scheduler ID tests passed, then perform + * the update. */ + if (update) { + RedisModule_HashSet(key, REDISMODULE_HASH_CFIELDS, "state", update_state, + "local_scheduler_id", local_scheduler_id, NULL); } /* Clean up. */ RedisModule_CloseKey(key); /* Construct a reply by getting the task from the task ID. */ - return ReplyWithTask(ctx, argv[1], updated); + return ReplyWithTask(ctx, task_id, update); } /** @@ -1168,12 +1149,6 @@ int RedisModule_OnLoad(RedisModuleCtx *ctx, return REDISMODULE_ERR; } - if (RedisModule_CreateCommand(ctx, "ray.get_client_address", - GetClientAddress_RedisCommand, "write", 0, 0, - 0) == REDISMODULE_ERR) { - return REDISMODULE_ERR; - } - if (RedisModule_CreateCommand(ctx, "ray.object_table_lookup", ObjectTableLookup_RedisCommand, "readonly", 0, 0, 0) == REDISMODULE_ERR) { diff --git a/src/common/state/db_client_table.cc b/src/common/state/db_client_table.cc index 1e2b6439f..0a01c3b59 100644 --- a/src/common/state/db_client_table.cc +++ b/src/common/state/db_client_table.cc @@ -28,6 +28,51 @@ void db_client_table_subscribe( redis_db_client_table_subscribe, user_context); } +const std::vector db_client_table_get_ip_addresses( + DBHandle *db_handle, + const std::vector &manager_ids) { + /* We time this function because in the past this loop has taken multiple + * seconds under stressful situations on hundreds of machines causing the + * plasma manager to die (because it went too long without sending + * heartbeats). */ + int64_t start_time = current_time_ms(); + + /* Construct the manager vector from the flatbuffers object. */ + std::vector manager_vector; + + for (auto const &manager_id : manager_ids) { + DBClient client = redis_cache_get_db_client(db_handle, manager_id); + CHECK(!client.manager_address.empty()); + manager_vector.push_back(client.manager_address); + } + + int64_t end_time = current_time_ms(); + if (end_time - start_time > RayConfig::instance().max_time_for_loop()) { + LOG_WARN( + "calling redis_get_cached_db_client in a loop in with %zu manager IDs " + "took %" PRId64 " milliseconds.", + manager_ids.size(), end_time - start_time); + } + + return manager_vector; +} + +void db_client_table_update_cache_callback(DBClient *db_client, + void *user_context) { + DBHandle *db_handle = (DBHandle *) user_context; + redis_cache_set_db_client(db_handle, *db_client); +} + +void db_client_table_cache_init(DBHandle *db_handle) { + db_client_table_subscribe(db_handle, db_client_table_update_cache_callback, + db_handle, NULL, NULL, NULL); +} + +DBClient db_client_table_cache_get(DBHandle *db_handle, DBClientID client_id) { + CHECK(!DBClientID_is_nil(client_id)); + return redis_cache_get_db_client(db_handle, client_id); +} + void plasma_manager_send_heartbeat(DBHandle *db_handle) { RetryInfo heartbeat_retry; heartbeat_retry.num_retries = 0; diff --git a/src/common/state/db_client_table.h b/src/common/state/db_client_table.h index b82582d63..d140ba770 100644 --- a/src/common/state/db_client_table.h +++ b/src/common/state/db_client_table.h @@ -1,6 +1,8 @@ #ifndef DB_CLIENT_TABLE_H #define DB_CLIENT_TABLE_H +#include + #include "db.h" #include "table.h" @@ -34,13 +36,13 @@ typedef struct { /** The database client ID. */ DBClientID id; /** The database client type. */ - const char *client_type; - /** An optional auxiliary address for an associated database client on the - * same node. */ - const char *aux_address; + std::string client_type; + /** An optional auxiliary address for the plasma manager associated with this + * database client. */ + std::string manager_address; /** Whether or not the database client exists. If this is false for an entry, * then it will never again be true. */ - bool is_insertion; + bool is_alive; } DBClient; /* Callback for subscribing to the db client table. */ @@ -76,6 +78,29 @@ typedef struct { void *subscribe_context; } DBClientTableSubscribeData; +const std::vector db_client_table_get_ip_addresses( + DBHandle *db, + const std::vector &manager_ids); + +/** + * Initialize the db client cache. The cache is updated with each notification + * from the db client table. + * + * @param db_handle Database handle. + * @return Void. + */ +void db_client_table_cache_init(DBHandle *db_handle); + +/** + * Get a db client from the cache. If the requested client is not there, + * request the latest entry from the db client table. + * + * @param db_handle Database handle. + * @param client_id The ID of the client to look up in the cache. + * @return The database client in the cache. + */ +DBClient db_client_table_cache_get(DBHandle *db_handle, DBClientID client_id); + /* * ==== Plasma manager heartbeats ==== */ diff --git a/src/common/state/object_table.h b/src/common/state/object_table.h index b8233620d..b77920d8b 100644 --- a/src/common/state/object_table.h +++ b/src/common/state/object_table.h @@ -18,14 +18,14 @@ typedef void (*object_table_lookup_done_callback)( ObjectID object_id, bool never_created, - const std::vector &manager_vector, + const std::vector &manager_ids, void *user_context); /* Callback called when object ObjectID is available. */ typedef void (*object_table_object_available_callback)( ObjectID object_id, int64_t data_size, - const std::vector &manager_vector, + const std::vector &manager_ids, void *user_context); /** diff --git a/src/common/state/redis.cc b/src/common/state/redis.cc index b89024578..e2e92fb26 100644 --- a/src/common/state/redis.cc +++ b/src/common/state/redis.cc @@ -323,12 +323,6 @@ void DBHandle_free(DBHandle *db) { redisAsyncFree(db->subscribe_contexts[i]); } - /* Clean up memory. */ - for (auto it = db->db_client_cache.begin(); it != db->db_client_cache.end(); - it = db->db_client_cache.erase(it)) { - free(it->second); - } - free(db->client_type); delete db; } @@ -596,6 +590,46 @@ void redis_result_table_lookup(TableCallbackData *callback_data) { } } +DBClient redis_db_client_table_get(DBHandle *db, + unsigned char *client_id, + size_t client_id_len) { + redisReply *reply = + (redisReply *) redisCommand(db->sync_context, "HGETALL %s%b", + DB_CLIENT_PREFIX, client_id, client_id_len); + CHECK(reply->type == REDIS_REPLY_ARRAY); + CHECK(reply->elements > 0); + DBClient db_client; + int num_fields = 0; + /* Parse the fields into a DBClient. */ + for (size_t j = 0; j < reply->elements; j = j + 2) { + const char *key = reply->element[j]->str; + const char *value = reply->element[j + 1]->str; + if (strcmp(key, "ray_client_id") == 0) { + memcpy(db_client.id.id, value, sizeof(db_client.id)); + num_fields++; + } else if (strcmp(key, "client_type") == 0) { + db_client.client_type = std::string(value); + num_fields++; + } else if (strcmp(key, "manager_address") == 0) { + db_client.manager_address = std::string(value); + num_fields++; + } else if (strcmp(key, "deleted") == 0) { + bool is_deleted = atoi(value); + db_client.is_alive = !is_deleted; + num_fields++; + } + } + freeReplyObject(reply); + /* The client ID, type, and whether it is deleted are all + * mandatory fields. Auxiliary address is optional. */ + CHECK(num_fields >= 3); + return db_client; +} + +void redis_cache_set_db_client(DBHandle *db, DBClient client) { + db->db_client_cache[client.id] = client; +} + /** * Get an entry from the plasma manager table in redis. * @@ -603,56 +637,15 @@ void redis_result_table_lookup(TableCallbackData *callback_data) { * @param index The index of the plasma manager. * @return The IP address and port of the manager. */ -const std::string redis_get_cached_db_client(DBHandle *db, - DBClientID db_client_id) { +DBClient redis_cache_get_db_client(DBHandle *db, DBClientID db_client_id) { auto it = db->db_client_cache.find(db_client_id); - - char *manager; if (it == db->db_client_cache.end()) { - /* This is a very rare case. It should happen at most once per db client. */ - redisReply *reply = (redisReply *) redisCommand( - db->sync_context, "RAY.GET_CLIENT_ADDRESS %b", (char *) db_client_id.id, - sizeof(db_client_id.id)); - CHECKM(reply->type == REDIS_REPLY_STRING, "REDIS reply type=%d, str=%s", - reply->type, reply->str); - char *addr = strdup(reply->str); - freeReplyObject(reply); - db->db_client_cache[db_client_id] = addr; - manager = addr; - } else { - manager = it->second; + DBClient db_client = + redis_db_client_table_get(db, db_client_id.id, sizeof(db_client_id.id)); + db->db_client_cache[db_client_id] = db_client; + it = db->db_client_cache.find(db_client_id); } - std::string manager_address(manager); - return manager_address; -} - -const std::vector redis_get_cached_db_clients( - DBHandle *db, - const std::vector &manager_ids) { - /* We time this function because in the past this loop has taken multiple - * seconds under stressful situations on hundreds of machines causing the - * plasma manager to die (because it went too long without sending - * heartbeats). */ - int64_t start_time = current_time_ms(); - - /* Construct the manager vector from the flatbuffers object. */ - std::vector manager_vector; - - for (auto const &manager_id : manager_ids) { - const std::string manager_address = - redis_get_cached_db_client(db, manager_id); - manager_vector.push_back(manager_address); - } - - int64_t end_time = current_time_ms(); - if (end_time - start_time > RayConfig::instance().max_time_for_loop()) { - LOG_WARN( - "calling redis_get_cached_db_client in a loop in with %zu manager IDs " - "took %" PRId64 " milliseconds.", - manager_ids.size(), end_time - start_time); - } - - return manager_vector; + return it->second; } void redis_object_table_lookup_callback(redisAsyncContext *c, @@ -672,7 +665,7 @@ void redis_object_table_lookup_callback(redisAsyncContext *c, if (reply->type == REDIS_REPLY_NIL) { /* The object entry did not exist. */ if (done_callback) { - done_callback(obj_id, true, std::vector(), + done_callback(obj_id, true, std::vector(), callback_data->user_context); } } else if (reply->type == REDIS_REPLY_ARRAY) { @@ -686,18 +679,15 @@ void redis_object_table_lookup_callback(redisAsyncContext *c, manager_ids.push_back(manager_id); } - const std::vector manager_vector = - redis_get_cached_db_clients(db, manager_ids); - if (done_callback) { - done_callback(obj_id, false, manager_vector, callback_data->user_context); + done_callback(obj_id, false, manager_ids, callback_data->user_context); } } else { LOG_FATAL("Unexpected reply type from object table lookup."); } /* Clean up timer and callback. */ - destroy_timer_callback(callback_data->db_handle->loop, callback_data); + destroy_timer_callback(db->loop, callback_data); } void object_table_redis_subscribe_to_notifications_callback( @@ -742,14 +732,11 @@ void object_table_redis_subscribe_to_notifications_callback( manager_ids.push_back(manager_id); } - const std::vector manager_vector = - redis_get_cached_db_clients(db, manager_ids); - /* Call the subscribe callback. */ ObjectTableSubscribeData *data = (ObjectTableSubscribeData *) callback_data->data; if (data->object_available_callback) { - data->object_available_callback(obj_id, data_size, manager_vector, + data->object_available_callback(obj_id, data_size, manager_ids, data->subscribe_context); } } else if (strcmp(message_type->str, "subscribe") == 0) { @@ -759,12 +746,12 @@ void object_table_redis_subscribe_to_notifications_callback( if (callback_data->done_callback != NULL) { object_table_lookup_done_callback done_callback = (object_table_lookup_done_callback) callback_data->done_callback; - done_callback(NIL_ID, false, std::vector(), + done_callback(NIL_ID, false, std::vector(), callback_data->user_context); } /* If the initial SUBSCRIBE was successful, clean up the timer, but don't * destroy the callback data. */ - remove_timer_callback(callback_data->db_handle->loop, callback_data); + remove_timer_callback(db->loop, callback_data); } else { LOG_FATAL( "Unexpected reply type from object table subscribe to notifications."); @@ -1047,13 +1034,28 @@ void redis_task_table_test_and_update(TableCallbackData *callback_data) { TaskTableTestAndUpdateData *update_data = (TaskTableTestAndUpdateData *) callback_data->data; - int status = redisAsyncCommand( - context, redis_task_table_test_and_update_callback, - (void *) callback_data->timer_id, - "RAY.TASK_TABLE_TEST_AND_UPDATE %b %d %d %b", task_id.id, - sizeof(task_id.id), update_data->test_state_bitmask, - update_data->update_state, update_data->local_scheduler_id.id, - sizeof(update_data->local_scheduler_id.id)); + int status; + /* If the test local scheduler ID is NIL, then ignore it. */ + if (IS_NIL_ID(update_data->test_local_scheduler_id)) { + status = redisAsyncCommand( + context, redis_task_table_test_and_update_callback, + (void *) callback_data->timer_id, + "RAY.TASK_TABLE_TEST_AND_UPDATE %b %d %d %b", task_id.id, + sizeof(task_id.id), update_data->test_state_bitmask, + update_data->update_state, update_data->local_scheduler_id.id, + sizeof(update_data->local_scheduler_id.id)); + } else { + status = redisAsyncCommand( + context, redis_task_table_test_and_update_callback, + (void *) callback_data->timer_id, + "RAY.TASK_TABLE_TEST_AND_UPDATE %b %d %d %b %b", task_id.id, + sizeof(task_id.id), update_data->test_state_bitmask, + update_data->update_state, update_data->local_scheduler_id.id, + sizeof(update_data->local_scheduler_id.id), + update_data->test_local_scheduler_id.id, + sizeof(update_data->test_local_scheduler_id.id)); + } + if ((status == REDIS_ERR) || context->err) { LOG_REDIS_DEBUG(context, "error in redis_task_table_test_and_update"); } @@ -1191,37 +1193,13 @@ void redis_db_client_table_scan(DBHandle *db, /* Get all the database client information. */ CHECK(reply->type == REDIS_REPLY_ARRAY); for (size_t i = 0; i < reply->elements; ++i) { - redisReply *client_reply = (redisReply *) redisCommand( - db->sync_context, "HGETALL %b", reply->element[i]->str, - reply->element[i]->len); - CHECK(reply->type == REDIS_REPLY_ARRAY); - CHECK(reply->elements > 0); - DBClient db_client; - memset(&db_client, 0, sizeof(db_client)); - int num_fields = 0; - /* Parse the fields into a DBClient. */ - for (size_t j = 0; j < client_reply->elements; j = j + 2) { - const char *key = client_reply->element[j]->str; - const char *value = client_reply->element[j + 1]->str; - if (strcmp(key, "ray_client_id") == 0) { - memcpy(db_client.id.id, value, sizeof(db_client.id)); - num_fields++; - } else if (strcmp(key, "client_type") == 0) { - db_client.client_type = strdup(value); - num_fields++; - } else if (strcmp(key, "aux_address") == 0) { - db_client.aux_address = strdup(value); - num_fields++; - } else if (strcmp(key, "deleted") == 0) { - bool is_deleted = atoi(value); - db_client.is_insertion = !is_deleted; - num_fields++; - } - } - freeReplyObject(client_reply); - /* The client ID, type, and whether it is deleted are all mandatory fields. - * Auxiliary address is optional. */ - CHECK(num_fields >= 3); + /* Strip the database client table prefix. */ + unsigned char *key = (unsigned char *) reply->element[i]->str; + key += strlen(DB_CLIENT_PREFIX); + size_t key_len = reply->element[i]->len; + key_len -= strlen(DB_CLIENT_PREFIX); + /* Get the database client's information. */ + DBClient db_client = redis_db_client_table_get(db, key, key_len); db_clients.push_back(db_client); } freeReplyObject(reply); @@ -1261,12 +1239,6 @@ void redis_db_client_table_subscribe_callback(redisAsyncContext *c, (DBClientTableSubscribeData *) callback_data->data; for (auto db_client : db_clients) { data->subscribe_callback(&db_client, data->subscribe_context); - if (db_client.client_type != NULL) { - free((void *) db_client.client_type); - } - if (db_client.aux_address != NULL) { - free((void *) db_client.aux_address); - } } return; } @@ -1278,9 +1250,9 @@ void redis_db_client_table_subscribe_callback(redisAsyncContext *c, * only client type, then the update was a delete. */ DBClient db_client; db_client.id = from_flatbuf(message->db_client_id()); - db_client.client_type = (char *) message->client_type()->data(); - db_client.aux_address = message->aux_address()->data(); - db_client.is_insertion = message->is_insertion(); + db_client.client_type = std::string(message->client_type()->data()); + db_client.manager_address = std::string(message->manager_address()->data()); + db_client.is_alive = message->is_insertion(); /* Call the subscription callback. */ DBClientTableSubscribeData *data = diff --git a/src/common/state/redis.h b/src/common/state/redis.h index f400a0b3e..f422c9baa 100644 --- a/src/common/state/redis.h +++ b/src/common/state/redis.h @@ -4,6 +4,7 @@ #include #include "db.h" +#include "db_client_table.h" #include "object_table.h" #include "task_table.h" @@ -45,7 +46,7 @@ struct DBHandle { int64_t db_index; /** Cache for the IP addresses of db clients. This is an unordered map mapping * client IDs to addresses. */ - std::unordered_map db_client_cache; + std::unordered_map db_client_cache; /** Redis context for synchronous connections. This should only be used very * rarely, it is not asynchronous. */ redisContext *sync_context; @@ -85,6 +86,10 @@ void get_redis_shards(redisContext *context, std::vector &db_shards_addresses, std::vector &db_shards_ports); +void redis_cache_set_db_client(DBHandle *db, DBClient client); + +DBClient redis_cache_get_db_client(DBHandle *db, DBClientID db_client_id); + void redis_object_table_get_entry(redisAsyncContext *c, void *r, void *privdata); diff --git a/src/common/state/task_table.cc b/src/common/state/task_table.cc index 6eb8bee27..2c0d3faea 100644 --- a/src/common/state/task_table.cc +++ b/src/common/state/task_table.cc @@ -36,6 +36,7 @@ void task_table_update(DBHandle *db_handle, void task_table_test_and_update( DBHandle *db_handle, TaskID task_id, + DBClientID test_local_scheduler_id, int test_state_bitmask, int update_state, RetryInfo *retry, @@ -43,6 +44,7 @@ void task_table_test_and_update( void *user_context) { TaskTableTestAndUpdateData *update_data = (TaskTableTestAndUpdateData *) malloc(sizeof(TaskTableTestAndUpdateData)); + update_data->test_local_scheduler_id = test_local_scheduler_id; update_data->test_state_bitmask = test_state_bitmask; update_data->update_state = update_state; /* Update the task entry's local scheduler with this client's ID. */ diff --git a/src/common/state/task_table.h b/src/common/state/task_table.h index dd5cf5d5d..ee98fe5ff 100644 --- a/src/common/state/task_table.h +++ b/src/common/state/task_table.h @@ -103,9 +103,13 @@ void task_table_update(DBHandle *db_handle, * * @param db_handle Database handle. * @param task_id The task ID of the task entry to update. + * @param test_local_scheduler_id The local scheduler ID to test the current + * local scheduler ID against. If not NIL_ID, and if the current local + * scheduler ID does not match it, then the update will not happen. * @param test_state_bitmask The bitmask to apply to the task entry's current * scheduling state. The update happens if and only if the current - * scheduling state AND-ed with the bitmask is greater than 0. + * scheduling state AND-ed with the bitmask is greater than 0 and the + * local scheduler ID test passes. * @param update_state The value to update the task entry's scheduling state * with, if the current state matches test_state_bitmask. * @param retry Information about retrying the request to the database. @@ -117,6 +121,7 @@ void task_table_update(DBHandle *db_handle, void task_table_test_and_update( DBHandle *db_handle, TaskID task_id, + DBClientID test_local_scheduler_id, int test_state_bitmask, int update_state, RetryInfo *retry, @@ -125,6 +130,9 @@ void task_table_test_and_update( /* Data that is needed to test and set the task's scheduling state. */ typedef struct { + /** The value to test the current local scheduler ID against. This field is + * ignored if equal to NIL_ID. */ + DBClientID test_local_scheduler_id; int test_state_bitmask; int update_state; DBClientID local_scheduler_id; diff --git a/src/common/test/db_tests.cc b/src/common/test/db_tests.cc index 782cae0a1..82bad9b99 100644 --- a/src/common/test/db_tests.cc +++ b/src/common/test/db_tests.cc @@ -7,7 +7,9 @@ #include "event_loop.h" #include "test_common.h" #include "example_task.h" +#include "net.h" #include "state/db.h" +#include "state/db_client_table.h" #include "state/object_table.h" #include "state/task_table.h" #include "state/redis.h" @@ -27,9 +29,9 @@ const char *manager_addr = "127.0.0.1"; int manager_port1 = 12345; int manager_port2 = 12346; char received_addr1[16] = {0}; -char received_port1[6] = {0}; +int received_port1; char received_addr2[16] = {0}; -char received_port2[6] = {0}; +int received_port2; typedef struct { int test_number; } user_context; @@ -39,17 +41,16 @@ const int TEST_NUMBER = 10; void lookup_done_callback(ObjectID object_id, bool never_created, - const std::vector &manager_vector, + const std::vector &manager_ids, void *user_context) { - CHECK(manager_vector.size() == 2); - if (sscanf(manager_vector.at(0).c_str(), "%15[0-9.]:%5[0-9]", received_addr1, - received_port1) != 2) { - CHECK(0); - } - if (sscanf(manager_vector.at(1).c_str(), "%15[0-9.]:%5[0-9]", received_addr2, - received_port2) != 2) { - CHECK(0); - } + DBHandle *db = (DBHandle *) user_context; + CHECK(manager_ids.size() == 2); + const std::vector managers = + db_client_table_get_ip_addresses(db, manager_ids); + CHECK(parse_ip_addr_port(managers.at(0).c_str(), received_addr1, + &received_port1) == 0); + CHECK(parse_ip_addr_port(managers.at(1).c_str(), received_addr2, + &received_port2) == 0); } /* Entry added to database successfully. */ @@ -69,11 +70,11 @@ int64_t timeout_handler(event_loop *loop, int64_t id, void *context) { TEST object_table_lookup_test(void) { event_loop *loop = event_loop_create(); /* This uses manager_port1. */ - const char *db_connect_args1[] = {"address", "127.0.0.1:12345"}; + const char *db_connect_args1[] = {"manager_address", "127.0.0.1:12345"}; DBHandle *db1 = db_connect(std::string("127.0.0.1"), 6379, "plasma_manager", manager_addr, 2, db_connect_args1); /* This uses manager_port2. */ - const char *db_connect_args2[] = {"address", "127.0.0.1:12346"}; + const char *db_connect_args2[] = {"manager_address", "127.0.0.1:12346"}; DBHandle *db2 = db_connect(std::string("127.0.0.1"), 6379, "plasma_manager", manager_addr, 2, db_connect_args2); db_attach(db1, loop, false); @@ -91,15 +92,13 @@ TEST object_table_lookup_test(void) { event_loop_add_timer(loop, 200, (event_loop_timer_handler) timeout_handler, NULL); event_loop_run(loop); - object_table_lookup(db1, id, &retry, lookup_done_callback, NULL); + object_table_lookup(db1, id, &retry, lookup_done_callback, db1); event_loop_add_timer(loop, 200, (event_loop_timer_handler) timeout_handler, NULL); event_loop_run(loop); - int port1 = atoi(received_port1); - int port2 = atoi(received_port2); ASSERT_STR_EQ(&received_addr1[0], manager_addr); - ASSERT((port1 == manager_port1 && port2 == manager_port2) || - (port2 == manager_port1 && port1 == manager_port2)); + ASSERT((received_port1 == manager_port1 && received_port2 == manager_port2) || + (received_port2 == manager_port1 && received_port1 == manager_port2)); db_disconnect(db1); db_disconnect(db2); diff --git a/src/common/test/object_table_tests.cc b/src/common/test/object_table_tests.cc index 4bd2e0edb..8aeb0b843 100644 --- a/src/common/test/object_table_tests.cc +++ b/src/common/test/object_table_tests.cc @@ -4,6 +4,7 @@ #include "example_task.h" #include "test_common.h" #include "common.h" +#include "state/db_client_table.h" #include "state/object_table.h" #include "state/redis.h" @@ -146,7 +147,7 @@ int lookup_failed = 0; void lookup_done_callback(ObjectID object_id, bool never_created, - const std::vector &manager_vector, + const std::vector &manager_vector, void *context) { /* The done callback should not be called. */ CHECK(0); @@ -226,7 +227,7 @@ int subscribe_failed = 0; void subscribe_done_callback(ObjectID object_id, int64_t data_size, - const std::vector &manager_vector, + const std::vector &manager_vector, void *user_context) { /* The done callback should not be called. */ CHECK(0); @@ -308,11 +309,13 @@ int add_retry_succeeded = 0; void add_lookup_done_callback(ObjectID object_id, bool never_created, - const std::vector &manager_vector, + const std::vector &manager_ids, void *context) { - CHECK(context == (void *) lookup_retry_context); - CHECK(manager_vector.size() == 1); - CHECK(manager_vector.at(0) == "127.0.0.1:11235"); + DBHandle *db = (DBHandle *) context; + CHECK(manager_ids.size() == 1); + const std::vector managers = + db_client_table_get_ip_addresses(db, manager_ids); + CHECK(managers.at(0) == "127.0.0.1:11235"); lookup_retry_succeeded = 1; } @@ -325,14 +328,14 @@ void add_lookup_callback(ObjectID object_id, bool success, void *user_context) { .fail_callback = lookup_retry_fail_callback, }; object_table_lookup(db, NIL_ID, &retry, add_lookup_done_callback, - (void *) lookup_retry_context); + (void *) db); } TEST add_lookup_test(void) { g_loop = event_loop_create(); lookup_retry_succeeded = 0; /* Construct the arguments to db_connect. */ - const char *db_connect_args[] = {"address", "127.0.0.1:11235"}; + const char *db_connect_args[] = {"manager_address", "127.0.0.1:11235"}; DBHandle *db = db_connect(std::string("127.0.0.1"), 6379, "plasma_manager", "127.0.0.1", 2, db_connect_args); db_attach(db, g_loop, true); @@ -359,7 +362,7 @@ TEST add_lookup_test(void) { void add_remove_lookup_done_callback( ObjectID object_id, bool never_created, - const std::vector &manager_vector, + const std::vector &manager_vector, void *context) { CHECK(context == (void *) lookup_retry_context); CHECK(manager_vector.size() == 0); @@ -433,7 +436,7 @@ void lookup_late_fail_callback(UniqueID id, void lookup_late_done_callback(ObjectID object_id, bool never_created, - const std::vector &manager_vector, + const std::vector &manager_vector, void *context) { /* This function should never be called. */ CHECK(0); @@ -520,11 +523,10 @@ void subscribe_late_fail_callback(UniqueID id, subscribe_late_failed = 1; } -void subscribe_late_done_callback( - ObjectID object_id, - bool never_created, - const std::vector &manager_vector, - void *user_context) { +void subscribe_late_done_callback(ObjectID object_id, + bool never_created, + const std::vector &manager_vector, + void *user_context) { /* This function should never be called. */ CHECK(0); } @@ -574,7 +576,7 @@ void subscribe_success_fail_callback(UniqueID id, void subscribe_success_done_callback( ObjectID object_id, bool never_created, - const std::vector &manager_vector, + const std::vector &manager_vector, void *user_context) { RetryInfo retry = { .num_retries = 0, .timeout = 750, .fail_callback = NULL, @@ -587,7 +589,7 @@ void subscribe_success_done_callback( void subscribe_success_object_available_callback( ObjectID object_id, int64_t data_size, - const std::vector &manager_vector, + const std::vector &manager_vector, void *user_context) { CHECK(user_context == (void *) subscribe_success_context); CHECK(ObjectID_equal(object_id, subscribe_id)); @@ -599,7 +601,7 @@ TEST subscribe_success_test(void) { g_loop = event_loop_create(); /* Construct the arguments to db_connect. */ - const char *db_connect_args[] = {"address", "127.0.0.1:11236"}; + const char *db_connect_args[] = {"manager_address", "127.0.0.1:11236"}; DBHandle *db = db_connect(std::string("127.0.0.1"), 6379, "plasma_manager", "127.0.0.1", 2, db_connect_args); db_attach(db, g_loop, false); @@ -645,7 +647,7 @@ int subscribe_object_present_succeeded = 0; void subscribe_object_present_object_available_callback( ObjectID object_id, int64_t data_size, - const std::vector &manager_vector, + const std::vector &manager_vector, void *user_context) { subscribe_object_present_context_t *ctx = (subscribe_object_present_context_t *) user_context; @@ -667,7 +669,7 @@ TEST subscribe_object_present_test(void) { g_loop = event_loop_create(); /* Construct the arguments to db_connect. */ - const char *db_connect_args[] = {"address", "127.0.0.1:11236"}; + const char *db_connect_args[] = {"manager_address", "127.0.0.1:11236"}; DBHandle *db = db_connect(std::string("127.0.0.1"), 6379, "plasma_manager", "127.0.0.1", 2, db_connect_args); db_attach(db, g_loop, false); @@ -711,7 +713,7 @@ const char *subscribe_object_not_present_context = void subscribe_object_not_present_object_available_callback( ObjectID object_id, int64_t data_size, - const std::vector &manager_vector, + const std::vector &manager_vector, void *user_context) { /* This should not be called. */ CHECK(0); @@ -760,7 +762,7 @@ int subscribe_object_available_later_succeeded = 0; void subscribe_object_available_later_object_available_callback( ObjectID object_id, int64_t data_size, - const std::vector &manager_vector, + const std::vector &manager_vector, void *user_context) { subscribe_object_present_context_t *myctx = (subscribe_object_present_context_t *) user_context; @@ -781,7 +783,7 @@ TEST subscribe_object_available_later_test(void) { g_loop = event_loop_create(); /* Construct the arguments to db_connect. */ - const char *db_connect_args[] = {"address", "127.0.0.1:11236"}; + const char *db_connect_args[] = {"manager_address", "127.0.0.1:11236"}; DBHandle *db = db_connect(std::string("127.0.0.1"), 6379, "plasma_manager", "127.0.0.1", 2, db_connect_args); db_attach(db, g_loop, false); @@ -834,7 +836,7 @@ TEST subscribe_object_available_subscribe_all(void) { subscribe_object_available_later_context, data_size}; g_loop = event_loop_create(); /* Construct the arguments to db_connect. */ - const char *db_connect_args[] = {"address", "127.0.0.1:11236"}; + const char *db_connect_args[] = {"manager_address", "127.0.0.1:11236"}; DBHandle *db = db_connect(std::string("127.0.0.1"), 6379, "plasma_manager", "127.0.0.1", 2, db_connect_args); db_attach(db, g_loop, false); diff --git a/src/global_scheduler/global_scheduler.cc b/src/global_scheduler/global_scheduler.cc index ce14c27f5..070aa8c0b 100644 --- a/src/global_scheduler/global_scheduler.cc +++ b/src/global_scheduler/global_scheduler.cc @@ -192,14 +192,16 @@ void process_task_waiting(Task *waiting_task, void *user_context) { void add_local_scheduler(GlobalSchedulerState *state, DBClientID db_client_id, - const char *aux_address) { + const char *manager_address) { /* Add plasma_manager ip:port -> local_scheduler_db_client_id association to * state. */ - state->plasma_local_scheduler_map[std::string(aux_address)] = db_client_id; + state->plasma_local_scheduler_map[std::string(manager_address)] = + db_client_id; /* Add local_scheduler_db_client_id -> plasma_manager ip:port association to * state. */ - state->local_scheduler_plasma_map[db_client_id] = std::string(aux_address); + state->local_scheduler_plasma_map[db_client_id] = + std::string(manager_address); /* Add new local scheduler to the state. */ LocalScheduler local_scheduler; @@ -231,10 +233,10 @@ remove_local_scheduler( /* Remove the local scheduler from the mappings. This code only makes sense if * there is a one-to-one mapping between local schedulers and plasma managers. */ - std::string aux_address = + std::string manager_address = state->local_scheduler_plasma_map[local_scheduler_id]; state->local_scheduler_plasma_map.erase(local_scheduler_id); - state->plasma_local_scheduler_map.erase(aux_address); + state->plasma_local_scheduler_map.erase(manager_address); handle_local_scheduler_removed(state, state->policy_state, local_scheduler_id); @@ -244,7 +246,7 @@ remove_local_scheduler( /** * Process a notification about a new DB client connecting to Redis. * - * @param aux_address An ip:port pair for the plasma manager associated with + * @param manager_address An ip:port pair for the plasma manager associated with * this db client. * @return Void. */ @@ -254,17 +256,18 @@ void process_new_db_client(DBClient *db_client, void *user_context) { LOG_DEBUG("db client table callback for db client = %s", ObjectID_to_string(db_client->id, id_string, ID_STRING_SIZE)); ARROW_UNUSED(id_string); - if (strncmp(db_client->client_type, "local_scheduler", + if (strncmp(db_client->client_type.c_str(), "local_scheduler", strlen("local_scheduler")) == 0) { bool local_scheduler_present = (state->local_schedulers.find(db_client->id) != state->local_schedulers.end()); - if (db_client->is_insertion) { + if (db_client->is_alive) { /* This is a notification for an insert. We may receive duplicate * notifications since we read the entire table before processing * notifications. Filter out local schedulers that we already added. */ if (!local_scheduler_present) { - add_local_scheduler(state, db_client->id, db_client->aux_address); + add_local_scheduler(state, db_client->id, + db_client->manager_address.c_str()); } } else { if (local_scheduler_present) { @@ -281,24 +284,26 @@ void process_new_db_client(DBClient *db_client, void *user_context) { * @param object_id ID of the object that the notification is about. * @param data_size The object size. * @param manager_count The number of locations for this object. - * @param manager_vector The vector of Plasma Manager locations. + * @param manager_ids The vector of Plasma Manager client IDs. * @param user_context The user context. * @return Void. */ -void object_table_subscribe_callback( - ObjectID object_id, - int64_t data_size, - const std::vector &manager_vector, - void *user_context) { +void object_table_subscribe_callback(ObjectID object_id, + int64_t data_size, + const std::vector &manager_ids, + void *user_context) { /* Extract global scheduler state from the callback context. */ GlobalSchedulerState *state = (GlobalSchedulerState *) user_context; char id_string[ID_STRING_SIZE]; LOG_DEBUG("object table subscribe callback for OBJECT = %s", ObjectID_to_string(object_id, id_string, ID_STRING_SIZE)); ARROW_UNUSED(id_string); - LOG_DEBUG("\tManagers<%d>:", manager_vector.size()); - for (size_t i = 0; i < manager_vector.size(); i++) { - LOG_DEBUG("\t\t%s", manager_vector[i]); + + const std::vector managers = + db_client_table_get_ip_addresses(state->db, manager_ids); + LOG_DEBUG("\tManagers<%d>:", managers.size()); + for (size_t i = 0; i < managers.size(); i++) { + LOG_DEBUG("\t\t%s", managers[i]); } if (state->scheduler_object_info_table.find(object_id) == @@ -311,8 +316,8 @@ void object_table_subscribe_callback( LOG_DEBUG("New object added to object_info_table with id = %s", ObjectID_to_string(object_id, id_string, ID_STRING_SIZE)); LOG_DEBUG("\tmanager locations:"); - for (size_t i = 0; i < manager_vector.size(); i++) { - LOG_DEBUG("\t\t%s", manager_vector[i]); + for (size_t i = 0; i < managers.size(); i++) { + LOG_DEBUG("\t\t%s", managers[i]); } } @@ -321,8 +326,8 @@ void object_table_subscribe_callback( /* In all cases, replace the object location vector on each callback. */ obj_info_entry.object_locations.clear(); - for (size_t i = 0; i < manager_vector.size(); i++) { - obj_info_entry.object_locations.push_back(std::string(manager_vector[i])); + for (size_t i = 0; i < managers.size(); i++) { + obj_info_entry.object_locations.push_back(managers[i]); } } diff --git a/src/local_scheduler/local_scheduler.cc b/src/local_scheduler/local_scheduler.cc index 1442bb059..a29c7a3a6 100644 --- a/src/local_scheduler/local_scheduler.cc +++ b/src/local_scheduler/local_scheduler.cc @@ -18,6 +18,7 @@ #include "net.h" #include "state/actor_notification_table.h" #include "state/db.h" +#include "state/db_client_table.h" #include "state/driver_table.h" #include "state/task_table.h" #include "state/object_table.h" @@ -149,6 +150,11 @@ void LocalSchedulerState_free(LocalSchedulerState *state) { * local scheduler at most once. If a SIGTERM is caught afterwards, there is * the possibility of orphan worker processes. */ signal(SIGTERM, SIG_DFL); + /* Send a null heartbeat that tells the global scheduler that we are dead to + * avoid waiting for the heartbeat timeout. */ + if (state->db != NULL) { + local_scheduler_table_disconnect(state->db); + } /* Kill any child processes that didn't register as a worker yet. */ for (auto const &worker_pid : state->child_pids) { @@ -176,9 +182,6 @@ void LocalSchedulerState_free(LocalSchedulerState *state) { * responsible for deleting our entry from the db_client table, so do not * delete it here. */ if (state->db != NULL) { - /* Send a null heartbeat that tells the global scheduler that we are dead - * to avoid waiting for the heartbeat timeout. */ - local_scheduler_table_disconnect(state->db); DBHandle_free(state->db); } @@ -357,7 +360,7 @@ LocalSchedulerState *LocalSchedulerState_init( db_connect_args[3] = utstring_body(num_cpus); db_connect_args[4] = "num_gpus"; db_connect_args[5] = utstring_body(num_gpus); - db_connect_args[6] = "aux_address"; + db_connect_args[6] = "manager_address"; db_connect_args[7] = plasma_manager_address; } else { num_args = 6; @@ -635,16 +638,31 @@ void process_plasma_notification(event_loop *loop, void reconstruct_task_update_callback(Task *task, void *user_context, bool updated) { + LocalSchedulerState *state = (LocalSchedulerState *) user_context; if (!updated) { - /* The test-and-set of the task's scheduling state failed, so the task was - * either not finished yet, or it was already being reconstructed. - * Suppress the reconstruction request. */ + /* The test-and-set failed. The task is either: (1) not finished yet, (2) + * lost, but not yet updated, or (3) already being reconstructed. */ + DBClientID current_local_scheduler_id = Task_local_scheduler(task); + if (!DBClientID_is_nil(current_local_scheduler_id)) { + DBClient current_local_scheduler = + db_client_table_cache_get(state->db, current_local_scheduler_id); + if (!current_local_scheduler.is_alive) { + /* (2) The current local scheduler for the task is dead. The task is + * lost, but the task table hasn't received the update yet. Retry the + * test-and-set. */ + task_table_test_and_update(state->db, Task_task_id(task), + current_local_scheduler_id, Task_state(task), + TASK_STATUS_RECONSTRUCTING, NULL, + reconstruct_task_update_callback, state); + } + } + /* The test-and-set failed, so it is not safe to resubmit the task for + * execution. Suppress the request. */ return; } /* Otherwise, the test-and-set succeeded, so resubmit the task for execution * to ensure that reconstruction will happen. */ - LocalSchedulerState *state = (LocalSchedulerState *) user_context; TaskSpec *spec = Task_task_spec(task); if (ActorID_equal(TaskSpec_actor_id(spec), NIL_ACTOR_ID)) { handle_task_submitted(state, state->algorithm_state, Task_task_spec(task), @@ -667,20 +685,46 @@ void reconstruct_task_update_callback(Task *task, void reconstruct_put_task_update_callback(Task *task, void *user_context, bool updated) { - if (updated) { + LocalSchedulerState *state = (LocalSchedulerState *) user_context; + if (!updated) { + /* The test-and-set failed. The task is either: (1) not finished yet, (2) + * lost, but not yet updated, or (3) already being reconstructed. */ + DBClientID current_local_scheduler_id = Task_local_scheduler(task); + if (!DBClientID_is_nil(current_local_scheduler_id)) { + DBClient current_local_scheduler = + db_client_table_cache_get(state->db, current_local_scheduler_id); + if (!current_local_scheduler.is_alive) { + /* (2) The current local scheduler for the task is dead. The task is + * lost, but the task table hasn't received the update yet. Retry the + * test-and-set. */ + task_table_test_and_update(state->db, Task_task_id(task), + current_local_scheduler_id, Task_state(task), + TASK_STATUS_RECONSTRUCTING, NULL, + reconstruct_put_task_update_callback, state); + } else if (Task_state(task) == TASK_STATUS_RUNNING) { + /* (1) The task is still executing on a live node. The object created + * by `ray.put` was not able to be reconstructed, and the workload will + * likely hang. Push an error to the appropriate driver. */ + TaskSpec *spec = Task_task_spec(task); + FunctionID function = TaskSpec_function(spec); + push_error(state->db, TaskSpec_driver_id(spec), + PUT_RECONSTRUCTION_ERROR_INDEX, sizeof(function), + function.id); + } + } else { + /* (1) The task is still executing and it is the driver task. We cannot + * restart the driver task, so the workload will hang. Push an error to + * the appropriate driver. */ + TaskSpec *spec = Task_task_spec(task); + FunctionID function = TaskSpec_function(spec); + push_error(state->db, TaskSpec_driver_id(spec), + PUT_RECONSTRUCTION_ERROR_INDEX, sizeof(function), function.id); + } + } else { /* The update to TASK_STATUS_RECONSTRUCTING succeeded, so continue with * reconstruction as usual. */ reconstruct_task_update_callback(task, user_context, updated); - return; } - - /* An object created by `ray.put` was not able to be reconstructed, and the - * workload will likely hang. Push an error to the appropriate driver. */ - LocalSchedulerState *state = (LocalSchedulerState *) user_context; - TaskSpec *spec = Task_task_spec(task); - FunctionID function = TaskSpec_function(spec); - push_error(state->db, TaskSpec_driver_id(spec), - PUT_RECONSTRUCTION_ERROR_INDEX, sizeof(function), function.id); } void reconstruct_evicted_result_lookup_callback(ObjectID reconstruct_object_id, @@ -705,7 +749,7 @@ void reconstruct_evicted_result_lookup_callback(ObjectID reconstruct_object_id, /* If there are no other instances of the task running, it's safe for us to * claim responsibility for reconstruction. */ task_table_test_and_update( - state->db, task_id, (TASK_STATUS_DONE | TASK_STATUS_LOST), + state->db, task_id, NIL_ID, (TASK_STATUS_DONE | TASK_STATUS_LOST), TASK_STATUS_RECONSTRUCTING, NULL, done_callback, state); } @@ -726,7 +770,7 @@ void reconstruct_failed_result_lookup_callback(ObjectID reconstruct_object_id, LocalSchedulerState *state = (LocalSchedulerState *) user_context; /* If the task failed to finish, it's safe for us to claim responsibility for * reconstruction. */ - task_table_test_and_update(state->db, task_id, TASK_STATUS_LOST, + task_table_test_and_update(state->db, task_id, NIL_ID, TASK_STATUS_LOST, TASK_STATUS_RECONSTRUCTING, NULL, reconstruct_task_update_callback, state); } @@ -734,9 +778,9 @@ void reconstruct_failed_result_lookup_callback(ObjectID reconstruct_object_id, void reconstruct_object_lookup_callback( ObjectID reconstruct_object_id, bool never_created, - const std::vector &manager_vector, + const std::vector &manager_ids, void *user_context) { - LOG_DEBUG("Manager count was %d", manager_count); + LOG_DEBUG("Manager count was %d", manager_ids.size()); /* Only continue reconstruction if we find that the object doesn't exist on * any nodes. NOTE: This codepath is not responsible for checking if the * object table entry is up-to-date. */ @@ -748,12 +792,24 @@ void reconstruct_object_lookup_callback( result_table_lookup(state->db, reconstruct_object_id, NULL, reconstruct_failed_result_lookup_callback, (void *) state); - } else if (manager_vector.size() == 0) { - /* If the object was created and later evicted, we reconstruct the object - * if and only if there are no other instances of the task running. */ - result_table_lookup(state->db, reconstruct_object_id, NULL, - reconstruct_evicted_result_lookup_callback, - (void *) state); + } else { + /* If the object has been created, filter out the dead plasma managers that + * have it. */ + size_t num_live_managers = 0; + for (auto manager_id : manager_ids) { + DBClient manager = db_client_table_cache_get(state->db, manager_id); + if (manager.is_alive) { + num_live_managers++; + } + } + /* If the object was created, but all plasma managers that had the object + * either evicted it or failed, we reconstruct the object if and only if + * there are no other instances of the task running. */ + if (num_live_managers == 0) { + result_table_lookup(state->db, reconstruct_object_id, NULL, + reconstruct_evicted_result_lookup_callback, + (void *) state); + } } } @@ -1292,6 +1348,10 @@ void start_server(const char *node_ip_address, RayConfig::instance().heartbeat_timeout_milliseconds(), heartbeat_handler, g_state); } + /* Listen for new and deleted db clients. */ + if (g_state->db != NULL) { + db_client_table_cache_init(g_state->db); + } /* Create a timer for fetching queued tasks' missing object dependencies. */ event_loop_add_timer( loop, RayConfig::instance().local_scheduler_fetch_timeout_milliseconds(), diff --git a/src/local_scheduler/test/local_scheduler_tests.cc b/src/local_scheduler/test/local_scheduler_tests.cc index 11f634014..3b17b8646 100644 --- a/src/local_scheduler/test/local_scheduler_tests.cc +++ b/src/local_scheduler/test/local_scheduler_tests.cc @@ -426,7 +426,7 @@ TEST object_reconstruction_suppression_test(void) { exit(0); } else { /* Connect a plasma manager client so we can call object_table_add. */ - const char *db_connect_args[] = {"address", "127.0.0.1:12346"}; + const char *db_connect_args[] = {"manager_address", "127.0.0.1:12346"}; DBHandle *db = db_connect(std::string("127.0.0.1"), 6379, "plasma_manager", "127.0.0.1", 2, db_connect_args); db_attach(db, local_scheduler->loop, false); diff --git a/src/plasma/plasma_manager.cc b/src/plasma/plasma_manager.cc index 46b216fe9..816063308 100644 --- a/src/plasma/plasma_manager.cc +++ b/src/plasma/plasma_manager.cc @@ -101,7 +101,7 @@ void process_status_request(ClientConnection *client_conn, ObjectID object_id); * @return Status of object_id as defined in plasma.h */ int request_status(ObjectID object_id, - const std::vector &manager_vector, + const std::vector &manager_vector, void *context); /** @@ -292,12 +292,6 @@ ClientConnection *ClientConnection_init(PlasmaManagerState *state, */ void ClientConnection_free(ClientConnection *client_conn); -void object_table_subscribe_callback(ObjectID object_id, - int64_t data_size, - int manager_count, - const char *manager_vector[], - void *context); - std::unordered_map, UniqueIDHasher> & object_wait_requests_from_type(PlasmaManagerState *manager_state, int type) { /* We use different types of hash tables for different requests. */ @@ -464,7 +458,7 @@ PlasmaManagerState *PlasmaManagerState_init(const char *store_socket_name, db_connect_args[1] = store_socket_name; db_connect_args[2] = "manager_socket_name"; db_connect_args[3] = manager_socket_name; - db_connect_args[4] = "address"; + db_connect_args[4] = "manager_address"; db_connect_args[5] = manager_address_str.c_str(); state->db = db_connect(std::string(redis_primary_addr), redis_primary_port, @@ -1003,35 +997,24 @@ void fatal_table_callback(ObjectID id, void *user_context, void *user_data) { CHECK(0); } -void object_present_callback(ObjectID object_id, - const std::vector &manager_vector, - void *context) { - PlasmaManagerState *manager_state = (PlasmaManagerState *) context; - /* This callback is called from object_table_subscribe, which guarantees that - * the manager vector contains at least one element. */ - CHECK(manager_vector.size() >= 1); - - /* Update the in-progress remote wait requests. */ - update_object_wait_requests(manager_state, object_id, - plasma::PLASMA_QUERY_ANYWHERE, - ObjectStatus_Remote); -} - /* This callback is used by both fetch and wait. Therefore, it may have to * handle outstanding fetch and wait requests. */ -void object_table_subscribe_callback( - ObjectID object_id, - int64_t data_size, - const std::vector &manager_vector, - void *context) { +void object_table_subscribe_callback(ObjectID object_id, + int64_t data_size, + const std::vector &manager_ids, + void *context) { PlasmaManagerState *manager_state = (PlasmaManagerState *) context; + const std::vector managers = + db_client_table_get_ip_addresses(manager_state->db, manager_ids); /* Run the callback for fetch requests if there is a fetch request. */ auto it = manager_state->fetch_requests.find(object_id); if (it != manager_state->fetch_requests.end()) { - request_transfer(object_id, manager_vector, context); + request_transfer(object_id, managers, context); } /* Run the callback for wait requests. */ - object_present_callback(object_id, manager_vector, context); + update_object_wait_requests(manager_state, object_id, + plasma::PLASMA_QUERY_ANYWHERE, + ObjectStatus_Remote); } void process_fetch_requests(ClientConnection *client_conn, @@ -1170,7 +1153,7 @@ void process_wait_request(ClientConnection *client_conn, */ void request_status_done(ObjectID object_id, bool never_created, - const std::vector &manager_vector, + const std::vector &manager_vector, void *context) { ClientConnection *client_conn = (ClientConnection *) context; int status = request_status(object_id, manager_vector, context); @@ -1181,7 +1164,7 @@ void request_status_done(ObjectID object_id, } int request_status(ObjectID object_id, - const std::vector &manager_vector, + const std::vector &manager_vector, void *context) { ClientConnection *client_conn = (ClientConnection *) context;