diff --git a/BUILD.bazel b/BUILD.bazel index dbeeb4e60..7f65bdbf8 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -1079,6 +1079,20 @@ cc_test( ], ) +cc_test( + name = "object_buffer_pool_test", + size = "small", + srcs = [ + "src/ray/object_manager/test/object_buffer_pool_test.cc", + ], + copts = COPTS, + tags = ["team:core"], + deps = [ + ":object_manager", + "@com_google_googletest//:gtest_main", + ], +) + cc_test( name = "ownership_based_object_directory_test", size = "small", diff --git a/src/ray/object_manager/object_buffer_pool.cc b/src/ray/object_manager/object_buffer_pool.cc index 9f4101c39..a91f64952 100644 --- a/src/ray/object_manager/object_buffer_pool.cc +++ b/src/ray/object_manager/object_buffer_pool.cc @@ -20,11 +20,9 @@ namespace ray { -ObjectBufferPool::ObjectBufferPool(const std::string &store_socket_name, - uint64_t chunk_size) - : store_socket_name_(store_socket_name), default_chunk_size_(chunk_size) { - RAY_CHECK_OK(store_client_.Connect(store_socket_name_.c_str(), "", 0, 300)); -} +ObjectBufferPool::ObjectBufferPool( + std::shared_ptr store_client, uint64_t chunk_size) + : store_client_(store_client), default_chunk_size_(chunk_size) {} ObjectBufferPool::~ObjectBufferPool() { absl::MutexLock lock(&pool_mutex_); @@ -50,13 +48,13 @@ ObjectBufferPool::~ObjectBufferPool() { // Abort unfinished buffers in progress. for (auto it = create_buffer_state_.begin(); it != create_buffer_state_.end(); it++) { - RAY_CHECK_OK(store_client_.Release(it->first)); - RAY_CHECK_OK(store_client_.Abort(it->first)); + RAY_CHECK_OK(store_client_->Release(it->first)); + RAY_CHECK_OK(store_client_->Abort(it->first)); create_buffer_state_.erase(it); } RAY_CHECK(create_buffer_state_.empty()); - RAY_CHECK_OK(store_client_.Disconnect()); + RAY_CHECK_OK(store_client_->Disconnect()); } uint64_t ObjectBufferPool::GetNumChunks(uint64_t data_size) const { @@ -78,7 +76,7 @@ ObjectBufferPool::CreateObjectReader(const ObjectID &object_id, std::vector object_ids{object_id}; std::vector object_buffers(1); RAY_CHECK_OK( - store_client_.Get(object_ids, 0, &object_buffers, /*is_from_worker=*/false)); + store_client_->Get(object_ids, 0, &object_buffers, /*is_from_worker=*/false)); if (object_buffers[0].data == nullptr) { RAY_LOG(INFO) << "Failed to get a chunk of the object: " << object_id @@ -104,6 +102,9 @@ ray::Status ObjectBufferPool::CreateChunk(const ObjectID &object_id, RAY_RETURN_NOT_OK(EnsureBufferExists(object_id, owner_address, data_size, metadata_size, chunk_index)); auto &state = create_buffer_state_.at(object_id); + if (chunk_index >= state.chunk_state.size()) { + return ray::Status::IOError("Object size mismatch"); + } if (state.chunk_state[chunk_index] != CreateChunkState::AVAILABLE) { // There can be only one reference to this chunk at any given time. return ray::Status::IOError("Chunk already received by a different thread."); @@ -112,14 +113,19 @@ ray::Status ObjectBufferPool::CreateChunk(const ObjectID &object_id, return ray::Status::OK(); } -void ObjectBufferPool::WriteChunk(const ObjectID &object_id, const uint64_t chunk_index, +void ObjectBufferPool::WriteChunk(const ObjectID &object_id, uint64_t data_size, + uint64_t metadata_size, const uint64_t chunk_index, const std::string &data) { absl::MutexLock lock(&pool_mutex_); auto it = create_buffer_state_.find(object_id); - if (it == create_buffer_state_.end() || + if (it == create_buffer_state_.end() || chunk_index >= it->second.chunk_state.size() || it->second.chunk_state.at(chunk_index) != CreateChunkState::REFERENCED) { - RAY_LOG(DEBUG) << "Object " << object_id << " aborted due to OOM before chunk " - << chunk_index << " could be sealed"; + RAY_LOG(DEBUG) << "Object " << object_id << " aborted before chunk " << chunk_index + << " could be sealed"; + return; + } + if (it->second.data_size != data_size || it->second.metadata_size != metadata_size) { + RAY_LOG(DEBUG) << "Object " << object_id << " size mismatch, rejecting chunk"; return; } RAY_CHECK(it->second.chunk_info.size() > chunk_index); @@ -131,8 +137,8 @@ void ObjectBufferPool::WriteChunk(const ObjectID &object_id, const uint64_t chun it->second.chunk_state.at(chunk_index) = CreateChunkState::SEALED; it->second.num_seals_remaining--; if (it->second.num_seals_remaining == 0) { - RAY_CHECK_OK(store_client_.Seal(object_id)); - RAY_CHECK_OK(store_client_.Release(object_id)); + RAY_CHECK_OK(store_client_->Seal(object_id)); + RAY_CHECK_OK(store_client_->Release(object_id)); create_buffer_state_.erase(it); RAY_LOG(DEBUG) << "Have received all chunks for object " << object_id << ", last chunk index: " << chunk_index; @@ -141,12 +147,16 @@ void ObjectBufferPool::WriteChunk(const ObjectID &object_id, const uint64_t chun void ObjectBufferPool::AbortCreate(const ObjectID &object_id) { absl::MutexLock lock(&pool_mutex_); + RAY_LOG(INFO) << "Not enough memory to create requested object " << object_id + << ", aborting"; + AbortCreateInternal(object_id); +} + +void ObjectBufferPool::AbortCreateInternal(const ObjectID &object_id) { auto it = create_buffer_state_.find(object_id); if (it != create_buffer_state_.end()) { - RAY_LOG(INFO) << "Not enough memory to create requested object " << object_id - << ", aborting"; - RAY_CHECK_OK(store_client_.Release(object_id)); - RAY_CHECK_OK(store_client_.Abort(object_id)); + RAY_CHECK_OK(store_client_->Release(object_id)); + RAY_CHECK_OK(store_client_->Abort(object_id)); create_buffer_state_.erase(object_id); } } @@ -177,9 +187,13 @@ ray::Status ObjectBufferPool::EnsureBufferExists(const ObjectID &object_id, uint64_t metadata_size, uint64_t chunk_index) { while (true) { - // Buffer for object_id already exists. - if (create_buffer_state_.contains(object_id)) { - return ray::Status::OK(); + // Buffer for object_id already exists and the size matches ours. + { + auto it = create_buffer_state_.find(object_id); + if (it != create_buffer_state_.end() && it->second.data_size == data_size && + it->second.metadata_size == metadata_size) { + return ray::Status::OK(); + } } auto it = create_buffer_ops_.find(object_id); @@ -198,13 +212,30 @@ ray::Status ObjectBufferPool::EnsureBufferExists(const ObjectID &object_id, // create_buffer_ops_. RAY_CHECK( create_buffer_ops_.insert({object_id, std::make_shared()}).second); + + // If the buffer currently exists, its size must be different. Abort the + // created buffer so we can recreate it with the correct size. + { + auto it = create_buffer_state_.find(object_id); + if (it != create_buffer_state_.end()) { + RAY_CHECK(it->second.data_size != data_size || + it->second.metadata_size != metadata_size); + RAY_LOG(WARNING) << "Object " << object_id << " size (" << data_size + << ") differs from the original (" << it->second.data_size + << "). This is likely due to re-execution of a task with a " + "nondeterministic output. Recreating object with size " + << data_size << "."; + AbortCreateInternal(it->first); + } + } + const int64_t object_size = static_cast(data_size) - static_cast(metadata_size); std::shared_ptr data; // Release pool_mutex_ during the blocking create call. pool_mutex_.Unlock(); - Status s = store_client_.CreateAndSpillIfNeeded( + Status s = store_client_->CreateAndSpillIfNeeded( object_id, owner_address, static_cast(object_size), nullptr, static_cast(metadata_size), &data, plasma::flatbuf::ObjectSource::ReceivedFromRemoteRaylet); @@ -231,10 +262,11 @@ ray::Status ObjectBufferPool::EnsureBufferExists(const ObjectID &object_id, // Read object into store. uint8_t *mutable_data = data->Data(); uint64_t num_chunks = GetNumChunks(data_size); - create_buffer_state_.emplace( + auto inserted = create_buffer_state_.emplace( std::piecewise_construct, std::forward_as_tuple(object_id), - std::forward_as_tuple(BuildChunks(object_id, mutable_data, data_size, data))); - RAY_CHECK(create_buffer_state_[object_id].chunk_info.size() == num_chunks); + std::forward_as_tuple(metadata_size, data_size, + BuildChunks(object_id, mutable_data, data_size, data))); + RAY_CHECK(inserted.first->second.chunk_info.size() == num_chunks); RAY_LOG(DEBUG) << "Created object " << object_id << " in plasma store, number of chunks: " << num_chunks << ", chunk index: " << chunk_index; @@ -244,7 +276,7 @@ ray::Status ObjectBufferPool::EnsureBufferExists(const ObjectID &object_id, void ObjectBufferPool::FreeObjects(const std::vector &object_ids) { absl::MutexLock lock(&pool_mutex_); - RAY_CHECK_OK(store_client_.Delete(object_ids)); + RAY_CHECK_OK(store_client_->Delete(object_ids)); } std::string ObjectBufferPool::DebugString() const { diff --git a/src/ray/object_manager/object_buffer_pool.h b/src/ray/object_manager/object_buffer_pool.h index b2722a3ec..21b7d69dc 100644 --- a/src/ray/object_manager/object_buffer_pool.h +++ b/src/ray/object_manager/object_buffer_pool.h @@ -56,10 +56,10 @@ class ObjectBufferPool { /// Constructor. /// - /// \param store_socket_name The socket name of the store to which plasma clients - /// connect. + /// \param store_client Plasma store client. Used for testing purposes only. /// \param chunk_size The chunk size into which objects are to be split. - ObjectBufferPool(const std::string &store_socket_name, const uint64_t chunk_size); + ObjectBufferPool(std::shared_ptr store_client, + const uint64_t chunk_size); ~ObjectBufferPool(); @@ -120,8 +120,9 @@ class ObjectBufferPool { /// \param object_id The ObjectID. /// \param chunk_index The index of the chunk. /// \param data The data to write into the chunk. - void WriteChunk(const ObjectID &object_id, uint64_t chunk_index, - const std::string &data) LOCKS_EXCLUDED(pool_mutex_); + void WriteChunk(const ObjectID &object_id, uint64_t data_size, uint64_t metadata_size, + uint64_t chunk_index, const std::string &data) + LOCKS_EXCLUDED(pool_mutex_); /// Free a list of objects from object store. /// @@ -155,16 +156,25 @@ class ObjectBufferPool { uint64_t metadata_size, uint64_t chunk_index) EXCLUSIVE_LOCKS_REQUIRED(pool_mutex_); + void AbortCreateInternal(const ObjectID &object_id) + EXCLUSIVE_LOCKS_REQUIRED(pool_mutex_); + /// The state of a chunk associated with a create operation. enum class CreateChunkState : unsigned int { AVAILABLE = 0, REFERENCED, SEALED }; /// Holds the state of creating chunks. Members are protected by pool_mutex_. struct CreateBufferState { - CreateBufferState() {} - CreateBufferState(std::vector chunk_info) - : chunk_info(chunk_info), + CreateBufferState(uint64_t metadata_size, uint64_t data_size, + std::vector chunk_info) + : metadata_size(metadata_size), + data_size(data_size), + chunk_info(chunk_info), chunk_state(chunk_info.size(), CreateChunkState::AVAILABLE), num_seals_remaining(chunk_info.size()) {} + /// Total size of the object metadata. + uint64_t metadata_size; + /// Total size of the object data. + uint64_t data_size; /// A vector maintaining information about the chunks which comprise /// an object. std::vector chunk_info; @@ -178,12 +188,6 @@ class ObjectBufferPool { /// Returned when GetChunk or CreateChunk fails. const ChunkInfo errored_chunk_ = {0, nullptr, 0, nullptr}; - /// Socket name of plasma store. - const std::string store_socket_name_; - - /// Determines the maximum chunk size to be transferred by a single thread. - const uint64_t default_chunk_size_; - /// Mutex to protect create_buffer_ops_, create_buffer_state_ and following invariants: /// - create_buffer_ops_ contains an object_id iff there is an inflight operation to /// create the buffer for the object. @@ -200,7 +204,12 @@ class ObjectBufferPool { GUARDED_BY(pool_mutex_); /// Plasma client pool. - plasma::PlasmaClient store_client_; + std::shared_ptr store_client_; + + /// Determines the maximum chunk size to be transferred by a single thread. + const uint64_t default_chunk_size_; + + friend class ObjectBufferPoolTest; }; } // namespace ray diff --git a/src/ray/object_manager/object_manager.cc b/src/ray/object_manager/object_manager.cc index 289d30be8..d5830ffc4 100644 --- a/src/ray/object_manager/object_manager.cc +++ b/src/ray/object_manager/object_manager.cc @@ -86,7 +86,8 @@ ObjectManager::ObjectManager( }, "ObjectManager.ObjectDeleted"); }), - buffer_pool_(config_.store_socket_name, config_.object_chunk_size), + buffer_pool_store_client_(std::make_shared()), + buffer_pool_(buffer_pool_store_client_, config_.object_chunk_size), rpc_work_(rpc_service_), object_manager_server_("ObjectManager", config_.object_manager_port, config_.object_manager_address == "127.0.0.1", @@ -127,6 +128,10 @@ ObjectManager::ObjectManager( self_node_id_, object_is_local, send_pull_request, cancel_pull_request, fail_pull_request, restore_spilled_object_, get_time, config.pull_timeout_ms, available_memory, pin_object, get_spilled_object_url)); + + RAY_CHECK_OK( + buffer_pool_store_client_->Connect(config_.store_socket_name.c_str(), "", 0, 300)); + // Start object manager rpc server and send & receive request threads StartRpcService(); } @@ -368,23 +373,16 @@ void ObjectManager::PushLocalObject(const ObjectID &object_id, const NodeID &nod if (object_reader->GetDataSize() != data_size || object_reader->GetMetadataSize() != metadata_size) { - if (object_reader->GetDataSize() == 0) { - // TODO(scv119): handle object size changes in a more graceful way. - RAY_LOG(WARNING) << object_id - << " is marked as failed but object_manager has stale info " - << " with data size: " << data_size - << ", metadata size: " << metadata_size - << ". This is likely due to race condition." - << " Update the info and proceed sending failed object."; - local_objects_[object_id].object_info.data_size = 0; - local_objects_[object_id].object_info.metadata_size = 1; - } else { - RAY_LOG(FATAL) << "Object id:" << object_id + // TODO(scv119): handle object size changes in a more graceful way. + RAY_LOG(WARNING) << "Object id:" << object_id << "'s size mismatches our record. Expected data size: " << data_size << ", expected metadata size: " << metadata_size << ", actual data size: " << object_reader->GetDataSize() - << ", actual metadata size: " << object_reader->GetMetadataSize(); - } + << ", actual metadata size: " << object_reader->GetMetadataSize() + << ". This is likely due to a race condition." + << " We will update the object size and proceed sending the object."; + local_objects_[object_id].object_info.data_size = 0; + local_objects_[object_id].object_info.metadata_size = 1; } PushObjectInternal(object_id, node_id, @@ -560,7 +558,7 @@ bool ObjectManager::ReceiveObjectChunk(const NodeID &node_id, const ObjectID &ob if (chunk_status.ok()) { // Avoid handling this chunk if it's already being handled by another process. - buffer_pool_.WriteChunk(object_id, chunk_index, data); + buffer_pool_.WriteChunk(object_id, data_size, metadata_size, chunk_index, data); return true; } else { num_chunks_received_failed_due_to_plasma_++; diff --git a/src/ray/object_manager/object_manager.h b/src/ray/object_manager/object_manager.h index 04d47019c..eb0aa4f46 100644 --- a/src/ray/object_manager/object_manager.h +++ b/src/ray/object_manager/object_manager.h @@ -377,6 +377,11 @@ class ObjectManager : public ObjectManagerInterface, /// Object store runner. ObjectStoreRunner object_store_internal_; + /// Used by the buffer pool to read and write objects in the local store + /// during object transfers. + std::shared_ptr buffer_pool_store_client_; + + /// Manages accesses to local objects for object transfers. ObjectBufferPool buffer_pool_; /// Multi-thread asio service, deal with all outgoing and incoming RPC request. diff --git a/src/ray/object_manager/plasma/client.h b/src/ray/object_manager/plasma/client.h index 6c6b9ea19..c15ecf75c 100644 --- a/src/ray/object_manager/plasma/client.h +++ b/src/ray/object_manager/plasma/client.h @@ -44,7 +44,103 @@ struct ObjectBuffer { int device_num; }; -class PlasmaClient { +class PlasmaClientInterface { + public: + virtual ~PlasmaClientInterface(){}; + + /// Tell Plasma that the client no longer needs the object. This should be + /// called after Get() or Create() when the client is done with the object. + /// After this call, the buffer returned by Get() is no longer valid. + /// + /// \param object_id The ID of the object that is no longer needed. + /// \return The return status. + virtual Status Release(const ObjectID &object_id) = 0; + + /// Disconnect from the local plasma instance, including the local store and + /// manager. + /// + /// \return The return status. + virtual Status Disconnect() = 0; + + /// Get some objects from the Plasma Store. This function will block until the + /// objects have all been created and sealed in the Plasma Store or the + /// timeout expires. + /// + /// If an object was not retrieved, the corresponding metadata and data + /// fields in the ObjectBuffer structure will evaluate to false. + /// Objects are automatically released by the client when their buffers + /// get out of scope. + /// + /// \param object_ids The IDs of the objects to get. + /// \param timeout_ms The amount of time in milliseconds to wait before this + /// request times out. If this value is -1, then no timeout is set. + /// \param[out] object_buffers The object results. + /// \param is_from_worker Whether or not if the Get request comes from a Ray workers. + /// \return The return status. + virtual Status Get(const std::vector &object_ids, int64_t timeout_ms, + std::vector *object_buffers, bool is_from_worker) = 0; + + /// Seal an object in the object store. The object will be immutable after + /// this + /// call. + /// + /// \param object_id The ID of the object to seal. + /// \return The return status. + virtual Status Seal(const ObjectID &object_id) = 0; + + /// Abort an unsealed object in the object store. If the abort succeeds, then + /// it will be as if the object was never created at all. The unsealed object + /// must have only a single reference (the one that would have been removed by + /// calling Seal). + /// + /// \param object_id The ID of the object to abort. + /// \return The return status. + virtual Status Abort(const ObjectID &object_id) = 0; + + /// Create an object in the Plasma Store. Any metadata for this object must be + /// be passed in when the object is created. + /// + /// If this request cannot be fulfilled immediately, this call will block until + /// enough objects have been spilled to make space. If spilling cannot free + /// enough space, an out of memory error will be returned. + /// + /// \param object_id The ID to use for the newly created object. + /// \param owner_address The address of the object's owner. + /// \param data_size The size in bytes of the space to be allocated for this + /// object's + /// data (this does not include space used for metadata). + /// \param metadata The object's metadata. If there is no metadata, this + /// pointer should be NULL. + /// \param metadata_size The size in bytes of the metadata. If there is no + /// metadata, this should be 0. + /// \param data The address of the newly created object will be written here. + /// \param device_num The number of the device where the object is being + /// created. + /// device_num = 0 corresponds to the host, + /// device_num = 1 corresponds to GPU0, + /// device_num = 2 corresponds to GPU1, etc. + /// \return The return status. + /// + /// The returned object must be released once it is done with. It must also + /// be either sealed or aborted. + virtual Status CreateAndSpillIfNeeded(const ObjectID &object_id, + const ray::rpc::Address &owner_address, + int64_t data_size, const uint8_t *metadata, + int64_t metadata_size, + std::shared_ptr *data, + plasma::flatbuf::ObjectSource source, + int device_num = 0) = 0; + + /// Delete a list of objects from the object store. This currently assumes that the + /// object is present, has been sealed and not used by another client. Otherwise, + /// it is a no operation. + /// + /// \param object_ids The list of IDs of the objects to delete. + /// \return The return status. If all the objects are non-existent, return OK. + virtual Status Delete(const std::vector &object_ids) = 0; +}; + +class PlasmaClient : public PlasmaClientInterface { public: PlasmaClient(); ~PlasmaClient(); diff --git a/src/ray/object_manager/test/object_buffer_pool_test.cc b/src/ray/object_manager/test/object_buffer_pool_test.cc new file mode 100644 index 000000000..0dac3bf1d --- /dev/null +++ b/src/ray/object_manager/test/object_buffer_pool_test.cc @@ -0,0 +1,168 @@ +// Copyright 2017 The Ray Authors. +// + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// clang-format off +#include "ray/object_manager/object_buffer_pool.h" + +#include +#include + +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "ray/common/id.h" +#include "ray/object_manager/plasma/client.h" + +#ifdef UNORDERED_VS_ABSL_MAPS_EVALUATION +#include + +#include "absl/container/flat_hash_map.h" +#endif // UNORDERED_VS_ABSL_MAPS_EVALUATION +// clang-format on + +namespace ray { + +using ::testing::_; + +class MockPlasmaClient : public plasma::PlasmaClientInterface { + public: + MOCK_METHOD1(Release, ray::Status(const ObjectID &object_id)); + + MOCK_METHOD0(Disconnect, ray::Status()); + + MOCK_METHOD4(Get, + ray::Status(const std::vector &object_ids, int64_t timeout_ms, + std::vector *object_buffers, + bool is_from_worker)); + + MOCK_METHOD1(Seal, ray::Status(const ObjectID &object_id)); + + MOCK_METHOD1(Abort, ray::Status(const ObjectID &object_id)); + + ray::Status CreateAndSpillIfNeeded(const ObjectID &object_id, + const ray::rpc::Address &owner_address, + int64_t data_size, const uint8_t *metadata, + int64_t metadata_size, std::shared_ptr *data, + plasma::flatbuf::ObjectSource source, + int device_num) { + *data = std::make_shared(data_size); + return ray::Status::OK(); + } + + MOCK_METHOD1(Delete, ray::Status(const std::vector &object_ids)); +}; + +class ObjectBufferPoolTest : public ::testing::Test { + public: + ObjectBufferPoolTest() + : chunk_size_(1000), + mock_plasma_client_(std::make_shared()), + object_buffer_pool_(mock_plasma_client_, chunk_size_), + mock_data_(chunk_size_, 'x') {} + + void AssertNoLeaks() { + absl::MutexLock lock(&object_buffer_pool_.pool_mutex_); + ASSERT_TRUE(object_buffer_pool_.create_buffer_state_.empty()); + ASSERT_TRUE(object_buffer_pool_.create_buffer_ops_.empty()); + } + + uint64_t chunk_size_; + std::shared_ptr mock_plasma_client_; + ObjectBufferPool object_buffer_pool_; + std::string mock_data_; +}; + +TEST_F(ObjectBufferPoolTest, TestBasic) { + auto obj_id = ObjectID::FromRandom(); + rpc::Address owner_address; + + ASSERT_TRUE( + object_buffer_pool_.CreateChunk(obj_id, owner_address, chunk_size_, 0, 0).ok()); + ASSERT_FALSE( + object_buffer_pool_.CreateChunk(obj_id, owner_address, chunk_size_, 0, 0).ok()); + EXPECT_CALL(*mock_plasma_client_, Seal(obj_id)); + EXPECT_CALL(*mock_plasma_client_, Release(obj_id)); + object_buffer_pool_.WriteChunk(obj_id, chunk_size_, 0, 0, mock_data_); +} + +TEST_F(ObjectBufferPoolTest, TestMultiChunk) { + auto obj_id = ObjectID::FromRandom(); + rpc::Address owner_address; + + for (int i = 0; i < 3; i++) { + ASSERT_TRUE( + object_buffer_pool_.CreateChunk(obj_id, owner_address, 3 * chunk_size_, 0, i) + .ok()); + ASSERT_FALSE( + object_buffer_pool_.CreateChunk(obj_id, owner_address, 3 * chunk_size_, 0, i) + .ok()); + } + EXPECT_CALL(*mock_plasma_client_, Seal(obj_id)); + EXPECT_CALL(*mock_plasma_client_, Release(obj_id)); + for (int i = 0; i < 3; i++) { + object_buffer_pool_.WriteChunk(obj_id, 3 * chunk_size_, 0, i, mock_data_); + } +} + +TEST_F(ObjectBufferPoolTest, TestAbort) { + auto obj_id = ObjectID::FromRandom(); + rpc::Address owner_address; + + ASSERT_TRUE( + object_buffer_pool_.CreateChunk(obj_id, owner_address, chunk_size_, 0, 0).ok()); + ASSERT_FALSE( + object_buffer_pool_.CreateChunk(obj_id, owner_address, chunk_size_, 0, 0).ok()); + EXPECT_CALL(*mock_plasma_client_, Abort(obj_id)); + object_buffer_pool_.AbortCreate(obj_id); + ASSERT_TRUE( + object_buffer_pool_.CreateChunk(obj_id, owner_address, chunk_size_, 0, 0).ok()); + + EXPECT_CALL(*mock_plasma_client_, Seal(obj_id)); + EXPECT_CALL(*mock_plasma_client_, Release(obj_id)); + object_buffer_pool_.WriteChunk(obj_id, chunk_size_, 0, 0, mock_data_); +} + +TEST_F(ObjectBufferPoolTest, TestSizeMismatch) { + auto obj_id = ObjectID::FromRandom(); + rpc::Address owner_address; + + int64_t data_size_1 = 3 * chunk_size_; + int64_t data_size_2 = 2 * chunk_size_; + ASSERT_TRUE( + object_buffer_pool_.CreateChunk(obj_id, owner_address, data_size_1, 0, 0).ok()); + object_buffer_pool_.WriteChunk(obj_id, data_size_1, 0, 0, mock_data_); + + // Object gets created again with a different size. + EXPECT_CALL(*mock_plasma_client_, Release(obj_id)); + EXPECT_CALL(*mock_plasma_client_, Abort(obj_id)); + ASSERT_TRUE( + object_buffer_pool_.CreateChunk(obj_id, owner_address, data_size_2, 0, 1).ok()); + object_buffer_pool_.WriteChunk(obj_id, data_size_2, 0, 1, mock_data_); + + ASSERT_TRUE( + object_buffer_pool_.CreateChunk(obj_id, owner_address, data_size_2, 0, 0).ok()); + // Writing a chunk with a stale data size has no effect. + object_buffer_pool_.WriteChunk(obj_id, data_size_1, 0, 0, mock_data_); + + EXPECT_CALL(*mock_plasma_client_, Seal(obj_id)); + EXPECT_CALL(*mock_plasma_client_, Release(obj_id)); + object_buffer_pool_.WriteChunk(obj_id, data_size_2, 0, 0, mock_data_); +} + +} // namespace ray + +int main(int argc, char **argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +}