From e50aa99be1ec439236d697596e22a05c9175362c Mon Sep 17 00:00:00 2001 From: Edward Oakes Date: Fri, 20 Dec 2019 17:06:33 -0800 Subject: [PATCH] Reference counting for direct call submitted tasks (#6514) Co-authored-by: Zhijun Fu <37800433+zhijunfu@users.noreply.github.com> --- python/ray/_raylet.pyx | 26 ++- python/ray/includes/libcoreworker.pxd | 7 +- python/ray/tests/test_reference_counting.py | 162 ++++++++++++++ src/ray/core_worker/core_worker.cc | 54 +---- src/ray/core_worker/core_worker.h | 35 ++- src/ray/core_worker/reference_count.cc | 112 +++++----- src/ray/core_worker/reference_count.h | 54 +++-- src/ray/core_worker/reference_count_test.cc | 201 ++++-------------- src/ray/core_worker/task_manager.cc | 50 ++++- src/ray/core_worker/task_manager.h | 24 ++- .../test/direct_actor_transport_test.cc | 3 +- .../test/direct_task_transport_test.cc | 27 ++- src/ray/core_worker/test/task_manager_test.cc | 76 +++++-- .../transport/dependency_resolver.cc | 9 +- .../transport/dependency_resolver.h | 9 +- .../transport/direct_actor_transport.h | 4 +- .../transport/direct_task_transport.h | 3 +- 17 files changed, 512 insertions(+), 344 deletions(-) create mode 100644 python/ray/tests/test_reference_counting.py diff --git a/python/ray/_raylet.pyx b/python/ray/_raylet.pyx index 07ecb6f77..12123dc0f 100644 --- a/python/ray/_raylet.pyx +++ b/python/ray/_raylet.pyx @@ -1073,16 +1073,12 @@ cdef class CoreWorker: return output def add_object_id_reference(self, ObjectID object_id): - cdef: - CObjectID c_object_id = object_id.native() # Note: faster to not release GIL for short-running op. - self.core_worker.get().AddObjectIDReference(c_object_id) + self.core_worker.get().AddLocalReference(object_id.native()) def remove_object_id_reference(self, ObjectID object_id): - cdef: - CObjectID c_object_id = object_id.native() # Note: faster to not release GIL for short-running op. - self.core_worker.get().RemoveObjectIDReference(c_object_id) + self.core_worker.get().RemoveLocalReference(object_id.native()) def serialize_and_promote_object_id(self, ObjectID object_id): cdef: @@ -1174,6 +1170,24 @@ cdef class CoreWorker: def current_actor_is_asyncio(self): return self.core_worker.get().GetWorkerContext().CurrentActorIsAsync() + def get_all_reference_counts(self): + cdef: + unordered_map[CObjectID, pair[size_t, size_t]] c_ref_counts + unordered_map[CObjectID, pair[size_t, size_t]].iterator it + + c_ref_counts = self.core_worker.get().GetAllReferenceCounts() + it = c_ref_counts.begin() + + ref_counts = {} + while it != c_ref_counts.end(): + object_id = ObjectID(dereference(it).first.Binary()) + ref_counts[object_id] = { + "local": dereference(it).second.first, + "submitted": dereference(it).second.second} + postincrement(it) + + return ref_counts + def in_memory_store_get_async(self, ObjectID object_id, future): self.core_worker.get().GetAsync( object_id.native(), diff --git a/python/ray/includes/libcoreworker.pxd b/python/ray/includes/libcoreworker.pxd index b8153fe13..1e8abf08b 100644 --- a/python/ray/includes/libcoreworker.pxd +++ b/python/ray/includes/libcoreworker.pxd @@ -118,8 +118,8 @@ cdef extern from "ray/core_worker/core_worker.h" nogil: CActorID DeserializeAndRegisterActorHandle(const c_string &bytes) CRayStatus SerializeActorHandle(const CActorID &actor_id, c_string *bytes) - void AddObjectIDReference(const CObjectID &object_id) - void RemoveObjectIDReference(const CObjectID &object_id) + void AddLocalReference(const CObjectID &object_id) + void RemoveLocalReference(const CObjectID &object_id) void PromoteObjectToPlasma(const CObjectID &object_id) void PromoteToPlasmaAndGetOwnershipInfo(const CObjectID &object_id, CTaskID *owner_id, @@ -149,6 +149,9 @@ cdef extern from "ray/core_worker/core_worker.h" nogil: CWorkerContext &GetWorkerContext() void YieldCurrentFiber(CFiberEvent &coroutine_done) + + unordered_map[CObjectID, pair[size_t, size_t]] GetAllReferenceCounts() + void GetAsync(const CObjectID &object_id, ray_callback_function successs_callback, ray_callback_function fallback_callback, diff --git a/python/ray/tests/test_reference_counting.py b/python/ray/tests/test_reference_counting.py new file mode 100644 index 000000000..b5367afdc --- /dev/null +++ b/python/ray/tests/test_reference_counting.py @@ -0,0 +1,162 @@ +# coding: utf-8 +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import copy +import tempfile +import numpy as np +import time +import logging +import uuid + +import ray +import ray.cluster_utils +import ray.test_utils + +logger = logging.getLogger(__name__) + + +def _check_refcounts(expected): + actual = ray.worker.global_worker.core_worker.get_all_reference_counts() + assert len(expected) == len(actual) + for object_id, (local, submitted) in expected.items(): + assert object_id in actual + assert local == actual[object_id]["local"] + assert submitted == actual[object_id]["submitted"] + + +def check_refcounts(expected, timeout=1): + start = time.time() + while True: + try: + _check_refcounts(expected) + break + except AssertionError as e: + if time.time() - start > timeout: + raise e + else: + time.sleep(0.1) + + +def test_local_refcounts(ray_start_regular): + oid1 = ray.put(None) + check_refcounts({oid1: (1, 0)}) + oid1_copy = copy.copy(oid1) + check_refcounts({oid1: (2, 0)}) + del oid1 + check_refcounts({oid1_copy: (1, 0)}) + del oid1_copy + check_refcounts({}) + + +def test_dependency_refcounts(ray_start_regular): + # Return a large object that will be spilled to plasma. + def large_object(): + return np.zeros(10 * 1024 * 1024, dtype=np.uint8) + + # TODO: Clean up tmpfiles? + def random_path(): + return os.path.join(tempfile.gettempdir(), uuid.uuid4().hex) + + def touch(path): + with open(path, "w"): + pass + + def wait_for_file(path): + while True: + if os.path.exists(path): + break + time.sleep(0.1) + + @ray.remote + def one_dep(dep, path=None, fail=False): + if path is not None: + wait_for_file(path) + if fail: + raise Exception("failed on purpose") + + @ray.remote + def one_dep_large(dep, path=None): + if path is not None: + wait_for_file(path) + # This should be spilled to plasma. + return large_object() + + # Test that regular plasma dependency refcounts are decremented once the + # task finishes. + f = random_path() + large_dep = ray.put(large_object()) + result = one_dep.remote(large_dep, path=f) + check_refcounts({large_dep: (1, 1), result: (1, 0)}) + touch(f) + # Reference count should be removed once the task finishes. + check_refcounts({large_dep: (1, 0), result: (1, 0)}) + del large_dep, result + check_refcounts({}) + + # Test that inlined dependency refcounts are decremented once they are + # inlined. + f = random_path() + dep = one_dep.remote(None, path=f) + check_refcounts({dep: (1, 0)}) + result = one_dep.remote(dep) + check_refcounts({dep: (1, 1), result: (1, 0)}) + touch(f) + # Reference count should be removed as soon as the dependency is inlined. + check_refcounts({dep: (1, 0), result: (1, 0)}, timeout=1) + del dep, result + check_refcounts({}) + + # Test that spilled plasma dependency refcounts are decremented once + # the task finishes. + f1, f2 = random_path(), random_path() + dep = one_dep_large.remote(None, path=f1) + check_refcounts({dep: (1, 0)}) + result = one_dep.remote(dep, path=f2) + check_refcounts({dep: (1, 1), result: (1, 0)}) + touch(f1) + ray.get(dep, timeout=5.0) + # Reference count should remain because the dependency is in plasma. + check_refcounts({dep: (1, 1), result: (1, 0)}) + touch(f2) + # Reference count should be removed because the task finished. + check_refcounts({dep: (1, 0), result: (1, 0)}) + del dep, result + check_refcounts({}) + + # Test that regular plasma dependency refcounts are decremented if a task + # fails. + f = random_path() + large_dep = ray.put(large_object()) + result = one_dep.remote(large_dep, path=f, fail=True) + check_refcounts({large_dep: (1, 1), result: (1, 0)}) + touch(f) + # Reference count should be removed once the task finishes. + check_refcounts({large_dep: (1, 0), result: (1, 0)}) + del large_dep, result + check_refcounts({}) + + # Test that spilled plasma dependency refcounts are decremented if a task + # fails. + f1, f2 = random_path(), random_path() + dep = one_dep_large.remote(None, path=f1) + check_refcounts({dep: (1, 0)}) + result = one_dep.remote(dep, path=f2, fail=True) + check_refcounts({dep: (1, 1), result: (1, 0)}) + touch(f1) + ray.get(dep, timeout=5.0) + # Reference count should remain because the dependency is in plasma. + check_refcounts({dep: (1, 1), result: (1, 0)}) + touch(f2) + # Reference count should be removed because the task finished. + check_refcounts({dep: (1, 0), result: (1, 0)}) + del dep, result + check_refcounts({}) + + +if __name__ == "__main__": + import pytest + import sys + sys.exit(pytest.main(["-v", __file__])) diff --git a/src/ray/core_worker/core_worker.cc b/src/ray/core_worker/core_worker.cc index bd34b612e..cd28a7335 100644 --- a/src/ray/core_worker/core_worker.cc +++ b/src/ray/core_worker/core_worker.cc @@ -192,14 +192,14 @@ CoreWorker::CoreWorker(const WorkerType worker_type, const Language language, ref_counting_enabled ? reference_counter_ : nullptr, local_raylet_client_)); task_manager_.reset(new TaskManager( - memory_store_, actor_manager_, [this](const TaskSpecification &spec) { + memory_store_, reference_counter_, actor_manager_, + [this](const TaskSpecification &spec) { // Retry after a delay to emulate the existing Raylet reconstruction // behaviour. TODO(ekl) backoff exponentially. RAY_LOG(ERROR) << "Will resubmit task after a 5 second delay: " << spec.DebugString(); to_resubmit_.push_back(std::make_pair(current_time_ms() + 5000, spec)); })); - resolver_.reset(new LocalDependencyResolver(memory_store_)); // Create an entry for the driver task in the task table. This task is // added immediately with status RUNNING. This allows us to push errors @@ -377,8 +377,7 @@ Status CoreWorker::Put(const RayObject &object, ObjectID *object_id) { *object_id = ObjectID::ForPut(worker_context_.GetCurrentTaskID(), worker_context_.GetNextPutIndex(), static_cast(TaskTransportType::RAYLET)); - reference_counter_->AddOwnedObject(*object_id, GetCallerId(), rpc_address_, - std::make_shared>()); + reference_counter_->AddOwnedObject(*object_id, GetCallerId(), rpc_address_); return Put(object, *object_id); } @@ -460,10 +459,6 @@ Status CoreWorker::Get(const std::vector &ids, const int64_t timeout_m // object. will_throw_exception = true; } - // If we got the result for this ObjectID, the task that created it must - // have finished. Therefore, we can safely remove its reference counting - // dependencies. - RemoveObjectIDDependencies(ids[i]); } else { missing_result = true; } @@ -597,30 +592,6 @@ TaskID CoreWorker::GetCallerId() const { return caller_id; } -void CoreWorker::PinObjectReferences(const TaskSpecification &task_spec, - const TaskTransportType transport_type) { - size_t num_returns = task_spec.NumReturns(); - if (task_spec.IsActorCreationTask() || task_spec.IsActorTask()) { - num_returns--; - } - - std::shared_ptr> task_deps = - std::make_shared>(); - for (size_t i = 0; i < task_spec.NumArgs(); i++) { - if (task_spec.ArgByRef(i)) { - for (size_t j = 0; j < task_spec.ArgIdCount(i); j++) { - task_deps->push_back(task_spec.ArgId(i, j)); - } - } - } - - // Note that we call this even if task_deps.size() == 0, in order to pin the return id. - for (size_t i = 0; i < num_returns; i++) { - reference_counter_->AddOwnedObject(task_spec.ReturnId(i, transport_type), - GetCallerId(), rpc_address_, task_deps); - } -} - Status CoreWorker::SubmitTask(const RayFunction &function, const std::vector &args, const TaskOptions &task_options, @@ -642,11 +613,9 @@ Status CoreWorker::SubmitTask(const RayFunction &function, return_ids); TaskSpecification task_spec = builder.Build(); if (task_options.is_direct_call) { - task_manager_->AddPendingTask(task_spec, max_retries); - PinObjectReferences(task_spec, TaskTransportType::DIRECT); + task_manager_->AddPendingTask(GetCallerId(), rpc_address_, task_spec, max_retries); return direct_task_submitter_->SubmitTask(task_spec); } else { - PinObjectReferences(task_spec, TaskTransportType::RAYLET); return local_raylet_client_->SubmitTask(task_spec); } } @@ -685,11 +654,10 @@ Status CoreWorker::CreateActor(const RayFunction &function, *return_actor_id = actor_id; TaskSpecification task_spec = builder.Build(); if (actor_creation_options.is_direct_call) { - task_manager_->AddPendingTask(task_spec, actor_creation_options.max_reconstructions); - PinObjectReferences(task_spec, TaskTransportType::DIRECT); + task_manager_->AddPendingTask(GetCallerId(), rpc_address_, task_spec, + actor_creation_options.max_reconstructions); return direct_task_submitter_->SubmitTask(task_spec); } else { - PinObjectReferences(task_spec, TaskTransportType::RAYLET); return local_raylet_client_->SubmitTask(task_spec); } } @@ -729,8 +697,7 @@ Status CoreWorker::SubmitActorTask(const ActorID &actor_id, const RayFunction &f Status status; TaskSpecification task_spec = builder.Build(); if (is_direct_call) { - task_manager_->AddPendingTask(task_spec); - PinObjectReferences(task_spec, TaskTransportType::DIRECT); + task_manager_->AddPendingTask(GetCallerId(), rpc_address_, task_spec); if (actor_handle->IsDead()) { auto status = Status::IOError("sent task to dead actor"); task_manager_->PendingTaskFailed(task_spec.TaskId(), rpc::ErrorType::ACTOR_DIED, @@ -739,7 +706,6 @@ Status CoreWorker::SubmitActorTask(const ActorID &actor_id, const RayFunction &f status = direct_actor_submitter_->SubmitTask(task_spec); } } else { - PinObjectReferences(task_spec, TaskTransportType::RAYLET); RAY_CHECK_OK(local_raylet_client_->SubmitTask(task_spec)); } return status; @@ -1043,17 +1009,17 @@ void CoreWorker::HandleGetObjectStatus(const rpc::GetObjectStatusRequest &reques if (task_manager_->IsTaskPending(object_id.TaskId())) { // Acquire a reference and retry. This prevents the object from being // evicted out from under us before we can start the get. - AddObjectIDReference(object_id); + AddLocalReference(object_id); if (task_manager_->IsTaskPending(object_id.TaskId())) { // The task is pending. Send the reply once the task finishes. memory_store_->GetAsync(object_id, [send_reply_callback](std::shared_ptr obj) { send_reply_callback(Status::OK(), nullptr, nullptr); }); - RemoveObjectIDReference(object_id); + RemoveLocalReference(object_id); } else { // We lost the race, the task is done. - RemoveObjectIDReference(object_id); + RemoveLocalReference(object_id); send_reply_callback(Status::OK(), nullptr, nullptr); } } else { diff --git a/src/ray/core_worker/core_worker.h b/src/ray/core_worker/core_worker.h index 806742027..b6351d030 100644 --- a/src/ray/core_worker/core_worker.h +++ b/src/ray/core_worker/core_worker.h @@ -101,17 +101,19 @@ class CoreWorker { actor_id_ = actor_id; } - /// Increase the reference count for this object ID. + /// Increase the local reference count for this object ID. Should be called + /// by the language frontend when a new reference is created. /// /// \param[in] object_id The object ID to increase the reference count for. - void AddObjectIDReference(const ObjectID &object_id) { + void AddLocalReference(const ObjectID &object_id) { reference_counter_->AddLocalReference(object_id); } - /// Decrease the reference count for this object ID. + /// Decrease the reference count for this object ID. Should be called + /// by the language frontend when a reference is destroyed. /// /// \param[in] object_id The object ID to decrease the reference count for. - void RemoveObjectIDReference(const ObjectID &object_id) { + void RemoveLocalReference(const ObjectID &object_id) { std::vector deleted; reference_counter_->RemoveLocalReference(object_id, &deleted); if (ref_counting_enabled_) { @@ -119,6 +121,12 @@ class CoreWorker { } } + /// Returns a map of all ObjectIDs currently in scope with a pair of their + /// (local, submitted_task) reference counts. For debugging purposes. + std::unordered_map> GetAllReferenceCounts() const { + return reference_counter_->GetAllReferenceCounts(); + } + /// Promote an object to plasma and get its owner information. This should be /// called when serializing an object ID, and the returned information should /// be stored with the serialized object ID. For plasma promotion, if the @@ -434,11 +442,6 @@ class CoreWorker { /// Private methods related to task submission. /// - /// Add task dependencies to the reference counter. This prevents the argument - /// objects from early eviction, and also adds the return object. - void PinObjectReferences(const TaskSpecification &task_spec, - const TaskTransportType transport_type); - /// Give this worker a handle to an actor. /// /// This handle will remain as long as the current actor or task is @@ -485,17 +488,6 @@ class CoreWorker { std::vector> *args, std::vector *arg_reference_ids); - /// Remove reference counting dependencies of this object ID. - /// - /// \param[in] object_id The object whose dependencies should be removed. - void RemoveObjectIDDependencies(const ObjectID &object_id) { - std::vector deleted; - reference_counter_->RemoveDependencies(object_id, &deleted); - if (ref_counting_enabled_) { - memory_store_->Delete(deleted); - } - } - /// Returns whether the message was sent to the wrong worker. The right error reply /// is sent automatically. Messages end up on the wrong worker when a worker dies /// and a new one takes its place with the same place. In this situation, we want @@ -624,9 +616,6 @@ class CoreWorker { absl::flat_hash_map> actor_handles_ GUARDED_BY(actor_handles_mutex_); - /// Resolve local and remote dependencies for actor creation. - std::unique_ptr resolver_; - /// /// Fields related to task execution. /// diff --git a/src/ray/core_worker/reference_count.cc b/src/ray/core_worker/reference_count.cc index f15c7d6a6..f799d67e2 100644 --- a/src/ray/core_worker/reference_count.cc +++ b/src/ray/core_worker/reference_count.cc @@ -14,81 +14,76 @@ void ReferenceCounter::AddBorrowedObject(const ObjectID &object_id, } } -void ReferenceCounter::AddOwnedObject( - const ObjectID &object_id, const TaskID &owner_id, const rpc::Address &owner_address, - std::shared_ptr> dependencies) { +void ReferenceCounter::AddOwnedObject(const ObjectID &object_id, const TaskID &owner_id, + const rpc::Address &owner_address) { absl::MutexLock lock(&mutex_); - - for (const ObjectID &dependency_id : *dependencies) { - AddLocalReferenceInternal(dependency_id); - } - RAY_CHECK(object_id_refs_.count(object_id) == 0) - << "Cannot create an object that already exists. ObjectID: " << object_id; + << "Tried to create an owned object that already exists: " << object_id; // If the entry doesn't exist, we initialize the direct reference count to zero // because this corresponds to a submitted task whose return ObjectID will be created // in the frontend language, incrementing the reference count. - object_id_refs_.emplace(object_id, Reference(owner_id, owner_address, dependencies)); + object_id_refs_.emplace(object_id, Reference(owner_id, owner_address)); } -void ReferenceCounter::AddLocalReferenceInternal(const ObjectID &object_id) { +void ReferenceCounter::AddLocalReference(const ObjectID &object_id) { + absl::MutexLock lock(&mutex_); auto entry = object_id_refs_.find(object_id); if (entry == object_id_refs_.end()) { // TODO: Once ref counting is implemented, we should always know how the - // ObjectID was created, so there should always ben an entry. + // ObjectID was created, so there should always be an entry. entry = object_id_refs_.emplace(object_id, Reference()).first; } entry->second.local_ref_count++; } -void ReferenceCounter::AddLocalReference(const ObjectID &object_id) { - absl::MutexLock lock(&mutex_); - AddLocalReferenceInternal(object_id); -} - void ReferenceCounter::RemoveLocalReference(const ObjectID &object_id, std::vector *deleted) { absl::MutexLock lock(&mutex_); - RemoveReferenceRecursive(object_id, deleted); -} - -void ReferenceCounter::RemoveDependencies(const ObjectID &object_id, - std::vector *deleted) { - absl::MutexLock lock(&mutex_); - auto entry = object_id_refs_.find(object_id); - if (entry == object_id_refs_.end()) { - RAY_LOG(WARNING) << "Tried to remove dependencies for nonexistent object ID: " - << object_id; - return; - } - if (entry->second.dependencies) { - for (const ObjectID &pending_task_object_id : *entry->second.dependencies) { - RemoveReferenceRecursive(pending_task_object_id, deleted); - } - entry->second.dependencies = nullptr; - } -} - -void ReferenceCounter::RemoveReferenceRecursive(const ObjectID &object_id, - std::vector *deleted) { auto entry = object_id_refs_.find(object_id); if (entry == object_id_refs_.end()) { RAY_LOG(WARNING) << "Tried to decrease ref count for nonexistent object ID: " << object_id; return; } - if (--entry->second.local_ref_count == 0) { - // If the reference count reached 0, decrease the reference count for each dependency. - if (entry->second.dependencies) { - for (const ObjectID &pending_task_object_id : *entry->second.dependencies) { - RemoveReferenceRecursive(pending_task_object_id, deleted); - } - } + if (--entry->second.local_ref_count == 0 && + entry->second.submitted_task_ref_count == 0) { object_id_refs_.erase(entry); deleted->push_back(object_id); } } +void ReferenceCounter::AddSubmittedTaskReferences( + const std::vector &object_ids) { + absl::MutexLock lock(&mutex_); + for (const ObjectID &object_id : object_ids) { + auto entry = object_id_refs_.find(object_id); + if (entry == object_id_refs_.end()) { + // TODO: Once ref counting is implemented, we should always know how the + // ObjectID was created, so there should always be an entry. + entry = object_id_refs_.emplace(object_id, Reference()).first; + } + entry->second.submitted_task_ref_count++; + } +} + +void ReferenceCounter::RemoveSubmittedTaskReferences( + const std::vector &object_ids, std::vector *deleted) { + absl::MutexLock lock(&mutex_); + for (const ObjectID &object_id : object_ids) { + auto entry = object_id_refs_.find(object_id); + if (entry == object_id_refs_.end()) { + RAY_LOG(WARNING) << "Tried to decrease ref count for nonexistent object ID: " + << object_id; + return; + } + if (--entry->second.submitted_task_ref_count == 0 && + entry->second.local_ref_count == 0) { + object_id_refs_.erase(entry); + deleted->push_back(object_id); + } + } +} + bool ReferenceCounter::GetOwner(const ObjectID &object_id, TaskID *owner_id, rpc::Address *owner_address) const { absl::MutexLock lock(&mutex_); @@ -126,6 +121,19 @@ std::unordered_set ReferenceCounter::GetAllInScopeObjectIDs() const { return in_scope_object_ids; } +std::unordered_map> +ReferenceCounter::GetAllReferenceCounts() const { + absl::MutexLock lock(&mutex_); + std::unordered_map> all_ref_counts; + all_ref_counts.reserve(object_id_refs_.size()); + for (auto it : object_id_refs_) { + all_ref_counts.emplace(it.first, + std::pair(it.second.local_ref_count, + it.second.submitted_task_ref_count)); + } + return all_ref_counts; +} + void ReferenceCounter::LogDebugString() const { absl::MutexLock lock(&mutex_); @@ -137,15 +145,9 @@ void ReferenceCounter::LogDebugString() const { for (const auto &entry : object_id_refs_) { RAY_LOG(DEBUG) << "\t" << entry.first.Hex(); - RAY_LOG(DEBUG) << "\t\treference count: " << entry.second.local_ref_count; - RAY_LOG(DEBUG) << "\t\tdependencies: "; - if (!entry.second.dependencies) { - RAY_LOG(DEBUG) << "\t\t\tNULL"; - } else { - for (const ObjectID &pending_task_object_id : *entry.second.dependencies) { - RAY_LOG(DEBUG) << "\t\t\t" << pending_task_object_id.Hex(); - } - } + RAY_LOG(DEBUG) << "\t\tlocal refcount: " << entry.second.local_ref_count; + RAY_LOG(DEBUG) << "\t\tsubmitted task refcount: " + << entry.second.submitted_task_ref_count; } } diff --git a/src/ray/core_worker/reference_count.h b/src/ray/core_worker/reference_count.h index 60000a581..d18cc6be6 100644 --- a/src/ray/core_worker/reference_count.h +++ b/src/ray/core_worker/reference_count.h @@ -25,27 +25,32 @@ class ReferenceCounter { /// \param[in] object_id The object to to increment the count for. void AddLocalReference(const ObjectID &object_id) LOCKS_EXCLUDED(mutex_); - /// Decrease the reference count for the ObjectID by one. If the reference count reaches - /// zero, it will be erased from the map and the reference count for all of its - /// dependencies will be decreased be one. + /// Decrease the local reference count for the ObjectID by one. /// /// \param[in] object_id The object to decrement the count for. /// \param[out] deleted List to store objects that hit zero ref count. void RemoveLocalReference(const ObjectID &object_id, std::vector *deleted) LOCKS_EXCLUDED(mutex_); - /// Remove any references to dependencies that this object may have. This does *not* - /// decrease the object's own reference count. + /// Add references for the provided object IDs that correspond to them being + /// dependencies to a submitted task. /// - /// \param[in] object_id The object whose dependencies should be removed. - /// \param[out] deleted List to store objects that hit zero ref count. - void RemoveDependencies(const ObjectID &object_id, std::vector *deleted) - LOCKS_EXCLUDED(mutex_); + /// \param[in] object_ids The object IDs to add references for. + void AddSubmittedTaskReferences(const std::vector &object_ids); + + /// Remove references for the provided object IDs that correspond to them being + /// dependencies to a submitted task. This should be called when inlined + /// dependencies are inlined or when the task finishes for plasma dependencies. + /// + /// \param[in] object_ids The object IDs to remove references for. + /// \param[out] deleted The object IDs whos reference counts reached zero. + void RemoveSubmittedTaskReferences(const std::vector &object_ids, + std::vector *deleted); /// Add an object that we own. The object may depend on other objects. - /// Dependencies for each ObjectID must be set at most once. The direct - /// reference count for the ObjectID is set to zero and the reference count - /// for each dependency is incremented. + /// Dependencies for each ObjectID must be set at most once. The local + /// reference count for the ObjectID is set to zero, which assumes that an + /// ObjectID for it will be created in the language frontend after this call. /// /// TODO(swang): We could avoid copying the owner_id and owner_address since /// we are the owner, but it is easier to store a copy for now, since the @@ -57,9 +62,7 @@ class ReferenceCounter { /// \param[in] owner_address The address of the object's owner. /// \param[in] dependencies The objects that the object depends on. void AddOwnedObject(const ObjectID &object_id, const TaskID &owner_id, - const rpc::Address &owner_address, - std::shared_ptr> dependencies) - LOCKS_EXCLUDED(mutex_); + const rpc::Address &owner_address) LOCKS_EXCLUDED(mutex_); /// Add an object that we are borrowing. /// @@ -82,6 +85,11 @@ class ReferenceCounter { /// Returns a set of all ObjectIDs currently in scope (i.e., nonzero reference count). std::unordered_set GetAllInScopeObjectIDs() const LOCKS_EXCLUDED(mutex_); + /// Returns a map of all ObjectIDs currently in scope with a pair of their + /// (local, submitted_task) reference counts. For debugging purposes. + std::unordered_map> GetAllReferenceCounts() const + LOCKS_EXCLUDED(mutex_); + /// Dumps information about all currently tracked references to RAY_LOG(DEBUG). void LogDebugString() const LOCKS_EXCLUDED(mutex_); @@ -91,22 +99,12 @@ class ReferenceCounter { /// Constructor for a reference whose origin is unknown. Reference() : owned_by_us(false) {} /// Constructor for a reference that we created. - Reference(const TaskID &owner_id, const rpc::Address &owner_address, - std::shared_ptr> deps) - : dependencies(std::move(deps)), - owned_by_us(true), - owner({owner_id, owner_address}) {} - /// Constructor for a reference that was given to us. Reference(const TaskID &owner_id, const rpc::Address &owner_address) - : owned_by_us(false), owner({owner_id, owner_address}) {} + : owned_by_us(true), owner({owner_id, owner_address}) {} /// The local ref count for the ObjectID in the language frontend. size_t local_ref_count = 0; - /// The objects that this object depends on. Tracked only by the owner of - /// the object. Dependencies are stored as shared_ptrs because the same set - /// of dependencies can be shared among multiple entries. For example, when - /// a task has multiple return values, the entry for each return ObjectID - /// depends on all task dependencies. - std::shared_ptr> dependencies; + /// The ref count for submitted tasks that depend on the ObjectID. + size_t submitted_task_ref_count = 0; /// Whether we own the object. If we own the object, then we are /// responsible for tracking the state of the task that creates the object /// (see task_manager.h). diff --git a/src/ray/core_worker/reference_count_test.cc b/src/ray/core_worker/reference_count_test.cc index 384768d2d..c86c2a30e 100644 --- a/src/ray/core_worker/reference_count_test.cc +++ b/src/ray/core_worker/reference_count_test.cc @@ -16,173 +16,63 @@ class ReferenceCountTest : public ::testing::Test { virtual void TearDown() {} }; -// Tests basic incrementing/decrementing of direct reference counts. An entry should only -// be removed once its reference count reaches zero. +// Tests basic incrementing/decrementing of direct/submitted task reference counts. An +// entry should only be removed once both of its reference counts reach zero. TEST_F(ReferenceCountTest, TestBasic) { std::vector out; - ObjectID id = ObjectID::FromRandom(); - rc->AddLocalReference(id); - ASSERT_EQ(rc->NumObjectIDsInScope(), 1); - rc->AddLocalReference(id); - ASSERT_EQ(rc->NumObjectIDsInScope(), 1); - rc->AddLocalReference(id); - ASSERT_EQ(rc->NumObjectIDsInScope(), 1); - rc->RemoveLocalReference(id, &out); - ASSERT_EQ(rc->NumObjectIDsInScope(), 1); - ASSERT_EQ(out.size(), 0); - rc->RemoveLocalReference(id, &out); - ASSERT_EQ(rc->NumObjectIDsInScope(), 1); - ASSERT_EQ(out.size(), 0); - rc->RemoveLocalReference(id, &out); - ASSERT_EQ(rc->NumObjectIDsInScope(), 0); - ASSERT_EQ(out.size(), 1); -} -// Tests the basic logic for dependencies - when an ObjectID with dependencies -// goes out of scope (i.e., reference count reaches zero), all of its dependencies -// should have their reference count decremented and be removed if it reaches zero. -TEST_F(ReferenceCountTest, TestDependencies) { - std::vector out; ObjectID id1 = ObjectID::FromRandom(); ObjectID id2 = ObjectID::FromRandom(); - ObjectID id3 = ObjectID::FromRandom(); - - std::shared_ptr> deps = std::make_shared>(); - deps->push_back(id2); - deps->push_back(id3); - rc->AddOwnedObject(id1, TaskID::Nil(), rpc::Address(), deps); + // Local references. rc->AddLocalReference(id1); - rc->AddLocalReference(id1); - rc->AddLocalReference(id3); - ASSERT_EQ(rc->NumObjectIDsInScope(), 3); - - rc->RemoveLocalReference(id1, &out); - ASSERT_EQ(rc->NumObjectIDsInScope(), 3); - ASSERT_EQ(out.size(), 0); - rc->RemoveLocalReference(id1, &out); - ASSERT_EQ(rc->NumObjectIDsInScope(), 1); - ASSERT_EQ(out.size(), 2); - - rc->RemoveLocalReference(id3, &out); - ASSERT_EQ(rc->NumObjectIDsInScope(), 0); - ASSERT_EQ(out.size(), 3); -} - -// Tests the case where two entries share the same set of dependencies. When one -// entry goes out of scope, it should decrease the reference count for the dependencies -// but they should still be nonzero until the second entry goes out of scope and all -// direct dependencies to the dependencies are removed. -TEST_F(ReferenceCountTest, TestSharedDependencies) { - std::vector out; - ObjectID id1 = ObjectID::FromRandom(); - ObjectID id2 = ObjectID::FromRandom(); - ObjectID id3 = ObjectID::FromRandom(); - ObjectID id4 = ObjectID::FromRandom(); - - std::shared_ptr> deps = std::make_shared>(); - deps->push_back(id3); - deps->push_back(id4); - rc->AddOwnedObject(id1, TaskID::Nil(), rpc::Address(), deps); - rc->AddOwnedObject(id2, TaskID::Nil(), rpc::Address(), deps); - rc->AddLocalReference(id1); rc->AddLocalReference(id2); - rc->AddLocalReference(id4); - ASSERT_EQ(rc->NumObjectIDsInScope(), 4); - - rc->RemoveLocalReference(id1, &out); - ASSERT_EQ(rc->NumObjectIDsInScope(), 3); - ASSERT_EQ(out.size(), 1); - rc->RemoveLocalReference(id2, &out); - ASSERT_EQ(rc->NumObjectIDsInScope(), 1); - ASSERT_EQ(out.size(), 3); - - rc->RemoveLocalReference(id4, &out); - ASSERT_EQ(rc->NumObjectIDsInScope(), 0); - ASSERT_EQ(out.size(), 4); -} - -// Tests the case when an entry has a dependency that itself has a -// dependency. In this case, when the first entry goes out of scope -// it should decrease the reference count for its dependency, causing -// that entry to go out of scope and decrease its dependencies' reference counts. -TEST_F(ReferenceCountTest, TestRecursiveDependencies) { - std::vector out; - ObjectID id1 = ObjectID::FromRandom(); - ObjectID id2 = ObjectID::FromRandom(); - ObjectID id3 = ObjectID::FromRandom(); - ObjectID id4 = ObjectID::FromRandom(); - - std::shared_ptr> deps2 = - std::make_shared>(); - deps2->push_back(id3); - deps2->push_back(id4); - rc->AddOwnedObject(id2, TaskID::Nil(), rpc::Address(), deps2); - - std::shared_ptr> deps1 = - std::make_shared>(); - deps1->push_back(id2); - rc->AddOwnedObject(id1, TaskID::Nil(), rpc::Address(), deps1); - - rc->AddLocalReference(id1); - rc->AddLocalReference(id2); - rc->AddLocalReference(id4); - ASSERT_EQ(rc->NumObjectIDsInScope(), 4); - - rc->RemoveLocalReference(id2, &out); - ASSERT_EQ(rc->NumObjectIDsInScope(), 4); - ASSERT_EQ(out.size(), 0); - rc->RemoveLocalReference(id1, &out); - ASSERT_EQ(rc->NumObjectIDsInScope(), 1); - ASSERT_EQ(out.size(), 3); - - rc->RemoveLocalReference(id4, &out); - ASSERT_EQ(rc->NumObjectIDsInScope(), 0); - ASSERT_EQ(out.size(), 4); -} - -TEST_F(ReferenceCountTest, TestRemoveDependenciesOnly) { - std::vector out; - ObjectID id1 = ObjectID::FromRandom(); - ObjectID id2 = ObjectID::FromRandom(); - ObjectID id3 = ObjectID::FromRandom(); - ObjectID id4 = ObjectID::FromRandom(); - - std::shared_ptr> deps2 = - std::make_shared>(); - deps2->push_back(id3); - deps2->push_back(id4); - rc->AddOwnedObject(id2, TaskID::Nil(), rpc::Address(), deps2); - - std::shared_ptr> deps1 = - std::make_shared>(); - deps1->push_back(id2); - rc->AddOwnedObject(id1, TaskID::Nil(), rpc::Address(), deps1); - - rc->AddLocalReference(id1); - rc->AddLocalReference(id2); - rc->AddLocalReference(id4); - ASSERT_EQ(rc->NumObjectIDsInScope(), 4); - - rc->RemoveDependencies(id2, &out); - ASSERT_EQ(rc->NumObjectIDsInScope(), 3); - ASSERT_EQ(out.size(), 1); - rc->RemoveDependencies(id1, &out); - ASSERT_EQ(rc->NumObjectIDsInScope(), 3); - ASSERT_EQ(out.size(), 1); - + ASSERT_EQ(rc->NumObjectIDsInScope(), 2); rc->RemoveLocalReference(id1, &out); ASSERT_EQ(rc->NumObjectIDsInScope(), 2); - ASSERT_EQ(out.size(), 2); - + ASSERT_EQ(out.size(), 0); rc->RemoveLocalReference(id2, &out); ASSERT_EQ(rc->NumObjectIDsInScope(), 1); - ASSERT_EQ(out.size(), 3); - - rc->RemoveLocalReference(id4, &out); + ASSERT_EQ(out.size(), 1); + rc->RemoveLocalReference(id1, &out); ASSERT_EQ(rc->NumObjectIDsInScope(), 0); - ASSERT_EQ(out.size(), 4); + ASSERT_EQ(out.size(), 2); + out.clear(); + + // Submitted task references. + rc->AddSubmittedTaskReferences({id1}); + rc->AddSubmittedTaskReferences({id1, id2}); + ASSERT_EQ(rc->NumObjectIDsInScope(), 2); + rc->RemoveSubmittedTaskReferences({id1}, &out); + ASSERT_EQ(rc->NumObjectIDsInScope(), 2); + ASSERT_EQ(out.size(), 0); + rc->RemoveSubmittedTaskReferences({id2}, &out); + ASSERT_EQ(rc->NumObjectIDsInScope(), 1); + ASSERT_EQ(out.size(), 1); + rc->RemoveSubmittedTaskReferences({id1}, &out); + ASSERT_EQ(rc->NumObjectIDsInScope(), 0); + ASSERT_EQ(out.size(), 2); + out.clear(); + + // Local & submitted task references. + rc->AddLocalReference(id1); + rc->AddSubmittedTaskReferences({id1, id2}); + rc->AddLocalReference(id2); + ASSERT_EQ(rc->NumObjectIDsInScope(), 2); + rc->RemoveLocalReference(id1, &out); + ASSERT_EQ(rc->NumObjectIDsInScope(), 2); + ASSERT_EQ(out.size(), 0); + rc->RemoveSubmittedTaskReferences({id2}, &out); + ASSERT_EQ(rc->NumObjectIDsInScope(), 2); + ASSERT_EQ(out.size(), 0); + rc->RemoveSubmittedTaskReferences({id1}, &out); + ASSERT_EQ(rc->NumObjectIDsInScope(), 1); + ASSERT_EQ(out.size(), 1); + rc->RemoveLocalReference(id2, &out); + ASSERT_EQ(rc->NumObjectIDsInScope(), 0); + ASSERT_EQ(out.size(), 2); + out.clear(); } // Tests that we can get the owner address correctly for objects that we own, @@ -193,8 +83,7 @@ TEST_F(ReferenceCountTest, TestOwnerAddress) { TaskID task_id = TaskID::ForFakeTask(); rpc::Address address; address.set_ip_address("1234"); - auto deps = std::make_shared>(); - rc->AddOwnedObject(object_id, task_id, address, deps); + rc->AddOwnedObject(object_id, task_id, address); TaskID added_id; rpc::Address added_address; @@ -205,7 +94,7 @@ TEST_F(ReferenceCountTest, TestOwnerAddress) { auto object_id2 = ObjectID::FromRandom(); task_id = TaskID::ForFakeTask(); address.set_ip_address("5678"); - rc->AddOwnedObject(object_id2, task_id, address, deps); + rc->AddOwnedObject(object_id2, task_id, address); ASSERT_TRUE(rc->GetOwner(object_id2, &added_id, &added_address)); ASSERT_EQ(task_id, added_id); ASSERT_EQ(address.ip_address(), added_address.ip_address()); diff --git a/src/ray/core_worker/task_manager.cc b/src/ray/core_worker/task_manager.cc index 806352979..7d905bffe 100644 --- a/src/ray/core_worker/task_manager.cc +++ b/src/ray/core_worker/task_manager.cc @@ -10,11 +10,34 @@ const int64_t kTaskFailureThrottlingThreshold = 50; // Throttle task failure logs to once this interval. const int64_t kTaskFailureLoggingFrequencyMillis = 5000; -void TaskManager::AddPendingTask(const TaskSpecification &spec, int max_retries) { +void TaskManager::AddPendingTask(const TaskID &caller_id, + const rpc::Address &caller_address, + const TaskSpecification &spec, int max_retries) { RAY_LOG(DEBUG) << "Adding pending task " << spec.TaskId(); absl::MutexLock lock(&mu_); std::pair entry = {spec, max_retries}; RAY_CHECK(pending_tasks_.emplace(spec.TaskId(), std::move(entry)).second); + + // Add references for the dependencies to the task. + std::vector task_deps; + for (size_t i = 0; i < spec.NumArgs(); i++) { + if (spec.ArgByRef(i)) { + for (size_t j = 0; j < spec.ArgIdCount(i); j++) { + task_deps.push_back(spec.ArgId(i, j)); + } + } + } + reference_counter_->AddSubmittedTaskReferences(task_deps); + + // Add new owned objects for the return values of the task. + size_t num_returns = spec.NumReturns(); + if (spec.IsActorCreationTask() || spec.IsActorTask()) { + num_returns--; + } + for (size_t i = 0; i < num_returns; i++) { + reference_counter_->AddOwnedObject(spec.ReturnId(i, TaskTransportType::DIRECT), + caller_id, caller_address); + } } void TaskManager::DrainAndShutdown(std::function shutdown) { @@ -48,6 +71,8 @@ void TaskManager::CompletePendingTask(const TaskID &task_id, pending_tasks_.erase(it); } + RemovePlasmaSubmittedTaskReferences(spec); + for (int i = 0; i < reply.return_objects_size(); i++) { const auto &return_object = reply.return_objects(i); ObjectID object_id = ObjectID::FromBinary(return_object.object_id()); @@ -134,6 +159,7 @@ void TaskManager::PendingTaskFailed(const TaskID &task_id, rpc::ErrorType error_ } } } + RemovePlasmaSubmittedTaskReferences(spec); MarkPendingTaskFailed(task_id, spec, error_type); } @@ -148,6 +174,28 @@ void TaskManager::ShutdownIfNeeded() { } } +void TaskManager::RemoveSubmittedTaskReferences(const std::vector &object_ids) { + std::vector deleted; + reference_counter_->RemoveSubmittedTaskReferences(object_ids, &deleted); + in_memory_store_->Delete(deleted); +} + +void TaskManager::OnTaskDependenciesInlined(const std::vector &object_ids) { + RemoveSubmittedTaskReferences(object_ids); +} + +void TaskManager::RemovePlasmaSubmittedTaskReferences(TaskSpecification &spec) { + std::vector plasma_dependencies; + for (size_t i = 0; i < spec.NumArgs(); i++) { + auto count = spec.ArgIdCount(i); + if (count > 0) { + const auto &id = spec.ArgId(i, 0); + plasma_dependencies.push_back(id); + } + } + RemoveSubmittedTaskReferences(plasma_dependencies); +} + void TaskManager::MarkPendingTaskFailed(const TaskID &task_id, const TaskSpecification &spec, rpc::ErrorType error_type) { diff --git a/src/ray/core_worker/task_manager.h b/src/ray/core_worker/task_manager.h index 0db630f38..63f3441eb 100644 --- a/src/ray/core_worker/task_manager.h +++ b/src/ray/core_worker/task_manager.h @@ -21,6 +21,8 @@ class TaskFinisherInterface { virtual void PendingTaskFailed(const TaskID &task_id, rpc::ErrorType error_type, Status *status = nullptr) = 0; + virtual void OnTaskDependenciesInlined(const std::vector &object_ids) = 0; + virtual ~TaskFinisherInterface() {} }; @@ -29,19 +31,24 @@ using RetryTaskCallback = std::function; class TaskManager : public TaskFinisherInterface { public: TaskManager(std::shared_ptr in_memory_store, + std::shared_ptr reference_counter, std::shared_ptr actor_manager, RetryTaskCallback retry_task_callback) : in_memory_store_(in_memory_store), + reference_counter_(reference_counter), actor_manager_(actor_manager), retry_task_callback_(retry_task_callback) {} /// Add a task that is pending execution. /// + /// \param[in] caller_id The TaskID of the calling task. + /// \param[in] caller_address The rpc address of the calling task. /// \param[in] spec The spec of the pending task. /// \param[in] max_retries Number of times this task may be retried /// on failure. /// \return Void. - void AddPendingTask(const TaskSpecification &spec, int max_retries = 0); + void AddPendingTask(const TaskID &caller_id, const rpc::Address &caller_address, + const TaskSpecification &spec, int max_retries = 0); /// Wait for all pending tasks to finish, and then shutdown. /// @@ -72,6 +79,8 @@ class TaskManager : public TaskFinisherInterface { void PendingTaskFailed(const TaskID &task_id, rpc::ErrorType error_type, Status *status = nullptr) override; + void OnTaskDependenciesInlined(const std::vector &object_id) override; + /// Return the spec for a pending task. TaskSpecification GetTaskSpec(const TaskID &task_id) const; @@ -81,12 +90,25 @@ class TaskManager : public TaskFinisherInterface { void MarkPendingTaskFailed(const TaskID &task_id, const TaskSpecification &spec, rpc::ErrorType error_type) LOCKS_EXCLUDED(mu_); + /// Remove submittted task references in the reference counter for the object IDs. + /// If their reference counts reach zero, they are deleted from the in-memory store. + void RemoveSubmittedTaskReferences(const std::vector &object_ids); + + /// Helper function to call RemoveSubmittedTaskReferences on the plasma dependencies + /// of the given task spec. + void RemovePlasmaSubmittedTaskReferences(TaskSpecification &spec); + /// Shutdown if all tasks are finished and shutdown is scheduled. void ShutdownIfNeeded() LOCKS_EXCLUDED(mu_); /// Used to store task results. std::shared_ptr in_memory_store_; + /// Used for reference counting objects. + /// The task manager is responsible for managing all references related to + /// submitted tasks (dependencies and return objects). + std::shared_ptr reference_counter_; + // Interface for publishing actor creation. std::shared_ptr actor_manager_; diff --git a/src/ray/core_worker/test/direct_actor_transport_test.cc b/src/ray/core_worker/test/direct_actor_transport_test.cc index 6e3740b93..5af72d3c0 100644 --- a/src/ray/core_worker/test/direct_actor_transport_test.cc +++ b/src/ray/core_worker/test/direct_actor_transport_test.cc @@ -1,6 +1,5 @@ #include "gmock/gmock.h" #include "gtest/gtest.h" - #include "ray/common/task/task_spec.h" #include "ray/core_worker/store_provider/memory_store/memory_store.h" #include "ray/core_worker/transport/direct_task_transport.h" @@ -45,6 +44,8 @@ class MockTaskFinisher : public TaskFinisherInterface { const rpc::Address *addr)); MOCK_METHOD3(PendingTaskFailed, void(const TaskID &task_id, rpc::ErrorType error_type, Status *status)); + + MOCK_METHOD1(OnTaskDependenciesInlined, void(const std::vector &object_ids)); }; TaskSpecification CreateActorTaskHelper(ActorID actor_id, int64_t counter) { diff --git a/src/ray/core_worker/test/direct_task_transport_test.cc b/src/ray/core_worker/test/direct_task_transport_test.cc index d7ee98608..a7a21f93d 100644 --- a/src/ray/core_worker/test/direct_task_transport_test.cc +++ b/src/ray/core_worker/test/direct_task_transport_test.cc @@ -55,8 +55,13 @@ class MockTaskFinisher : public TaskFinisherInterface { num_tasks_failed++; } + void OnTaskDependenciesInlined(const std::vector &object_ids) override { + num_inlined += object_ids.size(); + } + int num_tasks_complete = 0; int num_tasks_failed = 0; + int num_inlined = 0; }; class MockRayletClient : public WorkerLeaseInterface { @@ -136,17 +141,20 @@ TEST(TestMemoryStore, TestPromoteToPlasma) { TEST(LocalDependencyResolverTest, TestNoDependencies) { auto store = std::make_shared(); - LocalDependencyResolver resolver(store); + auto task_finisher = std::make_shared(); + LocalDependencyResolver resolver(store, task_finisher); TaskSpecification task; bool ok = false; resolver.ResolveDependencies(task, [&ok]() { ok = true; }); ASSERT_TRUE(ok); + ASSERT_EQ(task_finisher->num_inlined, 0); } TEST(LocalDependencyResolverTest, TestIgnorePlasmaDependencies) { auto store = std::make_shared(); - LocalDependencyResolver resolver(store); - ObjectID obj1 = ObjectID::FromRandom().WithTransportType(TaskTransportType::RAYLET); + auto task_finisher = std::make_shared(); + LocalDependencyResolver resolver(store, task_finisher); + ObjectID obj1 = ObjectID::FromRandom(); TaskSpecification task; task.GetMutableMessage().add_args()->add_object_ids(obj1.Binary()); bool ok = false; @@ -154,11 +162,13 @@ TEST(LocalDependencyResolverTest, TestIgnorePlasmaDependencies) { // We ignore and don't block on plasma dependencies. ASSERT_TRUE(ok); ASSERT_EQ(resolver.NumPendingTasks(), 0); + ASSERT_EQ(task_finisher->num_inlined, 0); } TEST(LocalDependencyResolverTest, TestHandlePlasmaPromotion) { auto store = std::make_shared(); - LocalDependencyResolver resolver(store); + auto task_finisher = std::make_shared(); + LocalDependencyResolver resolver(store, task_finisher); ObjectID obj1 = ObjectID::FromRandom().WithTransportType(TaskTransportType::DIRECT); std::string meta = std::to_string(static_cast(rpc::ErrorType::OBJECT_IN_PLASMA)); auto metadata = const_cast(reinterpret_cast(meta.data())); @@ -175,11 +185,13 @@ TEST(LocalDependencyResolverTest, TestHandlePlasmaPromotion) { // Checks that the object id is still a direct call id. ASSERT_TRUE(task.ArgId(0, 0).IsDirectCallType()); ASSERT_EQ(resolver.NumPendingTasks(), 0); + ASSERT_EQ(task_finisher->num_inlined, 0); } TEST(LocalDependencyResolverTest, TestInlineLocalDependencies) { auto store = std::make_shared(); - LocalDependencyResolver resolver(store); + auto task_finisher = std::make_shared(); + LocalDependencyResolver resolver(store, task_finisher); ObjectID obj1 = ObjectID::FromRandom().WithTransportType(TaskTransportType::DIRECT); ObjectID obj2 = ObjectID::FromRandom().WithTransportType(TaskTransportType::DIRECT); auto data = GenerateRandomObject(); @@ -198,11 +210,13 @@ TEST(LocalDependencyResolverTest, TestInlineLocalDependencies) { ASSERT_NE(task.ArgData(0), nullptr); ASSERT_NE(task.ArgData(1), nullptr); ASSERT_EQ(resolver.NumPendingTasks(), 0); + ASSERT_EQ(task_finisher->num_inlined, 2); } TEST(LocalDependencyResolverTest, TestInlinePendingDependencies) { auto store = std::make_shared(); - LocalDependencyResolver resolver(store); + auto task_finisher = std::make_shared(); + LocalDependencyResolver resolver(store, task_finisher); ObjectID obj1 = ObjectID::FromRandom().WithTransportType(TaskTransportType::DIRECT); ObjectID obj2 = ObjectID::FromRandom().WithTransportType(TaskTransportType::DIRECT); auto data = GenerateRandomObject(); @@ -223,6 +237,7 @@ TEST(LocalDependencyResolverTest, TestInlinePendingDependencies) { ASSERT_NE(task.ArgData(0), nullptr); ASSERT_NE(task.ArgData(1), nullptr); ASSERT_EQ(resolver.NumPendingTasks(), 0); + ASSERT_EQ(task_finisher->num_inlined, 2); } TaskSpecification BuildTaskSpec(const std::unordered_map &resources, diff --git a/src/ray/core_worker/test/task_manager_test.cc b/src/ray/core_worker/test/task_manager_test.cc index 671c36338..b9d37b3e4 100644 --- a/src/ray/core_worker/test/task_manager_test.cc +++ b/src/ray/core_worker/test/task_manager_test.cc @@ -1,17 +1,22 @@ -#include "gtest/gtest.h" +#include "ray/core_worker/task_manager.h" +#include "gtest/gtest.h" #include "ray/common/task/task_spec.h" #include "ray/core_worker/actor_manager.h" +#include "ray/core_worker/reference_count.h" #include "ray/core_worker/store_provider/memory_store/memory_store.h" -#include "ray/core_worker/task_manager.h" #include "ray/util/test_util.h" namespace ray { -TaskSpecification CreateTaskHelper(uint64_t num_returns) { +TaskSpecification CreateTaskHelper(uint64_t num_returns, + std::vector dependencies) { TaskSpecification task; task.GetMutableMessage().set_task_id(TaskID::ForFakeTask().Binary()); task.GetMutableMessage().set_num_returns(num_returns); + for (const ObjectID &dep : dependencies) { + task.GetMutableMessage().add_args()->add_object_ids(dep.Binary()); + } return task; } @@ -33,23 +38,31 @@ class TaskManagerTest : public ::testing::Test { public: TaskManagerTest() : store_(std::shared_ptr(new CoreWorkerMemoryStore())), + reference_counter_(std::shared_ptr(new ReferenceCounter())), actor_manager_(std::shared_ptr(new MockActorManager())), - manager_(store_, actor_manager_, [this](const TaskSpecification &spec) { - num_retries_++; - return Status::OK(); - }) {} + manager_(store_, reference_counter_, actor_manager_, + [this](const TaskSpecification &spec) { + num_retries_++; + return Status::OK(); + }) {} std::shared_ptr store_; + std::shared_ptr reference_counter_; std::shared_ptr actor_manager_; TaskManager manager_; int num_retries_ = 0; }; TEST_F(TaskManagerTest, TestTaskSuccess) { - auto spec = CreateTaskHelper(1); + TaskID caller_id = TaskID::Nil(); + rpc::Address caller_address; + ObjectID dep1 = ObjectID::FromRandom(); + ObjectID dep2 = ObjectID::FromRandom(); + auto spec = CreateTaskHelper(1, {dep1, dep2}); ASSERT_FALSE(manager_.IsTaskPending(spec.TaskId())); - manager_.AddPendingTask(spec); + manager_.AddPendingTask(caller_id, caller_address, spec); ASSERT_TRUE(manager_.IsTaskPending(spec.TaskId())); + ASSERT_EQ(reference_counter_->NumObjectIDsInScope(), 3); auto return_id = spec.ReturnId(0, TaskTransportType::DIRECT); WorkerContext ctx(WorkerType::WORKER, JobID::FromInt(0)); @@ -60,6 +73,8 @@ TEST_F(TaskManagerTest, TestTaskSuccess) { return_object->set_data(data->Data(), data->Size()); manager_.CompletePendingTask(spec.TaskId(), reply, nullptr); ASSERT_FALSE(manager_.IsTaskPending(spec.TaskId())); + // Only the return object reference should remain. + ASSERT_EQ(reference_counter_->NumObjectIDsInScope(), 1); std::vector> results; RAY_CHECK_OK(store_->Get({return_id}, 1, -1, ctx, false, &results)); @@ -69,19 +84,33 @@ TEST_F(TaskManagerTest, TestTaskSuccess) { return_object->data().size()), 0); ASSERT_EQ(num_retries_, 0); + + std::vector removed; + reference_counter_->AddLocalReference(return_id); + reference_counter_->RemoveLocalReference(return_id, &removed); + ASSERT_EQ(removed[0], return_id); + ASSERT_EQ(reference_counter_->NumObjectIDsInScope(), 0); } TEST_F(TaskManagerTest, TestTaskFailure) { - auto spec = CreateTaskHelper(1); + TaskID caller_id = TaskID::Nil(); + rpc::Address caller_address; + ObjectID dep1 = ObjectID::FromRandom(); + ObjectID dep2 = ObjectID::FromRandom(); + ASSERT_EQ(reference_counter_->NumObjectIDsInScope(), 0); + auto spec = CreateTaskHelper(1, {dep1, dep2}); ASSERT_FALSE(manager_.IsTaskPending(spec.TaskId())); - manager_.AddPendingTask(spec); + manager_.AddPendingTask(caller_id, caller_address, spec); ASSERT_TRUE(manager_.IsTaskPending(spec.TaskId())); + ASSERT_EQ(reference_counter_->NumObjectIDsInScope(), 3); auto return_id = spec.ReturnId(0, TaskTransportType::DIRECT); WorkerContext ctx(WorkerType::WORKER, JobID::FromInt(0)); auto error = rpc::ErrorType::WORKER_DIED; manager_.PendingTaskFailed(spec.TaskId(), error); ASSERT_FALSE(manager_.IsTaskPending(spec.TaskId())); + // Only the return object reference should remain. + ASSERT_EQ(reference_counter_->NumObjectIDsInScope(), 1); std::vector> results; RAY_CHECK_OK(store_->Get({return_id}, 1, -1, ctx, false, &results)); @@ -90,14 +119,26 @@ TEST_F(TaskManagerTest, TestTaskFailure) { ASSERT_TRUE(results[0]->IsException(&stored_error)); ASSERT_EQ(stored_error, error); ASSERT_EQ(num_retries_, 0); + + std::vector removed; + reference_counter_->AddLocalReference(return_id); + reference_counter_->RemoveLocalReference(return_id, &removed); + ASSERT_EQ(removed[0], return_id); + ASSERT_EQ(reference_counter_->NumObjectIDsInScope(), 0); } TEST_F(TaskManagerTest, TestTaskRetry) { - auto spec = CreateTaskHelper(1); + TaskID caller_id = TaskID::Nil(); + rpc::Address caller_address; + ObjectID dep1 = ObjectID::FromRandom(); + ObjectID dep2 = ObjectID::FromRandom(); + ASSERT_EQ(reference_counter_->NumObjectIDsInScope(), 0); + auto spec = CreateTaskHelper(1, {dep1, dep2}); ASSERT_FALSE(manager_.IsTaskPending(spec.TaskId())); int num_retries = 3; - manager_.AddPendingTask(spec, num_retries); + manager_.AddPendingTask(caller_id, caller_address, spec, num_retries); ASSERT_TRUE(manager_.IsTaskPending(spec.TaskId())); + ASSERT_EQ(reference_counter_->NumObjectIDsInScope(), 3); auto return_id = spec.ReturnId(0, TaskTransportType::DIRECT); WorkerContext ctx(WorkerType::WORKER, JobID::FromInt(0)); @@ -105,6 +146,7 @@ TEST_F(TaskManagerTest, TestTaskRetry) { for (int i = 0; i < num_retries; i++) { manager_.PendingTaskFailed(spec.TaskId(), error); ASSERT_TRUE(manager_.IsTaskPending(spec.TaskId())); + ASSERT_EQ(reference_counter_->NumObjectIDsInScope(), 3); std::vector> results; ASSERT_FALSE(store_->Get({return_id}, 1, 0, ctx, false, &results).ok()); ASSERT_EQ(num_retries_, i + 1); @@ -112,6 +154,8 @@ TEST_F(TaskManagerTest, TestTaskRetry) { manager_.PendingTaskFailed(spec.TaskId(), error); ASSERT_FALSE(manager_.IsTaskPending(spec.TaskId())); + // Only the return object reference should remain. + ASSERT_EQ(reference_counter_->NumObjectIDsInScope(), 1); std::vector> results; RAY_CHECK_OK(store_->Get({return_id}, 1, -0, ctx, false, &results)); @@ -119,6 +163,12 @@ TEST_F(TaskManagerTest, TestTaskRetry) { rpc::ErrorType stored_error; ASSERT_TRUE(results[0]->IsException(&stored_error)); ASSERT_EQ(stored_error, error); + + std::vector removed; + reference_counter_->AddLocalReference(return_id); + reference_counter_->RemoveLocalReference(return_id, &removed); + ASSERT_EQ(removed[0], return_id); + ASSERT_EQ(reference_counter_->NumObjectIDsInScope(), 0); } } // namespace ray diff --git a/src/ray/core_worker/transport/dependency_resolver.cc b/src/ray/core_worker/transport/dependency_resolver.cc index 89b5237a8..3f3b6586a 100644 --- a/src/ray/core_worker/transport/dependency_resolver.cc +++ b/src/ray/core_worker/transport/dependency_resolver.cc @@ -18,7 +18,7 @@ struct TaskState { void InlineDependencies( absl::flat_hash_map> dependencies, - TaskSpecification &task) { + TaskSpecification &task, std::vector *inlined) { auto &msg = task.GetMutableMessage(); size_t found = 0; for (size_t i = 0; i < task.NumArgs(); i++) { @@ -43,6 +43,7 @@ void InlineDependencies( const auto &metadata = it->second->GetMetadata(); mutable_arg->set_metadata(metadata->Data(), metadata->Size()); } + inlined->push_back(id); } found++; } else { @@ -83,15 +84,19 @@ void LocalDependencyResolver::ResolveDependencies(TaskSpecification &task, obj_id, [this, state, obj_id, on_complete](std::shared_ptr obj) { RAY_CHECK(obj != nullptr); bool complete = false; + std::vector inlined; { absl::MutexLock lock(&mu_); state->local_dependencies[obj_id] = std::move(obj); if (--state->dependencies_remaining == 0) { - InlineDependencies(state->local_dependencies, state->task); + InlineDependencies(state->local_dependencies, state->task, &inlined); complete = true; num_pending_ -= 1; } } + if (inlined.size() > 0) { + task_finisher_->OnTaskDependenciesInlined(inlined); + } if (complete) { on_complete(); } diff --git a/src/ray/core_worker/transport/dependency_resolver.h b/src/ray/core_worker/transport/dependency_resolver.h index b30c5e4a0..20631e3ed 100644 --- a/src/ray/core_worker/transport/dependency_resolver.h +++ b/src/ray/core_worker/transport/dependency_resolver.h @@ -6,14 +6,16 @@ #include "ray/common/id.h" #include "ray/common/task/task_spec.h" #include "ray/core_worker/store_provider/memory_store/memory_store.h" +#include "ray/core_worker/task_manager.h" namespace ray { // This class is thread-safe. class LocalDependencyResolver { public: - LocalDependencyResolver(std::shared_ptr store) - : in_memory_store_(store), num_pending_(0) {} + LocalDependencyResolver(std::shared_ptr store, + std::shared_ptr task_finisher) + : in_memory_store_(store), task_finisher_(task_finisher), num_pending_(0) {} /// Resolve all local and remote dependencies for the task, calling the specified /// callback when done. Direct call ids in the task specification will be resolved @@ -33,6 +35,9 @@ class LocalDependencyResolver { /// The in-memory store. std::shared_ptr in_memory_store_; + /// Used to complete tasks. + std::shared_ptr task_finisher_; + /// Number of tasks pending dependency resolution. std::atomic num_pending_; diff --git a/src/ray/core_worker/transport/direct_actor_transport.h b/src/ray/core_worker/transport/direct_actor_transport.h index 1164b7cfc..feb08c17b 100644 --- a/src/ray/core_worker/transport/direct_actor_transport.h +++ b/src/ray/core_worker/transport/direct_actor_transport.h @@ -8,9 +8,9 @@ #include #include #include -#include "absl/container/flat_hash_map.h" #include "absl/base/thread_annotations.h" +#include "absl/container/flat_hash_map.h" #include "absl/synchronization/mutex.h" #include "ray/common/id.h" #include "ray/common/ray_object.h" @@ -39,7 +39,7 @@ class CoreWorkerDirectActorTaskSubmitter { std::shared_ptr store, std::shared_ptr task_finisher) : client_factory_(client_factory), - resolver_(store), + resolver_(store, task_finisher), task_finisher_(task_finisher) {} /// Submit a task to an actor for execution. diff --git a/src/ray/core_worker/transport/direct_task_transport.h b/src/ray/core_worker/transport/direct_task_transport.h index 2df661d8d..2b8f6d36d 100644 --- a/src/ray/core_worker/transport/direct_task_transport.h +++ b/src/ray/core_worker/transport/direct_task_transport.h @@ -5,7 +5,6 @@ #include "absl/base/thread_annotations.h" #include "absl/synchronization/mutex.h" - #include "ray/common/id.h" #include "ray/common/ray_object.h" #include "ray/core_worker/context.h" @@ -41,7 +40,7 @@ class CoreWorkerDirectTaskSubmitter { : local_lease_client_(lease_client), client_factory_(client_factory), lease_client_factory_(lease_client_factory), - resolver_(store), + resolver_(store, task_finisher), task_finisher_(task_finisher), local_raylet_id_(local_raylet_id), lease_timeout_ms_(lease_timeout_ms) {}