diff --git a/src/ray/raylet/lineage_cache.cc b/src/ray/raylet/lineage_cache.cc index 47daf769b..16f6da5a0 100644 --- a/src/ray/raylet/lineage_cache.cc +++ b/src/ray/raylet/lineage_cache.cc @@ -23,6 +23,14 @@ void LineageEntry::ResetStatus(GcsStatus new_status) { status_ = new_status; } +void LineageEntry::MarkExplicitlyForwarded(const ClientID &node_id) { + forwarded_to_.insert(node_id); +} + +bool LineageEntry::WasExplicitlyForwarded(const ClientID &node_id) const { + return forwarded_to_.find(node_id) != forwarded_to_.end(); +} + const TaskID LineageEntry::GetEntryId() const { return task_.GetTaskSpecification().TaskId(); } @@ -139,7 +147,7 @@ LineageCache::LineageCache(const ClientID &client_id, /// lineage_from. This should return true if the merge should stop. void MergeLineageHelper(const TaskID &task_id, const Lineage &lineage_from, Lineage &lineage_to, - std::function stopping_condition) { + std::function stopping_condition) { // If the entry is not found in the lineage to merge, then we stop since // there is nothing to copy into the merged lineage. auto entry = lineage_from.GetEntry(task_id); @@ -147,8 +155,7 @@ void MergeLineageHelper(const TaskID &task_id, const Lineage &lineage_from, return; } // Check whether we should stop at this entry in the DFS. - auto status = entry->GetStatus(); - if (stopping_condition(status)) { + if (stopping_condition(entry.get())) { return; } @@ -167,16 +174,17 @@ void MergeLineageHelper(const TaskID &task_id, const Lineage &lineage_from, void LineageCache::AddWaitingTask(const Task &task, const Lineage &uncommitted_lineage) { auto task_id = task.GetTaskSpecification().TaskId(); // Merge the uncommitted lineage into the lineage cache. - MergeLineageHelper(task_id, uncommitted_lineage, lineage_, [](GcsStatus status) { - if (status != GcsStatus::NONE) { - // We received the uncommitted lineage from a remote node, so make sure - // that all entries in the lineage to merge have status - // UNCOMMITTED_REMOTE. - RAY_CHECK(status == GcsStatus::UNCOMMITTED_REMOTE); - } - // The only stopping condition is that an entry is not found. - return false; - }); + MergeLineageHelper(task_id, uncommitted_lineage, lineage_, + [](const LineageEntry &entry) { + if (entry.GetStatus() != GcsStatus::NONE) { + // We received the uncommitted lineage from a remote node, so + // make sure that all entries in the lineage to merge have + // status UNCOMMITTED_REMOTE. + RAY_CHECK(entry.GetStatus() == GcsStatus::UNCOMMITTED_REMOTE); + } + // The only stopping condition is that an entry is not found. + return false; + }); // If the task was previously remote, then we may have been subscribed to // it. Unsubscribe since we are now responsible for committing the task. @@ -256,15 +264,23 @@ void LineageCache::RemoveWaitingTask(const TaskID &task_id) { } } -Lineage LineageCache::GetUncommittedLineage(const TaskID &task_id) const { +void LineageCache::MarkTaskAsForwarded(const TaskID &task_id, const ClientID &node_id) { + RAY_CHECK(!node_id.is_nil()); + lineage_.GetEntryMutable(task_id)->MarkExplicitlyForwarded(node_id); +} + +Lineage LineageCache::GetUncommittedLineage(const TaskID &task_id, + const ClientID &node_id) const { Lineage uncommitted_lineage; // Add all uncommitted ancestors from the lineage cache to the uncommitted // lineage of the requested task. - MergeLineageHelper(task_id, lineage_, uncommitted_lineage, [](GcsStatus status) { - // The stopping condition for recursion is that the entry has been - // committed to the GCS. - return false; - }); + MergeLineageHelper( + task_id, lineage_, uncommitted_lineage, [&](const LineageEntry &entry) { + // The stopping condition for recursion is that the entry has + // been committed to the GCS or has already been forwarded. + // The lineage always includes the requested task id. + return entry.WasExplicitlyForwarded(node_id) && !(entry.GetEntryId() == task_id); + }); return uncommitted_lineage; } diff --git a/src/ray/raylet/lineage_cache.h b/src/ray/raylet/lineage_cache.h index 7987953d0..0a104ac97 100644 --- a/src/ray/raylet/lineage_cache.h +++ b/src/ray/raylet/lineage_cache.h @@ -64,6 +64,17 @@ class LineageEntry { /// \param new_status This must be lower than the current status. void ResetStatus(GcsStatus new_status); + /// Mark this entry as having been explicitly forwarded to a remote node manager. + /// + /// \param node_id The ID of the remote node manager. + void MarkExplicitlyForwarded(const ClientID &node_id); + + /// Gets whether this entry was explicitly forwarded to a remote node. + /// + /// \param node_id The ID of the remote node manager. + /// \return Whether this entry was explicitly forwarded to the remote node. + bool WasExplicitlyForwarded(const ClientID &node_id) const; + /// Get this entry's ID. /// /// \return The entry's ID. @@ -88,6 +99,9 @@ class LineageEntry { /// an object. // const Task task_; Task task_; + + /// IDs of node managers that this task has been explicitly forwarded to. + std::unordered_set forwarded_to_; }; /// \class Lineage @@ -184,14 +198,22 @@ class LineageCache { /// \param task_id The ID of the waiting task to remove. void RemoveWaitingTask(const TaskID &task_id); - /// Get the uncommitted lineage of a task. The uncommitted lineage consists - /// of all tasks in the given task's lineage that have not been committed in - /// the GCS, as far as we know. + /// Mark a task as having been explicitly forwarded to a node. + /// The lineage of the task is implicitly assumed to have also been forwarded. /// - /// \param entry_id The ID of the task to get the uncommitted lineage for. - /// \return The uncommitted lineage of the task. The returned lineage + /// \param task_id The ID of the task to get the uncommitted lineage for. + /// \param node_id The ID of the node to get the uncommitted lineage for. + void MarkTaskAsForwarded(const TaskID &task_id, const ClientID &node_id); + + /// Get the uncommitted lineage of a task that hasn't been forwarded to a node yet. + /// The uncommitted lineage consists of all tasks in the given task's lineage + /// that have not been committed in the GCS, as far as we know. + /// + /// \param task_id The ID of the task to get the uncommitted lineage for. + /// \param node_id The ID of the receiving node. + /// \return The uncommitted, unforwarded lineage of the task. The returned lineage /// includes the entry for the requested entry_id. - Lineage GetUncommittedLineage(const TaskID &entry_id) const; + Lineage GetUncommittedLineage(const TaskID &task_id, const ClientID &node_id) const; /// Asynchronously write any tasks that are in the UNCOMMITTED_READY state /// and for which all parents have been committed to the GCS. These tasks diff --git a/src/ray/raylet/lineage_cache_test.cc b/src/ray/raylet/lineage_cache_test.cc index ba6f5367b..240d622ec 100644 --- a/src/ray/raylet/lineage_cache_test.cc +++ b/src/ray/raylet/lineage_cache_test.cc @@ -157,11 +157,10 @@ TEST_F(LineageCacheTest, TestGetUncommittedLineage) { task_ids2.push_back(task.GetTaskSpecification().TaskId()); } - // Get the uncommitted lineage for the last task (the leaf) of one of the - // chains. - auto uncommitted_lineage = lineage_cache_.GetUncommittedLineage(task_ids1.back()); - // Check that the uncommitted lineage is exactly equal to the first chain of - // tasks. + // Get the uncommitted lineage for the last task (the leaf) of one of the chains. + auto uncommitted_lineage = + lineage_cache_.GetUncommittedLineage(task_ids1.back(), ClientID::nil()); + // Check that the uncommitted lineage is exactly equal to the first chain of tasks. ASSERT_EQ(task_ids1.size(), uncommitted_lineage.GetEntries().size()); for (auto &task_id : task_ids1) { ASSERT_TRUE(uncommitted_lineage.GetEntry(task_id)); @@ -180,7 +179,8 @@ TEST_F(LineageCacheTest, TestGetUncommittedLineage) { } // Get the uncommitted lineage for the inserted task. - uncommitted_lineage = lineage_cache_.GetUncommittedLineage(combined_task_ids.back()); + uncommitted_lineage = + lineage_cache_.GetUncommittedLineage(combined_task_ids.back(), ClientID::nil()); // Check that the uncommitted lineage is exactly equal to the entire set of // tasks inserted so far. ASSERT_EQ(combined_task_ids.size(), uncommitted_lineage.GetEntries().size()); @@ -189,6 +189,38 @@ TEST_F(LineageCacheTest, TestGetUncommittedLineage) { } } +TEST_F(LineageCacheTest, TestMarkTaskAsForwarded) { + // Insert chain of tasks. + std::vector tasks; + auto return_values = + InsertTaskChain(lineage_cache_, tasks, 4, std::vector(), 1); + std::vector task_ids; + for (const auto &task : tasks) { + task_ids.push_back(task.GetTaskSpecification().TaskId()); + } + + auto node_id = ClientID::from_random(); + auto node_id2 = ClientID::from_random(); + auto forwarded_task_id = task_ids[task_ids.size() - 2]; + auto remaining_task_id = task_ids[task_ids.size() - 1]; + lineage_cache_.MarkTaskAsForwarded(forwarded_task_id, node_id); + + auto uncommitted_lineage = + lineage_cache_.GetUncommittedLineage(remaining_task_id, node_id); + auto uncommitted_lineage_all = + lineage_cache_.GetUncommittedLineage(remaining_task_id, node_id2); + + ASSERT_EQ(1, uncommitted_lineage.GetEntries().size()); + ASSERT_EQ(4, uncommitted_lineage_all.GetEntries().size()); + ASSERT_TRUE(uncommitted_lineage.GetEntry(remaining_task_id)); + + // Check that lineage of requested task includes itself, regardless of whether + // it has been forwarded before. + auto uncommitted_lineage_forwarded = + lineage_cache_.GetUncommittedLineage(remaining_task_id, node_id); + ASSERT_EQ(1, uncommitted_lineage_forwarded.GetEntries().size()); +} + void CheckFlush(LineageCache &lineage_cache, MockGcs &mock_gcs, size_t num_tasks_flushed) { lineage_cache.Flush(); @@ -199,8 +231,7 @@ TEST_F(LineageCacheTest, TestWritebackNoneReady) { // Insert a chain of dependent tasks. size_t num_tasks_flushed = 0; std::vector tasks; - auto return_values1 = - InsertTaskChain(lineage_cache_, tasks, 3, std::vector(), 1); + InsertTaskChain(lineage_cache_, tasks, 3, std::vector(), 1); // Check that when no tasks have been marked as ready, we do not flush any // entries. @@ -211,8 +242,7 @@ TEST_F(LineageCacheTest, TestWritebackReady) { // Insert a chain of dependent tasks. size_t num_tasks_flushed = 0; std::vector tasks; - auto return_values1 = - InsertTaskChain(lineage_cache_, tasks, 3, std::vector(), 1); + InsertTaskChain(lineage_cache_, tasks, 3, std::vector(), 1); // Check that after marking the first task as ready, we flush only that task. lineage_cache_.AddReadyTask(tasks.front()); @@ -224,8 +254,7 @@ TEST_F(LineageCacheTest, TestWritebackOrder) { // Insert a chain of dependent tasks. size_t num_tasks_flushed = 0; std::vector tasks; - auto return_values1 = - InsertTaskChain(lineage_cache_, tasks, 3, std::vector(), 1); + InsertTaskChain(lineage_cache_, tasks, 3, std::vector(), 1); // Mark all tasks as ready. The first task, which has no dependencies, should // be flushed. @@ -288,15 +317,15 @@ TEST_F(LineageCacheTest, TestForwardTasksRoundTrip) { // Insert a chain of dependent tasks. uint64_t lineage_size = max_lineage_size_ + 1; std::vector tasks; - auto return_values1 = - InsertTaskChain(lineage_cache_, tasks, lineage_size, std::vector(), 1); + InsertTaskChain(lineage_cache_, tasks, lineage_size, std::vector(), 1); // Simulate removing each task, forwarding it to another node, then // receiving the task back again. for (auto it = tasks.begin(); it != tasks.end(); it++) { const auto task_id = it->GetTaskSpecification().TaskId(); // Simulate removing the task and forwarding it to another node. - auto uncommitted_lineage = lineage_cache_.GetUncommittedLineage(task_id); + auto uncommitted_lineage = + lineage_cache_.GetUncommittedLineage(task_id, ClientID::nil()); lineage_cache_.RemoveWaitingTask(task_id); // Simulate receiving the task again. Make sure we can add the task back. flatbuffers::FlatBufferBuilder fbb; @@ -312,15 +341,15 @@ TEST_F(LineageCacheTest, TestForwardTask) { // Insert a chain of dependent tasks. size_t num_tasks_flushed = 0; std::vector tasks; - auto return_values1 = - InsertTaskChain(lineage_cache_, tasks, 3, std::vector(), 1); + InsertTaskChain(lineage_cache_, tasks, 3, std::vector(), 1); // Simulate removing the task and forwarding it to another node. auto it = tasks.begin() + 1; auto forwarded_task = *it; tasks.erase(it); auto task_id_to_remove = forwarded_task.GetTaskSpecification().TaskId(); - auto uncommitted_lineage = lineage_cache_.GetUncommittedLineage(task_id_to_remove); + auto uncommitted_lineage = + lineage_cache_.GetUncommittedLineage(task_id_to_remove, ClientID::nil()); lineage_cache_.RemoveWaitingTask(task_id_to_remove); // Simulate executing the remaining tasks. @@ -366,7 +395,8 @@ TEST_F(LineageCacheTest, TestEviction) { // Check that the last task in the chain still has all tasks in its // uncommitted lineage. const auto last_task_id = tasks.back().GetTaskSpecification().TaskId(); - auto uncommitted_lineage = lineage_cache_.GetUncommittedLineage(last_task_id); + auto uncommitted_lineage = + lineage_cache_.GetUncommittedLineage(last_task_id, ClientID::nil()); ASSERT_EQ(uncommitted_lineage.GetEntries().size(), lineage_size); // Simulate executing the first task on a remote node and adding it to the @@ -394,7 +424,8 @@ TEST_F(LineageCacheTest, TestEviction) { } // All tasks have now been flushed. Check that enough lineage has been // evicted that the uncommitted lineage is now less than the maximum size. - uncommitted_lineage = lineage_cache_.GetUncommittedLineage(last_task_id); + uncommitted_lineage = + lineage_cache_.GetUncommittedLineage(last_task_id, ClientID::nil()); ASSERT_TRUE(uncommitted_lineage.GetEntries().size() <= max_lineage_size_); } @@ -418,7 +449,8 @@ TEST_F(LineageCacheTest, TestOutOfOrderEviction) { // Check that the last task in the chain still has all tasks in its // uncommitted lineage. const auto last_task_id = tasks.back().GetTaskSpecification().TaskId(); - auto uncommitted_lineage = lineage_cache_.GetUncommittedLineage(last_task_id); + auto uncommitted_lineage = + lineage_cache_.GetUncommittedLineage(last_task_id, ClientID::nil()); ASSERT_EQ(uncommitted_lineage.GetEntries().size(), lineage_size); // Simulate executing the tasks at the remote node and receiving the @@ -446,7 +478,8 @@ TEST_F(LineageCacheTest, TestOutOfOrderEviction) { } // All tasks have now been flushed. Check that enough lineage has been // evicted that the uncommitted lineage is now less than the maximum size. - uncommitted_lineage = lineage_cache_.GetUncommittedLineage(last_task_id); + uncommitted_lineage = + lineage_cache_.GetUncommittedLineage(last_task_id, ClientID::nil()); ASSERT_TRUE(uncommitted_lineage.GetEntries().size() <= max_lineage_size_); } @@ -479,7 +512,8 @@ TEST_F(LineageCacheTest, TestEvictionUncommittedChildren) { // Check that the last task in the chain still has all tasks in its // uncommitted lineage. const auto last_task_id = tasks.back().GetTaskSpecification().TaskId(); - auto uncommitted_lineage = lineage_cache_.GetUncommittedLineage(last_task_id); + auto uncommitted_lineage = + lineage_cache_.GetUncommittedLineage(last_task_id, ClientID::nil()); ASSERT_EQ(uncommitted_lineage.GetEntries().size(), lineage_size); // Simulate executing the last task on a remote node and adding it to the diff --git a/src/ray/raylet/node_manager.cc b/src/ray/raylet/node_manager.cc index 2be170a46..b2c6f91e4 100644 --- a/src/ray/raylet/node_manager.cc +++ b/src/ray/raylet/node_manager.cc @@ -1123,8 +1123,7 @@ void NodeManager::ResubmitTask(const TaskID &task_id) { void NodeManager::HandleObjectLocal(const ObjectID &object_id) { // Notify the task dependency manager that this object is local. const auto ready_task_ids = task_dependency_manager_.HandleObjectLocal(object_id); - // Transition the tasks whose dependencies are now fulfilled to the ready - // state. + // Transition the tasks whose dependencies are now fulfilled to the ready state. if (ready_task_ids.size() > 0) { std::unordered_set ready_task_id_set(ready_task_ids.begin(), ready_task_ids.end()); @@ -1197,10 +1196,11 @@ ray::Status NodeManager::ForwardTask(const Task &task, const ClientID &node_id) const auto &spec = task.GetTaskSpecification(); auto task_id = spec.TaskId(); - // Get and serialize the task's uncommitted lineage. - auto uncommitted_lineage = lineage_cache_.GetUncommittedLineage(task_id); + // Get and serialize the task's unforwarded, uncommitted lineage. + auto uncommitted_lineage = lineage_cache_.GetUncommittedLineage(task_id, node_id); Task &lineage_cache_entry_task = uncommitted_lineage.GetEntryMutable(task_id)->TaskDataMutable(); + // Increment forward count for the forwarded task. lineage_cache_entry_task.GetTaskExecutionSpec().IncrementNumForwards(); @@ -1230,6 +1230,10 @@ ray::Status NodeManager::ForwardTask(const Task &task, const ClientID &node_id) // lineage cache since the receiving node is now responsible for writing // the task to the GCS. lineage_cache_.RemoveWaitingTask(task_id); + // Mark as forwarded so that the task and its lineage is not re-forwarded + // in the future to the receiving node. + lineage_cache_.MarkTaskAsForwarded(task_id, node_id); + // Notify the task dependency manager that we are no longer responsible // for executing this task. task_dependency_manager_.TaskCanceled(task_id);