[core][2/2] Worker resubscribe when GCS failed (#24813)

A follow-up PR from this one: https://github.com/ray-project/ray/pull/24628

In the previous PR, it fixed the resubscribing issue for raylet. But there is also core worker which needs to do resubscribing.

There are two ways of doing resubscribe:
1. When the client-side detects any failure, it'll do resubscribing.
2. Server side will ask the client to do resubscribing.

1) is a cleaner and better solution. However, it's a little bit hard due to the following reasons:

- We are using long-polling, so for some extreme cases, we won't be able to detect the failure. For example, the client-side received the message, but before it sends another request, the server-side restarts, and the client will miss the opportunity of detecting the failure. This could happen if we have a standby GCS that starts very fast and somehow the client-side has a lot of traffic and runs very slow.
- The current gRPC framework doesn't give the user a way to handle failure which might need some refactoring on this one.

We can go with this way once we have gRPC streaming.

This PR is implementing 2) which includes three parts:
- raylet: (https://github.com/ray-project/ray/pull/24628)
- core worker: (this pr)
- python

Correctness: whenever when a worker started, it'll register to raylet immediately (sync call) before connecting to GCS. So, we just need to send all restart rpcs to registered workers and it should work because:
- if the worker just started and hasn't registered with the raylet: it's ok, because the worker hasn't connected with GCS yet, so no need to do resubscribing.
- if the worker has registered with the rayelt: it's covered by the code path here.
This commit is contained in:
Yi Cheng 2022-05-16 23:47:52 -07:00 committed by GitHub
parent a565948094
commit 379fa634ac
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
11 changed files with 162 additions and 48 deletions

View file

@ -294,6 +294,46 @@ def test_raylet_resubscription(tmp_path, ray_start_regular_with_external_redis):
wait_for_pid_to_exit(long_run_pid, 5) wait_for_pid_to_exit(long_run_pid, 5)
@pytest.mark.parametrize(
"ray_start_regular_with_external_redis",
[
generate_system_config_map(
num_heartbeats_timeout=20, gcs_rpc_server_reconnect_timeout_s=60
)
],
indirect=True,
)
def test_core_worker_resubscription(tmp_path, ray_start_regular_with_external_redis):
# This test is to ensure core worker will resubscribe to GCS after GCS
# restarts.
from filelock import FileLock
lock_file = str(tmp_path / "lock")
lock = FileLock(lock_file)
lock.acquire()
@ray.remote
class Actor:
def __init__(self):
lock = FileLock(lock_file)
lock.acquire()
def ready(self):
return
a = Actor.remote()
r = a.ready.remote()
# Actor is not ready before GCS is down.
ray.worker._global_node.kill_gcs_server()
lock.release()
# Actor is ready after GCS starts
ray.worker._global_node.start_gcs_server()
# Test the resubscribe works: if not, it'll timeout because worker
# will think the actor is not ready.
ray.get(r, timeout=5)
@pytest.mark.parametrize("auto_reconnect", [True, False]) @pytest.mark.parametrize("auto_reconnect", [True, False])
def test_gcs_client_reconnect(ray_start_regular_with_external_redis, auto_reconnect): def test_gcs_client_reconnect(ray_start_regular_with_external_redis, auto_reconnect):
gcs_address = ray.worker.global_worker.gcs_client.address gcs_address = ray.worker.global_worker.gcs_client.address

View file

@ -28,6 +28,7 @@ class MockWorkerInterface : public WorkerInterface {
MOCK_METHOD(void, SetProcess, (Process proc), (override)); MOCK_METHOD(void, SetProcess, (Process proc), (override));
MOCK_METHOD(Language, GetLanguage, (), (const, override)); MOCK_METHOD(Language, GetLanguage, (), (const, override));
MOCK_METHOD(const std::string, IpAddress, (), (const, override)); MOCK_METHOD(const std::string, IpAddress, (), (const, override));
MOCK_METHOD(void, AsyncNotifyGCSRestart, (), (override));
MOCK_METHOD(void, Connect, (int port), (override)); MOCK_METHOD(void, Connect, (int port), (override));
MOCK_METHOD(void, MOCK_METHOD(void,
Connect, Connect,

View file

@ -2658,6 +2658,14 @@ void CoreWorker::HandleDirectActorCallArgWaitComplete(
send_reply_callback(Status::OK(), nullptr, nullptr); send_reply_callback(Status::OK(), nullptr, nullptr);
} }
void CoreWorker::HandleRayletNotifyGCSRestart(
const rpc::RayletNotifyGCSRestartRequest &request,
rpc::RayletNotifyGCSRestartReply *reply,
rpc::SendReplyCallback send_reply_callback) {
gcs_client_->AsyncResubscribe();
send_reply_callback(Status::OK(), nullptr, nullptr);
}
void CoreWorker::HandleGetObjectStatus(const rpc::GetObjectStatusRequest &request, void CoreWorker::HandleGetObjectStatus(const rpc::GetObjectStatusRequest &request,
rpc::GetObjectStatusReply *reply, rpc::GetObjectStatusReply *reply,
rpc::SendReplyCallback send_reply_callback) { rpc::SendReplyCallback send_reply_callback) {

View file

@ -687,6 +687,11 @@ class CoreWorker : public rpc::CoreWorkerServiceHandler {
rpc::DirectActorCallArgWaitCompleteReply *reply, rpc::DirectActorCallArgWaitCompleteReply *reply,
rpc::SendReplyCallback send_reply_callback) override; rpc::SendReplyCallback send_reply_callback) override;
/// Implements gRPC server handler.
void HandleRayletNotifyGCSRestart(const rpc::RayletNotifyGCSRestartRequest &request,
rpc::RayletNotifyGCSRestartReply *reply,
rpc::SendReplyCallback send_reply_callback) override;
/// Implements gRPC server handler. /// Implements gRPC server handler.
void HandleGetObjectStatus(const rpc::GetObjectStatusRequest &request, void HandleGetObjectStatus(const rpc::GetObjectStatusRequest &request,
rpc::GetObjectStatusReply *reply, rpc::GetObjectStatusReply *reply,

View file

@ -357,10 +357,16 @@ message AssignObjectOwnerRequest {
string call_site = 5; string call_site = 5;
} }
message AssignObjectOwnerReply { message AssignObjectOwnerReply {}
}
message RayletNotifyGCSRestartRequest {}
message RayletNotifyGCSRestartReply {}
service CoreWorkerService { service CoreWorkerService {
// Notify core worker GCS has restarted.
rpc RayletNotifyGCSRestart(RayletNotifyGCSRestartRequest)
returns (RayletNotifyGCSRestartReply);
// Push a task directly to this worker from another. // Push a task directly to this worker from another.
rpc PushTask(PushTaskRequest) returns (PushTaskReply); rpc PushTask(PushTaskRequest) returns (PushTaskReply);
// Reply from raylet that wait for direct actor call args has completed. // Reply from raylet that wait for direct actor call args has completed.

View file

@ -1058,7 +1058,19 @@ void NodeManager::ResourceDeleted(const NodeID &node_id,
void NodeManager::HandleNotifyGCSRestart(const rpc::NotifyGCSRestartRequest &request, void NodeManager::HandleNotifyGCSRestart(const rpc::NotifyGCSRestartRequest &request,
rpc::NotifyGCSRestartReply *reply, rpc::NotifyGCSRestartReply *reply,
rpc::SendReplyCallback send_reply_callback) { rpc::SendReplyCallback send_reply_callback) {
// When GCS restarts, it'll notify raylet to do some initialization work
// (resubscribing). Raylet will also notify all workers to do this job. Workers are
// registered to raylet first (blocking call) and then connect to GCS, so there is no
// race condition here.
gcs_client_->AsyncResubscribe(); gcs_client_->AsyncResubscribe();
auto workers = worker_pool_.GetAllRegisteredWorkers(true);
for (auto worker : workers) {
worker->AsyncNotifyGCSRestart();
}
auto drivers = worker_pool_.GetAllRegisteredDrivers(true);
for (auto driver : drivers) {
driver->AsyncNotifyGCSRestart();
}
send_reply_callback(Status::OK(), nullptr, nullptr); send_reply_callback(Status::OK(), nullptr, nullptr);
} }

View file

@ -26,138 +26,144 @@ class MockWorker : public WorkerInterface {
is_detached_actor_(false), is_detached_actor_(false),
runtime_env_hash_(runtime_env_hash) {} runtime_env_hash_(runtime_env_hash) {}
WorkerID WorkerId() const { return worker_id_; } WorkerID WorkerId() const override { return worker_id_; }
rpc::WorkerType GetWorkerType() const { return rpc::WorkerType::WORKER; } rpc::WorkerType GetWorkerType() const override { return rpc::WorkerType::WORKER; }
int Port() const { return port_; } int Port() const override { return port_; }
void SetOwnerAddress(const rpc::Address &address) { address_ = address; } void SetOwnerAddress(const rpc::Address &address) override { address_ = address; }
void AssignTaskId(const TaskID &task_id) {} void AssignTaskId(const TaskID &task_id) override {}
void SetAssignedTask(const RayTask &assigned_task) { task_ = assigned_task; } void SetAssignedTask(const RayTask &assigned_task) override { task_ = assigned_task; }
const std::string IpAddress() const { return address_.ip_address(); } const std::string IpAddress() const override { return address_.ip_address(); }
void AsyncNotifyGCSRestart() override {}
void SetAllocatedInstances( void SetAllocatedInstances(
const std::shared_ptr<TaskResourceInstances> &allocated_instances) { const std::shared_ptr<TaskResourceInstances> &allocated_instances) override {
allocated_instances_ = allocated_instances; allocated_instances_ = allocated_instances;
} }
void SetLifetimeAllocatedInstances( void SetLifetimeAllocatedInstances(
const std::shared_ptr<TaskResourceInstances> &allocated_instances) { const std::shared_ptr<TaskResourceInstances> &allocated_instances) override {
lifetime_allocated_instances_ = allocated_instances; lifetime_allocated_instances_ = allocated_instances;
} }
std::shared_ptr<TaskResourceInstances> GetAllocatedInstances() { std::shared_ptr<TaskResourceInstances> GetAllocatedInstances() override {
return allocated_instances_; return allocated_instances_;
} }
std::shared_ptr<TaskResourceInstances> GetLifetimeAllocatedInstances() { std::shared_ptr<TaskResourceInstances> GetLifetimeAllocatedInstances() override {
return lifetime_allocated_instances_; return lifetime_allocated_instances_;
} }
void MarkDead() { RAY_CHECK(false) << "Method unused"; } void MarkDead() override { RAY_CHECK(false) << "Method unused"; }
bool IsDead() const { bool IsDead() const override {
RAY_CHECK(false) << "Method unused"; RAY_CHECK(false) << "Method unused";
return false; return false;
} }
void MarkBlocked() { blocked_ = true; } void MarkBlocked() override { blocked_ = true; }
void MarkUnblocked() { blocked_ = false; } void MarkUnblocked() override { blocked_ = false; }
bool IsBlocked() const { return blocked_; } bool IsBlocked() const override { return blocked_; }
Process GetProcess() const { return Process::CreateNewDummy(); } Process GetProcess() const override { return Process::CreateNewDummy(); }
StartupToken GetStartupToken() const { return 0; } StartupToken GetStartupToken() const override { return 0; }
void SetProcess(Process proc) { RAY_CHECK(false) << "Method unused"; } void SetProcess(Process proc) override { RAY_CHECK(false) << "Method unused"; }
Language GetLanguage() const { Language GetLanguage() const override {
RAY_CHECK(false) << "Method unused"; RAY_CHECK(false) << "Method unused";
return Language::PYTHON; return Language::PYTHON;
} }
void Connect(int port) { RAY_CHECK(false) << "Method unused"; } void Connect(int port) override { RAY_CHECK(false) << "Method unused"; }
void Connect(std::shared_ptr<rpc::CoreWorkerClientInterface> rpc_client) { void Connect(std::shared_ptr<rpc::CoreWorkerClientInterface> rpc_client) override {
RAY_CHECK(false) << "Method unused"; RAY_CHECK(false) << "Method unused";
} }
int AssignedPort() const { int AssignedPort() const override {
RAY_CHECK(false) << "Method unused"; RAY_CHECK(false) << "Method unused";
return -1; return -1;
} }
void SetAssignedPort(int port) { RAY_CHECK(false) << "Method unused"; } void SetAssignedPort(int port) override { RAY_CHECK(false) << "Method unused"; }
const TaskID &GetAssignedTaskId() const { const TaskID &GetAssignedTaskId() const override {
RAY_CHECK(false) << "Method unused"; RAY_CHECK(false) << "Method unused";
return TaskID::Nil(); return TaskID::Nil();
} }
bool AddBlockedTaskId(const TaskID &task_id) { bool AddBlockedTaskId(const TaskID &task_id) override {
RAY_CHECK(false) << "Method unused"; RAY_CHECK(false) << "Method unused";
return false; return false;
} }
bool RemoveBlockedTaskId(const TaskID &task_id) { bool RemoveBlockedTaskId(const TaskID &task_id) override {
RAY_CHECK(false) << "Method unused"; RAY_CHECK(false) << "Method unused";
return false; return false;
} }
const std::unordered_set<TaskID> &GetBlockedTaskIds() const { const std::unordered_set<TaskID> &GetBlockedTaskIds() const override {
RAY_CHECK(false) << "Method unused"; RAY_CHECK(false) << "Method unused";
auto *t = new std::unordered_set<TaskID>(); auto *t = new std::unordered_set<TaskID>();
return *t; return *t;
} }
const JobID &GetAssignedJobId() const { const JobID &GetAssignedJobId() const override {
RAY_CHECK(false) << "Method unused"; RAY_CHECK(false) << "Method unused";
return JobID::Nil(); return JobID::Nil();
} }
int GetRuntimeEnvHash() const { return runtime_env_hash_; } int GetRuntimeEnvHash() const override { return runtime_env_hash_; }
void AssignActorId(const ActorID &actor_id) { RAY_CHECK(false) << "Method unused"; } void AssignActorId(const ActorID &actor_id) override {
const ActorID &GetActorId() const { RAY_CHECK(false) << "Method unused";
}
const ActorID &GetActorId() const override {
RAY_CHECK(false) << "Method unused"; RAY_CHECK(false) << "Method unused";
return ActorID::Nil(); return ActorID::Nil();
} }
void MarkDetachedActor() { is_detached_actor_ = true; } void MarkDetachedActor() override { is_detached_actor_ = true; }
bool IsDetachedActor() const { return is_detached_actor_; } bool IsDetachedActor() const override { return is_detached_actor_; }
const std::shared_ptr<ClientConnection> Connection() const { const std::shared_ptr<ClientConnection> Connection() const override {
RAY_CHECK(false) << "Method unused"; RAY_CHECK(false) << "Method unused";
return nullptr; return nullptr;
} }
const rpc::Address &GetOwnerAddress() const { const rpc::Address &GetOwnerAddress() const override {
RAY_CHECK(false) << "Method unused"; RAY_CHECK(false) << "Method unused";
return address_; return address_;
} }
void DirectActorCallArgWaitComplete(int64_t tag) { void DirectActorCallArgWaitComplete(int64_t tag) override {
RAY_CHECK(false) << "Method unused"; RAY_CHECK(false) << "Method unused";
} }
void ClearAllocatedInstances() { allocated_instances_ = nullptr; } void ClearAllocatedInstances() override { allocated_instances_ = nullptr; }
void ClearLifetimeAllocatedInstances() { lifetime_allocated_instances_ = nullptr; } void ClearLifetimeAllocatedInstances() override {
lifetime_allocated_instances_ = nullptr;
}
const BundleID &GetBundleId() const { const BundleID &GetBundleId() const override {
RAY_CHECK(false) << "Method unused"; RAY_CHECK(false) << "Method unused";
return bundle_id_; return bundle_id_;
} }
void SetBundleId(const BundleID &bundle_id) { bundle_id_ = bundle_id; } void SetBundleId(const BundleID &bundle_id) override { bundle_id_ = bundle_id; }
RayTask &GetAssignedTask() { return task_; } RayTask &GetAssignedTask() override { return task_; }
bool IsRegistered() { bool IsRegistered() override {
RAY_CHECK(false) << "Method unused"; RAY_CHECK(false) << "Method unused";
return false; return false;
} }
rpc::CoreWorkerClientInterface *rpc_client() { rpc::CoreWorkerClientInterface *rpc_client() override {
RAY_CHECK(false) << "Method unused"; RAY_CHECK(false) << "Method unused";
return nullptr; return nullptr;
} }
bool IsAvailableForScheduling() const { bool IsAvailableForScheduling() const override {
RAY_CHECK(false) << "Method unused"; RAY_CHECK(false) << "Method unused";
return true; return true;
} }
protected: protected:
void SetStartupToken(StartupToken startup_token) { void SetStartupToken(StartupToken startup_token) override {
RAY_CHECK(false) << "Method unused"; RAY_CHECK(false) << "Method unused";
}; };

View file

@ -96,6 +96,20 @@ int Worker::AssignedPort() const { return assigned_port_; }
void Worker::SetAssignedPort(int port) { assigned_port_ = port; }; void Worker::SetAssignedPort(int port) { assigned_port_ = port; };
void Worker::AsyncNotifyGCSRestart() {
if (rpc_client_) {
rpc::RayletNotifyGCSRestartRequest request;
rpc_client_->RayletNotifyGCSRestart(request, [](Status status, auto reply) {
if (!status.ok()) {
RAY_LOG(ERROR) << "Failed to notify worker about GCS restarting: "
<< status.ToString();
}
});
} else {
notify_gcs_restarted_ = true;
}
}
void Worker::Connect(int port) { void Worker::Connect(int port) {
RAY_CHECK(port > 0); RAY_CHECK(port > 0);
port_ = port; port_ = port;
@ -103,10 +117,16 @@ void Worker::Connect(int port) {
addr.set_ip_address(ip_address_); addr.set_ip_address(ip_address_);
addr.set_port(port_); addr.set_port(port_);
rpc_client_ = std::make_unique<rpc::CoreWorkerClient>(addr, client_call_manager_); rpc_client_ = std::make_unique<rpc::CoreWorkerClient>(addr, client_call_manager_);
Connect(rpc_client_);
} }
void Worker::Connect(std::shared_ptr<rpc::CoreWorkerClientInterface> rpc_client) { void Worker::Connect(std::shared_ptr<rpc::CoreWorkerClientInterface> rpc_client) {
rpc_client_ = rpc_client; rpc_client_ = rpc_client;
if (notify_gcs_restarted_) {
// We need to send RPC to notify about the GCS restarts
AsyncNotifyGCSRestart();
notify_gcs_restarted_ = false;
}
} }
void Worker::AssignTaskId(const TaskID &task_id) { assigned_task_id_ = task_id; } void Worker::AssignTaskId(const TaskID &task_id) { assigned_task_id_ = task_id; }

View file

@ -54,6 +54,7 @@ class WorkerInterface {
virtual void SetProcess(Process proc) = 0; virtual void SetProcess(Process proc) = 0;
virtual Language GetLanguage() const = 0; virtual Language GetLanguage() const = 0;
virtual const std::string IpAddress() const = 0; virtual const std::string IpAddress() const = 0;
virtual void AsyncNotifyGCSRestart() = 0;
/// Connect this worker's gRPC client. /// Connect this worker's gRPC client.
virtual void Connect(int port) = 0; virtual void Connect(int port) = 0;
/// Testing-only /// Testing-only
@ -150,6 +151,7 @@ class Worker : public WorkerInterface {
void SetProcess(Process proc); void SetProcess(Process proc);
Language GetLanguage() const; Language GetLanguage() const;
const std::string IpAddress() const; const std::string IpAddress() const;
void AsyncNotifyGCSRestart();
/// Connect this worker's gRPC client. /// Connect this worker's gRPC client.
void Connect(int port); void Connect(int port);
/// Testing-only /// Testing-only
@ -281,6 +283,8 @@ class Worker : public WorkerInterface {
std::shared_ptr<TaskResourceInstances> lifetime_allocated_instances_; std::shared_ptr<TaskResourceInstances> lifetime_allocated_instances_;
/// RayTask being assigned to this worker. /// RayTask being assigned to this worker.
RayTask assigned_task_; RayTask assigned_task_;
/// If true, a RPC need to be sent to notify the worker about GCS restarting.
bool notify_gcs_restarted_ = false;
}; };
} // namespace raylet } // namespace raylet

View file

@ -193,6 +193,10 @@ class CoreWorkerClientInterface : public pubsub::SubscriberClientInterface {
const ClientCallback<AssignObjectOwnerReply> &callback) { const ClientCallback<AssignObjectOwnerReply> &callback) {
} }
virtual void RayletNotifyGCSRestart(
const RayletNotifyGCSRestartRequest &request,
const ClientCallback<RayletNotifyGCSRestartReply> &callback) {}
/// Returns the max acked sequence number, useful for checking on progress. /// Returns the max acked sequence number, useful for checking on progress.
virtual int64_t ClientProcessedUpToSeqno() { return -1; } virtual int64_t ClientProcessedUpToSeqno() { return -1; }
@ -312,6 +316,12 @@ class CoreWorkerClient : public std::enable_shared_from_this<CoreWorkerClient>,
/*method_timeout_ms*/ -1, /*method_timeout_ms*/ -1,
override) override)
VOID_RPC_CLIENT_METHOD(CoreWorkerService,
RayletNotifyGCSRestart,
grpc_client_,
/*method_timeout_ms*/ -1,
override)
VOID_RPC_CLIENT_METHOD( VOID_RPC_CLIENT_METHOD(
CoreWorkerService, Exit, grpc_client_, /*method_timeout_ms*/ -1, override) CoreWorkerService, Exit, grpc_client_, /*method_timeout_ms*/ -1, override)

View file

@ -30,6 +30,7 @@ namespace rpc {
#define RAY_CORE_WORKER_RPC_HANDLERS \ #define RAY_CORE_WORKER_RPC_HANDLERS \
RPC_SERVICE_HANDLER(CoreWorkerService, PushTask, -1) \ RPC_SERVICE_HANDLER(CoreWorkerService, PushTask, -1) \
RPC_SERVICE_HANDLER(CoreWorkerService, DirectActorCallArgWaitComplete, -1) \ RPC_SERVICE_HANDLER(CoreWorkerService, DirectActorCallArgWaitComplete, -1) \
RPC_SERVICE_HANDLER(CoreWorkerService, RayletNotifyGCSRestart, -1) \
RPC_SERVICE_HANDLER(CoreWorkerService, GetObjectStatus, -1) \ RPC_SERVICE_HANDLER(CoreWorkerService, GetObjectStatus, -1) \
RPC_SERVICE_HANDLER(CoreWorkerService, WaitForActorOutOfScope, -1) \ RPC_SERVICE_HANDLER(CoreWorkerService, WaitForActorOutOfScope, -1) \
RPC_SERVICE_HANDLER(CoreWorkerService, PubsubLongPolling, -1) \ RPC_SERVICE_HANDLER(CoreWorkerService, PubsubLongPolling, -1) \
@ -51,6 +52,7 @@ namespace rpc {
#define RAY_CORE_WORKER_DECLARE_RPC_HANDLERS \ #define RAY_CORE_WORKER_DECLARE_RPC_HANDLERS \
DECLARE_VOID_RPC_SERVICE_HANDLER_METHOD(PushTask) \ DECLARE_VOID_RPC_SERVICE_HANDLER_METHOD(PushTask) \
DECLARE_VOID_RPC_SERVICE_HANDLER_METHOD(DirectActorCallArgWaitComplete) \ DECLARE_VOID_RPC_SERVICE_HANDLER_METHOD(DirectActorCallArgWaitComplete) \
DECLARE_VOID_RPC_SERVICE_HANDLER_METHOD(RayletNotifyGCSRestart) \
DECLARE_VOID_RPC_SERVICE_HANDLER_METHOD(GetObjectStatus) \ DECLARE_VOID_RPC_SERVICE_HANDLER_METHOD(GetObjectStatus) \
DECLARE_VOID_RPC_SERVICE_HANDLER_METHOD(WaitForActorOutOfScope) \ DECLARE_VOID_RPC_SERVICE_HANDLER_METHOD(WaitForActorOutOfScope) \
DECLARE_VOID_RPC_SERVICE_HANDLER_METHOD(PubsubLongPolling) \ DECLARE_VOID_RPC_SERVICE_HANDLER_METHOD(PubsubLongPolling) \