Abstract plasma store get request queue (#18064)

* begin

* build

* add test

* add first test

* add test

* fix build

* lint bazel

* fix build

* fix build

* fix crash

* fix some comment

* revert shared_ptr ObjectLifecycleManager

* fix RemoveGetRequest lost

* no defer

* fix lots of comments

* fix build

* fix data race

* fix comments

* Revert "fix data race"

This reverts commit 8f58e3d70b73af864566e056211ff1b70cab870c.

* refine

* fix mac build

* fix unit test

* fix unit test
This commit is contained in:
wanxing 2021-09-03 05:16:50 +08:00 committed by GitHub
parent 549a8fa948
commit 60f84fa051
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
14 changed files with 715 additions and 281 deletions

View file

@ -335,6 +335,7 @@ cc_library(
"src/ray/object_manager/plasma/create_request_queue.cc",
"src/ray/object_manager/plasma/dlmalloc.cc",
"src/ray/object_manager/plasma/eviction_policy.cc",
"src/ray/object_manager/plasma/get_request_queue.cc",
"src/ray/object_manager/plasma/object_lifecycle_manager.cc",
"src/ray/object_manager/plasma/object_store.cc",
"src/ray/object_manager/plasma/plasma_allocator.cc",
@ -347,6 +348,7 @@ cc_library(
"src/ray/object_manager/plasma/allocator.h",
"src/ray/object_manager/plasma/create_request_queue.h",
"src/ray/object_manager/plasma/eviction_policy.h",
"src/ray/object_manager/plasma/get_request_queue.h",
"src/ray/object_manager/plasma/object_lifecycle_manager.h",
"src/ray/object_manager/plasma/object_store.h",
"src/ray/object_manager/plasma/plasma_allocator.h",
@ -1096,6 +1098,21 @@ cc_test(
],
)
cc_test(
name = "get_request_queue_test",
size = "small",
srcs = [
"src/ray/object_manager/test/get_request_queue_test.cc",
],
copts = COPTS,
tags = ["team:core"],
deps = [
":plasma_store_server_lib",
"@com_google_googletest//:gtest",
"@com_google_googletest//:gtest_main",
],
)
cc_test(
name = "worker_pool_test",
size = "small",

View file

@ -27,6 +27,7 @@
#include "ray/common/id.h"
#include "ray/object_manager/common.h"
#include "ray/object_manager/plasma/compat.h"
#include "ray/object_manager/plasma/plasma.h"
#include "ray/object_manager/plasma/plasma_generated.h"
#include "ray/util/macros.h"
@ -85,6 +86,7 @@ struct Allocation {
friend struct ObjectLifecycleManagerTest;
FRIEND_TEST(ObjectStoreTest, PassThroughTest);
FRIEND_TEST(EvictionPolicyTest, Test);
friend struct GetRequestQueueTest;
};
/// This type is used by the Plasma store. It is here because it is exposed to
@ -107,6 +109,20 @@ class LocalObject {
const plasma::flatbuf::ObjectSource &GetSource() const { return source; }
void ToPlasmaObject(PlasmaObject *object, bool check_sealed) const {
RAY_DCHECK(object != nullptr);
if (check_sealed) {
RAY_DCHECK(Sealed());
}
object->store_fd = GetAllocation().fd;
object->data_offset = GetAllocation().offset;
object->metadata_offset = GetAllocation().offset + GetObjectInfo().data_size;
object->data_size = GetObjectInfo().data_size;
object->metadata_size = GetObjectInfo().metadata_size;
object->device_num = GetAllocation().device_num;
object->mmap_size = GetAllocation().mmap_size;
}
private:
friend class ObjectStore;
friend class ObjectLifecycleManager;
@ -115,6 +131,7 @@ class LocalObject {
FRIEND_TEST(ObjectLifecycleManagerTest, RemoveReferenceOneRefNotSealed);
friend struct ObjectStatsCollectorTest;
FRIEND_TEST(EvictionPolicyTest, Test);
friend struct GetRequestQueueTest;
/// Allocation Info;
Allocation allocation;

View file

@ -21,6 +21,11 @@ using PlasmaStoreMessageHandler = std::function<ray::Status(
class ClientInterface {
public:
virtual ~ClientInterface() {}
virtual ray::Status SendFd(MEMFD_TYPE fd) = 0;
virtual const std::unordered_set<ray::ObjectID> &GetObjectIDs() = 0;
virtual void MarkObjectAsUsed(const ray::ObjectID &object_id) = 0;
virtual void MarkObjectAsUnused(const ray::ObjectID &object_id) = 0;
};
/// Contains all information that is associated with a Plasma store client.
@ -29,10 +34,17 @@ class Client : public ray::ClientConnection, public ClientInterface {
static std::shared_ptr<Client> Create(PlasmaStoreMessageHandler message_handler,
ray::local_stream_socket &&socket);
ray::Status SendFd(MEMFD_TYPE fd);
ray::Status SendFd(MEMFD_TYPE fd) override;
/// Object ids that are used by this client.
std::unordered_set<ray::ObjectID> object_ids;
const std::unordered_set<ray::ObjectID> &GetObjectIDs() override { return object_ids; }
virtual void MarkObjectAsUsed(const ray::ObjectID &object_id) override {
object_ids.insert(object_id);
}
virtual void MarkObjectAsUnused(const ray::ObjectID &object_id) override {
object_ids.erase(object_id);
}
std::string name = "anonymous_client";
@ -41,6 +53,9 @@ class Client : public ray::ClientConnection, public ClientInterface {
/// File descriptors that are used by this client.
/// TODO(ekl) we should also clean up old fds that are removed.
absl::flat_hash_set<MEMFD_TYPE> used_fds_;
/// Object ids that are used by this client.
std::unordered_set<ray::ObjectID> object_ids;
};
std::ostream &operator<<(std::ostream &os, const std::shared_ptr<Client> &client);

View file

@ -0,0 +1,195 @@
// Copyright 2017 The Ray Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "ray/object_manager/plasma/get_request_queue.h"
namespace plasma {
GetRequest::GetRequest(instrumented_io_context &io_context,
const std::shared_ptr<ClientInterface> &client,
const std::vector<ObjectID> &object_ids, bool is_from_worker,
int64_t num_unique_objects_to_wait_for)
: client(client),
object_ids(object_ids.begin(), object_ids.end()),
objects(object_ids.size()),
num_unique_objects_to_wait_for(num_unique_objects_to_wait_for),
num_unique_objects_satisfied(0),
is_from_worker(is_from_worker),
timer_(io_context) {}
void GetRequest::AsyncWait(
int64_t timeout_ms,
std::function<void(const boost::system::error_code &)> on_timeout) {
RAY_CHECK(!is_removed_);
// Set an expiry time relative to now.
timer_.expires_from_now(std::chrono::milliseconds(timeout_ms));
timer_.async_wait(on_timeout);
}
void GetRequest::CancelTimer() {
RAY_CHECK(!is_removed_);
timer_.cancel();
}
void GetRequest::MarkRemoved() {
RAY_CHECK(!is_removed_);
is_removed_ = true;
}
bool GetRequest::IsRemoved() const { return is_removed_; }
void GetRequestQueue::AddRequest(const std::shared_ptr<ClientInterface> &client,
const std::vector<ObjectID> &object_ids,
int64_t timeout_ms, bool is_from_worker) {
const absl::flat_hash_set<ObjectID> unique_ids(object_ids.begin(), object_ids.end());
// Create a get request for this object.
auto get_request = std::make_shared<GetRequest>(io_context_, client, object_ids,
is_from_worker, unique_ids.size());
for (const auto &object_id : unique_ids) {
// Check if this object is already present
// locally. If so, record that the object is being used and mark it as accounted for.
auto entry = object_lifecycle_mgr_.GetObject(object_id);
if (entry && entry->Sealed()) {
// Update the get request to take into account the present object.
entry->ToPlasmaObject(&get_request->objects[object_id], /* checksealed */ true);
get_request->num_unique_objects_satisfied += 1;
object_satisfied_callback_(object_id, get_request);
} else {
// Add a placeholder plasma object to the get request to indicate that the
// object is not present. This will be parsed by the client. We set the
// data size to -1 to indicate that the object is not present.
get_request->objects[object_id].data_size = -1;
// Add the get request to the relevant data structures.
object_get_requests_[object_id].push_back(get_request);
}
}
// If all of the objects are present already or if the timeout is 0, return to
// the client.
if (get_request->num_unique_objects_satisfied ==
get_request->num_unique_objects_to_wait_for ||
timeout_ms == 0) {
OnGetRequestCompleted(get_request);
} else if (timeout_ms != -1) {
// Set a timer that will cause the get request to return to the client. Note
// that a timeout of -1 is used to indicate that no timer should be set.
get_request->AsyncWait(timeout_ms,
[this, get_request](const boost::system::error_code &ec) {
if (ec != boost::asio::error::operation_aborted) {
// Timer was not cancelled, take necessary action.
OnGetRequestCompleted(get_request);
}
});
}
}
void GetRequestQueue::RemoveGetRequestsForClient(
const std::shared_ptr<ClientInterface> &client) {
/// TODO: Preventing duplicated can be optimized.
absl::flat_hash_set<std::shared_ptr<GetRequest>> get_requests_to_remove;
for (auto const &pair : object_get_requests_) {
for (const auto &get_request : pair.second) {
if (get_request->client == client) {
get_requests_to_remove.insert(get_request);
}
}
}
// It shouldn't be possible for a given client to be in the middle of multiple get
// requests.
RAY_CHECK(get_requests_to_remove.size() <= 1);
for (const auto &get_request : get_requests_to_remove) {
RemoveGetRequest(get_request);
}
}
void GetRequestQueue::RemoveGetRequest(const std::shared_ptr<GetRequest> &get_request) {
// Remove the get request from each of the relevant object_get_requests hash
// tables if it is present there. It should only be present there if the get
// request timed out or if it was issued by a client that has disconnected.
for (const auto &object_id : get_request->object_ids) {
auto object_request_iter = object_get_requests_.find(object_id);
if (object_request_iter != object_get_requests_.end()) {
auto &get_requests = object_request_iter->second;
// Erase get_request from the vector.
auto it = std::find(get_requests.begin(), get_requests.end(), get_request);
if (it != get_requests.end()) {
get_requests.erase(it);
// If the vector is empty, remove the object ID from the map.
if (get_requests.empty()) {
object_get_requests_.erase(object_request_iter);
}
}
}
}
// Remove the get request.
get_request->CancelTimer();
get_request->MarkRemoved();
}
void GetRequestQueue::MarkObjectSealed(const ObjectID &object_id) {
auto it = object_get_requests_.find(object_id);
// If there are no get requests involving this object, then return.
if (it == object_get_requests_.end()) {
return;
}
auto &get_requests = it->second;
// After finishing the loop below, get_requests and it will have been
// invalidated by the removal of object_id from object_get_requests_.
size_t index = 0;
size_t num_requests = get_requests.size();
for (size_t i = 0; i < num_requests; ++i) {
auto get_request = get_requests[index];
auto entry = object_lifecycle_mgr_.GetObject(object_id);
RAY_CHECK(entry != nullptr);
entry->ToPlasmaObject(&get_request->objects[object_id], /* check sealed */ true);
get_request->num_unique_objects_satisfied += 1;
object_satisfied_callback_(object_id, get_request);
// If this get request is done, reply to the client.
if (get_request->num_unique_objects_satisfied ==
get_request->num_unique_objects_to_wait_for) {
OnGetRequestCompleted(get_request);
} else {
// The call to ReturnFromGet will remove the current element in the
// array, so we only increment the counter in the else branch.
index += 1;
}
}
// No get requests should be waiting for this object anymore. The object ID
// may have been removed from the object_get_requests_ by ReturnFromGet, but
// if the get request has not returned yet, then remove the object ID from the
// map here.
it = object_get_requests_.find(object_id);
if (it != object_get_requests_.end()) {
object_get_requests_.erase(object_id);
}
}
bool GetRequestQueue::IsGetRequestExist(const ObjectID &object_id) {
return object_get_requests_.contains(object_id);
}
int64_t GetRequestQueue::GetRequestCount(const ObjectID &object_id) {
return object_get_requests_[object_id].size();
}
void GetRequestQueue::OnGetRequestCompleted(
const std::shared_ptr<GetRequest> &get_request) {
all_objects_satisfied_callback_(get_request);
RemoveGetRequest(get_request);
}
} // namespace plasma

View file

@ -0,0 +1,135 @@
// Copyright 2017 The Ray Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "ray/common/asio/instrumented_io_context.h"
#include "ray/common/id.h"
#include "ray/object_manager/plasma/connection.h"
#include "ray/object_manager/plasma/object_lifecycle_manager.h"
namespace plasma {
struct GetRequest;
using ObjectReadyCallback = std::function<void(
const ObjectID &object_id, const std::shared_ptr<GetRequest> &get_request)>;
using AllObjectReadyCallback =
std::function<void(const std::shared_ptr<GetRequest> &get_request)>;
struct GetRequest {
GetRequest(instrumented_io_context &io_context,
const std::shared_ptr<ClientInterface> &client,
const std::vector<ObjectID> &object_ids, bool is_from_worker,
int64_t num_unique_objects_to_wait_for);
/// The client that called get.
std::shared_ptr<ClientInterface> client;
/// The object IDs involved in this request. This is used in the reply.
std::vector<ObjectID> object_ids;
/// The object information for the objects in this request. This is used in
/// the reply.
absl::flat_hash_map<ObjectID, PlasmaObject> objects;
/// The minimum number of objects to wait for in this request.
const int64_t num_unique_objects_to_wait_for;
/// The number of object requests in this wait request that are already
/// satisfied.
int64_t num_unique_objects_satisfied;
/// Whether or not the request comes from the core worker. It is used to track the size
/// of total objects that are consumed by core worker.
const bool is_from_worker;
void AsyncWait(int64_t timeout_ms,
std::function<void(const boost::system::error_code &)> on_timeout);
void CancelTimer();
/// Mark that the get request is removed.
void MarkRemoved();
bool IsRemoved() const;
private:
/// The timer that will time out and cause this wait to return to
/// the client if it hasn't already returned.
boost::asio::steady_timer timer_;
/// Whether or not if this get request is removed.
/// Once the get request is removed, any operation on top of the get request shouldn't
/// happen.
bool is_removed_ = false;
};
class GetRequestQueue {
public:
GetRequestQueue(instrumented_io_context &io_context,
IObjectLifecycleManager &object_lifecycle_mgr,
ObjectReadyCallback object_callback,
AllObjectReadyCallback all_objects_callback)
: io_context_(io_context),
object_lifecycle_mgr_(object_lifecycle_mgr),
object_satisfied_callback_(object_callback),
all_objects_satisfied_callback_(all_objects_callback) {}
/// Add a get request to get request queue. Note this will call callback functions
/// directly if all objects has been satisfied, otherwise store the request
/// in queue.
/// \param client the client where the request comes from.
/// \param object_ids the object ids to get.
/// \param timeout_ms timeout in millisecond, -1 is used to indicate that no timer
/// should be set. \param is_from_worker whether the get request from a worker or not.
/// \param object_callback the callback function called once any object has been
/// satisfied. \param all_objects_callback the callback function called when all objects
/// has been satisfied.
void AddRequest(const std::shared_ptr<ClientInterface> &client,
const std::vector<ObjectID> &object_ids, int64_t timeout_ms,
bool is_from_worker);
/// Remove all of the GetRequests for a given client.
///
/// \param client The client whose GetRequests should be removed.
void RemoveGetRequestsForClient(const std::shared_ptr<ClientInterface> &client);
/// Handle a sealed object, should be called when an object sealed. Mark
/// the object satisfied and call object callbacks.
/// \param object_id the object_id to mark.
void MarkObjectSealed(const ObjectID &object_id);
private:
/// Remove a GetRequest and clean up the relevant data structures.
///
/// \param get_request The GetRequest to remove.
void RemoveGetRequest(const std::shared_ptr<GetRequest> &get_request);
/// Only for tests.
bool IsGetRequestExist(const ObjectID &object_id);
int64_t GetRequestCount(const ObjectID &object_id);
/// Called when objects satisfied. Call get request callback function and
/// remove get request in queue.
/// \param get_request the get request to be completed.
void OnGetRequestCompleted(const std::shared_ptr<GetRequest> &get_request);
instrumented_io_context &io_context_;
/// A hash table mapping object IDs to a vector of the get requests that are
/// waiting for the object to arrive.
absl::flat_hash_map<ObjectID, std::vector<std::shared_ptr<GetRequest>>>
object_get_requests_;
IObjectLifecycleManager &object_lifecycle_mgr_;
ObjectReadyCallback object_satisfied_callback_;
AllObjectReadyCallback all_objects_satisfied_callback_;
friend struct GetRequestQueueTest;
};
} // namespace plasma

View file

@ -27,13 +27,9 @@
namespace plasma {
// ObjectLifecycleManager allocates LocalObjects from the allocator.
// It tracks objects lifecycle states such as reference count or object states
// (created/sealed). It lazily garbage collects objects when running out of space.
class ObjectLifecycleManager {
class IObjectLifecycleManager {
public:
ObjectLifecycleManager(IAllocator &allocator,
ray::DeleteObjectCallback delete_object_callback);
virtual ~IObjectLifecycleManager() = default;
/// Create a new object given object's info. Object creation might
/// fail if runs out of space; or an object with the same id exists.
@ -45,15 +41,15 @@ class ObjectLifecycleManager {
/// - pointer to created object and PlasmaError::OK when succeeds.
/// - nullptr and error message, including ObjectExists/OutOfMemory
/// TODO(scv119): use RAII instead of pointer for returned object.
std::pair<const LocalObject *, flatbuf::PlasmaError> CreateObject(
virtual std::pair<const LocalObject *, flatbuf::PlasmaError> CreateObject(
const ray::ObjectInfo &object_info, plasma::flatbuf::ObjectSource source,
bool fallback_allocator);
bool fallback_allocator) = 0;
/// Get object by id.
/// \return
/// - nullptr if such object doesn't exist.
/// - otherwise, pointer to the object.
const LocalObject *GetObject(const ObjectID &object_id) const;
virtual const LocalObject *GetObject(const ObjectID &object_id) const = 0;
/// Seal created object by id.
///
@ -61,7 +57,7 @@ class ObjectLifecycleManager {
/// \return
/// - nulltpr if such object doesn't exist, or the object has already been sealed.
/// - otherise, pointer to the sealed object.
const LocalObject *SealObject(const ObjectID &object_id);
virtual const LocalObject *SealObject(const ObjectID &object_id) = 0;
/// Abort object creation by id. It deletes the object regardless of reference
/// counting.
@ -71,7 +67,7 @@ class ObjectLifecycleManager {
/// - PlasmaError::OK, if the object was aborted successfully.
/// - PlasmaError::ObjectNonexistent, if ths object doesn't exist.
/// - PlasmaError::ObjectSealed, if ths object has already been sealed.
flatbuf::PlasmaError AbortObject(const ObjectID &object_id);
virtual flatbuf::PlasmaError AbortObject(const ObjectID &object_id) = 0;
/// Delete a specific object by object_id. The object is delete immediately
/// if it's been sealed and reference counting is zero. Otherwise it will be
@ -84,18 +80,43 @@ class ObjectLifecycleManager {
/// - PlasmaError::ObjectNotsealed, if ths object is created but not sealed.
/// - PlasmaError::ObjectInUse, if the object is in use; it will be deleted
/// once it's no longer used (ref count becomes 0).
flatbuf::PlasmaError DeleteObject(const ObjectID &object_id);
virtual flatbuf::PlasmaError DeleteObject(const ObjectID &object_id) = 0;
/// Bump up the reference count of the object.
///
/// \return true if object exists, false otherise.
bool AddReference(const ObjectID &object_id);
virtual bool AddReference(const ObjectID &object_id) = 0;
/// Decrese the reference count of the object. When reference count
/// drop to zero the object becomes evictable.
///
/// \return true if object exists and reference count is greater than 0, false otherise.
bool RemoveReference(const ObjectID &object_id);
virtual bool RemoveReference(const ObjectID &object_id) = 0;
};
// ObjectLifecycleManager allocates LocalObjects from the allocator.
// It tracks objects lifecycle states such as reference count or object states
// (created/sealed). It lazily garbage collects objects when running out of space.
class ObjectLifecycleManager : public IObjectLifecycleManager {
public:
ObjectLifecycleManager(IAllocator &allocator,
ray::DeleteObjectCallback delete_object_callback);
std::pair<const LocalObject *, flatbuf::PlasmaError> CreateObject(
const ray::ObjectInfo &object_info, plasma::flatbuf::ObjectSource source,
bool fallback_allocator) override;
const LocalObject *GetObject(const ObjectID &object_id) const override;
const LocalObject *SealObject(const ObjectID &object_id) override;
flatbuf::PlasmaError AbortObject(const ObjectID &object_id) override;
flatbuf::PlasmaError DeleteObject(const ObjectID &object_id) override;
bool AddReference(const ObjectID &object_id) override;
bool RemoveReference(const ObjectID &object_id) override;
/// Ask it to evict objects until we have at least size of capacity
/// available.
@ -140,6 +161,8 @@ class ObjectLifecycleManager {
friend struct ObjectStatsCollectorTest;
FRIEND_TEST(ObjectLifecycleManagerTest, DeleteFailure);
FRIEND_TEST(ObjectLifecycleManagerTest, RemoveReferenceOneRefEagerlyDeletion);
friend struct GetRequestQueueTest;
FRIEND_TEST(GetRequestQueueTest, TestAddRequest);
std::unique_ptr<IObjectStore> object_store_;
std::unique_ptr<IEvictionPolicy> eviction_policy_;

View file

@ -23,5 +23,4 @@ namespace plasma {
LocalObject::LocalObject(Allocation allocation)
: allocation(std::move(allocation)), ref_count(0) {}
} // namespace plasma

View file

@ -16,6 +16,7 @@
// under the License.
#pragma once
#include <stddef.h>
#include <memory>
#include <string>
@ -23,7 +24,6 @@
#include <unordered_set>
#include <vector>
#include "ray/object_manager/plasma/common.h"
#include "ray/object_manager/plasma/compat.h"
namespace plasma {

View file

@ -545,7 +545,7 @@ Status ReadGetRequest(uint8_t *data, size_t size, std::vector<ObjectID> &object_
}
Status SendGetReply(const std::shared_ptr<Client> &client, ObjectID object_ids[],
std::unordered_map<ObjectID, PlasmaObject> &plasma_objects,
absl::flat_hash_map<ObjectID, PlasmaObject> &plasma_objects,
int64_t num_objects, const std::vector<MEMFD_TYPE> &store_fds,
const std::vector<int64_t> &mmap_sizes) {
flatbuffers::FlatBufferBuilder fbb;

View file

@ -22,8 +22,10 @@
#include <unordered_map>
#include <vector>
#include "absl/container/flat_hash_map.h"
#include "ray/common/status.h"
#include "ray/object_manager/common.h"
#include "ray/object_manager/plasma/common.h"
#include "ray/object_manager/plasma/plasma.h"
#include "ray/object_manager/plasma/plasma_generated.h"
#include "src/ray/protobuf/common.pb.h"
@ -124,7 +126,7 @@ Status ReadGetRequest(uint8_t *data, size_t size, std::vector<ObjectID> &object_
int64_t *timeout_ms, bool *is_from_worker);
Status SendGetReply(const std::shared_ptr<Client> &client, ObjectID object_ids[],
std::unordered_map<ObjectID, PlasmaObject> &plasma_objects,
absl::flat_hash_map<ObjectID, PlasmaObject> &plasma_objects,
int64_t num_objects, const std::vector<MEMFD_TYPE> &store_fds,
const std::vector<int64_t> &mmap_sizes);

View file

@ -47,6 +47,7 @@
#include "ray/common/asio/asio_util.h"
#include "ray/common/asio/instrumented_io_context.h"
#include "ray/object_manager/plasma/common.h"
#include "ray/object_manager/plasma/get_request_queue.h"
#include "ray/object_manager/plasma/malloc.h"
#include "ray/object_manager/plasma/plasma_allocator.h"
#include "ray/object_manager/plasma/protocol.h"
@ -64,86 +65,8 @@ ray::ObjectID GetCreateRequestObjectId(const std::vector<uint8_t> &message) {
RAY_DCHECK(plasma::VerifyFlatbuffer(request, input, input_size));
return ray::ObjectID::FromBinary(request->object_id()->str());
}
void ToPlasmaObject(const LocalObject &entry, PlasmaObject *object, bool check_sealed) {
RAY_DCHECK(object != nullptr);
if (check_sealed) {
RAY_DCHECK(entry.Sealed());
}
object->store_fd = entry.GetAllocation().fd;
object->data_offset = entry.GetAllocation().offset;
object->metadata_offset =
entry.GetAllocation().offset + entry.GetObjectInfo().data_size;
object->data_size = entry.GetObjectInfo().data_size;
object->metadata_size = entry.GetObjectInfo().metadata_size;
object->device_num = entry.GetAllocation().device_num;
object->mmap_size = entry.GetAllocation().mmap_size;
}
} // namespace
struct GetRequest {
GetRequest(instrumented_io_context &io_context, const std::shared_ptr<Client> &client,
const std::vector<ObjectID> &object_ids, bool is_from_worker);
/// The client that called get.
std::shared_ptr<Client> client;
/// The object IDs involved in this request. This is used in the reply.
std::vector<ObjectID> object_ids;
/// The object information for the objects in this request. This is used in
/// the reply.
std::unordered_map<ObjectID, PlasmaObject> objects;
/// The minimum number of objects to wait for in this request.
int64_t num_objects_to_wait_for;
/// The number of object requests in this wait request that are already
/// satisfied.
int64_t num_satisfied;
/// Whether or not the request comes from the core worker. It is used to track the size
/// of total objects that are consumed by core worker.
bool is_from_worker;
void AsyncWait(int64_t timeout_ms,
std::function<void(const boost::system::error_code &)> on_timeout) {
RAY_CHECK(!is_removed_);
// Set an expiry time relative to now.
timer_.expires_from_now(std::chrono::milliseconds(timeout_ms));
timer_.async_wait(on_timeout);
}
void CancelTimer() {
RAY_CHECK(!is_removed_);
timer_.cancel();
}
/// Mark that the get request is removed.
void MarkRemoved() {
RAY_CHECK(!is_removed_);
is_removed_ = true;
}
bool IsRemoved() const { return is_removed_; }
private:
/// The timer that will time out and cause this wait to return to
/// the client if it hasn't already returned.
boost::asio::steady_timer timer_;
/// Whether or not if this get request is removed.
/// Once the get request is removed, any operation on top of the get request shouldn't
/// happen.
bool is_removed_ = false;
};
GetRequest::GetRequest(instrumented_io_context &io_context,
const std::shared_ptr<Client> &client,
const std::vector<ObjectID> &object_ids, bool is_from_worker)
: client(client),
object_ids(object_ids.begin(), object_ids.end()),
objects(object_ids.size()),
num_satisfied(0),
is_from_worker(is_from_worker),
timer_(io_context) {
std::unordered_set<ObjectID> unique_ids(object_ids.begin(), object_ids.end());
num_objects_to_wait_for = unique_ids.size();
}
PlasmaStore::PlasmaStore(instrumented_io_context &main_service, IAllocator &allocator,
const std::string &socket_name, uint32_t delay_on_oom_ms,
float object_spilling_threshold,
@ -156,7 +79,6 @@ PlasmaStore::PlasmaStore(instrumented_io_context &main_service, IAllocator &allo
acceptor_(main_service, ParseUrlEndpoint(socket_name)),
socket_(main_service),
allocator_(allocator),
spill_objects_callback_(spill_objects_callback),
add_object_callback_(add_object_callback),
delete_object_callback_(delete_object_callback),
object_lifecycle_mgr_(allocator_, delete_object_callback_),
@ -167,7 +89,12 @@ PlasmaStore::PlasmaStore(instrumented_io_context &main_service, IAllocator &allo
spill_objects_callback, object_store_full_callback,
/*get_time=*/
[]() { return absl::GetCurrentTimeNanos(); },
[this]() { return GetDebugDump(); }) {
[this]() { return GetDebugDump(); }),
get_request_queue_(io_context_, object_lifecycle_mgr_,
[this](const ObjectID &object_id, const auto &request) {
this->AddToClientObjectIds(object_id, request->client);
},
[this](const auto &request) { this->ReturnFromGet(request); }) {
const auto event_stats_print_interval_ms =
RayConfig::instance().event_stats_print_interval_ms();
if (event_stats_print_interval_ms > 0 && RayConfig::instance().event_stats()) {
@ -188,14 +115,15 @@ void PlasmaStore::Stop() { acceptor_.close(); }
// If this client is not already using the object, add the client to the
// object's list of clients, otherwise do nothing.
void PlasmaStore::AddToClientObjectIds(const ObjectID &object_id,
const std::shared_ptr<Client> &client) {
const std::shared_ptr<ClientInterface> &client) {
// Check if this client is already using the object.
if (client->object_ids.find(object_id) != client->object_ids.end()) {
auto &object_ids = client->GetObjectIDs();
if (object_ids.find(object_id) != object_ids.end()) {
return;
}
RAY_CHECK(object_lifecycle_mgr_.AddReference(object_id));
// Add object id to the list of object ids that this client is using.
client->object_ids.insert(object_id);
client->MarkObjectAsUsed(object_id);
}
PlasmaError PlasmaStore::HandleCreateObjectRequest(const std::shared_ptr<Client> &client,
@ -249,58 +177,16 @@ PlasmaError PlasmaStore::CreateObject(const ray::ObjectInfo &object_info,
if (entry == nullptr) {
return error;
}
ToPlasmaObject(*entry, result, /* check sealed */ false);
entry->ToPlasmaObject(result, /* check sealed */ false);
// Record that this client is using this object.
AddToClientObjectIds(object_info.object_id, client);
return PlasmaError::OK;
}
void PlasmaStore::RemoveGetRequest(const std::shared_ptr<GetRequest> &get_request) {
// Remove the get request from each of the relevant object_get_requests hash
// tables if it is present there. It should only be present there if the get
// request timed out or if it was issued by a client that has disconnected.
for (ObjectID &object_id : get_request->object_ids) {
auto object_request_iter = object_get_requests_.find(object_id);
if (object_request_iter != object_get_requests_.end()) {
auto &get_requests = object_request_iter->second;
// Erase get_req from the vector.
auto it = std::find(get_requests.begin(), get_requests.end(), get_request);
if (it != get_requests.end()) {
get_requests.erase(it);
// If the vector is empty, remove the object ID from the map.
if (get_requests.empty()) {
object_get_requests_.erase(object_request_iter);
}
}
}
}
// Remove the get request.
get_request->CancelTimer();
get_request->MarkRemoved();
}
void PlasmaStore::RemoveGetRequestsForClient(const std::shared_ptr<Client> &client) {
std::unordered_set<std::shared_ptr<GetRequest>> get_requests_to_remove;
for (auto const &pair : object_get_requests_) {
for (const auto &get_request : pair.second) {
if (get_request->client == client) {
get_requests_to_remove.insert(get_request);
}
}
}
// It shouldn't be possible for a given client to be in the middle of multiple get
// requests.
RAY_CHECK(get_requests_to_remove.size() <= 1);
for (const auto &get_request : get_requests_to_remove) {
RemoveGetRequest(get_request);
}
}
void PlasmaStore::ReturnFromGet(const std::shared_ptr<GetRequest> &get_req) {
void PlasmaStore::ReturnFromGet(const std::shared_ptr<GetRequest> &get_request) {
// If the get request is already removed, do no-op. This can happen because the boost
// timer is not atomic. See https://github.com/ray-project/ray/pull/15071.
if (get_req->IsRemoved()) {
if (get_request->IsRemoved()) {
return;
}
@ -308,133 +194,50 @@ void PlasmaStore::ReturnFromGet(const std::shared_ptr<GetRequest> &get_req) {
absl::flat_hash_set<MEMFD_TYPE> fds_to_send;
std::vector<MEMFD_TYPE> store_fds;
std::vector<int64_t> mmap_sizes;
for (const auto &object_id : get_req->object_ids) {
const PlasmaObject &object = get_req->objects[object_id];
for (const auto &object_id : get_request->object_ids) {
const PlasmaObject &object = get_request->objects[object_id];
MEMFD_TYPE fd = object.store_fd;
if (object.data_size != -1 && fds_to_send.count(fd) == 0 && fd.first != INVALID_FD) {
fds_to_send.insert(fd);
store_fds.push_back(fd);
mmap_sizes.push_back(object.mmap_size);
if (get_req->is_from_worker) {
if (get_request->is_from_worker) {
total_consumed_bytes_ += object.data_size + object.metadata_size;
}
}
}
// Send the get reply to the client.
Status s = SendGetReply(get_req->client, &get_req->object_ids[0], get_req->objects,
get_req->object_ids.size(), store_fds, mmap_sizes);
Status s = SendGetReply(std::dynamic_pointer_cast<Client>(get_request->client),
&get_request->object_ids[0], get_request->objects,
get_request->object_ids.size(), store_fds, mmap_sizes);
// If we successfully sent the get reply message to the client, then also send
// the file descriptors.
if (s.ok()) {
// Send all of the file descriptors for the present objects.
for (MEMFD_TYPE store_fd : store_fds) {
Status send_fd_status = get_req->client->SendFd(store_fd);
Status send_fd_status = get_request->client->SendFd(store_fd);
if (!send_fd_status.ok()) {
RAY_LOG(ERROR) << "Failed to send mmap results to client on fd "
<< get_req->client;
<< get_request->client;
}
}
} else {
RAY_LOG(ERROR) << "Failed to send Get reply to client on fd " << get_req->client;
}
// Remove the get request from each of the relevant object_get_requests hash
// tables if it is present there. It should only be present there if the get
// request timed out.
RemoveGetRequest(get_req);
}
void PlasmaStore::NotifyObjectSealedToGetRequests(const ObjectID &object_id) {
auto it = object_get_requests_.find(object_id);
// If there are no get requests involving this object, then return.
if (it == object_get_requests_.end()) {
return;
}
auto &get_requests = it->second;
// After finishing the loop below, get_requests and it will have been
// invalidated by the removal of object_id from object_get_requests_.
size_t index = 0;
size_t num_requests = get_requests.size();
for (size_t i = 0; i < num_requests; ++i) {
auto get_req = get_requests[index];
auto entry = object_lifecycle_mgr_.GetObject(object_id);
RAY_CHECK(entry != nullptr);
ToPlasmaObject(*entry, &get_req->objects[object_id], /* check sealed */ true);
get_req->num_satisfied += 1;
// Record the fact that this client will be using this object and will
// be responsible for releasing this object.
AddToClientObjectIds(object_id, get_req->client);
// If this get request is done, reply to the client.
if (get_req->num_satisfied == get_req->num_objects_to_wait_for) {
ReturnFromGet(get_req);
} else {
// The call to ReturnFromGet will remove the current element in the
// array, so we only increment the counter in the else branch.
index += 1;
}
}
// No get requests should be waiting for this object anymore. The object ID
// may have been removed from the object_get_requests_ by ReturnFromGet, but
// if the get request has not returned yet, then remove the object ID from the
// map here.
it = object_get_requests_.find(object_id);
if (it != object_get_requests_.end()) {
object_get_requests_.erase(object_id);
RAY_LOG(ERROR) << "Failed to send Get reply to client on fd " << get_request->client;
}
}
void PlasmaStore::ProcessGetRequest(const std::shared_ptr<Client> &client,
const std::vector<ObjectID> &object_ids,
int64_t timeout_ms, bool is_from_worker) {
// Create a get request for this object.
auto get_req = std::make_shared<GetRequest>(
GetRequest(io_context_, client, object_ids, is_from_worker));
for (auto object_id : object_ids) {
// Check if this object is already present
// locally. If so, record that the object is being used and mark it as accounted for.
auto entry = object_lifecycle_mgr_.GetObject(object_id);
if (entry && entry->Sealed()) {
// Update the get request to take into account the present object.
ToPlasmaObject(*entry, &get_req->objects[object_id], /* checksealed */ true);
get_req->num_satisfied += 1;
// If necessary, record that this client is using this object. In the case
// where entry == NULL, this will be called from SealObject.
AddToClientObjectIds(object_id, client);
} else {
// Add a placeholder plasma object to the get request to indicate that the
// object is not present. This will be parsed by the client. We set the
// data size to -1 to indicate that the object is not present.
get_req->objects[object_id].data_size = -1;
// Add the get request to the relevant data structures.
object_get_requests_[object_id].push_back(get_req);
}
}
// If all of the objects are present already or if the timeout is 0, return to
// the client.
if (get_req->num_satisfied == get_req->num_objects_to_wait_for || timeout_ms == 0) {
ReturnFromGet(get_req);
} else if (timeout_ms != -1) {
// Set a timer that will cause the get request to return to the client. Note
// that a timeout of -1 is used to indicate that no timer should be set.
get_req->AsyncWait(timeout_ms, [this, get_req](const boost::system::error_code &ec) {
if (ec != boost::asio::error::operation_aborted) {
// Timer was not cancelled, take necessary action.
ReturnFromGet(get_req);
}
});
}
get_request_queue_.AddRequest(client, object_ids, timeout_ms, is_from_worker);
}
int PlasmaStore::RemoveFromClientObjectIds(const ObjectID &object_id,
const std::shared_ptr<Client> &client) {
auto it = client->object_ids.find(object_id);
if (it != client->object_ids.end()) {
client->object_ids.erase(it);
auto &object_ids = client->GetObjectIDs();
auto it = object_ids.find(object_id);
if (it != object_ids.end()) {
client->MarkObjectAsUnused(*it);
RAY_LOG(DEBUG) << "Object " << object_id << " no longer in use by client";
// Decrease reference count.
object_lifecycle_mgr_.RemoveReference(object_id);
@ -463,21 +266,22 @@ void PlasmaStore::SealObjects(const std::vector<ObjectID> &object_ids) {
}
for (size_t i = 0; i < object_ids.size(); ++i) {
NotifyObjectSealedToGetRequests(object_ids[i]);
get_request_queue_.MarkObjectSealed(object_ids[i]);
}
}
int PlasmaStore::AbortObject(const ObjectID &object_id,
const std::shared_ptr<Client> &client) {
auto it = client->object_ids.find(object_id);
if (it == client->object_ids.end()) {
auto &object_ids = client->GetObjectIDs();
auto it = object_ids.find(object_id);
if (it == object_ids.end()) {
// If the client requesting the abort is not the creator, do not
// perform the abort.
return 0;
}
// The client requesting the abort is the creator. Free the object.
RAY_CHECK(object_lifecycle_mgr_.AbortObject(object_id) == PlasmaError::OK);
client->object_ids.erase(it);
client->MarkObjectAsUnused(*it);
return 1;
}
@ -496,7 +300,8 @@ void PlasmaStore::DisconnectClient(const std::shared_ptr<Client> &client) {
RAY_LOG(DEBUG) << "Disconnecting client on fd " << client;
// Release all the objects that the client was using.
std::unordered_map<ObjectID, const LocalObject *> sealed_objects;
for (const auto &object_id : client->object_ids) {
auto &object_ids = client->GetObjectIDs();
for (const auto &object_id : object_ids) {
auto entry = object_lifecycle_mgr_.GetObject(object_id);
if (entry == nullptr) {
continue;
@ -513,7 +318,7 @@ void PlasmaStore::DisconnectClient(const std::shared_ptr<Client> &client) {
}
/// Remove all of the client's GetRequests.
RemoveGetRequestsForClient(client);
get_request_queue_.RemoveGetRequestsForClient(client);
for (const auto &entry : sealed_objects) {
RemoveFromClientObjectIds(entry.first, client);

View file

@ -32,6 +32,7 @@
#include "ray/object_manager/plasma/connection.h"
#include "ray/object_manager/plasma/create_request_queue.h"
#include "ray/object_manager/plasma/eviction_policy.h"
#include "ray/object_manager/plasma/get_request_queue.h"
#include "ray/object_manager/plasma/object_lifecycle_manager.h"
#include "ray/object_manager/plasma/object_store.h"
#include "ray/object_manager/plasma/plasma.h"
@ -48,8 +49,6 @@ enum class PlasmaError;
using flatbuf::PlasmaError;
struct GetRequest;
class PlasmaStore {
public:
// TODO: PascalCase PlasmaStore methods.
@ -200,21 +199,9 @@ class PlasmaStore {
const ObjectID &object_id, uint64_t req_id);
void AddToClientObjectIds(const ObjectID &object_id,
const std::shared_ptr<Client> &client);
const std::shared_ptr<ClientInterface> &client);
/// Remove a GetRequest and clean up the relevant data structures.
///
/// \param get_request The GetRequest to remove.
void RemoveGetRequest(const std::shared_ptr<GetRequest> &get_request);
/// Remove all of the GetRequests for a given client.
///
/// \param client The client whose GetRequests should be removed.
void RemoveGetRequestsForClient(const std::shared_ptr<Client> &client);
void ReturnFromGet(const std::shared_ptr<GetRequest> &get_req);
void NotifyObjectSealedToGetRequests(const ObjectID &object_id);
void ReturnFromGet(const std::shared_ptr<GetRequest> &get_request);
int RemoveFromClientObjectIds(const ObjectID &object_id,
const std::shared_ptr<Client> &client);
@ -232,21 +219,9 @@ class PlasmaStore {
ray::local_stream_socket socket_;
/// The allocator that allocates mmaped memory.
IAllocator &allocator_;
/// The object store stores created objects.
/// A hash table mapping object IDs to a vector of the get requests that are
/// waiting for the object to arrive.
std::unordered_map<ObjectID, std::vector<std::shared_ptr<GetRequest>>>
object_get_requests_;
std::unordered_set<ObjectID> deletion_cache_;
/// A callback to asynchronously spill objects when space is needed. The
/// callback returns the amount of space still needed after the spilling is
/// complete.
/// NOTE: This function should guarantee the thread-safety because the callback is
/// shared with the main raylet thread.
const ray::SpillObjectsCallback spill_objects_callback_;
/// A callback to asynchronously notify that an object is sealed.
/// NOTE: This function should guarantee the thread-safety because the callback is
/// shared with the main raylet thread.
@ -291,6 +266,8 @@ class PlasmaStore {
/// Whether we have dumped debug information on OOM yet. This limits dump
/// (which can be expensive) to once per OOM event.
bool dumped_on_oom_ = false;
GetRequestQueue get_request_queue_;
};
} // namespace plasma

View file

@ -14,6 +14,7 @@
#include "ray/object_manager/plasma/create_request_queue.h"
#include "gmock/gmock.h"
#include "gtest/gtest.h"
#include "ray/common/status.h"
@ -21,7 +22,10 @@ namespace plasma {
class MockClient : public ClientInterface {
public:
MockClient() {}
MOCK_METHOD1(SendFd, Status(MEMFD_TYPE));
MOCK_METHOD0(GetObjectIDs, const std::unordered_set<ray::ObjectID> &());
MOCK_METHOD1(MarkObjectAsUsed, void(const ObjectID &object_id));
MOCK_METHOD1(MarkObjectAsUnused, void(const ObjectID &object_id));
};
#define ASSERT_REQUEST_UNFINISHED(queue, req_id) \

View file

@ -0,0 +1,245 @@
// Copyright 2017 The Ray Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "ray/object_manager/plasma/get_request_queue.h"
#include "gmock/gmock.h"
#include "gtest/gtest.h"
using namespace ray;
using namespace testing;
namespace plasma {
class MockClient : public ClientInterface {
public:
MOCK_METHOD1(SendFd, Status(MEMFD_TYPE));
MOCK_METHOD0(GetObjectIDs, const std::unordered_set<ray::ObjectID> &());
MOCK_METHOD1(MarkObjectAsUsed, void(const ObjectID &object_id));
MOCK_METHOD1(MarkObjectAsUnused, void(const ObjectID &object_id));
};
class MockObjectLifecycleManager : public IObjectLifecycleManager {
public:
MOCK_METHOD3(CreateObject,
std::pair<const LocalObject *, flatbuf::PlasmaError>(
const ray::ObjectInfo &object_info,
plasma::flatbuf::ObjectSource source, bool fallback_allocator));
MOCK_CONST_METHOD1(GetObject, const LocalObject *(const ObjectID &object_id));
MOCK_METHOD1(SealObject, const LocalObject *(const ObjectID &object_id));
MOCK_METHOD1(AbortObject, flatbuf::PlasmaError(const ObjectID &object_id));
MOCK_METHOD1(DeleteObject, flatbuf::PlasmaError(const ObjectID &object_id));
MOCK_METHOD1(AddReference, bool(const ObjectID &object_id));
MOCK_METHOD1(RemoveReference, bool(const ObjectID &object_id));
};
struct GetRequestQueueTest : public Test {
public:
GetRequestQueueTest() : io_work_(io_context_) {}
void SetUp() override {
Test::SetUp();
object_id1 = ObjectID::FromRandom();
object_id2 = ObjectID::FromRandom();
object1.object_info.data_size = 10;
object1.object_info.metadata_size = 0;
object2.object_info.data_size = 10;
object2.object_info.metadata_size = 0;
}
void TearDown() override { io_context_.stop(); }
protected:
void MarkObject(LocalObject &object, ObjectState state) { object.state = state; }
bool IsGetRequestExist(GetRequestQueue &queue, const ObjectID &object_id) {
return queue.IsGetRequestExist(object_id);
}
int64_t GetRequestCount(GetRequestQueue &queue, const ObjectID &object_id) {
return queue.GetRequestCount(object_id);
}
void AssertNoLeak(GetRequestQueue &queue) {
EXPECT_FALSE(IsGetRequestExist(queue, object_id1));
EXPECT_FALSE(IsGetRequestExist(queue, object_id2));
}
protected:
instrumented_io_context io_context_;
boost::asio::io_service::work io_work_;
std::thread thread_;
LocalObject object1{Allocation()};
LocalObject object2{Allocation()};
ObjectID object_id1;
ObjectID object_id2;
};
TEST_F(GetRequestQueueTest, TestObjectSealed) {
bool satisfied = false;
MockObjectLifecycleManager object_lifecycle_manager;
GetRequestQueue get_request_queue(
io_context_, object_lifecycle_manager,
[&](const ObjectID &object_id, const auto &request) {},
[&](const std::shared_ptr<GetRequest> &get_req) { satisfied = true; });
auto client = std::make_shared<MockClient>();
/// Test object has been satisfied.
std::vector<ObjectID> object_ids{object_id1};
/// Mock the object already sealed.
MarkObject(object1, ObjectState::PLASMA_SEALED);
EXPECT_CALL(object_lifecycle_manager, GetObject(_)).Times(1).WillOnce(Return(&object1));
get_request_queue.AddRequest(client, object_ids, 1000, false);
EXPECT_TRUE(satisfied);
AssertNoLeak(get_request_queue);
}
TEST_F(GetRequestQueueTest, TestObjectTimeout) {
std::promise<bool> promise;
MockObjectLifecycleManager object_lifecycle_manager;
GetRequestQueue get_request_queue(
io_context_, object_lifecycle_manager,
[&](const ObjectID &object_id, const auto &request) {},
[&](const std::shared_ptr<GetRequest> &get_req) { promise.set_value(true); });
auto client = std::make_shared<MockClient>();
/// Test object not satisfied, time out.
std::vector<ObjectID> object_ids{object_id1};
MarkObject(object1, ObjectState::PLASMA_CREATED);
EXPECT_CALL(object_lifecycle_manager, GetObject(_)).Times(1).WillOnce(Return(&object1));
get_request_queue.AddRequest(client, object_ids, 1000, false);
/// This trigger timeout
io_context_.run_one();
promise.get_future().get();
AssertNoLeak(get_request_queue);
}
TEST_F(GetRequestQueueTest, TestObjectNotSealed) {
std::promise<bool> promise;
MockObjectLifecycleManager object_lifecycle_manager;
GetRequestQueue get_request_queue(
io_context_, object_lifecycle_manager,
[&](const ObjectID &object_id, const auto &request) {},
[&](const std::shared_ptr<GetRequest> &get_req) { promise.set_value(true); });
auto client = std::make_shared<MockClient>();
/// Test object not satisfied, then sealed.
std::vector<ObjectID> object_ids{object_id1};
MarkObject(object1, ObjectState::PLASMA_CREATED);
EXPECT_CALL(object_lifecycle_manager, GetObject(_))
.Times(2)
.WillRepeatedly(Return(&object1));
get_request_queue.AddRequest(client, object_ids, /*timeout_ms*/ -1, false);
MarkObject(object1, ObjectState::PLASMA_SEALED);
get_request_queue.MarkObjectSealed(object_id1);
promise.get_future().get();
AssertNoLeak(get_request_queue);
}
TEST_F(GetRequestQueueTest, TestMultipleObjects) {
std::promise<bool> promise1, promise2, promise3;
MockObjectLifecycleManager object_lifecycle_manager;
GetRequestQueue get_request_queue(
io_context_, object_lifecycle_manager,
[&](const ObjectID &object_id, const auto &request) {
if (object_id == object_id1) {
promise1.set_value(true);
}
if (object_id == object_id2) {
promise2.set_value(true);
}
},
[&](const std::shared_ptr<GetRequest> &get_req) { promise3.set_value(true); });
auto client = std::make_shared<MockClient>();
/// Test get request of mulitiple objects, one sealed, one timed out.
std::vector<ObjectID> object_ids{object_id1, object_id2};
MarkObject(object1, ObjectState::PLASMA_SEALED);
MarkObject(object2, ObjectState::PLASMA_CREATED);
EXPECT_CALL(object_lifecycle_manager, GetObject(Eq(object_id1)))
.WillRepeatedly(Return(&object1));
EXPECT_CALL(object_lifecycle_manager, GetObject(Eq(object_id2)))
.WillRepeatedly(Return(&object2));
get_request_queue.AddRequest(client, object_ids, 1000, false);
promise1.get_future().get();
EXPECT_FALSE(IsGetRequestExist(get_request_queue, object_id1));
EXPECT_TRUE(IsGetRequestExist(get_request_queue, object_id2));
MarkObject(object2, ObjectState::PLASMA_SEALED);
get_request_queue.MarkObjectSealed(object_id2);
io_context_.run_one();
promise2.get_future().get();
promise3.get_future().get();
AssertNoLeak(get_request_queue);
}
TEST_F(GetRequestQueueTest, TestDuplicateObjects) {
MockObjectLifecycleManager object_lifecycle_manager;
GetRequestQueue get_request_queue(
io_context_, object_lifecycle_manager,
[&](const ObjectID &object_id, const auto &request) {},
[&](const std::shared_ptr<GetRequest> &get_req) {});
auto client = std::make_shared<MockClient>();
/// Test get request of duplicated objects.
std::vector<ObjectID> object_ids{object_id1, object_id2, object_id1};
/// Set state to PLASMA_CREATED, so we can check them using IsGetRequestExist.
MarkObject(object1, ObjectState::PLASMA_CREATED);
MarkObject(object2, ObjectState::PLASMA_CREATED);
EXPECT_CALL(object_lifecycle_manager, GetObject(_))
.Times(2)
.WillOnce(Return(&object1))
.WillOnce(Return(&object2));
get_request_queue.AddRequest(client, object_ids, 1000, false);
EXPECT_TRUE(IsGetRequestExist(get_request_queue, object_id1));
EXPECT_TRUE(IsGetRequestExist(get_request_queue, object_id2));
EXPECT_EQ(1, GetRequestCount(get_request_queue, object_id1));
EXPECT_EQ(1, GetRequestCount(get_request_queue, object_id2));
}
TEST_F(GetRequestQueueTest, TestRemoveAll) {
MockObjectLifecycleManager object_lifecycle_manager;
GetRequestQueue get_request_queue(
io_context_, object_lifecycle_manager,
[&](const ObjectID &object_id, const auto &request) {},
[&](const std::shared_ptr<GetRequest> &get_req) {});
auto client = std::make_shared<MockClient>();
/// Test get request two not-sealed objects, remove all requests for this client.
std::vector<ObjectID> object_ids{object_id1, object_id2};
MarkObject(object1, ObjectState::PLASMA_CREATED);
MarkObject(object2, ObjectState::PLASMA_CREATED);
EXPECT_CALL(object_lifecycle_manager, GetObject(_))
.Times(2)
.WillOnce(Return(&object1))
.WillOnce(Return(&object2));
get_request_queue.AddRequest(client, object_ids, 1000, false);
EXPECT_TRUE(IsGetRequestExist(get_request_queue, object_id1));
EXPECT_TRUE(IsGetRequestExist(get_request_queue, object_id2));
get_request_queue.RemoveGetRequestsForClient(client);
EXPECT_FALSE(IsGetRequestExist(get_request_queue, object_id1));
EXPECT_FALSE(IsGetRequestExist(get_request_queue, object_id2));
AssertNoLeak(get_request_queue);
}
} // namespace plasma
int main(int argc, char **argv) {
::testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
}