[core] Recover spilled objects that are lost during node failure (#21485)

* Failing test

* trigger recovery from ref counter

* x

* update

* lint

* stress test

* update

* format

* x
This commit is contained in:
Stephanie Wang 2022-02-09 18:22:16 -08:00 committed by GitHub
parent 1c791b71d8
commit 495eb14179
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
11 changed files with 261 additions and 110 deletions

View file

@ -640,6 +640,61 @@ def test_reconstruction_stress(ray_start_cluster):
i += 1
@pytest.mark.skipif(sys.platform == "win32", reason="Failing on Windows.")
def test_reconstruction_stress_spill(ray_start_cluster):
config = {
"num_heartbeats_timeout": 10,
"raylet_heartbeat_period_milliseconds": 100,
"max_direct_call_object_size": 100,
"task_retry_delay_ms": 100,
"object_timeout_milliseconds": 200,
}
cluster = ray_start_cluster
# Head node with no resources.
cluster.add_node(
num_cpus=0, _system_config=config, enable_object_reconstruction=True
)
ray.init(address=cluster.address)
# Node to place the initial object.
node_to_kill = cluster.add_node(
num_cpus=1, resources={"node1": 1}, object_store_memory=10 ** 8
)
cluster.add_node(num_cpus=1, resources={"node2": 1}, object_store_memory=10 ** 8)
cluster.wait_for_nodes()
@ray.remote
def large_object():
return np.zeros(10 ** 6, dtype=np.uint8)
@ray.remote
def dependent_task(x):
return
for _ in range(3):
obj = large_object.options(resources={"node1": 1}).remote()
ray.get(dependent_task.options(resources={"node2": 1}).remote(obj))
outputs = [
large_object.options(resources={"node1": 1}).remote() for _ in range(1000)
]
outputs = [
dependent_task.options(resources={"node2": 1}).remote(obj)
for obj in outputs
]
cluster.remove_node(node_to_kill, allow_graceful=False)
node_to_kill = cluster.add_node(
num_cpus=1, resources={"node1": 1}, object_store_memory=10 ** 8
)
i = 0
while outputs:
ref = outputs.pop(0)
print(i, ref)
ray.get(ref)
i += 1
@pytest.mark.skipif(sys.platform == "win32", reason="Failing on Windows.")
@pytest.mark.parametrize("reconstruction_enabled", [False, True])
def test_nondeterministic_output(ray_start_cluster, reconstruction_enabled):
@ -905,6 +960,60 @@ def test_nested(ray_start_cluster, reconstruction_enabled):
ray.get(ref, timeout=60)
@pytest.mark.parametrize("reconstruction_enabled", [False, True])
def test_spilled(ray_start_cluster, reconstruction_enabled):
config = {
"num_heartbeats_timeout": 10,
"raylet_heartbeat_period_milliseconds": 100,
"object_timeout_milliseconds": 200,
}
# Workaround to reset the config to the default value.
if not reconstruction_enabled:
config["lineage_pinning_enabled"] = False
cluster = ray_start_cluster
# Head node with no resources.
cluster.add_node(
num_cpus=0,
_system_config=config,
enable_object_reconstruction=reconstruction_enabled,
)
ray.init(address=cluster.address)
# Node to place the initial object.
node_to_kill = cluster.add_node(
num_cpus=1, resources={"node1": 1}, object_store_memory=10 ** 8
)
cluster.wait_for_nodes()
@ray.remote(max_retries=1 if reconstruction_enabled else 0)
def large_object():
return np.zeros(10 ** 7, dtype=np.uint8)
@ray.remote
def dependent_task(x):
return
obj = large_object.options(resources={"node1": 1}).remote()
ray.get(dependent_task.options(resources={"node1": 1}).remote(obj))
# Force spilling.
objs = [large_object.options(resources={"node1": 1}).remote() for _ in range(20)]
for o in objs:
ray.get(o)
cluster.remove_node(node_to_kill, allow_graceful=False)
node_to_kill = cluster.add_node(
num_cpus=1, resources={"node1": 1}, object_store_memory=10 ** 8
)
if reconstruction_enabled:
ray.get(dependent_task.remote(obj), timeout=60)
else:
with pytest.raises(ray.exceptions.RayTaskError):
ray.get(dependent_task.remote(obj), timeout=60)
with pytest.raises(ray.exceptions.ObjectLostError):
ray.get(obj, timeout=60)
if __name__ == "__main__":
import pytest

View file

@ -90,8 +90,17 @@ RAY_CONFIG(int64_t, free_objects_period_milliseconds, 1000)
/// to -1.
RAY_CONFIG(size_t, free_objects_batch_size, 100)
/// Whether to pin object lineage, i.e. the task that created the object and
/// the task's recursive dependencies. If this is set to true, then the system
/// will attempt to reconstruct the object from its lineage if the object is
/// lost.
RAY_CONFIG(bool, lineage_pinning_enabled, false)
/// Objects that require recovery are added to a local cache. This is the
/// duration between attempts to flush and recover the objects in the local
/// cache.
RAY_CONFIG(int64_t, reconstruct_objects_period_milliseconds, 100)
/// Maximum amount of lineage to keep in bytes. This includes the specs of all
/// tasks that have previously already finished but that may be retried again.
/// If we reach this limit, 50% of the current lineage will be evicted and

View file

@ -218,10 +218,14 @@ CoreWorker::CoreWorker(const CoreWorkerOptions &options, const WorkerID &worker_
},
/*callback_service*/ &io_service_);
auto check_node_alive_fn = [this](const NodeID &node_id) {
auto node = gcs_client_->Nodes().Get(node_id);
return node != nullptr;
};
reference_counter_ = std::make_shared<ReferenceCounter>(
rpc_address_,
/*object_info_publisher=*/object_info_publisher_.get(),
/*object_info_subscriber=*/object_info_subscriber_.get(),
/*object_info_subscriber=*/object_info_subscriber_.get(), check_node_alive_fn,
RayConfig::instance().lineage_pinning_enabled(), [this](const rpc::Address &addr) {
return std::shared_ptr<rpc::CoreWorkerClient>(
new rpc::CoreWorkerClient(addr, *client_call_manager_));
@ -257,17 +261,6 @@ CoreWorker::CoreWorker(const CoreWorkerOptions &options, const WorkerID &worker_
periodical_runner_.RunFnPeriodically([this] { InternalHeartbeat(); },
kInternalHeartbeatMillis);
auto check_node_alive_fn = [this](const NodeID &node_id) {
auto node = gcs_client_->Nodes().Get(node_id);
return node != nullptr;
};
auto reconstruct_object_callback = [this](const ObjectID &object_id) {
io_service_.post(
[this, object_id]() {
RAY_CHECK(object_recovery_manager_->RecoverObject(object_id));
},
"CoreWorker.ReconstructObject");
};
auto push_error_callback = [this](const JobID &job_id, const std::string &type,
const std::string &error_message, double timestamp) {
return PushError(job_id, type, error_message, timestamp);
@ -301,8 +294,7 @@ CoreWorker::CoreWorker(const CoreWorkerOptions &options, const WorkerID &worker_
}
}
},
check_node_alive_fn, reconstruct_object_callback, push_error_callback,
RayConfig::instance().max_lineage_bytes()));
push_error_callback, RayConfig::instance().max_lineage_bytes()));
// Create an entry for the driver task in the task table. This task is
// added immediately with status RUNNING. This allows us to push errors
@ -458,6 +450,24 @@ CoreWorker::CoreWorker(const CoreWorkerOptions &options, const WorkerID &worker_
RayEventContext::Instance().SetEventContext(
ray::rpc::Event_SourceType::Event_SourceType_CORE_WORKER,
{{"worker_id", worker_id.Hex()}});
periodical_runner_.RunFnPeriodically(
[this] {
const auto lost_objects = reference_counter_->FlushObjectsToRecover();
// Delete the objects from the in-memory store to indicate that they are not
// available. The object recovery manager will guarantee that a new value
// will eventually be stored for the objects (either an
// UnreconstructableError or a value reconstructed from lineage).
memory_store_->Delete(lost_objects);
for (const auto &object_id : lost_objects) {
// NOTE(swang): There is a race condition where this can return false if
// the reference went out of scope since the call to the ref counter to get
// the lost objects. It's okay to not mark the object as failed or recover
// the object since there are no reference holders.
RAY_UNUSED(object_recovery_manager_->RecoverObject(object_id));
}
},
100);
}
CoreWorker::~CoreWorker() { RAY_LOG(INFO) << "Core worker is destructed"; }
@ -621,22 +631,7 @@ void CoreWorker::OnNodeRemoved(const NodeID &node_id) {
RAY_LOG(INFO) << "Node failure from " << node_id
<< ". All objects pinned on that node will be lost if object "
"reconstruction is not enabled.";
const auto lost_objects = reference_counter_->ResetObjectsOnRemovedNode(node_id);
// Delete the objects from the in-memory store to indicate that they are not
// available. The object recovery manager will guarantee that a new value
// will eventually be stored for the objects (either an
// UnreconstructableError or a value reconstructed from lineage).
memory_store_->Delete(lost_objects);
for (const auto &object_id : lost_objects) {
// NOTE(swang): There is a race condition where this can return false if
// the reference went out of scope since the call to the ref counter to get
// the lost objects. It's okay to not mark the object as failed or recover
// the object since there are no reference holders.
auto recovered = object_recovery_manager_->RecoverObject(object_id);
if (!recovered) {
RAY_LOG(DEBUG) << "Object " << object_id << " lost due to node failure " << node_id;
}
}
reference_counter_->ResetObjectsOnRemovedNode(node_id);
}
const WorkerID &CoreWorker::GetWorkerID() const { return worker_context_.GetWorkerID(); }
@ -2953,7 +2948,7 @@ void CoreWorker::HandleAddSpilledUrl(const rpc::AddSpilledUrlRequest &request,
<< ", which has been spilled to " << spilled_url << " on node "
<< node_id;
auto reference_exists = reference_counter_->HandleObjectSpilled(
object_id, spilled_url, node_id, request.size(), /*release*/ false);
object_id, spilled_url, node_id, request.size());
Status status =
reference_exists
? Status::OK()

