mirror of
https://github.com/vale981/ray
synced 2025-03-05 18:11:42 -05:00
[Java] Avoid data copy from C++ to Java for ByteBuffer type (#9033)
This commit is contained in:
parent
6346c70792
commit
bfa0605282
9 changed files with 125 additions and 16 deletions
|
@ -7,6 +7,7 @@ import io.ray.api.exception.UnreconstructableException;
|
|||
import io.ray.api.id.ObjectId;
|
||||
import io.ray.runtime.generated.Gcs.ErrorType;
|
||||
import io.ray.runtime.serializer.Serializer;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.util.Arrays;
|
||||
import org.apache.commons.lang3.tuple.Pair;
|
||||
|
||||
|
@ -45,6 +46,9 @@ public class ObjectSerializer {
|
|||
if (meta != null && meta.length > 0) {
|
||||
// If meta is not null, deserialize the object from meta.
|
||||
if (Arrays.equals(meta, OBJECT_METADATA_TYPE_RAW)) {
|
||||
if (objectType == ByteBuffer.class) {
|
||||
return ByteBuffer.wrap(data);
|
||||
}
|
||||
return data;
|
||||
} else if (Arrays.equals(meta, OBJECT_METADATA_TYPE_CROSS_LANGUAGE) ||
|
||||
Arrays.equals(meta, OBJECT_METADATA_TYPE_JAVA)) {
|
||||
|
@ -81,6 +85,17 @@ public class ObjectSerializer {
|
|||
// If the object is a byte array, skip serializing it and use a special metadata to
|
||||
// indicate it's raw binary. So that this object can also be read by Python.
|
||||
return new NativeRayObject((byte[]) object, OBJECT_METADATA_TYPE_RAW);
|
||||
} else if (object instanceof ByteBuffer) {
|
||||
// Serialize ByteBuffer to raw bytes.
|
||||
ByteBuffer buffer = (ByteBuffer) object;
|
||||
byte[] bytes;
|
||||
if (buffer.hasArray()) {
|
||||
bytes = buffer.array();
|
||||
} else {
|
||||
bytes = new byte[buffer.remaining()];
|
||||
buffer.get(bytes);
|
||||
}
|
||||
return new NativeRayObject(bytes, OBJECT_METADATA_TYPE_RAW);
|
||||
} else if (object instanceof RayTaskException) {
|
||||
byte[] serializedBytes = Serializer.encode(object).getLeft();
|
||||
return new NativeRayObject(serializedBytes, TASK_EXECUTION_EXCEPTION_META);
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
package io.ray.runtime.task;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
import io.ray.api.ObjectRef;
|
||||
import io.ray.api.Ray;
|
||||
import io.ray.api.id.ObjectId;
|
||||
|
@ -7,6 +8,7 @@ import io.ray.runtime.RayRuntimeInternal;
|
|||
import io.ray.runtime.generated.Common.Language;
|
||||
import io.ray.runtime.object.NativeRayObject;
|
||||
import io.ray.runtime.object.ObjectSerializer;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
|
@ -68,12 +70,19 @@ public class ArgumentsBuilder {
|
|||
}
|
||||
|
||||
/**
|
||||
* Convert list of NativeRayObject to real function arguments.
|
||||
* Convert list of NativeRayObject/ByteBuffer to real function arguments.
|
||||
*/
|
||||
public static Object[] unwrap(List<NativeRayObject> args, Class<?>[] types) {
|
||||
public static Object[] unwrap(List<Object> args, Class<?>[] types) {
|
||||
Object[] realArgs = new Object[args.size()];
|
||||
for (int i = 0; i < args.size(); i++) {
|
||||
realArgs[i] = ObjectSerializer.deserialize(args.get(i), null, types[i]);
|
||||
Object arg = args.get(i);
|
||||
Preconditions.checkState(arg instanceof ByteBuffer || arg instanceof NativeRayObject);
|
||||
if (arg instanceof ByteBuffer) {
|
||||
Preconditions.checkState(types[i] == ByteBuffer.class);
|
||||
realArgs[i] = arg;
|
||||
} else {
|
||||
realArgs[i] = ObjectSerializer.deserialize((NativeRayObject) arg, null, types[i]);
|
||||
}
|
||||
}
|
||||
return realArgs;
|
||||
}
|
||||
|
|
|
@ -311,8 +311,9 @@ public class LocalModeTaskSubmitter implements TaskSubmitter {
|
|||
? ((LocalModeTaskExecutor.LocalActorContext) actorContext).getWorkerId()
|
||||
: UniqueId.randomId();
|
||||
((LocalModeWorkerContext) runtime.getWorkerContext()).setCurrentWorkerId(workerId);
|
||||
List<NativeRayObject> returnObjects = taskExecutor
|
||||
.execute(getJavaFunctionDescriptor(taskSpec).toList(), args);
|
||||
List<String> rayFunctionInfo = getJavaFunctionDescriptor(taskSpec).toList();
|
||||
taskExecutor.checkByteBufferArguments(rayFunctionInfo);
|
||||
List<NativeRayObject> returnObjects = taskExecutor.execute(rayFunctionInfo, args);
|
||||
if (taskSpec.getType() == TaskType.ACTOR_CREATION_TASK) {
|
||||
// Update actor context map ASAP in case objectStore.putRaw triggered the next actor task
|
||||
// on this actor.
|
||||
|
|
|
@ -13,6 +13,7 @@ import io.ray.runtime.generated.Common.TaskType;
|
|||
import io.ray.runtime.object.NativeRayObject;
|
||||
import io.ray.runtime.object.ObjectSerializer;
|
||||
import java.lang.reflect.InvocationTargetException;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
|
@ -30,6 +31,8 @@ public abstract class TaskExecutor<T extends TaskExecutor.ActorContext> {
|
|||
|
||||
private final ConcurrentHashMap<UniqueId, T> actorContextMap = new ConcurrentHashMap<>();
|
||||
|
||||
private final ThreadLocal<RayFunction> localRayFunction = new ThreadLocal<>();
|
||||
|
||||
static class ActorContext {
|
||||
|
||||
/**
|
||||
|
@ -61,10 +64,34 @@ public abstract class TaskExecutor<T extends TaskExecutor.ActorContext> {
|
|||
this.actorContextMap.put(runtime.getWorkerContext().getCurrentWorkerId(), actorContext);
|
||||
}
|
||||
|
||||
protected List<NativeRayObject> execute(List<String> rayFunctionInfo,
|
||||
List<NativeRayObject> argsBytes) {
|
||||
runtime.setIsContextSet(true);
|
||||
private RayFunction getRayFunction(List<String> rayFunctionInfo) {
|
||||
JobId jobId = runtime.getWorkerContext().getCurrentJobId();
|
||||
JavaFunctionDescriptor functionDescriptor = parseFunctionDescriptor(rayFunctionInfo);
|
||||
return runtime.getFunctionManager().getFunction(jobId, functionDescriptor);
|
||||
}
|
||||
|
||||
/**
|
||||
* The return value indicates which parameters are ByteBuffer.
|
||||
*/
|
||||
protected boolean[] checkByteBufferArguments(List<String> rayFunctionInfo) {
|
||||
localRayFunction.set(null);
|
||||
try {
|
||||
localRayFunction.set(getRayFunction(rayFunctionInfo));
|
||||
} catch (Throwable e) {
|
||||
// Ignore the exception.
|
||||
return null;
|
||||
}
|
||||
Class<?>[] types = localRayFunction.get().executable.getParameterTypes();
|
||||
boolean[] results = new boolean[types.length];
|
||||
for (int i = 0; i < types.length; i++) {
|
||||
results[i] = types[i] == ByteBuffer.class;
|
||||
}
|
||||
return results;
|
||||
}
|
||||
|
||||
protected List<NativeRayObject> execute(List<String> rayFunctionInfo,
|
||||
List<Object> argsBytes) {
|
||||
runtime.setIsContextSet(true);
|
||||
TaskType taskType = runtime.getWorkerContext().getCurrentTaskType();
|
||||
TaskId taskId = runtime.getWorkerContext().getCurrentTaskId();
|
||||
LOGGER.debug("Executing task {}", taskId);
|
||||
|
@ -80,11 +107,14 @@ public abstract class TaskExecutor<T extends TaskExecutor.ActorContext> {
|
|||
|
||||
List<NativeRayObject> returnObjects = new ArrayList<>();
|
||||
ClassLoader oldLoader = Thread.currentThread().getContextClassLoader();
|
||||
JavaFunctionDescriptor functionDescriptor = parseFunctionDescriptor(rayFunctionInfo);
|
||||
RayFunction rayFunction = null;
|
||||
RayFunction rayFunction = localRayFunction.get();
|
||||
try {
|
||||
// Find the executable object.
|
||||
rayFunction = runtime.getFunctionManager().getFunction(jobId, functionDescriptor);
|
||||
if (rayFunction == null) {
|
||||
// Failed to get RayFunction in checkByteBufferArguments. Redo here to throw
|
||||
// the exception again.
|
||||
rayFunction = getRayFunction(rayFunctionInfo);
|
||||
}
|
||||
Thread.currentThread().setContextClassLoader(rayFunction.classLoader);
|
||||
runtime.getWorkerContext().setCurrentClassLoader(rayFunction.classLoader);
|
||||
|
||||
|
@ -132,7 +162,7 @@ public abstract class TaskExecutor<T extends TaskExecutor.ActorContext> {
|
|||
LOGGER.error("Error executing task " + taskId, e);
|
||||
if (taskType != TaskType.ACTOR_CREATION_TASK) {
|
||||
boolean hasReturn = rayFunction != null && rayFunction.hasReturn();
|
||||
boolean isCrossLanguage = functionDescriptor.signature.equals("");
|
||||
boolean isCrossLanguage = parseFunctionDescriptor(rayFunctionInfo).signature.equals("");
|
||||
if (hasReturn || isCrossLanguage) {
|
||||
returnObjects.add(ObjectSerializer
|
||||
.serialize(new RayTaskException("Error executing task " + taskId, e)));
|
||||
|
|
|
@ -4,6 +4,8 @@ import com.google.common.collect.ImmutableList;
|
|||
import com.google.common.collect.ImmutableMap;
|
||||
import io.ray.api.Ray;
|
||||
import io.ray.api.id.ObjectId;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.nio.charset.StandardCharsets;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import org.testng.Assert;
|
||||
|
@ -63,6 +65,10 @@ public class RayCallTest extends BaseTest {
|
|||
TestUtils.getRuntime().getObjectStore().put(1, objectId);
|
||||
}
|
||||
|
||||
private static ByteBuffer testByteBuffer(ByteBuffer buffer) {
|
||||
return buffer;
|
||||
}
|
||||
|
||||
/**
|
||||
* Test calling and returning different types.
|
||||
*/
|
||||
|
@ -82,6 +88,11 @@ public class RayCallTest extends BaseTest {
|
|||
Assert.assertEquals(map, Ray.task(RayCallTest::testMap, map).remote().get());
|
||||
TestUtils.LargeObject largeObject = new TestUtils.LargeObject();
|
||||
Assert.assertNotNull(Ray.task(RayCallTest::testLargeObject, largeObject).remote().get());
|
||||
ByteBuffer buffer1 = ByteBuffer.wrap("foo".getBytes(StandardCharsets.UTF_8));
|
||||
ByteBuffer buffer2 = Ray.task(RayCallTest::testByteBuffer, buffer1).remote().get();
|
||||
byte[] bytes = new byte[buffer2.remaining()];
|
||||
buffer2.get(bytes);
|
||||
Assert.assertEquals("foo", new String(bytes, StandardCharsets.UTF_8));
|
||||
|
||||
// TODO(edoakes): this test doesn't work now that we've switched to direct call
|
||||
// mode. To make it work, we need to implement the same protocol for resolving
|
||||
|
|
|
@ -1666,8 +1666,12 @@ Status CoreWorker::GetAndPinArgsForExecutor(const TaskSpecification &task,
|
|||
metadata = std::make_shared<LocalMemoryBuffer>(
|
||||
const_cast<uint8_t *>(task.ArgMetadata(i)), task.ArgMetadataSize(i));
|
||||
}
|
||||
args->at(i) = std::make_shared<RayObject>(data, metadata, task.ArgInlinedIds(i),
|
||||
/*copy_data*/ true);
|
||||
// NOTE: this is a workaround to avoid an extra copy for Java workers.
|
||||
// Python workers need this copy to pass test case
|
||||
// test_inline_arg_memory_corruption.
|
||||
bool copy_data = options_.language == Language::PYTHON;
|
||||
args->at(i) =
|
||||
std::make_shared<RayObject>(data, metadata, task.ArgInlinedIds(i), copy_data);
|
||||
arg_reference_ids->at(i) = ObjectID::Nil();
|
||||
// The task borrows all ObjectIDs that were serialized in the inlined
|
||||
// arguments. The task will receive references to these IDs, so it is
|
||||
|
|
|
@ -43,6 +43,35 @@ inline ray::gcs::GcsClientOptions ToGcsClientOptions(JNIEnv *env,
|
|||
return ray::gcs::GcsClientOptions(ip, port, password, /*is_test_client=*/false);
|
||||
}
|
||||
|
||||
jobject ToJavaArgs(JNIEnv *env, jbooleanArray java_check_results,
|
||||
const std::vector<std::shared_ptr<ray::RayObject>> &args) {
|
||||
if (java_check_results == nullptr) {
|
||||
// If `java_check_results` is null, it means that `checkByteBufferArguments`
|
||||
// failed. In this case, just return null here. The args won't be used anyway.
|
||||
return nullptr;
|
||||
} else {
|
||||
jboolean *check_results = env->GetBooleanArrayElements(java_check_results, nullptr);
|
||||
size_t i = 0;
|
||||
jobject args_array_list = NativeVectorToJavaList<std::shared_ptr<ray::RayObject>>(
|
||||
env, args,
|
||||
[check_results, &i](JNIEnv *env,
|
||||
const std::shared_ptr<ray::RayObject> &native_object) {
|
||||
if (*(check_results + (i++))) {
|
||||
// If the type of this argument is ByteBuffer, we create a
|
||||
// DirectByteBuffer here To avoid data copy.
|
||||
// TODO: Check native_object->GetMetadata() == "RAW"
|
||||
jobject obj = env->NewDirectByteBuffer(native_object->GetData()->Data(),
|
||||
native_object->GetData()->Size());
|
||||
RAY_CHECK(obj);
|
||||
return obj;
|
||||
}
|
||||
return NativeRayObjectToJavaNativeRayObject(env, native_object);
|
||||
});
|
||||
env->ReleaseBooleanArrayElements(java_check_results, check_results, JNI_ABORT);
|
||||
return args_array_list;
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
@ -100,8 +129,12 @@ JNIEXPORT void JNICALL Java_io_ray_runtime_RayNativeRuntime_nativeInitialize(
|
|||
|
||||
// convert args
|
||||
// TODO (kfstorm): Avoid copying binary data from Java to C++
|
||||
jobject args_array_list = NativeVectorToJavaList<std::shared_ptr<ray::RayObject>>(
|
||||
env, args, NativeRayObjectToJavaNativeRayObject);
|
||||
jbooleanArray java_check_results =
|
||||
static_cast<jbooleanArray>(env->CallObjectMethod(
|
||||
java_task_executor, java_task_executor_parse_function_arguments,
|
||||
ray_function_array_list));
|
||||
RAY_CHECK_JAVA_EXCEPTION(env);
|
||||
jobject args_array_list = ToJavaArgs(env, java_check_results, args);
|
||||
|
||||
// invoke Java method
|
||||
jobject java_return_objects =
|
||||
|
@ -120,6 +153,7 @@ JNIEXPORT void JNICALL Java_io_ray_runtime_RayNativeRuntime_nativeInitialize(
|
|||
}
|
||||
}
|
||||
|
||||
env->DeleteLocalRef(java_check_results);
|
||||
env->DeleteLocalRef(java_return_objects);
|
||||
env->DeleteLocalRef(args_array_list);
|
||||
return ray::Status::OK();
|
||||
|
|
|
@ -86,6 +86,7 @@ jfieldID java_native_ray_object_data;
|
|||
jfieldID java_native_ray_object_metadata;
|
||||
|
||||
jclass java_task_executor_class;
|
||||
jmethodID java_task_executor_parse_function_arguments;
|
||||
jmethodID java_task_executor_execute;
|
||||
|
||||
JavaVM *jvm;
|
||||
|
@ -205,6 +206,8 @@ jint JNI_OnLoad(JavaVM *vm, void *reserved) {
|
|||
env->GetFieldID(java_native_ray_object_class, "metadata", "[B");
|
||||
|
||||
java_task_executor_class = LoadClass(env, "io/ray/runtime/task/TaskExecutor");
|
||||
java_task_executor_parse_function_arguments = env->GetMethodID(
|
||||
java_task_executor_class, "checkByteBufferArguments", "(Ljava/util/List;)[Z");
|
||||
java_task_executor_execute =
|
||||
env->GetMethodID(java_task_executor_class, "execute",
|
||||
"(Ljava/util/List;Ljava/util/List;)Ljava/util/List;");
|
||||
|
|
|
@ -148,6 +148,8 @@ extern jfieldID java_native_ray_object_metadata;
|
|||
|
||||
/// TaskExecutor class
|
||||
extern jclass java_task_executor_class;
|
||||
/// checkByteBufferArguments method of TaskExecutor class
|
||||
extern jmethodID java_task_executor_parse_function_arguments;
|
||||
/// execute method of TaskExecutor class
|
||||
extern jmethodID java_task_executor_execute;
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue