[Java worker] Refactor object store and worker context on top of core worker (#5079)

This commit is contained in:
Kai Yang 2019-07-16 20:58:02 +08:00 committed by Hao Chen
parent e5be5fd46d
commit 806524384b
40 changed files with 1386 additions and 571 deletions

View file

@ -647,13 +647,13 @@ pyx_library(
)
cc_binary(
name = "libraylet_library_java.so",
srcs = [
"src/ray/raylet/lib/java/org_ray_runtime_raylet_RayletClientImpl.h",
"src/ray/raylet/lib/java/org_ray_runtime_raylet_RayletClientImpl.cc",
"src/ray/common/id.h",
"src/ray/raylet/raylet_client.h",
"src/ray/util/logging.h",
name = "libcore_worker_library_java.so",
srcs = glob([
"src/ray/core_worker/lib/java/*.h",
"src/ray/core_worker/lib/java/*.cc",
"src/ray/raylet/lib/java/*.h",
"src/ray/raylet/lib/java/*.cc",
]) + [
"@bazel_tools//tools/jdk:jni_header",
] + select({
"@bazel_tools//src/conditions:windows": ["@bazel_tools//tools/jdk:jni_md_header-windows"],
@ -671,24 +671,23 @@ cc_binary(
linkshared = 1,
linkstatic = 1,
deps = [
"//:raylet_lib",
"@plasma//:plasma_client",
"//:core_worker_lib",
],
)
genrule(
name = "raylet-jni-darwin-compat",
srcs = [":libraylet_library_java.so"],
outs = ["libraylet_library_java.dylib"],
name = "core_worker-jni-darwin-compat",
srcs = [":libcore_worker_library_java.so"],
outs = ["libcore_worker_library_java.dylib"],
cmd = "cp $< $@",
output_to_bindir = 1,
)
filegroup(
name = "raylet_library_java",
name = "core_worker_library_java",
srcs = select({
"@bazel_tools//src/conditions:darwin": [":libraylet_library_java.dylib"],
"//conditions:default": [":libraylet_library_java.so"],
"@bazel_tools//src/conditions:darwin": [":libcore_worker_library_java.dylib"],
"//conditions:default": [":libcore_worker_library_java.so"],
}),
visibility = ["//java:__subpackages__"],
)

View file

@ -2,27 +2,6 @@ load("@com_github_google_flatbuffers//:build_defs.bzl", "flatbuffer_cc_library")
COPTS = ["-DARROW_USE_GLOG"]
java_library(
name = "org_apache_arrow_arrow_plasma",
srcs = glob(["java/plasma/src/main/java/**/*.java"]),
data = [":plasma_client_java"],
visibility = ["//visibility:public"],
deps = [
"@maven//:org_slf4j_slf4j_api",
],
)
java_binary(
name = "org_apache_arrow_arrow_plasma_test",
srcs = ["java/plasma/src/test/java/org/apache/arrow/plasma/PlasmaClientTest.java"],
main_class = "org.apache.arrow.plasma.PlasmaClientTest",
visibility = ["//visibility:public"],
deps = [
":org_apache_arrow_arrow_plasma",
"@maven//:junit_junit",
],
)
cc_library(
name = "arrow",
srcs = [
@ -145,15 +124,6 @@ genrule(
output_to_bindir = 1,
)
filegroup(
name = "plasma_client_java",
srcs = select({
"@bazel_tools//src/conditions:darwin": [":libplasma_java.dylib"],
"//conditions:default": [":libplasma_java.so"],
}),
visibility = ["//visibility:public"],
)
cc_library(
name = "plasma_lib",
srcs = [

View file

@ -69,7 +69,6 @@ define_java_module(
],
deps = [
":org_ray_ray_api",
"@plasma//:org_apache_arrow_arrow_plasma",
"@maven//:com_google_guava_guava",
"@maven//:com_google_protobuf_protobuf_java",
"@maven//:com_typesafe_config",
@ -97,7 +96,6 @@ define_java_module(
deps = [
":org_ray_ray_api",
":org_ray_ray_runtime",
"@plasma//:org_apache_arrow_arrow_plasma",
"@maven//:com_google_guava_guava",
"@maven//:com_sun_xml_bind_jaxb_core",
"@maven//:com_sun_xml_bind_jaxb_impl",
@ -176,9 +174,8 @@ filegroup(
"//:redis-server",
"//:libray_redis_module.so",
"//:raylet",
"//:raylet_library_java",
"//:core_worker_library_java",
"@plasma//:plasma_store_server",
"@plasma//:plasma_client_java",
],
)
@ -189,7 +186,6 @@ genrule(
":all_java_proto",
":java_native_deps",
":copy_pom_file",
"@plasma//:org_apache_arrow_arrow_plasma",
],
outs = ["gen_maven_deps.out"],
cmd = """
@ -208,9 +204,6 @@ genrule(
chmod +w $$f
cp $$f $$NATIVE_DEPS_DIR
done
# Install plasma jar to local maven repo.
mvn install:install-file -Dfile=$(locations @plasma//:org_apache_arrow_arrow_plasma) -Dpackaging=jar \
-DgroupId=org.apache.arrow -DartifactId=arrow-plasma -Dversion=0.13.0-SNAPSHOT
echo $$(date) > $@
""",
local = 1,

View file

@ -24,11 +24,6 @@
<dependencyManagement>
<dependencies>
<dependency>
<groupId>org.apache.arrow</groupId>
<artifactId>arrow-plasma</artifactId>
<version>0.13.0-SNAPSHOT</version>
</dependency>
</dependencies>
</dependencyManagement>

View file

@ -22,10 +22,6 @@
<artifactId>ray-api</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
<groupId>org.apache.arrow</groupId>
<artifactId>arrow-plasma</artifactId>
</dependency>
<dependency>
<groupId>com.beust</groupId>
<artifactId>jcommander</artifactId>

View file

@ -22,10 +22,6 @@
<artifactId>ray-api</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
<groupId>org.apache.arrow</groupId>
<artifactId>arrow-plasma</artifactId>
</dependency>
{generated_bzl_deps}
</dependencies>

View file

@ -1,7 +1,15 @@
package org.ray.runtime;
import com.google.common.base.Preconditions;
import com.google.common.base.Strings;
import com.google.common.collect.ImmutableList;
import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.lang.reflect.Field;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.nio.file.StandardCopyOption;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
@ -72,6 +80,27 @@ public abstract class AbstractRayRuntime implements RayRuntime {
protected RuntimeContext runtimeContext;
protected GcsClient gcsClient;
static {
try {
LOGGER.debug("Loading native libraries.");
// Load native libraries.
String[] libraries = new String[]{"core_worker_library_java"};
for (String library : libraries) {
String fileName = System.mapLibraryName(library);
// Copy the file from resources to a temp dir, and load the native library.
File file = File.createTempFile(fileName, "");
file.deleteOnExit();
InputStream in = AbstractRayRuntime.class.getResourceAsStream("/" + fileName);
Preconditions.checkNotNull(in, "{} doesn't exist.", fileName);
Files.copy(in, Paths.get(file.getAbsolutePath()), StandardCopyOption.REPLACE_EXISTING);
System.load(file.getAbsolutePath());
}
LOGGER.debug("Native libraries loaded.");
} catch (IOException e) {
throw new RuntimeException("Couldn't load native libraries.", e);
}
}
public AbstractRayRuntime(RayConfig rayConfig) {
this.rayConfig = rayConfig;
functionManager = new FunctionManager(rayConfig.jobResourcePath);
@ -79,6 +108,33 @@ public abstract class AbstractRayRuntime implements RayRuntime {
runtimeContext = new RuntimeContextImpl(this);
}
protected void resetLibraryPath() {
if (rayConfig.libraryPath.isEmpty()) {
return;
}
String path = System.getProperty("java.library.path");
if (Strings.isNullOrEmpty(path)) {
path = "";
} else {
path += ":";
}
path += String.join(":", rayConfig.libraryPath);
// This is a hack to reset library path at runtime,
// see https://stackoverflow.com/questions/15409223/.
System.setProperty("java.library.path", path);
// Set sys_paths to null so that java.library.path will be re-evaluated next time it is needed.
final Field sysPathsField;
try {
sysPathsField = ClassLoader.class.getDeclaredField("sys_paths");
sysPathsField.setAccessible(true);
sysPathsField.set(null, null);
} catch (NoSuchFieldException | IllegalAccessException e) {
LOGGER.error("Failed to set library path.", e);
}
}
/**
* Start runtime.
*/
@ -330,8 +386,8 @@ public abstract class AbstractRayRuntime implements RayRuntime {
* Create the task specification.
*
* @param func The target remote function.
* @param pyFunctionDescriptor Descriptor of the target Python function, if the task is a
* Python task.
* @param pyFunctionDescriptor Descriptor of the target Python function, if the task is a Python
* task.
* @param actor The actor handle. If the task is not an actor task, actor id must be NIL.
* @param args The arguments for the remote function.
* @param isActorCreationTask Whether this task is an actor creation task.

View file

@ -3,7 +3,7 @@ package org.ray.runtime;
import java.util.concurrent.atomic.AtomicInteger;
import org.ray.api.id.JobId;
import org.ray.runtime.config.RayConfig;
import org.ray.runtime.objectstore.MockObjectStore;
import org.ray.runtime.objectstore.MockObjectInterface;
import org.ray.runtime.objectstore.ObjectStoreProxy;
import org.ray.runtime.raylet.MockRayletClient;
@ -13,19 +13,22 @@ public class RayDevRuntime extends AbstractRayRuntime {
super(rayConfig);
}
private MockObjectStore store;
private MockObjectInterface objectInterface;
private AtomicInteger jobCounter = new AtomicInteger(0);
@Override
public void start() {
store = new MockObjectStore(this);
// Reset library path at runtime.
resetLibraryPath();
objectInterface = new MockObjectInterface(workerContext);
if (rayConfig.getJobId().isNil()) {
rayConfig.setJobId(nextJobId());
}
workerContext = new WorkerContext(rayConfig.workerMode,
rayConfig.getJobId(), rayConfig.runMode);
objectStoreProxy = new ObjectStoreProxy(this, null);
objectStoreProxy = new ObjectStoreProxy(workerContext, objectInterface);
rayletClient = new MockRayletClient(this, rayConfig.numberExecThreadsForDevRuntime);
}
@ -34,8 +37,8 @@ public class RayDevRuntime extends AbstractRayRuntime {
rayletClient.destroy();
}
public MockObjectStore getObjectStore() {
return store;
public MockObjectInterface getObjectInterface() {
return objectInterface;
}
@Override

View file

@ -1,21 +1,13 @@
package org.ray.runtime;
import com.google.common.base.Preconditions;
import com.google.common.base.Strings;
import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.lang.reflect.Field;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.nio.file.StandardCopyOption;
import java.util.HashMap;
import java.util.Map;
import org.ray.api.id.JobId;
import org.ray.runtime.config.RayConfig;
import org.ray.runtime.config.WorkerMode;
import org.ray.runtime.gcs.GcsClient;
import org.ray.runtime.gcs.RedisClient;
import org.ray.runtime.generated.Common.WorkerType;
import org.ray.runtime.objectstore.ObjectInterfaceImpl;
import org.ray.runtime.objectstore.ObjectStoreProxy;
import org.ray.runtime.raylet.RayletClientImpl;
import org.ray.runtime.runner.RunManager;
@ -31,58 +23,12 @@ public final class RayNativeRuntime extends AbstractRayRuntime {
private RunManager manager = null;
static {
try {
LOGGER.debug("Loading native libraries.");
// Load native libraries.
String[] libraries = new String[]{"raylet_library_java", "plasma_java"};
for (String library : libraries) {
String fileName = System.mapLibraryName(library);
// Copy the file from resources to a temp dir, and load the native library.
File file = File.createTempFile(fileName, "");
file.deleteOnExit();
InputStream in = RayNativeRuntime.class.getResourceAsStream("/" + fileName);
Preconditions.checkNotNull(in, "{} doesn't exist.", fileName);
Files.copy(in, Paths.get(file.getAbsolutePath()), StandardCopyOption.REPLACE_EXISTING);
System.load(file.getAbsolutePath());
}
LOGGER.debug("Native libraries loaded.");
} catch (IOException e) {
throw new RuntimeException("Couldn't load native libraries.", e);
}
}
private ObjectInterfaceImpl objectInterfaceImpl = null;
public RayNativeRuntime(RayConfig rayConfig) {
super(rayConfig);
}
private void resetLibraryPath() {
if (rayConfig.libraryPath.isEmpty()) {
return;
}
String path = System.getProperty("java.library.path");
if (Strings.isNullOrEmpty(path)) {
path = "";
} else {
path += ":";
}
path += String.join(":", rayConfig.libraryPath);
// This is a hack to reset library path at runtime,
// see https://stackoverflow.com/questions/15409223/.
System.setProperty("java.library.path", path);
// Set sys_paths to null so that java.library.path will be re-evaluated next time it is needed.
final Field sysPathsField;
try {
sysPathsField = ClassLoader.class.getDeclaredField("sys_paths");
sysPathsField.setAccessible(true);
sysPathsField.set(null, null);
} catch (NoSuchFieldException | IllegalAccessException e) {
LOGGER.error("Failed to set library path.", e);
}
}
@Override
public void start() {
// Reset library path at runtime.
@ -101,16 +47,18 @@ public final class RayNativeRuntime extends AbstractRayRuntime {
workerContext = new WorkerContext(rayConfig.workerMode,
rayConfig.getJobId(), rayConfig.runMode);
// TODO(qwang): Get object_store_socket_name and raylet_socket_name from Redis.
objectStoreProxy = new ObjectStoreProxy(this, rayConfig.objectStoreSocketName);
rayletClient = new RayletClientImpl(
rayConfig.rayletSocketName,
workerContext.getCurrentWorkerId(),
rayConfig.workerMode == WorkerMode.WORKER,
rayConfig.workerMode == WorkerType.WORKER,
workerContext.getCurrentJobId()
);
// TODO(qwang): Get object_store_socket_name and raylet_socket_name from Redis.
objectInterfaceImpl = new ObjectInterfaceImpl(workerContext, rayletClient,
rayConfig.objectStoreSocketName);
objectStoreProxy = new ObjectStoreProxy(workerContext, objectInterfaceImpl);
// register
registerWorker();
@ -123,6 +71,8 @@ public final class RayNativeRuntime extends AbstractRayRuntime {
if (null != manager) {
manager.cleanup();
}
objectInterfaceImpl.destroy();
workerContext.destroy();
}
/**
@ -132,7 +82,7 @@ public final class RayNativeRuntime extends AbstractRayRuntime {
RedisClient redisClient = new RedisClient(rayConfig.getRedisAddress(), rayConfig.redisPassword);
Map<String, String> workerInfo = new HashMap<>();
String workerId = new String(workerContext.getCurrentWorkerId().getBytes());
if (rayConfig.workerMode == WorkerMode.DRIVER) {
if (rayConfig.workerMode == WorkerType.DRIVER) {
workerInfo.put("node_ip_address", rayConfig.nodeIp);
workerInfo.put("driver_id", workerId);
workerInfo.put("start_time", String.valueOf(System.currentTimeMillis()));

View file

@ -1,37 +1,24 @@
package org.ray.runtime;
import com.google.common.base.Preconditions;
import java.nio.ByteBuffer;
import org.ray.api.id.JobId;
import org.ray.api.id.TaskId;
import org.ray.api.id.UniqueId;
import org.ray.runtime.config.RunMode;
import org.ray.runtime.config.WorkerMode;
import org.ray.runtime.generated.Common.WorkerType;
import org.ray.runtime.raylet.RayletClientImpl;
import org.ray.runtime.task.TaskSpec;
import org.ray.runtime.util.IdUtil;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* This is a wrapper class for worker context of core worker.
*/
public class WorkerContext {
private static final Logger LOGGER = LoggerFactory.getLogger(WorkerContext.class);
private UniqueId workerId;
private ThreadLocal<TaskId> currentTaskId;
/**
* Number of objects that have been put from current task.
* The native pointer of worker context of core worker.
*/
private ThreadLocal<Integer> putIndex;
/**
* Number of tasks that have been submitted from current task.
*/
private ThreadLocal<Integer> taskIndex;
private ThreadLocal<TaskSpec> currentTask;
private JobId currentJobId;
private final long nativeWorkerContextPointer;
private ClassLoader currentClassLoader;
@ -45,31 +32,23 @@ public class WorkerContext {
*/
private RunMode runMode;
public WorkerContext(WorkerMode workerMode, JobId jobId, RunMode runMode) {
public WorkerContext(WorkerType workerType, JobId jobId, RunMode runMode) {
this.nativeWorkerContextPointer = nativeCreateWorkerContext(workerType.getNumber(), jobId.getBytes());
mainThreadId = Thread.currentThread().getId();
taskIndex = ThreadLocal.withInitial(() -> 0);
putIndex = ThreadLocal.withInitial(() -> 0);
currentTaskId = ThreadLocal.withInitial(TaskId::randomId);
this.runMode = runMode;
currentTask = ThreadLocal.withInitial(() -> null);
currentClassLoader = null;
if (workerMode == WorkerMode.DRIVER) {
workerId = IdUtil.computeDriverId(jobId);
currentTaskId.set(TaskId.randomId());
currentJobId = jobId;
} else {
workerId = UniqueId.randomId();
this.currentTaskId.set(TaskId.NIL);
this.currentJobId = JobId.NIL;
}
}
public long getNativeWorkerContext() {
return nativeWorkerContextPointer;
}
/**
* @return For the main thread, this method returns the ID of this worker's current running task;
* for other threads, this method returns a random ID.
* for other threads, this method returns a random ID.
*/
public TaskId getCurrentTaskId() {
return currentTaskId.get();
return new TaskId(nativeGetCurrentTaskId(nativeWorkerContextPointer));
}
/**
@ -79,17 +58,14 @@ public class WorkerContext {
public void setCurrentTask(TaskSpec task, ClassLoader classLoader) {
if (runMode == RunMode.CLUSTER) {
Preconditions.checkState(
Thread.currentThread().getId() == mainThreadId,
"This method should only be called from the main thread."
Thread.currentThread().getId() == mainThreadId,
"This method should only be called from the main thread."
);
}
Preconditions.checkNotNull(task);
this.currentTaskId.set(task.taskId);
this.currentJobId = task.jobId;
taskIndex.set(0);
putIndex.set(0);
this.currentTask.set(task);
byte[] taskSpec = RayletClientImpl.convertTaskSpecToProtobuf(task);
nativeSetCurrentTask(nativeWorkerContextPointer, taskSpec);
currentClassLoader = classLoader;
}
@ -97,30 +73,28 @@ public class WorkerContext {
* Increment the put index and return the new value.
*/
public int nextPutIndex() {
putIndex.set(putIndex.get() + 1);
return putIndex.get();
return nativeGetNextPutIndex(nativeWorkerContextPointer);
}
/**
* Increment the task index and return the new value.
*/
public int nextTaskIndex() {
taskIndex.set(taskIndex.get() + 1);
return taskIndex.get();
return nativeGetNextTaskIndex(nativeWorkerContextPointer);
}
/**
* @return The ID of the current worker.
*/
public UniqueId getCurrentWorkerId() {
return workerId;
return new UniqueId(nativeGetCurrentWorkerId(nativeWorkerContextPointer));
}
/**
* The ID of the current job.
*/
public JobId getCurrentJobId() {
return currentJobId;
return JobId.fromByteBuffer(nativeGetCurrentJobId(nativeWorkerContextPointer));
}
/**
@ -134,6 +108,32 @@ public class WorkerContext {
* Get the current task.
*/
public TaskSpec getCurrentTask() {
return this.currentTask.get();
byte[] bytes = nativeGetCurrentTask(nativeWorkerContextPointer);
if (bytes == null) {
return null;
}
return RayletClientImpl.parseTaskSpecFromProtobuf(bytes);
}
public void destroy() {
nativeDestroy(nativeWorkerContextPointer);
}
private static native long nativeCreateWorkerContext(int workerType, byte[] jobId);
private static native byte[] nativeGetCurrentTaskId(long nativeWorkerContextPointer);
private static native void nativeSetCurrentTask(long nativeWorkerContextPointer, byte[] taskSpec);
private static native byte[] nativeGetCurrentTask(long nativeWorkerContextPointer);
private static native ByteBuffer nativeGetCurrentJobId(long nativeWorkerContextPointer);
private static native byte[] nativeGetCurrentWorkerId(long nativeWorkerContextPointer);
private static native int nativeGetNextTaskIndex(long nativeWorkerContextPointer);
private static native int nativeGetNextPutIndex(long nativeWorkerContextPointer);
private static native void nativeDestroy(long nativeWorkerContextPointer);
}

View file

@ -11,6 +11,7 @@ import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import org.ray.api.id.JobId;
import org.ray.runtime.generated.Common.WorkerType;
import org.ray.runtime.util.NetworkUtil;
import org.ray.runtime.util.ResourceUtil;
import org.ray.runtime.util.StringUtil;
@ -29,7 +30,7 @@ public class RayConfig {
public static final String CUSTOM_CONFIG_FILE = "ray.conf";
public final String nodeIp;
public final WorkerMode workerMode;
public final WorkerType workerMode;
public final RunMode runMode;
public final Map<String, Double> resources;
private JobId jobId;
@ -62,7 +63,7 @@ public class RayConfig {
public final int numberExecThreadsForDevRuntime;
private void validate() {
if (workerMode == WorkerMode.WORKER) {
if (workerMode == WorkerType.WORKER) {
Preconditions.checkArgument(redisAddress != null,
"Redis address must be set in worker mode.");
}
@ -78,14 +79,14 @@ public class RayConfig {
public RayConfig(Config config) {
// Worker mode.
WorkerMode localWorkerMode;
WorkerType localWorkerMode;
try {
localWorkerMode = config.getEnum(WorkerMode.class, "ray.worker.mode");
localWorkerMode = config.getEnum(WorkerType.class, "ray.worker.mode");
} catch (ConfigException.Missing e) {
localWorkerMode = WorkerMode.DRIVER;
localWorkerMode = WorkerType.DRIVER;
}
workerMode = localWorkerMode;
boolean isDriver = workerMode == WorkerMode.DRIVER;
boolean isDriver = workerMode == WorkerType.DRIVER;
// Run mode.
runMode = config.getEnum(RunMode.class, "ray.run-mode");
// Node ip.

View file

@ -1,6 +0,0 @@
package org.ray.runtime.config;
public enum WorkerMode {
DRIVER,
WORKER
}

View file

@ -0,0 +1,98 @@
package org.ray.runtime.objectstore;
import com.google.common.base.Preconditions;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.function.Consumer;
import java.util.stream.Collectors;
import org.ray.api.id.ObjectId;
import org.ray.runtime.WorkerContext;
import org.ray.runtime.util.IdUtil;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
public class MockObjectInterface implements ObjectInterface {
private static final Logger LOGGER = LoggerFactory.getLogger(MockObjectInterface.class);
private static final int GET_CHECK_INTERVAL_MS = 100;
private final Map<ObjectId, NativeRayObject> pool = new ConcurrentHashMap<>();
private final List<Consumer<ObjectId>> objectPutCallbacks = new ArrayList<>();
private final WorkerContext workerContext;
public MockObjectInterface(WorkerContext workerContext) {
this.workerContext = workerContext;
}
public void addObjectPutCallback(Consumer<ObjectId> callback) {
this.objectPutCallbacks.add(callback);
}
public boolean isObjectReady(ObjectId id) {
return pool.containsKey(id);
}
@Override
public ObjectId put(NativeRayObject obj) {
ObjectId objectId = IdUtil.computePutId(workerContext.getCurrentTaskId(),
workerContext.nextPutIndex());
put(obj, objectId);
return objectId;
}
@Override
public void put(NativeRayObject obj, ObjectId objectId) {
Preconditions.checkNotNull(obj);
Preconditions.checkNotNull(objectId);
pool.putIfAbsent(objectId, obj);
for (Consumer<ObjectId> callback : objectPutCallbacks) {
callback.accept(objectId);
}
}
@Override
public List<NativeRayObject> get(List<ObjectId> objectIds, long timeoutMs) {
waitInternal(objectIds, objectIds.size(), timeoutMs);
return objectIds.stream().map(pool::get).collect(Collectors.toList());
}
@Override
public List<Boolean> wait(List<ObjectId> objectIds, int numObjects, long timeoutMs) {
waitInternal(objectIds, numObjects, timeoutMs);
return objectIds.stream().map(pool::containsKey).collect(Collectors.toList());
}
private void waitInternal(List<ObjectId> objectIds, int numObjects, long timeoutMs) {
int ready = 0;
long remainingTime = timeoutMs;
boolean firstCheck = true;
while (ready < numObjects && (timeoutMs < 0 || remainingTime > 0)) {
if (!firstCheck) {
long sleepTime = Math.min(remainingTime, GET_CHECK_INTERVAL_MS);
try {
Thread.sleep(sleepTime);
} catch (InterruptedException e) {
LOGGER.warn("Got InterruptedException while sleeping.");
}
remainingTime -= sleepTime;
}
ready = 0;
for (ObjectId objectId : objectIds) {
if (pool.containsKey(objectId)) {
ready += 1;
}
}
firstCheck = false;
}
}
@Override
public void delete(List<ObjectId> objectIds, boolean localOnly, boolean deleteCreatingTasks) {
for (ObjectId objectId : objectIds) {
pool.remove(objectId);
}
}
}

View file

@ -1,148 +0,0 @@
package org.ray.runtime.objectstore;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.function.Consumer;
import java.util.stream.Collectors;
import org.apache.arrow.plasma.ObjectStoreLink;
import org.ray.api.id.ObjectId;
import org.ray.runtime.RayDevRuntime;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* A mock implementation of {@code org.ray.spi.ObjectStoreLink}, which use Map to store data.
*/
public class MockObjectStore implements ObjectStoreLink {
private static final Logger LOGGER = LoggerFactory.getLogger(MockObjectStore.class);
private static final int GET_CHECK_INTERVAL_MS = 100;
private final RayDevRuntime runtime;
private final Map<ObjectId, byte[]> data = new ConcurrentHashMap<>();
private final Map<ObjectId, byte[]> metadata = new ConcurrentHashMap<>();
private final List<Consumer<ObjectId>> objectPutCallbacks;
public MockObjectStore(RayDevRuntime runtime) {
this.runtime = runtime;
this.objectPutCallbacks = new ArrayList<>();
}
public void addObjectPutCallback(Consumer<ObjectId> callback) {
this.objectPutCallbacks.add(callback);
}
@Override
public void put(byte[] objectId, byte[] value, byte[] metadataValue) {
if (objectId == null || objectId.length == 0 || value == null) {
LOGGER
.error("{} cannot put null: {}, {}", logPrefix(), objectId, Arrays.toString(value));
System.exit(-1);
}
ObjectId id = new ObjectId(objectId);
data.put(id, value);
if (metadataValue != null) {
metadata.put(id, metadataValue);
}
for (Consumer<ObjectId> callback : objectPutCallbacks) {
callback.accept(id);
}
}
@Override
public byte[] get(byte[] objectId, int timeoutMs, boolean isMetadata) {
return get(new byte[][] {objectId}, timeoutMs, isMetadata).get(0);
}
@Override
public List<byte[]> get(byte[][] objectIds, int timeoutMs, boolean isMetadata) {
return get(objectIds, timeoutMs)
.stream()
.map(data -> isMetadata ? data.metadata : data.data)
.collect(Collectors.toList());
}
@Override
public List<ObjectStoreData> get(byte[][] objectIds, int timeoutMs) {
int ready = 0;
int remainingTime = timeoutMs;
boolean firstCheck = true;
while (ready < objectIds.length && remainingTime > 0) {
if (!firstCheck) {
int sleepTime = Math.min(remainingTime, GET_CHECK_INTERVAL_MS);
try {
Thread.sleep(sleepTime);
} catch (InterruptedException e) {
LOGGER.warn("Got InterruptedException while sleeping.");
}
remainingTime -= sleepTime;
}
ready = 0;
for (byte[] id : objectIds) {
if (data.containsKey(new ObjectId(id))) {
ready += 1;
}
}
firstCheck = false;
}
ArrayList<ObjectStoreData> rets = new ArrayList<>();
for (byte[] objId : objectIds) {
ObjectId objectId = new ObjectId(objId);
rets.add(new ObjectStoreData(metadata.get(objectId), data.get(objectId)));
}
return rets;
}
@Override
public byte[] hash(byte[] objectId) {
return null;
}
@Override
public long evict(long numBytes) {
return 0;
}
@Override
public void release(byte[] objectId) {
return;
}
@Override
public void delete(byte[] objectId) {
return;
}
@Override
public boolean contains(byte[] objectId) {
return data.containsKey(new ObjectId(objectId));
}
private String logPrefix() {
return runtime.getWorkerContext().getCurrentTaskId() + "-" + getUserTrace() + " -> ";
}
private String getUserTrace() {
StackTraceElement[] stes = Thread.currentThread().getStackTrace();
int k = 1;
while (stes[k].getClassName().startsWith("org.ray")
&& !stes[k].getClassName().contains("test")) {
k++;
}
return stes[k].getFileName() + ":" + stes[k].getLineNumber();
}
public boolean isObjectReady(ObjectId id) {
return data.containsKey(id);
}
public void free(ObjectId id) {
data.remove(id);
metadata.remove(id);
}
}

View file

@ -0,0 +1,13 @@
package org.ray.runtime.objectstore;
public class NativeRayObject {
public byte[] data;
public byte[] metadata;
public NativeRayObject(byte[] data, byte[] metadata) {
this.data = data;
this.metadata = metadata;
}
}

View file

@ -0,0 +1,54 @@
package org.ray.runtime.objectstore;
import java.util.List;
import org.ray.api.id.ObjectId;
/**
* The interface that contains all worker methods that are related to object store.
*/
public interface ObjectInterface {
/**
* Put an object into object store.
*
* @param obj The ray object.
* @return Generated ID of the object.
*/
ObjectId put(NativeRayObject obj);
/**
* Put an object with specified ID into object store.
*
* @param obj The ray object.
* @param objectId Object ID specified by user.
*/
void put(NativeRayObject obj, ObjectId objectId);
/**
* Get a list of objects from the object store.
*
* @param objectIds IDs of the objects to get.
* @param timeoutMs Timeout in milliseconds, wait infinitely if it's negative.
* @return Result list of objects data.
*/
List<NativeRayObject> get(List<ObjectId> objectIds, long timeoutMs);
/**
* Wait for a list of objects to appear in the object store.
*
* @param objectIds IDs of the objects to wait for.
* @param numObjects Number of objects that should appear.
* @param timeoutMs Timeout in milliseconds, wait infinitely if it's negative.
* @return A bitset that indicates each object has appeared or not.
*/
List<Boolean> wait(List<ObjectId> objectIds, int numObjects, long timeoutMs);
/**
* Delete a list of objects from the object store.
*
* @param objectIds IDs of the objects to delete.
* @param localOnly Whether only delete the objects in local node, or all nodes in the cluster.
* @param deleteCreatingTasks Whether also delete the tasks that created these objects.
*/
void delete(List<ObjectId> objectIds, boolean localOnly, boolean deleteCreatingTasks);
}

View file

@ -0,0 +1,91 @@
package org.ray.runtime.objectstore;
import java.util.List;
import java.util.stream.Collectors;
import org.ray.api.exception.RayException;
import org.ray.api.id.BaseId;
import org.ray.api.id.ObjectId;
import org.ray.runtime.AbstractRayRuntime;
import org.ray.runtime.WorkerContext;
import org.ray.runtime.raylet.RayletClient;
import org.ray.runtime.raylet.RayletClientImpl;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* This is a wrapper class for core worker object interface.
*/
public class ObjectInterfaceImpl implements ObjectInterface {
private static final Logger LOGGER = LoggerFactory.getLogger(AbstractRayRuntime.class);
/**
* The native pointer of core worker object interface.
*/
private final long nativeObjectInterfacePointer;
public ObjectInterfaceImpl(WorkerContext workerContext, RayletClient rayletClient,
String storeSocketName) {
this.nativeObjectInterfacePointer =
nativeCreateObjectInterface(workerContext.getNativeWorkerContext(),
((RayletClientImpl) rayletClient).getClient(), storeSocketName);
}
@Override
public ObjectId put(NativeRayObject obj) {
return new ObjectId(nativePut(nativeObjectInterfacePointer, obj));
}
@Override
public void put(NativeRayObject obj, ObjectId objectId) {
try {
nativePut(nativeObjectInterfacePointer, objectId.getBytes(), obj);
} catch (RayException e) {
LOGGER.warn(e.getMessage());
}
}
@Override
public List<NativeRayObject> get(List<ObjectId> objectIds, long timeoutMs) {
return nativeGet(nativeObjectInterfacePointer, toBinaryList(objectIds), timeoutMs);
}
@Override
public List<Boolean> wait(List<ObjectId> objectIds, int numObjects, long timeoutMs) {
return nativeWait(nativeObjectInterfacePointer, toBinaryList(objectIds), numObjects, timeoutMs);
}
@Override
public void delete(List<ObjectId> objectIds, boolean localOnly, boolean deleteCreatingTasks) {
nativeDelete(nativeObjectInterfacePointer, toBinaryList(objectIds), localOnly, deleteCreatingTasks);
}
public void destroy() {
nativeDestroy(nativeObjectInterfacePointer);
}
private static List<byte[]> toBinaryList(List<ObjectId> ids) {
return ids.stream().map(BaseId::getBytes).collect(Collectors.toList());
}
private static native long nativeCreateObjectInterface(long nativeObjectInterface,
long nativeRayletClient,
String storeSocketName);
private static native byte[] nativePut(long nativeObjectInterface, NativeRayObject obj);
private static native void nativePut(long nativeObjectInterface, byte[] objectId,
NativeRayObject obj);
private static native List<NativeRayObject> nativeGet(long nativeObjectInterface,
List<byte[]> ids,
long timeoutMs);
private static native List<Boolean> nativeWait(long nativeObjectInterface, List<byte[]> objectIds,
int numObjects, long timeoutMs);
private static native void nativeDelete(long nativeObjectInterface, List<byte[]> objectIds,
boolean localOnly, boolean deleteCreatingTasks);
private static native void nativeDestroy(long nativeObjectInterface);
}

View file

@ -4,20 +4,14 @@ import com.google.common.collect.ImmutableList;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import org.apache.arrow.plasma.ObjectStoreLink;
import org.apache.arrow.plasma.ObjectStoreLink.ObjectStoreData;
import org.apache.arrow.plasma.PlasmaClient;
import org.apache.arrow.plasma.exceptions.DuplicateObjectException;
import org.ray.api.exception.RayActorException;
import org.ray.api.exception.RayException;
import org.ray.api.exception.RayTaskException;
import org.ray.api.exception.RayWorkerException;
import org.ray.api.exception.UnreconstructableException;
import org.ray.api.id.ObjectId;
import org.ray.runtime.AbstractRayRuntime;
import org.ray.runtime.RayDevRuntime;
import org.ray.runtime.config.RunMode;
import org.ray.runtime.WorkerContext;
import org.ray.runtime.generated.Gcs.ErrorType;
import org.ray.runtime.util.IdUtil;
import org.ray.runtime.util.Serializer;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@ -36,21 +30,18 @@ public class ObjectStoreProxy {
private static final byte[] UNRECONSTRUCTABLE_EXCEPTION_META = String
.valueOf(ErrorType.OBJECT_UNRECONSTRUCTABLE.getNumber()).getBytes();
private static final byte[] TASK_EXECUTION_EXCEPTION_META = String
.valueOf(ErrorType.TASK_EXECUTION_EXCEPTION.getNumber()).getBytes();
private static final byte[] RAW_TYPE_META = "RAW".getBytes();
private final AbstractRayRuntime runtime;
private final WorkerContext workerContext;
private static ThreadLocal<ObjectStoreLink> objectStore;
private final ObjectInterface objectInterface;
public ObjectStoreProxy(AbstractRayRuntime runtime, String storeSocketName) {
this.runtime = runtime;
objectStore = ThreadLocal.withInitial(() -> {
if (runtime.getRayConfig().runMode == RunMode.CLUSTER) {
return new PlasmaClient(storeSocketName, "", 0);
} else {
return ((RayDevRuntime) runtime).getObjectStore();
}
});
public ObjectStoreProxy(WorkerContext workerContext, ObjectInterface objectInterface) {
this.workerContext = workerContext;
this.objectInterface = objectInterface;
}
/**
@ -75,46 +66,44 @@ public class ObjectStoreProxy {
* @return A list of GetResult objects.
*/
public <T> List<GetResult<T>> get(List<ObjectId> ids, int timeoutMs) {
byte[][] binaryIds = IdUtil.getIdBytes(ids);
List<ObjectStoreData> dataAndMetaList = objectStore.get().get(binaryIds, timeoutMs);
List<NativeRayObject> dataAndMetaList = objectInterface.get(ids, timeoutMs);
List<GetResult<T>> results = new ArrayList<>();
for (int i = 0; i < dataAndMetaList.size(); i++) {
byte[] meta = dataAndMetaList.get(i).metadata;
byte[] data = dataAndMetaList.get(i).data;
NativeRayObject dataAndMeta = dataAndMetaList.get(i);
GetResult<T> result;
if (meta != null) {
// If meta is not null, deserialize the object from meta.
result = deserializeFromMeta(meta, data, ids.get(i));
} else if (data != null) {
// If data is not null, deserialize the Java object.
Object object = Serializer.decode(data, runtime.getWorkerContext().getCurrentClassLoader());
if (object instanceof RayException) {
// If the object is a `RayException`, it means that an error occurred during task
// execution.
result = new GetResult<>(true, null, (RayException) object);
if (dataAndMeta != null) {
byte[] meta = dataAndMeta.metadata;
byte[] data = dataAndMeta.data;
if (meta != null && meta.length > 0) {
// If meta is not null, deserialize the object from meta.
result = deserializeFromMeta(meta, data,
workerContext.getCurrentClassLoader(), ids.get(i));
} else {
// Otherwise, the object is valid.
result = new GetResult<>(true, (T) object, null);
// If data is not null, deserialize the Java object.
Object object = Serializer.decode(data, workerContext.getCurrentClassLoader());
if (object instanceof RayException) {
// If the object is a `RayException`, it means that an error occurred during task
// execution.
result = new GetResult<>(true, null, (RayException) object);
} else {
// Otherwise, the object is valid.
result = new GetResult<>(true, (T) object, null);
}
}
} else {
// If both meta and data are null, the object doesn't exist in object store.
result = new GetResult<>(false, null, null);
}
if (meta != null || data != null) {
// Release the object from object store..
objectStore.get().release(binaryIds[i]);
}
results.add(result);
}
return results;
}
@SuppressWarnings("unchecked")
private <T> GetResult<T> deserializeFromMeta(byte[] meta, byte[] data, ObjectId objectId) {
private <T> GetResult<T> deserializeFromMeta(byte[] meta, byte[] data,
ClassLoader classLoader, ObjectId objectId) {
if (Arrays.equals(meta, RAW_TYPE_META)) {
return (GetResult<T>) new GetResult<>(true, data, null);
} else if (Arrays.equals(meta, WORKER_EXCEPTION_META)) {
@ -123,6 +112,8 @@ public class ObjectStoreProxy {
return new GetResult<>(true, null, RayActorException.INSTANCE);
} else if (Arrays.equals(meta, UNRECONSTRUCTABLE_EXCEPTION_META)) {
return new GetResult<>(true, null, new UnreconstructableException(objectId));
} else if (Arrays.equals(meta, TASK_EXECUTION_EXCEPTION_META)) {
return new GetResult<>(true, null, Serializer.decode(data, classLoader));
}
throw new IllegalArgumentException("Unrecognized metadata " + Arrays.toString(meta));
}
@ -134,16 +125,14 @@ public class ObjectStoreProxy {
* @param object The object to put.
*/
public void put(ObjectId id, Object object) {
try {
if (object instanceof byte[]) {
// If the object is a byte array, skip serializing it and use a special metadata to
// indicate it's raw binary. So that this object can also be read by Python.
objectStore.get().put(id.getBytes(), (byte[]) object, RAW_TYPE_META);
} else {
objectStore.get().put(id.getBytes(), Serializer.encode(object), null);
}
} catch (DuplicateObjectException e) {
LOGGER.warn(e.getMessage());
if (object instanceof byte[]) {
// If the object is a byte array, skip serializing it and use a special metadata to
// indicate it's raw binary. So that this object can also be read by Python.
objectInterface.put(new NativeRayObject((byte[]) object, RAW_TYPE_META), id);
} else if (object instanceof RayTaskException) {
objectInterface.put(new NativeRayObject(Serializer.encode(object), TASK_EXECUTION_EXCEPTION_META), id);
} else {
objectInterface.put(new NativeRayObject(Serializer.encode(object), null), id);
}
}
@ -154,11 +143,7 @@ public class ObjectStoreProxy {
* @param serializedObject The serialized object to put.
*/
public void putSerialized(ObjectId id, byte[] serializedObject) {
try {
objectStore.get().put(id.getBytes(), serializedObject, null);
} catch (DuplicateObjectException e) {
LOGGER.warn(e.getMessage());
}
objectInterface.put(new NativeRayObject(serializedObject, null), id);
}
/**

View file

@ -14,6 +14,7 @@ import java.util.concurrent.ConcurrentLinkedDeque;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.stream.Collectors;
import org.apache.commons.lang3.NotImplementedException;
import org.ray.api.RayObject;
import org.ray.api.WaitResult;
@ -23,7 +24,8 @@ import org.ray.api.id.TaskId;
import org.ray.api.id.UniqueId;
import org.ray.runtime.RayDevRuntime;
import org.ray.runtime.Worker;
import org.ray.runtime.objectstore.MockObjectStore;
import org.ray.runtime.objectstore.MockObjectInterface;
import org.ray.runtime.objectstore.NativeRayObject;
import org.ray.runtime.task.FunctionArg;
import org.ray.runtime.task.TaskSpec;
import org.slf4j.Logger;
@ -37,7 +39,7 @@ public class MockRayletClient implements RayletClient {
private static final Logger LOGGER = LoggerFactory.getLogger(MockRayletClient.class);
private final Map<ObjectId, Set<TaskSpec>> waitingTasks = new ConcurrentHashMap<>();
private final MockObjectStore store;
private final MockObjectInterface objectInterface;
private final RayDevRuntime runtime;
private final ExecutorService exec;
private final Deque<Worker> idleWorkers;
@ -46,8 +48,8 @@ public class MockRayletClient implements RayletClient {
public MockRayletClient(RayDevRuntime runtime, int numberThreads) {
this.runtime = runtime;
this.store = runtime.getObjectStore();
store.addObjectPutCallback(this::onObjectPut);
this.objectInterface = runtime.getObjectInterface();
objectInterface.addObjectPutCallback(this::onObjectPut);
// The thread pool that executes tasks in parallel.
exec = Executors.newFixedThreadPool(numberThreads);
idleWorkers = new ConcurrentLinkedDeque<>();
@ -113,8 +115,8 @@ public class MockRayletClient implements RayletClient {
// can be executed.
if (task.isActorCreationTask() || task.isActorTask()) {
ObjectId[] returnIds = task.returnIds;
store.put(returnIds[returnIds.length - 1].getBytes(),
new byte[]{}, new byte[]{});
objectInterface.put(new NativeRayObject(new byte[] {}, new byte[] {}),
returnIds[returnIds.length - 1]);
}
} finally {
returnWorker(worker);
@ -133,13 +135,13 @@ public class MockRayletClient implements RayletClient {
// Check whether task arguments are ready.
for (FunctionArg arg : spec.args) {
if (arg.id != null) {
if (!store.isObjectReady(arg.id)) {
if (!objectInterface.isObjectReady(arg.id)) {
unreadyObjects.add(arg.id);
}
}
}
if (spec.isActorTask()) {
if (!store.isObjectReady(spec.previousActorTaskDummyObjectId)) {
if (!objectInterface.isObjectReady(spec.previousActorTaskDummyObjectId)) {
unreadyObjects.add(spec.previousActorTaskDummyObjectId);
}
}
@ -154,7 +156,7 @@ public class MockRayletClient implements RayletClient {
@Override
public void fetchOrReconstruct(List<ObjectId> objectIds, boolean fetchOnly,
TaskId currentTaskId) {
TaskId currentTaskId) {
}
@ -170,20 +172,17 @@ public class MockRayletClient implements RayletClient {
@Override
public <T> WaitResult<T> wait(List<RayObject<T>> waitFor, int numReturns, int
timeoutMs, TaskId currentTaskId) {
timeoutMs, TaskId currentTaskId) {
if (waitFor == null || waitFor.isEmpty()) {
return new WaitResult<>(ImmutableList.of(), ImmutableList.of());
}
byte[][] ids = new byte[waitFor.size()][];
for (int i = 0; i < waitFor.size(); i++) {
ids[i] = waitFor.get(i).getId().getBytes();
}
List<ObjectId> ids = waitFor.stream().map(RayObject::getId).collect(Collectors.toList());
List<RayObject<T>> readyList = new ArrayList<>();
List<RayObject<T>> unreadyList = new ArrayList<>();
List<byte[]> result = store.get(ids, timeoutMs, false);
List<Boolean> result = objectInterface.wait(ids, ids.size(), timeoutMs);
for (int i = 0; i < waitFor.size(); i++) {
if (result.get(i) != null) {
if (result.get(i)) {
readyList.add(waitFor.get(i));
} else {
unreadyList.add(waitFor.get(i));
@ -195,9 +194,7 @@ public class MockRayletClient implements RayletClient {
@Override
public void freePlasmaObjects(List<ObjectId> objectIds, boolean localOnly,
boolean deleteCreatingTasks) {
for (ObjectId id : objectIds) {
store.free(id);
}
objectInterface.delete(objectIds, localOnly, deleteCreatingTasks);
}

View file

@ -4,8 +4,6 @@ import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.protobuf.ByteString;
import com.google.protobuf.InvalidProtocolBufferException;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.charset.Charset;
import java.util.ArrayList;
import java.util.Arrays;
@ -40,11 +38,15 @@ public class RayletClientImpl implements RayletClient {
// TODO(qwang): JobId parameter can be removed once we embed jobId in driverId.
public RayletClientImpl(String schedulerSockName, UniqueId clientId,
boolean isWorker, JobId jobId) {
boolean isWorker, JobId jobId) {
client = nativeInit(schedulerSockName, clientId.getBytes(),
isWorker, jobId.getBytes());
}
public long getClient() {
return client;
}
@Override
public <T> WaitResult<T> wait(List<RayObject<T>> waitFor, int numReturns, int
timeoutMs, TaskId currentTaskId) {
@ -133,7 +135,7 @@ public class RayletClientImpl implements RayletClient {
/**
* Parse `TaskSpec` protobuf bytes.
*/
private static TaskSpec parseTaskSpecFromProtobuf(byte[] bytes) {
public static TaskSpec parseTaskSpecFromProtobuf(byte[] bytes) {
Common.TaskSpec taskSpec;
try {
taskSpec = Common.TaskSpec.parseFrom(bytes);
@ -214,7 +216,7 @@ public class RayletClientImpl implements RayletClient {
/**
* Convert a `TaskSpec` to protobuf-serialized bytes.
*/
private static byte[] convertTaskSpecToProtobuf(TaskSpec task) {
public static byte[] convertTaskSpecToProtobuf(TaskSpec task) {
// Set common fields.
Common.TaskSpec.Builder builder = Common.TaskSpec.newBuilder()
.setJobId(ByteString.copyFrom(task.jobId.getBytes()))

View file

@ -154,18 +154,6 @@ public class IdUtil {
}
/**
* Compute the driver id from the given job.
*/
public static UniqueId computeDriverId(JobId jobId) {
byte[] bytes = new byte[UniqueId.LENGTH];
System.arraycopy(jobId.getBytes(), 0, bytes, 0, jobId.size());
Arrays.fill(bytes, jobId.size(), UniqueId.LENGTH, (byte)0xFF);
ByteBuffer wbb = ByteBuffer.wrap(bytes);
wbb.order(ByteOrder.LITTLE_ENDIAN);
return new UniqueId(bytes);
}
/**
* Compute the murmur hash code of this ID.
*/

View file

@ -1,12 +1,18 @@
package org.ray.api.test;
import java.time.Duration;
import java.time.Instant;
import java.util.Arrays;
import java.util.List;
import org.ray.api.Ray;
import org.ray.api.RayActor;
import org.ray.api.RayObject;
import org.ray.api.TestUtils;
import org.ray.api.exception.RayActorException;
import org.ray.api.exception.RayException;
import org.ray.api.exception.RayTaskException;
import org.ray.api.exception.RayWorkerException;
import org.ray.api.function.RayFunc0;
import org.testng.Assert;
import org.testng.annotations.Test;
@ -23,6 +29,15 @@ public class FailureTest extends BaseTest {
return 0;
}
public static int slowFunc() {
try {
Thread.sleep(10000);
} catch (InterruptedException e) {
throw new RuntimeException(e);
}
return 0;
}
public static class BadActor {
public BadActor(boolean failOnCreation) {
@ -106,5 +121,26 @@ public class FailureTest extends BaseTest {
// RayActorException.
}
}
@Test
public void testGetThrowsQuicklyWhenFoundException() {
TestUtils.skipTestUnderSingleProcess();
List<RayFunc0<Integer>> badFunctions = Arrays.asList(FailureTest::badFunc,
FailureTest::badFunc2);
for (RayFunc0<Integer> badFunc : badFunctions) {
RayObject<Integer> obj1 = Ray.call(badFunc);
RayObject<Integer> obj2 = Ray.call(FailureTest::slowFunc);
Instant start = Instant.now();
try {
Ray.get(Arrays.asList(obj1.getId(), obj2.getId()));
Assert.fail("Should throw RayException.");
} catch (RayException e) {
Instant end = Instant.now();
long duration = Duration.between(start, end).toMillis();
Assert.assertTrue(duration < 5000, "Should fail quickly. " +
"Actual execution time: " + duration + " ms.");
}
}
}
}

View file

@ -1,12 +1,10 @@
package org.ray.api.test;
import org.apache.arrow.plasma.PlasmaClient;
import org.apache.arrow.plasma.exceptions.DuplicateObjectException;
import org.ray.api.Ray;
import org.ray.api.TestUtils;
import org.ray.api.id.UniqueId;
import org.ray.api.id.ObjectId;
import org.ray.runtime.AbstractRayRuntime;
import org.ray.runtime.objectstore.ObjectStoreProxy;
import org.testng.Assert;
import org.testng.annotations.Test;
@ -15,15 +13,13 @@ public class PlasmaStoreTest extends BaseTest {
@Test
public void testPutWithDuplicateId() {
TestUtils.skipTestUnderSingleProcess();
UniqueId objectId = UniqueId.randomId();
ObjectId objectId = ObjectId.randomId();
AbstractRayRuntime runtime = (AbstractRayRuntime) Ray.internal();
PlasmaClient store = new PlasmaClient(runtime.getRayConfig().objectStoreSocketName, "", 0);
store.put(objectId.getBytes(), new byte[]{}, new byte[]{});
try {
store.put(objectId.getBytes(), new byte[]{}, new byte[]{});
Assert.fail("This line shouldn't be reached.");
} catch (DuplicateObjectException e) {
// Putting 2 objects with duplicate ID should throw DuplicateObjectException.
}
ObjectStoreProxy objectInterface = runtime.getObjectStoreProxy();
objectInterface.put(objectId, 1);
Assert.assertEquals(objectInterface.<Integer>get(objectId, -1).object, (Integer) 1);
objectInterface.put(objectId, 2);
// Putting 2 objects with duplicate ID should fail but ignored.
Assert.assertEquals(objectInterface.<Integer>get(objectId, -1).object, (Integer) 1);
}
}

View file

@ -1,7 +1,7 @@
package org.ray.api.test;
import org.ray.runtime.config.RayConfig;
import org.ray.runtime.config.WorkerMode;
import org.ray.runtime.generated.Common.WorkerType;
import org.testng.Assert;
import org.testng.annotations.Test;
@ -12,7 +12,7 @@ public class RayConfigTest {
try {
System.setProperty("ray.job.resource-path", "path/to/ray/job/resource/path");
RayConfig rayConfig = RayConfig.create();
Assert.assertEquals(WorkerMode.DRIVER, rayConfig.workerMode);
Assert.assertEquals(WorkerType.DRIVER, rayConfig.workerMode);
Assert.assertEquals("path/to/ray/job/resource/path", rayConfig.jobResourcePath);
} finally {
// Unset system properties.

View file

@ -151,3 +151,10 @@ RAY_CONFIG(uint32_t, num_actor_checkpoints_to_keep, 20)
/// Maximum number of ids in one batch to send to GCS to delete keys.
RAY_CONFIG(uint32_t, maximum_gcs_deletion_batch_size, 1000)
/// When getting objects from object store, print a warning every this number of attempts.
RAY_CONFIG(uint32_t, object_store_get_warn_per_num_attempts, 50)
/// When getting objects from object store, max number of ids to print in the warning
/// message.
RAY_CONFIG(uint32_t, object_store_get_max_ids_to_print_in_warning, 20)

View file

@ -9,9 +9,7 @@
#include "ray/raylet/raylet_client.h"
namespace ray {
/// Type of this worker.
enum class WorkerType { WORKER, DRIVER };
using WorkerType = rpc::WorkerType;
/// Information about a remote function.
struct RayFunction {

View file

@ -6,69 +6,81 @@ namespace ray {
/// per-thread context for core worker.
struct WorkerThreadContext {
WorkerThreadContext()
: current_task_id(TaskID::FromRandom()), task_index(0), put_index(0) {}
: current_task_id_(TaskID::FromRandom()), task_index_(0), put_index_(0) {}
int GetNextTaskIndex() { return ++task_index; }
int GetNextTaskIndex() { return ++task_index_; }
int GetNextPutIndex() { return ++put_index; }
int GetNextPutIndex() { return ++put_index_; }
const TaskID &GetCurrentTaskID() const { return current_task_id; }
const TaskID &GetCurrentTaskID() const { return current_task_id_; }
void SetCurrentTask(const TaskID &task_id) {
current_task_id = task_id;
task_index = 0;
put_index = 0;
std::shared_ptr<const TaskSpecification> GetCurrentTask() const {
return current_task_;
}
void SetCurrentTaskId(const TaskID &task_id) {
current_task_id_ = task_id;
task_index_ = 0;
put_index_ = 0;
}
void SetCurrentTask(const TaskSpecification &task_spec) {
SetCurrentTask(task_spec.TaskId());
SetCurrentTaskId(task_spec.TaskId());
current_task_ = std::make_shared<const TaskSpecification>(task_spec);
}
private:
/// The task ID for current task.
TaskID current_task_id;
TaskID current_task_id_;
/// The current task.
std::shared_ptr<const TaskSpecification> current_task_;
/// Number of tasks that have been submitted from current task.
int task_index;
int task_index_;
/// Number of objects that have been put from current task.
int put_index;
int put_index_;
};
thread_local std::unique_ptr<WorkerThreadContext> WorkerContext::thread_context_ =
nullptr;
WorkerContext::WorkerContext(WorkerType worker_type, const JobID &job_id)
: worker_type(worker_type),
worker_id(worker_type == WorkerType::DRIVER ? ComputeDriverIdFromJob(job_id)
: WorkerID::FromRandom()),
current_job_id(worker_type == WorkerType::DRIVER ? job_id : JobID::Nil()) {
: worker_type_(worker_type),
worker_id_(worker_type_ == WorkerType::DRIVER ? ComputeDriverIdFromJob(job_id)
: WorkerID::FromRandom()),
current_job_id_(worker_type_ == WorkerType::DRIVER ? job_id : JobID::Nil()) {
// For worker main thread which initializes the WorkerContext,
// set task_id according to whether current worker is a driver.
// (For other threads it's set to random ID via GetThreadContext).
GetThreadContext().SetCurrentTask(
(worker_type == WorkerType::DRIVER) ? TaskID::FromRandom() : TaskID::Nil());
GetThreadContext().SetCurrentTaskId(
(worker_type_ == WorkerType::DRIVER) ? TaskID::FromRandom() : TaskID::Nil());
}
const WorkerType WorkerContext::GetWorkerType() const { return worker_type; }
const WorkerType WorkerContext::GetWorkerType() const { return worker_type_; }
const WorkerID &WorkerContext::GetWorkerID() const { return worker_id; }
const WorkerID &WorkerContext::GetWorkerID() const { return worker_id_; }
int WorkerContext::GetNextTaskIndex() { return GetThreadContext().GetNextTaskIndex(); }
int WorkerContext::GetNextPutIndex() { return GetThreadContext().GetNextPutIndex(); }
const JobID &WorkerContext::GetCurrentJobID() const { return current_job_id; }
const JobID &WorkerContext::GetCurrentJobID() const { return current_job_id_; }
const TaskID &WorkerContext::GetCurrentTaskID() const {
return GetThreadContext().GetCurrentTaskID();
}
void WorkerContext::SetCurrentTask(const TaskSpecification &task_spec) {
current_job_id = task_spec.JobId();
current_job_id_ = task_spec.JobId();
GetThreadContext().SetCurrentTask(task_spec);
}
std::shared_ptr<const TaskSpecification> WorkerContext::GetCurrentTask() const {
return GetThreadContext().GetCurrentTask();
}
WorkerThreadContext &WorkerContext::GetThreadContext() {
if (thread_context_ == nullptr) {
thread_context_ = std::unique_ptr<WorkerThreadContext>(new WorkerThreadContext());

View file

@ -22,19 +22,21 @@ class WorkerContext {
void SetCurrentTask(const TaskSpecification &task_spec);
std::shared_ptr<const TaskSpecification> GetCurrentTask() const;
int GetNextTaskIndex();
int GetNextPutIndex();
private:
/// Type of the worker.
const WorkerType worker_type;
const WorkerType worker_type_;
/// ID for this worker.
const WorkerID worker_id;
const WorkerID worker_id_;
/// Job ID for this worker.
JobID current_job_id;
JobID current_job_id_;
private:
static WorkerThreadContext &GetThreadContext();

View file

@ -15,7 +15,7 @@ CoreWorker::CoreWorker(
task_interface_(worker_context_, raylet_client_),
object_interface_(worker_context_, raylet_client_, store_socket) {
int rpc_server_port = 0;
if (worker_type_ == ray::WorkerType::WORKER) {
if (worker_type_ == WorkerType::WORKER) {
RAY_CHECK(execution_callback != nullptr);
task_execution_interface_ = std::unique_ptr<CoreWorkerTaskExecutionInterface>(
new CoreWorkerTaskExecutionInterface(worker_context_, raylet_client_,
@ -28,8 +28,8 @@ CoreWorker::CoreWorker(
// instead of crashing.
raylet_client_ = std::unique_ptr<RayletClient>(new RayletClient(
raylet_socket_, ClientID::FromBinary(worker_context_.GetWorkerID().Binary()),
(worker_type_ == ray::WorkerType::WORKER), worker_context_.GetCurrentJobID(),
language_, rpc_server_port));
(worker_type_ == WorkerType::WORKER), worker_context_.GetCurrentJobID(), language_,
rpc_server_port));
}
} // namespace ray

View file

@ -0,0 +1,75 @@
#include "ray/core_worker/lib/java/jni_utils.h"
jclass java_boolean_class;
jmethodID java_boolean_init;
jclass java_list_class;
jmethodID java_list_size;
jmethodID java_list_get;
jmethodID java_list_add;
jclass java_array_list_class;
jmethodID java_array_list_init;
jmethodID java_array_list_init_with_capacity;
jclass java_ray_exception_class;
jclass java_native_ray_object_class;
jmethodID java_native_ray_object_init;
jfieldID java_native_ray_object_data;
jfieldID java_native_ray_object_metadata;
jint JNI_VERSION = JNI_VERSION_1_8;
inline jclass LoadClass(JNIEnv *env, const char *class_name) {
jclass tempLocalClassRef = env->FindClass(class_name);
jclass ret = (jclass)env->NewGlobalRef(tempLocalClassRef);
env->DeleteLocalRef(tempLocalClassRef);
return ret;
}
/// Load and cache frequently-used Java classes and methods
jint JNI_OnLoad(JavaVM *vm, void *reserved) {
JNIEnv *env;
if (vm->GetEnv(reinterpret_cast<void **>(&env), JNI_VERSION) != JNI_OK) {
return JNI_ERR;
}
java_boolean_class = LoadClass(env, "java/lang/Boolean");
java_boolean_init = env->GetMethodID(java_boolean_class, "<init>", "(Z)V");
java_list_class = LoadClass(env, "java/util/List");
java_list_size = env->GetMethodID(java_list_class, "size", "()I");
java_list_get = env->GetMethodID(java_list_class, "get", "(I)Ljava/lang/Object;");
java_list_add = env->GetMethodID(java_list_class, "add", "(Ljava/lang/Object;)Z");
java_array_list_class = LoadClass(env, "java/util/ArrayList");
java_array_list_init = env->GetMethodID(java_array_list_class, "<init>", "()V");
java_array_list_init_with_capacity =
env->GetMethodID(java_array_list_class, "<init>", "(I)V");
java_ray_exception_class = LoadClass(env, "org/ray/api/exception/RayException");
java_native_ray_object_class =
LoadClass(env, "org/ray/runtime/objectstore/NativeRayObject");
java_native_ray_object_init =
env->GetMethodID(java_native_ray_object_class, "<init>", "([B[B)V");
java_native_ray_object_data =
env->GetFieldID(java_native_ray_object_class, "data", "[B");
java_native_ray_object_metadata =
env->GetFieldID(java_native_ray_object_class, "metadata", "[B");
return JNI_VERSION;
}
/// Unload java classes
void JNI_OnUnload(JavaVM *vm, void *reserved) {
JNIEnv *env;
vm->GetEnv(reinterpret_cast<void **>(&env), JNI_VERSION);
env->DeleteGlobalRef(java_boolean_class);
env->DeleteGlobalRef(java_list_class);
env->DeleteGlobalRef(java_array_list_class);
env->DeleteGlobalRef(java_ray_exception_class);
env->DeleteGlobalRef(java_native_ray_object_class);
}

View file

@ -0,0 +1,180 @@
#ifndef RAY_COMMON_JAVA_JNI_HELPER_H
#define RAY_COMMON_JAVA_JNI_HELPER_H
#include <jni.h>
#include "ray/common/buffer.h"
#include "ray/common/id.h"
#include "ray/common/status.h"
#include "ray/core_worker/store_provider/store_provider.h"
/// Boolean class
extern jclass java_boolean_class;
/// Constructor of Boolean class
extern jmethodID java_boolean_init;
/// List class
extern jclass java_list_class;
/// size method of List class
extern jmethodID java_list_size;
/// get method of List class
extern jmethodID java_list_get;
/// add method of List class
extern jmethodID java_list_add;
/// ArrayList class
extern jclass java_array_list_class;
/// Constructor of ArrayList class
extern jmethodID java_array_list_init;
/// Constructor of ArrayList class with single parameter capacity
extern jmethodID java_array_list_init_with_capacity;
/// RayException class
extern jclass java_ray_exception_class;
/// NativeRayObject class
extern jclass java_native_ray_object_class;
/// Constructor of NativeRayObject class
extern jmethodID java_native_ray_object_init;
/// data field of NativeRayObject class
extern jfieldID java_native_ray_object_data;
/// metadata field of NativeRayObject class
extern jfieldID java_native_ray_object_metadata;
/// Throws a Java RayException if the status is not OK.
#define THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, ret) \
{ \
if (!(status).ok()) { \
(env)->ThrowNew(java_ray_exception_class, (status).message().c_str()); \
return (ret); \
} \
}
/// Convert a Java byte array to a C++ UniqueID.
template <typename ID>
inline ID JavaByteArrayToId(JNIEnv *env, const jbyteArray &bytes) {
std::string id_str(ID::Size(), 0);
env->GetByteArrayRegion(bytes, 0, ID::Size(),
reinterpret_cast<jbyte *>(&id_str.front()));
return ID::FromBinary(id_str);
}
/// Convert C++ UniqueID to a Java byte array.
template <typename ID>
inline jbyteArray IdToJavaByteArray(JNIEnv *env, const ID &id) {
jbyteArray array = env->NewByteArray(ID::Size());
env->SetByteArrayRegion(array, 0, ID::Size(),
reinterpret_cast<const jbyte *>(id.Data()));
return array;
}
/// Convert C++ UniqueID to a Java ByteBuffer.
template <typename ID>
inline jobject IdToJavaByteBuffer(JNIEnv *env, const ID &id) {
return env->NewDirectByteBuffer(
reinterpret_cast<void *>(const_cast<uint8_t *>(id.Data())), id.Size());
}
/// Convert a Java String to C++ std::string.
inline std::string JavaStringToNativeString(JNIEnv *env, jstring jstr) {
const char *c_str = env->GetStringUTFChars(jstr, nullptr);
std::string result(c_str);
env->ReleaseStringUTFChars(static_cast<jstring>(jstr), c_str);
return result;
}
/// Convert a Java List to C++ std::vector.
template <typename NativeT>
inline void JavaListToNativeVector(
JNIEnv *env, jobject java_list, std::vector<NativeT> *native_vector,
std::function<NativeT(JNIEnv *, jobject)> element_converter) {
int size = env->CallIntMethod(java_list, java_list_size);
native_vector->clear();
for (int i = 0; i < size; i++) {
native_vector->emplace_back(
element_converter(env, env->CallObjectMethod(java_list, java_list_get, (jint)i)));
}
}
/// Convert a C++ std::vector to a Java List.
template <typename NativeT>
inline jobject NativeVectorToJavaList(
JNIEnv *env, const std::vector<NativeT> &native_vector,
std::function<jobject(JNIEnv *, const NativeT &)> element_converter) {
jobject java_list =
env->NewObject(java_array_list_class, java_array_list_init_with_capacity,
(jint)native_vector.size());
for (const auto &item : native_vector) {
env->CallVoidMethod(java_list, java_list_add, element_converter(env, item));
}
return java_list;
}
/// Convert a C++ ray::Buffer to a Java byte array.
inline jbyteArray NativeBufferToJavaByteArray(JNIEnv *env,
const std::shared_ptr<ray::Buffer> buffer) {
if (!buffer) {
return nullptr;
}
jbyteArray java_byte_array = env->NewByteArray(buffer->Size());
if (buffer->Size() > 0) {
env->SetByteArrayRegion(java_byte_array, 0, buffer->Size(),
reinterpret_cast<const jbyte *>(buffer->Data()));
}
return java_byte_array;
}
/// A helper method to help access a Java NativeRayObject instance and ensure memory
/// safety.
///
/// \param[in] java_obj The Java NativeRayObject object.
/// \param[in] reader The callback function to access a C++ ray::RayObject instance.
/// \return The return value of callback function.
template <typename ReturnT>
inline ReturnT ReadJavaNativeRayObject(
JNIEnv *env, const jobject &java_obj,
std::function<ReturnT(const std::shared_ptr<ray::RayObject> &)> reader) {
if (!java_obj) {
return reader(nullptr);
}
auto java_data = (jbyteArray)env->GetObjectField(java_obj, java_native_ray_object_data);
auto java_metadata =
(jbyteArray)env->GetObjectField(java_obj, java_native_ray_object_metadata);
auto data_size = env->GetArrayLength(java_data);
jbyte *data = data_size > 0 ? env->GetByteArrayElements(java_data, nullptr) : nullptr;
auto metadata_size = java_metadata ? env->GetArrayLength(java_metadata) : 0;
jbyte *metadata =
metadata_size > 0 ? env->GetByteArrayElements(java_metadata, nullptr) : nullptr;
auto data_buffer = std::make_shared<ray::LocalMemoryBuffer>(
reinterpret_cast<uint8_t *>(data), data_size);
auto metadata_buffer = java_metadata
? std::make_shared<ray::LocalMemoryBuffer>(
reinterpret_cast<uint8_t *>(metadata), metadata_size)
: nullptr;
auto native_obj = std::make_shared<ray::RayObject>(data_buffer, metadata_buffer);
auto result = reader(native_obj);
if (data) {
env->ReleaseByteArrayElements(java_data, data, JNI_ABORT);
}
if (metadata) {
env->ReleaseByteArrayElements(java_metadata, metadata, JNI_ABORT);
}
return result;
}
/// Convert a C++ ray::RayObject to a Java NativeRayObject.
inline jobject ToJavaNativeRayObject(JNIEnv *env,
const std::shared_ptr<ray::RayObject> &rayObject) {
if (!rayObject) {
return nullptr;
}
auto java_data = NativeBufferToJavaByteArray(env, rayObject->GetData());
auto java_metadata = NativeBufferToJavaByteArray(env, rayObject->GetMetadata());
auto java_obj = env->NewObject(java_native_ray_object_class,
java_native_ray_object_init, java_data, java_metadata);
return java_obj;
}
#endif // RAY_COMMON_JAVA_JNI_HELPER_H

View file

@ -0,0 +1,134 @@
#include "ray/core_worker/lib/java/org_ray_runtime_WorkerContext.h"
#include <jni.h>
#include "ray/common/id.h"
#include "ray/core_worker/context.h"
#include "ray/core_worker/lib/java/jni_utils.h"
inline ray::WorkerContext *GetWorkerContextFromPointer(
jlong nativeWorkerContextFromPointer) {
return reinterpret_cast<ray::WorkerContext *>(nativeWorkerContextFromPointer);
}
#ifdef __cplusplus
extern "C" {
#endif
/*
* Class: org_ray_runtime_WorkerContext
* Method: nativeCreateWorkerContext
* Signature: (I[B)J
*/
JNIEXPORT jlong JNICALL Java_org_ray_runtime_WorkerContext_nativeCreateWorkerContext(
JNIEnv *env, jclass, jint workerType, jbyteArray jobId) {
return reinterpret_cast<jlong>(
new ray::WorkerContext(static_cast<ray::rpc::WorkerType>(workerType),
JavaByteArrayToId<ray::JobID>(env, jobId)));
}
/*
* Class: org_ray_runtime_WorkerContext
* Method: nativeGetCurrentTaskId
* Signature: (J)[B
*/
JNIEXPORT jbyteArray JNICALL Java_org_ray_runtime_WorkerContext_nativeGetCurrentTaskId(
JNIEnv *env, jclass, jlong nativeWorkerContextFromPointer) {
auto task_id =
GetWorkerContextFromPointer(nativeWorkerContextFromPointer)->GetCurrentTaskID();
return IdToJavaByteArray<ray::TaskID>(env, task_id);
}
/*
* Class: org_ray_runtime_WorkerContext
* Method: nativeSetCurrentTask
* Signature: (J[B)V
*/
JNIEXPORT void JNICALL Java_org_ray_runtime_WorkerContext_nativeSetCurrentTask(
JNIEnv *env, jclass, jlong nativeWorkerContextFromPointer, jbyteArray taskSpec) {
jbyte *data = env->GetByteArrayElements(taskSpec, NULL);
jsize size = env->GetArrayLength(taskSpec);
ray::rpc::TaskSpec task_spec_message;
task_spec_message.ParseFromArray(data, size);
env->ReleaseByteArrayElements(taskSpec, data, JNI_ABORT);
ray::TaskSpecification spec(task_spec_message);
GetWorkerContextFromPointer(nativeWorkerContextFromPointer)->SetCurrentTask(spec);
}
/*
* Class: org_ray_runtime_WorkerContext
* Method: nativeGetCurrentTask
* Signature: (J)[B
*/
JNIEXPORT jbyteArray JNICALL Java_org_ray_runtime_WorkerContext_nativeGetCurrentTask(
JNIEnv *env, jclass, jlong nativeWorkerContextFromPointer) {
auto spec =
GetWorkerContextFromPointer(nativeWorkerContextFromPointer)->GetCurrentTask();
if (!spec) {
return nullptr;
}
auto task_message = spec->Serialize();
jbyteArray result = env->NewByteArray(task_message.size());
env->SetByteArrayRegion(
result, 0, task_message.size(),
reinterpret_cast<jbyte *>(const_cast<char *>(task_message.data())));
return result;
}
/*
* Class: org_ray_runtime_WorkerContext
* Method: nativeGetCurrentJobId
* Signature: (J)Ljava/nio/ByteBuffer;
*/
JNIEXPORT jobject JNICALL Java_org_ray_runtime_WorkerContext_nativeGetCurrentJobId(
JNIEnv *env, jclass, jlong nativeWorkerContextFromPointer) {
const auto &job_id =
GetWorkerContextFromPointer(nativeWorkerContextFromPointer)->GetCurrentJobID();
return IdToJavaByteBuffer<ray::JobID>(env, job_id);
}
/*
* Class: org_ray_runtime_WorkerContext
* Method: nativeGetCurrentWorkerId
* Signature: (J)[B
*/
JNIEXPORT jbyteArray JNICALL Java_org_ray_runtime_WorkerContext_nativeGetCurrentWorkerId(
JNIEnv *env, jclass, jlong nativeWorkerContextFromPointer) {
auto worker_id =
GetWorkerContextFromPointer(nativeWorkerContextFromPointer)->GetWorkerID();
return IdToJavaByteArray<ray::WorkerID>(env, worker_id);
}
/*
* Class: org_ray_runtime_WorkerContext
* Method: nativeGetNextTaskIndex
* Signature: (J)I
*/
JNIEXPORT jint JNICALL Java_org_ray_runtime_WorkerContext_nativeGetNextTaskIndex(
JNIEnv *env, jclass, jlong nativeWorkerContextFromPointer) {
return GetWorkerContextFromPointer(nativeWorkerContextFromPointer)->GetNextTaskIndex();
}
/*
* Class: org_ray_runtime_WorkerContext
* Method: nativeGetNextPutIndex
* Signature: (J)I
*/
JNIEXPORT jint JNICALL Java_org_ray_runtime_WorkerContext_nativeGetNextPutIndex(
JNIEnv *env, jclass, jlong nativeWorkerContextFromPointer) {
return GetWorkerContextFromPointer(nativeWorkerContextFromPointer)->GetNextPutIndex();
}
/*
* Class: org_ray_runtime_WorkerContext
* Method: nativeDestroy
* Signature: (J)V
*/
JNIEXPORT void JNICALL Java_org_ray_runtime_WorkerContext_nativeDestroy(
JNIEnv *env, jclass, jlong nativeWorkerContextFromPointer) {
delete GetWorkerContextFromPointer(nativeWorkerContextFromPointer);
}
#ifdef __cplusplus
}
#endif

View file

@ -0,0 +1,87 @@
/* DO NOT EDIT THIS FILE - it is machine generated */
#include <jni.h>
/* Header for class org_ray_runtime_WorkerContext */
#ifndef _Included_org_ray_runtime_WorkerContext
#define _Included_org_ray_runtime_WorkerContext
#ifdef __cplusplus
extern "C" {
#endif
/*
* Class: org_ray_runtime_WorkerContext
* Method: nativeCreateWorkerContext
* Signature: (I[B)J
*/
JNIEXPORT jlong JNICALL Java_org_ray_runtime_WorkerContext_nativeCreateWorkerContext(
JNIEnv *, jclass, jint, jbyteArray);
/*
* Class: org_ray_runtime_WorkerContext
* Method: nativeGetCurrentTaskId
* Signature: (J)[B
*/
JNIEXPORT jbyteArray JNICALL
Java_org_ray_runtime_WorkerContext_nativeGetCurrentTaskId(JNIEnv *, jclass, jlong);
/*
* Class: org_ray_runtime_WorkerContext
* Method: nativeSetCurrentTask
* Signature: (J[B)V
*/
JNIEXPORT void JNICALL Java_org_ray_runtime_WorkerContext_nativeSetCurrentTask(
JNIEnv *, jclass, jlong, jbyteArray);
/*
* Class: org_ray_runtime_WorkerContext
* Method: nativeGetCurrentTask
* Signature: (J)[B
*/
JNIEXPORT jbyteArray JNICALL
Java_org_ray_runtime_WorkerContext_nativeGetCurrentTask(JNIEnv *, jclass, jlong);
/*
* Class: org_ray_runtime_WorkerContext
* Method: nativeGetCurrentJobId
* Signature: (J)Ljava/nio/ByteBuffer;
*/
JNIEXPORT jobject JNICALL
Java_org_ray_runtime_WorkerContext_nativeGetCurrentJobId(JNIEnv *, jclass, jlong);
/*
* Class: org_ray_runtime_WorkerContext
* Method: nativeGetCurrentWorkerId
* Signature: (J)[B
*/
JNIEXPORT jbyteArray JNICALL
Java_org_ray_runtime_WorkerContext_nativeGetCurrentWorkerId(JNIEnv *, jclass, jlong);
/*
* Class: org_ray_runtime_WorkerContext
* Method: nativeGetNextTaskIndex
* Signature: (J)I
*/
JNIEXPORT jint JNICALL Java_org_ray_runtime_WorkerContext_nativeGetNextTaskIndex(JNIEnv *,
jclass,
jlong);
/*
* Class: org_ray_runtime_WorkerContext
* Method: nativeGetNextPutIndex
* Signature: (J)I
*/
JNIEXPORT jint JNICALL Java_org_ray_runtime_WorkerContext_nativeGetNextPutIndex(JNIEnv *,
jclass,
jlong);
/*
* Class: org_ray_runtime_WorkerContext
* Method: nativeDestroy
* Signature: (J)V
*/
JNIEXPORT void JNICALL Java_org_ray_runtime_WorkerContext_nativeDestroy(JNIEnv *, jclass,
jlong);
#ifdef __cplusplus
}
#endif
#endif

View file

@ -0,0 +1,149 @@
#include "ray/core_worker/lib/java/org_ray_runtime_objectstore_ObjectInterfaceImpl.h"
#include <jni.h>
#include "ray/common/id.h"
#include "ray/core_worker/common.h"
#include "ray/core_worker/lib/java/jni_utils.h"
#include "ray/core_worker/object_interface.h"
inline ray::CoreWorkerObjectInterface *GetObjectInterfaceFromPointer(
jlong nativeObjectInterfacePointer) {
return reinterpret_cast<ray::CoreWorkerObjectInterface *>(nativeObjectInterfacePointer);
}
#ifdef __cplusplus
extern "C" {
#endif
/*
* Class: org_ray_runtime_objectstore_ObjectInterfaceImpl
* Method: nativeCreateObjectInterface
* Signature: (JJLjava/lang/String;)J
*/
JNIEXPORT jlong JNICALL
Java_org_ray_runtime_objectstore_ObjectInterfaceImpl_nativeCreateObjectInterface(
JNIEnv *env, jclass, jlong nativeWorkerContext, jlong nativeRayletClient,
jstring storeSocketName) {
return reinterpret_cast<jlong>(new ray::CoreWorkerObjectInterface(
*reinterpret_cast<ray::WorkerContext *>(nativeWorkerContext),
*reinterpret_cast<std::unique_ptr<RayletClient> *>(nativeRayletClient),
JavaStringToNativeString(env, storeSocketName)));
}
/*
* Class: org_ray_runtime_objectstore_ObjectInterfaceImpl
* Method: nativePut
* Signature: (JLorg/ray/runtime/objectstore/NativeRayObject;)[B
*/
JNIEXPORT jbyteArray JNICALL
Java_org_ray_runtime_objectstore_ObjectInterfaceImpl_nativePut__JLorg_ray_runtime_objectstore_NativeRayObject_2(
JNIEnv *env, jclass, jlong nativeObjectInterfacePointer, jobject obj) {
ray::Status status;
ray::ObjectID object_id = ReadJavaNativeRayObject<ray::ObjectID>(
env, obj,
[nativeObjectInterfacePointer,
&status](const std::shared_ptr<ray::RayObject> &rayObject) {
RAY_CHECK(rayObject != nullptr);
ray::ObjectID object_id;
status = GetObjectInterfaceFromPointer(nativeObjectInterfacePointer)
->Put(*rayObject, &object_id);
return object_id;
});
THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, nullptr);
return IdToJavaByteArray<ray::ObjectID>(env, object_id);
}
/*
* Class: org_ray_runtime_objectstore_ObjectInterfaceImpl
* Method: nativePut
* Signature: (J[BLorg/ray/runtime/objectstore/NativeRayObject;)V
*/
JNIEXPORT void JNICALL
Java_org_ray_runtime_objectstore_ObjectInterfaceImpl_nativePut__J_3BLorg_ray_runtime_objectstore_NativeRayObject_2(
JNIEnv *env, jclass, jlong nativeObjectInterfacePointer, jbyteArray objectId,
jobject obj) {
auto object_id = JavaByteArrayToId<ray::ObjectID>(env, objectId);
auto status = ReadJavaNativeRayObject<ray::Status>(
env, obj,
[nativeObjectInterfacePointer,
&object_id](const std::shared_ptr<ray::RayObject> &rayObject) {
RAY_CHECK(rayObject != nullptr);
return GetObjectInterfaceFromPointer(nativeObjectInterfacePointer)
->Put(*rayObject, object_id);
});
THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, (void)0);
}
/*
* Class: org_ray_runtime_objectstore_ObjectInterfaceImpl
* Method: nativeGet
* Signature: (JLjava/util/List;J)Ljava/util/List;
*/
JNIEXPORT jobject JNICALL Java_org_ray_runtime_objectstore_ObjectInterfaceImpl_nativeGet(
JNIEnv *env, jclass, jlong nativeObjectInterfacePointer, jobject ids,
jlong timeoutMs) {
std::vector<ray::ObjectID> object_ids;
JavaListToNativeVector<ray::ObjectID>(
env, ids, &object_ids, [](JNIEnv *env, jobject id) {
return JavaByteArrayToId<ray::ObjectID>(env, static_cast<jbyteArray>(id));
});
std::vector<std::shared_ptr<ray::RayObject>> results;
auto status = GetObjectInterfaceFromPointer(nativeObjectInterfacePointer)
->Get(object_ids, (int64_t)timeoutMs, &results);
THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, nullptr);
return NativeVectorToJavaList<std::shared_ptr<ray::RayObject>>(env, results,
ToJavaNativeRayObject);
}
/*
* Class: org_ray_runtime_objectstore_ObjectInterfaceImpl
* Method: nativeWait
* Signature: (JLjava/util/List;IJ)Ljava/util/List;
*/
JNIEXPORT jobject JNICALL Java_org_ray_runtime_objectstore_ObjectInterfaceImpl_nativeWait(
JNIEnv *env, jclass, jlong nativeObjectInterfacePointer, jobject objectIds,
jint numObjects, jlong timeoutMs) {
std::vector<ray::ObjectID> object_ids;
JavaListToNativeVector<ray::ObjectID>(
env, objectIds, &object_ids, [](JNIEnv *env, jobject id) {
return JavaByteArrayToId<ray::ObjectID>(env, static_cast<jbyteArray>(id));
});
std::vector<bool> results;
auto status = GetObjectInterfaceFromPointer(nativeObjectInterfacePointer)
->Wait(object_ids, (int)numObjects, (int64_t)timeoutMs, &results);
THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, nullptr);
return NativeVectorToJavaList<bool>(env, results, [](JNIEnv *env, const bool &item) {
return env->NewObject(java_boolean_class, java_boolean_init, (jboolean)item);
});
}
/*
* Class: org_ray_runtime_objectstore_ObjectInterfaceImpl
* Method: nativeDelete
* Signature: (JLjava/util/List;ZZ)V
*/
JNIEXPORT void JNICALL Java_org_ray_runtime_objectstore_ObjectInterfaceImpl_nativeDelete(
JNIEnv *env, jclass, jlong nativeObjectInterfacePointer, jobject objectIds,
jboolean localOnly, jboolean deleteCreatingTasks) {
std::vector<ray::ObjectID> object_ids;
JavaListToNativeVector<ray::ObjectID>(
env, objectIds, &object_ids, [](JNIEnv *env, jobject id) {
return JavaByteArrayToId<ray::ObjectID>(env, static_cast<jbyteArray>(id));
});
auto status = GetObjectInterfaceFromPointer(nativeObjectInterfacePointer)
->Delete(object_ids, (bool)localOnly, (bool)deleteCreatingTasks);
THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, (void)0);
}
/*
* Class: org_ray_runtime_objectstore_ObjectInterfaceImpl
* Method: nativeDestroy
* Signature: (J)V
*/
JNIEXPORT void JNICALL Java_org_ray_runtime_objectstore_ObjectInterfaceImpl_nativeDestroy(
JNIEnv *env, jclass, jlong nativeObjectInterfacePointer) {
delete GetObjectInterfaceFromPointer(nativeObjectInterfacePointer);
}
#ifdef __cplusplus
}
#endif

View file

@ -0,0 +1,72 @@
/* DO NOT EDIT THIS FILE - it is machine generated */
#include <jni.h>
/* Header for class org_ray_runtime_objectstore_ObjectInterfaceImpl */
#ifndef _Included_org_ray_runtime_objectstore_ObjectInterfaceImpl
#define _Included_org_ray_runtime_objectstore_ObjectInterfaceImpl
#ifdef __cplusplus
extern "C" {
#endif
/*
* Class: org_ray_runtime_objectstore_ObjectInterfaceImpl
* Method: nativeCreateObjectInterface
* Signature: (JJLjava/lang/String;)J
*/
JNIEXPORT jlong JNICALL
Java_org_ray_runtime_objectstore_ObjectInterfaceImpl_nativeCreateObjectInterface(
JNIEnv *, jclass, jlong, jlong, jstring);
/*
* Class: org_ray_runtime_objectstore_ObjectInterfaceImpl
* Method: nativePut
* Signature: (JLorg/ray/runtime/objectstore/NativeRayObject;)[B
*/
JNIEXPORT jbyteArray JNICALL
Java_org_ray_runtime_objectstore_ObjectInterfaceImpl_nativePut__JLorg_ray_runtime_objectstore_NativeRayObject_2(
JNIEnv *, jclass, jlong, jobject);
/*
* Class: org_ray_runtime_objectstore_ObjectInterfaceImpl
* Method: nativePut
* Signature: (J[BLorg/ray/runtime/objectstore/NativeRayObject;)V
*/
JNIEXPORT void JNICALL
Java_org_ray_runtime_objectstore_ObjectInterfaceImpl_nativePut__J_3BLorg_ray_runtime_objectstore_NativeRayObject_2(
JNIEnv *, jclass, jlong, jbyteArray, jobject);
/*
* Class: org_ray_runtime_objectstore_ObjectInterfaceImpl
* Method: nativeGet
* Signature: (JLjava/util/List;J)Ljava/util/List;
*/
JNIEXPORT jobject JNICALL Java_org_ray_runtime_objectstore_ObjectInterfaceImpl_nativeGet(
JNIEnv *, jclass, jlong, jobject, jlong);
/*
* Class: org_ray_runtime_objectstore_ObjectInterfaceImpl
* Method: nativeWait
* Signature: (JLjava/util/List;IJ)Ljava/util/List;
*/
JNIEXPORT jobject JNICALL Java_org_ray_runtime_objectstore_ObjectInterfaceImpl_nativeWait(
JNIEnv *, jclass, jlong, jobject, jint, jlong);
/*
* Class: org_ray_runtime_objectstore_ObjectInterfaceImpl
* Method: nativeDelete
* Signature: (JLjava/util/List;ZZ)V
*/
JNIEXPORT void JNICALL Java_org_ray_runtime_objectstore_ObjectInterfaceImpl_nativeDelete(
JNIEnv *, jclass, jlong, jobject, jboolean, jboolean);
/*
* Class: org_ray_runtime_objectstore_ObjectInterfaceImpl
* Method: nativeDestroy
* Signature: (J)V
*/
JNIEXPORT void JNICALL Java_org_ray_runtime_objectstore_ObjectInterfaceImpl_nativeDestroy(
JNIEnv *, jclass, jlong);
#ifdef __cplusplus
}
#endif
#endif

View file

@ -3,6 +3,7 @@
#include "ray/core_worker/context.h"
#include "ray/core_worker/core_worker.h"
#include "ray/core_worker/object_interface.h"
#include "ray/protobuf/gcs.pb.h"
namespace ray {
@ -101,11 +102,14 @@ Status CoreWorkerPlasmaStoreProvider::Get(
std::make_shared<PlasmaBuffer>(object_buffers[i].data),
std::make_shared<PlasmaBuffer>(object_buffers[i].metadata));
unready.erase(object_id);
if (IsException(object_buffers[i])) {
should_break = true;
}
}
}
num_attempts += 1;
// TODO(zhijunfu): log a message if attempted too many times.
WarnIfAttemptedTooManyTimes(num_attempts, unready);
}
if (was_blocked) {
@ -144,4 +148,45 @@ Status CoreWorkerPlasmaStoreProvider::Delete(const std::vector<ObjectID> &object
return raylet_client_->FreeObjects(object_ids, local_only, delete_creating_tasks);
}
bool CoreWorkerPlasmaStoreProvider::IsException(const plasma::ObjectBuffer &buffer) {
// TODO (kfstorm): metadata should be structured.
const std::string metadata = buffer.metadata->ToString();
const auto error_type_descriptor = ray::rpc::ErrorType_descriptor();
for (int i = 0; i < error_type_descriptor->value_count(); i++) {
const auto error_type_number = error_type_descriptor->value(i)->number();
if (metadata == std::to_string(error_type_number)) {
return true;
}
}
return false;
}
void CoreWorkerPlasmaStoreProvider::WarnIfAttemptedTooManyTimes(
int num_attempts, const std::unordered_map<ObjectID, int> &unready) {
if (num_attempts % RayConfig::instance().object_store_get_warn_per_num_attempts() ==
0) {
std::ostringstream oss;
size_t printed = 0;
for (auto &entry : unready) {
if (printed >=
RayConfig::instance().object_store_get_max_ids_to_print_in_warning()) {
break;
}
if (printed > 0) {
oss << ", ";
}
oss << entry.first.Hex();
}
if (printed < unready.size()) {
oss << ", etc";
}
RAY_LOG(WARNING)
<< "Attempted " << num_attempts << " times to reconstruct objects, but "
<< "some objects are still unavailable. If this message continues to print,"
<< " it may indicate that object's creating task is hanging, or something wrong"
<< " happened in raylet backend. " << unready.size()
<< " object(s) pending: " << oss.str() << ".";
}
}
} // namespace ray

View file

@ -60,6 +60,20 @@ class CoreWorkerPlasmaStoreProvider : public CoreWorkerStoreProvider {
bool delete_creating_tasks) override;
private:
/// Whether the buffer represents an exception object.
///
/// \param[in] buffer the object buffer.
/// \return Whether it represents an exception object.
static bool IsException(const plasma::ObjectBuffer &buffer);
/// Print a warning if we've attempted too many times, but some objects are still
/// unavailable.
///
/// \param[in] num_attemps The number of attempted times.
/// \param[in] unready The unready objects.
static void WarnIfAttemptedTooManyTimes(
int num_attempts, const std::unordered_map<ObjectID, int> &unready);
/// Plasma store client.
plasma::PlasmaClient store_client_;

View file

@ -11,6 +11,12 @@ enum Language {
CPP = 2;
}
// Type of a worker.
enum WorkerType {
WORKER = 0;
DRIVER = 1;
}
// Type of a task.
enum TaskType {
// Normal task.

View file

@ -267,4 +267,6 @@ enum ErrorType {
// 2) The object's creating task is already cleaned up from GCS (this currently
// crashes raylet).
OBJECT_UNRECONSTRUCTABLE = 2;
// Indicates that a task failed due to user code failure.
TASK_EXECUTION_EXCEPTION = 3;
}

View file

@ -3,39 +3,14 @@
#include <jni.h>
#include "ray/common/id.h"
#include "ray/core_worker/lib/java/jni_utils.h"
#include "ray/raylet/raylet_client.h"
#include "ray/util/logging.h"
template <typename ID>
class UniqueIdFromJByteArray {
public:
const ID &GetId() const { return id; }
UniqueIdFromJByteArray(JNIEnv *env, const jbyteArray &bytes) {
std::string id_str(ID::Size(), 0);
env->GetByteArrayRegion(bytes, 0, ID::Size(),
reinterpret_cast<jbyte *>(&id_str.front()));
id = ID::FromBinary(id_str);
}
private:
ID id;
};
#ifdef __cplusplus
extern "C" {
#endif
inline bool ThrowRayExceptionIfNotOK(JNIEnv *env, const ray::Status &status) {
if (!status.ok()) {
jclass exception_class = env->FindClass("org/ray/api/exception/RayException");
env->ThrowNew(exception_class, status.message().c_str());
return true;
} else {
return false;
}
}
/*
* Class: org_ray_runtime_raylet_RayletClientImpl
* Method: nativeInit
@ -44,11 +19,11 @@ inline bool ThrowRayExceptionIfNotOK(JNIEnv *env, const ray::Status &status) {
JNIEXPORT jlong JNICALL Java_org_ray_runtime_raylet_RayletClientImpl_nativeInit(
JNIEnv *env, jclass, jstring sockName, jbyteArray workerId, jboolean isWorker,
jbyteArray jobId) {
UniqueIdFromJByteArray<ClientID> worker_id(env, workerId);
UniqueIdFromJByteArray<JobID> job_id(env, jobId);
const auto worker_id = JavaByteArrayToId<ClientID>(env, workerId);
const auto job_id = JavaByteArrayToId<JobID>(env, jobId);
const char *nativeString = env->GetStringUTFChars(sockName, JNI_FALSE);
auto raylet_client = new RayletClient(nativeString, worker_id.GetId(), isWorker,
job_id.GetId(), Language::JAVA);
auto raylet_client = new std::unique_ptr<RayletClient>(
new RayletClient(nativeString, worker_id, isWorker, job_id, Language::JAVA));
env->ReleaseStringUTFChars(sockName, nativeString);
return reinterpret_cast<jlong>(raylet_client);
}
@ -60,7 +35,7 @@ JNIEXPORT jlong JNICALL Java_org_ray_runtime_raylet_RayletClientImpl_nativeInit(
*/
JNIEXPORT void JNICALL Java_org_ray_runtime_raylet_RayletClientImpl_nativeSubmitTask(
JNIEnv *env, jclass, jlong client, jbyteArray taskSpec) {
auto raylet_client = reinterpret_cast<RayletClient *>(client);
auto &raylet_client = *reinterpret_cast<std::unique_ptr<RayletClient> *>(client);
jbyte *data = env->GetByteArrayElements(taskSpec, NULL);
jsize size = env->GetArrayLength(taskSpec);
@ -70,7 +45,7 @@ JNIEXPORT void JNICALL Java_org_ray_runtime_raylet_RayletClientImpl_nativeSubmit
ray::TaskSpecification task_spec(task_spec_message);
auto status = raylet_client->SubmitTask(task_spec);
ThrowRayExceptionIfNotOK(env, status);
THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, (void)0);
}
/*
@ -80,13 +55,11 @@ JNIEXPORT void JNICALL Java_org_ray_runtime_raylet_RayletClientImpl_nativeSubmit
*/
JNIEXPORT jbyteArray JNICALL Java_org_ray_runtime_raylet_RayletClientImpl_nativeGetTask(
JNIEnv *env, jclass, jlong client) {
auto raylet_client = reinterpret_cast<RayletClient *>(client);
auto &raylet_client = *reinterpret_cast<std::unique_ptr<RayletClient> *>(client);
std::unique_ptr<ray::TaskSpecification> spec;
auto status = raylet_client->GetTask(&spec);
if (ThrowRayExceptionIfNotOK(env, status)) {
return nullptr;
}
THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, nullptr);
// Serialize the task spec and copy to Java byte array.
auto task_data = spec->Serialize();
@ -109,8 +82,9 @@ JNIEXPORT jbyteArray JNICALL Java_org_ray_runtime_raylet_RayletClientImpl_native
*/
JNIEXPORT void JNICALL Java_org_ray_runtime_raylet_RayletClientImpl_nativeDestroy(
JNIEnv *env, jclass, jlong client) {
auto raylet_client = reinterpret_cast<RayletClient *>(client);
ThrowRayExceptionIfNotOK(env, raylet_client->Disconnect());
auto raylet_client = reinterpret_cast<std::unique_ptr<RayletClient> *>(client);
auto status = (*raylet_client)->Disconnect();
THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, (void)0);
delete raylet_client;
}
@ -128,15 +102,14 @@ Java_org_ray_runtime_raylet_RayletClientImpl_nativeFetchOrReconstruct(
for (int i = 0; i < len; i++) {
jbyteArray object_id_bytes =
static_cast<jbyteArray>(env->GetObjectArrayElement(objectIds, i));
UniqueIdFromJByteArray<ObjectID> object_id(env, object_id_bytes);
object_ids.push_back(object_id.GetId());
const auto object_id = JavaByteArrayToId<ObjectID>(env, object_id_bytes);
object_ids.push_back(object_id);
env->DeleteLocalRef(object_id_bytes);
}
UniqueIdFromJByteArray<TaskID> current_task_id(env, currentTaskId);
auto raylet_client = reinterpret_cast<RayletClient *>(client);
auto status =
raylet_client->FetchOrReconstruct(object_ids, fetchOnly, current_task_id.GetId());
ThrowRayExceptionIfNotOK(env, status);
const auto current_task_id = JavaByteArrayToId<TaskID>(env, currentTaskId);
auto &raylet_client = *reinterpret_cast<std::unique_ptr<RayletClient> *>(client);
auto status = raylet_client->FetchOrReconstruct(object_ids, fetchOnly, current_task_id);
THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, (void)0);
}
/*
@ -146,10 +119,10 @@ Java_org_ray_runtime_raylet_RayletClientImpl_nativeFetchOrReconstruct(
*/
JNIEXPORT void JNICALL Java_org_ray_runtime_raylet_RayletClientImpl_nativeNotifyUnblocked(
JNIEnv *env, jclass, jlong client, jbyteArray currentTaskId) {
UniqueIdFromJByteArray<TaskID> current_task_id(env, currentTaskId);
auto raylet_client = reinterpret_cast<RayletClient *>(client);
auto status = raylet_client->NotifyUnblocked(current_task_id.GetId());
ThrowRayExceptionIfNotOK(env, status);
const auto current_task_id = JavaByteArrayToId<TaskID>(env, currentTaskId);
auto &raylet_client = *reinterpret_cast<std::unique_ptr<RayletClient> *>(client);
auto status = raylet_client->NotifyUnblocked(current_task_id);
THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, (void)0);
}
/*
@ -166,22 +139,20 @@ Java_org_ray_runtime_raylet_RayletClientImpl_nativeWaitObject(
for (int i = 0; i < len; i++) {
jbyteArray object_id_bytes =
static_cast<jbyteArray>(env->GetObjectArrayElement(objectIds, i));
UniqueIdFromJByteArray<ObjectID> object_id(env, object_id_bytes);
object_ids.push_back(object_id.GetId());
const auto object_id = JavaByteArrayToId<ObjectID>(env, object_id_bytes);
object_ids.push_back(object_id);
env->DeleteLocalRef(object_id_bytes);
}
UniqueIdFromJByteArray<TaskID> current_task_id(env, currentTaskId);
const auto current_task_id = JavaByteArrayToId<TaskID>(env, currentTaskId);
auto raylet_client = reinterpret_cast<RayletClient *>(client);
auto &raylet_client = *reinterpret_cast<std::unique_ptr<RayletClient> *>(client);
// Invoke wait.
WaitResultPair result;
auto status = raylet_client->Wait(object_ids, numReturns, timeoutMillis,
static_cast<bool>(isWaitLocal),
current_task_id.GetId(), &result);
if (ThrowRayExceptionIfNotOK(env, status)) {
return nullptr;
}
auto status =
raylet_client->Wait(object_ids, numReturns, timeoutMillis,
static_cast<bool>(isWaitLocal), current_task_id, &result);
THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, nullptr);
// Convert result to java object.
jboolean put_value = true;
@ -216,11 +187,10 @@ JNIEXPORT jbyteArray JNICALL
Java_org_ray_runtime_raylet_RayletClientImpl_nativeGenerateTaskId(
JNIEnv *env, jclass, jbyteArray jobId, jbyteArray parentTaskId,
jint parent_task_counter) {
UniqueIdFromJByteArray<JobID> job_id(env, jobId);
UniqueIdFromJByteArray<TaskID> parent_task_id(env, parentTaskId);
const auto job_id = JavaByteArrayToId<JobID>(env, jobId);
const auto parent_task_id = JavaByteArrayToId<TaskID>(env, parentTaskId);
TaskID task_id =
ray::GenerateTaskId(job_id.GetId(), parent_task_id.GetId(), parent_task_counter);
TaskID task_id = ray::GenerateTaskId(job_id, parent_task_id, parent_task_counter);
jbyteArray result = env->NewByteArray(task_id.Size());
if (nullptr == result) {
return nullptr;
@ -245,13 +215,13 @@ Java_org_ray_runtime_raylet_RayletClientImpl_nativeFreePlasmaObjects(
for (int i = 0; i < len; i++) {
jbyteArray object_id_bytes =
static_cast<jbyteArray>(env->GetObjectArrayElement(objectIds, i));
UniqueIdFromJByteArray<ObjectID> object_id(env, object_id_bytes);
object_ids.push_back(object_id.GetId());
const auto object_id = JavaByteArrayToId<ObjectID>(env, object_id_bytes);
object_ids.push_back(object_id);
env->DeleteLocalRef(object_id_bytes);
}
auto raylet_client = reinterpret_cast<RayletClient *>(client);
auto &raylet_client = *reinterpret_cast<std::unique_ptr<RayletClient> *>(client);
auto status = raylet_client->FreeObjects(object_ids, localOnly, deleteCreatingTasks);
ThrowRayExceptionIfNotOK(env, status);
THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, (void)0);
}
/*
@ -263,13 +233,11 @@ JNIEXPORT jbyteArray JNICALL
Java_org_ray_runtime_raylet_RayletClientImpl_nativePrepareCheckpoint(JNIEnv *env, jclass,
jlong client,
jbyteArray actorId) {
auto raylet_client = reinterpret_cast<RayletClient *>(client);
UniqueIdFromJByteArray<ActorID> actor_id(env, actorId);
auto &raylet_client = *reinterpret_cast<std::unique_ptr<RayletClient> *>(client);
const auto actor_id = JavaByteArrayToId<ActorID>(env, actorId);
ActorCheckpointID checkpoint_id;
auto status = raylet_client->PrepareActorCheckpoint(actor_id.GetId(), checkpoint_id);
if (ThrowRayExceptionIfNotOK(env, status)) {
return nullptr;
}
auto status = raylet_client->PrepareActorCheckpoint(actor_id, checkpoint_id);
THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, nullptr);
jbyteArray result = env->NewByteArray(checkpoint_id.Size());
env->SetByteArrayRegion(result, 0, checkpoint_id.Size(),
reinterpret_cast<const jbyte *>(checkpoint_id.Data()));
@ -284,12 +252,11 @@ Java_org_ray_runtime_raylet_RayletClientImpl_nativePrepareCheckpoint(JNIEnv *env
JNIEXPORT void JNICALL
Java_org_ray_runtime_raylet_RayletClientImpl_nativeNotifyActorResumedFromCheckpoint(
JNIEnv *env, jclass, jlong client, jbyteArray actorId, jbyteArray checkpointId) {
auto raylet_client = reinterpret_cast<RayletClient *>(client);
UniqueIdFromJByteArray<ActorID> actor_id(env, actorId);
UniqueIdFromJByteArray<ActorCheckpointID> checkpoint_id(env, checkpointId);
auto status = raylet_client->NotifyActorResumedFromCheckpoint(actor_id.GetId(),
checkpoint_id.GetId());
ThrowRayExceptionIfNotOK(env, status);
auto &raylet_client = *reinterpret_cast<std::unique_ptr<RayletClient> *>(client);
const auto actor_id = JavaByteArrayToId<ActorID>(env, actorId);
const auto checkpoint_id = JavaByteArrayToId<ActorCheckpointID>(env, checkpointId);
auto status = raylet_client->NotifyActorResumedFromCheckpoint(actor_id, checkpoint_id);
THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, (void)0);
}
/*
@ -300,14 +267,14 @@ Java_org_ray_runtime_raylet_RayletClientImpl_nativeNotifyActorResumedFromCheckpo
JNIEXPORT void JNICALL Java_org_ray_runtime_raylet_RayletClientImpl_nativeSetResource(
JNIEnv *env, jclass, jlong client, jstring resourceName, jdouble capacity,
jbyteArray nodeId) {
auto raylet_client = reinterpret_cast<RayletClient *>(client);
UniqueIdFromJByteArray<ClientID> node_id(env, nodeId);
auto &raylet_client = *reinterpret_cast<std::unique_ptr<RayletClient> *>(client);
const auto node_id = JavaByteArrayToId<ClientID>(env, nodeId);
const char *native_resource_name = env->GetStringUTFChars(resourceName, JNI_FALSE);
auto status = raylet_client->SetResource(
native_resource_name, static_cast<double>(capacity), node_id.GetId());
auto status = raylet_client->SetResource(native_resource_name,
static_cast<double>(capacity), node_id);
env->ReleaseStringUTFChars(resourceName, native_resource_name);
ThrowRayExceptionIfNotOK(env, status);
THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, (void)0);
}
#ifdef __cplusplus