mirror of
https://github.com/vale981/ray
synced 2025-03-05 18:11:42 -05:00
Track pending tasks with TaskManager (#6259)
* TaskStateManager to track and complete pending tasks * Convert actor transport to use task state manager * Refactor direct actor transport to use TaskStateManager * rename * Unit test * doc * IsTaskPending * Fix? * Shared ptr * HUH? * Update src/ray/core_worker/task_manager.cc Co-Authored-By: Zhijun Fu <37800433+zhijunfu@users.noreply.github.com> * Revert "HUH?" This reverts commit f80f0ba204ff4da5e0b03191fa0d5a4d9f552434. * Fix memory issue * oops
This commit is contained in:
parent
ed5154d7fe
commit
f6a0408173
14 changed files with 423 additions and 164 deletions
10
BUILD.bazel
10
BUILD.bazel
|
@ -448,6 +448,16 @@ cc_test(
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_test(
|
||||||
|
name = "task_manager_test",
|
||||||
|
srcs = ["src/ray/core_worker/test/task_manager_test.cc"],
|
||||||
|
copts = COPTS,
|
||||||
|
deps = [
|
||||||
|
":core_worker_lib",
|
||||||
|
"@com_google_googletest//:gtest_main",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
cc_test(
|
cc_test(
|
||||||
name = "scheduling_test",
|
name = "scheduling_test",
|
||||||
srcs = ["src/ray/common/scheduling/scheduling_test.cc"],
|
srcs = ["src/ray/common/scheduling/scheduling_test.cc"],
|
||||||
|
|
|
@ -2,7 +2,7 @@
|
||||||
|
|
||||||
namespace ray {
|
namespace ray {
|
||||||
|
|
||||||
bool RayObject::IsException() const {
|
bool RayObject::IsException(rpc::ErrorType *error_type) const {
|
||||||
if (metadata_ == nullptr) {
|
if (metadata_ == nullptr) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
@ -13,6 +13,9 @@ bool RayObject::IsException() const {
|
||||||
for (int i = 0; i < error_type_descriptor->value_count(); i++) {
|
for (int i = 0; i < error_type_descriptor->value_count(); i++) {
|
||||||
const auto error_type_number = error_type_descriptor->value(i)->number();
|
const auto error_type_number = error_type_descriptor->value(i)->number();
|
||||||
if (metadata == std::to_string(error_type_number)) {
|
if (metadata == std::to_string(error_type_number)) {
|
||||||
|
if (error_type) {
|
||||||
|
*error_type = rpc::ErrorType(error_type_number);
|
||||||
|
}
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -62,7 +62,7 @@ class RayObject {
|
||||||
bool HasMetadata() const { return metadata_ != nullptr; }
|
bool HasMetadata() const { return metadata_ != nullptr; }
|
||||||
|
|
||||||
/// Whether the object represents an exception.
|
/// Whether the object represents an exception.
|
||||||
bool IsException() const;
|
bool IsException(rpc::ErrorType *error_type = nullptr) const;
|
||||||
|
|
||||||
/// Whether the object has been promoted to plasma (i.e., since it was too
|
/// Whether the object has been promoted to plasma (i.e., since it was too
|
||||||
/// large to return directly as part of a gRPC response).
|
/// large to return directly as part of a gRPC response).
|
||||||
|
|
|
@ -164,6 +164,7 @@ CoreWorker::CoreWorker(const WorkerType worker_type, const Language language,
|
||||||
RAY_CHECK_OK(plasma_store_provider_->Put(obj, obj_id));
|
RAY_CHECK_OK(plasma_store_provider_->Put(obj, obj_id));
|
||||||
},
|
},
|
||||||
ref_counting_enabled ? reference_counter_ : nullptr, raylet_client_));
|
ref_counting_enabled ? reference_counter_ : nullptr, raylet_client_));
|
||||||
|
task_manager_.reset(new TaskManager(memory_store_));
|
||||||
resolver_.reset(new LocalDependencyResolver(memory_store_));
|
resolver_.reset(new LocalDependencyResolver(memory_store_));
|
||||||
|
|
||||||
// Create an entry for the driver task in the task table. This task is
|
// Create an entry for the driver task in the task table. This task is
|
||||||
|
@ -193,7 +194,8 @@ CoreWorker::CoreWorker(const WorkerType worker_type, const Language language,
|
||||||
new rpc::CoreWorkerClient(addr.first, addr.second, *client_call_manager_));
|
new rpc::CoreWorkerClient(addr.first, addr.second, *client_call_manager_));
|
||||||
};
|
};
|
||||||
direct_actor_submitter_ = std::unique_ptr<CoreWorkerDirectActorTaskSubmitter>(
|
direct_actor_submitter_ = std::unique_ptr<CoreWorkerDirectActorTaskSubmitter>(
|
||||||
new CoreWorkerDirectActorTaskSubmitter(client_factory, memory_store_));
|
new CoreWorkerDirectActorTaskSubmitter(client_factory, memory_store_,
|
||||||
|
task_manager_));
|
||||||
|
|
||||||
direct_task_submitter_ =
|
direct_task_submitter_ =
|
||||||
std::unique_ptr<CoreWorkerDirectTaskSubmitter>(new CoreWorkerDirectTaskSubmitter(
|
std::unique_ptr<CoreWorkerDirectTaskSubmitter>(new CoreWorkerDirectTaskSubmitter(
|
||||||
|
@ -204,7 +206,8 @@ CoreWorker::CoreWorker(const WorkerType worker_type, const Language language,
|
||||||
return std::shared_ptr<RayletClient>(
|
return std::shared_ptr<RayletClient>(
|
||||||
new RayletClient(std::move(grpc_client)));
|
new RayletClient(std::move(grpc_client)));
|
||||||
},
|
},
|
||||||
memory_store_, RayConfig::instance().worker_lease_timeout_milliseconds()));
|
memory_store_, task_manager_,
|
||||||
|
RayConfig::instance().worker_lease_timeout_milliseconds()));
|
||||||
}
|
}
|
||||||
|
|
||||||
CoreWorker::~CoreWorker() {
|
CoreWorker::~CoreWorker() {
|
||||||
|
@ -577,6 +580,7 @@ Status CoreWorker::SubmitTask(const RayFunction &function,
|
||||||
return_ids);
|
return_ids);
|
||||||
TaskSpecification task_spec = builder.Build();
|
TaskSpecification task_spec = builder.Build();
|
||||||
if (task_options.is_direct_call) {
|
if (task_options.is_direct_call) {
|
||||||
|
task_manager_->AddPendingTask(task_spec);
|
||||||
PinObjectReferences(task_spec, TaskTransportType::DIRECT);
|
PinObjectReferences(task_spec, TaskTransportType::DIRECT);
|
||||||
return direct_task_submitter_->SubmitTask(task_spec);
|
return direct_task_submitter_->SubmitTask(task_spec);
|
||||||
} else {
|
} else {
|
||||||
|
@ -659,6 +663,7 @@ Status CoreWorker::SubmitActorTask(const ActorID &actor_id, const RayFunction &f
|
||||||
Status status;
|
Status status;
|
||||||
TaskSpecification task_spec = builder.Build();
|
TaskSpecification task_spec = builder.Build();
|
||||||
if (is_direct_call) {
|
if (is_direct_call) {
|
||||||
|
task_manager_->AddPendingTask(task_spec);
|
||||||
PinObjectReferences(task_spec, TaskTransportType::DIRECT);
|
PinObjectReferences(task_spec, TaskTransportType::DIRECT);
|
||||||
status = direct_actor_submitter_->SubmitTask(task_spec);
|
status = direct_actor_submitter_->SubmitTask(task_spec);
|
||||||
} else {
|
} else {
|
||||||
|
|
|
@ -513,6 +513,9 @@ class CoreWorker {
|
||||||
/// Fields related to task submission.
|
/// Fields related to task submission.
|
||||||
///
|
///
|
||||||
|
|
||||||
|
// Tracks the currently pending tasks.
|
||||||
|
std::shared_ptr<TaskManager> task_manager_;
|
||||||
|
|
||||||
// Interface to submit tasks directly to other actors.
|
// Interface to submit tasks directly to other actors.
|
||||||
std::unique_ptr<CoreWorkerDirectActorTaskSubmitter> direct_actor_submitter_;
|
std::unique_ptr<CoreWorkerDirectActorTaskSubmitter> direct_actor_submitter_;
|
||||||
|
|
||||||
|
|
80
src/ray/core_worker/task_manager.cc
Normal file
80
src/ray/core_worker/task_manager.cc
Normal file
|
@ -0,0 +1,80 @@
|
||||||
|
#include "ray/core_worker/task_manager.h"
|
||||||
|
|
||||||
|
namespace ray {
|
||||||
|
|
||||||
|
void TaskManager::AddPendingTask(const TaskSpecification &spec) {
|
||||||
|
RAY_LOG(DEBUG) << "Adding pending task " << spec.TaskId();
|
||||||
|
absl::MutexLock lock(&mu_);
|
||||||
|
RAY_CHECK(pending_tasks_.emplace(spec.TaskId(), spec.NumReturns()).second);
|
||||||
|
}
|
||||||
|
|
||||||
|
void TaskManager::CompletePendingTask(const TaskID &task_id,
|
||||||
|
const rpc::PushTaskReply &reply) {
|
||||||
|
RAY_LOG(DEBUG) << "Completing task " << task_id;
|
||||||
|
{
|
||||||
|
absl::MutexLock lock(&mu_);
|
||||||
|
auto it = pending_tasks_.find(task_id);
|
||||||
|
RAY_CHECK(it != pending_tasks_.end())
|
||||||
|
<< "Tried to complete task that was not pending " << task_id;
|
||||||
|
pending_tasks_.erase(it);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int i = 0; i < reply.return_objects_size(); i++) {
|
||||||
|
const auto &return_object = reply.return_objects(i);
|
||||||
|
ObjectID object_id = ObjectID::FromBinary(return_object.object_id());
|
||||||
|
|
||||||
|
if (return_object.in_plasma()) {
|
||||||
|
// Mark it as in plasma with a dummy object.
|
||||||
|
std::string meta =
|
||||||
|
std::to_string(static_cast<int>(rpc::ErrorType::OBJECT_IN_PLASMA));
|
||||||
|
auto metadata =
|
||||||
|
const_cast<uint8_t *>(reinterpret_cast<const uint8_t *>(meta.data()));
|
||||||
|
auto meta_buffer = std::make_shared<LocalMemoryBuffer>(metadata, meta.size());
|
||||||
|
RAY_CHECK_OK(in_memory_store_->Put(RayObject(nullptr, meta_buffer), object_id));
|
||||||
|
} else {
|
||||||
|
std::shared_ptr<LocalMemoryBuffer> data_buffer;
|
||||||
|
if (return_object.data().size() > 0) {
|
||||||
|
data_buffer = std::make_shared<LocalMemoryBuffer>(
|
||||||
|
const_cast<uint8_t *>(
|
||||||
|
reinterpret_cast<const uint8_t *>(return_object.data().data())),
|
||||||
|
return_object.data().size());
|
||||||
|
}
|
||||||
|
std::shared_ptr<LocalMemoryBuffer> metadata_buffer;
|
||||||
|
if (return_object.metadata().size() > 0) {
|
||||||
|
metadata_buffer = std::make_shared<LocalMemoryBuffer>(
|
||||||
|
const_cast<uint8_t *>(
|
||||||
|
reinterpret_cast<const uint8_t *>(return_object.metadata().data())),
|
||||||
|
return_object.metadata().size());
|
||||||
|
}
|
||||||
|
RAY_CHECK_OK(
|
||||||
|
in_memory_store_->Put(RayObject(data_buffer, metadata_buffer), object_id));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void TaskManager::FailPendingTask(const TaskID &task_id, rpc::ErrorType error_type) {
|
||||||
|
RAY_LOG(DEBUG) << "Failing task " << task_id;
|
||||||
|
int64_t num_returns;
|
||||||
|
{
|
||||||
|
absl::MutexLock lock(&mu_);
|
||||||
|
auto it = pending_tasks_.find(task_id);
|
||||||
|
RAY_CHECK(it != pending_tasks_.end())
|
||||||
|
<< "Tried to complete task that was not pending " << task_id;
|
||||||
|
num_returns = it->second;
|
||||||
|
pending_tasks_.erase(it);
|
||||||
|
}
|
||||||
|
|
||||||
|
RAY_LOG(DEBUG) << "Treat task as failed. task_id: " << task_id
|
||||||
|
<< ", error_type: " << ErrorType_Name(error_type);
|
||||||
|
for (int i = 0; i < num_returns; i++) {
|
||||||
|
const auto object_id = ObjectID::ForTaskReturn(
|
||||||
|
task_id, /*index=*/i + 1,
|
||||||
|
/*transport_type=*/static_cast<int>(TaskTransportType::DIRECT));
|
||||||
|
std::string meta = std::to_string(static_cast<int>(error_type));
|
||||||
|
auto metadata = const_cast<uint8_t *>(reinterpret_cast<const uint8_t *>(meta.data()));
|
||||||
|
auto meta_buffer = std::make_shared<LocalMemoryBuffer>(metadata, meta.size());
|
||||||
|
RAY_CHECK_OK(in_memory_store_->Put(RayObject(nullptr, meta_buffer), object_id));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace ray
|
74
src/ray/core_worker/task_manager.h
Normal file
74
src/ray/core_worker/task_manager.h
Normal file
|
@ -0,0 +1,74 @@
|
||||||
|
#ifndef RAY_CORE_WORKER_TASK_MANAGER_H
|
||||||
|
#define RAY_CORE_WORKER_TASK_MANAGER_H
|
||||||
|
|
||||||
|
#include "absl/base/thread_annotations.h"
|
||||||
|
#include "absl/container/flat_hash_map.h"
|
||||||
|
#include "absl/synchronization/mutex.h"
|
||||||
|
|
||||||
|
#include "ray/common/id.h"
|
||||||
|
#include "ray/common/task/task.h"
|
||||||
|
#include "ray/core_worker/store_provider/memory_store/memory_store.h"
|
||||||
|
#include "ray/protobuf/core_worker.pb.h"
|
||||||
|
#include "ray/protobuf/gcs.pb.h"
|
||||||
|
|
||||||
|
namespace ray {
|
||||||
|
|
||||||
|
class TaskFinisherInterface {
|
||||||
|
public:
|
||||||
|
virtual void CompletePendingTask(const TaskID &task_id,
|
||||||
|
const rpc::PushTaskReply &reply) = 0;
|
||||||
|
|
||||||
|
virtual void FailPendingTask(const TaskID &task_id, rpc::ErrorType error_type) = 0;
|
||||||
|
|
||||||
|
virtual ~TaskFinisherInterface() {}
|
||||||
|
};
|
||||||
|
|
||||||
|
class TaskManager : public TaskFinisherInterface {
|
||||||
|
public:
|
||||||
|
TaskManager(std::shared_ptr<CoreWorkerMemoryStore> in_memory_store)
|
||||||
|
: in_memory_store_(in_memory_store) {}
|
||||||
|
|
||||||
|
/// Add a task that is pending execution.
|
||||||
|
///
|
||||||
|
/// \param[in] spec The spec of the pending task.
|
||||||
|
/// \return Void.
|
||||||
|
void AddPendingTask(const TaskSpecification &spec);
|
||||||
|
|
||||||
|
/// Return whether the task is pending.
|
||||||
|
///
|
||||||
|
/// \param[in] task_id ID of the task to query.
|
||||||
|
/// \return Whether the task is pending.
|
||||||
|
bool IsTaskPending(const TaskID &task_id) const {
|
||||||
|
return pending_tasks_.count(task_id) > 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Write return objects for a pending task to the memory store.
|
||||||
|
///
|
||||||
|
/// \param[in] task_id ID of the pending task.
|
||||||
|
/// \param[in] reply Proto response to a direct actor or task call.
|
||||||
|
/// \return Void.
|
||||||
|
void CompletePendingTask(const TaskID &task_id,
|
||||||
|
const rpc::PushTaskReply &reply) override;
|
||||||
|
|
||||||
|
/// Treat a pending task as failed.
|
||||||
|
///
|
||||||
|
/// \param[in] task_id ID of the pending task.
|
||||||
|
/// \param[in] error_type The type of the specific error.
|
||||||
|
/// \return Void.
|
||||||
|
void FailPendingTask(const TaskID &task_id, rpc::ErrorType error_type) override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
/// Used to store task results.
|
||||||
|
std::shared_ptr<CoreWorkerMemoryStore> in_memory_store_;
|
||||||
|
|
||||||
|
/// Protects below fields.
|
||||||
|
absl::Mutex mu_;
|
||||||
|
|
||||||
|
/// Map from task ID to the task's number of return values. This map contains
|
||||||
|
/// one entry per pending task that we submitted.
|
||||||
|
absl::flat_hash_map<TaskID, int64_t> pending_tasks_ GUARDED_BY(mu_);
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace ray
|
||||||
|
|
||||||
|
#endif // RAY_CORE_WORKER_TASK_MANAGER_H
|
|
@ -1,3 +1,4 @@
|
||||||
|
#include "gmock/gmock.h"
|
||||||
#include "gtest/gtest.h"
|
#include "gtest/gtest.h"
|
||||||
|
|
||||||
#include "ray/common/task/task_spec.h"
|
#include "ray/common/task/task_spec.h"
|
||||||
|
@ -9,6 +10,8 @@
|
||||||
|
|
||||||
namespace ray {
|
namespace ray {
|
||||||
|
|
||||||
|
using ::testing::_;
|
||||||
|
|
||||||
class MockWorkerClient : public rpc::CoreWorkerClientInterface {
|
class MockWorkerClient : public rpc::CoreWorkerClientInterface {
|
||||||
public:
|
public:
|
||||||
ray::Status PushActorTask(
|
ray::Status PushActorTask(
|
||||||
|
@ -20,10 +23,28 @@ class MockWorkerClient : public rpc::CoreWorkerClientInterface {
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<rpc::ClientCallback<rpc::PushTaskReply>> callbacks;
|
bool ReplyPushTask(Status status = Status::OK()) {
|
||||||
|
if (callbacks.size() == 0) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
auto callback = callbacks.front();
|
||||||
|
callback(status, rpc::PushTaskReply());
|
||||||
|
callbacks.pop_front();
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::list<rpc::ClientCallback<rpc::PushTaskReply>> callbacks;
|
||||||
uint64_t counter = 0;
|
uint64_t counter = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
class MockTaskFinisher : public TaskFinisherInterface {
|
||||||
|
public:
|
||||||
|
MockTaskFinisher() {}
|
||||||
|
|
||||||
|
MOCK_METHOD2(CompletePendingTask, void(const TaskID &, const rpc::PushTaskReply &));
|
||||||
|
MOCK_METHOD2(FailPendingTask, void(const TaskID &task_id, rpc::ErrorType error_type));
|
||||||
|
};
|
||||||
|
|
||||||
TaskSpecification CreateActorTaskHelper(ActorID actor_id, int64_t counter) {
|
TaskSpecification CreateActorTaskHelper(ActorID actor_id, int64_t counter) {
|
||||||
TaskSpecification task;
|
TaskSpecification task;
|
||||||
task.GetMutableMessage().set_task_id(TaskID::Nil().Binary());
|
task.GetMutableMessage().set_task_id(TaskID::Nil().Binary());
|
||||||
|
@ -38,11 +59,13 @@ class DirectActorTransportTest : public ::testing::Test {
|
||||||
DirectActorTransportTest()
|
DirectActorTransportTest()
|
||||||
: worker_client_(std::shared_ptr<MockWorkerClient>(new MockWorkerClient())),
|
: worker_client_(std::shared_ptr<MockWorkerClient>(new MockWorkerClient())),
|
||||||
store_(std::shared_ptr<CoreWorkerMemoryStore>(new CoreWorkerMemoryStore())),
|
store_(std::shared_ptr<CoreWorkerMemoryStore>(new CoreWorkerMemoryStore())),
|
||||||
submitter_([&](const rpc::WorkerAddress &addr) { return worker_client_; },
|
task_finisher_(std::make_shared<MockTaskFinisher>()),
|
||||||
store_) {}
|
submitter_([&](const rpc::WorkerAddress &addr) { return worker_client_; }, store_,
|
||||||
|
task_finisher_) {}
|
||||||
|
|
||||||
std::shared_ptr<MockWorkerClient> worker_client_;
|
std::shared_ptr<MockWorkerClient> worker_client_;
|
||||||
std::shared_ptr<CoreWorkerMemoryStore> store_;
|
std::shared_ptr<CoreWorkerMemoryStore> store_;
|
||||||
|
std::shared_ptr<MockTaskFinisher> task_finisher_;
|
||||||
CoreWorkerDirectActorTaskSubmitter submitter_;
|
CoreWorkerDirectActorTaskSubmitter submitter_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -60,6 +83,13 @@ TEST_F(DirectActorTransportTest, TestSubmitTask) {
|
||||||
task = CreateActorTaskHelper(actor_id, 1);
|
task = CreateActorTaskHelper(actor_id, 1);
|
||||||
ASSERT_TRUE(submitter_.SubmitTask(task).ok());
|
ASSERT_TRUE(submitter_.SubmitTask(task).ok());
|
||||||
ASSERT_EQ(worker_client_->callbacks.size(), 2);
|
ASSERT_EQ(worker_client_->callbacks.size(), 2);
|
||||||
|
|
||||||
|
EXPECT_CALL(*task_finisher_, CompletePendingTask(TaskID::Nil(), _))
|
||||||
|
.Times(worker_client_->callbacks.size());
|
||||||
|
EXPECT_CALL(*task_finisher_, FailPendingTask(_, _)).Times(0);
|
||||||
|
while (!worker_client_->callbacks.empty()) {
|
||||||
|
ASSERT_TRUE(worker_client_->ReplyPushTask());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(DirectActorTransportTest, TestDependencies) {
|
TEST_F(DirectActorTransportTest, TestDependencies) {
|
||||||
|
@ -119,6 +149,27 @@ TEST_F(DirectActorTransportTest, TestOutOfOrderDependencies) {
|
||||||
ASSERT_EQ(worker_client_->callbacks.size(), 2);
|
ASSERT_EQ(worker_client_->callbacks.size(), 2);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(DirectActorTransportTest, TestActorFailure) {
|
||||||
|
ActorID actor_id = ActorID::Of(JobID::FromInt(0), TaskID::Nil(), 0);
|
||||||
|
gcs::ActorTableData actor_data;
|
||||||
|
submitter_.HandleActorUpdate(actor_id, actor_data);
|
||||||
|
ASSERT_EQ(worker_client_->callbacks.size(), 0);
|
||||||
|
|
||||||
|
// Create two tasks for the actor.
|
||||||
|
auto task1 = CreateActorTaskHelper(actor_id, 0);
|
||||||
|
auto task2 = CreateActorTaskHelper(actor_id, 1);
|
||||||
|
ASSERT_TRUE(submitter_.SubmitTask(task1).ok());
|
||||||
|
ASSERT_TRUE(submitter_.SubmitTask(task2).ok());
|
||||||
|
ASSERT_EQ(worker_client_->callbacks.size(), 2);
|
||||||
|
|
||||||
|
// Simulate the actor dying. All submitted tasks should get failed.
|
||||||
|
EXPECT_CALL(*task_finisher_, FailPendingTask(_, _)).Times(2);
|
||||||
|
EXPECT_CALL(*task_finisher_, CompletePendingTask(_, _)).Times(0);
|
||||||
|
while (!worker_client_->callbacks.empty()) {
|
||||||
|
ASSERT_TRUE(worker_client_->ReplyPushTask(Status::IOError("")));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace ray
|
} // namespace ray
|
||||||
|
|
||||||
int main(int argc, char **argv) {
|
int main(int argc, char **argv) {
|
||||||
|
|
|
@ -23,7 +23,32 @@ class MockWorkerClient : public rpc::CoreWorkerClientInterface {
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<rpc::ClientCallback<rpc::PushTaskReply>> callbacks;
|
bool ReplyPushTask(Status status = Status::OK()) {
|
||||||
|
if (callbacks.size() == 0) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
auto callback = callbacks.front();
|
||||||
|
callback(status, rpc::PushTaskReply());
|
||||||
|
callbacks.pop_front();
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::list<rpc::ClientCallback<rpc::PushTaskReply>> callbacks;
|
||||||
|
};
|
||||||
|
|
||||||
|
class MockTaskFinisher : public TaskFinisherInterface {
|
||||||
|
public:
|
||||||
|
MockTaskFinisher() {}
|
||||||
|
|
||||||
|
void CompletePendingTask(const TaskID &, const rpc::PushTaskReply &) override {
|
||||||
|
num_tasks_complete++;
|
||||||
|
}
|
||||||
|
void FailPendingTask(const TaskID &task_id, rpc::ErrorType error_type) override {
|
||||||
|
num_tasks_failed++;
|
||||||
|
}
|
||||||
|
|
||||||
|
int num_tasks_complete = 0;
|
||||||
|
int num_tasks_failed = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
class MockRayletClient : public WorkerLeaseInterface {
|
class MockRayletClient : public WorkerLeaseInterface {
|
||||||
|
@ -196,8 +221,9 @@ TEST(DirectTaskTransportTest, TestSubmitOneTask) {
|
||||||
auto worker_client = std::make_shared<MockWorkerClient>();
|
auto worker_client = std::make_shared<MockWorkerClient>();
|
||||||
auto store = std::make_shared<CoreWorkerMemoryStore>();
|
auto store = std::make_shared<CoreWorkerMemoryStore>();
|
||||||
auto factory = [&](const rpc::WorkerAddress &addr) { return worker_client; };
|
auto factory = [&](const rpc::WorkerAddress &addr) { return worker_client; };
|
||||||
|
auto task_finisher = std::make_shared<MockTaskFinisher>();
|
||||||
CoreWorkerDirectTaskSubmitter submitter(raylet_client, factory, nullptr, store,
|
CoreWorkerDirectTaskSubmitter submitter(raylet_client, factory, nullptr, store,
|
||||||
kLongTimeout);
|
task_finisher, kLongTimeout);
|
||||||
TaskSpecification task;
|
TaskSpecification task;
|
||||||
task.GetMutableMessage().set_task_id(TaskID::Nil().Binary());
|
task.GetMutableMessage().set_task_id(TaskID::Nil().Binary());
|
||||||
|
|
||||||
|
@ -208,10 +234,14 @@ TEST(DirectTaskTransportTest, TestSubmitOneTask) {
|
||||||
|
|
||||||
ASSERT_TRUE(raylet_client->GrantWorkerLease("localhost", 1234, ClientID::Nil()));
|
ASSERT_TRUE(raylet_client->GrantWorkerLease("localhost", 1234, ClientID::Nil()));
|
||||||
ASSERT_EQ(worker_client->callbacks.size(), 1);
|
ASSERT_EQ(worker_client->callbacks.size(), 1);
|
||||||
|
ASSERT_EQ(task_finisher->num_tasks_complete, 0);
|
||||||
|
ASSERT_EQ(task_finisher->num_tasks_failed, 0);
|
||||||
|
|
||||||
worker_client->callbacks[0](Status::OK(), rpc::PushTaskReply());
|
ASSERT_TRUE(worker_client->ReplyPushTask());
|
||||||
ASSERT_EQ(raylet_client->num_workers_returned, 1);
|
ASSERT_EQ(raylet_client->num_workers_returned, 1);
|
||||||
ASSERT_EQ(raylet_client->num_workers_disconnected, 0);
|
ASSERT_EQ(raylet_client->num_workers_disconnected, 0);
|
||||||
|
ASSERT_EQ(task_finisher->num_tasks_complete, 1);
|
||||||
|
ASSERT_EQ(task_finisher->num_tasks_failed, 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(DirectTaskTransportTest, TestHandleTaskFailure) {
|
TEST(DirectTaskTransportTest, TestHandleTaskFailure) {
|
||||||
|
@ -219,18 +249,21 @@ TEST(DirectTaskTransportTest, TestHandleTaskFailure) {
|
||||||
auto worker_client = std::make_shared<MockWorkerClient>();
|
auto worker_client = std::make_shared<MockWorkerClient>();
|
||||||
auto store = std::make_shared<CoreWorkerMemoryStore>();
|
auto store = std::make_shared<CoreWorkerMemoryStore>();
|
||||||
auto factory = [&](const rpc::WorkerAddress &addr) { return worker_client; };
|
auto factory = [&](const rpc::WorkerAddress &addr) { return worker_client; };
|
||||||
|
auto task_finisher = std::make_shared<MockTaskFinisher>();
|
||||||
CoreWorkerDirectTaskSubmitter submitter(raylet_client, factory, nullptr, store,
|
CoreWorkerDirectTaskSubmitter submitter(raylet_client, factory, nullptr, store,
|
||||||
kLongTimeout);
|
task_finisher, kLongTimeout);
|
||||||
TaskSpecification task;
|
TaskSpecification task;
|
||||||
task.GetMutableMessage().set_task_id(TaskID::Nil().Binary());
|
task.GetMutableMessage().set_task_id(TaskID::Nil().Binary());
|
||||||
|
|
||||||
ASSERT_TRUE(submitter.SubmitTask(task).ok());
|
ASSERT_TRUE(submitter.SubmitTask(task).ok());
|
||||||
ASSERT_TRUE(raylet_client->GrantWorkerLease("localhost", 1234, ClientID::Nil()));
|
ASSERT_TRUE(raylet_client->GrantWorkerLease("localhost", 1234, ClientID::Nil()));
|
||||||
// Simulate a system failure, i.e., worker died unexpectedly.
|
// Simulate a system failure, i.e., worker died unexpectedly.
|
||||||
worker_client->callbacks[0](Status::IOError("oops"), rpc::PushTaskReply());
|
ASSERT_TRUE(worker_client->ReplyPushTask(Status::IOError("oops")));
|
||||||
ASSERT_EQ(worker_client->callbacks.size(), 1);
|
ASSERT_EQ(worker_client->callbacks.size(), 0);
|
||||||
ASSERT_EQ(raylet_client->num_workers_returned, 0);
|
ASSERT_EQ(raylet_client->num_workers_returned, 0);
|
||||||
ASSERT_EQ(raylet_client->num_workers_disconnected, 1);
|
ASSERT_EQ(raylet_client->num_workers_disconnected, 1);
|
||||||
|
ASSERT_EQ(task_finisher->num_tasks_complete, 0);
|
||||||
|
ASSERT_EQ(task_finisher->num_tasks_failed, 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(DirectTaskTransportTest, TestConcurrentWorkerLeases) {
|
TEST(DirectTaskTransportTest, TestConcurrentWorkerLeases) {
|
||||||
|
@ -238,8 +271,9 @@ TEST(DirectTaskTransportTest, TestConcurrentWorkerLeases) {
|
||||||
auto worker_client = std::make_shared<MockWorkerClient>();
|
auto worker_client = std::make_shared<MockWorkerClient>();
|
||||||
auto store = std::make_shared<CoreWorkerMemoryStore>();
|
auto store = std::make_shared<CoreWorkerMemoryStore>();
|
||||||
auto factory = [&](const rpc::WorkerAddress &addr) { return worker_client; };
|
auto factory = [&](const rpc::WorkerAddress &addr) { return worker_client; };
|
||||||
|
auto task_finisher = std::make_shared<MockTaskFinisher>();
|
||||||
CoreWorkerDirectTaskSubmitter submitter(raylet_client, factory, nullptr, store,
|
CoreWorkerDirectTaskSubmitter submitter(raylet_client, factory, nullptr, store,
|
||||||
kLongTimeout);
|
task_finisher, kLongTimeout);
|
||||||
TaskSpecification task1;
|
TaskSpecification task1;
|
||||||
TaskSpecification task2;
|
TaskSpecification task2;
|
||||||
TaskSpecification task3;
|
TaskSpecification task3;
|
||||||
|
@ -268,11 +302,13 @@ TEST(DirectTaskTransportTest, TestConcurrentWorkerLeases) {
|
||||||
ASSERT_EQ(raylet_client->num_workers_requested, 3);
|
ASSERT_EQ(raylet_client->num_workers_requested, 3);
|
||||||
|
|
||||||
// All workers returned.
|
// All workers returned.
|
||||||
for (const auto &cb : worker_client->callbacks) {
|
while (!worker_client->callbacks.empty()) {
|
||||||
cb(Status::OK(), rpc::PushTaskReply());
|
ASSERT_TRUE(worker_client->ReplyPushTask());
|
||||||
}
|
}
|
||||||
ASSERT_EQ(raylet_client->num_workers_returned, 3);
|
ASSERT_EQ(raylet_client->num_workers_returned, 3);
|
||||||
ASSERT_EQ(raylet_client->num_workers_disconnected, 0);
|
ASSERT_EQ(raylet_client->num_workers_disconnected, 0);
|
||||||
|
ASSERT_EQ(task_finisher->num_tasks_complete, 3);
|
||||||
|
ASSERT_EQ(task_finisher->num_tasks_failed, 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(DirectTaskTransportTest, TestReuseWorkerLease) {
|
TEST(DirectTaskTransportTest, TestReuseWorkerLease) {
|
||||||
|
@ -280,8 +316,9 @@ TEST(DirectTaskTransportTest, TestReuseWorkerLease) {
|
||||||
auto worker_client = std::make_shared<MockWorkerClient>();
|
auto worker_client = std::make_shared<MockWorkerClient>();
|
||||||
auto store = std::make_shared<CoreWorkerMemoryStore>();
|
auto store = std::make_shared<CoreWorkerMemoryStore>();
|
||||||
auto factory = [&](const rpc::WorkerAddress &addr) { return worker_client; };
|
auto factory = [&](const rpc::WorkerAddress &addr) { return worker_client; };
|
||||||
|
auto task_finisher = std::make_shared<MockTaskFinisher>();
|
||||||
CoreWorkerDirectTaskSubmitter submitter(raylet_client, factory, nullptr, store,
|
CoreWorkerDirectTaskSubmitter submitter(raylet_client, factory, nullptr, store,
|
||||||
kLongTimeout);
|
task_finisher, kLongTimeout);
|
||||||
TaskSpecification task1;
|
TaskSpecification task1;
|
||||||
TaskSpecification task2;
|
TaskSpecification task2;
|
||||||
TaskSpecification task3;
|
TaskSpecification task3;
|
||||||
|
@ -300,23 +337,26 @@ TEST(DirectTaskTransportTest, TestReuseWorkerLease) {
|
||||||
ASSERT_EQ(raylet_client->num_workers_requested, 2);
|
ASSERT_EQ(raylet_client->num_workers_requested, 2);
|
||||||
|
|
||||||
// Task 1 finishes, Task 2 is scheduled on the same worker.
|
// Task 1 finishes, Task 2 is scheduled on the same worker.
|
||||||
worker_client->callbacks[0](Status::OK(), rpc::PushTaskReply());
|
ASSERT_TRUE(worker_client->ReplyPushTask());
|
||||||
ASSERT_EQ(worker_client->callbacks.size(), 2);
|
ASSERT_EQ(worker_client->callbacks.size(), 1);
|
||||||
ASSERT_EQ(raylet_client->num_workers_returned, 0);
|
ASSERT_EQ(raylet_client->num_workers_returned, 0);
|
||||||
|
|
||||||
// Task 2 finishes, Task 3 is scheduled on the same worker.
|
// Task 2 finishes, Task 3 is scheduled on the same worker.
|
||||||
worker_client->callbacks[1](Status::OK(), rpc::PushTaskReply());
|
ASSERT_TRUE(worker_client->ReplyPushTask());
|
||||||
ASSERT_EQ(worker_client->callbacks.size(), 3);
|
ASSERT_EQ(worker_client->callbacks.size(), 1);
|
||||||
ASSERT_EQ(raylet_client->num_workers_returned, 0);
|
ASSERT_EQ(raylet_client->num_workers_returned, 0);
|
||||||
|
|
||||||
// Task 3 finishes, the worker is returned.
|
// Task 3 finishes, the worker is returned.
|
||||||
worker_client->callbacks[2](Status::OK(), rpc::PushTaskReply());
|
ASSERT_TRUE(worker_client->ReplyPushTask());
|
||||||
ASSERT_EQ(raylet_client->num_workers_returned, 1);
|
ASSERT_EQ(raylet_client->num_workers_returned, 1);
|
||||||
|
|
||||||
// The second lease request is returned immediately.
|
// The second lease request is returned immediately.
|
||||||
ASSERT_TRUE(raylet_client->GrantWorkerLease("localhost", 1001, ClientID::Nil()));
|
ASSERT_TRUE(raylet_client->GrantWorkerLease("localhost", 1001, ClientID::Nil()));
|
||||||
|
ASSERT_EQ(worker_client->callbacks.size(), 0);
|
||||||
ASSERT_EQ(raylet_client->num_workers_returned, 2);
|
ASSERT_EQ(raylet_client->num_workers_returned, 2);
|
||||||
ASSERT_EQ(raylet_client->num_workers_disconnected, 0);
|
ASSERT_EQ(raylet_client->num_workers_disconnected, 0);
|
||||||
|
ASSERT_EQ(task_finisher->num_tasks_complete, 3);
|
||||||
|
ASSERT_EQ(task_finisher->num_tasks_failed, 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(DirectTaskTransportTest, TestWorkerNotReusedOnError) {
|
TEST(DirectTaskTransportTest, TestWorkerNotReusedOnError) {
|
||||||
|
@ -324,8 +364,9 @@ TEST(DirectTaskTransportTest, TestWorkerNotReusedOnError) {
|
||||||
auto worker_client = std::make_shared<MockWorkerClient>();
|
auto worker_client = std::make_shared<MockWorkerClient>();
|
||||||
auto store = std::make_shared<CoreWorkerMemoryStore>();
|
auto store = std::make_shared<CoreWorkerMemoryStore>();
|
||||||
auto factory = [&](const rpc::WorkerAddress &addr) { return worker_client; };
|
auto factory = [&](const rpc::WorkerAddress &addr) { return worker_client; };
|
||||||
|
auto task_finisher = std::make_shared<MockTaskFinisher>();
|
||||||
CoreWorkerDirectTaskSubmitter submitter(raylet_client, factory, nullptr, store,
|
CoreWorkerDirectTaskSubmitter submitter(raylet_client, factory, nullptr, store,
|
||||||
kLongTimeout);
|
task_finisher, kLongTimeout);
|
||||||
TaskSpecification task1;
|
TaskSpecification task1;
|
||||||
TaskSpecification task2;
|
TaskSpecification task2;
|
||||||
task1.GetMutableMessage().set_task_id(TaskID::Nil().Binary());
|
task1.GetMutableMessage().set_task_id(TaskID::Nil().Binary());
|
||||||
|
@ -341,17 +382,18 @@ TEST(DirectTaskTransportTest, TestWorkerNotReusedOnError) {
|
||||||
ASSERT_EQ(raylet_client->num_workers_requested, 2);
|
ASSERT_EQ(raylet_client->num_workers_requested, 2);
|
||||||
|
|
||||||
// Task 1 finishes with failure; the worker is returned.
|
// Task 1 finishes with failure; the worker is returned.
|
||||||
worker_client->callbacks[0](Status::IOError("worker dead"), rpc::PushTaskReply());
|
ASSERT_TRUE(worker_client->ReplyPushTask(Status::IOError("worker dead")));
|
||||||
ASSERT_EQ(worker_client->callbacks.size(), 1);
|
ASSERT_EQ(worker_client->callbacks.size(), 0);
|
||||||
ASSERT_EQ(raylet_client->num_workers_returned, 0);
|
ASSERT_EQ(raylet_client->num_workers_returned, 0);
|
||||||
ASSERT_EQ(raylet_client->num_workers_disconnected, 1);
|
ASSERT_EQ(raylet_client->num_workers_disconnected, 1);
|
||||||
|
|
||||||
// Task 2 runs successfully on the second worker.
|
// Task 2 runs successfully on the second worker.
|
||||||
ASSERT_TRUE(raylet_client->GrantWorkerLease("localhost", 1001, ClientID::Nil()));
|
ASSERT_TRUE(raylet_client->GrantWorkerLease("localhost", 1001, ClientID::Nil()));
|
||||||
ASSERT_EQ(worker_client->callbacks.size(), 2);
|
ASSERT_TRUE(worker_client->ReplyPushTask());
|
||||||
worker_client->callbacks[1](Status::OK(), rpc::PushTaskReply());
|
|
||||||
ASSERT_EQ(raylet_client->num_workers_returned, 1);
|
ASSERT_EQ(raylet_client->num_workers_returned, 1);
|
||||||
ASSERT_EQ(raylet_client->num_workers_disconnected, 1);
|
ASSERT_EQ(raylet_client->num_workers_disconnected, 1);
|
||||||
|
ASSERT_EQ(task_finisher->num_tasks_complete, 1);
|
||||||
|
ASSERT_EQ(task_finisher->num_tasks_failed, 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(DirectTaskTransportTest, TestSpillback) {
|
TEST(DirectTaskTransportTest, TestSpillback) {
|
||||||
|
@ -369,8 +411,9 @@ TEST(DirectTaskTransportTest, TestSpillback) {
|
||||||
remote_lease_clients[raylet_id] = client;
|
remote_lease_clients[raylet_id] = client;
|
||||||
return client;
|
return client;
|
||||||
};
|
};
|
||||||
|
auto task_finisher = std::make_shared<MockTaskFinisher>();
|
||||||
CoreWorkerDirectTaskSubmitter submitter(raylet_client, factory, lease_client_factory,
|
CoreWorkerDirectTaskSubmitter submitter(raylet_client, factory, lease_client_factory,
|
||||||
store, kLongTimeout);
|
store, task_finisher, kLongTimeout);
|
||||||
TaskSpecification task;
|
TaskSpecification task;
|
||||||
task.GetMutableMessage().set_task_id(TaskID::Nil().Binary());
|
task.GetMutableMessage().set_task_id(TaskID::Nil().Binary());
|
||||||
|
|
||||||
|
@ -389,15 +432,15 @@ TEST(DirectTaskTransportTest, TestSpillback) {
|
||||||
// Trigger retry at the remote node.
|
// Trigger retry at the remote node.
|
||||||
ASSERT_TRUE(remote_lease_clients[remote_raylet_id]->GrantWorkerLease("remote", 1234,
|
ASSERT_TRUE(remote_lease_clients[remote_raylet_id]->GrantWorkerLease("remote", 1234,
|
||||||
ClientID::Nil()));
|
ClientID::Nil()));
|
||||||
ASSERT_EQ(worker_client->callbacks.size(), 1);
|
|
||||||
|
|
||||||
// The worker is returned to the remote node, not the local one.
|
// The worker is returned to the remote node, not the local one.
|
||||||
worker_client->callbacks[0](Status::OK(), rpc::PushTaskReply());
|
ASSERT_TRUE(worker_client->ReplyPushTask());
|
||||||
ASSERT_EQ(raylet_client->num_workers_returned, 0);
|
ASSERT_EQ(raylet_client->num_workers_returned, 0);
|
||||||
ASSERT_EQ(remote_lease_clients[remote_raylet_id]->num_workers_returned, 1);
|
ASSERT_EQ(remote_lease_clients[remote_raylet_id]->num_workers_returned, 1);
|
||||||
|
|
||||||
ASSERT_EQ(raylet_client->num_workers_disconnected, 0);
|
ASSERT_EQ(raylet_client->num_workers_disconnected, 0);
|
||||||
ASSERT_EQ(remote_lease_clients[remote_raylet_id]->num_workers_disconnected, 0);
|
ASSERT_EQ(remote_lease_clients[remote_raylet_id]->num_workers_disconnected, 0);
|
||||||
|
ASSERT_EQ(task_finisher->num_tasks_complete, 1);
|
||||||
|
ASSERT_EQ(task_finisher->num_tasks_failed, 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(DirectTaskTransportTest, TestWorkerLeaseTimeout) {
|
TEST(DirectTaskTransportTest, TestWorkerLeaseTimeout) {
|
||||||
|
@ -405,7 +448,9 @@ TEST(DirectTaskTransportTest, TestWorkerLeaseTimeout) {
|
||||||
auto worker_client = std::make_shared<MockWorkerClient>();
|
auto worker_client = std::make_shared<MockWorkerClient>();
|
||||||
auto store = std::make_shared<CoreWorkerMemoryStore>();
|
auto store = std::make_shared<CoreWorkerMemoryStore>();
|
||||||
auto factory = [&](const rpc::WorkerAddress &addr) { return worker_client; };
|
auto factory = [&](const rpc::WorkerAddress &addr) { return worker_client; };
|
||||||
|
auto task_finisher = std::make_shared<MockTaskFinisher>();
|
||||||
CoreWorkerDirectTaskSubmitter submitter(raylet_client, factory, nullptr, store,
|
CoreWorkerDirectTaskSubmitter submitter(raylet_client, factory, nullptr, store,
|
||||||
|
task_finisher,
|
||||||
/*lease_timeout_ms=*/5);
|
/*lease_timeout_ms=*/5);
|
||||||
TaskSpecification task1;
|
TaskSpecification task1;
|
||||||
TaskSpecification task2;
|
TaskSpecification task2;
|
||||||
|
@ -421,13 +466,11 @@ TEST(DirectTaskTransportTest, TestWorkerLeaseTimeout) {
|
||||||
|
|
||||||
// Task 1 is pushed.
|
// Task 1 is pushed.
|
||||||
ASSERT_TRUE(raylet_client->GrantWorkerLease("localhost", 1000, ClientID::Nil()));
|
ASSERT_TRUE(raylet_client->GrantWorkerLease("localhost", 1000, ClientID::Nil()));
|
||||||
ASSERT_EQ(worker_client->callbacks.size(), 1);
|
|
||||||
ASSERT_EQ(raylet_client->num_workers_requested, 2);
|
ASSERT_EQ(raylet_client->num_workers_requested, 2);
|
||||||
|
|
||||||
// Task 1 finishes with failure; the worker is returned due to the error even though
|
// Task 1 finishes with failure; the worker is returned due to the error even though
|
||||||
// it hasn't timed out.
|
// it hasn't timed out.
|
||||||
worker_client->callbacks[0](Status::IOError("worker dead"), rpc::PushTaskReply());
|
ASSERT_TRUE(worker_client->ReplyPushTask(Status::IOError("worker dead")));
|
||||||
ASSERT_EQ(worker_client->callbacks.size(), 1);
|
|
||||||
ASSERT_EQ(raylet_client->num_workers_returned, 0);
|
ASSERT_EQ(raylet_client->num_workers_returned, 0);
|
||||||
ASSERT_EQ(raylet_client->num_workers_disconnected, 1);
|
ASSERT_EQ(raylet_client->num_workers_disconnected, 1);
|
||||||
|
|
||||||
|
@ -435,16 +478,15 @@ TEST(DirectTaskTransportTest, TestWorkerLeaseTimeout) {
|
||||||
// timeout.
|
// timeout.
|
||||||
ASSERT_TRUE(raylet_client->GrantWorkerLease("localhost", 1001, ClientID::Nil()));
|
ASSERT_TRUE(raylet_client->GrantWorkerLease("localhost", 1001, ClientID::Nil()));
|
||||||
usleep(10 * 1000); // Sleep for 10ms, causing the lease to time out.
|
usleep(10 * 1000); // Sleep for 10ms, causing the lease to time out.
|
||||||
ASSERT_EQ(worker_client->callbacks.size(), 2);
|
ASSERT_TRUE(worker_client->ReplyPushTask());
|
||||||
worker_client->callbacks[1](Status::OK(), rpc::PushTaskReply());
|
|
||||||
ASSERT_EQ(raylet_client->num_workers_returned, 1);
|
ASSERT_EQ(raylet_client->num_workers_returned, 1);
|
||||||
ASSERT_EQ(raylet_client->num_workers_disconnected, 1);
|
ASSERT_EQ(raylet_client->num_workers_disconnected, 1);
|
||||||
|
|
||||||
// Task 3 runs successfully on the third worker; the worker is returned even though it
|
// Task 3 runs successfully on the third worker; the worker is returned even though it
|
||||||
// hasn't timed out.
|
// hasn't timed out.
|
||||||
ASSERT_TRUE(raylet_client->GrantWorkerLease("localhost", 1002, ClientID::Nil()));
|
ASSERT_TRUE(raylet_client->GrantWorkerLease("localhost", 1002, ClientID::Nil()));
|
||||||
ASSERT_EQ(worker_client->callbacks.size(), 3);
|
ASSERT_TRUE(worker_client->ReplyPushTask());
|
||||||
worker_client->callbacks[2](Status::OK(), rpc::PushTaskReply());
|
ASSERT_EQ(worker_client->callbacks.size(), 0);
|
||||||
ASSERT_EQ(raylet_client->num_workers_returned, 2);
|
ASSERT_EQ(raylet_client->num_workers_returned, 2);
|
||||||
ASSERT_EQ(raylet_client->num_workers_disconnected, 1);
|
ASSERT_EQ(raylet_client->num_workers_disconnected, 1);
|
||||||
}
|
}
|
||||||
|
|
77
src/ray/core_worker/test/task_manager_test.cc
Normal file
77
src/ray/core_worker/test/task_manager_test.cc
Normal file
|
@ -0,0 +1,77 @@
|
||||||
|
#include "gtest/gtest.h"
|
||||||
|
|
||||||
|
#include "ray/common/task/task_spec.h"
|
||||||
|
#include "ray/core_worker/store_provider/memory_store/memory_store.h"
|
||||||
|
#include "ray/core_worker/task_manager.h"
|
||||||
|
#include "ray/util/test_util.h"
|
||||||
|
|
||||||
|
namespace ray {
|
||||||
|
|
||||||
|
TaskSpecification CreateTaskHelper(uint64_t num_returns) {
|
||||||
|
TaskSpecification task;
|
||||||
|
task.GetMutableMessage().set_task_id(TaskID::ForFakeTask().Binary());
|
||||||
|
task.GetMutableMessage().set_num_returns(num_returns);
|
||||||
|
return task;
|
||||||
|
}
|
||||||
|
|
||||||
|
class TaskManagerTest : public ::testing::Test {
|
||||||
|
public:
|
||||||
|
TaskManagerTest()
|
||||||
|
: store_(std::shared_ptr<CoreWorkerMemoryStore>(new CoreWorkerMemoryStore())),
|
||||||
|
manager_(store_) {}
|
||||||
|
|
||||||
|
std::shared_ptr<CoreWorkerMemoryStore> store_;
|
||||||
|
TaskManager manager_;
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST_F(TaskManagerTest, TestTaskSuccess) {
|
||||||
|
auto spec = CreateTaskHelper(1);
|
||||||
|
ASSERT_FALSE(manager_.IsTaskPending(spec.TaskId()));
|
||||||
|
manager_.AddPendingTask(spec);
|
||||||
|
ASSERT_TRUE(manager_.IsTaskPending(spec.TaskId()));
|
||||||
|
auto return_id = spec.ReturnId(0, TaskTransportType::DIRECT);
|
||||||
|
WorkerContext ctx(WorkerType::WORKER, JobID::FromInt(0));
|
||||||
|
|
||||||
|
rpc::PushTaskReply reply;
|
||||||
|
auto return_object = reply.add_return_objects();
|
||||||
|
return_object->set_object_id(return_id.Binary());
|
||||||
|
auto data = GenerateRandomBuffer();
|
||||||
|
return_object->set_data(data->Data(), data->Size());
|
||||||
|
manager_.CompletePendingTask(spec.TaskId(), reply);
|
||||||
|
ASSERT_FALSE(manager_.IsTaskPending(spec.TaskId()));
|
||||||
|
|
||||||
|
std::vector<std::shared_ptr<RayObject>> results;
|
||||||
|
RAY_CHECK_OK(store_->Get({return_id}, 1, -1, ctx, false, &results));
|
||||||
|
ASSERT_EQ(results.size(), 1);
|
||||||
|
ASSERT_FALSE(results[0]->IsException());
|
||||||
|
ASSERT_EQ(std::memcmp(results[0]->GetData()->Data(), return_object->data().data(),
|
||||||
|
return_object->data().size()),
|
||||||
|
0);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(TaskManagerTest, TestTaskFailure) {
|
||||||
|
auto spec = CreateTaskHelper(1);
|
||||||
|
ASSERT_FALSE(manager_.IsTaskPending(spec.TaskId()));
|
||||||
|
manager_.AddPendingTask(spec);
|
||||||
|
ASSERT_TRUE(manager_.IsTaskPending(spec.TaskId()));
|
||||||
|
auto return_id = spec.ReturnId(0, TaskTransportType::DIRECT);
|
||||||
|
WorkerContext ctx(WorkerType::WORKER, JobID::FromInt(0));
|
||||||
|
|
||||||
|
auto error = rpc::ErrorType::WORKER_DIED;
|
||||||
|
manager_.FailPendingTask(spec.TaskId(), error);
|
||||||
|
ASSERT_FALSE(manager_.IsTaskPending(spec.TaskId()));
|
||||||
|
|
||||||
|
std::vector<std::shared_ptr<RayObject>> results;
|
||||||
|
RAY_CHECK_OK(store_->Get({return_id}, 1, -1, ctx, false, &results));
|
||||||
|
ASSERT_EQ(results.size(), 1);
|
||||||
|
rpc::ErrorType stored_error;
|
||||||
|
ASSERT_TRUE(results[0]->IsException(&stored_error));
|
||||||
|
ASSERT_EQ(stored_error, error);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace ray
|
||||||
|
|
||||||
|
int main(int argc, char **argv) {
|
||||||
|
::testing::InitGoogleTest(&argc, argv);
|
||||||
|
return RUN_ALL_TESTS();
|
||||||
|
}
|
|
@ -11,57 +11,6 @@ int64_t GetRequestNumber(const std::unique_ptr<rpc::PushTaskRequest> &request) {
|
||||||
return request->task_spec().actor_task_spec().actor_counter();
|
return request->task_spec().actor_task_spec().actor_counter();
|
||||||
}
|
}
|
||||||
|
|
||||||
void TreatTaskAsFailed(const TaskID &task_id, int num_returns,
|
|
||||||
const rpc::ErrorType &error_type,
|
|
||||||
std::shared_ptr<CoreWorkerMemoryStore> &in_memory_store) {
|
|
||||||
RAY_LOG(DEBUG) << "Treat task as failed. task_id: " << task_id
|
|
||||||
<< ", error_type: " << ErrorType_Name(error_type);
|
|
||||||
for (int i = 0; i < num_returns; i++) {
|
|
||||||
const auto object_id = ObjectID::ForTaskReturn(
|
|
||||||
task_id, /*index=*/i + 1,
|
|
||||||
/*transport_type=*/static_cast<int>(TaskTransportType::DIRECT));
|
|
||||||
std::string meta = std::to_string(static_cast<int>(error_type));
|
|
||||||
auto metadata = const_cast<uint8_t *>(reinterpret_cast<const uint8_t *>(meta.data()));
|
|
||||||
auto meta_buffer = std::make_shared<LocalMemoryBuffer>(metadata, meta.size());
|
|
||||||
RAY_CHECK_OK(in_memory_store->Put(RayObject(nullptr, meta_buffer), object_id));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void WriteObjectsToMemoryStore(const rpc::PushTaskReply &reply,
|
|
||||||
std::shared_ptr<CoreWorkerMemoryStore> &in_memory_store) {
|
|
||||||
for (int i = 0; i < reply.return_objects_size(); i++) {
|
|
||||||
const auto &return_object = reply.return_objects(i);
|
|
||||||
ObjectID object_id = ObjectID::FromBinary(return_object.object_id());
|
|
||||||
|
|
||||||
if (return_object.in_plasma()) {
|
|
||||||
// Mark it as in plasma with a dummy object.
|
|
||||||
std::string meta =
|
|
||||||
std::to_string(static_cast<int>(rpc::ErrorType::OBJECT_IN_PLASMA));
|
|
||||||
auto metadata =
|
|
||||||
const_cast<uint8_t *>(reinterpret_cast<const uint8_t *>(meta.data()));
|
|
||||||
auto meta_buffer = std::make_shared<LocalMemoryBuffer>(metadata, meta.size());
|
|
||||||
RAY_CHECK_OK(in_memory_store->Put(RayObject(nullptr, meta_buffer), object_id));
|
|
||||||
} else {
|
|
||||||
std::shared_ptr<LocalMemoryBuffer> data_buffer;
|
|
||||||
if (return_object.data().size() > 0) {
|
|
||||||
data_buffer = std::make_shared<LocalMemoryBuffer>(
|
|
||||||
const_cast<uint8_t *>(
|
|
||||||
reinterpret_cast<const uint8_t *>(return_object.data().data())),
|
|
||||||
return_object.data().size());
|
|
||||||
}
|
|
||||||
std::shared_ptr<LocalMemoryBuffer> metadata_buffer;
|
|
||||||
if (return_object.metadata().size() > 0) {
|
|
||||||
metadata_buffer = std::make_shared<LocalMemoryBuffer>(
|
|
||||||
const_cast<uint8_t *>(
|
|
||||||
reinterpret_cast<const uint8_t *>(return_object.metadata().data())),
|
|
||||||
return_object.metadata().size());
|
|
||||||
}
|
|
||||||
RAY_CHECK_OK(
|
|
||||||
in_memory_store->Put(RayObject(data_buffer, metadata_buffer), object_id));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
Status CoreWorkerDirectActorTaskSubmitter::SubmitTask(TaskSpecification task_spec) {
|
Status CoreWorkerDirectActorTaskSubmitter::SubmitTask(TaskSpecification task_spec) {
|
||||||
RAY_LOG(DEBUG) << "Submitting task " << task_spec.TaskId();
|
RAY_LOG(DEBUG) << "Submitting task " << task_spec.TaskId();
|
||||||
RAY_CHECK(task_spec.IsActorTask());
|
RAY_CHECK(task_spec.IsActorTask());
|
||||||
|
@ -69,7 +18,6 @@ Status CoreWorkerDirectActorTaskSubmitter::SubmitTask(TaskSpecification task_spe
|
||||||
resolver_.ResolveDependencies(task_spec, [this, task_spec]() mutable {
|
resolver_.ResolveDependencies(task_spec, [this, task_spec]() mutable {
|
||||||
const auto &actor_id = task_spec.ActorId();
|
const auto &actor_id = task_spec.ActorId();
|
||||||
const auto task_id = task_spec.TaskId();
|
const auto task_id = task_spec.TaskId();
|
||||||
const auto num_returns = task_spec.NumReturns();
|
|
||||||
|
|
||||||
auto request = std::unique_ptr<rpc::PushTaskRequest>(new rpc::PushTaskRequest);
|
auto request = std::unique_ptr<rpc::PushTaskRequest>(new rpc::PushTaskRequest);
|
||||||
request->mutable_task_spec()->Swap(&task_spec.GetMutableMessage());
|
request->mutable_task_spec()->Swap(&task_spec.GetMutableMessage());
|
||||||
|
@ -97,8 +45,7 @@ Status CoreWorkerDirectActorTaskSubmitter::SubmitTask(TaskSpecification task_spe
|
||||||
} else {
|
} else {
|
||||||
// Actor is dead, treat the task as failure.
|
// Actor is dead, treat the task as failure.
|
||||||
RAY_CHECK(iter->second.state_ == ActorTableData::DEAD);
|
RAY_CHECK(iter->second.state_ == ActorTableData::DEAD);
|
||||||
TreatTaskAsFailed(task_id, num_returns, rpc::ErrorType::ACTOR_DIED,
|
task_finisher_->FailPendingTask(task_id, rpc::ErrorType::ACTOR_DIED);
|
||||||
in_memory_store_);
|
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
|
@ -130,19 +77,6 @@ void CoreWorkerDirectActorTaskSubmitter::HandleActorUpdate(
|
||||||
// Remove rpc client if it's dead or being reconstructed.
|
// Remove rpc client if it's dead or being reconstructed.
|
||||||
rpc_clients_.erase(actor_id);
|
rpc_clients_.erase(actor_id);
|
||||||
|
|
||||||
// For tasks that have been sent and are waiting for replies, treat them
|
|
||||||
// as failed when the destination actor is dead or reconstructing.
|
|
||||||
auto iter = waiting_reply_tasks_.find(actor_id);
|
|
||||||
if (iter != waiting_reply_tasks_.end()) {
|
|
||||||
for (const auto &entry : iter->second) {
|
|
||||||
const auto &task_id = entry.first;
|
|
||||||
const auto num_returns = entry.second;
|
|
||||||
TreatTaskAsFailed(task_id, num_returns, rpc::ErrorType::ACTOR_DIED,
|
|
||||||
in_memory_store_);
|
|
||||||
}
|
|
||||||
waiting_reply_tasks_.erase(actor_id);
|
|
||||||
}
|
|
||||||
|
|
||||||
// If there are pending requests, treat the pending tasks as failed.
|
// If there are pending requests, treat the pending tasks as failed.
|
||||||
auto pending_it = pending_requests_.find(actor_id);
|
auto pending_it = pending_requests_.find(actor_id);
|
||||||
if (pending_it != pending_requests_.end()) {
|
if (pending_it != pending_requests_.end()) {
|
||||||
|
@ -150,15 +84,16 @@ void CoreWorkerDirectActorTaskSubmitter::HandleActorUpdate(
|
||||||
while (head != pending_it->second.end()) {
|
while (head != pending_it->second.end()) {
|
||||||
auto request = std::move(head->second);
|
auto request = std::move(head->second);
|
||||||
head = pending_it->second.erase(head);
|
head = pending_it->second.erase(head);
|
||||||
|
auto task_id = TaskID::FromBinary(request->task_spec().task_id());
|
||||||
TreatTaskAsFailed(TaskID::FromBinary(request->task_spec().task_id()),
|
task_finisher_->FailPendingTask(task_id, rpc::ErrorType::ACTOR_DIED);
|
||||||
request->task_spec().num_returns(), rpc::ErrorType::ACTOR_DIED,
|
|
||||||
in_memory_store_);
|
|
||||||
}
|
}
|
||||||
pending_requests_.erase(pending_it);
|
pending_requests_.erase(pending_it);
|
||||||
}
|
}
|
||||||
|
|
||||||
next_sequence_number_.erase(actor_id);
|
next_sequence_number_.erase(actor_id);
|
||||||
|
|
||||||
|
// No need to clean up tasks that have been sent and are waiting for
|
||||||
|
// replies. They will be treated as failed once the connection dies.
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -182,7 +117,6 @@ void CoreWorkerDirectActorTaskSubmitter::PushActorTask(
|
||||||
rpc::CoreWorkerClientInterface &client, std::unique_ptr<rpc::PushTaskRequest> request,
|
rpc::CoreWorkerClientInterface &client, std::unique_ptr<rpc::PushTaskRequest> request,
|
||||||
const ActorID &actor_id, const TaskID &task_id, int num_returns) {
|
const ActorID &actor_id, const TaskID &task_id, int num_returns) {
|
||||||
RAY_LOG(DEBUG) << "Pushing task " << task_id << " to actor " << actor_id;
|
RAY_LOG(DEBUG) << "Pushing task " << task_id << " to actor " << actor_id;
|
||||||
waiting_reply_tasks_[actor_id].insert(std::make_pair(task_id, num_returns));
|
|
||||||
|
|
||||||
auto task_number = GetRequestNumber(request);
|
auto task_number = GetRequestNumber(request);
|
||||||
RAY_CHECK(next_sequence_number_[actor_id] == task_number)
|
RAY_CHECK(next_sequence_number_[actor_id] == task_number)
|
||||||
|
@ -190,24 +124,19 @@ void CoreWorkerDirectActorTaskSubmitter::PushActorTask(
|
||||||
next_sequence_number_[actor_id]++;
|
next_sequence_number_[actor_id]++;
|
||||||
|
|
||||||
auto status = client.PushActorTask(
|
auto status = client.PushActorTask(
|
||||||
std::move(request), [this, actor_id, task_id, num_returns](
|
std::move(request),
|
||||||
Status status, const rpc::PushTaskReply &reply) {
|
[this, task_id](Status status, const rpc::PushTaskReply &reply) {
|
||||||
{
|
|
||||||
std::unique_lock<std::mutex> guard(mutex_);
|
|
||||||
waiting_reply_tasks_[actor_id].erase(task_id);
|
|
||||||
}
|
|
||||||
if (!status.ok()) {
|
if (!status.ok()) {
|
||||||
// Note that this might be the __ray_terminate__ task, so we don't log
|
// Note that this might be the __ray_terminate__ task, so we don't log
|
||||||
// loudly with ERROR here.
|
// loudly with ERROR here.
|
||||||
RAY_LOG(INFO) << "Task failed with error: " << status;
|
RAY_LOG(INFO) << "Task failed with error: " << status;
|
||||||
TreatTaskAsFailed(task_id, num_returns, rpc::ErrorType::ACTOR_DIED,
|
task_finisher_->FailPendingTask(task_id, rpc::ErrorType::ACTOR_DIED);
|
||||||
in_memory_store_);
|
} else {
|
||||||
return;
|
task_finisher_->CompletePendingTask(task_id, reply);
|
||||||
}
|
}
|
||||||
WriteObjectsToMemoryStore(reply, in_memory_store_);
|
|
||||||
});
|
});
|
||||||
if (!status.ok()) {
|
if (!status.ok()) {
|
||||||
TreatTaskAsFailed(task_id, num_returns, rpc::ErrorType::ACTOR_DIED, in_memory_store_);
|
task_finisher_->FailPendingTask(task_id, rpc::ErrorType::ACTOR_DIED);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
#include "ray/common/ray_object.h"
|
#include "ray/common/ray_object.h"
|
||||||
#include "ray/core_worker/context.h"
|
#include "ray/core_worker/context.h"
|
||||||
#include "ray/core_worker/store_provider/memory_store/memory_store.h"
|
#include "ray/core_worker/store_provider/memory_store/memory_store.h"
|
||||||
|
#include "ray/core_worker/task_manager.h"
|
||||||
#include "ray/core_worker/transport/dependency_resolver.h"
|
#include "ray/core_worker/transport/dependency_resolver.h"
|
||||||
#include "ray/gcs/redis_gcs_client.h"
|
#include "ray/gcs/redis_gcs_client.h"
|
||||||
#include "ray/rpc/grpc_server.h"
|
#include "ray/rpc/grpc_server.h"
|
||||||
|
@ -28,25 +29,6 @@ namespace ray {
|
||||||
/// The max time to wait for out-of-order tasks.
|
/// The max time to wait for out-of-order tasks.
|
||||||
const int kMaxReorderWaitSeconds = 30;
|
const int kMaxReorderWaitSeconds = 30;
|
||||||
|
|
||||||
/// Treat a task as failed.
|
|
||||||
///
|
|
||||||
/// \param[in] task_id The ID of a task.
|
|
||||||
/// \param[in] num_returns Number of return objects.
|
|
||||||
/// \param[in] error_type The type of the specific error.
|
|
||||||
/// \param[in] in_memory_store The memory store to write to.
|
|
||||||
/// \return Void.
|
|
||||||
void TreatTaskAsFailed(const TaskID &task_id, int num_returns,
|
|
||||||
const rpc::ErrorType &error_type,
|
|
||||||
std::shared_ptr<CoreWorkerMemoryStore> &in_memory_store);
|
|
||||||
|
|
||||||
/// Write return objects to the memory store.
|
|
||||||
///
|
|
||||||
/// \param[in] reply Proto response to a direct actor or task call.
|
|
||||||
/// \param[in] in_memory_store The memory store to write to.
|
|
||||||
/// \return Void.
|
|
||||||
void WriteObjectsToMemoryStore(const rpc::PushTaskReply &reply,
|
|
||||||
std::shared_ptr<CoreWorkerMemoryStore> &in_memory_store);
|
|
||||||
|
|
||||||
/// In direct actor call task submitter and receiver, a task is directly submitted
|
/// In direct actor call task submitter and receiver, a task is directly submitted
|
||||||
/// to the actor that will execute it.
|
/// to the actor that will execute it.
|
||||||
|
|
||||||
|
@ -66,10 +48,11 @@ struct ActorStateData {
|
||||||
class CoreWorkerDirectActorTaskSubmitter {
|
class CoreWorkerDirectActorTaskSubmitter {
|
||||||
public:
|
public:
|
||||||
CoreWorkerDirectActorTaskSubmitter(rpc::ClientFactoryFn client_factory,
|
CoreWorkerDirectActorTaskSubmitter(rpc::ClientFactoryFn client_factory,
|
||||||
std::shared_ptr<CoreWorkerMemoryStore> store)
|
std::shared_ptr<CoreWorkerMemoryStore> store,
|
||||||
|
std::shared_ptr<TaskFinisherInterface> task_finisher)
|
||||||
: client_factory_(client_factory),
|
: client_factory_(client_factory),
|
||||||
in_memory_store_(store),
|
resolver_(store),
|
||||||
resolver_(in_memory_store_) {}
|
task_finisher_(task_finisher) {}
|
||||||
|
|
||||||
/// Submit a task to an actor for execution.
|
/// Submit a task to an actor for execution.
|
||||||
///
|
///
|
||||||
|
@ -138,15 +121,12 @@ class CoreWorkerDirectActorTaskSubmitter {
|
||||||
/// actor.
|
/// actor.
|
||||||
std::unordered_map<ActorID, int64_t> next_sequence_number_;
|
std::unordered_map<ActorID, int64_t> next_sequence_number_;
|
||||||
|
|
||||||
/// Map from actor id to the tasks that are waiting for reply.
|
|
||||||
std::unordered_map<ActorID, std::unordered_map<TaskID, int>> waiting_reply_tasks_;
|
|
||||||
|
|
||||||
/// The in-memory store.
|
|
||||||
std::shared_ptr<CoreWorkerMemoryStore> in_memory_store_;
|
|
||||||
|
|
||||||
/// Resolve direct call object dependencies;
|
/// Resolve direct call object dependencies;
|
||||||
LocalDependencyResolver resolver_;
|
LocalDependencyResolver resolver_;
|
||||||
|
|
||||||
|
/// Used to complete tasks.
|
||||||
|
std::shared_ptr<TaskFinisherInterface> task_finisher_;
|
||||||
|
|
||||||
friend class CoreWorkerTest;
|
friend class CoreWorkerTest;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
@ -43,12 +43,16 @@ void CoreWorkerDirectTaskSubmitter::OnWorkerIdle(const rpc::WorkerAddress &addr,
|
||||||
if (was_error || queued_tasks_.empty() || current_time_ms() > entry.second) {
|
if (was_error || queued_tasks_.empty() || current_time_ms() > entry.second) {
|
||||||
RAY_CHECK_OK(entry.first->ReturnWorker(addr.second, was_error));
|
RAY_CHECK_OK(entry.first->ReturnWorker(addr.second, was_error));
|
||||||
worker_to_lease_client_.erase(addr);
|
worker_to_lease_client_.erase(addr);
|
||||||
} else {
|
} else if (!queued_tasks_.empty()) {
|
||||||
auto &client = *client_cache_[addr];
|
auto &client = *client_cache_[addr];
|
||||||
PushNormalTask(addr, client, queued_tasks_.front());
|
PushNormalTask(addr, client, queued_tasks_.front());
|
||||||
queued_tasks_.pop_front();
|
queued_tasks_.pop_front();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// There are more tasks to run, so try to get another worker.
|
||||||
|
if (!queued_tasks_.empty()) {
|
||||||
RequestNewWorkerIfNeeded(queued_tasks_.front());
|
RequestNewWorkerIfNeeded(queued_tasks_.front());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
std::shared_ptr<WorkerLeaseInterface>
|
std::shared_ptr<WorkerLeaseInterface>
|
||||||
|
@ -129,23 +133,21 @@ void CoreWorkerDirectTaskSubmitter::PushNormalTask(const rpc::WorkerAddress &add
|
||||||
rpc::CoreWorkerClientInterface &client,
|
rpc::CoreWorkerClientInterface &client,
|
||||||
TaskSpecification &task_spec) {
|
TaskSpecification &task_spec) {
|
||||||
auto task_id = task_spec.TaskId();
|
auto task_id = task_spec.TaskId();
|
||||||
auto num_returns = task_spec.NumReturns();
|
|
||||||
auto request = std::unique_ptr<rpc::PushTaskRequest>(new rpc::PushTaskRequest);
|
auto request = std::unique_ptr<rpc::PushTaskRequest>(new rpc::PushTaskRequest);
|
||||||
request->mutable_task_spec()->Swap(&task_spec.GetMutableMessage());
|
request->mutable_task_spec()->Swap(&task_spec.GetMutableMessage());
|
||||||
auto status = client.PushNormalTask(
|
auto status = client.PushNormalTask(
|
||||||
std::move(request),
|
std::move(request),
|
||||||
[this, task_id, num_returns, addr](Status status, const rpc::PushTaskReply &reply) {
|
[this, task_id, addr](Status status, const rpc::PushTaskReply &reply) {
|
||||||
OnWorkerIdle(addr, /*error=*/!status.ok());
|
OnWorkerIdle(addr, /*error=*/!status.ok());
|
||||||
if (!status.ok()) {
|
if (!status.ok()) {
|
||||||
TreatTaskAsFailed(task_id, num_returns, rpc::ErrorType::WORKER_DIED,
|
task_finisher_->FailPendingTask(task_id, rpc::ErrorType::WORKER_DIED);
|
||||||
in_memory_store_);
|
} else {
|
||||||
return;
|
task_finisher_->CompletePendingTask(task_id, reply);
|
||||||
}
|
}
|
||||||
WriteObjectsToMemoryStore(reply, in_memory_store_);
|
|
||||||
});
|
});
|
||||||
if (!status.ok()) {
|
if (!status.ok()) {
|
||||||
TreatTaskAsFailed(task_id, num_returns, rpc::ErrorType::WORKER_DIED,
|
// TODO(swang): add unit test for this.
|
||||||
in_memory_store_);
|
task_finisher_->FailPendingTask(task_id, rpc::ErrorType::WORKER_DIED);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}; // namespace ray
|
}; // namespace ray
|
||||||
|
|
|
@ -3,10 +3,12 @@
|
||||||
|
|
||||||
#include "absl/base/thread_annotations.h"
|
#include "absl/base/thread_annotations.h"
|
||||||
#include "absl/synchronization/mutex.h"
|
#include "absl/synchronization/mutex.h"
|
||||||
|
|
||||||
#include "ray/common/id.h"
|
#include "ray/common/id.h"
|
||||||
#include "ray/common/ray_object.h"
|
#include "ray/common/ray_object.h"
|
||||||
#include "ray/core_worker/context.h"
|
#include "ray/core_worker/context.h"
|
||||||
#include "ray/core_worker/store_provider/memory_store/memory_store.h"
|
#include "ray/core_worker/store_provider/memory_store/memory_store.h"
|
||||||
|
#include "ray/core_worker/task_manager.h"
|
||||||
#include "ray/core_worker/transport/dependency_resolver.h"
|
#include "ray/core_worker/transport/dependency_resolver.h"
|
||||||
#include "ray/core_worker/transport/direct_actor_transport.h"
|
#include "ray/core_worker/transport/direct_actor_transport.h"
|
||||||
#include "ray/raylet/raylet_client.h"
|
#include "ray/raylet/raylet_client.h"
|
||||||
|
@ -24,12 +26,13 @@ class CoreWorkerDirectTaskSubmitter {
|
||||||
rpc::ClientFactoryFn client_factory,
|
rpc::ClientFactoryFn client_factory,
|
||||||
LeaseClientFactoryFn lease_client_factory,
|
LeaseClientFactoryFn lease_client_factory,
|
||||||
std::shared_ptr<CoreWorkerMemoryStore> store,
|
std::shared_ptr<CoreWorkerMemoryStore> store,
|
||||||
|
std::shared_ptr<TaskFinisherInterface> task_finisher,
|
||||||
int64_t lease_timeout_ms)
|
int64_t lease_timeout_ms)
|
||||||
: local_lease_client_(lease_client),
|
: local_lease_client_(lease_client),
|
||||||
client_factory_(client_factory),
|
client_factory_(client_factory),
|
||||||
lease_client_factory_(lease_client_factory),
|
lease_client_factory_(lease_client_factory),
|
||||||
in_memory_store_(store),
|
resolver_(store),
|
||||||
resolver_(in_memory_store_),
|
task_finisher_(task_finisher),
|
||||||
lease_timeout_ms_(lease_timeout_ms) {}
|
lease_timeout_ms_(lease_timeout_ms) {}
|
||||||
|
|
||||||
/// Schedule a task for direct submission to a worker.
|
/// Schedule a task for direct submission to a worker.
|
||||||
|
@ -81,12 +84,12 @@ class CoreWorkerDirectTaskSubmitter {
|
||||||
/// Factory for producing new clients to request leases from remote nodes.
|
/// Factory for producing new clients to request leases from remote nodes.
|
||||||
LeaseClientFactoryFn lease_client_factory_;
|
LeaseClientFactoryFn lease_client_factory_;
|
||||||
|
|
||||||
/// The store provider.
|
|
||||||
std::shared_ptr<CoreWorkerMemoryStore> in_memory_store_;
|
|
||||||
|
|
||||||
/// Resolve local and remote dependencies;
|
/// Resolve local and remote dependencies;
|
||||||
LocalDependencyResolver resolver_;
|
LocalDependencyResolver resolver_;
|
||||||
|
|
||||||
|
/// Used to complete tasks.
|
||||||
|
std::shared_ptr<TaskFinisherInterface> task_finisher_;
|
||||||
|
|
||||||
/// The timeout for worker leases; after this duration, workers will be returned
|
/// The timeout for worker leases; after this duration, workers will be returned
|
||||||
/// to the raylet.
|
/// to the raylet.
|
||||||
int64_t lease_timeout_ms_;
|
int64_t lease_timeout_ms_;
|
||||||
|
|
Loading…
Add table
Reference in a new issue