From f8ae2b2b62de9fecffbdbd7a5c7d5eb79abd96c5 Mon Sep 17 00:00:00 2001 From: Jiajun Yao Date: Mon, 13 Sep 2021 11:27:42 -0700 Subject: [PATCH] Don't pass in TaskID to TaskManager::MarkPendingTaskFailed since it can (#18532) be got from TaskSpecification --- src/mock/ray/core_worker/task_manager.h | 3 +-- src/ray/core_worker/task_manager.cc | 5 +++-- src/ray/core_worker/task_manager.h | 10 +++++----- src/ray/core_worker/test/direct_task_transport_test.cc | 3 +-- .../core_worker/transport/direct_actor_transport.cc | 6 ++---- src/ray/core_worker/transport/direct_task_transport.cc | 3 +-- 6 files changed, 13 insertions(+), 17 deletions(-) diff --git a/src/mock/ray/core_worker/task_manager.h b/src/mock/ray/core_worker/task_manager.h index 36a6b051a..effea598d 100644 --- a/src/mock/ray/core_worker/task_manager.h +++ b/src/mock/ray/core_worker/task_manager.h @@ -34,8 +34,7 @@ class MockTaskFinisherInterface : public TaskFinisherInterface { (override)); MOCK_METHOD(bool, MarkTaskCanceled, (const TaskID &task_id), (override)); MOCK_METHOD(void, MarkPendingTaskFailed, - (const TaskID &task_id, const TaskSpecification &spec, - rpc::ErrorType error_type, + (const TaskSpecification &spec, rpc::ErrorType error_type, const std::shared_ptr &creation_task_exception), (override)); MOCK_METHOD(absl::optional, GetTaskSpec, (const TaskID &task_id), diff --git a/src/ray/core_worker/task_manager.cc b/src/ray/core_worker/task_manager.cc index 1bcf77003..7e070017e 100644 --- a/src/ray/core_worker/task_manager.cc +++ b/src/ray/core_worker/task_manager.cc @@ -383,7 +383,7 @@ bool TaskManager::PendingTaskFailed( RemoveFinishedTaskReferences(spec, release_lineage, rpc::Address(), ReferenceCounter::ReferenceTableProto()); if (immediately_mark_object_fail) { - MarkPendingTaskFailed(task_id, spec, error_type, creation_task_exception); + MarkPendingTaskFailed(spec, error_type, creation_task_exception); } } @@ -492,8 +492,9 @@ bool TaskManager::MarkTaskCanceled(const TaskID &task_id) { } void TaskManager::MarkPendingTaskFailed( - const TaskID &task_id, const TaskSpecification &spec, rpc::ErrorType error_type, + const TaskSpecification &spec, rpc::ErrorType error_type, const std::shared_ptr &creation_task_exception) { + const TaskID task_id = spec.TaskId(); RAY_LOG(DEBUG) << "Treat task as failed. task_id: " << task_id << ", error_type: " << ErrorType_Name(error_type); int64_t num_returns = spec.NumReturns(); diff --git a/src/ray/core_worker/task_manager.h b/src/ray/core_worker/task_manager.h index f337d82bc..aece58ffc 100644 --- a/src/ray/core_worker/task_manager.h +++ b/src/ray/core_worker/task_manager.h @@ -45,7 +45,7 @@ class TaskFinisherInterface { virtual bool MarkTaskCanceled(const TaskID &task_id) = 0; virtual void MarkPendingTaskFailed( - const TaskID &task_id, const TaskSpecification &spec, rpc::ErrorType error_type, + const TaskSpecification &spec, rpc::ErrorType error_type, const std::shared_ptr &creation_task_exception = nullptr) = 0; virtual absl::optional GetTaskSpec(const TaskID &task_id) const = 0; @@ -148,10 +148,10 @@ class TaskManager : public TaskFinisherInterface, public TaskResubmissionInterfa /// Treat a pending task as failed. The lock should not be held when calling /// this method because it may trigger callbacks in this or other classes. - void MarkPendingTaskFailed( - const TaskID &task_id, const TaskSpecification &spec, rpc::ErrorType error_type, - const std::shared_ptr &creation_task_exception = - nullptr) override LOCKS_EXCLUDED(mu_); + void MarkPendingTaskFailed(const TaskSpecification &spec, rpc::ErrorType error_type, + const std::shared_ptr + &creation_task_exception = nullptr) override + LOCKS_EXCLUDED(mu_); /// A task's dependencies were inlined in the task spec. This will decrement /// the ref count for the dependency IDs. If the dependencies contained other diff --git a/src/ray/core_worker/test/direct_task_transport_test.cc b/src/ray/core_worker/test/direct_task_transport_test.cc index b5b24a278..6f5346ed7 100644 --- a/src/ray/core_worker/test/direct_task_transport_test.cc +++ b/src/ray/core_worker/test/direct_task_transport_test.cc @@ -123,8 +123,7 @@ class MockTaskFinisher : public TaskFinisherInterface { num_contained_ids += contained_ids.size(); } - void MarkPendingTaskFailed(const TaskID &task_id, const TaskSpecification &spec, - rpc::ErrorType error_type, + void MarkPendingTaskFailed(const TaskSpecification &spec, rpc::ErrorType error_type, const std::shared_ptr &creation_task_exception = nullptr) override {} diff --git a/src/ray/core_worker/transport/direct_actor_transport.cc b/src/ray/core_worker/transport/direct_actor_transport.cc index 7be01da15..2ee73a5de 100644 --- a/src/ray/core_worker/transport/direct_actor_transport.cc +++ b/src/ray/core_worker/transport/direct_actor_transport.cc @@ -229,8 +229,7 @@ void CoreWorkerDirectActorTaskSubmitter::DisconnectActor( << wait_for_death_info_tasks.size() << ", actor_id=" << actor_id; for (auto &net_err_task : wait_for_death_info_tasks) { RAY_UNUSED(task_finisher_->MarkPendingTaskFailed( - net_err_task.second.TaskId(), net_err_task.second, rpc::ErrorType::ACTOR_DIED, - creation_task_exception)); + net_err_task.second, rpc::ErrorType::ACTOR_DIED, creation_task_exception)); } // No need to clean up tasks that have been sent and are waiting for @@ -253,8 +252,7 @@ void CoreWorkerDirectActorTaskSubmitter::CheckTimeoutTasks() { while (deque_itr != queue.wait_for_death_info_tasks.end() && /*timeout timestamp*/ deque_itr->first < current_time_ms()) { auto task_spec = deque_itr->second; - task_finisher_->MarkPendingTaskFailed(task_spec.TaskId(), task_spec, - rpc::ErrorType::ACTOR_DIED); + task_finisher_->MarkPendingTaskFailed(task_spec, rpc::ErrorType::ACTOR_DIED); deque_itr = queue.wait_for_death_info_tasks.erase(deque_itr); } } diff --git a/src/ray/core_worker/transport/direct_task_transport.cc b/src/ray/core_worker/transport/direct_task_transport.cc index e1338f957..57a1b1e23 100644 --- a/src/ray/core_worker/transport/direct_task_transport.cc +++ b/src/ray/core_worker/transport/direct_task_transport.cc @@ -517,8 +517,7 @@ void CoreWorkerDirectTaskSubmitter::RequestNewWorkerIfNeeded( while (!task_queue.empty()) { auto &task_spec = task_queue.front(); RAY_UNUSED(task_finisher_->MarkPendingTaskFailed( - task_spec.TaskId(), task_spec, rpc::ErrorType::RUNTIME_ENV_SETUP_FAILED, - nullptr)); + task_spec, rpc::ErrorType::RUNTIME_ENV_SETUP_FAILED, nullptr)); task_queue.pop_front(); } if (scheduling_key_entry.CanDelete()) {