[Java] Avoid data copy from C++ to Java for ByteBuffer type (#9033)

This commit is contained in:
Kai Yang 2020-07-22 16:25:32 +08:00 committed by GitHub
parent 6346c70792
commit bfa0605282
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
9 changed files with 125 additions and 16 deletions

View file

@ -7,6 +7,7 @@ import io.ray.api.exception.UnreconstructableException;
import io.ray.api.id.ObjectId;
import io.ray.runtime.generated.Gcs.ErrorType;
import io.ray.runtime.serializer.Serializer;
import java.nio.ByteBuffer;
import java.util.Arrays;
import org.apache.commons.lang3.tuple.Pair;
@ -45,6 +46,9 @@ public class ObjectSerializer {
if (meta != null && meta.length > 0) {
// If meta is not null, deserialize the object from meta.
if (Arrays.equals(meta, OBJECT_METADATA_TYPE_RAW)) {
if (objectType == ByteBuffer.class) {
return ByteBuffer.wrap(data);
}
return data;
} else if (Arrays.equals(meta, OBJECT_METADATA_TYPE_CROSS_LANGUAGE) ||
Arrays.equals(meta, OBJECT_METADATA_TYPE_JAVA)) {
@ -81,6 +85,17 @@ public class ObjectSerializer {
// If the object is a byte array, skip serializing it and use a special metadata to
// indicate it's raw binary. So that this object can also be read by Python.
return new NativeRayObject((byte[]) object, OBJECT_METADATA_TYPE_RAW);
} else if (object instanceof ByteBuffer) {
// Serialize ByteBuffer to raw bytes.
ByteBuffer buffer = (ByteBuffer) object;
byte[] bytes;
if (buffer.hasArray()) {
bytes = buffer.array();
} else {
bytes = new byte[buffer.remaining()];
buffer.get(bytes);
}
return new NativeRayObject(bytes, OBJECT_METADATA_TYPE_RAW);
} else if (object instanceof RayTaskException) {
byte[] serializedBytes = Serializer.encode(object).getLeft();
return new NativeRayObject(serializedBytes, TASK_EXECUTION_EXCEPTION_META);

View file

@ -1,5 +1,6 @@
package io.ray.runtime.task;
import com.google.common.base.Preconditions;
import io.ray.api.ObjectRef;
import io.ray.api.Ray;
import io.ray.api.id.ObjectId;
@ -7,6 +8,7 @@ import io.ray.runtime.RayRuntimeInternal;
import io.ray.runtime.generated.Common.Language;
import io.ray.runtime.object.NativeRayObject;
import io.ray.runtime.object.ObjectSerializer;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
@ -68,12 +70,19 @@ public class ArgumentsBuilder {
}
/**
* Convert list of NativeRayObject to real function arguments.
* Convert list of NativeRayObject/ByteBuffer to real function arguments.
*/
public static Object[] unwrap(List<NativeRayObject> args, Class<?>[] types) {
public static Object[] unwrap(List<Object> args, Class<?>[] types) {
Object[] realArgs = new Object[args.size()];
for (int i = 0; i < args.size(); i++) {
realArgs[i] = ObjectSerializer.deserialize(args.get(i), null, types[i]);
Object arg = args.get(i);
Preconditions.checkState(arg instanceof ByteBuffer || arg instanceof NativeRayObject);
if (arg instanceof ByteBuffer) {
Preconditions.checkState(types[i] == ByteBuffer.class);
realArgs[i] = arg;
} else {
realArgs[i] = ObjectSerializer.deserialize((NativeRayObject) arg, null, types[i]);
}
}
return realArgs;
}

View file

@ -311,8 +311,9 @@ public class LocalModeTaskSubmitter implements TaskSubmitter {
? ((LocalModeTaskExecutor.LocalActorContext) actorContext).getWorkerId()
: UniqueId.randomId();
((LocalModeWorkerContext) runtime.getWorkerContext()).setCurrentWorkerId(workerId);
List<NativeRayObject> returnObjects = taskExecutor
.execute(getJavaFunctionDescriptor(taskSpec).toList(), args);
List<String> rayFunctionInfo = getJavaFunctionDescriptor(taskSpec).toList();
taskExecutor.checkByteBufferArguments(rayFunctionInfo);
List<NativeRayObject> returnObjects = taskExecutor.execute(rayFunctionInfo, args);
if (taskSpec.getType() == TaskType.ACTOR_CREATION_TASK) {
// Update actor context map ASAP in case objectStore.putRaw triggered the next actor task
// on this actor.

View file

@ -13,6 +13,7 @@ import io.ray.runtime.generated.Common.TaskType;
import io.ray.runtime.object.NativeRayObject;
import io.ray.runtime.object.ObjectSerializer;
import java.lang.reflect.InvocationTargetException;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.ConcurrentHashMap;
@ -30,6 +31,8 @@ public abstract class TaskExecutor<T extends TaskExecutor.ActorContext> {
private final ConcurrentHashMap<UniqueId, T> actorContextMap = new ConcurrentHashMap<>();
private final ThreadLocal<RayFunction> localRayFunction = new ThreadLocal<>();
static class ActorContext {
/**
@ -61,10 +64,34 @@ public abstract class TaskExecutor<T extends TaskExecutor.ActorContext> {
this.actorContextMap.put(runtime.getWorkerContext().getCurrentWorkerId(), actorContext);
}
protected List<NativeRayObject> execute(List<String> rayFunctionInfo,
List<NativeRayObject> argsBytes) {
runtime.setIsContextSet(true);
private RayFunction getRayFunction(List<String> rayFunctionInfo) {
JobId jobId = runtime.getWorkerContext().getCurrentJobId();
JavaFunctionDescriptor functionDescriptor = parseFunctionDescriptor(rayFunctionInfo);
return runtime.getFunctionManager().getFunction(jobId, functionDescriptor);
}
/**
* The return value indicates which parameters are ByteBuffer.
*/
protected boolean[] checkByteBufferArguments(List<String> rayFunctionInfo) {
localRayFunction.set(null);
try {
localRayFunction.set(getRayFunction(rayFunctionInfo));
} catch (Throwable e) {
// Ignore the exception.
return null;
}
Class<?>[] types = localRayFunction.get().executable.getParameterTypes();
boolean[] results = new boolean[types.length];
for (int i = 0; i < types.length; i++) {
results[i] = types[i] == ByteBuffer.class;
}
return results;
}
protected List<NativeRayObject> execute(List<String> rayFunctionInfo,
List<Object> argsBytes) {
runtime.setIsContextSet(true);
TaskType taskType = runtime.getWorkerContext().getCurrentTaskType();
TaskId taskId = runtime.getWorkerContext().getCurrentTaskId();
LOGGER.debug("Executing task {}", taskId);
@ -80,11 +107,14 @@ public abstract class TaskExecutor<T extends TaskExecutor.ActorContext> {
List<NativeRayObject> returnObjects = new ArrayList<>();
ClassLoader oldLoader = Thread.currentThread().getContextClassLoader();
JavaFunctionDescriptor functionDescriptor = parseFunctionDescriptor(rayFunctionInfo);
RayFunction rayFunction = null;
RayFunction rayFunction = localRayFunction.get();
try {
// Find the executable object.
rayFunction = runtime.getFunctionManager().getFunction(jobId, functionDescriptor);
if (rayFunction == null) {
// Failed to get RayFunction in checkByteBufferArguments. Redo here to throw
// the exception again.
rayFunction = getRayFunction(rayFunctionInfo);
}
Thread.currentThread().setContextClassLoader(rayFunction.classLoader);
runtime.getWorkerContext().setCurrentClassLoader(rayFunction.classLoader);
@ -132,7 +162,7 @@ public abstract class TaskExecutor<T extends TaskExecutor.ActorContext> {
LOGGER.error("Error executing task " + taskId, e);
if (taskType != TaskType.ACTOR_CREATION_TASK) {
boolean hasReturn = rayFunction != null && rayFunction.hasReturn();
boolean isCrossLanguage = functionDescriptor.signature.equals("");
boolean isCrossLanguage = parseFunctionDescriptor(rayFunctionInfo).signature.equals("");
if (hasReturn || isCrossLanguage) {
returnObjects.add(ObjectSerializer
.serialize(new RayTaskException("Error executing task " + taskId, e)));

View file

@ -4,6 +4,8 @@ import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import io.ray.api.Ray;
import io.ray.api.id.ObjectId;
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.util.List;
import java.util.Map;
import org.testng.Assert;
@ -63,6 +65,10 @@ public class RayCallTest extends BaseTest {
TestUtils.getRuntime().getObjectStore().put(1, objectId);
}
private static ByteBuffer testByteBuffer(ByteBuffer buffer) {
return buffer;
}
/**
* Test calling and returning different types.
*/
@ -82,6 +88,11 @@ public class RayCallTest extends BaseTest {
Assert.assertEquals(map, Ray.task(RayCallTest::testMap, map).remote().get());
TestUtils.LargeObject largeObject = new TestUtils.LargeObject();
Assert.assertNotNull(Ray.task(RayCallTest::testLargeObject, largeObject).remote().get());
ByteBuffer buffer1 = ByteBuffer.wrap("foo".getBytes(StandardCharsets.UTF_8));
ByteBuffer buffer2 = Ray.task(RayCallTest::testByteBuffer, buffer1).remote().get();
byte[] bytes = new byte[buffer2.remaining()];
buffer2.get(bytes);
Assert.assertEquals("foo", new String(bytes, StandardCharsets.UTF_8));
// TODO(edoakes): this test doesn't work now that we've switched to direct call
// mode. To make it work, we need to implement the same protocol for resolving

View file

@ -1666,8 +1666,12 @@ Status CoreWorker::GetAndPinArgsForExecutor(const TaskSpecification &task,
metadata = std::make_shared<LocalMemoryBuffer>(
const_cast<uint8_t *>(task.ArgMetadata(i)), task.ArgMetadataSize(i));
}
args->at(i) = std::make_shared<RayObject>(data, metadata, task.ArgInlinedIds(i),
/*copy_data*/ true);
// NOTE: this is a workaround to avoid an extra copy for Java workers.
// Python workers need this copy to pass test case
// test_inline_arg_memory_corruption.
bool copy_data = options_.language == Language::PYTHON;
args->at(i) =
std::make_shared<RayObject>(data, metadata, task.ArgInlinedIds(i), copy_data);
arg_reference_ids->at(i) = ObjectID::Nil();
// The task borrows all ObjectIDs that were serialized in the inlined
// arguments. The task will receive references to these IDs, so it is

View file

@ -43,6 +43,35 @@ inline ray::gcs::GcsClientOptions ToGcsClientOptions(JNIEnv *env,
return ray::gcs::GcsClientOptions(ip, port, password, /*is_test_client=*/false);
}
jobject ToJavaArgs(JNIEnv *env, jbooleanArray java_check_results,
const std::vector<std::shared_ptr<ray::RayObject>> &args) {
if (java_check_results == nullptr) {
// If `java_check_results` is null, it means that `checkByteBufferArguments`
// failed. In this case, just return null here. The args won't be used anyway.
return nullptr;
} else {
jboolean *check_results = env->GetBooleanArrayElements(java_check_results, nullptr);
size_t i = 0;
jobject args_array_list = NativeVectorToJavaList<std::shared_ptr<ray::RayObject>>(
env, args,
[check_results, &i](JNIEnv *env,
const std::shared_ptr<ray::RayObject> &native_object) {
if (*(check_results + (i++))) {
// If the type of this argument is ByteBuffer, we create a
// DirectByteBuffer here To avoid data copy.
// TODO: Check native_object->GetMetadata() == "RAW"
jobject obj = env->NewDirectByteBuffer(native_object->GetData()->Data(),
native_object->GetData()->Size());
RAY_CHECK(obj);
return obj;
}
return NativeRayObjectToJavaNativeRayObject(env, native_object);
});
env->ReleaseBooleanArrayElements(java_check_results, check_results, JNI_ABORT);
return args_array_list;
}
}
#ifdef __cplusplus
extern "C" {
#endif
@ -100,8 +129,12 @@ JNIEXPORT void JNICALL Java_io_ray_runtime_RayNativeRuntime_nativeInitialize(
// convert args
// TODO (kfstorm): Avoid copying binary data from Java to C++
jobject args_array_list = NativeVectorToJavaList<std::shared_ptr<ray::RayObject>>(
env, args, NativeRayObjectToJavaNativeRayObject);
jbooleanArray java_check_results =
static_cast<jbooleanArray>(env->CallObjectMethod(
java_task_executor, java_task_executor_parse_function_arguments,
ray_function_array_list));
RAY_CHECK_JAVA_EXCEPTION(env);
jobject args_array_list = ToJavaArgs(env, java_check_results, args);
// invoke Java method
jobject java_return_objects =
@ -120,6 +153,7 @@ JNIEXPORT void JNICALL Java_io_ray_runtime_RayNativeRuntime_nativeInitialize(
}
}
env->DeleteLocalRef(java_check_results);
env->DeleteLocalRef(java_return_objects);
env->DeleteLocalRef(args_array_list);
return ray::Status::OK();

View file

@ -86,6 +86,7 @@ jfieldID java_native_ray_object_data;
jfieldID java_native_ray_object_metadata;
jclass java_task_executor_class;
jmethodID java_task_executor_parse_function_arguments;
jmethodID java_task_executor_execute;
JavaVM *jvm;
@ -205,6 +206,8 @@ jint JNI_OnLoad(JavaVM *vm, void *reserved) {
env->GetFieldID(java_native_ray_object_class, "metadata", "[B");
java_task_executor_class = LoadClass(env, "io/ray/runtime/task/TaskExecutor");
java_task_executor_parse_function_arguments = env->GetMethodID(
java_task_executor_class, "checkByteBufferArguments", "(Ljava/util/List;)[Z");
java_task_executor_execute =
env->GetMethodID(java_task_executor_class, "execute",
"(Ljava/util/List;Ljava/util/List;)Ljava/util/List;");

View file

@ -148,6 +148,8 @@ extern jfieldID java_native_ray_object_metadata;
/// TaskExecutor class
extern jclass java_task_executor_class;
/// checkByteBufferArguments method of TaskExecutor class
extern jmethodID java_task_executor_parse_function_arguments;
/// execute method of TaskExecutor class
extern jmethodID java_task_executor_execute;