Introduce set data structure in GCS (#4199)

* Introduce set data structure in GCS. Change object table to Set instance.

* Fix a logic bug. Update python code.

* lint

* lint again

* Remove CURRENT_VALUE mode

* Remove 'CURRENT_VALUE'

* Add more test cases

* rename has_been_created to subscribed.

* Make `changed` parameter type of `bool *`

* Rename mode to notification_mode

* fix build

* RAY.SET_REMOVE return error if entry doesn't exist

* lint

* Address comments

* lint and fix build
This commit is contained in:
Kai Yang 2019-03-12 05:42:58 +08:00 committed by Stephanie Wang
parent c435013b27
commit 7ff56ce826
17 changed files with 881 additions and 241 deletions

View file

@ -243,14 +243,8 @@ class GlobalState(object):
object_info = {
"DataSize": entry.ObjectSize(),
"Manager": entry.Manager(),
"IsEviction": [entry.IsEviction()],
}
for i in range(1, gcs_entry.EntriesLength()):
entry = ray.gcs_utils.ObjectTableData.GetRootAsObjectTableData(
gcs_entry.Entries(i), 0)
object_info["IsEviction"].append(entry.IsEviction())
return object_info
def object_table(self, object_id=None):

View file

@ -2505,10 +2505,6 @@ def test_global_state_api(shutdown_only):
object_table = ray.global_state.object_table()
assert len(object_table) == 2
assert object_table[x_id]["IsEviction"][0] is False
assert object_table[result_id]["IsEviction"][0] is False
assert object_table[x_id] == ray.global_state.object_table(x_id)
object_table_entry = ray.global_state.object_table(result_id)
assert object_table[result_id] == object_table_entry

View file

@ -655,11 +655,14 @@ def test_redis_module_failure(shutdown_only):
-1)
run_failure_test("Index is not a number.", "RAY.TABLE_APPEND", 1, 1, 2, 1,
b"a")
run_failure_test("The entry to remove doesn't exist.", "RAY.SET_REMOVE", 1,
1, 3, 1)
run_one_command("RAY.TABLE_APPEND", 1, 1, 2, 1)
# It's okay to add duplicate entries.
run_one_command("RAY.TABLE_APPEND", 1, 1, 2, 1)
run_one_command("RAY.TABLE_APPEND", 1, 1, 2, 1, 0)
run_one_command("RAY.TABLE_APPEND", 1, 1, 2, 1, 1)
run_one_command("RAY.SET_ADD", 1, 1, 3, 1)
@pytest.fixture

View file

