diff --git a/java/runtime/src/main/java/io/ray/runtime/object/ObjectSerializer.java b/java/runtime/src/main/java/io/ray/runtime/object/ObjectSerializer.java index d45aba662..35ab285a8 100644 --- a/java/runtime/src/main/java/io/ray/runtime/object/ObjectSerializer.java +++ b/java/runtime/src/main/java/io/ray/runtime/object/ObjectSerializer.java @@ -7,6 +7,7 @@ import io.ray.api.exception.UnreconstructableException; import io.ray.api.id.ObjectId; import io.ray.runtime.generated.Gcs.ErrorType; import io.ray.runtime.serializer.Serializer; +import java.nio.ByteBuffer; import java.util.Arrays; import org.apache.commons.lang3.tuple.Pair; @@ -45,6 +46,9 @@ public class ObjectSerializer { if (meta != null && meta.length > 0) { // If meta is not null, deserialize the object from meta. if (Arrays.equals(meta, OBJECT_METADATA_TYPE_RAW)) { + if (objectType == ByteBuffer.class) { + return ByteBuffer.wrap(data); + } return data; } else if (Arrays.equals(meta, OBJECT_METADATA_TYPE_CROSS_LANGUAGE) || Arrays.equals(meta, OBJECT_METADATA_TYPE_JAVA)) { @@ -81,6 +85,17 @@ public class ObjectSerializer { // If the object is a byte array, skip serializing it and use a special metadata to // indicate it's raw binary. So that this object can also be read by Python. return new NativeRayObject((byte[]) object, OBJECT_METADATA_TYPE_RAW); + } else if (object instanceof ByteBuffer) { + // Serialize ByteBuffer to raw bytes. + ByteBuffer buffer = (ByteBuffer) object; + byte[] bytes; + if (buffer.hasArray()) { + bytes = buffer.array(); + } else { + bytes = new byte[buffer.remaining()]; + buffer.get(bytes); + } + return new NativeRayObject(bytes, OBJECT_METADATA_TYPE_RAW); } else if (object instanceof RayTaskException) { byte[] serializedBytes = Serializer.encode(object).getLeft(); return new NativeRayObject(serializedBytes, TASK_EXECUTION_EXCEPTION_META); diff --git a/java/runtime/src/main/java/io/ray/runtime/task/ArgumentsBuilder.java b/java/runtime/src/main/java/io/ray/runtime/task/ArgumentsBuilder.java index dfc28a021..7d1b30ffd 100644 --- a/java/runtime/src/main/java/io/ray/runtime/task/ArgumentsBuilder.java +++ b/java/runtime/src/main/java/io/ray/runtime/task/ArgumentsBuilder.java @@ -1,5 +1,6 @@ package io.ray.runtime.task; +import com.google.common.base.Preconditions; import io.ray.api.ObjectRef; import io.ray.api.Ray; import io.ray.api.id.ObjectId; @@ -7,6 +8,7 @@ import io.ray.runtime.RayRuntimeInternal; import io.ray.runtime.generated.Common.Language; import io.ray.runtime.object.NativeRayObject; import io.ray.runtime.object.ObjectSerializer; +import java.nio.ByteBuffer; import java.util.ArrayList; import java.util.Arrays; import java.util.List; @@ -68,12 +70,19 @@ public class ArgumentsBuilder { } /** - * Convert list of NativeRayObject to real function arguments. + * Convert list of NativeRayObject/ByteBuffer to real function arguments. */ - public static Object[] unwrap(List args, Class[] types) { + public static Object[] unwrap(List args, Class[] types) { Object[] realArgs = new Object[args.size()]; for (int i = 0; i < args.size(); i++) { - realArgs[i] = ObjectSerializer.deserialize(args.get(i), null, types[i]); + Object arg = args.get(i); + Preconditions.checkState(arg instanceof ByteBuffer || arg instanceof NativeRayObject); + if (arg instanceof ByteBuffer) { + Preconditions.checkState(types[i] == ByteBuffer.class); + realArgs[i] = arg; + } else { + realArgs[i] = ObjectSerializer.deserialize((NativeRayObject) arg, null, types[i]); + } } return realArgs; } diff --git a/java/runtime/src/main/java/io/ray/runtime/task/LocalModeTaskSubmitter.java b/java/runtime/src/main/java/io/ray/runtime/task/LocalModeTaskSubmitter.java index 2358982a6..d729692d2 100644 --- a/java/runtime/src/main/java/io/ray/runtime/task/LocalModeTaskSubmitter.java +++ b/java/runtime/src/main/java/io/ray/runtime/task/LocalModeTaskSubmitter.java @@ -311,8 +311,9 @@ public class LocalModeTaskSubmitter implements TaskSubmitter { ? ((LocalModeTaskExecutor.LocalActorContext) actorContext).getWorkerId() : UniqueId.randomId(); ((LocalModeWorkerContext) runtime.getWorkerContext()).setCurrentWorkerId(workerId); - List returnObjects = taskExecutor - .execute(getJavaFunctionDescriptor(taskSpec).toList(), args); + List rayFunctionInfo = getJavaFunctionDescriptor(taskSpec).toList(); + taskExecutor.checkByteBufferArguments(rayFunctionInfo); + List returnObjects = taskExecutor.execute(rayFunctionInfo, args); if (taskSpec.getType() == TaskType.ACTOR_CREATION_TASK) { // Update actor context map ASAP in case objectStore.putRaw triggered the next actor task // on this actor. diff --git a/java/runtime/src/main/java/io/ray/runtime/task/TaskExecutor.java b/java/runtime/src/main/java/io/ray/runtime/task/TaskExecutor.java index 7519ab753..115e422f7 100644 --- a/java/runtime/src/main/java/io/ray/runtime/task/TaskExecutor.java +++ b/java/runtime/src/main/java/io/ray/runtime/task/TaskExecutor.java @@ -13,6 +13,7 @@ import io.ray.runtime.generated.Common.TaskType; import io.ray.runtime.object.NativeRayObject; import io.ray.runtime.object.ObjectSerializer; import java.lang.reflect.InvocationTargetException; +import java.nio.ByteBuffer; import java.util.ArrayList; import java.util.List; import java.util.concurrent.ConcurrentHashMap; @@ -30,6 +31,8 @@ public abstract class TaskExecutor { private final ConcurrentHashMap actorContextMap = new ConcurrentHashMap<>(); + private final ThreadLocal localRayFunction = new ThreadLocal<>(); + static class ActorContext { /** @@ -61,10 +64,34 @@ public abstract class TaskExecutor { this.actorContextMap.put(runtime.getWorkerContext().getCurrentWorkerId(), actorContext); } - protected List execute(List rayFunctionInfo, - List argsBytes) { - runtime.setIsContextSet(true); + private RayFunction getRayFunction(List rayFunctionInfo) { JobId jobId = runtime.getWorkerContext().getCurrentJobId(); + JavaFunctionDescriptor functionDescriptor = parseFunctionDescriptor(rayFunctionInfo); + return runtime.getFunctionManager().getFunction(jobId, functionDescriptor); + } + + /** + * The return value indicates which parameters are ByteBuffer. + */ + protected boolean[] checkByteBufferArguments(List rayFunctionInfo) { + localRayFunction.set(null); + try { + localRayFunction.set(getRayFunction(rayFunctionInfo)); + } catch (Throwable e) { + // Ignore the exception. + return null; + } + Class[] types = localRayFunction.get().executable.getParameterTypes(); + boolean[] results = new boolean[types.length]; + for (int i = 0; i < types.length; i++) { + results[i] = types[i] == ByteBuffer.class; + } + return results; + } + + protected List execute(List rayFunctionInfo, + List argsBytes) { + runtime.setIsContextSet(true); TaskType taskType = runtime.getWorkerContext().getCurrentTaskType(); TaskId taskId = runtime.getWorkerContext().getCurrentTaskId(); LOGGER.debug("Executing task {}", taskId); @@ -80,11 +107,14 @@ public abstract class TaskExecutor { List returnObjects = new ArrayList<>(); ClassLoader oldLoader = Thread.currentThread().getContextClassLoader(); - JavaFunctionDescriptor functionDescriptor = parseFunctionDescriptor(rayFunctionInfo); - RayFunction rayFunction = null; + RayFunction rayFunction = localRayFunction.get(); try { // Find the executable object. - rayFunction = runtime.getFunctionManager().getFunction(jobId, functionDescriptor); + if (rayFunction == null) { + // Failed to get RayFunction in checkByteBufferArguments. Redo here to throw + // the exception again. + rayFunction = getRayFunction(rayFunctionInfo); + } Thread.currentThread().setContextClassLoader(rayFunction.classLoader); runtime.getWorkerContext().setCurrentClassLoader(rayFunction.classLoader); @@ -132,7 +162,7 @@ public abstract class TaskExecutor { LOGGER.error("Error executing task " + taskId, e); if (taskType != TaskType.ACTOR_CREATION_TASK) { boolean hasReturn = rayFunction != null && rayFunction.hasReturn(); - boolean isCrossLanguage = functionDescriptor.signature.equals(""); + boolean isCrossLanguage = parseFunctionDescriptor(rayFunctionInfo).signature.equals(""); if (hasReturn || isCrossLanguage) { returnObjects.add(ObjectSerializer .serialize(new RayTaskException("Error executing task " + taskId, e))); diff --git a/java/test/src/main/java/io/ray/test/RayCallTest.java b/java/test/src/main/java/io/ray/test/RayCallTest.java index d56f99adf..acb4f149e 100644 --- a/java/test/src/main/java/io/ray/test/RayCallTest.java +++ b/java/test/src/main/java/io/ray/test/RayCallTest.java @@ -4,6 +4,8 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import io.ray.api.Ray; import io.ray.api.id.ObjectId; +import java.nio.ByteBuffer; +import java.nio.charset.StandardCharsets; import java.util.List; import java.util.Map; import org.testng.Assert; @@ -63,6 +65,10 @@ public class RayCallTest extends BaseTest { TestUtils.getRuntime().getObjectStore().put(1, objectId); } + private static ByteBuffer testByteBuffer(ByteBuffer buffer) { + return buffer; + } + /** * Test calling and returning different types. */ @@ -82,6 +88,11 @@ public class RayCallTest extends BaseTest { Assert.assertEquals(map, Ray.task(RayCallTest::testMap, map).remote().get()); TestUtils.LargeObject largeObject = new TestUtils.LargeObject(); Assert.assertNotNull(Ray.task(RayCallTest::testLargeObject, largeObject).remote().get()); + ByteBuffer buffer1 = ByteBuffer.wrap("foo".getBytes(StandardCharsets.UTF_8)); + ByteBuffer buffer2 = Ray.task(RayCallTest::testByteBuffer, buffer1).remote().get(); + byte[] bytes = new byte[buffer2.remaining()]; + buffer2.get(bytes); + Assert.assertEquals("foo", new String(bytes, StandardCharsets.UTF_8)); // TODO(edoakes): this test doesn't work now that we've switched to direct call // mode. To make it work, we need to implement the same protocol for resolving diff --git a/src/ray/core_worker/core_worker.cc b/src/ray/core_worker/core_worker.cc index 9bcdb0d5a..8bdcb20f5 100644 --- a/src/ray/core_worker/core_worker.cc +++ b/src/ray/core_worker/core_worker.cc @@ -1666,8 +1666,12 @@ Status CoreWorker::GetAndPinArgsForExecutor(const TaskSpecification &task, metadata = std::make_shared( const_cast(task.ArgMetadata(i)), task.ArgMetadataSize(i)); } - args->at(i) = std::make_shared(data, metadata, task.ArgInlinedIds(i), - /*copy_data*/ true); + // NOTE: this is a workaround to avoid an extra copy for Java workers. + // Python workers need this copy to pass test case + // test_inline_arg_memory_corruption. + bool copy_data = options_.language == Language::PYTHON; + args->at(i) = + std::make_shared(data, metadata, task.ArgInlinedIds(i), copy_data); arg_reference_ids->at(i) = ObjectID::Nil(); // The task borrows all ObjectIDs that were serialized in the inlined // arguments. The task will receive references to these IDs, so it is diff --git a/src/ray/core_worker/lib/java/io_ray_runtime_RayNativeRuntime.cc b/src/ray/core_worker/lib/java/io_ray_runtime_RayNativeRuntime.cc index 8f4b9d0ec..c602fca27 100644 --- a/src/ray/core_worker/lib/java/io_ray_runtime_RayNativeRuntime.cc +++ b/src/ray/core_worker/lib/java/io_ray_runtime_RayNativeRuntime.cc @@ -43,6 +43,35 @@ inline ray::gcs::GcsClientOptions ToGcsClientOptions(JNIEnv *env, return ray::gcs::GcsClientOptions(ip, port, password, /*is_test_client=*/false); } +jobject ToJavaArgs(JNIEnv *env, jbooleanArray java_check_results, + const std::vector> &args) { + if (java_check_results == nullptr) { + // If `java_check_results` is null, it means that `checkByteBufferArguments` + // failed. In this case, just return null here. The args won't be used anyway. + return nullptr; + } else { + jboolean *check_results = env->GetBooleanArrayElements(java_check_results, nullptr); + size_t i = 0; + jobject args_array_list = NativeVectorToJavaList>( + env, args, + [check_results, &i](JNIEnv *env, + const std::shared_ptr &native_object) { + if (*(check_results + (i++))) { + // If the type of this argument is ByteBuffer, we create a + // DirectByteBuffer here To avoid data copy. + // TODO: Check native_object->GetMetadata() == "RAW" + jobject obj = env->NewDirectByteBuffer(native_object->GetData()->Data(), + native_object->GetData()->Size()); + RAY_CHECK(obj); + return obj; + } + return NativeRayObjectToJavaNativeRayObject(env, native_object); + }); + env->ReleaseBooleanArrayElements(java_check_results, check_results, JNI_ABORT); + return args_array_list; + } +} + #ifdef __cplusplus extern "C" { #endif @@ -100,8 +129,12 @@ JNIEXPORT void JNICALL Java_io_ray_runtime_RayNativeRuntime_nativeInitialize( // convert args // TODO (kfstorm): Avoid copying binary data from Java to C++ - jobject args_array_list = NativeVectorToJavaList>( - env, args, NativeRayObjectToJavaNativeRayObject); + jbooleanArray java_check_results = + static_cast(env->CallObjectMethod( + java_task_executor, java_task_executor_parse_function_arguments, + ray_function_array_list)); + RAY_CHECK_JAVA_EXCEPTION(env); + jobject args_array_list = ToJavaArgs(env, java_check_results, args); // invoke Java method jobject java_return_objects = @@ -120,6 +153,7 @@ JNIEXPORT void JNICALL Java_io_ray_runtime_RayNativeRuntime_nativeInitialize( } } + env->DeleteLocalRef(java_check_results); env->DeleteLocalRef(java_return_objects); env->DeleteLocalRef(args_array_list); return ray::Status::OK(); diff --git a/src/ray/core_worker/lib/java/jni_init.cc b/src/ray/core_worker/lib/java/jni_init.cc index 504c05eca..c3078b114 100644 --- a/src/ray/core_worker/lib/java/jni_init.cc +++ b/src/ray/core_worker/lib/java/jni_init.cc @@ -86,6 +86,7 @@ jfieldID java_native_ray_object_data; jfieldID java_native_ray_object_metadata; jclass java_task_executor_class; +jmethodID java_task_executor_parse_function_arguments; jmethodID java_task_executor_execute; JavaVM *jvm; @@ -205,6 +206,8 @@ jint JNI_OnLoad(JavaVM *vm, void *reserved) { env->GetFieldID(java_native_ray_object_class, "metadata", "[B"); java_task_executor_class = LoadClass(env, "io/ray/runtime/task/TaskExecutor"); + java_task_executor_parse_function_arguments = env->GetMethodID( + java_task_executor_class, "checkByteBufferArguments", "(Ljava/util/List;)[Z"); java_task_executor_execute = env->GetMethodID(java_task_executor_class, "execute", "(Ljava/util/List;Ljava/util/List;)Ljava/util/List;"); diff --git a/src/ray/core_worker/lib/java/jni_utils.h b/src/ray/core_worker/lib/java/jni_utils.h index b6bcaaa63..8a2ebd65b 100644 --- a/src/ray/core_worker/lib/java/jni_utils.h +++ b/src/ray/core_worker/lib/java/jni_utils.h @@ -148,6 +148,8 @@ extern jfieldID java_native_ray_object_metadata; /// TaskExecutor class extern jclass java_task_executor_class; +/// checkByteBufferArguments method of TaskExecutor class +extern jmethodID java_task_executor_parse_function_arguments; /// execute method of TaskExecutor class extern jmethodID java_task_executor_execute;