diff --git a/python/ray/tests/test_cancel.py b/python/ray/tests/test_cancel.py index adec360a3..e4965e550 100644 --- a/python/ray/tests/test_cancel.py +++ b/python/ray/tests/test_cancel.py @@ -228,5 +228,31 @@ def test_fast(shutdown_only, use_force): assert isinstance(e, valid_exceptions(use_force)) +@pytest.mark.parametrize("use_force", [True, False]) +def test_remote_cancel(ray_start_regular, use_force): + signaler = SignalActor.remote() + + @ray.remote + def wait_for(y): + return ray.get(y[0]) + + @ray.remote + def remote_wait(sg): + return [wait_for.remote([sg[0]])] + + sig = signaler.wait.remote() + + outer = remote_wait.remote([sig]) + inner = ray.get(outer)[0] + + with pytest.raises(RayTimeoutError): + ray.get(inner, 1) + + ray.cancel(inner) + + with pytest.raises(valid_exceptions(use_force)): + ray.get(inner, 10) + + if __name__ == "__main__": sys.exit(pytest.main(["-v", __file__])) diff --git a/python/ray/worker.py b/python/ray/worker.py index 6963bd08c..0660ae5af 100644 --- a/python/ray/worker.py +++ b/python/ray/worker.py @@ -1679,7 +1679,7 @@ def kill(actor): def cancel(object_id, force=False): - """Cancels a locally-submitted task according to the following conditions. + """Cancels a task according to the following conditions. If the specified task is pending execution, it will not be executed. If the task is currently executing, the behavior depends on the ``force`` @@ -1698,8 +1698,7 @@ def cancel(object_id, force=False): force (boolean): Whether to force-kill a running task by killing the worker that is running the task. Raises: - ValueError: This is also raised for actor tasks, already completed - tasks, and non-locally submitted tasks. + TypeError: This is also raised for actor tasks. """ worker = ray.worker.global_worker worker.check_connected() diff --git a/src/ray/core_worker/core_worker.cc b/src/ray/core_worker/core_worker.cc index c3c0ffe26..a3e823a11 100644 --- a/src/ray/core_worker/core_worker.cc +++ b/src/ray/core_worker/core_worker.cc @@ -1218,9 +1218,11 @@ Status CoreWorker::CancelTask(const ObjectID &object_id, bool force_kill) { return Status::Invalid("Actor task cancellation is not supported."); } rpc::Address obj_addr; - if (!reference_counter_->GetOwner(object_id, nullptr, &obj_addr) || - obj_addr.SerializeAsString() != rpc_address_.SerializeAsString()) { - return Status::Invalid("Task is not locally submitted."); + if (!reference_counter_->GetOwner(object_id, nullptr, &obj_addr)) { + return Status::Invalid("No owner found for object."); + } + if (obj_addr.SerializeAsString() != rpc_address_.SerializeAsString()) { + return direct_task_submitter_->CancelRemoteTask(object_id, obj_addr, force_kill); } auto task_spec = task_manager_->GetTaskSpec(object_id.TaskId()); @@ -1766,6 +1768,14 @@ void CoreWorker::HandleWaitForRefRemoved(const rpc::WaitForRefRemovedRequest &re owner_address, ref_removed_callback); } +void CoreWorker::HandleRemoteCancelTask(const rpc::RemoteCancelTaskRequest &request, + rpc::RemoteCancelTaskReply *reply, + rpc::SendReplyCallback send_reply_callback) { + auto status = + CancelTask(ObjectID::FromBinary(request.remote_object_id()), request.force_kill()); + send_reply_callback(status, nullptr, nullptr); +} + void CoreWorker::HandleCancelTask(const rpc::CancelTaskRequest &request, rpc::CancelTaskReply *reply, rpc::SendReplyCallback send_reply_callback) { diff --git a/src/ray/core_worker/core_worker.h b/src/ray/core_worker/core_worker.h index 520a8bf6c..27c374ccf 100644 --- a/src/ray/core_worker/core_worker.h +++ b/src/ray/core_worker/core_worker.h @@ -707,6 +707,11 @@ class CoreWorker : public rpc::CoreWorkerServiceHandler { rpc::CancelTaskReply *reply, rpc::SendReplyCallback send_reply_callback) override; + /// Implements gRPC server handler. + void HandleRemoteCancelTask(const rpc::RemoteCancelTaskRequest &request, + rpc::RemoteCancelTaskReply *reply, + rpc::SendReplyCallback send_reply_callback) override; + /// Implements gRPC server handler. void HandlePlasmaObjectReady(const rpc::PlasmaObjectReadyRequest &request, rpc::PlasmaObjectReadyReply *reply, diff --git a/src/ray/core_worker/transport/direct_task_transport.cc b/src/ray/core_worker/transport/direct_task_transport.cc index a9abdeb1e..ba2104160 100644 --- a/src/ray/core_worker/transport/direct_task_transport.cc +++ b/src/ray/core_worker/transport/direct_task_transport.cc @@ -367,4 +367,19 @@ Status CoreWorkerDirectTaskSubmitter::CancelTask(TaskSpecification task_spec, })); return Status::OK(); } + +Status CoreWorkerDirectTaskSubmitter::CancelRemoteTask(const ObjectID &object_id, + const rpc::Address &worker_addr, + bool force_kill) { + absl::MutexLock lock(&mu_); + auto client = client_cache_.find(rpc::WorkerAddress(worker_addr)); + if (client == client_cache_.end()) { + return Status::Invalid("No remote worker found"); + } + auto request = rpc::RemoteCancelTaskRequest(); + request.set_force_kill(force_kill); + request.set_remote_object_id(object_id.Binary()); + return client->second->RemoteCancelTask(request, nullptr); +} + }; // namespace ray diff --git a/src/ray/core_worker/transport/direct_task_transport.h b/src/ray/core_worker/transport/direct_task_transport.h index 2a90477ef..9be1d0323 100644 --- a/src/ray/core_worker/transport/direct_task_transport.h +++ b/src/ray/core_worker/transport/direct_task_transport.h @@ -80,6 +80,9 @@ class CoreWorkerDirectTaskSubmitter { /// \param[in] force_kill Whether to kill the worker executing the task. Status CancelTask(TaskSpecification task_spec, bool force_kill); + Status CancelRemoteTask(const ObjectID &object_id, const rpc::Address &worker_addr, + bool force_kill); + private: /// Schedule more work onto an idle worker or return it back to the raylet if /// no more tasks are queued for submission. If an error was encountered diff --git a/src/ray/protobuf/core_worker.proto b/src/ray/protobuf/core_worker.proto index 501830015..5ea21de3d 100644 --- a/src/ray/protobuf/core_worker.proto +++ b/src/ray/protobuf/core_worker.proto @@ -197,6 +197,16 @@ message CancelTaskReply { bool attempt_succeeded = 1; } +message RemoteCancelTaskRequest { + // Object ID of the remote task that should be killed. + bytes remote_object_id = 1; + // Whether to kill the worker. + bool force_kill = 2; +} + +message RemoteCancelTaskReply { +} + message GetCoreWorkerStatsRequest { // The ID of the worker this message is intended for. bytes intended_worker_id = 1; @@ -296,6 +306,8 @@ service CoreWorkerService { rpc KillActor(KillActorRequest) returns (KillActorReply); // Request that a worker cancels a task. rpc CancelTask(CancelTaskRequest) returns (CancelTaskReply); + // Request for a worker to issue a cancelation. + rpc RemoteCancelTask(RemoteCancelTaskRequest) returns (RemoteCancelTaskReply); // Get metrics from core workers. rpc GetCoreWorkerStats(GetCoreWorkerStatsRequest) returns (GetCoreWorkerStatsReply); // Wait for a borrower to finish using an object. Sent by the object's owner. diff --git a/src/ray/rpc/worker/core_worker_client.h b/src/ray/rpc/worker/core_worker_client.h index db4ce5ee4..d175d0553 100644 --- a/src/ray/rpc/worker/core_worker_client.h +++ b/src/ray/rpc/worker/core_worker_client.h @@ -162,6 +162,12 @@ class CoreWorkerClientInterface { return Status::NotImplemented(""); } + virtual ray::Status RemoteCancelTask( + const RemoteCancelTaskRequest &request, + const ClientCallback &callback) { + return Status::NotImplemented(""); + } + virtual ray::Status GetCoreWorkerStats( const GetCoreWorkerStatsRequest &request, const ClientCallback &callback) { @@ -217,6 +223,8 @@ class CoreWorkerClient : public std::enable_shared_from_this, RPC_CLIENT_METHOD(CoreWorkerService, CancelTask, grpc_client_, override) + RPC_CLIENT_METHOD(CoreWorkerService, RemoteCancelTask, grpc_client_, override) + RPC_CLIENT_METHOD(CoreWorkerService, WaitForObjectEviction, grpc_client_, override) RPC_CLIENT_METHOD(CoreWorkerService, GetCoreWorkerStats, grpc_client_, override) diff --git a/src/ray/rpc/worker/core_worker_server.h b/src/ray/rpc/worker/core_worker_server.h index 152c971cb..a73dfba16 100644 --- a/src/ray/rpc/worker/core_worker_server.h +++ b/src/ray/rpc/worker/core_worker_server.h @@ -36,6 +36,7 @@ namespace rpc { RPC_SERVICE_HANDLER(CoreWorkerService, WaitForRefRemoved) \ RPC_SERVICE_HANDLER(CoreWorkerService, KillActor) \ RPC_SERVICE_HANDLER(CoreWorkerService, CancelTask) \ + RPC_SERVICE_HANDLER(CoreWorkerService, RemoteCancelTask) \ RPC_SERVICE_HANDLER(CoreWorkerService, GetCoreWorkerStats) \ RPC_SERVICE_HANDLER(CoreWorkerService, LocalGC) \ RPC_SERVICE_HANDLER(CoreWorkerService, PlasmaObjectReady) @@ -49,6 +50,7 @@ namespace rpc { DECLARE_VOID_RPC_SERVICE_HANDLER_METHOD(WaitForRefRemoved) \ DECLARE_VOID_RPC_SERVICE_HANDLER_METHOD(KillActor) \ DECLARE_VOID_RPC_SERVICE_HANDLER_METHOD(CancelTask) \ + DECLARE_VOID_RPC_SERVICE_HANDLER_METHOD(RemoteCancelTask) \ DECLARE_VOID_RPC_SERVICE_HANDLER_METHOD(GetCoreWorkerStats) \ DECLARE_VOID_RPC_SERVICE_HANDLER_METHOD(LocalGC) \ DECLARE_VOID_RPC_SERVICE_HANDLER_METHOD(PlasmaObjectReady)