diff --git a/.gitignore b/.gitignore index 9da7c3c1c..cd28233ac 100644 --- a/.gitignore +++ b/.gitignore @@ -150,6 +150,12 @@ build .vscode/ *.iml + +# Java java/**/target java/run java/**/lib +java/**/.settings +java/**/.classpath +java/**/.project + diff --git a/java/api/src/main/java/org/ray/api/Ray.java b/java/api/src/main/java/org/ray/api/Ray.java index ecf227265..60726a67b 100644 --- a/java/api/src/main/java/org/ray/api/Ray.java +++ b/java/api/src/main/java/org/ray/api/Ray.java @@ -85,8 +85,8 @@ public final class Ray extends Rpc { if (cls.getConstructor() == null) { System.err.println("class " + cls.getName() + " does not (actors must) have a constructor with no arguments"); - RayLog.core.error("class " + cls.getName() - + " does not (actors must) have a constructor with no arguments"); + RayLog.core.error("class {} does not (actors must) have a constructor with no arguments", + cls.getName()); } } catch (Exception e) { System.exit(1); diff --git a/java/api/src/main/java/org/ray/api/internal/RayConnector.java b/java/api/src/main/java/org/ray/api/internal/RayConnector.java index 3d54dcc54..3d50dead1 100644 --- a/java/api/src/main/java/org/ray/api/internal/RayConnector.java +++ b/java/api/src/main/java/org/ray/api/internal/RayConnector.java @@ -19,7 +19,7 @@ public class RayConnector { m.setAccessible(false); return api; } catch (ReflectiveOperationException | IllegalArgumentException | SecurityException e) { - RayLog.core.error("Load " + className + " class failed.", e); + RayLog.core.error("Load {} class failed.", className, e); throw new Error("RayApi is not successfully initiated."); } } diff --git a/java/checkstyle-suppressions.xml b/java/checkstyle-suppressions.xml index f84d9e1ca..0420ef87b 100644 --- a/java/checkstyle-suppressions.xml +++ b/java/checkstyle-suppressions.xml @@ -19,4 +19,5 @@ + diff --git a/java/cleanup.sh b/java/cleanup.sh index 3d73e626a..84b2110bd 100755 --- a/java/cleanup.sh +++ b/java/cleanup.sh @@ -5,4 +5,6 @@ pkill -9 plasma_store pkill -9 global_scheduler pkill -9 redis-server pkill -9 redis +pkill -9 raylet ps aux | grep ray | awk '{system("kill "$2);}' +rm /tmp/raylet* diff --git a/java/cli/src/main/java/org/ray/cli/RayCli.java b/java/cli/src/main/java/org/ray/cli/RayCli.java index 671da14e3..c99541f35 100644 --- a/java/cli/src/main/java/org/ray/cli/RayCli.java +++ b/java/cli/src/main/java/org/ray/cli/RayCli.java @@ -21,8 +21,9 @@ import org.ray.spi.PathConfig; import org.ray.spi.RemoteFunctionManager; import org.ray.spi.StateStoreProxy; import org.ray.spi.impl.NativeRemoteFunctionManager; +import org.ray.spi.impl.NonRayletStateStoreProxyImpl; +import org.ray.spi.impl.RayletStateStoreProxyImpl; import org.ray.spi.impl.RedisClient; -import org.ray.spi.impl.StateStoreProxyImpl; import org.ray.util.FileUtil; import org.ray.util.config.ConfigReader; import org.ray.util.logger.RayLog; @@ -47,7 +48,7 @@ public class RayCli { throw new RuntimeException("Ray head node start failed", e); } - RayLog.core.info("Started Ray head node. Redis address: " + manager.info().redisAddress); + RayLog.core.info("Started Ray head node. Redis address: {}", manager.info().redisAddress); return manager; } @@ -74,7 +75,7 @@ public class RayCli { // Init RayLog before using it. RayLog.init(params.working_directory); - RayLog.core.info("Using IP address " + params.node_ip_address + " for this node."); + RayLog.core.info("Using IP address {} for this node.", params.node_ip_address); RunManager manager; if (cmdStart.head) { manager = startRayHead(params, paths, config); @@ -152,7 +153,9 @@ public class RayCli { KeyValueStoreLink kvStore = new RedisClient(); kvStore.setAddr(cmdSubmit.redisAddress); - StateStoreProxy stateStoreProxy = new StateStoreProxyImpl(kvStore); + StateStoreProxy stateStoreProxy = params.use_raylet + ? new RayletStateStoreProxyImpl(kvStore) + : new NonRayletStateStoreProxyImpl(kvStore); stateStoreProxy.initializeGlobalState(); RemoteFunctionManager functionManager = new NativeRemoteFunctionManager(kvStore); diff --git a/java/common/src/main/java/org/ray/util/MethodId.java b/java/common/src/main/java/org/ray/util/MethodId.java index f32b33510..1b517f736 100644 --- a/java/common/src/main/java/org/ray/util/MethodId.java +++ b/java/common/src/main/java/org/ray/util/MethodId.java @@ -117,7 +117,7 @@ public final class MethodId { cls = Class .forName(className, true, loader == null ? this.getClass().getClassLoader() : loader); } catch (Throwable e) { - RayLog.core.error("Cannot load class " + className, e); + RayLog.core.error("Cannot load class {}", className, e); return null; } @@ -148,7 +148,7 @@ public final class MethodId { if (methods.size() != 1) { RayLog.core.error( - "Load method " + toString() + " failed as there are " + methods.size() + " definitions"); + "Load method {} failed as there are {} definitions.", toString(), methods.size()); return null; } diff --git a/java/common/src/main/java/org/ray/util/NetworkUtil.java b/java/common/src/main/java/org/ray/util/NetworkUtil.java index d5e48999c..4eddbd7df 100644 --- a/java/common/src/main/java/org/ray/util/NetworkUtil.java +++ b/java/common/src/main/java/org/ray/util/NetworkUtil.java @@ -35,9 +35,9 @@ public class NetworkUtil { return addr.getHostAddress(); } } - RayLog.core.warn("you may need to correctly specify [ray.java] net_interface in config"); + RayLog.core.warn("You need to correctly specify [ray.java] net_interface in config."); } catch (Exception e) { - RayLog.core.error("Can't get our ip address, use 127.0.0.1 as default.", e); + RayLog.core.error("Can't get ip address, use 127.0.0.1 as default.", e); } return "127.0.0.1"; diff --git a/java/common/src/main/java/org/ray/util/Sha1Digestor.java b/java/common/src/main/java/org/ray/util/Sha1Digestor.java index 2c0c67789..b9d520609 100644 --- a/java/common/src/main/java/org/ray/util/Sha1Digestor.java +++ b/java/common/src/main/java/org/ray/util/Sha1Digestor.java @@ -10,8 +10,8 @@ public class Sha1Digestor { try { return MessageDigest.getInstance("SHA1"); } catch (Exception e) { - RayLog.core.error("cannot get SHA1 MessageDigest", e); - throw new RuntimeException("cannot get SHA1 digest", e); + RayLog.core.error("Cannot get SHA1 MessageDigest", e); + throw new RuntimeException("Cannot get SHA1 digest", e); } }); diff --git a/java/pom.xml b/java/pom.xml index 89584b6ee..73854c96d 100644 --- a/java/pom.xml +++ b/java/pom.xml @@ -63,7 +63,7 @@ com.github.davidmoten flatbuffers-java - 1.7.0.1 + 1.9.0.1 diff --git a/java/prepare.sh b/java/prepare.sh index d59b92b09..b699b4a5f 100755 --- a/java/prepare.sh +++ b/java/prepare.sh @@ -47,12 +47,15 @@ declare -a nativeBinaries=( "./src/plasma/plasma_manager" "./src/local_scheduler/local_scheduler" "./src/global_scheduler/global_scheduler" + "./src/ray/raylet/raylet" + "./src/ray/raylet/raylet_monitor" ) declare -a nativeLibraries=( "./src/common/redis_module/libray_redis_module.so" "./src/local_scheduler/liblocal_scheduler_library_java.*" "./src/plasma/libplasma_java.*" + "./src/ray/raylet/*lib.a" ) declare -a javaBinaries=( diff --git a/java/ray.config.ini b/java/ray.config.ini index c9630e75b..a8a6b0981 100644 --- a/java/ray.config.ini +++ b/java/ray.config.ini @@ -62,6 +62,8 @@ deploy = false onebox_delay_seconds_before_run_app_logic = 0 +use_raylet = false + ; java class which main is served as the driver in a java worker driver_class = @@ -123,6 +125,7 @@ store = %CONFIG_FILE_DIR%/../build/src/plasma/plasma_store store_manager = %CONFIG_FILE_DIR%/../build/src/plasma/plasma_manager local_scheduler = %CONFIG_FILE_DIR%/../build/src/local_scheduler/local_scheduler global_scheduler = %CONFIG_FILE_DIR%/../build/src/global_scheduler/global_scheduler +raylet = %CONFIG_FILE_DIR%/../build/src/ray/raylet/raylet python_dir = %CONFIG_FILE_DIR%/../build/ java_runtime_rewritten_jars_dir = java_class_paths = ray.java.path.classes.source @@ -135,6 +138,7 @@ store = %CONFIG_FILE_DIR%/../build/src/plasma/plasma_store store_manager = %CONFIG_FILE_DIR%/../build/src/plasma/plasma_manager local_scheduler = %CONFIG_FILE_DIR%/../build/src/local_scheduler/local_scheduler global_scheduler = %CONFIG_FILE_DIR%/../build/src/global_scheduler/global_scheduler +raylet = %CONFIG_FILE_DIR%/../build/src/ray/raylet/raylet python_dir = %CONFIG_FILE_DIR%/../build/ java_runtime_rewritten_jars_dir = java_class_paths = ray.java.path.classes.package @@ -147,6 +151,7 @@ store = %CONFIG_FILE_DIR%/native/bin/plasma_store store_manager = %CONFIG_FILE_DIR%/native/bin/plasma_manager local_scheduler = %CONFIG_FILE_DIR%/native/bin/local_scheduler global_scheduler = %CONFIG_FILE_DIR%/native/bin/global_scheduler +raylet = %CONFIG_FILE_DIR%/native/bin/raylet python_dir = %CONFIG_FILE_DIR%/python java_runtime_rewritten_jars_dir = %CONFIG_FILE_DIR%/java/lib/ java_class_paths = ray.java.path.classes.deploy diff --git a/java/runtime-common/src/main/java/org/ray/core/RayRuntime.java b/java/runtime-common/src/main/java/org/ray/core/RayRuntime.java index 52c2b94e8..25bacf062 100644 --- a/java/runtime-common/src/main/java/org/ray/core/RayRuntime.java +++ b/java/runtime-common/src/main/java/org/ray/core/RayRuntime.java @@ -1,5 +1,6 @@ package org.ray.core; +import com.google.common.collect.ImmutableList; import java.lang.reflect.Constructor; import java.lang.reflect.InvocationTargetException; import java.util.ArrayList; @@ -122,7 +123,13 @@ public abstract class RayRuntime implements RayApi { functions = new LocalFunctionManager(remoteLoader); localSchedulerProxy = new LocalSchedulerProxy(slink); - objectStoreProxy = new ObjectStoreProxy(plink); + + if (!params.use_raylet) { + objectStoreProxy = new ObjectStoreProxy(plink); + } else { + objectStoreProxy = new ObjectStoreProxy(plink, slink); + } + worker = new Worker(localSchedulerProxy, functions); } @@ -188,7 +195,9 @@ public abstract class RayRuntime implements RayApi { public void putRaw(UniqueID taskId, UniqueID objectId, T obj, TMT metadata) { RayLog.core.info("Task " + taskId.toString() + " Object " + objectId.toString() + " put"); - localSchedulerProxy.markTaskPutDependency(taskId, objectId); + if (!params.use_raylet) { + localSchedulerProxy.markTaskPutDependency(taskId, objectId); + } objectStoreProxy.put(objectId, obj, metadata); } @@ -274,22 +283,32 @@ public abstract class RayRuntime implements RayApi { return worker.rpcWithReturnIndices(taskId, funcCls, lambda, returnCount, args); } + private List doGet(List objectIds, boolean isMetadata) throws TaskExecutionException { boolean wasBlocked = false; UniqueID taskId = getCurrentTaskId(); + try { int numObjectIds = objectIds.size(); // Do an initial fetch for remote objects. - dividedFetch(objectIds); + List> fetchBatches = + splitIntoBatches(objectIds, params.worker_fetch_request_size); + for (List batch : fetchBatches) { + if (!params.use_raylet) { + objectStoreProxy.fetch(batch); + } else { + localSchedulerProxy.reconstructObjects(batch, true); + } + } // Get the objects. We initially try to get the objects immediately. List> ret = objectStoreProxy .get(objectIds, params.default_first_check_timeout_ms, isMetadata); assert ret.size() == numObjectIds; - // mapping the object IDs that we haven't gotten yet to their original index in objectIds + // Mapping the object IDs that we haven't gotten yet to their original index in objectIds. Map unreadys = new HashMap<>(); for (int i = 0; i < numObjectIds; i++) { if (ret.get(i).getRight() != GetStatus.SUCCESS) { @@ -301,15 +320,22 @@ public abstract class RayRuntime implements RayApi { // Try reconstructing any objects we haven't gotten yet. Try to get them // until at least PlasmaLink.GET_TIMEOUT_MS milliseconds passes, then repeat. while (unreadys.size() > 0) { - for (UniqueID id : unreadys.keySet()) { - localSchedulerProxy.reconstructObject(id); - } - - // Do another fetch for objects that aren't available locally yet, in case - // they were evicted since the last fetch. List unreadyList = new ArrayList<>(unreadys.keySet()); + List> reconstructBatches = + splitIntoBatches(unreadyList, params.worker_fetch_request_size); - dividedFetch(unreadyList); + for (List batch : reconstructBatches) { + if (!params.use_raylet) { + for (UniqueID objectId : batch) { + localSchedulerProxy.reconstructObject(objectId, false); + } + // Do another fetch for objects that aren't available locally yet, in case + // they were evicted since the last fetch. + objectStoreProxy.fetch(batch); + } else { + localSchedulerProxy.reconstructObjects(batch, false); + } + } List> results = objectStoreProxy .get(unreadyList, params.default_get_check_interval_ms, isMetadata); @@ -329,9 +355,11 @@ public abstract class RayRuntime implements RayApi { RayLog.core .debug("Task " + taskId + " Objects " + Arrays.toString(objectIds.toArray()) + " get"); List finalRet = new ArrayList<>(); + for (Pair value : ret) { finalRet.add(value.getLeft()); } + return finalRet; } catch (TaskExecutionException e) { RayLog.core.error("Task " + taskId + " Objects " + Arrays.toString(objectIds.toArray()) @@ -344,68 +372,30 @@ public abstract class RayRuntime implements RayApi { localSchedulerProxy.notifyUnblocked(); } } - } private T doGet(UniqueID objectId, boolean isMetadata) throws TaskExecutionException { + ImmutableList objectIds = ImmutableList.of(objectId); + List results = doGet(objectIds, isMetadata); - boolean wasBlocked = false; - UniqueID taskId = getCurrentTaskId(); - try { - // Do an initial fetch. - objectStoreProxy.fetch(objectId); - - // Get the object. We initially try to get the object immediately. - Pair ret = objectStoreProxy - .get(objectId, params.default_first_check_timeout_ms, isMetadata); - - wasBlocked = (ret.getRight() != GetStatus.SUCCESS); - - // Try reconstructing the object. Try to get it until at least PlasmaLink.GET_TIMEOUT_MS - // milliseconds passes, then repeat. - while (ret.getRight() != GetStatus.SUCCESS) { - RayLog.core.warn( - "Task " + taskId + " Object " + objectId.toString() + " get failed, reconstruct ..."); - localSchedulerProxy.reconstructObject(objectId); - - // Do another fetch - objectStoreProxy.fetch(objectId); - - ret = objectStoreProxy.get(objectId, params.default_get_check_interval_ms, - isMetadata);//check the result every 5s, but it will return once available - } - RayLog.core.debug( - "Task " + taskId + " Object " + objectId.toString() + " get" + ", the result " + ret - .getLeft()); - return ret.getLeft(); - } catch (TaskExecutionException e) { - RayLog.core - .error("Task " + taskId + " Object " + objectId.toString() + " get with Exception", e); - throw e; - } finally { - // If the object was not able to get locally, let the local scheduler - // know that we're now unblocked. - if (wasBlocked) { - localSchedulerProxy.notifyUnblocked(); - } - } - + assert results.size() == 1; + return results.get(0); } - // We divide the fetch into smaller fetches so as to not block the manager - // for a prolonged period of time in a single call. - private void dividedFetch(List objectIds) { - int fetchSize = objectStoreProxy.getFetchSize(); + private List> splitIntoBatches(List objectIds, int batchSize) { + List> batches = new ArrayList<>(); + int objectsSize = objectIds.size(); - int numObjectIds = objectIds.size(); - for (int i = 0; i < numObjectIds; i += fetchSize) { - int endIndex = i + fetchSize; - if (endIndex < numObjectIds) { - objectStoreProxy.fetch(objectIds.subList(i, endIndex)); - } else { - objectStoreProxy.fetch(objectIds.subList(i, numObjectIds)); - } + for (int i = 0; i < objectsSize; i += batchSize) { + int endIndex = i + batchSize; + List batchIds = (endIndex < objectsSize) + ? objectIds.subList(i, endIndex) + : objectIds.subList(i, objectsSize); + + batches.add(batchIds); } + + return batches; } /** diff --git a/java/runtime-common/src/main/java/org/ray/core/model/RayParameters.java b/java/runtime-common/src/main/java/org/ray/core/model/RayParameters.java index f2079c64d..a7ca6ac62 100644 --- a/java/runtime-common/src/main/java/org/ray/core/model/RayParameters.java +++ b/java/runtime-common/src/main/java/org/ray/core/model/RayParameters.java @@ -112,6 +112,18 @@ public class RayParameters { @AConfig(comment = "delay seconds under onebox before app logic for debugging") public int onebox_delay_seconds_before_run_app_logic = 0; + @AConfig(comment = "whether to use raylet") + public boolean use_raylet = false; + + @AConfig(comment = "raylet socket name (e.g., /tmp/raylet1111") + public String raylet_socket_name = ""; + + @AConfig(comment = "raylet rpc listen port") + public int raylet_port = 35567; + + @AConfig(comment = "worker fetch request size") + public int worker_fetch_request_size = 10000; + public RayParameters(ConfigReader config) { if (null != config) { String networkInterface = config.getStringValue("ray.java", "network_interface", null, diff --git a/java/runtime-common/src/main/java/org/ray/spi/LocalSchedulerLink.java b/java/runtime-common/src/main/java/org/ray/spi/LocalSchedulerLink.java index 93852ea3c..b04b2508d 100644 --- a/java/runtime-common/src/main/java/org/ray/spi/LocalSchedulerLink.java +++ b/java/runtime-common/src/main/java/org/ray/spi/LocalSchedulerLink.java @@ -1,5 +1,6 @@ package org.ray.spi; +import java.util.List; import org.ray.api.UniqueID; import org.ray.spi.model.TaskSpec; @@ -14,7 +15,11 @@ public interface LocalSchedulerLink { void markTaskPutDependency(UniqueID taskId, UniqueID objectId); - void reconstructObject(UniqueID objectId); + void reconstructObject(UniqueID objectId, boolean fetchOnly); + + void reconstructObjects(List objectIds, boolean fetchOnly); void notifyUnblocked(); + + List wait(byte[][] objectIds, int timeoutMs, int numReturns); } diff --git a/java/runtime-common/src/main/java/org/ray/spi/LocalSchedulerProxy.java b/java/runtime-common/src/main/java/org/ray/spi/LocalSchedulerProxy.java index 8750f4aa0..f999f8e8a 100644 --- a/java/runtime-common/src/main/java/org/ray/spi/LocalSchedulerProxy.java +++ b/java/runtime-common/src/main/java/org/ray/spi/LocalSchedulerProxy.java @@ -1,13 +1,17 @@ package org.ray.spi; +import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; import java.util.HashMap; +import java.util.List; import java.util.Map; +import org.ray.api.RayList; import org.ray.api.RayMap; import org.ray.api.RayObject; import org.ray.api.RayObjects; import org.ray.api.UniqueID; +import org.ray.api.WaitResult; import org.ray.core.ArgumentsBuilder; import org.ray.core.UniqueIdHelper; import org.ray.core.WorkerContext; @@ -124,11 +128,44 @@ public class LocalSchedulerProxy { scheduler.markTaskPutDependency(taskId, objectId); } - public void reconstructObject(UniqueID objectId) { - scheduler.reconstructObject(objectId); + public void reconstructObject(UniqueID objectId, boolean fetchOnly) { + scheduler.reconstructObject(objectId, fetchOnly); + } + + public void reconstructObjects(List objectIds, boolean fetchOnly) { + scheduler.reconstructObjects(objectIds, fetchOnly); } public void notifyUnblocked() { scheduler.notifyUnblocked(); } + + private static byte[][] getIdBytes(List objectIds) { + int size = objectIds.size(); + byte[][] ids = new byte[size][]; + for (int i = 0; i < size; i++) { + ids[i] = objectIds.get(i).getBytes(); + } + return ids; + } + + public WaitResult wait(RayList waitfor, int numReturns, int timeout) { + List ids = new ArrayList<>(); + for (RayObject obj : waitfor.Objects()) { + ids.add(obj.getId()); + } + List readys = scheduler.wait(getIdBytes(ids), timeout, numReturns); + + RayList readyObjs = new RayList<>(); + RayList remainObjs = new RayList<>(); + for (RayObject obj : waitfor.Objects()) { + if (readys.contains(obj.getId().getBytes())) { + readyObjs.add(obj); + } else { + remainObjs.add(obj); + } + } + + return new WaitResult<>(readyObjs, remainObjs); + } } diff --git a/java/runtime-common/src/main/java/org/ray/spi/ObjectStoreProxy.java b/java/runtime-common/src/main/java/org/ray/spi/ObjectStoreProxy.java index 3b4b34f8e..193afe537 100644 --- a/java/runtime-common/src/main/java/org/ray/spi/ObjectStoreProxy.java +++ b/java/runtime-common/src/main/java/org/ray/spi/ObjectStoreProxy.java @@ -10,6 +10,7 @@ import org.ray.api.UniqueID; import org.ray.api.WaitResult; import org.ray.core.Serializer; import org.ray.core.WorkerContext; +import org.ray.spi.LocalSchedulerLink; import org.ray.util.exception.TaskExecutionException; /** @@ -19,12 +20,19 @@ import org.ray.util.exception.TaskExecutionException; public class ObjectStoreProxy { private final ObjectStoreLink store; + private final LocalSchedulerLink localSchedulerLink; private final int getTimeoutMs = 1000; public ObjectStoreProxy(ObjectStoreLink store) { this.store = store; + this.localSchedulerLink = null; } + public ObjectStoreProxy(ObjectStoreLink store, LocalSchedulerLink localSchedulerLink) { + this.store = store; + this.localSchedulerLink = localSchedulerLink; + } + public Pair get(UniqueID objectId, boolean isMetadata) throws TaskExecutionException { return get(objectId, getTimeoutMs, isMetadata); @@ -88,7 +96,12 @@ public class ObjectStoreProxy { for (RayObject obj : waitfor.Objects()) { ids.add(obj.getId()); } - List readys = store.wait(getIdBytes(ids), timeout, numReturns); + List readys; + if (localSchedulerLink == null) { + readys = store.wait(getIdBytes(ids), timeout, numReturns); + } else { + readys = localSchedulerLink.wait(getIdBytes(ids), timeout, numReturns); + } RayList readyObjs = new RayList<>(); RayList remainObjs = new RayList<>(); @@ -103,19 +116,14 @@ public class ObjectStoreProxy { return new WaitResult<>(readyObjs, remainObjs); } - public void fetch(UniqueID objectId) { - store.fetch(objectId.getBytes()); - } - public void fetch(List objectIds) { - store.fetch(getIdBytes(objectIds)); + if (localSchedulerLink == null) { + store.fetch(getIdBytes(objectIds)); + } else { + localSchedulerLink.reconstructObjects(objectIds, true); + } } - public int getFetchSize() { - return 10000; - } - - public enum GetStatus { SUCCESS, FAILED } diff --git a/java/runtime-common/src/main/java/org/ray/spi/PathConfig.java b/java/runtime-common/src/main/java/org/ray/spi/PathConfig.java index f33e685b4..e8e0f1228 100644 --- a/java/runtime-common/src/main/java/org/ray/spi/PathConfig.java +++ b/java/runtime-common/src/main/java/org/ray/spi/PathConfig.java @@ -37,6 +37,9 @@ public class PathConfig { @AConfig(comment = "path to global scheduler") public String global_scheduler; + @AConfig(comment = "path to raylet") + public String raylet; + @AConfig(comment = "path to python directory") public String python_dir; diff --git a/java/runtime-common/src/main/java/org/ray/spi/model/AddressInfo.java b/java/runtime-common/src/main/java/org/ray/spi/model/AddressInfo.java index 7276278d0..dcf131d43 100644 --- a/java/runtime-common/src/main/java/org/ray/spi/model/AddressInfo.java +++ b/java/runtime-common/src/main/java/org/ray/spi/model/AddressInfo.java @@ -8,9 +8,11 @@ public class AddressInfo { public String managerName; public String storeName; public String schedulerName; + public String rayletSocketName; public int managerPort; public int workerCount; public String managerRpcAddr; public String storeRpcAddr; public String schedulerRpcAddr; + public String rayletRpcAddr; } diff --git a/java/runtime-dev/pom.xml b/java/runtime-dev/pom.xml index 5bc1bb6d4..891a333e1 100644 --- a/java/runtime-dev/pom.xml +++ b/java/runtime-dev/pom.xml @@ -1,5 +1,4 @@ - diff --git a/java/runtime-dev/src/main/java/org/ray/spi/impl/MockLocalScheduler.java b/java/runtime-dev/src/main/java/org/ray/spi/impl/MockLocalScheduler.java index 89bbee1b1..626f08473 100644 --- a/java/runtime-dev/src/main/java/org/ray/spi/impl/MockLocalScheduler.java +++ b/java/runtime-dev/src/main/java/org/ray/spi/impl/MockLocalScheduler.java @@ -1,5 +1,6 @@ package org.ray.spi.impl; +import java.util.List; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; import org.ray.api.UniqueID; @@ -74,7 +75,12 @@ public class MockLocalScheduler implements LocalSchedulerLink { } @Override - public void reconstructObject(UniqueID objectId) { + public void reconstructObject(UniqueID objectId, boolean fetchOnly) { + + } + + @Override + public void reconstructObjects(List objectIds, boolean fetchOnly) { } @@ -82,4 +88,9 @@ public class MockLocalScheduler implements LocalSchedulerLink { public void notifyUnblocked() { } + + @Override + public List wait(byte[][] objectIds, int timeoutMs, int numReturns) { + return store.wait(objectIds, timeoutMs, numReturns); + } } diff --git a/java/runtime-native/src/main/java/org/ray/core/impl/RayNativeRuntime.java b/java/runtime-native/src/main/java/org/ray/core/impl/RayNativeRuntime.java index b1e708b21..bcb8f4c37 100644 --- a/java/runtime-native/src/main/java/org/ray/core/impl/RayNativeRuntime.java +++ b/java/runtime-native/src/main/java/org/ray/core/impl/RayNativeRuntime.java @@ -25,8 +25,9 @@ import org.ray.spi.RemoteFunctionManager; import org.ray.spi.StateStoreProxy; import org.ray.spi.impl.DefaultLocalSchedulerClient; import org.ray.spi.impl.NativeRemoteFunctionManager; +import org.ray.spi.impl.NonRayletStateStoreProxyImpl; +import org.ray.spi.impl.RayletStateStoreProxyImpl; import org.ray.spi.impl.RedisClient; -import org.ray.spi.impl.StateStoreProxyImpl; import org.ray.spi.model.AddressInfo; import org.ray.util.exception.TaskExecutionException; import org.ray.util.logger.RayLog; @@ -62,14 +63,19 @@ public class RayNativeRuntime extends RayRuntime { throw new Error("Redis address must be configured under Worker mode."); } startOnebox(params, pathConfig); - initStateStore(params.redis_address); + initStateStore(params.redis_address, params.use_raylet); } else { - initStateStore(params.redis_address); + initStateStore(params.redis_address, params.use_raylet); if (!isWorker) { - List nodes = stateStoreProxy.getAddressInfo(params.node_ip_address, 5); + List nodes = stateStoreProxy.getAddressInfo( + params.node_ip_address, params.redis_address, 5); params.object_store_name = nodes.get(0).storeName; - params.object_store_manager_name = nodes.get(0).managerName; - params.local_scheduler_name = nodes.get(0).schedulerName; + if (!params.use_raylet) { + params.object_store_manager_name = nodes.get(0).managerName; + params.local_scheduler_name = nodes.get(0).schedulerName; + } else { + params.raylet_socket_name = nodes.get(0).rayletSocketName; + } } } @@ -101,23 +107,45 @@ public class RayNativeRuntime extends RayRuntime { .getIntegerValue("ray", "plasma_default_release_delay", 0, "how many release requests should be delayed in plasma client"); - ObjectStoreLink plink = new PlasmaClient(params.object_store_name, params - .object_store_manager_name, releaseDelay); + if (!params.use_raylet) { + ObjectStoreLink plink = new PlasmaClient(params.object_store_name, + params.object_store_manager_name, releaseDelay); - LocalSchedulerLink slink = new DefaultLocalSchedulerClient( - params.local_scheduler_name, - WorkerContext.currentWorkerId(), - UniqueID.nil, - isWorker, - WorkerContext.currentTask().taskId, - 0 - ); + LocalSchedulerLink slink = new DefaultLocalSchedulerClient( + params.local_scheduler_name, + WorkerContext.currentWorkerId(), + UniqueID.nil, + isWorker, + WorkerContext.currentTask().taskId, + 0, + false + ); - init(slink, plink, funcMgr, pathConfig); + init(slink, plink, funcMgr, pathConfig); - // register - registerWorker(isWorker, params.node_ip_address, params.object_store_name, - params.object_store_manager_name, params.local_scheduler_name); + // register + registerWorker(isWorker, params.node_ip_address, params.object_store_name, + params.object_store_manager_name, params.local_scheduler_name); + } else { + + ObjectStoreLink plink = new PlasmaClient(params.object_store_name, "", releaseDelay); + + LocalSchedulerLink slink = new DefaultLocalSchedulerClient( + params.raylet_socket_name, + WorkerContext.currentWorkerId(), + UniqueID.nil, + isWorker, + WorkerContext.currentTask().taskId, + 0, + true + ); + + init(slink, plink, funcMgr, pathConfig); + + // register + registerWorker(isWorker, params.node_ip_address, params.object_store_name, + params.raylet_socket_name); + } } RayLog.core.info("RayNativeRuntime start with " @@ -152,19 +180,44 @@ public class RayNativeRuntime extends RayRuntime { params.object_store_name = manager.info().localStores.get(0).storeName; params.object_store_manager_name = manager.info().localStores.get(0).managerName; params.local_scheduler_name = manager.info().localStores.get(0).schedulerName; + params.raylet_socket_name = manager.info().localStores.get(0).rayletSocketName; //params.node_ip_address = NetworkUtil.getIpAddress(); } - private void initStateStore(String redisAddress) throws Exception { + private void initStateStore(String redisAddress, boolean useRaylet) throws Exception { kvStore = new RedisClient(); kvStore.setAddr(redisAddress); - stateStoreProxy = new StateStoreProxyImpl(kvStore); + stateStoreProxy = useRaylet + ? new RayletStateStoreProxyImpl(kvStore) + : new NonRayletStateStoreProxyImpl(kvStore); //stateStoreProxy.setStore(kvStore); stateStoreProxy.initializeGlobalState(); } private void registerWorker(boolean isWorker, String nodeIpAddress, String storeName, - String managerName, String schedulerName) { + String rayletSocketName) { + Map workerInfo = new HashMap<>(); + String workerId = new String(WorkerContext.currentWorkerId().getBytes()); + if (!isWorker) { + workerInfo.put("node_ip_address", nodeIpAddress); + workerInfo.put("driver_id", workerId); + workerInfo.put("start_time", String.valueOf(System.currentTimeMillis())); + workerInfo.put("plasma_store_socket", storeName); + workerInfo.put("raylet_socket", rayletSocketName); + workerInfo.put("name", System.getProperty("user.dir")); + //TODO: worker.redis_client.hmset(b"Drivers:" + worker.workerId, driver_info) + kvStore.hmset("Drivers:" + workerId, workerInfo); + } else { + workerInfo.put("node_ip_address", nodeIpAddress); + workerInfo.put("plasma_store_socket", storeName); + workerInfo.put("raylet_socket", rayletSocketName); + //TODO: b"Workers:" + worker.workerId, + kvStore.hmset("Workers:" + workerId, workerInfo); + } + } + + private void registerWorker(boolean isWorker, String nodeIpAddress, String storeName, + String managerName, String schedulerName) { Map workerInfo = new HashMap<>(); String workerId = new String(WorkerContext.currentWorkerId().getBytes()); if (!isWorker) { diff --git a/java/runtime-native/src/main/java/org/ray/format/gcs/ClientTableData.java b/java/runtime-native/src/main/java/org/ray/format/gcs/ClientTableData.java new file mode 100644 index 000000000..e7d71415d --- /dev/null +++ b/java/runtime-native/src/main/java/org/ray/format/gcs/ClientTableData.java @@ -0,0 +1,79 @@ +package org.ray.format.gcs; +// automatically generated by the FlatBuffers compiler, do not modify + +import java.nio.*; +import java.lang.*; +import com.google.flatbuffers.*; + +@SuppressWarnings("unused") +public final class ClientTableData extends Table { + public static ClientTableData getRootAsClientTableData(ByteBuffer _bb) { return getRootAsClientTableData(_bb, new ClientTableData()); } + public static ClientTableData getRootAsClientTableData(ByteBuffer _bb, ClientTableData obj) { _bb.order(ByteOrder.LITTLE_ENDIAN); return (obj.__assign(_bb.getInt(_bb.position()) + _bb.position(), _bb)); } + public void __init(int _i, ByteBuffer _bb) { bb_pos = _i; bb = _bb; } + public ClientTableData __assign(int _i, ByteBuffer _bb) { __init(_i, _bb); return this; } + + public String clientId() { int o = __offset(4); return o != 0 ? __string(o + bb_pos) : null; } + public ByteBuffer clientIdAsByteBuffer() { return __vector_as_bytebuffer(4, 1); } + public ByteBuffer clientIdInByteBuffer(ByteBuffer _bb) { return __vector_in_bytebuffer(_bb, 4, 1); } + public String nodeManagerAddress() { int o = __offset(6); return o != 0 ? __string(o + bb_pos) : null; } + public ByteBuffer nodeManagerAddressAsByteBuffer() { return __vector_as_bytebuffer(6, 1); } + public ByteBuffer nodeManagerAddressInByteBuffer(ByteBuffer _bb) { return __vector_in_bytebuffer(_bb, 6, 1); } + public String rayletSocketName() { int o = __offset(8); return o != 0 ? __string(o + bb_pos) : null; } + public ByteBuffer rayletSocketNameAsByteBuffer() { return __vector_as_bytebuffer(8, 1); } + public ByteBuffer rayletSocketNameInByteBuffer(ByteBuffer _bb) { return __vector_in_bytebuffer(_bb, 8, 1); } + public String objectStoreSocketName() { int o = __offset(10); return o != 0 ? __string(o + bb_pos) : null; } + public ByteBuffer objectStoreSocketNameAsByteBuffer() { return __vector_as_bytebuffer(10, 1); } + public ByteBuffer objectStoreSocketNameInByteBuffer(ByteBuffer _bb) { return __vector_in_bytebuffer(_bb, 10, 1); } + public int nodeManagerPort() { int o = __offset(12); return o != 0 ? bb.getInt(o + bb_pos) : 0; } + public int objectManagerPort() { int o = __offset(14); return o != 0 ? bb.getInt(o + bb_pos) : 0; } + public boolean isInsertion() { int o = __offset(16); return o != 0 ? 0!=bb.get(o + bb_pos) : false; } + public String resourcesTotalLabel(int j) { int o = __offset(18); return o != 0 ? __string(__vector(o) + j * 4) : null; } + public int resourcesTotalLabelLength() { int o = __offset(18); return o != 0 ? __vector_len(o) : 0; } + public double resourcesTotalCapacity(int j) { int o = __offset(20); return o != 0 ? bb.getDouble(__vector(o) + j * 8) : 0; } + public int resourcesTotalCapacityLength() { int o = __offset(20); return o != 0 ? __vector_len(o) : 0; } + public ByteBuffer resourcesTotalCapacityAsByteBuffer() { return __vector_as_bytebuffer(20, 8); } + public ByteBuffer resourcesTotalCapacityInByteBuffer(ByteBuffer _bb) { return __vector_in_bytebuffer(_bb, 20, 8); } + + public static int createClientTableData(FlatBufferBuilder builder, + int client_idOffset, + int node_manager_addressOffset, + int raylet_socket_nameOffset, + int object_store_socket_nameOffset, + int node_manager_port, + int object_manager_port, + boolean is_insertion, + int resources_total_labelOffset, + int resources_total_capacityOffset) { + builder.startObject(9); + ClientTableData.addResourcesTotalCapacity(builder, resources_total_capacityOffset); + ClientTableData.addResourcesTotalLabel(builder, resources_total_labelOffset); + ClientTableData.addObjectManagerPort(builder, object_manager_port); + ClientTableData.addNodeManagerPort(builder, node_manager_port); + ClientTableData.addObjectStoreSocketName(builder, object_store_socket_nameOffset); + ClientTableData.addRayletSocketName(builder, raylet_socket_nameOffset); + ClientTableData.addNodeManagerAddress(builder, node_manager_addressOffset); + ClientTableData.addClientId(builder, client_idOffset); + ClientTableData.addIsInsertion(builder, is_insertion); + return ClientTableData.endClientTableData(builder); + } + + public static void startClientTableData(FlatBufferBuilder builder) { builder.startObject(9); } + public static void addClientId(FlatBufferBuilder builder, int clientIdOffset) { builder.addOffset(0, clientIdOffset, 0); } + public static void addNodeManagerAddress(FlatBufferBuilder builder, int nodeManagerAddressOffset) { builder.addOffset(1, nodeManagerAddressOffset, 0); } + public static void addRayletSocketName(FlatBufferBuilder builder, int rayletSocketNameOffset) { builder.addOffset(2, rayletSocketNameOffset, 0); } + public static void addObjectStoreSocketName(FlatBufferBuilder builder, int objectStoreSocketNameOffset) { builder.addOffset(3, objectStoreSocketNameOffset, 0); } + public static void addNodeManagerPort(FlatBufferBuilder builder, int nodeManagerPort) { builder.addInt(4, nodeManagerPort, 0); } + public static void addObjectManagerPort(FlatBufferBuilder builder, int objectManagerPort) { builder.addInt(5, objectManagerPort, 0); } + public static void addIsInsertion(FlatBufferBuilder builder, boolean isInsertion) { builder.addBoolean(6, isInsertion, false); } + public static void addResourcesTotalLabel(FlatBufferBuilder builder, int resourcesTotalLabelOffset) { builder.addOffset(7, resourcesTotalLabelOffset, 0); } + public static int createResourcesTotalLabelVector(FlatBufferBuilder builder, int[] data) { builder.startVector(4, data.length, 4); for (int i = data.length - 1; i >= 0; i--) builder.addOffset(data[i]); return builder.endVector(); } + public static void startResourcesTotalLabelVector(FlatBufferBuilder builder, int numElems) { builder.startVector(4, numElems, 4); } + public static void addResourcesTotalCapacity(FlatBufferBuilder builder, int resourcesTotalCapacityOffset) { builder.addOffset(8, resourcesTotalCapacityOffset, 0); } + public static int createResourcesTotalCapacityVector(FlatBufferBuilder builder, double[] data) { builder.startVector(8, data.length, 8); for (int i = data.length - 1; i >= 0; i--) builder.addDouble(data[i]); return builder.endVector(); } + public static void startResourcesTotalCapacityVector(FlatBufferBuilder builder, int numElems) { builder.startVector(8, numElems, 8); } + public static int endClientTableData(FlatBufferBuilder builder) { + int o = builder.endObject(); + return o; + } +} + diff --git a/java/runtime-native/src/main/java/org/ray/runner/RunInfo.java b/java/runtime-native/src/main/java/org/ray/runner/RunInfo.java index bb58cff7e..cddf74f13 100644 --- a/java/runtime-native/src/main/java/org/ray/runner/RunInfo.java +++ b/java/runtime-native/src/main/java/org/ray/runner/RunInfo.java @@ -35,7 +35,7 @@ public class RunInfo { public enum ProcessType { PT_WORKER, PT_LOCAL_SCHEDULER, PT_PLASMA_MANAGER, PT_PLASMA_STORE, - PT_GLOBAL_SCHEDULER, PT_REDIS_SERVER, PT_WEB_UI, + PT_GLOBAL_SCHEDULER, PT_REDIS_SERVER, PT_WEB_UI, PT_RAYLET, PT_DRIVER } } diff --git a/java/runtime-native/src/main/java/org/ray/runner/RunManager.java b/java/runtime-native/src/main/java/org/ray/runner/RunManager.java index 63f6d8435..43742133c 100644 --- a/java/runtime-native/src/main/java/org/ray/runner/RunManager.java +++ b/java/runtime-native/src/main/java/org/ray/runner/RunManager.java @@ -48,7 +48,7 @@ public class RunManager { private static boolean killProcess(Process p) { if (p.isAlive()) { - p.destroyForcibly(); + p.destroy(); return true; } else { return false; @@ -307,7 +307,7 @@ public class RunManager { redisClient.close(); // start global scheduler - if (params.include_global_scheduler) { + if (params.include_global_scheduler && !params.use_raylet) { startGlobalScheduler(params.working_directory + "/globalScheduler", params.redis_address, params.node_ip_address, params.redirect, params.cleanup); } @@ -340,49 +340,70 @@ public class RunManager { } } - // start object stores - for (int i = 0; i < params.num_local_schedulers; i++) { - AddressInfo info = new AddressInfo(); - // store - startObjectStore(i, info, params.working_directory + "/store", + AddressInfo info = new AddressInfo(); + + if (params.use_raylet) { + // Start object store + int rpcPort = params.object_store_rpc_port; + String storeName = "/tmp/plasma_store" + rpcPort; + + startObjectStore(0, info, params.working_directory + "/store", params.redis_address, params.node_ip_address, params.redirect, params.cleanup); - // store manager - startObjectManager(i, info, - params.working_directory + "/storeManager", params.redis_address, - params.node_ip_address, params.redirect, params.cleanup); + //Start raylet + startRaylet(storeName, info, params.num_cpus[0],params.num_gpus[0], + params.num_workers,params.working_directory + "/raylet", + params.redis_address, params.node_ip_address, params.redirect, params.cleanup); runInfo.localStores.add(info); - } + } else { + for (int i = 0; i < params.num_local_schedulers; i++) { + // Start object stores + startObjectStore(i, info, params.working_directory + "/store", + params.redis_address, params.node_ip_address, params.redirect, params.cleanup); - // start local scheduler - for (int i = 0; i < params.num_local_schedulers; i++) { - int workerCount = 0; + startObjectManager(i, info, + params.working_directory + "/storeManager", params.redis_address, + params.node_ip_address, params.redirect, params.cleanup); - if (params.start_workers_from_local_scheduler) { - workerCount = localNumWorkers[i]; - localNumWorkers[i] = 0; + // Start local scheduler + int workerCount = 0; + + if (params.start_workers_from_local_scheduler) { + workerCount = localNumWorkers[i]; + localNumWorkers[i] = 0; + } + + startLocalScheduler(i, info, + params.num_cpus[i], params.num_gpus[i], workerCount, + params.working_directory + "/localsc", params.redis_address, + params.node_ip_address, params.redirect, params.cleanup); + + runInfo.localStores.add(info); } - - startLocalScheduler(i, runInfo.localStores.get(i), - params.num_cpus[i], params.num_gpus[i], workerCount, - params.working_directory + "/localScheduler", params.redis_address, - params.node_ip_address, params.redirect, params.cleanup); } // start local workers - for (int i = 0; i < params.num_local_schedulers; i++) { - runInfo.localStores.get(i).workerCount = localNumWorkers[i]; - for (int j = 0; j < localNumWorkers[i]; j++) { - startWorker(runInfo.localStores.get(i).storeName, - runInfo.localStores.get(i).managerName, runInfo.localStores.get(i).schedulerName, - params.working_directory + "/worker" + i + "." + j, params.redis_address, - params.node_ip_address, UniqueID.nil, "", - params.redirect, params.cleanup); + if (!params.use_raylet) { + for (int i = 0; i < params.num_local_schedulers; i++) { + AddressInfo localStores = runInfo.localStores.get(i); + localStores.workerCount = localNumWorkers[i]; + for (int j = 0; j < localNumWorkers[i]; j++) { + startWorker(localStores.storeName, localStores.managerName, localStores.schedulerName, + params.working_directory + "/worker" + i + "." + j, params.redis_address, + params.node_ip_address, UniqueID.nil, "", params.redirect, params.cleanup); + } } } HashSet excludeTypes = new HashSet<>(); + if (!params.use_raylet) { + excludeTypes.add(RunInfo.ProcessType.PT_RAYLET); + } else { + excludeTypes.add(RunInfo.ProcessType.PT_LOCAL_SCHEDULER); + excludeTypes.add(RunInfo.ProcessType.PT_GLOBAL_SCHEDULER); + excludeTypes.add(RunInfo.ProcessType.PT_PLASMA_MANAGER); + } if (!checkAlive(excludeTypes)) { cleanup(true); throw new RuntimeException("Start Ray processes failed"); @@ -622,8 +643,8 @@ public class RunManager { cmd += " -m " + info.managerName; String workerCmd = null; - workerCmd = buildWorkerCommand(true, info.storeName, info.managerName, name, UniqueID.nil, - "", workDir + rpcPort, ip, redisAddress); + workerCmd = buildWorkerCommand(true, info.storeName, info.managerName, name, + UniqueID.nil, "", workDir + rpcPort, ip, redisAddress); cmd += " -w \"" + workerCmd + "\""; if (redisAddress.length() > 0) { @@ -656,6 +677,82 @@ public class RunManager { } } + private void startRaylet(String storeName, AddressInfo info, int numCpus, + int numGpus, int numWorkers, String workDir, + String redisAddress, String ip, boolean redirect, + boolean cleanup) { + + int rpcPort = params.raylet_port; + String rayletSocketName = "/tmp/raylet" + rpcPort; + + String filePath = paths.raylet; + + String workerCmd = null; + workerCmd = buildWorkerCommandRaylet(info.storeName, rayletSocketName, UniqueID.nil, + "", workDir + rpcPort, ip, redisAddress); + + int sep = redisAddress.indexOf(':'); + assert (sep != -1); + String gcsIp = redisAddress.substring(0, sep); + String gcsPort = redisAddress.substring(sep + 1); + + String resourceArgument = "GPU," + numGpus + ",CPU," + numCpus; + + String[] cmds = new String[]{filePath, rayletSocketName, storeName, ip, gcsIp, + gcsPort, "" + numWorkers, workerCmd, resourceArgument}; + + Process p = startProcess(cmds, null, RunInfo.ProcessType.PT_RAYLET, + workDir + rpcPort, redisAddress, ip, redirect, cleanup); + + if (p != null && p.isAlive()) { + try { + TimeUnit.MILLISECONDS.sleep(100); + } catch (InterruptedException e) { + e.printStackTrace(); + } + } + + if (p == null || !p.isAlive()) { + info.rayletSocketName = ""; + info.rayletRpcAddr = ""; + throw new RuntimeException("Failed to start raylet process."); + } else { + info.rayletSocketName = rayletSocketName; + info.rayletRpcAddr = ip + ":" + rpcPort; + } + } + + private String buildWorkerCommandRaylet(String storeName, String rayletSocketName, + UniqueID actorId, String actorClass, String workDir, + String ip, String redisAddress) { + String workerConfigs = "ray.java.start.object_store_name=" + storeName + + ";ray.java.start.raylet_socket_name=" + rayletSocketName + + ";ray.java.start.worker_mode=WORKER;ray.java.start.use_raylet=true"; + workerConfigs += ";ray.java.start.deploy=" + params.deploy; + if (!actorId.equals(UniqueID.nil)) { + workerConfigs += ";ray.java.start.actor_id=" + actorId; + } + if (!actorClass.equals("")) { + workerConfigs += ";ray.java.start.driver_class=" + actorClass; + } + + String jvmArgs = ""; + jvmArgs += " -Dlogging.path=" + params.working_directory + "/logs/workers"; + jvmArgs += " -Dlogging.file.name=core-*pid_suffix*"; + + return buildJavaProcessCommand( + RunInfo.ProcessType.PT_WORKER, + "org.ray.runner.worker.DefaultWorker", + "", + workerConfigs, + jvmArgs, + workDir, + ip, + redisAddress, + null + ); + } + private String buildWorkerCommand(boolean isFromLocalScheduler, String storeName, String storeManagerName, String localSchedulerName, UniqueID actorId, String actorClass, String workDir, String diff --git a/java/runtime-native/src/main/java/org/ray/spi/KeyValueStoreLink.java b/java/runtime-native/src/main/java/org/ray/spi/KeyValueStoreLink.java index 6d21fa8e8..b52ee6ab9 100644 --- a/java/runtime-native/src/main/java/org/ray/spi/KeyValueStoreLink.java +++ b/java/runtime-native/src/main/java/org/ray/spi/KeyValueStoreLink.java @@ -103,6 +103,15 @@ public interface KeyValueStoreLink { */ List lrange(final String key, final long start, final long end); + /** + * Return the set of elements of the sorted set stored at the specified key. + * @param key The specified key you want to query. + * @param start The start index of the range. + * @param end The end index of the range. + * @return The set of elements you queried. + */ + Set zrange(byte[] key, long start, long end); + /** * Rpush. * @return Integer reply, specifically, the number of elements inside the list after the push @@ -123,4 +132,7 @@ public interface KeyValueStoreLink { Long publish(byte[] channel, byte[] message); Object getImpl(); + + byte[] sendCommand(String command, int commandType, byte[] objectId); + } diff --git a/java/runtime-native/src/main/java/org/ray/spi/StateStoreProxy.java b/java/runtime-native/src/main/java/org/ray/spi/StateStoreProxy.java index 2ecaadc3a..c5995dba1 100644 --- a/java/runtime-native/src/main/java/org/ray/spi/StateStoreProxy.java +++ b/java/runtime-native/src/main/java/org/ray/spi/StateStoreProxy.java @@ -31,5 +31,7 @@ public interface StateStoreProxy { * getAddressInfo. * @return list of address information */ - List getAddressInfo(final String nodeIpAddress, int numRetries); + List getAddressInfo(final String nodeIpAddress, + final String redisAddress, + int numRetries); } diff --git a/java/runtime-native/src/main/java/org/ray/spi/impl/BaseStateStoreProxyImpl.java b/java/runtime-native/src/main/java/org/ray/spi/impl/BaseStateStoreProxyImpl.java new file mode 100644 index 000000000..9d01134f5 --- /dev/null +++ b/java/runtime-native/src/main/java/org/ray/spi/impl/BaseStateStoreProxyImpl.java @@ -0,0 +1,124 @@ +package org.ray.spi.impl; + +import java.io.UnsupportedEncodingException; +import java.util.ArrayList; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.TimeUnit; +import org.ray.spi.KeyValueStoreLink; +import org.ray.spi.StateStoreProxy; +import org.ray.spi.model.AddressInfo; +import org.ray.util.logger.RayLog; + +/** + * Base class used to interface with the Ray control state. + */ +public abstract class BaseStateStoreProxyImpl implements StateStoreProxy { + + public KeyValueStoreLink rayKvStore; + public ArrayList shardStoreList = new ArrayList<>(); + + public BaseStateStoreProxyImpl(KeyValueStoreLink rayKvStore) { + this.rayKvStore = rayKvStore; + } + + @Override + public void setStore(KeyValueStoreLink rayKvStore) { + this.rayKvStore = rayKvStore; + } + + @Override + public synchronized void initializeGlobalState() throws Exception { + + String es; + + checkConnected(); + + String s = rayKvStore.get("NumRedisShards", null); + if (s == null) { + throw new Exception("NumRedisShards not found in redis."); + } + int numRedisShards = Integer.parseInt(s); + if (numRedisShards < 1) { + es = String.format("Expected at least one Redis shard, found %d", numRedisShards); + throw new Exception(es); + } + List ipAddressPorts = rayKvStore.lrange("RedisShards", 0, -1); + Set distinctIpAddress = new HashSet(ipAddressPorts); + if (distinctIpAddress.size() != numRedisShards) { + es = String.format("Expected %d Redis shard addresses, found2 %d.", numRedisShards, + distinctIpAddress.size()); + throw new Exception(es); + } + + shardStoreList.clear(); + for (String ipPort : distinctIpAddress) { + shardStoreList.add(new RedisClient(ipPort)); + } + + } + + public void checkConnected() throws Exception { + rayKvStore.checkConnected(); + } + + @Override + public synchronized Set keys(final String pattern) { + Set allKeys = new HashSet<>(); + Set tmpKey; + for (KeyValueStoreLink ashardStoreList : shardStoreList) { + tmpKey = ashardStoreList.keys(pattern); + allKeys.addAll(tmpKey); + } + + return allKeys; + } + + @Override + public List getAddressInfo(final String nodeIpAddress, + final String redisAddress, + int numRetries) { + int count = 0; + while (count < numRetries) { + try { + return doGetAddressInfo(nodeIpAddress, redisAddress); + } catch (Exception e) { + try { + RayLog.core.warn("Error occurred in BaseStateStoreProxyImpl getAddressInfo, " + + (numRetries - count) + " retries remaining", e); + TimeUnit.MILLISECONDS.sleep(1000); + } catch (InterruptedException ie) { + RayLog.core.error("error at BaseStateStoreProxyImpl getAddressInfo", e); + throw new RuntimeException(e); + } + } + count++; + } + throw new RuntimeException("cannot get address info from state store"); + } + + /** + * Get address info of one node from primary redis. + * This method only tries to get address info once, without any retry. + * + * @param nodeIpAddress Usually local ip address. + * @param redisAddress The primary redis address. + * @return A list of SchedulerInfo which contains node manager or local scheduler address info. + * @throws Exception No redis client exception. + */ + protected abstract List doGetAddressInfo(final String nodeIpAddress, + final String redisAddress) throws Exception; + + protected String charsetDecode(byte[] bs, String charset) throws UnsupportedEncodingException { + return new String(bs, charset); + } + + protected byte[] charsetEncode(String str, String charset) throws UnsupportedEncodingException { + if (str != null) { + return str.getBytes(charset); + } + return null; + } +} diff --git a/java/runtime-native/src/main/java/org/ray/spi/impl/DefaultLocalSchedulerClient.java b/java/runtime-native/src/main/java/org/ray/spi/impl/DefaultLocalSchedulerClient.java index 872f2dc64..f732efb37 100644 --- a/java/runtime-native/src/main/java/org/ray/spi/impl/DefaultLocalSchedulerClient.java +++ b/java/runtime-native/src/main/java/org/ray/spi/impl/DefaultLocalSchedulerClient.java @@ -24,20 +24,44 @@ public class DefaultLocalSchedulerClient implements LocalSchedulerLink { return bb; }); private long client = 0; + boolean useRaylet = false; - public DefaultLocalSchedulerClient(String schedulerSockName, UniqueID clientId, UniqueID actorId, - boolean isWorker, UniqueID driverId, long numGpus) { + public DefaultLocalSchedulerClient(String schedulerSockName, UniqueID clientId, + UniqueID actorId, boolean isWorker, UniqueID driverId, + long numGpus, boolean useRaylet) { client = _init(schedulerSockName, clientId.getBytes(), actorId.getBytes(), isWorker, - driverId.getBytes(), numGpus); + driverId.getBytes(), numGpus, useRaylet); + this.useRaylet = useRaylet; } - private static native long _init(String localSchedulerSocket, byte[] workerId, byte[] actorId, - boolean isWorker, byte[] driverTaskId, long numGpus); + private static native long _init(String localSchedulerSocket, byte[] workerId, + byte[] actorId, boolean isWorker, byte[] driverTaskId, + long numGpus, boolean useRaylet); private static native byte[] _computePutId(long client, byte[] taskId, int putIndex); private static native void _task_done(long client); + private static native boolean[] _waitObject(long conn, byte[][] objectIds, + int numReturns, int timeout, boolean waitLocal); + + @Override + public List wait(byte[][] objectIds, int timeoutMs, int numReturns) { + assert (useRaylet == true); + + boolean[] readys = _waitObject(client, objectIds, numReturns, timeoutMs, false); + + List ret = new ArrayList<>(); + for (int i = 0; i < readys.length; i++) { + if (readys[i]) { + ret.add(objectIds[i]); + } + } + + assert (ret.size() == readys.length); + return ret; + } + @Override public void submitTask(TaskSpec task) { ByteBuffer info = taskSpec2Info(task); @@ -45,12 +69,13 @@ public class DefaultLocalSchedulerClient implements LocalSchedulerLink { if (!task.actorId.isNil()) { a = task.cursorId.getBytes(); } - _submitTask(client, a, info, info.position(), info.remaining()); + + _submitTask(client, a, info, info.position(), info.remaining(), useRaylet); } @Override public TaskSpec getTaskTodo() { - byte[] bytes = _getTaskTodo(client); + byte[] bytes = _getTaskTodo(client, useRaylet); assert (null != bytes); ByteBuffer bb = ByteBuffer.wrap(bytes); return taskInfo2Spec(bb); @@ -62,8 +87,16 @@ public class DefaultLocalSchedulerClient implements LocalSchedulerLink { } @Override - public void reconstructObject(UniqueID objectId) { - _reconstruct_object(client, objectId.getBytes()); + public void reconstructObject(UniqueID objectId, boolean fetchOnly) { + List objects = new ArrayList<>(); + objects.add(objectId); + _reconstruct_objects(client, getIdBytes(objects), fetchOnly); + } + + @Override + public void reconstructObjects(List objectIds, boolean fetchOnly) { + RayLog.core.info("reconstruct objects {}", objectIds); + _reconstruct_objects(client, getIdBytes(objectIds), fetchOnly); } @Override @@ -73,12 +106,13 @@ public class DefaultLocalSchedulerClient implements LocalSchedulerLink { private static native void _notify_unblocked(long client); - private static native void _reconstruct_object(long client, byte[] objectId); + private static native void _reconstruct_objects(long client, byte[][] objectIds, + boolean fetchOnly); private static native void _put_object(long client, byte[] taskId, byte[] objectId); // return TaskInfo (in FlatBuffer) - private static native byte[] _getTaskTodo(long client); + private static native byte[] _getTaskTodo(long client, boolean useRaylet); public static TaskSpec taskInfo2Spec(ByteBuffer bb) { bb.order(ByteOrder.LITTLE_ENDIAN); @@ -162,7 +196,10 @@ public class DefaultLocalSchedulerClient implements LocalSchedulerLink { idOffsets[k] = fbb.createString(task.args[i].ids.get(k).toByteBuffer()); } objectIdOffset = fbb.createVectorOfTables(idOffsets); + } else { + objectIdOffset = fbb.createVectorOfTables(new int[0]); } + if (task.args[i].data != null) { dataOffset = fbb.createString(ByteBuffer.wrap(task.args[i].data)); } @@ -214,8 +251,17 @@ public class DefaultLocalSchedulerClient implements LocalSchedulerLink { } // task -> TaskInfo (with FlatBuffer) - private static native void _submitTask(long client, byte[] cursorId, /*Direct*/ByteBuffer task, - int pos, int sz); + protected static native void _submitTask(long client, byte[] cursorId, /*Direct*/ByteBuffer task, + int pos, int sz, boolean useRaylet); + + private static byte[][] getIdBytes(List objectIds) { + int size = objectIds.size(); + byte[][] ids = new byte[size][]; + for (int i = 0; i < size; i++) { + ids[i] = objectIds.get(i).getBytes(); + } + return ids; + } public void destroy() { _destroy(client); diff --git a/java/runtime-native/src/main/java/org/ray/spi/impl/StateStoreProxyImpl.java b/java/runtime-native/src/main/java/org/ray/spi/impl/NonRayletStateStoreProxyImpl.java similarity index 52% rename from java/runtime-native/src/main/java/org/ray/spi/impl/StateStoreProxyImpl.java rename to java/runtime-native/src/main/java/org/ray/spi/impl/NonRayletStateStoreProxyImpl.java index d6fd6a211..f00267df1 100644 --- a/java/runtime-native/src/main/java/org/ray/spi/impl/StateStoreProxyImpl.java +++ b/java/runtime-native/src/main/java/org/ray/spi/impl/NonRayletStateStoreProxyImpl.java @@ -1,96 +1,18 @@ package org.ray.spi.impl; -import java.io.UnsupportedEncodingException; import java.util.ArrayList; -import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Set; -import java.util.concurrent.TimeUnit; import org.ray.spi.KeyValueStoreLink; -import org.ray.spi.StateStoreProxy; import org.ray.spi.model.AddressInfo; -import org.ray.util.logger.RayLog; /** - * A class used to interface with the Ray control state. + * A class used to interface with the Ray control state for non-raylet. */ -public class StateStoreProxyImpl implements StateStoreProxy { - - public KeyValueStoreLink rayKvStore; - public ArrayList shardStoreList = new ArrayList<>(); - - public StateStoreProxyImpl(KeyValueStoreLink rayKvStore) { - this.rayKvStore = rayKvStore; - } - - public void setStore(KeyValueStoreLink rayKvStore) { - this.rayKvStore = rayKvStore; - } - - public synchronized void initializeGlobalState() throws Exception { - - String es; - - checkConnected(); - - String s = rayKvStore.get("NumRedisShards", null); - if (s == null) { - throw new Exception("NumRedisShards not found in redis."); - } - int numRedisShards = Integer.parseInt(s); - if (numRedisShards < 1) { - es = String.format("Expected at least one Redis shard, found %d", numRedisShards); - throw new Exception(es); - } - List ipAddressPorts = rayKvStore.lrange("RedisShards", 0, -1); - if (ipAddressPorts.size() != numRedisShards) { - es = String.format("Expected %d Redis shard addresses, found %d.", numRedisShards, - ipAddressPorts.size()); - throw new Exception(es); - } - - shardStoreList.clear(); - for (String ipPort : ipAddressPorts) { - shardStoreList.add(new RedisClient(ipPort)); - } - - } - - public void checkConnected() throws Exception { - rayKvStore.checkConnected(); - } - - public synchronized Set keys(final String pattern) { - Set allKeys = new HashSet<>(); - Set tmpKey; - for (KeyValueStoreLink ashardStoreList : shardStoreList) { - tmpKey = ashardStoreList.keys(pattern); - allKeys.addAll(tmpKey); - } - - return allKeys; - - } - - public List getAddressInfo(final String nodeIpAddress, int numRetries) { - int count = 0; - while (count < numRetries) { - try { - return getAddressInfoHelper(nodeIpAddress); - } catch (Exception e) { - try { - RayLog.core.warn("Error occurred in StateStoreProxyImpl getAddressInfo, " - + (numRetries - count) + " retries remaining", e); - TimeUnit.MILLISECONDS.sleep(1000); - } catch (InterruptedException ie) { - RayLog.core.error("error at StateStoreProxyImpl getAddressInfo", e); - throw new RuntimeException(e); - } - } - count++; - } - throw new RuntimeException("cannot get address info from state store"); +public class NonRayletStateStoreProxyImpl extends BaseStateStoreProxyImpl { + public NonRayletStateStoreProxyImpl(KeyValueStoreLink rayKvStore) { + super(rayKvStore); } /* @@ -108,9 +30,11 @@ public class StateStoreProxyImpl implements StateStoreProxy { * "manager_socket_name"(op) * "local_scheduler_socket_name"(op) */ - public List getAddressInfoHelper(final String nodeIpAddress) throws Exception { + @Override + public List doGetAddressInfo(final String nodeIpAddress, + final String redisAddress) throws Exception { if (this.rayKvStore == null) { - throw new Exception("no redis client when use getAddressInfoHelper"); + throw new Exception("no redis client when use doGetAddressInfo"); } List schedulerInfo = new ArrayList<>(); @@ -136,13 +60,13 @@ public class StateStoreProxyImpl implements StateStoreProxy { } else if (!info.containsKey("client_type".getBytes())) { throw new Exception("no client_type in any client"); } - + if (charsetDecode(info.get("node_ip_address".getBytes()), "US-ASCII") .equals(nodeIpAddress)) { String clientType = charsetDecode(info.get("client_type".getBytes()), "US-ASCII"); - if (clientType.equals("plasma_manager")) { + if ("plasma_manager".equals(clientType)) { plasmaManager.add(info); - } else if (clientType.equals("local_scheduler")) { + } else if ("local_scheduler".equals(clientType)) { localScheduler.add(info); } } @@ -157,9 +81,9 @@ public class StateStoreProxyImpl implements StateStoreProxy { for (int i = 0; i < plasmaManager.size(); i++) { AddressInfo si = new AddressInfo(); si.storeName = charsetDecode(plasmaManager.get(i).get("store_socket_name".getBytes()), - "US-ASCII"); + "US-ASCII"); si.managerName = charsetDecode(plasmaManager.get(i).get("manager_socket_name".getBytes()), - "US-ASCII"); + "US-ASCII"); byte[] rpc = plasmaManager.get(i).get("manager_rpc_name".getBytes()); if (rpc != null) { @@ -188,14 +112,4 @@ public class StateStoreProxyImpl implements StateStoreProxy { return schedulerInfo; } - private String charsetDecode(byte[] bs, String charset) throws UnsupportedEncodingException { - return new String(bs, charset); - } - - private byte[] charsetEncode(String str, String charset) throws UnsupportedEncodingException { - if (str != null) { - return str.getBytes(charset); - } - return null; - } } diff --git a/java/runtime-native/src/main/java/org/ray/spi/impl/RayletStateStoreProxyImpl.java b/java/runtime-native/src/main/java/org/ray/spi/impl/RayletStateStoreProxyImpl.java new file mode 100644 index 000000000..0cfa4f532 --- /dev/null +++ b/java/runtime-native/src/main/java/org/ray/spi/impl/RayletStateStoreProxyImpl.java @@ -0,0 +1,62 @@ +package org.ray.spi.impl; + +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.List; +import java.util.Objects; +import java.util.Set; +import org.ray.api.UniqueID; +import org.ray.format.gcs.ClientTableData; +import org.ray.spi.KeyValueStoreLink; +import org.ray.spi.model.AddressInfo; +import org.ray.util.NetworkUtil; + +/** + * A class used to interface with the Ray control state for raylet. + */ +public class RayletStateStoreProxyImpl extends BaseStateStoreProxyImpl { + + public RayletStateStoreProxyImpl(KeyValueStoreLink rayKvStore) { + super(rayKvStore); + } + + @Override + public List doGetAddressInfo(final String nodeIpAddress, + final String redisAddress) throws Exception { + if (this.rayKvStore == null) { + throw new Exception("no redis client when use doGetAddressInfo"); + } + List schedulerInfo = new ArrayList<>(); + + byte[] prefix = "CLIENT".getBytes(); + byte[] postfix = UniqueID.genNil().getBytes(); + byte[] clientKey = new byte[prefix.length + postfix.length]; + System.arraycopy(prefix, 0, clientKey, 0, prefix.length); + System.arraycopy(postfix, 0, clientKey, prefix.length, postfix.length); + + Set clients = rayKvStore.zrange(clientKey, 0, -1); + + for (byte[] clientMessage : clients) { + ByteBuffer bb = ByteBuffer.wrap(clientMessage); + ClientTableData client = ClientTableData.getRootAsClientTableData(bb); + String clientNodeIpAddress = client.nodeManagerAddress(); + + String localIpAddress = NetworkUtil.getIpAddress(null); + String redisIpAddress = redisAddress.substring(0, redisAddress.indexOf(':')); + + boolean headNodeAddress = "127.0.0.1".equals(clientNodeIpAddress) + && Objects.equals(redisIpAddress, localIpAddress); + boolean notHeadNodeAddress = Objects.equals(clientNodeIpAddress, nodeIpAddress); + + if (headNodeAddress || notHeadNodeAddress) { + AddressInfo si = new AddressInfo(); + si.storeName = client.objectStoreSocketName(); + si.rayletSocketName = client.rayletSocketName(); + si.managerRpcAddr = client.nodeManagerAddress(); + si.managerPort = client.nodeManagerPort(); + schedulerInfo.add(si); + } + } + return schedulerInfo; + } +} diff --git a/java/runtime-native/src/main/java/org/ray/spi/impl/RedisClient.java b/java/runtime-native/src/main/java/org/ray/spi/impl/RedisClient.java index 97f744f7d..5dec52e57 100644 --- a/java/runtime-native/src/main/java/org/ray/spi/impl/RedisClient.java +++ b/java/runtime-native/src/main/java/org/ray/spi/impl/RedisClient.java @@ -13,6 +13,7 @@ public class RedisClient implements KeyValueStoreLink { private String redisAddress; private JedisPool jedisPool; + private int handle = 0; public RedisClient() { } @@ -171,6 +172,13 @@ public class RedisClient implements KeyValueStoreLink { } } + @Override + public Set zrange(byte[] key, long start, long end) { + try (Jedis jedis = jedisPool.getResource()) { + return jedis.zrange(key, start, end); + } + } + @Override public Long rpush(String key, String... strings) { try (Jedis jedis = jedisPool.getResource()) { @@ -203,4 +211,20 @@ public class RedisClient implements KeyValueStoreLink { public Object getImpl() { return jedisPool; } + + @Override + public byte[] sendCommand(String command, int commandType, byte[] objectId) { + if (handle == 0) { + String[] ipPort = redisAddress.split(":"); + handle = connect(ipPort[0], Integer.parseInt(ipPort[1])); + } + return execute_command(handle, command, commandType, objectId); + } + + private static native int connect(String redisAddress, int port); + + private static native void disconnect(int handle); + + private static native byte[] execute_command(int handle, + String command, int commandType, byte[] objectId); } diff --git a/java/test.sh b/java/test.sh index d277473c2..3c7da10bf 100755 --- a/java/test.sh +++ b/java/test.sh @@ -10,7 +10,17 @@ mvn clean install -Dmaven.test.skip check_style=$(mvn checkstyle:check) echo "${check_style}" [[ ${check_style} =~ "BUILD FAILURE" ]] && exit 1 + +# test non-raylet +sed -i 's/^use_raylet.*$/use_raylet = false/g' $ROOT_DIR/../java/ray.config.ini mvn_test=$(mvn test) echo "${mvn_test}" [[ ${mvn_test} =~ "BUILD SUCCESS" ]] || exit 1 -popd \ No newline at end of file + +# test raylet +sed -i 's/^use_raylet.*$/use_raylet = true/g' $ROOT_DIR/../java/ray.config.ini +mvn_test=$(mvn test) +echo "${mvn_test}" +[[ ${mvn_test} =~ "BUILD SUCCESS" ]] || exit 1 + +popd diff --git a/src/local_scheduler/lib/java/org_ray_spi_impl_DefaultLocalSchedulerClient.cc b/src/local_scheduler/lib/java/org_ray_spi_impl_DefaultLocalSchedulerClient.cc index b43ca2801..7eef8596d 100644 --- a/src/local_scheduler/lib/java/org_ray_spi_impl_DefaultLocalSchedulerClient.cc +++ b/src/local_scheduler/lib/java/org_ray_spi_impl_DefaultLocalSchedulerClient.cc @@ -43,15 +43,15 @@ Java_org_ray_spi_impl_DefaultLocalSchedulerClient__1init(JNIEnv *env, jbyteArray actorId, jboolean isWorker, jbyteArray driverId, - jlong numGpus) { + jlong numGpus, + jboolean useRaylet) { // native private static long _init(String localSchedulerSocket, // byte[] workerId, byte[] actorId, boolean isWorker, long numGpus); UniqueIdFromJByteArray worker_id(env, wid); UniqueIdFromJByteArray driver_id(env, driverId); const char *nativeString = env->GetStringUTFChars(sockName, JNI_FALSE); - bool use_raylet = false; auto client = LocalSchedulerConnection_init( - nativeString, *worker_id.PID, isWorker, *driver_id.PID, use_raylet); + nativeString, *worker_id.PID, isWorker, *driver_id.PID, useRaylet); env->ReleaseStringUTFChars(sockName, nativeString); return reinterpret_cast(client); } @@ -69,21 +69,30 @@ Java_org_ray_spi_impl_DefaultLocalSchedulerClient__1submitTask( jbyteArray cursorId, jobject buff, jint pos, - jint sz) { + jint sz, + jboolean useRaylet) { // task -> TaskInfo (with FlatBuffer) // native private static void _submitTask(long client, /*Direct*/ByteBuffer // task); auto client = reinterpret_cast(c); - TaskSpec *task = - reinterpret_cast(env->GetDirectBufferAddress(buff)) + pos; + std::vector execution_dependencies; if (cursorId != nullptr) { UniqueIdFromJByteArray cursor_id(env, cursorId); execution_dependencies.push_back(*cursor_id.PID); } - TaskExecutionSpec taskExecutionSpec = - TaskExecutionSpec(execution_dependencies, task, sz); - local_scheduler_submit(client, taskExecutionSpec); + if (!useRaylet) { + TaskSpec *task = + reinterpret_cast(env->GetDirectBufferAddress(buff)) + pos; + TaskExecutionSpec taskExecutionSpec = + TaskExecutionSpec(execution_dependencies, task, sz); + local_scheduler_submit(client, taskExecutionSpec); + } else { + auto data = + reinterpret_cast(env->GetDirectBufferAddress(buff)) + pos; + ray::raylet::TaskSpecification task_spec(std::string(data, sz)); + local_scheduler_submit_raylet(client, execution_dependencies, task_spec); + } } /* @@ -92,15 +101,19 @@ Java_org_ray_spi_impl_DefaultLocalSchedulerClient__1submitTask( * Signature: (J)[B */ JNIEXPORT jbyteArray JNICALL -Java_org_ray_spi_impl_DefaultLocalSchedulerClient__1getTaskTodo(JNIEnv *env, - jclass, - jlong c) { +Java_org_ray_spi_impl_DefaultLocalSchedulerClient__1getTaskTodo( + JNIEnv *env, + jclass, + jlong c, + jboolean useRaylet) { // native private static ByteBuffer _getTaskTodo(long client); auto client = reinterpret_cast(c); int64_t task_size = 0; // TODO: handle actor failure later - TaskSpec *spec = local_scheduler_get_task(client, &task_size); + TaskSpec *spec = !useRaylet + ? local_scheduler_get_task(client, &task_size) + : local_scheduler_get_task_raylet(client, &task_size); jbyteArray result; result = env->NewByteArray(task_size); @@ -178,20 +191,29 @@ Java_org_ray_spi_impl_DefaultLocalSchedulerClient__1task_1done(JNIEnv *, /* * Class: org_ray_spi_impl_DefaultLocalSchedulerClient - * Method: _reconstruct_object + * Method: _reconstruct_objects * Signature: (J[B)V */ JNIEXPORT void JNICALL -Java_org_ray_spi_impl_DefaultLocalSchedulerClient__1reconstruct_1object( +Java_org_ray_spi_impl_DefaultLocalSchedulerClient__1reconstruct_1objects( JNIEnv *env, jclass, jlong c, - jbyteArray oid) { - // native private static void _reconstruct_object(long client, byte[] - // objectId); - UniqueIdFromJByteArray o(env, oid); + jobjectArray oids, + jboolean fetch_only) { + // native private static void _reconstruct_objects(long client, byte[][] + // objectIds, boolean fetchOnly); + + std::vector object_ids; + auto len = env->GetArrayLength(oids); + for (int i = 0; i < len; i++) { + jbyteArray oid = (jbyteArray) env->GetObjectArrayElement(oids, i); + UniqueIdFromJByteArray o(env, oid); + object_ids.push_back(*o.PID); + env->DeleteLocalRef(oid); + } auto client = reinterpret_cast(c); - local_scheduler_reconstruct_objects(client, {*o.PID}); + local_scheduler_reconstruct_objects(client, object_ids, fetch_only); } /* @@ -227,6 +249,55 @@ Java_org_ray_spi_impl_DefaultLocalSchedulerClient__1put_1object( local_scheduler_put_object(client, *t.PID, *o.PID); } +JNIEXPORT jbooleanArray JNICALL +Java_org_ray_spi_impl_DefaultLocalSchedulerClient__1waitObject( + JNIEnv *env, + jclass, + jlong c, + jobjectArray oids, + jint num_returns, + jint timeout_ms, + jboolean wait_local) { + std::vector object_ids; + auto len = env->GetArrayLength(oids); + for (int i = 0; i < len; i++) { + jbyteArray oid = (jbyteArray) env->GetObjectArrayElement(oids, i); + UniqueIdFromJByteArray o(env, oid); + object_ids.push_back(*o.PID); + env->DeleteLocalRef(oid); + } + + auto client = reinterpret_cast(c); + + // Invoke wait. + std::pair, std::vector> result = + local_scheduler_wait(client, object_ids, num_returns, timeout_ms, + static_cast(wait_local)); + + // Convert result to java object. + jboolean putValue = true; + jbooleanArray resultArray = env->NewBooleanArray(object_ids.size()); + for (uint i = 0; i < result.first.size(); ++i) { + for (uint j = 0; j < object_ids.size(); ++j) { + if (result.first[i] == object_ids[j]) { + env->SetBooleanArrayRegion(resultArray, j, 1, &putValue); + break; + } + } + } + + putValue = false; + for (uint i = 0; i < result.second.size(); ++i) { + for (uint j = 0; j < object_ids.size(); ++j) { + if (result.second[i] == object_ids[j]) { + env->SetBooleanArrayRegion(resultArray, j, 1, &putValue); + break; + } + } + } + return resultArray; +} + #ifdef __cplusplus } #endif diff --git a/src/local_scheduler/lib/java/org_ray_spi_impl_DefaultLocalSchedulerClient.h b/src/local_scheduler/lib/java/org_ray_spi_impl_DefaultLocalSchedulerClient.h index cb3822ca1..edd6574b6 100644 --- a/src/local_scheduler/lib/java/org_ray_spi_impl_DefaultLocalSchedulerClient.h +++ b/src/local_scheduler/lib/java/org_ray_spi_impl_DefaultLocalSchedulerClient.h @@ -20,7 +20,8 @@ Java_org_ray_spi_impl_DefaultLocalSchedulerClient__1init(JNIEnv *, jbyteArray, jboolean, jbyteArray, - jlong); + jlong, + jboolean); /* * Class: org_ray_spi_impl_DefaultLocalSchedulerClient @@ -34,7 +35,8 @@ Java_org_ray_spi_impl_DefaultLocalSchedulerClient__1submitTask(JNIEnv *, jbyteArray, jobject, jint, - jint); + jint, + jboolean); /* * Class: org_ray_spi_impl_DefaultLocalSchedulerClient @@ -44,7 +46,8 @@ Java_org_ray_spi_impl_DefaultLocalSchedulerClient__1submitTask(JNIEnv *, JNIEXPORT jbyteArray JNICALL Java_org_ray_spi_impl_DefaultLocalSchedulerClient__1getTaskTodo(JNIEnv *, jclass, - jlong); + jlong, + jboolean); /* * Class: org_ray_spi_impl_DefaultLocalSchedulerClient @@ -80,15 +83,16 @@ Java_org_ray_spi_impl_DefaultLocalSchedulerClient__1task_1done(JNIEnv *, /* * Class: org_ray_spi_impl_DefaultLocalSchedulerClient - * Method: _reconstruct_object + * Method: _reconstruct_objects * Signature: (J[B)V */ JNIEXPORT void JNICALL -Java_org_ray_spi_impl_DefaultLocalSchedulerClient__1reconstruct_1object( +Java_org_ray_spi_impl_DefaultLocalSchedulerClient__1reconstruct_1objects( JNIEnv *, jclass, jlong, - jbyteArray); + jobjectArray, + jboolean); /* * Class: org_ray_spi_impl_DefaultLocalSchedulerClient @@ -112,6 +116,20 @@ Java_org_ray_spi_impl_DefaultLocalSchedulerClient__1put_1object(JNIEnv *, jbyteArray, jbyteArray); +/* + * Class: org_ray_spi_impl_DefaultLocalSchedulerClient + * Method: _waitObject + * Signature: (J[[BIIZ)[Z + */ +JNIEXPORT jbooleanArray JNICALL +Java_org_ray_spi_impl_DefaultLocalSchedulerClient__1waitObject(JNIEnv *, + jclass, + jlong, + jobjectArray, + jint, + jint, + jboolean); + #ifdef __cplusplus } #endif diff --git a/src/local_scheduler/local_scheduler_client.cc b/src/local_scheduler/local_scheduler_client.cc index 9923e0d2c..775c3518e 100644 --- a/src/local_scheduler/local_scheduler_client.cc +++ b/src/local_scheduler/local_scheduler_client.cc @@ -69,7 +69,7 @@ void local_scheduler_log_event(LocalSchedulerConnection *conn, } void local_scheduler_submit(LocalSchedulerConnection *conn, - TaskExecutionSpec &execution_spec) { + const TaskExecutionSpec &execution_spec) { flatbuffers::FlatBufferBuilder fbb; auto execution_dependencies = to_flatbuf(fbb, execution_spec.ExecutionDependencies()); @@ -86,7 +86,7 @@ void local_scheduler_submit(LocalSchedulerConnection *conn, void local_scheduler_submit_raylet( LocalSchedulerConnection *conn, const std::vector &execution_dependencies, - ray::raylet::TaskSpecification task_spec) { + const ray::raylet::TaskSpecification &task_spec) { flatbuffers::FlatBufferBuilder fbb; auto execution_dependencies_message = to_flatbuf(fbb, execution_dependencies); auto message = ray::local_scheduler::protocol::CreateSubmitTaskRequest( diff --git a/src/local_scheduler/local_scheduler_client.h b/src/local_scheduler/local_scheduler_client.h index e7eebdcdd..5026313d2 100644 --- a/src/local_scheduler/local_scheduler_client.h +++ b/src/local_scheduler/local_scheduler_client.h @@ -65,7 +65,7 @@ void LocalSchedulerConnection_free(LocalSchedulerConnection *conn); * @return Void. */ void local_scheduler_submit(LocalSchedulerConnection *conn, - TaskExecutionSpec &execution_spec); + const TaskExecutionSpec &execution_spec); /// Submit a task using the raylet code path. /// @@ -76,7 +76,7 @@ void local_scheduler_submit(LocalSchedulerConnection *conn, void local_scheduler_submit_raylet( LocalSchedulerConnection *conn, const std::vector &execution_dependencies, - ray::raylet::TaskSpecification task_spec); + const ray::raylet::TaskSpecification &task_spec); /** * Notify the local scheduler that this client is disconnecting gracefully. This diff --git a/src/ray/raylet/main.cc b/src/ray/raylet/main.cc index d3cdef460..46a8e139a 100644 --- a/src/ray/raylet/main.cc +++ b/src/ray/raylet/main.cc @@ -39,11 +39,10 @@ int main(int argc, char *argv[]) { RayConfig::instance().num_workers_per_process(); // Use a default worker that can execute empty tasks with dependencies. - std::stringstream worker_command_stream(worker_command); - std::string token; - while (getline(worker_command_stream, token, ' ')) { - node_manager_config.worker_command.push_back(token); - } + std::istringstream iss(worker_command); + std::vector results(std::istream_iterator{iss}, + std::istream_iterator()); + node_manager_config.worker_command.swap(results); node_manager_config.heartbeat_period_ms = RayConfig::instance().heartbeat_timeout_milliseconds(); @@ -84,8 +83,12 @@ int main(int argc, char *argv[]) { // Destroy the Raylet on a SIGTERM. The pointer to main_service is // guaranteed to be valid since this function will run the event loop // instead of returning immediately. - auto handler = [&main_service](const boost::system::error_code &error, - int signal_number) { main_service.stop(); }; + // We should stop the service and remove the local socket file. + auto handler = [&main_service, &raylet_socket_name]( + const boost::system::error_code &error, int signal_number) { + main_service.stop(); + remove(raylet_socket_name.c_str()); + }; boost::asio::signal_set signals(main_service, SIGTERM); signals.async_wait(handler); diff --git a/src/ray/raylet/task_spec.cc b/src/ray/raylet/task_spec.cc index 8488da3c4..40ba37f8f 100644 --- a/src/ray/raylet/task_spec.cc +++ b/src/ray/raylet/task_spec.cc @@ -37,6 +37,10 @@ TaskSpecification::TaskSpecification(const flatbuffers::String &string) { AssignSpecification(reinterpret_cast(string.data()), string.size()); } +TaskSpecification::TaskSpecification(const std::string &string) { + AssignSpecification(reinterpret_cast(string.data()), string.size()); +} + TaskSpecification::TaskSpecification( const UniqueID &driver_id, const TaskID &parent_task_id, int64_t parent_counter, const FunctionID &function_id, diff --git a/src/ray/raylet/task_spec.h b/src/ray/raylet/task_spec.h index d9e51dc96..7214b4aa0 100644 --- a/src/ray/raylet/task_spec.h +++ b/src/ray/raylet/task_spec.h @@ -109,6 +109,12 @@ class TaskSpecification { int64_t num_returns, const std::unordered_map &required_resources); + /// Deserialize a task specification from a flatbuffer's string data. + /// + /// \param string The string data for a serialized task specification + /// flatbuffer. + TaskSpecification(const std::string &string); + ~TaskSpecification() {} /// Serialize the TaskSpecification to a flatbuffer. diff --git a/src/ray/raylet/worker_pool.cc b/src/ray/raylet/worker_pool.cc index 0e87d0eb3..24ee3bd76 100644 --- a/src/ray/raylet/worker_pool.cc +++ b/src/ray/raylet/worker_pool.cc @@ -121,6 +121,7 @@ void WorkerPool::RegisterWorker(std::shared_ptr worker) { auto pid = worker->Pid(); RAY_LOG(DEBUG) << "Registering worker with pid " << pid; registered_workers_.push_back(std::move(worker)); + auto it = starting_worker_processes_.find(pid); RAY_CHECK(it != starting_worker_processes_.end()); it->second--;