View file

@ -38,7 +38,8 @@ bool ObjectRecoveryManager::RecoverObject(const ObjectID &object_id) {
}
bool already_pending_recovery = true;
if (pinned_at.IsNil() && !spilled) {
bool requires_recovery = pinned_at.IsNil() && !spilled;
if (requires_recovery) {
{
absl::MutexLock lock(&mu_);
// Mark that we are attempting recovery for this object to prevent
@ -61,8 +62,11 @@ bool ObjectRecoveryManager::RecoverObject(const ObjectID &object_id) {
[this](const ObjectID &object_id, const std::vector<rpc::Address> &locations) {
PinOrReconstructObject(object_id, locations);
}));
} else {
} else if (requires_recovery) {
RAY_LOG(DEBUG) << "Recovery already started for object " << object_id;
} else {
RAY_LOG(DEBUG) << "Object " << object_id
<< " has a pinned or spilled location, skipping recovery";
}
return true;
}

View file

@ -590,6 +590,15 @@ void ReferenceCounter::ReleasePlasmaObject(ReferenceTable::iterator it) {
it->second.on_delete = nullptr;
}
it->second.pinned_at_raylet_id.reset();
if (it->second.spilled && !it->second.spilled_node_id.IsNil()) {
// The spilled copy of the object should get deleted during the on_delete
// callback, so reset the spill location metadata here.
// NOTE(swang): Spilled copies in cloud storage are not GCed, so we do not
// reset the spilled metadata.
it->second.spilled = false;
it->second.spilled_url = "";
it->second.spilled_node_id = NodeID::Nil();
}
}
bool ReferenceCounter::SetDeleteCallback(
@ -619,19 +628,26 @@ bool ReferenceCounter::SetDeleteCallback(
return true;
}
std::vector<ObjectID> ReferenceCounter::ResetObjectsOnRemovedNode(
const NodeID &raylet_id) {
void ReferenceCounter::ResetObjectsOnRemovedNode(const NodeID &raylet_id) {
absl::MutexLock lock(&mutex_);
std::vector<ObjectID> lost_objects;
for (auto it = object_id_refs_.begin(); it != object_id_refs_.end(); it++) {
const auto &object_id = it->first;
if (it->second.pinned_at_raylet_id.value_or(NodeID::Nil()) == raylet_id) {
lost_objects.push_back(object_id);
if (it->second.pinned_at_raylet_id.value_or(NodeID::Nil()) == raylet_id ||
it->second.spilled_node_id == raylet_id) {
ReleasePlasmaObject(it);
if (!it->second.OutOfScope(lineage_pinning_enabled_)) {
objects_to_recover_.push_back(object_id);
}
}
RemoveObjectLocationInternal(it, raylet_id);
}
return lost_objects;
}
std::vector<ObjectID> ReferenceCounter::FlushObjectsToRecover() {
absl::MutexLock lock(&mutex_);
std::vector<ObjectID> objects_to_recover = std::move(objects_to_recover_);
objects_to_recover_.clear();
return objects_to_recover;
}
void ReferenceCounter::UpdateObjectPinnedAtRaylet(const ObjectID &object_id,
@ -655,9 +671,14 @@ void ReferenceCounter::UpdateObjectPinnedAtRaylet(const ObjectID &object_id,
// Only the owner tracks the location.
RAY_CHECK(it->second.owned_by_us);
if (!it->second.OutOfScope(lineage_pinning_enabled_)) {
it->second.pinned_at_raylet_id = raylet_id;
// We eagerly add the pinned location to the set of object locations.
AddObjectLocationInternal(it, raylet_id);
if (check_node_alive_(raylet_id)) {
it->second.pinned_at_raylet_id = raylet_id;
// We eagerly add the pinned location to the set of object locations.
AddObjectLocationInternal(it, raylet_id);
} else {
ReleasePlasmaObject(it);
objects_to_recover_.push_back(object_id);
}
}
}
}
@ -1164,29 +1185,40 @@ size_t ReferenceCounter::GetObjectSize(const ObjectID &object_id) const {
bool ReferenceCounter::HandleObjectSpilled(const ObjectID &object_id,
const std::string spilled_url,
const NodeID &spilled_node_id, int64_t size,
bool release) {
const NodeID &spilled_node_id, int64_t size) {
absl::MutexLock lock(&mutex_);
auto it = object_id_refs_.find(object_id);
if (it == object_id_refs_.end()) {
RAY_LOG(WARNING) << "Spilled object " << object_id << " already out of scope";
return false;
}
if (it->second.OutOfScope(lineage_pinning_enabled_) && !spilled_node_id.IsNil()) {
// NOTE(swang): If the object is out of scope and was spilled locally by
// its primary raylet, then we should have already sent the "object
// evicted" notification to delete the copy at this spilled URL. Therefore,
// we should not add this spill URL as a location.
return false;
}
it->second.spilled = true;
if (spilled_url != "") {
it->second.spilled_url = spilled_url;
}
if (!spilled_node_id.IsNil()) {
it->second.spilled_node_id = spilled_node_id;
}
if (size > 0) {
it->second.object_size = size;
}
PushToLocationSubscribers(it);
if (release) {
// Release the primary plasma copy, if any.
bool spilled_location_alive =
spilled_node_id.IsNil() || check_node_alive_(spilled_node_id);
if (spilled_location_alive) {
if (spilled_url != "") {
it->second.spilled_url = spilled_url;
}
if (!spilled_node_id.IsNil()) {
it->second.spilled_node_id = spilled_node_id;
}
if (size > 0) {
it->second.object_size = size;
}
PushToLocationSubscribers(it);
} else {
RAY_LOG(DEBUG) << "Object " << object_id << " spilled to dead node "
<< spilled_node_id;
ReleasePlasmaObject(it);
objects_to_recover_.push_back(object_id);
}
return true;
}

View file

@ -66,13 +66,15 @@ class ReferenceCounter : public ReferenceCounterInterface,
ReferenceCounter(const rpc::WorkerAddress &rpc_address,
pubsub::PublisherInterface *object_info_publisher,
pubsub::SubscriberInterface *object_info_subscriber,
const std::function<bool(const NodeID &node_id)> &check_node_alive,
bool lineage_pinning_enabled = false,
rpc::ClientFactoryFn client_factory = nullptr)
: rpc_address_(rpc_address),
lineage_pinning_enabled_(lineage_pinning_enabled),
borrower_pool_(client_factory),
object_info_publisher_(object_info_publisher),
object_info_subscriber_(object_info_subscriber) {}
object_info_subscriber_(object_info_subscriber),
check_node_alive_(check_node_alive) {}
~ReferenceCounter() {}
@ -346,14 +348,18 @@ class ReferenceCounter : public ReferenceCounterInterface,
NodeID *pinned_at, bool *spilled) const
LOCKS_EXCLUDED(mutex_);
/// Get and reset the objects that were pinned on the given node. This
/// method should be called upon a node failure, to determine which plasma
/// objects were lost. If a deletion callback was set for a lost object, it
/// will be invoked and reset.
/// Get and reset the objects that were pinned or spilled on the given node.
/// This method should be called upon a node failure, to trigger
/// reconstruction for any lost objects that are still in scope.
///
/// If a deletion callback was set for a lost object, it will be invoked and
/// reset.
///
/// \param[in] node_id The node whose object store has been removed.
/// \return The set of objects that were pinned on the given node.
std::vector<ObjectID> ResetObjectsOnRemovedNode(const NodeID &raylet_id);
void ResetObjectsOnRemovedNode(const NodeID &raylet_id);
std::vector<ObjectID> FlushObjectsToRecover();
/// Whether we have a reference to a particular ObjectID.
///
@ -427,10 +433,9 @@ class ReferenceCounter : public ReferenceCounterInterface,
/// \param[in] spilled_url The URL to which the object has been spilled.
/// \param[in] spilled_node_id The ID of the node on which the object was spilled.
/// \param[in] size The size of the object.
/// \param[in] release Whether to release the reference.
/// \return True if the reference exists, false otherwise.
/// \return True if the reference exists and is in scope, false otherwise.
bool HandleObjectSpilled(const ObjectID &object_id, const std::string spilled_url,
const NodeID &spilled_node_id, int64_t size, bool release);
const NodeID &spilled_node_id, int64_t size);
/// Get locality data for object. This is used by the leasing policy to implement
/// locality-aware leasing.
@ -879,6 +884,16 @@ class ReferenceCounter : public ReferenceCounterInterface,
/// object's place in the queue.
absl::flat_hash_map<ObjectID, std::list<ObjectID>::iterator>
reconstructable_owned_objects_index_ GUARDED_BY(mutex_);
/// Called to check whether a raylet is still alive. This is used when adding
/// the primary or spilled location of an object. If the node is dead, then
/// the object will be added to the buffer objects to recover.
const std::function<bool(const NodeID &node_id)> check_node_alive_;
/// A buffer of the objects whose primary or spilled locations have been lost
/// due to node failure. These objects are still in scope and need to be
/// recovered.
std::vector<ObjectID> objects_to_recover_ GUARDED_BY(mutex_);
};
} // namespace core

View file

@ -40,7 +40,8 @@ class ReferenceCountTest : public ::testing::Test {
rpc::Address addr;
publisher_ = std::make_shared<mock_pubsub::MockPublisher>();
subscriber_ = std::make_shared<mock_pubsub::MockSubscriber>();
rc = std::make_unique<ReferenceCounter>(addr, publisher_.get(), subscriber_.get());
rc = std::make_unique<ReferenceCounter>(addr, publisher_.get(), subscriber_.get(),
[](const NodeID &node_id) { return true; });
}
virtual void TearDown() {
@ -63,8 +64,10 @@ class ReferenceCountLineageEnabledTest : public ::testing::Test {
rpc::Address addr;
publisher_ = std::make_shared<mock_pubsub::MockPublisher>();
subscriber_ = std::make_shared<mock_pubsub::MockSubscriber>();
rc = std::make_unique<ReferenceCounter>(addr, publisher_.get(), subscriber_.get(),
/*lineage_pinning_enabled=*/true);
rc = std::make_unique<ReferenceCounter>(
addr, publisher_.get(), subscriber_.get(),
[](const NodeID &node_id) { return true; },
/*lineage_pinning_enabled=*/true);
}
virtual void TearDown() {
@ -280,7 +283,9 @@ class MockWorkerClient : public MockCoreWorkerClientInterface {
subscriber_(std::make_shared<MockDistributedSubscriber>(
&directory, &subscription_callback_map, &subscription_failure_callback_map,
WorkerID::FromBinary(address_.worker_id()), client_factory)),
rc_(rpc::WorkerAddress(address_), publisher_.get(), subscriber_.get(),
rc_(
rpc::WorkerAddress(address_), publisher_.get(), subscriber_.get(),
[](const NodeID &node_id) { return true; },
/*lineage_pinning_enabled=*/false, client_factory) {}
~MockWorkerClient() override {
@ -704,8 +709,9 @@ TEST(MemoryStoreIntegrationTest, TestSimple) {
auto publisher = std::make_shared<mock_pubsub::MockPublisher>();
auto subscriber = std::make_shared<mock_pubsub::MockSubscriber>();
auto rc = std::shared_ptr<ReferenceCounter>(new ReferenceCounter(
rpc::WorkerAddress(rpc::Address()), publisher.get(), subscriber.get()));
auto rc = std::shared_ptr<ReferenceCounter>(
new ReferenceCounter(rpc::WorkerAddress(rpc::Address()), publisher.get(),
subscriber.get(), [](const NodeID &node_id) { return true; }));
CoreWorkerMemoryStore store(rc);
// Tests putting an object with no references is ignored.
@ -2543,7 +2549,8 @@ TEST_F(ReferenceCountLineageEnabledTest, TestPlasmaLocation) {
rc->AddOwnedObject(id, {}, rpc::Address(), "", 0, true, /*add_local_ref=*/true);
ASSERT_TRUE(rc->SetDeleteCallback(id, callback));
rc->UpdateObjectPinnedAtRaylet(id, node_id);
auto objects = rc->ResetObjectsOnRemovedNode(node_id);
rc->ResetObjectsOnRemovedNode(node_id);
auto objects = rc->FlushObjectsToRecover();
ASSERT_EQ(objects.size(), 1);
ASSERT_EQ(objects[0], id);
ASSERT_TRUE(rc->IsPlasmaObjectPinnedOrSpilled(id, &owned_by_us, &pinned_at, &spilled));

View file

@ -238,17 +238,11 @@ void TaskManager::CompletePendingTask(const TaskID &task_id,
const auto nested_refs =
VectorFromProtobuf<rpc::ObjectReference>(return_object.nested_inlined_refs());
if (return_object.in_plasma()) {
// Mark it as in plasma with a dummy object.
RAY_CHECK(
in_memory_store_->Put(RayObject(rpc::ErrorType::OBJECT_IN_PLASMA), object_id));
const auto pinned_at_raylet_id = NodeID::FromBinary(worker_addr.raylet_id());
if (check_node_alive_(pinned_at_raylet_id)) {
reference_counter_->UpdateObjectPinnedAtRaylet(object_id, pinned_at_raylet_id);
// Mark it as in plasma with a dummy object.
RAY_CHECK(in_memory_store_->Put(RayObject(rpc::ErrorType::OBJECT_IN_PLASMA),
object_id));
} else {
RAY_LOG(DEBUG) << "Task " << task_id << " returned object " << object_id
<< " in plasma on a dead node, attempting to recover.";
reconstruct_object_callback_(object_id);
}
reference_counter_->UpdateObjectPinnedAtRaylet(object_id, pinned_at_raylet_id);
} else {
// NOTE(swang): If a direct object was promoted to plasma, then we do not
// record the node ID that it was pinned at, which means that we will not

View file

@ -79,15 +79,11 @@ class TaskManager : public TaskFinisherInterface, public TaskResubmissionInterfa
std::shared_ptr<ReferenceCounter> reference_counter,
PutInLocalPlasmaCallback put_in_local_plasma_callback,
RetryTaskCallback retry_task_callback,
const std::function<bool(const NodeID &node_id)> &check_node_alive,
ReconstructObjectCallback reconstruct_object_callback,
PushErrorCallback push_error_callback, int64_t max_lineage_bytes)
: in_memory_store_(in_memory_store),
reference_counter_(reference_counter),
put_in_local_plasma_callback_(put_in_local_plasma_callback),
retry_task_callback_(retry_task_callback),
check_node_alive_(check_node_alive),
reconstruct_object_callback_(reconstruct_object_callback),
push_error_callback_(push_error_callback),
max_lineage_bytes_(max_lineage_bytes) {
reference_counter_->SetReleaseLineageCallback(
@ -322,17 +318,6 @@ class TaskManager : public TaskFinisherInterface, public TaskResubmissionInterfa
/// Called when a task should be retried.
const RetryTaskCallback retry_task_callback_;
/// Called to check whether a raylet is still alive. This is used when
/// processing a worker's reply to check whether the node that the worker
/// was on is still alive. If the node is down, the plasma objects returned by the task
/// are marked as failed.
const std::function<bool(const NodeID &node_id)> check_node_alive_;
/// Called when processing a worker's reply if the node that the worker was
/// on died. This should be called to attempt to recover a plasma object
/// returned by the task (or store an error if the object is not
/// recoverable).
const ReconstructObjectCallback reconstruct_object_callback_;
// Called to push an error to the relevant driver.
const PushErrorCallback push_error_callback_;

View file

@ -116,6 +116,7 @@ class ObjectRecoveryManagerTestBase : public ::testing::Test {
task_resubmitter_(std::make_shared<MockTaskResubmitter>()),
ref_counter_(std::make_shared<ReferenceCounter>(
rpc::Address(), publisher_.get(), subscriber_.get(),
[](const NodeID &node_id) { return true; },
/*lineage_pinning_enabled=*/lineage_enabled)),
manager_(
rpc::Address(),
@ -257,7 +258,8 @@ TEST_F(ObjectRecoveryManagerTest, TestReconstructionSuppression) {
ASSERT_EQ(object_directory_->Flush(), 0);
// The object is removed and can be recovered again.
auto objects = ref_counter_->ResetObjectsOnRemovedNode(remote_node_id);
ref_counter_->ResetObjectsOnRemovedNode(remote_node_id);
auto objects = ref_counter_->FlushObjectsToRecover();
ASSERT_EQ(objects.size(), 1);
ASSERT_EQ(objects[0], object_id);
memory_store_->Delete(objects);

View file

@ -44,9 +44,10 @@ class TaskManagerTest : public ::testing::Test {
: store_(std::shared_ptr<CoreWorkerMemoryStore>(new CoreWorkerMemoryStore())),
publisher_(std::make_shared<mock_pubsub::MockPublisher>()),
subscriber_(std::make_shared<mock_pubsub::MockSubscriber>()),
reference_counter_(std::shared_ptr<ReferenceCounter>(
new ReferenceCounter(rpc::Address(), publisher_.get(), subscriber_.get(),
lineage_pinning_enabled))),
reference_counter_(std::shared_ptr<ReferenceCounter>(new ReferenceCounter(
rpc::Address(), publisher_.get(), subscriber_.get(),
[this](const NodeID &node_id) { return all_nodes_alive_; },
lineage_pinning_enabled))),
manager_(
store_, reference_counter_,
[this](const RayObject &object, const ObjectID &object_id) {
@ -56,10 +57,6 @@ class TaskManagerTest : public ::testing::Test {
num_retries_++;
return Status::OK();
},
[this](const NodeID &node_id) { return all_nodes_alive_; },
[this](const ObjectID &object_id) {
objects_to_recover_.push_back(object_id);
},
[](const JobID &job_id, const std::string &type,
const std::string &error_message,
double timestamp) { return Status::OK(); },
@ -79,7 +76,6 @@ class TaskManagerTest : public ::testing::Test {
std::shared_ptr<mock_pubsub::MockSubscriber> subscriber_;
std::shared_ptr<ReferenceCounter> reference_counter_;
bool all_nodes_alive_ = true;
std::vector<ObjectID> objects_to_recover_;
TaskManager manager_;
int num_retries_ = 0;
std::unordered_set<ObjectID> stored_in_plasma;
@ -173,7 +169,7 @@ TEST_F(TaskManagerTest, TestPlasmaConcurrentFailure) {
auto return_id = spec.ReturnId(0);
WorkerContext ctx(WorkerType::WORKER, WorkerID::FromRandom(), JobID::FromInt(0));
ASSERT_TRUE(objects_to_recover_.empty());
ASSERT_TRUE(reference_counter_->FlushObjectsToRecover().empty());
all_nodes_alive_ = false;
rpc::PushTaskReply reply;
@ -185,9 +181,12 @@ TEST_F(TaskManagerTest, TestPlasmaConcurrentFailure) {
ASSERT_FALSE(manager_.IsTaskPending(spec.TaskId()));
std::vector<std::shared_ptr<RayObject>> results;
ASSERT_FALSE(store_->Get({return_id}, 1, 0, ctx, false, &results).ok());
ASSERT_EQ(objects_to_recover_.size(), 1);
ASSERT_EQ(objects_to_recover_[0], return_id);
// Caller of FlushObjectsToRecover is responsible for deleting the object
// from the in-memory store and recovering the object.
ASSERT_TRUE(store_->Get({return_id}, 1, 0, ctx, false, &results).ok());
auto objects_to_recover = reference_counter_->FlushObjectsToRecover();
ASSERT_EQ(objects_to_recover.size(), 1);
ASSERT_EQ(objects_to_recover[0], return_id);
}
TEST_F(TaskManagerTest, TestFailPendingTask) {