[Java] Support actor handle reference counting. (#21249)

This commit is contained in:
Qing Wang 2022-01-01 10:26:22 +08:00 committed by GitHub
parent 14ed7cfaaa
commit 340fbf53c0
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
11 changed files with 241 additions and 13 deletions

View file

@ -1,14 +1,23 @@
package io.ray.runtime.actor;
import com.google.common.base.FinalizableReferenceQueue;
import com.google.common.base.FinalizableWeakReference;
import com.google.common.base.Preconditions;
import com.google.common.collect.Sets;
import io.ray.api.BaseActorHandle;
import io.ray.api.Ray;
import io.ray.api.id.ActorId;
import io.ray.api.id.ObjectId;
import io.ray.runtime.RayRuntimeInternal;
import io.ray.runtime.generated.Common.Language;
import java.io.Externalizable;
import java.io.IOException;
import java.io.ObjectInput;
import java.io.ObjectOutput;
import java.lang.ref.Reference;
import java.util.List;
import java.util.Set;
import java.util.concurrent.atomic.AtomicBoolean;
/**
* Abstract and language-independent implementation of actor handle for cluster mode. This is a
@ -16,19 +25,33 @@ import java.util.List;
*/
public abstract class NativeActorHandle implements BaseActorHandle, Externalizable {
private static final FinalizableReferenceQueue REFERENCE_QUEUE = new FinalizableReferenceQueue();
private static final Set<Reference<NativeActorHandle>> REFERENCES = Sets.newConcurrentHashSet();
/** ID of the actor. */
byte[] actorId;
/** ID of the actor handle. */
byte[] actorHandleId = new byte[ObjectId.LENGTH];
private Language language;
NativeActorHandle(byte[] actorId, Language language) {
Preconditions.checkState(!ActorId.fromBytes(actorId).isNil());
this.actorId = actorId;
this.language = language;
new NativeActorHandleReference(this);
}
/** Required by FST. */
NativeActorHandle() {}
NativeActorHandle() {
// Note there is no need to add local reference here since this is only used for FST.
}
public ObjectId getActorHandleId() {
return new ObjectId(actorHandleId);
}
public static NativeActorHandle create(byte[] actorId) {
Language language = Language.forNumber(nativeGetLanguage(actorId));
@ -58,7 +81,7 @@ public abstract class NativeActorHandle implements BaseActorHandle, Externalizab
@Override
public void writeExternal(ObjectOutput out) throws IOException {
out.writeObject(nativeSerialize(actorId));
out.writeObject(nativeSerialize(actorId, actorHandleId));
out.writeObject(language);
}
@ -66,6 +89,7 @@ public abstract class NativeActorHandle implements BaseActorHandle, Externalizab
public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException {
actorId = nativeDeserialize((byte[]) in.readObject());
language = (Language) in.readObject();
new NativeActorHandleReference(this);
}
/**
@ -74,7 +98,7 @@ public abstract class NativeActorHandle implements BaseActorHandle, Externalizab
* @return the bytes of the actor handle
*/
public byte[] toBytes() {
return nativeSerialize(actorId);
return nativeSerialize(actorId, actorHandleId);
}
/**
@ -89,13 +113,42 @@ public abstract class NativeActorHandle implements BaseActorHandle, Externalizab
return create(actorId, language);
}
private static final class NativeActorHandleReference
extends FinalizableWeakReference<NativeActorHandle> {
private final AtomicBoolean removed;
private final byte[] workerId;
private final byte[] actorId;
public NativeActorHandleReference(NativeActorHandle handle) {
super(handle, REFERENCE_QUEUE);
this.actorId = handle.actorId;
RayRuntimeInternal runtime = (RayRuntimeInternal) Ray.internal();
this.workerId = runtime.getWorkerContext().getCurrentWorkerId().getBytes();
this.removed = new AtomicBoolean(false);
REFERENCES.add(this);
}
@Override
public void finalizeReferent() {
if (!removed.getAndSet(true)) {
REFERENCES.remove(this);
// It's possible that GC is executed after the runtime is shutdown.
if (Ray.isInitialized()) {
nativeRemoveActorHandleReference(workerId, actorId);
}
}
}
}
// TODO(chaokunyang) do we need to free the ActorHandle in core worker by using phantom reference?
private static native int nativeGetLanguage(byte[] actorId);
static native List<String> nativeGetActorCreationTaskFunctionDescriptor(byte[] actorId);
private static native byte[] nativeSerialize(byte[] actorId);
private static native byte[] nativeSerialize(byte[] actorId, byte[] actorHandleId);
private static native byte[] nativeDeserialize(byte[] data);
private static native void nativeRemoveActorHandleReference(byte[] workerId, byte[] actorId);
}

View file

@ -1,5 +1,6 @@
package io.ray.runtime.object;
import com.google.common.collect.ImmutableList;
import com.google.common.primitives.Bytes;
import com.google.protobuf.InvalidProtocolBufferException;
import io.ray.api.id.ActorId;
@ -159,7 +160,10 @@ public class ObjectSerializer {
// serializedBytes is MessagePack serialized bytes
// Only OBJECT_METADATA_TYPE_RAW is raw bytes,
// any other type should be the MessagePack serialized bytes.
return new NativeRayObject(serializedBytes, OBJECT_METADATA_TYPE_ACTOR_HANDLE);
NativeRayObject nativeRayObject =
new NativeRayObject(serializedBytes, OBJECT_METADATA_TYPE_ACTOR_HANDLE);
nativeRayObject.setContainedObjectIds(ImmutableList.of(actorHandle.getActorHandleId()));
return nativeRayObject;
} else {
try {
Pair<byte[], Boolean> serialized = Serializer.encode(object);

View file

@ -0,0 +1,109 @@
package io.ray.test;
import io.ray.api.ActorHandle;
import io.ray.api.ObjectRef;
import io.ray.api.Ray;
import io.ray.runtime.actor.NativeActorHandle;
import io.ray.runtime.exception.RayActorException;
import io.ray.runtime.util.SystemUtil;
import java.lang.ref.Reference;
import java.lang.reflect.Field;
import java.lang.reflect.Method;
import java.util.Set;
import java.util.concurrent.TimeUnit;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.testng.Assert;
import org.testng.annotations.Test;
@Test(groups = {"cluster"})
public class ActorHandleReferenceCountTest {
private static final Logger LOG = LoggerFactory.getLogger(ActorHandleReferenceCountTest.class);
/**
* Because we can't explicitly GC an Java object. We use this helper method to manually remove an
* local reference.
*/
private static void del(ActorHandle<?> handle) {
try {
Field referencesField = NativeActorHandle.class.getDeclaredField("REFERENCES");
referencesField.setAccessible(true);
Set<?> references = (Set<?>) referencesField.get(null);
Class<?> referenceClass =
Class.forName("io.ray.runtime.actor.NativeActorHandle$NativeActorHandleReference");
Method finalizeReferentMethod = referenceClass.getDeclaredMethod("finalizeReferent");
finalizeReferentMethod.setAccessible(true);
for (Object reference : references) {
if (handle.equals(((Reference<?>) reference).get())) {
finalizeReferentMethod.invoke(reference);
break;
}
}
} catch (Exception e) {
throw new RuntimeException(e);
}
}
private static class MyActor {
public int getPid() {
return SystemUtil.pid();
}
public String hello() {
return "hello";
}
}
private static String foo(ActorHandle<MyActor> myActor, ActorHandle<SignalActor> signal) {
signal.task(SignalActor::waitSignal).remote().get();
String result = myActor.task(MyActor::hello).remote().get();
del(myActor);
return result;
}
public void testActorHandleReferenceCount() {
try {
System.setProperty("ray.job.num-java-workers-per-process", "1");
Ray.init();
ActorHandle<SignalActor> signal = Ray.actor(SignalActor::new).remote();
ActorHandle<MyActor> myActor = Ray.actor(MyActor::new).remote();
int pid = myActor.task(MyActor::getPid).remote().get();
// Pass the handle to another task that cannot run yet.
ObjectRef<String> helloObj =
Ray.task(ActorHandleReferenceCountTest::foo, myActor, signal).remote();
// Delete the original handle. The actor should not get killed yet.
del(myActor);
// Once the task finishes, the actor process should get killed.
signal.task(SignalActor::sendSignal).remote().get();
Assert.assertEquals("hello", helloObj.get());
Assert.assertTrue(TestUtils.waitForCondition(() -> !SystemUtil.isProcessAlive(pid), 10000));
} finally {
Ray.shutdown();
}
}
public void testRemoveActorHandleReferenceInMultipleThreadedActor() throws InterruptedException {
System.setProperty("ray.job.num-java-workers-per-process", "5");
try {
Ray.init();
ActorHandle<MyActor> myActor1 = Ray.actor(MyActor::new).remote();
int pid1 = myActor1.task(MyActor::getPid).remote().get();
ActorHandle<MyActor> myActor2 = Ray.actor(MyActor::new).remote();
int pid2 = myActor2.task(MyActor::getPid).remote().get();
Assert.assertEquals(pid1, pid2);
del(myActor1);
TimeUnit.SECONDS.sleep(5);
Assert.assertThrows(
RayActorException.class,
() -> {
myActor1.task(MyActor::hello).remote().get();
});
/// myActor2 shouldn't be killed.
Assert.assertEquals("hello", myActor2.task(MyActor::hello).remote().get());
} finally {
Ray.shutdown();
}
}
}

View file

@ -21,6 +21,7 @@
#include "ray/core_worker/actor_handle.h"
#include "ray/core_worker/common.h"
#include "ray/core_worker/core_worker.h"
#include "ray/core_worker/core_worker_process.h"
#ifdef __cplusplus
extern "C" {
@ -45,12 +46,14 @@ Java_io_ray_runtime_actor_NativeActorHandle_nativeGetActorCreationTaskFunctionDe
}
JNIEXPORT jbyteArray JNICALL Java_io_ray_runtime_actor_NativeActorHandle_nativeSerialize(
JNIEnv *env, jclass o, jbyteArray actorId) {
JNIEnv *env, jclass o, jbyteArray actorId, jbyteArray actorHandleId) {
auto actor_id = JavaByteArrayToId<ActorID>(env, actorId);
std::string output;
ObjectID actor_handle_id;
Status status = CoreWorkerProcess::GetCoreWorker().SerializeActorHandle(
actor_id, &output, &actor_handle_id);
env->SetByteArrayRegion(actorHandleId, 0, ObjectID::kLength,
reinterpret_cast<const jbyte *>(actor_handle_id.Data()));
THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, nullptr);
return NativeStringToJavaByteArray(env, output);
}
@ -67,6 +70,20 @@ Java_io_ray_runtime_actor_NativeActorHandle_nativeDeserialize(JNIEnv *env, jclas
return IdToJavaByteArray<ActorID>(env, actor_id);
}
JNIEXPORT void JNICALL
Java_io_ray_runtime_actor_NativeActorHandle_nativeRemoveActorHandleReference(
JNIEnv *env, jclass clz, jbyteArray workerId, jbyteArray actorId) {
// We can't control the timing of Java GC, so it's normal that this method is called but
// core worker is shutting down (or already shut down). If we can't get a core worker
// instance here, skip calling the `RemoveLocalReference` method.
const auto worker_id = JavaByteArrayToId<ray::WorkerID>(env, workerId);
auto core_worker = CoreWorkerProcess::TryGetWorker(worker_id);
if (core_worker != nullptr) {
const auto actor_id = JavaByteArrayToId<ActorID>(env, actorId);
core_worker->RemoveActorHandleReference(actor_id);
}
}
#ifdef __cplusplus
}
#endif

