diff --git a/cpp/src/ray/util/process_helper.cc b/cpp/src/ray/util/process_helper.cc index e1e7ba611..b5286cf7a 100644 --- a/cpp/src/ray/util/process_helper.cc +++ b/cpp/src/ray/util/process_helper.cc @@ -139,7 +139,6 @@ void ProcessHelper::RayStart(CoreWorkerOptions::TaskExecutionCallback callback) options.node_manager_port = ConfigInternal::Instance().node_manager_port; options.raylet_ip_address = node_ip; options.driver_name = "cpp_worker"; - options.num_workers = 1; options.metrics_agent_port = -1; options.task_execution_callback = callback; options.startup_token = ConfigInternal::Instance().startup_token; diff --git a/java/api/src/main/java/io/ray/api/Ray.java b/java/api/src/main/java/io/ray/api/Ray.java index 5dc3c4e6a..a6be0cd04 100644 --- a/java/api/src/main/java/io/ray/api/Ray.java +++ b/java/api/src/main/java/io/ray/api/Ray.java @@ -5,7 +5,6 @@ import io.ray.api.runtime.RayRuntimeFactory; import io.ray.api.runtimecontext.RuntimeContext; import java.util.List; import java.util.Optional; -import java.util.concurrent.Callable; /** This class contains all public APIs of Ray. */ public final class Ray extends RayCall { @@ -205,52 +204,6 @@ public final class Ray extends RayCall { return internal().getActor(name, namespace); } - /** - * If users want to use Ray API in their own threads, call this method to get the async context - * and then call {@link #setAsyncContext} at the beginning of the new thread. - * - * @return The async context. - */ - public static Object getAsyncContext() { - return internal().getAsyncContext(); - } - - /** - * Set the async context for the current thread. - * - * @param asyncContext The async context to set. - */ - public static void setAsyncContext(Object asyncContext) { - internal().setAsyncContext(asyncContext); - } - - // TODO (kfstorm): add the `rollbackAsyncContext` API to allow rollbacking the async context of - // the current thread to the one before `setAsyncContext` is called. - - // TODO (kfstorm): unify the `wrap*` methods. - - /** - * If users want to use Ray API in their own threads, they should wrap their {@link Runnable} - * objects with this method. - * - * @param runnable The runnable to wrap. - * @return The wrapped runnable. - */ - public static Runnable wrapRunnable(Runnable runnable) { - return internal().wrapRunnable(runnable); - } - - /** - * If users want to use Ray API in their own threads, they should wrap their {@link Callable} - * objects with this method. - * - * @param callable The callable to wrap. - * @return The wrapped callable. - */ - public static Callable wrapCallable(Callable callable) { - return internal().wrapCallable(callable); - } - /** Get the underlying runtime instance. */ public static RayRuntime internal() { if (runtime == null) { diff --git a/java/api/src/main/java/io/ray/api/runtime/RayRuntime.java b/java/api/src/main/java/io/ray/api/runtime/RayRuntime.java index ca2350a2b..c2e084c99 100644 --- a/java/api/src/main/java/io/ray/api/runtime/RayRuntime.java +++ b/java/api/src/main/java/io/ray/api/runtime/RayRuntime.java @@ -24,7 +24,6 @@ import io.ray.api.runtimeenv.RuntimeEnv; import java.util.List; import java.util.Map; import java.util.Optional; -import java.util.concurrent.Callable; /** Base interface of a Ray runtime. */ public interface RayRuntime { @@ -207,26 +206,6 @@ public interface RayRuntime { RuntimeContext getRuntimeContext(); - Object getAsyncContext(); - - void setAsyncContext(Object asyncContext); - - /** - * Wrap a {@link Runnable} with necessary context capture. - * - * @param runnable The runnable to wrap. - * @return The wrapped runnable. - */ - Runnable wrapRunnable(Runnable runnable); - - /** - * Wrap a {@link Callable} with necessary context capture. - * - * @param callable The callable to wrap. - * @return The wrapped callable. - */ - Callable wrapCallable(Callable callable); - /** Intentionally exit the current actor. */ void exitActor(); diff --git a/java/performance_test/src/main/java/io/ray/performancetest/test/ActorPerformanceTestBase.java b/java/performance_test/src/main/java/io/ray/performancetest/test/ActorPerformanceTestBase.java index e118567aa..bcff2ed7f 100644 --- a/java/performance_test/src/main/java/io/ray/performancetest/test/ActorPerformanceTestBase.java +++ b/java/performance_test/src/main/java/io/ray/performancetest/test/ActorPerformanceTestBase.java @@ -23,10 +23,7 @@ public class ActorPerformanceTestBase { boolean hasReturn, boolean ignoreReturn, int argSize, - boolean useDirectByteBuffer, - int numJavaWorkerPerProcess) { - System.setProperty( - "ray.job.num-java-workers-per-process", String.valueOf(numJavaWorkerPerProcess)); + boolean useDirectByteBuffer) { System.setProperty("ray.raylet.startup-token", "0"); Ray.init(); try { diff --git a/java/performance_test/src/main/java/io/ray/performancetest/test/ActorPerformanceTestCase1.java b/java/performance_test/src/main/java/io/ray/performancetest/test/ActorPerformanceTestCase1.java index c095d39ba..036c13409 100644 --- a/java/performance_test/src/main/java/io/ray/performancetest/test/ActorPerformanceTestCase1.java +++ b/java/performance_test/src/main/java/io/ray/performancetest/test/ActorPerformanceTestCase1.java @@ -13,7 +13,6 @@ public class ActorPerformanceTestCase1 { final int argSize = 0; final boolean useDirectByteBuffer = false; final boolean ignoreReturn = false; - final int numJavaWorkerPerProcess = 1; ActorPerformanceTestBase.run( args, layers, @@ -21,7 +20,6 @@ public class ActorPerformanceTestCase1 { hasReturn, ignoreReturn, argSize, - useDirectByteBuffer, - numJavaWorkerPerProcess); + useDirectByteBuffer); } } diff --git a/java/runtime/src/main/java/io/ray/runtime/AbstractRayRuntime.java b/java/runtime/src/main/java/io/ray/runtime/AbstractRayRuntime.java index 4a8a9f547..7659a6b35 100644 --- a/java/runtime/src/main/java/io/ray/runtime/AbstractRayRuntime.java +++ b/java/runtime/src/main/java/io/ray/runtime/AbstractRayRuntime.java @@ -31,7 +31,6 @@ import io.ray.runtime.functionmanager.FunctionDescriptor; import io.ray.runtime.functionmanager.FunctionManager; import io.ray.runtime.functionmanager.PyFunctionDescriptor; import io.ray.runtime.functionmanager.RayFunction; -import io.ray.runtime.generated.Common; import io.ray.runtime.generated.Common.Language; import io.ray.runtime.object.ObjectRefImpl; import io.ray.runtime.object.ObjectStore; @@ -46,7 +45,6 @@ import java.util.ArrayList; import java.util.List; import java.util.Map; import java.util.Optional; -import java.util.concurrent.Callable; import java.util.stream.Collectors; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -67,13 +65,8 @@ public abstract class AbstractRayRuntime implements RayRuntimeInternal { private static ParallelActorContextImpl parallelActorContextImpl = new ParallelActorContextImpl(); - /** Whether the required thread context is set on the current thread. */ - final ThreadLocal isContextSet = ThreadLocal.withInitial(() -> false); - public AbstractRayRuntime(RayConfig rayConfig) { this.rayConfig = rayConfig; - setIsContextSet(rayConfig.workerMode == Common.WorkerType.DRIVER); - functionManager = new FunctionManager(rayConfig.codeSearchPath); runtimeContext = new RuntimeContextImpl(this); } @@ -158,7 +151,7 @@ public abstract class AbstractRayRuntime implements RayRuntimeInternal { @Override public ObjectRef call(RayFunc func, Object[] args, CallOptions options) { - RayFunction rayFunction = functionManager.getFunction(workerContext.getCurrentJobId(), func); + RayFunction rayFunction = functionManager.getFunction(func); FunctionDescriptor functionDescriptor = rayFunction.functionDescriptor; Optional> returnType = rayFunction.getReturnType(); return callNormalFunction(functionDescriptor, args, returnType, options); @@ -176,7 +169,7 @@ public abstract class AbstractRayRuntime implements RayRuntimeInternal { @Override public ObjectRef callActor( ActorHandle actor, RayFunc func, Object[] args, CallOptions options) { - RayFunction rayFunction = functionManager.getFunction(workerContext.getCurrentJobId(), func); + RayFunction rayFunction = functionManager.getFunction(func); FunctionDescriptor functionDescriptor = rayFunction.functionDescriptor; Optional> returnType = rayFunction.getReturnType(); return callActorFunction(actor, functionDescriptor, args, returnType, options); @@ -201,8 +194,7 @@ public abstract class AbstractRayRuntime implements RayRuntimeInternal { public ActorHandle createActor( RayFunc actorFactoryFunc, Object[] args, ActorCreationOptions options) { FunctionDescriptor functionDescriptor = - functionManager.getFunction(workerContext.getCurrentJobId(), actorFactoryFunc) - .functionDescriptor; + functionManager.getFunction(actorFactoryFunc).functionDescriptor; return (ActorHandle) createActorImpl(functionDescriptor, args, options); } @@ -256,31 +248,6 @@ public abstract class AbstractRayRuntime implements RayRuntimeInternal { return (T) taskSubmitter.getActor(actorId); } - @Override - public void setAsyncContext(Object asyncContext) { - isContextSet.set(true); - } - - @Override - public final Runnable wrapRunnable(Runnable runnable) { - Object asyncContext = getAsyncContext(); - return () -> { - try (RayAsyncContextUpdater updater = new RayAsyncContextUpdater(asyncContext, this)) { - runnable.run(); - } - }; - } - - @Override - public final Callable wrapCallable(Callable callable) { - Object asyncContext = getAsyncContext(); - return () -> { - try (RayAsyncContextUpdater updater = new RayAsyncContextUpdater(asyncContext, this)) { - return callable.call(); - } - }; - } - @Override public ConcurrencyGroup createConcurrencyGroup( String name, int maxConcurrency, List funcs) { @@ -387,34 +354,6 @@ public abstract class AbstractRayRuntime implements RayRuntimeInternal { return actor; } - /// An auto closable class that is used for updating the async context when invoking Ray APIs. - private static final class RayAsyncContextUpdater implements AutoCloseable { - - private AbstractRayRuntime runtime; - - private boolean oldIsContextSet; - - private Object oldAsyncContext = null; - - public RayAsyncContextUpdater(Object asyncContext, AbstractRayRuntime runtime) { - this.runtime = runtime; - oldIsContextSet = runtime.isContextSet.get(); - if (oldIsContextSet) { - oldAsyncContext = runtime.getAsyncContext(); - } - runtime.setAsyncContext(asyncContext); - } - - @Override - public void close() { - if (oldIsContextSet) { - runtime.setAsyncContext(oldAsyncContext); - } else { - runtime.setIsContextSet(false); - } - } - } - abstract List getCurrentReturnIds(int numReturns, ActorId actorId); @Override @@ -446,11 +385,6 @@ public abstract class AbstractRayRuntime implements RayRuntimeInternal { return runtimeContext; } - @Override - public void setIsContextSet(boolean isContextSet) { - this.isContextSet.set(isContextSet); - } - /// A helper to validate if the prepared return ids is as expected. void validatePreparedReturnIds(List preparedReturnIds, List realReturnIds) { if (rayConfig.runMode == RunMode.CLUSTER) { diff --git a/java/runtime/src/main/java/io/ray/runtime/ConcurrencyGroupImpl.java b/java/runtime/src/main/java/io/ray/runtime/ConcurrencyGroupImpl.java index b5b2e4ebd..53ac57da5 100644 --- a/java/runtime/src/main/java/io/ray/runtime/ConcurrencyGroupImpl.java +++ b/java/runtime/src/main/java/io/ray/runtime/ConcurrencyGroupImpl.java @@ -24,9 +24,7 @@ public class ConcurrencyGroupImpl implements ConcurrencyGroup { funcs.forEach( func -> { RayFunction rayFunc = - ((RayRuntimeInternal) Ray.internal()) - .getFunctionManager() - .getFunction(Ray.getRuntimeContext().getCurrentJobId(), func); + ((RayRuntimeInternal) Ray.internal()).getFunctionManager().getFunction(func); functionDescriptors.add(rayFunc.getFunctionDescriptor()); }); } diff --git a/java/runtime/src/main/java/io/ray/runtime/DefaultRayRuntimeFactory.java b/java/runtime/src/main/java/io/ray/runtime/DefaultRayRuntimeFactory.java index 2356a70c7..a9461b820 100644 --- a/java/runtime/src/main/java/io/ray/runtime/DefaultRayRuntimeFactory.java +++ b/java/runtime/src/main/java/io/ray/runtime/DefaultRayRuntimeFactory.java @@ -32,10 +32,7 @@ public class DefaultRayRuntimeFactory implements RayRuntimeFactory { rayConfig.runMode == RunMode.SINGLE_PROCESS ? new RayDevRuntime(rayConfig) : new RayNativeRuntime(rayConfig); - RayRuntimeInternal runtime = - rayConfig.numWorkersPerProcess > 1 - ? RayRuntimeProxy.newInstance(innerRuntime) - : innerRuntime; + RayRuntimeInternal runtime = innerRuntime; runtime.start(); return runtime; } catch (Exception e) { diff --git a/java/runtime/src/main/java/io/ray/runtime/RayDevRuntime.java b/java/runtime/src/main/java/io/ray/runtime/RayDevRuntime.java index 080a31fcc..cba3ff298 100644 --- a/java/runtime/src/main/java/io/ray/runtime/RayDevRuntime.java +++ b/java/runtime/src/main/java/io/ray/runtime/RayDevRuntime.java @@ -1,6 +1,5 @@ package io.ray.runtime; -import com.google.common.base.Preconditions; import io.ray.api.BaseActorHandle; import io.ray.api.id.ActorId; import io.ray.api.id.JobId; @@ -10,6 +9,7 @@ import io.ray.api.placementgroup.PlacementGroup; import io.ray.api.runtimecontext.ResourceValue; import io.ray.runtime.config.RayConfig; import io.ray.runtime.context.LocalModeWorkerContext; +import io.ray.runtime.functionmanager.FunctionManager; import io.ray.runtime.gcs.GcsClient; import io.ray.runtime.generated.Common.TaskSpec; import io.ray.runtime.object.LocalModeObjectStore; @@ -24,13 +24,9 @@ import java.util.List; import java.util.Map; import java.util.Optional; import java.util.concurrent.atomic.AtomicInteger; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; public class RayDevRuntime extends AbstractRayRuntime { - private static final Logger LOGGER = LoggerFactory.getLogger(RayDevRuntime.class); - private AtomicInteger jobCounter = new AtomicInteger(0); public RayDevRuntime(RayConfig rayConfig) { @@ -49,6 +45,7 @@ public class RayDevRuntime extends AbstractRayRuntime { taskExecutor = new LocalModeTaskExecutor(this); workerContext = new LocalModeWorkerContext(rayConfig.getJobId()); objectStore = new LocalModeObjectStore(workerContext); + functionManager = new FunctionManager(rayConfig.codeSearchPath); taskSubmitter = new LocalModeTaskSubmitter(this, taskExecutor, (LocalModeObjectStore) objectStore); ((LocalModeObjectStore) objectStore) @@ -90,19 +87,6 @@ public class RayDevRuntime extends AbstractRayRuntime { throw new UnsupportedOperationException("Ray doesn't have gcs client in local mode."); } - @Override - public Object getAsyncContext() { - return new AsyncContext(((LocalModeWorkerContext) workerContext).getCurrentTask()); - } - - @Override - public void setAsyncContext(Object asyncContext) { - Preconditions.checkNotNull(asyncContext); - TaskSpec task = ((AsyncContext) asyncContext).task; - ((LocalModeWorkerContext) workerContext).setCurrentTask(task); - super.setAsyncContext(asyncContext); - } - @Override public Map> getAvailableResourceIds() { throw new UnsupportedOperationException("Ray doesn't support get resources ids in local mode."); diff --git a/java/runtime/src/main/java/io/ray/runtime/RayNativeRuntime.java b/java/runtime/src/main/java/io/ray/runtime/RayNativeRuntime.java index b53d655f6..3328413c5 100644 --- a/java/runtime/src/main/java/io/ray/runtime/RayNativeRuntime.java +++ b/java/runtime/src/main/java/io/ray/runtime/RayNativeRuntime.java @@ -14,6 +14,7 @@ import io.ray.api.runtimecontext.ResourceValue; import io.ray.runtime.config.RayConfig; import io.ray.runtime.context.NativeWorkerContext; import io.ray.runtime.exception.RayIntentionalSystemExitException; +import io.ray.runtime.functionmanager.FunctionManager; import io.ray.runtime.gcs.GcsClient; import io.ray.runtime.gcs.GcsClientOptions; import io.ray.runtime.generated.Common.WorkerType; @@ -102,14 +103,13 @@ public final class RayNativeRuntime extends AbstractRayRuntime { if (rayConfig.workerMode == WorkerType.DRIVER && rayConfig.getJobId() == JobId.NIL) { rayConfig.setJobId(getGcsClient().nextJobId()); } - int numWorkersPerProcess = - rayConfig.workerMode == WorkerType.DRIVER ? 1 : rayConfig.numWorkersPerProcess; + // Make sure the job id has been set already. + functionManager = new FunctionManager(rayConfig.codeSearchPath); byte[] serializedJobConfig = null; if (rayConfig.workerMode == WorkerType.DRIVER) { JobConfig.Builder jobConfigBuilder = JobConfig.newBuilder() - .setNumJavaWorkersPerProcess(rayConfig.numWorkersPerProcess) .addAllJvmOptions(rayConfig.jvmOptionsForJavaWorker) .addAllCodeSearchPath(rayConfig.codeSearchPath) .setRayNamespace(rayConfig.namespace); @@ -150,7 +150,6 @@ public final class RayNativeRuntime extends AbstractRayRuntime { rayConfig.rayletSocketName, (rayConfig.workerMode == WorkerType.DRIVER ? rayConfig.getJobId() : JobId.NIL).getBytes(), new GcsClientOptions(rayConfig), - numWorkersPerProcess, rayConfig.logDir, serializedJobConfig, rayConfig.getStartupToken(), @@ -223,19 +222,6 @@ public final class RayNativeRuntime extends AbstractRayRuntime { nativeKillActor(actor.getId().getBytes(), noRestart); } - @Override - public Object getAsyncContext() { - return new AsyncContext( - workerContext.getCurrentWorkerId(), workerContext.getCurrentClassLoader()); - } - - @Override - public void setAsyncContext(Object asyncContext) { - nativeSetCoreWorker(((AsyncContext) asyncContext).workerId.getBytes()); - workerContext.setCurrentClassLoader(((AsyncContext) asyncContext).currentClassLoader); - super.setAsyncContext(asyncContext); - } - @Override List getCurrentReturnIds(int numReturns, ActorId actorId) { List ret = nativeGetCurrentReturnIds(numReturns, actorId.getBytes()); @@ -289,7 +275,6 @@ public final class RayNativeRuntime extends AbstractRayRuntime { String rayletSocket, byte[] jobId, GcsClientOptions gcsClientOptions, - int numWorkersPerProcess, String logDir, byte[] serializedJobConfig, int startupToken, diff --git a/java/runtime/src/main/java/io/ray/runtime/RayRuntimeInternal.java b/java/runtime/src/main/java/io/ray/runtime/RayRuntimeInternal.java index 8fc14caa3..fd1a23b90 100644 --- a/java/runtime/src/main/java/io/ray/runtime/RayRuntimeInternal.java +++ b/java/runtime/src/main/java/io/ray/runtime/RayRuntimeInternal.java @@ -26,7 +26,5 @@ public interface RayRuntimeInternal extends RayRuntime { GcsClient getGcsClient(); - void setIsContextSet(boolean isContextSet); - void run(); } diff --git a/java/runtime/src/main/java/io/ray/runtime/RayRuntimeProxy.java b/java/runtime/src/main/java/io/ray/runtime/RayRuntimeProxy.java deleted file mode 100644 index 4a427295c..000000000 --- a/java/runtime/src/main/java/io/ray/runtime/RayRuntimeProxy.java +++ /dev/null @@ -1,80 +0,0 @@ -package io.ray.runtime; - -import io.ray.api.runtime.RayRuntime; -import io.ray.runtime.config.RunMode; -import io.ray.runtime.exception.RayException; -import java.lang.reflect.InvocationHandler; -import java.lang.reflect.InvocationTargetException; -import java.lang.reflect.Method; - -/** - * Protect a ray runtime with context checks for all methods of {@link RayRuntime} (except {@link - * RayRuntime#shutdown(boolean)}). - */ -public class RayRuntimeProxy implements InvocationHandler { - - /** The original runtime. */ - private AbstractRayRuntime obj; - - private RayRuntimeProxy(AbstractRayRuntime obj) { - this.obj = obj; - } - - public AbstractRayRuntime getRuntimeObject() { - return obj; - } - - /** Generate a new instance of {@link RayRuntimeInternal} with additional context check. */ - static RayRuntimeInternal newInstance(AbstractRayRuntime obj) { - return (RayRuntimeInternal) - java.lang.reflect.Proxy.newProxyInstance( - obj.getClass().getClassLoader(), - new Class[] {RayRuntimeInternal.class}, - new RayRuntimeProxy(obj)); - } - - @Override - public Object invoke(Object proxy, Method method, Object[] args) throws Throwable { - if (isInterfaceMethod(method) - && !method.getName().equals("shutdown") - && !method.getName().equals("setAsyncContext")) { - checkIsContextSet(); - } - try { - return method.invoke(obj, args); - } catch (InvocationTargetException e) { - if (e.getCause() != null) { - throw e.getCause(); - } else { - throw e; - } - } - } - - /** Whether the method is defined in the {@link RayRuntime} interface. */ - private boolean isInterfaceMethod(Method method) { - try { - RayRuntime.class.getMethod(method.getName(), method.getParameterTypes()); - return true; - } catch (NoSuchMethodException e) { - return false; - } - } - - /** - * Check if thread context is set. - * - *

This method should be invoked at the beginning of most public methods of {@link RayRuntime}, - * otherwise the native code might crash due to thread local core worker was not set. We check it - * for {@link AbstractRayRuntime} instead of {@link RayNativeRuntime} because we want to catch the - * error even if the application runs in {@link RunMode#SINGLE_PROCESS} mode. - */ - private void checkIsContextSet() { - if (!obj.isContextSet.get()) { - throw new RayException( - "`Ray.wrap***` is not called on the current thread." - + " If you want to use Ray API in your own threads," - + " please wrap your executable with `Ray.wrap***`."); - } - } -} diff --git a/java/runtime/src/main/java/io/ray/runtime/config/RayConfig.java b/java/runtime/src/main/java/io/ray/runtime/config/RayConfig.java index dd0e2ef3a..89dada1eb 100644 --- a/java/runtime/src/main/java/io/ray/runtime/config/RayConfig.java +++ b/java/runtime/src/main/java/io/ray/runtime/config/RayConfig.java @@ -75,8 +75,6 @@ public class RayConfig { public final List headArgs; - public final int numWorkersPerProcess; - public final String namespace; public final List jvmOptionsForJavaWorker; @@ -190,8 +188,6 @@ public class RayConfig { } codeSearchPath = Arrays.asList(codeSearchPathString.split(":")); - numWorkersPerProcess = config.getInt("ray.job.num-java-workers-per-process"); - startupToken = config.getInt("ray.raylet.startup-token"); /// Driver needn't this config item. diff --git a/java/runtime/src/main/java/io/ray/runtime/context/LocalModeWorkerContext.java b/java/runtime/src/main/java/io/ray/runtime/context/LocalModeWorkerContext.java index 76ba443a2..312a86c5a 100644 --- a/java/runtime/src/main/java/io/ray/runtime/context/LocalModeWorkerContext.java +++ b/java/runtime/src/main/java/io/ray/runtime/context/LocalModeWorkerContext.java @@ -48,31 +48,21 @@ public class LocalModeWorkerContext implements WorkerContext { @Override public ActorId getCurrentActorId() { TaskSpec taskSpec = currentTask.get(); - if (taskSpec == null) { - return ActorId.NIL; - } + checkTaskSpecNotNull(taskSpec); return LocalModeTaskSubmitter.getActorId(taskSpec); } - @Override - public ClassLoader getCurrentClassLoader() { - return null; - } - - @Override - public void setCurrentClassLoader(ClassLoader currentClassLoader) {} - @Override public TaskType getCurrentTaskType() { TaskSpec taskSpec = currentTask.get(); - Preconditions.checkNotNull(taskSpec, "Current task is not set."); + checkTaskSpecNotNull(taskSpec); return taskSpec.getType(); } @Override public TaskId getCurrentTaskId() { TaskSpec taskSpec = currentTask.get(); - Preconditions.checkState(taskSpec != null); + checkTaskSpecNotNull(taskSpec); return TaskId.fromBytes(taskSpec.getTaskId().toByteArray()); } @@ -85,7 +75,9 @@ public class LocalModeWorkerContext implements WorkerContext { currentTask.set(taskSpec); } - public TaskSpec getCurrentTask() { - return currentTask.get(); + private static void checkTaskSpecNotNull(TaskSpec taskSpec) { + Preconditions.checkNotNull( + taskSpec, + "Current task is not set. Maybe you invoked this API in a user-created thread not managed by Ray. Invoking this API in a user-created thread is not supported yet in local mode. You can switch to cluster mode."); } } diff --git a/java/runtime/src/main/java/io/ray/runtime/context/NativeWorkerContext.java b/java/runtime/src/main/java/io/ray/runtime/context/NativeWorkerContext.java index d20d48d0a..467dd0169 100644 --- a/java/runtime/src/main/java/io/ray/runtime/context/NativeWorkerContext.java +++ b/java/runtime/src/main/java/io/ray/runtime/context/NativeWorkerContext.java @@ -12,7 +12,7 @@ import java.nio.ByteBuffer; /** Worker context for cluster mode. This is a wrapper class for worker context of core worker. */ public class NativeWorkerContext implements WorkerContext { - private final ThreadLocal currentClassLoader = new ThreadLocal<>(); + private ClassLoader currentClassLoader = null; @Override public UniqueId getCurrentWorkerId() { @@ -29,18 +29,6 @@ public class NativeWorkerContext implements WorkerContext { return ActorId.fromByteBuffer(nativeGetCurrentActorId()); } - @Override - public ClassLoader getCurrentClassLoader() { - return currentClassLoader.get(); - } - - @Override - public void setCurrentClassLoader(ClassLoader currentClassLoader) { - if (this.currentClassLoader.get() != currentClassLoader) { - this.currentClassLoader.set(currentClassLoader); - } - } - @Override public TaskType getCurrentTaskType() { return TaskType.forNumber(nativeGetCurrentTaskType()); diff --git a/java/runtime/src/main/java/io/ray/runtime/context/WorkerContext.java b/java/runtime/src/main/java/io/ray/runtime/context/WorkerContext.java index ff750f651..3e3ac01e9 100644 --- a/java/runtime/src/main/java/io/ray/runtime/context/WorkerContext.java +++ b/java/runtime/src/main/java/io/ray/runtime/context/WorkerContext.java @@ -19,15 +19,6 @@ public interface WorkerContext { /** ID of the current actor. */ ActorId getCurrentActorId(); - /** - * The class loader that is associated with the current job. It's used for locating classes when - * dealing with serialization and deserialization in {@link Serializer}. - */ - ClassLoader getCurrentClassLoader(); - - /** Set the current class loader. */ - void setCurrentClassLoader(ClassLoader currentClassLoader); - /** Type of the current task. */ TaskType getCurrentTaskType(); diff --git a/java/runtime/src/main/java/io/ray/runtime/functionmanager/FunctionManager.java b/java/runtime/src/main/java/io/ray/runtime/functionmanager/FunctionManager.java index 75e9c2821..3ad60b434 100644 --- a/java/runtime/src/main/java/io/ray/runtime/functionmanager/FunctionManager.java +++ b/java/runtime/src/main/java/io/ray/runtime/functionmanager/FunctionManager.java @@ -2,7 +2,6 @@ package io.ray.runtime.functionmanager; import com.google.common.collect.Lists; import io.ray.api.function.RayFunc; -import io.ray.api.id.JobId; import io.ray.runtime.util.LambdaUtils; import java.io.File; import java.lang.invoke.SerializedLambda; @@ -34,7 +33,7 @@ import org.objectweb.asm.Type; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -/** Manages functions by job id. */ +/** Manages functions in the current worker. */ public class FunctionManager { private static final Logger LOGGER = LoggerFactory.getLogger(FunctionManager.class); @@ -50,8 +49,8 @@ public class FunctionManager { private static final ThreadLocal, JavaFunctionDescriptor>> RAY_FUNC_CACHE = ThreadLocal.withInitial(WeakHashMap::new); - /** Mapping from the job id to the functions that belong to this job. */ - private ConcurrentMap jobFunctionTables = new ConcurrentHashMap<>(); + /** The table that manages functions. */ + private final JobFunctionTable jobFunctionTable; /** The resource path which we can load the job's jar resources. */ private final List codeSearchPath; @@ -63,16 +62,20 @@ public class FunctionManager { */ public FunctionManager(List codeSearchPath) { this.codeSearchPath = codeSearchPath; + jobFunctionTable = createJobFunctionTable(); + } + + public ClassLoader getClassLoader() { + return jobFunctionTable.classLoader; } /** * Get the RayFunction from a RayFunc instance (a lambda). * - * @param jobId current job id. * @param func The lambda. * @return A RayFunction object. */ - public RayFunction getFunction(JobId jobId, RayFunc func) { + public RayFunction getFunction(RayFunc func) { JavaFunctionDescriptor functionDescriptor = RAY_FUNC_CACHE.get().get(func.getClass()); if (functionDescriptor == null) { // It's OK to not lock here, because it's OK to have multiple JavaFunctionDescriptor instances @@ -84,31 +87,21 @@ public class FunctionManager { functionDescriptor = new JavaFunctionDescriptor(className, methodName, signature); RAY_FUNC_CACHE.get().put(func.getClass(), functionDescriptor); } - return getFunction(jobId, functionDescriptor); + return getFunction(functionDescriptor); } /** * Get the RayFunction from a function descriptor. * - * @param jobId Current job id. * @param functionDescriptor The function descriptor. * @return A RayFunction object. */ - public RayFunction getFunction(JobId jobId, JavaFunctionDescriptor functionDescriptor) { - JobFunctionTable jobFunctionTable = jobFunctionTables.get(jobId); - if (jobFunctionTable == null) { - synchronized (this) { - jobFunctionTable = jobFunctionTables.get(jobId); - if (jobFunctionTable == null) { - jobFunctionTable = createJobFunctionTable(jobId); - jobFunctionTables.put(jobId, jobFunctionTable); - } - } - } + public RayFunction getFunction(JavaFunctionDescriptor functionDescriptor) { return jobFunctionTable.getFunction(functionDescriptor); } - private JobFunctionTable createJobFunctionTable(JobId jobId) { + /** A helper that creates function table. */ + private JobFunctionTable createJobFunctionTable() { ClassLoader classLoader; if (codeSearchPath == null || codeSearchPath.isEmpty()) { classLoader = getClass().getClassLoader(); @@ -145,7 +138,7 @@ public class FunctionManager { }) .toArray(URL[]::new); classLoader = new URLClassLoader(urls); - LOGGER.debug("Resource loaded for job {} from path {}.", jobId, urls); + LOGGER.debug("Resource loaded from path {}.", urls); } return new JobFunctionTable(classLoader); diff --git a/java/runtime/src/main/java/io/ray/runtime/task/LocalModeTaskSubmitter.java b/java/runtime/src/main/java/io/ray/runtime/task/LocalModeTaskSubmitter.java index 0e5687657..c6376c639 100644 --- a/java/runtime/src/main/java/io/ray/runtime/task/LocalModeTaskSubmitter.java +++ b/java/runtime/src/main/java/io/ray/runtime/task/LocalModeTaskSubmitter.java @@ -496,7 +496,6 @@ public class LocalModeTaskSubmitter implements TaskSubmitter { ? objectStore.getRaw(Collections.singletonList(arg.id), -1).get(0) : arg.value) .collect(Collectors.toList()); - runtime.setIsContextSet(true); ((LocalModeWorkerContext) runtime.getWorkerContext()).setCurrentTask(taskSpec); ((LocalModeWorkerContext) runtime.getWorkerContext()).setCurrentWorkerId(workerId); @@ -513,9 +512,7 @@ public class LocalModeTaskSubmitter implements TaskSubmitter { // Set this flag to true is necessary because at the end of `taskExecutor.execute()`, // this flag will be set to false. And `runtime.getWorkerContext()` requires it to be // true. - runtime.setIsContextSet(true); ((LocalModeWorkerContext) runtime.getWorkerContext()).setCurrentTask(null); - runtime.setIsContextSet(false); List returnIds = getReturnIds(taskSpec); for (int i = 0; i < returnIds.size(); i++) { NativeRayObject putObject; diff --git a/java/runtime/src/main/java/io/ray/runtime/task/TaskExecutor.java b/java/runtime/src/main/java/io/ray/runtime/task/TaskExecutor.java index 9cce4cf02..1c7d17414 100644 --- a/java/runtime/src/main/java/io/ray/runtime/task/TaskExecutor.java +++ b/java/runtime/src/main/java/io/ray/runtime/task/TaskExecutor.java @@ -32,6 +32,7 @@ public abstract class TaskExecutor { protected final RayRuntimeInternal runtime; + // TODO(qwang): Use actorContext instead later. private final ConcurrentHashMap actorContextMap = new ConcurrentHashMap<>(); private final ThreadLocal localRayFunction = new ThreadLocal<>(); @@ -66,7 +67,7 @@ public abstract class TaskExecutor { private RayFunction getRayFunction(List rayFunctionInfo) { JobId jobId = runtime.getWorkerContext().getCurrentJobId(); JavaFunctionDescriptor functionDescriptor = parseFunctionDescriptor(rayFunctionInfo); - return runtime.getFunctionManager().getFunction(jobId, functionDescriptor); + return runtime.getFunctionManager().getFunction(functionDescriptor); } /** The return value indicates which parameters are ByteBuffer. */ @@ -93,7 +94,6 @@ public abstract class TaskExecutor { } protected List execute(List rayFunctionInfo, List argsBytes) { - runtime.setIsContextSet(true); TaskType taskType = runtime.getWorkerContext().getCurrentTaskType(); TaskId taskId = runtime.getWorkerContext().getCurrentTaskId(); LOGGER.debug("Executing task {} {}", taskId, rayFunctionInfo); @@ -108,7 +108,6 @@ public abstract class TaskExecutor { } List returnObjects = new ArrayList<>(); - ClassLoader oldLoader = Thread.currentThread().getContextClassLoader(); // Find the executable object. RayFunction rayFunction = localRayFunction.get(); @@ -121,7 +120,6 @@ public abstract class TaskExecutor { rayFunction = getRayFunction(rayFunctionInfo); } Thread.currentThread().setContextClassLoader(rayFunction.classLoader); - runtime.getWorkerContext().setCurrentClassLoader(rayFunction.classLoader); // Get local actor object and arguments. Object actor = null; @@ -215,10 +213,6 @@ public abstract class TaskExecutor { } else { throw new RayActorException(e); } - } finally { - Thread.currentThread().setContextClassLoader(oldLoader); - runtime.getWorkerContext().setCurrentClassLoader(null); - runtime.setIsContextSet(false); } return returnObjects; } diff --git a/java/runtime/src/main/java/io/ray/runtime/util/MethodUtils.java b/java/runtime/src/main/java/io/ray/runtime/util/MethodUtils.java index 8f59102f6..b6523562f 100644 --- a/java/runtime/src/main/java/io/ray/runtime/util/MethodUtils.java +++ b/java/runtime/src/main/java/io/ray/runtime/util/MethodUtils.java @@ -47,7 +47,7 @@ public final class MethodUtils { /// This code path indicates that here might be in another thread of a worker. /// So try to load the class from URLClassLoader of this worker. ClassLoader cl = - ((RayRuntimeInternal) Ray.internal()).getWorkerContext().getCurrentClassLoader(); + ((RayRuntimeInternal) Ray.internal()).getFunctionManager().getClassLoader(); actorClz = Class.forName(className, true, cl); } } catch (Exception e) { diff --git a/java/runtime/src/main/java/io/ray/runtime/utils/parallelactor/ParallelActorContextImpl.java b/java/runtime/src/main/java/io/ray/runtime/utils/parallelactor/ParallelActorContextImpl.java index cd0e3ebb4..9da1f4cd0 100644 --- a/java/runtime/src/main/java/io/ray/runtime/utils/parallelactor/ParallelActorContextImpl.java +++ b/java/runtime/src/main/java/io/ray/runtime/utils/parallelactor/ParallelActorContextImpl.java @@ -28,9 +28,7 @@ public class ParallelActorContextImpl implements ParallelActorContext { FunctionManager functionManager = ((RayRuntimeInternal) Ray.internal()).getFunctionManager(); JavaFunctionDescriptor functionDescriptor = - functionManager - .getFunction(Ray.getRuntimeContext().getCurrentJobId(), ctorFunc) - .getFunctionDescriptor(); + functionManager.getFunction(ctorFunc).getFunctionDescriptor(); ActorHandle parallelExecutorHandle = Ray.actor(ParallelActorExecutorImpl::new, parallelism, functionDescriptor) .setConcurrencyGroups(concurrencyGroups) @@ -46,9 +44,7 @@ public class ParallelActorContextImpl implements ParallelActorContext { ((ParallelActorHandleImpl) parallelActorHandle).getExecutor(); FunctionManager functionManager = ((RayRuntimeInternal) Ray.internal()).getFunctionManager(); JavaFunctionDescriptor functionDescriptor = - functionManager - .getFunction(Ray.getRuntimeContext().getCurrentJobId(), func) - .getFunctionDescriptor(); + functionManager.getFunction(func).getFunctionDescriptor(); ObjectRef ret = parallelExecutor .task(ParallelActorExecutorImpl::execute, instanceId, functionDescriptor, args) diff --git a/java/runtime/src/main/java/io/ray/runtime/utils/parallelactor/ParallelActorExecutorImpl.java b/java/runtime/src/main/java/io/ray/runtime/utils/parallelactor/ParallelActorExecutorImpl.java index f3d481381..3836303e1 100644 --- a/java/runtime/src/main/java/io/ray/runtime/utils/parallelactor/ParallelActorExecutorImpl.java +++ b/java/runtime/src/main/java/io/ray/runtime/utils/parallelactor/ParallelActorExecutorImpl.java @@ -23,9 +23,7 @@ public class ParallelActorExecutorImpl { throws InvocationTargetException, IllegalAccessException { functionManager = ((RayRuntimeInternal) Ray.internal()).getFunctionManager(); - RayFunction init = - functionManager.getFunction( - Ray.getRuntimeContext().getCurrentJobId(), javaFunctionDescriptor); + RayFunction init = functionManager.getFunction(javaFunctionDescriptor); Thread.currentThread().setContextClassLoader(init.classLoader); for (int i = 0; i < parallelism; ++i) { Object instance = init.getMethod().invoke(null, null); @@ -35,8 +33,7 @@ public class ParallelActorExecutorImpl { public Object execute(int instanceId, JavaFunctionDescriptor functionDescriptor, Object[] args) throws IllegalAccessException, InvocationTargetException { - RayFunction func = - functionManager.getFunction(Ray.getRuntimeContext().getCurrentJobId(), functionDescriptor); + RayFunction func = functionManager.getFunction(functionDescriptor); Preconditions.checkState(instances.containsKey(instanceId)); return func.getMethod().invoke(instances.get(instanceId), args); } diff --git a/java/runtime/src/main/resources/ray.default.conf b/java/runtime/src/main/resources/ray.default.conf index 0ff149281..cf73f321d 100644 --- a/java/runtime/src/main/resources/ray.default.conf +++ b/java/runtime/src/main/resources/ray.default.conf @@ -27,8 +27,6 @@ ray { // search path for user code. This will be used as `CLASSPATH` in Java, // and `PYTHONPATH` in Python. code-search-path: "" - /// The number of java worker per worker process. - num-java-workers-per-process: 1 /// The jvm options for java workers of the job. jvm-options: [] diff --git a/java/runtime/src/test/java/io/ray/runtime/functionmanager/FunctionManagerTest.java b/java/runtime/src/test/java/io/ray/runtime/functionmanager/FunctionManagerTest.java index e693a4da5..f1da038e7 100644 --- a/java/runtime/src/test/java/io/ray/runtime/functionmanager/FunctionManagerTest.java +++ b/java/runtime/src/test/java/io/ray/runtime/functionmanager/FunctionManagerTest.java @@ -61,6 +61,8 @@ public class FunctionManagerTest { } } + private static final JobId JOB_ID = JobId.fromInt(1); + private static RayFunc0 fooFunc; private static RayFunc1 childClassBarFunc; private static RayFunc0 childClassConstructor; @@ -95,17 +97,17 @@ public class FunctionManagerTest { public void testGetFunctionFromRayFunc() { final FunctionManager functionManager = new FunctionManager(null); // Test normal function. - RayFunction func = functionManager.getFunction(JobId.NIL, fooFunc); + RayFunction func = functionManager.getFunction(fooFunc); Assert.assertFalse(func.isConstructor()); Assert.assertEquals(func.getFunctionDescriptor(), fooDescriptor); // Test actor method - func = functionManager.getFunction(JobId.NIL, childClassBarFunc); + func = functionManager.getFunction(childClassBarFunc); Assert.assertFalse(func.isConstructor()); Assert.assertEquals(func.getFunctionDescriptor(), childClassBarDescriptor); // Test actor constructor - func = functionManager.getFunction(JobId.NIL, childClassConstructor); + func = functionManager.getFunction(childClassConstructor); Assert.assertTrue(func.isConstructor()); Assert.assertEquals(func.getFunctionDescriptor(), childClassConstructorDescriptor); } @@ -114,17 +116,17 @@ public class FunctionManagerTest { public void testGetFunctionFromFunctionDescriptor() { final FunctionManager functionManager = new FunctionManager(null); // Test normal function. - RayFunction func = functionManager.getFunction(JobId.NIL, fooDescriptor); + RayFunction func = functionManager.getFunction(fooDescriptor); Assert.assertFalse(func.isConstructor()); Assert.assertEquals(func.getFunctionDescriptor(), fooDescriptor); // Test actor method - func = functionManager.getFunction(JobId.NIL, childClassBarDescriptor); + func = functionManager.getFunction(childClassBarDescriptor); Assert.assertFalse(func.isConstructor()); Assert.assertEquals(func.getFunctionDescriptor(), childClassBarDescriptor); // Test actor constructor - func = functionManager.getFunction(JobId.NIL, childClassConstructorDescriptor); + func = functionManager.getFunction(childClassConstructorDescriptor); Assert.assertTrue(func.isConstructor()); Assert.assertEquals(func.getFunctionDescriptor(), childClassConstructorDescriptor); @@ -133,7 +135,6 @@ public class FunctionManagerTest { RuntimeException.class, () -> { functionManager.getFunction( - JobId.NIL, new JavaFunctionDescriptor( FunctionManagerTest.class.getName(), "overloadFunction", "")); }); @@ -146,11 +147,10 @@ public class FunctionManagerTest { fooDescriptor = new JavaFunctionDescriptor(ParentClass.class.getName(), "foo", "()Ljava/lang/Object;"); Assert.assertEquals( - functionManager.getFunction(JobId.NIL, fooDescriptor).executable.getDeclaringClass(), + functionManager.getFunction(fooDescriptor).executable.getDeclaringClass(), ParentClass.class); RayFunction fooFunc = functionManager.getFunction( - JobId.NIL, new JavaFunctionDescriptor(ChildClass.class.getName(), "foo", "()Ljava/lang/Object;")); Assert.assertEquals(fooFunc.executable.getDeclaringClass(), ParentClass.class); @@ -159,21 +159,16 @@ public class FunctionManagerTest { childClassBarDescriptor = new JavaFunctionDescriptor(ParentClass.class.getName(), "bar", "()Ljava/lang/Object;"); Assert.assertEquals( - functionManager - .getFunction(JobId.NIL, childClassBarDescriptor) - .executable - .getDeclaringClass(), + functionManager.getFunction(childClassBarDescriptor).executable.getDeclaringClass(), ParentClass.class); RayFunction barFunc = functionManager.getFunction( - JobId.NIL, new JavaFunctionDescriptor(ChildClass.class.getName(), "bar", "()Ljava/lang/Object;")); Assert.assertEquals(barFunc.executable.getDeclaringClass(), ChildClass.class); // Check interface default methods. RayFunction interfaceNameFunc = functionManager.getFunction( - JobId.NIL, new JavaFunctionDescriptor( ChildClass.class.getName(), "interfaceName", "()Ljava/lang/String;")); Assert.assertEquals( @@ -220,7 +215,6 @@ public class FunctionManagerTest { @Test public void testGetFunctionFromLocalResource() throws Exception { - JobId jobId = JobId.fromInt(1); final String codeSearchPath = FileUtils.getTempDirectoryPath() + "/ray_test_resources/"; File jobResourceDir = new File(codeSearchPath); FileUtils.deleteQuietly(jobResourceDir); @@ -250,7 +244,7 @@ public class FunctionManagerTest { new JavaFunctionDescriptor("DemoApp", "hello", "()Ljava/lang/String;"); final FunctionManager functionManager = new FunctionManager(Collections.singletonList(codeSearchPath)); - RayFunction func = functionManager.getFunction(jobId, descriptor); + RayFunction func = functionManager.getFunction(descriptor); Assert.assertEquals(func.getFunctionDescriptor(), descriptor); } } diff --git a/java/serve/src/main/java/io/ray/serve/HttpProxy.java b/java/serve/src/main/java/io/ray/serve/HttpProxy.java index 809337e75..0c5c6018a 100644 --- a/java/serve/src/main/java/io/ray/serve/HttpProxy.java +++ b/java/serve/src/main/java/io/ray/serve/HttpProxy.java @@ -1,7 +1,6 @@ package io.ray.serve; import com.google.common.collect.ImmutableMap; -import io.ray.api.Ray; import io.ray.runtime.metric.Count; import io.ray.runtime.metric.Metrics; import io.ray.runtime.metric.TagKey; @@ -47,8 +46,6 @@ public class HttpProxy implements ServeProxy { private ProxyRouter proxyRouter; - private Object asyncContext = Ray.getAsyncContext(); - @Override public void init(Map config, ProxyRouter proxyRouter) { this.port = @@ -101,8 +98,6 @@ public class HttpProxy implements ServeProxy { ClassicHttpRequest request, ClassicHttpResponse response, HttpContext context) throws HttpException, IOException { - Ray.setAsyncContext(asyncContext); - int code = HttpURLConnection.HTTP_OK; Object result = null; String route = request.getPath(); diff --git a/java/test.sh b/java/test.sh index ba5f16ae4..7a8bdc637 100755 --- a/java/test.sh +++ b/java/test.sh @@ -122,8 +122,7 @@ for file in "$docdemo_path"*.java; do file=${file#"$docdemo_path"} class=${file%".java"} echo "Running $class" - java -cp bazel-bin/java/all_tests_shaded.jar -Dray.job.num-java-workers-per-process=1\ - -Dray.raylet.startup-token=0 "io.ray.docdemo.$class" + java -cp bazel-bin/java/all_tests_shaded.jar -Dray.raylet.startup-token=0 "io.ray.docdemo.$class" done popd diff --git a/java/test/src/main/java/io/ray/test/ActorHandleReferenceCountTest.java b/java/test/src/main/java/io/ray/test/ActorHandleReferenceCountTest.java index 13f8af022..42b24fa5e 100644 --- a/java/test/src/main/java/io/ray/test/ActorHandleReferenceCountTest.java +++ b/java/test/src/main/java/io/ray/test/ActorHandleReferenceCountTest.java @@ -4,13 +4,11 @@ import io.ray.api.ActorHandle; import io.ray.api.ObjectRef; import io.ray.api.Ray; import io.ray.runtime.actor.NativeActorHandle; -import io.ray.runtime.exception.RayActorException; import io.ray.runtime.util.SystemUtil; import java.lang.ref.Reference; import java.lang.reflect.Field; import java.lang.reflect.Method; import java.util.Set; -import java.util.concurrent.TimeUnit; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.testng.Assert; @@ -65,7 +63,6 @@ public class ActorHandleReferenceCountTest { public void testActorHandleReferenceCount() { try { - System.setProperty("ray.job.num-java-workers-per-process", "1"); Ray.init(); ActorHandle signal = Ray.actor(SignalActor::new).remote(); ActorHandle myActor = Ray.actor(MyActor::new).remote(); @@ -83,27 +80,4 @@ public class ActorHandleReferenceCountTest { Ray.shutdown(); } } - - public void testRemoveActorHandleReferenceInMultipleThreadedActor() throws InterruptedException { - System.setProperty("ray.job.num-java-workers-per-process", "5"); - try { - Ray.init(); - ActorHandle myActor1 = Ray.actor(MyActor::new).remote(); - int pid1 = myActor1.task(MyActor::getPid).remote().get(); - ActorHandle myActor2 = Ray.actor(MyActor::new).remote(); - int pid2 = myActor2.task(MyActor::getPid).remote().get(); - Assert.assertEquals(pid1, pid2); - del(myActor1); - TimeUnit.SECONDS.sleep(5); - Assert.assertThrows( - RayActorException.class, - () -> { - myActor1.task(MyActor::hello).remote().get(); - }); - /// myActor2 shouldn't be killed. - Assert.assertEquals("hello", myActor2.task(MyActor::hello).remote().get()); - } finally { - Ray.shutdown(); - } - } } diff --git a/java/test/src/main/java/io/ray/test/ClassLoaderTest.java b/java/test/src/main/java/io/ray/test/ClassLoaderTest.java deleted file mode 100644 index 249af8eba..000000000 --- a/java/test/src/main/java/io/ray/test/ClassLoaderTest.java +++ /dev/null @@ -1,167 +0,0 @@ -package io.ray.test; - -import io.ray.api.ActorHandle; -import io.ray.api.BaseActorHandle; -import io.ray.api.ObjectRef; -import io.ray.api.options.ActorCreationOptions; -import io.ray.api.options.CallOptions; -import io.ray.runtime.AbstractRayRuntime; -import io.ray.runtime.functionmanager.FunctionDescriptor; -import io.ray.runtime.functionmanager.JavaFunctionDescriptor; -import java.io.File; -import java.lang.reflect.Method; -import java.nio.file.Files; -import java.nio.file.Paths; -import java.util.Optional; -import javax.tools.JavaCompiler; -import javax.tools.ToolProvider; -import org.apache.commons.io.FileUtils; -import org.testng.Assert; -import org.testng.annotations.BeforeClass; -import org.testng.annotations.Test; - -public class ClassLoaderTest extends BaseTest { - - private final String codeSearchPath = - FileUtils.getTempDirectoryPath() + "/ray_test/ClassLoaderTest"; - - @BeforeClass - public void setUp() { - // The potential issue of multiple `ClassLoader` instances for the same job on multi-threading - // scenario only occurs if the classes are loaded from the job code search path. - System.setProperty("ray.job.code-search-path", codeSearchPath); - } - - @Test(groups = {"cluster"}) - public void testClassLoaderInMultiThreading() throws Exception { - File jobResourceDir = new File(codeSearchPath); - FileUtils.deleteQuietly(jobResourceDir); - jobResourceDir.mkdirs(); - jobResourceDir.deleteOnExit(); - - // In this test case the class is expected to be loaded from the job code search path, - // so we need to put the compiled class file into the job code search path and load it - // later. - String testJavaFile = - "" - + "import java.lang.management.ManagementFactory;\n" - + "import java.lang.management.RuntimeMXBean;\n" - + "\n" - + "public class ClassLoaderTester {\n" - + "\n" - + " static volatile int value;\n" - + "\n" - + " public int getPid() {\n" - + " RuntimeMXBean runtime = ManagementFactory.getRuntimeMXBean();\n" - + " String name = runtime.getName();\n" - + " int index = name.indexOf(\"@\");\n" - + " if (index != -1) {\n" - + " return Integer.parseInt(name.substring(0, index));\n" - + " } else {\n" - + " throw new RuntimeException(\"parse pid error:\" + name);\n" - + " }\n" - + " }\n" - + "\n" - + " public int increase() throws InterruptedException {\n" - + " return increaseInternal();\n" - + " }\n" - + "\n" - + " public static synchronized int increaseInternal() throws InterruptedException {\n" - + " int oldValue = value;\n" - + " Thread.sleep(10 * 1000);\n" - + " value = oldValue + 1;\n" - + " return value;\n" - + " }\n" - + "\n" - + " public int getClassLoaderHashCode() {\n" - + " return this.getClass().getClassLoader().hashCode();\n" - + " }\n" - + "}"; - - // Write the demo java file to the job code search path. - String javaFilePath = codeSearchPath + "/ClassLoaderTester.java"; - Files.write(Paths.get(javaFilePath), testJavaFile.getBytes()); - - // Compile the java file. - JavaCompiler compiler = ToolProvider.getSystemJavaCompiler(); - int result = compiler.run(null, null, null, "-d", codeSearchPath, javaFilePath); - if (result != 0) { - throw new RuntimeException("Couldn't compile ClassLoaderTester.java."); - } - - FunctionDescriptor constructor = - new JavaFunctionDescriptor("ClassLoaderTester", "", "()V"); - ActorHandle actor1 = createActor(constructor); - FunctionDescriptor getPid = new JavaFunctionDescriptor("ClassLoaderTester", "getPid", "()I"); - int pid = - this.callActorFunction(actor1, getPid, new Object[0], Optional.of(Integer.class)) - .get(); - ActorHandle actor2; - while (true) { - // Create another actor which share the same process of actor 1. - actor2 = createActor(constructor); - int actor2Pid = - this.callActorFunction(actor2, getPid, new Object[0], Optional.of(Integer.class)) - .get(); - if (actor2Pid == pid) { - break; - } - } - - FunctionDescriptor getClassLoaderHashCode = - new JavaFunctionDescriptor("ClassLoaderTester", "getClassLoaderHashCode", "()I"); - ObjectRef hashCode1 = - callActorFunction( - actor1, getClassLoaderHashCode, new Object[0], Optional.of(Integer.class)); - ObjectRef hashCode2 = - callActorFunction( - actor2, getClassLoaderHashCode, new Object[0], Optional.of(Integer.class)); - Assert.assertEquals(hashCode1.get(), hashCode2.get()); - - FunctionDescriptor increase = - new JavaFunctionDescriptor("ClassLoaderTester", "increase", "()I"); - ObjectRef value1 = - callActorFunction(actor1, increase, new Object[0], Optional.of(Integer.class)); - ObjectRef value2 = - callActorFunction(actor2, increase, new Object[0], Optional.of(Integer.class)); - Assert.assertNotEquals(value1.get(), value2.get()); - } - - private ActorHandle createActor(FunctionDescriptor functionDescriptor) throws Exception { - Method createActorMethod = - AbstractRayRuntime.class.getDeclaredMethod( - "createActorImpl", - FunctionDescriptor.class, - Object[].class, - ActorCreationOptions.class); - createActorMethod.setAccessible(true); - return (ActorHandle) - createActorMethod.invoke( - TestUtils.getUnderlyingRuntime(), functionDescriptor, new Object[0], null); - } - - private ObjectRef callActorFunction( - ActorHandle rayActor, - FunctionDescriptor functionDescriptor, - Object[] args, - Optional> returnType) - throws Exception { - Method callActorFunctionMethod = - AbstractRayRuntime.class.getDeclaredMethod( - "callActorFunction", - BaseActorHandle.class, - FunctionDescriptor.class, - Object[].class, - Optional.class, - CallOptions.class); - callActorFunctionMethod.setAccessible(true); - return (ObjectRef) - callActorFunctionMethod.invoke( - TestUtils.getUnderlyingRuntime(), - rayActor, - functionDescriptor, - args, - returnType, - new CallOptions.Builder().build()); - } -} diff --git a/java/test/src/main/java/io/ray/test/DefaultActorLifetimeTest.java b/java/test/src/main/java/io/ray/test/DefaultActorLifetimeTest.java index a1664313e..f9d6785fe 100644 --- a/java/test/src/main/java/io/ray/test/DefaultActorLifetimeTest.java +++ b/java/test/src/main/java/io/ray/test/DefaultActorLifetimeTest.java @@ -55,7 +55,6 @@ public class DefaultActorLifetimeTest { System.setProperty("ray.job.default-actor-lifetime", defaultActorLifetime.name()); } try { - System.setProperty("ray.job.num-java-workers-per-process", "1"); Ray.init(); /// 1. create owner and invoke createChildActor. diff --git a/java/test/src/main/java/io/ray/test/ExitActorTest.java b/java/test/src/main/java/io/ray/test/ExitActorTest.java index f402cb607..6b1bcb135 100644 --- a/java/test/src/main/java/io/ray/test/ExitActorTest.java +++ b/java/test/src/main/java/io/ray/test/ExitActorTest.java @@ -74,38 +74,6 @@ public class ExitActorTest extends BaseTest { Assert.assertThrows(RayActorException.class, obj::get); } - public void testExitActorInMultiWorker() { - Assert.assertTrue(TestUtils.getNumWorkersPerProcess() > 1); - ActorHandle actor1 = - Ray.actor(ExitingActor::new).setMaxRestarts(ActorCreationOptions.INFINITE_RESTART).remote(); - int pid = actor1.task(ExitingActor::getPid).remote().get(); - Assert.assertEquals( - 1, (int) actor1.task(ExitingActor::getSizeOfActorContextMap).remote().get()); - ActorHandle actor2; - while (true) { - // Create another actor which share the same process of actor 1. - actor2 = - Ray.actor(ExitingActor::new).setMaxRestarts(ActorCreationOptions.NO_RESTART).remote(); - int actor2Pid = actor2.task(ExitingActor::getPid).remote().get(); - if (actor2Pid == pid) { - break; - } - } - Assert.assertEquals( - 2, (int) actor1.task(ExitingActor::getSizeOfActorContextMap).remote().get()); - Assert.assertEquals( - 2, (int) actor2.task(ExitingActor::getSizeOfActorContextMap).remote().get()); - ObjectRef obj1 = actor1.task(ExitingActor::exit).remote(); - Assert.assertThrows(RayActorException.class, obj1::get); - Assert.assertTrue(SystemUtil.isProcessAlive(pid)); - // Actor 2 shouldn't exit or be reconstructed. - Assert.assertEquals(1, (int) actor2.task(ExitingActor::incr).remote().get()); - Assert.assertEquals( - 1, (int) actor2.task(ExitingActor::getSizeOfActorContextMap).remote().get()); - Assert.assertEquals(pid, (int) actor2.task(ExitingActor::getPid).remote().get()); - Assert.assertTrue(SystemUtil.isProcessAlive(pid)); - } - public void testExitActorWithDynamicOptions() { ActorHandle actor = Ray.actor(ExitingActor::new) diff --git a/java/test/src/main/java/io/ray/test/ExitActorTest2.java b/java/test/src/main/java/io/ray/test/ExitActorTest2.java index c12a6cb26..e639ee1f3 100644 --- a/java/test/src/main/java/io/ray/test/ExitActorTest2.java +++ b/java/test/src/main/java/io/ray/test/ExitActorTest2.java @@ -15,7 +15,6 @@ public class ExitActorTest2 extends BaseTest { @BeforeClass public void setUp() { - System.setProperty("ray.job.num-java-workers-per-process", "1"); System.setProperty("ray.raylet.startup-token", "0"); } diff --git a/java/test/src/main/java/io/ray/test/FailureTest.java b/java/test/src/main/java/io/ray/test/FailureTest.java index 9035bfea3..2191d7b5e 100644 --- a/java/test/src/main/java/io/ray/test/FailureTest.java +++ b/java/test/src/main/java/io/ray/test/FailureTest.java @@ -23,10 +23,6 @@ public class FailureTest extends BaseTest { @BeforeClass public void setUp() { - // This is needed by `testGetThrowsQuicklyWhenFoundException`. - // Set one worker per process. Otherwise, if `badFunc2` and `slowFunc` run in the same - // process, `sleep` will delay `System.exit`. - System.setProperty("ray.job.num-java-workers-per-process", "1"); System.setProperty("ray.raylet.startup-token", "0"); } diff --git a/java/test/src/main/java/io/ray/test/JobConfigTest.java b/java/test/src/main/java/io/ray/test/JobConfigTest.java index 29795773a..e5b781692 100644 --- a/java/test/src/main/java/io/ray/test/JobConfigTest.java +++ b/java/test/src/main/java/io/ray/test/JobConfigTest.java @@ -11,7 +11,6 @@ public class JobConfigTest extends BaseTest { @BeforeClass public void setupJobConfig() { - System.setProperty("ray.job.num-java-workers-per-process", "3"); System.setProperty("ray.raylet.startup-token", "0"); System.setProperty("ray.job.jvm-options.0", "-DX=999"); System.setProperty("ray.job.jvm-options.1", "-DY=998"); @@ -33,10 +32,6 @@ public class JobConfigTest extends BaseTest { Assert.assertEquals("998", Ray.task(JobConfigTest::getJvmOptions, "Y").remote().get()); } - public void testNumJavaWorkersPerProcess() { - Assert.assertEquals(TestUtils.getNumWorkersPerProcess(), 3); - } - public void testInActor() { ActorHandle actor = Ray.actor(MyActor::new).remote(); diff --git a/java/test/src/main/java/io/ray/test/KillActorTest.java b/java/test/src/main/java/io/ray/test/KillActorTest.java index 2c4469ad6..b4d8838a9 100644 --- a/java/test/src/main/java/io/ray/test/KillActorTest.java +++ b/java/test/src/main/java/io/ray/test/KillActorTest.java @@ -15,7 +15,6 @@ public class KillActorTest extends BaseTest { @BeforeClass public void setUp() { - System.setProperty("ray.job.num-java-workers-per-process", "1"); System.setProperty("ray.raylet.startup-token", "0"); } diff --git a/java/test/src/main/java/io/ray/test/MultiThreadingTest.java b/java/test/src/main/java/io/ray/test/MultiThreadingTest.java index 42e0dd2a9..81629245c 100644 --- a/java/test/src/main/java/io/ray/test/MultiThreadingTest.java +++ b/java/test/src/main/java/io/ray/test/MultiThreadingTest.java @@ -51,14 +51,13 @@ public class MultiThreadingTest extends BaseTest { final Object[] result = new Object[1]; Thread thread = new Thread( - Ray.wrapRunnable( - () -> { - try { - result[0] = Ray.getRuntimeContext().getCurrentActorId(); - } catch (Exception e) { - result[0] = e; - } - })); + () -> { + try { + result[0] = Ray.getRuntimeContext().getCurrentActorId(); + } catch (Exception e) { + result[0] = e; + } + }); thread.start(); thread.join(); if (result[0] instanceof Exception) { @@ -140,6 +139,8 @@ public class MultiThreadingTest extends BaseTest { Assert.assertEquals("ok", obj.get()); } + /// SINGLE_PROCESS mode doesn't support this API. + @Test(groups = {"cluster"}) public void testGetCurrentActorId() { ActorHandle actorIdTester = Ray.actor(ActorIdTester::new).remote(); ActorId actorId = actorIdTester.task(ActorIdTester::getCurrentActorId).remote().get(); @@ -162,104 +163,6 @@ public class MultiThreadingTest extends BaseTest { }; } - static boolean testMissingWrapRunnable() throws InterruptedException { - { - Runnable[] runnables = generateRunnables(); - // It's OK to run them in main thread. - for (Runnable runnable : runnables) { - runnable.run(); - } - } - - Throwable[] throwable = new Throwable[1]; - - { - Runnable[] runnables = generateRunnables(); - Thread thread = - new Thread( - Ray.wrapRunnable( - () -> { - try { - // It would be OK to run them in another thread if wrapped the runnable. - for (Runnable runnable : runnables) { - runnable.run(); - } - } catch (Throwable ex) { - throwable[0] = ex; - } - })); - thread.start(); - thread.join(); - if (throwable[0] != null) { - throw new RuntimeException("Exception occurred in thread.", throwable[0]); - } - } - - { - Runnable[] runnables = generateRunnables(); - Thread thread = - new Thread( - () -> { - try { - // It wouldn't be OK to run them in another thread if not wrapped the runnable. - for (Runnable runnable : runnables) { - Assert.expectThrows(RuntimeException.class, runnable::run); - } - } catch (Throwable ex) { - throwable[0] = ex; - } - }); - thread.start(); - thread.join(); - if (throwable[0] != null) { - throw new RuntimeException("Exception occurred in thread.", throwable[0]); - } - } - - { - Runnable[] runnables = generateRunnables(); - Runnable[] wrappedRunnables = new Runnable[runnables.length]; - for (int i = 0; i < runnables.length; i++) { - wrappedRunnables[i] = Ray.wrapRunnable(runnables[i]); - } - // It would be OK to run the wrapped runnables in the current thread. - for (Runnable runnable : wrappedRunnables) { - runnable.run(); - } - - // It would be OK to invoke Ray APIs after executing a wrapped runnable in the current thread. - wrappedRunnables[0].run(); - runnables[0].run(); - } - - // Return true here to make the caller.remote() returns an ObjectRef. - return true; - } - - public void testMissingWrapRunnableInWorker() { - Ray.task(MultiThreadingTest::testMissingWrapRunnable).remote().get(); - } - - public void testGetAndSetAsyncContext() throws InterruptedException { - Object asyncContext = Ray.getAsyncContext(); - Throwable[] throwable = new Throwable[1]; - Thread thread = - new Thread( - () -> { - try { - Ray.setAsyncContext(asyncContext); - Ray.put(1); - } catch (Throwable ex) { - throwable[0] = ex; - } - }); - thread.start(); - thread.join(); - if (throwable[0] != null) { - throw new RuntimeException("Exception occurred in thread.", throwable[0]); - } - } - private static void runTestCaseInMultipleThreads(Runnable testCase, int numRepeats) { ExecutorService service = Executors.newFixedThreadPool(NUM_THREADS); @@ -267,14 +170,13 @@ public class MultiThreadingTest extends BaseTest { List> futures = new ArrayList<>(); for (int i = 0; i < NUM_THREADS; i++) { Callable task = - Ray.wrapCallable( - () -> { - for (int j = 0; j < numRepeats; j++) { - TimeUnit.MILLISECONDS.sleep(1); - testCase.run(); - } - return "ok"; - }); + () -> { + for (int j = 0; j < numRepeats; j++) { + TimeUnit.MILLISECONDS.sleep(1); + testCase.run(); + } + return "ok"; + }; futures.add(service.submit(task)); } for (Future future : futures) { @@ -288,35 +190,4 @@ public class MultiThreadingTest extends BaseTest { service.shutdown(); } } - - private static boolean testGetAsyncContextAndSetAsyncContext() throws Exception { - final Object asyncContext = Ray.getAsyncContext(); - final Object[] result = new Object[1]; - Thread thread = - new Thread( - () -> { - try { - Ray.setAsyncContext(asyncContext); - Ray.put(0); - } catch (Exception e) { - result[0] = e; - } - }); - thread.start(); - thread.join(); - if (result[0] instanceof Exception) { - throw (Exception) result[0]; - } - return true; - } - - public void testGetAsyncContextAndSetAsyncContextInDriver() throws Exception { - Assert.assertTrue(testGetAsyncContextAndSetAsyncContext()); - } - - public void testGetAsyncContextAndSetAsyncContextInWorker() { - ObjectRef obj = - Ray.task(MultiThreadingTest::testGetAsyncContextAndSetAsyncContext).remote(); - Assert.assertTrue(obj.get()); - } } diff --git a/java/test/src/main/java/io/ray/test/RuntimeEnvTest.java b/java/test/src/main/java/io/ray/test/RuntimeEnvTest.java index a31c60d0f..8419f754a 100644 --- a/java/test/src/main/java/io/ray/test/RuntimeEnvTest.java +++ b/java/test/src/main/java/io/ray/test/RuntimeEnvTest.java @@ -4,7 +4,6 @@ import com.google.common.collect.ImmutableList; import io.ray.api.ActorHandle; import io.ray.api.Ray; import io.ray.api.runtimeenv.RuntimeEnv; -import io.ray.runtime.util.SystemUtil; import java.util.List; import org.testng.Assert; import org.testng.annotations.Test; @@ -30,10 +29,6 @@ public class RuntimeEnvTest { return System.getenv(key); } - public int getPid() { - return SystemUtil.pid(); - } - public boolean findClass(String className) { try { Class.forName(className); @@ -45,7 +40,6 @@ public class RuntimeEnvTest { } public void testPerJobEnvVars() { - System.setProperty("ray.job.num-java-workers-per-process", "1"); System.setProperty("ray.job.runtime-env.env-vars.KEY1", "A"); System.setProperty("ray.job.runtime-env.env-vars.KEY2", "B"); @@ -61,87 +55,6 @@ public class RuntimeEnvTest { } } - public void testPerActorEnvVars() { - /// This is used to test that actors with runtime envs will not reuse worker process. - System.setProperty("ray.job.num-java-workers-per-process", "2"); - try { - Ray.init(); - int pid1 = 0; - int pid2 = 0; - { - RuntimeEnv runtimeEnv = - new RuntimeEnv.Builder() - .addEnvVar("KEY1", "A") - .addEnvVar("KEY2", "B") - .addEnvVar("KEY1", "C") - .build(); - - ActorHandle actor1 = Ray.actor(A::new).setRuntimeEnv(runtimeEnv).remote(); - String val = actor1.task(A::getEnv, "KEY1").remote().get(); - Assert.assertEquals(val, "C"); - val = actor1.task(A::getEnv, "KEY2").remote().get(); - Assert.assertEquals(val, "B"); - - pid1 = actor1.task(A::getPid).remote().get(); - } - - { - /// Because we didn't set them for actor2 , all should be null. - ActorHandle actor2 = Ray.actor(A::new).remote(); - String val = actor2.task(A::getEnv, "KEY1").remote().get(); - Assert.assertNull(val); - val = actor2.task(A::getEnv, "KEY2").remote().get(); - Assert.assertNull(val); - pid2 = actor2.task(A::getPid).remote().get(); - } - - // actor1 and actor2 shouldn't be in one process because they have - // different runtime env. - Assert.assertNotEquals(pid1, pid2); - } finally { - Ray.shutdown(); - } - } - - public void testPerActorEnvVarsOverwritePerJobEnvVars() { - System.setProperty("ray.job.num-java-workers-per-process", "2"); - System.setProperty("ray.job.runtime-env.env-vars.KEY1", "A"); - System.setProperty("ray.job.runtime-env.env-vars.KEY2", "B"); - - int pid1 = 0; - int pid2 = 0; - try { - Ray.init(); - { - RuntimeEnv runtimeEnv = new RuntimeEnv.Builder().addEnvVar("KEY1", "C").build(); - - ActorHandle actor1 = Ray.actor(A::new).setRuntimeEnv(runtimeEnv).remote(); - String val = actor1.task(A::getEnv, "KEY1").remote().get(); - Assert.assertEquals(val, "C"); - val = actor1.task(A::getEnv, "KEY2").remote().get(); - Assert.assertEquals(val, "B"); - pid1 = actor1.task(A::getPid).remote().get(); - } - - { - /// Because we didn't set them for actor2 explicitly, it should use the per job - /// runtime env. - ActorHandle actor2 = Ray.actor(A::new).remote(); - String val = actor2.task(A::getEnv, "KEY1").remote().get(); - Assert.assertEquals(val, "A"); - val = actor2.task(A::getEnv, "KEY2").remote().get(); - Assert.assertEquals(val, "B"); - pid2 = actor2.task(A::getPid).remote().get(); - } - - // actor1 and actor2 shouldn't be in one process because they have - // different runtime env. - Assert.assertNotEquals(pid1, pid2); - } finally { - Ray.shutdown(); - } - } - private static String getEnvVar(String key) { return System.getenv(key); } diff --git a/java/test/src/main/java/io/ray/test/TestUtils.java b/java/test/src/main/java/io/ray/test/TestUtils.java index 71a0ece59..5db5e0e8e 100644 --- a/java/test/src/main/java/io/ray/test/TestUtils.java +++ b/java/test/src/main/java/io/ray/test/TestUtils.java @@ -3,9 +3,7 @@ package io.ray.test; import com.google.common.base.Preconditions; import io.ray.api.ObjectRef; import io.ray.api.Ray; -import io.ray.runtime.AbstractRayRuntime; import io.ray.runtime.RayRuntimeInternal; -import io.ray.runtime.RayRuntimeProxy; import io.ray.runtime.config.RayConfig; import io.ray.runtime.config.RunMode; import io.ray.runtime.task.ArgumentsBuilder; @@ -129,20 +127,7 @@ public class TestUtils { } public static RayRuntimeInternal getUnderlyingRuntime() { - if (Ray.internal() instanceof AbstractRayRuntime) { - return (RayRuntimeInternal) Ray.internal(); - } - RayRuntimeProxy proxy = - (RayRuntimeProxy) (java.lang.reflect.Proxy.getInvocationHandler(Ray.internal())); - return proxy.getRuntimeObject(); - } - - private static int getNumWorkersPerProcessRemoteFunction() { - return TestUtils.getRuntime().getRayConfig().numWorkersPerProcess; - } - - public static int getNumWorkersPerProcess() { - return Ray.task(TestUtils::getNumWorkersPerProcessRemoteFunction).remote().get(); + return (RayRuntimeInternal) Ray.internal(); } public static ProcessBuilder buildDriver(Class mainClass, String[] args) { diff --git a/java/test/src/main/resources/ray.conf b/java/test/src/main/resources/ray.conf deleted file mode 100644 index b838c0075..000000000 --- a/java/test/src/main/resources/ray.conf +++ /dev/null @@ -1,6 +0,0 @@ -ray { - job { - # Enable multi-worker feature in Java test - num-java-workers-per-process: 10 - } -} diff --git a/python/ray/_raylet.pyx b/python/ray/_raylet.pyx index e1be9aa2b..2ef826199 100644 --- a/python/ray/_raylet.pyx +++ b/python/ray/_raylet.pyx @@ -1116,7 +1116,6 @@ cdef class CoreWorker: options.unhandled_exception_handler = unhandled_exception_handler options.get_lang_stack = get_py_stack options.is_local_mode = local_mode - options.num_workers = 1 options.kill_main = kill_main_task options.terminate_asyncio_thread = terminate_asyncio_thread options.serialized_job_config = serialized_job_config diff --git a/python/ray/job_config.py b/python/ray/job_config.py index 0116a4012..c220e37b5 100644 --- a/python/ray/job_config.py +++ b/python/ray/job_config.py @@ -8,8 +8,6 @@ class JobConfig: """A class used to store the configurations of a job. Attributes: - num_java_workers_per_process (int): The number of java workers per - worker process. jvm_options (str[]): The jvm options for java workers of the job. code_search_path (list): A list of directories or jar files that specify the search path for user code. This will be used as @@ -22,7 +20,6 @@ class JobConfig: def __init__( self, - num_java_workers_per_process=1, jvm_options=None, code_search_path=None, runtime_env=None, @@ -31,7 +28,6 @@ class JobConfig: ray_namespace=None, default_actor_lifetime="non_detached", ): - self.num_java_workers_per_process = num_java_workers_per_process self.jvm_options = jvm_options or [] self.code_search_path = code_search_path or [] # It's difficult to find the error that caused by the @@ -108,7 +104,6 @@ class JobConfig: pb.ray_namespace = str(uuid.uuid4()) else: pb.ray_namespace = self.ray_namespace - pb.num_java_workers_per_process = self.num_java_workers_per_process pb.jvm_options.extend(self.jvm_options) pb.code_search_path.extend(self.code_search_path) for k, v in self.metadata.items(): @@ -147,9 +142,6 @@ class JobConfig: Generates a JobConfig object from json. """ return cls( - num_java_workers_per_process=job_config_json.get( - "num_java_workers_per_process", 1 - ), jvm_options=job_config_json.get("jvm_options", None), code_search_path=job_config_json.get("code_search_path", None), runtime_env=job_config_json.get("runtime_env", None), diff --git a/python/ray/tests/test_advanced_8.py b/python/ray/tests/test_advanced_8.py index a9fb0dd10..230539222 100644 --- a/python/ray/tests/test_advanced_8.py +++ b/python/ray/tests/test_advanced_8.py @@ -545,19 +545,16 @@ def test_k8s_cpu(use_cgroups_v2: bool): def test_sync_job_config(shutdown_only): - num_java_workers_per_process = 8 runtime_env = {"env_vars": {"key": "value"}} ray.init( job_config=ray.job_config.JobConfig( - num_java_workers_per_process=num_java_workers_per_process, runtime_env=runtime_env, ) ) # Check that the job config is synchronized at the driver side. job_config = ray.worker.global_worker.core_worker.get_job_config() - assert job_config.num_java_workers_per_process == num_java_workers_per_process job_runtime_env = RuntimeEnv.deserialize( job_config.runtime_env_info.serialized_runtime_env ) @@ -571,7 +568,6 @@ def test_sync_job_config(shutdown_only): # Check that the job config is synchronized at the worker side. job_config = gcs_utils.JobConfig() job_config.ParseFromString(ray.get(get_job_config.remote())) - assert job_config.num_java_workers_per_process == num_java_workers_per_process job_runtime_env = RuntimeEnv.deserialize( job_config.runtime_env_info.serialized_runtime_env ) diff --git a/src/ray/core_worker/core_worker.cc b/src/ray/core_worker/core_worker.cc index af98e6a93..96cf4a155 100644 --- a/src/ray/core_worker/core_worker.cc +++ b/src/ray/core_worker/core_worker.cc @@ -3084,15 +3084,6 @@ void CoreWorker::HandleKillActor(const rpc::KillActorRequest &request, if (request.no_restart()) { Disconnect(); } - if (options_.num_workers > 1) { - // TODO (kfstorm): Should we add some kind of check before sending the killing - // request? - RAY_LOG(ERROR) - << "Killing an actor which is running in a worker process with multiple " - "workers will also kill other actors in this process. To avoid this, " - "please create the Java actor with some dynamic options to make it being " - "hosted in a dedicated worker process."; - } // NOTE(hchen): Use `QuickExit()` to force-exit this process without doing cleanup. // `exit()` will destruct static objects in an incorrect order, which will lead to // core dumps. diff --git a/src/ray/core_worker/core_worker_options.h b/src/ray/core_worker/core_worker_options.h index 21d8d1816..5703c6544 100644 --- a/src/ray/core_worker/core_worker_options.h +++ b/src/ray/core_worker/core_worker_options.h @@ -76,7 +76,6 @@ struct CoreWorkerOptions { get_lang_stack(nullptr), kill_main(nullptr), is_local_mode(false), - num_workers(0), terminate_asyncio_thread(nullptr), serialized_job_config(""), metrics_agent_port(-1), @@ -148,8 +147,6 @@ struct CoreWorkerOptions { std::function kill_main; /// Is local mode being used. bool is_local_mode; - /// The number of workers to be started in the current process. - int num_workers; /// The function to destroy asyncio event and loops. std::function terminate_asyncio_thread; /// Serialized representation of JobConfig. diff --git a/src/ray/core_worker/core_worker_process.cc b/src/ray/core_worker/core_worker_process.cc index eca59f3bf..16221eabc 100644 --- a/src/ray/core_worker/core_worker_process.cc +++ b/src/ray/core_worker/core_worker_process.cc @@ -82,10 +82,9 @@ thread_local std::weak_ptr CoreWorkerProcessImpl::thread_local_core_ CoreWorkerProcessImpl::CoreWorkerProcessImpl(const CoreWorkerOptions &options) : options_(options), - global_worker_id_( - options.worker_type == WorkerType::DRIVER - ? ComputeDriverIdFromJob(options_.job_id) - : (options_.num_workers == 1 ? WorkerID::FromRandom() : WorkerID::Nil())) { + global_worker_id_(options.worker_type == WorkerType::DRIVER + ? ComputeDriverIdFromJob(options_.job_id) + : WorkerID::FromRandom()) { if (options_.enable_logging) { std::stringstream app_name; app_name << LanguageString(options_.language) << "-core-" @@ -111,12 +110,6 @@ CoreWorkerProcessImpl::CoreWorkerProcessImpl(const CoreWorkerOptions &options) << "install_failure_signal_handler must be false because ray log is disabled."; } - RAY_CHECK(options_.num_workers > 0); - if (options_.worker_type == WorkerType::DRIVER) { - // Driver process can only contain one worker. - RAY_CHECK(options_.num_workers == 1); - } - RAY_LOG(INFO) << "Constructing CoreWorkerProcess. pid: " << getpid(); // NOTE(kfstorm): any initialization depending on RayConfig must happen after this line. @@ -250,18 +243,14 @@ bool CoreWorkerProcessImpl::ShouldCreateGlobalWorkerOnConstruction() const { // worker APIs before `CoreWorkerProcess::RunTaskExecutionLoop` is called. So we need // to create the worker instance here. One example of invocations is // https://github.com/ray-project/ray/blob/45ce40e5d44801193220d2c546be8de0feeef988/python/ray/worker.py#L1281. - return options_.num_workers == 1 && (options_.worker_type == WorkerType::DRIVER || - options_.language == Language::PYTHON); + return (options_.worker_type == WorkerType::DRIVER || + options_.language == Language::PYTHON); } std::shared_ptr CoreWorkerProcessImpl::GetWorker( const WorkerID &worker_id) const { absl::ReaderMutexLock lock(&mutex_); - auto it = workers_.find(worker_id); - if (it != workers_.end()) { - return it->second; - } - return nullptr; + return global_worker_; } std::shared_ptr CoreWorkerProcessImpl::GetGlobalWorker() { @@ -275,13 +264,9 @@ std::shared_ptr CoreWorkerProcessImpl::CreateWorker() { global_worker_id_ != WorkerID::Nil() ? global_worker_id_ : WorkerID::FromRandom()); RAY_LOG(DEBUG) << "Worker " << worker->GetWorkerID() << " is created."; absl::WriterMutexLock lock(&mutex_); - if (options_.num_workers == 1) { - global_worker_ = worker; - } + global_worker_ = worker; thread_local_core_worker_ = worker; - workers_.emplace(worker->GetWorkerID(), worker); - RAY_CHECK(workers_.size() <= static_cast(options_.num_workers)); return worker; } @@ -293,10 +278,7 @@ void CoreWorkerProcessImpl::RemoveWorker(std::shared_ptr worker) { RAY_CHECK(thread_local_core_worker_.lock() == worker); } thread_local_core_worker_.reset(); - { - workers_.erase(worker->GetWorkerID()); - RAY_LOG(INFO) << "Removed worker " << worker->GetWorkerID(); - } + RAY_LOG(INFO) << "Removed worker " << worker->GetWorkerID(); if (global_worker_ == worker) { global_worker_ = nullptr; } @@ -304,31 +286,14 @@ void CoreWorkerProcessImpl::RemoveWorker(std::shared_ptr worker) { void CoreWorkerProcessImpl::RunWorkerTaskExecutionLoop() { RAY_CHECK(options_.worker_type == WorkerType::WORKER); - if (options_.num_workers == 1) { - // Run the task loop in the current thread only if the number of workers is 1. - auto worker = GetGlobalWorker(); - if (!worker) { - worker = CreateWorker(); - } - worker->RunTaskExecutionLoop(); - RAY_LOG(INFO) << "Task execution loop terminated. Removing the global worker."; - RemoveWorker(worker); - } else { - std::vector worker_threads; - for (int i = 0; i < options_.num_workers; i++) { - worker_threads.emplace_back([this, i] { - SetThreadName("worker.task" + std::to_string(i)); - auto worker = CreateWorker(); - worker->RunTaskExecutionLoop(); - RAY_LOG(INFO) << "Task execution loop terminated for a thread " - << std::to_string(i) << ". Removing a worker."; - RemoveWorker(worker); - }); - } - for (auto &thread : worker_threads) { - thread.join(); - } + // Run the task loop in the current thread only if the number of workers is 1. + auto worker = GetGlobalWorker(); + if (!worker) { + worker = CreateWorker(); } + worker->RunTaskExecutionLoop(); + RAY_LOG(INFO) << "Task execution loop terminated. Removing the global worker."; + RemoveWorker(worker); } void CoreWorkerProcessImpl::ShutdownDriver() { @@ -342,42 +307,30 @@ void CoreWorkerProcessImpl::ShutdownDriver() { } CoreWorker &CoreWorkerProcessImpl::GetCoreWorkerForCurrentThread() { - if (options_.num_workers == 1) { - auto global_worker = GetGlobalWorker(); - if (ShouldCreateGlobalWorkerOnConstruction() && !global_worker) { - // This could only happen when the worker has already been shutdown. - // In this case, we should exit without crashing. - // TODO (scv119): A better solution could be returning error code - // and handling it at language frontend. - if (options_.worker_type == WorkerType::DRIVER) { - RAY_LOG(ERROR) - << "The global worker has already been shutdown. This happens when " - "the language frontend accesses the Ray's worker after it is " - "shutdown. The process will exit"; - } else { - RAY_LOG(INFO) << "The global worker has already been shutdown. This happens when " - "the language frontend accesses the Ray's worker after it is " - "shutdown. The process will exit"; - } - QuickExit(); + auto global_worker = GetGlobalWorker(); + if (ShouldCreateGlobalWorkerOnConstruction() && !global_worker) { + // This could only happen when the worker has already been shutdown. + // In this case, we should exit without crashing. + // TODO (scv119): A better solution could be returning error code + // and handling it at language frontend. + if (options_.worker_type == WorkerType::DRIVER) { + RAY_LOG(ERROR) << "The global worker has already been shutdown. This happens when " + "the language frontend accesses the Ray's worker after it is " + "shutdown. The process will exit"; + } else { + RAY_LOG(INFO) << "The global worker has already been shutdown. This happens when " + "the language frontend accesses the Ray's worker after it is " + "shutdown. The process will exit"; } - RAY_CHECK(global_worker) << "global_worker_ must not be NULL"; - return *global_worker; + QuickExit(); } - auto ptr = thread_local_core_worker_.lock(); - RAY_CHECK(ptr != nullptr) - << "The current thread is not bound with a core worker instance."; - return *ptr; + RAY_CHECK(global_worker) << "global_worker_ must not be NULL"; + return *global_worker; } void CoreWorkerProcessImpl::SetThreadLocalWorkerById(const WorkerID &worker_id) { - if (options_.num_workers == 1) { - RAY_CHECK(GetGlobalWorker()->GetWorkerID() == worker_id); - return; - } - auto worker = GetWorker(worker_id); - RAY_CHECK(worker) << "Worker " << worker_id << " not found."; - thread_local_core_worker_ = GetWorker(worker_id); + RAY_CHECK(GetGlobalWorker()->GetWorkerID() == worker_id); + return; } } // namespace core diff --git a/src/ray/core_worker/core_worker_process.h b/src/ray/core_worker/core_worker_process.h index df8e10b87..15067be3a 100644 --- a/src/ray/core_worker/core_worker_process.h +++ b/src/ray/core_worker/core_worker_process.h @@ -25,7 +25,6 @@ class CoreWorker; /// CoreWorkerOptions options = { /// WorkerType::DRIVER, // worker_type /// ..., // other arguments -/// 1, // num_workers /// }; /// CoreWorkerProcess::Initialize(options); /// @@ -36,7 +35,6 @@ class CoreWorker; /// CoreWorkerOptions options = { /// WorkerType::WORKER, // worker_type /// ..., // other arguments -/// num_workers, // num_workers /// }; /// CoreWorkerProcess::Initialize(options); /// ... // Do other stuff @@ -52,14 +50,7 @@ class CoreWorker; /// threads, remember to call `CoreWorkerProcess::SetCurrentThreadWorkerId(worker_id)` /// once in the new thread before calling core worker APIs, to associate the current /// thread with a worker. You can obtain the worker ID via -/// `CoreWorkerProcess::GetCoreWorker()->GetWorkerID()`. Currently a Java worker process -/// starts multiple workers by default, but can be configured to start only 1 worker by -/// speicifying `num_java_workers_per_process` in the job config. -/// -/// If only 1 worker is started (either because the worker type is driver, or the -/// `num_workers` in `CoreWorkerOptions` is set to 1), all threads will be automatically -/// associated to the only worker. Then no need to call `SetCurrentThreadWorkerId` in -/// your own threads. Currently a Python worker process starts only 1 worker. +/// `CoreWorkerProcess::GetCoreWorker()->GetWorkerID()`. /// /// How does core worker process dealloation work? /// @@ -195,9 +186,6 @@ class CoreWorkerProcessImpl { /// The worker ID of the global worker, if the number of workers is 1. const WorkerID global_worker_id_; - /// Map from worker ID to worker. - absl::flat_hash_map> workers_ GUARDED_BY(mutex_); - /// To protect access to workers_ and global_worker_ mutable absl::Mutex mutex_; }; diff --git a/src/ray/core_worker/lib/java/io_ray_runtime_RayNativeRuntime.cc b/src/ray/core_worker/lib/java/io_ray_runtime_RayNativeRuntime.cc index 33433838d..edad6bbcd 100644 --- a/src/ray/core_worker/lib/java/io_ray_runtime_RayNativeRuntime.cc +++ b/src/ray/core_worker/lib/java/io_ray_runtime_RayNativeRuntime.cc @@ -103,7 +103,6 @@ Java_io_ray_runtime_RayNativeRuntime_nativeInitialize(JNIEnv *env, jstring rayletSocket, jbyteArray jobId, jobject gcsClientOptions, - jint numWorkersPerProcess, jstring logDir, jbyteArray jobConfig, jint startupToken, @@ -289,7 +288,6 @@ Java_io_ray_runtime_RayNativeRuntime_nativeInitialize(JNIEnv *env, options.task_execution_callback = task_execution_callback; options.on_worker_shutdown = on_worker_shutdown; options.gc_collect = gc_collect; - options.num_workers = static_cast(numWorkersPerProcess); options.serialized_job_config = serialized_job_config; options.metrics_agent_port = -1; options.startup_token = startupToken; diff --git a/src/ray/core_worker/lib/java/io_ray_runtime_RayNativeRuntime.h b/src/ray/core_worker/lib/java/io_ray_runtime_RayNativeRuntime.h index 785a7965a..6650799ce 100644 --- a/src/ray/core_worker/lib/java/io_ray_runtime_RayNativeRuntime.h +++ b/src/ray/core_worker/lib/java/io_ray_runtime_RayNativeRuntime.h @@ -25,7 +25,7 @@ extern "C" { * Class: io_ray_runtime_RayNativeRuntime * Method: nativeInitialize * Signature: - * (ILjava/lang/String;ILjava/lang/String;Ljava/lang/String;Ljava/lang/String;[BLio/ray/runtime/gcs/GcsClientOptions;ILjava/lang/String;[BII)V + * (ILjava/lang/String;ILjava/lang/String;Ljava/lang/String;Ljava/lang/String;[BLio/ray/runtime/gcs/GcsClientOptions;Ljava/lang/String;[BII)V */ JNIEXPORT void JNICALL Java_io_ray_runtime_RayNativeRuntime_nativeInitialize(JNIEnv *, jclass, @@ -37,7 +37,6 @@ JNIEXPORT void JNICALL Java_io_ray_runtime_RayNativeRuntime_nativeInitialize(JNI jstring, jbyteArray, jobject, - jint, jstring, jbyteArray, jint, diff --git a/src/ray/core_worker/test/core_worker_test.cc b/src/ray/core_worker/test/core_worker_test.cc index 968fec57e..497d493ce 100644 --- a/src/ray/core_worker/test/core_worker_test.cc +++ b/src/ray/core_worker/test/core_worker_test.cc @@ -140,7 +140,6 @@ class CoreWorkerTest : public ::testing::Test { options.node_manager_port = node_manager_port; options.raylet_ip_address = "127.0.0.1"; options.driver_name = "core_worker_test"; - options.num_workers = 1; options.metrics_agent_port = -1; CoreWorkerProcess::Initialize(options); } diff --git a/src/ray/core_worker/test/mock_worker.cc b/src/ray/core_worker/test/mock_worker.cc index 91ce0bb9a..9c9ec6ec3 100644 --- a/src/ray/core_worker/test/mock_worker.cc +++ b/src/ray/core_worker/test/mock_worker.cc @@ -51,7 +51,6 @@ class MockWorker { options.raylet_ip_address = "127.0.0.1"; options.task_execution_callback = std::bind(&MockWorker::ExecuteTask, this, _1, _2, _3, _4, _5, _6, _7, _8, _9); - options.num_workers = 1; options.metrics_agent_port = -1; options.startup_token = startup_token; CoreWorkerProcess::Initialize(options); diff --git a/src/ray/gcs/gcs_server/gcs_actor_manager.cc b/src/ray/gcs/gcs_server/gcs_actor_manager.cc index fcfbd7260..3b27acf7f 100644 --- a/src/ray/gcs/gcs_server/gcs_actor_manager.cc +++ b/src/ray/gcs/gcs_server/gcs_actor_manager.cc @@ -686,12 +686,8 @@ void GcsActorManager::PollOwnerForActorOutOfScope( if (node_it != owners_.end() && node_it->second.count(owner_id)) { // Only destroy the actor if its owner is still alive. The actor may // have already been destroyed if the owner died. - // For multiple actors in one process, if one actor is out of scope, - // We shouldn't force kill the actor because other actors in the process - // are still alive. - auto force_kill = - get_job_config_(actor_id.JobId())->num_java_workers_per_process() <= 1; - DestroyActor(actor_id, GenActorOutOfScopeCause(GetActor(actor_id)), force_kill); + DestroyActor( + actor_id, GenActorOutOfScopeCause(GetActor(actor_id)), /*force_kill=*/true); } }); } diff --git a/src/ray/gcs/gcs_server/test/gcs_job_manager_test.cc b/src/ray/gcs/gcs_server/test/gcs_job_manager_test.cc index b1ba094e8..506da34a1 100644 --- a/src/ray/gcs/gcs_server/test/gcs_job_manager_test.cc +++ b/src/ray/gcs/gcs_server/test/gcs_job_manager_test.cc @@ -89,7 +89,7 @@ TEST_F(GcsJobManagerTest, TestGetJobConfig) { auto job_id2 = JobID::FromInt(2); gcs::GcsInitData gcs_init_data(gcs_table_storage_); gcs_job_manager.Initialize(/*init_data=*/gcs_init_data); - auto add_job_request1 = Mocker::GenAddJobRequest(job_id1, "namespace_1", 4); + auto add_job_request1 = Mocker::GenAddJobRequest(job_id1, "namespace_1"); rpc::AddJobReply empty_reply; @@ -97,19 +97,16 @@ TEST_F(GcsJobManagerTest, TestGetJobConfig) { *add_job_request1, &empty_reply, [](Status, std::function, std::function) {}); - auto add_job_request2 = Mocker::GenAddJobRequest(job_id2, "namespace_2", 8); + auto add_job_request2 = Mocker::GenAddJobRequest(job_id2, "namespace_2"); gcs_job_manager.HandleAddJob( *add_job_request2, &empty_reply, [](Status, std::function, std::function) {}); - auto job_config1 = gcs_job_manager.GetJobConfig(job_id1); ASSERT_EQ("namespace_1", job_config1->ray_namespace()); - ASSERT_EQ(4, job_config1->num_java_workers_per_process()); auto job_config2 = gcs_job_manager.GetJobConfig(job_id2); ASSERT_EQ("namespace_2", job_config2->ray_namespace()); - ASSERT_EQ(8, job_config2->num_java_workers_per_process()); } int main(int argc, char **argv) { diff --git a/src/ray/gcs/test/gcs_test_util.h b/src/ray/gcs/test/gcs_test_util.h index e9cca3866..21c54c146 100644 --- a/src/ray/gcs/test/gcs_test_util.h +++ b/src/ray/gcs/test/gcs_test_util.h @@ -239,12 +239,9 @@ struct Mocker { } static std::shared_ptr GenAddJobRequest( - const JobID &job_id, - const std::string &ray_namespace, - uint32_t num_java_worker_per_process) { + const JobID &job_id, const std::string &ray_namespace) { auto job_config_data = std::make_shared(); job_config_data->set_ray_namespace(ray_namespace); - job_config_data->set_num_java_workers_per_process(num_java_worker_per_process); auto job_table_data = std::make_shared(); job_table_data->set_job_id(job_id.Binary()); diff --git a/src/ray/object_manager/common.h b/src/ray/object_manager/common.h index 01da4a1d3..66829d251 100644 --- a/src/ray/object_manager/common.h +++ b/src/ray/object_manager/common.h @@ -39,8 +39,8 @@ using RestoreSpilledObjectCallback = /// A struct that includes info about the object. struct ObjectInfo { ObjectID object_id; - int64_t data_size; - int64_t metadata_size; + int64_t data_size = 0; + int64_t metadata_size = 0; /// Owner's raylet ID. NodeID owner_raylet_id; /// Owner's IP address. diff --git a/src/ray/protobuf/gcs.proto b/src/ray/protobuf/gcs.proto index 34dec8bd5..0fb170c03 100644 --- a/src/ray/protobuf/gcs.proto +++ b/src/ray/protobuf/gcs.proto @@ -260,8 +260,6 @@ message JobConfig { NON_DETACHED = 1; } - // The number of java workers per worker process. - uint32 num_java_workers_per_process = 1; // The jvm options for java workers of the job. repeated string jvm_options = 2; // A list of directories or files (jar files or dynamic libraries) that specify the diff --git a/src/ray/raylet/worker_pool.cc b/src/ray/raylet/worker_pool.cc index 4390477a8..7cbc0de70 100644 --- a/src/ray/raylet/worker_pool.cc +++ b/src/ray/raylet/worker_pool.cc @@ -198,14 +198,12 @@ void WorkerPool::update_worker_startup_token_counter() { void WorkerPool::AddWorkerProcess( State &state, - const int workers_to_start, const rpc::WorkerType worker_type, const Process &proc, const std::chrono::high_resolution_clock::time_point &start, const rpc::RuntimeEnvInfo &runtime_env_info) { state.worker_processes.emplace(worker_startup_token_counter_, - WorkerProcessInfo{workers_to_start, - workers_to_start, + WorkerProcessInfo{/*is_pending_registration=*/true, {}, worker_type, proc, @@ -248,7 +246,7 @@ std::tuple WorkerPool::StartWorkerProcess( int starting_workers = 0; for (auto &entry : state.worker_processes) { if (entry.second.worker_type == worker_type) { - starting_workers += entry.second.num_starting_workers; + starting_workers += entry.second.is_pending_registration ? 1 : 0; } } @@ -268,13 +266,6 @@ std::tuple WorkerPool::StartWorkerProcess( << rpc::WorkerType_Name(worker_type) << ", current pool has " << state.idle.size() << " workers"; - int workers_to_start = 1; - if (dynamic_options.empty()) { - if (language == Language::JAVA) { - workers_to_start = job_config->num_java_workers_per_process(); - } - } - std::vector options; // Append Ray-defined per-job options here @@ -312,12 +303,6 @@ std::tuple WorkerPool::StartWorkerProcess( } } - // Append Ray-defined per-process options here - if (language == Language::JAVA) { - options.push_back("-Dray.job.num-java-workers-per-process=" + - std::to_string(workers_to_start)); - } - // Append startup-token for JAVA here if (language == Language::JAVA) { options.push_back("-Dray.raylet.startup-token=" + @@ -460,13 +445,12 @@ std::tuple WorkerPool::StartWorkerProcess( auto duration = std::chrono::duration_cast(end - start); stats::ProcessStartupTimeMs.Record(duration.count()); stats::NumWorkersStarted.Record(1); - RAY_LOG(INFO) << "Started worker process of " << workers_to_start - << " worker(s) with pid " << proc.GetId() << ", the token " + RAY_LOG(INFO) << "Started worker process with pid " << proc.GetId() << ", the token is " << worker_startup_token_counter_; AdjustWorkerOomScore(proc.GetId()); MonitorStartingWorkerProcess( proc, worker_startup_token_counter_, language, worker_type); - AddWorkerProcess(state, workers_to_start, worker_type, proc, start, runtime_env_info); + AddWorkerProcess(state, worker_type, proc, start, runtime_env_info); StartupToken worker_startup_token = worker_startup_token_counter_; update_worker_startup_token_counter(); if (IsIOWorkerType(worker_type)) { @@ -513,7 +497,7 @@ void WorkerPool::MonitorStartingWorkerProcess(const Process &proc, // Since this process times out to start, remove it from worker_processes // to avoid the zombie worker. auto it = state.worker_processes.find(proc_startup_token); - if (it != state.worker_processes.end() && it->second.num_starting_workers != 0) { + if (it != state.worker_processes.end() && it->second.is_pending_registration) { RAY_LOG(ERROR) << "Some workers of the worker process(" << proc.GetId() << ") have not registered within the timeout. " @@ -747,12 +731,10 @@ void WorkerPool::OnWorkerStarted(const std::shared_ptr &worker) auto it = state.worker_processes.find(worker_startup_token); if (it != state.worker_processes.end()) { - it->second.num_starting_workers--; + it->second.is_pending_registration = false; it->second.alive_started_workers.insert(worker); - if (it->second.num_starting_workers == 0) { - // We may have slots to start more workers now. - TryStartIOWorkers(worker->GetLanguage()); - } + // We may have slots to start more workers now. + TryStartIOWorkers(worker->GetLanguage()); } const auto &worker_type = worker->GetWorkerType(); if (IsIOWorkerType(worker_type)) { @@ -1044,8 +1026,7 @@ void WorkerPool::TryKillingIdleWorkers() { auto &worker_state = GetStateForLanguage(idle_worker->GetLanguage()); auto it = worker_state.worker_processes.find(worker_startup_token); - if (it != worker_state.worker_processes.end() && - it->second.num_starting_workers > 0) { + if (it != worker_state.worker_processes.end() && it->second.is_pending_registration) { // A Java worker process may hold multiple workers. // Some workers of this process are pending registration. Skip killing this worker. continue; @@ -1349,7 +1330,7 @@ void WorkerPool::PrestartWorkers(const TaskSpecification &task_spec, // The number of available workers that can be used for this task spec. int num_usable_workers = state.idle.size(); for (auto &entry : state.worker_processes) { - num_usable_workers += entry.second.num_starting_workers; + num_usable_workers += entry.second.is_pending_registration ? 1 : 0; } // Some existing workers may be holding less than 1 CPU each, so we should // start as many workers as needed to fill up the remaining CPUs. @@ -1376,10 +1357,10 @@ void WorkerPool::DisconnectWorker(const std::shared_ptr &worker if (!RemoveWorker(it->second.alive_started_workers, worker)) { // Worker is either starting or started, // if it's not started, we should remove it from starting. - it->second.num_starting_workers--; + it->second.is_pending_registration = false; } if (it->second.alive_started_workers.size() == 0 && - it->second.num_starting_workers == 0) { + !it->second.is_pending_registration) { DeleteRuntimeEnvIfPossible(it->second.runtime_env_info.serialized_runtime_env()); RemoveWorkerProcess(state, worker->GetStartupToken()); } @@ -1508,7 +1489,8 @@ void WorkerPool::WarnAboutSize() { num_workers_started_or_registered += static_cast(state.registered_workers.size()); for (const auto &starting_process : state.worker_processes) { - num_workers_started_or_registered += starting_process.second.num_starting_workers; + num_workers_started_or_registered += + starting_process.second.is_pending_registration ? 0 : 1; } // Don't count IO workers towards the warning message threshold. num_workers_started_or_registered -= RayConfig::instance().max_io_workers() * 2; diff --git a/src/ray/raylet/worker_pool.h b/src/ray/raylet/worker_pool.h index 4e94ade38..3bd3da859 100644 --- a/src/ray/raylet/worker_pool.h +++ b/src/ray/raylet/worker_pool.h @@ -473,10 +473,8 @@ class WorkerPool : public WorkerPoolInterface, public IOWorkerPoolInterface { /// Some basic information about the worker process. struct WorkerProcessInfo { - /// The number of workers in the worker process. - int num_workers; - /// The number of pending registration workers in the worker process. - int num_starting_workers; + /// Whether this worker is pending registration or is started. + bool is_pending_registration = true; /// The started workers which is alive. std::unordered_set> alive_started_workers; /// The type of the worker. @@ -684,7 +682,6 @@ class WorkerPool : public WorkerPoolInterface, public IOWorkerPoolInterface { void DeleteRuntimeEnvIfPossible(const std::string &serialized_runtime_env); void AddWorkerProcess(State &state, - const int workers_to_start, const rpc::WorkerType worker_type, const Process &proc, const std::chrono::high_resolution_clock::time_point &start, diff --git a/src/ray/raylet/worker_pool_test.cc b/src/ray/raylet/worker_pool_test.cc index a0be09325..c9d4255f4 100644 --- a/src/ray/raylet/worker_pool_test.cc +++ b/src/ray/raylet/worker_pool_test.cc @@ -26,8 +26,7 @@ namespace ray { namespace raylet { -int NUM_WORKERS_PER_PROCESS_JAVA = 3; -int MAXIMUM_STARTUP_CONCURRENCY = 5; +int MAXIMUM_STARTUP_CONCURRENCY = 15; int MAX_IO_WORKER_SIZE = 2; int POOL_SIZE_SOFT_LIMIT = 5; int WORKER_REGISTER_TIMEOUT_SECONDS = 3; @@ -37,10 +36,6 @@ const std::string BAD_RUNTIME_ENV_ERROR_MSG = "bad runtime env"; std::vector LANGUAGES = {Language::PYTHON, Language::JAVA}; -static inline std::string GetNumJavaWorkersPerProcessSystemProperty(int num) { - return std::string("-Dray.job.num-java-workers-per-process=") + std::to_string(num); -} - class MockWorkerClient : public rpc::CoreWorkerClientInterface { public: MockWorkerClient(instrumented_io_context &io_service) : io_service_(io_service) {} @@ -194,7 +189,7 @@ class WorkerPoolMock : public WorkerPool { int total = 0; for (auto &state_entry : states_by_lang_) { for (auto &process_entry : state_entry.second.worker_processes) { - total += process_entry.second.num_starting_workers; + total += process_entry.second.is_pending_registration ? 1 : 0; } } return total; @@ -204,7 +199,7 @@ class WorkerPoolMock : public WorkerPool { int total = 0; for (auto &entry : states_by_lang_) { for (auto process : entry.second.worker_processes) { - if (process.second.num_starting_workers != 0) { + if (process.second.is_pending_registration) { total += 1; } } @@ -313,7 +308,6 @@ class WorkerPoolMock : public WorkerPool { if (pushed_it == pushedProcesses_.end()) { int runtime_env_hash = 0; bool is_java = false; - bool has_dynamic_options = false; // Parses runtime env hash to make sure the pushed workers can be popped out. for (auto command_args : it->second) { std::string runtime_env_key = "--runtime-env-hash="; @@ -326,14 +320,9 @@ class WorkerPoolMock : public WorkerPool { if (pos != std::string::npos) { is_java = true; } - pos = command_args.find("-X"); - if (pos != std::string::npos) { - has_dynamic_options = true; - } } // TODO(SongGuyang): support C++ language workers. - int num_workers = - (is_java && !has_dynamic_options) ? NUM_WORKERS_PER_PROCESS_JAVA : 1; + int num_workers = 1; RAY_CHECK(timeout_worker_number <= num_workers) << "The timeout worker number cannot exceed the total number of workers"; auto register_workers = num_workers - timeout_worker_number; @@ -471,7 +460,6 @@ class WorkerPoolTest : public ::testing::Test { worker_pool_ = std::make_unique( io_service_, worker_commands, mock_worker_rpc_clients_); rpc::JobConfig job_config; - job_config.set_num_java_workers_per_process(NUM_WORKERS_PER_PROCESS_JAVA); RegisterDriver(Language::PYTHON, JOB_ID, job_config); } @@ -489,18 +477,7 @@ class WorkerPoolTest : public ::testing::Test { ASSERT_TRUE(worker_pool_->NumWorkerProcessesStarting() <= expected_worker_process_count); Process prev = worker_pool_->LastStartedWorkerProcess(); - if (!std::equal_to()(last_started_worker_process, prev)) { - last_started_worker_process = prev; - const auto &real_command = - worker_pool_->GetWorkerCommand(last_started_worker_process); - if (language == Language::JAVA) { - auto it = std::find( - real_command.begin(), - real_command.end(), - GetNumJavaWorkersPerProcessSystemProperty(num_workers_per_process)); - ASSERT_NE(it, real_command.end()); - } - } else { + if (std::equal_to()(last_started_worker_process, prev)) { ASSERT_EQ(worker_pool_->NumWorkerProcessesStarting(), expected_worker_process_count); ASSERT_TRUE(i >= expected_worker_process_count); @@ -618,9 +595,7 @@ TEST_F(WorkerPoolTest, HandleWorkerRegistration) { auto [proc, token] = worker_pool_->StartWorkerProcess( Language::JAVA, rpc::WorkerType::WORKER, JOB_ID, &status); std::vector> workers; - for (int i = 0; i < NUM_WORKERS_PER_PROCESS_JAVA; i++) { - workers.push_back(worker_pool_->CreateWorker(Process(), Language::JAVA)); - } + workers.push_back(worker_pool_->CreateWorker(Process(), Language::JAVA)); for (const auto &worker : workers) { // Check that there's still a starting worker process // before all workers have been registered @@ -670,7 +645,7 @@ TEST_F(WorkerPoolTest, StartupPythonWorkerProcessCount) { } TEST_F(WorkerPoolTest, StartupJavaWorkerProcessCount) { - TestStartupWorkerProcessCount(Language::JAVA, NUM_WORKERS_PER_PROCESS_JAVA); + TestStartupWorkerProcessCount(Language::JAVA, 1); } TEST_F(WorkerPoolTest, InitialWorkerProcessCount) { @@ -756,7 +731,6 @@ TEST_F(WorkerPoolTest, StartWorkerWithDynamicOptionsCommand) { rpc::JobConfig job_config = rpc::JobConfig(); job_config.add_code_search_path("/test/code_search_path"); - job_config.set_num_java_workers_per_process(NUM_WORKERS_PER_PROCESS_JAVA); job_config.add_jvm_options("-Xmx1g"); job_config.add_jvm_options("-Xms500m"); job_config.add_jvm_options("-Dmy-job.hello=world"); @@ -779,7 +753,6 @@ TEST_F(WorkerPoolTest, StartWorkerWithDynamicOptionsCommand) { expected_command.end(), {"-Xmx1g", "-Xms500m", "-Dmy-job.hello=world", "-Dmy-job.foo=bar"}); // Ray-defined per-process options - expected_command.push_back(GetNumJavaWorkersPerProcessSystemProperty(1)); expected_command.push_back("-Dray.raylet.startup-token=0"); expected_command.push_back("-Dray.internal.runtime-env-hash=1"); // User-defined per-process options @@ -1162,46 +1135,6 @@ TEST_F(WorkerPoolTest, DeleteWorkerPushPop) { }); } -TEST_F(WorkerPoolTest, NoPopOnCrashedWorkerProcess) { - // Start a Java worker process. - PopWorkerStatus status; - auto [proc, token] = worker_pool_->StartWorkerProcess( - Language::JAVA, rpc::WorkerType::WORKER, JOB_ID, &status); - auto worker1 = worker_pool_->CreateWorker(Process(), Language::JAVA); - auto worker2 = worker_pool_->CreateWorker(Process(), Language::JAVA); - - // We now imitate worker process crashing while core worker initializing. - - // 1. we register both workers. - RAY_CHECK_OK(worker_pool_->RegisterWorker( - worker1, proc.GetId(), worker_pool_->GetStartupToken(proc), [](Status, int) {})); - RAY_CHECK_OK(worker_pool_->RegisterWorker( - worker2, proc.GetId(), worker_pool_->GetStartupToken(proc), [](Status, int) {})); - - // 2. announce worker port for worker 1. When interacting with worker pool, it's - // PushWorker. - worker_pool_->PushWorker(worker1); - - // 3. kill the worker process. Now let's assume that Raylet found that the connection - // with worker 1 disconnected first. - worker_pool_->DisconnectWorker( - worker1, /*disconnect_type=*/rpc::WorkerExitType::SYSTEM_ERROR_EXIT); - - // 4. but the RPC for announcing worker port for worker 2 is already in Raylet input - // buffer. So now Raylet needs to handle worker 2. - worker_pool_->PushWorker(worker2); - - // 5. Let's try to pop a worker to execute a task. Worker 2 shouldn't be popped because - // the process has crashed. - const auto task_spec = ExampleTaskSpec(); - ASSERT_NE(worker_pool_->PopWorkerSync(task_spec), worker1); - ASSERT_NE(worker_pool_->PopWorkerSync(task_spec), worker2); - - // 6. Now Raylet disconnects with worker 2. - worker_pool_->DisconnectWorker( - worker2, /*disconnect_type=*/rpc::WorkerExitType::SYSTEM_ERROR_EXIT); -} - TEST_F(WorkerPoolTest, TestWorkerCapping) { auto job_id = JOB_ID; @@ -1423,8 +1356,7 @@ TEST_F(WorkerPoolTest, TestWorkerCappingWithExitDelay) { PopWorkerStatus status; auto [proc, token] = worker_pool_->StartWorkerProcess( language, rpc::WorkerType::WORKER, JOB_ID, &status); - int workers_to_start = - language == Language::JAVA ? NUM_WORKERS_PER_PROCESS_JAVA : 1; + int workers_to_start = 1; for (int j = 0; j < workers_to_start; j++) { auto worker = worker_pool_->CreateWorker(Process(), language); worker->SetStartupToken(worker_pool_->GetStartupToken(proc)); @@ -1644,77 +1576,6 @@ TEST_F(WorkerPoolTest, RuntimeEnvUriReferenceWorkerLevel) { } } -TEST_F(WorkerPoolTest, RuntimeEnvUriReferenceWithMultipleWorkers) { - auto job_id = JOB_ID; - std::string uri = "s3://567"; - auto runtime_env_info = ExampleRuntimeEnvInfo({uri}, false); - rpc::JobConfig job_config; - job_config.set_num_java_workers_per_process(NUM_WORKERS_PER_PROCESS_JAVA); - job_config.mutable_runtime_env_info()->CopyFrom(runtime_env_info); - // Start job without eager installed runtime env. - worker_pool_->HandleJobStarted(job_id, job_config); - ASSERT_EQ(GetReferenceCount(runtime_env_info.serialized_runtime_env()), 0); - - // First part, test normal case with all worker registered. - { - // Start actors with runtime env. The Java actors will trigger a multi-worker process. - std::vector> workers; - for (int i = 0; i < NUM_WORKERS_PER_PROCESS_JAVA; i++) { - auto actor_creation_id = ActorID::Of(job_id, TaskID::ForDriverTask(job_id), i + 1); - const auto actor_creation_task_spec = - ExampleTaskSpec(ActorID::Nil(), - Language::JAVA, - job_id, - actor_creation_id, - {}, - TaskID::FromRandom(JobID::Nil()), - runtime_env_info); - auto popped_actor_worker = worker_pool_->PopWorkerSync(actor_creation_task_spec); - ASSERT_NE(popped_actor_worker, nullptr); - workers.push_back(popped_actor_worker); - ASSERT_EQ(GetReferenceCount(runtime_env_info.serialized_runtime_env()), 1); - } - // Make sure only one worker process has been started. - ASSERT_EQ(worker_pool_->GetProcessSize(), 1); - // Disconnect all actor workers. - for (auto &worker : workers) { - worker_pool_->DisconnectWorker(worker, rpc::WorkerExitType::IDLE_EXIT); - } - ASSERT_EQ(GetReferenceCount(runtime_env_info.serialized_runtime_env()), 0); - } - - // Second part, test corner case with some worker registration timeout. - { - // Start one actor with runtime env. The Java actor will trigger a multi-worker - // process. - auto actor_creation_id = ActorID::Of(job_id, TaskID::ForDriverTask(job_id), 1); - const auto actor_creation_task_spec = - ExampleTaskSpec(ActorID::Nil(), - Language::JAVA, - job_id, - actor_creation_id, - {}, - TaskID::FromRandom(JobID::Nil()), - runtime_env_info); - PopWorkerStatus status; - // Only one worker registration. All the other worker registration times out. - auto popped_actor_worker = worker_pool_->PopWorkerSync( - actor_creation_task_spec, true, &status, NUM_WORKERS_PER_PROCESS_JAVA - 1); - ASSERT_EQ(GetReferenceCount(runtime_env_info.serialized_runtime_env()), 1); - // Disconnect actor worker. - worker_pool_->DisconnectWorker(popped_actor_worker, rpc::WorkerExitType::IDLE_EXIT); - ASSERT_EQ(GetReferenceCount(runtime_env_info.serialized_runtime_env()), 1); - // Sleep for a while to wait worker registration timeout. - std::this_thread::sleep_for( - std::chrono::seconds(WORKER_REGISTER_TIMEOUT_SECONDS + 1)); - ASSERT_EQ(GetReferenceCount(runtime_env_info.serialized_runtime_env()), 0); - } - - // Finish the job. - worker_pool_->HandleJobFinished(job_id); - ASSERT_EQ(GetReferenceCount(runtime_env_info.serialized_runtime_env()), 0); -} - TEST_F(WorkerPoolTest, CacheWorkersByRuntimeEnvHash) { /// /// Check that a worker can be popped only if there is a