Add WorkerID check to AssignTask (#6355)

This commit is contained in:
Edward Oakes 2019-12-04 12:38:29 -08:00 committed by GitHub
parent 1a3b83abf8
commit f65d65f5de
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 32 additions and 7 deletions

View file

@ -108,7 +108,8 @@ CoreWorker::CoreWorker(const WorkerType worker_type, const Language language,
std::placeholders::_2, std::placeholders::_3);
raylet_task_receiver_ =
std::unique_ptr<CoreWorkerRayletTaskReceiver>(new CoreWorkerRayletTaskReceiver(
local_raylet_client_, execute_task, exit_handler));
worker_context_.GetWorkerID(), local_raylet_client_, execute_task,
exit_handler));
direct_task_receiver_ =
std::unique_ptr<CoreWorkerDirectTaskReceiver>(new CoreWorkerDirectTaskReceiver(
worker_context_, task_execution_service_, execute_task, exit_handler));

View file

@ -6,15 +6,29 @@
namespace ray {
CoreWorkerRayletTaskReceiver::CoreWorkerRayletTaskReceiver(
std::shared_ptr<RayletClient> &raylet_client, const TaskHandler &task_handler,
const std::function<void()> &exit_handler)
: raylet_client_(raylet_client),
const WorkerID &worker_id, std::shared_ptr<RayletClient> &raylet_client,
const TaskHandler &task_handler, const std::function<void()> &exit_handler)
: worker_id_(worker_id),
raylet_client_(raylet_client),
task_handler_(task_handler),
exit_handler_(exit_handler) {}
void CoreWorkerRayletTaskReceiver::HandleAssignTask(
const rpc::AssignTaskRequest &request, rpc::AssignTaskReply *reply,
rpc::SendReplyCallback send_reply_callback) {
// Check that the message was intended for our WorkerID and drop it if not.
// This handles the case where a message is delayed so we get an AssignTask
// bound for a previous worker that is now dead because we bound to the same
// port. Note that returning the status here doesn't actually cause the raylet
// to fail the task, that happens due to the unintentional disconnect from the
// asio connection when the previous worker died.
WorkerID intended_worker_id = WorkerID::FromBinary(request.worker_id());
if (intended_worker_id != worker_id_) {
RAY_LOG(WARNING) << "Received task for mismatched WorkerID " << intended_worker_id;
send_reply_callback(Status::Invalid("Mismatched WorkerID"), nullptr, nullptr);
return;
}
const Task task(request.task());
const auto &task_spec = task.GetTaskSpecification();
RAY_LOG(DEBUG) << "Received task " << task_spec.TaskId() << " is create "

View file

@ -15,7 +15,8 @@ class CoreWorkerRayletTaskReceiver {
const TaskSpecification &task_spec, const ResourceMappingType &resource_ids,
std::vector<std::shared_ptr<RayObject>> *return_objects)>;
CoreWorkerRayletTaskReceiver(std::shared_ptr<RayletClient> &raylet_client,
CoreWorkerRayletTaskReceiver(const WorkerID &worker_id,
std::shared_ptr<RayletClient> &raylet_client,
const TaskHandler &task_handler,
const std::function<void()> &exit_handler);
@ -31,6 +32,8 @@ class CoreWorkerRayletTaskReceiver {
rpc::SendReplyCallback send_reply_callback);
private:
// WorkerID of this worker.
WorkerID worker_id_;
/// Reference to the core worker's raylet client. This is a pointer ref so that it
/// can be initialized by core worker after this class is constructed.
std::shared_ptr<RayletClient> &raylet_client_;

View file

@ -33,12 +33,18 @@ message ActorHandle {
}
message AssignTaskRequest {
// The ID of the worker this message is intended for. This is used to
// ensure that workers don't try to execute tasks assigned to workers
// that used to be bound to the same port.
bytes worker_id = 1;
// The task to be pushed.
Task task = 1;
Task task = 2;
// A list of the resources reserved for this worker.
// TODO(zhijunfu): `resource_ids` is represented as
// flatbutters-serialized bytes, will be moved to protobuf later.
bytes resource_ids = 2;
bytes resource_ids = 3;
}
message AssignTaskReply {

View file

@ -131,6 +131,7 @@ void Worker::SetActiveObjectIds(const std::unordered_set<ObjectID> &&object_ids)
Status Worker::AssignTask(const Task &task, const ResourceIdSet &resource_id_set) {
RAY_CHECK(port_ > 0);
rpc::AssignTaskRequest request;
request.set_worker_id(worker_id_.Binary());
request.mutable_task()->mutable_task_spec()->CopyFrom(
task.GetTaskSpecification().GetMessage());
request.mutable_task()->mutable_task_execution_spec()->CopyFrom(