diff --git a/python/ray/experimental/state.py b/python/ray/experimental/state.py index c20da6406..eea005874 100644 --- a/python/ray/experimental/state.py +++ b/python/ray/experimental/state.py @@ -243,14 +243,8 @@ class GlobalState(object): object_info = { "DataSize": entry.ObjectSize(), "Manager": entry.Manager(), - "IsEviction": [entry.IsEviction()], } - for i in range(1, gcs_entry.EntriesLength()): - entry = ray.gcs_utils.ObjectTableData.GetRootAsObjectTableData( - gcs_entry.Entries(i), 0) - object_info["IsEviction"].append(entry.IsEviction()) - return object_info def object_table(self, object_id=None): diff --git a/python/ray/tests/test_basic.py b/python/ray/tests/test_basic.py index c7c80965b..88c64d3e3 100644 --- a/python/ray/tests/test_basic.py +++ b/python/ray/tests/test_basic.py @@ -2505,10 +2505,6 @@ def test_global_state_api(shutdown_only): object_table = ray.global_state.object_table() assert len(object_table) == 2 - assert object_table[x_id]["IsEviction"][0] is False - - assert object_table[result_id]["IsEviction"][0] is False - assert object_table[x_id] == ray.global_state.object_table(x_id) object_table_entry = ray.global_state.object_table(result_id) assert object_table[result_id] == object_table_entry diff --git a/python/ray/tests/test_failure.py b/python/ray/tests/test_failure.py index 905b1ee28..21152c353 100644 --- a/python/ray/tests/test_failure.py +++ b/python/ray/tests/test_failure.py @@ -655,11 +655,14 @@ def test_redis_module_failure(shutdown_only): -1) run_failure_test("Index is not a number.", "RAY.TABLE_APPEND", 1, 1, 2, 1, b"a") + run_failure_test("The entry to remove doesn't exist.", "RAY.SET_REMOVE", 1, + 1, 3, 1) run_one_command("RAY.TABLE_APPEND", 1, 1, 2, 1) # It's okay to add duplicate entries. run_one_command("RAY.TABLE_APPEND", 1, 1, 2, 1) run_one_command("RAY.TABLE_APPEND", 1, 1, 2, 1, 0) run_one_command("RAY.TABLE_APPEND", 1, 1, 2, 1, 1) + run_one_command("RAY.SET_ADD", 1, 1, 3, 1) @pytest.fixture diff --git a/src/ray/gcs/client.cc b/src/ray/gcs/client.cc index 4ce6a07b4..b51421e10 100644 --- a/src/ray/gcs/client.cc +++ b/src/ray/gcs/client.cc @@ -112,7 +112,7 @@ AsyncGcsClient::AsyncGcsClient(const std::string &address, int port, driver_table_.reset(new DriverTable({primary_context_}, this)); heartbeat_batch_table_.reset(new HeartbeatBatchTable({primary_context_}, this)); // Tables below would be sharded. - object_table_.reset(new ObjectTable(shard_contexts_, this, command_type)); + object_table_.reset(new ObjectTable(shard_contexts_, this)); raylet_task_table_.reset(new raylet::TaskTable(shard_contexts_, this, command_type)); task_reconstruction_log_.reset(new TaskReconstructionLog(shard_contexts_, this)); task_lease_table_.reset(new TaskLeaseTable(shard_contexts_, this)); diff --git a/src/ray/gcs/client_test.cc b/src/ray/gcs/client_test.cc index b7aab1582..0d1f812a5 100644 --- a/src/ray/gcs/client_test.cc +++ b/src/ray/gcs/client_test.cc @@ -131,40 +131,42 @@ TEST_MACRO(TestGcsWithChainAsio, TestTableLookup); void TestLogLookup(const JobID &job_id, std::shared_ptr client) { // Append some entries to the log at an object ID. - ObjectID object_id = ObjectID::from_random(); - std::vector managers = {"abc", "def", "ghi"}; - for (auto &manager : managers) { - auto data = std::make_shared(); - data->manager = manager; + TaskID task_id = TaskID::from_random(); + std::vector node_manager_ids = {"abc", "def", "ghi"}; + for (auto &node_manager_id : node_manager_ids) { + auto data = std::make_shared(); + data->node_manager_id = node_manager_id; // Check that we added the correct object entries. - auto add_callback = [object_id, data](gcs::AsyncGcsClient *client, const UniqueID &id, - const ObjectTableDataT &d) { - ASSERT_EQ(id, object_id); - ASSERT_EQ(data->manager, d.manager); + auto add_callback = [task_id, data](gcs::AsyncGcsClient *client, const UniqueID &id, + const TaskReconstructionDataT &d) { + ASSERT_EQ(id, task_id); + ASSERT_EQ(data->node_manager_id, d.node_manager_id); }; - RAY_CHECK_OK(client->object_table().Append(job_id, object_id, data, add_callback)); + RAY_CHECK_OK( + client->task_reconstruction_log().Append(job_id, task_id, data, add_callback)); } // Check that lookup returns the added object entries. - auto lookup_callback = [object_id, managers]( - gcs::AsyncGcsClient *client, const ObjectID &id, - const std::vector &data) { - ASSERT_EQ(id, object_id); + auto lookup_callback = [task_id, node_manager_ids]( + gcs::AsyncGcsClient *client, const UniqueID &id, + const std::vector &data) { + ASSERT_EQ(id, task_id); for (const auto &entry : data) { - ASSERT_EQ(entry.manager, managers[test->NumCallbacks()]); + ASSERT_EQ(entry.node_manager_id, node_manager_ids[test->NumCallbacks()]); test->IncrementNumCallbacks(); } - if (test->NumCallbacks() == managers.size()) { + if (test->NumCallbacks() == node_manager_ids.size()) { test->Stop(); } }; // Do a lookup at the object ID. - RAY_CHECK_OK(client->object_table().Lookup(job_id, object_id, lookup_callback)); + RAY_CHECK_OK( + client->task_reconstruction_log().Lookup(job_id, task_id, lookup_callback)); // Run the event loop. The loop will only stop if the Lookup callback is // called (or an assertion failure). test->Start(); - ASSERT_EQ(test->NumCallbacks(), managers.size()); + ASSERT_EQ(test->NumCallbacks(), node_manager_ids.size()); } TEST_F(TestGcsWithAsio, TestLogLookup) { @@ -201,11 +203,11 @@ TEST_MACRO(TestGcsWithChainAsio, TestTableLookupFailure); void TestLogAppendAt(const JobID &job_id, std::shared_ptr client) { TaskID task_id = TaskID::from_random(); - std::vector managers = {"A", "B"}; + std::vector node_manager_ids = {"A", "B"}; std::vector> data_log; - for (const auto &manager : managers) { + for (const auto &node_manager_id : node_manager_ids) { auto data = std::make_shared(); - data->node_manager_id = manager; + data->node_manager_id = node_manager_id; data_log.push_back(data); } @@ -234,13 +236,14 @@ void TestLogAppendAt(const JobID &job_id, std::shared_ptr c job_id, task_id, data_log[1], /*done callback=*/nullptr, failure_callback, /*log_length=*/1)); - auto lookup_callback = [managers](gcs::AsyncGcsClient *client, const UniqueID &id, - const std::vector &data) { + auto lookup_callback = [node_manager_ids]( + gcs::AsyncGcsClient *client, const UniqueID &id, + const std::vector &data) { std::vector appended_managers; for (const auto &entry : data) { appended_managers.push_back(entry.node_manager_id); } - ASSERT_EQ(appended_managers, managers); + ASSERT_EQ(appended_managers, node_manager_ids); test->Stop(); }; RAY_CHECK_OK( @@ -256,14 +259,13 @@ TEST_F(TestGcsWithAsio, TestLogAppendAt) { TestLogAppendAt(job_id_, client_); } -void TestDeleteKeysFromLog(const JobID &job_id, - std::shared_ptr client, - std::vector> &data_vector) { - std::vector ids; - ObjectID object_id; - for (auto &data : data_vector) { - object_id = ObjectID::from_random(); - ids.push_back(object_id); +void TestSet(const JobID &job_id, std::shared_ptr client) { + // Add some entries to the set at an object ID. + ObjectID object_id = ObjectID::from_random(); + std::vector managers = {"abc", "def", "ghi"}; + for (auto &manager : managers) { + auto data = std::make_shared(); + data->manager = manager; // Check that we added the correct object entries. auto add_callback = [object_id, data](gcs::AsyncGcsClient *client, const UniqueID &id, const ObjectTableDataT &d) { @@ -271,32 +273,102 @@ void TestDeleteKeysFromLog(const JobID &job_id, ASSERT_EQ(data->manager, d.manager); test->IncrementNumCallbacks(); }; - RAY_CHECK_OK(client->object_table().Append(job_id, object_id, data, add_callback)); + RAY_CHECK_OK(client->object_table().Add(job_id, object_id, data, add_callback)); } - for (const auto &object_id : ids) { - // Check that lookup returns the added object entries. - auto lookup_callback = [object_id, data_vector]( - gcs::AsyncGcsClient *client, const ObjectID &id, - const std::vector &data) { + + // Check that lookup returns the added object entries. + auto lookup_callback = [object_id, managers]( + gcs::AsyncGcsClient *client, const ObjectID &id, + const std::vector &data) { + ASSERT_EQ(id, object_id); + ASSERT_EQ(data.size(), managers.size()); + test->IncrementNumCallbacks(); + }; + + // Do a lookup at the object ID. + RAY_CHECK_OK(client->object_table().Lookup(job_id, object_id, lookup_callback)); + + for (auto &manager : managers) { + auto data = std::make_shared(); + data->manager = manager; + // Check that we added the correct object entries. + auto remove_entry_callback = [object_id, data]( + gcs::AsyncGcsClient *client, const UniqueID &id, const ObjectTableDataT &d) { ASSERT_EQ(id, object_id); + ASSERT_EQ(data->manager, d.manager); + test->IncrementNumCallbacks(); + }; + RAY_CHECK_OK( + client->object_table().Remove(job_id, object_id, data, remove_entry_callback)); + } + + // Check that the entries are removed. + auto lookup_callback2 = [object_id, managers]( + gcs::AsyncGcsClient *client, const ObjectID &id, + const std::vector &data) { + ASSERT_EQ(id, object_id); + ASSERT_EQ(data.size(), 0); + test->IncrementNumCallbacks(); + test->Stop(); + }; + + // Do a lookup at the object ID. + RAY_CHECK_OK(client->object_table().Lookup(job_id, object_id, lookup_callback2)); + // Run the event loop. The loop will only stop if the Lookup callback is + // called (or an assertion failure). + test->Start(); + ASSERT_EQ(test->NumCallbacks(), managers.size() * 2 + 2); +} + +TEST_F(TestGcsWithAsio, TestSet) { + test = this; + TestSet(job_id_, client_); +} + +void TestDeleteKeysFromLog( + const JobID &job_id, std::shared_ptr client, + std::vector> &data_vector) { + std::vector ids; + TaskID task_id; + for (auto &data : data_vector) { + task_id = TaskID::from_random(); + ids.push_back(task_id); + // Check that we added the correct object entries. + auto add_callback = [task_id, data](gcs::AsyncGcsClient *client, const UniqueID &id, + const TaskReconstructionDataT &d) { + ASSERT_EQ(id, task_id); + ASSERT_EQ(data->node_manager_id, d.node_manager_id); + test->IncrementNumCallbacks(); + }; + RAY_CHECK_OK( + client->task_reconstruction_log().Append(job_id, task_id, data, add_callback)); + } + for (const auto &task_id : ids) { + // Check that lookup returns the added object entries. + auto lookup_callback = [task_id, data_vector]( + gcs::AsyncGcsClient *client, const UniqueID &id, + const std::vector &data) { + ASSERT_EQ(id, task_id); ASSERT_EQ(data.size(), 1); test->IncrementNumCallbacks(); }; - RAY_CHECK_OK(client->object_table().Lookup(job_id, object_id, lookup_callback)); + RAY_CHECK_OK( + client->task_reconstruction_log().Lookup(job_id, task_id, lookup_callback)); } if (ids.size() == 1) { - client->object_table().Delete(job_id, ids[0]); + client->task_reconstruction_log().Delete(job_id, ids[0]); } else { - client->object_table().Delete(job_id, ids); + client->task_reconstruction_log().Delete(job_id, ids); } - for (const auto &object_id : ids) { - auto lookup_callback = [object_id](gcs::AsyncGcsClient *client, const ObjectID &id, - const std::vector &data) { - ASSERT_EQ(id, object_id); + for (const auto &task_id : ids) { + auto lookup_callback = [task_id](gcs::AsyncGcsClient *client, const TaskID &id, + const std::vector &data) { + ASSERT_EQ(id, task_id); ASSERT_TRUE(data.size() == 0); test->IncrementNumCallbacks(); }; - RAY_CHECK_OK(client->object_table().Lookup(job_id, object_id, lookup_callback)); + RAY_CHECK_OK( + client->task_reconstruction_log().Lookup(job_id, task_id, lookup_callback)); } } @@ -349,34 +421,80 @@ void TestDeleteKeysFromTable(const JobID &job_id, } } +void TestDeleteKeysFromSet(const JobID &job_id, + std::shared_ptr client, + std::vector> &data_vector) { + std::vector ids; + ObjectID object_id; + for (auto &data : data_vector) { + object_id = ObjectID::from_random(); + ids.push_back(object_id); + // Check that we added the correct object entries. + auto add_callback = [object_id, data](gcs::AsyncGcsClient *client, const UniqueID &id, + const ObjectTableDataT &d) { + ASSERT_EQ(id, object_id); + ASSERT_EQ(data->manager, d.manager); + test->IncrementNumCallbacks(); + }; + RAY_CHECK_OK(client->object_table().Add(job_id, object_id, data, add_callback)); + } + for (const auto &object_id : ids) { + // Check that lookup returns the added object entries. + auto lookup_callback = [object_id, data_vector]( + gcs::AsyncGcsClient *client, const ObjectID &id, + const std::vector &data) { + ASSERT_EQ(id, object_id); + ASSERT_EQ(data.size(), 1); + test->IncrementNumCallbacks(); + }; + RAY_CHECK_OK(client->object_table().Lookup(job_id, object_id, lookup_callback)); + } + if (ids.size() == 1) { + client->object_table().Delete(job_id, ids[0]); + } else { + client->object_table().Delete(job_id, ids); + } + for (const auto &object_id : ids) { + auto lookup_callback = [object_id](gcs::AsyncGcsClient *client, const ObjectID &id, + const std::vector &data) { + ASSERT_EQ(id, object_id); + ASSERT_TRUE(data.size() == 0); + test->IncrementNumCallbacks(); + }; + RAY_CHECK_OK(client->object_table().Lookup(job_id, object_id, lookup_callback)); + } +} + // Test delete function for keys of Log or Table. void TestDeleteKeys(const JobID &job_id, std::shared_ptr client) { // Test delete function for keys of Log. - std::vector> object_vector; - auto AppendObjectData = [&object_vector](size_t add_count) { + std::vector> task_reconstruction_vector; + auto AppendTaskReconstructionData = [&task_reconstruction_vector](size_t add_count) { for (size_t i = 0; i < add_count; ++i) { - auto data = std::make_shared(); - data->manager = ObjectID::from_random().hex(); - object_vector.push_back(data); + auto data = std::make_shared(); + data->node_manager_id = ObjectID::from_random().hex(); + task_reconstruction_vector.push_back(data); } }; // Test one element case. - AppendObjectData(1); - ASSERT_EQ(object_vector.size(), 1); - TestDeleteKeysFromLog(job_id, client, object_vector); + AppendTaskReconstructionData(1); + ASSERT_EQ(task_reconstruction_vector.size(), 1); + TestDeleteKeysFromLog(job_id, client, task_reconstruction_vector); // Test the case for more than one elements and less than // maximum_gcs_deletion_batch_size. - AppendObjectData(RayConfig::instance().maximum_gcs_deletion_batch_size() / 2); - ASSERT_GT(object_vector.size(), 1); - ASSERT_LT(object_vector.size(), + AppendTaskReconstructionData(RayConfig::instance().maximum_gcs_deletion_batch_size() / + 2); + ASSERT_GT(task_reconstruction_vector.size(), 1); + ASSERT_LT(task_reconstruction_vector.size(), RayConfig::instance().maximum_gcs_deletion_batch_size()); - TestDeleteKeysFromLog(job_id, client, object_vector); + TestDeleteKeysFromLog(job_id, client, task_reconstruction_vector); // Test the case for more than maximum_gcs_deletion_batch_size. // The Delete function will split the data into two commands. - AppendObjectData(RayConfig::instance().maximum_gcs_deletion_batch_size() / 2); - ASSERT_GT(object_vector.size(), + AppendTaskReconstructionData(RayConfig::instance().maximum_gcs_deletion_batch_size() / + 2); + ASSERT_GT(task_reconstruction_vector.size(), RayConfig::instance().maximum_gcs_deletion_batch_size()); - TestDeleteKeysFromLog(job_id, client, object_vector); + TestDeleteKeysFromLog(job_id, client, task_reconstruction_vector); // Test delete function for keys of Table. std::vector> task_vector; @@ -403,6 +521,33 @@ void TestDeleteKeys(const JobID &job_id, std::shared_ptr cl test->Start(); ASSERT_GT(test->NumCallbacks(), 9 * RayConfig::instance().maximum_gcs_deletion_batch_size()); + + // Test delete function for keys of Set. + std::vector> object_vector; + auto AppendObjectData = [&object_vector](size_t add_count) { + for (size_t i = 0; i < add_count; ++i) { + auto data = std::make_shared(); + data->manager = ObjectID::from_random().hex(); + object_vector.push_back(data); + } + }; + // Test one element case. + AppendObjectData(1); + ASSERT_EQ(object_vector.size(), 1); + TestDeleteKeysFromSet(job_id, client, object_vector); + // Test the case for more than one elements and less than + // maximum_gcs_deletion_batch_size. + AppendObjectData(RayConfig::instance().maximum_gcs_deletion_batch_size() / 2); + ASSERT_GT(object_vector.size(), 1); + ASSERT_LT(object_vector.size(), + RayConfig::instance().maximum_gcs_deletion_batch_size()); + TestDeleteKeysFromSet(job_id, client, object_vector); + // Test the case for more than maximum_gcs_deletion_batch_size. + // The Delete function will split the data into two commands. + AppendObjectData(RayConfig::instance().maximum_gcs_deletion_batch_size() / 2); + ASSERT_GT(object_vector.size(), + RayConfig::instance().maximum_gcs_deletion_batch_size()); + TestDeleteKeysFromSet(job_id, client, object_vector); } TEST_F(TestGcsWithAsio, TestDeleteKey) { @@ -451,22 +596,77 @@ void TaskLookupAfterUpdateFailure(gcs::AsyncGcsClient *client, const TaskID &id) void TestLogSubscribeAll(const JobID &job_id, std::shared_ptr client) { - std::vector managers = {"abc", "def", "ghi"}; - std::vector object_ids; - for (size_t i = 0; i < managers.size(); i++) { - object_ids.push_back(ObjectID::from_random()); + std::vector driver_ids; + for (int i = 0; i < 3; i++) { + driver_ids.emplace_back(DriverID::from_random()); } + // Callback for a notification. + auto notification_callback = [driver_ids](gcs::AsyncGcsClient *client, + const UniqueID &id, + const std::vector data) { + ASSERT_EQ(id, driver_ids[test->NumCallbacks()]); + // Check that we get notifications in the same order as the writes. + for (const auto &entry : data) { + ASSERT_EQ(entry.driver_id, driver_ids[test->NumCallbacks()].binary()); + test->IncrementNumCallbacks(); + } + if (test->NumCallbacks() == driver_ids.size()) { + test->Stop(); + } + }; + + // Callback for subscription success. We are guaranteed to receive + // notifications after this is called. + auto subscribe_callback = [driver_ids](gcs::AsyncGcsClient *client) { + // We have subscribed. Do the writes to the table. + for (size_t i = 0; i < driver_ids.size(); i++) { + RAY_CHECK_OK(client->driver_table().AppendDriverData(driver_ids[i], false)); + } + }; + + // Subscribe to all driver table notifications. Once we have successfully + // subscribed, we will append to the key several times and check that we get + // notified for each. + RAY_CHECK_OK(client->driver_table().Subscribe( + job_id, ClientID::nil(), notification_callback, subscribe_callback)); + + // Run the event loop. The loop will only stop if the registered subscription + // callback is called (or an assertion failure). + test->Start(); + // Check that we received one notification callback for each write. + ASSERT_EQ(test->NumCallbacks(), driver_ids.size()); +} + +TEST_F(TestGcsWithAsio, TestLogSubscribeAll) { + test = this; + TestLogSubscribeAll(job_id_, client_); +} + +void TestSetSubscribeAll(const JobID &job_id, + std::shared_ptr client) { + std::vector object_ids; + for (int i = 0; i < 3; i++) { + object_ids.emplace_back(ObjectID::from_random()); + } + std::vector managers = {"abc", "def", "ghi"}; + // Callback for a notification. auto notification_callback = [object_ids, managers]( gcs::AsyncGcsClient *client, const UniqueID &id, + const GcsTableNotificationMode notification_mode, const std::vector data) { - ASSERT_EQ(id, object_ids[test->NumCallbacks()]); + if (test->NumCallbacks() < 3 * 3) { + ASSERT_EQ(notification_mode, GcsTableNotificationMode::APPEND_OR_ADD); + } else { + ASSERT_EQ(notification_mode, GcsTableNotificationMode::REMOVE); + } + ASSERT_EQ(id, object_ids[test->NumCallbacks() / 3 % 3]); // Check that we get notifications in the same order as the writes. for (const auto &entry : data) { - ASSERT_EQ(entry.manager, managers[test->NumCallbacks()]); + ASSERT_EQ(entry.manager, managers[test->NumCallbacks() % 3]); test->IncrementNumCallbacks(); } - if (test->NumCallbacks() == managers.size()) { + if (test->NumCallbacks() == object_ids.size() * 3 * 2) { test->Stop(); } }; @@ -476,13 +676,26 @@ void TestLogSubscribeAll(const JobID &job_id, auto subscribe_callback = [job_id, object_ids, managers](gcs::AsyncGcsClient *client) { // We have subscribed. Do the writes to the table. for (size_t i = 0; i < object_ids.size(); i++) { - auto data = std::make_shared(); - data->manager = managers[i]; - RAY_CHECK_OK(client->object_table().Append(job_id, object_ids[i], data, nullptr)); + for (size_t j = 0; j < managers.size(); j++) { + auto data = std::make_shared(); + data->manager = managers[j]; + for (int k = 0; k < 3; k++) { + // Add the same entry several times. + // Expect no notification if the entry already exists. + RAY_CHECK_OK(client->object_table().Add(job_id, object_ids[i], data, nullptr)); + } + } + } + for (size_t i = 0; i < object_ids.size(); i++) { + for (size_t j = 0; j < managers.size(); j++) { + auto data = std::make_shared(); + data->manager = managers[j]; + RAY_CHECK_OK(client->object_table().Remove(job_id, object_ids[i], data, nullptr)); + } } }; - // Subscribe to all object table notifications. Once we have successfully + // Subscribe to all driver table notifications. Once we have successfully // subscribed, we will append to the key several times and check that we get // notified for each. RAY_CHECK_OK(client->object_table().Subscribe( @@ -492,12 +705,12 @@ void TestLogSubscribeAll(const JobID &job_id, // callback is called (or an assertion failure). test->Start(); // Check that we received one notification callback for each write. - ASSERT_EQ(test->NumCallbacks(), managers.size()); + ASSERT_EQ(test->NumCallbacks(), object_ids.size() * 3 * 2); } -TEST_F(TestGcsWithAsio, TestLogSubscribeAll) { +TEST_F(TestGcsWithAsio, TestSetSubscribeAll) { test = this; - TestLogSubscribeAll(job_id_, client_); + TestSetSubscribeAll(job_id_, client_); } void TestTableSubscribeId(const JobID &job_id, @@ -579,24 +792,100 @@ TEST_MACRO(TestGcsWithChainAsio, TestTableSubscribeId); void TestLogSubscribeId(const JobID &job_id, std::shared_ptr client) { // Add a log entry. + DriverID driver_id1 = DriverID::from_random(); + std::vector driver_ids1 = {"abc", "def", "ghi"}; + auto data1 = std::make_shared(); + data1->driver_id = driver_ids1[0]; + RAY_CHECK_OK(client->driver_table().Append(job_id, driver_id1, data1, nullptr)); + + // Add a log entry at a second key. + DriverID driver_id2 = DriverID::from_random(); + std::vector driver_ids2 = {"jkl", "mno", "pqr"}; + auto data2 = std::make_shared(); + data2->driver_id = driver_ids2[0]; + RAY_CHECK_OK(client->driver_table().Append(job_id, driver_id2, data2, nullptr)); + + // The callback for a notification from the table. This should only be + // received for keys that we requested notifications for. + auto notification_callback = [driver_id2, driver_ids2]( + gcs::AsyncGcsClient *client, const UniqueID &id, + const std::vector &data) { + // Check that we only get notifications for the requested key. + ASSERT_EQ(id, driver_id2); + // Check that we get notifications in the same order as the writes. + for (const auto &entry : data) { + ASSERT_EQ(entry.driver_id, driver_ids2[test->NumCallbacks()]); + test->IncrementNumCallbacks(); + } + if (test->NumCallbacks() == driver_ids2.size()) { + test->Stop(); + } + }; + + // The callback for subscription success. Once we've subscribed, request + // notifications for only one of the keys, then write to both keys. + auto subscribe_callback = [job_id, driver_id1, driver_id2, driver_ids1, + driver_ids2](gcs::AsyncGcsClient *client) { + // Request notifications for one of the keys. + RAY_CHECK_OK(client->driver_table().RequestNotifications( + job_id, driver_id2, client->client_table().GetLocalClientId())); + // Write both keys. We should only receive notifications for the key that + // we requested them for. + auto remaining = std::vector(++driver_ids1.begin(), driver_ids1.end()); + for (const auto &driver_id : remaining) { + auto data = std::make_shared(); + data->driver_id = driver_id; + RAY_CHECK_OK(client->driver_table().Append(job_id, driver_id1, data, nullptr)); + } + remaining = std::vector(++driver_ids2.begin(), driver_ids2.end()); + for (const auto &driver_id : remaining) { + auto data = std::make_shared(); + data->driver_id = driver_id; + RAY_CHECK_OK(client->driver_table().Append(job_id, driver_id2, data, nullptr)); + } + }; + + // Subscribe to notifications for this client. This allows us to request and + // receive notifications for specific keys. + RAY_CHECK_OK( + client->driver_table().Subscribe(job_id, client->client_table().GetLocalClientId(), + notification_callback, subscribe_callback)); + // Run the event loop. The loop will only stop if the registered subscription + // callback is called for the requested key. + test->Start(); + // Check that we received one notification callback for each write to the + // requested key. + ASSERT_EQ(test->NumCallbacks(), driver_ids2.size()); +} + +TEST_F(TestGcsWithAsio, TestLogSubscribeId) { + test = this; + TestLogSubscribeId(job_id_, client_); +} + +void TestSetSubscribeId(const JobID &job_id, + std::shared_ptr client) { + // Add a set entry. ObjectID object_id1 = ObjectID::from_random(); std::vector managers1 = {"abc", "def", "ghi"}; auto data1 = std::make_shared(); data1->manager = managers1[0]; - RAY_CHECK_OK(client->object_table().Append(job_id, object_id1, data1, nullptr)); + RAY_CHECK_OK(client->object_table().Add(job_id, object_id1, data1, nullptr)); - // Add a log entry at a second key. + // Add a set entry at a second key. ObjectID object_id2 = ObjectID::from_random(); std::vector managers2 = {"jkl", "mno", "pqr"}; auto data2 = std::make_shared(); data2->manager = managers2[0]; - RAY_CHECK_OK(client->object_table().Append(job_id, object_id2, data2, nullptr)); + RAY_CHECK_OK(client->object_table().Add(job_id, object_id2, data2, nullptr)); // The callback for a notification from the table. This should only be // received for keys that we requested notifications for. auto notification_callback = [object_id2, managers2]( gcs::AsyncGcsClient *client, const ObjectID &id, + const GcsTableNotificationMode notification_mode, const std::vector &data) { + ASSERT_EQ(notification_mode, GcsTableNotificationMode::APPEND_OR_ADD); // Check that we only get notifications for the requested key. ASSERT_EQ(id, object_id2); // Check that we get notifications in the same order as the writes. @@ -622,13 +911,13 @@ void TestLogSubscribeId(const JobID &job_id, for (const auto &manager : remaining) { auto data = std::make_shared(); data->manager = manager; - RAY_CHECK_OK(client->object_table().Append(job_id, object_id1, data, nullptr)); + RAY_CHECK_OK(client->object_table().Add(job_id, object_id1, data, nullptr)); } remaining = std::vector(++managers2.begin(), managers2.end()); for (const auto &manager : remaining) { auto data = std::make_shared(); data->manager = manager; - RAY_CHECK_OK(client->object_table().Append(job_id, object_id2, data, nullptr)); + RAY_CHECK_OK(client->object_table().Add(job_id, object_id2, data, nullptr)); } }; @@ -645,9 +934,9 @@ void TestLogSubscribeId(const JobID &job_id, ASSERT_EQ(test->NumCallbacks(), managers2.size()); } -TEST_F(TestGcsWithAsio, TestLogSubscribeId) { +TEST_F(TestGcsWithAsio, TestSetSubscribeId) { test = this; - TestLogSubscribeId(job_id_, client_); + TestSetSubscribeId(job_id_, client_); } void TestTableSubscribeCancel(const JobID &job_id, @@ -727,28 +1016,110 @@ TEST_MACRO(TestGcsWithChainAsio, TestTableSubscribeCancel); void TestLogSubscribeCancel(const JobID &job_id, std::shared_ptr client) { // Add a log entry. + DriverID driver_id = DriverID::from_random(); + std::vector driver_ids = {"jkl", "mno", "pqr"}; + auto data = std::make_shared(); + data->driver_id = driver_ids[0]; + RAY_CHECK_OK(client->driver_table().Append(job_id, driver_id, data, nullptr)); + + // The callback for a notification from the object table. This should only be + // received for the object that we requested notifications for. + auto notification_callback = [driver_id, driver_ids]( + gcs::AsyncGcsClient *client, const UniqueID &id, + const std::vector &data) { + ASSERT_EQ(id, driver_id); + // Check that we get a duplicate notification for the first write. We get a + // duplicate notification because the log is append-only and notifications + // are canceled after the first write, then requested again. + auto driver_ids_copy = driver_ids; + driver_ids_copy.insert(driver_ids_copy.begin(), driver_ids_copy.front()); + for (const auto &entry : data) { + ASSERT_EQ(entry.driver_id, driver_ids_copy[test->NumCallbacks()]); + test->IncrementNumCallbacks(); + } + if (test->NumCallbacks() == driver_ids_copy.size()) { + test->Stop(); + } + }; + + // The callback for a notification from the table. This should only be + // received for keys that we requested notifications for. + auto subscribe_callback = [job_id, driver_id, driver_ids](gcs::AsyncGcsClient *client) { + // Request notifications, then cancel immediately. We should receive a + // notification for the current value at the key. + RAY_CHECK_OK(client->driver_table().RequestNotifications( + job_id, driver_id, client->client_table().GetLocalClientId())); + RAY_CHECK_OK(client->driver_table().CancelNotifications( + job_id, driver_id, client->client_table().GetLocalClientId())); + // Append to the key. Since we canceled notifications, we should not + // receive a notification for these writes. + auto remaining = std::vector(++driver_ids.begin(), driver_ids.end()); + for (const auto &remaining_driver_id : remaining) { + auto data = std::make_shared(); + data->driver_id = remaining_driver_id; + RAY_CHECK_OK(client->driver_table().Append(job_id, driver_id, data, nullptr)); + } + // Request notifications again. We should receive a notification for the + // current values at the key. + RAY_CHECK_OK(client->driver_table().RequestNotifications( + job_id, driver_id, client->client_table().GetLocalClientId())); + }; + + // Subscribe to notifications for this client. This allows us to request and + // receive notifications for specific keys. + RAY_CHECK_OK( + client->driver_table().Subscribe(job_id, client->client_table().GetLocalClientId(), + notification_callback, subscribe_callback)); + // Run the event loop. The loop will only stop if the registered subscription + // callback is called for the requested key. + test->Start(); + // Check that we received a notification callback for the first append to the + // key, then a notification for all of the appends, because we cancel + // notifications in between. + ASSERT_EQ(test->NumCallbacks(), driver_ids.size() + 1); +} + +TEST_F(TestGcsWithAsio, TestLogSubscribeCancel) { + test = this; + TestLogSubscribeCancel(job_id_, client_); +} + +void TestSetSubscribeCancel(const JobID &job_id, + std::shared_ptr client) { + // Add a set entry. ObjectID object_id = ObjectID::from_random(); std::vector managers = {"jkl", "mno", "pqr"}; auto data = std::make_shared(); data->manager = managers[0]; - RAY_CHECK_OK(client->object_table().Append(job_id, object_id, data, nullptr)); + RAY_CHECK_OK(client->object_table().Add(job_id, object_id, data, nullptr)); // The callback for a notification from the object table. This should only be // received for the object that we requested notifications for. auto notification_callback = [object_id, managers]( gcs::AsyncGcsClient *client, const ObjectID &id, + const GcsTableNotificationMode notification_mode, const std::vector &data) { + ASSERT_EQ(notification_mode, GcsTableNotificationMode::APPEND_OR_ADD); ASSERT_EQ(id, object_id); // Check that we get a duplicate notification for the first write. We get a - // duplicate notification because the log is append-only and notifications + // duplicate notification because notifications // are canceled after the first write, then requested again. - auto managers_copy = managers; - managers_copy.insert(managers_copy.begin(), managers_copy.front()); - for (const auto &entry : data) { - ASSERT_EQ(entry.manager, managers_copy[test->NumCallbacks()]); + if (data.size() == 1) { + // first notification + ASSERT_EQ(data[0].manager, managers[0]); test->IncrementNumCallbacks(); + } else { + // second notification + ASSERT_EQ(data.size(), managers.size()); + std::unordered_set managers_set(managers.begin(), managers.end()); + std::unordered_set data_managers_set; + for (const auto &entry : data) { + data_managers_set.insert(entry.manager); + test->IncrementNumCallbacks(); + } + ASSERT_EQ(managers_set, data_managers_set); } - if (test->NumCallbacks() == managers_copy.size()) { + if (test->NumCallbacks() == managers.size() + 1) { test->Stop(); } }; @@ -762,13 +1133,13 @@ void TestLogSubscribeCancel(const JobID &job_id, job_id, object_id, client->client_table().GetLocalClientId())); RAY_CHECK_OK(client->object_table().CancelNotifications( job_id, object_id, client->client_table().GetLocalClientId())); - // Append to the key. Since we canceled notifications, we should not + // Add to the key. Since we canceled notifications, we should not // receive a notification for these writes. auto remaining = std::vector(++managers.begin(), managers.end()); for (const auto &manager : remaining) { auto data = std::make_shared(); data->manager = manager; - RAY_CHECK_OK(client->object_table().Append(job_id, object_id, data, nullptr)); + RAY_CHECK_OK(client->object_table().Add(job_id, object_id, data, nullptr)); } // Request notifications again. We should receive a notification for the // current values at the key. @@ -790,9 +1161,9 @@ void TestLogSubscribeCancel(const JobID &job_id, ASSERT_EQ(test->NumCallbacks(), managers.size() + 1); } -TEST_F(TestGcsWithAsio, TestLogSubscribeCancel) { +TEST_F(TestGcsWithAsio, TestSetSubscribeCancel) { test = this; - TestLogSubscribeCancel(job_id_, client_); + TestSetSubscribeCancel(job_id_, client_); } void ClientTableNotification(gcs::AsyncGcsClient *client, const ClientID &client_id, diff --git a/src/ray/gcs/format/gcs.fbs b/src/ray/gcs/format/gcs.fbs index c826d97a6..5595c3657 100644 --- a/src/ray/gcs/format/gcs.fbs +++ b/src/ray/gcs/format/gcs.fbs @@ -108,7 +108,13 @@ table ResourcePair { value: double; } +enum GcsTableNotificationMode:int { + APPEND_OR_ADD = 0, + REMOVE, +} + table GcsTableEntry { + notification_mode: GcsTableNotificationMode; id: string; entries: [string]; } @@ -124,8 +130,6 @@ table ObjectTableData { object_size: long; // The node manager ID that this object appeared on or was evicted by. manager: string; - // Whether this entry is an addition or a deletion. - is_eviction: bool; } table TaskReconstructionData { diff --git a/src/ray/gcs/redis_module/ray_redis_module.cc b/src/ray/gcs/redis_module/ray_redis_module.cc index ee1e00f85..f1fa99a0f 100644 --- a/src/ray/gcs/redis_module/ray_redis_module.cc +++ b/src/ray/gcs/redis_module/ray_redis_module.cc @@ -181,7 +181,7 @@ flatbuffers::Offset RedisStringToFlatbuf( return fbb.CreateString(redis_string_str, redis_string_size); } -/// Publish a notification for a new entry at a key. This publishes a +/// Publish a notification for an entry update at a key. This publishes a /// notification to all subscribers of the table, as well as every client that /// has requested notifications for this key. /// @@ -189,15 +189,18 @@ flatbuffers::Offset RedisStringToFlatbuf( /// this key should be published to. When publishing to a specific /// client, the channel name should be :. /// \param id The ID of the key that the notification is about. -/// \param data The data to publish. +/// \param mode the update mode, such as append or remove. +/// \param data The appended/removed data. /// \return OK if there is no error during a publish. -int PublishTableAdd(RedisModuleCtx *ctx, RedisModuleString *pubsub_channel_str, - RedisModuleString *id, RedisModuleString *data) { +int PublishTableUpdate(RedisModuleCtx *ctx, RedisModuleString *pubsub_channel_str, + RedisModuleString *id, GcsTableNotificationMode notification_mode, + RedisModuleString *data) { // Serialize the notification to send. flatbuffers::FlatBufferBuilder fbb; auto data_flatbuf = RedisStringToFlatbuf(fbb, data); - auto message = CreateGcsTableEntry(fbb, RedisStringToFlatbuf(fbb, id), - fbb.CreateVector(&data_flatbuf, 1)); + auto message = + CreateGcsTableEntry(fbb, notification_mode, RedisStringToFlatbuf(fbb, id), + fbb.CreateVector(&data_flatbuf, 1)); fbb.Finish(message); // Write the data back to any subscribers that are listening to all table @@ -265,7 +268,8 @@ int TableAdd_DoPublish(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) if (pubsub_channel != TablePubsub::NO_PUBLISH) { // All other pubsub channels write the data back directly onto the channel. - return PublishTableAdd(ctx, pubsub_channel_str, id, data); + return PublishTableUpdate(ctx, pubsub_channel_str, id, + GcsTableNotificationMode::APPEND_OR_ADD, data); } else { return RedisModule_ReplyWithSimpleString(ctx, "OK"); } @@ -364,7 +368,8 @@ int TableAppend_DoPublish(RedisModuleCtx *ctx, RedisModuleString **argv, int /*a if (pubsub_channel != TablePubsub::NO_PUBLISH) { // All other pubsub channels write the data back directly onto the // channel. - return PublishTableAdd(ctx, pubsub_channel_str, id, data); + return PublishTableUpdate(ctx, pubsub_channel_str, id, + GcsTableNotificationMode::APPEND_OR_ADD, data); } else { return RedisModule_ReplyWithSimpleString(ctx, "OK"); } @@ -407,6 +412,112 @@ int ChainTableAppend_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, } #endif +int Set_DoPublish(RedisModuleCtx *ctx, RedisModuleString **argv, bool is_add) { + RedisModuleString *pubsub_channel_str = argv[2]; + RedisModuleString *id = argv[3]; + RedisModuleString *data = argv[4]; + // Publish a message on the requested pubsub channel if necessary. + TablePubsub pubsub_channel; + REPLY_AND_RETURN_IF_NOT_OK(ParseTablePubsub(&pubsub_channel, pubsub_channel_str)); + if (pubsub_channel != TablePubsub::NO_PUBLISH) { + // All other pubsub channels write the data back directly onto the + // channel. + return PublishTableUpdate(ctx, pubsub_channel_str, id, + is_add ? GcsTableNotificationMode::APPEND_OR_ADD + : GcsTableNotificationMode::REMOVE, + data); + } else { + return RedisModule_ReplyWithSimpleString(ctx, "OK"); + } +} + +int Set_DoWrite(RedisModuleCtx *ctx, RedisModuleString **argv, int argc, bool is_add, + bool *changed) { + if (argc != 5) { + return RedisModule_WrongArity(ctx); + } + + RedisModuleString *prefix_str = argv[1]; + RedisModuleString *id = argv[3]; + RedisModuleString *data = argv[4]; + + RedisModuleString *key_string = PrefixedKeyString(ctx, prefix_str, id); + // TODO(kfstorm): According to https://redis.io/topics/modules-intro, + // set type API is not available yet. We can change RedisModule_Call to + // set type API later. + RedisModuleCallReply *reply = + RedisModule_Call(ctx, is_add ? "SADD" : "SREM", "ss", key_string, data); + if (RedisModule_CallReplyType(reply) != REDISMODULE_REPLY_ERROR) { + *changed = RedisModule_CallReplyInteger(reply) > 0; + if (!is_add && *changed) { + // try to delete the empty set. + RedisModuleKey *key; + REPLY_AND_RETURN_IF_NOT_OK( + OpenPrefixedKey(&key, ctx, prefix_str, id, REDISMODULE_WRITE)); + auto size = RedisModule_ValueLength(key); + if (size == 0) { + REPLY_AND_RETURN_IF_FALSE(RedisModule_DeleteKey(key) == REDISMODULE_OK, + "ERR Failed to delete empty set."); + } + } + return REDISMODULE_OK; + } else { + // the SADD/SREM command failed + RedisModule_ReplyWithCallReply(ctx, reply); + return REDISMODULE_ERR; + } +} + +/// Add an entry to the set stored at a key. Publishes a notification about +/// the update to all subscribers, if a pubsub channel is provided. +/// +/// This is called from a client with the command: +// +/// RAY.SET_ADD +/// +/// \param table_prefix The prefix string for keys in this set. +/// \param pubsub_channel The pubsub channel name that notifications for +/// this key should be published to. When publishing to a specific +/// client, the channel name should be :. +/// \param id The ID of the key to add to. +/// \param data The data to add to the key. +/// \return OK if the add succeeds, or an error message string if the add +/// fails. +int SetAdd_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) { + bool changed; + if (Set_DoWrite(ctx, argv, argc, /*is_add=*/true, &changed) != REDISMODULE_OK) { + return REDISMODULE_ERR; + } + if (changed) { + return Set_DoPublish(ctx, argv, /*is_add=*/true); + } + return REDISMODULE_OK; +} + +/// Remove an entry from the set stored at a key. Publishes a notification about +/// the update to all subscribers, if a pubsub channel is provided. +/// +/// This is called from a client with the command: +// +/// RAY.SET_REMOVE +/// +/// \param table_prefix The prefix string for keys in this table. +/// \param pubsub_channel The pubsub channel name that notifications for +/// this key should be published to. When publishing to a specific +/// client, the channel name should be :. +/// \param id The ID of the key to remove from. +/// \param data The data to remove from the key. +/// \return OK if the remove succeeds, or an error message string if the remove +/// fails. +int SetRemove_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) { + bool changed; + if (Set_DoWrite(ctx, argv, argc, /*is_add=*/false, &changed) != REDISMODULE_OK) { + return REDISMODULE_ERR; + } + REPLY_AND_RETURN_IF_FALSE(changed, "ERR The entry to remove doesn't exist."); + return Set_DoPublish(ctx, argv, /*is_add=*/false); +} + /// A helper function to create and finish a GcsTableEntry, based on the /// current value or values at the given key. /// @@ -428,11 +539,13 @@ Status TableEntryToFlatbuf(RedisModuleCtx *ctx, RedisModuleKey *table_key, size_t data_len = 0; char *data_buf = RedisModule_StringDMA(table_key, &data_len, REDISMODULE_READ); auto data = fbb.CreateString(data_buf, data_len); - auto message = CreateGcsTableEntry(fbb, RedisStringToFlatbuf(fbb, entry_id), + auto message = CreateGcsTableEntry(fbb, GcsTableNotificationMode::APPEND_OR_ADD, + RedisStringToFlatbuf(fbb, entry_id), fbb.CreateVector(&data, 1)); fbb.Finish(message); } break; - case REDISMODULE_KEYTYPE_LIST: { + case REDISMODULE_KEYTYPE_LIST: + case REDISMODULE_KEYTYPE_SET: { RedisModule_CloseKey(table_key); // Close the key before executing the command. NOTE(swang): According to // https://github.com/RedisLabs/RedisModulesSDK/blob/master/API.md, "While @@ -440,10 +553,17 @@ Status TableEntryToFlatbuf(RedisModuleCtx *ctx, RedisModuleKey *table_key, RedisModuleString *table_key_str = PrefixedKeyString(ctx, prefix_str, entry_id); // TODO(swang): This could potentially be replaced with the native redis // server list iterator, once it is implemented for redis modules. - RedisModuleCallReply *reply = - RedisModule_Call(ctx, "LRANGE", "sll", table_key_str, 0, -1); + RedisModuleCallReply *reply = nullptr; + switch (key_type) { + case REDISMODULE_KEYTYPE_LIST: + reply = RedisModule_Call(ctx, "LRANGE", "sll", table_key_str, 0, -1); + break; + case REDISMODULE_KEYTYPE_SET: + reply = RedisModule_Call(ctx, "SMEMBERS", "s", table_key_str); + break; + } // Build the flatbuffer from the set of log entries. - if (RedisModule_CallReplyType(reply) != REDISMODULE_REPLY_ARRAY) { + if (reply == nullptr || RedisModule_CallReplyType(reply) != REDISMODULE_REPLY_ARRAY) { return Status::RedisError("Empty list or wrong type"); } std::vector> data; @@ -453,13 +573,14 @@ Status TableEntryToFlatbuf(RedisModuleCtx *ctx, RedisModuleKey *table_key, const char *element_str = RedisModule_CallReplyStringPtr(element, &len); data.push_back(fbb.CreateString(element_str, len)); } - auto message = CreateGcsTableEntry(fbb, RedisStringToFlatbuf(fbb, entry_id), - fbb.CreateVector(data)); + auto message = + CreateGcsTableEntry(fbb, GcsTableNotificationMode::APPEND_OR_ADD, + RedisStringToFlatbuf(fbb, entry_id), fbb.CreateVector(data)); fbb.Finish(message); } break; case REDISMODULE_KEYTYPE_EMPTY: { auto message = CreateGcsTableEntry( - fbb, RedisStringToFlatbuf(fbb, entry_id), + fbb, GcsTableNotificationMode::APPEND_OR_ADD, RedisStringToFlatbuf(fbb, entry_id), fbb.CreateVector(std::vector>())); fbb.Finish(message); } break; @@ -752,6 +873,8 @@ int DebugString_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int // Wrap all Redis commands with Redis' auto memory management. AUTO_MEMORY(TableAdd_RedisCommand); AUTO_MEMORY(TableAppend_RedisCommand); +AUTO_MEMORY(SetAdd_RedisCommand); +AUTO_MEMORY(SetRemove_RedisCommand); AUTO_MEMORY(TableLookup_RedisCommand); AUTO_MEMORY(TableRequestNotifications_RedisCommand); AUTO_MEMORY(TableDelete_RedisCommand); @@ -781,7 +904,17 @@ int RedisModule_OnLoad(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) } if (RedisModule_CreateCommand(ctx, "ray.table_append", TableAppend_RedisCommand, - "write", 0, 0, 0) == REDISMODULE_ERR) { + "write pubsub", 0, 0, 0) == REDISMODULE_ERR) { + return REDISMODULE_ERR; + } + + if (RedisModule_CreateCommand(ctx, "ray.set_add", SetAdd_RedisCommand, "write pubsub", + 0, 0, 0) == REDISMODULE_ERR) { + return REDISMODULE_ERR; + } + + if (RedisModule_CreateCommand(ctx, "ray.set_remove", SetRemove_RedisCommand, + "write pubsub", 0, 0, 0) == REDISMODULE_ERR) { return REDISMODULE_ERR; } diff --git a/src/ray/gcs/tables.cc b/src/ray/gcs/tables.cc index 8e60c3a0d..9b41c9460 100644 --- a/src/ray/gcs/tables.cc +++ b/src/ray/gcs/tables.cc @@ -112,6 +112,19 @@ template Status Log::Subscribe(const JobID &job_id, const ClientID &client_id, const Callback &subscribe, const SubscriptionCallback &done) { + auto subscribe_wrapper = [subscribe](AsyncGcsClient *client, const ID &id, + const GcsTableNotificationMode notification_mode, + const std::vector &data) { + RAY_CHECK(notification_mode != GcsTableNotificationMode::REMOVE); + subscribe(client, id, data); + }; + return Subscribe(job_id, client_id, subscribe_wrapper, done); +} + +template +Status Log::Subscribe(const JobID &job_id, const ClientID &client_id, + const NotificationCallback &subscribe, + const SubscriptionCallback &done) { RAY_CHECK(subscribe_callback_index_ == -1) << "Client called Subscribe twice on the same table"; auto callback = [this, subscribe, done](const std::string &data) { @@ -137,7 +150,7 @@ Status Log::Subscribe(const JobID &job_id, const ClientID &client_id, data_root->UnPackTo(&result); results.emplace_back(std::move(result)); } - subscribe(client_, id, results); + subscribe(client_, id, root->notification_mode(), results); } } // We do not delete the callback after calling it since there may be @@ -274,6 +287,50 @@ std::string Table::DebugString() const { return result.str(); } +template +Status Set::Add(const JobID &job_id, const ID &id, + std::shared_ptr &dataT, const WriteCallback &done) { + num_adds_++; + auto callback = [this, id, dataT, done](const std::string &data) { + if (done != nullptr) { + (done)(client_, id, *dataT); + } + return true; + }; + flatbuffers::FlatBufferBuilder fbb; + fbb.ForceDefaults(true); + fbb.Finish(Data::Pack(fbb, dataT.get())); + return GetRedisContext(id)->RunAsync("RAY.SET_ADD", id, fbb.GetBufferPointer(), + fbb.GetSize(), prefix_, pubsub_channel_, + std::move(callback)); +} + +template +Status Set::Remove(const JobID &job_id, const ID &id, + std::shared_ptr &dataT, const WriteCallback &done) { + num_removes_++; + auto callback = [this, id, dataT, done](const std::string &data) { + if (done != nullptr) { + (done)(client_, id, *dataT); + } + return true; + }; + flatbuffers::FlatBufferBuilder fbb; + fbb.ForceDefaults(true); + fbb.Finish(Data::Pack(fbb, dataT.get())); + return GetRedisContext(id)->RunAsync("RAY.SET_REMOVE", id, fbb.GetBufferPointer(), + fbb.GetSize(), prefix_, pubsub_channel_, + std::move(callback)); +} + +template +std::string Set::DebugString() const { + std::stringstream result; + result << "num lookups: " << num_lookups_ << ", num adds: " << num_adds_ + << ", num removes: " << num_removes_; + return result.str(); +} + Status ErrorTable::PushErrorToDriver(const DriverID &driver_id, const std::string &type, const std::string &error_message, double timestamp) { auto data = std::make_shared(); @@ -534,6 +591,7 @@ Status ActorCheckpointIdTable::AddCheckpointId(const JobID &job_id, } template class Log; +template class Set; template class Log; template class Table; template class Table; diff --git a/src/ray/gcs/tables.h b/src/ray/gcs/tables.h index 2aabf2ae3..54f7a68da 100644 --- a/src/ray/gcs/tables.h +++ b/src/ray/gcs/tables.h @@ -67,8 +67,6 @@ class LogInterface { /// pubsub_channel_ member if pubsub is required. /// /// Example tables backed by Log: -/// ObjectTable: Stores a log of which clients have added or evicted an -/// object. /// ClientTable: Stores a log of which GCS clients have been added or deleted /// from the system. template @@ -77,6 +75,9 @@ class Log : public LogInterface, virtual public PubsubInterface { using DataT = typename Data::NativeTableType; using Callback = std::function &data)>; + using NotificationCallback = std::function &data)>; /// The callback to call when a write to a key succeeds. using WriteCallback = typename LogInterface::WriteCallback; /// The callback to call when a SUBSCRIBE call completes and we are ready to @@ -208,6 +209,29 @@ class Log : public LogInterface, virtual public PubsubInterface { static std::hash index; return shard_contexts_[index(id) % shard_contexts_.size()]; } + + /// Subscribe to any modifications to the key. The caller may choose + /// to subscribe to all modifications, or to subscribe only to keys that it + /// requests notifications for. This may only be called once per Log + /// instance. This function is different from public version due to + /// an additional parameter notification_mode in NotificationCallback. Therefore this + /// function supports notifications of remove operations. + /// + /// \param job_id The ID of the job (= driver). + /// \param client_id The type of update to listen to. If this is nil, then a + /// message for each Add to the table will be received. Else, only + /// messages for the given client will be received. In the latter + /// case, the client may request notifications on specific keys in the + /// table via `RequestNotifications`. + /// \param subscribe Callback that is called on each received message. If the + /// callback is called with an empty vector, then there was no data at the key. + /// \param done Callback that is called when subscription is complete and we + /// are ready to receive messages. + /// \return Status + Status Subscribe(const JobID &job_id, const ClientID &client_id, + const NotificationCallback &subscribe, + const SubscriptionCallback &done); + /// The connection to the GCS. std::vector> shard_contexts_; /// The GCS client. @@ -228,7 +252,6 @@ class Log : public LogInterface, virtual public PubsubInterface { /// Commands to a GCS table can either be regular (default) or chain-replicated. CommandType command_type_ = CommandType::kRegular; - private: int64_t num_appends_ = 0; int64_t num_lookups_ = 0; }; @@ -337,26 +360,104 @@ class Table : private Log, using Log::command_type_; using Log::GetRedisContext; - private: int64_t num_adds_ = 0; int64_t num_lookups_ = 0; }; -class ObjectTable : public Log { +template +class SetInterface { + public: + using DataT = typename Data::NativeTableType; + using WriteCallback = typename Log::WriteCallback; + virtual Status Add(const JobID &job_id, const ID &id, std::shared_ptr &data, + const WriteCallback &done) = 0; + virtual Status Remove(const JobID &job_id, const ID &id, std::shared_ptr &data, + const WriteCallback &done) = 0; + virtual ~SetInterface(){}; +}; + +/// \class Set +/// +/// A GCS table where every entry is an addable & removable set. This class is not +/// meant to be used directly. All set classes should derive from this class +/// and override the prefix_ member with a unique prefix for that set, and the +/// pubsub_channel_ member if pubsub is required. +/// +/// Example tables backed by Set: +/// ObjectTable: Stores a set of which clients have added an object. +template +class Set : private Log, + public SetInterface, + virtual public PubsubInterface { + public: + using DataT = typename Log::DataT; + using Callback = typename Log::Callback; + using WriteCallback = typename Log::WriteCallback; + using NotificationCallback = typename Log::NotificationCallback; + using SubscriptionCallback = typename Log::SubscriptionCallback; + + Set(const std::vector> &contexts, AsyncGcsClient *client) + : Log(contexts, client) {} + + using Log::RequestNotifications; + using Log::CancelNotifications; + using Log::Lookup; + using Log::Delete; + + /// Add an entry to the set. + /// + /// \param job_id The ID of the job (= driver). + /// \param id The ID of the data that is added to the GCS. + /// \param data Data to add to the set. + /// \param done Callback that is called once the data has been written to the + /// GCS. + /// \return Status + Status Add(const JobID &job_id, const ID &id, std::shared_ptr &data, + const WriteCallback &done); + + /// Remove an entry from the set. + /// + /// \param job_id The ID of the job (= driver). + /// \param id The ID of the data that is removed from the GCS. + /// \param data Data to remove from the set. + /// \param done Callback that is called once the data has been written to the + /// GCS. + /// \return Status + Status Remove(const JobID &job_id, const ID &id, std::shared_ptr &data, + const WriteCallback &done); + + Status Subscribe(const JobID &job_id, const ClientID &client_id, + const NotificationCallback &subscribe, + const SubscriptionCallback &done) { + return Log::Subscribe(job_id, client_id, subscribe, done); + } + + /// Returns debug string for class. + /// + /// \return string. + std::string DebugString() const; + + protected: + using Log::shard_contexts_; + using Log::client_; + using Log::pubsub_channel_; + using Log::prefix_; + using Log::GetRedisContext; + + int64_t num_adds_ = 0; + int64_t num_removes_ = 0; + using Log::num_lookups_; +}; + +class ObjectTable : public Set { public: ObjectTable(const std::vector> &contexts, AsyncGcsClient *client) - : Log(contexts, client) { + : Set(contexts, client) { pubsub_channel_ = TablePubsub::OBJECT; prefix_ = TablePrefix::OBJECT; }; - ObjectTable(const std::vector> &contexts, - AsyncGcsClient *client, gcs::CommandType command_type) - : ObjectTable(contexts, client) { - command_type_ = command_type; - }; - virtual ~ObjectTable(){}; }; diff --git a/src/ray/object_manager/object_directory.cc b/src/ray/object_manager/object_directory.cc index d9f7b87a7..f9ec35365 100644 --- a/src/ray/object_manager/object_directory.cc +++ b/src/ray/object_manager/object_directory.cc @@ -8,30 +8,19 @@ ObjectDirectory::ObjectDirectory(boost::asio::io_service &io_service, namespace { -/// Process a suffix of the object table log and store the result in +/// Process a notification of the object table entries and store the result in /// client_ids. This assumes that client_ids already contains the result of the -/// object table log up to but not including this suffix. This also stores a -/// bool in has_been_created indicating whether the object has ever been -/// created before. -void UpdateObjectLocations(const std::vector &location_history, +/// object table entries up to but not including this notification. +void UpdateObjectLocations(const GcsTableNotificationMode notification_mode, + const std::vector &location_updates, const ray::gcs::ClientTable &client_table, - std::unordered_set *client_ids, - bool *has_been_created) { - // location_history contains the history of locations of the object (it is a log), - // which might look like the following: - // client1.is_eviction = false - // client1.is_eviction = true - // client2.is_eviction = false - // In such a scenario, we want to indicate client2 is the only client that contains - // the object, which the following code achieves. - if (!location_history.empty()) { - // If there are entries, then the object has been created. Once this flag - // is set to true, it should never go back to false. - *has_been_created = true; - } - for (const auto &object_table_data : location_history) { + std::unordered_set *client_ids) { + // location_updates contains the updates of locations of the object. + // with GcsTableNotificationMode, we can determine whether the update mode is + // addition or deletion. + for (const auto &object_table_data : location_updates) { ClientID client_id = ClientID::from_binary(object_table_data.manager); - if (!object_table_data.is_eviction) { + if (notification_mode != GcsTableNotificationMode::REMOVE) { client_ids->insert(client_id); } else { client_ids->erase(client_id); @@ -52,17 +41,22 @@ void UpdateObjectLocations(const std::vector &location_history void ObjectDirectory::RegisterBackend() { auto object_notification_callback = [this]( gcs::AsyncGcsClient *client, const ObjectID &object_id, - const std::vector &location_history) { + const GcsTableNotificationMode notification_mode, + const std::vector &location_updates) { // Objects are added to this map in SubscribeObjectLocations. auto it = listeners_.find(object_id); // Do nothing for objects we are not listening for. if (it == listeners_.end()) { return; } + + // Once this flag is set to true, it should never go back to false. + it->second.subscribed = true; + // Update entries for this object. - UpdateObjectLocations(location_history, gcs_client_->client_table(), - &it->second.current_object_locations, - &it->second.has_been_created); + UpdateObjectLocations(notification_mode, location_updates, + gcs_client_->client_table(), + &it->second.current_object_locations); // Copy the callbacks so that the callbacks can unsubscribe without interrupting // looping over the callbacks. auto callbacks = it->second.callbacks; @@ -73,8 +67,7 @@ void ObjectDirectory::RegisterBackend() { for (const auto &callback_pair : callbacks) { // It is safe to call the callback directly since this is already running // in the subscription callback stack. - callback_pair.second(object_id, it->second.current_object_locations, - it->second.has_been_created); + callback_pair.second(object_id, it->second.current_object_locations); } }; RAY_CHECK_OK(gcs_client_->object_table().Subscribe( @@ -89,22 +82,22 @@ ray::Status ObjectDirectory::ReportObjectAdded( // Append the addition entry to the object table. auto data = std::make_shared(); data->manager = client_id.binary(); - data->is_eviction = false; data->object_size = object_info.data_size; ray::Status status = - gcs_client_->object_table().Append(JobID::nil(), object_id, data, nullptr); + gcs_client_->object_table().Add(JobID::nil(), object_id, data, nullptr); return status; } -ray::Status ObjectDirectory::ReportObjectRemoved(const ObjectID &object_id, - const ClientID &client_id) { +ray::Status ObjectDirectory::ReportObjectRemoved( + const ObjectID &object_id, const ClientID &client_id, + const object_manager::protocol::ObjectInfoT &object_info) { RAY_LOG(DEBUG) << "Reporting object removed to GCS " << object_id; // Append the eviction entry to the object table. auto data = std::make_shared(); data->manager = client_id.binary(); - data->is_eviction = true; + data->object_size = object_info.data_size; ray::Status status = - gcs_client_->object_table().Append(JobID::nil(), object_id, data, nullptr); + gcs_client_->object_table().Remove(JobID::nil(), object_id, data, nullptr); return status; }; @@ -141,17 +134,16 @@ void ObjectDirectory::HandleClientRemoved(const ClientID &client_id) { const ObjectID &object_id = listener.first; if (listener.second.current_object_locations.count(client_id) > 0) { // If the subscribed object has the removed client as a location, update - // its locations with an empty log so that the location will be removed. - UpdateObjectLocations({}, gcs_client_->client_table(), - &listener.second.current_object_locations, - &listener.second.has_been_created); + // its locations with an empty update so that the location will be removed. + UpdateObjectLocations(GcsTableNotificationMode::APPEND_OR_ADD, {}, + gcs_client_->client_table(), + &listener.second.current_object_locations); // Re-call all the subscribed callbacks for the object, since its // locations have changed. for (const auto &callback_pair : listener.second.callbacks) { // It is safe to call the callback directly since this is already running // in the subscription callback stack. - callback_pair.second(object_id, listener.second.current_object_locations, - listener.second.has_been_created); + callback_pair.second(object_id, listener.second.current_object_locations); } } } @@ -175,11 +167,10 @@ ray::Status ObjectDirectory::SubscribeObjectLocations(const UniqueID &callback_i listener_state.callbacks.emplace(callback_id, callback); // If we previously received some notifications about the object's locations, // immediately notify the caller of the current known locations. - if (listener_state.has_been_created) { + if (listener_state.subscribed) { auto &locations = listener_state.current_object_locations; - io_service_.post([callback, locations, object_id]() { - callback(object_id, locations, /*has_been_created=*/true); - }); + io_service_.post( + [callback, locations, object_id]() { callback(object_id, locations); }); } return status; } @@ -204,16 +195,14 @@ ray::Status ObjectDirectory::LookupLocations(const ObjectID &object_id, const OnLocationsFound &callback) { ray::Status status; auto it = listeners_.find(object_id); - if (it != listeners_.end() && it->second.has_been_created) { + if (it != listeners_.end() && it->second.subscribed) { // If we have locations cached due to a concurrent SubscribeObjectLocations // call, and we have received at least one notification from the GCS about // the object's creation, then call the callback immediately with the // cached locations. auto &locations = it->second.current_object_locations; - bool has_been_created = it->second.has_been_created; - io_service_.post([callback, object_id, locations, has_been_created]() { - callback(object_id, locations, has_been_created); - }); + io_service_.post( + [callback, object_id, locations]() { callback(object_id, locations); }); } else { // We do not have any locations cached due to a concurrent // SubscribeObjectLocations call, so look up the object's locations @@ -221,15 +210,14 @@ ray::Status ObjectDirectory::LookupLocations(const ObjectID &object_id, status = gcs_client_->object_table().Lookup( JobID::nil(), object_id, [this, callback](gcs::AsyncGcsClient *client, const ObjectID &object_id, - const std::vector &location_history) { + const std::vector &location_updates) { // Build the set of current locations based on the entries in the log. std::unordered_set client_ids; - bool has_been_created = false; - UpdateObjectLocations(location_history, gcs_client_->client_table(), - &client_ids, &has_been_created); + UpdateObjectLocations(GcsTableNotificationMode::APPEND_OR_ADD, location_updates, + gcs_client_->client_table(), &client_ids); // It is safe to call the callback directly since this is already running // in the GCS client's lookup callback stack. - callback(object_id, client_ids, has_been_created); + callback(object_id, client_ids); }); } return status; diff --git a/src/ray/object_manager/object_directory.h b/src/ray/object_manager/object_directory.h index 0559ad534..96a2d726e 100644 --- a/src/ray/object_manager/object_directory.h +++ b/src/ray/object_manager/object_directory.h @@ -51,8 +51,7 @@ class ObjectDirectoryInterface { /// Callback for object location notifications. using OnLocationsFound = std::function &, - bool has_been_created)>; + const std::unordered_set &)>; /// Lookup object locations. Callback may be invoked with empty list of client ids. /// @@ -110,9 +109,11 @@ class ObjectDirectoryInterface { /// /// \param object_id The object id that was removed from the store. /// \param client_id The client id corresponding to this node. + /// \param object_info Additional information about the object. /// \return Status of whether this method succeeded. - virtual ray::Status ReportObjectRemoved(const ObjectID &object_id, - const ClientID &client_id) = 0; + virtual ray::Status ReportObjectRemoved( + const ObjectID &object_id, const ClientID &client_id, + const object_manager::protocol::ObjectInfoT &object_info) = 0; /// Get local client id /// @@ -159,8 +160,9 @@ class ObjectDirectory : public ObjectDirectoryInterface { ray::Status ReportObjectAdded( const ObjectID &object_id, const ClientID &client_id, const object_manager::protocol::ObjectInfoT &object_info) override; - ray::Status ReportObjectRemoved(const ObjectID &object_id, - const ClientID &client_id) override; + ray::Status ReportObjectRemoved( + const ObjectID &object_id, const ClientID &client_id, + const object_manager::protocol::ObjectInfoT &object_info) override; ray::ClientID GetLocalClientID() override; @@ -176,12 +178,12 @@ class ObjectDirectory : public ObjectDirectoryInterface { std::unordered_map callbacks; /// The current set of known locations of this object. std::unordered_set current_object_locations; - /// This flag will get set to true if the object has ever been created. It + /// This flag will get set to true if received any notification of the object. + /// It means current_object_locations is up-to-date with GCS. It /// should never go back to false once set to true. If this is true, and /// the current_object_locations is empty, then this means that the object - /// does not exist on any nodes due to eviction (rather than due to the - /// object never getting created, for instance). - bool has_been_created; + /// does not exist on any nodes due to eviction or the object never getting created. + bool subscribed; }; /// Reference to the event loop. diff --git a/src/ray/object_manager/object_manager.cc b/src/ray/object_manager/object_manager.cc index 7c949be31..29338b165 100644 --- a/src/ray/object_manager/object_manager.cc +++ b/src/ray/object_manager/object_manager.cc @@ -93,8 +93,10 @@ void ObjectManager::HandleObjectAdded( void ObjectManager::NotifyDirectoryObjectDeleted(const ObjectID &object_id) { auto it = local_objects_.find(object_id); RAY_CHECK(it != local_objects_.end()); + auto object_info = it->second.object_info; local_objects_.erase(it); - ray::Status status = object_directory_->ReportObjectRemoved(object_id, client_id_); + ray::Status status = + object_directory_->ReportObjectRemoved(object_id, client_id_, object_info); } ray::Status ObjectManager::SubscribeObjAdded( @@ -127,8 +129,7 @@ ray::Status ObjectManager::Pull(const ObjectID &object_id) { // no ordering guarantee between notifications. return object_directory_->SubscribeObjectLocations( object_directory_pull_callback_id_, object_id, - [this](const ObjectID &object_id, const std::unordered_set &client_ids, - bool created) { + [this](const ObjectID &object_id, const std::unordered_set &client_ids) { // Exit if the Pull request has already been fulfilled or canceled. auto it = pull_requests_.find(object_id); if (it == pull_requests_.end()) { @@ -578,9 +579,8 @@ ray::Status ObjectManager::LookupRemainingWaitObjects(const UniqueID &wait_id) { // Lookup remaining objects. wait_state.requested_objects.insert(object_id); RAY_RETURN_NOT_OK(object_directory_->LookupLocations( - object_id, - [this, wait_id](const ObjectID &lookup_object_id, - const std::unordered_set &client_ids, bool created) { + object_id, [this, wait_id](const ObjectID &lookup_object_id, + const std::unordered_set &client_ids) { auto &wait_state = active_wait_requests_.find(wait_id)->second; if (!client_ids.empty()) { wait_state.remaining.erase(lookup_object_id); @@ -618,7 +618,7 @@ void ObjectManager::SubscribeRemainingWaitObjects(const UniqueID &wait_id) { RAY_CHECK_OK(object_directory_->SubscribeObjectLocations( wait_id, object_id, [this, wait_id](const ObjectID &subscribe_object_id, - const std::unordered_set &client_ids, bool created) { + const std::unordered_set &client_ids) { if (!client_ids.empty()) { RAY_LOG(DEBUG) << "Wait request " << wait_id << ": subscription notification received for object " diff --git a/src/ray/object_manager/test/object_manager_test.cc b/src/ray/object_manager/test/object_manager_test.cc index 699d119e4..a373ea9b9 100644 --- a/src/ray/object_manager/test/object_manager_test.cc +++ b/src/ray/object_manager/test/object_manager_test.cc @@ -291,10 +291,9 @@ class TestObjectManager : public TestObjectManagerBase { UniqueID sub_id = ray::ObjectID::from_random(); RAY_CHECK_OK(server1->object_manager_.object_directory_->SubscribeObjectLocations( - sub_id, object_1, - [this, sub_id, object_1, object_2]( - const ray::ObjectID &object_id, - const std::unordered_set &clients, bool created) { + sub_id, object_1, [this, sub_id, object_1, object_2]( + const ray::ObjectID &object_id, + const std::unordered_set &clients) { if (!clients.empty()) { TestWaitWhileSubscribed(sub_id, object_1, object_2); } diff --git a/src/ray/raylet/node_manager.cc b/src/ray/raylet/node_manager.cc index a49b6268c..f94ddaeb1 100644 --- a/src/ray/raylet/node_manager.cc +++ b/src/ray/raylet/node_manager.cc @@ -60,9 +60,7 @@ NodeManager::NodeManager(boost::asio::io_service &io_service, scheduling_policy_(local_queues_), reconstruction_policy_( io_service_, - [this](const TaskID &task_id, bool return_values_lost) { - HandleTaskReconstruction(task_id); - }, + [this](const TaskID &task_id) { HandleTaskReconstruction(task_id); }, RayConfig::instance().initial_reconstruction_timeout_milliseconds(), gcs_client_->client_table().GetLocalClientId(), gcs_client_->task_lease_table(), object_directory_, gcs_client_->task_reconstruction_log()), @@ -1287,14 +1285,13 @@ void NodeManager::TreatTaskAsFailedIfLost(const Task &task) { const ObjectID object_id = spec.ReturnId(i); // Lookup the return value's locations. RAY_CHECK_OK(object_directory_->LookupLocations( - object_id, - [this, task_marked_as_failed, task]( - const ray::ObjectID &object_id, - const std::unordered_set &clients, bool has_been_created) { + object_id, [this, task_marked_as_failed, task]( + const ray::ObjectID &object_id, + const std::unordered_set &clients) { if (!*task_marked_as_failed) { // Only process the object locations if we haven't already marked the // task as failed. - if (clients.empty() && has_been_created) { + if (clients.empty()) { // The object does not exist on any nodes but has been created // before, so the object has been lost. Mark the task as failed to // prevent any tasks that depend on this object from hanging. diff --git a/src/ray/raylet/reconstruction_policy.cc b/src/ray/raylet/reconstruction_policy.cc index d69840299..d75f8799f 100644 --- a/src/ray/raylet/reconstruction_policy.cc +++ b/src/ray/raylet/reconstruction_policy.cc @@ -6,7 +6,7 @@ namespace raylet { ReconstructionPolicy::ReconstructionPolicy( boost::asio::io_service &io_service, - std::function reconstruction_handler, + std::function reconstruction_handler, int64_t initial_reconstruction_timeout_ms, const ClientID &client_id, gcs::PubsubInterface &task_lease_pubsub, std::shared_ptr object_directory, @@ -74,14 +74,13 @@ void ReconstructionPolicy::HandleReconstructionLogAppend(const TaskID &task_id, SetTaskTimeout(it, initial_reconstruction_timeout_ms_); if (success) { - reconstruction_handler_(task_id, it->second.return_values_lost); + reconstruction_handler_(task_id); } } void ReconstructionPolicy::AttemptReconstruction(const TaskID &task_id, const ObjectID &required_object_id, - int reconstruction_attempt, - bool created) { + int reconstruction_attempt) { // If we are no longer listening for objects created by this task, give up. auto it = listening_tasks_.find(task_id); if (it == listening_tasks_.end()) { @@ -93,10 +92,6 @@ void ReconstructionPolicy::AttemptReconstruction(const TaskID &task_id, return; } - if (created) { - it->second.return_values_lost = true; - } - // Suppress duplicate reconstructions of the same task. This can happen if, // for example, a task creates two different objects that both require // reconstruction. @@ -142,14 +137,13 @@ void ReconstructionPolicy::HandleTaskLeaseExpired(const TaskID &task_id) { // attempted asynchronously. for (const auto &created_object_id : it->second.created_objects) { RAY_CHECK_OK(object_directory_->LookupLocations( - created_object_id, - [this, task_id, reconstruction_attempt]( - const ray::ObjectID &object_id, - const std::unordered_set &clients, bool created) { + created_object_id, [this, task_id, reconstruction_attempt]( + const ray::ObjectID &object_id, + const std::unordered_set &clients) { if (clients.empty()) { // The required object no longer exists on any live nodes. Attempt // reconstruction. - AttemptReconstruction(task_id, object_id, reconstruction_attempt, created); + AttemptReconstruction(task_id, object_id, reconstruction_attempt); } })); } diff --git a/src/ray/raylet/reconstruction_policy.h b/src/ray/raylet/reconstruction_policy.h index d936a632e..f18290aa3 100644 --- a/src/ray/raylet/reconstruction_policy.h +++ b/src/ray/raylet/reconstruction_policy.h @@ -40,7 +40,7 @@ class ReconstructionPolicy : public ReconstructionPolicyInterface { /// lease notifications from. ReconstructionPolicy( boost::asio::io_service &io_service, - std::function reconstruction_handler, + std::function reconstruction_handler, int64_t initial_reconstruction_timeout_ms, const ClientID &client_id, gcs::PubsubInterface &task_lease_pubsub, std::shared_ptr object_directory, @@ -93,7 +93,6 @@ class ReconstructionPolicy : public ReconstructionPolicyInterface { bool subscribed; // The number of times we've attempted reconstructing this task so far. int reconstruction_attempt; - bool return_values_lost; // The task's reconstruction timer. If this expires before a lease // notification is received, then the task will be reconstructed. std::unique_ptr reconstruction_timer; @@ -116,7 +115,7 @@ class ReconstructionPolicy : public ReconstructionPolicyInterface { /// reconstructions of the same task (e.g., if a task creates two objects /// that both require reconstruction). void AttemptReconstruction(const TaskID &task_id, const ObjectID &required_object_id, - int reconstruction_attempt, bool created); + int reconstruction_attempt); /// Handle expiration of a task lease. void HandleTaskLeaseExpired(const TaskID &task_id); @@ -128,7 +127,7 @@ class ReconstructionPolicy : public ReconstructionPolicyInterface { /// The event loop. boost::asio::io_service &io_service_; /// The handler to call for tasks that require reconstruction. - const std::function reconstruction_handler_; + const std::function reconstruction_handler_; /// The initial timeout within which a task lease notification must be /// received. Otherwise, reconstruction will be triggered. const int64_t initial_reconstruction_timeout_ms_; diff --git a/src/ray/raylet/reconstruction_policy_test.cc b/src/ray/raylet/reconstruction_policy_test.cc index 093f5c236..c5678d6ce 100644 --- a/src/ray/raylet/reconstruction_policy_test.cc +++ b/src/ray/raylet/reconstruction_policy_test.cc @@ -29,10 +29,9 @@ class MockObjectDirectory : public ObjectDirectoryInterface { const ObjectID object_id = callback.first; auto it = locations_.find(object_id); if (it == locations_.end()) { - callback.second(object_id, std::unordered_set(), - /*created=*/false); + callback.second(object_id, std::unordered_set()); } else { - callback.second(object_id, it->second, /*created=*/true); + callback.second(object_id, it->second); } } callbacks_.clear(); @@ -63,7 +62,9 @@ class MockObjectDirectory : public ObjectDirectoryInterface { MOCK_METHOD3(ReportObjectAdded, ray::Status(const ObjectID &, const ClientID &, const object_manager::protocol::ObjectInfoT &)); - MOCK_METHOD2(ReportObjectRemoved, ray::Status(const ObjectID &, const ClientID &)); + MOCK_METHOD3(ReportObjectRemoved, + ray::Status(const ObjectID &, const ClientID &, + const object_manager::protocol::ObjectInfoT &)); private: std::vector> callbacks_; @@ -151,8 +152,8 @@ class ReconstructionPolicyTest : public ::testing::Test { mock_object_directory_(std::make_shared()), reconstruction_timeout_ms_(50), reconstruction_policy_(std::make_shared( - io_service_, [this](const TaskID &task_id, - bool created) { TriggerReconstruction(task_id); }, + io_service_, + [this](const TaskID &task_id) { TriggerReconstruction(task_id); }, reconstruction_timeout_ms_, ClientID::from_random(), mock_gcs_, mock_object_directory_, mock_gcs_)), timer_canceled_(false) {