mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
Reference counting for direct call submitted tasks (#6514)
Co-authored-by: Zhijun Fu <37800433+zhijunfu@users.noreply.github.com>
This commit is contained in:
parent
b0b6b56bb7
commit
e50aa99be1
17 changed files with 512 additions and 344 deletions
|
@ -1073,16 +1073,12 @@ cdef class CoreWorker:
|
|||
return output
|
||||
|
||||
def add_object_id_reference(self, ObjectID object_id):
|
||||
cdef:
|
||||
CObjectID c_object_id = object_id.native()
|
||||
# Note: faster to not release GIL for short-running op.
|
||||
self.core_worker.get().AddObjectIDReference(c_object_id)
|
||||
self.core_worker.get().AddLocalReference(object_id.native())
|
||||
|
||||
def remove_object_id_reference(self, ObjectID object_id):
|
||||
cdef:
|
||||
CObjectID c_object_id = object_id.native()
|
||||
# Note: faster to not release GIL for short-running op.
|
||||
self.core_worker.get().RemoveObjectIDReference(c_object_id)
|
||||
self.core_worker.get().RemoveLocalReference(object_id.native())
|
||||
|
||||
def serialize_and_promote_object_id(self, ObjectID object_id):
|
||||
cdef:
|
||||
|
@ -1174,6 +1170,24 @@ cdef class CoreWorker:
|
|||
def current_actor_is_asyncio(self):
|
||||
return self.core_worker.get().GetWorkerContext().CurrentActorIsAsync()
|
||||
|
||||
def get_all_reference_counts(self):
|
||||
cdef:
|
||||
unordered_map[CObjectID, pair[size_t, size_t]] c_ref_counts
|
||||
unordered_map[CObjectID, pair[size_t, size_t]].iterator it
|
||||
|
||||
c_ref_counts = self.core_worker.get().GetAllReferenceCounts()
|
||||
it = c_ref_counts.begin()
|
||||
|
||||
ref_counts = {}
|
||||
while it != c_ref_counts.end():
|
||||
object_id = ObjectID(dereference(it).first.Binary())
|
||||
ref_counts[object_id] = {
|
||||
"local": dereference(it).second.first,
|
||||
"submitted": dereference(it).second.second}
|
||||
postincrement(it)
|
||||
|
||||
return ref_counts
|
||||
|
||||
def in_memory_store_get_async(self, ObjectID object_id, future):
|
||||
self.core_worker.get().GetAsync(
|
||||
object_id.native(),
|
||||
|
|
|
@ -118,8 +118,8 @@ cdef extern from "ray/core_worker/core_worker.h" nogil:
|
|||
CActorID DeserializeAndRegisterActorHandle(const c_string &bytes)
|
||||
CRayStatus SerializeActorHandle(const CActorID &actor_id, c_string
|
||||
*bytes)
|
||||
void AddObjectIDReference(const CObjectID &object_id)
|
||||
void RemoveObjectIDReference(const CObjectID &object_id)
|
||||
void AddLocalReference(const CObjectID &object_id)
|
||||
void RemoveLocalReference(const CObjectID &object_id)
|
||||
void PromoteObjectToPlasma(const CObjectID &object_id)
|
||||
void PromoteToPlasmaAndGetOwnershipInfo(const CObjectID &object_id,
|
||||
CTaskID *owner_id,
|
||||
|
@ -149,6 +149,9 @@ cdef extern from "ray/core_worker/core_worker.h" nogil:
|
|||
|
||||
CWorkerContext &GetWorkerContext()
|
||||
void YieldCurrentFiber(CFiberEvent &coroutine_done)
|
||||
|
||||
unordered_map[CObjectID, pair[size_t, size_t]] GetAllReferenceCounts()
|
||||
|
||||
void GetAsync(const CObjectID &object_id,
|
||||
ray_callback_function successs_callback,
|
||||
ray_callback_function fallback_callback,
|
||||
|
|
162
python/ray/tests/test_reference_counting.py
Normal file
162
python/ray/tests/test_reference_counting.py
Normal file
|
@ -0,0 +1,162 @@
|
|||
# coding: utf-8
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import copy
|
||||
import tempfile
|
||||
import numpy as np
|
||||
import time
|
||||
import logging
|
||||
import uuid
|
||||
|
||||
import ray
|
||||
import ray.cluster_utils
|
||||
import ray.test_utils
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _check_refcounts(expected):
|
||||
actual = ray.worker.global_worker.core_worker.get_all_reference_counts()
|
||||
assert len(expected) == len(actual)
|
||||
for object_id, (local, submitted) in expected.items():
|
||||
assert object_id in actual
|
||||
assert local == actual[object_id]["local"]
|
||||
assert submitted == actual[object_id]["submitted"]
|
||||
|
||||
|
||||
def check_refcounts(expected, timeout=1):
|
||||
start = time.time()
|
||||
while True:
|
||||
try:
|
||||
_check_refcounts(expected)
|
||||
break
|
||||
except AssertionError as e:
|
||||
if time.time() - start > timeout:
|
||||
raise e
|
||||
else:
|
||||
time.sleep(0.1)
|
||||
|
||||
|
||||
def test_local_refcounts(ray_start_regular):
|
||||
oid1 = ray.put(None)
|
||||
check_refcounts({oid1: (1, 0)})
|
||||
oid1_copy = copy.copy(oid1)
|
||||
check_refcounts({oid1: (2, 0)})
|
||||
del oid1
|
||||
check_refcounts({oid1_copy: (1, 0)})
|
||||
del oid1_copy
|
||||
check_refcounts({})
|
||||
|
||||
|
||||
def test_dependency_refcounts(ray_start_regular):
|
||||
# Return a large object that will be spilled to plasma.
|
||||
def large_object():
|
||||
return np.zeros(10 * 1024 * 1024, dtype=np.uint8)
|
||||
|
||||
# TODO: Clean up tmpfiles?
|
||||
def random_path():
|
||||
return os.path.join(tempfile.gettempdir(), uuid.uuid4().hex)
|
||||
|
||||
def touch(path):
|
||||
with open(path, "w"):
|
||||
pass
|
||||
|
||||
def wait_for_file(path):
|
||||
while True:
|
||||
if os.path.exists(path):
|
||||
break
|
||||
time.sleep(0.1)
|
||||
|
||||
@ray.remote
|
||||
def one_dep(dep, path=None, fail=False):
|
||||
if path is not None:
|
||||
wait_for_file(path)
|
||||
if fail:
|
||||
raise Exception("failed on purpose")
|
||||
|
||||
@ray.remote
|
||||
def one_dep_large(dep, path=None):
|
||||
if path is not None:
|
||||
wait_for_file(path)
|
||||
# This should be spilled to plasma.
|
||||
return large_object()
|
||||
|
||||
# Test that regular plasma dependency refcounts are decremented once the
|
||||
# task finishes.
|
||||
f = random_path()
|
||||
large_dep = ray.put(large_object())
|
||||
result = one_dep.remote(large_dep, path=f)
|
||||
check_refcounts({large_dep: (1, 1), result: (1, 0)})
|
||||
touch(f)
|
||||
# Reference count should be removed once the task finishes.
|
||||
check_refcounts({large_dep: (1, 0), result: (1, 0)})
|
||||
del large_dep, result
|
||||
check_refcounts({})
|
||||
|
||||
# Test that inlined dependency refcounts are decremented once they are
|
||||
# inlined.
|
||||
f = random_path()
|
||||
dep = one_dep.remote(None, path=f)
|
||||
check_refcounts({dep: (1, 0)})
|
||||
result = one_dep.remote(dep)
|
||||
check_refcounts({dep: (1, 1), result: (1, 0)})
|
||||
touch(f)
|
||||
# Reference count should be removed as soon as the dependency is inlined.
|
||||
check_refcounts({dep: (1, 0), result: (1, 0)}, timeout=1)
|
||||
del dep, result
|
||||
check_refcounts({})
|
||||
|
||||
# Test that spilled plasma dependency refcounts are decremented once
|
||||
# the task finishes.
|
||||
f1, f2 = random_path(), random_path()
|
||||
dep = one_dep_large.remote(None, path=f1)
|
||||
check_refcounts({dep: (1, 0)})
|
||||
result = one_dep.remote(dep, path=f2)
|
||||
check_refcounts({dep: (1, 1), result: (1, 0)})
|
||||
touch(f1)
|
||||
ray.get(dep, timeout=5.0)
|
||||
# Reference count should remain because the dependency is in plasma.
|
||||
check_refcounts({dep: (1, 1), result: (1, 0)})
|
||||
touch(f2)
|
||||
# Reference count should be removed because the task finished.
|
||||
check_refcounts({dep: (1, 0), result: (1, 0)})
|
||||
del dep, result
|
||||
check_refcounts({})
|
||||
|
||||
# Test that regular plasma dependency refcounts are decremented if a task
|
||||
# fails.
|
||||
f = random_path()
|
||||
large_dep = ray.put(large_object())
|
||||
result = one_dep.remote(large_dep, path=f, fail=True)
|
||||
check_refcounts({large_dep: (1, 1), result: (1, 0)})
|
||||
touch(f)
|
||||
# Reference count should be removed once the task finishes.
|
||||
check_refcounts({large_dep: (1, 0), result: (1, 0)})
|
||||
del large_dep, result
|
||||
check_refcounts({})
|
||||
|
||||
# Test that spilled plasma dependency refcounts are decremented if a task
|
||||
# fails.
|
||||
f1, f2 = random_path(), random_path()
|
||||
dep = one_dep_large.remote(None, path=f1)
|
||||
check_refcounts({dep: (1, 0)})
|
||||
result = one_dep.remote(dep, path=f2, fail=True)
|
||||
check_refcounts({dep: (1, 1), result: (1, 0)})
|
||||
touch(f1)
|
||||
ray.get(dep, timeout=5.0)
|
||||
# Reference count should remain because the dependency is in plasma.
|
||||
check_refcounts({dep: (1, 1), result: (1, 0)})
|
||||
touch(f2)
|
||||
# Reference count should be removed because the task finished.
|
||||
check_refcounts({dep: (1, 0), result: (1, 0)})
|
||||
del dep, result
|
||||
check_refcounts({})
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import pytest
|
||||
import sys
|
||||
sys.exit(pytest.main(["-v", __file__]))
|
|
@ -192,14 +192,14 @@ CoreWorker::CoreWorker(const WorkerType worker_type, const Language language,
|
|||
ref_counting_enabled ? reference_counter_ : nullptr, local_raylet_client_));
|
||||
|
||||
task_manager_.reset(new TaskManager(
|
||||
memory_store_, actor_manager_, [this](const TaskSpecification &spec) {
|
||||
memory_store_, reference_counter_, actor_manager_,
|
||||
[this](const TaskSpecification &spec) {
|
||||
// Retry after a delay to emulate the existing Raylet reconstruction
|
||||
// behaviour. TODO(ekl) backoff exponentially.
|
||||
RAY_LOG(ERROR) << "Will resubmit task after a 5 second delay: "
|
||||
<< spec.DebugString();
|
||||
to_resubmit_.push_back(std::make_pair(current_time_ms() + 5000, spec));
|
||||
}));
|
||||
resolver_.reset(new LocalDependencyResolver(memory_store_));
|
||||
|
||||
// Create an entry for the driver task in the task table. This task is
|
||||
// added immediately with status RUNNING. This allows us to push errors
|
||||
|
@ -377,8 +377,7 @@ Status CoreWorker::Put(const RayObject &object, ObjectID *object_id) {
|
|||
*object_id = ObjectID::ForPut(worker_context_.GetCurrentTaskID(),
|
||||
worker_context_.GetNextPutIndex(),
|
||||
static_cast<uint8_t>(TaskTransportType::RAYLET));
|
||||
reference_counter_->AddOwnedObject(*object_id, GetCallerId(), rpc_address_,
|
||||
std::make_shared<std::vector<ObjectID>>());
|
||||
reference_counter_->AddOwnedObject(*object_id, GetCallerId(), rpc_address_);
|
||||
return Put(object, *object_id);
|
||||
}
|
||||
|
||||
|
@ -460,10 +459,6 @@ Status CoreWorker::Get(const std::vector<ObjectID> &ids, const int64_t timeout_m
|
|||
// object.
|
||||
will_throw_exception = true;
|
||||
}
|
||||
// If we got the result for this ObjectID, the task that created it must
|
||||
// have finished. Therefore, we can safely remove its reference counting
|
||||
// dependencies.
|
||||
RemoveObjectIDDependencies(ids[i]);
|
||||
} else {
|
||||
missing_result = true;
|
||||
}
|
||||
|
@ -597,30 +592,6 @@ TaskID CoreWorker::GetCallerId() const {
|
|||
return caller_id;
|
||||
}
|
||||
|
||||
void CoreWorker::PinObjectReferences(const TaskSpecification &task_spec,
|
||||
const TaskTransportType transport_type) {
|
||||
size_t num_returns = task_spec.NumReturns();
|
||||
if (task_spec.IsActorCreationTask() || task_spec.IsActorTask()) {
|
||||
num_returns--;
|
||||
}
|
||||
|
||||
std::shared_ptr<std::vector<ObjectID>> task_deps =
|
||||
std::make_shared<std::vector<ObjectID>>();
|
||||
for (size_t i = 0; i < task_spec.NumArgs(); i++) {
|
||||
if (task_spec.ArgByRef(i)) {
|
||||
for (size_t j = 0; j < task_spec.ArgIdCount(i); j++) {
|
||||
task_deps->push_back(task_spec.ArgId(i, j));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Note that we call this even if task_deps.size() == 0, in order to pin the return id.
|
||||
for (size_t i = 0; i < num_returns; i++) {
|
||||
reference_counter_->AddOwnedObject(task_spec.ReturnId(i, transport_type),
|
||||
GetCallerId(), rpc_address_, task_deps);
|
||||
}
|
||||
}
|
||||
|
||||
Status CoreWorker::SubmitTask(const RayFunction &function,
|
||||
const std::vector<TaskArg> &args,
|
||||
const TaskOptions &task_options,
|
||||
|
@ -642,11 +613,9 @@ Status CoreWorker::SubmitTask(const RayFunction &function,
|
|||
return_ids);
|
||||
TaskSpecification task_spec = builder.Build();
|
||||
if (task_options.is_direct_call) {
|
||||
task_manager_->AddPendingTask(task_spec, max_retries);
|
||||
PinObjectReferences(task_spec, TaskTransportType::DIRECT);
|
||||
task_manager_->AddPendingTask(GetCallerId(), rpc_address_, task_spec, max_retries);
|
||||
return direct_task_submitter_->SubmitTask(task_spec);
|
||||
} else {
|
||||
PinObjectReferences(task_spec, TaskTransportType::RAYLET);
|
||||
return local_raylet_client_->SubmitTask(task_spec);
|
||||
}
|
||||
}
|
||||
|
@ -685,11 +654,10 @@ Status CoreWorker::CreateActor(const RayFunction &function,
|
|||
*return_actor_id = actor_id;
|
||||
TaskSpecification task_spec = builder.Build();
|
||||
if (actor_creation_options.is_direct_call) {
|
||||
task_manager_->AddPendingTask(task_spec, actor_creation_options.max_reconstructions);
|
||||
PinObjectReferences(task_spec, TaskTransportType::DIRECT);
|
||||
task_manager_->AddPendingTask(GetCallerId(), rpc_address_, task_spec,
|
||||
actor_creation_options.max_reconstructions);
|
||||
return direct_task_submitter_->SubmitTask(task_spec);
|
||||
} else {
|
||||
PinObjectReferences(task_spec, TaskTransportType::RAYLET);
|
||||
return local_raylet_client_->SubmitTask(task_spec);
|
||||
}
|
||||
}
|
||||
|
@ -729,8 +697,7 @@ Status CoreWorker::SubmitActorTask(const ActorID &actor_id, const RayFunction &f
|
|||
Status status;
|
||||
TaskSpecification task_spec = builder.Build();
|
||||
if (is_direct_call) {
|
||||
task_manager_->AddPendingTask(task_spec);
|
||||
PinObjectReferences(task_spec, TaskTransportType::DIRECT);
|
||||
task_manager_->AddPendingTask(GetCallerId(), rpc_address_, task_spec);
|
||||
if (actor_handle->IsDead()) {
|
||||
auto status = Status::IOError("sent task to dead actor");
|
||||
task_manager_->PendingTaskFailed(task_spec.TaskId(), rpc::ErrorType::ACTOR_DIED,
|
||||
|
@ -739,7 +706,6 @@ Status CoreWorker::SubmitActorTask(const ActorID &actor_id, const RayFunction &f
|
|||
status = direct_actor_submitter_->SubmitTask(task_spec);
|
||||
}
|
||||
} else {
|
||||
PinObjectReferences(task_spec, TaskTransportType::RAYLET);
|
||||
RAY_CHECK_OK(local_raylet_client_->SubmitTask(task_spec));
|
||||
}
|
||||
return status;
|
||||
|
@ -1043,17 +1009,17 @@ void CoreWorker::HandleGetObjectStatus(const rpc::GetObjectStatusRequest &reques
|
|||
if (task_manager_->IsTaskPending(object_id.TaskId())) {
|
||||
// Acquire a reference and retry. This prevents the object from being
|
||||
// evicted out from under us before we can start the get.
|
||||
AddObjectIDReference(object_id);
|
||||
AddLocalReference(object_id);
|
||||
if (task_manager_->IsTaskPending(object_id.TaskId())) {
|
||||
// The task is pending. Send the reply once the task finishes.
|
||||
memory_store_->GetAsync(object_id,
|
||||
[send_reply_callback](std::shared_ptr<RayObject> obj) {
|
||||
send_reply_callback(Status::OK(), nullptr, nullptr);
|
||||
});
|
||||
RemoveObjectIDReference(object_id);
|
||||
RemoveLocalReference(object_id);
|
||||
} else {
|
||||
// We lost the race, the task is done.
|
||||
RemoveObjectIDReference(object_id);
|
||||
RemoveLocalReference(object_id);
|
||||
send_reply_callback(Status::OK(), nullptr, nullptr);
|
||||
}
|
||||
} else {
|
||||
|
|
|
@ -101,17 +101,19 @@ class CoreWorker {
|
|||
actor_id_ = actor_id;
|
||||
}
|
||||
|
||||
/// Increase the reference count for this object ID.
|
||||
/// Increase the local reference count for this object ID. Should be called
|
||||
/// by the language frontend when a new reference is created.
|
||||
///
|
||||
/// \param[in] object_id The object ID to increase the reference count for.
|
||||
void AddObjectIDReference(const ObjectID &object_id) {
|
||||
void AddLocalReference(const ObjectID &object_id) {
|
||||
reference_counter_->AddLocalReference(object_id);
|
||||
}
|
||||
|
||||
/// Decrease the reference count for this object ID.
|
||||
/// Decrease the reference count for this object ID. Should be called
|
||||
/// by the language frontend when a reference is destroyed.
|
||||
///
|
||||
/// \param[in] object_id The object ID to decrease the reference count for.
|
||||
void RemoveObjectIDReference(const ObjectID &object_id) {
|
||||
void RemoveLocalReference(const ObjectID &object_id) {
|
||||
std::vector<ObjectID> deleted;
|
||||
reference_counter_->RemoveLocalReference(object_id, &deleted);
|
||||
if (ref_counting_enabled_) {
|
||||
|
@ -119,6 +121,12 @@ class CoreWorker {
|
|||
}
|
||||
}
|
||||
|
||||
/// Returns a map of all ObjectIDs currently in scope with a pair of their
|
||||
/// (local, submitted_task) reference counts. For debugging purposes.
|
||||
std::unordered_map<ObjectID, std::pair<size_t, size_t>> GetAllReferenceCounts() const {
|
||||
return reference_counter_->GetAllReferenceCounts();
|
||||
}
|
||||
|
||||
/// Promote an object to plasma and get its owner information. This should be
|
||||
/// called when serializing an object ID, and the returned information should
|
||||
/// be stored with the serialized object ID. For plasma promotion, if the
|
||||
|
@ -434,11 +442,6 @@ class CoreWorker {
|
|||
/// Private methods related to task submission.
|
||||
///
|
||||
|
||||
/// Add task dependencies to the reference counter. This prevents the argument
|
||||
/// objects from early eviction, and also adds the return object.
|
||||
void PinObjectReferences(const TaskSpecification &task_spec,
|
||||
const TaskTransportType transport_type);
|
||||
|
||||
/// Give this worker a handle to an actor.
|
||||
///
|
||||
/// This handle will remain as long as the current actor or task is
|
||||
|
@ -485,17 +488,6 @@ class CoreWorker {
|
|||
std::vector<std::shared_ptr<RayObject>> *args,
|
||||
std::vector<ObjectID> *arg_reference_ids);
|
||||
|
||||
/// Remove reference counting dependencies of this object ID.
|
||||
///
|
||||
/// \param[in] object_id The object whose dependencies should be removed.
|
||||
void RemoveObjectIDDependencies(const ObjectID &object_id) {
|
||||
std::vector<ObjectID> deleted;
|
||||
reference_counter_->RemoveDependencies(object_id, &deleted);
|
||||
if (ref_counting_enabled_) {
|
||||
memory_store_->Delete(deleted);
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns whether the message was sent to the wrong worker. The right error reply
|
||||
/// is sent automatically. Messages end up on the wrong worker when a worker dies
|
||||
/// and a new one takes its place with the same place. In this situation, we want
|
||||
|
@ -624,9 +616,6 @@ class CoreWorker {
|
|||
absl::flat_hash_map<ActorID, std::unique_ptr<ActorHandle>> actor_handles_
|
||||
GUARDED_BY(actor_handles_mutex_);
|
||||
|
||||
/// Resolve local and remote dependencies for actor creation.
|
||||
std::unique_ptr<LocalDependencyResolver> resolver_;
|
||||
|
||||
///
|
||||
/// Fields related to task execution.
|
||||
///
|
||||
|
|
|
@ -14,81 +14,76 @@ void ReferenceCounter::AddBorrowedObject(const ObjectID &object_id,
|
|||
}
|
||||
}
|
||||
|
||||
void ReferenceCounter::AddOwnedObject(
|
||||
const ObjectID &object_id, const TaskID &owner_id, const rpc::Address &owner_address,
|
||||
std::shared_ptr<std::vector<ObjectID>> dependencies) {
|
||||
void ReferenceCounter::AddOwnedObject(const ObjectID &object_id, const TaskID &owner_id,
|
||||
const rpc::Address &owner_address) {
|
||||
absl::MutexLock lock(&mutex_);
|
||||
|
||||
for (const ObjectID &dependency_id : *dependencies) {
|
||||
AddLocalReferenceInternal(dependency_id);
|
||||
}
|
||||
|
||||
RAY_CHECK(object_id_refs_.count(object_id) == 0)
|
||||
<< "Cannot create an object that already exists. ObjectID: " << object_id;
|
||||
<< "Tried to create an owned object that already exists: " << object_id;
|
||||
// If the entry doesn't exist, we initialize the direct reference count to zero
|
||||
// because this corresponds to a submitted task whose return ObjectID will be created
|
||||
// in the frontend language, incrementing the reference count.
|
||||
object_id_refs_.emplace(object_id, Reference(owner_id, owner_address, dependencies));
|
||||
object_id_refs_.emplace(object_id, Reference(owner_id, owner_address));
|
||||
}
|
||||
|
||||
void ReferenceCounter::AddLocalReferenceInternal(const ObjectID &object_id) {
|
||||
void ReferenceCounter::AddLocalReference(const ObjectID &object_id) {
|
||||
absl::MutexLock lock(&mutex_);
|
||||
auto entry = object_id_refs_.find(object_id);
|
||||
if (entry == object_id_refs_.end()) {
|
||||
// TODO: Once ref counting is implemented, we should always know how the
|
||||
// ObjectID was created, so there should always ben an entry.
|
||||
// ObjectID was created, so there should always be an entry.
|
||||
entry = object_id_refs_.emplace(object_id, Reference()).first;
|
||||
}
|
||||
entry->second.local_ref_count++;
|
||||
}
|
||||
|
||||
void ReferenceCounter::AddLocalReference(const ObjectID &object_id) {
|
||||
absl::MutexLock lock(&mutex_);
|
||||
AddLocalReferenceInternal(object_id);
|
||||
}
|
||||
|
||||
void ReferenceCounter::RemoveLocalReference(const ObjectID &object_id,
|
||||
std::vector<ObjectID> *deleted) {
|
||||
absl::MutexLock lock(&mutex_);
|
||||
RemoveReferenceRecursive(object_id, deleted);
|
||||
}
|
||||
|
||||
void ReferenceCounter::RemoveDependencies(const ObjectID &object_id,
|
||||
std::vector<ObjectID> *deleted) {
|
||||
absl::MutexLock lock(&mutex_);
|
||||
auto entry = object_id_refs_.find(object_id);
|
||||
if (entry == object_id_refs_.end()) {
|
||||
RAY_LOG(WARNING) << "Tried to remove dependencies for nonexistent object ID: "
|
||||
<< object_id;
|
||||
return;
|
||||
}
|
||||
if (entry->second.dependencies) {
|
||||
for (const ObjectID &pending_task_object_id : *entry->second.dependencies) {
|
||||
RemoveReferenceRecursive(pending_task_object_id, deleted);
|
||||
}
|
||||
entry->second.dependencies = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
void ReferenceCounter::RemoveReferenceRecursive(const ObjectID &object_id,
|
||||
std::vector<ObjectID> *deleted) {
|
||||
auto entry = object_id_refs_.find(object_id);
|
||||
if (entry == object_id_refs_.end()) {
|
||||
RAY_LOG(WARNING) << "Tried to decrease ref count for nonexistent object ID: "
|
||||
<< object_id;
|
||||
return;
|
||||
}
|
||||
if (--entry->second.local_ref_count == 0) {
|
||||
// If the reference count reached 0, decrease the reference count for each dependency.
|
||||
if (entry->second.dependencies) {
|
||||
for (const ObjectID &pending_task_object_id : *entry->second.dependencies) {
|
||||
RemoveReferenceRecursive(pending_task_object_id, deleted);
|
||||
}
|
||||
}
|
||||
if (--entry->second.local_ref_count == 0 &&
|
||||
entry->second.submitted_task_ref_count == 0) {
|
||||
object_id_refs_.erase(entry);
|
||||
deleted->push_back(object_id);
|
||||
}
|
||||
}
|
||||
|
||||
void ReferenceCounter::AddSubmittedTaskReferences(
|
||||
const std::vector<ObjectID> &object_ids) {
|
||||
absl::MutexLock lock(&mutex_);
|
||||
for (const ObjectID &object_id : object_ids) {
|
||||
auto entry = object_id_refs_.find(object_id);
|
||||
if (entry == object_id_refs_.end()) {
|
||||
// TODO: Once ref counting is implemented, we should always know how the
|
||||
// ObjectID was created, so there should always be an entry.
|
||||
entry = object_id_refs_.emplace(object_id, Reference()).first;
|
||||
}
|
||||
entry->second.submitted_task_ref_count++;
|
||||
}
|
||||
}
|
||||
|
||||
void ReferenceCounter::RemoveSubmittedTaskReferences(
|
||||
const std::vector<ObjectID> &object_ids, std::vector<ObjectID> *deleted) {
|
||||
absl::MutexLock lock(&mutex_);
|
||||
for (const ObjectID &object_id : object_ids) {
|
||||
auto entry = object_id_refs_.find(object_id);
|
||||
if (entry == object_id_refs_.end()) {
|
||||
RAY_LOG(WARNING) << "Tried to decrease ref count for nonexistent object ID: "
|
||||
<< object_id;
|
||||
return;
|
||||
}
|
||||
if (--entry->second.submitted_task_ref_count == 0 &&
|
||||
entry->second.local_ref_count == 0) {
|
||||
object_id_refs_.erase(entry);
|
||||
deleted->push_back(object_id);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool ReferenceCounter::GetOwner(const ObjectID &object_id, TaskID *owner_id,
|
||||
rpc::Address *owner_address) const {
|
||||
absl::MutexLock lock(&mutex_);
|
||||
|
@ -126,6 +121,19 @@ std::unordered_set<ObjectID> ReferenceCounter::GetAllInScopeObjectIDs() const {
|
|||
return in_scope_object_ids;
|
||||
}
|
||||
|
||||
std::unordered_map<ObjectID, std::pair<size_t, size_t>>
|
||||
ReferenceCounter::GetAllReferenceCounts() const {
|
||||
absl::MutexLock lock(&mutex_);
|
||||
std::unordered_map<ObjectID, std::pair<size_t, size_t>> all_ref_counts;
|
||||
all_ref_counts.reserve(object_id_refs_.size());
|
||||
for (auto it : object_id_refs_) {
|
||||
all_ref_counts.emplace(it.first,
|
||||
std::pair<size_t, size_t>(it.second.local_ref_count,
|
||||
it.second.submitted_task_ref_count));
|
||||
}
|
||||
return all_ref_counts;
|
||||
}
|
||||
|
||||
void ReferenceCounter::LogDebugString() const {
|
||||
absl::MutexLock lock(&mutex_);
|
||||
|
||||
|
@ -137,15 +145,9 @@ void ReferenceCounter::LogDebugString() const {
|
|||
|
||||
for (const auto &entry : object_id_refs_) {
|
||||
RAY_LOG(DEBUG) << "\t" << entry.first.Hex();
|
||||
RAY_LOG(DEBUG) << "\t\treference count: " << entry.second.local_ref_count;
|
||||
RAY_LOG(DEBUG) << "\t\tdependencies: ";
|
||||
if (!entry.second.dependencies) {
|
||||
RAY_LOG(DEBUG) << "\t\t\tNULL";
|
||||
} else {
|
||||
for (const ObjectID &pending_task_object_id : *entry.second.dependencies) {
|
||||
RAY_LOG(DEBUG) << "\t\t\t" << pending_task_object_id.Hex();
|
||||
}
|
||||
}
|
||||
RAY_LOG(DEBUG) << "\t\tlocal refcount: " << entry.second.local_ref_count;
|
||||
RAY_LOG(DEBUG) << "\t\tsubmitted task refcount: "
|
||||
<< entry.second.submitted_task_ref_count;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -25,27 +25,32 @@ class ReferenceCounter {
|
|||
/// \param[in] object_id The object to to increment the count for.
|
||||
void AddLocalReference(const ObjectID &object_id) LOCKS_EXCLUDED(mutex_);
|
||||
|
||||
/// Decrease the reference count for the ObjectID by one. If the reference count reaches
|
||||
/// zero, it will be erased from the map and the reference count for all of its
|
||||
/// dependencies will be decreased be one.
|
||||
/// Decrease the local reference count for the ObjectID by one.
|
||||
///
|
||||
/// \param[in] object_id The object to decrement the count for.
|
||||
/// \param[out] deleted List to store objects that hit zero ref count.
|
||||
void RemoveLocalReference(const ObjectID &object_id, std::vector<ObjectID> *deleted)
|
||||
LOCKS_EXCLUDED(mutex_);
|
||||
|
||||
/// Remove any references to dependencies that this object may have. This does *not*
|
||||
/// decrease the object's own reference count.
|
||||
/// Add references for the provided object IDs that correspond to them being
|
||||
/// dependencies to a submitted task.
|
||||
///
|
||||
/// \param[in] object_id The object whose dependencies should be removed.
|
||||
/// \param[out] deleted List to store objects that hit zero ref count.
|
||||
void RemoveDependencies(const ObjectID &object_id, std::vector<ObjectID> *deleted)
|
||||
LOCKS_EXCLUDED(mutex_);
|
||||
/// \param[in] object_ids The object IDs to add references for.
|
||||
void AddSubmittedTaskReferences(const std::vector<ObjectID> &object_ids);
|
||||
|
||||
/// Remove references for the provided object IDs that correspond to them being
|
||||
/// dependencies to a submitted task. This should be called when inlined
|
||||
/// dependencies are inlined or when the task finishes for plasma dependencies.
|
||||
///
|
||||
/// \param[in] object_ids The object IDs to remove references for.
|
||||
/// \param[out] deleted The object IDs whos reference counts reached zero.
|
||||
void RemoveSubmittedTaskReferences(const std::vector<ObjectID> &object_ids,
|
||||
std::vector<ObjectID> *deleted);
|
||||
|
||||
/// Add an object that we own. The object may depend on other objects.
|
||||
/// Dependencies for each ObjectID must be set at most once. The direct
|
||||
/// reference count for the ObjectID is set to zero and the reference count
|
||||
/// for each dependency is incremented.
|
||||
/// Dependencies for each ObjectID must be set at most once. The local
|
||||
/// reference count for the ObjectID is set to zero, which assumes that an
|
||||
/// ObjectID for it will be created in the language frontend after this call.
|
||||
///
|
||||
/// TODO(swang): We could avoid copying the owner_id and owner_address since
|
||||
/// we are the owner, but it is easier to store a copy for now, since the
|
||||
|
@ -57,9 +62,7 @@ class ReferenceCounter {
|
|||
/// \param[in] owner_address The address of the object's owner.
|
||||
/// \param[in] dependencies The objects that the object depends on.
|
||||
void AddOwnedObject(const ObjectID &object_id, const TaskID &owner_id,
|
||||
const rpc::Address &owner_address,
|
||||
std::shared_ptr<std::vector<ObjectID>> dependencies)
|
||||
LOCKS_EXCLUDED(mutex_);
|
||||
const rpc::Address &owner_address) LOCKS_EXCLUDED(mutex_);
|
||||
|
||||
/// Add an object that we are borrowing.
|
||||
///
|
||||
|
@ -82,6 +85,11 @@ class ReferenceCounter {
|
|||
/// Returns a set of all ObjectIDs currently in scope (i.e., nonzero reference count).
|
||||
std::unordered_set<ObjectID> GetAllInScopeObjectIDs() const LOCKS_EXCLUDED(mutex_);
|
||||
|
||||
/// Returns a map of all ObjectIDs currently in scope with a pair of their
|
||||
/// (local, submitted_task) reference counts. For debugging purposes.
|
||||
std::unordered_map<ObjectID, std::pair<size_t, size_t>> GetAllReferenceCounts() const
|
||||
LOCKS_EXCLUDED(mutex_);
|
||||
|
||||
/// Dumps information about all currently tracked references to RAY_LOG(DEBUG).
|
||||
void LogDebugString() const LOCKS_EXCLUDED(mutex_);
|
||||
|
||||
|
@ -91,22 +99,12 @@ class ReferenceCounter {
|
|||
/// Constructor for a reference whose origin is unknown.
|
||||
Reference() : owned_by_us(false) {}
|
||||
/// Constructor for a reference that we created.
|
||||
Reference(const TaskID &owner_id, const rpc::Address &owner_address,
|
||||
std::shared_ptr<std::vector<ObjectID>> deps)
|
||||
: dependencies(std::move(deps)),
|
||||
owned_by_us(true),
|
||||
owner({owner_id, owner_address}) {}
|
||||
/// Constructor for a reference that was given to us.
|
||||
Reference(const TaskID &owner_id, const rpc::Address &owner_address)
|
||||
: owned_by_us(false), owner({owner_id, owner_address}) {}
|
||||
: owned_by_us(true), owner({owner_id, owner_address}) {}
|
||||
/// The local ref count for the ObjectID in the language frontend.
|
||||
size_t local_ref_count = 0;
|
||||
/// The objects that this object depends on. Tracked only by the owner of
|
||||
/// the object. Dependencies are stored as shared_ptrs because the same set
|
||||
/// of dependencies can be shared among multiple entries. For example, when
|
||||
/// a task has multiple return values, the entry for each return ObjectID
|
||||
/// depends on all task dependencies.
|
||||
std::shared_ptr<std::vector<ObjectID>> dependencies;
|
||||
/// The ref count for submitted tasks that depend on the ObjectID.
|
||||
size_t submitted_task_ref_count = 0;
|
||||
/// Whether we own the object. If we own the object, then we are
|
||||
/// responsible for tracking the state of the task that creates the object
|
||||
/// (see task_manager.h).
|
||||
|
|
|
@ -16,173 +16,63 @@ class ReferenceCountTest : public ::testing::Test {
|
|||
virtual void TearDown() {}
|
||||
};
|
||||
|
||||
// Tests basic incrementing/decrementing of direct reference counts. An entry should only
|
||||
// be removed once its reference count reaches zero.
|
||||
// Tests basic incrementing/decrementing of direct/submitted task reference counts. An
|
||||
// entry should only be removed once both of its reference counts reach zero.
|
||||
TEST_F(ReferenceCountTest, TestBasic) {
|
||||
std::vector<ObjectID> out;
|
||||
ObjectID id = ObjectID::FromRandom();
|
||||
rc->AddLocalReference(id);
|
||||
ASSERT_EQ(rc->NumObjectIDsInScope(), 1);
|
||||
rc->AddLocalReference(id);
|
||||
ASSERT_EQ(rc->NumObjectIDsInScope(), 1);
|
||||
rc->AddLocalReference(id);
|
||||
ASSERT_EQ(rc->NumObjectIDsInScope(), 1);
|
||||
rc->RemoveLocalReference(id, &out);
|
||||
ASSERT_EQ(rc->NumObjectIDsInScope(), 1);
|
||||
ASSERT_EQ(out.size(), 0);
|
||||
rc->RemoveLocalReference(id, &out);
|
||||
ASSERT_EQ(rc->NumObjectIDsInScope(), 1);
|
||||
ASSERT_EQ(out.size(), 0);
|
||||
rc->RemoveLocalReference(id, &out);
|
||||
ASSERT_EQ(rc->NumObjectIDsInScope(), 0);
|
||||
ASSERT_EQ(out.size(), 1);
|
||||
}
|
||||
|
||||
// Tests the basic logic for dependencies - when an ObjectID with dependencies
|
||||
// goes out of scope (i.e., reference count reaches zero), all of its dependencies
|
||||
// should have their reference count decremented and be removed if it reaches zero.
|
||||
TEST_F(ReferenceCountTest, TestDependencies) {
|
||||
std::vector<ObjectID> out;
|
||||
ObjectID id1 = ObjectID::FromRandom();
|
||||
ObjectID id2 = ObjectID::FromRandom();
|
||||
ObjectID id3 = ObjectID::FromRandom();
|
||||
|
||||
std::shared_ptr<std::vector<ObjectID>> deps = std::make_shared<std::vector<ObjectID>>();
|
||||
deps->push_back(id2);
|
||||
deps->push_back(id3);
|
||||
rc->AddOwnedObject(id1, TaskID::Nil(), rpc::Address(), deps);
|
||||
|
||||
// Local references.
|
||||
rc->AddLocalReference(id1);
|
||||
rc->AddLocalReference(id1);
|
||||
rc->AddLocalReference(id3);
|
||||
ASSERT_EQ(rc->NumObjectIDsInScope(), 3);
|
||||
|
||||
rc->RemoveLocalReference(id1, &out);
|
||||
ASSERT_EQ(rc->NumObjectIDsInScope(), 3);
|
||||
ASSERT_EQ(out.size(), 0);
|
||||
rc->RemoveLocalReference(id1, &out);
|
||||
ASSERT_EQ(rc->NumObjectIDsInScope(), 1);
|
||||
ASSERT_EQ(out.size(), 2);
|
||||
|
||||
rc->RemoveLocalReference(id3, &out);
|
||||
ASSERT_EQ(rc->NumObjectIDsInScope(), 0);
|
||||
ASSERT_EQ(out.size(), 3);
|
||||
}
|
||||
|
||||
// Tests the case where two entries share the same set of dependencies. When one
|
||||
// entry goes out of scope, it should decrease the reference count for the dependencies
|
||||
// but they should still be nonzero until the second entry goes out of scope and all
|
||||
// direct dependencies to the dependencies are removed.
|
||||
TEST_F(ReferenceCountTest, TestSharedDependencies) {
|
||||
std::vector<ObjectID> out;
|
||||
ObjectID id1 = ObjectID::FromRandom();
|
||||
ObjectID id2 = ObjectID::FromRandom();
|
||||
ObjectID id3 = ObjectID::FromRandom();
|
||||
ObjectID id4 = ObjectID::FromRandom();
|
||||
|
||||
std::shared_ptr<std::vector<ObjectID>> deps = std::make_shared<std::vector<ObjectID>>();
|
||||
deps->push_back(id3);
|
||||
deps->push_back(id4);
|
||||
rc->AddOwnedObject(id1, TaskID::Nil(), rpc::Address(), deps);
|
||||
rc->AddOwnedObject(id2, TaskID::Nil(), rpc::Address(), deps);
|
||||
|
||||
rc->AddLocalReference(id1);
|
||||
rc->AddLocalReference(id2);
|
||||
rc->AddLocalReference(id4);
|
||||
ASSERT_EQ(rc->NumObjectIDsInScope(), 4);
|
||||
|
||||
rc->RemoveLocalReference(id1, &out);
|
||||
ASSERT_EQ(rc->NumObjectIDsInScope(), 3);
|
||||
ASSERT_EQ(out.size(), 1);
|
||||
rc->RemoveLocalReference(id2, &out);
|
||||
ASSERT_EQ(rc->NumObjectIDsInScope(), 1);
|
||||
ASSERT_EQ(out.size(), 3);
|
||||
|
||||
rc->RemoveLocalReference(id4, &out);
|
||||
ASSERT_EQ(rc->NumObjectIDsInScope(), 0);
|
||||
ASSERT_EQ(out.size(), 4);
|
||||
}
|
||||
|
||||
// Tests the case when an entry has a dependency that itself has a
|
||||
// dependency. In this case, when the first entry goes out of scope
|
||||
// it should decrease the reference count for its dependency, causing
|
||||
// that entry to go out of scope and decrease its dependencies' reference counts.
|
||||
TEST_F(ReferenceCountTest, TestRecursiveDependencies) {
|
||||
std::vector<ObjectID> out;
|
||||
ObjectID id1 = ObjectID::FromRandom();
|
||||
ObjectID id2 = ObjectID::FromRandom();
|
||||
ObjectID id3 = ObjectID::FromRandom();
|
||||
ObjectID id4 = ObjectID::FromRandom();
|
||||
|
||||
std::shared_ptr<std::vector<ObjectID>> deps2 =
|
||||
std::make_shared<std::vector<ObjectID>>();
|
||||
deps2->push_back(id3);
|
||||
deps2->push_back(id4);
|
||||
rc->AddOwnedObject(id2, TaskID::Nil(), rpc::Address(), deps2);
|
||||
|
||||
std::shared_ptr<std::vector<ObjectID>> deps1 =
|
||||
std::make_shared<std::vector<ObjectID>>();
|
||||
deps1->push_back(id2);
|
||||
rc->AddOwnedObject(id1, TaskID::Nil(), rpc::Address(), deps1);
|
||||
|
||||
rc->AddLocalReference(id1);
|
||||
rc->AddLocalReference(id2);
|
||||
rc->AddLocalReference(id4);
|
||||
ASSERT_EQ(rc->NumObjectIDsInScope(), 4);
|
||||
|
||||
rc->RemoveLocalReference(id2, &out);
|
||||
ASSERT_EQ(rc->NumObjectIDsInScope(), 4);
|
||||
ASSERT_EQ(out.size(), 0);
|
||||
rc->RemoveLocalReference(id1, &out);
|
||||
ASSERT_EQ(rc->NumObjectIDsInScope(), 1);
|
||||
ASSERT_EQ(out.size(), 3);
|
||||
|
||||
rc->RemoveLocalReference(id4, &out);
|
||||
ASSERT_EQ(rc->NumObjectIDsInScope(), 0);
|
||||
ASSERT_EQ(out.size(), 4);
|
||||
}
|
||||
|
||||
TEST_F(ReferenceCountTest, TestRemoveDependenciesOnly) {
|
||||
std::vector<ObjectID> out;
|
||||
ObjectID id1 = ObjectID::FromRandom();
|
||||
ObjectID id2 = ObjectID::FromRandom();
|
||||
ObjectID id3 = ObjectID::FromRandom();
|
||||
ObjectID id4 = ObjectID::FromRandom();
|
||||
|
||||
std::shared_ptr<std::vector<ObjectID>> deps2 =
|
||||
std::make_shared<std::vector<ObjectID>>();
|
||||
deps2->push_back(id3);
|
||||
deps2->push_back(id4);
|
||||
rc->AddOwnedObject(id2, TaskID::Nil(), rpc::Address(), deps2);
|
||||
|
||||
std::shared_ptr<std::vector<ObjectID>> deps1 =
|
||||
std::make_shared<std::vector<ObjectID>>();
|
||||
deps1->push_back(id2);
|
||||
rc->AddOwnedObject(id1, TaskID::Nil(), rpc::Address(), deps1);
|
||||
|
||||
rc->AddLocalReference(id1);
|
||||
rc->AddLocalReference(id2);
|
||||
rc->AddLocalReference(id4);
|
||||
ASSERT_EQ(rc->NumObjectIDsInScope(), 4);
|
||||
|
||||
rc->RemoveDependencies(id2, &out);
|
||||
ASSERT_EQ(rc->NumObjectIDsInScope(), 3);
|
||||
ASSERT_EQ(out.size(), 1);
|
||||
rc->RemoveDependencies(id1, &out);
|
||||
ASSERT_EQ(rc->NumObjectIDsInScope(), 3);
|
||||
ASSERT_EQ(out.size(), 1);
|
||||
|
||||
ASSERT_EQ(rc->NumObjectIDsInScope(), 2);
|
||||
rc->RemoveLocalReference(id1, &out);
|
||||
ASSERT_EQ(rc->NumObjectIDsInScope(), 2);
|
||||
ASSERT_EQ(out.size(), 2);
|
||||
|
||||
ASSERT_EQ(out.size(), 0);
|
||||
rc->RemoveLocalReference(id2, &out);
|
||||
ASSERT_EQ(rc->NumObjectIDsInScope(), 1);
|
||||
ASSERT_EQ(out.size(), 3);
|
||||
|
||||
rc->RemoveLocalReference(id4, &out);
|
||||
ASSERT_EQ(out.size(), 1);
|
||||
rc->RemoveLocalReference(id1, &out);
|
||||
ASSERT_EQ(rc->NumObjectIDsInScope(), 0);
|
||||
ASSERT_EQ(out.size(), 4);
|
||||
ASSERT_EQ(out.size(), 2);
|
||||
out.clear();
|
||||
|
||||
// Submitted task references.
|
||||
rc->AddSubmittedTaskReferences({id1});
|
||||
rc->AddSubmittedTaskReferences({id1, id2});
|
||||
ASSERT_EQ(rc->NumObjectIDsInScope(), 2);
|
||||
rc->RemoveSubmittedTaskReferences({id1}, &out);
|
||||
ASSERT_EQ(rc->NumObjectIDsInScope(), 2);
|
||||
ASSERT_EQ(out.size(), 0);
|
||||
rc->RemoveSubmittedTaskReferences({id2}, &out);
|
||||
ASSERT_EQ(rc->NumObjectIDsInScope(), 1);
|
||||
ASSERT_EQ(out.size(), 1);
|
||||
rc->RemoveSubmittedTaskReferences({id1}, &out);
|
||||
ASSERT_EQ(rc->NumObjectIDsInScope(), 0);
|
||||
ASSERT_EQ(out.size(), 2);
|
||||
out.clear();
|
||||
|
||||
// Local & submitted task references.
|
||||
rc->AddLocalReference(id1);
|
||||
rc->AddSubmittedTaskReferences({id1, id2});
|
||||
rc->AddLocalReference(id2);
|
||||
ASSERT_EQ(rc->NumObjectIDsInScope(), 2);
|
||||
rc->RemoveLocalReference(id1, &out);
|
||||
ASSERT_EQ(rc->NumObjectIDsInScope(), 2);
|
||||
ASSERT_EQ(out.size(), 0);
|
||||
rc->RemoveSubmittedTaskReferences({id2}, &out);
|
||||
ASSERT_EQ(rc->NumObjectIDsInScope(), 2);
|
||||
ASSERT_EQ(out.size(), 0);
|
||||
rc->RemoveSubmittedTaskReferences({id1}, &out);
|
||||
ASSERT_EQ(rc->NumObjectIDsInScope(), 1);
|
||||
ASSERT_EQ(out.size(), 1);
|
||||
rc->RemoveLocalReference(id2, &out);
|
||||
ASSERT_EQ(rc->NumObjectIDsInScope(), 0);
|
||||
ASSERT_EQ(out.size(), 2);
|
||||
out.clear();
|
||||
}
|
||||
|
||||
// Tests that we can get the owner address correctly for objects that we own,
|
||||
|
@ -193,8 +83,7 @@ TEST_F(ReferenceCountTest, TestOwnerAddress) {
|
|||
TaskID task_id = TaskID::ForFakeTask();
|
||||
rpc::Address address;
|
||||
address.set_ip_address("1234");
|
||||
auto deps = std::make_shared<std::vector<ObjectID>>();
|
||||
rc->AddOwnedObject(object_id, task_id, address, deps);
|
||||
rc->AddOwnedObject(object_id, task_id, address);
|
||||
|
||||
TaskID added_id;
|
||||
rpc::Address added_address;
|
||||
|
@ -205,7 +94,7 @@ TEST_F(ReferenceCountTest, TestOwnerAddress) {
|
|||
auto object_id2 = ObjectID::FromRandom();
|
||||
task_id = TaskID::ForFakeTask();
|
||||
address.set_ip_address("5678");
|
||||
rc->AddOwnedObject(object_id2, task_id, address, deps);
|
||||
rc->AddOwnedObject(object_id2, task_id, address);
|
||||
ASSERT_TRUE(rc->GetOwner(object_id2, &added_id, &added_address));
|
||||
ASSERT_EQ(task_id, added_id);
|
||||
ASSERT_EQ(address.ip_address(), added_address.ip_address());
|
||||
|
|
|
@ -10,11 +10,34 @@ const int64_t kTaskFailureThrottlingThreshold = 50;
|
|||
// Throttle task failure logs to once this interval.
|
||||
const int64_t kTaskFailureLoggingFrequencyMillis = 5000;
|
||||
|
||||
void TaskManager::AddPendingTask(const TaskSpecification &spec, int max_retries) {
|
||||
void TaskManager::AddPendingTask(const TaskID &caller_id,
|
||||
const rpc::Address &caller_address,
|
||||
const TaskSpecification &spec, int max_retries) {
|
||||
RAY_LOG(DEBUG) << "Adding pending task " << spec.TaskId();
|
||||
absl::MutexLock lock(&mu_);
|
||||
std::pair<TaskSpecification, int> entry = {spec, max_retries};
|
||||
RAY_CHECK(pending_tasks_.emplace(spec.TaskId(), std::move(entry)).second);
|
||||
|
||||
// Add references for the dependencies to the task.
|
||||
std::vector<ObjectID> task_deps;
|
||||
for (size_t i = 0; i < spec.NumArgs(); i++) {
|
||||
if (spec.ArgByRef(i)) {
|
||||
for (size_t j = 0; j < spec.ArgIdCount(i); j++) {
|
||||
task_deps.push_back(spec.ArgId(i, j));
|
||||
}
|
||||
}
|
||||
}
|
||||
reference_counter_->AddSubmittedTaskReferences(task_deps);
|
||||
|
||||
// Add new owned objects for the return values of the task.
|
||||
size_t num_returns = spec.NumReturns();
|
||||
if (spec.IsActorCreationTask() || spec.IsActorTask()) {
|
||||
num_returns--;
|
||||
}
|
||||
for (size_t i = 0; i < num_returns; i++) {
|
||||
reference_counter_->AddOwnedObject(spec.ReturnId(i, TaskTransportType::DIRECT),
|
||||
caller_id, caller_address);
|
||||
}
|
||||
}
|
||||
|
||||
void TaskManager::DrainAndShutdown(std::function<void()> shutdown) {
|
||||
|
@ -48,6 +71,8 @@ void TaskManager::CompletePendingTask(const TaskID &task_id,
|
|||
pending_tasks_.erase(it);
|
||||
}
|
||||
|
||||
RemovePlasmaSubmittedTaskReferences(spec);
|
||||
|
||||
for (int i = 0; i < reply.return_objects_size(); i++) {
|
||||
const auto &return_object = reply.return_objects(i);
|
||||
ObjectID object_id = ObjectID::FromBinary(return_object.object_id());
|
||||
|
@ -134,6 +159,7 @@ void TaskManager::PendingTaskFailed(const TaskID &task_id, rpc::ErrorType error_
|
|||
}
|
||||
}
|
||||
}
|
||||
RemovePlasmaSubmittedTaskReferences(spec);
|
||||
MarkPendingTaskFailed(task_id, spec, error_type);
|
||||
}
|
||||
|
||||
|
@ -148,6 +174,28 @@ void TaskManager::ShutdownIfNeeded() {
|
|||
}
|
||||
}
|
||||
|
||||
void TaskManager::RemoveSubmittedTaskReferences(const std::vector<ObjectID> &object_ids) {
|
||||
std::vector<ObjectID> deleted;
|
||||
reference_counter_->RemoveSubmittedTaskReferences(object_ids, &deleted);
|
||||
in_memory_store_->Delete(deleted);
|
||||
}
|
||||
|
||||
void TaskManager::OnTaskDependenciesInlined(const std::vector<ObjectID> &object_ids) {
|
||||
RemoveSubmittedTaskReferences(object_ids);
|
||||
}
|
||||
|
||||
void TaskManager::RemovePlasmaSubmittedTaskReferences(TaskSpecification &spec) {
|
||||
std::vector<ObjectID> plasma_dependencies;
|
||||
for (size_t i = 0; i < spec.NumArgs(); i++) {
|
||||
auto count = spec.ArgIdCount(i);
|
||||
if (count > 0) {
|
||||
const auto &id = spec.ArgId(i, 0);
|
||||
plasma_dependencies.push_back(id);
|
||||
}
|
||||
}
|
||||
RemoveSubmittedTaskReferences(plasma_dependencies);
|
||||
}
|
||||
|
||||
void TaskManager::MarkPendingTaskFailed(const TaskID &task_id,
|
||||
const TaskSpecification &spec,
|
||||
rpc::ErrorType error_type) {
|
||||
|
|
|
@ -21,6 +21,8 @@ class TaskFinisherInterface {
|
|||
virtual void PendingTaskFailed(const TaskID &task_id, rpc::ErrorType error_type,
|
||||
Status *status = nullptr) = 0;
|
||||
|
||||
virtual void OnTaskDependenciesInlined(const std::vector<ObjectID> &object_ids) = 0;
|
||||
|
||||
virtual ~TaskFinisherInterface() {}
|
||||
};
|
||||
|
||||
|
@ -29,19 +31,24 @@ using RetryTaskCallback = std::function<void(const TaskSpecification &spec)>;
|
|||
class TaskManager : public TaskFinisherInterface {
|
||||
public:
|
||||
TaskManager(std::shared_ptr<CoreWorkerMemoryStore> in_memory_store,
|
||||
std::shared_ptr<ReferenceCounter> reference_counter,
|
||||
std::shared_ptr<ActorManagerInterface> actor_manager,
|
||||
RetryTaskCallback retry_task_callback)
|
||||
: in_memory_store_(in_memory_store),
|
||||
reference_counter_(reference_counter),
|
||||
actor_manager_(actor_manager),
|
||||
retry_task_callback_(retry_task_callback) {}
|
||||
|
||||
/// Add a task that is pending execution.
|
||||
///
|
||||
/// \param[in] caller_id The TaskID of the calling task.
|
||||
/// \param[in] caller_address The rpc address of the calling task.
|
||||
/// \param[in] spec The spec of the pending task.
|
||||
/// \param[in] max_retries Number of times this task may be retried
|
||||
/// on failure.
|
||||
/// \return Void.
|
||||
void AddPendingTask(const TaskSpecification &spec, int max_retries = 0);
|
||||
void AddPendingTask(const TaskID &caller_id, const rpc::Address &caller_address,
|
||||
const TaskSpecification &spec, int max_retries = 0);
|
||||
|
||||
/// Wait for all pending tasks to finish, and then shutdown.
|
||||
///
|
||||
|
@ -72,6 +79,8 @@ class TaskManager : public TaskFinisherInterface {
|
|||
void PendingTaskFailed(const TaskID &task_id, rpc::ErrorType error_type,
|
||||
Status *status = nullptr) override;
|
||||
|
||||
void OnTaskDependenciesInlined(const std::vector<ObjectID> &object_id) override;
|
||||
|
||||
/// Return the spec for a pending task.
|
||||
TaskSpecification GetTaskSpec(const TaskID &task_id) const;
|
||||
|
||||
|
@ -81,12 +90,25 @@ class TaskManager : public TaskFinisherInterface {
|
|||
void MarkPendingTaskFailed(const TaskID &task_id, const TaskSpecification &spec,
|
||||
rpc::ErrorType error_type) LOCKS_EXCLUDED(mu_);
|
||||
|
||||
/// Remove submittted task references in the reference counter for the object IDs.
|
||||
/// If their reference counts reach zero, they are deleted from the in-memory store.
|
||||
void RemoveSubmittedTaskReferences(const std::vector<ObjectID> &object_ids);
|
||||
|
||||
/// Helper function to call RemoveSubmittedTaskReferences on the plasma dependencies
|
||||
/// of the given task spec.
|
||||
void RemovePlasmaSubmittedTaskReferences(TaskSpecification &spec);
|
||||
|
||||
/// Shutdown if all tasks are finished and shutdown is scheduled.
|
||||
void ShutdownIfNeeded() LOCKS_EXCLUDED(mu_);
|
||||
|
||||
/// Used to store task results.
|
||||
std::shared_ptr<CoreWorkerMemoryStore> in_memory_store_;
|
||||
|
||||
/// Used for reference counting objects.
|
||||
/// The task manager is responsible for managing all references related to
|
||||
/// submitted tasks (dependencies and return objects).
|
||||
std::shared_ptr<ReferenceCounter> reference_counter_;
|
||||
|
||||
// Interface for publishing actor creation.
|
||||
std::shared_ptr<ActorManagerInterface> actor_manager_;
|
||||
|
||||
|
|
|
@ -1,6 +1,5 @@
|
|||
#include "gmock/gmock.h"
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
#include "ray/common/task/task_spec.h"
|
||||
#include "ray/core_worker/store_provider/memory_store/memory_store.h"
|
||||
#include "ray/core_worker/transport/direct_task_transport.h"
|
||||
|
@ -45,6 +44,8 @@ class MockTaskFinisher : public TaskFinisherInterface {
|
|||
const rpc::Address *addr));
|
||||
MOCK_METHOD3(PendingTaskFailed,
|
||||
void(const TaskID &task_id, rpc::ErrorType error_type, Status *status));
|
||||
|
||||
MOCK_METHOD1(OnTaskDependenciesInlined, void(const std::vector<ObjectID> &object_ids));
|
||||
};
|
||||
|
||||
TaskSpecification CreateActorTaskHelper(ActorID actor_id, int64_t counter) {
|
||||
|
|
|
@ -55,8 +55,13 @@ class MockTaskFinisher : public TaskFinisherInterface {
|
|||
num_tasks_failed++;
|
||||
}
|
||||
|
||||
void OnTaskDependenciesInlined(const std::vector<ObjectID> &object_ids) override {
|
||||
num_inlined += object_ids.size();
|
||||
}
|
||||
|
||||
int num_tasks_complete = 0;
|
||||
int num_tasks_failed = 0;
|
||||
int num_inlined = 0;
|
||||
};
|
||||
|
||||
class MockRayletClient : public WorkerLeaseInterface {
|
||||
|
@ -136,17 +141,20 @@ TEST(TestMemoryStore, TestPromoteToPlasma) {
|
|||
|
||||
TEST(LocalDependencyResolverTest, TestNoDependencies) {
|
||||
auto store = std::make_shared<CoreWorkerMemoryStore>();
|
||||
LocalDependencyResolver resolver(store);
|
||||
auto task_finisher = std::make_shared<MockTaskFinisher>();
|
||||
LocalDependencyResolver resolver(store, task_finisher);
|
||||
TaskSpecification task;
|
||||
bool ok = false;
|
||||
resolver.ResolveDependencies(task, [&ok]() { ok = true; });
|
||||
ASSERT_TRUE(ok);
|
||||
ASSERT_EQ(task_finisher->num_inlined, 0);
|
||||
}
|
||||
|
||||
TEST(LocalDependencyResolverTest, TestIgnorePlasmaDependencies) {
|
||||
auto store = std::make_shared<CoreWorkerMemoryStore>();
|
||||
LocalDependencyResolver resolver(store);
|
||||
ObjectID obj1 = ObjectID::FromRandom().WithTransportType(TaskTransportType::RAYLET);
|
||||
auto task_finisher = std::make_shared<MockTaskFinisher>();
|
||||
LocalDependencyResolver resolver(store, task_finisher);
|
||||
ObjectID obj1 = ObjectID::FromRandom();
|
||||
TaskSpecification task;
|
||||
task.GetMutableMessage().add_args()->add_object_ids(obj1.Binary());
|
||||
bool ok = false;
|
||||
|
@ -154,11 +162,13 @@ TEST(LocalDependencyResolverTest, TestIgnorePlasmaDependencies) {
|
|||
// We ignore and don't block on plasma dependencies.
|
||||
ASSERT_TRUE(ok);
|
||||
ASSERT_EQ(resolver.NumPendingTasks(), 0);
|
||||
ASSERT_EQ(task_finisher->num_inlined, 0);
|
||||
}
|
||||
|
||||
TEST(LocalDependencyResolverTest, TestHandlePlasmaPromotion) {
|
||||
auto store = std::make_shared<CoreWorkerMemoryStore>();
|
||||
LocalDependencyResolver resolver(store);
|
||||
auto task_finisher = std::make_shared<MockTaskFinisher>();
|
||||
LocalDependencyResolver resolver(store, task_finisher);
|
||||
ObjectID obj1 = ObjectID::FromRandom().WithTransportType(TaskTransportType::DIRECT);
|
||||
std::string meta = std::to_string(static_cast<int>(rpc::ErrorType::OBJECT_IN_PLASMA));
|
||||
auto metadata = const_cast<uint8_t *>(reinterpret_cast<const uint8_t *>(meta.data()));
|
||||
|
@ -175,11 +185,13 @@ TEST(LocalDependencyResolverTest, TestHandlePlasmaPromotion) {
|
|||
// Checks that the object id is still a direct call id.
|
||||
ASSERT_TRUE(task.ArgId(0, 0).IsDirectCallType());
|
||||
ASSERT_EQ(resolver.NumPendingTasks(), 0);
|
||||
ASSERT_EQ(task_finisher->num_inlined, 0);
|
||||
}
|
||||
|
||||
TEST(LocalDependencyResolverTest, TestInlineLocalDependencies) {
|
||||
auto store = std::make_shared<CoreWorkerMemoryStore>();
|
||||
LocalDependencyResolver resolver(store);
|
||||
auto task_finisher = std::make_shared<MockTaskFinisher>();
|
||||
LocalDependencyResolver resolver(store, task_finisher);
|
||||
ObjectID obj1 = ObjectID::FromRandom().WithTransportType(TaskTransportType::DIRECT);
|
||||
ObjectID obj2 = ObjectID::FromRandom().WithTransportType(TaskTransportType::DIRECT);
|
||||
auto data = GenerateRandomObject();
|
||||
|
@ -198,11 +210,13 @@ TEST(LocalDependencyResolverTest, TestInlineLocalDependencies) {
|
|||
ASSERT_NE(task.ArgData(0), nullptr);
|
||||
ASSERT_NE(task.ArgData(1), nullptr);
|
||||
ASSERT_EQ(resolver.NumPendingTasks(), 0);
|
||||
ASSERT_EQ(task_finisher->num_inlined, 2);
|
||||
}
|
||||
|
||||
TEST(LocalDependencyResolverTest, TestInlinePendingDependencies) {
|
||||
auto store = std::make_shared<CoreWorkerMemoryStore>();
|
||||
LocalDependencyResolver resolver(store);
|
||||
auto task_finisher = std::make_shared<MockTaskFinisher>();
|
||||
LocalDependencyResolver resolver(store, task_finisher);
|
||||
ObjectID obj1 = ObjectID::FromRandom().WithTransportType(TaskTransportType::DIRECT);
|
||||
ObjectID obj2 = ObjectID::FromRandom().WithTransportType(TaskTransportType::DIRECT);
|
||||
auto data = GenerateRandomObject();
|
||||
|
@ -223,6 +237,7 @@ TEST(LocalDependencyResolverTest, TestInlinePendingDependencies) {
|
|||
ASSERT_NE(task.ArgData(0), nullptr);
|
||||
ASSERT_NE(task.ArgData(1), nullptr);
|
||||
ASSERT_EQ(resolver.NumPendingTasks(), 0);
|
||||
ASSERT_EQ(task_finisher->num_inlined, 2);
|
||||
}
|
||||
|
||||
TaskSpecification BuildTaskSpec(const std::unordered_map<std::string, double> &resources,
|
||||
|
|
|
@ -1,17 +1,22 @@
|
|||
#include "gtest/gtest.h"
|
||||
#include "ray/core_worker/task_manager.h"
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "ray/common/task/task_spec.h"
|
||||
#include "ray/core_worker/actor_manager.h"
|
||||
#include "ray/core_worker/reference_count.h"
|
||||
#include "ray/core_worker/store_provider/memory_store/memory_store.h"
|
||||
#include "ray/core_worker/task_manager.h"
|
||||
#include "ray/util/test_util.h"
|
||||
|
||||
namespace ray {
|
||||
|
||||
TaskSpecification CreateTaskHelper(uint64_t num_returns) {
|
||||
TaskSpecification CreateTaskHelper(uint64_t num_returns,
|
||||
std::vector<ObjectID> dependencies) {
|
||||
TaskSpecification task;
|
||||
task.GetMutableMessage().set_task_id(TaskID::ForFakeTask().Binary());
|
||||
task.GetMutableMessage().set_num_returns(num_returns);
|
||||
for (const ObjectID &dep : dependencies) {
|
||||
task.GetMutableMessage().add_args()->add_object_ids(dep.Binary());
|
||||
}
|
||||
return task;
|
||||
}
|
||||
|
||||
|
@ -33,23 +38,31 @@ class TaskManagerTest : public ::testing::Test {
|
|||
public:
|
||||
TaskManagerTest()
|
||||
: store_(std::shared_ptr<CoreWorkerMemoryStore>(new CoreWorkerMemoryStore())),
|
||||
reference_counter_(std::shared_ptr<ReferenceCounter>(new ReferenceCounter())),
|
||||
actor_manager_(std::shared_ptr<ActorManagerInterface>(new MockActorManager())),
|
||||
manager_(store_, actor_manager_, [this](const TaskSpecification &spec) {
|
||||
num_retries_++;
|
||||
return Status::OK();
|
||||
}) {}
|
||||
manager_(store_, reference_counter_, actor_manager_,
|
||||
[this](const TaskSpecification &spec) {
|
||||
num_retries_++;
|
||||
return Status::OK();
|
||||
}) {}
|
||||
|
||||
std::shared_ptr<CoreWorkerMemoryStore> store_;
|
||||
std::shared_ptr<ReferenceCounter> reference_counter_;
|
||||
std::shared_ptr<ActorManagerInterface> actor_manager_;
|
||||
TaskManager manager_;
|
||||
int num_retries_ = 0;
|
||||
};
|
||||
|
||||
TEST_F(TaskManagerTest, TestTaskSuccess) {
|
||||
auto spec = CreateTaskHelper(1);
|
||||
TaskID caller_id = TaskID::Nil();
|
||||
rpc::Address caller_address;
|
||||
ObjectID dep1 = ObjectID::FromRandom();
|
||||
ObjectID dep2 = ObjectID::FromRandom();
|
||||
auto spec = CreateTaskHelper(1, {dep1, dep2});
|
||||
ASSERT_FALSE(manager_.IsTaskPending(spec.TaskId()));
|
||||
manager_.AddPendingTask(spec);
|
||||
manager_.AddPendingTask(caller_id, caller_address, spec);
|
||||
ASSERT_TRUE(manager_.IsTaskPending(spec.TaskId()));
|
||||
ASSERT_EQ(reference_counter_->NumObjectIDsInScope(), 3);
|
||||
auto return_id = spec.ReturnId(0, TaskTransportType::DIRECT);
|
||||
WorkerContext ctx(WorkerType::WORKER, JobID::FromInt(0));
|
||||
|
||||
|
@ -60,6 +73,8 @@ TEST_F(TaskManagerTest, TestTaskSuccess) {
|
|||
return_object->set_data(data->Data(), data->Size());
|
||||
manager_.CompletePendingTask(spec.TaskId(), reply, nullptr);
|
||||
ASSERT_FALSE(manager_.IsTaskPending(spec.TaskId()));
|
||||
// Only the return object reference should remain.
|
||||
ASSERT_EQ(reference_counter_->NumObjectIDsInScope(), 1);
|
||||
|
||||
std::vector<std::shared_ptr<RayObject>> results;
|
||||
RAY_CHECK_OK(store_->Get({return_id}, 1, -1, ctx, false, &results));
|
||||
|
@ -69,19 +84,33 @@ TEST_F(TaskManagerTest, TestTaskSuccess) {
|
|||
return_object->data().size()),
|
||||
0);
|
||||
ASSERT_EQ(num_retries_, 0);
|
||||
|
||||
std::vector<ObjectID> removed;
|
||||
reference_counter_->AddLocalReference(return_id);
|
||||
reference_counter_->RemoveLocalReference(return_id, &removed);
|
||||
ASSERT_EQ(removed[0], return_id);
|
||||
ASSERT_EQ(reference_counter_->NumObjectIDsInScope(), 0);
|
||||
}
|
||||
|
||||
TEST_F(TaskManagerTest, TestTaskFailure) {
|
||||
auto spec = CreateTaskHelper(1);
|
||||
TaskID caller_id = TaskID::Nil();
|
||||
rpc::Address caller_address;
|
||||
ObjectID dep1 = ObjectID::FromRandom();
|
||||
ObjectID dep2 = ObjectID::FromRandom();
|
||||
ASSERT_EQ(reference_counter_->NumObjectIDsInScope(), 0);
|
||||
auto spec = CreateTaskHelper(1, {dep1, dep2});
|
||||
ASSERT_FALSE(manager_.IsTaskPending(spec.TaskId()));
|
||||
manager_.AddPendingTask(spec);
|
||||
manager_.AddPendingTask(caller_id, caller_address, spec);
|
||||
ASSERT_TRUE(manager_.IsTaskPending(spec.TaskId()));
|
||||
ASSERT_EQ(reference_counter_->NumObjectIDsInScope(), 3);
|
||||
auto return_id = spec.ReturnId(0, TaskTransportType::DIRECT);
|
||||
WorkerContext ctx(WorkerType::WORKER, JobID::FromInt(0));
|
||||
|
||||
auto error = rpc::ErrorType::WORKER_DIED;
|
||||
manager_.PendingTaskFailed(spec.TaskId(), error);
|
||||
ASSERT_FALSE(manager_.IsTaskPending(spec.TaskId()));
|
||||
// Only the return object reference should remain.
|
||||
ASSERT_EQ(reference_counter_->NumObjectIDsInScope(), 1);
|
||||
|
||||
std::vector<std::shared_ptr<RayObject>> results;
|
||||
RAY_CHECK_OK(store_->Get({return_id}, 1, -1, ctx, false, &results));
|
||||
|
@ -90,14 +119,26 @@ TEST_F(TaskManagerTest, TestTaskFailure) {
|
|||
ASSERT_TRUE(results[0]->IsException(&stored_error));
|
||||
ASSERT_EQ(stored_error, error);
|
||||
ASSERT_EQ(num_retries_, 0);
|
||||
|
||||
std::vector<ObjectID> removed;
|
||||
reference_counter_->AddLocalReference(return_id);
|
||||
reference_counter_->RemoveLocalReference(return_id, &removed);
|
||||
ASSERT_EQ(removed[0], return_id);
|
||||
ASSERT_EQ(reference_counter_->NumObjectIDsInScope(), 0);
|
||||
}
|
||||
|
||||
TEST_F(TaskManagerTest, TestTaskRetry) {
|
||||
auto spec = CreateTaskHelper(1);
|
||||
TaskID caller_id = TaskID::Nil();
|
||||
rpc::Address caller_address;
|
||||
ObjectID dep1 = ObjectID::FromRandom();
|
||||
ObjectID dep2 = ObjectID::FromRandom();
|
||||
ASSERT_EQ(reference_counter_->NumObjectIDsInScope(), 0);
|
||||
auto spec = CreateTaskHelper(1, {dep1, dep2});
|
||||
ASSERT_FALSE(manager_.IsTaskPending(spec.TaskId()));
|
||||
int num_retries = 3;
|
||||
manager_.AddPendingTask(spec, num_retries);
|
||||
manager_.AddPendingTask(caller_id, caller_address, spec, num_retries);
|
||||
ASSERT_TRUE(manager_.IsTaskPending(spec.TaskId()));
|
||||
ASSERT_EQ(reference_counter_->NumObjectIDsInScope(), 3);
|
||||
auto return_id = spec.ReturnId(0, TaskTransportType::DIRECT);
|
||||
WorkerContext ctx(WorkerType::WORKER, JobID::FromInt(0));
|
||||
|
||||
|
@ -105,6 +146,7 @@ TEST_F(TaskManagerTest, TestTaskRetry) {
|
|||
for (int i = 0; i < num_retries; i++) {
|
||||
manager_.PendingTaskFailed(spec.TaskId(), error);
|
||||
ASSERT_TRUE(manager_.IsTaskPending(spec.TaskId()));
|
||||
ASSERT_EQ(reference_counter_->NumObjectIDsInScope(), 3);
|
||||
std::vector<std::shared_ptr<RayObject>> results;
|
||||
ASSERT_FALSE(store_->Get({return_id}, 1, 0, ctx, false, &results).ok());
|
||||
ASSERT_EQ(num_retries_, i + 1);
|
||||
|
@ -112,6 +154,8 @@ TEST_F(TaskManagerTest, TestTaskRetry) {
|
|||
|
||||
manager_.PendingTaskFailed(spec.TaskId(), error);
|
||||
ASSERT_FALSE(manager_.IsTaskPending(spec.TaskId()));
|
||||
// Only the return object reference should remain.
|
||||
ASSERT_EQ(reference_counter_->NumObjectIDsInScope(), 1);
|
||||
|
||||
std::vector<std::shared_ptr<RayObject>> results;
|
||||
RAY_CHECK_OK(store_->Get({return_id}, 1, -0, ctx, false, &results));
|
||||
|
@ -119,6 +163,12 @@ TEST_F(TaskManagerTest, TestTaskRetry) {
|
|||
rpc::ErrorType stored_error;
|
||||
ASSERT_TRUE(results[0]->IsException(&stored_error));
|
||||
ASSERT_EQ(stored_error, error);
|
||||
|
||||
std::vector<ObjectID> removed;
|
||||
reference_counter_->AddLocalReference(return_id);
|
||||
reference_counter_->RemoveLocalReference(return_id, &removed);
|
||||
ASSERT_EQ(removed[0], return_id);
|
||||
ASSERT_EQ(reference_counter_->NumObjectIDsInScope(), 0);
|
||||
}
|
||||
|
||||
} // namespace ray
|
||||
|
|
|
@ -18,7 +18,7 @@ struct TaskState {
|
|||
|
||||
void InlineDependencies(
|
||||
absl::flat_hash_map<ObjectID, std::shared_ptr<RayObject>> dependencies,
|
||||
TaskSpecification &task) {
|
||||
TaskSpecification &task, std::vector<ObjectID> *inlined) {
|
||||
auto &msg = task.GetMutableMessage();
|
||||
size_t found = 0;
|
||||
for (size_t i = 0; i < task.NumArgs(); i++) {
|
||||
|
@ -43,6 +43,7 @@ void InlineDependencies(
|
|||
const auto &metadata = it->second->GetMetadata();
|
||||
mutable_arg->set_metadata(metadata->Data(), metadata->Size());
|
||||
}
|
||||
inlined->push_back(id);
|
||||
}
|
||||
found++;
|
||||
} else {
|
||||
|
@ -83,15 +84,19 @@ void LocalDependencyResolver::ResolveDependencies(TaskSpecification &task,
|
|||
obj_id, [this, state, obj_id, on_complete](std::shared_ptr<RayObject> obj) {
|
||||
RAY_CHECK(obj != nullptr);
|
||||
bool complete = false;
|
||||
std::vector<ObjectID> inlined;
|
||||
{
|
||||
absl::MutexLock lock(&mu_);
|
||||
state->local_dependencies[obj_id] = std::move(obj);
|
||||
if (--state->dependencies_remaining == 0) {
|
||||
InlineDependencies(state->local_dependencies, state->task);
|
||||
InlineDependencies(state->local_dependencies, state->task, &inlined);
|
||||
complete = true;
|
||||
num_pending_ -= 1;
|
||||
}
|
||||
}
|
||||
if (inlined.size() > 0) {
|
||||
task_finisher_->OnTaskDependenciesInlined(inlined);
|
||||
}
|
||||
if (complete) {
|
||||
on_complete();
|
||||
}
|
||||
|
|
|
@ -6,14 +6,16 @@
|
|||
#include "ray/common/id.h"
|
||||
#include "ray/common/task/task_spec.h"
|
||||
#include "ray/core_worker/store_provider/memory_store/memory_store.h"
|
||||
#include "ray/core_worker/task_manager.h"
|
||||
|
||||
namespace ray {
|
||||
|
||||
// This class is thread-safe.
|
||||
class LocalDependencyResolver {
|
||||
public:
|
||||
LocalDependencyResolver(std::shared_ptr<CoreWorkerMemoryStore> store)
|
||||
: in_memory_store_(store), num_pending_(0) {}
|
||||
LocalDependencyResolver(std::shared_ptr<CoreWorkerMemoryStore> store,
|
||||
std::shared_ptr<TaskFinisherInterface> task_finisher)
|
||||
: in_memory_store_(store), task_finisher_(task_finisher), num_pending_(0) {}
|
||||
|
||||
/// Resolve all local and remote dependencies for the task, calling the specified
|
||||
/// callback when done. Direct call ids in the task specification will be resolved
|
||||
|
@ -33,6 +35,9 @@ class LocalDependencyResolver {
|
|||
/// The in-memory store.
|
||||
std::shared_ptr<CoreWorkerMemoryStore> in_memory_store_;
|
||||
|
||||
/// Used to complete tasks.
|
||||
std::shared_ptr<TaskFinisherInterface> task_finisher_;
|
||||
|
||||
/// Number of tasks pending dependency resolution.
|
||||
std::atomic<int> num_pending_;
|
||||
|
||||
|
|
|
@ -8,9 +8,9 @@
|
|||
#include <queue>
|
||||
#include <set>
|
||||
#include <utility>
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
|
||||
#include "absl/base/thread_annotations.h"
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/synchronization/mutex.h"
|
||||
#include "ray/common/id.h"
|
||||
#include "ray/common/ray_object.h"
|
||||
|
@ -39,7 +39,7 @@ class CoreWorkerDirectActorTaskSubmitter {
|
|||
std::shared_ptr<CoreWorkerMemoryStore> store,
|
||||
std::shared_ptr<TaskFinisherInterface> task_finisher)
|
||||
: client_factory_(client_factory),
|
||||
resolver_(store),
|
||||
resolver_(store, task_finisher),
|
||||
task_finisher_(task_finisher) {}
|
||||
|
||||
/// Submit a task to an actor for execution.
|
||||
|
|
|
@ -5,7 +5,6 @@
|
|||
|
||||
#include "absl/base/thread_annotations.h"
|
||||
#include "absl/synchronization/mutex.h"
|
||||
|
||||
#include "ray/common/id.h"
|
||||
#include "ray/common/ray_object.h"
|
||||
#include "ray/core_worker/context.h"
|
||||
|
@ -41,7 +40,7 @@ class CoreWorkerDirectTaskSubmitter {
|
|||
: local_lease_client_(lease_client),
|
||||
client_factory_(client_factory),
|
||||
lease_client_factory_(lease_client_factory),
|
||||
resolver_(store),
|
||||
resolver_(store, task_finisher),
|
||||
task_finisher_(task_finisher),
|
||||
local_raylet_id_(local_raylet_id),
|
||||
lease_timeout_ms_(lease_timeout_ms) {}
|
||||
|
|
Loading…
Add table
Reference in a new issue