diff --git a/.gitignore b/.gitignore
index 9da7c3c1c..cd28233ac 100644
--- a/.gitignore
+++ b/.gitignore
@@ -150,6 +150,12 @@ build
.vscode/
*.iml
+
+# Java
java/**/target
java/run
java/**/lib
+java/**/.settings
+java/**/.classpath
+java/**/.project
+
diff --git a/java/api/src/main/java/org/ray/api/Ray.java b/java/api/src/main/java/org/ray/api/Ray.java
index ecf227265..60726a67b 100644
--- a/java/api/src/main/java/org/ray/api/Ray.java
+++ b/java/api/src/main/java/org/ray/api/Ray.java
@@ -85,8 +85,8 @@ public final class Ray extends Rpc {
if (cls.getConstructor() == null) {
System.err.println("class " + cls.getName()
+ " does not (actors must) have a constructor with no arguments");
- RayLog.core.error("class " + cls.getName()
- + " does not (actors must) have a constructor with no arguments");
+ RayLog.core.error("class {} does not (actors must) have a constructor with no arguments",
+ cls.getName());
}
} catch (Exception e) {
System.exit(1);
diff --git a/java/api/src/main/java/org/ray/api/internal/RayConnector.java b/java/api/src/main/java/org/ray/api/internal/RayConnector.java
index 3d54dcc54..3d50dead1 100644
--- a/java/api/src/main/java/org/ray/api/internal/RayConnector.java
+++ b/java/api/src/main/java/org/ray/api/internal/RayConnector.java
@@ -19,7 +19,7 @@ public class RayConnector {
m.setAccessible(false);
return api;
} catch (ReflectiveOperationException | IllegalArgumentException | SecurityException e) {
- RayLog.core.error("Load " + className + " class failed.", e);
+ RayLog.core.error("Load {} class failed.", className, e);
throw new Error("RayApi is not successfully initiated.");
}
}
diff --git a/java/checkstyle-suppressions.xml b/java/checkstyle-suppressions.xml
index f84d9e1ca..0420ef87b 100644
--- a/java/checkstyle-suppressions.xml
+++ b/java/checkstyle-suppressions.xml
@@ -19,4 +19,5 @@
+
diff --git a/java/cleanup.sh b/java/cleanup.sh
index 3d73e626a..84b2110bd 100755
--- a/java/cleanup.sh
+++ b/java/cleanup.sh
@@ -5,4 +5,6 @@ pkill -9 plasma_store
pkill -9 global_scheduler
pkill -9 redis-server
pkill -9 redis
+pkill -9 raylet
ps aux | grep ray | awk '{system("kill "$2);}'
+rm /tmp/raylet*
diff --git a/java/cli/src/main/java/org/ray/cli/RayCli.java b/java/cli/src/main/java/org/ray/cli/RayCli.java
index 671da14e3..c99541f35 100644
--- a/java/cli/src/main/java/org/ray/cli/RayCli.java
+++ b/java/cli/src/main/java/org/ray/cli/RayCli.java
@@ -21,8 +21,9 @@ import org.ray.spi.PathConfig;
import org.ray.spi.RemoteFunctionManager;
import org.ray.spi.StateStoreProxy;
import org.ray.spi.impl.NativeRemoteFunctionManager;
+import org.ray.spi.impl.NonRayletStateStoreProxyImpl;
+import org.ray.spi.impl.RayletStateStoreProxyImpl;
import org.ray.spi.impl.RedisClient;
-import org.ray.spi.impl.StateStoreProxyImpl;
import org.ray.util.FileUtil;
import org.ray.util.config.ConfigReader;
import org.ray.util.logger.RayLog;
@@ -47,7 +48,7 @@ public class RayCli {
throw new RuntimeException("Ray head node start failed", e);
}
- RayLog.core.info("Started Ray head node. Redis address: " + manager.info().redisAddress);
+ RayLog.core.info("Started Ray head node. Redis address: {}", manager.info().redisAddress);
return manager;
}
@@ -74,7 +75,7 @@ public class RayCli {
// Init RayLog before using it.
RayLog.init(params.working_directory);
- RayLog.core.info("Using IP address " + params.node_ip_address + " for this node.");
+ RayLog.core.info("Using IP address {} for this node.", params.node_ip_address);
RunManager manager;
if (cmdStart.head) {
manager = startRayHead(params, paths, config);
@@ -152,7 +153,9 @@ public class RayCli {
KeyValueStoreLink kvStore = new RedisClient();
kvStore.setAddr(cmdSubmit.redisAddress);
- StateStoreProxy stateStoreProxy = new StateStoreProxyImpl(kvStore);
+ StateStoreProxy stateStoreProxy = params.use_raylet
+ ? new RayletStateStoreProxyImpl(kvStore)
+ : new NonRayletStateStoreProxyImpl(kvStore);
stateStoreProxy.initializeGlobalState();
RemoteFunctionManager functionManager = new NativeRemoteFunctionManager(kvStore);
diff --git a/java/common/src/main/java/org/ray/util/MethodId.java b/java/common/src/main/java/org/ray/util/MethodId.java
index f32b33510..1b517f736 100644
--- a/java/common/src/main/java/org/ray/util/MethodId.java
+++ b/java/common/src/main/java/org/ray/util/MethodId.java
@@ -117,7 +117,7 @@ public final class MethodId {
cls = Class
.forName(className, true, loader == null ? this.getClass().getClassLoader() : loader);
} catch (Throwable e) {
- RayLog.core.error("Cannot load class " + className, e);
+ RayLog.core.error("Cannot load class {}", className, e);
return null;
}
@@ -148,7 +148,7 @@ public final class MethodId {
if (methods.size() != 1) {
RayLog.core.error(
- "Load method " + toString() + " failed as there are " + methods.size() + " definitions");
+ "Load method {} failed as there are {} definitions.", toString(), methods.size());
return null;
}
diff --git a/java/common/src/main/java/org/ray/util/NetworkUtil.java b/java/common/src/main/java/org/ray/util/NetworkUtil.java
index d5e48999c..4eddbd7df 100644
--- a/java/common/src/main/java/org/ray/util/NetworkUtil.java
+++ b/java/common/src/main/java/org/ray/util/NetworkUtil.java
@@ -35,9 +35,9 @@ public class NetworkUtil {
return addr.getHostAddress();
}
}
- RayLog.core.warn("you may need to correctly specify [ray.java] net_interface in config");
+ RayLog.core.warn("You need to correctly specify [ray.java] net_interface in config.");
} catch (Exception e) {
- RayLog.core.error("Can't get our ip address, use 127.0.0.1 as default.", e);
+ RayLog.core.error("Can't get ip address, use 127.0.0.1 as default.", e);
}
return "127.0.0.1";
diff --git a/java/common/src/main/java/org/ray/util/Sha1Digestor.java b/java/common/src/main/java/org/ray/util/Sha1Digestor.java
index 2c0c67789..b9d520609 100644
--- a/java/common/src/main/java/org/ray/util/Sha1Digestor.java
+++ b/java/common/src/main/java/org/ray/util/Sha1Digestor.java
@@ -10,8 +10,8 @@ public class Sha1Digestor {
try {
return MessageDigest.getInstance("SHA1");
} catch (Exception e) {
- RayLog.core.error("cannot get SHA1 MessageDigest", e);
- throw new RuntimeException("cannot get SHA1 digest", e);
+ RayLog.core.error("Cannot get SHA1 MessageDigest", e);
+ throw new RuntimeException("Cannot get SHA1 digest", e);
}
});
diff --git a/java/pom.xml b/java/pom.xml
index 89584b6ee..73854c96d 100644
--- a/java/pom.xml
+++ b/java/pom.xml
@@ -63,7 +63,7 @@
com.github.davidmoten
flatbuffers-java
- 1.7.0.1
+ 1.9.0.1
diff --git a/java/prepare.sh b/java/prepare.sh
index d59b92b09..b699b4a5f 100755
--- a/java/prepare.sh
+++ b/java/prepare.sh
@@ -47,12 +47,15 @@ declare -a nativeBinaries=(
"./src/plasma/plasma_manager"
"./src/local_scheduler/local_scheduler"
"./src/global_scheduler/global_scheduler"
+ "./src/ray/raylet/raylet"
+ "./src/ray/raylet/raylet_monitor"
)
declare -a nativeLibraries=(
"./src/common/redis_module/libray_redis_module.so"
"./src/local_scheduler/liblocal_scheduler_library_java.*"
"./src/plasma/libplasma_java.*"
+ "./src/ray/raylet/*lib.a"
)
declare -a javaBinaries=(
diff --git a/java/ray.config.ini b/java/ray.config.ini
index c9630e75b..a8a6b0981 100644
--- a/java/ray.config.ini
+++ b/java/ray.config.ini
@@ -62,6 +62,8 @@ deploy = false
onebox_delay_seconds_before_run_app_logic = 0
+use_raylet = false
+
; java class which main is served as the driver in a java worker
driver_class =
@@ -123,6 +125,7 @@ store = %CONFIG_FILE_DIR%/../build/src/plasma/plasma_store
store_manager = %CONFIG_FILE_DIR%/../build/src/plasma/plasma_manager
local_scheduler = %CONFIG_FILE_DIR%/../build/src/local_scheduler/local_scheduler
global_scheduler = %CONFIG_FILE_DIR%/../build/src/global_scheduler/global_scheduler
+raylet = %CONFIG_FILE_DIR%/../build/src/ray/raylet/raylet
python_dir = %CONFIG_FILE_DIR%/../build/
java_runtime_rewritten_jars_dir =
java_class_paths = ray.java.path.classes.source
@@ -135,6 +138,7 @@ store = %CONFIG_FILE_DIR%/../build/src/plasma/plasma_store
store_manager = %CONFIG_FILE_DIR%/../build/src/plasma/plasma_manager
local_scheduler = %CONFIG_FILE_DIR%/../build/src/local_scheduler/local_scheduler
global_scheduler = %CONFIG_FILE_DIR%/../build/src/global_scheduler/global_scheduler
+raylet = %CONFIG_FILE_DIR%/../build/src/ray/raylet/raylet
python_dir = %CONFIG_FILE_DIR%/../build/
java_runtime_rewritten_jars_dir =
java_class_paths = ray.java.path.classes.package
@@ -147,6 +151,7 @@ store = %CONFIG_FILE_DIR%/native/bin/plasma_store
store_manager = %CONFIG_FILE_DIR%/native/bin/plasma_manager
local_scheduler = %CONFIG_FILE_DIR%/native/bin/local_scheduler
global_scheduler = %CONFIG_FILE_DIR%/native/bin/global_scheduler
+raylet = %CONFIG_FILE_DIR%/native/bin/raylet
python_dir = %CONFIG_FILE_DIR%/python
java_runtime_rewritten_jars_dir = %CONFIG_FILE_DIR%/java/lib/
java_class_paths = ray.java.path.classes.deploy
diff --git a/java/runtime-common/src/main/java/org/ray/core/RayRuntime.java b/java/runtime-common/src/main/java/org/ray/core/RayRuntime.java
index 52c2b94e8..25bacf062 100644
--- a/java/runtime-common/src/main/java/org/ray/core/RayRuntime.java
+++ b/java/runtime-common/src/main/java/org/ray/core/RayRuntime.java
@@ -1,5 +1,6 @@
package org.ray.core;
+import com.google.common.collect.ImmutableList;
import java.lang.reflect.Constructor;
import java.lang.reflect.InvocationTargetException;
import java.util.ArrayList;
@@ -122,7 +123,13 @@ public abstract class RayRuntime implements RayApi {
functions = new LocalFunctionManager(remoteLoader);
localSchedulerProxy = new LocalSchedulerProxy(slink);
- objectStoreProxy = new ObjectStoreProxy(plink);
+
+ if (!params.use_raylet) {
+ objectStoreProxy = new ObjectStoreProxy(plink);
+ } else {
+ objectStoreProxy = new ObjectStoreProxy(plink, slink);
+ }
+
worker = new Worker(localSchedulerProxy, functions);
}
@@ -188,7 +195,9 @@ public abstract class RayRuntime implements RayApi {
public void putRaw(UniqueID taskId, UniqueID objectId, T obj, TMT metadata) {
RayLog.core.info("Task " + taskId.toString() + " Object " + objectId.toString() + " put");
- localSchedulerProxy.markTaskPutDependency(taskId, objectId);
+ if (!params.use_raylet) {
+ localSchedulerProxy.markTaskPutDependency(taskId, objectId);
+ }
objectStoreProxy.put(objectId, obj, metadata);
}
@@ -274,22 +283,32 @@ public abstract class RayRuntime implements RayApi {
return worker.rpcWithReturnIndices(taskId, funcCls, lambda, returnCount, args);
}
+
private List doGet(List objectIds, boolean isMetadata)
throws TaskExecutionException {
boolean wasBlocked = false;
UniqueID taskId = getCurrentTaskId();
+
try {
int numObjectIds = objectIds.size();
// Do an initial fetch for remote objects.
- dividedFetch(objectIds);
+ List> fetchBatches =
+ splitIntoBatches(objectIds, params.worker_fetch_request_size);
+ for (List batch : fetchBatches) {
+ if (!params.use_raylet) {
+ objectStoreProxy.fetch(batch);
+ } else {
+ localSchedulerProxy.reconstructObjects(batch, true);
+ }
+ }
// Get the objects. We initially try to get the objects immediately.
List> ret = objectStoreProxy
.get(objectIds, params.default_first_check_timeout_ms, isMetadata);
assert ret.size() == numObjectIds;
- // mapping the object IDs that we haven't gotten yet to their original index in objectIds
+ // Mapping the object IDs that we haven't gotten yet to their original index in objectIds.
Map unreadys = new HashMap<>();
for (int i = 0; i < numObjectIds; i++) {
if (ret.get(i).getRight() != GetStatus.SUCCESS) {
@@ -301,15 +320,22 @@ public abstract class RayRuntime implements RayApi {
// Try reconstructing any objects we haven't gotten yet. Try to get them
// until at least PlasmaLink.GET_TIMEOUT_MS milliseconds passes, then repeat.
while (unreadys.size() > 0) {
- for (UniqueID id : unreadys.keySet()) {
- localSchedulerProxy.reconstructObject(id);
- }
-
- // Do another fetch for objects that aren't available locally yet, in case
- // they were evicted since the last fetch.
List unreadyList = new ArrayList<>(unreadys.keySet());
+ List> reconstructBatches =
+ splitIntoBatches(unreadyList, params.worker_fetch_request_size);
- dividedFetch(unreadyList);
+ for (List batch : reconstructBatches) {
+ if (!params.use_raylet) {
+ for (UniqueID objectId : batch) {
+ localSchedulerProxy.reconstructObject(objectId, false);
+ }
+ // Do another fetch for objects that aren't available locally yet, in case
+ // they were evicted since the last fetch.
+ objectStoreProxy.fetch(batch);
+ } else {
+ localSchedulerProxy.reconstructObjects(batch, false);
+ }
+ }
List> results = objectStoreProxy
.get(unreadyList, params.default_get_check_interval_ms, isMetadata);
@@ -329,9 +355,11 @@ public abstract class RayRuntime implements RayApi {
RayLog.core
.debug("Task " + taskId + " Objects " + Arrays.toString(objectIds.toArray()) + " get");
List finalRet = new ArrayList<>();
+
for (Pair value : ret) {
finalRet.add(value.getLeft());
}
+
return finalRet;
} catch (TaskExecutionException e) {
RayLog.core.error("Task " + taskId + " Objects " + Arrays.toString(objectIds.toArray())
@@ -344,68 +372,30 @@ public abstract class RayRuntime implements RayApi {
localSchedulerProxy.notifyUnblocked();
}
}
-
}
private T doGet(UniqueID objectId, boolean isMetadata) throws TaskExecutionException {
+ ImmutableList objectIds = ImmutableList.of(objectId);
+ List results = doGet(objectIds, isMetadata);
- boolean wasBlocked = false;
- UniqueID taskId = getCurrentTaskId();
- try {
- // Do an initial fetch.
- objectStoreProxy.fetch(objectId);
-
- // Get the object. We initially try to get the object immediately.
- Pair ret = objectStoreProxy
- .get(objectId, params.default_first_check_timeout_ms, isMetadata);
-
- wasBlocked = (ret.getRight() != GetStatus.SUCCESS);
-
- // Try reconstructing the object. Try to get it until at least PlasmaLink.GET_TIMEOUT_MS
- // milliseconds passes, then repeat.
- while (ret.getRight() != GetStatus.SUCCESS) {
- RayLog.core.warn(
- "Task " + taskId + " Object " + objectId.toString() + " get failed, reconstruct ...");
- localSchedulerProxy.reconstructObject(objectId);
-
- // Do another fetch
- objectStoreProxy.fetch(objectId);
-
- ret = objectStoreProxy.get(objectId, params.default_get_check_interval_ms,
- isMetadata);//check the result every 5s, but it will return once available
- }
- RayLog.core.debug(
- "Task " + taskId + " Object " + objectId.toString() + " get" + ", the result " + ret
- .getLeft());
- return ret.getLeft();
- } catch (TaskExecutionException e) {
- RayLog.core
- .error("Task " + taskId + " Object " + objectId.toString() + " get with Exception", e);
- throw e;
- } finally {
- // If the object was not able to get locally, let the local scheduler
- // know that we're now unblocked.
- if (wasBlocked) {
- localSchedulerProxy.notifyUnblocked();
- }
- }
-
+ assert results.size() == 1;
+ return results.get(0);
}
- // We divide the fetch into smaller fetches so as to not block the manager
- // for a prolonged period of time in a single call.
- private void dividedFetch(List objectIds) {
- int fetchSize = objectStoreProxy.getFetchSize();
+ private List> splitIntoBatches(List objectIds, int batchSize) {
+ List> batches = new ArrayList<>();
+ int objectsSize = objectIds.size();
- int numObjectIds = objectIds.size();
- for (int i = 0; i < numObjectIds; i += fetchSize) {
- int endIndex = i + fetchSize;
- if (endIndex < numObjectIds) {
- objectStoreProxy.fetch(objectIds.subList(i, endIndex));
- } else {
- objectStoreProxy.fetch(objectIds.subList(i, numObjectIds));
- }
+ for (int i = 0; i < objectsSize; i += batchSize) {
+ int endIndex = i + batchSize;
+ List batchIds = (endIndex < objectsSize)
+ ? objectIds.subList(i, endIndex)
+ : objectIds.subList(i, objectsSize);
+
+ batches.add(batchIds);
}
+
+ return batches;
}
/**
diff --git a/java/runtime-common/src/main/java/org/ray/core/model/RayParameters.java b/java/runtime-common/src/main/java/org/ray/core/model/RayParameters.java
index f2079c64d..a7ca6ac62 100644
--- a/java/runtime-common/src/main/java/org/ray/core/model/RayParameters.java
+++ b/java/runtime-common/src/main/java/org/ray/core/model/RayParameters.java
@@ -112,6 +112,18 @@ public class RayParameters {
@AConfig(comment = "delay seconds under onebox before app logic for debugging")
public int onebox_delay_seconds_before_run_app_logic = 0;
+ @AConfig(comment = "whether to use raylet")
+ public boolean use_raylet = false;
+
+ @AConfig(comment = "raylet socket name (e.g., /tmp/raylet1111")
+ public String raylet_socket_name = "";
+
+ @AConfig(comment = "raylet rpc listen port")
+ public int raylet_port = 35567;
+
+ @AConfig(comment = "worker fetch request size")
+ public int worker_fetch_request_size = 10000;
+
public RayParameters(ConfigReader config) {
if (null != config) {
String networkInterface = config.getStringValue("ray.java", "network_interface", null,
diff --git a/java/runtime-common/src/main/java/org/ray/spi/LocalSchedulerLink.java b/java/runtime-common/src/main/java/org/ray/spi/LocalSchedulerLink.java
index 93852ea3c..b04b2508d 100644
--- a/java/runtime-common/src/main/java/org/ray/spi/LocalSchedulerLink.java
+++ b/java/runtime-common/src/main/java/org/ray/spi/LocalSchedulerLink.java
@@ -1,5 +1,6 @@
package org.ray.spi;
+import java.util.List;
import org.ray.api.UniqueID;
import org.ray.spi.model.TaskSpec;
@@ -14,7 +15,11 @@ public interface LocalSchedulerLink {
void markTaskPutDependency(UniqueID taskId, UniqueID objectId);
- void reconstructObject(UniqueID objectId);
+ void reconstructObject(UniqueID objectId, boolean fetchOnly);
+
+ void reconstructObjects(List objectIds, boolean fetchOnly);
void notifyUnblocked();
+
+ List wait(byte[][] objectIds, int timeoutMs, int numReturns);
}
diff --git a/java/runtime-common/src/main/java/org/ray/spi/LocalSchedulerProxy.java b/java/runtime-common/src/main/java/org/ray/spi/LocalSchedulerProxy.java
index 8750f4aa0..f999f8e8a 100644
--- a/java/runtime-common/src/main/java/org/ray/spi/LocalSchedulerProxy.java
+++ b/java/runtime-common/src/main/java/org/ray/spi/LocalSchedulerProxy.java
@@ -1,13 +1,17 @@
package org.ray.spi;
+import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashMap;
+import java.util.List;
import java.util.Map;
+import org.ray.api.RayList;
import org.ray.api.RayMap;
import org.ray.api.RayObject;
import org.ray.api.RayObjects;
import org.ray.api.UniqueID;
+import org.ray.api.WaitResult;
import org.ray.core.ArgumentsBuilder;
import org.ray.core.UniqueIdHelper;
import org.ray.core.WorkerContext;
@@ -124,11 +128,44 @@ public class LocalSchedulerProxy {
scheduler.markTaskPutDependency(taskId, objectId);
}
- public void reconstructObject(UniqueID objectId) {
- scheduler.reconstructObject(objectId);
+ public void reconstructObject(UniqueID objectId, boolean fetchOnly) {
+ scheduler.reconstructObject(objectId, fetchOnly);
+ }
+
+ public void reconstructObjects(List objectIds, boolean fetchOnly) {
+ scheduler.reconstructObjects(objectIds, fetchOnly);
}
public void notifyUnblocked() {
scheduler.notifyUnblocked();
}
+
+ private static byte[][] getIdBytes(List objectIds) {
+ int size = objectIds.size();
+ byte[][] ids = new byte[size][];
+ for (int i = 0; i < size; i++) {
+ ids[i] = objectIds.get(i).getBytes();
+ }
+ return ids;
+ }
+
+ public WaitResult wait(RayList waitfor, int numReturns, int timeout) {
+ List ids = new ArrayList<>();
+ for (RayObject obj : waitfor.Objects()) {
+ ids.add(obj.getId());
+ }
+ List readys = scheduler.wait(getIdBytes(ids), timeout, numReturns);
+
+ RayList readyObjs = new RayList<>();
+ RayList remainObjs = new RayList<>();
+ for (RayObject obj : waitfor.Objects()) {
+ if (readys.contains(obj.getId().getBytes())) {
+ readyObjs.add(obj);
+ } else {
+ remainObjs.add(obj);
+ }
+ }
+
+ return new WaitResult<>(readyObjs, remainObjs);
+ }
}
diff --git a/java/runtime-common/src/main/java/org/ray/spi/ObjectStoreProxy.java b/java/runtime-common/src/main/java/org/ray/spi/ObjectStoreProxy.java
index 3b4b34f8e..193afe537 100644
--- a/java/runtime-common/src/main/java/org/ray/spi/ObjectStoreProxy.java
+++ b/java/runtime-common/src/main/java/org/ray/spi/ObjectStoreProxy.java
@@ -10,6 +10,7 @@ import org.ray.api.UniqueID;
import org.ray.api.WaitResult;
import org.ray.core.Serializer;
import org.ray.core.WorkerContext;
+import org.ray.spi.LocalSchedulerLink;
import org.ray.util.exception.TaskExecutionException;
/**
@@ -19,12 +20,19 @@ import org.ray.util.exception.TaskExecutionException;
public class ObjectStoreProxy {
private final ObjectStoreLink store;
+ private final LocalSchedulerLink localSchedulerLink;
private final int getTimeoutMs = 1000;
public ObjectStoreProxy(ObjectStoreLink store) {
this.store = store;
+ this.localSchedulerLink = null;
}
+ public ObjectStoreProxy(ObjectStoreLink store, LocalSchedulerLink localSchedulerLink) {
+ this.store = store;
+ this.localSchedulerLink = localSchedulerLink;
+ }
+
public Pair get(UniqueID objectId, boolean isMetadata)
throws TaskExecutionException {
return get(objectId, getTimeoutMs, isMetadata);
@@ -88,7 +96,12 @@ public class ObjectStoreProxy {
for (RayObject obj : waitfor.Objects()) {
ids.add(obj.getId());
}
- List readys = store.wait(getIdBytes(ids), timeout, numReturns);
+ List readys;
+ if (localSchedulerLink == null) {
+ readys = store.wait(getIdBytes(ids), timeout, numReturns);
+ } else {
+ readys = localSchedulerLink.wait(getIdBytes(ids), timeout, numReturns);
+ }
RayList readyObjs = new RayList<>();
RayList remainObjs = new RayList<>();
@@ -103,19 +116,14 @@ public class ObjectStoreProxy {
return new WaitResult<>(readyObjs, remainObjs);
}
- public void fetch(UniqueID objectId) {
- store.fetch(objectId.getBytes());
- }
-
public void fetch(List objectIds) {
- store.fetch(getIdBytes(objectIds));
+ if (localSchedulerLink == null) {
+ store.fetch(getIdBytes(objectIds));
+ } else {
+ localSchedulerLink.reconstructObjects(objectIds, true);
+ }
}
- public int getFetchSize() {
- return 10000;
- }
-
-
public enum GetStatus {
SUCCESS, FAILED
}
diff --git a/java/runtime-common/src/main/java/org/ray/spi/PathConfig.java b/java/runtime-common/src/main/java/org/ray/spi/PathConfig.java
index f33e685b4..e8e0f1228 100644
--- a/java/runtime-common/src/main/java/org/ray/spi/PathConfig.java
+++ b/java/runtime-common/src/main/java/org/ray/spi/PathConfig.java
@@ -37,6 +37,9 @@ public class PathConfig {
@AConfig(comment = "path to global scheduler")
public String global_scheduler;
+ @AConfig(comment = "path to raylet")
+ public String raylet;
+
@AConfig(comment = "path to python directory")
public String python_dir;
diff --git a/java/runtime-common/src/main/java/org/ray/spi/model/AddressInfo.java b/java/runtime-common/src/main/java/org/ray/spi/model/AddressInfo.java
index 7276278d0..dcf131d43 100644
--- a/java/runtime-common/src/main/java/org/ray/spi/model/AddressInfo.java
+++ b/java/runtime-common/src/main/java/org/ray/spi/model/AddressInfo.java
@@ -8,9 +8,11 @@ public class AddressInfo {
public String managerName;
public String storeName;
public String schedulerName;
+ public String rayletSocketName;
public int managerPort;
public int workerCount;
public String managerRpcAddr;
public String storeRpcAddr;
public String schedulerRpcAddr;
+ public String rayletRpcAddr;
}
diff --git a/java/runtime-dev/pom.xml b/java/runtime-dev/pom.xml
index 5bc1bb6d4..891a333e1 100644
--- a/java/runtime-dev/pom.xml
+++ b/java/runtime-dev/pom.xml
@@ -1,5 +1,4 @@
-
diff --git a/java/runtime-dev/src/main/java/org/ray/spi/impl/MockLocalScheduler.java b/java/runtime-dev/src/main/java/org/ray/spi/impl/MockLocalScheduler.java
index 89bbee1b1..626f08473 100644
--- a/java/runtime-dev/src/main/java/org/ray/spi/impl/MockLocalScheduler.java
+++ b/java/runtime-dev/src/main/java/org/ray/spi/impl/MockLocalScheduler.java
@@ -1,5 +1,6 @@
package org.ray.spi.impl;
+import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import org.ray.api.UniqueID;
@@ -74,7 +75,12 @@ public class MockLocalScheduler implements LocalSchedulerLink {
}
@Override
- public void reconstructObject(UniqueID objectId) {
+ public void reconstructObject(UniqueID objectId, boolean fetchOnly) {
+
+ }
+
+ @Override
+ public void reconstructObjects(List objectIds, boolean fetchOnly) {
}
@@ -82,4 +88,9 @@ public class MockLocalScheduler implements LocalSchedulerLink {
public void notifyUnblocked() {
}
+
+ @Override
+ public List wait(byte[][] objectIds, int timeoutMs, int numReturns) {
+ return store.wait(objectIds, timeoutMs, numReturns);
+ }
}
diff --git a/java/runtime-native/src/main/java/org/ray/core/impl/RayNativeRuntime.java b/java/runtime-native/src/main/java/org/ray/core/impl/RayNativeRuntime.java
index b1e708b21..bcb8f4c37 100644
--- a/java/runtime-native/src/main/java/org/ray/core/impl/RayNativeRuntime.java
+++ b/java/runtime-native/src/main/java/org/ray/core/impl/RayNativeRuntime.java
@@ -25,8 +25,9 @@ import org.ray.spi.RemoteFunctionManager;
import org.ray.spi.StateStoreProxy;
import org.ray.spi.impl.DefaultLocalSchedulerClient;
import org.ray.spi.impl.NativeRemoteFunctionManager;
+import org.ray.spi.impl.NonRayletStateStoreProxyImpl;
+import org.ray.spi.impl.RayletStateStoreProxyImpl;
import org.ray.spi.impl.RedisClient;
-import org.ray.spi.impl.StateStoreProxyImpl;
import org.ray.spi.model.AddressInfo;
import org.ray.util.exception.TaskExecutionException;
import org.ray.util.logger.RayLog;
@@ -62,14 +63,19 @@ public class RayNativeRuntime extends RayRuntime {
throw new Error("Redis address must be configured under Worker mode.");
}
startOnebox(params, pathConfig);
- initStateStore(params.redis_address);
+ initStateStore(params.redis_address, params.use_raylet);
} else {
- initStateStore(params.redis_address);
+ initStateStore(params.redis_address, params.use_raylet);
if (!isWorker) {
- List nodes = stateStoreProxy.getAddressInfo(params.node_ip_address, 5);
+ List nodes = stateStoreProxy.getAddressInfo(
+ params.node_ip_address, params.redis_address, 5);
params.object_store_name = nodes.get(0).storeName;
- params.object_store_manager_name = nodes.get(0).managerName;
- params.local_scheduler_name = nodes.get(0).schedulerName;
+ if (!params.use_raylet) {
+ params.object_store_manager_name = nodes.get(0).managerName;
+ params.local_scheduler_name = nodes.get(0).schedulerName;
+ } else {
+ params.raylet_socket_name = nodes.get(0).rayletSocketName;
+ }
}
}
@@ -101,23 +107,45 @@ public class RayNativeRuntime extends RayRuntime {
.getIntegerValue("ray", "plasma_default_release_delay", 0,
"how many release requests should be delayed in plasma client");
- ObjectStoreLink plink = new PlasmaClient(params.object_store_name, params
- .object_store_manager_name, releaseDelay);
+ if (!params.use_raylet) {
+ ObjectStoreLink plink = new PlasmaClient(params.object_store_name,
+ params.object_store_manager_name, releaseDelay);
- LocalSchedulerLink slink = new DefaultLocalSchedulerClient(
- params.local_scheduler_name,
- WorkerContext.currentWorkerId(),
- UniqueID.nil,
- isWorker,
- WorkerContext.currentTask().taskId,
- 0
- );
+ LocalSchedulerLink slink = new DefaultLocalSchedulerClient(
+ params.local_scheduler_name,
+ WorkerContext.currentWorkerId(),
+ UniqueID.nil,
+ isWorker,
+ WorkerContext.currentTask().taskId,
+ 0,
+ false
+ );
- init(slink, plink, funcMgr, pathConfig);
+ init(slink, plink, funcMgr, pathConfig);
- // register
- registerWorker(isWorker, params.node_ip_address, params.object_store_name,
- params.object_store_manager_name, params.local_scheduler_name);
+ // register
+ registerWorker(isWorker, params.node_ip_address, params.object_store_name,
+ params.object_store_manager_name, params.local_scheduler_name);
+ } else {
+
+ ObjectStoreLink plink = new PlasmaClient(params.object_store_name, "", releaseDelay);
+
+ LocalSchedulerLink slink = new DefaultLocalSchedulerClient(
+ params.raylet_socket_name,
+ WorkerContext.currentWorkerId(),
+ UniqueID.nil,
+ isWorker,
+ WorkerContext.currentTask().taskId,
+ 0,
+ true
+ );
+
+ init(slink, plink, funcMgr, pathConfig);
+
+ // register
+ registerWorker(isWorker, params.node_ip_address, params.object_store_name,
+ params.raylet_socket_name);
+ }
}
RayLog.core.info("RayNativeRuntime start with "
@@ -152,19 +180,44 @@ public class RayNativeRuntime extends RayRuntime {
params.object_store_name = manager.info().localStores.get(0).storeName;
params.object_store_manager_name = manager.info().localStores.get(0).managerName;
params.local_scheduler_name = manager.info().localStores.get(0).schedulerName;
+ params.raylet_socket_name = manager.info().localStores.get(0).rayletSocketName;
//params.node_ip_address = NetworkUtil.getIpAddress();
}
- private void initStateStore(String redisAddress) throws Exception {
+ private void initStateStore(String redisAddress, boolean useRaylet) throws Exception {
kvStore = new RedisClient();
kvStore.setAddr(redisAddress);
- stateStoreProxy = new StateStoreProxyImpl(kvStore);
+ stateStoreProxy = useRaylet
+ ? new RayletStateStoreProxyImpl(kvStore)
+ : new NonRayletStateStoreProxyImpl(kvStore);
//stateStoreProxy.setStore(kvStore);
stateStoreProxy.initializeGlobalState();
}
private void registerWorker(boolean isWorker, String nodeIpAddress, String storeName,
- String managerName, String schedulerName) {
+ String rayletSocketName) {
+ Map workerInfo = new HashMap<>();
+ String workerId = new String(WorkerContext.currentWorkerId().getBytes());
+ if (!isWorker) {
+ workerInfo.put("node_ip_address", nodeIpAddress);
+ workerInfo.put("driver_id", workerId);
+ workerInfo.put("start_time", String.valueOf(System.currentTimeMillis()));
+ workerInfo.put("plasma_store_socket", storeName);
+ workerInfo.put("raylet_socket", rayletSocketName);
+ workerInfo.put("name", System.getProperty("user.dir"));
+ //TODO: worker.redis_client.hmset(b"Drivers:" + worker.workerId, driver_info)
+ kvStore.hmset("Drivers:" + workerId, workerInfo);
+ } else {
+ workerInfo.put("node_ip_address", nodeIpAddress);
+ workerInfo.put("plasma_store_socket", storeName);
+ workerInfo.put("raylet_socket", rayletSocketName);
+ //TODO: b"Workers:" + worker.workerId,
+ kvStore.hmset("Workers:" + workerId, workerInfo);
+ }
+ }
+
+ private void registerWorker(boolean isWorker, String nodeIpAddress, String storeName,
+ String managerName, String schedulerName) {
Map workerInfo = new HashMap<>();
String workerId = new String(WorkerContext.currentWorkerId().getBytes());
if (!isWorker) {
diff --git a/java/runtime-native/src/main/java/org/ray/format/gcs/ClientTableData.java b/java/runtime-native/src/main/java/org/ray/format/gcs/ClientTableData.java
new file mode 100644
index 000000000..e7d71415d
--- /dev/null
+++ b/java/runtime-native/src/main/java/org/ray/format/gcs/ClientTableData.java
@@ -0,0 +1,79 @@
+package org.ray.format.gcs;
+// automatically generated by the FlatBuffers compiler, do not modify
+
+import java.nio.*;
+import java.lang.*;
+import com.google.flatbuffers.*;
+
+@SuppressWarnings("unused")
+public final class ClientTableData extends Table {
+ public static ClientTableData getRootAsClientTableData(ByteBuffer _bb) { return getRootAsClientTableData(_bb, new ClientTableData()); }
+ public static ClientTableData getRootAsClientTableData(ByteBuffer _bb, ClientTableData obj) { _bb.order(ByteOrder.LITTLE_ENDIAN); return (obj.__assign(_bb.getInt(_bb.position()) + _bb.position(), _bb)); }
+ public void __init(int _i, ByteBuffer _bb) { bb_pos = _i; bb = _bb; }
+ public ClientTableData __assign(int _i, ByteBuffer _bb) { __init(_i, _bb); return this; }
+
+ public String clientId() { int o = __offset(4); return o != 0 ? __string(o + bb_pos) : null; }
+ public ByteBuffer clientIdAsByteBuffer() { return __vector_as_bytebuffer(4, 1); }
+ public ByteBuffer clientIdInByteBuffer(ByteBuffer _bb) { return __vector_in_bytebuffer(_bb, 4, 1); }
+ public String nodeManagerAddress() { int o = __offset(6); return o != 0 ? __string(o + bb_pos) : null; }
+ public ByteBuffer nodeManagerAddressAsByteBuffer() { return __vector_as_bytebuffer(6, 1); }
+ public ByteBuffer nodeManagerAddressInByteBuffer(ByteBuffer _bb) { return __vector_in_bytebuffer(_bb, 6, 1); }
+ public String rayletSocketName() { int o = __offset(8); return o != 0 ? __string(o + bb_pos) : null; }
+ public ByteBuffer rayletSocketNameAsByteBuffer() { return __vector_as_bytebuffer(8, 1); }
+ public ByteBuffer rayletSocketNameInByteBuffer(ByteBuffer _bb) { return __vector_in_bytebuffer(_bb, 8, 1); }
+ public String objectStoreSocketName() { int o = __offset(10); return o != 0 ? __string(o + bb_pos) : null; }
+ public ByteBuffer objectStoreSocketNameAsByteBuffer() { return __vector_as_bytebuffer(10, 1); }
+ public ByteBuffer objectStoreSocketNameInByteBuffer(ByteBuffer _bb) { return __vector_in_bytebuffer(_bb, 10, 1); }
+ public int nodeManagerPort() { int o = __offset(12); return o != 0 ? bb.getInt(o + bb_pos) : 0; }
+ public int objectManagerPort() { int o = __offset(14); return o != 0 ? bb.getInt(o + bb_pos) : 0; }
+ public boolean isInsertion() { int o = __offset(16); return o != 0 ? 0!=bb.get(o + bb_pos) : false; }
+ public String resourcesTotalLabel(int j) { int o = __offset(18); return o != 0 ? __string(__vector(o) + j * 4) : null; }
+ public int resourcesTotalLabelLength() { int o = __offset(18); return o != 0 ? __vector_len(o) : 0; }
+ public double resourcesTotalCapacity(int j) { int o = __offset(20); return o != 0 ? bb.getDouble(__vector(o) + j * 8) : 0; }
+ public int resourcesTotalCapacityLength() { int o = __offset(20); return o != 0 ? __vector_len(o) : 0; }
+ public ByteBuffer resourcesTotalCapacityAsByteBuffer() { return __vector_as_bytebuffer(20, 8); }
+ public ByteBuffer resourcesTotalCapacityInByteBuffer(ByteBuffer _bb) { return __vector_in_bytebuffer(_bb, 20, 8); }
+
+ public static int createClientTableData(FlatBufferBuilder builder,
+ int client_idOffset,
+ int node_manager_addressOffset,
+ int raylet_socket_nameOffset,
+ int object_store_socket_nameOffset,
+ int node_manager_port,
+ int object_manager_port,
+ boolean is_insertion,
+ int resources_total_labelOffset,
+ int resources_total_capacityOffset) {
+ builder.startObject(9);
+ ClientTableData.addResourcesTotalCapacity(builder, resources_total_capacityOffset);
+ ClientTableData.addResourcesTotalLabel(builder, resources_total_labelOffset);
+ ClientTableData.addObjectManagerPort(builder, object_manager_port);
+ ClientTableData.addNodeManagerPort(builder, node_manager_port);
+ ClientTableData.addObjectStoreSocketName(builder, object_store_socket_nameOffset);
+ ClientTableData.addRayletSocketName(builder, raylet_socket_nameOffset);
+ ClientTableData.addNodeManagerAddress(builder, node_manager_addressOffset);
+ ClientTableData.addClientId(builder, client_idOffset);
+ ClientTableData.addIsInsertion(builder, is_insertion);
+ return ClientTableData.endClientTableData(builder);
+ }
+
+ public static void startClientTableData(FlatBufferBuilder builder) { builder.startObject(9); }
+ public static void addClientId(FlatBufferBuilder builder, int clientIdOffset) { builder.addOffset(0, clientIdOffset, 0); }
+ public static void addNodeManagerAddress(FlatBufferBuilder builder, int nodeManagerAddressOffset) { builder.addOffset(1, nodeManagerAddressOffset, 0); }
+ public static void addRayletSocketName(FlatBufferBuilder builder, int rayletSocketNameOffset) { builder.addOffset(2, rayletSocketNameOffset, 0); }
+ public static void addObjectStoreSocketName(FlatBufferBuilder builder, int objectStoreSocketNameOffset) { builder.addOffset(3, objectStoreSocketNameOffset, 0); }
+ public static void addNodeManagerPort(FlatBufferBuilder builder, int nodeManagerPort) { builder.addInt(4, nodeManagerPort, 0); }
+ public static void addObjectManagerPort(FlatBufferBuilder builder, int objectManagerPort) { builder.addInt(5, objectManagerPort, 0); }
+ public static void addIsInsertion(FlatBufferBuilder builder, boolean isInsertion) { builder.addBoolean(6, isInsertion, false); }
+ public static void addResourcesTotalLabel(FlatBufferBuilder builder, int resourcesTotalLabelOffset) { builder.addOffset(7, resourcesTotalLabelOffset, 0); }
+ public static int createResourcesTotalLabelVector(FlatBufferBuilder builder, int[] data) { builder.startVector(4, data.length, 4); for (int i = data.length - 1; i >= 0; i--) builder.addOffset(data[i]); return builder.endVector(); }
+ public static void startResourcesTotalLabelVector(FlatBufferBuilder builder, int numElems) { builder.startVector(4, numElems, 4); }
+ public static void addResourcesTotalCapacity(FlatBufferBuilder builder, int resourcesTotalCapacityOffset) { builder.addOffset(8, resourcesTotalCapacityOffset, 0); }
+ public static int createResourcesTotalCapacityVector(FlatBufferBuilder builder, double[] data) { builder.startVector(8, data.length, 8); for (int i = data.length - 1; i >= 0; i--) builder.addDouble(data[i]); return builder.endVector(); }
+ public static void startResourcesTotalCapacityVector(FlatBufferBuilder builder, int numElems) { builder.startVector(8, numElems, 8); }
+ public static int endClientTableData(FlatBufferBuilder builder) {
+ int o = builder.endObject();
+ return o;
+ }
+}
+
diff --git a/java/runtime-native/src/main/java/org/ray/runner/RunInfo.java b/java/runtime-native/src/main/java/org/ray/runner/RunInfo.java
index bb58cff7e..cddf74f13 100644
--- a/java/runtime-native/src/main/java/org/ray/runner/RunInfo.java
+++ b/java/runtime-native/src/main/java/org/ray/runner/RunInfo.java
@@ -35,7 +35,7 @@ public class RunInfo {
public enum ProcessType {
PT_WORKER, PT_LOCAL_SCHEDULER, PT_PLASMA_MANAGER, PT_PLASMA_STORE,
- PT_GLOBAL_SCHEDULER, PT_REDIS_SERVER, PT_WEB_UI,
+ PT_GLOBAL_SCHEDULER, PT_REDIS_SERVER, PT_WEB_UI, PT_RAYLET,
PT_DRIVER
}
}
diff --git a/java/runtime-native/src/main/java/org/ray/runner/RunManager.java b/java/runtime-native/src/main/java/org/ray/runner/RunManager.java
index 63f6d8435..43742133c 100644
--- a/java/runtime-native/src/main/java/org/ray/runner/RunManager.java
+++ b/java/runtime-native/src/main/java/org/ray/runner/RunManager.java
@@ -48,7 +48,7 @@ public class RunManager {
private static boolean killProcess(Process p) {
if (p.isAlive()) {
- p.destroyForcibly();
+ p.destroy();
return true;
} else {
return false;
@@ -307,7 +307,7 @@ public class RunManager {
redisClient.close();
// start global scheduler
- if (params.include_global_scheduler) {
+ if (params.include_global_scheduler && !params.use_raylet) {
startGlobalScheduler(params.working_directory + "/globalScheduler",
params.redis_address, params.node_ip_address, params.redirect, params.cleanup);
}
@@ -340,49 +340,70 @@ public class RunManager {
}
}
- // start object stores
- for (int i = 0; i < params.num_local_schedulers; i++) {
- AddressInfo info = new AddressInfo();
- // store
- startObjectStore(i, info, params.working_directory + "/store",
+ AddressInfo info = new AddressInfo();
+
+ if (params.use_raylet) {
+ // Start object store
+ int rpcPort = params.object_store_rpc_port;
+ String storeName = "/tmp/plasma_store" + rpcPort;
+
+ startObjectStore(0, info, params.working_directory + "/store",
params.redis_address, params.node_ip_address, params.redirect, params.cleanup);
- // store manager
- startObjectManager(i, info,
- params.working_directory + "/storeManager", params.redis_address,
- params.node_ip_address, params.redirect, params.cleanup);
+ //Start raylet
+ startRaylet(storeName, info, params.num_cpus[0],params.num_gpus[0],
+ params.num_workers,params.working_directory + "/raylet",
+ params.redis_address, params.node_ip_address, params.redirect, params.cleanup);
runInfo.localStores.add(info);
- }
+ } else {
+ for (int i = 0; i < params.num_local_schedulers; i++) {
+ // Start object stores
+ startObjectStore(i, info, params.working_directory + "/store",
+ params.redis_address, params.node_ip_address, params.redirect, params.cleanup);
- // start local scheduler
- for (int i = 0; i < params.num_local_schedulers; i++) {
- int workerCount = 0;
+ startObjectManager(i, info,
+ params.working_directory + "/storeManager", params.redis_address,
+ params.node_ip_address, params.redirect, params.cleanup);
- if (params.start_workers_from_local_scheduler) {
- workerCount = localNumWorkers[i];
- localNumWorkers[i] = 0;
+ // Start local scheduler
+ int workerCount = 0;
+
+ if (params.start_workers_from_local_scheduler) {
+ workerCount = localNumWorkers[i];
+ localNumWorkers[i] = 0;
+ }
+
+ startLocalScheduler(i, info,
+ params.num_cpus[i], params.num_gpus[i], workerCount,
+ params.working_directory + "/localsc", params.redis_address,
+ params.node_ip_address, params.redirect, params.cleanup);
+
+ runInfo.localStores.add(info);
}
-
- startLocalScheduler(i, runInfo.localStores.get(i),
- params.num_cpus[i], params.num_gpus[i], workerCount,
- params.working_directory + "/localScheduler", params.redis_address,
- params.node_ip_address, params.redirect, params.cleanup);
}
// start local workers
- for (int i = 0; i < params.num_local_schedulers; i++) {
- runInfo.localStores.get(i).workerCount = localNumWorkers[i];
- for (int j = 0; j < localNumWorkers[i]; j++) {
- startWorker(runInfo.localStores.get(i).storeName,
- runInfo.localStores.get(i).managerName, runInfo.localStores.get(i).schedulerName,
- params.working_directory + "/worker" + i + "." + j, params.redis_address,
- params.node_ip_address, UniqueID.nil, "",
- params.redirect, params.cleanup);
+ if (!params.use_raylet) {
+ for (int i = 0; i < params.num_local_schedulers; i++) {
+ AddressInfo localStores = runInfo.localStores.get(i);
+ localStores.workerCount = localNumWorkers[i];
+ for (int j = 0; j < localNumWorkers[i]; j++) {
+ startWorker(localStores.storeName, localStores.managerName, localStores.schedulerName,
+ params.working_directory + "/worker" + i + "." + j, params.redis_address,
+ params.node_ip_address, UniqueID.nil, "", params.redirect, params.cleanup);
+ }
}
}
HashSet excludeTypes = new HashSet<>();
+ if (!params.use_raylet) {
+ excludeTypes.add(RunInfo.ProcessType.PT_RAYLET);
+ } else {
+ excludeTypes.add(RunInfo.ProcessType.PT_LOCAL_SCHEDULER);
+ excludeTypes.add(RunInfo.ProcessType.PT_GLOBAL_SCHEDULER);
+ excludeTypes.add(RunInfo.ProcessType.PT_PLASMA_MANAGER);
+ }
if (!checkAlive(excludeTypes)) {
cleanup(true);
throw new RuntimeException("Start Ray processes failed");
@@ -622,8 +643,8 @@ public class RunManager {
cmd += " -m " + info.managerName;
String workerCmd = null;
- workerCmd = buildWorkerCommand(true, info.storeName, info.managerName, name, UniqueID.nil,
- "", workDir + rpcPort, ip, redisAddress);
+ workerCmd = buildWorkerCommand(true, info.storeName, info.managerName, name,
+ UniqueID.nil, "", workDir + rpcPort, ip, redisAddress);
cmd += " -w \"" + workerCmd + "\"";
if (redisAddress.length() > 0) {
@@ -656,6 +677,82 @@ public class RunManager {
}
}
+ private void startRaylet(String storeName, AddressInfo info, int numCpus,
+ int numGpus, int numWorkers, String workDir,
+ String redisAddress, String ip, boolean redirect,
+ boolean cleanup) {
+
+ int rpcPort = params.raylet_port;
+ String rayletSocketName = "/tmp/raylet" + rpcPort;
+
+ String filePath = paths.raylet;
+
+ String workerCmd = null;
+ workerCmd = buildWorkerCommandRaylet(info.storeName, rayletSocketName, UniqueID.nil,
+ "", workDir + rpcPort, ip, redisAddress);
+
+ int sep = redisAddress.indexOf(':');
+ assert (sep != -1);
+ String gcsIp = redisAddress.substring(0, sep);
+ String gcsPort = redisAddress.substring(sep + 1);
+
+ String resourceArgument = "GPU," + numGpus + ",CPU," + numCpus;
+
+ String[] cmds = new String[]{filePath, rayletSocketName, storeName, ip, gcsIp,
+ gcsPort, "" + numWorkers, workerCmd, resourceArgument};
+
+ Process p = startProcess(cmds, null, RunInfo.ProcessType.PT_RAYLET,
+ workDir + rpcPort, redisAddress, ip, redirect, cleanup);
+
+ if (p != null && p.isAlive()) {
+ try {
+ TimeUnit.MILLISECONDS.sleep(100);
+ } catch (InterruptedException e) {
+ e.printStackTrace();
+ }
+ }
+
+ if (p == null || !p.isAlive()) {
+ info.rayletSocketName = "";
+ info.rayletRpcAddr = "";
+ throw new RuntimeException("Failed to start raylet process.");
+ } else {
+ info.rayletSocketName = rayletSocketName;
+ info.rayletRpcAddr = ip + ":" + rpcPort;
+ }
+ }
+
+ private String buildWorkerCommandRaylet(String storeName, String rayletSocketName,
+ UniqueID actorId, String actorClass, String workDir,
+ String ip, String redisAddress) {
+ String workerConfigs = "ray.java.start.object_store_name=" + storeName
+ + ";ray.java.start.raylet_socket_name=" + rayletSocketName
+ + ";ray.java.start.worker_mode=WORKER;ray.java.start.use_raylet=true";
+ workerConfigs += ";ray.java.start.deploy=" + params.deploy;
+ if (!actorId.equals(UniqueID.nil)) {
+ workerConfigs += ";ray.java.start.actor_id=" + actorId;
+ }
+ if (!actorClass.equals("")) {
+ workerConfigs += ";ray.java.start.driver_class=" + actorClass;
+ }
+
+ String jvmArgs = "";
+ jvmArgs += " -Dlogging.path=" + params.working_directory + "/logs/workers";
+ jvmArgs += " -Dlogging.file.name=core-*pid_suffix*";
+
+ return buildJavaProcessCommand(
+ RunInfo.ProcessType.PT_WORKER,
+ "org.ray.runner.worker.DefaultWorker",
+ "",
+ workerConfigs,
+ jvmArgs,
+ workDir,
+ ip,
+ redisAddress,
+ null
+ );
+ }
+
private String buildWorkerCommand(boolean isFromLocalScheduler, String storeName,
String storeManagerName, String localSchedulerName,
UniqueID actorId, String actorClass, String workDir, String
diff --git a/java/runtime-native/src/main/java/org/ray/spi/KeyValueStoreLink.java b/java/runtime-native/src/main/java/org/ray/spi/KeyValueStoreLink.java
index 6d21fa8e8..b52ee6ab9 100644
--- a/java/runtime-native/src/main/java/org/ray/spi/KeyValueStoreLink.java
+++ b/java/runtime-native/src/main/java/org/ray/spi/KeyValueStoreLink.java
@@ -103,6 +103,15 @@ public interface KeyValueStoreLink {
*/
List lrange(final String key, final long start, final long end);
+ /**
+ * Return the set of elements of the sorted set stored at the specified key.
+ * @param key The specified key you want to query.
+ * @param start The start index of the range.
+ * @param end The end index of the range.
+ * @return The set of elements you queried.
+ */
+ Set zrange(byte[] key, long start, long end);
+
/**
* Rpush.
* @return Integer reply, specifically, the number of elements inside the list after the push
@@ -123,4 +132,7 @@ public interface KeyValueStoreLink {
Long publish(byte[] channel, byte[] message);
Object getImpl();
+
+ byte[] sendCommand(String command, int commandType, byte[] objectId);
+
}
diff --git a/java/runtime-native/src/main/java/org/ray/spi/StateStoreProxy.java b/java/runtime-native/src/main/java/org/ray/spi/StateStoreProxy.java
index 2ecaadc3a..c5995dba1 100644
--- a/java/runtime-native/src/main/java/org/ray/spi/StateStoreProxy.java
+++ b/java/runtime-native/src/main/java/org/ray/spi/StateStoreProxy.java
@@ -31,5 +31,7 @@ public interface StateStoreProxy {
* getAddressInfo.
* @return list of address information
*/
- List getAddressInfo(final String nodeIpAddress, int numRetries);
+ List getAddressInfo(final String nodeIpAddress,
+ final String redisAddress,
+ int numRetries);
}
diff --git a/java/runtime-native/src/main/java/org/ray/spi/impl/BaseStateStoreProxyImpl.java b/java/runtime-native/src/main/java/org/ray/spi/impl/BaseStateStoreProxyImpl.java
new file mode 100644
index 000000000..9d01134f5
--- /dev/null
+++ b/java/runtime-native/src/main/java/org/ray/spi/impl/BaseStateStoreProxyImpl.java
@@ -0,0 +1,124 @@
+package org.ray.spi.impl;
+
+import java.io.UnsupportedEncodingException;
+import java.util.ArrayList;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.concurrent.TimeUnit;
+import org.ray.spi.KeyValueStoreLink;
+import org.ray.spi.StateStoreProxy;
+import org.ray.spi.model.AddressInfo;
+import org.ray.util.logger.RayLog;
+
+/**
+ * Base class used to interface with the Ray control state.
+ */
+public abstract class BaseStateStoreProxyImpl implements StateStoreProxy {
+
+ public KeyValueStoreLink rayKvStore;
+ public ArrayList shardStoreList = new ArrayList<>();
+
+ public BaseStateStoreProxyImpl(KeyValueStoreLink rayKvStore) {
+ this.rayKvStore = rayKvStore;
+ }
+
+ @Override
+ public void setStore(KeyValueStoreLink rayKvStore) {
+ this.rayKvStore = rayKvStore;
+ }
+
+ @Override
+ public synchronized void initializeGlobalState() throws Exception {
+
+ String es;
+
+ checkConnected();
+
+ String s = rayKvStore.get("NumRedisShards", null);
+ if (s == null) {
+ throw new Exception("NumRedisShards not found in redis.");
+ }
+ int numRedisShards = Integer.parseInt(s);
+ if (numRedisShards < 1) {
+ es = String.format("Expected at least one Redis shard, found %d", numRedisShards);
+ throw new Exception(es);
+ }
+ List ipAddressPorts = rayKvStore.lrange("RedisShards", 0, -1);
+ Set distinctIpAddress = new HashSet(ipAddressPorts);
+ if (distinctIpAddress.size() != numRedisShards) {
+ es = String.format("Expected %d Redis shard addresses, found2 %d.", numRedisShards,
+ distinctIpAddress.size());
+ throw new Exception(es);
+ }
+
+ shardStoreList.clear();
+ for (String ipPort : distinctIpAddress) {
+ shardStoreList.add(new RedisClient(ipPort));
+ }
+
+ }
+
+ public void checkConnected() throws Exception {
+ rayKvStore.checkConnected();
+ }
+
+ @Override
+ public synchronized Set keys(final String pattern) {
+ Set allKeys = new HashSet<>();
+ Set tmpKey;
+ for (KeyValueStoreLink ashardStoreList : shardStoreList) {
+ tmpKey = ashardStoreList.keys(pattern);
+ allKeys.addAll(tmpKey);
+ }
+
+ return allKeys;
+ }
+
+ @Override
+ public List getAddressInfo(final String nodeIpAddress,
+ final String redisAddress,
+ int numRetries) {
+ int count = 0;
+ while (count < numRetries) {
+ try {
+ return doGetAddressInfo(nodeIpAddress, redisAddress);
+ } catch (Exception e) {
+ try {
+ RayLog.core.warn("Error occurred in BaseStateStoreProxyImpl getAddressInfo, "
+ + (numRetries - count) + " retries remaining", e);
+ TimeUnit.MILLISECONDS.sleep(1000);
+ } catch (InterruptedException ie) {
+ RayLog.core.error("error at BaseStateStoreProxyImpl getAddressInfo", e);
+ throw new RuntimeException(e);
+ }
+ }
+ count++;
+ }
+ throw new RuntimeException("cannot get address info from state store");
+ }
+
+ /**
+ * Get address info of one node from primary redis.
+ * This method only tries to get address info once, without any retry.
+ *
+ * @param nodeIpAddress Usually local ip address.
+ * @param redisAddress The primary redis address.
+ * @return A list of SchedulerInfo which contains node manager or local scheduler address info.
+ * @throws Exception No redis client exception.
+ */
+ protected abstract List doGetAddressInfo(final String nodeIpAddress,
+ final String redisAddress) throws Exception;
+
+ protected String charsetDecode(byte[] bs, String charset) throws UnsupportedEncodingException {
+ return new String(bs, charset);
+ }
+
+ protected byte[] charsetEncode(String str, String charset) throws UnsupportedEncodingException {
+ if (str != null) {
+ return str.getBytes(charset);
+ }
+ return null;
+ }
+}
diff --git a/java/runtime-native/src/main/java/org/ray/spi/impl/DefaultLocalSchedulerClient.java b/java/runtime-native/src/main/java/org/ray/spi/impl/DefaultLocalSchedulerClient.java
index 872f2dc64..f732efb37 100644
--- a/java/runtime-native/src/main/java/org/ray/spi/impl/DefaultLocalSchedulerClient.java
+++ b/java/runtime-native/src/main/java/org/ray/spi/impl/DefaultLocalSchedulerClient.java
@@ -24,20 +24,44 @@ public class DefaultLocalSchedulerClient implements LocalSchedulerLink {
return bb;
});
private long client = 0;
+ boolean useRaylet = false;
- public DefaultLocalSchedulerClient(String schedulerSockName, UniqueID clientId, UniqueID actorId,
- boolean isWorker, UniqueID driverId, long numGpus) {
+ public DefaultLocalSchedulerClient(String schedulerSockName, UniqueID clientId,
+ UniqueID actorId, boolean isWorker, UniqueID driverId,
+ long numGpus, boolean useRaylet) {
client = _init(schedulerSockName, clientId.getBytes(), actorId.getBytes(), isWorker,
- driverId.getBytes(), numGpus);
+ driverId.getBytes(), numGpus, useRaylet);
+ this.useRaylet = useRaylet;
}
- private static native long _init(String localSchedulerSocket, byte[] workerId, byte[] actorId,
- boolean isWorker, byte[] driverTaskId, long numGpus);
+ private static native long _init(String localSchedulerSocket, byte[] workerId,
+ byte[] actorId, boolean isWorker, byte[] driverTaskId,
+ long numGpus, boolean useRaylet);
private static native byte[] _computePutId(long client, byte[] taskId, int putIndex);
private static native void _task_done(long client);
+ private static native boolean[] _waitObject(long conn, byte[][] objectIds,
+ int numReturns, int timeout, boolean waitLocal);
+
+ @Override
+ public List wait(byte[][] objectIds, int timeoutMs, int numReturns) {
+ assert (useRaylet == true);
+
+ boolean[] readys = _waitObject(client, objectIds, numReturns, timeoutMs, false);
+
+ List ret = new ArrayList<>();
+ for (int i = 0; i < readys.length; i++) {
+ if (readys[i]) {
+ ret.add(objectIds[i]);
+ }
+ }
+
+ assert (ret.size() == readys.length);
+ return ret;
+ }
+
@Override
public void submitTask(TaskSpec task) {
ByteBuffer info = taskSpec2Info(task);
@@ -45,12 +69,13 @@ public class DefaultLocalSchedulerClient implements LocalSchedulerLink {
if (!task.actorId.isNil()) {
a = task.cursorId.getBytes();
}
- _submitTask(client, a, info, info.position(), info.remaining());
+
+ _submitTask(client, a, info, info.position(), info.remaining(), useRaylet);
}
@Override
public TaskSpec getTaskTodo() {
- byte[] bytes = _getTaskTodo(client);
+ byte[] bytes = _getTaskTodo(client, useRaylet);
assert (null != bytes);
ByteBuffer bb = ByteBuffer.wrap(bytes);
return taskInfo2Spec(bb);
@@ -62,8 +87,16 @@ public class DefaultLocalSchedulerClient implements LocalSchedulerLink {
}
@Override
- public void reconstructObject(UniqueID objectId) {
- _reconstruct_object(client, objectId.getBytes());
+ public void reconstructObject(UniqueID objectId, boolean fetchOnly) {
+ List objects = new ArrayList<>();
+ objects.add(objectId);
+ _reconstruct_objects(client, getIdBytes(objects), fetchOnly);
+ }
+
+ @Override
+ public void reconstructObjects(List objectIds, boolean fetchOnly) {
+ RayLog.core.info("reconstruct objects {}", objectIds);
+ _reconstruct_objects(client, getIdBytes(objectIds), fetchOnly);
}
@Override
@@ -73,12 +106,13 @@ public class DefaultLocalSchedulerClient implements LocalSchedulerLink {
private static native void _notify_unblocked(long client);
- private static native void _reconstruct_object(long client, byte[] objectId);
+ private static native void _reconstruct_objects(long client, byte[][] objectIds,
+ boolean fetchOnly);
private static native void _put_object(long client, byte[] taskId, byte[] objectId);
// return TaskInfo (in FlatBuffer)
- private static native byte[] _getTaskTodo(long client);
+ private static native byte[] _getTaskTodo(long client, boolean useRaylet);
public static TaskSpec taskInfo2Spec(ByteBuffer bb) {
bb.order(ByteOrder.LITTLE_ENDIAN);
@@ -162,7 +196,10 @@ public class DefaultLocalSchedulerClient implements LocalSchedulerLink {
idOffsets[k] = fbb.createString(task.args[i].ids.get(k).toByteBuffer());
}
objectIdOffset = fbb.createVectorOfTables(idOffsets);
+ } else {
+ objectIdOffset = fbb.createVectorOfTables(new int[0]);
}
+
if (task.args[i].data != null) {
dataOffset = fbb.createString(ByteBuffer.wrap(task.args[i].data));
}
@@ -214,8 +251,17 @@ public class DefaultLocalSchedulerClient implements LocalSchedulerLink {
}
// task -> TaskInfo (with FlatBuffer)
- private static native void _submitTask(long client, byte[] cursorId, /*Direct*/ByteBuffer task,
- int pos, int sz);
+ protected static native void _submitTask(long client, byte[] cursorId, /*Direct*/ByteBuffer task,
+ int pos, int sz, boolean useRaylet);
+
+ private static byte[][] getIdBytes(List objectIds) {
+ int size = objectIds.size();
+ byte[][] ids = new byte[size][];
+ for (int i = 0; i < size; i++) {
+ ids[i] = objectIds.get(i).getBytes();
+ }
+ return ids;
+ }
public void destroy() {
_destroy(client);
diff --git a/java/runtime-native/src/main/java/org/ray/spi/impl/StateStoreProxyImpl.java b/java/runtime-native/src/main/java/org/ray/spi/impl/NonRayletStateStoreProxyImpl.java
similarity index 52%
rename from java/runtime-native/src/main/java/org/ray/spi/impl/StateStoreProxyImpl.java
rename to java/runtime-native/src/main/java/org/ray/spi/impl/NonRayletStateStoreProxyImpl.java
index d6fd6a211..f00267df1 100644
--- a/java/runtime-native/src/main/java/org/ray/spi/impl/StateStoreProxyImpl.java
+++ b/java/runtime-native/src/main/java/org/ray/spi/impl/NonRayletStateStoreProxyImpl.java
@@ -1,96 +1,18 @@
package org.ray.spi.impl;
-import java.io.UnsupportedEncodingException;
import java.util.ArrayList;
-import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
-import java.util.concurrent.TimeUnit;
import org.ray.spi.KeyValueStoreLink;
-import org.ray.spi.StateStoreProxy;
import org.ray.spi.model.AddressInfo;
-import org.ray.util.logger.RayLog;
/**
- * A class used to interface with the Ray control state.
+ * A class used to interface with the Ray control state for non-raylet.
*/
-public class StateStoreProxyImpl implements StateStoreProxy {
-
- public KeyValueStoreLink rayKvStore;
- public ArrayList shardStoreList = new ArrayList<>();
-
- public StateStoreProxyImpl(KeyValueStoreLink rayKvStore) {
- this.rayKvStore = rayKvStore;
- }
-
- public void setStore(KeyValueStoreLink rayKvStore) {
- this.rayKvStore = rayKvStore;
- }
-
- public synchronized void initializeGlobalState() throws Exception {
-
- String es;
-
- checkConnected();
-
- String s = rayKvStore.get("NumRedisShards", null);
- if (s == null) {
- throw new Exception("NumRedisShards not found in redis.");
- }
- int numRedisShards = Integer.parseInt(s);
- if (numRedisShards < 1) {
- es = String.format("Expected at least one Redis shard, found %d", numRedisShards);
- throw new Exception(es);
- }
- List ipAddressPorts = rayKvStore.lrange("RedisShards", 0, -1);
- if (ipAddressPorts.size() != numRedisShards) {
- es = String.format("Expected %d Redis shard addresses, found %d.", numRedisShards,
- ipAddressPorts.size());
- throw new Exception(es);
- }
-
- shardStoreList.clear();
- for (String ipPort : ipAddressPorts) {
- shardStoreList.add(new RedisClient(ipPort));
- }
-
- }
-
- public void checkConnected() throws Exception {
- rayKvStore.checkConnected();
- }
-
- public synchronized Set keys(final String pattern) {
- Set allKeys = new HashSet<>();
- Set tmpKey;
- for (KeyValueStoreLink ashardStoreList : shardStoreList) {
- tmpKey = ashardStoreList.keys(pattern);
- allKeys.addAll(tmpKey);
- }
-
- return allKeys;
-
- }
-
- public List getAddressInfo(final String nodeIpAddress, int numRetries) {
- int count = 0;
- while (count < numRetries) {
- try {
- return getAddressInfoHelper(nodeIpAddress);
- } catch (Exception e) {
- try {
- RayLog.core.warn("Error occurred in StateStoreProxyImpl getAddressInfo, "
- + (numRetries - count) + " retries remaining", e);
- TimeUnit.MILLISECONDS.sleep(1000);
- } catch (InterruptedException ie) {
- RayLog.core.error("error at StateStoreProxyImpl getAddressInfo", e);
- throw new RuntimeException(e);
- }
- }
- count++;
- }
- throw new RuntimeException("cannot get address info from state store");
+public class NonRayletStateStoreProxyImpl extends BaseStateStoreProxyImpl {
+ public NonRayletStateStoreProxyImpl(KeyValueStoreLink rayKvStore) {
+ super(rayKvStore);
}
/*
@@ -108,9 +30,11 @@ public class StateStoreProxyImpl implements StateStoreProxy {
* "manager_socket_name"(op)
* "local_scheduler_socket_name"(op)
*/
- public List getAddressInfoHelper(final String nodeIpAddress) throws Exception {
+ @Override
+ public List doGetAddressInfo(final String nodeIpAddress,
+ final String redisAddress) throws Exception {
if (this.rayKvStore == null) {
- throw new Exception("no redis client when use getAddressInfoHelper");
+ throw new Exception("no redis client when use doGetAddressInfo");
}
List schedulerInfo = new ArrayList<>();
@@ -136,13 +60,13 @@ public class StateStoreProxyImpl implements StateStoreProxy {
} else if (!info.containsKey("client_type".getBytes())) {
throw new Exception("no client_type in any client");
}
-
+
if (charsetDecode(info.get("node_ip_address".getBytes()), "US-ASCII")
.equals(nodeIpAddress)) {
String clientType = charsetDecode(info.get("client_type".getBytes()), "US-ASCII");
- if (clientType.equals("plasma_manager")) {
+ if ("plasma_manager".equals(clientType)) {
plasmaManager.add(info);
- } else if (clientType.equals("local_scheduler")) {
+ } else if ("local_scheduler".equals(clientType)) {
localScheduler.add(info);
}
}
@@ -157,9 +81,9 @@ public class StateStoreProxyImpl implements StateStoreProxy {
for (int i = 0; i < plasmaManager.size(); i++) {
AddressInfo si = new AddressInfo();
si.storeName = charsetDecode(plasmaManager.get(i).get("store_socket_name".getBytes()),
- "US-ASCII");
+ "US-ASCII");
si.managerName = charsetDecode(plasmaManager.get(i).get("manager_socket_name".getBytes()),
- "US-ASCII");
+ "US-ASCII");
byte[] rpc = plasmaManager.get(i).get("manager_rpc_name".getBytes());
if (rpc != null) {
@@ -188,14 +112,4 @@ public class StateStoreProxyImpl implements StateStoreProxy {
return schedulerInfo;
}
- private String charsetDecode(byte[] bs, String charset) throws UnsupportedEncodingException {
- return new String(bs, charset);
- }
-
- private byte[] charsetEncode(String str, String charset) throws UnsupportedEncodingException {
- if (str != null) {
- return str.getBytes(charset);
- }
- return null;
- }
}
diff --git a/java/runtime-native/src/main/java/org/ray/spi/impl/RayletStateStoreProxyImpl.java b/java/runtime-native/src/main/java/org/ray/spi/impl/RayletStateStoreProxyImpl.java
new file mode 100644
index 000000000..0cfa4f532
--- /dev/null
+++ b/java/runtime-native/src/main/java/org/ray/spi/impl/RayletStateStoreProxyImpl.java
@@ -0,0 +1,62 @@
+package org.ray.spi.impl;
+
+import java.nio.ByteBuffer;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Objects;
+import java.util.Set;
+import org.ray.api.UniqueID;
+import org.ray.format.gcs.ClientTableData;
+import org.ray.spi.KeyValueStoreLink;
+import org.ray.spi.model.AddressInfo;
+import org.ray.util.NetworkUtil;
+
+/**
+ * A class used to interface with the Ray control state for raylet.
+ */
+public class RayletStateStoreProxyImpl extends BaseStateStoreProxyImpl {
+
+ public RayletStateStoreProxyImpl(KeyValueStoreLink rayKvStore) {
+ super(rayKvStore);
+ }
+
+ @Override
+ public List doGetAddressInfo(final String nodeIpAddress,
+ final String redisAddress) throws Exception {
+ if (this.rayKvStore == null) {
+ throw new Exception("no redis client when use doGetAddressInfo");
+ }
+ List schedulerInfo = new ArrayList<>();
+
+ byte[] prefix = "CLIENT".getBytes();
+ byte[] postfix = UniqueID.genNil().getBytes();
+ byte[] clientKey = new byte[prefix.length + postfix.length];
+ System.arraycopy(prefix, 0, clientKey, 0, prefix.length);
+ System.arraycopy(postfix, 0, clientKey, prefix.length, postfix.length);
+
+ Set clients = rayKvStore.zrange(clientKey, 0, -1);
+
+ for (byte[] clientMessage : clients) {
+ ByteBuffer bb = ByteBuffer.wrap(clientMessage);
+ ClientTableData client = ClientTableData.getRootAsClientTableData(bb);
+ String clientNodeIpAddress = client.nodeManagerAddress();
+
+ String localIpAddress = NetworkUtil.getIpAddress(null);
+ String redisIpAddress = redisAddress.substring(0, redisAddress.indexOf(':'));
+
+ boolean headNodeAddress = "127.0.0.1".equals(clientNodeIpAddress)
+ && Objects.equals(redisIpAddress, localIpAddress);
+ boolean notHeadNodeAddress = Objects.equals(clientNodeIpAddress, nodeIpAddress);
+
+ if (headNodeAddress || notHeadNodeAddress) {
+ AddressInfo si = new AddressInfo();
+ si.storeName = client.objectStoreSocketName();
+ si.rayletSocketName = client.rayletSocketName();
+ si.managerRpcAddr = client.nodeManagerAddress();
+ si.managerPort = client.nodeManagerPort();
+ schedulerInfo.add(si);
+ }
+ }
+ return schedulerInfo;
+ }
+}
diff --git a/java/runtime-native/src/main/java/org/ray/spi/impl/RedisClient.java b/java/runtime-native/src/main/java/org/ray/spi/impl/RedisClient.java
index 97f744f7d..5dec52e57 100644
--- a/java/runtime-native/src/main/java/org/ray/spi/impl/RedisClient.java
+++ b/java/runtime-native/src/main/java/org/ray/spi/impl/RedisClient.java
@@ -13,6 +13,7 @@ public class RedisClient implements KeyValueStoreLink {
private String redisAddress;
private JedisPool jedisPool;
+ private int handle = 0;
public RedisClient() {
}
@@ -171,6 +172,13 @@ public class RedisClient implements KeyValueStoreLink {
}
}
+ @Override
+ public Set zrange(byte[] key, long start, long end) {
+ try (Jedis jedis = jedisPool.getResource()) {
+ return jedis.zrange(key, start, end);
+ }
+ }
+
@Override
public Long rpush(String key, String... strings) {
try (Jedis jedis = jedisPool.getResource()) {
@@ -203,4 +211,20 @@ public class RedisClient implements KeyValueStoreLink {
public Object getImpl() {
return jedisPool;
}
+
+ @Override
+ public byte[] sendCommand(String command, int commandType, byte[] objectId) {
+ if (handle == 0) {
+ String[] ipPort = redisAddress.split(":");
+ handle = connect(ipPort[0], Integer.parseInt(ipPort[1]));
+ }
+ return execute_command(handle, command, commandType, objectId);
+ }
+
+ private static native int connect(String redisAddress, int port);
+
+ private static native void disconnect(int handle);
+
+ private static native byte[] execute_command(int handle,
+ String command, int commandType, byte[] objectId);
}
diff --git a/java/test.sh b/java/test.sh
index d277473c2..3c7da10bf 100755
--- a/java/test.sh
+++ b/java/test.sh
@@ -10,7 +10,17 @@ mvn clean install -Dmaven.test.skip
check_style=$(mvn checkstyle:check)
echo "${check_style}"
[[ ${check_style} =~ "BUILD FAILURE" ]] && exit 1
+
+# test non-raylet
+sed -i 's/^use_raylet.*$/use_raylet = false/g' $ROOT_DIR/../java/ray.config.ini
mvn_test=$(mvn test)
echo "${mvn_test}"
[[ ${mvn_test} =~ "BUILD SUCCESS" ]] || exit 1
-popd
\ No newline at end of file
+
+# test raylet
+sed -i 's/^use_raylet.*$/use_raylet = true/g' $ROOT_DIR/../java/ray.config.ini
+mvn_test=$(mvn test)
+echo "${mvn_test}"
+[[ ${mvn_test} =~ "BUILD SUCCESS" ]] || exit 1
+
+popd
diff --git a/src/local_scheduler/lib/java/org_ray_spi_impl_DefaultLocalSchedulerClient.cc b/src/local_scheduler/lib/java/org_ray_spi_impl_DefaultLocalSchedulerClient.cc
index b43ca2801..7eef8596d 100644
--- a/src/local_scheduler/lib/java/org_ray_spi_impl_DefaultLocalSchedulerClient.cc
+++ b/src/local_scheduler/lib/java/org_ray_spi_impl_DefaultLocalSchedulerClient.cc
@@ -43,15 +43,15 @@ Java_org_ray_spi_impl_DefaultLocalSchedulerClient__1init(JNIEnv *env,
jbyteArray actorId,
jboolean isWorker,
jbyteArray driverId,
- jlong numGpus) {
+ jlong numGpus,
+ jboolean useRaylet) {
// native private static long _init(String localSchedulerSocket,
// byte[] workerId, byte[] actorId, boolean isWorker, long numGpus);
UniqueIdFromJByteArray worker_id(env, wid);
UniqueIdFromJByteArray driver_id(env, driverId);
const char *nativeString = env->GetStringUTFChars(sockName, JNI_FALSE);
- bool use_raylet = false;
auto client = LocalSchedulerConnection_init(
- nativeString, *worker_id.PID, isWorker, *driver_id.PID, use_raylet);
+ nativeString, *worker_id.PID, isWorker, *driver_id.PID, useRaylet);
env->ReleaseStringUTFChars(sockName, nativeString);
return reinterpret_cast(client);
}
@@ -69,21 +69,30 @@ Java_org_ray_spi_impl_DefaultLocalSchedulerClient__1submitTask(
jbyteArray cursorId,
jobject buff,
jint pos,
- jint sz) {
+ jint sz,
+ jboolean useRaylet) {
// task -> TaskInfo (with FlatBuffer)
// native private static void _submitTask(long client, /*Direct*/ByteBuffer
// task);
auto client = reinterpret_cast(c);
- TaskSpec *task =
- reinterpret_cast(env->GetDirectBufferAddress(buff)) + pos;
+
std::vector execution_dependencies;
if (cursorId != nullptr) {
UniqueIdFromJByteArray cursor_id(env, cursorId);
execution_dependencies.push_back(*cursor_id.PID);
}
- TaskExecutionSpec taskExecutionSpec =
- TaskExecutionSpec(execution_dependencies, task, sz);
- local_scheduler_submit(client, taskExecutionSpec);
+ if (!useRaylet) {
+ TaskSpec *task =
+ reinterpret_cast(env->GetDirectBufferAddress(buff)) + pos;
+ TaskExecutionSpec taskExecutionSpec =
+ TaskExecutionSpec(execution_dependencies, task, sz);
+ local_scheduler_submit(client, taskExecutionSpec);
+ } else {
+ auto data =
+ reinterpret_cast(env->GetDirectBufferAddress(buff)) + pos;
+ ray::raylet::TaskSpecification task_spec(std::string(data, sz));
+ local_scheduler_submit_raylet(client, execution_dependencies, task_spec);
+ }
}
/*
@@ -92,15 +101,19 @@ Java_org_ray_spi_impl_DefaultLocalSchedulerClient__1submitTask(
* Signature: (J)[B
*/
JNIEXPORT jbyteArray JNICALL
-Java_org_ray_spi_impl_DefaultLocalSchedulerClient__1getTaskTodo(JNIEnv *env,
- jclass,
- jlong c) {
+Java_org_ray_spi_impl_DefaultLocalSchedulerClient__1getTaskTodo(
+ JNIEnv *env,
+ jclass,
+ jlong c,
+ jboolean useRaylet) {
// native private static ByteBuffer _getTaskTodo(long client);
auto client = reinterpret_cast(c);
int64_t task_size = 0;
// TODO: handle actor failure later
- TaskSpec *spec = local_scheduler_get_task(client, &task_size);
+ TaskSpec *spec = !useRaylet
+ ? local_scheduler_get_task(client, &task_size)
+ : local_scheduler_get_task_raylet(client, &task_size);
jbyteArray result;
result = env->NewByteArray(task_size);
@@ -178,20 +191,29 @@ Java_org_ray_spi_impl_DefaultLocalSchedulerClient__1task_1done(JNIEnv *,
/*
* Class: org_ray_spi_impl_DefaultLocalSchedulerClient
- * Method: _reconstruct_object
+ * Method: _reconstruct_objects
* Signature: (J[B)V
*/
JNIEXPORT void JNICALL
-Java_org_ray_spi_impl_DefaultLocalSchedulerClient__1reconstruct_1object(
+Java_org_ray_spi_impl_DefaultLocalSchedulerClient__1reconstruct_1objects(
JNIEnv *env,
jclass,
jlong c,
- jbyteArray oid) {
- // native private static void _reconstruct_object(long client, byte[]
- // objectId);
- UniqueIdFromJByteArray o(env, oid);
+ jobjectArray oids,
+ jboolean fetch_only) {
+ // native private static void _reconstruct_objects(long client, byte[][]
+ // objectIds, boolean fetchOnly);
+
+ std::vector object_ids;
+ auto len = env->GetArrayLength(oids);
+ for (int i = 0; i < len; i++) {
+ jbyteArray oid = (jbyteArray) env->GetObjectArrayElement(oids, i);
+ UniqueIdFromJByteArray o(env, oid);
+ object_ids.push_back(*o.PID);
+ env->DeleteLocalRef(oid);
+ }
auto client = reinterpret_cast(c);
- local_scheduler_reconstruct_objects(client, {*o.PID});
+ local_scheduler_reconstruct_objects(client, object_ids, fetch_only);
}
/*
@@ -227,6 +249,55 @@ Java_org_ray_spi_impl_DefaultLocalSchedulerClient__1put_1object(
local_scheduler_put_object(client, *t.PID, *o.PID);
}
+JNIEXPORT jbooleanArray JNICALL
+Java_org_ray_spi_impl_DefaultLocalSchedulerClient__1waitObject(
+ JNIEnv *env,
+ jclass,
+ jlong c,
+ jobjectArray oids,
+ jint num_returns,
+ jint timeout_ms,
+ jboolean wait_local) {
+ std::vector object_ids;
+ auto len = env->GetArrayLength(oids);
+ for (int i = 0; i < len; i++) {
+ jbyteArray oid = (jbyteArray) env->GetObjectArrayElement(oids, i);
+ UniqueIdFromJByteArray o(env, oid);
+ object_ids.push_back(*o.PID);
+ env->DeleteLocalRef(oid);
+ }
+
+ auto client = reinterpret_cast(c);
+
+ // Invoke wait.
+ std::pair, std::vector> result =
+ local_scheduler_wait(client, object_ids, num_returns, timeout_ms,
+ static_cast(wait_local));
+
+ // Convert result to java object.
+ jboolean putValue = true;
+ jbooleanArray resultArray = env->NewBooleanArray(object_ids.size());
+ for (uint i = 0; i < result.first.size(); ++i) {
+ for (uint j = 0; j < object_ids.size(); ++j) {
+ if (result.first[i] == object_ids[j]) {
+ env->SetBooleanArrayRegion(resultArray, j, 1, &putValue);
+ break;
+ }
+ }
+ }
+
+ putValue = false;
+ for (uint i = 0; i < result.second.size(); ++i) {
+ for (uint j = 0; j < object_ids.size(); ++j) {
+ if (result.second[i] == object_ids[j]) {
+ env->SetBooleanArrayRegion(resultArray, j, 1, &putValue);
+ break;
+ }
+ }
+ }
+ return resultArray;
+}
+
#ifdef __cplusplus
}
#endif
diff --git a/src/local_scheduler/lib/java/org_ray_spi_impl_DefaultLocalSchedulerClient.h b/src/local_scheduler/lib/java/org_ray_spi_impl_DefaultLocalSchedulerClient.h
index cb3822ca1..edd6574b6 100644
--- a/src/local_scheduler/lib/java/org_ray_spi_impl_DefaultLocalSchedulerClient.h
+++ b/src/local_scheduler/lib/java/org_ray_spi_impl_DefaultLocalSchedulerClient.h
@@ -20,7 +20,8 @@ Java_org_ray_spi_impl_DefaultLocalSchedulerClient__1init(JNIEnv *,
jbyteArray,
jboolean,
jbyteArray,
- jlong);
+ jlong,
+ jboolean);
/*
* Class: org_ray_spi_impl_DefaultLocalSchedulerClient
@@ -34,7 +35,8 @@ Java_org_ray_spi_impl_DefaultLocalSchedulerClient__1submitTask(JNIEnv *,
jbyteArray,
jobject,
jint,
- jint);
+ jint,
+ jboolean);
/*
* Class: org_ray_spi_impl_DefaultLocalSchedulerClient
@@ -44,7 +46,8 @@ Java_org_ray_spi_impl_DefaultLocalSchedulerClient__1submitTask(JNIEnv *,
JNIEXPORT jbyteArray JNICALL
Java_org_ray_spi_impl_DefaultLocalSchedulerClient__1getTaskTodo(JNIEnv *,
jclass,
- jlong);
+ jlong,
+ jboolean);
/*
* Class: org_ray_spi_impl_DefaultLocalSchedulerClient
@@ -80,15 +83,16 @@ Java_org_ray_spi_impl_DefaultLocalSchedulerClient__1task_1done(JNIEnv *,
/*
* Class: org_ray_spi_impl_DefaultLocalSchedulerClient
- * Method: _reconstruct_object
+ * Method: _reconstruct_objects
* Signature: (J[B)V
*/
JNIEXPORT void JNICALL
-Java_org_ray_spi_impl_DefaultLocalSchedulerClient__1reconstruct_1object(
+Java_org_ray_spi_impl_DefaultLocalSchedulerClient__1reconstruct_1objects(
JNIEnv *,
jclass,
jlong,
- jbyteArray);
+ jobjectArray,
+ jboolean);
/*
* Class: org_ray_spi_impl_DefaultLocalSchedulerClient
@@ -112,6 +116,20 @@ Java_org_ray_spi_impl_DefaultLocalSchedulerClient__1put_1object(JNIEnv *,
jbyteArray,
jbyteArray);
+/*
+ * Class: org_ray_spi_impl_DefaultLocalSchedulerClient
+ * Method: _waitObject
+ * Signature: (J[[BIIZ)[Z
+ */
+JNIEXPORT jbooleanArray JNICALL
+Java_org_ray_spi_impl_DefaultLocalSchedulerClient__1waitObject(JNIEnv *,
+ jclass,
+ jlong,
+ jobjectArray,
+ jint,
+ jint,
+ jboolean);
+
#ifdef __cplusplus
}
#endif
diff --git a/src/local_scheduler/local_scheduler_client.cc b/src/local_scheduler/local_scheduler_client.cc
index 9923e0d2c..775c3518e 100644
--- a/src/local_scheduler/local_scheduler_client.cc
+++ b/src/local_scheduler/local_scheduler_client.cc
@@ -69,7 +69,7 @@ void local_scheduler_log_event(LocalSchedulerConnection *conn,
}
void local_scheduler_submit(LocalSchedulerConnection *conn,
- TaskExecutionSpec &execution_spec) {
+ const TaskExecutionSpec &execution_spec) {
flatbuffers::FlatBufferBuilder fbb;
auto execution_dependencies =
to_flatbuf(fbb, execution_spec.ExecutionDependencies());
@@ -86,7 +86,7 @@ void local_scheduler_submit(LocalSchedulerConnection *conn,
void local_scheduler_submit_raylet(
LocalSchedulerConnection *conn,
const std::vector &execution_dependencies,
- ray::raylet::TaskSpecification task_spec) {
+ const ray::raylet::TaskSpecification &task_spec) {
flatbuffers::FlatBufferBuilder fbb;
auto execution_dependencies_message = to_flatbuf(fbb, execution_dependencies);
auto message = ray::local_scheduler::protocol::CreateSubmitTaskRequest(
diff --git a/src/local_scheduler/local_scheduler_client.h b/src/local_scheduler/local_scheduler_client.h
index e7eebdcdd..5026313d2 100644
--- a/src/local_scheduler/local_scheduler_client.h
+++ b/src/local_scheduler/local_scheduler_client.h
@@ -65,7 +65,7 @@ void LocalSchedulerConnection_free(LocalSchedulerConnection *conn);
* @return Void.
*/
void local_scheduler_submit(LocalSchedulerConnection *conn,
- TaskExecutionSpec &execution_spec);
+ const TaskExecutionSpec &execution_spec);
/// Submit a task using the raylet code path.
///
@@ -76,7 +76,7 @@ void local_scheduler_submit(LocalSchedulerConnection *conn,
void local_scheduler_submit_raylet(
LocalSchedulerConnection *conn,
const std::vector &execution_dependencies,
- ray::raylet::TaskSpecification task_spec);
+ const ray::raylet::TaskSpecification &task_spec);
/**
* Notify the local scheduler that this client is disconnecting gracefully. This
diff --git a/src/ray/raylet/main.cc b/src/ray/raylet/main.cc
index d3cdef460..46a8e139a 100644
--- a/src/ray/raylet/main.cc
+++ b/src/ray/raylet/main.cc
@@ -39,11 +39,10 @@ int main(int argc, char *argv[]) {
RayConfig::instance().num_workers_per_process();
// Use a default worker that can execute empty tasks with dependencies.
- std::stringstream worker_command_stream(worker_command);
- std::string token;
- while (getline(worker_command_stream, token, ' ')) {
- node_manager_config.worker_command.push_back(token);
- }
+ std::istringstream iss(worker_command);
+ std::vector results(std::istream_iterator{iss},
+ std::istream_iterator());
+ node_manager_config.worker_command.swap(results);
node_manager_config.heartbeat_period_ms =
RayConfig::instance().heartbeat_timeout_milliseconds();
@@ -84,8 +83,12 @@ int main(int argc, char *argv[]) {
// Destroy the Raylet on a SIGTERM. The pointer to main_service is
// guaranteed to be valid since this function will run the event loop
// instead of returning immediately.
- auto handler = [&main_service](const boost::system::error_code &error,
- int signal_number) { main_service.stop(); };
+ // We should stop the service and remove the local socket file.
+ auto handler = [&main_service, &raylet_socket_name](
+ const boost::system::error_code &error, int signal_number) {
+ main_service.stop();
+ remove(raylet_socket_name.c_str());
+ };
boost::asio::signal_set signals(main_service, SIGTERM);
signals.async_wait(handler);
diff --git a/src/ray/raylet/task_spec.cc b/src/ray/raylet/task_spec.cc
index 8488da3c4..40ba37f8f 100644
--- a/src/ray/raylet/task_spec.cc
+++ b/src/ray/raylet/task_spec.cc
@@ -37,6 +37,10 @@ TaskSpecification::TaskSpecification(const flatbuffers::String &string) {
AssignSpecification(reinterpret_cast(string.data()), string.size());
}
+TaskSpecification::TaskSpecification(const std::string &string) {
+ AssignSpecification(reinterpret_cast(string.data()), string.size());
+}
+
TaskSpecification::TaskSpecification(
const UniqueID &driver_id, const TaskID &parent_task_id, int64_t parent_counter,
const FunctionID &function_id,
diff --git a/src/ray/raylet/task_spec.h b/src/ray/raylet/task_spec.h
index d9e51dc96..7214b4aa0 100644
--- a/src/ray/raylet/task_spec.h
+++ b/src/ray/raylet/task_spec.h
@@ -109,6 +109,12 @@ class TaskSpecification {
int64_t num_returns,
const std::unordered_map &required_resources);
+ /// Deserialize a task specification from a flatbuffer's string data.
+ ///
+ /// \param string The string data for a serialized task specification
+ /// flatbuffer.
+ TaskSpecification(const std::string &string);
+
~TaskSpecification() {}
/// Serialize the TaskSpecification to a flatbuffer.
diff --git a/src/ray/raylet/worker_pool.cc b/src/ray/raylet/worker_pool.cc
index 0e87d0eb3..24ee3bd76 100644
--- a/src/ray/raylet/worker_pool.cc
+++ b/src/ray/raylet/worker_pool.cc
@@ -121,6 +121,7 @@ void WorkerPool::RegisterWorker(std::shared_ptr worker) {
auto pid = worker->Pid();
RAY_LOG(DEBUG) << "Registering worker with pid " << pid;
registered_workers_.push_back(std::move(worker));
+
auto it = starting_worker_processes_.find(pid);
RAY_CHECK(it != starting_worker_processes_.end());
it->second--;