mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
Implement GetByJobId in gcs table storage (#8727)
This commit is contained in:
parent
84a8f2ccb5
commit
41072fbcc8
8 changed files with 177 additions and 73 deletions
|
@ -85,9 +85,17 @@ Status GcsTableWithJobId<Key, Data>::Put(const Key &key, const Data &value,
|
|||
template <typename Key, typename Data>
|
||||
Status GcsTableWithJobId<Key, Data>::GetByJobId(const JobID &job_id,
|
||||
const MapCallback<Key, Data> &callback) {
|
||||
// TODO(ffbin): We will add this function after redis store client support
|
||||
// AsyncGetByIndex interface.
|
||||
return Status::NotImplemented("GetByJobId not implemented");
|
||||
auto on_done = [callback](const std::unordered_map<std::string, std::string> &result) {
|
||||
std::unordered_map<Key, Data> values;
|
||||
for (auto &item : result) {
|
||||
Data data;
|
||||
data.ParseFromString(item.second);
|
||||
values[Key::FromBinary(item.first)] = std::move(data);
|
||||
}
|
||||
callback(values);
|
||||
};
|
||||
return this->store_client_->AsyncGetByIndex(this->table_name_, job_id.Binary(),
|
||||
on_done);
|
||||
}
|
||||
|
||||
template <typename Key, typename Data>
|
||||
|
|
|
@ -67,6 +67,9 @@ class GcsTableStorageTestBase : public ::testing::Test {
|
|||
std::vector<rpc::ActorTableData> values;
|
||||
ASSERT_EQ(Get(table, actor_id, values), 1);
|
||||
|
||||
// Get by job id.
|
||||
ASSERT_EQ(GetByJobId(table, job_id, actor_id, values), 1);
|
||||
|
||||
// Delete.
|
||||
Delete(table, actor_id);
|
||||
ASSERT_EQ(Get(table, actor_id, values), 0);
|
||||
|
@ -96,6 +99,24 @@ class GcsTableStorageTestBase : public ::testing::Test {
|
|||
return values.size();
|
||||
}
|
||||
|
||||
template <typename TABLE, typename KEY, typename VALUE>
|
||||
int GetByJobId(TABLE &table, const JobID &job_id, const KEY &key,
|
||||
std::vector<VALUE> &values) {
|
||||
auto on_done = [this, &values](const std::unordered_map<KEY, VALUE> &result) {
|
||||
--pending_count_;
|
||||
values.clear();
|
||||
if (!result.empty()) {
|
||||
for (auto &item : result) {
|
||||
values.push_back(item.second);
|
||||
}
|
||||
}
|
||||
};
|
||||
++pending_count_;
|
||||
RAY_CHECK_OK(table.GetByJobId(job_id, on_done));
|
||||
WaitPendingDone();
|
||||
return values.size();
|
||||
}
|
||||
|
||||
template <typename TABLE, typename KEY>
|
||||
void Delete(TABLE &table, const KEY &key) {
|
||||
auto on_done = [this](Status status) {
|
||||
|
|
|
@ -89,6 +89,26 @@ Status InMemoryStoreClient::AsyncBatchDelete(const std::string &table_name,
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
Status InMemoryStoreClient::AsyncGetByIndex(
|
||||
const std::string &table_name, const std::string &index_key,
|
||||
const MapCallback<std::string, std::string> &callback) {
|
||||
auto table = GetOrCreateTable(table_name);
|
||||
absl::MutexLock lock(&(table->mutex_));
|
||||
auto iter = table->index_keys_.find(index_key);
|
||||
std::unordered_map<std::string, std::string> result;
|
||||
if (iter != table->index_keys_.end()) {
|
||||
for (auto &key : iter->second) {
|
||||
auto kv_iter = table->records_.find(key);
|
||||
if (kv_iter != table->records_.end()) {
|
||||
result[kv_iter->first] = kv_iter->second;
|
||||
}
|
||||
}
|
||||
}
|
||||
main_io_service_.post([result, callback]() { callback(result); });
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status InMemoryStoreClient::AsyncDeleteByIndex(const std::string &table_name,
|
||||
const std::string &index_key,
|
||||
const StatusCallback &callback) {
|
||||
|
|
|
@ -42,6 +42,9 @@ class InMemoryStoreClient : public StoreClient {
|
|||
Status AsyncGet(const std::string &table_name, const std::string &key,
|
||||
const OptionalItemCallback<std::string> &callback) override;
|
||||
|
||||
Status AsyncGetByIndex(const std::string &table_name, const std::string &index_key,
|
||||
const MapCallback<std::string, std::string> &callback) override;
|
||||
|
||||
Status AsyncGetAll(const std::string &table_name,
|
||||
const MapCallback<std::string, std::string> &callback) override;
|
||||
|
||||
|
|
|
@ -89,12 +89,12 @@ Status RedisStoreClient::AsyncGetAll(
|
|||
const MapCallback<std::string, std::string> &callback) {
|
||||
RAY_CHECK(callback);
|
||||
std::string match_pattern = GenRedisMatchPattern(table_name);
|
||||
auto scanner = std::make_shared<RedisScanner>(redis_client_, table_name, match_pattern);
|
||||
auto scanner = std::make_shared<RedisScanner>(redis_client_, table_name);
|
||||
auto on_done = [callback,
|
||||
scanner](const std::unordered_map<std::string, std::string> &result) {
|
||||
callback(result);
|
||||
};
|
||||
return scanner->ScanKeysAndValues(on_done);
|
||||
return scanner->ScanKeysAndValues(match_pattern, on_done);
|
||||
}
|
||||
|
||||
Status RedisStoreClient::AsyncDelete(const std::string &table_name,
|
||||
|
@ -125,11 +125,35 @@ Status RedisStoreClient::AsyncBatchDelete(const std::string &table_name,
|
|||
return DeleteByKeys(redis_keys, callback);
|
||||
}
|
||||
|
||||
Status RedisStoreClient::AsyncGetByIndex(
|
||||
const std::string &table_name, const std::string &index_key,
|
||||
const MapCallback<std::string, std::string> &callback) {
|
||||
RAY_CHECK(callback);
|
||||
std::string match_pattern = GenRedisMatchPattern(table_name, index_key);
|
||||
auto scanner = std::make_shared<RedisScanner>(redis_client_, table_name);
|
||||
auto on_done = [this, callback, scanner, table_name, index_key](
|
||||
const Status &status, const std::vector<std::string> &result) {
|
||||
if (!result.empty()) {
|
||||
std::vector<std::string> keys;
|
||||
keys.reserve(result.size());
|
||||
for (auto &item : result) {
|
||||
keys.push_back(
|
||||
GenRedisKey(table_name, GetKeyFromRedisKey(item, table_name, index_key)));
|
||||
}
|
||||
|
||||
RAY_CHECK_OK(MGetValues(redis_client_, table_name, keys, callback));
|
||||
} else {
|
||||
callback(std::unordered_map<std::string, std::string>());
|
||||
}
|
||||
};
|
||||
return scanner->ScanKeys(match_pattern, on_done);
|
||||
}
|
||||
|
||||
Status RedisStoreClient::AsyncDeleteByIndex(const std::string &table_name,
|
||||
const std::string &index_key,
|
||||
const StatusCallback &callback) {
|
||||
std::string match_pattern = GenRedisMatchPattern(table_name, index_key);
|
||||
auto scanner = std::make_shared<RedisScanner>(redis_client_, table_name, match_pattern);
|
||||
auto scanner = std::make_shared<RedisScanner>(redis_client_, table_name);
|
||||
auto on_done = [this, table_name, index_key, callback, scanner](
|
||||
const Status &status, const std::vector<std::string> &result) {
|
||||
if (!result.empty()) {
|
||||
|
@ -149,7 +173,7 @@ Status RedisStoreClient::AsyncDeleteByIndex(const std::string &table_name,
|
|||
}
|
||||
};
|
||||
|
||||
return scanner->ScanKeys(on_done);
|
||||
return scanner->ScanKeys(match_pattern, on_done);
|
||||
}
|
||||
|
||||
Status RedisStoreClient::DoPut(const std::string &key, const std::string &data,
|
||||
|
@ -248,42 +272,74 @@ std::string RedisStoreClient::GetKeyFromRedisKey(const std::string &redis_key,
|
|||
return redis_key.substr(pos, redis_key.size() - pos);
|
||||
}
|
||||
|
||||
Status RedisStoreClient::MGetValues(
|
||||
std::shared_ptr<RedisClient> redis_client, std::string table_name,
|
||||
const std::vector<std::string> &keys,
|
||||
const ItemCallback<std::unordered_map<std::string, std::string>> &callback) {
|
||||
// The `MGET` command for each shard.
|
||||
auto mget_commands_by_shards = GenCommandsByShards(redis_client, "MGET", keys);
|
||||
|
||||
auto finished_count = std::make_shared<int>(0);
|
||||
int size = mget_commands_by_shards.size();
|
||||
for (auto &item : mget_commands_by_shards) {
|
||||
auto mget_keys = std::move(item.second);
|
||||
auto mget_callback = [table_name, finished_count, size, mget_keys,
|
||||
callback](const std::shared_ptr<CallbackReply> &reply) {
|
||||
std::unordered_map<std::string, std::string> key_value_map;
|
||||
if (!reply->IsNil()) {
|
||||
auto value = reply->ReadAsStringArray();
|
||||
// The 0 th element of mget_keys is "MGET", so we start from the 1 th element.
|
||||
for (int index = 0; index < (int)value.size(); ++index) {
|
||||
key_value_map[GetKeyFromRedisKey(mget_keys[index + 1], table_name)] =
|
||||
value[index];
|
||||
}
|
||||
}
|
||||
|
||||
++(*finished_count);
|
||||
if (*finished_count == size) {
|
||||
callback(key_value_map);
|
||||
}
|
||||
};
|
||||
RAY_CHECK_OK(item.first->RunArgvAsync(mget_keys, mget_callback));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
RedisStoreClient::RedisScanner::RedisScanner(std::shared_ptr<RedisClient> redis_client,
|
||||
std::string table_name,
|
||||
std::string match_pattern)
|
||||
: table_name_(std::move(table_name)),
|
||||
match_pattern_(std::move(match_pattern)),
|
||||
redis_client_(std::move(redis_client)) {
|
||||
std::string table_name)
|
||||
: table_name_(std::move(table_name)), redis_client_(std::move(redis_client)) {
|
||||
for (size_t index = 0; index < redis_client_->GetShardContexts().size(); ++index) {
|
||||
shard_to_cursor_[index] = 0;
|
||||
}
|
||||
}
|
||||
|
||||
Status RedisStoreClient::RedisScanner::ScanKeysAndValues(
|
||||
std::string match_pattern,
|
||||
const ItemCallback<std::unordered_map<std::string, std::string>> &callback) {
|
||||
auto on_done = [this, callback](const Status &status,
|
||||
const std::vector<std::string> &result) {
|
||||
if (result.empty()) {
|
||||
callback(std::unordered_map<std::string, std::string>());
|
||||
} else {
|
||||
MGetValues(result, callback);
|
||||
RAY_CHECK_OK(MGetValues(redis_client_, table_name_, result, callback));
|
||||
}
|
||||
};
|
||||
return ScanKeys(on_done);
|
||||
return ScanKeys(match_pattern, on_done);
|
||||
}
|
||||
|
||||
Status RedisStoreClient::RedisScanner::ScanKeys(
|
||||
const MultiItemCallback<std::string> &callback) {
|
||||
std::string match_pattern, const MultiItemCallback<std::string> &callback) {
|
||||
auto on_done = [this, callback](const Status &status) {
|
||||
std::vector<std::string> result;
|
||||
result.insert(result.begin(), keys_.begin(), keys_.end());
|
||||
callback(status, result);
|
||||
};
|
||||
Scan(on_done);
|
||||
Scan(match_pattern, on_done);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void RedisStoreClient::RedisScanner::Scan(const StatusCallback &callback) {
|
||||
void RedisStoreClient::RedisScanner::Scan(std::string match_pattern,
|
||||
const StatusCallback &callback) {
|
||||
if (shard_to_cursor_.empty()) {
|
||||
callback(Status::OK());
|
||||
return;
|
||||
|
@ -296,18 +352,16 @@ void RedisStoreClient::RedisScanner::Scan(const StatusCallback &callback) {
|
|||
size_t shard_index = item.first;
|
||||
size_t cursor = item.second;
|
||||
|
||||
auto scan_callback = [this, shard_index,
|
||||
auto scan_callback = [this, match_pattern, shard_index,
|
||||
callback](const std::shared_ptr<CallbackReply> &reply) {
|
||||
OnScanCallback(shard_index, reply, callback);
|
||||
OnScanCallback(match_pattern, shard_index, reply, callback);
|
||||
};
|
||||
|
||||
// Scan by prefix from Redis.
|
||||
std::vector<std::string> args = {"SCAN", std::to_string(cursor),
|
||||
"MATCH", match_pattern_,
|
||||
"MATCH", match_pattern,
|
||||
"COUNT", std::to_string(batch_count)};
|
||||
auto shard_context = redis_client_->GetShardContexts()[shard_index];
|
||||
Status status = shard_context->RunArgvAsync(args, scan_callback);
|
||||
|
||||
if (!status.ok()) {
|
||||
RAY_LOG(FATAL) << "Scan failed, status " << status.ToString();
|
||||
}
|
||||
|
@ -315,12 +369,11 @@ void RedisStoreClient::RedisScanner::Scan(const StatusCallback &callback) {
|
|||
}
|
||||
|
||||
void RedisStoreClient::RedisScanner::OnScanCallback(
|
||||
size_t shard_index, const std::shared_ptr<CallbackReply> &reply,
|
||||
const StatusCallback &callback) {
|
||||
std::string match_pattern, size_t shard_index,
|
||||
const std::shared_ptr<CallbackReply> &reply, const StatusCallback &callback) {
|
||||
RAY_CHECK(reply);
|
||||
std::vector<std::string> scan_result;
|
||||
size_t cursor = reply->ReadAsScanArray(&scan_result);
|
||||
|
||||
// Update shard cursors and keys_.
|
||||
{
|
||||
absl::MutexLock lock(&mutex_);
|
||||
|
@ -340,40 +393,7 @@ void RedisStoreClient::RedisScanner::OnScanCallback(
|
|||
// If pending_request_count_ is equal to 0, it means that the scan of this batch is
|
||||
// completed and the next batch is started if any.
|
||||
if (--pending_request_count_ == 0) {
|
||||
Scan(callback);
|
||||
}
|
||||
}
|
||||
|
||||
void RedisStoreClient::RedisScanner::MGetValues(
|
||||
const std::vector<std::string> &keys,
|
||||
const ItemCallback<std::unordered_map<std::string, std::string>> &callback) {
|
||||
// The `MGET` command for each shard.
|
||||
auto mget_commands_by_shards = GenCommandsByShards(redis_client_, "MGET", keys);
|
||||
|
||||
auto finished_count = std::make_shared<int>(0);
|
||||
int size = mget_commands_by_shards.size();
|
||||
for (auto &item : mget_commands_by_shards) {
|
||||
auto mget_keys = std::move(item.second);
|
||||
auto mget_callback = [this, finished_count, size, mget_keys,
|
||||
callback](const std::shared_ptr<CallbackReply> &reply) {
|
||||
if (!reply->IsNil()) {
|
||||
auto value = reply->ReadAsStringArray();
|
||||
{
|
||||
absl::MutexLock lock(&mutex_);
|
||||
// The 0 th element of mget_keys is "MGET", so we start from the 1 th element.
|
||||
for (int index = 0; index < (int)value.size(); ++index) {
|
||||
key_value_map_[GetKeyFromRedisKey(mget_keys[index + 1], table_name_)] =
|
||||
value[index];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
++(*finished_count);
|
||||
if (*finished_count == size) {
|
||||
callback(key_value_map_);
|
||||
}
|
||||
};
|
||||
RAY_CHECK_OK(item.first->RunArgvAsync(mget_keys, mget_callback));
|
||||
Scan(match_pattern, callback);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -40,6 +40,9 @@ class RedisStoreClient : public StoreClient {
|
|||
Status AsyncGet(const std::string &table_name, const std::string &key,
|
||||
const OptionalItemCallback<std::string> &callback) override;
|
||||
|
||||
Status AsyncGetByIndex(const std::string &table_name, const std::string &index_key,
|
||||
const MapCallback<std::string, std::string> &callback) override;
|
||||
|
||||
Status AsyncGetAll(const std::string &table_name,
|
||||
const MapCallback<std::string, std::string> &callback) override;
|
||||
|
||||
|
@ -62,26 +65,23 @@ class RedisStoreClient : public StoreClient {
|
|||
class RedisScanner {
|
||||
public:
|
||||
explicit RedisScanner(std::shared_ptr<RedisClient> redis_client,
|
||||
std::string table_name, std::string match_pattern);
|
||||
std::string table_name);
|
||||
|
||||
Status ScanKeysAndValues(const MapCallback<std::string, std::string> &callback);
|
||||
Status ScanKeysAndValues(std::string match_pattern,
|
||||
const MapCallback<std::string, std::string> &callback);
|
||||
|
||||
Status ScanKeys(const MultiItemCallback<std::string> &callback);
|
||||
Status ScanKeys(std::string match_pattern,
|
||||
const MultiItemCallback<std::string> &callback);
|
||||
|
||||
private:
|
||||
void Scan(const StatusCallback &callback);
|
||||
void Scan(std::string match_pattern, const StatusCallback &callback);
|
||||
|
||||
void OnScanCallback(size_t shard_index, const std::shared_ptr<CallbackReply> &reply,
|
||||
void OnScanCallback(std::string match_pattern, size_t shard_index,
|
||||
const std::shared_ptr<CallbackReply> &reply,
|
||||
const StatusCallback &callback);
|
||||
|
||||
void MGetValues(const std::vector<std::string> &keys,
|
||||
const MapCallback<std::string, std::string> &callback);
|
||||
|
||||
std::string table_name_;
|
||||
|
||||
/// The scan match pattern.
|
||||
std::string match_pattern_;
|
||||
|
||||
/// Mutex to protect the shard_to_cursor_ field and the keys_ field and the
|
||||
/// key_value_map_ field.
|
||||
absl::Mutex mutex_;
|
||||
|
@ -89,9 +89,6 @@ class RedisStoreClient : public StoreClient {
|
|||
/// All keys that scanned from redis.
|
||||
absl::flat_hash_set<std::string> keys_;
|
||||
|
||||
/// Key-Value pairs that scanned from redis.
|
||||
std::unordered_map<std::string, std::string> key_value_map_;
|
||||
|
||||
/// The scan cursor for each shard.
|
||||
std::unordered_map<size_t, size_t> shard_to_cursor_;
|
||||
|
||||
|
@ -132,6 +129,10 @@ class RedisStoreClient : public StoreClient {
|
|||
const std::string &table_name,
|
||||
const std::string &index_key);
|
||||
|
||||
static Status MGetValues(std::shared_ptr<RedisClient> redis_client,
|
||||
std::string table_name, const std::vector<std::string> &keys,
|
||||
const MapCallback<std::string, std::string> &callback);
|
||||
|
||||
std::shared_ptr<RedisClient> redis_client_;
|
||||
};
|
||||
|
||||
|
|
|
@ -65,6 +65,16 @@ class StoreClient {
|
|||
virtual Status AsyncGet(const std::string &table_name, const std::string &key,
|
||||
const OptionalItemCallback<std::string> &callback) = 0;
|
||||
|
||||
/// Get data by index from the given table asynchronously.
|
||||
///
|
||||
/// \param table_name The name of the table to be read.
|
||||
/// \param index_key The secondary key that will be used to get the indexed data.
|
||||
/// \param callback Callback that will be called after read finishes.
|
||||
/// \return Status
|
||||
virtual Status AsyncGetByIndex(
|
||||
const std::string &table_name, const std::string &index_key,
|
||||
const MapCallback<std::string, std::string> &callback) = 0;
|
||||
|
||||
/// Get all data from the given table asynchronously.
|
||||
///
|
||||
/// \param table_name The name of the table to be read.
|
||||
|
|
|
@ -132,6 +132,24 @@ class StoreClientTestBase : public ::testing::Test {
|
|||
WaitPendingDone();
|
||||
}
|
||||
|
||||
void GetByIndex() {
|
||||
auto get_calllback =
|
||||
[this](const std::unordered_map<std::string, std::string> &result) {
|
||||
if (!result.empty()) {
|
||||
auto key = ActorID::FromBinary(result.begin()->first);
|
||||
auto it = key_to_index_.find(key);
|
||||
RAY_CHECK(it != key_to_index_.end());
|
||||
RAY_CHECK(index_to_keys_[it->second].size() == result.size());
|
||||
}
|
||||
pending_count_ -= result.size();
|
||||
};
|
||||
auto iter = index_to_keys_.begin();
|
||||
pending_count_ += iter->second.size();
|
||||
RAY_CHECK_OK(
|
||||
store_client_->AsyncGetByIndex(table_name_, iter->first.Hex(), get_calllback));
|
||||
WaitPendingDone();
|
||||
}
|
||||
|
||||
void DeleteByIndex() {
|
||||
auto delete_calllback = [this](const Status &status) {
|
||||
RAY_CHECK_OK(status);
|
||||
|
@ -198,6 +216,9 @@ class StoreClientTestBase : public ::testing::Test {
|
|||
// AsyncPut with index
|
||||
PutWithIndex();
|
||||
|
||||
// AsyncGet with index
|
||||
GetByIndex();
|
||||
|
||||
// AsyncDelete by index
|
||||
DeleteByIndex();
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue