Work stealing! (#15475)

* work_stealing one commit squash

* using random task id to request workers

* inlining methods in direct_task_transport.h

* faster checking for presence of stealable tasks in RequestNewWorkerIfNeeded

* linting

* fixup! using random task id to request workers

* estimating number of tasks to steal based only on tasks in flight

* linting

* fixup! linting

* backup of changes

* fixed issue in scheduling queue test after merge

* linting

* redesigned work stealing. compiles but not tested

* all tests passing locally

* fixup! all tests passing locally

* fixup! fixup! all tests passing locally

* fixed big bug in StealTasksIfNeeded

* rev1

* rev2 (before removing the work_stealing param)

* removed work_stealing flag, fixed existing unit tests

* added unit tests; need to figure out how to assign distinct worker ids in GrantWorkerLease

* fixed work stealing test

* revisions, added two more unit/regression tests

* test
This commit is contained in:
Gabriele Oliaro 2021-06-23 20:08:28 -04:00 committed by GitHub
parent 9249287a36
commit 3e2f608145
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
14 changed files with 1246 additions and 259 deletions

View file

@ -296,6 +296,7 @@ def test_ray_options(shutdown_only):
to_check = ["CPU", "GPU", "memory", "custom1"] to_check = ["CPU", "GPU", "memory", "custom1"]
for key in to_check: for key in to_check:
print(key, without_options[key], with_options[key])
assert without_options[key] != with_options[key], key assert without_options[key] != with_options[key], key
assert without_options != with_options assert without_options != with_options

View file

@ -629,6 +629,7 @@ CoreWorker::CoreWorker(const CoreWorkerOptions &options, const WorkerID &worker_
reference_counter_, node_addr_factory, rpc_address_)) reference_counter_, node_addr_factory, rpc_address_))
: std::shared_ptr<LeasePolicyInterface>( : std::shared_ptr<LeasePolicyInterface>(
std::make_shared<LocalLeasePolicy>(rpc_address_)); std::make_shared<LocalLeasePolicy>(rpc_address_));
direct_task_submitter_ = std::make_unique<CoreWorkerDirectTaskSubmitter>( direct_task_submitter_ = std::make_unique<CoreWorkerDirectTaskSubmitter>(
rpc_address_, local_raylet_client_, core_worker_client_pool_, raylet_client_factory, rpc_address_, local_raylet_client_, core_worker_client_pool_, raylet_client_factory,
std::move(lease_policy), memory_store_, task_manager_, local_raylet_id, std::move(lease_policy), memory_store_, task_manager_, local_raylet_id,
@ -643,6 +644,7 @@ CoreWorker::CoreWorker(const CoreWorkerOptions &options, const WorkerID &worker_
future_resolver_.reset(new FutureResolver(memory_store_, future_resolver_.reset(new FutureResolver(memory_store_,
std::move(report_locality_data_callback), std::move(report_locality_data_callback),
core_worker_client_pool_, rpc_address_)); core_worker_client_pool_, rpc_address_));
// Unfortunately the raylet client has to be constructed after the receivers. // Unfortunately the raylet client has to be constructed after the receivers.
if (direct_task_receiver_ != nullptr) { if (direct_task_receiver_ != nullptr) {
task_argument_waiter_.reset(new DependencyWaiterImpl(*local_raylet_client_)); task_argument_waiter_.reset(new DependencyWaiterImpl(*local_raylet_client_));
@ -2357,6 +2359,13 @@ void CoreWorker::HandlePushTask(const rpc::PushTaskRequest &request,
} }
} }
void CoreWorker::HandleStealTasks(const rpc::StealTasksRequest &request,
rpc::StealTasksReply *reply,
rpc::SendReplyCallback send_reply_callback) {
RAY_LOG(DEBUG) << "Entering CoreWorker::HandleStealWork!";
direct_task_receiver_->HandleStealTasks(request, reply, send_reply_callback);
}
void CoreWorker::HandleDirectActorCallArgWaitComplete( void CoreWorker::HandleDirectActorCallArgWaitComplete(
const rpc::DirectActorCallArgWaitCompleteRequest &request, const rpc::DirectActorCallArgWaitCompleteRequest &request,
rpc::DirectActorCallArgWaitCompleteReply *reply, rpc::DirectActorCallArgWaitCompleteReply *reply,

View file

@ -896,6 +896,11 @@ class CoreWorker : public rpc::CoreWorkerServiceHandler {
void HandlePushTask(const rpc::PushTaskRequest &request, rpc::PushTaskReply *reply, void HandlePushTask(const rpc::PushTaskRequest &request, rpc::PushTaskReply *reply,
rpc::SendReplyCallback send_reply_callback) override; rpc::SendReplyCallback send_reply_callback) override;
/// Implements gRPC server handler.
void HandleStealTasks(const rpc::StealTasksRequest &request,
rpc::StealTasksReply *reply,
rpc::SendReplyCallback send_reply_callback) override;
/// Implements gRPC server handler. /// Implements gRPC server handler.
void HandleDirectActorCallArgWaitComplete( void HandleDirectActorCallArgWaitComplete(
const rpc::DirectActorCallArgWaitCompleteRequest &request, const rpc::DirectActorCallArgWaitCompleteRequest &request,

View file

@ -45,6 +45,8 @@ class TaskFinisherInterface {
const TaskID &task_id, const TaskSpecification &spec, rpc::ErrorType error_type, const TaskID &task_id, const TaskSpecification &spec, rpc::ErrorType error_type,
const std::shared_ptr<rpc::RayException> &creation_task_exception = nullptr) = 0; const std::shared_ptr<rpc::RayException> &creation_task_exception = nullptr) = 0;
virtual absl::optional<TaskSpecification> GetTaskSpec(const TaskID &task_id) const = 0;
virtual ~TaskFinisherInterface() {} virtual ~TaskFinisherInterface() {}
}; };
@ -158,7 +160,7 @@ class TaskManager : public TaskFinisherInterface, public TaskResubmissionInterfa
bool MarkTaskCanceled(const TaskID &task_id) override; bool MarkTaskCanceled(const TaskID &task_id) override;
/// Return the spec for a pending task. /// Return the spec for a pending task.
absl::optional<TaskSpecification> GetTaskSpec(const TaskID &task_id) const; absl::optional<TaskSpecification> GetTaskSpec(const TaskID &task_id) const override;
/// Return specs for pending children tasks of the given parent task. /// Return specs for pending children tasks of the given parent task.
std::vector<TaskID> GetPendingChildrenTasks(const TaskID &parent_task_id) const; std::vector<TaskID> GetPendingChildrenTasks(const TaskID &parent_task_id) const;

View file

@ -97,6 +97,9 @@ class MockTaskFinisher : public TaskFinisherInterface {
MOCK_METHOD1(MarkTaskCanceled, bool(const TaskID &task_id)); MOCK_METHOD1(MarkTaskCanceled, bool(const TaskID &task_id));
MOCK_CONST_METHOD1(GetTaskSpec,
absl::optional<TaskSpecification>(const TaskID &task_id));
MOCK_METHOD4(MarkPendingTaskFailed, MOCK_METHOD4(MarkPendingTaskFailed,
void(const TaskID &task_id, const TaskSpecification &spec, void(const TaskID &task_id, const TaskSpecification &spec,
rpc::ErrorType error_type, rpc::ErrorType error_type,

View file

@ -28,6 +28,10 @@ namespace ray {
// be better to use a mock clock or lease manager interface, but that's high // be better to use a mock clock or lease manager interface, but that's high
// overhead for the very simple timeout logic we currently have. // overhead for the very simple timeout logic we currently have.
int64_t kLongTimeout = 1024 * 1024 * 1024; int64_t kLongTimeout = 1024 * 1024 * 1024;
TaskSpecification BuildTaskSpec(const std::unordered_map<std::string, double> &resources,
const ray::FunctionDescriptor &function_descriptor);
// Calls BuildTaskSpec with empty resources map and empty function descriptor
TaskSpecification BuildEmptyTaskSpec();
class MockWorkerClient : public rpc::CoreWorkerClientInterface { class MockWorkerClient : public rpc::CoreWorkerClientInterface {
public: public:
@ -36,7 +40,31 @@ class MockWorkerClient : public rpc::CoreWorkerClientInterface {
callbacks.push_back(callback); callbacks.push_back(callback);
} }
bool ReplyPushTask(Status status = Status::OK(), bool exit = false) { void StealTasks(std::unique_ptr<rpc::StealTasksRequest> request,
const rpc::ClientCallback<rpc::StealTasksReply> &callback) override {
steal_callbacks.push_back(callback);
}
bool ReplyStealTasks(
Status status = Status::OK(),
std::vector<TaskSpecification> tasks_stolen = std::vector<TaskSpecification>()) {
if (steal_callbacks.size() == 0) {
return false;
}
auto callback = steal_callbacks.front();
auto reply = rpc::StealTasksReply();
for (auto task_spec : tasks_stolen) {
reply.add_stolen_tasks_ids(task_spec.TaskId().Binary());
}
callback(status, reply);
steal_callbacks.pop_front();
return true;
}
bool ReplyPushTask(Status status = Status::OK(), bool exit = false,
bool stolen = false) {
if (callbacks.size() == 0) { if (callbacks.size() == 0) {
return false; return false;
} }
@ -45,6 +73,9 @@ class MockWorkerClient : public rpc::CoreWorkerClientInterface {
if (exit) { if (exit) {
reply.set_worker_exiting(true); reply.set_worker_exiting(true);
} }
if (stolen) {
reply.set_task_stolen(true);
}
callback(status, reply); callback(status, reply);
callbacks.pop_front(); callbacks.pop_front();
return true; return true;
@ -56,6 +87,7 @@ class MockWorkerClient : public rpc::CoreWorkerClientInterface {
} }
std::list<rpc::ClientCallback<rpc::PushTaskReply>> callbacks; std::list<rpc::ClientCallback<rpc::PushTaskReply>> callbacks;
std::list<rpc::ClientCallback<rpc::StealTasksReply>> steal_callbacks;
std::list<rpc::CancelTaskRequest> kill_requests; std::list<rpc::CancelTaskRequest> kill_requests;
}; };
@ -89,6 +121,11 @@ class MockTaskFinisher : public TaskFinisherInterface {
bool MarkTaskCanceled(const TaskID &task_id) override { return true; } bool MarkTaskCanceled(const TaskID &task_id) override { return true; }
absl::optional<TaskSpecification> GetTaskSpec(const TaskID &task_id) const override {
TaskSpecification task = BuildEmptyTaskSpec();
return task;
}
int num_tasks_complete = 0; int num_tasks_complete = 0;
int num_tasks_failed = 0; int num_tasks_failed = 0;
int num_inlined_dependencies = 0; int num_inlined_dependencies = 0;
@ -128,7 +165,8 @@ class MockRayletClient : public WorkerLeaseInterface {
// Trigger reply to RequestWorkerLease. // Trigger reply to RequestWorkerLease.
bool GrantWorkerLease(const std::string &address, int port, bool GrantWorkerLease(const std::string &address, int port,
const NodeID &retry_at_raylet_id, bool cancel = false) { const NodeID &retry_at_raylet_id, bool cancel = false,
std::string worker_id = std::string()) {
rpc::RequestWorkerLeaseReply reply; rpc::RequestWorkerLeaseReply reply;
if (cancel) { if (cancel) {
reply.set_canceled(true); reply.set_canceled(true);
@ -140,6 +178,11 @@ class MockRayletClient : public WorkerLeaseInterface {
reply.mutable_worker_address()->set_ip_address(address); reply.mutable_worker_address()->set_ip_address(address);
reply.mutable_worker_address()->set_port(port); reply.mutable_worker_address()->set_port(port);
reply.mutable_worker_address()->set_raylet_id(retry_at_raylet_id.Binary()); reply.mutable_worker_address()->set_raylet_id(retry_at_raylet_id.Binary());
// Set the worker ID if the worker_id string is a valid, non-empty argument. A
// worker ID can only be set using a 28-characters string.
if (worker_id.length() == 28) {
reply.mutable_worker_address()->set_worker_id(worker_id);
}
} }
if (callbacks.size() == 0) { if (callbacks.size() == 0) {
return false; return false;
@ -358,6 +401,13 @@ TaskSpecification BuildTaskSpec(const std::unordered_map<std::string, double> &r
return builder.Build(); return builder.Build();
} }
TaskSpecification BuildEmptyTaskSpec() {
std::unordered_map<std::string, double> empty_resources;
ray::FunctionDescriptor empty_descriptor =
ray::FunctionDescriptorBuilder::BuildPython("", "", "", "");
return BuildTaskSpec(empty_resources, empty_descriptor);
}
TEST(DirectTaskTransportTest, TestSubmitOneTask) { TEST(DirectTaskTransportTest, TestSubmitOneTask) {
rpc::Address address; rpc::Address address;
auto raylet_client = std::make_shared<MockRayletClient>(); auto raylet_client = std::make_shared<MockRayletClient>();
@ -372,10 +422,7 @@ TEST(DirectTaskTransportTest, TestSubmitOneTask) {
lease_policy, store, task_finisher, lease_policy, store, task_finisher,
NodeID::Nil(), kLongTimeout, actor_creator); NodeID::Nil(), kLongTimeout, actor_creator);
std::unordered_map<std::string, double> empty_resources; TaskSpecification task = BuildEmptyTaskSpec();
ray::FunctionDescriptor empty_descriptor =
ray::FunctionDescriptorBuilder::BuildPython("", "", "", "");
TaskSpecification task = BuildTaskSpec(empty_resources, empty_descriptor);
ASSERT_TRUE(submitter.SubmitTask(task).ok()); ASSERT_TRUE(submitter.SubmitTask(task).ok());
ASSERT_EQ(lease_policy->num_lease_policy_consults, 1); ASSERT_EQ(lease_policy->num_lease_policy_consults, 1);
@ -414,10 +461,7 @@ TEST(DirectTaskTransportTest, TestHandleTaskFailure) {
CoreWorkerDirectTaskSubmitter submitter(address, raylet_client, client_pool, nullptr, CoreWorkerDirectTaskSubmitter submitter(address, raylet_client, client_pool, nullptr,
lease_policy, store, task_finisher, lease_policy, store, task_finisher,
NodeID::Nil(), kLongTimeout, actor_creator); NodeID::Nil(), kLongTimeout, actor_creator);
std::unordered_map<std::string, double> empty_resources; TaskSpecification task = BuildEmptyTaskSpec();
ray::FunctionDescriptor empty_descriptor =
ray::FunctionDescriptorBuilder::BuildPython("", "", "", "");
TaskSpecification task = BuildTaskSpec(empty_resources, empty_descriptor);
ASSERT_TRUE(submitter.SubmitTask(task).ok()); ASSERT_TRUE(submitter.SubmitTask(task).ok());
ASSERT_TRUE(raylet_client->GrantWorkerLease("localhost", 1234, NodeID::Nil())); ASSERT_TRUE(raylet_client->GrantWorkerLease("localhost", 1234, NodeID::Nil()));
@ -449,12 +493,10 @@ TEST(DirectTaskTransportTest, TestConcurrentWorkerLeases) {
CoreWorkerDirectTaskSubmitter submitter(address, raylet_client, client_pool, nullptr, CoreWorkerDirectTaskSubmitter submitter(address, raylet_client, client_pool, nullptr,
lease_policy, store, task_finisher, lease_policy, store, task_finisher,
NodeID::Nil(), kLongTimeout, actor_creator); NodeID::Nil(), kLongTimeout, actor_creator);
std::unordered_map<std::string, double> empty_resources;
ray::FunctionDescriptor empty_descriptor = TaskSpecification task1 = BuildEmptyTaskSpec();
ray::FunctionDescriptorBuilder::BuildPython("", "", "", ""); TaskSpecification task2 = BuildEmptyTaskSpec();
TaskSpecification task1 = BuildTaskSpec(empty_resources, empty_descriptor); TaskSpecification task3 = BuildEmptyTaskSpec();
TaskSpecification task2 = BuildTaskSpec(empty_resources, empty_descriptor);
TaskSpecification task3 = BuildTaskSpec(empty_resources, empty_descriptor);
ASSERT_TRUE(submitter.SubmitTask(task1).ok()); ASSERT_TRUE(submitter.SubmitTask(task1).ok());
ASSERT_TRUE(submitter.SubmitTask(task2).ok()); ASSERT_TRUE(submitter.SubmitTask(task2).ok());
@ -509,12 +551,10 @@ TEST(DirectTaskTransportTest, TestReuseWorkerLease) {
CoreWorkerDirectTaskSubmitter submitter(address, raylet_client, client_pool, nullptr, CoreWorkerDirectTaskSubmitter submitter(address, raylet_client, client_pool, nullptr,
lease_policy, store, task_finisher, lease_policy, store, task_finisher,
NodeID::Nil(), kLongTimeout, actor_creator); NodeID::Nil(), kLongTimeout, actor_creator);
std::unordered_map<std::string, double> empty_resources;
ray::FunctionDescriptor empty_descriptor = TaskSpecification task1 = BuildEmptyTaskSpec();
ray::FunctionDescriptorBuilder::BuildPython("", "", "", ""); TaskSpecification task2 = BuildEmptyTaskSpec();
TaskSpecification task1 = BuildTaskSpec(empty_resources, empty_descriptor); TaskSpecification task3 = BuildEmptyTaskSpec();
TaskSpecification task2 = BuildTaskSpec(empty_resources, empty_descriptor);
TaskSpecification task3 = BuildTaskSpec(empty_resources, empty_descriptor);
ASSERT_TRUE(submitter.SubmitTask(task1).ok()); ASSERT_TRUE(submitter.SubmitTask(task1).ok());
ASSERT_TRUE(submitter.SubmitTask(task2).ok()); ASSERT_TRUE(submitter.SubmitTask(task2).ok());
@ -574,12 +614,9 @@ TEST(DirectTaskTransportTest, TestRetryLeaseCancellation) {
CoreWorkerDirectTaskSubmitter submitter(address, raylet_client, client_pool, nullptr, CoreWorkerDirectTaskSubmitter submitter(address, raylet_client, client_pool, nullptr,
lease_policy, store, task_finisher, lease_policy, store, task_finisher,
NodeID::Nil(), kLongTimeout, actor_creator); NodeID::Nil(), kLongTimeout, actor_creator);
std::unordered_map<std::string, double> empty_resources; TaskSpecification task1 = BuildEmptyTaskSpec();
ray::FunctionDescriptor empty_descriptor = TaskSpecification task2 = BuildEmptyTaskSpec();
ray::FunctionDescriptorBuilder::BuildPython("", "", "", ""); TaskSpecification task3 = BuildEmptyTaskSpec();
TaskSpecification task1 = BuildTaskSpec(empty_resources, empty_descriptor);
TaskSpecification task2 = BuildTaskSpec(empty_resources, empty_descriptor);
TaskSpecification task3 = BuildTaskSpec(empty_resources, empty_descriptor);
ASSERT_TRUE(submitter.SubmitTask(task1).ok()); ASSERT_TRUE(submitter.SubmitTask(task1).ok());
ASSERT_TRUE(submitter.SubmitTask(task2).ok()); ASSERT_TRUE(submitter.SubmitTask(task2).ok());
@ -635,12 +672,9 @@ TEST(DirectTaskTransportTest, TestConcurrentCancellationAndSubmission) {
CoreWorkerDirectTaskSubmitter submitter(address, raylet_client, client_pool, nullptr, CoreWorkerDirectTaskSubmitter submitter(address, raylet_client, client_pool, nullptr,
lease_policy, store, task_finisher, lease_policy, store, task_finisher,
NodeID::Nil(), kLongTimeout, actor_creator); NodeID::Nil(), kLongTimeout, actor_creator);
std::unordered_map<std::string, double> empty_resources; TaskSpecification task1 = BuildEmptyTaskSpec();
ray::FunctionDescriptor empty_descriptor = TaskSpecification task2 = BuildEmptyTaskSpec();
ray::FunctionDescriptorBuilder::BuildPython("", "", "", ""); TaskSpecification task3 = BuildEmptyTaskSpec();
TaskSpecification task1 = BuildTaskSpec(empty_resources, empty_descriptor);
TaskSpecification task2 = BuildTaskSpec(empty_resources, empty_descriptor);
TaskSpecification task3 = BuildTaskSpec(empty_resources, empty_descriptor);
ASSERT_TRUE(submitter.SubmitTask(task1).ok()); ASSERT_TRUE(submitter.SubmitTask(task1).ok());
ASSERT_TRUE(submitter.SubmitTask(task2).ok()); ASSERT_TRUE(submitter.SubmitTask(task2).ok());
@ -693,11 +727,8 @@ TEST(DirectTaskTransportTest, TestWorkerNotReusedOnError) {
CoreWorkerDirectTaskSubmitter submitter(address, raylet_client, client_pool, nullptr, CoreWorkerDirectTaskSubmitter submitter(address, raylet_client, client_pool, nullptr,
lease_policy, store, task_finisher, lease_policy, store, task_finisher,
NodeID::Nil(), kLongTimeout, actor_creator); NodeID::Nil(), kLongTimeout, actor_creator);
std::unordered_map<std::string, double> empty_resources; TaskSpecification task1 = BuildEmptyTaskSpec();
ray::FunctionDescriptor empty_descriptor = TaskSpecification task2 = BuildEmptyTaskSpec();
ray::FunctionDescriptorBuilder::BuildPython("", "", "", "");
TaskSpecification task1 = BuildTaskSpec(empty_resources, empty_descriptor);
TaskSpecification task2 = BuildTaskSpec(empty_resources, empty_descriptor);
ASSERT_TRUE(submitter.SubmitTask(task1).ok()); ASSERT_TRUE(submitter.SubmitTask(task1).ok());
ASSERT_TRUE(submitter.SubmitTask(task2).ok()); ASSERT_TRUE(submitter.SubmitTask(task2).ok());
@ -742,10 +773,7 @@ TEST(DirectTaskTransportTest, TestWorkerNotReturnedOnExit) {
CoreWorkerDirectTaskSubmitter submitter(address, raylet_client, client_pool, nullptr, CoreWorkerDirectTaskSubmitter submitter(address, raylet_client, client_pool, nullptr,
lease_policy, store, task_finisher, lease_policy, store, task_finisher,
NodeID::Nil(), kLongTimeout, actor_creator); NodeID::Nil(), kLongTimeout, actor_creator);
std::unordered_map<std::string, double> empty_resources; TaskSpecification task1 = BuildEmptyTaskSpec();
ray::FunctionDescriptor empty_descriptor =
ray::FunctionDescriptorBuilder::BuildPython("", "", "", "");
TaskSpecification task1 = BuildTaskSpec(empty_resources, empty_descriptor);
ASSERT_TRUE(submitter.SubmitTask(task1).ok()); ASSERT_TRUE(submitter.SubmitTask(task1).ok());
ASSERT_EQ(raylet_client->num_workers_requested, 1); ASSERT_EQ(raylet_client->num_workers_requested, 1);
@ -790,10 +818,7 @@ TEST(DirectTaskTransportTest, TestSpillback) {
CoreWorkerDirectTaskSubmitter submitter( CoreWorkerDirectTaskSubmitter submitter(
address, raylet_client, client_pool, lease_client_factory, lease_policy, store, address, raylet_client, client_pool, lease_client_factory, lease_policy, store,
task_finisher, NodeID::Nil(), kLongTimeout, actor_creator); task_finisher, NodeID::Nil(), kLongTimeout, actor_creator);
std::unordered_map<std::string, double> empty_resources; TaskSpecification task = BuildEmptyTaskSpec();
ray::FunctionDescriptor empty_descriptor =
ray::FunctionDescriptorBuilder::BuildPython("", "", "", "");
TaskSpecification task = BuildTaskSpec(empty_resources, empty_descriptor);
ASSERT_TRUE(submitter.SubmitTask(task).ok()); ASSERT_TRUE(submitter.SubmitTask(task).ok());
ASSERT_EQ(lease_policy->num_lease_policy_consults, 1); ASSERT_EQ(lease_policy->num_lease_policy_consults, 1);
@ -857,10 +882,7 @@ TEST(DirectTaskTransportTest, TestSpillbackRoundTrip) {
CoreWorkerDirectTaskSubmitter submitter( CoreWorkerDirectTaskSubmitter submitter(
address, raylet_client, client_pool, lease_client_factory, lease_policy, store, address, raylet_client, client_pool, lease_client_factory, lease_policy, store,
task_finisher, local_raylet_id, kLongTimeout, actor_creator); task_finisher, local_raylet_id, kLongTimeout, actor_creator);
std::unordered_map<std::string, double> empty_resources; TaskSpecification task = BuildEmptyTaskSpec();
ray::FunctionDescriptor empty_descriptor =
ray::FunctionDescriptorBuilder::BuildPython("", "", "", "");
TaskSpecification task = BuildTaskSpec(empty_resources, empty_descriptor);
ASSERT_TRUE(submitter.SubmitTask(task).ok()); ASSERT_TRUE(submitter.SubmitTask(task).ok());
ASSERT_EQ(raylet_client->num_workers_requested, 1); ASSERT_EQ(raylet_client->num_workers_requested, 1);
@ -1049,12 +1071,9 @@ TEST(DirectTaskTransportTest, TestWorkerLeaseTimeout) {
lease_policy, store, task_finisher, lease_policy, store, task_finisher,
NodeID::Nil(), NodeID::Nil(),
/*lease_timeout_ms=*/5, actor_creator); /*lease_timeout_ms=*/5, actor_creator);
std::unordered_map<std::string, double> empty_resources; TaskSpecification task1 = BuildEmptyTaskSpec();
ray::FunctionDescriptor empty_descriptor = TaskSpecification task2 = BuildEmptyTaskSpec();
ray::FunctionDescriptorBuilder::BuildPython("", "", "", ""); TaskSpecification task3 = BuildEmptyTaskSpec();
TaskSpecification task1 = BuildTaskSpec(empty_resources, empty_descriptor);
TaskSpecification task2 = BuildTaskSpec(empty_resources, empty_descriptor);
TaskSpecification task3 = BuildTaskSpec(empty_resources, empty_descriptor);
ASSERT_TRUE(submitter.SubmitTask(task1).ok()); ASSERT_TRUE(submitter.SubmitTask(task1).ok());
ASSERT_TRUE(submitter.SubmitTask(task2).ok()); ASSERT_TRUE(submitter.SubmitTask(task2).ok());
@ -1109,10 +1128,7 @@ TEST(DirectTaskTransportTest, TestKillExecutingTask) {
CoreWorkerDirectTaskSubmitter submitter(address, raylet_client, client_pool, nullptr, CoreWorkerDirectTaskSubmitter submitter(address, raylet_client, client_pool, nullptr,
lease_policy, store, task_finisher, lease_policy, store, task_finisher,
NodeID::Nil(), kLongTimeout, actor_creator); NodeID::Nil(), kLongTimeout, actor_creator);
std::unordered_map<std::string, double> empty_resources; TaskSpecification task = BuildEmptyTaskSpec();
ray::FunctionDescriptor empty_descriptor =
ray::FunctionDescriptorBuilder::BuildPython("", "", "", "");
TaskSpecification task = BuildTaskSpec(empty_resources, empty_descriptor);
ASSERT_TRUE(submitter.SubmitTask(task).ok()); ASSERT_TRUE(submitter.SubmitTask(task).ok());
ASSERT_TRUE(raylet_client->GrantWorkerLease("localhost", 1234, NodeID::Nil())); ASSERT_TRUE(raylet_client->GrantWorkerLease("localhost", 1234, NodeID::Nil()));
@ -1162,10 +1178,7 @@ TEST(DirectTaskTransportTest, TestKillPendingTask) {
CoreWorkerDirectTaskSubmitter submitter(address, raylet_client, client_pool, nullptr, CoreWorkerDirectTaskSubmitter submitter(address, raylet_client, client_pool, nullptr,
lease_policy, store, task_finisher, lease_policy, store, task_finisher,
NodeID::Nil(), kLongTimeout, actor_creator); NodeID::Nil(), kLongTimeout, actor_creator);
std::unordered_map<std::string, double> empty_resources; TaskSpecification task = BuildEmptyTaskSpec();
ray::FunctionDescriptor empty_descriptor =
ray::FunctionDescriptorBuilder::BuildPython("", "", "", "");
TaskSpecification task = BuildTaskSpec(empty_resources, empty_descriptor);
ASSERT_TRUE(submitter.SubmitTask(task).ok()); ASSERT_TRUE(submitter.SubmitTask(task).ok());
ASSERT_TRUE(submitter.CancelTask(task, true, false).ok()); ASSERT_TRUE(submitter.CancelTask(task, true, false).ok());
@ -1199,10 +1212,7 @@ TEST(DirectTaskTransportTest, TestKillResolvingTask) {
CoreWorkerDirectTaskSubmitter submitter(address, raylet_client, client_pool, nullptr, CoreWorkerDirectTaskSubmitter submitter(address, raylet_client, client_pool, nullptr,
lease_policy, store, task_finisher, lease_policy, store, task_finisher,
NodeID::Nil(), kLongTimeout, actor_creator); NodeID::Nil(), kLongTimeout, actor_creator);
std::unordered_map<std::string, double> empty_resources; TaskSpecification task = BuildEmptyTaskSpec();
ray::FunctionDescriptor empty_descriptor =
ray::FunctionDescriptorBuilder::BuildPython("", "", "", "");
TaskSpecification task = BuildTaskSpec(empty_resources, empty_descriptor);
ObjectID obj1 = ObjectID::FromRandom(); ObjectID obj1 = ObjectID::FromRandom();
task.GetMutableMessage().add_args()->mutable_object_ref()->set_object_id(obj1.Binary()); task.GetMutableMessage().add_args()->mutable_object_ref()->set_object_id(obj1.Binary());
ASSERT_TRUE(submitter.SubmitTask(task).ok()); ASSERT_TRUE(submitter.SubmitTask(task).ok());
@ -1242,12 +1252,9 @@ TEST(DirectTaskTransportTest, TestPipeliningConcurrentWorkerLeases) {
NodeID::Nil(), kLongTimeout, actor_creator, max_tasks_in_flight_per_worker); NodeID::Nil(), kLongTimeout, actor_creator, max_tasks_in_flight_per_worker);
// Prepare 20 tasks and save them in a vector. // Prepare 20 tasks and save them in a vector.
std::unordered_map<std::string, double> empty_resources;
ray::FunctionDescriptor empty_descriptor =
ray::FunctionDescriptorBuilder::BuildPython("", "", "", "");
std::vector<TaskSpecification> tasks; std::vector<TaskSpecification> tasks;
for (int i = 1; i <= 20; i++) { for (int i = 1; i <= 20; i++) {
tasks.push_back(BuildTaskSpec(empty_resources, empty_descriptor)); tasks.push_back(BuildEmptyTaskSpec());
} }
ASSERT_EQ(tasks.size(), 20); ASSERT_EQ(tasks.size(), 20);
@ -1262,10 +1269,11 @@ TEST(DirectTaskTransportTest, TestPipeliningConcurrentWorkerLeases) {
ASSERT_EQ(worker_client->callbacks.size(), 10); ASSERT_EQ(worker_client->callbacks.size(), 10);
ASSERT_EQ(raylet_client->num_workers_requested, 2); ASSERT_EQ(raylet_client->num_workers_requested, 2);
// Last 10 tasks are pushed; no more workers are requested. // Last 10 tasks are pushed; one more worker is requested due to the Eager Worker
// Requesting Mode.
ASSERT_TRUE(raylet_client->GrantWorkerLease("localhost", 1001, NodeID::Nil())); ASSERT_TRUE(raylet_client->GrantWorkerLease("localhost", 1001, NodeID::Nil()));
ASSERT_EQ(worker_client->callbacks.size(), 20); ASSERT_EQ(worker_client->callbacks.size(), 20);
ASSERT_EQ(raylet_client->num_workers_requested, 2); ASSERT_EQ(raylet_client->num_workers_requested, 3);
for (int i = 1; i <= 20; i++) { for (int i = 1; i <= 20; i++) {
ASSERT_FALSE(worker_client->callbacks.empty()); ASSERT_FALSE(worker_client->callbacks.empty());
@ -1283,14 +1291,15 @@ TEST(DirectTaskTransportTest, TestPipeliningConcurrentWorkerLeases) {
} }
} }
ASSERT_EQ(raylet_client->num_workers_requested, 2); ASSERT_EQ(raylet_client->num_workers_requested, 3);
ASSERT_EQ(raylet_client->num_workers_returned, 2); ASSERT_EQ(raylet_client->num_workers_returned, 2);
ASSERT_EQ(raylet_client->num_workers_disconnected, 0); ASSERT_EQ(raylet_client->num_workers_disconnected, 0);
ASSERT_EQ(task_finisher->num_tasks_complete, 20); ASSERT_EQ(task_finisher->num_tasks_complete, 20);
ASSERT_EQ(task_finisher->num_tasks_failed, 0); ASSERT_EQ(task_finisher->num_tasks_failed, 0);
ASSERT_EQ(raylet_client->num_leases_canceled, 0); ASSERT_EQ(raylet_client->num_leases_canceled, 1);
ASSERT_TRUE(raylet_client->ReplyCancelWorkerLease());
ASSERT_FALSE(raylet_client->ReplyCancelWorkerLease()); ASSERT_TRUE(raylet_client->GrantWorkerLease("nil", 0, NodeID::Nil(), /*cancel=*/true));
ASSERT_EQ(raylet_client->num_leases_canceled, 1);
// Check that there are no entries left in the scheduling_key_entries_ hashmap. These // Check that there are no entries left in the scheduling_key_entries_ hashmap. These
// would otherwise cause a memory leak. // would otherwise cause a memory leak.
@ -1317,12 +1326,9 @@ TEST(DirectTaskTransportTest, TestPipeliningReuseWorkerLease) {
NodeID::Nil(), kLongTimeout, actor_creator, max_tasks_in_flight_per_worker); NodeID::Nil(), kLongTimeout, actor_creator, max_tasks_in_flight_per_worker);
// prepare 30 tasks and save them in a vector // prepare 30 tasks and save them in a vector
std::unordered_map<std::string, double> empty_resources;
ray::FunctionDescriptor empty_descriptor =
ray::FunctionDescriptorBuilder::BuildPython("", "", "", "");
std::vector<TaskSpecification> tasks; std::vector<TaskSpecification> tasks;
for (int i = 0; i < 30; i++) { for (int i = 0; i < 30; i++) {
tasks.push_back(BuildTaskSpec(empty_resources, empty_descriptor)); tasks.push_back(BuildEmptyTaskSpec());
} }
ASSERT_EQ(tasks.size(), 30); ASSERT_EQ(tasks.size(), 30);
@ -1353,14 +1359,17 @@ TEST(DirectTaskTransportTest, TestPipeliningReuseWorkerLease) {
} }
ASSERT_EQ(worker_client->callbacks.size(), 10); ASSERT_EQ(worker_client->callbacks.size(), 10);
ASSERT_EQ(raylet_client->num_workers_returned, 0); ASSERT_EQ(raylet_client->num_workers_returned, 0);
ASSERT_EQ(raylet_client->num_leases_canceled, 1); ASSERT_EQ(raylet_client->num_leases_canceled, 0);
ASSERT_TRUE(raylet_client->ReplyCancelWorkerLease());
// Tasks 21-30 finish, and the worker is finally returned. // Tasks 21-30 finish, and the worker is finally returned.
for (int i = 21; i <= 30; i++) { for (int i = 21; i <= 30; i++) {
ASSERT_TRUE(worker_client->ReplyPushTask()); ASSERT_TRUE(worker_client->ReplyPushTask());
} }
ASSERT_EQ(raylet_client->num_workers_returned, 1); ASSERT_EQ(raylet_client->num_workers_returned, 1);
ASSERT_EQ(worker_client->callbacks.size(), 0);
ASSERT_EQ(task_finisher->num_tasks_complete, 30);
ASSERT_EQ(raylet_client->num_leases_canceled, 1);
ASSERT_TRUE(raylet_client->ReplyCancelWorkerLease());
// The second lease request is returned immediately. // The second lease request is returned immediately.
ASSERT_TRUE(raylet_client->GrantWorkerLease("localhost", 1001, NodeID::Nil())); ASSERT_TRUE(raylet_client->GrantWorkerLease("localhost", 1001, NodeID::Nil()));
@ -1397,12 +1406,9 @@ TEST(DirectTaskTransportTest, TestPipeliningNumberOfWorkersRequested) {
NodeID::Nil(), kLongTimeout, actor_creator, max_tasks_in_flight_per_worker); NodeID::Nil(), kLongTimeout, actor_creator, max_tasks_in_flight_per_worker);
// prepare 30 tasks and save them in a vector // prepare 30 tasks and save them in a vector
std::unordered_map<std::string, double> empty_resources;
ray::FunctionDescriptor empty_descriptor =
ray::FunctionDescriptorBuilder::BuildPython("", "", "", "");
std::vector<TaskSpecification> tasks; std::vector<TaskSpecification> tasks;
for (int i = 0; i < 30; i++) { for (int i = 0; i < 30; i++) {
tasks.push_back(BuildTaskSpec(empty_resources, empty_descriptor)); tasks.push_back(BuildEmptyTaskSpec());
} }
ASSERT_EQ(tasks.size(), 30); ASSERT_EQ(tasks.size(), 30);
@ -1419,9 +1425,10 @@ TEST(DirectTaskTransportTest, TestPipeliningNumberOfWorkersRequested) {
ASSERT_EQ(raylet_client->num_leases_canceled, 0); ASSERT_EQ(raylet_client->num_leases_canceled, 0);
ASSERT_EQ(worker_client->callbacks.size(), 0); ASSERT_EQ(worker_client->callbacks.size(), 0);
// Grant a worker lease, and check that still only 1 worker was requested. // Grant a worker lease, and check that one more worker was requested due to the Eager
// Worker Requesting Mode.
ASSERT_TRUE(raylet_client->GrantWorkerLease("localhost", 1000, NodeID::Nil())); ASSERT_TRUE(raylet_client->GrantWorkerLease("localhost", 1000, NodeID::Nil()));
ASSERT_EQ(raylet_client->num_workers_requested, 1); ASSERT_EQ(raylet_client->num_workers_requested, 2);
ASSERT_EQ(raylet_client->num_workers_returned, 0); ASSERT_EQ(raylet_client->num_workers_returned, 0);
ASSERT_EQ(raylet_client->num_workers_disconnected, 0); ASSERT_EQ(raylet_client->num_workers_disconnected, 0);
ASSERT_EQ(task_finisher->num_tasks_complete, 0); ASSERT_EQ(task_finisher->num_tasks_complete, 0);
@ -1429,14 +1436,14 @@ TEST(DirectTaskTransportTest, TestPipeliningNumberOfWorkersRequested) {
ASSERT_EQ(raylet_client->num_leases_canceled, 0); ASSERT_EQ(raylet_client->num_leases_canceled, 0);
ASSERT_EQ(worker_client->callbacks.size(), 4); ASSERT_EQ(worker_client->callbacks.size(), 4);
// Submit 6 more tasks, and check that still only 1 worker was requested. // Submit 6 more tasks, and check that still only 2 worker were requested.
for (int i = 1; i <= 6; i++) { for (int i = 1; i <= 6; i++) {
auto task = tasks.front(); auto task = tasks.front();
ASSERT_TRUE(submitter.SubmitTask(task).ok()); ASSERT_TRUE(submitter.SubmitTask(task).ok());
tasks.erase(tasks.begin()); tasks.erase(tasks.begin());
} }
ASSERT_EQ(tasks.size(), 20); ASSERT_EQ(tasks.size(), 20);
ASSERT_EQ(raylet_client->num_workers_requested, 1); ASSERT_EQ(raylet_client->num_workers_requested, 2);
ASSERT_EQ(raylet_client->num_workers_returned, 0); ASSERT_EQ(raylet_client->num_workers_returned, 0);
ASSERT_EQ(raylet_client->num_workers_disconnected, 0); ASSERT_EQ(raylet_client->num_workers_disconnected, 0);
ASSERT_EQ(task_finisher->num_tasks_complete, 0); ASSERT_EQ(task_finisher->num_tasks_complete, 0);
@ -1444,7 +1451,8 @@ TEST(DirectTaskTransportTest, TestPipeliningNumberOfWorkersRequested) {
ASSERT_EQ(raylet_client->num_leases_canceled, 0); ASSERT_EQ(raylet_client->num_leases_canceled, 0);
ASSERT_EQ(worker_client->callbacks.size(), 10); ASSERT_EQ(worker_client->callbacks.size(), 10);
// Submit 1 more task, and check that one more worker is requested, for a total of 2. // Submit 1 more task, and check that no additional worker is requested, because a
// request is already pending (due to the Eager Worker Requesting mode)
auto task = tasks.front(); auto task = tasks.front();
ASSERT_TRUE(submitter.SubmitTask(task).ok()); ASSERT_TRUE(submitter.SubmitTask(task).ok());
tasks.erase(tasks.begin()); tasks.erase(tasks.begin());
@ -1457,9 +1465,10 @@ TEST(DirectTaskTransportTest, TestPipeliningNumberOfWorkersRequested) {
ASSERT_EQ(raylet_client->num_leases_canceled, 0); ASSERT_EQ(raylet_client->num_leases_canceled, 0);
ASSERT_EQ(worker_client->callbacks.size(), 10); ASSERT_EQ(worker_client->callbacks.size(), 10);
// Grant a worker lease, and check that still only 2 workers were requested. // Grant a worker lease, and check that one more worker is requested because there are
// stealable tasks.
ASSERT_TRUE(raylet_client->GrantWorkerLease("localhost", 1001, NodeID::Nil())); ASSERT_TRUE(raylet_client->GrantWorkerLease("localhost", 1001, NodeID::Nil()));
ASSERT_EQ(raylet_client->num_workers_requested, 2); ASSERT_EQ(raylet_client->num_workers_requested, 3);
ASSERT_EQ(raylet_client->num_workers_returned, 0); ASSERT_EQ(raylet_client->num_workers_returned, 0);
ASSERT_EQ(raylet_client->num_workers_disconnected, 0); ASSERT_EQ(raylet_client->num_workers_disconnected, 0);
ASSERT_EQ(task_finisher->num_tasks_complete, 0); ASSERT_EQ(task_finisher->num_tasks_complete, 0);
@ -1475,7 +1484,7 @@ TEST(DirectTaskTransportTest, TestPipeliningNumberOfWorkersRequested) {
tasks.erase(tasks.begin()); tasks.erase(tasks.begin());
} }
ASSERT_EQ(tasks.size(), 10); ASSERT_EQ(tasks.size(), 10);
ASSERT_EQ(raylet_client->num_workers_requested, 2); ASSERT_EQ(raylet_client->num_workers_requested, 3);
ASSERT_EQ(raylet_client->num_workers_returned, 0); ASSERT_EQ(raylet_client->num_workers_returned, 0);
ASSERT_EQ(raylet_client->num_workers_disconnected, 0); ASSERT_EQ(raylet_client->num_workers_disconnected, 0);
ASSERT_EQ(task_finisher->num_tasks_complete, 0); ASSERT_EQ(task_finisher->num_tasks_complete, 0);
@ -1484,11 +1493,11 @@ TEST(DirectTaskTransportTest, TestPipeliningNumberOfWorkersRequested) {
ASSERT_EQ(worker_client->callbacks.size(), 20); ASSERT_EQ(worker_client->callbacks.size(), 20);
// Call ReplyPushTask on a quarter of the submitted tasks (5), and check that the // Call ReplyPushTask on a quarter of the submitted tasks (5), and check that the
// total number of workers requested remains equal to 2. // total number of workers requested remains equal to 3.
for (int i = 1; i <= 5; i++) { for (int i = 1; i <= 5; i++) {
ASSERT_TRUE(worker_client->ReplyPushTask()); ASSERT_TRUE(worker_client->ReplyPushTask());
} }
ASSERT_EQ(raylet_client->num_workers_requested, 2); ASSERT_EQ(raylet_client->num_workers_requested, 3);
ASSERT_EQ(raylet_client->num_workers_returned, 0); ASSERT_EQ(raylet_client->num_workers_returned, 0);
ASSERT_EQ(raylet_client->num_workers_disconnected, 0); ASSERT_EQ(raylet_client->num_workers_disconnected, 0);
ASSERT_EQ(task_finisher->num_tasks_complete, 5); ASSERT_EQ(task_finisher->num_tasks_complete, 5);
@ -1503,7 +1512,7 @@ TEST(DirectTaskTransportTest, TestPipeliningNumberOfWorkersRequested) {
tasks.erase(tasks.begin()); tasks.erase(tasks.begin());
} }
ASSERT_EQ(tasks.size(), 5); ASSERT_EQ(tasks.size(), 5);
ASSERT_EQ(raylet_client->num_workers_requested, 2); ASSERT_EQ(raylet_client->num_workers_requested, 3);
ASSERT_EQ(raylet_client->num_workers_returned, 0); ASSERT_EQ(raylet_client->num_workers_returned, 0);
ASSERT_EQ(raylet_client->num_workers_disconnected, 0); ASSERT_EQ(raylet_client->num_workers_disconnected, 0);
ASSERT_EQ(task_finisher->num_tasks_complete, 5); ASSERT_EQ(task_finisher->num_tasks_complete, 5);
@ -1512,11 +1521,11 @@ TEST(DirectTaskTransportTest, TestPipeliningNumberOfWorkersRequested) {
ASSERT_EQ(worker_client->callbacks.size(), 20); ASSERT_EQ(worker_client->callbacks.size(), 20);
// Call ReplyPushTask on a quarter of the submitted tasks (5), and check that the // Call ReplyPushTask on a quarter of the submitted tasks (5), and check that the
// total number of workers requested remains equal to 2. // total number of workers requested remains equal to 3.
for (int i = 1; i <= 5; i++) { for (int i = 1; i <= 5; i++) {
ASSERT_TRUE(worker_client->ReplyPushTask()); ASSERT_TRUE(worker_client->ReplyPushTask());
} }
ASSERT_EQ(raylet_client->num_workers_requested, 2); ASSERT_EQ(raylet_client->num_workers_requested, 3);
ASSERT_EQ(raylet_client->num_workers_returned, 0); ASSERT_EQ(raylet_client->num_workers_returned, 0);
ASSERT_EQ(raylet_client->num_workers_disconnected, 0); ASSERT_EQ(raylet_client->num_workers_disconnected, 0);
ASSERT_EQ(task_finisher->num_tasks_complete, 10); ASSERT_EQ(task_finisher->num_tasks_complete, 10);
@ -1525,14 +1534,14 @@ TEST(DirectTaskTransportTest, TestPipeliningNumberOfWorkersRequested) {
ASSERT_EQ(worker_client->callbacks.size(), 15); ASSERT_EQ(worker_client->callbacks.size(), 15);
// Submit last 5 tasks, and check that the total number of workers requested is still // Submit last 5 tasks, and check that the total number of workers requested is still
// 2 // 3
for (int i = 1; i <= 5; i++) { for (int i = 1; i <= 5; i++) {
auto task = tasks.front(); auto task = tasks.front();
ASSERT_TRUE(submitter.SubmitTask(task).ok()); ASSERT_TRUE(submitter.SubmitTask(task).ok());
tasks.erase(tasks.begin()); tasks.erase(tasks.begin());
} }
ASSERT_EQ(tasks.size(), 0); ASSERT_EQ(tasks.size(), 0);
ASSERT_EQ(raylet_client->num_workers_requested, 2); ASSERT_EQ(raylet_client->num_workers_requested, 3);
ASSERT_EQ(raylet_client->num_workers_returned, 0); ASSERT_EQ(raylet_client->num_workers_returned, 0);
ASSERT_EQ(raylet_client->num_workers_disconnected, 0); ASSERT_EQ(raylet_client->num_workers_disconnected, 0);
ASSERT_EQ(task_finisher->num_tasks_complete, 10); ASSERT_EQ(task_finisher->num_tasks_complete, 10);
@ -1545,19 +1554,436 @@ TEST(DirectTaskTransportTest, TestPipeliningNumberOfWorkersRequested) {
for (int i = 1; i <= 20; i++) { for (int i = 1; i <= 20; i++) {
ASSERT_TRUE(worker_client->ReplyPushTask()); ASSERT_TRUE(worker_client->ReplyPushTask());
} }
ASSERT_EQ(raylet_client->num_workers_requested, 2); ASSERT_EQ(raylet_client->num_workers_requested, 3);
ASSERT_EQ(raylet_client->num_workers_returned, 2); ASSERT_EQ(raylet_client->num_workers_returned, 2);
ASSERT_EQ(raylet_client->num_workers_disconnected, 0); ASSERT_EQ(raylet_client->num_workers_disconnected, 0);
ASSERT_EQ(task_finisher->num_tasks_complete, 30); ASSERT_EQ(task_finisher->num_tasks_complete, 30);
ASSERT_EQ(task_finisher->num_tasks_failed, 0); ASSERT_EQ(task_finisher->num_tasks_failed, 0);
ASSERT_EQ(raylet_client->num_leases_canceled, 0); ASSERT_EQ(raylet_client->num_leases_canceled, 1);
ASSERT_EQ(worker_client->callbacks.size(), 0); ASSERT_EQ(worker_client->callbacks.size(), 0);
ASSERT_TRUE(raylet_client->ReplyCancelWorkerLease());
ASSERT_TRUE(raylet_client->GrantWorkerLease("nil", 0, NodeID::Nil(), /*cancel=*/true));
ASSERT_FALSE(raylet_client->ReplyCancelWorkerLease());
ASSERT_EQ(raylet_client->num_leases_canceled, 1);
// Check that there are no entries left in the scheduling_key_entries_ hashmap. These // Check that there are no entries left in the scheduling_key_entries_ hashmap. These
// would otherwise cause a memory leak. // would otherwise cause a memory leak.
ASSERT_TRUE(submitter.CheckNoSchedulingKeyEntriesPublic()); ASSERT_TRUE(submitter.CheckNoSchedulingKeyEntriesPublic());
} }
TEST(DirectTaskTransportTest, TestStealingTasks) {
rpc::Address address;
auto raylet_client = std::make_shared<MockRayletClient>();
auto worker_client = std::make_shared<MockWorkerClient>();
auto store = std::make_shared<CoreWorkerMemoryStore>();
auto client_pool = std::make_shared<rpc::CoreWorkerClientPool>(
[&](const rpc::Address &addr) { return worker_client; });
auto task_finisher = std::make_shared<MockTaskFinisher>();
auto actor_creator = std::make_shared<MockActorCreator>();
auto lease_policy = std::make_shared<MockLeasePolicy>();
// Set max_tasks_in_flight_per_worker to a value larger than 1 to enable the
// pipelining of task submissions. This is done by passing a
// max_tasks_in_flight_per_worker parameter to the CoreWorkerDirectTaskSubmitter.
uint32_t max_tasks_in_flight_per_worker = 10;
CoreWorkerDirectTaskSubmitter submitter(
address, raylet_client, client_pool, nullptr, lease_policy, store, task_finisher,
NodeID::Nil(), kLongTimeout, actor_creator, max_tasks_in_flight_per_worker);
// prepare 20 tasks and save them in a vector
std::vector<TaskSpecification> tasks;
for (int i = 0; i < 20; i++) {
tasks.push_back(BuildEmptyTaskSpec());
}
ASSERT_EQ(tasks.size(), 20);
// Submit 10 tasks, and check that 1 worker is requested.
for (int i = 1; i <= 20; i++) {
auto task = tasks.front();
ASSERT_TRUE(submitter.SubmitTask(task).ok());
tasks.erase(tasks.begin());
}
ASSERT_EQ(tasks.size(), 0);
ASSERT_EQ(raylet_client->num_workers_requested, 1);
ASSERT_EQ(task_finisher->num_tasks_complete, 0);
ASSERT_EQ(task_finisher->num_tasks_failed, 0);
ASSERT_EQ(raylet_client->num_leases_canceled, 0);
ASSERT_EQ(worker_client->callbacks.size(), 0);
// Grant a worker lease, and check that one more worker is requested due to the Eager
// Worker Requesting Mode.
std::string worker1_id = "worker1_ID_abcdefghijklmnopq";
ASSERT_TRUE(raylet_client->GrantWorkerLease("localhost", 1001, NodeID::Nil(), false,
worker1_id));
ASSERT_EQ(raylet_client->num_workers_requested, 2);
ASSERT_EQ(raylet_client->num_workers_returned, 0);
ASSERT_EQ(raylet_client->num_workers_disconnected, 0);
ASSERT_EQ(task_finisher->num_tasks_complete, 0);
ASSERT_EQ(task_finisher->num_tasks_failed, 0);
ASSERT_EQ(raylet_client->num_leases_canceled, 0);
ASSERT_EQ(worker_client->callbacks.size(), 10);
ASSERT_EQ(worker_client->steal_callbacks.size(), 0);
std::string worker2_id = "worker2_ID_abcdefghijklmnopq";
ASSERT_TRUE(raylet_client->GrantWorkerLease("localhost", 1002, NodeID::Nil(), false,
worker2_id));
ASSERT_EQ(raylet_client->num_workers_requested, 3);
ASSERT_EQ(raylet_client->num_workers_returned, 0);
ASSERT_EQ(raylet_client->num_workers_disconnected, 0);
ASSERT_EQ(task_finisher->num_tasks_complete, 0);
ASSERT_EQ(task_finisher->num_tasks_failed, 0);
ASSERT_EQ(raylet_client->num_leases_canceled, 0);
ASSERT_EQ(worker_client->callbacks.size(), 20);
ASSERT_EQ(worker_client->steal_callbacks.size(), 0);
// First worker runs the first 10 tasks
for (int i = 1; i <= 10; i++) {
ASSERT_TRUE(worker_client->ReplyPushTask());
}
// First worker begins stealing from the second worker
ASSERT_EQ(raylet_client->num_workers_requested, 3);
ASSERT_EQ(raylet_client->num_workers_returned, 0);
ASSERT_EQ(raylet_client->num_workers_disconnected, 0);
ASSERT_EQ(task_finisher->num_tasks_complete, 10);
ASSERT_EQ(task_finisher->num_tasks_failed, 0);
ASSERT_EQ(raylet_client->num_leases_canceled, 0);
ASSERT_EQ(worker_client->callbacks.size(), 10);
ASSERT_EQ(worker_client->steal_callbacks.size(), 1);
// 5 tasks get stolen!
for (int i = 1; i <= 5; i++) {
ASSERT_TRUE(worker_client->ReplyPushTask(Status::OK(), false, true));
}
ASSERT_EQ(raylet_client->num_workers_requested, 3);
ASSERT_EQ(raylet_client->num_workers_returned, 0);
ASSERT_EQ(raylet_client->num_workers_disconnected, 0);
ASSERT_EQ(task_finisher->num_tasks_complete, 10);
ASSERT_EQ(task_finisher->num_tasks_failed, 0);
ASSERT_EQ(raylet_client->num_leases_canceled, 0);
ASSERT_EQ(worker_client->callbacks.size(), 5);
ASSERT_EQ(worker_client->steal_callbacks.size(), 1);
// The 5 stolen tasks are forwarded from the victim (2nd worker) to the thief (1st
// worker)
std::vector<TaskSpecification> tasks_stolen;
for (int i = 0; i < 5; i++) {
tasks_stolen.push_back(BuildEmptyTaskSpec());
}
ASSERT_TRUE(worker_client->ReplyStealTasks(Status::OK(), tasks_stolen));
tasks_stolen.clear();
ASSERT_TRUE(tasks_stolen.empty());
ASSERT_EQ(raylet_client->num_workers_requested, 3);
ASSERT_EQ(raylet_client->num_workers_returned, 0);
ASSERT_EQ(raylet_client->num_workers_disconnected, 0);
ASSERT_EQ(task_finisher->num_tasks_complete, 10);
ASSERT_EQ(task_finisher->num_tasks_failed, 0);
ASSERT_EQ(raylet_client->num_leases_canceled, 0);
ASSERT_EQ(worker_client->callbacks.size(), 10);
ASSERT_EQ(worker_client->steal_callbacks.size(), 0);
// The second worker finishes its workload of 5 tasks and begins stealing from the first
// worker
for (int i = 1; i <= 5; i++) {
ASSERT_TRUE(worker_client->ReplyPushTask());
}
ASSERT_EQ(raylet_client->num_workers_requested, 3);
ASSERT_EQ(raylet_client->num_workers_returned, 0);
ASSERT_EQ(raylet_client->num_workers_disconnected, 0);
ASSERT_EQ(task_finisher->num_tasks_complete, 15);
ASSERT_EQ(task_finisher->num_tasks_failed, 0);
ASSERT_EQ(raylet_client->num_leases_canceled, 0);
ASSERT_EQ(worker_client->callbacks.size(), 5);
ASSERT_EQ(worker_client->steal_callbacks.size(), 1);
// The second worker steals floor(5/2)=2 tasks from the first worker
for (int i = 1; i <= 2; i++) {
ASSERT_TRUE(worker_client->ReplyPushTask(Status::OK(), false, true));
}
ASSERT_EQ(raylet_client->num_workers_requested, 3);
ASSERT_EQ(raylet_client->num_workers_returned, 0);
ASSERT_EQ(raylet_client->num_workers_disconnected, 0);
ASSERT_EQ(task_finisher->num_tasks_complete, 15);
ASSERT_EQ(task_finisher->num_tasks_failed, 0);
ASSERT_EQ(raylet_client->num_leases_canceled, 0);
ASSERT_EQ(worker_client->callbacks.size(), 3);
ASSERT_EQ(worker_client->steal_callbacks.size(), 1);
ASSERT_TRUE(tasks_stolen.empty());
for (int i = 0; i < 2; i++) {
tasks_stolen.push_back(BuildEmptyTaskSpec());
}
ASSERT_FALSE(tasks_stolen.empty());
ASSERT_TRUE(worker_client->ReplyStealTasks(Status::OK(), tasks_stolen));
tasks_stolen.clear();
ASSERT_TRUE(tasks_stolen.empty());
ASSERT_EQ(raylet_client->num_workers_requested, 3);
ASSERT_EQ(raylet_client->num_workers_returned, 0);
ASSERT_EQ(raylet_client->num_workers_disconnected, 0);
ASSERT_EQ(task_finisher->num_tasks_complete, 15);
ASSERT_EQ(task_finisher->num_tasks_failed, 0);
ASSERT_EQ(raylet_client->num_leases_canceled, 0);
ASSERT_EQ(worker_client->callbacks.size(), 5);
ASSERT_EQ(worker_client->steal_callbacks.size(), 0);
// The first worker executes the remaining 3 tasks (the ones not stolen) and returns
for (int i = 1; i <= 3; i++) {
ASSERT_TRUE(worker_client->ReplyPushTask());
}
ASSERT_EQ(raylet_client->num_workers_requested, 3);
ASSERT_EQ(raylet_client->num_workers_returned, 1);
ASSERT_EQ(raylet_client->num_workers_disconnected, 0);
ASSERT_EQ(task_finisher->num_tasks_complete, 18);
ASSERT_EQ(task_finisher->num_tasks_failed, 0);
ASSERT_EQ(raylet_client->num_leases_canceled, 1);
ASSERT_EQ(worker_client->callbacks.size(), 2);
ASSERT_EQ(worker_client->steal_callbacks.size(), 0);
// The second worker executes the stolen 2 tasks and returns.
for (int i = 1; i <= 2; i++) {
ASSERT_TRUE(worker_client->ReplyPushTask());
}
ASSERT_EQ(raylet_client->num_workers_requested, 3);
ASSERT_EQ(raylet_client->num_workers_returned, 2);
ASSERT_EQ(raylet_client->num_workers_disconnected, 0);
ASSERT_EQ(task_finisher->num_tasks_complete, 20);
ASSERT_EQ(task_finisher->num_tasks_failed, 0);
ASSERT_EQ(raylet_client->num_leases_canceled, 2);
ASSERT_EQ(worker_client->callbacks.size(), 0);
ASSERT_EQ(worker_client->steal_callbacks.size(), 0);
}
TEST(DirectTaskTransportTest, TestNoStealingByExpiredWorker) {
rpc::Address address;
auto raylet_client = std::make_shared<MockRayletClient>();
auto worker_client = std::make_shared<MockWorkerClient>();
auto store = std::make_shared<CoreWorkerMemoryStore>();
auto client_pool = std::make_shared<rpc::CoreWorkerClientPool>(
[&](const rpc::Address &addr) { return worker_client; });
auto task_finisher = std::make_shared<MockTaskFinisher>();
auto actor_creator = std::make_shared<MockActorCreator>();
auto lease_policy = std::make_shared<MockLeasePolicy>();
// Set max_tasks_in_flight_per_worker to a value larger than 1 to enable the
// pipelining of task submissions. This is done by passing a
// max_tasks_in_flight_per_worker parameter to the CoreWorkerDirectTaskSubmitter.
uint32_t max_tasks_in_flight_per_worker = 10;
CoreWorkerDirectTaskSubmitter submitter(
address, raylet_client, client_pool, nullptr, lease_policy, store, task_finisher,
NodeID::Nil(), 1000, actor_creator, max_tasks_in_flight_per_worker);
// prepare 30 tasks and save them in a vector
std::vector<TaskSpecification> tasks;
for (int i = 0; i < 30; i++) {
tasks.push_back(BuildEmptyTaskSpec());
}
ASSERT_EQ(tasks.size(), 30);
// Submit the tasks, and check that one worker is requested.
for (int i = 1; i <= 30; i++) {
auto task = tasks.front();
ASSERT_TRUE(submitter.SubmitTask(task).ok());
tasks.erase(tasks.begin());
}
ASSERT_EQ(tasks.size(), 0);
ASSERT_EQ(raylet_client->num_workers_requested, 1);
ASSERT_EQ(task_finisher->num_tasks_complete, 0);
ASSERT_EQ(task_finisher->num_tasks_failed, 0);
ASSERT_EQ(raylet_client->num_leases_canceled, 0);
ASSERT_EQ(worker_client->callbacks.size(), 0);
// Grant a worker lease, and check that one more worker is requested due to the Eager
// Worker Requesting Mode.
std::string worker1_id = "worker1_ID_abcdefghijklmnopq";
ASSERT_TRUE(raylet_client->GrantWorkerLease("localhost", 1001, NodeID::Nil(), false,
worker1_id));
ASSERT_EQ(raylet_client->num_workers_requested, 2);
ASSERT_EQ(raylet_client->num_workers_returned, 0);
ASSERT_EQ(raylet_client->num_workers_disconnected, 0);
ASSERT_EQ(task_finisher->num_tasks_complete, 0);
ASSERT_EQ(task_finisher->num_tasks_failed, 0);
ASSERT_EQ(raylet_client->num_leases_canceled, 0);
ASSERT_EQ(worker_client->callbacks.size(), 10);
ASSERT_EQ(worker_client->steal_callbacks.size(), 0);
// Grant a second worker lease, and check that one more worker is requested due to the
// Eager Worker Requesting Mode.
std::string worker2_id = "worker2_ID_abcdefghijklmnopq";
ASSERT_TRUE(raylet_client->GrantWorkerLease("localhost", 1002, NodeID::Nil(), false,
worker2_id));
ASSERT_EQ(raylet_client->num_workers_requested, 3);
ASSERT_EQ(raylet_client->num_workers_returned, 0);
ASSERT_EQ(raylet_client->num_workers_disconnected, 0);
ASSERT_EQ(task_finisher->num_tasks_complete, 0);
ASSERT_EQ(task_finisher->num_tasks_failed, 0);
ASSERT_EQ(raylet_client->num_leases_canceled, 0);
ASSERT_EQ(worker_client->callbacks.size(), 20);
ASSERT_EQ(worker_client->steal_callbacks.size(), 0);
// Grant a third worker lease, and check that one more worker is requested due to the
// Eager Worker Requesting Mode.
std::string worker3_id = "worker3_ID_abcdefghijklmnopq";
ASSERT_TRUE(raylet_client->GrantWorkerLease("localhost", 1003, NodeID::Nil(), false,
worker3_id));
ASSERT_EQ(raylet_client->num_workers_requested, 4);
ASSERT_EQ(raylet_client->num_workers_returned, 0);
ASSERT_EQ(raylet_client->num_workers_disconnected, 0);
ASSERT_EQ(task_finisher->num_tasks_complete, 0);
ASSERT_EQ(task_finisher->num_tasks_failed, 0);
ASSERT_EQ(raylet_client->num_leases_canceled, 0);
ASSERT_EQ(worker_client->callbacks.size(), 30);
ASSERT_EQ(worker_client->steal_callbacks.size(), 0);
// First worker runs the first 9 tasks and returns an error on completion of the last
// one (10th task).
for (int i = 1; i <= 10; i++) {
bool found_error = (i == 10);
auto status = Status::OK();
ASSERT_TRUE(status.ok());
if (found_error) {
status = Status::UnknownError("Worker has experienced an unknown error!");
ASSERT_FALSE(status.ok());
}
ASSERT_TRUE(worker_client->ReplyPushTask(status));
}
// Check that the first worker does not start stealing, and that it is returned to the
// Raylet instead.
ASSERT_EQ(raylet_client->num_workers_requested, 4);
ASSERT_EQ(raylet_client->num_workers_returned, 0);
ASSERT_EQ(raylet_client->num_workers_disconnected, 1);
ASSERT_EQ(task_finisher->num_tasks_complete, 9);
ASSERT_EQ(task_finisher->num_tasks_failed, 1);
ASSERT_EQ(raylet_client->num_leases_canceled, 0);
ASSERT_EQ(worker_client->callbacks.size(), 20);
ASSERT_EQ(worker_client->steal_callbacks.size(), 0);
// Second worker runs the first 9 tasks. Then we let its lease expire, and check that it
// does not initiate stealing.
for (int i = 1; i <= 9; i++) {
ASSERT_TRUE(worker_client->ReplyPushTask());
}
std::this_thread::sleep_for(
std::chrono::milliseconds(2000)); // Sleep for 1s, causing the lease to time out.
ASSERT_TRUE(worker_client->ReplyPushTask());
// Check that the second worker does not start stealing, and that it is returned to the
// Raylet instead.
ASSERT_EQ(raylet_client->num_workers_requested, 4);
ASSERT_EQ(raylet_client->num_workers_returned, 1);
ASSERT_EQ(raylet_client->num_workers_disconnected, 1);
ASSERT_EQ(task_finisher->num_tasks_complete, 19);
ASSERT_EQ(task_finisher->num_tasks_failed, 1);
ASSERT_EQ(raylet_client->num_leases_canceled, 0);
ASSERT_EQ(worker_client->callbacks.size(), 10);
ASSERT_EQ(worker_client->steal_callbacks.size(), 0);
// Last worker finishes its workload and returns
for (int i = 1; i <= 10; i++) {
ASSERT_TRUE(worker_client->ReplyPushTask());
}
ASSERT_EQ(raylet_client->num_workers_requested, 4);
ASSERT_EQ(raylet_client->num_workers_returned, 2);
ASSERT_EQ(raylet_client->num_workers_disconnected, 1);
ASSERT_EQ(task_finisher->num_tasks_complete, 29);
ASSERT_EQ(task_finisher->num_tasks_failed, 1);
ASSERT_EQ(raylet_client->num_leases_canceled, 0);
ASSERT_EQ(worker_client->callbacks.size(), 0);
ASSERT_EQ(worker_client->steal_callbacks.size(), 0);
}
TEST(DirectTaskTransportTest, TestNoWorkerRequestedIfStealingUnavailable) {
rpc::Address address;
auto raylet_client = std::make_shared<MockRayletClient>();
auto worker_client = std::make_shared<MockWorkerClient>();
auto store = std::make_shared<CoreWorkerMemoryStore>();
auto client_pool = std::make_shared<rpc::CoreWorkerClientPool>(
[&](const rpc::Address &addr) { return worker_client; });
auto task_finisher = std::make_shared<MockTaskFinisher>();
auto actor_creator = std::make_shared<MockActorCreator>();
auto lease_policy = std::make_shared<MockLeasePolicy>();
// Set max_tasks_in_flight_per_worker to a value larger than 1 to enable the
// pipelining of task submissions. This is done by passing a
// max_tasks_in_flight_per_worker parameter to the CoreWorkerDirectTaskSubmitter.
uint32_t max_tasks_in_flight_per_worker = 10;
CoreWorkerDirectTaskSubmitter submitter(
address, raylet_client, client_pool, nullptr, lease_policy, store, task_finisher,
NodeID::Nil(), kLongTimeout, actor_creator, max_tasks_in_flight_per_worker);
// prepare 2 tasks and save them in a vector
std::vector<TaskSpecification> tasks;
for (int i = 0; i < 10; i++) {
tasks.push_back(BuildEmptyTaskSpec());
}
ASSERT_EQ(tasks.size(), 10);
// submit both tasks
for (int i = 1; i <= 10; i++) {
auto task = tasks.front();
ASSERT_TRUE(submitter.SubmitTask(task).ok());
tasks.erase(tasks.begin());
}
ASSERT_EQ(tasks.size(), 0);
ASSERT_EQ(raylet_client->num_workers_requested, 1);
ASSERT_EQ(task_finisher->num_tasks_complete, 0);
ASSERT_EQ(task_finisher->num_tasks_failed, 0);
ASSERT_EQ(raylet_client->num_leases_canceled, 0);
ASSERT_EQ(worker_client->callbacks.size(), 0);
// Grant a worker lease, and check that one more worker is requested due to the Eager
// Worker Requesting Mode, even if the task queue is empty.
std::string worker1_id = "worker1_ID_abcdefghijklmnopq";
ASSERT_TRUE(raylet_client->GrantWorkerLease("localhost", 1001, NodeID::Nil(), false,
worker1_id));
ASSERT_EQ(raylet_client->num_workers_requested, 2);
ASSERT_EQ(raylet_client->num_workers_disconnected, 0);
ASSERT_EQ(raylet_client->num_workers_returned, 0);
ASSERT_EQ(task_finisher->num_tasks_complete, 0);
ASSERT_EQ(task_finisher->num_tasks_failed, 0);
ASSERT_EQ(raylet_client->num_leases_canceled, 0);
ASSERT_EQ(worker_client->callbacks.size(), 10);
ASSERT_EQ(worker_client->steal_callbacks.size(), 0);
// Execute 9 tasks
for (int i = 1; i <= 9; i++) {
ASSERT_TRUE(worker_client->ReplyPushTask());
}
ASSERT_EQ(raylet_client->num_workers_requested, 2);
ASSERT_EQ(raylet_client->num_workers_disconnected, 0);
ASSERT_EQ(raylet_client->num_workers_returned, 0);
ASSERT_EQ(task_finisher->num_tasks_complete, 9);
ASSERT_EQ(task_finisher->num_tasks_failed, 0);
ASSERT_EQ(raylet_client->num_leases_canceled, 0);
ASSERT_EQ(worker_client->callbacks.size(), 1);
ASSERT_EQ(worker_client->steal_callbacks.size(), 0);
// Grant a second worker, which returns immediately because there are no stealable
// tasks.
std::string worker2_id = "worker2_ID_abcdefghijklmnopq";
ASSERT_TRUE(raylet_client->GrantWorkerLease("localhost", 1002, NodeID::Nil(), false,
worker2_id));
// Check that no more workers are requested now that there are no more stealable tasks.
ASSERT_EQ(raylet_client->num_workers_requested, 2);
ASSERT_EQ(raylet_client->num_workers_disconnected, 0);
ASSERT_EQ(raylet_client->num_workers_returned, 1);
ASSERT_EQ(task_finisher->num_tasks_complete, 9);
ASSERT_EQ(task_finisher->num_tasks_failed, 0);
ASSERT_EQ(raylet_client->num_leases_canceled, 0);
ASSERT_EQ(worker_client->callbacks.size(), 1);
ASSERT_EQ(worker_client->steal_callbacks.size(), 0);
// Last task runs and first worker is returned
ASSERT_TRUE(worker_client->ReplyPushTask());
ASSERT_EQ(raylet_client->num_workers_requested, 2);
ASSERT_EQ(raylet_client->num_workers_returned, 2);
ASSERT_EQ(raylet_client->num_workers_disconnected, 0);
ASSERT_EQ(task_finisher->num_tasks_complete, 10);
ASSERT_EQ(task_finisher->num_tasks_failed, 0);
ASSERT_EQ(raylet_client->num_leases_canceled, 0);
ASSERT_EQ(worker_client->callbacks.size(), 0);
ASSERT_EQ(worker_client->steal_callbacks.size(), 0);
}
} // namespace ray } // namespace ray
int main(int argc, char **argv) { int main(int argc, char **argv) {

View file

@ -42,15 +42,18 @@ TEST(SchedulingQueueTest, TestInOrder) {
ActorSchedulingQueue queue(io_service, waiter); ActorSchedulingQueue queue(io_service, waiter);
int n_ok = 0; int n_ok = 0;
int n_rej = 0; int n_rej = 0;
int n_steal = 0;
auto fn_ok = [&n_ok](rpc::SendReplyCallback callback) { n_ok++; }; auto fn_ok = [&n_ok](rpc::SendReplyCallback callback) { n_ok++; };
auto fn_rej = [&n_rej](rpc::SendReplyCallback callback) { n_rej++; }; auto fn_rej = [&n_rej](rpc::SendReplyCallback callback) { n_rej++; };
queue.Add(0, -1, fn_ok, fn_rej, nullptr); auto fn_steal = [&n_steal](rpc::SendReplyCallback callback) { n_steal++; };
queue.Add(1, -1, fn_ok, fn_rej, nullptr); queue.Add(0, -1, fn_ok, fn_rej, nullptr, fn_steal);
queue.Add(2, -1, fn_ok, fn_rej, nullptr); queue.Add(1, -1, fn_ok, fn_rej, nullptr, fn_steal);
queue.Add(3, -1, fn_ok, fn_rej, nullptr); queue.Add(2, -1, fn_ok, fn_rej, nullptr, fn_steal);
queue.Add(3, -1, fn_ok, fn_rej, nullptr, fn_steal);
io_service.run(); io_service.run();
ASSERT_EQ(n_ok, 4); ASSERT_EQ(n_ok, 4);
ASSERT_EQ(n_rej, 0); ASSERT_EQ(n_rej, 0);
ASSERT_EQ(n_steal, 0);
} }
TEST(SchedulingQueueTest, TestWaitForObjects) { TEST(SchedulingQueueTest, TestWaitForObjects) {
@ -62,12 +65,19 @@ TEST(SchedulingQueueTest, TestWaitForObjects) {
ActorSchedulingQueue queue(io_service, waiter); ActorSchedulingQueue queue(io_service, waiter);
int n_ok = 0; int n_ok = 0;
int n_rej = 0; int n_rej = 0;
int n_steal = 0;
auto fn_ok = [&n_ok](rpc::SendReplyCallback callback) { n_ok++; }; auto fn_ok = [&n_ok](rpc::SendReplyCallback callback) { n_ok++; };
auto fn_rej = [&n_rej](rpc::SendReplyCallback callback) { n_rej++; }; auto fn_rej = [&n_rej](rpc::SendReplyCallback callback) { n_rej++; };
queue.Add(0, -1, fn_ok, fn_rej, nullptr); auto fn_steal = [&n_steal](rpc::SendReplyCallback callback) { n_steal++; };
queue.Add(1, -1, fn_ok, fn_rej, nullptr, TaskID::Nil(), ObjectIdsToRefs({obj1})); queue.Add(0, -1, fn_ok, fn_rej, nullptr, fn_steal);
queue.Add(2, -1, fn_ok, fn_rej, nullptr, TaskID::Nil(), ObjectIdsToRefs({obj2})); queue.Add(1, -1, fn_ok, fn_rej, nullptr, fn_steal, TaskID::Nil(),
queue.Add(3, -1, fn_ok, fn_rej, nullptr, TaskID::Nil(), ObjectIdsToRefs({obj3})); ObjectIdsToRefs({obj1}));
queue.Add(2, -1, fn_ok, fn_rej, nullptr, fn_steal, TaskID::Nil(),
ObjectIdsToRefs({obj2}));
queue.Add(3, -1, fn_ok, fn_rej, nullptr, fn_steal, TaskID::Nil(),
ObjectIdsToRefs({obj3}));
ASSERT_EQ(n_ok, 1); ASSERT_EQ(n_ok, 1);
waiter.Complete(0); waiter.Complete(0);
@ -78,6 +88,8 @@ TEST(SchedulingQueueTest, TestWaitForObjects) {
waiter.Complete(1); waiter.Complete(1);
ASSERT_EQ(n_ok, 4); ASSERT_EQ(n_ok, 4);
ASSERT_EQ(n_steal, 0);
} }
TEST(SchedulingQueueTest, TestWaitForObjectsNotSubjectToSeqTimeout) { TEST(SchedulingQueueTest, TestWaitForObjectsNotSubjectToSeqTimeout) {
@ -87,15 +99,21 @@ TEST(SchedulingQueueTest, TestWaitForObjectsNotSubjectToSeqTimeout) {
ActorSchedulingQueue queue(io_service, waiter); ActorSchedulingQueue queue(io_service, waiter);
int n_ok = 0; int n_ok = 0;
int n_rej = 0; int n_rej = 0;
int n_steal = 0;
auto fn_ok = [&n_ok](rpc::SendReplyCallback callback) { n_ok++; }; auto fn_ok = [&n_ok](rpc::SendReplyCallback callback) { n_ok++; };
auto fn_rej = [&n_rej](rpc::SendReplyCallback callback) { n_rej++; }; auto fn_rej = [&n_rej](rpc::SendReplyCallback callback) { n_rej++; };
queue.Add(0, -1, fn_ok, fn_rej, nullptr); auto fn_steal = [&n_steal](rpc::SendReplyCallback callback) { n_steal++; };
queue.Add(1, -1, fn_ok, fn_rej, nullptr, TaskID::Nil(), ObjectIdsToRefs({obj1})); queue.Add(0, -1, fn_ok, fn_rej, nullptr, fn_steal);
queue.Add(1, -1, fn_ok, fn_rej, nullptr, fn_steal, TaskID::Nil(),
ObjectIdsToRefs({obj1}));
ASSERT_EQ(n_ok, 1); ASSERT_EQ(n_ok, 1);
io_service.run(); io_service.run();
ASSERT_EQ(n_rej, 0); ASSERT_EQ(n_rej, 0);
waiter.Complete(0); waiter.Complete(0);
ASSERT_EQ(n_ok, 2); ASSERT_EQ(n_ok, 2);
ASSERT_EQ(n_steal, 0);
} }
TEST(SchedulingQueueTest, TestOutOfOrder) { TEST(SchedulingQueueTest, TestOutOfOrder) {
@ -104,15 +122,18 @@ TEST(SchedulingQueueTest, TestOutOfOrder) {
ActorSchedulingQueue queue(io_service, waiter); ActorSchedulingQueue queue(io_service, waiter);
int n_ok = 0; int n_ok = 0;
int n_rej = 0; int n_rej = 0;
int n_steal = 0;
auto fn_ok = [&n_ok](rpc::SendReplyCallback callback) { n_ok++; }; auto fn_ok = [&n_ok](rpc::SendReplyCallback callback) { n_ok++; };
auto fn_rej = [&n_rej](rpc::SendReplyCallback callback) { n_rej++; }; auto fn_rej = [&n_rej](rpc::SendReplyCallback callback) { n_rej++; };
queue.Add(2, -1, fn_ok, fn_rej, nullptr); auto fn_steal = [&n_steal](rpc::SendReplyCallback callback) { n_steal++; };
queue.Add(0, -1, fn_ok, fn_rej, nullptr); queue.Add(2, -1, fn_ok, fn_rej, nullptr, fn_steal);
queue.Add(3, -1, fn_ok, fn_rej, nullptr); queue.Add(0, -1, fn_ok, fn_rej, nullptr, fn_steal);
queue.Add(1, -1, fn_ok, fn_rej, nullptr); queue.Add(3, -1, fn_ok, fn_rej, nullptr, fn_steal);
queue.Add(1, -1, fn_ok, fn_rej, nullptr, fn_steal);
io_service.run(); io_service.run();
ASSERT_EQ(n_ok, 4); ASSERT_EQ(n_ok, 4);
ASSERT_EQ(n_rej, 0); ASSERT_EQ(n_rej, 0);
ASSERT_EQ(n_steal, 0);
} }
TEST(SchedulingQueueTest, TestSeqWaitTimeout) { TEST(SchedulingQueueTest, TestSeqWaitTimeout) {
@ -121,20 +142,23 @@ TEST(SchedulingQueueTest, TestSeqWaitTimeout) {
ActorSchedulingQueue queue(io_service, waiter); ActorSchedulingQueue queue(io_service, waiter);
int n_ok = 0; int n_ok = 0;
int n_rej = 0; int n_rej = 0;
int n_steal = 0;
auto fn_ok = [&n_ok](rpc::SendReplyCallback callback) { n_ok++; }; auto fn_ok = [&n_ok](rpc::SendReplyCallback callback) { n_ok++; };
auto fn_rej = [&n_rej](rpc::SendReplyCallback callback) { n_rej++; }; auto fn_rej = [&n_rej](rpc::SendReplyCallback callback) { n_rej++; };
queue.Add(2, -1, fn_ok, fn_rej, nullptr); auto fn_steal = [&n_steal](rpc::SendReplyCallback callback) { n_steal++; };
queue.Add(0, -1, fn_ok, fn_rej, nullptr); queue.Add(2, -1, fn_ok, fn_rej, nullptr, fn_steal);
queue.Add(3, -1, fn_ok, fn_rej, nullptr); queue.Add(0, -1, fn_ok, fn_rej, nullptr, fn_steal);
queue.Add(3, -1, fn_ok, fn_rej, nullptr, fn_steal);
ASSERT_EQ(n_ok, 1); ASSERT_EQ(n_ok, 1);
ASSERT_EQ(n_rej, 0); ASSERT_EQ(n_rej, 0);
io_service.run(); // immediately triggers timeout io_service.run(); // immediately triggers timeout
ASSERT_EQ(n_ok, 1); ASSERT_EQ(n_ok, 1);
ASSERT_EQ(n_rej, 2); ASSERT_EQ(n_rej, 2);
queue.Add(4, -1, fn_ok, fn_rej, nullptr); queue.Add(4, -1, fn_ok, fn_rej, nullptr, fn_steal);
queue.Add(5, -1, fn_ok, fn_rej, nullptr); queue.Add(5, -1, fn_ok, fn_rej, nullptr, fn_steal);
ASSERT_EQ(n_ok, 3); ASSERT_EQ(n_ok, 3);
ASSERT_EQ(n_rej, 2); ASSERT_EQ(n_rej, 2);
ASSERT_EQ(n_steal, 0);
} }
TEST(SchedulingQueueTest, TestSkipAlreadyProcessedByClient) { TEST(SchedulingQueueTest, TestSkipAlreadyProcessedByClient) {
@ -143,14 +167,17 @@ TEST(SchedulingQueueTest, TestSkipAlreadyProcessedByClient) {
ActorSchedulingQueue queue(io_service, waiter); ActorSchedulingQueue queue(io_service, waiter);
int n_ok = 0; int n_ok = 0;
int n_rej = 0; int n_rej = 0;
int n_steal = 0;
auto fn_ok = [&n_ok](rpc::SendReplyCallback callback) { n_ok++; }; auto fn_ok = [&n_ok](rpc::SendReplyCallback callback) { n_ok++; };
auto fn_rej = [&n_rej](rpc::SendReplyCallback callback) { n_rej++; }; auto fn_rej = [&n_rej](rpc::SendReplyCallback callback) { n_rej++; };
queue.Add(2, 2, fn_ok, fn_rej, nullptr); auto fn_steal = [&n_steal](rpc::SendReplyCallback callback) { n_steal++; };
queue.Add(3, 2, fn_ok, fn_rej, nullptr); queue.Add(2, 2, fn_ok, fn_rej, nullptr, fn_steal);
queue.Add(1, 2, fn_ok, fn_rej, nullptr); queue.Add(3, 2, fn_ok, fn_rej, nullptr, fn_steal);
queue.Add(1, 2, fn_ok, fn_rej, nullptr, fn_steal);
io_service.run(); io_service.run();
ASSERT_EQ(n_ok, 1); ASSERT_EQ(n_ok, 1);
ASSERT_EQ(n_rej, 2); ASSERT_EQ(n_rej, 2);
ASSERT_EQ(n_steal, 0);
} }
TEST(SchedulingQueueTest, TestCancelQueuedTask) { TEST(SchedulingQueueTest, TestCancelQueuedTask) {
@ -158,18 +185,127 @@ TEST(SchedulingQueueTest, TestCancelQueuedTask) {
ASSERT_TRUE(queue->TaskQueueEmpty()); ASSERT_TRUE(queue->TaskQueueEmpty());
int n_ok = 0; int n_ok = 0;
int n_rej = 0; int n_rej = 0;
int n_steal = 0;
auto fn_ok = [&n_ok](rpc::SendReplyCallback callback) { n_ok++; }; auto fn_ok = [&n_ok](rpc::SendReplyCallback callback) { n_ok++; };
auto fn_rej = [&n_rej](rpc::SendReplyCallback callback) { n_rej++; }; auto fn_rej = [&n_rej](rpc::SendReplyCallback callback) { n_rej++; };
queue->Add(-1, -1, fn_ok, fn_rej, nullptr); auto fn_steal = [&n_steal](rpc::SendReplyCallback callback) { n_steal++; };
queue->Add(-1, -1, fn_ok, fn_rej, nullptr); queue->Add(-1, -1, fn_ok, fn_rej, nullptr, fn_steal);
queue->Add(-1, -1, fn_ok, fn_rej, nullptr); queue->Add(-1, -1, fn_ok, fn_rej, nullptr, fn_steal);
queue->Add(-1, -1, fn_ok, fn_rej, nullptr); queue->Add(-1, -1, fn_ok, fn_rej, nullptr, fn_steal);
queue->Add(-1, -1, fn_ok, fn_rej, nullptr); queue->Add(-1, -1, fn_ok, fn_rej, nullptr, fn_steal);
queue->Add(-1, -1, fn_ok, fn_rej, nullptr, fn_steal);
ASSERT_TRUE(queue->CancelTaskIfFound(TaskID::Nil())); ASSERT_TRUE(queue->CancelTaskIfFound(TaskID::Nil()));
ASSERT_FALSE(queue->TaskQueueEmpty()); ASSERT_FALSE(queue->TaskQueueEmpty());
queue->ScheduleRequests(); queue->ScheduleRequests();
ASSERT_EQ(n_ok, 4); ASSERT_EQ(n_ok, 4);
ASSERT_EQ(n_rej, 0); ASSERT_EQ(n_rej, 0);
ASSERT_EQ(n_steal, 0);
}
TEST(SchedulingQueueTest, TestStealingOneTask) {
NormalSchedulingQueue *queue = new NormalSchedulingQueue();
ASSERT_TRUE(queue->TaskQueueEmpty());
int n_ok = 0;
int n_rej = 0;
int n_steal = 0;
auto fn_ok = [&n_ok](rpc::SendReplyCallback callback) { n_ok++; };
auto fn_rej = [&n_rej](rpc::SendReplyCallback callback) { n_rej++; };
auto fn_steal = [&n_steal](rpc::SendReplyCallback callback) { n_steal++; };
queue->Add(-1, -1, fn_ok, fn_rej, nullptr, fn_steal);
auto reply = rpc::StealTasksReply();
size_t n_stolen = reply.stolen_tasks_ids_size();
ASSERT_EQ(n_stolen, 0);
ASSERT_EQ(queue->Steal(&reply), 0);
n_stolen = reply.stolen_tasks_ids_size();
ASSERT_EQ(n_stolen, 0);
ASSERT_FALSE(queue->TaskQueueEmpty());
queue->ScheduleRequests();
ASSERT_TRUE(queue->TaskQueueEmpty());
ASSERT_EQ(n_ok, 1);
ASSERT_EQ(n_rej, 0);
ASSERT_EQ(n_steal, 0);
}
TEST(SchedulingQueueTest, TestStealingEvenNumberTasks) {
NormalSchedulingQueue *queue = new NormalSchedulingQueue();
ASSERT_TRUE(queue->TaskQueueEmpty());
int n_ok = 0;
int n_rej = 0;
int n_steal = 0;
auto fn_ok = [&n_ok](rpc::SendReplyCallback callback) { n_ok++; };
auto fn_rej = [&n_rej](rpc::SendReplyCallback callback) { n_rej++; };
auto fn_steal = [&n_steal](rpc::SendReplyCallback callback) { n_steal++; };
queue->Add(-1, -1, fn_ok, fn_rej, nullptr, fn_steal);
queue->Add(-1, -1, fn_ok, fn_rej, nullptr, fn_steal);
queue->Add(-1, -1, fn_ok, fn_rej, nullptr, fn_steal);
queue->Add(-1, -1, fn_ok, fn_rej, nullptr, fn_steal);
queue->Add(-1, -1, fn_ok, fn_rej, nullptr, fn_steal);
queue->Add(-1, -1, fn_ok, fn_rej, nullptr, fn_steal);
queue->Add(-1, -1, fn_ok, fn_rej, nullptr, fn_steal);
queue->Add(-1, -1, fn_ok, fn_rej, nullptr, fn_steal);
queue->Add(-1, -1, fn_ok, fn_rej, nullptr, fn_steal);
queue->Add(-1, -1, fn_ok, fn_rej, nullptr, fn_steal);
auto reply = rpc::StealTasksReply();
size_t n_stolen = reply.stolen_tasks_ids_size();
ASSERT_EQ(n_stolen, 0);
ASSERT_EQ(queue->Steal(&reply), 5);
n_stolen = reply.stolen_tasks_ids_size();
ASSERT_EQ(n_stolen, 5);
ASSERT_FALSE(queue->TaskQueueEmpty());
queue->ScheduleRequests();
queue->ScheduleRequests();
queue->ScheduleRequests();
queue->ScheduleRequests();
queue->ScheduleRequests();
ASSERT_TRUE(queue->TaskQueueEmpty());
ASSERT_EQ(n_ok, 5);
ASSERT_EQ(n_rej, 0);
ASSERT_EQ(n_steal, 5);
}
TEST(SchedulingQueueTest, TestStealingOddNumberTasks) {
NormalSchedulingQueue *queue = new NormalSchedulingQueue();
ASSERT_TRUE(queue->TaskQueueEmpty());
int n_ok = 0;
int n_rej = 0;
int n_steal = 0;
auto fn_ok = [&n_ok](rpc::SendReplyCallback callback) { n_ok++; };
auto fn_rej = [&n_rej](rpc::SendReplyCallback callback) { n_rej++; };
auto fn_steal = [&n_steal](rpc::SendReplyCallback callback) { n_steal++; };
queue->Add(-1, -1, fn_ok, fn_rej, nullptr, fn_steal);
queue->Add(-1, -1, fn_ok, fn_rej, nullptr, fn_steal);
queue->Add(-1, -1, fn_ok, fn_rej, nullptr, fn_steal);
queue->Add(-1, -1, fn_ok, fn_rej, nullptr, fn_steal);
queue->Add(-1, -1, fn_ok, fn_rej, nullptr, fn_steal);
queue->Add(-1, -1, fn_ok, fn_rej, nullptr, fn_steal);
queue->Add(-1, -1, fn_ok, fn_rej, nullptr, fn_steal);
queue->Add(-1, -1, fn_ok, fn_rej, nullptr, fn_steal);
queue->Add(-1, -1, fn_ok, fn_rej, nullptr, fn_steal);
queue->Add(-1, -1, fn_ok, fn_rej, nullptr, fn_steal);
queue->Add(-1, -1, fn_ok, fn_rej, nullptr, fn_steal);
auto reply = rpc::StealTasksReply();
size_t n_stolen = reply.stolen_tasks_ids_size();
ASSERT_EQ(n_stolen, 0);
ASSERT_EQ(queue->Steal(&reply), 5);
n_stolen = reply.stolen_tasks_ids_size();
ASSERT_EQ(n_stolen, 5);
ASSERT_FALSE(queue->TaskQueueEmpty());
queue->ScheduleRequests();
queue->ScheduleRequests();
queue->ScheduleRequests();
queue->ScheduleRequests();
queue->ScheduleRequests();
queue->ScheduleRequests();
ASSERT_TRUE(queue->TaskQueueEmpty());
ASSERT_EQ(n_ok, 6);
ASSERT_EQ(n_rej, 0);
ASSERT_EQ(n_steal, 5);
} }
} // namespace ray } // namespace ray

View file

@ -530,6 +530,15 @@ void CoreWorkerDirectTaskReceiver::HandleTask(
send_reply_callback(Status::Invalid("client cancelled stale rpc"), nullptr, nullptr); send_reply_callback(Status::Invalid("client cancelled stale rpc"), nullptr, nullptr);
}; };
auto steal_callback = [this, task_spec,
reply](rpc::SendReplyCallback send_reply_callback) {
RAY_LOG(DEBUG) << "Task " << task_spec.TaskId() << " was stolen from "
<< worker_context_.GetWorkerID()
<< "'s non_actor_task_queue_! Setting reply->set_task_stolen(true)!";
reply->set_task_stolen(true);
send_reply_callback(Status::OK(), nullptr, nullptr);
};
auto dependencies = task_spec.GetDependencies(false); auto dependencies = task_spec.GetDependencies(false);
if (task_spec.IsActorTask()) { if (task_spec.IsActorTask()) {
@ -544,13 +553,15 @@ void CoreWorkerDirectTaskReceiver::HandleTask(
it->second->Add(request.sequence_number(), request.client_processed_up_to(), it->second->Add(request.sequence_number(), request.client_processed_up_to(),
std::move(accept_callback), std::move(reject_callback), std::move(accept_callback), std::move(reject_callback),
std::move(send_reply_callback), task_spec.TaskId(), dependencies); std::move(send_reply_callback), nullptr, task_spec.TaskId(),
dependencies);
} else { } else {
// Add the normal task's callbacks to the non-actor scheduling queue. // Add the normal task's callbacks to the non-actor scheduling queue.
normal_scheduling_queue_->Add( normal_scheduling_queue_->Add(
request.sequence_number(), request.client_processed_up_to(), request.sequence_number(), request.client_processed_up_to(),
std::move(accept_callback), std::move(reject_callback), std::move(accept_callback), std::move(reject_callback),
std::move(send_reply_callback), task_spec.TaskId(), dependencies); std::move(send_reply_callback), std::move(steal_callback), task_spec.TaskId(),
dependencies);
} }
} }
@ -564,6 +575,16 @@ void CoreWorkerDirectTaskReceiver::RunNormalTasksFromQueue() {
normal_scheduling_queue_->ScheduleRequests(); normal_scheduling_queue_->ScheduleRequests();
} }
void CoreWorkerDirectTaskReceiver::HandleStealTasks(
const rpc::StealTasksRequest &request, rpc::StealTasksReply *reply,
rpc::SendReplyCallback send_reply_callback) {
size_t n_tasks_stolen = normal_scheduling_queue_->Steal(reply);
RAY_LOG(DEBUG) << "Number of tasks stolen is " << n_tasks_stolen;
// send reply back
send_reply_callback(Status::OK(), nullptr, nullptr);
}
bool CoreWorkerDirectTaskReceiver::CancelQueuedNormalTask(TaskID task_id) { bool CoreWorkerDirectTaskReceiver::CancelQueuedNormalTask(TaskID task_id) {
// Look up the task to be canceled in the queue of normal tasks. If it is found and // Look up the task to be canceled in the queue of normal tasks. If it is found and
// removed successfully, return true. // removed successfully, return true.

View file

@ -276,18 +276,28 @@ class CoreWorkerDirectActorTaskSubmitter
class InboundRequest { class InboundRequest {
public: public:
InboundRequest(){}; InboundRequest(){};
InboundRequest(std::function<void(rpc::SendReplyCallback)> accept_callback, InboundRequest(std::function<void(rpc::SendReplyCallback)> accept_callback,
std::function<void(rpc::SendReplyCallback)> reject_callback, std::function<void(rpc::SendReplyCallback)> reject_callback,
std::function<void(rpc::SendReplyCallback)> steal_callback,
rpc::SendReplyCallback send_reply_callback, TaskID task_id, rpc::SendReplyCallback send_reply_callback, TaskID task_id,
bool has_dependencies) bool has_dependencies)
: accept_callback_(std::move(accept_callback)), : accept_callback_(std::move(accept_callback)),
reject_callback_(std::move(reject_callback)), reject_callback_(std::move(reject_callback)),
steal_callback_(std::move(steal_callback)),
send_reply_callback_(std::move(send_reply_callback)), send_reply_callback_(std::move(send_reply_callback)),
task_id(task_id), task_id(task_id),
has_pending_dependencies_(has_dependencies) {} has_pending_dependencies_(has_dependencies) {}
void Accept() { accept_callback_(std::move(send_reply_callback_)); } void Accept() { accept_callback_(std::move(send_reply_callback_)); }
void Cancel() { reject_callback_(std::move(send_reply_callback_)); } void Cancel() { reject_callback_(std::move(send_reply_callback_)); }
void Steal(rpc::StealTasksReply *reply) {
reply->add_stolen_tasks_ids(task_id.Binary());
RAY_CHECK(TaskID::FromBinary(reply->stolen_tasks_ids(reply->stolen_tasks_ids_size() -
1)) == task_id);
steal_callback_(std::move(send_reply_callback_));
}
bool CanExecute() const { return !has_pending_dependencies_; } bool CanExecute() const { return !has_pending_dependencies_; }
ray::TaskID TaskID() const { return task_id; } ray::TaskID TaskID() const { return task_id; }
void MarkDependenciesSatisfied() { has_pending_dependencies_ = false; } void MarkDependenciesSatisfied() { has_pending_dependencies_ = false; }
@ -295,7 +305,9 @@ class InboundRequest {
private: private:
std::function<void(rpc::SendReplyCallback)> accept_callback_; std::function<void(rpc::SendReplyCallback)> accept_callback_;
std::function<void(rpc::SendReplyCallback)> reject_callback_; std::function<void(rpc::SendReplyCallback)> reject_callback_;
std::function<void(rpc::SendReplyCallback)> steal_callback_;
rpc::SendReplyCallback send_reply_callback_; rpc::SendReplyCallback send_reply_callback_;
ray::TaskID task_id; ray::TaskID task_id;
bool has_pending_dependencies_; bool has_pending_dependencies_;
}; };
@ -378,10 +390,13 @@ class SchedulingQueue {
std::function<void(rpc::SendReplyCallback)> accept_request, std::function<void(rpc::SendReplyCallback)> accept_request,
std::function<void(rpc::SendReplyCallback)> reject_request, std::function<void(rpc::SendReplyCallback)> reject_request,
rpc::SendReplyCallback send_reply_callback, rpc::SendReplyCallback send_reply_callback,
std::function<void(rpc::SendReplyCallback)> steal_request = nullptr,
TaskID task_id = TaskID::Nil(), TaskID task_id = TaskID::Nil(),
const std::vector<rpc::ObjectReference> &dependencies = {}) = 0; const std::vector<rpc::ObjectReference> &dependencies = {}) = 0;
virtual void ScheduleRequests() = 0; virtual void ScheduleRequests() = 0;
virtual bool TaskQueueEmpty() const = 0; virtual bool TaskQueueEmpty() const = 0;
virtual size_t Size() const = 0;
virtual size_t Steal(rpc::StealTasksReply *reply) = 0;
virtual bool CancelTaskIfFound(TaskID task_id) = 0; virtual bool CancelTaskIfFound(TaskID task_id) = 0;
virtual ~SchedulingQueue(){}; virtual ~SchedulingQueue(){};
}; };
@ -407,13 +422,27 @@ class ActorSchedulingQueue : public SchedulingQueue {
} }
} }
bool TaskQueueEmpty() const { return pending_actor_tasks_.empty(); } bool TaskQueueEmpty() const {
RAY_CHECK(false) << "TaskQueueEmpty() not implemented for actor queues";
// The return instruction will never be executed, but we need to include it
// nonetheless because this is a non-void function.
return false;
}
size_t Size() const {
RAY_CHECK(false) << "Size() not implemented for actor queues";
// The return instruction will never be executed, but we need to include it
// nonetheless because this is a non-void function.
return 0;
}
/// Add a new actor task's callbacks to the worker queue. /// Add a new actor task's callbacks to the worker queue.
void Add(int64_t seq_no, int64_t client_processed_up_to, void Add(int64_t seq_no, int64_t client_processed_up_to,
std::function<void(rpc::SendReplyCallback)> accept_request, std::function<void(rpc::SendReplyCallback)> accept_request,
std::function<void(rpc::SendReplyCallback)> reject_request, std::function<void(rpc::SendReplyCallback)> reject_request,
rpc::SendReplyCallback send_reply_callback, TaskID task_id = TaskID::Nil(), rpc::SendReplyCallback send_reply_callback,
std::function<void(rpc::SendReplyCallback)> steal_request = nullptr,
TaskID task_id = TaskID::Nil(),
const std::vector<rpc::ObjectReference> &dependencies = {}) { const std::vector<rpc::ObjectReference> &dependencies = {}) {
// A seq_no of -1 means no ordering constraint. Actor tasks must be executed in order. // A seq_no of -1 means no ordering constraint. Actor tasks must be executed in order.
RAY_CHECK(seq_no != -1); RAY_CHECK(seq_no != -1);
@ -425,9 +454,11 @@ class ActorSchedulingQueue : public SchedulingQueue {
next_seq_no_ = client_processed_up_to + 1; next_seq_no_ = client_processed_up_to + 1;
} }
RAY_LOG(DEBUG) << "Enqueue " << seq_no << " cur seqno " << next_seq_no_; RAY_LOG(DEBUG) << "Enqueue " << seq_no << " cur seqno " << next_seq_no_;
pending_actor_tasks_[seq_no] =
InboundRequest(std::move(accept_request), std::move(reject_request), pending_actor_tasks_[seq_no] = InboundRequest(
std::move(accept_request), std::move(reject_request), std::move(steal_request),
std::move(send_reply_callback), task_id, dependencies.size() > 0); std::move(send_reply_callback), task_id, dependencies.size() > 0);
if (dependencies.size() > 0) { if (dependencies.size() > 0) {
waiter_.Wait(dependencies, [seq_no, this]() { waiter_.Wait(dependencies, [seq_no, this]() {
RAY_CHECK(boost::this_thread::get_id() == main_thread_id_); RAY_CHECK(boost::this_thread::get_id() == main_thread_id_);
@ -441,8 +472,15 @@ class ActorSchedulingQueue : public SchedulingQueue {
ScheduleRequests(); ScheduleRequests();
} }
// We don't allow the cancellation of actor tasks, so invoking CancelTaskIfFound results size_t Steal(rpc::StealTasksReply *reply) {
// in a fatal error. RAY_CHECK(false) << "Cannot steal actor tasks";
// The return instruction will never be executed, but we need to include it
// nonetheless because this is a non-void function.
return 0;
}
// We don't allow the cancellation of actor tasks, so invoking CancelTaskIfFound
// results in a fatal error.
bool CancelTaskIfFound(TaskID task_id) { bool CancelTaskIfFound(TaskID task_id) {
RAY_CHECK(false) << "Cannot cancel actor tasks"; RAY_CHECK(false) << "Cannot cancel actor tasks";
// The return instruction will never be executed, but we need to include it // The return instruction will never be executed, but we need to include it
@ -550,24 +588,60 @@ class NormalSchedulingQueue : public SchedulingQueue {
return pending_normal_tasks_.empty(); return pending_normal_tasks_.empty();
} }
// Returns the current size of the task queue.
size_t Size() const {
absl::MutexLock lock(&mu_);
return pending_normal_tasks_.size();
}
/// Add a new task's callbacks to the worker queue. /// Add a new task's callbacks to the worker queue.
void Add(int64_t seq_no, int64_t client_processed_up_to, void Add(int64_t seq_no, int64_t client_processed_up_to,
std::function<void(rpc::SendReplyCallback)> accept_request, std::function<void(rpc::SendReplyCallback)> accept_request,
std::function<void(rpc::SendReplyCallback)> reject_request, std::function<void(rpc::SendReplyCallback)> reject_request,
rpc::SendReplyCallback send_reply_callback, TaskID task_id = TaskID::Nil(), rpc::SendReplyCallback send_reply_callback,
std::function<void(rpc::SendReplyCallback)> steal_request = nullptr,
TaskID task_id = TaskID::Nil(),
const std::vector<rpc::ObjectReference> &dependencies = {}) { const std::vector<rpc::ObjectReference> &dependencies = {}) {
absl::MutexLock lock(&mu_); absl::MutexLock lock(&mu_);
// Normal tasks should not have ordering constraints. // Normal tasks should not have ordering constraints.
RAY_CHECK(seq_no == -1); RAY_CHECK(seq_no == -1);
// Create a InboundRequest object for the new task, and add it to the queue. // Create a InboundRequest object for the new task, and add it to the queue.
pending_normal_tasks_.push_back(
InboundRequest(std::move(accept_request), std::move(reject_request), pending_normal_tasks_.push_back(InboundRequest(
std::move(accept_request), std::move(reject_request), std::move(steal_request),
std::move(send_reply_callback), task_id, dependencies.size() > 0)); std::move(send_reply_callback), task_id, dependencies.size() > 0));
} }
/// Steal up to max_tasks tasks by removing them from the queue and responding to the
/// owner.
size_t Steal(rpc::StealTasksReply *reply) {
size_t tasks_stolen = 0;
absl::MutexLock lock(&mu_);
if (pending_normal_tasks_.size() <= 1) {
RAY_LOG(DEBUG) << "We don't have enough tasks to steal, so we return early!";
return tasks_stolen;
}
size_t half = pending_normal_tasks_.size() / 2;
for (tasks_stolen = 0; tasks_stolen < half; tasks_stolen++) {
RAY_CHECK(!pending_normal_tasks_.empty());
InboundRequest tail = pending_normal_tasks_.back();
pending_normal_tasks_.pop_back();
int stolen_task_ids = reply->stolen_tasks_ids_size();
tail.Steal(reply);
RAY_CHECK(reply->stolen_tasks_ids_size() == stolen_task_ids + 1);
}
return tasks_stolen;
}
// Search for an InboundRequest associated with the task that we are trying to cancel. // Search for an InboundRequest associated with the task that we are trying to cancel.
// If found, remove the InboundRequest from the queue and return true. Otherwise, return // If found, remove the InboundRequest from the queue and return true. Otherwise,
// false. // return false.
bool CancelTaskIfFound(TaskID task_id) { bool CancelTaskIfFound(TaskID task_id) {
absl::MutexLock lock(&mu_); absl::MutexLock lock(&mu_);
for (std::deque<InboundRequest>::reverse_iterator it = pending_normal_tasks_.rbegin(); for (std::deque<InboundRequest>::reverse_iterator it = pending_normal_tasks_.rbegin();
@ -641,6 +715,15 @@ class CoreWorkerDirectTaskReceiver {
/// Pop tasks from the queue and execute them sequentially /// Pop tasks from the queue and execute them sequentially
void RunNormalTasksFromQueue(); void RunNormalTasksFromQueue();
/// Handle a `StealTask` request.
///
/// \param[in] request The request message.
/// \param[out] reply The reply message.
/// \param[in] send_reply_callback The callback to be called when the request is done.
void HandleStealTasks(const rpc::StealTasksRequest &request,
rpc::StealTasksReply *reply,
rpc::SendReplyCallback send_reply_callback);
bool CancelQueuedNormalTask(TaskID task_id); bool CancelQueuedNormalTask(TaskID task_id);
private: private:

View file

@ -77,6 +77,8 @@ Status CoreWorkerDirectTaskSubmitter::SubmitTask(TaskSpecification task_spec) {
: ActorID::Nil()); : ActorID::Nil());
auto &scheduling_key_entry = scheduling_key_entries_[scheduling_key]; auto &scheduling_key_entry = scheduling_key_entries_[scheduling_key];
scheduling_key_entry.task_queue.push_back(task_spec); scheduling_key_entry.task_queue.push_back(task_spec);
scheduling_key_entry.resource_spec = task_spec;
if (!scheduling_key_entry.AllPipelinesToWorkersFull( if (!scheduling_key_entry.AllPipelinesToWorkersFull(
max_tasks_in_flight_per_worker_)) { max_tasks_in_flight_per_worker_)) {
// The pipelines to the current workers are not full yet, so we don't need more // The pipelines to the current workers are not full yet, so we don't need more
@ -118,8 +120,8 @@ void CoreWorkerDirectTaskSubmitter::AddWorkerLeaseClient(
const SchedulingKey &scheduling_key) { const SchedulingKey &scheduling_key) {
client_cache_->GetOrConnect(addr.ToProto()); client_cache_->GetOrConnect(addr.ToProto());
int64_t expiration = current_time_ms() + lease_timeout_ms_; int64_t expiration = current_time_ms() + lease_timeout_ms_;
LeaseEntry new_lease_entry = LeaseEntry(std::move(lease_client), expiration, 0, LeaseEntry new_lease_entry =
assigned_resources, scheduling_key); LeaseEntry(std::move(lease_client), expiration, assigned_resources, scheduling_key);
worker_to_lease_entry_.emplace(addr, new_lease_entry); worker_to_lease_entry_.emplace(addr, new_lease_entry);
auto &scheduling_key_entry = scheduling_key_entries_[scheduling_key]; auto &scheduling_key_entry = scheduling_key_entries_[scheduling_key];
@ -127,25 +129,16 @@ void CoreWorkerDirectTaskSubmitter::AddWorkerLeaseClient(
RAY_CHECK(scheduling_key_entry.active_workers.size() >= 1); RAY_CHECK(scheduling_key_entry.active_workers.size() >= 1);
} }
void CoreWorkerDirectTaskSubmitter::OnWorkerIdle( void CoreWorkerDirectTaskSubmitter::ReturnWorker(const rpc::WorkerAddress addr,
const rpc::WorkerAddress &addr, const SchedulingKey &scheduling_key, bool was_error, bool was_error,
const google::protobuf::RepeatedPtrField<rpc::ResourceMapEntry> &assigned_resources) { const SchedulingKey &scheduling_key) {
auto &lease_entry = worker_to_lease_entry_[addr];
if (!lease_entry.lease_client) {
return;
}
auto &scheduling_key_entry = scheduling_key_entries_[scheduling_key]; auto &scheduling_key_entry = scheduling_key_entries_[scheduling_key];
auto &current_queue = scheduling_key_entry.task_queue;
// Return the worker if there was an error executing the previous task,
// the previous task is an actor creation task,
// there are no more applicable queued tasks, or the lease is expired.
if (was_error || current_queue.empty() ||
current_time_ms() > lease_entry.lease_expiration_time) {
RAY_CHECK(scheduling_key_entry.active_workers.size() >= 1); RAY_CHECK(scheduling_key_entry.active_workers.size() >= 1);
auto &lease_entry = worker_to_lease_entry_[addr];
RAY_CHECK(lease_entry.lease_client);
RAY_CHECK(lease_entry.tasks_in_flight == 0);
RAY_CHECK(lease_entry.WorkerIsStealing() == false);
// Return the worker only if there are no tasks in flight
if (lease_entry.tasks_in_flight == 0) {
// Decrement the number of active workers consuming tasks from the queue associated // Decrement the number of active workers consuming tasks from the queue associated
// with the current scheduling_key // with the current scheduling_key
scheduling_key_entry.active_workers.erase(addr); scheduling_key_entry.active_workers.erase(addr);
@ -163,14 +156,201 @@ void CoreWorkerDirectTaskSubmitter::OnWorkerIdle(
worker_to_lease_entry_.erase(addr); worker_to_lease_entry_.erase(addr);
} }
bool CoreWorkerDirectTaskSubmitter::FindOptimalVictimForStealing(
const SchedulingKey &scheduling_key, rpc::WorkerAddress thief_addr,
rpc::Address *victim_raw_addr) {
auto &scheduling_key_entry = scheduling_key_entries_[scheduling_key];
// Check that there is at least one worker (other than the thief) with the current
// SchedulingKey and that there are stealable tasks
if (scheduling_key_entry.active_workers.size() <= 1 ||
!scheduling_key_entry.StealableTasks()) {
return false;
}
// Iterate through the active workers with the relevant SchedulingKey, and select the
// best one for stealing by updating the victim_raw_addr (pointing to the designated
// victim) every time we find a candidate that is better than the incumbent. A candidate
// is better if: (1) the incumbent victim is the thief -- because this choice would be
// illegal (thief cannot steal from itself), so any alternative choice is better (2) the
// candidate is not the thief (otherwise, again, it cannot be designated as the victim),
// and it has more stealable tasks than the incumbent victim
*victim_raw_addr = scheduling_key_entry.active_workers.begin()->ToProto();
for (auto candidate_it = scheduling_key_entry.active_workers.begin();
candidate_it != scheduling_key_entry.active_workers.end(); candidate_it++) {
const rpc::WorkerAddress &candidate_addr = *candidate_it;
const auto &candidate_entry = worker_to_lease_entry_[candidate_addr];
const rpc::WorkerAddress victim_addr = rpc::WorkerAddress(*victim_raw_addr);
RAY_CHECK(worker_to_lease_entry_.find(victim_addr) != worker_to_lease_entry_.end());
const auto &victim_entry = worker_to_lease_entry_[victim_addr];
// Update the designated victim if the alternative candidate is a better choice than
// the incumbent victim
if (victim_addr.worker_id == thief_addr.worker_id ||
((candidate_entry.tasks_in_flight > victim_entry.tasks_in_flight) &&
candidate_addr.worker_id != thief_addr.worker_id)) {
// We copy the candidate's rpc::Address (instead of its rpc::WorkerAddress) because
// objects of type 'ray::rpc::WorkerAddress' cannot be assigned as their copy
// assignment operator is implicitly deleted
*victim_raw_addr = candidate_addr.ToProto();
}
}
const rpc::WorkerAddress victim_addr = rpc::WorkerAddress(*victim_raw_addr);
// We can't steal unless we can find a thief and a victim with distinct addresses/worker
// ids. In fact, if we allow stealing among workers with the same address/worker id, we
// will also necessarily enable self-stealing.
if ((victim_addr == thief_addr) || victim_addr.worker_id == thief_addr.worker_id) {
RAY_LOG(INFO) << "No victim available with address distinct from thief!";
RAY_LOG(INFO) << "victim_addr.worker_id: " << victim_addr.worker_id
<< " thief_addr.worker_id: " << thief_addr.worker_id;
return false;
}
const auto &victim_entry = worker_to_lease_entry_[victim_addr];
// Double check that the victim has the correct SchedulingKey
RAY_CHECK(victim_entry.scheduling_key == scheduling_key);
RAY_LOG(DEBUG) << "Victim is worker " << victim_addr.worker_id << " and has "
<< victim_entry.tasks_in_flight << " tasks in flight, "
<< " among which we estimate that " << victim_entry.tasks_in_flight / 2
<< " are available for stealing";
RAY_CHECK(scheduling_key_entry.total_tasks_in_flight >= victim_entry.tasks_in_flight);
if ((victim_entry.tasks_in_flight / 2) < 1) {
RAY_LOG(DEBUG) << "The designated victim does not have enough tasks to steal.";
return false;
}
return true;
}
void CoreWorkerDirectTaskSubmitter::StealTasksOrReturnWorker(
const rpc::WorkerAddress &thief_addr, bool was_error,
const SchedulingKey &scheduling_key,
const google::protobuf::RepeatedPtrField<rpc::ResourceMapEntry> &assigned_resources) {
auto &thief_entry = worker_to_lease_entry_[thief_addr];
// Check that the thief still retains its lease_client, and it has no tasks in flights
RAY_CHECK(thief_entry.lease_client);
RAY_CHECK(thief_entry.tasks_in_flight == 0);
RAY_CHECK(thief_entry.WorkerIsStealing() == false);
// Return the worker if there was an error or the lease has expired.
if ((was_error || current_time_ms() > thief_entry.lease_expiration_time)) {
RAY_LOG(DEBUG) << "Returning worker " << thief_addr.worker_id
<< " due to error or lease expiration";
ReturnWorker(thief_addr, was_error, scheduling_key);
return;
}
RAY_LOG(DEBUG) << "Beginning to steal work now! Thief is worker: "
<< thief_addr.worker_id;
// Search for a suitable victim
rpc::Address victim_raw_addr;
if (!FindOptimalVictimForStealing(scheduling_key, thief_addr, &victim_raw_addr)) {
RAY_LOG(DEBUG) << "Could not find a suitable victim for stealing! Returning worker "
<< thief_addr.worker_id;
// If stealing was enabled, we can now cancel any pending new workeer lease request,
// because stealing is now possible this time.
if (max_tasks_in_flight_per_worker_ > 1) {
CancelWorkerLeaseIfNeeded(scheduling_key);
}
ReturnWorker(thief_addr, was_error, scheduling_key);
return;
}
// If we get here, stealing must be enabled.
RAY_CHECK(max_tasks_in_flight_per_worker_ > 1);
rpc::WorkerAddress victim_addr = rpc::WorkerAddress(victim_raw_addr);
RAY_CHECK(worker_to_lease_entry_.find(victim_addr) != worker_to_lease_entry_.end());
thief_entry.SetWorkerIsStealing();
// By this point, we have ascertained that the victim is available for stealing, so we
// can go ahead with the RPC
RAY_LOG(DEBUG) << "Executing StealTasks RPC!";
auto request = std::unique_ptr<rpc::StealTasksRequest>(new rpc::StealTasksRequest);
request->mutable_thief_addr()->CopyFrom(thief_addr.ToProto());
auto &victim_client = *client_cache_->GetOrConnect(victim_addr.ToProto());
auto victim_wid = victim_addr.worker_id;
RAY_UNUSED(victim_client.StealTasks(
std::move(request), [this, scheduling_key, victim_wid, victim_addr, thief_addr](
Status status, const rpc::StealTasksReply &reply) {
absl::MutexLock lock(&mu_);
// Obtain the thief's lease entry (after ensuring that it still exists)
RAY_CHECK(worker_to_lease_entry_.find(thief_addr) !=
worker_to_lease_entry_.end());
auto &thief_entry = worker_to_lease_entry_[thief_addr];
RAY_CHECK(thief_entry.WorkerIsStealing());
// Compute number of tasks stolen
size_t number_of_tasks_stolen = reply.stolen_tasks_ids_size();
RAY_LOG(DEBUG) << "We stole " << number_of_tasks_stolen << " tasks "
<< "from worker: " << victim_wid;
thief_entry.SetWorkerDoneStealing();
// push all tasks to the front of the queue
for (size_t i = 0; i < number_of_tasks_stolen; i++) {
// Get the task_id of the stolen task, and obtain the corresponding task_spec
// from the TaskManager
TaskID stolen_task_id = TaskID::FromBinary(reply.stolen_tasks_ids(i));
RAY_CHECK(task_finisher_->GetTaskSpec(stolen_task_id));
auto stolen_task_spec = *(task_finisher_->GetTaskSpec(stolen_task_id));
// delete the stolen task from the executing_tasks map if it is still there.
executing_tasks_.erase(stolen_task_id);
auto &scheduling_key_entry = scheduling_key_entries_[scheduling_key];
// Add the task to the queue
RAY_LOG(DEBUG) << "Adding stolen task " << stolen_task_spec.TaskId()
<< " back to the queue (of current size="
<< scheduling_key_entry.task_queue.size() << ")!";
scheduling_key_entry.task_queue.push_front(stolen_task_spec);
}
// call OnWorkerIdle to ship the task to the thief
OnWorkerIdle(thief_addr, scheduling_key, /*error=*/!status.ok(),
thief_entry.assigned_resources);
}));
}
void CoreWorkerDirectTaskSubmitter::OnWorkerIdle(
const rpc::WorkerAddress &addr, const SchedulingKey &scheduling_key, bool was_error,
const google::protobuf::RepeatedPtrField<rpc::ResourceMapEntry> &assigned_resources) {
auto &lease_entry = worker_to_lease_entry_[addr];
if (!lease_entry.lease_client) {
return;
}
auto &scheduling_key_entry = scheduling_key_entries_[scheduling_key];
auto &current_queue = scheduling_key_entry.task_queue;
// Return the worker if there was an error executing the previous task,
// the lease is expired; Steal or return the worker if there are no more applicable
// queued tasks and the worker is not stealing.
if ((was_error || current_time_ms() > lease_entry.lease_expiration_time) ||
(current_queue.empty() && !lease_entry.WorkerIsStealing())) {
RAY_CHECK(scheduling_key_entry.active_workers.size() >= 1);
// Return the worker only if there are no tasks in flight
if (lease_entry.tasks_in_flight == 0) {
RAY_LOG(DEBUG)
<< "Number of tasks in flight == 0, calling StealTasksOrReturnWorker!";
StealTasksOrReturnWorker(addr, was_error, scheduling_key, assigned_resources);
}
} else { } else {
auto &client = *client_cache_->GetOrConnect(addr.ToProto()); auto &client = *client_cache_->GetOrConnect(addr.ToProto());
while (!current_queue.empty() && while (!current_queue.empty() &&
!lease_entry.PipelineToWorkerFull(max_tasks_in_flight_per_worker_)) { !lease_entry.PipelineToWorkerFull(max_tasks_in_flight_per_worker_)) {
auto task_spec = current_queue.front(); auto task_spec = current_queue.front();
lease_entry // Increment the number of tasks in flight to the worker
.tasks_in_flight++; // Increment the number of tasks in flight to the worker lease_entry.tasks_in_flight++;
// Increment the total number of tasks in flight to any worker associated with the // Increment the total number of tasks in flight to any worker associated with the
// current scheduling_key // current scheduling_key
@ -182,11 +362,8 @@ void CoreWorkerDirectTaskSubmitter::OnWorkerIdle(
PushNormalTask(addr, client, scheduling_key, task_spec, assigned_resources); PushNormalTask(addr, client, scheduling_key, task_spec, assigned_resources);
current_queue.pop_front(); current_queue.pop_front();
} }
// If stealing is not an option, we can cancel the request for new worker leases
// Delete the queue if it's now empty. Note that the queue cannot already be empty if (max_tasks_in_flight_per_worker_ == 1) {
// because this is the only place tasks are removed from it.
if (current_queue.empty()) {
RAY_LOG(INFO) << "Task queue empty, canceling lease request";
CancelWorkerLeaseIfNeeded(scheduling_key); CancelWorkerLeaseIfNeeded(scheduling_key);
} }
} }
@ -197,11 +374,15 @@ void CoreWorkerDirectTaskSubmitter::CancelWorkerLeaseIfNeeded(
const SchedulingKey &scheduling_key) { const SchedulingKey &scheduling_key) {
auto &scheduling_key_entry = scheduling_key_entries_[scheduling_key]; auto &scheduling_key_entry = scheduling_key_entries_[scheduling_key];
auto &task_queue = scheduling_key_entry.task_queue; auto &task_queue = scheduling_key_entry.task_queue;
if (!task_queue.empty()) { if (!task_queue.empty() || scheduling_key_entry.StealableTasks()) {
// There are still pending tasks, so let the worker lease request succeed. // There are still pending tasks, or there are tasks that can be stolen by a new
// worker, so let the worker lease request succeed.
return; return;
} }
RAY_LOG(DEBUG)
<< "Task queue is empty, and there are no stealable tasks; canceling lease request";
auto &pending_lease_request = scheduling_key_entry.pending_lease_request; auto &pending_lease_request = scheduling_key_entry.pending_lease_request;
if (pending_lease_request.first) { if (pending_lease_request.first) {
// There is an in-flight lease request. Cancel it. // There is an in-flight lease request. Cancel it.
@ -237,7 +418,7 @@ CoreWorkerDirectTaskSubmitter::GetOrConnectLeaseClient(
NodeID raylet_id = NodeID::FromBinary(raylet_address->raylet_id()); NodeID raylet_id = NodeID::FromBinary(raylet_address->raylet_id());
auto it = remote_lease_clients_.find(raylet_id); auto it = remote_lease_clients_.find(raylet_id);
if (it == remote_lease_clients_.end()) { if (it == remote_lease_clients_.end()) {
RAY_LOG(DEBUG) << "Connecting to raylet " << raylet_id; RAY_LOG(INFO) << "Connecting to raylet " << raylet_id;
it = remote_lease_clients_ it = remote_lease_clients_
.emplace(raylet_id, lease_client_factory_(raylet_address->ip_address(), .emplace(raylet_id, lease_client_factory_(raylet_address->ip_address(),
raylet_address->port())) raylet_address->port()))
@ -261,9 +442,27 @@ void CoreWorkerDirectTaskSubmitter::RequestNewWorkerIfNeeded(
return; return;
} }
// Check whether we really need a new worker or whether we have
// enough room in an existing worker's pipeline to send the new tasks. If the pipelines
// are not full, we do not request a new worker (unless work stealing is enabled, in
// which case we can request a worker under the Eager Worker Requesting mode)
if (!scheduling_key_entry.AllPipelinesToWorkersFull(max_tasks_in_flight_per_worker_) &&
max_tasks_in_flight_per_worker_ == 1) {
// The pipelines to the current workers are not full yet, so we don't need more
// workers.
return;
}
auto &task_queue = scheduling_key_entry.task_queue; auto &task_queue = scheduling_key_entry.task_queue;
// Check if the task queue is empty. If that is the case, it only makes sense to
// consider requesting a new worker if work stealing is enabled, and there is at least a
// worker with stealable tasks. If work stealing is not enabled, or there is no tasks
// that we can steal from existing workers, we don't need a new worker because we don't
// have any tasks to execute on that worker.
if (task_queue.empty()) { if (task_queue.empty()) {
// We don't have any of this type of task to run. // If any worker has more than one task in flight, then that task can be stolen.
bool stealable_tasks = scheduling_key_entry.StealableTasks();
if (!stealable_tasks) {
if (scheduling_key_entry.CanDelete()) { if (scheduling_key_entry.CanDelete()) {
// We can safely remove the entry keyed by scheduling_key from the // We can safely remove the entry keyed by scheduling_key from the
// scheduling_key_entries_ hashmap. // scheduling_key_entries_ hashmap.
@ -271,26 +470,26 @@ void CoreWorkerDirectTaskSubmitter::RequestNewWorkerIfNeeded(
} }
return; return;
} }
// Check whether we really need a new worker or whether we have
// enough room in an existing worker's pipeline to send the new tasks
if (!scheduling_key_entry.AllPipelinesToWorkersFull(max_tasks_in_flight_per_worker_)) {
// The pipelines to the current workers are not full yet, so we don't need more
// workers.
return;
} }
TaskSpecification &resource_spec = task_queue.front(); // Create a TaskSpecification with an overwritten TaskID to make sure we don't reuse the
// same TaskID to request a worker
auto resource_spec_msg = scheduling_key_entry.resource_spec.GetMutableMessage();
resource_spec_msg.set_task_id(TaskID::ForFakeTask().Binary());
TaskSpecification resource_spec = TaskSpecification(resource_spec_msg);
rpc::Address best_node_address; rpc::Address best_node_address;
if (raylet_address == nullptr) { if (raylet_address == nullptr) {
// If no raylet address is given, find the best worker for our next lease request. // If no raylet address is given, find the best worker for our next lease request.
best_node_address = lease_policy_->GetBestNodeForTask(resource_spec); best_node_address = lease_policy_->GetBestNodeForTask(resource_spec);
raylet_address = &best_node_address; raylet_address = &best_node_address;
} }
auto lease_client = GetOrConnectLeaseClient(raylet_address); auto lease_client = GetOrConnectLeaseClient(raylet_address);
TaskID task_id = resource_spec.TaskId(); TaskID task_id = resource_spec.TaskId();
// Subtract 1 so we don't double count the task we are requesting for. // Subtract 1 so we don't double count the task we are requesting for.
int64_t queue_size = task_queue.size() - 1; int64_t queue_size = task_queue.size() - 1;
lease_client->RequestWorkerLease( lease_client->RequestWorkerLease(
resource_spec, resource_spec,
[this, scheduling_key](const Status &status, [this, scheduling_key](const Status &status,
@ -313,6 +512,7 @@ void CoreWorkerDirectTaskSubmitter::RequestNewWorkerIfNeeded(
// assign work to the worker. // assign work to the worker.
RAY_LOG(DEBUG) << "Lease granted " << task_id; RAY_LOG(DEBUG) << "Lease granted " << task_id;
rpc::WorkerAddress addr(reply.worker_address()); rpc::WorkerAddress addr(reply.worker_address());
auto resources_copy = reply.resource_mapping(); auto resources_copy = reply.resource_mapping();
AddWorkerLeaseClient(addr, std::move(lease_client), resources_copy, AddWorkerLeaseClient(addr, std::move(lease_client), resources_copy,
@ -322,6 +522,7 @@ void CoreWorkerDirectTaskSubmitter::RequestNewWorkerIfNeeded(
/*error=*/false, resources_copy); /*error=*/false, resources_copy);
} else { } else {
// The raylet redirected us to a different raylet to retry at. // The raylet redirected us to a different raylet to retry at.
RequestNewWorkerIfNeeded(scheduling_key, &reply.retry_at_raylet_address()); RequestNewWorkerIfNeeded(scheduling_key, &reply.retry_at_raylet_address());
} }
} else if (lease_client != local_lease_client_) { } else if (lease_client != local_lease_client_) {
@ -330,7 +531,9 @@ void CoreWorkerDirectTaskSubmitter::RequestNewWorkerIfNeeded(
// TODO(swang): Fail after some number of retries? // TODO(swang): Fail after some number of retries?
RAY_LOG(ERROR) << "Retrying attempt to schedule task at remote node. Error: " RAY_LOG(ERROR) << "Retrying attempt to schedule task at remote node. Error: "
<< status.ToString(); << status.ToString();
RequestNewWorkerIfNeeded(scheduling_key); RequestNewWorkerIfNeeded(scheduling_key);
} else { } else {
// A local request failed. This shouldn't happen if the raylet is still alive // A local request failed. This shouldn't happen if the raylet is still alive
// and we don't currently handle raylet failures, so treat it as a fatal // and we don't currently handle raylet failures, so treat it as a fatal
@ -360,10 +563,10 @@ void CoreWorkerDirectTaskSubmitter::PushNormalTask(
request->mutable_task_spec()->CopyFrom(task_spec.GetMessage()); request->mutable_task_spec()->CopyFrom(task_spec.GetMessage());
request->mutable_resource_mapping()->CopyFrom(assigned_resources); request->mutable_resource_mapping()->CopyFrom(assigned_resources);
request->set_intended_worker_id(addr.worker_id.Binary()); request->set_intended_worker_id(addr.worker_id.Binary());
client.PushNormalTask(std::move(request), [this, task_id, is_actor, is_actor_creation, client.PushNormalTask(
scheduling_key, addr, assigned_resources]( std::move(request),
Status status, [this, task_spec, task_id, is_actor, is_actor_creation, scheduling_key, addr,
const rpc::PushTaskReply &reply) { assigned_resources](Status status, const rpc::PushTaskReply &reply) {
{ {
absl::MutexLock lock(&mu_); absl::MutexLock lock(&mu_);
executing_tasks_.erase(task_id); executing_tasks_.erase(task_id);
@ -379,11 +582,12 @@ void CoreWorkerDirectTaskSubmitter::PushNormalTask(
RAY_CHECK(scheduling_key_entry.active_workers.size() >= 1); RAY_CHECK(scheduling_key_entry.active_workers.size() >= 1);
RAY_CHECK(scheduling_key_entry.total_tasks_in_flight >= 1); RAY_CHECK(scheduling_key_entry.total_tasks_in_flight >= 1);
scheduling_key_entry.total_tasks_in_flight--; scheduling_key_entry.total_tasks_in_flight--;
}
if (reply.worker_exiting()) { if (reply.worker_exiting()) {
RAY_LOG(DEBUG) << "Worker " << addr.worker_id
<< " replied that it is exiting.";
// The worker is draining and will shutdown after it is done. Don't return // The worker is draining and will shutdown after it is done. Don't return
// it to the Raylet since that will kill it early. // it to the Raylet since that will kill it early.
absl::MutexLock lock(&mu_);
worker_to_lease_entry_.erase(addr); worker_to_lease_entry_.erase(addr);
auto &scheduling_key_entry = scheduling_key_entries_[scheduling_key]; auto &scheduling_key_entry = scheduling_key_entries_[scheduling_key];
scheduling_key_entry.active_workers.erase(addr); scheduling_key_entry.active_workers.erase(addr);
@ -392,19 +596,25 @@ void CoreWorkerDirectTaskSubmitter::PushNormalTask(
// scheduling_key_entries_ hashmap. // scheduling_key_entries_ hashmap.
scheduling_key_entries_.erase(scheduling_key); scheduling_key_entries_.erase(scheduling_key);
} }
} else if (reply.task_stolen()) {
// If the task was stolen, we push it to the thief worker & call OnWorkerIdle
// in the StealTasks callback within StealTasksOrReturnWorker. So we don't
// need to do anything here.
return;
} else if (!status.ok() || !is_actor_creation) { } else if (!status.ok() || !is_actor_creation) {
// Successful actor creation leases the worker indefinitely from the raylet. // Successful actor creation leases the worker indefinitely from the raylet.
absl::MutexLock lock(&mu_);
OnWorkerIdle(addr, scheduling_key, OnWorkerIdle(addr, scheduling_key,
/*error=*/!status.ok(), assigned_resources); /*error=*/!status.ok(), assigned_resources);
} }
}
if (!status.ok()) { if (!status.ok()) {
// TODO: It'd be nice to differentiate here between process vs node // TODO: It'd be nice to differentiate here between process vs node
// failure (e.g., by contacting the raylet). If it was a process // failure (e.g., by contacting the raylet). If it was a process
// failure, it may have been an application-level error and it may // failure, it may have been an application-level error and it may
// not make sense to retry the task. // not make sense to retry the task.
RAY_UNUSED(task_finisher_->PendingTaskFailed( RAY_UNUSED(task_finisher_->PendingTaskFailed(
task_id, is_actor ? rpc::ErrorType::ACTOR_DIED : rpc::ErrorType::WORKER_DIED, task_id,
is_actor ? rpc::ErrorType::ACTOR_DIED : rpc::ErrorType::WORKER_DIED,
&status)); &status));
} else { } else {
task_finisher_->CompletePendingTask(task_id, reply, addr.ToProto()); task_finisher_->CompletePendingTask(task_id, reply, addr.ToProto());

View file

@ -136,6 +136,43 @@ class CoreWorkerDirectTaskSubmitter {
const google::protobuf::RepeatedPtrField<rpc::ResourceMapEntry> &assigned_resources, const google::protobuf::RepeatedPtrField<rpc::ResourceMapEntry> &assigned_resources,
const SchedulingKey &scheduling_key) EXCLUSIVE_LOCKS_REQUIRED(mu_); const SchedulingKey &scheduling_key) EXCLUSIVE_LOCKS_REQUIRED(mu_);
/// This function takes care of returning a worker to the Raylet.
/// \param[in] addr The address of the worker.
/// \param[in] was_error Whether the task failed to be submitted.
void ReturnWorker(const rpc::WorkerAddress addr, bool was_error,
const SchedulingKey &scheduling_key) EXCLUSIVE_LOCKS_REQUIRED(mu_);
/// Check that the scheduling_key_entries_ hashmap is empty.
inline bool CheckNoSchedulingKeyEntries() const EXCLUSIVE_LOCKS_REQUIRED(mu_) {
return scheduling_key_entries_.empty();
}
/// Find the optimal victim (if there is any) for stealing work
///
/// \param[in] scheduling_key The SchedulingKey of the thief.
/// \param[in] victim_addr The pointer to a variable that the function will fill with
/// the address of the victim, if one is found \param[out] A boolean indicating whether
/// we found a suitable victim or not
bool FindOptimalVictimForStealing(const SchedulingKey &scheduling_key,
rpc::WorkerAddress thief_addr,
rpc::Address *victim_raw_addr)
EXCLUSIVE_LOCKS_REQUIRED(mu_);
/// Look for workers with a surplus of tasks in flight, and, if it is possible,
/// steal some of those tasks and submit them to the current worker. If no tasks
/// are available for stealing, return the worker to the Raylet.
///
/// \param[in] thief_addr The address of the worker that has finished its own work,
/// and is ready for stealing.
/// \param[in] was_error Whether the last task failed to be submitted to the worker.
/// \param[in] scheduling_key The scheduling class of the worker.
/// \param[in] assigned_resources Resource ids previously assigned to the worker.
void StealTasksOrReturnWorker(
const rpc::WorkerAddress &thief_addr, bool was_error,
const SchedulingKey &scheduling_key,
const google::protobuf::RepeatedPtrField<rpc::ResourceMapEntry> &assigned_resources)
EXCLUSIVE_LOCKS_REQUIRED(mu_);
/// Push a task to a specific worker. /// Push a task to a specific worker.
void PushNormalTask(const rpc::WorkerAddress &addr, void PushNormalTask(const rpc::WorkerAddress &addr,
rpc::CoreWorkerClientInterface &client, rpc::CoreWorkerClientInterface &client,
@ -144,11 +181,6 @@ class CoreWorkerDirectTaskSubmitter {
const google::protobuf::RepeatedPtrField<rpc::ResourceMapEntry> const google::protobuf::RepeatedPtrField<rpc::ResourceMapEntry>
&assigned_resources); &assigned_resources);
/// Check that the scheduling_key_entries_ hashmap is empty.
bool CheckNoSchedulingKeyEntries() const EXCLUSIVE_LOCKS_REQUIRED(mu_) {
return scheduling_key_entries_.empty();
}
/// Address of our RPC server. /// Address of our RPC server.
rpc::Address rpc_address_; rpc::Address rpc_address_;
@ -197,32 +229,53 @@ class CoreWorkerDirectTaskSubmitter {
/// (1) The lease client through which the worker should be returned /// (1) The lease client through which the worker should be returned
/// (2) The expiration time of a worker's lease. /// (2) The expiration time of a worker's lease.
/// (3) The number of tasks that are currently in flight to the worker /// (3) The number of tasks that are currently in flight to the worker
/// (4) The resources assigned to the worker /// (4) A boolean that indicates whether we have launched a StealTasks request, and we
/// (5) The SchedulingKey assigned to tasks that will be sent to the worker /// are waiting for the stolen tasks (5) The resources assigned to the worker (6) The
/// SchedulingKey assigned to tasks that will be sent to the worker
struct LeaseEntry { struct LeaseEntry {
std::shared_ptr<WorkerLeaseInterface> lease_client; std::shared_ptr<WorkerLeaseInterface> lease_client;
int64_t lease_expiration_time; int64_t lease_expiration_time;
uint32_t tasks_in_flight; uint32_t tasks_in_flight = 0;
bool currently_stealing = false;
google::protobuf::RepeatedPtrField<rpc::ResourceMapEntry> assigned_resources; google::protobuf::RepeatedPtrField<rpc::ResourceMapEntry> assigned_resources;
SchedulingKey scheduling_key; SchedulingKey scheduling_key;
LeaseEntry( LeaseEntry(
std::shared_ptr<WorkerLeaseInterface> lease_client = nullptr, std::shared_ptr<WorkerLeaseInterface> lease_client = nullptr,
int64_t lease_expiration_time = 0, uint32_t tasks_in_flight = 0, int64_t lease_expiration_time = 0,
google::protobuf::RepeatedPtrField<rpc::ResourceMapEntry> assigned_resources = google::protobuf::RepeatedPtrField<rpc::ResourceMapEntry> assigned_resources =
google::protobuf::RepeatedPtrField<rpc::ResourceMapEntry>(), google::protobuf::RepeatedPtrField<rpc::ResourceMapEntry>(),
SchedulingKey scheduling_key = std::make_tuple(0, std::vector<ObjectID>(), SchedulingKey scheduling_key = std::make_tuple(0, std::vector<ObjectID>(),
ActorID::Nil())) ActorID::Nil()))
: lease_client(lease_client), : lease_client(lease_client),
lease_expiration_time(lease_expiration_time), lease_expiration_time(lease_expiration_time),
tasks_in_flight(tasks_in_flight),
assigned_resources(assigned_resources), assigned_resources(assigned_resources),
scheduling_key(scheduling_key) {} scheduling_key(scheduling_key) {}
// Check whether the pipeline to the worker associated with a LeaseEntry is full. // Check whether the pipeline to the worker associated with a LeaseEntry is full.
bool PipelineToWorkerFull(uint32_t max_tasks_in_flight_per_worker) const { inline bool PipelineToWorkerFull(uint32_t max_tasks_in_flight_per_worker) const {
return tasks_in_flight == max_tasks_in_flight_per_worker; return tasks_in_flight == max_tasks_in_flight_per_worker;
} }
// Check whether the worker is a thief who is in the process of stealing tasks.
// Knowing whether a thief is currently stealing is important to prevent the thief
// from initiating another StealTasks request or from being returned to the raylet
// until stealing has completed.
inline bool WorkerIsStealing() const { return currently_stealing; }
// Once stealing has begun, updated the thief's currently_stealing flag to reflect the
// new state.
inline void SetWorkerIsStealing() {
RAY_CHECK(!currently_stealing);
currently_stealing = true;
}
// Once stealing has completed, updated the thief's currently_stealing flag to reflect
// the new state.
inline void SetWorkerDoneStealing() {
RAY_CHECK(currently_stealing);
currently_stealing = false;
}
}; };
// Map from worker address to a LeaseEntry struct containing the lease's metadata. // Map from worker address to a LeaseEntry struct containing the lease's metadata.
@ -233,6 +286,7 @@ class CoreWorkerDirectTaskSubmitter {
// Keep track of pending worker lease requests to the raylet. // Keep track of pending worker lease requests to the raylet.
std::pair<std::shared_ptr<WorkerLeaseInterface>, TaskID> pending_lease_request = std::pair<std::shared_ptr<WorkerLeaseInterface>, TaskID> pending_lease_request =
std::make_pair(nullptr, TaskID::Nil()); std::make_pair(nullptr, TaskID::Nil());
TaskSpecification resource_spec = TaskSpecification();
// Tasks that are queued for execution. We keep an individual queue per // Tasks that are queued for execution. We keep an individual queue per
// scheduling class to ensure fairness. // scheduling class to ensure fairness.
std::deque<TaskSpecification> task_queue = std::deque<TaskSpecification>(); std::deque<TaskSpecification> task_queue = std::deque<TaskSpecification>();
@ -245,7 +299,7 @@ class CoreWorkerDirectTaskSubmitter {
// Check whether it's safe to delete this SchedulingKeyEntry from the // Check whether it's safe to delete this SchedulingKeyEntry from the
// scheduling_key_entries_ hashmap. // scheduling_key_entries_ hashmap.
bool CanDelete() const { inline bool CanDelete() const {
if (!pending_lease_request.first && task_queue.empty() && if (!pending_lease_request.first && task_queue.empty() &&
active_workers.size() == 0 && total_tasks_in_flight == 0) { active_workers.size() == 0 && total_tasks_in_flight == 0) {
return true; return true;
@ -256,10 +310,23 @@ class CoreWorkerDirectTaskSubmitter {
// Check whether the pipelines to the active workers associated with a // Check whether the pipelines to the active workers associated with a
// SchedulingKeyEntry are all full. // SchedulingKeyEntry are all full.
bool AllPipelinesToWorkersFull(uint32_t max_tasks_in_flight_per_worker) const { inline bool AllPipelinesToWorkersFull(uint32_t max_tasks_in_flight_per_worker) const {
return total_tasks_in_flight == return total_tasks_in_flight >=
(active_workers.size() * max_tasks_in_flight_per_worker); (active_workers.size() * max_tasks_in_flight_per_worker);
} }
// Check whether there exists at least one task that can be stolen
inline bool StealableTasks() const {
// TODO: Make this function more accurate without introducing excessive
// inefficiencies. Currently, there is one scenario where this function can return
// false even if there are stealable tasks. This happens if the number of tasks in
// flight is less or equal to the number of active workers (so the condition below
// evaluates to FALSE), but some workers have more than 1 task queued, while others
// have none.
// If any worker has more than one task in flight, then that task can be stolen.
return total_tasks_in_flight > active_workers.size();
}
}; };
// For each Scheduling Key, scheduling_key_entries_ contains a SchedulingKeyEntry struct // For each Scheduling Key, scheduling_key_entries_ contains a SchedulingKeyEntry struct

View file

@ -72,6 +72,16 @@ message ReturnObject {
int64 size = 6; int64 size = 6;
} }
message StealTasksRequest {
// The address of the thief that is requesting to steal tasks.
Address thief_addr = 1;
}
message StealTasksReply {
// The TaskIDs of the tasks that were stolen
repeated bytes stolen_tasks_ids = 2;
}
message PushTaskRequest { message PushTaskRequest {
// The ID of the worker this message is intended for. // The ID of the worker this message is intended for.
bytes intended_worker_id = 1; bytes intended_worker_id = 1;
@ -95,8 +105,10 @@ message PushTaskRequest {
message PushTaskReply { message PushTaskReply {
// The returned objects. // The returned objects.
repeated ReturnObject return_objects = 1; repeated ReturnObject return_objects = 1;
// Set to true if the task was stolen before its execution at the worker.
bool task_stolen = 2;
// Set to true if the worker will be exiting. // Set to true if the worker will be exiting.
bool worker_exiting = 2; bool worker_exiting = 3;
// The references that the worker borrowed during the task execution. A // The references that the worker borrowed during the task execution. A
// borrower is a process that is currently using the object ID, in one of 3 // borrower is a process that is currently using the object ID, in one of 3
// ways: // ways:
@ -111,7 +123,7 @@ message PushTaskReply {
// counts for any IDs that were nested inside these objects that the worker // counts for any IDs that were nested inside these objects that the worker
// may now be borrowing. The reference counts also include any new borrowers // may now be borrowing. The reference counts also include any new borrowers
// that the worker created by passing a borrowed ID into a nested task. // that the worker created by passing a borrowed ID into a nested task.
repeated ObjectReferenceCount borrowed_refs = 3; repeated ObjectReferenceCount borrowed_refs = 4;
} }
message DirectActorCallArgWaitCompleteRequest { message DirectActorCallArgWaitCompleteRequest {
@ -365,6 +377,8 @@ message RunOnUtilWorkerReply {
service CoreWorkerService { service CoreWorkerService {
// Push a task directly to this worker from another. // Push a task directly to this worker from another.
rpc PushTask(PushTaskRequest) returns (PushTaskReply); rpc PushTask(PushTaskRequest) returns (PushTaskReply);
// Steal tasks from a worker if it has a surplus of work
rpc StealTasks(StealTasksRequest) returns (StealTasksReply);
// Reply from raylet that wait for direct actor call args has completed. // Reply from raylet that wait for direct actor call args has completed.
rpc DirectActorCallArgWaitComplete(DirectActorCallArgWaitCompleteRequest) rpc DirectActorCallArgWaitComplete(DirectActorCallArgWaitCompleteRequest)
returns (DirectActorCallArgWaitCompleteReply); returns (DirectActorCallArgWaitCompleteReply);

View file

@ -119,6 +119,9 @@ class CoreWorkerClientInterface {
virtual void PushNormalTask(std::unique_ptr<PushTaskRequest> request, virtual void PushNormalTask(std::unique_ptr<PushTaskRequest> request,
const ClientCallback<PushTaskReply> &callback) {} const ClientCallback<PushTaskReply> &callback) {}
virtual void StealTasks(std::unique_ptr<StealTasksRequest> request,
const ClientCallback<StealTasksReply> &callback) {}
/// Notify a wait has completed for direct actor call arguments. /// Notify a wait has completed for direct actor call arguments.
/// ///
/// \param[in] request The request message. /// \param[in] request The request message.
@ -292,6 +295,11 @@ class CoreWorkerClient : public std::enable_shared_from_this<CoreWorkerClient>,
INVOKE_RPC_CALL(CoreWorkerService, PushTask, *request, callback, grpc_client_); INVOKE_RPC_CALL(CoreWorkerService, PushTask, *request, callback, grpc_client_);
} }
void StealTasks(std::unique_ptr<StealTasksRequest> request,
const ClientCallback<StealTasksReply> &callback) override {
INVOKE_RPC_CALL(CoreWorkerService, StealTasks, *request, callback, grpc_client_);
}
/// Send as many pending tasks as possible. This method is thread-safe. /// Send as many pending tasks as possible. This method is thread-safe.
/// ///
/// The client will guarantee no more than kMaxBytesInFlight bytes of RPCs are being /// The client will guarantee no more than kMaxBytesInFlight bytes of RPCs are being

View file

@ -29,6 +29,7 @@ namespace rpc {
/// NOTE: See src/ray/core_worker/core_worker.h on how to add a new grpc handler. /// NOTE: See src/ray/core_worker/core_worker.h on how to add a new grpc handler.
#define RAY_CORE_WORKER_RPC_HANDLERS \ #define RAY_CORE_WORKER_RPC_HANDLERS \
RPC_SERVICE_HANDLER(CoreWorkerService, PushTask) \ RPC_SERVICE_HANDLER(CoreWorkerService, PushTask) \
RPC_SERVICE_HANDLER(CoreWorkerService, StealTasks) \
RPC_SERVICE_HANDLER(CoreWorkerService, DirectActorCallArgWaitComplete) \ RPC_SERVICE_HANDLER(CoreWorkerService, DirectActorCallArgWaitComplete) \
RPC_SERVICE_HANDLER(CoreWorkerService, GetObjectStatus) \ RPC_SERVICE_HANDLER(CoreWorkerService, GetObjectStatus) \
RPC_SERVICE_HANDLER(CoreWorkerService, WaitForActorOutOfScope) \ RPC_SERVICE_HANDLER(CoreWorkerService, WaitForActorOutOfScope) \
@ -52,6 +53,7 @@ namespace rpc {
#define RAY_CORE_WORKER_DECLARE_RPC_HANDLERS \ #define RAY_CORE_WORKER_DECLARE_RPC_HANDLERS \
DECLARE_VOID_RPC_SERVICE_HANDLER_METHOD(PushTask) \ DECLARE_VOID_RPC_SERVICE_HANDLER_METHOD(PushTask) \
DECLARE_VOID_RPC_SERVICE_HANDLER_METHOD(StealTasks) \
DECLARE_VOID_RPC_SERVICE_HANDLER_METHOD(DirectActorCallArgWaitComplete) \ DECLARE_VOID_RPC_SERVICE_HANDLER_METHOD(DirectActorCallArgWaitComplete) \
DECLARE_VOID_RPC_SERVICE_HANDLER_METHOD(GetObjectStatus) \ DECLARE_VOID_RPC_SERVICE_HANDLER_METHOD(GetObjectStatus) \
DECLARE_VOID_RPC_SERVICE_HANDLER_METHOD(WaitForActorOutOfScope) \ DECLARE_VOID_RPC_SERVICE_HANDLER_METHOD(WaitForActorOutOfScope) \