From ade7ebfc0c47b077273fa36bbb26a889c8637665 Mon Sep 17 00:00:00 2001 From: fangfengbin <869218239a@zju.edu.cn> Date: Wed, 5 Feb 2020 12:06:25 +0800 Subject: [PATCH] Add service based gcs client (#6686) --- .travis.yml | 18 + BUILD.bazel | 38 + java/BUILD.bazel | 1 + python/setup.py | 1 + src/ray/common/ray_config_def.h | 5 + src/ray/core_worker/test/core_worker_test.cc | 38 +- .../gcs/gcs_client/service_based_accessor.cc | 821 ++++++++++++++++++ .../gcs/gcs_client/service_based_accessor.h | 319 +++++++ .../gcs_client/service_based_gcs_client.cc | 86 ++ .../gcs/gcs_client/service_based_gcs_client.h | 43 + .../test/service_based_gcs_client_test.cc | 631 ++++++++++++++ .../gcs/gcs_server/actor_info_handler_impl.cc | 5 +- src/ray/gcs/gcs_server/gcs_server.cc | 33 + src/ray/gcs/gcs_server/gcs_server.h | 7 + .../gcs/gcs_server/node_info_handler_impl.cc | 3 +- .../gcs_server/test/gcs_server_rpc_test.cc | 4 +- src/ray/raylet/main.cc | 14 +- src/ray/rpc/gcs_server/gcs_rpc_client.h | 1 + src/ray/test/run_core_worker_tests.sh | 5 +- streaming/src/test/queue_tests_base.h | 34 +- .../src/test/run_streaming_queue_test.sh | 5 +- streaming/src/test/streaming_queue_tests.cc | 8 +- 22 files changed, 2105 insertions(+), 15 deletions(-) create mode 100644 src/ray/gcs/gcs_client/service_based_accessor.cc create mode 100644 src/ray/gcs/gcs_client/service_based_accessor.h create mode 100644 src/ray/gcs/gcs_client/service_based_gcs_client.cc create mode 100644 src/ray/gcs/gcs_client/service_based_gcs_client.h create mode 100644 src/ray/gcs/gcs_client/test/service_based_gcs_client_test.cc diff --git a/.travis.yml b/.travis.yml index 509928db8..e3b341333 100644 --- a/.travis.yml +++ b/.travis.yml @@ -47,6 +47,24 @@ matrix: - if [ $RAY_CI_STREAMING_PYTHON_AFFECTED == "1" ]; then python -m pytest -v --durations=5 --timeout=300 streaming/python/tests/; fi - if [ $RAY_CI_STREAMING_JAVA_AFFECTED == "1" ]; then ./streaming/java/test.sh; fi + - os: linux + env: + - TESTSUITE=gcs_service + - JDK='Oracle JDK 8' + - RAY_GCS_SERVICE_ENABLED=true + - RAY_INSTALL_JAVA=1 + - PYTHON=3.5 PYTHONWARNINGS=ignore + install: + - python $TRAVIS_BUILD_DIR/ci/travis/determine_tests_to_run.py + - eval `python $TRAVIS_BUILD_DIR/ci/travis/determine_tests_to_run.py` + - ./ci/travis/install-bazel.sh + - ./ci/suppress_output ./ci/travis/install-dependencies.sh + - export PATH="$HOME/miniconda/bin:$PATH" + - ./ci/suppress_output ./ci/travis/install-ray.sh + script: + - ./ci/suppress_output bash src/ray/test/run_core_worker_tests.sh + - ./ci/suppress_output bash streaming/src/test/run_streaming_queue_test.sh + - os: linux env: LINT=1 PYTHONWARNINGS=ignore before_install: diff --git a/BUILD.bazel b/BUILD.bazel index 12a1c213f..92c9f0b0f 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -329,6 +329,7 @@ cc_binary( "src/ray/gcs/gcs_server/gcs_server_main.cc", ], copts = COPTS, + visibility = ["//java:__subpackages__"], deps = [ ":gcs_server_lib", "@com_github_gflags_gflags//:gflags", @@ -412,6 +413,7 @@ cc_library( ":object_manager", ":ray_common", ":ray_util", + ":service_based_gcs_client_lib", ":stats_lib", ":worker_rpc", "@boost//:asio", @@ -702,6 +704,40 @@ cc_test( ], ) +cc_library( + name = "service_based_gcs_client_lib", + srcs = glob( + [ + "src/ray/gcs/gcs_client/service_based_*.cc", + ], + ), + hdrs = glob( + [ + "src/ray/gcs/gcs_client/service_based_*.h", + ], + ), + copts = COPTS, + deps = [ + ":gcs_server_lib", + ], +) + +cc_test( + name = "gcs_server_test", + srcs = ["src/ray/gcs/gcs_client/test/service_based_gcs_client_test.cc"], + args = ["$(location redis-server) $(location redis-cli) $(location libray_redis_module.so)"], + copts = COPTS, + data = [ + "//:libray_redis_module.so", + "//:redis-cli", + "//:redis-server", + ], + deps = [ + ":service_based_gcs_client_lib", + "@com_google_googletest//:gtest_main", + ], +) + cc_library( name = "object_manager", srcs = glob([ @@ -1209,6 +1245,7 @@ genrule( "//:libray_redis_module.so", "//:raylet", "//:raylet_monitor", + "//:gcs_server", "@plasma//:plasma_store_server", "//streaming:copy_streaming_py_proto", ] + select({ @@ -1228,6 +1265,7 @@ genrule( cp -f $(location //:raylet_monitor) "$$WORK_DIR/python/ray/core/src/ray/raylet/" && cp -f $(location @plasma//:plasma_store_server) "$$WORK_DIR/python/ray/core/src/plasma/" && cp -f $(location //:raylet) "$$WORK_DIR/python/ray/core/src/ray/raylet/" && + cp -f $(location //:gcs_server) "$$WORK_DIR/python/ray/core/src/ray/gcs/" && mkdir -p "$$WORK_DIR/python/ray/core/generated/ray/protocol/" && for f in $(locations //:all_py_proto); do cp -f "$$f" "$$WORK_DIR/python/ray/core/generated/"; diff --git a/java/BUILD.bazel b/java/BUILD.bazel index 9159ffc7f..90e376b92 100644 --- a/java/BUILD.bazel +++ b/java/BUILD.bazel @@ -155,6 +155,7 @@ filegroup( srcs = [ ":cp_plasma_store_server", "//:core_worker_library_java", + "//:gcs_server", "//:libray_redis_module.so", "//:raylet", "//:redis-server", diff --git a/python/setup.py b/python/setup.py index a29d4d542..bc0dd003f 100644 --- a/python/setup.py +++ b/python/setup.py @@ -22,6 +22,7 @@ ray_files = [ "ray/core/src/plasma/plasma_store_server", "ray/_raylet.so", "ray/core/src/ray/raylet/raylet_monitor", + "ray/core/src/ray/gcs/gcs_server", "ray/core/src/ray/raylet/raylet", "ray/dashboard/dashboard.py", "ray/streaming/_streaming.so", diff --git a/src/ray/common/ray_config_def.h b/src/ray/common/ray_config_def.h index a708358fd..26b3a05c9 100644 --- a/src/ray/common/ray_config_def.h +++ b/src/ray/common/ray_config_def.h @@ -191,3 +191,8 @@ RAY_CONFIG(uint32_t, object_store_get_warn_per_num_attempts, 50) /// When getting objects from object store, max number of ids to print in the warning /// message. RAY_CONFIG(uint32_t, object_store_get_max_ids_to_print_in_warning, 20) + +/// Allow up to 5 seconds for connecting to gcs service. +/// Note: this only takes effect when gcs service is enabled. +RAY_CONFIG(int64_t, gcs_service_connect_retries, 50) +RAY_CONFIG(int64_t, gcs_service_connect_wait_milliseconds, 100) diff --git a/src/ray/core_worker/test/core_worker_test.cc b/src/ray/core_worker/test/core_worker_test.cc index 165264cc0..6e366b601 100644 --- a/src/ray/core_worker/test/core_worker_test.cc +++ b/src/ray/core_worker/test/core_worker_test.cc @@ -29,6 +29,7 @@ std::string raylet_executable; int node_manager_port = 0; std::string raylet_monitor_executable; std::string mock_worker_executable; +std::string gcs_server_executable; } // namespace @@ -91,6 +92,11 @@ class CoreWorkerTest : public ::testing::Test { // receive the heartbeat from another. So starting raylet monitor is required here. raylet_monitor_pid_ = StartRayletMonitor("127.0.0.1"); + // start gcs server + if (getenv("RAY_GCS_SERVICE_ENABLED") != nullptr) { + gcs_server_pid_ = StartGcsServer("127.0.0.1"); + } + // start raylet on each node. Assign each node with different resources so that // a task can be scheduled to the desired node. for (int i = 0; i < num_nodes; i++) { @@ -112,6 +118,10 @@ class CoreWorkerTest : public ::testing::Test { if (!raylet_monitor_pid_.empty()) { StopRayletMonitor(raylet_monitor_pid_); } + + if (!gcs_server_pid_.empty()) { + StopGcsServer(gcs_server_pid_); + } } JobID NextJobId() const { @@ -192,6 +202,30 @@ class CoreWorkerTest : public ::testing::Test { std::string kill_9 = "kill -9 `cat " + raylet_monitor_pid + "`"; RAY_LOG(DEBUG) << kill_9; ASSERT_TRUE(system(kill_9.c_str()) == 0); + ASSERT_TRUE(system(("rm -f " + raylet_monitor_pid).c_str()) == 0); + } + + std::string StartGcsServer(std::string redis_address) { + std::string gcs_server_pid = + "/tmp/gcs_server" + ObjectID::FromRandom().Hex() + ".pid"; + std::string gcs_server_start_cmd = gcs_server_executable; + gcs_server_start_cmd.append(" --redis_address=" + redis_address) + .append(" --redis_port=6379") + .append(" --config_list=initial_reconstruction_timeout_milliseconds,2000") + .append(" & echo $! > " + gcs_server_pid); + + RAY_LOG(DEBUG) << "Starting GCS server, command: " << gcs_server_start_cmd; + RAY_CHECK(system(gcs_server_start_cmd.c_str()) == 0); + usleep(200 * 1000); + RAY_LOG(INFO) << "GCS server started."; + return gcs_server_pid; + } + + void StopGcsServer(std::string gcs_server_pid) { + std::string kill_9 = "kill -9 `cat " + gcs_server_pid + "`"; + RAY_LOG(DEBUG) << kill_9; + ASSERT_TRUE(system(kill_9.c_str()) == 0); + ASSERT_TRUE(system(("rm -f " + gcs_server_pid).c_str()) == 0); } void SetUp() {} @@ -230,6 +264,7 @@ class CoreWorkerTest : public ::testing::Test { std::vector raylet_store_socket_names_; std::string raylet_monitor_pid_; gcs::GcsClientOptions gcs_options_; + std::string gcs_server_pid_; }; bool CoreWorkerTest::WaitForDirectCallActorState(CoreWorker &worker, @@ -1020,11 +1055,12 @@ TEST_F(TwoNodeTest, TestDirectActorTaskCrossNodesFailure) { int main(int argc, char **argv) { ::testing::InitGoogleTest(&argc, argv); - RAY_CHECK(argc == 6); + RAY_CHECK(argc == 7); store_executable = std::string(argv[1]); raylet_executable = std::string(argv[2]); node_manager_port = std::stoi(std::string(argv[3])); raylet_monitor_executable = std::string(argv[4]); mock_worker_executable = std::string(argv[5]); + gcs_server_executable = std::string(argv[6]); return RUN_ALL_TESTS(); } diff --git a/src/ray/gcs/gcs_client/service_based_accessor.cc b/src/ray/gcs/gcs_client/service_based_accessor.cc new file mode 100644 index 000000000..9c3627267 --- /dev/null +++ b/src/ray/gcs/gcs_client/service_based_accessor.cc @@ -0,0 +1,821 @@ +#include "ray/gcs/gcs_client/service_based_accessor.h" +#include "ray/gcs/gcs_client/service_based_gcs_client.h" + +namespace ray { +namespace gcs { + +ServiceBasedJobInfoAccessor::ServiceBasedJobInfoAccessor( + ServiceBasedGcsClient *client_impl) + : client_impl_(client_impl), + job_sub_executor_(client_impl->GetRedisGcsClient().job_table()) {} + +Status ServiceBasedJobInfoAccessor::AsyncAdd( + const std::shared_ptr &data_ptr, const StatusCallback &callback) { + JobID job_id = JobID::FromBinary(data_ptr->job_id()); + RAY_LOG(DEBUG) << "Adding job, job id = " << job_id + << ", driver pid = " << data_ptr->driver_pid(); + rpc::AddJobRequest request; + request.mutable_data()->CopyFrom(*data_ptr); + client_impl_->GetGcsRpcClient().AddJob( + request, + [job_id, data_ptr, callback](const Status &status, const rpc::AddJobReply &reply) { + if (callback) { + callback(status); + } + RAY_LOG(DEBUG) << "Finished adding job, status = " << status + << ", job id = " << job_id + << ", driver pid = " << data_ptr->driver_pid(); + }); + return Status::OK(); +} + +Status ServiceBasedJobInfoAccessor::AsyncMarkFinished(const JobID &job_id, + const StatusCallback &callback) { + RAY_LOG(DEBUG) << "Marking job state, job id = " << job_id; + rpc::MarkJobFinishedRequest request; + request.set_job_id(job_id.Binary()); + client_impl_->GetGcsRpcClient().MarkJobFinished( + request, + [job_id, callback](const Status &status, const rpc::MarkJobFinishedReply &reply) { + if (callback) { + callback(status); + } + RAY_LOG(DEBUG) << "Finished marking job state, status = " << status + << ", job id = " << job_id; + }); + return Status::OK(); +} + +Status ServiceBasedJobInfoAccessor::AsyncSubscribeToFinishedJobs( + const SubscribeCallback &subscribe, const StatusCallback &done) { + RAY_LOG(DEBUG) << "Subscribing finished job."; + RAY_CHECK(subscribe != nullptr); + auto on_subscribe = [subscribe](const JobID &job_id, const JobTableData &job_data) { + if (job_data.is_dead()) { + subscribe(job_id, job_data); + } + }; + Status status = + job_sub_executor_.AsyncSubscribeAll(ClientID::Nil(), on_subscribe, done); + RAY_LOG(DEBUG) << "Finished subscribing finished job."; + return status; +} + +ServiceBasedActorInfoAccessor::ServiceBasedActorInfoAccessor( + ServiceBasedGcsClient *client_impl) + : client_impl_(client_impl), + subscribe_id_(ClientID::FromRandom()), + actor_sub_executor_(client_impl->GetRedisGcsClient().actor_table()) {} + +Status ServiceBasedActorInfoAccessor::AsyncGet( + const ActorID &actor_id, const OptionalItemCallback &callback) { + RAY_LOG(DEBUG) << "Getting actor info, actor id = " << actor_id; + rpc::GetActorInfoRequest request; + request.set_actor_id(actor_id.Binary()); + client_impl_->GetGcsRpcClient().GetActorInfo( + request, + [actor_id, callback](const Status &status, const rpc::GetActorInfoReply &reply) { + if (reply.has_actor_table_data()) { + rpc::ActorTableData actor_table_data(reply.actor_table_data()); + callback(status, actor_table_data); + } else { + callback(status, boost::none); + } + RAY_LOG(DEBUG) << "Finished getting actor info, status = " << status + << ", actor id = " << actor_id; + }); + return Status::OK(); +} + +Status ServiceBasedActorInfoAccessor::AsyncRegister( + const std::shared_ptr &data_ptr, + const StatusCallback &callback) { + ActorID actor_id = ActorID::FromBinary(data_ptr->actor_id()); + RAY_LOG(DEBUG) << "Registering actor info, actor id = " << actor_id; + rpc::RegisterActorInfoRequest request; + request.mutable_actor_table_data()->CopyFrom(*data_ptr); + client_impl_->GetGcsRpcClient().RegisterActorInfo( + request, [actor_id, callback](const Status &status, + const rpc::RegisterActorInfoReply &reply) { + if (callback) { + callback(status); + } + RAY_LOG(DEBUG) << "Finished registering actor info, status = " << status + << ", actor id = " << actor_id; + }); + return Status::OK(); +} + +Status ServiceBasedActorInfoAccessor::AsyncUpdate( + const ActorID &actor_id, const std::shared_ptr &data_ptr, + const StatusCallback &callback) { + RAY_LOG(DEBUG) << "Updating actor info, actor id = " << actor_id; + rpc::UpdateActorInfoRequest request; + request.set_actor_id(actor_id.Binary()); + request.mutable_actor_table_data()->CopyFrom(*data_ptr); + client_impl_->GetGcsRpcClient().UpdateActorInfo( + request, + [actor_id, callback](const Status &status, const rpc::UpdateActorInfoReply &reply) { + if (callback) { + callback(status); + } + RAY_LOG(DEBUG) << "Finished updating actor info, status = " << status + << ", actor id = " << actor_id; + }); + return Status::OK(); +} + +Status ServiceBasedActorInfoAccessor::AsyncSubscribeAll( + const SubscribeCallback &subscribe, + const StatusCallback &done) { + RAY_LOG(DEBUG) << "Subscribing register or update operations of actors."; + RAY_CHECK(subscribe != nullptr); + auto status = actor_sub_executor_.AsyncSubscribeAll(ClientID::Nil(), subscribe, done); + RAY_LOG(DEBUG) << "Finished subscribing register or update operations of actors."; + return status; +} + +Status ServiceBasedActorInfoAccessor::AsyncSubscribe( + const ActorID &actor_id, + const SubscribeCallback &subscribe, + const StatusCallback &done) { + RAY_LOG(DEBUG) << "Subscribing update operations of actor, actor id = " << actor_id; + RAY_CHECK(subscribe != nullptr) << "Failed to subscribe actor, actor id = " << actor_id; + auto status = + actor_sub_executor_.AsyncSubscribe(subscribe_id_, actor_id, subscribe, done); + RAY_LOG(DEBUG) << "Finished subscribing update operations of actor, actor id = " + << actor_id; + return status; +} + +Status ServiceBasedActorInfoAccessor::AsyncUnsubscribe(const ActorID &actor_id, + const StatusCallback &done) { + RAY_LOG(DEBUG) << "Cancelling subscription to an actor, actor id = " << actor_id; + auto status = actor_sub_executor_.AsyncUnsubscribe(subscribe_id_, actor_id, done); + RAY_LOG(DEBUG) << "Finished cancelling subscription to an actor, actor id = " + << actor_id; + return status; +} + +Status ServiceBasedActorInfoAccessor::AsyncAddCheckpoint( + const std::shared_ptr &data_ptr, + const StatusCallback &callback) { + ActorID actor_id = ActorID::FromBinary(data_ptr->actor_id()); + ActorCheckpointID checkpoint_id = + ActorCheckpointID::FromBinary(data_ptr->checkpoint_id()); + RAY_LOG(DEBUG) << "Adding actor checkpoint, actor id = " << actor_id + << ", checkpoint id = " << checkpoint_id; + rpc::AddActorCheckpointRequest request; + request.mutable_checkpoint_data()->CopyFrom(*data_ptr); + client_impl_->GetGcsRpcClient().AddActorCheckpoint( + request, [actor_id, checkpoint_id, callback]( + const Status &status, const rpc::AddActorCheckpointReply &reply) { + if (callback) { + callback(status); + } + RAY_LOG(DEBUG) << "Finished adding actor checkpoint, status = " << status + << ", actor id = " << actor_id + << ", checkpoint id = " << checkpoint_id; + }); + return Status::OK(); +} + +Status ServiceBasedActorInfoAccessor::AsyncGetCheckpoint( + const ActorCheckpointID &checkpoint_id, + const OptionalItemCallback &callback) { + RAY_LOG(DEBUG) << "Getting actor checkpoint, checkpoint id = " << checkpoint_id; + rpc::GetActorCheckpointRequest request; + request.set_checkpoint_id(checkpoint_id.Binary()); + client_impl_->GetGcsRpcClient().GetActorCheckpoint( + request, [checkpoint_id, callback](const Status &status, + const rpc::GetActorCheckpointReply &reply) { + if (reply.has_checkpoint_data()) { + rpc::ActorCheckpointData checkpoint_data(reply.checkpoint_data()); + callback(status, checkpoint_data); + } else { + callback(status, boost::none); + } + RAY_LOG(DEBUG) << "Finished getting actor checkpoint, status = " << status + << ", checkpoint id = " << checkpoint_id; + }); + return Status::OK(); +} + +Status ServiceBasedActorInfoAccessor::AsyncGetCheckpointID( + const ActorID &actor_id, + const OptionalItemCallback &callback) { + RAY_LOG(DEBUG) << "Getting actor checkpoint id, actor id = " << actor_id; + rpc::GetActorCheckpointIDRequest request; + request.set_actor_id(actor_id.Binary()); + client_impl_->GetGcsRpcClient().GetActorCheckpointID( + request, [actor_id, callback](const Status &status, + const rpc::GetActorCheckpointIDReply &reply) { + if (reply.has_checkpoint_id_data()) { + rpc::ActorCheckpointIdData checkpoint_id_data(reply.checkpoint_id_data()); + callback(status, checkpoint_id_data); + } else { + callback(status, boost::none); + } + RAY_LOG(DEBUG) << "Finished getting actor checkpoint id, status = " << status + << ", actor id = " << actor_id; + }); + return Status::OK(); +} + +ServiceBasedNodeInfoAccessor::ServiceBasedNodeInfoAccessor( + ServiceBasedGcsClient *client_impl) + : client_impl_(client_impl), + resource_sub_executor_(client_impl->GetRedisGcsClient().resource_table()), + heartbeat_sub_executor_(client_impl->GetRedisGcsClient().heartbeat_table()), + heartbeat_batch_sub_executor_( + client_impl->GetRedisGcsClient().heartbeat_batch_table()) {} + +Status ServiceBasedNodeInfoAccessor::RegisterSelf(const GcsNodeInfo &local_node_info) { + auto node_id = ClientID::FromBinary(local_node_info.node_id()); + RAY_LOG(DEBUG) << "Registering node info, node id = " << node_id + << ", address is = " << local_node_info.node_manager_address(); + RAY_CHECK(local_node_id_.IsNil()) << "This node is already connected."; + RAY_CHECK(local_node_info.state() == GcsNodeInfo::ALIVE); + rpc::RegisterNodeRequest request; + request.mutable_node_info()->CopyFrom(local_node_info); + client_impl_->GetGcsRpcClient().RegisterNode( + request, [this, node_id, &local_node_info](const Status &status, + const rpc::RegisterNodeReply &reply) { + if (status.ok()) { + local_node_info_.CopyFrom(local_node_info); + local_node_id_ = ClientID::FromBinary(local_node_info.node_id()); + } + RAY_LOG(DEBUG) << "Finished registering node info, status = " << status + << ", node id = " << node_id; + }); + return Status::OK(); +} + +Status ServiceBasedNodeInfoAccessor::UnregisterSelf() { + RAY_CHECK(!local_node_id_.IsNil()) << "This node is disconnected."; + ClientID node_id = ClientID::FromBinary(local_node_info_.node_id()); + RAY_LOG(DEBUG) << "Unregistering node info, node id = " << node_id; + rpc::UnregisterNodeRequest request; + request.set_node_id(local_node_info_.node_id()); + client_impl_->GetGcsRpcClient().UnregisterNode( + request, + [this, node_id](const Status &status, const rpc::UnregisterNodeReply &reply) { + if (status.ok()) { + local_node_info_.set_state(GcsNodeInfo::DEAD); + local_node_id_ = ClientID::Nil(); + } + RAY_LOG(DEBUG) << "Finished unregistering node info, status = " << status + << ", node id = " << node_id; + }); + return Status::OK(); +} + +const ClientID &ServiceBasedNodeInfoAccessor::GetSelfId() const { return local_node_id_; } + +const GcsNodeInfo &ServiceBasedNodeInfoAccessor::GetSelfInfo() const { + return local_node_info_; +} + +Status ServiceBasedNodeInfoAccessor::AsyncRegister(const rpc::GcsNodeInfo &node_info, + const StatusCallback &callback) { + ClientID node_id = ClientID::FromBinary(node_info.node_id()); + RAY_LOG(DEBUG) << "Registering node info, node id = " << node_id; + rpc::RegisterNodeRequest request; + request.mutable_node_info()->CopyFrom(node_info); + client_impl_->GetGcsRpcClient().RegisterNode( + request, + [node_id, callback](const Status &status, const rpc::RegisterNodeReply &reply) { + if (callback) { + callback(status); + } + RAY_LOG(DEBUG) << "Finished registering node info, status = " << status + << ", node id = " << node_id; + }); + return Status::OK(); +} + +Status ServiceBasedNodeInfoAccessor::AsyncUnregister(const ClientID &node_id, + const StatusCallback &callback) { + RAY_LOG(DEBUG) << "Unregistering node info, node id = " << node_id; + rpc::UnregisterNodeRequest request; + request.set_node_id(node_id.Binary()); + client_impl_->GetGcsRpcClient().UnregisterNode( + request, + [node_id, callback](const Status &status, const rpc::UnregisterNodeReply &reply) { + if (callback) { + callback(status); + } + RAY_LOG(DEBUG) << "Finished unregistering node info, status = " << status + << ", node id = " << node_id; + }); + return Status::OK(); +} + +Status ServiceBasedNodeInfoAccessor::AsyncGetAll( + const MultiItemCallback &callback) { + RAY_LOG(DEBUG) << "Getting information of all nodes."; + rpc::GetAllNodeInfoRequest request; + client_impl_->GetGcsRpcClient().GetAllNodeInfo( + request, [callback](const Status &status, const rpc::GetAllNodeInfoReply &reply) { + std::vector result; + result.reserve((reply.node_info_list_size())); + for (int index = 0; index < reply.node_info_list_size(); ++index) { + result.emplace_back(reply.node_info_list(index)); + } + callback(status, result); + RAY_LOG(DEBUG) << "Finished getting information of all nodes, status = " + << status; + }); + return Status::OK(); +} + +Status ServiceBasedNodeInfoAccessor::AsyncSubscribeToNodeChange( + const SubscribeCallback &subscribe, + const StatusCallback &done) { + RAY_LOG(DEBUG) << "Subscribing node change."; + RAY_CHECK(subscribe != nullptr); + ClientTable &client_table = client_impl_->GetRedisGcsClient().client_table(); + auto status = client_table.SubscribeToNodeChange(subscribe, done); + RAY_LOG(DEBUG) << "Finished subscribing node change."; + return status; +} + +boost::optional ServiceBasedNodeInfoAccessor::Get( + const ClientID &node_id) const { + GcsNodeInfo node_info; + ClientTable &client_table = client_impl_->GetRedisGcsClient().client_table(); + bool found = client_table.GetClient(node_id, &node_info); + boost::optional optional_node; + if (found) { + optional_node = std::move(node_info); + } + return optional_node; +} + +const std::unordered_map &ServiceBasedNodeInfoAccessor::GetAll() + const { + ClientTable &client_table = client_impl_->GetRedisGcsClient().client_table(); + return client_table.GetAllClients(); +} + +bool ServiceBasedNodeInfoAccessor::IsRemoved(const ClientID &node_id) const { + ClientTable &client_table = client_impl_->GetRedisGcsClient().client_table(); + return client_table.IsRemoved(node_id); +} + +Status ServiceBasedNodeInfoAccessor::AsyncGetResources( + const ClientID &node_id, const OptionalItemCallback &callback) { + RAY_LOG(DEBUG) << "Getting node resources, node id = " << node_id; + rpc::GetResourcesRequest request; + request.set_node_id(node_id.Binary()); + client_impl_->GetGcsRpcClient().GetResources( + request, + [node_id, callback](const Status &status, const rpc::GetResourcesReply &reply) { + ResourceMap resource_map; + for (auto resource : reply.resources()) { + resource_map[resource.first] = + std::make_shared(resource.second); + } + callback(status, resource_map); + RAY_LOG(DEBUG) << "Finished getting node resources, status = " << status + << ", node id = " << node_id; + }); + return Status::OK(); +} + +Status ServiceBasedNodeInfoAccessor::AsyncUpdateResources( + const ClientID &node_id, const ResourceMap &resources, + const StatusCallback &callback) { + RAY_LOG(DEBUG) << "Updating node resources, node id = " << node_id; + rpc::UpdateResourcesRequest request; + request.set_node_id(node_id.Binary()); + for (auto &resource : resources) { + (*request.mutable_resources())[resource.first] = *resource.second; + } + client_impl_->GetGcsRpcClient().UpdateResources( + request, + [node_id, callback](const Status &status, const rpc::UpdateResourcesReply &reply) { + if (callback) { + callback(status); + } + RAY_LOG(DEBUG) << "Finished updating node resources, status = " << status + << ", node id = " << node_id; + }); + return Status::OK(); +} + +Status ServiceBasedNodeInfoAccessor::AsyncDeleteResources( + const ClientID &node_id, const std::vector &resource_names, + const StatusCallback &callback) { + RAY_LOG(DEBUG) << "Deleting node resources, node id = " << node_id; + rpc::DeleteResourcesRequest request; + request.set_node_id(node_id.Binary()); + for (auto &resource_name : resource_names) { + request.add_resource_name_list(resource_name); + } + client_impl_->GetGcsRpcClient().DeleteResources( + request, + [node_id, callback](const Status &status, const rpc::DeleteResourcesReply &reply) { + if (callback) { + callback(status); + } + RAY_LOG(DEBUG) << "Finished deleting node resources, status = " << status + << ", node id = " << node_id; + }); + return Status::OK(); +} + +Status ServiceBasedNodeInfoAccessor::AsyncSubscribeToResources( + const SubscribeCallback &subscribe, + const StatusCallback &done) { + RAY_LOG(DEBUG) << "Subscribing node resources change."; + RAY_CHECK(subscribe != nullptr); + auto status = + resource_sub_executor_.AsyncSubscribeAll(ClientID::Nil(), subscribe, done); + RAY_LOG(DEBUG) << "Finished subscribing node resources change."; + return status; +} + +Status ServiceBasedNodeInfoAccessor::AsyncReportHeartbeat( + const std::shared_ptr &data_ptr, + const StatusCallback &callback) { + ClientID node_id = ClientID::FromBinary(data_ptr->client_id()); + RAY_LOG(DEBUG) << "Reporting heartbeat, node id = " << node_id; + rpc::ReportHeartbeatRequest request; + request.mutable_heartbeat()->CopyFrom(*data_ptr); + client_impl_->GetGcsRpcClient().ReportHeartbeat( + request, + [node_id, callback](const Status &status, const rpc::ReportHeartbeatReply &reply) { + if (callback) { + callback(status); + } + RAY_LOG(DEBUG) << "Finished reporting heartbeat, status = " << status + << ", node id = " << node_id; + }); + return Status::OK(); +} + +Status ServiceBasedNodeInfoAccessor::AsyncSubscribeHeartbeat( + const SubscribeCallback &subscribe, + const StatusCallback &done) { + RAY_LOG(DEBUG) << "Subscribing heartbeat."; + RAY_CHECK(subscribe != nullptr); + auto status = + heartbeat_sub_executor_.AsyncSubscribeAll(ClientID::Nil(), subscribe, done); + RAY_LOG(DEBUG) << "Finished subscribing heartbeat."; + return status; +} + +Status ServiceBasedNodeInfoAccessor::AsyncReportBatchHeartbeat( + const std::shared_ptr &data_ptr, + const StatusCallback &callback) { + RAY_LOG(DEBUG) << "Reporting batch heartbeat, batch size = " << data_ptr->batch_size(); + rpc::ReportBatchHeartbeatRequest request; + request.mutable_heartbeat_batch()->CopyFrom(*data_ptr); + client_impl_->GetGcsRpcClient().ReportBatchHeartbeat( + request, [data_ptr, callback](const Status &status, + const rpc::ReportBatchHeartbeatReply &reply) { + if (callback) { + callback(status); + } + RAY_LOG(DEBUG) << "Finished reporting batch heartbeat, status = " << status + << ", batch size = " << data_ptr->batch_size(); + }); + return Status::OK(); +} + +Status ServiceBasedNodeInfoAccessor::AsyncSubscribeBatchHeartbeat( + const ItemCallback &subscribe, + const StatusCallback &done) { + RAY_LOG(DEBUG) << "Subscribing batch heartbeat."; + RAY_CHECK(subscribe != nullptr); + auto on_subscribe = [subscribe](const ClientID &node_id, + const HeartbeatBatchTableData &data) { + subscribe(data); + }; + auto status = heartbeat_batch_sub_executor_.AsyncSubscribeAll(ClientID::Nil(), + on_subscribe, done); + RAY_LOG(DEBUG) << "Finished subscribing batch heartbeat."; + return status; +} + +ServiceBasedTaskInfoAccessor::ServiceBasedTaskInfoAccessor( + ServiceBasedGcsClient *client_impl) + : client_impl_(client_impl), + subscribe_id_(ClientID::FromRandom()), + task_sub_executor_(client_impl->GetRedisGcsClient().raylet_task_table()), + task_lease_sub_executor_(client_impl->GetRedisGcsClient().task_lease_table()) {} + +Status ServiceBasedTaskInfoAccessor::AsyncAdd( + const std::shared_ptr &data_ptr, const StatusCallback &callback) { + TaskID task_id = TaskID::FromBinary(data_ptr->task().task_spec().task_id()); + JobID job_id = JobID::FromBinary(data_ptr->task().task_spec().job_id()); + RAY_LOG(DEBUG) << "Adding task, task id = " << task_id << ", job id = " << job_id; + rpc::AddTaskRequest request; + request.mutable_task_data()->CopyFrom(*data_ptr); + client_impl_->GetGcsRpcClient().AddTask( + request, + [task_id, job_id, callback](const Status &status, const rpc::AddTaskReply &reply) { + if (callback) { + callback(status); + } + RAY_LOG(DEBUG) << "Finished adding task, status = " << status + << ", task id = " << task_id << ", job id = " << job_id; + }); + return Status::OK(); +} + +Status ServiceBasedTaskInfoAccessor::AsyncGet( + const TaskID &task_id, const OptionalItemCallback &callback) { + RAY_LOG(DEBUG) << "Getting task, task id = " << task_id; + rpc::GetTaskRequest request; + request.set_task_id(task_id.Binary()); + client_impl_->GetGcsRpcClient().GetTask( + request, [task_id, callback](const Status &status, const rpc::GetTaskReply &reply) { + if (reply.has_task_data()) { + TaskTableData task_table_data(reply.task_data()); + callback(status, task_table_data); + } else { + callback(status, boost::none); + } + RAY_LOG(DEBUG) << "Finished getting task, status = " << status + << ", task id = " << task_id; + }); + return Status::OK(); +} + +Status ServiceBasedTaskInfoAccessor::AsyncDelete(const std::vector &task_ids, + const StatusCallback &callback) { + RAY_LOG(DEBUG) << "Deleting tasks, task id list size = " << task_ids.size(); + rpc::DeleteTasksRequest request; + for (auto &task_id : task_ids) { + request.add_task_id_list(task_id.Binary()); + } + client_impl_->GetGcsRpcClient().DeleteTasks( + request, + [task_ids, callback](const Status &status, const rpc::DeleteTasksReply &reply) { + if (callback) { + callback(status); + } + RAY_LOG(DEBUG) << "Finished deleting tasks, status = " << status + << ", task id list size = " << task_ids.size(); + }); + return Status::OK(); +} + +Status ServiceBasedTaskInfoAccessor::AsyncSubscribe( + const TaskID &task_id, const SubscribeCallback &subscribe, + const StatusCallback &done) { + RAY_LOG(DEBUG) << "Subscribing task, task id = " << task_id; + RAY_CHECK(subscribe != nullptr) << "Failed to subscribe task, task id = " << task_id; + auto status = + task_sub_executor_.AsyncSubscribe(subscribe_id_, task_id, subscribe, done); + RAY_LOG(DEBUG) << "Finished subscribing task, task id = " << task_id; + return status; +} + +Status ServiceBasedTaskInfoAccessor::AsyncUnsubscribe(const TaskID &task_id, + const StatusCallback &done) { + RAY_LOG(DEBUG) << "Unsubscribing task, task id = " << task_id; + auto status = task_sub_executor_.AsyncUnsubscribe(subscribe_id_, task_id, done); + RAY_LOG(DEBUG) << "Finished unsubscribing task, task id = " << task_id; + return status; +} + +Status ServiceBasedTaskInfoAccessor::AsyncAddTaskLease( + const std::shared_ptr &data_ptr, const StatusCallback &callback) { + TaskID task_id = TaskID::FromBinary(data_ptr->task_id()); + ClientID node_id = ClientID::FromBinary(data_ptr->node_manager_id()); + RAY_LOG(DEBUG) << "Adding task lease, task id = " << task_id + << ", node id = " << node_id; + rpc::AddTaskLeaseRequest request; + request.mutable_task_lease_data()->CopyFrom(*data_ptr); + client_impl_->GetGcsRpcClient().AddTaskLease( + request, [task_id, node_id, callback](const Status &status, + const rpc::AddTaskLeaseReply &reply) { + if (callback) { + callback(status); + } + RAY_LOG(DEBUG) << "Finished adding task lease, status = " << status + << ", task id = " << task_id << ", node id = " << node_id; + }); + return Status::OK(); +} + +Status ServiceBasedTaskInfoAccessor::AsyncSubscribeTaskLease( + const TaskID &task_id, + const SubscribeCallback> &subscribe, + const StatusCallback &done) { + RAY_LOG(DEBUG) << "Subscribing task lease, task id = " << task_id; + RAY_CHECK(subscribe != nullptr) + << "Failed to subscribe task lease, task id = " << task_id; + auto status = + task_lease_sub_executor_.AsyncSubscribe(subscribe_id_, task_id, subscribe, done); + RAY_LOG(DEBUG) << "Finished subscribing task lease, task id = " << task_id; + return status; +} + +Status ServiceBasedTaskInfoAccessor::AsyncUnsubscribeTaskLease( + const TaskID &task_id, const StatusCallback &done) { + RAY_LOG(DEBUG) << "Unsubscribing task lease, task id = " << task_id; + auto status = task_lease_sub_executor_.AsyncUnsubscribe(subscribe_id_, task_id, done); + RAY_LOG(DEBUG) << "Finished unsubscribing task lease, task id = " << task_id; + return status; +} + +Status ServiceBasedTaskInfoAccessor::AttemptTaskReconstruction( + const std::shared_ptr &data_ptr, + const StatusCallback &callback) { + ClientID node_id = ClientID::FromBinary(data_ptr->node_manager_id()); + RAY_LOG(DEBUG) << "Reconstructing task, reconstructions num = " + << data_ptr->num_reconstructions() << ", node id = " << node_id; + rpc::AttemptTaskReconstructionRequest request; + request.mutable_task_reconstruction()->CopyFrom(*data_ptr); + client_impl_->GetGcsRpcClient().AttemptTaskReconstruction( + request, + [data_ptr, node_id, callback](const Status &status, + const rpc::AttemptTaskReconstructionReply &reply) { + if (callback) { + callback(status); + } + RAY_LOG(DEBUG) << "Finished reconstructing task, status = " << status + << ", reconstructions num = " << data_ptr->num_reconstructions() + << ", node id = " << node_id; + }); + return Status::OK(); +} + +ServiceBasedObjectInfoAccessor::ServiceBasedObjectInfoAccessor( + ServiceBasedGcsClient *client_impl) + : client_impl_(client_impl), + subscribe_id_(ClientID::FromRandom()), + object_sub_executor_(client_impl->GetRedisGcsClient().object_table()) {} + +Status ServiceBasedObjectInfoAccessor::AsyncGetLocations( + const ObjectID &object_id, const MultiItemCallback &callback) { + RAY_LOG(DEBUG) << "Getting object locations, object id = " << object_id; + rpc::GetObjectLocationsRequest request; + request.set_object_id(object_id.Binary()); + client_impl_->GetGcsRpcClient().GetObjectLocations( + request, [object_id, callback](const Status &status, + const rpc::GetObjectLocationsReply &reply) { + std::vector result; + result.reserve((reply.object_table_data_list_size())); + for (int index = 0; index < reply.object_table_data_list_size(); ++index) { + result.emplace_back(reply.object_table_data_list(index)); + } + callback(status, result); + RAY_LOG(DEBUG) << "Finished getting object locations, status = " << status + << ", object id = " << object_id; + }); + return Status::OK(); +} + +Status ServiceBasedObjectInfoAccessor::AsyncAddLocation(const ObjectID &object_id, + const ClientID &node_id, + const StatusCallback &callback) { + RAY_LOG(DEBUG) << "Adding object location, object id = " << object_id + << ", node id = " << node_id; + rpc::AddObjectLocationRequest request; + request.set_object_id(object_id.Binary()); + request.set_node_id(node_id.Binary()); + client_impl_->GetGcsRpcClient().AddObjectLocation( + request, [object_id, node_id, callback](const Status &status, + const rpc::AddObjectLocationReply &reply) { + if (callback) { + callback(status); + } + RAY_LOG(DEBUG) << "Finished adding object location, status = " << status + << ", object id = " << object_id << ", node id = " << node_id; + }); + return Status::OK(); +} + +Status ServiceBasedObjectInfoAccessor::AsyncRemoveLocation( + const ObjectID &object_id, const ClientID &node_id, const StatusCallback &callback) { + RAY_LOG(DEBUG) << "Removing object location, object id = " << object_id + << ", node id = " << node_id; + rpc::RemoveObjectLocationRequest request; + request.set_object_id(object_id.Binary()); + request.set_node_id(node_id.Binary()); + client_impl_->GetGcsRpcClient().RemoveObjectLocation( + request, [object_id, node_id, callback]( + const Status &status, const rpc::RemoveObjectLocationReply &reply) { + if (callback) { + callback(status); + } + RAY_LOG(DEBUG) << "Finished removing object location, status = " << status + << ", object id = " << object_id << ", node id = " << node_id; + }); + return Status::OK(); +} + +Status ServiceBasedObjectInfoAccessor::AsyncSubscribeToLocations( + const ObjectID &object_id, + const SubscribeCallback &subscribe, + const StatusCallback &done) { + RAY_LOG(DEBUG) << "Subscribing object location, object id = " << object_id; + RAY_CHECK(subscribe != nullptr) + << "Failed to subscribe object location, object id = " << object_id; + auto status = + object_sub_executor_.AsyncSubscribe(subscribe_id_, object_id, subscribe, done); + RAY_LOG(DEBUG) << "Finished subscribing object location, object id = " << object_id; + return status; +} + +Status ServiceBasedObjectInfoAccessor::AsyncUnsubscribeToLocations( + const ObjectID &object_id, const StatusCallback &done) { + RAY_LOG(DEBUG) << "Unsubscribing object location, object id = " << object_id; + auto status = object_sub_executor_.AsyncUnsubscribe(subscribe_id_, object_id, done); + RAY_LOG(DEBUG) << "Finished unsubscribing object location, object id = " << object_id; + return status; +} + +ServiceBasedStatsInfoAccessor::ServiceBasedStatsInfoAccessor( + ServiceBasedGcsClient *client_impl) + : client_impl_(client_impl) {} + +Status ServiceBasedStatsInfoAccessor::AsyncAddProfileData( + const std::shared_ptr &data_ptr, + const StatusCallback &callback) { + ClientID node_id = ClientID::FromBinary(data_ptr->component_id()); + RAY_LOG(DEBUG) << "Adding profile data, component type = " << data_ptr->component_type() + << ", node id = " << node_id; + rpc::AddProfileDataRequest request; + request.mutable_profile_data()->CopyFrom(*data_ptr); + client_impl_->GetGcsRpcClient().AddProfileData( + request, [data_ptr, node_id, callback](const Status &status, + const rpc::AddProfileDataReply &reply) { + if (callback) { + callback(status); + } + RAY_LOG(DEBUG) << "Finished adding profile data, status = " << status + << ", component type = " << data_ptr->component_type() + << ", node id = " << node_id; + }); + return Status::OK(); +} + +ServiceBasedErrorInfoAccessor::ServiceBasedErrorInfoAccessor( + ServiceBasedGcsClient *client_impl) + : client_impl_(client_impl) {} + +Status ServiceBasedErrorInfoAccessor::AsyncReportJobError( + const std::shared_ptr &data_ptr, + const StatusCallback &callback) { + JobID job_id = JobID::FromBinary(data_ptr->job_id()); + std::string type = data_ptr->type(); + RAY_LOG(DEBUG) << "Reporting job error, job id = " << job_id << ", type = " << type; + rpc::ReportJobErrorRequest request; + request.mutable_error_data()->CopyFrom(*data_ptr); + client_impl_->GetGcsRpcClient().ReportJobError( + request, [job_id, type, callback](const Status &status, + const rpc::ReportJobErrorReply &reply) { + if (callback) { + callback(status); + } + RAY_LOG(DEBUG) << "Finished reporting job error, status = " << status + << ", job id = " << job_id << ", type = " << type; + ; + }); + return Status::OK(); +} + +ServiceBasedWorkerInfoAccessor::ServiceBasedWorkerInfoAccessor( + ServiceBasedGcsClient *client_impl) + : client_impl_(client_impl), + worker_failure_sub_executor_( + client_impl->GetRedisGcsClient().worker_failure_table()) {} + +Status ServiceBasedWorkerInfoAccessor::AsyncSubscribeToWorkerFailures( + const SubscribeCallback &subscribe, + const StatusCallback &done) { + RAY_LOG(DEBUG) << "Subscribing worker failures."; + RAY_CHECK(subscribe != nullptr); + auto status = + worker_failure_sub_executor_.AsyncSubscribeAll(ClientID::Nil(), subscribe, done); + RAY_LOG(DEBUG) << "Finished subscribing worker failures."; + return status; +} + +Status ServiceBasedWorkerInfoAccessor::AsyncReportWorkerFailure( + const std::shared_ptr &data_ptr, + const StatusCallback &callback) { + rpc::Address worker_address = data_ptr->worker_address(); + RAY_LOG(DEBUG) << "Reporting worker failure, " << worker_address.DebugString(); + rpc::ReportWorkerFailureRequest request; + request.mutable_worker_failure()->CopyFrom(*data_ptr); + client_impl_->GetGcsRpcClient().ReportWorkerFailure( + request, [worker_address, callback](const Status &status, + const rpc::ReportWorkerFailureReply &reply) { + if (callback) { + callback(status); + } + RAY_LOG(DEBUG) << "Finished reporting worker failure, " + << worker_address.DebugString() << ", status = " << status; + }); + return Status::OK(); +} + +} // namespace gcs +} // namespace ray diff --git a/src/ray/gcs/gcs_client/service_based_accessor.h b/src/ray/gcs/gcs_client/service_based_accessor.h new file mode 100644 index 000000000..2f781c6f9 --- /dev/null +++ b/src/ray/gcs/gcs_client/service_based_accessor.h @@ -0,0 +1,319 @@ +#ifndef RAY_GCS_SERVICE_BASED_ACCESSOR_H +#define RAY_GCS_SERVICE_BASED_ACCESSOR_H + +#include "src/ray/gcs/accessor.h" +#include "src/ray/gcs/subscription_executor.h" + +namespace ray { +namespace gcs { + +class ServiceBasedGcsClient; + +/// \class ServiceBasedJobInfoAccessor +/// ServiceBasedJobInfoAccessor is an implementation of `JobInfoAccessor` +/// that uses GCS Service as the backend. +class ServiceBasedJobInfoAccessor : public JobInfoAccessor { + public: + explicit ServiceBasedJobInfoAccessor(ServiceBasedGcsClient *client_impl); + + virtual ~ServiceBasedJobInfoAccessor() = default; + + Status AsyncAdd(const std::shared_ptr &data_ptr, + const StatusCallback &callback) override; + + Status AsyncMarkFinished(const JobID &job_id, const StatusCallback &callback) override; + + Status AsyncSubscribeToFinishedJobs( + const SubscribeCallback &subscribe, + const StatusCallback &done) override; + + private: + ServiceBasedGcsClient *client_impl_; + + typedef SubscriptionExecutor JobSubscriptionExecutor; + JobSubscriptionExecutor job_sub_executor_; +}; + +/// \class ServiceBasedActorInfoAccessor +/// ServiceBasedActorInfoAccessor is an implementation of `ActorInfoAccessor` +/// that uses GCS Service as the backend. +class ServiceBasedActorInfoAccessor : public ActorInfoAccessor { + public: + explicit ServiceBasedActorInfoAccessor(ServiceBasedGcsClient *client_impl); + + virtual ~ServiceBasedActorInfoAccessor() = default; + + Status AsyncGet(const ActorID &actor_id, + const OptionalItemCallback &callback) override; + + Status AsyncRegister(const std::shared_ptr &data_ptr, + const StatusCallback &callback) override; + + Status AsyncUpdate(const ActorID &actor_id, + const std::shared_ptr &data_ptr, + const StatusCallback &callback) override; + + Status AsyncSubscribeAll( + const SubscribeCallback &subscribe, + const StatusCallback &done) override; + + Status AsyncSubscribe(const ActorID &actor_id, + const SubscribeCallback &subscribe, + const StatusCallback &done) override; + + Status AsyncUnsubscribe(const ActorID &actor_id, const StatusCallback &done) override; + + Status AsyncAddCheckpoint(const std::shared_ptr &data_ptr, + const StatusCallback &callback) override; + + Status AsyncGetCheckpoint( + const ActorCheckpointID &checkpoint_id, + const OptionalItemCallback &callback) override; + + Status AsyncGetCheckpointID( + const ActorID &actor_id, + const OptionalItemCallback &callback) override; + + private: + ServiceBasedGcsClient *client_impl_; + + ClientID subscribe_id_; + + typedef SubscriptionExecutor + ActorSubscriptionExecutor; + ActorSubscriptionExecutor actor_sub_executor_; +}; + +/// \class ServiceBasedNodeInfoAccessor +/// ServiceBasedNodeInfoAccessor is an implementation of `NodeInfoAccessor` +/// that uses GCS Service as the backend. +class ServiceBasedNodeInfoAccessor : public NodeInfoAccessor { + public: + explicit ServiceBasedNodeInfoAccessor(ServiceBasedGcsClient *client_impl); + + virtual ~ServiceBasedNodeInfoAccessor() = default; + + Status RegisterSelf(const GcsNodeInfo &local_node_info) override; + + Status UnregisterSelf() override; + + const ClientID &GetSelfId() const override; + + const GcsNodeInfo &GetSelfInfo() const override; + + Status AsyncRegister(const rpc::GcsNodeInfo &node_info, + const StatusCallback &callback) override; + + Status AsyncUnregister(const ClientID &node_id, + const StatusCallback &callback) override; + + Status AsyncGetAll(const MultiItemCallback &callback) override; + + Status AsyncSubscribeToNodeChange( + const SubscribeCallback &subscribe, + const StatusCallback &done) override; + + boost::optional Get(const ClientID &node_id) const override; + + const std::unordered_map &GetAll() const override; + + bool IsRemoved(const ClientID &node_id) const override; + + Status AsyncGetResources(const ClientID &node_id, + const OptionalItemCallback &callback) override; + + Status AsyncUpdateResources(const ClientID &node_id, const ResourceMap &resources, + const StatusCallback &callback) override; + + Status AsyncDeleteResources(const ClientID &node_id, + const std::vector &resource_names, + const StatusCallback &callback) override; + + Status AsyncSubscribeToResources( + const SubscribeCallback &subscribe, + const StatusCallback &done) override; + + Status AsyncReportHeartbeat(const std::shared_ptr &data_ptr, + const StatusCallback &callback) override; + + Status AsyncSubscribeHeartbeat( + const SubscribeCallback &subscribe, + const StatusCallback &done) override; + + Status AsyncReportBatchHeartbeat( + const std::shared_ptr &data_ptr, + const StatusCallback &callback) override; + + Status AsyncSubscribeBatchHeartbeat( + const ItemCallback &subscribe, + const StatusCallback &done) override; + + private: + ServiceBasedGcsClient *client_impl_; + + typedef SubscriptionExecutor + DynamicResourceSubscriptionExecutor; + DynamicResourceSubscriptionExecutor resource_sub_executor_; + + typedef SubscriptionExecutor + HeartbeatSubscriptionExecutor; + HeartbeatSubscriptionExecutor heartbeat_sub_executor_; + + typedef SubscriptionExecutor + HeartbeatBatchSubscriptionExecutor; + HeartbeatBatchSubscriptionExecutor heartbeat_batch_sub_executor_; + + GcsNodeInfo local_node_info_; + ClientID local_node_id_; +}; + +/// \class ServiceBasedTaskInfoAccessor +/// ServiceBasedTaskInfoAccessor is an implementation of `TaskInfoAccessor` +/// that uses GCS service as the backend. +class ServiceBasedTaskInfoAccessor : public TaskInfoAccessor { + public: + explicit ServiceBasedTaskInfoAccessor(ServiceBasedGcsClient *client_impl); + + virtual ~ServiceBasedTaskInfoAccessor() = default; + + Status AsyncAdd(const std::shared_ptr &data_ptr, + const StatusCallback &callback) override; + + Status AsyncGet(const TaskID &task_id, + const OptionalItemCallback &callback) override; + + Status AsyncDelete(const std::vector &task_ids, + const StatusCallback &callback) override; + + Status AsyncSubscribe(const TaskID &task_id, + const SubscribeCallback &subscribe, + const StatusCallback &done) override; + + Status AsyncUnsubscribe(const TaskID &task_id, const StatusCallback &done) override; + + Status AsyncAddTaskLease(const std::shared_ptr &data_ptr, + const StatusCallback &callback) override; + + Status AsyncSubscribeTaskLease( + const TaskID &task_id, + const SubscribeCallback> &subscribe, + const StatusCallback &done) override; + + Status AsyncUnsubscribeTaskLease(const TaskID &task_id, + const StatusCallback &done) override; + + Status AttemptTaskReconstruction( + const std::shared_ptr &data_ptr, + const StatusCallback &callback) override; + + private: + ServiceBasedGcsClient *client_impl_; + + ClientID subscribe_id_; + + typedef SubscriptionExecutor + TaskSubscriptionExecutor; + TaskSubscriptionExecutor task_sub_executor_; + + typedef SubscriptionExecutor, TaskLeaseTable> + TaskLeaseSubscriptionExecutor; + TaskLeaseSubscriptionExecutor task_lease_sub_executor_; +}; + +/// \class ServiceBasedObjectInfoAccessor +/// ServiceBasedObjectInfoAccessor is an implementation of `ObjectInfoAccessor` +/// that uses GCS service as the backend. +class ServiceBasedObjectInfoAccessor : public ObjectInfoAccessor { + public: + explicit ServiceBasedObjectInfoAccessor(ServiceBasedGcsClient *client_impl); + + virtual ~ServiceBasedObjectInfoAccessor() = default; + + Status AsyncGetLocations( + const ObjectID &object_id, + const MultiItemCallback &callback) override; + + Status AsyncAddLocation(const ObjectID &object_id, const ClientID &node_id, + const StatusCallback &callback) override; + + Status AsyncRemoveLocation(const ObjectID &object_id, const ClientID &node_id, + const StatusCallback &callback) override; + + Status AsyncSubscribeToLocations( + const ObjectID &object_id, + const SubscribeCallback &subscribe, + const StatusCallback &done) override; + + Status AsyncUnsubscribeToLocations(const ObjectID &object_id, + const StatusCallback &done) override; + + private: + ServiceBasedGcsClient *client_impl_; + + ClientID subscribe_id_; + + typedef SubscriptionExecutor + ObjectSubscriptionExecutor; + ObjectSubscriptionExecutor object_sub_executor_; +}; + +/// \class ServiceBasedStatsInfoAccessor +/// ServiceBasedStatsInfoAccessor is an implementation of `StatsInfoAccessor` +/// that uses GCS Service as the backend. +class ServiceBasedStatsInfoAccessor : public StatsInfoAccessor { + public: + explicit ServiceBasedStatsInfoAccessor(ServiceBasedGcsClient *client_impl); + + virtual ~ServiceBasedStatsInfoAccessor() = default; + + Status AsyncAddProfileData(const std::shared_ptr &data_ptr, + const StatusCallback &callback) override; + + private: + ServiceBasedGcsClient *client_impl_; +}; + +/// \class ServiceBasedErrorInfoAccessor +/// ServiceBasedErrorInfoAccessor is an implementation of `ErrorInfoAccessor` +/// that uses GCS Service as the backend. +class ServiceBasedErrorInfoAccessor : public ErrorInfoAccessor { + public: + explicit ServiceBasedErrorInfoAccessor(ServiceBasedGcsClient *client_impl); + + virtual ~ServiceBasedErrorInfoAccessor() = default; + + Status AsyncReportJobError(const std::shared_ptr &data_ptr, + const StatusCallback &callback) override; + + private: + ServiceBasedGcsClient *client_impl_; +}; + +/// \class ServiceBasedWorkerInfoAccessor +/// ServiceBasedWorkerInfoAccessor is an implementation of `WorkerInfoAccessor` +/// that uses GCS Service as the backend. +class ServiceBasedWorkerInfoAccessor : public WorkerInfoAccessor { + public: + explicit ServiceBasedWorkerInfoAccessor(ServiceBasedGcsClient *client_impl); + + virtual ~ServiceBasedWorkerInfoAccessor() = default; + + Status AsyncSubscribeToWorkerFailures( + const SubscribeCallback &subscribe, + const StatusCallback &done) override; + + Status AsyncReportWorkerFailure(const std::shared_ptr &data_ptr, + const StatusCallback &callback) override; + + private: + ServiceBasedGcsClient *client_impl_; + + typedef SubscriptionExecutor + WorkerFailureSubscriptionExecutor; + WorkerFailureSubscriptionExecutor worker_failure_sub_executor_; +}; + +} // namespace gcs +} // namespace ray + +#endif // RAY_GCS_SERVICE_BASED_ACCESSOR_H diff --git a/src/ray/gcs/gcs_client/service_based_gcs_client.cc b/src/ray/gcs/gcs_client/service_based_gcs_client.cc new file mode 100644 index 000000000..45ac9d166 --- /dev/null +++ b/src/ray/gcs/gcs_client/service_based_gcs_client.cc @@ -0,0 +1,86 @@ +#include "ray/gcs/gcs_client/service_based_gcs_client.h" +#include "ray/common/ray_config.h" +#include "ray/gcs/gcs_client/service_based_accessor.h" + +namespace ray { +namespace gcs { + +ServiceBasedGcsClient::ServiceBasedGcsClient(const GcsClientOptions &options) + : GcsClient(options) {} + +Status ServiceBasedGcsClient::Connect(boost::asio::io_service &io_service) { + RAY_CHECK(!is_connected_); + + if (options_.server_ip_.empty()) { + RAY_LOG(ERROR) << "Failed to connect, gcs service address is empty."; + return Status::Invalid("gcs service address is invalid!"); + } + + // Connect to gcs + redis_gcs_client_.reset(new RedisGcsClient(options_)); + RAY_CHECK_OK(redis_gcs_client_->Connect(io_service)); + + // Get gcs service address + std::pair address; + GetGcsServerAddressFromRedis(redis_gcs_client_->primary_context()->sync_context(), + &address); + + // Connect to gcs service + client_call_manager_.reset(new rpc::ClientCallManager(io_service)); + gcs_rpc_client_.reset( + new rpc::GcsRpcClient(address.first, address.second, *client_call_manager_)); + + job_accessor_.reset(new ServiceBasedJobInfoAccessor(this)); + actor_accessor_.reset(new ServiceBasedActorInfoAccessor(this)); + node_accessor_.reset(new ServiceBasedNodeInfoAccessor(this)); + task_accessor_.reset(new ServiceBasedTaskInfoAccessor(this)); + object_accessor_.reset(new ServiceBasedObjectInfoAccessor(this)); + stats_accessor_.reset(new ServiceBasedStatsInfoAccessor(this)); + error_accessor_.reset(new ServiceBasedErrorInfoAccessor(this)); + worker_accessor_.reset(new ServiceBasedWorkerInfoAccessor(this)); + + is_connected_ = true; + + RAY_LOG(INFO) << "ServiceBasedGcsClient Connected."; + return Status::OK(); +} + +void ServiceBasedGcsClient::Disconnect() { + RAY_CHECK(is_connected_); + is_connected_ = false; + RAY_LOG(INFO) << "ServiceBasedGcsClient Disconnected."; +} + +void ServiceBasedGcsClient::GetGcsServerAddressFromRedis( + redisContext *context, std::pair *address) { + // Get gcs server address. + int num_attempts = 0; + redisReply *reply = nullptr; + while (num_attempts < RayConfig::instance().gcs_service_connect_retries()) { + reply = reinterpret_cast(redisCommand(context, "GET GcsServerAddress")); + if (reply->type != REDIS_REPLY_NIL) { + break; + } + + // Sleep for a little, and try again if the entry isn't there yet. + freeReplyObject(reply); + usleep(RayConfig::instance().gcs_service_connect_wait_milliseconds() * 1000); + num_attempts++; + } + RAY_CHECK(num_attempts < RayConfig::instance().gcs_service_connect_retries()) + << "No entry found for GcsServerAddress"; + RAY_CHECK(reply->type == REDIS_REPLY_STRING) + << "Expected string, found Redis type " << reply->type << " for GcsServerAddress"; + std::string result(reply->str); + freeReplyObject(reply); + + RAY_CHECK(!result.empty()) << "Gcs service address is empty"; + size_t pos = result.find(':'); + RAY_CHECK(pos != std::string::npos) + << "Gcs service address format is erroneous: " << result; + address->first = result.substr(0, pos); + address->second = std::stoi(result.substr(pos + 1)); +} + +} // namespace gcs +} // namespace ray diff --git a/src/ray/gcs/gcs_client/service_based_gcs_client.h b/src/ray/gcs/gcs_client/service_based_gcs_client.h new file mode 100644 index 000000000..22213b1a4 --- /dev/null +++ b/src/ray/gcs/gcs_client/service_based_gcs_client.h @@ -0,0 +1,43 @@ +#ifndef RAY_GCS_SERVICE_BASED_GCS_CLIENT_H +#define RAY_GCS_SERVICE_BASED_GCS_CLIENT_H + +#include "ray/gcs/redis_gcs_client.h" +#include "ray/rpc/gcs_server/gcs_rpc_client.h" + +namespace ray { +namespace gcs { + +class RAY_EXPORT ServiceBasedGcsClient : public GcsClient { + public: + ServiceBasedGcsClient(const GcsClientOptions &options); + + ServiceBasedGcsClient(RedisGcsClient *redis_gcs_client); + + Status Connect(boost::asio::io_service &io_service) override; + + void Disconnect() override; + + RedisGcsClient &GetRedisGcsClient() { return *redis_gcs_client_; } + + rpc::GcsRpcClient &GetGcsRpcClient() { return *gcs_rpc_client_; } + + private: + /// Get gcs server address from redis. + /// This address is set by GcsServer::StoreGcsServerAddressInRedis function. + /// + /// \param context The context of redis. + /// \param address The address of gcs server. + void GetGcsServerAddressFromRedis(redisContext *context, + std::pair *address); + + std::unique_ptr redis_gcs_client_; + + // Gcs rpc client + std::unique_ptr gcs_rpc_client_; + std::unique_ptr client_call_manager_; +}; + +} // namespace gcs +} // namespace ray + +#endif // RAY_GCS_SERVICE_BASED_GCS_CLIENT_H diff --git a/src/ray/gcs/gcs_client/test/service_based_gcs_client_test.cc b/src/ray/gcs/gcs_client/test/service_based_gcs_client_test.cc new file mode 100644 index 000000000..c5a278afb --- /dev/null +++ b/src/ray/gcs/gcs_client/test/service_based_gcs_client_test.cc @@ -0,0 +1,631 @@ +#include "ray/gcs/gcs_client/service_based_gcs_client.h" +#include "gtest/gtest.h" +#include "ray/gcs/gcs_client/service_based_accessor.h" +#include "ray/gcs/gcs_server/gcs_server.h" +#include "ray/rpc/gcs_server/gcs_rpc_client.h" +#include "ray/util/test_util.h" + +namespace ray { + +static std::string redis_server_executable; +static std::string redis_client_executable; +static std::string libray_redis_module_path; + +class ServiceBasedGcsGcsClientTest : public RedisServiceManagerForTest { + public: + void SetUp() override { + gcs::GcsServerConfig config; + config.grpc_server_port = 0; + config.grpc_server_name = "MockedGcsServer"; + config.grpc_server_thread_num = 1; + config.redis_address = "127.0.0.1"; + config.is_test = true; + config.redis_port = REDIS_SERVER_PORT; + gcs_server_.reset(new gcs::GcsServer(config)); + + thread_io_service_.reset(new std::thread([this] { + std::unique_ptr work( + new boost::asio::io_service::work(io_service_)); + io_service_.run(); + })); + + thread_gcs_server_.reset(new std::thread([this] { gcs_server_->Start(); })); + + // Wait until server starts listening. + while (gcs_server_->GetPort() == 0) { + std::this_thread::sleep_for(std::chrono::milliseconds(10)); + } + + // Create gcs client + gcs::GcsClientOptions options(config.redis_address, config.redis_port, + config.redis_password, config.is_test); + gcs_client_.reset(new gcs::ServiceBasedGcsClient(options)); + RAY_CHECK_OK(gcs_client_->Connect(io_service_)); + } + + void TearDown() override { + gcs_server_->Stop(); + io_service_.stop(); + thread_io_service_->join(); + thread_gcs_server_->join(); + gcs_client_->Disconnect(); + } + + bool AddJob(const std::shared_ptr &job_table_data) { + std::promise promise; + RAY_CHECK_OK(gcs_client_->Jobs().AsyncAdd( + job_table_data, [&promise](Status status) { promise.set_value(status.ok()); })); + return WaitReady(promise.get_future(), timeout_ms_); + } + + bool MarkJobFinished(const JobID &job_id) { + std::promise promise; + RAY_CHECK_OK(gcs_client_->Jobs().AsyncMarkFinished( + job_id, [&promise](Status status) { promise.set_value(status.ok()); })); + return WaitReady(promise.get_future(), timeout_ms_); + } + + bool RegisterActor(const std::shared_ptr &actor_table_data) { + std::promise promise; + RAY_CHECK_OK(gcs_client_->Actors().AsyncRegister( + actor_table_data, [&promise](Status status) { promise.set_value(status.ok()); })); + return WaitReady(promise.get_future(), timeout_ms_); + } + + bool UpdateActor(const ActorID &actor_id, + const std::shared_ptr &actor_table_data) { + std::promise promise; + RAY_CHECK_OK(gcs_client_->Actors().AsyncUpdate( + actor_id, actor_table_data, + [&promise](Status status) { promise.set_value(status.ok()); })); + return WaitReady(promise.get_future(), timeout_ms_); + } + + rpc::ActorTableData GetActor(const ActorID &actor_id) { + std::promise promise; + rpc::ActorTableData actor_table_data; + RAY_CHECK_OK(gcs_client_->Actors().AsyncGet( + actor_id, [&actor_table_data, &promise]( + Status status, const boost::optional &result) { + assert(result); + actor_table_data.CopyFrom(*result); + promise.set_value(true); + })); + EXPECT_TRUE(WaitReady(promise.get_future(), timeout_ms_)); + return actor_table_data; + } + + bool AddCheckpoint( + const std::shared_ptr &actor_checkpoint_data) { + std::promise promise; + RAY_CHECK_OK(gcs_client_->Actors().AsyncAddCheckpoint( + actor_checkpoint_data, + [&promise](Status status) { promise.set_value(status.ok()); })); + return WaitReady(promise.get_future(), timeout_ms_); + } + + rpc::ActorCheckpointData GetCheckpoint(const ActorCheckpointID &checkpoint_id) { + std::promise promise; + rpc::ActorCheckpointData actor_checkpoint_data; + RAY_CHECK_OK(gcs_client_->Actors().AsyncGetCheckpoint( + checkpoint_id, + [&actor_checkpoint_data, &promise]( + Status status, const boost::optional &result) { + assert(result); + actor_checkpoint_data.CopyFrom(*result); + promise.set_value(true); + })); + EXPECT_TRUE(WaitReady(promise.get_future(), timeout_ms_)); + return actor_checkpoint_data; + } + + rpc::ActorCheckpointIdData GetCheckpointID(const ActorID &actor_id) { + std::promise promise; + rpc::ActorCheckpointIdData actor_checkpoint_id_data; + RAY_CHECK_OK(gcs_client_->Actors().AsyncGetCheckpointID( + actor_id, + [&actor_checkpoint_id_data, &promise]( + Status status, const boost::optional &result) { + assert(result); + actor_checkpoint_id_data.CopyFrom(*result); + promise.set_value(true); + })); + EXPECT_TRUE(WaitReady(promise.get_future(), timeout_ms_)); + return actor_checkpoint_id_data; + } + + bool RegisterSelf(const rpc::GcsNodeInfo &local_node_info) { + Status status = gcs_client_->Nodes().RegisterSelf(local_node_info); + return status.ok(); + } + + bool UnregisterSelf() { + Status status = gcs_client_->Nodes().UnregisterSelf(); + return status.ok(); + } + + bool RegisterNode(const rpc::GcsNodeInfo &node_info) { + std::promise promise; + RAY_CHECK_OK(gcs_client_->Nodes().AsyncRegister( + node_info, [&promise](Status status) { promise.set_value(status.ok()); })); + return WaitReady(promise.get_future(), timeout_ms_); + } + + std::vector GetNodeInfoList() { + std::promise promise; + std::vector nodes; + RAY_CHECK_OK(gcs_client_->Nodes().AsyncGetAll( + [&nodes, &promise](Status status, const std::vector &result) { + assert(result); + nodes.assign(result.begin(), result.end()); + promise.set_value(status.ok()); + })); + EXPECT_TRUE(WaitReady(promise.get_future(), timeout_ms_)); + return nodes; + } + + bool UnregisterNode(const ClientID &node_id) { + std::promise promise; + RAY_CHECK_OK(gcs_client_->Nodes().AsyncUnregister( + node_id, [&promise](Status status) { promise.set_value(status.ok()); })); + return WaitReady(promise.get_future(), timeout_ms_); + } + + gcs::NodeInfoAccessor::ResourceMap GetResources(const ClientID &node_id) { + gcs::NodeInfoAccessor::ResourceMap resource_map; + std::promise promise; + RAY_CHECK_OK(gcs_client_->Nodes().AsyncGetResources( + node_id, [&resource_map, &promise]( + Status status, + const boost::optional &result) { + if (result) { + resource_map.insert(result->begin(), result->end()); + } + promise.set_value(true); + })); + EXPECT_TRUE(WaitReady(promise.get_future(), timeout_ms_)); + return resource_map; + } + + bool UpdateResources(const ClientID &node_id, + const gcs::NodeInfoAccessor::ResourceMap &resource_map) { + std::promise promise; + RAY_CHECK_OK(gcs_client_->Nodes().AsyncUpdateResources( + node_id, resource_map, + [&promise](Status status) { promise.set_value(status.ok()); })); + return WaitReady(promise.get_future(), timeout_ms_); + } + + bool DeleteResources(const ClientID &node_id, + const std::vector &resource_names) { + std::promise promise; + RAY_CHECK_OK(gcs_client_->Nodes().AsyncDeleteResources( + node_id, resource_names, + [&promise](Status status) { promise.set_value(status.ok()); })); + return WaitReady(promise.get_future(), timeout_ms_); + } + + bool ReportHeartbeat(const std::shared_ptr heartbeat) { + std::promise promise; + RAY_CHECK_OK(gcs_client_->Nodes().AsyncReportHeartbeat( + heartbeat, [&promise](Status status) { promise.set_value(status.ok()); })); + return WaitReady(promise.get_future(), timeout_ms_); + } + + bool ReportBatchHeartbeat( + const std::shared_ptr batch_heartbeat) { + std::promise promise; + RAY_CHECK_OK(gcs_client_->Nodes().AsyncReportBatchHeartbeat( + batch_heartbeat, [&promise](Status status) { promise.set_value(status.ok()); })); + return WaitReady(promise.get_future(), timeout_ms_); + } + + bool AddTask(const std::shared_ptr task) { + std::promise promise; + RAY_CHECK_OK(gcs_client_->Tasks().AsyncAdd( + task, [&promise](Status status) { promise.set_value(status.ok()); })); + return WaitReady(promise.get_future(), timeout_ms_); + } + + rpc::TaskTableData GetTask(const TaskID &task_id) { + std::promise promise; + rpc::TaskTableData task_table_data; + RAY_CHECK_OK(gcs_client_->Tasks().AsyncGet( + task_id, [&task_table_data, &promise]( + Status status, const boost::optional &result) { + if (result) { + task_table_data.CopyFrom(*result); + } + promise.set_value(status.ok()); + })); + EXPECT_TRUE(WaitReady(promise.get_future(), timeout_ms_)); + return task_table_data; + } + + bool DeleteTask(const std::vector &task_ids) { + std::promise promise; + RAY_CHECK_OK(gcs_client_->Tasks().AsyncDelete( + task_ids, [&promise](Status status) { promise.set_value(status.ok()); })); + return WaitReady(promise.get_future(), timeout_ms_); + } + + bool AddTaskLease(const std::shared_ptr task_lease) { + std::promise promise; + RAY_CHECK_OK(gcs_client_->Tasks().AsyncAddTaskLease( + task_lease, [&promise](Status status) { promise.set_value(status.ok()); })); + return WaitReady(promise.get_future(), timeout_ms_); + } + + bool AttemptTaskReconstruction( + const std::shared_ptr task_reconstruction_data) { + std::promise promise; + RAY_CHECK_OK(gcs_client_->Tasks().AttemptTaskReconstruction( + task_reconstruction_data, + [&promise](Status status) { promise.set_value(status.ok()); })); + return WaitReady(promise.get_future(), timeout_ms_); + } + + protected: + bool WaitReady(const std::future &future, + const std::chrono::milliseconds &timeout_ms) { + auto status = future.wait_for(timeout_ms); + return status == std::future_status::ready; + } + + std::shared_ptr GenJobTableData(JobID job_id) { + auto job_table_data = std::make_shared(); + job_table_data->set_job_id(job_id.Binary()); + job_table_data->set_is_dead(false); + job_table_data->set_timestamp(std::time(nullptr)); + job_table_data->set_node_manager_address("127.0.0.1"); + job_table_data->set_driver_pid(5667L); + return job_table_data; + } + + std::shared_ptr GenActorTableData(const JobID &job_id) { + auto actor_table_data = std::make_shared(); + ActorID actor_id = ActorID::Of(job_id, RandomTaskId(), 0); + actor_table_data->set_actor_id(actor_id.Binary()); + actor_table_data->set_job_id(job_id.Binary()); + actor_table_data->set_state( + rpc::ActorTableData_ActorState::ActorTableData_ActorState_ALIVE); + actor_table_data->set_max_reconstructions(1); + actor_table_data->set_remaining_reconstructions(1); + return actor_table_data; + } + + rpc::GcsNodeInfo GenGcsNodeInfo(const std::string &node_id) { + rpc::GcsNodeInfo gcs_node_info; + gcs_node_info.set_node_id(node_id); + gcs_node_info.set_state(rpc::GcsNodeInfo_GcsNodeState_ALIVE); + return gcs_node_info; + } + + std::shared_ptr GenTaskTableData(const std::string &job_id, + const std::string &task_id) { + auto task_table_data = std::make_shared(); + rpc::Task task; + rpc::TaskSpec task_spec; + task_spec.set_job_id(job_id); + task_spec.set_task_id(task_id); + task.mutable_task_spec()->CopyFrom(task_spec); + task_table_data->mutable_task()->CopyFrom(task); + return task_table_data; + } + + std::shared_ptr GenTaskLeaseData(const std::string &task_id, + const std::string &node_id) { + auto task_lease_data = std::make_shared(); + task_lease_data->set_task_id(task_id); + task_lease_data->set_node_manager_id(node_id); + return task_lease_data; + } + + // Gcs server + std::unique_ptr gcs_server_; + std::unique_ptr thread_io_service_; + std::unique_ptr thread_gcs_server_; + boost::asio::io_service io_service_; + + // Gcs client + std::unique_ptr gcs_client_; + + // Timeout waiting for gcs server reply, default is 2s + const std::chrono::milliseconds timeout_ms_{2000}; +}; + +TEST_F(ServiceBasedGcsGcsClientTest, TestJobInfo) { + // Create job_table_data + JobID add_job_id = JobID::FromInt(1); + auto job_table_data = GenJobTableData(add_job_id); + + std::promise promise; + auto on_subscribe = [&promise, add_job_id](const JobID &job_id, + const gcs::JobTableData &data) { + ASSERT_TRUE(add_job_id == job_id); + promise.set_value(true); + }; + RAY_CHECK_OK(gcs_client_->Jobs().AsyncSubscribeToFinishedJobs( + on_subscribe, [](Status status) { RAY_CHECK_OK(status); })); + + ASSERT_TRUE(AddJob(job_table_data)); + ASSERT_TRUE(MarkJobFinished(add_job_id)); + ASSERT_TRUE(WaitReady(promise.get_future(), timeout_ms_)); +} + +TEST_F(ServiceBasedGcsGcsClientTest, TestActorInfo) { + // Create actor_table_data + JobID job_id = JobID::FromInt(1); + auto actor_table_data = GenActorTableData(job_id); + ActorID actor_id = ActorID::FromBinary(actor_table_data->actor_id()); + + // Subscribe + std::promise promise_subscribe; + std::atomic subscribe_callback_count(0); + auto on_subscribe = [&subscribe_callback_count](const ActorID &actor_id, + const gcs::ActorTableData &data) { + ++subscribe_callback_count; + }; + RAY_CHECK_OK(gcs_client_->Actors().AsyncSubscribe(actor_id, on_subscribe, + [&promise_subscribe](Status status) { + RAY_CHECK_OK(status); + promise_subscribe.set_value(true); + })); + + // Register actor + ASSERT_TRUE(RegisterActor(actor_table_data)); + ASSERT_TRUE(GetActor(actor_id).state() == + rpc::ActorTableData_ActorState::ActorTableData_ActorState_ALIVE); + + // Unsubscribe + std::promise promise_unsubscribe; + RAY_CHECK_OK(gcs_client_->Actors().AsyncUnsubscribe( + actor_id, [&promise_unsubscribe](Status status) { + RAY_CHECK_OK(status); + promise_unsubscribe.set_value(true); + })); + ASSERT_TRUE(WaitReady(promise_unsubscribe.get_future(), timeout_ms_)); + + // Update actor + actor_table_data->set_state( + rpc::ActorTableData_ActorState::ActorTableData_ActorState_DEAD); + ASSERT_TRUE(UpdateActor(actor_id, actor_table_data)); + ASSERT_TRUE(GetActor(actor_id).state() == + rpc::ActorTableData_ActorState::ActorTableData_ActorState_DEAD); + ASSERT_TRUE(WaitReady(promise_subscribe.get_future(), timeout_ms_)); + auto condition = [&subscribe_callback_count]() { + return 1 == subscribe_callback_count; + }; + EXPECT_TRUE(WaitForCondition(condition, timeout_ms_.count())); +} + +TEST_F(ServiceBasedGcsGcsClientTest, TestActorCheckpoint) { + // Create actor checkpoint + JobID job_id = JobID::FromInt(1); + auto actor_table_data = GenActorTableData(job_id); + ActorID actor_id = ActorID::FromBinary(actor_table_data->actor_id()); + + ActorCheckpointID checkpoint_id = ActorCheckpointID::FromRandom(); + auto checkpoint = std::make_shared(); + checkpoint->set_actor_id(actor_table_data->actor_id()); + checkpoint->set_checkpoint_id(checkpoint_id.Binary()); + checkpoint->set_execution_dependency(checkpoint_id.Binary()); + + // Add checkpoint + ASSERT_TRUE(AddCheckpoint(checkpoint)); + + // Get Checkpoint + auto get_checkpoint_result = GetCheckpoint(checkpoint_id); + ASSERT_TRUE(get_checkpoint_result.actor_id() == actor_id.Binary()); + + // Get CheckpointID + auto get_checkpoint_id_result = GetCheckpointID(actor_id); + ASSERT_TRUE(get_checkpoint_id_result.checkpoint_ids_size() == 1); + ASSERT_TRUE(get_checkpoint_id_result.checkpoint_ids(0) == checkpoint_id.Binary()); +} + +TEST_F(ServiceBasedGcsGcsClientTest, TestActorSubscribeAll) { + // Create actor_table_data + JobID job_id = JobID::FromInt(1); + auto actor_table_data1 = GenActorTableData(job_id); + auto actor_table_data2 = GenActorTableData(job_id); + + // Subscribe all + std::promise promise_subscribe_all; + std::atomic subscribe_all_callback_count(0); + auto on_subscribe_all = [&subscribe_all_callback_count]( + const ActorID &actor_id, const gcs::ActorTableData &data) { + ++subscribe_all_callback_count; + }; + RAY_CHECK_OK(gcs_client_->Actors().AsyncSubscribeAll( + on_subscribe_all, [&promise_subscribe_all](Status status) { + RAY_CHECK_OK(status); + promise_subscribe_all.set_value(true); + })); + ASSERT_TRUE(WaitReady(promise_subscribe_all.get_future(), timeout_ms_)); + + // Register actor + ASSERT_TRUE(RegisterActor(actor_table_data1)); + ASSERT_TRUE(RegisterActor(actor_table_data2)); + auto condition = [&subscribe_all_callback_count]() { + return 2 == subscribe_all_callback_count; + }; + EXPECT_TRUE(WaitForCondition(condition, timeout_ms_.count())); +} + +TEST_F(ServiceBasedGcsGcsClientTest, TestNodeInfo) { + // Create gcs node info + ClientID node1_id = ClientID::FromRandom(); + auto gcs_node1_info = GenGcsNodeInfo(node1_id.Binary()); + + int register_count = 0; + int unregister_count = 0; + RAY_CHECK_OK(gcs_client_->Nodes().AsyncSubscribeToNodeChange( + [®ister_count, &unregister_count](const ClientID &node_id, + const rpc::GcsNodeInfo &data) { + if (data.state() == rpc::GcsNodeInfo::ALIVE) { + ++register_count; + } else if (data.state() == rpc::GcsNodeInfo::DEAD) { + ++unregister_count; + } + }, + nullptr)); + + // Register self + ASSERT_TRUE(RegisterSelf(gcs_node1_info)); + sleep(1); + EXPECT_EQ(gcs_client_->Nodes().GetSelfId(), node1_id); + EXPECT_EQ(gcs_client_->Nodes().GetSelfInfo().node_id(), gcs_node1_info.node_id()); + EXPECT_EQ(gcs_client_->Nodes().GetSelfInfo().state(), gcs_node1_info.state()); + + // Register node + ClientID node2_id = ClientID::FromRandom(); + auto gcs_node2_info = GenGcsNodeInfo(node2_id.Binary()); + ASSERT_TRUE(RegisterNode(gcs_node2_info)); + + // Get node list + std::vector node_list = GetNodeInfoList(); + EXPECT_EQ(node_list.size(), 2); + EXPECT_EQ(register_count, 2); + ASSERT_TRUE(gcs_client_->Nodes().Get(node1_id)); + EXPECT_EQ(gcs_client_->Nodes().GetAll().size(), 2); + + // Unregister self + ASSERT_TRUE(UnregisterSelf()); + + // Unregister node + ASSERT_TRUE(UnregisterNode(node2_id)); + node_list = GetNodeInfoList(); + EXPECT_EQ(node_list.size(), 2); + EXPECT_EQ(node_list[0].state(), + rpc::GcsNodeInfo_GcsNodeState::GcsNodeInfo_GcsNodeState_DEAD); + EXPECT_EQ(node_list[1].state(), + rpc::GcsNodeInfo_GcsNodeState::GcsNodeInfo_GcsNodeState_DEAD); + EXPECT_EQ(unregister_count, 2); + ASSERT_TRUE(gcs_client_->Nodes().IsRemoved(node2_id)); +} + +TEST_F(ServiceBasedGcsGcsClientTest, TestNodeResources) { + int add_count = 0; + int remove_count = 0; + auto subscribe = [&add_count, &remove_count]( + const ClientID &id, + const gcs::ResourceChangeNotification ¬ification) { + if (notification.IsAdded()) { + ++add_count; + } else if (notification.IsRemoved()) { + ++remove_count; + } + }; + RAY_CHECK_OK(gcs_client_->Nodes().AsyncSubscribeToResources(subscribe, nullptr)); + + // Update resources + ClientID node_id = ClientID::FromRandom(); + gcs::NodeInfoAccessor::ResourceMap resource_map; + std::string key = "CPU"; + auto resource = std::make_shared(); + resource->set_resource_capacity(1.0); + resource_map[key] = resource; + ASSERT_TRUE(UpdateResources(node_id, resource_map)); + auto get_resources_result = GetResources(node_id); + ASSERT_TRUE(get_resources_result.count(key)); + + // Delete resources + ASSERT_TRUE(DeleteResources(node_id, {key})); + get_resources_result = GetResources(node_id); + ASSERT_TRUE(get_resources_result.empty()); + EXPECT_EQ(add_count, 1); + EXPECT_EQ(remove_count, 1); +} + +TEST_F(ServiceBasedGcsGcsClientTest, TestNodeHeartbeat) { + int heartbeat_count = 0; + auto heartbeat_subscribe = [&heartbeat_count](const ClientID &id, + const gcs::HeartbeatTableData &result) { + ++heartbeat_count; + }; + RAY_CHECK_OK( + gcs_client_->Nodes().AsyncSubscribeHeartbeat(heartbeat_subscribe, nullptr)); + + int heartbeat_batch_count = 0; + auto heartbeat_batch_subscribe = + [&heartbeat_batch_count](const gcs::HeartbeatBatchTableData &result) { + ++heartbeat_batch_count; + }; + RAY_CHECK_OK(gcs_client_->Nodes().AsyncSubscribeBatchHeartbeat( + heartbeat_batch_subscribe, nullptr)); + + // Report heartbeat + ClientID node_id = ClientID::FromRandom(); + auto heartbeat = std::make_shared(); + heartbeat->set_client_id(node_id.Binary()); + ASSERT_TRUE(ReportHeartbeat(heartbeat)); + + // Report batch heartbeat + auto batch_heartbeat = std::make_shared(); + batch_heartbeat->add_batch()->set_client_id(node_id.Binary()); + ASSERT_TRUE(ReportBatchHeartbeat(batch_heartbeat)); + + EXPECT_EQ(heartbeat_count, 1); + EXPECT_EQ(heartbeat_batch_count, 1); +} + +TEST_F(ServiceBasedGcsGcsClientTest, TestTaskInfo) { + JobID job_id = JobID::FromInt(1); + TaskID task_id = TaskID::ForDriverTask(job_id); + auto task_table_data = GenTaskTableData(job_id.Binary(), task_id.Binary()); + + int task_count = 0; + auto task_subscribe = [&task_count](const TaskID &id, + const rpc::TaskTableData &result) { ++task_count; }; + RAY_CHECK_OK(gcs_client_->Tasks().AsyncSubscribe(task_id, task_subscribe, nullptr)); + + // Add task + ASSERT_TRUE(AddTask(task_table_data)); + + auto get_task_result = GetTask(task_id); + ASSERT_TRUE(get_task_result.task().task_spec().task_id() == task_id.Binary()); + ASSERT_TRUE(get_task_result.task().task_spec().job_id() == job_id.Binary()); + RAY_CHECK_OK(gcs_client_->Tasks().AsyncUnsubscribe(task_id, nullptr)); + ASSERT_TRUE(AddTask(task_table_data)); + + // Delete task + std::vector task_ids = {task_id}; + ASSERT_TRUE(DeleteTask(task_ids)); + EXPECT_EQ(task_count, 1); + + // Add task lease + int task_lease_count = 0; + auto task_lease_subscribe = [&task_lease_count]( + const TaskID &id, + const boost::optional &result) { + ++task_lease_count; + }; + RAY_CHECK_OK(gcs_client_->Tasks().AsyncSubscribeTaskLease(task_id, task_lease_subscribe, + nullptr)); + ClientID node_id = ClientID::FromRandom(); + auto task_lease = GenTaskLeaseData(task_id.Binary(), node_id.Binary()); + ASSERT_TRUE(AddTaskLease(task_lease)); + EXPECT_EQ(task_lease_count, 2); + + RAY_CHECK_OK(gcs_client_->Tasks().AsyncUnsubscribeTaskLease(task_id, nullptr)); + ASSERT_TRUE(AddTaskLease(task_lease)); + EXPECT_EQ(task_lease_count, 2); + + // Attempt task reconstruction + auto task_reconstruction_data = std::make_shared(); + task_reconstruction_data->set_task_id(task_id.Binary()); + task_reconstruction_data->set_num_reconstructions(0); + ASSERT_TRUE(AttemptTaskReconstruction(task_reconstruction_data)); +} + +} // namespace ray + +int main(int argc, char **argv) { + ::testing::InitGoogleTest(&argc, argv); + RAY_CHECK(argc == 4); + ray::REDIS_SERVER_EXEC_PATH = argv[1]; + ray::REDIS_CLIENT_EXEC_PATH = argv[2]; + ray::REDIS_MODULE_LIBRARY_PATH = argv[3]; + return RUN_ALL_TESTS(); +} diff --git a/src/ray/gcs/gcs_server/actor_info_handler_impl.cc b/src/ray/gcs/gcs_server/actor_info_handler_impl.cc index 37088d882..d55ebe7e2 100644 --- a/src/ray/gcs/gcs_server/actor_info_handler_impl.cc +++ b/src/ray/gcs/gcs_server/actor_info_handler_impl.cc @@ -13,8 +13,9 @@ void DefaultActorInfoHandler::HandleGetActorInfo( auto on_done = [actor_id, reply, send_reply_callback]( Status status, const boost::optional &result) { if (status.ok()) { - RAY_DCHECK(result); - reply->mutable_actor_table_data()->CopyFrom(*result); + if (result) { + reply->mutable_actor_table_data()->CopyFrom(*result); + } } else { RAY_LOG(ERROR) << "Failed to get actor info: " << status.ToString() << ", actor id = " << actor_id; diff --git a/src/ray/gcs/gcs_server/gcs_server.cc b/src/ray/gcs/gcs_server/gcs_server.cc index a9e9a144e..29247cc7f 100644 --- a/src/ray/gcs/gcs_server/gcs_server.cc +++ b/src/ray/gcs/gcs_server/gcs_server.cc @@ -64,6 +64,9 @@ void GcsServer::Start() { // Run rpc server. rpc_server_.Run(); + // Store gcs rpc server address in redis + StoreGcsServerAddressInRedis(); + // Run the event loop. // Using boost::asio::io_context::work to avoid ending the event loop when // there are no events to handle. @@ -107,6 +110,36 @@ std::unique_ptr GcsServer::InitObjectInfoHandler() { new rpc::DefaultObjectInfoHandler(*redis_gcs_client_)); } +void GcsServer::StoreGcsServerAddressInRedis() { + boost::asio::ip::detail::endpoint primary_endpoint; + boost::asio::ip::tcp::resolver resolver(main_service_); + boost::asio::ip::tcp::resolver::query query(boost::asio::ip::host_name(), ""); + boost::asio::ip::tcp::resolver::iterator iter = resolver.resolve(query); + boost::asio::ip::tcp::resolver::iterator end; // End marker. + while (iter != end) { + boost::asio::ip::tcp::endpoint ep = *iter; + if (ep.address().is_v4() && !ep.address().is_loopback() && + !ep.address().is_multicast()) { + primary_endpoint.address(ep.address()); + primary_endpoint.port(ep.port()); + break; + } + iter++; + } + + std::string address; + if (iter == end) { + address = "127.0.0.1:" + std::to_string(GetPort()); + } else { + address = primary_endpoint.address().to_string() + ":" + std::to_string(GetPort()); + } + RAY_LOG(INFO) << "Gcs server address = " << address; + + RAY_CHECK_OK(redis_gcs_client_->primary_context()->RunArgvAsync( + {"SET", "GcsServerAddress", address})); + RAY_LOG(INFO) << "Finished setting gcs server address: " << address; +} + std::unique_ptr GcsServer::InitTaskInfoHandler() { return std::unique_ptr( new rpc::DefaultTaskInfoHandler(*redis_gcs_client_)); diff --git a/src/ray/gcs/gcs_server/gcs_server.h b/src/ray/gcs/gcs_server/gcs_server.h index 3a0887401..9321440c0 100644 --- a/src/ray/gcs/gcs_server/gcs_server.h +++ b/src/ray/gcs/gcs_server/gcs_server.h @@ -69,6 +69,13 @@ class GcsServer { virtual std::unique_ptr InitWorkerInfoHandler(); private: + /// Store the address of GCS server in Redis. + /// + /// Clients will look up this address in Redis and use it to connect to GCS server. + /// TODO(ffbin): Once we entirely migrate to service-based GCS, we should pass GCS + /// server address directly to raylets and get rid of this lookup. + void StoreGcsServerAddressInRedis(); + /// Gcs server configuration GcsServerConfig config_; /// The grpc server diff --git a/src/ray/gcs/gcs_server/node_info_handler_impl.cc b/src/ray/gcs/gcs_server/node_info_handler_impl.cc index eb231d7d0..e7a805e42 100644 --- a/src/ray/gcs/gcs_server/node_info_handler_impl.cc +++ b/src/ray/gcs/gcs_server/node_info_handler_impl.cc @@ -152,12 +152,13 @@ void DefaultNodeInfoHandler::HandleUpdateResources( const UpdateResourcesRequest &request, UpdateResourcesReply *reply, SendReplyCallback send_reply_callback) { ClientID node_id = ClientID::FromBinary(request.node_id()); + RAY_LOG(DEBUG) << "Updating node resources, node id = " << node_id; + gcs::NodeInfoAccessor::ResourceMap resources; for (auto resource : request.resources()) { resources[resource.first] = std::make_shared(resource.second); } - RAY_LOG(DEBUG) << "Updating node resources, node id = " << node_id; auto on_done = [node_id, send_reply_callback](Status status) { if (!status.ok()) { RAY_LOG(ERROR) << "Failed to update node resources: " << status.ToString() diff --git a/src/ray/gcs/gcs_server/test/gcs_server_rpc_test.cc b/src/ray/gcs/gcs_server/test/gcs_server_rpc_test.cc index 0b15c2749..92efeafb7 100644 --- a/src/ray/gcs/gcs_server/test/gcs_server_rpc_test.cc +++ b/src/ray/gcs/gcs_server/test/gcs_server_rpc_test.cc @@ -235,7 +235,7 @@ class GcsServerTest : public RedisServiceManagerForTest { [&resources, &promise](const Status &status, const rpc::GetResourcesReply &reply) { RAY_CHECK_OK(status); - for (auto resource : reply.resources()) { + for (auto &resource : reply.resources()) { resources[resource.first] = resource.second; } promise.set_value(true); @@ -553,7 +553,7 @@ TEST_F(GcsServerTest, TestNodeInfo) { delete_resources_request.add_resource_name_list(resource_name); ASSERT_TRUE(DeleteResources(delete_resources_request)); resources = GetResources(node_id.Binary()); - ASSERT_TRUE(resources.size() == 0); + ASSERT_TRUE(resources.empty()); } TEST_F(GcsServerTest, TestObjectInfo) { diff --git a/src/ray/raylet/main.cc b/src/ray/raylet/main.cc index 82912b59d..9ec49e6d5 100644 --- a/src/ray/raylet/main.cc +++ b/src/ray/raylet/main.cc @@ -8,6 +8,8 @@ #include "ray/raylet/raylet.h" #include "ray/stats/stats.h" +#include "ray/gcs/gcs_client/service_based_gcs_client.h" + DEFINE_string(raylet_socket_name, "", "The socket name of raylet."); DEFINE_string(store_socket_name, "", "The socket name of object store."); DEFINE_int32(object_manager_port, -1, "The port of object manager."); @@ -156,7 +158,17 @@ int main(int argc, char *argv[]) { // Initialize gcs client ray::gcs::GcsClientOptions client_options(redis_address, redis_port, redis_password); - auto gcs_client = std::make_shared(client_options); + std::shared_ptr gcs_client; + + std::unique_ptr thread_io_service; + boost::asio::io_service io_service; + + // RAY_GCS_SERVICE_ENABLED only set in ci job, so we just check if it is null. + if (getenv("RAY_GCS_SERVICE_ENABLED") != nullptr) { + gcs_client = std::make_shared(client_options); + } else { + gcs_client = std::make_shared(client_options); + } RAY_CHECK_OK(gcs_client->Connect(main_service)); std::unique_ptr server(new ray::raylet::Raylet( diff --git a/src/ray/rpc/gcs_server/gcs_rpc_client.h b/src/ray/rpc/gcs_server/gcs_rpc_client.h index 551e868cc..c9237c7f8 100644 --- a/src/ray/rpc/gcs_server/gcs_rpc_client.h +++ b/src/ray/rpc/gcs_server/gcs_rpc_client.h @@ -1,6 +1,7 @@ #ifndef RAY_RPC_GCS_RPC_CLIENT_H #define RAY_RPC_GCS_RPC_CLIENT_H +#include "src/ray/protobuf/gcs_service.grpc.pb.h" #include "src/ray/rpc/grpc_client.h" namespace ray { diff --git a/src/ray/test/run_core_worker_tests.sh b/src/ray/test/run_core_worker_tests.sh index 4a602d2cd..340a065eb 100644 --- a/src/ray/test/run_core_worker_tests.sh +++ b/src/ray/test/run_core_worker_tests.sh @@ -22,7 +22,7 @@ fi set -e set -x -bazel build -c dbg $RAY_BAZEL_CONFIG "//:core_worker_test" "//:mock_worker" "//:raylet" "//:raylet_monitor" "//:libray_redis_module.so" "@plasma//:plasma_store_server" +bazel build -c dbg $RAY_BAZEL_CONFIG "//:core_worker_test" "//:mock_worker" "//:raylet" "//:raylet_monitor" "//:gcs_server" "//:libray_redis_module.so" "@plasma//:plasma_store_server" # Get the directory in which this script is executing. SCRIPT_DIR="`dirname \"$0\"`" @@ -45,6 +45,7 @@ STORE_EXEC="$BAZEL_BIN_PREFIX/external/plasma/plasma_store_server" RAYLET_EXEC="$BAZEL_BIN_PREFIX/raylet" RAYLET_MONITOR_EXEC="$BAZEL_BIN_PREFIX/raylet_monitor" MOCK_WORKER_EXEC="$BAZEL_BIN_PREFIX/mock_worker" +GCS_SERVER_EXEC="$BAZEL_BIN_PREFIX/gcs_server" # Allow cleanup commands to fail. bazel run "//:redis-cli" -- -p 6379 shutdown || true @@ -56,7 +57,7 @@ sleep 2s bazel run "//:redis-server" -- --loglevel warning ${LOAD_MODULE_ARGS} --port 6380 & sleep 2s # Run tests. -bazel run -c dbg $RAY_BAZEL_CONFIG "//:core_worker_test" $STORE_EXEC $RAYLET_EXEC $RAYLET_PORT $RAYLET_MONITOR_EXEC $MOCK_WORKER_EXEC +bazel run -c dbg $RAY_BAZEL_CONFIG "//:core_worker_test" $STORE_EXEC $RAYLET_EXEC $RAYLET_PORT $RAYLET_MONITOR_EXEC $MOCK_WORKER_EXEC $GCS_SERVER_EXEC sleep 1s bazel run "//:redis-cli" -- -p 6379 shutdown bazel run "//:redis-cli" -- -p 6380 shutdown diff --git a/streaming/src/test/queue_tests_base.h b/streaming/src/test/queue_tests_base.h index 818b7b852..ea4b7bfd0 100644 --- a/streaming/src/test/queue_tests_base.h +++ b/streaming/src/test/queue_tests_base.h @@ -14,9 +14,10 @@ static void flushall_redis(void) { class StreamingQueueTestBase : public ::testing::TestWithParam { public: StreamingQueueTestBase(int num_nodes, std::string raylet_exe, std::string store_exe, - int port, std::string actor_exe) + int port, std::string actor_exe, std::string gcs_server_exe) : gcs_options_("127.0.0.1", 6379, ""), raylet_executable_(raylet_exe), + gcs_server_executable_(gcs_server_exe), store_executable_(store_exe), actor_executable_(actor_exe), node_manager_port_(port) { @@ -34,6 +35,9 @@ class StreamingQueueTestBase : public ::testing::TestWithParam { store_socket = StartStore(); } + // start gcs server + gcs_server_pid_ = StartGcsServer("127.0.0.1"); + // start raylet on each node. Assign each node with different resources so that // a task can be scheduled to the desired node. for (int i = 0; i < num_nodes; i++) { @@ -52,6 +56,8 @@ class StreamingQueueTestBase : public ::testing::TestWithParam { for (const auto &store_socket : raylet_store_socket_names_) { StopStore(store_socket); } + + StopGcsServer(gcs_server_pid_); } JobID NextJobId() const { @@ -80,6 +86,30 @@ class StreamingQueueTestBase : public ::testing::TestWithParam { ASSERT_EQ(system(("rm -rf " + store_socket_name + ".pid").c_str()), 0); } + std::string StartGcsServer(std::string redis_address) { + std::string gcs_server_socket_name = "/tmp/gcs_server" + ObjectID::FromRandom().Hex(); + std::string ray_start_cmd = gcs_server_executable_; + ray_start_cmd.append(" --redis_address=" + redis_address) + .append(" --redis_port=6379") + .append(" --config_list=initial_reconstruction_timeout_milliseconds,2000") + .append(" & echo $! > " + gcs_server_socket_name + ".pid"); + + RAY_LOG(INFO) << "Start gcs server command: " << ray_start_cmd; + RAY_CHECK(system(ray_start_cmd.c_str()) == 0); + usleep(200 * 1000); + RAY_LOG(INFO) << "Finished start gcs server."; + return gcs_server_socket_name; + } + + void StopGcsServer(std::string gcs_server_socket_name) { + std::string gcs_server_pid = gcs_server_socket_name + ".pid"; + std::string kill_9 = "kill -9 `cat " + gcs_server_pid + "`"; + RAY_LOG(DEBUG) << kill_9; + ASSERT_TRUE(system(kill_9.c_str()) == 0); + ASSERT_TRUE(system(("rm -rf " + gcs_server_socket_name).c_str()) == 0); + ASSERT_TRUE(system(("rm -rf " + gcs_server_socket_name + ".pid").c_str()) == 0); + } + std::string StartRaylet(std::string store_socket_name, std::string node_ip_address, int port, std::string redis_address, std::string resource) { std::string raylet_socket_name = "/tmp/raylet" + RandomObjectID().Hex(); @@ -304,9 +334,11 @@ class StreamingQueueTestBase : public ::testing::TestWithParam { std::vector raylet_store_socket_names_; gcs::GcsClientOptions gcs_options_; std::string raylet_executable_; + std::string gcs_server_executable_; std::string store_executable_; std::string actor_executable_; int node_manager_port_; + std::string gcs_server_pid_; }; } // namespace streaming diff --git a/streaming/src/test/run_streaming_queue_test.sh b/streaming/src/test/run_streaming_queue_test.sh index 64dc8181f..7bf4bd1dd 100644 --- a/streaming/src/test/run_streaming_queue_test.sh +++ b/streaming/src/test/run_streaming_queue_test.sh @@ -35,7 +35,7 @@ if [ -z "$RAY_ROOT" ] ; then exit 1 fi -bazel build "//:core_worker_test" "//:mock_worker" "//:raylet" "//:libray_redis_module.so" "@plasma//:plasma_store_server" +bazel build "//:core_worker_test" "//:mock_worker" "//:raylet" "//:gcs_server" "//:libray_redis_module.so" "@plasma//:plasma_store_server" bazel build //streaming:streaming_test_worker bazel build //streaming:streaming_queue_tests @@ -50,6 +50,7 @@ LOAD_MODULE_ARGS="--loadmodule ${REDIS_MODULE}" STORE_EXEC="./bazel-bin/external/plasma/plasma_store_server" RAYLET_EXEC="./bazel-bin/raylet" STREAMING_TEST_WORKER_EXEC="./bazel-bin/streaming/streaming_test_worker" +GCS_SERVER_EXEC="./bazel-bin/gcs_server" # Allow cleanup commands to fail. bazel run //:redis-cli -- -p 6379 shutdown || true @@ -61,7 +62,7 @@ sleep 2s bazel run //:redis-server -- --loglevel warning ${LOAD_MODULE_ARGS} --port 6380 & sleep 2s # Run tests. -./bazel-bin/streaming/streaming_queue_tests $STORE_EXEC $RAYLET_EXEC $RAYLET_PORT $STREAMING_TEST_WORKER_EXEC +./bazel-bin/streaming/streaming_queue_tests $STORE_EXEC $RAYLET_EXEC $RAYLET_PORT $STREAMING_TEST_WORKER_EXEC $GCS_SERVER_EXEC sleep 1s bazel run //:redis-cli -- -p 6379 shutdown bazel run //:redis-cli -- -p 6380 shutdown diff --git a/streaming/src/test/streaming_queue_tests.cc b/streaming/src/test/streaming_queue_tests.cc index e5eb87b13..4bb4d2c5a 100644 --- a/streaming/src/test/streaming_queue_tests.cc +++ b/streaming/src/test/streaming_queue_tests.cc @@ -18,6 +18,7 @@ namespace streaming { static std::string store_executable; static std::string raylet_executable; +static std::string gcs_server_executable; static std::string actor_executable; static int node_manager_port; @@ -25,14 +26,14 @@ class StreamingWriterTest : public StreamingQueueTestBase { public: StreamingWriterTest() : StreamingQueueTestBase(1, raylet_executable, store_executable, node_manager_port, - actor_executable) {} + actor_executable, gcs_server_executable) {} }; class StreamingExactlySameTest : public StreamingQueueTestBase { public: StreamingExactlySameTest() : StreamingQueueTestBase(1, raylet_executable, store_executable, node_manager_port, - actor_executable) {} + actor_executable, gcs_server_executable) {} }; TEST_P(StreamingWriterTest, streaming_writer_exactly_once_test) { @@ -56,10 +57,11 @@ INSTANTIATE_TEST_CASE_P(StreamingTest, StreamingExactlySameTest, int main(int argc, char **argv) { // set_streaming_log_config("streaming_writer_test", StreamingLogLevel::INFO, 0); ::testing::InitGoogleTest(&argc, argv); - RAY_CHECK(argc == 5); + RAY_CHECK(argc == 6); ray::streaming::store_executable = std::string(argv[1]); ray::streaming::raylet_executable = std::string(argv[2]); ray::streaming::node_manager_port = std::stoi(std::string(argv[3])); ray::streaming::actor_executable = std::string(argv[4]); + ray::streaming::gcs_server_executable = std::string(argv[5]); return RUN_ALL_TESTS(); }