Reference counting for direct call submitted tasks (#6514)

Co-authored-by: Zhijun Fu <37800433+zhijunfu@users.noreply.github.com>
This commit is contained in:
Edward Oakes 2019-12-20 17:06:33 -08:00 committed by GitHub
parent b0b6b56bb7
commit e50aa99be1
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
17 changed files with 512 additions and 344 deletions

View file

@ -1073,16 +1073,12 @@ cdef class CoreWorker:
return output
def add_object_id_reference(self, ObjectID object_id):
cdef:
CObjectID c_object_id = object_id.native()
# Note: faster to not release GIL for short-running op.
self.core_worker.get().AddObjectIDReference(c_object_id)
self.core_worker.get().AddLocalReference(object_id.native())
def remove_object_id_reference(self, ObjectID object_id):
cdef:
CObjectID c_object_id = object_id.native()
# Note: faster to not release GIL for short-running op.
self.core_worker.get().RemoveObjectIDReference(c_object_id)
self.core_worker.get().RemoveLocalReference(object_id.native())
def serialize_and_promote_object_id(self, ObjectID object_id):
cdef:
@ -1174,6 +1170,24 @@ cdef class CoreWorker:
def current_actor_is_asyncio(self):
return self.core_worker.get().GetWorkerContext().CurrentActorIsAsync()
def get_all_reference_counts(self):
cdef:
unordered_map[CObjectID, pair[size_t, size_t]] c_ref_counts
unordered_map[CObjectID, pair[size_t, size_t]].iterator it
c_ref_counts = self.core_worker.get().GetAllReferenceCounts()
it = c_ref_counts.begin()
ref_counts = {}
while it != c_ref_counts.end():
object_id = ObjectID(dereference(it).first.Binary())
ref_counts[object_id] = {
"local": dereference(it).second.first,
"submitted": dereference(it).second.second}
postincrement(it)
return ref_counts
def in_memory_store_get_async(self, ObjectID object_id, future):
self.core_worker.get().GetAsync(
object_id.native(),

View file

@ -118,8 +118,8 @@ cdef extern from "ray/core_worker/core_worker.h" nogil:
CActorID DeserializeAndRegisterActorHandle(const c_string &bytes)
CRayStatus SerializeActorHandle(const CActorID &actor_id, c_string
*bytes)
void AddObjectIDReference(const CObjectID &object_id)
void RemoveObjectIDReference(const CObjectID &object_id)
void AddLocalReference(const CObjectID &object_id)
void RemoveLocalReference(const CObjectID &object_id)
void PromoteObjectToPlasma(const CObjectID &object_id)
void PromoteToPlasmaAndGetOwnershipInfo(const CObjectID &object_id,
CTaskID *owner_id,
@ -149,6 +149,9 @@ cdef extern from "ray/core_worker/core_worker.h" nogil:
CWorkerContext &GetWorkerContext()
void YieldCurrentFiber(CFiberEvent &coroutine_done)
unordered_map[CObjectID, pair[size_t, size_t]] GetAllReferenceCounts()
void GetAsync(const CObjectID &object_id,
ray_callback_function successs_callback,
ray_callback_function fallback_callback,

View file

@ -0,0 +1,162 @@
# coding: utf-8
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import copy
import tempfile
import numpy as np
import time
import logging
import uuid
import ray
import ray.cluster_utils
import ray.test_utils
logger = logging.getLogger(__name__)
def _check_refcounts(expected):
actual = ray.worker.global_worker.core_worker.get_all_reference_counts()
assert len(expected) == len(actual)
for object_id, (local, submitted) in expected.items():
assert object_id in actual
assert local == actual[object_id]["local"]
assert submitted == actual[object_id]["submitted"]
def check_refcounts(expected, timeout=1):
start = time.time()
while True:
try:
_check_refcounts(expected)
break
except AssertionError as e:
if time.time() - start > timeout:
raise e
else:
time.sleep(0.1)
def test_local_refcounts(ray_start_regular):
oid1 = ray.put(None)
check_refcounts({oid1: (1, 0)})
oid1_copy = copy.copy(oid1)
check_refcounts({oid1: (2, 0)})
del oid1
check_refcounts({oid1_copy: (1, 0)})
del oid1_copy
check_refcounts({})
def test_dependency_refcounts(ray_start_regular):
# Return a large object that will be spilled to plasma.
def large_object():
return np.zeros(10 * 1024 * 1024, dtype=np.uint8)
# TODO: Clean up tmpfiles?
def random_path():
return os.path.join(tempfile.gettempdir(), uuid.uuid4().hex)
def touch(path):
with open(path, "w"):
pass
def wait_for_file(path):
while True:
if os.path.exists(path):
break
time.sleep(0.1)
@ray.remote
def one_dep(dep, path=None, fail=False):
if path is not None:
wait_for_file(path)
if fail:
raise Exception("failed on purpose")
@ray.remote
def one_dep_large(dep, path=None):
if path is not None:
wait_for_file(path)
# This should be spilled to plasma.
return large_object()
# Test that regular plasma dependency refcounts are decremented once the
# task finishes.
f = random_path()
large_dep = ray.put(large_object())
result = one_dep.remote(large_dep, path=f)
check_refcounts({large_dep: (1, 1), result: (1, 0)})
touch(f)
# Reference count should be removed once the task finishes.
check_refcounts({large_dep: (1, 0), result: (1, 0)})
del large_dep, result
check_refcounts({})
# Test that inlined dependency refcounts are decremented once they are
# inlined.
f = random_path()
dep = one_dep.remote(None, path=f)
check_refcounts({dep: (1, 0)})
result = one_dep.remote(dep)
check_refcounts({dep: (1, 1), result: (1, 0)})
touch(f)
# Reference count should be removed as soon as the dependency is inlined.
check_refcounts({dep: (1, 0), result: (1, 0)}, timeout=1)
del dep, result
check_refcounts({})
# Test that spilled plasma dependency refcounts are decremented once
# the task finishes.
f1, f2 = random_path(), random_path()
dep = one_dep_large.remote(None, path=f1)
check_refcounts({dep: (1, 0)})
result = one_dep.remote(dep, path=f2)
check_refcounts({dep: (1, 1), result: (1, 0)})
touch(f1)
ray.get(dep, timeout=5.0)
# Reference count should remain because the dependency is in plasma.
check_refcounts({dep: (1, 1), result: (1, 0)})
touch(f2)
# Reference count should be removed because the task finished.
check_refcounts({dep: (1, 0), result: (1, 0)})
del dep, result
check_refcounts({})
# Test that regular plasma dependency refcounts are decremented if a task
# fails.
f = random_path()
large_dep = ray.put(large_object())
result = one_dep.remote(large_dep, path=f, fail=True)
check_refcounts({large_dep: (1, 1), result: (1, 0)})
touch(f)
# Reference count should be removed once the task finishes.
check_refcounts({large_dep: (1, 0), result: (1, 0)})
del large_dep, result
check_refcounts({})
# Test that spilled plasma dependency refcounts are decremented if a task
# fails.
f1, f2 = random_path(), random_path()
dep = one_dep_large.remote(None, path=f1)
check_refcounts({dep: (1, 0)})
result = one_dep.remote(dep, path=f2, fail=True)
check_refcounts({dep: (1, 1), result: (1, 0)})
touch(f1)
ray.get(dep, timeout=5.0)
# Reference count should remain because the dependency is in plasma.
check_refcounts({dep: (1, 1), result: (1, 0)})
touch(f2)
# Reference count should be removed because the task finished.
check_refcounts({dep: (1, 0), result: (1, 0)})
del dep, result
check_refcounts({})
if __name__ == "__main__":
import pytest
import sys
sys.exit(pytest.main(["-v", __file__]))

View file

@ -192,14 +192,14 @@ CoreWorker::CoreWorker(const WorkerType worker_type, const Language language,
ref_counting_enabled ? reference_counter_ : nullptr, local_raylet_client_));
task_manager_.reset(new TaskManager(
memory_store_, actor_manager_, [this](const TaskSpecification &spec) {
memory_store_, reference_counter_, actor_manager_,
[this](const TaskSpecification &spec) {
// Retry after a delay to emulate the existing Raylet reconstruction
// behaviour. TODO(ekl) backoff exponentially.
RAY_LOG(ERROR) << "Will resubmit task after a 5 second delay: "
<< spec.DebugString();
to_resubmit_.push_back(std::make_pair(current_time_ms() + 5000, spec));
}));
resolver_.reset(new LocalDependencyResolver(memory_store_));
// Create an entry for the driver task in the task table. This task is
// added immediately with status RUNNING. This allows us to push errors
@ -377,8 +377,7 @@ Status CoreWorker::Put(const RayObject &object, ObjectID *object_id) {
*object_id = ObjectID::ForPut(worker_context_.GetCurrentTaskID(),
worker_context_.GetNextPutIndex(),
static_cast<uint8_t>(TaskTransportType::RAYLET));
reference_counter_->AddOwnedObject(*object_id, GetCallerId(), rpc_address_,
std::make_shared<std::vector<ObjectID>>());
reference_counter_->AddOwnedObject(*object_id, GetCallerId(), rpc_address_);
return Put(object, *object_id);
}
@ -460,10 +459,6 @@ Status CoreWorker::Get(const std::vector<ObjectID> &ids, const int64_t timeout_m
// object.
will_throw_exception = true;
}
// If we got the result for this ObjectID, the task that created it must
// have finished. Therefore, we can safely remove its reference counting
// dependencies.
RemoveObjectIDDependencies(ids[i]);
} else {
missing_result = true;
}
@ -597,30 +592,6 @@ TaskID CoreWorker::GetCallerId() const {
return caller_id;
}
void CoreWorker::PinObjectReferences(const TaskSpecification &task_spec,
const TaskTransportType transport_type) {
size_t num_returns = task_spec.NumReturns();
if (task_spec.IsActorCreationTask() || task_spec.IsActorTask()) {
num_returns--;
}
std::shared_ptr<std::vector<ObjectID>> task_deps =
std::make_shared<std::vector<ObjectID>>();
for (size_t i = 0; i < task_spec.NumArgs(); i++) {
if (task_spec.ArgByRef(i)) {
for (size_t j = 0; j < task_spec.ArgIdCount(i); j++) {
task_deps->push_back(task_spec.ArgId(i, j));
}
}
}
// Note that we call this even if task_deps.size() == 0, in order to pin the return id.
for (size_t i = 0; i < num_returns; i++) {
reference_counter_->AddOwnedObject(task_spec.ReturnId(i, transport_type),
GetCallerId(), rpc_address_, task_deps);
}
}
Status CoreWorker::SubmitTask(const RayFunction &function,
const std::vector<TaskArg> &args,
const TaskOptions &task_options,
@ -642,11 +613,9 @@ Status CoreWorker::SubmitTask(const RayFunction &function,
return_ids);
TaskSpecification task_spec = builder.Build();
if (task_options.is_direct_call) {
task_manager_->AddPendingTask(task_spec, max_retries);
PinObjectReferences(task_spec, TaskTransportType::DIRECT);
task_manager_->AddPendingTask(GetCallerId(), rpc_address_, task_spec, max_retries);
return direct_task_submitter_->SubmitTask(task_spec);
} else {
PinObjectReferences(task_spec, TaskTransportType::RAYLET);
return local_raylet_client_->SubmitTask(task_spec);
}
}
@ -685,11 +654,10 @@ Status CoreWorker::CreateActor(const RayFunction &function,
*return_actor_id = actor_id;
TaskSpecification task_spec = builder.Build();
if (actor_creation_options.is_direct_call) {
task_manager_->AddPendingTask(task_spec, actor_creation_options.max_reconstructions);
PinObjectReferences(task_spec, TaskTransportType::DIRECT);
task_manager_->AddPendingTask(GetCallerId(), rpc_address_, task_spec,
actor_creation_options.max_reconstructions);
return direct_task_submitter_->SubmitTask(task_spec);
} else {
PinObjectReferences(task_spec, TaskTransportType::RAYLET);
return local_raylet_client_->SubmitTask(task_spec);
}
}
@ -729,8 +697,7 @@ Status CoreWorker::SubmitActorTask(const ActorID &actor_id, const RayFunction &f
Status status;
TaskSpecification task_spec = builder.Build();
if (is_direct_call) {
task_manager_->AddPendingTask(task_spec);
PinObjectReferences(task_spec, TaskTransportType::DIRECT);
task_manager_->AddPendingTask(GetCallerId(), rpc_address_, task_spec);
if (actor_handle->IsDead()) {
auto status = Status::IOError("sent task to dead actor");
task_manager_->PendingTaskFailed(task_spec.TaskId(), rpc::ErrorType::ACTOR_DIED,
@ -739,7 +706,6 @@ Status CoreWorker::SubmitActorTask(const ActorID &actor_id, const RayFunction &f
status = direct_actor_submitter_->SubmitTask(task_spec);
}
} else {
PinObjectReferences(task_spec, TaskTransportType::RAYLET);
RAY_CHECK_OK(local_raylet_client_->SubmitTask(task_spec));
}
return status;
@ -1043,17 +1009,17 @@ void CoreWorker::HandleGetObjectStatus(const rpc::GetObjectStatusRequest &reques
if (task_manager_->IsTaskPending(object_id.TaskId())) {
// Acquire a reference and retry. This prevents the object from being
// evicted out from under us before we can start the get.
AddObjectIDReference(object_id);
AddLocalReference(object_id);
if (task_manager_->IsTaskPending(object_id.TaskId())) {
// The task is pending. Send the reply once the task finishes.
memory_store_->GetAsync(object_id,
[send_reply_callback](std::shared_ptr<RayObject> obj) {
send_reply_callback(Status::OK(), nullptr, nullptr);
});
RemoveObjectIDReference(object_id);
RemoveLocalReference(object_id);
} else {
// We lost the race, the task is done.
RemoveObjectIDReference(object_id);
RemoveLocalReference(object_id);
send_reply_callback(Status::OK(), nullptr, nullptr);
}
} else {

View file

@ -101,17 +101,19 @@ class CoreWorker {
actor_id_ = actor_id;
}
/// Increase the reference count for this object ID.
/// Increase the local reference count for this object ID. Should be called
/// by the language frontend when a new reference is created.
///
/// \param[in] object_id The object ID to increase the reference count for.
void AddObjectIDReference(const ObjectID &object_id) {
void AddLocalReference(const ObjectID &object_id) {
reference_counter_->AddLocalReference(object_id);
}
/// Decrease the reference count for this object ID.
/// Decrease the reference count for this object ID. Should be called
/// by the language frontend when a reference is destroyed.
///
/// \param[in] object_id The object ID to decrease the reference count for.
void RemoveObjectIDReference(const ObjectID &object_id) {
void RemoveLocalReference(const ObjectID &object_id) {
std::vector<ObjectID> deleted;
reference_counter_->RemoveLocalReference(object_id, &deleted);
if (ref_counting_enabled_) {
@ -119,6 +121,12 @@ class CoreWorker {
}
}
/// Returns a map of all ObjectIDs currently in scope with a pair of their
/// (local, submitted_task) reference counts. For debugging purposes.
std::unordered_map<ObjectID, std::pair<size_t, size_t>> GetAllReferenceCounts() const {
return reference_counter_->GetAllReferenceCounts();
}
/// Promote an object to plasma and get its owner information. This should be
/// called when serializing an object ID, and the returned information should
/// be stored with the serialized object ID. For plasma promotion, if the
@ -434,11 +442,6 @@ class CoreWorker {
/// Private methods related to task submission.
///
/// Add task dependencies to the reference counter. This prevents the argument
/// objects from early eviction, and also adds the return object.
void PinObjectReferences(const TaskSpecification &task_spec,
const TaskTransportType transport_type);
/// Give this worker a handle to an actor.
///
/// This handle will remain as long as the current actor or task is
@ -485,17 +488,6 @@ class CoreWorker {
std::vector<std::shared_ptr<RayObject>> *args,
std::vector<ObjectID> *arg_reference_ids);
/// Remove reference counting dependencies of this object ID.
///
/// \param[in] object_id The object whose dependencies should be removed.
void RemoveObjectIDDependencies(const ObjectID &object_id) {
std::vector<ObjectID> deleted;
reference_counter_->RemoveDependencies(object_id, &deleted);
if (ref_counting_enabled_) {
memory_store_->Delete(deleted);
}
}
/// Returns whether the message was sent to the wrong worker. The right error reply
/// is sent automatically. Messages end up on the wrong worker when a worker dies
/// and a new one takes its place with the same place. In this situation, we want
@ -624,9 +616,6 @@ class CoreWorker {
absl::flat_hash_map<ActorID, std::unique_ptr<ActorHandle>> actor_handles_
GUARDED_BY(actor_handles_mutex_);
/// Resolve local and remote dependencies for actor creation.
std::unique_ptr<LocalDependencyResolver> resolver_;
///
/// Fields related to task execution.
///

View file

@ -14,81 +14,76 @@ void ReferenceCounter::AddBorrowedObject(const ObjectID &object_id,
}
}
void ReferenceCounter::AddOwnedObject(
const ObjectID &object_id, const TaskID &owner_id, const rpc::Address &owner_address,
std::shared_ptr<std::vector<ObjectID>> dependencies) {
void ReferenceCounter::AddOwnedObject(const ObjectID &object_id, const TaskID &owner_id,
const rpc::Address &owner_address) {
absl::MutexLock lock(&mutex_);
for (const ObjectID &dependency_id : *dependencies) {
AddLocalReferenceInternal(dependency_id);
}
RAY_CHECK(object_id_refs_.count(object_id) == 0)
<< "Cannot create an object that already exists. ObjectID: " << object_id;
<< "Tried to create an owned object that already exists: " << object_id;
// If the entry doesn't exist, we initialize the direct reference count to zero
// because this corresponds to a submitted task whose return ObjectID will be created
// in the frontend language, incrementing the reference count.
object_id_refs_.emplace(object_id, Reference(owner_id, owner_address, dependencies));
object_id_refs_.emplace(object_id, Reference(owner_id, owner_address));
}
void ReferenceCounter::AddLocalReferenceInternal(const ObjectID &object_id) {
void ReferenceCounter::AddLocalReference(const ObjectID &object_id) {
absl::MutexLock lock(&mutex_);
auto entry = object_id_refs_.find(object_id);
if (entry == object_id_refs_.end()) {
// TODO: Once ref counting is implemented, we should always know how the
// ObjectID was created, so there should always ben an entry.
// ObjectID was created, so there should always be an entry.
entry = object_id_refs_.emplace(object_id, Reference()).first;
}
entry->second.local_ref_count++;
}
void ReferenceCounter::AddLocalReference(const ObjectID &object_id) {
absl::MutexLock lock(&mutex_);
AddLocalReferenceInternal(object_id);
}
void ReferenceCounter::RemoveLocalReference(const ObjectID &object_id,
std::vector<ObjectID> *deleted) {
absl::MutexLock lock(&mutex_);
RemoveReferenceRecursive(object_id, deleted);
}
void ReferenceCounter::RemoveDependencies(const ObjectID &object_id,
std::vector<ObjectID> *deleted) {
absl::MutexLock lock(&mutex_);
auto entry = object_id_refs_.find(object_id);
if (entry == object_id_refs_.end()) {
RAY_LOG(WARNING) << "Tried to remove dependencies for nonexistent object ID: "
<< object_id;
return;
}
if (entry->second.dependencies) {
for (const ObjectID &pending_task_object_id : *entry->second.dependencies) {
RemoveReferenceRecursive(pending_task_object_id, deleted);
}
entry->second.dependencies = nullptr;
}
}
void ReferenceCounter::RemoveReferenceRecursive(const ObjectID &object_id,
std::vector<ObjectID> *deleted) {
auto entry = object_id_refs_.find(object_id);
if (entry == object_id_refs_.end()) {
RAY_LOG(WARNING) << "Tried to decrease ref count for nonexistent object ID: "
<< object_id;
return;
}
if (--entry->second.local_ref_count == 0) {
// If the reference count reached 0, decrease the reference count for each dependency.
if (entry->second.dependencies) {
for (const ObjectID &pending_task_object_id : *entry->second.dependencies) {
RemoveReferenceRecursive(pending_task_object_id, deleted);
}
}
if (--entry->second.local_ref_count == 0 &&
entry->second.submitted_task_ref_count == 0) {
object_id_refs_.erase(entry);
deleted->push_back(object_id);
}
}
void ReferenceCounter::AddSubmittedTaskReferences(
const std::vector<ObjectID> &object_ids) {
absl::MutexLock lock(&mutex_);
for (const ObjectID &object_id : object_ids) {
auto entry = object_id_refs_.find(object_id);
if (entry == object_id_refs_.end()) {
// TODO: Once ref counting is implemented, we should always know how the
// ObjectID was created, so there should always be an entry.
entry = object_id_refs_.emplace(object_id, Reference()).first;
}
entry->second.submitted_task_ref_count++;
}
}
void ReferenceCounter::RemoveSubmittedTaskReferences(
const std::vector<ObjectID> &object_ids, std::vector<ObjectID> *deleted) {
absl::MutexLock lock(&mutex_);
for (const ObjectID &object_id : object_ids) {
auto entry = object_id_refs_.find(object_id);
if (entry == object_id_refs_.end()) {
RAY_LOG(WARNING) << "Tried to decrease ref count for nonexistent object ID: "
<< object_id;
return;
}
if (--entry->second.submitted_task_ref_count == 0 &&
entry->second.local_ref_count == 0) {
object_id_refs_.erase(entry);
deleted->push_back(object_id);
}
}
}
bool ReferenceCounter::GetOwner(const ObjectID &object_id, TaskID *owner_id,
rpc::Address *owner_address) const {
absl::MutexLock lock(&mutex_);
@ -126,6 +121,19 @@ std::unordered_set<ObjectID> ReferenceCounter::GetAllInScopeObjectIDs() const {
return in_scope_object_ids;
}
std::unordered_map<ObjectID, std::pair<size_t, size_t>>
ReferenceCounter::GetAllReferenceCounts() const {
absl::MutexLock lock(&mutex_);
std::unordered_map<ObjectID, std::pair<size_t, size_t>> all_ref_counts;
all_ref_counts.reserve(object_id_refs_.size());
for (auto it : object_id_refs_) {
all_ref_counts.emplace(it.first,
std::pair<size_t, size_t>(it.second.local_ref_count,
it.second.submitted_task_ref_count));
}
return all_ref_counts;
}
void ReferenceCounter::LogDebugString() const {
absl::MutexLock lock(&mutex_);
@ -137,15 +145,9 @@ void ReferenceCounter::LogDebugString() const {
for (const auto &entry : object_id_refs_) {
RAY_LOG(DEBUG) << "\t" << entry.first.Hex();
RAY_LOG(DEBUG) << "\t\treference count: " << entry.second.local_ref_count;
RAY_LOG(DEBUG) << "\t\tdependencies: ";
if (!entry.second.dependencies) {
RAY_LOG(DEBUG) << "\t\t\tNULL";
} else {
for (const ObjectID &pending_task_object_id : *entry.second.dependencies) {
RAY_LOG(DEBUG) << "\t\t\t" << pending_task_object_id.Hex();
}
}
RAY_LOG(DEBUG) << "\t\tlocal refcount: " << entry.second.local_ref_count;
RAY_LOG(DEBUG) << "\t\tsubmitted task refcount: "
<< entry.second.submitted_task_ref_count;
}
}

View file

@ -25,27 +25,32 @@ class ReferenceCounter {
/// \param[in] object_id The object to to increment the count for.
void AddLocalReference(const ObjectID &object_id) LOCKS_EXCLUDED(mutex_);
/// Decrease the reference count for the ObjectID by one. If the reference count reaches
/// zero, it will be erased from the map and the reference count for all of its
/// dependencies will be decreased be one.
/// Decrease the local reference count for the ObjectID by one.
///
/// \param[in] object_id The object to decrement the count for.
/// \param[out] deleted List to store objects that hit zero ref count.
void RemoveLocalReference(const ObjectID &object_id, std::vector<ObjectID> *deleted)
LOCKS_EXCLUDED(mutex_);
/// Remove any references to dependencies that this object may have. This does *not*
/// decrease the object's own reference count.
/// Add references for the provided object IDs that correspond to them being
/// dependencies to a submitted task.
///
/// \param[in] object_id The object whose dependencies should be removed.
/// \param[out] deleted List to store objects that hit zero ref count.
void RemoveDependencies(const ObjectID &object_id, std::vector<ObjectID> *deleted)
LOCKS_EXCLUDED(mutex_);
/// \param[in] object_ids The object IDs to add references for.
void AddSubmittedTaskReferences(const std::vector<ObjectID> &object_ids);
/// Remove references for the provided object IDs that correspond to them being
/// dependencies to a submitted task. This should be called when inlined
/// dependencies are inlined or when the task finishes for plasma dependencies.
///
/// \param[in] object_ids The object IDs to remove references for.
/// \param[out] deleted The object IDs whos reference counts reached zero.
void RemoveSubmittedTaskReferences(const std::vector<ObjectID> &object_ids,
std::vector<ObjectID> *deleted);
/// Add an object that we own. The object may depend on other objects.
/// Dependencies for each ObjectID must be set at most once. The direct
/// reference count for the ObjectID is set to zero and the reference count
/// for each dependency is incremented.
/// Dependencies for each ObjectID must be set at most once. The local
/// reference count for the ObjectID is set to zero, which assumes that an
/// ObjectID for it will be created in the language frontend after this call.
///
/// TODO(swang): We could avoid copying the owner_id and owner_address since
/// we are the owner, but it is easier to store a copy for now, since the
@ -57,9 +62,7 @@ class ReferenceCounter {
/// \param[in] owner_address The address of the object's owner.
/// \param[in] dependencies The objects that the object depends on.
void AddOwnedObject(const ObjectID &object_id, const TaskID &owner_id,
const rpc::Address &owner_address,
std::shared_ptr<std::vector<ObjectID>> dependencies)
LOCKS_EXCLUDED(mutex_);
const rpc::Address &owner_address) LOCKS_EXCLUDED(mutex_);
/// Add an object that we are borrowing.
///
@ -82,6 +85,11 @@ class ReferenceCounter {
/// Returns a set of all ObjectIDs currently in scope (i.e., nonzero reference count).
std::unordered_set<ObjectID> GetAllInScopeObjectIDs() const LOCKS_EXCLUDED(mutex_);
/// Returns a map of all ObjectIDs currently in scope with a pair of their
/// (local, submitted_task) reference counts. For debugging purposes.
std::unordered_map<ObjectID, std::pair<size_t, size_t>> GetAllReferenceCounts() const
LOCKS_EXCLUDED(mutex_);
/// Dumps information about all currently tracked references to RAY_LOG(DEBUG).
void LogDebugString() const LOCKS_EXCLUDED(mutex_);
@ -91,22 +99,12 @@ class ReferenceCounter {
/// Constructor for a reference whose origin is unknown.
Reference() : owned_by_us(false) {}
/// Constructor for a reference that we created.
Reference(const TaskID &owner_id, const rpc::Address &owner_address,
std::shared_ptr<std::vector<ObjectID>> deps)
: dependencies(std::move(deps)),
owned_by_us(true),
owner({owner_id, owner_address}) {}
/// Constructor for a reference that was given to us.
Reference(const TaskID &owner_id, const rpc::Address &owner_address)
: owned_by_us(false), owner({owner_id, owner_address}) {}
: owned_by_us(true), owner({owner_id, owner_address}) {}
/// The local ref count for the ObjectID in the language frontend.
size_t local_ref_count = 0;
/// The objects that this object depends on. Tracked only by the owner of
/// the object. Dependencies are stored as shared_ptrs because the same set
/// of dependencies can be shared among multiple entries. For example, when
/// a task has multiple return values, the entry for each return ObjectID
/// depends on all task dependencies.
std::shared_ptr<std::vector<ObjectID>> dependencies;
/// The ref count for submitted tasks that depend on the ObjectID.
size_t submitted_task_ref_count = 0;
/// Whether we own the object. If we own the object, then we are
/// responsible for tracking the state of the task that creates the object
/// (see task_manager.h).

View file

@ -16,173 +16,63 @@ class ReferenceCountTest : public ::testing::Test {
virtual void TearDown() {}
};
// Tests basic incrementing/decrementing of direct reference counts. An entry should only
// be removed once its reference count reaches zero.
// Tests basic incrementing/decrementing of direct/submitted task reference counts. An
// entry should only be removed once both of its reference counts reach zero.
TEST_F(ReferenceCountTest, TestBasic) {
std::vector<ObjectID> out;
ObjectID id = ObjectID::FromRandom();
rc->AddLocalReference(id);
ASSERT_EQ(rc->NumObjectIDsInScope(), 1);
rc->AddLocalReference(id);
ASSERT_EQ(rc->NumObjectIDsInScope(), 1);
rc->AddLocalReference(id);
ASSERT_EQ(rc->NumObjectIDsInScope(), 1);
rc->RemoveLocalReference(id, &out);
ASSERT_EQ(rc->NumObjectIDsInScope(), 1);
ASSERT_EQ(out.size(), 0);
rc->RemoveLocalReference(id, &out);
ASSERT_EQ(rc->NumObjectIDsInScope(), 1);
ASSERT_EQ(out.size(), 0);
rc->RemoveLocalReference(id, &out);
ASSERT_EQ(rc->NumObjectIDsInScope(), 0);
ASSERT_EQ(out.size(), 1);
}
// Tests the basic logic for dependencies - when an ObjectID with dependencies
// goes out of scope (i.e., reference count reaches zero), all of its dependencies
// should have their reference count decremented and be removed if it reaches zero.
TEST_F(ReferenceCountTest, TestDependencies) {
std::vector<ObjectID> out;
ObjectID id1 = ObjectID::FromRandom();
ObjectID id2 = ObjectID::FromRandom();
ObjectID id3 = ObjectID::FromRandom();
std::shared_ptr<std::vector<ObjectID>> deps = std::make_shared<std::vector<ObjectID>>();
deps->push_back(id2);
deps->push_back(id3);
rc->AddOwnedObject(id1, TaskID::Nil(), rpc::Address(), deps);
// Local references.
rc->AddLocalReference(id1);
rc->AddLocalReference(id1);
rc->AddLocalReference(id3);
ASSERT_EQ(rc->NumObjectIDsInScope(), 3);
rc->RemoveLocalReference(id1, &out);
ASSERT_EQ(rc->NumObjectIDsInScope(), 3);
ASSERT_EQ(out.size(), 0);
rc->RemoveLocalReference(id1, &out);
ASSERT_EQ(rc->NumObjectIDsInScope(), 1);
ASSERT_EQ(out.size(), 2);
rc->RemoveLocalReference(id3, &out);
ASSERT_EQ(rc->NumObjectIDsInScope(), 0);
ASSERT_EQ(out.size(), 3);
}
// Tests the case where two entries share the same set of dependencies. When one
// entry goes out of scope, it should decrease the reference count for the dependencies
// but they should still be nonzero until the second entry goes out of scope and all
// direct dependencies to the dependencies are removed.
TEST_F(ReferenceCountTest, TestSharedDependencies) {
std::vector<ObjectID> out;
ObjectID id1 = ObjectID::FromRandom();
ObjectID id2 = ObjectID::FromRandom();
ObjectID id3 = ObjectID::FromRandom();
ObjectID id4 = ObjectID::FromRandom();
std::shared_ptr<std::vector<ObjectID>> deps = std::make_shared<std::vector<ObjectID>>();
deps->push_back(id3);
deps->push_back(id4);
rc->AddOwnedObject(id1, TaskID::Nil(), rpc::Address(), deps);
rc->AddOwnedObject(id2, TaskID::Nil(), rpc::Address(), deps);
rc->AddLocalReference(id1);
rc->AddLocalReference(id2);
rc->AddLocalReference(id4);
ASSERT_EQ(rc->NumObjectIDsInScope(), 4);
rc->RemoveLocalReference(id1, &out);
ASSERT_EQ(rc->NumObjectIDsInScope(), 3);
ASSERT_EQ(out.size(), 1);
rc->RemoveLocalReference(id2, &out);
ASSERT_EQ(rc->NumObjectIDsInScope(), 1);
ASSERT_EQ(out.size(), 3);
rc->RemoveLocalReference(id4, &out);
ASSERT_EQ(rc->NumObjectIDsInScope(), 0);
ASSERT_EQ(out.size(), 4);
}
// Tests the case when an entry has a dependency that itself has a
// dependency. In this case, when the first entry goes out of scope
// it should decrease the reference count for its dependency, causing
// that entry to go out of scope and decrease its dependencies' reference counts.
TEST_F(ReferenceCountTest, TestRecursiveDependencies) {
std::vector<ObjectID> out;
ObjectID id1 = ObjectID::FromRandom();
ObjectID id2 = ObjectID::FromRandom();
ObjectID id3 = ObjectID::FromRandom();
ObjectID id4 = ObjectID::FromRandom();
std::shared_ptr<std::vector<ObjectID>> deps2 =
std::make_shared<std::vector<ObjectID>>();
deps2->push_back(id3);
deps2->push_back(id4);
rc->AddOwnedObject(id2, TaskID::Nil(), rpc::Address(), deps2);
std::shared_ptr<std::vector<ObjectID>> deps1 =
std::make_shared<std::vector<ObjectID>>();
deps1->push_back(id2);
rc->AddOwnedObject(id1, TaskID::Nil(), rpc::Address(), deps1);
rc->AddLocalReference(id1);
rc->AddLocalReference(id2);
rc->AddLocalReference(id4);
ASSERT_EQ(rc->NumObjectIDsInScope(), 4);
rc->RemoveLocalReference(id2, &out);
ASSERT_EQ(rc->NumObjectIDsInScope(), 4);
ASSERT_EQ(out.size(), 0);
rc->RemoveLocalReference(id1, &out);
ASSERT_EQ(rc->NumObjectIDsInScope(), 1);
ASSERT_EQ(out.size(), 3);
rc->RemoveLocalReference(id4, &out);
ASSERT_EQ(rc->NumObjectIDsInScope(), 0);
ASSERT_EQ(out.size(), 4);
}
TEST_F(ReferenceCountTest, TestRemoveDependenciesOnly) {
std::vector<ObjectID> out;
ObjectID id1 = ObjectID::FromRandom();
ObjectID id2 = ObjectID::FromRandom();
ObjectID id3 = ObjectID::FromRandom();
ObjectID id4 = ObjectID::FromRandom();
std::shared_ptr<std::vector<ObjectID>> deps2 =
std::make_shared<std::vector<ObjectID>>();
deps2->push_back(id3);
deps2->push_back(id4);
rc->AddOwnedObject(id2, TaskID::Nil(), rpc::Address(), deps2);
std::shared_ptr<std::vector<ObjectID>> deps1 =
std::make_shared<std::vector<ObjectID>>();
deps1->push_back(id2);
rc->AddOwnedObject(id1, TaskID::Nil(), rpc::Address(), deps1);
rc->AddLocalReference(id1);
rc->AddLocalReference(id2);
rc->AddLocalReference(id4);
ASSERT_EQ(rc->NumObjectIDsInScope(), 4);
rc->RemoveDependencies(id2, &out);
ASSERT_EQ(rc->NumObjectIDsInScope(), 3);
ASSERT_EQ(out.size(), 1);
rc->RemoveDependencies(id1, &out);
ASSERT_EQ(rc->NumObjectIDsInScope(), 3);
ASSERT_EQ(out.size(), 1);
ASSERT_EQ(rc->NumObjectIDsInScope(), 2);
rc->RemoveLocalReference(id1, &out);
ASSERT_EQ(rc->NumObjectIDsInScope(), 2);
ASSERT_EQ(out.size(), 2);
ASSERT_EQ(out.size(), 0);
rc->RemoveLocalReference(id2, &out);
ASSERT_EQ(rc->NumObjectIDsInScope(), 1);
ASSERT_EQ(out.size(), 3);
rc->RemoveLocalReference(id4, &out);
ASSERT_EQ(out.size(), 1);
rc->RemoveLocalReference(id1, &out);
ASSERT_EQ(rc->NumObjectIDsInScope(), 0);
ASSERT_EQ(out.size(), 4);
ASSERT_EQ(out.size(), 2);
out.clear();
// Submitted task references.
rc->AddSubmittedTaskReferences({id1});
rc->AddSubmittedTaskReferences({id1, id2});
ASSERT_EQ(rc->NumObjectIDsInScope(), 2);
rc->RemoveSubmittedTaskReferences({id1}, &out);
ASSERT_EQ(rc->NumObjectIDsInScope(), 2);
ASSERT_EQ(out.size(), 0);
rc->RemoveSubmittedTaskReferences({id2}, &out);
ASSERT_EQ(rc->NumObjectIDsInScope(), 1);
ASSERT_EQ(out.size(), 1);
rc->RemoveSubmittedTaskReferences({id1}, &out);
ASSERT_EQ(rc->NumObjectIDsInScope(), 0);
ASSERT_EQ(out.size(), 2);
out.clear();
// Local & submitted task references.
rc->AddLocalReference(id1);
rc->AddSubmittedTaskReferences({id1, id2});
rc->AddLocalReference(id2);
ASSERT_EQ(rc->NumObjectIDsInScope(), 2);
rc->RemoveLocalReference(id1, &out);
ASSERT_EQ(rc->NumObjectIDsInScope(), 2);
ASSERT_EQ(out.size(), 0);
rc->RemoveSubmittedTaskReferences({id2}, &out);
ASSERT_EQ(rc->NumObjectIDsInScope(), 2);
ASSERT_EQ(out.size(), 0);
rc->RemoveSubmittedTaskReferences({id1}, &out);
ASSERT_EQ(rc->NumObjectIDsInScope(), 1);
ASSERT_EQ(out.size(), 1);
rc->RemoveLocalReference(id2, &out);
ASSERT_EQ(rc->NumObjectIDsInScope(), 0);
ASSERT_EQ(out.size(), 2);
out.clear();
}
// Tests that we can get the owner address correctly for objects that we own,
@ -193,8 +83,7 @@ TEST_F(ReferenceCountTest, TestOwnerAddress) {
TaskID task_id = TaskID::ForFakeTask();
rpc::Address address;
address.set_ip_address("1234");
auto deps = std::make_shared<std::vector<ObjectID>>();
rc->AddOwnedObject(object_id, task_id, address, deps);
rc->AddOwnedObject(object_id, task_id, address);
TaskID added_id;
rpc::Address added_address;
@ -205,7 +94,7 @@ TEST_F(ReferenceCountTest, TestOwnerAddress) {
auto object_id2 = ObjectID::FromRandom();
task_id = TaskID::ForFakeTask();
address.set_ip_address("5678");
rc->AddOwnedObject(object_id2, task_id, address, deps);
rc->AddOwnedObject(object_id2, task_id, address);
ASSERT_TRUE(rc->GetOwner(object_id2, &added_id, &added_address));
ASSERT_EQ(task_id, added_id);
ASSERT_EQ(address.ip_address(), added_address.ip_address());

View file

@ -10,11 +10,34 @@ const int64_t kTaskFailureThrottlingThreshold = 50;
// Throttle task failure logs to once this interval.
const int64_t kTaskFailureLoggingFrequencyMillis = 5000;
void TaskManager::AddPendingTask(const TaskSpecification &spec, int max_retries) {
void TaskManager::AddPendingTask(const TaskID &caller_id,
const rpc::Address &caller_address,
const TaskSpecification &spec, int max_retries) {
RAY_LOG(DEBUG) << "Adding pending task " << spec.TaskId();
absl::MutexLock lock(&mu_);
std::pair<TaskSpecification, int> entry = {spec, max_retries};
RAY_CHECK(pending_tasks_.emplace(spec.TaskId(), std::move(entry)).second);
// Add references for the dependencies to the task.
std::vector<ObjectID> task_deps;
for (size_t i = 0; i < spec.NumArgs(); i++) {
if (spec.ArgByRef(i)) {
for (size_t j = 0; j < spec.ArgIdCount(i); j++) {
task_deps.push_back(spec.ArgId(i, j));
}
}
}
reference_counter_->AddSubmittedTaskReferences(task_deps);
// Add new owned objects for the return values of the task.
size_t num_returns = spec.NumReturns();
if (spec.IsActorCreationTask() || spec.IsActorTask()) {
num_returns--;
}
for (size_t i = 0; i < num_returns; i++) {
reference_counter_->AddOwnedObject(spec.ReturnId(i, TaskTransportType::DIRECT),
caller_id, caller_address);
}
}
void TaskManager::DrainAndShutdown(std::function<void()> shutdown) {
@ -48,6 +71,8 @@ void TaskManager::CompletePendingTask(const TaskID &task_id,
pending_tasks_.erase(it);
}
RemovePlasmaSubmittedTaskReferences(spec);
for (int i = 0; i < reply.return_objects_size(); i++) {
const auto &return_object = reply.return_objects(i);
ObjectID object_id = ObjectID::FromBinary(return_object.object_id());
@ -134,6 +159,7 @@ void TaskManager::PendingTaskFailed(const TaskID &task_id, rpc::ErrorType error_
}
}
}
RemovePlasmaSubmittedTaskReferences(spec);
MarkPendingTaskFailed(task_id, spec, error_type);
}
@ -148,6 +174,28 @@ void TaskManager::ShutdownIfNeeded() {
}
}
void TaskManager::RemoveSubmittedTaskReferences(const std::vector<ObjectID> &object_ids) {
std::vector<ObjectID> deleted;
reference_counter_->RemoveSubmittedTaskReferences(object_ids, &deleted);
in_memory_store_->Delete(deleted);
}
void TaskManager::OnTaskDependenciesInlined(const std::vector<ObjectID> &object_ids) {
RemoveSubmittedTaskReferences(object_ids);
}
void TaskManager::RemovePlasmaSubmittedTaskReferences(TaskSpecification &spec) {
std::vector<ObjectID> plasma_dependencies;
for (size_t i = 0; i < spec.NumArgs(); i++) {
auto count = spec.ArgIdCount(i);
if (count > 0) {
const auto &id = spec.ArgId(i, 0);
plasma_dependencies.push_back(id);
}
}
RemoveSubmittedTaskReferences(plasma_dependencies);
}
void TaskManager::MarkPendingTaskFailed(const TaskID &task_id,
const TaskSpecification &spec,
rpc::ErrorType error_type) {

View file

@ -21,6 +21,8 @@ class TaskFinisherInterface {
virtual void PendingTaskFailed(const TaskID &task_id, rpc::ErrorType error_type,
Status *status = nullptr) = 0;
virtual void OnTaskDependenciesInlined(const std::vector<ObjectID> &object_ids) = 0;
virtual ~TaskFinisherInterface() {}
};
@ -29,19 +31,24 @@ using RetryTaskCallback = std::function<void(const TaskSpecification &spec)>;
class TaskManager : public TaskFinisherInterface {
public:
TaskManager(std::shared_ptr<CoreWorkerMemoryStore> in_memory_store,
std::shared_ptr<ReferenceCounter> reference_counter,
std::shared_ptr<ActorManagerInterface> actor_manager,
RetryTaskCallback retry_task_callback)
: in_memory_store_(in_memory_store),
reference_counter_(reference_counter),
actor_manager_(actor_manager),
retry_task_callback_(retry_task_callback) {}
/// Add a task that is pending execution.
///
/// \param[in] caller_id The TaskID of the calling task.
/// \param[in] caller_address The rpc address of the calling task.
/// \param[in] spec The spec of the pending task.
/// \param[in] max_retries Number of times this task may be retried
/// on failure.
/// \return Void.
void AddPendingTask(const TaskSpecification &spec, int max_retries = 0);
void AddPendingTask(const TaskID &caller_id, const rpc::Address &caller_address,
const TaskSpecification &spec, int max_retries = 0);
/// Wait for all pending tasks to finish, and then shutdown.
///
@ -72,6 +79,8 @@ class TaskManager : public TaskFinisherInterface {
void PendingTaskFailed(const TaskID &task_id, rpc::ErrorType error_type,
Status *status = nullptr) override;
void OnTaskDependenciesInlined(const std::vector<ObjectID> &object_id) override;
/// Return the spec for a pending task.
TaskSpecification GetTaskSpec(const TaskID &task_id) const;
@ -81,12 +90,25 @@ class TaskManager : public TaskFinisherInterface {
void MarkPendingTaskFailed(const TaskID &task_id, const TaskSpecification &spec,
rpc::ErrorType error_type) LOCKS_EXCLUDED(mu_);
/// Remove submittted task references in the reference counter for the object IDs.
/// If their reference counts reach zero, they are deleted from the in-memory store.
void RemoveSubmittedTaskReferences(const std::vector<ObjectID> &object_ids);
/// Helper function to call RemoveSubmittedTaskReferences on the plasma dependencies
/// of the given task spec.
void RemovePlasmaSubmittedTaskReferences(TaskSpecification &spec);
/// Shutdown if all tasks are finished and shutdown is scheduled.
void ShutdownIfNeeded() LOCKS_EXCLUDED(mu_);
/// Used to store task results.
std::shared_ptr<CoreWorkerMemoryStore> in_memory_store_;
/// Used for reference counting objects.
/// The task manager is responsible for managing all references related to
/// submitted tasks (dependencies and return objects).
std::shared_ptr<ReferenceCounter> reference_counter_;
// Interface for publishing actor creation.
std::shared_ptr<ActorManagerInterface> actor_manager_;

View file

@ -1,6 +1,5 @@
#include "gmock/gmock.h"
#include "gtest/gtest.h"
#include "ray/common/task/task_spec.h"
#include "ray/core_worker/store_provider/memory_store/memory_store.h"
#include "ray/core_worker/transport/direct_task_transport.h"
@ -45,6 +44,8 @@ class MockTaskFinisher : public TaskFinisherInterface {
const rpc::Address *addr));
MOCK_METHOD3(PendingTaskFailed,
void(const TaskID &task_id, rpc::ErrorType error_type, Status *status));
MOCK_METHOD1(OnTaskDependenciesInlined, void(const std::vector<ObjectID> &object_ids));
};
TaskSpecification CreateActorTaskHelper(ActorID actor_id, int64_t counter) {

View file

@ -55,8 +55,13 @@ class MockTaskFinisher : public TaskFinisherInterface {
num_tasks_failed++;
}
void OnTaskDependenciesInlined(const std::vector<ObjectID> &object_ids) override {
num_inlined += object_ids.size();
}
int num_tasks_complete = 0;
int num_tasks_failed = 0;
int num_inlined = 0;
};
class MockRayletClient : public WorkerLeaseInterface {
@ -136,17 +141,20 @@ TEST(TestMemoryStore, TestPromoteToPlasma) {
TEST(LocalDependencyResolverTest, TestNoDependencies) {
auto store = std::make_shared<CoreWorkerMemoryStore>();
LocalDependencyResolver resolver(store);
auto task_finisher = std::make_shared<MockTaskFinisher>();
LocalDependencyResolver resolver(store, task_finisher);
TaskSpecification task;
bool ok = false;
resolver.ResolveDependencies(task, [&ok]() { ok = true; });
ASSERT_TRUE(ok);
ASSERT_EQ(task_finisher->num_inlined, 0);
}
TEST(LocalDependencyResolverTest, TestIgnorePlasmaDependencies) {
auto store = std::make_shared<CoreWorkerMemoryStore>();
LocalDependencyResolver resolver(store);
ObjectID obj1 = ObjectID::FromRandom().WithTransportType(TaskTransportType::RAYLET);
auto task_finisher = std::make_shared<MockTaskFinisher>();
LocalDependencyResolver resolver(store, task_finisher);
ObjectID obj1 = ObjectID::FromRandom();
TaskSpecification task;
task.GetMutableMessage().add_args()->add_object_ids(obj1.Binary());
bool ok = false;
@ -154,11 +162,13 @@ TEST(LocalDependencyResolverTest, TestIgnorePlasmaDependencies) {
// We ignore and don't block on plasma dependencies.
ASSERT_TRUE(ok);
ASSERT_EQ(resolver.NumPendingTasks(), 0);
ASSERT_EQ(task_finisher->num_inlined, 0);
}
TEST(LocalDependencyResolverTest, TestHandlePlasmaPromotion) {
auto store = std::make_shared<CoreWorkerMemoryStore>();
LocalDependencyResolver resolver(store);
auto task_finisher = std::make_shared<MockTaskFinisher>();
LocalDependencyResolver resolver(store, task_finisher);
ObjectID obj1 = ObjectID::FromRandom().WithTransportType(TaskTransportType::DIRECT);
std::string meta = std::to_string(static_cast<int>(rpc::ErrorType::OBJECT_IN_PLASMA));
auto metadata = const_cast<uint8_t *>(reinterpret_cast<const uint8_t *>(meta.data()));
@ -175,11 +185,13 @@ TEST(LocalDependencyResolverTest, TestHandlePlasmaPromotion) {
// Checks that the object id is still a direct call id.
ASSERT_TRUE(task.ArgId(0, 0).IsDirectCallType());
ASSERT_EQ(resolver.NumPendingTasks(), 0);
ASSERT_EQ(task_finisher->num_inlined, 0);
}
TEST(LocalDependencyResolverTest, TestInlineLocalDependencies) {
auto store = std::make_shared<CoreWorkerMemoryStore>();
LocalDependencyResolver resolver(store);
auto task_finisher = std::make_shared<MockTaskFinisher>();
LocalDependencyResolver resolver(store, task_finisher);
ObjectID obj1 = ObjectID::FromRandom().WithTransportType(TaskTransportType::DIRECT);
ObjectID obj2 = ObjectID::FromRandom().WithTransportType(TaskTransportType::DIRECT);
auto data = GenerateRandomObject();
@ -198,11 +210,13 @@ TEST(LocalDependencyResolverTest, TestInlineLocalDependencies) {
ASSERT_NE(task.ArgData(0), nullptr);
ASSERT_NE(task.ArgData(1), nullptr);
ASSERT_EQ(resolver.NumPendingTasks(), 0);
ASSERT_EQ(task_finisher->num_inlined, 2);
}
TEST(LocalDependencyResolverTest, TestInlinePendingDependencies) {
auto store = std::make_shared<CoreWorkerMemoryStore>();
LocalDependencyResolver resolver(store);
auto task_finisher = std::make_shared<MockTaskFinisher>();
LocalDependencyResolver resolver(store, task_finisher);
ObjectID obj1 = ObjectID::FromRandom().WithTransportType(TaskTransportType::DIRECT);
ObjectID obj2 = ObjectID::FromRandom().WithTransportType(TaskTransportType::DIRECT);
auto data = GenerateRandomObject();
@ -223,6 +237,7 @@ TEST(LocalDependencyResolverTest, TestInlinePendingDependencies) {
ASSERT_NE(task.ArgData(0), nullptr);
ASSERT_NE(task.ArgData(1), nullptr);
ASSERT_EQ(resolver.NumPendingTasks(), 0);
ASSERT_EQ(task_finisher->num_inlined, 2);
}
TaskSpecification BuildTaskSpec(const std::unordered_map<std::string, double> &resources,

View file

@ -1,17 +1,22 @@
#include "gtest/gtest.h"
#include "ray/core_worker/task_manager.h"
#include "gtest/gtest.h"
#include "ray/common/task/task_spec.h"
#include "ray/core_worker/actor_manager.h"
#include "ray/core_worker/reference_count.h"
#include "ray/core_worker/store_provider/memory_store/memory_store.h"
#include "ray/core_worker/task_manager.h"
#include "ray/util/test_util.h"
namespace ray {
TaskSpecification CreateTaskHelper(uint64_t num_returns) {
TaskSpecification CreateTaskHelper(uint64_t num_returns,
std::vector<ObjectID> dependencies) {
TaskSpecification task;
task.GetMutableMessage().set_task_id(TaskID::ForFakeTask().Binary());
task.GetMutableMessage().set_num_returns(num_returns);
for (const ObjectID &dep : dependencies) {
task.GetMutableMessage().add_args()->add_object_ids(dep.Binary());
}
return task;
}
@ -33,23 +38,31 @@ class TaskManagerTest : public ::testing::Test {
public:
TaskManagerTest()
: store_(std::shared_ptr<CoreWorkerMemoryStore>(new CoreWorkerMemoryStore())),
reference_counter_(std::shared_ptr<ReferenceCounter>(new ReferenceCounter())),
actor_manager_(std::shared_ptr<ActorManagerInterface>(new MockActorManager())),
manager_(store_, actor_manager_, [this](const TaskSpecification &spec) {
num_retries_++;
return Status::OK();
}) {}
manager_(store_, reference_counter_, actor_manager_,
[this](const TaskSpecification &spec) {
num_retries_++;
return Status::OK();
}) {}
std::shared_ptr<CoreWorkerMemoryStore> store_;
std::shared_ptr<ReferenceCounter> reference_counter_;
std::shared_ptr<ActorManagerInterface> actor_manager_;
TaskManager manager_;
int num_retries_ = 0;
};
TEST_F(TaskManagerTest, TestTaskSuccess) {
auto spec = CreateTaskHelper(1);
TaskID caller_id = TaskID::Nil();
rpc::Address caller_address;
ObjectID dep1 = ObjectID::FromRandom();
ObjectID dep2 = ObjectID::FromRandom();
auto spec = CreateTaskHelper(1, {dep1, dep2});
ASSERT_FALSE(manager_.IsTaskPending(spec.TaskId()));
manager_.AddPendingTask(spec);
manager_.AddPendingTask(caller_id, caller_address, spec);
ASSERT_TRUE(manager_.IsTaskPending(spec.TaskId()));
ASSERT_EQ(reference_counter_->NumObjectIDsInScope(), 3);
auto return_id = spec.ReturnId(0, TaskTransportType::DIRECT);
WorkerContext ctx(WorkerType::WORKER, JobID::FromInt(0));
@ -60,6 +73,8 @@ TEST_F(TaskManagerTest, TestTaskSuccess) {
return_object->set_data(data->Data(), data->Size());
manager_.CompletePendingTask(spec.TaskId(), reply, nullptr);
ASSERT_FALSE(manager_.IsTaskPending(spec.TaskId()));
// Only the return object reference should remain.
ASSERT_EQ(reference_counter_->NumObjectIDsInScope(), 1);
std::vector<std::shared_ptr<RayObject>> results;
RAY_CHECK_OK(store_->Get({return_id}, 1, -1, ctx, false, &results));
@ -69,19 +84,33 @@ TEST_F(TaskManagerTest, TestTaskSuccess) {
return_object->data().size()),
0);
ASSERT_EQ(num_retries_, 0);
std::vector<ObjectID> removed;
reference_counter_->AddLocalReference(return_id);
reference_counter_->RemoveLocalReference(return_id, &removed);
ASSERT_EQ(removed[0], return_id);
ASSERT_EQ(reference_counter_->NumObjectIDsInScope(), 0);
}
TEST_F(TaskManagerTest, TestTaskFailure) {
auto spec = CreateTaskHelper(1);
TaskID caller_id = TaskID::Nil();
rpc::Address caller_address;
ObjectID dep1 = ObjectID::FromRandom();
ObjectID dep2 = ObjectID::FromRandom();
ASSERT_EQ(reference_counter_->NumObjectIDsInScope(), 0);
auto spec = CreateTaskHelper(1, {dep1, dep2});
ASSERT_FALSE(manager_.IsTaskPending(spec.TaskId()));
manager_.AddPendingTask(spec);
manager_.AddPendingTask(caller_id, caller_address, spec);
ASSERT_TRUE(manager_.IsTaskPending(spec.TaskId()));
ASSERT_EQ(reference_counter_->NumObjectIDsInScope(), 3);
auto return_id = spec.ReturnId(0, TaskTransportType::DIRECT);
WorkerContext ctx(WorkerType::WORKER, JobID::FromInt(0));
auto error = rpc::ErrorType::WORKER_DIED;
manager_.PendingTaskFailed(spec.TaskId(), error);
ASSERT_FALSE(manager_.IsTaskPending(spec.TaskId()));
// Only the return object reference should remain.
ASSERT_EQ(reference_counter_->NumObjectIDsInScope(), 1);
std::vector<std::shared_ptr<RayObject>> results;
RAY_CHECK_OK(store_->Get({return_id}, 1, -1, ctx, false, &results));
@ -90,14 +119,26 @@ TEST_F(TaskManagerTest, TestTaskFailure) {
ASSERT_TRUE(results[0]->IsException(&stored_error));
ASSERT_EQ(stored_error, error);
ASSERT_EQ(num_retries_, 0);
std::vector<ObjectID> removed;
reference_counter_->AddLocalReference(return_id);
reference_counter_->RemoveLocalReference(return_id, &removed);
ASSERT_EQ(removed[0], return_id);
ASSERT_EQ(reference_counter_->NumObjectIDsInScope(), 0);
}
TEST_F(TaskManagerTest, TestTaskRetry) {
auto spec = CreateTaskHelper(1);
TaskID caller_id = TaskID::Nil();
rpc::Address caller_address;
ObjectID dep1 = ObjectID::FromRandom();
ObjectID dep2 = ObjectID::FromRandom();
ASSERT_EQ(reference_counter_->NumObjectIDsInScope(), 0);
auto spec = CreateTaskHelper(1, {dep1, dep2});
ASSERT_FALSE(manager_.IsTaskPending(spec.TaskId()));
int num_retries = 3;
manager_.AddPendingTask(spec, num_retries);
manager_.AddPendingTask(caller_id, caller_address, spec, num_retries);
ASSERT_TRUE(manager_.IsTaskPending(spec.TaskId()));
ASSERT_EQ(reference_counter_->NumObjectIDsInScope(), 3);
auto return_id = spec.ReturnId(0, TaskTransportType::DIRECT);
WorkerContext ctx(WorkerType::WORKER, JobID::FromInt(0));
@ -105,6 +146,7 @@ TEST_F(TaskManagerTest, TestTaskRetry) {
for (int i = 0; i < num_retries; i++) {
manager_.PendingTaskFailed(spec.TaskId(), error);
ASSERT_TRUE(manager_.IsTaskPending(spec.TaskId()));
ASSERT_EQ(reference_counter_->NumObjectIDsInScope(), 3);
std::vector<std::shared_ptr<RayObject>> results;
ASSERT_FALSE(store_->Get({return_id}, 1, 0, ctx, false, &results).ok());
ASSERT_EQ(num_retries_, i + 1);
@ -112,6 +154,8 @@ TEST_F(TaskManagerTest, TestTaskRetry) {
manager_.PendingTaskFailed(spec.TaskId(), error);
ASSERT_FALSE(manager_.IsTaskPending(spec.TaskId()));
// Only the return object reference should remain.
ASSERT_EQ(reference_counter_->NumObjectIDsInScope(), 1);
std::vector<std::shared_ptr<RayObject>> results;
RAY_CHECK_OK(store_->Get({return_id}, 1, -0, ctx, false, &results));
@ -119,6 +163,12 @@ TEST_F(TaskManagerTest, TestTaskRetry) {
rpc::ErrorType stored_error;
ASSERT_TRUE(results[0]->IsException(&stored_error));
ASSERT_EQ(stored_error, error);
std::vector<ObjectID> removed;
reference_counter_->AddLocalReference(return_id);
reference_counter_->RemoveLocalReference(return_id, &removed);
ASSERT_EQ(removed[0], return_id);
ASSERT_EQ(reference_counter_->NumObjectIDsInScope(), 0);
}
} // namespace ray

View file

@ -18,7 +18,7 @@ struct TaskState {
void InlineDependencies(
absl::flat_hash_map<ObjectID, std::shared_ptr<RayObject>> dependencies,
TaskSpecification &task) {
TaskSpecification &task, std::vector<ObjectID> *inlined) {
auto &msg = task.GetMutableMessage();
size_t found = 0;
for (size_t i = 0; i < task.NumArgs(); i++) {
@ -43,6 +43,7 @@ void InlineDependencies(
const auto &metadata = it->second->GetMetadata();
mutable_arg->set_metadata(metadata->Data(), metadata->Size());
}
inlined->push_back(id);
}
found++;
} else {
@ -83,15 +84,19 @@ void LocalDependencyResolver::ResolveDependencies(TaskSpecification &task,
obj_id, [this, state, obj_id, on_complete](std::shared_ptr<RayObject> obj) {
RAY_CHECK(obj != nullptr);
bool complete = false;
std::vector<ObjectID> inlined;
{
absl::MutexLock lock(&mu_);
state->local_dependencies[obj_id] = std::move(obj);
if (--state->dependencies_remaining == 0) {
InlineDependencies(state->local_dependencies, state->task);
InlineDependencies(state->local_dependencies, state->task, &inlined);
complete = true;
num_pending_ -= 1;
}
}
if (inlined.size() > 0) {
task_finisher_->OnTaskDependenciesInlined(inlined);
}
if (complete) {
on_complete();
}

View file

@ -6,14 +6,16 @@
#include "ray/common/id.h"
#include "ray/common/task/task_spec.h"
#include "ray/core_worker/store_provider/memory_store/memory_store.h"
#include "ray/core_worker/task_manager.h"
namespace ray {
// This class is thread-safe.
class LocalDependencyResolver {
public:
LocalDependencyResolver(std::shared_ptr<CoreWorkerMemoryStore> store)
: in_memory_store_(store), num_pending_(0) {}
LocalDependencyResolver(std::shared_ptr<CoreWorkerMemoryStore> store,
std::shared_ptr<TaskFinisherInterface> task_finisher)
: in_memory_store_(store), task_finisher_(task_finisher), num_pending_(0) {}
/// Resolve all local and remote dependencies for the task, calling the specified
/// callback when done. Direct call ids in the task specification will be resolved
@ -33,6 +35,9 @@ class LocalDependencyResolver {
/// The in-memory store.
std::shared_ptr<CoreWorkerMemoryStore> in_memory_store_;
/// Used to complete tasks.
std::shared_ptr<TaskFinisherInterface> task_finisher_;
/// Number of tasks pending dependency resolution.
std::atomic<int> num_pending_;

View file

@ -8,9 +8,9 @@
#include <queue>
#include <set>
#include <utility>
#include "absl/container/flat_hash_map.h"
#include "absl/base/thread_annotations.h"
#include "absl/container/flat_hash_map.h"
#include "absl/synchronization/mutex.h"
#include "ray/common/id.h"
#include "ray/common/ray_object.h"
@ -39,7 +39,7 @@ class CoreWorkerDirectActorTaskSubmitter {
std::shared_ptr<CoreWorkerMemoryStore> store,
std::shared_ptr<TaskFinisherInterface> task_finisher)
: client_factory_(client_factory),
resolver_(store),
resolver_(store, task_finisher),
task_finisher_(task_finisher) {}
/// Submit a task to an actor for execution.

View file

@ -5,7 +5,6 @@
#include "absl/base/thread_annotations.h"
#include "absl/synchronization/mutex.h"
#include "ray/common/id.h"
#include "ray/common/ray_object.h"
#include "ray/core_worker/context.h"
@ -41,7 +40,7 @@ class CoreWorkerDirectTaskSubmitter {
: local_lease_client_(lease_client),
client_factory_(client_factory),
lease_client_factory_(lease_client_factory),
resolver_(store),
resolver_(store, task_finisher),
task_finisher_(task_finisher),
local_raylet_id_(local_raylet_id),
lease_timeout_ms_(lease_timeout_ms) {}