diff --git a/BUILD.bazel b/BUILD.bazel index 8057ba78d..02f664bb9 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -913,6 +913,45 @@ cc_library( ], ) +cc_library( + name = "global_state_accessor_lib", + srcs = glob( + [ + "src/ray/gcs/gcs_client/global_state_accessor.cc", + ], + ), + hdrs = glob( + [ + "src/ray/gcs/gcs_client/global_state_accessor.h", + ], + ), + copts = COPTS, + deps = [ + ":service_based_gcs_client_lib", + ], +) + +cc_test( + name = "global_state_accessor_test", + srcs = [ + "src/ray/gcs/gcs_client/test/global_state_accessor_test.cc", + ], + args = ["$(location redis-server) $(location redis-cli) $(location libray_redis_module.so)"], + copts = COPTS, + data = [ + "//:libray_redis_module.so", + "//:redis-cli", + "//:redis-server", + ], + deps = [ + ":gcs_server_lib", + ":gcs_test_util_lib", + ":global_state_accessor_lib", + ":service_based_gcs_client_lib", + "@com_google_googletest//:gtest_main", + ], +) + cc_test( name = "gcs_server_test", srcs = [ @@ -1386,6 +1425,7 @@ pyx_library( ), deps = [ "//:core_worker_lib", + "//:global_state_accessor_lib", "//:ray_util", "//:raylet_lib", "//:serialization_cc_proto", diff --git a/python/ray/_raylet.pyx b/python/ray/_raylet.pyx index e49c13fa8..b39e80f49 100644 --- a/python/ray/_raylet.pyx +++ b/python/ray/_raylet.pyx @@ -79,6 +79,7 @@ from ray.includes.libcoreworker cimport ( CActorHandle, ) from ray.includes.ray_config cimport RayConfig +from ray.includes.global_state_accessor cimport CGlobalStateAccessor import ray from ray.async_compat import (sync_to_async, @@ -107,6 +108,7 @@ include "includes/buffer.pxi" include "includes/common.pxi" include "includes/serialization.pxi" include "includes/libcoreworker.pxi" +include "includes/global_state_accessor.pxi" logger = logging.getLogger(__name__) diff --git a/python/ray/gcs_utils.py b/python/ray/gcs_utils.py index 22331c45f..42f7581ee 100644 --- a/python/ray/gcs_utils.py +++ b/python/ray/gcs_utils.py @@ -46,7 +46,7 @@ XRAY_HEARTBEAT_BATCH_CHANNEL = str( TablePubsub.Value("HEARTBEAT_BATCH_PUBSUB")).encode("ascii") # xray job updates -XRAY_JOB_CHANNEL = str(TablePubsub.Value("JOB_PUBSUB")).encode("ascii") +XRAY_JOB_CHANNEL = "JOB".encode("ascii") # These prefixes must be kept up-to-date with the TablePrefix enum in # gcs.proto. diff --git a/python/ray/includes/global_state_accessor.pxd b/python/ray/includes/global_state_accessor.pxd new file mode 100644 index 000000000..e22526229 --- /dev/null +++ b/python/ray/includes/global_state_accessor.pxd @@ -0,0 +1,13 @@ +from libcpp.string cimport string as c_string +from libcpp cimport bool as c_bool +from libcpp.vector cimport vector as c_vector + +cdef extern from "ray/gcs/gcs_client/global_state_accessor.h" nogil: + cdef cppclass CGlobalStateAccessor "ray::gcs::GlobalStateAccessor": + CGlobalStateAccessor(const c_string &redis_address, + const c_string &redis_password, + c_bool is_test) + c_bool Connect() + void Disconnect() + c_vector[c_string] GetAllJobInfo() + diff --git a/python/ray/includes/global_state_accessor.pxi b/python/ray/includes/global_state_accessor.pxi new file mode 100644 index 000000000..81ba5fdfd --- /dev/null +++ b/python/ray/includes/global_state_accessor.pxi @@ -0,0 +1,24 @@ +from ray.includes.global_state_accessor cimport ( + CGlobalStateAccessor, +) + +cdef class GlobalStateAccessor: + """Cython wrapper class of C++ `ray::gcs::GlobalStateAccessor`.""" + cdef: + unique_ptr[CGlobalStateAccessor] inner + + def __init__(self, redis_address, redis_password, c_bool is_test_client=False): + if not redis_password: + redis_password = "" + self.inner.reset( + new CGlobalStateAccessor(redis_address.encode("ascii"), + redis_password.encode("ascii"), is_test_client)) + + def connect(self): + return self.inner.get().Connect() + + def disconnect(self): + self.inner.get().Disconnect() + + def get_job_table(self): + return self.inner.get().GetAllJobInfo() diff --git a/python/ray/state.py b/python/ray/state.py index 0e608d032..b22c163c6 100644 --- a/python/ray/state.py +++ b/python/ray/state.py @@ -13,6 +13,8 @@ from ray import ( from ray.utils import (decode, binary_to_object_id, binary_to_hex, hex_to_binary) +from ray._raylet import GlobalStateAccessor + logger = logging.getLogger(__name__) @@ -125,6 +127,8 @@ class GlobalState: Attributes: redis_client: The Redis client used to query the primary redis server. redis_clients: Redis clients for each of the Redis shards. + global_state_accessor: The client used to query gcs table from gcs + server. """ def __init__(self): @@ -134,6 +138,7 @@ class GlobalState: self.redis_client = None # Clients for the redis shards, storing the object table & task table. self.redis_clients = None + self.global_state_accessor = None def _check_connected(self): """Check that the object has been initialized before it is used. @@ -150,10 +155,17 @@ class GlobalState: raise RuntimeError("The ray global state API cannot be used " "before ray.init has been called.") + if self.global_state_accessor is None: + raise RuntimeError("The ray global state API cannot be used " + "before ray.init has been called.") + def disconnect(self): """Disconnect global state from GCS.""" self.redis_client = None self.redis_clients = None + if self.global_state_accessor is not None: + self.global_state_accessor.disconnect() + self.global_state_accessor = None def _initialize_global_state(self, redis_address, @@ -171,6 +183,9 @@ class GlobalState: """ self.redis_client = services.create_redis_client( redis_address, redis_password) + self.global_state_accessor = GlobalStateAccessor( + redis_address, redis_password, False) + self.global_state_accessor.connect() start_time = time.time() num_redis_shards = None @@ -382,47 +397,6 @@ class GlobalState: client["alive"] = client["Alive"] return client_table - def _job_table(self, job_id): - """Fetch and parse the job table information for a single job ID. - - Args: - job_id: A job ID or hex string to get information about. - - Returns: - A dictionary with information about the job ID in question. - """ - # Allow the argument to be either a JobID or a hex string. - if not isinstance(job_id, ray.JobID): - assert isinstance(job_id, str) - job_id = ray.JobID(hex_to_binary(job_id)) - - # Return information about a single job ID. - message = self.redis_client.execute_command( - "RAY.TABLE_LOOKUP", gcs_utils.TablePrefix.Value("JOB"), "", - job_id.binary()) - - if message is None: - return {} - - gcs_entry = gcs_utils.GcsEntry.FromString(message) - - assert len(gcs_entry.entries) > 0 - - job_info = {} - - for i in range(len(gcs_entry.entries)): - entry = gcs_utils.JobTableData.FromString(gcs_entry.entries[i]) - assert entry.job_id == job_id.binary() - job_info["JobID"] = job_id.hex() - job_info["DriverIPAddress"] = entry.driver_ip_address - job_info["DriverPid"] = entry.driver_pid - if entry.is_dead: - job_info["StopTime"] = entry.timestamp - else: - job_info["StartTime"] = entry.timestamp - - return job_info - def job_table(self): """Fetch and parse the Redis job table. @@ -437,18 +411,20 @@ class GlobalState: """ self._check_connected() - job_keys = self.redis_client.keys(gcs_utils.TablePrefix_JOB_string + - "*") - - job_ids_binary = { - key[len(gcs_utils.TablePrefix_JOB_string):] - for key in job_keys - } + job_table = self.global_state_accessor.get_job_table() results = [] - - for job_id_binary in job_ids_binary: - results.append(self._job_table(binary_to_hex(job_id_binary))) + for i in range(len(job_table)): + entry = gcs_utils.JobTableData.FromString(job_table[i]) + job_info = {} + job_info["JobID"] = entry.job_id.hex() + job_info["DriverIPAddress"] = entry.driver_ip_address + job_info["DriverPid"] = entry.driver_pid + if entry.is_dead: + job_info["StopTime"] = entry.timestamp + else: + job_info["StartTime"] = entry.timestamp + results.append(job_info) return results diff --git a/src/ray/gcs/accessor.h b/src/ray/gcs/accessor.h index d5cfad211..1fc18c669 100644 --- a/src/ray/gcs/accessor.h +++ b/src/ray/gcs/accessor.h @@ -189,6 +189,12 @@ class JobInfoAccessor { const SubscribeCallback &subscribe, const StatusCallback &done) = 0; + /// Get all job info from GCS asynchronously. + /// + /// \param callback Callback that will be called after lookup finished. + /// \return Status + virtual Status AsyncGetAll(const MultiItemCallback &callback) = 0; + protected: JobInfoAccessor() = default; }; diff --git a/src/ray/gcs/gcs_client/global_state_accessor.cc b/src/ray/gcs/gcs_client/global_state_accessor.cc new file mode 100644 index 000000000..98a08493e --- /dev/null +++ b/src/ray/gcs/gcs_client/global_state_accessor.cc @@ -0,0 +1,84 @@ +// Copyright 2017 The Ray Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "global_state_accessor.h" + +namespace ray { +namespace gcs { + +GlobalStateAccessor::GlobalStateAccessor(const std::string &redis_address, + const std::string &redis_password, + bool is_test) { + RAY_LOG(INFO) << "Redis server address = " << redis_address + << ", is test flag = " << is_test; + std::vector address; + boost::split(address, redis_address, boost::is_any_of(":")); + RAY_CHECK(address.size() == 2); + GcsClientOptions options; + options.server_ip_ = address[0]; + options.server_port_ = std::stoi(address[1]); + options.password_ = redis_password; + options.is_test_client_ = is_test; + gcs_client_.reset(new ServiceBasedGcsClient(options)); + + io_service_.reset(new boost::asio::io_service()); + + std::promise promise; + thread_io_service_.reset(new std::thread([this, &promise] { + std::unique_ptr work( + new boost::asio::io_service::work(*io_service_)); + promise.set_value(true); + io_service_->run(); + })); + promise.get_future().get(); +} + +GlobalStateAccessor::~GlobalStateAccessor() { + Disconnect(); + io_service_->stop(); + thread_io_service_->join(); +} + +bool GlobalStateAccessor::Connect() { + is_connected_ = true; + return gcs_client_->Connect(*io_service_).ok(); +} + +void GlobalStateAccessor::Disconnect() { + if (is_connected_) { + gcs_client_->Disconnect(); + is_connected_ = false; + } +} + +std::vector GlobalStateAccessor::GetAllJobInfo() { + std::vector job_table_data; + std::promise promise; + auto on_done = [&job_table_data, &promise]( + const Status &status, const std::vector &result) { + RAY_CHECK_OK(status); + for (auto &data : result) { + job_table_data.push_back(data.SerializeAsString()); + } + promise.set_value(true); + }; + RAY_CHECK_OK(gcs_client_->Jobs().AsyncGetAll(on_done)); + promise.get_future().get(); + return job_table_data; +} + +} // namespace gcs +} // namespace ray diff --git a/src/ray/gcs/gcs_client/global_state_accessor.h b/src/ray/gcs/gcs_client/global_state_accessor.h new file mode 100644 index 000000000..c80972173 --- /dev/null +++ b/src/ray/gcs/gcs_client/global_state_accessor.h @@ -0,0 +1,67 @@ +// Copyright 2017 The Ray Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef RAY_GCS_GLOBAL_STATE_ACCESSOR_H +#define RAY_GCS_GLOBAL_STATE_ACCESSOR_H + +#include "service_based_gcs_client.h" + +namespace ray { +namespace gcs { + +/// \class GlobalStateAccessor +/// +/// `GlobalStateAccessor` is used to provide synchronous interfaces to access data in GCS +/// for the language front-end (e.g., Python's `state.py`). +class GlobalStateAccessor { + public: + /// Constructor of GlobalStateAccessor. + /// + /// \param redis_address The address of GCS Redis. + /// \param redis_password The password of GCS Redis. + /// \param is_test Whether this accessor is used for tests. + explicit GlobalStateAccessor(const std::string &redis_address, + const std::string &redis_password, bool is_test = false); + + ~GlobalStateAccessor(); + + /// Connect gcs server. + /// + /// \return Whether the connection is successful. + bool Connect(); + + /// Disconnect from gcs server. + void Disconnect(); + + /// Get information of all jobs from GCS Service. + /// + /// \return All job info. To support multi-language, we serialized each JobTableData and + /// returned the serialized string. Where used, it needs to be deserialized with + /// protobuf function. + std::vector GetAllJobInfo(); + + private: + /// Whether this client is connected to gcs server. + bool is_connected_{false}; + + std::unique_ptr gcs_client_; + + std::unique_ptr thread_io_service_; + std::unique_ptr io_service_; +}; + +} // namespace gcs +} // namespace ray + +#endif // RAY_GCS_GLOBAL_STATE_ACCESSOR_H diff --git a/src/ray/gcs/gcs_client/service_based_accessor.cc b/src/ray/gcs/gcs_client/service_based_accessor.cc index 930c76ea1..ec560c927 100644 --- a/src/ray/gcs/gcs_client/service_based_accessor.cc +++ b/src/ray/gcs/gcs_client/service_based_accessor.cc @@ -76,6 +76,20 @@ Status ServiceBasedJobInfoAccessor::AsyncSubscribeToFinishedJobs( return status; } +Status ServiceBasedJobInfoAccessor::AsyncGetAll( + const MultiItemCallback &callback) { + RAY_LOG(DEBUG) << "Getting all job info."; + RAY_CHECK(callback); + rpc::GetAllJobInfoRequest request; + client_impl_->GetGcsRpcClient().GetAllJobInfo( + request, [callback](const Status &status, const rpc::GetAllJobInfoReply &reply) { + auto result = VectorFromProtobuf(reply.job_info_list()); + callback(status, result); + RAY_LOG(DEBUG) << "Finished getting all job info."; + }); + return Status::OK(); +} + ServiceBasedActorInfoAccessor::ServiceBasedActorInfoAccessor( ServiceBasedGcsClient *client_impl) : client_impl_(client_impl) {} diff --git a/src/ray/gcs/gcs_client/service_based_accessor.h b/src/ray/gcs/gcs_client/service_based_accessor.h index e946d3b24..66edf2416 100644 --- a/src/ray/gcs/gcs_client/service_based_accessor.h +++ b/src/ray/gcs/gcs_client/service_based_accessor.h @@ -43,6 +43,8 @@ class ServiceBasedJobInfoAccessor : public JobInfoAccessor { const SubscribeCallback &subscribe, const StatusCallback &done) override; + Status AsyncGetAll(const MultiItemCallback &callback) override; + private: ServiceBasedGcsClient *client_impl_; }; diff --git a/src/ray/gcs/gcs_client/service_based_gcs_client.cc b/src/ray/gcs/gcs_client/service_based_gcs_client.cc index b1a094fb4..c277dc6ee 100644 --- a/src/ray/gcs/gcs_client/service_based_gcs_client.cc +++ b/src/ray/gcs/gcs_client/service_based_gcs_client.cc @@ -70,6 +70,9 @@ Status ServiceBasedGcsClient::Connect(boost::asio::io_service &io_service) { void ServiceBasedGcsClient::Disconnect() { RAY_CHECK(is_connected_); is_connected_ = false; + gcs_pub_sub_.reset(); + redis_gcs_client_->Disconnect(); + redis_gcs_client_.reset(); RAY_LOG(INFO) << "ServiceBasedGcsClient Disconnected."; } diff --git a/src/ray/gcs/gcs_client/test/global_state_accessor_test.cc b/src/ray/gcs/gcs_client/test/global_state_accessor_test.cc new file mode 100644 index 000000000..308bc8ce6 --- /dev/null +++ b/src/ray/gcs/gcs_client/test/global_state_accessor_test.cc @@ -0,0 +1,117 @@ +// Copyright 2017 The Ray Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ray/gcs/gcs_client/global_state_accessor.h" +#include "gtest/gtest.h" +#include "ray/common/test_util.h" +#include "ray/gcs/gcs_server/gcs_server.h" +#include "ray/gcs/test/gcs_test_util.h" +#include "ray/rpc/gcs_server/gcs_rpc_client.h" + +namespace ray { + +class GlobalStateAccessorTest : public RedisServiceManagerForTest { + protected: + void SetUp() override { + config.grpc_server_port = 0; + config.grpc_server_name = "MockedGcsServer"; + config.grpc_server_thread_num = 1; + config.redis_address = "127.0.0.1"; + config.is_test = true; + config.redis_port = REDIS_SERVER_PORTS.front(); + gcs_server_.reset(new gcs::GcsServer(config)); + io_service_.reset(new boost::asio::io_service()); + + thread_io_service_.reset(new std::thread([this] { + std::unique_ptr work( + new boost::asio::io_service::work(*io_service_)); + io_service_->run(); + })); + + thread_gcs_server_.reset(new std::thread([this] { gcs_server_->Start(); })); + + // Wait until server starts listening. + while (!gcs_server_->IsStarted()) { + std::this_thread::sleep_for(std::chrono::milliseconds(10)); + } + + // Create GCS client. + gcs::GcsClientOptions options(config.redis_address, config.redis_port, + config.redis_password, config.is_test); + gcs_client_.reset(new gcs::ServiceBasedGcsClient(options)); + RAY_CHECK_OK(gcs_client_->Connect(*io_service_)); + + // Create global state. + std::stringstream address; + address << config.redis_address << ":" << config.redis_port; + global_state_.reset(new gcs::GlobalStateAccessor(address.str(), "", true)); + RAY_CHECK(global_state_->Connect()); + } + + void TearDown() override { + gcs_server_->Stop(); + io_service_->stop(); + thread_io_service_->join(); + thread_gcs_server_->join(); + gcs_client_->Disconnect(); + global_state_->Disconnect(); + global_state_.reset(); + FlushAll(); + } + + bool WaitReady(std::future future, const std::chrono::milliseconds &timeout_ms) { + auto status = future.wait_for(timeout_ms); + return status == std::future_status::ready && future.get(); + } + + // GCS server. + gcs::GcsServerConfig config; + std::unique_ptr gcs_server_; + std::unique_ptr thread_io_service_; + std::unique_ptr thread_gcs_server_; + std::unique_ptr io_service_; + + // GCS client. + std::unique_ptr gcs_client_; + + std::unique_ptr global_state_; + + // Timeout waiting for GCS server reply, default is 2s. + const std::chrono::milliseconds timeout_ms_{2000}; +}; + +TEST_F(GlobalStateAccessorTest, TestJobTable) { + int job_count = 100; + ASSERT_EQ(global_state_->GetAllJobInfo().size(), 0); + for (int index = 0; index < job_count; ++index) { + auto job_id = JobID::FromInt(index); + auto job_table_data = Mocker::GenJobTableData(job_id); + std::promise promise; + RAY_CHECK_OK(gcs_client_->Jobs().AsyncAdd( + job_table_data, [&promise](Status status) { promise.set_value(status.ok()); })); + WaitReady(promise.get_future(), timeout_ms_); + } + ASSERT_EQ(global_state_->GetAllJobInfo().size(), job_count); +} + +} // namespace ray + +int main(int argc, char **argv) { + ::testing::InitGoogleTest(&argc, argv); + RAY_CHECK(argc == 4); + ray::REDIS_SERVER_EXEC_PATH = argv[1]; + ray::REDIS_CLIENT_EXEC_PATH = argv[2]; + ray::REDIS_MODULE_LIBRARY_PATH = argv[3]; + return RUN_ALL_TESTS(); +} diff --git a/src/ray/gcs/gcs_server/gcs_server.cc b/src/ray/gcs/gcs_server/gcs_server.cc index d741451b3..b17e5d336 100644 --- a/src/ray/gcs/gcs_server/gcs_server.cc +++ b/src/ray/gcs/gcs_server/gcs_server.cc @@ -44,6 +44,10 @@ void GcsServer::Start() { // Init gcs pub sub instance. gcs_pub_sub_ = std::make_shared(redis_gcs_client_->GetRedisClient()); + // Init gcs table storage. + gcs_table_storage_ = + std::make_shared(redis_gcs_client_->GetRedisClient()); + // Init gcs node_manager. InitGcsNodeManager(); @@ -192,7 +196,7 @@ void GcsServer::InitGcsActorManager() { std::unique_ptr GcsServer::InitJobInfoHandler() { return std::unique_ptr( - new rpc::DefaultJobInfoHandler(*redis_gcs_client_, gcs_pub_sub_)); + new rpc::DefaultJobInfoHandler(gcs_table_storage_, gcs_pub_sub_)); } std::unique_ptr GcsServer::InitActorInfoHandler() { diff --git a/src/ray/gcs/gcs_server/gcs_server.h b/src/ray/gcs/gcs_server/gcs_server.h index ccbdb904b..97ea40429 100644 --- a/src/ray/gcs/gcs_server/gcs_server.h +++ b/src/ray/gcs/gcs_server/gcs_server.h @@ -20,6 +20,7 @@ #include #include #include "ray/gcs/gcs_server/gcs_redis_failure_detector.h" +#include "ray/gcs/gcs_server/gcs_table_storage.h" namespace ray { namespace gcs { @@ -148,6 +149,8 @@ class GcsServer { std::shared_ptr redis_gcs_client_; /// A publisher for publishing gcs messages. std::shared_ptr gcs_pub_sub_; + /// The gcs table storage. + std::shared_ptr gcs_table_storage_; /// Gcs service state flag, which is used for ut. bool is_started_ = false; bool is_stopped_ = false; diff --git a/src/ray/gcs/gcs_server/job_info_handler_impl.cc b/src/ray/gcs/gcs_server/job_info_handler_impl.cc index b3fa67b84..e20d1d068 100644 --- a/src/ray/gcs/gcs_server/job_info_handler_impl.cc +++ b/src/ray/gcs/gcs_server/job_info_handler_impl.cc @@ -17,14 +17,13 @@ namespace ray { namespace rpc { + void DefaultJobInfoHandler::HandleAddJob(const rpc::AddJobRequest &request, rpc::AddJobReply *reply, rpc::SendReplyCallback send_reply_callback) { JobID job_id = JobID::FromBinary(request.data().job_id()); RAY_LOG(INFO) << "Adding job, job id = " << job_id << ", driver pid = " << request.data().driver_pid(); - auto job_table_data = std::make_shared(); - job_table_data->CopyFrom(request.data()); auto on_done = [job_id, request, reply, send_reply_callback](const Status &status) { if (!status.ok()) { RAY_LOG(ERROR) << "Failed to add job, job id = " << job_id @@ -36,7 +35,7 @@ void DefaultJobInfoHandler::HandleAddJob(const rpc::AddJobRequest &request, GCS_RPC_SEND_REPLY(send_reply_callback, reply, status); }; - Status status = gcs_client_.Jobs().AsyncAdd(job_table_data, on_done); + Status status = gcs_table_storage_->JobTable().Put(job_id, request.data(), on_done); if (!status.ok()) { on_done(status); } @@ -61,10 +60,29 @@ void DefaultJobInfoHandler::HandleMarkJobFinished( GCS_RPC_SEND_REPLY(send_reply_callback, reply, status); }; - Status status = gcs_client_.Jobs().AsyncMarkFinished(job_id, on_done); + Status status = gcs_table_storage_->JobTable().Put(job_id, *job_table_data, on_done); if (!status.ok()) { on_done(status); } } + +void DefaultJobInfoHandler::HandleGetAllJobInfo( + const rpc::GetAllJobInfoRequest &request, rpc::GetAllJobInfoReply *reply, + rpc::SendReplyCallback send_reply_callback) { + RAY_LOG(INFO) << "Getting all job info."; + auto on_done = [reply, send_reply_callback]( + const std::unordered_map &result) { + for (auto &data : result) { + reply->add_job_info_list()->CopyFrom(data.second); + } + RAY_LOG(INFO) << "Finished getting all job info."; + GCS_RPC_SEND_REPLY(send_reply_callback, reply, Status::OK()); + }; + Status status = gcs_table_storage_->JobTable().GetAll(on_done); + if (!status.ok()) { + on_done(std::unordered_map()); + } +} + } // namespace rpc } // namespace ray diff --git a/src/ray/gcs/gcs_server/job_info_handler_impl.h b/src/ray/gcs/gcs_server/job_info_handler_impl.h index 27caddf41..7b8531232 100644 --- a/src/ray/gcs/gcs_server/job_info_handler_impl.h +++ b/src/ray/gcs/gcs_server/job_info_handler_impl.h @@ -15,6 +15,7 @@ #ifndef RAY_GCS_JOB_INFO_HANDLER_IMPL_H #define RAY_GCS_JOB_INFO_HANDLER_IMPL_H +#include "gcs_table_storage.h" #include "ray/gcs/pubsub/gcs_pub_sub.h" #include "ray/gcs/redis_gcs_client.h" #include "ray/rpc/gcs_server/gcs_rpc_server.h" @@ -25,9 +26,10 @@ namespace rpc { /// This implementation class of `JobInfoHandler`. class DefaultJobInfoHandler : public rpc::JobInfoHandler { public: - explicit DefaultJobInfoHandler(gcs::RedisGcsClient &gcs_client, - std::shared_ptr &gcs_pub_sub) - : gcs_client_(gcs_client), gcs_pub_sub_(gcs_pub_sub) {} + explicit DefaultJobInfoHandler(std::shared_ptr gcs_table_storage, + std::shared_ptr gcs_pub_sub) + : gcs_table_storage_(std::move(gcs_table_storage)), + gcs_pub_sub_(std::move(gcs_pub_sub)) {} void HandleAddJob(const AddJobRequest &request, AddJobReply *reply, SendReplyCallback send_reply_callback) override; @@ -36,8 +38,11 @@ class DefaultJobInfoHandler : public rpc::JobInfoHandler { MarkJobFinishedReply *reply, SendReplyCallback send_reply_callback) override; + void HandleGetAllJobInfo(const GetAllJobInfoRequest &request, GetAllJobInfoReply *reply, + SendReplyCallback send_reply_callback) override; + private: - gcs::RedisGcsClient &gcs_client_; + std::shared_ptr gcs_table_storage_; std::shared_ptr gcs_pub_sub_; }; diff --git a/src/ray/gcs/redis_accessor.h b/src/ray/gcs/redis_accessor.h index 9e6ba61c2..799fa4e63 100644 --- a/src/ray/gcs/redis_accessor.h +++ b/src/ray/gcs/redis_accessor.h @@ -177,6 +177,10 @@ class RedisJobInfoAccessor : public JobInfoAccessor { const SubscribeCallback &subscribe, const StatusCallback &done) override; + Status AsyncGetAll(const MultiItemCallback &callback) override { + return Status::NotImplemented("AsyncGetAll not implemented"); + } + private: /// Append job information to GCS asynchronously. /// diff --git a/src/ray/gcs/store_client/redis_store_client.cc b/src/ray/gcs/store_client/redis_store_client.cc index 99e390ffb..35533f233 100644 --- a/src/ray/gcs/store_client/redis_store_client.cc +++ b/src/ray/gcs/store_client/redis_store_client.cc @@ -350,7 +350,7 @@ void RedisStoreClient::RedisScanner::MGetValues( auto finished_count = std::make_shared(0); int size = mget_commands_by_shards.size(); for (auto &item : mget_commands_by_shards) { - auto mget_keys = item.second; + auto mget_keys = std::move(item.second); auto mget_callback = [this, finished_count, size, mget_keys, callback](const std::shared_ptr &reply) { if (!reply->IsNil()) { @@ -370,7 +370,7 @@ void RedisStoreClient::RedisScanner::MGetValues( callback(key_value_map_); } }; - RAY_CHECK_OK(item.first->RunArgvAsync(item.second, mget_callback)); + RAY_CHECK_OK(item.first->RunArgvAsync(mget_keys, mget_callback)); } } diff --git a/src/ray/gcs/test/gcs_test_util.h b/src/ray/gcs/test/gcs_test_util.h index d5acbeb47..c6fd7371d 100644 --- a/src/ray/gcs/test/gcs_test_util.h +++ b/src/ray/gcs/test/gcs_test_util.h @@ -37,8 +37,10 @@ struct Mocker { ray::FunctionDescriptorBuilder::BuildPython("", "", "", ""); auto actor_id = ActorID::Of(job_id, RandomTaskId(), 0); auto task_id = TaskID::ForActorCreationTask(actor_id); + auto resource = std::unordered_map(); builder.SetCommonTaskSpec(task_id, Language::PYTHON, empty_descriptor, job_id, - TaskID::Nil(), 0, TaskID::Nil(), owner_address, 1, {}, {}); + TaskID::Nil(), 0, TaskID::Nil(), owner_address, 1, resource, + resource); builder.SetActorCreationTaskSpec(actor_id, max_restarts, {}, 1, detached, name); return builder.Build(); } diff --git a/src/ray/protobuf/gcs_service.proto b/src/ray/protobuf/gcs_service.proto index fc88f6f9f..4b68c2d84 100644 --- a/src/ray/protobuf/gcs_service.proto +++ b/src/ray/protobuf/gcs_service.proto @@ -40,12 +40,22 @@ message MarkJobFinishedReply { GcsStatus status = 1; } +message GetAllJobInfoRequest { +} + +message GetAllJobInfoReply { + GcsStatus status = 1; + repeated JobTableData job_info_list = 2; +} + // Service for job info access. service JobInfoGcsService { // Add job to GCS Service. rpc AddJob(AddJobRequest) returns (AddJobReply); // Mark job as finished to GCS Service. rpc MarkJobFinished(MarkJobFinishedRequest) returns (MarkJobFinishedReply); + // Get information of all jobs from GCS Service. + rpc GetAllJobInfo(GetAllJobInfoRequest) returns (GetAllJobInfoReply); } message GetActorInfoRequest { diff --git a/src/ray/rpc/gcs_server/gcs_rpc_client.h b/src/ray/rpc/gcs_server/gcs_rpc_client.h index 61607385d..07d9c89ce 100644 --- a/src/ray/rpc/gcs_server/gcs_rpc_client.h +++ b/src/ray/rpc/gcs_server/gcs_rpc_client.h @@ -96,6 +96,9 @@ class GcsRpcClient { /// Mark job as finished to gcs server. VOID_GCS_RPC_CLIENT_METHOD(JobInfoGcsService, MarkJobFinished, job_info_grpc_client_, ) + /// Get information of all jobs from GCS Service. + VOID_GCS_RPC_CLIENT_METHOD(JobInfoGcsService, GetAllJobInfo, job_info_grpc_client_, ) + /// Create actor via GCS Service. VOID_RPC_CLIENT_METHOD(ActorInfoGcsService, CreateActor, actor_info_grpc_client_, ) diff --git a/src/ray/rpc/gcs_server/gcs_rpc_server.h b/src/ray/rpc/gcs_server/gcs_rpc_server.h index e11d2b373..a449881c9 100644 --- a/src/ray/rpc/gcs_server/gcs_rpc_server.h +++ b/src/ray/rpc/gcs_server/gcs_rpc_server.h @@ -61,6 +61,10 @@ class JobInfoGcsServiceHandler { virtual void HandleMarkJobFinished(const MarkJobFinishedRequest &request, MarkJobFinishedReply *reply, SendReplyCallback send_reply_callback) = 0; + + virtual void HandleGetAllJobInfo(const GetAllJobInfoRequest &request, + GetAllJobInfoReply *reply, + SendReplyCallback send_reply_callback) = 0; }; /// The `GrpcService` for `JobInfoGcsService`. @@ -81,6 +85,7 @@ class JobInfoGrpcService : public GrpcService { std::vector> *server_call_factories) override { JOB_INFO_SERVICE_RPC_HANDLER(AddJob); JOB_INFO_SERVICE_RPC_HANDLER(MarkJobFinished); + JOB_INFO_SERVICE_RPC_HANDLER(GetAllJobInfo); } private: