diff --git a/java/runtime/src/main/java/io/ray/runtime/AbstractRayRuntime.java b/java/runtime/src/main/java/io/ray/runtime/AbstractRayRuntime.java index fb2eb39f2..486e689c3 100644 --- a/java/runtime/src/main/java/io/ray/runtime/AbstractRayRuntime.java +++ b/java/runtime/src/main/java/io/ray/runtime/AbstractRayRuntime.java @@ -28,7 +28,6 @@ import io.ray.runtime.functionmanager.FunctionDescriptor; import io.ray.runtime.functionmanager.FunctionManager; import io.ray.runtime.functionmanager.PyFunctionDescriptor; import io.ray.runtime.functionmanager.RayFunction; -import io.ray.runtime.gcs.GcsClient; import io.ray.runtime.generated.Common; import io.ray.runtime.generated.Common.Language; import io.ray.runtime.object.ObjectRefImpl; @@ -54,7 +53,6 @@ public abstract class AbstractRayRuntime implements RayRuntimeInternal { protected TaskExecutor taskExecutor; protected FunctionManager functionManager; protected RuntimeContext runtimeContext; - protected GcsClient gcsClient; protected ObjectStore objectStore; protected TaskSubmitter taskSubmitter; @@ -217,19 +215,19 @@ public abstract class AbstractRayRuntime implements RayRuntimeInternal { @Override public PlacementGroup getPlacementGroup(PlacementGroupId id) { - return gcsClient.getPlacementGroupInfo(id); + return getGcsClient().getPlacementGroupInfo(id); } @Override public PlacementGroup getPlacementGroup(String name, String namespace) { return namespace == null - ? gcsClient.getPlacementGroupInfo(name, runtimeContext.getNamespace()) - : gcsClient.getPlacementGroupInfo(name, namespace); + ? getGcsClient().getPlacementGroupInfo(name, runtimeContext.getNamespace()) + : getGcsClient().getPlacementGroupInfo(name, namespace); } @Override public List getAllPlacementGroups() { - return gcsClient.getAllPlacementGroupInfo(); + return getGcsClient().getAllPlacementGroupInfo(); } @Override @@ -396,11 +394,6 @@ public abstract class AbstractRayRuntime implements RayRuntimeInternal { return runtimeContext; } - @Override - public GcsClient getGcsClient() { - return gcsClient; - } - @Override public void setIsContextSet(boolean isContextSet) { this.isContextSet.set(isContextSet); diff --git a/java/runtime/src/main/java/io/ray/runtime/RayDevRuntime.java b/java/runtime/src/main/java/io/ray/runtime/RayDevRuntime.java index e02023c79..e9c3a9882 100644 --- a/java/runtime/src/main/java/io/ray/runtime/RayDevRuntime.java +++ b/java/runtime/src/main/java/io/ray/runtime/RayDevRuntime.java @@ -8,6 +8,7 @@ import io.ray.api.placementgroup.PlacementGroup; import io.ray.api.runtimecontext.ResourceValue; import io.ray.runtime.config.RayConfig; import io.ray.runtime.context.LocalModeWorkerContext; +import io.ray.runtime.gcs.GcsClient; import io.ray.runtime.generated.Common.TaskSpec; import io.ray.runtime.object.LocalModeObjectStore; import io.ray.runtime.task.LocalModeTaskExecutor; @@ -82,6 +83,11 @@ public class RayDevRuntime extends AbstractRayRuntime { return (Optional) ((LocalModeTaskSubmitter) taskSubmitter).getActor(name); } + @Override + public GcsClient getGcsClient() { + throw new UnsupportedOperationException("Ray doesn't have gcs client in local mode."); + } + @Override public Object getAsyncContext() { return new AsyncContext(((LocalModeWorkerContext) workerContext).getCurrentTask()); diff --git a/java/runtime/src/main/java/io/ray/runtime/RayNativeRuntime.java b/java/runtime/src/main/java/io/ray/runtime/RayNativeRuntime.java index 23516604e..3fc97470e 100644 --- a/java/runtime/src/main/java/io/ray/runtime/RayNativeRuntime.java +++ b/java/runtime/src/main/java/io/ray/runtime/RayNativeRuntime.java @@ -42,6 +42,8 @@ public final class RayNativeRuntime extends AbstractRayRuntime { private boolean startRayHead = false; + private GcsClient gcsClient; + /** * In Java, GC runs in a standalone thread, and we can't control the exact timing of garbage * collection. By using this lock, when {@link NativeObjectStore#nativeRemoveLocalReference} is @@ -54,9 +56,9 @@ public final class RayNativeRuntime extends AbstractRayRuntime { super(rayConfig); } - private void updateSessionDir(GcsClient gcsClient) { + private void updateSessionDir() { // Fetch session dir from GCS. - final String sessionDir = gcsClient.getInternalKV("@namespace_session:session_dir"); + final String sessionDir = getGcsClient().getInternalKV("@namespace_session:session_dir"); Preconditions.checkNotNull(sessionDir); rayConfig.setSessionDir(sessionDir); } @@ -77,8 +79,7 @@ public final class RayNativeRuntime extends AbstractRayRuntime { if (rayConfig.workerMode == WorkerType.DRIVER) { String tmpDir = "/tmp/ray/".concat(String.valueOf(System.currentTimeMillis())); JniUtils.loadLibrary(tmpDir, BinaryFileUtil.CORE_WORKER_JAVA_LIBRARY, true); - gcsClient = new GcsClient(rayConfig.getRedisAddress(), rayConfig.redisPassword); - updateSessionDir(gcsClient); + updateSessionDir(); Preconditions.checkNotNull(rayConfig.sessionDir); } else { // Expose ray ABI symbols which may be depended by other shared @@ -86,18 +87,17 @@ public final class RayNativeRuntime extends AbstractRayRuntime { // See BUILD.bazel:libcore_worker_library_java.so Preconditions.checkNotNull(rayConfig.sessionDir); JniUtils.loadLibrary(rayConfig.sessionDir, BinaryFileUtil.CORE_WORKER_JAVA_LIBRARY, true); - gcsClient = new GcsClient(rayConfig.getRedisAddress(), rayConfig.redisPassword); } if (rayConfig.workerMode == WorkerType.DRIVER) { - GcsNodeInfo nodeInfo = gcsClient.getNodeToConnectForDriver(rayConfig.nodeIp); + GcsNodeInfo nodeInfo = getGcsClient().getNodeToConnectForDriver(rayConfig.nodeIp); rayConfig.rayletSocketName = nodeInfo.getRayletSocketName(); rayConfig.objectStoreSocketName = nodeInfo.getObjectStoreSocketName(); rayConfig.nodeManagerPort = nodeInfo.getNodeManagerPort(); } if (rayConfig.workerMode == WorkerType.DRIVER && rayConfig.getJobId() == JobId.NIL) { - rayConfig.setJobId(gcsClient.nextJobId()); + rayConfig.setJobId(getGcsClient().nextJobId()); } int numWorkersPerProcess = rayConfig.workerMode == WorkerType.DRIVER ? 1 : rayConfig.numWorkersPerProcess; @@ -233,6 +233,18 @@ public final class RayNativeRuntime extends AbstractRayRuntime { String.format("Actor %s is exiting.", runtimeContext.getCurrentActorId())); } + @Override + public GcsClient getGcsClient() { + if (gcsClient == null) { + synchronized (this) { + if (gcsClient == null) { + gcsClient = new GcsClient(rayConfig.getRedisAddress(), rayConfig.redisPassword); + } + } + } + return gcsClient; + } + @Override public void run() { Preconditions.checkState(rayConfig.workerMode == WorkerType.WORKER);