[core] Ref counting for actor handles (#7434)

* tmp

* Move Exit handler into CoreWorker, exit once owner's ref count goes to 0

* fix build

* Remove __ray_terminate__ and add test case for distributed ref counting

* lint

* Remove unused

* Fixes for detached actor, duplicate actor handles

* Remove unused

* Remove creation return ID

* Remove ObjectIDs from python, set references in CoreWorker

* Fix crash

* Fix memory crash

* Fix tests

* fix

* fixes

* fix tests

* fix java build

* fix build

* fix

* check status

* check status
This commit is contained in:
Stephanie Wang 2020-03-10 17:45:07 -07:00 committed by GitHub
parent 119a303ea0
commit fdb528514b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
23 changed files with 330 additions and 180 deletions

View file

@ -34,7 +34,7 @@ def py_func_call_java_actor(value):
@ray.remote
def py_func_call_java_actor_from_handle(value):
assert isinstance(value, bytes)
actor_handle = ray.actor.ActorHandle._deserialization_helper(value, False)
actor_handle = ray.actor.ActorHandle._deserialization_helper(value)
r = actor_handle.concat.remote(b"2")
return ray.get(r)
@ -42,7 +42,7 @@ def py_func_call_java_actor_from_handle(value):
@ray.remote
def py_func_call_python_actor_from_handle(value):
assert isinstance(value, bytes)
actor_handle = ray.actor.ActorHandle._deserialization_helper(value, False)
actor_handle = ray.actor.ActorHandle._deserialization_helper(value)
r = actor_handle.increase.remote(2)
return ray.get(r)
@ -52,7 +52,7 @@ def py_func_pass_python_actor_handle():
counter = Counter.remote(2)
f = ray.java_function("org.ray.api.test.CrossLanguageInvocationTest",
"callPythonActorHandle")
r = f.remote(counter._serialization_helper(False))
r = f.remote(counter._serialization_helper())
return ray.get(r)

View file

@ -865,7 +865,7 @@ cdef class CoreWorker:
with nogil:
check_status(self.core_worker.get().KillActor(
c_actor_id))
c_actor_id, True))
def resource_ids(self):
cdef:
@ -894,15 +894,24 @@ cdef class CoreWorker:
self.core_worker.get().CreateProfileEvent(event_type),
extra_data)
def deserialize_and_register_actor_handle(self, const c_string &bytes):
def remove_actor_handle_reference(self, ActorID actor_id):
cdef:
CActorID c_actor_id = actor_id.native()
self.core_worker.get().RemoveActorHandleReference(c_actor_id)
def deserialize_and_register_actor_handle(self, const c_string &bytes,
ObjectID
outer_object_id):
cdef:
CActorHandle* c_actor_handle
CObjectID c_outer_object_id = (outer_object_id.native() if
outer_object_id else
CObjectID.Nil())
worker = ray.worker.get_global_worker()
worker.check_connected()
manager = worker.function_actor_manager
c_actor_id = self.core_worker.get().DeserializeAndRegisterActorHandle(
bytes)
bytes, c_outer_object_id)
check_status(self.core_worker.get().GetActorHandle(
c_actor_id, &c_actor_handle))
actor_id = ActorID(c_actor_id.Binary())
@ -940,14 +949,13 @@ cdef class CoreWorker:
actor_creation_function_descriptor,
worker.current_session_and_job)
def serialize_actor_handle(self, actor_handle):
assert isinstance(actor_handle, ray.actor.ActorHandle)
def serialize_actor_handle(self, ActorID actor_id):
cdef:
ActorID actor_id = actor_handle._ray_actor_id
c_string output
CObjectID c_actor_handle_id
check_status(self.core_worker.get().SerializeActorHandle(
actor_id.native(), &output))
return output
actor_id.native(), &output, &c_actor_handle_id))
return output, ObjectID(c_actor_handle_id.Binary())
def add_object_id_reference(self, ObjectID object_id):
# Note: faster to not release GIL for short-running op.
@ -974,7 +982,9 @@ cdef class CoreWorker:
const c_string &serialized_owner_address):
cdef:
CObjectID c_object_id = CObjectID.FromBinary(object_id_binary)
CObjectID c_outer_object_id = outer_object_id.native()
CObjectID c_outer_object_id = (outer_object_id.native() if
outer_object_id else
CObjectID.Nil())
CTaskID c_owner_id = CTaskID.FromBinary(owner_id_binary)
CAddress c_owner_address = CAddress()

View file

