mirror of
https://github.com/vale981/ray
synced 2025-03-04 17:41:43 -05:00
[Java worker] Refactor object store and worker context on top of core worker (#5079)
This commit is contained in:
parent
e5be5fd46d
commit
806524384b
40 changed files with 1386 additions and 571 deletions
29
BUILD.bazel
29
BUILD.bazel
|
@ -647,13 +647,13 @@ pyx_library(
|
|||
)
|
||||
|
||||
cc_binary(
|
||||
name = "libraylet_library_java.so",
|
||||
srcs = [
|
||||
"src/ray/raylet/lib/java/org_ray_runtime_raylet_RayletClientImpl.h",
|
||||
"src/ray/raylet/lib/java/org_ray_runtime_raylet_RayletClientImpl.cc",
|
||||
"src/ray/common/id.h",
|
||||
"src/ray/raylet/raylet_client.h",
|
||||
"src/ray/util/logging.h",
|
||||
name = "libcore_worker_library_java.so",
|
||||
srcs = glob([
|
||||
"src/ray/core_worker/lib/java/*.h",
|
||||
"src/ray/core_worker/lib/java/*.cc",
|
||||
"src/ray/raylet/lib/java/*.h",
|
||||
"src/ray/raylet/lib/java/*.cc",
|
||||
]) + [
|
||||
"@bazel_tools//tools/jdk:jni_header",
|
||||
] + select({
|
||||
"@bazel_tools//src/conditions:windows": ["@bazel_tools//tools/jdk:jni_md_header-windows"],
|
||||
|
@ -671,24 +671,23 @@ cc_binary(
|
|||
linkshared = 1,
|
||||
linkstatic = 1,
|
||||
deps = [
|
||||
"//:raylet_lib",
|
||||
"@plasma//:plasma_client",
|
||||
"//:core_worker_lib",
|
||||
],
|
||||
)
|
||||
|
||||
genrule(
|
||||
name = "raylet-jni-darwin-compat",
|
||||
srcs = [":libraylet_library_java.so"],
|
||||
outs = ["libraylet_library_java.dylib"],
|
||||
name = "core_worker-jni-darwin-compat",
|
||||
srcs = [":libcore_worker_library_java.so"],
|
||||
outs = ["libcore_worker_library_java.dylib"],
|
||||
cmd = "cp $< $@",
|
||||
output_to_bindir = 1,
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "raylet_library_java",
|
||||
name = "core_worker_library_java",
|
||||
srcs = select({
|
||||
"@bazel_tools//src/conditions:darwin": [":libraylet_library_java.dylib"],
|
||||
"//conditions:default": [":libraylet_library_java.so"],
|
||||
"@bazel_tools//src/conditions:darwin": [":libcore_worker_library_java.dylib"],
|
||||
"//conditions:default": [":libcore_worker_library_java.so"],
|
||||
}),
|
||||
visibility = ["//java:__subpackages__"],
|
||||
)
|
||||
|
|
|
@ -2,27 +2,6 @@ load("@com_github_google_flatbuffers//:build_defs.bzl", "flatbuffer_cc_library")
|
|||
|
||||
COPTS = ["-DARROW_USE_GLOG"]
|
||||
|
||||
java_library(
|
||||
name = "org_apache_arrow_arrow_plasma",
|
||||
srcs = glob(["java/plasma/src/main/java/**/*.java"]),
|
||||
data = [":plasma_client_java"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"@maven//:org_slf4j_slf4j_api",
|
||||
],
|
||||
)
|
||||
|
||||
java_binary(
|
||||
name = "org_apache_arrow_arrow_plasma_test",
|
||||
srcs = ["java/plasma/src/test/java/org/apache/arrow/plasma/PlasmaClientTest.java"],
|
||||
main_class = "org.apache.arrow.plasma.PlasmaClientTest",
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":org_apache_arrow_arrow_plasma",
|
||||
"@maven//:junit_junit",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "arrow",
|
||||
srcs = [
|
||||
|
@ -145,15 +124,6 @@ genrule(
|
|||
output_to_bindir = 1,
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "plasma_client_java",
|
||||
srcs = select({
|
||||
"@bazel_tools//src/conditions:darwin": [":libplasma_java.dylib"],
|
||||
"//conditions:default": [":libplasma_java.so"],
|
||||
}),
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "plasma_lib",
|
||||
srcs = [
|
||||
|
|
|
@ -69,7 +69,6 @@ define_java_module(
|
|||
],
|
||||
deps = [
|
||||
":org_ray_ray_api",
|
||||
"@plasma//:org_apache_arrow_arrow_plasma",
|
||||
"@maven//:com_google_guava_guava",
|
||||
"@maven//:com_google_protobuf_protobuf_java",
|
||||
"@maven//:com_typesafe_config",
|
||||
|
@ -97,7 +96,6 @@ define_java_module(
|
|||
deps = [
|
||||
":org_ray_ray_api",
|
||||
":org_ray_ray_runtime",
|
||||
"@plasma//:org_apache_arrow_arrow_plasma",
|
||||
"@maven//:com_google_guava_guava",
|
||||
"@maven//:com_sun_xml_bind_jaxb_core",
|
||||
"@maven//:com_sun_xml_bind_jaxb_impl",
|
||||
|
@ -176,9 +174,8 @@ filegroup(
|
|||
"//:redis-server",
|
||||
"//:libray_redis_module.so",
|
||||
"//:raylet",
|
||||
"//:raylet_library_java",
|
||||
"//:core_worker_library_java",
|
||||
"@plasma//:plasma_store_server",
|
||||
"@plasma//:plasma_client_java",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -189,7 +186,6 @@ genrule(
|
|||
":all_java_proto",
|
||||
":java_native_deps",
|
||||
":copy_pom_file",
|
||||
"@plasma//:org_apache_arrow_arrow_plasma",
|
||||
],
|
||||
outs = ["gen_maven_deps.out"],
|
||||
cmd = """
|
||||
|
@ -208,9 +204,6 @@ genrule(
|
|||
chmod +w $$f
|
||||
cp $$f $$NATIVE_DEPS_DIR
|
||||
done
|
||||
# Install plasma jar to local maven repo.
|
||||
mvn install:install-file -Dfile=$(locations @plasma//:org_apache_arrow_arrow_plasma) -Dpackaging=jar \
|
||||
-DgroupId=org.apache.arrow -DartifactId=arrow-plasma -Dversion=0.13.0-SNAPSHOT
|
||||
echo $$(date) > $@
|
||||
""",
|
||||
local = 1,
|
||||
|
|
|
@ -24,11 +24,6 @@
|
|||
|
||||
<dependencyManagement>
|
||||
<dependencies>
|
||||
<dependency>
|
||||
<groupId>org.apache.arrow</groupId>
|
||||
<artifactId>arrow-plasma</artifactId>
|
||||
<version>0.13.0-SNAPSHOT</version>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
</dependencyManagement>
|
||||
|
||||
|
|
|
@ -22,10 +22,6 @@
|
|||
<artifactId>ray-api</artifactId>
|
||||
<version>${project.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.apache.arrow</groupId>
|
||||
<artifactId>arrow-plasma</artifactId>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>com.beust</groupId>
|
||||
<artifactId>jcommander</artifactId>
|
||||
|
|
|
@ -22,10 +22,6 @@
|
|||
<artifactId>ray-api</artifactId>
|
||||
<version>${project.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.apache.arrow</groupId>
|
||||
<artifactId>arrow-plasma</artifactId>
|
||||
</dependency>
|
||||
{generated_bzl_deps}
|
||||
</dependencies>
|
||||
|
||||
|
|
|
@ -1,7 +1,15 @@
|
|||
package org.ray.runtime;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
import com.google.common.base.Strings;
|
||||
import com.google.common.collect.ImmutableList;
|
||||
import java.io.File;
|
||||
import java.io.IOException;
|
||||
import java.io.InputStream;
|
||||
import java.lang.reflect.Field;
|
||||
import java.nio.file.Files;
|
||||
import java.nio.file.Paths;
|
||||
import java.nio.file.StandardCopyOption;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
|
@ -72,6 +80,27 @@ public abstract class AbstractRayRuntime implements RayRuntime {
|
|||
protected RuntimeContext runtimeContext;
|
||||
protected GcsClient gcsClient;
|
||||
|
||||
static {
|
||||
try {
|
||||
LOGGER.debug("Loading native libraries.");
|
||||
// Load native libraries.
|
||||
String[] libraries = new String[]{"core_worker_library_java"};
|
||||
for (String library : libraries) {
|
||||
String fileName = System.mapLibraryName(library);
|
||||
// Copy the file from resources to a temp dir, and load the native library.
|
||||
File file = File.createTempFile(fileName, "");
|
||||
file.deleteOnExit();
|
||||
InputStream in = AbstractRayRuntime.class.getResourceAsStream("/" + fileName);
|
||||
Preconditions.checkNotNull(in, "{} doesn't exist.", fileName);
|
||||
Files.copy(in, Paths.get(file.getAbsolutePath()), StandardCopyOption.REPLACE_EXISTING);
|
||||
System.load(file.getAbsolutePath());
|
||||
}
|
||||
LOGGER.debug("Native libraries loaded.");
|
||||
} catch (IOException e) {
|
||||
throw new RuntimeException("Couldn't load native libraries.", e);
|
||||
}
|
||||
}
|
||||
|
||||
public AbstractRayRuntime(RayConfig rayConfig) {
|
||||
this.rayConfig = rayConfig;
|
||||
functionManager = new FunctionManager(rayConfig.jobResourcePath);
|
||||
|
@ -79,6 +108,33 @@ public abstract class AbstractRayRuntime implements RayRuntime {
|
|||
runtimeContext = new RuntimeContextImpl(this);
|
||||
}
|
||||
|
||||
protected void resetLibraryPath() {
|
||||
if (rayConfig.libraryPath.isEmpty()) {
|
||||
return;
|
||||
}
|
||||
|
||||
String path = System.getProperty("java.library.path");
|
||||
if (Strings.isNullOrEmpty(path)) {
|
||||
path = "";
|
||||
} else {
|
||||
path += ":";
|
||||
}
|
||||
path += String.join(":", rayConfig.libraryPath);
|
||||
|
||||
// This is a hack to reset library path at runtime,
|
||||
// see https://stackoverflow.com/questions/15409223/.
|
||||
System.setProperty("java.library.path", path);
|
||||
// Set sys_paths to null so that java.library.path will be re-evaluated next time it is needed.
|
||||
final Field sysPathsField;
|
||||
try {
|
||||
sysPathsField = ClassLoader.class.getDeclaredField("sys_paths");
|
||||
sysPathsField.setAccessible(true);
|
||||
sysPathsField.set(null, null);
|
||||
} catch (NoSuchFieldException | IllegalAccessException e) {
|
||||
LOGGER.error("Failed to set library path.", e);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Start runtime.
|
||||
*/
|
||||
|
@ -330,8 +386,8 @@ public abstract class AbstractRayRuntime implements RayRuntime {
|
|||
* Create the task specification.
|
||||
*
|
||||
* @param func The target remote function.
|
||||
* @param pyFunctionDescriptor Descriptor of the target Python function, if the task is a
|
||||
* Python task.
|
||||
* @param pyFunctionDescriptor Descriptor of the target Python function, if the task is a Python
|
||||
* task.
|
||||
* @param actor The actor handle. If the task is not an actor task, actor id must be NIL.
|
||||
* @param args The arguments for the remote function.
|
||||
* @param isActorCreationTask Whether this task is an actor creation task.
|
||||
|
|
|
@ -3,7 +3,7 @@ package org.ray.runtime;
|
|||
import java.util.concurrent.atomic.AtomicInteger;
|
||||
import org.ray.api.id.JobId;
|
||||
import org.ray.runtime.config.RayConfig;
|
||||
import org.ray.runtime.objectstore.MockObjectStore;
|
||||
import org.ray.runtime.objectstore.MockObjectInterface;
|
||||
import org.ray.runtime.objectstore.ObjectStoreProxy;
|
||||
import org.ray.runtime.raylet.MockRayletClient;
|
||||
|
||||
|
@ -13,19 +13,22 @@ public class RayDevRuntime extends AbstractRayRuntime {
|
|||
super(rayConfig);
|
||||
}
|
||||
|
||||
private MockObjectStore store;
|
||||
private MockObjectInterface objectInterface;
|
||||
|
||||
private AtomicInteger jobCounter = new AtomicInteger(0);
|
||||
|
||||
@Override
|
||||
public void start() {
|
||||
store = new MockObjectStore(this);
|
||||
// Reset library path at runtime.
|
||||
resetLibraryPath();
|
||||
|
||||
objectInterface = new MockObjectInterface(workerContext);
|
||||
if (rayConfig.getJobId().isNil()) {
|
||||
rayConfig.setJobId(nextJobId());
|
||||
}
|
||||
workerContext = new WorkerContext(rayConfig.workerMode,
|
||||
rayConfig.getJobId(), rayConfig.runMode);
|
||||
objectStoreProxy = new ObjectStoreProxy(this, null);
|
||||
objectStoreProxy = new ObjectStoreProxy(workerContext, objectInterface);
|
||||
rayletClient = new MockRayletClient(this, rayConfig.numberExecThreadsForDevRuntime);
|
||||
}
|
||||
|
||||
|
@ -34,8 +37,8 @@ public class RayDevRuntime extends AbstractRayRuntime {
|
|||
rayletClient.destroy();
|
||||
}
|
||||
|
||||
public MockObjectStore getObjectStore() {
|
||||
return store;
|
||||
public MockObjectInterface getObjectInterface() {
|
||||
return objectInterface;
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -1,21 +1,13 @@
|
|||
package org.ray.runtime;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
import com.google.common.base.Strings;
|
||||
import java.io.File;
|
||||
import java.io.IOException;
|
||||
import java.io.InputStream;
|
||||
import java.lang.reflect.Field;
|
||||
import java.nio.file.Files;
|
||||
import java.nio.file.Paths;
|
||||
import java.nio.file.StandardCopyOption;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
import org.ray.api.id.JobId;
|
||||
import org.ray.runtime.config.RayConfig;
|
||||
import org.ray.runtime.config.WorkerMode;
|
||||
import org.ray.runtime.gcs.GcsClient;
|
||||
import org.ray.runtime.gcs.RedisClient;
|
||||
import org.ray.runtime.generated.Common.WorkerType;
|
||||
import org.ray.runtime.objectstore.ObjectInterfaceImpl;
|
||||
import org.ray.runtime.objectstore.ObjectStoreProxy;
|
||||
import org.ray.runtime.raylet.RayletClientImpl;
|
||||
import org.ray.runtime.runner.RunManager;
|
||||
|
@ -31,58 +23,12 @@ public final class RayNativeRuntime extends AbstractRayRuntime {
|
|||
|
||||
private RunManager manager = null;
|
||||
|
||||
static {
|
||||
try {
|
||||
LOGGER.debug("Loading native libraries.");
|
||||
// Load native libraries.
|
||||
String[] libraries = new String[]{"raylet_library_java", "plasma_java"};
|
||||
for (String library : libraries) {
|
||||
String fileName = System.mapLibraryName(library);
|
||||
// Copy the file from resources to a temp dir, and load the native library.
|
||||
File file = File.createTempFile(fileName, "");
|
||||
file.deleteOnExit();
|
||||
InputStream in = RayNativeRuntime.class.getResourceAsStream("/" + fileName);
|
||||
Preconditions.checkNotNull(in, "{} doesn't exist.", fileName);
|
||||
Files.copy(in, Paths.get(file.getAbsolutePath()), StandardCopyOption.REPLACE_EXISTING);
|
||||
System.load(file.getAbsolutePath());
|
||||
}
|
||||
LOGGER.debug("Native libraries loaded.");
|
||||
} catch (IOException e) {
|
||||
throw new RuntimeException("Couldn't load native libraries.", e);
|
||||
}
|
||||
}
|
||||
private ObjectInterfaceImpl objectInterfaceImpl = null;
|
||||
|
||||
public RayNativeRuntime(RayConfig rayConfig) {
|
||||
super(rayConfig);
|
||||
}
|
||||
|
||||
private void resetLibraryPath() {
|
||||
if (rayConfig.libraryPath.isEmpty()) {
|
||||
return;
|
||||
}
|
||||
|
||||
String path = System.getProperty("java.library.path");
|
||||
if (Strings.isNullOrEmpty(path)) {
|
||||
path = "";
|
||||
} else {
|
||||
path += ":";
|
||||
}
|
||||
path += String.join(":", rayConfig.libraryPath);
|
||||
|
||||
// This is a hack to reset library path at runtime,
|
||||
// see https://stackoverflow.com/questions/15409223/.
|
||||
System.setProperty("java.library.path", path);
|
||||
// Set sys_paths to null so that java.library.path will be re-evaluated next time it is needed.
|
||||
final Field sysPathsField;
|
||||
try {
|
||||
sysPathsField = ClassLoader.class.getDeclaredField("sys_paths");
|
||||
sysPathsField.setAccessible(true);
|
||||
sysPathsField.set(null, null);
|
||||
} catch (NoSuchFieldException | IllegalAccessException e) {
|
||||
LOGGER.error("Failed to set library path.", e);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void start() {
|
||||
// Reset library path at runtime.
|
||||
|
@ -101,16 +47,18 @@ public final class RayNativeRuntime extends AbstractRayRuntime {
|
|||
|
||||
workerContext = new WorkerContext(rayConfig.workerMode,
|
||||
rayConfig.getJobId(), rayConfig.runMode);
|
||||
// TODO(qwang): Get object_store_socket_name and raylet_socket_name from Redis.
|
||||
objectStoreProxy = new ObjectStoreProxy(this, rayConfig.objectStoreSocketName);
|
||||
|
||||
rayletClient = new RayletClientImpl(
|
||||
rayConfig.rayletSocketName,
|
||||
workerContext.getCurrentWorkerId(),
|
||||
rayConfig.workerMode == WorkerMode.WORKER,
|
||||
rayConfig.workerMode == WorkerType.WORKER,
|
||||
workerContext.getCurrentJobId()
|
||||
);
|
||||
|
||||
// TODO(qwang): Get object_store_socket_name and raylet_socket_name from Redis.
|
||||
objectInterfaceImpl = new ObjectInterfaceImpl(workerContext, rayletClient,
|
||||
rayConfig.objectStoreSocketName);
|
||||
objectStoreProxy = new ObjectStoreProxy(workerContext, objectInterfaceImpl);
|
||||
|
||||
// register
|
||||
registerWorker();
|
||||
|
||||
|
@ -123,6 +71,8 @@ public final class RayNativeRuntime extends AbstractRayRuntime {
|
|||
if (null != manager) {
|
||||
manager.cleanup();
|
||||
}
|
||||
objectInterfaceImpl.destroy();
|
||||
workerContext.destroy();
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -132,7 +82,7 @@ public final class RayNativeRuntime extends AbstractRayRuntime {
|
|||
RedisClient redisClient = new RedisClient(rayConfig.getRedisAddress(), rayConfig.redisPassword);
|
||||
Map<String, String> workerInfo = new HashMap<>();
|
||||
String workerId = new String(workerContext.getCurrentWorkerId().getBytes());
|
||||
if (rayConfig.workerMode == WorkerMode.DRIVER) {
|
||||
if (rayConfig.workerMode == WorkerType.DRIVER) {
|
||||
workerInfo.put("node_ip_address", rayConfig.nodeIp);
|
||||
workerInfo.put("driver_id", workerId);
|
||||
workerInfo.put("start_time", String.valueOf(System.currentTimeMillis()));
|
||||
|
|
|
@ -1,37 +1,24 @@
|
|||
package org.ray.runtime;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
import java.nio.ByteBuffer;
|
||||
import org.ray.api.id.JobId;
|
||||
import org.ray.api.id.TaskId;
|
||||
import org.ray.api.id.UniqueId;
|
||||
import org.ray.runtime.config.RunMode;
|
||||
import org.ray.runtime.config.WorkerMode;
|
||||
import org.ray.runtime.generated.Common.WorkerType;
|
||||
import org.ray.runtime.raylet.RayletClientImpl;
|
||||
import org.ray.runtime.task.TaskSpec;
|
||||
import org.ray.runtime.util.IdUtil;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
/**
|
||||
* This is a wrapper class for worker context of core worker.
|
||||
*/
|
||||
public class WorkerContext {
|
||||
|
||||
private static final Logger LOGGER = LoggerFactory.getLogger(WorkerContext.class);
|
||||
|
||||
private UniqueId workerId;
|
||||
|
||||
private ThreadLocal<TaskId> currentTaskId;
|
||||
|
||||
/**
|
||||
* Number of objects that have been put from current task.
|
||||
* The native pointer of worker context of core worker.
|
||||
*/
|
||||
private ThreadLocal<Integer> putIndex;
|
||||
|
||||
/**
|
||||
* Number of tasks that have been submitted from current task.
|
||||
*/
|
||||
private ThreadLocal<Integer> taskIndex;
|
||||
|
||||
private ThreadLocal<TaskSpec> currentTask;
|
||||
|
||||
private JobId currentJobId;
|
||||
private final long nativeWorkerContextPointer;
|
||||
|
||||
private ClassLoader currentClassLoader;
|
||||
|
||||
|
@ -45,31 +32,23 @@ public class WorkerContext {
|
|||
*/
|
||||
private RunMode runMode;
|
||||
|
||||
public WorkerContext(WorkerMode workerMode, JobId jobId, RunMode runMode) {
|
||||
public WorkerContext(WorkerType workerType, JobId jobId, RunMode runMode) {
|
||||
this.nativeWorkerContextPointer = nativeCreateWorkerContext(workerType.getNumber(), jobId.getBytes());
|
||||
mainThreadId = Thread.currentThread().getId();
|
||||
taskIndex = ThreadLocal.withInitial(() -> 0);
|
||||
putIndex = ThreadLocal.withInitial(() -> 0);
|
||||
currentTaskId = ThreadLocal.withInitial(TaskId::randomId);
|
||||
this.runMode = runMode;
|
||||
currentTask = ThreadLocal.withInitial(() -> null);
|
||||
currentClassLoader = null;
|
||||
if (workerMode == WorkerMode.DRIVER) {
|
||||
workerId = IdUtil.computeDriverId(jobId);
|
||||
currentTaskId.set(TaskId.randomId());
|
||||
currentJobId = jobId;
|
||||
} else {
|
||||
workerId = UniqueId.randomId();
|
||||
this.currentTaskId.set(TaskId.NIL);
|
||||
this.currentJobId = JobId.NIL;
|
||||
}
|
||||
}
|
||||
|
||||
public long getNativeWorkerContext() {
|
||||
return nativeWorkerContextPointer;
|
||||
}
|
||||
|
||||
/**
|
||||
* @return For the main thread, this method returns the ID of this worker's current running task;
|
||||
* for other threads, this method returns a random ID.
|
||||
* for other threads, this method returns a random ID.
|
||||
*/
|
||||
public TaskId getCurrentTaskId() {
|
||||
return currentTaskId.get();
|
||||
return new TaskId(nativeGetCurrentTaskId(nativeWorkerContextPointer));
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -79,17 +58,14 @@ public class WorkerContext {
|
|||
public void setCurrentTask(TaskSpec task, ClassLoader classLoader) {
|
||||
if (runMode == RunMode.CLUSTER) {
|
||||
Preconditions.checkState(
|
||||
Thread.currentThread().getId() == mainThreadId,
|
||||
"This method should only be called from the main thread."
|
||||
Thread.currentThread().getId() == mainThreadId,
|
||||
"This method should only be called from the main thread."
|
||||
);
|
||||
}
|
||||
|
||||
Preconditions.checkNotNull(task);
|
||||
this.currentTaskId.set(task.taskId);
|
||||
this.currentJobId = task.jobId;
|
||||
taskIndex.set(0);
|
||||
putIndex.set(0);
|
||||
this.currentTask.set(task);
|
||||
byte[] taskSpec = RayletClientImpl.convertTaskSpecToProtobuf(task);
|
||||
nativeSetCurrentTask(nativeWorkerContextPointer, taskSpec);
|
||||
currentClassLoader = classLoader;
|
||||
}
|
||||
|
||||
|
@ -97,30 +73,28 @@ public class WorkerContext {
|
|||
* Increment the put index and return the new value.
|
||||
*/
|
||||
public int nextPutIndex() {
|
||||
putIndex.set(putIndex.get() + 1);
|
||||
return putIndex.get();
|
||||
return nativeGetNextPutIndex(nativeWorkerContextPointer);
|
||||
}
|
||||
|
||||
/**
|
||||
* Increment the task index and return the new value.
|
||||
*/
|
||||
public int nextTaskIndex() {
|
||||
taskIndex.set(taskIndex.get() + 1);
|
||||
return taskIndex.get();
|
||||
return nativeGetNextTaskIndex(nativeWorkerContextPointer);
|
||||
}
|
||||
|
||||
/**
|
||||
* @return The ID of the current worker.
|
||||
*/
|
||||
public UniqueId getCurrentWorkerId() {
|
||||
return workerId;
|
||||
return new UniqueId(nativeGetCurrentWorkerId(nativeWorkerContextPointer));
|
||||
}
|
||||
|
||||
/**
|
||||
* The ID of the current job.
|
||||
*/
|
||||
public JobId getCurrentJobId() {
|
||||
return currentJobId;
|
||||
return JobId.fromByteBuffer(nativeGetCurrentJobId(nativeWorkerContextPointer));
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -134,6 +108,32 @@ public class WorkerContext {
|
|||
* Get the current task.
|
||||
*/
|
||||
public TaskSpec getCurrentTask() {
|
||||
return this.currentTask.get();
|
||||
byte[] bytes = nativeGetCurrentTask(nativeWorkerContextPointer);
|
||||
if (bytes == null) {
|
||||
return null;
|
||||
}
|
||||
return RayletClientImpl.parseTaskSpecFromProtobuf(bytes);
|
||||
}
|
||||
|
||||
public void destroy() {
|
||||
nativeDestroy(nativeWorkerContextPointer);
|
||||
}
|
||||
|
||||
private static native long nativeCreateWorkerContext(int workerType, byte[] jobId);
|
||||
|
||||
private static native byte[] nativeGetCurrentTaskId(long nativeWorkerContextPointer);
|
||||
|
||||
private static native void nativeSetCurrentTask(long nativeWorkerContextPointer, byte[] taskSpec);
|
||||
|
||||
private static native byte[] nativeGetCurrentTask(long nativeWorkerContextPointer);
|
||||
|
||||
private static native ByteBuffer nativeGetCurrentJobId(long nativeWorkerContextPointer);
|
||||
|
||||
private static native byte[] nativeGetCurrentWorkerId(long nativeWorkerContextPointer);
|
||||
|
||||
private static native int nativeGetNextTaskIndex(long nativeWorkerContextPointer);
|
||||
|
||||
private static native int nativeGetNextPutIndex(long nativeWorkerContextPointer);
|
||||
|
||||
private static native void nativeDestroy(long nativeWorkerContextPointer);
|
||||
}
|
||||
|
|
|
@ -11,6 +11,7 @@ import java.util.ArrayList;
|
|||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import org.ray.api.id.JobId;
|
||||
import org.ray.runtime.generated.Common.WorkerType;
|
||||
import org.ray.runtime.util.NetworkUtil;
|
||||
import org.ray.runtime.util.ResourceUtil;
|
||||
import org.ray.runtime.util.StringUtil;
|
||||
|
@ -29,7 +30,7 @@ public class RayConfig {
|
|||
public static final String CUSTOM_CONFIG_FILE = "ray.conf";
|
||||
|
||||
public final String nodeIp;
|
||||
public final WorkerMode workerMode;
|
||||
public final WorkerType workerMode;
|
||||
public final RunMode runMode;
|
||||
public final Map<String, Double> resources;
|
||||
private JobId jobId;
|
||||
|
@ -62,7 +63,7 @@ public class RayConfig {
|
|||
public final int numberExecThreadsForDevRuntime;
|
||||
|
||||
private void validate() {
|
||||
if (workerMode == WorkerMode.WORKER) {
|
||||
if (workerMode == WorkerType.WORKER) {
|
||||
Preconditions.checkArgument(redisAddress != null,
|
||||
"Redis address must be set in worker mode.");
|
||||
}
|
||||
|
@ -78,14 +79,14 @@ public class RayConfig {
|
|||
|
||||
public RayConfig(Config config) {
|
||||
// Worker mode.
|
||||
WorkerMode localWorkerMode;
|
||||
WorkerType localWorkerMode;
|
||||
try {
|
||||
localWorkerMode = config.getEnum(WorkerMode.class, "ray.worker.mode");
|
||||
localWorkerMode = config.getEnum(WorkerType.class, "ray.worker.mode");
|
||||
} catch (ConfigException.Missing e) {
|
||||
localWorkerMode = WorkerMode.DRIVER;
|
||||
localWorkerMode = WorkerType.DRIVER;
|
||||
}
|
||||
workerMode = localWorkerMode;
|
||||
boolean isDriver = workerMode == WorkerMode.DRIVER;
|
||||
boolean isDriver = workerMode == WorkerType.DRIVER;
|
||||
// Run mode.
|
||||
runMode = config.getEnum(RunMode.class, "ray.run-mode");
|
||||
// Node ip.
|
||||
|
|
|
@ -1,6 +0,0 @@
|
|||
package org.ray.runtime.config;
|
||||
|
||||
public enum WorkerMode {
|
||||
DRIVER,
|
||||
WORKER
|
||||
}
|
|
@ -0,0 +1,98 @@
|
|||
package org.ray.runtime.objectstore;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
import java.util.function.Consumer;
|
||||
import java.util.stream.Collectors;
|
||||
import org.ray.api.id.ObjectId;
|
||||
import org.ray.runtime.WorkerContext;
|
||||
import org.ray.runtime.util.IdUtil;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
public class MockObjectInterface implements ObjectInterface {
|
||||
|
||||
private static final Logger LOGGER = LoggerFactory.getLogger(MockObjectInterface.class);
|
||||
|
||||
private static final int GET_CHECK_INTERVAL_MS = 100;
|
||||
|
||||
private final Map<ObjectId, NativeRayObject> pool = new ConcurrentHashMap<>();
|
||||
private final List<Consumer<ObjectId>> objectPutCallbacks = new ArrayList<>();
|
||||
private final WorkerContext workerContext;
|
||||
|
||||
public MockObjectInterface(WorkerContext workerContext) {
|
||||
this.workerContext = workerContext;
|
||||
}
|
||||
|
||||
public void addObjectPutCallback(Consumer<ObjectId> callback) {
|
||||
this.objectPutCallbacks.add(callback);
|
||||
}
|
||||
|
||||
public boolean isObjectReady(ObjectId id) {
|
||||
return pool.containsKey(id);
|
||||
}
|
||||
|
||||
@Override
|
||||
public ObjectId put(NativeRayObject obj) {
|
||||
ObjectId objectId = IdUtil.computePutId(workerContext.getCurrentTaskId(),
|
||||
workerContext.nextPutIndex());
|
||||
put(obj, objectId);
|
||||
return objectId;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void put(NativeRayObject obj, ObjectId objectId) {
|
||||
Preconditions.checkNotNull(obj);
|
||||
Preconditions.checkNotNull(objectId);
|
||||
pool.putIfAbsent(objectId, obj);
|
||||
for (Consumer<ObjectId> callback : objectPutCallbacks) {
|
||||
callback.accept(objectId);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<NativeRayObject> get(List<ObjectId> objectIds, long timeoutMs) {
|
||||
waitInternal(objectIds, objectIds.size(), timeoutMs);
|
||||
return objectIds.stream().map(pool::get).collect(Collectors.toList());
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<Boolean> wait(List<ObjectId> objectIds, int numObjects, long timeoutMs) {
|
||||
waitInternal(objectIds, numObjects, timeoutMs);
|
||||
return objectIds.stream().map(pool::containsKey).collect(Collectors.toList());
|
||||
}
|
||||
|
||||
private void waitInternal(List<ObjectId> objectIds, int numObjects, long timeoutMs) {
|
||||
int ready = 0;
|
||||
long remainingTime = timeoutMs;
|
||||
boolean firstCheck = true;
|
||||
while (ready < numObjects && (timeoutMs < 0 || remainingTime > 0)) {
|
||||
if (!firstCheck) {
|
||||
long sleepTime = Math.min(remainingTime, GET_CHECK_INTERVAL_MS);
|
||||
try {
|
||||
Thread.sleep(sleepTime);
|
||||
} catch (InterruptedException e) {
|
||||
LOGGER.warn("Got InterruptedException while sleeping.");
|
||||
}
|
||||
remainingTime -= sleepTime;
|
||||
}
|
||||
ready = 0;
|
||||
for (ObjectId objectId : objectIds) {
|
||||
if (pool.containsKey(objectId)) {
|
||||
ready += 1;
|
||||
}
|
||||
}
|
||||
firstCheck = false;
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void delete(List<ObjectId> objectIds, boolean localOnly, boolean deleteCreatingTasks) {
|
||||
for (ObjectId objectId : objectIds) {
|
||||
pool.remove(objectId);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,148 +0,0 @@
|
|||
package org.ray.runtime.objectstore;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
import java.util.function.Consumer;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
import org.apache.arrow.plasma.ObjectStoreLink;
|
||||
import org.ray.api.id.ObjectId;
|
||||
import org.ray.runtime.RayDevRuntime;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
/**
|
||||
* A mock implementation of {@code org.ray.spi.ObjectStoreLink}, which use Map to store data.
|
||||
*/
|
||||
public class MockObjectStore implements ObjectStoreLink {
|
||||
|
||||
private static final Logger LOGGER = LoggerFactory.getLogger(MockObjectStore.class);
|
||||
|
||||
private static final int GET_CHECK_INTERVAL_MS = 100;
|
||||
|
||||
private final RayDevRuntime runtime;
|
||||
private final Map<ObjectId, byte[]> data = new ConcurrentHashMap<>();
|
||||
private final Map<ObjectId, byte[]> metadata = new ConcurrentHashMap<>();
|
||||
private final List<Consumer<ObjectId>> objectPutCallbacks;
|
||||
|
||||
public MockObjectStore(RayDevRuntime runtime) {
|
||||
this.runtime = runtime;
|
||||
this.objectPutCallbacks = new ArrayList<>();
|
||||
}
|
||||
|
||||
public void addObjectPutCallback(Consumer<ObjectId> callback) {
|
||||
this.objectPutCallbacks.add(callback);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void put(byte[] objectId, byte[] value, byte[] metadataValue) {
|
||||
if (objectId == null || objectId.length == 0 || value == null) {
|
||||
LOGGER
|
||||
.error("{} cannot put null: {}, {}", logPrefix(), objectId, Arrays.toString(value));
|
||||
System.exit(-1);
|
||||
}
|
||||
ObjectId id = new ObjectId(objectId);
|
||||
data.put(id, value);
|
||||
if (metadataValue != null) {
|
||||
metadata.put(id, metadataValue);
|
||||
}
|
||||
for (Consumer<ObjectId> callback : objectPutCallbacks) {
|
||||
callback.accept(id);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public byte[] get(byte[] objectId, int timeoutMs, boolean isMetadata) {
|
||||
return get(new byte[][] {objectId}, timeoutMs, isMetadata).get(0);
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<byte[]> get(byte[][] objectIds, int timeoutMs, boolean isMetadata) {
|
||||
return get(objectIds, timeoutMs)
|
||||
.stream()
|
||||
.map(data -> isMetadata ? data.metadata : data.data)
|
||||
.collect(Collectors.toList());
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<ObjectStoreData> get(byte[][] objectIds, int timeoutMs) {
|
||||
int ready = 0;
|
||||
int remainingTime = timeoutMs;
|
||||
boolean firstCheck = true;
|
||||
while (ready < objectIds.length && remainingTime > 0) {
|
||||
if (!firstCheck) {
|
||||
int sleepTime = Math.min(remainingTime, GET_CHECK_INTERVAL_MS);
|
||||
try {
|
||||
Thread.sleep(sleepTime);
|
||||
} catch (InterruptedException e) {
|
||||
LOGGER.warn("Got InterruptedException while sleeping.");
|
||||
}
|
||||
remainingTime -= sleepTime;
|
||||
}
|
||||
ready = 0;
|
||||
for (byte[] id : objectIds) {
|
||||
if (data.containsKey(new ObjectId(id))) {
|
||||
ready += 1;
|
||||
}
|
||||
}
|
||||
firstCheck = false;
|
||||
}
|
||||
ArrayList<ObjectStoreData> rets = new ArrayList<>();
|
||||
for (byte[] objId : objectIds) {
|
||||
ObjectId objectId = new ObjectId(objId);
|
||||
rets.add(new ObjectStoreData(metadata.get(objectId), data.get(objectId)));
|
||||
}
|
||||
return rets;
|
||||
}
|
||||
|
||||
@Override
|
||||
public byte[] hash(byte[] objectId) {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public long evict(long numBytes) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void release(byte[] objectId) {
|
||||
return;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void delete(byte[] objectId) {
|
||||
return;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean contains(byte[] objectId) {
|
||||
return data.containsKey(new ObjectId(objectId));
|
||||
}
|
||||
|
||||
private String logPrefix() {
|
||||
return runtime.getWorkerContext().getCurrentTaskId() + "-" + getUserTrace() + " -> ";
|
||||
}
|
||||
|
||||
private String getUserTrace() {
|
||||
StackTraceElement[] stes = Thread.currentThread().getStackTrace();
|
||||
int k = 1;
|
||||
while (stes[k].getClassName().startsWith("org.ray")
|
||||
&& !stes[k].getClassName().contains("test")) {
|
||||
k++;
|
||||
}
|
||||
return stes[k].getFileName() + ":" + stes[k].getLineNumber();
|
||||
}
|
||||
|
||||
public boolean isObjectReady(ObjectId id) {
|
||||
return data.containsKey(id);
|
||||
}
|
||||
|
||||
public void free(ObjectId id) {
|
||||
data.remove(id);
|
||||
metadata.remove(id);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,13 @@
|
|||
package org.ray.runtime.objectstore;
|
||||
|
||||
public class NativeRayObject {
|
||||
|
||||
public byte[] data;
|
||||
public byte[] metadata;
|
||||
|
||||
public NativeRayObject(byte[] data, byte[] metadata) {
|
||||
this.data = data;
|
||||
this.metadata = metadata;
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,54 @@
|
|||
package org.ray.runtime.objectstore;
|
||||
|
||||
import java.util.List;
|
||||
import org.ray.api.id.ObjectId;
|
||||
|
||||
/**
|
||||
* The interface that contains all worker methods that are related to object store.
|
||||
*/
|
||||
public interface ObjectInterface {
|
||||
|
||||
/**
|
||||
* Put an object into object store.
|
||||
*
|
||||
* @param obj The ray object.
|
||||
* @return Generated ID of the object.
|
||||
*/
|
||||
ObjectId put(NativeRayObject obj);
|
||||
|
||||
/**
|
||||
* Put an object with specified ID into object store.
|
||||
*
|
||||
* @param obj The ray object.
|
||||
* @param objectId Object ID specified by user.
|
||||
*/
|
||||
void put(NativeRayObject obj, ObjectId objectId);
|
||||
|
||||
/**
|
||||
* Get a list of objects from the object store.
|
||||
*
|
||||
* @param objectIds IDs of the objects to get.
|
||||
* @param timeoutMs Timeout in milliseconds, wait infinitely if it's negative.
|
||||
* @return Result list of objects data.
|
||||
*/
|
||||
List<NativeRayObject> get(List<ObjectId> objectIds, long timeoutMs);
|
||||
|
||||
/**
|
||||
* Wait for a list of objects to appear in the object store.
|
||||
*
|
||||
* @param objectIds IDs of the objects to wait for.
|
||||
* @param numObjects Number of objects that should appear.
|
||||
* @param timeoutMs Timeout in milliseconds, wait infinitely if it's negative.
|
||||
* @return A bitset that indicates each object has appeared or not.
|
||||
*/
|
||||
List<Boolean> wait(List<ObjectId> objectIds, int numObjects, long timeoutMs);
|
||||
|
||||
/**
|
||||
* Delete a list of objects from the object store.
|
||||
*
|
||||
* @param objectIds IDs of the objects to delete.
|
||||
* @param localOnly Whether only delete the objects in local node, or all nodes in the cluster.
|
||||
* @param deleteCreatingTasks Whether also delete the tasks that created these objects.
|
||||
*/
|
||||
void delete(List<ObjectId> objectIds, boolean localOnly, boolean deleteCreatingTasks);
|
||||
}
|
|
@ -0,0 +1,91 @@
|
|||
package org.ray.runtime.objectstore;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.stream.Collectors;
|
||||
import org.ray.api.exception.RayException;
|
||||
import org.ray.api.id.BaseId;
|
||||
import org.ray.api.id.ObjectId;
|
||||
import org.ray.runtime.AbstractRayRuntime;
|
||||
import org.ray.runtime.WorkerContext;
|
||||
import org.ray.runtime.raylet.RayletClient;
|
||||
import org.ray.runtime.raylet.RayletClientImpl;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
/**
|
||||
* This is a wrapper class for core worker object interface.
|
||||
*/
|
||||
public class ObjectInterfaceImpl implements ObjectInterface {
|
||||
|
||||
private static final Logger LOGGER = LoggerFactory.getLogger(AbstractRayRuntime.class);
|
||||
|
||||
/**
|
||||
* The native pointer of core worker object interface.
|
||||
*/
|
||||
private final long nativeObjectInterfacePointer;
|
||||
|
||||
public ObjectInterfaceImpl(WorkerContext workerContext, RayletClient rayletClient,
|
||||
String storeSocketName) {
|
||||
this.nativeObjectInterfacePointer =
|
||||
nativeCreateObjectInterface(workerContext.getNativeWorkerContext(),
|
||||
((RayletClientImpl) rayletClient).getClient(), storeSocketName);
|
||||
}
|
||||
|
||||
@Override
|
||||
public ObjectId put(NativeRayObject obj) {
|
||||
return new ObjectId(nativePut(nativeObjectInterfacePointer, obj));
|
||||
}
|
||||
|
||||
@Override
|
||||
public void put(NativeRayObject obj, ObjectId objectId) {
|
||||
try {
|
||||
nativePut(nativeObjectInterfacePointer, objectId.getBytes(), obj);
|
||||
} catch (RayException e) {
|
||||
LOGGER.warn(e.getMessage());
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<NativeRayObject> get(List<ObjectId> objectIds, long timeoutMs) {
|
||||
return nativeGet(nativeObjectInterfacePointer, toBinaryList(objectIds), timeoutMs);
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<Boolean> wait(List<ObjectId> objectIds, int numObjects, long timeoutMs) {
|
||||
return nativeWait(nativeObjectInterfacePointer, toBinaryList(objectIds), numObjects, timeoutMs);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void delete(List<ObjectId> objectIds, boolean localOnly, boolean deleteCreatingTasks) {
|
||||
nativeDelete(nativeObjectInterfacePointer, toBinaryList(objectIds), localOnly, deleteCreatingTasks);
|
||||
}
|
||||
|
||||
public void destroy() {
|
||||
nativeDestroy(nativeObjectInterfacePointer);
|
||||
}
|
||||
|
||||
private static List<byte[]> toBinaryList(List<ObjectId> ids) {
|
||||
return ids.stream().map(BaseId::getBytes).collect(Collectors.toList());
|
||||
}
|
||||
|
||||
private static native long nativeCreateObjectInterface(long nativeObjectInterface,
|
||||
long nativeRayletClient,
|
||||
String storeSocketName);
|
||||
|
||||
private static native byte[] nativePut(long nativeObjectInterface, NativeRayObject obj);
|
||||
|
||||
private static native void nativePut(long nativeObjectInterface, byte[] objectId,
|
||||
NativeRayObject obj);
|
||||
|
||||
private static native List<NativeRayObject> nativeGet(long nativeObjectInterface,
|
||||
List<byte[]> ids,
|
||||
long timeoutMs);
|
||||
|
||||
private static native List<Boolean> nativeWait(long nativeObjectInterface, List<byte[]> objectIds,
|
||||
int numObjects, long timeoutMs);
|
||||
|
||||
private static native void nativeDelete(long nativeObjectInterface, List<byte[]> objectIds,
|
||||
boolean localOnly, boolean deleteCreatingTasks);
|
||||
|
||||
private static native void nativeDestroy(long nativeObjectInterface);
|
||||
}
|
|
@ -4,20 +4,14 @@ import com.google.common.collect.ImmutableList;
|
|||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import org.apache.arrow.plasma.ObjectStoreLink;
|
||||
import org.apache.arrow.plasma.ObjectStoreLink.ObjectStoreData;
|
||||
import org.apache.arrow.plasma.PlasmaClient;
|
||||
import org.apache.arrow.plasma.exceptions.DuplicateObjectException;
|
||||
import org.ray.api.exception.RayActorException;
|
||||
import org.ray.api.exception.RayException;
|
||||
import org.ray.api.exception.RayTaskException;
|
||||
import org.ray.api.exception.RayWorkerException;
|
||||
import org.ray.api.exception.UnreconstructableException;
|
||||
import org.ray.api.id.ObjectId;
|
||||
import org.ray.runtime.AbstractRayRuntime;
|
||||
import org.ray.runtime.RayDevRuntime;
|
||||
import org.ray.runtime.config.RunMode;
|
||||
import org.ray.runtime.WorkerContext;
|
||||
import org.ray.runtime.generated.Gcs.ErrorType;
|
||||
import org.ray.runtime.util.IdUtil;
|
||||
import org.ray.runtime.util.Serializer;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
@ -36,21 +30,18 @@ public class ObjectStoreProxy {
|
|||
private static final byte[] UNRECONSTRUCTABLE_EXCEPTION_META = String
|
||||
.valueOf(ErrorType.OBJECT_UNRECONSTRUCTABLE.getNumber()).getBytes();
|
||||
|
||||
private static final byte[] TASK_EXECUTION_EXCEPTION_META = String
|
||||
.valueOf(ErrorType.TASK_EXECUTION_EXCEPTION.getNumber()).getBytes();
|
||||
|
||||
private static final byte[] RAW_TYPE_META = "RAW".getBytes();
|
||||
|
||||
private final AbstractRayRuntime runtime;
|
||||
private final WorkerContext workerContext;
|
||||
|
||||
private static ThreadLocal<ObjectStoreLink> objectStore;
|
||||
private final ObjectInterface objectInterface;
|
||||
|
||||
public ObjectStoreProxy(AbstractRayRuntime runtime, String storeSocketName) {
|
||||
this.runtime = runtime;
|
||||
objectStore = ThreadLocal.withInitial(() -> {
|
||||
if (runtime.getRayConfig().runMode == RunMode.CLUSTER) {
|
||||
return new PlasmaClient(storeSocketName, "", 0);
|
||||
} else {
|
||||
return ((RayDevRuntime) runtime).getObjectStore();
|
||||
}
|
||||
});
|
||||
public ObjectStoreProxy(WorkerContext workerContext, ObjectInterface objectInterface) {
|
||||
this.workerContext = workerContext;
|
||||
this.objectInterface = objectInterface;
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -75,46 +66,44 @@ public class ObjectStoreProxy {
|
|||
* @return A list of GetResult objects.
|
||||
*/
|
||||
public <T> List<GetResult<T>> get(List<ObjectId> ids, int timeoutMs) {
|
||||
byte[][] binaryIds = IdUtil.getIdBytes(ids);
|
||||
List<ObjectStoreData> dataAndMetaList = objectStore.get().get(binaryIds, timeoutMs);
|
||||
List<NativeRayObject> dataAndMetaList = objectInterface.get(ids, timeoutMs);
|
||||
|
||||
List<GetResult<T>> results = new ArrayList<>();
|
||||
for (int i = 0; i < dataAndMetaList.size(); i++) {
|
||||
byte[] meta = dataAndMetaList.get(i).metadata;
|
||||
byte[] data = dataAndMetaList.get(i).data;
|
||||
|
||||
NativeRayObject dataAndMeta = dataAndMetaList.get(i);
|
||||
GetResult<T> result;
|
||||
if (meta != null) {
|
||||
// If meta is not null, deserialize the object from meta.
|
||||
result = deserializeFromMeta(meta, data, ids.get(i));
|
||||
} else if (data != null) {
|
||||
// If data is not null, deserialize the Java object.
|
||||
Object object = Serializer.decode(data, runtime.getWorkerContext().getCurrentClassLoader());
|
||||
if (object instanceof RayException) {
|
||||
// If the object is a `RayException`, it means that an error occurred during task
|
||||
// execution.
|
||||
result = new GetResult<>(true, null, (RayException) object);
|
||||
if (dataAndMeta != null) {
|
||||
byte[] meta = dataAndMeta.metadata;
|
||||
byte[] data = dataAndMeta.data;
|
||||
if (meta != null && meta.length > 0) {
|
||||
// If meta is not null, deserialize the object from meta.
|
||||
result = deserializeFromMeta(meta, data,
|
||||
workerContext.getCurrentClassLoader(), ids.get(i));
|
||||
} else {
|
||||
// Otherwise, the object is valid.
|
||||
result = new GetResult<>(true, (T) object, null);
|
||||
// If data is not null, deserialize the Java object.
|
||||
Object object = Serializer.decode(data, workerContext.getCurrentClassLoader());
|
||||
if (object instanceof RayException) {
|
||||
// If the object is a `RayException`, it means that an error occurred during task
|
||||
// execution.
|
||||
result = new GetResult<>(true, null, (RayException) object);
|
||||
} else {
|
||||
// Otherwise, the object is valid.
|
||||
result = new GetResult<>(true, (T) object, null);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// If both meta and data are null, the object doesn't exist in object store.
|
||||
result = new GetResult<>(false, null, null);
|
||||
}
|
||||
|
||||
if (meta != null || data != null) {
|
||||
// Release the object from object store..
|
||||
objectStore.get().release(binaryIds[i]);
|
||||
}
|
||||
|
||||
results.add(result);
|
||||
}
|
||||
return results;
|
||||
}
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
private <T> GetResult<T> deserializeFromMeta(byte[] meta, byte[] data, ObjectId objectId) {
|
||||
private <T> GetResult<T> deserializeFromMeta(byte[] meta, byte[] data,
|
||||
ClassLoader classLoader, ObjectId objectId) {
|
||||
if (Arrays.equals(meta, RAW_TYPE_META)) {
|
||||
return (GetResult<T>) new GetResult<>(true, data, null);
|
||||
} else if (Arrays.equals(meta, WORKER_EXCEPTION_META)) {
|
||||
|
@ -123,6 +112,8 @@ public class ObjectStoreProxy {
|
|||
return new GetResult<>(true, null, RayActorException.INSTANCE);
|
||||
} else if (Arrays.equals(meta, UNRECONSTRUCTABLE_EXCEPTION_META)) {
|
||||
return new GetResult<>(true, null, new UnreconstructableException(objectId));
|
||||
} else if (Arrays.equals(meta, TASK_EXECUTION_EXCEPTION_META)) {
|
||||
return new GetResult<>(true, null, Serializer.decode(data, classLoader));
|
||||
}
|
||||
throw new IllegalArgumentException("Unrecognized metadata " + Arrays.toString(meta));
|
||||
}
|
||||
|
@ -134,16 +125,14 @@ public class ObjectStoreProxy {
|
|||
* @param object The object to put.
|
||||
*/
|
||||
public void put(ObjectId id, Object object) {
|
||||
try {
|
||||
if (object instanceof byte[]) {
|
||||
// If the object is a byte array, skip serializing it and use a special metadata to
|
||||
// indicate it's raw binary. So that this object can also be read by Python.
|
||||
objectStore.get().put(id.getBytes(), (byte[]) object, RAW_TYPE_META);
|
||||
} else {
|
||||
objectStore.get().put(id.getBytes(), Serializer.encode(object), null);
|
||||
}
|
||||
} catch (DuplicateObjectException e) {
|
||||
LOGGER.warn(e.getMessage());
|
||||
if (object instanceof byte[]) {
|
||||
// If the object is a byte array, skip serializing it and use a special metadata to
|
||||
// indicate it's raw binary. So that this object can also be read by Python.
|
||||
objectInterface.put(new NativeRayObject((byte[]) object, RAW_TYPE_META), id);
|
||||
} else if (object instanceof RayTaskException) {
|
||||
objectInterface.put(new NativeRayObject(Serializer.encode(object), TASK_EXECUTION_EXCEPTION_META), id);
|
||||
} else {
|
||||
objectInterface.put(new NativeRayObject(Serializer.encode(object), null), id);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -154,11 +143,7 @@ public class ObjectStoreProxy {
|
|||
* @param serializedObject The serialized object to put.
|
||||
*/
|
||||
public void putSerialized(ObjectId id, byte[] serializedObject) {
|
||||
try {
|
||||
objectStore.get().put(id.getBytes(), serializedObject, null);
|
||||
} catch (DuplicateObjectException e) {
|
||||
LOGGER.warn(e.getMessage());
|
||||
}
|
||||
objectInterface.put(new NativeRayObject(serializedObject, null), id);
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
@ -14,6 +14,7 @@ import java.util.concurrent.ConcurrentLinkedDeque;
|
|||
import java.util.concurrent.ExecutorService;
|
||||
import java.util.concurrent.Executors;
|
||||
|
||||
import java.util.stream.Collectors;
|
||||
import org.apache.commons.lang3.NotImplementedException;
|
||||
import org.ray.api.RayObject;
|
||||
import org.ray.api.WaitResult;
|
||||
|
@ -23,7 +24,8 @@ import org.ray.api.id.TaskId;
|
|||
import org.ray.api.id.UniqueId;
|
||||
import org.ray.runtime.RayDevRuntime;
|
||||
import org.ray.runtime.Worker;
|
||||
import org.ray.runtime.objectstore.MockObjectStore;
|
||||
import org.ray.runtime.objectstore.MockObjectInterface;
|
||||
import org.ray.runtime.objectstore.NativeRayObject;
|
||||
import org.ray.runtime.task.FunctionArg;
|
||||
import org.ray.runtime.task.TaskSpec;
|
||||
import org.slf4j.Logger;
|
||||
|
@ -37,7 +39,7 @@ public class MockRayletClient implements RayletClient {
|
|||
private static final Logger LOGGER = LoggerFactory.getLogger(MockRayletClient.class);
|
||||
|
||||
private final Map<ObjectId, Set<TaskSpec>> waitingTasks = new ConcurrentHashMap<>();
|
||||
private final MockObjectStore store;
|
||||
private final MockObjectInterface objectInterface;
|
||||
private final RayDevRuntime runtime;
|
||||
private final ExecutorService exec;
|
||||
private final Deque<Worker> idleWorkers;
|
||||
|
@ -46,8 +48,8 @@ public class MockRayletClient implements RayletClient {
|
|||
|
||||
public MockRayletClient(RayDevRuntime runtime, int numberThreads) {
|
||||
this.runtime = runtime;
|
||||
this.store = runtime.getObjectStore();
|
||||
store.addObjectPutCallback(this::onObjectPut);
|
||||
this.objectInterface = runtime.getObjectInterface();
|
||||
objectInterface.addObjectPutCallback(this::onObjectPut);
|
||||
// The thread pool that executes tasks in parallel.
|
||||
exec = Executors.newFixedThreadPool(numberThreads);
|
||||
idleWorkers = new ConcurrentLinkedDeque<>();
|
||||
|
@ -113,8 +115,8 @@ public class MockRayletClient implements RayletClient {
|
|||
// can be executed.
|
||||
if (task.isActorCreationTask() || task.isActorTask()) {
|
||||
ObjectId[] returnIds = task.returnIds;
|
||||
store.put(returnIds[returnIds.length - 1].getBytes(),
|
||||
new byte[]{}, new byte[]{});
|
||||
objectInterface.put(new NativeRayObject(new byte[] {}, new byte[] {}),
|
||||
returnIds[returnIds.length - 1]);
|
||||
}
|
||||
} finally {
|
||||
returnWorker(worker);
|
||||
|
@ -133,13 +135,13 @@ public class MockRayletClient implements RayletClient {
|
|||
// Check whether task arguments are ready.
|
||||
for (FunctionArg arg : spec.args) {
|
||||
if (arg.id != null) {
|
||||
if (!store.isObjectReady(arg.id)) {
|
||||
if (!objectInterface.isObjectReady(arg.id)) {
|
||||
unreadyObjects.add(arg.id);
|
||||
}
|
||||
}
|
||||
}
|
||||
if (spec.isActorTask()) {
|
||||
if (!store.isObjectReady(spec.previousActorTaskDummyObjectId)) {
|
||||
if (!objectInterface.isObjectReady(spec.previousActorTaskDummyObjectId)) {
|
||||
unreadyObjects.add(spec.previousActorTaskDummyObjectId);
|
||||
}
|
||||
}
|
||||
|
@ -154,7 +156,7 @@ public class MockRayletClient implements RayletClient {
|
|||
|
||||
@Override
|
||||
public void fetchOrReconstruct(List<ObjectId> objectIds, boolean fetchOnly,
|
||||
TaskId currentTaskId) {
|
||||
TaskId currentTaskId) {
|
||||
|
||||
}
|
||||
|
||||
|
@ -170,20 +172,17 @@ public class MockRayletClient implements RayletClient {
|
|||
|
||||
@Override
|
||||
public <T> WaitResult<T> wait(List<RayObject<T>> waitFor, int numReturns, int
|
||||
timeoutMs, TaskId currentTaskId) {
|
||||
timeoutMs, TaskId currentTaskId) {
|
||||
if (waitFor == null || waitFor.isEmpty()) {
|
||||
return new WaitResult<>(ImmutableList.of(), ImmutableList.of());
|
||||
}
|
||||
|
||||
byte[][] ids = new byte[waitFor.size()][];
|
||||
for (int i = 0; i < waitFor.size(); i++) {
|
||||
ids[i] = waitFor.get(i).getId().getBytes();
|
||||
}
|
||||
List<ObjectId> ids = waitFor.stream().map(RayObject::getId).collect(Collectors.toList());
|
||||
List<RayObject<T>> readyList = new ArrayList<>();
|
||||
List<RayObject<T>> unreadyList = new ArrayList<>();
|
||||
List<byte[]> result = store.get(ids, timeoutMs, false);
|
||||
List<Boolean> result = objectInterface.wait(ids, ids.size(), timeoutMs);
|
||||
for (int i = 0; i < waitFor.size(); i++) {
|
||||
if (result.get(i) != null) {
|
||||
if (result.get(i)) {
|
||||
readyList.add(waitFor.get(i));
|
||||
} else {
|
||||
unreadyList.add(waitFor.get(i));
|
||||
|
@ -195,9 +194,7 @@ public class MockRayletClient implements RayletClient {
|
|||
@Override
|
||||
public void freePlasmaObjects(List<ObjectId> objectIds, boolean localOnly,
|
||||
boolean deleteCreatingTasks) {
|
||||
for (ObjectId id : objectIds) {
|
||||
store.free(id);
|
||||
}
|
||||
objectInterface.delete(objectIds, localOnly, deleteCreatingTasks);
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -4,8 +4,6 @@ import com.google.common.base.Preconditions;
|
|||
import com.google.common.collect.ImmutableList;
|
||||
import com.google.protobuf.ByteString;
|
||||
import com.google.protobuf.InvalidProtocolBufferException;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.nio.ByteOrder;
|
||||
import java.nio.charset.Charset;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
|
@ -40,11 +38,15 @@ public class RayletClientImpl implements RayletClient {
|
|||
|
||||
// TODO(qwang): JobId parameter can be removed once we embed jobId in driverId.
|
||||
public RayletClientImpl(String schedulerSockName, UniqueId clientId,
|
||||
boolean isWorker, JobId jobId) {
|
||||
boolean isWorker, JobId jobId) {
|
||||
client = nativeInit(schedulerSockName, clientId.getBytes(),
|
||||
isWorker, jobId.getBytes());
|
||||
}
|
||||
|
||||
public long getClient() {
|
||||
return client;
|
||||
}
|
||||
|
||||
@Override
|
||||
public <T> WaitResult<T> wait(List<RayObject<T>> waitFor, int numReturns, int
|
||||
timeoutMs, TaskId currentTaskId) {
|
||||
|
@ -133,7 +135,7 @@ public class RayletClientImpl implements RayletClient {
|
|||
/**
|
||||
* Parse `TaskSpec` protobuf bytes.
|
||||
*/
|
||||
private static TaskSpec parseTaskSpecFromProtobuf(byte[] bytes) {
|
||||
public static TaskSpec parseTaskSpecFromProtobuf(byte[] bytes) {
|
||||
Common.TaskSpec taskSpec;
|
||||
try {
|
||||
taskSpec = Common.TaskSpec.parseFrom(bytes);
|
||||
|
@ -214,7 +216,7 @@ public class RayletClientImpl implements RayletClient {
|
|||
/**
|
||||
* Convert a `TaskSpec` to protobuf-serialized bytes.
|
||||
*/
|
||||
private static byte[] convertTaskSpecToProtobuf(TaskSpec task) {
|
||||
public static byte[] convertTaskSpecToProtobuf(TaskSpec task) {
|
||||
// Set common fields.
|
||||
Common.TaskSpec.Builder builder = Common.TaskSpec.newBuilder()
|
||||
.setJobId(ByteString.copyFrom(task.jobId.getBytes()))
|
||||
|
|
|
@ -154,18 +154,6 @@ public class IdUtil {
|
|||
}
|
||||
|
||||
|
||||
/**
|
||||
* Compute the driver id from the given job.
|
||||
*/
|
||||
public static UniqueId computeDriverId(JobId jobId) {
|
||||
byte[] bytes = new byte[UniqueId.LENGTH];
|
||||
System.arraycopy(jobId.getBytes(), 0, bytes, 0, jobId.size());
|
||||
Arrays.fill(bytes, jobId.size(), UniqueId.LENGTH, (byte)0xFF);
|
||||
ByteBuffer wbb = ByteBuffer.wrap(bytes);
|
||||
wbb.order(ByteOrder.LITTLE_ENDIAN);
|
||||
return new UniqueId(bytes);
|
||||
}
|
||||
|
||||
/**
|
||||
* Compute the murmur hash code of this ID.
|
||||
*/
|
||||
|
|
|
@ -1,12 +1,18 @@
|
|||
package org.ray.api.test;
|
||||
|
||||
import java.time.Duration;
|
||||
import java.time.Instant;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import org.ray.api.Ray;
|
||||
import org.ray.api.RayActor;
|
||||
import org.ray.api.RayObject;
|
||||
import org.ray.api.TestUtils;
|
||||
import org.ray.api.exception.RayActorException;
|
||||
import org.ray.api.exception.RayException;
|
||||
import org.ray.api.exception.RayTaskException;
|
||||
import org.ray.api.exception.RayWorkerException;
|
||||
import org.ray.api.function.RayFunc0;
|
||||
import org.testng.Assert;
|
||||
import org.testng.annotations.Test;
|
||||
|
||||
|
@ -23,6 +29,15 @@ public class FailureTest extends BaseTest {
|
|||
return 0;
|
||||
}
|
||||
|
||||
public static int slowFunc() {
|
||||
try {
|
||||
Thread.sleep(10000);
|
||||
} catch (InterruptedException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
public static class BadActor {
|
||||
|
||||
public BadActor(boolean failOnCreation) {
|
||||
|
@ -106,5 +121,26 @@ public class FailureTest extends BaseTest {
|
|||
// RayActorException.
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testGetThrowsQuicklyWhenFoundException() {
|
||||
TestUtils.skipTestUnderSingleProcess();
|
||||
List<RayFunc0<Integer>> badFunctions = Arrays.asList(FailureTest::badFunc,
|
||||
FailureTest::badFunc2);
|
||||
for (RayFunc0<Integer> badFunc : badFunctions) {
|
||||
RayObject<Integer> obj1 = Ray.call(badFunc);
|
||||
RayObject<Integer> obj2 = Ray.call(FailureTest::slowFunc);
|
||||
Instant start = Instant.now();
|
||||
try {
|
||||
Ray.get(Arrays.asList(obj1.getId(), obj2.getId()));
|
||||
Assert.fail("Should throw RayException.");
|
||||
} catch (RayException e) {
|
||||
Instant end = Instant.now();
|
||||
long duration = Duration.between(start, end).toMillis();
|
||||
Assert.assertTrue(duration < 5000, "Should fail quickly. " +
|
||||
"Actual execution time: " + duration + " ms.");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -1,12 +1,10 @@
|
|||
package org.ray.api.test;
|
||||
|
||||
import org.apache.arrow.plasma.PlasmaClient;
|
||||
import org.apache.arrow.plasma.exceptions.DuplicateObjectException;
|
||||
|
||||
import org.ray.api.Ray;
|
||||
import org.ray.api.TestUtils;
|
||||
import org.ray.api.id.UniqueId;
|
||||
import org.ray.api.id.ObjectId;
|
||||
import org.ray.runtime.AbstractRayRuntime;
|
||||
import org.ray.runtime.objectstore.ObjectStoreProxy;
|
||||
import org.testng.Assert;
|
||||
import org.testng.annotations.Test;
|
||||
|
||||
|
@ -15,15 +13,13 @@ public class PlasmaStoreTest extends BaseTest {
|
|||
@Test
|
||||
public void testPutWithDuplicateId() {
|
||||
TestUtils.skipTestUnderSingleProcess();
|
||||
UniqueId objectId = UniqueId.randomId();
|
||||
ObjectId objectId = ObjectId.randomId();
|
||||
AbstractRayRuntime runtime = (AbstractRayRuntime) Ray.internal();
|
||||
PlasmaClient store = new PlasmaClient(runtime.getRayConfig().objectStoreSocketName, "", 0);
|
||||
store.put(objectId.getBytes(), new byte[]{}, new byte[]{});
|
||||
try {
|
||||
store.put(objectId.getBytes(), new byte[]{}, new byte[]{});
|
||||
Assert.fail("This line shouldn't be reached.");
|
||||
} catch (DuplicateObjectException e) {
|
||||
// Putting 2 objects with duplicate ID should throw DuplicateObjectException.
|
||||
}
|
||||
ObjectStoreProxy objectInterface = runtime.getObjectStoreProxy();
|
||||
objectInterface.put(objectId, 1);
|
||||
Assert.assertEquals(objectInterface.<Integer>get(objectId, -1).object, (Integer) 1);
|
||||
objectInterface.put(objectId, 2);
|
||||
// Putting 2 objects with duplicate ID should fail but ignored.
|
||||
Assert.assertEquals(objectInterface.<Integer>get(objectId, -1).object, (Integer) 1);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
package org.ray.api.test;
|
||||
|
||||
import org.ray.runtime.config.RayConfig;
|
||||
import org.ray.runtime.config.WorkerMode;
|
||||
import org.ray.runtime.generated.Common.WorkerType;
|
||||
import org.testng.Assert;
|
||||
import org.testng.annotations.Test;
|
||||
|
||||
|
@ -12,7 +12,7 @@ public class RayConfigTest {
|
|||
try {
|
||||
System.setProperty("ray.job.resource-path", "path/to/ray/job/resource/path");
|
||||
RayConfig rayConfig = RayConfig.create();
|
||||
Assert.assertEquals(WorkerMode.DRIVER, rayConfig.workerMode);
|
||||
Assert.assertEquals(WorkerType.DRIVER, rayConfig.workerMode);
|
||||
Assert.assertEquals("path/to/ray/job/resource/path", rayConfig.jobResourcePath);
|
||||
} finally {
|
||||
// Unset system properties.
|
||||
|
|
|
@ -151,3 +151,10 @@ RAY_CONFIG(uint32_t, num_actor_checkpoints_to_keep, 20)
|
|||
|
||||
/// Maximum number of ids in one batch to send to GCS to delete keys.
|
||||
RAY_CONFIG(uint32_t, maximum_gcs_deletion_batch_size, 1000)
|
||||
|
||||
/// When getting objects from object store, print a warning every this number of attempts.
|
||||
RAY_CONFIG(uint32_t, object_store_get_warn_per_num_attempts, 50)
|
||||
|
||||
/// When getting objects from object store, max number of ids to print in the warning
|
||||
/// message.
|
||||
RAY_CONFIG(uint32_t, object_store_get_max_ids_to_print_in_warning, 20)
|
||||
|
|
|
@ -9,9 +9,7 @@
|
|||
#include "ray/raylet/raylet_client.h"
|
||||
|
||||
namespace ray {
|
||||
|
||||
/// Type of this worker.
|
||||
enum class WorkerType { WORKER, DRIVER };
|
||||
using WorkerType = rpc::WorkerType;
|
||||
|
||||
/// Information about a remote function.
|
||||
struct RayFunction {
|
||||
|
|
|
@ -6,69 +6,81 @@ namespace ray {
|
|||
/// per-thread context for core worker.
|
||||
struct WorkerThreadContext {
|
||||
WorkerThreadContext()
|
||||
: current_task_id(TaskID::FromRandom()), task_index(0), put_index(0) {}
|
||||
: current_task_id_(TaskID::FromRandom()), task_index_(0), put_index_(0) {}
|
||||
|
||||
int GetNextTaskIndex() { return ++task_index; }
|
||||
int GetNextTaskIndex() { return ++task_index_; }
|
||||
|
||||
int GetNextPutIndex() { return ++put_index; }
|
||||
int GetNextPutIndex() { return ++put_index_; }
|
||||
|
||||
const TaskID &GetCurrentTaskID() const { return current_task_id; }
|
||||
const TaskID &GetCurrentTaskID() const { return current_task_id_; }
|
||||
|
||||
void SetCurrentTask(const TaskID &task_id) {
|
||||
current_task_id = task_id;
|
||||
task_index = 0;
|
||||
put_index = 0;
|
||||
std::shared_ptr<const TaskSpecification> GetCurrentTask() const {
|
||||
return current_task_;
|
||||
}
|
||||
|
||||
void SetCurrentTaskId(const TaskID &task_id) {
|
||||
current_task_id_ = task_id;
|
||||
task_index_ = 0;
|
||||
put_index_ = 0;
|
||||
}
|
||||
|
||||
void SetCurrentTask(const TaskSpecification &task_spec) {
|
||||
SetCurrentTask(task_spec.TaskId());
|
||||
SetCurrentTaskId(task_spec.TaskId());
|
||||
current_task_ = std::make_shared<const TaskSpecification>(task_spec);
|
||||
}
|
||||
|
||||
private:
|
||||
/// The task ID for current task.
|
||||
TaskID current_task_id;
|
||||
TaskID current_task_id_;
|
||||
|
||||
/// The current task.
|
||||
std::shared_ptr<const TaskSpecification> current_task_;
|
||||
|
||||
/// Number of tasks that have been submitted from current task.
|
||||
int task_index;
|
||||
int task_index_;
|
||||
|
||||
/// Number of objects that have been put from current task.
|
||||
int put_index;
|
||||
int put_index_;
|
||||
};
|
||||
|
||||
thread_local std::unique_ptr<WorkerThreadContext> WorkerContext::thread_context_ =
|
||||
nullptr;
|
||||
|
||||
WorkerContext::WorkerContext(WorkerType worker_type, const JobID &job_id)
|
||||
: worker_type(worker_type),
|
||||
worker_id(worker_type == WorkerType::DRIVER ? ComputeDriverIdFromJob(job_id)
|
||||
: WorkerID::FromRandom()),
|
||||
current_job_id(worker_type == WorkerType::DRIVER ? job_id : JobID::Nil()) {
|
||||
: worker_type_(worker_type),
|
||||
worker_id_(worker_type_ == WorkerType::DRIVER ? ComputeDriverIdFromJob(job_id)
|
||||
: WorkerID::FromRandom()),
|
||||
current_job_id_(worker_type_ == WorkerType::DRIVER ? job_id : JobID::Nil()) {
|
||||
// For worker main thread which initializes the WorkerContext,
|
||||
// set task_id according to whether current worker is a driver.
|
||||
// (For other threads it's set to random ID via GetThreadContext).
|
||||
GetThreadContext().SetCurrentTask(
|
||||
(worker_type == WorkerType::DRIVER) ? TaskID::FromRandom() : TaskID::Nil());
|
||||
GetThreadContext().SetCurrentTaskId(
|
||||
(worker_type_ == WorkerType::DRIVER) ? TaskID::FromRandom() : TaskID::Nil());
|
||||
}
|
||||
|
||||
const WorkerType WorkerContext::GetWorkerType() const { return worker_type; }
|
||||
const WorkerType WorkerContext::GetWorkerType() const { return worker_type_; }
|
||||
|
||||
const WorkerID &WorkerContext::GetWorkerID() const { return worker_id; }
|
||||
const WorkerID &WorkerContext::GetWorkerID() const { return worker_id_; }
|
||||
|
||||
int WorkerContext::GetNextTaskIndex() { return GetThreadContext().GetNextTaskIndex(); }
|
||||
|
||||
int WorkerContext::GetNextPutIndex() { return GetThreadContext().GetNextPutIndex(); }
|
||||
|
||||
const JobID &WorkerContext::GetCurrentJobID() const { return current_job_id; }
|
||||
const JobID &WorkerContext::GetCurrentJobID() const { return current_job_id_; }
|
||||
|
||||
const TaskID &WorkerContext::GetCurrentTaskID() const {
|
||||
return GetThreadContext().GetCurrentTaskID();
|
||||
}
|
||||
|
||||
void WorkerContext::SetCurrentTask(const TaskSpecification &task_spec) {
|
||||
current_job_id = task_spec.JobId();
|
||||
current_job_id_ = task_spec.JobId();
|
||||
GetThreadContext().SetCurrentTask(task_spec);
|
||||
}
|
||||
|
||||
std::shared_ptr<const TaskSpecification> WorkerContext::GetCurrentTask() const {
|
||||
return GetThreadContext().GetCurrentTask();
|
||||
}
|
||||
|
||||
WorkerThreadContext &WorkerContext::GetThreadContext() {
|
||||
if (thread_context_ == nullptr) {
|
||||
thread_context_ = std::unique_ptr<WorkerThreadContext>(new WorkerThreadContext());
|
||||
|
|
|
@ -22,19 +22,21 @@ class WorkerContext {
|
|||
|
||||
void SetCurrentTask(const TaskSpecification &task_spec);
|
||||
|
||||
std::shared_ptr<const TaskSpecification> GetCurrentTask() const;
|
||||
|
||||
int GetNextTaskIndex();
|
||||
|
||||
int GetNextPutIndex();
|
||||
|
||||
private:
|
||||
/// Type of the worker.
|
||||
const WorkerType worker_type;
|
||||
const WorkerType worker_type_;
|
||||
|
||||
/// ID for this worker.
|
||||
const WorkerID worker_id;
|
||||
const WorkerID worker_id_;
|
||||
|
||||
/// Job ID for this worker.
|
||||
JobID current_job_id;
|
||||
JobID current_job_id_;
|
||||
|
||||
private:
|
||||
static WorkerThreadContext &GetThreadContext();
|
||||
|
|
|
@ -15,7 +15,7 @@ CoreWorker::CoreWorker(
|
|||
task_interface_(worker_context_, raylet_client_),
|
||||
object_interface_(worker_context_, raylet_client_, store_socket) {
|
||||
int rpc_server_port = 0;
|
||||
if (worker_type_ == ray::WorkerType::WORKER) {
|
||||
if (worker_type_ == WorkerType::WORKER) {
|
||||
RAY_CHECK(execution_callback != nullptr);
|
||||
task_execution_interface_ = std::unique_ptr<CoreWorkerTaskExecutionInterface>(
|
||||
new CoreWorkerTaskExecutionInterface(worker_context_, raylet_client_,
|
||||
|
@ -28,8 +28,8 @@ CoreWorker::CoreWorker(
|
|||
// instead of crashing.
|
||||
raylet_client_ = std::unique_ptr<RayletClient>(new RayletClient(
|
||||
raylet_socket_, ClientID::FromBinary(worker_context_.GetWorkerID().Binary()),
|
||||
(worker_type_ == ray::WorkerType::WORKER), worker_context_.GetCurrentJobID(),
|
||||
language_, rpc_server_port));
|
||||
(worker_type_ == WorkerType::WORKER), worker_context_.GetCurrentJobID(), language_,
|
||||
rpc_server_port));
|
||||
}
|
||||
|
||||
} // namespace ray
|
||||
|
|
75
src/ray/core_worker/lib/java/jni_init.cc
Normal file
75
src/ray/core_worker/lib/java/jni_init.cc
Normal file
|
@ -0,0 +1,75 @@
|
|||
#include "ray/core_worker/lib/java/jni_utils.h"
|
||||
|
||||
jclass java_boolean_class;
|
||||
jmethodID java_boolean_init;
|
||||
|
||||
jclass java_list_class;
|
||||
jmethodID java_list_size;
|
||||
jmethodID java_list_get;
|
||||
jmethodID java_list_add;
|
||||
|
||||
jclass java_array_list_class;
|
||||
jmethodID java_array_list_init;
|
||||
jmethodID java_array_list_init_with_capacity;
|
||||
|
||||
jclass java_ray_exception_class;
|
||||
|
||||
jclass java_native_ray_object_class;
|
||||
jmethodID java_native_ray_object_init;
|
||||
jfieldID java_native_ray_object_data;
|
||||
jfieldID java_native_ray_object_metadata;
|
||||
|
||||
jint JNI_VERSION = JNI_VERSION_1_8;
|
||||
|
||||
inline jclass LoadClass(JNIEnv *env, const char *class_name) {
|
||||
jclass tempLocalClassRef = env->FindClass(class_name);
|
||||
jclass ret = (jclass)env->NewGlobalRef(tempLocalClassRef);
|
||||
env->DeleteLocalRef(tempLocalClassRef);
|
||||
return ret;
|
||||
}
|
||||
|
||||
/// Load and cache frequently-used Java classes and methods
|
||||
jint JNI_OnLoad(JavaVM *vm, void *reserved) {
|
||||
JNIEnv *env;
|
||||
if (vm->GetEnv(reinterpret_cast<void **>(&env), JNI_VERSION) != JNI_OK) {
|
||||
return JNI_ERR;
|
||||
}
|
||||
|
||||
java_boolean_class = LoadClass(env, "java/lang/Boolean");
|
||||
java_boolean_init = env->GetMethodID(java_boolean_class, "<init>", "(Z)V");
|
||||
|
||||
java_list_class = LoadClass(env, "java/util/List");
|
||||
java_list_size = env->GetMethodID(java_list_class, "size", "()I");
|
||||
java_list_get = env->GetMethodID(java_list_class, "get", "(I)Ljava/lang/Object;");
|
||||
java_list_add = env->GetMethodID(java_list_class, "add", "(Ljava/lang/Object;)Z");
|
||||
|
||||
java_array_list_class = LoadClass(env, "java/util/ArrayList");
|
||||
java_array_list_init = env->GetMethodID(java_array_list_class, "<init>", "()V");
|
||||
java_array_list_init_with_capacity =
|
||||
env->GetMethodID(java_array_list_class, "<init>", "(I)V");
|
||||
|
||||
java_ray_exception_class = LoadClass(env, "org/ray/api/exception/RayException");
|
||||
|
||||
java_native_ray_object_class =
|
||||
LoadClass(env, "org/ray/runtime/objectstore/NativeRayObject");
|
||||
java_native_ray_object_init =
|
||||
env->GetMethodID(java_native_ray_object_class, "<init>", "([B[B)V");
|
||||
java_native_ray_object_data =
|
||||
env->GetFieldID(java_native_ray_object_class, "data", "[B");
|
||||
java_native_ray_object_metadata =
|
||||
env->GetFieldID(java_native_ray_object_class, "metadata", "[B");
|
||||
|
||||
return JNI_VERSION;
|
||||
}
|
||||
|
||||
/// Unload java classes
|
||||
void JNI_OnUnload(JavaVM *vm, void *reserved) {
|
||||
JNIEnv *env;
|
||||
vm->GetEnv(reinterpret_cast<void **>(&env), JNI_VERSION);
|
||||
|
||||
env->DeleteGlobalRef(java_boolean_class);
|
||||
env->DeleteGlobalRef(java_list_class);
|
||||
env->DeleteGlobalRef(java_array_list_class);
|
||||
env->DeleteGlobalRef(java_ray_exception_class);
|
||||
env->DeleteGlobalRef(java_native_ray_object_class);
|
||||
}
|
180
src/ray/core_worker/lib/java/jni_utils.h
Normal file
180
src/ray/core_worker/lib/java/jni_utils.h
Normal file
|
@ -0,0 +1,180 @@
|
|||
#ifndef RAY_COMMON_JAVA_JNI_HELPER_H
|
||||
#define RAY_COMMON_JAVA_JNI_HELPER_H
|
||||
|
||||
#include <jni.h>
|
||||
#include "ray/common/buffer.h"
|
||||
#include "ray/common/id.h"
|
||||
#include "ray/common/status.h"
|
||||
#include "ray/core_worker/store_provider/store_provider.h"
|
||||
|
||||
/// Boolean class
|
||||
extern jclass java_boolean_class;
|
||||
/// Constructor of Boolean class
|
||||
extern jmethodID java_boolean_init;
|
||||
|
||||
/// List class
|
||||
extern jclass java_list_class;
|
||||
/// size method of List class
|
||||
extern jmethodID java_list_size;
|
||||
/// get method of List class
|
||||
extern jmethodID java_list_get;
|
||||
/// add method of List class
|
||||
extern jmethodID java_list_add;
|
||||
|
||||
/// ArrayList class
|
||||
extern jclass java_array_list_class;
|
||||
/// Constructor of ArrayList class
|
||||
extern jmethodID java_array_list_init;
|
||||
/// Constructor of ArrayList class with single parameter capacity
|
||||
extern jmethodID java_array_list_init_with_capacity;
|
||||
|
||||
/// RayException class
|
||||
extern jclass java_ray_exception_class;
|
||||
|
||||
/// NativeRayObject class
|
||||
extern jclass java_native_ray_object_class;
|
||||
/// Constructor of NativeRayObject class
|
||||
extern jmethodID java_native_ray_object_init;
|
||||
/// data field of NativeRayObject class
|
||||
extern jfieldID java_native_ray_object_data;
|
||||
/// metadata field of NativeRayObject class
|
||||
extern jfieldID java_native_ray_object_metadata;
|
||||
|
||||
/// Throws a Java RayException if the status is not OK.
|
||||
#define THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, ret) \
|
||||
{ \
|
||||
if (!(status).ok()) { \
|
||||
(env)->ThrowNew(java_ray_exception_class, (status).message().c_str()); \
|
||||
return (ret); \
|
||||
} \
|
||||
}
|
||||
|
||||
/// Convert a Java byte array to a C++ UniqueID.
|
||||
template <typename ID>
|
||||
inline ID JavaByteArrayToId(JNIEnv *env, const jbyteArray &bytes) {
|
||||
std::string id_str(ID::Size(), 0);
|
||||
env->GetByteArrayRegion(bytes, 0, ID::Size(),
|
||||
reinterpret_cast<jbyte *>(&id_str.front()));
|
||||
return ID::FromBinary(id_str);
|
||||
}
|
||||
|
||||
/// Convert C++ UniqueID to a Java byte array.
|
||||
template <typename ID>
|
||||
inline jbyteArray IdToJavaByteArray(JNIEnv *env, const ID &id) {
|
||||
jbyteArray array = env->NewByteArray(ID::Size());
|
||||
env->SetByteArrayRegion(array, 0, ID::Size(),
|
||||
reinterpret_cast<const jbyte *>(id.Data()));
|
||||
return array;
|
||||
}
|
||||
|
||||
/// Convert C++ UniqueID to a Java ByteBuffer.
|
||||
template <typename ID>
|
||||
inline jobject IdToJavaByteBuffer(JNIEnv *env, const ID &id) {
|
||||
return env->NewDirectByteBuffer(
|
||||
reinterpret_cast<void *>(const_cast<uint8_t *>(id.Data())), id.Size());
|
||||
}
|
||||
|
||||
/// Convert a Java String to C++ std::string.
|
||||
inline std::string JavaStringToNativeString(JNIEnv *env, jstring jstr) {
|
||||
const char *c_str = env->GetStringUTFChars(jstr, nullptr);
|
||||
std::string result(c_str);
|
||||
env->ReleaseStringUTFChars(static_cast<jstring>(jstr), c_str);
|
||||
return result;
|
||||
}
|
||||
|
||||
/// Convert a Java List to C++ std::vector.
|
||||
template <typename NativeT>
|
||||
inline void JavaListToNativeVector(
|
||||
JNIEnv *env, jobject java_list, std::vector<NativeT> *native_vector,
|
||||
std::function<NativeT(JNIEnv *, jobject)> element_converter) {
|
||||
int size = env->CallIntMethod(java_list, java_list_size);
|
||||
native_vector->clear();
|
||||
for (int i = 0; i < size; i++) {
|
||||
native_vector->emplace_back(
|
||||
element_converter(env, env->CallObjectMethod(java_list, java_list_get, (jint)i)));
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert a C++ std::vector to a Java List.
|
||||
template <typename NativeT>
|
||||
inline jobject NativeVectorToJavaList(
|
||||
JNIEnv *env, const std::vector<NativeT> &native_vector,
|
||||
std::function<jobject(JNIEnv *, const NativeT &)> element_converter) {
|
||||
jobject java_list =
|
||||
env->NewObject(java_array_list_class, java_array_list_init_with_capacity,
|
||||
(jint)native_vector.size());
|
||||
for (const auto &item : native_vector) {
|
||||
env->CallVoidMethod(java_list, java_list_add, element_converter(env, item));
|
||||
}
|
||||
return java_list;
|
||||
}
|
||||
|
||||
/// Convert a C++ ray::Buffer to a Java byte array.
|
||||
inline jbyteArray NativeBufferToJavaByteArray(JNIEnv *env,
|
||||
const std::shared_ptr<ray::Buffer> buffer) {
|
||||
if (!buffer) {
|
||||
return nullptr;
|
||||
}
|
||||
jbyteArray java_byte_array = env->NewByteArray(buffer->Size());
|
||||
if (buffer->Size() > 0) {
|
||||
env->SetByteArrayRegion(java_byte_array, 0, buffer->Size(),
|
||||
reinterpret_cast<const jbyte *>(buffer->Data()));
|
||||
}
|
||||
return java_byte_array;
|
||||
}
|
||||
|
||||
/// A helper method to help access a Java NativeRayObject instance and ensure memory
|
||||
/// safety.
|
||||
///
|
||||
/// \param[in] java_obj The Java NativeRayObject object.
|
||||
/// \param[in] reader The callback function to access a C++ ray::RayObject instance.
|
||||
/// \return The return value of callback function.
|
||||
template <typename ReturnT>
|
||||
inline ReturnT ReadJavaNativeRayObject(
|
||||
JNIEnv *env, const jobject &java_obj,
|
||||
std::function<ReturnT(const std::shared_ptr<ray::RayObject> &)> reader) {
|
||||
if (!java_obj) {
|
||||
return reader(nullptr);
|
||||
}
|
||||
auto java_data = (jbyteArray)env->GetObjectField(java_obj, java_native_ray_object_data);
|
||||
auto java_metadata =
|
||||
(jbyteArray)env->GetObjectField(java_obj, java_native_ray_object_metadata);
|
||||
auto data_size = env->GetArrayLength(java_data);
|
||||
jbyte *data = data_size > 0 ? env->GetByteArrayElements(java_data, nullptr) : nullptr;
|
||||
auto metadata_size = java_metadata ? env->GetArrayLength(java_metadata) : 0;
|
||||
jbyte *metadata =
|
||||
metadata_size > 0 ? env->GetByteArrayElements(java_metadata, nullptr) : nullptr;
|
||||
auto data_buffer = std::make_shared<ray::LocalMemoryBuffer>(
|
||||
reinterpret_cast<uint8_t *>(data), data_size);
|
||||
auto metadata_buffer = java_metadata
|
||||
? std::make_shared<ray::LocalMemoryBuffer>(
|
||||
reinterpret_cast<uint8_t *>(metadata), metadata_size)
|
||||
: nullptr;
|
||||
|
||||
auto native_obj = std::make_shared<ray::RayObject>(data_buffer, metadata_buffer);
|
||||
auto result = reader(native_obj);
|
||||
|
||||
if (data) {
|
||||
env->ReleaseByteArrayElements(java_data, data, JNI_ABORT);
|
||||
}
|
||||
if (metadata) {
|
||||
env->ReleaseByteArrayElements(java_metadata, metadata, JNI_ABORT);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
/// Convert a C++ ray::RayObject to a Java NativeRayObject.
|
||||
inline jobject ToJavaNativeRayObject(JNIEnv *env,
|
||||
const std::shared_ptr<ray::RayObject> &rayObject) {
|
||||
if (!rayObject) {
|
||||
return nullptr;
|
||||
}
|
||||
auto java_data = NativeBufferToJavaByteArray(env, rayObject->GetData());
|
||||
auto java_metadata = NativeBufferToJavaByteArray(env, rayObject->GetMetadata());
|
||||
auto java_obj = env->NewObject(java_native_ray_object_class,
|
||||
java_native_ray_object_init, java_data, java_metadata);
|
||||
return java_obj;
|
||||
}
|
||||
|
||||
#endif // RAY_COMMON_JAVA_JNI_HELPER_H
|
134
src/ray/core_worker/lib/java/org_ray_runtime_WorkerContext.cc
Normal file
134
src/ray/core_worker/lib/java/org_ray_runtime_WorkerContext.cc
Normal file
|
@ -0,0 +1,134 @@
|
|||
#include "ray/core_worker/lib/java/org_ray_runtime_WorkerContext.h"
|
||||
#include <jni.h>
|
||||
#include "ray/common/id.h"
|
||||
#include "ray/core_worker/context.h"
|
||||
#include "ray/core_worker/lib/java/jni_utils.h"
|
||||
|
||||
inline ray::WorkerContext *GetWorkerContextFromPointer(
|
||||
jlong nativeWorkerContextFromPointer) {
|
||||
return reinterpret_cast<ray::WorkerContext *>(nativeWorkerContextFromPointer);
|
||||
}
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
/*
|
||||
* Class: org_ray_runtime_WorkerContext
|
||||
* Method: nativeCreateWorkerContext
|
||||
* Signature: (I[B)J
|
||||
*/
|
||||
JNIEXPORT jlong JNICALL Java_org_ray_runtime_WorkerContext_nativeCreateWorkerContext(
|
||||
JNIEnv *env, jclass, jint workerType, jbyteArray jobId) {
|
||||
return reinterpret_cast<jlong>(
|
||||
new ray::WorkerContext(static_cast<ray::rpc::WorkerType>(workerType),
|
||||
JavaByteArrayToId<ray::JobID>(env, jobId)));
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: org_ray_runtime_WorkerContext
|
||||
* Method: nativeGetCurrentTaskId
|
||||
* Signature: (J)[B
|
||||
*/
|
||||
JNIEXPORT jbyteArray JNICALL Java_org_ray_runtime_WorkerContext_nativeGetCurrentTaskId(
|
||||
JNIEnv *env, jclass, jlong nativeWorkerContextFromPointer) {
|
||||
auto task_id =
|
||||
GetWorkerContextFromPointer(nativeWorkerContextFromPointer)->GetCurrentTaskID();
|
||||
return IdToJavaByteArray<ray::TaskID>(env, task_id);
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: org_ray_runtime_WorkerContext
|
||||
* Method: nativeSetCurrentTask
|
||||
* Signature: (J[B)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL Java_org_ray_runtime_WorkerContext_nativeSetCurrentTask(
|
||||
JNIEnv *env, jclass, jlong nativeWorkerContextFromPointer, jbyteArray taskSpec) {
|
||||
jbyte *data = env->GetByteArrayElements(taskSpec, NULL);
|
||||
jsize size = env->GetArrayLength(taskSpec);
|
||||
ray::rpc::TaskSpec task_spec_message;
|
||||
task_spec_message.ParseFromArray(data, size);
|
||||
env->ReleaseByteArrayElements(taskSpec, data, JNI_ABORT);
|
||||
|
||||
ray::TaskSpecification spec(task_spec_message);
|
||||
GetWorkerContextFromPointer(nativeWorkerContextFromPointer)->SetCurrentTask(spec);
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: org_ray_runtime_WorkerContext
|
||||
* Method: nativeGetCurrentTask
|
||||
* Signature: (J)[B
|
||||
*/
|
||||
JNIEXPORT jbyteArray JNICALL Java_org_ray_runtime_WorkerContext_nativeGetCurrentTask(
|
||||
JNIEnv *env, jclass, jlong nativeWorkerContextFromPointer) {
|
||||
auto spec =
|
||||
GetWorkerContextFromPointer(nativeWorkerContextFromPointer)->GetCurrentTask();
|
||||
if (!spec) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto task_message = spec->Serialize();
|
||||
jbyteArray result = env->NewByteArray(task_message.size());
|
||||
env->SetByteArrayRegion(
|
||||
result, 0, task_message.size(),
|
||||
reinterpret_cast<jbyte *>(const_cast<char *>(task_message.data())));
|
||||
return result;
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: org_ray_runtime_WorkerContext
|
||||
* Method: nativeGetCurrentJobId
|
||||
* Signature: (J)Ljava/nio/ByteBuffer;
|
||||
*/
|
||||
JNIEXPORT jobject JNICALL Java_org_ray_runtime_WorkerContext_nativeGetCurrentJobId(
|
||||
JNIEnv *env, jclass, jlong nativeWorkerContextFromPointer) {
|
||||
const auto &job_id =
|
||||
GetWorkerContextFromPointer(nativeWorkerContextFromPointer)->GetCurrentJobID();
|
||||
return IdToJavaByteBuffer<ray::JobID>(env, job_id);
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: org_ray_runtime_WorkerContext
|
||||
* Method: nativeGetCurrentWorkerId
|
||||
* Signature: (J)[B
|
||||
*/
|
||||
JNIEXPORT jbyteArray JNICALL Java_org_ray_runtime_WorkerContext_nativeGetCurrentWorkerId(
|
||||
JNIEnv *env, jclass, jlong nativeWorkerContextFromPointer) {
|
||||
auto worker_id =
|
||||
GetWorkerContextFromPointer(nativeWorkerContextFromPointer)->GetWorkerID();
|
||||
return IdToJavaByteArray<ray::WorkerID>(env, worker_id);
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: org_ray_runtime_WorkerContext
|
||||
* Method: nativeGetNextTaskIndex
|
||||
* Signature: (J)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_org_ray_runtime_WorkerContext_nativeGetNextTaskIndex(
|
||||
JNIEnv *env, jclass, jlong nativeWorkerContextFromPointer) {
|
||||
return GetWorkerContextFromPointer(nativeWorkerContextFromPointer)->GetNextTaskIndex();
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: org_ray_runtime_WorkerContext
|
||||
* Method: nativeGetNextPutIndex
|
||||
* Signature: (J)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_org_ray_runtime_WorkerContext_nativeGetNextPutIndex(
|
||||
JNIEnv *env, jclass, jlong nativeWorkerContextFromPointer) {
|
||||
return GetWorkerContextFromPointer(nativeWorkerContextFromPointer)->GetNextPutIndex();
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: org_ray_runtime_WorkerContext
|
||||
* Method: nativeDestroy
|
||||
* Signature: (J)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL Java_org_ray_runtime_WorkerContext_nativeDestroy(
|
||||
JNIEnv *env, jclass, jlong nativeWorkerContextFromPointer) {
|
||||
delete GetWorkerContextFromPointer(nativeWorkerContextFromPointer);
|
||||
}
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
87
src/ray/core_worker/lib/java/org_ray_runtime_WorkerContext.h
Normal file
87
src/ray/core_worker/lib/java/org_ray_runtime_WorkerContext.h
Normal file
|
@ -0,0 +1,87 @@
|
|||
/* DO NOT EDIT THIS FILE - it is machine generated */
|
||||
#include <jni.h>
|
||||
/* Header for class org_ray_runtime_WorkerContext */
|
||||
|
||||
#ifndef _Included_org_ray_runtime_WorkerContext
|
||||
#define _Included_org_ray_runtime_WorkerContext
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
/*
|
||||
* Class: org_ray_runtime_WorkerContext
|
||||
* Method: nativeCreateWorkerContext
|
||||
* Signature: (I[B)J
|
||||
*/
|
||||
JNIEXPORT jlong JNICALL Java_org_ray_runtime_WorkerContext_nativeCreateWorkerContext(
|
||||
JNIEnv *, jclass, jint, jbyteArray);
|
||||
|
||||
/*
|
||||
* Class: org_ray_runtime_WorkerContext
|
||||
* Method: nativeGetCurrentTaskId
|
||||
* Signature: (J)[B
|
||||
*/
|
||||
JNIEXPORT jbyteArray JNICALL
|
||||
Java_org_ray_runtime_WorkerContext_nativeGetCurrentTaskId(JNIEnv *, jclass, jlong);
|
||||
|
||||
/*
|
||||
* Class: org_ray_runtime_WorkerContext
|
||||
* Method: nativeSetCurrentTask
|
||||
* Signature: (J[B)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL Java_org_ray_runtime_WorkerContext_nativeSetCurrentTask(
|
||||
JNIEnv *, jclass, jlong, jbyteArray);
|
||||
|
||||
/*
|
||||
* Class: org_ray_runtime_WorkerContext
|
||||
* Method: nativeGetCurrentTask
|
||||
* Signature: (J)[B
|
||||
*/
|
||||
JNIEXPORT jbyteArray JNICALL
|
||||
Java_org_ray_runtime_WorkerContext_nativeGetCurrentTask(JNIEnv *, jclass, jlong);
|
||||
|
||||
/*
|
||||
* Class: org_ray_runtime_WorkerContext
|
||||
* Method: nativeGetCurrentJobId
|
||||
* Signature: (J)Ljava/nio/ByteBuffer;
|
||||
*/
|
||||
JNIEXPORT jobject JNICALL
|
||||
Java_org_ray_runtime_WorkerContext_nativeGetCurrentJobId(JNIEnv *, jclass, jlong);
|
||||
|
||||
/*
|
||||
* Class: org_ray_runtime_WorkerContext
|
||||
* Method: nativeGetCurrentWorkerId
|
||||
* Signature: (J)[B
|
||||
*/
|
||||
JNIEXPORT jbyteArray JNICALL
|
||||
Java_org_ray_runtime_WorkerContext_nativeGetCurrentWorkerId(JNIEnv *, jclass, jlong);
|
||||
|
||||
/*
|
||||
* Class: org_ray_runtime_WorkerContext
|
||||
* Method: nativeGetNextTaskIndex
|
||||
* Signature: (J)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_org_ray_runtime_WorkerContext_nativeGetNextTaskIndex(JNIEnv *,
|
||||
jclass,
|
||||
jlong);
|
||||
|
||||
/*
|
||||
* Class: org_ray_runtime_WorkerContext
|
||||
* Method: nativeGetNextPutIndex
|
||||
* Signature: (J)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_org_ray_runtime_WorkerContext_nativeGetNextPutIndex(JNIEnv *,
|
||||
jclass,
|
||||
jlong);
|
||||
|
||||
/*
|
||||
* Class: org_ray_runtime_WorkerContext
|
||||
* Method: nativeDestroy
|
||||
* Signature: (J)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL Java_org_ray_runtime_WorkerContext_nativeDestroy(JNIEnv *, jclass,
|
||||
jlong);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
#endif
|
|
@ -0,0 +1,149 @@
|
|||
#include "ray/core_worker/lib/java/org_ray_runtime_objectstore_ObjectInterfaceImpl.h"
|
||||
#include <jni.h>
|
||||
#include "ray/common/id.h"
|
||||
#include "ray/core_worker/common.h"
|
||||
#include "ray/core_worker/lib/java/jni_utils.h"
|
||||
#include "ray/core_worker/object_interface.h"
|
||||
|
||||
inline ray::CoreWorkerObjectInterface *GetObjectInterfaceFromPointer(
|
||||
jlong nativeObjectInterfacePointer) {
|
||||
return reinterpret_cast<ray::CoreWorkerObjectInterface *>(nativeObjectInterfacePointer);
|
||||
}
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
/*
|
||||
* Class: org_ray_runtime_objectstore_ObjectInterfaceImpl
|
||||
* Method: nativeCreateObjectInterface
|
||||
* Signature: (JJLjava/lang/String;)J
|
||||
*/
|
||||
JNIEXPORT jlong JNICALL
|
||||
Java_org_ray_runtime_objectstore_ObjectInterfaceImpl_nativeCreateObjectInterface(
|
||||
JNIEnv *env, jclass, jlong nativeWorkerContext, jlong nativeRayletClient,
|
||||
jstring storeSocketName) {
|
||||
return reinterpret_cast<jlong>(new ray::CoreWorkerObjectInterface(
|
||||
*reinterpret_cast<ray::WorkerContext *>(nativeWorkerContext),
|
||||
*reinterpret_cast<std::unique_ptr<RayletClient> *>(nativeRayletClient),
|
||||
JavaStringToNativeString(env, storeSocketName)));
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: org_ray_runtime_objectstore_ObjectInterfaceImpl
|
||||
* Method: nativePut
|
||||
* Signature: (JLorg/ray/runtime/objectstore/NativeRayObject;)[B
|
||||
*/
|
||||
JNIEXPORT jbyteArray JNICALL
|
||||
Java_org_ray_runtime_objectstore_ObjectInterfaceImpl_nativePut__JLorg_ray_runtime_objectstore_NativeRayObject_2(
|
||||
JNIEnv *env, jclass, jlong nativeObjectInterfacePointer, jobject obj) {
|
||||
ray::Status status;
|
||||
ray::ObjectID object_id = ReadJavaNativeRayObject<ray::ObjectID>(
|
||||
env, obj,
|
||||
[nativeObjectInterfacePointer,
|
||||
&status](const std::shared_ptr<ray::RayObject> &rayObject) {
|
||||
RAY_CHECK(rayObject != nullptr);
|
||||
ray::ObjectID object_id;
|
||||
status = GetObjectInterfaceFromPointer(nativeObjectInterfacePointer)
|
||||
->Put(*rayObject, &object_id);
|
||||
return object_id;
|
||||
});
|
||||
THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, nullptr);
|
||||
return IdToJavaByteArray<ray::ObjectID>(env, object_id);
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: org_ray_runtime_objectstore_ObjectInterfaceImpl
|
||||
* Method: nativePut
|
||||
* Signature: (J[BLorg/ray/runtime/objectstore/NativeRayObject;)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL
|
||||
Java_org_ray_runtime_objectstore_ObjectInterfaceImpl_nativePut__J_3BLorg_ray_runtime_objectstore_NativeRayObject_2(
|
||||
JNIEnv *env, jclass, jlong nativeObjectInterfacePointer, jbyteArray objectId,
|
||||
jobject obj) {
|
||||
auto object_id = JavaByteArrayToId<ray::ObjectID>(env, objectId);
|
||||
auto status = ReadJavaNativeRayObject<ray::Status>(
|
||||
env, obj,
|
||||
[nativeObjectInterfacePointer,
|
||||
&object_id](const std::shared_ptr<ray::RayObject> &rayObject) {
|
||||
RAY_CHECK(rayObject != nullptr);
|
||||
return GetObjectInterfaceFromPointer(nativeObjectInterfacePointer)
|
||||
->Put(*rayObject, object_id);
|
||||
});
|
||||
THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, (void)0);
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: org_ray_runtime_objectstore_ObjectInterfaceImpl
|
||||
* Method: nativeGet
|
||||
* Signature: (JLjava/util/List;J)Ljava/util/List;
|
||||
*/
|
||||
JNIEXPORT jobject JNICALL Java_org_ray_runtime_objectstore_ObjectInterfaceImpl_nativeGet(
|
||||
JNIEnv *env, jclass, jlong nativeObjectInterfacePointer, jobject ids,
|
||||
jlong timeoutMs) {
|
||||
std::vector<ray::ObjectID> object_ids;
|
||||
JavaListToNativeVector<ray::ObjectID>(
|
||||
env, ids, &object_ids, [](JNIEnv *env, jobject id) {
|
||||
return JavaByteArrayToId<ray::ObjectID>(env, static_cast<jbyteArray>(id));
|
||||
});
|
||||
std::vector<std::shared_ptr<ray::RayObject>> results;
|
||||
auto status = GetObjectInterfaceFromPointer(nativeObjectInterfacePointer)
|
||||
->Get(object_ids, (int64_t)timeoutMs, &results);
|
||||
THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, nullptr);
|
||||
return NativeVectorToJavaList<std::shared_ptr<ray::RayObject>>(env, results,
|
||||
ToJavaNativeRayObject);
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: org_ray_runtime_objectstore_ObjectInterfaceImpl
|
||||
* Method: nativeWait
|
||||
* Signature: (JLjava/util/List;IJ)Ljava/util/List;
|
||||
*/
|
||||
JNIEXPORT jobject JNICALL Java_org_ray_runtime_objectstore_ObjectInterfaceImpl_nativeWait(
|
||||
JNIEnv *env, jclass, jlong nativeObjectInterfacePointer, jobject objectIds,
|
||||
jint numObjects, jlong timeoutMs) {
|
||||
std::vector<ray::ObjectID> object_ids;
|
||||
JavaListToNativeVector<ray::ObjectID>(
|
||||
env, objectIds, &object_ids, [](JNIEnv *env, jobject id) {
|
||||
return JavaByteArrayToId<ray::ObjectID>(env, static_cast<jbyteArray>(id));
|
||||
});
|
||||
std::vector<bool> results;
|
||||
auto status = GetObjectInterfaceFromPointer(nativeObjectInterfacePointer)
|
||||
->Wait(object_ids, (int)numObjects, (int64_t)timeoutMs, &results);
|
||||
THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, nullptr);
|
||||
return NativeVectorToJavaList<bool>(env, results, [](JNIEnv *env, const bool &item) {
|
||||
return env->NewObject(java_boolean_class, java_boolean_init, (jboolean)item);
|
||||
});
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: org_ray_runtime_objectstore_ObjectInterfaceImpl
|
||||
* Method: nativeDelete
|
||||
* Signature: (JLjava/util/List;ZZ)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL Java_org_ray_runtime_objectstore_ObjectInterfaceImpl_nativeDelete(
|
||||
JNIEnv *env, jclass, jlong nativeObjectInterfacePointer, jobject objectIds,
|
||||
jboolean localOnly, jboolean deleteCreatingTasks) {
|
||||
std::vector<ray::ObjectID> object_ids;
|
||||
JavaListToNativeVector<ray::ObjectID>(
|
||||
env, objectIds, &object_ids, [](JNIEnv *env, jobject id) {
|
||||
return JavaByteArrayToId<ray::ObjectID>(env, static_cast<jbyteArray>(id));
|
||||
});
|
||||
auto status = GetObjectInterfaceFromPointer(nativeObjectInterfacePointer)
|
||||
->Delete(object_ids, (bool)localOnly, (bool)deleteCreatingTasks);
|
||||
THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, (void)0);
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: org_ray_runtime_objectstore_ObjectInterfaceImpl
|
||||
* Method: nativeDestroy
|
||||
* Signature: (J)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL Java_org_ray_runtime_objectstore_ObjectInterfaceImpl_nativeDestroy(
|
||||
JNIEnv *env, jclass, jlong nativeObjectInterfacePointer) {
|
||||
delete GetObjectInterfaceFromPointer(nativeObjectInterfacePointer);
|
||||
}
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
|
@ -0,0 +1,72 @@
|
|||
/* DO NOT EDIT THIS FILE - it is machine generated */
|
||||
#include <jni.h>
|
||||
/* Header for class org_ray_runtime_objectstore_ObjectInterfaceImpl */
|
||||
|
||||
#ifndef _Included_org_ray_runtime_objectstore_ObjectInterfaceImpl
|
||||
#define _Included_org_ray_runtime_objectstore_ObjectInterfaceImpl
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
/*
|
||||
* Class: org_ray_runtime_objectstore_ObjectInterfaceImpl
|
||||
* Method: nativeCreateObjectInterface
|
||||
* Signature: (JJLjava/lang/String;)J
|
||||
*/
|
||||
JNIEXPORT jlong JNICALL
|
||||
Java_org_ray_runtime_objectstore_ObjectInterfaceImpl_nativeCreateObjectInterface(
|
||||
JNIEnv *, jclass, jlong, jlong, jstring);
|
||||
|
||||
/*
|
||||
* Class: org_ray_runtime_objectstore_ObjectInterfaceImpl
|
||||
* Method: nativePut
|
||||
* Signature: (JLorg/ray/runtime/objectstore/NativeRayObject;)[B
|
||||
*/
|
||||
JNIEXPORT jbyteArray JNICALL
|
||||
Java_org_ray_runtime_objectstore_ObjectInterfaceImpl_nativePut__JLorg_ray_runtime_objectstore_NativeRayObject_2(
|
||||
JNIEnv *, jclass, jlong, jobject);
|
||||
|
||||
/*
|
||||
* Class: org_ray_runtime_objectstore_ObjectInterfaceImpl
|
||||
* Method: nativePut
|
||||
* Signature: (J[BLorg/ray/runtime/objectstore/NativeRayObject;)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL
|
||||
Java_org_ray_runtime_objectstore_ObjectInterfaceImpl_nativePut__J_3BLorg_ray_runtime_objectstore_NativeRayObject_2(
|
||||
JNIEnv *, jclass, jlong, jbyteArray, jobject);
|
||||
|
||||
/*
|
||||
* Class: org_ray_runtime_objectstore_ObjectInterfaceImpl
|
||||
* Method: nativeGet
|
||||
* Signature: (JLjava/util/List;J)Ljava/util/List;
|
||||
*/
|
||||
JNIEXPORT jobject JNICALL Java_org_ray_runtime_objectstore_ObjectInterfaceImpl_nativeGet(
|
||||
JNIEnv *, jclass, jlong, jobject, jlong);
|
||||
|
||||
/*
|
||||
* Class: org_ray_runtime_objectstore_ObjectInterfaceImpl
|
||||
* Method: nativeWait
|
||||
* Signature: (JLjava/util/List;IJ)Ljava/util/List;
|
||||
*/
|
||||
JNIEXPORT jobject JNICALL Java_org_ray_runtime_objectstore_ObjectInterfaceImpl_nativeWait(
|
||||
JNIEnv *, jclass, jlong, jobject, jint, jlong);
|
||||
|
||||
/*
|
||||
* Class: org_ray_runtime_objectstore_ObjectInterfaceImpl
|
||||
* Method: nativeDelete
|
||||
* Signature: (JLjava/util/List;ZZ)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL Java_org_ray_runtime_objectstore_ObjectInterfaceImpl_nativeDelete(
|
||||
JNIEnv *, jclass, jlong, jobject, jboolean, jboolean);
|
||||
|
||||
/*
|
||||
* Class: org_ray_runtime_objectstore_ObjectInterfaceImpl
|
||||
* Method: nativeDestroy
|
||||
* Signature: (J)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL Java_org_ray_runtime_objectstore_ObjectInterfaceImpl_nativeDestroy(
|
||||
JNIEnv *, jclass, jlong);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
#endif
|
|
@ -3,6 +3,7 @@
|
|||
#include "ray/core_worker/context.h"
|
||||
#include "ray/core_worker/core_worker.h"
|
||||
#include "ray/core_worker/object_interface.h"
|
||||
#include "ray/protobuf/gcs.pb.h"
|
||||
|
||||
namespace ray {
|
||||
|
||||
|
@ -101,11 +102,14 @@ Status CoreWorkerPlasmaStoreProvider::Get(
|
|||
std::make_shared<PlasmaBuffer>(object_buffers[i].data),
|
||||
std::make_shared<PlasmaBuffer>(object_buffers[i].metadata));
|
||||
unready.erase(object_id);
|
||||
if (IsException(object_buffers[i])) {
|
||||
should_break = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
num_attempts += 1;
|
||||
// TODO(zhijunfu): log a message if attempted too many times.
|
||||
WarnIfAttemptedTooManyTimes(num_attempts, unready);
|
||||
}
|
||||
|
||||
if (was_blocked) {
|
||||
|
@ -144,4 +148,45 @@ Status CoreWorkerPlasmaStoreProvider::Delete(const std::vector<ObjectID> &object
|
|||
return raylet_client_->FreeObjects(object_ids, local_only, delete_creating_tasks);
|
||||
}
|
||||
|
||||
bool CoreWorkerPlasmaStoreProvider::IsException(const plasma::ObjectBuffer &buffer) {
|
||||
// TODO (kfstorm): metadata should be structured.
|
||||
const std::string metadata = buffer.metadata->ToString();
|
||||
const auto error_type_descriptor = ray::rpc::ErrorType_descriptor();
|
||||
for (int i = 0; i < error_type_descriptor->value_count(); i++) {
|
||||
const auto error_type_number = error_type_descriptor->value(i)->number();
|
||||
if (metadata == std::to_string(error_type_number)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
void CoreWorkerPlasmaStoreProvider::WarnIfAttemptedTooManyTimes(
|
||||
int num_attempts, const std::unordered_map<ObjectID, int> &unready) {
|
||||
if (num_attempts % RayConfig::instance().object_store_get_warn_per_num_attempts() ==
|
||||
0) {
|
||||
std::ostringstream oss;
|
||||
size_t printed = 0;
|
||||
for (auto &entry : unready) {
|
||||
if (printed >=
|
||||
RayConfig::instance().object_store_get_max_ids_to_print_in_warning()) {
|
||||
break;
|
||||
}
|
||||
if (printed > 0) {
|
||||
oss << ", ";
|
||||
}
|
||||
oss << entry.first.Hex();
|
||||
}
|
||||
if (printed < unready.size()) {
|
||||
oss << ", etc";
|
||||
}
|
||||
RAY_LOG(WARNING)
|
||||
<< "Attempted " << num_attempts << " times to reconstruct objects, but "
|
||||
<< "some objects are still unavailable. If this message continues to print,"
|
||||
<< " it may indicate that object's creating task is hanging, or something wrong"
|
||||
<< " happened in raylet backend. " << unready.size()
|
||||
<< " object(s) pending: " << oss.str() << ".";
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace ray
|
||||
|
|
|
@ -60,6 +60,20 @@ class CoreWorkerPlasmaStoreProvider : public CoreWorkerStoreProvider {
|
|||
bool delete_creating_tasks) override;
|
||||
|
||||
private:
|
||||
/// Whether the buffer represents an exception object.
|
||||
///
|
||||
/// \param[in] buffer the object buffer.
|
||||
/// \return Whether it represents an exception object.
|
||||
static bool IsException(const plasma::ObjectBuffer &buffer);
|
||||
|
||||
/// Print a warning if we've attempted too many times, but some objects are still
|
||||
/// unavailable.
|
||||
///
|
||||
/// \param[in] num_attemps The number of attempted times.
|
||||
/// \param[in] unready The unready objects.
|
||||
static void WarnIfAttemptedTooManyTimes(
|
||||
int num_attempts, const std::unordered_map<ObjectID, int> &unready);
|
||||
|
||||
/// Plasma store client.
|
||||
plasma::PlasmaClient store_client_;
|
||||
|
||||
|
|
|
@ -11,6 +11,12 @@ enum Language {
|
|||
CPP = 2;
|
||||
}
|
||||
|
||||
// Type of a worker.
|
||||
enum WorkerType {
|
||||
WORKER = 0;
|
||||
DRIVER = 1;
|
||||
}
|
||||
|
||||
// Type of a task.
|
||||
enum TaskType {
|
||||
// Normal task.
|
||||
|
|
|
@ -267,4 +267,6 @@ enum ErrorType {
|
|||
// 2) The object's creating task is already cleaned up from GCS (this currently
|
||||
// crashes raylet).
|
||||
OBJECT_UNRECONSTRUCTABLE = 2;
|
||||
// Indicates that a task failed due to user code failure.
|
||||
TASK_EXECUTION_EXCEPTION = 3;
|
||||
}
|
||||
|
|
|
@ -3,39 +3,14 @@
|
|||
#include <jni.h>
|
||||
|
||||
#include "ray/common/id.h"
|
||||
#include "ray/core_worker/lib/java/jni_utils.h"
|
||||
#include "ray/raylet/raylet_client.h"
|
||||
#include "ray/util/logging.h"
|
||||
|
||||
template <typename ID>
|
||||
class UniqueIdFromJByteArray {
|
||||
public:
|
||||
const ID &GetId() const { return id; }
|
||||
|
||||
UniqueIdFromJByteArray(JNIEnv *env, const jbyteArray &bytes) {
|
||||
std::string id_str(ID::Size(), 0);
|
||||
env->GetByteArrayRegion(bytes, 0, ID::Size(),
|
||||
reinterpret_cast<jbyte *>(&id_str.front()));
|
||||
id = ID::FromBinary(id_str);
|
||||
}
|
||||
|
||||
private:
|
||||
ID id;
|
||||
};
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
inline bool ThrowRayExceptionIfNotOK(JNIEnv *env, const ray::Status &status) {
|
||||
if (!status.ok()) {
|
||||
jclass exception_class = env->FindClass("org/ray/api/exception/RayException");
|
||||
env->ThrowNew(exception_class, status.message().c_str());
|
||||
return true;
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: org_ray_runtime_raylet_RayletClientImpl
|
||||
* Method: nativeInit
|
||||
|
@ -44,11 +19,11 @@ inline bool ThrowRayExceptionIfNotOK(JNIEnv *env, const ray::Status &status) {
|
|||
JNIEXPORT jlong JNICALL Java_org_ray_runtime_raylet_RayletClientImpl_nativeInit(
|
||||
JNIEnv *env, jclass, jstring sockName, jbyteArray workerId, jboolean isWorker,
|
||||
jbyteArray jobId) {
|
||||
UniqueIdFromJByteArray<ClientID> worker_id(env, workerId);
|
||||
UniqueIdFromJByteArray<JobID> job_id(env, jobId);
|
||||
const auto worker_id = JavaByteArrayToId<ClientID>(env, workerId);
|
||||
const auto job_id = JavaByteArrayToId<JobID>(env, jobId);
|
||||
const char *nativeString = env->GetStringUTFChars(sockName, JNI_FALSE);
|
||||
auto raylet_client = new RayletClient(nativeString, worker_id.GetId(), isWorker,
|
||||
job_id.GetId(), Language::JAVA);
|
||||
auto raylet_client = new std::unique_ptr<RayletClient>(
|
||||
new RayletClient(nativeString, worker_id, isWorker, job_id, Language::JAVA));
|
||||
env->ReleaseStringUTFChars(sockName, nativeString);
|
||||
return reinterpret_cast<jlong>(raylet_client);
|
||||
}
|
||||
|
@ -60,7 +35,7 @@ JNIEXPORT jlong JNICALL Java_org_ray_runtime_raylet_RayletClientImpl_nativeInit(
|
|||
*/
|
||||
JNIEXPORT void JNICALL Java_org_ray_runtime_raylet_RayletClientImpl_nativeSubmitTask(
|
||||
JNIEnv *env, jclass, jlong client, jbyteArray taskSpec) {
|
||||
auto raylet_client = reinterpret_cast<RayletClient *>(client);
|
||||
auto &raylet_client = *reinterpret_cast<std::unique_ptr<RayletClient> *>(client);
|
||||
|
||||
jbyte *data = env->GetByteArrayElements(taskSpec, NULL);
|
||||
jsize size = env->GetArrayLength(taskSpec);
|
||||
|
@ -70,7 +45,7 @@ JNIEXPORT void JNICALL Java_org_ray_runtime_raylet_RayletClientImpl_nativeSubmit
|
|||
|
||||
ray::TaskSpecification task_spec(task_spec_message);
|
||||
auto status = raylet_client->SubmitTask(task_spec);
|
||||
ThrowRayExceptionIfNotOK(env, status);
|
||||
THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, (void)0);
|
||||
}
|
||||
|
||||
/*
|
||||
|
@ -80,13 +55,11 @@ JNIEXPORT void JNICALL Java_org_ray_runtime_raylet_RayletClientImpl_nativeSubmit
|
|||
*/
|
||||
JNIEXPORT jbyteArray JNICALL Java_org_ray_runtime_raylet_RayletClientImpl_nativeGetTask(
|
||||
JNIEnv *env, jclass, jlong client) {
|
||||
auto raylet_client = reinterpret_cast<RayletClient *>(client);
|
||||
auto &raylet_client = *reinterpret_cast<std::unique_ptr<RayletClient> *>(client);
|
||||
|
||||
std::unique_ptr<ray::TaskSpecification> spec;
|
||||
auto status = raylet_client->GetTask(&spec);
|
||||
if (ThrowRayExceptionIfNotOK(env, status)) {
|
||||
return nullptr;
|
||||
}
|
||||
THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, nullptr);
|
||||
|
||||
// Serialize the task spec and copy to Java byte array.
|
||||
auto task_data = spec->Serialize();
|
||||
|
@ -109,8 +82,9 @@ JNIEXPORT jbyteArray JNICALL Java_org_ray_runtime_raylet_RayletClientImpl_native
|
|||
*/
|
||||
JNIEXPORT void JNICALL Java_org_ray_runtime_raylet_RayletClientImpl_nativeDestroy(
|
||||
JNIEnv *env, jclass, jlong client) {
|
||||
auto raylet_client = reinterpret_cast<RayletClient *>(client);
|
||||
ThrowRayExceptionIfNotOK(env, raylet_client->Disconnect());
|
||||
auto raylet_client = reinterpret_cast<std::unique_ptr<RayletClient> *>(client);
|
||||
auto status = (*raylet_client)->Disconnect();
|
||||
THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, (void)0);
|
||||
delete raylet_client;
|
||||
}
|
||||
|
||||
|
@ -128,15 +102,14 @@ Java_org_ray_runtime_raylet_RayletClientImpl_nativeFetchOrReconstruct(
|
|||
for (int i = 0; i < len; i++) {
|
||||
jbyteArray object_id_bytes =
|
||||
static_cast<jbyteArray>(env->GetObjectArrayElement(objectIds, i));
|
||||
UniqueIdFromJByteArray<ObjectID> object_id(env, object_id_bytes);
|
||||
object_ids.push_back(object_id.GetId());
|
||||
const auto object_id = JavaByteArrayToId<ObjectID>(env, object_id_bytes);
|
||||
object_ids.push_back(object_id);
|
||||
env->DeleteLocalRef(object_id_bytes);
|
||||
}
|
||||
UniqueIdFromJByteArray<TaskID> current_task_id(env, currentTaskId);
|
||||
auto raylet_client = reinterpret_cast<RayletClient *>(client);
|
||||
auto status =
|
||||
raylet_client->FetchOrReconstruct(object_ids, fetchOnly, current_task_id.GetId());
|
||||
ThrowRayExceptionIfNotOK(env, status);
|
||||
const auto current_task_id = JavaByteArrayToId<TaskID>(env, currentTaskId);
|
||||
auto &raylet_client = *reinterpret_cast<std::unique_ptr<RayletClient> *>(client);
|
||||
auto status = raylet_client->FetchOrReconstruct(object_ids, fetchOnly, current_task_id);
|
||||
THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, (void)0);
|
||||
}
|
||||
|
||||
/*
|
||||
|
@ -146,10 +119,10 @@ Java_org_ray_runtime_raylet_RayletClientImpl_nativeFetchOrReconstruct(
|
|||
*/
|
||||
JNIEXPORT void JNICALL Java_org_ray_runtime_raylet_RayletClientImpl_nativeNotifyUnblocked(
|
||||
JNIEnv *env, jclass, jlong client, jbyteArray currentTaskId) {
|
||||
UniqueIdFromJByteArray<TaskID> current_task_id(env, currentTaskId);
|
||||
auto raylet_client = reinterpret_cast<RayletClient *>(client);
|
||||
auto status = raylet_client->NotifyUnblocked(current_task_id.GetId());
|
||||
ThrowRayExceptionIfNotOK(env, status);
|
||||
const auto current_task_id = JavaByteArrayToId<TaskID>(env, currentTaskId);
|
||||
auto &raylet_client = *reinterpret_cast<std::unique_ptr<RayletClient> *>(client);
|
||||
auto status = raylet_client->NotifyUnblocked(current_task_id);
|
||||
THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, (void)0);
|
||||
}
|
||||
|
||||
/*
|
||||
|
@ -166,22 +139,20 @@ Java_org_ray_runtime_raylet_RayletClientImpl_nativeWaitObject(
|
|||
for (int i = 0; i < len; i++) {
|
||||
jbyteArray object_id_bytes =
|
||||
static_cast<jbyteArray>(env->GetObjectArrayElement(objectIds, i));
|
||||
UniqueIdFromJByteArray<ObjectID> object_id(env, object_id_bytes);
|
||||
object_ids.push_back(object_id.GetId());
|
||||
const auto object_id = JavaByteArrayToId<ObjectID>(env, object_id_bytes);
|
||||
object_ids.push_back(object_id);
|
||||
env->DeleteLocalRef(object_id_bytes);
|
||||
}
|
||||
UniqueIdFromJByteArray<TaskID> current_task_id(env, currentTaskId);
|
||||
const auto current_task_id = JavaByteArrayToId<TaskID>(env, currentTaskId);
|
||||
|
||||
auto raylet_client = reinterpret_cast<RayletClient *>(client);
|
||||
auto &raylet_client = *reinterpret_cast<std::unique_ptr<RayletClient> *>(client);
|
||||
|
||||
// Invoke wait.
|
||||
WaitResultPair result;
|
||||
auto status = raylet_client->Wait(object_ids, numReturns, timeoutMillis,
|
||||
static_cast<bool>(isWaitLocal),
|
||||
current_task_id.GetId(), &result);
|
||||
if (ThrowRayExceptionIfNotOK(env, status)) {
|
||||
return nullptr;
|
||||
}
|
||||
auto status =
|
||||
raylet_client->Wait(object_ids, numReturns, timeoutMillis,
|
||||
static_cast<bool>(isWaitLocal), current_task_id, &result);
|
||||
THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, nullptr);
|
||||
|
||||
// Convert result to java object.
|
||||
jboolean put_value = true;
|
||||
|
@ -216,11 +187,10 @@ JNIEXPORT jbyteArray JNICALL
|
|||
Java_org_ray_runtime_raylet_RayletClientImpl_nativeGenerateTaskId(
|
||||
JNIEnv *env, jclass, jbyteArray jobId, jbyteArray parentTaskId,
|
||||
jint parent_task_counter) {
|
||||
UniqueIdFromJByteArray<JobID> job_id(env, jobId);
|
||||
UniqueIdFromJByteArray<TaskID> parent_task_id(env, parentTaskId);
|
||||
const auto job_id = JavaByteArrayToId<JobID>(env, jobId);
|
||||
const auto parent_task_id = JavaByteArrayToId<TaskID>(env, parentTaskId);
|
||||
|
||||
TaskID task_id =
|
||||
ray::GenerateTaskId(job_id.GetId(), parent_task_id.GetId(), parent_task_counter);
|
||||
TaskID task_id = ray::GenerateTaskId(job_id, parent_task_id, parent_task_counter);
|
||||
jbyteArray result = env->NewByteArray(task_id.Size());
|
||||
if (nullptr == result) {
|
||||
return nullptr;
|
||||
|
@ -245,13 +215,13 @@ Java_org_ray_runtime_raylet_RayletClientImpl_nativeFreePlasmaObjects(
|
|||
for (int i = 0; i < len; i++) {
|
||||
jbyteArray object_id_bytes =
|
||||
static_cast<jbyteArray>(env->GetObjectArrayElement(objectIds, i));
|
||||
UniqueIdFromJByteArray<ObjectID> object_id(env, object_id_bytes);
|
||||
object_ids.push_back(object_id.GetId());
|
||||
const auto object_id = JavaByteArrayToId<ObjectID>(env, object_id_bytes);
|
||||
object_ids.push_back(object_id);
|
||||
env->DeleteLocalRef(object_id_bytes);
|
||||
}
|
||||
auto raylet_client = reinterpret_cast<RayletClient *>(client);
|
||||
auto &raylet_client = *reinterpret_cast<std::unique_ptr<RayletClient> *>(client);
|
||||
auto status = raylet_client->FreeObjects(object_ids, localOnly, deleteCreatingTasks);
|
||||
ThrowRayExceptionIfNotOK(env, status);
|
||||
THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, (void)0);
|
||||
}
|
||||
|
||||
/*
|
||||
|
@ -263,13 +233,11 @@ JNIEXPORT jbyteArray JNICALL
|
|||
Java_org_ray_runtime_raylet_RayletClientImpl_nativePrepareCheckpoint(JNIEnv *env, jclass,
|
||||
jlong client,
|
||||
jbyteArray actorId) {
|
||||
auto raylet_client = reinterpret_cast<RayletClient *>(client);
|
||||
UniqueIdFromJByteArray<ActorID> actor_id(env, actorId);
|
||||
auto &raylet_client = *reinterpret_cast<std::unique_ptr<RayletClient> *>(client);
|
||||
const auto actor_id = JavaByteArrayToId<ActorID>(env, actorId);
|
||||
ActorCheckpointID checkpoint_id;
|
||||
auto status = raylet_client->PrepareActorCheckpoint(actor_id.GetId(), checkpoint_id);
|
||||
if (ThrowRayExceptionIfNotOK(env, status)) {
|
||||
return nullptr;
|
||||
}
|
||||
auto status = raylet_client->PrepareActorCheckpoint(actor_id, checkpoint_id);
|
||||
THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, nullptr);
|
||||
jbyteArray result = env->NewByteArray(checkpoint_id.Size());
|
||||
env->SetByteArrayRegion(result, 0, checkpoint_id.Size(),
|
||||
reinterpret_cast<const jbyte *>(checkpoint_id.Data()));
|
||||
|
@ -284,12 +252,11 @@ Java_org_ray_runtime_raylet_RayletClientImpl_nativePrepareCheckpoint(JNIEnv *env
|
|||
JNIEXPORT void JNICALL
|
||||
Java_org_ray_runtime_raylet_RayletClientImpl_nativeNotifyActorResumedFromCheckpoint(
|
||||
JNIEnv *env, jclass, jlong client, jbyteArray actorId, jbyteArray checkpointId) {
|
||||
auto raylet_client = reinterpret_cast<RayletClient *>(client);
|
||||
UniqueIdFromJByteArray<ActorID> actor_id(env, actorId);
|
||||
UniqueIdFromJByteArray<ActorCheckpointID> checkpoint_id(env, checkpointId);
|
||||
auto status = raylet_client->NotifyActorResumedFromCheckpoint(actor_id.GetId(),
|
||||
checkpoint_id.GetId());
|
||||
ThrowRayExceptionIfNotOK(env, status);
|
||||
auto &raylet_client = *reinterpret_cast<std::unique_ptr<RayletClient> *>(client);
|
||||
const auto actor_id = JavaByteArrayToId<ActorID>(env, actorId);
|
||||
const auto checkpoint_id = JavaByteArrayToId<ActorCheckpointID>(env, checkpointId);
|
||||
auto status = raylet_client->NotifyActorResumedFromCheckpoint(actor_id, checkpoint_id);
|
||||
THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, (void)0);
|
||||
}
|
||||
|
||||
/*
|
||||
|
@ -300,14 +267,14 @@ Java_org_ray_runtime_raylet_RayletClientImpl_nativeNotifyActorResumedFromCheckpo
|
|||
JNIEXPORT void JNICALL Java_org_ray_runtime_raylet_RayletClientImpl_nativeSetResource(
|
||||
JNIEnv *env, jclass, jlong client, jstring resourceName, jdouble capacity,
|
||||
jbyteArray nodeId) {
|
||||
auto raylet_client = reinterpret_cast<RayletClient *>(client);
|
||||
UniqueIdFromJByteArray<ClientID> node_id(env, nodeId);
|
||||
auto &raylet_client = *reinterpret_cast<std::unique_ptr<RayletClient> *>(client);
|
||||
const auto node_id = JavaByteArrayToId<ClientID>(env, nodeId);
|
||||
const char *native_resource_name = env->GetStringUTFChars(resourceName, JNI_FALSE);
|
||||
|
||||
auto status = raylet_client->SetResource(
|
||||
native_resource_name, static_cast<double>(capacity), node_id.GetId());
|
||||
auto status = raylet_client->SetResource(native_resource_name,
|
||||
static_cast<double>(capacity), node_id);
|
||||
env->ReleaseStringUTFChars(resourceName, native_resource_name);
|
||||
ThrowRayExceptionIfNotOK(env, status);
|
||||
THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, (void)0);
|
||||
}
|
||||
|
||||
#ifdef __cplusplus
|
||||
|
|
Loading…
Add table
Reference in a new issue