View file

@ -41,10 +41,10 @@ Java_io_ray_runtime_actor_NativeActorHandle_nativeGetActorCreationTaskFunctionDe
/*
* Class: io_ray_runtime_actor_NativeActorHandle
* Method: nativeSerialize
* Signature: ([B)[B
* Signature: ([B[B)[B
*/
JNIEXPORT jbyteArray JNICALL
Java_io_ray_runtime_actor_NativeActorHandle_nativeSerialize(JNIEnv *, jclass, jbyteArray);
JNIEXPORT jbyteArray JNICALL Java_io_ray_runtime_actor_NativeActorHandle_nativeSerialize(
JNIEnv *, jclass, jbyteArray, jbyteArray);
/*
* Class: io_ray_runtime_actor_NativeActorHandle
@ -55,6 +55,17 @@ JNIEXPORT jbyteArray JNICALL
Java_io_ray_runtime_actor_NativeActorHandle_nativeDeserialize(JNIEnv *, jclass,
jbyteArray);
/*
* Class: io_ray_runtime_actor_NativeActorHandle
* Method: nativeRemoveActorHandleReference
* Signature: ([B[B)V
*/
JNIEXPORT void JNICALL
Java_io_ray_runtime_actor_NativeActorHandle_nativeRemoveActorHandleReference(JNIEnv *,
jclass,
jbyteArray,
jbyteArray);
#ifdef __cplusplus
}
#endif

