Reference counting for direct call submitted tasks (#6514)

Co-authored-by: Zhijun Fu <37800433+zhijunfu@users.noreply.github.com>
This commit is contained in:
Edward Oakes 2019-12-20 17:06:33 -08:00 committed by GitHub
parent b0b6b56bb7
commit e50aa99be1
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
17 changed files with 512 additions and 344 deletions

View file

@ -1073,16 +1073,12 @@ cdef class CoreWorker:
return output return output
def add_object_id_reference(self, ObjectID object_id): def add_object_id_reference(self, ObjectID object_id):
cdef:
CObjectID c_object_id = object_id.native()
# Note: faster to not release GIL for short-running op. # Note: faster to not release GIL for short-running op.
self.core_worker.get().AddObjectIDReference(c_object_id) self.core_worker.get().AddLocalReference(object_id.native())
def remove_object_id_reference(self, ObjectID object_id): def remove_object_id_reference(self, ObjectID object_id):
cdef:
CObjectID c_object_id = object_id.native()
# Note: faster to not release GIL for short-running op. # Note: faster to not release GIL for short-running op.
self.core_worker.get().RemoveObjectIDReference(c_object_id) self.core_worker.get().RemoveLocalReference(object_id.native())
def serialize_and_promote_object_id(self, ObjectID object_id): def serialize_and_promote_object_id(self, ObjectID object_id):
cdef: cdef:
@ -1174,6 +1170,24 @@ cdef class CoreWorker:
def current_actor_is_asyncio(self): def current_actor_is_asyncio(self):
return self.core_worker.get().GetWorkerContext().CurrentActorIsAsync() return self.core_worker.get().GetWorkerContext().CurrentActorIsAsync()
def get_all_reference_counts(self):
cdef:
unordered_map[CObjectID, pair[size_t, size_t]] c_ref_counts
unordered_map[CObjectID, pair[size_t, size_t]].iterator it
c_ref_counts = self.core_worker.get().GetAllReferenceCounts()
it = c_ref_counts.begin()
ref_counts = {}
while it != c_ref_counts.end():
object_id = ObjectID(dereference(it).first.Binary())
ref_counts[object_id] = {
"local": dereference(it).second.first,
"submitted": dereference(it).second.second}
postincrement(it)
return ref_counts
def in_memory_store_get_async(self, ObjectID object_id, future): def in_memory_store_get_async(self, ObjectID object_id, future):
self.core_worker.get().GetAsync( self.core_worker.get().GetAsync(
object_id.native(), object_id.native(),

View file

@ -118,8 +118,8 @@ cdef extern from "ray/core_worker/core_worker.h" nogil:
CActorID DeserializeAndRegisterActorHandle(const c_string &bytes) CActorID DeserializeAndRegisterActorHandle(const c_string &bytes)
CRayStatus SerializeActorHandle(const CActorID &actor_id, c_string CRayStatus SerializeActorHandle(const CActorID &actor_id, c_string
*bytes) *bytes)
void AddObjectIDReference(const CObjectID &object_id) void AddLocalReference(const CObjectID &object_id)
void RemoveObjectIDReference(const CObjectID &object_id) void RemoveLocalReference(const CObjectID &object_id)
void PromoteObjectToPlasma(const CObjectID &object_id) void PromoteObjectToPlasma(const CObjectID &object_id)
void PromoteToPlasmaAndGetOwnershipInfo(const CObjectID &object_id, void PromoteToPlasmaAndGetOwnershipInfo(const CObjectID &object_id,
CTaskID *owner_id, CTaskID *owner_id,
@ -149,6 +149,9 @@ cdef extern from "ray/core_worker/core_worker.h" nogil:
CWorkerContext &GetWorkerContext() CWorkerContext &GetWorkerContext()
void YieldCurrentFiber(CFiberEvent &coroutine_done) void YieldCurrentFiber(CFiberEvent &coroutine_done)
unordered_map[CObjectID, pair[size_t, size_t]] GetAllReferenceCounts()
void GetAsync(const CObjectID &object_id, void GetAsync(const CObjectID &object_id,
ray_callback_function successs_callback, ray_callback_function successs_callback,
ray_callback_function fallback_callback, ray_callback_function fallback_callback,

View file

@ -0,0 +1,162 @@
# coding: utf-8
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import copy
import tempfile
import numpy as np
import time
import logging
import uuid
import ray
import ray.cluster_utils
import ray.test_utils
logger = logging.getLogger(__name__)
def _check_refcounts(expected):
actual = ray.worker.global_worker.core_worker.get_all_reference_counts()
assert len(expected) == len(actual)
for object_id, (local, submitted) in expected.items():
assert object_id in actual
assert local == actual[object_id]["local"]
assert submitted == actual[object_id]["submitted"]
def check_refcounts(expected, timeout=1):
start = time.time()
while True:
try:
_check_refcounts(expected)
break
except AssertionError as e:
if time.time() - start > timeout:
raise e
else:
time.sleep(0.1)
def test_local_refcounts(ray_start_regular):
oid1 = ray.put(None)
check_refcounts({oid1: (1, 0)})
oid1_copy = copy.copy(oid1)
check_refcounts({oid1: (2, 0)})
del oid1
check_refcounts({oid1_copy: (1, 0)})
del oid1_copy
check_refcounts({})
def test_dependency_refcounts(ray_start_regular):
# Return a large object that will be spilled to plasma.
def large_object():
return np.zeros(10 * 1024 * 1024, dtype=np.uint8)
# TODO: Clean up tmpfiles?
def random_path():
return os.path.join(tempfile.gettempdir(), uuid.uuid4().hex)
def touch(path):
with open(path, "w"):
pass
def wait_for_file(path):
while True:
if os.path.exists(path):
break
time.sleep(0.1)
@ray.remote
def one_dep(dep, path=None, fail=False):
if path is not None:
wait_for_file(path)
if fail:
raise Exception("failed on purpose")
@ray.remote
def one_dep_large(dep, path=None):
if path is not None:
wait_for_file(path)
# This should be spilled to plasma.
return large_object()
# Test that regular plasma dependency refcounts are decremented once the
# task finishes.
f = random_path()
large_dep = ray.put(large_object())
result = one_dep.remote(large_dep, path=f)
check_refcounts({large_dep: (1, 1), result: (1, 0)})
touch(f)
# Reference count should be removed once the task finishes.
check_refcounts({large_dep: (1, 0), result: (1, 0)})
del large_dep, result
check_refcounts({})
# Test that inlined dependency refcounts are decremented once they are
# inlined.
f = random_path()
dep = one_dep.remote(None, path=f)
check_refcounts({dep: (1, 0)})
result = one_dep.remote(dep)
check_refcounts({dep: (1, 1), result: (1, 0)})
touch(f)
# Reference count should be removed as soon as the dependency is inlined.
check_refcounts({dep: (1, 0), result: (1, 0)}, timeout=1)
del dep, result
check_refcounts({})
# Test that spilled plasma dependency refcounts are decremented once
# the task finishes.
f1, f2 = random_path(), random_path()
dep = one_dep_large.remote(None, path=f1)
check_refcounts({dep: (1, 0)})
result = one_dep.remote(dep, path=f2)
check_refcounts({dep: (1, 1), result: (1, 0)})
touch(f1)
ray.get(dep, timeout=5.0)
# Reference count should remain because the dependency is in plasma.
check_refcounts({dep: (1, 1), result: (1, 0)})
touch(f2)
# Reference count should be removed because the task finished.
check_refcounts({dep: (1, 0), result: (1, 0)})
del dep, result
check_refcounts({})
# Test that regular plasma dependency refcounts are decremented if a task
# fails.
f = random_path()
large_dep = ray.put(large_object())
result = one_dep.remote(large_dep, path=f, fail=True)
check_refcounts({large_dep: (1, 1), result: (1, 0)})
touch(f)
# Reference count should be removed once the task finishes.
check_refcounts({large_dep: (1, 0), result: (1, 0)})
del large_dep, result
check_refcounts({})
# Test that spilled plasma dependency refcounts are decremented if a task
# fails.
f1, f2 = random_path(), random_path()
dep = one_dep_large.remote(None, path=f1)
check_refcounts({dep: (1, 0)})
result = one_dep.remote(dep, path=f2, fail=True)
check_refcounts({dep: (1, 1), result: (1, 0)})
touch(f1)
ray.get(dep, timeout=5.0)
# Reference count should remain because the dependency is in plasma.
check_refcounts({dep: (1, 1), result: (1, 0)})
touch(f2)
# Reference count should be removed because the task finished.
check_refcounts({dep: (1, 0), result: (1, 0)})
del dep, result
check_refcounts({})
if __name__ == "__main__":
import pytest
import sys
sys.exit(pytest.main(["-v", __file__]))

View file

@ -192,14 +192,14 @@ CoreWorker::CoreWorker(const WorkerType worker_type, const Language language,
ref_counting_enabled ? reference_counter_ : nullptr, local_raylet_client_)); ref_counting_enabled ? reference_counter_ : nullptr, local_raylet_client_));
task_manager_.reset(new TaskManager( task_manager_.reset(new TaskManager(
memory_store_, actor_manager_, [this](const TaskSpecification &spec) { memory_store_, reference_counter_, actor_manager_,
[this](const TaskSpecification &spec) {
// Retry after a delay to emulate the existing Raylet reconstruction // Retry after a delay to emulate the existing Raylet reconstruction
// behaviour. TODO(ekl) backoff exponentially. // behaviour. TODO(ekl) backoff exponentially.
RAY_LOG(ERROR) << "Will resubmit task after a 5 second delay: " RAY_LOG(ERROR) << "Will resubmit task after a 5 second delay: "
<< spec.DebugString(); << spec.DebugString();
to_resubmit_.push_back(std::make_pair(current_time_ms() + 5000, spec)); to_resubmit_.push_back(std::make_pair(current_time_ms() + 5000, spec));
})); }));
resolver_.reset(new LocalDependencyResolver(memory_store_));
// Create an entry for the driver task in the task table. This task is // Create an entry for the driver task in the task table. This task is
// added immediately with status RUNNING. This allows us to push errors // added immediately with status RUNNING. This allows us to push errors
@ -377,8 +377,7 @@ Status CoreWorker::Put(const RayObject &object, ObjectID *object_id) {
*object_id = ObjectID::ForPut(worker_context_.GetCurrentTaskID(), *object_id = ObjectID::ForPut(worker_context_.GetCurrentTaskID(),
worker_context_.GetNextPutIndex(), worker_context_.GetNextPutIndex(),
static_cast<uint8_t>(TaskTransportType::RAYLET)); static_cast<uint8_t>(TaskTransportType::RAYLET));
reference_counter_->AddOwnedObject(*object_id, GetCallerId(), rpc_address_, reference_counter_->AddOwnedObject(*object_id, GetCallerId(), rpc_address_);
std::make_shared<std::vector<ObjectID>>());
return Put(object, *object_id); return Put(object, *object_id);
} }
@ -460,10 +459,6 @@ Status CoreWorker::Get(const std::vector<ObjectID> &ids, const int64_t timeout_m
// object. // object.
will_throw_exception = true; will_throw_exception = true;
} }
// If we got the result for this ObjectID, the task that created it must
// have finished. Therefore, we can safely remove its reference counting
// dependencies.
RemoveObjectIDDependencies(ids[i]);
} else { } else {
missing_result = true; missing_result = true;
} }
@ -597,30 +592,6 @@ TaskID CoreWorker::GetCallerId() const {
return caller_id; return caller_id;
} }
void CoreWorker::PinObjectReferences(const TaskSpecification &task_spec,
const TaskTransportType transport_type) {
size_t num_returns = task_spec.NumReturns();
if (task_spec.IsActorCreationTask() || task_spec.IsActorTask()) {
num_returns--;
}
std::shared_ptr<std::vector<ObjectID>> task_deps =
std::make_shared<std::vector<ObjectID>>();
for (size_t i = 0; i < task_spec.NumArgs(); i++) {
if (task_spec.ArgByRef(i)) {
for (size_t j = 0; j < task_spec.ArgIdCount(i); j++) {
task_deps->push_back(task_spec.ArgId(i, j));
}
}
}
// Note that we call this even if task_deps.size() == 0, in order to pin the return id.
for (size_t i = 0; i < num_returns; i++) {
reference_counter_->AddOwnedObject(task_spec.ReturnId(i, transport_type),
GetCallerId(), rpc_address_, task_deps);
}
}
Status CoreWorker::SubmitTask(const RayFunction &function, Status CoreWorker::SubmitTask(const RayFunction &function,
const std::vector<TaskArg> &args, const std::vector<TaskArg> &args,
const TaskOptions &task_options, const TaskOptions &task_options,
@ -642,11 +613,9 @@ Status CoreWorker::SubmitTask(const RayFunction &function,
return_ids); return_ids);
TaskSpecification task_spec = builder.Build(); TaskSpecification task_spec = builder.Build();
if (task_options.is_direct_call) { if (task_options.is_direct_call) {
task_manager_->AddPendingTask(task_spec, max_retries); task_manager_->AddPendingTask(GetCallerId(), rpc_address_, task_spec, max_retries);
PinObjectReferences(task_spec, TaskTransportType::DIRECT);
return direct_task_submitter_->SubmitTask(task_spec); return direct_task_submitter_->SubmitTask(task_spec);
} else { } else {
PinObjectReferences(task_spec, TaskTransportType::RAYLET);
return local_raylet_client_->SubmitTask(task_spec); return local_raylet_client_->SubmitTask(task_spec);
} }
} }
@ -685,11 +654,10 @@ Status CoreWorker::CreateActor(const RayFunction &function,
*return_actor_id = actor_id; *return_actor_id = actor_id;
TaskSpecification task_spec = builder.Build(); TaskSpecification task_spec = builder.Build();
if (actor_creation_options.is_direct_call) { if (actor_creation_options.is_direct_call) {
task_manager_->AddPendingTask(task_spec, actor_creation_options.max_reconstructions); task_manager_->AddPendingTask(GetCallerId(), rpc_address_, task_spec,
PinObjectReferences(task_spec, TaskTransportType::DIRECT); actor_creation_options.max_reconstructions);
return direct_task_submitter_->SubmitTask(task_spec); return direct_task_submitter_->SubmitTask(task_spec);
} else { } else {
PinObjectReferences(task_spec, TaskTransportType::RAYLET);
return local_raylet_client_->SubmitTask(task_spec); return local_raylet_client_->SubmitTask(task_spec);
} }
} }
@ -729,8 +697,7 @@ Status CoreWorker::SubmitActorTask(const ActorID &actor_id, const RayFunction &f
Status status; Status status;
TaskSpecification task_spec = builder.Build(); TaskSpecification task_spec = builder.Build();
if (is_direct_call) { if (is_direct_call) {
task_manager_->AddPendingTask(task_spec); task_manager_->AddPendingTask(GetCallerId(), rpc_address_, task_spec);
PinObjectReferences(task_spec, TaskTransportType::DIRECT);
if (actor_handle->IsDead()) { if (actor_handle->IsDead()) {
auto status = Status::IOError("sent task to dead actor"); auto status = Status::IOError("sent task to dead actor");
task_manager_->PendingTaskFailed(task_spec.TaskId(), rpc::ErrorType::ACTOR_DIED, task_manager_->PendingTaskFailed(task_spec.TaskId(), rpc::ErrorType::ACTOR_DIED,
@ -739,7 +706,6 @@ Status CoreWorker::SubmitActorTask(const ActorID &actor_id, const RayFunction &f
status = direct_actor_submitter_->SubmitTask(task_spec); status = direct_actor_submitter_->SubmitTask(task_spec);
} }
} else { } else {
PinObjectReferences(task_spec, TaskTransportType::RAYLET);
RAY_CHECK_OK(local_raylet_client_->SubmitTask(task_spec)); RAY_CHECK_OK(local_raylet_client_->SubmitTask(task_spec));
} }
return status; return status;
@ -1043,17 +1009,17 @@ void CoreWorker::HandleGetObjectStatus(const rpc::GetObjectStatusRequest &reques
if (task_manager_->IsTaskPending(object_id.TaskId())) { if (task_manager_->IsTaskPending(object_id.TaskId())) {
// Acquire a reference and retry. This prevents the object from being // Acquire a reference and retry. This prevents the object from being
// evicted out from under us before we can start the get. // evicted out from under us before we can start the get.
AddObjectIDReference(object_id); AddLocalReference(object_id);
if (task_manager_->IsTaskPending(object_id.TaskId())) { if (task_manager_->IsTaskPending(object_id.TaskId())) {
// The task is pending. Send the reply once the task finishes. // The task is pending. Send the reply once the task finishes.
memory_store_->GetAsync(object_id, memory_store_->GetAsync(object_id,
[send_reply_callback](std::shared_ptr<RayObject> obj) { [send_reply_callback](std::shared_ptr<RayObject> obj) {
send_reply_callback(Status::OK(), nullptr, nullptr); send_reply_callback(Status::OK(), nullptr, nullptr);
}); });
RemoveObjectIDReference(object_id); RemoveLocalReference(object_id);
} else { } else {
// We lost the race, the task is done. // We lost the race, the task is done.
RemoveObjectIDReference(object_id); RemoveLocalReference(object_id);
send_reply_callback(Status::OK(), nullptr, nullptr); send_reply_callback(Status::OK(), nullptr, nullptr);
} }
} else { } else {

View file

@ -101,17 +101,19 @@ class CoreWorker {
actor_id_ = actor_id; actor_id_ = actor_id;
} }
/// Increase the reference count for this object ID. /// Increase the local reference count for this object ID. Should be called
/// by the language frontend when a new reference is created.
/// ///
/// \param[in] object_id The object ID to increase the reference count for. /// \param[in] object_id The object ID to increase the reference count for.
void AddObjectIDReference(const ObjectID &object_id) { void AddLocalReference(const ObjectID &object_id) {
reference_counter_->AddLocalReference(object_id); reference_counter_->AddLocalReference(object_id);
} }
/// Decrease the reference count for this object ID. /// Decrease the reference count for this object ID. Should be called
/// by the language frontend when a reference is destroyed.
/// ///
/// \param[in] object_id The object ID to decrease the reference count for. /// \param[in] object_id The object ID to decrease the reference count for.
void RemoveObjectIDReference(const ObjectID &object_id) { void RemoveLocalReference(const ObjectID &object_id) {
std::vector<ObjectID> deleted; std::vector<ObjectID> deleted;
reference_counter_->RemoveLocalReference(object_id, &deleted); reference_counter_->RemoveLocalReference(object_id, &deleted);
if (ref_counting_enabled_) { if (ref_counting_enabled_) {
@ -119,6 +121,12 @@ class CoreWorker {
} }
} }
/// Returns a map of all ObjectIDs currently in scope with a pair of their
/// (local, submitted_task) reference counts. For debugging purposes.
std::unordered_map<ObjectID, std::pair<size_t, size_t>> GetAllReferenceCounts() const {
return reference_counter_->GetAllReferenceCounts();
}
/// Promote an object to plasma and get its owner information. This should be /// Promote an object to plasma and get its owner information. This should be
/// called when serializing an object ID, and the returned information should /// called when serializing an object ID, and the returned information should
/// be stored with the serialized object ID. For plasma promotion, if the /// be stored with the serialized object ID. For plasma promotion, if the
@ -434,11 +442,6 @@ class CoreWorker {
/// Private methods related to task submission. /// Private methods related to task submission.
/// ///
/// Add task dependencies to the reference counter. This prevents the argument
/// objects from early eviction, and also adds the return object.
void PinObjectReferences(const TaskSpecification &task_spec,
const TaskTransportType transport_type);
/// Give this worker a handle to an actor. /// Give this worker a handle to an actor.
/// ///
/// This handle will remain as long as the current actor or task is /// This handle will remain as long as the current actor or task is
@ -485,17 +488,6 @@ class CoreWorker {
std::vector<std::shared_ptr<RayObject>> *args, std::vector<std::shared_ptr<RayObject>> *args,
std::vector<ObjectID> *arg_reference_ids); std::vector<ObjectID> *arg_reference_ids);
/// Remove reference counting dependencies of this object ID.
///
/// \param[in] object_id The object whose dependencies should be removed.
void RemoveObjectIDDependencies(const ObjectID &object_id) {
std::vector<ObjectID> deleted;
reference_counter_->RemoveDependencies(object_id, &deleted);
if (ref_counting_enabled_) {
memory_store_->Delete(deleted);
}
}
/// Returns whether the message was sent to the wrong worker. The right error reply /// Returns whether the message was sent to the wrong worker. The right error reply
/// is sent automatically. Messages end up on the wrong worker when a worker dies /// is sent automatically. Messages end up on the wrong worker when a worker dies
/// and a new one takes its place with the same place. In this situation, we want /// and a new one takes its place with the same place. In this situation, we want
@ -624,9 +616,6 @@ class CoreWorker {
absl::flat_hash_map<ActorID, std::unique_ptr<ActorHandle>> actor_handles_ absl::flat_hash_map<ActorID, std::unique_ptr<ActorHandle>> actor_handles_
GUARDED_BY(actor_handles_mutex_); GUARDED_BY(actor_handles_mutex_);
/// Resolve local and remote dependencies for actor creation.
std::unique_ptr<LocalDependencyResolver> resolver_;
/// ///
/// Fields related to task execution. /// Fields related to task execution.
/// ///

View file

@ -14,81 +14,76 @@ void ReferenceCounter::AddBorrowedObject(const ObjectID &object_id,
} }
} }
void ReferenceCounter::AddOwnedObject( void ReferenceCounter::AddOwnedObject(const ObjectID &object_id, const TaskID &owner_id,
const ObjectID &object_id, const TaskID &owner_id, const rpc::Address &owner_address, const rpc::Address &owner_address) {
std::shared_ptr<std::vector<ObjectID>> dependencies) {
absl::MutexLock lock(&mutex_); absl::MutexLock lock(&mutex_);
for (const ObjectID &dependency_id : *dependencies) {
AddLocalReferenceInternal(dependency_id);
}
RAY_CHECK(object_id_refs_.count(object_id) == 0) RAY_CHECK(object_id_refs_.count(object_id) == 0)
<< "Cannot create an object that already exists. ObjectID: " << object_id; << "Tried to create an owned object that already exists: " << object_id;
// If the entry doesn't exist, we initialize the direct reference count to zero // If the entry doesn't exist, we initialize the direct reference count to zero
// because this corresponds to a submitted task whose return ObjectID will be created // because this corresponds to a submitted task whose return ObjectID will be created
// in the frontend language, incrementing the reference count. // in the frontend language, incrementing the reference count.
object_id_refs_.emplace(object_id, Reference(owner_id, owner_address, dependencies)); object_id_refs_.emplace(object_id, Reference(owner_id, owner_address));
} }
void ReferenceCounter::AddLocalReferenceInternal(const ObjectID &object_id) { void ReferenceCounter::AddLocalReference(const ObjectID &object_id) {
absl::MutexLock lock(&mutex_);
auto entry = object_id_refs_.find(object_id); auto entry = object_id_refs_.find(object_id);
if (entry == object_id_refs_.end()) { if (entry == object_id_refs_.end()) {
// TODO: Once ref counting is implemented, we should always know how the // TODO: Once ref counting is implemented, we should always know how the
// ObjectID was created, so there should always ben an entry. // ObjectID was created, so there should always be an entry.
entry = object_id_refs_.emplace(object_id, Reference()).first; entry = object_id_refs_.emplace(object_id, Reference()).first;
} }
entry->second.local_ref_count++; entry->second.local_ref_count++;
} }
void ReferenceCounter::AddLocalReference(const ObjectID &object_id) {
absl::MutexLock lock(&mutex_);
AddLocalReferenceInternal(object_id);
}
void ReferenceCounter::RemoveLocalReference(const ObjectID &object_id, void ReferenceCounter::RemoveLocalReference(const ObjectID &object_id,
std::vector<ObjectID> *deleted) { std::vector<ObjectID> *deleted) {
absl::MutexLock lock(&mutex_); absl::MutexLock lock(&mutex_);
RemoveReferenceRecursive(object_id, deleted);
}
void ReferenceCounter::RemoveDependencies(const ObjectID &object_id,
std::vector<ObjectID> *deleted) {
absl::MutexLock lock(&mutex_);
auto entry = object_id_refs_.find(object_id);
if (entry == object_id_refs_.end()) {
RAY_LOG(WARNING) << "Tried to remove dependencies for nonexistent object ID: "
<< object_id;
return;
}
if (entry->second.dependencies) {
for (const ObjectID &pending_task_object_id : *entry->second.dependencies) {
RemoveReferenceRecursive(pending_task_object_id, deleted);
}
entry->second.dependencies = nullptr;
}
}
void ReferenceCounter::RemoveReferenceRecursive(const ObjectID &object_id,
std::vector<ObjectID> *deleted) {
auto entry = object_id_refs_.find(object_id); auto entry = object_id_refs_.find(object_id);
if (entry == object_id_refs_.end()) { if (entry == object_id_refs_.end()) {
RAY_LOG(WARNING) << "Tried to decrease ref count for nonexistent object ID: " RAY_LOG(WARNING) << "Tried to decrease ref count for nonexistent object ID: "
<< object_id; << object_id;
return; return;
} }
if (--entry->second.local_ref_count == 0) { if (--entry->second.local_ref_count == 0 &&
// If the reference count reached 0, decrease the reference count for each dependency. entry->second.submitted_task_ref_count == 0) {
if (entry->second.dependencies) {
for (const ObjectID &pending_task_object_id : *entry->second.dependencies) {
RemoveReferenceRecursive(pending_task_object_id, deleted);
}
}
object_id_refs_.erase(entry); object_id_refs_.erase(entry);
deleted->push_back(object_id); deleted->push_back(object_id);
} }
} }
void ReferenceCounter::AddSubmittedTaskReferences(
const std::vector<ObjectID> &object_ids) {
absl::MutexLock lock(&mutex_);
for (const ObjectID &object_id : object_ids) {
auto entry = object_id_refs_.find(object_id);
if (entry == object_id_refs_.end()) {
// TODO: Once ref counting is implemented, we should always know how the
// ObjectID was created, so there should always be an entry.
entry = object_id_refs_.emplace(object_id, Reference()).first;
}
entry->second.submitted_task_ref_count++;
}
}
void ReferenceCounter::RemoveSubmittedTaskReferences(
const std::vector<ObjectID> &object_ids, std::vector<ObjectID> *deleted) {
absl::MutexLock lock(&mutex_);
for (const ObjectID &object_id : object_ids) {
auto entry = object_id_refs_.find(object_id);
if (entry == object_id_refs_.end()) {
RAY_LOG(WARNING) << "Tried to decrease ref count for nonexistent object ID: "
<< object_id;
return;
}
if (--entry->second.submitted_task_ref_count == 0 &&
entry->second.local_ref_count == 0) {
object_id_refs_.erase(entry);
deleted->push_back(object_id);
}
}
}
bool ReferenceCounter::GetOwner(const ObjectID &object_id, TaskID *owner_id, bool ReferenceCounter::GetOwner(const ObjectID &object_id, TaskID *owner_id,
rpc::Address *owner_address) const { rpc::Address *owner_address) const {
absl::MutexLock lock(&mutex_); absl::MutexLock lock(&mutex_);
@ -126,6 +121,19 @@ std::unordered_set<ObjectID> ReferenceCounter::GetAllInScopeObjectIDs() const {
return in_scope_object_ids; return in_scope_object_ids;
} }
std::unordered_map<ObjectID, std::pair<size_t, size_t>>
ReferenceCounter::GetAllReferenceCounts() const {
absl::MutexLock lock(&mutex_);
std::unordered_map<ObjectID, std::pair<size_t, size_t>> all_ref_counts;
all_ref_counts.reserve(object_id_refs_.size());
for (auto it : object_id_refs_) {
all_ref_counts.emplace(it.first,
std::pair<size_t, size_t>(it.second.local_ref_count,
it.second.submitted_task_ref_count));
}
return all_ref_counts;
}
void ReferenceCounter::LogDebugString() const { void ReferenceCounter::LogDebugString() const {
absl::MutexLock lock(&mutex_); absl::MutexLock lock(&mutex_);
@ -137,15 +145,9 @@ void ReferenceCounter::LogDebugString() const {
for (const auto &entry : object_id_refs_) { for (const auto &entry : object_id_refs_) {
RAY_LOG(DEBUG) << "\t" << entry.first.Hex(); RAY_LOG(DEBUG) << "\t" << entry.first.Hex();
RAY_LOG(DEBUG) << "\t\treference count: " << entry.second.local_ref_count; RAY_LOG(DEBUG) << "\t\tlocal refcount: " << entry.second.local_ref_count;
RAY_LOG(DEBUG) << "\t\tdependencies: "; RAY_LOG(DEBUG) << "\t\tsubmitted task refcount: "
if (!entry.second.dependencies) { << entry.second.submitted_task_ref_count;
RAY_LOG(DEBUG) << "\t\t\tNULL";
} else {
for (const ObjectID &pending_task_object_id : *entry.second.dependencies) {
RAY_LOG(DEBUG) << "\t\t\t" << pending_task_object_id.Hex();
}
}
} }
} }

View file

@ -25,27 +25,32 @@ class ReferenceCounter {
/// \param[in] object_id The object to to increment the count for. /// \param[in] object_id The object to to increment the count for.
void AddLocalReference(const ObjectID &object_id) LOCKS_EXCLUDED(mutex_); void AddLocalReference(const ObjectID &object_id) LOCKS_EXCLUDED(mutex_);
/// Decrease the reference count for the ObjectID by one. If the reference count reaches /// Decrease the local reference count for the ObjectID by one.
/// zero, it will be erased from the map and the reference count for all of its
/// dependencies will be decreased be one.
/// ///
/// \param[in] object_id The object to decrement the count for. /// \param[in] object_id The object to decrement the count for.
/// \param[out] deleted List to store objects that hit zero ref count. /// \param[out] deleted List to store objects that hit zero ref count.
void RemoveLocalReference(const ObjectID &object_id, std::vector<ObjectID> *deleted) void RemoveLocalReference(const ObjectID &object_id, std::vector<ObjectID> *deleted)
LOCKS_EXCLUDED(mutex_); LOCKS_EXCLUDED(mutex_);
/// Remove any references to dependencies that this object may have. This does *not* /// Add references for the provided object IDs that correspond to them being
/// decrease the object's own reference count. /// dependencies to a submitted task.
/// ///
/// \param[in] object_id The object whose dependencies should be removed. /// \param[in] object_ids The object IDs to add references for.
/// \param[out] deleted List to store objects that hit zero ref count. void AddSubmittedTaskReferences(const std::vector<ObjectID> &object_ids);
void RemoveDependencies(const ObjectID &object_id, std::vector<ObjectID> *deleted)
LOCKS_EXCLUDED(mutex_); /// Remove references for the provided object IDs that correspond to them being
/// dependencies to a submitted task. This should be called when inlined
/// dependencies are inlined or when the task finishes for plasma dependencies.
///
/// \param[in] object_ids The object IDs to remove references for.
/// \param[out] deleted The object IDs whos reference counts reached zero.
void RemoveSubmittedTaskReferences(const std::vector<ObjectID> &object_ids,
std::vector<ObjectID> *deleted);
/// Add an object that we own. The object may depend on other objects. /// Add an object that we own. The object may depend on other objects.
/// Dependencies for each ObjectID must be set at most once. The direct /// Dependencies for each ObjectID must be set at most once. The local
/// reference count for the ObjectID is set to zero and the reference count /// reference count for the ObjectID is set to zero, which assumes that an
/// for each dependency is incremented. /// ObjectID for it will be created in the language frontend after this call.
/// ///
/// TODO(swang): We could avoid copying the owner_id and owner_address since /// TODO(swang): We could avoid copying the owner_id and owner_address since
/// we are the owner, but it is easier to store a copy for now, since the /// we are the owner, but it is easier to store a copy for now, since the
@ -57,9 +62,7 @@ class ReferenceCounter {
/// \param[in] owner_address The address of the object's owner. /// \param[in] owner_address The address of the object's owner.
/// \param[in] dependencies The objects that the object depends on. /// \param[in] dependencies The objects that the object depends on.
void AddOwnedObject(const ObjectID &object_id, const TaskID &owner_id, void AddOwnedObject(const ObjectID &object_id, const TaskID &owner_id,
const rpc::Address &owner_address, const rpc::Address &owner_address) LOCKS_EXCLUDED(mutex_);
std::shared_ptr<std::vector<ObjectID>> dependencies)
LOCKS_EXCLUDED(mutex_);
/// Add an object that we are borrowing. /// Add an object that we are borrowing.
/// ///
@ -82,6 +85,11 @@ class ReferenceCounter {
/// Returns a set of all ObjectIDs currently in scope (i.e., nonzero reference count). /// Returns a set of all ObjectIDs currently in scope (i.e., nonzero reference count).
std::unordered_set<ObjectID> GetAllInScopeObjectIDs() const LOCKS_EXCLUDED(mutex_); std::unordered_set<ObjectID> GetAllInScopeObjectIDs() const LOCKS_EXCLUDED(mutex_);
/// Returns a map of all ObjectIDs currently in scope with a pair of their
/// (local, submitted_task) reference counts. For debugging purposes.
std::unordered_map<ObjectID, std::pair<size_t, size_t>> GetAllReferenceCounts() const
LOCKS_EXCLUDED(mutex_);
/// Dumps information about all currently tracked references to RAY_LOG(DEBUG). /// Dumps information about all currently tracked references to RAY_LOG(DEBUG).
void LogDebugString() const LOCKS_EXCLUDED(mutex_); void LogDebugString() const LOCKS_EXCLUDED(mutex_);
@ -91,22 +99,12 @@ class ReferenceCounter {
/// Constructor for a reference whose origin is unknown. /// Constructor for a reference whose origin is unknown.
Reference() : owned_by_us(false) {} Reference() : owned_by_us(false) {}
/// Constructor for a reference that we created. /// Constructor for a reference that we created.
Reference(const TaskID &owner_id, const rpc::Address &owner_address,
std::shared_ptr<std::vector<ObjectID>> deps)
: dependencies(std::move(deps)),
owned_by_us(true),
owner({owner_id, owner_address}) {}
/// Constructor for a reference that was given to us.
Reference(const TaskID &owner_id, const rpc::Address &owner_address) Reference(const TaskID &owner_id, const rpc::Address &owner_address)
: owned_by_us(false), owner({owner_id, owner_address}) {} : owned_by_us(true), owner({owner_id, owner_address}) {}
/// The local ref count for the ObjectID in the language frontend. /// The local ref count for the ObjectID in the language frontend.
size_t local_ref_count = 0; size_t local_ref_count = 0;
/// The objects that this object depends on. Tracked only by the owner of /// The ref count for submitted tasks that depend on the ObjectID.
/// the object. Dependencies are stored as shared_ptrs because the same set size_t submitted_task_ref_count = 0;
/// of dependencies can be shared among multiple entries. For example, when
/// a task has multiple return values, the entry for each return ObjectID
/// depends on all task dependencies.
std::shared_ptr<std::vector<ObjectID>> dependencies;
/// Whether we own the object. If we own the object, then we are /// Whether we own the object. If we own the object, then we are
/// responsible for tracking the state of the task that creates the object /// responsible for tracking the state of the task that creates the object
/// (see task_manager.h). /// (see task_manager.h).

View file

@ -16,173 +16,63 @@ class ReferenceCountTest : public ::testing::Test {
virtual void TearDown() {} virtual void TearDown() {}
}; };
// Tests basic incrementing/decrementing of direct reference counts. An entry should only // Tests basic incrementing/decrementing of direct/submitted task reference counts. An
// be removed once its reference count reaches zero. // entry should only be removed once both of its reference counts reach zero.
TEST_F(ReferenceCountTest, TestBasic) { TEST_F(ReferenceCountTest, TestBasic) {
std::vector<ObjectID> out; std::vector<ObjectID> out;
ObjectID id = ObjectID::FromRandom();
rc->AddLocalReference(id);
ASSERT_EQ(rc->NumObjectIDsInScope(), 1);
rc->AddLocalReference(id);
ASSERT_EQ(rc->NumObjectIDsInScope(), 1);
rc->AddLocalReference(id);
ASSERT_EQ(rc->NumObjectIDsInScope(), 1);
rc->RemoveLocalReference(id, &out);
ASSERT_EQ(rc->NumObjectIDsInScope(), 1);
ASSERT_EQ(out.size(), 0);
rc->RemoveLocalReference(id, &out);
ASSERT_EQ(rc->NumObjectIDsInScope(), 1);
ASSERT_EQ(out.size(), 0);
rc->RemoveLocalReference(id, &out);
ASSERT_EQ(rc->NumObjectIDsInScope(), 0);
ASSERT_EQ(out.size(), 1);
}
// Tests the basic logic for dependencies - when an ObjectID with dependencies
// goes out of scope (i.e., reference count reaches zero), all of its dependencies
// should have their reference count decremented and be removed if it reaches zero.
TEST_F(ReferenceCountTest, TestDependencies) {
std::vector<ObjectID> out;
ObjectID id1 = ObjectID::FromRandom(); ObjectID id1 = ObjectID::FromRandom();
ObjectID id2 = ObjectID::FromRandom(); ObjectID id2 = ObjectID::FromRandom();
ObjectID id3 = ObjectID::FromRandom();
std::shared_ptr<std::vector<ObjectID>> deps = std::make_shared<std::vector<ObjectID>>();
deps->push_back(id2);
deps->push_back(id3);
rc->AddOwnedObject(id1, TaskID::Nil(), rpc::Address(), deps);
// Local references.
rc->AddLocalReference(id1); rc->AddLocalReference(id1);
rc->AddLocalReference(id1);
rc->AddLocalReference(id3);
ASSERT_EQ(rc->NumObjectIDsInScope(), 3);
rc->RemoveLocalReference(id1, &out);
ASSERT_EQ(rc->NumObjectIDsInScope(), 3);
ASSERT_EQ(out.size(), 0);
rc->RemoveLocalReference(id1, &out);
ASSERT_EQ(rc->NumObjectIDsInScope(), 1);
ASSERT_EQ(out.size(), 2);
rc->RemoveLocalReference(id3, &out);
ASSERT_EQ(rc->NumObjectIDsInScope(), 0);
ASSERT_EQ(out.size(), 3);
}
// Tests the case where two entries share the same set of dependencies. When one
// entry goes out of scope, it should decrease the reference count for the dependencies
// but they should still be nonzero until the second entry goes out of scope and all
// direct dependencies to the dependencies are removed.
TEST_F(ReferenceCountTest, TestSharedDependencies) {
std::vector<ObjectID> out;
ObjectID id1 = ObjectID::FromRandom();
ObjectID id2 = ObjectID::FromRandom();
ObjectID id3 = ObjectID::FromRandom();
ObjectID id4 = ObjectID::FromRandom();
std::shared_ptr<std::vector<ObjectID>> deps = std::make_shared<std::vector<ObjectID>>();
deps->push_back(id3);
deps->push_back(id4);
rc->AddOwnedObject(id1, TaskID::Nil(), rpc::Address(), deps);
rc->AddOwnedObject(id2, TaskID::Nil(), rpc::Address(), deps);
rc->AddLocalReference(id1); rc->AddLocalReference(id1);
rc->AddLocalReference(id2); rc->AddLocalReference(id2);
rc->AddLocalReference(id4); ASSERT_EQ(rc->NumObjectIDsInScope(), 2);
ASSERT_EQ(rc->NumObjectIDsInScope(), 4);
rc->RemoveLocalReference(id1, &out);
ASSERT_EQ(rc->NumObjectIDsInScope(), 3);
ASSERT_EQ(out.size(), 1);
rc->RemoveLocalReference(id2, &out);
ASSERT_EQ(rc->NumObjectIDsInScope(), 1);
ASSERT_EQ(out.size(), 3);
rc->RemoveLocalReference(id4, &out);
ASSERT_EQ(rc->NumObjectIDsInScope(), 0);
ASSERT_EQ(out.size(), 4);
}
// Tests the case when an entry has a dependency that itself has a
// dependency. In this case, when the first entry goes out of scope
// it should decrease the reference count for its dependency, causing
// that entry to go out of scope and decrease its dependencies' reference counts.
TEST_F(ReferenceCountTest, TestRecursiveDependencies) {
std::vector<ObjectID> out;
ObjectID id1 = ObjectID::FromRandom();
ObjectID id2 = ObjectID::FromRandom();
ObjectID id3 = ObjectID::FromRandom();
ObjectID id4 = ObjectID::FromRandom();
std::shared_ptr<std::vector<ObjectID>> deps2 =
std::make_shared<std::vector<ObjectID>>();
deps2->push_back(id3);
deps2->push_back(id4);
rc->AddOwnedObject(id2, TaskID::Nil(), rpc::Address(), deps2);
std::shared_ptr<std::vector<ObjectID>> deps1 =
std::make_shared<std::vector<ObjectID>>();
deps1->push_back(id2);
rc->AddOwnedObject(id1, TaskID::Nil(), rpc::Address(), deps1);
rc->AddLocalReference(id1);
rc->AddLocalReference(id2);
rc->AddLocalReference(id4);
ASSERT_EQ(rc->NumObjectIDsInScope(), 4);
rc->RemoveLocalReference(id2, &out);
ASSERT_EQ(rc->NumObjectIDsInScope(), 4);
ASSERT_EQ(out.size(), 0);
rc->RemoveLocalReference(id1, &out);
ASSERT_EQ(rc->NumObjectIDsInScope(), 1);
ASSERT_EQ(out.size(), 3);
rc->RemoveLocalReference(id4, &out);
ASSERT_EQ(rc->NumObjectIDsInScope(), 0);
ASSERT_EQ(out.size(), 4);
}
TEST_F(ReferenceCountTest, TestRemoveDependenciesOnly) {
std::vector<ObjectID> out;
ObjectID id1 = ObjectID::FromRandom();
ObjectID id2 = ObjectID::FromRandom();
ObjectID id3 = ObjectID::FromRandom();
ObjectID id4 = ObjectID::FromRandom();
std::shared_ptr<std::vector<ObjectID>> deps2 =
std::make_shared<std::vector<ObjectID>>();
deps2->push_back(id3);
deps2->push_back(id4);
rc->AddOwnedObject(id2, TaskID::Nil(), rpc::Address(), deps2);
std::shared_ptr<std::vector<ObjectID>> deps1 =
std::make_shared<std::vector<ObjectID>>();
deps1->push_back(id2);
rc->AddOwnedObject(id1, TaskID::Nil(), rpc::Address(), deps1);
rc->AddLocalReference(id1);
rc->AddLocalReference(id2);
rc->AddLocalReference(id4);
ASSERT_EQ(rc->NumObjectIDsInScope(), 4);
rc->RemoveDependencies(id2, &out);
ASSERT_EQ(rc->NumObjectIDsInScope(), 3);
ASSERT_EQ(out.size(), 1);
rc->RemoveDependencies(id1, &out);
ASSERT_EQ(rc->NumObjectIDsInScope(), 3);
ASSERT_EQ(out.size(), 1);
rc->RemoveLocalReference(id1, &out); rc->RemoveLocalReference(id1, &out);
ASSERT_EQ(rc->NumObjectIDsInScope(), 2); ASSERT_EQ(rc->NumObjectIDsInScope(), 2);
ASSERT_EQ(out.size(), 2); ASSERT_EQ(out.size(), 0);
rc->RemoveLocalReference(id2, &out); rc->RemoveLocalReference(id2, &out);
ASSERT_EQ(rc->NumObjectIDsInScope(), 1); ASSERT_EQ(rc->NumObjectIDsInScope(), 1);
ASSERT_EQ(out.size(), 3); ASSERT_EQ(out.size(), 1);
rc->RemoveLocalReference(id1, &out);
rc->RemoveLocalReference(id4, &out);
ASSERT_EQ(rc->NumObjectIDsInScope(), 0); ASSERT_EQ(rc->NumObjectIDsInScope(), 0);
ASSERT_EQ(out.size(), 4); ASSERT_EQ(out.size(), 2);
out.clear();
// Submitted task references.
rc->AddSubmittedTaskReferences({id1});
rc->AddSubmittedTaskReferences({id1, id2});
ASSERT_EQ(rc->NumObjectIDsInScope(), 2);
rc->RemoveSubmittedTaskReferences({id1}, &out);
ASSERT_EQ(rc->NumObjectIDsInScope(), 2);
ASSERT_EQ(out.size(), 0);
rc->RemoveSubmittedTaskReferences({id2}, &out);
ASSERT_EQ(rc->NumObjectIDsInScope(), 1);
ASSERT_EQ(out.size(), 1);
rc->RemoveSubmittedTaskReferences({id1}, &out);
ASSERT_EQ(rc->NumObjectIDsInScope(), 0);
ASSERT_EQ(out.size(), 2);
out.clear();
// Local & submitted task references.
rc->AddLocalReference(id1);
rc->AddSubmittedTaskReferences({id1, id2});
rc->AddLocalReference(id2);
ASSERT_EQ(rc->NumObjectIDsInScope(), 2);
rc->RemoveLocalReference(id1, &out);
ASSERT_EQ(rc->NumObjectIDsInScope(), 2);
ASSERT_EQ(out.size(), 0);
rc->RemoveSubmittedTaskReferences({id2}, &out);
ASSERT_EQ(rc->NumObjectIDsInScope(), 2);
ASSERT_EQ(out.size(), 0);
rc->RemoveSubmittedTaskReferences({id1}, &out);
ASSERT_EQ(rc->NumObjectIDsInScope(), 1);
ASSERT_EQ(out.size(), 1);
rc->RemoveLocalReference(id2, &out);
ASSERT_EQ(rc->NumObjectIDsInScope(), 0);
ASSERT_EQ(out.size(), 2);
out.clear();
} }
// Tests that we can get the owner address correctly for objects that we own, // Tests that we can get the owner address correctly for objects that we own,
@ -193,8 +83,7 @@ TEST_F(ReferenceCountTest, TestOwnerAddress) {
TaskID task_id = TaskID::ForFakeTask(); TaskID task_id = TaskID::ForFakeTask();
rpc::Address address; rpc::Address address;
address.set_ip_address("1234"); address.set_ip_address("1234");
auto deps = std::make_shared<std::vector<ObjectID>>(); rc->AddOwnedObject(object_id, task_id, address);
rc->AddOwnedObject(object_id, task_id, address, deps);
TaskID added_id; TaskID added_id;
rpc::Address added_address; rpc::Address added_address;
@ -205,7 +94,7 @@ TEST_F(ReferenceCountTest, TestOwnerAddress) {
auto object_id2 = ObjectID::FromRandom(); auto object_id2 = ObjectID::FromRandom();
task_id = TaskID::ForFakeTask(); task_id = TaskID::ForFakeTask();
address.set_ip_address("5678"); address.set_ip_address("5678");
rc->AddOwnedObject(object_id2, task_id, address, deps); rc->AddOwnedObject(object_id2, task_id, address);
ASSERT_TRUE(rc->GetOwner(object_id2, &added_id, &added_address)); ASSERT_TRUE(rc->GetOwner(object_id2, &added_id, &added_address));
ASSERT_EQ(task_id, added_id); ASSERT_EQ(task_id, added_id);
ASSERT_EQ(address.ip_address(), added_address.ip_address()); ASSERT_EQ(address.ip_address(), added_address.ip_address());

View file

@ -10,11 +10,34 @@ const int64_t kTaskFailureThrottlingThreshold = 50;
// Throttle task failure logs to once this interval. // Throttle task failure logs to once this interval.
const int64_t kTaskFailureLoggingFrequencyMillis = 5000; const int64_t kTaskFailureLoggingFrequencyMillis = 5000;
void TaskManager::AddPendingTask(const TaskSpecification &spec, int max_retries) { void TaskManager::AddPendingTask(const TaskID &caller_id,
const rpc::Address &caller_address,
const TaskSpecification &spec, int max_retries) {
RAY_LOG(DEBUG) << "Adding pending task " << spec.TaskId(); RAY_LOG(DEBUG) << "Adding pending task " << spec.TaskId();
absl::MutexLock lock(&mu_); absl::MutexLock lock(&mu_);
std::pair<TaskSpecification, int> entry = {spec, max_retries}; std::pair<TaskSpecification, int> entry = {spec, max_retries};
RAY_CHECK(pending_tasks_.emplace(spec.TaskId(), std::move(entry)).second); RAY_CHECK(pending_tasks_.emplace(spec.TaskId(), std::move(entry)).second);
// Add references for the dependencies to the task.
std::vector<ObjectID> task_deps;
for (size_t i = 0; i < spec.NumArgs(); i++) {
if (spec.ArgByRef(i)) {
for (size_t j = 0; j < spec.ArgIdCount(i); j++) {
task_deps.push_back(spec.ArgId(i, j));
}
}
}
reference_counter_->AddSubmittedTaskReferences(task_deps);
// Add new owned objects for the return values of the task.
size_t num_returns = spec.NumReturns();
if (spec.IsActorCreationTask() || spec.IsActorTask()) {
num_returns--;
}
for (size_t i = 0; i < num_returns; i++) {
reference_counter_->AddOwnedObject(spec.ReturnId(i, TaskTransportType::DIRECT),
caller_id, caller_address);
}
} }
void TaskManager::DrainAndShutdown(std::function<void()> shutdown) { void TaskManager::DrainAndShutdown(std::function<void()> shutdown) {
@ -48,6 +71,8 @@ void TaskManager::CompletePendingTask(const TaskID &task_id,
pending_tasks_.erase(it); pending_tasks_.erase(it);
} }
RemovePlasmaSubmittedTaskReferences(spec);
for (int i = 0; i < reply.return_objects_size(); i++) { for (int i = 0; i < reply.return_objects_size(); i++) {
const auto &return_object = reply.return_objects(i); const auto &return_object = reply.return_objects(i);
ObjectID object_id = ObjectID::FromBinary(return_object.object_id()); ObjectID object_id = ObjectID::FromBinary(return_object.object_id());
@ -134,6 +159,7 @@ void TaskManager::PendingTaskFailed(const TaskID &task_id, rpc::ErrorType error_
} }
} }
} }
RemovePlasmaSubmittedTaskReferences(spec);
MarkPendingTaskFailed(task_id, spec, error_type); MarkPendingTaskFailed(task_id, spec, error_type);
} }
@ -148,6 +174,28 @@ void TaskManager::ShutdownIfNeeded() {
} }
} }
void TaskManager::RemoveSubmittedTaskReferences(const std::vector<ObjectID> &object_ids) {
std::vector<ObjectID> deleted;
reference_counter_->RemoveSubmittedTaskReferences(object_ids, &deleted);
in_memory_store_->Delete(deleted);
}
void TaskManager::OnTaskDependenciesInlined(const std::vector<ObjectID> &object_ids) {
RemoveSubmittedTaskReferences(object_ids);
}
void TaskManager::RemovePlasmaSubmittedTaskReferences(TaskSpecification &spec) {
std::vector<ObjectID> plasma_dependencies;
for (size_t i = 0; i < spec.NumArgs(); i++) {
auto count = spec.ArgIdCount(i);
if (count > 0) {
const auto &id = spec.ArgId(i, 0);
plasma_dependencies.push_back(id);
}
}
RemoveSubmittedTaskReferences(plasma_dependencies);
}
void TaskManager::MarkPendingTaskFailed(const TaskID &task_id, void TaskManager::MarkPendingTaskFailed(const TaskID &task_id,
const TaskSpecification &spec, const TaskSpecification &spec,
rpc::ErrorType error_type) { rpc::ErrorType error_type) {

View file

@ -21,6 +21,8 @@ class TaskFinisherInterface {
virtual void PendingTaskFailed(const TaskID &task_id, rpc::ErrorType error_type, virtual void PendingTaskFailed(const TaskID &task_id, rpc::ErrorType error_type,
Status *status = nullptr) = 0; Status *status = nullptr) = 0;
virtual void OnTaskDependenciesInlined(const std::vector<ObjectID> &object_ids) = 0;
virtual ~TaskFinisherInterface() {} virtual ~TaskFinisherInterface() {}
}; };
@ -29,19 +31,24 @@ using RetryTaskCallback = std::function<void(const TaskSpecification &spec)>;
class TaskManager : public TaskFinisherInterface { class TaskManager : public TaskFinisherInterface {
public: public:
TaskManager(std::shared_ptr<CoreWorkerMemoryStore> in_memory_store, TaskManager(std::shared_ptr<CoreWorkerMemoryStore> in_memory_store,
std::shared_ptr<ReferenceCounter> reference_counter,
std::shared_ptr<ActorManagerInterface> actor_manager, std::shared_ptr<ActorManagerInterface> actor_manager,
RetryTaskCallback retry_task_callback) RetryTaskCallback retry_task_callback)
: in_memory_store_(in_memory_store), : in_memory_store_(in_memory_store),
reference_counter_(reference_counter),
actor_manager_(actor_manager), actor_manager_(actor_manager),
retry_task_callback_(retry_task_callback) {} retry_task_callback_(retry_task_callback) {}
/// Add a task that is pending execution. /// Add a task that is pending execution.
/// ///
/// \param[in] caller_id The TaskID of the calling task.
/// \param[in] caller_address The rpc address of the calling task.
/// \param[in] spec The spec of the pending task. /// \param[in] spec The spec of the pending task.
/// \param[in] max_retries Number of times this task may be retried /// \param[in] max_retries Number of times this task may be retried
/// on failure. /// on failure.
/// \return Void. /// \return Void.
void AddPendingTask(const TaskSpecification &spec, int max_retries = 0); void AddPendingTask(const TaskID &caller_id, const rpc::Address &caller_address,
const TaskSpecification &spec, int max_retries = 0);
/// Wait for all pending tasks to finish, and then shutdown. /// Wait for all pending tasks to finish, and then shutdown.
/// ///
@ -72,6 +79,8 @@ class TaskManager : public TaskFinisherInterface {
void PendingTaskFailed(const TaskID &task_id, rpc::ErrorType error_type, void PendingTaskFailed(const TaskID &task_id, rpc::ErrorType error_type,
Status *status = nullptr) override; Status *status = nullptr) override;
void OnTaskDependenciesInlined(const std::vector<ObjectID> &object_id) override;
/// Return the spec for a pending task. /// Return the spec for a pending task.
TaskSpecification GetTaskSpec(const TaskID &task_id) const; TaskSpecification GetTaskSpec(const TaskID &task_id) const;
@ -81,12 +90,25 @@ class TaskManager : public TaskFinisherInterface {
void MarkPendingTaskFailed(const TaskID &task_id, const TaskSpecification &spec, void MarkPendingTaskFailed(const TaskID &task_id, const TaskSpecification &spec,
rpc::ErrorType error_type) LOCKS_EXCLUDED(mu_); rpc::ErrorType error_type) LOCKS_EXCLUDED(mu_);
/// Remove submittted task references in the reference counter for the object IDs.
/// If their reference counts reach zero, they are deleted from the in-memory store.
void RemoveSubmittedTaskReferences(const std::vector<ObjectID> &object_ids);
/// Helper function to call RemoveSubmittedTaskReferences on the plasma dependencies
/// of the given task spec.
void RemovePlasmaSubmittedTaskReferences(TaskSpecification &spec);
/// Shutdown if all tasks are finished and shutdown is scheduled. /// Shutdown if all tasks are finished and shutdown is scheduled.
void ShutdownIfNeeded() LOCKS_EXCLUDED(mu_); void ShutdownIfNeeded() LOCKS_EXCLUDED(mu_);
/// Used to store task results. /// Used to store task results.
std::shared_ptr<CoreWorkerMemoryStore> in_memory_store_; std::shared_ptr<CoreWorkerMemoryStore> in_memory_store_;
/// Used for reference counting objects.
/// The task manager is responsible for managing all references related to
/// submitted tasks (dependencies and return objects).
std::shared_ptr<ReferenceCounter> reference_counter_;
// Interface for publishing actor creation. // Interface for publishing actor creation.
std::shared_ptr<ActorManagerInterface> actor_manager_; std::shared_ptr<ActorManagerInterface> actor_manager_;

View file

@ -1,6 +1,5 @@
#include "gmock/gmock.h" #include "gmock/gmock.h"
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "ray/common/task/task_spec.h" #include "ray/common/task/task_spec.h"
#include "ray/core_worker/store_provider/memory_store/memory_store.h" #include "ray/core_worker/store_provider/memory_store/memory_store.h"
#include "ray/core_worker/transport/direct_task_transport.h" #include "ray/core_worker/transport/direct_task_transport.h"
@ -45,6 +44,8 @@ class MockTaskFinisher : public TaskFinisherInterface {
const rpc::Address *addr)); const rpc::Address *addr));
MOCK_METHOD3(PendingTaskFailed, MOCK_METHOD3(PendingTaskFailed,
void(const TaskID &task_id, rpc::ErrorType error_type, Status *status)); void(const TaskID &task_id, rpc::ErrorType error_type, Status *status));
MOCK_METHOD1(OnTaskDependenciesInlined, void(const std::vector<ObjectID> &object_ids));
}; };
TaskSpecification CreateActorTaskHelper(ActorID actor_id, int64_t counter) { TaskSpecification CreateActorTaskHelper(ActorID actor_id, int64_t counter) {

View file

@ -55,8 +55,13 @@ class MockTaskFinisher : public TaskFinisherInterface {
num_tasks_failed++; num_tasks_failed++;
} }
void OnTaskDependenciesInlined(const std::vector<ObjectID> &object_ids) override {
num_inlined += object_ids.size();
}
int num_tasks_complete = 0; int num_tasks_complete = 0;
int num_tasks_failed = 0; int num_tasks_failed = 0;
int num_inlined = 0;
}; };
class MockRayletClient : public WorkerLeaseInterface { class MockRayletClient : public WorkerLeaseInterface {
@ -136,17 +141,20 @@ TEST(TestMemoryStore, TestPromoteToPlasma) {
TEST(LocalDependencyResolverTest, TestNoDependencies) { TEST(LocalDependencyResolverTest, TestNoDependencies) {
auto store = std::make_shared<CoreWorkerMemoryStore>(); auto store = std::make_shared<CoreWorkerMemoryStore>();
LocalDependencyResolver resolver(store); auto task_finisher = std::make_shared<MockTaskFinisher>();
LocalDependencyResolver resolver(store, task_finisher);
TaskSpecification task; TaskSpecification task;
bool ok = false; bool ok = false;
resolver.ResolveDependencies(task, [&ok]() { ok = true; }); resolver.ResolveDependencies(task, [&ok]() { ok = true; });
ASSERT_TRUE(ok); ASSERT_TRUE(ok);
ASSERT_EQ(task_finisher->num_inlined, 0);
} }
TEST(LocalDependencyResolverTest, TestIgnorePlasmaDependencies) { TEST(LocalDependencyResolverTest, TestIgnorePlasmaDependencies) {
auto store = std::make_shared<CoreWorkerMemoryStore>(); auto store = std::make_shared<CoreWorkerMemoryStore>();
LocalDependencyResolver resolver(store); auto task_finisher = std::make_shared<MockTaskFinisher>();
ObjectID obj1 = ObjectID::FromRandom().WithTransportType(TaskTransportType::RAYLET); LocalDependencyResolver resolver(store, task_finisher);
ObjectID obj1 = ObjectID::FromRandom();
TaskSpecification task; TaskSpecification task;
task.GetMutableMessage().add_args()->add_object_ids(obj1.Binary()); task.GetMutableMessage().add_args()->add_object_ids(obj1.Binary());
bool ok = false; bool ok = false;
@ -154,11 +162,13 @@ TEST(LocalDependencyResolverTest, TestIgnorePlasmaDependencies) {
// We ignore and don't block on plasma dependencies. // We ignore and don't block on plasma dependencies.
ASSERT_TRUE(ok); ASSERT_TRUE(ok);
ASSERT_EQ(resolver.NumPendingTasks(), 0); ASSERT_EQ(resolver.NumPendingTasks(), 0);
ASSERT_EQ(task_finisher->num_inlined, 0);
} }
TEST(LocalDependencyResolverTest, TestHandlePlasmaPromotion) { TEST(LocalDependencyResolverTest, TestHandlePlasmaPromotion) {
auto store = std::make_shared<CoreWorkerMemoryStore>(); auto store = std::make_shared<CoreWorkerMemoryStore>();
LocalDependencyResolver resolver(store); auto task_finisher = std::make_shared<MockTaskFinisher>();
LocalDependencyResolver resolver(store, task_finisher);
ObjectID obj1 = ObjectID::FromRandom().WithTransportType(TaskTransportType::DIRECT); ObjectID obj1 = ObjectID::FromRandom().WithTransportType(TaskTransportType::DIRECT);
std::string meta = std::to_string(static_cast<int>(rpc::ErrorType::OBJECT_IN_PLASMA)); std::string meta = std::to_string(static_cast<int>(rpc::ErrorType::OBJECT_IN_PLASMA));
auto metadata = const_cast<uint8_t *>(reinterpret_cast<const uint8_t *>(meta.data())); auto metadata = const_cast<uint8_t *>(reinterpret_cast<const uint8_t *>(meta.data()));
@ -175,11 +185,13 @@ TEST(LocalDependencyResolverTest, TestHandlePlasmaPromotion) {
// Checks that the object id is still a direct call id. // Checks that the object id is still a direct call id.
ASSERT_TRUE(task.ArgId(0, 0).IsDirectCallType()); ASSERT_TRUE(task.ArgId(0, 0).IsDirectCallType());
ASSERT_EQ(resolver.NumPendingTasks(), 0); ASSERT_EQ(resolver.NumPendingTasks(), 0);
ASSERT_EQ(task_finisher->num_inlined, 0);
} }
TEST(LocalDependencyResolverTest, TestInlineLocalDependencies) { TEST(LocalDependencyResolverTest, TestInlineLocalDependencies) {
auto store = std::make_shared<CoreWorkerMemoryStore>(); auto store = std::make_shared<CoreWorkerMemoryStore>();
LocalDependencyResolver resolver(store); auto task_finisher = std::make_shared<MockTaskFinisher>();
LocalDependencyResolver resolver(store, task_finisher);
ObjectID obj1 = ObjectID::FromRandom().WithTransportType(TaskTransportType::DIRECT); ObjectID obj1 = ObjectID::FromRandom().WithTransportType(TaskTransportType::DIRECT);
ObjectID obj2 = ObjectID::FromRandom().WithTransportType(TaskTransportType::DIRECT); ObjectID obj2 = ObjectID::FromRandom().WithTransportType(TaskTransportType::DIRECT);
auto data = GenerateRandomObject(); auto data = GenerateRandomObject();
@ -198,11 +210,13 @@ TEST(LocalDependencyResolverTest, TestInlineLocalDependencies) {
ASSERT_NE(task.ArgData(0), nullptr); ASSERT_NE(task.ArgData(0), nullptr);
ASSERT_NE(task.ArgData(1), nullptr); ASSERT_NE(task.ArgData(1), nullptr);
ASSERT_EQ(resolver.NumPendingTasks(), 0); ASSERT_EQ(resolver.NumPendingTasks(), 0);
ASSERT_EQ(task_finisher->num_inlined, 2);
} }
TEST(LocalDependencyResolverTest, TestInlinePendingDependencies) { TEST(LocalDependencyResolverTest, TestInlinePendingDependencies) {
auto store = std::make_shared<CoreWorkerMemoryStore>(); auto store = std::make_shared<CoreWorkerMemoryStore>();
LocalDependencyResolver resolver(store); auto task_finisher = std::make_shared<MockTaskFinisher>();
LocalDependencyResolver resolver(store, task_finisher);
ObjectID obj1 = ObjectID::FromRandom().WithTransportType(TaskTransportType::DIRECT); ObjectID obj1 = ObjectID::FromRandom().WithTransportType(TaskTransportType::DIRECT);
ObjectID obj2 = ObjectID::FromRandom().WithTransportType(TaskTransportType::DIRECT); ObjectID obj2 = ObjectID::FromRandom().WithTransportType(TaskTransportType::DIRECT);
auto data = GenerateRandomObject(); auto data = GenerateRandomObject();
@ -223,6 +237,7 @@ TEST(LocalDependencyResolverTest, TestInlinePendingDependencies) {
ASSERT_NE(task.ArgData(0), nullptr); ASSERT_NE(task.ArgData(0), nullptr);
ASSERT_NE(task.ArgData(1), nullptr); ASSERT_NE(task.ArgData(1), nullptr);
ASSERT_EQ(resolver.NumPendingTasks(), 0); ASSERT_EQ(resolver.NumPendingTasks(), 0);
ASSERT_EQ(task_finisher->num_inlined, 2);
} }
TaskSpecification BuildTaskSpec(const std::unordered_map<std::string, double> &resources, TaskSpecification BuildTaskSpec(const std::unordered_map<std::string, double> &resources,

View file

@ -1,17 +1,22 @@
#include "gtest/gtest.h" #include "ray/core_worker/task_manager.h"
#include "gtest/gtest.h"
#include "ray/common/task/task_spec.h" #include "ray/common/task/task_spec.h"
#include "ray/core_worker/actor_manager.h" #include "ray/core_worker/actor_manager.h"
#include "ray/core_worker/reference_count.h"
#include "ray/core_worker/store_provider/memory_store/memory_store.h" #include "ray/core_worker/store_provider/memory_store/memory_store.h"
#include "ray/core_worker/task_manager.h"
#include "ray/util/test_util.h" #include "ray/util/test_util.h"
namespace ray { namespace ray {
TaskSpecification CreateTaskHelper(uint64_t num_returns) { TaskSpecification CreateTaskHelper(uint64_t num_returns,
std::vector<ObjectID> dependencies) {
TaskSpecification task; TaskSpecification task;
task.GetMutableMessage().set_task_id(TaskID::ForFakeTask().Binary()); task.GetMutableMessage().set_task_id(TaskID::ForFakeTask().Binary());
task.GetMutableMessage().set_num_returns(num_returns); task.GetMutableMessage().set_num_returns(num_returns);
for (const ObjectID &dep : dependencies) {
task.GetMutableMessage().add_args()->add_object_ids(dep.Binary());
}
return task; return task;
} }
@ -33,23 +38,31 @@ class TaskManagerTest : public ::testing::Test {
public: public:
TaskManagerTest() TaskManagerTest()
: store_(std::shared_ptr<CoreWorkerMemoryStore>(new CoreWorkerMemoryStore())), : store_(std::shared_ptr<CoreWorkerMemoryStore>(new CoreWorkerMemoryStore())),
reference_counter_(std::shared_ptr<ReferenceCounter>(new ReferenceCounter())),
actor_manager_(std::shared_ptr<ActorManagerInterface>(new MockActorManager())), actor_manager_(std::shared_ptr<ActorManagerInterface>(new MockActorManager())),
manager_(store_, actor_manager_, [this](const TaskSpecification &spec) { manager_(store_, reference_counter_, actor_manager_,
[this](const TaskSpecification &spec) {
num_retries_++; num_retries_++;
return Status::OK(); return Status::OK();
}) {} }) {}
std::shared_ptr<CoreWorkerMemoryStore> store_; std::shared_ptr<CoreWorkerMemoryStore> store_;
std::shared_ptr<ReferenceCounter> reference_counter_;
std::shared_ptr<ActorManagerInterface> actor_manager_; std::shared_ptr<ActorManagerInterface> actor_manager_;
TaskManager manager_; TaskManager manager_;
int num_retries_ = 0; int num_retries_ = 0;
}; };
TEST_F(TaskManagerTest, TestTaskSuccess) { TEST_F(TaskManagerTest, TestTaskSuccess) {
auto spec = CreateTaskHelper(1); TaskID caller_id = TaskID::Nil();
rpc::Address caller_address;
ObjectID dep1 = ObjectID::FromRandom();
ObjectID dep2 = ObjectID::FromRandom();
auto spec = CreateTaskHelper(1, {dep1, dep2});
ASSERT_FALSE(manager_.IsTaskPending(spec.TaskId())); ASSERT_FALSE(manager_.IsTaskPending(spec.TaskId()));
manager_.AddPendingTask(spec); manager_.AddPendingTask(caller_id, caller_address, spec);
ASSERT_TRUE(manager_.IsTaskPending(spec.TaskId())); ASSERT_TRUE(manager_.IsTaskPending(spec.TaskId()));
ASSERT_EQ(reference_counter_->NumObjectIDsInScope(), 3);
auto return_id = spec.ReturnId(0, TaskTransportType::DIRECT); auto return_id = spec.ReturnId(0, TaskTransportType::DIRECT);
WorkerContext ctx(WorkerType::WORKER, JobID::FromInt(0)); WorkerContext ctx(WorkerType::WORKER, JobID::FromInt(0));
@ -60,6 +73,8 @@ TEST_F(TaskManagerTest, TestTaskSuccess) {
return_object->set_data(data->Data(), data->Size()); return_object->set_data(data->Data(), data->Size());
manager_.CompletePendingTask(spec.TaskId(), reply, nullptr); manager_.CompletePendingTask(spec.TaskId(), reply, nullptr);
ASSERT_FALSE(manager_.IsTaskPending(spec.TaskId())); ASSERT_FALSE(manager_.IsTaskPending(spec.TaskId()));
// Only the return object reference should remain.
ASSERT_EQ(reference_counter_->NumObjectIDsInScope(), 1);
std::vector<std::shared_ptr<RayObject>> results; std::vector<std::shared_ptr<RayObject>> results;
RAY_CHECK_OK(store_->Get({return_id}, 1, -1, ctx, false, &results)); RAY_CHECK_OK(store_->Get({return_id}, 1, -1, ctx, false, &results));
@ -69,19 +84,33 @@ TEST_F(TaskManagerTest, TestTaskSuccess) {
return_object->data().size()), return_object->data().size()),
0); 0);
ASSERT_EQ(num_retries_, 0); ASSERT_EQ(num_retries_, 0);
std::vector<ObjectID> removed;
reference_counter_->AddLocalReference(return_id);
reference_counter_->RemoveLocalReference(return_id, &removed);
ASSERT_EQ(removed[0], return_id);
ASSERT_EQ(reference_counter_->NumObjectIDsInScope(), 0);
} }
TEST_F(TaskManagerTest, TestTaskFailure) { TEST_F(TaskManagerTest, TestTaskFailure) {
auto spec = CreateTaskHelper(1); TaskID caller_id = TaskID::Nil();
rpc::Address caller_address;
ObjectID dep1 = ObjectID::FromRandom();
ObjectID dep2 = ObjectID::FromRandom();
ASSERT_EQ(reference_counter_->NumObjectIDsInScope(), 0);
auto spec = CreateTaskHelper(1, {dep1, dep2});
ASSERT_FALSE(manager_.IsTaskPending(spec.TaskId())); ASSERT_FALSE(manager_.IsTaskPending(spec.TaskId()));
manager_.AddPendingTask(spec); manager_.AddPendingTask(caller_id, caller_address, spec);
ASSERT_TRUE(manager_.IsTaskPending(spec.TaskId())); ASSERT_TRUE(manager_.IsTaskPending(spec.TaskId()));
ASSERT_EQ(reference_counter_->NumObjectIDsInScope(), 3);
auto return_id = spec.ReturnId(0, TaskTransportType::DIRECT); auto return_id = spec.ReturnId(0, TaskTransportType::DIRECT);
WorkerContext ctx(WorkerType::WORKER, JobID::FromInt(0)); WorkerContext ctx(WorkerType::WORKER, JobID::FromInt(0));
auto error = rpc::ErrorType::WORKER_DIED; auto error = rpc::ErrorType::WORKER_DIED;
manager_.PendingTaskFailed(spec.TaskId(), error); manager_.PendingTaskFailed(spec.TaskId(), error);
ASSERT_FALSE(manager_.IsTaskPending(spec.TaskId())); ASSERT_FALSE(manager_.IsTaskPending(spec.TaskId()));
// Only the return object reference should remain.
ASSERT_EQ(reference_counter_->NumObjectIDsInScope(), 1);
std::vector<std::shared_ptr<RayObject>> results; std::vector<std::shared_ptr<RayObject>> results;
RAY_CHECK_OK(store_->Get({return_id}, 1, -1, ctx, false, &results)); RAY_CHECK_OK(store_->Get({return_id}, 1, -1, ctx, false, &results));
@ -90,14 +119,26 @@ TEST_F(TaskManagerTest, TestTaskFailure) {
ASSERT_TRUE(results[0]->IsException(&stored_error)); ASSERT_TRUE(results[0]->IsException(&stored_error));
ASSERT_EQ(stored_error, error); ASSERT_EQ(stored_error, error);
ASSERT_EQ(num_retries_, 0); ASSERT_EQ(num_retries_, 0);
std::vector<ObjectID> removed;
reference_counter_->AddLocalReference(return_id);
reference_counter_->RemoveLocalReference(return_id, &removed);
ASSERT_EQ(removed[0], return_id);
ASSERT_EQ(reference_counter_->NumObjectIDsInScope(), 0);
} }
TEST_F(TaskManagerTest, TestTaskRetry) { TEST_F(TaskManagerTest, TestTaskRetry) {
auto spec = CreateTaskHelper(1); TaskID caller_id = TaskID::Nil();
rpc::Address caller_address;
ObjectID dep1 = ObjectID::FromRandom();
ObjectID dep2 = ObjectID::FromRandom();
ASSERT_EQ(reference_counter_->NumObjectIDsInScope(), 0);
auto spec = CreateTaskHelper(1, {dep1, dep2});
ASSERT_FALSE(manager_.IsTaskPending(spec.TaskId())); ASSERT_FALSE(manager_.IsTaskPending(spec.TaskId()));
int num_retries = 3; int num_retries = 3;
manager_.AddPendingTask(spec, num_retries); manager_.AddPendingTask(caller_id, caller_address, spec, num_retries);
ASSERT_TRUE(manager_.IsTaskPending(spec.TaskId())); ASSERT_TRUE(manager_.IsTaskPending(spec.TaskId()));
ASSERT_EQ(reference_counter_->NumObjectIDsInScope(), 3);
auto return_id = spec.ReturnId(0, TaskTransportType::DIRECT); auto return_id = spec.ReturnId(0, TaskTransportType::DIRECT);
WorkerContext ctx(WorkerType::WORKER, JobID::FromInt(0)); WorkerContext ctx(WorkerType::WORKER, JobID::FromInt(0));
@ -105,6 +146,7 @@ TEST_F(TaskManagerTest, TestTaskRetry) {
for (int i = 0; i < num_retries; i++) { for (int i = 0; i < num_retries; i++) {
manager_.PendingTaskFailed(spec.TaskId(), error); manager_.PendingTaskFailed(spec.TaskId(), error);
ASSERT_TRUE(manager_.IsTaskPending(spec.TaskId())); ASSERT_TRUE(manager_.IsTaskPending(spec.TaskId()));
ASSERT_EQ(reference_counter_->NumObjectIDsInScope(), 3);
std::vector<std::shared_ptr<RayObject>> results; std::vector<std::shared_ptr<RayObject>> results;
ASSERT_FALSE(store_->Get({return_id}, 1, 0, ctx, false, &results).ok()); ASSERT_FALSE(store_->Get({return_id}, 1, 0, ctx, false, &results).ok());
ASSERT_EQ(num_retries_, i + 1); ASSERT_EQ(num_retries_, i + 1);
@ -112,6 +154,8 @@ TEST_F(TaskManagerTest, TestTaskRetry) {
manager_.PendingTaskFailed(spec.TaskId(), error); manager_.PendingTaskFailed(spec.TaskId(), error);
ASSERT_FALSE(manager_.IsTaskPending(spec.TaskId())); ASSERT_FALSE(manager_.IsTaskPending(spec.TaskId()));
// Only the return object reference should remain.
ASSERT_EQ(reference_counter_->NumObjectIDsInScope(), 1);
std::vector<std::shared_ptr<RayObject>> results; std::vector<std::shared_ptr<RayObject>> results;
RAY_CHECK_OK(store_->Get({return_id}, 1, -0, ctx, false, &results)); RAY_CHECK_OK(store_->Get({return_id}, 1, -0, ctx, false, &results));
@ -119,6 +163,12 @@ TEST_F(TaskManagerTest, TestTaskRetry) {
rpc::ErrorType stored_error; rpc::ErrorType stored_error;
ASSERT_TRUE(results[0]->IsException(&stored_error)); ASSERT_TRUE(results[0]->IsException(&stored_error));
ASSERT_EQ(stored_error, error); ASSERT_EQ(stored_error, error);
std::vector<ObjectID> removed;
reference_counter_->AddLocalReference(return_id);
reference_counter_->RemoveLocalReference(return_id, &removed);
ASSERT_EQ(removed[0], return_id);
ASSERT_EQ(reference_counter_->NumObjectIDsInScope(), 0);
} }
} // namespace ray } // namespace ray

View file

@ -18,7 +18,7 @@ struct TaskState {
void InlineDependencies( void InlineDependencies(
absl::flat_hash_map<ObjectID, std::shared_ptr<RayObject>> dependencies, absl::flat_hash_map<ObjectID, std::shared_ptr<RayObject>> dependencies,
TaskSpecification &task) { TaskSpecification &task, std::vector<ObjectID> *inlined) {
auto &msg = task.GetMutableMessage(); auto &msg = task.GetMutableMessage();
size_t found = 0; size_t found = 0;
for (size_t i = 0; i < task.NumArgs(); i++) { for (size_t i = 0; i < task.NumArgs(); i++) {
@ -43,6 +43,7 @@ void InlineDependencies(
const auto &metadata = it->second->GetMetadata(); const auto &metadata = it->second->GetMetadata();
mutable_arg->set_metadata(metadata->Data(), metadata->Size()); mutable_arg->set_metadata(metadata->Data(), metadata->Size());
} }
inlined->push_back(id);
} }
found++; found++;
} else { } else {
@ -83,15 +84,19 @@ void LocalDependencyResolver::ResolveDependencies(TaskSpecification &task,
obj_id, [this, state, obj_id, on_complete](std::shared_ptr<RayObject> obj) { obj_id, [this, state, obj_id, on_complete](std::shared_ptr<RayObject> obj) {
RAY_CHECK(obj != nullptr); RAY_CHECK(obj != nullptr);
bool complete = false; bool complete = false;
std::vector<ObjectID> inlined;
{ {
absl::MutexLock lock(&mu_); absl::MutexLock lock(&mu_);
state->local_dependencies[obj_id] = std::move(obj); state->local_dependencies[obj_id] = std::move(obj);
if (--state->dependencies_remaining == 0) { if (--state->dependencies_remaining == 0) {
InlineDependencies(state->local_dependencies, state->task); InlineDependencies(state->local_dependencies, state->task, &inlined);
complete = true; complete = true;
num_pending_ -= 1; num_pending_ -= 1;
} }
} }
if (inlined.size() > 0) {
task_finisher_->OnTaskDependenciesInlined(inlined);
}
if (complete) { if (complete) {
on_complete(); on_complete();
} }

View file

@ -6,14 +6,16 @@
#include "ray/common/id.h" #include "ray/common/id.h"
#include "ray/common/task/task_spec.h" #include "ray/common/task/task_spec.h"
#include "ray/core_worker/store_provider/memory_store/memory_store.h" #include "ray/core_worker/store_provider/memory_store/memory_store.h"
#include "ray/core_worker/task_manager.h"
namespace ray { namespace ray {
// This class is thread-safe. // This class is thread-safe.
class LocalDependencyResolver { class LocalDependencyResolver {
public: public:
LocalDependencyResolver(std::shared_ptr<CoreWorkerMemoryStore> store) LocalDependencyResolver(std::shared_ptr<CoreWorkerMemoryStore> store,
: in_memory_store_(store), num_pending_(0) {} std::shared_ptr<TaskFinisherInterface> task_finisher)
: in_memory_store_(store), task_finisher_(task_finisher), num_pending_(0) {}
/// Resolve all local and remote dependencies for the task, calling the specified /// Resolve all local and remote dependencies for the task, calling the specified
/// callback when done. Direct call ids in the task specification will be resolved /// callback when done. Direct call ids in the task specification will be resolved
@ -33,6 +35,9 @@ class LocalDependencyResolver {
/// The in-memory store. /// The in-memory store.
std::shared_ptr<CoreWorkerMemoryStore> in_memory_store_; std::shared_ptr<CoreWorkerMemoryStore> in_memory_store_;
/// Used to complete tasks.
std::shared_ptr<TaskFinisherInterface> task_finisher_;
/// Number of tasks pending dependency resolution. /// Number of tasks pending dependency resolution.
std::atomic<int> num_pending_; std::atomic<int> num_pending_;

View file

@ -8,9 +8,9 @@
#include <queue> #include <queue>
#include <set> #include <set>
#include <utility> #include <utility>
#include "absl/container/flat_hash_map.h"
#include "absl/base/thread_annotations.h" #include "absl/base/thread_annotations.h"
#include "absl/container/flat_hash_map.h"
#include "absl/synchronization/mutex.h" #include "absl/synchronization/mutex.h"
#include "ray/common/id.h" #include "ray/common/id.h"
#include "ray/common/ray_object.h" #include "ray/common/ray_object.h"
@ -39,7 +39,7 @@ class CoreWorkerDirectActorTaskSubmitter {
std::shared_ptr<CoreWorkerMemoryStore> store, std::shared_ptr<CoreWorkerMemoryStore> store,
std::shared_ptr<TaskFinisherInterface> task_finisher) std::shared_ptr<TaskFinisherInterface> task_finisher)
: client_factory_(client_factory), : client_factory_(client_factory),
resolver_(store), resolver_(store, task_finisher),
task_finisher_(task_finisher) {} task_finisher_(task_finisher) {}
/// Submit a task to an actor for execution. /// Submit a task to an actor for execution.

View file

@ -5,7 +5,6 @@
#include "absl/base/thread_annotations.h" #include "absl/base/thread_annotations.h"
#include "absl/synchronization/mutex.h" #include "absl/synchronization/mutex.h"
#include "ray/common/id.h" #include "ray/common/id.h"
#include "ray/common/ray_object.h" #include "ray/common/ray_object.h"
#include "ray/core_worker/context.h" #include "ray/core_worker/context.h"
@ -41,7 +40,7 @@ class CoreWorkerDirectTaskSubmitter {
: local_lease_client_(lease_client), : local_lease_client_(lease_client),
client_factory_(client_factory), client_factory_(client_factory),
lease_client_factory_(lease_client_factory), lease_client_factory_(lease_client_factory),
resolver_(store), resolver_(store, task_finisher),
task_finisher_(task_finisher), task_finisher_(task_finisher),
local_raylet_id_(local_raylet_id), local_raylet_id_(local_raylet_id),
lease_timeout_ms_(lease_timeout_ms) {} lease_timeout_ms_(lease_timeout_ms) {}