@ -112,7 +112,7 @@ AsyncGcsClient::AsyncGcsClient(const std::string &address, int port,
driver_table_.reset(new DriverTable({primary_context_}, this));
heartbeat_batch_table_.reset(new HeartbeatBatchTable({primary_context_}, this));
// Tables below would be sharded.
object_table_.reset(new ObjectTable(shard_contexts_, this, command_type));
object_table_.reset(new ObjectTable(shard_contexts_, this));
raylet_task_table_.reset(new raylet::TaskTable(shard_contexts_, this, command_type));
task_reconstruction_log_.reset(new TaskReconstructionLog(shard_contexts_, this));
task_lease_table_.reset(new TaskLeaseTable(shard_contexts_, this));

View file

@ -131,40 +131,42 @@ TEST_MACRO(TestGcsWithChainAsio, TestTableLookup);
void TestLogLookup(const JobID &job_id, std::shared_ptr<gcs::AsyncGcsClient> client) {
// Append some entries to the log at an object ID.
ObjectID object_id = ObjectID::from_random();
std::vector<std::string> managers = {"abc", "def", "ghi"};
for (auto &manager : managers) {
auto data = std::make_shared<ObjectTableDataT>();
data->manager = manager;
TaskID task_id = TaskID::from_random();
std::vector<std::string> node_manager_ids = {"abc", "def", "ghi"};
for (auto &node_manager_id : node_manager_ids) {
auto data = std::make_shared<TaskReconstructionDataT>();
data->node_manager_id = node_manager_id;
// Check that we added the correct object entries.
auto add_callback = [object_id, data](gcs::AsyncGcsClient *client, const UniqueID &id,
const ObjectTableDataT &d) {
ASSERT_EQ(id, object_id);
ASSERT_EQ(data->manager, d.manager);
auto add_callback = [task_id, data](gcs::AsyncGcsClient *client, const UniqueID &id,
const TaskReconstructionDataT &d) {
ASSERT_EQ(id, task_id);
ASSERT_EQ(data->node_manager_id, d.node_manager_id);
};
RAY_CHECK_OK(client->object_table().Append(job_id, object_id, data, add_callback));
RAY_CHECK_OK(
client->task_reconstruction_log().Append(job_id, task_id, data, add_callback));
}
// Check that lookup returns the added object entries.
auto lookup_callback = [object_id, managers](
gcs::AsyncGcsClient *client, const ObjectID &id,
const std::vector<ObjectTableDataT> &data) {
ASSERT_EQ(id, object_id);
auto lookup_callback = [task_id, node_manager_ids](
gcs::AsyncGcsClient *client, const UniqueID &id,
const std::vector<TaskReconstructionDataT> &data) {
ASSERT_EQ(id, task_id);
for (const auto &entry : data) {
ASSERT_EQ(entry.manager, managers[test->NumCallbacks()]);
ASSERT_EQ(entry.node_manager_id, node_manager_ids[test->NumCallbacks()]);
test->IncrementNumCallbacks();
}
if (test->NumCallbacks() == managers.size()) {
if (test->NumCallbacks() == node_manager_ids.size()) {
test->Stop();
}
};
// Do a lookup at the object ID.
RAY_CHECK_OK(client->object_table().Lookup(job_id, object_id, lookup_callback));
RAY_CHECK_OK(
client->task_reconstruction_log().Lookup(job_id, task_id, lookup_callback));
// Run the event loop. The loop will only stop if the Lookup callback is
// called (or an assertion failure).
test->Start();
ASSERT_EQ(test->NumCallbacks(), managers.size());
ASSERT_EQ(test->NumCallbacks(), node_manager_ids.size());
}
TEST_F(TestGcsWithAsio, TestLogLookup) {
@ -201,11 +203,11 @@ TEST_MACRO(TestGcsWithChainAsio, TestTableLookupFailure);
void TestLogAppendAt(const JobID &job_id, std::shared_ptr<gcs::AsyncGcsClient> client) {
TaskID task_id = TaskID::from_random();
std::vector<std::string> managers = {"A", "B"};
std::vector<std::string> node_manager_ids = {"A", "B"};
std::vector<std::shared_ptr<TaskReconstructionDataT>> data_log;
for (const auto &manager : managers) {
for (const auto &node_manager_id : node_manager_ids) {
auto data = std::make_shared<TaskReconstructionDataT>();
data->node_manager_id = manager;
data->node_manager_id = node_manager_id;
data_log.push_back(data);
}
@ -234,13 +236,14 @@ void TestLogAppendAt(const JobID &job_id, std::shared_ptr<gcs::AsyncGcsClient> c
job_id, task_id, data_log[1],
/*done callback=*/nullptr, failure_callback, /*log_length=*/1));
auto lookup_callback = [managers](gcs::AsyncGcsClient *client, const UniqueID &id,
const std::vector<TaskReconstructionDataT> &data) {
auto lookup_callback = [node_manager_ids](
gcs::AsyncGcsClient *client, const UniqueID &id,
const std::vector<TaskReconstructionDataT> &data) {
std::vector<std::string> appended_managers;
for (const auto &entry : data) {
appended_managers.push_back(entry.node_manager_id);
}
ASSERT_EQ(appended_managers, managers);
ASSERT_EQ(appended_managers, node_manager_ids);
test->Stop();
};
RAY_CHECK_OK(
@ -256,14 +259,13 @@ TEST_F(TestGcsWithAsio, TestLogAppendAt) {
TestLogAppendAt(job_id_, client_);
}
void TestDeleteKeysFromLog(const JobID &job_id,
std::shared_ptr<gcs::AsyncGcsClient> client,
std::vector<std::shared_ptr<ObjectTableDataT>> &data_vector) {
std::vector<ObjectID> ids;
ObjectID object_id;
for (auto &data : data_vector) {
object_id = ObjectID::from_random();
ids.push_back(object_id);
void TestSet(const JobID &job_id, std::shared_ptr<gcs::AsyncGcsClient> client) {
// Add some entries to the set at an object ID.
ObjectID object_id = ObjectID::from_random();
std::vector<std::string> managers = {"abc", "def", "ghi"};
for (auto &manager : managers) {
auto data = std::make_shared<ObjectTableDataT>();
data->manager = manager;
// Check that we added the correct object entries.
auto add_callback = [object_id, data](gcs::AsyncGcsClient *client, const UniqueID &id,
const ObjectTableDataT &d) {
@ -271,32 +273,102 @@ void TestDeleteKeysFromLog(const JobID &job_id,
ASSERT_EQ(data->manager, d.manager);
test->IncrementNumCallbacks();
};
RAY_CHECK_OK(client->object_table().Append(job_id, object_id, data, add_callback));
RAY_CHECK_OK(client->object_table().Add(job_id, object_id, data, add_callback));
}
for (const auto &object_id : ids) {
// Check that lookup returns the added object entries.
auto lookup_callback = [object_id, data_vector](
gcs::AsyncGcsClient *client, const ObjectID &id,
const std::vector<ObjectTableDataT> &data) {
// Check that lookup returns the added object entries.
auto lookup_callback = [object_id, managers](
gcs::AsyncGcsClient *client, const ObjectID &id,
const std::vector<ObjectTableDataT> &data) {
ASSERT_EQ(id, object_id);
ASSERT_EQ(data.size(), managers.size());
test->IncrementNumCallbacks();
};
// Do a lookup at the object ID.
RAY_CHECK_OK(client->object_table().Lookup(job_id, object_id, lookup_callback));
for (auto &manager : managers) {
auto data = std::make_shared<ObjectTableDataT>();
data->manager = manager;
// Check that we added the correct object entries.
auto remove_entry_callback = [object_id, data](
gcs::AsyncGcsClient *client, const UniqueID &id, const ObjectTableDataT &d) {
ASSERT_EQ(id, object_id);
ASSERT_EQ(data->manager, d.manager);
test->IncrementNumCallbacks();
};
RAY_CHECK_OK(
client->object_table().Remove(job_id, object_id, data, remove_entry_callback));
}
// Check that the entries are removed.
auto lookup_callback2 = [object_id, managers](
gcs::AsyncGcsClient *client, const ObjectID &id,
const std::vector<ObjectTableDataT> &data) {
ASSERT_EQ(id, object_id);
ASSERT_EQ(data.size(), 0);
test->IncrementNumCallbacks();
test->Stop();
};
// Do a lookup at the object ID.
RAY_CHECK_OK(client->object_table().Lookup(job_id, object_id, lookup_callback2));
// Run the event loop. The loop will only stop if the Lookup callback is
// called (or an assertion failure).
test->Start();
ASSERT_EQ(test->NumCallbacks(), managers.size() * 2 + 2);
}
TEST_F(TestGcsWithAsio, TestSet) {
test = this;
TestSet(job_id_, client_);
}
void TestDeleteKeysFromLog(
const JobID &job_id, std::shared_ptr<gcs::AsyncGcsClient> client,
std::vector<std::shared_ptr<TaskReconstructionDataT>> &data_vector) {
std::vector<TaskID> ids;
TaskID task_id;
for (auto &data : data_vector) {
task_id = TaskID::from_random();
ids.push_back(task_id);
// Check that we added the correct object entries.
auto add_callback = [task_id, data](gcs::AsyncGcsClient *client, const UniqueID &id,
const TaskReconstructionDataT &d) {
ASSERT_EQ(id, task_id);
ASSERT_EQ(data->node_manager_id, d.node_manager_id);
test->IncrementNumCallbacks();
};
RAY_CHECK_OK(
client->task_reconstruction_log().Append(job_id, task_id, data, add_callback));
}
for (const auto &task_id : ids) {
// Check that lookup returns the added object entries.
auto lookup_callback = [task_id, data_vector](
gcs::AsyncGcsClient *client, const UniqueID &id,
const std::vector<TaskReconstructionDataT> &data) {
ASSERT_EQ(id, task_id);
ASSERT_EQ(data.size(), 1);
test->IncrementNumCallbacks();
};
RAY_CHECK_OK(client->object_table().Lookup(job_id, object_id, lookup_callback));
RAY_CHECK_OK(
client->task_reconstruction_log().Lookup(job_id, task_id, lookup_callback));
}
if (ids.size() == 1) {
client->object_table().Delete(job_id, ids[0]);
client->task_reconstruction_log().Delete(job_id, ids[0]);
} else {
client->object_table().Delete(job_id, ids);
client->task_reconstruction_log().Delete(job_id, ids);
}
for (const auto &object_id : ids) {
auto lookup_callback = [object_id](gcs::AsyncGcsClient *client, const ObjectID &id,
const std::vector<ObjectTableDataT> &data) {
ASSERT_EQ(id, object_id);
for (const auto &task_id : ids) {
auto lookup_callback = [task_id](gcs::AsyncGcsClient *client, const TaskID &id,
const std::vector<TaskReconstructionDataT> &data) {
ASSERT_EQ(id, task_id);
ASSERT_TRUE(data.size() == 0);
test->IncrementNumCallbacks();
};
RAY_CHECK_OK(client->object_table().Lookup(job_id, object_id, lookup_callback));
RAY_CHECK_OK(
client->task_reconstruction_log().Lookup(job_id, task_id, lookup_callback));
}
}
@ -349,34 +421,80 @@ void TestDeleteKeysFromTable(const JobID &job_id,
}
}
void TestDeleteKeysFromSet(const JobID &job_id,
std::shared_ptr<gcs::AsyncGcsClient> client,
std::vector<std::shared_ptr<ObjectTableDataT>> &data_vector) {
std::vector<ObjectID> ids;
ObjectID object_id;
for (auto &data : data_vector) {
object_id = ObjectID::from_random();
ids.push_back(object_id);
// Check that we added the correct object entries.
auto add_callback = [object_id, data](gcs::AsyncGcsClient *client, const UniqueID &id,
const ObjectTableDataT &d) {
ASSERT_EQ(id, object_id);
ASSERT_EQ(data->manager, d.manager);
test->IncrementNumCallbacks();
};
RAY_CHECK_OK(client->object_table().Add(job_id, object_id, data, add_callback));
}
for (const auto &object_id : ids) {
// Check that lookup returns the added object entries.
auto lookup_callback = [object_id, data_vector](
gcs::AsyncGcsClient *client, const ObjectID &id,
const std::vector<ObjectTableDataT> &data) {
ASSERT_EQ(id, object_id);
ASSERT_EQ(data.size(), 1);
test->IncrementNumCallbacks();
};
RAY_CHECK_OK(client->object_table().Lookup(job_id, object_id, lookup_callback));
}
if (ids.size() == 1) {
client->object_table().Delete(job_id, ids[0]);
} else {
client->object_table().Delete(job_id, ids);
}
for (const auto &object_id : ids) {
auto lookup_callback = [object_id](gcs::AsyncGcsClient *client, const ObjectID &id,
const std::vector<ObjectTableDataT> &data) {
ASSERT_EQ(id, object_id);
ASSERT_TRUE(data.size() == 0);
test->IncrementNumCallbacks();
};
RAY_CHECK_OK(client->object_table().Lookup(job_id, object_id, lookup_callback));
}
}
// Test delete function for keys of Log or Table.
void TestDeleteKeys(const JobID &job_id, std::shared_ptr<gcs::AsyncGcsClient> client) {
// Test delete function for keys of Log.
std::vector<std::shared_ptr<ObjectTableDataT>> object_vector;
auto AppendObjectData = [&object_vector](size_t add_count) {
std::vector<std::shared_ptr<TaskReconstructionDataT>> task_reconstruction_vector;
auto AppendTaskReconstructionData = [&task_reconstruction_vector](size_t add_count) {
for (size_t i = 0; i < add_count; ++i) {
auto data = std::make_shared<ObjectTableDataT>();
data->manager = ObjectID::from_random().hex();
object_vector.push_back(data);
auto data = std::make_shared<TaskReconstructionDataT>();
data->node_manager_id = ObjectID::from_random().hex();
task_reconstruction_vector.push_back(data);
}
};
// Test one element case.
AppendObjectData(1);
ASSERT_EQ(object_vector.size(), 1);
TestDeleteKeysFromLog(job_id, client, object_vector);
AppendTaskReconstructionData(1);
ASSERT_EQ(task_reconstruction_vector.size(), 1);
TestDeleteKeysFromLog(job_id, client, task_reconstruction_vector);
// Test the case for more than one elements and less than
// maximum_gcs_deletion_batch_size.
AppendObjectData(RayConfig::instance().maximum_gcs_deletion_batch_size() / 2);
ASSERT_GT(object_vector.size(), 1);
ASSERT_LT(object_vector.size(),
AppendTaskReconstructionData(RayConfig::instance().maximum_gcs_deletion_batch_size() /
2);
ASSERT_GT(task_reconstruction_vector.size(), 1);
ASSERT_LT(task_reconstruction_vector.size(),
RayConfig::instance().maximum_gcs_deletion_batch_size());
TestDeleteKeysFromLog(job_id, client, object_vector);
TestDeleteKeysFromLog(job_id, client, task_reconstruction_vector);
// Test the case for more than maximum_gcs_deletion_batch_size.
// The Delete function will split the data into two commands.
AppendObjectData(RayConfig::instance().maximum_gcs_deletion_batch_size() / 2);
ASSERT_GT(object_vector.size(),
AppendTaskReconstructionData(RayConfig::instance().maximum_gcs_deletion_batch_size() /
2);
ASSERT_GT(task_reconstruction_vector.size(),
RayConfig::instance().maximum_gcs_deletion_batch_size());
TestDeleteKeysFromLog(job_id, client, object_vector);
TestDeleteKeysFromLog(job_id, client, task_reconstruction_vector);
// Test delete function for keys of Table.
std::vector<std::shared_ptr<protocol::TaskT>> task_vector;
@ -403,6 +521,33 @@ void TestDeleteKeys(const JobID &job_id, std::shared_ptr<gcs::AsyncGcsClient> cl
test->Start();
ASSERT_GT(test->NumCallbacks(),
9 * RayConfig::instance().maximum_gcs_deletion_batch_size());
// Test delete function for keys of Set.
std::vector<std::shared_ptr<ObjectTableDataT>> object_vector;
auto AppendObjectData = [&object_vector](size_t add_count) {
for (size_t i = 0; i < add_count; ++i) {
auto data = std::make_shared<ObjectTableDataT>();
data->manager = ObjectID::from_random().hex();
object_vector.push_back(data);
}
};
// Test one element case.
AppendObjectData(1);
ASSERT_EQ(object_vector.size(), 1);
TestDeleteKeysFromSet(job_id, client, object_vector);
// Test the case for more than one elements and less than
// maximum_gcs_deletion_batch_size.
AppendObjectData(RayConfig::instance().maximum_gcs_deletion_batch_size() / 2);
ASSERT_GT(object_vector.size(), 1);
ASSERT_LT(object_vector.size(),
RayConfig::instance().maximum_gcs_deletion_batch_size());
TestDeleteKeysFromSet(job_id, client, object_vector);
// Test the case for more than maximum_gcs_deletion_batch_size.
// The Delete function will split the data into two commands.
AppendObjectData(RayConfig::instance().maximum_gcs_deletion_batch_size() / 2);
ASSERT_GT(object_vector.size(),
RayConfig::instance().maximum_gcs_deletion_batch_size());
TestDeleteKeysFromSet(job_id, client, object_vector);
}
TEST_F(TestGcsWithAsio, TestDeleteKey) {
@ -451,22 +596,77 @@ void TaskLookupAfterUpdateFailure(gcs::AsyncGcsClient *client, const TaskID &id)
void TestLogSubscribeAll(const JobID &job_id,
std::shared_ptr<gcs::AsyncGcsClient> client) {
std::vector<std::string> managers = {"abc", "def", "ghi"};
std::vector<ObjectID> object_ids;
for (size_t i = 0; i < managers.size(); i++) {
object_ids.push_back(ObjectID::from_random());
std::vector<DriverID> driver_ids;
for (int i = 0; i < 3; i++) {
driver_ids.emplace_back(DriverID::from_random());
}
// Callback for a notification.
auto notification_callback = [driver_ids](gcs::AsyncGcsClient *client,
const UniqueID &id,
const std::vector<DriverTableDataT> data) {
ASSERT_EQ(id, driver_ids[test->NumCallbacks()]);
// Check that we get notifications in the same order as the writes.
for (const auto &entry : data) {
ASSERT_EQ(entry.driver_id, driver_ids[test->NumCallbacks()].binary());
test->IncrementNumCallbacks();
}
if (test->NumCallbacks() == driver_ids.size()) {
test->Stop();
}
};
// Callback for subscription success. We are guaranteed to receive
// notifications after this is called.
auto subscribe_callback = [driver_ids](gcs::AsyncGcsClient *client) {
// We have subscribed. Do the writes to the table.
for (size_t i = 0; i < driver_ids.size(); i++) {
RAY_CHECK_OK(client->driver_table().AppendDriverData(driver_ids[i], false));
}
};
// Subscribe to all driver table notifications. Once we have successfully
// subscribed, we will append to the key several times and check that we get
// notified for each.
RAY_CHECK_OK(client->driver_table().Subscribe(
job_id, ClientID::nil(), notification_callback, subscribe_callback));
// Run the event loop. The loop will only stop if the registered subscription
// callback is called (or an assertion failure).
test->Start();
// Check that we received one notification callback for each write.
ASSERT_EQ(test->NumCallbacks(), driver_ids.size());
}
TEST_F(TestGcsWithAsio, TestLogSubscribeAll) {
test = this;
TestLogSubscribeAll(job_id_, client_);
}
void TestSetSubscribeAll(const JobID &job_id,
std::shared_ptr<gcs::AsyncGcsClient> client) {
std::vector<ObjectID> object_ids;
for (int i = 0; i < 3; i++) {
object_ids.emplace_back(ObjectID::from_random());
}
std::vector<std::string> managers = {"abc", "def", "ghi"};
// Callback for a notification.
auto notification_callback = [object_ids, managers](
gcs::AsyncGcsClient *client, const UniqueID &id,
const GcsTableNotificationMode notification_mode,
const std::vector<ObjectTableDataT> data) {
ASSERT_EQ(id, object_ids[test->NumCallbacks()]);
if (test->NumCallbacks() < 3 * 3) {
ASSERT_EQ(notification_mode, GcsTableNotificationMode::APPEND_OR_ADD);
} else {
ASSERT_EQ(notification_mode, GcsTableNotificationMode::REMOVE);
}
ASSERT_EQ(id, object_ids[test->NumCallbacks() / 3 % 3]);
// Check that we get notifications in the same order as the writes.
for (const auto &entry : data) {
ASSERT_EQ(entry.manager, managers[test->NumCallbacks()]);
ASSERT_EQ(entry.manager, managers[test->NumCallbacks() % 3]);
test->IncrementNumCallbacks();
}
if (test->NumCallbacks() == managers.size()) {
if (test->NumCallbacks() == object_ids.size() * 3 * 2) {
test->Stop();
}
};
@ -476,13 +676,26 @@ void TestLogSubscribeAll(const JobID &job_id,
auto subscribe_callback = [job_id, object_ids, managers](gcs::AsyncGcsClient *client) {
// We have subscribed. Do the writes to the table.
for (size_t i = 0; i < object_ids.size(); i++) {
auto data = std::make_shared<ObjectTableDataT>();
data->manager = managers[i];
RAY_CHECK_OK(client->object_table().Append(job_id, object_ids[i], data, nullptr));
for (size_t j = 0; j < managers.size(); j++) {
auto data = std::make_shared<ObjectTableDataT>();
data->manager = managers[j];
for (int k = 0; k < 3; k++) {
// Add the same entry several times.
// Expect no notification if the entry already exists.
RAY_CHECK_OK(client->object_table().Add(job_id, object_ids[i], data, nullptr));
}
}
}
for (size_t i = 0; i < object_ids.size(); i++) {
for (size_t j = 0; j < managers.size(); j++) {
auto data = std::make_shared<ObjectTableDataT>();
data->manager = managers[j];
RAY_CHECK_OK(client->object_table().Remove(job_id, object_ids[i], data, nullptr));
}
}
};
// Subscribe to all object table notifications. Once we have successfully
// Subscribe to all driver table notifications. Once we have successfully
// subscribed, we will append to the key several times and check that we get
// notified for each.
RAY_CHECK_OK(client->object_table().Subscribe(
@ -492,12 +705,12 @@ void TestLogSubscribeAll(const JobID &job_id,
// callback is called (or an assertion failure).
test->Start();
// Check that we received one notification callback for each write.
ASSERT_EQ(test->NumCallbacks(), managers.size());
ASSERT_EQ(test->NumCallbacks(), object_ids.size() * 3 * 2);
}
TEST_F(TestGcsWithAsio, TestLogSubscribeAll) {
TEST_F(TestGcsWithAsio, TestSetSubscribeAll) {
test = this;
TestLogSubscribeAll(job_id_, client_);
TestSetSubscribeAll(job_id_, client_);
}
void TestTableSubscribeId(const JobID &job_id,
@ -579,24 +792,100 @@ TEST_MACRO(TestGcsWithChainAsio, TestTableSubscribeId);
void TestLogSubscribeId(const JobID &job_id,
std::shared_ptr<gcs::AsyncGcsClient> client) {
// Add a log entry.
DriverID driver_id1 = DriverID::from_random();
std::vector<std::string> driver_ids1 = {"abc", "def", "ghi"};
auto data1 = std::make_shared<DriverTableDataT>();
data1->driver_id = driver_ids1[0];
RAY_CHECK_OK(client->driver_table().Append(job_id, driver_id1, data1, nullptr));
// Add a log entry at a second key.
DriverID driver_id2 = DriverID::from_random();
std::vector<std::string> driver_ids2 = {"jkl", "mno", "pqr"};
auto data2 = std::make_shared<DriverTableDataT>();
data2->driver_id = driver_ids2[0];
RAY_CHECK_OK(client->driver_table().Append(job_id, driver_id2, data2, nullptr));
// The callback for a notification from the table. This should only be
// received for keys that we requested notifications for.
auto notification_callback = [driver_id2, driver_ids2](
gcs::AsyncGcsClient *client, const UniqueID &id,
const std::vector<DriverTableDataT> &data) {
// Check that we only get notifications for the requested key.
ASSERT_EQ(id, driver_id2);
// Check that we get notifications in the same order as the writes.
for (const auto &entry : data) {
ASSERT_EQ(entry.driver_id, driver_ids2[test->NumCallbacks()]);
test->IncrementNumCallbacks();
}
if (test->NumCallbacks() == driver_ids2.size()) {
test->Stop();
}
};
// The callback for subscription success. Once we've subscribed, request
// notifications for only one of the keys, then write to both keys.
auto subscribe_callback = [job_id, driver_id1, driver_id2, driver_ids1,
driver_ids2](gcs::AsyncGcsClient *client) {
// Request notifications for one of the keys.
RAY_CHECK_OK(client->driver_table().RequestNotifications(
job_id, driver_id2, client->client_table().GetLocalClientId()));
// Write both keys. We should only receive notifications for the key that
// we requested them for.
auto remaining = std::vector<std::string>(++driver_ids1.begin(), driver_ids1.end());
for (const auto &driver_id : remaining) {
auto data = std::make_shared<DriverTableDataT>();
data->driver_id = driver_id;
RAY_CHECK_OK(client->driver_table().Append(job_id, driver_id1, data, nullptr));
}
remaining = std::vector<std::string>(++driver_ids2.begin(), driver_ids2.end());
for (const auto &driver_id : remaining) {
auto data = std::make_shared<DriverTableDataT>();
data->driver_id = driver_id;
RAY_CHECK_OK(client->driver_table().Append(job_id, driver_id2, data, nullptr));
}
};
// Subscribe to notifications for this client. This allows us to request and
// receive notifications for specific keys.
RAY_CHECK_OK(
client->driver_table().Subscribe(job_id, client->client_table().GetLocalClientId(),
notification_callback, subscribe_callback));
// Run the event loop. The loop will only stop if the registered subscription
// callback is called for the requested key.
test->Start();
// Check that we received one notification callback for each write to the
// requested key.
ASSERT_EQ(test->NumCallbacks(), driver_ids2.size());
}
TEST_F(TestGcsWithAsio, TestLogSubscribeId) {
test = this;
TestLogSubscribeId(job_id_, client_);
}
void TestSetSubscribeId(const JobID &job_id,
std::shared_ptr<gcs::AsyncGcsClient> client) {
// Add a set entry.
ObjectID object_id1 = ObjectID::from_random();
std::vector<std::string> managers1 = {"abc", "def", "ghi"};
auto data1 = std::make_shared<ObjectTableDataT>();
data1->manager = managers1[0];
RAY_CHECK_OK(client->object_table().Append(job_id, object_id1, data1, nullptr));
RAY_CHECK_OK(client->object_table().Add(job_id, object_id1, data1, nullptr));
// Add a log entry at a second key.
// Add a set entry at a second key.
ObjectID object_id2 = ObjectID::from_random();
std::vector<std::string> managers2 = {"jkl", "mno", "pqr"};
auto data2 = std::make_shared<ObjectTableDataT>();
data2->manager = managers2[0];
RAY_CHECK_OK(client->object_table().Append(job_id, object_id2, data2, nullptr));
RAY_CHECK_OK(client->object_table().Add(job_id, object_id2, data2, nullptr));
// The callback for a notification from the table. This should only be
// received for keys that we requested notifications for.
auto notification_callback = [object_id2, managers2](
gcs::AsyncGcsClient *client, const ObjectID &id,
const GcsTableNotificationMode notification_mode,
const std::vector<ObjectTableDataT> &data) {
ASSERT_EQ(notification_mode, GcsTableNotificationMode::APPEND_OR_ADD);
// Check that we only get notifications for the requested key.
ASSERT_EQ(id, object_id2);
// Check that we get notifications in the same order as the writes.
@ -622,13 +911,13 @@ void TestLogSubscribeId(const JobID &job_id,
for (const auto &manager : remaining) {
auto data = std::make_shared<ObjectTableDataT>();
data->manager = manager;
RAY_CHECK_OK(client->object_table().Append(job_id, object_id1, data, nullptr));
RAY_CHECK_OK(client->object_table().Add(job_id, object_id1, data, nullptr));
}
remaining = std::vector<std::string>(++managers2.begin(), managers2.end());
for (const auto &manager : remaining) {
auto data = std::make_shared<ObjectTableDataT>();
data->manager = manager;
RAY_CHECK_OK(client->object_table().Append(job_id, object_id2, data, nullptr));
RAY_CHECK_OK(client->object_table().Add(job_id, object_id2, data, nullptr));
}
};
@ -645,9 +934,9 @@ void TestLogSubscribeId(const JobID &job_id,
ASSERT_EQ(test->NumCallbacks(), managers2.size());
}
TEST_F(TestGcsWithAsio, TestLogSubscribeId) {
TEST_F(TestGcsWithAsio, TestSetSubscribeId) {
test = this;
TestLogSubscribeId(job_id_, client_);
TestSetSubscribeId(job_id_, client_);
}
void TestTableSubscribeCancel(const JobID &job_id,
@ -727,28 +1016,110 @@ TEST_MACRO(TestGcsWithChainAsio, TestTableSubscribeCancel);
void TestLogSubscribeCancel(const JobID &job_id,
std::shared_ptr<gcs::AsyncGcsClient> client) {
// Add a log entry.
DriverID driver_id = DriverID::from_random();
std::vector<std::string> driver_ids = {"jkl", "mno", "pqr"};
auto data = std::make_shared<DriverTableDataT>();
data->driver_id = driver_ids[0];
RAY_CHECK_OK(client->driver_table().Append(job_id, driver_id, data, nullptr));
// The callback for a notification from the object table. This should only be
// received for the object that we requested notifications for.
auto notification_callback = [driver_id, driver_ids](
gcs::AsyncGcsClient *client, const UniqueID &id,
const std::vector<DriverTableDataT> &data) {
ASSERT_EQ(id, driver_id);
// Check that we get a duplicate notification for the first write. We get a
// duplicate notification because the log is append-only and notifications
// are canceled after the first write, then requested again.
auto driver_ids_copy = driver_ids;
driver_ids_copy.insert(driver_ids_copy.begin(), driver_ids_copy.front());
for (const auto &entry : data) {
ASSERT_EQ(entry.driver_id, driver_ids_copy[test->NumCallbacks()]);
test->IncrementNumCallbacks();
}
if (test->NumCallbacks() == driver_ids_copy.size()) {
test->Stop();
}
};
// The callback for a notification from the table. This should only be
// received for keys that we requested notifications for.
auto subscribe_callback = [job_id, driver_id, driver_ids](gcs::AsyncGcsClient *client) {
// Request notifications, then cancel immediately. We should receive a
// notification for the current value at the key.
RAY_CHECK_OK(client->driver_table().RequestNotifications(
job_id, driver_id, client->client_table().GetLocalClientId()));
RAY_CHECK_OK(client->driver_table().CancelNotifications(
job_id, driver_id, client->client_table().GetLocalClientId()));
// Append to the key. Since we canceled notifications, we should not
// receive a notification for these writes.
auto remaining = std::vector<std::string>(++driver_ids.begin(), driver_ids.end());
for (const auto &remaining_driver_id : remaining) {
auto data = std::make_shared<DriverTableDataT>();
data->driver_id = remaining_driver_id;
RAY_CHECK_OK(client->driver_table().Append(job_id, driver_id, data, nullptr));
}
// Request notifications again. We should receive a notification for the
// current values at the key.
RAY_CHECK_OK(client->driver_table().RequestNotifications(
job_id, driver_id, client->client_table().GetLocalClientId()));
};
// Subscribe to notifications for this client. This allows us to request and
// receive notifications for specific keys.
RAY_CHECK_OK(
client->driver_table().Subscribe(job_id, client->client_table().GetLocalClientId(),
notification_callback, subscribe_callback));
// Run the event loop. The loop will only stop if the registered subscription
// callback is called for the requested key.
test->Start();
// Check that we received a notification callback for the first append to the
// key, then a notification for all of the appends, because we cancel
// notifications in between.
ASSERT_EQ(test->NumCallbacks(), driver_ids.size() + 1);
}
TEST_F(TestGcsWithAsio, TestLogSubscribeCancel) {
test = this;
TestLogSubscribeCancel(job_id_, client_);
}
void TestSetSubscribeCancel(const JobID &job_id,
std::shared_ptr<gcs::AsyncGcsClient> client) {
// Add a set entry.
ObjectID object_id = ObjectID::from_random();
std::vector<std::string> managers = {"jkl", "mno", "pqr"};
auto data = std::make_shared<ObjectTableDataT>();
data->manager = managers[0];
RAY_CHECK_OK(client->object_table().Append(job_id, object_id, data, nullptr));
RAY_CHECK_OK(client->object_table().Add(job_id, object_id, data, nullptr));
// The callback for a notification from the object table. This should only be
// received for the object that we requested notifications for.
auto notification_callback = [object_id, managers](
gcs::AsyncGcsClient *client, const ObjectID &id,
const GcsTableNotificationMode notification_mode,
const std::vector<ObjectTableDataT> &data) {
ASSERT_EQ(notification_mode, GcsTableNotificationMode::APPEND_OR_ADD);
ASSERT_EQ(id, object_id);
// Check that we get a duplicate notification for the first write. We get a
// duplicate notification because the log is append-only and notifications
// duplicate notification because notifications
// are canceled after the first write, then requested again.
auto managers_copy = managers;
managers_copy.insert(managers_copy.begin(), managers_copy.front());
for (const auto &entry : data) {
ASSERT_EQ(entry.manager, managers_copy[test->NumCallbacks()]);
if (data.size() == 1) {
// first notification
ASSERT_EQ(data[0].manager, managers[0]);
test->IncrementNumCallbacks();
} else {
// second notification
ASSERT_EQ(data.size(), managers.size());
std::unordered_set<std::string> managers_set(managers.begin(), managers.end());
std::unordered_set<std::string> data_managers_set;
for (const auto &entry : data) {
data_managers_set.insert(entry.manager);
test->IncrementNumCallbacks();
}
ASSERT_EQ(managers_set, data_managers_set);
}
if (test->NumCallbacks() == managers_copy.size()) {
if (test->NumCallbacks() == managers.size() + 1) {
test->Stop();
}
};
@ -762,13 +1133,13 @@ void TestLogSubscribeCancel(const JobID &job_id,
job_id, object_id, client->client_table().GetLocalClientId()));
RAY_CHECK_OK(client->object_table().CancelNotifications(
job_id, object_id, client->client_table().GetLocalClientId()));
// Append to the key. Since we canceled notifications, we should not
// Add to the key. Since we canceled notifications, we should not
// receive a notification for these writes.
auto remaining = std::vector<std::string>(++managers.begin(), managers.end());
for (const auto &manager : remaining) {
auto data = std::make_shared<ObjectTableDataT>();
data->manager = manager;
RAY_CHECK_OK(client->object_table().Append(job_id, object_id, data, nullptr));
RAY_CHECK_OK(client->object_table().Add(job_id, object_id, data, nullptr));
}
// Request notifications again. We should receive a notification for the
// current values at the key.
@ -790,9 +1161,9 @@ void TestLogSubscribeCancel(const JobID &job_id,
ASSERT_EQ(test->NumCallbacks(), managers.size() + 1);
}
TEST_F(TestGcsWithAsio, TestLogSubscribeCancel) {
TEST_F(TestGcsWithAsio, TestSetSubscribeCancel) {
test = this;
TestLogSubscribeCancel(job_id_, client_);
TestSetSubscribeCancel(job_id_, client_);
}
void ClientTableNotification(gcs::AsyncGcsClient *client, const ClientID &client_id,

View file

@ -108,7 +108,13 @@ table ResourcePair {
value: double;
}
enum GcsTableNotificationMode:int {
APPEND_OR_ADD = 0,
REMOVE,
}
table GcsTableEntry {
notification_mode: GcsTableNotificationMode;
id: string;
entries: [string];
}
@ -124,8 +130,6 @@ table ObjectTableData {
object_size: long;
// The node manager ID that this object appeared on or was evicted by.
manager: string;
// Whether this entry is an addition or a deletion.
is_eviction: bool;
}
table TaskReconstructionData {

View file

@ -181,7 +181,7 @@ flatbuffers::Offset<flatbuffers::String> RedisStringToFlatbuf(
return fbb.CreateString(redis_string_str, redis_string_size);
}
/// Publish a notification for a new entry at a key. This publishes a
/// Publish a notification for an entry update at a key. This publishes a
/// notification to all subscribers of the table, as well as every client that
/// has requested notifications for this key.
///
@ -189,15 +189,18 @@ flatbuffers::Offset<flatbuffers::String> RedisStringToFlatbuf(
/// this key should be published to. When publishing to a specific
/// client, the channel name should be <pubsub_channel>:<client_id>.
/// \param id The ID of the key that the notification is about.
/// \param data The data to publish.
/// \param mode the update mode, such as append or remove.
/// \param data The appended/removed data.
/// \return OK if there is no error during a publish.
int PublishTableAdd(RedisModuleCtx *ctx, RedisModuleString *pubsub_channel_str,
RedisModuleString *id, RedisModuleString *data) {
int PublishTableUpdate(RedisModuleCtx *ctx, RedisModuleString *pubsub_channel_str,
RedisModuleString *id, GcsTableNotificationMode notification_mode,
RedisModuleString *data) {
// Serialize the notification to send.
flatbuffers::FlatBufferBuilder fbb;
auto data_flatbuf = RedisStringToFlatbuf(fbb, data);
auto message = CreateGcsTableEntry(fbb, RedisStringToFlatbuf(fbb, id),
fbb.CreateVector(&data_flatbuf, 1));
auto message =
CreateGcsTableEntry(fbb, notification_mode, RedisStringToFlatbuf(fbb, id),
fbb.CreateVector(&data_flatbuf, 1));
fbb.Finish(message);
// Write the data back to any subscribers that are listening to all table
@ -265,7 +268,8 @@ int TableAdd_DoPublish(RedisModuleCtx *ctx, RedisModuleString **argv, int argc)
if (pubsub_channel != TablePubsub::NO_PUBLISH) {
// All other pubsub channels write the data back directly onto the channel.
return PublishTableAdd(ctx, pubsub_channel_str, id, data);
return PublishTableUpdate(ctx, pubsub_channel_str, id,
GcsTableNotificationMode::APPEND_OR_ADD, data);
} else {
return RedisModule_ReplyWithSimpleString(ctx, "OK");
}
@ -364,7 +368,8 @@ int TableAppend_DoPublish(RedisModuleCtx *ctx, RedisModuleString **argv, int /*a
if (pubsub_channel != TablePubsub::NO_PUBLISH) {
// All other pubsub channels write the data back directly onto the
// channel.
return PublishTableAdd(ctx, pubsub_channel_str, id, data);
return PublishTableUpdate(ctx, pubsub_channel_str, id,
GcsTableNotificationMode::APPEND_OR_ADD, data);
} else {
return RedisModule_ReplyWithSimpleString(ctx, "OK");
}
@ -407,6 +412,112 @@ int ChainTableAppend_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv,
}
#endif
int Set_DoPublish(RedisModuleCtx *ctx, RedisModuleString **argv, bool is_add) {
RedisModuleString *pubsub_channel_str = argv[2];
RedisModuleString *id = argv[3];
RedisModuleString *data = argv[4];
// Publish a message on the requested pubsub channel if necessary.
TablePubsub pubsub_channel;
REPLY_AND_RETURN_IF_NOT_OK(ParseTablePubsub(&pubsub_channel, pubsub_channel_str));
if (pubsub_channel != TablePubsub::NO_PUBLISH) {
// All other pubsub channels write the data back directly onto the
// channel.
return PublishTableUpdate(ctx, pubsub_channel_str, id,
is_add ? GcsTableNotificationMode::APPEND_OR_ADD
: GcsTableNotificationMode::REMOVE,
data);
} else {
return RedisModule_ReplyWithSimpleString(ctx, "OK");
}
}
int Set_DoWrite(RedisModuleCtx *ctx, RedisModuleString **argv, int argc, bool is_add,
bool *changed) {
if (argc != 5) {
return RedisModule_WrongArity(ctx);
}
RedisModuleString *prefix_str = argv[1];
RedisModuleString *id = argv[3];
RedisModuleString *data = argv[4];
RedisModuleString *key_string = PrefixedKeyString(ctx, prefix_str, id);
// TODO(kfstorm): According to https://redis.io/topics/modules-intro,
// set type API is not available yet. We can change RedisModule_Call to
// set type API later.
RedisModuleCallReply *reply =
RedisModule_Call(ctx, is_add ? "SADD" : "SREM", "ss", key_string, data);
if (RedisModule_CallReplyType(reply) != REDISMODULE_REPLY_ERROR) {
*changed = RedisModule_CallReplyInteger(reply) > 0;
if (!is_add && *changed) {
// try to delete the empty set.
RedisModuleKey *key;
REPLY_AND_RETURN_IF_NOT_OK(
OpenPrefixedKey(&key, ctx, prefix_str, id, REDISMODULE_WRITE));
auto size = RedisModule_ValueLength(key);
if (size == 0) {
REPLY_AND_RETURN_IF_FALSE(RedisModule_DeleteKey(key) == REDISMODULE_OK,
"ERR Failed to delete empty set.");
}
}
return REDISMODULE_OK;
} else {
// the SADD/SREM command failed
RedisModule_ReplyWithCallReply(ctx, reply);
return REDISMODULE_ERR;
}
}
/// Add an entry to the set stored at a key. Publishes a notification about
/// the update to all subscribers, if a pubsub channel is provided.
///
/// This is called from a client with the command:
//
/// RAY.SET_ADD <table_prefix> <pubsub_channel> <id> <data>
///
/// \param table_prefix The prefix string for keys in this set.
/// \param pubsub_channel The pubsub channel name that notifications for
/// this key should be published to. When publishing to a specific
/// client, the channel name should be <pubsub_channel>:<client_id>.
/// \param id The ID of the key to add to.
/// \param data The data to add to the key.
/// \return OK if the add succeeds, or an error message string if the add
/// fails.
int SetAdd_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) {
bool changed;
if (Set_DoWrite(ctx, argv, argc, /*is_add=*/true, &changed) != REDISMODULE_OK) {
return REDISMODULE_ERR;
}
if (changed) {
return Set_DoPublish(ctx, argv, /*is_add=*/true);
}
return REDISMODULE_OK;
}
/// Remove an entry from the set stored at a key. Publishes a notification about
/// the update to all subscribers, if a pubsub channel is provided.
///
/// This is called from a client with the command:
//
/// RAY.SET_REMOVE <table_prefix> <pubsub_channel> <id> <data>
///
/// \param table_prefix The prefix string for keys in this table.
/// \param pubsub_channel The pubsub channel name that notifications for
/// this key should be published to. When publishing to a specific
/// client, the channel name should be <pubsub_channel>:<client_id>.
/// \param id The ID of the key to remove from.
/// \param data The data to remove from the key.
/// \return OK if the remove succeeds, or an error message string if the remove
/// fails.
int SetRemove_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) {
bool changed;
if (Set_DoWrite(ctx, argv, argc, /*is_add=*/false, &changed) != REDISMODULE_OK) {
return REDISMODULE_ERR;
}
REPLY_AND_RETURN_IF_FALSE(changed, "ERR The entry to remove doesn't exist.");
return Set_DoPublish(ctx, argv, /*is_add=*/false);
}
/// A helper function to create and finish a GcsTableEntry, based on the
/// current value or values at the given key.
///
@ -428,11 +539,13 @@ Status TableEntryToFlatbuf(RedisModuleCtx *ctx, RedisModuleKey *table_key,
size_t data_len = 0;
char *data_buf = RedisModule_StringDMA(table_key, &data_len, REDISMODULE_READ);
auto data = fbb.CreateString(data_buf, data_len);
auto message = CreateGcsTableEntry(fbb, RedisStringToFlatbuf(fbb, entry_id),
auto message = CreateGcsTableEntry(fbb, GcsTableNotificationMode::APPEND_OR_ADD,
RedisStringToFlatbuf(fbb, entry_id),
fbb.CreateVector(&data, 1));
fbb.Finish(message);
} break;
case REDISMODULE_KEYTYPE_LIST: {
case REDISMODULE_KEYTYPE_LIST:
case REDISMODULE_KEYTYPE_SET: {
RedisModule_CloseKey(table_key);
// Close the key before executing the command. NOTE(swang): According to
// https://github.com/RedisLabs/RedisModulesSDK/blob/master/API.md, "While
@ -440,10 +553,17 @@ Status TableEntryToFlatbuf(RedisModuleCtx *ctx, RedisModuleKey *table_key,
RedisModuleString *table_key_str = PrefixedKeyString(ctx, prefix_str, entry_id);
// TODO(swang): This could potentially be replaced with the native redis
// server list iterator, once it is implemented for redis modules.
RedisModuleCallReply *reply =
RedisModule_Call(ctx, "LRANGE", "sll", table_key_str, 0, -1);
RedisModuleCallReply *reply = nullptr;
switch (key_type) {
case REDISMODULE_KEYTYPE_LIST:
reply = RedisModule_Call(ctx, "LRANGE", "sll", table_key_str, 0, -1);
break;
case REDISMODULE_KEYTYPE_SET:
reply = RedisModule_Call(ctx, "SMEMBERS", "s", table_key_str);
break;
}
// Build the flatbuffer from the set of log entries.
if (RedisModule_CallReplyType(reply) != REDISMODULE_REPLY_ARRAY) {
if (reply == nullptr || RedisModule_CallReplyType(reply) != REDISMODULE_REPLY_ARRAY) {
return Status::RedisError("Empty list or wrong type");
}
std::vector<flatbuffers::Offset<flatbuffers::String>> data;
@ -453,13 +573,14 @@ Status TableEntryToFlatbuf(RedisModuleCtx *ctx, RedisModuleKey *table_key,
const char *element_str = RedisModule_CallReplyStringPtr(element, &len);
data.push_back(fbb.CreateString(element_str, len));
}
auto message = CreateGcsTableEntry(fbb, RedisStringToFlatbuf(fbb, entry_id),
fbb.CreateVector(data));
auto message =
CreateGcsTableEntry(fbb, GcsTableNotificationMode::APPEND_OR_ADD,
RedisStringToFlatbuf(fbb, entry_id), fbb.CreateVector(data));
fbb.Finish(message);
} break;
case REDISMODULE_KEYTYPE_EMPTY: {
auto message = CreateGcsTableEntry(
fbb, RedisStringToFlatbuf(fbb, entry_id),
fbb, GcsTableNotificationMode::APPEND_OR_ADD, RedisStringToFlatbuf(fbb, entry_id),
fbb.CreateVector(std::vector<flatbuffers::Offset<flatbuffers::String>>()));
fbb.Finish(message);
} break;
@ -752,6 +873,8 @@ int DebugString_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int
// Wrap all Redis commands with Redis' auto memory management.
AUTO_MEMORY(TableAdd_RedisCommand);
AUTO_MEMORY(TableAppend_RedisCommand);
AUTO_MEMORY(SetAdd_RedisCommand);
AUTO_MEMORY(SetRemove_RedisCommand);
AUTO_MEMORY(TableLookup_RedisCommand);
AUTO_MEMORY(TableRequestNotifications_RedisCommand);
AUTO_MEMORY(TableDelete_RedisCommand);
@ -781,7 +904,17 @@ int RedisModule_OnLoad(RedisModuleCtx *ctx, RedisModuleString **argv, int argc)
}
if (RedisModule_CreateCommand(ctx, "ray.table_append", TableAppend_RedisCommand,
"write", 0, 0, 0) == REDISMODULE_ERR) {
"write pubsub", 0, 0, 0) == REDISMODULE_ERR) {
return REDISMODULE_ERR;
}
if (RedisModule_CreateCommand(ctx, "ray.set_add", SetAdd_RedisCommand, "write pubsub",
0, 0, 0) == REDISMODULE_ERR) {
return REDISMODULE_ERR;
}
if (RedisModule_CreateCommand(ctx, "ray.set_remove", SetRemove_RedisCommand,
"write pubsub", 0, 0, 0) == REDISMODULE_ERR) {
return REDISMODULE_ERR;
}

View file

@ -112,6 +112,19 @@ template <typename ID, typename Data>
Status Log<ID, Data>::Subscribe(const JobID &job_id, const ClientID &client_id,
const Callback &subscribe,
const SubscriptionCallback &done) {
auto subscribe_wrapper = [subscribe](AsyncGcsClient *client, const ID &id,
const GcsTableNotificationMode notification_mode,
const std::vector<DataT> &data) {
RAY_CHECK(notification_mode != GcsTableNotificationMode::REMOVE);
subscribe(client, id, data);
};
return Subscribe(job_id, client_id, subscribe_wrapper, done);
}
template <typename ID, typename Data>
Status Log<ID, Data>::Subscribe(const JobID &job_id, const ClientID &client_id,
const NotificationCallback &subscribe,
const SubscriptionCallback &done) {
RAY_CHECK(subscribe_callback_index_ == -1)
<< "Client called Subscribe twice on the same table";
auto callback = [this, subscribe, done](const std::string &data) {
@ -137,7 +150,7 @@ Status Log<ID, Data>::Subscribe(const JobID &job_id, const ClientID &client_id,
data_root->UnPackTo(&result);
results.emplace_back(std::move(result));
}
subscribe(client_, id, results);
subscribe(client_, id, root->notification_mode(), results);
}
}
// We do not delete the callback after calling it since there may be
@ -274,6 +287,50 @@ std::string Table<ID, Data>::DebugString() const {
return result.str();
}
template <typename ID, typename Data>
Status Set<ID, Data>::Add(const JobID &job_id, const ID &id,
std::shared_ptr<DataT> &dataT, const WriteCallback &done) {
num_adds_++;
auto callback = [this, id, dataT, done](const std::string &data) {
if (done != nullptr) {
(done)(client_, id, *dataT);
}
return true;
};
flatbuffers::FlatBufferBuilder fbb;
fbb.ForceDefaults(true);
fbb.Finish(Data::Pack(fbb, dataT.get()));
return GetRedisContext(id)->RunAsync("RAY.SET_ADD", id, fbb.GetBufferPointer(),
fbb.GetSize(), prefix_, pubsub_channel_,
std::move(callback));
}
template <typename ID, typename Data>
Status Set<ID, Data>::Remove(const JobID &job_id, const ID &id,
std::shared_ptr<DataT> &dataT, const WriteCallback &done) {
num_removes_++;
auto callback = [this, id, dataT, done](const std::string &data) {
if (done != nullptr) {
(done)(client_, id, *dataT);
}
return true;
};
flatbuffers::FlatBufferBuilder fbb;
fbb.ForceDefaults(true);
fbb.Finish(Data::Pack(fbb, dataT.get()));
return GetRedisContext(id)->RunAsync("RAY.SET_REMOVE", id, fbb.GetBufferPointer(),
fbb.GetSize(), prefix_, pubsub_channel_,
std::move(callback));
}
template <typename ID, typename Data>
std::string Set<ID, Data>::DebugString() const {
std::stringstream result;
result << "num lookups: " << num_lookups_ << ", num adds: " << num_adds_
<< ", num removes: " << num_removes_;
return result.str();
}
Status ErrorTable::PushErrorToDriver(const DriverID &driver_id, const std::string &type,
const std::string &error_message, double timestamp) {
auto data = std::make_shared<ErrorTableDataT>();
@ -534,6 +591,7 @@ Status ActorCheckpointIdTable::AddCheckpointId(const JobID &job_id,
}
template class Log<ObjectID, ObjectTableData>;
template class Set<ObjectID, ObjectTableData>;
template class Log<TaskID, ray::protocol::Task>;
template class Table<TaskID, ray::protocol::Task>;
template class Table<TaskID, TaskTableData>;

View file

@ -67,8 +67,6 @@ class LogInterface {
/// pubsub_channel_ member if pubsub is required.
///
/// Example tables backed by Log:
/// ObjectTable: Stores a log of which clients have added or evicted an
/// object.
/// ClientTable: Stores a log of which GCS clients have been added or deleted
/// from the system.
template <typename ID, typename Data>
@ -77,6 +75,9 @@ class Log : public LogInterface<ID, Data>, virtual public PubsubInterface<ID> {
using DataT = typename Data::NativeTableType;
using Callback = std::function<void(AsyncGcsClient *client, const ID &id,
const std::vector<DataT> &data)>;
using NotificationCallback = std::function<void(
AsyncGcsClient *client, const ID &id,
const GcsTableNotificationMode notification_mode, const std::vector<DataT> &data)>;
/// The callback to call when a write to a key succeeds.
using WriteCallback = typename LogInterface<ID, Data>::WriteCallback;
/// The callback to call when a SUBSCRIBE call completes and we are ready to
@ -208,6 +209,29 @@ class Log : public LogInterface<ID, Data>, virtual public PubsubInterface<ID> {
static std::hash<ray::UniqueID> index;
return shard_contexts_[index(id) % shard_contexts_.size()];
}
/// Subscribe to any modifications to the key. The caller may choose
/// to subscribe to all modifications, or to subscribe only to keys that it
/// requests notifications for. This may only be called once per Log
/// instance. This function is different from public version due to
/// an additional parameter notification_mode in NotificationCallback. Therefore this
/// function supports notifications of remove operations.
///
/// \param job_id The ID of the job (= driver).
/// \param client_id The type of update to listen to. If this is nil, then a
/// message for each Add to the table will be received. Else, only
/// messages for the given client will be received. In the latter
/// case, the client may request notifications on specific keys in the
/// table via `RequestNotifications`.
/// \param subscribe Callback that is called on each received message. If the
/// callback is called with an empty vector, then there was no data at the key.
/// \param done Callback that is called when subscription is complete and we
/// are ready to receive messages.
/// \return Status
Status Subscribe(const JobID &job_id, const ClientID &client_id,
const NotificationCallback &subscribe,
const SubscriptionCallback &done);
/// The connection to the GCS.
std::vector<std::shared_ptr<RedisContext>> shard_contexts_;
/// The GCS client.
@ -228,7 +252,6 @@ class Log : public LogInterface<ID, Data>, virtual public PubsubInterface<ID> {
/// Commands to a GCS table can either be regular (default) or chain-replicated.
CommandType command_type_ = CommandType::kRegular;
private:
int64_t num_appends_ = 0;
int64_t num_lookups_ = 0;
};
@ -337,26 +360,104 @@ class Table : private Log<ID, Data>,
using Log<ID, Data>::command_type_;
using Log<ID, Data>::GetRedisContext;
private:
int64_t num_adds_ = 0;
int64_t num_lookups_ = 0;
};
class ObjectTable : public Log<ObjectID, ObjectTableData> {
template <typename ID, typename Data>
class SetInterface {
public:
using DataT = typename Data::NativeTableType;
using WriteCallback = typename Log<ID, Data>::WriteCallback;
virtual Status Add(const JobID &job_id, const ID &id, std::shared_ptr<DataT> &data,
const WriteCallback &done) = 0;
virtual Status Remove(const JobID &job_id, const ID &id, std::shared_ptr<DataT> &data,
const WriteCallback &done) = 0;
virtual ~SetInterface(){};
};
/// \class Set
///
/// A GCS table where every entry is an addable & removable set. This class is not
/// meant to be used directly. All set classes should derive from this class
/// and override the prefix_ member with a unique prefix for that set, and the
/// pubsub_channel_ member if pubsub is required.
///
/// Example tables backed by Set:
/// ObjectTable: Stores a set of which clients have added an object.
template <typename ID, typename Data>
class Set : private Log<ID, Data>,
public SetInterface<ID, Data>,
virtual public PubsubInterface<ID> {
public:
using DataT = typename Log<ID, Data>::DataT;
using Callback = typename Log<ID, Data>::Callback;
using WriteCallback = typename Log<ID, Data>::WriteCallback;
using NotificationCallback = typename Log<ID, Data>::NotificationCallback;
using SubscriptionCallback = typename Log<ID, Data>::SubscriptionCallback;
Set(const std::vector<std::shared_ptr<RedisContext>> &contexts, AsyncGcsClient *client)
: Log<ID, Data>(contexts, client) {}
using Log<ID, Data>::RequestNotifications;
using Log<ID, Data>::CancelNotifications;
using Log<ID, Data>::Lookup;
using Log<ID, Data>::Delete;
/// Add an entry to the set.
///
/// \param job_id The ID of the job (= driver).
/// \param id The ID of the data that is added to the GCS.
/// \param data Data to add to the set.
/// \param done Callback that is called once the data has been written to the
/// GCS.
/// \return Status
Status Add(const JobID &job_id, const ID &id, std::shared_ptr<DataT> &data,
const WriteCallback &done);
/// Remove an entry from the set.
///
/// \param job_id The ID of the job (= driver).
/// \param id The ID of the data that is removed from the GCS.
/// \param data Data to remove from the set.
/// \param done Callback that is called once the data has been written to the
/// GCS.
/// \return Status
Status Remove(const JobID &job_id, const ID &id, std::shared_ptr<DataT> &data,
const WriteCallback &done);
Status Subscribe(const JobID &job_id, const ClientID &client_id,
const NotificationCallback &subscribe,
const SubscriptionCallback &done) {
return Log<ID, Data>::Subscribe(job_id, client_id, subscribe, done);
}
/// Returns debug string for class.
///
/// \return string.
std::string DebugString() const;
protected:
using Log<ID, Data>::shard_contexts_;
using Log<ID, Data>::client_;
using Log<ID, Data>::pubsub_channel_;
using Log<ID, Data>::prefix_;
using Log<ID, Data>::GetRedisContext;
int64_t num_adds_ = 0;
int64_t num_removes_ = 0;
using Log<ID, Data>::num_lookups_;
};
class ObjectTable : public Set<ObjectID, ObjectTableData> {
public:
ObjectTable(const std::vector<std::shared_ptr<RedisContext>> &contexts,
AsyncGcsClient *client)
: Log(contexts, client) {
: Set(contexts, client) {
pubsub_channel_ = TablePubsub::OBJECT;
prefix_ = TablePrefix::OBJECT;
};
ObjectTable(const std::vector<std::shared_ptr<RedisContext>> &contexts,
AsyncGcsClient *client, gcs::CommandType command_type)
: ObjectTable(contexts, client) {
command_type_ = command_type;
};
virtual ~ObjectTable(){};
};

View file

@ -8,30 +8,19 @@ ObjectDirectory::ObjectDirectory(boost::asio::io_service &io_service,
namespace {
/// Process a suffix of the object table log and store the result in
/// Process a notification of the object table entries and store the result in
/// client_ids. This assumes that client_ids already contains the result of the
/// object table log up to but not including this suffix. This also stores a
/// bool in has_been_created indicating whether the object has ever been
/// created before.
void UpdateObjectLocations(const std::vector<ObjectTableDataT> &location_history,
/// object table entries up to but not including this notification.
void UpdateObjectLocations(const GcsTableNotificationMode notification_mode,
const std::vector<ObjectTableDataT> &location_updates,
const ray::gcs::ClientTable &client_table,
std::unordered_set<ClientID> *client_ids,
bool *has_been_created) {
// location_history contains the history of locations of the object (it is a log),
// which might look like the following:
// client1.is_eviction = false
// client1.is_eviction = true
// client2.is_eviction = false
// In such a scenario, we want to indicate client2 is the only client that contains
// the object, which the following code achieves.
if (!location_history.empty()) {
// If there are entries, then the object has been created. Once this flag
// is set to true, it should never go back to false.
*has_been_created = true;
}
for (const auto &object_table_data : location_history) {
std::unordered_set<ClientID> *client_ids) {
// location_updates contains the updates of locations of the object.
// with GcsTableNotificationMode, we can determine whether the update mode is
// addition or deletion.
for (const auto &object_table_data : location_updates) {
ClientID client_id = ClientID::from_binary(object_table_data.manager);
if (!object_table_data.is_eviction) {
if (notification_mode != GcsTableNotificationMode::REMOVE) {
client_ids->insert(client_id);
} else {
client_ids->erase(client_id);
@ -52,17 +41,22 @@ void UpdateObjectLocations(const std::vector<ObjectTableDataT> &location_history
void ObjectDirectory::RegisterBackend() {
auto object_notification_callback = [this](
gcs::AsyncGcsClient *client, const ObjectID &object_id,
const std::vector<ObjectTableDataT> &location_history) {
const GcsTableNotificationMode notification_mode,
const std::vector<ObjectTableDataT> &location_updates) {
// Objects are added to this map in SubscribeObjectLocations.
auto it = listeners_.find(object_id);
// Do nothing for objects we are not listening for.
if (it == listeners_.end()) {
return;
}
// Once this flag is set to true, it should never go back to false.
it->second.subscribed = true;
// Update entries for this object.
UpdateObjectLocations(location_history, gcs_client_->client_table(),
&it->second.current_object_locations,
&it->second.has_been_created);
UpdateObjectLocations(notification_mode, location_updates,
gcs_client_->client_table(),
&it->second.current_object_locations);
// Copy the callbacks so that the callbacks can unsubscribe without interrupting
// looping over the callbacks.
auto callbacks = it->second.callbacks;
@ -73,8 +67,7 @@ void ObjectDirectory::RegisterBackend() {
for (const auto &callback_pair : callbacks) {
// It is safe to call the callback directly since this is already running
// in the subscription callback stack.
callback_pair.second(object_id, it->second.current_object_locations,
it->second.has_been_created);
callback_pair.second(object_id, it->second.current_object_locations);
}
};
RAY_CHECK_OK(gcs_client_->object_table().Subscribe(
@ -89,22 +82,22 @@ ray::Status ObjectDirectory::ReportObjectAdded(
// Append the addition entry to the object table.
auto data = std::make_shared<ObjectTableDataT>();
data->manager = client_id.binary();
data->is_eviction = false;
data->object_size = object_info.data_size;
ray::Status status =
gcs_client_->object_table().Append(JobID::nil(), object_id, data, nullptr);
gcs_client_->object_table().Add(JobID::nil(), object_id, data, nullptr);
return status;
}
ray::Status ObjectDirectory::ReportObjectRemoved(const ObjectID &object_id,
const ClientID &client_id) {
ray::Status ObjectDirectory::ReportObjectRemoved(
const ObjectID &object_id, const ClientID &client_id,
const object_manager::protocol::ObjectInfoT &object_info) {
RAY_LOG(DEBUG) << "Reporting object removed to GCS " << object_id;
// Append the eviction entry to the object table.
auto data = std::make_shared<ObjectTableDataT>();
data->manager = client_id.binary();
data->is_eviction = true;
data->object_size = object_info.data_size;
ray::Status status =
gcs_client_->object_table().Append(JobID::nil(), object_id, data, nullptr);
gcs_client_->object_table().Remove(JobID::nil(), object_id, data, nullptr);
return status;
};
@ -141,17 +134,16 @@ void ObjectDirectory::HandleClientRemoved(const ClientID &client_id) {
const ObjectID &object_id = listener.first;
if (listener.second.current_object_locations.count(client_id) > 0) {
// If the subscribed object has the removed client as a location, update
// its locations with an empty log so that the location will be removed.
UpdateObjectLocations({}, gcs_client_->client_table(),
&listener.second.current_object_locations,
&listener.second.has_been_created);
// its locations with an empty update so that the location will be removed.
UpdateObjectLocations(GcsTableNotificationMode::APPEND_OR_ADD, {},
gcs_client_->client_table(),
&listener.second.current_object_locations);
// Re-call all the subscribed callbacks for the object, since its
// locations have changed.
for (const auto &callback_pair : listener.second.callbacks) {
// It is safe to call the callback directly since this is already running
// in the subscription callback stack.
callback_pair.second(object_id, listener.second.current_object_locations,
listener.second.has_been_created);
callback_pair.second(object_id, listener.second.current_object_locations);
}
}
}
@ -175,11 +167,10 @@ ray::Status ObjectDirectory::SubscribeObjectLocations(const UniqueID &callback_i
listener_state.callbacks.emplace(callback_id, callback);
// If we previously received some notifications about the object's locations,
// immediately notify the caller of the current known locations.
if (listener_state.has_been_created) {
if (listener_state.subscribed) {
auto &locations = listener_state.current_object_locations;
io_service_.post([callback, locations, object_id]() {
callback(object_id, locations, /*has_been_created=*/true);
});
io_service_.post(
[callback, locations, object_id]() { callback(object_id, locations); });
}
return status;
}
@ -204,16 +195,14 @@ ray::Status ObjectDirectory::LookupLocations(const ObjectID &object_id,
const OnLocationsFound &callback) {
ray::Status status;
auto it = listeners_.find(object_id);
if (it != listeners_.end() && it->second.has_been_created) {
if (it != listeners_.end() && it->second.subscribed) {
// If we have locations cached due to a concurrent SubscribeObjectLocations
// call, and we have received at least one notification from the GCS about
// the object's creation, then call the callback immediately with the
// cached locations.
auto &locations = it->second.current_object_locations;
bool has_been_created = it->second.has_been_created;
io_service_.post([callback, object_id, locations, has_been_created]() {
callback(object_id, locations, has_been_created);
});
io_service_.post(
[callback, object_id, locations]() { callback(object_id, locations); });
} else {
// We do not have any locations cached due to a concurrent
// SubscribeObjectLocations call, so look up the object's locations
@ -221,15 +210,14 @@ ray::Status ObjectDirectory::LookupLocations(const ObjectID &object_id,
status = gcs_client_->object_table().Lookup(
JobID::nil(), object_id,
[this, callback](gcs::AsyncGcsClient *client, const ObjectID &object_id,
const std::vector<ObjectTableDataT> &location_history) {
const std::vector<ObjectTableDataT> &location_updates) {
// Build the set of current locations based on the entries in the log.
std::unordered_set<ClientID> client_ids;
bool has_been_created = false;
UpdateObjectLocations(location_history, gcs_client_->client_table(),
&client_ids, &has_been_created);
UpdateObjectLocations(GcsTableNotificationMode::APPEND_OR_ADD, location_updates,
gcs_client_->client_table(), &client_ids);
// It is safe to call the callback directly since this is already running
// in the GCS client's lookup callback stack.
callback(object_id, client_ids, has_been_created);
callback(object_id, client_ids);
});
}
return status;

View file

@ -51,8 +51,7 @@ class ObjectDirectoryInterface {
/// Callback for object location notifications.
using OnLocationsFound = std::function<void(const ray::ObjectID &object_id,
const std::unordered_set<ray::ClientID> &,
bool has_been_created)>;
const std::unordered_set<ray::ClientID> &)>;
/// Lookup object locations. Callback may be invoked with empty list of client ids.
///
@ -110,9 +109,11 @@ class ObjectDirectoryInterface {
///
/// \param object_id The object id that was removed from the store.
/// \param client_id The client id corresponding to this node.
/// \param object_info Additional information about the object.
/// \return Status of whether this method succeeded.
virtual ray::Status ReportObjectRemoved(const ObjectID &object_id,
const ClientID &client_id) = 0;
virtual ray::Status ReportObjectRemoved(
const ObjectID &object_id, const ClientID &client_id,
const object_manager::protocol::ObjectInfoT &object_info) = 0;
/// Get local client id
///
@ -159,8 +160,9 @@ class ObjectDirectory : public ObjectDirectoryInterface {
ray::Status ReportObjectAdded(
const ObjectID &object_id, const ClientID &client_id,
const object_manager::protocol::ObjectInfoT &object_info) override;
ray::Status ReportObjectRemoved(const ObjectID &object_id,
const ClientID &client_id) override;
ray::Status ReportObjectRemoved(
const ObjectID &object_id, const ClientID &client_id,
const object_manager::protocol::ObjectInfoT &object_info) override;
ray::ClientID GetLocalClientID() override;
@ -176,12 +178,12 @@ class ObjectDirectory : public ObjectDirectoryInterface {
std::unordered_map<UniqueID, OnLocationsFound> callbacks;
/// The current set of known locations of this object.
std::unordered_set<ClientID> current_object_locations;
/// This flag will get set to true if the object has ever been created. It
/// This flag will get set to true if received any notification of the object.
/// It means current_object_locations is up-to-date with GCS. It
/// should never go back to false once set to true. If this is true, and
/// the current_object_locations is empty, then this means that the object
/// does not exist on any nodes due to eviction (rather than due to the
/// object never getting created, for instance).
bool has_been_created;
/// does not exist on any nodes due to eviction or the object never getting created.
bool subscribed;
};
/// Reference to the event loop.

View file

@ -93,8 +93,10 @@ void ObjectManager::HandleObjectAdded(
void ObjectManager::NotifyDirectoryObjectDeleted(const ObjectID &object_id) {
auto it = local_objects_.find(object_id);
RAY_CHECK(it != local_objects_.end());
auto object_info = it->second.object_info;
local_objects_.erase(it);
ray::Status status = object_directory_->ReportObjectRemoved(object_id, client_id_);
ray::Status status =
object_directory_->ReportObjectRemoved(object_id, client_id_, object_info);
}
ray::Status ObjectManager::SubscribeObjAdded(
@ -127,8 +129,7 @@ ray::Status ObjectManager::Pull(const ObjectID &object_id) {
// no ordering guarantee between notifications.
return object_directory_->SubscribeObjectLocations(
object_directory_pull_callback_id_, object_id,
[this](const ObjectID &object_id, const std::unordered_set<ClientID> &client_ids,
bool created) {
[this](const ObjectID &object_id, const std::unordered_set<ClientID> &client_ids) {
// Exit if the Pull request has already been fulfilled or canceled.
auto it = pull_requests_.find(object_id);
if (it == pull_requests_.end()) {
@ -578,9 +579,8 @@ ray::Status ObjectManager::LookupRemainingWaitObjects(const UniqueID &wait_id) {
// Lookup remaining objects.
wait_state.requested_objects.insert(object_id);
RAY_RETURN_NOT_OK(object_directory_->LookupLocations(
object_id,
[this, wait_id](const ObjectID &lookup_object_id,
const std::unordered_set<ClientID> &client_ids, bool created) {
object_id, [this, wait_id](const ObjectID &lookup_object_id,
const std::unordered_set<ClientID> &client_ids) {
auto &wait_state = active_wait_requests_.find(wait_id)->second;
if (!client_ids.empty()) {
wait_state.remaining.erase(lookup_object_id);
@ -618,7 +618,7 @@ void ObjectManager::SubscribeRemainingWaitObjects(const UniqueID &wait_id) {
RAY_CHECK_OK(object_directory_->SubscribeObjectLocations(
wait_id, object_id,
[this, wait_id](const ObjectID &subscribe_object_id,
const std::unordered_set<ClientID> &client_ids, bool created) {
const std::unordered_set<ClientID> &client_ids) {
if (!client_ids.empty()) {
RAY_LOG(DEBUG) << "Wait request " << wait_id
<< ": subscription notification received for object "

View file

@ -291,10 +291,9 @@ class TestObjectManager : public TestObjectManagerBase {
UniqueID sub_id = ray::ObjectID::from_random();
RAY_CHECK_OK(server1->object_manager_.object_directory_->SubscribeObjectLocations(
sub_id, object_1,
[this, sub_id, object_1, object_2](
const ray::ObjectID &object_id,
const std::unordered_set<ray::ClientID> &clients, bool created) {
sub_id, object_1, [this, sub_id, object_1, object_2](
const ray::ObjectID &object_id,
const std::unordered_set<ray::ClientID> &clients) {
if (!clients.empty()) {
TestWaitWhileSubscribed(sub_id, object_1, object_2);
}

View file

@ -60,9 +60,7 @@ NodeManager::NodeManager(boost::asio::io_service &io_service,
scheduling_policy_(local_queues_),
reconstruction_policy_(
io_service_,
[this](const TaskID &task_id, bool return_values_lost) {
HandleTaskReconstruction(task_id);
},
[this](const TaskID &task_id) { HandleTaskReconstruction(task_id); },
RayConfig::instance().initial_reconstruction_timeout_milliseconds(),
gcs_client_->client_table().GetLocalClientId(), gcs_client_->task_lease_table(),
object_directory_, gcs_client_->task_reconstruction_log()),
@ -1287,14 +1285,13 @@ void NodeManager::TreatTaskAsFailedIfLost(const Task &task) {
const ObjectID object_id = spec.ReturnId(i);
// Lookup the return value's locations.
RAY_CHECK_OK(object_directory_->LookupLocations(
object_id,
[this, task_marked_as_failed, task](
const ray::ObjectID &object_id,
const std::unordered_set<ray::ClientID> &clients, bool has_been_created) {
object_id, [this, task_marked_as_failed, task](
const ray::ObjectID &object_id,
const std::unordered_set<ray::ClientID> &clients) {
if (!*task_marked_as_failed) {
// Only process the object locations if we haven't already marked the
// task as failed.
if (clients.empty() && has_been_created) {
if (clients.empty()) {
// The object does not exist on any nodes but has been created
// before, so the object has been lost. Mark the task as failed to
// prevent any tasks that depend on this object from hanging.

View file

@ -6,7 +6,7 @@ namespace raylet {
ReconstructionPolicy::ReconstructionPolicy(
boost::asio::io_service &io_service,
std::function<void(const TaskID &, bool)> reconstruction_handler,
std::function<void(const TaskID &)> reconstruction_handler,
int64_t initial_reconstruction_timeout_ms, const ClientID &client_id,
gcs::PubsubInterface<TaskID> &task_lease_pubsub,
std::shared_ptr<ObjectDirectoryInterface> object_directory,
@ -74,14 +74,13 @@ void ReconstructionPolicy::HandleReconstructionLogAppend(const TaskID &task_id,
SetTaskTimeout(it, initial_reconstruction_timeout_ms_);
if (success) {
reconstruction_handler_(task_id, it->second.return_values_lost);
reconstruction_handler_(task_id);
}
}
void ReconstructionPolicy::AttemptReconstruction(const TaskID &task_id,
const ObjectID &required_object_id,
int reconstruction_attempt,
bool created) {
int reconstruction_attempt) {
// If we are no longer listening for objects created by this task, give up.
auto it = listening_tasks_.find(task_id);
if (it == listening_tasks_.end()) {
@ -93,10 +92,6 @@ void ReconstructionPolicy::AttemptReconstruction(const TaskID &task_id,
return;
}
if (created) {
it->second.return_values_lost = true;
}
// Suppress duplicate reconstructions of the same task. This can happen if,
// for example, a task creates two different objects that both require
// reconstruction.
@ -142,14 +137,13 @@ void ReconstructionPolicy::HandleTaskLeaseExpired(const TaskID &task_id) {
// attempted asynchronously.
for (const auto &created_object_id : it->second.created_objects) {
RAY_CHECK_OK(object_directory_->LookupLocations(
created_object_id,
[this, task_id, reconstruction_attempt](
const ray::ObjectID &object_id,
const std::unordered_set<ray::ClientID> &clients, bool created) {
created_object_id, [this, task_id, reconstruction_attempt](
const ray::ObjectID &object_id,
const std::unordered_set<ray::ClientID> &clients) {
if (clients.empty()) {
// The required object no longer exists on any live nodes. Attempt
// reconstruction.
AttemptReconstruction(task_id, object_id, reconstruction_attempt, created);
AttemptReconstruction(task_id, object_id, reconstruction_attempt);
}
}));
}

View file

@ -40,7 +40,7 @@ class ReconstructionPolicy : public ReconstructionPolicyInterface {
/// lease notifications from.
ReconstructionPolicy(
boost::asio::io_service &io_service,
std::function<void(const TaskID &, bool)> reconstruction_handler,
std::function<void(const TaskID &)> reconstruction_handler,
int64_t initial_reconstruction_timeout_ms, const ClientID &client_id,
gcs::PubsubInterface<TaskID> &task_lease_pubsub,
std::shared_ptr<ObjectDirectoryInterface> object_directory,
@ -93,7 +93,6 @@ class ReconstructionPolicy : public ReconstructionPolicyInterface {
bool subscribed;
// The number of times we've attempted reconstructing this task so far.
int reconstruction_attempt;
bool return_values_lost;
// The task's reconstruction timer. If this expires before a lease
// notification is received, then the task will be reconstructed.
std::unique_ptr<boost::asio::deadline_timer> reconstruction_timer;
@ -116,7 +115,7 @@ class ReconstructionPolicy : public ReconstructionPolicyInterface {
/// reconstructions of the same task (e.g., if a task creates two objects
/// that both require reconstruction).
void AttemptReconstruction(const TaskID &task_id, const ObjectID &required_object_id,
int reconstruction_attempt, bool created);
int reconstruction_attempt);
/// Handle expiration of a task lease.
void HandleTaskLeaseExpired(const TaskID &task_id);
@ -128,7 +127,7 @@ class ReconstructionPolicy : public ReconstructionPolicyInterface {
/// The event loop.
boost::asio::io_service &io_service_;
/// The handler to call for tasks that require reconstruction.
const std::function<void(const TaskID &, bool)> reconstruction_handler_;
const std::function<void(const TaskID &)> reconstruction_handler_;
/// The initial timeout within which a task lease notification must be
/// received. Otherwise, reconstruction will be triggered.
const int64_t initial_reconstruction_timeout_ms_;

View file

@ -29,10 +29,9 @@ class MockObjectDirectory : public ObjectDirectoryInterface {
const ObjectID object_id = callback.first;
auto it = locations_.find(object_id);
if (it == locations_.end()) {
callback.second(object_id, std::unordered_set<ray::ClientID>(),
/*created=*/false);
callback.second(object_id, std::unordered_set<ray::ClientID>());
} else {
callback.second(object_id, it->second, /*created=*/true);
callback.second(object_id, it->second);
}
}
callbacks_.clear();
@ -63,7 +62,9 @@ class MockObjectDirectory : public ObjectDirectoryInterface {
MOCK_METHOD3(ReportObjectAdded,
ray::Status(const ObjectID &, const ClientID &,
const object_manager::protocol::ObjectInfoT &));
MOCK_METHOD2(ReportObjectRemoved, ray::Status(const ObjectID &, const ClientID &));
MOCK_METHOD3(ReportObjectRemoved,
ray::Status(const ObjectID &, const ClientID &,
const object_manager::protocol::ObjectInfoT &));
private:
std::vector<std::pair<ObjectID, OnLocationsFound>> callbacks_;
@ -151,8 +152,8 @@ class ReconstructionPolicyTest : public ::testing::Test {
mock_object_directory_(std::make_shared<MockObjectDirectory>()),
reconstruction_timeout_ms_(50),
reconstruction_policy_(std::make_shared<ReconstructionPolicy>(
io_service_, [this](const TaskID &task_id,
bool created) { TriggerReconstruction(task_id); },
io_service_,
[this](const TaskID &task_id) { TriggerReconstruction(task_id); },
reconstruction_timeout_ms_, ClientID::from_random(), mock_gcs_,
mock_object_directory_, mock_gcs_)),
timer_canceled_(false) {