diff --git a/src/ray/common/network_util.cc b/src/ray/common/network_util.cc new file mode 100644 index 000000000..a628de50f --- /dev/null +++ b/src/ray/common/network_util.cc @@ -0,0 +1,73 @@ +// Copyright 2017 The Ray Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "network_util.h" +#include "ray/util/logging.h" + +std::string GetValidLocalIp(int port, int64_t timeout_ms) { + AsyncClient async_client; + boost::system::error_code error_code; + std::string address; + bool is_timeout; + if (async_client.Connect(kPublicDNSServerIp, kPublicDNSServerPort, timeout_ms, + &is_timeout, &error_code)) { + address = async_client.GetLocalIPAddress(); + } else { + address = "127.0.0.1"; + + if (is_timeout || error_code == boost::system::errc::host_unreachable) { + boost::asio::ip::detail::endpoint primary_endpoint; + boost::asio::io_context io_context; + boost::asio::ip::tcp::resolver resolver(io_context); + boost::asio::ip::tcp::resolver::query query( + boost::asio::ip::host_name(), "", + boost::asio::ip::resolver_query_base::flags::v4_mapped); + boost::asio::ip::tcp::resolver::iterator iter = resolver.resolve(query, error_code); + boost::asio::ip::tcp::resolver::iterator end; // End marker. + if (!error_code) { + while (iter != end) { + boost::asio::ip::tcp::endpoint ep = *iter; + if (ep.address().is_v4() && !ep.address().is_loopback() && + !ep.address().is_multicast()) { + primary_endpoint.address(ep.address()); + primary_endpoint.port(ep.port()); + + AsyncClient client; + if (client.Connect(primary_endpoint.address().to_string(), port, timeout_ms, + &is_timeout)) { + break; + } + } + iter++; + } + } else { + RAY_LOG(WARNING) << "Failed to resolve ip address, error = " + << strerror(error_code.value()); + iter = end; + } + + if (iter != end) { + address = primary_endpoint.address().to_string(); + } + } + } + + return address; +} + +bool Ping(const std::string &ip, int port, int64_t timeout_ms) { + AsyncClient client; + bool is_timeout; + return client.Connect(ip, port, timeout_ms, &is_timeout); +} diff --git a/src/ray/common/network_util.h b/src/ray/common/network_util.h index 83e6d0e3e..56776888c 100644 --- a/src/ray/common/network_util.h +++ b/src/ray/common/network_util.h @@ -33,6 +33,11 @@ class AsyncClient { public: AsyncClient() : socket_(io_service_), timer_(io_service_) {} + ~AsyncClient() { + io_service_.stop(); + socket_.close(); + } + /// This function is used to asynchronously connect a socket to the specified address /// with timeout. /// @@ -108,55 +113,14 @@ class AsyncClient { /// \param port The port that the local ip is listening on. /// \param timeout_ms The maximum wait time in milliseconds. /// \return A valid local ip. -std::string GetValidLocalIp(int port, int64_t timeout_ms) { - AsyncClient async_client; - boost::system::error_code error_code; - std::string address; - bool is_timeout; - if (async_client.Connect(kPublicDNSServerIp, kPublicDNSServerPort, timeout_ms, - &is_timeout, &error_code)) { - address = async_client.GetLocalIPAddress(); - } else { - address = "127.0.0.1"; +std::string GetValidLocalIp(int port, int64_t timeout_ms); - if (is_timeout || error_code == boost::system::errc::host_unreachable) { - boost::asio::ip::detail::endpoint primary_endpoint; - boost::asio::io_context io_context; - boost::asio::ip::tcp::resolver resolver(io_context); - boost::asio::ip::tcp::resolver::query query( - boost::asio::ip::host_name(), "", - boost::asio::ip::resolver_query_base::flags::v4_mapped); - boost::asio::ip::tcp::resolver::iterator iter = resolver.resolve(query, error_code); - boost::asio::ip::tcp::resolver::iterator end; // End marker. - if (!error_code) { - while (iter != end) { - boost::asio::ip::tcp::endpoint ep = *iter; - if (ep.address().is_v4() && !ep.address().is_loopback() && - !ep.address().is_multicast()) { - primary_endpoint.address(ep.address()); - primary_endpoint.port(ep.port()); - - AsyncClient client; - if (client.Connect(primary_endpoint.address().to_string(), port, timeout_ms, - &is_timeout)) { - break; - } - } - iter++; - } - } else { - RAY_LOG(WARNING) << "Failed to resolve ip address, error = " - << strerror(error_code.value()); - iter = end; - } - - if (iter != end) { - address = primary_endpoint.address().to_string(); - } - } - } - - return address; -} +/// A helper function to test whether target rpc server is valid. +/// +/// \param ip The ip that the target rpc server is listening on. +/// \param port The port that the target rpc server is listening on. +/// \param timeout_ms The maximum wait time in milliseconds. +/// \return Whether target rpc server is valid. +bool Ping(const std::string &ip, int port, int64_t timeout_ms); #endif // RAY_COMMON_NETWORK_UTIL_H diff --git a/src/ray/common/ray_config_def.h b/src/ray/common/ray_config_def.h index 8aae98f1b..9b73a032f 100644 --- a/src/ray/common/ray_config_def.h +++ b/src/ray/common/ray_config_def.h @@ -284,6 +284,12 @@ RAY_CONFIG(uint32_t, task_retry_delay_ms, 5000) /// Duration to wait between retrying to kill a task. RAY_CONFIG(uint32_t, cancellation_retry_ms, 2000) +/// The interval at which the gcs rpc client will check if gcs rpc server is ready. +RAY_CONFIG(int64_t, ping_gcs_rpc_server_interval_milliseconds, 1000) + +/// Maximum number of times to retry ping gcs rpc server when gcs server restarts. +RAY_CONFIG(int32_t, ping_gcs_rpc_server_max_retries, 600) + /// Whether to enable gcs service. /// RAY_GCS_SERVICE_ENABLED is an env variable which only set in ci job. /// If the value of RAY_GCS_SERVICE_ENABLED is false, we will disable gcs service, diff --git a/src/ray/gcs/accessor.h b/src/ray/gcs/accessor.h index 17cd5cd60..4961a99bc 100644 --- a/src/ray/gcs/accessor.h +++ b/src/ray/gcs/accessor.h @@ -151,6 +151,12 @@ class ActorInfoAccessor { const ActorID &actor_id, const OptionalItemCallback &callback) = 0; + /// Reestablish subscription. + /// This should be called when GCS server restarts from a failure. + /// + /// \return Status + virtual Status AsyncReSubscribe() = 0; + protected: ActorInfoAccessor() = default; }; @@ -195,6 +201,12 @@ class JobInfoAccessor { /// \return Status virtual Status AsyncGetAll(const MultiItemCallback &callback) = 0; + /// Reestablish subscription. + /// This should be called when GCS server restarts from a failure. + /// + /// \return Status + virtual Status AsyncReSubscribe() = 0; + protected: JobInfoAccessor() = default; }; diff --git a/src/ray/gcs/gcs_client/service_based_accessor.cc b/src/ray/gcs/gcs_client/service_based_accessor.cc index 1420dfa7d..c822332b2 100644 --- a/src/ray/gcs/gcs_client/service_based_accessor.cc +++ b/src/ray/gcs/gcs_client/service_based_accessor.cc @@ -61,19 +61,30 @@ Status ServiceBasedJobInfoAccessor::AsyncMarkFinished(const JobID &job_id, Status ServiceBasedJobInfoAccessor::AsyncSubscribeToFinishedJobs( const SubscribeCallback &subscribe, const StatusCallback &done) { - RAY_LOG(DEBUG) << "Subscribing finished job."; RAY_CHECK(subscribe != nullptr); - auto on_subscribe = [subscribe](const std::string &id, const std::string &data) { - JobTableData job_data; - job_data.ParseFromString(data); - if (job_data.is_dead()) { - subscribe(JobID::FromBinary(id), job_data); - } + subscribe_operation_ = [this, subscribe](const StatusCallback &done) { + RAY_LOG(DEBUG) << "Subscribing finished job."; + auto on_subscribe = [subscribe](const std::string &id, const std::string &data) { + JobTableData job_data; + job_data.ParseFromString(data); + if (job_data.is_dead()) { + subscribe(JobID::FromBinary(id), job_data); + } + }; + Status status = + client_impl_->GetGcsPubSub().SubscribeAll(JOB_CHANNEL, on_subscribe, done); + RAY_LOG(DEBUG) << "Finished subscribing finished job."; + return status; }; - Status status = - client_impl_->GetGcsPubSub().SubscribeAll(JOB_CHANNEL, on_subscribe, done); - RAY_LOG(DEBUG) << "Finished subscribing finished job."; - return status; + return subscribe_operation_(done); +} + +Status ServiceBasedJobInfoAccessor::AsyncReSubscribe() { + RAY_LOG(INFO) << "Reestablishing subscription for job info."; + if (subscribe_operation_ != nullptr) { + return subscribe_operation_(nullptr); + } + return Status::OK(); } Status ServiceBasedJobInfoAccessor::AsyncGetAll( @@ -224,32 +235,35 @@ Status ServiceBasedActorInfoAccessor::AsyncSubscribeAll( const StatusCallback &done) { RAY_LOG(DEBUG) << "Subscribing register or update operations of actors."; RAY_CHECK(subscribe != nullptr); - auto on_subscribe = [subscribe](const std::string &id, const std::string &data) { - ActorTableData actor_data; - actor_data.ParseFromString(data); - subscribe(ActorID::FromBinary(actor_data.actor_id()), actor_data); + subscribe_all_operation_ = [this, subscribe](const StatusCallback &done) { + auto on_subscribe = [subscribe](const std::string &id, const std::string &data) { + ActorTableData actor_data; + actor_data.ParseFromString(data); + subscribe(ActorID::FromBinary(actor_data.actor_id()), actor_data); + }; + auto on_done = [this, subscribe, done](const Status &status) { + if (status.ok()) { + auto callback = [subscribe, done]( + const Status &status, + const std::vector &actor_info_list) { + for (auto &actor_info : actor_info_list) { + subscribe(ActorID::FromBinary(actor_info.actor_id()), actor_info); + } + if (done) { + done(status); + } + }; + RAY_CHECK_OK(AsyncGetAll(callback)); + } else if (done) { + done(status); + } + }; + auto status = + client_impl_->GetGcsPubSub().SubscribeAll(ACTOR_CHANNEL, on_subscribe, on_done); + RAY_LOG(DEBUG) << "Finished subscribing register or update operations of actors."; + return status; }; - auto on_done = [this, subscribe, done](const Status &status) { - if (status.ok()) { - auto callback = [subscribe, done]( - const Status &status, - const std::vector &actor_info_list) { - for (auto &actor_info : actor_info_list) { - subscribe(ActorID::FromBinary(actor_info.actor_id()), actor_info); - } - if (done) { - done(status); - } - }; - RAY_CHECK_OK(AsyncGetAll(callback)); - } else if (done) { - done(status); - } - }; - auto status = - client_impl_->GetGcsPubSub().SubscribeAll(ACTOR_CHANNEL, on_subscribe, on_done); - RAY_LOG(DEBUG) << "Finished subscribing register or update operations of actors."; - return status; + return subscribe_all_operation_(done); } Status ServiceBasedActorInfoAccessor::AsyncSubscribe( @@ -258,38 +272,43 @@ Status ServiceBasedActorInfoAccessor::AsyncSubscribe( const StatusCallback &done) { RAY_LOG(DEBUG) << "Subscribing update operations of actor, actor id = " << actor_id; RAY_CHECK(subscribe != nullptr) << "Failed to subscribe actor, actor id = " << actor_id; - auto on_subscribe = [subscribe](const std::string &id, const std::string &data) { - ActorTableData actor_data; - actor_data.ParseFromString(data); - subscribe(ActorID::FromBinary(actor_data.actor_id()), actor_data); + auto subscribe_operation = [this, actor_id, subscribe](const StatusCallback &done) { + auto on_subscribe = [subscribe](const std::string &id, const std::string &data) { + ActorTableData actor_data; + actor_data.ParseFromString(data); + subscribe(ActorID::FromBinary(actor_data.actor_id()), actor_data); + }; + auto on_done = [this, actor_id, subscribe, done](const Status &status) { + if (status.ok()) { + auto callback = [actor_id, subscribe, done]( + const Status &status, + const boost::optional &result) { + if (result) { + subscribe(actor_id, *result); + } + if (done) { + done(status); + } + }; + RAY_CHECK_OK(AsyncGet(actor_id, callback)); + } else if (done) { + done(status); + } + }; + auto status = client_impl_->GetGcsPubSub().Subscribe(ACTOR_CHANNEL, actor_id.Hex(), + on_subscribe, on_done); + RAY_LOG(DEBUG) << "Finished subscribing update operations of actor, actor id = " + << actor_id; + return status; }; - auto on_done = [this, actor_id, subscribe, done](const Status &status) { - if (status.ok()) { - auto callback = [actor_id, subscribe, done]( - const Status &status, - const boost::optional &result) { - if (result) { - subscribe(actor_id, *result); - } - if (done) { - done(status); - } - }; - RAY_CHECK_OK(AsyncGet(actor_id, callback)); - } else if (done) { - done(status); - } - }; - auto status = client_impl_->GetGcsPubSub().Subscribe(ACTOR_CHANNEL, actor_id.Hex(), - on_subscribe, on_done); - RAY_LOG(DEBUG) << "Finished subscribing update operations of actor, actor id = " - << actor_id; - return status; + subscribe_operations_[actor_id] = subscribe_operation; + return subscribe_operation(done); } Status ServiceBasedActorInfoAccessor::AsyncUnsubscribe(const ActorID &actor_id) { RAY_LOG(DEBUG) << "Cancelling subscription to an actor, actor id = " << actor_id; auto status = client_impl_->GetGcsPubSub().Unsubscribe(ACTOR_CHANNEL, actor_id.Hex()); + subscribe_operations_.erase(actor_id); RAY_LOG(DEBUG) << "Finished cancelling subscription to an actor, actor id = " << actor_id; return status; @@ -366,6 +385,17 @@ Status ServiceBasedActorInfoAccessor::AsyncGetCheckpointID( return Status::OK(); } +Status ServiceBasedActorInfoAccessor::AsyncReSubscribe() { + RAY_LOG(INFO) << "Reestablishing subscription for actor info."; + if (subscribe_all_operation_ != nullptr) { + RAY_CHECK_OK(subscribe_all_operation_(nullptr)); + } + for (auto &item : subscribe_operations_) { + RAY_CHECK_OK(item.second(nullptr)); + } + return Status::OK(); +} + ServiceBasedNodeInfoAccessor::ServiceBasedNodeInfoAccessor( ServiceBasedGcsClient *client_impl) : client_impl_(client_impl) {} diff --git a/src/ray/gcs/gcs_client/service_based_accessor.h b/src/ray/gcs/gcs_client/service_based_accessor.h index af0edba4f..51f0026b0 100644 --- a/src/ray/gcs/gcs_client/service_based_accessor.h +++ b/src/ray/gcs/gcs_client/service_based_accessor.h @@ -23,6 +23,8 @@ namespace ray { namespace gcs { +using SubscribeOperation = std::function; + class ServiceBasedGcsClient; /// \class ServiceBasedJobInfoAccessor @@ -45,7 +47,13 @@ class ServiceBasedJobInfoAccessor : public JobInfoAccessor { Status AsyncGetAll(const MultiItemCallback &callback) override; + Status AsyncReSubscribe() override; + private: + /// Save the subscribe operation in this function, so we can call it again when GCS + /// restarts from a failure. + SubscribeOperation subscribe_operation_; + ServiceBasedGcsClient *client_impl_; }; @@ -100,7 +108,15 @@ class ServiceBasedActorInfoAccessor : public ActorInfoAccessor { const ActorID &actor_id, const OptionalItemCallback &callback) override; + Status AsyncReSubscribe() override; + private: + /// Save the subscribe operation in this function, so we can call it again when GCS + /// restarts from a failure. + SubscribeOperation subscribe_all_operation_; + /// Save the subscribe operation of actors. + std::unordered_map subscribe_operations_; + ServiceBasedGcsClient *client_impl_; Sequencer sequencer_; diff --git a/src/ray/gcs/gcs_client/service_based_gcs_client.cc b/src/ray/gcs/gcs_client/service_based_gcs_client.cc index c277dc6ee..388e939fd 100644 --- a/src/ray/gcs/gcs_client/service_based_gcs_client.cc +++ b/src/ray/gcs/gcs_client/service_based_gcs_client.cc @@ -47,11 +47,16 @@ Status ServiceBasedGcsClient::Connect(boost::asio::io_service &io_service) { }; std::pair address = get_server_address(); + auto re_subscribe = [this]() { + RAY_CHECK_OK(job_accessor_->AsyncReSubscribe()); + RAY_CHECK_OK(actor_accessor_->AsyncReSubscribe()); + }; + // Connect to gcs service. client_call_manager_.reset(new rpc::ClientCallManager(io_service)); gcs_rpc_client_.reset(new rpc::GcsRpcClient(address.first, address.second, - *client_call_manager_, get_server_address)); - + *client_call_manager_, get_server_address, + re_subscribe)); job_accessor_.reset(new ServiceBasedJobInfoAccessor(this)); actor_accessor_.reset(new ServiceBasedActorInfoAccessor(this)); node_accessor_.reset(new ServiceBasedNodeInfoAccessor(this)); diff --git a/src/ray/gcs/gcs_client/test/service_based_gcs_client_test.cc b/src/ray/gcs/gcs_client/test/service_based_gcs_client_test.cc index a5a70f044..7da0b376b 100644 --- a/src/ray/gcs/gcs_client/test/service_based_gcs_client_test.cc +++ b/src/ray/gcs/gcs_client/test/service_based_gcs_client_test.cc @@ -55,14 +55,29 @@ class ServiceBasedGcsClientTest : public RedisServiceManagerForTest { } void TearDown() override { - gcs_server_->Stop(); io_service_->stop(); + gcs_server_->Stop(); thread_io_service_->join(); thread_gcs_server_->join(); gcs_client_->Disconnect(); FlushAll(); } + void RestartGcsServer() { + RAY_LOG(INFO) << "Stopping GCS service, port = " << gcs_server_->GetPort(); + gcs_server_->Stop(); + thread_gcs_server_->join(); + + gcs_server_.reset(new gcs::GcsServer(config)); + thread_gcs_server_.reset(new std::thread([this] { gcs_server_->Start(); })); + + // Wait until server starts listening. + while (gcs_server_->GetPort() == 0) { + std::this_thread::sleep_for(std::chrono::milliseconds(10)); + } + RAY_LOG(INFO) << "GCS service restarted, port = " << gcs_server_->GetPort(); + } + bool SubscribeToFinishedJobs( const gcs::SubscribeCallback &subscribe) { std::promise promise; @@ -797,28 +812,67 @@ TEST_F(ServiceBasedGcsClientTest, TestErrorInfo) { ASSERT_TRUE(ReportJobError(error_table_data)); } -TEST_F(ServiceBasedGcsClientTest, TestDetectGcsAvailability) { - // Create job table data. - JobID add_job_id = JobID::FromInt(1); - auto job_table_data = Mocker::GenJobTableData(add_job_id); +TEST_F(ServiceBasedGcsClientTest, TestJobTableReSubscribe) { + // Test that subscription of the job table can still work when GCS server restarts. + JobID job_id = JobID::FromInt(1); + auto job_table_data = Mocker::GenJobTableData(job_id); - RAY_LOG(INFO) << "Initializing GCS service, port = " << gcs_server_->GetPort(); - gcs_server_->Stop(); - thread_gcs_server_->join(); + // Subscribe to finished jobs. + std::atomic job_update_count(0); + auto subscribe = [&job_update_count](const JobID &id, const rpc::JobTableData &result) { + ++job_update_count; + }; + ASSERT_TRUE(SubscribeToFinishedJobs(subscribe)); - gcs_server_.reset(new gcs::GcsServer(config)); - thread_gcs_server_.reset(new std::thread([this] { gcs_server_->Start(); })); + RestartGcsServer(); - // Wait until server starts listening. - while (gcs_server_->GetPort() == 0) { - std::this_thread::sleep_for(std::chrono::milliseconds(10)); - } - RAY_LOG(INFO) << "GCS service restarted, port = " << gcs_server_->GetPort(); + ASSERT_TRUE(AddJob(job_table_data)); + ASSERT_TRUE(MarkJobFinished(job_id)); + WaitPendingDone(job_update_count, 1); +} - std::promise promise; - RAY_CHECK_OK(gcs_client_->Jobs().AsyncAdd( - job_table_data, [&promise](Status status) { promise.set_value(status.ok()); })); - promise.get_future().get(); +TEST_F(ServiceBasedGcsClientTest, TestActorTableReSubscribe) { + JobID job_id = JobID::FromInt(1); + auto actor1_table_data = Mocker::GenActorTableData(job_id); + auto actor1_id = ActorID::FromBinary(actor1_table_data->actor_id()); + auto actor2_table_data = Mocker::GenActorTableData(job_id); + auto actor2_id = ActorID::FromBinary(actor2_table_data->actor_id()); + + // Subscribe to any register or update operations of actors. + std::atomic actors_update_count(0); + auto subscribe_all = [&actors_update_count](const ActorID &id, + const rpc::ActorTableData &result) { + ++actors_update_count; + }; + ASSERT_TRUE(SubscribeAllActors(subscribe_all)); + + // Subscribe to any update operations of actor1. + std::atomic actor1_update_count(0); + auto actor1_subscribe = [&actor1_update_count](const ActorID &actor_id, + const gcs::ActorTableData &data) { + ++actor1_update_count; + }; + ASSERT_TRUE(SubscribeActor(actor1_id, actor1_subscribe)); + + // Subscribe to any update operations of actor2. + std::atomic actor2_update_count(0); + auto actor2_subscribe = [&actor2_update_count](const ActorID &actor_id, + const gcs::ActorTableData &data) { + ++actor2_update_count; + }; + ASSERT_TRUE(SubscribeActor(actor2_id, actor2_subscribe)); + + ASSERT_TRUE(RegisterActor(actor1_table_data)); + ASSERT_TRUE(RegisterActor(actor2_table_data)); + WaitPendingDone(actor2_update_count, 1); + UnsubscribeActor(actor2_id); + + RestartGcsServer(); + + ASSERT_TRUE(UpdateActor(actor1_id, actor1_table_data)); + ASSERT_TRUE(UpdateActor(actor2_id, actor2_table_data)); + WaitPendingDone(actor1_update_count, 3); + WaitPendingDone(actor2_update_count, 1); } TEST_F(ServiceBasedGcsClientTest, TestGcsRedisFailureDetector) { diff --git a/src/ray/gcs/gcs_server/test/gcs_server_test_util.h b/src/ray/gcs/gcs_server/test/gcs_server_test_util.h index 2c7212364..79f105940 100644 --- a/src/ray/gcs/gcs_server/test/gcs_server_test_util.h +++ b/src/ray/gcs/gcs_server/test/gcs_server_test_util.h @@ -242,6 +242,8 @@ struct GcsServerMocker { const gcs::OptionalItemCallback &callback) override { return Status::NotImplemented(""); } + + Status AsyncReSubscribe() override { return Status::NotImplemented(""); } }; class MockedNodeInfoAccessor : public gcs::NodeInfoAccessor { diff --git a/src/ray/gcs/redis_accessor.h b/src/ray/gcs/redis_accessor.h index 71992cb51..15e5d458f 100644 --- a/src/ray/gcs/redis_accessor.h +++ b/src/ray/gcs/redis_accessor.h @@ -83,6 +83,8 @@ class RedisLogBasedActorInfoAccessor : public ActorInfoAccessor { const ActorID &actor_id, const OptionalItemCallback &callback) override; + Status AsyncReSubscribe() override { return Status::NotImplemented(""); } + protected: virtual std::vector GetAllActorID() const; virtual Status Get(const ActorID &actor_id, ActorTableData *actor_table_data) const; @@ -181,6 +183,10 @@ class RedisJobInfoAccessor : public JobInfoAccessor { return Status::NotImplemented("AsyncGetAll not implemented"); } + Status AsyncReSubscribe() override { + return Status::NotImplemented("AsyncReSubscribe not implemented"); + } + private: /// Append job information to GCS asynchronously. /// diff --git a/src/ray/rpc/gcs_server/gcs_rpc_client.h b/src/ray/rpc/gcs_server/gcs_rpc_client.h index 0c3dfa007..9da56705c 100644 --- a/src/ray/rpc/gcs_server/gcs_rpc_client.h +++ b/src/ray/rpc/gcs_server/gcs_rpc_client.h @@ -15,6 +15,8 @@ #ifndef RAY_RPC_GCS_RPC_CLIENT_H #define RAY_RPC_GCS_RPC_CLIENT_H +#include +#include "src/ray/common/network_util.h" #include "src/ray/protobuf/gcs_service.grpc.pb.h" #include "src/ray/rpc/grpc_client.h" @@ -84,9 +86,11 @@ class GcsRpcClient { /// rpc server. GcsRpcClient(const std::string &address, const int port, ClientCallManager &client_call_manager, - std::function()> get_server_address = nullptr) + std::function()> get_server_address = nullptr, + std::function reconnected_callback = nullptr) : client_call_manager_(client_call_manager), - get_server_address_(std::move(get_server_address)) { + get_server_address_(std::move(get_server_address)), + reconnected_callback_(std::move(reconnected_callback)) { Init(address, port, client_call_manager); }; @@ -228,15 +232,47 @@ class GcsRpcClient { } void Reconnect() { + absl::MutexLock lock(&mutex_); if (get_server_address_) { - auto address = get_server_address_(); - Init(address.first, address.second, client_call_manager_); + std::pair address; + int index = 0; + for (; index < RayConfig::instance().ping_gcs_rpc_server_max_retries(); ++index) { + address = get_server_address_(); + RAY_LOG(DEBUG) << "Attempt to reconnect to GCS server: " << address.first << ":" + << address.second; + if (Ping(address.first, address.second, 100)) { + RAY_LOG(INFO) << "Reconnected to GCS server: " << address.first << ":" + << address.second; + break; + } + usleep(RayConfig::instance().ping_gcs_rpc_server_interval_milliseconds() * 1000); + } + + if (index < RayConfig::instance().ping_gcs_rpc_server_max_retries()) { + Init(address.first, address.second, client_call_manager_); + if (reconnected_callback_) { + reconnected_callback_(); + } + } else { + RAY_LOG(FATAL) << "Couldn't reconnect to GCS server. The last attempted GCS " + "server address was " + << address.first << ":" << address.second; + } } } + absl::Mutex mutex_; + ClientCallManager &client_call_manager_; std::function()> get_server_address_; + /// The callback that will be called when we reconnect to GCS server. + /// Currently, we use this function to reestablish subscription to GCS. + /// Note, we use ping to detect whether the reconnection is successful. If the ping + /// succeeds but the RPC connection fails, this function might be called called again. + /// So it needs to be idempotent. + std::function reconnected_callback_; + /// The gRPC-generated stub. std::unique_ptr> job_info_grpc_client_; std::unique_ptr> actor_info_grpc_client_;