From eb29895dbb491552b9b957cfc01e5f72855e6bb0 Mon Sep 17 00:00:00 2001 From: Qing Wang Date: Thu, 19 May 2022 00:36:22 +0800 Subject: [PATCH] [Core] Remove multiple core workers in one process 1/n. (#24147) This is the 1st PR to remove the code path of multiple core workers in one process. This PR is aiming to remove the flags and APIs related to `num_workers`. After this PR checking in, we needn't to consider the multiple core workers any longer. The further following PRs are related to the deeper logic refactor, like eliminating the gap between core worker and core worker process, removing the logic related to multiple workers from workerpool, gcs and etc. **BREAK CHANGE** This PR removes these APIs: - Ray.wrapRunnable(); - Ray.wrapCallable(); - Ray.setAsyncContext(); - Ray.getAsyncContext(); And the following APIs are not allowed to invoke in a user-created thread in local mode: - Ray.getRuntimeContext().getCurrentActorId(); - Ray.getRuntimeContext().getCurrentTaskId() Note that this PR shouldn't be merged to 1.x. --- cpp/src/ray/util/process_helper.cc | 1 - java/api/src/main/java/io/ray/api/Ray.java | 47 ----- .../java/io/ray/api/runtime/RayRuntime.java | 21 --- .../test/ActorPerformanceTestBase.java | 5 +- .../test/ActorPerformanceTestCase1.java | 4 +- .../io/ray/runtime/AbstractRayRuntime.java | 72 +------- .../io/ray/runtime/ConcurrencyGroupImpl.java | 4 +- .../ray/runtime/DefaultRayRuntimeFactory.java | 5 +- .../java/io/ray/runtime/RayDevRuntime.java | 20 +-- .../java/io/ray/runtime/RayNativeRuntime.java | 21 +-- .../io/ray/runtime/RayRuntimeInternal.java | 2 - .../java/io/ray/runtime/RayRuntimeProxy.java | 80 --------- .../java/io/ray/runtime/config/RayConfig.java | 4 - .../context/LocalModeWorkerContext.java | 22 +-- .../runtime/context/NativeWorkerContext.java | 14 +- .../io/ray/runtime/context/WorkerContext.java | 9 - .../functionmanager/FunctionManager.java | 35 ++-- .../runtime/task/LocalModeTaskSubmitter.java | 3 - .../io/ray/runtime/task/TaskExecutor.java | 10 +- .../java/io/ray/runtime/util/MethodUtils.java | 2 +- .../ParallelActorContextImpl.java | 8 +- .../ParallelActorExecutorImpl.java | 7 +- .../src/main/resources/ray.default.conf | 2 - .../functionmanager/FunctionManagerTest.java | 28 ++- .../src/main/java/io/ray/serve/HttpProxy.java | 5 - java/test.sh | 3 +- .../test/ActorHandleReferenceCountTest.java | 26 --- .../java/io/ray/test/ClassLoaderTest.java | 167 ------------------ .../io/ray/test/DefaultActorLifetimeTest.java | 1 - .../main/java/io/ray/test/ExitActorTest.java | 32 ---- .../main/java/io/ray/test/ExitActorTest2.java | 1 - .../main/java/io/ray/test/FailureTest.java | 4 - .../main/java/io/ray/test/JobConfigTest.java | 5 - .../main/java/io/ray/test/KillActorTest.java | 1 - .../java/io/ray/test/MultiThreadingTest.java | 161 ++--------------- .../main/java/io/ray/test/RuntimeEnvTest.java | 87 --------- .../src/main/java/io/ray/test/TestUtils.java | 17 +- java/test/src/main/resources/ray.conf | 6 - python/ray/_raylet.pyx | 1 - python/ray/job_config.py | 8 - python/ray/tests/test_advanced_8.py | 4 - src/ray/core_worker/core_worker.cc | 9 - src/ray/core_worker/core_worker_options.h | 3 - src/ray/core_worker/core_worker_process.cc | 115 ++++-------- src/ray/core_worker/core_worker_process.h | 14 +- .../java/io_ray_runtime_RayNativeRuntime.cc | 2 - .../java/io_ray_runtime_RayNativeRuntime.h | 3 +- src/ray/core_worker/test/core_worker_test.cc | 1 - src/ray/core_worker/test/mock_worker.cc | 1 - src/ray/gcs/gcs_server/gcs_actor_manager.cc | 8 +- .../gcs_server/test/gcs_job_manager_test.cc | 7 +- src/ray/gcs/test/gcs_test_util.h | 5 +- src/ray/object_manager/common.h | 4 +- src/ray/protobuf/gcs.proto | 2 - src/ray/raylet/worker_pool.cc | 46 ++--- src/ray/raylet/worker_pool.h | 7 +- src/ray/raylet/worker_pool_test.cc | 155 +--------------- 57 files changed, 137 insertions(+), 1200 deletions(-) delete mode 100644 java/runtime/src/main/java/io/ray/runtime/RayRuntimeProxy.java delete mode 100644 java/test/src/main/java/io/ray/test/ClassLoaderTest.java delete mode 100644 java/test/src/main/resources/ray.conf diff --git a/cpp/src/ray/util/process_helper.cc b/cpp/src/ray/util/process_helper.cc index e1e7ba611..b5286cf7a 100644 --- a/cpp/src/ray/util/process_helper.cc +++ b/cpp/src/ray/util/process_helper.cc @@ -139,7 +139,6 @@ void ProcessHelper::RayStart(CoreWorkerOptions::TaskExecutionCallback callback) options.node_manager_port = ConfigInternal::Instance().node_manager_port; options.raylet_ip_address = node_ip; options.driver_name = "cpp_worker"; - options.num_workers = 1; options.metrics_agent_port = -1; options.task_execution_callback = callback; options.startup_token = ConfigInternal::Instance().startup_token; diff --git a/java/api/src/main/java/io/ray/api/Ray.java b/java/api/src/main/java/io/ray/api/Ray.java index 5dc3c4e6a..a6be0cd04 100644 --- a/java/api/src/main/java/io/ray/api/Ray.java +++ b/java/api/src/main/java/io/ray/api/Ray.java @@ -5,7 +5,6 @@ import io.ray.api.runtime.RayRuntimeFactory; import io.ray.api.runtimecontext.RuntimeContext; import java.util.List; import java.util.Optional; -import java.util.concurrent.Callable; /** This class contains all public APIs of Ray. */ public final class Ray extends RayCall { @@ -205,52 +204,6 @@ public final class Ray extends RayCall { return internal().getActor(name, namespace); } - /** - * If users want to use Ray API in their own threads, call this method to get the async context - * and then call {@link #setAsyncContext} at the beginning of the new thread. - * - * @return The async context. - */ - public static Object getAsyncContext() { - return internal().getAsyncContext(); - } - - /** - * Set the async context for the current thread. - * - * @param asyncContext The async context to set. - */ - public static void setAsyncContext(Object asyncContext) { - internal().setAsyncContext(asyncContext); - } - - // TODO (kfstorm): add the `rollbackAsyncContext` API to allow rollbacking the async context of - // the current thread to the one before `setAsyncContext` is called. - - // TODO (kfstorm): unify the `wrap*` methods. - - /** - * If users want to use Ray API in their own threads, they should wrap their {@link Runnable} - * objects with this method. - * - * @param runnable The runnable to wrap. - * @return The wrapped runnable. - */ - public static Runnable wrapRunnable(Runnable runnable) { - return internal().wrapRunnable(runnable); - } - - /** - * If users want to use Ray API in their own threads, they should wrap their {@link Callable} - * objects with this method. - * - * @param callable The callable to wrap. - * @return The wrapped callable. - */ - public static Callable wrapCallable(Callable callable) { - return internal().wrapCallable(callable); - } - /** Get the underlying runtime instance. */ public static RayRuntime internal() { if (runtime == null) { diff --git a/java/api/src/main/java/io/ray/api/runtime/RayRuntime.java b/java/api/src/main/java/io/ray/api/runtime/RayRuntime.java index ca2350a2b..c2e084c99 100644 --- a/java/api/src/main/java/io/ray/api/runtime/RayRuntime.java +++ b/java/api/src/main/java/io/ray/api/runtime/RayRuntime.java @@ -24,7 +24,6 @@ import io.ray.api.runtimeenv.RuntimeEnv; import java.util.List; import java.util.Map; import java.util.Optional; -import java.util.concurrent.Callable; /** Base interface of a Ray runtime. */ public interface RayRuntime { @@ -207,26 +206,6 @@ public interface RayRuntime { RuntimeContext getRuntimeContext(); - Object getAsyncContext(); - - void setAsyncContext(Object asyncContext); - - /** - * Wrap a {@link Runnable} with necessary context capture. - * - * @param runnable The runnable to wrap. - * @return The wrapped runnable. - */ - Runnable wrapRunnable(Runnable runnable); - - /** - * Wrap a {@link Callable} with necessary context capture. - * - * @param callable The callable to wrap. - * @return The wrapped callable. - */ - Callable wrapCallable(Callable callable); - /** Intentionally exit the current actor. */ void exitActor(); diff --git a/java/performance_test/src/main/java/io/ray/performancetest/test/ActorPerformanceTestBase.java b/java/performance_test/src/main/java/io/ray/performancetest/test/ActorPerformanceTestBase.java index e118567aa..bcff2ed7f 100644 --- a/java/performance_test/src/main/java/io/ray/performancetest/test/ActorPerformanceTestBase.java +++ b/java/performance_test/src/main/java/io/ray/performancetest/test/ActorPerformanceTestBase.java @@ -23,10 +23,7 @@ public class ActorPerformanceTestBase { boolean hasReturn, boolean ignoreReturn, int argSize, - boolean useDirectByteBuffer, - int numJavaWorkerPerProcess) { - System.setProperty( - "ray.job.num-java-workers-per-process", String.valueOf(numJavaWorkerPerProcess)); + boolean useDirectByteBuffer) { System.setProperty("ray.raylet.startup-token", "0"); Ray.init(); try { diff --git a/java/performance_test/src/main/java/io/ray/performancetest/test/ActorPerformanceTestCase1.java b/java/performance_test/src/main/java/io/ray/performancetest/test/ActorPerformanceTestCase1.java index c095d39ba..036c13409 100644 --- a/java/performance_test/src/main/java/io/ray/performancetest/test/ActorPerformanceTestCase1.java +++ b/java/performance_test/src/main/java/io/ray/performancetest/test/ActorPerformanceTestCase1.java @@ -13,7 +13,6 @@ public class ActorPerformanceTestCase1 { final int argSize = 0; final boolean useDirectByteBuffer = false; final boolean ignoreReturn = false; - final int numJavaWorkerPerProcess = 1; ActorPerformanceTestBase.run( args, layers, @@ -21,7 +20,6 @@ public class ActorPerformanceTestCase1 { hasReturn, ignoreReturn, argSize, - useDirectByteBuffer, - numJavaWorkerPerProcess); + useDirectByteBuffer); } } diff --git a/java/runtime/src/main/java/io/ray/runtime/AbstractRayRuntime.java b/java/runtime/src/main/java/io/ray/runtime/AbstractRayRuntime.java index 4a8a9f547..7659a6b35 100644 --- a/java/runtime/src/main/java/io/ray/runtime/AbstractRayRuntime.java +++ b/java/runtime/src/main/java/io/ray/runtime/AbstractRayRuntime.java @@ -31,7 +31,6 @@ import io.ray.runtime.functionmanager.FunctionDescriptor; import io.ray.runtime.functionmanager.FunctionManager; import io.ray.runtime.functionmanager.PyFunctionDescriptor; import io.ray.runtime.functionmanager.RayFunction; -import io.ray.runtime.generated.Common; import io.ray.runtime.generated.Common.Language; import io.ray.runtime.object.ObjectRefImpl; import io.ray.runtime.object.ObjectStore; @@ -46,7 +45,6 @@ import java.util.ArrayList; import java.util.List; import java.util.Map; import java.util.Optional; -import java.util.concurrent.Callable; import java.util.stream.Collectors; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -67,13 +65,8 @@ public abstract class AbstractRayRuntime implements RayRuntimeInternal { private static ParallelActorContextImpl parallelActorContextImpl = new ParallelActorContextImpl(); - /** Whether the required thread context is set on the current thread. */ - final ThreadLocal isContextSet = ThreadLocal.withInitial(() -> false); - public AbstractRayRuntime(RayConfig rayConfig) { this.rayConfig = rayConfig; - setIsContextSet(rayConfig.workerMode == Common.WorkerType.DRIVER); - functionManager = new FunctionManager(rayConfig.codeSearchPath); runtimeContext = new RuntimeContextImpl(this); } @@ -158,7 +151,7 @@ public abstract class AbstractRayRuntime implements RayRuntimeInternal { @Override public ObjectRef call(RayFunc func, Object[] args, CallOptions options) { - RayFunction rayFunction = functionManager.getFunction(workerContext.getCurrentJobId(), func); + RayFunction rayFunction = functionManager.getFunction(func); FunctionDescriptor functionDescriptor = rayFunction.functionDescriptor; Optional> returnType = rayFunction.getReturnType(); return callNormalFunction(functionDescriptor, args, returnType, options); @@ -176,7 +169,7 @@ public abstract class AbstractRayRuntime implements RayRuntimeInternal { @Override public ObjectRef callActor( ActorHandle actor, RayFunc func, Object[] args, CallOptions options) { - RayFunction rayFunction = functionManager.getFunction(workerContext.getCurrentJobId(), func); + RayFunction rayFunction = functionManager.getFunction(func); FunctionDescriptor functionDescriptor = rayFunction.functionDescriptor; Optional> returnType = rayFunction.getReturnType(); return callActorFunction(actor, functionDescriptor, args, returnType, options); @@ -201,8 +194,7 @@ public abstract class AbstractRayRuntime implements RayRuntimeInternal { public ActorHandle createActor( RayFunc actorFactoryFunc, Object[] args, ActorCreationOptions options) { FunctionDescriptor functionDescriptor = - functionManager.getFunction(workerContext.getCurrentJobId(), actorFactoryFunc) - .functionDescriptor; + functionManager.getFunction(actorFactoryFunc).functionDescriptor; return (ActorHandle) createActorImpl(functionDescriptor, args, options); } @@ -256,31 +248,6 @@ public abstract class AbstractRayRuntime implements RayRuntimeInternal { return (T) taskSubmitter.getActor(actorId); } - @Override - public void setAsyncContext(Object asyncContext) { - isContextSet.set(true); - } - - @Override - public final Runnable wrapRunnable(Runnable runnable) { - Object asyncContext = getAsyncContext(); - return () -> { - try (RayAsyncContextUpdater updater = new RayAsyncContextUpdater(asyncContext, this)) { - runnable.run(); - } - }; - } - - @Override - public final Callable wrapCallable(Callable callable) { - Object asyncContext = getAsyncContext(); - return () -> { - try (RayAsyncContextUpdater updater = new RayAsyncContextUpdater(asyncContext, this)) { - return callable.call(); - } - }; - } - @Override public ConcurrencyGroup createConcurrencyGroup( String name, int maxConcurrency, List funcs) { @@ -387,34 +354,6 @@ public abstract class AbstractRayRuntime implements RayRuntimeInternal { return actor; } - /// An auto closable class that is used for updating the async context when invoking Ray APIs. - private static final class RayAsyncContextUpdater implements AutoCloseable { - - private AbstractRayRuntime runtime; - - private boolean oldIsContextSet; - - private Object oldAsyncContext = null; - - public RayAsyncContextUpdater(Object asyncContext, AbstractRayRuntime runtime) { - this.runtime = runtime; - oldIsContextSet = runtime.isContextSet.get(); - if (oldIsContextSet) { - oldAsyncContext = runtime.getAsyncContext(); - } - runtime.setAsyncContext(asyncContext); - } - - @Override - public void close() { - if (oldIsContextSet) { - runtime.setAsyncContext(oldAsyncContext); - } else { - runtime.setIsContextSet(false); - } - } - } - abstract List getCurrentReturnIds(int numReturns, ActorId actorId); @Override @@ -446,11 +385,6 @@ public abstract class AbstractRayRuntime implements RayRuntimeInternal { return runtimeContext; } - @Override - public void setIsContextSet(boolean isContextSet) { - this.isContextSet.set(isContextSet); - } - /// A helper to validate if the prepared return ids is as expected. void validatePreparedReturnIds(List preparedReturnIds, List realReturnIds) { if (rayConfig.runMode == RunMode.CLUSTER) { diff --git a/java/runtime/src/main/java/io/ray/runtime/ConcurrencyGroupImpl.java b/java/runtime/src/main/java/io/ray/runtime/ConcurrencyGroupImpl.java index b5b2e4ebd..53ac57da5 100644 --- a/java/runtime/src/main/java/io/ray/runtime/ConcurrencyGroupImpl.java +++ b/java/runtime/src/main/java/io/ray/runtime/ConcurrencyGroupImpl.java @@ -24,9 +24,7 @@ public class ConcurrencyGroupImpl implements ConcurrencyGroup { funcs.forEach( func -> { RayFunction rayFunc = - ((RayRuntimeInternal) Ray.internal()) - .getFunctionManager() - .getFunction(Ray.getRuntimeContext().getCurrentJobId(), func); + ((RayRuntimeInternal) Ray.internal()).getFunctionManager().getFunction(func); functionDescriptors.add(rayFunc.getFunctionDescriptor()); }); } diff --git a/java/runtime/src/main/java/io/ray/runtime/DefaultRayRuntimeFactory.java b/java/runtime/src/main/java/io/ray/runtime/DefaultRayRuntimeFactory.java index 2356a70c7..a9461b820 100644 --- a/java/runtime/src/main/java/io/ray/runtime/DefaultRayRuntimeFactory.java +++ b/java/runtime/src/main/java/io/ray/runtime/DefaultRayRuntimeFactory.java @@ -32,10 +32,7 @@ public class DefaultRayRuntimeFactory implements RayRuntimeFactory { rayConfig.runMode == RunMode.SINGLE_PROCESS ? new RayDevRuntime(rayConfig) : new RayNativeRuntime(rayConfig); - RayRuntimeInternal runtime = - rayConfig.numWorkersPerProcess > 1 - ? RayRuntimeProxy.newInstance(innerRuntime) - : innerRuntime; + RayRuntimeInternal runtime = innerRuntime; runtime.start(); return runtime; } catch (Exception e) { diff --git a/java/runtime/src/main/java/io/ray/runtime/RayDevRuntime.java b/java/runtime/src/main/java/io/ray/runtime/RayDevRuntime.java index 080a31fcc..cba3ff298 100644 --- a/java/runtime/src/main/java/io/ray/runtime/RayDevRuntime.java +++ b/java/runtime/src/main/java/io/ray/runtime/RayDevRuntime.java @@ -1,6 +1,5 @@ package io.ray.runtime; -import com.google.common.base.Preconditions; import io.ray.api.BaseActorHandle; import io.ray.api.id.ActorId; import io.ray.api.id.JobId; @@ -10,6 +9,7 @@ import io.ray.api.placementgroup.PlacementGroup; import io.ray.api.runtimecontext.ResourceValue; import io.ray.runtime.config.RayConfig; import io.ray.runtime.context.LocalModeWorkerContext; +import io.ray.runtime.functionmanager.FunctionManager; import io.ray.runtime.gcs.GcsClient; import io.ray.runtime.generated.Common.TaskSpec; import io.ray.runtime.object.LocalModeObjectStore; @@ -24,13 +24,9 @@ import java.util.List; import java.util.Map; import java.util.Optional; import java.util.concurrent.atomic.AtomicInteger; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; public class RayDevRuntime extends AbstractRayRuntime { - private static final Logger LOGGER = LoggerFactory.getLogger(RayDevRuntime.class); - private AtomicInteger jobCounter = new AtomicInteger(0); public RayDevRuntime(RayConfig rayConfig) { @@ -49,6 +45,7 @@ public class RayDevRuntime extends AbstractRayRuntime { taskExecutor = new LocalModeTaskExecutor(this); workerContext = new LocalModeWorkerContext(rayConfig.getJobId()); objectStore = new LocalModeObjectStore(workerContext); + functionManager = new FunctionManager(rayConfig.codeSearchPath); taskSubmitter = new LocalModeTaskSubmitter(this, taskExecutor, (LocalModeObjectStore) objectStore); ((LocalModeObjectStore) objectStore) @@ -90,19 +87,6 @@ public class RayDevRuntime extends AbstractRayRuntime { throw new UnsupportedOperationException("Ray doesn't have gcs client in local mode."); } - @Override - public Object getAsyncContext() { - return new AsyncContext(((LocalModeWorkerContext) workerContext).getCurrentTask()); - } - - @Override - public void setAsyncContext(Object asyncContext) { - Preconditions.checkNotNull(asyncContext); - TaskSpec task = ((AsyncContext) asyncContext).task; - ((LocalModeWorkerContext) workerContext).setCurrentTask(task); - super.setAsyncContext(asyncContext); - } - @Override public Map> getAvailableResourceIds() { throw new UnsupportedOperationException("Ray doesn't support get resources ids in local mode."); diff --git a/java/runtime/src/main/java/io/ray/runtime/RayNativeRuntime.java b/java/runtime/src/main/java/io/ray/runtime/RayNativeRuntime.java index b53d655f6..3328413c5 100644 --- a/java/runtime/src/main/java/io/ray/runtime/RayNativeRuntime.java +++ b/java/runtime/src/main/java/io/ray/runtime/RayNativeRuntime.java @@ -14,6 +14,7 @@ import io.ray.api.runtimecontext.ResourceValue; import io.ray.runtime.config.RayConfig; import io.ray.runtime.context.NativeWorkerContext; import io.ray.runtime.exception.RayIntentionalSystemExitException; +import io.ray.runtime.functionmanager.FunctionManager; import io.ray.runtime.gcs.GcsClient; import io.ray.runtime.gcs.GcsClientOptions; import io.ray.runtime.generated.Common.WorkerType; @@ -102,14 +103,13 @@ public final class RayNativeRuntime extends AbstractRayRuntime { if (rayConfig.workerMode == WorkerType.DRIVER && rayConfig.getJobId() == JobId.NIL) { rayConfig.setJobId(getGcsClient().nextJobId()); } - int numWorkersPerProcess = - rayConfig.workerMode == WorkerType.DRIVER ? 1 : rayConfig.numWorkersPerProcess; + // Make sure the job id has been set already. + functionManager = new FunctionManager(rayConfig.codeSearchPath); byte[] serializedJobConfig = null; if (rayConfig.workerMode == WorkerType.DRIVER) { JobConfig.Builder jobConfigBuilder = JobConfig.newBuilder() - .setNumJavaWorkersPerProcess(rayConfig.numWorkersPerProcess) .addAllJvmOptions(rayConfig.jvmOptionsForJavaWorker) .addAllCodeSearchPath(rayConfig.codeSearchPath) .setRayNamespace(rayConfig.namespace); @@ -150,7 +150,6 @@ public final class RayNativeRuntime extends AbstractRayRuntime { rayConfig.rayletSocketName, (rayConfig.workerMode == WorkerType.DRIVER ? rayConfig.getJobId() : JobId.NIL).getBytes(), new GcsClientOptions(rayConfig), - numWorkersPerProcess, rayConfig.logDir, serializedJobConfig, rayConfig.getStartupToken(), @@ -223,19 +222,6 @@ public final class RayNativeRuntime extends AbstractRayRuntime { nativeKillActor(actor.getId().getBytes(), noRestart); } - @Override - public Object getAsyncContext() { - return new AsyncContext( - workerContext.getCurrentWorkerId(), workerContext.getCurrentClassLoader()); - } - - @Override - public void setAsyncContext(Object asyncContext) { - nativeSetCoreWorker(((AsyncContext) asyncContext).workerId.getBytes()); - workerContext.setCurrentClassLoader(((AsyncContext) asyncContext).currentClassLoader); - super.setAsyncContext(asyncContext); - } - @Override List getCurrentReturnIds(int numReturns, ActorId actorId) { List ret = nativeGetCurrentReturnIds(numReturns, actorId.getBytes()); @@ -289,7 +275,6 @@ public final class RayNativeRuntime extends AbstractRayRuntime { String rayletSocket, byte[] jobId, GcsClientOptions gcsClientOptions, - int numWorkersPerProcess, String logDir, byte[] serializedJobConfig, int startupToken, diff --git a/java/runtime/src/main/java/io/ray/runtime/RayRuntimeInternal.java b/java/runtime/src/main/java/io/ray/runtime/RayRuntimeInternal.java index 8fc14caa3..fd1a23b90 100644 --- a/java/runtime/src/main/java/io/ray/runtime/RayRuntimeInternal.java +++ b/java/runtime/src/main/java/io/ray/runtime/RayRuntimeInternal.java @@ -26,7 +26,5 @@ public interface RayRuntimeInternal extends RayRuntime { GcsClient getGcsClient(); - void setIsContextSet(boolean isContextSet); - void run(); } diff --git a/java/runtime/src/main/java/io/ray/runtime/RayRuntimeProxy.java b/java/runtime/src/main/java/io/ray/runtime/RayRuntimeProxy.java deleted file mode 100644 index 4a427295c..000000000 --- a/java/runtime/src/main/java/io/ray/runtime/RayRuntimeProxy.java +++ /dev/null @@ -1,80 +0,0 @@ -package io.ray.runtime; - -import io.ray.api.runtime.RayRuntime; -import io.ray.runtime.config.RunMode; -import io.ray.runtime.exception.RayException; -import java.lang.reflect.InvocationHandler; -import java.lang.reflect.InvocationTargetException; -import java.lang.reflect.Method; - -/** - * Protect a ray runtime with context checks for all methods of {@link RayRuntime} (except {@link - * RayRuntime#shutdown(boolean)}). - */ -public class RayRuntimeProxy implements InvocationHandler { - - /** The original runtime. */ - private AbstractRayRuntime obj; - - private RayRuntimeProxy(AbstractRayRuntime obj) { - this.obj = obj; - } - - public AbstractRayRuntime getRuntimeObject() { - return obj; - } - - /** Generate a new instance of {@link RayRuntimeInternal} with additional context check. */ - static RayRuntimeInternal newInstance(AbstractRayRuntime obj) { - return (RayRuntimeInternal) - java.lang.reflect.Proxy.newProxyInstance( - obj.getClass().getClassLoader(), - new Class[] {RayRuntimeInternal.class}, - new RayRuntimeProxy(obj)); - } - - @Override - public Object invoke(Object proxy, Method method, Object[] args) throws Throwable { - if (isInterfaceMethod(method) - && !method.getName().equals("shutdown") - && !method.getName().equals("setAsyncContext")) { - checkIsContextSet(); - } - try { - return method.invoke(obj, args); - } catch (InvocationTargetException e) { - if (e.getCause() != null) { - throw e.getCause(); - } else { - throw e; - } - } - } - - /** Whether the method is defined in the {@link RayRuntime} interface. */ - private boolean isInterfaceMethod(Method method) { - try { - RayRuntime.class.getMethod(method.getName(), method.getParameterTypes()); - return true; - } catch (NoSuchMethodException e) { - return false; - } - } - - /** - * Check if thread context is set. - * - *

This method should be invoked at the beginning of most public methods of {@link RayRuntime}, - * otherwise the native code might crash due to thread local core worker was not set. We check it - * for {@link AbstractRayRuntime} instead of {@link RayNativeRuntime} because we want to catch the - * error even if the application runs in {@link RunMode#SINGLE_PROCESS} mode. - */ - private void checkIsContextSet() { - if (!obj.isContextSet.get()) { - throw new RayException( - "`Ray.wrap***` is not called on the current thread." - + " If you want to use Ray API in your own threads," - + " please wrap your executable with `Ray.wrap***`."); - } - } -} diff --git a/java/runtime/src/main/java/io/ray/runtime/config/RayConfig.java b/java/runtime/src/main/java/io/ray/runtime/config/RayConfig.java index dd0e2ef3a..89dada1eb 100644 --- a/java/runtime/src/main/java/io/ray/runtime/config/RayConfig.java +++ b/java/runtime/src/main/java/io/ray/runtime/config/RayConfig.java @@ -75,8 +75,6 @@ public class RayConfig { public final List headArgs; - public final int numWorkersPerProcess; - public final String namespace; public final List jvmOptionsForJavaWorker; @@ -190,8 +188,6 @@ public class RayConfig { } codeSearchPath = Arrays.asList(codeSearchPathString.split(":")); - numWorkersPerProcess = config.getInt("ray.job.num-java-workers-per-process"); - startupToken = config.getInt("ray.raylet.startup-token"); /// Driver needn't this config item. diff --git a/java/runtime/src/main/java/io/ray/runtime/context/LocalModeWorkerContext.java b/java/runtime/src/main/java/io/ray/runtime/context/LocalModeWorkerContext.java index 76ba443a2..312a86c5a 100644 --- a/java/runtime/src/main/java/io/ray/runtime/context/LocalModeWorkerContext.java +++ b/java/runtime/src/main/java/io/ray/runtime/context/LocalModeWorkerContext.java @@ -48,31 +48,21 @@ public class LocalModeWorkerContext implements WorkerContext { @Override public ActorId getCurrentActorId() { TaskSpec taskSpec = currentTask.get(); - if (taskSpec == null) { - return ActorId.NIL; - } + checkTaskSpecNotNull(taskSpec); return LocalModeTaskSubmitter.getActorId(taskSpec); } - @Override - public ClassLoader getCurrentClassLoader() { - return null; - } - - @Override - public void setCurrentClassLoader(ClassLoader currentClassLoader) {} - @Override public TaskType getCurrentTaskType() { TaskSpec taskSpec = currentTask.get(); - Preconditions.checkNotNull(taskSpec, "Current task is not set."); + checkTaskSpecNotNull(taskSpec); return taskSpec.getType(); } @Override public TaskId getCurrentTaskId() { TaskSpec taskSpec = currentTask.get(); - Preconditions.checkState(taskSpec != null); + checkTaskSpecNotNull(taskSpec); return TaskId.fromBytes(taskSpec.getTaskId().toByteArray()); } @@ -85,7 +75,9 @@ public class LocalModeWorkerContext implements WorkerContext { currentTask.set(taskSpec); } - public TaskSpec getCurrentTask() { - return currentTask.get(); + private static void checkTaskSpecNotNull(TaskSpec taskSpec) { + Preconditions.checkNotNull( + taskSpec, + "Current task is not set. Maybe you invoked this API in a user-created thread not managed by Ray. Invoking this API in a user-created thread is not supported yet in local mode. You can switch to cluster mode."); } } diff --git a/java/runtime/src/main/java/io/ray/runtime/context/NativeWorkerContext.java b/java/runtime/src/main/java/io/ray/runtime/context/NativeWorkerContext.java index d20d48d0a..467dd0169 100644 --- a/java/runtime/src/main/java/io/ray/runtime/context/NativeWorkerContext.java +++ b/java/runtime/src/main/java/io/ray/runtime/context/NativeWorkerContext.java @@ -12,7 +12,7 @@ import java.nio.ByteBuffer; /** Worker context for cluster mode. This is a wrapper class for worker context of core worker. */ public class NativeWorkerContext implements WorkerContext { - private final ThreadLocal currentClassLoader = new ThreadLocal<>(); + private ClassLoader currentClassLoader = null; @Override public UniqueId getCurrentWorkerId() { @@ -29,18 +29,6 @@ public class NativeWorkerContext implements WorkerContext { return ActorId.fromByteBuffer(nativeGetCurrentActorId()); } - @Override - public ClassLoader getCurrentClassLoader() { - return currentClassLoader.get(); - } - - @Override - public void setCurrentClassLoader(ClassLoader currentClassLoader) { - if (this.currentClassLoader.get() != currentClassLoader) { - this.currentClassLoader.set(currentClassLoader); - } - } - @Override public TaskType getCurrentTaskType() { return TaskType.forNumber(nativeGetCurrentTaskType()); diff --git a/java/runtime/src/main/java/io/ray/runtime/context/WorkerContext.java b/java/runtime/src/main/java/io/ray/runtime/context/WorkerContext.java index ff750f651..3e3ac01e9 100644 --- a/java/runtime/src/main/java/io/ray/runtime/context/WorkerContext.java +++ b/java/runtime/src/main/java/io/ray/runtime/context/WorkerContext.java @@ -19,15 +19,6 @@ public interface WorkerContext { /** ID of the current actor. */ ActorId getCurrentActorId(); - /** - * The class loader that is associated with the current job. It's used for locating classes when - * dealing with serialization and deserialization in {@link Serializer}. - */ - ClassLoader getCurrentClassLoader(); - - /** Set the current class loader. */ - void setCurrentClassLoader(ClassLoader currentClassLoader); - /** Type of the current task. */ TaskType getCurrentTaskType(); diff --git a/java/runtime/src/main/java/io/ray/runtime/functionmanager/FunctionManager.java b/java/runtime/src/main/java/io/ray/runtime/functionmanager/FunctionManager.java index 75e9c2821..3ad60b434 100644 --- a/java/runtime/src/main/java/io/ray/runtime/functionmanager/FunctionManager.java +++ b/java/runtime/src/main/java/io/ray/runtime/functionmanager/FunctionManager.java @@ -2,7 +2,6 @@ package io.ray.runtime.functionmanager; import com.google.common.collect.Lists; import io.ray.api.function.RayFunc; -import io.ray.api.id.JobId; import io.ray.runtime.util.LambdaUtils; import java.io.File; import java.lang.invoke.SerializedLambda; @@ -34,7 +33,7 @@ import org.objectweb.asm.Type; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -/** Manages functions by job id. */ +/** Manages functions in the current worker. */ public class FunctionManager { private static final Logger LOGGER = LoggerFactory.getLogger(FunctionManager.class); @@ -50,8 +49,8 @@ public class FunctionManager { private static final ThreadLocal, JavaFunctionDescriptor>> RAY_FUNC_CACHE = ThreadLocal.withInitial(WeakHashMap::new); - /** Mapping from the job id to the functions that belong to this job. */ - private ConcurrentMap jobFunctionTables = new ConcurrentHashMap<>(); + /** The table that manages functions. */ + private final JobFunctionTable jobFunctionTable; /** The resource path which we can load the job's jar resources. */ private final List codeSearchPath; @@ -63,16 +62,20 @@ public class FunctionManager { */ public FunctionManager(List codeSearchPath) { this.codeSearchPath = codeSearchPath; + jobFunctionTable = createJobFunctionTable(); + } + + public ClassLoader getClassLoader() { + return jobFunctionTable.classLoader; } /** * Get the RayFunction from a RayFunc instance (a lambda). * - * @param jobId current job id. * @param func The lambda. * @return A RayFunction object. */ - public RayFunction getFunction(JobId jobId, RayFunc func) { + public RayFunction getFunction(RayFunc func) { JavaFunctionDescriptor functionDescriptor = RAY_FUNC_CACHE.get().get(func.getClass()); if (functionDescriptor == null) { // It's OK to not lock here, because it's OK to have multiple JavaFunctionDescriptor instances @@ -84,31 +87,21 @@ public class FunctionManager { functionDescriptor = new JavaFunctionDescriptor(className, methodName, signature); RAY_FUNC_CACHE.get().put(func.getClass(), functionDescriptor); } - return getFunction(jobId, functionDescriptor); + return getFunction(functionDescriptor); } /** * Get the RayFunction from a function descriptor. * - * @param jobId Current job id. * @param functionDescriptor The function descriptor. * @return A RayFunction object. */ - public RayFunction getFunction(JobId jobId, JavaFunctionDescriptor functionDescriptor) { - JobFunctionTable jobFunctionTable = jobFunctionTables.get(jobId); - if (jobFunctionTable == null) { - synchronized (this) { - jobFunctionTable = jobFunctionTables.get(jobId); - if (jobFunctionTable == null) { - jobFunctionTable = createJobFunctionTable(jobId); - jobFunctionTables.put(jobId, jobFunctionTable); - } - } - } + public RayFunction getFunction(JavaFunctionDescriptor functionDescriptor) { return jobFunctionTable.getFunction(functionDescriptor); } - private JobFunctionTable createJobFunctionTable(JobId jobId) { + /** A helper that creates function table. */ + private JobFunctionTable createJobFunctionTable() { ClassLoader classLoader; if (codeSearchPath == null || codeSearchPath.isEmpty()) { classLoader = getClass().getClassLoader(); @@ -145,7 +138,7 @@ public class FunctionManager { }) .toArray(URL[]::new); classLoader = new URLClassLoader(urls); - LOGGER.debug("Resource loaded for job {} from path {}.", jobId, urls); + LOGGER.debug("Resource loaded from path {}.", urls); } return new JobFunctionTable(classLoader); diff --git a/java/runtime/src/main/java/io/ray/runtime/task/LocalModeTaskSubmitter.java b/java/runtime/src/main/java/io/ray/runtime/task/LocalModeTaskSubmitter.java index 0e5687657..c6376c639 100644 --- a/java/runtime/src/main/java/io/ray/runtime/task/LocalModeTaskSubmitter.java +++ b/java/runtime/src/main/java/io/ray/runtime/task/LocalModeTaskSubmitter.java @@ -496,7 +496,6 @@ public class LocalModeTaskSubmitter implements TaskSubmitter { ? objectStore.getRaw(Collections.singletonList(arg.id), -1).get(0) : arg.value) .collect(Collectors.toList()); - runtime.setIsContextSet(true); ((LocalModeWorkerContext) runtime.getWorkerContext()).setCurrentTask(taskSpec); ((LocalModeWorkerContext) runtime.getWorkerContext()).setCurrentWorkerId(workerId); @@ -513,9 +512,7 @@ public class LocalModeTaskSubmitter implements TaskSubmitter { // Set this flag to true is necessary because at the end of `taskExecutor.execute()`, // this flag will be set to false. And `runtime.getWorkerContext()` requires it to be // true. - runtime.setIsContextSet(true); ((LocalModeWorkerContext) runtime.getWorkerContext()).setCurrentTask(null); - runtime.setIsContextSet(false); List returnIds = getReturnIds(taskSpec); for (int i = 0; i < returnIds.size(); i++) { NativeRayObject putObject; diff --git a/java/runtime/src/main/java/io/ray/runtime/task/TaskExecutor.java b/java/runtime/src/main/java/io/ray/runtime/task/TaskExecutor.java index 9cce4cf02..1c7d17414 100644 --- a/java/runtime/src/main/java/io/ray/runtime/task/TaskExecutor.java +++ b/java/runtime/src/main/java/io/ray/runtime/task/TaskExecutor.java @@ -32,6 +32,7 @@ public abstract class TaskExecutor { protected final RayRuntimeInternal runtime; + // TODO(qwang): Use actorContext instead later. private final ConcurrentHashMap actorContextMap = new ConcurrentHashMap<>(); private final ThreadLocal localRayFunction = new ThreadLocal<>(); @@ -66,7 +67,7 @@ public abstract class TaskExecutor { private RayFunction getRayFunction(List rayFunctionInfo) { JobId jobId = runtime.getWorkerContext().getCurrentJobId(); JavaFunctionDescriptor functionDescriptor = parseFunctionDescriptor(rayFunctionInfo); - return runtime.getFunctionManager().getFunction(jobId, functionDescriptor); + return runtime.getFunctionManager().getFunction(functionDescriptor); } /** The return value indicates which parameters are ByteBuffer. */ @@ -93,7 +94,6 @@ public abstract class TaskExecutor { } protected List execute(List rayFunctionInfo, List argsBytes) { - runtime.setIsContextSet(true); TaskType taskType = runtime.getWorkerContext().getCurrentTaskType(); TaskId taskId = runtime.getWorkerContext().getCurrentTaskId(); LOGGER.debug("Executing task {} {}", taskId, rayFunctionInfo); @@ -108,7 +108,6 @@ public abstract class TaskExecutor { } List returnObjects = new ArrayList<>(); - ClassLoader oldLoader = Thread.currentThread().getContextClassLoader(); // Find the executable object. RayFunction rayFunction = localRayFunction.get(); @@ -121,7 +120,6 @@ public abstract class TaskExecutor { rayFunction = getRayFunction(rayFunctionInfo); } Thread.currentThread().setContextClassLoader(rayFunction.classLoader); - runtime.getWorkerContext().setCurrentClassLoader(rayFunction.classLoader); // Get local actor object and arguments. Object actor = null; @@ -215,10 +213,6 @@ public abstract class TaskExecutor { } else { throw new RayActorException(e); } - } finally { - Thread.currentThread().setContextClassLoader(oldLoader); - runtime.getWorkerContext().setCurrentClassLoader(null); - runtime.setIsContextSet(false); } return returnObjects; } diff --git a/java/runtime/src/main/java/io/ray/runtime/util/MethodUtils.java b/java/runtime/src/main/java/io/ray/runtime/util/MethodUtils.java index 8f59102f6..b6523562f 100644 --- a/java/runtime/src/main/java/io/ray/runtime/util/MethodUtils.java +++ b/java/runtime/src/main/java/io/ray/runtime/util/MethodUtils.java @@ -47,7 +47,7 @@ public final class MethodUtils { /// This code path indicates that here might be in another thread of a worker. /// So try to load the class from URLClassLoader of this worker. ClassLoader cl = - ((RayRuntimeInternal) Ray.internal()).getWorkerContext().getCurrentClassLoader(); + ((RayRuntimeInternal) Ray.internal()).getFunctionManager().getClassLoader(); actorClz = Class.forName(className, true, cl); } } catch (Exception e) { diff --git a/java/runtime/src/main/java/io/ray/runtime/utils/parallelactor/ParallelActorContextImpl.java b/java/runtime/src/main/java/io/ray/runtime/utils/parallelactor/ParallelActorContextImpl.java index cd0e3ebb4..9da1f4cd0 100644 --- a/java/runtime/src/main/java/io/ray/runtime/utils/parallelactor/ParallelActorContextImpl.java +++ b/java/runtime/src/main/java/io/ray/runtime/utils/parallelactor/ParallelActorContextImpl.java @@ -28,9 +28,7 @@ public class ParallelActorContextImpl implements ParallelActorContext { FunctionManager functionManager = ((RayRuntimeInternal) Ray.internal()).getFunctionManager(); JavaFunctionDescriptor functionDescriptor = - functionManager - .getFunction(Ray.getRuntimeContext().getCurrentJobId(), ctorFunc) - .getFunctionDescriptor(); + functionManager.getFunction(ctorFunc).getFunctionDescriptor(); ActorHandle parallelExecutorHandle = Ray.actor(ParallelActorExecutorImpl::new, parallelism, functionDescriptor) .setConcurrencyGroups(concurrencyGroups) @@ -46,9 +44,7 @@ public class ParallelActorContextImpl implements ParallelActorContext { ((ParallelActorHandleImpl) parallelActorHandle).getExecutor(); FunctionManager functionManager = ((RayRuntimeInternal) Ray.internal()).getFunctionManager(); JavaFunctionDescriptor functionDescriptor = - functionManager - .getFunction(Ray.getRuntimeContext().getCurrentJobId(), func) - .getFunctionDescriptor(); + functionManager.getFunction(func).getFunctionDescriptor(); ObjectRef ret = parallelExecutor .task(ParallelActorExecutorImpl::execute, instanceId, functionDescriptor, args) diff --git a/java/runtime/src/main/java/io/ray/runtime/utils/parallelactor/ParallelActorExecutorImpl.java b/java/runtime/src/main/java/io/ray/runtime/utils/parallelactor/ParallelActorExecutorImpl.java index f3d481381..3836303e1 100644 --- a/java/runtime/src/main/java/io/ray/runtime/utils/parallelactor/ParallelActorExecutorImpl.java +++ b/java/runtime/src/main/java/io/ray/runtime/utils/parallelactor/ParallelActorExecutorImpl.java @@ -23,9 +23,7 @@ public class ParallelActorExecutorImpl { throws InvocationTargetException, IllegalAccessException { functionManager = ((RayRuntimeInternal) Ray.internal()).getFunctionManager(); - RayFunction init = - functionManager.getFunction( - Ray.getRuntimeContext().getCurrentJobId(), javaFunctionDescriptor); + RayFunction init = functionManager.getFunction(javaFunctionDescriptor); Thread.currentThread().setContextClassLoader(init.classLoader); for (int i = 0; i < parallelism; ++i) { Object instance = init.getMethod().invoke(null, null); @@ -35,8 +33,7 @@ public class ParallelActorExecutorImpl { public Object execute(int instanceId, JavaFunctionDescriptor functionDescriptor, Object[] args) throws IllegalAccessException, InvocationTargetException { - RayFunction func = - functionManager.getFunction(Ray.getRuntimeContext().getCurrentJobId(), functionDescriptor); + RayFunction func = functionManager.getFunction(functionDescriptor); Preconditions.checkState(instances.containsKey(instanceId)); return func.getMethod().invoke(instances.get(instanceId), args); } diff --git a/java/runtime/src/main/resources/ray.default.conf b/java/runtime/src/main/resources/ray.default.conf index 0ff149281..cf73f321d 100644 --- a/java/runtime/src/main/resources/ray.default.conf +++ b/java/runtime/src/main/resources/ray.default.conf @@ -27,8 +27,6 @@ ray { // search path for user code. This will be used as `CLASSPATH` in Java, // and `PYTHONPATH` in Python. code-search-path: "" - /// The number of java worker per worker process. - num-java-workers-per-process: 1 /// The jvm options for java workers of the job. jvm-options: [] diff --git a/java/runtime/src/test/java/io/ray/runtime/functionmanager/FunctionManagerTest.java b/java/runtime/src/test/java/io/ray/runtime/functionmanager/FunctionManagerTest.java index e693a4da5..f1da038e7 100644 --- a/java/runtime/src/test/java/io/ray/runtime/functionmanager/FunctionManagerTest.java +++ b/java/runtime/src/test/java/io/ray/runtime/functionmanager/FunctionManagerTest.java @@ -61,6 +61,8 @@ public class FunctionManagerTest { } } + private static final JobId JOB_ID = JobId.fromInt(1); + private static RayFunc0 fooFunc; private static RayFunc1 childClassBarFunc; private static RayFunc0 childClassConstructor; @@ -95,17 +97,17 @@ public class FunctionManagerTest { public void testGetFunctionFromRayFunc() { final FunctionManager functionManager = new FunctionManager(null); // Test normal function. - RayFunction func = functionManager.getFunction(JobId.NIL, fooFunc); + RayFunction func = functionManager.getFunction(fooFunc); Assert.assertFalse(func.isConstructor()); Assert.assertEquals(func.getFunctionDescriptor(), fooDescriptor); // Test actor method - func = functionManager.getFunction(JobId.NIL, childClassBarFunc); + func = functionManager.getFunction(childClassBarFunc); Assert.assertFalse(func.isConstructor()); Assert.assertEquals(func.getFunctionDescriptor(), childClassBarDescriptor); // Test actor constructor - func = functionManager.getFunction(JobId.NIL, childClassConstructor); + func = functionManager.getFunction(childClassConstructor); Assert.assertTrue(func.isConstructor()); Assert.assertEquals(func.getFunctionDescriptor(), childClassConstructorDescriptor); } @@ -114,17 +116,17 @@ public class FunctionManagerTest { public void testGetFunctionFromFunctionDescriptor() { final FunctionManager functionManager = new FunctionManager(null); // Test normal function. - RayFunction func = functionManager.getFunction(JobId.NIL, fooDescriptor); + RayFunction func = functionManager.getFunction(fooDescriptor); Assert.assertFalse(func.isConstructor()); Assert.assertEquals(func.getFunctionDescriptor(), fooDescriptor); // Test actor method - func = functionManager.getFunction(JobId.NIL, childClassBarDescriptor); + func = functionManager.getFunction(childClassBarDescriptor); Assert.assertFalse(func.isConstructor()); Assert.assertEquals(func.getFunctionDescriptor(), childClassBarDescriptor); // Test actor constructor - func = functionManager.getFunction(JobId.NIL, childClassConstructorDescriptor); + func = functionManager.getFunction(childClassConstructorDescriptor); Assert.assertTrue(func.isConstructor()); Assert.assertEquals(func.getFunctionDescriptor(), childClassConstructorDescriptor); @@ -133,7 +135,6 @@ public class FunctionManagerTest { RuntimeException.class, () -> { functionManager.getFunction( - JobId.NIL, new JavaFunctionDescriptor( FunctionManagerTest.class.getName(), "overloadFunction", "")); }); @@ -146,11 +147,10 @@ public class FunctionManagerTest { fooDescriptor = new JavaFunctionDescriptor(ParentClass.class.getName(), "foo", "()Ljava/lang/Object;"); Assert.assertEquals( - functionManager.getFunction(JobId.NIL, fooDescriptor).executable.getDeclaringClass(), + functionManager.getFunction(fooDescriptor).executable.getDeclaringClass(), ParentClass.class); RayFunction fooFunc = functionManager.getFunction( - JobId.NIL, new JavaFunctionDescriptor(ChildClass.class.getName(), "foo", "()Ljava/lang/Object;")); Assert.assertEquals(fooFunc.executable.getDeclaringClass(), ParentClass.class); @@ -159,21 +159,16 @@ public class FunctionManagerTest { childClassBarDescriptor = new JavaFunctionDescriptor(ParentClass.class.getName(), "bar", "()Ljava/lang/Object;"); Assert.assertEquals( - functionManager - .getFunction(JobId.NIL, childClassBarDescriptor) - .executable - .getDeclaringClass(), + functionManager.getFunction(childClassBarDescriptor).executable.getDeclaringClass(), ParentClass.class); RayFunction barFunc = functionManager.getFunction( - JobId.NIL, new JavaFunctionDescriptor(ChildClass.class.getName(), "bar", "()Ljava/lang/Object;")); Assert.assertEquals(barFunc.executable.getDeclaringClass(), ChildClass.class); // Check interface default methods. RayFunction interfaceNameFunc = functionManager.getFunction( - JobId.NIL, new JavaFunctionDescriptor( ChildClass.class.getName(), "interfaceName", "()Ljava/lang/String;")); Assert.assertEquals( @@ -220,7 +215,6 @@ public class FunctionManagerTest { @Test public void testGetFunctionFromLocalResource() throws Exception { - JobId jobId = JobId.fromInt(1); final String codeSearchPath = FileUtils.getTempDirectoryPath() + "/ray_test_resources/"; File jobResourceDir = new File(codeSearchPath); FileUtils.deleteQuietly(jobResourceDir); @@ -250,7 +244,7 @@ public class FunctionManagerTest { new JavaFunctionDescriptor("DemoApp", "hello", "()Ljava/lang/String;"); final FunctionManager functionManager = new FunctionManager(Collections.singletonList(codeSearchPath)); - RayFunction func = functionManager.getFunction(jobId, descriptor); + RayFunction func = functionManager.getFunction(descriptor); Assert.assertEquals(func.getFunctionDescriptor(), descriptor); } } diff --git a/java/serve/src/main/java/io/ray/serve/HttpProxy.java b/java/serve/src/main/java/io/ray/serve/HttpProxy.java index 809337e75..0c5c6018a 100644 --- a/java/serve/src/main/java/io/ray/serve/HttpProxy.java +++ b/java/serve/src/main/java/io/ray/serve/HttpProxy.java @@ -1,7 +1,6 @@ package io.ray.serve; import com.google.common.collect.ImmutableMap; -import io.ray.api.Ray; import io.ray.runtime.metric.Count; import io.ray.runtime.metric.Metrics; import io.ray.runtime.metric.TagKey; @@ -47,8 +46,6 @@ public class HttpProxy implements ServeProxy { private ProxyRouter proxyRouter; - private Object asyncContext = Ray.getAsyncContext(); - @Override public void init(Map config, ProxyRouter proxyRouter) { this.port = @@ -101,8 +98,6 @@ public class HttpProxy implements ServeProxy { ClassicHttpRequest request, ClassicHttpResponse response, HttpContext context) throws HttpException, IOException { - Ray.setAsyncContext(asyncContext); - int code = HttpURLConnection.HTTP_OK; Object result = null; String route = request.getPath(); diff --git a/java/test.sh b/java/test.sh index ba5f16ae4..7a8bdc637 100755 --- a/java/test.sh +++ b/java/test.sh @@ -122,8 +122,7 @@ for file in "$docdemo_path"*.java; do file=${file#"$docdemo_path"} class=${file%".java"} echo "Running $class" - java -cp bazel-bin/java/all_tests_shaded.jar -Dray.job.num-java-workers-per-process=1\ - -Dray.raylet.startup-token=0 "io.ray.docdemo.$class" + java -cp bazel-bin/java/all_tests_shaded.jar -Dray.raylet.startup-token=0 "io.ray.docdemo.$class" done popd diff --git a/java/test/src/main/java/io/ray/test/ActorHandleReferenceCountTest.java b/java/test/src/main/java/io/ray/test/ActorHandleReferenceCountTest.java index 13f8af022..42b24fa5e 100644 --- a/java/test/src/main/java/io/ray/test/ActorHandleReferenceCountTest.java +++ b/java/test/src/main/java/io/ray/test/ActorHandleReferenceCountTest.java @@ -4,13 +4,11 @@ import io.ray.api.ActorHandle; import io.ray.api.ObjectRef; import io.ray.api.Ray; import io.ray.runtime.actor.NativeActorHandle; -import io.ray.runtime.exception.RayActorException; import io.ray.runtime.util.SystemUtil; import java.lang.ref.Reference; import java.lang.reflect.Field; import java.lang.reflect.Method; import java.util.Set; -import java.util.concurrent.TimeUnit; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.testng.Assert; @@ -65,7 +63,6 @@ public class ActorHandleReferenceCountTest { public void testActorHandleReferenceCount() { try { - System.setProperty("ray.job.num-java-workers-per-process", "1"); Ray.init(); ActorHandle signal = Ray.actor(SignalActor::new).remote(); ActorHandle myActor = Ray.actor(MyActor::new).remote(); @@ -83,27 +80,4 @@ public class ActorHandleReferenceCountTest { Ray.shutdown(); } } - - public void testRemoveActorHandleReferenceInMultipleThreadedActor() throws InterruptedException { - System.setProperty("ray.job.num-java-workers-per-process", "5"); - try { - Ray.init(); - ActorHandle myActor1 = Ray.actor(MyActor::new).remote(); - int pid1 = myActor1.task(MyActor::getPid).remote().get(); - ActorHandle myActor2 = Ray.actor(MyActor::new).remote(); - int pid2 = myActor2.task(MyActor::getPid).remote().get(); - Assert.assertEquals(pid1, pid2); - del(myActor1); - TimeUnit.SECONDS.sleep(5); - Assert.assertThrows( - RayActorException.class, - () -> { - myActor1.task(MyActor::hello).remote().get(); - }); - /// myActor2 shouldn't be killed. - Assert.assertEquals("hello", myActor2.task(MyActor::hello).remote().get()); - } finally { - Ray.shutdown(); - } - } } diff --git a/java/test/src/main/java/io/ray/test/ClassLoaderTest.java b/java/test/src/main/java/io/ray/test/ClassLoaderTest.java deleted file mode 100644 index 249af8eba..000000000 --- a/java/test/src/main/java/io/ray/test/ClassLoaderTest.java +++ /dev/null @@ -1,167 +0,0 @@ -package io.ray.test; - -import io.ray.api.ActorHandle; -import io.ray.api.BaseActorHandle; -import io.ray.api.ObjectRef; -import io.ray.api.options.ActorCreationOptions; -import io.ray.api.options.CallOptions; -import io.ray.runtime.AbstractRayRuntime; -import io.ray.runtime.functionmanager.FunctionDescriptor; -import io.ray.runtime.functionmanager.JavaFunctionDescriptor; -import java.io.File; -import java.lang.reflect.Method; -import java.nio.file.Files; -import java.nio.file.Paths; -import java.util.Optional; -import javax.tools.JavaCompiler; -import javax.tools.ToolProvider; -import org.apache.commons.io.FileUtils; -import org.testng.Assert; -import org.testng.annotations.BeforeClass; -import org.testng.annotations.Test; - -public class ClassLoaderTest extends BaseTest { - - private final String codeSearchPath = - FileUtils.getTempDirectoryPath() + "/ray_test/ClassLoaderTest"; - - @BeforeClass - public void setUp() { - // The potential issue of multiple `ClassLoader` instances for the same job on multi-threading - // scenario only occurs if the classes are loaded from the job code search path. - System.setProperty("ray.job.code-search-path", codeSearchPath); - } - - @Test(groups = {"cluster"}) - public void testClassLoaderInMultiThreading() throws Exception { - File jobResourceDir = new File(codeSearchPath); - FileUtils.deleteQuietly(jobResourceDir); - jobResourceDir.mkdirs(); - jobResourceDir.deleteOnExit(); - - // In this test case the class is expected to be loaded from the job code search path, - // so we need to put the compiled class file into the job code search path and load it - // later. - String testJavaFile = - "" - + "import java.lang.management.ManagementFactory;\n" - + "import java.lang.management.RuntimeMXBean;\n" - + "\n" - + "public class ClassLoaderTester {\n" - + "\n" - + " static volatile int value;\n" - + "\n" - + " public int getPid() {\n" - + " RuntimeMXBean runtime = ManagementFactory.getRuntimeMXBean();\n" - + " String name = runtime.getName();\n" - + " int index = name.indexOf(\"@\");\n" - + " if (index != -1) {\n" - + " return Integer.parseInt(name.substring(0, index));\n" - + " } else {\n" - + " throw new RuntimeException(\"parse pid error:\" + name);\n" - + " }\n" - + " }\n" - + "\n" - + " public int increase() throws InterruptedException {\n" - + " return increaseInternal();\n" - + " }\n" - + "\n" - + " public static synchronized int increaseInternal() throws InterruptedException {\n" - + " int oldValue = value;\n" - + " Thread.sleep(10 * 1000);\n" - + " value = oldValue + 1;\n" - + " return value;\n" - + " }\n" - + "\n" - + " public int getClassLoaderHashCode() {\n" - + " return this.getClass().getClassLoader().hashCode();\n" - + " }\n" - + "}"; - - // Write the demo java file to the job code search path. - String javaFilePath = codeSearchPath + "/ClassLoaderTester.java"; - Files.write(Paths.get(javaFilePath), testJavaFile.getBytes()); - - // Compile the java file. - JavaCompiler compiler = ToolProvider.getSystemJavaCompiler(); - int result = compiler.run(null, null, null, "-d", codeSearchPath, javaFilePath); - if (result != 0) { - throw new RuntimeException("Couldn't compile ClassLoaderTester.java."); - } - - FunctionDescriptor constructor = - new JavaFunctionDescriptor("ClassLoaderTester", "", "()V"); - ActorHandle actor1 = createActor(constructor); - FunctionDescriptor getPid = new JavaFunctionDescriptor("ClassLoaderTester", "getPid", "()I"); - int pid = - this.callActorFunction(actor1, getPid, new Object[0], Optional.of(Integer.class)) - .get(); - ActorHandle actor2; - while (true) { - // Create another actor which share the same process of actor 1. - actor2 = createActor(constructor); - int actor2Pid = - this.callActorFunction(actor2, getPid, new Object[0], Optional.of(Integer.class)) - .get(); - if (actor2Pid == pid) { - break; - } - } - - FunctionDescriptor getClassLoaderHashCode = - new JavaFunctionDescriptor("ClassLoaderTester", "getClassLoaderHashCode", "()I"); - ObjectRef hashCode1 = - callActorFunction( - actor1, getClassLoaderHashCode, new Object[0], Optional.of(Integer.class)); - ObjectRef hashCode2 = - callActorFunction( - actor2, getClassLoaderHashCode, new Object[0], Optional.of(Integer.class)); - Assert.assertEquals(hashCode1.get(), hashCode2.get()); - - FunctionDescriptor increase = - new JavaFunctionDescriptor("ClassLoaderTester", "increase", "()I"); - ObjectRef value1 = - callActorFunction(actor1, increase, new Object[0], Optional.of(Integer.class)); - ObjectRef value2 = - callActorFunction(actor2, increase, new Object[0], Optional.of(Integer.class)); - Assert.assertNotEquals(value1.get(), value2.get()); - } - - private ActorHandle createActor(FunctionDescriptor functionDescriptor) throws Exception { - Method createActorMethod = - AbstractRayRuntime.class.getDeclaredMethod( - "createActorImpl", - FunctionDescriptor.class, - Object[].class, - ActorCreationOptions.class); - createActorMethod.setAccessible(true); - return (ActorHandle) - createActorMethod.invoke( - TestUtils.getUnderlyingRuntime(), functionDescriptor, new Object[0], null); - } - - private ObjectRef callActorFunction( - ActorHandle rayActor, - FunctionDescriptor functionDescriptor, - Object[] args, - Optional> returnType) - throws Exception { - Method callActorFunctionMethod = - AbstractRayRuntime.class.getDeclaredMethod( - "callActorFunction", - BaseActorHandle.class, - FunctionDescriptor.class, - Object[].class, - Optional.class, - CallOptions.class); - callActorFunctionMethod.setAccessible(true); - return (ObjectRef) - callActorFunctionMethod.invoke( - TestUtils.getUnderlyingRuntime(), - rayActor, - functionDescriptor, - args, - returnType, - new CallOptions.Builder().build()); - } -} diff --git a/java/test/src/main/java/io/ray/test/DefaultActorLifetimeTest.java b/java/test/src/main/java/io/ray/test/DefaultActorLifetimeTest.java index a1664313e..f9d6785fe 100644 --- a/java/test/src/main/java/io/ray/test/DefaultActorLifetimeTest.java +++ b/java/test/src/main/java/io/ray/test/DefaultActorLifetimeTest.java @@ -55,7 +55,6 @@ public class DefaultActorLifetimeTest { System.setProperty("ray.job.default-actor-lifetime", defaultActorLifetime.name()); } try { - System.setProperty("ray.job.num-java-workers-per-process", "1"); Ray.init(); /// 1. create owner and invoke createChildActor. diff --git a/java/test/src/main/java/io/ray/test/ExitActorTest.java b/java/test/src/main/java/io/ray/test/ExitActorTest.java index f402cb607..6b1bcb135 100644 --- a/java/test/src/main/java/io/ray/test/ExitActorTest.java +++ b/java/test/src/main/java/io/ray/test/ExitActorTest.java @@ -74,38 +74,6 @@ public class ExitActorTest extends BaseTest { Assert.assertThrows(RayActorException.class, obj::get); } - public void testExitActorInMultiWorker() { - Assert.assertTrue(TestUtils.getNumWorkersPerProcess() > 1); - ActorHandle actor1 = - Ray.actor(ExitingActor::new).setMaxRestarts(ActorCreationOptions.INFINITE_RESTART).remote(); - int pid = actor1.task(ExitingActor::getPid).remote().get(); - Assert.assertEquals( - 1, (int) actor1.task(ExitingActor::getSizeOfActorContextMap).remote().get()); - ActorHandle actor2; - while (true) { - // Create another actor which share the same process of actor 1. - actor2 = - Ray.actor(ExitingActor::new).setMaxRestarts(ActorCreationOptions.NO_RESTART).remote(); - int actor2Pid = actor2.task(ExitingActor::getPid).remote().get(); - if (actor2Pid == pid) { - break; - } - } - Assert.assertEquals( - 2, (int) actor1.task(ExitingActor::getSizeOfActorContextMap).remote().get()); - Assert.assertEquals( - 2, (int) actor2.task(ExitingActor::getSizeOfActorContextMap).remote().get()); - ObjectRef obj1 = actor1.task(ExitingActor::exit).remote(); - Assert.assertThrows(RayActorException.class, obj1::get); - Assert.assertTrue(SystemUtil.isProcessAlive(pid)); - // Actor 2 shouldn't exit or be reconstructed. - Assert.assertEquals(1, (int) actor2.task(ExitingActor::incr).remote().get()); - Assert.assertEquals( - 1, (int) actor2.task(ExitingActor::getSizeOfActorContextMap).remote().get()); - Assert.assertEquals(pid, (int) actor2.task(ExitingActor::getPid).remote().get()); - Assert.assertTrue(SystemUtil.isProcessAlive(pid)); - } - public void testExitActorWithDynamicOptions() { ActorHandle actor = Ray.actor(ExitingActor::new) diff --git a/java/test/src/main/java/io/ray/test/ExitActorTest2.java b/java/test/src/main/java/io/ray/test/ExitActorTest2.java index c12a6cb26..e639ee1f3 100644 --- a/java/test/src/main/java/io/ray/test/ExitActorTest2.java +++ b/java/test/src/main/java/io/ray/test/ExitActorTest2.java @@ -15,7 +15,6 @@ public class ExitActorTest2 extends BaseTest { @BeforeClass public void setUp() { - System.setProperty("ray.job.num-java-workers-per-process", "1"); System.setProperty("ray.raylet.startup-token", "0"); } diff --git a/java/test/src/main/java/io/ray/test/FailureTest.java b/java/test/src/main/java/io/ray/test/FailureTest.java index 9035bfea3..2191d7b5e 100644 --- a/java/test/src/main/java/io/ray/test/FailureTest.java +++ b/java/test/src/main/java/io/ray/test/FailureTest.java @@ -23,10 +23,6 @@ public class FailureTest extends BaseTest { @BeforeClass public void setUp() { - // This is needed by `testGetThrowsQuicklyWhenFoundException`. - // Set one worker per process. Otherwise, if `badFunc2` and `slowFunc` run in the same - // process, `sleep` will delay `System.exit`. - System.setProperty("ray.job.num-java-workers-per-process", "1"); System.setProperty("ray.raylet.startup-token", "0"); } diff --git a/java/test/src/main/java/io/ray/test/JobConfigTest.java b/java/test/src/main/java/io/ray/test/JobConfigTest.java index 29795773a..e5b781692 100644 --- a/java/test/src/main/java/io/ray/test/JobConfigTest.java +++ b/java/test/src/main/java/io/ray/test/JobConfigTest.java @@ -11,7 +11,6 @@ public class JobConfigTest extends BaseTest { @BeforeClass public void setupJobConfig() { - System.setProperty("ray.job.num-java-workers-per-process", "3"); System.setProperty("ray.raylet.startup-token", "0"); System.setProperty("ray.job.jvm-options.0", "-DX=999"); System.setProperty("ray.job.jvm-options.1", "-DY=998"); @@ -33,10 +32,6 @@ public class JobConfigTest extends BaseTest { Assert.assertEquals("998", Ray.task(JobConfigTest::getJvmOptions, "Y").remote().get()); } - public void testNumJavaWorkersPerProcess() { - Assert.assertEquals(TestUtils.getNumWorkersPerProcess(), 3); - } - public void testInActor() { ActorHandle actor = Ray.actor(MyActor::new).remote(); diff --git a/java/test/src/main/java/io/ray/test/KillActorTest.java b/java/test/src/main/java/io/ray/test/KillActorTest.java index 2c4469ad6..b4d8838a9 100644 --- a/java/test/src/main/java/io/ray/test/KillActorTest.java +++ b/java/test/src/main/java/io/ray/test/KillActorTest.java @@ -15,7 +15,6 @@ public class KillActorTest extends BaseTest { @BeforeClass public void setUp() { - System.setProperty("ray.job.num-java-workers-per-process", "1"); System.setProperty("ray.raylet.startup-token", "0"); } diff --git a/java/test/src/main/java/io/ray/test/MultiThreadingTest.java b/java/test/src/main/java/io/ray/test/MultiThreadingTest.java index 42e0dd2a9..81629245c 100644 --- a/java/test/src/main/java/io/ray/test/MultiThreadingTest.java +++ b/java/test/src/main/java/io/ray/test/MultiThreadingTest.java @@ -51,14 +51,13 @@ public class MultiThreadingTest extends BaseTest { final Object[] result = new Object[1]; Thread thread = new Thread( - Ray.wrapRunnable( - () -> { - try { - result[0] = Ray.getRuntimeContext().getCurrentActorId(); - } catch (Exception e) { - result[0] = e; - } - })); + () -> { + try { + result[0] = Ray.getRuntimeContext().getCurrentActorId(); + } catch (Exception e) { + result[0] = e; + } + }); thread.start(); thread.join(); if (result[0] instanceof Exception) { @@ -140,6 +139,8 @@ public class MultiThreadingTest extends BaseTest { Assert.assertEquals("ok", obj.get()); } + /// SINGLE_PROCESS mode doesn't support this API. + @Test(groups = {"cluster"}) public void testGetCurrentActorId() { ActorHandle actorIdTester = Ray.actor(ActorIdTester::new).remote(); ActorId actorId = actorIdTester.task(ActorIdTester::getCurrentActorId).remote().get(); @@ -162,104 +163,6 @@ public class MultiThreadingTest extends BaseTest { }; } - static boolean testMissingWrapRunnable() throws InterruptedException { - { - Runnable[] runnables = generateRunnables(); - // It's OK to run them in main thread. - for (Runnable runnable : runnables) { - runnable.run(); - } - } - - Throwable[] throwable = new Throwable[1]; - - { - Runnable[] runnables = generateRunnables(); - Thread thread = - new Thread( - Ray.wrapRunnable( - () -> { - try { - // It would be OK to run them in another thread if wrapped the runnable. - for (Runnable runnable : runnables) { - runnable.run(); - } - } catch (Throwable ex) { - throwable[0] = ex; - } - })); - thread.start(); - thread.join(); - if (throwable[0] != null) { - throw new RuntimeException("Exception occurred in thread.", throwable[0]); - } - } - - { - Runnable[] runnables = generateRunnables(); - Thread thread = - new Thread( - () -> { - try { - // It wouldn't be OK to run them in another thread if not wrapped the runnable. - for (Runnable runnable : runnables) { - Assert.expectThrows(RuntimeException.class, runnable::run); - } - } catch (Throwable ex) { - throwable[0] = ex; - } - }); - thread.start(); - thread.join(); - if (throwable[0] != null) { - throw new RuntimeException("Exception occurred in thread.", throwable[0]); - } - } - - { - Runnable[] runnables = generateRunnables(); - Runnable[] wrappedRunnables = new Runnable[runnables.length]; - for (int i = 0; i < runnables.length; i++) { - wrappedRunnables[i] = Ray.wrapRunnable(runnables[i]); - } - // It would be OK to run the wrapped runnables in the current thread. - for (Runnable runnable : wrappedRunnables) { - runnable.run(); - } - - // It would be OK to invoke Ray APIs after executing a wrapped runnable in the current thread. - wrappedRunnables[0].run(); - runnables[0].run(); - } - - // Return true here to make the caller.remote() returns an ObjectRef. - return true; - } - - public void testMissingWrapRunnableInWorker() { - Ray.task(MultiThreadingTest::testMissingWrapRunnable).remote().get(); - } - - public void testGetAndSetAsyncContext() throws InterruptedException { - Object asyncContext = Ray.getAsyncContext(); - Throwable[] throwable = new Throwable[1]; - Thread thread = - new Thread( - () -> { - try { - Ray.setAsyncContext(asyncContext); - Ray.put(1); - } catch (Throwable ex) { - throwable[0] = ex; - } - }); - thread.start(); - thread.join(); - if (throwable[0] != null) { - throw new RuntimeException("Exception occurred in thread.", throwable[0]); - } - } - private static void runTestCaseInMultipleThreads(Runnable testCase, int numRepeats) { ExecutorService service = Executors.newFixedThreadPool(NUM_THREADS); @@ -267,14 +170,13 @@ public class MultiThreadingTest extends BaseTest { List> futures = new ArrayList<>(); for (int i = 0; i < NUM_THREADS; i++) { Callable task = - Ray.wrapCallable( - () -> { - for (int j = 0; j < numRepeats; j++) { - TimeUnit.MILLISECONDS.sleep(1); - testCase.run(); - } - return "ok"; - }); + () -> { + for (int j = 0; j < numRepeats; j++) { + TimeUnit.MILLISECONDS.sleep(1); + testCase.run(); + } + return "ok"; + }; futures.add(service.submit(task)); } for (Future future : futures) { @@ -288,35 +190,4 @@ public class MultiThreadingTest extends BaseTest { service.shutdown(); } } - - private static boolean testGetAsyncContextAndSetAsyncContext() throws Exception { - final Object asyncContext = Ray.getAsyncContext(); - final Object[] result = new Object[1]; - Thread thread = - new Thread( - () -> { - try { - Ray.setAsyncContext(asyncContext); - Ray.put(0); - } catch (Exception e) { - result[0] = e; - } - }); - thread.start(); - thread.join(); - if (result[0] instanceof Exception) { - throw (Exception) result[0]; - } - return true; - } - - public void testGetAsyncContextAndSetAsyncContextInDriver() throws Exception { - Assert.assertTrue(testGetAsyncContextAndSetAsyncContext()); - } - - public void testGetAsyncContextAndSetAsyncContextInWorker() { - ObjectRef obj = - Ray.task(MultiThreadingTest::testGetAsyncContextAndSetAsyncContext).remote(); - Assert.assertTrue(obj.get()); - } } diff --git a/java/test/src/main/java/io/ray/test/RuntimeEnvTest.java b/java/test/src/main/java/io/ray/test/RuntimeEnvTest.java index a31c60d0f..8419f754a 100644 --- a/java/test/src/main/java/io/ray/test/RuntimeEnvTest.java +++ b/java/test/src/main/java/io/ray/test/RuntimeEnvTest.java @@ -4,7 +4,6 @@ import com.google.common.collect.ImmutableList; import io.ray.api.ActorHandle; import io.ray.api.Ray; import io.ray.api.runtimeenv.RuntimeEnv; -import io.ray.runtime.util.SystemUtil; import java.util.List; import org.testng.Assert; import org.testng.annotations.Test; @@ -30,10 +29,6 @@ public class RuntimeEnvTest { return System.getenv(key); } - public int getPid() { - return SystemUtil.pid(); - } - public boolean findClass(String className) { try { Class.forName(className); @@ -45,7 +40,6 @@ public class RuntimeEnvTest { } public void testPerJobEnvVars() { - System.setProperty("ray.job.num-java-workers-per-process", "1"); System.setProperty("ray.job.runtime-env.env-vars.KEY1", "A"); System.setProperty("ray.job.runtime-env.env-vars.KEY2", "B"); @@ -61,87 +55,6 @@ public class RuntimeEnvTest { } } - public void testPerActorEnvVars() { - /// This is used to test that actors with runtime envs will not reuse worker process. - System.setProperty("ray.job.num-java-workers-per-process", "2"); - try { - Ray.init(); - int pid1 = 0; - int pid2 = 0; - { - RuntimeEnv runtimeEnv = - new RuntimeEnv.Builder() - .addEnvVar("KEY1", "A") - .addEnvVar("KEY2", "B") - .addEnvVar("KEY1", "C") - .build(); - - ActorHandle actor1 = Ray.actor(A::new).setRuntimeEnv(runtimeEnv).remote(); - String val = actor1.task(A::getEnv, "KEY1").remote().get(); - Assert.assertEquals(val, "C"); - val = actor1.task(A::getEnv, "KEY2").remote().get(); - Assert.assertEquals(val, "B"); - - pid1 = actor1.task(A::getPid).remote().get(); - } - - { - /// Because we didn't set them for actor2 , all should be null. - ActorHandle actor2 = Ray.actor(A::new).remote(); - String val = actor2.task(A::getEnv, "KEY1").remote().get(); - Assert.assertNull(val); - val = actor2.task(A::getEnv, "KEY2").remote().get(); - Assert.assertNull(val); - pid2 = actor2.task(A::getPid).remote().get(); - } - - // actor1 and actor2 shouldn't be in one process because they have - // different runtime env. - Assert.assertNotEquals(pid1, pid2); - } finally { - Ray.shutdown(); - } - } - - public void testPerActorEnvVarsOverwritePerJobEnvVars() { - System.setProperty("ray.job.num-java-workers-per-process", "2"); - System.setProperty("ray.job.runtime-env.env-vars.KEY1", "A"); - System.setProperty("ray.job.runtime-env.env-vars.KEY2", "B"); - - int pid1 = 0; - int pid2 = 0; - try { - Ray.init(); - { - RuntimeEnv runtimeEnv = new RuntimeEnv.Builder().addEnvVar("KEY1", "C").build(); - - ActorHandle actor1 = Ray.actor(A::new).setRuntimeEnv(runtimeEnv).remote(); - String val = actor1.task(A::getEnv, "KEY1").remote().get(); - Assert.assertEquals(val, "C"); - val = actor1.task(A::getEnv, "KEY2").remote().get(); - Assert.assertEquals(val, "B"); - pid1 = actor1.task(A::getPid).remote().get(); - } - - { - /// Because we didn't set them for actor2 explicitly, it should use the per job - /// runtime env. - ActorHandle actor2 = Ray.actor(A::new).remote(); - String val = actor2.task(A::getEnv, "KEY1").remote().get(); - Assert.assertEquals(val, "A"); - val = actor2.task(A::getEnv, "KEY2").remote().get(); - Assert.assertEquals(val, "B"); - pid2 = actor2.task(A::getPid).remote().get(); - } - - // actor1 and actor2 shouldn't be in one process because they have - // different runtime env. - Assert.assertNotEquals(pid1, pid2); - } finally { - Ray.shutdown(); - } - } - private static String getEnvVar(String key) { return System.getenv(key); } diff --git a/java/test/src/main/java/io/ray/test/TestUtils.java b/java/test/src/main/java/io/ray/test/TestUtils.java index 71a0ece59..5db5e0e8e 100644 --- a/java/test/src/main/java/io/ray/test/TestUtils.java +++ b/java/test/src/main/java/io/ray/test/TestUtils.java @@ -3,9 +3,7 @@ package io.ray.test; import com.google.common.base.Preconditions; import io.ray.api.ObjectRef; import io.ray.api.Ray; -import io.ray.runtime.AbstractRayRuntime; import io.ray.runtime.RayRuntimeInternal; -import io.ray.runtime.RayRuntimeProxy; import io.ray.runtime.config.RayConfig; import io.ray.runtime.config.RunMode; import io.ray.runtime.task.ArgumentsBuilder; @@ -129,20 +127,7 @@ public class TestUtils { } public static RayRuntimeInternal getUnderlyingRuntime() { - if (Ray.internal() instanceof AbstractRayRuntime) { - return (RayRuntimeInternal) Ray.internal(); - } - RayRuntimeProxy proxy = - (RayRuntimeProxy) (java.lang.reflect.Proxy.getInvocationHandler(Ray.internal())); - return proxy.getRuntimeObject(); - } - - private static int getNumWorkersPerProcessRemoteFunction() { - return TestUtils.getRuntime().getRayConfig().numWorkersPerProcess; - } - - public static int getNumWorkersPerProcess() { - return Ray.task(TestUtils::getNumWorkersPerProcessRemoteFunction).remote().get(); + return (RayRuntimeInternal) Ray.internal(); } public static ProcessBuilder buildDriver(Class mainClass, String[] args) { diff --git a/java/test/src/main/resources/ray.conf b/java/test/src/main/resources/ray.conf deleted file mode 100644 index b838c0075..000000000 --- a/java/test/src/main/resources/ray.conf +++ /dev/null @@ -1,6 +0,0 @@ -ray { - job { - # Enable multi-worker feature in Java test - num-java-workers-per-process: 10 - } -} diff --git a/python/ray/_raylet.pyx b/python/ray/_raylet.pyx index e1be9aa2b..2ef826199 100644 --- a/python/ray/_raylet.pyx +++ b/python/ray/_raylet.pyx @@ -1116,7 +1116,6 @@ cdef class CoreWorker: options.unhandled_exception_handler = unhandled_exception_handler options.get_lang_stack = get_py_stack options.is_local_mode = local_mode - options.num_workers = 1 options.kill_main = kill_main_task options.terminate_asyncio_thread = terminate_asyncio_thread options.serialized_job_config = serialized_job_config diff --git a/python/ray/job_config.py b/python/ray/job_config.py index 0116a4012..c220e37b5 100644 --- a/python/ray/job_config.py +++ b/python/ray/job_config.py @@ -8,8 +8,6 @@ class JobConfig: """A class used to store the configurations of a job. Attributes: - num_java_workers_per_process (int): The number of java workers per - worker process. jvm_options (str[]): The jvm options for java workers of the job. code_search_path (list): A list of directories or jar files that specify the search path for user code. This will be used as @@ -22,7 +20,6 @@ class JobConfig: def __init__( self, - num_java_workers_per_process=1, jvm_options=None, code_search_path=None, runtime_env=None, @@ -31,7 +28,6 @@ class JobConfig: ray_namespace=None, default_actor_lifetime="non_detached", ): - self.num_java_workers_per_process = num_java_workers_per_process self.jvm_options = jvm_options or [] self.code_search_path = code_search_path or [] # It's difficult to find the error that caused by the @@ -108,7 +104,6 @@ class JobConfig: pb.ray_namespace = str(uuid.uuid4()) else: pb.ray_namespace = self.ray_namespace - pb.num_java_workers_per_process = self.num_java_workers_per_process pb.jvm_options.extend(self.jvm_options) pb.code_search_path.extend(self.code_search_path) for k, v in self.metadata.items(): @@ -147,9 +142,6 @@ class JobConfig: Generates a JobConfig object from json. """ return cls( - num_java_workers_per_process=job_config_json.get( - "num_java_workers_per_process", 1 - ), jvm_options=job_config_json.get("jvm_options", None), code_search_path=job_config_json.get("code_search_path", None), runtime_env=job_config_json.get("runtime_env", None), diff --git a/python/ray/tests/test_advanced_8.py b/python/ray/tests/test_advanced_8.py index a9fb0dd10..230539222 100644 --- a/python/ray/tests/test_advanced_8.py +++ b/python/ray/tests/test_advanced_8.py @@ -545,19 +545,16 @@ def test_k8s_cpu(use_cgroups_v2: bool): def test_sync_job_config(shutdown_only): - num_java_workers_per_process = 8 runtime_env = {"env_vars": {"key": "value"}} ray.init( job_config=ray.job_config.JobConfig( - num_java_workers_per_process=num_java_workers_per_process, runtime_env=runtime_env, ) ) # Check that the job config is synchronized at the driver side. job_config = ray.worker.global_worker.core_worker.get_job_config() - assert job_config.num_java_workers_per_process == num_java_workers_per_process job_runtime_env = RuntimeEnv.deserialize( job_config.runtime_env_info.serialized_runtime_env ) @@ -571,7 +568,6 @@ def test_sync_job_config(shutdown_only): # Check that the job config is synchronized at the worker side. job_config = gcs_utils.JobConfig() job_config.ParseFromString(ray.get(get_job_config.remote())) - assert job_config.num_java_workers_per_process == num_java_workers_per_process job_runtime_env = RuntimeEnv.deserialize( job_config.runtime_env_info.serialized_runtime_env ) diff --git a/src/ray/core_worker/core_worker.cc b/src/ray/core_worker/core_worker.cc index af98e6a93..96cf4a155 100644 --- a/src/ray/core_worker/core_worker.cc +++ b/src/ray/core_worker/core_worker.cc @@ -3084,15 +3084,6 @@ void CoreWorker::HandleKillActor(const rpc::KillActorRequest &request, if (request.no_restart()) { Disconnect(); } - if (options_.num_workers > 1) { - // TODO (kfstorm): Should we add some kind of check before sending the killing - // request? - RAY_LOG(ERROR) - << "Killing an actor which is running in a worker process with multiple " - "workers will also kill other actors in this process. To avoid this, " - "please create the Java actor with some dynamic options to make it being " - "hosted in a dedicated worker process."; - } // NOTE(hchen): Use `QuickExit()` to force-exit this process without doing cleanup. // `exit()` will destruct static objects in an incorrect order, which will lead to // core dumps. diff --git a/src/ray/core_worker/core_worker_options.h b/src/ray/core_worker/core_worker_options.h index 21d8d1816..5703c6544 100644 --- a/src/ray/core_worker/core_worker_options.h +++ b/src/ray/core_worker/core_worker_options.h @@ -76,7 +76,6 @@ struct CoreWorkerOptions { get_lang_stack(nullptr), kill_main(nullptr), is_local_mode(false), - num_workers(0), terminate_asyncio_thread(nullptr), serialized_job_config(""), metrics_agent_port(-1), @@ -148,8 +147,6 @@ struct CoreWorkerOptions { std::function kill_main; /// Is local mode being used. bool is_local_mode; - /// The number of workers to be started in the current process. - int num_workers; /// The function to destroy asyncio event and loops. std::function terminate_asyncio_thread; /// Serialized representation of JobConfig. diff --git a/src/ray/core_worker/core_worker_process.cc b/src/ray/core_worker/core_worker_process.cc index eca59f3bf..16221eabc 100644 --- a/src/ray/core_worker/core_worker_process.cc +++ b/src/ray/core_worker/core_worker_process.cc @@ -82,10 +82,9 @@ thread_local std::weak_ptr CoreWorkerProcessImpl::thread_local_core_ CoreWorkerProcessImpl::CoreWorkerProcessImpl(const CoreWorkerOptions &options) : options_(options), - global_worker_id_( - options.worker_type == WorkerType::DRIVER - ? ComputeDriverIdFromJob(options_.job_id) - : (options_.num_workers == 1 ? WorkerID::FromRandom() : WorkerID::Nil())) { + global_worker_id_(options.worker_type == WorkerType::DRIVER + ? ComputeDriverIdFromJob(options_.job_id) + : WorkerID::FromRandom()) { if (options_.enable_logging) { std::stringstream app_name; app_name << LanguageString(options_.language) << "-core-" @@ -111,12 +110,6 @@ CoreWorkerProcessImpl::CoreWorkerProcessImpl(const CoreWorkerOptions &options) << "install_failure_signal_handler must be false because ray log is disabled."; } - RAY_CHECK(options_.num_workers > 0); - if (options_.worker_type == WorkerType::DRIVER) { - // Driver process can only contain one worker. - RAY_CHECK(options_.num_workers == 1); - } - RAY_LOG(INFO) << "Constructing CoreWorkerProcess. pid: " << getpid(); // NOTE(kfstorm): any initialization depending on RayConfig must happen after this line. @@ -250,18 +243,14 @@ bool CoreWorkerProcessImpl::ShouldCreateGlobalWorkerOnConstruction() const { // worker APIs before `CoreWorkerProcess::RunTaskExecutionLoop` is called. So we need // to create the worker instance here. One example of invocations is // https://github.com/ray-project/ray/blob/45ce40e5d44801193220d2c546be8de0feeef988/python/ray/worker.py#L1281. - return options_.num_workers == 1 && (options_.worker_type == WorkerType::DRIVER || - options_.language == Language::PYTHON); + return (options_.worker_type == WorkerType::DRIVER || + options_.language == Language::PYTHON); } std::shared_ptr CoreWorkerProcessImpl::GetWorker( const WorkerID &worker_id) const { absl::ReaderMutexLock lock(&mutex_); - auto it = workers_.find(worker_id); - if (it != workers_.end()) { - return it->second; - } - return nullptr; + return global_worker_; } std::shared_ptr CoreWorkerProcessImpl::GetGlobalWorker() { @@ -275,13 +264,9 @@ std::shared_ptr CoreWorkerProcessImpl::CreateWorker() { global_worker_id_ != WorkerID::Nil() ? global_worker_id_ : WorkerID::FromRandom()); RAY_LOG(DEBUG) << "Worker " << worker->GetWorkerID() << " is created."; absl::WriterMutexLock lock(&mutex_); - if (options_.num_workers == 1) { - global_worker_ = worker; - } + global_worker_ = worker; thread_local_core_worker_ = worker; - workers_.emplace(worker->GetWorkerID(), worker); - RAY_CHECK(workers_.size() <= static_cast(options_.num_workers)); return worker; } @@ -293,10 +278,7 @@ void CoreWorkerProcessImpl::RemoveWorker(std::shared_ptr worker) { RAY_CHECK(thread_local_core_worker_.lock() == worker); } thread_local_core_worker_.reset(); - { - workers_.erase(worker->GetWorkerID()); - RAY_LOG(INFO) << "Removed worker " << worker->GetWorkerID(); - } + RAY_LOG(INFO) << "Removed worker " << worker->GetWorkerID(); if (global_worker_ == worker) { global_worker_ = nullptr; } @@ -304,31 +286,14 @@ void CoreWorkerProcessImpl::RemoveWorker(std::shared_ptr worker) { void CoreWorkerProcessImpl::RunWorkerTaskExecutionLoop() { RAY_CHECK(options_.worker_type == WorkerType::WORKER); - if (options_.num_workers == 1) { - // Run the task loop in the current thread only if the number of workers is 1. - auto worker = GetGlobalWorker(); - if (!worker) { - worker = CreateWorker(); - } - worker->RunTaskExecutionLoop(); - RAY_LOG(INFO) << "Task execution loop terminated. Removing the global worker."; - RemoveWorker(worker); - } else { - std::vector worker_threads; - for (int i = 0; i < options_.num_workers; i++) { - worker_threads.emplace_back([this, i] { - SetThreadName("worker.task" + std::to_string(i)); - auto worker = CreateWorker(); - worker->RunTaskExecutionLoop(); - RAY_LOG(INFO) << "Task execution loop terminated for a thread " - << std::to_string(i) << ". Removing a worker."; - RemoveWorker(worker); - }); - } - for (auto &thread : worker_threads) { - thread.join(); - } + // Run the task loop in the current thread only if the number of workers is 1. + auto worker = GetGlobalWorker(); + if (!worker) { + worker = CreateWorker(); } + worker->RunTaskExecutionLoop(); + RAY_LOG(INFO) << "Task execution loop terminated. Removing the global worker."; + RemoveWorker(worker); } void CoreWorkerProcessImpl::ShutdownDriver() { @@ -342,42 +307,30 @@ void CoreWorkerProcessImpl::ShutdownDriver() { } CoreWorker &CoreWorkerProcessImpl::GetCoreWorkerForCurrentThread() { - if (options_.num_workers == 1) { - auto global_worker = GetGlobalWorker(); - if (ShouldCreateGlobalWorkerOnConstruction() && !global_worker) { - // This could only happen when the worker has already been shutdown. - // In this case, we should exit without crashing. - // TODO (scv119): A better solution could be returning error code - // and handling it at language frontend. - if (options_.worker_type == WorkerType::DRIVER) { - RAY_LOG(ERROR) - << "The global worker has already been shutdown. This happens when " - "the language frontend accesses the Ray's worker after it is " - "shutdown. The process will exit"; - } else { - RAY_LOG(INFO) << "The global worker has already been shutdown. This happens when " - "the language frontend accesses the Ray's worker after it is " - "shutdown. The process will exit"; - } - QuickExit(); + auto global_worker = GetGlobalWorker(); + if (ShouldCreateGlobalWorkerOnConstruction() && !global_worker) { + // This could only happen when the worker has already been shutdown. + // In this case, we should exit without crashing. + // TODO (scv119): A better solution could be returning error code + // and handling it at language frontend. + if (options_.worker_type == WorkerType::DRIVER) { + RAY_LOG(ERROR) << "The global worker has already been shutdown. This happens when " + "the language frontend accesses the Ray's worker after it is " + "shutdown. The process will exit"; + } else { + RAY_LOG(INFO) << "The global worker has already been shutdown. This happens when " + "the language frontend accesses the Ray's worker after it is " + "shutdown. The process will exit"; } - RAY_CHECK(global_worker) << "global_worker_ must not be NULL"; - return *global_worker; + QuickExit(); } - auto ptr = thread_local_core_worker_.lock(); - RAY_CHECK(ptr != nullptr) - << "The current thread is not bound with a core worker instance."; - return *ptr; + RAY_CHECK(global_worker) << "global_worker_ must not be NULL"; + return *global_worker; } void CoreWorkerProcessImpl::SetThreadLocalWorkerById(const WorkerID &worker_id) { - if (options_.num_workers == 1) { - RAY_CHECK(GetGlobalWorker()->GetWorkerID() == worker_id); - return; - } - auto worker = GetWorker(worker_id); - RAY_CHECK(worker) << "Worker " << worker_id << " not found."; - thread_local_core_worker_ = GetWorker(worker_id); + RAY_CHECK(GetGlobalWorker()->GetWorkerID() == worker_id); + return; } } // namespace core diff --git a/src/ray/core_worker/core_worker_process.h b/src/ray/core_worker/core_worker_process.h index df8e10b87..15067be3a 100644 --- a/src/ray/core_worker/core_worker_process.h +++ b/src/ray/core_worker/core_worker_process.h @@ -25,7 +25,6 @@ class CoreWorker; /// CoreWorkerOptions options = { /// WorkerType::DRIVER, // worker_type /// ..., // other arguments -/// 1, // num_workers /// }; /// CoreWorkerProcess::Initialize(options); /// @@ -36,7 +35,6 @@ class CoreWorker; /// CoreWorkerOptions options = { /// WorkerType::WORKER, // worker_type /// ..., // other arguments -/// num_workers, // num_workers /// }; /// CoreWorkerProcess::Initialize(options); /// ... // Do other stuff @@ -52,14 +50,7 @@ class CoreWorker; /// threads, remember to call `CoreWorkerProcess::SetCurrentThreadWorkerId(worker_id)` /// once in the new thread before calling core worker APIs, to associate the current /// thread with a worker. You can obtain the worker ID via -/// `CoreWorkerProcess::GetCoreWorker()->GetWorkerID()`. Currently a Java worker process -/// starts multiple workers by default, but can be configured to start only 1 worker by -/// speicifying `num_java_workers_per_process` in the job config. -/// -/// If only 1 worker is started (either because the worker type is driver, or the -/// `num_workers` in `CoreWorkerOptions` is set to 1), all threads will be automatically -/// associated to the only worker. Then no need to call `SetCurrentThreadWorkerId` in -/// your own threads. Currently a Python worker process starts only 1 worker. +/// `CoreWorkerProcess::GetCoreWorker()->GetWorkerID()`. /// /// How does core worker process dealloation work? /// @@ -195,9 +186,6 @@ class CoreWorkerProcessImpl { /// The worker ID of the global worker, if the number of workers is 1. const WorkerID global_worker_id_; - /// Map from worker ID to worker. - absl::flat_hash_map> workers_ GUARDED_BY(mutex_); - /// To protect access to workers_ and global_worker_ mutable absl::Mutex mutex_; }; diff --git a/src/ray/core_worker/lib/java/io_ray_runtime_RayNativeRuntime.cc b/src/ray/core_worker/lib/java/io_ray_runtime_RayNativeRuntime.cc index 33433838d..edad6bbcd 100644 --- a/src/ray/core_worker/lib/java/io_ray_runtime_RayNativeRuntime.cc +++ b/src/ray/core_worker/lib/java/io_ray_runtime_RayNativeRuntime.cc @@ -103,7 +103,6 @@ Java_io_ray_runtime_RayNativeRuntime_nativeInitialize(JNIEnv *env, jstring rayletSocket, jbyteArray jobId, jobject gcsClientOptions, - jint numWorkersPerProcess, jstring logDir, jbyteArray jobConfig, jint startupToken, @@ -289,7 +288,6 @@ Java_io_ray_runtime_RayNativeRuntime_nativeInitialize(JNIEnv *env, options.task_execution_callback = task_execution_callback; options.on_worker_shutdown = on_worker_shutdown; options.gc_collect = gc_collect; - options.num_workers = static_cast(numWorkersPerProcess); options.serialized_job_config = serialized_job_config; options.metrics_agent_port = -1; options.startup_token = startupToken; diff --git a/src/ray/core_worker/lib/java/io_ray_runtime_RayNativeRuntime.h b/src/ray/core_worker/lib/java/io_ray_runtime_RayNativeRuntime.h index 785a7965a..6650799ce 100644 --- a/src/ray/core_worker/lib/java/io_ray_runtime_RayNativeRuntime.h +++ b/src/ray/core_worker/lib/java/io_ray_runtime_RayNativeRuntime.h @@ -25,7 +25,7 @@ extern "C" { * Class: io_ray_runtime_RayNativeRuntime * Method: nativeInitialize * Signature: - * (ILjava/lang/String;ILjava/lang/String;Ljava/lang/String;Ljava/lang/String;[BLio/ray/runtime/gcs/GcsClientOptions;ILjava/lang/String;[BII)V + * (ILjava/lang/String;ILjava/lang/String;Ljava/lang/String;Ljava/lang/String;[BLio/ray/runtime/gcs/GcsClientOptions;Ljava/lang/String;[BII)V */ JNIEXPORT void JNICALL Java_io_ray_runtime_RayNativeRuntime_nativeInitialize(JNIEnv *, jclass, @@ -37,7 +37,6 @@ JNIEXPORT void JNICALL Java_io_ray_runtime_RayNativeRuntime_nativeInitialize(JNI jstring, jbyteArray, jobject, - jint, jstring, jbyteArray, jint, diff --git a/src/ray/core_worker/test/core_worker_test.cc b/src/ray/core_worker/test/core_worker_test.cc index 968fec57e..497d493ce 100644 --- a/src/ray/core_worker/test/core_worker_test.cc +++ b/src/ray/core_worker/test/core_worker_test.cc @@ -140,7 +140,6 @@ class CoreWorkerTest : public ::testing::Test { options.node_manager_port = node_manager_port; options.raylet_ip_address = "127.0.0.1"; options.driver_name = "core_worker_test"; - options.num_workers = 1; options.metrics_agent_port = -1; CoreWorkerProcess::Initialize(options); } diff --git a/src/ray/core_worker/test/mock_worker.cc b/src/ray/core_worker/test/mock_worker.cc index 91ce0bb9a..9c9ec6ec3 100644 --- a/src/ray/core_worker/test/mock_worker.cc +++ b/src/ray/core_worker/test/mock_worker.cc @@ -51,7 +51,6 @@ class MockWorker { options.raylet_ip_address = "127.0.0.1"; options.task_execution_callback = std::bind(&MockWorker::ExecuteTask, this, _1, _2, _3, _4, _5, _6, _7, _8, _9); - options.num_workers = 1; options.metrics_agent_port = -1; options.startup_token = startup_token; CoreWorkerProcess::Initialize(options); diff --git a/src/ray/gcs/gcs_server/gcs_actor_manager.cc b/src/ray/gcs/gcs_server/gcs_actor_manager.cc index fcfbd7260..3b27acf7f 100644 --- a/src/ray/gcs/gcs_server/gcs_actor_manager.cc +++ b/src/ray/gcs/gcs_server/gcs_actor_manager.cc @@ -686,12 +686,8 @@ void GcsActorManager::PollOwnerForActorOutOfScope( if (node_it != owners_.end() && node_it->second.count(owner_id)) { // Only destroy the actor if its owner is still alive. The actor may // have already been destroyed if the owner died. - // For multiple actors in one process, if one actor is out of scope, - // We shouldn't force kill the actor because other actors in the process - // are still alive. - auto force_kill = - get_job_config_(actor_id.JobId())->num_java_workers_per_process() <= 1; - DestroyActor(actor_id, GenActorOutOfScopeCause(GetActor(actor_id)), force_kill); + DestroyActor( + actor_id, GenActorOutOfScopeCause(GetActor(actor_id)), /*force_kill=*/true); } }); } diff --git a/src/ray/gcs/gcs_server/test/gcs_job_manager_test.cc b/src/ray/gcs/gcs_server/test/gcs_job_manager_test.cc index b1ba094e8..506da34a1 100644 --- a/src/ray/gcs/gcs_server/test/gcs_job_manager_test.cc +++ b/src/ray/gcs/gcs_server/test/gcs_job_manager_test.cc @@ -89,7 +89,7 @@ TEST_F(GcsJobManagerTest, TestGetJobConfig) { auto job_id2 = JobID::FromInt(2); gcs::GcsInitData gcs_init_data(gcs_table_storage_); gcs_job_manager.Initialize(/*init_data=*/gcs_init_data); - auto add_job_request1 = Mocker::GenAddJobRequest(job_id1, "namespace_1", 4); + auto add_job_request1 = Mocker::GenAddJobRequest(job_id1, "namespace_1"); rpc::AddJobReply empty_reply; @@ -97,19 +97,16 @@ TEST_F(GcsJobManagerTest, TestGetJobConfig) { *add_job_request1, &empty_reply, [](Status, std::function, std::function) {}); - auto add_job_request2 = Mocker::GenAddJobRequest(job_id2, "namespace_2", 8); + auto add_job_request2 = Mocker::GenAddJobRequest(job_id2, "namespace_2"); gcs_job_manager.HandleAddJob( *add_job_request2, &empty_reply, [](Status, std::function, std::function) {}); - auto job_config1 = gcs_job_manager.GetJobConfig(job_id1); ASSERT_EQ("namespace_1", job_config1->ray_namespace()); - ASSERT_EQ(4, job_config1->num_java_workers_per_process()); auto job_config2 = gcs_job_manager.GetJobConfig(job_id2); ASSERT_EQ("namespace_2", job_config2->ray_namespace()); - ASSERT_EQ(8, job_config2->num_java_workers_per_process()); } int main(int argc, char **argv) { diff --git a/src/ray/gcs/test/gcs_test_util.h b/src/ray/gcs/test/gcs_test_util.h index e9cca3866..21c54c146 100644 --- a/src/ray/gcs/test/gcs_test_util.h +++ b/src/ray/gcs/test/gcs_test_util.h @@ -239,12 +239,9 @@ struct Mocker { } static std::shared_ptr GenAddJobRequest( - const JobID &job_id, - const std::string &ray_namespace, - uint32_t num_java_worker_per_process) { + const JobID &job_id, const std::string &ray_namespace) { auto job_config_data = std::make_shared(); job_config_data->set_ray_namespace(ray_namespace); - job_config_data->set_num_java_workers_per_process(num_java_worker_per_process); auto job_table_data = std::make_shared(); job_table_data->set_job_id(job_id.Binary()); diff --git a/src/ray/object_manager/common.h b/src/ray/object_manager/common.h index 01da4a1d3..66829d251 100644 --- a/src/ray/object_manager/common.h +++ b/src/ray/object_manager/common.h @@ -39,8 +39,8 @@ using RestoreSpilledObjectCallback = /// A struct that includes info about the object. struct ObjectInfo { ObjectID object_id; - int64_t data_size; - int64_t metadata_size; + int64_t data_size = 0; + int64_t metadata_size = 0; /// Owner's raylet ID. NodeID owner_raylet_id; /// Owner's IP address. diff --git a/src/ray/protobuf/gcs.proto b/src/ray/protobuf/gcs.proto index 34dec8bd5..0fb170c03 100644 --- a/src/ray/protobuf/gcs.proto +++ b/src/ray/protobuf/gcs.proto @@ -260,8 +260,6 @@ message JobConfig { NON_DETACHED = 1; } - // The number of java workers per worker process. - uint32 num_java_workers_per_process = 1; // The jvm options for java workers of the job. repeated string jvm_options = 2; // A list of directories or files (jar files or dynamic libraries) that specify the diff --git a/src/ray/raylet/worker_pool.cc b/src/ray/raylet/worker_pool.cc index 4390477a8..7cbc0de70 100644 --- a/src/ray/raylet/worker_pool.cc +++ b/src/ray/raylet/worker_pool.cc @@ -198,14 +198,12 @@ void WorkerPool::update_worker_startup_token_counter() { void WorkerPool::AddWorkerProcess( State &state, - const int workers_to_start, const rpc::WorkerType worker_type, const Process &proc, const std::chrono::high_resolution_clock::time_point &start, const rpc::RuntimeEnvInfo &runtime_env_info) { state.worker_processes.emplace(worker_startup_token_counter_, - WorkerProcessInfo{workers_to_start, - workers_to_start, + WorkerProcessInfo{/*is_pending_registration=*/true, {}, worker_type, proc, @@ -248,7 +246,7 @@ std::tuple WorkerPool::StartWorkerProcess( int starting_workers = 0; for (auto &entry : state.worker_processes) { if (entry.second.worker_type == worker_type) { - starting_workers += entry.second.num_starting_workers; + starting_workers += entry.second.is_pending_registration ? 1 : 0; } } @@ -268,13 +266,6 @@ std::tuple WorkerPool::StartWorkerProcess( << rpc::WorkerType_Name(worker_type) << ", current pool has " << state.idle.size() << " workers"; - int workers_to_start = 1; - if (dynamic_options.empty()) { - if (language == Language::JAVA) { - workers_to_start = job_config->num_java_workers_per_process(); - } - } - std::vector options; // Append Ray-defined per-job options here @@ -312,12 +303,6 @@ std::tuple WorkerPool::StartWorkerProcess( } } - // Append Ray-defined per-process options here - if (language == Language::JAVA) { - options.push_back("-Dray.job.num-java-workers-per-process=" + - std::to_string(workers_to_start)); - } - // Append startup-token for JAVA here if (language == Language::JAVA) { options.push_back("-Dray.raylet.startup-token=" + @@ -460,13 +445,12 @@ std::tuple WorkerPool::StartWorkerProcess( auto duration = std::chrono::duration_cast(end - start); stats::ProcessStartupTimeMs.Record(duration.count()); stats::NumWorkersStarted.Record(1); - RAY_LOG(INFO) << "Started worker process of " << workers_to_start - << " worker(s) with pid " << proc.GetId() << ", the token " + RAY_LOG(INFO) << "Started worker process with pid " << proc.GetId() << ", the token is " << worker_startup_token_counter_; AdjustWorkerOomScore(proc.GetId()); MonitorStartingWorkerProcess( proc, worker_startup_token_counter_, language, worker_type); - AddWorkerProcess(state, workers_to_start, worker_type, proc, start, runtime_env_info); + AddWorkerProcess(state, worker_type, proc, start, runtime_env_info); StartupToken worker_startup_token = worker_startup_token_counter_; update_worker_startup_token_counter(); if (IsIOWorkerType(worker_type)) { @@ -513,7 +497,7 @@ void WorkerPool::MonitorStartingWorkerProcess(const Process &proc, // Since this process times out to start, remove it from worker_processes // to avoid the zombie worker. auto it = state.worker_processes.find(proc_startup_token); - if (it != state.worker_processes.end() && it->second.num_starting_workers != 0) { + if (it != state.worker_processes.end() && it->second.is_pending_registration) { RAY_LOG(ERROR) << "Some workers of the worker process(" << proc.GetId() << ") have not registered within the timeout. " @@ -747,12 +731,10 @@ void WorkerPool::OnWorkerStarted(const std::shared_ptr &worker) auto it = state.worker_processes.find(worker_startup_token); if (it != state.worker_processes.end()) { - it->second.num_starting_workers--; + it->second.is_pending_registration = false; it->second.alive_started_workers.insert(worker); - if (it->second.num_starting_workers == 0) { - // We may have slots to start more workers now. - TryStartIOWorkers(worker->GetLanguage()); - } + // We may have slots to start more workers now. + TryStartIOWorkers(worker->GetLanguage()); } const auto &worker_type = worker->GetWorkerType(); if (IsIOWorkerType(worker_type)) { @@ -1044,8 +1026,7 @@ void WorkerPool::TryKillingIdleWorkers() { auto &worker_state = GetStateForLanguage(idle_worker->GetLanguage()); auto it = worker_state.worker_processes.find(worker_startup_token); - if (it != worker_state.worker_processes.end() && - it->second.num_starting_workers > 0) { + if (it != worker_state.worker_processes.end() && it->second.is_pending_registration) { // A Java worker process may hold multiple workers. // Some workers of this process are pending registration. Skip killing this worker. continue; @@ -1349,7 +1330,7 @@ void WorkerPool::PrestartWorkers(const TaskSpecification &task_spec, // The number of available workers that can be used for this task spec. int num_usable_workers = state.idle.size(); for (auto &entry : state.worker_processes) { - num_usable_workers += entry.second.num_starting_workers; + num_usable_workers += entry.second.is_pending_registration ? 1 : 0; } // Some existing workers may be holding less than 1 CPU each, so we should // start as many workers as needed to fill up the remaining CPUs. @@ -1376,10 +1357,10 @@ void WorkerPool::DisconnectWorker(const std::shared_ptr &worker if (!RemoveWorker(it->second.alive_started_workers, worker)) { // Worker is either starting or started, // if it's not started, we should remove it from starting. - it->second.num_starting_workers--; + it->second.is_pending_registration = false; } if (it->second.alive_started_workers.size() == 0 && - it->second.num_starting_workers == 0) { + !it->second.is_pending_registration) { DeleteRuntimeEnvIfPossible(it->second.runtime_env_info.serialized_runtime_env()); RemoveWorkerProcess(state, worker->GetStartupToken()); } @@ -1508,7 +1489,8 @@ void WorkerPool::WarnAboutSize() { num_workers_started_or_registered += static_cast(state.registered_workers.size()); for (const auto &starting_process : state.worker_processes) { - num_workers_started_or_registered += starting_process.second.num_starting_workers; + num_workers_started_or_registered += + starting_process.second.is_pending_registration ? 0 : 1; } // Don't count IO workers towards the warning message threshold. num_workers_started_or_registered -= RayConfig::instance().max_io_workers() * 2; diff --git a/src/ray/raylet/worker_pool.h b/src/ray/raylet/worker_pool.h index 4e94ade38..3bd3da859 100644 --- a/src/ray/raylet/worker_pool.h +++ b/src/ray/raylet/worker_pool.h @@ -473,10 +473,8 @@ class WorkerPool : public WorkerPoolInterface, public IOWorkerPoolInterface { /// Some basic information about the worker process. struct WorkerProcessInfo { - /// The number of workers in the worker process. - int num_workers; - /// The number of pending registration workers in the worker process. - int num_starting_workers; + /// Whether this worker is pending registration or is started. + bool is_pending_registration = true; /// The started workers which is alive. std::unordered_set> alive_started_workers; /// The type of the worker. @@ -684,7 +682,6 @@ class WorkerPool : public WorkerPoolInterface, public IOWorkerPoolInterface { void DeleteRuntimeEnvIfPossible(const std::string &serialized_runtime_env); void AddWorkerProcess(State &state, - const int workers_to_start, const rpc::WorkerType worker_type, const Process &proc, const std::chrono::high_resolution_clock::time_point &start, diff --git a/src/ray/raylet/worker_pool_test.cc b/src/ray/raylet/worker_pool_test.cc index a0be09325..c9d4255f4 100644 --- a/src/ray/raylet/worker_pool_test.cc +++ b/src/ray/raylet/worker_pool_test.cc @@ -26,8 +26,7 @@ namespace ray { namespace raylet { -int NUM_WORKERS_PER_PROCESS_JAVA = 3; -int MAXIMUM_STARTUP_CONCURRENCY = 5; +int MAXIMUM_STARTUP_CONCURRENCY = 15; int MAX_IO_WORKER_SIZE = 2; int POOL_SIZE_SOFT_LIMIT = 5; int WORKER_REGISTER_TIMEOUT_SECONDS = 3; @@ -37,10 +36,6 @@ const std::string BAD_RUNTIME_ENV_ERROR_MSG = "bad runtime env"; std::vector LANGUAGES = {Language::PYTHON, Language::JAVA}; -static inline std::string GetNumJavaWorkersPerProcessSystemProperty(int num) { - return std::string("-Dray.job.num-java-workers-per-process=") + std::to_string(num); -} - class MockWorkerClient : public rpc::CoreWorkerClientInterface { public: MockWorkerClient(instrumented_io_context &io_service) : io_service_(io_service) {} @@ -194,7 +189,7 @@ class WorkerPoolMock : public WorkerPool { int total = 0; for (auto &state_entry : states_by_lang_) { for (auto &process_entry : state_entry.second.worker_processes) { - total += process_entry.second.num_starting_workers; + total += process_entry.second.is_pending_registration ? 1 : 0; } } return total; @@ -204,7 +199,7 @@ class WorkerPoolMock : public WorkerPool { int total = 0; for (auto &entry : states_by_lang_) { for (auto process : entry.second.worker_processes) { - if (process.second.num_starting_workers != 0) { + if (process.second.is_pending_registration) { total += 1; } } @@ -313,7 +308,6 @@ class WorkerPoolMock : public WorkerPool { if (pushed_it == pushedProcesses_.end()) { int runtime_env_hash = 0; bool is_java = false; - bool has_dynamic_options = false; // Parses runtime env hash to make sure the pushed workers can be popped out. for (auto command_args : it->second) { std::string runtime_env_key = "--runtime-env-hash="; @@ -326,14 +320,9 @@ class WorkerPoolMock : public WorkerPool { if (pos != std::string::npos) { is_java = true; } - pos = command_args.find("-X"); - if (pos != std::string::npos) { - has_dynamic_options = true; - } } // TODO(SongGuyang): support C++ language workers. - int num_workers = - (is_java && !has_dynamic_options) ? NUM_WORKERS_PER_PROCESS_JAVA : 1; + int num_workers = 1; RAY_CHECK(timeout_worker_number <= num_workers) << "The timeout worker number cannot exceed the total number of workers"; auto register_workers = num_workers - timeout_worker_number; @@ -471,7 +460,6 @@ class WorkerPoolTest : public ::testing::Test { worker_pool_ = std::make_unique( io_service_, worker_commands, mock_worker_rpc_clients_); rpc::JobConfig job_config; - job_config.set_num_java_workers_per_process(NUM_WORKERS_PER_PROCESS_JAVA); RegisterDriver(Language::PYTHON, JOB_ID, job_config); } @@ -489,18 +477,7 @@ class WorkerPoolTest : public ::testing::Test { ASSERT_TRUE(worker_pool_->NumWorkerProcessesStarting() <= expected_worker_process_count); Process prev = worker_pool_->LastStartedWorkerProcess(); - if (!std::equal_to()(last_started_worker_process, prev)) { - last_started_worker_process = prev; - const auto &real_command = - worker_pool_->GetWorkerCommand(last_started_worker_process); - if (language == Language::JAVA) { - auto it = std::find( - real_command.begin(), - real_command.end(), - GetNumJavaWorkersPerProcessSystemProperty(num_workers_per_process)); - ASSERT_NE(it, real_command.end()); - } - } else { + if (std::equal_to()(last_started_worker_process, prev)) { ASSERT_EQ(worker_pool_->NumWorkerProcessesStarting(), expected_worker_process_count); ASSERT_TRUE(i >= expected_worker_process_count); @@ -618,9 +595,7 @@ TEST_F(WorkerPoolTest, HandleWorkerRegistration) { auto [proc, token] = worker_pool_->StartWorkerProcess( Language::JAVA, rpc::WorkerType::WORKER, JOB_ID, &status); std::vector> workers; - for (int i = 0; i < NUM_WORKERS_PER_PROCESS_JAVA; i++) { - workers.push_back(worker_pool_->CreateWorker(Process(), Language::JAVA)); - } + workers.push_back(worker_pool_->CreateWorker(Process(), Language::JAVA)); for (const auto &worker : workers) { // Check that there's still a starting worker process // before all workers have been registered @@ -670,7 +645,7 @@ TEST_F(WorkerPoolTest, StartupPythonWorkerProcessCount) { } TEST_F(WorkerPoolTest, StartupJavaWorkerProcessCount) { - TestStartupWorkerProcessCount(Language::JAVA, NUM_WORKERS_PER_PROCESS_JAVA); + TestStartupWorkerProcessCount(Language::JAVA, 1); } TEST_F(WorkerPoolTest, InitialWorkerProcessCount) { @@ -756,7 +731,6 @@ TEST_F(WorkerPoolTest, StartWorkerWithDynamicOptionsCommand) { rpc::JobConfig job_config = rpc::JobConfig(); job_config.add_code_search_path("/test/code_search_path"); - job_config.set_num_java_workers_per_process(NUM_WORKERS_PER_PROCESS_JAVA); job_config.add_jvm_options("-Xmx1g"); job_config.add_jvm_options("-Xms500m"); job_config.add_jvm_options("-Dmy-job.hello=world"); @@ -779,7 +753,6 @@ TEST_F(WorkerPoolTest, StartWorkerWithDynamicOptionsCommand) { expected_command.end(), {"-Xmx1g", "-Xms500m", "-Dmy-job.hello=world", "-Dmy-job.foo=bar"}); // Ray-defined per-process options - expected_command.push_back(GetNumJavaWorkersPerProcessSystemProperty(1)); expected_command.push_back("-Dray.raylet.startup-token=0"); expected_command.push_back("-Dray.internal.runtime-env-hash=1"); // User-defined per-process options @@ -1162,46 +1135,6 @@ TEST_F(WorkerPoolTest, DeleteWorkerPushPop) { }); } -TEST_F(WorkerPoolTest, NoPopOnCrashedWorkerProcess) { - // Start a Java worker process. - PopWorkerStatus status; - auto [proc, token] = worker_pool_->StartWorkerProcess( - Language::JAVA, rpc::WorkerType::WORKER, JOB_ID, &status); - auto worker1 = worker_pool_->CreateWorker(Process(), Language::JAVA); - auto worker2 = worker_pool_->CreateWorker(Process(), Language::JAVA); - - // We now imitate worker process crashing while core worker initializing. - - // 1. we register both workers. - RAY_CHECK_OK(worker_pool_->RegisterWorker( - worker1, proc.GetId(), worker_pool_->GetStartupToken(proc), [](Status, int) {})); - RAY_CHECK_OK(worker_pool_->RegisterWorker( - worker2, proc.GetId(), worker_pool_->GetStartupToken(proc), [](Status, int) {})); - - // 2. announce worker port for worker 1. When interacting with worker pool, it's - // PushWorker. - worker_pool_->PushWorker(worker1); - - // 3. kill the worker process. Now let's assume that Raylet found that the connection - // with worker 1 disconnected first. - worker_pool_->DisconnectWorker( - worker1, /*disconnect_type=*/rpc::WorkerExitType::SYSTEM_ERROR_EXIT); - - // 4. but the RPC for announcing worker port for worker 2 is already in Raylet input - // buffer. So now Raylet needs to handle worker 2. - worker_pool_->PushWorker(worker2); - - // 5. Let's try to pop a worker to execute a task. Worker 2 shouldn't be popped because - // the process has crashed. - const auto task_spec = ExampleTaskSpec(); - ASSERT_NE(worker_pool_->PopWorkerSync(task_spec), worker1); - ASSERT_NE(worker_pool_->PopWorkerSync(task_spec), worker2); - - // 6. Now Raylet disconnects with worker 2. - worker_pool_->DisconnectWorker( - worker2, /*disconnect_type=*/rpc::WorkerExitType::SYSTEM_ERROR_EXIT); -} - TEST_F(WorkerPoolTest, TestWorkerCapping) { auto job_id = JOB_ID; @@ -1423,8 +1356,7 @@ TEST_F(WorkerPoolTest, TestWorkerCappingWithExitDelay) { PopWorkerStatus status; auto [proc, token] = worker_pool_->StartWorkerProcess( language, rpc::WorkerType::WORKER, JOB_ID, &status); - int workers_to_start = - language == Language::JAVA ? NUM_WORKERS_PER_PROCESS_JAVA : 1; + int workers_to_start = 1; for (int j = 0; j < workers_to_start; j++) { auto worker = worker_pool_->CreateWorker(Process(), language); worker->SetStartupToken(worker_pool_->GetStartupToken(proc)); @@ -1644,77 +1576,6 @@ TEST_F(WorkerPoolTest, RuntimeEnvUriReferenceWorkerLevel) { } } -TEST_F(WorkerPoolTest, RuntimeEnvUriReferenceWithMultipleWorkers) { - auto job_id = JOB_ID; - std::string uri = "s3://567"; - auto runtime_env_info = ExampleRuntimeEnvInfo({uri}, false); - rpc::JobConfig job_config; - job_config.set_num_java_workers_per_process(NUM_WORKERS_PER_PROCESS_JAVA); - job_config.mutable_runtime_env_info()->CopyFrom(runtime_env_info); - // Start job without eager installed runtime env. - worker_pool_->HandleJobStarted(job_id, job_config); - ASSERT_EQ(GetReferenceCount(runtime_env_info.serialized_runtime_env()), 0); - - // First part, test normal case with all worker registered. - { - // Start actors with runtime env. The Java actors will trigger a multi-worker process. - std::vector> workers; - for (int i = 0; i < NUM_WORKERS_PER_PROCESS_JAVA; i++) { - auto actor_creation_id = ActorID::Of(job_id, TaskID::ForDriverTask(job_id), i + 1); - const auto actor_creation_task_spec = - ExampleTaskSpec(ActorID::Nil(), - Language::JAVA, - job_id, - actor_creation_id, - {}, - TaskID::FromRandom(JobID::Nil()), - runtime_env_info); - auto popped_actor_worker = worker_pool_->PopWorkerSync(actor_creation_task_spec); - ASSERT_NE(popped_actor_worker, nullptr); - workers.push_back(popped_actor_worker); - ASSERT_EQ(GetReferenceCount(runtime_env_info.serialized_runtime_env()), 1); - } - // Make sure only one worker process has been started. - ASSERT_EQ(worker_pool_->GetProcessSize(), 1); - // Disconnect all actor workers. - for (auto &worker : workers) { - worker_pool_->DisconnectWorker(worker, rpc::WorkerExitType::IDLE_EXIT); - } - ASSERT_EQ(GetReferenceCount(runtime_env_info.serialized_runtime_env()), 0); - } - - // Second part, test corner case with some worker registration timeout. - { - // Start one actor with runtime env. The Java actor will trigger a multi-worker - // process. - auto actor_creation_id = ActorID::Of(job_id, TaskID::ForDriverTask(job_id), 1); - const auto actor_creation_task_spec = - ExampleTaskSpec(ActorID::Nil(), - Language::JAVA, - job_id, - actor_creation_id, - {}, - TaskID::FromRandom(JobID::Nil()), - runtime_env_info); - PopWorkerStatus status; - // Only one worker registration. All the other worker registration times out. - auto popped_actor_worker = worker_pool_->PopWorkerSync( - actor_creation_task_spec, true, &status, NUM_WORKERS_PER_PROCESS_JAVA - 1); - ASSERT_EQ(GetReferenceCount(runtime_env_info.serialized_runtime_env()), 1); - // Disconnect actor worker. - worker_pool_->DisconnectWorker(popped_actor_worker, rpc::WorkerExitType::IDLE_EXIT); - ASSERT_EQ(GetReferenceCount(runtime_env_info.serialized_runtime_env()), 1); - // Sleep for a while to wait worker registration timeout. - std::this_thread::sleep_for( - std::chrono::seconds(WORKER_REGISTER_TIMEOUT_SECONDS + 1)); - ASSERT_EQ(GetReferenceCount(runtime_env_info.serialized_runtime_env()), 0); - } - - // Finish the job. - worker_pool_->HandleJobFinished(job_id); - ASSERT_EQ(GetReferenceCount(runtime_env_info.serialized_runtime_env()), 0); -} - TEST_F(WorkerPoolTest, CacheWorkersByRuntimeEnvHash) { /// /// Check that a worker can be popped only if there is a