From 495eb141795e9748a99c82e5a6daa5ad39b715f1 Mon Sep 17 00:00:00 2001 From: Stephanie Wang Date: Wed, 9 Feb 2022 18:22:16 -0800 Subject: [PATCH] [core] Recover spilled objects that are lost during node failure (#21485) * Failing test * trigger recovery from ref counter * x * update * lint * stress test * update * format * x --- python/ray/tests/test_reconstruction.py | 109 ++++++++++++++++++ src/ray/common/ray_config_def.h | 9 ++ src/ray/core_worker/core_worker.cc | 57 +++++---- .../core_worker/object_recovery_manager.cc | 8 +- src/ray/core_worker/reference_count.cc | 78 +++++++++---- src/ray/core_worker/reference_count.h | 33 ++++-- src/ray/core_worker/reference_count_test.cc | 21 ++-- src/ray/core_worker/task_manager.cc | 14 +-- src/ray/core_worker/task_manager.h | 15 --- .../test/object_recovery_manager_test.cc | 4 +- src/ray/core_worker/test/task_manager_test.cc | 23 ++-- 11 files changed, 261 insertions(+), 110 deletions(-) diff --git a/python/ray/tests/test_reconstruction.py b/python/ray/tests/test_reconstruction.py index d9d7debc3..cf1116cf9 100644 --- a/python/ray/tests/test_reconstruction.py +++ b/python/ray/tests/test_reconstruction.py @@ -640,6 +640,61 @@ def test_reconstruction_stress(ray_start_cluster): i += 1 +@pytest.mark.skipif(sys.platform == "win32", reason="Failing on Windows.") +def test_reconstruction_stress_spill(ray_start_cluster): + config = { + "num_heartbeats_timeout": 10, + "raylet_heartbeat_period_milliseconds": 100, + "max_direct_call_object_size": 100, + "task_retry_delay_ms": 100, + "object_timeout_milliseconds": 200, + } + cluster = ray_start_cluster + # Head node with no resources. + cluster.add_node( + num_cpus=0, _system_config=config, enable_object_reconstruction=True + ) + ray.init(address=cluster.address) + # Node to place the initial object. + node_to_kill = cluster.add_node( + num_cpus=1, resources={"node1": 1}, object_store_memory=10 ** 8 + ) + cluster.add_node(num_cpus=1, resources={"node2": 1}, object_store_memory=10 ** 8) + cluster.wait_for_nodes() + + @ray.remote + def large_object(): + return np.zeros(10 ** 6, dtype=np.uint8) + + @ray.remote + def dependent_task(x): + return + + for _ in range(3): + obj = large_object.options(resources={"node1": 1}).remote() + ray.get(dependent_task.options(resources={"node2": 1}).remote(obj)) + + outputs = [ + large_object.options(resources={"node1": 1}).remote() for _ in range(1000) + ] + outputs = [ + dependent_task.options(resources={"node2": 1}).remote(obj) + for obj in outputs + ] + + cluster.remove_node(node_to_kill, allow_graceful=False) + node_to_kill = cluster.add_node( + num_cpus=1, resources={"node1": 1}, object_store_memory=10 ** 8 + ) + + i = 0 + while outputs: + ref = outputs.pop(0) + print(i, ref) + ray.get(ref) + i += 1 + + @pytest.mark.skipif(sys.platform == "win32", reason="Failing on Windows.") @pytest.mark.parametrize("reconstruction_enabled", [False, True]) def test_nondeterministic_output(ray_start_cluster, reconstruction_enabled): @@ -905,6 +960,60 @@ def test_nested(ray_start_cluster, reconstruction_enabled): ray.get(ref, timeout=60) +@pytest.mark.parametrize("reconstruction_enabled", [False, True]) +def test_spilled(ray_start_cluster, reconstruction_enabled): + config = { + "num_heartbeats_timeout": 10, + "raylet_heartbeat_period_milliseconds": 100, + "object_timeout_milliseconds": 200, + } + # Workaround to reset the config to the default value. + if not reconstruction_enabled: + config["lineage_pinning_enabled"] = False + + cluster = ray_start_cluster + # Head node with no resources. + cluster.add_node( + num_cpus=0, + _system_config=config, + enable_object_reconstruction=reconstruction_enabled, + ) + ray.init(address=cluster.address) + # Node to place the initial object. + node_to_kill = cluster.add_node( + num_cpus=1, resources={"node1": 1}, object_store_memory=10 ** 8 + ) + cluster.wait_for_nodes() + + @ray.remote(max_retries=1 if reconstruction_enabled else 0) + def large_object(): + return np.zeros(10 ** 7, dtype=np.uint8) + + @ray.remote + def dependent_task(x): + return + + obj = large_object.options(resources={"node1": 1}).remote() + ray.get(dependent_task.options(resources={"node1": 1}).remote(obj)) + # Force spilling. + objs = [large_object.options(resources={"node1": 1}).remote() for _ in range(20)] + for o in objs: + ray.get(o) + + cluster.remove_node(node_to_kill, allow_graceful=False) + node_to_kill = cluster.add_node( + num_cpus=1, resources={"node1": 1}, object_store_memory=10 ** 8 + ) + + if reconstruction_enabled: + ray.get(dependent_task.remote(obj), timeout=60) + else: + with pytest.raises(ray.exceptions.RayTaskError): + ray.get(dependent_task.remote(obj), timeout=60) + with pytest.raises(ray.exceptions.ObjectLostError): + ray.get(obj, timeout=60) + + if __name__ == "__main__": import pytest diff --git a/src/ray/common/ray_config_def.h b/src/ray/common/ray_config_def.h index 7ff9a3847..1b5494fb3 100644 --- a/src/ray/common/ray_config_def.h +++ b/src/ray/common/ray_config_def.h @@ -90,8 +90,17 @@ RAY_CONFIG(int64_t, free_objects_period_milliseconds, 1000) /// to -1. RAY_CONFIG(size_t, free_objects_batch_size, 100) +/// Whether to pin object lineage, i.e. the task that created the object and +/// the task's recursive dependencies. If this is set to true, then the system +/// will attempt to reconstruct the object from its lineage if the object is +/// lost. RAY_CONFIG(bool, lineage_pinning_enabled, false) +/// Objects that require recovery are added to a local cache. This is the +/// duration between attempts to flush and recover the objects in the local +/// cache. +RAY_CONFIG(int64_t, reconstruct_objects_period_milliseconds, 100) + /// Maximum amount of lineage to keep in bytes. This includes the specs of all /// tasks that have previously already finished but that may be retried again. /// If we reach this limit, 50% of the current lineage will be evicted and diff --git a/src/ray/core_worker/core_worker.cc b/src/ray/core_worker/core_worker.cc index 0c8edfa2a..7421750ef 100644 --- a/src/ray/core_worker/core_worker.cc +++ b/src/ray/core_worker/core_worker.cc @@ -218,10 +218,14 @@ CoreWorker::CoreWorker(const CoreWorkerOptions &options, const WorkerID &worker_ }, /*callback_service*/ &io_service_); + auto check_node_alive_fn = [this](const NodeID &node_id) { + auto node = gcs_client_->Nodes().Get(node_id); + return node != nullptr; + }; reference_counter_ = std::make_shared( rpc_address_, /*object_info_publisher=*/object_info_publisher_.get(), - /*object_info_subscriber=*/object_info_subscriber_.get(), + /*object_info_subscriber=*/object_info_subscriber_.get(), check_node_alive_fn, RayConfig::instance().lineage_pinning_enabled(), [this](const rpc::Address &addr) { return std::shared_ptr( new rpc::CoreWorkerClient(addr, *client_call_manager_)); @@ -257,17 +261,6 @@ CoreWorker::CoreWorker(const CoreWorkerOptions &options, const WorkerID &worker_ periodical_runner_.RunFnPeriodically([this] { InternalHeartbeat(); }, kInternalHeartbeatMillis); - auto check_node_alive_fn = [this](const NodeID &node_id) { - auto node = gcs_client_->Nodes().Get(node_id); - return node != nullptr; - }; - auto reconstruct_object_callback = [this](const ObjectID &object_id) { - io_service_.post( - [this, object_id]() { - RAY_CHECK(object_recovery_manager_->RecoverObject(object_id)); - }, - "CoreWorker.ReconstructObject"); - }; auto push_error_callback = [this](const JobID &job_id, const std::string &type, const std::string &error_message, double timestamp) { return PushError(job_id, type, error_message, timestamp); @@ -301,8 +294,7 @@ CoreWorker::CoreWorker(const CoreWorkerOptions &options, const WorkerID &worker_ } } }, - check_node_alive_fn, reconstruct_object_callback, push_error_callback, - RayConfig::instance().max_lineage_bytes())); + push_error_callback, RayConfig::instance().max_lineage_bytes())); // Create an entry for the driver task in the task table. This task is // added immediately with status RUNNING. This allows us to push errors @@ -458,6 +450,24 @@ CoreWorker::CoreWorker(const CoreWorkerOptions &options, const WorkerID &worker_ RayEventContext::Instance().SetEventContext( ray::rpc::Event_SourceType::Event_SourceType_CORE_WORKER, {{"worker_id", worker_id.Hex()}}); + + periodical_runner_.RunFnPeriodically( + [this] { + const auto lost_objects = reference_counter_->FlushObjectsToRecover(); + // Delete the objects from the in-memory store to indicate that they are not + // available. The object recovery manager will guarantee that a new value + // will eventually be stored for the objects (either an + // UnreconstructableError or a value reconstructed from lineage). + memory_store_->Delete(lost_objects); + for (const auto &object_id : lost_objects) { + // NOTE(swang): There is a race condition where this can return false if + // the reference went out of scope since the call to the ref counter to get + // the lost objects. It's okay to not mark the object as failed or recover + // the object since there are no reference holders. + RAY_UNUSED(object_recovery_manager_->RecoverObject(object_id)); + } + }, + 100); } CoreWorker::~CoreWorker() { RAY_LOG(INFO) << "Core worker is destructed"; } @@ -621,22 +631,7 @@ void CoreWorker::OnNodeRemoved(const NodeID &node_id) { RAY_LOG(INFO) << "Node failure from " << node_id << ". All objects pinned on that node will be lost if object " "reconstruction is not enabled."; - const auto lost_objects = reference_counter_->ResetObjectsOnRemovedNode(node_id); - // Delete the objects from the in-memory store to indicate that they are not - // available. The object recovery manager will guarantee that a new value - // will eventually be stored for the objects (either an - // UnreconstructableError or a value reconstructed from lineage). - memory_store_->Delete(lost_objects); - for (const auto &object_id : lost_objects) { - // NOTE(swang): There is a race condition where this can return false if - // the reference went out of scope since the call to the ref counter to get - // the lost objects. It's okay to not mark the object as failed or recover - // the object since there are no reference holders. - auto recovered = object_recovery_manager_->RecoverObject(object_id); - if (!recovered) { - RAY_LOG(DEBUG) << "Object " << object_id << " lost due to node failure " << node_id; - } - } + reference_counter_->ResetObjectsOnRemovedNode(node_id); } const WorkerID &CoreWorker::GetWorkerID() const { return worker_context_.GetWorkerID(); } @@ -2953,7 +2948,7 @@ void CoreWorker::HandleAddSpilledUrl(const rpc::AddSpilledUrlRequest &request, << ", which has been spilled to " << spilled_url << " on node " << node_id; auto reference_exists = reference_counter_->HandleObjectSpilled( - object_id, spilled_url, node_id, request.size(), /*release*/ false); + object_id, spilled_url, node_id, request.size()); Status status = reference_exists ? Status::OK() diff --git a/src/ray/core_worker/object_recovery_manager.cc b/src/ray/core_worker/object_recovery_manager.cc index ae7513db6..a81498f8f 100644 --- a/src/ray/core_worker/object_recovery_manager.cc +++ b/src/ray/core_worker/object_recovery_manager.cc @@ -38,7 +38,8 @@ bool ObjectRecoveryManager::RecoverObject(const ObjectID &object_id) { } bool already_pending_recovery = true; - if (pinned_at.IsNil() && !spilled) { + bool requires_recovery = pinned_at.IsNil() && !spilled; + if (requires_recovery) { { absl::MutexLock lock(&mu_); // Mark that we are attempting recovery for this object to prevent @@ -61,8 +62,11 @@ bool ObjectRecoveryManager::RecoverObject(const ObjectID &object_id) { [this](const ObjectID &object_id, const std::vector &locations) { PinOrReconstructObject(object_id, locations); })); - } else { + } else if (requires_recovery) { RAY_LOG(DEBUG) << "Recovery already started for object " << object_id; + } else { + RAY_LOG(DEBUG) << "Object " << object_id + << " has a pinned or spilled location, skipping recovery"; } return true; } diff --git a/src/ray/core_worker/reference_count.cc b/src/ray/core_worker/reference_count.cc index 88405f8ff..9ca66137a 100644 --- a/src/ray/core_worker/reference_count.cc +++ b/src/ray/core_worker/reference_count.cc @@ -590,6 +590,15 @@ void ReferenceCounter::ReleasePlasmaObject(ReferenceTable::iterator it) { it->second.on_delete = nullptr; } it->second.pinned_at_raylet_id.reset(); + if (it->second.spilled && !it->second.spilled_node_id.IsNil()) { + // The spilled copy of the object should get deleted during the on_delete + // callback, so reset the spill location metadata here. + // NOTE(swang): Spilled copies in cloud storage are not GCed, so we do not + // reset the spilled metadata. + it->second.spilled = false; + it->second.spilled_url = ""; + it->second.spilled_node_id = NodeID::Nil(); + } } bool ReferenceCounter::SetDeleteCallback( @@ -619,19 +628,26 @@ bool ReferenceCounter::SetDeleteCallback( return true; } -std::vector ReferenceCounter::ResetObjectsOnRemovedNode( - const NodeID &raylet_id) { +void ReferenceCounter::ResetObjectsOnRemovedNode(const NodeID &raylet_id) { absl::MutexLock lock(&mutex_); - std::vector lost_objects; for (auto it = object_id_refs_.begin(); it != object_id_refs_.end(); it++) { const auto &object_id = it->first; - if (it->second.pinned_at_raylet_id.value_or(NodeID::Nil()) == raylet_id) { - lost_objects.push_back(object_id); + if (it->second.pinned_at_raylet_id.value_or(NodeID::Nil()) == raylet_id || + it->second.spilled_node_id == raylet_id) { ReleasePlasmaObject(it); + if (!it->second.OutOfScope(lineage_pinning_enabled_)) { + objects_to_recover_.push_back(object_id); + } } RemoveObjectLocationInternal(it, raylet_id); } - return lost_objects; +} + +std::vector ReferenceCounter::FlushObjectsToRecover() { + absl::MutexLock lock(&mutex_); + std::vector objects_to_recover = std::move(objects_to_recover_); + objects_to_recover_.clear(); + return objects_to_recover; } void ReferenceCounter::UpdateObjectPinnedAtRaylet(const ObjectID &object_id, @@ -655,9 +671,14 @@ void ReferenceCounter::UpdateObjectPinnedAtRaylet(const ObjectID &object_id, // Only the owner tracks the location. RAY_CHECK(it->second.owned_by_us); if (!it->second.OutOfScope(lineage_pinning_enabled_)) { - it->second.pinned_at_raylet_id = raylet_id; - // We eagerly add the pinned location to the set of object locations. - AddObjectLocationInternal(it, raylet_id); + if (check_node_alive_(raylet_id)) { + it->second.pinned_at_raylet_id = raylet_id; + // We eagerly add the pinned location to the set of object locations. + AddObjectLocationInternal(it, raylet_id); + } else { + ReleasePlasmaObject(it); + objects_to_recover_.push_back(object_id); + } } } } @@ -1164,29 +1185,40 @@ size_t ReferenceCounter::GetObjectSize(const ObjectID &object_id) const { bool ReferenceCounter::HandleObjectSpilled(const ObjectID &object_id, const std::string spilled_url, - const NodeID &spilled_node_id, int64_t size, - bool release) { + const NodeID &spilled_node_id, int64_t size) { absl::MutexLock lock(&mutex_); auto it = object_id_refs_.find(object_id); if (it == object_id_refs_.end()) { RAY_LOG(WARNING) << "Spilled object " << object_id << " already out of scope"; return false; } + if (it->second.OutOfScope(lineage_pinning_enabled_) && !spilled_node_id.IsNil()) { + // NOTE(swang): If the object is out of scope and was spilled locally by + // its primary raylet, then we should have already sent the "object + // evicted" notification to delete the copy at this spilled URL. Therefore, + // we should not add this spill URL as a location. + return false; + } it->second.spilled = true; - if (spilled_url != "") { - it->second.spilled_url = spilled_url; - } - if (!spilled_node_id.IsNil()) { - it->second.spilled_node_id = spilled_node_id; - } - if (size > 0) { - it->second.object_size = size; - } - PushToLocationSubscribers(it); - if (release) { - // Release the primary plasma copy, if any. + bool spilled_location_alive = + spilled_node_id.IsNil() || check_node_alive_(spilled_node_id); + if (spilled_location_alive) { + if (spilled_url != "") { + it->second.spilled_url = spilled_url; + } + if (!spilled_node_id.IsNil()) { + it->second.spilled_node_id = spilled_node_id; + } + if (size > 0) { + it->second.object_size = size; + } + PushToLocationSubscribers(it); + } else { + RAY_LOG(DEBUG) << "Object " << object_id << " spilled to dead node " + << spilled_node_id; ReleasePlasmaObject(it); + objects_to_recover_.push_back(object_id); } return true; } diff --git a/src/ray/core_worker/reference_count.h b/src/ray/core_worker/reference_count.h index 6c239db04..71b4e138c 100644 --- a/src/ray/core_worker/reference_count.h +++ b/src/ray/core_worker/reference_count.h @@ -66,13 +66,15 @@ class ReferenceCounter : public ReferenceCounterInterface, ReferenceCounter(const rpc::WorkerAddress &rpc_address, pubsub::PublisherInterface *object_info_publisher, pubsub::SubscriberInterface *object_info_subscriber, + const std::function &check_node_alive, bool lineage_pinning_enabled = false, rpc::ClientFactoryFn client_factory = nullptr) : rpc_address_(rpc_address), lineage_pinning_enabled_(lineage_pinning_enabled), borrower_pool_(client_factory), object_info_publisher_(object_info_publisher), - object_info_subscriber_(object_info_subscriber) {} + object_info_subscriber_(object_info_subscriber), + check_node_alive_(check_node_alive) {} ~ReferenceCounter() {} @@ -346,14 +348,18 @@ class ReferenceCounter : public ReferenceCounterInterface, NodeID *pinned_at, bool *spilled) const LOCKS_EXCLUDED(mutex_); - /// Get and reset the objects that were pinned on the given node. This - /// method should be called upon a node failure, to determine which plasma - /// objects were lost. If a deletion callback was set for a lost object, it - /// will be invoked and reset. + /// Get and reset the objects that were pinned or spilled on the given node. + /// This method should be called upon a node failure, to trigger + /// reconstruction for any lost objects that are still in scope. + /// + /// If a deletion callback was set for a lost object, it will be invoked and + /// reset. /// /// \param[in] node_id The node whose object store has been removed. /// \return The set of objects that were pinned on the given node. - std::vector ResetObjectsOnRemovedNode(const NodeID &raylet_id); + void ResetObjectsOnRemovedNode(const NodeID &raylet_id); + + std::vector FlushObjectsToRecover(); /// Whether we have a reference to a particular ObjectID. /// @@ -427,10 +433,9 @@ class ReferenceCounter : public ReferenceCounterInterface, /// \param[in] spilled_url The URL to which the object has been spilled. /// \param[in] spilled_node_id The ID of the node on which the object was spilled. /// \param[in] size The size of the object. - /// \param[in] release Whether to release the reference. - /// \return True if the reference exists, false otherwise. + /// \return True if the reference exists and is in scope, false otherwise. bool HandleObjectSpilled(const ObjectID &object_id, const std::string spilled_url, - const NodeID &spilled_node_id, int64_t size, bool release); + const NodeID &spilled_node_id, int64_t size); /// Get locality data for object. This is used by the leasing policy to implement /// locality-aware leasing. @@ -879,6 +884,16 @@ class ReferenceCounter : public ReferenceCounterInterface, /// object's place in the queue. absl::flat_hash_map::iterator> reconstructable_owned_objects_index_ GUARDED_BY(mutex_); + + /// Called to check whether a raylet is still alive. This is used when adding + /// the primary or spilled location of an object. If the node is dead, then + /// the object will be added to the buffer objects to recover. + const std::function check_node_alive_; + + /// A buffer of the objects whose primary or spilled locations have been lost + /// due to node failure. These objects are still in scope and need to be + /// recovered. + std::vector objects_to_recover_ GUARDED_BY(mutex_); }; } // namespace core diff --git a/src/ray/core_worker/reference_count_test.cc b/src/ray/core_worker/reference_count_test.cc index b9512b822..254f9a8ae 100644 --- a/src/ray/core_worker/reference_count_test.cc +++ b/src/ray/core_worker/reference_count_test.cc @@ -40,7 +40,8 @@ class ReferenceCountTest : public ::testing::Test { rpc::Address addr; publisher_ = std::make_shared(); subscriber_ = std::make_shared(); - rc = std::make_unique(addr, publisher_.get(), subscriber_.get()); + rc = std::make_unique(addr, publisher_.get(), subscriber_.get(), + [](const NodeID &node_id) { return true; }); } virtual void TearDown() { @@ -63,8 +64,10 @@ class ReferenceCountLineageEnabledTest : public ::testing::Test { rpc::Address addr; publisher_ = std::make_shared(); subscriber_ = std::make_shared(); - rc = std::make_unique(addr, publisher_.get(), subscriber_.get(), - /*lineage_pinning_enabled=*/true); + rc = std::make_unique( + addr, publisher_.get(), subscriber_.get(), + [](const NodeID &node_id) { return true; }, + /*lineage_pinning_enabled=*/true); } virtual void TearDown() { @@ -280,7 +283,9 @@ class MockWorkerClient : public MockCoreWorkerClientInterface { subscriber_(std::make_shared( &directory, &subscription_callback_map, &subscription_failure_callback_map, WorkerID::FromBinary(address_.worker_id()), client_factory)), - rc_(rpc::WorkerAddress(address_), publisher_.get(), subscriber_.get(), + rc_( + rpc::WorkerAddress(address_), publisher_.get(), subscriber_.get(), + [](const NodeID &node_id) { return true; }, /*lineage_pinning_enabled=*/false, client_factory) {} ~MockWorkerClient() override { @@ -704,8 +709,9 @@ TEST(MemoryStoreIntegrationTest, TestSimple) { auto publisher = std::make_shared(); auto subscriber = std::make_shared(); - auto rc = std::shared_ptr(new ReferenceCounter( - rpc::WorkerAddress(rpc::Address()), publisher.get(), subscriber.get())); + auto rc = std::shared_ptr( + new ReferenceCounter(rpc::WorkerAddress(rpc::Address()), publisher.get(), + subscriber.get(), [](const NodeID &node_id) { return true; })); CoreWorkerMemoryStore store(rc); // Tests putting an object with no references is ignored. @@ -2543,7 +2549,8 @@ TEST_F(ReferenceCountLineageEnabledTest, TestPlasmaLocation) { rc->AddOwnedObject(id, {}, rpc::Address(), "", 0, true, /*add_local_ref=*/true); ASSERT_TRUE(rc->SetDeleteCallback(id, callback)); rc->UpdateObjectPinnedAtRaylet(id, node_id); - auto objects = rc->ResetObjectsOnRemovedNode(node_id); + rc->ResetObjectsOnRemovedNode(node_id); + auto objects = rc->FlushObjectsToRecover(); ASSERT_EQ(objects.size(), 1); ASSERT_EQ(objects[0], id); ASSERT_TRUE(rc->IsPlasmaObjectPinnedOrSpilled(id, &owned_by_us, &pinned_at, &spilled)); diff --git a/src/ray/core_worker/task_manager.cc b/src/ray/core_worker/task_manager.cc index f8e55e4f5..e2b488ee2 100644 --- a/src/ray/core_worker/task_manager.cc +++ b/src/ray/core_worker/task_manager.cc @@ -238,17 +238,11 @@ void TaskManager::CompletePendingTask(const TaskID &task_id, const auto nested_refs = VectorFromProtobuf(return_object.nested_inlined_refs()); if (return_object.in_plasma()) { + // Mark it as in plasma with a dummy object. + RAY_CHECK( + in_memory_store_->Put(RayObject(rpc::ErrorType::OBJECT_IN_PLASMA), object_id)); const auto pinned_at_raylet_id = NodeID::FromBinary(worker_addr.raylet_id()); - if (check_node_alive_(pinned_at_raylet_id)) { - reference_counter_->UpdateObjectPinnedAtRaylet(object_id, pinned_at_raylet_id); - // Mark it as in plasma with a dummy object. - RAY_CHECK(in_memory_store_->Put(RayObject(rpc::ErrorType::OBJECT_IN_PLASMA), - object_id)); - } else { - RAY_LOG(DEBUG) << "Task " << task_id << " returned object " << object_id - << " in plasma on a dead node, attempting to recover."; - reconstruct_object_callback_(object_id); - } + reference_counter_->UpdateObjectPinnedAtRaylet(object_id, pinned_at_raylet_id); } else { // NOTE(swang): If a direct object was promoted to plasma, then we do not // record the node ID that it was pinned at, which means that we will not diff --git a/src/ray/core_worker/task_manager.h b/src/ray/core_worker/task_manager.h index 4a444418d..160bc5459 100644 --- a/src/ray/core_worker/task_manager.h +++ b/src/ray/core_worker/task_manager.h @@ -79,15 +79,11 @@ class TaskManager : public TaskFinisherInterface, public TaskResubmissionInterfa std::shared_ptr reference_counter, PutInLocalPlasmaCallback put_in_local_plasma_callback, RetryTaskCallback retry_task_callback, - const std::function &check_node_alive, - ReconstructObjectCallback reconstruct_object_callback, PushErrorCallback push_error_callback, int64_t max_lineage_bytes) : in_memory_store_(in_memory_store), reference_counter_(reference_counter), put_in_local_plasma_callback_(put_in_local_plasma_callback), retry_task_callback_(retry_task_callback), - check_node_alive_(check_node_alive), - reconstruct_object_callback_(reconstruct_object_callback), push_error_callback_(push_error_callback), max_lineage_bytes_(max_lineage_bytes) { reference_counter_->SetReleaseLineageCallback( @@ -322,17 +318,6 @@ class TaskManager : public TaskFinisherInterface, public TaskResubmissionInterfa /// Called when a task should be retried. const RetryTaskCallback retry_task_callback_; - /// Called to check whether a raylet is still alive. This is used when - /// processing a worker's reply to check whether the node that the worker - /// was on is still alive. If the node is down, the plasma objects returned by the task - /// are marked as failed. - const std::function check_node_alive_; - /// Called when processing a worker's reply if the node that the worker was - /// on died. This should be called to attempt to recover a plasma object - /// returned by the task (or store an error if the object is not - /// recoverable). - const ReconstructObjectCallback reconstruct_object_callback_; - // Called to push an error to the relevant driver. const PushErrorCallback push_error_callback_; diff --git a/src/ray/core_worker/test/object_recovery_manager_test.cc b/src/ray/core_worker/test/object_recovery_manager_test.cc index f3272b5e0..951684ee1 100644 --- a/src/ray/core_worker/test/object_recovery_manager_test.cc +++ b/src/ray/core_worker/test/object_recovery_manager_test.cc @@ -116,6 +116,7 @@ class ObjectRecoveryManagerTestBase : public ::testing::Test { task_resubmitter_(std::make_shared()), ref_counter_(std::make_shared( rpc::Address(), publisher_.get(), subscriber_.get(), + [](const NodeID &node_id) { return true; }, /*lineage_pinning_enabled=*/lineage_enabled)), manager_( rpc::Address(), @@ -257,7 +258,8 @@ TEST_F(ObjectRecoveryManagerTest, TestReconstructionSuppression) { ASSERT_EQ(object_directory_->Flush(), 0); // The object is removed and can be recovered again. - auto objects = ref_counter_->ResetObjectsOnRemovedNode(remote_node_id); + ref_counter_->ResetObjectsOnRemovedNode(remote_node_id); + auto objects = ref_counter_->FlushObjectsToRecover(); ASSERT_EQ(objects.size(), 1); ASSERT_EQ(objects[0], object_id); memory_store_->Delete(objects); diff --git a/src/ray/core_worker/test/task_manager_test.cc b/src/ray/core_worker/test/task_manager_test.cc index 8b054a5a0..cdcf3557d 100644 --- a/src/ray/core_worker/test/task_manager_test.cc +++ b/src/ray/core_worker/test/task_manager_test.cc @@ -44,9 +44,10 @@ class TaskManagerTest : public ::testing::Test { : store_(std::shared_ptr(new CoreWorkerMemoryStore())), publisher_(std::make_shared()), subscriber_(std::make_shared()), - reference_counter_(std::shared_ptr( - new ReferenceCounter(rpc::Address(), publisher_.get(), subscriber_.get(), - lineage_pinning_enabled))), + reference_counter_(std::shared_ptr(new ReferenceCounter( + rpc::Address(), publisher_.get(), subscriber_.get(), + [this](const NodeID &node_id) { return all_nodes_alive_; }, + lineage_pinning_enabled))), manager_( store_, reference_counter_, [this](const RayObject &object, const ObjectID &object_id) { @@ -56,10 +57,6 @@ class TaskManagerTest : public ::testing::Test { num_retries_++; return Status::OK(); }, - [this](const NodeID &node_id) { return all_nodes_alive_; }, - [this](const ObjectID &object_id) { - objects_to_recover_.push_back(object_id); - }, [](const JobID &job_id, const std::string &type, const std::string &error_message, double timestamp) { return Status::OK(); }, @@ -79,7 +76,6 @@ class TaskManagerTest : public ::testing::Test { std::shared_ptr subscriber_; std::shared_ptr reference_counter_; bool all_nodes_alive_ = true; - std::vector objects_to_recover_; TaskManager manager_; int num_retries_ = 0; std::unordered_set stored_in_plasma; @@ -173,7 +169,7 @@ TEST_F(TaskManagerTest, TestPlasmaConcurrentFailure) { auto return_id = spec.ReturnId(0); WorkerContext ctx(WorkerType::WORKER, WorkerID::FromRandom(), JobID::FromInt(0)); - ASSERT_TRUE(objects_to_recover_.empty()); + ASSERT_TRUE(reference_counter_->FlushObjectsToRecover().empty()); all_nodes_alive_ = false; rpc::PushTaskReply reply; @@ -185,9 +181,12 @@ TEST_F(TaskManagerTest, TestPlasmaConcurrentFailure) { ASSERT_FALSE(manager_.IsTaskPending(spec.TaskId())); std::vector> results; - ASSERT_FALSE(store_->Get({return_id}, 1, 0, ctx, false, &results).ok()); - ASSERT_EQ(objects_to_recover_.size(), 1); - ASSERT_EQ(objects_to_recover_[0], return_id); + // Caller of FlushObjectsToRecover is responsible for deleting the object + // from the in-memory store and recovering the object. + ASSERT_TRUE(store_->Get({return_id}, 1, 0, ctx, false, &results).ok()); + auto objects_to_recover = reference_counter_->FlushObjectsToRecover(); + ASSERT_EQ(objects_to_recover.size(), 1); + ASSERT_EQ(objects_to_recover[0], return_id); } TEST_F(TaskManagerTest, TestFailPendingTask) {