mirror of
https://github.com/vale981/ray
synced 2025-03-04 17:41:43 -05:00
[Core] Remove multiple core workers in one process 1/n. (#24147)
This is the 1st PR to remove the code path of multiple core workers in one process. This PR is aiming to remove the flags and APIs related to `num_workers`. After this PR checking in, we needn't to consider the multiple core workers any longer. The further following PRs are related to the deeper logic refactor, like eliminating the gap between core worker and core worker process, removing the logic related to multiple workers from workerpool, gcs and etc. **BREAK CHANGE** This PR removes these APIs: - Ray.wrapRunnable(); - Ray.wrapCallable(); - Ray.setAsyncContext(); - Ray.getAsyncContext(); And the following APIs are not allowed to invoke in a user-created thread in local mode: - Ray.getRuntimeContext().getCurrentActorId(); - Ray.getRuntimeContext().getCurrentTaskId() Note that this PR shouldn't be merged to 1.x.
This commit is contained in:
parent
1d5e6d908d
commit
eb29895dbb
57 changed files with 137 additions and 1200 deletions
|
@ -139,7 +139,6 @@ void ProcessHelper::RayStart(CoreWorkerOptions::TaskExecutionCallback callback)
|
|||
options.node_manager_port = ConfigInternal::Instance().node_manager_port;
|
||||
options.raylet_ip_address = node_ip;
|
||||
options.driver_name = "cpp_worker";
|
||||
options.num_workers = 1;
|
||||
options.metrics_agent_port = -1;
|
||||
options.task_execution_callback = callback;
|
||||
options.startup_token = ConfigInternal::Instance().startup_token;
|
||||
|
|
|
@ -5,7 +5,6 @@ import io.ray.api.runtime.RayRuntimeFactory;
|
|||
import io.ray.api.runtimecontext.RuntimeContext;
|
||||
import java.util.List;
|
||||
import java.util.Optional;
|
||||
import java.util.concurrent.Callable;
|
||||
|
||||
/** This class contains all public APIs of Ray. */
|
||||
public final class Ray extends RayCall {
|
||||
|
@ -205,52 +204,6 @@ public final class Ray extends RayCall {
|
|||
return internal().getActor(name, namespace);
|
||||
}
|
||||
|
||||
/**
|
||||
* If users want to use Ray API in their own threads, call this method to get the async context
|
||||
* and then call {@link #setAsyncContext} at the beginning of the new thread.
|
||||
*
|
||||
* @return The async context.
|
||||
*/
|
||||
public static Object getAsyncContext() {
|
||||
return internal().getAsyncContext();
|
||||
}
|
||||
|
||||
/**
|
||||
* Set the async context for the current thread.
|
||||
*
|
||||
* @param asyncContext The async context to set.
|
||||
*/
|
||||
public static void setAsyncContext(Object asyncContext) {
|
||||
internal().setAsyncContext(asyncContext);
|
||||
}
|
||||
|
||||
// TODO (kfstorm): add the `rollbackAsyncContext` API to allow rollbacking the async context of
|
||||
// the current thread to the one before `setAsyncContext` is called.
|
||||
|
||||
// TODO (kfstorm): unify the `wrap*` methods.
|
||||
|
||||
/**
|
||||
* If users want to use Ray API in their own threads, they should wrap their {@link Runnable}
|
||||
* objects with this method.
|
||||
*
|
||||
* @param runnable The runnable to wrap.
|
||||
* @return The wrapped runnable.
|
||||
*/
|
||||
public static Runnable wrapRunnable(Runnable runnable) {
|
||||
return internal().wrapRunnable(runnable);
|
||||
}
|
||||
|
||||
/**
|
||||
* If users want to use Ray API in their own threads, they should wrap their {@link Callable}
|
||||
* objects with this method.
|
||||
*
|
||||
* @param callable The callable to wrap.
|
||||
* @return The wrapped callable.
|
||||
*/
|
||||
public static <T> Callable<T> wrapCallable(Callable<T> callable) {
|
||||
return internal().wrapCallable(callable);
|
||||
}
|
||||
|
||||
/** Get the underlying runtime instance. */
|
||||
public static RayRuntime internal() {
|
||||
if (runtime == null) {
|
||||
|
|
|
@ -24,7 +24,6 @@ import io.ray.api.runtimeenv.RuntimeEnv;
|
|||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Optional;
|
||||
import java.util.concurrent.Callable;
|
||||
|
||||
/** Base interface of a Ray runtime. */
|
||||
public interface RayRuntime {
|
||||
|
@ -207,26 +206,6 @@ public interface RayRuntime {
|
|||
|
||||
RuntimeContext getRuntimeContext();
|
||||
|
||||
Object getAsyncContext();
|
||||
|
||||
void setAsyncContext(Object asyncContext);
|
||||
|
||||
/**
|
||||
* Wrap a {@link Runnable} with necessary context capture.
|
||||
*
|
||||
* @param runnable The runnable to wrap.
|
||||
* @return The wrapped runnable.
|
||||
*/
|
||||
Runnable wrapRunnable(Runnable runnable);
|
||||
|
||||
/**
|
||||
* Wrap a {@link Callable} with necessary context capture.
|
||||
*
|
||||
* @param callable The callable to wrap.
|
||||
* @return The wrapped callable.
|
||||
*/
|
||||
<T> Callable<T> wrapCallable(Callable<T> callable);
|
||||
|
||||
/** Intentionally exit the current actor. */
|
||||
void exitActor();
|
||||
|
||||
|
|
|
@ -23,10 +23,7 @@ public class ActorPerformanceTestBase {
|
|||
boolean hasReturn,
|
||||
boolean ignoreReturn,
|
||||
int argSize,
|
||||
boolean useDirectByteBuffer,
|
||||
int numJavaWorkerPerProcess) {
|
||||
System.setProperty(
|
||||
"ray.job.num-java-workers-per-process", String.valueOf(numJavaWorkerPerProcess));
|
||||
boolean useDirectByteBuffer) {
|
||||
System.setProperty("ray.raylet.startup-token", "0");
|
||||
Ray.init();
|
||||
try {
|
||||
|
|
|
@ -13,7 +13,6 @@ public class ActorPerformanceTestCase1 {
|
|||
final int argSize = 0;
|
||||
final boolean useDirectByteBuffer = false;
|
||||
final boolean ignoreReturn = false;
|
||||
final int numJavaWorkerPerProcess = 1;
|
||||
ActorPerformanceTestBase.run(
|
||||
args,
|
||||
layers,
|
||||
|
@ -21,7 +20,6 @@ public class ActorPerformanceTestCase1 {
|
|||
hasReturn,
|
||||
ignoreReturn,
|
||||
argSize,
|
||||
useDirectByteBuffer,
|
||||
numJavaWorkerPerProcess);
|
||||
useDirectByteBuffer);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -31,7 +31,6 @@ import io.ray.runtime.functionmanager.FunctionDescriptor;
|
|||
import io.ray.runtime.functionmanager.FunctionManager;
|
||||
import io.ray.runtime.functionmanager.PyFunctionDescriptor;
|
||||
import io.ray.runtime.functionmanager.RayFunction;
|
||||
import io.ray.runtime.generated.Common;
|
||||
import io.ray.runtime.generated.Common.Language;
|
||||
import io.ray.runtime.object.ObjectRefImpl;
|
||||
import io.ray.runtime.object.ObjectStore;
|
||||
|
@ -46,7 +45,6 @@ import java.util.ArrayList;
|
|||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Optional;
|
||||
import java.util.concurrent.Callable;
|
||||
import java.util.stream.Collectors;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
@ -67,13 +65,8 @@ public abstract class AbstractRayRuntime implements RayRuntimeInternal {
|
|||
|
||||
private static ParallelActorContextImpl parallelActorContextImpl = new ParallelActorContextImpl();
|
||||
|
||||
/** Whether the required thread context is set on the current thread. */
|
||||
final ThreadLocal<Boolean> isContextSet = ThreadLocal.withInitial(() -> false);
|
||||
|
||||
public AbstractRayRuntime(RayConfig rayConfig) {
|
||||
this.rayConfig = rayConfig;
|
||||
setIsContextSet(rayConfig.workerMode == Common.WorkerType.DRIVER);
|
||||
functionManager = new FunctionManager(rayConfig.codeSearchPath);
|
||||
runtimeContext = new RuntimeContextImpl(this);
|
||||
}
|
||||
|
||||
|
@ -158,7 +151,7 @@ public abstract class AbstractRayRuntime implements RayRuntimeInternal {
|
|||
|
||||
@Override
|
||||
public ObjectRef call(RayFunc func, Object[] args, CallOptions options) {
|
||||
RayFunction rayFunction = functionManager.getFunction(workerContext.getCurrentJobId(), func);
|
||||
RayFunction rayFunction = functionManager.getFunction(func);
|
||||
FunctionDescriptor functionDescriptor = rayFunction.functionDescriptor;
|
||||
Optional<Class<?>> returnType = rayFunction.getReturnType();
|
||||
return callNormalFunction(functionDescriptor, args, returnType, options);
|
||||
|
@ -176,7 +169,7 @@ public abstract class AbstractRayRuntime implements RayRuntimeInternal {
|
|||
@Override
|
||||
public ObjectRef callActor(
|
||||
ActorHandle<?> actor, RayFunc func, Object[] args, CallOptions options) {
|
||||
RayFunction rayFunction = functionManager.getFunction(workerContext.getCurrentJobId(), func);
|
||||
RayFunction rayFunction = functionManager.getFunction(func);
|
||||
FunctionDescriptor functionDescriptor = rayFunction.functionDescriptor;
|
||||
Optional<Class<?>> returnType = rayFunction.getReturnType();
|
||||
return callActorFunction(actor, functionDescriptor, args, returnType, options);
|
||||
|
@ -201,8 +194,7 @@ public abstract class AbstractRayRuntime implements RayRuntimeInternal {
|
|||
public <T> ActorHandle<T> createActor(
|
||||
RayFunc actorFactoryFunc, Object[] args, ActorCreationOptions options) {
|
||||
FunctionDescriptor functionDescriptor =
|
||||
functionManager.getFunction(workerContext.getCurrentJobId(), actorFactoryFunc)
|
||||
.functionDescriptor;
|
||||
functionManager.getFunction(actorFactoryFunc).functionDescriptor;
|
||||
return (ActorHandle<T>) createActorImpl(functionDescriptor, args, options);
|
||||
}
|
||||
|
||||
|
@ -256,31 +248,6 @@ public abstract class AbstractRayRuntime implements RayRuntimeInternal {
|
|||
return (T) taskSubmitter.getActor(actorId);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setAsyncContext(Object asyncContext) {
|
||||
isContextSet.set(true);
|
||||
}
|
||||
|
||||
@Override
|
||||
public final Runnable wrapRunnable(Runnable runnable) {
|
||||
Object asyncContext = getAsyncContext();
|
||||
return () -> {
|
||||
try (RayAsyncContextUpdater updater = new RayAsyncContextUpdater(asyncContext, this)) {
|
||||
runnable.run();
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
@Override
|
||||
public final <T> Callable<T> wrapCallable(Callable<T> callable) {
|
||||
Object asyncContext = getAsyncContext();
|
||||
return () -> {
|
||||
try (RayAsyncContextUpdater updater = new RayAsyncContextUpdater(asyncContext, this)) {
|
||||
return callable.call();
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
@Override
|
||||
public ConcurrencyGroup createConcurrencyGroup(
|
||||
String name, int maxConcurrency, List<RayFunc> funcs) {
|
||||
|
@ -387,34 +354,6 @@ public abstract class AbstractRayRuntime implements RayRuntimeInternal {
|
|||
return actor;
|
||||
}
|
||||
|
||||
/// An auto closable class that is used for updating the async context when invoking Ray APIs.
|
||||
private static final class RayAsyncContextUpdater implements AutoCloseable {
|
||||
|
||||
private AbstractRayRuntime runtime;
|
||||
|
||||
private boolean oldIsContextSet;
|
||||
|
||||
private Object oldAsyncContext = null;
|
||||
|
||||
public RayAsyncContextUpdater(Object asyncContext, AbstractRayRuntime runtime) {
|
||||
this.runtime = runtime;
|
||||
oldIsContextSet = runtime.isContextSet.get();
|
||||
if (oldIsContextSet) {
|
||||
oldAsyncContext = runtime.getAsyncContext();
|
||||
}
|
||||
runtime.setAsyncContext(asyncContext);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void close() {
|
||||
if (oldIsContextSet) {
|
||||
runtime.setAsyncContext(oldAsyncContext);
|
||||
} else {
|
||||
runtime.setIsContextSet(false);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
abstract List<ObjectId> getCurrentReturnIds(int numReturns, ActorId actorId);
|
||||
|
||||
@Override
|
||||
|
@ -446,11 +385,6 @@ public abstract class AbstractRayRuntime implements RayRuntimeInternal {
|
|||
return runtimeContext;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setIsContextSet(boolean isContextSet) {
|
||||
this.isContextSet.set(isContextSet);
|
||||
}
|
||||
|
||||
/// A helper to validate if the prepared return ids is as expected.
|
||||
void validatePreparedReturnIds(List<ObjectId> preparedReturnIds, List<ObjectId> realReturnIds) {
|
||||
if (rayConfig.runMode == RunMode.CLUSTER) {
|
||||
|
|
|
@ -24,9 +24,7 @@ public class ConcurrencyGroupImpl implements ConcurrencyGroup {
|
|||
funcs.forEach(
|
||||
func -> {
|
||||
RayFunction rayFunc =
|
||||
((RayRuntimeInternal) Ray.internal())
|
||||
.getFunctionManager()
|
||||
.getFunction(Ray.getRuntimeContext().getCurrentJobId(), func);
|
||||
((RayRuntimeInternal) Ray.internal()).getFunctionManager().getFunction(func);
|
||||
functionDescriptors.add(rayFunc.getFunctionDescriptor());
|
||||
});
|
||||
}
|
||||
|
|
|
@ -32,10 +32,7 @@ public class DefaultRayRuntimeFactory implements RayRuntimeFactory {
|
|||
rayConfig.runMode == RunMode.SINGLE_PROCESS
|
||||
? new RayDevRuntime(rayConfig)
|
||||
: new RayNativeRuntime(rayConfig);
|
||||
RayRuntimeInternal runtime =
|
||||
rayConfig.numWorkersPerProcess > 1
|
||||
? RayRuntimeProxy.newInstance(innerRuntime)
|
||||
: innerRuntime;
|
||||
RayRuntimeInternal runtime = innerRuntime;
|
||||
runtime.start();
|
||||
return runtime;
|
||||
} catch (Exception e) {
|
||||
|
|
|
@ -1,6 +1,5 @@
|
|||
package io.ray.runtime;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
import io.ray.api.BaseActorHandle;
|
||||
import io.ray.api.id.ActorId;
|
||||
import io.ray.api.id.JobId;
|
||||
|
@ -10,6 +9,7 @@ import io.ray.api.placementgroup.PlacementGroup;
|
|||
import io.ray.api.runtimecontext.ResourceValue;
|
||||
import io.ray.runtime.config.RayConfig;
|
||||
import io.ray.runtime.context.LocalModeWorkerContext;
|
||||
import io.ray.runtime.functionmanager.FunctionManager;
|
||||
import io.ray.runtime.gcs.GcsClient;
|
||||
import io.ray.runtime.generated.Common.TaskSpec;
|
||||
import io.ray.runtime.object.LocalModeObjectStore;
|
||||
|
@ -24,13 +24,9 @@ import java.util.List;
|
|||
import java.util.Map;
|
||||
import java.util.Optional;
|
||||
import java.util.concurrent.atomic.AtomicInteger;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
public class RayDevRuntime extends AbstractRayRuntime {
|
||||
|
||||
private static final Logger LOGGER = LoggerFactory.getLogger(RayDevRuntime.class);
|
||||
|
||||
private AtomicInteger jobCounter = new AtomicInteger(0);
|
||||
|
||||
public RayDevRuntime(RayConfig rayConfig) {
|
||||
|
@ -49,6 +45,7 @@ public class RayDevRuntime extends AbstractRayRuntime {
|
|||
taskExecutor = new LocalModeTaskExecutor(this);
|
||||
workerContext = new LocalModeWorkerContext(rayConfig.getJobId());
|
||||
objectStore = new LocalModeObjectStore(workerContext);
|
||||
functionManager = new FunctionManager(rayConfig.codeSearchPath);
|
||||
taskSubmitter =
|
||||
new LocalModeTaskSubmitter(this, taskExecutor, (LocalModeObjectStore) objectStore);
|
||||
((LocalModeObjectStore) objectStore)
|
||||
|
@ -90,19 +87,6 @@ public class RayDevRuntime extends AbstractRayRuntime {
|
|||
throw new UnsupportedOperationException("Ray doesn't have gcs client in local mode.");
|
||||
}
|
||||
|
||||
@Override
|
||||
public Object getAsyncContext() {
|
||||
return new AsyncContext(((LocalModeWorkerContext) workerContext).getCurrentTask());
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setAsyncContext(Object asyncContext) {
|
||||
Preconditions.checkNotNull(asyncContext);
|
||||
TaskSpec task = ((AsyncContext) asyncContext).task;
|
||||
((LocalModeWorkerContext) workerContext).setCurrentTask(task);
|
||||
super.setAsyncContext(asyncContext);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Map<String, List<ResourceValue>> getAvailableResourceIds() {
|
||||
throw new UnsupportedOperationException("Ray doesn't support get resources ids in local mode.");
|
||||
|
|
|
@ -14,6 +14,7 @@ import io.ray.api.runtimecontext.ResourceValue;
|
|||
import io.ray.runtime.config.RayConfig;
|
||||
import io.ray.runtime.context.NativeWorkerContext;
|
||||
import io.ray.runtime.exception.RayIntentionalSystemExitException;
|
||||
import io.ray.runtime.functionmanager.FunctionManager;
|
||||
import io.ray.runtime.gcs.GcsClient;
|
||||
import io.ray.runtime.gcs.GcsClientOptions;
|
||||
import io.ray.runtime.generated.Common.WorkerType;
|
||||
|
@ -102,14 +103,13 @@ public final class RayNativeRuntime extends AbstractRayRuntime {
|
|||
if (rayConfig.workerMode == WorkerType.DRIVER && rayConfig.getJobId() == JobId.NIL) {
|
||||
rayConfig.setJobId(getGcsClient().nextJobId());
|
||||
}
|
||||
int numWorkersPerProcess =
|
||||
rayConfig.workerMode == WorkerType.DRIVER ? 1 : rayConfig.numWorkersPerProcess;
|
||||
// Make sure the job id has been set already.
|
||||
functionManager = new FunctionManager(rayConfig.codeSearchPath);
|
||||
|
||||
byte[] serializedJobConfig = null;
|
||||
if (rayConfig.workerMode == WorkerType.DRIVER) {
|
||||
JobConfig.Builder jobConfigBuilder =
|
||||
JobConfig.newBuilder()
|
||||
.setNumJavaWorkersPerProcess(rayConfig.numWorkersPerProcess)
|
||||
.addAllJvmOptions(rayConfig.jvmOptionsForJavaWorker)
|
||||
.addAllCodeSearchPath(rayConfig.codeSearchPath)
|
||||
.setRayNamespace(rayConfig.namespace);
|
||||
|
@ -150,7 +150,6 @@ public final class RayNativeRuntime extends AbstractRayRuntime {
|
|||
rayConfig.rayletSocketName,
|
||||
(rayConfig.workerMode == WorkerType.DRIVER ? rayConfig.getJobId() : JobId.NIL).getBytes(),
|
||||
new GcsClientOptions(rayConfig),
|
||||
numWorkersPerProcess,
|
||||
rayConfig.logDir,
|
||||
serializedJobConfig,
|
||||
rayConfig.getStartupToken(),
|
||||
|
@ -223,19 +222,6 @@ public final class RayNativeRuntime extends AbstractRayRuntime {
|
|||
nativeKillActor(actor.getId().getBytes(), noRestart);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Object getAsyncContext() {
|
||||
return new AsyncContext(
|
||||
workerContext.getCurrentWorkerId(), workerContext.getCurrentClassLoader());
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setAsyncContext(Object asyncContext) {
|
||||
nativeSetCoreWorker(((AsyncContext) asyncContext).workerId.getBytes());
|
||||
workerContext.setCurrentClassLoader(((AsyncContext) asyncContext).currentClassLoader);
|
||||
super.setAsyncContext(asyncContext);
|
||||
}
|
||||
|
||||
@Override
|
||||
List<ObjectId> getCurrentReturnIds(int numReturns, ActorId actorId) {
|
||||
List<byte[]> ret = nativeGetCurrentReturnIds(numReturns, actorId.getBytes());
|
||||
|
@ -289,7 +275,6 @@ public final class RayNativeRuntime extends AbstractRayRuntime {
|
|||
String rayletSocket,
|
||||
byte[] jobId,
|
||||
GcsClientOptions gcsClientOptions,
|
||||
int numWorkersPerProcess,
|
||||
String logDir,
|
||||
byte[] serializedJobConfig,
|
||||
int startupToken,
|
||||
|
|
|
@ -26,7 +26,5 @@ public interface RayRuntimeInternal extends RayRuntime {
|
|||
|
||||
GcsClient getGcsClient();
|
||||
|
||||
void setIsContextSet(boolean isContextSet);
|
||||
|
||||
void run();
|
||||
}
|
||||
|
|
|
@ -1,80 +0,0 @@
|
|||
package io.ray.runtime;
|
||||
|
||||
import io.ray.api.runtime.RayRuntime;
|
||||
import io.ray.runtime.config.RunMode;
|
||||
import io.ray.runtime.exception.RayException;
|
||||
import java.lang.reflect.InvocationHandler;
|
||||
import java.lang.reflect.InvocationTargetException;
|
||||
import java.lang.reflect.Method;
|
||||
|
||||
/**
|
||||
* Protect a ray runtime with context checks for all methods of {@link RayRuntime} (except {@link
|
||||
* RayRuntime#shutdown(boolean)}).
|
||||
*/
|
||||
public class RayRuntimeProxy implements InvocationHandler {
|
||||
|
||||
/** The original runtime. */
|
||||
private AbstractRayRuntime obj;
|
||||
|
||||
private RayRuntimeProxy(AbstractRayRuntime obj) {
|
||||
this.obj = obj;
|
||||
}
|
||||
|
||||
public AbstractRayRuntime getRuntimeObject() {
|
||||
return obj;
|
||||
}
|
||||
|
||||
/** Generate a new instance of {@link RayRuntimeInternal} with additional context check. */
|
||||
static RayRuntimeInternal newInstance(AbstractRayRuntime obj) {
|
||||
return (RayRuntimeInternal)
|
||||
java.lang.reflect.Proxy.newProxyInstance(
|
||||
obj.getClass().getClassLoader(),
|
||||
new Class<?>[] {RayRuntimeInternal.class},
|
||||
new RayRuntimeProxy(obj));
|
||||
}
|
||||
|
||||
@Override
|
||||
public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
|
||||
if (isInterfaceMethod(method)
|
||||
&& !method.getName().equals("shutdown")
|
||||
&& !method.getName().equals("setAsyncContext")) {
|
||||
checkIsContextSet();
|
||||
}
|
||||
try {
|
||||
return method.invoke(obj, args);
|
||||
} catch (InvocationTargetException e) {
|
||||
if (e.getCause() != null) {
|
||||
throw e.getCause();
|
||||
} else {
|
||||
throw e;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/** Whether the method is defined in the {@link RayRuntime} interface. */
|
||||
private boolean isInterfaceMethod(Method method) {
|
||||
try {
|
||||
RayRuntime.class.getMethod(method.getName(), method.getParameterTypes());
|
||||
return true;
|
||||
} catch (NoSuchMethodException e) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if thread context is set.
|
||||
*
|
||||
* <p>This method should be invoked at the beginning of most public methods of {@link RayRuntime},
|
||||
* otherwise the native code might crash due to thread local core worker was not set. We check it
|
||||
* for {@link AbstractRayRuntime} instead of {@link RayNativeRuntime} because we want to catch the
|
||||
* error even if the application runs in {@link RunMode#SINGLE_PROCESS} mode.
|
||||
*/
|
||||
private void checkIsContextSet() {
|
||||
if (!obj.isContextSet.get()) {
|
||||
throw new RayException(
|
||||
"`Ray.wrap***` is not called on the current thread."
|
||||
+ " If you want to use Ray API in your own threads,"
|
||||
+ " please wrap your executable with `Ray.wrap***`.");
|
||||
}
|
||||
}
|
||||
}
|
|
@ -75,8 +75,6 @@ public class RayConfig {
|
|||
|
||||
public final List<String> headArgs;
|
||||
|
||||
public final int numWorkersPerProcess;
|
||||
|
||||
public final String namespace;
|
||||
|
||||
public final List<String> jvmOptionsForJavaWorker;
|
||||
|
@ -190,8 +188,6 @@ public class RayConfig {
|
|||
}
|
||||
codeSearchPath = Arrays.asList(codeSearchPathString.split(":"));
|
||||
|
||||
numWorkersPerProcess = config.getInt("ray.job.num-java-workers-per-process");
|
||||
|
||||
startupToken = config.getInt("ray.raylet.startup-token");
|
||||
|
||||
/// Driver needn't this config item.
|
||||
|
|
|
@ -48,31 +48,21 @@ public class LocalModeWorkerContext implements WorkerContext {
|
|||
@Override
|
||||
public ActorId getCurrentActorId() {
|
||||
TaskSpec taskSpec = currentTask.get();
|
||||
if (taskSpec == null) {
|
||||
return ActorId.NIL;
|
||||
}
|
||||
checkTaskSpecNotNull(taskSpec);
|
||||
return LocalModeTaskSubmitter.getActorId(taskSpec);
|
||||
}
|
||||
|
||||
@Override
|
||||
public ClassLoader getCurrentClassLoader() {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setCurrentClassLoader(ClassLoader currentClassLoader) {}
|
||||
|
||||
@Override
|
||||
public TaskType getCurrentTaskType() {
|
||||
TaskSpec taskSpec = currentTask.get();
|
||||
Preconditions.checkNotNull(taskSpec, "Current task is not set.");
|
||||
checkTaskSpecNotNull(taskSpec);
|
||||
return taskSpec.getType();
|
||||
}
|
||||
|
||||
@Override
|
||||
public TaskId getCurrentTaskId() {
|
||||
TaskSpec taskSpec = currentTask.get();
|
||||
Preconditions.checkState(taskSpec != null);
|
||||
checkTaskSpecNotNull(taskSpec);
|
||||
return TaskId.fromBytes(taskSpec.getTaskId().toByteArray());
|
||||
}
|
||||
|
||||
|
@ -85,7 +75,9 @@ public class LocalModeWorkerContext implements WorkerContext {
|
|||
currentTask.set(taskSpec);
|
||||
}
|
||||
|
||||
public TaskSpec getCurrentTask() {
|
||||
return currentTask.get();
|
||||
private static void checkTaskSpecNotNull(TaskSpec taskSpec) {
|
||||
Preconditions.checkNotNull(
|
||||
taskSpec,
|
||||
"Current task is not set. Maybe you invoked this API in a user-created thread not managed by Ray. Invoking this API in a user-created thread is not supported yet in local mode. You can switch to cluster mode.");
|
||||
}
|
||||
}
|
||||
|
|
|
@ -12,7 +12,7 @@ import java.nio.ByteBuffer;
|
|||
/** Worker context for cluster mode. This is a wrapper class for worker context of core worker. */
|
||||
public class NativeWorkerContext implements WorkerContext {
|
||||
|
||||
private final ThreadLocal<ClassLoader> currentClassLoader = new ThreadLocal<>();
|
||||
private ClassLoader currentClassLoader = null;
|
||||
|
||||
@Override
|
||||
public UniqueId getCurrentWorkerId() {
|
||||
|
@ -29,18 +29,6 @@ public class NativeWorkerContext implements WorkerContext {
|
|||
return ActorId.fromByteBuffer(nativeGetCurrentActorId());
|
||||
}
|
||||
|
||||
@Override
|
||||
public ClassLoader getCurrentClassLoader() {
|
||||
return currentClassLoader.get();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setCurrentClassLoader(ClassLoader currentClassLoader) {
|
||||
if (this.currentClassLoader.get() != currentClassLoader) {
|
||||
this.currentClassLoader.set(currentClassLoader);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public TaskType getCurrentTaskType() {
|
||||
return TaskType.forNumber(nativeGetCurrentTaskType());
|
||||
|
|
|
@ -19,15 +19,6 @@ public interface WorkerContext {
|
|||
/** ID of the current actor. */
|
||||
ActorId getCurrentActorId();
|
||||
|
||||
/**
|
||||
* The class loader that is associated with the current job. It's used for locating classes when
|
||||
* dealing with serialization and deserialization in {@link Serializer}.
|
||||
*/
|
||||
ClassLoader getCurrentClassLoader();
|
||||
|
||||
/** Set the current class loader. */
|
||||
void setCurrentClassLoader(ClassLoader currentClassLoader);
|
||||
|
||||
/** Type of the current task. */
|
||||
TaskType getCurrentTaskType();
|
||||
|
||||
|
|
|
@ -2,7 +2,6 @@ package io.ray.runtime.functionmanager;
|
|||
|
||||
import com.google.common.collect.Lists;
|
||||
import io.ray.api.function.RayFunc;
|
||||
import io.ray.api.id.JobId;
|
||||
import io.ray.runtime.util.LambdaUtils;
|
||||
import java.io.File;
|
||||
import java.lang.invoke.SerializedLambda;
|
||||
|
@ -34,7 +33,7 @@ import org.objectweb.asm.Type;
|
|||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
/** Manages functions by job id. */
|
||||
/** Manages functions in the current worker. */
|
||||
public class FunctionManager {
|
||||
|
||||
private static final Logger LOGGER = LoggerFactory.getLogger(FunctionManager.class);
|
||||
|
@ -50,8 +49,8 @@ public class FunctionManager {
|
|||
private static final ThreadLocal<WeakHashMap<Class<? extends RayFunc>, JavaFunctionDescriptor>>
|
||||
RAY_FUNC_CACHE = ThreadLocal.withInitial(WeakHashMap::new);
|
||||
|
||||
/** Mapping from the job id to the functions that belong to this job. */
|
||||
private ConcurrentMap<JobId, JobFunctionTable> jobFunctionTables = new ConcurrentHashMap<>();
|
||||
/** The table that manages functions. */
|
||||
private final JobFunctionTable jobFunctionTable;
|
||||
|
||||
/** The resource path which we can load the job's jar resources. */
|
||||
private final List<String> codeSearchPath;
|
||||
|
@ -63,16 +62,20 @@ public class FunctionManager {
|
|||
*/
|
||||
public FunctionManager(List<String> codeSearchPath) {
|
||||
this.codeSearchPath = codeSearchPath;
|
||||
jobFunctionTable = createJobFunctionTable();
|
||||
}
|
||||
|
||||
public ClassLoader getClassLoader() {
|
||||
return jobFunctionTable.classLoader;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the RayFunction from a RayFunc instance (a lambda).
|
||||
*
|
||||
* @param jobId current job id.
|
||||
* @param func The lambda.
|
||||
* @return A RayFunction object.
|
||||
*/
|
||||
public RayFunction getFunction(JobId jobId, RayFunc func) {
|
||||
public RayFunction getFunction(RayFunc func) {
|
||||
JavaFunctionDescriptor functionDescriptor = RAY_FUNC_CACHE.get().get(func.getClass());
|
||||
if (functionDescriptor == null) {
|
||||
// It's OK to not lock here, because it's OK to have multiple JavaFunctionDescriptor instances
|
||||
|
@ -84,31 +87,21 @@ public class FunctionManager {
|
|||
functionDescriptor = new JavaFunctionDescriptor(className, methodName, signature);
|
||||
RAY_FUNC_CACHE.get().put(func.getClass(), functionDescriptor);
|
||||
}
|
||||
return getFunction(jobId, functionDescriptor);
|
||||
return getFunction(functionDescriptor);
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the RayFunction from a function descriptor.
|
||||
*
|
||||
* @param jobId Current job id.
|
||||
* @param functionDescriptor The function descriptor.
|
||||
* @return A RayFunction object.
|
||||
*/
|
||||
public RayFunction getFunction(JobId jobId, JavaFunctionDescriptor functionDescriptor) {
|
||||
JobFunctionTable jobFunctionTable = jobFunctionTables.get(jobId);
|
||||
if (jobFunctionTable == null) {
|
||||
synchronized (this) {
|
||||
jobFunctionTable = jobFunctionTables.get(jobId);
|
||||
if (jobFunctionTable == null) {
|
||||
jobFunctionTable = createJobFunctionTable(jobId);
|
||||
jobFunctionTables.put(jobId, jobFunctionTable);
|
||||
}
|
||||
}
|
||||
}
|
||||
public RayFunction getFunction(JavaFunctionDescriptor functionDescriptor) {
|
||||
return jobFunctionTable.getFunction(functionDescriptor);
|
||||
}
|
||||
|
||||
private JobFunctionTable createJobFunctionTable(JobId jobId) {
|
||||
/** A helper that creates function table. */
|
||||
private JobFunctionTable createJobFunctionTable() {
|
||||
ClassLoader classLoader;
|
||||
if (codeSearchPath == null || codeSearchPath.isEmpty()) {
|
||||
classLoader = getClass().getClassLoader();
|
||||
|
@ -145,7 +138,7 @@ public class FunctionManager {
|
|||
})
|
||||
.toArray(URL[]::new);
|
||||
classLoader = new URLClassLoader(urls);
|
||||
LOGGER.debug("Resource loaded for job {} from path {}.", jobId, urls);
|
||||
LOGGER.debug("Resource loaded from path {}.", urls);
|
||||
}
|
||||
|
||||
return new JobFunctionTable(classLoader);
|
||||
|
|
|
@ -496,7 +496,6 @@ public class LocalModeTaskSubmitter implements TaskSubmitter {
|
|||
? objectStore.getRaw(Collections.singletonList(arg.id), -1).get(0)
|
||||
: arg.value)
|
||||
.collect(Collectors.toList());
|
||||
runtime.setIsContextSet(true);
|
||||
((LocalModeWorkerContext) runtime.getWorkerContext()).setCurrentTask(taskSpec);
|
||||
|
||||
((LocalModeWorkerContext) runtime.getWorkerContext()).setCurrentWorkerId(workerId);
|
||||
|
@ -513,9 +512,7 @@ public class LocalModeTaskSubmitter implements TaskSubmitter {
|
|||
// Set this flag to true is necessary because at the end of `taskExecutor.execute()`,
|
||||
// this flag will be set to false. And `runtime.getWorkerContext()` requires it to be
|
||||
// true.
|
||||
runtime.setIsContextSet(true);
|
||||
((LocalModeWorkerContext) runtime.getWorkerContext()).setCurrentTask(null);
|
||||
runtime.setIsContextSet(false);
|
||||
List<ObjectId> returnIds = getReturnIds(taskSpec);
|
||||
for (int i = 0; i < returnIds.size(); i++) {
|
||||
NativeRayObject putObject;
|
||||
|
|
|
@ -32,6 +32,7 @@ public abstract class TaskExecutor<T extends TaskExecutor.ActorContext> {
|
|||
|
||||
protected final RayRuntimeInternal runtime;
|
||||
|
||||
// TODO(qwang): Use actorContext instead later.
|
||||
private final ConcurrentHashMap<UniqueId, T> actorContextMap = new ConcurrentHashMap<>();
|
||||
|
||||
private final ThreadLocal<RayFunction> localRayFunction = new ThreadLocal<>();
|
||||
|
@ -66,7 +67,7 @@ public abstract class TaskExecutor<T extends TaskExecutor.ActorContext> {
|
|||
private RayFunction getRayFunction(List<String> rayFunctionInfo) {
|
||||
JobId jobId = runtime.getWorkerContext().getCurrentJobId();
|
||||
JavaFunctionDescriptor functionDescriptor = parseFunctionDescriptor(rayFunctionInfo);
|
||||
return runtime.getFunctionManager().getFunction(jobId, functionDescriptor);
|
||||
return runtime.getFunctionManager().getFunction(functionDescriptor);
|
||||
}
|
||||
|
||||
/** The return value indicates which parameters are ByteBuffer. */
|
||||
|
@ -93,7 +94,6 @@ public abstract class TaskExecutor<T extends TaskExecutor.ActorContext> {
|
|||
}
|
||||
|
||||
protected List<NativeRayObject> execute(List<String> rayFunctionInfo, List<Object> argsBytes) {
|
||||
runtime.setIsContextSet(true);
|
||||
TaskType taskType = runtime.getWorkerContext().getCurrentTaskType();
|
||||
TaskId taskId = runtime.getWorkerContext().getCurrentTaskId();
|
||||
LOGGER.debug("Executing task {} {}", taskId, rayFunctionInfo);
|
||||
|
@ -108,7 +108,6 @@ public abstract class TaskExecutor<T extends TaskExecutor.ActorContext> {
|
|||
}
|
||||
|
||||
List<NativeRayObject> returnObjects = new ArrayList<>();
|
||||
ClassLoader oldLoader = Thread.currentThread().getContextClassLoader();
|
||||
// Find the executable object.
|
||||
|
||||
RayFunction rayFunction = localRayFunction.get();
|
||||
|
@ -121,7 +120,6 @@ public abstract class TaskExecutor<T extends TaskExecutor.ActorContext> {
|
|||
rayFunction = getRayFunction(rayFunctionInfo);
|
||||
}
|
||||
Thread.currentThread().setContextClassLoader(rayFunction.classLoader);
|
||||
runtime.getWorkerContext().setCurrentClassLoader(rayFunction.classLoader);
|
||||
|
||||
// Get local actor object and arguments.
|
||||
Object actor = null;
|
||||
|
@ -215,10 +213,6 @@ public abstract class TaskExecutor<T extends TaskExecutor.ActorContext> {
|
|||
} else {
|
||||
throw new RayActorException(e);
|
||||
}
|
||||
} finally {
|
||||
Thread.currentThread().setContextClassLoader(oldLoader);
|
||||
runtime.getWorkerContext().setCurrentClassLoader(null);
|
||||
runtime.setIsContextSet(false);
|
||||
}
|
||||
return returnObjects;
|
||||
}
|
||||
|
|
|
@ -47,7 +47,7 @@ public final class MethodUtils {
|
|||
/// This code path indicates that here might be in another thread of a worker.
|
||||
/// So try to load the class from URLClassLoader of this worker.
|
||||
ClassLoader cl =
|
||||
((RayRuntimeInternal) Ray.internal()).getWorkerContext().getCurrentClassLoader();
|
||||
((RayRuntimeInternal) Ray.internal()).getFunctionManager().getClassLoader();
|
||||
actorClz = Class.forName(className, true, cl);
|
||||
}
|
||||
} catch (Exception e) {
|
||||
|
|
|
@ -28,9 +28,7 @@ public class ParallelActorContextImpl implements ParallelActorContext {
|
|||
|
||||
FunctionManager functionManager = ((RayRuntimeInternal) Ray.internal()).getFunctionManager();
|
||||
JavaFunctionDescriptor functionDescriptor =
|
||||
functionManager
|
||||
.getFunction(Ray.getRuntimeContext().getCurrentJobId(), ctorFunc)
|
||||
.getFunctionDescriptor();
|
||||
functionManager.getFunction(ctorFunc).getFunctionDescriptor();
|
||||
ActorHandle<ParallelActorExecutorImpl> parallelExecutorHandle =
|
||||
Ray.actor(ParallelActorExecutorImpl::new, parallelism, functionDescriptor)
|
||||
.setConcurrencyGroups(concurrencyGroups)
|
||||
|
@ -46,9 +44,7 @@ public class ParallelActorContextImpl implements ParallelActorContext {
|
|||
((ParallelActorHandleImpl) parallelActorHandle).getExecutor();
|
||||
FunctionManager functionManager = ((RayRuntimeInternal) Ray.internal()).getFunctionManager();
|
||||
JavaFunctionDescriptor functionDescriptor =
|
||||
functionManager
|
||||
.getFunction(Ray.getRuntimeContext().getCurrentJobId(), func)
|
||||
.getFunctionDescriptor();
|
||||
functionManager.getFunction(func).getFunctionDescriptor();
|
||||
ObjectRef<Object> ret =
|
||||
parallelExecutor
|
||||
.task(ParallelActorExecutorImpl::execute, instanceId, functionDescriptor, args)
|
||||
|
|
|
@ -23,9 +23,7 @@ public class ParallelActorExecutorImpl {
|
|||
throws InvocationTargetException, IllegalAccessException {
|
||||
|
||||
functionManager = ((RayRuntimeInternal) Ray.internal()).getFunctionManager();
|
||||
RayFunction init =
|
||||
functionManager.getFunction(
|
||||
Ray.getRuntimeContext().getCurrentJobId(), javaFunctionDescriptor);
|
||||
RayFunction init = functionManager.getFunction(javaFunctionDescriptor);
|
||||
Thread.currentThread().setContextClassLoader(init.classLoader);
|
||||
for (int i = 0; i < parallelism; ++i) {
|
||||
Object instance = init.getMethod().invoke(null, null);
|
||||
|
@ -35,8 +33,7 @@ public class ParallelActorExecutorImpl {
|
|||
|
||||
public Object execute(int instanceId, JavaFunctionDescriptor functionDescriptor, Object[] args)
|
||||
throws IllegalAccessException, InvocationTargetException {
|
||||
RayFunction func =
|
||||
functionManager.getFunction(Ray.getRuntimeContext().getCurrentJobId(), functionDescriptor);
|
||||
RayFunction func = functionManager.getFunction(functionDescriptor);
|
||||
Preconditions.checkState(instances.containsKey(instanceId));
|
||||
return func.getMethod().invoke(instances.get(instanceId), args);
|
||||
}
|
||||
|
|
|
@ -27,8 +27,6 @@ ray {
|
|||
// search path for user code. This will be used as `CLASSPATH` in Java,
|
||||
// and `PYTHONPATH` in Python.
|
||||
code-search-path: ""
|
||||
/// The number of java worker per worker process.
|
||||
num-java-workers-per-process: 1
|
||||
/// The jvm options for java workers of the job.
|
||||
jvm-options: []
|
||||
|
||||
|
|
|
@ -61,6 +61,8 @@ public class FunctionManagerTest {
|
|||
}
|
||||
}
|
||||
|
||||
private static final JobId JOB_ID = JobId.fromInt(1);
|
||||
|
||||
private static RayFunc0<Object> fooFunc;
|
||||
private static RayFunc1<ChildClass, Object> childClassBarFunc;
|
||||
private static RayFunc0<ChildClass> childClassConstructor;
|
||||
|
@ -95,17 +97,17 @@ public class FunctionManagerTest {
|
|||
public void testGetFunctionFromRayFunc() {
|
||||
final FunctionManager functionManager = new FunctionManager(null);
|
||||
// Test normal function.
|
||||
RayFunction func = functionManager.getFunction(JobId.NIL, fooFunc);
|
||||
RayFunction func = functionManager.getFunction(fooFunc);
|
||||
Assert.assertFalse(func.isConstructor());
|
||||
Assert.assertEquals(func.getFunctionDescriptor(), fooDescriptor);
|
||||
|
||||
// Test actor method
|
||||
func = functionManager.getFunction(JobId.NIL, childClassBarFunc);
|
||||
func = functionManager.getFunction(childClassBarFunc);
|
||||
Assert.assertFalse(func.isConstructor());
|
||||
Assert.assertEquals(func.getFunctionDescriptor(), childClassBarDescriptor);
|
||||
|
||||
// Test actor constructor
|
||||
func = functionManager.getFunction(JobId.NIL, childClassConstructor);
|
||||
func = functionManager.getFunction(childClassConstructor);
|
||||
Assert.assertTrue(func.isConstructor());
|
||||
Assert.assertEquals(func.getFunctionDescriptor(), childClassConstructorDescriptor);
|
||||
}
|
||||
|
@ -114,17 +116,17 @@ public class FunctionManagerTest {
|
|||
public void testGetFunctionFromFunctionDescriptor() {
|
||||
final FunctionManager functionManager = new FunctionManager(null);
|
||||
// Test normal function.
|
||||
RayFunction func = functionManager.getFunction(JobId.NIL, fooDescriptor);
|
||||
RayFunction func = functionManager.getFunction(fooDescriptor);
|
||||
Assert.assertFalse(func.isConstructor());
|
||||
Assert.assertEquals(func.getFunctionDescriptor(), fooDescriptor);
|
||||
|
||||
// Test actor method
|
||||
func = functionManager.getFunction(JobId.NIL, childClassBarDescriptor);
|
||||
func = functionManager.getFunction(childClassBarDescriptor);
|
||||
Assert.assertFalse(func.isConstructor());
|
||||
Assert.assertEquals(func.getFunctionDescriptor(), childClassBarDescriptor);
|
||||
|
||||
// Test actor constructor
|
||||
func = functionManager.getFunction(JobId.NIL, childClassConstructorDescriptor);
|
||||
func = functionManager.getFunction(childClassConstructorDescriptor);
|
||||
Assert.assertTrue(func.isConstructor());
|
||||
Assert.assertEquals(func.getFunctionDescriptor(), childClassConstructorDescriptor);
|
||||
|
||||
|
@ -133,7 +135,6 @@ public class FunctionManagerTest {
|
|||
RuntimeException.class,
|
||||
() -> {
|
||||
functionManager.getFunction(
|
||||
JobId.NIL,
|
||||
new JavaFunctionDescriptor(
|
||||
FunctionManagerTest.class.getName(), "overloadFunction", ""));
|
||||
});
|
||||
|
@ -146,11 +147,10 @@ public class FunctionManagerTest {
|
|||
fooDescriptor =
|
||||
new JavaFunctionDescriptor(ParentClass.class.getName(), "foo", "()Ljava/lang/Object;");
|
||||
Assert.assertEquals(
|
||||
functionManager.getFunction(JobId.NIL, fooDescriptor).executable.getDeclaringClass(),
|
||||
functionManager.getFunction(fooDescriptor).executable.getDeclaringClass(),
|
||||
ParentClass.class);
|
||||
RayFunction fooFunc =
|
||||
functionManager.getFunction(
|
||||
JobId.NIL,
|
||||
new JavaFunctionDescriptor(ChildClass.class.getName(), "foo", "()Ljava/lang/Object;"));
|
||||
Assert.assertEquals(fooFunc.executable.getDeclaringClass(), ParentClass.class);
|
||||
|
||||
|
@ -159,21 +159,16 @@ public class FunctionManagerTest {
|
|||
childClassBarDescriptor =
|
||||
new JavaFunctionDescriptor(ParentClass.class.getName(), "bar", "()Ljava/lang/Object;");
|
||||
Assert.assertEquals(
|
||||
functionManager
|
||||
.getFunction(JobId.NIL, childClassBarDescriptor)
|
||||
.executable
|
||||
.getDeclaringClass(),
|
||||
functionManager.getFunction(childClassBarDescriptor).executable.getDeclaringClass(),
|
||||
ParentClass.class);
|
||||
RayFunction barFunc =
|
||||
functionManager.getFunction(
|
||||
JobId.NIL,
|
||||
new JavaFunctionDescriptor(ChildClass.class.getName(), "bar", "()Ljava/lang/Object;"));
|
||||
Assert.assertEquals(barFunc.executable.getDeclaringClass(), ChildClass.class);
|
||||
|
||||
// Check interface default methods.
|
||||
RayFunction interfaceNameFunc =
|
||||
functionManager.getFunction(
|
||||
JobId.NIL,
|
||||
new JavaFunctionDescriptor(
|
||||
ChildClass.class.getName(), "interfaceName", "()Ljava/lang/String;"));
|
||||
Assert.assertEquals(
|
||||
|
@ -220,7 +215,6 @@ public class FunctionManagerTest {
|
|||
|
||||
@Test
|
||||
public void testGetFunctionFromLocalResource() throws Exception {
|
||||
JobId jobId = JobId.fromInt(1);
|
||||
final String codeSearchPath = FileUtils.getTempDirectoryPath() + "/ray_test_resources/";
|
||||
File jobResourceDir = new File(codeSearchPath);
|
||||
FileUtils.deleteQuietly(jobResourceDir);
|
||||
|
@ -250,7 +244,7 @@ public class FunctionManagerTest {
|
|||
new JavaFunctionDescriptor("DemoApp", "hello", "()Ljava/lang/String;");
|
||||
final FunctionManager functionManager =
|
||||
new FunctionManager(Collections.singletonList(codeSearchPath));
|
||||
RayFunction func = functionManager.getFunction(jobId, descriptor);
|
||||
RayFunction func = functionManager.getFunction(descriptor);
|
||||
Assert.assertEquals(func.getFunctionDescriptor(), descriptor);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
package io.ray.serve;
|
||||
|
||||
import com.google.common.collect.ImmutableMap;
|
||||
import io.ray.api.Ray;
|
||||
import io.ray.runtime.metric.Count;
|
||||
import io.ray.runtime.metric.Metrics;
|
||||
import io.ray.runtime.metric.TagKey;
|
||||
|
@ -47,8 +46,6 @@ public class HttpProxy implements ServeProxy {
|
|||
|
||||
private ProxyRouter proxyRouter;
|
||||
|
||||
private Object asyncContext = Ray.getAsyncContext();
|
||||
|
||||
@Override
|
||||
public void init(Map<String, String> config, ProxyRouter proxyRouter) {
|
||||
this.port =
|
||||
|
@ -101,8 +98,6 @@ public class HttpProxy implements ServeProxy {
|
|||
ClassicHttpRequest request, ClassicHttpResponse response, HttpContext context)
|
||||
throws HttpException, IOException {
|
||||
|
||||
Ray.setAsyncContext(asyncContext);
|
||||
|
||||
int code = HttpURLConnection.HTTP_OK;
|
||||
Object result = null;
|
||||
String route = request.getPath();
|
||||
|
|
|
@ -122,8 +122,7 @@ for file in "$docdemo_path"*.java; do
|
|||
file=${file#"$docdemo_path"}
|
||||
class=${file%".java"}
|
||||
echo "Running $class"
|
||||
java -cp bazel-bin/java/all_tests_shaded.jar -Dray.job.num-java-workers-per-process=1\
|
||||
-Dray.raylet.startup-token=0 "io.ray.docdemo.$class"
|
||||
java -cp bazel-bin/java/all_tests_shaded.jar -Dray.raylet.startup-token=0 "io.ray.docdemo.$class"
|
||||
done
|
||||
popd
|
||||
|
||||
|
|
|
@ -4,13 +4,11 @@ import io.ray.api.ActorHandle;
|
|||
import io.ray.api.ObjectRef;
|
||||
import io.ray.api.Ray;
|
||||
import io.ray.runtime.actor.NativeActorHandle;
|
||||
import io.ray.runtime.exception.RayActorException;
|
||||
import io.ray.runtime.util.SystemUtil;
|
||||
import java.lang.ref.Reference;
|
||||
import java.lang.reflect.Field;
|
||||
import java.lang.reflect.Method;
|
||||
import java.util.Set;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import org.testng.Assert;
|
||||
|
@ -65,7 +63,6 @@ public class ActorHandleReferenceCountTest {
|
|||
|
||||
public void testActorHandleReferenceCount() {
|
||||
try {
|
||||
System.setProperty("ray.job.num-java-workers-per-process", "1");
|
||||
Ray.init();
|
||||
ActorHandle<SignalActor> signal = Ray.actor(SignalActor::new).remote();
|
||||
ActorHandle<MyActor> myActor = Ray.actor(MyActor::new).remote();
|
||||
|
@ -83,27 +80,4 @@ public class ActorHandleReferenceCountTest {
|
|||
Ray.shutdown();
|
||||
}
|
||||
}
|
||||
|
||||
public void testRemoveActorHandleReferenceInMultipleThreadedActor() throws InterruptedException {
|
||||
System.setProperty("ray.job.num-java-workers-per-process", "5");
|
||||
try {
|
||||
Ray.init();
|
||||
ActorHandle<MyActor> myActor1 = Ray.actor(MyActor::new).remote();
|
||||
int pid1 = myActor1.task(MyActor::getPid).remote().get();
|
||||
ActorHandle<MyActor> myActor2 = Ray.actor(MyActor::new).remote();
|
||||
int pid2 = myActor2.task(MyActor::getPid).remote().get();
|
||||
Assert.assertEquals(pid1, pid2);
|
||||
del(myActor1);
|
||||
TimeUnit.SECONDS.sleep(5);
|
||||
Assert.assertThrows(
|
||||
RayActorException.class,
|
||||
() -> {
|
||||
myActor1.task(MyActor::hello).remote().get();
|
||||
});
|
||||
/// myActor2 shouldn't be killed.
|
||||
Assert.assertEquals("hello", myActor2.task(MyActor::hello).remote().get());
|
||||
} finally {
|
||||
Ray.shutdown();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,167 +0,0 @@
|
|||
package io.ray.test;
|
||||
|
||||
import io.ray.api.ActorHandle;
|
||||
import io.ray.api.BaseActorHandle;
|
||||
import io.ray.api.ObjectRef;
|
||||
import io.ray.api.options.ActorCreationOptions;
|
||||
import io.ray.api.options.CallOptions;
|
||||
import io.ray.runtime.AbstractRayRuntime;
|
||||
import io.ray.runtime.functionmanager.FunctionDescriptor;
|
||||
import io.ray.runtime.functionmanager.JavaFunctionDescriptor;
|
||||
import java.io.File;
|
||||
import java.lang.reflect.Method;
|
||||
import java.nio.file.Files;
|
||||
import java.nio.file.Paths;
|
||||
import java.util.Optional;
|
||||
import javax.tools.JavaCompiler;
|
||||
import javax.tools.ToolProvider;
|
||||
import org.apache.commons.io.FileUtils;
|
||||
import org.testng.Assert;
|
||||
import org.testng.annotations.BeforeClass;
|
||||
import org.testng.annotations.Test;
|
||||
|
||||
public class ClassLoaderTest extends BaseTest {
|
||||
|
||||
private final String codeSearchPath =
|
||||
FileUtils.getTempDirectoryPath() + "/ray_test/ClassLoaderTest";
|
||||
|
||||
@BeforeClass
|
||||
public void setUp() {
|
||||
// The potential issue of multiple `ClassLoader` instances for the same job on multi-threading
|
||||
// scenario only occurs if the classes are loaded from the job code search path.
|
||||
System.setProperty("ray.job.code-search-path", codeSearchPath);
|
||||
}
|
||||
|
||||
@Test(groups = {"cluster"})
|
||||
public void testClassLoaderInMultiThreading() throws Exception {
|
||||
File jobResourceDir = new File(codeSearchPath);
|
||||
FileUtils.deleteQuietly(jobResourceDir);
|
||||
jobResourceDir.mkdirs();
|
||||
jobResourceDir.deleteOnExit();
|
||||
|
||||
// In this test case the class is expected to be loaded from the job code search path,
|
||||
// so we need to put the compiled class file into the job code search path and load it
|
||||
// later.
|
||||
String testJavaFile =
|
||||
""
|
||||
+ "import java.lang.management.ManagementFactory;\n"
|
||||
+ "import java.lang.management.RuntimeMXBean;\n"
|
||||
+ "\n"
|
||||
+ "public class ClassLoaderTester {\n"
|
||||
+ "\n"
|
||||
+ " static volatile int value;\n"
|
||||
+ "\n"
|
||||
+ " public int getPid() {\n"
|
||||
+ " RuntimeMXBean runtime = ManagementFactory.getRuntimeMXBean();\n"
|
||||
+ " String name = runtime.getName();\n"
|
||||
+ " int index = name.indexOf(\"@\");\n"
|
||||
+ " if (index != -1) {\n"
|
||||
+ " return Integer.parseInt(name.substring(0, index));\n"
|
||||
+ " } else {\n"
|
||||
+ " throw new RuntimeException(\"parse pid error:\" + name);\n"
|
||||
+ " }\n"
|
||||
+ " }\n"
|
||||
+ "\n"
|
||||
+ " public int increase() throws InterruptedException {\n"
|
||||
+ " return increaseInternal();\n"
|
||||
+ " }\n"
|
||||
+ "\n"
|
||||
+ " public static synchronized int increaseInternal() throws InterruptedException {\n"
|
||||
+ " int oldValue = value;\n"
|
||||
+ " Thread.sleep(10 * 1000);\n"
|
||||
+ " value = oldValue + 1;\n"
|
||||
+ " return value;\n"
|
||||
+ " }\n"
|
||||
+ "\n"
|
||||
+ " public int getClassLoaderHashCode() {\n"
|
||||
+ " return this.getClass().getClassLoader().hashCode();\n"
|
||||
+ " }\n"
|
||||
+ "}";
|
||||
|
||||
// Write the demo java file to the job code search path.
|
||||
String javaFilePath = codeSearchPath + "/ClassLoaderTester.java";
|
||||
Files.write(Paths.get(javaFilePath), testJavaFile.getBytes());
|
||||
|
||||
// Compile the java file.
|
||||
JavaCompiler compiler = ToolProvider.getSystemJavaCompiler();
|
||||
int result = compiler.run(null, null, null, "-d", codeSearchPath, javaFilePath);
|
||||
if (result != 0) {
|
||||
throw new RuntimeException("Couldn't compile ClassLoaderTester.java.");
|
||||
}
|
||||
|
||||
FunctionDescriptor constructor =
|
||||
new JavaFunctionDescriptor("ClassLoaderTester", "<init>", "()V");
|
||||
ActorHandle<?> actor1 = createActor(constructor);
|
||||
FunctionDescriptor getPid = new JavaFunctionDescriptor("ClassLoaderTester", "getPid", "()I");
|
||||
int pid =
|
||||
this.<Integer>callActorFunction(actor1, getPid, new Object[0], Optional.of(Integer.class))
|
||||
.get();
|
||||
ActorHandle<?> actor2;
|
||||
while (true) {
|
||||
// Create another actor which share the same process of actor 1.
|
||||
actor2 = createActor(constructor);
|
||||
int actor2Pid =
|
||||
this.<Integer>callActorFunction(actor2, getPid, new Object[0], Optional.of(Integer.class))
|
||||
.get();
|
||||
if (actor2Pid == pid) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
FunctionDescriptor getClassLoaderHashCode =
|
||||
new JavaFunctionDescriptor("ClassLoaderTester", "getClassLoaderHashCode", "()I");
|
||||
ObjectRef<Integer> hashCode1 =
|
||||
callActorFunction(
|
||||
actor1, getClassLoaderHashCode, new Object[0], Optional.of(Integer.class));
|
||||
ObjectRef<Integer> hashCode2 =
|
||||
callActorFunction(
|
||||
actor2, getClassLoaderHashCode, new Object[0], Optional.of(Integer.class));
|
||||
Assert.assertEquals(hashCode1.get(), hashCode2.get());
|
||||
|
||||
FunctionDescriptor increase =
|
||||
new JavaFunctionDescriptor("ClassLoaderTester", "increase", "()I");
|
||||
ObjectRef<Integer> value1 =
|
||||
callActorFunction(actor1, increase, new Object[0], Optional.of(Integer.class));
|
||||
ObjectRef<Integer> value2 =
|
||||
callActorFunction(actor2, increase, new Object[0], Optional.of(Integer.class));
|
||||
Assert.assertNotEquals(value1.get(), value2.get());
|
||||
}
|
||||
|
||||
private ActorHandle<?> createActor(FunctionDescriptor functionDescriptor) throws Exception {
|
||||
Method createActorMethod =
|
||||
AbstractRayRuntime.class.getDeclaredMethod(
|
||||
"createActorImpl",
|
||||
FunctionDescriptor.class,
|
||||
Object[].class,
|
||||
ActorCreationOptions.class);
|
||||
createActorMethod.setAccessible(true);
|
||||
return (ActorHandle<?>)
|
||||
createActorMethod.invoke(
|
||||
TestUtils.getUnderlyingRuntime(), functionDescriptor, new Object[0], null);
|
||||
}
|
||||
|
||||
private <T> ObjectRef<T> callActorFunction(
|
||||
ActorHandle<?> rayActor,
|
||||
FunctionDescriptor functionDescriptor,
|
||||
Object[] args,
|
||||
Optional<Class<?>> returnType)
|
||||
throws Exception {
|
||||
Method callActorFunctionMethod =
|
||||
AbstractRayRuntime.class.getDeclaredMethod(
|
||||
"callActorFunction",
|
||||
BaseActorHandle.class,
|
||||
FunctionDescriptor.class,
|
||||
Object[].class,
|
||||
Optional.class,
|
||||
CallOptions.class);
|
||||
callActorFunctionMethod.setAccessible(true);
|
||||
return (ObjectRef<T>)
|
||||
callActorFunctionMethod.invoke(
|
||||
TestUtils.getUnderlyingRuntime(),
|
||||
rayActor,
|
||||
functionDescriptor,
|
||||
args,
|
||||
returnType,
|
||||
new CallOptions.Builder().build());
|
||||
}
|
||||
}
|
|
@ -55,7 +55,6 @@ public class DefaultActorLifetimeTest {
|
|||
System.setProperty("ray.job.default-actor-lifetime", defaultActorLifetime.name());
|
||||
}
|
||||
try {
|
||||
System.setProperty("ray.job.num-java-workers-per-process", "1");
|
||||
Ray.init();
|
||||
|
||||
/// 1. create owner and invoke createChildActor.
|
||||
|
|
|
@ -74,38 +74,6 @@ public class ExitActorTest extends BaseTest {
|
|||
Assert.assertThrows(RayActorException.class, obj::get);
|
||||
}
|
||||
|
||||
public void testExitActorInMultiWorker() {
|
||||
Assert.assertTrue(TestUtils.getNumWorkersPerProcess() > 1);
|
||||
ActorHandle<ExitingActor> actor1 =
|
||||
Ray.actor(ExitingActor::new).setMaxRestarts(ActorCreationOptions.INFINITE_RESTART).remote();
|
||||
int pid = actor1.task(ExitingActor::getPid).remote().get();
|
||||
Assert.assertEquals(
|
||||
1, (int) actor1.task(ExitingActor::getSizeOfActorContextMap).remote().get());
|
||||
ActorHandle<ExitingActor> actor2;
|
||||
while (true) {
|
||||
// Create another actor which share the same process of actor 1.
|
||||
actor2 =
|
||||
Ray.actor(ExitingActor::new).setMaxRestarts(ActorCreationOptions.NO_RESTART).remote();
|
||||
int actor2Pid = actor2.task(ExitingActor::getPid).remote().get();
|
||||
if (actor2Pid == pid) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
Assert.assertEquals(
|
||||
2, (int) actor1.task(ExitingActor::getSizeOfActorContextMap).remote().get());
|
||||
Assert.assertEquals(
|
||||
2, (int) actor2.task(ExitingActor::getSizeOfActorContextMap).remote().get());
|
||||
ObjectRef<Boolean> obj1 = actor1.task(ExitingActor::exit).remote();
|
||||
Assert.assertThrows(RayActorException.class, obj1::get);
|
||||
Assert.assertTrue(SystemUtil.isProcessAlive(pid));
|
||||
// Actor 2 shouldn't exit or be reconstructed.
|
||||
Assert.assertEquals(1, (int) actor2.task(ExitingActor::incr).remote().get());
|
||||
Assert.assertEquals(
|
||||
1, (int) actor2.task(ExitingActor::getSizeOfActorContextMap).remote().get());
|
||||
Assert.assertEquals(pid, (int) actor2.task(ExitingActor::getPid).remote().get());
|
||||
Assert.assertTrue(SystemUtil.isProcessAlive(pid));
|
||||
}
|
||||
|
||||
public void testExitActorWithDynamicOptions() {
|
||||
ActorHandle<ExitingActor> actor =
|
||||
Ray.actor(ExitingActor::new)
|
||||
|
|
|
@ -15,7 +15,6 @@ public class ExitActorTest2 extends BaseTest {
|
|||
|
||||
@BeforeClass
|
||||
public void setUp() {
|
||||
System.setProperty("ray.job.num-java-workers-per-process", "1");
|
||||
System.setProperty("ray.raylet.startup-token", "0");
|
||||
}
|
||||
|
||||
|
|
|
@ -23,10 +23,6 @@ public class FailureTest extends BaseTest {
|
|||
|
||||
@BeforeClass
|
||||
public void setUp() {
|
||||
// This is needed by `testGetThrowsQuicklyWhenFoundException`.
|
||||
// Set one worker per process. Otherwise, if `badFunc2` and `slowFunc` run in the same
|
||||
// process, `sleep` will delay `System.exit`.
|
||||
System.setProperty("ray.job.num-java-workers-per-process", "1");
|
||||
System.setProperty("ray.raylet.startup-token", "0");
|
||||
}
|
||||
|
||||
|
|
|
@ -11,7 +11,6 @@ public class JobConfigTest extends BaseTest {
|
|||
|
||||
@BeforeClass
|
||||
public void setupJobConfig() {
|
||||
System.setProperty("ray.job.num-java-workers-per-process", "3");
|
||||
System.setProperty("ray.raylet.startup-token", "0");
|
||||
System.setProperty("ray.job.jvm-options.0", "-DX=999");
|
||||
System.setProperty("ray.job.jvm-options.1", "-DY=998");
|
||||
|
@ -33,10 +32,6 @@ public class JobConfigTest extends BaseTest {
|
|||
Assert.assertEquals("998", Ray.task(JobConfigTest::getJvmOptions, "Y").remote().get());
|
||||
}
|
||||
|
||||
public void testNumJavaWorkersPerProcess() {
|
||||
Assert.assertEquals(TestUtils.getNumWorkersPerProcess(), 3);
|
||||
}
|
||||
|
||||
public void testInActor() {
|
||||
ActorHandle<MyActor> actor = Ray.actor(MyActor::new).remote();
|
||||
|
||||
|
|
|
@ -15,7 +15,6 @@ public class KillActorTest extends BaseTest {
|
|||
|
||||
@BeforeClass
|
||||
public void setUp() {
|
||||
System.setProperty("ray.job.num-java-workers-per-process", "1");
|
||||
System.setProperty("ray.raylet.startup-token", "0");
|
||||
}
|
||||
|
||||
|
|
|
@ -51,14 +51,13 @@ public class MultiThreadingTest extends BaseTest {
|
|||
final Object[] result = new Object[1];
|
||||
Thread thread =
|
||||
new Thread(
|
||||
Ray.wrapRunnable(
|
||||
() -> {
|
||||
try {
|
||||
result[0] = Ray.getRuntimeContext().getCurrentActorId();
|
||||
} catch (Exception e) {
|
||||
result[0] = e;
|
||||
}
|
||||
}));
|
||||
() -> {
|
||||
try {
|
||||
result[0] = Ray.getRuntimeContext().getCurrentActorId();
|
||||
} catch (Exception e) {
|
||||
result[0] = e;
|
||||
}
|
||||
});
|
||||
thread.start();
|
||||
thread.join();
|
||||
if (result[0] instanceof Exception) {
|
||||
|
@ -140,6 +139,8 @@ public class MultiThreadingTest extends BaseTest {
|
|||
Assert.assertEquals("ok", obj.get());
|
||||
}
|
||||
|
||||
/// SINGLE_PROCESS mode doesn't support this API.
|
||||
@Test(groups = {"cluster"})
|
||||
public void testGetCurrentActorId() {
|
||||
ActorHandle<ActorIdTester> actorIdTester = Ray.actor(ActorIdTester::new).remote();
|
||||
ActorId actorId = actorIdTester.task(ActorIdTester::getCurrentActorId).remote().get();
|
||||
|
@ -162,104 +163,6 @@ public class MultiThreadingTest extends BaseTest {
|
|||
};
|
||||
}
|
||||
|
||||
static boolean testMissingWrapRunnable() throws InterruptedException {
|
||||
{
|
||||
Runnable[] runnables = generateRunnables();
|
||||
// It's OK to run them in main thread.
|
||||
for (Runnable runnable : runnables) {
|
||||
runnable.run();
|
||||
}
|
||||
}
|
||||
|
||||
Throwable[] throwable = new Throwable[1];
|
||||
|
||||
{
|
||||
Runnable[] runnables = generateRunnables();
|
||||
Thread thread =
|
||||
new Thread(
|
||||
Ray.wrapRunnable(
|
||||
() -> {
|
||||
try {
|
||||
// It would be OK to run them in another thread if wrapped the runnable.
|
||||
for (Runnable runnable : runnables) {
|
||||
runnable.run();
|
||||
}
|
||||
} catch (Throwable ex) {
|
||||
throwable[0] = ex;
|
||||
}
|
||||
}));
|
||||
thread.start();
|
||||
thread.join();
|
||||
if (throwable[0] != null) {
|
||||
throw new RuntimeException("Exception occurred in thread.", throwable[0]);
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
Runnable[] runnables = generateRunnables();
|
||||
Thread thread =
|
||||
new Thread(
|
||||
() -> {
|
||||
try {
|
||||
// It wouldn't be OK to run them in another thread if not wrapped the runnable.
|
||||
for (Runnable runnable : runnables) {
|
||||
Assert.expectThrows(RuntimeException.class, runnable::run);
|
||||
}
|
||||
} catch (Throwable ex) {
|
||||
throwable[0] = ex;
|
||||
}
|
||||
});
|
||||
thread.start();
|
||||
thread.join();
|
||||
if (throwable[0] != null) {
|
||||
throw new RuntimeException("Exception occurred in thread.", throwable[0]);
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
Runnable[] runnables = generateRunnables();
|
||||
Runnable[] wrappedRunnables = new Runnable[runnables.length];
|
||||
for (int i = 0; i < runnables.length; i++) {
|
||||
wrappedRunnables[i] = Ray.wrapRunnable(runnables[i]);
|
||||
}
|
||||
// It would be OK to run the wrapped runnables in the current thread.
|
||||
for (Runnable runnable : wrappedRunnables) {
|
||||
runnable.run();
|
||||
}
|
||||
|
||||
// It would be OK to invoke Ray APIs after executing a wrapped runnable in the current thread.
|
||||
wrappedRunnables[0].run();
|
||||
runnables[0].run();
|
||||
}
|
||||
|
||||
// Return true here to make the caller.remote() returns an ObjectRef.
|
||||
return true;
|
||||
}
|
||||
|
||||
public void testMissingWrapRunnableInWorker() {
|
||||
Ray.task(MultiThreadingTest::testMissingWrapRunnable).remote().get();
|
||||
}
|
||||
|
||||
public void testGetAndSetAsyncContext() throws InterruptedException {
|
||||
Object asyncContext = Ray.getAsyncContext();
|
||||
Throwable[] throwable = new Throwable[1];
|
||||
Thread thread =
|
||||
new Thread(
|
||||
() -> {
|
||||
try {
|
||||
Ray.setAsyncContext(asyncContext);
|
||||
Ray.put(1);
|
||||
} catch (Throwable ex) {
|
||||
throwable[0] = ex;
|
||||
}
|
||||
});
|
||||
thread.start();
|
||||
thread.join();
|
||||
if (throwable[0] != null) {
|
||||
throw new RuntimeException("Exception occurred in thread.", throwable[0]);
|
||||
}
|
||||
}
|
||||
|
||||
private static void runTestCaseInMultipleThreads(Runnable testCase, int numRepeats) {
|
||||
ExecutorService service = Executors.newFixedThreadPool(NUM_THREADS);
|
||||
|
||||
|
@ -267,14 +170,13 @@ public class MultiThreadingTest extends BaseTest {
|
|||
List<Future<String>> futures = new ArrayList<>();
|
||||
for (int i = 0; i < NUM_THREADS; i++) {
|
||||
Callable<String> task =
|
||||
Ray.wrapCallable(
|
||||
() -> {
|
||||
for (int j = 0; j < numRepeats; j++) {
|
||||
TimeUnit.MILLISECONDS.sleep(1);
|
||||
testCase.run();
|
||||
}
|
||||
return "ok";
|
||||
});
|
||||
() -> {
|
||||
for (int j = 0; j < numRepeats; j++) {
|
||||
TimeUnit.MILLISECONDS.sleep(1);
|
||||
testCase.run();
|
||||
}
|
||||
return "ok";
|
||||
};
|
||||
futures.add(service.submit(task));
|
||||
}
|
||||
for (Future<String> future : futures) {
|
||||
|
@ -288,35 +190,4 @@ public class MultiThreadingTest extends BaseTest {
|
|||
service.shutdown();
|
||||
}
|
||||
}
|
||||
|
||||
private static boolean testGetAsyncContextAndSetAsyncContext() throws Exception {
|
||||
final Object asyncContext = Ray.getAsyncContext();
|
||||
final Object[] result = new Object[1];
|
||||
Thread thread =
|
||||
new Thread(
|
||||
() -> {
|
||||
try {
|
||||
Ray.setAsyncContext(asyncContext);
|
||||
Ray.put(0);
|
||||
} catch (Exception e) {
|
||||
result[0] = e;
|
||||
}
|
||||
});
|
||||
thread.start();
|
||||
thread.join();
|
||||
if (result[0] instanceof Exception) {
|
||||
throw (Exception) result[0];
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
public void testGetAsyncContextAndSetAsyncContextInDriver() throws Exception {
|
||||
Assert.assertTrue(testGetAsyncContextAndSetAsyncContext());
|
||||
}
|
||||
|
||||
public void testGetAsyncContextAndSetAsyncContextInWorker() {
|
||||
ObjectRef<Boolean> obj =
|
||||
Ray.task(MultiThreadingTest::testGetAsyncContextAndSetAsyncContext).remote();
|
||||
Assert.assertTrue(obj.get());
|
||||
}
|
||||
}
|
||||
|
|
|
@ -4,7 +4,6 @@ import com.google.common.collect.ImmutableList;
|
|||
import io.ray.api.ActorHandle;
|
||||
import io.ray.api.Ray;
|
||||
import io.ray.api.runtimeenv.RuntimeEnv;
|
||||
import io.ray.runtime.util.SystemUtil;
|
||||
import java.util.List;
|
||||
import org.testng.Assert;
|
||||
import org.testng.annotations.Test;
|
||||
|
@ -30,10 +29,6 @@ public class RuntimeEnvTest {
|
|||
return System.getenv(key);
|
||||
}
|
||||
|
||||
public int getPid() {
|
||||
return SystemUtil.pid();
|
||||
}
|
||||
|
||||
public boolean findClass(String className) {
|
||||
try {
|
||||
Class.forName(className);
|
||||
|
@ -45,7 +40,6 @@ public class RuntimeEnvTest {
|
|||
}
|
||||
|
||||
public void testPerJobEnvVars() {
|
||||
System.setProperty("ray.job.num-java-workers-per-process", "1");
|
||||
System.setProperty("ray.job.runtime-env.env-vars.KEY1", "A");
|
||||
System.setProperty("ray.job.runtime-env.env-vars.KEY2", "B");
|
||||
|
||||
|
@ -61,87 +55,6 @@ public class RuntimeEnvTest {
|
|||
}
|
||||
}
|
||||
|
||||
public void testPerActorEnvVars() {
|
||||
/// This is used to test that actors with runtime envs will not reuse worker process.
|
||||
System.setProperty("ray.job.num-java-workers-per-process", "2");
|
||||
try {
|
||||
Ray.init();
|
||||
int pid1 = 0;
|
||||
int pid2 = 0;
|
||||
{
|
||||
RuntimeEnv runtimeEnv =
|
||||
new RuntimeEnv.Builder()
|
||||
.addEnvVar("KEY1", "A")
|
||||
.addEnvVar("KEY2", "B")
|
||||
.addEnvVar("KEY1", "C")
|
||||
.build();
|
||||
|
||||
ActorHandle<A> actor1 = Ray.actor(A::new).setRuntimeEnv(runtimeEnv).remote();
|
||||
String val = actor1.task(A::getEnv, "KEY1").remote().get();
|
||||
Assert.assertEquals(val, "C");
|
||||
val = actor1.task(A::getEnv, "KEY2").remote().get();
|
||||
Assert.assertEquals(val, "B");
|
||||
|
||||
pid1 = actor1.task(A::getPid).remote().get();
|
||||
}
|
||||
|
||||
{
|
||||
/// Because we didn't set them for actor2 , all should be null.
|
||||
ActorHandle<A> actor2 = Ray.actor(A::new).remote();
|
||||
String val = actor2.task(A::getEnv, "KEY1").remote().get();
|
||||
Assert.assertNull(val);
|
||||
val = actor2.task(A::getEnv, "KEY2").remote().get();
|
||||
Assert.assertNull(val);
|
||||
pid2 = actor2.task(A::getPid).remote().get();
|
||||
}
|
||||
|
||||
// actor1 and actor2 shouldn't be in one process because they have
|
||||
// different runtime env.
|
||||
Assert.assertNotEquals(pid1, pid2);
|
||||
} finally {
|
||||
Ray.shutdown();
|
||||
}
|
||||
}
|
||||
|
||||
public void testPerActorEnvVarsOverwritePerJobEnvVars() {
|
||||
System.setProperty("ray.job.num-java-workers-per-process", "2");
|
||||
System.setProperty("ray.job.runtime-env.env-vars.KEY1", "A");
|
||||
System.setProperty("ray.job.runtime-env.env-vars.KEY2", "B");
|
||||
|
||||
int pid1 = 0;
|
||||
int pid2 = 0;
|
||||
try {
|
||||
Ray.init();
|
||||
{
|
||||
RuntimeEnv runtimeEnv = new RuntimeEnv.Builder().addEnvVar("KEY1", "C").build();
|
||||
|
||||
ActorHandle<A> actor1 = Ray.actor(A::new).setRuntimeEnv(runtimeEnv).remote();
|
||||
String val = actor1.task(A::getEnv, "KEY1").remote().get();
|
||||
Assert.assertEquals(val, "C");
|
||||
val = actor1.task(A::getEnv, "KEY2").remote().get();
|
||||
Assert.assertEquals(val, "B");
|
||||
pid1 = actor1.task(A::getPid).remote().get();
|
||||
}
|
||||
|
||||
{
|
||||
/// Because we didn't set them for actor2 explicitly, it should use the per job
|
||||
/// runtime env.
|
||||
ActorHandle<A> actor2 = Ray.actor(A::new).remote();
|
||||
String val = actor2.task(A::getEnv, "KEY1").remote().get();
|
||||
Assert.assertEquals(val, "A");
|
||||
val = actor2.task(A::getEnv, "KEY2").remote().get();
|
||||
Assert.assertEquals(val, "B");
|
||||
pid2 = actor2.task(A::getPid).remote().get();
|
||||
}
|
||||
|
||||
// actor1 and actor2 shouldn't be in one process because they have
|
||||
// different runtime env.
|
||||
Assert.assertNotEquals(pid1, pid2);
|
||||
} finally {
|
||||
Ray.shutdown();
|
||||
}
|
||||
}
|
||||
|
||||
private static String getEnvVar(String key) {
|
||||
return System.getenv(key);
|
||||
}
|
||||
|
|
|
@ -3,9 +3,7 @@ package io.ray.test;
|
|||
import com.google.common.base.Preconditions;
|
||||
import io.ray.api.ObjectRef;
|
||||
import io.ray.api.Ray;
|
||||
import io.ray.runtime.AbstractRayRuntime;
|
||||
import io.ray.runtime.RayRuntimeInternal;
|
||||
import io.ray.runtime.RayRuntimeProxy;
|
||||
import io.ray.runtime.config.RayConfig;
|
||||
import io.ray.runtime.config.RunMode;
|
||||
import io.ray.runtime.task.ArgumentsBuilder;
|
||||
|
@ -129,20 +127,7 @@ public class TestUtils {
|
|||
}
|
||||
|
||||
public static RayRuntimeInternal getUnderlyingRuntime() {
|
||||
if (Ray.internal() instanceof AbstractRayRuntime) {
|
||||
return (RayRuntimeInternal) Ray.internal();
|
||||
}
|
||||
RayRuntimeProxy proxy =
|
||||
(RayRuntimeProxy) (java.lang.reflect.Proxy.getInvocationHandler(Ray.internal()));
|
||||
return proxy.getRuntimeObject();
|
||||
}
|
||||
|
||||
private static int getNumWorkersPerProcessRemoteFunction() {
|
||||
return TestUtils.getRuntime().getRayConfig().numWorkersPerProcess;
|
||||
}
|
||||
|
||||
public static int getNumWorkersPerProcess() {
|
||||
return Ray.task(TestUtils::getNumWorkersPerProcessRemoteFunction).remote().get();
|
||||
return (RayRuntimeInternal) Ray.internal();
|
||||
}
|
||||
|
||||
public static ProcessBuilder buildDriver(Class<?> mainClass, String[] args) {
|
||||
|
|
|
@ -1,6 +0,0 @@
|
|||
ray {
|
||||
job {
|
||||
# Enable multi-worker feature in Java test
|
||||
num-java-workers-per-process: 10
|
||||
}
|
||||
}
|
|
@ -1116,7 +1116,6 @@ cdef class CoreWorker:
|
|||
options.unhandled_exception_handler = unhandled_exception_handler
|
||||
options.get_lang_stack = get_py_stack
|
||||
options.is_local_mode = local_mode
|
||||
options.num_workers = 1
|
||||
options.kill_main = kill_main_task
|
||||
options.terminate_asyncio_thread = terminate_asyncio_thread
|
||||
options.serialized_job_config = serialized_job_config
|
||||
|
|
|
@ -8,8 +8,6 @@ class JobConfig:
|
|||
"""A class used to store the configurations of a job.
|
||||
|
||||
Attributes:
|
||||
num_java_workers_per_process (int): The number of java workers per
|
||||
worker process.
|
||||
jvm_options (str[]): The jvm options for java workers of the job.
|
||||
code_search_path (list): A list of directories or jar files that
|
||||
specify the search path for user code. This will be used as
|
||||
|
@ -22,7 +20,6 @@ class JobConfig:
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
num_java_workers_per_process=1,
|
||||
jvm_options=None,
|
||||
code_search_path=None,
|
||||
runtime_env=None,
|
||||
|
@ -31,7 +28,6 @@ class JobConfig:
|
|||
ray_namespace=None,
|
||||
default_actor_lifetime="non_detached",
|
||||
):
|
||||
self.num_java_workers_per_process = num_java_workers_per_process
|
||||
self.jvm_options = jvm_options or []
|
||||
self.code_search_path = code_search_path or []
|
||||
# It's difficult to find the error that caused by the
|
||||
|
@ -108,7 +104,6 @@ class JobConfig:
|
|||
pb.ray_namespace = str(uuid.uuid4())
|
||||
else:
|
||||
pb.ray_namespace = self.ray_namespace
|
||||
pb.num_java_workers_per_process = self.num_java_workers_per_process
|
||||
pb.jvm_options.extend(self.jvm_options)
|
||||
pb.code_search_path.extend(self.code_search_path)
|
||||
for k, v in self.metadata.items():
|
||||
|
@ -147,9 +142,6 @@ class JobConfig:
|
|||
Generates a JobConfig object from json.
|
||||
"""
|
||||
return cls(
|
||||
num_java_workers_per_process=job_config_json.get(
|
||||
"num_java_workers_per_process", 1
|
||||
),
|
||||
jvm_options=job_config_json.get("jvm_options", None),
|
||||
code_search_path=job_config_json.get("code_search_path", None),
|
||||
runtime_env=job_config_json.get("runtime_env", None),
|
||||
|
|
|
@ -545,19 +545,16 @@ def test_k8s_cpu(use_cgroups_v2: bool):
|
|||
|
||||
|
||||
def test_sync_job_config(shutdown_only):
|
||||
num_java_workers_per_process = 8
|
||||
runtime_env = {"env_vars": {"key": "value"}}
|
||||
|
||||
ray.init(
|
||||
job_config=ray.job_config.JobConfig(
|
||||
num_java_workers_per_process=num_java_workers_per_process,
|
||||
runtime_env=runtime_env,
|
||||
)
|
||||
)
|
||||
|
||||
# Check that the job config is synchronized at the driver side.
|
||||
job_config = ray.worker.global_worker.core_worker.get_job_config()
|
||||
assert job_config.num_java_workers_per_process == num_java_workers_per_process
|
||||
job_runtime_env = RuntimeEnv.deserialize(
|
||||
job_config.runtime_env_info.serialized_runtime_env
|
||||
)
|
||||
|
@ -571,7 +568,6 @@ def test_sync_job_config(shutdown_only):
|
|||
# Check that the job config is synchronized at the worker side.
|
||||
job_config = gcs_utils.JobConfig()
|
||||
job_config.ParseFromString(ray.get(get_job_config.remote()))
|
||||
assert job_config.num_java_workers_per_process == num_java_workers_per_process
|
||||
job_runtime_env = RuntimeEnv.deserialize(
|
||||
job_config.runtime_env_info.serialized_runtime_env
|
||||
)
|
||||
|
|
|
@ -3084,15 +3084,6 @@ void CoreWorker::HandleKillActor(const rpc::KillActorRequest &request,
|
|||
if (request.no_restart()) {
|
||||
Disconnect();
|
||||
}
|
||||
if (options_.num_workers > 1) {
|
||||
// TODO (kfstorm): Should we add some kind of check before sending the killing
|
||||
// request?
|
||||
RAY_LOG(ERROR)
|
||||
<< "Killing an actor which is running in a worker process with multiple "
|
||||
"workers will also kill other actors in this process. To avoid this, "
|
||||
"please create the Java actor with some dynamic options to make it being "
|
||||
"hosted in a dedicated worker process.";
|
||||
}
|
||||
// NOTE(hchen): Use `QuickExit()` to force-exit this process without doing cleanup.
|
||||
// `exit()` will destruct static objects in an incorrect order, which will lead to
|
||||
// core dumps.
|
||||
|
|
|
@ -76,7 +76,6 @@ struct CoreWorkerOptions {
|
|||
get_lang_stack(nullptr),
|
||||
kill_main(nullptr),
|
||||
is_local_mode(false),
|
||||
num_workers(0),
|
||||
terminate_asyncio_thread(nullptr),
|
||||
serialized_job_config(""),
|
||||
metrics_agent_port(-1),
|
||||
|
@ -148,8 +147,6 @@ struct CoreWorkerOptions {
|
|||
std::function<bool()> kill_main;
|
||||
/// Is local mode being used.
|
||||
bool is_local_mode;
|
||||
/// The number of workers to be started in the current process.
|
||||
int num_workers;
|
||||
/// The function to destroy asyncio event and loops.
|
||||
std::function<void()> terminate_asyncio_thread;
|
||||
/// Serialized representation of JobConfig.
|
||||
|
|
|
@ -82,10 +82,9 @@ thread_local std::weak_ptr<CoreWorker> CoreWorkerProcessImpl::thread_local_core_
|
|||
|
||||
CoreWorkerProcessImpl::CoreWorkerProcessImpl(const CoreWorkerOptions &options)
|
||||
: options_(options),
|
||||
global_worker_id_(
|
||||
options.worker_type == WorkerType::DRIVER
|
||||
? ComputeDriverIdFromJob(options_.job_id)
|
||||
: (options_.num_workers == 1 ? WorkerID::FromRandom() : WorkerID::Nil())) {
|
||||
global_worker_id_(options.worker_type == WorkerType::DRIVER
|
||||
? ComputeDriverIdFromJob(options_.job_id)
|
||||
: WorkerID::FromRandom()) {
|
||||
if (options_.enable_logging) {
|
||||
std::stringstream app_name;
|
||||
app_name << LanguageString(options_.language) << "-core-"
|
||||
|
@ -111,12 +110,6 @@ CoreWorkerProcessImpl::CoreWorkerProcessImpl(const CoreWorkerOptions &options)
|
|||
<< "install_failure_signal_handler must be false because ray log is disabled.";
|
||||
}
|
||||
|
||||
RAY_CHECK(options_.num_workers > 0);
|
||||
if (options_.worker_type == WorkerType::DRIVER) {
|
||||
// Driver process can only contain one worker.
|
||||
RAY_CHECK(options_.num_workers == 1);
|
||||
}
|
||||
|
||||
RAY_LOG(INFO) << "Constructing CoreWorkerProcess. pid: " << getpid();
|
||||
|
||||
// NOTE(kfstorm): any initialization depending on RayConfig must happen after this line.
|
||||
|
@ -250,18 +243,14 @@ bool CoreWorkerProcessImpl::ShouldCreateGlobalWorkerOnConstruction() const {
|
|||
// worker APIs before `CoreWorkerProcess::RunTaskExecutionLoop` is called. So we need
|
||||
// to create the worker instance here. One example of invocations is
|
||||
// https://github.com/ray-project/ray/blob/45ce40e5d44801193220d2c546be8de0feeef988/python/ray/worker.py#L1281.
|
||||
return options_.num_workers == 1 && (options_.worker_type == WorkerType::DRIVER ||
|
||||
options_.language == Language::PYTHON);
|
||||
return (options_.worker_type == WorkerType::DRIVER ||
|
||||
options_.language == Language::PYTHON);
|
||||
}
|
||||
|
||||
std::shared_ptr<CoreWorker> CoreWorkerProcessImpl::GetWorker(
|
||||
const WorkerID &worker_id) const {
|
||||
absl::ReaderMutexLock lock(&mutex_);
|
||||
auto it = workers_.find(worker_id);
|
||||
if (it != workers_.end()) {
|
||||
return it->second;
|
||||
}
|
||||
return nullptr;
|
||||
return global_worker_;
|
||||
}
|
||||
|
||||
std::shared_ptr<CoreWorker> CoreWorkerProcessImpl::GetGlobalWorker() {
|
||||
|
@ -275,13 +264,9 @@ std::shared_ptr<CoreWorker> CoreWorkerProcessImpl::CreateWorker() {
|
|||
global_worker_id_ != WorkerID::Nil() ? global_worker_id_ : WorkerID::FromRandom());
|
||||
RAY_LOG(DEBUG) << "Worker " << worker->GetWorkerID() << " is created.";
|
||||
absl::WriterMutexLock lock(&mutex_);
|
||||
if (options_.num_workers == 1) {
|
||||
global_worker_ = worker;
|
||||
}
|
||||
global_worker_ = worker;
|
||||
thread_local_core_worker_ = worker;
|
||||
|
||||
workers_.emplace(worker->GetWorkerID(), worker);
|
||||
RAY_CHECK(workers_.size() <= static_cast<size_t>(options_.num_workers));
|
||||
return worker;
|
||||
}
|
||||
|
||||
|
@ -293,10 +278,7 @@ void CoreWorkerProcessImpl::RemoveWorker(std::shared_ptr<CoreWorker> worker) {
|
|||
RAY_CHECK(thread_local_core_worker_.lock() == worker);
|
||||
}
|
||||
thread_local_core_worker_.reset();
|
||||
{
|
||||
workers_.erase(worker->GetWorkerID());
|
||||
RAY_LOG(INFO) << "Removed worker " << worker->GetWorkerID();
|
||||
}
|
||||
RAY_LOG(INFO) << "Removed worker " << worker->GetWorkerID();
|
||||
if (global_worker_ == worker) {
|
||||
global_worker_ = nullptr;
|
||||
}
|
||||
|
@ -304,31 +286,14 @@ void CoreWorkerProcessImpl::RemoveWorker(std::shared_ptr<CoreWorker> worker) {
|
|||
|
||||
void CoreWorkerProcessImpl::RunWorkerTaskExecutionLoop() {
|
||||
RAY_CHECK(options_.worker_type == WorkerType::WORKER);
|
||||
if (options_.num_workers == 1) {
|
||||
// Run the task loop in the current thread only if the number of workers is 1.
|
||||
auto worker = GetGlobalWorker();
|
||||
if (!worker) {
|
||||
worker = CreateWorker();
|
||||
}
|
||||
worker->RunTaskExecutionLoop();
|
||||
RAY_LOG(INFO) << "Task execution loop terminated. Removing the global worker.";
|
||||
RemoveWorker(worker);
|
||||
} else {
|
||||
std::vector<std::thread> worker_threads;
|
||||
for (int i = 0; i < options_.num_workers; i++) {
|
||||
worker_threads.emplace_back([this, i] {
|
||||
SetThreadName("worker.task" + std::to_string(i));
|
||||
auto worker = CreateWorker();
|
||||
worker->RunTaskExecutionLoop();
|
||||
RAY_LOG(INFO) << "Task execution loop terminated for a thread "
|
||||
<< std::to_string(i) << ". Removing a worker.";
|
||||
RemoveWorker(worker);
|
||||
});
|
||||
}
|
||||
for (auto &thread : worker_threads) {
|
||||
thread.join();
|
||||
}
|
||||
// Run the task loop in the current thread only if the number of workers is 1.
|
||||
auto worker = GetGlobalWorker();
|
||||
if (!worker) {
|
||||
worker = CreateWorker();
|
||||
}
|
||||
worker->RunTaskExecutionLoop();
|
||||
RAY_LOG(INFO) << "Task execution loop terminated. Removing the global worker.";
|
||||
RemoveWorker(worker);
|
||||
}
|
||||
|
||||
void CoreWorkerProcessImpl::ShutdownDriver() {
|
||||
|
@ -342,42 +307,30 @@ void CoreWorkerProcessImpl::ShutdownDriver() {
|
|||
}
|
||||
|
||||
CoreWorker &CoreWorkerProcessImpl::GetCoreWorkerForCurrentThread() {
|
||||
if (options_.num_workers == 1) {
|
||||
auto global_worker = GetGlobalWorker();
|
||||
if (ShouldCreateGlobalWorkerOnConstruction() && !global_worker) {
|
||||
// This could only happen when the worker has already been shutdown.
|
||||
// In this case, we should exit without crashing.
|
||||
// TODO (scv119): A better solution could be returning error code
|
||||
// and handling it at language frontend.
|
||||
if (options_.worker_type == WorkerType::DRIVER) {
|
||||
RAY_LOG(ERROR)
|
||||
<< "The global worker has already been shutdown. This happens when "
|
||||
"the language frontend accesses the Ray's worker after it is "
|
||||
"shutdown. The process will exit";
|
||||
} else {
|
||||
RAY_LOG(INFO) << "The global worker has already been shutdown. This happens when "
|
||||
"the language frontend accesses the Ray's worker after it is "
|
||||
"shutdown. The process will exit";
|
||||
}
|
||||
QuickExit();
|
||||
auto global_worker = GetGlobalWorker();
|
||||
if (ShouldCreateGlobalWorkerOnConstruction() && !global_worker) {
|
||||
// This could only happen when the worker has already been shutdown.
|
||||
// In this case, we should exit without crashing.
|
||||
// TODO (scv119): A better solution could be returning error code
|
||||
// and handling it at language frontend.
|
||||
if (options_.worker_type == WorkerType::DRIVER) {
|
||||
RAY_LOG(ERROR) << "The global worker has already been shutdown. This happens when "
|
||||
"the language frontend accesses the Ray's worker after it is "
|
||||
"shutdown. The process will exit";
|
||||
} else {
|
||||
RAY_LOG(INFO) << "The global worker has already been shutdown. This happens when "
|
||||
"the language frontend accesses the Ray's worker after it is "
|
||||
"shutdown. The process will exit";
|
||||
}
|
||||
RAY_CHECK(global_worker) << "global_worker_ must not be NULL";
|
||||
return *global_worker;
|
||||
QuickExit();
|
||||
}
|
||||
auto ptr = thread_local_core_worker_.lock();
|
||||
RAY_CHECK(ptr != nullptr)
|
||||
<< "The current thread is not bound with a core worker instance.";
|
||||
return *ptr;
|
||||
RAY_CHECK(global_worker) << "global_worker_ must not be NULL";
|
||||
return *global_worker;
|
||||
}
|
||||
|
||||
void CoreWorkerProcessImpl::SetThreadLocalWorkerById(const WorkerID &worker_id) {
|
||||
if (options_.num_workers == 1) {
|
||||
RAY_CHECK(GetGlobalWorker()->GetWorkerID() == worker_id);
|
||||
return;
|
||||
}
|
||||
auto worker = GetWorker(worker_id);
|
||||
RAY_CHECK(worker) << "Worker " << worker_id << " not found.";
|
||||
thread_local_core_worker_ = GetWorker(worker_id);
|
||||
RAY_CHECK(GetGlobalWorker()->GetWorkerID() == worker_id);
|
||||
return;
|
||||
}
|
||||
|
||||
} // namespace core
|
||||
|
|
|
@ -25,7 +25,6 @@ class CoreWorker;
|
|||
/// CoreWorkerOptions options = {
|
||||
/// WorkerType::DRIVER, // worker_type
|
||||
/// ..., // other arguments
|
||||
/// 1, // num_workers
|
||||
/// };
|
||||
/// CoreWorkerProcess::Initialize(options);
|
||||
///
|
||||
|
@ -36,7 +35,6 @@ class CoreWorker;
|
|||
/// CoreWorkerOptions options = {
|
||||
/// WorkerType::WORKER, // worker_type
|
||||
/// ..., // other arguments
|
||||
/// num_workers, // num_workers
|
||||
/// };
|
||||
/// CoreWorkerProcess::Initialize(options);
|
||||
/// ... // Do other stuff
|
||||
|
@ -52,14 +50,7 @@ class CoreWorker;
|
|||
/// threads, remember to call `CoreWorkerProcess::SetCurrentThreadWorkerId(worker_id)`
|
||||
/// once in the new thread before calling core worker APIs, to associate the current
|
||||
/// thread with a worker. You can obtain the worker ID via
|
||||
/// `CoreWorkerProcess::GetCoreWorker()->GetWorkerID()`. Currently a Java worker process
|
||||
/// starts multiple workers by default, but can be configured to start only 1 worker by
|
||||
/// speicifying `num_java_workers_per_process` in the job config.
|
||||
///
|
||||
/// If only 1 worker is started (either because the worker type is driver, or the
|
||||
/// `num_workers` in `CoreWorkerOptions` is set to 1), all threads will be automatically
|
||||
/// associated to the only worker. Then no need to call `SetCurrentThreadWorkerId` in
|
||||
/// your own threads. Currently a Python worker process starts only 1 worker.
|
||||
/// `CoreWorkerProcess::GetCoreWorker()->GetWorkerID()`.
|
||||
///
|
||||
/// How does core worker process dealloation work?
|
||||
///
|
||||
|
@ -195,9 +186,6 @@ class CoreWorkerProcessImpl {
|
|||
/// The worker ID of the global worker, if the number of workers is 1.
|
||||
const WorkerID global_worker_id_;
|
||||
|
||||
/// Map from worker ID to worker.
|
||||
absl::flat_hash_map<WorkerID, std::shared_ptr<CoreWorker>> workers_ GUARDED_BY(mutex_);
|
||||
|
||||
/// To protect access to workers_ and global_worker_
|
||||
mutable absl::Mutex mutex_;
|
||||
};
|
||||
|
|
|
@ -103,7 +103,6 @@ Java_io_ray_runtime_RayNativeRuntime_nativeInitialize(JNIEnv *env,
|
|||
jstring rayletSocket,
|
||||
jbyteArray jobId,
|
||||
jobject gcsClientOptions,
|
||||
jint numWorkersPerProcess,
|
||||
jstring logDir,
|
||||
jbyteArray jobConfig,
|
||||
jint startupToken,
|
||||
|
@ -289,7 +288,6 @@ Java_io_ray_runtime_RayNativeRuntime_nativeInitialize(JNIEnv *env,
|
|||
options.task_execution_callback = task_execution_callback;
|
||||
options.on_worker_shutdown = on_worker_shutdown;
|
||||
options.gc_collect = gc_collect;
|
||||
options.num_workers = static_cast<int>(numWorkersPerProcess);
|
||||
options.serialized_job_config = serialized_job_config;
|
||||
options.metrics_agent_port = -1;
|
||||
options.startup_token = startupToken;
|
||||
|
|
|
@ -25,7 +25,7 @@ extern "C" {
|
|||
* Class: io_ray_runtime_RayNativeRuntime
|
||||
* Method: nativeInitialize
|
||||
* Signature:
|
||||
* (ILjava/lang/String;ILjava/lang/String;Ljava/lang/String;Ljava/lang/String;[BLio/ray/runtime/gcs/GcsClientOptions;ILjava/lang/String;[BII)V
|
||||
* (ILjava/lang/String;ILjava/lang/String;Ljava/lang/String;Ljava/lang/String;[BLio/ray/runtime/gcs/GcsClientOptions;Ljava/lang/String;[BII)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL Java_io_ray_runtime_RayNativeRuntime_nativeInitialize(JNIEnv *,
|
||||
jclass,
|
||||
|
@ -37,7 +37,6 @@ JNIEXPORT void JNICALL Java_io_ray_runtime_RayNativeRuntime_nativeInitialize(JNI
|
|||
jstring,
|
||||
jbyteArray,
|
||||
jobject,
|
||||
jint,
|
||||
jstring,
|
||||
jbyteArray,
|
||||
jint,
|
||||
|
|
|
@ -140,7 +140,6 @@ class CoreWorkerTest : public ::testing::Test {
|
|||
options.node_manager_port = node_manager_port;
|
||||
options.raylet_ip_address = "127.0.0.1";
|
||||
options.driver_name = "core_worker_test";
|
||||
options.num_workers = 1;
|
||||
options.metrics_agent_port = -1;
|
||||
CoreWorkerProcess::Initialize(options);
|
||||
}
|
||||
|
|
|
@ -51,7 +51,6 @@ class MockWorker {
|
|||
options.raylet_ip_address = "127.0.0.1";
|
||||
options.task_execution_callback =
|
||||
std::bind(&MockWorker::ExecuteTask, this, _1, _2, _3, _4, _5, _6, _7, _8, _9);
|
||||
options.num_workers = 1;
|
||||
options.metrics_agent_port = -1;
|
||||
options.startup_token = startup_token;
|
||||
CoreWorkerProcess::Initialize(options);
|
||||
|
|
|
@ -686,12 +686,8 @@ void GcsActorManager::PollOwnerForActorOutOfScope(
|
|||
if (node_it != owners_.end() && node_it->second.count(owner_id)) {
|
||||
// Only destroy the actor if its owner is still alive. The actor may
|
||||
// have already been destroyed if the owner died.
|
||||
// For multiple actors in one process, if one actor is out of scope,
|
||||
// We shouldn't force kill the actor because other actors in the process
|
||||
// are still alive.
|
||||
auto force_kill =
|
||||
get_job_config_(actor_id.JobId())->num_java_workers_per_process() <= 1;
|
||||
DestroyActor(actor_id, GenActorOutOfScopeCause(GetActor(actor_id)), force_kill);
|
||||
DestroyActor(
|
||||
actor_id, GenActorOutOfScopeCause(GetActor(actor_id)), /*force_kill=*/true);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
|
|
@ -89,7 +89,7 @@ TEST_F(GcsJobManagerTest, TestGetJobConfig) {
|
|||
auto job_id2 = JobID::FromInt(2);
|
||||
gcs::GcsInitData gcs_init_data(gcs_table_storage_);
|
||||
gcs_job_manager.Initialize(/*init_data=*/gcs_init_data);
|
||||
auto add_job_request1 = Mocker::GenAddJobRequest(job_id1, "namespace_1", 4);
|
||||
auto add_job_request1 = Mocker::GenAddJobRequest(job_id1, "namespace_1");
|
||||
|
||||
rpc::AddJobReply empty_reply;
|
||||
|
||||
|
@ -97,19 +97,16 @@ TEST_F(GcsJobManagerTest, TestGetJobConfig) {
|
|||
*add_job_request1,
|
||||
&empty_reply,
|
||||
[](Status, std::function<void()>, std::function<void()>) {});
|
||||
auto add_job_request2 = Mocker::GenAddJobRequest(job_id2, "namespace_2", 8);
|
||||
auto add_job_request2 = Mocker::GenAddJobRequest(job_id2, "namespace_2");
|
||||
gcs_job_manager.HandleAddJob(
|
||||
*add_job_request2,
|
||||
&empty_reply,
|
||||
[](Status, std::function<void()>, std::function<void()>) {});
|
||||
|
||||
auto job_config1 = gcs_job_manager.GetJobConfig(job_id1);
|
||||
ASSERT_EQ("namespace_1", job_config1->ray_namespace());
|
||||
ASSERT_EQ(4, job_config1->num_java_workers_per_process());
|
||||
|
||||
auto job_config2 = gcs_job_manager.GetJobConfig(job_id2);
|
||||
ASSERT_EQ("namespace_2", job_config2->ray_namespace());
|
||||
ASSERT_EQ(8, job_config2->num_java_workers_per_process());
|
||||
}
|
||||
|
||||
int main(int argc, char **argv) {
|
||||
|
|
|
@ -239,12 +239,9 @@ struct Mocker {
|
|||
}
|
||||
|
||||
static std::shared_ptr<rpc::AddJobRequest> GenAddJobRequest(
|
||||
const JobID &job_id,
|
||||
const std::string &ray_namespace,
|
||||
uint32_t num_java_worker_per_process) {
|
||||
const JobID &job_id, const std::string &ray_namespace) {
|
||||
auto job_config_data = std::make_shared<rpc::JobConfig>();
|
||||
job_config_data->set_ray_namespace(ray_namespace);
|
||||
job_config_data->set_num_java_workers_per_process(num_java_worker_per_process);
|
||||
|
||||
auto job_table_data = std::make_shared<rpc::JobTableData>();
|
||||
job_table_data->set_job_id(job_id.Binary());
|
||||
|
|
|
@ -39,8 +39,8 @@ using RestoreSpilledObjectCallback =
|
|||
/// A struct that includes info about the object.
|
||||
struct ObjectInfo {
|
||||
ObjectID object_id;
|
||||
int64_t data_size;
|
||||
int64_t metadata_size;
|
||||
int64_t data_size = 0;
|
||||
int64_t metadata_size = 0;
|
||||
/// Owner's raylet ID.
|
||||
NodeID owner_raylet_id;
|
||||
/// Owner's IP address.
|
||||
|
|
|
@ -260,8 +260,6 @@ message JobConfig {
|
|||
NON_DETACHED = 1;
|
||||
}
|
||||
|
||||
// The number of java workers per worker process.
|
||||
uint32 num_java_workers_per_process = 1;
|
||||
// The jvm options for java workers of the job.
|
||||
repeated string jvm_options = 2;
|
||||
// A list of directories or files (jar files or dynamic libraries) that specify the
|
||||
|
|
|
@ -198,14 +198,12 @@ void WorkerPool::update_worker_startup_token_counter() {
|
|||
|
||||
void WorkerPool::AddWorkerProcess(
|
||||
State &state,
|
||||
const int workers_to_start,
|
||||
const rpc::WorkerType worker_type,
|
||||
const Process &proc,
|
||||
const std::chrono::high_resolution_clock::time_point &start,
|
||||
const rpc::RuntimeEnvInfo &runtime_env_info) {
|
||||
state.worker_processes.emplace(worker_startup_token_counter_,
|
||||
WorkerProcessInfo{workers_to_start,
|
||||
workers_to_start,
|
||||
WorkerProcessInfo{/*is_pending_registration=*/true,
|
||||
{},
|
||||
worker_type,
|
||||
proc,
|
||||
|
@ -248,7 +246,7 @@ std::tuple<Process, StartupToken> WorkerPool::StartWorkerProcess(
|
|||
int starting_workers = 0;
|
||||
for (auto &entry : state.worker_processes) {
|
||||
if (entry.second.worker_type == worker_type) {
|
||||
starting_workers += entry.second.num_starting_workers;
|
||||
starting_workers += entry.second.is_pending_registration ? 1 : 0;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -268,13 +266,6 @@ std::tuple<Process, StartupToken> WorkerPool::StartWorkerProcess(
|
|||
<< rpc::WorkerType_Name(worker_type) << ", current pool has "
|
||||
<< state.idle.size() << " workers";
|
||||
|
||||
int workers_to_start = 1;
|
||||
if (dynamic_options.empty()) {
|
||||
if (language == Language::JAVA) {
|
||||
workers_to_start = job_config->num_java_workers_per_process();
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<std::string> options;
|
||||
|
||||
// Append Ray-defined per-job options here
|
||||
|
@ -312,12 +303,6 @@ std::tuple<Process, StartupToken> WorkerPool::StartWorkerProcess(
|
|||
}
|
||||
}
|
||||
|
||||
// Append Ray-defined per-process options here
|
||||
if (language == Language::JAVA) {
|
||||
options.push_back("-Dray.job.num-java-workers-per-process=" +
|
||||
std::to_string(workers_to_start));
|
||||
}
|
||||
|
||||
// Append startup-token for JAVA here
|
||||
if (language == Language::JAVA) {
|
||||
options.push_back("-Dray.raylet.startup-token=" +
|
||||
|
@ -460,13 +445,12 @@ std::tuple<Process, StartupToken> WorkerPool::StartWorkerProcess(
|
|||
auto duration = std::chrono::duration_cast<std::chrono::milliseconds>(end - start);
|
||||
stats::ProcessStartupTimeMs.Record(duration.count());
|
||||
stats::NumWorkersStarted.Record(1);
|
||||
RAY_LOG(INFO) << "Started worker process of " << workers_to_start
|
||||
<< " worker(s) with pid " << proc.GetId() << ", the token "
|
||||
RAY_LOG(INFO) << "Started worker process with pid " << proc.GetId() << ", the token is "
|
||||
<< worker_startup_token_counter_;
|
||||
AdjustWorkerOomScore(proc.GetId());
|
||||
MonitorStartingWorkerProcess(
|
||||
proc, worker_startup_token_counter_, language, worker_type);
|
||||
AddWorkerProcess(state, workers_to_start, worker_type, proc, start, runtime_env_info);
|
||||
AddWorkerProcess(state, worker_type, proc, start, runtime_env_info);
|
||||
StartupToken worker_startup_token = worker_startup_token_counter_;
|
||||
update_worker_startup_token_counter();
|
||||
if (IsIOWorkerType(worker_type)) {
|
||||
|
@ -513,7 +497,7 @@ void WorkerPool::MonitorStartingWorkerProcess(const Process &proc,
|
|||
// Since this process times out to start, remove it from worker_processes
|
||||
// to avoid the zombie worker.
|
||||
auto it = state.worker_processes.find(proc_startup_token);
|
||||
if (it != state.worker_processes.end() && it->second.num_starting_workers != 0) {
|
||||
if (it != state.worker_processes.end() && it->second.is_pending_registration) {
|
||||
RAY_LOG(ERROR)
|
||||
<< "Some workers of the worker process(" << proc.GetId()
|
||||
<< ") have not registered within the timeout. "
|
||||
|
@ -747,12 +731,10 @@ void WorkerPool::OnWorkerStarted(const std::shared_ptr<WorkerInterface> &worker)
|
|||
|
||||
auto it = state.worker_processes.find(worker_startup_token);
|
||||
if (it != state.worker_processes.end()) {
|
||||
it->second.num_starting_workers--;
|
||||
it->second.is_pending_registration = false;
|
||||
it->second.alive_started_workers.insert(worker);
|
||||
if (it->second.num_starting_workers == 0) {
|
||||
// We may have slots to start more workers now.
|
||||
TryStartIOWorkers(worker->GetLanguage());
|
||||
}
|
||||
// We may have slots to start more workers now.
|
||||
TryStartIOWorkers(worker->GetLanguage());
|
||||
}
|
||||
const auto &worker_type = worker->GetWorkerType();
|
||||
if (IsIOWorkerType(worker_type)) {
|
||||
|
@ -1044,8 +1026,7 @@ void WorkerPool::TryKillingIdleWorkers() {
|
|||
auto &worker_state = GetStateForLanguage(idle_worker->GetLanguage());
|
||||
|
||||
auto it = worker_state.worker_processes.find(worker_startup_token);
|
||||
if (it != worker_state.worker_processes.end() &&
|
||||
it->second.num_starting_workers > 0) {
|
||||
if (it != worker_state.worker_processes.end() && it->second.is_pending_registration) {
|
||||
// A Java worker process may hold multiple workers.
|
||||
// Some workers of this process are pending registration. Skip killing this worker.
|
||||
continue;
|
||||
|
@ -1349,7 +1330,7 @@ void WorkerPool::PrestartWorkers(const TaskSpecification &task_spec,
|
|||
// The number of available workers that can be used for this task spec.
|
||||
int num_usable_workers = state.idle.size();
|
||||
for (auto &entry : state.worker_processes) {
|
||||
num_usable_workers += entry.second.num_starting_workers;
|
||||
num_usable_workers += entry.second.is_pending_registration ? 1 : 0;
|
||||
}
|
||||
// Some existing workers may be holding less than 1 CPU each, so we should
|
||||
// start as many workers as needed to fill up the remaining CPUs.
|
||||
|
@ -1376,10 +1357,10 @@ void WorkerPool::DisconnectWorker(const std::shared_ptr<WorkerInterface> &worker
|
|||
if (!RemoveWorker(it->second.alive_started_workers, worker)) {
|
||||
// Worker is either starting or started,
|
||||
// if it's not started, we should remove it from starting.
|
||||
it->second.num_starting_workers--;
|
||||
it->second.is_pending_registration = false;
|
||||
}
|
||||
if (it->second.alive_started_workers.size() == 0 &&
|
||||
it->second.num_starting_workers == 0) {
|
||||
!it->second.is_pending_registration) {
|
||||
DeleteRuntimeEnvIfPossible(it->second.runtime_env_info.serialized_runtime_env());
|
||||
RemoveWorkerProcess(state, worker->GetStartupToken());
|
||||
}
|
||||
|
@ -1508,7 +1489,8 @@ void WorkerPool::WarnAboutSize() {
|
|||
num_workers_started_or_registered +=
|
||||
static_cast<int64_t>(state.registered_workers.size());
|
||||
for (const auto &starting_process : state.worker_processes) {
|
||||
num_workers_started_or_registered += starting_process.second.num_starting_workers;
|
||||
num_workers_started_or_registered +=
|
||||
starting_process.second.is_pending_registration ? 0 : 1;
|
||||
}
|
||||
// Don't count IO workers towards the warning message threshold.
|
||||
num_workers_started_or_registered -= RayConfig::instance().max_io_workers() * 2;
|
||||
|
|
|
@ -473,10 +473,8 @@ class WorkerPool : public WorkerPoolInterface, public IOWorkerPoolInterface {
|
|||
|
||||
/// Some basic information about the worker process.
|
||||
struct WorkerProcessInfo {
|
||||
/// The number of workers in the worker process.
|
||||
int num_workers;
|
||||
/// The number of pending registration workers in the worker process.
|
||||
int num_starting_workers;
|
||||
/// Whether this worker is pending registration or is started.
|
||||
bool is_pending_registration = true;
|
||||
/// The started workers which is alive.
|
||||
std::unordered_set<std::shared_ptr<WorkerInterface>> alive_started_workers;
|
||||
/// The type of the worker.
|
||||
|
@ -684,7 +682,6 @@ class WorkerPool : public WorkerPoolInterface, public IOWorkerPoolInterface {
|
|||
void DeleteRuntimeEnvIfPossible(const std::string &serialized_runtime_env);
|
||||
|
||||
void AddWorkerProcess(State &state,
|
||||
const int workers_to_start,
|
||||
const rpc::WorkerType worker_type,
|
||||
const Process &proc,
|
||||
const std::chrono::high_resolution_clock::time_point &start,
|
||||
|
|
|
@ -26,8 +26,7 @@ namespace ray {
|
|||
|
||||
namespace raylet {
|
||||
|
||||
int NUM_WORKERS_PER_PROCESS_JAVA = 3;
|
||||
int MAXIMUM_STARTUP_CONCURRENCY = 5;
|
||||
int MAXIMUM_STARTUP_CONCURRENCY = 15;
|
||||
int MAX_IO_WORKER_SIZE = 2;
|
||||
int POOL_SIZE_SOFT_LIMIT = 5;
|
||||
int WORKER_REGISTER_TIMEOUT_SECONDS = 3;
|
||||
|
@ -37,10 +36,6 @@ const std::string BAD_RUNTIME_ENV_ERROR_MSG = "bad runtime env";
|
|||
|
||||
std::vector<Language> LANGUAGES = {Language::PYTHON, Language::JAVA};
|
||||
|
||||
static inline std::string GetNumJavaWorkersPerProcessSystemProperty(int num) {
|
||||
return std::string("-Dray.job.num-java-workers-per-process=") + std::to_string(num);
|
||||
}
|
||||
|
||||
class MockWorkerClient : public rpc::CoreWorkerClientInterface {
|
||||
public:
|
||||
MockWorkerClient(instrumented_io_context &io_service) : io_service_(io_service) {}
|
||||
|
@ -194,7 +189,7 @@ class WorkerPoolMock : public WorkerPool {
|
|||
int total = 0;
|
||||
for (auto &state_entry : states_by_lang_) {
|
||||
for (auto &process_entry : state_entry.second.worker_processes) {
|
||||
total += process_entry.second.num_starting_workers;
|
||||
total += process_entry.second.is_pending_registration ? 1 : 0;
|
||||
}
|
||||
}
|
||||
return total;
|
||||
|
@ -204,7 +199,7 @@ class WorkerPoolMock : public WorkerPool {
|
|||
int total = 0;
|
||||
for (auto &entry : states_by_lang_) {
|
||||
for (auto process : entry.second.worker_processes) {
|
||||
if (process.second.num_starting_workers != 0) {
|
||||
if (process.second.is_pending_registration) {
|
||||
total += 1;
|
||||
}
|
||||
}
|
||||
|
@ -313,7 +308,6 @@ class WorkerPoolMock : public WorkerPool {
|
|||
if (pushed_it == pushedProcesses_.end()) {
|
||||
int runtime_env_hash = 0;
|
||||
bool is_java = false;
|
||||
bool has_dynamic_options = false;
|
||||
// Parses runtime env hash to make sure the pushed workers can be popped out.
|
||||
for (auto command_args : it->second) {
|
||||
std::string runtime_env_key = "--runtime-env-hash=";
|
||||
|
@ -326,14 +320,9 @@ class WorkerPoolMock : public WorkerPool {
|
|||
if (pos != std::string::npos) {
|
||||
is_java = true;
|
||||
}
|
||||
pos = command_args.find("-X");
|
||||
if (pos != std::string::npos) {
|
||||
has_dynamic_options = true;
|
||||
}
|
||||
}
|
||||
// TODO(SongGuyang): support C++ language workers.
|
||||
int num_workers =
|
||||
(is_java && !has_dynamic_options) ? NUM_WORKERS_PER_PROCESS_JAVA : 1;
|
||||
int num_workers = 1;
|
||||
RAY_CHECK(timeout_worker_number <= num_workers)
|
||||
<< "The timeout worker number cannot exceed the total number of workers";
|
||||
auto register_workers = num_workers - timeout_worker_number;
|
||||
|
@ -471,7 +460,6 @@ class WorkerPoolTest : public ::testing::Test {
|
|||
worker_pool_ = std::make_unique<WorkerPoolMock>(
|
||||
io_service_, worker_commands, mock_worker_rpc_clients_);
|
||||
rpc::JobConfig job_config;
|
||||
job_config.set_num_java_workers_per_process(NUM_WORKERS_PER_PROCESS_JAVA);
|
||||
RegisterDriver(Language::PYTHON, JOB_ID, job_config);
|
||||
}
|
||||
|
||||
|
@ -489,18 +477,7 @@ class WorkerPoolTest : public ::testing::Test {
|
|||
ASSERT_TRUE(worker_pool_->NumWorkerProcessesStarting() <=
|
||||
expected_worker_process_count);
|
||||
Process prev = worker_pool_->LastStartedWorkerProcess();
|
||||
if (!std::equal_to<Process>()(last_started_worker_process, prev)) {
|
||||
last_started_worker_process = prev;
|
||||
const auto &real_command =
|
||||
worker_pool_->GetWorkerCommand(last_started_worker_process);
|
||||
if (language == Language::JAVA) {
|
||||
auto it = std::find(
|
||||
real_command.begin(),
|
||||
real_command.end(),
|
||||
GetNumJavaWorkersPerProcessSystemProperty(num_workers_per_process));
|
||||
ASSERT_NE(it, real_command.end());
|
||||
}
|
||||
} else {
|
||||
if (std::equal_to<Process>()(last_started_worker_process, prev)) {
|
||||
ASSERT_EQ(worker_pool_->NumWorkerProcessesStarting(),
|
||||
expected_worker_process_count);
|
||||
ASSERT_TRUE(i >= expected_worker_process_count);
|
||||
|
@ -618,9 +595,7 @@ TEST_F(WorkerPoolTest, HandleWorkerRegistration) {
|
|||
auto [proc, token] = worker_pool_->StartWorkerProcess(
|
||||
Language::JAVA, rpc::WorkerType::WORKER, JOB_ID, &status);
|
||||
std::vector<std::shared_ptr<WorkerInterface>> workers;
|
||||
for (int i = 0; i < NUM_WORKERS_PER_PROCESS_JAVA; i++) {
|
||||
workers.push_back(worker_pool_->CreateWorker(Process(), Language::JAVA));
|
||||
}
|
||||
workers.push_back(worker_pool_->CreateWorker(Process(), Language::JAVA));
|
||||
for (const auto &worker : workers) {
|
||||
// Check that there's still a starting worker process
|
||||
// before all workers have been registered
|
||||
|
@ -670,7 +645,7 @@ TEST_F(WorkerPoolTest, StartupPythonWorkerProcessCount) {
|
|||
}
|
||||
|
||||
TEST_F(WorkerPoolTest, StartupJavaWorkerProcessCount) {
|
||||
TestStartupWorkerProcessCount(Language::JAVA, NUM_WORKERS_PER_PROCESS_JAVA);
|
||||
TestStartupWorkerProcessCount(Language::JAVA, 1);
|
||||
}
|
||||
|
||||
TEST_F(WorkerPoolTest, InitialWorkerProcessCount) {
|
||||
|
@ -756,7 +731,6 @@ TEST_F(WorkerPoolTest, StartWorkerWithDynamicOptionsCommand) {
|
|||
|
||||
rpc::JobConfig job_config = rpc::JobConfig();
|
||||
job_config.add_code_search_path("/test/code_search_path");
|
||||
job_config.set_num_java_workers_per_process(NUM_WORKERS_PER_PROCESS_JAVA);
|
||||
job_config.add_jvm_options("-Xmx1g");
|
||||
job_config.add_jvm_options("-Xms500m");
|
||||
job_config.add_jvm_options("-Dmy-job.hello=world");
|
||||
|
@ -779,7 +753,6 @@ TEST_F(WorkerPoolTest, StartWorkerWithDynamicOptionsCommand) {
|
|||
expected_command.end(),
|
||||
{"-Xmx1g", "-Xms500m", "-Dmy-job.hello=world", "-Dmy-job.foo=bar"});
|
||||
// Ray-defined per-process options
|
||||
expected_command.push_back(GetNumJavaWorkersPerProcessSystemProperty(1));
|
||||
expected_command.push_back("-Dray.raylet.startup-token=0");
|
||||
expected_command.push_back("-Dray.internal.runtime-env-hash=1");
|
||||
// User-defined per-process options
|
||||
|
@ -1162,46 +1135,6 @@ TEST_F(WorkerPoolTest, DeleteWorkerPushPop) {
|
|||
});
|
||||
}
|
||||
|
||||
TEST_F(WorkerPoolTest, NoPopOnCrashedWorkerProcess) {
|
||||
// Start a Java worker process.
|
||||
PopWorkerStatus status;
|
||||
auto [proc, token] = worker_pool_->StartWorkerProcess(
|
||||
Language::JAVA, rpc::WorkerType::WORKER, JOB_ID, &status);
|
||||
auto worker1 = worker_pool_->CreateWorker(Process(), Language::JAVA);
|
||||
auto worker2 = worker_pool_->CreateWorker(Process(), Language::JAVA);
|
||||
|
||||
// We now imitate worker process crashing while core worker initializing.
|
||||
|
||||
// 1. we register both workers.
|
||||
RAY_CHECK_OK(worker_pool_->RegisterWorker(
|
||||
worker1, proc.GetId(), worker_pool_->GetStartupToken(proc), [](Status, int) {}));
|
||||
RAY_CHECK_OK(worker_pool_->RegisterWorker(
|
||||
worker2, proc.GetId(), worker_pool_->GetStartupToken(proc), [](Status, int) {}));
|
||||
|
||||
// 2. announce worker port for worker 1. When interacting with worker pool, it's
|
||||
// PushWorker.
|
||||
worker_pool_->PushWorker(worker1);
|
||||
|
||||
// 3. kill the worker process. Now let's assume that Raylet found that the connection
|
||||
// with worker 1 disconnected first.
|
||||
worker_pool_->DisconnectWorker(
|
||||
worker1, /*disconnect_type=*/rpc::WorkerExitType::SYSTEM_ERROR_EXIT);
|
||||
|
||||
// 4. but the RPC for announcing worker port for worker 2 is already in Raylet input
|
||||
// buffer. So now Raylet needs to handle worker 2.
|
||||
worker_pool_->PushWorker(worker2);
|
||||
|
||||
// 5. Let's try to pop a worker to execute a task. Worker 2 shouldn't be popped because
|
||||
// the process has crashed.
|
||||
const auto task_spec = ExampleTaskSpec();
|
||||
ASSERT_NE(worker_pool_->PopWorkerSync(task_spec), worker1);
|
||||
ASSERT_NE(worker_pool_->PopWorkerSync(task_spec), worker2);
|
||||
|
||||
// 6. Now Raylet disconnects with worker 2.
|
||||
worker_pool_->DisconnectWorker(
|
||||
worker2, /*disconnect_type=*/rpc::WorkerExitType::SYSTEM_ERROR_EXIT);
|
||||
}
|
||||
|
||||
TEST_F(WorkerPoolTest, TestWorkerCapping) {
|
||||
auto job_id = JOB_ID;
|
||||
|
||||
|
@ -1423,8 +1356,7 @@ TEST_F(WorkerPoolTest, TestWorkerCappingWithExitDelay) {
|
|||
PopWorkerStatus status;
|
||||
auto [proc, token] = worker_pool_->StartWorkerProcess(
|
||||
language, rpc::WorkerType::WORKER, JOB_ID, &status);
|
||||
int workers_to_start =
|
||||
language == Language::JAVA ? NUM_WORKERS_PER_PROCESS_JAVA : 1;
|
||||
int workers_to_start = 1;
|
||||
for (int j = 0; j < workers_to_start; j++) {
|
||||
auto worker = worker_pool_->CreateWorker(Process(), language);
|
||||
worker->SetStartupToken(worker_pool_->GetStartupToken(proc));
|
||||
|
@ -1644,77 +1576,6 @@ TEST_F(WorkerPoolTest, RuntimeEnvUriReferenceWorkerLevel) {
|
|||
}
|
||||
}
|
||||
|
||||
TEST_F(WorkerPoolTest, RuntimeEnvUriReferenceWithMultipleWorkers) {
|
||||
auto job_id = JOB_ID;
|
||||
std::string uri = "s3://567";
|
||||
auto runtime_env_info = ExampleRuntimeEnvInfo({uri}, false);
|
||||
rpc::JobConfig job_config;
|
||||
job_config.set_num_java_workers_per_process(NUM_WORKERS_PER_PROCESS_JAVA);
|
||||
job_config.mutable_runtime_env_info()->CopyFrom(runtime_env_info);
|
||||
// Start job without eager installed runtime env.
|
||||
worker_pool_->HandleJobStarted(job_id, job_config);
|
||||
ASSERT_EQ(GetReferenceCount(runtime_env_info.serialized_runtime_env()), 0);
|
||||
|
||||
// First part, test normal case with all worker registered.
|
||||
{
|
||||
// Start actors with runtime env. The Java actors will trigger a multi-worker process.
|
||||
std::vector<std::shared_ptr<WorkerInterface>> workers;
|
||||
for (int i = 0; i < NUM_WORKERS_PER_PROCESS_JAVA; i++) {
|
||||
auto actor_creation_id = ActorID::Of(job_id, TaskID::ForDriverTask(job_id), i + 1);
|
||||
const auto actor_creation_task_spec =
|
||||
ExampleTaskSpec(ActorID::Nil(),
|
||||
Language::JAVA,
|
||||
job_id,
|
||||
actor_creation_id,
|
||||
{},
|
||||
TaskID::FromRandom(JobID::Nil()),
|
||||
runtime_env_info);
|
||||
auto popped_actor_worker = worker_pool_->PopWorkerSync(actor_creation_task_spec);
|
||||
ASSERT_NE(popped_actor_worker, nullptr);
|
||||
workers.push_back(popped_actor_worker);
|
||||
ASSERT_EQ(GetReferenceCount(runtime_env_info.serialized_runtime_env()), 1);
|
||||
}
|
||||
// Make sure only one worker process has been started.
|
||||
ASSERT_EQ(worker_pool_->GetProcessSize(), 1);
|
||||
// Disconnect all actor workers.
|
||||
for (auto &worker : workers) {
|
||||
worker_pool_->DisconnectWorker(worker, rpc::WorkerExitType::IDLE_EXIT);
|
||||
}
|
||||
ASSERT_EQ(GetReferenceCount(runtime_env_info.serialized_runtime_env()), 0);
|
||||
}
|
||||
|
||||
// Second part, test corner case with some worker registration timeout.
|
||||
{
|
||||
// Start one actor with runtime env. The Java actor will trigger a multi-worker
|
||||
// process.
|
||||
auto actor_creation_id = ActorID::Of(job_id, TaskID::ForDriverTask(job_id), 1);
|
||||
const auto actor_creation_task_spec =
|
||||
ExampleTaskSpec(ActorID::Nil(),
|
||||
Language::JAVA,
|
||||
job_id,
|
||||
actor_creation_id,
|
||||
{},
|
||||
TaskID::FromRandom(JobID::Nil()),
|
||||
runtime_env_info);
|
||||
PopWorkerStatus status;
|
||||
// Only one worker registration. All the other worker registration times out.
|
||||
auto popped_actor_worker = worker_pool_->PopWorkerSync(
|
||||
actor_creation_task_spec, true, &status, NUM_WORKERS_PER_PROCESS_JAVA - 1);
|
||||
ASSERT_EQ(GetReferenceCount(runtime_env_info.serialized_runtime_env()), 1);
|
||||
// Disconnect actor worker.
|
||||
worker_pool_->DisconnectWorker(popped_actor_worker, rpc::WorkerExitType::IDLE_EXIT);
|
||||
ASSERT_EQ(GetReferenceCount(runtime_env_info.serialized_runtime_env()), 1);
|
||||
// Sleep for a while to wait worker registration timeout.
|
||||
std::this_thread::sleep_for(
|
||||
std::chrono::seconds(WORKER_REGISTER_TIMEOUT_SECONDS + 1));
|
||||
ASSERT_EQ(GetReferenceCount(runtime_env_info.serialized_runtime_env()), 0);
|
||||
}
|
||||
|
||||
// Finish the job.
|
||||
worker_pool_->HandleJobFinished(job_id);
|
||||
ASSERT_EQ(GetReferenceCount(runtime_env_info.serialized_runtime_env()), 0);
|
||||
}
|
||||
|
||||
TEST_F(WorkerPoolTest, CacheWorkersByRuntimeEnvHash) {
|
||||
///
|
||||
/// Check that a worker can be popped only if there is a
|
||||
|
|
Loading…
Add table
Reference in a new issue