From 806524384b93fe6a776f1883dc2e09b3010573f9 Mon Sep 17 00:00:00 2001 From: Kai Yang Date: Tue, 16 Jul 2019 20:58:02 +0800 Subject: [PATCH] [Java worker] Refactor object store and worker context on top of core worker (#5079) --- BUILD.bazel | 29 ++- bazel/BUILD.plasma | 30 --- java/BUILD.bazel | 9 +- java/pom.xml | 5 - java/runtime/pom.xml | 4 - java/runtime/pom_template.xml | 4 - .../org/ray/runtime/AbstractRayRuntime.java | 60 +++++- .../java/org/ray/runtime/RayDevRuntime.java | 15 +- .../org/ray/runtime/RayNativeRuntime.java | 74 ++----- .../java/org/ray/runtime/WorkerContext.java | 102 +++++----- .../org/ray/runtime/config/RayConfig.java | 13 +- .../org/ray/runtime/config/WorkerMode.java | 6 - .../objectstore/MockObjectInterface.java | 98 ++++++++++ .../runtime/objectstore/MockObjectStore.java | 148 -------------- .../runtime/objectstore/NativeRayObject.java | 13 ++ .../runtime/objectstore/ObjectInterface.java | 54 ++++++ .../objectstore/ObjectInterfaceImpl.java | 91 +++++++++ .../runtime/objectstore/ObjectStoreProxy.java | 99 ++++------ .../ray/runtime/raylet/MockRayletClient.java | 35 ++-- .../ray/runtime/raylet/RayletClientImpl.java | 12 +- .../java/org/ray/runtime/util/IdUtil.java | 12 -- .../java/org/ray/api/test/FailureTest.java | 36 ++++ .../org/ray/api/test/PlasmaStoreTest.java | 22 +-- .../java/org/ray/api/test/RayConfigTest.java | 4 +- src/ray/common/ray_config_def.h | 7 + src/ray/core_worker/common.h | 4 +- src/ray/core_worker/context.cc | 56 +++--- src/ray/core_worker/context.h | 8 +- src/ray/core_worker/core_worker.cc | 6 +- src/ray/core_worker/lib/java/jni_init.cc | 75 ++++++++ src/ray/core_worker/lib/java/jni_utils.h | 180 ++++++++++++++++++ .../lib/java/org_ray_runtime_WorkerContext.cc | 134 +++++++++++++ .../lib/java/org_ray_runtime_WorkerContext.h | 87 +++++++++ ...runtime_objectstore_ObjectInterfaceImpl.cc | 149 +++++++++++++++ ..._runtime_objectstore_ObjectInterfaceImpl.h | 72 +++++++ .../store_provider/plasma_store_provider.cc | 47 ++++- .../store_provider/plasma_store_provider.h | 14 ++ src/ray/protobuf/common.proto | 6 + src/ray/protobuf/gcs.proto | 2 + ...org_ray_runtime_raylet_RayletClientImpl.cc | 135 +++++-------- 40 files changed, 1386 insertions(+), 571 deletions(-) delete mode 100644 java/runtime/src/main/java/org/ray/runtime/config/WorkerMode.java create mode 100644 java/runtime/src/main/java/org/ray/runtime/objectstore/MockObjectInterface.java delete mode 100644 java/runtime/src/main/java/org/ray/runtime/objectstore/MockObjectStore.java create mode 100644 java/runtime/src/main/java/org/ray/runtime/objectstore/NativeRayObject.java create mode 100644 java/runtime/src/main/java/org/ray/runtime/objectstore/ObjectInterface.java create mode 100644 java/runtime/src/main/java/org/ray/runtime/objectstore/ObjectInterfaceImpl.java create mode 100644 src/ray/core_worker/lib/java/jni_init.cc create mode 100644 src/ray/core_worker/lib/java/jni_utils.h create mode 100644 src/ray/core_worker/lib/java/org_ray_runtime_WorkerContext.cc create mode 100644 src/ray/core_worker/lib/java/org_ray_runtime_WorkerContext.h create mode 100644 src/ray/core_worker/lib/java/org_ray_runtime_objectstore_ObjectInterfaceImpl.cc create mode 100644 src/ray/core_worker/lib/java/org_ray_runtime_objectstore_ObjectInterfaceImpl.h diff --git a/BUILD.bazel b/BUILD.bazel index 95fd33a52..cc3da5139 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -647,13 +647,13 @@ pyx_library( ) cc_binary( - name = "libraylet_library_java.so", - srcs = [ - "src/ray/raylet/lib/java/org_ray_runtime_raylet_RayletClientImpl.h", - "src/ray/raylet/lib/java/org_ray_runtime_raylet_RayletClientImpl.cc", - "src/ray/common/id.h", - "src/ray/raylet/raylet_client.h", - "src/ray/util/logging.h", + name = "libcore_worker_library_java.so", + srcs = glob([ + "src/ray/core_worker/lib/java/*.h", + "src/ray/core_worker/lib/java/*.cc", + "src/ray/raylet/lib/java/*.h", + "src/ray/raylet/lib/java/*.cc", + ]) + [ "@bazel_tools//tools/jdk:jni_header", ] + select({ "@bazel_tools//src/conditions:windows": ["@bazel_tools//tools/jdk:jni_md_header-windows"], @@ -671,24 +671,23 @@ cc_binary( linkshared = 1, linkstatic = 1, deps = [ - "//:raylet_lib", - "@plasma//:plasma_client", + "//:core_worker_lib", ], ) genrule( - name = "raylet-jni-darwin-compat", - srcs = [":libraylet_library_java.so"], - outs = ["libraylet_library_java.dylib"], + name = "core_worker-jni-darwin-compat", + srcs = [":libcore_worker_library_java.so"], + outs = ["libcore_worker_library_java.dylib"], cmd = "cp $< $@", output_to_bindir = 1, ) filegroup( - name = "raylet_library_java", + name = "core_worker_library_java", srcs = select({ - "@bazel_tools//src/conditions:darwin": [":libraylet_library_java.dylib"], - "//conditions:default": [":libraylet_library_java.so"], + "@bazel_tools//src/conditions:darwin": [":libcore_worker_library_java.dylib"], + "//conditions:default": [":libcore_worker_library_java.so"], }), visibility = ["//java:__subpackages__"], ) diff --git a/bazel/BUILD.plasma b/bazel/BUILD.plasma index ff0fe3e14..5a8b57afc 100644 --- a/bazel/BUILD.plasma +++ b/bazel/BUILD.plasma @@ -2,27 +2,6 @@ load("@com_github_google_flatbuffers//:build_defs.bzl", "flatbuffer_cc_library") COPTS = ["-DARROW_USE_GLOG"] -java_library( - name = "org_apache_arrow_arrow_plasma", - srcs = glob(["java/plasma/src/main/java/**/*.java"]), - data = [":plasma_client_java"], - visibility = ["//visibility:public"], - deps = [ - "@maven//:org_slf4j_slf4j_api", - ], -) - -java_binary( - name = "org_apache_arrow_arrow_plasma_test", - srcs = ["java/plasma/src/test/java/org/apache/arrow/plasma/PlasmaClientTest.java"], - main_class = "org.apache.arrow.plasma.PlasmaClientTest", - visibility = ["//visibility:public"], - deps = [ - ":org_apache_arrow_arrow_plasma", - "@maven//:junit_junit", - ], -) - cc_library( name = "arrow", srcs = [ @@ -145,15 +124,6 @@ genrule( output_to_bindir = 1, ) -filegroup( - name = "plasma_client_java", - srcs = select({ - "@bazel_tools//src/conditions:darwin": [":libplasma_java.dylib"], - "//conditions:default": [":libplasma_java.so"], - }), - visibility = ["//visibility:public"], -) - cc_library( name = "plasma_lib", srcs = [ diff --git a/java/BUILD.bazel b/java/BUILD.bazel index b9c43424f..37ef5b93b 100644 --- a/java/BUILD.bazel +++ b/java/BUILD.bazel @@ -69,7 +69,6 @@ define_java_module( ], deps = [ ":org_ray_ray_api", - "@plasma//:org_apache_arrow_arrow_plasma", "@maven//:com_google_guava_guava", "@maven//:com_google_protobuf_protobuf_java", "@maven//:com_typesafe_config", @@ -97,7 +96,6 @@ define_java_module( deps = [ ":org_ray_ray_api", ":org_ray_ray_runtime", - "@plasma//:org_apache_arrow_arrow_plasma", "@maven//:com_google_guava_guava", "@maven//:com_sun_xml_bind_jaxb_core", "@maven//:com_sun_xml_bind_jaxb_impl", @@ -176,9 +174,8 @@ filegroup( "//:redis-server", "//:libray_redis_module.so", "//:raylet", - "//:raylet_library_java", + "//:core_worker_library_java", "@plasma//:plasma_store_server", - "@plasma//:plasma_client_java", ], ) @@ -189,7 +186,6 @@ genrule( ":all_java_proto", ":java_native_deps", ":copy_pom_file", - "@plasma//:org_apache_arrow_arrow_plasma", ], outs = ["gen_maven_deps.out"], cmd = """ @@ -208,9 +204,6 @@ genrule( chmod +w $$f cp $$f $$NATIVE_DEPS_DIR done - # Install plasma jar to local maven repo. - mvn install:install-file -Dfile=$(locations @plasma//:org_apache_arrow_arrow_plasma) -Dpackaging=jar \ - -DgroupId=org.apache.arrow -DartifactId=arrow-plasma -Dversion=0.13.0-SNAPSHOT echo $$(date) > $@ """, local = 1, diff --git a/java/pom.xml b/java/pom.xml index bf7a41229..912b803de 100644 --- a/java/pom.xml +++ b/java/pom.xml @@ -24,11 +24,6 @@ - - org.apache.arrow - arrow-plasma - 0.13.0-SNAPSHOT - diff --git a/java/runtime/pom.xml b/java/runtime/pom.xml index aba612b36..3c40f7ffc 100644 --- a/java/runtime/pom.xml +++ b/java/runtime/pom.xml @@ -22,10 +22,6 @@ ray-api ${project.version} - - org.apache.arrow - arrow-plasma - com.beust jcommander diff --git a/java/runtime/pom_template.xml b/java/runtime/pom_template.xml index 9200bd6c6..10a36bfce 100644 --- a/java/runtime/pom_template.xml +++ b/java/runtime/pom_template.xml @@ -22,10 +22,6 @@ ray-api ${project.version} - - org.apache.arrow - arrow-plasma - {generated_bzl_deps} diff --git a/java/runtime/src/main/java/org/ray/runtime/AbstractRayRuntime.java b/java/runtime/src/main/java/org/ray/runtime/AbstractRayRuntime.java index 2d51f113a..831de1acf 100644 --- a/java/runtime/src/main/java/org/ray/runtime/AbstractRayRuntime.java +++ b/java/runtime/src/main/java/org/ray/runtime/AbstractRayRuntime.java @@ -1,7 +1,15 @@ package org.ray.runtime; import com.google.common.base.Preconditions; +import com.google.common.base.Strings; import com.google.common.collect.ImmutableList; +import java.io.File; +import java.io.IOException; +import java.io.InputStream; +import java.lang.reflect.Field; +import java.nio.file.Files; +import java.nio.file.Paths; +import java.nio.file.StandardCopyOption; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; @@ -72,6 +80,27 @@ public abstract class AbstractRayRuntime implements RayRuntime { protected RuntimeContext runtimeContext; protected GcsClient gcsClient; + static { + try { + LOGGER.debug("Loading native libraries."); + // Load native libraries. + String[] libraries = new String[]{"core_worker_library_java"}; + for (String library : libraries) { + String fileName = System.mapLibraryName(library); + // Copy the file from resources to a temp dir, and load the native library. + File file = File.createTempFile(fileName, ""); + file.deleteOnExit(); + InputStream in = AbstractRayRuntime.class.getResourceAsStream("/" + fileName); + Preconditions.checkNotNull(in, "{} doesn't exist.", fileName); + Files.copy(in, Paths.get(file.getAbsolutePath()), StandardCopyOption.REPLACE_EXISTING); + System.load(file.getAbsolutePath()); + } + LOGGER.debug("Native libraries loaded."); + } catch (IOException e) { + throw new RuntimeException("Couldn't load native libraries.", e); + } + } + public AbstractRayRuntime(RayConfig rayConfig) { this.rayConfig = rayConfig; functionManager = new FunctionManager(rayConfig.jobResourcePath); @@ -79,6 +108,33 @@ public abstract class AbstractRayRuntime implements RayRuntime { runtimeContext = new RuntimeContextImpl(this); } + protected void resetLibraryPath() { + if (rayConfig.libraryPath.isEmpty()) { + return; + } + + String path = System.getProperty("java.library.path"); + if (Strings.isNullOrEmpty(path)) { + path = ""; + } else { + path += ":"; + } + path += String.join(":", rayConfig.libraryPath); + + // This is a hack to reset library path at runtime, + // see https://stackoverflow.com/questions/15409223/. + System.setProperty("java.library.path", path); + // Set sys_paths to null so that java.library.path will be re-evaluated next time it is needed. + final Field sysPathsField; + try { + sysPathsField = ClassLoader.class.getDeclaredField("sys_paths"); + sysPathsField.setAccessible(true); + sysPathsField.set(null, null); + } catch (NoSuchFieldException | IllegalAccessException e) { + LOGGER.error("Failed to set library path.", e); + } + } + /** * Start runtime. */ @@ -330,8 +386,8 @@ public abstract class AbstractRayRuntime implements RayRuntime { * Create the task specification. * * @param func The target remote function. - * @param pyFunctionDescriptor Descriptor of the target Python function, if the task is a - * Python task. + * @param pyFunctionDescriptor Descriptor of the target Python function, if the task is a Python + * task. * @param actor The actor handle. If the task is not an actor task, actor id must be NIL. * @param args The arguments for the remote function. * @param isActorCreationTask Whether this task is an actor creation task. diff --git a/java/runtime/src/main/java/org/ray/runtime/RayDevRuntime.java b/java/runtime/src/main/java/org/ray/runtime/RayDevRuntime.java index a53d59bc8..a491d89e5 100644 --- a/java/runtime/src/main/java/org/ray/runtime/RayDevRuntime.java +++ b/java/runtime/src/main/java/org/ray/runtime/RayDevRuntime.java @@ -3,7 +3,7 @@ package org.ray.runtime; import java.util.concurrent.atomic.AtomicInteger; import org.ray.api.id.JobId; import org.ray.runtime.config.RayConfig; -import org.ray.runtime.objectstore.MockObjectStore; +import org.ray.runtime.objectstore.MockObjectInterface; import org.ray.runtime.objectstore.ObjectStoreProxy; import org.ray.runtime.raylet.MockRayletClient; @@ -13,19 +13,22 @@ public class RayDevRuntime extends AbstractRayRuntime { super(rayConfig); } - private MockObjectStore store; + private MockObjectInterface objectInterface; private AtomicInteger jobCounter = new AtomicInteger(0); @Override public void start() { - store = new MockObjectStore(this); + // Reset library path at runtime. + resetLibraryPath(); + + objectInterface = new MockObjectInterface(workerContext); if (rayConfig.getJobId().isNil()) { rayConfig.setJobId(nextJobId()); } workerContext = new WorkerContext(rayConfig.workerMode, rayConfig.getJobId(), rayConfig.runMode); - objectStoreProxy = new ObjectStoreProxy(this, null); + objectStoreProxy = new ObjectStoreProxy(workerContext, objectInterface); rayletClient = new MockRayletClient(this, rayConfig.numberExecThreadsForDevRuntime); } @@ -34,8 +37,8 @@ public class RayDevRuntime extends AbstractRayRuntime { rayletClient.destroy(); } - public MockObjectStore getObjectStore() { - return store; + public MockObjectInterface getObjectInterface() { + return objectInterface; } @Override diff --git a/java/runtime/src/main/java/org/ray/runtime/RayNativeRuntime.java b/java/runtime/src/main/java/org/ray/runtime/RayNativeRuntime.java index 8d98b18f4..cf804ee02 100644 --- a/java/runtime/src/main/java/org/ray/runtime/RayNativeRuntime.java +++ b/java/runtime/src/main/java/org/ray/runtime/RayNativeRuntime.java @@ -1,21 +1,13 @@ package org.ray.runtime; -import com.google.common.base.Preconditions; -import com.google.common.base.Strings; -import java.io.File; -import java.io.IOException; -import java.io.InputStream; -import java.lang.reflect.Field; -import java.nio.file.Files; -import java.nio.file.Paths; -import java.nio.file.StandardCopyOption; import java.util.HashMap; import java.util.Map; import org.ray.api.id.JobId; import org.ray.runtime.config.RayConfig; -import org.ray.runtime.config.WorkerMode; import org.ray.runtime.gcs.GcsClient; import org.ray.runtime.gcs.RedisClient; +import org.ray.runtime.generated.Common.WorkerType; +import org.ray.runtime.objectstore.ObjectInterfaceImpl; import org.ray.runtime.objectstore.ObjectStoreProxy; import org.ray.runtime.raylet.RayletClientImpl; import org.ray.runtime.runner.RunManager; @@ -31,58 +23,12 @@ public final class RayNativeRuntime extends AbstractRayRuntime { private RunManager manager = null; - static { - try { - LOGGER.debug("Loading native libraries."); - // Load native libraries. - String[] libraries = new String[]{"raylet_library_java", "plasma_java"}; - for (String library : libraries) { - String fileName = System.mapLibraryName(library); - // Copy the file from resources to a temp dir, and load the native library. - File file = File.createTempFile(fileName, ""); - file.deleteOnExit(); - InputStream in = RayNativeRuntime.class.getResourceAsStream("/" + fileName); - Preconditions.checkNotNull(in, "{} doesn't exist.", fileName); - Files.copy(in, Paths.get(file.getAbsolutePath()), StandardCopyOption.REPLACE_EXISTING); - System.load(file.getAbsolutePath()); - } - LOGGER.debug("Native libraries loaded."); - } catch (IOException e) { - throw new RuntimeException("Couldn't load native libraries.", e); - } - } + private ObjectInterfaceImpl objectInterfaceImpl = null; public RayNativeRuntime(RayConfig rayConfig) { super(rayConfig); } - private void resetLibraryPath() { - if (rayConfig.libraryPath.isEmpty()) { - return; - } - - String path = System.getProperty("java.library.path"); - if (Strings.isNullOrEmpty(path)) { - path = ""; - } else { - path += ":"; - } - path += String.join(":", rayConfig.libraryPath); - - // This is a hack to reset library path at runtime, - // see https://stackoverflow.com/questions/15409223/. - System.setProperty("java.library.path", path); - // Set sys_paths to null so that java.library.path will be re-evaluated next time it is needed. - final Field sysPathsField; - try { - sysPathsField = ClassLoader.class.getDeclaredField("sys_paths"); - sysPathsField.setAccessible(true); - sysPathsField.set(null, null); - } catch (NoSuchFieldException | IllegalAccessException e) { - LOGGER.error("Failed to set library path.", e); - } - } - @Override public void start() { // Reset library path at runtime. @@ -101,16 +47,18 @@ public final class RayNativeRuntime extends AbstractRayRuntime { workerContext = new WorkerContext(rayConfig.workerMode, rayConfig.getJobId(), rayConfig.runMode); - // TODO(qwang): Get object_store_socket_name and raylet_socket_name from Redis. - objectStoreProxy = new ObjectStoreProxy(this, rayConfig.objectStoreSocketName); - rayletClient = new RayletClientImpl( rayConfig.rayletSocketName, workerContext.getCurrentWorkerId(), - rayConfig.workerMode == WorkerMode.WORKER, + rayConfig.workerMode == WorkerType.WORKER, workerContext.getCurrentJobId() ); + // TODO(qwang): Get object_store_socket_name and raylet_socket_name from Redis. + objectInterfaceImpl = new ObjectInterfaceImpl(workerContext, rayletClient, + rayConfig.objectStoreSocketName); + objectStoreProxy = new ObjectStoreProxy(workerContext, objectInterfaceImpl); + // register registerWorker(); @@ -123,6 +71,8 @@ public final class RayNativeRuntime extends AbstractRayRuntime { if (null != manager) { manager.cleanup(); } + objectInterfaceImpl.destroy(); + workerContext.destroy(); } /** @@ -132,7 +82,7 @@ public final class RayNativeRuntime extends AbstractRayRuntime { RedisClient redisClient = new RedisClient(rayConfig.getRedisAddress(), rayConfig.redisPassword); Map workerInfo = new HashMap<>(); String workerId = new String(workerContext.getCurrentWorkerId().getBytes()); - if (rayConfig.workerMode == WorkerMode.DRIVER) { + if (rayConfig.workerMode == WorkerType.DRIVER) { workerInfo.put("node_ip_address", rayConfig.nodeIp); workerInfo.put("driver_id", workerId); workerInfo.put("start_time", String.valueOf(System.currentTimeMillis())); diff --git a/java/runtime/src/main/java/org/ray/runtime/WorkerContext.java b/java/runtime/src/main/java/org/ray/runtime/WorkerContext.java index 828d39cb5..4153e732a 100644 --- a/java/runtime/src/main/java/org/ray/runtime/WorkerContext.java +++ b/java/runtime/src/main/java/org/ray/runtime/WorkerContext.java @@ -1,37 +1,24 @@ package org.ray.runtime; import com.google.common.base.Preconditions; +import java.nio.ByteBuffer; import org.ray.api.id.JobId; import org.ray.api.id.TaskId; import org.ray.api.id.UniqueId; import org.ray.runtime.config.RunMode; -import org.ray.runtime.config.WorkerMode; +import org.ray.runtime.generated.Common.WorkerType; +import org.ray.runtime.raylet.RayletClientImpl; import org.ray.runtime.task.TaskSpec; -import org.ray.runtime.util.IdUtil; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; +/** + * This is a wrapper class for worker context of core worker. + */ public class WorkerContext { - private static final Logger LOGGER = LoggerFactory.getLogger(WorkerContext.class); - - private UniqueId workerId; - - private ThreadLocal currentTaskId; - /** - * Number of objects that have been put from current task. + * The native pointer of worker context of core worker. */ - private ThreadLocal putIndex; - - /** - * Number of tasks that have been submitted from current task. - */ - private ThreadLocal taskIndex; - - private ThreadLocal currentTask; - - private JobId currentJobId; + private final long nativeWorkerContextPointer; private ClassLoader currentClassLoader; @@ -45,31 +32,23 @@ public class WorkerContext { */ private RunMode runMode; - public WorkerContext(WorkerMode workerMode, JobId jobId, RunMode runMode) { + public WorkerContext(WorkerType workerType, JobId jobId, RunMode runMode) { + this.nativeWorkerContextPointer = nativeCreateWorkerContext(workerType.getNumber(), jobId.getBytes()); mainThreadId = Thread.currentThread().getId(); - taskIndex = ThreadLocal.withInitial(() -> 0); - putIndex = ThreadLocal.withInitial(() -> 0); - currentTaskId = ThreadLocal.withInitial(TaskId::randomId); this.runMode = runMode; - currentTask = ThreadLocal.withInitial(() -> null); currentClassLoader = null; - if (workerMode == WorkerMode.DRIVER) { - workerId = IdUtil.computeDriverId(jobId); - currentTaskId.set(TaskId.randomId()); - currentJobId = jobId; - } else { - workerId = UniqueId.randomId(); - this.currentTaskId.set(TaskId.NIL); - this.currentJobId = JobId.NIL; - } + } + + public long getNativeWorkerContext() { + return nativeWorkerContextPointer; } /** * @return For the main thread, this method returns the ID of this worker's current running task; - * for other threads, this method returns a random ID. + * for other threads, this method returns a random ID. */ public TaskId getCurrentTaskId() { - return currentTaskId.get(); + return new TaskId(nativeGetCurrentTaskId(nativeWorkerContextPointer)); } /** @@ -79,17 +58,14 @@ public class WorkerContext { public void setCurrentTask(TaskSpec task, ClassLoader classLoader) { if (runMode == RunMode.CLUSTER) { Preconditions.checkState( - Thread.currentThread().getId() == mainThreadId, - "This method should only be called from the main thread." + Thread.currentThread().getId() == mainThreadId, + "This method should only be called from the main thread." ); } Preconditions.checkNotNull(task); - this.currentTaskId.set(task.taskId); - this.currentJobId = task.jobId; - taskIndex.set(0); - putIndex.set(0); - this.currentTask.set(task); + byte[] taskSpec = RayletClientImpl.convertTaskSpecToProtobuf(task); + nativeSetCurrentTask(nativeWorkerContextPointer, taskSpec); currentClassLoader = classLoader; } @@ -97,30 +73,28 @@ public class WorkerContext { * Increment the put index and return the new value. */ public int nextPutIndex() { - putIndex.set(putIndex.get() + 1); - return putIndex.get(); + return nativeGetNextPutIndex(nativeWorkerContextPointer); } /** * Increment the task index and return the new value. */ public int nextTaskIndex() { - taskIndex.set(taskIndex.get() + 1); - return taskIndex.get(); + return nativeGetNextTaskIndex(nativeWorkerContextPointer); } /** * @return The ID of the current worker. */ public UniqueId getCurrentWorkerId() { - return workerId; + return new UniqueId(nativeGetCurrentWorkerId(nativeWorkerContextPointer)); } /** * The ID of the current job. */ public JobId getCurrentJobId() { - return currentJobId; + return JobId.fromByteBuffer(nativeGetCurrentJobId(nativeWorkerContextPointer)); } /** @@ -134,6 +108,32 @@ public class WorkerContext { * Get the current task. */ public TaskSpec getCurrentTask() { - return this.currentTask.get(); + byte[] bytes = nativeGetCurrentTask(nativeWorkerContextPointer); + if (bytes == null) { + return null; + } + return RayletClientImpl.parseTaskSpecFromProtobuf(bytes); } + + public void destroy() { + nativeDestroy(nativeWorkerContextPointer); + } + + private static native long nativeCreateWorkerContext(int workerType, byte[] jobId); + + private static native byte[] nativeGetCurrentTaskId(long nativeWorkerContextPointer); + + private static native void nativeSetCurrentTask(long nativeWorkerContextPointer, byte[] taskSpec); + + private static native byte[] nativeGetCurrentTask(long nativeWorkerContextPointer); + + private static native ByteBuffer nativeGetCurrentJobId(long nativeWorkerContextPointer); + + private static native byte[] nativeGetCurrentWorkerId(long nativeWorkerContextPointer); + + private static native int nativeGetNextTaskIndex(long nativeWorkerContextPointer); + + private static native int nativeGetNextPutIndex(long nativeWorkerContextPointer); + + private static native void nativeDestroy(long nativeWorkerContextPointer); } diff --git a/java/runtime/src/main/java/org/ray/runtime/config/RayConfig.java b/java/runtime/src/main/java/org/ray/runtime/config/RayConfig.java index e67c88d59..1e90d68f4 100644 --- a/java/runtime/src/main/java/org/ray/runtime/config/RayConfig.java +++ b/java/runtime/src/main/java/org/ray/runtime/config/RayConfig.java @@ -11,6 +11,7 @@ import java.util.ArrayList; import java.util.List; import java.util.Map; import org.ray.api.id.JobId; +import org.ray.runtime.generated.Common.WorkerType; import org.ray.runtime.util.NetworkUtil; import org.ray.runtime.util.ResourceUtil; import org.ray.runtime.util.StringUtil; @@ -29,7 +30,7 @@ public class RayConfig { public static final String CUSTOM_CONFIG_FILE = "ray.conf"; public final String nodeIp; - public final WorkerMode workerMode; + public final WorkerType workerMode; public final RunMode runMode; public final Map resources; private JobId jobId; @@ -62,7 +63,7 @@ public class RayConfig { public final int numberExecThreadsForDevRuntime; private void validate() { - if (workerMode == WorkerMode.WORKER) { + if (workerMode == WorkerType.WORKER) { Preconditions.checkArgument(redisAddress != null, "Redis address must be set in worker mode."); } @@ -78,14 +79,14 @@ public class RayConfig { public RayConfig(Config config) { // Worker mode. - WorkerMode localWorkerMode; + WorkerType localWorkerMode; try { - localWorkerMode = config.getEnum(WorkerMode.class, "ray.worker.mode"); + localWorkerMode = config.getEnum(WorkerType.class, "ray.worker.mode"); } catch (ConfigException.Missing e) { - localWorkerMode = WorkerMode.DRIVER; + localWorkerMode = WorkerType.DRIVER; } workerMode = localWorkerMode; - boolean isDriver = workerMode == WorkerMode.DRIVER; + boolean isDriver = workerMode == WorkerType.DRIVER; // Run mode. runMode = config.getEnum(RunMode.class, "ray.run-mode"); // Node ip. diff --git a/java/runtime/src/main/java/org/ray/runtime/config/WorkerMode.java b/java/runtime/src/main/java/org/ray/runtime/config/WorkerMode.java deleted file mode 100644 index 947159c3b..000000000 --- a/java/runtime/src/main/java/org/ray/runtime/config/WorkerMode.java +++ /dev/null @@ -1,6 +0,0 @@ -package org.ray.runtime.config; - -public enum WorkerMode { - DRIVER, - WORKER -} diff --git a/java/runtime/src/main/java/org/ray/runtime/objectstore/MockObjectInterface.java b/java/runtime/src/main/java/org/ray/runtime/objectstore/MockObjectInterface.java new file mode 100644 index 000000000..8ec855bca --- /dev/null +++ b/java/runtime/src/main/java/org/ray/runtime/objectstore/MockObjectInterface.java @@ -0,0 +1,98 @@ +package org.ray.runtime.objectstore; + +import com.google.common.base.Preconditions; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.function.Consumer; +import java.util.stream.Collectors; +import org.ray.api.id.ObjectId; +import org.ray.runtime.WorkerContext; +import org.ray.runtime.util.IdUtil; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class MockObjectInterface implements ObjectInterface { + + private static final Logger LOGGER = LoggerFactory.getLogger(MockObjectInterface.class); + + private static final int GET_CHECK_INTERVAL_MS = 100; + + private final Map pool = new ConcurrentHashMap<>(); + private final List> objectPutCallbacks = new ArrayList<>(); + private final WorkerContext workerContext; + + public MockObjectInterface(WorkerContext workerContext) { + this.workerContext = workerContext; + } + + public void addObjectPutCallback(Consumer callback) { + this.objectPutCallbacks.add(callback); + } + + public boolean isObjectReady(ObjectId id) { + return pool.containsKey(id); + } + + @Override + public ObjectId put(NativeRayObject obj) { + ObjectId objectId = IdUtil.computePutId(workerContext.getCurrentTaskId(), + workerContext.nextPutIndex()); + put(obj, objectId); + return objectId; + } + + @Override + public void put(NativeRayObject obj, ObjectId objectId) { + Preconditions.checkNotNull(obj); + Preconditions.checkNotNull(objectId); + pool.putIfAbsent(objectId, obj); + for (Consumer callback : objectPutCallbacks) { + callback.accept(objectId); + } + } + + @Override + public List get(List objectIds, long timeoutMs) { + waitInternal(objectIds, objectIds.size(), timeoutMs); + return objectIds.stream().map(pool::get).collect(Collectors.toList()); + } + + @Override + public List wait(List objectIds, int numObjects, long timeoutMs) { + waitInternal(objectIds, numObjects, timeoutMs); + return objectIds.stream().map(pool::containsKey).collect(Collectors.toList()); + } + + private void waitInternal(List objectIds, int numObjects, long timeoutMs) { + int ready = 0; + long remainingTime = timeoutMs; + boolean firstCheck = true; + while (ready < numObjects && (timeoutMs < 0 || remainingTime > 0)) { + if (!firstCheck) { + long sleepTime = Math.min(remainingTime, GET_CHECK_INTERVAL_MS); + try { + Thread.sleep(sleepTime); + } catch (InterruptedException e) { + LOGGER.warn("Got InterruptedException while sleeping."); + } + remainingTime -= sleepTime; + } + ready = 0; + for (ObjectId objectId : objectIds) { + if (pool.containsKey(objectId)) { + ready += 1; + } + } + firstCheck = false; + } + } + + @Override + public void delete(List objectIds, boolean localOnly, boolean deleteCreatingTasks) { + for (ObjectId objectId : objectIds) { + pool.remove(objectId); + } + } +} diff --git a/java/runtime/src/main/java/org/ray/runtime/objectstore/MockObjectStore.java b/java/runtime/src/main/java/org/ray/runtime/objectstore/MockObjectStore.java deleted file mode 100644 index f3d64c834..000000000 --- a/java/runtime/src/main/java/org/ray/runtime/objectstore/MockObjectStore.java +++ /dev/null @@ -1,148 +0,0 @@ -package org.ray.runtime.objectstore; - -import java.util.ArrayList; -import java.util.Arrays; -import java.util.List; -import java.util.Map; -import java.util.concurrent.ConcurrentHashMap; -import java.util.function.Consumer; -import java.util.stream.Collectors; - -import org.apache.arrow.plasma.ObjectStoreLink; -import org.ray.api.id.ObjectId; -import org.ray.runtime.RayDevRuntime; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -/** - * A mock implementation of {@code org.ray.spi.ObjectStoreLink}, which use Map to store data. - */ -public class MockObjectStore implements ObjectStoreLink { - - private static final Logger LOGGER = LoggerFactory.getLogger(MockObjectStore.class); - - private static final int GET_CHECK_INTERVAL_MS = 100; - - private final RayDevRuntime runtime; - private final Map data = new ConcurrentHashMap<>(); - private final Map metadata = new ConcurrentHashMap<>(); - private final List> objectPutCallbacks; - - public MockObjectStore(RayDevRuntime runtime) { - this.runtime = runtime; - this.objectPutCallbacks = new ArrayList<>(); - } - - public void addObjectPutCallback(Consumer callback) { - this.objectPutCallbacks.add(callback); - } - - @Override - public void put(byte[] objectId, byte[] value, byte[] metadataValue) { - if (objectId == null || objectId.length == 0 || value == null) { - LOGGER - .error("{} cannot put null: {}, {}", logPrefix(), objectId, Arrays.toString(value)); - System.exit(-1); - } - ObjectId id = new ObjectId(objectId); - data.put(id, value); - if (metadataValue != null) { - metadata.put(id, metadataValue); - } - for (Consumer callback : objectPutCallbacks) { - callback.accept(id); - } - } - - @Override - public byte[] get(byte[] objectId, int timeoutMs, boolean isMetadata) { - return get(new byte[][] {objectId}, timeoutMs, isMetadata).get(0); - } - - @Override - public List get(byte[][] objectIds, int timeoutMs, boolean isMetadata) { - return get(objectIds, timeoutMs) - .stream() - .map(data -> isMetadata ? data.metadata : data.data) - .collect(Collectors.toList()); - } - - @Override - public List get(byte[][] objectIds, int timeoutMs) { - int ready = 0; - int remainingTime = timeoutMs; - boolean firstCheck = true; - while (ready < objectIds.length && remainingTime > 0) { - if (!firstCheck) { - int sleepTime = Math.min(remainingTime, GET_CHECK_INTERVAL_MS); - try { - Thread.sleep(sleepTime); - } catch (InterruptedException e) { - LOGGER.warn("Got InterruptedException while sleeping."); - } - remainingTime -= sleepTime; - } - ready = 0; - for (byte[] id : objectIds) { - if (data.containsKey(new ObjectId(id))) { - ready += 1; - } - } - firstCheck = false; - } - ArrayList rets = new ArrayList<>(); - for (byte[] objId : objectIds) { - ObjectId objectId = new ObjectId(objId); - rets.add(new ObjectStoreData(metadata.get(objectId), data.get(objectId))); - } - return rets; - } - - @Override - public byte[] hash(byte[] objectId) { - return null; - } - - @Override - public long evict(long numBytes) { - return 0; - } - - @Override - public void release(byte[] objectId) { - return; - } - - @Override - public void delete(byte[] objectId) { - return; - } - - @Override - public boolean contains(byte[] objectId) { - return data.containsKey(new ObjectId(objectId)); - } - - private String logPrefix() { - return runtime.getWorkerContext().getCurrentTaskId() + "-" + getUserTrace() + " -> "; - } - - private String getUserTrace() { - StackTraceElement[] stes = Thread.currentThread().getStackTrace(); - int k = 1; - while (stes[k].getClassName().startsWith("org.ray") - && !stes[k].getClassName().contains("test")) { - k++; - } - return stes[k].getFileName() + ":" + stes[k].getLineNumber(); - } - - public boolean isObjectReady(ObjectId id) { - return data.containsKey(id); - } - - public void free(ObjectId id) { - data.remove(id); - metadata.remove(id); - } -} diff --git a/java/runtime/src/main/java/org/ray/runtime/objectstore/NativeRayObject.java b/java/runtime/src/main/java/org/ray/runtime/objectstore/NativeRayObject.java new file mode 100644 index 000000000..7146765c2 --- /dev/null +++ b/java/runtime/src/main/java/org/ray/runtime/objectstore/NativeRayObject.java @@ -0,0 +1,13 @@ +package org.ray.runtime.objectstore; + +public class NativeRayObject { + + public byte[] data; + public byte[] metadata; + + public NativeRayObject(byte[] data, byte[] metadata) { + this.data = data; + this.metadata = metadata; + } +} + diff --git a/java/runtime/src/main/java/org/ray/runtime/objectstore/ObjectInterface.java b/java/runtime/src/main/java/org/ray/runtime/objectstore/ObjectInterface.java new file mode 100644 index 000000000..5780dbd6c --- /dev/null +++ b/java/runtime/src/main/java/org/ray/runtime/objectstore/ObjectInterface.java @@ -0,0 +1,54 @@ +package org.ray.runtime.objectstore; + +import java.util.List; +import org.ray.api.id.ObjectId; + +/** + * The interface that contains all worker methods that are related to object store. + */ +public interface ObjectInterface { + + /** + * Put an object into object store. + * + * @param obj The ray object. + * @return Generated ID of the object. + */ + ObjectId put(NativeRayObject obj); + + /** + * Put an object with specified ID into object store. + * + * @param obj The ray object. + * @param objectId Object ID specified by user. + */ + void put(NativeRayObject obj, ObjectId objectId); + + /** + * Get a list of objects from the object store. + * + * @param objectIds IDs of the objects to get. + * @param timeoutMs Timeout in milliseconds, wait infinitely if it's negative. + * @return Result list of objects data. + */ + List get(List objectIds, long timeoutMs); + + /** + * Wait for a list of objects to appear in the object store. + * + * @param objectIds IDs of the objects to wait for. + * @param numObjects Number of objects that should appear. + * @param timeoutMs Timeout in milliseconds, wait infinitely if it's negative. + * @return A bitset that indicates each object has appeared or not. + */ + List wait(List objectIds, int numObjects, long timeoutMs); + + /** + * Delete a list of objects from the object store. + * + * @param objectIds IDs of the objects to delete. + * @param localOnly Whether only delete the objects in local node, or all nodes in the cluster. + * @param deleteCreatingTasks Whether also delete the tasks that created these objects. + */ + void delete(List objectIds, boolean localOnly, boolean deleteCreatingTasks); +} diff --git a/java/runtime/src/main/java/org/ray/runtime/objectstore/ObjectInterfaceImpl.java b/java/runtime/src/main/java/org/ray/runtime/objectstore/ObjectInterfaceImpl.java new file mode 100644 index 000000000..5e1774808 --- /dev/null +++ b/java/runtime/src/main/java/org/ray/runtime/objectstore/ObjectInterfaceImpl.java @@ -0,0 +1,91 @@ +package org.ray.runtime.objectstore; + +import java.util.List; +import java.util.stream.Collectors; +import org.ray.api.exception.RayException; +import org.ray.api.id.BaseId; +import org.ray.api.id.ObjectId; +import org.ray.runtime.AbstractRayRuntime; +import org.ray.runtime.WorkerContext; +import org.ray.runtime.raylet.RayletClient; +import org.ray.runtime.raylet.RayletClientImpl; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * This is a wrapper class for core worker object interface. + */ +public class ObjectInterfaceImpl implements ObjectInterface { + + private static final Logger LOGGER = LoggerFactory.getLogger(AbstractRayRuntime.class); + + /** + * The native pointer of core worker object interface. + */ + private final long nativeObjectInterfacePointer; + + public ObjectInterfaceImpl(WorkerContext workerContext, RayletClient rayletClient, + String storeSocketName) { + this.nativeObjectInterfacePointer = + nativeCreateObjectInterface(workerContext.getNativeWorkerContext(), + ((RayletClientImpl) rayletClient).getClient(), storeSocketName); + } + + @Override + public ObjectId put(NativeRayObject obj) { + return new ObjectId(nativePut(nativeObjectInterfacePointer, obj)); + } + + @Override + public void put(NativeRayObject obj, ObjectId objectId) { + try { + nativePut(nativeObjectInterfacePointer, objectId.getBytes(), obj); + } catch (RayException e) { + LOGGER.warn(e.getMessage()); + } + } + + @Override + public List get(List objectIds, long timeoutMs) { + return nativeGet(nativeObjectInterfacePointer, toBinaryList(objectIds), timeoutMs); + } + + @Override + public List wait(List objectIds, int numObjects, long timeoutMs) { + return nativeWait(nativeObjectInterfacePointer, toBinaryList(objectIds), numObjects, timeoutMs); + } + + @Override + public void delete(List objectIds, boolean localOnly, boolean deleteCreatingTasks) { + nativeDelete(nativeObjectInterfacePointer, toBinaryList(objectIds), localOnly, deleteCreatingTasks); + } + + public void destroy() { + nativeDestroy(nativeObjectInterfacePointer); + } + + private static List toBinaryList(List ids) { + return ids.stream().map(BaseId::getBytes).collect(Collectors.toList()); + } + + private static native long nativeCreateObjectInterface(long nativeObjectInterface, + long nativeRayletClient, + String storeSocketName); + + private static native byte[] nativePut(long nativeObjectInterface, NativeRayObject obj); + + private static native void nativePut(long nativeObjectInterface, byte[] objectId, + NativeRayObject obj); + + private static native List nativeGet(long nativeObjectInterface, + List ids, + long timeoutMs); + + private static native List nativeWait(long nativeObjectInterface, List objectIds, + int numObjects, long timeoutMs); + + private static native void nativeDelete(long nativeObjectInterface, List objectIds, + boolean localOnly, boolean deleteCreatingTasks); + + private static native void nativeDestroy(long nativeObjectInterface); +} diff --git a/java/runtime/src/main/java/org/ray/runtime/objectstore/ObjectStoreProxy.java b/java/runtime/src/main/java/org/ray/runtime/objectstore/ObjectStoreProxy.java index 1a7e4701c..5470d719b 100644 --- a/java/runtime/src/main/java/org/ray/runtime/objectstore/ObjectStoreProxy.java +++ b/java/runtime/src/main/java/org/ray/runtime/objectstore/ObjectStoreProxy.java @@ -4,20 +4,14 @@ import com.google.common.collect.ImmutableList; import java.util.ArrayList; import java.util.Arrays; import java.util.List; -import org.apache.arrow.plasma.ObjectStoreLink; -import org.apache.arrow.plasma.ObjectStoreLink.ObjectStoreData; -import org.apache.arrow.plasma.PlasmaClient; -import org.apache.arrow.plasma.exceptions.DuplicateObjectException; import org.ray.api.exception.RayActorException; import org.ray.api.exception.RayException; +import org.ray.api.exception.RayTaskException; import org.ray.api.exception.RayWorkerException; import org.ray.api.exception.UnreconstructableException; import org.ray.api.id.ObjectId; -import org.ray.runtime.AbstractRayRuntime; -import org.ray.runtime.RayDevRuntime; -import org.ray.runtime.config.RunMode; +import org.ray.runtime.WorkerContext; import org.ray.runtime.generated.Gcs.ErrorType; -import org.ray.runtime.util.IdUtil; import org.ray.runtime.util.Serializer; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -36,21 +30,18 @@ public class ObjectStoreProxy { private static final byte[] UNRECONSTRUCTABLE_EXCEPTION_META = String .valueOf(ErrorType.OBJECT_UNRECONSTRUCTABLE.getNumber()).getBytes(); + private static final byte[] TASK_EXECUTION_EXCEPTION_META = String + .valueOf(ErrorType.TASK_EXECUTION_EXCEPTION.getNumber()).getBytes(); + private static final byte[] RAW_TYPE_META = "RAW".getBytes(); - private final AbstractRayRuntime runtime; + private final WorkerContext workerContext; - private static ThreadLocal objectStore; + private final ObjectInterface objectInterface; - public ObjectStoreProxy(AbstractRayRuntime runtime, String storeSocketName) { - this.runtime = runtime; - objectStore = ThreadLocal.withInitial(() -> { - if (runtime.getRayConfig().runMode == RunMode.CLUSTER) { - return new PlasmaClient(storeSocketName, "", 0); - } else { - return ((RayDevRuntime) runtime).getObjectStore(); - } - }); + public ObjectStoreProxy(WorkerContext workerContext, ObjectInterface objectInterface) { + this.workerContext = workerContext; + this.objectInterface = objectInterface; } /** @@ -75,46 +66,44 @@ public class ObjectStoreProxy { * @return A list of GetResult objects. */ public List> get(List ids, int timeoutMs) { - byte[][] binaryIds = IdUtil.getIdBytes(ids); - List dataAndMetaList = objectStore.get().get(binaryIds, timeoutMs); + List dataAndMetaList = objectInterface.get(ids, timeoutMs); List> results = new ArrayList<>(); for (int i = 0; i < dataAndMetaList.size(); i++) { - byte[] meta = dataAndMetaList.get(i).metadata; - byte[] data = dataAndMetaList.get(i).data; - + NativeRayObject dataAndMeta = dataAndMetaList.get(i); GetResult result; - if (meta != null) { - // If meta is not null, deserialize the object from meta. - result = deserializeFromMeta(meta, data, ids.get(i)); - } else if (data != null) { - // If data is not null, deserialize the Java object. - Object object = Serializer.decode(data, runtime.getWorkerContext().getCurrentClassLoader()); - if (object instanceof RayException) { - // If the object is a `RayException`, it means that an error occurred during task - // execution. - result = new GetResult<>(true, null, (RayException) object); + if (dataAndMeta != null) { + byte[] meta = dataAndMeta.metadata; + byte[] data = dataAndMeta.data; + if (meta != null && meta.length > 0) { + // If meta is not null, deserialize the object from meta. + result = deserializeFromMeta(meta, data, + workerContext.getCurrentClassLoader(), ids.get(i)); } else { - // Otherwise, the object is valid. - result = new GetResult<>(true, (T) object, null); + // If data is not null, deserialize the Java object. + Object object = Serializer.decode(data, workerContext.getCurrentClassLoader()); + if (object instanceof RayException) { + // If the object is a `RayException`, it means that an error occurred during task + // execution. + result = new GetResult<>(true, null, (RayException) object); + } else { + // Otherwise, the object is valid. + result = new GetResult<>(true, (T) object, null); + } } } else { // If both meta and data are null, the object doesn't exist in object store. result = new GetResult<>(false, null, null); } - if (meta != null || data != null) { - // Release the object from object store.. - objectStore.get().release(binaryIds[i]); - } - results.add(result); } return results; } @SuppressWarnings("unchecked") - private GetResult deserializeFromMeta(byte[] meta, byte[] data, ObjectId objectId) { + private GetResult deserializeFromMeta(byte[] meta, byte[] data, + ClassLoader classLoader, ObjectId objectId) { if (Arrays.equals(meta, RAW_TYPE_META)) { return (GetResult) new GetResult<>(true, data, null); } else if (Arrays.equals(meta, WORKER_EXCEPTION_META)) { @@ -123,6 +112,8 @@ public class ObjectStoreProxy { return new GetResult<>(true, null, RayActorException.INSTANCE); } else if (Arrays.equals(meta, UNRECONSTRUCTABLE_EXCEPTION_META)) { return new GetResult<>(true, null, new UnreconstructableException(objectId)); + } else if (Arrays.equals(meta, TASK_EXECUTION_EXCEPTION_META)) { + return new GetResult<>(true, null, Serializer.decode(data, classLoader)); } throw new IllegalArgumentException("Unrecognized metadata " + Arrays.toString(meta)); } @@ -134,16 +125,14 @@ public class ObjectStoreProxy { * @param object The object to put. */ public void put(ObjectId id, Object object) { - try { - if (object instanceof byte[]) { - // If the object is a byte array, skip serializing it and use a special metadata to - // indicate it's raw binary. So that this object can also be read by Python. - objectStore.get().put(id.getBytes(), (byte[]) object, RAW_TYPE_META); - } else { - objectStore.get().put(id.getBytes(), Serializer.encode(object), null); - } - } catch (DuplicateObjectException e) { - LOGGER.warn(e.getMessage()); + if (object instanceof byte[]) { + // If the object is a byte array, skip serializing it and use a special metadata to + // indicate it's raw binary. So that this object can also be read by Python. + objectInterface.put(new NativeRayObject((byte[]) object, RAW_TYPE_META), id); + } else if (object instanceof RayTaskException) { + objectInterface.put(new NativeRayObject(Serializer.encode(object), TASK_EXECUTION_EXCEPTION_META), id); + } else { + objectInterface.put(new NativeRayObject(Serializer.encode(object), null), id); } } @@ -154,11 +143,7 @@ public class ObjectStoreProxy { * @param serializedObject The serialized object to put. */ public void putSerialized(ObjectId id, byte[] serializedObject) { - try { - objectStore.get().put(id.getBytes(), serializedObject, null); - } catch (DuplicateObjectException e) { - LOGGER.warn(e.getMessage()); - } + objectInterface.put(new NativeRayObject(serializedObject, null), id); } /** diff --git a/java/runtime/src/main/java/org/ray/runtime/raylet/MockRayletClient.java b/java/runtime/src/main/java/org/ray/runtime/raylet/MockRayletClient.java index 0dc8f4c9e..38995bf9b 100644 --- a/java/runtime/src/main/java/org/ray/runtime/raylet/MockRayletClient.java +++ b/java/runtime/src/main/java/org/ray/runtime/raylet/MockRayletClient.java @@ -14,6 +14,7 @@ import java.util.concurrent.ConcurrentLinkedDeque; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; +import java.util.stream.Collectors; import org.apache.commons.lang3.NotImplementedException; import org.ray.api.RayObject; import org.ray.api.WaitResult; @@ -23,7 +24,8 @@ import org.ray.api.id.TaskId; import org.ray.api.id.UniqueId; import org.ray.runtime.RayDevRuntime; import org.ray.runtime.Worker; -import org.ray.runtime.objectstore.MockObjectStore; +import org.ray.runtime.objectstore.MockObjectInterface; +import org.ray.runtime.objectstore.NativeRayObject; import org.ray.runtime.task.FunctionArg; import org.ray.runtime.task.TaskSpec; import org.slf4j.Logger; @@ -37,7 +39,7 @@ public class MockRayletClient implements RayletClient { private static final Logger LOGGER = LoggerFactory.getLogger(MockRayletClient.class); private final Map> waitingTasks = new ConcurrentHashMap<>(); - private final MockObjectStore store; + private final MockObjectInterface objectInterface; private final RayDevRuntime runtime; private final ExecutorService exec; private final Deque idleWorkers; @@ -46,8 +48,8 @@ public class MockRayletClient implements RayletClient { public MockRayletClient(RayDevRuntime runtime, int numberThreads) { this.runtime = runtime; - this.store = runtime.getObjectStore(); - store.addObjectPutCallback(this::onObjectPut); + this.objectInterface = runtime.getObjectInterface(); + objectInterface.addObjectPutCallback(this::onObjectPut); // The thread pool that executes tasks in parallel. exec = Executors.newFixedThreadPool(numberThreads); idleWorkers = new ConcurrentLinkedDeque<>(); @@ -113,8 +115,8 @@ public class MockRayletClient implements RayletClient { // can be executed. if (task.isActorCreationTask() || task.isActorTask()) { ObjectId[] returnIds = task.returnIds; - store.put(returnIds[returnIds.length - 1].getBytes(), - new byte[]{}, new byte[]{}); + objectInterface.put(new NativeRayObject(new byte[] {}, new byte[] {}), + returnIds[returnIds.length - 1]); } } finally { returnWorker(worker); @@ -133,13 +135,13 @@ public class MockRayletClient implements RayletClient { // Check whether task arguments are ready. for (FunctionArg arg : spec.args) { if (arg.id != null) { - if (!store.isObjectReady(arg.id)) { + if (!objectInterface.isObjectReady(arg.id)) { unreadyObjects.add(arg.id); } } } if (spec.isActorTask()) { - if (!store.isObjectReady(spec.previousActorTaskDummyObjectId)) { + if (!objectInterface.isObjectReady(spec.previousActorTaskDummyObjectId)) { unreadyObjects.add(spec.previousActorTaskDummyObjectId); } } @@ -154,7 +156,7 @@ public class MockRayletClient implements RayletClient { @Override public void fetchOrReconstruct(List objectIds, boolean fetchOnly, - TaskId currentTaskId) { + TaskId currentTaskId) { } @@ -170,20 +172,17 @@ public class MockRayletClient implements RayletClient { @Override public WaitResult wait(List> waitFor, int numReturns, int - timeoutMs, TaskId currentTaskId) { + timeoutMs, TaskId currentTaskId) { if (waitFor == null || waitFor.isEmpty()) { return new WaitResult<>(ImmutableList.of(), ImmutableList.of()); } - byte[][] ids = new byte[waitFor.size()][]; - for (int i = 0; i < waitFor.size(); i++) { - ids[i] = waitFor.get(i).getId().getBytes(); - } + List ids = waitFor.stream().map(RayObject::getId).collect(Collectors.toList()); List> readyList = new ArrayList<>(); List> unreadyList = new ArrayList<>(); - List result = store.get(ids, timeoutMs, false); + List result = objectInterface.wait(ids, ids.size(), timeoutMs); for (int i = 0; i < waitFor.size(); i++) { - if (result.get(i) != null) { + if (result.get(i)) { readyList.add(waitFor.get(i)); } else { unreadyList.add(waitFor.get(i)); @@ -195,9 +194,7 @@ public class MockRayletClient implements RayletClient { @Override public void freePlasmaObjects(List objectIds, boolean localOnly, boolean deleteCreatingTasks) { - for (ObjectId id : objectIds) { - store.free(id); - } + objectInterface.delete(objectIds, localOnly, deleteCreatingTasks); } diff --git a/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClientImpl.java b/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClientImpl.java index 059edbe67..a1e11141e 100644 --- a/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClientImpl.java +++ b/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClientImpl.java @@ -4,8 +4,6 @@ import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableList; import com.google.protobuf.ByteString; import com.google.protobuf.InvalidProtocolBufferException; -import java.nio.ByteBuffer; -import java.nio.ByteOrder; import java.nio.charset.Charset; import java.util.ArrayList; import java.util.Arrays; @@ -40,11 +38,15 @@ public class RayletClientImpl implements RayletClient { // TODO(qwang): JobId parameter can be removed once we embed jobId in driverId. public RayletClientImpl(String schedulerSockName, UniqueId clientId, - boolean isWorker, JobId jobId) { + boolean isWorker, JobId jobId) { client = nativeInit(schedulerSockName, clientId.getBytes(), isWorker, jobId.getBytes()); } + public long getClient() { + return client; + } + @Override public WaitResult wait(List> waitFor, int numReturns, int timeoutMs, TaskId currentTaskId) { @@ -133,7 +135,7 @@ public class RayletClientImpl implements RayletClient { /** * Parse `TaskSpec` protobuf bytes. */ - private static TaskSpec parseTaskSpecFromProtobuf(byte[] bytes) { + public static TaskSpec parseTaskSpecFromProtobuf(byte[] bytes) { Common.TaskSpec taskSpec; try { taskSpec = Common.TaskSpec.parseFrom(bytes); @@ -214,7 +216,7 @@ public class RayletClientImpl implements RayletClient { /** * Convert a `TaskSpec` to protobuf-serialized bytes. */ - private static byte[] convertTaskSpecToProtobuf(TaskSpec task) { + public static byte[] convertTaskSpecToProtobuf(TaskSpec task) { // Set common fields. Common.TaskSpec.Builder builder = Common.TaskSpec.newBuilder() .setJobId(ByteString.copyFrom(task.jobId.getBytes())) diff --git a/java/runtime/src/main/java/org/ray/runtime/util/IdUtil.java b/java/runtime/src/main/java/org/ray/runtime/util/IdUtil.java index 8a96bc57a..6f9c95ea4 100644 --- a/java/runtime/src/main/java/org/ray/runtime/util/IdUtil.java +++ b/java/runtime/src/main/java/org/ray/runtime/util/IdUtil.java @@ -154,18 +154,6 @@ public class IdUtil { } - /** - * Compute the driver id from the given job. - */ - public static UniqueId computeDriverId(JobId jobId) { - byte[] bytes = new byte[UniqueId.LENGTH]; - System.arraycopy(jobId.getBytes(), 0, bytes, 0, jobId.size()); - Arrays.fill(bytes, jobId.size(), UniqueId.LENGTH, (byte)0xFF); - ByteBuffer wbb = ByteBuffer.wrap(bytes); - wbb.order(ByteOrder.LITTLE_ENDIAN); - return new UniqueId(bytes); - } - /** * Compute the murmur hash code of this ID. */ diff --git a/java/test/src/main/java/org/ray/api/test/FailureTest.java b/java/test/src/main/java/org/ray/api/test/FailureTest.java index 6d47a2fc9..b47b010ae 100644 --- a/java/test/src/main/java/org/ray/api/test/FailureTest.java +++ b/java/test/src/main/java/org/ray/api/test/FailureTest.java @@ -1,12 +1,18 @@ package org.ray.api.test; +import java.time.Duration; +import java.time.Instant; +import java.util.Arrays; +import java.util.List; import org.ray.api.Ray; import org.ray.api.RayActor; import org.ray.api.RayObject; import org.ray.api.TestUtils; import org.ray.api.exception.RayActorException; +import org.ray.api.exception.RayException; import org.ray.api.exception.RayTaskException; import org.ray.api.exception.RayWorkerException; +import org.ray.api.function.RayFunc0; import org.testng.Assert; import org.testng.annotations.Test; @@ -23,6 +29,15 @@ public class FailureTest extends BaseTest { return 0; } + public static int slowFunc() { + try { + Thread.sleep(10000); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + return 0; + } + public static class BadActor { public BadActor(boolean failOnCreation) { @@ -106,5 +121,26 @@ public class FailureTest extends BaseTest { // RayActorException. } } + + @Test + public void testGetThrowsQuicklyWhenFoundException() { + TestUtils.skipTestUnderSingleProcess(); + List> badFunctions = Arrays.asList(FailureTest::badFunc, + FailureTest::badFunc2); + for (RayFunc0 badFunc : badFunctions) { + RayObject obj1 = Ray.call(badFunc); + RayObject obj2 = Ray.call(FailureTest::slowFunc); + Instant start = Instant.now(); + try { + Ray.get(Arrays.asList(obj1.getId(), obj2.getId())); + Assert.fail("Should throw RayException."); + } catch (RayException e) { + Instant end = Instant.now(); + long duration = Duration.between(start, end).toMillis(); + Assert.assertTrue(duration < 5000, "Should fail quickly. " + + "Actual execution time: " + duration + " ms."); + } + } + } } diff --git a/java/test/src/main/java/org/ray/api/test/PlasmaStoreTest.java b/java/test/src/main/java/org/ray/api/test/PlasmaStoreTest.java index 7abc3f421..84adba6d7 100644 --- a/java/test/src/main/java/org/ray/api/test/PlasmaStoreTest.java +++ b/java/test/src/main/java/org/ray/api/test/PlasmaStoreTest.java @@ -1,12 +1,10 @@ package org.ray.api.test; -import org.apache.arrow.plasma.PlasmaClient; -import org.apache.arrow.plasma.exceptions.DuplicateObjectException; - import org.ray.api.Ray; import org.ray.api.TestUtils; -import org.ray.api.id.UniqueId; +import org.ray.api.id.ObjectId; import org.ray.runtime.AbstractRayRuntime; +import org.ray.runtime.objectstore.ObjectStoreProxy; import org.testng.Assert; import org.testng.annotations.Test; @@ -15,15 +13,13 @@ public class PlasmaStoreTest extends BaseTest { @Test public void testPutWithDuplicateId() { TestUtils.skipTestUnderSingleProcess(); - UniqueId objectId = UniqueId.randomId(); + ObjectId objectId = ObjectId.randomId(); AbstractRayRuntime runtime = (AbstractRayRuntime) Ray.internal(); - PlasmaClient store = new PlasmaClient(runtime.getRayConfig().objectStoreSocketName, "", 0); - store.put(objectId.getBytes(), new byte[]{}, new byte[]{}); - try { - store.put(objectId.getBytes(), new byte[]{}, new byte[]{}); - Assert.fail("This line shouldn't be reached."); - } catch (DuplicateObjectException e) { - // Putting 2 objects with duplicate ID should throw DuplicateObjectException. - } + ObjectStoreProxy objectInterface = runtime.getObjectStoreProxy(); + objectInterface.put(objectId, 1); + Assert.assertEquals(objectInterface.get(objectId, -1).object, (Integer) 1); + objectInterface.put(objectId, 2); + // Putting 2 objects with duplicate ID should fail but ignored. + Assert.assertEquals(objectInterface.get(objectId, -1).object, (Integer) 1); } } diff --git a/java/test/src/main/java/org/ray/api/test/RayConfigTest.java b/java/test/src/main/java/org/ray/api/test/RayConfigTest.java index 5b6834e5e..ebc342722 100644 --- a/java/test/src/main/java/org/ray/api/test/RayConfigTest.java +++ b/java/test/src/main/java/org/ray/api/test/RayConfigTest.java @@ -1,7 +1,7 @@ package org.ray.api.test; import org.ray.runtime.config.RayConfig; -import org.ray.runtime.config.WorkerMode; +import org.ray.runtime.generated.Common.WorkerType; import org.testng.Assert; import org.testng.annotations.Test; @@ -12,7 +12,7 @@ public class RayConfigTest { try { System.setProperty("ray.job.resource-path", "path/to/ray/job/resource/path"); RayConfig rayConfig = RayConfig.create(); - Assert.assertEquals(WorkerMode.DRIVER, rayConfig.workerMode); + Assert.assertEquals(WorkerType.DRIVER, rayConfig.workerMode); Assert.assertEquals("path/to/ray/job/resource/path", rayConfig.jobResourcePath); } finally { // Unset system properties. diff --git a/src/ray/common/ray_config_def.h b/src/ray/common/ray_config_def.h index 3888e3ba4..462699002 100644 --- a/src/ray/common/ray_config_def.h +++ b/src/ray/common/ray_config_def.h @@ -151,3 +151,10 @@ RAY_CONFIG(uint32_t, num_actor_checkpoints_to_keep, 20) /// Maximum number of ids in one batch to send to GCS to delete keys. RAY_CONFIG(uint32_t, maximum_gcs_deletion_batch_size, 1000) + +/// When getting objects from object store, print a warning every this number of attempts. +RAY_CONFIG(uint32_t, object_store_get_warn_per_num_attempts, 50) + +/// When getting objects from object store, max number of ids to print in the warning +/// message. +RAY_CONFIG(uint32_t, object_store_get_max_ids_to_print_in_warning, 20) diff --git a/src/ray/core_worker/common.h b/src/ray/core_worker/common.h index aabb3fa83..d265aa536 100644 --- a/src/ray/core_worker/common.h +++ b/src/ray/core_worker/common.h @@ -9,9 +9,7 @@ #include "ray/raylet/raylet_client.h" namespace ray { - -/// Type of this worker. -enum class WorkerType { WORKER, DRIVER }; +using WorkerType = rpc::WorkerType; /// Information about a remote function. struct RayFunction { diff --git a/src/ray/core_worker/context.cc b/src/ray/core_worker/context.cc index c5d7e7857..b655e4588 100644 --- a/src/ray/core_worker/context.cc +++ b/src/ray/core_worker/context.cc @@ -6,69 +6,81 @@ namespace ray { /// per-thread context for core worker. struct WorkerThreadContext { WorkerThreadContext() - : current_task_id(TaskID::FromRandom()), task_index(0), put_index(0) {} + : current_task_id_(TaskID::FromRandom()), task_index_(0), put_index_(0) {} - int GetNextTaskIndex() { return ++task_index; } + int GetNextTaskIndex() { return ++task_index_; } - int GetNextPutIndex() { return ++put_index; } + int GetNextPutIndex() { return ++put_index_; } - const TaskID &GetCurrentTaskID() const { return current_task_id; } + const TaskID &GetCurrentTaskID() const { return current_task_id_; } - void SetCurrentTask(const TaskID &task_id) { - current_task_id = task_id; - task_index = 0; - put_index = 0; + std::shared_ptr GetCurrentTask() const { + return current_task_; + } + + void SetCurrentTaskId(const TaskID &task_id) { + current_task_id_ = task_id; + task_index_ = 0; + put_index_ = 0; } void SetCurrentTask(const TaskSpecification &task_spec) { - SetCurrentTask(task_spec.TaskId()); + SetCurrentTaskId(task_spec.TaskId()); + current_task_ = std::make_shared(task_spec); } private: /// The task ID for current task. - TaskID current_task_id; + TaskID current_task_id_; + + /// The current task. + std::shared_ptr current_task_; /// Number of tasks that have been submitted from current task. - int task_index; + int task_index_; /// Number of objects that have been put from current task. - int put_index; + int put_index_; }; thread_local std::unique_ptr WorkerContext::thread_context_ = nullptr; WorkerContext::WorkerContext(WorkerType worker_type, const JobID &job_id) - : worker_type(worker_type), - worker_id(worker_type == WorkerType::DRIVER ? ComputeDriverIdFromJob(job_id) - : WorkerID::FromRandom()), - current_job_id(worker_type == WorkerType::DRIVER ? job_id : JobID::Nil()) { + : worker_type_(worker_type), + worker_id_(worker_type_ == WorkerType::DRIVER ? ComputeDriverIdFromJob(job_id) + : WorkerID::FromRandom()), + current_job_id_(worker_type_ == WorkerType::DRIVER ? job_id : JobID::Nil()) { // For worker main thread which initializes the WorkerContext, // set task_id according to whether current worker is a driver. // (For other threads it's set to random ID via GetThreadContext). - GetThreadContext().SetCurrentTask( - (worker_type == WorkerType::DRIVER) ? TaskID::FromRandom() : TaskID::Nil()); + GetThreadContext().SetCurrentTaskId( + (worker_type_ == WorkerType::DRIVER) ? TaskID::FromRandom() : TaskID::Nil()); } -const WorkerType WorkerContext::GetWorkerType() const { return worker_type; } +const WorkerType WorkerContext::GetWorkerType() const { return worker_type_; } -const WorkerID &WorkerContext::GetWorkerID() const { return worker_id; } +const WorkerID &WorkerContext::GetWorkerID() const { return worker_id_; } int WorkerContext::GetNextTaskIndex() { return GetThreadContext().GetNextTaskIndex(); } int WorkerContext::GetNextPutIndex() { return GetThreadContext().GetNextPutIndex(); } -const JobID &WorkerContext::GetCurrentJobID() const { return current_job_id; } +const JobID &WorkerContext::GetCurrentJobID() const { return current_job_id_; } const TaskID &WorkerContext::GetCurrentTaskID() const { return GetThreadContext().GetCurrentTaskID(); } void WorkerContext::SetCurrentTask(const TaskSpecification &task_spec) { - current_job_id = task_spec.JobId(); + current_job_id_ = task_spec.JobId(); GetThreadContext().SetCurrentTask(task_spec); } +std::shared_ptr WorkerContext::GetCurrentTask() const { + return GetThreadContext().GetCurrentTask(); +} + WorkerThreadContext &WorkerContext::GetThreadContext() { if (thread_context_ == nullptr) { thread_context_ = std::unique_ptr(new WorkerThreadContext()); diff --git a/src/ray/core_worker/context.h b/src/ray/core_worker/context.h index 629249103..8405501d3 100644 --- a/src/ray/core_worker/context.h +++ b/src/ray/core_worker/context.h @@ -22,19 +22,21 @@ class WorkerContext { void SetCurrentTask(const TaskSpecification &task_spec); + std::shared_ptr GetCurrentTask() const; + int GetNextTaskIndex(); int GetNextPutIndex(); private: /// Type of the worker. - const WorkerType worker_type; + const WorkerType worker_type_; /// ID for this worker. - const WorkerID worker_id; + const WorkerID worker_id_; /// Job ID for this worker. - JobID current_job_id; + JobID current_job_id_; private: static WorkerThreadContext &GetThreadContext(); diff --git a/src/ray/core_worker/core_worker.cc b/src/ray/core_worker/core_worker.cc index 6fa560f27..e49ca9972 100644 --- a/src/ray/core_worker/core_worker.cc +++ b/src/ray/core_worker/core_worker.cc @@ -15,7 +15,7 @@ CoreWorker::CoreWorker( task_interface_(worker_context_, raylet_client_), object_interface_(worker_context_, raylet_client_, store_socket) { int rpc_server_port = 0; - if (worker_type_ == ray::WorkerType::WORKER) { + if (worker_type_ == WorkerType::WORKER) { RAY_CHECK(execution_callback != nullptr); task_execution_interface_ = std::unique_ptr( new CoreWorkerTaskExecutionInterface(worker_context_, raylet_client_, @@ -28,8 +28,8 @@ CoreWorker::CoreWorker( // instead of crashing. raylet_client_ = std::unique_ptr(new RayletClient( raylet_socket_, ClientID::FromBinary(worker_context_.GetWorkerID().Binary()), - (worker_type_ == ray::WorkerType::WORKER), worker_context_.GetCurrentJobID(), - language_, rpc_server_port)); + (worker_type_ == WorkerType::WORKER), worker_context_.GetCurrentJobID(), language_, + rpc_server_port)); } } // namespace ray diff --git a/src/ray/core_worker/lib/java/jni_init.cc b/src/ray/core_worker/lib/java/jni_init.cc new file mode 100644 index 000000000..6c66f8f2f --- /dev/null +++ b/src/ray/core_worker/lib/java/jni_init.cc @@ -0,0 +1,75 @@ +#include "ray/core_worker/lib/java/jni_utils.h" + +jclass java_boolean_class; +jmethodID java_boolean_init; + +jclass java_list_class; +jmethodID java_list_size; +jmethodID java_list_get; +jmethodID java_list_add; + +jclass java_array_list_class; +jmethodID java_array_list_init; +jmethodID java_array_list_init_with_capacity; + +jclass java_ray_exception_class; + +jclass java_native_ray_object_class; +jmethodID java_native_ray_object_init; +jfieldID java_native_ray_object_data; +jfieldID java_native_ray_object_metadata; + +jint JNI_VERSION = JNI_VERSION_1_8; + +inline jclass LoadClass(JNIEnv *env, const char *class_name) { + jclass tempLocalClassRef = env->FindClass(class_name); + jclass ret = (jclass)env->NewGlobalRef(tempLocalClassRef); + env->DeleteLocalRef(tempLocalClassRef); + return ret; +} + +/// Load and cache frequently-used Java classes and methods +jint JNI_OnLoad(JavaVM *vm, void *reserved) { + JNIEnv *env; + if (vm->GetEnv(reinterpret_cast(&env), JNI_VERSION) != JNI_OK) { + return JNI_ERR; + } + + java_boolean_class = LoadClass(env, "java/lang/Boolean"); + java_boolean_init = env->GetMethodID(java_boolean_class, "", "(Z)V"); + + java_list_class = LoadClass(env, "java/util/List"); + java_list_size = env->GetMethodID(java_list_class, "size", "()I"); + java_list_get = env->GetMethodID(java_list_class, "get", "(I)Ljava/lang/Object;"); + java_list_add = env->GetMethodID(java_list_class, "add", "(Ljava/lang/Object;)Z"); + + java_array_list_class = LoadClass(env, "java/util/ArrayList"); + java_array_list_init = env->GetMethodID(java_array_list_class, "", "()V"); + java_array_list_init_with_capacity = + env->GetMethodID(java_array_list_class, "", "(I)V"); + + java_ray_exception_class = LoadClass(env, "org/ray/api/exception/RayException"); + + java_native_ray_object_class = + LoadClass(env, "org/ray/runtime/objectstore/NativeRayObject"); + java_native_ray_object_init = + env->GetMethodID(java_native_ray_object_class, "", "([B[B)V"); + java_native_ray_object_data = + env->GetFieldID(java_native_ray_object_class, "data", "[B"); + java_native_ray_object_metadata = + env->GetFieldID(java_native_ray_object_class, "metadata", "[B"); + + return JNI_VERSION; +} + +/// Unload java classes +void JNI_OnUnload(JavaVM *vm, void *reserved) { + JNIEnv *env; + vm->GetEnv(reinterpret_cast(&env), JNI_VERSION); + + env->DeleteGlobalRef(java_boolean_class); + env->DeleteGlobalRef(java_list_class); + env->DeleteGlobalRef(java_array_list_class); + env->DeleteGlobalRef(java_ray_exception_class); + env->DeleteGlobalRef(java_native_ray_object_class); +} diff --git a/src/ray/core_worker/lib/java/jni_utils.h b/src/ray/core_worker/lib/java/jni_utils.h new file mode 100644 index 000000000..d0f4ca8a5 --- /dev/null +++ b/src/ray/core_worker/lib/java/jni_utils.h @@ -0,0 +1,180 @@ +#ifndef RAY_COMMON_JAVA_JNI_HELPER_H +#define RAY_COMMON_JAVA_JNI_HELPER_H + +#include +#include "ray/common/buffer.h" +#include "ray/common/id.h" +#include "ray/common/status.h" +#include "ray/core_worker/store_provider/store_provider.h" + +/// Boolean class +extern jclass java_boolean_class; +/// Constructor of Boolean class +extern jmethodID java_boolean_init; + +/// List class +extern jclass java_list_class; +/// size method of List class +extern jmethodID java_list_size; +/// get method of List class +extern jmethodID java_list_get; +/// add method of List class +extern jmethodID java_list_add; + +/// ArrayList class +extern jclass java_array_list_class; +/// Constructor of ArrayList class +extern jmethodID java_array_list_init; +/// Constructor of ArrayList class with single parameter capacity +extern jmethodID java_array_list_init_with_capacity; + +/// RayException class +extern jclass java_ray_exception_class; + +/// NativeRayObject class +extern jclass java_native_ray_object_class; +/// Constructor of NativeRayObject class +extern jmethodID java_native_ray_object_init; +/// data field of NativeRayObject class +extern jfieldID java_native_ray_object_data; +/// metadata field of NativeRayObject class +extern jfieldID java_native_ray_object_metadata; + +/// Throws a Java RayException if the status is not OK. +#define THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, ret) \ + { \ + if (!(status).ok()) { \ + (env)->ThrowNew(java_ray_exception_class, (status).message().c_str()); \ + return (ret); \ + } \ + } + +/// Convert a Java byte array to a C++ UniqueID. +template +inline ID JavaByteArrayToId(JNIEnv *env, const jbyteArray &bytes) { + std::string id_str(ID::Size(), 0); + env->GetByteArrayRegion(bytes, 0, ID::Size(), + reinterpret_cast(&id_str.front())); + return ID::FromBinary(id_str); +} + +/// Convert C++ UniqueID to a Java byte array. +template +inline jbyteArray IdToJavaByteArray(JNIEnv *env, const ID &id) { + jbyteArray array = env->NewByteArray(ID::Size()); + env->SetByteArrayRegion(array, 0, ID::Size(), + reinterpret_cast(id.Data())); + return array; +} + +/// Convert C++ UniqueID to a Java ByteBuffer. +template +inline jobject IdToJavaByteBuffer(JNIEnv *env, const ID &id) { + return env->NewDirectByteBuffer( + reinterpret_cast(const_cast(id.Data())), id.Size()); +} + +/// Convert a Java String to C++ std::string. +inline std::string JavaStringToNativeString(JNIEnv *env, jstring jstr) { + const char *c_str = env->GetStringUTFChars(jstr, nullptr); + std::string result(c_str); + env->ReleaseStringUTFChars(static_cast(jstr), c_str); + return result; +} + +/// Convert a Java List to C++ std::vector. +template +inline void JavaListToNativeVector( + JNIEnv *env, jobject java_list, std::vector *native_vector, + std::function element_converter) { + int size = env->CallIntMethod(java_list, java_list_size); + native_vector->clear(); + for (int i = 0; i < size; i++) { + native_vector->emplace_back( + element_converter(env, env->CallObjectMethod(java_list, java_list_get, (jint)i))); + } +} + +/// Convert a C++ std::vector to a Java List. +template +inline jobject NativeVectorToJavaList( + JNIEnv *env, const std::vector &native_vector, + std::function element_converter) { + jobject java_list = + env->NewObject(java_array_list_class, java_array_list_init_with_capacity, + (jint)native_vector.size()); + for (const auto &item : native_vector) { + env->CallVoidMethod(java_list, java_list_add, element_converter(env, item)); + } + return java_list; +} + +/// Convert a C++ ray::Buffer to a Java byte array. +inline jbyteArray NativeBufferToJavaByteArray(JNIEnv *env, + const std::shared_ptr buffer) { + if (!buffer) { + return nullptr; + } + jbyteArray java_byte_array = env->NewByteArray(buffer->Size()); + if (buffer->Size() > 0) { + env->SetByteArrayRegion(java_byte_array, 0, buffer->Size(), + reinterpret_cast(buffer->Data())); + } + return java_byte_array; +} + +/// A helper method to help access a Java NativeRayObject instance and ensure memory +/// safety. +/// +/// \param[in] java_obj The Java NativeRayObject object. +/// \param[in] reader The callback function to access a C++ ray::RayObject instance. +/// \return The return value of callback function. +template +inline ReturnT ReadJavaNativeRayObject( + JNIEnv *env, const jobject &java_obj, + std::function &)> reader) { + if (!java_obj) { + return reader(nullptr); + } + auto java_data = (jbyteArray)env->GetObjectField(java_obj, java_native_ray_object_data); + auto java_metadata = + (jbyteArray)env->GetObjectField(java_obj, java_native_ray_object_metadata); + auto data_size = env->GetArrayLength(java_data); + jbyte *data = data_size > 0 ? env->GetByteArrayElements(java_data, nullptr) : nullptr; + auto metadata_size = java_metadata ? env->GetArrayLength(java_metadata) : 0; + jbyte *metadata = + metadata_size > 0 ? env->GetByteArrayElements(java_metadata, nullptr) : nullptr; + auto data_buffer = std::make_shared( + reinterpret_cast(data), data_size); + auto metadata_buffer = java_metadata + ? std::make_shared( + reinterpret_cast(metadata), metadata_size) + : nullptr; + + auto native_obj = std::make_shared(data_buffer, metadata_buffer); + auto result = reader(native_obj); + + if (data) { + env->ReleaseByteArrayElements(java_data, data, JNI_ABORT); + } + if (metadata) { + env->ReleaseByteArrayElements(java_metadata, metadata, JNI_ABORT); + } + + return result; +} + +/// Convert a C++ ray::RayObject to a Java NativeRayObject. +inline jobject ToJavaNativeRayObject(JNIEnv *env, + const std::shared_ptr &rayObject) { + if (!rayObject) { + return nullptr; + } + auto java_data = NativeBufferToJavaByteArray(env, rayObject->GetData()); + auto java_metadata = NativeBufferToJavaByteArray(env, rayObject->GetMetadata()); + auto java_obj = env->NewObject(java_native_ray_object_class, + java_native_ray_object_init, java_data, java_metadata); + return java_obj; +} + +#endif // RAY_COMMON_JAVA_JNI_HELPER_H diff --git a/src/ray/core_worker/lib/java/org_ray_runtime_WorkerContext.cc b/src/ray/core_worker/lib/java/org_ray_runtime_WorkerContext.cc new file mode 100644 index 000000000..2c91dcdaa --- /dev/null +++ b/src/ray/core_worker/lib/java/org_ray_runtime_WorkerContext.cc @@ -0,0 +1,134 @@ +#include "ray/core_worker/lib/java/org_ray_runtime_WorkerContext.h" +#include +#include "ray/common/id.h" +#include "ray/core_worker/context.h" +#include "ray/core_worker/lib/java/jni_utils.h" + +inline ray::WorkerContext *GetWorkerContextFromPointer( + jlong nativeWorkerContextFromPointer) { + return reinterpret_cast(nativeWorkerContextFromPointer); +} + +#ifdef __cplusplus +extern "C" { +#endif + +/* + * Class: org_ray_runtime_WorkerContext + * Method: nativeCreateWorkerContext + * Signature: (I[B)J + */ +JNIEXPORT jlong JNICALL Java_org_ray_runtime_WorkerContext_nativeCreateWorkerContext( + JNIEnv *env, jclass, jint workerType, jbyteArray jobId) { + return reinterpret_cast( + new ray::WorkerContext(static_cast(workerType), + JavaByteArrayToId(env, jobId))); +} + +/* + * Class: org_ray_runtime_WorkerContext + * Method: nativeGetCurrentTaskId + * Signature: (J)[B + */ +JNIEXPORT jbyteArray JNICALL Java_org_ray_runtime_WorkerContext_nativeGetCurrentTaskId( + JNIEnv *env, jclass, jlong nativeWorkerContextFromPointer) { + auto task_id = + GetWorkerContextFromPointer(nativeWorkerContextFromPointer)->GetCurrentTaskID(); + return IdToJavaByteArray(env, task_id); +} + +/* + * Class: org_ray_runtime_WorkerContext + * Method: nativeSetCurrentTask + * Signature: (J[B)V + */ +JNIEXPORT void JNICALL Java_org_ray_runtime_WorkerContext_nativeSetCurrentTask( + JNIEnv *env, jclass, jlong nativeWorkerContextFromPointer, jbyteArray taskSpec) { + jbyte *data = env->GetByteArrayElements(taskSpec, NULL); + jsize size = env->GetArrayLength(taskSpec); + ray::rpc::TaskSpec task_spec_message; + task_spec_message.ParseFromArray(data, size); + env->ReleaseByteArrayElements(taskSpec, data, JNI_ABORT); + + ray::TaskSpecification spec(task_spec_message); + GetWorkerContextFromPointer(nativeWorkerContextFromPointer)->SetCurrentTask(spec); +} + +/* + * Class: org_ray_runtime_WorkerContext + * Method: nativeGetCurrentTask + * Signature: (J)[B + */ +JNIEXPORT jbyteArray JNICALL Java_org_ray_runtime_WorkerContext_nativeGetCurrentTask( + JNIEnv *env, jclass, jlong nativeWorkerContextFromPointer) { + auto spec = + GetWorkerContextFromPointer(nativeWorkerContextFromPointer)->GetCurrentTask(); + if (!spec) { + return nullptr; + } + + auto task_message = spec->Serialize(); + jbyteArray result = env->NewByteArray(task_message.size()); + env->SetByteArrayRegion( + result, 0, task_message.size(), + reinterpret_cast(const_cast(task_message.data()))); + return result; +} + +/* + * Class: org_ray_runtime_WorkerContext + * Method: nativeGetCurrentJobId + * Signature: (J)Ljava/nio/ByteBuffer; + */ +JNIEXPORT jobject JNICALL Java_org_ray_runtime_WorkerContext_nativeGetCurrentJobId( + JNIEnv *env, jclass, jlong nativeWorkerContextFromPointer) { + const auto &job_id = + GetWorkerContextFromPointer(nativeWorkerContextFromPointer)->GetCurrentJobID(); + return IdToJavaByteBuffer(env, job_id); +} + +/* + * Class: org_ray_runtime_WorkerContext + * Method: nativeGetCurrentWorkerId + * Signature: (J)[B + */ +JNIEXPORT jbyteArray JNICALL Java_org_ray_runtime_WorkerContext_nativeGetCurrentWorkerId( + JNIEnv *env, jclass, jlong nativeWorkerContextFromPointer) { + auto worker_id = + GetWorkerContextFromPointer(nativeWorkerContextFromPointer)->GetWorkerID(); + return IdToJavaByteArray(env, worker_id); +} + +/* + * Class: org_ray_runtime_WorkerContext + * Method: nativeGetNextTaskIndex + * Signature: (J)I + */ +JNIEXPORT jint JNICALL Java_org_ray_runtime_WorkerContext_nativeGetNextTaskIndex( + JNIEnv *env, jclass, jlong nativeWorkerContextFromPointer) { + return GetWorkerContextFromPointer(nativeWorkerContextFromPointer)->GetNextTaskIndex(); +} + +/* + * Class: org_ray_runtime_WorkerContext + * Method: nativeGetNextPutIndex + * Signature: (J)I + */ +JNIEXPORT jint JNICALL Java_org_ray_runtime_WorkerContext_nativeGetNextPutIndex( + JNIEnv *env, jclass, jlong nativeWorkerContextFromPointer) { + return GetWorkerContextFromPointer(nativeWorkerContextFromPointer)->GetNextPutIndex(); +} + +/* + * Class: org_ray_runtime_WorkerContext + * Method: nativeDestroy + * Signature: (J)V + */ +JNIEXPORT void JNICALL Java_org_ray_runtime_WorkerContext_nativeDestroy( + JNIEnv *env, jclass, jlong nativeWorkerContextFromPointer) { + delete GetWorkerContextFromPointer(nativeWorkerContextFromPointer); +} + +#ifdef __cplusplus +} +#endif diff --git a/src/ray/core_worker/lib/java/org_ray_runtime_WorkerContext.h b/src/ray/core_worker/lib/java/org_ray_runtime_WorkerContext.h new file mode 100644 index 000000000..df9c60a56 --- /dev/null +++ b/src/ray/core_worker/lib/java/org_ray_runtime_WorkerContext.h @@ -0,0 +1,87 @@ +/* DO NOT EDIT THIS FILE - it is machine generated */ +#include +/* Header for class org_ray_runtime_WorkerContext */ + +#ifndef _Included_org_ray_runtime_WorkerContext +#define _Included_org_ray_runtime_WorkerContext +#ifdef __cplusplus +extern "C" { +#endif +/* + * Class: org_ray_runtime_WorkerContext + * Method: nativeCreateWorkerContext + * Signature: (I[B)J + */ +JNIEXPORT jlong JNICALL Java_org_ray_runtime_WorkerContext_nativeCreateWorkerContext( + JNIEnv *, jclass, jint, jbyteArray); + +/* + * Class: org_ray_runtime_WorkerContext + * Method: nativeGetCurrentTaskId + * Signature: (J)[B + */ +JNIEXPORT jbyteArray JNICALL +Java_org_ray_runtime_WorkerContext_nativeGetCurrentTaskId(JNIEnv *, jclass, jlong); + +/* + * Class: org_ray_runtime_WorkerContext + * Method: nativeSetCurrentTask + * Signature: (J[B)V + */ +JNIEXPORT void JNICALL Java_org_ray_runtime_WorkerContext_nativeSetCurrentTask( + JNIEnv *, jclass, jlong, jbyteArray); + +/* + * Class: org_ray_runtime_WorkerContext + * Method: nativeGetCurrentTask + * Signature: (J)[B + */ +JNIEXPORT jbyteArray JNICALL +Java_org_ray_runtime_WorkerContext_nativeGetCurrentTask(JNIEnv *, jclass, jlong); + +/* + * Class: org_ray_runtime_WorkerContext + * Method: nativeGetCurrentJobId + * Signature: (J)Ljava/nio/ByteBuffer; + */ +JNIEXPORT jobject JNICALL +Java_org_ray_runtime_WorkerContext_nativeGetCurrentJobId(JNIEnv *, jclass, jlong); + +/* + * Class: org_ray_runtime_WorkerContext + * Method: nativeGetCurrentWorkerId + * Signature: (J)[B + */ +JNIEXPORT jbyteArray JNICALL +Java_org_ray_runtime_WorkerContext_nativeGetCurrentWorkerId(JNIEnv *, jclass, jlong); + +/* + * Class: org_ray_runtime_WorkerContext + * Method: nativeGetNextTaskIndex + * Signature: (J)I + */ +JNIEXPORT jint JNICALL Java_org_ray_runtime_WorkerContext_nativeGetNextTaskIndex(JNIEnv *, + jclass, + jlong); + +/* + * Class: org_ray_runtime_WorkerContext + * Method: nativeGetNextPutIndex + * Signature: (J)I + */ +JNIEXPORT jint JNICALL Java_org_ray_runtime_WorkerContext_nativeGetNextPutIndex(JNIEnv *, + jclass, + jlong); + +/* + * Class: org_ray_runtime_WorkerContext + * Method: nativeDestroy + * Signature: (J)V + */ +JNIEXPORT void JNICALL Java_org_ray_runtime_WorkerContext_nativeDestroy(JNIEnv *, jclass, + jlong); + +#ifdef __cplusplus +} +#endif +#endif diff --git a/src/ray/core_worker/lib/java/org_ray_runtime_objectstore_ObjectInterfaceImpl.cc b/src/ray/core_worker/lib/java/org_ray_runtime_objectstore_ObjectInterfaceImpl.cc new file mode 100644 index 000000000..3c7bb43a0 --- /dev/null +++ b/src/ray/core_worker/lib/java/org_ray_runtime_objectstore_ObjectInterfaceImpl.cc @@ -0,0 +1,149 @@ +#include "ray/core_worker/lib/java/org_ray_runtime_objectstore_ObjectInterfaceImpl.h" +#include +#include "ray/common/id.h" +#include "ray/core_worker/common.h" +#include "ray/core_worker/lib/java/jni_utils.h" +#include "ray/core_worker/object_interface.h" + +inline ray::CoreWorkerObjectInterface *GetObjectInterfaceFromPointer( + jlong nativeObjectInterfacePointer) { + return reinterpret_cast(nativeObjectInterfacePointer); +} + +#ifdef __cplusplus +extern "C" { +#endif + +/* + * Class: org_ray_runtime_objectstore_ObjectInterfaceImpl + * Method: nativeCreateObjectInterface + * Signature: (JJLjava/lang/String;)J + */ +JNIEXPORT jlong JNICALL +Java_org_ray_runtime_objectstore_ObjectInterfaceImpl_nativeCreateObjectInterface( + JNIEnv *env, jclass, jlong nativeWorkerContext, jlong nativeRayletClient, + jstring storeSocketName) { + return reinterpret_cast(new ray::CoreWorkerObjectInterface( + *reinterpret_cast(nativeWorkerContext), + *reinterpret_cast *>(nativeRayletClient), + JavaStringToNativeString(env, storeSocketName))); +} + +/* + * Class: org_ray_runtime_objectstore_ObjectInterfaceImpl + * Method: nativePut + * Signature: (JLorg/ray/runtime/objectstore/NativeRayObject;)[B + */ +JNIEXPORT jbyteArray JNICALL +Java_org_ray_runtime_objectstore_ObjectInterfaceImpl_nativePut__JLorg_ray_runtime_objectstore_NativeRayObject_2( + JNIEnv *env, jclass, jlong nativeObjectInterfacePointer, jobject obj) { + ray::Status status; + ray::ObjectID object_id = ReadJavaNativeRayObject( + env, obj, + [nativeObjectInterfacePointer, + &status](const std::shared_ptr &rayObject) { + RAY_CHECK(rayObject != nullptr); + ray::ObjectID object_id; + status = GetObjectInterfaceFromPointer(nativeObjectInterfacePointer) + ->Put(*rayObject, &object_id); + return object_id; + }); + THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, nullptr); + return IdToJavaByteArray(env, object_id); +} + +/* + * Class: org_ray_runtime_objectstore_ObjectInterfaceImpl + * Method: nativePut + * Signature: (J[BLorg/ray/runtime/objectstore/NativeRayObject;)V + */ +JNIEXPORT void JNICALL +Java_org_ray_runtime_objectstore_ObjectInterfaceImpl_nativePut__J_3BLorg_ray_runtime_objectstore_NativeRayObject_2( + JNIEnv *env, jclass, jlong nativeObjectInterfacePointer, jbyteArray objectId, + jobject obj) { + auto object_id = JavaByteArrayToId(env, objectId); + auto status = ReadJavaNativeRayObject( + env, obj, + [nativeObjectInterfacePointer, + &object_id](const std::shared_ptr &rayObject) { + RAY_CHECK(rayObject != nullptr); + return GetObjectInterfaceFromPointer(nativeObjectInterfacePointer) + ->Put(*rayObject, object_id); + }); + THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, (void)0); +} + +/* + * Class: org_ray_runtime_objectstore_ObjectInterfaceImpl + * Method: nativeGet + * Signature: (JLjava/util/List;J)Ljava/util/List; + */ +JNIEXPORT jobject JNICALL Java_org_ray_runtime_objectstore_ObjectInterfaceImpl_nativeGet( + JNIEnv *env, jclass, jlong nativeObjectInterfacePointer, jobject ids, + jlong timeoutMs) { + std::vector object_ids; + JavaListToNativeVector( + env, ids, &object_ids, [](JNIEnv *env, jobject id) { + return JavaByteArrayToId(env, static_cast(id)); + }); + std::vector> results; + auto status = GetObjectInterfaceFromPointer(nativeObjectInterfacePointer) + ->Get(object_ids, (int64_t)timeoutMs, &results); + THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, nullptr); + return NativeVectorToJavaList>(env, results, + ToJavaNativeRayObject); +} + +/* + * Class: org_ray_runtime_objectstore_ObjectInterfaceImpl + * Method: nativeWait + * Signature: (JLjava/util/List;IJ)Ljava/util/List; + */ +JNIEXPORT jobject JNICALL Java_org_ray_runtime_objectstore_ObjectInterfaceImpl_nativeWait( + JNIEnv *env, jclass, jlong nativeObjectInterfacePointer, jobject objectIds, + jint numObjects, jlong timeoutMs) { + std::vector object_ids; + JavaListToNativeVector( + env, objectIds, &object_ids, [](JNIEnv *env, jobject id) { + return JavaByteArrayToId(env, static_cast(id)); + }); + std::vector results; + auto status = GetObjectInterfaceFromPointer(nativeObjectInterfacePointer) + ->Wait(object_ids, (int)numObjects, (int64_t)timeoutMs, &results); + THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, nullptr); + return NativeVectorToJavaList(env, results, [](JNIEnv *env, const bool &item) { + return env->NewObject(java_boolean_class, java_boolean_init, (jboolean)item); + }); +} + +/* + * Class: org_ray_runtime_objectstore_ObjectInterfaceImpl + * Method: nativeDelete + * Signature: (JLjava/util/List;ZZ)V + */ +JNIEXPORT void JNICALL Java_org_ray_runtime_objectstore_ObjectInterfaceImpl_nativeDelete( + JNIEnv *env, jclass, jlong nativeObjectInterfacePointer, jobject objectIds, + jboolean localOnly, jboolean deleteCreatingTasks) { + std::vector object_ids; + JavaListToNativeVector( + env, objectIds, &object_ids, [](JNIEnv *env, jobject id) { + return JavaByteArrayToId(env, static_cast(id)); + }); + auto status = GetObjectInterfaceFromPointer(nativeObjectInterfacePointer) + ->Delete(object_ids, (bool)localOnly, (bool)deleteCreatingTasks); + THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, (void)0); +} + +/* + * Class: org_ray_runtime_objectstore_ObjectInterfaceImpl + * Method: nativeDestroy + * Signature: (J)V + */ +JNIEXPORT void JNICALL Java_org_ray_runtime_objectstore_ObjectInterfaceImpl_nativeDestroy( + JNIEnv *env, jclass, jlong nativeObjectInterfacePointer) { + delete GetObjectInterfaceFromPointer(nativeObjectInterfacePointer); +} + +#ifdef __cplusplus +} +#endif diff --git a/src/ray/core_worker/lib/java/org_ray_runtime_objectstore_ObjectInterfaceImpl.h b/src/ray/core_worker/lib/java/org_ray_runtime_objectstore_ObjectInterfaceImpl.h new file mode 100644 index 000000000..0ea41535e --- /dev/null +++ b/src/ray/core_worker/lib/java/org_ray_runtime_objectstore_ObjectInterfaceImpl.h @@ -0,0 +1,72 @@ +/* DO NOT EDIT THIS FILE - it is machine generated */ +#include +/* Header for class org_ray_runtime_objectstore_ObjectInterfaceImpl */ + +#ifndef _Included_org_ray_runtime_objectstore_ObjectInterfaceImpl +#define _Included_org_ray_runtime_objectstore_ObjectInterfaceImpl +#ifdef __cplusplus +extern "C" { +#endif +/* + * Class: org_ray_runtime_objectstore_ObjectInterfaceImpl + * Method: nativeCreateObjectInterface + * Signature: (JJLjava/lang/String;)J + */ +JNIEXPORT jlong JNICALL +Java_org_ray_runtime_objectstore_ObjectInterfaceImpl_nativeCreateObjectInterface( + JNIEnv *, jclass, jlong, jlong, jstring); + +/* + * Class: org_ray_runtime_objectstore_ObjectInterfaceImpl + * Method: nativePut + * Signature: (JLorg/ray/runtime/objectstore/NativeRayObject;)[B + */ +JNIEXPORT jbyteArray JNICALL +Java_org_ray_runtime_objectstore_ObjectInterfaceImpl_nativePut__JLorg_ray_runtime_objectstore_NativeRayObject_2( + JNIEnv *, jclass, jlong, jobject); + +/* + * Class: org_ray_runtime_objectstore_ObjectInterfaceImpl + * Method: nativePut + * Signature: (J[BLorg/ray/runtime/objectstore/NativeRayObject;)V + */ +JNIEXPORT void JNICALL +Java_org_ray_runtime_objectstore_ObjectInterfaceImpl_nativePut__J_3BLorg_ray_runtime_objectstore_NativeRayObject_2( + JNIEnv *, jclass, jlong, jbyteArray, jobject); + +/* + * Class: org_ray_runtime_objectstore_ObjectInterfaceImpl + * Method: nativeGet + * Signature: (JLjava/util/List;J)Ljava/util/List; + */ +JNIEXPORT jobject JNICALL Java_org_ray_runtime_objectstore_ObjectInterfaceImpl_nativeGet( + JNIEnv *, jclass, jlong, jobject, jlong); + +/* + * Class: org_ray_runtime_objectstore_ObjectInterfaceImpl + * Method: nativeWait + * Signature: (JLjava/util/List;IJ)Ljava/util/List; + */ +JNIEXPORT jobject JNICALL Java_org_ray_runtime_objectstore_ObjectInterfaceImpl_nativeWait( + JNIEnv *, jclass, jlong, jobject, jint, jlong); + +/* + * Class: org_ray_runtime_objectstore_ObjectInterfaceImpl + * Method: nativeDelete + * Signature: (JLjava/util/List;ZZ)V + */ +JNIEXPORT void JNICALL Java_org_ray_runtime_objectstore_ObjectInterfaceImpl_nativeDelete( + JNIEnv *, jclass, jlong, jobject, jboolean, jboolean); + +/* + * Class: org_ray_runtime_objectstore_ObjectInterfaceImpl + * Method: nativeDestroy + * Signature: (J)V + */ +JNIEXPORT void JNICALL Java_org_ray_runtime_objectstore_ObjectInterfaceImpl_nativeDestroy( + JNIEnv *, jclass, jlong); + +#ifdef __cplusplus +} +#endif +#endif diff --git a/src/ray/core_worker/store_provider/plasma_store_provider.cc b/src/ray/core_worker/store_provider/plasma_store_provider.cc index 53c330dc0..5a59420a5 100644 --- a/src/ray/core_worker/store_provider/plasma_store_provider.cc +++ b/src/ray/core_worker/store_provider/plasma_store_provider.cc @@ -3,6 +3,7 @@ #include "ray/core_worker/context.h" #include "ray/core_worker/core_worker.h" #include "ray/core_worker/object_interface.h" +#include "ray/protobuf/gcs.pb.h" namespace ray { @@ -101,11 +102,14 @@ Status CoreWorkerPlasmaStoreProvider::Get( std::make_shared(object_buffers[i].data), std::make_shared(object_buffers[i].metadata)); unready.erase(object_id); + if (IsException(object_buffers[i])) { + should_break = true; + } } } num_attempts += 1; - // TODO(zhijunfu): log a message if attempted too many times. + WarnIfAttemptedTooManyTimes(num_attempts, unready); } if (was_blocked) { @@ -144,4 +148,45 @@ Status CoreWorkerPlasmaStoreProvider::Delete(const std::vector &object return raylet_client_->FreeObjects(object_ids, local_only, delete_creating_tasks); } +bool CoreWorkerPlasmaStoreProvider::IsException(const plasma::ObjectBuffer &buffer) { + // TODO (kfstorm): metadata should be structured. + const std::string metadata = buffer.metadata->ToString(); + const auto error_type_descriptor = ray::rpc::ErrorType_descriptor(); + for (int i = 0; i < error_type_descriptor->value_count(); i++) { + const auto error_type_number = error_type_descriptor->value(i)->number(); + if (metadata == std::to_string(error_type_number)) { + return true; + } + } + return false; +} + +void CoreWorkerPlasmaStoreProvider::WarnIfAttemptedTooManyTimes( + int num_attempts, const std::unordered_map &unready) { + if (num_attempts % RayConfig::instance().object_store_get_warn_per_num_attempts() == + 0) { + std::ostringstream oss; + size_t printed = 0; + for (auto &entry : unready) { + if (printed >= + RayConfig::instance().object_store_get_max_ids_to_print_in_warning()) { + break; + } + if (printed > 0) { + oss << ", "; + } + oss << entry.first.Hex(); + } + if (printed < unready.size()) { + oss << ", etc"; + } + RAY_LOG(WARNING) + << "Attempted " << num_attempts << " times to reconstruct objects, but " + << "some objects are still unavailable. If this message continues to print," + << " it may indicate that object's creating task is hanging, or something wrong" + << " happened in raylet backend. " << unready.size() + << " object(s) pending: " << oss.str() << "."; + } +} + } // namespace ray diff --git a/src/ray/core_worker/store_provider/plasma_store_provider.h b/src/ray/core_worker/store_provider/plasma_store_provider.h index 9aa2f914a..89ecb2ea2 100644 --- a/src/ray/core_worker/store_provider/plasma_store_provider.h +++ b/src/ray/core_worker/store_provider/plasma_store_provider.h @@ -60,6 +60,20 @@ class CoreWorkerPlasmaStoreProvider : public CoreWorkerStoreProvider { bool delete_creating_tasks) override; private: + /// Whether the buffer represents an exception object. + /// + /// \param[in] buffer the object buffer. + /// \return Whether it represents an exception object. + static bool IsException(const plasma::ObjectBuffer &buffer); + + /// Print a warning if we've attempted too many times, but some objects are still + /// unavailable. + /// + /// \param[in] num_attemps The number of attempted times. + /// \param[in] unready The unready objects. + static void WarnIfAttemptedTooManyTimes( + int num_attempts, const std::unordered_map &unready); + /// Plasma store client. plasma::PlasmaClient store_client_; diff --git a/src/ray/protobuf/common.proto b/src/ray/protobuf/common.proto index 8d14c004a..48a3f1cf7 100644 --- a/src/ray/protobuf/common.proto +++ b/src/ray/protobuf/common.proto @@ -11,6 +11,12 @@ enum Language { CPP = 2; } +// Type of a worker. +enum WorkerType { + WORKER = 0; + DRIVER = 1; +} + // Type of a task. enum TaskType { // Normal task. diff --git a/src/ray/protobuf/gcs.proto b/src/ray/protobuf/gcs.proto index 5dcc36662..05d97750f 100644 --- a/src/ray/protobuf/gcs.proto +++ b/src/ray/protobuf/gcs.proto @@ -267,4 +267,6 @@ enum ErrorType { // 2) The object's creating task is already cleaned up from GCS (this currently // crashes raylet). OBJECT_UNRECONSTRUCTABLE = 2; + // Indicates that a task failed due to user code failure. + TASK_EXECUTION_EXCEPTION = 3; } diff --git a/src/ray/raylet/lib/java/org_ray_runtime_raylet_RayletClientImpl.cc b/src/ray/raylet/lib/java/org_ray_runtime_raylet_RayletClientImpl.cc index ac6d33b9d..fb4390fb5 100644 --- a/src/ray/raylet/lib/java/org_ray_runtime_raylet_RayletClientImpl.cc +++ b/src/ray/raylet/lib/java/org_ray_runtime_raylet_RayletClientImpl.cc @@ -3,39 +3,14 @@ #include #include "ray/common/id.h" +#include "ray/core_worker/lib/java/jni_utils.h" #include "ray/raylet/raylet_client.h" #include "ray/util/logging.h" -template -class UniqueIdFromJByteArray { - public: - const ID &GetId() const { return id; } - - UniqueIdFromJByteArray(JNIEnv *env, const jbyteArray &bytes) { - std::string id_str(ID::Size(), 0); - env->GetByteArrayRegion(bytes, 0, ID::Size(), - reinterpret_cast(&id_str.front())); - id = ID::FromBinary(id_str); - } - - private: - ID id; -}; - #ifdef __cplusplus extern "C" { #endif -inline bool ThrowRayExceptionIfNotOK(JNIEnv *env, const ray::Status &status) { - if (!status.ok()) { - jclass exception_class = env->FindClass("org/ray/api/exception/RayException"); - env->ThrowNew(exception_class, status.message().c_str()); - return true; - } else { - return false; - } -} - /* * Class: org_ray_runtime_raylet_RayletClientImpl * Method: nativeInit @@ -44,11 +19,11 @@ inline bool ThrowRayExceptionIfNotOK(JNIEnv *env, const ray::Status &status) { JNIEXPORT jlong JNICALL Java_org_ray_runtime_raylet_RayletClientImpl_nativeInit( JNIEnv *env, jclass, jstring sockName, jbyteArray workerId, jboolean isWorker, jbyteArray jobId) { - UniqueIdFromJByteArray worker_id(env, workerId); - UniqueIdFromJByteArray job_id(env, jobId); + const auto worker_id = JavaByteArrayToId(env, workerId); + const auto job_id = JavaByteArrayToId(env, jobId); const char *nativeString = env->GetStringUTFChars(sockName, JNI_FALSE); - auto raylet_client = new RayletClient(nativeString, worker_id.GetId(), isWorker, - job_id.GetId(), Language::JAVA); + auto raylet_client = new std::unique_ptr( + new RayletClient(nativeString, worker_id, isWorker, job_id, Language::JAVA)); env->ReleaseStringUTFChars(sockName, nativeString); return reinterpret_cast(raylet_client); } @@ -60,7 +35,7 @@ JNIEXPORT jlong JNICALL Java_org_ray_runtime_raylet_RayletClientImpl_nativeInit( */ JNIEXPORT void JNICALL Java_org_ray_runtime_raylet_RayletClientImpl_nativeSubmitTask( JNIEnv *env, jclass, jlong client, jbyteArray taskSpec) { - auto raylet_client = reinterpret_cast(client); + auto &raylet_client = *reinterpret_cast *>(client); jbyte *data = env->GetByteArrayElements(taskSpec, NULL); jsize size = env->GetArrayLength(taskSpec); @@ -70,7 +45,7 @@ JNIEXPORT void JNICALL Java_org_ray_runtime_raylet_RayletClientImpl_nativeSubmit ray::TaskSpecification task_spec(task_spec_message); auto status = raylet_client->SubmitTask(task_spec); - ThrowRayExceptionIfNotOK(env, status); + THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, (void)0); } /* @@ -80,13 +55,11 @@ JNIEXPORT void JNICALL Java_org_ray_runtime_raylet_RayletClientImpl_nativeSubmit */ JNIEXPORT jbyteArray JNICALL Java_org_ray_runtime_raylet_RayletClientImpl_nativeGetTask( JNIEnv *env, jclass, jlong client) { - auto raylet_client = reinterpret_cast(client); + auto &raylet_client = *reinterpret_cast *>(client); std::unique_ptr spec; auto status = raylet_client->GetTask(&spec); - if (ThrowRayExceptionIfNotOK(env, status)) { - return nullptr; - } + THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, nullptr); // Serialize the task spec and copy to Java byte array. auto task_data = spec->Serialize(); @@ -109,8 +82,9 @@ JNIEXPORT jbyteArray JNICALL Java_org_ray_runtime_raylet_RayletClientImpl_native */ JNIEXPORT void JNICALL Java_org_ray_runtime_raylet_RayletClientImpl_nativeDestroy( JNIEnv *env, jclass, jlong client) { - auto raylet_client = reinterpret_cast(client); - ThrowRayExceptionIfNotOK(env, raylet_client->Disconnect()); + auto raylet_client = reinterpret_cast *>(client); + auto status = (*raylet_client)->Disconnect(); + THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, (void)0); delete raylet_client; } @@ -128,15 +102,14 @@ Java_org_ray_runtime_raylet_RayletClientImpl_nativeFetchOrReconstruct( for (int i = 0; i < len; i++) { jbyteArray object_id_bytes = static_cast(env->GetObjectArrayElement(objectIds, i)); - UniqueIdFromJByteArray object_id(env, object_id_bytes); - object_ids.push_back(object_id.GetId()); + const auto object_id = JavaByteArrayToId(env, object_id_bytes); + object_ids.push_back(object_id); env->DeleteLocalRef(object_id_bytes); } - UniqueIdFromJByteArray current_task_id(env, currentTaskId); - auto raylet_client = reinterpret_cast(client); - auto status = - raylet_client->FetchOrReconstruct(object_ids, fetchOnly, current_task_id.GetId()); - ThrowRayExceptionIfNotOK(env, status); + const auto current_task_id = JavaByteArrayToId(env, currentTaskId); + auto &raylet_client = *reinterpret_cast *>(client); + auto status = raylet_client->FetchOrReconstruct(object_ids, fetchOnly, current_task_id); + THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, (void)0); } /* @@ -146,10 +119,10 @@ Java_org_ray_runtime_raylet_RayletClientImpl_nativeFetchOrReconstruct( */ JNIEXPORT void JNICALL Java_org_ray_runtime_raylet_RayletClientImpl_nativeNotifyUnblocked( JNIEnv *env, jclass, jlong client, jbyteArray currentTaskId) { - UniqueIdFromJByteArray current_task_id(env, currentTaskId); - auto raylet_client = reinterpret_cast(client); - auto status = raylet_client->NotifyUnblocked(current_task_id.GetId()); - ThrowRayExceptionIfNotOK(env, status); + const auto current_task_id = JavaByteArrayToId(env, currentTaskId); + auto &raylet_client = *reinterpret_cast *>(client); + auto status = raylet_client->NotifyUnblocked(current_task_id); + THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, (void)0); } /* @@ -166,22 +139,20 @@ Java_org_ray_runtime_raylet_RayletClientImpl_nativeWaitObject( for (int i = 0; i < len; i++) { jbyteArray object_id_bytes = static_cast(env->GetObjectArrayElement(objectIds, i)); - UniqueIdFromJByteArray object_id(env, object_id_bytes); - object_ids.push_back(object_id.GetId()); + const auto object_id = JavaByteArrayToId(env, object_id_bytes); + object_ids.push_back(object_id); env->DeleteLocalRef(object_id_bytes); } - UniqueIdFromJByteArray current_task_id(env, currentTaskId); + const auto current_task_id = JavaByteArrayToId(env, currentTaskId); - auto raylet_client = reinterpret_cast(client); + auto &raylet_client = *reinterpret_cast *>(client); // Invoke wait. WaitResultPair result; - auto status = raylet_client->Wait(object_ids, numReturns, timeoutMillis, - static_cast(isWaitLocal), - current_task_id.GetId(), &result); - if (ThrowRayExceptionIfNotOK(env, status)) { - return nullptr; - } + auto status = + raylet_client->Wait(object_ids, numReturns, timeoutMillis, + static_cast(isWaitLocal), current_task_id, &result); + THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, nullptr); // Convert result to java object. jboolean put_value = true; @@ -216,11 +187,10 @@ JNIEXPORT jbyteArray JNICALL Java_org_ray_runtime_raylet_RayletClientImpl_nativeGenerateTaskId( JNIEnv *env, jclass, jbyteArray jobId, jbyteArray parentTaskId, jint parent_task_counter) { - UniqueIdFromJByteArray job_id(env, jobId); - UniqueIdFromJByteArray parent_task_id(env, parentTaskId); + const auto job_id = JavaByteArrayToId(env, jobId); + const auto parent_task_id = JavaByteArrayToId(env, parentTaskId); - TaskID task_id = - ray::GenerateTaskId(job_id.GetId(), parent_task_id.GetId(), parent_task_counter); + TaskID task_id = ray::GenerateTaskId(job_id, parent_task_id, parent_task_counter); jbyteArray result = env->NewByteArray(task_id.Size()); if (nullptr == result) { return nullptr; @@ -245,13 +215,13 @@ Java_org_ray_runtime_raylet_RayletClientImpl_nativeFreePlasmaObjects( for (int i = 0; i < len; i++) { jbyteArray object_id_bytes = static_cast(env->GetObjectArrayElement(objectIds, i)); - UniqueIdFromJByteArray object_id(env, object_id_bytes); - object_ids.push_back(object_id.GetId()); + const auto object_id = JavaByteArrayToId(env, object_id_bytes); + object_ids.push_back(object_id); env->DeleteLocalRef(object_id_bytes); } - auto raylet_client = reinterpret_cast(client); + auto &raylet_client = *reinterpret_cast *>(client); auto status = raylet_client->FreeObjects(object_ids, localOnly, deleteCreatingTasks); - ThrowRayExceptionIfNotOK(env, status); + THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, (void)0); } /* @@ -263,13 +233,11 @@ JNIEXPORT jbyteArray JNICALL Java_org_ray_runtime_raylet_RayletClientImpl_nativePrepareCheckpoint(JNIEnv *env, jclass, jlong client, jbyteArray actorId) { - auto raylet_client = reinterpret_cast(client); - UniqueIdFromJByteArray actor_id(env, actorId); + auto &raylet_client = *reinterpret_cast *>(client); + const auto actor_id = JavaByteArrayToId(env, actorId); ActorCheckpointID checkpoint_id; - auto status = raylet_client->PrepareActorCheckpoint(actor_id.GetId(), checkpoint_id); - if (ThrowRayExceptionIfNotOK(env, status)) { - return nullptr; - } + auto status = raylet_client->PrepareActorCheckpoint(actor_id, checkpoint_id); + THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, nullptr); jbyteArray result = env->NewByteArray(checkpoint_id.Size()); env->SetByteArrayRegion(result, 0, checkpoint_id.Size(), reinterpret_cast(checkpoint_id.Data())); @@ -284,12 +252,11 @@ Java_org_ray_runtime_raylet_RayletClientImpl_nativePrepareCheckpoint(JNIEnv *env JNIEXPORT void JNICALL Java_org_ray_runtime_raylet_RayletClientImpl_nativeNotifyActorResumedFromCheckpoint( JNIEnv *env, jclass, jlong client, jbyteArray actorId, jbyteArray checkpointId) { - auto raylet_client = reinterpret_cast(client); - UniqueIdFromJByteArray actor_id(env, actorId); - UniqueIdFromJByteArray checkpoint_id(env, checkpointId); - auto status = raylet_client->NotifyActorResumedFromCheckpoint(actor_id.GetId(), - checkpoint_id.GetId()); - ThrowRayExceptionIfNotOK(env, status); + auto &raylet_client = *reinterpret_cast *>(client); + const auto actor_id = JavaByteArrayToId(env, actorId); + const auto checkpoint_id = JavaByteArrayToId(env, checkpointId); + auto status = raylet_client->NotifyActorResumedFromCheckpoint(actor_id, checkpoint_id); + THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, (void)0); } /* @@ -300,14 +267,14 @@ Java_org_ray_runtime_raylet_RayletClientImpl_nativeNotifyActorResumedFromCheckpo JNIEXPORT void JNICALL Java_org_ray_runtime_raylet_RayletClientImpl_nativeSetResource( JNIEnv *env, jclass, jlong client, jstring resourceName, jdouble capacity, jbyteArray nodeId) { - auto raylet_client = reinterpret_cast(client); - UniqueIdFromJByteArray node_id(env, nodeId); + auto &raylet_client = *reinterpret_cast *>(client); + const auto node_id = JavaByteArrayToId(env, nodeId); const char *native_resource_name = env->GetStringUTFChars(resourceName, JNI_FALSE); - auto status = raylet_client->SetResource( - native_resource_name, static_cast(capacity), node_id.GetId()); + auto status = raylet_client->SetResource(native_resource_name, + static_cast(capacity), node_id); env->ReleaseStringUTFChars(resourceName, native_resource_name); - ThrowRayExceptionIfNotOK(env, status); + THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, (void)0); } #ifdef __cplusplus