diff --git a/python/ray/tests/test_gcs_fault_tolerance.py b/python/ray/tests/test_gcs_fault_tolerance.py index fcf0b7074..59c76752c 100644 --- a/python/ray/tests/test_gcs_fault_tolerance.py +++ b/python/ray/tests/test_gcs_fault_tolerance.py @@ -294,6 +294,46 @@ def test_raylet_resubscription(tmp_path, ray_start_regular_with_external_redis): wait_for_pid_to_exit(long_run_pid, 5) +@pytest.mark.parametrize( + "ray_start_regular_with_external_redis", + [ + generate_system_config_map( + num_heartbeats_timeout=20, gcs_rpc_server_reconnect_timeout_s=60 + ) + ], + indirect=True, +) +def test_core_worker_resubscription(tmp_path, ray_start_regular_with_external_redis): + # This test is to ensure core worker will resubscribe to GCS after GCS + # restarts. + from filelock import FileLock + + lock_file = str(tmp_path / "lock") + lock = FileLock(lock_file) + lock.acquire() + + @ray.remote + class Actor: + def __init__(self): + lock = FileLock(lock_file) + lock.acquire() + + def ready(self): + return + + a = Actor.remote() + r = a.ready.remote() + # Actor is not ready before GCS is down. + ray.worker._global_node.kill_gcs_server() + + lock.release() + # Actor is ready after GCS starts + ray.worker._global_node.start_gcs_server() + # Test the resubscribe works: if not, it'll timeout because worker + # will think the actor is not ready. + ray.get(r, timeout=5) + + @pytest.mark.parametrize("auto_reconnect", [True, False]) def test_gcs_client_reconnect(ray_start_regular_with_external_redis, auto_reconnect): gcs_address = ray.worker.global_worker.gcs_client.address diff --git a/src/mock/ray/raylet/worker.h b/src/mock/ray/raylet/worker.h index a296183d4..7c915a731 100644 --- a/src/mock/ray/raylet/worker.h +++ b/src/mock/ray/raylet/worker.h @@ -28,6 +28,7 @@ class MockWorkerInterface : public WorkerInterface { MOCK_METHOD(void, SetProcess, (Process proc), (override)); MOCK_METHOD(Language, GetLanguage, (), (const, override)); MOCK_METHOD(const std::string, IpAddress, (), (const, override)); + MOCK_METHOD(void, AsyncNotifyGCSRestart, (), (override)); MOCK_METHOD(void, Connect, (int port), (override)); MOCK_METHOD(void, Connect, diff --git a/src/ray/core_worker/core_worker.cc b/src/ray/core_worker/core_worker.cc index e2e2a473c..af98e6a93 100644 --- a/src/ray/core_worker/core_worker.cc +++ b/src/ray/core_worker/core_worker.cc @@ -2658,6 +2658,14 @@ void CoreWorker::HandleDirectActorCallArgWaitComplete( send_reply_callback(Status::OK(), nullptr, nullptr); } +void CoreWorker::HandleRayletNotifyGCSRestart( + const rpc::RayletNotifyGCSRestartRequest &request, + rpc::RayletNotifyGCSRestartReply *reply, + rpc::SendReplyCallback send_reply_callback) { + gcs_client_->AsyncResubscribe(); + send_reply_callback(Status::OK(), nullptr, nullptr); +} + void CoreWorker::HandleGetObjectStatus(const rpc::GetObjectStatusRequest &request, rpc::GetObjectStatusReply *reply, rpc::SendReplyCallback send_reply_callback) { diff --git a/src/ray/core_worker/core_worker.h b/src/ray/core_worker/core_worker.h index 777db21b2..bfe1b9050 100644 --- a/src/ray/core_worker/core_worker.h +++ b/src/ray/core_worker/core_worker.h @@ -687,6 +687,11 @@ class CoreWorker : public rpc::CoreWorkerServiceHandler { rpc::DirectActorCallArgWaitCompleteReply *reply, rpc::SendReplyCallback send_reply_callback) override; + /// Implements gRPC server handler. + void HandleRayletNotifyGCSRestart(const rpc::RayletNotifyGCSRestartRequest &request, + rpc::RayletNotifyGCSRestartReply *reply, + rpc::SendReplyCallback send_reply_callback) override; + /// Implements gRPC server handler. void HandleGetObjectStatus(const rpc::GetObjectStatusRequest &request, rpc::GetObjectStatusReply *reply, diff --git a/src/ray/protobuf/core_worker.proto b/src/ray/protobuf/core_worker.proto index 977315efb..0ec55e98e 100644 --- a/src/ray/protobuf/core_worker.proto +++ b/src/ray/protobuf/core_worker.proto @@ -357,10 +357,16 @@ message AssignObjectOwnerRequest { string call_site = 5; } -message AssignObjectOwnerReply { -} +message AssignObjectOwnerReply {} + +message RayletNotifyGCSRestartRequest {} + +message RayletNotifyGCSRestartReply {} service CoreWorkerService { + // Notify core worker GCS has restarted. + rpc RayletNotifyGCSRestart(RayletNotifyGCSRestartRequest) + returns (RayletNotifyGCSRestartReply); // Push a task directly to this worker from another. rpc PushTask(PushTaskRequest) returns (PushTaskReply); // Reply from raylet that wait for direct actor call args has completed. diff --git a/src/ray/raylet/node_manager.cc b/src/ray/raylet/node_manager.cc index de9f22ee3..b95a458a7 100644 --- a/src/ray/raylet/node_manager.cc +++ b/src/ray/raylet/node_manager.cc @@ -1058,7 +1058,19 @@ void NodeManager::ResourceDeleted(const NodeID &node_id, void NodeManager::HandleNotifyGCSRestart(const rpc::NotifyGCSRestartRequest &request, rpc::NotifyGCSRestartReply *reply, rpc::SendReplyCallback send_reply_callback) { + // When GCS restarts, it'll notify raylet to do some initialization work + // (resubscribing). Raylet will also notify all workers to do this job. Workers are + // registered to raylet first (blocking call) and then connect to GCS, so there is no + // race condition here. gcs_client_->AsyncResubscribe(); + auto workers = worker_pool_.GetAllRegisteredWorkers(true); + for (auto worker : workers) { + worker->AsyncNotifyGCSRestart(); + } + auto drivers = worker_pool_.GetAllRegisteredDrivers(true); + for (auto driver : drivers) { + driver->AsyncNotifyGCSRestart(); + } send_reply_callback(Status::OK(), nullptr, nullptr); } diff --git a/src/ray/raylet/test/util.h b/src/ray/raylet/test/util.h index 749398c95..e7f3ebeda 100644 --- a/src/ray/raylet/test/util.h +++ b/src/ray/raylet/test/util.h @@ -26,138 +26,144 @@ class MockWorker : public WorkerInterface { is_detached_actor_(false), runtime_env_hash_(runtime_env_hash) {} - WorkerID WorkerId() const { return worker_id_; } + WorkerID WorkerId() const override { return worker_id_; } - rpc::WorkerType GetWorkerType() const { return rpc::WorkerType::WORKER; } + rpc::WorkerType GetWorkerType() const override { return rpc::WorkerType::WORKER; } - int Port() const { return port_; } + int Port() const override { return port_; } - void SetOwnerAddress(const rpc::Address &address) { address_ = address; } + void SetOwnerAddress(const rpc::Address &address) override { address_ = address; } - void AssignTaskId(const TaskID &task_id) {} + void AssignTaskId(const TaskID &task_id) override {} - void SetAssignedTask(const RayTask &assigned_task) { task_ = assigned_task; } + void SetAssignedTask(const RayTask &assigned_task) override { task_ = assigned_task; } - const std::string IpAddress() const { return address_.ip_address(); } + const std::string IpAddress() const override { return address_.ip_address(); } + + void AsyncNotifyGCSRestart() override {} void SetAllocatedInstances( - const std::shared_ptr &allocated_instances) { + const std::shared_ptr &allocated_instances) override { allocated_instances_ = allocated_instances; } void SetLifetimeAllocatedInstances( - const std::shared_ptr &allocated_instances) { + const std::shared_ptr &allocated_instances) override { lifetime_allocated_instances_ = allocated_instances; } - std::shared_ptr GetAllocatedInstances() { + std::shared_ptr GetAllocatedInstances() override { return allocated_instances_; } - std::shared_ptr GetLifetimeAllocatedInstances() { + std::shared_ptr GetLifetimeAllocatedInstances() override { return lifetime_allocated_instances_; } - void MarkDead() { RAY_CHECK(false) << "Method unused"; } - bool IsDead() const { + void MarkDead() override { RAY_CHECK(false) << "Method unused"; } + bool IsDead() const override { RAY_CHECK(false) << "Method unused"; return false; } - void MarkBlocked() { blocked_ = true; } - void MarkUnblocked() { blocked_ = false; } - bool IsBlocked() const { return blocked_; } + void MarkBlocked() override { blocked_ = true; } + void MarkUnblocked() override { blocked_ = false; } + bool IsBlocked() const override { return blocked_; } - Process GetProcess() const { return Process::CreateNewDummy(); } - StartupToken GetStartupToken() const { return 0; } - void SetProcess(Process proc) { RAY_CHECK(false) << "Method unused"; } + Process GetProcess() const override { return Process::CreateNewDummy(); } + StartupToken GetStartupToken() const override { return 0; } + void SetProcess(Process proc) override { RAY_CHECK(false) << "Method unused"; } - Language GetLanguage() const { + Language GetLanguage() const override { RAY_CHECK(false) << "Method unused"; return Language::PYTHON; } - void Connect(int port) { RAY_CHECK(false) << "Method unused"; } + void Connect(int port) override { RAY_CHECK(false) << "Method unused"; } - void Connect(std::shared_ptr rpc_client) { + void Connect(std::shared_ptr rpc_client) override { RAY_CHECK(false) << "Method unused"; } - int AssignedPort() const { + int AssignedPort() const override { RAY_CHECK(false) << "Method unused"; return -1; } - void SetAssignedPort(int port) { RAY_CHECK(false) << "Method unused"; } - const TaskID &GetAssignedTaskId() const { + void SetAssignedPort(int port) override { RAY_CHECK(false) << "Method unused"; } + const TaskID &GetAssignedTaskId() const override { RAY_CHECK(false) << "Method unused"; return TaskID::Nil(); } - bool AddBlockedTaskId(const TaskID &task_id) { + bool AddBlockedTaskId(const TaskID &task_id) override { RAY_CHECK(false) << "Method unused"; return false; } - bool RemoveBlockedTaskId(const TaskID &task_id) { + bool RemoveBlockedTaskId(const TaskID &task_id) override { RAY_CHECK(false) << "Method unused"; return false; } - const std::unordered_set &GetBlockedTaskIds() const { + const std::unordered_set &GetBlockedTaskIds() const override { RAY_CHECK(false) << "Method unused"; auto *t = new std::unordered_set(); return *t; } - const JobID &GetAssignedJobId() const { + const JobID &GetAssignedJobId() const override { RAY_CHECK(false) << "Method unused"; return JobID::Nil(); } - int GetRuntimeEnvHash() const { return runtime_env_hash_; } - void AssignActorId(const ActorID &actor_id) { RAY_CHECK(false) << "Method unused"; } - const ActorID &GetActorId() const { + int GetRuntimeEnvHash() const override { return runtime_env_hash_; } + void AssignActorId(const ActorID &actor_id) override { + RAY_CHECK(false) << "Method unused"; + } + const ActorID &GetActorId() const override { RAY_CHECK(false) << "Method unused"; return ActorID::Nil(); } - void MarkDetachedActor() { is_detached_actor_ = true; } - bool IsDetachedActor() const { return is_detached_actor_; } - const std::shared_ptr Connection() const { + void MarkDetachedActor() override { is_detached_actor_ = true; } + bool IsDetachedActor() const override { return is_detached_actor_; } + const std::shared_ptr Connection() const override { RAY_CHECK(false) << "Method unused"; return nullptr; } - const rpc::Address &GetOwnerAddress() const { + const rpc::Address &GetOwnerAddress() const override { RAY_CHECK(false) << "Method unused"; return address_; } - void DirectActorCallArgWaitComplete(int64_t tag) { + void DirectActorCallArgWaitComplete(int64_t tag) override { RAY_CHECK(false) << "Method unused"; } - void ClearAllocatedInstances() { allocated_instances_ = nullptr; } + void ClearAllocatedInstances() override { allocated_instances_ = nullptr; } - void ClearLifetimeAllocatedInstances() { lifetime_allocated_instances_ = nullptr; } + void ClearLifetimeAllocatedInstances() override { + lifetime_allocated_instances_ = nullptr; + } - const BundleID &GetBundleId() const { + const BundleID &GetBundleId() const override { RAY_CHECK(false) << "Method unused"; return bundle_id_; } - void SetBundleId(const BundleID &bundle_id) { bundle_id_ = bundle_id; } + void SetBundleId(const BundleID &bundle_id) override { bundle_id_ = bundle_id; } - RayTask &GetAssignedTask() { return task_; } + RayTask &GetAssignedTask() override { return task_; } - bool IsRegistered() { + bool IsRegistered() override { RAY_CHECK(false) << "Method unused"; return false; } - rpc::CoreWorkerClientInterface *rpc_client() { + rpc::CoreWorkerClientInterface *rpc_client() override { RAY_CHECK(false) << "Method unused"; return nullptr; } - bool IsAvailableForScheduling() const { + bool IsAvailableForScheduling() const override { RAY_CHECK(false) << "Method unused"; return true; } protected: - void SetStartupToken(StartupToken startup_token) { + void SetStartupToken(StartupToken startup_token) override { RAY_CHECK(false) << "Method unused"; }; diff --git a/src/ray/raylet/worker.cc b/src/ray/raylet/worker.cc index cac5dff73..49d0a6556 100644 --- a/src/ray/raylet/worker.cc +++ b/src/ray/raylet/worker.cc @@ -96,6 +96,20 @@ int Worker::AssignedPort() const { return assigned_port_; } void Worker::SetAssignedPort(int port) { assigned_port_ = port; }; +void Worker::AsyncNotifyGCSRestart() { + if (rpc_client_) { + rpc::RayletNotifyGCSRestartRequest request; + rpc_client_->RayletNotifyGCSRestart(request, [](Status status, auto reply) { + if (!status.ok()) { + RAY_LOG(ERROR) << "Failed to notify worker about GCS restarting: " + << status.ToString(); + } + }); + } else { + notify_gcs_restarted_ = true; + } +} + void Worker::Connect(int port) { RAY_CHECK(port > 0); port_ = port; @@ -103,10 +117,16 @@ void Worker::Connect(int port) { addr.set_ip_address(ip_address_); addr.set_port(port_); rpc_client_ = std::make_unique(addr, client_call_manager_); + Connect(rpc_client_); } void Worker::Connect(std::shared_ptr rpc_client) { rpc_client_ = rpc_client; + if (notify_gcs_restarted_) { + // We need to send RPC to notify about the GCS restarts + AsyncNotifyGCSRestart(); + notify_gcs_restarted_ = false; + } } void Worker::AssignTaskId(const TaskID &task_id) { assigned_task_id_ = task_id; } diff --git a/src/ray/raylet/worker.h b/src/ray/raylet/worker.h index cfee1ea6f..84dc7bc9a 100644 --- a/src/ray/raylet/worker.h +++ b/src/ray/raylet/worker.h @@ -54,6 +54,7 @@ class WorkerInterface { virtual void SetProcess(Process proc) = 0; virtual Language GetLanguage() const = 0; virtual const std::string IpAddress() const = 0; + virtual void AsyncNotifyGCSRestart() = 0; /// Connect this worker's gRPC client. virtual void Connect(int port) = 0; /// Testing-only @@ -150,6 +151,7 @@ class Worker : public WorkerInterface { void SetProcess(Process proc); Language GetLanguage() const; const std::string IpAddress() const; + void AsyncNotifyGCSRestart(); /// Connect this worker's gRPC client. void Connect(int port); /// Testing-only @@ -281,6 +283,8 @@ class Worker : public WorkerInterface { std::shared_ptr lifetime_allocated_instances_; /// RayTask being assigned to this worker. RayTask assigned_task_; + /// If true, a RPC need to be sent to notify the worker about GCS restarting. + bool notify_gcs_restarted_ = false; }; } // namespace raylet diff --git a/src/ray/rpc/worker/core_worker_client.h b/src/ray/rpc/worker/core_worker_client.h index a2a76d001..dc80444cf 100644 --- a/src/ray/rpc/worker/core_worker_client.h +++ b/src/ray/rpc/worker/core_worker_client.h @@ -193,6 +193,10 @@ class CoreWorkerClientInterface : public pubsub::SubscriberClientInterface { const ClientCallback &callback) { } + virtual void RayletNotifyGCSRestart( + const RayletNotifyGCSRestartRequest &request, + const ClientCallback &callback) {} + /// Returns the max acked sequence number, useful for checking on progress. virtual int64_t ClientProcessedUpToSeqno() { return -1; } @@ -312,6 +316,12 @@ class CoreWorkerClient : public std::enable_shared_from_this, /*method_timeout_ms*/ -1, override) + VOID_RPC_CLIENT_METHOD(CoreWorkerService, + RayletNotifyGCSRestart, + grpc_client_, + /*method_timeout_ms*/ -1, + override) + VOID_RPC_CLIENT_METHOD( CoreWorkerService, Exit, grpc_client_, /*method_timeout_ms*/ -1, override) diff --git a/src/ray/rpc/worker/core_worker_server.h b/src/ray/rpc/worker/core_worker_server.h index 18dc51b87..a66ef4657 100644 --- a/src/ray/rpc/worker/core_worker_server.h +++ b/src/ray/rpc/worker/core_worker_server.h @@ -30,6 +30,7 @@ namespace rpc { #define RAY_CORE_WORKER_RPC_HANDLERS \ RPC_SERVICE_HANDLER(CoreWorkerService, PushTask, -1) \ RPC_SERVICE_HANDLER(CoreWorkerService, DirectActorCallArgWaitComplete, -1) \ + RPC_SERVICE_HANDLER(CoreWorkerService, RayletNotifyGCSRestart, -1) \ RPC_SERVICE_HANDLER(CoreWorkerService, GetObjectStatus, -1) \ RPC_SERVICE_HANDLER(CoreWorkerService, WaitForActorOutOfScope, -1) \ RPC_SERVICE_HANDLER(CoreWorkerService, PubsubLongPolling, -1) \ @@ -51,6 +52,7 @@ namespace rpc { #define RAY_CORE_WORKER_DECLARE_RPC_HANDLERS \ DECLARE_VOID_RPC_SERVICE_HANDLER_METHOD(PushTask) \ DECLARE_VOID_RPC_SERVICE_HANDLER_METHOD(DirectActorCallArgWaitComplete) \ + DECLARE_VOID_RPC_SERVICE_HANDLER_METHOD(RayletNotifyGCSRestart) \ DECLARE_VOID_RPC_SERVICE_HANDLER_METHOD(GetObjectStatus) \ DECLARE_VOID_RPC_SERVICE_HANDLER_METHOD(WaitForActorOutOfScope) \ DECLARE_VOID_RPC_SERVICE_HANDLER_METHOD(PubsubLongPolling) \