[Java] Fix instanceof RayPyActor (#6377)

This commit is contained in:
Kai Yang 2019-12-07 16:28:29 +08:00 committed by Hao Chen
parent 7e9fddf3ed
commit eb912b68b1
7 changed files with 102 additions and 49 deletions

View file

@ -8,40 +8,54 @@ import java.io.ObjectOutput;
import java.util.List;
import org.ray.api.Ray;
import org.ray.api.RayActor;
import org.ray.api.RayPyActor;
import org.ray.api.id.ActorId;
import org.ray.api.id.UniqueId;
import org.ray.api.runtime.RayRuntime;
import org.ray.runtime.AbstractRayRuntime;
import org.ray.runtime.RayNativeRuntime;
import org.ray.runtime.RayMultiWorkerNativeRuntime;
import org.ray.runtime.RayNativeRuntime;
import org.ray.runtime.generated.Common.Language;
/**
* RayActor implementation for cluster mode. This is a wrapper class for C++ ActorHandle.
* RayActor abstract language-independent implementation for cluster mode. This is a wrapper class
* for C++ ActorHandle.
*/
public class NativeRayActor implements RayActor, RayPyActor, Externalizable {
public abstract class NativeRayActor implements RayActor, Externalizable {
/**
* Address of core worker.
*/
private long nativeCoreWorkerPointer;
long nativeCoreWorkerPointer;
/**
* ID of the actor.
*/
private byte[] actorId;
byte[] actorId;
public NativeRayActor(long nativeCoreWorkerPointer, byte[] actorId) {
private Language language;
NativeRayActor(long nativeCoreWorkerPointer, byte[] actorId, Language language) {
Preconditions.checkState(nativeCoreWorkerPointer != 0);
Preconditions.checkState(!ActorId.fromBytes(actorId).isNil());
this.nativeCoreWorkerPointer = nativeCoreWorkerPointer;
this.actorId = actorId;
this.language = language;
}
/**
* Required by FST
*/
public NativeRayActor() {
NativeRayActor() {
}
public static NativeRayActor create(long nativeCoreWorkerPointer, byte[] actorId,
Language language) {
Preconditions.checkState(nativeCoreWorkerPointer != 0);
switch (language) {
case JAVA:
return new NativeRayJavaActor(nativeCoreWorkerPointer, actorId);
case PYTHON:
return new NativeRayPyActor(nativeCoreWorkerPointer, actorId);
default:
throw new IllegalStateException("Unknown actor handle language: " + language);
}
}
@Override
@ -50,30 +64,17 @@ public class NativeRayActor implements RayActor, RayPyActor, Externalizable {
}
public Language getLanguage() {
return Language.forNumber(nativeGetLanguage(nativeCoreWorkerPointer, actorId));
return language;
}
public boolean isDirectCallActor() {
return nativeIsDirectCallActor(nativeCoreWorkerPointer, actorId);
}
@Override
public String getModuleName() {
Preconditions.checkState(getLanguage() == Language.PYTHON);
return nativeGetActorCreationTaskFunctionDescriptor(
nativeCoreWorkerPointer, actorId).get(0);
}
@Override
public String getClassName() {
Preconditions.checkState(getLanguage() == Language.PYTHON);
return nativeGetActorCreationTaskFunctionDescriptor(
nativeCoreWorkerPointer, actorId).get(1);
}
@Override
public void writeExternal(ObjectOutput out) throws IOException {
out.writeObject(nativeSerialize(nativeCoreWorkerPointer, actorId));
out.writeObject(language);
}
@Override
@ -82,11 +83,11 @@ public class NativeRayActor implements RayActor, RayPyActor, Externalizable {
if (runtime instanceof RayMultiWorkerNativeRuntime) {
runtime = ((RayMultiWorkerNativeRuntime) runtime).getCurrentRuntime();
}
Preconditions.checkState(runtime instanceof RayNativeRuntime);
nativeCoreWorkerPointer = ((RayNativeRuntime)runtime).getNativeCoreWorkerPointer();
nativeCoreWorkerPointer = ((RayNativeRuntime) runtime).getNativeCoreWorkerPointer();
actorId = nativeDeserialize(nativeCoreWorkerPointer, (byte[]) in.readObject());
language = (Language) in.readObject();
}
@Override
@ -94,11 +95,9 @@ public class NativeRayActor implements RayActor, RayPyActor, Externalizable {
// TODO(zhijunfu): do we need to free the ActorHandle in core worker?
}
private static native int nativeGetLanguage(long nativeCoreWorkerPointer, byte[] actorId);
private static native boolean nativeIsDirectCallActor(long nativeCoreWorkerPointer, byte[] actorId);
private static native List<String> nativeGetActorCreationTaskFunctionDescriptor(
static native List<String> nativeGetActorCreationTaskFunctionDescriptor(
long nativeCoreWorkerPointer, byte[] actorId);
private static native byte[] nativeSerialize(long nativeCoreWorkerPointer, byte[] actorId);

View file

@ -0,0 +1,29 @@
package org.ray.runtime.actor;
import com.google.common.base.Preconditions;
import java.io.IOException;
import java.io.ObjectInput;
import org.ray.runtime.generated.Common.Language;
/**
* RayActor Java implementation for cluster mode.
*/
public class NativeRayJavaActor extends NativeRayActor {
NativeRayJavaActor(long nativeCoreWorkerPointer, byte[] actorId) {
super(nativeCoreWorkerPointer, actorId, Language.JAVA);
}
/**
* Required by FST
*/
public NativeRayJavaActor() {
super();
}
@Override
public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException {
super.readExternal(in);
Preconditions.checkState(getLanguage() == Language.JAVA);
}
}

View file

@ -0,0 +1,40 @@
package org.ray.runtime.actor;
import com.google.common.base.Preconditions;
import java.io.IOException;
import java.io.ObjectInput;
import org.ray.api.RayPyActor;
import org.ray.runtime.generated.Common.Language;
/**
* RayActor Python implementation for cluster mode.
*/
public class NativeRayPyActor extends NativeRayActor implements RayPyActor {
NativeRayPyActor(long nativeCoreWorkerPointer, byte[] actorId) {
super(nativeCoreWorkerPointer, actorId, Language.PYTHON);
}
/**
* Required by FST
*/
public NativeRayPyActor() {
super();
}
@Override
public String getModuleName() {
return nativeGetActorCreationTaskFunctionDescriptor(nativeCoreWorkerPointer, actorId).get(0);
}
@Override
public String getClassName() {
return nativeGetActorCreationTaskFunctionDescriptor(nativeCoreWorkerPointer, actorId).get(1);
}
@Override
public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException {
super.readExternal(in);
Preconditions.checkState(getLanguage() == Language.PYTHON);
}
}

View file

@ -37,7 +37,7 @@ public class NativeTaskSubmitter implements TaskSubmitter {
ActorCreationOptions options) {
byte[] actorId = nativeCreateActor(nativeCoreWorkerPointer, functionDescriptor, args,
options);
return new NativeRayActor(nativeCoreWorkerPointer, actorId);
return NativeRayActor.create(nativeCoreWorkerPointer, actorId, functionDescriptor.getLanguage());
}
@Override

View file

@ -7,6 +7,7 @@ import java.util.concurrent.TimeUnit;
import org.ray.api.Ray;
import org.ray.api.RayActor;
import org.ray.api.RayObject;
import org.ray.api.RayPyActor;
import org.ray.api.TestUtils;
import org.ray.api.TestUtils.LargeObject;
import org.ray.api.annotation.RayRemote;
@ -50,6 +51,8 @@ public class ActorTest extends BaseTest {
// Test creating an actor from a constructor
RayActor<Counter> actor = Ray.createActor(Counter::new, 1);
Assert.assertNotEquals(actor.getId(), UniqueId.NIL);
// A java actor is not a python actor
Assert.assertFalse(actor instanceof RayPyActor);
// Test calling an actor
Assert.assertEquals(Integer.valueOf(1), Ray.call(Counter::getValue, actor).get());
Ray.call(Counter::increase, actor, 1);

View file

@ -13,16 +13,6 @@ inline ray::CoreWorker &GetCoreWorker(jlong nativeCoreWorkerPointer) {
extern "C" {
#endif
JNIEXPORT jint JNICALL Java_org_ray_runtime_actor_NativeRayActor_nativeGetLanguage(
JNIEnv *env, jclass o, jlong nativeCoreWorkerPointer, jbyteArray actorId) {
auto actor_id = JavaByteArrayToId<ray::ActorID>(env, actorId);
ray::ActorHandle *native_actor_handle = nullptr;
auto status = GetCoreWorker(nativeCoreWorkerPointer)
.GetActorHandle(actor_id, &native_actor_handle);
THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, (jint)0);
return (jint)native_actor_handle->ActorLanguage();
}
JNIEXPORT jboolean JNICALL
Java_org_ray_runtime_actor_NativeRayActor_nativeIsDirectCallActor(
JNIEnv *env, jclass o, jlong nativeCoreWorkerPointer, jbyteArray actorId) {

View file

@ -7,14 +7,6 @@
#ifdef __cplusplus
extern "C" {
#endif
/*
* Class: org_ray_runtime_actor_NativeRayActor
* Method: nativeGetLanguage
* Signature: (J[B)I
*/
JNIEXPORT jint JNICALL Java_org_ray_runtime_actor_NativeRayActor_nativeGetLanguage(
JNIEnv *, jclass, jlong, jbyteArray);
/*
* Class: org_ray_runtime_actor_NativeRayActor
* Method: nativeIsDirectCallActor