Track pending tasks with TaskManager (#6259)

* TaskStateManager to track and complete pending tasks

* Convert actor transport to use task state manager

* Refactor direct actor transport to use TaskStateManager

* rename

* Unit test

* doc

* IsTaskPending

* Fix?

* Shared ptr

* HUH?

* Update src/ray/core_worker/task_manager.cc

Co-Authored-By: Zhijun Fu <37800433+zhijunfu@users.noreply.github.com>

* Revert "HUH?"

This reverts commit f80f0ba204ff4da5e0b03191fa0d5a4d9f552434.

* Fix memory issue

* oops
This commit is contained in:
Stephanie Wang 2019-11-25 16:37:26 -08:00 committed by GitHub
parent ed5154d7fe
commit f6a0408173
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
14 changed files with 423 additions and 164 deletions

View file

@ -448,6 +448,16 @@ cc_test(
],
)
cc_test(
name = "task_manager_test",
srcs = ["src/ray/core_worker/test/task_manager_test.cc"],
copts = COPTS,
deps = [
":core_worker_lib",
"@com_google_googletest//:gtest_main",
],
)
cc_test(
name = "scheduling_test",
srcs = ["src/ray/common/scheduling/scheduling_test.cc"],

View file

@ -2,7 +2,7 @@
namespace ray {
bool RayObject::IsException() const {
bool RayObject::IsException(rpc::ErrorType *error_type) const {
if (metadata_ == nullptr) {
return false;
}
@ -13,6 +13,9 @@ bool RayObject::IsException() const {
for (int i = 0; i < error_type_descriptor->value_count(); i++) {
const auto error_type_number = error_type_descriptor->value(i)->number();
if (metadata == std::to_string(error_type_number)) {
if (error_type) {
*error_type = rpc::ErrorType(error_type_number);
}
return true;
}
}

View file

@ -62,7 +62,7 @@ class RayObject {
bool HasMetadata() const { return metadata_ != nullptr; }
/// Whether the object represents an exception.
bool IsException() const;
bool IsException(rpc::ErrorType *error_type = nullptr) const;
/// Whether the object has been promoted to plasma (i.e., since it was too
/// large to return directly as part of a gRPC response).

View file

@ -164,6 +164,7 @@ CoreWorker::CoreWorker(const WorkerType worker_type, const Language language,
RAY_CHECK_OK(plasma_store_provider_->Put(obj, obj_id));
},
ref_counting_enabled ? reference_counter_ : nullptr, raylet_client_));
task_manager_.reset(new TaskManager(memory_store_));
resolver_.reset(new LocalDependencyResolver(memory_store_));
// Create an entry for the driver task in the task table. This task is
@ -193,7 +194,8 @@ CoreWorker::CoreWorker(const WorkerType worker_type, const Language language,
new rpc::CoreWorkerClient(addr.first, addr.second, *client_call_manager_));
};
direct_actor_submitter_ = std::unique_ptr<CoreWorkerDirectActorTaskSubmitter>(
new CoreWorkerDirectActorTaskSubmitter(client_factory, memory_store_));
new CoreWorkerDirectActorTaskSubmitter(client_factory, memory_store_,
task_manager_));
direct_task_submitter_ =
std::unique_ptr<CoreWorkerDirectTaskSubmitter>(new CoreWorkerDirectTaskSubmitter(
@ -204,7 +206,8 @@ CoreWorker::CoreWorker(const WorkerType worker_type, const Language language,
return std::shared_ptr<RayletClient>(
new RayletClient(std::move(grpc_client)));
},
memory_store_, RayConfig::instance().worker_lease_timeout_milliseconds()));
memory_store_, task_manager_,
RayConfig::instance().worker_lease_timeout_milliseconds()));
}
CoreWorker::~CoreWorker() {
@ -577,6 +580,7 @@ Status CoreWorker::SubmitTask(const RayFunction &function,
return_ids);
TaskSpecification task_spec = builder.Build();
if (task_options.is_direct_call) {
task_manager_->AddPendingTask(task_spec);
PinObjectReferences(task_spec, TaskTransportType::DIRECT);
return direct_task_submitter_->SubmitTask(task_spec);
} else {
@ -659,6 +663,7 @@ Status CoreWorker::SubmitActorTask(const ActorID &actor_id, const RayFunction &f
Status status;
TaskSpecification task_spec = builder.Build();
if (is_direct_call) {
task_manager_->AddPendingTask(task_spec);
PinObjectReferences(task_spec, TaskTransportType::DIRECT);
status = direct_actor_submitter_->SubmitTask(task_spec);
} else {

View file

@ -513,6 +513,9 @@ class CoreWorker {
/// Fields related to task submission.
///
// Tracks the currently pending tasks.
std::shared_ptr<TaskManager> task_manager_;
// Interface to submit tasks directly to other actors.
std::unique_ptr<CoreWorkerDirectActorTaskSubmitter> direct_actor_submitter_;

View file

@ -0,0 +1,80 @@
#include "ray/core_worker/task_manager.h"
namespace ray {
void TaskManager::AddPendingTask(const TaskSpecification &spec) {
RAY_LOG(DEBUG) << "Adding pending task " << spec.TaskId();
absl::MutexLock lock(&mu_);
RAY_CHECK(pending_tasks_.emplace(spec.TaskId(), spec.NumReturns()).second);
}
void TaskManager::CompletePendingTask(const TaskID &task_id,
const rpc::PushTaskReply &reply) {
RAY_LOG(DEBUG) << "Completing task " << task_id;
{
absl::MutexLock lock(&mu_);
auto it = pending_tasks_.find(task_id);
RAY_CHECK(it != pending_tasks_.end())
<< "Tried to complete task that was not pending " << task_id;
pending_tasks_.erase(it);
}
for (int i = 0; i < reply.return_objects_size(); i++) {
const auto &return_object = reply.return_objects(i);
ObjectID object_id = ObjectID::FromBinary(return_object.object_id());
if (return_object.in_plasma()) {
// Mark it as in plasma with a dummy object.
std::string meta =
std::to_string(static_cast<int>(rpc::ErrorType::OBJECT_IN_PLASMA));
auto metadata =
const_cast<uint8_t *>(reinterpret_cast<const uint8_t *>(meta.data()));
auto meta_buffer = std::make_shared<LocalMemoryBuffer>(metadata, meta.size());
RAY_CHECK_OK(in_memory_store_->Put(RayObject(nullptr, meta_buffer), object_id));
} else {
std::shared_ptr<LocalMemoryBuffer> data_buffer;
if (return_object.data().size() > 0) {
data_buffer = std::make_shared<LocalMemoryBuffer>(
const_cast<uint8_t *>(
reinterpret_cast<const uint8_t *>(return_object.data().data())),
return_object.data().size());
}
std::shared_ptr<LocalMemoryBuffer> metadata_buffer;
if (return_object.metadata().size() > 0) {
metadata_buffer = std::make_shared<LocalMemoryBuffer>(
const_cast<uint8_t *>(
reinterpret_cast<const uint8_t *>(return_object.metadata().data())),
return_object.metadata().size());
}
RAY_CHECK_OK(
in_memory_store_->Put(RayObject(data_buffer, metadata_buffer), object_id));
}
}
}
void TaskManager::FailPendingTask(const TaskID &task_id, rpc::ErrorType error_type) {
RAY_LOG(DEBUG) << "Failing task " << task_id;
int64_t num_returns;
{
absl::MutexLock lock(&mu_);
auto it = pending_tasks_.find(task_id);
RAY_CHECK(it != pending_tasks_.end())
<< "Tried to complete task that was not pending " << task_id;
num_returns = it->second;
pending_tasks_.erase(it);
}
RAY_LOG(DEBUG) << "Treat task as failed. task_id: " << task_id
<< ", error_type: " << ErrorType_Name(error_type);
for (int i = 0; i < num_returns; i++) {
const auto object_id = ObjectID::ForTaskReturn(
task_id, /*index=*/i + 1,
/*transport_type=*/static_cast<int>(TaskTransportType::DIRECT));
std::string meta = std::to_string(static_cast<int>(error_type));
auto metadata = const_cast<uint8_t *>(reinterpret_cast<const uint8_t *>(meta.data()));
auto meta_buffer = std::make_shared<LocalMemoryBuffer>(metadata, meta.size());
RAY_CHECK_OK(in_memory_store_->Put(RayObject(nullptr, meta_buffer), object_id));
}
}
} // namespace ray

View file

@ -0,0 +1,74 @@
#ifndef RAY_CORE_WORKER_TASK_MANAGER_H
#define RAY_CORE_WORKER_TASK_MANAGER_H
#include "absl/base/thread_annotations.h"
#include "absl/container/flat_hash_map.h"
#include "absl/synchronization/mutex.h"
#include "ray/common/id.h"
#include "ray/common/task/task.h"
#include "ray/core_worker/store_provider/memory_store/memory_store.h"
#include "ray/protobuf/core_worker.pb.h"
#include "ray/protobuf/gcs.pb.h"
namespace ray {
class TaskFinisherInterface {
public:
virtual void CompletePendingTask(const TaskID &task_id,
const rpc::PushTaskReply &reply) = 0;
virtual void FailPendingTask(const TaskID &task_id, rpc::ErrorType error_type) = 0;
virtual ~TaskFinisherInterface() {}
};
class TaskManager : public TaskFinisherInterface {
public:
TaskManager(std::shared_ptr<CoreWorkerMemoryStore> in_memory_store)
: in_memory_store_(in_memory_store) {}
/// Add a task that is pending execution.
///
/// \param[in] spec The spec of the pending task.
/// \return Void.
void AddPendingTask(const TaskSpecification &spec);
/// Return whether the task is pending.
///
/// \param[in] task_id ID of the task to query.
/// \return Whether the task is pending.
bool IsTaskPending(const TaskID &task_id) const {
return pending_tasks_.count(task_id) > 0;
}
/// Write return objects for a pending task to the memory store.
///
/// \param[in] task_id ID of the pending task.
/// \param[in] reply Proto response to a direct actor or task call.
/// \return Void.
void CompletePendingTask(const TaskID &task_id,
const rpc::PushTaskReply &reply) override;
/// Treat a pending task as failed.
///
/// \param[in] task_id ID of the pending task.
/// \param[in] error_type The type of the specific error.
/// \return Void.
void FailPendingTask(const TaskID &task_id, rpc::ErrorType error_type) override;
private:
/// Used to store task results.
std::shared_ptr<CoreWorkerMemoryStore> in_memory_store_;
/// Protects below fields.
absl::Mutex mu_;
/// Map from task ID to the task's number of return values. This map contains
/// one entry per pending task that we submitted.
absl::flat_hash_map<TaskID, int64_t> pending_tasks_ GUARDED_BY(mu_);
};
} // namespace ray
#endif // RAY_CORE_WORKER_TASK_MANAGER_H

View file

@ -1,3 +1,4 @@
#include "gmock/gmock.h"
#include "gtest/gtest.h"
#include "ray/common/task/task_spec.h"
@ -9,6 +10,8 @@
namespace ray {
using ::testing::_;
class MockWorkerClient : public rpc::CoreWorkerClientInterface {
public:
ray::Status PushActorTask(
@ -20,10 +23,28 @@ class MockWorkerClient : public rpc::CoreWorkerClientInterface {
return Status::OK();
}
std::vector<rpc::ClientCallback<rpc::PushTaskReply>> callbacks;
bool ReplyPushTask(Status status = Status::OK()) {
if (callbacks.size() == 0) {
return false;
}
auto callback = callbacks.front();
callback(status, rpc::PushTaskReply());
callbacks.pop_front();
return true;
}
std::list<rpc::ClientCallback<rpc::PushTaskReply>> callbacks;
uint64_t counter = 0;
};
class MockTaskFinisher : public TaskFinisherInterface {
public:
MockTaskFinisher() {}
MOCK_METHOD2(CompletePendingTask, void(const TaskID &, const rpc::PushTaskReply &));
MOCK_METHOD2(FailPendingTask, void(const TaskID &task_id, rpc::ErrorType error_type));
};
TaskSpecification CreateActorTaskHelper(ActorID actor_id, int64_t counter) {
TaskSpecification task;
task.GetMutableMessage().set_task_id(TaskID::Nil().Binary());
@ -38,11 +59,13 @@ class DirectActorTransportTest : public ::testing::Test {
DirectActorTransportTest()
: worker_client_(std::shared_ptr<MockWorkerClient>(new MockWorkerClient())),
store_(std::shared_ptr<CoreWorkerMemoryStore>(new CoreWorkerMemoryStore())),
submitter_([&](const rpc::WorkerAddress &addr) { return worker_client_; },
store_) {}
task_finisher_(std::make_shared<MockTaskFinisher>()),
submitter_([&](const rpc::WorkerAddress &addr) { return worker_client_; }, store_,
task_finisher_) {}
std::shared_ptr<MockWorkerClient> worker_client_;
std::shared_ptr<CoreWorkerMemoryStore> store_;
std::shared_ptr<MockTaskFinisher> task_finisher_;
CoreWorkerDirectActorTaskSubmitter submitter_;
};
@ -60,6 +83,13 @@ TEST_F(DirectActorTransportTest, TestSubmitTask) {
task = CreateActorTaskHelper(actor_id, 1);
ASSERT_TRUE(submitter_.SubmitTask(task).ok());
ASSERT_EQ(worker_client_->callbacks.size(), 2);
EXPECT_CALL(*task_finisher_, CompletePendingTask(TaskID::Nil(), _))
.Times(worker_client_->callbacks.size());
EXPECT_CALL(*task_finisher_, FailPendingTask(_, _)).Times(0);
while (!worker_client_->callbacks.empty()) {
ASSERT_TRUE(worker_client_->ReplyPushTask());
}
}
TEST_F(DirectActorTransportTest, TestDependencies) {
@ -119,6 +149,27 @@ TEST_F(DirectActorTransportTest, TestOutOfOrderDependencies) {
ASSERT_EQ(worker_client_->callbacks.size(), 2);
}
TEST_F(DirectActorTransportTest, TestActorFailure) {
ActorID actor_id = ActorID::Of(JobID::FromInt(0), TaskID::Nil(), 0);
gcs::ActorTableData actor_data;
submitter_.HandleActorUpdate(actor_id, actor_data);
ASSERT_EQ(worker_client_->callbacks.size(), 0);
// Create two tasks for the actor.
auto task1 = CreateActorTaskHelper(actor_id, 0);
auto task2 = CreateActorTaskHelper(actor_id, 1);
ASSERT_TRUE(submitter_.SubmitTask(task1).ok());
ASSERT_TRUE(submitter_.SubmitTask(task2).ok());
ASSERT_EQ(worker_client_->callbacks.size(), 2);
// Simulate the actor dying. All submitted tasks should get failed.
EXPECT_CALL(*task_finisher_, FailPendingTask(_, _)).Times(2);
EXPECT_CALL(*task_finisher_, CompletePendingTask(_, _)).Times(0);
while (!worker_client_->callbacks.empty()) {
ASSERT_TRUE(worker_client_->ReplyPushTask(Status::IOError("")));
}
}
} // namespace ray
int main(int argc, char **argv) {

View file

@ -23,7 +23,32 @@ class MockWorkerClient : public rpc::CoreWorkerClientInterface {
return Status::OK();
}
std::vector<rpc::ClientCallback<rpc::PushTaskReply>> callbacks;
bool ReplyPushTask(Status status = Status::OK()) {
if (callbacks.size() == 0) {
return false;
}
auto callback = callbacks.front();
callback(status, rpc::PushTaskReply());
callbacks.pop_front();
return true;
}
std::list<rpc::ClientCallback<rpc::PushTaskReply>> callbacks;
};
class MockTaskFinisher : public TaskFinisherInterface {
public:
MockTaskFinisher() {}
void CompletePendingTask(const TaskID &, const rpc::PushTaskReply &) override {
num_tasks_complete++;
}
void FailPendingTask(const TaskID &task_id, rpc::ErrorType error_type) override {
num_tasks_failed++;
}
int num_tasks_complete = 0;
int num_tasks_failed = 0;
};
class MockRayletClient : public WorkerLeaseInterface {
@ -196,8 +221,9 @@ TEST(DirectTaskTransportTest, TestSubmitOneTask) {
auto worker_client = std::make_shared<MockWorkerClient>();
auto store = std::make_shared<CoreWorkerMemoryStore>();
auto factory = [&](const rpc::WorkerAddress &addr) { return worker_client; };
auto task_finisher = std::make_shared<MockTaskFinisher>();
CoreWorkerDirectTaskSubmitter submitter(raylet_client, factory, nullptr, store,
kLongTimeout);
task_finisher, kLongTimeout);
TaskSpecification task;
task.GetMutableMessage().set_task_id(TaskID::Nil().Binary());
@ -208,10 +234,14 @@ TEST(DirectTaskTransportTest, TestSubmitOneTask) {
ASSERT_TRUE(raylet_client->GrantWorkerLease("localhost", 1234, ClientID::Nil()));
ASSERT_EQ(worker_client->callbacks.size(), 1);
ASSERT_EQ(task_finisher->num_tasks_complete, 0);
ASSERT_EQ(task_finisher->num_tasks_failed, 0);
worker_client->callbacks[0](Status::OK(), rpc::PushTaskReply());
ASSERT_TRUE(worker_client->ReplyPushTask());
ASSERT_EQ(raylet_client->num_workers_returned, 1);
ASSERT_EQ(raylet_client->num_workers_disconnected, 0);
ASSERT_EQ(task_finisher->num_tasks_complete, 1);
ASSERT_EQ(task_finisher->num_tasks_failed, 0);
}
TEST(DirectTaskTransportTest, TestHandleTaskFailure) {
@ -219,18 +249,21 @@ TEST(DirectTaskTransportTest, TestHandleTaskFailure) {
auto worker_client = std::make_shared<MockWorkerClient>();
auto store = std::make_shared<CoreWorkerMemoryStore>();
auto factory = [&](const rpc::WorkerAddress &addr) { return worker_client; };
auto task_finisher = std::make_shared<MockTaskFinisher>();
CoreWorkerDirectTaskSubmitter submitter(raylet_client, factory, nullptr, store,
kLongTimeout);
task_finisher, kLongTimeout);
TaskSpecification task;
task.GetMutableMessage().set_task_id(TaskID::Nil().Binary());
ASSERT_TRUE(submitter.SubmitTask(task).ok());
ASSERT_TRUE(raylet_client->GrantWorkerLease("localhost", 1234, ClientID::Nil()));
// Simulate a system failure, i.e., worker died unexpectedly.
worker_client->callbacks[0](Status::IOError("oops"), rpc::PushTaskReply());
ASSERT_EQ(worker_client->callbacks.size(), 1);
ASSERT_TRUE(worker_client->ReplyPushTask(Status::IOError("oops")));
ASSERT_EQ(worker_client->callbacks.size(), 0);
ASSERT_EQ(raylet_client->num_workers_returned, 0);
ASSERT_EQ(raylet_client->num_workers_disconnected, 1);
ASSERT_EQ(task_finisher->num_tasks_complete, 0);
ASSERT_EQ(task_finisher->num_tasks_failed, 1);
}
TEST(DirectTaskTransportTest, TestConcurrentWorkerLeases) {
@ -238,8 +271,9 @@ TEST(DirectTaskTransportTest, TestConcurrentWorkerLeases) {
auto worker_client = std::make_shared<MockWorkerClient>();
auto store = std::make_shared<CoreWorkerMemoryStore>();
auto factory = [&](const rpc::WorkerAddress &addr) { return worker_client; };
auto task_finisher = std::make_shared<MockTaskFinisher>();
CoreWorkerDirectTaskSubmitter submitter(raylet_client, factory, nullptr, store,
kLongTimeout);
task_finisher, kLongTimeout);
TaskSpecification task1;
TaskSpecification task2;
TaskSpecification task3;
@ -268,11 +302,13 @@ TEST(DirectTaskTransportTest, TestConcurrentWorkerLeases) {
ASSERT_EQ(raylet_client->num_workers_requested, 3);
// All workers returned.
for (const auto &cb : worker_client->callbacks) {
cb(Status::OK(), rpc::PushTaskReply());
while (!worker_client->callbacks.empty()) {
ASSERT_TRUE(worker_client->ReplyPushTask());
}
ASSERT_EQ(raylet_client->num_workers_returned, 3);
ASSERT_EQ(raylet_client->num_workers_disconnected, 0);
ASSERT_EQ(task_finisher->num_tasks_complete, 3);
ASSERT_EQ(task_finisher->num_tasks_failed, 0);
}
TEST(DirectTaskTransportTest, TestReuseWorkerLease) {
@ -280,8 +316,9 @@ TEST(DirectTaskTransportTest, TestReuseWorkerLease) {
auto worker_client = std::make_shared<MockWorkerClient>();
auto store = std::make_shared<CoreWorkerMemoryStore>();
auto factory = [&](const rpc::WorkerAddress &addr) { return worker_client; };
auto task_finisher = std::make_shared<MockTaskFinisher>();
CoreWorkerDirectTaskSubmitter submitter(raylet_client, factory, nullptr, store,
kLongTimeout);
task_finisher, kLongTimeout);
TaskSpecification task1;
TaskSpecification task2;
TaskSpecification task3;
@ -300,23 +337,26 @@ TEST(DirectTaskTransportTest, TestReuseWorkerLease) {
ASSERT_EQ(raylet_client->num_workers_requested, 2);
// Task 1 finishes, Task 2 is scheduled on the same worker.
worker_client->callbacks[0](Status::OK(), rpc::PushTaskReply());
ASSERT_EQ(worker_client->callbacks.size(), 2);
ASSERT_TRUE(worker_client->ReplyPushTask());
ASSERT_EQ(worker_client->callbacks.size(), 1);
ASSERT_EQ(raylet_client->num_workers_returned, 0);
// Task 2 finishes, Task 3 is scheduled on the same worker.
worker_client->callbacks[1](Status::OK(), rpc::PushTaskReply());
ASSERT_EQ(worker_client->callbacks.size(), 3);
ASSERT_TRUE(worker_client->ReplyPushTask());
ASSERT_EQ(worker_client->callbacks.size(), 1);
ASSERT_EQ(raylet_client->num_workers_returned, 0);
// Task 3 finishes, the worker is returned.
worker_client->callbacks[2](Status::OK(), rpc::PushTaskReply());
ASSERT_TRUE(worker_client->ReplyPushTask());
ASSERT_EQ(raylet_client->num_workers_returned, 1);
// The second lease request is returned immediately.
ASSERT_TRUE(raylet_client->GrantWorkerLease("localhost", 1001, ClientID::Nil()));
ASSERT_EQ(worker_client->callbacks.size(), 0);
ASSERT_EQ(raylet_client->num_workers_returned, 2);
ASSERT_EQ(raylet_client->num_workers_disconnected, 0);
ASSERT_EQ(task_finisher->num_tasks_complete, 3);
ASSERT_EQ(task_finisher->num_tasks_failed, 0);
}
TEST(DirectTaskTransportTest, TestWorkerNotReusedOnError) {
@ -324,8 +364,9 @@ TEST(DirectTaskTransportTest, TestWorkerNotReusedOnError) {
auto worker_client = std::make_shared<MockWorkerClient>();
auto store = std::make_shared<CoreWorkerMemoryStore>();
auto factory = [&](const rpc::WorkerAddress &addr) { return worker_client; };
auto task_finisher = std::make_shared<MockTaskFinisher>();
CoreWorkerDirectTaskSubmitter submitter(raylet_client, factory, nullptr, store,
kLongTimeout);
task_finisher, kLongTimeout);
TaskSpecification task1;
TaskSpecification task2;
task1.GetMutableMessage().set_task_id(TaskID::Nil().Binary());
@ -341,17 +382,18 @@ TEST(DirectTaskTransportTest, TestWorkerNotReusedOnError) {
ASSERT_EQ(raylet_client->num_workers_requested, 2);
// Task 1 finishes with failure; the worker is returned.
worker_client->callbacks[0](Status::IOError("worker dead"), rpc::PushTaskReply());
ASSERT_EQ(worker_client->callbacks.size(), 1);
ASSERT_TRUE(worker_client->ReplyPushTask(Status::IOError("worker dead")));
ASSERT_EQ(worker_client->callbacks.size(), 0);
ASSERT_EQ(raylet_client->num_workers_returned, 0);
ASSERT_EQ(raylet_client->num_workers_disconnected, 1);
// Task 2 runs successfully on the second worker.
ASSERT_TRUE(raylet_client->GrantWorkerLease("localhost", 1001, ClientID::Nil()));
ASSERT_EQ(worker_client->callbacks.size(), 2);
worker_client->callbacks[1](Status::OK(), rpc::PushTaskReply());
ASSERT_TRUE(worker_client->ReplyPushTask());
ASSERT_EQ(raylet_client->num_workers_returned, 1);
ASSERT_EQ(raylet_client->num_workers_disconnected, 1);
ASSERT_EQ(task_finisher->num_tasks_complete, 1);
ASSERT_EQ(task_finisher->num_tasks_failed, 1);
}
TEST(DirectTaskTransportTest, TestSpillback) {
@ -369,8 +411,9 @@ TEST(DirectTaskTransportTest, TestSpillback) {
remote_lease_clients[raylet_id] = client;
return client;
};
auto task_finisher = std::make_shared<MockTaskFinisher>();
CoreWorkerDirectTaskSubmitter submitter(raylet_client, factory, lease_client_factory,
store, kLongTimeout);
store, task_finisher, kLongTimeout);
TaskSpecification task;
task.GetMutableMessage().set_task_id(TaskID::Nil().Binary());
@ -389,15 +432,15 @@ TEST(DirectTaskTransportTest, TestSpillback) {
// Trigger retry at the remote node.
ASSERT_TRUE(remote_lease_clients[remote_raylet_id]->GrantWorkerLease("remote", 1234,
ClientID::Nil()));
ASSERT_EQ(worker_client->callbacks.size(), 1);
// The worker is returned to the remote node, not the local one.
worker_client->callbacks[0](Status::OK(), rpc::PushTaskReply());
ASSERT_TRUE(worker_client->ReplyPushTask());
ASSERT_EQ(raylet_client->num_workers_returned, 0);
ASSERT_EQ(remote_lease_clients[remote_raylet_id]->num_workers_returned, 1);
ASSERT_EQ(raylet_client->num_workers_disconnected, 0);
ASSERT_EQ(remote_lease_clients[remote_raylet_id]->num_workers_disconnected, 0);
ASSERT_EQ(task_finisher->num_tasks_complete, 1);
ASSERT_EQ(task_finisher->num_tasks_failed, 0);
}
TEST(DirectTaskTransportTest, TestWorkerLeaseTimeout) {
@ -405,7 +448,9 @@ TEST(DirectTaskTransportTest, TestWorkerLeaseTimeout) {
auto worker_client = std::make_shared<MockWorkerClient>();
auto store = std::make_shared<CoreWorkerMemoryStore>();
auto factory = [&](const rpc::WorkerAddress &addr) { return worker_client; };
auto task_finisher = std::make_shared<MockTaskFinisher>();
CoreWorkerDirectTaskSubmitter submitter(raylet_client, factory, nullptr, store,
task_finisher,
/*lease_timeout_ms=*/5);
TaskSpecification task1;
TaskSpecification task2;
@ -421,13 +466,11 @@ TEST(DirectTaskTransportTest, TestWorkerLeaseTimeout) {
// Task 1 is pushed.
ASSERT_TRUE(raylet_client->GrantWorkerLease("localhost", 1000, ClientID::Nil()));
ASSERT_EQ(worker_client->callbacks.size(), 1);
ASSERT_EQ(raylet_client->num_workers_requested, 2);
// Task 1 finishes with failure; the worker is returned due to the error even though
// it hasn't timed out.
worker_client->callbacks[0](Status::IOError("worker dead"), rpc::PushTaskReply());
ASSERT_EQ(worker_client->callbacks.size(), 1);
ASSERT_TRUE(worker_client->ReplyPushTask(Status::IOError("worker dead")));
ASSERT_EQ(raylet_client->num_workers_returned, 0);
ASSERT_EQ(raylet_client->num_workers_disconnected, 1);
@ -435,16 +478,15 @@ TEST(DirectTaskTransportTest, TestWorkerLeaseTimeout) {
// timeout.
ASSERT_TRUE(raylet_client->GrantWorkerLease("localhost", 1001, ClientID::Nil()));
usleep(10 * 1000); // Sleep for 10ms, causing the lease to time out.
ASSERT_EQ(worker_client->callbacks.size(), 2);
worker_client->callbacks[1](Status::OK(), rpc::PushTaskReply());
ASSERT_TRUE(worker_client->ReplyPushTask());
ASSERT_EQ(raylet_client->num_workers_returned, 1);
ASSERT_EQ(raylet_client->num_workers_disconnected, 1);
// Task 3 runs successfully on the third worker; the worker is returned even though it
// hasn't timed out.
ASSERT_TRUE(raylet_client->GrantWorkerLease("localhost", 1002, ClientID::Nil()));
ASSERT_EQ(worker_client->callbacks.size(), 3);
worker_client->callbacks[2](Status::OK(), rpc::PushTaskReply());
ASSERT_TRUE(worker_client->ReplyPushTask());
ASSERT_EQ(worker_client->callbacks.size(), 0);
ASSERT_EQ(raylet_client->num_workers_returned, 2);
ASSERT_EQ(raylet_client->num_workers_disconnected, 1);
}

View file

@ -0,0 +1,77 @@
#include "gtest/gtest.h"
#include "ray/common/task/task_spec.h"
#include "ray/core_worker/store_provider/memory_store/memory_store.h"
#include "ray/core_worker/task_manager.h"
#include "ray/util/test_util.h"
namespace ray {
TaskSpecification CreateTaskHelper(uint64_t num_returns) {
TaskSpecification task;
task.GetMutableMessage().set_task_id(TaskID::ForFakeTask().Binary());
task.GetMutableMessage().set_num_returns(num_returns);
return task;
}
class TaskManagerTest : public ::testing::Test {
public:
TaskManagerTest()
: store_(std::shared_ptr<CoreWorkerMemoryStore>(new CoreWorkerMemoryStore())),
manager_(store_) {}
std::shared_ptr<CoreWorkerMemoryStore> store_;
TaskManager manager_;
};
TEST_F(TaskManagerTest, TestTaskSuccess) {
auto spec = CreateTaskHelper(1);
ASSERT_FALSE(manager_.IsTaskPending(spec.TaskId()));
manager_.AddPendingTask(spec);
ASSERT_TRUE(manager_.IsTaskPending(spec.TaskId()));
auto return_id = spec.ReturnId(0, TaskTransportType::DIRECT);
WorkerContext ctx(WorkerType::WORKER, JobID::FromInt(0));
rpc::PushTaskReply reply;
auto return_object = reply.add_return_objects();
return_object->set_object_id(return_id.Binary());
auto data = GenerateRandomBuffer();
return_object->set_data(data->Data(), data->Size());
manager_.CompletePendingTask(spec.TaskId(), reply);
ASSERT_FALSE(manager_.IsTaskPending(spec.TaskId()));
std::vector<std::shared_ptr<RayObject>> results;
RAY_CHECK_OK(store_->Get({return_id}, 1, -1, ctx, false, &results));
ASSERT_EQ(results.size(), 1);
ASSERT_FALSE(results[0]->IsException());
ASSERT_EQ(std::memcmp(results[0]->GetData()->Data(), return_object->data().data(),
return_object->data().size()),
0);
}
TEST_F(TaskManagerTest, TestTaskFailure) {
auto spec = CreateTaskHelper(1);
ASSERT_FALSE(manager_.IsTaskPending(spec.TaskId()));
manager_.AddPendingTask(spec);
ASSERT_TRUE(manager_.IsTaskPending(spec.TaskId()));
auto return_id = spec.ReturnId(0, TaskTransportType::DIRECT);
WorkerContext ctx(WorkerType::WORKER, JobID::FromInt(0));
auto error = rpc::ErrorType::WORKER_DIED;
manager_.FailPendingTask(spec.TaskId(), error);
ASSERT_FALSE(manager_.IsTaskPending(spec.TaskId()));
std::vector<std::shared_ptr<RayObject>> results;
RAY_CHECK_OK(store_->Get({return_id}, 1, -1, ctx, false, &results));
ASSERT_EQ(results.size(), 1);
rpc::ErrorType stored_error;
ASSERT_TRUE(results[0]->IsException(&stored_error));
ASSERT_EQ(stored_error, error);
}
} // namespace ray
int main(int argc, char **argv) {
::testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
}

View file

@ -11,57 +11,6 @@ int64_t GetRequestNumber(const std::unique_ptr<rpc::PushTaskRequest> &request) {
return request->task_spec().actor_task_spec().actor_counter();
}
void TreatTaskAsFailed(const TaskID &task_id, int num_returns,
const rpc::ErrorType &error_type,
std::shared_ptr<CoreWorkerMemoryStore> &in_memory_store) {
RAY_LOG(DEBUG) << "Treat task as failed. task_id: " << task_id
<< ", error_type: " << ErrorType_Name(error_type);
for (int i = 0; i < num_returns; i++) {
const auto object_id = ObjectID::ForTaskReturn(
task_id, /*index=*/i + 1,
/*transport_type=*/static_cast<int>(TaskTransportType::DIRECT));
std::string meta = std::to_string(static_cast<int>(error_type));
auto metadata = const_cast<uint8_t *>(reinterpret_cast<const uint8_t *>(meta.data()));
auto meta_buffer = std::make_shared<LocalMemoryBuffer>(metadata, meta.size());
RAY_CHECK_OK(in_memory_store->Put(RayObject(nullptr, meta_buffer), object_id));
}
}
void WriteObjectsToMemoryStore(const rpc::PushTaskReply &reply,
std::shared_ptr<CoreWorkerMemoryStore> &in_memory_store) {
for (int i = 0; i < reply.return_objects_size(); i++) {
const auto &return_object = reply.return_objects(i);
ObjectID object_id = ObjectID::FromBinary(return_object.object_id());
if (return_object.in_plasma()) {
// Mark it as in plasma with a dummy object.
std::string meta =
std::to_string(static_cast<int>(rpc::ErrorType::OBJECT_IN_PLASMA));
auto metadata =
const_cast<uint8_t *>(reinterpret_cast<const uint8_t *>(meta.data()));
auto meta_buffer = std::make_shared<LocalMemoryBuffer>(metadata, meta.size());
RAY_CHECK_OK(in_memory_store->Put(RayObject(nullptr, meta_buffer), object_id));
} else {
std::shared_ptr<LocalMemoryBuffer> data_buffer;
if (return_object.data().size() > 0) {
data_buffer = std::make_shared<LocalMemoryBuffer>(
const_cast<uint8_t *>(
reinterpret_cast<const uint8_t *>(return_object.data().data())),
return_object.data().size());
}
std::shared_ptr<LocalMemoryBuffer> metadata_buffer;
if (return_object.metadata().size() > 0) {
metadata_buffer = std::make_shared<LocalMemoryBuffer>(
const_cast<uint8_t *>(
reinterpret_cast<const uint8_t *>(return_object.metadata().data())),
return_object.metadata().size());
}
RAY_CHECK_OK(
in_memory_store->Put(RayObject(data_buffer, metadata_buffer), object_id));
}
}
}
Status CoreWorkerDirectActorTaskSubmitter::SubmitTask(TaskSpecification task_spec) {
RAY_LOG(DEBUG) << "Submitting task " << task_spec.TaskId();
RAY_CHECK(task_spec.IsActorTask());
@ -69,7 +18,6 @@ Status CoreWorkerDirectActorTaskSubmitter::SubmitTask(TaskSpecification task_spe
resolver_.ResolveDependencies(task_spec, [this, task_spec]() mutable {
const auto &actor_id = task_spec.ActorId();
const auto task_id = task_spec.TaskId();
const auto num_returns = task_spec.NumReturns();
auto request = std::unique_ptr<rpc::PushTaskRequest>(new rpc::PushTaskRequest);
request->mutable_task_spec()->Swap(&task_spec.GetMutableMessage());
@ -97,8 +45,7 @@ Status CoreWorkerDirectActorTaskSubmitter::SubmitTask(TaskSpecification task_spe
} else {
// Actor is dead, treat the task as failure.
RAY_CHECK(iter->second.state_ == ActorTableData::DEAD);
TreatTaskAsFailed(task_id, num_returns, rpc::ErrorType::ACTOR_DIED,
in_memory_store_);
task_finisher_->FailPendingTask(task_id, rpc::ErrorType::ACTOR_DIED);
}
});
@ -130,19 +77,6 @@ void CoreWorkerDirectActorTaskSubmitter::HandleActorUpdate(
// Remove rpc client if it's dead or being reconstructed.
rpc_clients_.erase(actor_id);
// For tasks that have been sent and are waiting for replies, treat them
// as failed when the destination actor is dead or reconstructing.
auto iter = waiting_reply_tasks_.find(actor_id);
if (iter != waiting_reply_tasks_.end()) {
for (const auto &entry : iter->second) {
const auto &task_id = entry.first;
const auto num_returns = entry.second;
TreatTaskAsFailed(task_id, num_returns, rpc::ErrorType::ACTOR_DIED,
in_memory_store_);
}
waiting_reply_tasks_.erase(actor_id);
}
// If there are pending requests, treat the pending tasks as failed.
auto pending_it = pending_requests_.find(actor_id);
if (pending_it != pending_requests_.end()) {
@ -150,15 +84,16 @@ void CoreWorkerDirectActorTaskSubmitter::HandleActorUpdate(
while (head != pending_it->second.end()) {
auto request = std::move(head->second);
head = pending_it->second.erase(head);
TreatTaskAsFailed(TaskID::FromBinary(request->task_spec().task_id()),
request->task_spec().num_returns(), rpc::ErrorType::ACTOR_DIED,
in_memory_store_);
auto task_id = TaskID::FromBinary(request->task_spec().task_id());
task_finisher_->FailPendingTask(task_id, rpc::ErrorType::ACTOR_DIED);
}
pending_requests_.erase(pending_it);
}
next_sequence_number_.erase(actor_id);
// No need to clean up tasks that have been sent and are waiting for
// replies. They will be treated as failed once the connection dies.
}
}
@ -182,7 +117,6 @@ void CoreWorkerDirectActorTaskSubmitter::PushActorTask(
rpc::CoreWorkerClientInterface &client, std::unique_ptr<rpc::PushTaskRequest> request,
const ActorID &actor_id, const TaskID &task_id, int num_returns) {
RAY_LOG(DEBUG) << "Pushing task " << task_id << " to actor " << actor_id;
waiting_reply_tasks_[actor_id].insert(std::make_pair(task_id, num_returns));
auto task_number = GetRequestNumber(request);
RAY_CHECK(next_sequence_number_[actor_id] == task_number)
@ -190,24 +124,19 @@ void CoreWorkerDirectActorTaskSubmitter::PushActorTask(
next_sequence_number_[actor_id]++;
auto status = client.PushActorTask(
std::move(request), [this, actor_id, task_id, num_returns](
Status status, const rpc::PushTaskReply &reply) {
{
std::unique_lock<std::mutex> guard(mutex_);
waiting_reply_tasks_[actor_id].erase(task_id);
}
std::move(request),
[this, task_id](Status status, const rpc::PushTaskReply &reply) {
if (!status.ok()) {
// Note that this might be the __ray_terminate__ task, so we don't log
// loudly with ERROR here.
RAY_LOG(INFO) << "Task failed with error: " << status;
TreatTaskAsFailed(task_id, num_returns, rpc::ErrorType::ACTOR_DIED,
in_memory_store_);
return;
task_finisher_->FailPendingTask(task_id, rpc::ErrorType::ACTOR_DIED);
} else {
task_finisher_->CompletePendingTask(task_id, reply);
}
WriteObjectsToMemoryStore(reply, in_memory_store_);
});
if (!status.ok()) {
TreatTaskAsFailed(task_id, num_returns, rpc::ErrorType::ACTOR_DIED, in_memory_store_);
task_finisher_->FailPendingTask(task_id, rpc::ErrorType::ACTOR_DIED);
}
}

View file

@ -16,6 +16,7 @@
#include "ray/common/ray_object.h"
#include "ray/core_worker/context.h"
#include "ray/core_worker/store_provider/memory_store/memory_store.h"
#include "ray/core_worker/task_manager.h"
#include "ray/core_worker/transport/dependency_resolver.h"
#include "ray/gcs/redis_gcs_client.h"
#include "ray/rpc/grpc_server.h"
@ -28,25 +29,6 @@ namespace ray {
/// The max time to wait for out-of-order tasks.
const int kMaxReorderWaitSeconds = 30;
/// Treat a task as failed.
///
/// \param[in] task_id The ID of a task.
/// \param[in] num_returns Number of return objects.
/// \param[in] error_type The type of the specific error.
/// \param[in] in_memory_store The memory store to write to.
/// \return Void.
void TreatTaskAsFailed(const TaskID &task_id, int num_returns,
const rpc::ErrorType &error_type,
std::shared_ptr<CoreWorkerMemoryStore> &in_memory_store);
/// Write return objects to the memory store.
///
/// \param[in] reply Proto response to a direct actor or task call.
/// \param[in] in_memory_store The memory store to write to.
/// \return Void.
void WriteObjectsToMemoryStore(const rpc::PushTaskReply &reply,
std::shared_ptr<CoreWorkerMemoryStore> &in_memory_store);
/// In direct actor call task submitter and receiver, a task is directly submitted
/// to the actor that will execute it.
@ -66,10 +48,11 @@ struct ActorStateData {
class CoreWorkerDirectActorTaskSubmitter {
public:
CoreWorkerDirectActorTaskSubmitter(rpc::ClientFactoryFn client_factory,
std::shared_ptr<CoreWorkerMemoryStore> store)
std::shared_ptr<CoreWorkerMemoryStore> store,
std::shared_ptr<TaskFinisherInterface> task_finisher)
: client_factory_(client_factory),
in_memory_store_(store),
resolver_(in_memory_store_) {}
resolver_(store),
task_finisher_(task_finisher) {}
/// Submit a task to an actor for execution.
///
@ -138,15 +121,12 @@ class CoreWorkerDirectActorTaskSubmitter {
/// actor.
std::unordered_map<ActorID, int64_t> next_sequence_number_;
/// Map from actor id to the tasks that are waiting for reply.
std::unordered_map<ActorID, std::unordered_map<TaskID, int>> waiting_reply_tasks_;
/// The in-memory store.
std::shared_ptr<CoreWorkerMemoryStore> in_memory_store_;
/// Resolve direct call object dependencies;
LocalDependencyResolver resolver_;
/// Used to complete tasks.
std::shared_ptr<TaskFinisherInterface> task_finisher_;
friend class CoreWorkerTest;
};

View file

@ -43,12 +43,16 @@ void CoreWorkerDirectTaskSubmitter::OnWorkerIdle(const rpc::WorkerAddress &addr,
if (was_error || queued_tasks_.empty() || current_time_ms() > entry.second) {
RAY_CHECK_OK(entry.first->ReturnWorker(addr.second, was_error));
worker_to_lease_client_.erase(addr);
} else {
} else if (!queued_tasks_.empty()) {
auto &client = *client_cache_[addr];
PushNormalTask(addr, client, queued_tasks_.front());
queued_tasks_.pop_front();
}
RequestNewWorkerIfNeeded(queued_tasks_.front());
// There are more tasks to run, so try to get another worker.
if (!queued_tasks_.empty()) {
RequestNewWorkerIfNeeded(queued_tasks_.front());
}
}
std::shared_ptr<WorkerLeaseInterface>
@ -129,23 +133,21 @@ void CoreWorkerDirectTaskSubmitter::PushNormalTask(const rpc::WorkerAddress &add
rpc::CoreWorkerClientInterface &client,
TaskSpecification &task_spec) {
auto task_id = task_spec.TaskId();
auto num_returns = task_spec.NumReturns();
auto request = std::unique_ptr<rpc::PushTaskRequest>(new rpc::PushTaskRequest);
request->mutable_task_spec()->Swap(&task_spec.GetMutableMessage());
auto status = client.PushNormalTask(
std::move(request),
[this, task_id, num_returns, addr](Status status, const rpc::PushTaskReply &reply) {
[this, task_id, addr](Status status, const rpc::PushTaskReply &reply) {
OnWorkerIdle(addr, /*error=*/!status.ok());
if (!status.ok()) {
TreatTaskAsFailed(task_id, num_returns, rpc::ErrorType::WORKER_DIED,
in_memory_store_);
return;
task_finisher_->FailPendingTask(task_id, rpc::ErrorType::WORKER_DIED);
} else {
task_finisher_->CompletePendingTask(task_id, reply);
}
WriteObjectsToMemoryStore(reply, in_memory_store_);
});
if (!status.ok()) {
TreatTaskAsFailed(task_id, num_returns, rpc::ErrorType::WORKER_DIED,
in_memory_store_);
// TODO(swang): add unit test for this.
task_finisher_->FailPendingTask(task_id, rpc::ErrorType::WORKER_DIED);
}
}
}; // namespace ray

View file

@ -3,10 +3,12 @@
#include "absl/base/thread_annotations.h"
#include "absl/synchronization/mutex.h"
#include "ray/common/id.h"
#include "ray/common/ray_object.h"
#include "ray/core_worker/context.h"
#include "ray/core_worker/store_provider/memory_store/memory_store.h"
#include "ray/core_worker/task_manager.h"
#include "ray/core_worker/transport/dependency_resolver.h"
#include "ray/core_worker/transport/direct_actor_transport.h"
#include "ray/raylet/raylet_client.h"
@ -24,12 +26,13 @@ class CoreWorkerDirectTaskSubmitter {
rpc::ClientFactoryFn client_factory,
LeaseClientFactoryFn lease_client_factory,
std::shared_ptr<CoreWorkerMemoryStore> store,
std::shared_ptr<TaskFinisherInterface> task_finisher,
int64_t lease_timeout_ms)
: local_lease_client_(lease_client),
client_factory_(client_factory),
lease_client_factory_(lease_client_factory),
in_memory_store_(store),
resolver_(in_memory_store_),
resolver_(store),
task_finisher_(task_finisher),
lease_timeout_ms_(lease_timeout_ms) {}
/// Schedule a task for direct submission to a worker.
@ -81,12 +84,12 @@ class CoreWorkerDirectTaskSubmitter {
/// Factory for producing new clients to request leases from remote nodes.
LeaseClientFactoryFn lease_client_factory_;
/// The store provider.
std::shared_ptr<CoreWorkerMemoryStore> in_memory_store_;
/// Resolve local and remote dependencies;
LocalDependencyResolver resolver_;
/// Used to complete tasks.
std::shared_ptr<TaskFinisherInterface> task_finisher_;
/// The timeout for worker leases; after this duration, workers will be returned
/// to the raylet.
int64_t lease_timeout_ms_;