mirror of
https://github.com/vale981/ray
synced 2025-04-23 06:25:52 -04:00
Fix local scheduler crash when driver creates actor and exits. (#1474)
* Make check failures in redis.cc more informative. * Fix bug by calling task_table_add_task. * Add test.
This commit is contained in:
parent
668737f383
commit
3195c6aa63
3 changed files with 75 additions and 28 deletions
|
@ -225,7 +225,7 @@ void db_connect_shard(const std::string &db_address,
|
|||
* we've defined. */
|
||||
reply = (redisReply *) redisCommandArgv(sync_context, argc, argv, argvlen);
|
||||
CHECKM(reply != NULL, "db_connect failed on RAY.CONNECT");
|
||||
CHECK(reply->type != REDIS_REPLY_ERROR);
|
||||
CHECKM(reply->type != REDIS_REPLY_ERROR, "reply->str is %s", reply->str);
|
||||
CHECKM(strcmp(reply->str, "OK") == 0, "reply->str is %s", reply->str);
|
||||
freeReplyObject(reply);
|
||||
free(argv);
|
||||
|
@ -326,7 +326,7 @@ void db_disconnect(DBHandle *db) {
|
|||
redisReply *reply =
|
||||
(redisReply *) redisCommand(db->sync_context, "RAY.DISCONNECT %b",
|
||||
db->client.data(), sizeof(db->client));
|
||||
CHECK(reply->type != REDIS_REPLY_ERROR);
|
||||
CHECKM(reply->type != REDIS_REPLY_ERROR, "reply->str is %s", reply->str);
|
||||
CHECKM(strcmp(reply->str, "OK") == 0, "reply->str is %s", reply->str);
|
||||
freeReplyObject(reply);
|
||||
|
||||
|
@ -382,7 +382,7 @@ void redis_object_table_add_callback(redisAsyncContext *c,
|
|||
"because a nondeterministic task was executed twice, either for "
|
||||
"reconstruction or for speculation.");
|
||||
} else {
|
||||
CHECK(reply->type != REDIS_REPLY_ERROR);
|
||||
CHECKM(reply->type != REDIS_REPLY_ERROR, "reply->str is %s", reply->str);
|
||||
CHECKM(strcmp(reply->str, "OK") == 0, "reply->str is %s", reply->str);
|
||||
}
|
||||
/* Call the done callback if there is one. */
|
||||
|
@ -428,7 +428,7 @@ void redis_object_table_remove_callback(redisAsyncContext *c,
|
|||
* condition with an object_table_add. */
|
||||
return;
|
||||
}
|
||||
CHECK(reply->type != REDIS_REPLY_ERROR);
|
||||
CHECKM(reply->type != REDIS_REPLY_ERROR, "reply->str is %s", reply->str);
|
||||
CHECKM(strcmp(reply->str, "OK") == 0, "reply->str is %s", reply->str);
|
||||
/* Call the done callback if there is one. */
|
||||
if (callback_data->done_callback != NULL) {
|
||||
|
@ -486,7 +486,7 @@ void redis_result_table_add_callback(redisAsyncContext *c,
|
|||
REDIS_CALLBACK_HEADER(db, callback_data, r);
|
||||
redisReply *reply = (redisReply *) r;
|
||||
/* Check that the command succeeded. */
|
||||
CHECK(reply->type != REDIS_REPLY_ERROR);
|
||||
CHECKM(reply->type != REDIS_REPLY_ERROR, "reply->str is %s", reply->str);
|
||||
CHECKM(strncmp(reply->str, "OK", strlen("OK")) == 0, "reply->str is %s",
|
||||
reply->str);
|
||||
/* Call the done callback if there is one. */
|
||||
|
@ -802,7 +802,7 @@ void redis_object_table_request_notifications_callback(redisAsyncContext *c,
|
|||
|
||||
/* Do some minimal checking. */
|
||||
redisReply *reply = (redisReply *) r;
|
||||
CHECK(reply->type != REDIS_REPLY_ERROR);
|
||||
CHECKM(reply->type != REDIS_REPLY_ERROR, "reply->str is %s", reply->str);
|
||||
CHECKM(strcmp(reply->str, "OK") == 0, "reply->str is %s", reply->str);
|
||||
CHECK(callback_data->done_callback == NULL);
|
||||
/* Clean up the timer and callback. */
|
||||
|
@ -909,7 +909,7 @@ void redis_task_table_add_task_callback(redisAsyncContext *c,
|
|||
callback_data->data->Get());
|
||||
}
|
||||
} else {
|
||||
CHECK(reply->type != REDIS_REPLY_ERROR);
|
||||
CHECKM(reply->type != REDIS_REPLY_ERROR, "reply->str is %s", reply->str);
|
||||
CHECKM(strcmp(reply->str, "OK") == 0, "reply->str is %s", reply->str);
|
||||
/* Call the done callback if there is one. */
|
||||
if (callback_data->done_callback != NULL) {
|
||||
|
@ -974,7 +974,7 @@ void redis_task_table_update_callback(redisAsyncContext *c,
|
|||
callback_data->data->Get());
|
||||
}
|
||||
} else {
|
||||
CHECK(reply->type != REDIS_REPLY_ERROR);
|
||||
CHECKM(reply->type != REDIS_REPLY_ERROR, "reply->str is %s", reply->str);
|
||||
CHECKM(strcmp(reply->str, "OK") == 0, "reply->str is %s", reply->str);
|
||||
|
||||
/* Call the done callback if there is one. */
|
||||
|
@ -1194,7 +1194,7 @@ void redis_db_client_table_remove_callback(redisAsyncContext *c,
|
|||
REDIS_CALLBACK_HEADER(db, callback_data, r);
|
||||
redisReply *reply = (redisReply *) r;
|
||||
|
||||
CHECK(reply->type != REDIS_REPLY_ERROR);
|
||||
CHECKM(reply->type != REDIS_REPLY_ERROR, "reply->str is %s", reply->str);
|
||||
CHECKM(strcmp(reply->str, "OK") == 0, "reply->str is %s", reply->str);
|
||||
|
||||
/* Call the done callback if there is one. */
|
||||
|
@ -1424,7 +1424,7 @@ void redis_local_scheduler_table_disconnect(DBHandle *db) {
|
|||
redisReply *reply = (redisReply *) redisCommand(
|
||||
db->sync_context, "PUBLISH local_schedulers %b", fbb.GetBufferPointer(),
|
||||
(size_t) fbb.GetSize());
|
||||
CHECK(reply->type != REDIS_REPLY_ERROR);
|
||||
CHECKM(reply->type != REDIS_REPLY_ERROR, "reply->str is %s", reply->str);
|
||||
CHECK(reply->type == REDIS_REPLY_INTEGER);
|
||||
LOG_DEBUG("%lld subscribers received this publish.\n", reply->integer);
|
||||
freeReplyObject(reply);
|
||||
|
@ -1632,7 +1632,7 @@ void redis_push_error_hmset_callback(redisAsyncContext *c,
|
|||
redisReply *reply = (redisReply *) r;
|
||||
|
||||
/* Make sure we were able to add the error information. */
|
||||
CHECK(reply->type != REDIS_REPLY_ERROR);
|
||||
CHECKM(reply->type != REDIS_REPLY_ERROR, "reply->str is %s", reply->str);
|
||||
CHECKM(strcmp(reply->str, "OK") == 0, "reply->str is %s", reply->str);
|
||||
|
||||
/* Add the error to this driver's list of errors. */
|
||||
|
|
|
@ -407,7 +407,11 @@ void finish_killed_task(LocalSchedulerState *state,
|
|||
if (state->db != NULL) {
|
||||
Task *task = Task_alloc(execution_spec, TASK_STATUS_DONE,
|
||||
get_db_client_id(state->db));
|
||||
task_table_update(state->db, task, NULL, NULL, NULL);
|
||||
// In most cases, task_table_update would be appropriate, however, it is
|
||||
// possible in some cases that the task has not yet been added to the task
|
||||
// table (e.g., if it is an actor task that is queued locally because the
|
||||
// actor has not been created yet).
|
||||
task_table_add_task(state->db, task, NULL, NULL, NULL);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -10,6 +10,24 @@ import tempfile
|
|||
import time
|
||||
|
||||
|
||||
def run_string_as_driver(driver_script):
|
||||
"""Run a driver as a separate process.
|
||||
|
||||
Args:
|
||||
driver_script: A string to run as a Python script.
|
||||
|
||||
Returns:
|
||||
The scripts output.
|
||||
"""
|
||||
# Save the driver script as a file so we can call it using subprocess.
|
||||
with tempfile.NamedTemporaryFile() as f:
|
||||
f.write(driver_script.encode("ascii"))
|
||||
f.flush()
|
||||
out = subprocess.check_output([sys.executable,
|
||||
f.name]).decode("ascii")
|
||||
return out
|
||||
|
||||
|
||||
class MultiNodeTest(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
|
@ -24,6 +42,7 @@ class MultiNodeTest(unittest.TestCase):
|
|||
self.redis_address = redis_address.split("\"")[0]
|
||||
|
||||
def tearDown(self):
|
||||
ray.worker.cleanup()
|
||||
# Kill the Ray cluster.
|
||||
subprocess.Popen(["ray", "stop"]).wait()
|
||||
|
||||
|
@ -86,13 +105,7 @@ assert "{}" in ray.error_info()[0][b"message"].decode("ascii")
|
|||
print("success")
|
||||
""".format(self.redis_address, error_string2, error_string2)
|
||||
|
||||
# Save the driver script as a file so we can call it using subprocess.
|
||||
with tempfile.NamedTemporaryFile() as f:
|
||||
f.write(driver_script.encode("ascii"))
|
||||
f.flush()
|
||||
out = subprocess.check_output([sys.executable,
|
||||
f.name]).decode("ascii")
|
||||
|
||||
out = run_string_as_driver(driver_script)
|
||||
# Make sure the other driver succeeded.
|
||||
self.assertIn("success", out)
|
||||
|
||||
|
@ -102,8 +115,6 @@ print("success")
|
|||
self.assertIn(error_string1,
|
||||
ray.error_info()[0][b"message"].decode("ascii"))
|
||||
|
||||
ray.worker.cleanup()
|
||||
|
||||
def testRemoteFunctionIsolation(self):
|
||||
# This test will run multiple remote functions with the same names in
|
||||
# two different drivers. Connect a driver to the Ray cluster.
|
||||
|
@ -127,12 +138,7 @@ for _ in range(10000):
|
|||
print("success")
|
||||
""".format(self.redis_address)
|
||||
|
||||
# Save the driver script as a file so we can call it using subprocess.
|
||||
with tempfile.NamedTemporaryFile() as f:
|
||||
f.write(driver_script.encode("ascii"))
|
||||
f.flush()
|
||||
out = subprocess.check_output([sys.executable,
|
||||
f.name]).decode("ascii")
|
||||
out = run_string_as_driver(driver_script)
|
||||
|
||||
@ray.remote
|
||||
def f():
|
||||
|
@ -149,7 +155,44 @@ print("success")
|
|||
# Make sure the other driver succeeded.
|
||||
self.assertIn("success", out)
|
||||
|
||||
ray.worker.cleanup()
|
||||
def testDriverExitingQuickly(self):
|
||||
# This test will create some drivers that submit some tasks and then
|
||||
# exit without waiting for the tasks to complete.
|
||||
ray.init(redis_address=self.redis_address, driver_mode=ray.SILENT_MODE)
|
||||
|
||||
# Define a driver that creates an actor and exits.
|
||||
driver_script1 = """
|
||||
import ray
|
||||
ray.init(redis_address="{}")
|
||||
@ray.remote
|
||||
class Foo(object):
|
||||
def __init__(self):
|
||||
pass
|
||||
Foo.remote()
|
||||
print("success")
|
||||
""".format(self.redis_address)
|
||||
|
||||
# Define a driver that creates some tasks and exits.
|
||||
driver_script2 = """
|
||||
import ray
|
||||
ray.init(redis_address="{}")
|
||||
@ray.remote
|
||||
def f():
|
||||
return 1
|
||||
f.remote()
|
||||
print("success")
|
||||
""".format(self.redis_address)
|
||||
|
||||
# Create some drivers and let them exit and make sure everything is
|
||||
# still alive.
|
||||
for _ in range(3):
|
||||
out = run_string_as_driver(driver_script1)
|
||||
# Make sure the first driver ran to completion.
|
||||
self.assertIn("success", out)
|
||||
out = run_string_as_driver(driver_script2)
|
||||
# Make sure the first driver ran to completion.
|
||||
self.assertIn("success", out)
|
||||
self.assertTrue(ray.services.all_processes_alive())
|
||||
|
||||
|
||||
class StartRayScriptTest(unittest.TestCase):
|
||||
|
|
Loading…
Add table
Reference in a new issue