@ -652,6 +652,14 @@ class ActorHandle:
decorator=self._ray_method_decorators.get(method_name))
setattr(self, method_name, method)
def __del__(self):
# Mark that this actor handle has gone out of scope. Once all actor
# handles are out of scope, the actor will exit.
worker = ray.worker.get_global_worker()
if worker.connected and hasattr(worker, "core_worker"):
worker.core_worker.remove_actor_handle_reference(
self._ray_actor_id)
def _actor_method_call(self,
method_name,
args=None,
@ -752,36 +760,6 @@ class ActorHandle:
self._ray_actor_creation_function_descriptor.class_name,
self._actor_id.hex())
def __del__(self):
"""Terminate the worker that is running this actor."""
# TODO(swang): Also clean up forked actor handles.
# Kill the worker if this is the original actor handle, created
# with Class.remote(). TODO(rkn): Even without passing handles around,
# this is not the right policy. the actor should be alive as long as
# there are ANY handles in scope in the process that created the actor,
# not just the first one.
worker = ray.worker.get_global_worker()
exported_in_current_session_and_job = (
self._ray_session_and_job == worker.current_session_and_job)
if (worker.mode == ray.worker.SCRIPT_MODE
and not exported_in_current_session_and_job):
# If the worker is a driver and driver id has changed because
# Ray was shut down re-initialized, the actor is already cleaned up
# and we don't need to send `__ray_terminate__` again.
logger.warning(
"Actor is garbage collected in the wrong driver." +
" Actor id = %s, class name = %s.", self._ray_actor_id,
self._ray_actor_creation_function_descriptor.class_name)
return
if worker.connected and self._ray_original_handle:
# Note: in py2 the weakref is destroyed prior to calling __del__
# so we need to set the hardref here briefly
try:
self.__ray_terminate__._actor_hard_ref = self
self.__ray_terminate__.remote()
finally:
self.__ray_terminate__._actor_hard_ref = None
def __ray_kill__(self):
"""Deprecated - use ray.kill() instead."""
logger.warning("actor.__ray_kill__() is deprecated and will be removed"
@ -792,13 +770,9 @@ class ActorHandle:
def _actor_id(self):
return self._ray_actor_id
def _serialization_helper(self, ray_forking):
def _serialization_helper(self):
"""This is defined in order to make pickling work.
Args:
ray_forking: True if this is being called because Ray is forking
the actor handle and false if it is being called by pickling.
Returns:
A dictionary of the information needed to reconstruct the object.
"""
@ -807,10 +781,11 @@ class ActorHandle:
if hasattr(worker, "core_worker"):
# Non-local mode
state = worker.core_worker.serialize_actor_handle(self)
state = worker.core_worker.serialize_actor_handle(
self._ray_actor_id)
else:
# Local mode
state = {
state = ({
"actor_language": self._ray_actor_language,
"actor_id": self._ray_actor_id,
"method_decorators": self._ray_method_decorators,
@ -819,18 +794,20 @@ class ActorHandle:
"actor_method_cpus": self._ray_actor_method_cpus,
"actor_creation_function_descriptor": self.
_ray_actor_creation_function_descriptor,
}
}, None)
return state
@classmethod
def _deserialization_helper(cls, state, ray_forking):
def _deserialization_helper(cls, state, outer_object_id=None):
"""This is defined in order to make pickling work.
Args:
state: The serialized state of the actor handle.
ray_forking: True if this is being called because Ray is forking
the actor handle and false if it is being called by pickling.
outer_object_id: The ObjectID that the serialized actor handle was
contained in, if any. This is used for counting references to
the actor handle.
"""
worker = ray.worker.get_global_worker()
worker.check_connected()
@ -838,7 +815,7 @@ class ActorHandle:
if hasattr(worker, "core_worker"):
# Non-local mode
return worker.core_worker.deserialize_and_register_actor_handle(
state)
state, outer_object_id)
else:
# Local mode
return cls(
@ -855,8 +832,8 @@ class ActorHandle:
def __reduce__(self):
"""This code path is used by pickling but not by Ray forking."""
state = self._serialization_helper(False)
return ActorHandle._deserialization_helper, (state, False)
state = self._serialization_helper()
return ActorHandle._deserialization_helper, (state)
def modify_class(cls):

View file

@ -116,7 +116,7 @@ cdef extern from "ray/core_worker/core_worker.h" nogil:
const CActorID &actor_id, const CRayFunction &function,
const c_vector[CTaskArg] &args, const CTaskOptions &options,
c_vector[CObjectID] *return_ids)
CRayStatus KillActor(const CActorID &actor_id)
CRayStatus KillActor(const CActorID &actor_id, c_bool force_kill)
unique_ptr[CProfileEvent] CreateProfileEvent(
const c_string &event_type)
@ -134,9 +134,12 @@ cdef extern from "ray/core_worker/core_worker.h" nogil:
void SetWebuiDisplay(const c_string &key, const c_string &message)
CTaskID GetCallerId()
const ResourceMappingType &GetResourceIDs() const
CActorID DeserializeAndRegisterActorHandle(const c_string &bytes)
void RemoveActorHandleReference(const CActorID &actor_id)
CActorID DeserializeAndRegisterActorHandle(const c_string &bytes, const
CObjectID &outer_object_id)
CRayStatus SerializeActorHandle(const CActorID &actor_id, c_string
*bytes)
*bytes,
CObjectID *c_actor_handle_id)
CRayStatus GetActorHandle(const CActorID &actor_id,
CActorHandle **actor_handle) const
void AddLocalReference(const CObjectID &object_id)

View file

