Implement GetByJobId in gcs table storage (#8727)

This commit is contained in:
Tao Wang 2020-06-04 20:51:43 +08:00 committed by GitHub
parent 84a8f2ccb5
commit 41072fbcc8
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 177 additions and 73 deletions

View file

@ -85,9 +85,17 @@ Status GcsTableWithJobId<Key, Data>::Put(const Key &key, const Data &value,
template <typename Key, typename Data>
Status GcsTableWithJobId<Key, Data>::GetByJobId(const JobID &job_id,
const MapCallback<Key, Data> &callback) {
// TODO(ffbin): We will add this function after redis store client support
// AsyncGetByIndex interface.
return Status::NotImplemented("GetByJobId not implemented");
auto on_done = [callback](const std::unordered_map<std::string, std::string> &result) {
std::unordered_map<Key, Data> values;
for (auto &item : result) {
Data data;
data.ParseFromString(item.second);
values[Key::FromBinary(item.first)] = std::move(data);
}
callback(values);
};
return this->store_client_->AsyncGetByIndex(this->table_name_, job_id.Binary(),
on_done);
}
template <typename Key, typename Data>

View file

@ -67,6 +67,9 @@ class GcsTableStorageTestBase : public ::testing::Test {
std::vector<rpc::ActorTableData> values;
ASSERT_EQ(Get(table, actor_id, values), 1);
// Get by job id.
ASSERT_EQ(GetByJobId(table, job_id, actor_id, values), 1);
// Delete.
Delete(table, actor_id);
ASSERT_EQ(Get(table, actor_id, values), 0);
@ -96,6 +99,24 @@ class GcsTableStorageTestBase : public ::testing::Test {
return values.size();
}
template <typename TABLE, typename KEY, typename VALUE>
int GetByJobId(TABLE &table, const JobID &job_id, const KEY &key,
std::vector<VALUE> &values) {
auto on_done = [this, &values](const std::unordered_map<KEY, VALUE> &result) {
--pending_count_;
values.clear();
if (!result.empty()) {
for (auto &item : result) {
values.push_back(item.second);
}
}
};
++pending_count_;
RAY_CHECK_OK(table.GetByJobId(job_id, on_done));
WaitPendingDone();
return values.size();
}
template <typename TABLE, typename KEY>
void Delete(TABLE &table, const KEY &key) {
auto on_done = [this](Status status) {

View file

@ -89,6 +89,26 @@ Status InMemoryStoreClient::AsyncBatchDelete(const std::string &table_name,
return Status::OK();
}
Status InMemoryStoreClient::AsyncGetByIndex(
const std::string &table_name, const std::string &index_key,
const MapCallback<std::string, std::string> &callback) {
auto table = GetOrCreateTable(table_name);
absl::MutexLock lock(&(table->mutex_));
auto iter = table->index_keys_.find(index_key);
std::unordered_map<std::string, std::string> result;
if (iter != table->index_keys_.end()) {
for (auto &key : iter->second) {
auto kv_iter = table->records_.find(key);
if (kv_iter != table->records_.end()) {
result[kv_iter->first] = kv_iter->second;
}
}
}
main_io_service_.post([result, callback]() { callback(result); });
return Status::OK();
}
Status InMemoryStoreClient::AsyncDeleteByIndex(const std::string &table_name,
const std::string &index_key,
const StatusCallback &callback) {

View file

@ -42,6 +42,9 @@ class InMemoryStoreClient : public StoreClient {
Status AsyncGet(const std::string &table_name, const std::string &key,
const OptionalItemCallback<std::string> &callback) override;
Status AsyncGetByIndex(const std::string &table_name, const std::string &index_key,
const MapCallback<std::string, std::string> &callback) override;
Status AsyncGetAll(const std::string &table_name,
const MapCallback<std::string, std::string> &callback) override;

View file

@ -89,12 +89,12 @@ Status RedisStoreClient::AsyncGetAll(
const MapCallback<std::string, std::string> &callback) {
RAY_CHECK(callback);
std::string match_pattern = GenRedisMatchPattern(table_name);
auto scanner = std::make_shared<RedisScanner>(redis_client_, table_name, match_pattern);
auto scanner = std::make_shared<RedisScanner>(redis_client_, table_name);
auto on_done = [callback,
scanner](const std::unordered_map<std::string, std::string> &result) {
callback(result);
};
return scanner->ScanKeysAndValues(on_done);
return scanner->ScanKeysAndValues(match_pattern, on_done);
}
Status RedisStoreClient::AsyncDelete(const std::string &table_name,
@ -125,11 +125,35 @@ Status RedisStoreClient::AsyncBatchDelete(const std::string &table_name,
return DeleteByKeys(redis_keys, callback);
}
Status RedisStoreClient::AsyncGetByIndex(
const std::string &table_name, const std::string &index_key,
const MapCallback<std::string, std::string> &callback) {
RAY_CHECK(callback);
std::string match_pattern = GenRedisMatchPattern(table_name, index_key);
auto scanner = std::make_shared<RedisScanner>(redis_client_, table_name);
auto on_done = [this, callback, scanner, table_name, index_key](
const Status &status, const std::vector<std::string> &result) {
if (!result.empty()) {
std::vector<std::string> keys;
keys.reserve(result.size());
for (auto &item : result) {
keys.push_back(
GenRedisKey(table_name, GetKeyFromRedisKey(item, table_name, index_key)));
}
RAY_CHECK_OK(MGetValues(redis_client_, table_name, keys, callback));
} else {
callback(std::unordered_map<std::string, std::string>());
}
};
return scanner->ScanKeys(match_pattern, on_done);
}
Status RedisStoreClient::AsyncDeleteByIndex(const std::string &table_name,
const std::string &index_key,
const StatusCallback &callback) {
std::string match_pattern = GenRedisMatchPattern(table_name, index_key);
auto scanner = std::make_shared<RedisScanner>(redis_client_, table_name, match_pattern);
auto scanner = std::make_shared<RedisScanner>(redis_client_, table_name);
auto on_done = [this, table_name, index_key, callback, scanner](
const Status &status, const std::vector<std::string> &result) {
if (!result.empty()) {
@ -149,7 +173,7 @@ Status RedisStoreClient::AsyncDeleteByIndex(const std::string &table_name,
}
};
return scanner->ScanKeys(on_done);
return scanner->ScanKeys(match_pattern, on_done);
}
Status RedisStoreClient::DoPut(const std::string &key, const std::string &data,
@ -248,42 +272,74 @@ std::string RedisStoreClient::GetKeyFromRedisKey(const std::string &redis_key,
return redis_key.substr(pos, redis_key.size() - pos);
}
Status RedisStoreClient::MGetValues(
std::shared_ptr<RedisClient> redis_client, std::string table_name,
const std::vector<std::string> &keys,
const ItemCallback<std::unordered_map<std::string, std::string>> &callback) {
// The `MGET` command for each shard.
auto mget_commands_by_shards = GenCommandsByShards(redis_client, "MGET", keys);
auto finished_count = std::make_shared<int>(0);
int size = mget_commands_by_shards.size();
for (auto &item : mget_commands_by_shards) {
auto mget_keys = std::move(item.second);
auto mget_callback = [table_name, finished_count, size, mget_keys,
callback](const std::shared_ptr<CallbackReply> &reply) {
std::unordered_map<std::string, std::string> key_value_map;
if (!reply->IsNil()) {
auto value = reply->ReadAsStringArray();
// The 0 th element of mget_keys is "MGET", so we start from the 1 th element.
for (int index = 0; index < (int)value.size(); ++index) {
key_value_map[GetKeyFromRedisKey(mget_keys[index + 1], table_name)] =
value[index];
}
}
++(*finished_count);
if (*finished_count == size) {
callback(key_value_map);
}
};
RAY_CHECK_OK(item.first->RunArgvAsync(mget_keys, mget_callback));
}
return Status::OK();
}
RedisStoreClient::RedisScanner::RedisScanner(std::shared_ptr<RedisClient> redis_client,
std::string table_name,
std::string match_pattern)
: table_name_(std::move(table_name)),
match_pattern_(std::move(match_pattern)),
redis_client_(std::move(redis_client)) {
std::string table_name)
: table_name_(std::move(table_name)), redis_client_(std::move(redis_client)) {
for (size_t index = 0; index < redis_client_->GetShardContexts().size(); ++index) {
shard_to_cursor_[index] = 0;
}
}
Status RedisStoreClient::RedisScanner::ScanKeysAndValues(
std::string match_pattern,
const ItemCallback<std::unordered_map<std::string, std::string>> &callback) {
auto on_done = [this, callback](const Status &status,
const std::vector<std::string> &result) {
if (result.empty()) {
callback(std::unordered_map<std::string, std::string>());
} else {
MGetValues(result, callback);
RAY_CHECK_OK(MGetValues(redis_client_, table_name_, result, callback));
}
};
return ScanKeys(on_done);
return ScanKeys(match_pattern, on_done);
}
Status RedisStoreClient::RedisScanner::ScanKeys(
const MultiItemCallback<std::string> &callback) {
std::string match_pattern, const MultiItemCallback<std::string> &callback) {
auto on_done = [this, callback](const Status &status) {
std::vector<std::string> result;
result.insert(result.begin(), keys_.begin(), keys_.end());
callback(status, result);
};
Scan(on_done);
Scan(match_pattern, on_done);
return Status::OK();
}
void RedisStoreClient::RedisScanner::Scan(const StatusCallback &callback) {
void RedisStoreClient::RedisScanner::Scan(std::string match_pattern,
const StatusCallback &callback) {
if (shard_to_cursor_.empty()) {
callback(Status::OK());
return;
@ -296,18 +352,16 @@ void RedisStoreClient::RedisScanner::Scan(const StatusCallback &callback) {
size_t shard_index = item.first;
size_t cursor = item.second;
auto scan_callback = [this, shard_index,
auto scan_callback = [this, match_pattern, shard_index,
callback](const std::shared_ptr<CallbackReply> &reply) {
OnScanCallback(shard_index, reply, callback);
OnScanCallback(match_pattern, shard_index, reply, callback);
};
// Scan by prefix from Redis.
std::vector<std::string> args = {"SCAN", std::to_string(cursor),
"MATCH", match_pattern_,
"MATCH", match_pattern,
"COUNT", std::to_string(batch_count)};
auto shard_context = redis_client_->GetShardContexts()[shard_index];
Status status = shard_context->RunArgvAsync(args, scan_callback);
if (!status.ok()) {
RAY_LOG(FATAL) << "Scan failed, status " << status.ToString();
}
@ -315,12 +369,11 @@ void RedisStoreClient::RedisScanner::Scan(const StatusCallback &callback) {
}
void RedisStoreClient::RedisScanner::OnScanCallback(
size_t shard_index, const std::shared_ptr<CallbackReply> &reply,
const StatusCallback &callback) {
std::string match_pattern, size_t shard_index,
const std::shared_ptr<CallbackReply> &reply, const StatusCallback &callback) {
RAY_CHECK(reply);
std::vector<std::string> scan_result;
size_t cursor = reply->ReadAsScanArray(&scan_result);
// Update shard cursors and keys_.
{
absl::MutexLock lock(&mutex_);
@ -340,40 +393,7 @@ void RedisStoreClient::RedisScanner::OnScanCallback(
// If pending_request_count_ is equal to 0, it means that the scan of this batch is
// completed and the next batch is started if any.
if (--pending_request_count_ == 0) {
Scan(callback);
}
}
void RedisStoreClient::RedisScanner::MGetValues(
const std::vector<std::string> &keys,
const ItemCallback<std::unordered_map<std::string, std::string>> &callback) {
// The `MGET` command for each shard.
auto mget_commands_by_shards = GenCommandsByShards(redis_client_, "MGET", keys);
auto finished_count = std::make_shared<int>(0);
int size = mget_commands_by_shards.size();
for (auto &item : mget_commands_by_shards) {
auto mget_keys = std::move(item.second);
auto mget_callback = [this, finished_count, size, mget_keys,
callback](const std::shared_ptr<CallbackReply> &reply) {
if (!reply->IsNil()) {
auto value = reply->ReadAsStringArray();
{
absl::MutexLock lock(&mutex_);
// The 0 th element of mget_keys is "MGET", so we start from the 1 th element.
for (int index = 0; index < (int)value.size(); ++index) {
key_value_map_[GetKeyFromRedisKey(mget_keys[index + 1], table_name_)] =
value[index];
}
}
}
++(*finished_count);
if (*finished_count == size) {
callback(key_value_map_);
}
};
RAY_CHECK_OK(item.first->RunArgvAsync(mget_keys, mget_callback));
Scan(match_pattern, callback);
}
}

View file

@ -40,6 +40,9 @@ class RedisStoreClient : public StoreClient {
Status AsyncGet(const std::string &table_name, const std::string &key,
const OptionalItemCallback<std::string> &callback) override;
Status AsyncGetByIndex(const std::string &table_name, const std::string &index_key,
const MapCallback<std::string, std::string> &callback) override;
Status AsyncGetAll(const std::string &table_name,
const MapCallback<std::string, std::string> &callback) override;
@ -62,25 +65,22 @@ class RedisStoreClient : public StoreClient {
class RedisScanner {
public:
explicit RedisScanner(std::shared_ptr<RedisClient> redis_client,
std::string table_name, std::string match_pattern);
std::string table_name);
Status ScanKeysAndValues(const MapCallback<std::string, std::string> &callback);
Status ScanKeys(const MultiItemCallback<std::string> &callback);
private:
void Scan(const StatusCallback &callback);
void OnScanCallback(size_t shard_index, const std::shared_ptr<CallbackReply> &reply,
const StatusCallback &callback);
void MGetValues(const std::vector<std::string> &keys,
Status ScanKeysAndValues(std::string match_pattern,
const MapCallback<std::string, std::string> &callback);
std::string table_name_;
Status ScanKeys(std::string match_pattern,
const MultiItemCallback<std::string> &callback);
/// The scan match pattern.
std::string match_pattern_;
private:
void Scan(std::string match_pattern, const StatusCallback &callback);
void OnScanCallback(std::string match_pattern, size_t shard_index,
const std::shared_ptr<CallbackReply> &reply,
const StatusCallback &callback);
std::string table_name_;
/// Mutex to protect the shard_to_cursor_ field and the keys_ field and the
/// key_value_map_ field.
@ -89,9 +89,6 @@ class RedisStoreClient : public StoreClient {
/// All keys that scanned from redis.
absl::flat_hash_set<std::string> keys_;
/// Key-Value pairs that scanned from redis.
std::unordered_map<std::string, std::string> key_value_map_;
/// The scan cursor for each shard.
std::unordered_map<size_t, size_t> shard_to_cursor_;
@ -132,6 +129,10 @@ class RedisStoreClient : public StoreClient {
const std::string &table_name,
const std::string &index_key);
static Status MGetValues(std::shared_ptr<RedisClient> redis_client,
std::string table_name, const std::vector<std::string> &keys,
const MapCallback<std::string, std::string> &callback);
std::shared_ptr<RedisClient> redis_client_;
};

View file

@ -65,6 +65,16 @@ class StoreClient {
virtual Status AsyncGet(const std::string &table_name, const std::string &key,
const OptionalItemCallback<std::string> &callback) = 0;
/// Get data by index from the given table asynchronously.
///
/// \param table_name The name of the table to be read.
/// \param index_key The secondary key that will be used to get the indexed data.
/// \param callback Callback that will be called after read finishes.
/// \return Status
virtual Status AsyncGetByIndex(
const std::string &table_name, const std::string &index_key,
const MapCallback<std::string, std::string> &callback) = 0;
/// Get all data from the given table asynchronously.
///
/// \param table_name The name of the table to be read.

View file

@ -132,6 +132,24 @@ class StoreClientTestBase : public ::testing::Test {
WaitPendingDone();
}
void GetByIndex() {
auto get_calllback =
[this](const std::unordered_map<std::string, std::string> &result) {
if (!result.empty()) {
auto key = ActorID::FromBinary(result.begin()->first);
auto it = key_to_index_.find(key);
RAY_CHECK(it != key_to_index_.end());
RAY_CHECK(index_to_keys_[it->second].size() == result.size());
}
pending_count_ -= result.size();
};
auto iter = index_to_keys_.begin();
pending_count_ += iter->second.size();
RAY_CHECK_OK(
store_client_->AsyncGetByIndex(table_name_, iter->first.Hex(), get_calllback));
WaitPendingDone();
}
void DeleteByIndex() {
auto delete_calllback = [this](const Status &status) {
RAY_CHECK_OK(status);
@ -198,6 +216,9 @@ class StoreClientTestBase : public ::testing::Test {
// AsyncPut with index
PutWithIndex();
// AsyncGet with index
GetByIndex();
// AsyncDelete by index
DeleteByIndex();