From b5b8c1d361bf9100e5f5528adf5aa064776cc4eb Mon Sep 17 00:00:00 2001 From: micafan <550435771@qq.com> Date: Fri, 19 Jul 2019 11:28:34 +0800 Subject: [PATCH] [GCS] introduce new gcs client and refactor actor table (#5058) --- BUILD.bazel | 15 +- src/ray/gcs/actor_state_accessor.cc | 114 ++++++++++++ src/ray/gcs/actor_state_accessor.h | 71 +++++++ src/ray/gcs/actor_state_accessor_test.cc | 161 ++++++++++++++++ src/ray/gcs/callback.h | 41 +++++ src/ray/gcs/gcs_client_interface.h | 108 +++++++++++ .../gcs/{client.cc => redis_gcs_client.cc} | 140 +++++++------- src/ray/gcs/{client.h => redis_gcs_client.h} | 74 +++----- ...lient_test.cc => redis_gcs_client_test.cc} | 173 +++++++++--------- src/ray/gcs/tables.cc | 24 +-- src/ray/gcs/tables.h | 70 +++---- src/ray/object_manager/object_directory.cc | 6 +- src/ray/object_manager/object_directory.h | 6 +- .../test/object_manager_stress_test.cc | 27 +-- .../test/object_manager_test.cc | 27 +-- src/ray/raylet/lineage_cache.cc | 2 +- src/ray/raylet/lineage_cache_test.cc | 6 +- src/ray/raylet/main.cc | 13 +- src/ray/raylet/monitor.cc | 8 +- src/ray/raylet/monitor.h | 4 +- src/ray/raylet/node_manager.cc | 146 ++++++--------- src/ray/raylet/node_manager.h | 14 +- .../raylet/object_manager_integration_test.cc | 15 +- src/ray/raylet/raylet.cc | 4 +- src/ray/raylet/raylet.h | 4 +- src/ray/raylet/reconstruction_policy.cc | 4 +- src/ray/raylet/reconstruction_policy_test.cc | 6 +- src/ray/raylet/worker_pool.cc | 2 +- src/ray/raylet/worker_pool.h | 6 +- src/ray/test/run_gcs_tests.sh | 5 +- 30 files changed, 875 insertions(+), 421 deletions(-) create mode 100644 src/ray/gcs/actor_state_accessor.cc create mode 100644 src/ray/gcs/actor_state_accessor.h create mode 100644 src/ray/gcs/actor_state_accessor_test.cc create mode 100644 src/ray/gcs/callback.h create mode 100644 src/ray/gcs/gcs_client_interface.h rename src/ray/gcs/{client.cc => redis_gcs_client.cc} (59%) rename src/ray/gcs/{client.h => redis_gcs_client.h} (58%) rename src/ray/gcs/{client_test.cc => redis_gcs_client_test.cc} (90%) diff --git a/BUILD.bazel b/BUILD.bazel index cc3da5139..33c4fc793 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -586,9 +586,20 @@ cc_library( ) cc_binary( - name = "gcs_client_test", + name = "redis_gcs_client_test", testonly = 1, - srcs = ["src/ray/gcs/client_test.cc"], + srcs = ["src/ray/gcs/redis_gcs_client_test.cc"], + copts = COPTS, + deps = [ + ":gcs", + "@com_google_googletest//:gtest_main", + ], +) + +cc_binary( + name = "actor_state_accessor_test", + testonly = 1, + srcs = ["src/ray/gcs/actor_state_accessor_test.cc"], copts = COPTS, deps = [ ":gcs", diff --git a/src/ray/gcs/actor_state_accessor.cc b/src/ray/gcs/actor_state_accessor.cc new file mode 100644 index 000000000..6d868c301 --- /dev/null +++ b/src/ray/gcs/actor_state_accessor.cc @@ -0,0 +1,114 @@ +#include "ray/gcs/actor_state_accessor.h" +#include +#include "ray/gcs/redis_gcs_client.h" +#include "ray/util/logging.h" + +namespace ray { + +namespace gcs { + +ActorStateAccessor::ActorStateAccessor(RedisGcsClient &client_impl) + : client_impl_(client_impl) {} + +Status ActorStateAccessor::AsyncGet(const ActorID &actor_id, + const MultiItemCallback &callback) { + RAY_CHECK(callback != nullptr); + auto on_done = [callback](RedisGcsClient *client, const ActorID &actor_id, + const std::vector &data) { + callback(Status::OK(), data); + }; + + ActorTable &actor_table = client_impl_.actor_table(); + return actor_table.Lookup(JobID::Nil(), actor_id, on_done); +} + +Status ActorStateAccessor::AsyncRegister(const std::shared_ptr &data_ptr, + const StatusCallback &callback) { + auto on_success = [callback](RedisGcsClient *client, const ActorID &actor_id, + const ActorTableData &data) { + if (callback != nullptr) { + callback(Status::OK()); + } + }; + + auto on_failure = [callback](RedisGcsClient *client, const ActorID &actor_id, + const ActorTableData &data) { + if (callback != nullptr) { + callback(Status::Invalid("Adding actor failed.")); + } + }; + + ActorID actor_id = ActorID::FromBinary(data_ptr->actor_id()); + ActorTable &actor_table = client_impl_.actor_table(); + return actor_table.AppendAt(JobID::Nil(), actor_id, data_ptr, on_success, on_failure, + /*log_length*/ 0); +} + +Status ActorStateAccessor::AsyncUpdate(const ActorID &actor_id, + const std::shared_ptr &data_ptr, + const StatusCallback &callback) { + // The actor log starts with an ALIVE entry. This is followed by 0 to N pairs + // of (RECONSTRUCTING, ALIVE) entries, where N is the maximum number of + // reconstructions. This is followed optionally by a DEAD entry. + int log_length = + 2 * (data_ptr->max_reconstructions() - data_ptr->remaining_reconstructions()); + if (data_ptr->state() != ActorTableData::ALIVE) { + // RECONSTRUCTING or DEAD entries have an odd index. + log_length += 1; + } + + auto on_success = [callback](RedisGcsClient *client, const ActorID &actor_id, + const ActorTableData &data) { + // If we successfully appended a record to the GCS table of the actor that + // has died, signal this to anyone receiving signals from this actor. + if (data.state() == ActorTableData::DEAD || + data.state() == ActorTableData::RECONSTRUCTING) { + std::vector args = {"XADD", actor_id.Hex(), "*", "signal", + "ACTOR_DIED_SIGNAL"}; + auto redis_context = client->primary_context(); + RAY_CHECK_OK(redis_context->RunArgvAsync(args)); + } + + if (callback != nullptr) { + callback(Status::OK()); + } + }; + + auto on_failure = [callback](RedisGcsClient *client, const ActorID &actor_id, + const ActorTableData &data) { + if (callback != nullptr) { + callback(Status::Invalid("Updating actor failed.")); + } + }; + + ActorTable &actor_table = client_impl_.actor_table(); + return actor_table.AppendAt(JobID::Nil(), actor_id, data_ptr, on_success, on_failure, + log_length); +} + +Status ActorStateAccessor::AsyncSubscribe( + const SubscribeCallback &subscribe, + const StatusCallback &done) { + RAY_CHECK(subscribe != nullptr); + auto on_subscribe = [subscribe](RedisGcsClient *client, const ActorID &actor_id, + const std::vector &data) { + if (!data.empty()) { + // We only need the last entry, because it represents the latest state of + // this actor. + subscribe(actor_id, data.back()); + } + }; + + auto on_done = [done](RedisGcsClient *client) { + if (done != nullptr) { + done(Status::OK()); + } + }; + + ActorTable &actor_table = client_impl_.actor_table(); + return actor_table.Subscribe(JobID::Nil(), ClientID::Nil(), on_subscribe, on_done); +} + +} // namespace gcs + +} // namespace ray diff --git a/src/ray/gcs/actor_state_accessor.h b/src/ray/gcs/actor_state_accessor.h new file mode 100644 index 000000000..273c7aa88 --- /dev/null +++ b/src/ray/gcs/actor_state_accessor.h @@ -0,0 +1,71 @@ +#ifndef RAY_GCS_ACTOR_STATE_ACCESSOR_H +#define RAY_GCS_ACTOR_STATE_ACCESSOR_H + +#include "ray/common/id.h" +#include "ray/gcs/callback.h" +#include "ray/gcs/tables.h" + +namespace ray { + +namespace gcs { + +class RedisGcsClient; + +/// \class ActorStateAccessor +/// ActorStateAccessor class encapsulates the implementation details of +/// reading or writing or subscribing of actor's specification (immutable fields which +/// determined at submission time, and mutable fields which are determined at runtime). +class ActorStateAccessor { + public: + explicit ActorStateAccessor(RedisGcsClient &client_impl); + + ~ActorStateAccessor() {} + + /// Get actor specification from GCS asynchronously. + /// + /// \param actor_id The ID of actor to look up in the GCS. + /// \param callback Callback that will be called after lookup finishes. + /// \return Status + Status AsyncGet(const ActorID &actor_id, + const MultiItemCallback &callback); + + /// Register an actor to GCS asynchronously. + /// + /// \param data_ptr The actor that will be registered to the GCS. + /// \param callback Callback that will be called after actor has been registered + /// to the GCS. + /// \return Status + Status AsyncRegister(const std::shared_ptr &data_ptr, + const StatusCallback &callback); + + /// Update dynamic states of actor in GCS asynchronously. + /// + /// \param actor_id ID of the actor to update. + /// \param data_ptr Data of the actor to update. + /// \param callback Callback that will be called after update finishes. + /// \return Status + /// TODO(micafan) Don't expose the whole `ActorTableData` and only allow + /// updating dynamic states. + Status AsyncUpdate(const ActorID &actor_id, + const std::shared_ptr &data_ptr, + const StatusCallback &callback); + + /// Subscribe to any register operations of actors. + /// + /// \param subscribe Callback that will be called each time when an actor is registered + /// or updated. + /// \param done Callback that will be called when subscription is complete and we + /// are ready to receive notification. + /// \return Status + Status AsyncSubscribe(const SubscribeCallback &subscribe, + const StatusCallback &done); + + private: + RedisGcsClient &client_impl_; +}; + +} // namespace gcs + +} // namespace ray + +#endif // RAY_GCS_ACTOR_STATE_ACCESSOR_H diff --git a/src/ray/gcs/actor_state_accessor_test.cc b/src/ray/gcs/actor_state_accessor_test.cc new file mode 100644 index 000000000..2278fa5b7 --- /dev/null +++ b/src/ray/gcs/actor_state_accessor_test.cc @@ -0,0 +1,161 @@ +#include +#include +#include +#include +#include +#include "gtest/gtest.h" +#include "ray/gcs/redis_gcs_client.h" + +namespace ray { + +namespace gcs { + +class ActorStateAccessorTest : public ::testing::Test { + public: + ActorStateAccessorTest() : options_("127.0.0.1", 6379, "", true) {} + + virtual void SetUp() { + GenTestData(); + + gcs_client_.reset(new RedisGcsClient(options_)); + RAY_CHECK_OK(gcs_client_->Connect(io_service_)); + + work_thread.reset(new std::thread([this] { + std::auto_ptr work( + new boost::asio::io_service::work(io_service_)); + io_service_.run(); + })); + } + + virtual void TearDown() { + gcs_client_->Disconnect(); + + io_service_.stop(); + work_thread->join(); + work_thread.reset(); + + gcs_client_.reset(); + + ClearTestData(); + } + + protected: + void GenTestData() { GenActorData(); } + + void GenActorData() { + for (size_t i = 0; i < 2; ++i) { + std::shared_ptr actor = std::make_shared(); + ActorID actor_id = ActorID::FromRandom(); + actor->set_actor_id(actor_id.Binary()); + actor->set_max_reconstructions(1); + actor->set_remaining_reconstructions(1); + JobID job_id = JobID::FromInt(i); + actor->set_job_id(job_id.Binary()); + actor->set_state(ActorTableData::ALIVE); + actor_datas_[actor_id] = actor; + } + } + + void ClearTestData() { actor_datas_.clear(); } + + void WaitPendingDone(std::chrono::milliseconds timeout) { + WaitPendingDone(pending_count_, timeout); + } + + void WaitPendingDone(std::atomic &pending_count, + std::chrono::milliseconds timeout) { + while (pending_count != 0 && timeout.count() > 0) { + std::chrono::milliseconds interval(10); + std::this_thread::sleep_for(interval); + timeout -= interval; + } + EXPECT_EQ(pending_count, 0); + } + + protected: + GcsClientOptions options_; + std::unique_ptr gcs_client_; + + boost::asio::io_service io_service_; + std::unique_ptr work_thread; + + std::unordered_map> actor_datas_; + + std::atomic pending_count_{0}; +}; + +TEST_F(ActorStateAccessorTest, RegisterAndGet) { + ActorStateAccessor &actor_accessor = gcs_client_->Actors(); + // register + for (const auto &elem : actor_datas_) { + const auto &actor = elem.second; + ++pending_count_; + actor_accessor.AsyncRegister(actor, [this](Status status) { + RAY_CHECK_OK(status); + --pending_count_; + }); + } + + std::chrono::milliseconds timeout(10000); + WaitPendingDone(timeout); + + // get + for (const auto &elem : actor_datas_) { + const auto &actor = elem.second; + ++pending_count_; + actor_accessor.AsyncGet(elem.first, + [this](Status status, std::vector datas) { + ASSERT_EQ(datas.size(), 1U); + ActorID actor_id = ActorID::FromBinary(datas[0].actor_id()); + auto it = actor_datas_.find(actor_id); + ASSERT_TRUE(it != actor_datas_.end()); + --pending_count_; + }); + } + + WaitPendingDone(timeout); +} + +TEST_F(ActorStateAccessorTest, Subscribe) { + ActorStateAccessor &actor_accessor = gcs_client_->Actors(); + std::chrono::milliseconds timeout(10000); + // subscribe + std::atomic sub_pending_count(0); + std::atomic do_sub_pending_count(0); + auto subscribe = [this, &sub_pending_count](const ActorID &actor_id, + const ActorTableData &data) { + const auto it = actor_datas_.find(actor_id); + ASSERT_TRUE(it != actor_datas_.end()); + --sub_pending_count; + }; + auto done = [&do_sub_pending_count](Status status) { + RAY_CHECK_OK(status); + --do_sub_pending_count; + }; + + ++do_sub_pending_count; + actor_accessor.AsyncSubscribe(subscribe, done); + // Wait until subscribe finishes. + WaitPendingDone(do_sub_pending_count, timeout); + + // register + std::atomic register_pending_count(0); + for (const auto &elem : actor_datas_) { + const auto &actor = elem.second; + ++sub_pending_count; + ++register_pending_count; + actor_accessor.AsyncRegister(actor, [®ister_pending_count](Status status) { + RAY_CHECK_OK(status); + --register_pending_count; + }); + } + // Wait until register finishes. + WaitPendingDone(register_pending_count, timeout); + + // Wait for all subscribe notifications. + WaitPendingDone(sub_pending_count, timeout); +} + +} // namespace gcs + +} // namespace ray diff --git a/src/ray/gcs/callback.h b/src/ray/gcs/callback.h new file mode 100644 index 000000000..ef21a0beb --- /dev/null +++ b/src/ray/gcs/callback.h @@ -0,0 +1,41 @@ +#ifndef RAY_GCS_CALLBACK_H +#define RAY_GCS_CALLBACK_H + +#include +#include +#include "ray/common/status.h" + +namespace ray { + +namespace gcs { + +/// This callback is used to notify when a write/subscribe to GCS completes. +/// \param status Status indicates whether the write/subscribe was successful. +using StatusCallback = std::function; + +/// This callback is used to receive one item from GCS when a read completes. +/// \param status Status indicates whether the read was successful. +/// \param result The item returned by GCS. If the item to read doesn't exist, +/// this optional object is empty. +template +using OptionalItemCallback = + std::function result)>; + +/// This callback is used to receive multiple items from GCS when a read completes. +/// \param status Status indicates whether the read was successful. +/// \param result The items returned by GCS. +template +using MultiItemCallback = + std::function &result)>; + +/// This callback is used to receive notifications of the subscribed items in the GCS. +/// \param id The id of the item. +/// \param result The notification message. +template +using SubscribeCallback = std::function; + +} // namespace gcs + +} // namespace ray + +#endif // RAY_GCS_CALLBACK_H diff --git a/src/ray/gcs/gcs_client_interface.h b/src/ray/gcs/gcs_client_interface.h new file mode 100644 index 000000000..69c6fa814 --- /dev/null +++ b/src/ray/gcs/gcs_client_interface.h @@ -0,0 +1,108 @@ +#ifndef RAY_GCS_GCS_CLIENT_H +#define RAY_GCS_GCS_CLIENT_H + +#include +#include +#include +#include +#include "ray/common/status.h" +#include "ray/gcs/actor_state_accessor.h" +#include "ray/util/logging.h" + +namespace ray { + +namespace gcs { + +/// \class GcsClientOptions +/// GCS client's options (configuration items), such as service address, and service +/// password. +class GcsClientOptions { + public: + /// Constructor of GcsClientOptions. + /// + /// \param ip GCS service ip. + /// \param port GCS service port. + /// \param password GCS service password. + /// \param is_test_client Whether this client is used for tests. + GcsClientOptions(const std::string &ip, int port, const std::string &password, + bool is_test_client = false) + : server_ip_(ip), + server_port_(port), + password_(password), + is_test_client_(is_test_client) { +#if RAY_USE_NEW_GCS + command_type_ = CommandType::kChain; +#else + command_type_ = CommandType::kRegular; +#endif + } + + /// This constructor is only used for testing (RedisGcsClient's test). + /// + /// \param ip GCS service ip + /// \param port GCS service port + /// \param command_type Command type of RedisGcsClient + GcsClientOptions(const std::string &ip, int port, CommandType command_type) + : server_ip_(ip), + server_port_(port), + command_type_(command_type), + is_test_client_(true) {} + + // GCS server address + std::string server_ip_; + int server_port_; + + // Password of GCS server. + std::string password_; + // GCS command type. If CommandType::kChain, chain-replicated versions of the tables + // might be used, if available. + CommandType command_type_ = CommandType::kUnknown; + + // Whether this client is used for tests. + bool is_test_client_{false}; +}; + +/// \class GcsClientInterface +/// Abstract interface of the GCS client. +/// +/// To read and write from the GCS, `Connect()` must be called and return Status::OK. +/// Before exit, `Disconnect()` must be called. +class GcsClientInterface : public std::enable_shared_from_this { + public: + virtual ~GcsClientInterface() { RAY_CHECK(!is_connected_); } + + /// Connect to GCS Service. Non-thread safe. + /// This function must be called before calling other functions. + /// + /// \return Status + virtual Status Connect(boost::asio::io_service &io_service) = 0; + + /// Disconnect with GCS Service. Non-thread safe. + virtual void Disconnect() = 0; + + /// Get ActorStateAccessor for reading or writing or subscribing to + /// actors. This function is thread safe. + ActorStateAccessor &Actors() { + RAY_CHECK(actor_accessor_ != nullptr); + return *actor_accessor_; + } + + protected: + /// Constructor of GcsClientInterface. + /// + /// \param options Options for client. + GcsClientInterface(const GcsClientOptions &options) : options_(options) {} + + GcsClientOptions options_; + + // Whether this client is connected to GCS. + bool is_connected_{false}; + + std::unique_ptr actor_accessor_; +}; + +} // namespace gcs + +} // namespace ray + +#endif // RAY_GCS_GCS_CLIENT_H diff --git a/src/ray/gcs/client.cc b/src/ray/gcs/redis_gcs_client.cc similarity index 59% rename from src/ray/gcs/client.cc rename to src/ray/gcs/redis_gcs_client.cc index e96c5ad38..f7fad832e 100644 --- a/src/ray/gcs/client.cc +++ b/src/ray/gcs/redis_gcs_client.cc @@ -1,4 +1,4 @@ -#include "ray/gcs/client.h" +#include "ray/gcs/redis_gcs_client.h" #include "ray/common/ray_config.h" #include "ray/gcs/redis_context.h" @@ -70,50 +70,63 @@ namespace ray { namespace gcs { -AsyncGcsClient::AsyncGcsClient(const std::string &address, int port, - const ClientID &client_id, CommandType command_type, - bool is_test_client = false, - const std::string &password = "") { +RedisGcsClient::RedisGcsClient(const GcsClientOptions &options) + : GcsClientInterface(options) {} + +Status RedisGcsClient::Connect(boost::asio::io_service &io_service) { + RAY_CHECK(!is_connected_); + + if (options_.server_ip_.empty()) { + RAY_LOG(ERROR) << "Failed to connect, gcs service address is empty."; + return Status::Invalid("gcs service address is invalid!"); + } + primary_context_ = std::make_shared(); - RAY_CHECK_OK( - primary_context_->Connect(address, port, /*sharding=*/true, /*password=*/password)); + RAY_CHECK_OK(primary_context_->Connect(options_.server_ip_, options_.server_port_, + /*sharding=*/true, + /*password=*/options_.password_)); - if (!is_test_client) { + if (!options_.is_test_client_) { // Moving sharding into constructor defaultly means that sharding = true. // This design decision may worth a look. std::vector addresses; std::vector ports; GetRedisShards(primary_context_->sync_context(), addresses, ports); - if (addresses.size() == 0 || ports.size() == 0) { - addresses.push_back(address); - ports.push_back(port); + if (addresses.empty()) { + RAY_CHECK(ports.empty()); + addresses.push_back(options_.server_ip_); + ports.push_back(options_.server_port_); } - // Populate shard_contexts. for (size_t i = 0; i < addresses.size(); ++i) { + // Populate shard_contexts. shard_contexts_.push_back(std::make_shared()); - } - - RAY_CHECK(shard_contexts_.size() == addresses.size()); - for (size_t i = 0; i < addresses.size(); ++i) { RAY_CHECK_OK(shard_contexts_[i]->Connect(addresses[i], ports[i], /*sharding=*/true, - /*password=*/password)); + /*password=*/options_.password_)); } } else { shard_contexts_.push_back(std::make_shared()); - RAY_CHECK_OK(shard_contexts_[0]->Connect(address, port, /*sharding=*/true, - /*password=*/password)); + RAY_CHECK_OK(shard_contexts_[0]->Connect(options_.server_ip_, options_.server_port_, + /*sharding=*/true, + /*password=*/options_.password_)); } actor_table_.reset(new ActorTable({primary_context_}, this)); - client_table_.reset(new ClientTable({primary_context_}, this, client_id)); + + // TODO(micafan) Modify ClientTable' Constructor(remove ClientID) in future. + // We will use NodeID instead of ClientID. + // For worker/driver, it might not have this field(NodeID). + // For raylet, NodeID should be initialized in raylet layer(not here). + client_table_.reset(new ClientTable({primary_context_}, this, ClientID::FromRandom())); + error_table_.reset(new ErrorTable({primary_context_}, this)); job_table_.reset(new JobTable({primary_context_}, this)); heartbeat_batch_table_.reset(new HeartbeatBatchTable({primary_context_}, this)); // Tables below would be sharded. object_table_.reset(new ObjectTable(shard_contexts_, this)); - raylet_task_table_.reset(new raylet::TaskTable(shard_contexts_, this, command_type)); + raylet_task_table_.reset( + new raylet::TaskTable(shard_contexts_, this, options_.command_type_)); task_reconstruction_log_.reset(new TaskReconstructionLog(shard_contexts_, this)); task_lease_table_.reset(new TaskLeaseTable(shard_contexts_, this)); heartbeat_table_.reset(new HeartbeatTable(shard_contexts_, this)); @@ -121,47 +134,26 @@ AsyncGcsClient::AsyncGcsClient(const std::string &address, int port, actor_checkpoint_table_.reset(new ActorCheckpointTable(shard_contexts_, this)); actor_checkpoint_id_table_.reset(new ActorCheckpointIdTable(shard_contexts_, this)); resource_table_.reset(new DynamicResourceTable({primary_context_}, this)); - command_type_ = command_type; - // TODO(swang): Call the client table's Connect() method here. To do this, - // we need to make sure that we are attached to an event loop first. This - // currently isn't possible because the aeEventLoop, which we use for - // testing, requires us to connect to Redis first. + actor_accessor_.reset(new ActorStateAccessor(*this)); + + Status status = Attach(io_service); + is_connected_ = status.ok(); + + // TODO(micafan): Synchronously register node and look up existing nodes here + // for this client is Raylet. + RAY_LOG(INFO) << "RedisGcsClient::Connect finished with status " << status; + return status; } -#if RAY_USE_NEW_GCS -// Use of kChain currently only applies to Table::Add which affects only the -// task table, and when RAY_USE_NEW_GCS is set at compile time. -AsyncGcsClient::AsyncGcsClient(const std::string &address, int port, - const ClientID &client_id, bool is_test_client = false, - const std::string &password = "") - : AsyncGcsClient(address, port, client_id, CommandType::kChain, is_test_client, - password) {} -#else -AsyncGcsClient::AsyncGcsClient(const std::string &address, int port, - const ClientID &client_id, bool is_test_client = false, - const std::string &password = "") - : AsyncGcsClient(address, port, client_id, CommandType::kRegular, is_test_client, - password) {} -#endif // RAY_USE_NEW_GCS +void RedisGcsClient::Disconnect() { + RAY_CHECK(is_connected_); + is_connected_ = false; + RAY_LOG(INFO) << "RedisGcsClient Disconnected."; + // TODO(micafan): Synchronously unregister node if this client is Raylet. +} -AsyncGcsClient::AsyncGcsClient(const std::string &address, int port, - CommandType command_type) - : AsyncGcsClient(address, port, ClientID::FromRandom(), command_type) {} - -AsyncGcsClient::AsyncGcsClient(const std::string &address, int port, - CommandType command_type, bool is_test_client) - : AsyncGcsClient(address, port, ClientID::FromRandom(), command_type, - is_test_client) {} - -AsyncGcsClient::AsyncGcsClient(const std::string &address, int port, - const std::string &password = "") - : AsyncGcsClient(address, port, ClientID::FromRandom(), false, password) {} - -AsyncGcsClient::AsyncGcsClient(const std::string &address, int port, bool is_test_client) - : AsyncGcsClient(address, port, ClientID::FromRandom(), is_test_client) {} - -Status AsyncGcsClient::Attach(boost::asio::io_service &io_service) { +Status RedisGcsClient::Attach(boost::asio::io_service &io_service) { // Take care of sharding contexts. RAY_CHECK(shard_asio_async_clients_.empty()) << "Attach shall be called only once"; for (std::shared_ptr context : shard_contexts_) { @@ -177,9 +169,9 @@ Status AsyncGcsClient::Attach(boost::asio::io_service &io_service) { return Status::OK(); } -std::string AsyncGcsClient::DebugString() const { +std::string RedisGcsClient::DebugString() const { std::stringstream result; - result << "AsyncGcsClient:"; + result << "RedisGcsClient:"; result << "\n- TaskTable: " << raylet_task_table_->DebugString(); result << "\n- ActorTable: " << actor_table_->DebugString(); result << "\n- TaskReconstructionLog: " << task_reconstruction_log_->DebugString(); @@ -192,41 +184,41 @@ std::string AsyncGcsClient::DebugString() const { return result.str(); } -ObjectTable &AsyncGcsClient::object_table() { return *object_table_; } +ObjectTable &RedisGcsClient::object_table() { return *object_table_; } -raylet::TaskTable &AsyncGcsClient::raylet_task_table() { return *raylet_task_table_; } +raylet::TaskTable &RedisGcsClient::raylet_task_table() { return *raylet_task_table_; } -ActorTable &AsyncGcsClient::actor_table() { return *actor_table_; } +ActorTable &RedisGcsClient::actor_table() { return *actor_table_; } -TaskReconstructionLog &AsyncGcsClient::task_reconstruction_log() { +TaskReconstructionLog &RedisGcsClient::task_reconstruction_log() { return *task_reconstruction_log_; } -TaskLeaseTable &AsyncGcsClient::task_lease_table() { return *task_lease_table_; } +TaskLeaseTable &RedisGcsClient::task_lease_table() { return *task_lease_table_; } -ClientTable &AsyncGcsClient::client_table() { return *client_table_; } +ClientTable &RedisGcsClient::client_table() { return *client_table_; } -HeartbeatTable &AsyncGcsClient::heartbeat_table() { return *heartbeat_table_; } +HeartbeatTable &RedisGcsClient::heartbeat_table() { return *heartbeat_table_; } -HeartbeatBatchTable &AsyncGcsClient::heartbeat_batch_table() { +HeartbeatBatchTable &RedisGcsClient::heartbeat_batch_table() { return *heartbeat_batch_table_; } -ErrorTable &AsyncGcsClient::error_table() { return *error_table_; } +ErrorTable &RedisGcsClient::error_table() { return *error_table_; } -JobTable &AsyncGcsClient::job_table() { return *job_table_; } +JobTable &RedisGcsClient::job_table() { return *job_table_; } -ProfileTable &AsyncGcsClient::profile_table() { return *profile_table_; } +ProfileTable &RedisGcsClient::profile_table() { return *profile_table_; } -ActorCheckpointTable &AsyncGcsClient::actor_checkpoint_table() { +ActorCheckpointTable &RedisGcsClient::actor_checkpoint_table() { return *actor_checkpoint_table_; } -ActorCheckpointIdTable &AsyncGcsClient::actor_checkpoint_id_table() { +ActorCheckpointIdTable &RedisGcsClient::actor_checkpoint_id_table() { return *actor_checkpoint_id_table_; } -DynamicResourceTable &AsyncGcsClient::resource_table() { return *resource_table_; } +DynamicResourceTable &RedisGcsClient::resource_table() { return *resource_table_; } } // namespace gcs diff --git a/src/ray/gcs/client.h b/src/ray/gcs/redis_gcs_client.h similarity index 58% rename from src/ray/gcs/client.h rename to src/ray/gcs/redis_gcs_client.h index 0ebee0d70..52bf9f628 100644 --- a/src/ray/gcs/client.h +++ b/src/ray/gcs/redis_gcs_client.h @@ -1,5 +1,5 @@ -#ifndef RAY_GCS_CLIENT_H -#define RAY_GCS_CLIENT_H +#ifndef RAY_GCS_REDIS_GCS_CLIENT_H +#define RAY_GCS_REDIS_GCS_CLIENT_H #include #include @@ -7,6 +7,7 @@ #include "ray/common/id.h" #include "ray/common/status.h" #include "ray/gcs/asio.h" +#include "ray/gcs/gcs_client_interface.h" #include "ray/gcs/tables.h" #include "ray/util/logging.h" @@ -16,38 +17,30 @@ namespace gcs { class RedisContext; -class RAY_EXPORT AsyncGcsClient { - public: - /// Start a GCS client with the given client ID and command type (regular or - /// chain-replicated). To read from the GCS tables, Connect() and then - /// Attach() must be called. To read and write from the GCS tables requires a - /// further call to Connect() to the client table. - /// - /// \param address The GCS IP address. - /// \param port The GCS port. - /// \param sharding If true, use sharded redis for the GCS. - /// \param client_id The ID to assign to the client. - /// \param command_type GCS command type. If CommandType::kChain, chain-replicated - /// versions of the tables might be used, if available. - AsyncGcsClient(const std::string &address, int port, const ClientID &client_id, - CommandType command_type, bool is_test_client, - const std::string &redis_password); - AsyncGcsClient(const std::string &address, int port, const ClientID &client_id, - bool is_test_client, const std::string &password); - AsyncGcsClient(const std::string &address, int port, CommandType command_type); - AsyncGcsClient(const std::string &address, int port, CommandType command_type, - bool is_test_client); - AsyncGcsClient(const std::string &address, int port, const std::string &password); - AsyncGcsClient(const std::string &address, int port, bool is_test_client); +class RAY_EXPORT RedisGcsClient : public GcsClientInterface { + friend class ActorStateAccessor; - /// Attach this client to an asio event loop. Note that only - /// one event loop should be attached at a time. - Status Attach(boost::asio::io_service &io_service); + public: + /// Constructor of RedisGcsClient. + /// Connect() must be called(and return ok) before you call any other methods. + /// TODO(micafan) To read and write from the GCS tables requires a further + /// call to Connect() to the client table. Will fix this in next pr. + /// + /// \param GcsClientOptions Options of client, e.g. server address, is test client ... + RedisGcsClient(const GcsClientOptions &options); + + /// Connect to GCS Service. Non-thread safe. + /// Call this function before calling other functions. + /// + /// \return Status + Status Connect(boost::asio::io_service &io_service); + + /// Disconnect with GCS Service. Non-thread safe. + void Disconnect(); // TODO: Some API for getting the error on the driver ObjectTable &object_table(); raylet::TaskTable &raylet_task_table(); - ActorTable &actor_table(); TaskReconstructionLog &task_reconstruction_log(); TaskLeaseTable &task_lease_table(); ClientTable &client_table(); @@ -77,6 +70,13 @@ class RAY_EXPORT AsyncGcsClient { std::string DebugString() const; private: + /// Attach this client to an asio event loop. Note that only + /// one event loop should be attached at a time. + Status Attach(boost::asio::io_service &io_service); + + /// Use method Actors() instead + ActorTable &actor_table(); + std::unique_ptr object_table_; std::unique_ptr raylet_task_table_; std::unique_ptr actor_table_; @@ -99,24 +99,10 @@ class RAY_EXPORT AsyncGcsClient { std::unique_ptr job_table_; std::unique_ptr asio_async_auxiliary_client_; std::unique_ptr asio_subscribe_auxiliary_client_; - CommandType command_type_; -}; - -class SyncGcsClient { - Status LogEvent(const std::string &key, const std::string &value, double timestamp); - Status NotifyError(const std::map &error_info); - Status RegisterFunction(const JobID &job_id, const FunctionID &function_id, - const std::string &language, const std::string &name, - const std::string &data); - Status RetrieveFunction(const JobID &job_id, const FunctionID &function_id, - std::string *name, std::string *data); - - Status AddExport(const std::string &job_id, std::string &export_data); - Status GetExport(const std::string &job_id, int64_t export_index, std::string *data); }; } // namespace gcs } // namespace ray -#endif // RAY_GCS_CLIENT_H +#endif // RAY_GCS_REDIS_GCS_CLIENT_H diff --git a/src/ray/gcs/client_test.cc b/src/ray/gcs/redis_gcs_client_test.cc similarity index 90% rename from src/ray/gcs/client_test.cc rename to src/ray/gcs/redis_gcs_client_test.cc index a63439833..1ba998329 100644 --- a/src/ray/gcs/client_test.cc +++ b/src/ray/gcs/redis_gcs_client_test.cc @@ -6,7 +6,7 @@ extern "C" { } #include "ray/common/ray_config.h" -#include "ray/gcs/client.h" +#include "ray/gcs/redis_gcs_client.h" #include "ray/gcs/tables.h" namespace ray { @@ -29,8 +29,8 @@ inline JobID NextJobID() { class TestGcs : public ::testing::Test { public: TestGcs(CommandType command_type) : num_callbacks_(0), command_type_(command_type) { - client_ = std::make_shared("127.0.0.1", 6379, command_type_, - /*is_test_client=*/true); + GcsClientOptions options("127.0.0.1", 6379, command_type_); + client_ = std::make_shared(options); job_id_ = NextJobID(); } @@ -50,7 +50,7 @@ class TestGcs : public ::testing::Test { protected: uint64_t num_callbacks_; gcs::CommandType command_type_; - std::shared_ptr client_; + std::shared_ptr client_; JobID job_id_; }; @@ -60,13 +60,14 @@ class TestGcsWithAsio : public TestGcs { public: TestGcsWithAsio(CommandType command_type) : TestGcs(command_type), io_service_(), work_(io_service_) { - RAY_CHECK_OK(client_->Attach(io_service_)); + RAY_CHECK_OK(client_->Connect(io_service_)); } TestGcsWithAsio() : TestGcsWithAsio(CommandType::kRegular) {} ~TestGcsWithAsio() { // Destroy the client first since it has a reference to the event loop. + client_->Disconnect(); client_.reset(); } void Start() override { io_service_.run(); } @@ -102,19 +103,19 @@ bool TaskTableDataEqual(const TaskTableData &data1, const TaskTableData &data2) spec1.num_returns() == spec2.num_returns()); } -void TestTableLookup(const JobID &job_id, std::shared_ptr client) { +void TestTableLookup(const JobID &job_id, std::shared_ptr client) { const auto task_id = TaskID::FromRandom(); const auto data = CreateTaskTableData(task_id); // Check that we added the correct task. - auto add_callback = [task_id, data](gcs::AsyncGcsClient *client, const TaskID &id, + auto add_callback = [task_id, data](gcs::RedisGcsClient *client, const TaskID &id, const TaskTableData &d) { ASSERT_EQ(id, task_id); ASSERT_TRUE(TaskTableDataEqual(*data, d)); }; // Check that the lookup returns the added task. - auto lookup_callback = [task_id, data](gcs::AsyncGcsClient *client, const TaskID &id, + auto lookup_callback = [task_id, data](gcs::RedisGcsClient *client, const TaskID &id, const TaskTableData &d) { ASSERT_EQ(id, task_id); ASSERT_TRUE(TaskTableDataEqual(*data, d)); @@ -122,7 +123,7 @@ void TestTableLookup(const JobID &job_id, std::shared_ptr c }; // Check that the lookup does not return an empty entry. - auto failure_callback = [](gcs::AsyncGcsClient *client, const TaskID &id) { + auto failure_callback = [](gcs::RedisGcsClient *client, const TaskID &id) { RAY_CHECK(false); }; @@ -148,7 +149,7 @@ TEST_MACRO(TestGcsWithAsio, TestTableLookup); TEST_MACRO(TestGcsWithChainAsio, TestTableLookup); #endif -void TestLogLookup(const JobID &job_id, std::shared_ptr client) { +void TestLogLookup(const JobID &job_id, std::shared_ptr client) { // Append some entries to the log at an object ID. TaskID task_id = TaskID::FromRandom(); std::vector node_manager_ids = {"abc", "def", "ghi"}; @@ -156,7 +157,7 @@ void TestLogLookup(const JobID &job_id, std::shared_ptr cli auto data = std::make_shared(); data->set_node_manager_id(node_manager_id); // Check that we added the correct object entries. - auto add_callback = [task_id, data](gcs::AsyncGcsClient *client, const TaskID &id, + auto add_callback = [task_id, data](gcs::RedisGcsClient *client, const TaskID &id, const TaskReconstructionData &d) { ASSERT_EQ(id, task_id); ASSERT_EQ(data->node_manager_id(), d.node_manager_id()); @@ -167,7 +168,7 @@ void TestLogLookup(const JobID &job_id, std::shared_ptr cli // Check that lookup returns the added object entries. auto lookup_callback = [task_id, node_manager_ids]( - gcs::AsyncGcsClient *client, const TaskID &id, + gcs::RedisGcsClient *client, const TaskID &id, const std::vector &data) { ASSERT_EQ(id, task_id); for (const auto &entry : data) { @@ -194,15 +195,15 @@ TEST_F(TestGcsWithAsio, TestLogLookup) { } void TestTableLookupFailure(const JobID &job_id, - std::shared_ptr client) { + std::shared_ptr client) { TaskID task_id = TaskID::FromRandom(); // Check that the lookup does not return data. - auto lookup_callback = [](gcs::AsyncGcsClient *client, const TaskID &id, + auto lookup_callback = [](gcs::RedisGcsClient *client, const TaskID &id, const TaskTableData &d) { RAY_CHECK(false); }; // Check that the lookup returns an empty entry. - auto failure_callback = [task_id](gcs::AsyncGcsClient *client, const TaskID &id) { + auto failure_callback = [task_id](gcs::RedisGcsClient *client, const TaskID &id) { ASSERT_EQ(id, task_id); test->Stop(); }; @@ -220,7 +221,7 @@ TEST_MACRO(TestGcsWithAsio, TestTableLookupFailure); TEST_MACRO(TestGcsWithChainAsio, TestTableLookupFailure); #endif -void TestLogAppendAt(const JobID &job_id, std::shared_ptr client) { +void TestLogAppendAt(const JobID &job_id, std::shared_ptr client) { TaskID task_id = TaskID::FromRandom(); std::vector node_manager_ids = {"A", "B"}; std::vector> data_log; @@ -231,7 +232,7 @@ void TestLogAppendAt(const JobID &job_id, std::shared_ptr c } // Check that we added the correct task. - auto failure_callback = [task_id](gcs::AsyncGcsClient *client, const TaskID &id, + auto failure_callback = [task_id](gcs::RedisGcsClient *client, const TaskID &id, const TaskReconstructionData &d) { ASSERT_EQ(id, task_id); test->IncrementNumCallbacks(); @@ -256,7 +257,7 @@ void TestLogAppendAt(const JobID &job_id, std::shared_ptr c /*done callback=*/nullptr, failure_callback, /*log_length=*/1)); auto lookup_callback = [node_manager_ids]( - gcs::AsyncGcsClient *client, const TaskID &id, + gcs::RedisGcsClient *client, const TaskID &id, const std::vector &data) { std::vector appended_managers; for (const auto &entry : data) { @@ -278,7 +279,7 @@ TEST_F(TestGcsWithAsio, TestLogAppendAt) { TestLogAppendAt(job_id_, client_); } -void TestSet(const JobID &job_id, std::shared_ptr client) { +void TestSet(const JobID &job_id, std::shared_ptr client) { // Add some entries to the set at an object ID. ObjectID object_id = ObjectID::FromRandom(); std::vector managers = {"abc", "def", "ghi"}; @@ -286,7 +287,7 @@ void TestSet(const JobID &job_id, std::shared_ptr client) { auto data = std::make_shared(); data->set_manager(manager); // Check that we added the correct object entries. - auto add_callback = [object_id, data](gcs::AsyncGcsClient *client, const ObjectID &id, + auto add_callback = [object_id, data](gcs::RedisGcsClient *client, const ObjectID &id, const ObjectTableData &d) { ASSERT_EQ(id, object_id); ASSERT_EQ(data->manager(), d.manager()); @@ -296,7 +297,7 @@ void TestSet(const JobID &job_id, std::shared_ptr client) { } // Check that lookup returns the added object entries. - auto lookup_callback = [object_id, managers](gcs::AsyncGcsClient *client, + auto lookup_callback = [object_id, managers](gcs::RedisGcsClient *client, const ObjectID &id, const std::vector &data) { ASSERT_EQ(id, object_id); @@ -311,7 +312,7 @@ void TestSet(const JobID &job_id, std::shared_ptr client) { auto data = std::make_shared(); data->set_manager(manager); // Check that we added the correct object entries. - auto remove_entry_callback = [object_id, data](gcs::AsyncGcsClient *client, + auto remove_entry_callback = [object_id, data](gcs::RedisGcsClient *client, const ObjectID &id, const ObjectTableData &d) { ASSERT_EQ(id, object_id); @@ -324,7 +325,7 @@ void TestSet(const JobID &job_id, std::shared_ptr client) { // Check that the entries are removed. auto lookup_callback2 = [object_id, managers]( - gcs::AsyncGcsClient *client, const ObjectID &id, + gcs::RedisGcsClient *client, const ObjectID &id, const std::vector &data) { ASSERT_EQ(id, object_id); ASSERT_EQ(data.size(), 0); @@ -346,7 +347,7 @@ TEST_F(TestGcsWithAsio, TestSet) { } void TestDeleteKeysFromLog( - const JobID &job_id, std::shared_ptr client, + const JobID &job_id, std::shared_ptr client, std::vector> &data_vector) { std::vector ids; TaskID task_id; @@ -354,7 +355,7 @@ void TestDeleteKeysFromLog( task_id = TaskID::FromRandom(); ids.push_back(task_id); // Check that we added the correct object entries. - auto add_callback = [task_id, data](gcs::AsyncGcsClient *client, const TaskID &id, + auto add_callback = [task_id, data](gcs::RedisGcsClient *client, const TaskID &id, const TaskReconstructionData &d) { ASSERT_EQ(id, task_id); ASSERT_EQ(data->node_manager_id(), d.node_manager_id()); @@ -366,7 +367,7 @@ void TestDeleteKeysFromLog( for (const auto &task_id : ids) { // Check that lookup returns the added object entries. auto lookup_callback = [task_id, data_vector]( - gcs::AsyncGcsClient *client, const TaskID &id, + gcs::RedisGcsClient *client, const TaskID &id, const std::vector &data) { ASSERT_EQ(id, task_id); ASSERT_EQ(data.size(), 1); @@ -381,7 +382,7 @@ void TestDeleteKeysFromLog( client->task_reconstruction_log().Delete(job_id, ids); } for (const auto &task_id : ids) { - auto lookup_callback = [task_id](gcs::AsyncGcsClient *client, const TaskID &id, + auto lookup_callback = [task_id](gcs::RedisGcsClient *client, const TaskID &id, const std::vector &data) { ASSERT_EQ(id, task_id); ASSERT_TRUE(data.size() == 0); @@ -393,7 +394,7 @@ void TestDeleteKeysFromLog( } void TestDeleteKeysFromTable(const JobID &job_id, - std::shared_ptr client, + std::shared_ptr client, std::vector> &data_vector, bool stop_at_end) { std::vector ids; @@ -402,7 +403,7 @@ void TestDeleteKeysFromTable(const JobID &job_id, task_id = TaskID::FromRandom(); ids.push_back(task_id); // Check that we added the correct object entries. - auto add_callback = [task_id, data](gcs::AsyncGcsClient *client, const TaskID &id, + auto add_callback = [task_id, data](gcs::RedisGcsClient *client, const TaskID &id, const TaskTableData &d) { ASSERT_EQ(id, task_id); ASSERT_TRUE(TaskTableDataEqual(*data, d)); @@ -411,7 +412,7 @@ void TestDeleteKeysFromTable(const JobID &job_id, RAY_CHECK_OK(client->raylet_task_table().Add(job_id, task_id, data, add_callback)); } for (const auto &task_id : ids) { - auto task_lookup_callback = [task_id](gcs::AsyncGcsClient *client, const TaskID &id, + auto task_lookup_callback = [task_id](gcs::RedisGcsClient *client, const TaskID &id, const TaskTableData &data) { ASSERT_EQ(id, task_id); test->IncrementNumCallbacks(); @@ -424,25 +425,25 @@ void TestDeleteKeysFromTable(const JobID &job_id, } else { client->raylet_task_table().Delete(job_id, ids); } - auto expected_failure_callback = [](AsyncGcsClient *client, const TaskID &id) { + auto expected_failure_callback = [](RedisGcsClient *client, const TaskID &id) { ASSERT_TRUE(true); test->IncrementNumCallbacks(); }; - auto undesired_callback = [](gcs::AsyncGcsClient *client, const TaskID &id, + auto undesired_callback = [](gcs::RedisGcsClient *client, const TaskID &id, const TaskTableData &data) { ASSERT_TRUE(false); }; for (size_t i = 0; i < ids.size(); ++i) { RAY_CHECK_OK(client->raylet_task_table().Lookup(job_id, task_id, undesired_callback, expected_failure_callback)); } if (stop_at_end) { - auto stop_callback = [](AsyncGcsClient *client, const TaskID &id) { test->Stop(); }; + auto stop_callback = [](RedisGcsClient *client, const TaskID &id) { test->Stop(); }; RAY_CHECK_OK( client->raylet_task_table().Lookup(job_id, ids[0], nullptr, stop_callback)); } } void TestDeleteKeysFromSet(const JobID &job_id, - std::shared_ptr client, + std::shared_ptr client, std::vector> &data_vector) { std::vector ids; ObjectID object_id; @@ -450,7 +451,7 @@ void TestDeleteKeysFromSet(const JobID &job_id, object_id = ObjectID::FromRandom(); ids.push_back(object_id); // Check that we added the correct object entries. - auto add_callback = [object_id, data](gcs::AsyncGcsClient *client, const ObjectID &id, + auto add_callback = [object_id, data](gcs::RedisGcsClient *client, const ObjectID &id, const ObjectTableData &d) { ASSERT_EQ(id, object_id); ASSERT_EQ(data->manager(), d.manager()); @@ -461,7 +462,7 @@ void TestDeleteKeysFromSet(const JobID &job_id, for (const auto &object_id : ids) { // Check that lookup returns the added object entries. auto lookup_callback = [object_id, data_vector]( - gcs::AsyncGcsClient *client, const ObjectID &id, + gcs::RedisGcsClient *client, const ObjectID &id, const std::vector &data) { ASSERT_EQ(id, object_id); ASSERT_EQ(data.size(), 1); @@ -475,7 +476,7 @@ void TestDeleteKeysFromSet(const JobID &job_id, client->object_table().Delete(job_id, ids); } for (const auto &object_id : ids) { - auto lookup_callback = [object_id](gcs::AsyncGcsClient *client, const ObjectID &id, + auto lookup_callback = [object_id](gcs::RedisGcsClient *client, const ObjectID &id, const std::vector &data) { ASSERT_EQ(id, object_id); ASSERT_TRUE(data.size() == 0); @@ -486,7 +487,7 @@ void TestDeleteKeysFromSet(const JobID &job_id, } // Test delete function for keys of Log or Table. -void TestDeleteKeys(const JobID &job_id, std::shared_ptr client) { +void TestDeleteKeys(const JobID &job_id, std::shared_ptr client) { // Test delete function for keys of Log. std::vector> task_reconstruction_vector; auto AppendTaskReconstructionData = [&task_reconstruction_vector](size_t add_count) { @@ -574,13 +575,13 @@ TEST_F(TestGcsWithAsio, TestDeleteKey) { } void TestLogSubscribeAll(const JobID &job_id, - std::shared_ptr client) { + std::shared_ptr client) { std::vector job_ids; for (int i = 0; i < 3; i++) { job_ids.emplace_back(NextJobID()); } // Callback for a notification. - auto notification_callback = [job_ids](gcs::AsyncGcsClient *client, const JobID &id, + auto notification_callback = [job_ids](gcs::RedisGcsClient *client, const JobID &id, const std::vector data) { ASSERT_EQ(id, job_ids[test->NumCallbacks()]); // Check that we get notifications in the same order as the writes. @@ -595,7 +596,7 @@ void TestLogSubscribeAll(const JobID &job_id, // Callback for subscription success. We are guaranteed to receive // notifications after this is called. - auto subscribe_callback = [job_ids](gcs::AsyncGcsClient *client) { + auto subscribe_callback = [job_ids](gcs::RedisGcsClient *client) { // We have subscribed. Do the writes to the table. for (size_t i = 0; i < job_ids.size(); i++) { RAY_CHECK_OK( @@ -622,7 +623,7 @@ TEST_F(TestGcsWithAsio, TestLogSubscribeAll) { } void TestSetSubscribeAll(const JobID &job_id, - std::shared_ptr client) { + std::shared_ptr client) { std::vector object_ids; for (int i = 0; i < 3; i++) { object_ids.emplace_back(ObjectID::FromRandom()); @@ -631,7 +632,7 @@ void TestSetSubscribeAll(const JobID &job_id, // Callback for a notification. auto notification_callback = [object_ids, managers]( - gcs::AsyncGcsClient *client, const ObjectID &id, + gcs::RedisGcsClient *client, const ObjectID &id, const GcsChangeMode change_mode, const std::vector data) { if (test->NumCallbacks() < 3 * 3) { @@ -652,7 +653,7 @@ void TestSetSubscribeAll(const JobID &job_id, // Callback for subscription success. We are guaranteed to receive // notifications after this is called. - auto subscribe_callback = [job_id, object_ids, managers](gcs::AsyncGcsClient *client) { + auto subscribe_callback = [job_id, object_ids, managers](gcs::RedisGcsClient *client) { // We have subscribed. Do the writes to the table. for (size_t i = 0; i < object_ids.size(); i++) { for (size_t j = 0; j < managers.size(); j++) { @@ -698,7 +699,7 @@ TEST_F(TestGcsWithAsio, TestSetSubscribeAll) { } void TestTableSubscribeId(const JobID &job_id, - std::shared_ptr client) { + std::shared_ptr client) { int num_modifications = 3; // Add a table entry. @@ -709,7 +710,7 @@ void TestTableSubscribeId(const JobID &job_id, // The callback for a notification from the table. This should only be // received for keys that we requested notifications for. - auto notification_callback = [task_id2, num_modifications](gcs::AsyncGcsClient *client, + auto notification_callback = [task_id2, num_modifications](gcs::RedisGcsClient *client, const TaskID &id, const TaskTableData &data) { // Check that we only get notifications for the requested key. @@ -726,7 +727,7 @@ void TestTableSubscribeId(const JobID &job_id, // The failure callback should be called once since both keys start as empty. bool failure_notification_received = false; auto failure_callback = [task_id2, &failure_notification_received]( - gcs::AsyncGcsClient *client, const TaskID &id) { + gcs::RedisGcsClient *client, const TaskID &id) { ASSERT_EQ(id, task_id2); // The failure notification should be the first notification received. ASSERT_EQ(test->NumCallbacks(), 0); @@ -736,7 +737,7 @@ void TestTableSubscribeId(const JobID &job_id, // The callback for subscription success. Once we've subscribed, request // notifications for only one of the keys, then write to both keys. auto subscribe_callback = [job_id, task_id1, task_id2, - num_modifications](gcs::AsyncGcsClient *client) { + num_modifications](gcs::RedisGcsClient *client) { // Request notifications for one of the keys. RAY_CHECK_OK(client->raylet_task_table().RequestNotifications( job_id, task_id2, client->client_table().GetLocalClientId())); @@ -774,7 +775,7 @@ TEST_MACRO(TestGcsWithChainAsio, TestTableSubscribeId); #endif void TestLogSubscribeId(const JobID &job_id, - std::shared_ptr client) { + std::shared_ptr client) { // Add a log entry. JobID job_id1 = NextJobID(); std::vector job_ids1 = {"abc", "def", "ghi"}; @@ -792,7 +793,7 @@ void TestLogSubscribeId(const JobID &job_id, // The callback for a notification from the table. This should only be // received for keys that we requested notifications for. auto notification_callback = [job_id2, job_ids2]( - gcs::AsyncGcsClient *client, const JobID &id, + gcs::RedisGcsClient *client, const JobID &id, const std::vector &data) { // Check that we only get notifications for the requested key. ASSERT_EQ(id, job_id2); @@ -809,7 +810,7 @@ void TestLogSubscribeId(const JobID &job_id, // The callback for subscription success. Once we've subscribed, request // notifications for only one of the keys, then write to both keys. auto subscribe_callback = [job_id, job_id1, job_id2, job_ids1, - job_ids2](gcs::AsyncGcsClient *client) { + job_ids2](gcs::RedisGcsClient *client) { // Request notifications for one of the keys. RAY_CHECK_OK(client->job_table().RequestNotifications( job_id, job_id2, client->client_table().GetLocalClientId())); @@ -848,7 +849,7 @@ TEST_F(TestGcsWithAsio, TestLogSubscribeId) { } void TestSetSubscribeId(const JobID &job_id, - std::shared_ptr client) { + std::shared_ptr client) { // Add a set entry. ObjectID object_id1 = ObjectID::FromRandom(); std::vector managers1 = {"abc", "def", "ghi"}; @@ -866,7 +867,7 @@ void TestSetSubscribeId(const JobID &job_id, // The callback for a notification from the table. This should only be // received for keys that we requested notifications for. auto notification_callback = [object_id2, managers2]( - gcs::AsyncGcsClient *client, const ObjectID &id, + gcs::RedisGcsClient *client, const ObjectID &id, const GcsChangeMode change_mode, const std::vector &data) { ASSERT_EQ(change_mode, GcsChangeMode::APPEND_OR_ADD); @@ -885,7 +886,7 @@ void TestSetSubscribeId(const JobID &job_id, // The callback for subscription success. Once we've subscribed, request // notifications for only one of the keys, then write to both keys. auto subscribe_callback = [job_id, object_id1, object_id2, managers1, - managers2](gcs::AsyncGcsClient *client) { + managers2](gcs::RedisGcsClient *client) { // Request notifications for one of the keys. RAY_CHECK_OK(client->object_table().RequestNotifications( job_id, object_id2, client->client_table().GetLocalClientId())); @@ -924,7 +925,7 @@ TEST_F(TestGcsWithAsio, TestSetSubscribeId) { } void TestTableSubscribeCancel(const JobID &job_id, - std::shared_ptr client) { + std::shared_ptr client) { // Add a table entry. const auto task_id = TaskID::FromRandom(); const int num_modifications = 3; @@ -933,13 +934,13 @@ void TestTableSubscribeCancel(const JobID &job_id, // The failure callback should not be called since all keys are non-empty // when notifications are requested. - auto failure_callback = [](gcs::AsyncGcsClient *client, const TaskID &id) { + auto failure_callback = [](gcs::RedisGcsClient *client, const TaskID &id) { RAY_CHECK(false); }; // The callback for a notification from the table. This should only be // received for keys that we requested notifications for. - auto notification_callback = [task_id](gcs::AsyncGcsClient *client, const TaskID &id, + auto notification_callback = [task_id](gcs::RedisGcsClient *client, const TaskID &id, const TaskTableData &data) { ASSERT_EQ(id, task_id); // Check that we only get notifications for the first and last writes, @@ -958,7 +959,7 @@ void TestTableSubscribeCancel(const JobID &job_id, // The callback for a notification from the table. This should only be // received for keys that we requested notifications for. - auto subscribe_callback = [job_id, task_id](gcs::AsyncGcsClient *client) { + auto subscribe_callback = [job_id, task_id](gcs::RedisGcsClient *client) { // Request notifications, then cancel immediately. We should receive a // notification for the current value at the key. RAY_CHECK_OK(client->raylet_task_table().RequestNotifications( @@ -996,7 +997,7 @@ TEST_MACRO(TestGcsWithChainAsio, TestTableSubscribeCancel); #endif void TestLogSubscribeCancel(const JobID &job_id, - std::shared_ptr client) { + std::shared_ptr client) { // Add a log entry. JobID random_job_id = NextJobID(); std::vector job_ids = {"jkl", "mno", "pqr"}; @@ -1007,7 +1008,7 @@ void TestLogSubscribeCancel(const JobID &job_id, // The callback for a notification from the object table. This should only be // received for the object that we requested notifications for. auto notification_callback = [random_job_id, job_ids]( - gcs::AsyncGcsClient *client, const JobID &id, + gcs::RedisGcsClient *client, const JobID &id, const std::vector &data) { ASSERT_EQ(id, random_job_id); // Check that we get a duplicate notification for the first write. We get a @@ -1027,7 +1028,7 @@ void TestLogSubscribeCancel(const JobID &job_id, // The callback for a notification from the table. This should only be // received for keys that we requested notifications for. auto subscribe_callback = [job_id, random_job_id, - job_ids](gcs::AsyncGcsClient *client) { + job_ids](gcs::RedisGcsClient *client) { // Request notifications, then cancel immediately. We should receive a // notification for the current value at the key. RAY_CHECK_OK(client->job_table().RequestNotifications( @@ -1068,7 +1069,7 @@ TEST_F(TestGcsWithAsio, TestLogSubscribeCancel) { } void TestSetSubscribeCancel(const JobID &job_id, - std::shared_ptr client) { + std::shared_ptr client) { // Add a set entry. ObjectID object_id = ObjectID::FromRandom(); std::vector managers = {"jkl", "mno", "pqr"}; @@ -1079,7 +1080,7 @@ void TestSetSubscribeCancel(const JobID &job_id, // The callback for a notification from the object table. This should only be // received for the object that we requested notifications for. auto notification_callback = [object_id, managers]( - gcs::AsyncGcsClient *client, const ObjectID &id, + gcs::RedisGcsClient *client, const ObjectID &id, const GcsChangeMode change_mode, const std::vector &data) { ASSERT_EQ(change_mode, GcsChangeMode::APPEND_OR_ADD); @@ -1109,7 +1110,7 @@ void TestSetSubscribeCancel(const JobID &job_id, // The callback for a notification from the table. This should only be // received for keys that we requested notifications for. - auto subscribe_callback = [job_id, object_id, managers](gcs::AsyncGcsClient *client) { + auto subscribe_callback = [job_id, object_id, managers](gcs::RedisGcsClient *client) { // Request notifications, then cancel immediately. We should receive a // notification for the current value at the key. RAY_CHECK_OK(client->object_table().RequestNotifications( @@ -1149,7 +1150,7 @@ TEST_F(TestGcsWithAsio, TestSetSubscribeCancel) { TestSetSubscribeCancel(job_id_, client_); } -void ClientTableNotification(gcs::AsyncGcsClient *client, const ClientID &client_id, +void ClientTableNotification(gcs::RedisGcsClient *client, const ClientID &client_id, const ClientTableData &data, bool is_insertion) { ClientID added_id = client->client_table().GetLocalClientId(); ASSERT_EQ(client_id, added_id); @@ -1164,11 +1165,11 @@ void ClientTableNotification(gcs::AsyncGcsClient *client, const ClientID &client } void TestClientTableConnect(const JobID &job_id, - std::shared_ptr client) { + std::shared_ptr client) { // Register callbacks for when a client gets added and removed. The latter // event will stop the event loop. client->client_table().RegisterClientAddedCallback( - [](gcs::AsyncGcsClient *client, const ClientID &id, const ClientTableData &data) { + [](gcs::RedisGcsClient *client, const ClientID &id, const ClientTableData &data) { ClientTableNotification(client, id, data, true); test->Stop(); }); @@ -1189,18 +1190,18 @@ TEST_F(TestGcsWithAsio, TestClientTableConnect) { } void TestClientTableDisconnect(const JobID &job_id, - std::shared_ptr client) { + std::shared_ptr client) { // Register callbacks for when a client gets added and removed. The latter // event will stop the event loop. client->client_table().RegisterClientAddedCallback( - [](gcs::AsyncGcsClient *client, const ClientID &id, const ClientTableData &data) { + [](gcs::RedisGcsClient *client, const ClientID &id, const ClientTableData &data) { ClientTableNotification(client, id, data, /*is_insertion=*/true); // Disconnect from the client table. We should receive a notification // for the removal of our own entry. RAY_CHECK_OK(client->client_table().Disconnect()); }); client->client_table().RegisterClientRemovedCallback( - [](gcs::AsyncGcsClient *client, const ClientID &id, const ClientTableData &data) { + [](gcs::RedisGcsClient *client, const ClientID &id, const ClientTableData &data) { ClientTableNotification(client, id, data, /*is_insertion=*/false); test->Stop(); }); @@ -1220,15 +1221,15 @@ TEST_F(TestGcsWithAsio, TestClientTableDisconnect) { } void TestClientTableImmediateDisconnect(const JobID &job_id, - std::shared_ptr client) { + std::shared_ptr client) { // Register callbacks for when a client gets added and removed. The latter // event will stop the event loop. client->client_table().RegisterClientAddedCallback( - [](gcs::AsyncGcsClient *client, const ClientID &id, const ClientTableData &data) { + [](gcs::RedisGcsClient *client, const ClientID &id, const ClientTableData &data) { ClientTableNotification(client, id, data, true); }); client->client_table().RegisterClientRemovedCallback( - [](gcs::AsyncGcsClient *client, const ClientID &id, const ClientTableData &data) { + [](gcs::RedisGcsClient *client, const ClientID &id, const ClientTableData &data) { ClientTableNotification(client, id, data, false); test->Stop(); }); @@ -1249,7 +1250,7 @@ TEST_F(TestGcsWithAsio, TestClientTableImmediateDisconnect) { } void TestClientTableMarkDisconnected(const JobID &job_id, - std::shared_ptr client) { + std::shared_ptr client) { ClientTableData local_client_info = client->client_table().GetLocalClient(); local_client_info.set_node_manager_address("127.0.0.1"); local_client_info.set_node_manager_port(0); @@ -1262,7 +1263,7 @@ void TestClientTableMarkDisconnected(const JobID &job_id, // Make sure we only get a notification for the removal of the client we // marked as dead. client->client_table().RegisterClientRemovedCallback( - [dead_client_id](gcs::AsyncGcsClient *client, const UniqueID &id, + [dead_client_id](gcs::RedisGcsClient *client, const UniqueID &id, const ClientTableData &data) { ASSERT_EQ(ClientID::FromBinary(data.client_id()), dead_client_id); test->Stop(); @@ -1275,7 +1276,7 @@ TEST_F(TestGcsWithAsio, TestClientTableMarkDisconnected) { TestClientTableMarkDisconnected(job_id_, client_); } -void TestHashTable(const JobID &job_id, std::shared_ptr client) { +void TestHashTable(const JobID &job_id, std::shared_ptr client) { const int expected_count = 14; ClientID client_id = ClientID::FromRandom(); // Prepare the first resource map: data_map1. @@ -1309,12 +1310,12 @@ void TestHashTable(const JobID &job_id, std::shared_ptr cli ASSERT_EQ(iter->second->resource_capacity(), data.second->resource_capacity()); } }; - auto subscribe_callback = [](AsyncGcsClient *client) { + auto subscribe_callback = [](RedisGcsClient *client) { ASSERT_TRUE(true); test->IncrementNumCallbacks(); }; auto notification_callback = [data_map1, data_map2, compare_test]( - AsyncGcsClient *client, const ClientID &id, + RedisGcsClient *client, const ClientID &id, const GcsChangeMode change_mode, const DynamicResourceTable::DataMap &data) { if (change_mode == GcsChangeMode::REMOVE) { @@ -1345,7 +1346,7 @@ void TestHashTable(const JobID &job_id, std::shared_ptr cli // Step 1: Add elements to the hash table. auto update_callback1 = [data_map1, compare_test]( - AsyncGcsClient *client, const ClientID &id, + RedisGcsClient *client, const ClientID &id, const DynamicResourceTable::DataMap &callback_data) { compare_test(data_map1, callback_data); test->IncrementNumCallbacks(); @@ -1353,7 +1354,7 @@ void TestHashTable(const JobID &job_id, std::shared_ptr cli RAY_CHECK_OK( client->resource_table().Update(job_id, client_id, data_map1, update_callback1)); auto lookup_callback1 = [data_map1, compare_test]( - AsyncGcsClient *client, const ClientID &id, + RedisGcsClient *client, const ClientID &id, const DynamicResourceTable::DataMap &callback_data) { compare_test(data_map1, callback_data); test->IncrementNumCallbacks(); @@ -1363,16 +1364,16 @@ void TestHashTable(const JobID &job_id, std::shared_ptr cli // Step 2: Decrease one element, increase one and add a new one. RAY_CHECK_OK(client->resource_table().Update(job_id, client_id, data_map2, nullptr)); auto lookup_callback2 = [data_map2, compare_test]( - AsyncGcsClient *client, const ClientID &id, + RedisGcsClient *client, const ClientID &id, const DynamicResourceTable::DataMap &callback_data) { compare_test(data_map2, callback_data); test->IncrementNumCallbacks(); }; RAY_CHECK_OK(client->resource_table().Lookup(job_id, client_id, lookup_callback2)); std::vector delete_keys({"GPU", "CUSTOM", "None-Existent"}); - auto remove_callback = [delete_keys](AsyncGcsClient *client, const ClientID &id, + auto remove_callback = [delete_keys](RedisGcsClient *client, const ClientID &id, const std::vector &callback_data) { - for (int i = 0; i < callback_data.size(); ++i) { + for (size_t i = 0; i < callback_data.size(); ++i) { // All deleting keys exist in this argument even if the key doesn't exist. ASSERT_EQ(callback_data[i], delete_keys[i]); } @@ -1384,7 +1385,7 @@ void TestHashTable(const JobID &job_id, std::shared_ptr cli data_map3.erase("GPU"); data_map3.erase("CUSTOM"); auto lookup_callback3 = [data_map3, compare_test]( - AsyncGcsClient *client, const ClientID &id, + RedisGcsClient *client, const ClientID &id, const DynamicResourceTable::DataMap &callback_data) { compare_test(data_map3, callback_data); test->IncrementNumCallbacks(); @@ -1395,7 +1396,7 @@ void TestHashTable(const JobID &job_id, std::shared_ptr cli RAY_CHECK_OK( client->resource_table().Update(job_id, client_id, data_map1, update_callback1)); auto lookup_callback4 = [data_map1, compare_test]( - AsyncGcsClient *client, const ClientID &id, + RedisGcsClient *client, const ClientID &id, const DynamicResourceTable::DataMap &callback_data) { compare_test(data_map1, callback_data); test->IncrementNumCallbacks(); @@ -1405,7 +1406,7 @@ void TestHashTable(const JobID &job_id, std::shared_ptr cli // Step 4: Removing all elements will remove the home Hash table from GCS. RAY_CHECK_OK(client->resource_table().RemoveEntries( job_id, client_id, {"GPU", "CPU", "CUSTOM", "None-Existent"}, nullptr)); - auto lookup_callback5 = [](AsyncGcsClient *client, const ClientID &id, + auto lookup_callback5 = [](RedisGcsClient *client, const ClientID &id, const DynamicResourceTable::DataMap &callback_data) { ASSERT_EQ(callback_data.size(), 0); test->IncrementNumCallbacks(); diff --git a/src/ray/gcs/tables.cc b/src/ray/gcs/tables.cc index 7a07cd879..b02ebb196 100644 --- a/src/ray/gcs/tables.cc +++ b/src/ray/gcs/tables.cc @@ -3,7 +3,7 @@ #include "ray/common/common_protocol.h" #include "ray/common/grpc_util.h" #include "ray/common/ray_config.h" -#include "ray/gcs/client.h" +#include "ray/gcs/redis_gcs_client.h" #include "ray/util/util.h" namespace { @@ -110,7 +110,7 @@ template Status Log::Subscribe(const JobID &job_id, const ClientID &client_id, const Callback &subscribe, const SubscriptionCallback &done) { - auto subscribe_wrapper = [subscribe](AsyncGcsClient *client, const ID &id, + auto subscribe_wrapper = [subscribe](RedisGcsClient *client, const ID &id, const GcsChangeMode change_mode, const std::vector &data) { RAY_CHECK(change_mode != GcsChangeMode::REMOVE); @@ -247,7 +247,7 @@ Status Table::Lookup(const JobID &job_id, const ID &id, const Callback const FailureCallback &failure) { num_lookups_++; return Log::Lookup(job_id, id, - [lookup, failure](AsyncGcsClient *client, const ID &id, + [lookup, failure](RedisGcsClient *client, const ID &id, const std::vector &data) { if (data.empty()) { if (failure != nullptr) { @@ -269,7 +269,7 @@ Status Table::Subscribe(const JobID &job_id, const ClientID &client_id const SubscriptionCallback &done) { return Log::Subscribe( job_id, client_id, - [subscribe, failure](AsyncGcsClient *client, const ID &id, + [subscribe, failure](RedisGcsClient *client, const ID &id, const std::vector &data) { RAY_CHECK(data.empty() || data.size() == 1); if (data.size() == 1) { @@ -511,7 +511,7 @@ void ClientTable::RegisterClientRemovedCallback(const ClientTableCallback &callb } } -void ClientTable::HandleNotification(AsyncGcsClient *client, +void ClientTable::HandleNotification(RedisGcsClient *client, const ClientTableData &data) { ClientID client_id = ClientID::FromBinary(data.client_id()); // It's possible to get duplicate notifications from the client table, so @@ -564,7 +564,7 @@ void ClientTable::HandleNotification(AsyncGcsClient *client, } } -void ClientTable::HandleConnected(AsyncGcsClient *client, const ClientTableData &data) { +void ClientTable::HandleConnected(RedisGcsClient *client, const ClientTableData &data) { auto connected_client_id = ClientID::FromBinary(data.client_id()); RAY_CHECK(client_id_ == connected_client_id) << connected_client_id << " " << client_id_; @@ -589,14 +589,14 @@ Status ClientTable::Connect(const ClientTableData &local_client) { data->set_is_insertion(true); // Callback to handle our own successful connection once we've added // ourselves. - auto add_callback = [this](AsyncGcsClient *client, const UniqueID &log_key, + auto add_callback = [this](RedisGcsClient *client, const UniqueID &log_key, const ClientTableData &data) { RAY_CHECK(log_key == client_log_key_); HandleConnected(client, data); // Callback for a notification from the client table. auto notification_callback = [this]( - AsyncGcsClient *client, const UniqueID &log_key, + RedisGcsClient *client, const UniqueID &log_key, const std::vector ¬ifications) { RAY_CHECK(log_key == client_log_key_); std::unordered_map connected_nodes; @@ -623,7 +623,7 @@ Status ClientTable::Connect(const ClientTableData &local_client) { }; // Callback to request notifications from the client table once we've // successfully subscribed. - auto subscription_callback = [this](AsyncGcsClient *c) { + auto subscription_callback = [this](RedisGcsClient *c) { RAY_CHECK_OK(RequestNotifications(JobID::Nil(), client_log_key_, client_id_)); }; // Subscribe to the client table. @@ -636,7 +636,7 @@ Status ClientTable::Connect(const ClientTableData &local_client) { Status ClientTable::Disconnect(const DisconnectCallback &callback) { auto data = std::make_shared(local_client_); data->set_is_insertion(false); - auto add_callback = [this, callback](AsyncGcsClient *client, const ClientID &id, + auto add_callback = [this, callback](RedisGcsClient *client, const ClientID &id, const ClientTableData &data) { HandleConnected(client, data); RAY_CHECK_OK(CancelNotifications(JobID::Nil(), client_log_key_, id)); @@ -689,7 +689,7 @@ Status ActorCheckpointIdTable::AddCheckpointId(const JobID &job_id, const ActorID &actor_id, const ActorCheckpointID &checkpoint_id) { auto lookup_callback = [this, checkpoint_id, job_id, actor_id]( - ray::gcs::AsyncGcsClient *client, const UniqueID &id, + ray::gcs::RedisGcsClient *client, const UniqueID &id, const ActorCheckpointIdData &data) { std::shared_ptr copy = std::make_shared(data); @@ -707,7 +707,7 @@ Status ActorCheckpointIdTable::AddCheckpointId(const JobID &job_id, RAY_CHECK_OK(Add(job_id, actor_id, copy, nullptr)); }; auto failure_callback = [this, checkpoint_id, job_id, actor_id]( - ray::gcs::AsyncGcsClient *client, const UniqueID &id) { + ray::gcs::RedisGcsClient *client, const UniqueID &id) { std::shared_ptr data = std::make_shared(); data->set_actor_id(id.Binary()); diff --git a/src/ray/gcs/tables.h b/src/ray/gcs/tables.h index 189555d2a..98241eb34 100644 --- a/src/ray/gcs/tables.h +++ b/src/ray/gcs/tables.h @@ -41,11 +41,11 @@ using rpc::TaskTableData; class RedisContext; -class AsyncGcsClient; +class RedisGcsClient; /// Specifies whether commands issued to a table should be regular or chain-replicated /// (when available). -enum class CommandType { kRegular, kChain }; +enum class CommandType { kRegular, kChain, kUnknown }; /// \class PubsubInterface /// @@ -66,7 +66,7 @@ template class LogInterface { public: using WriteCallback = - std::function; + std::function; virtual Status Append(const JobID &job_id, const ID &id, const std::shared_ptr &data, const WriteCallback &done) = 0; virtual Status AppendAt(const JobID &job_id, const ID &task_id, @@ -88,16 +88,16 @@ class LogInterface { template class Log : public LogInterface, virtual public PubsubInterface { public: - using Callback = std::function &data)>; using NotificationCallback = - std::function &data)>; /// The callback to call when a write to a key succeeds. using WriteCallback = typename LogInterface::WriteCallback; /// The callback to call when a SUBSCRIBE call completes and we are ready to /// request and receive notifications. - using SubscriptionCallback = std::function; + using SubscriptionCallback = std::function; struct CallbackData { ID id; @@ -107,10 +107,10 @@ class Log : public LogInterface, virtual public PubsubInterface { // first message is a notification of subscription success. SubscriptionCallback subscription_callback; Log *log; - AsyncGcsClient *client; + RedisGcsClient *client; }; - Log(const std::vector> &contexts, AsyncGcsClient *client) + Log(const std::vector> &contexts, RedisGcsClient *client) : shard_contexts_(contexts), client_(client), pubsub_channel_(TablePubsub::NO_PUBLISH), @@ -247,7 +247,7 @@ class Log : public LogInterface, virtual public PubsubInterface { /// The connection to the GCS. std::vector> shard_contexts_; /// The GCS client. - AsyncGcsClient *client_; + RedisGcsClient *client_; /// The pubsub channel to subscribe to for notifications about keys in this /// table. If no notifications are required, this should be set to /// TablePubsub_NO_PUBLISH. If notifications are required, then this must be @@ -292,16 +292,16 @@ class Table : private Log, virtual public PubsubInterface { public: using Callback = - std::function; + std::function; using WriteCallback = typename Log::WriteCallback; /// The callback to call when a Lookup call returns an empty entry. - using FailureCallback = std::function; + using FailureCallback = std::function; /// The callback to call when a Subscribe call completes and we are ready to /// request and receive notifications. using SubscriptionCallback = typename Log::SubscriptionCallback; Table(const std::vector> &contexts, - AsyncGcsClient *client) + RedisGcsClient *client) : Log(contexts, client) {} using Log::RequestNotifications; @@ -404,7 +404,7 @@ class Set : private Log, using NotificationCallback = typename Log::NotificationCallback; using SubscriptionCallback = typename Log::SubscriptionCallback; - Set(const std::vector> &contexts, AsyncGcsClient *client) + Set(const std::vector> &contexts, RedisGcsClient *client) : Log(contexts, client) {} using Log::RequestNotifications; @@ -471,7 +471,7 @@ class HashInterface { /// \param data Map data contains the change to the Hash Table. /// \return Void using HashCallback = - std::function; + std::function; /// The callback function used by function RemoveEntries. /// @@ -479,7 +479,7 @@ class HashInterface { /// \param id The ID of the Hash Table whose entries are removed. /// \param keys The keys that are moved from this Hash Table. /// \return Void - using HashRemoveCallback = std::function &keys)>; /// The notification function used by function Subscribe. @@ -489,7 +489,7 @@ class HashInterface { /// \param data Map data contains the change to the Hash Table. /// \return Void using HashNotificationCallback = - std::function; /// Add entries of a hash table. @@ -556,7 +556,7 @@ class Hash : private Log, typename HashInterface::HashNotificationCallback; using SubscriptionCallback = typename Log::SubscriptionCallback; - Hash(const std::vector> &contexts, AsyncGcsClient *client) + Hash(const std::vector> &contexts, RedisGcsClient *client) : Log(contexts, client) {} using Log::RequestNotifications; @@ -596,7 +596,7 @@ class Hash : private Log, class DynamicResourceTable : public Hash { public: DynamicResourceTable(const std::vector> &contexts, - AsyncGcsClient *client) + RedisGcsClient *client) : Hash(contexts, client) { pubsub_channel_ = TablePubsub::NODE_RESOURCE_PUBSUB; prefix_ = TablePrefix::NODE_RESOURCE; @@ -608,7 +608,7 @@ class DynamicResourceTable : public Hash { class ObjectTable : public Set { public: ObjectTable(const std::vector> &contexts, - AsyncGcsClient *client) + RedisGcsClient *client) : Set(contexts, client) { pubsub_channel_ = TablePubsub::OBJECT_PUBSUB; prefix_ = TablePrefix::OBJECT; @@ -620,7 +620,7 @@ class ObjectTable : public Set { class HeartbeatTable : public Table { public: HeartbeatTable(const std::vector> &contexts, - AsyncGcsClient *client) + RedisGcsClient *client) : Table(contexts, client) { pubsub_channel_ = TablePubsub::HEARTBEAT_PUBSUB; prefix_ = TablePrefix::HEARTBEAT; @@ -631,7 +631,7 @@ class HeartbeatTable : public Table { class HeartbeatBatchTable : public Table { public: HeartbeatBatchTable(const std::vector> &contexts, - AsyncGcsClient *client) + RedisGcsClient *client) : Table(contexts, client) { pubsub_channel_ = TablePubsub::HEARTBEAT_BATCH_PUBSUB; prefix_ = TablePrefix::HEARTBEAT_BATCH; @@ -642,7 +642,7 @@ class HeartbeatBatchTable : public Table { class JobTable : public Log { public: JobTable(const std::vector> &contexts, - AsyncGcsClient *client) + RedisGcsClient *client) : Log(contexts, client) { pubsub_channel_ = TablePubsub::JOB_PUBSUB; prefix_ = TablePrefix::JOB; @@ -670,7 +670,7 @@ class JobTable : public Log { class ActorTable : public Log { public: ActorTable(const std::vector> &contexts, - AsyncGcsClient *client) + RedisGcsClient *client) : Log(contexts, client) { pubsub_channel_ = TablePubsub::ACTOR_PUBSUB; prefix_ = TablePrefix::ACTOR; @@ -680,7 +680,7 @@ class ActorTable : public Log { class TaskReconstructionLog : public Log { public: TaskReconstructionLog(const std::vector> &contexts, - AsyncGcsClient *client) + RedisGcsClient *client) : Log(contexts, client) { prefix_ = TablePrefix::TASK_RECONSTRUCTION; } @@ -689,7 +689,7 @@ class TaskReconstructionLog : public Log { class TaskLeaseTable : public Table { public: TaskLeaseTable(const std::vector> &contexts, - AsyncGcsClient *client) + RedisGcsClient *client) : Table(contexts, client) { pubsub_channel_ = TablePubsub::TASK_LEASE_PUBSUB; prefix_ = TablePrefix::TASK_LEASE; @@ -715,7 +715,7 @@ class TaskLeaseTable : public Table { class ActorCheckpointTable : public Table { public: ActorCheckpointTable(const std::vector> &contexts, - AsyncGcsClient *client) + RedisGcsClient *client) : Table(contexts, client) { prefix_ = TablePrefix::ACTOR_CHECKPOINT; }; @@ -724,7 +724,7 @@ class ActorCheckpointTable : public Table { public: ActorCheckpointIdTable(const std::vector> &contexts, - AsyncGcsClient *client) + RedisGcsClient *client) : Table(contexts, client) { prefix_ = TablePrefix::ACTOR_CHECKPOINT_ID; }; @@ -745,14 +745,14 @@ namespace raylet { class TaskTable : public Table { public: TaskTable(const std::vector> &contexts, - AsyncGcsClient *client) + RedisGcsClient *client) : Table(contexts, client) { pubsub_channel_ = TablePubsub::RAYLET_TASK_PUBSUB; prefix_ = TablePrefix::RAYLET_TASK; } TaskTable(const std::vector> &contexts, - AsyncGcsClient *client, gcs::CommandType command_type) + RedisGcsClient *client, gcs::CommandType command_type) : TaskTable(contexts, client) { command_type_ = command_type; }; @@ -763,7 +763,7 @@ class TaskTable : public Table { class ErrorTable : private Log { public: ErrorTable(const std::vector> &contexts, - AsyncGcsClient *client) + RedisGcsClient *client) : Log(contexts, client) { pubsub_channel_ = TablePubsub::ERROR_INFO_PUBSUB; prefix_ = TablePrefix::ERROR_INFO; @@ -794,7 +794,7 @@ class ErrorTable : private Log { class ProfileTable : private Log { public: ProfileTable(const std::vector> &contexts, - AsyncGcsClient *client) + RedisGcsClient *client) : Log(contexts, client) { prefix_ = TablePrefix::PROFILE; }; @@ -823,10 +823,10 @@ class ProfileTable : private Log { class ClientTable : public Log { public: using ClientTableCallback = std::function; + RedisGcsClient *client, const ClientID &id, const ClientTableData &data)>; using DisconnectCallback = std::function; ClientTable(const std::vector> &contexts, - AsyncGcsClient *client, const ClientID &client_id) + RedisGcsClient *client, const ClientID &client_id) : Log(contexts, client), // We set the client log's key equal to nil so that all instances of // ClientTable have the same key. @@ -922,9 +922,9 @@ class ClientTable : public Log { private: /// Handle a client table notification. - void HandleNotification(AsyncGcsClient *client, const ClientTableData ¬ifications); + void HandleNotification(RedisGcsClient *client, const ClientTableData ¬ifications); /// Handle this client's successful connection to the GCS. - void HandleConnected(AsyncGcsClient *client, const ClientTableData &client_data); + void HandleConnected(RedisGcsClient *client, const ClientTableData &client_data); /// Whether this client has called Disconnect(). bool disconnected_; /// This client's ID. diff --git a/src/ray/object_manager/object_directory.cc b/src/ray/object_manager/object_directory.cc index 8d28c0a84..d4175bdb0 100644 --- a/src/ray/object_manager/object_directory.cc +++ b/src/ray/object_manager/object_directory.cc @@ -3,7 +3,7 @@ namespace ray { ObjectDirectory::ObjectDirectory(boost::asio::io_service &io_service, - std::shared_ptr &gcs_client) + std::shared_ptr &gcs_client) : io_service_(io_service), gcs_client_(gcs_client) {} namespace { @@ -44,7 +44,7 @@ void UpdateObjectLocations(const GcsChangeMode change_mode, void ObjectDirectory::RegisterBackend() { auto object_notification_callback = - [this](gcs::AsyncGcsClient *client, const ObjectID &object_id, + [this](gcs::RedisGcsClient *client, const ObjectID &object_id, const GcsChangeMode change_mode, const std::vector &location_updates) { // Objects are added to this map in SubscribeObjectLocations. @@ -211,7 +211,7 @@ ray::Status ObjectDirectory::LookupLocations(const ObjectID &object_id, // directly from the GCS. status = gcs_client_->object_table().Lookup( JobID::Nil(), object_id, - [this, callback](gcs::AsyncGcsClient *client, const ObjectID &object_id, + [this, callback](gcs::RedisGcsClient *client, const ObjectID &object_id, const std::vector &location_updates) { // Build the set of current locations based on the entries in the log. std::unordered_set client_ids; diff --git a/src/ray/object_manager/object_directory.h b/src/ray/object_manager/object_directory.h index be21e5280..8ad72db17 100644 --- a/src/ray/object_manager/object_directory.h +++ b/src/ray/object_manager/object_directory.h @@ -11,7 +11,7 @@ #include "ray/common/id.h" #include "ray/common/status.h" -#include "ray/gcs/client.h" +#include "ray/gcs/redis_gcs_client.h" #include "ray/object_manager/format/object_manager_generated.h" namespace ray { @@ -136,7 +136,7 @@ class ObjectDirectory : public ObjectDirectoryInterface { /// \param gcs_client A Ray GCS client to request object and client /// information from. ObjectDirectory(boost::asio::io_service &io_service, - std::shared_ptr &gcs_client); + std::shared_ptr &gcs_client); virtual ~ObjectDirectory() {} @@ -189,7 +189,7 @@ class ObjectDirectory : public ObjectDirectoryInterface { /// Reference to the event loop. boost::asio::io_service &io_service_; /// Reference to the gcs client. - std::shared_ptr gcs_client_; + std::shared_ptr gcs_client_; /// Info about subscribers to object locations. std::unordered_map listeners_; }; diff --git a/src/ray/object_manager/test/object_manager_stress_test.cc b/src/ray/object_manager/test/object_manager_stress_test.cc index 3b142e55c..fa217ab7b 100644 --- a/src/ray/object_manager/test/object_manager_stress_test.cc +++ b/src/ray/object_manager/test/object_manager_stress_test.cc @@ -32,7 +32,7 @@ class MockServer { public: MockServer(boost::asio::io_service &main_service, const ObjectManagerConfig &object_manager_config, - std::shared_ptr gcs_client) + std::shared_ptr gcs_client) : config_(object_manager_config), gcs_client_(gcs_client), object_manager_(main_service, object_manager_config, @@ -44,8 +44,6 @@ class MockServer { private: ray::Status RegisterGcs(boost::asio::io_service &io_service) { - RAY_RETURN_NOT_OK(gcs_client_->Attach(io_service)); - auto object_manager_port = config_.object_manager_port; ClientTableData client_info = gcs_client_->client_table().GetLocalClient(); client_info.set_node_manager_address("127.0.0.1"); @@ -60,7 +58,7 @@ class MockServer { friend class StressTestObjectManager; ObjectManagerConfig config_; - std::shared_ptr gcs_client_; + std::shared_ptr gcs_client_; ObjectManager object_manager_; }; @@ -102,8 +100,11 @@ class TestObjectManagerBase : public ::testing::Test { int push_timeout_ms = 10000; // start first server - gcs_client_1 = std::shared_ptr( - new gcs::AsyncGcsClient("127.0.0.1", 6379, /*is_test_client=*/true)); + gcs::GcsClientOptions client_options("127.0.0.1", 6379, /*password*/ "", + /*is_test_client=*/true); + gcs_client_1 = + std::shared_ptr(new gcs::RedisGcsClient(client_options)); + RAY_CHECK_OK(gcs_client_1->Connect(main_service)); ObjectManagerConfig om_config_1; om_config_1.store_socket_name = store_id_1; om_config_1.pull_timeout_ms = pull_timeout_ms; @@ -114,8 +115,9 @@ class TestObjectManagerBase : public ::testing::Test { server1.reset(new MockServer(main_service, om_config_1, gcs_client_1)); // start second server - gcs_client_2 = std::shared_ptr( - new gcs::AsyncGcsClient("127.0.0.1", 6379, /*is_test_client=*/true)); + gcs_client_2 = + std::shared_ptr(new gcs::RedisGcsClient(client_options)); + RAY_CHECK_OK(gcs_client_2->Connect(main_service)); ObjectManagerConfig om_config_2; om_config_2.store_socket_name = store_id_2; om_config_2.pull_timeout_ms = pull_timeout_ms; @@ -135,6 +137,9 @@ class TestObjectManagerBase : public ::testing::Test { arrow::Status client2_status = client2.Disconnect(); ASSERT_TRUE(client1_status.ok() && client2_status.ok()); + gcs_client_1->Disconnect(); + gcs_client_2->Disconnect(); + this->server1.reset(); this->server2.reset(); @@ -161,8 +166,8 @@ class TestObjectManagerBase : public ::testing::Test { protected: std::thread p; boost::asio::io_service main_service; - std::shared_ptr gcs_client_1; - std::shared_ptr gcs_client_2; + std::shared_ptr gcs_client_1; + std::shared_ptr gcs_client_2; std::unique_ptr server1; std::unique_ptr server2; @@ -210,7 +215,7 @@ class StressTestObjectManager : public TestObjectManagerBase { client_id_1 = gcs_client_1->client_table().GetLocalClientId(); client_id_2 = gcs_client_2->client_table().GetLocalClientId(); gcs_client_1->client_table().RegisterClientAddedCallback( - [this](gcs::AsyncGcsClient *client, const ClientID &id, + [this](gcs::RedisGcsClient *client, const ClientID &id, const ClientTableData &data) { ClientID parsed_id = ClientID::FromBinary(data.client_id()); if (parsed_id == client_id_1 || parsed_id == client_id_2) { diff --git a/src/ray/object_manager/test/object_manager_test.cc b/src/ray/object_manager/test/object_manager_test.cc index 0e0af1ad2..8d8941e9f 100644 --- a/src/ray/object_manager/test/object_manager_test.cc +++ b/src/ray/object_manager/test/object_manager_test.cc @@ -26,7 +26,7 @@ class MockServer { public: MockServer(boost::asio::io_service &main_service, const ObjectManagerConfig &object_manager_config, - std::shared_ptr gcs_client) + std::shared_ptr gcs_client) : config_(object_manager_config), gcs_client_(gcs_client), object_manager_(main_service, object_manager_config, @@ -38,8 +38,6 @@ class MockServer { private: ray::Status RegisterGcs(boost::asio::io_service &io_service) { - RAY_RETURN_NOT_OK(gcs_client_->Attach(io_service)); - auto object_manager_port = config_.object_manager_port; ClientTableData client_info = gcs_client_->client_table().GetLocalClient(); client_info.set_node_manager_address("127.0.0.1"); @@ -54,7 +52,7 @@ class MockServer { friend class TestObjectManager; ObjectManagerConfig config_; - std::shared_ptr gcs_client_; + std::shared_ptr gcs_client_; ObjectManager object_manager_; }; @@ -94,8 +92,11 @@ class TestObjectManagerBase : public ::testing::Test { push_timeout_ms = 1000; // start first server - gcs_client_1 = std::shared_ptr( - new gcs::AsyncGcsClient("127.0.0.1", 6379, /*is_test_client=*/true)); + gcs::GcsClientOptions client_options("127.0.0.1", 6379, /*password*/ "", + /*is_test_client=*/true); + gcs_client_1 = + std::shared_ptr(new gcs::RedisGcsClient(client_options)); + RAY_CHECK_OK(gcs_client_1->Connect(main_service)); ObjectManagerConfig om_config_1; om_config_1.store_socket_name = store_id_1; om_config_1.pull_timeout_ms = pull_timeout_ms; @@ -106,8 +107,9 @@ class TestObjectManagerBase : public ::testing::Test { server1.reset(new MockServer(main_service, om_config_1, gcs_client_1)); // start second server - gcs_client_2 = std::shared_ptr( - new gcs::AsyncGcsClient("127.0.0.1", 6379, /*is_test_client=*/true)); + gcs_client_2 = + std::shared_ptr(new gcs::RedisGcsClient(client_options)); + RAY_CHECK_OK(gcs_client_2->Connect(main_service)); ObjectManagerConfig om_config_2; om_config_2.store_socket_name = store_id_2; om_config_2.pull_timeout_ms = pull_timeout_ms; @@ -127,6 +129,9 @@ class TestObjectManagerBase : public ::testing::Test { arrow::Status client2_status = client2.Disconnect(); ASSERT_TRUE(client1_status.ok() && client2_status.ok()); + gcs_client_1->Disconnect(); + gcs_client_2->Disconnect(); + this->server1.reset(); this->server2.reset(); @@ -157,8 +162,8 @@ class TestObjectManagerBase : public ::testing::Test { protected: std::thread p; boost::asio::io_service main_service; - std::shared_ptr gcs_client_1; - std::shared_ptr gcs_client_2; + std::shared_ptr gcs_client_1; + std::shared_ptr gcs_client_2; std::unique_ptr server1; std::unique_ptr server2; @@ -191,7 +196,7 @@ class TestObjectManager : public TestObjectManagerBase { client_id_1 = gcs_client_1->client_table().GetLocalClientId(); client_id_2 = gcs_client_2->client_table().GetLocalClientId(); gcs_client_1->client_table().RegisterClientAddedCallback( - [this](gcs::AsyncGcsClient *client, const ClientID &id, + [this](gcs::RedisGcsClient *client, const ClientID &id, const ClientTableData &data) { ClientID parsed_id = ClientID::FromBinary(data.client_id()); if (parsed_id == client_id_1 || parsed_id == client_id_2) { diff --git a/src/ray/raylet/lineage_cache.cc b/src/ray/raylet/lineage_cache.cc index 986f735f7..1bdd7d4a2 100644 --- a/src/ray/raylet/lineage_cache.cc +++ b/src/ray/raylet/lineage_cache.cc @@ -269,7 +269,7 @@ void LineageCache::FlushTask(const TaskID &task_id) { RAY_CHECK(entry->GetStatus() < GcsStatus::COMMITTING); gcs::raylet::TaskTable::WriteCallback task_callback = - [this](ray::gcs::AsyncGcsClient *client, const TaskID &id, + [this](ray::gcs::RedisGcsClient *client, const TaskID &id, const TaskTableData &data) { HandleEntryCommitted(id); }; auto task = lineage_.GetEntry(task_id); auto task_data = std::make_shared(); diff --git a/src/ray/raylet/lineage_cache_test.cc b/src/ray/raylet/lineage_cache_test.cc index 19ac1918c..48b850d96 100644 --- a/src/ray/raylet/lineage_cache_test.cc +++ b/src/ray/raylet/lineage_cache_test.cc @@ -31,7 +31,7 @@ class MockGcs : public gcs::TableInterface, // If we requested notifications for this task ID, send the notification as // part of the callback. if (subscribed_tasks_.count(task_id) == 1) { - callback = [this, done](ray::gcs::AsyncGcsClient *client, const TaskID &task_id, + callback = [this, done](ray::gcs::RedisGcsClient *client, const TaskID &task_id, const TaskTableData &data) { done(client, task_id, data); // If we're subscribed to the task to be added, also send a @@ -51,7 +51,7 @@ class MockGcs : public gcs::TableInterface, // Send a notification after the add if the lineage cache requested // notifications for this key. bool send_notification = (subscribed_tasks_.count(task_id) == 1); - auto callback = [this, send_notification](ray::gcs::AsyncGcsClient *client, + auto callback = [this, send_notification](ray::gcs::RedisGcsClient *client, const TaskID &task_id, const TaskTableData &data) { if (send_notification) { @@ -111,7 +111,7 @@ class LineageCacheTest : public ::testing::Test { num_notifications_(0), mock_gcs_(), lineage_cache_(ClientID::FromRandom(), mock_gcs_, mock_gcs_, max_lineage_size_) { - mock_gcs_.Subscribe([this](ray::gcs::AsyncGcsClient *client, const TaskID &task_id, + mock_gcs_.Subscribe([this](ray::gcs::RedisGcsClient *client, const TaskID &task_id, const TaskTableData &data) { lineage_cache_.HandleEntryCommitted(task_id); num_notifications_++; diff --git a/src/ray/raylet/main.cc b/src/ray/raylet/main.cc index f8c82c739..4be890228 100644 --- a/src/ray/raylet/main.cc +++ b/src/ray/raylet/main.cc @@ -1,5 +1,6 @@ #include +#include "ray/common/id.h" #include "ray/common/ray_config.h" #include "ray/common/status.h" #include "ray/common/task/task_common.h" @@ -159,11 +160,10 @@ int main(int argc, char *argv[]) { // Initialize the node manager. boost::asio::io_service main_service; - // initialize mock gcs & object directory - auto gcs_client = std::make_shared(redis_address, redis_port, - redis_password); - RAY_LOG(DEBUG) << "Initializing GCS client " - << gcs_client->client_table().GetLocalClientId(); + // Initialize gcs client + ray::gcs::GcsClientOptions client_options(redis_address, redis_port, redis_password); + auto gcs_client = std::make_shared(client_options); + RAY_CHECK_OK(gcs_client->Connect(main_service)); std::unique_ptr server(new ray::raylet::Raylet( main_service, raylet_socket_name, node_ip_address, redis_address, redis_port, @@ -175,8 +175,9 @@ int main(int argc, char *argv[]) { // We should stop the service and remove the local socket file. auto handler = [&main_service, &raylet_socket_name, &server, &gcs_client]( const boost::system::error_code &error, int signal_number) { - auto shutdown_callback = [&server, &main_service]() { + auto shutdown_callback = [&server, &main_service, &gcs_client]() { server.reset(); + gcs_client->Disconnect(); main_service.stop(); }; RAY_CHECK_OK(gcs_client->client_table().Disconnect(shutdown_callback)); diff --git a/src/ray/raylet/monitor.cc b/src/ray/raylet/monitor.cc index 39bc3f9fb..708dfc039 100644 --- a/src/ray/raylet/monitor.cc +++ b/src/ray/raylet/monitor.cc @@ -17,10 +17,10 @@ namespace raylet { /// the client table, which broadcasts the event to all other Raylets. Monitor::Monitor(boost::asio::io_service &io_service, const std::string &redis_address, int redis_port, const std::string &redis_password) - : gcs_client_(redis_address, redis_port, redis_password), + : gcs_client_(gcs::GcsClientOptions(redis_address, redis_port, redis_password)), num_heartbeats_timeout_(RayConfig::instance().num_heartbeats_timeout()), heartbeat_timer_(io_service) { - RAY_CHECK_OK(gcs_client_.Attach(io_service)); + RAY_CHECK_OK(gcs_client_.Connect(io_service)); } void Monitor::HandleHeartbeat(const ClientID &client_id, @@ -30,7 +30,7 @@ void Monitor::HandleHeartbeat(const ClientID &client_id, } void Monitor::Start() { - const auto heartbeat_callback = [this](gcs::AsyncGcsClient *client, const ClientID &id, + const auto heartbeat_callback = [this](gcs::RedisGcsClient *client, const ClientID &id, const HeartbeatTableData &heartbeat_data) { HandleHeartbeat(id, heartbeat_data); }; @@ -48,7 +48,7 @@ void Monitor::Tick() { auto client_id = it->first; RAY_LOG(WARNING) << "Client timed out: " << client_id; auto lookup_callback = [this, client_id]( - gcs::AsyncGcsClient *client, const ClientID &id, + gcs::RedisGcsClient *client, const ClientID &id, const std::vector &all_data) { bool marked = false; for (const auto &data : all_data) { diff --git a/src/ray/raylet/monitor.h b/src/ray/raylet/monitor.h index 5725e52cf..8c5a8d150 100644 --- a/src/ray/raylet/monitor.h +++ b/src/ray/raylet/monitor.h @@ -5,7 +5,7 @@ #include #include "ray/common/id.h" -#include "ray/gcs/client.h" +#include "ray/gcs/redis_gcs_client.h" namespace ray { @@ -43,7 +43,7 @@ class Monitor { private: /// A client to the GCS, through which heartbeats are received. - gcs::AsyncGcsClient gcs_client_; + gcs::RedisGcsClient gcs_client_; /// The number of heartbeats that can be missed before a client is removed. int64_t num_heartbeats_timeout_; /// A timer that ticks every heartbeat_timeout_ms_ milliseconds. diff --git a/src/ray/raylet/node_manager.cc b/src/ray/raylet/node_manager.cc index b56029d28..e023d5a93 100644 --- a/src/ray/raylet/node_manager.cc +++ b/src/ray/raylet/node_manager.cc @@ -68,7 +68,7 @@ namespace raylet { NodeManager::NodeManager(boost::asio::io_service &io_service, const NodeManagerConfig &config, ObjectManager &object_manager, - std::shared_ptr gcs_client, + std::shared_ptr gcs_client, std::shared_ptr object_directory) : client_id_(gcs_client->client_table().GetLocalClientId()), io_service_(io_service), @@ -130,7 +130,7 @@ ray::Status NodeManager::RegisterGcs() { // Subscribe to task entry commits in the GCS. These notifications are // forwarded to the lineage cache, which requests notifications about tasks // that were executed remotely. - const auto task_committed_callback = [this](gcs::AsyncGcsClient *client, + const auto task_committed_callback = [this](gcs::RedisGcsClient *client, const TaskID &task_id, const TaskTableData &task_data) { lineage_cache_.HandleEntryCommitted(task_id); @@ -139,7 +139,7 @@ ray::Status NodeManager::RegisterGcs() { JobID::Nil(), gcs_client_->client_table().GetLocalClientId(), task_committed_callback, nullptr, nullptr)); - const auto task_lease_notification_callback = [this](gcs::AsyncGcsClient *client, + const auto task_lease_notification_callback = [this](gcs::RedisGcsClient *client, const TaskID &task_id, const TaskLeaseData &task_lease) { const ClientID node_manager_id = ClientID::FromBinary(task_lease.node_manager_id()); @@ -155,7 +155,7 @@ ray::Status NodeManager::RegisterGcs() { reconstruction_policy_.HandleTaskLeaseNotification(task_id, task_lease.timeout()); } }; - const auto task_lease_empty_callback = [this](gcs::AsyncGcsClient *client, + const auto task_lease_empty_callback = [this](gcs::RedisGcsClient *client, const TaskID &task_id) { reconstruction_policy_.HandleTaskLeaseNotification(task_id, 0); }; @@ -164,35 +164,30 @@ ray::Status NodeManager::RegisterGcs() { task_lease_notification_callback, task_lease_empty_callback, nullptr)); // Register a callback to handle actor notifications. - auto actor_notification_callback = [this](gcs::AsyncGcsClient *client, - const ActorID &actor_id, - const std::vector &data) { - if (!data.empty()) { - // We only need the last entry, because it represents the latest state of - // this actor. - HandleActorStateTransition(actor_id, ActorRegistration(data.back())); - } + auto actor_notification_callback = [this](const ActorID &actor_id, + const ActorTableData &data) { + HandleActorStateTransition(actor_id, ActorRegistration(data)); }; - RAY_RETURN_NOT_OK(gcs_client_->actor_table().Subscribe( - JobID::Nil(), ClientID::Nil(), actor_notification_callback, nullptr)); + RAY_RETURN_NOT_OK( + gcs_client_->Actors().AsyncSubscribe(actor_notification_callback, nullptr)); // Register a callback on the client table for new clients. - auto node_manager_client_added = [this](gcs::AsyncGcsClient *client, const UniqueID &id, + auto node_manager_client_added = [this](gcs::RedisGcsClient *client, const UniqueID &id, const ClientTableData &data) { ClientAdded(data); }; gcs_client_->client_table().RegisterClientAddedCallback(node_manager_client_added); // Register a callback on the client table for removed clients. auto node_manager_client_removed = - [this](gcs::AsyncGcsClient *client, const UniqueID &id, + [this](gcs::RedisGcsClient *client, const UniqueID &id, const ClientTableData &data) { ClientRemoved(data); }; gcs_client_->client_table().RegisterClientRemovedCallback(node_manager_client_removed); // Subscribe to resource changes. const auto &resources_changed = [this]( - gcs::AsyncGcsClient *client, const ClientID &id, + gcs::RedisGcsClient *client, const ClientID &id, const gcs::GcsChangeMode change_mode, const std::unordered_map> &data) { @@ -219,7 +214,7 @@ ray::Status NodeManager::RegisterGcs() { // Subscribe to heartbeat batches from the monitor. const auto &heartbeat_batch_added = - [this](gcs::AsyncGcsClient *client, const ClientID &id, + [this](gcs::RedisGcsClient *client, const ClientID &id, const HeartbeatBatchTableData &heartbeat_batch) { HeartbeatBatchAdded(heartbeat_batch); }; @@ -229,7 +224,7 @@ ray::Status NodeManager::RegisterGcs() { /*done_callback=*/nullptr)); // Subscribe to driver table updates. - const auto job_table_handler = [this](gcs::AsyncGcsClient *client, const JobID &job_id, + const auto job_table_handler = [this](gcs::RedisGcsClient *client, const JobID &job_id, const std::vector &job_data) { HandleJobTableUpdate(job_id, job_data); }; @@ -395,7 +390,7 @@ void NodeManager::ClientAdded(const ClientTableData &client_data) { // Fetch resource info for the remote client and update cluster resource map. RAY_CHECK_OK(gcs_client_->resource_table().Lookup( JobID::Nil(), client_id, - [this](gcs::AsyncGcsClient *client, const ClientID &client_id, + [this](gcs::RedisGcsClient *client, const ClientID &client_id, const std::unordered_map> &pairs) { ResourceSet resource_set; @@ -591,38 +586,6 @@ void NodeManager::HeartbeatBatchAdded(const HeartbeatBatchTableData &heartbeat_b } } -void NodeManager::PublishActorStateTransition( - const ActorID &actor_id, const ActorTableData &data, - const ray::gcs::ActorTable::WriteCallback &failure_callback) { - // Copy the actor notification data. - auto actor_notification = std::make_shared(data); - - // The actor log starts with an ALIVE entry. This is followed by 0 to N pairs - // of (RECONSTRUCTING, ALIVE) entries, where N is the maximum number of - // reconstructions. This is followed optionally by a DEAD entry. - int log_length = 2 * (actor_notification->max_reconstructions() - - actor_notification->remaining_reconstructions()); - if (actor_notification->state() != ActorTableData::ALIVE) { - // RECONSTRUCTING or DEAD entries have an odd index. - log_length += 1; - } - // If we successful appended a record to the GCS table of the actor that - // has died, signal this to anyone receiving signals from this actor. - auto success_callback = [](gcs::AsyncGcsClient *client, const ActorID &id, - const ActorTableData &data) { - auto redis_context = client->primary_context(); - if (data.state() == ActorTableData::DEAD || - data.state() == ActorTableData::RECONSTRUCTING) { - std::vector args = {"XADD", id.Hex(), "*", "signal", - "ACTOR_DIED_SIGNAL"}; - RAY_CHECK_OK(redis_context->RunArgvAsync(args)); - } - }; - RAY_CHECK_OK(gcs_client_->actor_table().AppendAt(JobID::Nil(), actor_id, - actor_notification, success_callback, - failure_callback, log_length)); -} - void NodeManager::HandleActorStateTransition(const ActorID &actor_id, ActorRegistration &&actor_registration) { // Update local registry. @@ -924,16 +887,16 @@ void NodeManager::HandleDisconnectedActor(const ActorID &actor_id, bool was_loca // instead of being assigned to the dead actor. HandleActorStateTransition(actor_id, ActorRegistration(new_actor_data)); } - ray::gcs::ActorTable::WriteCallback failure_callback = nullptr; - if (was_local) { - failure_callback = [](gcs::AsyncGcsClient *client, const ActorID &id, - const ActorTableData &data) { + + auto done = [was_local, actor_id](Status status) { + if (was_local && !status.ok()) { // If the disconnected actor was local, only this node will try to update actor // state. So the update shouldn't fail. - RAY_LOG(FATAL) << "Failed to update state for actor " << id; - }; - } - PublishActorStateTransition(actor_id, new_actor_data, failure_callback); + RAY_LOG(FATAL) << "Failed to update state for actor " << actor_id; + } + }; + auto actor_notification = std::make_shared(new_actor_data); + RAY_CHECK_OK(gcs_client_->Actors().AsyncUpdate(actor_id, actor_notification, done)); } void NodeManager::HandleWorkerAvailable( @@ -1206,7 +1169,7 @@ void NodeManager::ProcessPrepareActorCheckpointRequest( // Write checkpoint data to GCS. RAY_CHECK_OK(gcs_client_->actor_checkpoint_table().Add( JobID::Nil(), checkpoint_id, checkpoint_data, - [worker, actor_id, this](ray::gcs::AsyncGcsClient *client, + [worker, actor_id, this](ray::gcs::RedisGcsClient *client, const ActorCheckpointID &checkpoint_id, const ActorCheckpointData &data) { RAY_LOG(DEBUG) << "Checkpoint " << checkpoint_id << " saved for actor " @@ -1562,17 +1525,16 @@ void NodeManager::SubmitTask(const Task &task, const Lineage &uncommitted_lineag // actor creation because this node joined the cluster after the actor // was already created. Look up the actor's registered location in case // we missed the creation notification. - auto lookup_callback = [this](gcs::AsyncGcsClient *client, - const ActorID &actor_id, - const std::vector &data) { + const ActorID &actor_id = spec.ActorId(); + auto lookup_callback = [this, actor_id](Status status, + const std::vector &data) { if (!data.empty()) { // The actor has been created. We only need the last entry, because // it represents the latest state of this actor. HandleActorStateTransition(actor_id, ActorRegistration(data.back())); } }; - RAY_CHECK_OK(gcs_client_->actor_table().Lookup(JobID::Nil(), spec.ActorId(), - lookup_callback)); + RAY_CHECK_OK(gcs_client_->Actors().AsyncGet(actor_id, lookup_callback)); actor_creation_dummy_object = spec.ActorCreationDummyObjectId(); } else { actor_creation_dummy_object = actor_entry->second.GetActorCreationDependency(); @@ -1911,7 +1873,7 @@ void NodeManager::FinishAssignedActorTask(Worker &worker, const Task &task) { JobID::Nil(), parent_task_id, /*success_callback=*/ [this, task_spec, resumed_from_checkpoint]( - ray::gcs::AsyncGcsClient *client, const TaskID &parent_task_id, + ray::gcs::RedisGcsClient *client, const TaskID &parent_task_id, const TaskTableData &parent_task_data) { // The task was in the GCS task table. Use the stored task spec to // get the parent actor id. @@ -1926,7 +1888,7 @@ void NodeManager::FinishAssignedActorTask(Worker &worker, const Task &task) { resumed_from_checkpoint); }, /*failure_callback=*/ - [this, task_spec, resumed_from_checkpoint](ray::gcs::AsyncGcsClient *client, + [this, task_spec, resumed_from_checkpoint](ray::gcs::RedisGcsClient *client, const TaskID &parent_task_id) { // The parent task was not in the GCS task table. It should most likely be in // the @@ -1990,6 +1952,13 @@ void NodeManager::FinishAssignedActorCreationTask(const ActorID &parent_actor_id const ActorID actor_id = task_spec.ActorCreationId(); auto new_actor_data = CreateActorTableDataFromCreationTask(task_spec); new_actor_data.set_parent_actor_id(parent_actor_id.Binary()); + auto update_callback = [actor_id](Status status) { + if (!status.ok()) { + // Only one node at a time should succeed at creating or updating the actor. + RAY_LOG(FATAL) << "Failed to update state to ALIVE for actor " << actor_id; + } + }; + if (resumed_from_checkpoint) { // This actor was resumed from a checkpoint. In this case, we first look // up the checkpoint in GCS and use it to restore the actor registration @@ -2000,9 +1969,9 @@ void NodeManager::FinishAssignedActorCreationTask(const ActorID &parent_actor_id << actor_id; RAY_CHECK_OK(gcs_client_->actor_checkpoint_table().Lookup( JobID::Nil(), checkpoint_id, - [this, actor_id, new_actor_data](ray::gcs::AsyncGcsClient *client, - const UniqueID &checkpoint_id, - const ActorCheckpointData &checkpoint_data) { + [this, actor_id, new_actor_data, update_callback]( + ray::gcs::RedisGcsClient *client, const UniqueID &checkpoint_id, + const ActorCheckpointData &checkpoint_data) { RAY_LOG(INFO) << "Restoring registration for actor " << actor_id << " from checkpoint " << checkpoint_id; ActorRegistration actor_registration = @@ -2012,16 +1981,12 @@ void NodeManager::FinishAssignedActorCreationTask(const ActorID &parent_actor_id HandleObjectLocal(entry.first); } HandleActorStateTransition(actor_id, std::move(actor_registration)); - PublishActorStateTransition( - actor_id, new_actor_data, - /*failure_callback=*/ - [](gcs::AsyncGcsClient *client, const ActorID &id, - const ActorTableData &data) { - // Only one node at a time should succeed at creating the actor. - RAY_LOG(FATAL) << "Failed to update state to ALIVE for actor " << id; - }); + auto actor_notification = std::make_shared(new_actor_data); + // The actor was created before. + RAY_CHECK_OK(gcs_client_->Actors().AsyncUpdate(actor_id, actor_notification, + update_callback)); }, - [actor_id](ray::gcs::AsyncGcsClient *client, const UniqueID &checkpoint_id) { + [actor_id](ray::gcs::RedisGcsClient *client, const UniqueID &checkpoint_id) { RAY_LOG(FATAL) << "Couldn't find checkpoint " << checkpoint_id << " for actor " << actor_id << " in GCS."; })); @@ -2029,13 +1994,16 @@ void NodeManager::FinishAssignedActorCreationTask(const ActorID &parent_actor_id // The actor did not resume from a checkpoint. Immediately notify the // other node managers that the actor has been created. HandleActorStateTransition(actor_id, ActorRegistration(new_actor_data)); - PublishActorStateTransition( - actor_id, new_actor_data, - /*failure_callback=*/ - [](gcs::AsyncGcsClient *client, const ActorID &id, const ActorTableData &data) { - // Only one node at a time should succeed at creating the actor. - RAY_LOG(FATAL) << "Failed to update state to ALIVE for actor " << id; - }); + auto actor_notification = std::make_shared(new_actor_data); + if (actor_registry_.find(actor_id) != actor_registry_.end()) { + // The actor was created before. + RAY_CHECK_OK(gcs_client_->Actors().AsyncUpdate(actor_id, actor_notification, + update_callback)); + } else { + // The actor was never created before. + RAY_CHECK_OK( + gcs_client_->Actors().AsyncRegister(actor_notification, update_callback)); + } } if (!resumed_from_checkpoint) { // The actor was not resumed from a checkpoint. We extend the actor's @@ -2049,14 +2017,14 @@ void NodeManager::HandleTaskReconstruction(const TaskID &task_id) { RAY_CHECK_OK(gcs_client_->raylet_task_table().Lookup( JobID::Nil(), task_id, /*success_callback=*/ - [this](ray::gcs::AsyncGcsClient *client, const TaskID &task_id, + [this](ray::gcs::RedisGcsClient *client, const TaskID &task_id, const TaskTableData &task_data) { // The task was in the GCS task table. Use the stored task spec to // re-execute the task. ResubmitTask(Task(task_data.task())); }, /*failure_callback=*/ - [this](ray::gcs::AsyncGcsClient *client, const TaskID &task_id) { + [this](ray::gcs::RedisGcsClient *client, const TaskID &task_id) { // The task was not in the GCS task table. It must therefore be in the // lineage cache. RAY_CHECK(lineage_cache_.ContainsTask(task_id)) diff --git a/src/ray/raylet/node_manager.h b/src/ray/raylet/node_manager.h index 9fbb4a9bc..84b7b00e7 100644 --- a/src/ray/raylet/node_manager.h +++ b/src/ray/raylet/node_manager.h @@ -72,7 +72,7 @@ class NodeManager : public rpc::NodeManagerServiceHandler { /// \param object_manager A reference to the local object manager. NodeManager(boost::asio::io_service &io_service, const NodeManagerConfig &config, ObjectManager &object_manager, - std::shared_ptr gcs_client, + std::shared_ptr gcs_client, std::shared_ptr object_directory_); /// Process a new client connection. @@ -343,16 +343,6 @@ class NodeManager : public rpc::NodeManagerServiceHandler { void HandleActorStateTransition(const ActorID &actor_id, ActorRegistration &&actor_registration); - /// Publish an actor's state transition to all other nodes. - /// - /// \param actor_id The actor ID of the actor whose state was updated. - /// \param data Data to publish. - /// \param failure_callback An optional callback to call if the publish is - /// unsuccessful. - void PublishActorStateTransition( - const ActorID &actor_id, const ActorTableData &data, - const ray::gcs::ActorTable::WriteCallback &failure_callback); - /// When a job finished, loop over all of the queued tasks for that job and /// treat them as failed. /// @@ -501,7 +491,7 @@ class NodeManager : public rpc::NodeManagerServiceHandler { /// because the actor died). plasma::PlasmaClient store_client_; /// A client connection to the GCS. - std::shared_ptr gcs_client_; + std::shared_ptr gcs_client_; /// The object table. This is shared with the object manager. std::shared_ptr object_directory_; /// The timer used to send heartbeats. diff --git a/src/ray/raylet/object_manager_integration_test.cc b/src/ray/raylet/object_manager_integration_test.cc index 0f411e8c5..331ad027d 100644 --- a/src/ray/raylet/object_manager_integration_test.cc +++ b/src/ray/raylet/object_manager_integration_test.cc @@ -56,8 +56,9 @@ class TestObjectManagerBase : public ::testing::Test { std::string store_sock_2 = StartStore("2"); // start first server - gcs_client_1 = std::shared_ptr( - new gcs::AsyncGcsClient("127.0.0.1", 6379, /*is_test_client=*/true)); + gcs::GcsClientOptions client_options("127.0.0.1", 6379, /*password*/ "", true); + gcs_client_1 = + std::shared_ptr(new gcs::RedisGcsClient(client_options)); ObjectManagerConfig om_config_1; om_config_1.store_socket_name = store_sock_1; om_config_1.push_timeout_ms = 10000; @@ -66,8 +67,8 @@ class TestObjectManagerBase : public ::testing::Test { GetNodeManagerConfig("raylet_1", store_sock_1), om_config_1, gcs_client_1)); // start second server - gcs_client_2 = std::shared_ptr( - new gcs::AsyncGcsClient("127.0.0.1", 6379, /*is_test_client=*/true)); + gcs_client_2 = + std::shared_ptr(new gcs::RedisGcsClient(client_options)); ObjectManagerConfig om_config_2; om_config_2.store_socket_name = store_sock_2; om_config_2.push_timeout_ms = 10000; @@ -113,8 +114,8 @@ class TestObjectManagerBase : public ::testing::Test { protected: std::thread p; boost::asio::io_service main_service; - std::shared_ptr gcs_client_1; - std::shared_ptr gcs_client_2; + std::shared_ptr gcs_client_1; + std::shared_ptr gcs_client_2; std::unique_ptr server1; std::unique_ptr server2; @@ -137,7 +138,7 @@ class TestObjectManagerIntegration : public TestObjectManagerBase { client_id_1 = gcs_client_1->client_table().GetLocalClientId(); client_id_2 = gcs_client_2->client_table().GetLocalClientId(); gcs_client_1->client_table().RegisterClientAddedCallback( - [this](gcs::AsyncGcsClient *client, const ClientID &id, + [this](gcs::RedisGcsClient *client, const ClientID &id, const ClientTableDataT &data) { ClientID parsed_id = ClientID::FromBinary(data.client_id); if (parsed_id == client_id_1 || parsed_id == client_id_2) { diff --git a/src/ray/raylet/raylet.cc b/src/ray/raylet/raylet.cc index 0544a5674..72d9ab799 100644 --- a/src/ray/raylet/raylet.cc +++ b/src/ray/raylet/raylet.cc @@ -44,7 +44,7 @@ Raylet::Raylet(boost::asio::io_service &main_service, const std::string &socket_ int redis_port, const std::string &redis_password, const NodeManagerConfig &node_manager_config, const ObjectManagerConfig &object_manager_config, - std::shared_ptr gcs_client) + std::shared_ptr gcs_client) : gcs_client_(gcs_client), object_directory_(std::make_shared(main_service, gcs_client_)), object_manager_(main_service, object_manager_config, object_directory_), @@ -78,8 +78,6 @@ ray::Status Raylet::RegisterGcs(const std::string &node_ip_address, const std::string &redis_password, boost::asio::io_service &io_service, const NodeManagerConfig &node_manager_config) { - RAY_RETURN_NOT_OK(gcs_client_->Attach(io_service)); - ClientTableData client_info = gcs_client_->client_table().GetLocalClient(); client_info.set_node_manager_address(node_ip_address); client_info.set_raylet_socket_name(raylet_socket_name); diff --git a/src/ray/raylet/raylet.h b/src/ray/raylet/raylet.h index d0ae57c7f..39e226a77 100644 --- a/src/ray/raylet/raylet.h +++ b/src/ray/raylet/raylet.h @@ -41,7 +41,7 @@ class Raylet { int redis_port, const std::string &redis_password, const NodeManagerConfig &node_manager_config, const ObjectManagerConfig &object_manager_config, - std::shared_ptr gcs_client); + std::shared_ptr gcs_client); /// Destroy the NodeServer. ~Raylet(); @@ -64,7 +64,7 @@ class Raylet { friend class TestObjectManagerIntegration; /// A client connection to the GCS. - std::shared_ptr gcs_client_; + std::shared_ptr gcs_client_; /// The object table. This is shared between the object manager and node /// manager. std::shared_ptr object_directory_; diff --git a/src/ray/raylet/reconstruction_policy.cc b/src/ray/raylet/reconstruction_policy.cc index f522c8986..a51cded17 100644 --- a/src/ray/raylet/reconstruction_policy.cc +++ b/src/ray/raylet/reconstruction_policy.cc @@ -112,12 +112,12 @@ void ReconstructionPolicy::AttemptReconstruction(const TaskID &task_id, RAY_CHECK_OK(task_reconstruction_log_.AppendAt( JobID::Nil(), task_id, reconstruction_entry, /*success_callback=*/ - [this](gcs::AsyncGcsClient *client, const TaskID &task_id, + [this](gcs::RedisGcsClient *client, const TaskID &task_id, const TaskReconstructionData &data) { HandleReconstructionLogAppend(task_id, /*success=*/true); }, /*failure_callback=*/ - [this](gcs::AsyncGcsClient *client, const TaskID &task_id, + [this](gcs::RedisGcsClient *client, const TaskID &task_id, const TaskReconstructionData &data) { HandleReconstructionLogAppend(task_id, /*success=*/false); }, diff --git a/src/ray/raylet/reconstruction_policy_test.cc b/src/ray/raylet/reconstruction_policy_test.cc index 94155b442..02554399e 100644 --- a/src/ray/raylet/reconstruction_policy_test.cc +++ b/src/ray/raylet/reconstruction_policy_test.cc @@ -160,12 +160,12 @@ class ReconstructionPolicyTest : public ::testing::Test { mock_object_directory_, mock_gcs_)), timer_canceled_(false) { mock_gcs_.Subscribe( - [this](gcs::AsyncGcsClient *client, const TaskID &task_id, + [this](gcs::RedisGcsClient *client, const TaskID &task_id, const TaskLeaseData &task_lease) { reconstruction_policy_->HandleTaskLeaseNotification(task_id, task_lease.timeout()); }, - [this](gcs::AsyncGcsClient *client, const TaskID &task_id) { + [this](gcs::RedisGcsClient *client, const TaskID &task_id) { reconstruction_policy_->HandleTaskLeaseNotification(task_id, 0); }); } @@ -401,7 +401,7 @@ TEST_F(ReconstructionPolicyTest, TestSimultaneousReconstructionSuppressed) { RAY_CHECK_OK( mock_gcs_.AppendAt(JobID::Nil(), task_id, task_reconstruction_data, nullptr, /*failure_callback=*/ - [](ray::gcs::AsyncGcsClient *client, const TaskID &task_id, + [](ray::gcs::RedisGcsClient *client, const TaskID &task_id, const TaskReconstructionData &data) { ASSERT_TRUE(false); }, /*log_index=*/0)); diff --git a/src/ray/raylet/worker_pool.cc b/src/ray/raylet/worker_pool.cc index d9cb37843..ee2f10ea0 100644 --- a/src/ray/raylet/worker_pool.cc +++ b/src/ray/raylet/worker_pool.cc @@ -43,7 +43,7 @@ namespace raylet { /// (num_worker_processes * num_workers_per_process) workers for each language. WorkerPool::WorkerPool(int num_worker_processes, int num_workers_per_process, int maximum_startup_concurrency, - std::shared_ptr gcs_client, + std::shared_ptr gcs_client, const WorkerCommandMap &worker_commands) : num_workers_per_process_(num_workers_per_process), multiple_for_warning_(std::max(num_worker_processes, maximum_startup_concurrency)), diff --git a/src/ray/raylet/worker_pool.h b/src/ray/raylet/worker_pool.h index a243d53a7..3221c5cf0 100644 --- a/src/ray/raylet/worker_pool.h +++ b/src/ray/raylet/worker_pool.h @@ -9,7 +9,7 @@ #include "ray/common/client_connection.h" #include "ray/common/task/task.h" #include "ray/common/task/task_common.h" -#include "ray/gcs/client.h" +#include "ray/gcs/redis_gcs_client.h" #include "ray/raylet/worker.h" namespace ray { @@ -41,7 +41,7 @@ class WorkerPool { /// language. WorkerPool(int num_worker_processes, int num_workers_per_process, int maximum_startup_concurrency, - std::shared_ptr gcs_client, + std::shared_ptr gcs_client, const WorkerCommandMap &worker_commands); /// Destructor responsible for freeing a set of workers owned by this class. @@ -198,7 +198,7 @@ class WorkerPool { /// was generated. int64_t last_warning_multiple_; /// A client connection to the GCS. - std::shared_ptr gcs_client_; + std::shared_ptr gcs_client_; }; } // namespace raylet diff --git a/src/ray/test/run_gcs_tests.sh b/src/ray/test/run_gcs_tests.sh index 2aaf88da7..4c780a1e3 100644 --- a/src/ray/test/run_gcs_tests.sh +++ b/src/ray/test/run_gcs_tests.sh @@ -6,7 +6,7 @@ set -e set -x -bazel build "//:gcs_client_test" "//:asio_test" "//:libray_redis_module.so" +bazel build "//:redis_gcs_client_test" "//:actor_state_accessor_test" "//:asio_test" "//:libray_redis_module.so" # Start Redis. if [[ "${RAY_USE_NEW_GCS}" = "on" ]]; then @@ -23,7 +23,8 @@ else fi sleep 1s -./bazel-bin/gcs_client_test +./bazel-bin/redis_gcs_client_test +./bazel-bin/actor_state_accessor_test ./bazel-bin/asio_test ./bazel-genfiles/redis-cli -p 6379 shutdown