View file

@ -159,6 +159,7 @@ GcsActorManager::GcsActorManager(
std::shared_ptr<GcsPublisher> gcs_publisher, RuntimeEnvManager &runtime_env_manager,
std::function<void(const ActorID &)> destroy_owned_placement_group_if_needed,
std::function<std::string(const JobID &)> get_ray_namespace,
std::function<int32_t(const JobID &)> get_num_java_workers_per_process,
std::function<void(std::function<void(void)>, boost::posix_time::milliseconds)>
run_delayed,
const rpc::ClientFactoryFn &worker_client_factory)
@ -169,6 +170,7 @@ GcsActorManager::GcsActorManager(
worker_client_factory_(worker_client_factory),
destroy_owned_placement_group_if_needed_(destroy_owned_placement_group_if_needed),
get_ray_namespace_(get_ray_namespace),
get_num_java_workers_per_process_(std::move(get_num_java_workers_per_process)),
runtime_env_manager_(runtime_env_manager),
run_delayed_(run_delayed),
actor_gc_delay_(RayConfig::instance().gcs_actor_table_min_duration_ms()) {
@ -644,13 +646,19 @@ void GcsActorManager::PollOwnerForActorOutOfScope(
if (node_it != owners_.end() && node_it->second.count(owner_id)) {
// Only destroy the actor if its owner is still alive. The actor may
// have already been destroyed if the owner died.
DestroyActor(actor_id, GenActorOutOfScopeCause());
// For multiple actors in one process, if one actor is out of scope,
// We shouldn't force kill the actor because other actors in the process
// are still alive.
auto force_kill = get_num_java_workers_per_process_(actor_id.JobId()) <= 1;
DestroyActor(actor_id, GenActorOutOfScopeCause(), force_kill);
}
});
}
void GcsActorManager::DestroyActor(const ActorID &actor_id,
const rpc::ActorDeathCause &death_cause) {
const rpc::ActorDeathCause &death_cause,
bool force_kill) {
RAY_LOG(INFO) << "Destroying actor, actor id = " << actor_id
<< ", job id = " << actor_id.JobId();
actor_to_register_callbacks_.erase(actor_id);
@ -700,7 +708,7 @@ void GcsActorManager::DestroyActor(const ActorID &actor_id,
if (node_it != created_actors_.end() && node_it->second.count(worker_id)) {
// The actor has already been created. Destroy the process by force-killing
// it.
NotifyCoreWorkerToKillActor(actor);
NotifyCoreWorkerToKillActor(actor, force_kill);
RAY_CHECK(node_it->second.erase(actor->GetWorkerID()));
if (node_it->second.empty()) {
created_actors_.erase(node_it);

View file

@ -201,6 +201,7 @@ class GcsActorManager : public rpc::ActorInfoHandler {
std::shared_ptr<GcsPublisher> gcs_publisher, RuntimeEnvManager &runtime_env_manager,
std::function<void(const ActorID &)> destroy_ownded_placement_group_if_needed,
std::function<std::string(const JobID &)> get_ray_namespace,
std::function<int32_t(const JobID &)> get_num_java_workers_per_process,
std::function<void(std::function<void(void)>, boost::posix_time::milliseconds)>
run_delayed,
const rpc::ClientFactoryFn &worker_client_factory = nullptr);
@ -384,7 +385,9 @@ class GcsActorManager : public rpc::ActorInfoHandler {
///
/// \param[in] actor_id The actor id to destroy.
/// \param[in] death_cause The reason why actor is destroyed.
void DestroyActor(const ActorID &actor_id, const rpc::ActorDeathCause &death_cause);
/// \param[in] force_kill Whether destory the actor forcelly.
void DestroyActor(const ActorID &actor_id, const rpc::ActorDeathCause &death_cause,
bool force_kill = true);
/// Get unresolved actors that were submitted from the specified node.
absl::flat_hash_map<WorkerID, absl::flat_hash_set<ActorID>>
@ -517,6 +520,9 @@ class GcsActorManager : public rpc::ActorInfoHandler {
/// A callback to get the namespace an actor belongs to based on its job id. This is
/// necessary for actor creation.
std::function<std::string(const JobID &)> get_ray_namespace_;
/// A callback to get the number of java workers per process config item by the
/// given job id. It is necessary for deciding whether we should clear the Java actor.
std::function<int32_t(const JobID &)> get_num_java_workers_per_process_;
RuntimeEnvManager &runtime_env_manager_;
/// Run a function on a delay. This is useful for guaranteeing data will be
/// accessible for a minimum amount of time.

View file

@ -25,6 +25,8 @@ void GcsJobManager::Initialize(const GcsInitData &gcs_init_data) {
const auto &job_table_data = pair.second;
const auto &ray_namespace = job_table_data.config().ray_namespace();
ray_namespaces_[job_id] = ray_namespace;
cache_num_java_worker_per_processes_[job_id] =
job_table_data.config().num_java_workers_per_process();
}
}
@ -54,6 +56,8 @@ void GcsJobManager::HandleAddJob(const rpc::AddJobRequest &request,
RAY_LOG(INFO) << "Finished adding job, job id = " << job_id
<< ", driver pid = " << mutable_job_table_data.driver_pid();
ray_namespaces_[job_id] = mutable_job_table_data.config().ray_namespace();
cache_num_java_worker_per_processes_[job_id] =
mutable_job_table_data.config().num_java_workers_per_process();
}
GCS_RPC_SEND_REPLY(send_reply_callback, reply, status);
};
@ -173,5 +177,12 @@ std::string GcsJobManager::GetRayNamespace(const JobID &job_id) const {
return it->second;
}
int32_t GcsJobManager::GetNumJavaWorkersPerProcess(const JobID &job_id) const {
auto it = cache_num_java_worker_per_processes_.find(job_id);
RAY_CHECK(it != cache_num_java_worker_per_processes_.end())
<< "Couldn't find job with id: " << job_id;
return it->second;
}
} // namespace gcs
} // namespace ray

View file

@ -59,6 +59,8 @@ class GcsJobManager : public rpc::JobInfoHandler {
std::string GetRayNamespace(const JobID &job_id) const;
int32_t GetNumJavaWorkersPerProcess(const JobID &job_id) const;
private:
std::shared_ptr<GcsTableStorage> gcs_table_storage_;
std::shared_ptr<GcsPublisher> gcs_publisher_;
@ -69,6 +71,9 @@ class GcsJobManager : public rpc::JobInfoHandler {
/// A cached mapping from job id to namespace.
std::unordered_map<JobID, std::string> ray_namespaces_;
/// A cached mapping from job id to num_java_workers_per_process.
std::unordered_map<JobID, int32_t> cache_num_java_worker_per_processes_;
ray::RuntimeEnvManager &runtime_env_manager_;
void ClearJobInfos(const JobID &job_id);

View file

@ -325,6 +325,9 @@ void GcsServer::InitGcsActorManager(const GcsInitData &gcs_init_data) {
gcs_placement_group_manager_->CleanPlacementGroupIfNeededWhenActorDead(actor_id);
},
[this](const JobID &job_id) { return gcs_job_manager_->GetRayNamespace(job_id); },
[this](const JobID &job_id) {
return gcs_job_manager_->GetNumJavaWorkersPerProcess(job_id);
},
[this](std::function<void(void)> fn, boost::posix_time::milliseconds delay) {
boost::asio::deadline_timer timer(main_service_);
timer.expires_from_now(delay);

View file

@ -111,6 +111,7 @@ class GcsActorManagerTest : public ::testing::Test {
io_service_, mock_actor_scheduler_, gcs_table_storage_, gcs_publisher_,
*runtime_env_mgr_, [](const ActorID &actor_id) {},
[this](const JobID &job_id) { return job_namespace_table_[job_id]; },
[](const JobID &job_id) { return /*num_java_worker_per_process=*/1; },
[this](std::function<void(void)> fn, boost::posix_time::milliseconds delay) {
if (skip_delay_) {
fn();