support for subscription to an actor (#5269)

This commit is contained in:
micafan 2019-08-20 20:32:53 +08:00 committed by Hao Chen
parent 851c5b2dae
commit da7bdacea5
20 changed files with 655 additions and 130 deletions

View file

@ -640,6 +640,7 @@ cc_library(
], ],
) )
# TODO(micafan) Replace cc_binary with cc_test for GCS test.
cc_binary( cc_binary(
name = "redis_gcs_client_test", name = "redis_gcs_client_test",
testonly = 1, testonly = 1,
@ -662,6 +663,17 @@ cc_binary(
], ],
) )
cc_binary(
name = "subscription_executor_test",
testonly = 1,
srcs = ["src/ray/gcs/subscription_executor_test.cc"],
copts = COPTS,
deps = [
":gcs",
"@com_google_googletest//:gtest_main",
],
)
cc_binary( cc_binary(
name = "asio_test", name = "asio_test",
testonly = 1, testonly = 1,

View file

@ -0,0 +1,76 @@
#include <atomic>
#include <chrono>
#include <string>
#include <thread>
#include <vector>
#include "gtest/gtest.h"
#include "ray/gcs/redis_gcs_client.h"
#include "ray/util/test_util.h"
namespace ray {
namespace gcs {
template <typename ID, typename Data>
class AccessorTestBase : public ::testing::Test {
public:
AccessorTestBase() : options_("127.0.0.1", 6379, "", true) {}
virtual ~AccessorTestBase() {}
virtual void SetUp() {
GenTestData();
gcs_client_.reset(new RedisGcsClient(options_));
RAY_CHECK_OK(gcs_client_->Connect(io_service_));
work_thread.reset(new std::thread([this] {
std::unique_ptr<boost::asio::io_service::work> work(
new boost::asio::io_service::work(io_service_));
io_service_.run();
}));
}
virtual void TearDown() {
gcs_client_->Disconnect();
io_service_.stop();
work_thread->join();
work_thread.reset();
gcs_client_.reset();
ClearTestData();
}
protected:
virtual void GenTestData() = 0;
void ClearTestData() { id_to_data_.clear(); }
void WaitPendingDone(std::chrono::milliseconds timeout) {
WaitPendingDone(pending_count_, timeout);
}
void WaitPendingDone(std::atomic<int> &pending_count,
std::chrono::milliseconds timeout) {
auto condition = [&pending_count]() { return pending_count == 0; };
EXPECT_TRUE(WaitForCondition(condition, timeout.count()));
}
protected:
GcsClientOptions options_;
std::unique_ptr<RedisGcsClient> gcs_client_;
boost::asio::io_service io_service_;
std::unique_ptr<std::thread> work_thread;
std::unordered_map<ID, std::shared_ptr<Data>> id_to_data_;
std::atomic<int> pending_count_{0};
std::chrono::milliseconds wait_pending_timeout_{10000};
};
} // namespace gcs
} // namespace ray

View file

@ -8,7 +8,7 @@ namespace ray {
namespace gcs { namespace gcs {
ActorStateAccessor::ActorStateAccessor(RedisGcsClient &client_impl) ActorStateAccessor::ActorStateAccessor(RedisGcsClient &client_impl)
: client_impl_(client_impl) {} : client_impl_(client_impl), actor_sub_executor_(client_impl_.actor_table()) {}
Status ActorStateAccessor::AsyncGet(const ActorID &actor_id, Status ActorStateAccessor::AsyncGet(const ActorID &actor_id,
const MultiItemCallback<ActorTableData> &callback) { const MultiItemCallback<ActorTableData> &callback) {
@ -90,23 +90,19 @@ Status ActorStateAccessor::AsyncSubscribe(
const SubscribeCallback<ActorID, ActorTableData> &subscribe, const SubscribeCallback<ActorID, ActorTableData> &subscribe,
const StatusCallback &done) { const StatusCallback &done) {
RAY_CHECK(subscribe != nullptr); RAY_CHECK(subscribe != nullptr);
auto on_subscribe = [subscribe](RedisGcsClient *client, const ActorID &actor_id, return actor_sub_executor_.AsyncSubscribe(ClientID::Nil(), subscribe, done);
const std::vector<ActorTableData> &data) { }
if (!data.empty()) {
// We only need the last entry, because it represents the latest state of
// this actor.
subscribe(actor_id, data.back());
}
};
auto on_done = [done](RedisGcsClient *client) { Status ActorStateAccessor::AsyncSubscribe(
if (done != nullptr) { const ActorID &actor_id, const SubscribeCallback<ActorID, ActorTableData> &subscribe,
done(Status::OK()); const StatusCallback &done) {
} RAY_CHECK(subscribe != nullptr);
}; return actor_sub_executor_.AsyncSubscribe(node_id_, actor_id, subscribe, done);
}
ActorTable &actor_table = client_impl_.actor_table(); Status ActorStateAccessor::AsyncUnsubscribe(const ActorID &actor_id,
return actor_table.Subscribe(JobID::Nil(), ClientID::Nil(), on_subscribe, on_done); const StatusCallback &done) {
return actor_sub_executor_.AsyncUnsubscribe(node_id_, actor_id, done);
} }
} // namespace gcs } // namespace gcs

View file

@ -3,6 +3,7 @@
#include "ray/common/id.h" #include "ray/common/id.h"
#include "ray/gcs/callback.h" #include "ray/gcs/callback.h"
#include "ray/gcs/subscription_executor.h"
#include "ray/gcs/tables.h" #include "ray/gcs/tables.h"
namespace ray { namespace ray {
@ -50,7 +51,7 @@ class ActorStateAccessor {
const std::shared_ptr<ActorTableData> &data_ptr, const std::shared_ptr<ActorTableData> &data_ptr,
const StatusCallback &callback); const StatusCallback &callback);
/// Subscribe to any register operations of actors. /// Subscribe to any register or update operations of actors.
/// ///
/// \param subscribe Callback that will be called each time when an actor is registered /// \param subscribe Callback that will be called each time when an actor is registered
/// or updated. /// or updated.
@ -60,8 +61,36 @@ class ActorStateAccessor {
Status AsyncSubscribe(const SubscribeCallback<ActorID, ActorTableData> &subscribe, Status AsyncSubscribe(const SubscribeCallback<ActorID, ActorTableData> &subscribe,
const StatusCallback &done); const StatusCallback &done);
/// Subscribe to any update operations of an actor.
///
/// \param actor_id The ID of actor to be subscribed to.
/// \param subscribe Callback that will be called each time when the actor is updated.
/// \param done Callback that will be called when subscription is complete.
/// \return Status
Status AsyncSubscribe(const ActorID &actor_id,
const SubscribeCallback<ActorID, ActorTableData> &subscribe,
const StatusCallback &done);
/// Cancel subscription to an actor.
///
/// \param actor_id The ID of the actor to be unsubscribed to.
/// \param done Callback that will be called when unsubscribe is complete.
/// \return Status
Status AsyncUnsubscribe(const ActorID &actor_id, const StatusCallback &done);
private: private:
RedisGcsClient &client_impl_; RedisGcsClient &client_impl_;
// Use a random ClientID for actor subscription. Because:
// If we use ClientID::Nil, GCS will still send all actors' updates to this GCS Client.
// Even we can filter out irrelevant updates, but there will be extra overhead.
// And because the new GCS Client will no longer hold the local ClientID, so we use
// random ClientID instead.
// TODO(micafan): Remove this random id, once GCS becomes a service.
ClientID node_id_{ClientID::FromRandom()};
typedef SubscriptionExecutor<ActorID, ActorTableData, ActorTable>
ActorSubscriptionExecutor;
ActorSubscriptionExecutor actor_sub_executor_;
}; };
} // namespace gcs } // namespace gcs

View file

@ -4,6 +4,7 @@
#include <thread> #include <thread>
#include <vector> #include <vector>
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "ray/gcs/accessor_test_base.h"
#include "ray/gcs/redis_gcs_client.h" #include "ray/gcs/redis_gcs_client.h"
#include "ray/util/test_util.h" #include "ray/util/test_util.h"
@ -11,39 +12,9 @@ namespace ray {
namespace gcs { namespace gcs {
class ActorStateAccessorTest : public ::testing::Test { class ActorStateAccessorTest : public AccessorTestBase<ActorID, ActorTableData> {
public:
ActorStateAccessorTest() : options_("127.0.0.1", 6379, "", true) {}
virtual void SetUp() {
GenTestData();
gcs_client_.reset(new RedisGcsClient(options_));
RAY_CHECK_OK(gcs_client_->Connect(io_service_));
work_thread.reset(new std::thread([this] {
std::unique_ptr<boost::asio::io_service::work> work(
new boost::asio::io_service::work(io_service_));
io_service_.run();
}));
}
virtual void TearDown() {
gcs_client_->Disconnect();
io_service_.stop();
work_thread->join();
work_thread.reset();
gcs_client_.reset();
ClearTestData();
}
protected: protected:
void GenTestData() { GenActorData(); } virtual void GenTestData() {
void GenActorData() {
for (size_t i = 0; i < 100; ++i) { for (size_t i = 0; i < 100; ++i) {
std::shared_ptr<ActorTableData> actor = std::make_shared<ActorTableData>(); std::shared_ptr<ActorTableData> actor = std::make_shared<ActorTableData>();
actor->set_max_reconstructions(1); actor->set_max_reconstructions(1);
@ -53,42 +24,15 @@ class ActorStateAccessorTest : public ::testing::Test {
actor->set_state(ActorTableData::ALIVE); actor->set_state(ActorTableData::ALIVE);
ActorID actor_id = ActorID::Of(job_id, RandomTaskId(), /*parent_task_counter=*/i); ActorID actor_id = ActorID::Of(job_id, RandomTaskId(), /*parent_task_counter=*/i);
actor->set_actor_id(actor_id.Binary()); actor->set_actor_id(actor_id.Binary());
actor_datas_[actor_id] = actor; id_to_data_[actor_id] = actor;
} }
} }
void ClearTestData() { actor_datas_.clear(); }
void WaitPendingDone(std::chrono::milliseconds timeout) {
WaitPendingDone(pending_count_, timeout);
}
void WaitPendingDone(std::atomic<int> &pending_count,
std::chrono::milliseconds timeout) {
while (pending_count != 0 && timeout.count() > 0) {
std::chrono::milliseconds interval(10);
std::this_thread::sleep_for(interval);
timeout -= interval;
}
EXPECT_EQ(pending_count, 0);
}
protected:
GcsClientOptions options_;
std::unique_ptr<RedisGcsClient> gcs_client_;
boost::asio::io_service io_service_;
std::unique_ptr<std::thread> work_thread;
std::unordered_map<ActorID, std::shared_ptr<ActorTableData>> actor_datas_;
std::atomic<int> pending_count_{0};
}; };
TEST_F(ActorStateAccessorTest, RegisterAndGet) { TEST_F(ActorStateAccessorTest, RegisterAndGet) {
ActorStateAccessor &actor_accessor = gcs_client_->Actors(); ActorStateAccessor &actor_accessor = gcs_client_->Actors();
// register // register
for (const auto &elem : actor_datas_) { for (const auto &elem : id_to_data_) {
const auto &actor = elem.second; const auto &actor = elem.second;
++pending_count_; ++pending_count_;
RAY_CHECK_OK(actor_accessor.AsyncRegister(actor, [this](Status status) { RAY_CHECK_OK(actor_accessor.AsyncRegister(actor, [this](Status status) {
@ -97,35 +41,33 @@ TEST_F(ActorStateAccessorTest, RegisterAndGet) {
})); }));
} }
std::chrono::milliseconds timeout(10000); WaitPendingDone(wait_pending_timeout_);
WaitPendingDone(timeout);
// get // get
for (const auto &elem : actor_datas_) { for (const auto &elem : id_to_data_) {
++pending_count_; ++pending_count_;
RAY_CHECK_OK(actor_accessor.AsyncGet( RAY_CHECK_OK(actor_accessor.AsyncGet(
elem.first, [this](Status status, std::vector<ActorTableData> datas) { elem.first, [this](Status status, std::vector<ActorTableData> datas) {
ASSERT_EQ(datas.size(), 1U); ASSERT_EQ(datas.size(), 1U);
ActorID actor_id = ActorID::FromBinary(datas[0].actor_id()); ActorID actor_id = ActorID::FromBinary(datas[0].actor_id());
auto it = actor_datas_.find(actor_id); auto it = id_to_data_.find(actor_id);
ASSERT_TRUE(it != actor_datas_.end()); ASSERT_TRUE(it != id_to_data_.end());
--pending_count_; --pending_count_;
})); }));
} }
WaitPendingDone(timeout); WaitPendingDone(wait_pending_timeout_);
} }
TEST_F(ActorStateAccessorTest, Subscribe) { TEST_F(ActorStateAccessorTest, Subscribe) {
ActorStateAccessor &actor_accessor = gcs_client_->Actors(); ActorStateAccessor &actor_accessor = gcs_client_->Actors();
std::chrono::milliseconds timeout(10000);
// subscribe // subscribe
std::atomic<int> sub_pending_count(0); std::atomic<int> sub_pending_count(0);
std::atomic<int> do_sub_pending_count(0); std::atomic<int> do_sub_pending_count(0);
auto subscribe = [this, &sub_pending_count](const ActorID &actor_id, auto subscribe = [this, &sub_pending_count](const ActorID &actor_id,
const ActorTableData &data) { const ActorTableData &data) {
const auto it = actor_datas_.find(actor_id); const auto it = id_to_data_.find(actor_id);
ASSERT_TRUE(it != actor_datas_.end()); ASSERT_TRUE(it != id_to_data_.end());
--sub_pending_count; --sub_pending_count;
}; };
auto done = [&do_sub_pending_count](Status status) { auto done = [&do_sub_pending_count](Status status) {
@ -136,11 +78,11 @@ TEST_F(ActorStateAccessorTest, Subscribe) {
++do_sub_pending_count; ++do_sub_pending_count;
RAY_CHECK_OK(actor_accessor.AsyncSubscribe(subscribe, done)); RAY_CHECK_OK(actor_accessor.AsyncSubscribe(subscribe, done));
// Wait until subscribe finishes. // Wait until subscribe finishes.
WaitPendingDone(do_sub_pending_count, timeout); WaitPendingDone(do_sub_pending_count, wait_pending_timeout_);
// register // register
std::atomic<int> register_pending_count(0); std::atomic<int> register_pending_count(0);
for (const auto &elem : actor_datas_) { for (const auto &elem : id_to_data_) {
const auto &actor = elem.second; const auto &actor = elem.second;
++sub_pending_count; ++sub_pending_count;
++register_pending_count; ++register_pending_count;
@ -151,10 +93,10 @@ TEST_F(ActorStateAccessorTest, Subscribe) {
})); }));
} }
// Wait until register finishes. // Wait until register finishes.
WaitPendingDone(register_pending_count, timeout); WaitPendingDone(register_pending_count, wait_pending_timeout_);
// Wait for all subscribe notifications. // Wait for all subscribe notifications.
WaitPendingDone(sub_pending_count, timeout); WaitPendingDone(sub_pending_count, wait_pending_timeout_);
} }
} // namespace gcs } // namespace gcs

View file

@ -95,7 +95,7 @@ class GcsClientInterface : public std::enable_shared_from_this<GcsClientInterfac
GcsClientOptions options_; GcsClientOptions options_;
// Whether this client is connected to GCS. /// Whether this client is connected to GCS.
bool is_connected_{false}; bool is_connected_{false};
std::unique_ptr<ActorStateAccessor> actor_accessor_; std::unique_ptr<ActorStateAccessor> actor_accessor_;

View file

@ -19,6 +19,7 @@ class RedisContext;
class RAY_EXPORT RedisGcsClient : public GcsClientInterface { class RAY_EXPORT RedisGcsClient : public GcsClientInterface {
friend class ActorStateAccessor; friend class ActorStateAccessor;
friend class SubscriptionExecutorTest;
public: public:
/// Constructor of RedisGcsClient. /// Constructor of RedisGcsClient.

View file

@ -741,7 +741,7 @@ void TestTableSubscribeId(const JobID &job_id,
num_modifications](gcs::RedisGcsClient *client) { num_modifications](gcs::RedisGcsClient *client) {
// Request notifications for one of the keys. // Request notifications for one of the keys.
RAY_CHECK_OK(client->raylet_task_table().RequestNotifications( RAY_CHECK_OK(client->raylet_task_table().RequestNotifications(
job_id, task_id2, client->client_table().GetLocalClientId())); job_id, task_id2, client->client_table().GetLocalClientId(), nullptr));
// Write both keys. We should only receive notifications for the key that // Write both keys. We should only receive notifications for the key that
// we requested them for. // we requested them for.
for (uint64_t i = 0; i < num_modifications; i++) { for (uint64_t i = 0; i < num_modifications; i++) {
@ -814,7 +814,7 @@ void TestLogSubscribeId(const JobID &job_id,
job_ids2](gcs::RedisGcsClient *client) { job_ids2](gcs::RedisGcsClient *client) {
// Request notifications for one of the keys. // Request notifications for one of the keys.
RAY_CHECK_OK(client->job_table().RequestNotifications( RAY_CHECK_OK(client->job_table().RequestNotifications(
job_id, job_id2, client->client_table().GetLocalClientId())); job_id, job_id2, client->client_table().GetLocalClientId(), nullptr));
// Write both keys. We should only receive notifications for the key that // Write both keys. We should only receive notifications for the key that
// we requested them for. // we requested them for.
auto remaining = std::vector<std::string>(++job_ids1.begin(), job_ids1.end()); auto remaining = std::vector<std::string>(++job_ids1.begin(), job_ids1.end());
@ -890,7 +890,7 @@ void TestSetSubscribeId(const JobID &job_id,
managers2](gcs::RedisGcsClient *client) { managers2](gcs::RedisGcsClient *client) {
// Request notifications for one of the keys. // Request notifications for one of the keys.
RAY_CHECK_OK(client->object_table().RequestNotifications( RAY_CHECK_OK(client->object_table().RequestNotifications(
job_id, object_id2, client->client_table().GetLocalClientId())); job_id, object_id2, client->client_table().GetLocalClientId(), nullptr));
// Write both keys. We should only receive notifications for the key that // Write both keys. We should only receive notifications for the key that
// we requested them for. // we requested them for.
auto remaining = std::vector<std::string>(++managers1.begin(), managers1.end()); auto remaining = std::vector<std::string>(++managers1.begin(), managers1.end());
@ -964,9 +964,9 @@ void TestTableSubscribeCancel(const JobID &job_id,
// Request notifications, then cancel immediately. We should receive a // Request notifications, then cancel immediately. We should receive a
// notification for the current value at the key. // notification for the current value at the key.
RAY_CHECK_OK(client->raylet_task_table().RequestNotifications( RAY_CHECK_OK(client->raylet_task_table().RequestNotifications(
job_id, task_id, client->client_table().GetLocalClientId())); job_id, task_id, client->client_table().GetLocalClientId(), nullptr));
RAY_CHECK_OK(client->raylet_task_table().CancelNotifications( RAY_CHECK_OK(client->raylet_task_table().CancelNotifications(
job_id, task_id, client->client_table().GetLocalClientId())); job_id, task_id, client->client_table().GetLocalClientId(), nullptr));
// Write to the key. Since we canceled notifications, we should not receive // Write to the key. Since we canceled notifications, we should not receive
// a notification for these writes. // a notification for these writes.
for (uint64_t i = 1; i < num_modifications; i++) { for (uint64_t i = 1; i < num_modifications; i++) {
@ -976,7 +976,7 @@ void TestTableSubscribeCancel(const JobID &job_id,
// Request notifications again. We should receive a notification for the // Request notifications again. We should receive a notification for the
// current value at the key. // current value at the key.
RAY_CHECK_OK(client->raylet_task_table().RequestNotifications( RAY_CHECK_OK(client->raylet_task_table().RequestNotifications(
job_id, task_id, client->client_table().GetLocalClientId())); job_id, task_id, client->client_table().GetLocalClientId(), nullptr));
}; };
// Subscribe to notifications for this client. This allows us to request and // Subscribe to notifications for this client. This allows us to request and
@ -1033,9 +1033,9 @@ void TestLogSubscribeCancel(const JobID &job_id,
// Request notifications, then cancel immediately. We should receive a // Request notifications, then cancel immediately. We should receive a
// notification for the current value at the key. // notification for the current value at the key.
RAY_CHECK_OK(client->job_table().RequestNotifications( RAY_CHECK_OK(client->job_table().RequestNotifications(
job_id, random_job_id, client->client_table().GetLocalClientId())); job_id, random_job_id, client->client_table().GetLocalClientId(), nullptr));
RAY_CHECK_OK(client->job_table().CancelNotifications( RAY_CHECK_OK(client->job_table().CancelNotifications(
job_id, random_job_id, client->client_table().GetLocalClientId())); job_id, random_job_id, client->client_table().GetLocalClientId(), nullptr));
// Append to the key. Since we canceled notifications, we should not // Append to the key. Since we canceled notifications, we should not
// receive a notification for these writes. // receive a notification for these writes.
auto remaining = std::vector<std::string>(++job_ids.begin(), job_ids.end()); auto remaining = std::vector<std::string>(++job_ids.begin(), job_ids.end());
@ -1047,7 +1047,7 @@ void TestLogSubscribeCancel(const JobID &job_id,
// Request notifications again. We should receive a notification for the // Request notifications again. We should receive a notification for the
// current values at the key. // current values at the key.
RAY_CHECK_OK(client->job_table().RequestNotifications( RAY_CHECK_OK(client->job_table().RequestNotifications(
job_id, random_job_id, client->client_table().GetLocalClientId())); job_id, random_job_id, client->client_table().GetLocalClientId(), nullptr));
}; };
// Subscribe to notifications for this client. This allows us to request and // Subscribe to notifications for this client. This allows us to request and
@ -1115,9 +1115,9 @@ void TestSetSubscribeCancel(const JobID &job_id,
// Request notifications, then cancel immediately. We should receive a // Request notifications, then cancel immediately. We should receive a
// notification for the current value at the key. // notification for the current value at the key.
RAY_CHECK_OK(client->object_table().RequestNotifications( RAY_CHECK_OK(client->object_table().RequestNotifications(
job_id, object_id, client->client_table().GetLocalClientId())); job_id, object_id, client->client_table().GetLocalClientId(), nullptr));
RAY_CHECK_OK(client->object_table().CancelNotifications( RAY_CHECK_OK(client->object_table().CancelNotifications(
job_id, object_id, client->client_table().GetLocalClientId())); job_id, object_id, client->client_table().GetLocalClientId(), nullptr));
// Add to the key. Since we canceled notifications, we should not // Add to the key. Since we canceled notifications, we should not
// receive a notification for these writes. // receive a notification for these writes.
auto remaining = std::vector<std::string>(++managers.begin(), managers.end()); auto remaining = std::vector<std::string>(++managers.begin(), managers.end());
@ -1129,7 +1129,7 @@ void TestSetSubscribeCancel(const JobID &job_id,
// Request notifications again. We should receive a notification for the // Request notifications again. We should receive a notification for the
// current values at the key. // current values at the key.
RAY_CHECK_OK(client->object_table().RequestNotifications( RAY_CHECK_OK(client->object_table().RequestNotifications(
job_id, object_id, client->client_table().GetLocalClientId())); job_id, object_id, client->client_table().GetLocalClientId(), nullptr));
}; };
// Subscribe to notifications for this client. This allows us to request and // Subscribe to notifications for this client. This allows us to request and
@ -1342,7 +1342,7 @@ void TestHashTable(const JobID &job_id, std::shared_ptr<gcs::RedisGcsClient> cli
RAY_CHECK_OK(client->resource_table().Subscribe( RAY_CHECK_OK(client->resource_table().Subscribe(
job_id, ClientID::Nil(), notification_callback, subscribe_callback)); job_id, ClientID::Nil(), notification_callback, subscribe_callback));
RAY_CHECK_OK(client->resource_table().RequestNotifications( RAY_CHECK_OK(client->resource_table().RequestNotifications(
job_id, client_id, client->client_table().GetLocalClientId())); job_id, client_id, client->client_table().GetLocalClientId(), nullptr));
// Step 1: Add elements to the hash table. // Step 1: Add elements to the hash table.
auto update_callback1 = [data_map1, compare_test]( auto update_callback1 = [data_map1, compare_test](

View file

@ -0,0 +1,139 @@
#include "ray/gcs/subscription_executor.h"
namespace ray {
namespace gcs {
template <typename ID, typename Data, typename Table>
Status SubscriptionExecutor<ID, Data, Table>::AsyncSubscribe(
const ClientID &client_id, const SubscribeCallback<ID, Data> &subscribe,
const StatusCallback &done) {
// TODO(micafan) Optimize the lock when necessary.
// Consider avoiding locking in single-threaded processes.
std::lock_guard<std::mutex> lock(mutex_);
if (subscribe_all_callback_ != nullptr) {
RAY_LOG(DEBUG) << "Duplicate subscription! Already subscribed to all elements.";
return Status::Invalid("Duplicate subscription!");
}
if (registered_) {
if (subscribe != nullptr) {
RAY_LOG(DEBUG) << "Duplicate subscription! Already subscribed to specific elements"
", can't subscribe to all elements.";
return Status::Invalid("Duplicate subscription!");
}
return Status::OK();
}
auto on_subscribe = [this](RedisGcsClient *client, const ID &id,
const std::vector<Data> &result) {
if (result.empty()) {
return;
}
RAY_LOG(DEBUG) << "Subscribe received update of id " << id;
SubscribeCallback<ID, Data> sub_one_callback = nullptr;
SubscribeCallback<ID, Data> sub_all_callback = nullptr;
{
std::lock_guard<std::mutex> lock(mutex_);
const auto it = id_to_callback_map_.find(id);
if (it != id_to_callback_map_.end()) {
sub_one_callback = it->second;
}
sub_all_callback = subscribe_all_callback_;
}
if (sub_one_callback != nullptr) {
sub_one_callback(id, result.back());
}
if (sub_all_callback != nullptr) {
RAY_CHECK(sub_one_callback == nullptr);
sub_all_callback(id, result.back());
}
};
auto on_done = [done](RedisGcsClient *client) {
if (done != nullptr) {
done(Status::OK());
}
};
Status status = table_.Subscribe(JobID::Nil(), client_id, on_subscribe, on_done);
if (status.ok()) {
registered_ = true;
subscribe_all_callback_ = subscribe;
}
return status;
}
template <typename ID, typename Data, typename Table>
Status SubscriptionExecutor<ID, Data, Table>::AsyncSubscribe(
const ClientID &client_id, const ID &id, const SubscribeCallback<ID, Data> &subscribe,
const StatusCallback &done) {
Status status = AsyncSubscribe(client_id, nullptr, nullptr);
if (!status.ok()) {
return status;
}
auto on_done = [this, done, id](Status status) {
if (!status.ok()) {
std::lock_guard<std::mutex> lock(mutex_);
id_to_callback_map_.erase(id);
}
if (done != nullptr) {
done(status);
}
};
{
std::lock_guard<std::mutex> lock(mutex_);
const auto it = id_to_callback_map_.find(id);
if (it != id_to_callback_map_.end()) {
RAY_LOG(DEBUG) << "Duplicate subscription to id " << id << " client_id "
<< client_id;
return Status::Invalid("Duplicate subscription to element!");
}
status = table_.RequestNotifications(JobID::Nil(), id, client_id, on_done);
if (status.ok()) {
id_to_callback_map_[id] = subscribe;
}
}
return status;
}
template <typename ID, typename Data, typename Table>
Status SubscriptionExecutor<ID, Data, Table>::AsyncUnsubscribe(
const ClientID &client_id, const ID &id, const StatusCallback &done) {
{
std::lock_guard<std::mutex> lock(mutex_);
const auto it = id_to_callback_map_.find(id);
if (it == id_to_callback_map_.end()) {
RAY_LOG(DEBUG) << "Invalid Unsubscribe! id " << id << " client_id " << client_id;
return Status::Invalid("Invalid Unsubscribe, no existing subscription found.");
}
}
auto on_done = [this, id, done](Status status) {
if (status.ok()) {
std::lock_guard<std::mutex> lock(mutex_);
const auto it = id_to_callback_map_.find(id);
if (it != id_to_callback_map_.end()) {
id_to_callback_map_.erase(it);
}
}
if (done != nullptr) {
done(status);
}
};
return table_.CancelNotifications(JobID::Nil(), id, client_id, on_done);
}
template class SubscriptionExecutor<ActorID, ActorTableData, ActorTable>;
} // namespace gcs
} // namespace ray

View file

@ -0,0 +1,85 @@
#ifndef RAY_GCS_SUBSCRIPTION_EXECUTOR_H
#define RAY_GCS_SUBSCRIPTION_EXECUTOR_H
#include <atomic>
#include <mutex>
#include "ray/gcs/callback.h"
#include "ray/gcs/tables.h"
namespace ray {
namespace gcs {
/// \class SubscriptionExecutor
/// SubscriptionExecutor class encapsulates the implementation details of
/// subscribe/unsubscribe to elements (e.g.: actors or tasks or objects or nodes).
/// Support subscribing to a specific element or subscribing to all elements.
template <typename ID, typename Data, typename Table>
class SubscriptionExecutor {
public:
SubscriptionExecutor(Table &table) : table_(table) {}
~SubscriptionExecutor() {}
/// Subscribe to operations of all elements.
/// Repeated subscription will return a failure.
///
/// \param client_id The type of update to listen to. If this is nil, then a
/// message for each update will be received. Else, only
/// messages for the given client will be received.
/// \param subscribe Callback that will be called each time when an element
/// is registered or updated.
/// \param done Callback that will be called when subscription is complete.
/// \return Status
Status AsyncSubscribe(const ClientID &client_id,
const SubscribeCallback<ID, Data> &subscribe,
const StatusCallback &done);
/// Subscribe to operations of an element.
/// Repeated subscription to an element will return a failure.
///
/// \param client_id The type of update to listen to. If this is nil, then a
/// message for each update will be received. Else, only
/// messages for the given client will be received.
/// \param id The id of the element to be subscribe to.
/// \param subscribe Callback that will be called each time when the element
/// is registered or updated.
/// \param done Callback that will be called when subscription is complete.
/// \return Status
Status AsyncSubscribe(const ClientID &client_id, const ID &id,
const SubscribeCallback<ID, Data> &subscribe,
const StatusCallback &done);
/// Cancel subscription to an element.
/// Unsubscribing can only be called after the subscription request is completed.
///
/// \param client_id The type of update to listen to. If this is nil, then a
/// message for each update will be received. Else, only
/// messages for the given client will be received.
/// \param id The id of the element to be unsubscribed to.
/// \param done Callback that will be called when cancel subscription is complete.
/// \return Status
Status AsyncUnsubscribe(const ClientID &client_id, const ID &id,
const StatusCallback &done);
private:
Table &table_;
std::mutex mutex_;
/// Whether successfully registered subscription to GCS.
bool registered_{false};
/// Subscribe Callback of all elements.
SubscribeCallback<ID, Data> subscribe_all_callback_{nullptr};
/// A mapping from element ID to subscription callback.
typedef std::unordered_map<ID, SubscribeCallback<ID, Data>> IDToCallbackMap;
IDToCallbackMap id_to_callback_map_;
};
} // namespace gcs
} // namespace ray
#endif // RAY_GCS_SUBSCRIPTION_EXECUTOR_H

View file

@ -0,0 +1,201 @@
#include "gtest/gtest.h"
#include "ray/gcs/accessor_test_base.h"
#include "ray/gcs/callback.h"
#include "ray/gcs/redis_gcs_client.h"
namespace ray {
namespace gcs {
class SubscriptionExecutorTest : public AccessorTestBase<ActorID, ActorTableData> {
public:
typedef SubscriptionExecutor<ActorID, ActorTableData, ActorTable> ActorSubExecutor;
virtual void SetUp() {
AccessorTestBase<ActorID, ActorTableData>::SetUp();
actor_sub_executor_.reset(new ActorSubExecutor(gcs_client_->actor_table()));
subscribe_ = [this](const ActorID &id, const ActorTableData &data) {
const auto it = id_to_data_.find(id);
ASSERT_TRUE(it != id_to_data_.end());
--sub_pending_count_;
};
sub_done_ = [this](Status status) {
ASSERT_TRUE(status.ok()) << status;
--do_sub_pending_count_;
};
unsub_done_ = [this](Status status) {
ASSERT_TRUE(status.ok()) << status;
--do_unsub_pending_count_;
};
}
virtual void TearDown() {
AccessorTestBase<ActorID, ActorTableData>::TearDown();
ASSERT_EQ(sub_pending_count_, 0);
ASSERT_EQ(do_sub_pending_count_, 0);
ASSERT_EQ(do_unsub_pending_count_, 0);
}
protected:
virtual void GenTestData() {
for (size_t i = 0; i < 2; ++i) {
std::shared_ptr<ActorTableData> actor = std::make_shared<ActorTableData>();
actor->set_max_reconstructions(1);
actor->set_remaining_reconstructions(1);
JobID job_id = JobID::FromInt(i);
actor->set_job_id(job_id.Binary());
actor->set_state(ActorTableData::ALIVE);
ActorID actor_id = ActorID::Of(job_id, RandomTaskId(), /*parent_task_counter=*/i);
actor->set_actor_id(actor_id.Binary());
id_to_data_[actor_id] = actor;
}
}
size_t AsyncRegisterActorToGcs() {
ActorStateAccessor &actor_accessor = gcs_client_->Actors();
for (const auto &elem : id_to_data_) {
const auto &actor = elem.second;
auto done = [this](Status status) {
ASSERT_TRUE(status.ok());
--pending_count_;
};
++pending_count_;
Status status = actor_accessor.AsyncRegister(actor, done);
RAY_CHECK_OK(status);
}
return id_to_data_.size();
}
protected:
std::unique_ptr<ActorSubExecutor> actor_sub_executor_;
std::atomic<int> sub_pending_count_{0};
std::atomic<int> do_sub_pending_count_{0};
std::atomic<int> do_unsub_pending_count_{0};
SubscribeCallback<ActorID, ActorTableData> subscribe_{nullptr};
StatusCallback sub_done_{nullptr};
StatusCallback unsub_done_{nullptr};
};
TEST_F(SubscriptionExecutorTest, SubscribeAllTest) {
++do_sub_pending_count_;
Status status =
actor_sub_executor_->AsyncSubscribe(ClientID::Nil(), subscribe_, sub_done_);
WaitPendingDone(do_sub_pending_count_, wait_pending_timeout_);
ASSERT_TRUE(status.ok());
sub_pending_count_ = id_to_data_.size();
AsyncRegisterActorToGcs();
status = actor_sub_executor_->AsyncSubscribe(ClientID::Nil(), subscribe_, sub_done_);
ASSERT_TRUE(status.IsInvalid());
WaitPendingDone(sub_pending_count_, wait_pending_timeout_);
}
TEST_F(SubscriptionExecutorTest, SubscribeOneTest) {
Status status;
for (const auto &item : id_to_data_) {
++do_sub_pending_count_;
status = actor_sub_executor_->AsyncSubscribe(ClientID::Nil(), item.first, subscribe_,
sub_done_);
ASSERT_TRUE(status.ok());
}
WaitPendingDone(do_sub_pending_count_, wait_pending_timeout_);
sub_pending_count_ = id_to_data_.size();
AsyncRegisterActorToGcs();
for (const auto &item : id_to_data_) {
status = actor_sub_executor_->AsyncSubscribe(ClientID::Nil(), item.first, subscribe_,
sub_done_);
ASSERT_TRUE(status.IsInvalid());
}
WaitPendingDone(sub_pending_count_, wait_pending_timeout_);
}
TEST_F(SubscriptionExecutorTest, SubscribeOneWithClientIDTest) {
const auto &item = id_to_data_.begin();
++do_sub_pending_count_;
++sub_pending_count_;
Status status = actor_sub_executor_->AsyncSubscribe(ClientID::FromRandom(), item->first,
subscribe_, sub_done_);
WaitPendingDone(do_sub_pending_count_, wait_pending_timeout_);
ASSERT_TRUE(status.ok());
AsyncRegisterActorToGcs();
WaitPendingDone(sub_pending_count_, wait_pending_timeout_);
}
TEST_F(SubscriptionExecutorTest, SubscribeAllAndSubscribeOneTest) {
++do_sub_pending_count_;
Status status =
actor_sub_executor_->AsyncSubscribe(ClientID::Nil(), subscribe_, sub_done_);
ASSERT_TRUE(status.ok());
WaitPendingDone(do_sub_pending_count_, wait_pending_timeout_);
for (const auto &item : id_to_data_) {
status = actor_sub_executor_->AsyncSubscribe(ClientID::Nil(), item.first, subscribe_,
sub_done_);
ASSERT_FALSE(status.ok());
}
sub_pending_count_ = id_to_data_.size();
AsyncRegisterActorToGcs();
WaitPendingDone(sub_pending_count_, wait_pending_timeout_);
}
TEST_F(SubscriptionExecutorTest, UnsubscribeTest) {
Status status;
for (const auto &item : id_to_data_) {
status =
actor_sub_executor_->AsyncUnsubscribe(ClientID::Nil(), item.first, unsub_done_);
ASSERT_TRUE(status.IsInvalid());
}
for (const auto &item : id_to_data_) {
++do_sub_pending_count_;
status = actor_sub_executor_->AsyncSubscribe(ClientID::Nil(), item.first, subscribe_,
sub_done_);
ASSERT_TRUE(status.ok());
}
WaitPendingDone(do_sub_pending_count_, wait_pending_timeout_);
for (const auto &item : id_to_data_) {
++do_unsub_pending_count_;
status =
actor_sub_executor_->AsyncUnsubscribe(ClientID::Nil(), item.first, unsub_done_);
ASSERT_TRUE(status.ok());
}
WaitPendingDone(do_unsub_pending_count_, wait_pending_timeout_);
for (const auto &item : id_to_data_) {
status =
actor_sub_executor_->AsyncUnsubscribe(ClientID::Nil(), item.first, unsub_done_);
ASSERT_TRUE(!status.ok());
}
for (const auto &item : id_to_data_) {
++do_sub_pending_count_;
status = actor_sub_executor_->AsyncSubscribe(ClientID::Nil(), item.first, subscribe_,
sub_done_);
ASSERT_TRUE(status.ok());
}
WaitPendingDone(do_sub_pending_count_, wait_pending_timeout_);
for (const auto &item : id_to_data_) {
++do_unsub_pending_count_;
status =
actor_sub_executor_->AsyncUnsubscribe(ClientID::Nil(), item.first, unsub_done_);
ASSERT_TRUE(status.ok());
}
WaitPendingDone(do_unsub_pending_count_, wait_pending_timeout_);
for (const auto &item : id_to_data_) {
++do_sub_pending_count_;
status = actor_sub_executor_->AsyncSubscribe(ClientID::Nil(), item.first, subscribe_,
sub_done_);
ASSERT_TRUE(status.ok());
}
WaitPendingDone(do_sub_pending_count_, wait_pending_timeout_);
sub_pending_count_ = id_to_data_.size();
AsyncRegisterActorToGcs();
WaitPendingDone(sub_pending_count_, wait_pending_timeout_);
}
} // namespace gcs
} // namespace ray

View file

@ -162,22 +162,44 @@ Status Log<ID, Data>::Subscribe(const JobID &job_id, const ClientID &client_id,
template <typename ID, typename Data> template <typename ID, typename Data>
Status Log<ID, Data>::RequestNotifications(const JobID &job_id, const ID &id, Status Log<ID, Data>::RequestNotifications(const JobID &job_id, const ID &id,
const ClientID &client_id) { const ClientID &client_id,
const StatusCallback &done) {
RAY_CHECK(subscribe_callback_index_ >= 0) RAY_CHECK(subscribe_callback_index_ >= 0)
<< "Client requested notifications on a key before Subscribe completed"; << "Client requested notifications on a key before Subscribe completed";
RedisCallback callback = nullptr;
if (done != nullptr) {
callback = [done](const CallbackReply &reply) {
const auto status = reply.IsNil()
? Status::OK()
: Status::RedisError("request notifications failed.");
done(status);
};
}
return GetRedisContext(id)->RunAsync("RAY.TABLE_REQUEST_NOTIFICATIONS", id, return GetRedisContext(id)->RunAsync("RAY.TABLE_REQUEST_NOTIFICATIONS", id,
client_id.Data(), client_id.Size(), prefix_, client_id.Data(), client_id.Size(), prefix_,
pubsub_channel_, nullptr); pubsub_channel_, callback);
} }
template <typename ID, typename Data> template <typename ID, typename Data>
Status Log<ID, Data>::CancelNotifications(const JobID &job_id, const ID &id, Status Log<ID, Data>::CancelNotifications(const JobID &job_id, const ID &id,
const ClientID &client_id) { const ClientID &client_id,
const StatusCallback &done) {
RAY_CHECK(subscribe_callback_index_ >= 0) RAY_CHECK(subscribe_callback_index_ >= 0)
<< "Client canceled notifications on a key before Subscribe completed"; << "Client canceled notifications on a key before Subscribe completed";
RedisCallback callback = nullptr;
if (done != nullptr) {
callback = [done](const CallbackReply &reply) {
const auto status = reply.ReadAsStatus();
done(status);
};
}
return GetRedisContext(id)->RunAsync("RAY.TABLE_CANCEL_NOTIFICATIONS", id, return GetRedisContext(id)->RunAsync("RAY.TABLE_CANCEL_NOTIFICATIONS", id,
client_id.Data(), client_id.Size(), prefix_, client_id.Data(), client_id.Size(), prefix_,
pubsub_channel_, nullptr); pubsub_channel_, callback);
} }
template <typename ID, typename Data> template <typename ID, typename Data>
@ -621,7 +643,8 @@ Status ClientTable::Connect(const GcsNodeInfo &local_node_info) {
// Callback to request notifications from the client table once we've // Callback to request notifications from the client table once we've
// successfully subscribed. // successfully subscribed.
auto subscription_callback = [this](RedisGcsClient *c) { auto subscription_callback = [this](RedisGcsClient *c) {
RAY_CHECK_OK(RequestNotifications(JobID::Nil(), client_log_key_, node_id_)); RAY_CHECK_OK(RequestNotifications(JobID::Nil(), client_log_key_, node_id_,
/*done*/ nullptr));
}; };
// Subscribe to the client table. // Subscribe to the client table.
RAY_CHECK_OK( RAY_CHECK_OK(
@ -636,7 +659,8 @@ Status ClientTable::Disconnect(const DisconnectCallback &callback) {
auto add_callback = [this, callback](RedisGcsClient *client, const ClientID &id, auto add_callback = [this, callback](RedisGcsClient *client, const ClientID &id,
const GcsNodeInfo &data) { const GcsNodeInfo &data) {
HandleConnected(client, data); HandleConnected(client, data);
RAY_CHECK_OK(CancelNotifications(JobID::Nil(), client_log_key_, id)); RAY_CHECK_OK(
CancelNotifications(JobID::Nil(), client_log_key_, id, /*done*/ nullptr));
if (callback != nullptr) { if (callback != nullptr) {
callback(); callback();
} }

View file

@ -11,6 +11,7 @@
#include "ray/common/status.h" #include "ray/common/status.h"
#include "ray/util/logging.h" #include "ray/util/logging.h"
#include "ray/gcs/callback.h"
#include "ray/gcs/redis_context.h" #include "ray/gcs/redis_context.h"
#include "ray/protobuf/gcs.pb.h" #include "ray/protobuf/gcs.pb.h"
@ -56,9 +57,11 @@ template <typename ID>
class PubsubInterface { class PubsubInterface {
public: public:
virtual Status RequestNotifications(const JobID &job_id, const ID &id, virtual Status RequestNotifications(const JobID &job_id, const ID &id,
const ClientID &client_id) = 0; const ClientID &client_id,
const StatusCallback &done) = 0;
virtual Status CancelNotifications(const JobID &job_id, const ID &id, virtual Status CancelNotifications(const JobID &job_id, const ID &id,
const ClientID &client_id) = 0; const ClientID &client_id,
const StatusCallback &done) = 0;
virtual ~PubsubInterface(){}; virtual ~PubsubInterface(){};
}; };
@ -182,20 +185,22 @@ class Log : public LogInterface<ID, Data>, virtual public PubsubInterface<ID> {
/// \param job_id The ID of the job. /// \param job_id The ID of the job.
/// \param id The ID of the key to request notifications for. /// \param id The ID of the key to request notifications for.
/// \param client_id The client who is requesting notifications. Before /// \param client_id The client who is requesting notifications. Before
/// \param done Callback that is called when request notifications is complete.
/// notifications can be requested, a call to `Subscribe` to this /// notifications can be requested, a call to `Subscribe` to this
/// table with the same `client_id` must complete successfully. /// table with the same `client_id` must complete successfully.
/// \return Status /// \return Status
Status RequestNotifications(const JobID &job_id, const ID &id, Status RequestNotifications(const JobID &job_id, const ID &id,
const ClientID &client_id); const ClientID &client_id, const StatusCallback &done);
/// Cancel notifications about a key in this table. /// Cancel notifications about a key in this table.
/// ///
/// \param job_id The ID of the job. /// \param job_id The ID of the job.
/// \param id The ID of the key to request notifications for. /// \param id The ID of the key to request notifications for.
/// \param client_id The client who originally requested notifications. /// \param client_id The client who originally requested notifications.
/// \param done Callback that is called when cancel notifications is complete.
/// \return Status /// \return Status
Status CancelNotifications(const JobID &job_id, const ID &id, Status CancelNotifications(const JobID &job_id, const ID &id, const ClientID &client_id,
const ClientID &client_id); const StatusCallback &done);
/// Delete an entire key from redis. /// Delete an entire key from redis.
/// ///

View file

@ -159,7 +159,8 @@ ray::Status ObjectDirectory::SubscribeObjectLocations(const UniqueID &callback_i
if (it == listeners_.end()) { if (it == listeners_.end()) {
it = listeners_.emplace(object_id, LocationListenerState()).first; it = listeners_.emplace(object_id, LocationListenerState()).first;
status = gcs_client_->object_table().RequestNotifications( status = gcs_client_->object_table().RequestNotifications(
JobID::Nil(), object_id, gcs_client_->client_table().GetLocalClientId()); JobID::Nil(), object_id, gcs_client_->client_table().GetLocalClientId(),
/*done*/ nullptr);
} }
auto &listener_state = it->second; auto &listener_state = it->second;
// TODO(hme): Make this fatal after implementing Pull suppression. // TODO(hme): Make this fatal after implementing Pull suppression.
@ -187,7 +188,8 @@ ray::Status ObjectDirectory::UnsubscribeObjectLocations(const UniqueID &callback
entry->second.callbacks.erase(callback_id); entry->second.callbacks.erase(callback_id);
if (entry->second.callbacks.empty()) { if (entry->second.callbacks.empty()) {
status = gcs_client_->object_table().CancelNotifications( status = gcs_client_->object_table().CancelNotifications(
JobID::Nil(), object_id, gcs_client_->client_table().GetLocalClientId()); JobID::Nil(), object_id, gcs_client_->client_table().GetLocalClientId(),
/*done*/ nullptr);
listeners_.erase(entry); listeners_.erase(entry);
} }
return status; return status;

View file

@ -291,7 +291,8 @@ bool LineageCache::SubscribeTask(const TaskID &task_id) {
if (unsubscribed) { if (unsubscribed) {
// Request notifications for the task if we haven't already requested // Request notifications for the task if we haven't already requested
// notifications for it. // notifications for it.
RAY_CHECK_OK(task_pubsub_.RequestNotifications(JobID::Nil(), task_id, client_id_)); RAY_CHECK_OK(task_pubsub_.RequestNotifications(JobID::Nil(), task_id, client_id_,
/*done*/ nullptr));
} }
// Return whether we were previously unsubscribed to this task and are now // Return whether we were previously unsubscribed to this task and are now
// subscribed. // subscribed.
@ -304,7 +305,8 @@ bool LineageCache::UnsubscribeTask(const TaskID &task_id) {
if (subscribed) { if (subscribed) {
// Cancel notifications for the task if we previously requested // Cancel notifications for the task if we previously requested
// notifications for it. // notifications for it.
RAY_CHECK_OK(task_pubsub_.CancelNotifications(JobID::Nil(), task_id, client_id_)); RAY_CHECK_OK(task_pubsub_.CancelNotifications(JobID::Nil(), task_id, client_id_,
/*done*/ nullptr));
subscribed_tasks_.erase(it); subscribed_tasks_.erase(it);
} }
// Return whether we were previously subscribed to this task and are now // Return whether we were previously subscribed to this task and are now

View file

@ -7,8 +7,12 @@
#include "ray/common/task/task_execution_spec.h" #include "ray/common/task/task_execution_spec.h"
#include "ray/common/task/task_spec.h" #include "ray/common/task/task_spec.h"
#include "ray/common/task/task_util.h" #include "ray/common/task/task_util.h"
#include "ray/gcs/callback.h"
#include "ray/raylet/format/node_manager_generated.h" #include "ray/raylet/format/node_manager_generated.h"
#include "ray/raylet/lineage_cache.h" #include "ray/raylet/lineage_cache.h"
#include "ray/util/test_util.h" #include "ray/util/test_util.h"
namespace ray { namespace ray {
@ -67,7 +71,8 @@ class MockGcs : public gcs::TableInterface<TaskID, TaskTableData>,
} }
Status RequestNotifications(const JobID &job_id, const TaskID &task_id, Status RequestNotifications(const JobID &job_id, const TaskID &task_id,
const ClientID &client_id) { const ClientID &client_id,
const gcs::StatusCallback &done) {
subscribed_tasks_.insert(task_id); subscribed_tasks_.insert(task_id);
if (task_table_.count(task_id) == 1) { if (task_table_.count(task_id) == 1) {
callbacks_.push_back({notification_callback_, task_id}); callbacks_.push_back({notification_callback_, task_id});
@ -77,7 +82,7 @@ class MockGcs : public gcs::TableInterface<TaskID, TaskTableData>,
} }
Status CancelNotifications(const JobID &job_id, const TaskID &task_id, Status CancelNotifications(const JobID &job_id, const TaskID &task_id,
const ClientID &client_id) { const ClientID &client_id, const gcs::StatusCallback &done) {
subscribed_tasks_.erase(task_id); subscribed_tasks_.erase(task_id);
return ray::Status::OK(); return ray::Status::OK();
} }

View file

@ -53,7 +53,8 @@ void ReconstructionPolicy::SetTaskTimeout(
// task is still required after this initial period, then we now // task is still required after this initial period, then we now
// subscribe to task lease notifications. // subscribe to task lease notifications.
RAY_CHECK_OK(task_lease_pubsub_.RequestNotifications(JobID::Nil(), task_id, RAY_CHECK_OK(task_lease_pubsub_.RequestNotifications(JobID::Nil(), task_id,
client_id_)); client_id_,
/*done*/ nullptr));
it->second.subscribed = true; it->second.subscribed = true;
} }
} else { } else {
@ -200,8 +201,9 @@ void ReconstructionPolicy::Cancel(const ObjectID &object_id) {
if (it->second.created_objects.empty()) { if (it->second.created_objects.empty()) {
// Cancel notifications for the task lease if we were subscribed to them. // Cancel notifications for the task lease if we were subscribed to them.
if (it->second.subscribed) { if (it->second.subscribed) {
RAY_CHECK_OK( RAY_CHECK_OK(task_lease_pubsub_.CancelNotifications(JobID::Nil(), task_id,
task_lease_pubsub_.CancelNotifications(JobID::Nil(), task_id, client_id_)); client_id_,
/*done*/ nullptr));
} }
listening_tasks_.erase(it); listening_tasks_.erase(it);
} }

View file

@ -5,6 +5,8 @@
#include <boost/asio.hpp> #include <boost/asio.hpp>
#include "ray/gcs/callback.h"
#include "ray/raylet/format/node_manager_generated.h" #include "ray/raylet/format/node_manager_generated.h"
#include "ray/raylet/reconstruction_policy.h" #include "ray/raylet/reconstruction_policy.h"
@ -102,7 +104,8 @@ class MockGcs : public gcs::PubsubInterface<TaskID>,
} }
Status RequestNotifications(const JobID &job_id, const TaskID &task_id, Status RequestNotifications(const JobID &job_id, const TaskID &task_id,
const ClientID &client_id) { const ClientID &client_id,
const gcs::StatusCallback &done) {
subscribed_tasks_.insert(task_id); subscribed_tasks_.insert(task_id);
auto entry = task_lease_table_.find(task_id); auto entry = task_lease_table_.find(task_id);
if (entry == task_lease_table_.end()) { if (entry == task_lease_table_.end()) {
@ -114,7 +117,7 @@ class MockGcs : public gcs::PubsubInterface<TaskID>,
} }
Status CancelNotifications(const JobID &job_id, const TaskID &task_id, Status CancelNotifications(const JobID &job_id, const TaskID &task_id,
const ClientID &client_id) { const ClientID &client_id, const gcs::StatusCallback &done) {
subscribed_tasks_.erase(task_id); subscribed_tasks_.erase(task_id);
return ray::Status::OK(); return ray::Status::OK();
} }

View file

@ -6,7 +6,7 @@
set -e set -e
set -x set -x
bazel build "//:redis_gcs_client_test" "//:actor_state_accessor_test" "//:asio_test" "//:libray_redis_module.so" bazel build "//:redis_gcs_client_test" "//:actor_state_accessor_test" "//:subscription_executor_test" "//:asio_test" "//:libray_redis_module.so"
# Start Redis. # Start Redis.
if [[ "${RAY_USE_NEW_GCS}" = "on" ]]; then if [[ "${RAY_USE_NEW_GCS}" = "on" ]]; then
@ -25,6 +25,7 @@ sleep 1s
./bazel-bin/redis_gcs_client_test ./bazel-bin/redis_gcs_client_test
./bazel-bin/actor_state_accessor_test ./bazel-bin/actor_state_accessor_test
./bazel-bin/subscription_executor_test
./bazel-bin/asio_test ./bazel-bin/asio_test
./bazel-genfiles/redis-cli -p 6379 shutdown ./bazel-genfiles/redis-cli -p 6379 shutdown

View file

@ -19,8 +19,8 @@ bool WaitForCondition(std::function<bool()> condition, int timeout_ms) {
return true; return true;
} }
// sleep 100ms. // sleep 10ms.
const int wait_interval_ms = 100; const int wait_interval_ms = 10;
usleep(wait_interval_ms * 1000); usleep(wait_interval_ms * 1000);
wait_time += wait_interval_ms; wait_time += wait_interval_ms;
if (wait_time > timeout_ms) { if (wait_time > timeout_ms) {