diff --git a/BUILD.bazel b/BUILD.bazel index 4ae6c0492..e991687eb 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -448,6 +448,16 @@ cc_test( ], ) +cc_test( + name = "task_manager_test", + srcs = ["src/ray/core_worker/test/task_manager_test.cc"], + copts = COPTS, + deps = [ + ":core_worker_lib", + "@com_google_googletest//:gtest_main", + ], +) + cc_test( name = "scheduling_test", srcs = ["src/ray/common/scheduling/scheduling_test.cc"], diff --git a/src/ray/common/ray_object.cc b/src/ray/common/ray_object.cc index 76e73b4da..faabaaee5 100644 --- a/src/ray/common/ray_object.cc +++ b/src/ray/common/ray_object.cc @@ -2,7 +2,7 @@ namespace ray { -bool RayObject::IsException() const { +bool RayObject::IsException(rpc::ErrorType *error_type) const { if (metadata_ == nullptr) { return false; } @@ -13,6 +13,9 @@ bool RayObject::IsException() const { for (int i = 0; i < error_type_descriptor->value_count(); i++) { const auto error_type_number = error_type_descriptor->value(i)->number(); if (metadata == std::to_string(error_type_number)) { + if (error_type) { + *error_type = rpc::ErrorType(error_type_number); + } return true; } } diff --git a/src/ray/common/ray_object.h b/src/ray/common/ray_object.h index c7d9d73a6..6c82a5d6f 100644 --- a/src/ray/common/ray_object.h +++ b/src/ray/common/ray_object.h @@ -62,7 +62,7 @@ class RayObject { bool HasMetadata() const { return metadata_ != nullptr; } /// Whether the object represents an exception. - bool IsException() const; + bool IsException(rpc::ErrorType *error_type = nullptr) const; /// Whether the object has been promoted to plasma (i.e., since it was too /// large to return directly as part of a gRPC response). diff --git a/src/ray/core_worker/core_worker.cc b/src/ray/core_worker/core_worker.cc index 173499637..d1d8b8703 100644 --- a/src/ray/core_worker/core_worker.cc +++ b/src/ray/core_worker/core_worker.cc @@ -164,6 +164,7 @@ CoreWorker::CoreWorker(const WorkerType worker_type, const Language language, RAY_CHECK_OK(plasma_store_provider_->Put(obj, obj_id)); }, ref_counting_enabled ? reference_counter_ : nullptr, raylet_client_)); + task_manager_.reset(new TaskManager(memory_store_)); resolver_.reset(new LocalDependencyResolver(memory_store_)); // Create an entry for the driver task in the task table. This task is @@ -193,7 +194,8 @@ CoreWorker::CoreWorker(const WorkerType worker_type, const Language language, new rpc::CoreWorkerClient(addr.first, addr.second, *client_call_manager_)); }; direct_actor_submitter_ = std::unique_ptr( - new CoreWorkerDirectActorTaskSubmitter(client_factory, memory_store_)); + new CoreWorkerDirectActorTaskSubmitter(client_factory, memory_store_, + task_manager_)); direct_task_submitter_ = std::unique_ptr(new CoreWorkerDirectTaskSubmitter( @@ -204,7 +206,8 @@ CoreWorker::CoreWorker(const WorkerType worker_type, const Language language, return std::shared_ptr( new RayletClient(std::move(grpc_client))); }, - memory_store_, RayConfig::instance().worker_lease_timeout_milliseconds())); + memory_store_, task_manager_, + RayConfig::instance().worker_lease_timeout_milliseconds())); } CoreWorker::~CoreWorker() { @@ -577,6 +580,7 @@ Status CoreWorker::SubmitTask(const RayFunction &function, return_ids); TaskSpecification task_spec = builder.Build(); if (task_options.is_direct_call) { + task_manager_->AddPendingTask(task_spec); PinObjectReferences(task_spec, TaskTransportType::DIRECT); return direct_task_submitter_->SubmitTask(task_spec); } else { @@ -659,6 +663,7 @@ Status CoreWorker::SubmitActorTask(const ActorID &actor_id, const RayFunction &f Status status; TaskSpecification task_spec = builder.Build(); if (is_direct_call) { + task_manager_->AddPendingTask(task_spec); PinObjectReferences(task_spec, TaskTransportType::DIRECT); status = direct_actor_submitter_->SubmitTask(task_spec); } else { diff --git a/src/ray/core_worker/core_worker.h b/src/ray/core_worker/core_worker.h index ef91c8b06..2746a9f7c 100644 --- a/src/ray/core_worker/core_worker.h +++ b/src/ray/core_worker/core_worker.h @@ -513,6 +513,9 @@ class CoreWorker { /// Fields related to task submission. /// + // Tracks the currently pending tasks. + std::shared_ptr task_manager_; + // Interface to submit tasks directly to other actors. std::unique_ptr direct_actor_submitter_; diff --git a/src/ray/core_worker/task_manager.cc b/src/ray/core_worker/task_manager.cc new file mode 100644 index 000000000..b7df83802 --- /dev/null +++ b/src/ray/core_worker/task_manager.cc @@ -0,0 +1,80 @@ +#include "ray/core_worker/task_manager.h" + +namespace ray { + +void TaskManager::AddPendingTask(const TaskSpecification &spec) { + RAY_LOG(DEBUG) << "Adding pending task " << spec.TaskId(); + absl::MutexLock lock(&mu_); + RAY_CHECK(pending_tasks_.emplace(spec.TaskId(), spec.NumReturns()).second); +} + +void TaskManager::CompletePendingTask(const TaskID &task_id, + const rpc::PushTaskReply &reply) { + RAY_LOG(DEBUG) << "Completing task " << task_id; + { + absl::MutexLock lock(&mu_); + auto it = pending_tasks_.find(task_id); + RAY_CHECK(it != pending_tasks_.end()) + << "Tried to complete task that was not pending " << task_id; + pending_tasks_.erase(it); + } + + for (int i = 0; i < reply.return_objects_size(); i++) { + const auto &return_object = reply.return_objects(i); + ObjectID object_id = ObjectID::FromBinary(return_object.object_id()); + + if (return_object.in_plasma()) { + // Mark it as in plasma with a dummy object. + std::string meta = + std::to_string(static_cast(rpc::ErrorType::OBJECT_IN_PLASMA)); + auto metadata = + const_cast(reinterpret_cast(meta.data())); + auto meta_buffer = std::make_shared(metadata, meta.size()); + RAY_CHECK_OK(in_memory_store_->Put(RayObject(nullptr, meta_buffer), object_id)); + } else { + std::shared_ptr data_buffer; + if (return_object.data().size() > 0) { + data_buffer = std::make_shared( + const_cast( + reinterpret_cast(return_object.data().data())), + return_object.data().size()); + } + std::shared_ptr metadata_buffer; + if (return_object.metadata().size() > 0) { + metadata_buffer = std::make_shared( + const_cast( + reinterpret_cast(return_object.metadata().data())), + return_object.metadata().size()); + } + RAY_CHECK_OK( + in_memory_store_->Put(RayObject(data_buffer, metadata_buffer), object_id)); + } + } +} + +void TaskManager::FailPendingTask(const TaskID &task_id, rpc::ErrorType error_type) { + RAY_LOG(DEBUG) << "Failing task " << task_id; + int64_t num_returns; + { + absl::MutexLock lock(&mu_); + auto it = pending_tasks_.find(task_id); + RAY_CHECK(it != pending_tasks_.end()) + << "Tried to complete task that was not pending " << task_id; + num_returns = it->second; + pending_tasks_.erase(it); + } + + RAY_LOG(DEBUG) << "Treat task as failed. task_id: " << task_id + << ", error_type: " << ErrorType_Name(error_type); + for (int i = 0; i < num_returns; i++) { + const auto object_id = ObjectID::ForTaskReturn( + task_id, /*index=*/i + 1, + /*transport_type=*/static_cast(TaskTransportType::DIRECT)); + std::string meta = std::to_string(static_cast(error_type)); + auto metadata = const_cast(reinterpret_cast(meta.data())); + auto meta_buffer = std::make_shared(metadata, meta.size()); + RAY_CHECK_OK(in_memory_store_->Put(RayObject(nullptr, meta_buffer), object_id)); + } +} + +} // namespace ray diff --git a/src/ray/core_worker/task_manager.h b/src/ray/core_worker/task_manager.h new file mode 100644 index 000000000..6309a3a48 --- /dev/null +++ b/src/ray/core_worker/task_manager.h @@ -0,0 +1,74 @@ +#ifndef RAY_CORE_WORKER_TASK_MANAGER_H +#define RAY_CORE_WORKER_TASK_MANAGER_H + +#include "absl/base/thread_annotations.h" +#include "absl/container/flat_hash_map.h" +#include "absl/synchronization/mutex.h" + +#include "ray/common/id.h" +#include "ray/common/task/task.h" +#include "ray/core_worker/store_provider/memory_store/memory_store.h" +#include "ray/protobuf/core_worker.pb.h" +#include "ray/protobuf/gcs.pb.h" + +namespace ray { + +class TaskFinisherInterface { + public: + virtual void CompletePendingTask(const TaskID &task_id, + const rpc::PushTaskReply &reply) = 0; + + virtual void FailPendingTask(const TaskID &task_id, rpc::ErrorType error_type) = 0; + + virtual ~TaskFinisherInterface() {} +}; + +class TaskManager : public TaskFinisherInterface { + public: + TaskManager(std::shared_ptr in_memory_store) + : in_memory_store_(in_memory_store) {} + + /// Add a task that is pending execution. + /// + /// \param[in] spec The spec of the pending task. + /// \return Void. + void AddPendingTask(const TaskSpecification &spec); + + /// Return whether the task is pending. + /// + /// \param[in] task_id ID of the task to query. + /// \return Whether the task is pending. + bool IsTaskPending(const TaskID &task_id) const { + return pending_tasks_.count(task_id) > 0; + } + + /// Write return objects for a pending task to the memory store. + /// + /// \param[in] task_id ID of the pending task. + /// \param[in] reply Proto response to a direct actor or task call. + /// \return Void. + void CompletePendingTask(const TaskID &task_id, + const rpc::PushTaskReply &reply) override; + + /// Treat a pending task as failed. + /// + /// \param[in] task_id ID of the pending task. + /// \param[in] error_type The type of the specific error. + /// \return Void. + void FailPendingTask(const TaskID &task_id, rpc::ErrorType error_type) override; + + private: + /// Used to store task results. + std::shared_ptr in_memory_store_; + + /// Protects below fields. + absl::Mutex mu_; + + /// Map from task ID to the task's number of return values. This map contains + /// one entry per pending task that we submitted. + absl::flat_hash_map pending_tasks_ GUARDED_BY(mu_); +}; + +} // namespace ray + +#endif // RAY_CORE_WORKER_TASK_MANAGER_H diff --git a/src/ray/core_worker/test/direct_actor_transport_test.cc b/src/ray/core_worker/test/direct_actor_transport_test.cc index 1ef7a006d..769b69523 100644 --- a/src/ray/core_worker/test/direct_actor_transport_test.cc +++ b/src/ray/core_worker/test/direct_actor_transport_test.cc @@ -1,3 +1,4 @@ +#include "gmock/gmock.h" #include "gtest/gtest.h" #include "ray/common/task/task_spec.h" @@ -9,6 +10,8 @@ namespace ray { +using ::testing::_; + class MockWorkerClient : public rpc::CoreWorkerClientInterface { public: ray::Status PushActorTask( @@ -20,10 +23,28 @@ class MockWorkerClient : public rpc::CoreWorkerClientInterface { return Status::OK(); } - std::vector> callbacks; + bool ReplyPushTask(Status status = Status::OK()) { + if (callbacks.size() == 0) { + return false; + } + auto callback = callbacks.front(); + callback(status, rpc::PushTaskReply()); + callbacks.pop_front(); + return true; + } + + std::list> callbacks; uint64_t counter = 0; }; +class MockTaskFinisher : public TaskFinisherInterface { + public: + MockTaskFinisher() {} + + MOCK_METHOD2(CompletePendingTask, void(const TaskID &, const rpc::PushTaskReply &)); + MOCK_METHOD2(FailPendingTask, void(const TaskID &task_id, rpc::ErrorType error_type)); +}; + TaskSpecification CreateActorTaskHelper(ActorID actor_id, int64_t counter) { TaskSpecification task; task.GetMutableMessage().set_task_id(TaskID::Nil().Binary()); @@ -38,11 +59,13 @@ class DirectActorTransportTest : public ::testing::Test { DirectActorTransportTest() : worker_client_(std::shared_ptr(new MockWorkerClient())), store_(std::shared_ptr(new CoreWorkerMemoryStore())), - submitter_([&](const rpc::WorkerAddress &addr) { return worker_client_; }, - store_) {} + task_finisher_(std::make_shared()), + submitter_([&](const rpc::WorkerAddress &addr) { return worker_client_; }, store_, + task_finisher_) {} std::shared_ptr worker_client_; std::shared_ptr store_; + std::shared_ptr task_finisher_; CoreWorkerDirectActorTaskSubmitter submitter_; }; @@ -60,6 +83,13 @@ TEST_F(DirectActorTransportTest, TestSubmitTask) { task = CreateActorTaskHelper(actor_id, 1); ASSERT_TRUE(submitter_.SubmitTask(task).ok()); ASSERT_EQ(worker_client_->callbacks.size(), 2); + + EXPECT_CALL(*task_finisher_, CompletePendingTask(TaskID::Nil(), _)) + .Times(worker_client_->callbacks.size()); + EXPECT_CALL(*task_finisher_, FailPendingTask(_, _)).Times(0); + while (!worker_client_->callbacks.empty()) { + ASSERT_TRUE(worker_client_->ReplyPushTask()); + } } TEST_F(DirectActorTransportTest, TestDependencies) { @@ -119,6 +149,27 @@ TEST_F(DirectActorTransportTest, TestOutOfOrderDependencies) { ASSERT_EQ(worker_client_->callbacks.size(), 2); } +TEST_F(DirectActorTransportTest, TestActorFailure) { + ActorID actor_id = ActorID::Of(JobID::FromInt(0), TaskID::Nil(), 0); + gcs::ActorTableData actor_data; + submitter_.HandleActorUpdate(actor_id, actor_data); + ASSERT_EQ(worker_client_->callbacks.size(), 0); + + // Create two tasks for the actor. + auto task1 = CreateActorTaskHelper(actor_id, 0); + auto task2 = CreateActorTaskHelper(actor_id, 1); + ASSERT_TRUE(submitter_.SubmitTask(task1).ok()); + ASSERT_TRUE(submitter_.SubmitTask(task2).ok()); + ASSERT_EQ(worker_client_->callbacks.size(), 2); + + // Simulate the actor dying. All submitted tasks should get failed. + EXPECT_CALL(*task_finisher_, FailPendingTask(_, _)).Times(2); + EXPECT_CALL(*task_finisher_, CompletePendingTask(_, _)).Times(0); + while (!worker_client_->callbacks.empty()) { + ASSERT_TRUE(worker_client_->ReplyPushTask(Status::IOError(""))); + } +} + } // namespace ray int main(int argc, char **argv) { diff --git a/src/ray/core_worker/test/direct_task_transport_test.cc b/src/ray/core_worker/test/direct_task_transport_test.cc index 80fdbf5d3..3f2a05aed 100644 --- a/src/ray/core_worker/test/direct_task_transport_test.cc +++ b/src/ray/core_worker/test/direct_task_transport_test.cc @@ -23,7 +23,32 @@ class MockWorkerClient : public rpc::CoreWorkerClientInterface { return Status::OK(); } - std::vector> callbacks; + bool ReplyPushTask(Status status = Status::OK()) { + if (callbacks.size() == 0) { + return false; + } + auto callback = callbacks.front(); + callback(status, rpc::PushTaskReply()); + callbacks.pop_front(); + return true; + } + + std::list> callbacks; +}; + +class MockTaskFinisher : public TaskFinisherInterface { + public: + MockTaskFinisher() {} + + void CompletePendingTask(const TaskID &, const rpc::PushTaskReply &) override { + num_tasks_complete++; + } + void FailPendingTask(const TaskID &task_id, rpc::ErrorType error_type) override { + num_tasks_failed++; + } + + int num_tasks_complete = 0; + int num_tasks_failed = 0; }; class MockRayletClient : public WorkerLeaseInterface { @@ -196,8 +221,9 @@ TEST(DirectTaskTransportTest, TestSubmitOneTask) { auto worker_client = std::make_shared(); auto store = std::make_shared(); auto factory = [&](const rpc::WorkerAddress &addr) { return worker_client; }; + auto task_finisher = std::make_shared(); CoreWorkerDirectTaskSubmitter submitter(raylet_client, factory, nullptr, store, - kLongTimeout); + task_finisher, kLongTimeout); TaskSpecification task; task.GetMutableMessage().set_task_id(TaskID::Nil().Binary()); @@ -208,10 +234,14 @@ TEST(DirectTaskTransportTest, TestSubmitOneTask) { ASSERT_TRUE(raylet_client->GrantWorkerLease("localhost", 1234, ClientID::Nil())); ASSERT_EQ(worker_client->callbacks.size(), 1); + ASSERT_EQ(task_finisher->num_tasks_complete, 0); + ASSERT_EQ(task_finisher->num_tasks_failed, 0); - worker_client->callbacks[0](Status::OK(), rpc::PushTaskReply()); + ASSERT_TRUE(worker_client->ReplyPushTask()); ASSERT_EQ(raylet_client->num_workers_returned, 1); ASSERT_EQ(raylet_client->num_workers_disconnected, 0); + ASSERT_EQ(task_finisher->num_tasks_complete, 1); + ASSERT_EQ(task_finisher->num_tasks_failed, 0); } TEST(DirectTaskTransportTest, TestHandleTaskFailure) { @@ -219,18 +249,21 @@ TEST(DirectTaskTransportTest, TestHandleTaskFailure) { auto worker_client = std::make_shared(); auto store = std::make_shared(); auto factory = [&](const rpc::WorkerAddress &addr) { return worker_client; }; + auto task_finisher = std::make_shared(); CoreWorkerDirectTaskSubmitter submitter(raylet_client, factory, nullptr, store, - kLongTimeout); + task_finisher, kLongTimeout); TaskSpecification task; task.GetMutableMessage().set_task_id(TaskID::Nil().Binary()); ASSERT_TRUE(submitter.SubmitTask(task).ok()); ASSERT_TRUE(raylet_client->GrantWorkerLease("localhost", 1234, ClientID::Nil())); // Simulate a system failure, i.e., worker died unexpectedly. - worker_client->callbacks[0](Status::IOError("oops"), rpc::PushTaskReply()); - ASSERT_EQ(worker_client->callbacks.size(), 1); + ASSERT_TRUE(worker_client->ReplyPushTask(Status::IOError("oops"))); + ASSERT_EQ(worker_client->callbacks.size(), 0); ASSERT_EQ(raylet_client->num_workers_returned, 0); ASSERT_EQ(raylet_client->num_workers_disconnected, 1); + ASSERT_EQ(task_finisher->num_tasks_complete, 0); + ASSERT_EQ(task_finisher->num_tasks_failed, 1); } TEST(DirectTaskTransportTest, TestConcurrentWorkerLeases) { @@ -238,8 +271,9 @@ TEST(DirectTaskTransportTest, TestConcurrentWorkerLeases) { auto worker_client = std::make_shared(); auto store = std::make_shared(); auto factory = [&](const rpc::WorkerAddress &addr) { return worker_client; }; + auto task_finisher = std::make_shared(); CoreWorkerDirectTaskSubmitter submitter(raylet_client, factory, nullptr, store, - kLongTimeout); + task_finisher, kLongTimeout); TaskSpecification task1; TaskSpecification task2; TaskSpecification task3; @@ -268,11 +302,13 @@ TEST(DirectTaskTransportTest, TestConcurrentWorkerLeases) { ASSERT_EQ(raylet_client->num_workers_requested, 3); // All workers returned. - for (const auto &cb : worker_client->callbacks) { - cb(Status::OK(), rpc::PushTaskReply()); + while (!worker_client->callbacks.empty()) { + ASSERT_TRUE(worker_client->ReplyPushTask()); } ASSERT_EQ(raylet_client->num_workers_returned, 3); ASSERT_EQ(raylet_client->num_workers_disconnected, 0); + ASSERT_EQ(task_finisher->num_tasks_complete, 3); + ASSERT_EQ(task_finisher->num_tasks_failed, 0); } TEST(DirectTaskTransportTest, TestReuseWorkerLease) { @@ -280,8 +316,9 @@ TEST(DirectTaskTransportTest, TestReuseWorkerLease) { auto worker_client = std::make_shared(); auto store = std::make_shared(); auto factory = [&](const rpc::WorkerAddress &addr) { return worker_client; }; + auto task_finisher = std::make_shared(); CoreWorkerDirectTaskSubmitter submitter(raylet_client, factory, nullptr, store, - kLongTimeout); + task_finisher, kLongTimeout); TaskSpecification task1; TaskSpecification task2; TaskSpecification task3; @@ -300,23 +337,26 @@ TEST(DirectTaskTransportTest, TestReuseWorkerLease) { ASSERT_EQ(raylet_client->num_workers_requested, 2); // Task 1 finishes, Task 2 is scheduled on the same worker. - worker_client->callbacks[0](Status::OK(), rpc::PushTaskReply()); - ASSERT_EQ(worker_client->callbacks.size(), 2); + ASSERT_TRUE(worker_client->ReplyPushTask()); + ASSERT_EQ(worker_client->callbacks.size(), 1); ASSERT_EQ(raylet_client->num_workers_returned, 0); // Task 2 finishes, Task 3 is scheduled on the same worker. - worker_client->callbacks[1](Status::OK(), rpc::PushTaskReply()); - ASSERT_EQ(worker_client->callbacks.size(), 3); + ASSERT_TRUE(worker_client->ReplyPushTask()); + ASSERT_EQ(worker_client->callbacks.size(), 1); ASSERT_EQ(raylet_client->num_workers_returned, 0); // Task 3 finishes, the worker is returned. - worker_client->callbacks[2](Status::OK(), rpc::PushTaskReply()); + ASSERT_TRUE(worker_client->ReplyPushTask()); ASSERT_EQ(raylet_client->num_workers_returned, 1); // The second lease request is returned immediately. ASSERT_TRUE(raylet_client->GrantWorkerLease("localhost", 1001, ClientID::Nil())); + ASSERT_EQ(worker_client->callbacks.size(), 0); ASSERT_EQ(raylet_client->num_workers_returned, 2); ASSERT_EQ(raylet_client->num_workers_disconnected, 0); + ASSERT_EQ(task_finisher->num_tasks_complete, 3); + ASSERT_EQ(task_finisher->num_tasks_failed, 0); } TEST(DirectTaskTransportTest, TestWorkerNotReusedOnError) { @@ -324,8 +364,9 @@ TEST(DirectTaskTransportTest, TestWorkerNotReusedOnError) { auto worker_client = std::make_shared(); auto store = std::make_shared(); auto factory = [&](const rpc::WorkerAddress &addr) { return worker_client; }; + auto task_finisher = std::make_shared(); CoreWorkerDirectTaskSubmitter submitter(raylet_client, factory, nullptr, store, - kLongTimeout); + task_finisher, kLongTimeout); TaskSpecification task1; TaskSpecification task2; task1.GetMutableMessage().set_task_id(TaskID::Nil().Binary()); @@ -341,17 +382,18 @@ TEST(DirectTaskTransportTest, TestWorkerNotReusedOnError) { ASSERT_EQ(raylet_client->num_workers_requested, 2); // Task 1 finishes with failure; the worker is returned. - worker_client->callbacks[0](Status::IOError("worker dead"), rpc::PushTaskReply()); - ASSERT_EQ(worker_client->callbacks.size(), 1); + ASSERT_TRUE(worker_client->ReplyPushTask(Status::IOError("worker dead"))); + ASSERT_EQ(worker_client->callbacks.size(), 0); ASSERT_EQ(raylet_client->num_workers_returned, 0); ASSERT_EQ(raylet_client->num_workers_disconnected, 1); // Task 2 runs successfully on the second worker. ASSERT_TRUE(raylet_client->GrantWorkerLease("localhost", 1001, ClientID::Nil())); - ASSERT_EQ(worker_client->callbacks.size(), 2); - worker_client->callbacks[1](Status::OK(), rpc::PushTaskReply()); + ASSERT_TRUE(worker_client->ReplyPushTask()); ASSERT_EQ(raylet_client->num_workers_returned, 1); ASSERT_EQ(raylet_client->num_workers_disconnected, 1); + ASSERT_EQ(task_finisher->num_tasks_complete, 1); + ASSERT_EQ(task_finisher->num_tasks_failed, 1); } TEST(DirectTaskTransportTest, TestSpillback) { @@ -369,8 +411,9 @@ TEST(DirectTaskTransportTest, TestSpillback) { remote_lease_clients[raylet_id] = client; return client; }; + auto task_finisher = std::make_shared(); CoreWorkerDirectTaskSubmitter submitter(raylet_client, factory, lease_client_factory, - store, kLongTimeout); + store, task_finisher, kLongTimeout); TaskSpecification task; task.GetMutableMessage().set_task_id(TaskID::Nil().Binary()); @@ -389,15 +432,15 @@ TEST(DirectTaskTransportTest, TestSpillback) { // Trigger retry at the remote node. ASSERT_TRUE(remote_lease_clients[remote_raylet_id]->GrantWorkerLease("remote", 1234, ClientID::Nil())); - ASSERT_EQ(worker_client->callbacks.size(), 1); // The worker is returned to the remote node, not the local one. - worker_client->callbacks[0](Status::OK(), rpc::PushTaskReply()); + ASSERT_TRUE(worker_client->ReplyPushTask()); ASSERT_EQ(raylet_client->num_workers_returned, 0); ASSERT_EQ(remote_lease_clients[remote_raylet_id]->num_workers_returned, 1); - ASSERT_EQ(raylet_client->num_workers_disconnected, 0); ASSERT_EQ(remote_lease_clients[remote_raylet_id]->num_workers_disconnected, 0); + ASSERT_EQ(task_finisher->num_tasks_complete, 1); + ASSERT_EQ(task_finisher->num_tasks_failed, 0); } TEST(DirectTaskTransportTest, TestWorkerLeaseTimeout) { @@ -405,7 +448,9 @@ TEST(DirectTaskTransportTest, TestWorkerLeaseTimeout) { auto worker_client = std::make_shared(); auto store = std::make_shared(); auto factory = [&](const rpc::WorkerAddress &addr) { return worker_client; }; + auto task_finisher = std::make_shared(); CoreWorkerDirectTaskSubmitter submitter(raylet_client, factory, nullptr, store, + task_finisher, /*lease_timeout_ms=*/5); TaskSpecification task1; TaskSpecification task2; @@ -421,13 +466,11 @@ TEST(DirectTaskTransportTest, TestWorkerLeaseTimeout) { // Task 1 is pushed. ASSERT_TRUE(raylet_client->GrantWorkerLease("localhost", 1000, ClientID::Nil())); - ASSERT_EQ(worker_client->callbacks.size(), 1); ASSERT_EQ(raylet_client->num_workers_requested, 2); // Task 1 finishes with failure; the worker is returned due to the error even though // it hasn't timed out. - worker_client->callbacks[0](Status::IOError("worker dead"), rpc::PushTaskReply()); - ASSERT_EQ(worker_client->callbacks.size(), 1); + ASSERT_TRUE(worker_client->ReplyPushTask(Status::IOError("worker dead"))); ASSERT_EQ(raylet_client->num_workers_returned, 0); ASSERT_EQ(raylet_client->num_workers_disconnected, 1); @@ -435,16 +478,15 @@ TEST(DirectTaskTransportTest, TestWorkerLeaseTimeout) { // timeout. ASSERT_TRUE(raylet_client->GrantWorkerLease("localhost", 1001, ClientID::Nil())); usleep(10 * 1000); // Sleep for 10ms, causing the lease to time out. - ASSERT_EQ(worker_client->callbacks.size(), 2); - worker_client->callbacks[1](Status::OK(), rpc::PushTaskReply()); + ASSERT_TRUE(worker_client->ReplyPushTask()); ASSERT_EQ(raylet_client->num_workers_returned, 1); ASSERT_EQ(raylet_client->num_workers_disconnected, 1); // Task 3 runs successfully on the third worker; the worker is returned even though it // hasn't timed out. ASSERT_TRUE(raylet_client->GrantWorkerLease("localhost", 1002, ClientID::Nil())); - ASSERT_EQ(worker_client->callbacks.size(), 3); - worker_client->callbacks[2](Status::OK(), rpc::PushTaskReply()); + ASSERT_TRUE(worker_client->ReplyPushTask()); + ASSERT_EQ(worker_client->callbacks.size(), 0); ASSERT_EQ(raylet_client->num_workers_returned, 2); ASSERT_EQ(raylet_client->num_workers_disconnected, 1); } diff --git a/src/ray/core_worker/test/task_manager_test.cc b/src/ray/core_worker/test/task_manager_test.cc new file mode 100644 index 000000000..a276980cd --- /dev/null +++ b/src/ray/core_worker/test/task_manager_test.cc @@ -0,0 +1,77 @@ +#include "gtest/gtest.h" + +#include "ray/common/task/task_spec.h" +#include "ray/core_worker/store_provider/memory_store/memory_store.h" +#include "ray/core_worker/task_manager.h" +#include "ray/util/test_util.h" + +namespace ray { + +TaskSpecification CreateTaskHelper(uint64_t num_returns) { + TaskSpecification task; + task.GetMutableMessage().set_task_id(TaskID::ForFakeTask().Binary()); + task.GetMutableMessage().set_num_returns(num_returns); + return task; +} + +class TaskManagerTest : public ::testing::Test { + public: + TaskManagerTest() + : store_(std::shared_ptr(new CoreWorkerMemoryStore())), + manager_(store_) {} + + std::shared_ptr store_; + TaskManager manager_; +}; + +TEST_F(TaskManagerTest, TestTaskSuccess) { + auto spec = CreateTaskHelper(1); + ASSERT_FALSE(manager_.IsTaskPending(spec.TaskId())); + manager_.AddPendingTask(spec); + ASSERT_TRUE(manager_.IsTaskPending(spec.TaskId())); + auto return_id = spec.ReturnId(0, TaskTransportType::DIRECT); + WorkerContext ctx(WorkerType::WORKER, JobID::FromInt(0)); + + rpc::PushTaskReply reply; + auto return_object = reply.add_return_objects(); + return_object->set_object_id(return_id.Binary()); + auto data = GenerateRandomBuffer(); + return_object->set_data(data->Data(), data->Size()); + manager_.CompletePendingTask(spec.TaskId(), reply); + ASSERT_FALSE(manager_.IsTaskPending(spec.TaskId())); + + std::vector> results; + RAY_CHECK_OK(store_->Get({return_id}, 1, -1, ctx, false, &results)); + ASSERT_EQ(results.size(), 1); + ASSERT_FALSE(results[0]->IsException()); + ASSERT_EQ(std::memcmp(results[0]->GetData()->Data(), return_object->data().data(), + return_object->data().size()), + 0); +} + +TEST_F(TaskManagerTest, TestTaskFailure) { + auto spec = CreateTaskHelper(1); + ASSERT_FALSE(manager_.IsTaskPending(spec.TaskId())); + manager_.AddPendingTask(spec); + ASSERT_TRUE(manager_.IsTaskPending(spec.TaskId())); + auto return_id = spec.ReturnId(0, TaskTransportType::DIRECT); + WorkerContext ctx(WorkerType::WORKER, JobID::FromInt(0)); + + auto error = rpc::ErrorType::WORKER_DIED; + manager_.FailPendingTask(spec.TaskId(), error); + ASSERT_FALSE(manager_.IsTaskPending(spec.TaskId())); + + std::vector> results; + RAY_CHECK_OK(store_->Get({return_id}, 1, -1, ctx, false, &results)); + ASSERT_EQ(results.size(), 1); + rpc::ErrorType stored_error; + ASSERT_TRUE(results[0]->IsException(&stored_error)); + ASSERT_EQ(stored_error, error); +} + +} // namespace ray + +int main(int argc, char **argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/src/ray/core_worker/transport/direct_actor_transport.cc b/src/ray/core_worker/transport/direct_actor_transport.cc index 9137afd0f..ac9b7b8db 100644 --- a/src/ray/core_worker/transport/direct_actor_transport.cc +++ b/src/ray/core_worker/transport/direct_actor_transport.cc @@ -11,57 +11,6 @@ int64_t GetRequestNumber(const std::unique_ptr &request) { return request->task_spec().actor_task_spec().actor_counter(); } -void TreatTaskAsFailed(const TaskID &task_id, int num_returns, - const rpc::ErrorType &error_type, - std::shared_ptr &in_memory_store) { - RAY_LOG(DEBUG) << "Treat task as failed. task_id: " << task_id - << ", error_type: " << ErrorType_Name(error_type); - for (int i = 0; i < num_returns; i++) { - const auto object_id = ObjectID::ForTaskReturn( - task_id, /*index=*/i + 1, - /*transport_type=*/static_cast(TaskTransportType::DIRECT)); - std::string meta = std::to_string(static_cast(error_type)); - auto metadata = const_cast(reinterpret_cast(meta.data())); - auto meta_buffer = std::make_shared(metadata, meta.size()); - RAY_CHECK_OK(in_memory_store->Put(RayObject(nullptr, meta_buffer), object_id)); - } -} - -void WriteObjectsToMemoryStore(const rpc::PushTaskReply &reply, - std::shared_ptr &in_memory_store) { - for (int i = 0; i < reply.return_objects_size(); i++) { - const auto &return_object = reply.return_objects(i); - ObjectID object_id = ObjectID::FromBinary(return_object.object_id()); - - if (return_object.in_plasma()) { - // Mark it as in plasma with a dummy object. - std::string meta = - std::to_string(static_cast(rpc::ErrorType::OBJECT_IN_PLASMA)); - auto metadata = - const_cast(reinterpret_cast(meta.data())); - auto meta_buffer = std::make_shared(metadata, meta.size()); - RAY_CHECK_OK(in_memory_store->Put(RayObject(nullptr, meta_buffer), object_id)); - } else { - std::shared_ptr data_buffer; - if (return_object.data().size() > 0) { - data_buffer = std::make_shared( - const_cast( - reinterpret_cast(return_object.data().data())), - return_object.data().size()); - } - std::shared_ptr metadata_buffer; - if (return_object.metadata().size() > 0) { - metadata_buffer = std::make_shared( - const_cast( - reinterpret_cast(return_object.metadata().data())), - return_object.metadata().size()); - } - RAY_CHECK_OK( - in_memory_store->Put(RayObject(data_buffer, metadata_buffer), object_id)); - } - } -} - Status CoreWorkerDirectActorTaskSubmitter::SubmitTask(TaskSpecification task_spec) { RAY_LOG(DEBUG) << "Submitting task " << task_spec.TaskId(); RAY_CHECK(task_spec.IsActorTask()); @@ -69,7 +18,6 @@ Status CoreWorkerDirectActorTaskSubmitter::SubmitTask(TaskSpecification task_spe resolver_.ResolveDependencies(task_spec, [this, task_spec]() mutable { const auto &actor_id = task_spec.ActorId(); const auto task_id = task_spec.TaskId(); - const auto num_returns = task_spec.NumReturns(); auto request = std::unique_ptr(new rpc::PushTaskRequest); request->mutable_task_spec()->Swap(&task_spec.GetMutableMessage()); @@ -97,8 +45,7 @@ Status CoreWorkerDirectActorTaskSubmitter::SubmitTask(TaskSpecification task_spe } else { // Actor is dead, treat the task as failure. RAY_CHECK(iter->second.state_ == ActorTableData::DEAD); - TreatTaskAsFailed(task_id, num_returns, rpc::ErrorType::ACTOR_DIED, - in_memory_store_); + task_finisher_->FailPendingTask(task_id, rpc::ErrorType::ACTOR_DIED); } }); @@ -130,19 +77,6 @@ void CoreWorkerDirectActorTaskSubmitter::HandleActorUpdate( // Remove rpc client if it's dead or being reconstructed. rpc_clients_.erase(actor_id); - // For tasks that have been sent and are waiting for replies, treat them - // as failed when the destination actor is dead or reconstructing. - auto iter = waiting_reply_tasks_.find(actor_id); - if (iter != waiting_reply_tasks_.end()) { - for (const auto &entry : iter->second) { - const auto &task_id = entry.first; - const auto num_returns = entry.second; - TreatTaskAsFailed(task_id, num_returns, rpc::ErrorType::ACTOR_DIED, - in_memory_store_); - } - waiting_reply_tasks_.erase(actor_id); - } - // If there are pending requests, treat the pending tasks as failed. auto pending_it = pending_requests_.find(actor_id); if (pending_it != pending_requests_.end()) { @@ -150,15 +84,16 @@ void CoreWorkerDirectActorTaskSubmitter::HandleActorUpdate( while (head != pending_it->second.end()) { auto request = std::move(head->second); head = pending_it->second.erase(head); - - TreatTaskAsFailed(TaskID::FromBinary(request->task_spec().task_id()), - request->task_spec().num_returns(), rpc::ErrorType::ACTOR_DIED, - in_memory_store_); + auto task_id = TaskID::FromBinary(request->task_spec().task_id()); + task_finisher_->FailPendingTask(task_id, rpc::ErrorType::ACTOR_DIED); } pending_requests_.erase(pending_it); } next_sequence_number_.erase(actor_id); + + // No need to clean up tasks that have been sent and are waiting for + // replies. They will be treated as failed once the connection dies. } } @@ -182,7 +117,6 @@ void CoreWorkerDirectActorTaskSubmitter::PushActorTask( rpc::CoreWorkerClientInterface &client, std::unique_ptr request, const ActorID &actor_id, const TaskID &task_id, int num_returns) { RAY_LOG(DEBUG) << "Pushing task " << task_id << " to actor " << actor_id; - waiting_reply_tasks_[actor_id].insert(std::make_pair(task_id, num_returns)); auto task_number = GetRequestNumber(request); RAY_CHECK(next_sequence_number_[actor_id] == task_number) @@ -190,24 +124,19 @@ void CoreWorkerDirectActorTaskSubmitter::PushActorTask( next_sequence_number_[actor_id]++; auto status = client.PushActorTask( - std::move(request), [this, actor_id, task_id, num_returns]( - Status status, const rpc::PushTaskReply &reply) { - { - std::unique_lock guard(mutex_); - waiting_reply_tasks_[actor_id].erase(task_id); - } + std::move(request), + [this, task_id](Status status, const rpc::PushTaskReply &reply) { if (!status.ok()) { // Note that this might be the __ray_terminate__ task, so we don't log // loudly with ERROR here. RAY_LOG(INFO) << "Task failed with error: " << status; - TreatTaskAsFailed(task_id, num_returns, rpc::ErrorType::ACTOR_DIED, - in_memory_store_); - return; + task_finisher_->FailPendingTask(task_id, rpc::ErrorType::ACTOR_DIED); + } else { + task_finisher_->CompletePendingTask(task_id, reply); } - WriteObjectsToMemoryStore(reply, in_memory_store_); }); if (!status.ok()) { - TreatTaskAsFailed(task_id, num_returns, rpc::ErrorType::ACTOR_DIED, in_memory_store_); + task_finisher_->FailPendingTask(task_id, rpc::ErrorType::ACTOR_DIED); } } diff --git a/src/ray/core_worker/transport/direct_actor_transport.h b/src/ray/core_worker/transport/direct_actor_transport.h index 18e029b45..c93e860a9 100644 --- a/src/ray/core_worker/transport/direct_actor_transport.h +++ b/src/ray/core_worker/transport/direct_actor_transport.h @@ -16,6 +16,7 @@ #include "ray/common/ray_object.h" #include "ray/core_worker/context.h" #include "ray/core_worker/store_provider/memory_store/memory_store.h" +#include "ray/core_worker/task_manager.h" #include "ray/core_worker/transport/dependency_resolver.h" #include "ray/gcs/redis_gcs_client.h" #include "ray/rpc/grpc_server.h" @@ -28,25 +29,6 @@ namespace ray { /// The max time to wait for out-of-order tasks. const int kMaxReorderWaitSeconds = 30; -/// Treat a task as failed. -/// -/// \param[in] task_id The ID of a task. -/// \param[in] num_returns Number of return objects. -/// \param[in] error_type The type of the specific error. -/// \param[in] in_memory_store The memory store to write to. -/// \return Void. -void TreatTaskAsFailed(const TaskID &task_id, int num_returns, - const rpc::ErrorType &error_type, - std::shared_ptr &in_memory_store); - -/// Write return objects to the memory store. -/// -/// \param[in] reply Proto response to a direct actor or task call. -/// \param[in] in_memory_store The memory store to write to. -/// \return Void. -void WriteObjectsToMemoryStore(const rpc::PushTaskReply &reply, - std::shared_ptr &in_memory_store); - /// In direct actor call task submitter and receiver, a task is directly submitted /// to the actor that will execute it. @@ -66,10 +48,11 @@ struct ActorStateData { class CoreWorkerDirectActorTaskSubmitter { public: CoreWorkerDirectActorTaskSubmitter(rpc::ClientFactoryFn client_factory, - std::shared_ptr store) + std::shared_ptr store, + std::shared_ptr task_finisher) : client_factory_(client_factory), - in_memory_store_(store), - resolver_(in_memory_store_) {} + resolver_(store), + task_finisher_(task_finisher) {} /// Submit a task to an actor for execution. /// @@ -138,15 +121,12 @@ class CoreWorkerDirectActorTaskSubmitter { /// actor. std::unordered_map next_sequence_number_; - /// Map from actor id to the tasks that are waiting for reply. - std::unordered_map> waiting_reply_tasks_; - - /// The in-memory store. - std::shared_ptr in_memory_store_; - /// Resolve direct call object dependencies; LocalDependencyResolver resolver_; + /// Used to complete tasks. + std::shared_ptr task_finisher_; + friend class CoreWorkerTest; }; diff --git a/src/ray/core_worker/transport/direct_task_transport.cc b/src/ray/core_worker/transport/direct_task_transport.cc index 418ee55e8..7b092c8fd 100644 --- a/src/ray/core_worker/transport/direct_task_transport.cc +++ b/src/ray/core_worker/transport/direct_task_transport.cc @@ -43,12 +43,16 @@ void CoreWorkerDirectTaskSubmitter::OnWorkerIdle(const rpc::WorkerAddress &addr, if (was_error || queued_tasks_.empty() || current_time_ms() > entry.second) { RAY_CHECK_OK(entry.first->ReturnWorker(addr.second, was_error)); worker_to_lease_client_.erase(addr); - } else { + } else if (!queued_tasks_.empty()) { auto &client = *client_cache_[addr]; PushNormalTask(addr, client, queued_tasks_.front()); queued_tasks_.pop_front(); } - RequestNewWorkerIfNeeded(queued_tasks_.front()); + + // There are more tasks to run, so try to get another worker. + if (!queued_tasks_.empty()) { + RequestNewWorkerIfNeeded(queued_tasks_.front()); + } } std::shared_ptr @@ -129,23 +133,21 @@ void CoreWorkerDirectTaskSubmitter::PushNormalTask(const rpc::WorkerAddress &add rpc::CoreWorkerClientInterface &client, TaskSpecification &task_spec) { auto task_id = task_spec.TaskId(); - auto num_returns = task_spec.NumReturns(); auto request = std::unique_ptr(new rpc::PushTaskRequest); request->mutable_task_spec()->Swap(&task_spec.GetMutableMessage()); auto status = client.PushNormalTask( std::move(request), - [this, task_id, num_returns, addr](Status status, const rpc::PushTaskReply &reply) { + [this, task_id, addr](Status status, const rpc::PushTaskReply &reply) { OnWorkerIdle(addr, /*error=*/!status.ok()); if (!status.ok()) { - TreatTaskAsFailed(task_id, num_returns, rpc::ErrorType::WORKER_DIED, - in_memory_store_); - return; + task_finisher_->FailPendingTask(task_id, rpc::ErrorType::WORKER_DIED); + } else { + task_finisher_->CompletePendingTask(task_id, reply); } - WriteObjectsToMemoryStore(reply, in_memory_store_); }); if (!status.ok()) { - TreatTaskAsFailed(task_id, num_returns, rpc::ErrorType::WORKER_DIED, - in_memory_store_); + // TODO(swang): add unit test for this. + task_finisher_->FailPendingTask(task_id, rpc::ErrorType::WORKER_DIED); } } }; // namespace ray diff --git a/src/ray/core_worker/transport/direct_task_transport.h b/src/ray/core_worker/transport/direct_task_transport.h index b9459ea55..39e786abe 100644 --- a/src/ray/core_worker/transport/direct_task_transport.h +++ b/src/ray/core_worker/transport/direct_task_transport.h @@ -3,10 +3,12 @@ #include "absl/base/thread_annotations.h" #include "absl/synchronization/mutex.h" + #include "ray/common/id.h" #include "ray/common/ray_object.h" #include "ray/core_worker/context.h" #include "ray/core_worker/store_provider/memory_store/memory_store.h" +#include "ray/core_worker/task_manager.h" #include "ray/core_worker/transport/dependency_resolver.h" #include "ray/core_worker/transport/direct_actor_transport.h" #include "ray/raylet/raylet_client.h" @@ -24,12 +26,13 @@ class CoreWorkerDirectTaskSubmitter { rpc::ClientFactoryFn client_factory, LeaseClientFactoryFn lease_client_factory, std::shared_ptr store, + std::shared_ptr task_finisher, int64_t lease_timeout_ms) : local_lease_client_(lease_client), client_factory_(client_factory), lease_client_factory_(lease_client_factory), - in_memory_store_(store), - resolver_(in_memory_store_), + resolver_(store), + task_finisher_(task_finisher), lease_timeout_ms_(lease_timeout_ms) {} /// Schedule a task for direct submission to a worker. @@ -81,12 +84,12 @@ class CoreWorkerDirectTaskSubmitter { /// Factory for producing new clients to request leases from remote nodes. LeaseClientFactoryFn lease_client_factory_; - /// The store provider. - std::shared_ptr in_memory_store_; - /// Resolve local and remote dependencies; LocalDependencyResolver resolver_; + /// Used to complete tasks. + std::shared_ptr task_finisher_; + /// The timeout for worker leases; after this duration, workers will be returned /// to the raylet. int64_t lease_timeout_ms_;