diff --git a/python/ray/common/redis_module/runtest.py b/python/ray/common/redis_module/runtest.py index be1d35ed7..c82838a1b 100644 --- a/python/ray/common/redis_module/runtest.py +++ b/python/ray/common/redis_module/runtest.py @@ -14,6 +14,7 @@ import ray.services # Import flatbuffer bindings. from ray.core.generated.SubscribeToNotificationsReply import SubscribeToNotificationsReply from ray.core.generated.TaskReply import TaskReply +from ray.core.generated.ResultTableReply import ResultTableReply OBJECT_INFO_PREFIX = "OI:" OBJECT_LOCATION_PREFIX = "OL:" @@ -197,6 +198,11 @@ class TestGlobalStateStore(unittest.TestCase): [b"manager_id1", b"manager_id2", b"manager_id3"]) def testResultTableAddAndLookup(self): + def check_result_table_entry(message, task_id, is_put): + result_table_reply = ResultTableReply.GetRootAsResultTableReply(message, 0) + self.assertEqual(result_table_reply.TaskId(), task_id) + self.assertEqual(result_table_reply.IsPut(), is_put) + # Try looking up something in the result table before anything is added. response = self.redis.execute_command("RAY.RESULT_TABLE_LOOKUP", "object_id1") self.assertIsNone(response) @@ -206,17 +212,17 @@ class TestGlobalStateStore(unittest.TestCase): self.assertIsNone(response) # Add the result to the result table. The lookup now returns the task ID. task_id = b"task_id1" - self.redis.execute_command("RAY.RESULT_TABLE_ADD", "object_id1", task_id) + self.redis.execute_command("RAY.RESULT_TABLE_ADD", "object_id1", task_id, 0) response = self.redis.execute_command("RAY.RESULT_TABLE_LOOKUP", "object_id1") - self.assertEqual(response, task_id) + check_result_table_entry(response, task_id, False) # Doing it again should still work. response = self.redis.execute_command("RAY.RESULT_TABLE_LOOKUP", "object_id1") - self.assertEqual(response, task_id) + check_result_table_entry(response, task_id, False) # Try another result table lookup. This should succeed. task_id = b"task_id2" - self.redis.execute_command("RAY.RESULT_TABLE_ADD", "object_id2", task_id) + self.redis.execute_command("RAY.RESULT_TABLE_ADD", "object_id2", task_id, 1) response = self.redis.execute_command("RAY.RESULT_TABLE_LOOKUP", "object_id2") - self.assertEqual(response, task_id) + check_result_table_entry(response, task_id, True) def testInvalidTaskTableAdd(self): # Check that Redis returns an error when RAY.TASK_TABLE_ADD is called with @@ -241,12 +247,13 @@ class TestGlobalStateStore(unittest.TestCase): TASK_STATUS_SCHEDULED = 2 TASK_STATUS_QUEUED = 4 - def check_task_reply(message, task_args): + def check_task_reply(message, task_args, updated=False): task_status, local_scheduler_id, task_spec = task_args task_reply_object = TaskReply.GetRootAsTaskReply(message, 0) self.assertEqual(task_reply_object.State(), task_status) self.assertEqual(task_reply_object.LocalSchedulerId(), local_scheduler_id) self.assertEqual(task_reply_object.TaskSpec(), task_spec) + self.assertEqual(task_reply_object.Updated(), updated) # Check that task table adds, updates, and lookups work correctly. task_args = [TASK_STATUS_WAITING, b"node_id", b"task_spec"] @@ -266,7 +273,7 @@ class TestGlobalStateStore(unittest.TestCase): response = self.redis.execute_command("RAY.TASK_TABLE_TEST_AND_UPDATE", "task_id", *task_args[:3]) - check_task_reply(response, task_args[1:]) + check_task_reply(response, task_args[1:], updated=True) # Check that the task entry is still the same. get_response = self.redis.execute_command("RAY.TASK_TABLE_GET", "task_id") check_task_reply(get_response, task_args[1:]) @@ -277,43 +284,46 @@ class TestGlobalStateStore(unittest.TestCase): response = self.redis.execute_command("RAY.TASK_TABLE_TEST_AND_UPDATE", "task_id", *task_args[:3]) - check_task_reply(response, task_args[1:]) + check_task_reply(response, task_args[1:], updated=True) # Check that the update happened. get_response = self.redis.execute_command("RAY.TASK_TABLE_GET", "task_id") check_task_reply(get_response, task_args[1:]) # If the current value is no longer the same as the test value, the - # response is nil. - task_args[1] = TASK_STATUS_WAITING + # response is the same task as before the test-and-set. + new_task_args = task_args[:] + new_task_args[1] = TASK_STATUS_WAITING response = self.redis.execute_command("RAY.TASK_TABLE_TEST_AND_UPDATE", "task_id", - *task_args[:3]) - self.assertEqual(response, None) + *new_task_args[:3]) + check_task_reply(response, task_args[1:], updated=False) # Check that the update did not happen. get_response2 = self.redis.execute_command("RAY.TASK_TABLE_GET", "task_id") self.assertEqual(get_response2, get_response) - self.assertNotEqual(get_response2, task_args[1:]) # If the test value is a bitmask that matches the current value, the update # happens. + task_args = new_task_args task_args[0] = TASK_STATUS_SCHEDULED | TASK_STATUS_QUEUED response = self.redis.execute_command("RAY.TASK_TABLE_TEST_AND_UPDATE", "task_id", *task_args[:3]) - check_task_reply(response, task_args[1:]) + check_task_reply(response, task_args[1:], updated=True) # If the test value is a bitmask that does not match the current value, the - # update does not happen. - task_args[1] = TASK_STATUS_SCHEDULED + # update does not happen, and the response is the same task as before the + # test-and-set. + new_task_args = task_args[:] + new_task_args[0] = TASK_STATUS_SCHEDULED old_response = response response = self.redis.execute_command("RAY.TASK_TABLE_TEST_AND_UPDATE", "task_id", - *task_args[:3]) - self.assertEqual(response, None) + *new_task_args[:3]) + check_task_reply(response, task_args[1:], updated=False) # Check that the update did not happen. get_response = self.redis.execute_command("RAY.TASK_TABLE_GET", "task_id") - self.assertEqual(get_response, old_response) - self.assertNotEqual(get_response, task_args[1:]) + self.assertNotEqual(get_response, old_response) + check_task_reply(get_response, task_args[1:]) def testTaskTableSubscribe(self): scheduling_state = 1 diff --git a/python/ray/worker.py b/python/ray/worker.py index 5700b23a3..5dda40441 100644 --- a/python/ray/worker.py +++ b/python/ray/worker.py @@ -43,7 +43,10 @@ DRIVER_ID_LENGTH = 20 ERROR_ID_LENGTH = 20 # This must match the definition of NIL_ACTOR_ID in task.h. -NIL_ACTOR_ID = 20 * b"\xff" +NIL_ID = 20 * b"\xff" +NIL_LOCAL_SCHEDULER_ID = NIL_ID +NIL_FUNCTION_ID = NIL_ID +NIL_ACTOR_ID = NIL_ID # When performing ray.get, wait 1 second before attemping to reconstruct and # fetch the object again. @@ -52,6 +55,10 @@ GET_TIMEOUT_MILLISECONDS = 1000 # This must be kept in sync with the `error_types` array in # common/state/error_table.h. OBJECT_HASH_MISMATCH_ERROR_TYPE = b"object_hash_mismatch" +PUT_RECONSTRUCTION_ERROR_TYPE = b"put_reconstruction" + +# This must be kept in sync with the `scheduling_state` enum in common/task.h. +TASK_STATUS_RUNNING = 8 def random_string(): return np.random.bytes(20) @@ -696,9 +703,14 @@ def error_info(worker=global_worker): error_contents = worker.redis_client.hgetall(error_key) # If the error is an object hash mismatch, look up the function name for # the nondeterministic task. - if error_contents[b"type"] == OBJECT_HASH_MISMATCH_ERROR_TYPE: + error_type = error_contents[b"type"] + if (error_type == OBJECT_HASH_MISMATCH_ERROR_TYPE or error_type == + PUT_RECONSTRUCTION_ERROR_TYPE): function_id = error_contents[b"data"] - function_name = worker.redis_client.hget("RemoteFunction:{}".format(function_id), "name") + if function_id == NIL_FUNCTION_ID: + function_name = b"Driver" + else: + function_name = worker.redis_client.hget("RemoteFunction:{}".format(function_id), "name") error_contents[b"data"] = function_name errors.append(error_contents) @@ -1238,6 +1250,7 @@ def connect(info, object_id_seed=None, mode=WORKER_MODE, worker=global_worker, a redis_ip_address, redis_port = info["redis_address"].split(":") worker.redis_client = redis.StrictRedis(host=redis_ip_address, port=int(redis_port)) worker.lock = threading.Lock() + # Register the worker with Redis. if mode in [SCRIPT_MODE, SILENT_MODE]: # The concept of a driver is the same as the concept of a "job". Register @@ -1266,7 +1279,10 @@ def connect(info, object_id_seed=None, mode=WORKER_MODE, worker=global_worker, a # Create an object store client. worker.plasma_client = ray.plasma.PlasmaClient(info["store_socket_name"], info["manager_socket_name"]) # Create the local scheduler client. - worker.local_scheduler_client = ray.local_scheduler.LocalSchedulerClient(info["local_scheduler_socket_name"], worker.actor_id, is_worker) + worker.local_scheduler_client = ray.local_scheduler.LocalSchedulerClient( + info["local_scheduler_socket_name"], + worker.actor_id, + is_worker) # If this is a driver, set the current task ID, the task driver ID, and set # the task index to 0. @@ -1292,12 +1308,39 @@ def connect(info, object_id_seed=None, mode=WORKER_MODE, worker=global_worker, a # Set other fields needed for computing task IDs. worker.task_index = 0 worker.put_index = 0 + + # Create an entry for the driver task in the task table. This task is added + # immediately with status RUNNING. This allows us to push errors related to + # this driver task back to the driver. For example, if the driver creates + # an object that is later evicted, we should notify the user that we're + # unable to reconstruct the object, since we cannot rerun the driver. + driver_task = ray.local_scheduler.Task( + worker.task_driver_id, + ray.local_scheduler.ObjectID(NIL_FUNCTION_ID), + [], + 0, + worker.current_task_id, + worker.task_index, + ray.local_scheduler.ObjectID(NIL_ACTOR_ID), + worker.actor_counters[actor_id], + [0, 0]) + worker.redis_client.execute_command( + "RAY.TASK_TABLE_ADD", + driver_task.task_id().id(), + TASK_STATUS_RUNNING, + NIL_LOCAL_SCHEDULER_ID, + ray.local_scheduler.task_to_string(driver_task)) + # Set the driver's current task ID to the task ID assigned to the driver + # task. + worker.current_task_id = driver_task.task_id() + # If this is a worker, then start a thread to import exports from the driver. if mode == WORKER_MODE: t = threading.Thread(target=import_thread, args=(worker,)) # Making the thread a daemon causes it to exit when the main thread exits. t.daemon = True t.start() + # If this is a driver running in SCRIPT_MODE, start a thread to print error # messages asynchronously in the background. Ideally the scheduler would push # messages to the driver's worker service, but we ran into bugs when trying to @@ -1503,7 +1546,8 @@ def put(value, worker=global_worker): if worker.mode == PYTHON_MODE: # In PYTHON_MODE, ray.put is the identity operation return value - object_id = ray.local_scheduler.compute_put_id(worker.current_task_id, worker.put_index) + object_id = worker.local_scheduler_client.compute_put_id( + worker.current_task_id, worker.put_index) worker.put_object(object_id, value) worker.put_index += 1 return object_id diff --git a/src/common/format/common.fbs b/src/common/format/common.fbs index 5d5c508e0..cde6b1df5 100644 --- a/src/common/format/common.fbs +++ b/src/common/format/common.fbs @@ -90,6 +90,9 @@ table TaskReply { local_scheduler_id: string; // A string of bytes representing the task specification. task_spec: string; + // A boolean representing whether the update was successful. This field + // should only be used for test-and-set operations. + updated: bool; } root_type TaskReply; @@ -127,3 +130,12 @@ table LocalSchedulerInfoMessage { } root_type LocalSchedulerInfoMessage; + +table ResultTableReply { + // The task ID of the task that created the object. + task_id: string; + // Whether the task created the object through a ray.put. + is_put: bool; +} + +root_type ResultTableReply; diff --git a/src/common/io.cc b/src/common/io.cc index c4d2679ee..25944ee26 100644 --- a/src/common/io.cc +++ b/src/common/io.cc @@ -116,6 +116,10 @@ int connect_ipc_sock_retry(const char *socket_pathname, if (fd >= 0) { break; } + if (num_attempts == 0) { + LOG_ERROR("Connection to socket failed for pathname %s.", + socket_pathname); + } /* Sleep for timeout milliseconds. */ usleep(timeout * 1000); } @@ -147,7 +151,7 @@ int connect_ipc_sock(const char *socket_pathname) { if (connect(socket_fd, (struct sockaddr *) &socket_address, sizeof(socket_address)) != 0) { - LOG_ERROR("Connection to socket failed for pathname %s.", socket_pathname); + close(socket_fd); return -1; } @@ -173,6 +177,10 @@ int connect_inet_sock_retry(const char *ip_addr, if (fd >= 0) { break; } + if (num_attempts == 0) { + LOG_ERROR("Connection to socket failed for address %s:%d.", ip_addr, + port); + } /* Sleep for timeout milliseconds. */ usleep(timeout * 1000); } @@ -203,7 +211,6 @@ int connect_inet_sock(const char *ip_addr, int port) { addr.sin_port = htons(port); if (connect(fd, (struct sockaddr *) &addr, sizeof(addr)) != 0) { - LOG_ERROR("Connection to socket failed for address %s:%d.", ip_addr, port); close(fd); return -1; } diff --git a/src/common/lib/python/common_extension.cc b/src/common/lib/python/common_extension.cc index a8c52ece1..384dab91e 100644 --- a/src/common/lib/python/common_extension.cc +++ b/src/common/lib/python/common_extension.cc @@ -543,14 +543,3 @@ PyObject *check_simple_value(PyObject *self, PyObject *args) { } Py_RETURN_FALSE; } - -PyObject *compute_put_id(PyObject *self, PyObject *args) { - int put_index; - TaskID task_id; - if (!PyArg_ParseTuple(args, "O&i", &PyObjectToUniqueID, &task_id, - &put_index)) { - return NULL; - } - ObjectID put_id = task_compute_put_id(task_id, put_index); - return PyObjectID_make(put_id); -} diff --git a/src/common/redis_module/ray_redis_module.cc b/src/common/redis_module/ray_redis_module.cc index c9fb2194f..b21479875 100644 --- a/src/common/redis_module/ray_redis_module.cc +++ b/src/common/redis_module/ray_redis_module.cc @@ -690,27 +690,40 @@ int ObjectInfoSubscribe_RedisCommand(RedisModuleCtx *ctx, * * This is called from a client with the command: * - * RAY.RESULT_TABLE_ADD + * RAY.RESULT_TABLE_ADD * * @param object_id A string representing the object ID. * @param task_id A string representing the task ID of the task that produced * the object. + * @param is_put An integer that is 1 if the object was created through ray.put + * and 0 if created by return value. * @return OK if the operation was successful. */ int ResultTableAdd_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) { - if (argc != 3) { + if (argc != 4) { return RedisModule_WrongArity(ctx); } /* Set the task ID under field "task" in the object info table. */ RedisModuleString *object_id = argv[1]; RedisModuleString *task_id = argv[2]; + RedisModuleString *is_put = argv[3]; + + /* Check to make sure the is_put field was a 0 or a 1. */ + long long is_put_integer; + if ((RedisModule_StringToLongLong(is_put, &is_put_integer) != + REDISMODULE_OK) || + (is_put_integer != 0 && is_put_integer != 1)) { + return RedisModule_ReplyWithError( + ctx, "The is_put field must be either a 0 or a 1."); + } RedisModuleKey *key; key = OpenPrefixedKey(ctx, OBJECT_INFO_PREFIX, object_id, REDISMODULE_WRITE); - RedisModule_HashSet(key, REDISMODULE_HASH_CFIELDS, "task", task_id, NULL); + RedisModule_HashSet(key, REDISMODULE_HASH_CFIELDS, "task", task_id, "is_put", + is_put, NULL); /* Clean up. */ RedisModule_CloseKey(key); @@ -723,13 +736,19 @@ int ResultTableAdd_RedisCommand(RedisModuleCtx *ctx, * Reply with information about a task ID. This is used by * RAY.RESULT_TABLE_LOOKUP and RAY.TASK_TABLE_GET. * + * @param ctx The Redis context. * @param task_id The task ID of the task to reply about. + * @param updated A boolean representing whether the task was updated during + * this operation. This field is only used for + * RAY.TASK_TABLE_TEST_AND_UPDATE operations. * @return NIL if the task ID is not in the task table. An error if the task ID * is in the task table but the appropriate fields are not there, and * an array of the task scheduling state, the local scheduler ID, and * the task spec for the task otherwise. */ -int ReplyWithTask(RedisModuleCtx *ctx, RedisModuleString *task_id) { +int ReplyWithTask(RedisModuleCtx *ctx, + RedisModuleString *task_id, + bool updated) { RedisModuleKey *key = OpenPrefixedKey(ctx, TASK_PREFIX, task_id, REDISMODULE_READ); @@ -762,7 +781,7 @@ int ReplyWithTask(RedisModuleCtx *ctx, RedisModuleString *task_id) { auto message = CreateTaskReply(fbb, RedisStringToFlatbuf(fbb, task_id), state_integer, RedisStringToFlatbuf(fbb, local_scheduler_id), - RedisStringToFlatbuf(fbb, task_spec)); + RedisStringToFlatbuf(fbb, task_spec), updated); fbb.Finish(message); RedisModuleString *reply = RedisModule_CreateString( @@ -790,10 +809,8 @@ int ReplyWithTask(RedisModuleCtx *ctx, RedisModuleString *task_id) { * RAY.RESULT_TABLE_LOOKUP * * @param object_id A string representing the object ID. - * @return NIL if the object ID is not in the result table or if the - * corresponding task ID is not in the task table. Otherwise, this - * returns an array of the scheduling state, the local scheduler ID, and - * the task spec for the task corresponding to this object ID. + * @return NIL if the object ID is not in the result table. Otherwise, this + * returns a ResultTableReply flatbuffer. */ int ResultTableLookup_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, @@ -814,13 +831,36 @@ int ResultTableLookup_RedisCommand(RedisModuleCtx *ctx, } RedisModuleString *task_id; - RedisModule_HashGet(key, REDISMODULE_HASH_CFIELDS, "task", &task_id, NULL); + RedisModuleString *is_put; + RedisModule_HashGet(key, REDISMODULE_HASH_CFIELDS, "task", &task_id, "is_put", + &is_put, NULL); RedisModule_CloseKey(key); - if (task_id == NULL) { + if (task_id == NULL || is_put == NULL) { return RedisModule_ReplyWithNull(ctx); } - RedisModule_ReplyWithString(ctx, task_id); + /* Check to make sure the is_put field was a 0 or a 1. */ + long long is_put_integer; + if (RedisModule_StringToLongLong(is_put, &is_put_integer) != REDISMODULE_OK || + (is_put_integer != 0 && is_put_integer != 1)) { + RedisModule_FreeString(ctx, is_put); + RedisModule_FreeString(ctx, task_id); + return RedisModule_ReplyWithError( + ctx, "The is_put field must be either a 0 or a 1."); + } + + /* Make and return the flatbuffer reply. */ + flatbuffers::FlatBufferBuilder fbb; + auto message = CreateResultTableReply(fbb, RedisStringToFlatbuf(fbb, task_id), + bool(is_put_integer)); + fbb.Finish(message); + RedisModuleString *reply = RedisModule_CreateString( + ctx, (const char *) fbb.GetBufferPointer(), fbb.GetSize()); + RedisModule_ReplyWithString(ctx, reply); + + /* Clean up. */ + RedisModule_FreeString(ctx, reply); + RedisModule_FreeString(ctx, is_put); RedisModule_FreeString(ctx, task_id); return REDISMODULE_OK; @@ -971,12 +1011,8 @@ int TaskTableUpdate_RedisCommand(RedisModuleCtx *ctx, * instance) to update the task entry with. * @param ray_client_id A string that is the ray client ID of the associated * local scheduler, if any, to update the task entry with. - * @return If the current scheduling state does not match the test bitmask, - * returns nil. Else, returns the same as RAY.TASK_TABLE_GET: an array - * of strings representing the updated task fields in the following - * order: 1) (integer) scheduling state 2) (string) associated local - * scheduler ID, if any 3) (string) the task specification, which can be - * cast to a task_spec. + * @return Returns the task entry as a TaskReply. The reply will reflect the + * update, if it happened. */ int TaskTableTestAndUpdate_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, @@ -1015,20 +1051,19 @@ int TaskTableTestAndUpdate_RedisCommand(RedisModuleCtx *ctx, return RedisModule_ReplyWithError( ctx, "Invalid test value for scheduling state"); } - if ((current_state_integer & test_state_bitmask) == 0) { - /* The current value does not match the test bitmask, so do not perform the - * update. */ - RedisModule_CloseKey(key); - return RedisModule_ReplyWithNull(ctx); + + bool updated = false; + if (current_state_integer & test_state_bitmask) { + /* The test passed, so perform the update. */ + RedisModule_HashSet(key, REDISMODULE_HASH_CFIELDS, "state", state, + "local_scheduler_id", argv[4], NULL); + updated = true; } - /* The test passed, so perform the update. */ - RedisModule_HashSet(key, REDISMODULE_HASH_CFIELDS, "state", state, - "local_scheduler_id", argv[4], NULL); /* Clean up. */ RedisModule_CloseKey(key); /* Construct a reply by getting the task from the task ID. */ - return ReplyWithTask(ctx, argv[1]); + return ReplyWithTask(ctx, argv[1], updated); } /** @@ -1052,7 +1087,7 @@ int TaskTableGet_RedisCommand(RedisModuleCtx *ctx, } /* Construct a reply by getting the task from the task ID. */ - return ReplyWithTask(ctx, argv[1]); + return ReplyWithTask(ctx, argv[1], false); } extern "C" { diff --git a/src/common/state/error_table.h b/src/common/state/error_table.h index c064b9b60..9c68ad112 100644 --- a/src/common/state/error_table.h +++ b/src/common/state/error_table.h @@ -18,14 +18,19 @@ typedef enum { /** An object was added with a different hash from the existing * one. */ OBJECT_HASH_MISMATCH_ERROR_INDEX = 0, + /** An object that was created through a ray.put is lost. */ + PUT_RECONSTRUCTION_ERROR_INDEX, /** The total number of error types. */ MAX_ERROR_INDEX } error_index; /** Information about the error to be displayed to the user. */ -static const char *error_types[] = {"object_hash_mismatch"}; +static const char *error_types[] = {"object_hash_mismatch", + "put_reconstruction"}; static const char *error_messages[] = { - "A nondeterministic task was reexecuted."}; + "A nondeterministic task was reexecuted.", + "An object created by ray.put was evicted and could not be reconstructed. " + "The driver may need to be restarted."}; /** * Push an error to the given Python driver. diff --git a/src/common/state/object_table.cc b/src/common/state/object_table.cc index fb3432188..7992f969f 100644 --- a/src/common/state/object_table.cc +++ b/src/common/state/object_table.cc @@ -104,13 +104,16 @@ void object_info_subscribe(DBHandle *db_handle, void result_table_add(DBHandle *db_handle, ObjectID object_id, - TaskID task_id_arg, + TaskID task_id, + bool is_put, RetryInfo *retry, result_table_done_callback done_callback, void *user_context) { - TaskID *task_id_copy = (TaskID *) malloc(sizeof(TaskID)); - memcpy(task_id_copy, task_id_arg.id, sizeof(*task_id_copy)); - init_table_callback(db_handle, object_id, __func__, task_id_copy, retry, + ResultTableAddInfo *info = + (ResultTableAddInfo *) malloc(sizeof(ResultTableAddInfo)); + info->task_id = task_id; + info->is_put = is_put; + init_table_callback(db_handle, object_id, __func__, info, retry, (table_done_callback) done_callback, redis_result_table_add, user_context); } diff --git a/src/common/state/object_table.h b/src/common/state/object_table.h index 422632be7..377115fa7 100644 --- a/src/common/state/object_table.h +++ b/src/common/state/object_table.h @@ -224,6 +224,15 @@ typedef struct { typedef void (*result_table_done_callback)(ObjectID object_id, void *user_context); +/** Information about a result table entry to add. */ +typedef struct { + /** The task ID of the task that created the requested object. */ + TaskID task_id; + /** True if the object was created through a put, and false if created by + * return value. */ + bool is_put; +} ResultTableAddInfo; + /** * Add information about a new object to the object table. This * is immutable information like the ID of the task that @@ -232,6 +241,8 @@ typedef void (*result_table_done_callback)(ObjectID object_id, * @param db_handle Handle to object_table database. * @param object_id ID of the object to add. * @param task_id ID of the task that creates this object. + * @param is_put A boolean that is true if the object was created through a + * ray.put, and false if the object was created by return value. * @param retry Information about retrying the request to the database. * @param done_callback Function to be called when database returns result. * @param user_context Context passed by the caller. @@ -240,6 +251,7 @@ typedef void (*result_table_done_callback)(ObjectID object_id, void result_table_add(DBHandle *db_handle, ObjectID object_id, TaskID task_id, + bool is_put, RetryInfo *retry, result_table_done_callback done_callback, void *user_context); @@ -247,6 +259,7 @@ void result_table_add(DBHandle *db_handle, /** Callback called when the result table lookup completes. */ typedef void (*result_table_lookup_callback)(ObjectID object_id, TaskID task_id, + bool is_put, void *user_context); /** diff --git a/src/common/state/redis.cc b/src/common/state/redis.cc index 01a2e1264..15050dd79 100644 --- a/src/common/state/redis.cc +++ b/src/common/state/redis.cc @@ -342,12 +342,14 @@ void redis_result_table_add(TableCallbackData *callback_data) { CHECK(callback_data); DBHandle *db = callback_data->db_handle; ObjectID id = callback_data->id; - TaskID *result_task_id = (TaskID *) callback_data->data; + ResultTableAddInfo *info = (ResultTableAddInfo *) callback_data->data; + int is_put = info->is_put ? 1 : 0; + /* Add the result entry to the result table. */ int status = redisAsyncCommand( db->context, redis_result_table_add_callback, - (void *) callback_data->timer_id, "RAY.RESULT_TABLE_ADD %b %b", id.id, - sizeof(id.id), result_task_id->id, sizeof(result_task_id->id)); + (void *) callback_data->timer_id, "RAY.RESULT_TABLE_ADD %b %b %d", id.id, + sizeof(id.id), info->task_id.id, sizeof(info->task_id.id), is_put); if ((status == REDIS_ERR) || db->context->err) { LOG_REDIS_DEBUG(db->context, "Error in result table add"); } @@ -386,16 +388,19 @@ void redis_result_table_lookup_callback(redisAsyncContext *c, reply->type); /* Parse the task from the reply. */ TaskID result_id = NIL_TASK_ID; + bool is_put = false; if (reply->type == REDIS_REPLY_STRING) { - CHECK(reply->len == sizeof(result_id)); - memcpy(&result_id, reply->str, reply->len); + auto message = flatbuffers::GetRoot(reply->str); + result_id = from_flatbuf(message->task_id()); + is_put = message->is_put(); } /* Call the done callback if there is one. */ result_table_lookup_callback done_callback = (result_table_lookup_callback) callback_data->done_callback; if (done_callback != NULL) { - done_callback(callback_data->id, result_id, callback_data->user_context); + done_callback(callback_data->id, result_id, is_put, + callback_data->user_context); } /* Clean up timer and callback. */ destroy_timer_callback(db->loop, callback_data); @@ -761,11 +766,15 @@ void redis_task_table_test_and_update_callback(redisAsyncContext *c, redisReply *reply = (redisReply *) r; /* Parse the task from the reply. */ Task *task = parse_and_construct_task_from_redis_reply(reply); + /* Determine whether the update happened. */ + auto message = flatbuffers::GetRoot(reply->str); + bool updated = message->updated(); + /* Call the done callback if there is one. */ - task_table_get_callback done_callback = - (task_table_get_callback) callback_data->done_callback; + task_table_test_and_update_callback done_callback = + (task_table_test_and_update_callback) callback_data->done_callback; if (done_callback != NULL) { - done_callback(task, callback_data->user_context); + done_callback(task, callback_data->user_context, updated); } /* Free the task if it is not NULL. */ if (task != NULL) { diff --git a/src/common/state/task_table.cc b/src/common/state/task_table.cc index 8fdb6d0a0..6eb8bee27 100644 --- a/src/common/state/task_table.cc +++ b/src/common/state/task_table.cc @@ -33,13 +33,14 @@ void task_table_update(DBHandle *db_handle, redis_task_table_update, user_context); } -void task_table_test_and_update(DBHandle *db_handle, - TaskID task_id, - int test_state_bitmask, - int update_state, - RetryInfo *retry, - task_table_get_callback done_callback, - void *user_context) { +void task_table_test_and_update( + DBHandle *db_handle, + TaskID task_id, + int test_state_bitmask, + int update_state, + RetryInfo *retry, + task_table_test_and_update_callback done_callback, + void *user_context) { TaskTableTestAndUpdateData *update_data = (TaskTableTestAndUpdateData *) malloc(sizeof(TaskTableTestAndUpdateData)); update_data->test_state_bitmask = test_state_bitmask; diff --git a/src/common/state/task_table.h b/src/common/state/task_table.h index f09d9994c..fee7b854c 100644 --- a/src/common/state/task_table.h +++ b/src/common/state/task_table.h @@ -28,6 +28,13 @@ typedef void (*task_table_done_callback)(TaskID task_id, void *user_context); * was not in the task table, then the task pointer will be NULL. */ typedef void (*task_table_get_callback)(Task *task, void *user_context); +/* Callback called when a task table test-and-update operation completes. If + * the task ID was not in the task table, then the task pointer will be NULL. + * If the update succeeded, the updated field will be set to true. */ +typedef void (*task_table_test_and_update_callback)(Task *task, + void *user_context, + bool updated); + /** * Get a task's entry from the task table. * @@ -107,13 +114,14 @@ void task_table_update(DBHandle *db_handle, * fail_callback. * @return Void. */ -void task_table_test_and_update(DBHandle *db_handle, - TaskID task_id, - int test_state_bitmask, - int update_state, - RetryInfo *retry, - task_table_get_callback done_callback, - void *user_context); +void task_table_test_and_update( + DBHandle *db_handle, + TaskID task_id, + int test_state_bitmask, + int update_state, + RetryInfo *retry, + task_table_test_and_update_callback done_callback, + void *user_context); /* Data that is needed to test and set the task's scheduling state. */ typedef struct { diff --git a/src/common/test/object_table_tests.cc b/src/common/test/object_table_tests.cc index 787a8ea23..6a5216829 100644 --- a/src/common/test/object_table_tests.cc +++ b/src/common/test/object_table_tests.cc @@ -34,6 +34,7 @@ void new_object_fail_callback(UniqueID id, void new_object_done_callback(ObjectID object_id, TaskID task_id, + bool is_put, void *user_context) { new_object_succeeded = 1; CHECK(ObjectID_equal(object_id, new_object_id)); @@ -60,7 +61,7 @@ void new_object_task_callback(TaskID task_id, void *user_context) { .fail_callback = new_object_fail_callback, }; DBHandle *db = (DBHandle *) user_context; - result_table_add(db, new_object_id, new_object_task_id, &retry, + result_table_add(db, new_object_id, new_object_task_id, false, &retry, new_object_lookup_callback, (void *) db); } @@ -95,6 +96,7 @@ TEST new_object_test(void) { void new_object_no_task_callback(ObjectID object_id, TaskID task_id, + bool is_put, void *user_context) { new_object_succeeded = 1; CHECK(IS_NIL_ID(task_id)); diff --git a/src/local_scheduler/format/local_scheduler.fbs b/src/local_scheduler/format/local_scheduler.fbs index 13907f4b7..4dadf696e 100644 --- a/src/local_scheduler/format/local_scheduler.fbs +++ b/src/local_scheduler/format/local_scheduler.fbs @@ -26,7 +26,9 @@ enum MessageType:int { // For a worker that was blocked on some object(s), tell the local scheduler // that the worker is now unblocked. This is sent from a worker to a local // scheduler. - NotifyUnblocked + NotifyUnblocked, + // Add a result table entry for an object put. + PutObject } table EventLogMessage { @@ -48,3 +50,10 @@ table ReconstructObject { // Object ID of the object that needs to be reconstructed. object_id: string; } + +table PutObject { + // Task ID of the task that performed the put. + task_id: string; + // Object ID of the object that is being put. + object_id: string; +} diff --git a/src/local_scheduler/local_scheduler.cc b/src/local_scheduler/local_scheduler.cc index 1f1011558..996c4801d 100644 --- a/src/local_scheduler/local_scheduler.cc +++ b/src/local_scheduler/local_scheduler.cc @@ -20,6 +20,7 @@ #include "state/db.h" #include "state/task_table.h" #include "state/object_table.h" +#include "state/error_table.h" #include "utarray.h" #include "uthash.h" @@ -124,9 +125,12 @@ void kill_worker(LocalSchedulerClient *worker, bool cleanup) { /* Clean up the task in progress. */ if (worker->task_in_progress) { - /* Return the resources that the worker was using. */ - TaskSpec *spec = Task_task_spec(worker->task_in_progress); - update_dynamic_resources(state, spec, true); + if (!worker->is_blocked) { + /* Return the resources that the worker was using, if any. Blocked + * workers do not use resources. */ + TaskSpec *spec = Task_task_spec(worker->task_in_progress); + update_dynamic_resources(state, spec, true); + } /* Update the task table to reflect that the task failed to complete. */ if (state->db != NULL) { Task_set_state(worker->task_in_progress, TASK_STATUS_LOST); @@ -420,11 +424,16 @@ void update_dynamic_resources(LocalSchedulerState *state, * subtract the resource quantities from our accounting. */ resource *= -1; } + + bool oversubscribed = + (!return_resources && state->dynamic_resources[i] < 0); /* Add or subtract the task's resources from our count. */ state->dynamic_resources[i] += resource; - if (!return_resources && state->dynamic_resources[i] < 0) { - /* We are using more resources than we have been allocated. */ + if ((!return_resources && state->dynamic_resources[i] < 0) && + !oversubscribed) { + /* Log a warning if we are using more resources than we have been + * allocated, and we weren't already oversubscribed. */ LOG_WARN("local_scheduler dynamic resources dropped to %8.4f\t%8.4f\n", state->dynamic_resources[0], state->dynamic_resources[1]); } @@ -491,8 +500,12 @@ void process_plasma_notification(event_loop *loop, free(notification); } -void reconstruct_task_update_callback(Task *task, void *user_context) { - if (task == NULL) { +void reconstruct_task_update_callback(Task *task, + void *user_context, + bool updated) { + /* The task ID should be in the task table. */ + CHECK(task != NULL); + if (!updated) { /* The test-and-set of the task's scheduling state failed, so the task was * either not finished yet, or it was already being reconstructed. * Suppress the reconstruction request. */ @@ -517,27 +530,56 @@ void reconstruct_task_update_callback(Task *task, void *user_context) { } } +void reconstruct_put_task_update_callback(Task *task, + void *user_context, + bool updated) { + CHECK(task != NULL); + if (updated) { + /* The update to TASK_STATUS_RECONSTRUCTING succeeded, so continue with + * reconstruction as usual. */ + reconstruct_task_update_callback(task, user_context, updated); + return; + } + + /* An object created by `ray.put` was not able to be reconstructed, and the + * workload will likely hang. Push an error to the appropriate driver. */ + LocalSchedulerState *state = (LocalSchedulerState *) user_context; + TaskSpec *spec = Task_task_spec(task); + FunctionID function = TaskSpec_function(spec); + push_error(state->db, TaskSpec_driver_id(spec), + PUT_RECONSTRUCTION_ERROR_INDEX, sizeof(function), function.id); +} + void reconstruct_evicted_result_lookup_callback(ObjectID reconstruct_object_id, TaskID task_id, + bool is_put, void *user_context) { - /* TODO(swang): The following check will fail if an object was created by a - * put. */ CHECKM(!IS_NIL_ID(task_id), "No task information found for object during reconstruction"); LocalSchedulerState *state = (LocalSchedulerState *) user_context; + + task_table_test_and_update_callback done_callback; + if (is_put) { + /* If the evicted object was created through ray.put and the originating + * task + * is still executing, it's very likely that the workload will hang and the + * worker needs to be restarted. Else, the reconstruction behavior is the + * same as for other evicted objects */ + done_callback = reconstruct_put_task_update_callback; + } else { + done_callback = reconstruct_task_update_callback; + } /* If there are no other instances of the task running, it's safe for us to * claim responsibility for reconstruction. */ - task_table_test_and_update(state->db, task_id, - (TASK_STATUS_DONE | TASK_STATUS_LOST), - TASK_STATUS_RECONSTRUCTING, NULL, - reconstruct_task_update_callback, state); + task_table_test_and_update( + state->db, task_id, (TASK_STATUS_DONE | TASK_STATUS_LOST), + TASK_STATUS_RECONSTRUCTING, NULL, done_callback, state); } void reconstruct_failed_result_lookup_callback(ObjectID reconstruct_object_id, TaskID task_id, + bool is_put, void *user_context) { - /* TODO(swang): The following check will fail if an object was created by a - * put. */ if (IS_NIL_ID(task_id)) { /* NOTE(swang): For some reason, the result table update sometimes happens * after this lookup returns, possibly due to concurrent clients. In most @@ -613,7 +655,8 @@ void process_message(event_loop *loop, TaskID task_id = TaskSpec_task_id(spec); for (int64_t i = 0; i < TaskSpec_num_returns(spec); ++i) { ObjectID return_id = TaskSpec_return(spec, i); - result_table_add(state->db, return_id, task_id, NULL, NULL, NULL); + result_table_add(state->db, return_id, task_id, false, NULL, NULL, + NULL); } } @@ -745,6 +788,12 @@ void process_message(event_loop *loop, } print_worker_info("Worker unblocked", state->algorithm_state); } break; + case MessageType_PutObject: { + auto message = + flatbuffers::GetRoot(utarray_front(state->input_buffer)); + result_table_add(state->db, from_flatbuf(message->object_id()), + from_flatbuf(message->task_id()), true, NULL, NULL, NULL); + } break; default: /* This code should be unreachable. */ CHECK(0); diff --git a/src/local_scheduler/local_scheduler_algorithm.cc b/src/local_scheduler/local_scheduler_algorithm.cc index c6e10007d..05e84034d 100644 --- a/src/local_scheduler/local_scheduler_algorithm.cc +++ b/src/local_scheduler/local_scheduler_algorithm.cc @@ -1080,14 +1080,16 @@ void handle_worker_blocked(LocalSchedulerState *state, DCHECK(*q != worker); } + /* Add the worker to the list of blocked workers. */ + worker->is_blocked = true; + utarray_push_back(algorithm_state->blocked_workers, &worker); /* Return the resources that the blocked worker was using. */ CHECK(worker->task_in_progress != NULL); TaskSpec *spec = Task_task_spec(worker->task_in_progress); update_dynamic_resources(state, spec, true); - /* Add the worker to the list of blocked workers. */ - worker->is_blocked = true; - utarray_push_back(algorithm_state->blocked_workers, &worker); + /* Try to dispatch tasks, since we may have freed up some resources. */ + dispatch_tasks(state, algorithm_state); return; } } diff --git a/src/local_scheduler/local_scheduler_client.cc b/src/local_scheduler/local_scheduler_client.cc index 323f17731..61d3e5a47 100644 --- a/src/local_scheduler/local_scheduler_client.cc +++ b/src/local_scheduler/local_scheduler_client.cc @@ -92,3 +92,15 @@ void local_scheduler_log_message(LocalSchedulerConnection *conn) { void local_scheduler_notify_unblocked(LocalSchedulerConnection *conn) { write_message(conn->conn, MessageType_NotifyUnblocked, 0, NULL); } + +void local_scheduler_put_object(LocalSchedulerConnection *conn, + TaskID task_id, + ObjectID object_id) { + flatbuffers::FlatBufferBuilder fbb; + auto message = CreatePutObject(fbb, to_flatbuf(fbb, task_id), + to_flatbuf(fbb, object_id)); + fbb.Finish(message); + + write_message(conn->conn, MessageType_PutObject, fbb.GetSize(), + fbb.GetBufferPointer()); +} diff --git a/src/local_scheduler/local_scheduler_client.h b/src/local_scheduler/local_scheduler_client.h index 014c1861e..b05341712 100644 --- a/src/local_scheduler/local_scheduler_client.h +++ b/src/local_scheduler/local_scheduler_client.h @@ -112,4 +112,16 @@ void local_scheduler_log_message(LocalSchedulerConnection *conn); */ void local_scheduler_notify_unblocked(LocalSchedulerConnection *conn); +/** + * Record the mapping from object ID to task ID for put events. + * + * @param conn The connection information. + * @param task_id The ID of the task that called put. + * @param object_id The ID of the object being stored. + * @return Void. + */ +void local_scheduler_put_object(LocalSchedulerConnection *conn, + TaskID task_id, + ObjectID object_id); + #endif diff --git a/src/local_scheduler/local_scheduler_extension.cc b/src/local_scheduler/local_scheduler_extension.cc index 943d82643..46d133571 100644 --- a/src/local_scheduler/local_scheduler_extension.cc +++ b/src/local_scheduler/local_scheduler_extension.cc @@ -93,6 +93,21 @@ static PyObject *PyLocalSchedulerClient_notify_unblocked(PyObject *self) { Py_RETURN_NONE; } +static PyObject *PyLocalSchedulerClient_compute_put_id(PyObject *self, + PyObject *args) { + int put_index; + TaskID task_id; + if (!PyArg_ParseTuple(args, "O&i", &PyObjectToUniqueID, &task_id, + &put_index)) { + return NULL; + } + ObjectID put_id = task_compute_put_id(task_id, put_index); + local_scheduler_put_object( + ((PyLocalSchedulerClient *) self)->local_scheduler_connection, task_id, + put_id); + return PyObjectID_make(put_id); +} + static PyMethodDef PyLocalSchedulerClient_methods[] = { {"submit", (PyCFunction) PyLocalSchedulerClient_submit, METH_VARARGS, "Submit a task to the local scheduler."}, @@ -105,6 +120,8 @@ static PyMethodDef PyLocalSchedulerClient_methods[] = { "Log an event to the event log through the local scheduler."}, {"notify_unblocked", (PyCFunction) PyLocalSchedulerClient_notify_unblocked, METH_NOARGS, "Notify the local scheduler that we are unblocked."}, + {"compute_put_id", (PyCFunction) PyLocalSchedulerClient_compute_put_id, + METH_VARARGS, "Return the object ID for a put call within a task."}, {NULL} /* Sentinel */ }; @@ -152,8 +169,6 @@ static PyTypeObject PyLocalSchedulerClientType = { static PyMethodDef local_scheduler_methods[] = { {"check_simple_value", check_simple_value, METH_VARARGS, "Should the object be passed by value?"}, - {"compute_put_id", compute_put_id, METH_VARARGS, - "Return the object ID for a put call within a task."}, {"task_from_string", PyTask_from_string, METH_VARARGS, "Creates a Python PyTask object from a string representation of " "TaskSpec."}, diff --git a/src/plasma/plasma_manager.cc b/src/plasma/plasma_manager.cc index dcece41db..12f306c2b 100644 --- a/src/plasma/plasma_manager.cc +++ b/src/plasma/plasma_manager.cc @@ -1316,6 +1316,7 @@ void log_object_hash_mismatch_error_task_callback(Task *task, void log_object_hash_mismatch_error_result_callback(ObjectID object_id, TaskID task_id, + bool is_put, void *user_context) { CHECK(!IS_NIL_ID(task_id)); PlasmaManagerState *state = (PlasmaManagerState *) user_context; diff --git a/test/stress_tests.py b/test/stress_tests.py index af5aaee45..04bbe5f40 100644 --- a/test/stress_tests.py +++ b/test/stress_tests.py @@ -163,8 +163,8 @@ class ReconstructionTests(unittest.TestCase): def tearDown(self): self.assertTrue(ray.services.all_processes_alive()) - # Make sure that all nodes in the cluster were used by checking where tasks - # were scheduled and/or submitted from. + # Determine the IDs of all local schedulers that had a task scheduled or + # submitted. r = redis.StrictRedis(port=self.redis_port) task_ids = r.keys("TT:*") task_ids = [task_id[3:] for task_id in task_ids] @@ -174,7 +174,13 @@ class ReconstructionTests(unittest.TestCase): task_reply_object = TaskReply.GetRootAsTaskReply(message, 0) local_scheduler_ids.append(task_reply_object.LocalSchedulerId()) - self.assertEqual(len(set(local_scheduler_ids)), self.num_local_schedulers) + # Make sure that all nodes in the cluster were used by checking that the + # set of local scheduler IDs that had a task scheduled or submitted is + # equal to the total number of local schedulers started. We add one to the + # total number of local schedulers to account for NIL_LOCAL_SCHEDULER_ID. + # This is the local scheduler ID associated with the driver task, since it + # is not scheduled by a particular local scheduler. + self.assertEqual(len(set(local_scheduler_ids)), self.num_local_schedulers + 1) # Clean up the Ray cluster. ray.worker.cleanup() @@ -208,6 +214,12 @@ class ReconstructionTests(unittest.TestCase): for i in range(num_objects): value = ray.get(args[i]) self.assertEqual(value[0], i) + # Get values sequentially, in chunks. + num_chunks = 4 * self.num_local_schedulers + chunk = num_objects // num_chunks + for i in range(num_chunks): + values = ray.get(args[i * chunk : (i + 1) * chunk]) + del values def testRecursive(self): # Define the size of one task's return argument so that the combined sum of @@ -252,6 +264,12 @@ class ReconstructionTests(unittest.TestCase): i = np.random.randint(num_objects) value = ray.get(args[i]) self.assertEqual(value[0], i) + # Get values sequentially, in chunks. + num_chunks = 4 * self.num_local_schedulers + chunk = num_objects // num_chunks + for i in range(num_chunks): + values = ray.get(args[i * chunk : (i + 1) * chunk]) + del values def testMultipleRecursive(self): # Define the size of one task's return argument so that the combined sum of @@ -302,6 +320,21 @@ class ReconstructionTests(unittest.TestCase): value = ray.get(args[i]) self.assertEqual(value[0], i) + def wait_for_errors(self, error_check): + # Wait for errors from all the nondeterministic tasks. + errors = [] + time_left = 100 + while time_left > 0: + errors = ray.error_info() + if error_check(errors): + break + time_left -= 1 + time.sleep(1) + + # Make sure that enough errors came through. + self.assertTrue(error_check(errors)) + return errors + def testNondeterministicTask(self): # Define the size of one task's return argument so that the combined sum of # all objects' sizes is at least twice the plasma stores' combined allotted @@ -345,22 +378,147 @@ class ReconstructionTests(unittest.TestCase): value = ray.get(args[i]) self.assertEqual(value[0], i) - # Wait for errors from all the nondeterministic tasks. - time_left = 100 - while time_left > 0: - errors = ray.error_info() - if len(errors) >= num_objects / 2: - break - time_left -= 0.1 - time.sleep(0.1) - - # Make sure that enough errors came through. - self.assertTrue(len(errors) >= num_objects / 2) + def error_check(errors): + if self.num_local_schedulers == 1: + # In a single-node setting, each object is evicted and reconstructed + # exactly once, so exactly half the objects will produce an error + # during reconstruction. + min_errors = num_objects // 2 + else: + # In a multinode setting, each object is evicted zero or one times, so + # some of the nondeterministic tasks may not be reexecuted. + min_errors = 1 + return len(errors) >= min_errors + errors = self.wait_for_errors(error_check) # Make sure all the errors have the correct type. self.assertTrue(all(error[b"type"] == b"object_hash_mismatch" for error in errors)) # Make sure all the errors have the correct function name. self.assertTrue(all(error[b"data"] == b"__main__.foo" for error in errors)) + def testPutErrors(self): + # Define the size of one task's return argument so that the combined sum of + # all objects' sizes is at least twice the plasma stores' combined allotted + # memory. + num_objects = 1000 + size = self.plasma_store_memory * 2 // (num_objects * 8) + + # Define a task with a single dependency, a numpy array, that returns + # another array. + @ray.remote + def single_dependency(i, arg): + arg = np.copy(arg) + arg[0] = i + return arg + + # Define a root task that calls `ray.put` to put an argument in the object + # store. + @ray.remote + def put_arg_task(size): + # Launch num_objects instances of the remote task, each dependent on the + # one before it. The first instance of the task takes a numpy array as an + # argument, which is put into the object store. + args = [] + arg = single_dependency.remote(0, np.zeros(size)) + for i in range(num_objects): + arg = single_dependency.remote(i, arg) + args.append(arg) + + # Get each value to force each task to finish. After some number of gets, + # old values should be evicted. + for i in range(num_objects): + value = ray.get(args[i]) + self.assertEqual(value[0], i) + # Get each value again to force reconstruction. Currently, since we're + # not able to reconstruct `ray.put` objects that were evicted and whose + # originating tasks are still running, this for-loop should hang on its + # first iteration and push an error to the driver. + for i in range(num_objects): + value = ray.get(args[i]) + self.assertEqual(value[0], i) + + # Define a root task that calls `ray.put` directly. + @ray.remote + def put_task(size): + # Launch num_objects instances of the remote task, each dependent on the + # one before it. The first instance of the task takes an object ID + # returned by ray.put. + args = [] + arg = ray.put(np.zeros(size)) + for i in range(num_objects): + arg = single_dependency.remote(i, arg) + args.append(arg) + + # Get each value to force each task to finish. After some number of gets, + # old values should be evicted. + for i in range(num_objects): + value = ray.get(args[i]) + self.assertEqual(value[0], i) + # Get each value again to force reconstruction. Currently, since we're + # not able to reconstruct `ray.put` objects that were evicted and whose + # originating tasks are still running, this for-loop should hang on its + # first iteration and push an error to the driver. + for i in range(num_objects): + value = ray.get(args[i]) + self.assertEqual(value[0], i) + + put_arg_task.remote(size) + def error_check(errors): + return len(errors) > 1 + errors = self.wait_for_errors(error_check) + # Make sure all the errors have the correct type. + self.assertTrue(all(error[b"type"] == b"put_reconstruction" for error in errors)) + self.assertTrue(all(error[b"data"] == b"__main__.put_arg_task" for error in errors)) + + put_task.remote(size) + def error_check(errors): + return any(error[b"data"] == b"__main__.put_task" for error in errors) + errors = self.wait_for_errors(error_check) + # Make sure all the errors have the correct type. + self.assertTrue(all(error[b"type"] == b"put_reconstruction" for error in errors)) + self.assertTrue(any(error[b"data"] == b"__main__.put_task" for error in errors)) + + def testDriverPutErrors(self): + # Define the size of one task's return argument so that the combined sum of + # all objects' sizes is at least twice the plasma stores' combined allotted + # memory. + num_objects = 1000 + size = self.plasma_store_memory * 2 // (num_objects * 8) + + # Define a task with a single dependency, a numpy array, that returns + # another array. + @ray.remote + def single_dependency(i, arg): + arg = np.copy(arg) + arg[0] = i + return arg + + # Launch num_objects instances of the remote task, each dependent on the + # one before it. The first instance of the task takes a numpy array as an + # argument, which is put into the object store. + args = [] + arg = single_dependency.remote(0, np.zeros(size)) + for i in range(num_objects): + arg = single_dependency.remote(i, arg) + args.append(arg) + # Get each value to force each task to finish. After some number of gets, + # old values should be evicted. + for i in range(num_objects): + value = ray.get(args[i]) + self.assertEqual(value[0], i) + + # Get each value starting from the beginning to force reconstruction. + # Currently, since we're not able to reconstruct `ray.put` objects that + # were evicted and whose originating tasks are still running, this + # for-loop should hang on its first iteration and push an error to the + # driver. + ray.worker.global_worker.local_scheduler_client.reconstruct_object(args[0].id()) + def error_check(errors): + return len(errors) > 1 + errors = self.wait_for_errors(error_check) + self.assertTrue(all(error[b"type"] == b"put_reconstruction" for error in errors)) + self.assertTrue(all(error[b"data"] == b"Driver" for error in errors)) + + class ReconstructionTestsMultinode(ReconstructionTests): # Run the same tests as the single-node suite, but with 4 local schedulers,