[Core] Remove multiple core workers in one process 1/n. (#24147)

This is the 1st PR to remove the code path of multiple core workers in one process. This PR is aiming to remove the flags and APIs related to `num_workers`.
After this PR checking in, we needn't to consider the multiple core workers any longer.

The further following PRs are related to the deeper logic refactor, like eliminating the gap between core worker and core worker process,  removing the logic related to multiple workers from workerpool, gcs and etc.

**BREAK CHANGE**
This PR removes these APIs:
- Ray.wrapRunnable();
- Ray.wrapCallable();
- Ray.setAsyncContext();
- Ray.getAsyncContext();

And the following APIs are not allowed to invoke in a user-created thread in local mode:
- Ray.getRuntimeContext().getCurrentActorId();
- Ray.getRuntimeContext().getCurrentTaskId()

Note that this PR shouldn't be merged to 1.x.
This commit is contained in:
Qing Wang 2022-05-19 00:36:22 +08:00 committed by GitHub
parent 1d5e6d908d
commit eb29895dbb
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
57 changed files with 137 additions and 1200 deletions

View file

@ -139,7 +139,6 @@ void ProcessHelper::RayStart(CoreWorkerOptions::TaskExecutionCallback callback)
options.node_manager_port = ConfigInternal::Instance().node_manager_port;
options.raylet_ip_address = node_ip;
options.driver_name = "cpp_worker";
options.num_workers = 1;
options.metrics_agent_port = -1;
options.task_execution_callback = callback;
options.startup_token = ConfigInternal::Instance().startup_token;

View file

@ -5,7 +5,6 @@ import io.ray.api.runtime.RayRuntimeFactory;
import io.ray.api.runtimecontext.RuntimeContext;
import java.util.List;
import java.util.Optional;
import java.util.concurrent.Callable;
/** This class contains all public APIs of Ray. */
public final class Ray extends RayCall {
@ -205,52 +204,6 @@ public final class Ray extends RayCall {
return internal().getActor(name, namespace);
}
/**
* If users want to use Ray API in their own threads, call this method to get the async context
* and then call {@link #setAsyncContext} at the beginning of the new thread.
*
* @return The async context.
*/
public static Object getAsyncContext() {
return internal().getAsyncContext();
}
/**
* Set the async context for the current thread.
*
* @param asyncContext The async context to set.
*/
public static void setAsyncContext(Object asyncContext) {
internal().setAsyncContext(asyncContext);
}
// TODO (kfstorm): add the `rollbackAsyncContext` API to allow rollbacking the async context of
// the current thread to the one before `setAsyncContext` is called.
// TODO (kfstorm): unify the `wrap*` methods.
/**
* If users want to use Ray API in their own threads, they should wrap their {@link Runnable}
* objects with this method.
*
* @param runnable The runnable to wrap.
* @return The wrapped runnable.
*/
public static Runnable wrapRunnable(Runnable runnable) {
return internal().wrapRunnable(runnable);
}
/**
* If users want to use Ray API in their own threads, they should wrap their {@link Callable}
* objects with this method.
*
* @param callable The callable to wrap.
* @return The wrapped callable.
*/
public static <T> Callable<T> wrapCallable(Callable<T> callable) {
return internal().wrapCallable(callable);
}
/** Get the underlying runtime instance. */
public static RayRuntime internal() {
if (runtime == null) {

View file

@ -24,7 +24,6 @@ import io.ray.api.runtimeenv.RuntimeEnv;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.Callable;
/** Base interface of a Ray runtime. */
public interface RayRuntime {
@ -207,26 +206,6 @@ public interface RayRuntime {
RuntimeContext getRuntimeContext();
Object getAsyncContext();
void setAsyncContext(Object asyncContext);
/**
* Wrap a {@link Runnable} with necessary context capture.
*
* @param runnable The runnable to wrap.
* @return The wrapped runnable.
*/
Runnable wrapRunnable(Runnable runnable);
/**
* Wrap a {@link Callable} with necessary context capture.
*
* @param callable The callable to wrap.
* @return The wrapped callable.
*/
<T> Callable<T> wrapCallable(Callable<T> callable);
/** Intentionally exit the current actor. */
void exitActor();

View file

@ -23,10 +23,7 @@ public class ActorPerformanceTestBase {
boolean hasReturn,
boolean ignoreReturn,
int argSize,
boolean useDirectByteBuffer,
int numJavaWorkerPerProcess) {
System.setProperty(
"ray.job.num-java-workers-per-process", String.valueOf(numJavaWorkerPerProcess));
boolean useDirectByteBuffer) {
System.setProperty("ray.raylet.startup-token", "0");
Ray.init();
try {

View file

@ -13,7 +13,6 @@ public class ActorPerformanceTestCase1 {
final int argSize = 0;
final boolean useDirectByteBuffer = false;
final boolean ignoreReturn = false;
final int numJavaWorkerPerProcess = 1;
ActorPerformanceTestBase.run(
args,
layers,
@ -21,7 +20,6 @@ public class ActorPerformanceTestCase1 {
hasReturn,
ignoreReturn,
argSize,
useDirectByteBuffer,
numJavaWorkerPerProcess);
useDirectByteBuffer);
}
}

View file

@ -31,7 +31,6 @@ import io.ray.runtime.functionmanager.FunctionDescriptor;
import io.ray.runtime.functionmanager.FunctionManager;
import io.ray.runtime.functionmanager.PyFunctionDescriptor;
import io.ray.runtime.functionmanager.RayFunction;
import io.ray.runtime.generated.Common;
import io.ray.runtime.generated.Common.Language;
import io.ray.runtime.object.ObjectRefImpl;
import io.ray.runtime.object.ObjectStore;
@ -46,7 +45,6 @@ import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.Callable;
import java.util.stream.Collectors;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@ -67,13 +65,8 @@ public abstract class AbstractRayRuntime implements RayRuntimeInternal {
private static ParallelActorContextImpl parallelActorContextImpl = new ParallelActorContextImpl();
/** Whether the required thread context is set on the current thread. */
final ThreadLocal<Boolean> isContextSet = ThreadLocal.withInitial(() -> false);
public AbstractRayRuntime(RayConfig rayConfig) {
this.rayConfig = rayConfig;
setIsContextSet(rayConfig.workerMode == Common.WorkerType.DRIVER);
functionManager = new FunctionManager(rayConfig.codeSearchPath);
runtimeContext = new RuntimeContextImpl(this);
}
@ -158,7 +151,7 @@ public abstract class AbstractRayRuntime implements RayRuntimeInternal {
@Override
public ObjectRef call(RayFunc func, Object[] args, CallOptions options) {
RayFunction rayFunction = functionManager.getFunction(workerContext.getCurrentJobId(), func);
RayFunction rayFunction = functionManager.getFunction(func);
FunctionDescriptor functionDescriptor = rayFunction.functionDescriptor;
Optional<Class<?>> returnType = rayFunction.getReturnType();
return callNormalFunction(functionDescriptor, args, returnType, options);
@ -176,7 +169,7 @@ public abstract class AbstractRayRuntime implements RayRuntimeInternal {
@Override
public ObjectRef callActor(
ActorHandle<?> actor, RayFunc func, Object[] args, CallOptions options) {
RayFunction rayFunction = functionManager.getFunction(workerContext.getCurrentJobId(), func);
RayFunction rayFunction = functionManager.getFunction(func);
FunctionDescriptor functionDescriptor = rayFunction.functionDescriptor;
Optional<Class<?>> returnType = rayFunction.getReturnType();
return callActorFunction(actor, functionDescriptor, args, returnType, options);
@ -201,8 +194,7 @@ public abstract class AbstractRayRuntime implements RayRuntimeInternal {
public <T> ActorHandle<T> createActor(
RayFunc actorFactoryFunc, Object[] args, ActorCreationOptions options) {
FunctionDescriptor functionDescriptor =
functionManager.getFunction(workerContext.getCurrentJobId(), actorFactoryFunc)
.functionDescriptor;
functionManager.getFunction(actorFactoryFunc).functionDescriptor;
return (ActorHandle<T>) createActorImpl(functionDescriptor, args, options);
}
@ -256,31 +248,6 @@ public abstract class AbstractRayRuntime implements RayRuntimeInternal {
return (T) taskSubmitter.getActor(actorId);
}
@Override
public void setAsyncContext(Object asyncContext) {
isContextSet.set(true);
}
@Override
public final Runnable wrapRunnable(Runnable runnable) {
Object asyncContext = getAsyncContext();
return () -> {
try (RayAsyncContextUpdater updater = new RayAsyncContextUpdater(asyncContext, this)) {
runnable.run();
}
};
}
@Override
public final <T> Callable<T> wrapCallable(Callable<T> callable) {
Object asyncContext = getAsyncContext();
return () -> {
try (RayAsyncContextUpdater updater = new RayAsyncContextUpdater(asyncContext, this)) {
return callable.call();
}
};
}
@Override
public ConcurrencyGroup createConcurrencyGroup(
String name, int maxConcurrency, List<RayFunc> funcs) {
@ -387,34 +354,6 @@ public abstract class AbstractRayRuntime implements RayRuntimeInternal {
return actor;
}
/// An auto closable class that is used for updating the async context when invoking Ray APIs.
private static final class RayAsyncContextUpdater implements AutoCloseable {
private AbstractRayRuntime runtime;
private boolean oldIsContextSet;
private Object oldAsyncContext = null;
public RayAsyncContextUpdater(Object asyncContext, AbstractRayRuntime runtime) {
this.runtime = runtime;
oldIsContextSet = runtime.isContextSet.get();
if (oldIsContextSet) {
oldAsyncContext = runtime.getAsyncContext();
}
runtime.setAsyncContext(asyncContext);
}
@Override
public void close() {
if (oldIsContextSet) {
runtime.setAsyncContext(oldAsyncContext);
} else {
runtime.setIsContextSet(false);
}
}
}
abstract List<ObjectId> getCurrentReturnIds(int numReturns, ActorId actorId);
@Override
@ -446,11 +385,6 @@ public abstract class AbstractRayRuntime implements RayRuntimeInternal {
return runtimeContext;
}
@Override
public void setIsContextSet(boolean isContextSet) {
this.isContextSet.set(isContextSet);
}
/// A helper to validate if the prepared return ids is as expected.
void validatePreparedReturnIds(List<ObjectId> preparedReturnIds, List<ObjectId> realReturnIds) {
if (rayConfig.runMode == RunMode.CLUSTER) {

View file

@ -24,9 +24,7 @@ public class ConcurrencyGroupImpl implements ConcurrencyGroup {
funcs.forEach(
func -> {
RayFunction rayFunc =
((RayRuntimeInternal) Ray.internal())
.getFunctionManager()
.getFunction(Ray.getRuntimeContext().getCurrentJobId(), func);
((RayRuntimeInternal) Ray.internal()).getFunctionManager().getFunction(func);
functionDescriptors.add(rayFunc.getFunctionDescriptor());
});
}

View file

@ -32,10 +32,7 @@ public class DefaultRayRuntimeFactory implements RayRuntimeFactory {
rayConfig.runMode == RunMode.SINGLE_PROCESS
? new RayDevRuntime(rayConfig)
: new RayNativeRuntime(rayConfig);
RayRuntimeInternal runtime =
rayConfig.numWorkersPerProcess > 1
? RayRuntimeProxy.newInstance(innerRuntime)
: innerRuntime;
RayRuntimeInternal runtime = innerRuntime;
runtime.start();
return runtime;
} catch (Exception e) {

View file

@ -1,6 +1,5 @@
package io.ray.runtime;
import com.google.common.base.Preconditions;
import io.ray.api.BaseActorHandle;
import io.ray.api.id.ActorId;
import io.ray.api.id.JobId;
@ -10,6 +9,7 @@ import io.ray.api.placementgroup.PlacementGroup;
import io.ray.api.runtimecontext.ResourceValue;
import io.ray.runtime.config.RayConfig;
import io.ray.runtime.context.LocalModeWorkerContext;
import io.ray.runtime.functionmanager.FunctionManager;
import io.ray.runtime.gcs.GcsClient;
import io.ray.runtime.generated.Common.TaskSpec;
import io.ray.runtime.object.LocalModeObjectStore;
@ -24,13 +24,9 @@ import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.atomic.AtomicInteger;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
public class RayDevRuntime extends AbstractRayRuntime {
private static final Logger LOGGER = LoggerFactory.getLogger(RayDevRuntime.class);
private AtomicInteger jobCounter = new AtomicInteger(0);
public RayDevRuntime(RayConfig rayConfig) {
@ -49,6 +45,7 @@ public class RayDevRuntime extends AbstractRayRuntime {
taskExecutor = new LocalModeTaskExecutor(this);
workerContext = new LocalModeWorkerContext(rayConfig.getJobId());
objectStore = new LocalModeObjectStore(workerContext);
functionManager = new FunctionManager(rayConfig.codeSearchPath);
taskSubmitter =
new LocalModeTaskSubmitter(this, taskExecutor, (LocalModeObjectStore) objectStore);
((LocalModeObjectStore) objectStore)
@ -90,19 +87,6 @@ public class RayDevRuntime extends AbstractRayRuntime {
throw new UnsupportedOperationException("Ray doesn't have gcs client in local mode.");
}
@Override
public Object getAsyncContext() {
return new AsyncContext(((LocalModeWorkerContext) workerContext).getCurrentTask());
}
@Override
public void setAsyncContext(Object asyncContext) {
Preconditions.checkNotNull(asyncContext);
TaskSpec task = ((AsyncContext) asyncContext).task;
((LocalModeWorkerContext) workerContext).setCurrentTask(task);
super.setAsyncContext(asyncContext);
}
@Override
public Map<String, List<ResourceValue>> getAvailableResourceIds() {
throw new UnsupportedOperationException("Ray doesn't support get resources ids in local mode.");

View file

@ -14,6 +14,7 @@ import io.ray.api.runtimecontext.ResourceValue;
import io.ray.runtime.config.RayConfig;
import io.ray.runtime.context.NativeWorkerContext;
import io.ray.runtime.exception.RayIntentionalSystemExitException;
import io.ray.runtime.functionmanager.FunctionManager;
import io.ray.runtime.gcs.GcsClient;
import io.ray.runtime.gcs.GcsClientOptions;
import io.ray.runtime.generated.Common.WorkerType;
@ -102,14 +103,13 @@ public final class RayNativeRuntime extends AbstractRayRuntime {
if (rayConfig.workerMode == WorkerType.DRIVER && rayConfig.getJobId() == JobId.NIL) {
rayConfig.setJobId(getGcsClient().nextJobId());
}
int numWorkersPerProcess =
rayConfig.workerMode == WorkerType.DRIVER ? 1 : rayConfig.numWorkersPerProcess;
// Make sure the job id has been set already.
functionManager = new FunctionManager(rayConfig.codeSearchPath);
byte[] serializedJobConfig = null;
if (rayConfig.workerMode == WorkerType.DRIVER) {
JobConfig.Builder jobConfigBuilder =
JobConfig.newBuilder()
.setNumJavaWorkersPerProcess(rayConfig.numWorkersPerProcess)
.addAllJvmOptions(rayConfig.jvmOptionsForJavaWorker)
.addAllCodeSearchPath(rayConfig.codeSearchPath)
.setRayNamespace(rayConfig.namespace);
@ -150,7 +150,6 @@ public final class RayNativeRuntime extends AbstractRayRuntime {
rayConfig.rayletSocketName,
(rayConfig.workerMode == WorkerType.DRIVER ? rayConfig.getJobId() : JobId.NIL).getBytes(),
new GcsClientOptions(rayConfig),
numWorkersPerProcess,
rayConfig.logDir,
serializedJobConfig,
rayConfig.getStartupToken(),
@ -223,19 +222,6 @@ public final class RayNativeRuntime extends AbstractRayRuntime {
nativeKillActor(actor.getId().getBytes(), noRestart);
}
@Override
public Object getAsyncContext() {
return new AsyncContext(
workerContext.getCurrentWorkerId(), workerContext.getCurrentClassLoader());
}
@Override
public void setAsyncContext(Object asyncContext) {
nativeSetCoreWorker(((AsyncContext) asyncContext).workerId.getBytes());
workerContext.setCurrentClassLoader(((AsyncContext) asyncContext).currentClassLoader);
super.setAsyncContext(asyncContext);
}
@Override
List<ObjectId> getCurrentReturnIds(int numReturns, ActorId actorId) {
List<byte[]> ret = nativeGetCurrentReturnIds(numReturns, actorId.getBytes());
@ -289,7 +275,6 @@ public final class RayNativeRuntime extends AbstractRayRuntime {
String rayletSocket,
byte[] jobId,
GcsClientOptions gcsClientOptions,
int numWorkersPerProcess,
String logDir,
byte[] serializedJobConfig,
int startupToken,

View file

@ -26,7 +26,5 @@ public interface RayRuntimeInternal extends RayRuntime {
GcsClient getGcsClient();
void setIsContextSet(boolean isContextSet);
void run();
}

View file

@ -1,80 +0,0 @@
package io.ray.runtime;
import io.ray.api.runtime.RayRuntime;
import io.ray.runtime.config.RunMode;
import io.ray.runtime.exception.RayException;
import java.lang.reflect.InvocationHandler;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
/**
* Protect a ray runtime with context checks for all methods of {@link RayRuntime} (except {@link
* RayRuntime#shutdown(boolean)}).
*/
public class RayRuntimeProxy implements InvocationHandler {
/** The original runtime. */
private AbstractRayRuntime obj;
private RayRuntimeProxy(AbstractRayRuntime obj) {
this.obj = obj;
}
public AbstractRayRuntime getRuntimeObject() {
return obj;
}
/** Generate a new instance of {@link RayRuntimeInternal} with additional context check. */
static RayRuntimeInternal newInstance(AbstractRayRuntime obj) {
return (RayRuntimeInternal)
java.lang.reflect.Proxy.newProxyInstance(
obj.getClass().getClassLoader(),
new Class<?>[] {RayRuntimeInternal.class},
new RayRuntimeProxy(obj));
}
@Override
public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
if (isInterfaceMethod(method)
&& !method.getName().equals("shutdown")
&& !method.getName().equals("setAsyncContext")) {
checkIsContextSet();
}
try {
return method.invoke(obj, args);
} catch (InvocationTargetException e) {
if (e.getCause() != null) {
throw e.getCause();
} else {
throw e;
}
}
}
/** Whether the method is defined in the {@link RayRuntime} interface. */
private boolean isInterfaceMethod(Method method) {
try {
RayRuntime.class.getMethod(method.getName(), method.getParameterTypes());
return true;
} catch (NoSuchMethodException e) {
return false;
}
}
/**
* Check if thread context is set.
*
* <p>This method should be invoked at the beginning of most public methods of {@link RayRuntime},
* otherwise the native code might crash due to thread local core worker was not set. We check it
* for {@link AbstractRayRuntime} instead of {@link RayNativeRuntime} because we want to catch the
* error even if the application runs in {@link RunMode#SINGLE_PROCESS} mode.
*/
private void checkIsContextSet() {
if (!obj.isContextSet.get()) {
throw new RayException(
"`Ray.wrap***` is not called on the current thread."
+ " If you want to use Ray API in your own threads,"
+ " please wrap your executable with `Ray.wrap***`.");
}
}
}

View file

@ -75,8 +75,6 @@ public class RayConfig {
public final List<String> headArgs;
public final int numWorkersPerProcess;
public final String namespace;
public final List<String> jvmOptionsForJavaWorker;
@ -190,8 +188,6 @@ public class RayConfig {
}
codeSearchPath = Arrays.asList(codeSearchPathString.split(":"));
numWorkersPerProcess = config.getInt("ray.job.num-java-workers-per-process");
startupToken = config.getInt("ray.raylet.startup-token");
/// Driver needn't this config item.

View file

@ -48,31 +48,21 @@ public class LocalModeWorkerContext implements WorkerContext {
@Override
public ActorId getCurrentActorId() {
TaskSpec taskSpec = currentTask.get();
if (taskSpec == null) {
return ActorId.NIL;
}
checkTaskSpecNotNull(taskSpec);
return LocalModeTaskSubmitter.getActorId(taskSpec);
}
@Override
public ClassLoader getCurrentClassLoader() {
return null;
}
@Override
public void setCurrentClassLoader(ClassLoader currentClassLoader) {}
@Override
public TaskType getCurrentTaskType() {
TaskSpec taskSpec = currentTask.get();
Preconditions.checkNotNull(taskSpec, "Current task is not set.");
checkTaskSpecNotNull(taskSpec);
return taskSpec.getType();
}
@Override
public TaskId getCurrentTaskId() {
TaskSpec taskSpec = currentTask.get();
Preconditions.checkState(taskSpec != null);
checkTaskSpecNotNull(taskSpec);
return TaskId.fromBytes(taskSpec.getTaskId().toByteArray());
}
@ -85,7 +75,9 @@ public class LocalModeWorkerContext implements WorkerContext {
currentTask.set(taskSpec);
}
public TaskSpec getCurrentTask() {
return currentTask.get();
private static void checkTaskSpecNotNull(TaskSpec taskSpec) {
Preconditions.checkNotNull(
taskSpec,
"Current task is not set. Maybe you invoked this API in a user-created thread not managed by Ray. Invoking this API in a user-created thread is not supported yet in local mode. You can switch to cluster mode.");
}
}

View file

@ -12,7 +12,7 @@ import java.nio.ByteBuffer;
/** Worker context for cluster mode. This is a wrapper class for worker context of core worker. */
public class NativeWorkerContext implements WorkerContext {
private final ThreadLocal<ClassLoader> currentClassLoader = new ThreadLocal<>();
private ClassLoader currentClassLoader = null;
@Override
public UniqueId getCurrentWorkerId() {
@ -29,18 +29,6 @@ public class NativeWorkerContext implements WorkerContext {
return ActorId.fromByteBuffer(nativeGetCurrentActorId());
}
@Override
public ClassLoader getCurrentClassLoader() {
return currentClassLoader.get();
}
@Override
public void setCurrentClassLoader(ClassLoader currentClassLoader) {
if (this.currentClassLoader.get() != currentClassLoader) {
this.currentClassLoader.set(currentClassLoader);
}
}
@Override
public TaskType getCurrentTaskType() {
return TaskType.forNumber(nativeGetCurrentTaskType());

View file

@ -19,15 +19,6 @@ public interface WorkerContext {
/** ID of the current actor. */
ActorId getCurrentActorId();
/**
* The class loader that is associated with the current job. It's used for locating classes when
* dealing with serialization and deserialization in {@link Serializer}.
*/
ClassLoader getCurrentClassLoader();
/** Set the current class loader. */
void setCurrentClassLoader(ClassLoader currentClassLoader);
/** Type of the current task. */
TaskType getCurrentTaskType();

View file

@ -2,7 +2,6 @@ package io.ray.runtime.functionmanager;
import com.google.common.collect.Lists;
import io.ray.api.function.RayFunc;
import io.ray.api.id.JobId;
import io.ray.runtime.util.LambdaUtils;
import java.io.File;
import java.lang.invoke.SerializedLambda;
@ -34,7 +33,7 @@ import org.objectweb.asm.Type;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/** Manages functions by job id. */
/** Manages functions in the current worker. */
public class FunctionManager {
private static final Logger LOGGER = LoggerFactory.getLogger(FunctionManager.class);
@ -50,8 +49,8 @@ public class FunctionManager {
private static final ThreadLocal<WeakHashMap<Class<? extends RayFunc>, JavaFunctionDescriptor>>
RAY_FUNC_CACHE = ThreadLocal.withInitial(WeakHashMap::new);
/** Mapping from the job id to the functions that belong to this job. */
private ConcurrentMap<JobId, JobFunctionTable> jobFunctionTables = new ConcurrentHashMap<>();
/** The table that manages functions. */
private final JobFunctionTable jobFunctionTable;
/** The resource path which we can load the job's jar resources. */
private final List<String> codeSearchPath;
@ -63,16 +62,20 @@ public class FunctionManager {
*/
public FunctionManager(List<String> codeSearchPath) {
this.codeSearchPath = codeSearchPath;
jobFunctionTable = createJobFunctionTable();
}
public ClassLoader getClassLoader() {
return jobFunctionTable.classLoader;
}
/**
* Get the RayFunction from a RayFunc instance (a lambda).
*
* @param jobId current job id.
* @param func The lambda.
* @return A RayFunction object.
*/
public RayFunction getFunction(JobId jobId, RayFunc func) {
public RayFunction getFunction(RayFunc func) {
JavaFunctionDescriptor functionDescriptor = RAY_FUNC_CACHE.get().get(func.getClass());
if (functionDescriptor == null) {
// It's OK to not lock here, because it's OK to have multiple JavaFunctionDescriptor instances
@ -84,31 +87,21 @@ public class FunctionManager {
functionDescriptor = new JavaFunctionDescriptor(className, methodName, signature);
RAY_FUNC_CACHE.get().put(func.getClass(), functionDescriptor);
}
return getFunction(jobId, functionDescriptor);
return getFunction(functionDescriptor);
}
/**
* Get the RayFunction from a function descriptor.
*
* @param jobId Current job id.
* @param functionDescriptor The function descriptor.
* @return A RayFunction object.
*/
public RayFunction getFunction(JobId jobId, JavaFunctionDescriptor functionDescriptor) {
JobFunctionTable jobFunctionTable = jobFunctionTables.get(jobId);
if (jobFunctionTable == null) {
synchronized (this) {
jobFunctionTable = jobFunctionTables.get(jobId);
if (jobFunctionTable == null) {
jobFunctionTable = createJobFunctionTable(jobId);
jobFunctionTables.put(jobId, jobFunctionTable);
}
}
}
public RayFunction getFunction(JavaFunctionDescriptor functionDescriptor) {
return jobFunctionTable.getFunction(functionDescriptor);
}
private JobFunctionTable createJobFunctionTable(JobId jobId) {
/** A helper that creates function table. */
private JobFunctionTable createJobFunctionTable() {
ClassLoader classLoader;
if (codeSearchPath == null || codeSearchPath.isEmpty()) {
classLoader = getClass().getClassLoader();
@ -145,7 +138,7 @@ public class FunctionManager {
})
.toArray(URL[]::new);
classLoader = new URLClassLoader(urls);
LOGGER.debug("Resource loaded for job {} from path {}.", jobId, urls);
LOGGER.debug("Resource loaded from path {}.", urls);
}
return new JobFunctionTable(classLoader);

View file

@ -496,7 +496,6 @@ public class LocalModeTaskSubmitter implements TaskSubmitter {
? objectStore.getRaw(Collections.singletonList(arg.id), -1).get(0)
: arg.value)
.collect(Collectors.toList());
runtime.setIsContextSet(true);
((LocalModeWorkerContext) runtime.getWorkerContext()).setCurrentTask(taskSpec);
((LocalModeWorkerContext) runtime.getWorkerContext()).setCurrentWorkerId(workerId);
@ -513,9 +512,7 @@ public class LocalModeTaskSubmitter implements TaskSubmitter {
// Set this flag to true is necessary because at the end of `taskExecutor.execute()`,
// this flag will be set to false. And `runtime.getWorkerContext()` requires it to be
// true.
runtime.setIsContextSet(true);
((LocalModeWorkerContext) runtime.getWorkerContext()).setCurrentTask(null);
runtime.setIsContextSet(false);
List<ObjectId> returnIds = getReturnIds(taskSpec);
for (int i = 0; i < returnIds.size(); i++) {
NativeRayObject putObject;

View file

@ -32,6 +32,7 @@ public abstract class TaskExecutor<T extends TaskExecutor.ActorContext> {
protected final RayRuntimeInternal runtime;
// TODO(qwang): Use actorContext instead later.
private final ConcurrentHashMap<UniqueId, T> actorContextMap = new ConcurrentHashMap<>();
private final ThreadLocal<RayFunction> localRayFunction = new ThreadLocal<>();
@ -66,7 +67,7 @@ public abstract class TaskExecutor<T extends TaskExecutor.ActorContext> {
private RayFunction getRayFunction(List<String> rayFunctionInfo) {
JobId jobId = runtime.getWorkerContext().getCurrentJobId();
JavaFunctionDescriptor functionDescriptor = parseFunctionDescriptor(rayFunctionInfo);
return runtime.getFunctionManager().getFunction(jobId, functionDescriptor);
return runtime.getFunctionManager().getFunction(functionDescriptor);
}
/** The return value indicates which parameters are ByteBuffer. */
@ -93,7 +94,6 @@ public abstract class TaskExecutor<T extends TaskExecutor.ActorContext> {
}
protected List<NativeRayObject> execute(List<String> rayFunctionInfo, List<Object> argsBytes) {
runtime.setIsContextSet(true);
TaskType taskType = runtime.getWorkerContext().getCurrentTaskType();
TaskId taskId = runtime.getWorkerContext().getCurrentTaskId();
LOGGER.debug("Executing task {} {}", taskId, rayFunctionInfo);
@ -108,7 +108,6 @@ public abstract class TaskExecutor<T extends TaskExecutor.ActorContext> {
}
List<NativeRayObject> returnObjects = new ArrayList<>();
ClassLoader oldLoader = Thread.currentThread().getContextClassLoader();
// Find the executable object.
RayFunction rayFunction = localRayFunction.get();
@ -121,7 +120,6 @@ public abstract class TaskExecutor<T extends TaskExecutor.ActorContext> {
rayFunction = getRayFunction(rayFunctionInfo);
}
Thread.currentThread().setContextClassLoader(rayFunction.classLoader);
runtime.getWorkerContext().setCurrentClassLoader(rayFunction.classLoader);
// Get local actor object and arguments.
Object actor = null;
@ -215,10 +213,6 @@ public abstract class TaskExecutor<T extends TaskExecutor.ActorContext> {
} else {
throw new RayActorException(e);
}
} finally {
Thread.currentThread().setContextClassLoader(oldLoader);
runtime.getWorkerContext().setCurrentClassLoader(null);
runtime.setIsContextSet(false);
}
return returnObjects;
}

View file

@ -47,7 +47,7 @@ public final class MethodUtils {
/// This code path indicates that here might be in another thread of a worker.
/// So try to load the class from URLClassLoader of this worker.
ClassLoader cl =
((RayRuntimeInternal) Ray.internal()).getWorkerContext().getCurrentClassLoader();
((RayRuntimeInternal) Ray.internal()).getFunctionManager().getClassLoader();
actorClz = Class.forName(className, true, cl);
}
} catch (Exception e) {

View file

@ -28,9 +28,7 @@ public class ParallelActorContextImpl implements ParallelActorContext {
FunctionManager functionManager = ((RayRuntimeInternal) Ray.internal()).getFunctionManager();
JavaFunctionDescriptor functionDescriptor =
functionManager
.getFunction(Ray.getRuntimeContext().getCurrentJobId(), ctorFunc)
.getFunctionDescriptor();
functionManager.getFunction(ctorFunc).getFunctionDescriptor();
ActorHandle<ParallelActorExecutorImpl> parallelExecutorHandle =
Ray.actor(ParallelActorExecutorImpl::new, parallelism, functionDescriptor)
.setConcurrencyGroups(concurrencyGroups)
@ -46,9 +44,7 @@ public class ParallelActorContextImpl implements ParallelActorContext {
((ParallelActorHandleImpl) parallelActorHandle).getExecutor();
FunctionManager functionManager = ((RayRuntimeInternal) Ray.internal()).getFunctionManager();
JavaFunctionDescriptor functionDescriptor =
functionManager
.getFunction(Ray.getRuntimeContext().getCurrentJobId(), func)
.getFunctionDescriptor();
functionManager.getFunction(func).getFunctionDescriptor();
ObjectRef<Object> ret =
parallelExecutor
.task(ParallelActorExecutorImpl::execute, instanceId, functionDescriptor, args)

View file

@ -23,9 +23,7 @@ public class ParallelActorExecutorImpl {
throws InvocationTargetException, IllegalAccessException {
functionManager = ((RayRuntimeInternal) Ray.internal()).getFunctionManager();
RayFunction init =
functionManager.getFunction(
Ray.getRuntimeContext().getCurrentJobId(), javaFunctionDescriptor);
RayFunction init = functionManager.getFunction(javaFunctionDescriptor);
Thread.currentThread().setContextClassLoader(init.classLoader);
for (int i = 0; i < parallelism; ++i) {
Object instance = init.getMethod().invoke(null, null);
@ -35,8 +33,7 @@ public class ParallelActorExecutorImpl {
public Object execute(int instanceId, JavaFunctionDescriptor functionDescriptor, Object[] args)
throws IllegalAccessException, InvocationTargetException {
RayFunction func =
functionManager.getFunction(Ray.getRuntimeContext().getCurrentJobId(), functionDescriptor);
RayFunction func = functionManager.getFunction(functionDescriptor);
Preconditions.checkState(instances.containsKey(instanceId));
return func.getMethod().invoke(instances.get(instanceId), args);
}

View file

@ -27,8 +27,6 @@ ray {
// search path for user code. This will be used as `CLASSPATH` in Java,
// and `PYTHONPATH` in Python.
code-search-path: ""
/// The number of java worker per worker process.
num-java-workers-per-process: 1
/// The jvm options for java workers of the job.
jvm-options: []

View file

@ -61,6 +61,8 @@ public class FunctionManagerTest {
}
}
private static final JobId JOB_ID = JobId.fromInt(1);
private static RayFunc0<Object> fooFunc;
private static RayFunc1<ChildClass, Object> childClassBarFunc;
private static RayFunc0<ChildClass> childClassConstructor;
@ -95,17 +97,17 @@ public class FunctionManagerTest {
public void testGetFunctionFromRayFunc() {
final FunctionManager functionManager = new FunctionManager(null);
// Test normal function.
RayFunction func = functionManager.getFunction(JobId.NIL, fooFunc);
RayFunction func = functionManager.getFunction(fooFunc);
Assert.assertFalse(func.isConstructor());
Assert.assertEquals(func.getFunctionDescriptor(), fooDescriptor);
// Test actor method
func = functionManager.getFunction(JobId.NIL, childClassBarFunc);
func = functionManager.getFunction(childClassBarFunc);
Assert.assertFalse(func.isConstructor());
Assert.assertEquals(func.getFunctionDescriptor(), childClassBarDescriptor);
// Test actor constructor
func = functionManager.getFunction(JobId.NIL, childClassConstructor);
func = functionManager.getFunction(childClassConstructor);
Assert.assertTrue(func.isConstructor());
Assert.assertEquals(func.getFunctionDescriptor(), childClassConstructorDescriptor);
}
@ -114,17 +116,17 @@ public class FunctionManagerTest {
public void testGetFunctionFromFunctionDescriptor() {
final FunctionManager functionManager = new FunctionManager(null);
// Test normal function.
RayFunction func = functionManager.getFunction(JobId.NIL, fooDescriptor);
RayFunction func = functionManager.getFunction(fooDescriptor);
Assert.assertFalse(func.isConstructor());
Assert.assertEquals(func.getFunctionDescriptor(), fooDescriptor);
// Test actor method
func = functionManager.getFunction(JobId.NIL, childClassBarDescriptor);
func = functionManager.getFunction(childClassBarDescriptor);
Assert.assertFalse(func.isConstructor());
Assert.assertEquals(func.getFunctionDescriptor(), childClassBarDescriptor);
// Test actor constructor
func = functionManager.getFunction(JobId.NIL, childClassConstructorDescriptor);
func = functionManager.getFunction(childClassConstructorDescriptor);
Assert.assertTrue(func.isConstructor());
Assert.assertEquals(func.getFunctionDescriptor(), childClassConstructorDescriptor);
@ -133,7 +135,6 @@ public class FunctionManagerTest {
RuntimeException.class,
() -> {
functionManager.getFunction(
JobId.NIL,
new JavaFunctionDescriptor(
FunctionManagerTest.class.getName(), "overloadFunction", ""));
});
@ -146,11 +147,10 @@ public class FunctionManagerTest {
fooDescriptor =
new JavaFunctionDescriptor(ParentClass.class.getName(), "foo", "()Ljava/lang/Object;");
Assert.assertEquals(
functionManager.getFunction(JobId.NIL, fooDescriptor).executable.getDeclaringClass(),
functionManager.getFunction(fooDescriptor).executable.getDeclaringClass(),
ParentClass.class);
RayFunction fooFunc =
functionManager.getFunction(
JobId.NIL,
new JavaFunctionDescriptor(ChildClass.class.getName(), "foo", "()Ljava/lang/Object;"));
Assert.assertEquals(fooFunc.executable.getDeclaringClass(), ParentClass.class);
@ -159,21 +159,16 @@ public class FunctionManagerTest {
childClassBarDescriptor =
new JavaFunctionDescriptor(ParentClass.class.getName(), "bar", "()Ljava/lang/Object;");
Assert.assertEquals(
functionManager
.getFunction(JobId.NIL, childClassBarDescriptor)
.executable
.getDeclaringClass(),
functionManager.getFunction(childClassBarDescriptor).executable.getDeclaringClass(),
ParentClass.class);
RayFunction barFunc =
functionManager.getFunction(
JobId.NIL,
new JavaFunctionDescriptor(ChildClass.class.getName(), "bar", "()Ljava/lang/Object;"));
Assert.assertEquals(barFunc.executable.getDeclaringClass(), ChildClass.class);
// Check interface default methods.
RayFunction interfaceNameFunc =
functionManager.getFunction(
JobId.NIL,
new JavaFunctionDescriptor(
ChildClass.class.getName(), "interfaceName", "()Ljava/lang/String;"));
Assert.assertEquals(
@ -220,7 +215,6 @@ public class FunctionManagerTest {
@Test
public void testGetFunctionFromLocalResource() throws Exception {
JobId jobId = JobId.fromInt(1);
final String codeSearchPath = FileUtils.getTempDirectoryPath() + "/ray_test_resources/";
File jobResourceDir = new File(codeSearchPath);
FileUtils.deleteQuietly(jobResourceDir);
@ -250,7 +244,7 @@ public class FunctionManagerTest {
new JavaFunctionDescriptor("DemoApp", "hello", "()Ljava/lang/String;");
final FunctionManager functionManager =
new FunctionManager(Collections.singletonList(codeSearchPath));
RayFunction func = functionManager.getFunction(jobId, descriptor);
RayFunction func = functionManager.getFunction(descriptor);
Assert.assertEquals(func.getFunctionDescriptor(), descriptor);
}
}

View file

@ -1,7 +1,6 @@
package io.ray.serve;
import com.google.common.collect.ImmutableMap;
import io.ray.api.Ray;
import io.ray.runtime.metric.Count;
import io.ray.runtime.metric.Metrics;
import io.ray.runtime.metric.TagKey;
@ -47,8 +46,6 @@ public class HttpProxy implements ServeProxy {
private ProxyRouter proxyRouter;
private Object asyncContext = Ray.getAsyncContext();
@Override
public void init(Map<String, String> config, ProxyRouter proxyRouter) {
this.port =
@ -101,8 +98,6 @@ public class HttpProxy implements ServeProxy {
ClassicHttpRequest request, ClassicHttpResponse response, HttpContext context)
throws HttpException, IOException {
Ray.setAsyncContext(asyncContext);
int code = HttpURLConnection.HTTP_OK;
Object result = null;
String route = request.getPath();

View file

@ -122,8 +122,7 @@ for file in "$docdemo_path"*.java; do
file=${file#"$docdemo_path"}
class=${file%".java"}
echo "Running $class"
java -cp bazel-bin/java/all_tests_shaded.jar -Dray.job.num-java-workers-per-process=1\
-Dray.raylet.startup-token=0 "io.ray.docdemo.$class"
java -cp bazel-bin/java/all_tests_shaded.jar -Dray.raylet.startup-token=0 "io.ray.docdemo.$class"
done
popd

View file

@ -4,13 +4,11 @@ import io.ray.api.ActorHandle;
import io.ray.api.ObjectRef;
import io.ray.api.Ray;
import io.ray.runtime.actor.NativeActorHandle;
import io.ray.runtime.exception.RayActorException;
import io.ray.runtime.util.SystemUtil;
import java.lang.ref.Reference;
import java.lang.reflect.Field;
import java.lang.reflect.Method;
import java.util.Set;
import java.util.concurrent.TimeUnit;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.testng.Assert;
@ -65,7 +63,6 @@ public class ActorHandleReferenceCountTest {
public void testActorHandleReferenceCount() {
try {
System.setProperty("ray.job.num-java-workers-per-process", "1");
Ray.init();
ActorHandle<SignalActor> signal = Ray.actor(SignalActor::new).remote();
ActorHandle<MyActor> myActor = Ray.actor(MyActor::new).remote();
@ -83,27 +80,4 @@ public class ActorHandleReferenceCountTest {
Ray.shutdown();
}
}
public void testRemoveActorHandleReferenceInMultipleThreadedActor() throws InterruptedException {
System.setProperty("ray.job.num-java-workers-per-process", "5");
try {
Ray.init();
ActorHandle<MyActor> myActor1 = Ray.actor(MyActor::new).remote();
int pid1 = myActor1.task(MyActor::getPid).remote().get();
ActorHandle<MyActor> myActor2 = Ray.actor(MyActor::new).remote();
int pid2 = myActor2.task(MyActor::getPid).remote().get();
Assert.assertEquals(pid1, pid2);
del(myActor1);
TimeUnit.SECONDS.sleep(5);
Assert.assertThrows(
RayActorException.class,
() -> {
myActor1.task(MyActor::hello).remote().get();
});
/// myActor2 shouldn't be killed.
Assert.assertEquals("hello", myActor2.task(MyActor::hello).remote().get());
} finally {
Ray.shutdown();
}
}
}

View file

@ -1,167 +0,0 @@
package io.ray.test;
import io.ray.api.ActorHandle;
import io.ray.api.BaseActorHandle;
import io.ray.api.ObjectRef;
import io.ray.api.options.ActorCreationOptions;
import io.ray.api.options.CallOptions;
import io.ray.runtime.AbstractRayRuntime;
import io.ray.runtime.functionmanager.FunctionDescriptor;
import io.ray.runtime.functionmanager.JavaFunctionDescriptor;
import java.io.File;
import java.lang.reflect.Method;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.util.Optional;
import javax.tools.JavaCompiler;
import javax.tools.ToolProvider;
import org.apache.commons.io.FileUtils;
import org.testng.Assert;
import org.testng.annotations.BeforeClass;
import org.testng.annotations.Test;
public class ClassLoaderTest extends BaseTest {
private final String codeSearchPath =
FileUtils.getTempDirectoryPath() + "/ray_test/ClassLoaderTest";
@BeforeClass
public void setUp() {
// The potential issue of multiple `ClassLoader` instances for the same job on multi-threading
// scenario only occurs if the classes are loaded from the job code search path.
System.setProperty("ray.job.code-search-path", codeSearchPath);
}
@Test(groups = {"cluster"})
public void testClassLoaderInMultiThreading() throws Exception {
File jobResourceDir = new File(codeSearchPath);
FileUtils.deleteQuietly(jobResourceDir);
jobResourceDir.mkdirs();
jobResourceDir.deleteOnExit();
// In this test case the class is expected to be loaded from the job code search path,
// so we need to put the compiled class file into the job code search path and load it
// later.
String testJavaFile =
""
+ "import java.lang.management.ManagementFactory;\n"
+ "import java.lang.management.RuntimeMXBean;\n"
+ "\n"
+ "public class ClassLoaderTester {\n"
+ "\n"
+ " static volatile int value;\n"
+ "\n"
+ " public int getPid() {\n"
+ " RuntimeMXBean runtime = ManagementFactory.getRuntimeMXBean();\n"
+ " String name = runtime.getName();\n"
+ " int index = name.indexOf(\"@\");\n"
+ " if (index != -1) {\n"
+ " return Integer.parseInt(name.substring(0, index));\n"
+ " } else {\n"
+ " throw new RuntimeException(\"parse pid error:\" + name);\n"
+ " }\n"
+ " }\n"
+ "\n"
+ " public int increase() throws InterruptedException {\n"
+ " return increaseInternal();\n"
+ " }\n"
+ "\n"
+ " public static synchronized int increaseInternal() throws InterruptedException {\n"
+ " int oldValue = value;\n"
+ " Thread.sleep(10 * 1000);\n"
+ " value = oldValue + 1;\n"
+ " return value;\n"
+ " }\n"
+ "\n"
+ " public int getClassLoaderHashCode() {\n"
+ " return this.getClass().getClassLoader().hashCode();\n"
+ " }\n"
+ "}";
// Write the demo java file to the job code search path.
String javaFilePath = codeSearchPath + "/ClassLoaderTester.java";
Files.write(Paths.get(javaFilePath), testJavaFile.getBytes());
// Compile the java file.
JavaCompiler compiler = ToolProvider.getSystemJavaCompiler();
int result = compiler.run(null, null, null, "-d", codeSearchPath, javaFilePath);
if (result != 0) {
throw new RuntimeException("Couldn't compile ClassLoaderTester.java.");
}
FunctionDescriptor constructor =
new JavaFunctionDescriptor("ClassLoaderTester", "<init>", "()V");
ActorHandle<?> actor1 = createActor(constructor);
FunctionDescriptor getPid = new JavaFunctionDescriptor("ClassLoaderTester", "getPid", "()I");
int pid =
this.<Integer>callActorFunction(actor1, getPid, new Object[0], Optional.of(Integer.class))
.get();
ActorHandle<?> actor2;
while (true) {
// Create another actor which share the same process of actor 1.
actor2 = createActor(constructor);
int actor2Pid =
this.<Integer>callActorFunction(actor2, getPid, new Object[0], Optional.of(Integer.class))
.get();
if (actor2Pid == pid) {
break;
}
}
FunctionDescriptor getClassLoaderHashCode =
new JavaFunctionDescriptor("ClassLoaderTester", "getClassLoaderHashCode", "()I");
ObjectRef<Integer> hashCode1 =
callActorFunction(
actor1, getClassLoaderHashCode, new Object[0], Optional.of(Integer.class));
ObjectRef<Integer> hashCode2 =
callActorFunction(
actor2, getClassLoaderHashCode, new Object[0], Optional.of(Integer.class));
Assert.assertEquals(hashCode1.get(), hashCode2.get());
FunctionDescriptor increase =
new JavaFunctionDescriptor("ClassLoaderTester", "increase", "()I");
ObjectRef<Integer> value1 =
callActorFunction(actor1, increase, new Object[0], Optional.of(Integer.class));
ObjectRef<Integer> value2 =
callActorFunction(actor2, increase, new Object[0], Optional.of(Integer.class));
Assert.assertNotEquals(value1.get(), value2.get());
}
private ActorHandle<?> createActor(FunctionDescriptor functionDescriptor) throws Exception {
Method createActorMethod =
AbstractRayRuntime.class.getDeclaredMethod(
"createActorImpl",
FunctionDescriptor.class,
Object[].class,
ActorCreationOptions.class);
createActorMethod.setAccessible(true);
return (ActorHandle<?>)
createActorMethod.invoke(
TestUtils.getUnderlyingRuntime(), functionDescriptor, new Object[0], null);
}
private <T> ObjectRef<T> callActorFunction(
ActorHandle<?> rayActor,
FunctionDescriptor functionDescriptor,
Object[] args,
Optional<Class<?>> returnType)
throws Exception {
Method callActorFunctionMethod =
AbstractRayRuntime.class.getDeclaredMethod(
"callActorFunction",
BaseActorHandle.class,
FunctionDescriptor.class,
Object[].class,
Optional.class,
CallOptions.class);
callActorFunctionMethod.setAccessible(true);
return (ObjectRef<T>)
callActorFunctionMethod.invoke(
TestUtils.getUnderlyingRuntime(),
rayActor,
functionDescriptor,
args,
returnType,
new CallOptions.Builder().build());
}
}

View file

@ -55,7 +55,6 @@ public class DefaultActorLifetimeTest {
System.setProperty("ray.job.default-actor-lifetime", defaultActorLifetime.name());
}
try {
System.setProperty("ray.job.num-java-workers-per-process", "1");
Ray.init();
/// 1. create owner and invoke createChildActor.

View file

@ -74,38 +74,6 @@ public class ExitActorTest extends BaseTest {
Assert.assertThrows(RayActorException.class, obj::get);
}
public void testExitActorInMultiWorker() {
Assert.assertTrue(TestUtils.getNumWorkersPerProcess() > 1);
ActorHandle<ExitingActor> actor1 =
Ray.actor(ExitingActor::new).setMaxRestarts(ActorCreationOptions.INFINITE_RESTART).remote();
int pid = actor1.task(ExitingActor::getPid).remote().get();
Assert.assertEquals(
1, (int) actor1.task(ExitingActor::getSizeOfActorContextMap).remote().get());
ActorHandle<ExitingActor> actor2;
while (true) {
// Create another actor which share the same process of actor 1.
actor2 =
Ray.actor(ExitingActor::new).setMaxRestarts(ActorCreationOptions.NO_RESTART).remote();
int actor2Pid = actor2.task(ExitingActor::getPid).remote().get();
if (actor2Pid == pid) {
break;
}
}
Assert.assertEquals(
2, (int) actor1.task(ExitingActor::getSizeOfActorContextMap).remote().get());
Assert.assertEquals(
2, (int) actor2.task(ExitingActor::getSizeOfActorContextMap).remote().get());
ObjectRef<Boolean> obj1 = actor1.task(ExitingActor::exit).remote();
Assert.assertThrows(RayActorException.class, obj1::get);
Assert.assertTrue(SystemUtil.isProcessAlive(pid));
// Actor 2 shouldn't exit or be reconstructed.
Assert.assertEquals(1, (int) actor2.task(ExitingActor::incr).remote().get());
Assert.assertEquals(
1, (int) actor2.task(ExitingActor::getSizeOfActorContextMap).remote().get());
Assert.assertEquals(pid, (int) actor2.task(ExitingActor::getPid).remote().get());
Assert.assertTrue(SystemUtil.isProcessAlive(pid));
}
public void testExitActorWithDynamicOptions() {
ActorHandle<ExitingActor> actor =
Ray.actor(ExitingActor::new)

View file

@ -15,7 +15,6 @@ public class ExitActorTest2 extends BaseTest {
@BeforeClass
public void setUp() {
System.setProperty("ray.job.num-java-workers-per-process", "1");
System.setProperty("ray.raylet.startup-token", "0");
}

View file

@ -23,10 +23,6 @@ public class FailureTest extends BaseTest {
@BeforeClass
public void setUp() {
// This is needed by `testGetThrowsQuicklyWhenFoundException`.
// Set one worker per process. Otherwise, if `badFunc2` and `slowFunc` run in the same
// process, `sleep` will delay `System.exit`.
System.setProperty("ray.job.num-java-workers-per-process", "1");
System.setProperty("ray.raylet.startup-token", "0");
}

View file

@ -11,7 +11,6 @@ public class JobConfigTest extends BaseTest {
@BeforeClass
public void setupJobConfig() {
System.setProperty("ray.job.num-java-workers-per-process", "3");
System.setProperty("ray.raylet.startup-token", "0");
System.setProperty("ray.job.jvm-options.0", "-DX=999");
System.setProperty("ray.job.jvm-options.1", "-DY=998");
@ -33,10 +32,6 @@ public class JobConfigTest extends BaseTest {
Assert.assertEquals("998", Ray.task(JobConfigTest::getJvmOptions, "Y").remote().get());
}
public void testNumJavaWorkersPerProcess() {
Assert.assertEquals(TestUtils.getNumWorkersPerProcess(), 3);
}
public void testInActor() {
ActorHandle<MyActor> actor = Ray.actor(MyActor::new).remote();

View file

@ -15,7 +15,6 @@ public class KillActorTest extends BaseTest {
@BeforeClass
public void setUp() {
System.setProperty("ray.job.num-java-workers-per-process", "1");
System.setProperty("ray.raylet.startup-token", "0");
}

View file

@ -51,14 +51,13 @@ public class MultiThreadingTest extends BaseTest {
final Object[] result = new Object[1];
Thread thread =
new Thread(
Ray.wrapRunnable(
() -> {
try {
result[0] = Ray.getRuntimeContext().getCurrentActorId();
} catch (Exception e) {
result[0] = e;
}
}));
() -> {
try {
result[0] = Ray.getRuntimeContext().getCurrentActorId();
} catch (Exception e) {
result[0] = e;
}
});
thread.start();
thread.join();
if (result[0] instanceof Exception) {
@ -140,6 +139,8 @@ public class MultiThreadingTest extends BaseTest {
Assert.assertEquals("ok", obj.get());
}
/// SINGLE_PROCESS mode doesn't support this API.
@Test(groups = {"cluster"})
public void testGetCurrentActorId() {
ActorHandle<ActorIdTester> actorIdTester = Ray.actor(ActorIdTester::new).remote();
ActorId actorId = actorIdTester.task(ActorIdTester::getCurrentActorId).remote().get();
@ -162,104 +163,6 @@ public class MultiThreadingTest extends BaseTest {
};
}
static boolean testMissingWrapRunnable() throws InterruptedException {
{
Runnable[] runnables = generateRunnables();
// It's OK to run them in main thread.
for (Runnable runnable : runnables) {
runnable.run();
}
}
Throwable[] throwable = new Throwable[1];
{
Runnable[] runnables = generateRunnables();
Thread thread =
new Thread(
Ray.wrapRunnable(
() -> {
try {
// It would be OK to run them in another thread if wrapped the runnable.
for (Runnable runnable : runnables) {
runnable.run();
}
} catch (Throwable ex) {
throwable[0] = ex;
}
}));
thread.start();
thread.join();
if (throwable[0] != null) {
throw new RuntimeException("Exception occurred in thread.", throwable[0]);
}
}
{
Runnable[] runnables = generateRunnables();
Thread thread =
new Thread(
() -> {
try {
// It wouldn't be OK to run them in another thread if not wrapped the runnable.
for (Runnable runnable : runnables) {
Assert.expectThrows(RuntimeException.class, runnable::run);
}
} catch (Throwable ex) {
throwable[0] = ex;
}
});
thread.start();
thread.join();
if (throwable[0] != null) {
throw new RuntimeException("Exception occurred in thread.", throwable[0]);
}
}
{
Runnable[] runnables = generateRunnables();
Runnable[] wrappedRunnables = new Runnable[runnables.length];
for (int i = 0; i < runnables.length; i++) {
wrappedRunnables[i] = Ray.wrapRunnable(runnables[i]);
}
// It would be OK to run the wrapped runnables in the current thread.
for (Runnable runnable : wrappedRunnables) {
runnable.run();
}
// It would be OK to invoke Ray APIs after executing a wrapped runnable in the current thread.
wrappedRunnables[0].run();
runnables[0].run();
}
// Return true here to make the caller.remote() returns an ObjectRef.
return true;
}
public void testMissingWrapRunnableInWorker() {
Ray.task(MultiThreadingTest::testMissingWrapRunnable).remote().get();
}
public void testGetAndSetAsyncContext() throws InterruptedException {
Object asyncContext = Ray.getAsyncContext();
Throwable[] throwable = new Throwable[1];
Thread thread =
new Thread(
() -> {
try {
Ray.setAsyncContext(asyncContext);
Ray.put(1);
} catch (Throwable ex) {
throwable[0] = ex;
}
});
thread.start();
thread.join();
if (throwable[0] != null) {
throw new RuntimeException("Exception occurred in thread.", throwable[0]);
}
}
private static void runTestCaseInMultipleThreads(Runnable testCase, int numRepeats) {
ExecutorService service = Executors.newFixedThreadPool(NUM_THREADS);
@ -267,14 +170,13 @@ public class MultiThreadingTest extends BaseTest {
List<Future<String>> futures = new ArrayList<>();
for (int i = 0; i < NUM_THREADS; i++) {
Callable<String> task =
Ray.wrapCallable(
() -> {
for (int j = 0; j < numRepeats; j++) {
TimeUnit.MILLISECONDS.sleep(1);
testCase.run();
}
return "ok";
});
() -> {
for (int j = 0; j < numRepeats; j++) {
TimeUnit.MILLISECONDS.sleep(1);
testCase.run();
}
return "ok";
};
futures.add(service.submit(task));
}
for (Future<String> future : futures) {
@ -288,35 +190,4 @@ public class MultiThreadingTest extends BaseTest {
service.shutdown();
}
}
private static boolean testGetAsyncContextAndSetAsyncContext() throws Exception {
final Object asyncContext = Ray.getAsyncContext();
final Object[] result = new Object[1];
Thread thread =
new Thread(
() -> {
try {
Ray.setAsyncContext(asyncContext);
Ray.put(0);
} catch (Exception e) {
result[0] = e;
}
});
thread.start();
thread.join();
if (result[0] instanceof Exception) {
throw (Exception) result[0];
}
return true;
}
public void testGetAsyncContextAndSetAsyncContextInDriver() throws Exception {
Assert.assertTrue(testGetAsyncContextAndSetAsyncContext());
}
public void testGetAsyncContextAndSetAsyncContextInWorker() {
ObjectRef<Boolean> obj =
Ray.task(MultiThreadingTest::testGetAsyncContextAndSetAsyncContext).remote();
Assert.assertTrue(obj.get());
}
}

View file

@ -4,7 +4,6 @@ import com.google.common.collect.ImmutableList;
import io.ray.api.ActorHandle;
import io.ray.api.Ray;
import io.ray.api.runtimeenv.RuntimeEnv;
import io.ray.runtime.util.SystemUtil;
import java.util.List;
import org.testng.Assert;
import org.testng.annotations.Test;
@ -30,10 +29,6 @@ public class RuntimeEnvTest {
return System.getenv(key);
}
public int getPid() {
return SystemUtil.pid();
}
public boolean findClass(String className) {
try {
Class.forName(className);
@ -45,7 +40,6 @@ public class RuntimeEnvTest {
}
public void testPerJobEnvVars() {
System.setProperty("ray.job.num-java-workers-per-process", "1");
System.setProperty("ray.job.runtime-env.env-vars.KEY1", "A");
System.setProperty("ray.job.runtime-env.env-vars.KEY2", "B");
@ -61,87 +55,6 @@ public class RuntimeEnvTest {
}
}
public void testPerActorEnvVars() {
/// This is used to test that actors with runtime envs will not reuse worker process.
System.setProperty("ray.job.num-java-workers-per-process", "2");
try {
Ray.init();
int pid1 = 0;
int pid2 = 0;
{
RuntimeEnv runtimeEnv =
new RuntimeEnv.Builder()
.addEnvVar("KEY1", "A")
.addEnvVar("KEY2", "B")
.addEnvVar("KEY1", "C")
.build();
ActorHandle<A> actor1 = Ray.actor(A::new).setRuntimeEnv(runtimeEnv).remote();
String val = actor1.task(A::getEnv, "KEY1").remote().get();
Assert.assertEquals(val, "C");
val = actor1.task(A::getEnv, "KEY2").remote().get();
Assert.assertEquals(val, "B");
pid1 = actor1.task(A::getPid).remote().get();
}
{
/// Because we didn't set them for actor2 , all should be null.
ActorHandle<A> actor2 = Ray.actor(A::new).remote();
String val = actor2.task(A::getEnv, "KEY1").remote().get();
Assert.assertNull(val);
val = actor2.task(A::getEnv, "KEY2").remote().get();
Assert.assertNull(val);
pid2 = actor2.task(A::getPid).remote().get();
}
// actor1 and actor2 shouldn't be in one process because they have
// different runtime env.
Assert.assertNotEquals(pid1, pid2);
} finally {
Ray.shutdown();
}
}
public void testPerActorEnvVarsOverwritePerJobEnvVars() {
System.setProperty("ray.job.num-java-workers-per-process", "2");
System.setProperty("ray.job.runtime-env.env-vars.KEY1", "A");
System.setProperty("ray.job.runtime-env.env-vars.KEY2", "B");
int pid1 = 0;
int pid2 = 0;
try {
Ray.init();
{
RuntimeEnv runtimeEnv = new RuntimeEnv.Builder().addEnvVar("KEY1", "C").build();
ActorHandle<A> actor1 = Ray.actor(A::new).setRuntimeEnv(runtimeEnv).remote();
String val = actor1.task(A::getEnv, "KEY1").remote().get();
Assert.assertEquals(val, "C");
val = actor1.task(A::getEnv, "KEY2").remote().get();
Assert.assertEquals(val, "B");
pid1 = actor1.task(A::getPid).remote().get();
}
{
/// Because we didn't set them for actor2 explicitly, it should use the per job
/// runtime env.
ActorHandle<A> actor2 = Ray.actor(A::new).remote();
String val = actor2.task(A::getEnv, "KEY1").remote().get();
Assert.assertEquals(val, "A");
val = actor2.task(A::getEnv, "KEY2").remote().get();
Assert.assertEquals(val, "B");
pid2 = actor2.task(A::getPid).remote().get();
}
// actor1 and actor2 shouldn't be in one process because they have
// different runtime env.
Assert.assertNotEquals(pid1, pid2);
} finally {
Ray.shutdown();
}
}
private static String getEnvVar(String key) {
return System.getenv(key);
}

View file

@ -3,9 +3,7 @@ package io.ray.test;
import com.google.common.base.Preconditions;
import io.ray.api.ObjectRef;
import io.ray.api.Ray;
import io.ray.runtime.AbstractRayRuntime;
import io.ray.runtime.RayRuntimeInternal;
import io.ray.runtime.RayRuntimeProxy;
import io.ray.runtime.config.RayConfig;
import io.ray.runtime.config.RunMode;
import io.ray.runtime.task.ArgumentsBuilder;
@ -129,20 +127,7 @@ public class TestUtils {
}
public static RayRuntimeInternal getUnderlyingRuntime() {
if (Ray.internal() instanceof AbstractRayRuntime) {
return (RayRuntimeInternal) Ray.internal();
}
RayRuntimeProxy proxy =
(RayRuntimeProxy) (java.lang.reflect.Proxy.getInvocationHandler(Ray.internal()));
return proxy.getRuntimeObject();
}
private static int getNumWorkersPerProcessRemoteFunction() {
return TestUtils.getRuntime().getRayConfig().numWorkersPerProcess;
}
public static int getNumWorkersPerProcess() {
return Ray.task(TestUtils::getNumWorkersPerProcessRemoteFunction).remote().get();
return (RayRuntimeInternal) Ray.internal();
}
public static ProcessBuilder buildDriver(Class<?> mainClass, String[] args) {

View file

@ -1,6 +0,0 @@
ray {
job {
# Enable multi-worker feature in Java test
num-java-workers-per-process: 10
}
}

View file

@ -1116,7 +1116,6 @@ cdef class CoreWorker:
options.unhandled_exception_handler = unhandled_exception_handler
options.get_lang_stack = get_py_stack
options.is_local_mode = local_mode
options.num_workers = 1
options.kill_main = kill_main_task
options.terminate_asyncio_thread = terminate_asyncio_thread
options.serialized_job_config = serialized_job_config

View file

@ -8,8 +8,6 @@ class JobConfig:
"""A class used to store the configurations of a job.
Attributes:
num_java_workers_per_process (int): The number of java workers per
worker process.
jvm_options (str[]): The jvm options for java workers of the job.
code_search_path (list): A list of directories or jar files that
specify the search path for user code. This will be used as
@ -22,7 +20,6 @@ class JobConfig:
def __init__(
self,
num_java_workers_per_process=1,
jvm_options=None,
code_search_path=None,
runtime_env=None,
@ -31,7 +28,6 @@ class JobConfig:
ray_namespace=None,
default_actor_lifetime="non_detached",
):
self.num_java_workers_per_process = num_java_workers_per_process
self.jvm_options = jvm_options or []
self.code_search_path = code_search_path or []
# It's difficult to find the error that caused by the
@ -108,7 +104,6 @@ class JobConfig:
pb.ray_namespace = str(uuid.uuid4())
else:
pb.ray_namespace = self.ray_namespace
pb.num_java_workers_per_process = self.num_java_workers_per_process
pb.jvm_options.extend(self.jvm_options)
pb.code_search_path.extend(self.code_search_path)
for k, v in self.metadata.items():
@ -147,9 +142,6 @@ class JobConfig:
Generates a JobConfig object from json.
"""
return cls(
num_java_workers_per_process=job_config_json.get(
"num_java_workers_per_process", 1
),
jvm_options=job_config_json.get("jvm_options", None),
code_search_path=job_config_json.get("code_search_path", None),
runtime_env=job_config_json.get("runtime_env", None),

View file

@ -545,19 +545,16 @@ def test_k8s_cpu(use_cgroups_v2: bool):
def test_sync_job_config(shutdown_only):
num_java_workers_per_process = 8
runtime_env = {"env_vars": {"key": "value"}}
ray.init(
job_config=ray.job_config.JobConfig(
num_java_workers_per_process=num_java_workers_per_process,
runtime_env=runtime_env,
)
)
# Check that the job config is synchronized at the driver side.
job_config = ray.worker.global_worker.core_worker.get_job_config()
assert job_config.num_java_workers_per_process == num_java_workers_per_process
job_runtime_env = RuntimeEnv.deserialize(
job_config.runtime_env_info.serialized_runtime_env
)
@ -571,7 +568,6 @@ def test_sync_job_config(shutdown_only):
# Check that the job config is synchronized at the worker side.
job_config = gcs_utils.JobConfig()
job_config.ParseFromString(ray.get(get_job_config.remote()))
assert job_config.num_java_workers_per_process == num_java_workers_per_process
job_runtime_env = RuntimeEnv.deserialize(
job_config.runtime_env_info.serialized_runtime_env
)

View file

@ -3084,15 +3084,6 @@ void CoreWorker::HandleKillActor(const rpc::KillActorRequest &request,
if (request.no_restart()) {
Disconnect();
}
if (options_.num_workers > 1) {
// TODO (kfstorm): Should we add some kind of check before sending the killing
// request?
RAY_LOG(ERROR)
<< "Killing an actor which is running in a worker process with multiple "
"workers will also kill other actors in this process. To avoid this, "
"please create the Java actor with some dynamic options to make it being "
"hosted in a dedicated worker process.";
}
// NOTE(hchen): Use `QuickExit()` to force-exit this process without doing cleanup.
// `exit()` will destruct static objects in an incorrect order, which will lead to
// core dumps.

View file

@ -76,7 +76,6 @@ struct CoreWorkerOptions {
get_lang_stack(nullptr),
kill_main(nullptr),
is_local_mode(false),
num_workers(0),
terminate_asyncio_thread(nullptr),
serialized_job_config(""),
metrics_agent_port(-1),
@ -148,8 +147,6 @@ struct CoreWorkerOptions {
std::function<bool()> kill_main;
/// Is local mode being used.
bool is_local_mode;
/// The number of workers to be started in the current process.
int num_workers;
/// The function to destroy asyncio event and loops.
std::function<void()> terminate_asyncio_thread;
/// Serialized representation of JobConfig.

View file

@ -82,10 +82,9 @@ thread_local std::weak_ptr<CoreWorker> CoreWorkerProcessImpl::thread_local_core_
CoreWorkerProcessImpl::CoreWorkerProcessImpl(const CoreWorkerOptions &options)
: options_(options),
global_worker_id_(
options.worker_type == WorkerType::DRIVER
? ComputeDriverIdFromJob(options_.job_id)
: (options_.num_workers == 1 ? WorkerID::FromRandom() : WorkerID::Nil())) {
global_worker_id_(options.worker_type == WorkerType::DRIVER
? ComputeDriverIdFromJob(options_.job_id)
: WorkerID::FromRandom()) {
if (options_.enable_logging) {
std::stringstream app_name;
app_name << LanguageString(options_.language) << "-core-"
@ -111,12 +110,6 @@ CoreWorkerProcessImpl::CoreWorkerProcessImpl(const CoreWorkerOptions &options)
<< "install_failure_signal_handler must be false because ray log is disabled.";
}
RAY_CHECK(options_.num_workers > 0);
if (options_.worker_type == WorkerType::DRIVER) {
// Driver process can only contain one worker.
RAY_CHECK(options_.num_workers == 1);
}
RAY_LOG(INFO) << "Constructing CoreWorkerProcess. pid: " << getpid();
// NOTE(kfstorm): any initialization depending on RayConfig must happen after this line.
@ -250,18 +243,14 @@ bool CoreWorkerProcessImpl::ShouldCreateGlobalWorkerOnConstruction() const {
// worker APIs before `CoreWorkerProcess::RunTaskExecutionLoop` is called. So we need
// to create the worker instance here. One example of invocations is
// https://github.com/ray-project/ray/blob/45ce40e5d44801193220d2c546be8de0feeef988/python/ray/worker.py#L1281.
return options_.num_workers == 1 && (options_.worker_type == WorkerType::DRIVER ||
options_.language == Language::PYTHON);
return (options_.worker_type == WorkerType::DRIVER ||
options_.language == Language::PYTHON);
}
std::shared_ptr<CoreWorker> CoreWorkerProcessImpl::GetWorker(
const WorkerID &worker_id) const {
absl::ReaderMutexLock lock(&mutex_);
auto it = workers_.find(worker_id);
if (it != workers_.end()) {
return it->second;
}
return nullptr;
return global_worker_;
}
std::shared_ptr<CoreWorker> CoreWorkerProcessImpl::GetGlobalWorker() {
@ -275,13 +264,9 @@ std::shared_ptr<CoreWorker> CoreWorkerProcessImpl::CreateWorker() {
global_worker_id_ != WorkerID::Nil() ? global_worker_id_ : WorkerID::FromRandom());
RAY_LOG(DEBUG) << "Worker " << worker->GetWorkerID() << " is created.";
absl::WriterMutexLock lock(&mutex_);
if (options_.num_workers == 1) {
global_worker_ = worker;
}
global_worker_ = worker;
thread_local_core_worker_ = worker;
workers_.emplace(worker->GetWorkerID(), worker);
RAY_CHECK(workers_.size() <= static_cast<size_t>(options_.num_workers));
return worker;
}
@ -293,10 +278,7 @@ void CoreWorkerProcessImpl::RemoveWorker(std::shared_ptr<CoreWorker> worker) {
RAY_CHECK(thread_local_core_worker_.lock() == worker);
}
thread_local_core_worker_.reset();
{
workers_.erase(worker->GetWorkerID());
RAY_LOG(INFO) << "Removed worker " << worker->GetWorkerID();
}
RAY_LOG(INFO) << "Removed worker " << worker->GetWorkerID();
if (global_worker_ == worker) {
global_worker_ = nullptr;
}
@ -304,31 +286,14 @@ void CoreWorkerProcessImpl::RemoveWorker(std::shared_ptr<CoreWorker> worker) {
void CoreWorkerProcessImpl::RunWorkerTaskExecutionLoop() {
RAY_CHECK(options_.worker_type == WorkerType::WORKER);
if (options_.num_workers == 1) {
// Run the task loop in the current thread only if the number of workers is 1.
auto worker = GetGlobalWorker();
if (!worker) {
worker = CreateWorker();
}
worker->RunTaskExecutionLoop();
RAY_LOG(INFO) << "Task execution loop terminated. Removing the global worker.";
RemoveWorker(worker);
} else {
std::vector<std::thread> worker_threads;
for (int i = 0; i < options_.num_workers; i++) {
worker_threads.emplace_back([this, i] {
SetThreadName("worker.task" + std::to_string(i));
auto worker = CreateWorker();
worker->RunTaskExecutionLoop();
RAY_LOG(INFO) << "Task execution loop terminated for a thread "
<< std::to_string(i) << ". Removing a worker.";
RemoveWorker(worker);
});
}
for (auto &thread : worker_threads) {
thread.join();
}
// Run the task loop in the current thread only if the number of workers is 1.
auto worker = GetGlobalWorker();
if (!worker) {
worker = CreateWorker();
}
worker->RunTaskExecutionLoop();
RAY_LOG(INFO) << "Task execution loop terminated. Removing the global worker.";
RemoveWorker(worker);
}
void CoreWorkerProcessImpl::ShutdownDriver() {
@ -342,42 +307,30 @@ void CoreWorkerProcessImpl::ShutdownDriver() {
}
CoreWorker &CoreWorkerProcessImpl::GetCoreWorkerForCurrentThread() {
if (options_.num_workers == 1) {
auto global_worker = GetGlobalWorker();
if (ShouldCreateGlobalWorkerOnConstruction() && !global_worker) {
// This could only happen when the worker has already been shutdown.
// In this case, we should exit without crashing.
// TODO (scv119): A better solution could be returning error code
// and handling it at language frontend.
if (options_.worker_type == WorkerType::DRIVER) {
RAY_LOG(ERROR)
<< "The global worker has already been shutdown. This happens when "
"the language frontend accesses the Ray's worker after it is "
"shutdown. The process will exit";
} else {
RAY_LOG(INFO) << "The global worker has already been shutdown. This happens when "
"the language frontend accesses the Ray's worker after it is "
"shutdown. The process will exit";
}
QuickExit();
auto global_worker = GetGlobalWorker();
if (ShouldCreateGlobalWorkerOnConstruction() && !global_worker) {
// This could only happen when the worker has already been shutdown.
// In this case, we should exit without crashing.
// TODO (scv119): A better solution could be returning error code
// and handling it at language frontend.
if (options_.worker_type == WorkerType::DRIVER) {
RAY_LOG(ERROR) << "The global worker has already been shutdown. This happens when "
"the language frontend accesses the Ray's worker after it is "
"shutdown. The process will exit";
} else {
RAY_LOG(INFO) << "The global worker has already been shutdown. This happens when "
"the language frontend accesses the Ray's worker after it is "
"shutdown. The process will exit";
}
RAY_CHECK(global_worker) << "global_worker_ must not be NULL";
return *global_worker;
QuickExit();
}
auto ptr = thread_local_core_worker_.lock();
RAY_CHECK(ptr != nullptr)
<< "The current thread is not bound with a core worker instance.";
return *ptr;
RAY_CHECK(global_worker) << "global_worker_ must not be NULL";
return *global_worker;
}
void CoreWorkerProcessImpl::SetThreadLocalWorkerById(const WorkerID &worker_id) {
if (options_.num_workers == 1) {
RAY_CHECK(GetGlobalWorker()->GetWorkerID() == worker_id);
return;
}
auto worker = GetWorker(worker_id);
RAY_CHECK(worker) << "Worker " << worker_id << " not found.";
thread_local_core_worker_ = GetWorker(worker_id);
RAY_CHECK(GetGlobalWorker()->GetWorkerID() == worker_id);
return;
}
} // namespace core

View file

@ -25,7 +25,6 @@ class CoreWorker;
/// CoreWorkerOptions options = {
/// WorkerType::DRIVER, // worker_type
/// ..., // other arguments
/// 1, // num_workers
/// };
/// CoreWorkerProcess::Initialize(options);
///
@ -36,7 +35,6 @@ class CoreWorker;
/// CoreWorkerOptions options = {
/// WorkerType::WORKER, // worker_type
/// ..., // other arguments
/// num_workers, // num_workers
/// };
/// CoreWorkerProcess::Initialize(options);
/// ... // Do other stuff
@ -52,14 +50,7 @@ class CoreWorker;
/// threads, remember to call `CoreWorkerProcess::SetCurrentThreadWorkerId(worker_id)`
/// once in the new thread before calling core worker APIs, to associate the current
/// thread with a worker. You can obtain the worker ID via
/// `CoreWorkerProcess::GetCoreWorker()->GetWorkerID()`. Currently a Java worker process
/// starts multiple workers by default, but can be configured to start only 1 worker by
/// speicifying `num_java_workers_per_process` in the job config.
///
/// If only 1 worker is started (either because the worker type is driver, or the
/// `num_workers` in `CoreWorkerOptions` is set to 1), all threads will be automatically
/// associated to the only worker. Then no need to call `SetCurrentThreadWorkerId` in
/// your own threads. Currently a Python worker process starts only 1 worker.
/// `CoreWorkerProcess::GetCoreWorker()->GetWorkerID()`.
///
/// How does core worker process dealloation work?
///
@ -195,9 +186,6 @@ class CoreWorkerProcessImpl {
/// The worker ID of the global worker, if the number of workers is 1.
const WorkerID global_worker_id_;
/// Map from worker ID to worker.
absl::flat_hash_map<WorkerID, std::shared_ptr<CoreWorker>> workers_ GUARDED_BY(mutex_);
/// To protect access to workers_ and global_worker_
mutable absl::Mutex mutex_;
};

View file

@ -103,7 +103,6 @@ Java_io_ray_runtime_RayNativeRuntime_nativeInitialize(JNIEnv *env,
jstring rayletSocket,
jbyteArray jobId,
jobject gcsClientOptions,
jint numWorkersPerProcess,
jstring logDir,
jbyteArray jobConfig,
jint startupToken,
@ -289,7 +288,6 @@ Java_io_ray_runtime_RayNativeRuntime_nativeInitialize(JNIEnv *env,
options.task_execution_callback = task_execution_callback;
options.on_worker_shutdown = on_worker_shutdown;
options.gc_collect = gc_collect;
options.num_workers = static_cast<int>(numWorkersPerProcess);
options.serialized_job_config = serialized_job_config;
options.metrics_agent_port = -1;
options.startup_token = startupToken;

View file

@ -25,7 +25,7 @@ extern "C" {
* Class: io_ray_runtime_RayNativeRuntime
* Method: nativeInitialize
* Signature:
* (ILjava/lang/String;ILjava/lang/String;Ljava/lang/String;Ljava/lang/String;[BLio/ray/runtime/gcs/GcsClientOptions;ILjava/lang/String;[BII)V
* (ILjava/lang/String;ILjava/lang/String;Ljava/lang/String;Ljava/lang/String;[BLio/ray/runtime/gcs/GcsClientOptions;Ljava/lang/String;[BII)V
*/
JNIEXPORT void JNICALL Java_io_ray_runtime_RayNativeRuntime_nativeInitialize(JNIEnv *,
jclass,
@ -37,7 +37,6 @@ JNIEXPORT void JNICALL Java_io_ray_runtime_RayNativeRuntime_nativeInitialize(JNI
jstring,
jbyteArray,
jobject,
jint,
jstring,
jbyteArray,
jint,

View file

@ -140,7 +140,6 @@ class CoreWorkerTest : public ::testing::Test {
options.node_manager_port = node_manager_port;
options.raylet_ip_address = "127.0.0.1";
options.driver_name = "core_worker_test";
options.num_workers = 1;
options.metrics_agent_port = -1;
CoreWorkerProcess::Initialize(options);
}

View file

@ -51,7 +51,6 @@ class MockWorker {
options.raylet_ip_address = "127.0.0.1";
options.task_execution_callback =
std::bind(&MockWorker::ExecuteTask, this, _1, _2, _3, _4, _5, _6, _7, _8, _9);
options.num_workers = 1;
options.metrics_agent_port = -1;
options.startup_token = startup_token;
CoreWorkerProcess::Initialize(options);

View file

@ -686,12 +686,8 @@ void GcsActorManager::PollOwnerForActorOutOfScope(
if (node_it != owners_.end() && node_it->second.count(owner_id)) {
// Only destroy the actor if its owner is still alive. The actor may
// have already been destroyed if the owner died.
// For multiple actors in one process, if one actor is out of scope,
// We shouldn't force kill the actor because other actors in the process
// are still alive.
auto force_kill =
get_job_config_(actor_id.JobId())->num_java_workers_per_process() <= 1;
DestroyActor(actor_id, GenActorOutOfScopeCause(GetActor(actor_id)), force_kill);
DestroyActor(
actor_id, GenActorOutOfScopeCause(GetActor(actor_id)), /*force_kill=*/true);
}
});
}

View file

@ -89,7 +89,7 @@ TEST_F(GcsJobManagerTest, TestGetJobConfig) {
auto job_id2 = JobID::FromInt(2);
gcs::GcsInitData gcs_init_data(gcs_table_storage_);
gcs_job_manager.Initialize(/*init_data=*/gcs_init_data);
auto add_job_request1 = Mocker::GenAddJobRequest(job_id1, "namespace_1", 4);
auto add_job_request1 = Mocker::GenAddJobRequest(job_id1, "namespace_1");
rpc::AddJobReply empty_reply;
@ -97,19 +97,16 @@ TEST_F(GcsJobManagerTest, TestGetJobConfig) {
*add_job_request1,
&empty_reply,
[](Status, std::function<void()>, std::function<void()>) {});
auto add_job_request2 = Mocker::GenAddJobRequest(job_id2, "namespace_2", 8);
auto add_job_request2 = Mocker::GenAddJobRequest(job_id2, "namespace_2");
gcs_job_manager.HandleAddJob(
*add_job_request2,
&empty_reply,
[](Status, std::function<void()>, std::function<void()>) {});
auto job_config1 = gcs_job_manager.GetJobConfig(job_id1);
ASSERT_EQ("namespace_1", job_config1->ray_namespace());
ASSERT_EQ(4, job_config1->num_java_workers_per_process());
auto job_config2 = gcs_job_manager.GetJobConfig(job_id2);
ASSERT_EQ("namespace_2", job_config2->ray_namespace());
ASSERT_EQ(8, job_config2->num_java_workers_per_process());
}
int main(int argc, char **argv) {

View file

@ -239,12 +239,9 @@ struct Mocker {
}
static std::shared_ptr<rpc::AddJobRequest> GenAddJobRequest(
const JobID &job_id,
const std::string &ray_namespace,
uint32_t num_java_worker_per_process) {
const JobID &job_id, const std::string &ray_namespace) {
auto job_config_data = std::make_shared<rpc::JobConfig>();
job_config_data->set_ray_namespace(ray_namespace);
job_config_data->set_num_java_workers_per_process(num_java_worker_per_process);
auto job_table_data = std::make_shared<rpc::JobTableData>();
job_table_data->set_job_id(job_id.Binary());

View file

@ -39,8 +39,8 @@ using RestoreSpilledObjectCallback =
/// A struct that includes info about the object.
struct ObjectInfo {
ObjectID object_id;
int64_t data_size;
int64_t metadata_size;
int64_t data_size = 0;
int64_t metadata_size = 0;
/// Owner's raylet ID.
NodeID owner_raylet_id;
/// Owner's IP address.

View file

@ -260,8 +260,6 @@ message JobConfig {
NON_DETACHED = 1;
}
// The number of java workers per worker process.
uint32 num_java_workers_per_process = 1;
// The jvm options for java workers of the job.
repeated string jvm_options = 2;
// A list of directories or files (jar files or dynamic libraries) that specify the

View file

@ -198,14 +198,12 @@ void WorkerPool::update_worker_startup_token_counter() {
void WorkerPool::AddWorkerProcess(
State &state,
const int workers_to_start,
const rpc::WorkerType worker_type,
const Process &proc,
const std::chrono::high_resolution_clock::time_point &start,
const rpc::RuntimeEnvInfo &runtime_env_info) {
state.worker_processes.emplace(worker_startup_token_counter_,
WorkerProcessInfo{workers_to_start,
workers_to_start,
WorkerProcessInfo{/*is_pending_registration=*/true,
{},
worker_type,
proc,
@ -248,7 +246,7 @@ std::tuple<Process, StartupToken> WorkerPool::StartWorkerProcess(
int starting_workers = 0;
for (auto &entry : state.worker_processes) {
if (entry.second.worker_type == worker_type) {
starting_workers += entry.second.num_starting_workers;
starting_workers += entry.second.is_pending_registration ? 1 : 0;
}
}
@ -268,13 +266,6 @@ std::tuple<Process, StartupToken> WorkerPool::StartWorkerProcess(
<< rpc::WorkerType_Name(worker_type) << ", current pool has "
<< state.idle.size() << " workers";
int workers_to_start = 1;
if (dynamic_options.empty()) {
if (language == Language::JAVA) {
workers_to_start = job_config->num_java_workers_per_process();
}
}
std::vector<std::string> options;
// Append Ray-defined per-job options here
@ -312,12 +303,6 @@ std::tuple<Process, StartupToken> WorkerPool::StartWorkerProcess(
}
}
// Append Ray-defined per-process options here
if (language == Language::JAVA) {
options.push_back("-Dray.job.num-java-workers-per-process=" +
std::to_string(workers_to_start));
}
// Append startup-token for JAVA here
if (language == Language::JAVA) {
options.push_back("-Dray.raylet.startup-token=" +
@ -460,13 +445,12 @@ std::tuple<Process, StartupToken> WorkerPool::StartWorkerProcess(
auto duration = std::chrono::duration_cast<std::chrono::milliseconds>(end - start);
stats::ProcessStartupTimeMs.Record(duration.count());
stats::NumWorkersStarted.Record(1);
RAY_LOG(INFO) << "Started worker process of " << workers_to_start
<< " worker(s) with pid " << proc.GetId() << ", the token "
RAY_LOG(INFO) << "Started worker process with pid " << proc.GetId() << ", the token is "
<< worker_startup_token_counter_;
AdjustWorkerOomScore(proc.GetId());
MonitorStartingWorkerProcess(
proc, worker_startup_token_counter_, language, worker_type);
AddWorkerProcess(state, workers_to_start, worker_type, proc, start, runtime_env_info);
AddWorkerProcess(state, worker_type, proc, start, runtime_env_info);
StartupToken worker_startup_token = worker_startup_token_counter_;
update_worker_startup_token_counter();
if (IsIOWorkerType(worker_type)) {
@ -513,7 +497,7 @@ void WorkerPool::MonitorStartingWorkerProcess(const Process &proc,
// Since this process times out to start, remove it from worker_processes
// to avoid the zombie worker.
auto it = state.worker_processes.find(proc_startup_token);
if (it != state.worker_processes.end() && it->second.num_starting_workers != 0) {
if (it != state.worker_processes.end() && it->second.is_pending_registration) {
RAY_LOG(ERROR)
<< "Some workers of the worker process(" << proc.GetId()
<< ") have not registered within the timeout. "
@ -747,12 +731,10 @@ void WorkerPool::OnWorkerStarted(const std::shared_ptr<WorkerInterface> &worker)
auto it = state.worker_processes.find(worker_startup_token);
if (it != state.worker_processes.end()) {
it->second.num_starting_workers--;
it->second.is_pending_registration = false;
it->second.alive_started_workers.insert(worker);
if (it->second.num_starting_workers == 0) {
// We may have slots to start more workers now.
TryStartIOWorkers(worker->GetLanguage());
}
// We may have slots to start more workers now.
TryStartIOWorkers(worker->GetLanguage());
}
const auto &worker_type = worker->GetWorkerType();
if (IsIOWorkerType(worker_type)) {
@ -1044,8 +1026,7 @@ void WorkerPool::TryKillingIdleWorkers() {
auto &worker_state = GetStateForLanguage(idle_worker->GetLanguage());
auto it = worker_state.worker_processes.find(worker_startup_token);
if (it != worker_state.worker_processes.end() &&
it->second.num_starting_workers > 0) {
if (it != worker_state.worker_processes.end() && it->second.is_pending_registration) {
// A Java worker process may hold multiple workers.
// Some workers of this process are pending registration. Skip killing this worker.
continue;
@ -1349,7 +1330,7 @@ void WorkerPool::PrestartWorkers(const TaskSpecification &task_spec,
// The number of available workers that can be used for this task spec.
int num_usable_workers = state.idle.size();
for (auto &entry : state.worker_processes) {
num_usable_workers += entry.second.num_starting_workers;
num_usable_workers += entry.second.is_pending_registration ? 1 : 0;
}
// Some existing workers may be holding less than 1 CPU each, so we should
// start as many workers as needed to fill up the remaining CPUs.
@ -1376,10 +1357,10 @@ void WorkerPool::DisconnectWorker(const std::shared_ptr<WorkerInterface> &worker
if (!RemoveWorker(it->second.alive_started_workers, worker)) {
// Worker is either starting or started,
// if it's not started, we should remove it from starting.
it->second.num_starting_workers--;
it->second.is_pending_registration = false;
}
if (it->second.alive_started_workers.size() == 0 &&
it->second.num_starting_workers == 0) {
!it->second.is_pending_registration) {
DeleteRuntimeEnvIfPossible(it->second.runtime_env_info.serialized_runtime_env());
RemoveWorkerProcess(state, worker->GetStartupToken());
}
@ -1508,7 +1489,8 @@ void WorkerPool::WarnAboutSize() {
num_workers_started_or_registered +=
static_cast<int64_t>(state.registered_workers.size());
for (const auto &starting_process : state.worker_processes) {
num_workers_started_or_registered += starting_process.second.num_starting_workers;
num_workers_started_or_registered +=
starting_process.second.is_pending_registration ? 0 : 1;
}
// Don't count IO workers towards the warning message threshold.
num_workers_started_or_registered -= RayConfig::instance().max_io_workers() * 2;

View file

@ -473,10 +473,8 @@ class WorkerPool : public WorkerPoolInterface, public IOWorkerPoolInterface {
/// Some basic information about the worker process.
struct WorkerProcessInfo {
/// The number of workers in the worker process.
int num_workers;
/// The number of pending registration workers in the worker process.
int num_starting_workers;
/// Whether this worker is pending registration or is started.
bool is_pending_registration = true;
/// The started workers which is alive.
std::unordered_set<std::shared_ptr<WorkerInterface>> alive_started_workers;
/// The type of the worker.
@ -684,7 +682,6 @@ class WorkerPool : public WorkerPoolInterface, public IOWorkerPoolInterface {
void DeleteRuntimeEnvIfPossible(const std::string &serialized_runtime_env);
void AddWorkerProcess(State &state,
const int workers_to_start,
const rpc::WorkerType worker_type,
const Process &proc,
const std::chrono::high_resolution_clock::time_point &start,

View file

@ -26,8 +26,7 @@ namespace ray {
namespace raylet {
int NUM_WORKERS_PER_PROCESS_JAVA = 3;
int MAXIMUM_STARTUP_CONCURRENCY = 5;
int MAXIMUM_STARTUP_CONCURRENCY = 15;
int MAX_IO_WORKER_SIZE = 2;
int POOL_SIZE_SOFT_LIMIT = 5;
int WORKER_REGISTER_TIMEOUT_SECONDS = 3;
@ -37,10 +36,6 @@ const std::string BAD_RUNTIME_ENV_ERROR_MSG = "bad runtime env";
std::vector<Language> LANGUAGES = {Language::PYTHON, Language::JAVA};
static inline std::string GetNumJavaWorkersPerProcessSystemProperty(int num) {
return std::string("-Dray.job.num-java-workers-per-process=") + std::to_string(num);
}
class MockWorkerClient : public rpc::CoreWorkerClientInterface {
public:
MockWorkerClient(instrumented_io_context &io_service) : io_service_(io_service) {}
@ -194,7 +189,7 @@ class WorkerPoolMock : public WorkerPool {
int total = 0;
for (auto &state_entry : states_by_lang_) {
for (auto &process_entry : state_entry.second.worker_processes) {
total += process_entry.second.num_starting_workers;
total += process_entry.second.is_pending_registration ? 1 : 0;
}
}
return total;
@ -204,7 +199,7 @@ class WorkerPoolMock : public WorkerPool {
int total = 0;
for (auto &entry : states_by_lang_) {
for (auto process : entry.second.worker_processes) {
if (process.second.num_starting_workers != 0) {
if (process.second.is_pending_registration) {
total += 1;
}
}
@ -313,7 +308,6 @@ class WorkerPoolMock : public WorkerPool {
if (pushed_it == pushedProcesses_.end()) {
int runtime_env_hash = 0;
bool is_java = false;
bool has_dynamic_options = false;
// Parses runtime env hash to make sure the pushed workers can be popped out.
for (auto command_args : it->second) {
std::string runtime_env_key = "--runtime-env-hash=";
@ -326,14 +320,9 @@ class WorkerPoolMock : public WorkerPool {
if (pos != std::string::npos) {
is_java = true;
}
pos = command_args.find("-X");
if (pos != std::string::npos) {
has_dynamic_options = true;
}
}
// TODO(SongGuyang): support C++ language workers.
int num_workers =
(is_java && !has_dynamic_options) ? NUM_WORKERS_PER_PROCESS_JAVA : 1;
int num_workers = 1;
RAY_CHECK(timeout_worker_number <= num_workers)
<< "The timeout worker number cannot exceed the total number of workers";
auto register_workers = num_workers - timeout_worker_number;
@ -471,7 +460,6 @@ class WorkerPoolTest : public ::testing::Test {
worker_pool_ = std::make_unique<WorkerPoolMock>(
io_service_, worker_commands, mock_worker_rpc_clients_);
rpc::JobConfig job_config;
job_config.set_num_java_workers_per_process(NUM_WORKERS_PER_PROCESS_JAVA);
RegisterDriver(Language::PYTHON, JOB_ID, job_config);
}
@ -489,18 +477,7 @@ class WorkerPoolTest : public ::testing::Test {
ASSERT_TRUE(worker_pool_->NumWorkerProcessesStarting() <=
expected_worker_process_count);
Process prev = worker_pool_->LastStartedWorkerProcess();
if (!std::equal_to<Process>()(last_started_worker_process, prev)) {
last_started_worker_process = prev;
const auto &real_command =
worker_pool_->GetWorkerCommand(last_started_worker_process);
if (language == Language::JAVA) {
auto it = std::find(
real_command.begin(),
real_command.end(),
GetNumJavaWorkersPerProcessSystemProperty(num_workers_per_process));
ASSERT_NE(it, real_command.end());
}
} else {
if (std::equal_to<Process>()(last_started_worker_process, prev)) {
ASSERT_EQ(worker_pool_->NumWorkerProcessesStarting(),
expected_worker_process_count);
ASSERT_TRUE(i >= expected_worker_process_count);
@ -618,9 +595,7 @@ TEST_F(WorkerPoolTest, HandleWorkerRegistration) {
auto [proc, token] = worker_pool_->StartWorkerProcess(
Language::JAVA, rpc::WorkerType::WORKER, JOB_ID, &status);
std::vector<std::shared_ptr<WorkerInterface>> workers;
for (int i = 0; i < NUM_WORKERS_PER_PROCESS_JAVA; i++) {
workers.push_back(worker_pool_->CreateWorker(Process(), Language::JAVA));
}
workers.push_back(worker_pool_->CreateWorker(Process(), Language::JAVA));
for (const auto &worker : workers) {
// Check that there's still a starting worker process
// before all workers have been registered
@ -670,7 +645,7 @@ TEST_F(WorkerPoolTest, StartupPythonWorkerProcessCount) {
}
TEST_F(WorkerPoolTest, StartupJavaWorkerProcessCount) {
TestStartupWorkerProcessCount(Language::JAVA, NUM_WORKERS_PER_PROCESS_JAVA);
TestStartupWorkerProcessCount(Language::JAVA, 1);
}
TEST_F(WorkerPoolTest, InitialWorkerProcessCount) {
@ -756,7 +731,6 @@ TEST_F(WorkerPoolTest, StartWorkerWithDynamicOptionsCommand) {
rpc::JobConfig job_config = rpc::JobConfig();
job_config.add_code_search_path("/test/code_search_path");
job_config.set_num_java_workers_per_process(NUM_WORKERS_PER_PROCESS_JAVA);
job_config.add_jvm_options("-Xmx1g");
job_config.add_jvm_options("-Xms500m");
job_config.add_jvm_options("-Dmy-job.hello=world");
@ -779,7 +753,6 @@ TEST_F(WorkerPoolTest, StartWorkerWithDynamicOptionsCommand) {
expected_command.end(),
{"-Xmx1g", "-Xms500m", "-Dmy-job.hello=world", "-Dmy-job.foo=bar"});
// Ray-defined per-process options
expected_command.push_back(GetNumJavaWorkersPerProcessSystemProperty(1));
expected_command.push_back("-Dray.raylet.startup-token=0");
expected_command.push_back("-Dray.internal.runtime-env-hash=1");
// User-defined per-process options
@ -1162,46 +1135,6 @@ TEST_F(WorkerPoolTest, DeleteWorkerPushPop) {
});
}
TEST_F(WorkerPoolTest, NoPopOnCrashedWorkerProcess) {
// Start a Java worker process.
PopWorkerStatus status;
auto [proc, token] = worker_pool_->StartWorkerProcess(
Language::JAVA, rpc::WorkerType::WORKER, JOB_ID, &status);
auto worker1 = worker_pool_->CreateWorker(Process(), Language::JAVA);
auto worker2 = worker_pool_->CreateWorker(Process(), Language::JAVA);
// We now imitate worker process crashing while core worker initializing.
// 1. we register both workers.
RAY_CHECK_OK(worker_pool_->RegisterWorker(
worker1, proc.GetId(), worker_pool_->GetStartupToken(proc), [](Status, int) {}));
RAY_CHECK_OK(worker_pool_->RegisterWorker(
worker2, proc.GetId(), worker_pool_->GetStartupToken(proc), [](Status, int) {}));
// 2. announce worker port for worker 1. When interacting with worker pool, it's
// PushWorker.
worker_pool_->PushWorker(worker1);
// 3. kill the worker process. Now let's assume that Raylet found that the connection
// with worker 1 disconnected first.
worker_pool_->DisconnectWorker(
worker1, /*disconnect_type=*/rpc::WorkerExitType::SYSTEM_ERROR_EXIT);
// 4. but the RPC for announcing worker port for worker 2 is already in Raylet input
// buffer. So now Raylet needs to handle worker 2.
worker_pool_->PushWorker(worker2);
// 5. Let's try to pop a worker to execute a task. Worker 2 shouldn't be popped because
// the process has crashed.
const auto task_spec = ExampleTaskSpec();
ASSERT_NE(worker_pool_->PopWorkerSync(task_spec), worker1);
ASSERT_NE(worker_pool_->PopWorkerSync(task_spec), worker2);
// 6. Now Raylet disconnects with worker 2.
worker_pool_->DisconnectWorker(
worker2, /*disconnect_type=*/rpc::WorkerExitType::SYSTEM_ERROR_EXIT);
}
TEST_F(WorkerPoolTest, TestWorkerCapping) {
auto job_id = JOB_ID;
@ -1423,8 +1356,7 @@ TEST_F(WorkerPoolTest, TestWorkerCappingWithExitDelay) {
PopWorkerStatus status;
auto [proc, token] = worker_pool_->StartWorkerProcess(
language, rpc::WorkerType::WORKER, JOB_ID, &status);
int workers_to_start =
language == Language::JAVA ? NUM_WORKERS_PER_PROCESS_JAVA : 1;
int workers_to_start = 1;
for (int j = 0; j < workers_to_start; j++) {
auto worker = worker_pool_->CreateWorker(Process(), language);
worker->SetStartupToken(worker_pool_->GetStartupToken(proc));
@ -1644,77 +1576,6 @@ TEST_F(WorkerPoolTest, RuntimeEnvUriReferenceWorkerLevel) {
}
}
TEST_F(WorkerPoolTest, RuntimeEnvUriReferenceWithMultipleWorkers) {
auto job_id = JOB_ID;
std::string uri = "s3://567";
auto runtime_env_info = ExampleRuntimeEnvInfo({uri}, false);
rpc::JobConfig job_config;
job_config.set_num_java_workers_per_process(NUM_WORKERS_PER_PROCESS_JAVA);
job_config.mutable_runtime_env_info()->CopyFrom(runtime_env_info);
// Start job without eager installed runtime env.
worker_pool_->HandleJobStarted(job_id, job_config);
ASSERT_EQ(GetReferenceCount(runtime_env_info.serialized_runtime_env()), 0);
// First part, test normal case with all worker registered.
{
// Start actors with runtime env. The Java actors will trigger a multi-worker process.
std::vector<std::shared_ptr<WorkerInterface>> workers;
for (int i = 0; i < NUM_WORKERS_PER_PROCESS_JAVA; i++) {
auto actor_creation_id = ActorID::Of(job_id, TaskID::ForDriverTask(job_id), i + 1);
const auto actor_creation_task_spec =
ExampleTaskSpec(ActorID::Nil(),
Language::JAVA,
job_id,
actor_creation_id,
{},
TaskID::FromRandom(JobID::Nil()),
runtime_env_info);
auto popped_actor_worker = worker_pool_->PopWorkerSync(actor_creation_task_spec);
ASSERT_NE(popped_actor_worker, nullptr);
workers.push_back(popped_actor_worker);
ASSERT_EQ(GetReferenceCount(runtime_env_info.serialized_runtime_env()), 1);
}
// Make sure only one worker process has been started.
ASSERT_EQ(worker_pool_->GetProcessSize(), 1);
// Disconnect all actor workers.
for (auto &worker : workers) {
worker_pool_->DisconnectWorker(worker, rpc::WorkerExitType::IDLE_EXIT);
}
ASSERT_EQ(GetReferenceCount(runtime_env_info.serialized_runtime_env()), 0);
}
// Second part, test corner case with some worker registration timeout.
{
// Start one actor with runtime env. The Java actor will trigger a multi-worker
// process.
auto actor_creation_id = ActorID::Of(job_id, TaskID::ForDriverTask(job_id), 1);
const auto actor_creation_task_spec =
ExampleTaskSpec(ActorID::Nil(),
Language::JAVA,
job_id,
actor_creation_id,
{},
TaskID::FromRandom(JobID::Nil()),
runtime_env_info);
PopWorkerStatus status;
// Only one worker registration. All the other worker registration times out.
auto popped_actor_worker = worker_pool_->PopWorkerSync(
actor_creation_task_spec, true, &status, NUM_WORKERS_PER_PROCESS_JAVA - 1);
ASSERT_EQ(GetReferenceCount(runtime_env_info.serialized_runtime_env()), 1);
// Disconnect actor worker.
worker_pool_->DisconnectWorker(popped_actor_worker, rpc::WorkerExitType::IDLE_EXIT);
ASSERT_EQ(GetReferenceCount(runtime_env_info.serialized_runtime_env()), 1);
// Sleep for a while to wait worker registration timeout.
std::this_thread::sleep_for(
std::chrono::seconds(WORKER_REGISTER_TIMEOUT_SECONDS + 1));
ASSERT_EQ(GetReferenceCount(runtime_env_info.serialized_runtime_env()), 0);
}
// Finish the job.
worker_pool_->HandleJobFinished(job_id);
ASSERT_EQ(GetReferenceCount(runtime_env_info.serialized_runtime_env()), 0);
}
TEST_F(WorkerPoolTest, CacheWorkersByRuntimeEnvHash) {
///
/// Check that a worker can be popped only if there is a