mirror of
https://github.com/vale981/ray
synced 2025-03-08 19:41:38 -05:00
Reference counting for direct call submitted tasks (#6514)
Co-authored-by: Zhijun Fu <37800433+zhijunfu@users.noreply.github.com>
This commit is contained in:
parent
b0b6b56bb7
commit
e50aa99be1
17 changed files with 512 additions and 344 deletions
|
@ -1073,16 +1073,12 @@ cdef class CoreWorker:
|
||||||
return output
|
return output
|
||||||
|
|
||||||
def add_object_id_reference(self, ObjectID object_id):
|
def add_object_id_reference(self, ObjectID object_id):
|
||||||
cdef:
|
|
||||||
CObjectID c_object_id = object_id.native()
|
|
||||||
# Note: faster to not release GIL for short-running op.
|
# Note: faster to not release GIL for short-running op.
|
||||||
self.core_worker.get().AddObjectIDReference(c_object_id)
|
self.core_worker.get().AddLocalReference(object_id.native())
|
||||||
|
|
||||||
def remove_object_id_reference(self, ObjectID object_id):
|
def remove_object_id_reference(self, ObjectID object_id):
|
||||||
cdef:
|
|
||||||
CObjectID c_object_id = object_id.native()
|
|
||||||
# Note: faster to not release GIL for short-running op.
|
# Note: faster to not release GIL for short-running op.
|
||||||
self.core_worker.get().RemoveObjectIDReference(c_object_id)
|
self.core_worker.get().RemoveLocalReference(object_id.native())
|
||||||
|
|
||||||
def serialize_and_promote_object_id(self, ObjectID object_id):
|
def serialize_and_promote_object_id(self, ObjectID object_id):
|
||||||
cdef:
|
cdef:
|
||||||
|
@ -1174,6 +1170,24 @@ cdef class CoreWorker:
|
||||||
def current_actor_is_asyncio(self):
|
def current_actor_is_asyncio(self):
|
||||||
return self.core_worker.get().GetWorkerContext().CurrentActorIsAsync()
|
return self.core_worker.get().GetWorkerContext().CurrentActorIsAsync()
|
||||||
|
|
||||||
|
def get_all_reference_counts(self):
|
||||||
|
cdef:
|
||||||
|
unordered_map[CObjectID, pair[size_t, size_t]] c_ref_counts
|
||||||
|
unordered_map[CObjectID, pair[size_t, size_t]].iterator it
|
||||||
|
|
||||||
|
c_ref_counts = self.core_worker.get().GetAllReferenceCounts()
|
||||||
|
it = c_ref_counts.begin()
|
||||||
|
|
||||||
|
ref_counts = {}
|
||||||
|
while it != c_ref_counts.end():
|
||||||
|
object_id = ObjectID(dereference(it).first.Binary())
|
||||||
|
ref_counts[object_id] = {
|
||||||
|
"local": dereference(it).second.first,
|
||||||
|
"submitted": dereference(it).second.second}
|
||||||
|
postincrement(it)
|
||||||
|
|
||||||
|
return ref_counts
|
||||||
|
|
||||||
def in_memory_store_get_async(self, ObjectID object_id, future):
|
def in_memory_store_get_async(self, ObjectID object_id, future):
|
||||||
self.core_worker.get().GetAsync(
|
self.core_worker.get().GetAsync(
|
||||||
object_id.native(),
|
object_id.native(),
|
||||||
|
|
|
@ -118,8 +118,8 @@ cdef extern from "ray/core_worker/core_worker.h" nogil:
|
||||||
CActorID DeserializeAndRegisterActorHandle(const c_string &bytes)
|
CActorID DeserializeAndRegisterActorHandle(const c_string &bytes)
|
||||||
CRayStatus SerializeActorHandle(const CActorID &actor_id, c_string
|
CRayStatus SerializeActorHandle(const CActorID &actor_id, c_string
|
||||||
*bytes)
|
*bytes)
|
||||||
void AddObjectIDReference(const CObjectID &object_id)
|
void AddLocalReference(const CObjectID &object_id)
|
||||||
void RemoveObjectIDReference(const CObjectID &object_id)
|
void RemoveLocalReference(const CObjectID &object_id)
|
||||||
void PromoteObjectToPlasma(const CObjectID &object_id)
|
void PromoteObjectToPlasma(const CObjectID &object_id)
|
||||||
void PromoteToPlasmaAndGetOwnershipInfo(const CObjectID &object_id,
|
void PromoteToPlasmaAndGetOwnershipInfo(const CObjectID &object_id,
|
||||||
CTaskID *owner_id,
|
CTaskID *owner_id,
|
||||||
|
@ -149,6 +149,9 @@ cdef extern from "ray/core_worker/core_worker.h" nogil:
|
||||||
|
|
||||||
CWorkerContext &GetWorkerContext()
|
CWorkerContext &GetWorkerContext()
|
||||||
void YieldCurrentFiber(CFiberEvent &coroutine_done)
|
void YieldCurrentFiber(CFiberEvent &coroutine_done)
|
||||||
|
|
||||||
|
unordered_map[CObjectID, pair[size_t, size_t]] GetAllReferenceCounts()
|
||||||
|
|
||||||
void GetAsync(const CObjectID &object_id,
|
void GetAsync(const CObjectID &object_id,
|
||||||
ray_callback_function successs_callback,
|
ray_callback_function successs_callback,
|
||||||
ray_callback_function fallback_callback,
|
ray_callback_function fallback_callback,
|
||||||
|
|
162
python/ray/tests/test_reference_counting.py
Normal file
162
python/ray/tests/test_reference_counting.py
Normal file
|
@ -0,0 +1,162 @@
|
||||||
|
# coding: utf-8
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import os
|
||||||
|
import copy
|
||||||
|
import tempfile
|
||||||
|
import numpy as np
|
||||||
|
import time
|
||||||
|
import logging
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
import ray
|
||||||
|
import ray.cluster_utils
|
||||||
|
import ray.test_utils
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _check_refcounts(expected):
|
||||||
|
actual = ray.worker.global_worker.core_worker.get_all_reference_counts()
|
||||||
|
assert len(expected) == len(actual)
|
||||||
|
for object_id, (local, submitted) in expected.items():
|
||||||
|
assert object_id in actual
|
||||||
|
assert local == actual[object_id]["local"]
|
||||||
|
assert submitted == actual[object_id]["submitted"]
|
||||||
|
|
||||||
|
|
||||||
|
def check_refcounts(expected, timeout=1):
|
||||||
|
start = time.time()
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
_check_refcounts(expected)
|
||||||
|
break
|
||||||
|
except AssertionError as e:
|
||||||
|
if time.time() - start > timeout:
|
||||||
|
raise e
|
||||||
|
else:
|
||||||
|
time.sleep(0.1)
|
||||||
|
|
||||||
|
|
||||||
|
def test_local_refcounts(ray_start_regular):
|
||||||
|
oid1 = ray.put(None)
|
||||||
|
check_refcounts({oid1: (1, 0)})
|
||||||
|
oid1_copy = copy.copy(oid1)
|
||||||
|
check_refcounts({oid1: (2, 0)})
|
||||||
|
del oid1
|
||||||
|
check_refcounts({oid1_copy: (1, 0)})
|
||||||
|
del oid1_copy
|
||||||
|
check_refcounts({})
|
||||||
|
|
||||||
|
|
||||||
|
def test_dependency_refcounts(ray_start_regular):
|
||||||
|
# Return a large object that will be spilled to plasma.
|
||||||
|
def large_object():
|
||||||
|
return np.zeros(10 * 1024 * 1024, dtype=np.uint8)
|
||||||
|
|
||||||
|
# TODO: Clean up tmpfiles?
|
||||||
|
def random_path():
|
||||||
|
return os.path.join(tempfile.gettempdir(), uuid.uuid4().hex)
|
||||||
|
|
||||||
|
def touch(path):
|
||||||
|
with open(path, "w"):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def wait_for_file(path):
|
||||||
|
while True:
|
||||||
|
if os.path.exists(path):
|
||||||
|
break
|
||||||
|
time.sleep(0.1)
|
||||||
|
|
||||||
|
@ray.remote
|
||||||
|
def one_dep(dep, path=None, fail=False):
|
||||||
|
if path is not None:
|
||||||
|
wait_for_file(path)
|
||||||
|
if fail:
|
||||||
|
raise Exception("failed on purpose")
|
||||||
|
|
||||||
|
@ray.remote
|
||||||
|
def one_dep_large(dep, path=None):
|
||||||
|
if path is not None:
|
||||||
|
wait_for_file(path)
|
||||||
|
# This should be spilled to plasma.
|
||||||
|
return large_object()
|
||||||
|
|
||||||
|
# Test that regular plasma dependency refcounts are decremented once the
|
||||||
|
# task finishes.
|
||||||
|
f = random_path()
|
||||||
|
large_dep = ray.put(large_object())
|
||||||
|
result = one_dep.remote(large_dep, path=f)
|
||||||
|
check_refcounts({large_dep: (1, 1), result: (1, 0)})
|
||||||
|
touch(f)
|
||||||
|
# Reference count should be removed once the task finishes.
|
||||||
|
check_refcounts({large_dep: (1, 0), result: (1, 0)})
|
||||||
|
del large_dep, result
|
||||||
|
check_refcounts({})
|
||||||
|
|
||||||
|
# Test that inlined dependency refcounts are decremented once they are
|
||||||
|
# inlined.
|
||||||
|
f = random_path()
|
||||||
|
dep = one_dep.remote(None, path=f)
|
||||||
|
check_refcounts({dep: (1, 0)})
|
||||||
|
result = one_dep.remote(dep)
|
||||||
|
check_refcounts({dep: (1, 1), result: (1, 0)})
|
||||||
|
touch(f)
|
||||||
|
# Reference count should be removed as soon as the dependency is inlined.
|
||||||
|
check_refcounts({dep: (1, 0), result: (1, 0)}, timeout=1)
|
||||||
|
del dep, result
|
||||||
|
check_refcounts({})
|
||||||
|
|
||||||
|
# Test that spilled plasma dependency refcounts are decremented once
|
||||||
|
# the task finishes.
|
||||||
|
f1, f2 = random_path(), random_path()
|
||||||
|
dep = one_dep_large.remote(None, path=f1)
|
||||||
|
check_refcounts({dep: (1, 0)})
|
||||||
|
result = one_dep.remote(dep, path=f2)
|
||||||
|
check_refcounts({dep: (1, 1), result: (1, 0)})
|
||||||
|
touch(f1)
|
||||||
|
ray.get(dep, timeout=5.0)
|
||||||
|
# Reference count should remain because the dependency is in plasma.
|
||||||
|
check_refcounts({dep: (1, 1), result: (1, 0)})
|
||||||
|
touch(f2)
|
||||||
|
# Reference count should be removed because the task finished.
|
||||||
|
check_refcounts({dep: (1, 0), result: (1, 0)})
|
||||||
|
del dep, result
|
||||||
|
check_refcounts({})
|
||||||
|
|
||||||
|
# Test that regular plasma dependency refcounts are decremented if a task
|
||||||
|
# fails.
|
||||||
|
f = random_path()
|
||||||
|
large_dep = ray.put(large_object())
|
||||||
|
result = one_dep.remote(large_dep, path=f, fail=True)
|
||||||
|
check_refcounts({large_dep: (1, 1), result: (1, 0)})
|
||||||
|
touch(f)
|
||||||
|
# Reference count should be removed once the task finishes.
|
||||||
|
check_refcounts({large_dep: (1, 0), result: (1, 0)})
|
||||||
|
del large_dep, result
|
||||||
|
check_refcounts({})
|
||||||
|
|
||||||
|
# Test that spilled plasma dependency refcounts are decremented if a task
|
||||||
|
# fails.
|
||||||
|
f1, f2 = random_path(), random_path()
|
||||||
|
dep = one_dep_large.remote(None, path=f1)
|
||||||
|
check_refcounts({dep: (1, 0)})
|
||||||
|
result = one_dep.remote(dep, path=f2, fail=True)
|
||||||
|
check_refcounts({dep: (1, 1), result: (1, 0)})
|
||||||
|
touch(f1)
|
||||||
|
ray.get(dep, timeout=5.0)
|
||||||
|
# Reference count should remain because the dependency is in plasma.
|
||||||
|
check_refcounts({dep: (1, 1), result: (1, 0)})
|
||||||
|
touch(f2)
|
||||||
|
# Reference count should be removed because the task finished.
|
||||||
|
check_refcounts({dep: (1, 0), result: (1, 0)})
|
||||||
|
del dep, result
|
||||||
|
check_refcounts({})
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import pytest
|
||||||
|
import sys
|
||||||
|
sys.exit(pytest.main(["-v", __file__]))
|
|
@ -192,14 +192,14 @@ CoreWorker::CoreWorker(const WorkerType worker_type, const Language language,
|
||||||
ref_counting_enabled ? reference_counter_ : nullptr, local_raylet_client_));
|
ref_counting_enabled ? reference_counter_ : nullptr, local_raylet_client_));
|
||||||
|
|
||||||
task_manager_.reset(new TaskManager(
|
task_manager_.reset(new TaskManager(
|
||||||
memory_store_, actor_manager_, [this](const TaskSpecification &spec) {
|
memory_store_, reference_counter_, actor_manager_,
|
||||||
|
[this](const TaskSpecification &spec) {
|
||||||
// Retry after a delay to emulate the existing Raylet reconstruction
|
// Retry after a delay to emulate the existing Raylet reconstruction
|
||||||
// behaviour. TODO(ekl) backoff exponentially.
|
// behaviour. TODO(ekl) backoff exponentially.
|
||||||
RAY_LOG(ERROR) << "Will resubmit task after a 5 second delay: "
|
RAY_LOG(ERROR) << "Will resubmit task after a 5 second delay: "
|
||||||
<< spec.DebugString();
|
<< spec.DebugString();
|
||||||
to_resubmit_.push_back(std::make_pair(current_time_ms() + 5000, spec));
|
to_resubmit_.push_back(std::make_pair(current_time_ms() + 5000, spec));
|
||||||
}));
|
}));
|
||||||
resolver_.reset(new LocalDependencyResolver(memory_store_));
|
|
||||||
|
|
||||||
// Create an entry for the driver task in the task table. This task is
|
// Create an entry for the driver task in the task table. This task is
|
||||||
// added immediately with status RUNNING. This allows us to push errors
|
// added immediately with status RUNNING. This allows us to push errors
|
||||||
|
@ -377,8 +377,7 @@ Status CoreWorker::Put(const RayObject &object, ObjectID *object_id) {
|
||||||
*object_id = ObjectID::ForPut(worker_context_.GetCurrentTaskID(),
|
*object_id = ObjectID::ForPut(worker_context_.GetCurrentTaskID(),
|
||||||
worker_context_.GetNextPutIndex(),
|
worker_context_.GetNextPutIndex(),
|
||||||
static_cast<uint8_t>(TaskTransportType::RAYLET));
|
static_cast<uint8_t>(TaskTransportType::RAYLET));
|
||||||
reference_counter_->AddOwnedObject(*object_id, GetCallerId(), rpc_address_,
|
reference_counter_->AddOwnedObject(*object_id, GetCallerId(), rpc_address_);
|
||||||
std::make_shared<std::vector<ObjectID>>());
|
|
||||||
return Put(object, *object_id);
|
return Put(object, *object_id);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -460,10 +459,6 @@ Status CoreWorker::Get(const std::vector<ObjectID> &ids, const int64_t timeout_m
|
||||||
// object.
|
// object.
|
||||||
will_throw_exception = true;
|
will_throw_exception = true;
|
||||||
}
|
}
|
||||||
// If we got the result for this ObjectID, the task that created it must
|
|
||||||
// have finished. Therefore, we can safely remove its reference counting
|
|
||||||
// dependencies.
|
|
||||||
RemoveObjectIDDependencies(ids[i]);
|
|
||||||
} else {
|
} else {
|
||||||
missing_result = true;
|
missing_result = true;
|
||||||
}
|
}
|
||||||
|
@ -597,30 +592,6 @@ TaskID CoreWorker::GetCallerId() const {
|
||||||
return caller_id;
|
return caller_id;
|
||||||
}
|
}
|
||||||
|
|
||||||
void CoreWorker::PinObjectReferences(const TaskSpecification &task_spec,
|
|
||||||
const TaskTransportType transport_type) {
|
|
||||||
size_t num_returns = task_spec.NumReturns();
|
|
||||||
if (task_spec.IsActorCreationTask() || task_spec.IsActorTask()) {
|
|
||||||
num_returns--;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::shared_ptr<std::vector<ObjectID>> task_deps =
|
|
||||||
std::make_shared<std::vector<ObjectID>>();
|
|
||||||
for (size_t i = 0; i < task_spec.NumArgs(); i++) {
|
|
||||||
if (task_spec.ArgByRef(i)) {
|
|
||||||
for (size_t j = 0; j < task_spec.ArgIdCount(i); j++) {
|
|
||||||
task_deps->push_back(task_spec.ArgId(i, j));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Note that we call this even if task_deps.size() == 0, in order to pin the return id.
|
|
||||||
for (size_t i = 0; i < num_returns; i++) {
|
|
||||||
reference_counter_->AddOwnedObject(task_spec.ReturnId(i, transport_type),
|
|
||||||
GetCallerId(), rpc_address_, task_deps);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
Status CoreWorker::SubmitTask(const RayFunction &function,
|
Status CoreWorker::SubmitTask(const RayFunction &function,
|
||||||
const std::vector<TaskArg> &args,
|
const std::vector<TaskArg> &args,
|
||||||
const TaskOptions &task_options,
|
const TaskOptions &task_options,
|
||||||
|
@ -642,11 +613,9 @@ Status CoreWorker::SubmitTask(const RayFunction &function,
|
||||||
return_ids);
|
return_ids);
|
||||||
TaskSpecification task_spec = builder.Build();
|
TaskSpecification task_spec = builder.Build();
|
||||||
if (task_options.is_direct_call) {
|
if (task_options.is_direct_call) {
|
||||||
task_manager_->AddPendingTask(task_spec, max_retries);
|
task_manager_->AddPendingTask(GetCallerId(), rpc_address_, task_spec, max_retries);
|
||||||
PinObjectReferences(task_spec, TaskTransportType::DIRECT);
|
|
||||||
return direct_task_submitter_->SubmitTask(task_spec);
|
return direct_task_submitter_->SubmitTask(task_spec);
|
||||||
} else {
|
} else {
|
||||||
PinObjectReferences(task_spec, TaskTransportType::RAYLET);
|
|
||||||
return local_raylet_client_->SubmitTask(task_spec);
|
return local_raylet_client_->SubmitTask(task_spec);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -685,11 +654,10 @@ Status CoreWorker::CreateActor(const RayFunction &function,
|
||||||
*return_actor_id = actor_id;
|
*return_actor_id = actor_id;
|
||||||
TaskSpecification task_spec = builder.Build();
|
TaskSpecification task_spec = builder.Build();
|
||||||
if (actor_creation_options.is_direct_call) {
|
if (actor_creation_options.is_direct_call) {
|
||||||
task_manager_->AddPendingTask(task_spec, actor_creation_options.max_reconstructions);
|
task_manager_->AddPendingTask(GetCallerId(), rpc_address_, task_spec,
|
||||||
PinObjectReferences(task_spec, TaskTransportType::DIRECT);
|
actor_creation_options.max_reconstructions);
|
||||||
return direct_task_submitter_->SubmitTask(task_spec);
|
return direct_task_submitter_->SubmitTask(task_spec);
|
||||||
} else {
|
} else {
|
||||||
PinObjectReferences(task_spec, TaskTransportType::RAYLET);
|
|
||||||
return local_raylet_client_->SubmitTask(task_spec);
|
return local_raylet_client_->SubmitTask(task_spec);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -729,8 +697,7 @@ Status CoreWorker::SubmitActorTask(const ActorID &actor_id, const RayFunction &f
|
||||||
Status status;
|
Status status;
|
||||||
TaskSpecification task_spec = builder.Build();
|
TaskSpecification task_spec = builder.Build();
|
||||||
if (is_direct_call) {
|
if (is_direct_call) {
|
||||||
task_manager_->AddPendingTask(task_spec);
|
task_manager_->AddPendingTask(GetCallerId(), rpc_address_, task_spec);
|
||||||
PinObjectReferences(task_spec, TaskTransportType::DIRECT);
|
|
||||||
if (actor_handle->IsDead()) {
|
if (actor_handle->IsDead()) {
|
||||||
auto status = Status::IOError("sent task to dead actor");
|
auto status = Status::IOError("sent task to dead actor");
|
||||||
task_manager_->PendingTaskFailed(task_spec.TaskId(), rpc::ErrorType::ACTOR_DIED,
|
task_manager_->PendingTaskFailed(task_spec.TaskId(), rpc::ErrorType::ACTOR_DIED,
|
||||||
|
@ -739,7 +706,6 @@ Status CoreWorker::SubmitActorTask(const ActorID &actor_id, const RayFunction &f
|
||||||
status = direct_actor_submitter_->SubmitTask(task_spec);
|
status = direct_actor_submitter_->SubmitTask(task_spec);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
PinObjectReferences(task_spec, TaskTransportType::RAYLET);
|
|
||||||
RAY_CHECK_OK(local_raylet_client_->SubmitTask(task_spec));
|
RAY_CHECK_OK(local_raylet_client_->SubmitTask(task_spec));
|
||||||
}
|
}
|
||||||
return status;
|
return status;
|
||||||
|
@ -1043,17 +1009,17 @@ void CoreWorker::HandleGetObjectStatus(const rpc::GetObjectStatusRequest &reques
|
||||||
if (task_manager_->IsTaskPending(object_id.TaskId())) {
|
if (task_manager_->IsTaskPending(object_id.TaskId())) {
|
||||||
// Acquire a reference and retry. This prevents the object from being
|
// Acquire a reference and retry. This prevents the object from being
|
||||||
// evicted out from under us before we can start the get.
|
// evicted out from under us before we can start the get.
|
||||||
AddObjectIDReference(object_id);
|
AddLocalReference(object_id);
|
||||||
if (task_manager_->IsTaskPending(object_id.TaskId())) {
|
if (task_manager_->IsTaskPending(object_id.TaskId())) {
|
||||||
// The task is pending. Send the reply once the task finishes.
|
// The task is pending. Send the reply once the task finishes.
|
||||||
memory_store_->GetAsync(object_id,
|
memory_store_->GetAsync(object_id,
|
||||||
[send_reply_callback](std::shared_ptr<RayObject> obj) {
|
[send_reply_callback](std::shared_ptr<RayObject> obj) {
|
||||||
send_reply_callback(Status::OK(), nullptr, nullptr);
|
send_reply_callback(Status::OK(), nullptr, nullptr);
|
||||||
});
|
});
|
||||||
RemoveObjectIDReference(object_id);
|
RemoveLocalReference(object_id);
|
||||||
} else {
|
} else {
|
||||||
// We lost the race, the task is done.
|
// We lost the race, the task is done.
|
||||||
RemoveObjectIDReference(object_id);
|
RemoveLocalReference(object_id);
|
||||||
send_reply_callback(Status::OK(), nullptr, nullptr);
|
send_reply_callback(Status::OK(), nullptr, nullptr);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
|
|
@ -101,17 +101,19 @@ class CoreWorker {
|
||||||
actor_id_ = actor_id;
|
actor_id_ = actor_id;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Increase the reference count for this object ID.
|
/// Increase the local reference count for this object ID. Should be called
|
||||||
|
/// by the language frontend when a new reference is created.
|
||||||
///
|
///
|
||||||
/// \param[in] object_id The object ID to increase the reference count for.
|
/// \param[in] object_id The object ID to increase the reference count for.
|
||||||
void AddObjectIDReference(const ObjectID &object_id) {
|
void AddLocalReference(const ObjectID &object_id) {
|
||||||
reference_counter_->AddLocalReference(object_id);
|
reference_counter_->AddLocalReference(object_id);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Decrease the reference count for this object ID.
|
/// Decrease the reference count for this object ID. Should be called
|
||||||
|
/// by the language frontend when a reference is destroyed.
|
||||||
///
|
///
|
||||||
/// \param[in] object_id The object ID to decrease the reference count for.
|
/// \param[in] object_id The object ID to decrease the reference count for.
|
||||||
void RemoveObjectIDReference(const ObjectID &object_id) {
|
void RemoveLocalReference(const ObjectID &object_id) {
|
||||||
std::vector<ObjectID> deleted;
|
std::vector<ObjectID> deleted;
|
||||||
reference_counter_->RemoveLocalReference(object_id, &deleted);
|
reference_counter_->RemoveLocalReference(object_id, &deleted);
|
||||||
if (ref_counting_enabled_) {
|
if (ref_counting_enabled_) {
|
||||||
|
@ -119,6 +121,12 @@ class CoreWorker {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Returns a map of all ObjectIDs currently in scope with a pair of their
|
||||||
|
/// (local, submitted_task) reference counts. For debugging purposes.
|
||||||
|
std::unordered_map<ObjectID, std::pair<size_t, size_t>> GetAllReferenceCounts() const {
|
||||||
|
return reference_counter_->GetAllReferenceCounts();
|
||||||
|
}
|
||||||
|
|
||||||
/// Promote an object to plasma and get its owner information. This should be
|
/// Promote an object to plasma and get its owner information. This should be
|
||||||
/// called when serializing an object ID, and the returned information should
|
/// called when serializing an object ID, and the returned information should
|
||||||
/// be stored with the serialized object ID. For plasma promotion, if the
|
/// be stored with the serialized object ID. For plasma promotion, if the
|
||||||
|
@ -434,11 +442,6 @@ class CoreWorker {
|
||||||
/// Private methods related to task submission.
|
/// Private methods related to task submission.
|
||||||
///
|
///
|
||||||
|
|
||||||
/// Add task dependencies to the reference counter. This prevents the argument
|
|
||||||
/// objects from early eviction, and also adds the return object.
|
|
||||||
void PinObjectReferences(const TaskSpecification &task_spec,
|
|
||||||
const TaskTransportType transport_type);
|
|
||||||
|
|
||||||
/// Give this worker a handle to an actor.
|
/// Give this worker a handle to an actor.
|
||||||
///
|
///
|
||||||
/// This handle will remain as long as the current actor or task is
|
/// This handle will remain as long as the current actor or task is
|
||||||
|
@ -485,17 +488,6 @@ class CoreWorker {
|
||||||
std::vector<std::shared_ptr<RayObject>> *args,
|
std::vector<std::shared_ptr<RayObject>> *args,
|
||||||
std::vector<ObjectID> *arg_reference_ids);
|
std::vector<ObjectID> *arg_reference_ids);
|
||||||
|
|
||||||
/// Remove reference counting dependencies of this object ID.
|
|
||||||
///
|
|
||||||
/// \param[in] object_id The object whose dependencies should be removed.
|
|
||||||
void RemoveObjectIDDependencies(const ObjectID &object_id) {
|
|
||||||
std::vector<ObjectID> deleted;
|
|
||||||
reference_counter_->RemoveDependencies(object_id, &deleted);
|
|
||||||
if (ref_counting_enabled_) {
|
|
||||||
memory_store_->Delete(deleted);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Returns whether the message was sent to the wrong worker. The right error reply
|
/// Returns whether the message was sent to the wrong worker. The right error reply
|
||||||
/// is sent automatically. Messages end up on the wrong worker when a worker dies
|
/// is sent automatically. Messages end up on the wrong worker when a worker dies
|
||||||
/// and a new one takes its place with the same place. In this situation, we want
|
/// and a new one takes its place with the same place. In this situation, we want
|
||||||
|
@ -624,9 +616,6 @@ class CoreWorker {
|
||||||
absl::flat_hash_map<ActorID, std::unique_ptr<ActorHandle>> actor_handles_
|
absl::flat_hash_map<ActorID, std::unique_ptr<ActorHandle>> actor_handles_
|
||||||
GUARDED_BY(actor_handles_mutex_);
|
GUARDED_BY(actor_handles_mutex_);
|
||||||
|
|
||||||
/// Resolve local and remote dependencies for actor creation.
|
|
||||||
std::unique_ptr<LocalDependencyResolver> resolver_;
|
|
||||||
|
|
||||||
///
|
///
|
||||||
/// Fields related to task execution.
|
/// Fields related to task execution.
|
||||||
///
|
///
|
||||||
|
|
|
@ -14,81 +14,76 @@ void ReferenceCounter::AddBorrowedObject(const ObjectID &object_id,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void ReferenceCounter::AddOwnedObject(
|
void ReferenceCounter::AddOwnedObject(const ObjectID &object_id, const TaskID &owner_id,
|
||||||
const ObjectID &object_id, const TaskID &owner_id, const rpc::Address &owner_address,
|
const rpc::Address &owner_address) {
|
||||||
std::shared_ptr<std::vector<ObjectID>> dependencies) {
|
|
||||||
absl::MutexLock lock(&mutex_);
|
absl::MutexLock lock(&mutex_);
|
||||||
|
|
||||||
for (const ObjectID &dependency_id : *dependencies) {
|
|
||||||
AddLocalReferenceInternal(dependency_id);
|
|
||||||
}
|
|
||||||
|
|
||||||
RAY_CHECK(object_id_refs_.count(object_id) == 0)
|
RAY_CHECK(object_id_refs_.count(object_id) == 0)
|
||||||
<< "Cannot create an object that already exists. ObjectID: " << object_id;
|
<< "Tried to create an owned object that already exists: " << object_id;
|
||||||
// If the entry doesn't exist, we initialize the direct reference count to zero
|
// If the entry doesn't exist, we initialize the direct reference count to zero
|
||||||
// because this corresponds to a submitted task whose return ObjectID will be created
|
// because this corresponds to a submitted task whose return ObjectID will be created
|
||||||
// in the frontend language, incrementing the reference count.
|
// in the frontend language, incrementing the reference count.
|
||||||
object_id_refs_.emplace(object_id, Reference(owner_id, owner_address, dependencies));
|
object_id_refs_.emplace(object_id, Reference(owner_id, owner_address));
|
||||||
}
|
}
|
||||||
|
|
||||||
void ReferenceCounter::AddLocalReferenceInternal(const ObjectID &object_id) {
|
void ReferenceCounter::AddLocalReference(const ObjectID &object_id) {
|
||||||
|
absl::MutexLock lock(&mutex_);
|
||||||
auto entry = object_id_refs_.find(object_id);
|
auto entry = object_id_refs_.find(object_id);
|
||||||
if (entry == object_id_refs_.end()) {
|
if (entry == object_id_refs_.end()) {
|
||||||
// TODO: Once ref counting is implemented, we should always know how the
|
// TODO: Once ref counting is implemented, we should always know how the
|
||||||
// ObjectID was created, so there should always ben an entry.
|
// ObjectID was created, so there should always be an entry.
|
||||||
entry = object_id_refs_.emplace(object_id, Reference()).first;
|
entry = object_id_refs_.emplace(object_id, Reference()).first;
|
||||||
}
|
}
|
||||||
entry->second.local_ref_count++;
|
entry->second.local_ref_count++;
|
||||||
}
|
}
|
||||||
|
|
||||||
void ReferenceCounter::AddLocalReference(const ObjectID &object_id) {
|
|
||||||
absl::MutexLock lock(&mutex_);
|
|
||||||
AddLocalReferenceInternal(object_id);
|
|
||||||
}
|
|
||||||
|
|
||||||
void ReferenceCounter::RemoveLocalReference(const ObjectID &object_id,
|
void ReferenceCounter::RemoveLocalReference(const ObjectID &object_id,
|
||||||
std::vector<ObjectID> *deleted) {
|
std::vector<ObjectID> *deleted) {
|
||||||
absl::MutexLock lock(&mutex_);
|
absl::MutexLock lock(&mutex_);
|
||||||
RemoveReferenceRecursive(object_id, deleted);
|
|
||||||
}
|
|
||||||
|
|
||||||
void ReferenceCounter::RemoveDependencies(const ObjectID &object_id,
|
|
||||||
std::vector<ObjectID> *deleted) {
|
|
||||||
absl::MutexLock lock(&mutex_);
|
|
||||||
auto entry = object_id_refs_.find(object_id);
|
|
||||||
if (entry == object_id_refs_.end()) {
|
|
||||||
RAY_LOG(WARNING) << "Tried to remove dependencies for nonexistent object ID: "
|
|
||||||
<< object_id;
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
if (entry->second.dependencies) {
|
|
||||||
for (const ObjectID &pending_task_object_id : *entry->second.dependencies) {
|
|
||||||
RemoveReferenceRecursive(pending_task_object_id, deleted);
|
|
||||||
}
|
|
||||||
entry->second.dependencies = nullptr;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void ReferenceCounter::RemoveReferenceRecursive(const ObjectID &object_id,
|
|
||||||
std::vector<ObjectID> *deleted) {
|
|
||||||
auto entry = object_id_refs_.find(object_id);
|
auto entry = object_id_refs_.find(object_id);
|
||||||
if (entry == object_id_refs_.end()) {
|
if (entry == object_id_refs_.end()) {
|
||||||
RAY_LOG(WARNING) << "Tried to decrease ref count for nonexistent object ID: "
|
RAY_LOG(WARNING) << "Tried to decrease ref count for nonexistent object ID: "
|
||||||
<< object_id;
|
<< object_id;
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
if (--entry->second.local_ref_count == 0) {
|
if (--entry->second.local_ref_count == 0 &&
|
||||||
// If the reference count reached 0, decrease the reference count for each dependency.
|
entry->second.submitted_task_ref_count == 0) {
|
||||||
if (entry->second.dependencies) {
|
|
||||||
for (const ObjectID &pending_task_object_id : *entry->second.dependencies) {
|
|
||||||
RemoveReferenceRecursive(pending_task_object_id, deleted);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
object_id_refs_.erase(entry);
|
object_id_refs_.erase(entry);
|
||||||
deleted->push_back(object_id);
|
deleted->push_back(object_id);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void ReferenceCounter::AddSubmittedTaskReferences(
|
||||||
|
const std::vector<ObjectID> &object_ids) {
|
||||||
|
absl::MutexLock lock(&mutex_);
|
||||||
|
for (const ObjectID &object_id : object_ids) {
|
||||||
|
auto entry = object_id_refs_.find(object_id);
|
||||||
|
if (entry == object_id_refs_.end()) {
|
||||||
|
// TODO: Once ref counting is implemented, we should always know how the
|
||||||
|
// ObjectID was created, so there should always be an entry.
|
||||||
|
entry = object_id_refs_.emplace(object_id, Reference()).first;
|
||||||
|
}
|
||||||
|
entry->second.submitted_task_ref_count++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void ReferenceCounter::RemoveSubmittedTaskReferences(
|
||||||
|
const std::vector<ObjectID> &object_ids, std::vector<ObjectID> *deleted) {
|
||||||
|
absl::MutexLock lock(&mutex_);
|
||||||
|
for (const ObjectID &object_id : object_ids) {
|
||||||
|
auto entry = object_id_refs_.find(object_id);
|
||||||
|
if (entry == object_id_refs_.end()) {
|
||||||
|
RAY_LOG(WARNING) << "Tried to decrease ref count for nonexistent object ID: "
|
||||||
|
<< object_id;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (--entry->second.submitted_task_ref_count == 0 &&
|
||||||
|
entry->second.local_ref_count == 0) {
|
||||||
|
object_id_refs_.erase(entry);
|
||||||
|
deleted->push_back(object_id);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
bool ReferenceCounter::GetOwner(const ObjectID &object_id, TaskID *owner_id,
|
bool ReferenceCounter::GetOwner(const ObjectID &object_id, TaskID *owner_id,
|
||||||
rpc::Address *owner_address) const {
|
rpc::Address *owner_address) const {
|
||||||
absl::MutexLock lock(&mutex_);
|
absl::MutexLock lock(&mutex_);
|
||||||
|
@ -126,6 +121,19 @@ std::unordered_set<ObjectID> ReferenceCounter::GetAllInScopeObjectIDs() const {
|
||||||
return in_scope_object_ids;
|
return in_scope_object_ids;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::unordered_map<ObjectID, std::pair<size_t, size_t>>
|
||||||
|
ReferenceCounter::GetAllReferenceCounts() const {
|
||||||
|
absl::MutexLock lock(&mutex_);
|
||||||
|
std::unordered_map<ObjectID, std::pair<size_t, size_t>> all_ref_counts;
|
||||||
|
all_ref_counts.reserve(object_id_refs_.size());
|
||||||
|
for (auto it : object_id_refs_) {
|
||||||
|
all_ref_counts.emplace(it.first,
|
||||||
|
std::pair<size_t, size_t>(it.second.local_ref_count,
|
||||||
|
it.second.submitted_task_ref_count));
|
||||||
|
}
|
||||||
|
return all_ref_counts;
|
||||||
|
}
|
||||||
|
|
||||||
void ReferenceCounter::LogDebugString() const {
|
void ReferenceCounter::LogDebugString() const {
|
||||||
absl::MutexLock lock(&mutex_);
|
absl::MutexLock lock(&mutex_);
|
||||||
|
|
||||||
|
@ -137,15 +145,9 @@ void ReferenceCounter::LogDebugString() const {
|
||||||
|
|
||||||
for (const auto &entry : object_id_refs_) {
|
for (const auto &entry : object_id_refs_) {
|
||||||
RAY_LOG(DEBUG) << "\t" << entry.first.Hex();
|
RAY_LOG(DEBUG) << "\t" << entry.first.Hex();
|
||||||
RAY_LOG(DEBUG) << "\t\treference count: " << entry.second.local_ref_count;
|
RAY_LOG(DEBUG) << "\t\tlocal refcount: " << entry.second.local_ref_count;
|
||||||
RAY_LOG(DEBUG) << "\t\tdependencies: ";
|
RAY_LOG(DEBUG) << "\t\tsubmitted task refcount: "
|
||||||
if (!entry.second.dependencies) {
|
<< entry.second.submitted_task_ref_count;
|
||||||
RAY_LOG(DEBUG) << "\t\t\tNULL";
|
|
||||||
} else {
|
|
||||||
for (const ObjectID &pending_task_object_id : *entry.second.dependencies) {
|
|
||||||
RAY_LOG(DEBUG) << "\t\t\t" << pending_task_object_id.Hex();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -25,27 +25,32 @@ class ReferenceCounter {
|
||||||
/// \param[in] object_id The object to to increment the count for.
|
/// \param[in] object_id The object to to increment the count for.
|
||||||
void AddLocalReference(const ObjectID &object_id) LOCKS_EXCLUDED(mutex_);
|
void AddLocalReference(const ObjectID &object_id) LOCKS_EXCLUDED(mutex_);
|
||||||
|
|
||||||
/// Decrease the reference count for the ObjectID by one. If the reference count reaches
|
/// Decrease the local reference count for the ObjectID by one.
|
||||||
/// zero, it will be erased from the map and the reference count for all of its
|
|
||||||
/// dependencies will be decreased be one.
|
|
||||||
///
|
///
|
||||||
/// \param[in] object_id The object to decrement the count for.
|
/// \param[in] object_id The object to decrement the count for.
|
||||||
/// \param[out] deleted List to store objects that hit zero ref count.
|
/// \param[out] deleted List to store objects that hit zero ref count.
|
||||||
void RemoveLocalReference(const ObjectID &object_id, std::vector<ObjectID> *deleted)
|
void RemoveLocalReference(const ObjectID &object_id, std::vector<ObjectID> *deleted)
|
||||||
LOCKS_EXCLUDED(mutex_);
|
LOCKS_EXCLUDED(mutex_);
|
||||||
|
|
||||||
/// Remove any references to dependencies that this object may have. This does *not*
|
/// Add references for the provided object IDs that correspond to them being
|
||||||
/// decrease the object's own reference count.
|
/// dependencies to a submitted task.
|
||||||
///
|
///
|
||||||
/// \param[in] object_id The object whose dependencies should be removed.
|
/// \param[in] object_ids The object IDs to add references for.
|
||||||
/// \param[out] deleted List to store objects that hit zero ref count.
|
void AddSubmittedTaskReferences(const std::vector<ObjectID> &object_ids);
|
||||||
void RemoveDependencies(const ObjectID &object_id, std::vector<ObjectID> *deleted)
|
|
||||||
LOCKS_EXCLUDED(mutex_);
|
/// Remove references for the provided object IDs that correspond to them being
|
||||||
|
/// dependencies to a submitted task. This should be called when inlined
|
||||||
|
/// dependencies are inlined or when the task finishes for plasma dependencies.
|
||||||
|
///
|
||||||
|
/// \param[in] object_ids The object IDs to remove references for.
|
||||||
|
/// \param[out] deleted The object IDs whos reference counts reached zero.
|
||||||
|
void RemoveSubmittedTaskReferences(const std::vector<ObjectID> &object_ids,
|
||||||
|
std::vector<ObjectID> *deleted);
|
||||||
|
|
||||||
/// Add an object that we own. The object may depend on other objects.
|
/// Add an object that we own. The object may depend on other objects.
|
||||||
/// Dependencies for each ObjectID must be set at most once. The direct
|
/// Dependencies for each ObjectID must be set at most once. The local
|
||||||
/// reference count for the ObjectID is set to zero and the reference count
|
/// reference count for the ObjectID is set to zero, which assumes that an
|
||||||
/// for each dependency is incremented.
|
/// ObjectID for it will be created in the language frontend after this call.
|
||||||
///
|
///
|
||||||
/// TODO(swang): We could avoid copying the owner_id and owner_address since
|
/// TODO(swang): We could avoid copying the owner_id and owner_address since
|
||||||
/// we are the owner, but it is easier to store a copy for now, since the
|
/// we are the owner, but it is easier to store a copy for now, since the
|
||||||
|
@ -57,9 +62,7 @@ class ReferenceCounter {
|
||||||
/// \param[in] owner_address The address of the object's owner.
|
/// \param[in] owner_address The address of the object's owner.
|
||||||
/// \param[in] dependencies The objects that the object depends on.
|
/// \param[in] dependencies The objects that the object depends on.
|
||||||
void AddOwnedObject(const ObjectID &object_id, const TaskID &owner_id,
|
void AddOwnedObject(const ObjectID &object_id, const TaskID &owner_id,
|
||||||
const rpc::Address &owner_address,
|
const rpc::Address &owner_address) LOCKS_EXCLUDED(mutex_);
|
||||||
std::shared_ptr<std::vector<ObjectID>> dependencies)
|
|
||||||
LOCKS_EXCLUDED(mutex_);
|
|
||||||
|
|
||||||
/// Add an object that we are borrowing.
|
/// Add an object that we are borrowing.
|
||||||
///
|
///
|
||||||
|
@ -82,6 +85,11 @@ class ReferenceCounter {
|
||||||
/// Returns a set of all ObjectIDs currently in scope (i.e., nonzero reference count).
|
/// Returns a set of all ObjectIDs currently in scope (i.e., nonzero reference count).
|
||||||
std::unordered_set<ObjectID> GetAllInScopeObjectIDs() const LOCKS_EXCLUDED(mutex_);
|
std::unordered_set<ObjectID> GetAllInScopeObjectIDs() const LOCKS_EXCLUDED(mutex_);
|
||||||
|
|
||||||
|
/// Returns a map of all ObjectIDs currently in scope with a pair of their
|
||||||
|
/// (local, submitted_task) reference counts. For debugging purposes.
|
||||||
|
std::unordered_map<ObjectID, std::pair<size_t, size_t>> GetAllReferenceCounts() const
|
||||||
|
LOCKS_EXCLUDED(mutex_);
|
||||||
|
|
||||||
/// Dumps information about all currently tracked references to RAY_LOG(DEBUG).
|
/// Dumps information about all currently tracked references to RAY_LOG(DEBUG).
|
||||||
void LogDebugString() const LOCKS_EXCLUDED(mutex_);
|
void LogDebugString() const LOCKS_EXCLUDED(mutex_);
|
||||||
|
|
||||||
|
@ -91,22 +99,12 @@ class ReferenceCounter {
|
||||||
/// Constructor for a reference whose origin is unknown.
|
/// Constructor for a reference whose origin is unknown.
|
||||||
Reference() : owned_by_us(false) {}
|
Reference() : owned_by_us(false) {}
|
||||||
/// Constructor for a reference that we created.
|
/// Constructor for a reference that we created.
|
||||||
Reference(const TaskID &owner_id, const rpc::Address &owner_address,
|
|
||||||
std::shared_ptr<std::vector<ObjectID>> deps)
|
|
||||||
: dependencies(std::move(deps)),
|
|
||||||
owned_by_us(true),
|
|
||||||
owner({owner_id, owner_address}) {}
|
|
||||||
/// Constructor for a reference that was given to us.
|
|
||||||
Reference(const TaskID &owner_id, const rpc::Address &owner_address)
|
Reference(const TaskID &owner_id, const rpc::Address &owner_address)
|
||||||
: owned_by_us(false), owner({owner_id, owner_address}) {}
|
: owned_by_us(true), owner({owner_id, owner_address}) {}
|
||||||
/// The local ref count for the ObjectID in the language frontend.
|
/// The local ref count for the ObjectID in the language frontend.
|
||||||
size_t local_ref_count = 0;
|
size_t local_ref_count = 0;
|
||||||
/// The objects that this object depends on. Tracked only by the owner of
|
/// The ref count for submitted tasks that depend on the ObjectID.
|
||||||
/// the object. Dependencies are stored as shared_ptrs because the same set
|
size_t submitted_task_ref_count = 0;
|
||||||
/// of dependencies can be shared among multiple entries. For example, when
|
|
||||||
/// a task has multiple return values, the entry for each return ObjectID
|
|
||||||
/// depends on all task dependencies.
|
|
||||||
std::shared_ptr<std::vector<ObjectID>> dependencies;
|
|
||||||
/// Whether we own the object. If we own the object, then we are
|
/// Whether we own the object. If we own the object, then we are
|
||||||
/// responsible for tracking the state of the task that creates the object
|
/// responsible for tracking the state of the task that creates the object
|
||||||
/// (see task_manager.h).
|
/// (see task_manager.h).
|
||||||
|
|
|
@ -16,173 +16,63 @@ class ReferenceCountTest : public ::testing::Test {
|
||||||
virtual void TearDown() {}
|
virtual void TearDown() {}
|
||||||
};
|
};
|
||||||
|
|
||||||
// Tests basic incrementing/decrementing of direct reference counts. An entry should only
|
// Tests basic incrementing/decrementing of direct/submitted task reference counts. An
|
||||||
// be removed once its reference count reaches zero.
|
// entry should only be removed once both of its reference counts reach zero.
|
||||||
TEST_F(ReferenceCountTest, TestBasic) {
|
TEST_F(ReferenceCountTest, TestBasic) {
|
||||||
std::vector<ObjectID> out;
|
std::vector<ObjectID> out;
|
||||||
ObjectID id = ObjectID::FromRandom();
|
|
||||||
rc->AddLocalReference(id);
|
|
||||||
ASSERT_EQ(rc->NumObjectIDsInScope(), 1);
|
|
||||||
rc->AddLocalReference(id);
|
|
||||||
ASSERT_EQ(rc->NumObjectIDsInScope(), 1);
|
|
||||||
rc->AddLocalReference(id);
|
|
||||||
ASSERT_EQ(rc->NumObjectIDsInScope(), 1);
|
|
||||||
rc->RemoveLocalReference(id, &out);
|
|
||||||
ASSERT_EQ(rc->NumObjectIDsInScope(), 1);
|
|
||||||
ASSERT_EQ(out.size(), 0);
|
|
||||||
rc->RemoveLocalReference(id, &out);
|
|
||||||
ASSERT_EQ(rc->NumObjectIDsInScope(), 1);
|
|
||||||
ASSERT_EQ(out.size(), 0);
|
|
||||||
rc->RemoveLocalReference(id, &out);
|
|
||||||
ASSERT_EQ(rc->NumObjectIDsInScope(), 0);
|
|
||||||
ASSERT_EQ(out.size(), 1);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Tests the basic logic for dependencies - when an ObjectID with dependencies
|
|
||||||
// goes out of scope (i.e., reference count reaches zero), all of its dependencies
|
|
||||||
// should have their reference count decremented and be removed if it reaches zero.
|
|
||||||
TEST_F(ReferenceCountTest, TestDependencies) {
|
|
||||||
std::vector<ObjectID> out;
|
|
||||||
ObjectID id1 = ObjectID::FromRandom();
|
ObjectID id1 = ObjectID::FromRandom();
|
||||||
ObjectID id2 = ObjectID::FromRandom();
|
ObjectID id2 = ObjectID::FromRandom();
|
||||||
ObjectID id3 = ObjectID::FromRandom();
|
|
||||||
|
|
||||||
std::shared_ptr<std::vector<ObjectID>> deps = std::make_shared<std::vector<ObjectID>>();
|
|
||||||
deps->push_back(id2);
|
|
||||||
deps->push_back(id3);
|
|
||||||
rc->AddOwnedObject(id1, TaskID::Nil(), rpc::Address(), deps);
|
|
||||||
|
|
||||||
|
// Local references.
|
||||||
rc->AddLocalReference(id1);
|
rc->AddLocalReference(id1);
|
||||||
rc->AddLocalReference(id1);
|
|
||||||
rc->AddLocalReference(id3);
|
|
||||||
ASSERT_EQ(rc->NumObjectIDsInScope(), 3);
|
|
||||||
|
|
||||||
rc->RemoveLocalReference(id1, &out);
|
|
||||||
ASSERT_EQ(rc->NumObjectIDsInScope(), 3);
|
|
||||||
ASSERT_EQ(out.size(), 0);
|
|
||||||
rc->RemoveLocalReference(id1, &out);
|
|
||||||
ASSERT_EQ(rc->NumObjectIDsInScope(), 1);
|
|
||||||
ASSERT_EQ(out.size(), 2);
|
|
||||||
|
|
||||||
rc->RemoveLocalReference(id3, &out);
|
|
||||||
ASSERT_EQ(rc->NumObjectIDsInScope(), 0);
|
|
||||||
ASSERT_EQ(out.size(), 3);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Tests the case where two entries share the same set of dependencies. When one
|
|
||||||
// entry goes out of scope, it should decrease the reference count for the dependencies
|
|
||||||
// but they should still be nonzero until the second entry goes out of scope and all
|
|
||||||
// direct dependencies to the dependencies are removed.
|
|
||||||
TEST_F(ReferenceCountTest, TestSharedDependencies) {
|
|
||||||
std::vector<ObjectID> out;
|
|
||||||
ObjectID id1 = ObjectID::FromRandom();
|
|
||||||
ObjectID id2 = ObjectID::FromRandom();
|
|
||||||
ObjectID id3 = ObjectID::FromRandom();
|
|
||||||
ObjectID id4 = ObjectID::FromRandom();
|
|
||||||
|
|
||||||
std::shared_ptr<std::vector<ObjectID>> deps = std::make_shared<std::vector<ObjectID>>();
|
|
||||||
deps->push_back(id3);
|
|
||||||
deps->push_back(id4);
|
|
||||||
rc->AddOwnedObject(id1, TaskID::Nil(), rpc::Address(), deps);
|
|
||||||
rc->AddOwnedObject(id2, TaskID::Nil(), rpc::Address(), deps);
|
|
||||||
|
|
||||||
rc->AddLocalReference(id1);
|
rc->AddLocalReference(id1);
|
||||||
rc->AddLocalReference(id2);
|
rc->AddLocalReference(id2);
|
||||||
rc->AddLocalReference(id4);
|
ASSERT_EQ(rc->NumObjectIDsInScope(), 2);
|
||||||
ASSERT_EQ(rc->NumObjectIDsInScope(), 4);
|
|
||||||
|
|
||||||
rc->RemoveLocalReference(id1, &out);
|
|
||||||
ASSERT_EQ(rc->NumObjectIDsInScope(), 3);
|
|
||||||
ASSERT_EQ(out.size(), 1);
|
|
||||||
rc->RemoveLocalReference(id2, &out);
|
|
||||||
ASSERT_EQ(rc->NumObjectIDsInScope(), 1);
|
|
||||||
ASSERT_EQ(out.size(), 3);
|
|
||||||
|
|
||||||
rc->RemoveLocalReference(id4, &out);
|
|
||||||
ASSERT_EQ(rc->NumObjectIDsInScope(), 0);
|
|
||||||
ASSERT_EQ(out.size(), 4);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Tests the case when an entry has a dependency that itself has a
|
|
||||||
// dependency. In this case, when the first entry goes out of scope
|
|
||||||
// it should decrease the reference count for its dependency, causing
|
|
||||||
// that entry to go out of scope and decrease its dependencies' reference counts.
|
|
||||||
TEST_F(ReferenceCountTest, TestRecursiveDependencies) {
|
|
||||||
std::vector<ObjectID> out;
|
|
||||||
ObjectID id1 = ObjectID::FromRandom();
|
|
||||||
ObjectID id2 = ObjectID::FromRandom();
|
|
||||||
ObjectID id3 = ObjectID::FromRandom();
|
|
||||||
ObjectID id4 = ObjectID::FromRandom();
|
|
||||||
|
|
||||||
std::shared_ptr<std::vector<ObjectID>> deps2 =
|
|
||||||
std::make_shared<std::vector<ObjectID>>();
|
|
||||||
deps2->push_back(id3);
|
|
||||||
deps2->push_back(id4);
|
|
||||||
rc->AddOwnedObject(id2, TaskID::Nil(), rpc::Address(), deps2);
|
|
||||||
|
|
||||||
std::shared_ptr<std::vector<ObjectID>> deps1 =
|
|
||||||
std::make_shared<std::vector<ObjectID>>();
|
|
||||||
deps1->push_back(id2);
|
|
||||||
rc->AddOwnedObject(id1, TaskID::Nil(), rpc::Address(), deps1);
|
|
||||||
|
|
||||||
rc->AddLocalReference(id1);
|
|
||||||
rc->AddLocalReference(id2);
|
|
||||||
rc->AddLocalReference(id4);
|
|
||||||
ASSERT_EQ(rc->NumObjectIDsInScope(), 4);
|
|
||||||
|
|
||||||
rc->RemoveLocalReference(id2, &out);
|
|
||||||
ASSERT_EQ(rc->NumObjectIDsInScope(), 4);
|
|
||||||
ASSERT_EQ(out.size(), 0);
|
|
||||||
rc->RemoveLocalReference(id1, &out);
|
|
||||||
ASSERT_EQ(rc->NumObjectIDsInScope(), 1);
|
|
||||||
ASSERT_EQ(out.size(), 3);
|
|
||||||
|
|
||||||
rc->RemoveLocalReference(id4, &out);
|
|
||||||
ASSERT_EQ(rc->NumObjectIDsInScope(), 0);
|
|
||||||
ASSERT_EQ(out.size(), 4);
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(ReferenceCountTest, TestRemoveDependenciesOnly) {
|
|
||||||
std::vector<ObjectID> out;
|
|
||||||
ObjectID id1 = ObjectID::FromRandom();
|
|
||||||
ObjectID id2 = ObjectID::FromRandom();
|
|
||||||
ObjectID id3 = ObjectID::FromRandom();
|
|
||||||
ObjectID id4 = ObjectID::FromRandom();
|
|
||||||
|
|
||||||
std::shared_ptr<std::vector<ObjectID>> deps2 =
|
|
||||||
std::make_shared<std::vector<ObjectID>>();
|
|
||||||
deps2->push_back(id3);
|
|
||||||
deps2->push_back(id4);
|
|
||||||
rc->AddOwnedObject(id2, TaskID::Nil(), rpc::Address(), deps2);
|
|
||||||
|
|
||||||
std::shared_ptr<std::vector<ObjectID>> deps1 =
|
|
||||||
std::make_shared<std::vector<ObjectID>>();
|
|
||||||
deps1->push_back(id2);
|
|
||||||
rc->AddOwnedObject(id1, TaskID::Nil(), rpc::Address(), deps1);
|
|
||||||
|
|
||||||
rc->AddLocalReference(id1);
|
|
||||||
rc->AddLocalReference(id2);
|
|
||||||
rc->AddLocalReference(id4);
|
|
||||||
ASSERT_EQ(rc->NumObjectIDsInScope(), 4);
|
|
||||||
|
|
||||||
rc->RemoveDependencies(id2, &out);
|
|
||||||
ASSERT_EQ(rc->NumObjectIDsInScope(), 3);
|
|
||||||
ASSERT_EQ(out.size(), 1);
|
|
||||||
rc->RemoveDependencies(id1, &out);
|
|
||||||
ASSERT_EQ(rc->NumObjectIDsInScope(), 3);
|
|
||||||
ASSERT_EQ(out.size(), 1);
|
|
||||||
|
|
||||||
rc->RemoveLocalReference(id1, &out);
|
rc->RemoveLocalReference(id1, &out);
|
||||||
ASSERT_EQ(rc->NumObjectIDsInScope(), 2);
|
ASSERT_EQ(rc->NumObjectIDsInScope(), 2);
|
||||||
ASSERT_EQ(out.size(), 2);
|
ASSERT_EQ(out.size(), 0);
|
||||||
|
|
||||||
rc->RemoveLocalReference(id2, &out);
|
rc->RemoveLocalReference(id2, &out);
|
||||||
ASSERT_EQ(rc->NumObjectIDsInScope(), 1);
|
ASSERT_EQ(rc->NumObjectIDsInScope(), 1);
|
||||||
ASSERT_EQ(out.size(), 3);
|
ASSERT_EQ(out.size(), 1);
|
||||||
|
rc->RemoveLocalReference(id1, &out);
|
||||||
rc->RemoveLocalReference(id4, &out);
|
|
||||||
ASSERT_EQ(rc->NumObjectIDsInScope(), 0);
|
ASSERT_EQ(rc->NumObjectIDsInScope(), 0);
|
||||||
ASSERT_EQ(out.size(), 4);
|
ASSERT_EQ(out.size(), 2);
|
||||||
|
out.clear();
|
||||||
|
|
||||||
|
// Submitted task references.
|
||||||
|
rc->AddSubmittedTaskReferences({id1});
|
||||||
|
rc->AddSubmittedTaskReferences({id1, id2});
|
||||||
|
ASSERT_EQ(rc->NumObjectIDsInScope(), 2);
|
||||||
|
rc->RemoveSubmittedTaskReferences({id1}, &out);
|
||||||
|
ASSERT_EQ(rc->NumObjectIDsInScope(), 2);
|
||||||
|
ASSERT_EQ(out.size(), 0);
|
||||||
|
rc->RemoveSubmittedTaskReferences({id2}, &out);
|
||||||
|
ASSERT_EQ(rc->NumObjectIDsInScope(), 1);
|
||||||
|
ASSERT_EQ(out.size(), 1);
|
||||||
|
rc->RemoveSubmittedTaskReferences({id1}, &out);
|
||||||
|
ASSERT_EQ(rc->NumObjectIDsInScope(), 0);
|
||||||
|
ASSERT_EQ(out.size(), 2);
|
||||||
|
out.clear();
|
||||||
|
|
||||||
|
// Local & submitted task references.
|
||||||
|
rc->AddLocalReference(id1);
|
||||||
|
rc->AddSubmittedTaskReferences({id1, id2});
|
||||||
|
rc->AddLocalReference(id2);
|
||||||
|
ASSERT_EQ(rc->NumObjectIDsInScope(), 2);
|
||||||
|
rc->RemoveLocalReference(id1, &out);
|
||||||
|
ASSERT_EQ(rc->NumObjectIDsInScope(), 2);
|
||||||
|
ASSERT_EQ(out.size(), 0);
|
||||||
|
rc->RemoveSubmittedTaskReferences({id2}, &out);
|
||||||
|
ASSERT_EQ(rc->NumObjectIDsInScope(), 2);
|
||||||
|
ASSERT_EQ(out.size(), 0);
|
||||||
|
rc->RemoveSubmittedTaskReferences({id1}, &out);
|
||||||
|
ASSERT_EQ(rc->NumObjectIDsInScope(), 1);
|
||||||
|
ASSERT_EQ(out.size(), 1);
|
||||||
|
rc->RemoveLocalReference(id2, &out);
|
||||||
|
ASSERT_EQ(rc->NumObjectIDsInScope(), 0);
|
||||||
|
ASSERT_EQ(out.size(), 2);
|
||||||
|
out.clear();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Tests that we can get the owner address correctly for objects that we own,
|
// Tests that we can get the owner address correctly for objects that we own,
|
||||||
|
@ -193,8 +83,7 @@ TEST_F(ReferenceCountTest, TestOwnerAddress) {
|
||||||
TaskID task_id = TaskID::ForFakeTask();
|
TaskID task_id = TaskID::ForFakeTask();
|
||||||
rpc::Address address;
|
rpc::Address address;
|
||||||
address.set_ip_address("1234");
|
address.set_ip_address("1234");
|
||||||
auto deps = std::make_shared<std::vector<ObjectID>>();
|
rc->AddOwnedObject(object_id, task_id, address);
|
||||||
rc->AddOwnedObject(object_id, task_id, address, deps);
|
|
||||||
|
|
||||||
TaskID added_id;
|
TaskID added_id;
|
||||||
rpc::Address added_address;
|
rpc::Address added_address;
|
||||||
|
@ -205,7 +94,7 @@ TEST_F(ReferenceCountTest, TestOwnerAddress) {
|
||||||
auto object_id2 = ObjectID::FromRandom();
|
auto object_id2 = ObjectID::FromRandom();
|
||||||
task_id = TaskID::ForFakeTask();
|
task_id = TaskID::ForFakeTask();
|
||||||
address.set_ip_address("5678");
|
address.set_ip_address("5678");
|
||||||
rc->AddOwnedObject(object_id2, task_id, address, deps);
|
rc->AddOwnedObject(object_id2, task_id, address);
|
||||||
ASSERT_TRUE(rc->GetOwner(object_id2, &added_id, &added_address));
|
ASSERT_TRUE(rc->GetOwner(object_id2, &added_id, &added_address));
|
||||||
ASSERT_EQ(task_id, added_id);
|
ASSERT_EQ(task_id, added_id);
|
||||||
ASSERT_EQ(address.ip_address(), added_address.ip_address());
|
ASSERT_EQ(address.ip_address(), added_address.ip_address());
|
||||||
|
|
|
@ -10,11 +10,34 @@ const int64_t kTaskFailureThrottlingThreshold = 50;
|
||||||
// Throttle task failure logs to once this interval.
|
// Throttle task failure logs to once this interval.
|
||||||
const int64_t kTaskFailureLoggingFrequencyMillis = 5000;
|
const int64_t kTaskFailureLoggingFrequencyMillis = 5000;
|
||||||
|
|
||||||
void TaskManager::AddPendingTask(const TaskSpecification &spec, int max_retries) {
|
void TaskManager::AddPendingTask(const TaskID &caller_id,
|
||||||
|
const rpc::Address &caller_address,
|
||||||
|
const TaskSpecification &spec, int max_retries) {
|
||||||
RAY_LOG(DEBUG) << "Adding pending task " << spec.TaskId();
|
RAY_LOG(DEBUG) << "Adding pending task " << spec.TaskId();
|
||||||
absl::MutexLock lock(&mu_);
|
absl::MutexLock lock(&mu_);
|
||||||
std::pair<TaskSpecification, int> entry = {spec, max_retries};
|
std::pair<TaskSpecification, int> entry = {spec, max_retries};
|
||||||
RAY_CHECK(pending_tasks_.emplace(spec.TaskId(), std::move(entry)).second);
|
RAY_CHECK(pending_tasks_.emplace(spec.TaskId(), std::move(entry)).second);
|
||||||
|
|
||||||
|
// Add references for the dependencies to the task.
|
||||||
|
std::vector<ObjectID> task_deps;
|
||||||
|
for (size_t i = 0; i < spec.NumArgs(); i++) {
|
||||||
|
if (spec.ArgByRef(i)) {
|
||||||
|
for (size_t j = 0; j < spec.ArgIdCount(i); j++) {
|
||||||
|
task_deps.push_back(spec.ArgId(i, j));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
reference_counter_->AddSubmittedTaskReferences(task_deps);
|
||||||
|
|
||||||
|
// Add new owned objects for the return values of the task.
|
||||||
|
size_t num_returns = spec.NumReturns();
|
||||||
|
if (spec.IsActorCreationTask() || spec.IsActorTask()) {
|
||||||
|
num_returns--;
|
||||||
|
}
|
||||||
|
for (size_t i = 0; i < num_returns; i++) {
|
||||||
|
reference_counter_->AddOwnedObject(spec.ReturnId(i, TaskTransportType::DIRECT),
|
||||||
|
caller_id, caller_address);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void TaskManager::DrainAndShutdown(std::function<void()> shutdown) {
|
void TaskManager::DrainAndShutdown(std::function<void()> shutdown) {
|
||||||
|
@ -48,6 +71,8 @@ void TaskManager::CompletePendingTask(const TaskID &task_id,
|
||||||
pending_tasks_.erase(it);
|
pending_tasks_.erase(it);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
RemovePlasmaSubmittedTaskReferences(spec);
|
||||||
|
|
||||||
for (int i = 0; i < reply.return_objects_size(); i++) {
|
for (int i = 0; i < reply.return_objects_size(); i++) {
|
||||||
const auto &return_object = reply.return_objects(i);
|
const auto &return_object = reply.return_objects(i);
|
||||||
ObjectID object_id = ObjectID::FromBinary(return_object.object_id());
|
ObjectID object_id = ObjectID::FromBinary(return_object.object_id());
|
||||||
|
@ -134,6 +159,7 @@ void TaskManager::PendingTaskFailed(const TaskID &task_id, rpc::ErrorType error_
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
RemovePlasmaSubmittedTaskReferences(spec);
|
||||||
MarkPendingTaskFailed(task_id, spec, error_type);
|
MarkPendingTaskFailed(task_id, spec, error_type);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -148,6 +174,28 @@ void TaskManager::ShutdownIfNeeded() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void TaskManager::RemoveSubmittedTaskReferences(const std::vector<ObjectID> &object_ids) {
|
||||||
|
std::vector<ObjectID> deleted;
|
||||||
|
reference_counter_->RemoveSubmittedTaskReferences(object_ids, &deleted);
|
||||||
|
in_memory_store_->Delete(deleted);
|
||||||
|
}
|
||||||
|
|
||||||
|
void TaskManager::OnTaskDependenciesInlined(const std::vector<ObjectID> &object_ids) {
|
||||||
|
RemoveSubmittedTaskReferences(object_ids);
|
||||||
|
}
|
||||||
|
|
||||||
|
void TaskManager::RemovePlasmaSubmittedTaskReferences(TaskSpecification &spec) {
|
||||||
|
std::vector<ObjectID> plasma_dependencies;
|
||||||
|
for (size_t i = 0; i < spec.NumArgs(); i++) {
|
||||||
|
auto count = spec.ArgIdCount(i);
|
||||||
|
if (count > 0) {
|
||||||
|
const auto &id = spec.ArgId(i, 0);
|
||||||
|
plasma_dependencies.push_back(id);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
RemoveSubmittedTaskReferences(plasma_dependencies);
|
||||||
|
}
|
||||||
|
|
||||||
void TaskManager::MarkPendingTaskFailed(const TaskID &task_id,
|
void TaskManager::MarkPendingTaskFailed(const TaskID &task_id,
|
||||||
const TaskSpecification &spec,
|
const TaskSpecification &spec,
|
||||||
rpc::ErrorType error_type) {
|
rpc::ErrorType error_type) {
|
||||||
|
|
|
@ -21,6 +21,8 @@ class TaskFinisherInterface {
|
||||||
virtual void PendingTaskFailed(const TaskID &task_id, rpc::ErrorType error_type,
|
virtual void PendingTaskFailed(const TaskID &task_id, rpc::ErrorType error_type,
|
||||||
Status *status = nullptr) = 0;
|
Status *status = nullptr) = 0;
|
||||||
|
|
||||||
|
virtual void OnTaskDependenciesInlined(const std::vector<ObjectID> &object_ids) = 0;
|
||||||
|
|
||||||
virtual ~TaskFinisherInterface() {}
|
virtual ~TaskFinisherInterface() {}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -29,19 +31,24 @@ using RetryTaskCallback = std::function<void(const TaskSpecification &spec)>;
|
||||||
class TaskManager : public TaskFinisherInterface {
|
class TaskManager : public TaskFinisherInterface {
|
||||||
public:
|
public:
|
||||||
TaskManager(std::shared_ptr<CoreWorkerMemoryStore> in_memory_store,
|
TaskManager(std::shared_ptr<CoreWorkerMemoryStore> in_memory_store,
|
||||||
|
std::shared_ptr<ReferenceCounter> reference_counter,
|
||||||
std::shared_ptr<ActorManagerInterface> actor_manager,
|
std::shared_ptr<ActorManagerInterface> actor_manager,
|
||||||
RetryTaskCallback retry_task_callback)
|
RetryTaskCallback retry_task_callback)
|
||||||
: in_memory_store_(in_memory_store),
|
: in_memory_store_(in_memory_store),
|
||||||
|
reference_counter_(reference_counter),
|
||||||
actor_manager_(actor_manager),
|
actor_manager_(actor_manager),
|
||||||
retry_task_callback_(retry_task_callback) {}
|
retry_task_callback_(retry_task_callback) {}
|
||||||
|
|
||||||
/// Add a task that is pending execution.
|
/// Add a task that is pending execution.
|
||||||
///
|
///
|
||||||
|
/// \param[in] caller_id The TaskID of the calling task.
|
||||||
|
/// \param[in] caller_address The rpc address of the calling task.
|
||||||
/// \param[in] spec The spec of the pending task.
|
/// \param[in] spec The spec of the pending task.
|
||||||
/// \param[in] max_retries Number of times this task may be retried
|
/// \param[in] max_retries Number of times this task may be retried
|
||||||
/// on failure.
|
/// on failure.
|
||||||
/// \return Void.
|
/// \return Void.
|
||||||
void AddPendingTask(const TaskSpecification &spec, int max_retries = 0);
|
void AddPendingTask(const TaskID &caller_id, const rpc::Address &caller_address,
|
||||||
|
const TaskSpecification &spec, int max_retries = 0);
|
||||||
|
|
||||||
/// Wait for all pending tasks to finish, and then shutdown.
|
/// Wait for all pending tasks to finish, and then shutdown.
|
||||||
///
|
///
|
||||||
|
@ -72,6 +79,8 @@ class TaskManager : public TaskFinisherInterface {
|
||||||
void PendingTaskFailed(const TaskID &task_id, rpc::ErrorType error_type,
|
void PendingTaskFailed(const TaskID &task_id, rpc::ErrorType error_type,
|
||||||
Status *status = nullptr) override;
|
Status *status = nullptr) override;
|
||||||
|
|
||||||
|
void OnTaskDependenciesInlined(const std::vector<ObjectID> &object_id) override;
|
||||||
|
|
||||||
/// Return the spec for a pending task.
|
/// Return the spec for a pending task.
|
||||||
TaskSpecification GetTaskSpec(const TaskID &task_id) const;
|
TaskSpecification GetTaskSpec(const TaskID &task_id) const;
|
||||||
|
|
||||||
|
@ -81,12 +90,25 @@ class TaskManager : public TaskFinisherInterface {
|
||||||
void MarkPendingTaskFailed(const TaskID &task_id, const TaskSpecification &spec,
|
void MarkPendingTaskFailed(const TaskID &task_id, const TaskSpecification &spec,
|
||||||
rpc::ErrorType error_type) LOCKS_EXCLUDED(mu_);
|
rpc::ErrorType error_type) LOCKS_EXCLUDED(mu_);
|
||||||
|
|
||||||
|
/// Remove submittted task references in the reference counter for the object IDs.
|
||||||
|
/// If their reference counts reach zero, they are deleted from the in-memory store.
|
||||||
|
void RemoveSubmittedTaskReferences(const std::vector<ObjectID> &object_ids);
|
||||||
|
|
||||||
|
/// Helper function to call RemoveSubmittedTaskReferences on the plasma dependencies
|
||||||
|
/// of the given task spec.
|
||||||
|
void RemovePlasmaSubmittedTaskReferences(TaskSpecification &spec);
|
||||||
|
|
||||||
/// Shutdown if all tasks are finished and shutdown is scheduled.
|
/// Shutdown if all tasks are finished and shutdown is scheduled.
|
||||||
void ShutdownIfNeeded() LOCKS_EXCLUDED(mu_);
|
void ShutdownIfNeeded() LOCKS_EXCLUDED(mu_);
|
||||||
|
|
||||||
/// Used to store task results.
|
/// Used to store task results.
|
||||||
std::shared_ptr<CoreWorkerMemoryStore> in_memory_store_;
|
std::shared_ptr<CoreWorkerMemoryStore> in_memory_store_;
|
||||||
|
|
||||||
|
/// Used for reference counting objects.
|
||||||
|
/// The task manager is responsible for managing all references related to
|
||||||
|
/// submitted tasks (dependencies and return objects).
|
||||||
|
std::shared_ptr<ReferenceCounter> reference_counter_;
|
||||||
|
|
||||||
// Interface for publishing actor creation.
|
// Interface for publishing actor creation.
|
||||||
std::shared_ptr<ActorManagerInterface> actor_manager_;
|
std::shared_ptr<ActorManagerInterface> actor_manager_;
|
||||||
|
|
||||||
|
|
|
@ -1,6 +1,5 @@
|
||||||
#include "gmock/gmock.h"
|
#include "gmock/gmock.h"
|
||||||
#include "gtest/gtest.h"
|
#include "gtest/gtest.h"
|
||||||
|
|
||||||
#include "ray/common/task/task_spec.h"
|
#include "ray/common/task/task_spec.h"
|
||||||
#include "ray/core_worker/store_provider/memory_store/memory_store.h"
|
#include "ray/core_worker/store_provider/memory_store/memory_store.h"
|
||||||
#include "ray/core_worker/transport/direct_task_transport.h"
|
#include "ray/core_worker/transport/direct_task_transport.h"
|
||||||
|
@ -45,6 +44,8 @@ class MockTaskFinisher : public TaskFinisherInterface {
|
||||||
const rpc::Address *addr));
|
const rpc::Address *addr));
|
||||||
MOCK_METHOD3(PendingTaskFailed,
|
MOCK_METHOD3(PendingTaskFailed,
|
||||||
void(const TaskID &task_id, rpc::ErrorType error_type, Status *status));
|
void(const TaskID &task_id, rpc::ErrorType error_type, Status *status));
|
||||||
|
|
||||||
|
MOCK_METHOD1(OnTaskDependenciesInlined, void(const std::vector<ObjectID> &object_ids));
|
||||||
};
|
};
|
||||||
|
|
||||||
TaskSpecification CreateActorTaskHelper(ActorID actor_id, int64_t counter) {
|
TaskSpecification CreateActorTaskHelper(ActorID actor_id, int64_t counter) {
|
||||||
|
|
|
@ -55,8 +55,13 @@ class MockTaskFinisher : public TaskFinisherInterface {
|
||||||
num_tasks_failed++;
|
num_tasks_failed++;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void OnTaskDependenciesInlined(const std::vector<ObjectID> &object_ids) override {
|
||||||
|
num_inlined += object_ids.size();
|
||||||
|
}
|
||||||
|
|
||||||
int num_tasks_complete = 0;
|
int num_tasks_complete = 0;
|
||||||
int num_tasks_failed = 0;
|
int num_tasks_failed = 0;
|
||||||
|
int num_inlined = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
class MockRayletClient : public WorkerLeaseInterface {
|
class MockRayletClient : public WorkerLeaseInterface {
|
||||||
|
@ -136,17 +141,20 @@ TEST(TestMemoryStore, TestPromoteToPlasma) {
|
||||||
|
|
||||||
TEST(LocalDependencyResolverTest, TestNoDependencies) {
|
TEST(LocalDependencyResolverTest, TestNoDependencies) {
|
||||||
auto store = std::make_shared<CoreWorkerMemoryStore>();
|
auto store = std::make_shared<CoreWorkerMemoryStore>();
|
||||||
LocalDependencyResolver resolver(store);
|
auto task_finisher = std::make_shared<MockTaskFinisher>();
|
||||||
|
LocalDependencyResolver resolver(store, task_finisher);
|
||||||
TaskSpecification task;
|
TaskSpecification task;
|
||||||
bool ok = false;
|
bool ok = false;
|
||||||
resolver.ResolveDependencies(task, [&ok]() { ok = true; });
|
resolver.ResolveDependencies(task, [&ok]() { ok = true; });
|
||||||
ASSERT_TRUE(ok);
|
ASSERT_TRUE(ok);
|
||||||
|
ASSERT_EQ(task_finisher->num_inlined, 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(LocalDependencyResolverTest, TestIgnorePlasmaDependencies) {
|
TEST(LocalDependencyResolverTest, TestIgnorePlasmaDependencies) {
|
||||||
auto store = std::make_shared<CoreWorkerMemoryStore>();
|
auto store = std::make_shared<CoreWorkerMemoryStore>();
|
||||||
LocalDependencyResolver resolver(store);
|
auto task_finisher = std::make_shared<MockTaskFinisher>();
|
||||||
ObjectID obj1 = ObjectID::FromRandom().WithTransportType(TaskTransportType::RAYLET);
|
LocalDependencyResolver resolver(store, task_finisher);
|
||||||
|
ObjectID obj1 = ObjectID::FromRandom();
|
||||||
TaskSpecification task;
|
TaskSpecification task;
|
||||||
task.GetMutableMessage().add_args()->add_object_ids(obj1.Binary());
|
task.GetMutableMessage().add_args()->add_object_ids(obj1.Binary());
|
||||||
bool ok = false;
|
bool ok = false;
|
||||||
|
@ -154,11 +162,13 @@ TEST(LocalDependencyResolverTest, TestIgnorePlasmaDependencies) {
|
||||||
// We ignore and don't block on plasma dependencies.
|
// We ignore and don't block on plasma dependencies.
|
||||||
ASSERT_TRUE(ok);
|
ASSERT_TRUE(ok);
|
||||||
ASSERT_EQ(resolver.NumPendingTasks(), 0);
|
ASSERT_EQ(resolver.NumPendingTasks(), 0);
|
||||||
|
ASSERT_EQ(task_finisher->num_inlined, 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(LocalDependencyResolverTest, TestHandlePlasmaPromotion) {
|
TEST(LocalDependencyResolverTest, TestHandlePlasmaPromotion) {
|
||||||
auto store = std::make_shared<CoreWorkerMemoryStore>();
|
auto store = std::make_shared<CoreWorkerMemoryStore>();
|
||||||
LocalDependencyResolver resolver(store);
|
auto task_finisher = std::make_shared<MockTaskFinisher>();
|
||||||
|
LocalDependencyResolver resolver(store, task_finisher);
|
||||||
ObjectID obj1 = ObjectID::FromRandom().WithTransportType(TaskTransportType::DIRECT);
|
ObjectID obj1 = ObjectID::FromRandom().WithTransportType(TaskTransportType::DIRECT);
|
||||||
std::string meta = std::to_string(static_cast<int>(rpc::ErrorType::OBJECT_IN_PLASMA));
|
std::string meta = std::to_string(static_cast<int>(rpc::ErrorType::OBJECT_IN_PLASMA));
|
||||||
auto metadata = const_cast<uint8_t *>(reinterpret_cast<const uint8_t *>(meta.data()));
|
auto metadata = const_cast<uint8_t *>(reinterpret_cast<const uint8_t *>(meta.data()));
|
||||||
|
@ -175,11 +185,13 @@ TEST(LocalDependencyResolverTest, TestHandlePlasmaPromotion) {
|
||||||
// Checks that the object id is still a direct call id.
|
// Checks that the object id is still a direct call id.
|
||||||
ASSERT_TRUE(task.ArgId(0, 0).IsDirectCallType());
|
ASSERT_TRUE(task.ArgId(0, 0).IsDirectCallType());
|
||||||
ASSERT_EQ(resolver.NumPendingTasks(), 0);
|
ASSERT_EQ(resolver.NumPendingTasks(), 0);
|
||||||
|
ASSERT_EQ(task_finisher->num_inlined, 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(LocalDependencyResolverTest, TestInlineLocalDependencies) {
|
TEST(LocalDependencyResolverTest, TestInlineLocalDependencies) {
|
||||||
auto store = std::make_shared<CoreWorkerMemoryStore>();
|
auto store = std::make_shared<CoreWorkerMemoryStore>();
|
||||||
LocalDependencyResolver resolver(store);
|
auto task_finisher = std::make_shared<MockTaskFinisher>();
|
||||||
|
LocalDependencyResolver resolver(store, task_finisher);
|
||||||
ObjectID obj1 = ObjectID::FromRandom().WithTransportType(TaskTransportType::DIRECT);
|
ObjectID obj1 = ObjectID::FromRandom().WithTransportType(TaskTransportType::DIRECT);
|
||||||
ObjectID obj2 = ObjectID::FromRandom().WithTransportType(TaskTransportType::DIRECT);
|
ObjectID obj2 = ObjectID::FromRandom().WithTransportType(TaskTransportType::DIRECT);
|
||||||
auto data = GenerateRandomObject();
|
auto data = GenerateRandomObject();
|
||||||
|
@ -198,11 +210,13 @@ TEST(LocalDependencyResolverTest, TestInlineLocalDependencies) {
|
||||||
ASSERT_NE(task.ArgData(0), nullptr);
|
ASSERT_NE(task.ArgData(0), nullptr);
|
||||||
ASSERT_NE(task.ArgData(1), nullptr);
|
ASSERT_NE(task.ArgData(1), nullptr);
|
||||||
ASSERT_EQ(resolver.NumPendingTasks(), 0);
|
ASSERT_EQ(resolver.NumPendingTasks(), 0);
|
||||||
|
ASSERT_EQ(task_finisher->num_inlined, 2);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(LocalDependencyResolverTest, TestInlinePendingDependencies) {
|
TEST(LocalDependencyResolverTest, TestInlinePendingDependencies) {
|
||||||
auto store = std::make_shared<CoreWorkerMemoryStore>();
|
auto store = std::make_shared<CoreWorkerMemoryStore>();
|
||||||
LocalDependencyResolver resolver(store);
|
auto task_finisher = std::make_shared<MockTaskFinisher>();
|
||||||
|
LocalDependencyResolver resolver(store, task_finisher);
|
||||||
ObjectID obj1 = ObjectID::FromRandom().WithTransportType(TaskTransportType::DIRECT);
|
ObjectID obj1 = ObjectID::FromRandom().WithTransportType(TaskTransportType::DIRECT);
|
||||||
ObjectID obj2 = ObjectID::FromRandom().WithTransportType(TaskTransportType::DIRECT);
|
ObjectID obj2 = ObjectID::FromRandom().WithTransportType(TaskTransportType::DIRECT);
|
||||||
auto data = GenerateRandomObject();
|
auto data = GenerateRandomObject();
|
||||||
|
@ -223,6 +237,7 @@ TEST(LocalDependencyResolverTest, TestInlinePendingDependencies) {
|
||||||
ASSERT_NE(task.ArgData(0), nullptr);
|
ASSERT_NE(task.ArgData(0), nullptr);
|
||||||
ASSERT_NE(task.ArgData(1), nullptr);
|
ASSERT_NE(task.ArgData(1), nullptr);
|
||||||
ASSERT_EQ(resolver.NumPendingTasks(), 0);
|
ASSERT_EQ(resolver.NumPendingTasks(), 0);
|
||||||
|
ASSERT_EQ(task_finisher->num_inlined, 2);
|
||||||
}
|
}
|
||||||
|
|
||||||
TaskSpecification BuildTaskSpec(const std::unordered_map<std::string, double> &resources,
|
TaskSpecification BuildTaskSpec(const std::unordered_map<std::string, double> &resources,
|
||||||
|
|
|
@ -1,17 +1,22 @@
|
||||||
#include "gtest/gtest.h"
|
#include "ray/core_worker/task_manager.h"
|
||||||
|
|
||||||
|
#include "gtest/gtest.h"
|
||||||
#include "ray/common/task/task_spec.h"
|
#include "ray/common/task/task_spec.h"
|
||||||
#include "ray/core_worker/actor_manager.h"
|
#include "ray/core_worker/actor_manager.h"
|
||||||
|
#include "ray/core_worker/reference_count.h"
|
||||||
#include "ray/core_worker/store_provider/memory_store/memory_store.h"
|
#include "ray/core_worker/store_provider/memory_store/memory_store.h"
|
||||||
#include "ray/core_worker/task_manager.h"
|
|
||||||
#include "ray/util/test_util.h"
|
#include "ray/util/test_util.h"
|
||||||
|
|
||||||
namespace ray {
|
namespace ray {
|
||||||
|
|
||||||
TaskSpecification CreateTaskHelper(uint64_t num_returns) {
|
TaskSpecification CreateTaskHelper(uint64_t num_returns,
|
||||||
|
std::vector<ObjectID> dependencies) {
|
||||||
TaskSpecification task;
|
TaskSpecification task;
|
||||||
task.GetMutableMessage().set_task_id(TaskID::ForFakeTask().Binary());
|
task.GetMutableMessage().set_task_id(TaskID::ForFakeTask().Binary());
|
||||||
task.GetMutableMessage().set_num_returns(num_returns);
|
task.GetMutableMessage().set_num_returns(num_returns);
|
||||||
|
for (const ObjectID &dep : dependencies) {
|
||||||
|
task.GetMutableMessage().add_args()->add_object_ids(dep.Binary());
|
||||||
|
}
|
||||||
return task;
|
return task;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -33,23 +38,31 @@ class TaskManagerTest : public ::testing::Test {
|
||||||
public:
|
public:
|
||||||
TaskManagerTest()
|
TaskManagerTest()
|
||||||
: store_(std::shared_ptr<CoreWorkerMemoryStore>(new CoreWorkerMemoryStore())),
|
: store_(std::shared_ptr<CoreWorkerMemoryStore>(new CoreWorkerMemoryStore())),
|
||||||
|
reference_counter_(std::shared_ptr<ReferenceCounter>(new ReferenceCounter())),
|
||||||
actor_manager_(std::shared_ptr<ActorManagerInterface>(new MockActorManager())),
|
actor_manager_(std::shared_ptr<ActorManagerInterface>(new MockActorManager())),
|
||||||
manager_(store_, actor_manager_, [this](const TaskSpecification &spec) {
|
manager_(store_, reference_counter_, actor_manager_,
|
||||||
|
[this](const TaskSpecification &spec) {
|
||||||
num_retries_++;
|
num_retries_++;
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}) {}
|
}) {}
|
||||||
|
|
||||||
std::shared_ptr<CoreWorkerMemoryStore> store_;
|
std::shared_ptr<CoreWorkerMemoryStore> store_;
|
||||||
|
std::shared_ptr<ReferenceCounter> reference_counter_;
|
||||||
std::shared_ptr<ActorManagerInterface> actor_manager_;
|
std::shared_ptr<ActorManagerInterface> actor_manager_;
|
||||||
TaskManager manager_;
|
TaskManager manager_;
|
||||||
int num_retries_ = 0;
|
int num_retries_ = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
TEST_F(TaskManagerTest, TestTaskSuccess) {
|
TEST_F(TaskManagerTest, TestTaskSuccess) {
|
||||||
auto spec = CreateTaskHelper(1);
|
TaskID caller_id = TaskID::Nil();
|
||||||
|
rpc::Address caller_address;
|
||||||
|
ObjectID dep1 = ObjectID::FromRandom();
|
||||||
|
ObjectID dep2 = ObjectID::FromRandom();
|
||||||
|
auto spec = CreateTaskHelper(1, {dep1, dep2});
|
||||||
ASSERT_FALSE(manager_.IsTaskPending(spec.TaskId()));
|
ASSERT_FALSE(manager_.IsTaskPending(spec.TaskId()));
|
||||||
manager_.AddPendingTask(spec);
|
manager_.AddPendingTask(caller_id, caller_address, spec);
|
||||||
ASSERT_TRUE(manager_.IsTaskPending(spec.TaskId()));
|
ASSERT_TRUE(manager_.IsTaskPending(spec.TaskId()));
|
||||||
|
ASSERT_EQ(reference_counter_->NumObjectIDsInScope(), 3);
|
||||||
auto return_id = spec.ReturnId(0, TaskTransportType::DIRECT);
|
auto return_id = spec.ReturnId(0, TaskTransportType::DIRECT);
|
||||||
WorkerContext ctx(WorkerType::WORKER, JobID::FromInt(0));
|
WorkerContext ctx(WorkerType::WORKER, JobID::FromInt(0));
|
||||||
|
|
||||||
|
@ -60,6 +73,8 @@ TEST_F(TaskManagerTest, TestTaskSuccess) {
|
||||||
return_object->set_data(data->Data(), data->Size());
|
return_object->set_data(data->Data(), data->Size());
|
||||||
manager_.CompletePendingTask(spec.TaskId(), reply, nullptr);
|
manager_.CompletePendingTask(spec.TaskId(), reply, nullptr);
|
||||||
ASSERT_FALSE(manager_.IsTaskPending(spec.TaskId()));
|
ASSERT_FALSE(manager_.IsTaskPending(spec.TaskId()));
|
||||||
|
// Only the return object reference should remain.
|
||||||
|
ASSERT_EQ(reference_counter_->NumObjectIDsInScope(), 1);
|
||||||
|
|
||||||
std::vector<std::shared_ptr<RayObject>> results;
|
std::vector<std::shared_ptr<RayObject>> results;
|
||||||
RAY_CHECK_OK(store_->Get({return_id}, 1, -1, ctx, false, &results));
|
RAY_CHECK_OK(store_->Get({return_id}, 1, -1, ctx, false, &results));
|
||||||
|
@ -69,19 +84,33 @@ TEST_F(TaskManagerTest, TestTaskSuccess) {
|
||||||
return_object->data().size()),
|
return_object->data().size()),
|
||||||
0);
|
0);
|
||||||
ASSERT_EQ(num_retries_, 0);
|
ASSERT_EQ(num_retries_, 0);
|
||||||
|
|
||||||
|
std::vector<ObjectID> removed;
|
||||||
|
reference_counter_->AddLocalReference(return_id);
|
||||||
|
reference_counter_->RemoveLocalReference(return_id, &removed);
|
||||||
|
ASSERT_EQ(removed[0], return_id);
|
||||||
|
ASSERT_EQ(reference_counter_->NumObjectIDsInScope(), 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TaskManagerTest, TestTaskFailure) {
|
TEST_F(TaskManagerTest, TestTaskFailure) {
|
||||||
auto spec = CreateTaskHelper(1);
|
TaskID caller_id = TaskID::Nil();
|
||||||
|
rpc::Address caller_address;
|
||||||
|
ObjectID dep1 = ObjectID::FromRandom();
|
||||||
|
ObjectID dep2 = ObjectID::FromRandom();
|
||||||
|
ASSERT_EQ(reference_counter_->NumObjectIDsInScope(), 0);
|
||||||
|
auto spec = CreateTaskHelper(1, {dep1, dep2});
|
||||||
ASSERT_FALSE(manager_.IsTaskPending(spec.TaskId()));
|
ASSERT_FALSE(manager_.IsTaskPending(spec.TaskId()));
|
||||||
manager_.AddPendingTask(spec);
|
manager_.AddPendingTask(caller_id, caller_address, spec);
|
||||||
ASSERT_TRUE(manager_.IsTaskPending(spec.TaskId()));
|
ASSERT_TRUE(manager_.IsTaskPending(spec.TaskId()));
|
||||||
|
ASSERT_EQ(reference_counter_->NumObjectIDsInScope(), 3);
|
||||||
auto return_id = spec.ReturnId(0, TaskTransportType::DIRECT);
|
auto return_id = spec.ReturnId(0, TaskTransportType::DIRECT);
|
||||||
WorkerContext ctx(WorkerType::WORKER, JobID::FromInt(0));
|
WorkerContext ctx(WorkerType::WORKER, JobID::FromInt(0));
|
||||||
|
|
||||||
auto error = rpc::ErrorType::WORKER_DIED;
|
auto error = rpc::ErrorType::WORKER_DIED;
|
||||||
manager_.PendingTaskFailed(spec.TaskId(), error);
|
manager_.PendingTaskFailed(spec.TaskId(), error);
|
||||||
ASSERT_FALSE(manager_.IsTaskPending(spec.TaskId()));
|
ASSERT_FALSE(manager_.IsTaskPending(spec.TaskId()));
|
||||||
|
// Only the return object reference should remain.
|
||||||
|
ASSERT_EQ(reference_counter_->NumObjectIDsInScope(), 1);
|
||||||
|
|
||||||
std::vector<std::shared_ptr<RayObject>> results;
|
std::vector<std::shared_ptr<RayObject>> results;
|
||||||
RAY_CHECK_OK(store_->Get({return_id}, 1, -1, ctx, false, &results));
|
RAY_CHECK_OK(store_->Get({return_id}, 1, -1, ctx, false, &results));
|
||||||
|
@ -90,14 +119,26 @@ TEST_F(TaskManagerTest, TestTaskFailure) {
|
||||||
ASSERT_TRUE(results[0]->IsException(&stored_error));
|
ASSERT_TRUE(results[0]->IsException(&stored_error));
|
||||||
ASSERT_EQ(stored_error, error);
|
ASSERT_EQ(stored_error, error);
|
||||||
ASSERT_EQ(num_retries_, 0);
|
ASSERT_EQ(num_retries_, 0);
|
||||||
|
|
||||||
|
std::vector<ObjectID> removed;
|
||||||
|
reference_counter_->AddLocalReference(return_id);
|
||||||
|
reference_counter_->RemoveLocalReference(return_id, &removed);
|
||||||
|
ASSERT_EQ(removed[0], return_id);
|
||||||
|
ASSERT_EQ(reference_counter_->NumObjectIDsInScope(), 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TaskManagerTest, TestTaskRetry) {
|
TEST_F(TaskManagerTest, TestTaskRetry) {
|
||||||
auto spec = CreateTaskHelper(1);
|
TaskID caller_id = TaskID::Nil();
|
||||||
|
rpc::Address caller_address;
|
||||||
|
ObjectID dep1 = ObjectID::FromRandom();
|
||||||
|
ObjectID dep2 = ObjectID::FromRandom();
|
||||||
|
ASSERT_EQ(reference_counter_->NumObjectIDsInScope(), 0);
|
||||||
|
auto spec = CreateTaskHelper(1, {dep1, dep2});
|
||||||
ASSERT_FALSE(manager_.IsTaskPending(spec.TaskId()));
|
ASSERT_FALSE(manager_.IsTaskPending(spec.TaskId()));
|
||||||
int num_retries = 3;
|
int num_retries = 3;
|
||||||
manager_.AddPendingTask(spec, num_retries);
|
manager_.AddPendingTask(caller_id, caller_address, spec, num_retries);
|
||||||
ASSERT_TRUE(manager_.IsTaskPending(spec.TaskId()));
|
ASSERT_TRUE(manager_.IsTaskPending(spec.TaskId()));
|
||||||
|
ASSERT_EQ(reference_counter_->NumObjectIDsInScope(), 3);
|
||||||
auto return_id = spec.ReturnId(0, TaskTransportType::DIRECT);
|
auto return_id = spec.ReturnId(0, TaskTransportType::DIRECT);
|
||||||
WorkerContext ctx(WorkerType::WORKER, JobID::FromInt(0));
|
WorkerContext ctx(WorkerType::WORKER, JobID::FromInt(0));
|
||||||
|
|
||||||
|
@ -105,6 +146,7 @@ TEST_F(TaskManagerTest, TestTaskRetry) {
|
||||||
for (int i = 0; i < num_retries; i++) {
|
for (int i = 0; i < num_retries; i++) {
|
||||||
manager_.PendingTaskFailed(spec.TaskId(), error);
|
manager_.PendingTaskFailed(spec.TaskId(), error);
|
||||||
ASSERT_TRUE(manager_.IsTaskPending(spec.TaskId()));
|
ASSERT_TRUE(manager_.IsTaskPending(spec.TaskId()));
|
||||||
|
ASSERT_EQ(reference_counter_->NumObjectIDsInScope(), 3);
|
||||||
std::vector<std::shared_ptr<RayObject>> results;
|
std::vector<std::shared_ptr<RayObject>> results;
|
||||||
ASSERT_FALSE(store_->Get({return_id}, 1, 0, ctx, false, &results).ok());
|
ASSERT_FALSE(store_->Get({return_id}, 1, 0, ctx, false, &results).ok());
|
||||||
ASSERT_EQ(num_retries_, i + 1);
|
ASSERT_EQ(num_retries_, i + 1);
|
||||||
|
@ -112,6 +154,8 @@ TEST_F(TaskManagerTest, TestTaskRetry) {
|
||||||
|
|
||||||
manager_.PendingTaskFailed(spec.TaskId(), error);
|
manager_.PendingTaskFailed(spec.TaskId(), error);
|
||||||
ASSERT_FALSE(manager_.IsTaskPending(spec.TaskId()));
|
ASSERT_FALSE(manager_.IsTaskPending(spec.TaskId()));
|
||||||
|
// Only the return object reference should remain.
|
||||||
|
ASSERT_EQ(reference_counter_->NumObjectIDsInScope(), 1);
|
||||||
|
|
||||||
std::vector<std::shared_ptr<RayObject>> results;
|
std::vector<std::shared_ptr<RayObject>> results;
|
||||||
RAY_CHECK_OK(store_->Get({return_id}, 1, -0, ctx, false, &results));
|
RAY_CHECK_OK(store_->Get({return_id}, 1, -0, ctx, false, &results));
|
||||||
|
@ -119,6 +163,12 @@ TEST_F(TaskManagerTest, TestTaskRetry) {
|
||||||
rpc::ErrorType stored_error;
|
rpc::ErrorType stored_error;
|
||||||
ASSERT_TRUE(results[0]->IsException(&stored_error));
|
ASSERT_TRUE(results[0]->IsException(&stored_error));
|
||||||
ASSERT_EQ(stored_error, error);
|
ASSERT_EQ(stored_error, error);
|
||||||
|
|
||||||
|
std::vector<ObjectID> removed;
|
||||||
|
reference_counter_->AddLocalReference(return_id);
|
||||||
|
reference_counter_->RemoveLocalReference(return_id, &removed);
|
||||||
|
ASSERT_EQ(removed[0], return_id);
|
||||||
|
ASSERT_EQ(reference_counter_->NumObjectIDsInScope(), 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace ray
|
} // namespace ray
|
||||||
|
|
|
@ -18,7 +18,7 @@ struct TaskState {
|
||||||
|
|
||||||
void InlineDependencies(
|
void InlineDependencies(
|
||||||
absl::flat_hash_map<ObjectID, std::shared_ptr<RayObject>> dependencies,
|
absl::flat_hash_map<ObjectID, std::shared_ptr<RayObject>> dependencies,
|
||||||
TaskSpecification &task) {
|
TaskSpecification &task, std::vector<ObjectID> *inlined) {
|
||||||
auto &msg = task.GetMutableMessage();
|
auto &msg = task.GetMutableMessage();
|
||||||
size_t found = 0;
|
size_t found = 0;
|
||||||
for (size_t i = 0; i < task.NumArgs(); i++) {
|
for (size_t i = 0; i < task.NumArgs(); i++) {
|
||||||
|
@ -43,6 +43,7 @@ void InlineDependencies(
|
||||||
const auto &metadata = it->second->GetMetadata();
|
const auto &metadata = it->second->GetMetadata();
|
||||||
mutable_arg->set_metadata(metadata->Data(), metadata->Size());
|
mutable_arg->set_metadata(metadata->Data(), metadata->Size());
|
||||||
}
|
}
|
||||||
|
inlined->push_back(id);
|
||||||
}
|
}
|
||||||
found++;
|
found++;
|
||||||
} else {
|
} else {
|
||||||
|
@ -83,15 +84,19 @@ void LocalDependencyResolver::ResolveDependencies(TaskSpecification &task,
|
||||||
obj_id, [this, state, obj_id, on_complete](std::shared_ptr<RayObject> obj) {
|
obj_id, [this, state, obj_id, on_complete](std::shared_ptr<RayObject> obj) {
|
||||||
RAY_CHECK(obj != nullptr);
|
RAY_CHECK(obj != nullptr);
|
||||||
bool complete = false;
|
bool complete = false;
|
||||||
|
std::vector<ObjectID> inlined;
|
||||||
{
|
{
|
||||||
absl::MutexLock lock(&mu_);
|
absl::MutexLock lock(&mu_);
|
||||||
state->local_dependencies[obj_id] = std::move(obj);
|
state->local_dependencies[obj_id] = std::move(obj);
|
||||||
if (--state->dependencies_remaining == 0) {
|
if (--state->dependencies_remaining == 0) {
|
||||||
InlineDependencies(state->local_dependencies, state->task);
|
InlineDependencies(state->local_dependencies, state->task, &inlined);
|
||||||
complete = true;
|
complete = true;
|
||||||
num_pending_ -= 1;
|
num_pending_ -= 1;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if (inlined.size() > 0) {
|
||||||
|
task_finisher_->OnTaskDependenciesInlined(inlined);
|
||||||
|
}
|
||||||
if (complete) {
|
if (complete) {
|
||||||
on_complete();
|
on_complete();
|
||||||
}
|
}
|
||||||
|
|
|
@ -6,14 +6,16 @@
|
||||||
#include "ray/common/id.h"
|
#include "ray/common/id.h"
|
||||||
#include "ray/common/task/task_spec.h"
|
#include "ray/common/task/task_spec.h"
|
||||||
#include "ray/core_worker/store_provider/memory_store/memory_store.h"
|
#include "ray/core_worker/store_provider/memory_store/memory_store.h"
|
||||||
|
#include "ray/core_worker/task_manager.h"
|
||||||
|
|
||||||
namespace ray {
|
namespace ray {
|
||||||
|
|
||||||
// This class is thread-safe.
|
// This class is thread-safe.
|
||||||
class LocalDependencyResolver {
|
class LocalDependencyResolver {
|
||||||
public:
|
public:
|
||||||
LocalDependencyResolver(std::shared_ptr<CoreWorkerMemoryStore> store)
|
LocalDependencyResolver(std::shared_ptr<CoreWorkerMemoryStore> store,
|
||||||
: in_memory_store_(store), num_pending_(0) {}
|
std::shared_ptr<TaskFinisherInterface> task_finisher)
|
||||||
|
: in_memory_store_(store), task_finisher_(task_finisher), num_pending_(0) {}
|
||||||
|
|
||||||
/// Resolve all local and remote dependencies for the task, calling the specified
|
/// Resolve all local and remote dependencies for the task, calling the specified
|
||||||
/// callback when done. Direct call ids in the task specification will be resolved
|
/// callback when done. Direct call ids in the task specification will be resolved
|
||||||
|
@ -33,6 +35,9 @@ class LocalDependencyResolver {
|
||||||
/// The in-memory store.
|
/// The in-memory store.
|
||||||
std::shared_ptr<CoreWorkerMemoryStore> in_memory_store_;
|
std::shared_ptr<CoreWorkerMemoryStore> in_memory_store_;
|
||||||
|
|
||||||
|
/// Used to complete tasks.
|
||||||
|
std::shared_ptr<TaskFinisherInterface> task_finisher_;
|
||||||
|
|
||||||
/// Number of tasks pending dependency resolution.
|
/// Number of tasks pending dependency resolution.
|
||||||
std::atomic<int> num_pending_;
|
std::atomic<int> num_pending_;
|
||||||
|
|
||||||
|
|
|
@ -8,9 +8,9 @@
|
||||||
#include <queue>
|
#include <queue>
|
||||||
#include <set>
|
#include <set>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
#include "absl/container/flat_hash_map.h"
|
|
||||||
|
|
||||||
#include "absl/base/thread_annotations.h"
|
#include "absl/base/thread_annotations.h"
|
||||||
|
#include "absl/container/flat_hash_map.h"
|
||||||
#include "absl/synchronization/mutex.h"
|
#include "absl/synchronization/mutex.h"
|
||||||
#include "ray/common/id.h"
|
#include "ray/common/id.h"
|
||||||
#include "ray/common/ray_object.h"
|
#include "ray/common/ray_object.h"
|
||||||
|
@ -39,7 +39,7 @@ class CoreWorkerDirectActorTaskSubmitter {
|
||||||
std::shared_ptr<CoreWorkerMemoryStore> store,
|
std::shared_ptr<CoreWorkerMemoryStore> store,
|
||||||
std::shared_ptr<TaskFinisherInterface> task_finisher)
|
std::shared_ptr<TaskFinisherInterface> task_finisher)
|
||||||
: client_factory_(client_factory),
|
: client_factory_(client_factory),
|
||||||
resolver_(store),
|
resolver_(store, task_finisher),
|
||||||
task_finisher_(task_finisher) {}
|
task_finisher_(task_finisher) {}
|
||||||
|
|
||||||
/// Submit a task to an actor for execution.
|
/// Submit a task to an actor for execution.
|
||||||
|
|
|
@ -5,7 +5,6 @@
|
||||||
|
|
||||||
#include "absl/base/thread_annotations.h"
|
#include "absl/base/thread_annotations.h"
|
||||||
#include "absl/synchronization/mutex.h"
|
#include "absl/synchronization/mutex.h"
|
||||||
|
|
||||||
#include "ray/common/id.h"
|
#include "ray/common/id.h"
|
||||||
#include "ray/common/ray_object.h"
|
#include "ray/common/ray_object.h"
|
||||||
#include "ray/core_worker/context.h"
|
#include "ray/core_worker/context.h"
|
||||||
|
@ -41,7 +40,7 @@ class CoreWorkerDirectTaskSubmitter {
|
||||||
: local_lease_client_(lease_client),
|
: local_lease_client_(lease_client),
|
||||||
client_factory_(client_factory),
|
client_factory_(client_factory),
|
||||||
lease_client_factory_(lease_client_factory),
|
lease_client_factory_(lease_client_factory),
|
||||||
resolver_(store),
|
resolver_(store, task_finisher),
|
||||||
task_finisher_(task_finisher),
|
task_finisher_(task_finisher),
|
||||||
local_raylet_id_(local_raylet_id),
|
local_raylet_id_(local_raylet_id),
|
||||||
lease_timeout_ms_(lease_timeout_ms) {}
|
lease_timeout_ms_(lease_timeout_ms) {}
|
||||||
|
|
Loading…
Add table
Reference in a new issue