support for subscription to an actor (#5269)

This commit is contained in:
micafan 2019-08-20 20:32:53 +08:00 committed by Hao Chen
parent 851c5b2dae
commit da7bdacea5
20 changed files with 655 additions and 130 deletions

View file

@ -640,6 +640,7 @@ cc_library(
],
)
# TODO(micafan) Replace cc_binary with cc_test for GCS test.
cc_binary(
name = "redis_gcs_client_test",
testonly = 1,
@ -662,6 +663,17 @@ cc_binary(
],
)
cc_binary(
name = "subscription_executor_test",
testonly = 1,
srcs = ["src/ray/gcs/subscription_executor_test.cc"],
copts = COPTS,
deps = [
":gcs",
"@com_google_googletest//:gtest_main",
],
)
cc_binary(
name = "asio_test",
testonly = 1,

View file

@ -0,0 +1,76 @@
#include <atomic>
#include <chrono>
#include <string>
#include <thread>
#include <vector>
#include "gtest/gtest.h"
#include "ray/gcs/redis_gcs_client.h"
#include "ray/util/test_util.h"
namespace ray {
namespace gcs {
template <typename ID, typename Data>
class AccessorTestBase : public ::testing::Test {
public:
AccessorTestBase() : options_("127.0.0.1", 6379, "", true) {}
virtual ~AccessorTestBase() {}
virtual void SetUp() {
GenTestData();
gcs_client_.reset(new RedisGcsClient(options_));
RAY_CHECK_OK(gcs_client_->Connect(io_service_));
work_thread.reset(new std::thread([this] {
std::unique_ptr<boost::asio::io_service::work> work(
new boost::asio::io_service::work(io_service_));
io_service_.run();
}));
}
virtual void TearDown() {
gcs_client_->Disconnect();
io_service_.stop();
work_thread->join();
work_thread.reset();
gcs_client_.reset();
ClearTestData();
}
protected:
virtual void GenTestData() = 0;
void ClearTestData() { id_to_data_.clear(); }
void WaitPendingDone(std::chrono::milliseconds timeout) {
WaitPendingDone(pending_count_, timeout);
}
void WaitPendingDone(std::atomic<int> &pending_count,
std::chrono::milliseconds timeout) {
auto condition = [&pending_count]() { return pending_count == 0; };
EXPECT_TRUE(WaitForCondition(condition, timeout.count()));
}
protected:
GcsClientOptions options_;
std::unique_ptr<RedisGcsClient> gcs_client_;
boost::asio::io_service io_service_;
std::unique_ptr<std::thread> work_thread;
std::unordered_map<ID, std::shared_ptr<Data>> id_to_data_;
std::atomic<int> pending_count_{0};
std::chrono::milliseconds wait_pending_timeout_{10000};
};
} // namespace gcs
} // namespace ray

View file

@ -8,7 +8,7 @@ namespace ray {
namespace gcs {
ActorStateAccessor::ActorStateAccessor(RedisGcsClient &client_impl)
: client_impl_(client_impl) {}
: client_impl_(client_impl), actor_sub_executor_(client_impl_.actor_table()) {}
Status ActorStateAccessor::AsyncGet(const ActorID &actor_id,
const MultiItemCallback<ActorTableData> &callback) {
@ -90,23 +90,19 @@ Status ActorStateAccessor::AsyncSubscribe(
const SubscribeCallback<ActorID, ActorTableData> &subscribe,
const StatusCallback &done) {
RAY_CHECK(subscribe != nullptr);
auto on_subscribe = [subscribe](RedisGcsClient *client, const ActorID &actor_id,
const std::vector<ActorTableData> &data) {
if (!data.empty()) {
// We only need the last entry, because it represents the latest state of
// this actor.
subscribe(actor_id, data.back());
}
};
return actor_sub_executor_.AsyncSubscribe(ClientID::Nil(), subscribe, done);
}
auto on_done = [done](RedisGcsClient *client) {
if (done != nullptr) {
done(Status::OK());
}
};
Status ActorStateAccessor::AsyncSubscribe(
const ActorID &actor_id, const SubscribeCallback<ActorID, ActorTableData> &subscribe,
const StatusCallback &done) {
RAY_CHECK(subscribe != nullptr);
return actor_sub_executor_.AsyncSubscribe(node_id_, actor_id, subscribe, done);
}
ActorTable &actor_table = client_impl_.actor_table();
return actor_table.Subscribe(JobID::Nil(), ClientID::Nil(), on_subscribe, on_done);
Status ActorStateAccessor::AsyncUnsubscribe(const ActorID &actor_id,
const StatusCallback &done) {
return actor_sub_executor_.AsyncUnsubscribe(node_id_, actor_id, done);
}
} // namespace gcs

View file

@ -3,6 +3,7 @@
#include "ray/common/id.h"
#include "ray/gcs/callback.h"
#include "ray/gcs/subscription_executor.h"
#include "ray/gcs/tables.h"
namespace ray {
@ -50,7 +51,7 @@ class ActorStateAccessor {
const std::shared_ptr<ActorTableData> &data_ptr,
const StatusCallback &callback);
/// Subscribe to any register operations of actors.
/// Subscribe to any register or update operations of actors.
///
/// \param subscribe Callback that will be called each time when an actor is registered
/// or updated.
@ -60,8 +61,36 @@ class ActorStateAccessor {
Status AsyncSubscribe(const SubscribeCallback<ActorID, ActorTableData> &subscribe,
const StatusCallback &done);
/// Subscribe to any update operations of an actor.
///
/// \param actor_id The ID of actor to be subscribed to.
/// \param subscribe Callback that will be called each time when the actor is updated.
/// \param done Callback that will be called when subscription is complete.
/// \return Status
Status AsyncSubscribe(const ActorID &actor_id,
const SubscribeCallback<ActorID, ActorTableData> &subscribe,
const StatusCallback &done);
/// Cancel subscription to an actor.
///
/// \param actor_id The ID of the actor to be unsubscribed to.
/// \param done Callback that will be called when unsubscribe is complete.
/// \return Status
Status AsyncUnsubscribe(const ActorID &actor_id, const StatusCallback &done);
private:
RedisGcsClient &client_impl_;
// Use a random ClientID for actor subscription. Because:
// If we use ClientID::Nil, GCS will still send all actors' updates to this GCS Client.
// Even we can filter out irrelevant updates, but there will be extra overhead.
// And because the new GCS Client will no longer hold the local ClientID, so we use
// random ClientID instead.
// TODO(micafan): Remove this random id, once GCS becomes a service.
ClientID node_id_{ClientID::FromRandom()};
typedef SubscriptionExecutor<ActorID, ActorTableData, ActorTable>
ActorSubscriptionExecutor;
ActorSubscriptionExecutor actor_sub_executor_;
};
} // namespace gcs

View file

@ -4,6 +4,7 @@
#include <thread>
#include <vector>
#include "gtest/gtest.h"
#include "ray/gcs/accessor_test_base.h"
#include "ray/gcs/redis_gcs_client.h"
#include "ray/util/test_util.h"
@ -11,39 +12,9 @@ namespace ray {
namespace gcs {
class ActorStateAccessorTest : public ::testing::Test {
public:
ActorStateAccessorTest() : options_("127.0.0.1", 6379, "", true) {}
virtual void SetUp() {
GenTestData();
gcs_client_.reset(new RedisGcsClient(options_));
RAY_CHECK_OK(gcs_client_->Connect(io_service_));
work_thread.reset(new std::thread([this] {
std::unique_ptr<boost::asio::io_service::work> work(
new boost::asio::io_service::work(io_service_));
io_service_.run();
}));
}
virtual void TearDown() {
gcs_client_->Disconnect();
io_service_.stop();
work_thread->join();
work_thread.reset();
gcs_client_.reset();
ClearTestData();
}
class ActorStateAccessorTest : public AccessorTestBase<ActorID, ActorTableData> {
protected:
void GenTestData() { GenActorData(); }
void GenActorData() {
virtual void GenTestData() {
for (size_t i = 0; i < 100; ++i) {
std::shared_ptr<ActorTableData> actor = std::make_shared<ActorTableData>();
actor->set_max_reconstructions(1);
@ -53,42 +24,15 @@ class ActorStateAccessorTest : public ::testing::Test {
actor->set_state(ActorTableData::ALIVE);
ActorID actor_id = ActorID::Of(job_id, RandomTaskId(), /*parent_task_counter=*/i);
actor->set_actor_id(actor_id.Binary());
actor_datas_[actor_id] = actor;
id_to_data_[actor_id] = actor;
}
}
void ClearTestData() { actor_datas_.clear(); }
void WaitPendingDone(std::chrono::milliseconds timeout) {
WaitPendingDone(pending_count_, timeout);
}
void WaitPendingDone(std::atomic<int> &pending_count,
std::chrono::milliseconds timeout) {
while (pending_count != 0 && timeout.count() > 0) {
std::chrono::milliseconds interval(10);
std::this_thread::sleep_for(interval);
timeout -= interval;
}
EXPECT_EQ(pending_count, 0);
}
protected:
GcsClientOptions options_;
std::unique_ptr<RedisGcsClient> gcs_client_;
boost::asio::io_service io_service_;
std::unique_ptr<std::thread> work_thread;
std::unordered_map<ActorID, std::shared_ptr<ActorTableData>> actor_datas_;
std::atomic<int> pending_count_{0};
};
TEST_F(ActorStateAccessorTest, RegisterAndGet) {
ActorStateAccessor &actor_accessor = gcs_client_->Actors();
// register
for (const auto &elem : actor_datas_) {
for (const auto &elem : id_to_data_) {
const auto &actor = elem.second;
++pending_count_;
RAY_CHECK_OK(actor_accessor.AsyncRegister(actor, [this](Status status) {
@ -97,35 +41,33 @@ TEST_F(ActorStateAccessorTest, RegisterAndGet) {
}));
}
std::chrono::milliseconds timeout(10000);
WaitPendingDone(timeout);
WaitPendingDone(wait_pending_timeout_);
// get
for (const auto &elem : actor_datas_) {
for (const auto &elem : id_to_data_) {
++pending_count_;
RAY_CHECK_OK(actor_accessor.AsyncGet(
elem.first, [this](Status status, std::vector<ActorTableData> datas) {
ASSERT_EQ(datas.size(), 1U);
ActorID actor_id = ActorID::FromBinary(datas[0].actor_id());
auto it = actor_datas_.find(actor_id);
ASSERT_TRUE(it != actor_datas_.end());
auto it = id_to_data_.find(actor_id);
ASSERT_TRUE(it != id_to_data_.end());
--pending_count_;
}));
}
WaitPendingDone(timeout);
WaitPendingDone(wait_pending_timeout_);
}
TEST_F(ActorStateAccessorTest, Subscribe) {
ActorStateAccessor &actor_accessor = gcs_client_->Actors();
std::chrono::milliseconds timeout(10000);
// subscribe
std::atomic<int> sub_pending_count(0);
std::atomic<int> do_sub_pending_count(0);
auto subscribe = [this, &sub_pending_count](const ActorID &actor_id,
const ActorTableData &data) {
const auto it = actor_datas_.find(actor_id);
ASSERT_TRUE(it != actor_datas_.end());
const auto it = id_to_data_.find(actor_id);
ASSERT_TRUE(it != id_to_data_.end());
--sub_pending_count;
};
auto done = [&do_sub_pending_count](Status status) {
@ -136,11 +78,11 @@ TEST_F(ActorStateAccessorTest, Subscribe) {
++do_sub_pending_count;
RAY_CHECK_OK(actor_accessor.AsyncSubscribe(subscribe, done));
// Wait until subscribe finishes.
WaitPendingDone(do_sub_pending_count, timeout);
WaitPendingDone(do_sub_pending_count, wait_pending_timeout_);
// register
std::atomic<int> register_pending_count(0);
for (const auto &elem : actor_datas_) {
for (const auto &elem : id_to_data_) {
const auto &actor = elem.second;
++sub_pending_count;
++register_pending_count;
@ -151,10 +93,10 @@ TEST_F(ActorStateAccessorTest, Subscribe) {
}));
}
// Wait until register finishes.
WaitPendingDone(register_pending_count, timeout);
WaitPendingDone(register_pending_count, wait_pending_timeout_);
// Wait for all subscribe notifications.
WaitPendingDone(sub_pending_count, timeout);
WaitPendingDone(sub_pending_count, wait_pending_timeout_);
}
} // namespace gcs

View file

@ -95,7 +95,7 @@ class GcsClientInterface : public std::enable_shared_from_this<GcsClientInterfac
GcsClientOptions options_;
// Whether this client is connected to GCS.
/// Whether this client is connected to GCS.
bool is_connected_{false};
std::unique_ptr<ActorStateAccessor> actor_accessor_;

View file

@ -19,6 +19,7 @@ class RedisContext;
class RAY_EXPORT RedisGcsClient : public GcsClientInterface {
friend class ActorStateAccessor;
friend class SubscriptionExecutorTest;
public:
/// Constructor of RedisGcsClient.

View file

@ -741,7 +741,7 @@ void TestTableSubscribeId(const JobID &job_id,
num_modifications](gcs::RedisGcsClient *client) {
// Request notifications for one of the keys.
RAY_CHECK_OK(client->raylet_task_table().RequestNotifications(
job_id, task_id2, client->client_table().GetLocalClientId()));
job_id, task_id2, client->client_table().GetLocalClientId(), nullptr));
// Write both keys. We should only receive notifications for the key that
// we requested them for.
for (uint64_t i = 0; i < num_modifications; i++) {
@ -814,7 +814,7 @@ void TestLogSubscribeId(const JobID &job_id,
job_ids2](gcs::RedisGcsClient *client) {
// Request notifications for one of the keys.
RAY_CHECK_OK(client->job_table().RequestNotifications(
job_id, job_id2, client->client_table().GetLocalClientId()));
job_id, job_id2, client->client_table().GetLocalClientId(), nullptr));
// Write both keys. We should only receive notifications for the key that
// we requested them for.
auto remaining = std::vector<std::string>(++job_ids1.begin(), job_ids1.end());
@ -890,7 +890,7 @@ void TestSetSubscribeId(const JobID &job_id,
managers2](gcs::RedisGcsClient *client) {
// Request notifications for one of the keys.
RAY_CHECK_OK(client->object_table().RequestNotifications(
job_id, object_id2, client->client_table().GetLocalClientId()));
job_id, object_id2, client->client_table().GetLocalClientId(), nullptr));
// Write both keys. We should only receive notifications for the key that
// we requested them for.
auto remaining = std::vector<std::string>(++managers1.begin(), managers1.end());
@ -964,9 +964,9 @@ void TestTableSubscribeCancel(const JobID &job_id,
// Request notifications, then cancel immediately. We should receive a
// notification for the current value at the key.
RAY_CHECK_OK(client->raylet_task_table().RequestNotifications(
job_id, task_id, client->client_table().GetLocalClientId()));
job_id, task_id, client->client_table().GetLocalClientId(), nullptr));
RAY_CHECK_OK(client->raylet_task_table().CancelNotifications(
job_id, task_id, client->client_table().GetLocalClientId()));
job_id, task_id, client->client_table().GetLocalClientId(), nullptr));
// Write to the key. Since we canceled notifications, we should not receive
// a notification for these writes.
for (uint64_t i = 1; i < num_modifications; i++) {
@ -976,7 +976,7 @@ void TestTableSubscribeCancel(const JobID &job_id,
// Request notifications again. We should receive a notification for the
// current value at the key.
RAY_CHECK_OK(client->raylet_task_table().RequestNotifications(
job_id, task_id, client->client_table().GetLocalClientId()));
job_id, task_id, client->client_table().GetLocalClientId(), nullptr));
};
// Subscribe to notifications for this client. This allows us to request and
@ -1033,9 +1033,9 @@ void TestLogSubscribeCancel(const JobID &job_id,
// Request notifications, then cancel immediately. We should receive a
// notification for the current value at the key.
RAY_CHECK_OK(client->job_table().RequestNotifications(
job_id, random_job_id, client->client_table().GetLocalClientId()));
job_id, random_job_id, client->client_table().GetLocalClientId(), nullptr));
RAY_CHECK_OK(client->job_table().CancelNotifications(
job_id, random_job_id, client->client_table().GetLocalClientId()));
job_id, random_job_id, client->client_table().GetLocalClientId(), nullptr));
// Append to the key. Since we canceled notifications, we should not
// receive a notification for these writes.
auto remaining = std::vector<std::string>(++job_ids.begin(), job_ids.end());
@ -1047,7 +1047,7 @@ void TestLogSubscribeCancel(const JobID &job_id,
// Request notifications again. We should receive a notification for the
// current values at the key.
RAY_CHECK_OK(client->job_table().RequestNotifications(
job_id, random_job_id, client->client_table().GetLocalClientId()));
job_id, random_job_id, client->client_table().GetLocalClientId(), nullptr));
};
// Subscribe to notifications for this client. This allows us to request and
@ -1115,9 +1115,9 @@ void TestSetSubscribeCancel(const JobID &job_id,
// Request notifications, then cancel immediately. We should receive a
// notification for the current value at the key.
RAY_CHECK_OK(client->object_table().RequestNotifications(
job_id, object_id, client->client_table().GetLocalClientId()));
job_id, object_id, client->client_table().GetLocalClientId(), nullptr));
RAY_CHECK_OK(client->object_table().CancelNotifications(
job_id, object_id, client->client_table().GetLocalClientId()));
job_id, object_id, client->client_table().GetLocalClientId(), nullptr));
// Add to the key. Since we canceled notifications, we should not
// receive a notification for these writes.
auto remaining = std::vector<std::string>(++managers.begin(), managers.end());
@ -1129,7 +1129,7 @@ void TestSetSubscribeCancel(const JobID &job_id,
// Request notifications again. We should receive a notification for the
// current values at the key.
RAY_CHECK_OK(client->object_table().RequestNotifications(
job_id, object_id, client->client_table().GetLocalClientId()));
job_id, object_id, client->client_table().GetLocalClientId(), nullptr));
};
// Subscribe to notifications for this client. This allows us to request and
@ -1342,7 +1342,7 @@ void TestHashTable(const JobID &job_id, std::shared_ptr<gcs::RedisGcsClient> cli
RAY_CHECK_OK(client->resource_table().Subscribe(
job_id, ClientID::Nil(), notification_callback, subscribe_callback));
RAY_CHECK_OK(client->resource_table().RequestNotifications(
job_id, client_id, client->client_table().GetLocalClientId()));
job_id, client_id, client->client_table().GetLocalClientId(), nullptr));
// Step 1: Add elements to the hash table.
auto update_callback1 = [data_map1, compare_test](

View file

@ -0,0 +1,139 @@
#include "ray/gcs/subscription_executor.h"
namespace ray {
namespace gcs {
template <typename ID, typename Data, typename Table>
Status SubscriptionExecutor<ID, Data, Table>::AsyncSubscribe(
const ClientID &client_id, const SubscribeCallback<ID, Data> &subscribe,
const StatusCallback &done) {
// TODO(micafan) Optimize the lock when necessary.
// Consider avoiding locking in single-threaded processes.
std::lock_guard<std::mutex> lock(mutex_);
if (subscribe_all_callback_ != nullptr) {
RAY_LOG(DEBUG) << "Duplicate subscription! Already subscribed to all elements.";
return Status::Invalid("Duplicate subscription!");
}
if (registered_) {
if (subscribe != nullptr) {
RAY_LOG(DEBUG) << "Duplicate subscription! Already subscribed to specific elements"
", can't subscribe to all elements.";
return Status::Invalid("Duplicate subscription!");
}
return Status::OK();
}
auto on_subscribe = [this](RedisGcsClient *client, const ID &id,
const std::vector<Data> &result) {
if (result.empty()) {
return;
}
RAY_LOG(DEBUG) << "Subscribe received update of id " << id;
SubscribeCallback<ID, Data> sub_one_callback = nullptr;
SubscribeCallback<ID, Data> sub_all_callback = nullptr;
{
std::lock_guard<std::mutex> lock(mutex_);
const auto it = id_to_callback_map_.find(id);
if (it != id_to_callback_map_.end()) {
sub_one_callback = it->second;
}
sub_all_callback = subscribe_all_callback_;
}
if (sub_one_callback != nullptr) {
sub_one_callback(id, result.back());
}
if (sub_all_callback != nullptr) {
RAY_CHECK(sub_one_callback == nullptr);
sub_all_callback(id, result.back());
}
};
auto on_done = [done](RedisGcsClient *client) {
if (done != nullptr) {
done(Status::OK());
}
};
Status status = table_.Subscribe(JobID::Nil(), client_id, on_subscribe, on_done);
if (status.ok()) {
registered_ = true;
subscribe_all_callback_ = subscribe;
}
return status;
}
template <typename ID, typename Data, typename Table>
Status SubscriptionExecutor<ID, Data, Table>::AsyncSubscribe(
const ClientID &client_id, const ID &id, const SubscribeCallback<ID, Data> &subscribe,
const StatusCallback &done) {
Status status = AsyncSubscribe(client_id, nullptr, nullptr);
if (!status.ok()) {
return status;
}
auto on_done = [this, done, id](Status status) {
if (!status.ok()) {
std::lock_guard<std::mutex> lock(mutex_);
id_to_callback_map_.erase(id);
}
if (done != nullptr) {
done(status);
}
};
{
std::lock_guard<std::mutex> lock(mutex_);
const auto it = id_to_callback_map_.find(id);
if (it != id_to_callback_map_.end()) {
RAY_LOG(DEBUG) << "Duplicate subscription to id " << id << " client_id "
<< client_id;
return Status::Invalid("Duplicate subscription to element!");
}
status = table_.RequestNotifications(JobID::Nil(), id, client_id, on_done);
if (status.ok()) {
id_to_callback_map_[id] = subscribe;
}
}
return status;
}
template <typename ID, typename Data, typename Table>
Status SubscriptionExecutor<ID, Data, Table>::AsyncUnsubscribe(
const ClientID &client_id, const ID &id, const StatusCallback &done) {
{
std::lock_guard<std::mutex> lock(mutex_);
const auto it = id_to_callback_map_.find(id);
if (it == id_to_callback_map_.end()) {
RAY_LOG(DEBUG) << "Invalid Unsubscribe! id " << id << " client_id " << client_id;
return Status::Invalid("Invalid Unsubscribe, no existing subscription found.");
}
}
auto on_done = [this, id, done](Status status) {
if (status.ok()) {
std::lock_guard<std::mutex> lock(mutex_);
const auto it = id_to_callback_map_.find(id);
if (it != id_to_callback_map_.end()) {
id_to_callback_map_.erase(it);
}
}
if (done != nullptr) {
done(status);
}
};
return table_.CancelNotifications(JobID::Nil(), id, client_id, on_done);
}
template class SubscriptionExecutor<ActorID, ActorTableData, ActorTable>;
} // namespace gcs
} // namespace ray

View file

@ -0,0 +1,85 @@
#ifndef RAY_GCS_SUBSCRIPTION_EXECUTOR_H
#define RAY_GCS_SUBSCRIPTION_EXECUTOR_H
#include <atomic>
#include <mutex>
#include "ray/gcs/callback.h"
#include "ray/gcs/tables.h"
namespace ray {
namespace gcs {
/// \class SubscriptionExecutor
/// SubscriptionExecutor class encapsulates the implementation details of
/// subscribe/unsubscribe to elements (e.g.: actors or tasks or objects or nodes).
/// Support subscribing to a specific element or subscribing to all elements.
template <typename ID, typename Data, typename Table>
class SubscriptionExecutor {
public:
SubscriptionExecutor(Table &table) : table_(table) {}
~SubscriptionExecutor() {}
/// Subscribe to operations of all elements.
/// Repeated subscription will return a failure.
///
/// \param client_id The type of update to listen to. If this is nil, then a
/// message for each update will be received. Else, only
/// messages for the given client will be received.
/// \param subscribe Callback that will be called each time when an element
/// is registered or updated.
/// \param done Callback that will be called when subscription is complete.
/// \return Status
Status AsyncSubscribe(const ClientID &client_id,
const SubscribeCallback<ID, Data> &subscribe,
const StatusCallback &done);
/// Subscribe to operations of an element.
/// Repeated subscription to an element will return a failure.
///
/// \param client_id The type of update to listen to. If this is nil, then a
/// message for each update will be received. Else, only
/// messages for the given client will be received.
/// \param id The id of the element to be subscribe to.
/// \param subscribe Callback that will be called each time when the element
/// is registered or updated.
/// \param done Callback that will be called when subscription is complete.
/// \return Status
Status AsyncSubscribe(const ClientID &client_id, const ID &id,
const SubscribeCallback<ID, Data> &subscribe,
const StatusCallback &done);
/// Cancel subscription to an element.
/// Unsubscribing can only be called after the subscription request is completed.
///
/// \param client_id The type of update to listen to. If this is nil, then a
/// message for each update will be received. Else, only
/// messages for the given client will be received.
/// \param id The id of the element to be unsubscribed to.
/// \param done Callback that will be called when cancel subscription is complete.
/// \return Status
Status AsyncUnsubscribe(const ClientID &client_id, const ID &id,
const StatusCallback &done);
private:
Table &table_;
std::mutex mutex_;
/// Whether successfully registered subscription to GCS.
bool registered_{false};
/// Subscribe Callback of all elements.
SubscribeCallback<ID, Data> subscribe_all_callback_{nullptr};
/// A mapping from element ID to subscription callback.
typedef std::unordered_map<ID, SubscribeCallback<ID, Data>> IDToCallbackMap;
IDToCallbackMap id_to_callback_map_;
};
} // namespace gcs
} // namespace ray
#endif // RAY_GCS_SUBSCRIPTION_EXECUTOR_H

View file

@ -0,0 +1,201 @@
#include "gtest/gtest.h"
#include "ray/gcs/accessor_test_base.h"
#include "ray/gcs/callback.h"
#include "ray/gcs/redis_gcs_client.h"
namespace ray {
namespace gcs {
class SubscriptionExecutorTest : public AccessorTestBase<ActorID, ActorTableData> {
public:
typedef SubscriptionExecutor<ActorID, ActorTableData, ActorTable> ActorSubExecutor;
virtual void SetUp() {
AccessorTestBase<ActorID, ActorTableData>::SetUp();
actor_sub_executor_.reset(new ActorSubExecutor(gcs_client_->actor_table()));
subscribe_ = [this](const ActorID &id, const ActorTableData &data) {
const auto it = id_to_data_.find(id);
ASSERT_TRUE(it != id_to_data_.end());
--sub_pending_count_;
};
sub_done_ = [this](Status status) {
ASSERT_TRUE(status.ok()) << status;
--do_sub_pending_count_;
};
unsub_done_ = [this](Status status) {
ASSERT_TRUE(status.ok()) << status;
--do_unsub_pending_count_;
};
}
virtual void TearDown() {
AccessorTestBase<ActorID, ActorTableData>::TearDown();
ASSERT_EQ(sub_pending_count_, 0);
ASSERT_EQ(do_sub_pending_count_, 0);
ASSERT_EQ(do_unsub_pending_count_, 0);
}
protected:
virtual void GenTestData() {
for (size_t i = 0; i < 2; ++i) {
std::shared_ptr<ActorTableData> actor = std::make_shared<ActorTableData>();
actor->set_max_reconstructions(1);
actor->set_remaining_reconstructions(1);
JobID job_id = JobID::FromInt(i);
actor->set_job_id(job_id.Binary());
actor->set_state(ActorTableData::ALIVE);
ActorID actor_id = ActorID::Of(job_id, RandomTaskId(), /*parent_task_counter=*/i);
actor->set_actor_id(actor_id.Binary());
id_to_data_[actor_id] = actor;
}
}
size_t AsyncRegisterActorToGcs() {
ActorStateAccessor &actor_accessor = gcs_client_->Actors();
for (const auto &elem : id_to_data_) {
const auto &actor = elem.second;
auto done = [this](Status status) {
ASSERT_TRUE(status.ok());
--pending_count_;
};
++pending_count_;
Status status = actor_accessor.AsyncRegister(actor, done);
RAY_CHECK_OK(status);
}
return id_to_data_.size();
}
protected:
std::unique_ptr<ActorSubExecutor> actor_sub_executor_;
std::atomic<int> sub_pending_count_{0};
std::atomic<int> do_sub_pending_count_{0};
std::atomic<int> do_unsub_pending_count_{0};
SubscribeCallback<ActorID, ActorTableData> subscribe_{nullptr};
StatusCallback sub_done_{nullptr};
StatusCallback unsub_done_{nullptr};
};
TEST_F(SubscriptionExecutorTest, SubscribeAllTest) {
++do_sub_pending_count_;
Status status =
actor_sub_executor_->AsyncSubscribe(ClientID::Nil(), subscribe_, sub_done_);
WaitPendingDone(do_sub_pending_count_, wait_pending_timeout_);
ASSERT_TRUE(status.ok());
sub_pending_count_ = id_to_data_.size();
AsyncRegisterActorToGcs();
status = actor_sub_executor_->AsyncSubscribe(ClientID::Nil(), subscribe_, sub_done_);
ASSERT_TRUE(status.IsInvalid());
WaitPendingDone(sub_pending_count_, wait_pending_timeout_);
}
TEST_F(SubscriptionExecutorTest, SubscribeOneTest) {
Status status;
for (const auto &item : id_to_data_) {
++do_sub_pending_count_;
status = actor_sub_executor_->AsyncSubscribe(ClientID::Nil(), item.first, subscribe_,
sub_done_);
ASSERT_TRUE(status.ok());
}
WaitPendingDone(do_sub_pending_count_, wait_pending_timeout_);
sub_pending_count_ = id_to_data_.size();
AsyncRegisterActorToGcs();
for (const auto &item : id_to_data_) {
status = actor_sub_executor_->AsyncSubscribe(ClientID::Nil(), item.first, subscribe_,
sub_done_);
ASSERT_TRUE(status.IsInvalid());
}
WaitPendingDone(sub_pending_count_, wait_pending_timeout_);
}
TEST_F(SubscriptionExecutorTest, SubscribeOneWithClientIDTest) {
const auto &item = id_to_data_.begin();
++do_sub_pending_count_;
++sub_pending_count_;
Status status = actor_sub_executor_->AsyncSubscribe(ClientID::FromRandom(), item->first,
subscribe_, sub_done_);
WaitPendingDone(do_sub_pending_count_, wait_pending_timeout_);
ASSERT_TRUE(status.ok());
AsyncRegisterActorToGcs();
WaitPendingDone(sub_pending_count_, wait_pending_timeout_);
}
TEST_F(SubscriptionExecutorTest, SubscribeAllAndSubscribeOneTest) {
++do_sub_pending_count_;
Status status =
actor_sub_executor_->AsyncSubscribe(ClientID::Nil(), subscribe_, sub_done_);
ASSERT_TRUE(status.ok());
WaitPendingDone(do_sub_pending_count_, wait_pending_timeout_);
for (const auto &item : id_to_data_) {
status = actor_sub_executor_->AsyncSubscribe(ClientID::Nil(), item.first, subscribe_,
sub_done_);
ASSERT_FALSE(status.ok());
}
sub_pending_count_ = id_to_data_.size();
AsyncRegisterActorToGcs();
WaitPendingDone(sub_pending_count_, wait_pending_timeout_);
}
TEST_F(SubscriptionExecutorTest, UnsubscribeTest) {
Status status;
for (const auto &item : id_to_data_) {
status =
actor_sub_executor_->AsyncUnsubscribe(ClientID::Nil(), item.first, unsub_done_);
ASSERT_TRUE(status.IsInvalid());
}
for (const auto &item : id_to_data_) {
++do_sub_pending_count_;
status = actor_sub_executor_->AsyncSubscribe(ClientID::Nil(), item.first, subscribe_,
sub_done_);
ASSERT_TRUE(status.ok());
}
WaitPendingDone(do_sub_pending_count_, wait_pending_timeout_);
for (const auto &item : id_to_data_) {
++do_unsub_pending_count_;
status =
actor_sub_executor_->AsyncUnsubscribe(ClientID::Nil(), item.first, unsub_done_);
ASSERT_TRUE(status.ok());
}
WaitPendingDone(do_unsub_pending_count_, wait_pending_timeout_);
for (const auto &item : id_to_data_) {
status =
actor_sub_executor_->AsyncUnsubscribe(ClientID::Nil(), item.first, unsub_done_);
ASSERT_TRUE(!status.ok());
}
for (const auto &item : id_to_data_) {
++do_sub_pending_count_;
status = actor_sub_executor_->AsyncSubscribe(ClientID::Nil(), item.first, subscribe_,
sub_done_);
ASSERT_TRUE(status.ok());
}
WaitPendingDone(do_sub_pending_count_, wait_pending_timeout_);
for (const auto &item : id_to_data_) {
++do_unsub_pending_count_;
status =
actor_sub_executor_->AsyncUnsubscribe(ClientID::Nil(), item.first, unsub_done_);
ASSERT_TRUE(status.ok());
}
WaitPendingDone(do_unsub_pending_count_, wait_pending_timeout_);
for (const auto &item : id_to_data_) {
++do_sub_pending_count_;
status = actor_sub_executor_->AsyncSubscribe(ClientID::Nil(), item.first, subscribe_,
sub_done_);
ASSERT_TRUE(status.ok());
}
WaitPendingDone(do_sub_pending_count_, wait_pending_timeout_);
sub_pending_count_ = id_to_data_.size();
AsyncRegisterActorToGcs();
WaitPendingDone(sub_pending_count_, wait_pending_timeout_);
}
} // namespace gcs
} // namespace ray

View file

@ -162,22 +162,44 @@ Status Log<ID, Data>::Subscribe(const JobID &job_id, const ClientID &client_id,
template <typename ID, typename Data>
Status Log<ID, Data>::RequestNotifications(const JobID &job_id, const ID &id,
const ClientID &client_id) {
const ClientID &client_id,
const StatusCallback &done) {
RAY_CHECK(subscribe_callback_index_ >= 0)
<< "Client requested notifications on a key before Subscribe completed";
RedisCallback callback = nullptr;
if (done != nullptr) {
callback = [done](const CallbackReply &reply) {
const auto status = reply.IsNil()
? Status::OK()
: Status::RedisError("request notifications failed.");
done(status);
};
}
return GetRedisContext(id)->RunAsync("RAY.TABLE_REQUEST_NOTIFICATIONS", id,
client_id.Data(), client_id.Size(), prefix_,
pubsub_channel_, nullptr);
pubsub_channel_, callback);
}
template <typename ID, typename Data>
Status Log<ID, Data>::CancelNotifications(const JobID &job_id, const ID &id,
const ClientID &client_id) {
const ClientID &client_id,
const StatusCallback &done) {
RAY_CHECK(subscribe_callback_index_ >= 0)
<< "Client canceled notifications on a key before Subscribe completed";
RedisCallback callback = nullptr;
if (done != nullptr) {
callback = [done](const CallbackReply &reply) {
const auto status = reply.ReadAsStatus();
done(status);
};
}
return GetRedisContext(id)->RunAsync("RAY.TABLE_CANCEL_NOTIFICATIONS", id,
client_id.Data(), client_id.Size(), prefix_,
pubsub_channel_, nullptr);
pubsub_channel_, callback);
}
template <typename ID, typename Data>
@ -621,7 +643,8 @@ Status ClientTable::Connect(const GcsNodeInfo &local_node_info) {
// Callback to request notifications from the client table once we've
// successfully subscribed.
auto subscription_callback = [this](RedisGcsClient *c) {
RAY_CHECK_OK(RequestNotifications(JobID::Nil(), client_log_key_, node_id_));
RAY_CHECK_OK(RequestNotifications(JobID::Nil(), client_log_key_, node_id_,
/*done*/ nullptr));
};
// Subscribe to the client table.
RAY_CHECK_OK(
@ -636,7 +659,8 @@ Status ClientTable::Disconnect(const DisconnectCallback &callback) {
auto add_callback = [this, callback](RedisGcsClient *client, const ClientID &id,
const GcsNodeInfo &data) {
HandleConnected(client, data);
RAY_CHECK_OK(CancelNotifications(JobID::Nil(), client_log_key_, id));
RAY_CHECK_OK(
CancelNotifications(JobID::Nil(), client_log_key_, id, /*done*/ nullptr));
if (callback != nullptr) {
callback();
}

View file

@ -11,6 +11,7 @@
#include "ray/common/status.h"
#include "ray/util/logging.h"
#include "ray/gcs/callback.h"
#include "ray/gcs/redis_context.h"
#include "ray/protobuf/gcs.pb.h"
@ -56,9 +57,11 @@ template <typename ID>
class PubsubInterface {
public:
virtual Status RequestNotifications(const JobID &job_id, const ID &id,
const ClientID &client_id) = 0;
const ClientID &client_id,
const StatusCallback &done) = 0;
virtual Status CancelNotifications(const JobID &job_id, const ID &id,
const ClientID &client_id) = 0;
const ClientID &client_id,
const StatusCallback &done) = 0;
virtual ~PubsubInterface(){};
};
@ -182,20 +185,22 @@ class Log : public LogInterface<ID, Data>, virtual public PubsubInterface<ID> {
/// \param job_id The ID of the job.
/// \param id The ID of the key to request notifications for.
/// \param client_id The client who is requesting notifications. Before
/// \param done Callback that is called when request notifications is complete.
/// notifications can be requested, a call to `Subscribe` to this
/// table with the same `client_id` must complete successfully.
/// \return Status
Status RequestNotifications(const JobID &job_id, const ID &id,
const ClientID &client_id);
const ClientID &client_id, const StatusCallback &done);
/// Cancel notifications about a key in this table.
///
/// \param job_id The ID of the job.
/// \param id The ID of the key to request notifications for.
/// \param client_id The client who originally requested notifications.
/// \param done Callback that is called when cancel notifications is complete.
/// \return Status
Status CancelNotifications(const JobID &job_id, const ID &id,
const ClientID &client_id);
Status CancelNotifications(const JobID &job_id, const ID &id, const ClientID &client_id,
const StatusCallback &done);
/// Delete an entire key from redis.
///

View file

@ -159,7 +159,8 @@ ray::Status ObjectDirectory::SubscribeObjectLocations(const UniqueID &callback_i
if (it == listeners_.end()) {
it = listeners_.emplace(object_id, LocationListenerState()).first;
status = gcs_client_->object_table().RequestNotifications(
JobID::Nil(), object_id, gcs_client_->client_table().GetLocalClientId());
JobID::Nil(), object_id, gcs_client_->client_table().GetLocalClientId(),
/*done*/ nullptr);
}
auto &listener_state = it->second;
// TODO(hme): Make this fatal after implementing Pull suppression.
@ -187,7 +188,8 @@ ray::Status ObjectDirectory::UnsubscribeObjectLocations(const UniqueID &callback
entry->second.callbacks.erase(callback_id);
if (entry->second.callbacks.empty()) {
status = gcs_client_->object_table().CancelNotifications(
JobID::Nil(), object_id, gcs_client_->client_table().GetLocalClientId());
JobID::Nil(), object_id, gcs_client_->client_table().GetLocalClientId(),
/*done*/ nullptr);
listeners_.erase(entry);
}
return status;

View file

@ -291,7 +291,8 @@ bool LineageCache::SubscribeTask(const TaskID &task_id) {
if (unsubscribed) {
// Request notifications for the task if we haven't already requested
// notifications for it.
RAY_CHECK_OK(task_pubsub_.RequestNotifications(JobID::Nil(), task_id, client_id_));
RAY_CHECK_OK(task_pubsub_.RequestNotifications(JobID::Nil(), task_id, client_id_,
/*done*/ nullptr));
}
// Return whether we were previously unsubscribed to this task and are now
// subscribed.
@ -304,7 +305,8 @@ bool LineageCache::UnsubscribeTask(const TaskID &task_id) {
if (subscribed) {
// Cancel notifications for the task if we previously requested
// notifications for it.
RAY_CHECK_OK(task_pubsub_.CancelNotifications(JobID::Nil(), task_id, client_id_));
RAY_CHECK_OK(task_pubsub_.CancelNotifications(JobID::Nil(), task_id, client_id_,
/*done*/ nullptr));
subscribed_tasks_.erase(it);
}
// Return whether we were previously subscribed to this task and are now

View file

@ -7,8 +7,12 @@
#include "ray/common/task/task_execution_spec.h"
#include "ray/common/task/task_spec.h"
#include "ray/common/task/task_util.h"
#include "ray/gcs/callback.h"
#include "ray/raylet/format/node_manager_generated.h"
#include "ray/raylet/lineage_cache.h"
#include "ray/util/test_util.h"
namespace ray {
@ -67,7 +71,8 @@ class MockGcs : public gcs::TableInterface<TaskID, TaskTableData>,
}
Status RequestNotifications(const JobID &job_id, const TaskID &task_id,
const ClientID &client_id) {
const ClientID &client_id,
const gcs::StatusCallback &done) {
subscribed_tasks_.insert(task_id);
if (task_table_.count(task_id) == 1) {
callbacks_.push_back({notification_callback_, task_id});
@ -77,7 +82,7 @@ class MockGcs : public gcs::TableInterface<TaskID, TaskTableData>,
}
Status CancelNotifications(const JobID &job_id, const TaskID &task_id,
const ClientID &client_id) {
const ClientID &client_id, const gcs::StatusCallback &done) {
subscribed_tasks_.erase(task_id);
return ray::Status::OK();
}

View file

@ -53,7 +53,8 @@ void ReconstructionPolicy::SetTaskTimeout(
// task is still required after this initial period, then we now
// subscribe to task lease notifications.
RAY_CHECK_OK(task_lease_pubsub_.RequestNotifications(JobID::Nil(), task_id,
client_id_));
client_id_,
/*done*/ nullptr));
it->second.subscribed = true;
}
} else {
@ -200,8 +201,9 @@ void ReconstructionPolicy::Cancel(const ObjectID &object_id) {
if (it->second.created_objects.empty()) {
// Cancel notifications for the task lease if we were subscribed to them.
if (it->second.subscribed) {
RAY_CHECK_OK(
task_lease_pubsub_.CancelNotifications(JobID::Nil(), task_id, client_id_));
RAY_CHECK_OK(task_lease_pubsub_.CancelNotifications(JobID::Nil(), task_id,
client_id_,
/*done*/ nullptr));
}
listening_tasks_.erase(it);
}

View file

@ -5,6 +5,8 @@
#include <boost/asio.hpp>
#include "ray/gcs/callback.h"
#include "ray/raylet/format/node_manager_generated.h"
#include "ray/raylet/reconstruction_policy.h"
@ -102,7 +104,8 @@ class MockGcs : public gcs::PubsubInterface<TaskID>,
}
Status RequestNotifications(const JobID &job_id, const TaskID &task_id,
const ClientID &client_id) {
const ClientID &client_id,
const gcs::StatusCallback &done) {
subscribed_tasks_.insert(task_id);
auto entry = task_lease_table_.find(task_id);
if (entry == task_lease_table_.end()) {
@ -114,7 +117,7 @@ class MockGcs : public gcs::PubsubInterface<TaskID>,
}
Status CancelNotifications(const JobID &job_id, const TaskID &task_id,
const ClientID &client_id) {
const ClientID &client_id, const gcs::StatusCallback &done) {
subscribed_tasks_.erase(task_id);
return ray::Status::OK();
}

View file

@ -6,7 +6,7 @@
set -e
set -x
bazel build "//:redis_gcs_client_test" "//:actor_state_accessor_test" "//:asio_test" "//:libray_redis_module.so"
bazel build "//:redis_gcs_client_test" "//:actor_state_accessor_test" "//:subscription_executor_test" "//:asio_test" "//:libray_redis_module.so"
# Start Redis.
if [[ "${RAY_USE_NEW_GCS}" = "on" ]]; then
@ -25,6 +25,7 @@ sleep 1s
./bazel-bin/redis_gcs_client_test
./bazel-bin/actor_state_accessor_test
./bazel-bin/subscription_executor_test
./bazel-bin/asio_test
./bazel-genfiles/redis-cli -p 6379 shutdown

View file

@ -19,8 +19,8 @@ bool WaitForCondition(std::function<bool()> condition, int timeout_ms) {
return true;
}
// sleep 100ms.
const int wait_interval_ms = 100;
// sleep 10ms.
const int wait_interval_ms = 10;
usleep(wait_interval_ms * 1000);
wait_time += wait_interval_ms;
if (wait_time > timeout_ms) {