@ -133,11 +133,18 @@ class SerializationContext:
self._thread_local = threading.local()
def actor_handle_serializer(obj):
return obj._serialization_helper(True)
serialized, actor_handle_id = obj._serialization_helper()
# Update ref counting for the actor handle
self.add_contained_object_id(actor_handle_id)
return serialized
def actor_handle_deserializer(serialized_obj):
# If this actor handle was stored in another object, then tell the
# core worker.
context = ray.worker.global_worker.get_serialization_context()
outer_id = context.get_outer_object_id()
return ray.actor.ActorHandle._deserialization_helper(
serialized_obj, True)
serialized_obj, outer_id)
self._register_cloudpickle_serializer(
ray.actor.ActorHandle,
@ -151,15 +158,7 @@ class SerializationContext:
return serialized_obj[0](*serialized_obj[1])
def object_id_serializer(obj):
if self.is_in_band_serialization():
self.add_contained_object_id(obj)
else:
# If this serialization is out-of-band (e.g., from a call to
# cloudpickle directly or captured in a remote function/actor),
# then pin the object for the lifetime of this worker by adding
# a local reference that won't ever be removed.
ray.worker.get_global_worker(
).core_worker.add_object_id_reference(obj)
self.add_contained_object_id(obj)
owner_id = ""
owner_address = ""
# TODO(swang): Remove this check. Otherwise, we will not be able to
@ -239,10 +238,20 @@ class SerializationContext:
return object_ids
def add_contained_object_id(self, object_id):
if not hasattr(self._thread_local, "object_ids"):
self._thread_local.object_ids = set()
self._thread_local.object_ids.add(object_id)
if self.is_in_band_serialization():
# This object ID is being stored in an object. Add the ID to the
# list of IDs contained in the object so that we keep the inner
# object value alive as long as the outer object is in scope.
if not hasattr(self._thread_local, "object_ids"):
self._thread_local.object_ids = set()
self._thread_local.object_ids.add(object_id)
else:
# If this serialization is out-of-band (e.g., from a call to
# cloudpickle directly or captured in a remote function/actor),
# then pin the object for the lifetime of this worker by adding
# a local reference that won't ever be removed.
ray.worker.get_global_worker().core_worker.add_object_id_reference(
object_id)
def _deserialize_pickle5_data(self, data):
try:

View file

@ -106,6 +106,7 @@ def test_actor_method_metadata_cache(ray_start_regular):
# The cache of ActorClassMethodMetadata.
cache = ray.actor.ActorClassMethodMetadata._cache
cache.clear()
# Check cache hit during ActorHandle deserialization.
A1 = ray.remote(Actor)
@ -532,6 +533,34 @@ def test_actor_method_deletion(ray_start_regular):
assert ray.get(Actor.remote().method.remote()) == 1
def test_distributed_actor_handle_deletion(ray_start_regular):
@ray.remote
class Actor:
def method(self):
return 1
def getpid(self):
return os.getpid()
@ray.remote
def f(actor, signal):
ray.get(signal.wait.remote())
return ray.get(actor.method.remote())
signal = ray.test_utils.SignalActor.remote()
a = Actor.remote()
pid = ray.get(a.getpid.remote())
# Pass the handle to another task that cannot run yet.
x_id = f.remote(a, signal)
# Delete the original handle. The actor should not get killed yet.
del a
# Once the task finishes, the actor process should get killed.
ray.get(signal.send.remote())
assert ray.get(x_id) == 1
ray.test_utils.wait_for_pid_to_exit(pid)
def test_multiple_actors(ray_start_regular):
@ray.remote
class Counter:

View file

@ -202,7 +202,7 @@ def test_raylet_info_endpoint(shutdown_only):
try:
assert len(actor_info) == 1
_, parent_actor_info = actor_info.popitem()
assert parent_actor_info["numObjectIdsInScope"] == 11
assert parent_actor_info["numObjectIdsInScope"] == 13
assert parent_actor_info["numLocalObjects"] == 10
children = parent_actor_info["children"]
assert len(children) == 2

View file

@ -362,6 +362,12 @@ ObjectID ObjectID::FromRandom() {
flags);
}
ObjectID ObjectID::ForActorHandle(const ActorID &actor_id) {
return ObjectID::ForTaskReturn(TaskID::ForActorCreationTask(actor_id),
/*return_index=*/1,
static_cast<int>(TaskTransportType::DIRECT));
}
ObjectID ObjectID::GenerateObjectId(const std::string &task_id_binary,
ObjectIDFlagsType flags,
ObjectIDIndexType object_index) {

View file

@ -363,6 +363,14 @@ class ObjectID : public BaseID<ObjectID> {
/// \return A random object id.
static ObjectID FromRandom();
/// Compute the object ID that is used to track an actor's lifetime. This
/// object does not actually have a value; it is just used for counting
/// references (handles) to the actor.
///
/// \param actor_id The ID of the actor to track.
/// \return The computed object ID.
static ObjectID ForActorHandle(const ActorID &actor_id);
private:
/// A helper method to generate an ObjectID.
static ObjectID GenerateObjectId(const std::string &task_id_binary,

View file

@ -19,12 +19,15 @@
namespace {
ray::rpc::ActorHandle CreateInnerActorHandle(
const class ActorID &actor_id, const class JobID &job_id,
const class ActorID &actor_id, const TaskID &owner_id,
const ray::rpc::Address &owner_address, const class JobID &job_id,
const ObjectID &initial_cursor, const Language actor_language, bool is_direct_call,
const ray::FunctionDescriptor &actor_creation_task_function_descriptor,
const std::string &extension_data) {
ray::rpc::ActorHandle inner;
inner.set_actor_id(actor_id.Data(), actor_id.Size());
inner.set_owner_id(owner_id.Binary());
inner.mutable_owner_address()->CopyFrom(owner_address);
inner.set_creation_job_id(job_id.Data(), job_id.Size());
inner.set_actor_language(actor_language);
*inner.mutable_actor_creation_task_function_descriptor() =
@ -46,13 +49,14 @@ ray::rpc::ActorHandle CreateInnerActorHandleFromString(const std::string &serial
namespace ray {
ActorHandle::ActorHandle(
const class ActorID &actor_id, const class JobID &job_id,
const class ActorID &actor_id, const TaskID &owner_id,
const rpc::Address &owner_address, const class JobID &job_id,
const ObjectID &initial_cursor, const Language actor_language, bool is_direct_call,
const ray::FunctionDescriptor &actor_creation_task_function_descriptor,
const std::string &extension_data)
: ActorHandle(CreateInnerActorHandle(
actor_id, job_id, initial_cursor, actor_language, is_direct_call,
actor_creation_task_function_descriptor, extension_data)) {}
actor_id, owner_id, owner_address, job_id, initial_cursor, actor_language,
is_direct_call, actor_creation_task_function_descriptor, extension_data)) {}
ActorHandle::ActorHandle(const std::string &serialized)
: ActorHandle(CreateInnerActorHandleFromString(serialized)) {}

View file

@ -32,7 +32,8 @@ class ActorHandle {
: inner_(inner), actor_cursor_(ObjectID::FromBinary(inner_.actor_cursor())) {}
// Constructs a new ActorHandle as part of the actor creation process.
ActorHandle(const ActorID &actor_id, const JobID &job_id,
ActorHandle(const ActorID &actor_id, const TaskID &owner_id,
const rpc::Address &owner_address, const JobID &job_id,
const ObjectID &initial_cursor, const Language actor_language,
bool is_direct_call,
const ray::FunctionDescriptor &actor_creation_task_function_descriptor,
@ -43,6 +44,10 @@ class ActorHandle {
ActorID GetActorID() const { return ActorID::FromBinary(inner_.actor_id()); };
TaskID GetOwnerId() const { return TaskID::FromBinary(inner_.owner_id()); }
rpc::Address GetOwnerAddress() const { return inner_.owner_address(); }
/// ID of the job that created the actor (it is possible that the handle
/// exists on a job with a different job ID).
JobID CreationJobID() const { return JobID::FromBinary(inner_.creation_job_id()); };

View file

@ -133,26 +133,12 @@ CoreWorker::CoreWorker(const WorkerType worker_type, const Language language,
auto execute_task =
std::bind(&CoreWorker::ExecuteTask, this, std::placeholders::_1,
std::placeholders::_2, std::placeholders::_3, std::placeholders::_4);
auto exit = [this](bool intentional) {
// Release the resources early in case draining takes a long time.
RAY_CHECK_OK(local_raylet_client_->NotifyDirectCallTaskBlocked());
task_manager_->DrainAndShutdown([this, intentional]() {
// To avoid problems, make sure shutdown is always called from the same
// event loop each time.
task_execution_service_.post([this, intentional]() {
if (intentional) {
Disconnect(); // Notify the raylet this is an intentional exit.
}
Shutdown();
});
});
};
raylet_task_receiver_ =
std::unique_ptr<CoreWorkerRayletTaskReceiver>(new CoreWorkerRayletTaskReceiver(
worker_context_.GetWorkerID(), local_raylet_client_, execute_task, exit));
worker_context_.GetWorkerID(), local_raylet_client_, execute_task));
direct_task_receiver_ = std::unique_ptr<CoreWorkerDirectTaskReceiver>(
new CoreWorkerDirectTaskReceiver(worker_context_, local_raylet_client_,
task_execution_service_, execute_task, exit));
task_execution_service_, execute_task));
}
// Start RPC server after all the task receivers are properly initialized.
@ -294,6 +280,22 @@ void CoreWorker::Disconnect() {
}
}
void CoreWorker::Exit(bool intentional) {
exiting_ = true;
// Release the resources early in case draining takes a long time.
RAY_CHECK_OK(local_raylet_client_->NotifyDirectCallTaskBlocked());
task_manager_->DrainAndShutdown([this, intentional]() {
// To avoid problems, make sure shutdown is always called from the same
// event loop each time.
task_execution_service_.post([this, intentional]() {
if (intentional) {
Disconnect(); // Notify the raylet this is an intentional exit.
}
Shutdown();
});
});
}
void CoreWorker::RunIOService() {
#ifdef _WIN32
// TODO(mehrdadn): Is there an equivalent for Windows we need here?
@ -358,6 +360,20 @@ void CoreWorker::InternalHeartbeat() {
internal_timer_.async_wait(boost::bind(&CoreWorker::InternalHeartbeat, this));
}
std::unordered_map<ObjectID, std::pair<size_t, size_t>>
CoreWorker::GetAllReferenceCounts() const {
auto counts = reference_counter_->GetAllReferenceCounts();
absl::MutexLock lock(&actor_handles_mutex_);
// Strip actor IDs from the ref counts since there is no associated ObjectID
// in the language frontend.
for (const auto &handle : actor_handles_) {
auto actor_id = handle.first;
auto actor_handle_id = ObjectID::ForActorHandle(actor_id);
counts.erase(actor_handle_id);
}
return counts;
}
void CoreWorker::PromoteToPlasmaAndGetOwnershipInfo(const ObjectID &object_id,
TaskID *owner_id,
rpc::Address *owner_address) {
@ -784,24 +800,27 @@ Status CoreWorker::CreateActor(const RayFunction &function,
actor_creation_options.is_direct_call, actor_creation_options.max_concurrency,
actor_creation_options.is_detached, actor_creation_options.is_asyncio);
std::unique_ptr<ActorHandle> actor_handle(
new ActorHandle(actor_id, job_id, /*actor_cursor=*/return_ids[0],
function.GetLanguage(), actor_creation_options.is_direct_call,
function.GetFunctionDescriptor(), extension_data));
RAY_CHECK(AddActorHandle(std::move(actor_handle)))
<< "Actor " << actor_id << " already exists";
*return_actor_id = actor_id;
TaskSpecification task_spec = builder.Build();
Status status;
if (actor_creation_options.is_direct_call) {
task_manager_->AddPendingTask(
GetCallerId(), rpc_address_, task_spec,
std::max(RayConfig::instance().actor_creation_min_retries(),
actor_creation_options.max_reconstructions));
return direct_task_submitter_->SubmitTask(task_spec);
status = direct_task_submitter_->SubmitTask(task_spec);
} else {
return local_raylet_client_->SubmitTask(task_spec);
status = local_raylet_client_->SubmitTask(task_spec);
}
std::unique_ptr<ActorHandle> actor_handle(new ActorHandle(
actor_id, GetCallerId(), rpc_address_, job_id, /*actor_cursor=*/return_ids[0],
function.GetLanguage(), actor_creation_options.is_direct_call,
function.GetFunctionDescriptor(), extension_data));
RAY_CHECK(AddActorHandle(std::move(actor_handle),
/*is_owner_handle=*/!actor_creation_options.is_detached))
<< "Actor " << actor_id << " already exists";
return status;
}
Status CoreWorker::SubmitActorTask(const ActorID &actor_id, const RayFunction &function,
@ -853,35 +872,58 @@ Status CoreWorker::SubmitActorTask(const ActorID &actor_id, const RayFunction &f
return status;
}
Status CoreWorker::KillActor(const ActorID &actor_id) {
Status CoreWorker::KillActor(const ActorID &actor_id, bool force_kill) {
ActorHandle *actor_handle = nullptr;
RAY_RETURN_NOT_OK(GetActorHandle(actor_id, &actor_handle));
RAY_CHECK(actor_handle->IsDirectCallActor());
return direct_actor_submitter_->KillActor(actor_id);
direct_actor_submitter_->KillActor(actor_id, force_kill);
return Status::OK();
}
ActorID CoreWorker::DeserializeAndRegisterActorHandle(const std::string &serialized) {
void CoreWorker::RemoveActorHandleReference(const ActorID &actor_id) {
ObjectID actor_handle_id = ObjectID::ForActorHandle(actor_id);
reference_counter_->RemoveLocalReference(actor_handle_id, nullptr);
}
ActorID CoreWorker::DeserializeAndRegisterActorHandle(const std::string &serialized,
const ObjectID &outer_object_id) {
std::unique_ptr<ActorHandle> actor_handle(new ActorHandle(serialized));
const ActorID actor_id = actor_handle->GetActorID();
RAY_UNUSED(AddActorHandle(std::move(actor_handle)));
const auto actor_id = actor_handle->GetActorID();
const auto owner_id = actor_handle->GetOwnerId();
const auto owner_address = actor_handle->GetOwnerAddress();
RAY_UNUSED(AddActorHandle(std::move(actor_handle), /*is_owner_handle=*/false));
ObjectID actor_handle_id = ObjectID::ForActorHandle(actor_id);
reference_counter_->AddBorrowedObject(actor_handle_id, outer_object_id, owner_id,
owner_address);
return actor_id;
}
Status CoreWorker::SerializeActorHandle(const ActorID &actor_id,
std::string *output) const {
Status CoreWorker::SerializeActorHandle(const ActorID &actor_id, std::string *output,
ObjectID *actor_handle_id) const {
ActorHandle *actor_handle = nullptr;
auto status = GetActorHandle(actor_id, &actor_handle);
if (status.ok()) {
actor_handle->Serialize(output);
*actor_handle_id = ObjectID::ForActorHandle(actor_id);
}
return status;
}
bool CoreWorker::AddActorHandle(std::unique_ptr<ActorHandle> actor_handle) {
absl::MutexLock lock(&actor_handles_mutex_);
bool CoreWorker::AddActorHandle(std::unique_ptr<ActorHandle> actor_handle,
bool is_owner_handle) {
const auto &actor_id = actor_handle->GetActorID();
const auto actor_creation_return_id = ObjectID::ForActorHandle(actor_id);
reference_counter_->AddLocalReference(actor_creation_return_id);
bool inserted;
{
absl::MutexLock lock(&actor_handles_mutex_);
inserted = actor_handles_.emplace(actor_id, std::move(actor_handle)).second;
}
auto inserted = actor_handles_.emplace(actor_id, std::move(actor_handle)).second;
if (inserted) {
// Register a callback to handle actor notifications.
auto actor_notification_callback = [this](const ActorID &actor_id,
@ -923,7 +965,23 @@ bool CoreWorker::AddActorHandle(std::unique_ptr<ActorHandle> actor_handle) {
RAY_CHECK_OK(gcs_client_->Actors().AsyncSubscribe(
actor_id, actor_notification_callback, nullptr));
RAY_CHECK(reference_counter_->SetDeleteCallback(
actor_creation_return_id,
[this, actor_id, is_owner_handle](const ObjectID &object_id) {
// TODO(swang): Unsubscribe from the actor table.
// TODO(swang): Remove the actor handle entry.
// If we own the actor and the actor handle is no longer in scope,
// terminate the actor.
if (is_owner_handle) {
RAY_LOG(INFO) << "Owner's handle and creation ID " << object_id
<< " has gone out of scope, sending message to actor "
<< actor_id << " to do a clean exit.";
RAY_CHECK_OK(KillActor(actor_id, /*intentional=*/true));
}
}));
}
return inserted;
}
@ -1106,6 +1164,11 @@ Status CoreWorker::ExecuteTask(const TaskSpecification &task_spec,
current_task_ = TaskSpecification();
}
RAY_LOG(DEBUG) << "Finished executing task " << task_spec.TaskId();
if (status.IsSystemExit()) {
Exit(status.IsIntentionalSystemExit());
}
return status;
}
@ -1214,6 +1277,9 @@ void CoreWorker::HandlePushTask(const rpc::PushTaskRequest &request,
task_queue_length_ += 1;
task_execution_service_.post([=] {
// We have posted an exit task onto the main event loop,
// so shouldn't bother executing any further work.
if (exiting_) return;
direct_task_receiver_->HandlePushTask(request, reply, send_reply_callback);
});
}
@ -1328,11 +1394,16 @@ void CoreWorker::HandleKillActor(const rpc::KillActorRequest &request,
send_reply_callback(Status::Invalid(msg), nullptr, nullptr);
return;
}
RAY_LOG(INFO) << "Got KillActor, exiting immediately...";
if (log_dir_ != "") {
RayLog::ShutDownRayLog();
if (request.force_kill()) {
RAY_LOG(INFO) << "Got KillActor, exiting immediately...";
if (log_dir_ != "") {
RayLog::ShutDownRayLog();
}
exit(1);
} else {
Exit(/*intentional=*/true);
}
exit(1);
}
void CoreWorker::HandleGetCoreWorkerStats(const rpc::GetCoreWorkerStatsRequest &request,

View file

@ -95,6 +95,8 @@ class CoreWorker : public rpc::CoreWorkerServiceHandler {
virtual ~CoreWorker();
void Exit(bool intentional);
void Disconnect();
WorkerType GetWorkerType() const { return worker_type_; }
@ -140,9 +142,7 @@ class CoreWorker : public rpc::CoreWorkerServiceHandler {
/// Returns a map of all ObjectIDs currently in scope with a pair of their
/// (local, submitted_task) reference counts. For debugging purposes.
std::unordered_map<ObjectID, std::pair<size_t, size_t>> GetAllReferenceCounts() const {
return reference_counter_->GetAllReferenceCounts();
}
std::unordered_map<ObjectID, std::pair<size_t, size_t>> GetAllReferenceCounts() const;
/// Promote an object to plasma and get its owner information. This should be
/// called when serializing an object ID, and the returned information should
@ -395,7 +395,13 @@ class CoreWorker : public rpc::CoreWorkerServiceHandler {
///
/// \param[in] actor_id ID of the actor to kill.
/// \param[out] Status
Status KillActor(const ActorID &actor_id);
Status KillActor(const ActorID &actor_id, bool force_kill);
/// Decrease the reference count for this actor. Should be called by the
/// language frontend when a reference to the ActorHandle destroyed.
///
/// \param[in] actor_id The actor ID to decrease the reference count for.
void RemoveActorHandleReference(const ActorID &actor_id);
/// Add an actor handle from a serialized string.
///
@ -404,8 +410,11 @@ class CoreWorker : public rpc::CoreWorkerServiceHandler {
/// actor.
///
/// \param[in] serialized The serialized actor handle.
/// \param[in] outer_object_id The object ID that contained the serialized
/// actor handle, if any.
/// \return The ActorID of the deserialized handle.
ActorID DeserializeAndRegisterActorHandle(const std::string &serialized);
ActorID DeserializeAndRegisterActorHandle(const std::string &serialized,
const ObjectID &outer_object_id);
/// Serialize an actor handle.
///
@ -414,8 +423,12 @@ class CoreWorker : public rpc::CoreWorkerServiceHandler {
///
/// \param[in] actor_id The ID of the actor handle to serialize.
/// \param[out] The serialized handle.
/// \param[out] The ID used to track references to the actor handle. If the
/// serialized actor handle in the language frontend is stored inside an
/// object, then this must be recorded in the worker's ReferenceCounter.
/// \return Status::Invalid if we don't have the specified handle.
Status SerializeActorHandle(const ActorID &actor_id, std::string *output) const;
Status SerializeActorHandle(const ActorID &actor_id, std::string *output,
ObjectID *actor_handle_id) const;
///
/// Public methods related to task execution. Should not be used by driver processes.
@ -572,9 +585,12 @@ class CoreWorker : public rpc::CoreWorkerServiceHandler {
/// they are submitted.
///
/// \param actor_handle The handle to the actor.
/// \param is_owner_handle Whether this is the owner's handle to the actor.
/// The owner is the creator of the actor and is responsible for telling the
/// actor to disconnect once all handles are out of scope.
/// \return True if the handle was added and False if we already had a handle
/// to the same actor.
bool AddActorHandle(std::unique_ptr<ActorHandle> actor_handle);
bool AddActorHandle(std::unique_ptr<ActorHandle> actor_handle, bool is_owner_handle);
///
/// Private methods related to task execution. Should not be used by driver processes.
@ -813,6 +829,9 @@ class CoreWorker : public rpc::CoreWorkerServiceHandler {
// Plasma Callback
PlasmaSubscriptionCallback plasma_done_callback_;
/// Whether we are shutting down and not running further tasks.
bool exiting_ = false;
friend class CoreWorkerTest;
};

View file

@ -160,7 +160,8 @@ JNIEXPORT void JNICALL Java_org_ray_runtime_RayNativeRuntime_nativeSetResource(
JNIEXPORT void JNICALL Java_org_ray_runtime_RayNativeRuntime_nativeKillActor(
JNIEnv *env, jclass, jlong nativeCoreWorkerPointer, jbyteArray actorId) {
auto core_worker = reinterpret_cast<ray::CoreWorker *>(nativeCoreWorkerPointer);
auto status = core_worker->KillActor(JavaByteArrayToId<ActorID>(env, actorId));
auto status = core_worker->KillActor(JavaByteArrayToId<ActorID>(env, actorId),
/*force_kill=*/true);
THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, (void)0);
}

View file

@ -64,8 +64,9 @@ JNIEXPORT jbyteArray JNICALL Java_org_ray_runtime_actor_NativeRayActor_nativeSer
JNIEnv *env, jclass o, jlong nativeCoreWorkerPointer, jbyteArray actorId) {
auto actor_id = JavaByteArrayToId<ray::ActorID>(env, actorId);
std::string output;
ray::Status status =
GetCoreWorker(nativeCoreWorkerPointer).SerializeActorHandle(actor_id, &output);
ObjectID actor_handle_id;
ray::Status status = GetCoreWorker(nativeCoreWorkerPointer)
.SerializeActorHandle(actor_id, &output, &actor_handle_id);
jbyteArray bytes = env->NewByteArray(output.size());
env->SetByteArrayRegion(bytes, 0, output.size(),
reinterpret_cast<const jbyte *>(output.c_str()));
@ -78,7 +79,8 @@ JNIEXPORT jbyteArray JNICALL Java_org_ray_runtime_actor_NativeRayActor_nativeDes
RAY_CHECK(buffer->Size() > 0);
auto binary = std::string(reinterpret_cast<char *>(buffer->Data()), buffer->Size());
auto actor_id =
GetCoreWorker(nativeCoreWorkerPointer).DeserializeAndRegisterActorHandle(binary);
GetCoreWorker(nativeCoreWorkerPointer)
.DeserializeAndRegisterActorHandle(binary, /*outer_object_id=*/ObjectID::Nil());
return IdToJavaByteArray<ray::ActorID>(env, actor_id);
}

View file

@ -48,11 +48,16 @@ void TaskManager::AddPendingTask(const TaskID &caller_id,
}
}
}
if (spec.IsActorTask()) {
const auto actor_creation_return_id =
spec.ActorCreationDummyObjectId().WithTransportType(TaskTransportType::DIRECT);
task_deps.push_back(actor_creation_return_id);
}
reference_counter_->UpdateSubmittedTaskReferences(task_deps);
// Add new owned objects for the return values of the task.
size_t num_returns = spec.NumReturns();
if (spec.IsActorCreationTask() || spec.IsActorTask()) {
if (spec.IsActorTask()) {
num_returns--;
}
for (size_t i = 0; i < num_returns; i++) {
@ -225,6 +230,11 @@ void TaskManager::RemoveFinishedTaskReferences(
inlined_ids.end());
}
}
if (spec.IsActorTask()) {
const auto actor_creation_return_id =
spec.ActorCreationDummyObjectId().WithTransportType(TaskTransportType::DIRECT);
plasma_dependencies.push_back(actor_creation_return_id);
}
std::vector<ObjectID> deleted;
reference_counter_->UpdateFinishedTaskReferences(plasma_dependencies, borrower_addr,

View file

@ -632,9 +632,10 @@ TEST_F(ZeroNodeTest, TestTaskSpecPerf) {
/*is_detached*/ false,
/*is_asyncio*/ false};
const auto job_id = NextJobId();
ActorHandle actor_handle(ActorID::Of(job_id, TaskID::ForDriverTask(job_id), 1), job_id,
ObjectID::FromRandom(), function.GetLanguage(), true,
function.GetFunctionDescriptor(), "");
ActorHandle actor_handle(ActorID::Of(job_id, TaskID::ForDriverTask(job_id), 1),
TaskID::Nil(), rpc::Address(), job_id, ObjectID::FromRandom(),
function.GetLanguage(), true, function.GetFunctionDescriptor(),
"");
// Manually create `num_tasks` task specs, and for each of them create a
// `PushTaskRequest`, this is to batch performance of TaskSpec
@ -748,8 +749,9 @@ TEST_F(ZeroNodeTest, TestWorkerContext) {
TEST_F(ZeroNodeTest, TestActorHandle) {
// Test actor handle serialization and deserialization round trip.
JobID job_id = NextJobId();
ActorHandle original(ActorID::Of(job_id, TaskID::ForDriverTask(job_id), 0), job_id,
ObjectID::FromRandom(), Language::PYTHON, /*is_direct_call=*/false,
ActorHandle original(ActorID::Of(job_id, TaskID::ForDriverTask(job_id), 0),
TaskID::Nil(), rpc::Address(), job_id, ObjectID::FromRandom(),
Language::PYTHON, /*is_direct_call=*/false,
ray::FunctionDescriptorBuilder::BuildPython("", "", "", ""), "");
std::string output;
original.Serialize(&output);

View file

@ -22,9 +22,15 @@ using ray::rpc::ActorTableData;
namespace ray {
Status CoreWorkerDirectActorTaskSubmitter::KillActor(const ActorID &actor_id) {
void CoreWorkerDirectActorTaskSubmitter::KillActor(const ActorID &actor_id,
bool force_kill) {
absl::MutexLock lock(&mu_);
pending_force_kills_.insert(actor_id);
auto inserted = pending_force_kills_.emplace(actor_id, force_kill);
if (!inserted.second && force_kill) {
// Overwrite the previous request to kill the actor if the new request is a
// force kill.
inserted.first->second = true;
}
auto it = rpc_clients_.find(actor_id);
if (it == rpc_clients_.end()) {
// Actor is not yet created, or is being reconstructed, cache the request
@ -37,7 +43,6 @@ Status CoreWorkerDirectActorTaskSubmitter::KillActor(const ActorID &actor_id) {
} else {
SendPendingTasks(actor_id);
}
return Status::OK();
}
Status CoreWorkerDirectActorTaskSubmitter::SubmitTask(TaskSpecification task_spec) {
@ -138,11 +143,14 @@ void CoreWorkerDirectActorTaskSubmitter::SendPendingTasks(const ActorID &actor_i
RAY_CHECK(client);
// Check if there is a pending force kill. If there is, send it and disconnect the
// client.
if (pending_force_kills_.find(actor_id) != pending_force_kills_.end()) {
auto it = pending_force_kills_.find(actor_id);
if (it != pending_force_kills_.end()) {
rpc::KillActorRequest request;
request.set_intended_actor_id(actor_id.Binary());
RAY_CHECK_OK(client->KillActor(request, nullptr));
pending_force_kills_.erase(actor_id);
request.set_force_kill(it->second);
// It's okay if this fails because this means the worker is already dead.
RAY_UNUSED(client->KillActor(request, nullptr));
pending_force_kills_.erase(it);
}
// Submit all pending requests.
@ -226,10 +234,6 @@ void CoreWorkerDirectTaskReceiver::HandlePushTask(
}
auto accept_callback = [this, reply, send_reply_callback, task_spec, resource_ids]() {
// We have posted an exit task onto the main event loop,
// so shouldn't bother executing any further work.
if (exiting_) return;
auto num_returns = task_spec.NumReturns();
if (task_spec.IsActorCreationTask() || task_spec.IsActorTask()) {
// Decrease to account for the dummy object id.
@ -280,18 +284,12 @@ void CoreWorkerDirectTaskReceiver::HandlePushTask(
// Don't allow the worker to be reused, even though the reply status is OK.
// The worker will be shutting down shortly.
reply->set_worker_exiting(true);
// In Python, SystemExit can only be raised on the main thread. To
// work around this when we are executing tasks on worker threads,
// we re-post the exit event explicitly on the main thread.
exiting_ = true;
if (objects_valid) {
// This happens when max_calls is hit. We still need to return the objects.
send_reply_callback(Status::OK(), nullptr, nullptr);
} else {
send_reply_callback(status, nullptr, nullptr);
}
task_main_io_service_.post(
[this, status]() { exit_handler_(status.IsIntentionalSystemExit()); });
} else {
RAY_CHECK(objects_valid) << return_objects.size() << " " << num_returns;
send_reply_callback(status, nullptr, nullptr);

View file

@ -68,8 +68,9 @@ class CoreWorkerDirectActorTaskSubmitter {
/// Tell this actor to exit immediately.
///
/// \param[in] actor_id The actor_id of the actor to kill.
/// \return Status::Invalid if the actor could not be killed.
Status KillActor(const ActorID &actor_id);
/// \param[in] force_kill Whether to force kill the actor, or let the actor
/// try a clean exit.
void KillActor(const ActorID &actor_id, bool force_kill);
/// Create connection to actor and send all pending tasks.
///
@ -134,7 +135,7 @@ class CoreWorkerDirectActorTaskSubmitter {
absl::flat_hash_map<ActorID, std::string> worker_ids_ GUARDED_BY(mu_);
/// Set of actor ids that should be force killed once a client is available.
absl::flat_hash_set<ActorID> pending_force_kills_ GUARDED_BY(mu_);
absl::flat_hash_map<ActorID, bool> pending_force_kills_ GUARDED_BY(mu_);
/// Map from actor id to the actor's pending requests. Each actor's requests
/// are ordered by the task number in the request.
@ -407,12 +408,10 @@ class CoreWorkerDirectTaskReceiver {
CoreWorkerDirectTaskReceiver(WorkerContext &worker_context,
std::shared_ptr<raylet::RayletClient> &local_raylet_client,
boost::asio::io_service &main_io_service,
const TaskHandler &task_handler,
const std::function<void(bool)> &exit_handler)
const TaskHandler &task_handler)
: worker_context_(worker_context),
local_raylet_client_(local_raylet_client),
task_handler_(task_handler),
exit_handler_(exit_handler),
task_main_io_service_(main_io_service) {}
/// Initialize this receiver. This must be called prior to use.
@ -441,8 +440,6 @@ class CoreWorkerDirectTaskReceiver {
WorkerContext &worker_context_;
/// The callback function to process a task.
TaskHandler task_handler_;
/// The callback function to exit the worker.
std::function<void(bool)> exit_handler_;
/// The IO event loop for running tasks on.
boost::asio::io_service &task_main_io_service_;
/// Factory for producing new core worker clients.
@ -457,8 +454,6 @@ class CoreWorkerDirectTaskReceiver {
/// Queue of pending requests per actor handle.
/// TODO(ekl) GC these queues once the handle is no longer active.
std::unordered_map<TaskID, std::unique_ptr<SchedulingQueue>> scheduling_queue_;
/// Whether we are shutting down and not running further tasks.
bool exiting_ = false;
};
} // namespace ray

View file

@ -20,11 +20,8 @@ namespace ray {
CoreWorkerRayletTaskReceiver::CoreWorkerRayletTaskReceiver(
const WorkerID &worker_id, std::shared_ptr<raylet::RayletClient> &raylet_client,
const TaskHandler &task_handler, const std::function<void(bool)> &exit_handler)
: worker_id_(worker_id),
raylet_client_(raylet_client),
task_handler_(task_handler),
exit_handler_(exit_handler) {}
const TaskHandler &task_handler)
: worker_id_(worker_id), raylet_client_(raylet_client), task_handler_(task_handler) {}
void CoreWorkerRayletTaskReceiver::HandleAssignTask(
const rpc::AssignTaskRequest &request, rpc::AssignTaskReply *reply,
@ -66,7 +63,6 @@ void CoreWorkerRayletTaskReceiver::HandleAssignTask(
// transport.
auto status = task_handler_(task_spec, resource_ids, &results, &borrower_refs);
if (status.IsSystemExit()) {
exit_handler_(status.IsIntentionalSystemExit());
return;
}

View file

@ -34,8 +34,7 @@ class CoreWorkerRayletTaskReceiver {
CoreWorkerRayletTaskReceiver(const WorkerID &worker_id,
std::shared_ptr<raylet::RayletClient> &raylet_client,
const TaskHandler &task_handler,
const std::function<void(bool)> &exit_handler);
const TaskHandler &task_handler);
/// Handle a `AssignTask` request.
/// The implementation can handle this request asynchronously. When handling is done,
@ -56,8 +55,6 @@ class CoreWorkerRayletTaskReceiver {
std::shared_ptr<raylet::RayletClient> &raylet_client_;
/// The callback function to process a task.
TaskHandler task_handler_;
/// The callback function to exit the worker.
std::function<void(bool)> exit_handler_;
/// The callback to process arg wait complete.
std::function<void(int64_t)> on_wait_complete_;
};

View file

@ -27,26 +27,32 @@ message ActorHandle {
// ID of the actor.
bytes actor_id = 1;
// The task or actor ID of the actor's owner.
bytes owner_id = 2;
// The address of the actor's owner.
Address owner_address = 3;
// ID of the job that created the actor (it is possible that the handle
// exists on a job with a different job ID).
bytes creation_job_id = 3;
bytes creation_job_id = 4;
// Language of the actor.
Language actor_language = 4;
Language actor_language = 5;
// Function descriptor of actor creation task.
FunctionDescriptor actor_creation_task_function_descriptor = 5;
FunctionDescriptor actor_creation_task_function_descriptor = 6;
// The unique id of the dummy object returned by the actor creation task.
// It's used as a dependency for the first task.
// TODO: Remove this once scheduling is done by task counter only.
bytes actor_cursor = 6;
bytes actor_cursor = 7;
// Whether direct actor call is used.
bool is_direct_call = 7;
bool is_direct_call = 8;
// An extension field that is used for storing app-language-specific data.
bytes extension_data = 8;
bytes extension_data = 9;
}
message AssignTaskRequest {
@ -164,6 +170,8 @@ message WaitForObjectEvictionReply {
message KillActorRequest {
// ID of the actor that is intended to be killed.
bytes intended_actor_id = 1;
// Whether to force kill the actor.
bool force_kill = 2;
}
message KillActorReply {

View file

@ -62,7 +62,7 @@ class ExecutionTask:
self.task_id = task_pb.task_id
self.task_index = task_pb.task_index
self.worker_actor = ray.actor.ActorHandle.\
_deserialization_helper(task_pb.worker_actor, False)
_deserialization_helper(task_pb.worker_actor)
class ExecutionGraph: