diff --git a/src/ray/common/constants.h b/src/ray/common/constants.h index f28cd8299..096857e32 100644 --- a/src/ray/common/constants.h +++ b/src/ray/common/constants.h @@ -43,4 +43,8 @@ constexpr char kWorkerDynamicOptionPlaceholderPrefix[] = constexpr char kWorkerRayletConfigPlaceholder[] = "RAY_WORKER_RAYLET_CONFIG_PLACEHOLDER"; +/// Public DNS address which is is used to connect and get local IP. +constexpr char kPublicDNSServerIp[] = "8.8.8.8"; +constexpr int kPublicDNSServerPort = 53; + #endif // RAY_CONSTANTS_H_ diff --git a/src/ray/common/network_util.h b/src/ray/common/network_util.h new file mode 100644 index 000000000..3f1a4cec7 --- /dev/null +++ b/src/ray/common/network_util.h @@ -0,0 +1,153 @@ +// Copyright 2017 The Ray Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef RAY_COMMON_NETWORK_UTIL_H +#define RAY_COMMON_NETWORK_UTIL_H + +#include +#include +#include +#include +#include +#include "constants.h" + +using boost::asio::deadline_timer; +using boost::asio::io_service; +using boost::asio::ip::tcp; + +/// \class AsyncClient +/// +/// This class provides the socket asynchronous interface with timeout: Connect. +class AsyncClient { + public: + AsyncClient() : socket_(io_service_), timer_(io_service_) {} + + /// This function is used to asynchronously connect a socket to the specified address + /// with timeout. + /// + /// \param ip The ip that the rpc server is listening on. + /// \param port The port that the rpc server is listening on. + /// \param timeout_ms The maximum wait time in milliseconds. + /// \return Whether the connection is successful. + bool Connect(const std::string &ip, int port, int64_t timeout_ms) { + try { + auto endpoint = + boost::asio::ip::tcp::endpoint(boost::asio::ip::address::from_string(ip), port); + + bool is_connected = false; + bool is_timeout = false; + socket_.async_connect(endpoint, boost::bind(&AsyncClient::ConnectHandle, this, + boost::asio::placeholders::error, + boost::ref(is_connected))); + + // Set a deadline for the asynchronous operation. + timer_.expires_from_now(boost::posix_time::milliseconds(timeout_ms)); + timer_.async_wait(boost::bind(&AsyncClient::TimerHandle, this, + boost::asio::placeholders::error, + boost::ref(is_timeout))); + + do { + io_service_.run_one(); + } while (!is_timeout && !is_connected); + + timer_.cancel(); + return is_connected; + } catch (...) { + return false; + } + } + + private: + void ConnectHandle(boost::system::error_code error_code, bool &is_connected) { + if (!error_code) { + is_connected = true; + } + } + + void TimerHandle(boost::system::error_code error_code, bool &is_timeout) { + if (!error_code) { + socket_.close(); + is_timeout = true; + } + } + + boost::asio::io_service io_service_; + tcp::socket socket_; + deadline_timer timer_; +}; + +/// A helper function to get a valid local ip. +/// We will connect google public dns server and get local ip from socket. +/// If dns server is unreachable, try to resolve hostname and get a valid ip by ping the +/// port of the local ip is listening on. If there is no valid local ip, `127.0.0.1` is +/// returned. +/// +/// \param port The port that the local ip is listening on. +/// \param timeout_ms The maximum wait time in milliseconds. +/// \return A valid local ip. +std::string GetValidLocalIp(int port, int64_t timeout_ms) { + boost::asio::io_service io_service; + boost::asio::ip::tcp::socket socket(io_service); + boost::system::error_code error_code; + socket.connect(boost::asio::ip::tcp::endpoint( + boost::asio::ip::address::from_string(kPublicDNSServerIp), + kPublicDNSServerPort), + error_code); + std::string address; + if (!error_code) { + address = socket.local_endpoint().address().to_string(); + } else { + address = "127.0.0.1"; + + if (error_code == boost::system::errc::host_unreachable) { + boost::asio::ip::detail::endpoint primary_endpoint; + boost::asio::io_context io_context; + boost::asio::ip::tcp::resolver resolver(io_context); + boost::asio::ip::tcp::resolver::query query( + boost::asio::ip::host_name(), "", + boost::asio::ip::resolver_query_base::flags::v4_mapped); + boost::asio::ip::tcp::resolver::iterator iter = resolver.resolve(query, error_code); + boost::asio::ip::tcp::resolver::iterator end; // End marker. + if (!error_code) { + while (iter != end) { + boost::asio::ip::tcp::endpoint ep = *iter; + if (ep.address().is_v4() && !ep.address().is_loopback() && + !ep.address().is_multicast()) { + primary_endpoint.address(ep.address()); + primary_endpoint.port(ep.port()); + + AsyncClient client; + if (client.Connect(primary_endpoint.address().to_string(), port, + timeout_ms)) { + break; + } + } + iter++; + } + } else { + RAY_LOG(WARNING) << "Failed to resolve ip address, error = " + << strerror(error_code.value()); + iter = end; + } + + if (iter != end) { + address = primary_endpoint.address().to_string(); + } + } + } + + return address; +} + +#endif // RAY_COMMON_NETWORK_UTIL_H diff --git a/src/ray/gcs/gcs_server/gcs_server.cc b/src/ray/gcs/gcs_server/gcs_server.cc index 91ac168a5..058b978c0 100644 --- a/src/ray/gcs/gcs_server/gcs_server.cc +++ b/src/ray/gcs/gcs_server/gcs_server.cc @@ -20,6 +20,8 @@ #include "job_info_handler_impl.h" #include "node_info_handler_impl.h" #include "object_info_handler_impl.h" +#include "ray/common/network_util.h" +#include "ray/common/ray_config.h" #include "stats_handler_impl.h" #include "task_info_handler_impl.h" #include "worker_info_handler_impl.h" @@ -155,37 +157,11 @@ std::unique_ptr GcsServer::InitObjectInfoHandler() { } void GcsServer::StoreGcsServerAddressInRedis() { - boost::asio::ip::detail::endpoint primary_endpoint; - boost::asio::ip::tcp::resolver resolver(main_service_); - boost::asio::ip::tcp::resolver::query query( - boost::asio::ip::host_name(), "", - boost::asio::ip::resolver_query_base::flags::v4_mapped); - boost::system::error_code error_code; - boost::asio::ip::tcp::resolver::iterator iter = resolver.resolve(query, error_code); - boost::asio::ip::tcp::resolver::iterator end; // End marker. - if (!error_code) { - while (iter != end) { - boost::asio::ip::tcp::endpoint ep = *iter; - if (ep.address().is_v4() && !ep.address().is_loopback() && - !ep.address().is_multicast()) { - primary_endpoint.address(ep.address()); - primary_endpoint.port(ep.port()); - break; - } - iter++; - } - } else { - RAY_LOG(WARNING) << "Failed to resolve ip address, error = " - << strerror(error_code.value()); - iter = end; - } - - std::string address; - if (iter == end) { - address = "127.0.0.1:" + std::to_string(GetPort()); - } else { - address = primary_endpoint.address().to_string() + ":" + std::to_string(GetPort()); - } + std::string address = + GetValidLocalIp( + GetPort(), + RayConfig::instance().internal_gcs_service_connect_wait_milliseconds()) + + ":" + std::to_string(GetPort()); RAY_LOG(INFO) << "Gcs server address = " << address; RAY_CHECK_OK(redis_gcs_client_->primary_context()->RunArgvAsync(