mirror of
https://github.com/vale981/ray
synced 2025-03-05 10:01:43 -05:00
Track pending tasks with TaskManager (#6259)
* TaskStateManager to track and complete pending tasks * Convert actor transport to use task state manager * Refactor direct actor transport to use TaskStateManager * rename * Unit test * doc * IsTaskPending * Fix? * Shared ptr * HUH? * Update src/ray/core_worker/task_manager.cc Co-Authored-By: Zhijun Fu <37800433+zhijunfu@users.noreply.github.com> * Revert "HUH?" This reverts commit f80f0ba204ff4da5e0b03191fa0d5a4d9f552434. * Fix memory issue * oops
This commit is contained in:
parent
ed5154d7fe
commit
f6a0408173
14 changed files with 423 additions and 164 deletions
10
BUILD.bazel
10
BUILD.bazel
|
@ -448,6 +448,16 @@ cc_test(
|
|||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "task_manager_test",
|
||||
srcs = ["src/ray/core_worker/test/task_manager_test.cc"],
|
||||
copts = COPTS,
|
||||
deps = [
|
||||
":core_worker_lib",
|
||||
"@com_google_googletest//:gtest_main",
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "scheduling_test",
|
||||
srcs = ["src/ray/common/scheduling/scheduling_test.cc"],
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
|
||||
namespace ray {
|
||||
|
||||
bool RayObject::IsException() const {
|
||||
bool RayObject::IsException(rpc::ErrorType *error_type) const {
|
||||
if (metadata_ == nullptr) {
|
||||
return false;
|
||||
}
|
||||
|
@ -13,6 +13,9 @@ bool RayObject::IsException() const {
|
|||
for (int i = 0; i < error_type_descriptor->value_count(); i++) {
|
||||
const auto error_type_number = error_type_descriptor->value(i)->number();
|
||||
if (metadata == std::to_string(error_type_number)) {
|
||||
if (error_type) {
|
||||
*error_type = rpc::ErrorType(error_type_number);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -62,7 +62,7 @@ class RayObject {
|
|||
bool HasMetadata() const { return metadata_ != nullptr; }
|
||||
|
||||
/// Whether the object represents an exception.
|
||||
bool IsException() const;
|
||||
bool IsException(rpc::ErrorType *error_type = nullptr) const;
|
||||
|
||||
/// Whether the object has been promoted to plasma (i.e., since it was too
|
||||
/// large to return directly as part of a gRPC response).
|
||||
|
|
|
@ -164,6 +164,7 @@ CoreWorker::CoreWorker(const WorkerType worker_type, const Language language,
|
|||
RAY_CHECK_OK(plasma_store_provider_->Put(obj, obj_id));
|
||||
},
|
||||
ref_counting_enabled ? reference_counter_ : nullptr, raylet_client_));
|
||||
task_manager_.reset(new TaskManager(memory_store_));
|
||||
resolver_.reset(new LocalDependencyResolver(memory_store_));
|
||||
|
||||
// Create an entry for the driver task in the task table. This task is
|
||||
|
@ -193,7 +194,8 @@ CoreWorker::CoreWorker(const WorkerType worker_type, const Language language,
|
|||
new rpc::CoreWorkerClient(addr.first, addr.second, *client_call_manager_));
|
||||
};
|
||||
direct_actor_submitter_ = std::unique_ptr<CoreWorkerDirectActorTaskSubmitter>(
|
||||
new CoreWorkerDirectActorTaskSubmitter(client_factory, memory_store_));
|
||||
new CoreWorkerDirectActorTaskSubmitter(client_factory, memory_store_,
|
||||
task_manager_));
|
||||
|
||||
direct_task_submitter_ =
|
||||
std::unique_ptr<CoreWorkerDirectTaskSubmitter>(new CoreWorkerDirectTaskSubmitter(
|
||||
|
@ -204,7 +206,8 @@ CoreWorker::CoreWorker(const WorkerType worker_type, const Language language,
|
|||
return std::shared_ptr<RayletClient>(
|
||||
new RayletClient(std::move(grpc_client)));
|
||||
},
|
||||
memory_store_, RayConfig::instance().worker_lease_timeout_milliseconds()));
|
||||
memory_store_, task_manager_,
|
||||
RayConfig::instance().worker_lease_timeout_milliseconds()));
|
||||
}
|
||||
|
||||
CoreWorker::~CoreWorker() {
|
||||
|
@ -577,6 +580,7 @@ Status CoreWorker::SubmitTask(const RayFunction &function,
|
|||
return_ids);
|
||||
TaskSpecification task_spec = builder.Build();
|
||||
if (task_options.is_direct_call) {
|
||||
task_manager_->AddPendingTask(task_spec);
|
||||
PinObjectReferences(task_spec, TaskTransportType::DIRECT);
|
||||
return direct_task_submitter_->SubmitTask(task_spec);
|
||||
} else {
|
||||
|
@ -659,6 +663,7 @@ Status CoreWorker::SubmitActorTask(const ActorID &actor_id, const RayFunction &f
|
|||
Status status;
|
||||
TaskSpecification task_spec = builder.Build();
|
||||
if (is_direct_call) {
|
||||
task_manager_->AddPendingTask(task_spec);
|
||||
PinObjectReferences(task_spec, TaskTransportType::DIRECT);
|
||||
status = direct_actor_submitter_->SubmitTask(task_spec);
|
||||
} else {
|
||||
|
|
|
@ -513,6 +513,9 @@ class CoreWorker {
|
|||
/// Fields related to task submission.
|
||||
///
|
||||
|
||||
// Tracks the currently pending tasks.
|
||||
std::shared_ptr<TaskManager> task_manager_;
|
||||
|
||||
// Interface to submit tasks directly to other actors.
|
||||
std::unique_ptr<CoreWorkerDirectActorTaskSubmitter> direct_actor_submitter_;
|
||||
|
||||
|
|
80
src/ray/core_worker/task_manager.cc
Normal file
80
src/ray/core_worker/task_manager.cc
Normal file
|
@ -0,0 +1,80 @@
|
|||
#include "ray/core_worker/task_manager.h"
|
||||
|
||||
namespace ray {
|
||||
|
||||
void TaskManager::AddPendingTask(const TaskSpecification &spec) {
|
||||
RAY_LOG(DEBUG) << "Adding pending task " << spec.TaskId();
|
||||
absl::MutexLock lock(&mu_);
|
||||
RAY_CHECK(pending_tasks_.emplace(spec.TaskId(), spec.NumReturns()).second);
|
||||
}
|
||||
|
||||
void TaskManager::CompletePendingTask(const TaskID &task_id,
|
||||
const rpc::PushTaskReply &reply) {
|
||||
RAY_LOG(DEBUG) << "Completing task " << task_id;
|
||||
{
|
||||
absl::MutexLock lock(&mu_);
|
||||
auto it = pending_tasks_.find(task_id);
|
||||
RAY_CHECK(it != pending_tasks_.end())
|
||||
<< "Tried to complete task that was not pending " << task_id;
|
||||
pending_tasks_.erase(it);
|
||||
}
|
||||
|
||||
for (int i = 0; i < reply.return_objects_size(); i++) {
|
||||
const auto &return_object = reply.return_objects(i);
|
||||
ObjectID object_id = ObjectID::FromBinary(return_object.object_id());
|
||||
|
||||
if (return_object.in_plasma()) {
|
||||
// Mark it as in plasma with a dummy object.
|
||||
std::string meta =
|
||||
std::to_string(static_cast<int>(rpc::ErrorType::OBJECT_IN_PLASMA));
|
||||
auto metadata =
|
||||
const_cast<uint8_t *>(reinterpret_cast<const uint8_t *>(meta.data()));
|
||||
auto meta_buffer = std::make_shared<LocalMemoryBuffer>(metadata, meta.size());
|
||||
RAY_CHECK_OK(in_memory_store_->Put(RayObject(nullptr, meta_buffer), object_id));
|
||||
} else {
|
||||
std::shared_ptr<LocalMemoryBuffer> data_buffer;
|
||||
if (return_object.data().size() > 0) {
|
||||
data_buffer = std::make_shared<LocalMemoryBuffer>(
|
||||
const_cast<uint8_t *>(
|
||||
reinterpret_cast<const uint8_t *>(return_object.data().data())),
|
||||
return_object.data().size());
|
||||
}
|
||||
std::shared_ptr<LocalMemoryBuffer> metadata_buffer;
|
||||
if (return_object.metadata().size() > 0) {
|
||||
metadata_buffer = std::make_shared<LocalMemoryBuffer>(
|
||||
const_cast<uint8_t *>(
|
||||
reinterpret_cast<const uint8_t *>(return_object.metadata().data())),
|
||||
return_object.metadata().size());
|
||||
}
|
||||
RAY_CHECK_OK(
|
||||
in_memory_store_->Put(RayObject(data_buffer, metadata_buffer), object_id));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void TaskManager::FailPendingTask(const TaskID &task_id, rpc::ErrorType error_type) {
|
||||
RAY_LOG(DEBUG) << "Failing task " << task_id;
|
||||
int64_t num_returns;
|
||||
{
|
||||
absl::MutexLock lock(&mu_);
|
||||
auto it = pending_tasks_.find(task_id);
|
||||
RAY_CHECK(it != pending_tasks_.end())
|
||||
<< "Tried to complete task that was not pending " << task_id;
|
||||
num_returns = it->second;
|
||||
pending_tasks_.erase(it);
|
||||
}
|
||||
|
||||
RAY_LOG(DEBUG) << "Treat task as failed. task_id: " << task_id
|
||||
<< ", error_type: " << ErrorType_Name(error_type);
|
||||
for (int i = 0; i < num_returns; i++) {
|
||||
const auto object_id = ObjectID::ForTaskReturn(
|
||||
task_id, /*index=*/i + 1,
|
||||
/*transport_type=*/static_cast<int>(TaskTransportType::DIRECT));
|
||||
std::string meta = std::to_string(static_cast<int>(error_type));
|
||||
auto metadata = const_cast<uint8_t *>(reinterpret_cast<const uint8_t *>(meta.data()));
|
||||
auto meta_buffer = std::make_shared<LocalMemoryBuffer>(metadata, meta.size());
|
||||
RAY_CHECK_OK(in_memory_store_->Put(RayObject(nullptr, meta_buffer), object_id));
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace ray
|
74
src/ray/core_worker/task_manager.h
Normal file
74
src/ray/core_worker/task_manager.h
Normal file
|
@ -0,0 +1,74 @@
|
|||
#ifndef RAY_CORE_WORKER_TASK_MANAGER_H
|
||||
#define RAY_CORE_WORKER_TASK_MANAGER_H
|
||||
|
||||
#include "absl/base/thread_annotations.h"
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/synchronization/mutex.h"
|
||||
|
||||
#include "ray/common/id.h"
|
||||
#include "ray/common/task/task.h"
|
||||
#include "ray/core_worker/store_provider/memory_store/memory_store.h"
|
||||
#include "ray/protobuf/core_worker.pb.h"
|
||||
#include "ray/protobuf/gcs.pb.h"
|
||||
|
||||
namespace ray {
|
||||
|
||||
class TaskFinisherInterface {
|
||||
public:
|
||||
virtual void CompletePendingTask(const TaskID &task_id,
|
||||
const rpc::PushTaskReply &reply) = 0;
|
||||
|
||||
virtual void FailPendingTask(const TaskID &task_id, rpc::ErrorType error_type) = 0;
|
||||
|
||||
virtual ~TaskFinisherInterface() {}
|
||||
};
|
||||
|
||||
class TaskManager : public TaskFinisherInterface {
|
||||
public:
|
||||
TaskManager(std::shared_ptr<CoreWorkerMemoryStore> in_memory_store)
|
||||
: in_memory_store_(in_memory_store) {}
|
||||
|
||||
/// Add a task that is pending execution.
|
||||
///
|
||||
/// \param[in] spec The spec of the pending task.
|
||||
/// \return Void.
|
||||
void AddPendingTask(const TaskSpecification &spec);
|
||||
|
||||
/// Return whether the task is pending.
|
||||
///
|
||||
/// \param[in] task_id ID of the task to query.
|
||||
/// \return Whether the task is pending.
|
||||
bool IsTaskPending(const TaskID &task_id) const {
|
||||
return pending_tasks_.count(task_id) > 0;
|
||||
}
|
||||
|
||||
/// Write return objects for a pending task to the memory store.
|
||||
///
|
||||
/// \param[in] task_id ID of the pending task.
|
||||
/// \param[in] reply Proto response to a direct actor or task call.
|
||||
/// \return Void.
|
||||
void CompletePendingTask(const TaskID &task_id,
|
||||
const rpc::PushTaskReply &reply) override;
|
||||
|
||||
/// Treat a pending task as failed.
|
||||
///
|
||||
/// \param[in] task_id ID of the pending task.
|
||||
/// \param[in] error_type The type of the specific error.
|
||||
/// \return Void.
|
||||
void FailPendingTask(const TaskID &task_id, rpc::ErrorType error_type) override;
|
||||
|
||||
private:
|
||||
/// Used to store task results.
|
||||
std::shared_ptr<CoreWorkerMemoryStore> in_memory_store_;
|
||||
|
||||
/// Protects below fields.
|
||||
absl::Mutex mu_;
|
||||
|
||||
/// Map from task ID to the task's number of return values. This map contains
|
||||
/// one entry per pending task that we submitted.
|
||||
absl::flat_hash_map<TaskID, int64_t> pending_tasks_ GUARDED_BY(mu_);
|
||||
};
|
||||
|
||||
} // namespace ray
|
||||
|
||||
#endif // RAY_CORE_WORKER_TASK_MANAGER_H
|
|
@ -1,3 +1,4 @@
|
|||
#include "gmock/gmock.h"
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
#include "ray/common/task/task_spec.h"
|
||||
|
@ -9,6 +10,8 @@
|
|||
|
||||
namespace ray {
|
||||
|
||||
using ::testing::_;
|
||||
|
||||
class MockWorkerClient : public rpc::CoreWorkerClientInterface {
|
||||
public:
|
||||
ray::Status PushActorTask(
|
||||
|
@ -20,10 +23,28 @@ class MockWorkerClient : public rpc::CoreWorkerClientInterface {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
std::vector<rpc::ClientCallback<rpc::PushTaskReply>> callbacks;
|
||||
bool ReplyPushTask(Status status = Status::OK()) {
|
||||
if (callbacks.size() == 0) {
|
||||
return false;
|
||||
}
|
||||
auto callback = callbacks.front();
|
||||
callback(status, rpc::PushTaskReply());
|
||||
callbacks.pop_front();
|
||||
return true;
|
||||
}
|
||||
|
||||
std::list<rpc::ClientCallback<rpc::PushTaskReply>> callbacks;
|
||||
uint64_t counter = 0;
|
||||
};
|
||||
|
||||
class MockTaskFinisher : public TaskFinisherInterface {
|
||||
public:
|
||||
MockTaskFinisher() {}
|
||||
|
||||
MOCK_METHOD2(CompletePendingTask, void(const TaskID &, const rpc::PushTaskReply &));
|
||||
MOCK_METHOD2(FailPendingTask, void(const TaskID &task_id, rpc::ErrorType error_type));
|
||||
};
|
||||
|
||||
TaskSpecification CreateActorTaskHelper(ActorID actor_id, int64_t counter) {
|
||||
TaskSpecification task;
|
||||
task.GetMutableMessage().set_task_id(TaskID::Nil().Binary());
|
||||
|
@ -38,11 +59,13 @@ class DirectActorTransportTest : public ::testing::Test {
|
|||
DirectActorTransportTest()
|
||||
: worker_client_(std::shared_ptr<MockWorkerClient>(new MockWorkerClient())),
|
||||
store_(std::shared_ptr<CoreWorkerMemoryStore>(new CoreWorkerMemoryStore())),
|
||||
submitter_([&](const rpc::WorkerAddress &addr) { return worker_client_; },
|
||||
store_) {}
|
||||
task_finisher_(std::make_shared<MockTaskFinisher>()),
|
||||
submitter_([&](const rpc::WorkerAddress &addr) { return worker_client_; }, store_,
|
||||
task_finisher_) {}
|
||||
|
||||
std::shared_ptr<MockWorkerClient> worker_client_;
|
||||
std::shared_ptr<CoreWorkerMemoryStore> store_;
|
||||
std::shared_ptr<MockTaskFinisher> task_finisher_;
|
||||
CoreWorkerDirectActorTaskSubmitter submitter_;
|
||||
};
|
||||
|
||||
|
@ -60,6 +83,13 @@ TEST_F(DirectActorTransportTest, TestSubmitTask) {
|
|||
task = CreateActorTaskHelper(actor_id, 1);
|
||||
ASSERT_TRUE(submitter_.SubmitTask(task).ok());
|
||||
ASSERT_EQ(worker_client_->callbacks.size(), 2);
|
||||
|
||||
EXPECT_CALL(*task_finisher_, CompletePendingTask(TaskID::Nil(), _))
|
||||
.Times(worker_client_->callbacks.size());
|
||||
EXPECT_CALL(*task_finisher_, FailPendingTask(_, _)).Times(0);
|
||||
while (!worker_client_->callbacks.empty()) {
|
||||
ASSERT_TRUE(worker_client_->ReplyPushTask());
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(DirectActorTransportTest, TestDependencies) {
|
||||
|
@ -119,6 +149,27 @@ TEST_F(DirectActorTransportTest, TestOutOfOrderDependencies) {
|
|||
ASSERT_EQ(worker_client_->callbacks.size(), 2);
|
||||
}
|
||||
|
||||
TEST_F(DirectActorTransportTest, TestActorFailure) {
|
||||
ActorID actor_id = ActorID::Of(JobID::FromInt(0), TaskID::Nil(), 0);
|
||||
gcs::ActorTableData actor_data;
|
||||
submitter_.HandleActorUpdate(actor_id, actor_data);
|
||||
ASSERT_EQ(worker_client_->callbacks.size(), 0);
|
||||
|
||||
// Create two tasks for the actor.
|
||||
auto task1 = CreateActorTaskHelper(actor_id, 0);
|
||||
auto task2 = CreateActorTaskHelper(actor_id, 1);
|
||||
ASSERT_TRUE(submitter_.SubmitTask(task1).ok());
|
||||
ASSERT_TRUE(submitter_.SubmitTask(task2).ok());
|
||||
ASSERT_EQ(worker_client_->callbacks.size(), 2);
|
||||
|
||||
// Simulate the actor dying. All submitted tasks should get failed.
|
||||
EXPECT_CALL(*task_finisher_, FailPendingTask(_, _)).Times(2);
|
||||
EXPECT_CALL(*task_finisher_, CompletePendingTask(_, _)).Times(0);
|
||||
while (!worker_client_->callbacks.empty()) {
|
||||
ASSERT_TRUE(worker_client_->ReplyPushTask(Status::IOError("")));
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace ray
|
||||
|
||||
int main(int argc, char **argv) {
|
||||
|
|
|
@ -23,7 +23,32 @@ class MockWorkerClient : public rpc::CoreWorkerClientInterface {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
std::vector<rpc::ClientCallback<rpc::PushTaskReply>> callbacks;
|
||||
bool ReplyPushTask(Status status = Status::OK()) {
|
||||
if (callbacks.size() == 0) {
|
||||
return false;
|
||||
}
|
||||
auto callback = callbacks.front();
|
||||
callback(status, rpc::PushTaskReply());
|
||||
callbacks.pop_front();
|
||||
return true;
|
||||
}
|
||||
|
||||
std::list<rpc::ClientCallback<rpc::PushTaskReply>> callbacks;
|
||||
};
|
||||
|
||||
class MockTaskFinisher : public TaskFinisherInterface {
|
||||
public:
|
||||
MockTaskFinisher() {}
|
||||
|
||||
void CompletePendingTask(const TaskID &, const rpc::PushTaskReply &) override {
|
||||
num_tasks_complete++;
|
||||
}
|
||||
void FailPendingTask(const TaskID &task_id, rpc::ErrorType error_type) override {
|
||||
num_tasks_failed++;
|
||||
}
|
||||
|
||||
int num_tasks_complete = 0;
|
||||
int num_tasks_failed = 0;
|
||||
};
|
||||
|
||||
class MockRayletClient : public WorkerLeaseInterface {
|
||||
|
@ -196,8 +221,9 @@ TEST(DirectTaskTransportTest, TestSubmitOneTask) {
|
|||
auto worker_client = std::make_shared<MockWorkerClient>();
|
||||
auto store = std::make_shared<CoreWorkerMemoryStore>();
|
||||
auto factory = [&](const rpc::WorkerAddress &addr) { return worker_client; };
|
||||
auto task_finisher = std::make_shared<MockTaskFinisher>();
|
||||
CoreWorkerDirectTaskSubmitter submitter(raylet_client, factory, nullptr, store,
|
||||
kLongTimeout);
|
||||
task_finisher, kLongTimeout);
|
||||
TaskSpecification task;
|
||||
task.GetMutableMessage().set_task_id(TaskID::Nil().Binary());
|
||||
|
||||
|
@ -208,10 +234,14 @@ TEST(DirectTaskTransportTest, TestSubmitOneTask) {
|
|||
|
||||
ASSERT_TRUE(raylet_client->GrantWorkerLease("localhost", 1234, ClientID::Nil()));
|
||||
ASSERT_EQ(worker_client->callbacks.size(), 1);
|
||||
ASSERT_EQ(task_finisher->num_tasks_complete, 0);
|
||||
ASSERT_EQ(task_finisher->num_tasks_failed, 0);
|
||||
|
||||
worker_client->callbacks[0](Status::OK(), rpc::PushTaskReply());
|
||||
ASSERT_TRUE(worker_client->ReplyPushTask());
|
||||
ASSERT_EQ(raylet_client->num_workers_returned, 1);
|
||||
ASSERT_EQ(raylet_client->num_workers_disconnected, 0);
|
||||
ASSERT_EQ(task_finisher->num_tasks_complete, 1);
|
||||
ASSERT_EQ(task_finisher->num_tasks_failed, 0);
|
||||
}
|
||||
|
||||
TEST(DirectTaskTransportTest, TestHandleTaskFailure) {
|
||||
|
@ -219,18 +249,21 @@ TEST(DirectTaskTransportTest, TestHandleTaskFailure) {
|
|||
auto worker_client = std::make_shared<MockWorkerClient>();
|
||||
auto store = std::make_shared<CoreWorkerMemoryStore>();
|
||||
auto factory = [&](const rpc::WorkerAddress &addr) { return worker_client; };
|
||||
auto task_finisher = std::make_shared<MockTaskFinisher>();
|
||||
CoreWorkerDirectTaskSubmitter submitter(raylet_client, factory, nullptr, store,
|
||||
kLongTimeout);
|
||||
task_finisher, kLongTimeout);
|
||||
TaskSpecification task;
|
||||
task.GetMutableMessage().set_task_id(TaskID::Nil().Binary());
|
||||
|
||||
ASSERT_TRUE(submitter.SubmitTask(task).ok());
|
||||
ASSERT_TRUE(raylet_client->GrantWorkerLease("localhost", 1234, ClientID::Nil()));
|
||||
// Simulate a system failure, i.e., worker died unexpectedly.
|
||||
worker_client->callbacks[0](Status::IOError("oops"), rpc::PushTaskReply());
|
||||
ASSERT_EQ(worker_client->callbacks.size(), 1);
|
||||
ASSERT_TRUE(worker_client->ReplyPushTask(Status::IOError("oops")));
|
||||
ASSERT_EQ(worker_client->callbacks.size(), 0);
|
||||
ASSERT_EQ(raylet_client->num_workers_returned, 0);
|
||||
ASSERT_EQ(raylet_client->num_workers_disconnected, 1);
|
||||
ASSERT_EQ(task_finisher->num_tasks_complete, 0);
|
||||
ASSERT_EQ(task_finisher->num_tasks_failed, 1);
|
||||
}
|
||||
|
||||
TEST(DirectTaskTransportTest, TestConcurrentWorkerLeases) {
|
||||
|
@ -238,8 +271,9 @@ TEST(DirectTaskTransportTest, TestConcurrentWorkerLeases) {
|
|||
auto worker_client = std::make_shared<MockWorkerClient>();
|
||||
auto store = std::make_shared<CoreWorkerMemoryStore>();
|
||||
auto factory = [&](const rpc::WorkerAddress &addr) { return worker_client; };
|
||||
auto task_finisher = std::make_shared<MockTaskFinisher>();
|
||||
CoreWorkerDirectTaskSubmitter submitter(raylet_client, factory, nullptr, store,
|
||||
kLongTimeout);
|
||||
task_finisher, kLongTimeout);
|
||||
TaskSpecification task1;
|
||||
TaskSpecification task2;
|
||||
TaskSpecification task3;
|
||||
|
@ -268,11 +302,13 @@ TEST(DirectTaskTransportTest, TestConcurrentWorkerLeases) {
|
|||
ASSERT_EQ(raylet_client->num_workers_requested, 3);
|
||||
|
||||
// All workers returned.
|
||||
for (const auto &cb : worker_client->callbacks) {
|
||||
cb(Status::OK(), rpc::PushTaskReply());
|
||||
while (!worker_client->callbacks.empty()) {
|
||||
ASSERT_TRUE(worker_client->ReplyPushTask());
|
||||
}
|
||||
ASSERT_EQ(raylet_client->num_workers_returned, 3);
|
||||
ASSERT_EQ(raylet_client->num_workers_disconnected, 0);
|
||||
ASSERT_EQ(task_finisher->num_tasks_complete, 3);
|
||||
ASSERT_EQ(task_finisher->num_tasks_failed, 0);
|
||||
}
|
||||
|
||||
TEST(DirectTaskTransportTest, TestReuseWorkerLease) {
|
||||
|
@ -280,8 +316,9 @@ TEST(DirectTaskTransportTest, TestReuseWorkerLease) {
|
|||
auto worker_client = std::make_shared<MockWorkerClient>();
|
||||
auto store = std::make_shared<CoreWorkerMemoryStore>();
|
||||
auto factory = [&](const rpc::WorkerAddress &addr) { return worker_client; };
|
||||
auto task_finisher = std::make_shared<MockTaskFinisher>();
|
||||
CoreWorkerDirectTaskSubmitter submitter(raylet_client, factory, nullptr, store,
|
||||
kLongTimeout);
|
||||
task_finisher, kLongTimeout);
|
||||
TaskSpecification task1;
|
||||
TaskSpecification task2;
|
||||
TaskSpecification task3;
|
||||
|
@ -300,23 +337,26 @@ TEST(DirectTaskTransportTest, TestReuseWorkerLease) {
|
|||
ASSERT_EQ(raylet_client->num_workers_requested, 2);
|
||||
|
||||
// Task 1 finishes, Task 2 is scheduled on the same worker.
|
||||
worker_client->callbacks[0](Status::OK(), rpc::PushTaskReply());
|
||||
ASSERT_EQ(worker_client->callbacks.size(), 2);
|
||||
ASSERT_TRUE(worker_client->ReplyPushTask());
|
||||
ASSERT_EQ(worker_client->callbacks.size(), 1);
|
||||
ASSERT_EQ(raylet_client->num_workers_returned, 0);
|
||||
|
||||
// Task 2 finishes, Task 3 is scheduled on the same worker.
|
||||
worker_client->callbacks[1](Status::OK(), rpc::PushTaskReply());
|
||||
ASSERT_EQ(worker_client->callbacks.size(), 3);
|
||||
ASSERT_TRUE(worker_client->ReplyPushTask());
|
||||
ASSERT_EQ(worker_client->callbacks.size(), 1);
|
||||
ASSERT_EQ(raylet_client->num_workers_returned, 0);
|
||||
|
||||
// Task 3 finishes, the worker is returned.
|
||||
worker_client->callbacks[2](Status::OK(), rpc::PushTaskReply());
|
||||
ASSERT_TRUE(worker_client->ReplyPushTask());
|
||||
ASSERT_EQ(raylet_client->num_workers_returned, 1);
|
||||
|
||||
// The second lease request is returned immediately.
|
||||
ASSERT_TRUE(raylet_client->GrantWorkerLease("localhost", 1001, ClientID::Nil()));
|
||||
ASSERT_EQ(worker_client->callbacks.size(), 0);
|
||||
ASSERT_EQ(raylet_client->num_workers_returned, 2);
|
||||
ASSERT_EQ(raylet_client->num_workers_disconnected, 0);
|
||||
ASSERT_EQ(task_finisher->num_tasks_complete, 3);
|
||||
ASSERT_EQ(task_finisher->num_tasks_failed, 0);
|
||||
}
|
||||
|
||||
TEST(DirectTaskTransportTest, TestWorkerNotReusedOnError) {
|
||||
|
@ -324,8 +364,9 @@ TEST(DirectTaskTransportTest, TestWorkerNotReusedOnError) {
|
|||
auto worker_client = std::make_shared<MockWorkerClient>();
|
||||
auto store = std::make_shared<CoreWorkerMemoryStore>();
|
||||
auto factory = [&](const rpc::WorkerAddress &addr) { return worker_client; };
|
||||
auto task_finisher = std::make_shared<MockTaskFinisher>();
|
||||
CoreWorkerDirectTaskSubmitter submitter(raylet_client, factory, nullptr, store,
|
||||
kLongTimeout);
|
||||
task_finisher, kLongTimeout);
|
||||
TaskSpecification task1;
|
||||
TaskSpecification task2;
|
||||
task1.GetMutableMessage().set_task_id(TaskID::Nil().Binary());
|
||||
|
@ -341,17 +382,18 @@ TEST(DirectTaskTransportTest, TestWorkerNotReusedOnError) {
|
|||
ASSERT_EQ(raylet_client->num_workers_requested, 2);
|
||||
|
||||
// Task 1 finishes with failure; the worker is returned.
|
||||
worker_client->callbacks[0](Status::IOError("worker dead"), rpc::PushTaskReply());
|
||||
ASSERT_EQ(worker_client->callbacks.size(), 1);
|
||||
ASSERT_TRUE(worker_client->ReplyPushTask(Status::IOError("worker dead")));
|
||||
ASSERT_EQ(worker_client->callbacks.size(), 0);
|
||||
ASSERT_EQ(raylet_client->num_workers_returned, 0);
|
||||
ASSERT_EQ(raylet_client->num_workers_disconnected, 1);
|
||||
|
||||
// Task 2 runs successfully on the second worker.
|
||||
ASSERT_TRUE(raylet_client->GrantWorkerLease("localhost", 1001, ClientID::Nil()));
|
||||
ASSERT_EQ(worker_client->callbacks.size(), 2);
|
||||
worker_client->callbacks[1](Status::OK(), rpc::PushTaskReply());
|
||||
ASSERT_TRUE(worker_client->ReplyPushTask());
|
||||
ASSERT_EQ(raylet_client->num_workers_returned, 1);
|
||||
ASSERT_EQ(raylet_client->num_workers_disconnected, 1);
|
||||
ASSERT_EQ(task_finisher->num_tasks_complete, 1);
|
||||
ASSERT_EQ(task_finisher->num_tasks_failed, 1);
|
||||
}
|
||||
|
||||
TEST(DirectTaskTransportTest, TestSpillback) {
|
||||
|
@ -369,8 +411,9 @@ TEST(DirectTaskTransportTest, TestSpillback) {
|
|||
remote_lease_clients[raylet_id] = client;
|
||||
return client;
|
||||
};
|
||||
auto task_finisher = std::make_shared<MockTaskFinisher>();
|
||||
CoreWorkerDirectTaskSubmitter submitter(raylet_client, factory, lease_client_factory,
|
||||
store, kLongTimeout);
|
||||
store, task_finisher, kLongTimeout);
|
||||
TaskSpecification task;
|
||||
task.GetMutableMessage().set_task_id(TaskID::Nil().Binary());
|
||||
|
||||
|
@ -389,15 +432,15 @@ TEST(DirectTaskTransportTest, TestSpillback) {
|
|||
// Trigger retry at the remote node.
|
||||
ASSERT_TRUE(remote_lease_clients[remote_raylet_id]->GrantWorkerLease("remote", 1234,
|
||||
ClientID::Nil()));
|
||||
ASSERT_EQ(worker_client->callbacks.size(), 1);
|
||||
|
||||
// The worker is returned to the remote node, not the local one.
|
||||
worker_client->callbacks[0](Status::OK(), rpc::PushTaskReply());
|
||||
ASSERT_TRUE(worker_client->ReplyPushTask());
|
||||
ASSERT_EQ(raylet_client->num_workers_returned, 0);
|
||||
ASSERT_EQ(remote_lease_clients[remote_raylet_id]->num_workers_returned, 1);
|
||||
|
||||
ASSERT_EQ(raylet_client->num_workers_disconnected, 0);
|
||||
ASSERT_EQ(remote_lease_clients[remote_raylet_id]->num_workers_disconnected, 0);
|
||||
ASSERT_EQ(task_finisher->num_tasks_complete, 1);
|
||||
ASSERT_EQ(task_finisher->num_tasks_failed, 0);
|
||||
}
|
||||
|
||||
TEST(DirectTaskTransportTest, TestWorkerLeaseTimeout) {
|
||||
|
@ -405,7 +448,9 @@ TEST(DirectTaskTransportTest, TestWorkerLeaseTimeout) {
|
|||
auto worker_client = std::make_shared<MockWorkerClient>();
|
||||
auto store = std::make_shared<CoreWorkerMemoryStore>();
|
||||
auto factory = [&](const rpc::WorkerAddress &addr) { return worker_client; };
|
||||
auto task_finisher = std::make_shared<MockTaskFinisher>();
|
||||
CoreWorkerDirectTaskSubmitter submitter(raylet_client, factory, nullptr, store,
|
||||
task_finisher,
|
||||
/*lease_timeout_ms=*/5);
|
||||
TaskSpecification task1;
|
||||
TaskSpecification task2;
|
||||
|
@ -421,13 +466,11 @@ TEST(DirectTaskTransportTest, TestWorkerLeaseTimeout) {
|
|||
|
||||
// Task 1 is pushed.
|
||||
ASSERT_TRUE(raylet_client->GrantWorkerLease("localhost", 1000, ClientID::Nil()));
|
||||
ASSERT_EQ(worker_client->callbacks.size(), 1);
|
||||
ASSERT_EQ(raylet_client->num_workers_requested, 2);
|
||||
|
||||
// Task 1 finishes with failure; the worker is returned due to the error even though
|
||||
// it hasn't timed out.
|
||||
worker_client->callbacks[0](Status::IOError("worker dead"), rpc::PushTaskReply());
|
||||
ASSERT_EQ(worker_client->callbacks.size(), 1);
|
||||
ASSERT_TRUE(worker_client->ReplyPushTask(Status::IOError("worker dead")));
|
||||
ASSERT_EQ(raylet_client->num_workers_returned, 0);
|
||||
ASSERT_EQ(raylet_client->num_workers_disconnected, 1);
|
||||
|
||||
|
@ -435,16 +478,15 @@ TEST(DirectTaskTransportTest, TestWorkerLeaseTimeout) {
|
|||
// timeout.
|
||||
ASSERT_TRUE(raylet_client->GrantWorkerLease("localhost", 1001, ClientID::Nil()));
|
||||
usleep(10 * 1000); // Sleep for 10ms, causing the lease to time out.
|
||||
ASSERT_EQ(worker_client->callbacks.size(), 2);
|
||||
worker_client->callbacks[1](Status::OK(), rpc::PushTaskReply());
|
||||
ASSERT_TRUE(worker_client->ReplyPushTask());
|
||||
ASSERT_EQ(raylet_client->num_workers_returned, 1);
|
||||
ASSERT_EQ(raylet_client->num_workers_disconnected, 1);
|
||||
|
||||
// Task 3 runs successfully on the third worker; the worker is returned even though it
|
||||
// hasn't timed out.
|
||||
ASSERT_TRUE(raylet_client->GrantWorkerLease("localhost", 1002, ClientID::Nil()));
|
||||
ASSERT_EQ(worker_client->callbacks.size(), 3);
|
||||
worker_client->callbacks[2](Status::OK(), rpc::PushTaskReply());
|
||||
ASSERT_TRUE(worker_client->ReplyPushTask());
|
||||
ASSERT_EQ(worker_client->callbacks.size(), 0);
|
||||
ASSERT_EQ(raylet_client->num_workers_returned, 2);
|
||||
ASSERT_EQ(raylet_client->num_workers_disconnected, 1);
|
||||
}
|
||||
|
|
77
src/ray/core_worker/test/task_manager_test.cc
Normal file
77
src/ray/core_worker/test/task_manager_test.cc
Normal file
|
@ -0,0 +1,77 @@
|
|||
#include "gtest/gtest.h"
|
||||
|
||||
#include "ray/common/task/task_spec.h"
|
||||
#include "ray/core_worker/store_provider/memory_store/memory_store.h"
|
||||
#include "ray/core_worker/task_manager.h"
|
||||
#include "ray/util/test_util.h"
|
||||
|
||||
namespace ray {
|
||||
|
||||
TaskSpecification CreateTaskHelper(uint64_t num_returns) {
|
||||
TaskSpecification task;
|
||||
task.GetMutableMessage().set_task_id(TaskID::ForFakeTask().Binary());
|
||||
task.GetMutableMessage().set_num_returns(num_returns);
|
||||
return task;
|
||||
}
|
||||
|
||||
class TaskManagerTest : public ::testing::Test {
|
||||
public:
|
||||
TaskManagerTest()
|
||||
: store_(std::shared_ptr<CoreWorkerMemoryStore>(new CoreWorkerMemoryStore())),
|
||||
manager_(store_) {}
|
||||
|
||||
std::shared_ptr<CoreWorkerMemoryStore> store_;
|
||||
TaskManager manager_;
|
||||
};
|
||||
|
||||
TEST_F(TaskManagerTest, TestTaskSuccess) {
|
||||
auto spec = CreateTaskHelper(1);
|
||||
ASSERT_FALSE(manager_.IsTaskPending(spec.TaskId()));
|
||||
manager_.AddPendingTask(spec);
|
||||
ASSERT_TRUE(manager_.IsTaskPending(spec.TaskId()));
|
||||
auto return_id = spec.ReturnId(0, TaskTransportType::DIRECT);
|
||||
WorkerContext ctx(WorkerType::WORKER, JobID::FromInt(0));
|
||||
|
||||
rpc::PushTaskReply reply;
|
||||
auto return_object = reply.add_return_objects();
|
||||
return_object->set_object_id(return_id.Binary());
|
||||
auto data = GenerateRandomBuffer();
|
||||
return_object->set_data(data->Data(), data->Size());
|
||||
manager_.CompletePendingTask(spec.TaskId(), reply);
|
||||
ASSERT_FALSE(manager_.IsTaskPending(spec.TaskId()));
|
||||
|
||||
std::vector<std::shared_ptr<RayObject>> results;
|
||||
RAY_CHECK_OK(store_->Get({return_id}, 1, -1, ctx, false, &results));
|
||||
ASSERT_EQ(results.size(), 1);
|
||||
ASSERT_FALSE(results[0]->IsException());
|
||||
ASSERT_EQ(std::memcmp(results[0]->GetData()->Data(), return_object->data().data(),
|
||||
return_object->data().size()),
|
||||
0);
|
||||
}
|
||||
|
||||
TEST_F(TaskManagerTest, TestTaskFailure) {
|
||||
auto spec = CreateTaskHelper(1);
|
||||
ASSERT_FALSE(manager_.IsTaskPending(spec.TaskId()));
|
||||
manager_.AddPendingTask(spec);
|
||||
ASSERT_TRUE(manager_.IsTaskPending(spec.TaskId()));
|
||||
auto return_id = spec.ReturnId(0, TaskTransportType::DIRECT);
|
||||
WorkerContext ctx(WorkerType::WORKER, JobID::FromInt(0));
|
||||
|
||||
auto error = rpc::ErrorType::WORKER_DIED;
|
||||
manager_.FailPendingTask(spec.TaskId(), error);
|
||||
ASSERT_FALSE(manager_.IsTaskPending(spec.TaskId()));
|
||||
|
||||
std::vector<std::shared_ptr<RayObject>> results;
|
||||
RAY_CHECK_OK(store_->Get({return_id}, 1, -1, ctx, false, &results));
|
||||
ASSERT_EQ(results.size(), 1);
|
||||
rpc::ErrorType stored_error;
|
||||
ASSERT_TRUE(results[0]->IsException(&stored_error));
|
||||
ASSERT_EQ(stored_error, error);
|
||||
}
|
||||
|
||||
} // namespace ray
|
||||
|
||||
int main(int argc, char **argv) {
|
||||
::testing::InitGoogleTest(&argc, argv);
|
||||
return RUN_ALL_TESTS();
|
||||
}
|
|
@ -11,57 +11,6 @@ int64_t GetRequestNumber(const std::unique_ptr<rpc::PushTaskRequest> &request) {
|
|||
return request->task_spec().actor_task_spec().actor_counter();
|
||||
}
|
||||
|
||||
void TreatTaskAsFailed(const TaskID &task_id, int num_returns,
|
||||
const rpc::ErrorType &error_type,
|
||||
std::shared_ptr<CoreWorkerMemoryStore> &in_memory_store) {
|
||||
RAY_LOG(DEBUG) << "Treat task as failed. task_id: " << task_id
|
||||
<< ", error_type: " << ErrorType_Name(error_type);
|
||||
for (int i = 0; i < num_returns; i++) {
|
||||
const auto object_id = ObjectID::ForTaskReturn(
|
||||
task_id, /*index=*/i + 1,
|
||||
/*transport_type=*/static_cast<int>(TaskTransportType::DIRECT));
|
||||
std::string meta = std::to_string(static_cast<int>(error_type));
|
||||
auto metadata = const_cast<uint8_t *>(reinterpret_cast<const uint8_t *>(meta.data()));
|
||||
auto meta_buffer = std::make_shared<LocalMemoryBuffer>(metadata, meta.size());
|
||||
RAY_CHECK_OK(in_memory_store->Put(RayObject(nullptr, meta_buffer), object_id));
|
||||
}
|
||||
}
|
||||
|
||||
void WriteObjectsToMemoryStore(const rpc::PushTaskReply &reply,
|
||||
std::shared_ptr<CoreWorkerMemoryStore> &in_memory_store) {
|
||||
for (int i = 0; i < reply.return_objects_size(); i++) {
|
||||
const auto &return_object = reply.return_objects(i);
|
||||
ObjectID object_id = ObjectID::FromBinary(return_object.object_id());
|
||||
|
||||
if (return_object.in_plasma()) {
|
||||
// Mark it as in plasma with a dummy object.
|
||||
std::string meta =
|
||||
std::to_string(static_cast<int>(rpc::ErrorType::OBJECT_IN_PLASMA));
|
||||
auto metadata =
|
||||
const_cast<uint8_t *>(reinterpret_cast<const uint8_t *>(meta.data()));
|
||||
auto meta_buffer = std::make_shared<LocalMemoryBuffer>(metadata, meta.size());
|
||||
RAY_CHECK_OK(in_memory_store->Put(RayObject(nullptr, meta_buffer), object_id));
|
||||
} else {
|
||||
std::shared_ptr<LocalMemoryBuffer> data_buffer;
|
||||
if (return_object.data().size() > 0) {
|
||||
data_buffer = std::make_shared<LocalMemoryBuffer>(
|
||||
const_cast<uint8_t *>(
|
||||
reinterpret_cast<const uint8_t *>(return_object.data().data())),
|
||||
return_object.data().size());
|
||||
}
|
||||
std::shared_ptr<LocalMemoryBuffer> metadata_buffer;
|
||||
if (return_object.metadata().size() > 0) {
|
||||
metadata_buffer = std::make_shared<LocalMemoryBuffer>(
|
||||
const_cast<uint8_t *>(
|
||||
reinterpret_cast<const uint8_t *>(return_object.metadata().data())),
|
||||
return_object.metadata().size());
|
||||
}
|
||||
RAY_CHECK_OK(
|
||||
in_memory_store->Put(RayObject(data_buffer, metadata_buffer), object_id));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Status CoreWorkerDirectActorTaskSubmitter::SubmitTask(TaskSpecification task_spec) {
|
||||
RAY_LOG(DEBUG) << "Submitting task " << task_spec.TaskId();
|
||||
RAY_CHECK(task_spec.IsActorTask());
|
||||
|
@ -69,7 +18,6 @@ Status CoreWorkerDirectActorTaskSubmitter::SubmitTask(TaskSpecification task_spe
|
|||
resolver_.ResolveDependencies(task_spec, [this, task_spec]() mutable {
|
||||
const auto &actor_id = task_spec.ActorId();
|
||||
const auto task_id = task_spec.TaskId();
|
||||
const auto num_returns = task_spec.NumReturns();
|
||||
|
||||
auto request = std::unique_ptr<rpc::PushTaskRequest>(new rpc::PushTaskRequest);
|
||||
request->mutable_task_spec()->Swap(&task_spec.GetMutableMessage());
|
||||
|
@ -97,8 +45,7 @@ Status CoreWorkerDirectActorTaskSubmitter::SubmitTask(TaskSpecification task_spe
|
|||
} else {
|
||||
// Actor is dead, treat the task as failure.
|
||||
RAY_CHECK(iter->second.state_ == ActorTableData::DEAD);
|
||||
TreatTaskAsFailed(task_id, num_returns, rpc::ErrorType::ACTOR_DIED,
|
||||
in_memory_store_);
|
||||
task_finisher_->FailPendingTask(task_id, rpc::ErrorType::ACTOR_DIED);
|
||||
}
|
||||
});
|
||||
|
||||
|
@ -130,19 +77,6 @@ void CoreWorkerDirectActorTaskSubmitter::HandleActorUpdate(
|
|||
// Remove rpc client if it's dead or being reconstructed.
|
||||
rpc_clients_.erase(actor_id);
|
||||
|
||||
// For tasks that have been sent and are waiting for replies, treat them
|
||||
// as failed when the destination actor is dead or reconstructing.
|
||||
auto iter = waiting_reply_tasks_.find(actor_id);
|
||||
if (iter != waiting_reply_tasks_.end()) {
|
||||
for (const auto &entry : iter->second) {
|
||||
const auto &task_id = entry.first;
|
||||
const auto num_returns = entry.second;
|
||||
TreatTaskAsFailed(task_id, num_returns, rpc::ErrorType::ACTOR_DIED,
|
||||
in_memory_store_);
|
||||
}
|
||||
waiting_reply_tasks_.erase(actor_id);
|
||||
}
|
||||
|
||||
// If there are pending requests, treat the pending tasks as failed.
|
||||
auto pending_it = pending_requests_.find(actor_id);
|
||||
if (pending_it != pending_requests_.end()) {
|
||||
|
@ -150,15 +84,16 @@ void CoreWorkerDirectActorTaskSubmitter::HandleActorUpdate(
|
|||
while (head != pending_it->second.end()) {
|
||||
auto request = std::move(head->second);
|
||||
head = pending_it->second.erase(head);
|
||||
|
||||
TreatTaskAsFailed(TaskID::FromBinary(request->task_spec().task_id()),
|
||||
request->task_spec().num_returns(), rpc::ErrorType::ACTOR_DIED,
|
||||
in_memory_store_);
|
||||
auto task_id = TaskID::FromBinary(request->task_spec().task_id());
|
||||
task_finisher_->FailPendingTask(task_id, rpc::ErrorType::ACTOR_DIED);
|
||||
}
|
||||
pending_requests_.erase(pending_it);
|
||||
}
|
||||
|
||||
next_sequence_number_.erase(actor_id);
|
||||
|
||||
// No need to clean up tasks that have been sent and are waiting for
|
||||
// replies. They will be treated as failed once the connection dies.
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -182,7 +117,6 @@ void CoreWorkerDirectActorTaskSubmitter::PushActorTask(
|
|||
rpc::CoreWorkerClientInterface &client, std::unique_ptr<rpc::PushTaskRequest> request,
|
||||
const ActorID &actor_id, const TaskID &task_id, int num_returns) {
|
||||
RAY_LOG(DEBUG) << "Pushing task " << task_id << " to actor " << actor_id;
|
||||
waiting_reply_tasks_[actor_id].insert(std::make_pair(task_id, num_returns));
|
||||
|
||||
auto task_number = GetRequestNumber(request);
|
||||
RAY_CHECK(next_sequence_number_[actor_id] == task_number)
|
||||
|
@ -190,24 +124,19 @@ void CoreWorkerDirectActorTaskSubmitter::PushActorTask(
|
|||
next_sequence_number_[actor_id]++;
|
||||
|
||||
auto status = client.PushActorTask(
|
||||
std::move(request), [this, actor_id, task_id, num_returns](
|
||||
Status status, const rpc::PushTaskReply &reply) {
|
||||
{
|
||||
std::unique_lock<std::mutex> guard(mutex_);
|
||||
waiting_reply_tasks_[actor_id].erase(task_id);
|
||||
}
|
||||
std::move(request),
|
||||
[this, task_id](Status status, const rpc::PushTaskReply &reply) {
|
||||
if (!status.ok()) {
|
||||
// Note that this might be the __ray_terminate__ task, so we don't log
|
||||
// loudly with ERROR here.
|
||||
RAY_LOG(INFO) << "Task failed with error: " << status;
|
||||
TreatTaskAsFailed(task_id, num_returns, rpc::ErrorType::ACTOR_DIED,
|
||||
in_memory_store_);
|
||||
return;
|
||||
task_finisher_->FailPendingTask(task_id, rpc::ErrorType::ACTOR_DIED);
|
||||
} else {
|
||||
task_finisher_->CompletePendingTask(task_id, reply);
|
||||
}
|
||||
WriteObjectsToMemoryStore(reply, in_memory_store_);
|
||||
});
|
||||
if (!status.ok()) {
|
||||
TreatTaskAsFailed(task_id, num_returns, rpc::ErrorType::ACTOR_DIED, in_memory_store_);
|
||||
task_finisher_->FailPendingTask(task_id, rpc::ErrorType::ACTOR_DIED);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
#include "ray/common/ray_object.h"
|
||||
#include "ray/core_worker/context.h"
|
||||
#include "ray/core_worker/store_provider/memory_store/memory_store.h"
|
||||
#include "ray/core_worker/task_manager.h"
|
||||
#include "ray/core_worker/transport/dependency_resolver.h"
|
||||
#include "ray/gcs/redis_gcs_client.h"
|
||||
#include "ray/rpc/grpc_server.h"
|
||||
|
@ -28,25 +29,6 @@ namespace ray {
|
|||
/// The max time to wait for out-of-order tasks.
|
||||
const int kMaxReorderWaitSeconds = 30;
|
||||
|
||||
/// Treat a task as failed.
|
||||
///
|
||||
/// \param[in] task_id The ID of a task.
|
||||
/// \param[in] num_returns Number of return objects.
|
||||
/// \param[in] error_type The type of the specific error.
|
||||
/// \param[in] in_memory_store The memory store to write to.
|
||||
/// \return Void.
|
||||
void TreatTaskAsFailed(const TaskID &task_id, int num_returns,
|
||||
const rpc::ErrorType &error_type,
|
||||
std::shared_ptr<CoreWorkerMemoryStore> &in_memory_store);
|
||||
|
||||
/// Write return objects to the memory store.
|
||||
///
|
||||
/// \param[in] reply Proto response to a direct actor or task call.
|
||||
/// \param[in] in_memory_store The memory store to write to.
|
||||
/// \return Void.
|
||||
void WriteObjectsToMemoryStore(const rpc::PushTaskReply &reply,
|
||||
std::shared_ptr<CoreWorkerMemoryStore> &in_memory_store);
|
||||
|
||||
/// In direct actor call task submitter and receiver, a task is directly submitted
|
||||
/// to the actor that will execute it.
|
||||
|
||||
|
@ -66,10 +48,11 @@ struct ActorStateData {
|
|||
class CoreWorkerDirectActorTaskSubmitter {
|
||||
public:
|
||||
CoreWorkerDirectActorTaskSubmitter(rpc::ClientFactoryFn client_factory,
|
||||
std::shared_ptr<CoreWorkerMemoryStore> store)
|
||||
std::shared_ptr<CoreWorkerMemoryStore> store,
|
||||
std::shared_ptr<TaskFinisherInterface> task_finisher)
|
||||
: client_factory_(client_factory),
|
||||
in_memory_store_(store),
|
||||
resolver_(in_memory_store_) {}
|
||||
resolver_(store),
|
||||
task_finisher_(task_finisher) {}
|
||||
|
||||
/// Submit a task to an actor for execution.
|
||||
///
|
||||
|
@ -138,15 +121,12 @@ class CoreWorkerDirectActorTaskSubmitter {
|
|||
/// actor.
|
||||
std::unordered_map<ActorID, int64_t> next_sequence_number_;
|
||||
|
||||
/// Map from actor id to the tasks that are waiting for reply.
|
||||
std::unordered_map<ActorID, std::unordered_map<TaskID, int>> waiting_reply_tasks_;
|
||||
|
||||
/// The in-memory store.
|
||||
std::shared_ptr<CoreWorkerMemoryStore> in_memory_store_;
|
||||
|
||||
/// Resolve direct call object dependencies;
|
||||
LocalDependencyResolver resolver_;
|
||||
|
||||
/// Used to complete tasks.
|
||||
std::shared_ptr<TaskFinisherInterface> task_finisher_;
|
||||
|
||||
friend class CoreWorkerTest;
|
||||
};
|
||||
|
||||
|
|
|
@ -43,12 +43,16 @@ void CoreWorkerDirectTaskSubmitter::OnWorkerIdle(const rpc::WorkerAddress &addr,
|
|||
if (was_error || queued_tasks_.empty() || current_time_ms() > entry.second) {
|
||||
RAY_CHECK_OK(entry.first->ReturnWorker(addr.second, was_error));
|
||||
worker_to_lease_client_.erase(addr);
|
||||
} else {
|
||||
} else if (!queued_tasks_.empty()) {
|
||||
auto &client = *client_cache_[addr];
|
||||
PushNormalTask(addr, client, queued_tasks_.front());
|
||||
queued_tasks_.pop_front();
|
||||
}
|
||||
|
||||
// There are more tasks to run, so try to get another worker.
|
||||
if (!queued_tasks_.empty()) {
|
||||
RequestNewWorkerIfNeeded(queued_tasks_.front());
|
||||
}
|
||||
}
|
||||
|
||||
std::shared_ptr<WorkerLeaseInterface>
|
||||
|
@ -129,23 +133,21 @@ void CoreWorkerDirectTaskSubmitter::PushNormalTask(const rpc::WorkerAddress &add
|
|||
rpc::CoreWorkerClientInterface &client,
|
||||
TaskSpecification &task_spec) {
|
||||
auto task_id = task_spec.TaskId();
|
||||
auto num_returns = task_spec.NumReturns();
|
||||
auto request = std::unique_ptr<rpc::PushTaskRequest>(new rpc::PushTaskRequest);
|
||||
request->mutable_task_spec()->Swap(&task_spec.GetMutableMessage());
|
||||
auto status = client.PushNormalTask(
|
||||
std::move(request),
|
||||
[this, task_id, num_returns, addr](Status status, const rpc::PushTaskReply &reply) {
|
||||
[this, task_id, addr](Status status, const rpc::PushTaskReply &reply) {
|
||||
OnWorkerIdle(addr, /*error=*/!status.ok());
|
||||
if (!status.ok()) {
|
||||
TreatTaskAsFailed(task_id, num_returns, rpc::ErrorType::WORKER_DIED,
|
||||
in_memory_store_);
|
||||
return;
|
||||
task_finisher_->FailPendingTask(task_id, rpc::ErrorType::WORKER_DIED);
|
||||
} else {
|
||||
task_finisher_->CompletePendingTask(task_id, reply);
|
||||
}
|
||||
WriteObjectsToMemoryStore(reply, in_memory_store_);
|
||||
});
|
||||
if (!status.ok()) {
|
||||
TreatTaskAsFailed(task_id, num_returns, rpc::ErrorType::WORKER_DIED,
|
||||
in_memory_store_);
|
||||
// TODO(swang): add unit test for this.
|
||||
task_finisher_->FailPendingTask(task_id, rpc::ErrorType::WORKER_DIED);
|
||||
}
|
||||
}
|
||||
}; // namespace ray
|
||||
|
|
|
@ -3,10 +3,12 @@
|
|||
|
||||
#include "absl/base/thread_annotations.h"
|
||||
#include "absl/synchronization/mutex.h"
|
||||
|
||||
#include "ray/common/id.h"
|
||||
#include "ray/common/ray_object.h"
|
||||
#include "ray/core_worker/context.h"
|
||||
#include "ray/core_worker/store_provider/memory_store/memory_store.h"
|
||||
#include "ray/core_worker/task_manager.h"
|
||||
#include "ray/core_worker/transport/dependency_resolver.h"
|
||||
#include "ray/core_worker/transport/direct_actor_transport.h"
|
||||
#include "ray/raylet/raylet_client.h"
|
||||
|
@ -24,12 +26,13 @@ class CoreWorkerDirectTaskSubmitter {
|
|||
rpc::ClientFactoryFn client_factory,
|
||||
LeaseClientFactoryFn lease_client_factory,
|
||||
std::shared_ptr<CoreWorkerMemoryStore> store,
|
||||
std::shared_ptr<TaskFinisherInterface> task_finisher,
|
||||
int64_t lease_timeout_ms)
|
||||
: local_lease_client_(lease_client),
|
||||
client_factory_(client_factory),
|
||||
lease_client_factory_(lease_client_factory),
|
||||
in_memory_store_(store),
|
||||
resolver_(in_memory_store_),
|
||||
resolver_(store),
|
||||
task_finisher_(task_finisher),
|
||||
lease_timeout_ms_(lease_timeout_ms) {}
|
||||
|
||||
/// Schedule a task for direct submission to a worker.
|
||||
|
@ -81,12 +84,12 @@ class CoreWorkerDirectTaskSubmitter {
|
|||
/// Factory for producing new clients to request leases from remote nodes.
|
||||
LeaseClientFactoryFn lease_client_factory_;
|
||||
|
||||
/// The store provider.
|
||||
std::shared_ptr<CoreWorkerMemoryStore> in_memory_store_;
|
||||
|
||||
/// Resolve local and remote dependencies;
|
||||
LocalDependencyResolver resolver_;
|
||||
|
||||
/// Used to complete tasks.
|
||||
std::shared_ptr<TaskFinisherInterface> task_finisher_;
|
||||
|
||||
/// The timeout for worker leases; after this duration, workers will be returned
|
||||
/// to the raylet.
|
||||
int64_t lease_timeout_ms_;
|
||||
|
|
Loading…
Add table
Reference in a new issue