Support multiple core workers in one process (#7623)

This commit is contained in:
Kai Yang 2020-04-07 11:01:47 +08:00 committed by GitHub
parent e91595f955
commit 48b48cc8c2
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
90 changed files with 2014 additions and 1411 deletions

View file

@ -74,7 +74,10 @@ python_grpc_compile(
proto_library(
name = "gcs_service_proto",
srcs = ["src/ray/protobuf/gcs_service.proto"],
deps = [":gcs_proto"],
deps = [
":common_proto",
":gcs_proto",
],
)
cc_proto_library(

View file

@ -12,8 +12,8 @@ namespace api {
LocalModeRayRuntime::LocalModeRayRuntime(std::shared_ptr<RayConfig> config) {
config_ = config;
worker_ =
std::unique_ptr<WorkerContext>(new WorkerContext(WorkerType::DRIVER, JobID::Nil()));
worker_ = std::unique_ptr<WorkerContext>(new WorkerContext(
WorkerType::DRIVER, ComputeDriverIdFromJob(JobID::Nil()), JobID::Nil()));
object_store_ = std::unique_ptr<ObjectStore>(new LocalModeObjectStore(*this));
task_submitter_ = std::unique_ptr<TaskSubmitter>(new LocalModeTaskSubmitter(*this));
}

View file

@ -44,7 +44,7 @@ public final class Ray extends RayCall {
/**
* Shutdown Ray runtime.
*/
public static void shutdown() {
public static synchronized void shutdown() {
if (runtime != null) {
runtime.shutdown();
runtime = null;
@ -137,6 +137,11 @@ public final class Ray extends RayCall {
runtime.setAsyncContext(asyncContext);
}
// TODO (kfstorm): add the `rollbackAsyncContext` API to allow rollbacking the async context of
// the current thread to the one before `setAsyncContext` is called.
// TODO (kfstorm): unify the `wrap*` methods.
/**
* If users want to use Ray API in their own threads, they should wrap their {@link Runnable}
* objects with this method.
@ -155,7 +160,7 @@ public final class Ray extends RayCall {
* @param callable The callable to wrap.
* @return The wrapped callable.
*/
public static Callable wrapCallable(Callable callable) {
public static <T> Callable<T> wrapCallable(Callable<T> callable) {
return runtime.wrapCallable(callable);
}

View file

@ -159,6 +159,7 @@ public interface RayRuntime {
/**
* Wrap a {@link Runnable} with necessary context capture.
*
* @param runnable The runnable to wrap.
* @return The wrapped runnable.
*/
@ -166,8 +167,9 @@ public interface RayRuntime {
/**
* Wrap a {@link Callable} with necessary context capture.
*
* @param callable The callable to wrap.
* @return The wrapped callable.
*/
Callable wrapCallable(Callable callable);
<T> Callable<T> wrapCallable(Callable<T> callable);
}

View file

@ -0,0 +1,45 @@
#!/usr/bin/env bash
set -e
set -x
cd "$(dirname "$0")"
(cd .. && bazel build //java:all_tests_deploy.jar)
function generate_one()
{
file=${1//./_}.h
javah -classpath ../bazel-bin/java/all_tests_deploy.jar $1
clang-format -i $file
cat <<EOF > ../src/ray/core_worker/lib/java/$file
// Copyright 2017 The Ray Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
EOF
cat $file >> ../src/ray/core_worker/lib/java/$file
rm -f $file
}
generate_one org.ray.runtime.RayNativeRuntime
generate_one org.ray.runtime.task.NativeTaskSubmitter
generate_one org.ray.runtime.context.NativeWorkerContext
generate_one org.ray.runtime.actor.NativeRayActor
generate_one org.ray.runtime.object.NativeObjectStore
generate_one org.ray.runtime.task.NativeTaskExecutor
# Remove empty files
rm -f org_ray_runtime_RayNativeRuntime_AsyncContext.h
rm -f org_ray_runtime_task_NativeTaskExecutor_NativeActorContext.h

View file

@ -19,7 +19,6 @@ import org.ray.api.function.RayFuncVoid;
import org.ray.api.id.ObjectId;
import org.ray.api.options.ActorCreationOptions;
import org.ray.api.options.CallOptions;
import org.ray.api.runtime.RayRuntime;
import org.ray.api.runtimecontext.RuntimeContext;
import org.ray.runtime.config.RayConfig;
import org.ray.runtime.context.RuntimeContextImpl;
@ -29,6 +28,7 @@ import org.ray.runtime.functionmanager.FunctionManager;
import org.ray.runtime.functionmanager.PyFunctionDescriptor;
import org.ray.runtime.gcs.GcsClient;
import org.ray.runtime.generated.Common.Language;
import org.ray.runtime.generated.Common.WorkerType;
import org.ray.runtime.object.ObjectStore;
import org.ray.runtime.object.RayObjectImpl;
import org.ray.runtime.task.ArgumentsBuilder;
@ -41,7 +41,7 @@ import org.slf4j.LoggerFactory;
/**
* Core functionality to implement Ray APIs.
*/
public abstract class AbstractRayRuntime implements RayRuntime {
public abstract class AbstractRayRuntime implements RayRuntimeInternal {
private static final Logger LOGGER = LoggerFactory.getLogger(AbstractRayRuntime.class);
public static final String PYTHON_INIT_METHOD_NAME = "__init__";
@ -55,9 +55,15 @@ public abstract class AbstractRayRuntime implements RayRuntime {
protected TaskSubmitter taskSubmitter;
protected WorkerContext workerContext;
public AbstractRayRuntime(RayConfig rayConfig, FunctionManager functionManager) {
/**
* Whether the required thread context is set on the current thread.
*/
final ThreadLocal<Boolean> isContextSet = ThreadLocal.withInitial(() -> false);
public AbstractRayRuntime(RayConfig rayConfig) {
this.rayConfig = rayConfig;
this.functionManager = functionManager;
setIsContextSet(rayConfig.workerMode == WorkerType.DRIVER);
functionManager = new FunctionManager(rayConfig.jobResourcePath);
runtimeContext = new RuntimeContextImpl(this);
}
@ -161,13 +167,54 @@ public abstract class AbstractRayRuntime implements RayRuntime {
}
@Override
public Runnable wrapRunnable(Runnable runnable) {
return runnable;
public void setAsyncContext(Object asyncContext) {
isContextSet.set(true);
}
// TODO (kfstorm): Simplify the duplicate code in wrap*** methods.
@Override
public final Runnable wrapRunnable(Runnable runnable) {
Object asyncContext = getAsyncContext();
return () -> {
boolean oldIsContextSet = isContextSet.get();
Object oldAsyncContext = null;
if (oldIsContextSet) {
oldAsyncContext = getAsyncContext();
}
setAsyncContext(asyncContext);
try {
runnable.run();
} finally {
if (oldIsContextSet) {
setAsyncContext(oldAsyncContext);
} else {
setIsContextSet(false);
}
}
};
}
@Override
public Callable wrapCallable(Callable callable) {
return callable;
public final <T> Callable<T> wrapCallable(Callable<T> callable) {
Object asyncContext = getAsyncContext();
return () -> {
boolean oldIsContextSet = isContextSet.get();
Object oldAsyncContext = null;
if (oldIsContextSet) {
oldAsyncContext = getAsyncContext();
}
setAsyncContext(asyncContext);
try {
return callable.call();
} finally {
if (oldIsContextSet) {
setAsyncContext(oldAsyncContext);
} else {
setIsContextSet(false);
}
}
};
}
private RayObject callNormalFunction(FunctionDescriptor functionDescriptor,
@ -209,18 +256,22 @@ public abstract class AbstractRayRuntime implements RayRuntime {
return actor;
}
@Override
public WorkerContext getWorkerContext() {
return workerContext;
}
@Override
public ObjectStore getObjectStore() {
return objectStore;
}
@Override
public FunctionManager getFunctionManager() {
return functionManager;
}
@Override
public RayConfig getRayConfig() {
return rayConfig;
}
@ -229,7 +280,13 @@ public abstract class AbstractRayRuntime implements RayRuntime {
return runtimeContext;
}
@Override
public GcsClient getGcsClient() {
return gcsClient;
}
@Override
public void setIsContextSet(boolean isContextSet) {
this.isContextSet.set(isContextSet);
}
}

View file

@ -4,8 +4,6 @@ import org.ray.api.runtime.RayRuntime;
import org.ray.api.runtime.RayRuntimeFactory;
import org.ray.runtime.config.RayConfig;
import org.ray.runtime.config.RunMode;
import org.ray.runtime.functionmanager.FunctionManager;
import org.ray.runtime.generated.Common.WorkerType;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@ -20,17 +18,13 @@ public class DefaultRayRuntimeFactory implements RayRuntimeFactory {
public RayRuntime createRayRuntime() {
RayConfig rayConfig = RayConfig.getInstance();
try {
FunctionManager functionManager = new FunctionManager(rayConfig.jobResourcePath);
RayRuntime runtime;
if (rayConfig.runMode == RunMode.SINGLE_PROCESS) {
runtime = new RayDevRuntime(rayConfig, functionManager);
} else {
if (rayConfig.workerMode == WorkerType.DRIVER) {
runtime = new RayNativeRuntime(rayConfig, functionManager);
} else {
runtime = new RayMultiWorkerNativeRuntime(rayConfig, functionManager);
}
}
AbstractRayRuntime innerRuntime = rayConfig.runMode == RunMode.SINGLE_PROCESS
? new RayDevRuntime(rayConfig)
: new RayNativeRuntime(rayConfig);
RayRuntimeInternal runtime = rayConfig.numWorkersPerProcess > 1
? RayRuntimeProxy.newInstance(innerRuntime)
: innerRuntime;
runtime.start();
return runtime;
} catch (Exception e) {
LOGGER.error("Failed to initialize ray runtime", e);

View file

@ -1,12 +1,12 @@
package org.ray.runtime;
import com.google.common.base.Preconditions;
import java.util.concurrent.atomic.AtomicInteger;
import org.ray.api.BaseActor;
import org.ray.api.id.JobId;
import org.ray.api.id.UniqueId;
import org.ray.runtime.config.RayConfig;
import org.ray.runtime.context.LocalModeWorkerContext;
import org.ray.runtime.functionmanager.FunctionManager;
import org.ray.runtime.object.LocalModeObjectStore;
import org.ray.runtime.task.LocalModeTaskExecutor;
import org.ray.runtime.task.LocalModeTaskSubmitter;
@ -19,18 +19,31 @@ public class RayDevRuntime extends AbstractRayRuntime {
private AtomicInteger jobCounter = new AtomicInteger(0);
public RayDevRuntime(RayConfig rayConfig, FunctionManager functionManager) {
super(rayConfig, functionManager);
public RayDevRuntime(RayConfig rayConfig) {
super(rayConfig);
}
@Override
public void start() {
if (rayConfig.getJobId().isNil()) {
rayConfig.setJobId(nextJobId());
}
taskExecutor = new LocalModeTaskExecutor(this);
workerContext = new LocalModeWorkerContext(rayConfig.getJobId());
objectStore = new LocalModeObjectStore(workerContext);
taskSubmitter = new LocalModeTaskSubmitter(this, (LocalModeObjectStore) objectStore,
rayConfig.numberExecThreadsForDevRuntime);
taskSubmitter = new LocalModeTaskSubmitter(this, taskExecutor,
(LocalModeObjectStore) objectStore);
((LocalModeObjectStore) objectStore).addObjectPutCallback(
objectId -> ((LocalModeTaskSubmitter) taskSubmitter).onObjectPut(objectId));
objectId -> {
if (taskSubmitter != null) {
((LocalModeTaskSubmitter) taskSubmitter).onObjectPut(objectId);
}
});
}
@Override
public void run() {
throw new UnsupportedOperationException();
}
@Override
@ -60,6 +73,8 @@ public class RayDevRuntime extends AbstractRayRuntime {
@Override
public void setAsyncContext(Object asyncContext) {
Preconditions.checkArgument(asyncContext == null);
super.setAsyncContext(asyncContext);
}
private JobId nextJobId() {

View file

@ -1,215 +0,0 @@
package org.ray.runtime;
import com.google.common.base.Preconditions;
import java.util.List;
import java.util.concurrent.Callable;
import org.ray.api.BaseActor;
import org.ray.api.RayActor;
import org.ray.api.RayObject;
import org.ray.api.RayPyActor;
import org.ray.api.WaitResult;
import org.ray.api.function.PyActorClass;
import org.ray.api.function.PyActorMethod;
import org.ray.api.function.PyRemoteFunction;
import org.ray.api.function.RayFunc;
import org.ray.api.id.ObjectId;
import org.ray.api.id.UniqueId;
import org.ray.api.options.ActorCreationOptions;
import org.ray.api.options.CallOptions;
import org.ray.api.runtime.RayRuntime;
import org.ray.api.runtimecontext.RuntimeContext;
import org.ray.runtime.config.RayConfig;
import org.ray.runtime.config.RunMode;
import org.ray.runtime.functionmanager.FunctionManager;
import org.ray.runtime.generated.Common.WorkerType;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* This is a proxy runtime for multi-worker support. It holds multiple {@link RayNativeRuntime}
* instances and redirect calls to the correct one based on thread context.
*/
public class RayMultiWorkerNativeRuntime implements RayRuntime {
private static final Logger LOGGER = LoggerFactory.getLogger(RayMultiWorkerNativeRuntime.class);
private final FunctionManager functionManager;
/**
* The number of workers per worker process.
*/
private final int numWorkers;
/**
* The worker threads.
*/
private final Thread[] threads;
/**
* The {@link RayNativeRuntime} instances of workers.
*/
private final RayNativeRuntime[] runtimes;
/**
* The {@link RayNativeRuntime} instance of current thread.
*/
private final ThreadLocal<RayNativeRuntime> currentThreadRuntime = new ThreadLocal<>();
public RayMultiWorkerNativeRuntime(RayConfig rayConfig, FunctionManager functionManager) {
this.functionManager = functionManager;
Preconditions.checkState(
rayConfig.runMode == RunMode.CLUSTER && rayConfig.workerMode == WorkerType.WORKER);
Preconditions.checkState(rayConfig.numWorkersPerProcess > 0,
"numWorkersPerProcess must be greater than 0.");
numWorkers = rayConfig.numWorkersPerProcess;
runtimes = new RayNativeRuntime[numWorkers];
threads = new Thread[numWorkers];
LOGGER.info("Starting {} workers.", numWorkers);
for (int i = 0; i < numWorkers; i++) {
final int workerIndex = i;
threads[i] = new Thread(() -> {
RayNativeRuntime runtime = new RayNativeRuntime(rayConfig, functionManager);
runtimes[workerIndex] = runtime;
currentThreadRuntime.set(runtime);
runtime.run();
});
}
}
public void run() {
for (int i = 0; i < numWorkers; i++) {
threads[i].start();
}
for (int i = 0; i < numWorkers; i++) {
try {
threads[i].join();
} catch (InterruptedException e) {
throw new RuntimeException(e);
}
}
}
@Override
public void shutdown() {
for (int i = 0; i < numWorkers; i++) {
runtimes[i].shutdown();
}
for (int i = 0; i < numWorkers; i++) {
try {
threads[i].join();
} catch (InterruptedException e) {
throw new RuntimeException(e);
}
}
}
public RayNativeRuntime getCurrentRuntime() {
RayNativeRuntime currentRuntime = currentThreadRuntime.get();
Preconditions.checkNotNull(currentRuntime,
"RayRuntime is not set on current thread."
+ " If you want to use Ray API in your own threads,"
+ " please wrap your `Runnable`s or `Callable`s with"
+ " `Ray.wrapRunnable` or `Ray.wrapCallable`.");
return currentRuntime;
}
@Override
public <T> RayObject<T> put(T obj) {
return getCurrentRuntime().put(obj);
}
@Override
public <T> T get(ObjectId objectId) {
return getCurrentRuntime().get(objectId);
}
@Override
public <T> List<T> get(List<ObjectId> objectIds) {
return getCurrentRuntime().get(objectIds);
}
@Override
public <T> WaitResult<T> wait(List<RayObject<T>> waitList, int numReturns, int timeoutMs) {
return getCurrentRuntime().wait(waitList, numReturns, timeoutMs);
}
@Override
public void free(List<ObjectId> objectIds, boolean localOnly, boolean deleteCreatingTasks) {
getCurrentRuntime().free(objectIds, localOnly, deleteCreatingTasks);
}
@Override
public void setResource(String resourceName, double capacity, UniqueId nodeId) {
getCurrentRuntime().setResource(resourceName, capacity, nodeId);
}
@Override
public void killActor(BaseActor actor, boolean noReconstruction) {
getCurrentRuntime().killActor(actor, noReconstruction);
}
@Override
public RayObject call(RayFunc func, Object[] args, CallOptions options) {
return getCurrentRuntime().call(func, args, options);
}
@Override
public RayObject call(PyRemoteFunction pyRemoteFunction, Object[] args,
CallOptions options) {
return getCurrentRuntime().call(pyRemoteFunction, args, options);
}
@Override
public RayObject callActor(RayActor<?> actor, RayFunc func, Object[] args) {
return getCurrentRuntime().callActor(actor, func, args);
}
@Override
public RayObject callActor(RayPyActor pyActor, PyActorMethod pyActorMethod, Object[] args) {
return getCurrentRuntime().callActor(pyActor, pyActorMethod, args);
}
@Override
public <T> RayActor<T> createActor(RayFunc actorFactoryFunc, Object[] args,
ActorCreationOptions options) {
return getCurrentRuntime().createActor(actorFactoryFunc, args, options);
}
@Override
public RayPyActor createActor(PyActorClass pyActorClass, Object[] args,
ActorCreationOptions options) {
return getCurrentRuntime().createActor(pyActorClass, args, options);
}
@Override
public RuntimeContext getRuntimeContext() {
return getCurrentRuntime().getRuntimeContext();
}
@Override
public Object getAsyncContext() {
return getCurrentRuntime();
}
@Override
public void setAsyncContext(Object asyncContext) {
currentThreadRuntime.set((RayNativeRuntime) asyncContext);
}
@Override
public Runnable wrapRunnable(Runnable runnable) {
Object asyncContext = getAsyncContext();
return () -> {
setAsyncContext(asyncContext);
runnable.run();
};
}
@Override
public Callable wrapCallable(Callable callable) {
Object asyncContext = getAsyncContext();
return () -> {
setAsyncContext(asyncContext);
return callable.call();
};
}
}

View file

@ -3,7 +3,6 @@ package org.ray.runtime;
import com.google.common.base.Preconditions;
import java.io.File;
import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
import org.apache.commons.io.FileUtils;
import org.ray.api.BaseActor;
@ -11,7 +10,6 @@ import org.ray.api.id.JobId;
import org.ray.api.id.UniqueId;
import org.ray.runtime.config.RayConfig;
import org.ray.runtime.context.NativeWorkerContext;
import org.ray.runtime.functionmanager.FunctionManager;
import org.ray.runtime.gcs.GcsClient;
import org.ray.runtime.gcs.GcsClientOptions;
import org.ray.runtime.gcs.RedisClient;
@ -34,10 +32,6 @@ public final class RayNativeRuntime extends AbstractRayRuntime {
private RunManager manager = null;
/**
* The native pointer of core worker.
*/
private long nativeCoreWorkerPointer;
static {
LOGGER.debug("Loading native libraries.");
@ -55,14 +49,13 @@ public final class RayNativeRuntime extends AbstractRayRuntime {
JniUtils.loadLibrary("core_worker_library_java", true);
LOGGER.debug("Native libraries loaded.");
// Reset library path at runtime.
resetLibraryPath(rayConfig);
try {
FileUtils.forceMkdir(new File(rayConfig.logDir));
} catch (IOException e) {
throw new RuntimeException("Failed to create the log directory.", e);
}
nativeSetup(rayConfig.logDir, rayConfig.rayletConfigParameters);
Runtime.getRuntime().addShutdownHook(new Thread(RayNativeRuntime::nativeShutdownHook));
}
private static void resetLibraryPath(RayConfig rayConfig) {
@ -71,11 +64,12 @@ public final class RayNativeRuntime extends AbstractRayRuntime {
JniUtils.resetLibraryPath(libraryPath);
}
public RayNativeRuntime(RayConfig rayConfig, FunctionManager functionManager) {
super(rayConfig, functionManager);
// Reset library path at runtime.
resetLibraryPath(rayConfig);
public RayNativeRuntime(RayConfig rayConfig) {
super(rayConfig);
}
@Override
public void start() {
if (rayConfig.getRedisAddress() == null) {
manager = new RunManager(rayConfig);
manager.startRayProcesses(true);
@ -86,21 +80,21 @@ public final class RayNativeRuntime extends AbstractRayRuntime {
if (rayConfig.getJobId() == JobId.NIL) {
rayConfig.setJobId(gcsClient.nextJobId());
}
int numWorkersPerProcess =
rayConfig.workerMode == WorkerType.DRIVER ? 1 : rayConfig.numWorkersPerProcess;
// TODO(qwang): Get object_store_socket_name and raylet_socket_name from Redis.
nativeCoreWorkerPointer = nativeInitCoreWorker(rayConfig.workerMode.getNumber(),
rayConfig.objectStoreSocketName, rayConfig.rayletSocketName,
nativeInitialize(rayConfig.workerMode.getNumber(),
rayConfig.nodeIp, rayConfig.getNodeManagerPort(),
rayConfig.workerMode == WorkerType.DRIVER ? System.getProperty("user.dir") : "",
rayConfig.objectStoreSocketName, rayConfig.rayletSocketName,
(rayConfig.workerMode == WorkerType.DRIVER ? rayConfig.getJobId() : JobId.NIL).getBytes(),
new GcsClientOptions(rayConfig));
Preconditions.checkState(nativeCoreWorkerPointer != 0);
new GcsClientOptions(rayConfig), numWorkersPerProcess,
rayConfig.logDir, rayConfig.rayletConfigParameters);
workerContext = new NativeWorkerContext(nativeCoreWorkerPointer);
taskExecutor = new NativeTaskExecutor(nativeCoreWorkerPointer, this);
objectStore = new NativeObjectStore(workerContext, nativeCoreWorkerPointer);
taskSubmitter = new NativeTaskSubmitter(nativeCoreWorkerPointer);
// register
registerWorker();
taskExecutor = new NativeTaskExecutor(this);
workerContext = new NativeWorkerContext();
objectStore = new NativeObjectStore(workerContext);
taskSubmitter = new NativeTaskSubmitter();
LOGGER.info("RayNativeRuntime started with store {}, raylet {}",
rayConfig.objectStoreSocketName, rayConfig.rayletSocketName);
@ -108,10 +102,7 @@ public final class RayNativeRuntime extends AbstractRayRuntime {
@Override
public void shutdown() {
if (nativeCoreWorkerPointer != 0) {
nativeDestroyCoreWorker(nativeCoreWorkerPointer);
nativeCoreWorkerPointer = 0;
}
nativeShutdown();
if (null != manager) {
manager.cleanup();
manager = null;
@ -131,75 +122,56 @@ public final class RayNativeRuntime extends AbstractRayRuntime {
if (nodeId == null) {
nodeId = UniqueId.NIL;
}
nativeSetResource(nativeCoreWorkerPointer, resourceName, capacity, nodeId.getBytes());
nativeSetResource(resourceName, capacity, nodeId.getBytes());
}
@Override
public void killActor(BaseActor actor, boolean noReconstruction) {
nativeKillActor(nativeCoreWorkerPointer, actor.getId().getBytes(), noReconstruction);
nativeKillActor(actor.getId().getBytes(), noReconstruction);
}
@Override
public Object getAsyncContext() {
return null;
return new AsyncContext(workerContext.getCurrentWorkerId(),
workerContext.getCurrentClassLoader());
}
@Override
public void setAsyncContext(Object asyncContext) {
nativeSetCoreWorker(((AsyncContext) asyncContext).workerId.getBytes());
workerContext.setCurrentClassLoader(((AsyncContext) asyncContext).currentClassLoader);
super.setAsyncContext(asyncContext);
}
@Override
public void run() {
nativeRunTaskExecutor(nativeCoreWorkerPointer);
Preconditions.checkState(rayConfig.workerMode == WorkerType.WORKER);
nativeRunTaskExecutor(taskExecutor);
}
public long getNativeCoreWorkerPointer() {
return nativeCoreWorkerPointer;
}
private static native void nativeInitialize(int workerMode, String ndoeIpAddress,
int nodeManagerPort, String driverName, String storeSocket, String rayletSocket,
byte[] jobId, GcsClientOptions gcsClientOptions, int numWorkersPerProcess,
String logDir, Map<String, String> rayletConfigParameters);
public TaskExecutor getTaskExecutor() {
return taskExecutor;
}
private static native void nativeRunTaskExecutor(TaskExecutor taskExecutor);
/**
* Register this worker or driver to GCS.
*/
private void registerWorker() {
RedisClient redisClient = new RedisClient(rayConfig.getRedisAddress(), rayConfig.redisPassword);
Map<String, String> workerInfo = new HashMap<>();
String workerId = new String(workerContext.getCurrentWorkerId().getBytes());
if (rayConfig.workerMode == WorkerType.DRIVER) {
workerInfo.put("node_ip_address", rayConfig.nodeIp);
workerInfo.put("driver_id", workerId);
workerInfo.put("start_time", String.valueOf(System.currentTimeMillis()));
workerInfo.put("plasma_store_socket", rayConfig.objectStoreSocketName);
workerInfo.put("raylet_socket", rayConfig.rayletSocketName);
workerInfo.put("name", System.getProperty("user.dir"));
//TODO: worker.redis_client.hmset(b"Drivers:" + worker.workerId, driver_info)
redisClient.hmset("Drivers:" + workerId, workerInfo);
} else {
workerInfo.put("node_ip_address", rayConfig.nodeIp);
workerInfo.put("plasma_store_socket", rayConfig.objectStoreSocketName);
workerInfo.put("raylet_socket", rayConfig.rayletSocketName);
//TODO: b"Workers:" + worker.workerId,
redisClient.hmset("Workers:" + workerId, workerInfo);
private static native void nativeShutdown();
private static native void nativeSetResource(String resourceName, double capacity, byte[] nodeId);
private static native void nativeKillActor(byte[] actorId, boolean noReconstruction);
private static native void nativeSetCoreWorker(byte[] workerId);
static class AsyncContext {
public final UniqueId workerId;
public final ClassLoader currentClassLoader;
AsyncContext(UniqueId workerId, ClassLoader currentClassLoader) {
this.workerId = workerId;
this.currentClassLoader = currentClassLoader;
}
}
private static native long nativeInitCoreWorker(int workerMode, String storeSocket,
String rayletSocket, String nodeIpAddress, int nodeManagerPort, byte[] jobId,
GcsClientOptions gcsClientOptions);
private static native void nativeRunTaskExecutor(long nativeCoreWorkerPointer);
private static native void nativeDestroyCoreWorker(long nativeCoreWorkerPointer);
private static native void nativeSetup(String logDir, Map<String, String> rayletConfigParameters);
private static native void nativeShutdownHook();
private static native void nativeSetResource(long conn, String resourceName, double capacity,
byte[] nodeId);
private static native void nativeKillActor(long nativeCoreWorkerPointer, byte[] actorId,
boolean noReconstruction);
}

View file

@ -0,0 +1,33 @@
package org.ray.runtime;
import org.ray.api.runtime.RayRuntime;
import org.ray.runtime.config.RayConfig;
import org.ray.runtime.context.WorkerContext;
import org.ray.runtime.functionmanager.FunctionManager;
import org.ray.runtime.gcs.GcsClient;
import org.ray.runtime.object.ObjectStore;
/**
* This interface is required to make {@link RayRuntimeProxy} work.
*/
public interface RayRuntimeInternal extends RayRuntime {
/**
* Start runtime.
*/
void start();
WorkerContext getWorkerContext();
ObjectStore getObjectStore();
FunctionManager getFunctionManager();
RayConfig getRayConfig();
GcsClient getGcsClient();
void setIsContextSet(boolean isContextSet);
void run();
}

View file

@ -0,0 +1,83 @@
package org.ray.runtime;
import java.lang.reflect.InvocationHandler;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import org.ray.api.exception.RayException;
import org.ray.api.runtime.RayRuntime;
import org.ray.runtime.config.RunMode;
/**
* Protect a ray runtime with context checks for all methods of {@link RayRuntime} (except {@link
* RayRuntime#shutdown(boolean)}).
*/
public class RayRuntimeProxy implements InvocationHandler {
/**
* The original runtime.
*/
private AbstractRayRuntime obj;
private RayRuntimeProxy(AbstractRayRuntime obj) {
this.obj = obj;
}
public AbstractRayRuntime getRuntimeObject() {
return obj;
}
/**
* Generate a new instance of {@link RayRuntimeInternal} with additional context check.
*/
static RayRuntimeInternal newInstance(AbstractRayRuntime obj) {
return (RayRuntimeInternal) java.lang.reflect.Proxy
.newProxyInstance(obj.getClass().getClassLoader(), new Class<?>[]{RayRuntimeInternal.class},
new RayRuntimeProxy(obj));
}
@Override
public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
if (isInterfaceMethod(method) && !method.getName().equals("shutdown") && !method.getName()
.equals("setAsyncContext")) {
checkIsContextSet();
}
try {
return method.invoke(obj, args);
} catch (InvocationTargetException e) {
if (e.getCause() != null) {
throw e.getCause();
} else {
throw e;
}
}
}
/**
* Whether the method is defined in the {@link RayRuntime} interface.
*/
private boolean isInterfaceMethod(Method method) {
try {
RayRuntime.class.getMethod(method.getName(), method.getParameterTypes());
return true;
} catch (NoSuchMethodException e) {
return false;
}
}
/**
* Check if thread context is set.
* <p/>
* This method should be invoked at the beginning of most public methods of {@link RayRuntime},
* otherwise the native code might crash due to thread local core worker was not set. We check it
* for {@link AbstractRayRuntime} instead of {@link RayNativeRuntime} because we want to catch the
* error even if the application runs in {@link RunMode#SINGLE_PROCESS} mode.
*/
private void checkIsContextSet() {
if (!obj.isContextSet.get()) {
throw new RayException(
"`Ray.wrap***` is not called on the current thread."
+ " If you want to use Ray API in your own threads,"
+ " please wrap your executable with `Ray.wrap***`.");
}
}
}

View file

@ -7,11 +7,7 @@ import java.io.ObjectInput;
import java.io.ObjectOutput;
import java.util.List;
import org.ray.api.BaseActor;
import org.ray.api.Ray;
import org.ray.api.id.ActorId;
import org.ray.api.runtime.RayRuntime;
import org.ray.runtime.RayMultiWorkerNativeRuntime;
import org.ray.runtime.RayNativeRuntime;
import org.ray.runtime.generated.Common.Language;
/**
@ -20,20 +16,17 @@ import org.ray.runtime.generated.Common.Language;
*/
public abstract class NativeRayActor implements BaseActor, Externalizable {
/**
* Address of core worker.
*/
long nativeCoreWorkerPointer;
/**
* ID of the actor.
*/
byte[] actorId;
NativeRayActor(long nativeCoreWorkerPointer, byte[] actorId) {
Preconditions.checkState(nativeCoreWorkerPointer != 0);
private Language language;
NativeRayActor(byte[] actorId, Language language) {
Preconditions.checkState(!ActorId.fromBytes(actorId).isNil());
this.nativeCoreWorkerPointer = nativeCoreWorkerPointer;
this.actorId = actorId;
this.language = language;
}
/**
@ -42,14 +35,12 @@ public abstract class NativeRayActor implements BaseActor, Externalizable {
NativeRayActor() {
}
public static NativeRayActor create(long nativeCoreWorkerPointer, byte[] actorId,
Language language) {
Preconditions.checkState(nativeCoreWorkerPointer != 0);
public static NativeRayActor create(byte[] actorId, Language language) {
switch (language) {
case JAVA:
return new NativeRayJavaActor(nativeCoreWorkerPointer, actorId);
return new NativeRayJavaActor(actorId);
case PYTHON:
return new NativeRayPyActor(nativeCoreWorkerPointer, actorId);
return new NativeRayPyActor(actorId);
default:
throw new IllegalStateException("Unknown actor handle language: " + language);
}
@ -61,18 +52,19 @@ public abstract class NativeRayActor implements BaseActor, Externalizable {
}
public Language getLanguage() {
return Language.forNumber(nativeGetLanguage(nativeCoreWorkerPointer, actorId));
return language;
}
@Override
public void writeExternal(ObjectOutput out) throws IOException {
out.writeObject(toBytes());
out.writeObject(nativeSerialize(actorId));
out.writeObject(language);
}
@Override
public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException {
nativeCoreWorkerPointer = getNativeCoreWorkerPointer();
actorId = nativeDeserialize(nativeCoreWorkerPointer, (byte[]) in.readObject());
actorId = nativeDeserialize((byte[]) in.readObject());
language = (Language) in.readObject();
}
/**
@ -81,7 +73,7 @@ public abstract class NativeRayActor implements BaseActor, Externalizable {
* @return the bytes of the actor handle
*/
public byte[] toBytes() {
return nativeSerialize(nativeCoreWorkerPointer, actorId);
return nativeSerialize(actorId);
}
/**
@ -90,21 +82,10 @@ public abstract class NativeRayActor implements BaseActor, Externalizable {
* @return the bytes of an actor handle
*/
public static NativeRayActor fromBytes(byte[] bytes) {
long nativeCoreWorkerPointer = getNativeCoreWorkerPointer();
byte[] actorId = nativeDeserialize(nativeCoreWorkerPointer, bytes);
Language language = Language.forNumber(nativeGetLanguage(nativeCoreWorkerPointer, actorId));
byte[] actorId = nativeDeserialize(bytes);
Language language = Language.forNumber(nativeGetLanguage(actorId));
Preconditions.checkNotNull(language);
return create(nativeCoreWorkerPointer, actorId, language);
}
private static long getNativeCoreWorkerPointer() {
RayRuntime runtime = Ray.internal();
if (runtime instanceof RayMultiWorkerNativeRuntime) {
runtime = ((RayMultiWorkerNativeRuntime) runtime).getCurrentRuntime();
}
Preconditions.checkState(runtime instanceof RayNativeRuntime);
return ((RayNativeRuntime) runtime).getNativeCoreWorkerPointer();
return create(actorId, language);
}
@Override
@ -112,13 +93,11 @@ public abstract class NativeRayActor implements BaseActor, Externalizable {
// TODO(zhijunfu): do we need to free the ActorHandle in core worker?
}
private static native int nativeGetLanguage(
long nativeCoreWorkerPointer, byte[] actorId);
private static native int nativeGetLanguage(byte[] actorId);
static native List<String> nativeGetActorCreationTaskFunctionDescriptor(
long nativeCoreWorkerPointer, byte[] actorId);
static native List<String> nativeGetActorCreationTaskFunctionDescriptor(byte[] actorId);
private static native byte[] nativeSerialize(long nativeCoreWorkerPointer, byte[] actorId);
private static native byte[] nativeSerialize(byte[] actorId);
private static native byte[] nativeDeserialize(long nativeCoreWorkerPointer, byte[] data);
private static native byte[] nativeDeserialize(byte[] data);
}

View file

@ -11,8 +11,8 @@ import org.ray.runtime.generated.Common.Language;
*/
public class NativeRayJavaActor extends NativeRayActor implements RayActor {
NativeRayJavaActor(long nativeCoreWorkerPointer, byte[] actorId) {
super(nativeCoreWorkerPointer, actorId);
NativeRayJavaActor(byte[] actorId) {
super(actorId, Language.JAVA);
}
/**

View file

@ -11,8 +11,8 @@ import org.ray.runtime.generated.Common.Language;
*/
public class NativeRayPyActor extends NativeRayActor implements RayPyActor {
NativeRayPyActor(long nativeCoreWorkerPointer, byte[] actorId) {
super(nativeCoreWorkerPointer, actorId);
NativeRayPyActor(byte[] actorId) {
super(actorId, Language.PYTHON);
}
/**
@ -24,12 +24,12 @@ public class NativeRayPyActor extends NativeRayActor implements RayPyActor {
@Override
public String getModuleName() {
return nativeGetActorCreationTaskFunctionDescriptor(nativeCoreWorkerPointer, actorId).get(0);
return nativeGetActorCreationTaskFunctionDescriptor(actorId).get(0);
}
@Override
public String getClassName() {
return nativeGetActorCreationTaskFunctionDescriptor(nativeCoreWorkerPointer, actorId).get(1);
return nativeGetActorCreationTaskFunctionDescriptor(actorId).get(1);
}
@Override

View file

@ -93,10 +93,6 @@ public class RayConfig {
}
}
/**
* Number of threads that execute tasks.
*/
public final int numberExecThreadsForDevRuntime;
public final int numWorkersPerProcess;
@ -220,9 +216,6 @@ public class RayConfig {
jobResourcePath = null;
}
// Number of threads that execute tasks.
numberExecThreadsForDevRuntime = config.getInt("ray.dev-runtime.execution-parallelism");
numWorkersPerProcess = config.getInt("ray.raylet.config.num_workers_per_process_java");
gcsServiceEnabled = System.getenv("RAY_GCS_SERVICE_ENABLED") == null ||

View file

@ -16,6 +16,7 @@ public class LocalModeWorkerContext implements WorkerContext {
private final JobId jobId;
private ThreadLocal<TaskSpec> currentTask = new ThreadLocal<>();
private final ThreadLocal<UniqueId> currentWorkerId = new ThreadLocal<>();
public LocalModeWorkerContext(JobId jobId) {
this.jobId = jobId;
@ -23,7 +24,11 @@ public class LocalModeWorkerContext implements WorkerContext {
@Override
public UniqueId getCurrentWorkerId() {
throw new UnsupportedOperationException();
return currentWorkerId.get();
}
public void setCurrentWorkerId(UniqueId workerId) {
currentWorkerId.set(workerId);
}
@Override

View file

@ -12,61 +12,52 @@ import org.ray.runtime.generated.Common.TaskType;
*/
public class NativeWorkerContext implements WorkerContext {
/**
* The native pointer of core worker.
*/
private final long nativeCoreWorkerPointer;
private ClassLoader currentClassLoader;
public NativeWorkerContext(long nativeCoreWorkerPointer) {
this.nativeCoreWorkerPointer = nativeCoreWorkerPointer;
}
private final ThreadLocal<ClassLoader> currentClassLoader = new ThreadLocal<>();
@Override
public UniqueId getCurrentWorkerId() {
return UniqueId.fromByteBuffer(nativeGetCurrentWorkerId(nativeCoreWorkerPointer));
return UniqueId.fromByteBuffer(nativeGetCurrentWorkerId());
}
@Override
public JobId getCurrentJobId() {
return JobId.fromByteBuffer(nativeGetCurrentJobId(nativeCoreWorkerPointer));
return JobId.fromByteBuffer(nativeGetCurrentJobId());
}
@Override
public ActorId getCurrentActorId() {
return ActorId.fromByteBuffer(nativeGetCurrentActorId(nativeCoreWorkerPointer));
return ActorId.fromByteBuffer(nativeGetCurrentActorId());
}
@Override
public ClassLoader getCurrentClassLoader() {
return currentClassLoader;
return currentClassLoader.get();
}
@Override
public void setCurrentClassLoader(ClassLoader currentClassLoader) {
if (this.currentClassLoader != currentClassLoader) {
this.currentClassLoader = currentClassLoader;
if (this.currentClassLoader.get() != currentClassLoader) {
this.currentClassLoader.set(currentClassLoader);
}
}
@Override
public TaskType getCurrentTaskType() {
return TaskType.forNumber(nativeGetCurrentTaskType(nativeCoreWorkerPointer));
return TaskType.forNumber(nativeGetCurrentTaskType());
}
@Override
public TaskId getCurrentTaskId() {
return TaskId.fromByteBuffer(nativeGetCurrentTaskId(nativeCoreWorkerPointer));
return TaskId.fromByteBuffer(nativeGetCurrentTaskId());
}
private static native int nativeGetCurrentTaskType(long nativeCoreWorkerPointer);
private static native int nativeGetCurrentTaskType();
private static native ByteBuffer nativeGetCurrentTaskId(long nativeCoreWorkerPointer);
private static native ByteBuffer nativeGetCurrentTaskId();
private static native ByteBuffer nativeGetCurrentJobId(long nativeCoreWorkerPointer);
private static native ByteBuffer nativeGetCurrentJobId();
private static native ByteBuffer nativeGetCurrentWorkerId(long nativeCoreWorkerPointer);
private static native ByteBuffer nativeGetCurrentWorkerId();
private static native ByteBuffer nativeGetCurrentActorId(long nativeCoreWorkerPointer);
private static native ByteBuffer nativeGetCurrentActorId();
}

View file

@ -6,15 +6,15 @@ import org.ray.api.id.ActorId;
import org.ray.api.id.JobId;
import org.ray.api.runtimecontext.NodeInfo;
import org.ray.api.runtimecontext.RuntimeContext;
import org.ray.runtime.AbstractRayRuntime;
import org.ray.runtime.RayRuntimeInternal;
import org.ray.runtime.config.RunMode;
import org.ray.runtime.generated.Common.TaskType;
public class RuntimeContextImpl implements RuntimeContext {
private AbstractRayRuntime runtime;
private RayRuntimeInternal runtime;
public RuntimeContextImpl(AbstractRayRuntime runtime) {
public RuntimeContextImpl(RayRuntimeInternal runtime) {
this.runtime = runtime;
}

View file

@ -15,56 +15,48 @@ public class NativeObjectStore extends ObjectStore {
private static final Logger LOGGER = LoggerFactory.getLogger(NativeObjectStore.class);
/**
* The native pointer of core worker.
*/
private final long nativeCoreWorkerPointer;
public NativeObjectStore(WorkerContext workerContext, long nativeCoreWorkerPointer) {
public NativeObjectStore(WorkerContext workerContext) {
super(workerContext);
this.nativeCoreWorkerPointer = nativeCoreWorkerPointer;
}
@Override
public ObjectId putRaw(NativeRayObject obj) {
return new ObjectId(nativePut(nativeCoreWorkerPointer, obj));
return new ObjectId(nativePut(obj));
}
@Override
public void putRaw(NativeRayObject obj, ObjectId objectId) {
nativePut(nativeCoreWorkerPointer, objectId.getBytes(), obj);
nativePut(objectId.getBytes(), obj);
}
@Override
public List<NativeRayObject> getRaw(List<ObjectId> objectIds, long timeoutMs) {
return nativeGet(nativeCoreWorkerPointer, toBinaryList(objectIds), timeoutMs);
return nativeGet(toBinaryList(objectIds), timeoutMs);
}
@Override
public List<Boolean> wait(List<ObjectId> objectIds, int numObjects, long timeoutMs) {
return nativeWait(nativeCoreWorkerPointer, toBinaryList(objectIds), numObjects, timeoutMs);
return nativeWait(toBinaryList(objectIds), numObjects, timeoutMs);
}
@Override
public void delete(List<ObjectId> objectIds, boolean localOnly, boolean deleteCreatingTasks) {
nativeDelete(nativeCoreWorkerPointer, toBinaryList(objectIds), localOnly, deleteCreatingTasks);
nativeDelete(toBinaryList(objectIds), localOnly, deleteCreatingTasks);
}
private static List<byte[]> toBinaryList(List<ObjectId> ids) {
return ids.stream().map(BaseId::getBytes).collect(Collectors.toList());
}
private static native byte[] nativePut(long nativeCoreWorkerPointer, NativeRayObject obj);
private static native byte[] nativePut(NativeRayObject obj);
private static native void nativePut(long nativeCoreWorkerPointer, byte[] objectId,
NativeRayObject obj);
private static native void nativePut(byte[] objectId, NativeRayObject obj);
private static native List<NativeRayObject> nativeGet(long nativeCoreWorkerPointer,
List<byte[]> ids, long timeoutMs);
private static native List<NativeRayObject> nativeGet(List<byte[]> ids, long timeoutMs);
private static native List<Boolean> nativeWait(long nativeCoreWorkerPointer,
List<byte[]> objectIds, int numObjects, long timeoutMs);
private static native List<Boolean> nativeWait(List<byte[]> objectIds, int numObjects,
long timeoutMs);
private static native void nativeDelete(long nativeCoreWorkerPointer, List<byte[]> objectIds,
boolean localOnly, boolean deleteCreatingTasks);
private static native void nativeDelete(List<byte[]> objectIds, boolean localOnly,
boolean deleteCreatingTasks);
}

View file

@ -1,9 +1,7 @@
package org.ray.runtime.runner.worker;
import org.ray.api.Ray;
import org.ray.api.runtime.RayRuntime;
import org.ray.runtime.RayMultiWorkerNativeRuntime;
import org.ray.runtime.RayNativeRuntime;
import org.ray.runtime.RayRuntimeInternal;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@ -25,14 +23,7 @@ public class DefaultWorker {
});
Ray.init();
LOGGER.info("Worker started.");
RayRuntime runtime = Ray.internal();
if (runtime instanceof RayNativeRuntime) {
((RayNativeRuntime)runtime).run();
} else if (runtime instanceof RayMultiWorkerNativeRuntime) {
((RayMultiWorkerNativeRuntime)runtime).run();
} else {
throw new RuntimeException("Unknown RayRuntime: " + runtime);
}
((RayRuntimeInternal) Ray.internal()).run();
} catch (Exception e) {
LOGGER.error("Failed to start worker.", e);
}

View file

@ -6,9 +6,7 @@ import java.util.List;
import org.ray.api.Ray;
import org.ray.api.RayObject;
import org.ray.api.id.ObjectId;
import org.ray.api.runtime.RayRuntime;
import org.ray.runtime.AbstractRayRuntime;
import org.ray.runtime.RayMultiWorkerNativeRuntime;
import org.ray.runtime.RayRuntimeInternal;
import org.ray.runtime.generated.Common.Language;
import org.ray.runtime.object.NativeRayObject;
import org.ray.runtime.object.ObjectSerializer;
@ -43,11 +41,7 @@ public class ArgumentsBuilder {
} else {
value = ObjectSerializer.serialize(arg);
if (value.data.length > LARGEST_SIZE_PASS_BY_VALUE) {
RayRuntime runtime = Ray.internal();
if (runtime instanceof RayMultiWorkerNativeRuntime) {
runtime = ((RayMultiWorkerNativeRuntime) runtime).getCurrentRuntime();
}
id = ((AbstractRayRuntime) runtime).getObjectStore()
id = ((RayRuntimeInternal) Ray.internal()).getObjectStore()
.putRaw(value);
value = null;
}

View file

@ -1,17 +1,40 @@
package org.ray.runtime.task;
import org.ray.api.id.ActorId;
import org.ray.runtime.AbstractRayRuntime;
import org.ray.api.id.UniqueId;
import org.ray.runtime.RayRuntimeInternal;
import org.ray.runtime.task.LocalModeTaskExecutor.LocalActorContext;
/**
* Task executor for local mode.
*/
public class LocalModeTaskExecutor extends TaskExecutor {
public class LocalModeTaskExecutor extends TaskExecutor<LocalActorContext> {
public LocalModeTaskExecutor(AbstractRayRuntime runtime) {
static class LocalActorContext extends TaskExecutor.ActorContext {
/**
* The worker ID of the actor.
*/
private final UniqueId workerId;
public LocalActorContext(UniqueId workerId) {
this.workerId = workerId;
}
public UniqueId getWorkerId() {
return workerId;
}
}
public LocalModeTaskExecutor(RayRuntimeInternal runtime) {
super(runtime);
}
@Override
protected LocalActorContext createActorContext() {
return new LocalActorContext(runtime.getWorkerContext().getCurrentWorkerId());
}
@Override
protected void maybeSaveCheckpoint(Object actor, ActorId actorId) {
}

View file

@ -4,16 +4,15 @@ import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.protobuf.ByteString;
import java.nio.ByteBuffer;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Deque;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.stream.Collectors;
@ -21,9 +20,10 @@ import org.ray.api.BaseActor;
import org.ray.api.id.ActorId;
import org.ray.api.id.ObjectId;
import org.ray.api.id.TaskId;
import org.ray.api.id.UniqueId;
import org.ray.api.options.ActorCreationOptions;
import org.ray.api.options.CallOptions;
import org.ray.runtime.RayDevRuntime;
import org.ray.runtime.RayRuntimeInternal;
import org.ray.runtime.actor.LocalModeRayActor;
import org.ray.runtime.context.LocalModeWorkerContext;
import org.ray.runtime.functionmanager.FunctionDescriptor;
@ -37,6 +37,7 @@ import org.ray.runtime.generated.Common.TaskSpec;
import org.ray.runtime.generated.Common.TaskType;
import org.ray.runtime.object.LocalModeObjectStore;
import org.ray.runtime.object.NativeRayObject;
import org.ray.runtime.task.TaskExecutor.ActorContext;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@ -49,7 +50,8 @@ public class LocalModeTaskSubmitter implements TaskSubmitter {
private final Map<ObjectId, Set<TaskSpec>> waitingTasks = new HashMap<>();
private final Object taskAndObjectLock = new Object();
private final RayDevRuntime runtime;
private final RayRuntimeInternal runtime;
private final TaskExecutor taskExecutor;
private final LocalModeObjectStore objectStore;
/// The thread pool to execute actor tasks.
@ -58,17 +60,16 @@ public class LocalModeTaskSubmitter implements TaskSubmitter {
/// The thread pool to execute normal tasks.
private final ExecutorService normalTaskExecutorService;
private final Deque<TaskExecutor> idleTaskExecutors = new ArrayDeque<>();
private final Map<ActorId, TaskExecutor> actorTaskExecutors = new HashMap<>();
private final Object taskExecutorLock = new Object();
private final ThreadLocal<TaskExecutor> currentTaskExecutor = new ThreadLocal<>();
public LocalModeTaskSubmitter(RayDevRuntime runtime, LocalModeObjectStore objectStore,
int numberThreads) {
private final Map<ActorId, ActorContext> actorContexts = new ConcurrentHashMap<>();
public LocalModeTaskSubmitter(RayRuntimeInternal runtime, TaskExecutor taskExecutor,
LocalModeObjectStore objectStore) {
this.runtime = runtime;
this.taskExecutor = taskExecutor;
this.objectStore = objectStore;
// The thread pool that executes normal tasks in parallel.
normalTaskExecutorService = Executors.newFixedThreadPool(numberThreads);
normalTaskExecutorService = Executors.newCachedThreadPool();
// The thread pool that executes actor tasks in parallel.
actorTaskExecutorServices = new HashMap<>();
}
@ -88,46 +89,6 @@ public class LocalModeTaskSubmitter implements TaskSubmitter {
}
}
/**
* Get the worker of current thread. <br> NOTE: Cannot be used for multi-threading in worker.
*/
public TaskExecutor getCurrentTaskExecutor() {
return currentTaskExecutor.get();
}
/**
* Get a worker from the worker pool to run the given task.
*/
private TaskExecutor getTaskExecutor(TaskSpec task) {
TaskExecutor taskExecutor;
synchronized (taskExecutorLock) {
if (task.getType() == TaskType.ACTOR_TASK) {
taskExecutor = actorTaskExecutors.get(getActorId(task));
} else if (task.getType() == TaskType.ACTOR_CREATION_TASK) {
taskExecutor = new LocalModeTaskExecutor(runtime);
actorTaskExecutors.put(getActorId(task), taskExecutor);
} else if (idleTaskExecutors.size() > 0) {
taskExecutor = idleTaskExecutors.pop();
} else {
taskExecutor = new LocalModeTaskExecutor(runtime);
}
}
currentTaskExecutor.set(taskExecutor);
return taskExecutor;
}
/**
* Return the worker to the worker pool.
*/
private void returnTaskExecutor(TaskExecutor worker, TaskSpec taskSpec) {
currentTaskExecutor.remove();
synchronized (taskExecutorLock) {
if (taskSpec.getType() == TaskType.NORMAL_TASK) {
idleTaskExecutors.push(worker);
}
}
}
private Set<ObjectId> getUnreadyObjects(TaskSpec taskSpec) {
Set<ObjectId> unreadyObjects = new HashSet<>();
// Check whether task arguments are ready.
@ -257,32 +218,11 @@ public class LocalModeTaskSubmitter implements TaskSubmitter {
Set<ObjectId> unreadyObjects = getUnreadyObjects(taskSpec);
final Runnable runnable = () -> {
TaskExecutor taskExecutor = getTaskExecutor(taskSpec);
try {
List<NativeRayObject> args = getFunctionArgs(taskSpec).stream()
.map(arg -> arg.id != null ?
objectStore.getRaw(Collections.singletonList(arg.id), -1).get(0)
: arg.value)
.collect(Collectors.toList());
((LocalModeWorkerContext) runtime.getWorkerContext()).setCurrentTask(taskSpec);
List<NativeRayObject> returnObjects = taskExecutor
.execute(getJavaFunctionDescriptor(taskSpec).toList(), args);
((LocalModeWorkerContext) runtime.getWorkerContext()).setCurrentTask(null);
List<ObjectId> returnIds = getReturnIds(taskSpec);
for (int i = 0; i < returnIds.size(); i++) {
NativeRayObject putObject;
if (i >= returnObjects.size()) {
// If the task is an actor task or an actor creation task,
// put the dummy object in object store, so those tasks which depends on it
// can be executed.
putObject = new NativeRayObject(new byte[] {1}, null);
} else {
putObject = returnObjects.get(i);
}
objectStore.putRaw(putObject, returnIds.get(i));
}
} finally {
returnTaskExecutor(taskExecutor, taskSpec);
executeTask(taskSpec);
} catch (Exception ex) {
LOGGER.error("Unexpected exception when executing a task.", ex);
System.exit(-1);
}
};
@ -313,6 +253,52 @@ public class LocalModeTaskSubmitter implements TaskSubmitter {
}
}
private void executeTask(TaskSpec taskSpec) {
ActorContext actorContext = null;
if (taskSpec.getType() == TaskType.ACTOR_TASK) {
actorContext = actorContexts.get(getActorId(taskSpec));
Preconditions.checkNotNull(actorContext);
}
taskExecutor.setActorContext(actorContext);
List<NativeRayObject> args = getFunctionArgs(taskSpec).stream()
.map(arg -> arg.id != null ?
objectStore.getRaw(Collections.singletonList(arg.id), -1).get(0)
: arg.value)
.collect(Collectors.toList());
runtime.setIsContextSet(true);
((LocalModeWorkerContext) runtime.getWorkerContext()).setCurrentTask(taskSpec);
UniqueId workerId = actorContext != null
? ((LocalModeTaskExecutor.LocalActorContext) actorContext).getWorkerId()
: UniqueId.randomId();
((LocalModeWorkerContext) runtime.getWorkerContext()).setCurrentWorkerId(workerId);
List<NativeRayObject> returnObjects = taskExecutor
.execute(getJavaFunctionDescriptor(taskSpec).toList(), args);
if (taskSpec.getType() == TaskType.ACTOR_CREATION_TASK) {
// Update actor context map ASAP in case objectStore.putRaw triggered the next actor task
// on this actor.
actorContexts.put(getActorId(taskSpec), taskExecutor.getActorContext());
}
// Set this flag to true is necessary because at the end of `taskExecutor.execute()`,
// this flag will be set to false. And `runtime.getWorkerContext()` requires it to be
// true.
runtime.setIsContextSet(true);
((LocalModeWorkerContext) runtime.getWorkerContext()).setCurrentTask(null);
runtime.setIsContextSet(false);
List<ObjectId> returnIds = getReturnIds(taskSpec);
for (int i = 0; i < returnIds.size(); i++) {
NativeRayObject putObject;
if (i >= returnObjects.size()) {
// If the task is an actor task or an actor creation task,
// put the dummy object in object store, so those tasks which depends on it
// can be executed.
putObject = new NativeRayObject(new byte[]{1}, null);
} else {
putObject = returnObjects.get(i);
}
objectStore.putRaw(putObject, returnIds.get(i));
}
}
private static JavaFunctionDescriptor getJavaFunctionDescriptor(TaskSpec taskSpec) {
org.ray.runtime.generated.Common.FunctionDescriptor functionDescriptor =
taskSpec.getFunctionDescriptor();

View file

@ -8,39 +8,42 @@ import org.ray.api.Checkpointable.Checkpoint;
import org.ray.api.Checkpointable.CheckpointContext;
import org.ray.api.id.ActorId;
import org.ray.api.id.UniqueId;
import org.ray.runtime.AbstractRayRuntime;
import org.ray.runtime.RayRuntimeInternal;
import org.ray.runtime.task.NativeTaskExecutor.NativeActorContext;
/**
* Task executor for cluster mode.
*/
public class NativeTaskExecutor extends TaskExecutor {
public class NativeTaskExecutor extends TaskExecutor<NativeActorContext> {
// TODO(hchen): Use the C++ config.
private static final int NUM_ACTOR_CHECKPOINTS_TO_KEEP = 20;
/**
* The native pointer of core worker.
*/
private final long nativeCoreWorkerPointer;
static class NativeActorContext extends TaskExecutor.ActorContext {
/**
* Number of tasks executed since last actor checkpoint.
*/
private int numTasksSinceLastCheckpoint = 0;
/**
* Number of tasks executed since last actor checkpoint.
*/
private int numTasksSinceLastCheckpoint = 0;
/**
* IDs of this actor's previous checkpoints.
*/
private List<UniqueId> checkpointIds;
/**
* IDs of this actor's previous checkpoints.
*/
private List<UniqueId> checkpointIds;
/**
* Timestamp of the last actor checkpoint.
*/
private long lastCheckpointTimestamp = 0;
/**
* Timestamp of the last actor checkpoint.
*/
private long lastCheckpointTimestamp = 0;
}
public NativeTaskExecutor(long nativeCoreWorkerPointer, AbstractRayRuntime runtime) {
public NativeTaskExecutor(RayRuntimeInternal runtime) {
super(runtime);
this.nativeCoreWorkerPointer = nativeCoreWorkerPointer;
}
@Override
protected NativeActorContext createActorContext() {
return new NativeActorContext();
}
@Override
@ -48,15 +51,18 @@ public class NativeTaskExecutor extends TaskExecutor {
if (!(actor instanceof Checkpointable)) {
return;
}
NativeActorContext actorContext = getActorContext();
CheckpointContext checkpointContext = new CheckpointContext(actorId,
++numTasksSinceLastCheckpoint, System.currentTimeMillis() - lastCheckpointTimestamp);
++actorContext.numTasksSinceLastCheckpoint,
System.currentTimeMillis() - actorContext.lastCheckpointTimestamp);
Checkpointable checkpointable = (Checkpointable) actor;
if (!checkpointable.shouldCheckpoint(checkpointContext)) {
return;
}
numTasksSinceLastCheckpoint = 0;
lastCheckpointTimestamp = System.currentTimeMillis();
UniqueId checkpointId = new UniqueId(nativePrepareCheckpoint(nativeCoreWorkerPointer));
actorContext.numTasksSinceLastCheckpoint = 0;
actorContext.lastCheckpointTimestamp = System.currentTimeMillis();
UniqueId checkpointId = new UniqueId(nativePrepareCheckpoint());
List<UniqueId> checkpointIds = actorContext.checkpointIds;
checkpointIds.add(checkpointId);
if (checkpointIds.size() > NUM_ACTOR_CHECKPOINTS_TO_KEEP) {
((Checkpointable) actor).checkpointExpired(actorId, checkpointIds.get(0));
@ -70,9 +76,10 @@ public class NativeTaskExecutor extends TaskExecutor {
if (!(actor instanceof Checkpointable)) {
return;
}
numTasksSinceLastCheckpoint = 0;
lastCheckpointTimestamp = System.currentTimeMillis();
checkpointIds = new ArrayList<>();
NativeActorContext actorContext = getActorContext();
actorContext.numTasksSinceLastCheckpoint = 0;
actorContext.lastCheckpointTimestamp = System.currentTimeMillis();
actorContext.checkpointIds = new ArrayList<>();
List<Checkpoint> availableCheckpoints
= runtime.getGcsClient().getCheckpointsForActor(actorId);
if (availableCheckpoints.isEmpty()) {
@ -90,13 +97,11 @@ public class NativeTaskExecutor extends TaskExecutor {
Preconditions.checkArgument(checkpointValid,
"'loadCheckpoint' must return a checkpoint ID that exists in the "
+ "'availableCheckpoints' list, or null.");
nativeNotifyActorResumedFromCheckpoint(nativeCoreWorkerPointer, checkpointId.getBytes());
nativeNotifyActorResumedFromCheckpoint(checkpointId.getBytes());
}
}
private static native byte[] nativePrepareCheckpoint(long nativeCoreWorkerPointer);
private static native byte[] nativePrepareCheckpoint();
private static native void nativeNotifyActorResumedFromCheckpoint(long nativeCoreWorkerPointer,
byte[] checkpointId);
private static native void nativeNotifyActorResumedFromCheckpoint(byte[] checkpointId);
}

View file

@ -15,30 +15,18 @@ import org.ray.runtime.functionmanager.FunctionDescriptor;
*/
public class NativeTaskSubmitter implements TaskSubmitter {
/**
* The native pointer of core worker.
*/
private final long nativeCoreWorkerPointer;
public NativeTaskSubmitter(long nativeCoreWorkerPointer) {
this.nativeCoreWorkerPointer = nativeCoreWorkerPointer;
}
@Override
public List<ObjectId> submitTask(FunctionDescriptor functionDescriptor, List<FunctionArg> args,
int numReturns, CallOptions options) {
List<byte[]> returnIds = nativeSubmitTask(nativeCoreWorkerPointer, functionDescriptor, args,
numReturns, options);
List<byte[]> returnIds = nativeSubmitTask(functionDescriptor, args, numReturns, options);
return returnIds.stream().map(ObjectId::new).collect(Collectors.toList());
}
@Override
public BaseActor createActor(FunctionDescriptor functionDescriptor, List<FunctionArg> args,
ActorCreationOptions options) {
byte[] actorId = nativeCreateActor(nativeCoreWorkerPointer, functionDescriptor, args,
options);
return NativeRayActor.create(nativeCoreWorkerPointer, actorId,
functionDescriptor.getLanguage());
byte[] actorId = nativeCreateActor(functionDescriptor, args, options);
return NativeRayActor.create(actorId, functionDescriptor.getLanguage());
}
@Override
@ -46,24 +34,18 @@ public class NativeTaskSubmitter implements TaskSubmitter {
BaseActor actor, FunctionDescriptor functionDescriptor,
List<FunctionArg> args, int numReturns, CallOptions options) {
Preconditions.checkState(actor instanceof NativeRayActor);
List<byte[]> returnIds = nativeSubmitActorTask(nativeCoreWorkerPointer,
actor.getId().getBytes(), functionDescriptor, args, numReturns,
options);
List<byte[]> returnIds = nativeSubmitActorTask(actor.getId().getBytes(),
functionDescriptor, args, numReturns, options);
return returnIds.stream().map(ObjectId::new).collect(Collectors.toList());
}
private static native List<byte[]> nativeSubmitTask(
long nativeCoreWorkerPointer,
private static native List<byte[]> nativeSubmitTask(FunctionDescriptor functionDescriptor,
List<FunctionArg> args, int numReturns, CallOptions callOptions);
private static native byte[] nativeCreateActor(FunctionDescriptor functionDescriptor,
List<FunctionArg> args, ActorCreationOptions actorCreationOptions);
private static native List<byte[]> nativeSubmitActorTask(byte[] actorId,
FunctionDescriptor functionDescriptor, List<FunctionArg> args, int numReturns,
CallOptions callOptions);
private static native byte[] nativeCreateActor(
long nativeCoreWorkerPointer,
FunctionDescriptor functionDescriptor, List<FunctionArg> args,
ActorCreationOptions actorCreationOptions);
private static native List<byte[]> nativeSubmitActorTask(
long nativeCoreWorkerPointer,
byte[] actorId, FunctionDescriptor functionDescriptor, List<FunctionArg> args,
int numReturns, CallOptions callOptions);
}

View file

@ -1,6 +1,7 @@
package org.ray.runtime.task;
import com.google.common.base.Preconditions;
import java.lang.reflect.InvocationTargetException;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.ConcurrentHashMap;
@ -9,58 +10,75 @@ import org.ray.api.id.ActorId;
import org.ray.api.id.JobId;
import org.ray.api.id.TaskId;
import org.ray.api.id.UniqueId;
import org.ray.runtime.AbstractRayRuntime;
import org.ray.runtime.config.RayConfig;
import org.ray.runtime.config.RunMode;
import org.ray.runtime.RayRuntimeInternal;
import org.ray.runtime.functionmanager.JavaFunctionDescriptor;
import org.ray.runtime.functionmanager.RayFunction;
import org.ray.runtime.generated.Common.TaskType;
import org.ray.runtime.object.NativeRayObject;
import org.ray.runtime.object.ObjectSerializer;
import org.ray.runtime.task.TaskExecutor.ActorContext;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* The task executor, which executes tasks assigned by raylet continuously.
*/
public abstract class TaskExecutor {
public abstract class TaskExecutor<T extends ActorContext> {
private static final Logger LOGGER = LoggerFactory.getLogger(TaskExecutor.class);
// A helper map to help we get the corresponding executor for the given worker in JNI.
private static ConcurrentHashMap<UniqueId, TaskExecutor> taskExecutors
= new ConcurrentHashMap<>();
protected final RayRuntimeInternal runtime;
protected final AbstractRayRuntime runtime;
private final ConcurrentHashMap<UniqueId, T> actorContextMap = new ConcurrentHashMap<>();
/**
* The current actor object, if this worker is an actor, otherwise null.
*/
protected Object currentActor = null;
static class ActorContext {
/**
* The exception that failed the actor creation task, if any.
*/
private Exception actorCreationException = null;
/**
* The current actor object, if this worker is an actor, otherwise null.
*/
Object currentActor = null;
protected TaskExecutor(AbstractRayRuntime runtime) {
this.runtime = runtime;
if (RayConfig.getInstance().runMode == RunMode.CLUSTER) {
taskExecutors.put(runtime.getWorkerContext().getCurrentWorkerId(), this);
}
/**
* The exception that failed the actor creation task, if any.
*/
Throwable actorCreationException = null;
}
public static TaskExecutor get(byte[] workerId) {
return taskExecutors.get(new UniqueId(workerId));
TaskExecutor(RayRuntimeInternal runtime) {
this.runtime = runtime;
}
protected abstract T createActorContext();
T getActorContext() {
return actorContextMap.get(runtime.getWorkerContext().getCurrentWorkerId());
}
void setActorContext(T actorContext) {
if (actorContext == null) {
// ConcurrentHashMap doesn't allow null values. So just return here.
return;
}
this.actorContextMap.put(runtime.getWorkerContext().getCurrentWorkerId(), actorContext);
}
protected List<NativeRayObject> execute(List<String> rayFunctionInfo,
List<NativeRayObject> argsBytes) {
runtime.setIsContextSet(true);
JobId jobId = runtime.getWorkerContext().getCurrentJobId();
TaskType taskType = runtime.getWorkerContext().getCurrentTaskType();
TaskId taskId = runtime.getWorkerContext().getCurrentTaskId();
LOGGER.debug("Executing task {}", taskId);
T actorContext = null;
if (taskType == TaskType.ACTOR_CREATION_TASK) {
actorContext = createActorContext();
setActorContext(actorContext);
} else if (taskType == TaskType.ACTOR_TASK) {
actorContext = getActorContext();
Preconditions.checkNotNull(actorContext);
}
List<NativeRayObject> returnObjects = new ArrayList<>();
ClassLoader oldLoader = Thread.currentThread().getContextClassLoader();
JavaFunctionDescriptor functionDescriptor = parseFunctionDescriptor(rayFunctionInfo);
@ -74,19 +92,26 @@ public abstract class TaskExecutor {
// Get local actor object and arguments.
Object actor = null;
if (taskType == TaskType.ACTOR_TASK) {
if (actorCreationException != null) {
throw actorCreationException;
if (actorContext.actorCreationException != null) {
throw actorContext.actorCreationException;
}
actor = currentActor;
actor = actorContext.currentActor;
}
Object[] args = ArgumentsBuilder.unwrap(argsBytes, rayFunction.classLoader);
// Execute the task.
Object result;
if (!rayFunction.isConstructor()) {
result = rayFunction.getMethod().invoke(actor, args);
} else {
result = rayFunction.getConstructor().newInstance(args);
try {
if (!rayFunction.isConstructor()) {
result = rayFunction.getMethod().invoke(actor, args);
} else {
result = rayFunction.getConstructor().newInstance(args);
}
} catch (InvocationTargetException e) {
if (e.getCause() != null) {
throw e.getCause();
} else {
throw e;
}
}
// Set result
if (taskType != TaskType.ACTOR_CREATION_TASK) {
@ -100,10 +125,10 @@ public abstract class TaskExecutor {
} else {
// TODO (kfstorm): handle checkpoint in core worker.
maybeLoadCheckpoint(result, runtime.getWorkerContext().getCurrentActorId());
currentActor = result;
actorContext.currentActor = result;
}
LOGGER.debug("Finished executing task {}", taskId);
} catch (Exception e) {
} catch (Throwable e) {
LOGGER.error("Error executing task " + taskId, e);
if (taskType != TaskType.ACTOR_CREATION_TASK) {
boolean hasReturn = rayFunction != null && rayFunction.hasReturn();
@ -113,11 +138,12 @@ public abstract class TaskExecutor {
.serialize(new RayTaskException("Error executing task " + taskId, e)));
}
} else {
actorCreationException = e;
actorContext.actorCreationException = e;
}
} finally {
Thread.currentThread().setContextClassLoader(oldLoader);
runtime.getWorkerContext().setCurrentClassLoader(null);
runtime.setIsContextSet(false);
}
return returnObjects;
}

View file

@ -98,12 +98,6 @@ ray {
}
}
// ----------------------------
// configurations under SINGLE_PROCESS mode
// ----------------------------
dev-runtime {
// Number of threads that you process tasks
execution-parallelism: 10
}
// Whether we enable job manager to submit and manage job.
enable-job-manager: false
}

View file

@ -42,6 +42,10 @@ public class TestProgressListener implements IInvokedMethodListener, ITestListen
@Override
public void onTestFailure(ITestResult result) {
printInfo("TEST FAILURE", getFullTestName(result));
Throwable throwable = result.getThrowable();
if (throwable != null) {
throwable.printStackTrace();
}
}
@Override

View file

@ -1,11 +1,9 @@
package org.ray.api;
import com.google.common.base.Preconditions;
import java.io.Serializable;
import java.util.function.Supplier;
import org.ray.api.runtime.RayRuntime;
import org.ray.runtime.AbstractRayRuntime;
import org.ray.runtime.RayMultiWorkerNativeRuntime;
import org.ray.runtime.RayRuntimeInternal;
import org.ray.runtime.RayRuntimeProxy;
import org.ray.runtime.config.RunMode;
import org.testng.Assert;
import org.testng.SkipException;
@ -79,12 +77,13 @@ public class TestUtils {
Assert.assertEquals(obj.get(), "hi");
}
public static AbstractRayRuntime getRuntime() {
RayRuntime runtime = Ray.internal();
if (runtime instanceof RayMultiWorkerNativeRuntime) {
runtime = ((RayMultiWorkerNativeRuntime) runtime).getCurrentRuntime();
}
Preconditions.checkState(runtime instanceof AbstractRayRuntime);
return (AbstractRayRuntime) runtime;
public static RayRuntimeInternal getRuntime() {
return (RayRuntimeInternal) Ray.internal();
}
public static RayRuntimeInternal getUnderlyingRuntime() {
RayRuntimeProxy proxy = (RayRuntimeProxy) (java.lang.reflect.Proxy
.getInvocationHandler(Ray.internal()));
return proxy.getRuntimeObject();
}
}

View file

@ -134,7 +134,7 @@ public class ClassLoaderTest extends BaseTest {
FunctionDescriptor.class, Object[].class, ActorCreationOptions.class);
createActorMethod.setAccessible(true);
return (RayActor<?>) createActorMethod
.invoke(TestUtils.getRuntime(), functionDescriptor, new Object[0], null);
.invoke(TestUtils.getUnderlyingRuntime(), functionDescriptor, new Object[0], null);
}
private <T> RayObject<T> callActorFunction(RayActor<?> rayActor,
@ -143,6 +143,6 @@ public class ClassLoaderTest extends BaseTest {
BaseActor.class, FunctionDescriptor.class, Object[].class, int.class);
callActorFunctionMethod.setAccessible(true);
return (RayObject<T>) callActorFunctionMethod
.invoke(TestUtils.getRuntime(), rayActor, functionDescriptor, args, numReturns);
.invoke(TestUtils.getUnderlyingRuntime(), rayActor, functionDescriptor, args, numReturns);
}
}

View file

@ -29,7 +29,8 @@ public class ClientExceptionTest extends BaseTest {
try {
TimeUnit.SECONDS.sleep(1);
// kill raylet
RunManager runManager = ((RayNativeRuntime) TestUtils.getRuntime()).getRunManager();
RunManager runManager =
((RayNativeRuntime) TestUtils.getUnderlyingRuntime()).getRunManager();
for (Process process : runManager.getProcesses("raylet")) {
runManager.terminateProcess("raylet", process);
}

View file

@ -14,6 +14,7 @@ import org.ray.api.RayActor;
import org.ray.api.RayObject;
import org.ray.api.TestUtils;
import org.ray.api.WaitResult;
import org.ray.api.exception.RayException;
import org.ray.api.id.ActorId;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@ -135,6 +136,105 @@ public class MultiThreadingTest extends BaseTest {
Assert.assertEquals(actorId, actorIdTester.getId());
}
static boolean testMissingWrapRunnable() throws InterruptedException {
final RayObject<Integer> fooObject = Ray.put(1);
final RayActor<Echo> fooActor = Ray.createActor(Echo::new);
final Runnable[] runnables = new Runnable[]{
() -> Ray.put(1),
() -> Ray.get(fooObject.getId()),
fooObject::get,
() -> Ray.wait(ImmutableList.of(fooObject)),
Ray::getRuntimeContext,
() -> Ray.call(MultiThreadingTest::echo, 1),
() -> Ray.createActor(Echo::new),
() -> fooActor.call(Echo::echo, 1),
};
// It's OK to run them in main thread.
for (Runnable runnable : runnables) {
runnable.run();
}
Exception[] exception = new Exception[1];
Thread thread = new Thread(Ray.wrapRunnable(() -> {
try {
// It would be OK to run them in another thread if wrapped the runnable.
for (Runnable runnable : runnables) {
runnable.run();
}
} catch (Exception ex) {
exception[0] = ex;
}
}));
thread.start();
thread.join();
if (exception[0] != null) {
throw new RuntimeException("Exception occurred in thread.", exception[0]);
}
thread = new Thread(() -> {
try {
// It wouldn't be OK to run them in another thread if not wrapped the runnable.
for (Runnable runnable : runnables) {
Assert.expectThrows(RayException.class, runnable::run);
}
} catch (Exception ex) {
exception[0] = ex;
}
});
thread.start();
thread.join();
if (exception[0] != null) {
throw new RuntimeException("Exception occurred in thread.", exception[0]);
}
Runnable[] wrappedRunnables = new Runnable[runnables.length];
for (int i = 0; i < runnables.length; i++) {
wrappedRunnables[i] = Ray.wrapRunnable(runnables[i]);
}
// It would be OK to run the wrapped runnables in the current thread.
for (Runnable runnable : wrappedRunnables) {
runnable.run();
}
// It would be OK to invoke Ray APIs after executing a wrapped runnable in the current thread.
wrappedRunnables[0].run();
runnables[0].run();
// Return true here to make the Ray.call returns an RayObject.
return true;
}
@Test
public void testMissingWrapRunnableInDriver() throws InterruptedException {
testMissingWrapRunnable();
}
@Test
public void testMissingWrapRunnableInWorker() {
Ray.call(MultiThreadingTest::testMissingWrapRunnable).get();
}
@Test
public void testGetAndSetAsyncContext() throws InterruptedException {
Object asyncContext = Ray.getAsyncContext();
Exception[] exception = new Exception[1];
Thread thread = new Thread(() -> {
try {
Ray.setAsyncContext(asyncContext);
Ray.put(1);
} catch (Exception ex) {
exception[0] = ex;
}
});
thread.start();
thread.join();
if (exception[0] != null) {
throw new RuntimeException("Exception occurred in thread.", exception[0]);
}
}
private static void runTestCaseInMultipleThreads(Runnable testCase, int numRepeats) {
ExecutorService service = Executors.newFixedThreadPool(NUM_THREADS);

View file

@ -17,7 +17,7 @@ from ray.includes.common cimport (
CBuffer,
CRayObject
)
from ray.includes.libcoreworker cimport CCoreWorker
from ray.includes.libcoreworker cimport CFiberEvent
from ray.includes.unique_ids cimport (
CObjectID,
CActorID
@ -72,7 +72,7 @@ cdef class ActorID(BaseID):
cdef class CoreWorker:
cdef:
unique_ptr[CCoreWorker] core_worker
c_bool is_driver
object async_thread
object async_event_loop
object plasma_event_handler
@ -85,6 +85,7 @@ cdef class CoreWorker:
cdef store_task_outputs(
self, worker, outputs, const c_vector[CObjectID] return_ids,
c_vector[shared_ptr[CRayObject]] *returns)
cdef yield_current_fiber(self, CFiberEvent &fiber_event)
cdef class FunctionDescriptor:
cdef:

View file

@ -69,7 +69,8 @@ from ray.includes.unique_ids cimport (
)
from ray.includes.libcoreworker cimport (
CActorCreationOptions,
CCoreWorker,
CCoreWorkerOptions,
CCoreWorkerProcess,
CTaskOptions,
ResourceMappingType,
CFiberEvent,
@ -312,7 +313,7 @@ cdef execute_task(
dict execution_infos = manager.execution_infos
CoreWorker core_worker = worker.core_worker
JobID job_id = core_worker.get_current_job_id()
CTaskID task_id = core_worker.core_worker.get().GetCurrentTaskId()
TaskID task_id = core_worker.get_current_task_id()
CFiberEvent task_done_event
# Automatically restrict the GPUs available to this task.
@ -339,7 +340,7 @@ cdef execute_task(
function_name = execution_info.function_name
extra_data = (b'{"name": ' + function_name.encode("ascii") +
b' "task_id": ' + task_id.Hex() + b'}')
b' "task_id": ' + task_id.hex().encode("ascii") + b'}')
if <int>task_type == <int>TASK_TYPE_NORMAL_TASK:
title = "ray::{}()".format(function_name)
@ -396,9 +397,7 @@ cdef execute_task(
monitor_state.unregister_coroutine(coroutine)
future.add_done_callback(callback)
with nogil:
(core_worker.core_worker.get()
.YieldCurrentFiber(task_done_event))
core_worker.yield_current_fiber(task_done_event)
return future.result()
@ -499,8 +498,7 @@ cdef CRayStatus task_execution_handler(
const c_vector[shared_ptr[CRayObject]] &c_args,
const c_vector[CObjectID] &c_arg_reference_ids,
const c_vector[CObjectID] &c_return_ids,
c_vector[shared_ptr[CRayObject]] *returns,
const CWorkerID &c_worker_id) nogil:
c_vector[shared_ptr[CRayObject]] *returns) nogil:
with gil:
try:
@ -645,43 +643,76 @@ cdef class CoreWorker:
def __cinit__(self, is_driver, store_socket, raylet_socket,
JobID job_id, GcsClientOptions gcs_options, log_dir,
node_ip_address, node_manager_port, local_mode):
use_driver = is_driver or local_mode
self.core_worker.reset(new CCoreWorker(
WORKER_TYPE_DRIVER if use_driver else WORKER_TYPE_WORKER,
LANGUAGE_PYTHON, store_socket.encode("ascii"),
raylet_socket.encode("ascii"), job_id.native(),
gcs_options.native()[0], log_dir.encode("utf-8"),
node_ip_address.encode("utf-8"), node_manager_port,
task_execution_handler, check_signals, gc_collect,
get_py_stack, True, local_mode))
node_ip_address, node_manager_port, local_mode,
driver_name, stdout_file, stderr_file):
self.is_driver = is_driver
self.is_local_mode = local_mode
cdef CCoreWorkerOptions options = CCoreWorkerOptions()
options.worker_type = (
WORKER_TYPE_DRIVER if is_driver else WORKER_TYPE_WORKER)
options.language = LANGUAGE_PYTHON
options.store_socket = store_socket.encode("ascii")
options.raylet_socket = raylet_socket.encode("ascii")
options.job_id = job_id.native()
options.gcs_options = gcs_options.native()[0]
options.log_dir = log_dir.encode("utf-8")
options.install_failure_signal_handler = True
options.node_ip_address = node_ip_address.encode("utf-8")
options.node_manager_port = node_manager_port
options.driver_name = driver_name
options.stdout_file = stdout_file
options.stderr_file = stderr_file
options.task_execution_callback = task_execution_handler
options.check_signals = check_signals
options.gc_collect = gc_collect
options.get_lang_stack = get_py_stack
options.ref_counting_enabled = True
options.is_local_mode = local_mode
options.num_workers = 1
CCoreWorkerProcess.Initialize(options)
def __dealloc__(self):
with nogil:
# If it's a worker, the core worker process should have been
# shutdown. So we can't call
# `CCoreWorkerProcess.GetCoreWorker().GetWorkerType()` here.
# Instead, we use the cached `is_driver` flag to test if it's a
# driver.
if self.is_driver:
CCoreWorkerProcess.Shutdown()
def run_task_loop(self):
with nogil:
self.core_worker.get().StartExecutingTasks()
CCoreWorkerProcess.RunTaskExecutionLoop()
def get_current_task_id(self):
return TaskID(self.core_worker.get().GetCurrentTaskId().Binary())
return TaskID(
CCoreWorkerProcess.GetCoreWorker().GetCurrentTaskId().Binary())
def get_current_job_id(self):
return JobID(self.core_worker.get().GetCurrentJobId().Binary())
return JobID(
CCoreWorkerProcess.GetCoreWorker().GetCurrentJobId().Binary())
def get_actor_id(self):
return ActorID(self.core_worker.get().GetActorId().Binary())
return ActorID(
CCoreWorkerProcess.GetCoreWorker().GetActorId().Binary())
def set_webui_display(self, key, message):
self.core_worker.get().SetWebuiDisplay(key, message)
CCoreWorkerProcess.GetCoreWorker().SetWebuiDisplay(key, message)
def set_actor_title(self, title):
self.core_worker.get().SetActorTitle(title)
CCoreWorkerProcess.GetCoreWorker().SetActorTitle(title)
def set_plasma_added_callback(self, plasma_event_handler):
self.plasma_event_handler = plasma_event_handler
self.core_worker.get().SetPlasmaAddedCallback(async_plasma_callback)
CCoreWorkerProcess.GetCoreWorker().SetPlasmaAddedCallback(
async_plasma_callback)
def subscribe_to_plasma_object(self, ObjectID object_id):
self.core_worker.get().SubscribeToPlasmaAdd(object_id.native())
CCoreWorkerProcess.GetCoreWorker().SubscribeToPlasmaAdd(
object_id.native())
def get_plasma_event_handler(self):
return self.plasma_event_handler
@ -694,7 +725,7 @@ cdef class CoreWorker:
c_vector[CObjectID] c_object_ids = ObjectIDsToVector(object_ids)
with nogil:
check_status(self.core_worker.get().Get(
check_status(CCoreWorkerProcess.GetCoreWorker().Get(
c_object_ids, timeout_ms, &results))
return RayObjectsToDataMetadataPairs(results)
@ -705,7 +736,7 @@ cdef class CoreWorker:
CObjectID c_object_id = object_id.native()
with nogil:
check_status(self.core_worker.get().Contains(
check_status(CCoreWorkerProcess.GetCoreWorker().Contains(
c_object_id, &has_object))
return has_object
@ -716,13 +747,13 @@ cdef class CoreWorker:
CObjectID *c_object_id, shared_ptr[CBuffer] *data):
if object_id is None:
with nogil:
check_status(self.core_worker.get().Create(
check_status(CCoreWorkerProcess.GetCoreWorker().Create(
metadata, data_size, contained_ids,
c_object_id, data))
else:
c_object_id[0] = object_id.native()
with nogil:
check_status(self.core_worker.get().Create(
check_status(CCoreWorkerProcess.GetCoreWorker().Create(
metadata, data_size,
c_object_id[0], data))
@ -752,7 +783,7 @@ cdef class CoreWorker:
write_serialized_object(serialized_object, data)
if self.is_local_mode:
c_object_id_vector.push_back(c_object_id)
check_status(self.core_worker.get().Put(
check_status(CCoreWorkerProcess.GetCoreWorker().Put(
CRayObject(data, metadata, c_object_id_vector),
c_object_id_vector, c_object_id))
else:
@ -760,7 +791,7 @@ cdef class CoreWorker:
# Using custom object IDs is not supported because we can't
# track their lifecycle, so we don't pin the object in this
# case.
check_status(self.core_worker.get().Seal(
check_status(CCoreWorkerProcess.GetCoreWorker().Seal(
c_object_id,
pin_object and object_id is None))
@ -775,7 +806,7 @@ cdef class CoreWorker:
wait_ids = ObjectIDsToVector(object_ids)
with nogil:
check_status(self.core_worker.get().Wait(
check_status(CCoreWorkerProcess.GetCoreWorker().Wait(
wait_ids, num_returns, timeout_ms, &results))
assert len(results) == len(object_ids)
@ -795,19 +826,19 @@ cdef class CoreWorker:
c_vector[CObjectID] free_ids = ObjectIDsToVector(object_ids)
with nogil:
check_status(self.core_worker.get().Delete(
check_status(CCoreWorkerProcess.GetCoreWorker().Delete(
free_ids, local_only, delete_creating_tasks))
def global_gc(self):
with nogil:
self.core_worker.get().TriggerGlobalGC()
CCoreWorkerProcess.GetCoreWorker().TriggerGlobalGC()
def set_object_store_client_options(self, client_name,
int64_t limit_bytes):
try:
logger.debug("Setting plasma memory limit to {} for {}".format(
limit_bytes, client_name))
check_status(self.core_worker.get().SetClientOptions(
check_status(CCoreWorkerProcess.GetCoreWorker().SetClientOptions(
client_name.encode("ascii"), limit_bytes))
except RayError as e:
self.dump_object_store_memory_usage()
@ -820,7 +851,7 @@ cdef class CoreWorker:
limit_bytes, client_name, e))
def dump_object_store_memory_usage(self):
message = self.core_worker.get().MemoryUsageString()
message = CCoreWorkerProcess.GetCoreWorker().MemoryUsageString()
logger.warning("Local object store memory usage:\n{}\n".format(
message.decode("utf-8")))
@ -847,7 +878,7 @@ cdef class CoreWorker:
prepare_args(self, args, &args_vector)
with nogil:
check_status(self.core_worker.get().SubmitTask(
check_status(CCoreWorkerProcess.GetCoreWorker().SubmitTask(
ray_function, args_vector, task_options, &return_ids,
max_retries))
@ -880,7 +911,7 @@ cdef class CoreWorker:
prepare_args(self, args, &args_vector)
with nogil:
check_status(self.core_worker.get().CreateActor(
check_status(CCoreWorkerProcess.GetCoreWorker().CreateActor(
ray_function, args_vector,
CActorCreationOptions(
max_reconstructions, max_concurrency,
@ -916,10 +947,11 @@ cdef class CoreWorker:
prepare_args(self, args, &args_vector)
with nogil:
check_status(self.core_worker.get().SubmitActorTask(
c_actor_id,
ray_function,
args_vector, task_options, &return_ids))
check_status(
CCoreWorkerProcess.GetCoreWorker().SubmitActorTask(
c_actor_id,
ray_function,
args_vector, task_options, &return_ids))
return VectorToObjectIDs(return_ids)
@ -928,13 +960,13 @@ cdef class CoreWorker:
CActorID c_actor_id = actor_id.native()
with nogil:
check_status(self.core_worker.get().KillActor(
check_status(CCoreWorkerProcess.GetCoreWorker().KillActor(
c_actor_id, True, no_reconstruction))
def resource_ids(self):
cdef:
ResourceMappingType resource_mapping = (
self.core_worker.get().GetResourceIDs())
CCoreWorkerProcess.GetCoreWorker().GetResourceIDs())
unordered_map[
c_string, c_vector[pair[int64_t, double]]
].iterator iterator = resource_mapping.begin()
@ -955,13 +987,14 @@ cdef class CoreWorker:
def profile_event(self, c_string event_type, object extra_data=None):
return ProfileEvent.make(
self.core_worker.get().CreateProfileEvent(event_type),
CCoreWorkerProcess.GetCoreWorker().CreateProfileEvent(event_type),
extra_data)
def remove_actor_handle_reference(self, ActorID actor_id):
cdef:
CActorID c_actor_id = actor_id.native()
self.core_worker.get().RemoveActorHandleReference(c_actor_id)
CCoreWorkerProcess.GetCoreWorker().RemoveActorHandleReference(
c_actor_id)
def deserialize_and_register_actor_handle(self, const c_string &bytes,
ObjectID
@ -974,9 +1007,10 @@ cdef class CoreWorker:
worker = ray.worker.global_worker
worker.check_connected()
manager = worker.function_actor_manager
c_actor_id = self.core_worker.get().DeserializeAndRegisterActorHandle(
bytes, c_outer_object_id)
check_status(self.core_worker.get().GetActorHandle(
c_actor_id = (CCoreWorkerProcess.GetCoreWorker()
.DeserializeAndRegisterActorHandle(
bytes, c_outer_object_id))
check_status(CCoreWorkerProcess.GetCoreWorker().GetActorHandle(
c_actor_id, &c_actor_handle))
actor_id = ActorID(c_actor_id.Binary())
job_id = JobID(c_actor_handle.CreationJobID().Binary())
@ -1017,24 +1051,26 @@ cdef class CoreWorker:
cdef:
c_string output
CObjectID c_actor_handle_id
check_status(self.core_worker.get().SerializeActorHandle(
check_status(CCoreWorkerProcess.GetCoreWorker().SerializeActorHandle(
actor_id.native(), &output, &c_actor_handle_id))
return output, ObjectID(c_actor_handle_id.Binary())
def add_object_id_reference(self, ObjectID object_id):
# Note: faster to not release GIL for short-running op.
self.core_worker.get().AddLocalReference(object_id.native())
CCoreWorkerProcess.GetCoreWorker().AddLocalReference(
object_id.native())
def remove_object_id_reference(self, ObjectID object_id):
# Note: faster to not release GIL for short-running op.
self.core_worker.get().RemoveLocalReference(object_id.native())
CCoreWorkerProcess.GetCoreWorker().RemoveLocalReference(
object_id.native())
def serialize_and_promote_object_id(self, ObjectID object_id):
cdef:
CObjectID c_object_id = object_id.native()
CTaskID c_owner_id = CTaskID.Nil()
CAddress c_owner_address = CAddress()
self.core_worker.get().PromoteToPlasmaAndGetOwnershipInfo(
CCoreWorkerProcess.GetCoreWorker().PromoteToPlasmaAndGetOwnershipInfo(
c_object_id, &c_owner_id, &c_owner_address)
return (object_id,
TaskID(c_owner_id.Binary()),
@ -1053,11 +1089,12 @@ cdef class CoreWorker:
CAddress c_owner_address = CAddress()
c_owner_address.ParseFromString(serialized_owner_address)
self.core_worker.get().RegisterOwnershipInfoAndResolveFuture(
(CCoreWorkerProcess.GetCoreWorker()
.RegisterOwnershipInfoAndResolveFuture(
c_object_id,
c_outer_object_id,
c_owner_id,
c_owner_address)
c_owner_address))
cdef store_task_outputs(
self, worker, outputs, const c_vector[CObjectID] return_ids,
@ -1088,8 +1125,10 @@ cdef class CoreWorker:
ObjectIDsToVector(serialized_object.contained_object_ids))
with nogil:
check_status(self.core_worker.get().AllocateReturnObjects(
return_ids, data_sizes, metadatas, contained_ids, returns))
check_status(CCoreWorkerProcess.GetCoreWorker()
.AllocateReturnObjects(
return_ids, data_sizes, metadatas, contained_ids,
returns))
for i, serialized_object in enumerate(serialized_objects):
# A nullptr is returned if the object already exists.
@ -1099,7 +1138,7 @@ cdef class CoreWorker:
if self.is_local_mode:
return_ids_vector.push_back(return_ids[i])
check_status(
self.core_worker.get().Put(
CCoreWorkerProcess.GetCoreWorker().Put(
CRayObject(returns[0][i].get().GetData(),
returns[0][i].get().GetMetadata(),
return_ids_vector),
@ -1138,7 +1177,7 @@ cdef class CoreWorker:
future = asyncio.run_coroutine_threadsafe(coroutine, loop)
future.add_done_callback(lambda _: event.Notify())
with nogil:
(self.core_worker.get()
(CCoreWorkerProcess.GetCoreWorker()
.YieldCurrentFiber(event))
return future.result()
@ -1149,14 +1188,20 @@ cdef class CoreWorker:
self.async_thread.join()
def current_actor_is_asyncio(self):
return self.core_worker.get().GetWorkerContext().CurrentActorIsAsync()
return (CCoreWorkerProcess.GetCoreWorker().GetWorkerContext()
.CurrentActorIsAsync())
cdef yield_current_fiber(self, CFiberEvent &fiber_event):
with nogil:
CCoreWorkerProcess.GetCoreWorker().YieldCurrentFiber(fiber_event)
def get_all_reference_counts(self):
cdef:
unordered_map[CObjectID, pair[size_t, size_t]] c_ref_counts
unordered_map[CObjectID, pair[size_t, size_t]].iterator it
c_ref_counts = self.core_worker.get().GetAllReferenceCounts()
c_ref_counts = (
CCoreWorkerProcess.GetCoreWorker().GetAllReferenceCounts())
it = c_ref_counts.begin()
ref_counts = {}
@ -1170,7 +1215,7 @@ cdef class CoreWorker:
return ref_counts
def in_memory_store_get_async(self, ObjectID object_id, future):
self.core_worker.get().GetAsync(
CCoreWorkerProcess.GetCoreWorker().GetAsync(
object_id.native(),
async_set_result_callback,
async_retry_with_plasma_callback,
@ -1178,7 +1223,7 @@ cdef class CoreWorker:
def push_error(self, JobID job_id, error_type, error_message,
double timestamp):
check_status(self.core_worker.get().PushError(
check_status(CCoreWorkerProcess.GetCoreWorker().PushError(
job_id.native(), error_type.encode("ascii"),
error_message.encode("ascii"), timestamp))
@ -1190,18 +1235,21 @@ cdef class CoreWorker:
# PrepareActorCheckpoint will wait for raylet's reply, release
# the GIL so other Python threads can run.
with nogil:
check_status(self.core_worker.get().PrepareActorCheckpoint(
c_actor_id, &checkpoint_id))
check_status(
CCoreWorkerProcess.GetCoreWorker()
.PrepareActorCheckpoint(c_actor_id, &checkpoint_id))
return ActorCheckpointID(checkpoint_id.Binary())
def notify_actor_resumed_from_checkpoint(self, ActorID actor_id,
ActorCheckpointID checkpoint_id):
check_status(self.core_worker.get().NotifyActorResumedFromCheckpoint(
actor_id.native(), checkpoint_id.native()))
check_status(
CCoreWorkerProcess.GetCoreWorker()
.NotifyActorResumedFromCheckpoint(
actor_id.native(), checkpoint_id.native()))
def set_resource(self, basestring resource_name,
double capacity, ClientID client_id):
self.core_worker.get().SetResource(
CCoreWorkerProcess.GetCoreWorker().SetResource(
resource_name.encode("ascii"), capacity,
CClientID.FromBinary(client_id.binary()))

View file

@ -17,7 +17,6 @@ from ray.includes.unique_ids cimport (
CJobID,
CTaskID,
CObjectID,
CWorkerID,
)
from ray.includes.common cimport (
CAddress,
@ -80,31 +79,9 @@ cdef extern from "ray/core_worker/core_worker.h" nogil:
c_string ExtensionData() const
cdef cppclass CCoreWorker "ray::CoreWorker":
CCoreWorker(const CWorkerType worker_type, const CLanguage language,
const c_string &store_socket,
const c_string &raylet_socket, const CJobID &job_id,
const CGcsClientOptions &gcs_options,
const c_string &log_dir, const c_string &node_ip_address,
int node_manager_port,
CRayStatus (
CTaskType task_type,
const CRayFunction &ray_function,
const unordered_map[c_string, double] &resources,
const c_vector[shared_ptr[CRayObject]] &args,
const c_vector[CObjectID] &arg_reference_ids,
const c_vector[CObjectID] &return_ids,
c_vector[shared_ptr[CRayObject]] *returns,
const CWorkerID &worker_id) nogil,
CRayStatus() nogil,
void() nogil,
void(c_string *stack_out) nogil,
c_bool ref_counting_enabled,
c_bool local_worker)
CWorkerType &GetWorkerType()
CLanguage &GetLanguage()
void StartExecutingTasks()
CRayStatus SubmitTask(
const CRayFunction &function, const c_vector[CTaskArg] &args,
const CTaskOptions &options, c_vector[CObjectID] *return_ids,
@ -206,3 +183,46 @@ cdef extern from "ray/core_worker/core_worker.h" nogil:
void SetPlasmaAddedCallback(plasma_callback_function callback)
void SubscribeToPlasmaAdd(const CObjectID &object_id)
cdef cppclass CCoreWorkerOptions "ray::CoreWorkerOptions":
CWorkerType worker_type
CLanguage language
c_string store_socket
c_string raylet_socket
CJobID job_id
CGcsClientOptions gcs_options
c_string log_dir
c_bool install_failure_signal_handler
c_string node_ip_address
int node_manager_port
c_string driver_name
c_string stdout_file
c_string stderr_file
(CRayStatus(
CTaskType task_type,
const CRayFunction &ray_function,
const unordered_map[c_string, double] &resources,
const c_vector[shared_ptr[CRayObject]] &args,
const c_vector[CObjectID] &arg_reference_ids,
const c_vector[CObjectID] &return_ids,
c_vector[shared_ptr[CRayObject]] *returns) nogil
) task_execution_callback
(CRayStatus() nogil) check_signals
(void() nogil) gc_collect
(void(c_string *stack_out) nogil) get_lang_stack
c_bool ref_counting_enabled
c_bool is_local_mode
int num_workers
CCoreWorkerOptions()
cdef cppclass CCoreWorkerProcess "ray::CoreWorkerProcess":
@staticmethod
void Initialize(const CCoreWorkerOptions &options)
# Only call this in CoreWorker.__cinit__,
# use CoreWorker.core_worker to access C++ CoreWorker.
@staticmethod
CCoreWorker &GetCoreWorker()
@staticmethod
void Shutdown()
@staticmethod
void RunTaskExecutionLoop()

View file

@ -1173,27 +1173,14 @@ def connect(node,
ray.state.state._initialize_global_state(
node.redis_address, redis_password=node.redis_password)
# Register the worker with Redis.
driver_name = ""
log_stdout_file_name = ""
log_stderr_file_name = ""
if mode == SCRIPT_MODE:
# The concept of a driver is the same as the concept of a "job".
# Register the driver/job with Redis here.
import __main__ as main
driver_info = {
"node_ip_address": node.node_ip_address,
"driver_id": worker.worker_id,
"start_time": time.time(),
"plasma_store_socket": node.plasma_store_socket_name,
"raylet_socket": node.raylet_socket_name,
"name": (main.__file__
if hasattr(main, "__file__") else "INTERACTIVE MODE")
}
worker.redis_client.hmset(b"Drivers:" + worker.worker_id, driver_info)
driver_name = (main.__file__
if hasattr(main, "__file__") else "INTERACTIVE MODE")
elif mode == WORKER_MODE:
# Register the worker with Redis.
worker_dict = {
"node_ip_address": node.node_ip_address,
"plasma_store_socket": node.plasma_store_socket_name,
}
# Check the RedirectOutput key in Redis and based on its value redirect
# worker output and error to their own files.
# This key is set in services.py when Redis is started.
@ -1224,14 +1211,12 @@ def connect(node,
print("Ray worker pid: {}".format(os.getpid()), file=sys.stderr)
sys.stdout.flush()
sys.stderr.flush()
worker_dict["stdout_file"] = os.path.abspath(
log_stdout_file_name = os.path.abspath(
(log_stdout_file
if log_stdout_file is not None else sys.stdout).name)
worker_dict["stderr_file"] = os.path.abspath(
log_stderr_file_name = os.path.abspath(
(log_stderr_file
if log_stderr_file is not None else sys.stderr).name)
worker.redis_client.hmset(b"Workers:" + worker.worker_id, worker_dict)
elif not LOCAL_MODE:
raise ValueError(
"Invalid worker mode. Expected DRIVER, WORKER or LOCAL.")
@ -1242,9 +1227,19 @@ def connect(node,
node.redis_password,
)
worker.core_worker = ray._raylet.CoreWorker(
(mode == SCRIPT_MODE), node.plasma_store_socket_name,
node.raylet_socket_name, job_id, gcs_options, node.get_logs_dir_path(),
node.node_ip_address, node.node_manager_port, mode == LOCAL_MODE)
(mode == SCRIPT_MODE or mode == LOCAL_MODE),
node.plasma_store_socket_name,
node.raylet_socket_name,
job_id,
gcs_options,
node.get_logs_dir_path(),
node.node_ip_address,
node.node_manager_port,
(mode == LOCAL_MODE),
driver_name,
log_stdout_file_name,
log_stderr_file_name,
)
if driver_object_store_memory is not None:
worker.core_worker.set_object_store_client_options(

View file

@ -63,10 +63,10 @@ struct WorkerThreadContext {
thread_local std::unique_ptr<WorkerThreadContext> WorkerContext::thread_context_ =
nullptr;
WorkerContext::WorkerContext(WorkerType worker_type, const JobID &job_id)
WorkerContext::WorkerContext(WorkerType worker_type, const WorkerID &worker_id,
const JobID &job_id)
: worker_type_(worker_type),
worker_id_(worker_type_ == WorkerType::DRIVER ? ComputeDriverIdFromJob(job_id)
: WorkerID::FromRandom()),
worker_id_(worker_id),
current_job_id_(worker_type_ == WorkerType::DRIVER ? job_id : JobID::Nil()),
current_actor_id_(ActorID::Nil()),
main_thread_id_(boost::this_thread::get_id()) {

View file

@ -26,7 +26,7 @@ struct WorkerThreadContext;
class WorkerContext {
public:
WorkerContext(WorkerType worker_type, const JobID &job_id);
WorkerContext(WorkerType worker_type, const WorkerID &worker_id, const JobID &job_id);
const WorkerType GetWorkerType() const;

View file

@ -75,63 +75,214 @@ void GroupObjectIdsByStoreProvider(const std::vector<ObjectID> &object_ids,
namespace ray {
CoreWorker::CoreWorker(const WorkerType worker_type, const Language language,
const std::string &store_socket, const std::string &raylet_socket,
const JobID &job_id, const gcs::GcsClientOptions &gcs_options,
const std::string &log_dir, const std::string &node_ip_address,
int node_manager_port,
const TaskExecutionCallback &task_execution_callback,
std::function<Status()> check_signals,
std::function<void()> gc_collect,
std::function<void(std::string *)> get_lang_stack,
bool ref_counting_enabled, bool local_mode)
: worker_type_(worker_type),
language_(language),
log_dir_(log_dir),
ref_counting_enabled_(ref_counting_enabled),
is_local_mode_(local_mode),
check_signals_(check_signals),
gc_collect_(gc_collect),
get_call_site_(RayConfig::instance().record_ref_creation_sites() ? get_lang_stack
: nullptr),
worker_context_(worker_type, job_id),
std::unique_ptr<CoreWorkerProcess> CoreWorkerProcess::instance_;
thread_local std::weak_ptr<CoreWorker> CoreWorkerProcess::current_core_worker_;
void CoreWorkerProcess::Initialize(const CoreWorkerOptions &options) {
RAY_CHECK(!instance_) << "The process is already initialized for core worker.";
instance_ = std::unique_ptr<CoreWorkerProcess>(new CoreWorkerProcess(options));
}
void CoreWorkerProcess::Shutdown() {
if (!instance_) {
return;
}
RAY_CHECK(instance_->options_.worker_type == WorkerType::DRIVER)
<< "The `Shutdown` interface is for driver only.";
RAY_CHECK(instance_->global_worker_);
instance_->global_worker_->Disconnect();
instance_->global_worker_->Shutdown();
instance_->RemoveWorker(instance_->global_worker_);
instance_.reset();
}
bool CoreWorkerProcess::IsInitialized() { return instance_ != nullptr; }
CoreWorkerProcess::CoreWorkerProcess(const CoreWorkerOptions &options)
: options_(options),
global_worker_id_(
options.worker_type == WorkerType::DRIVER
? ComputeDriverIdFromJob(options_.job_id)
: (options_.num_workers == 1 ? WorkerID::FromRandom() : WorkerID::Nil())) {
// Initialize logging if log_dir is passed. Otherwise, it must be initialized
// and cleaned up by the caller.
if (options_.log_dir != "") {
std::stringstream app_name;
app_name << LanguageString(options_.language) << "-core-"
<< WorkerTypeString(options_.worker_type);
if (!global_worker_id_.IsNil()) {
app_name << "-" << global_worker_id_;
}
RayLog::StartRayLog(app_name.str(), RayLogLevel::INFO, options_.log_dir);
if (options_.install_failure_signal_handler) {
RayLog::InstallFailureSignalHandler();
}
}
RAY_CHECK(options_.num_workers > 0);
if (options_.worker_type == WorkerType::DRIVER) {
// Driver process can only contain one worker.
RAY_CHECK(options_.num_workers == 1);
}
RAY_LOG(INFO) << "Constructing CoreWorkerProcess. pid: " << getpid();
if (options_.num_workers == 1) {
// We need to create the worker instance here if:
// 1. This is a driver process. In this case, the driver is ready to use right after
// the CoreWorkerProcess::Initialize.
// 2. This is a Python worker process. In this case, Python will invoke some core
// worker APIs before `CoreWorkerProcess::RunTaskExecutionLoop` is called. So we need
// to create the worker instance here. One example of invocations is
// https://github.com/ray-project/ray/blob/45ce40e5d44801193220d2c546be8de0feeef988/python/ray/worker.py#L1281.
if (options_.worker_type == WorkerType::DRIVER ||
options_.language == Language::PYTHON) {
CreateWorker();
}
}
}
CoreWorkerProcess::~CoreWorkerProcess() {
RAY_LOG(INFO) << "Destructing CoreWorkerProcess. pid: " << getpid();
{
// Check that all `CoreWorker` instances have been removed.
absl::ReaderMutexLock lock(&worker_map_mutex_);
RAY_CHECK(workers_.empty());
}
if (options_.log_dir != "") {
RayLog::ShutDownRayLog();
}
}
void CoreWorkerProcess::EnsureInitialized() {
RAY_CHECK(instance_) << "The core worker process is not initialized yet or already "
<< "shutdown.";
}
CoreWorker &CoreWorkerProcess::GetCoreWorker() {
EnsureInitialized();
if (instance_->options_.num_workers == 1) {
return *instance_->global_worker_;
}
auto ptr = current_core_worker_.lock();
RAY_CHECK(ptr != nullptr)
<< "The current thread is not bound with a core worker instance.";
return *ptr;
}
void CoreWorkerProcess::SetCurrentThreadWorkerId(const WorkerID &worker_id) {
EnsureInitialized();
if (instance_->options_.num_workers == 1) {
RAY_CHECK(instance_->global_worker_->GetWorkerID() == worker_id);
return;
}
current_core_worker_ = instance_->GetWorker(worker_id);
}
std::shared_ptr<CoreWorker> CoreWorkerProcess::GetWorker(
const WorkerID &worker_id) const {
absl::ReaderMutexLock lock(&worker_map_mutex_);
auto it = workers_.find(worker_id);
RAY_CHECK(it != workers_.end()) << "Worker " << worker_id << " not found.";
return it->second;
}
std::shared_ptr<CoreWorker> CoreWorkerProcess::CreateWorker() {
auto worker = std::make_shared<CoreWorker>(
options_,
global_worker_id_ != WorkerID::Nil() ? global_worker_id_ : WorkerID::FromRandom());
RAY_LOG(INFO) << "Worker " << worker->GetWorkerID() << " is created.";
if (options_.num_workers == 1) {
global_worker_ = worker;
}
current_core_worker_ = worker;
absl::MutexLock lock(&worker_map_mutex_);
workers_.emplace(worker->GetWorkerID(), worker);
RAY_CHECK(workers_.size() <= static_cast<size_t>(options_.num_workers));
return worker;
}
void CoreWorkerProcess::RemoveWorker(std::shared_ptr<CoreWorker> worker) {
worker->WaitForShutdown();
if (global_worker_) {
RAY_CHECK(global_worker_ == worker);
} else {
RAY_CHECK(current_core_worker_.lock() == worker);
}
current_core_worker_.reset();
{
absl::MutexLock lock(&worker_map_mutex_);
workers_.erase(worker->GetWorkerID());
RAY_LOG(INFO) << "Removed worker " << worker->GetWorkerID();
}
if (global_worker_ == worker) {
global_worker_ = nullptr;
}
}
void CoreWorkerProcess::RunTaskExecutionLoop() {
EnsureInitialized();
RAY_CHECK(instance_->options_.worker_type == WorkerType::WORKER);
if (instance_->options_.num_workers == 1) {
// Run the task loop in the current thread only if the number of workers is 1.
auto worker =
instance_->global_worker_ ? instance_->global_worker_ : instance_->CreateWorker();
worker->RunTaskExecutionLoop();
instance_->RemoveWorker(worker);
} else {
std::vector<std::thread> worker_threads;
for (int i = 0; i < instance_->options_.num_workers; i++) {
worker_threads.emplace_back([]() {
auto worker = instance_->CreateWorker();
worker->RunTaskExecutionLoop();
instance_->RemoveWorker(worker);
});
}
for (auto &thread : worker_threads) {
thread.join();
}
}
instance_.reset();
}
CoreWorker::CoreWorker(const CoreWorkerOptions &options, const WorkerID &worker_id)
: options_(options),
get_call_site_(RayConfig::instance().record_ref_creation_sites()
? options_.get_lang_stack
: nullptr),
worker_context_(options_.worker_type, worker_id, options_.job_id),
io_work_(io_service_),
client_call_manager_(new rpc::ClientCallManager(io_service_)),
death_check_timer_(io_service_),
internal_timer_(io_service_),
core_worker_server_(WorkerTypeString(worker_type), 0 /* let grpc choose a port */),
core_worker_server_(WorkerTypeString(options_.worker_type),
0 /* let grpc choose a port */),
task_queue_length_(0),
num_executed_tasks_(0),
task_execution_service_work_(task_execution_service_),
task_execution_callback_(task_execution_callback),
resource_ids_(new ResourceMappingType()),
grpc_service_(io_service_, *this) {
// Initialize logging if log_dir is passed. Otherwise, it must be initialized
// and cleaned up by the caller.
if (log_dir_ != "") {
std::stringstream app_name;
app_name << LanguageString(language_) << "-" << WorkerTypeString(worker_type_) << "-"
<< worker_context_.GetWorkerID();
RayLog::StartRayLog(app_name.str(), RayLogLevel::INFO, log_dir_);
RayLog::InstallFailureSignalHandler();
}
// Initialize gcs client.
if (RayConfig::instance().gcs_service_enabled()) {
gcs_client_ = std::make_shared<ray::gcs::ServiceBasedGcsClient>(gcs_options);
gcs_client_ = std::make_shared<ray::gcs::ServiceBasedGcsClient>(options_.gcs_options);
} else {
gcs_client_ = std::make_shared<ray::gcs::RedisGcsClient>(gcs_options);
gcs_client_ = std::make_shared<ray::gcs::RedisGcsClient>(options_.gcs_options);
}
RAY_CHECK_OK(gcs_client_->Connect(io_service_));
RegisterToGcs();
actor_manager_ = std::unique_ptr<ActorManager>(new ActorManager(gcs_client_->Actors()));
// Initialize profiler.
profiler_ = std::make_shared<worker::Profiler>(worker_context_, node_ip_address,
io_service_, gcs_client_);
profiler_ = std::make_shared<worker::Profiler>(
worker_context_, options_.node_ip_address, io_service_, gcs_client_);
// Initialize task receivers.
if (worker_type_ == WorkerType::WORKER || is_local_mode_) {
RAY_CHECK(task_execution_callback_ != nullptr);
if (options_.worker_type == WorkerType::WORKER || options_.is_local_mode) {
RAY_CHECK(options_.task_execution_callback != nullptr);
auto execute_task =
std::bind(&CoreWorker::ExecuteTask, this, std::placeholders::_1,
std::placeholders::_2, std::placeholders::_3, std::placeholders::_4);
@ -154,17 +305,18 @@ CoreWorker::CoreWorker(const WorkerType worker_type, const Language language,
// so that the worker (java/python .etc) can retrieve and handle the error
// instead of crashing.
auto grpc_client = rpc::NodeManagerWorkerClient::make(
node_ip_address, node_manager_port, *client_call_manager_);
options_.node_ip_address, options_.node_manager_port, *client_call_manager_);
ClientID local_raylet_id;
local_raylet_client_ = std::shared_ptr<raylet::RayletClient>(new raylet::RayletClient(
io_service_, std::move(grpc_client), raylet_socket, worker_context_.GetWorkerID(),
(worker_type_ == ray::WorkerType::WORKER), worker_context_.GetCurrentJobID(),
language_, &local_raylet_id, core_worker_server_.GetPort()));
io_service_, std::move(grpc_client), options_.raylet_socket, GetWorkerID(),
(options_.worker_type == ray::WorkerType::WORKER),
worker_context_.GetCurrentJobID(), options_.language, &local_raylet_id,
core_worker_server_.GetPort()));
connected_ = true;
// Set our own address.
RAY_CHECK(!local_raylet_id.IsNil());
rpc_address_.set_ip_address(node_ip_address);
rpc_address_.set_ip_address(options_.node_ip_address);
rpc_address_.set_port(core_worker_server_.GetPort());
rpc_address_.set_raylet_id(local_raylet_id.Binary());
rpc_address_.set_worker_id(worker_context_.GetWorkerID().Binary());
@ -179,20 +331,21 @@ CoreWorker::CoreWorker(const WorkerType worker_type, const Language language,
new rpc::CoreWorkerClient(addr, *client_call_manager_));
});
if (worker_type_ == ray::WorkerType::WORKER) {
if (options_.worker_type == ray::WorkerType::WORKER) {
death_check_timer_.expires_from_now(boost::asio::chrono::milliseconds(
RayConfig::instance().raylet_death_check_interval_milliseconds()));
death_check_timer_.async_wait(boost::bind(&CoreWorker::CheckForRayletFailure, this));
death_check_timer_.async_wait(
boost::bind(&CoreWorker::CheckForRayletFailure, this, _1));
}
internal_timer_.expires_from_now(
boost::asio::chrono::milliseconds(kInternalHeartbeatMillis));
internal_timer_.async_wait(boost::bind(&CoreWorker::InternalHeartbeat, this));
internal_timer_.async_wait(boost::bind(&CoreWorker::InternalHeartbeat, this, _1));
io_thread_ = std::thread(&CoreWorker::RunIOService, this);
plasma_store_provider_.reset(new CoreWorkerPlasmaStoreProvider(
store_socket, local_raylet_client_, check_signals_,
options_.store_socket, local_raylet_client_, options_.check_signals,
/*evict_if_full=*/RayConfig::instance().object_pinning_enabled(),
boost::bind(&CoreWorker::TriggerGlobalGC, this),
boost::bind(&CoreWorker::CurrentCallSite, this)));
@ -201,8 +354,8 @@ CoreWorker::CoreWorker(const WorkerType worker_type, const Language language,
RAY_LOG(DEBUG) << "Promoting object to plasma " << obj_id;
RAY_CHECK_OK(Put(obj, /*contained_object_ids=*/{}, obj_id, /*pin_object=*/true));
},
ref_counting_enabled ? reference_counter_ : nullptr, local_raylet_client_,
check_signals_));
options_.ref_counting_enabled ? reference_counter_ : nullptr, local_raylet_client_,
options_.check_signals));
task_manager_.reset(new TaskManager(
memory_store_, reference_counter_, actor_manager_,
@ -222,16 +375,17 @@ CoreWorker::CoreWorker(const WorkerType worker_type, const Language language,
// driver creates an object that is later evicted, we should notify the
// user that we're unable to reconstruct the object, since we cannot
// rerun the driver.
if (worker_type_ == WorkerType::DRIVER) {
if (options_.worker_type == WorkerType::DRIVER) {
TaskSpecBuilder builder;
const TaskID task_id = TaskID::ForDriverTask(worker_context_.GetCurrentJobID());
builder.SetDriverTaskSpec(task_id, language_, worker_context_.GetCurrentJobID(),
builder.SetDriverTaskSpec(task_id, options_.language,
worker_context_.GetCurrentJobID(),
TaskID::ComputeDriverTaskId(worker_context_.GetWorkerID()),
GetCallerId(), rpc_address_);
std::shared_ptr<gcs::TaskTableData> data = std::make_shared<gcs::TaskTableData>();
data->mutable_task()->mutable_task_spec()->CopyFrom(builder.Build().GetMessage());
if (!is_local_mode_) {
if (!options_.is_local_mode) {
RAY_CHECK_OK(gcs_client_->Tasks().AsyncAdd(data, nullptr));
}
SetCurrentTaskId(task_id);
@ -262,17 +416,9 @@ CoreWorker::CoreWorker(const WorkerType worker_type, const Language language,
}
}
CoreWorker::~CoreWorker() {
io_service_.stop();
io_thread_.join();
if (log_dir_ != "") {
RayLog::ShutDownRayLog();
}
}
void CoreWorker::Shutdown() {
io_service_.stop();
if (worker_type_ == WorkerType::WORKER) {
if (options_.worker_type == WorkerType::WORKER) {
task_execution_service_.stop();
}
}
@ -356,6 +502,17 @@ void CoreWorker::RunIOService() {
io_service_.run();
}
void CoreWorker::WaitForShutdown() {
if (io_thread_.joinable()) {
io_thread_.join();
}
if (options_.worker_type == WorkerType::WORKER) {
RAY_CHECK(task_execution_service_.stopped());
}
}
const WorkerID &CoreWorker::GetWorkerID() const { return worker_context_.GetWorkerID(); }
void CoreWorker::SetCurrentTaskId(const TaskID &task_id) {
worker_context_.SetCurrentTaskId(task_id);
main_thread_task_id_ = task_id;
@ -375,8 +532,41 @@ void CoreWorker::SetCurrentTaskId(const TaskID &task_id) {
}
}
void CoreWorker::CheckForRayletFailure() {
// If the raylet fails, we will be reassigned to init (PID=1).
void CoreWorker::RegisterToGcs() {
std::unordered_map<std::string, std::string> worker_info;
const auto &worker_id = GetWorkerID();
worker_info.emplace("node_ip_address", options_.node_ip_address);
worker_info.emplace("plasma_store_socket", options_.store_socket);
worker_info.emplace("raylet_socket", options_.raylet_socket);
if (options_.worker_type == WorkerType::DRIVER) {
auto start_time = std::chrono::duration_cast<std::chrono::milliseconds>(
std::chrono::system_clock::now().time_since_epoch())
.count();
worker_info.emplace("driver_id", worker_id.Binary());
worker_info.emplace("start_time", std::to_string(start_time));
if (!options_.driver_name.empty()) {
worker_info.emplace("name", options_.driver_name);
}
}
if (!options_.stdout_file.empty()) {
worker_info.emplace("stdout_file", options_.stdout_file);
}
if (!options_.stderr_file.empty()) {
worker_info.emplace("stderr_file", options_.stderr_file);
}
RAY_CHECK_OK(gcs_client_->Workers().AsyncRegisterWorker(options_.worker_type, worker_id,
worker_info, nullptr));
}
void CoreWorker::CheckForRayletFailure(const boost::system::error_code &error) {
if (error == boost::asio::error::operation_aborted) {
return;
}
// If the raylet fails, we will be reassigned to init (PID=1).
if (getppid() == 1) {
RAY_LOG(ERROR) << "Raylet failed. Shutting down.";
Shutdown();
@ -387,10 +577,15 @@ void CoreWorker::CheckForRayletFailure() {
death_check_timer_.expiry() +
boost::asio::chrono::milliseconds(
RayConfig::instance().raylet_death_check_interval_milliseconds()));
death_check_timer_.async_wait(boost::bind(&CoreWorker::CheckForRayletFailure, this));
death_check_timer_.async_wait(
boost::bind(&CoreWorker::CheckForRayletFailure, this, _1));
}
void CoreWorker::InternalHeartbeat() {
void CoreWorker::InternalHeartbeat(const boost::system::error_code &error) {
if (error == boost::asio::error::operation_aborted) {
return;
}
absl::MutexLock lock(&mutex_);
while (!to_resubmit_.empty() && current_time_ms() > to_resubmit_.front().first) {
RAY_CHECK_OK(direct_task_submitter_->SubmitTask(to_resubmit_.front().second));
@ -398,7 +593,7 @@ void CoreWorker::InternalHeartbeat() {
}
internal_timer_.expires_at(internal_timer_.expiry() +
boost::asio::chrono::milliseconds(kInternalHeartbeatMillis));
internal_timer_.async_wait(boost::bind(&CoreWorker::InternalHeartbeat, this));
internal_timer_.async_wait(boost::bind(&CoreWorker::InternalHeartbeat, this, _1));
}
std::unordered_map<ObjectID, std::pair<size_t, size_t>>
@ -445,7 +640,7 @@ void CoreWorker::RegisterOwnershipInfoAndResolveFuture(
reference_counter_->AddBorrowedObject(object_id, outer_object_id, owner_id,
owner_address);
RAY_CHECK(!owner_id.IsNil() || is_local_mode_);
RAY_CHECK(!owner_id.IsNil() || options_.is_local_mode);
// We will ask the owner about the object until the object is
// created or we can no longer reach the owner.
future_resolver_->ResolveFutureAsync(object_id, owner_id, owner_address);
@ -471,7 +666,7 @@ Status CoreWorker::Put(const RayObject &object,
const std::vector<ObjectID> &contained_object_ids,
const ObjectID &object_id, bool pin_object) {
bool object_exists;
if (is_local_mode_) {
if (options_.is_local_mode) {
RAY_CHECK(memory_store_->Put(object, object_id));
return Status::OK();
}
@ -505,7 +700,7 @@ Status CoreWorker::Create(const std::shared_ptr<Buffer> &metadata, const size_t
worker_context_.GetNextPutIndex(),
static_cast<uint8_t>(TaskTransportType::DIRECT));
if (is_local_mode_) {
if (options_.is_local_mode) {
*data = std::make_shared<LocalMemoryBuffer>(data_size);
} else {
RAY_RETURN_NOT_OK(
@ -523,7 +718,7 @@ Status CoreWorker::Create(const std::shared_ptr<Buffer> &metadata, const size_t
Status CoreWorker::Create(const std::shared_ptr<Buffer> &metadata, const size_t data_size,
const ObjectID &object_id, std::shared_ptr<Buffer> *data) {
if (is_local_mode_) {
if (options_.is_local_mode) {
return Status::NotImplemented(
"Creating an object with a pre-existing ObjectID is not supported in local mode");
} else {
@ -791,7 +986,7 @@ TaskID CoreWorker::GetCallerId() const {
Status CoreWorker::PushError(const JobID &job_id, const std::string &type,
const std::string &error_message, double timestamp) {
if (is_local_mode_) {
if (options_.is_local_mode) {
RAY_LOG(ERROR) << "Pushed Error with JobID: " << job_id << " of type: " << type
<< " with message: " << error_message << " at time: " << timestamp;
return Status::OK();
@ -831,7 +1026,7 @@ Status CoreWorker::SubmitTask(const RayFunction &function,
rpc_address_, function, args, task_options.num_returns,
task_options.resources, required_resources, return_ids);
TaskSpecification task_spec = builder.Build();
if (is_local_mode_) {
if (options_.is_local_mode) {
return ExecuteTaskLocalMode(task_spec);
} else {
task_manager_->AddPendingTask(GetCallerId(), rpc_address_, task_spec,
@ -866,7 +1061,7 @@ Status CoreWorker::CreateActor(const RayFunction &function,
*return_actor_id = actor_id;
TaskSpecification task_spec = builder.Build();
Status status;
if (is_local_mode_) {
if (options_.is_local_mode) {
status = ExecuteTaskLocalMode(task_spec);
} else {
task_manager_->AddPendingTask(
@ -914,7 +1109,7 @@ Status CoreWorker::SubmitActorTask(const ActorID &actor_id, const RayFunction &f
// Submit task.
Status status;
TaskSpecification task_spec = builder.Build();
if (is_local_mode_) {
if (options_.is_local_mode) {
return ExecuteTaskLocalMode(task_spec, actor_id);
}
task_manager_->AddPendingTask(GetCallerId(), rpc_address_, task_spec,
@ -1055,7 +1250,7 @@ std::unique_ptr<worker::ProfileEvent> CoreWorker::CreateProfileEvent(
new worker::ProfileEvent(profiler_, event_type));
}
void CoreWorker::StartExecutingTasks() { task_execution_service_.run(); }
void CoreWorker::RunTaskExecutionLoop() { task_execution_service_.run(); }
Status CoreWorker::AllocateReturnObjects(
const std::vector<ObjectID> &object_ids, const std::vector<size_t> &data_sizes,
@ -1066,7 +1261,7 @@ Status CoreWorker::AllocateReturnObjects(
RAY_CHECK(object_ids.size() == data_sizes.size());
return_objects->resize(object_ids.size(), nullptr);
rpc::Address owner_address(is_local_mode_
rpc::Address owner_address(options_.is_local_mode
? rpc::Address()
: worker_context_.GetCurrentTask()->CallerAddress());
@ -1083,8 +1278,9 @@ Status CoreWorker::AllocateReturnObjects(
}
// Allocate a buffer for the return object.
if (is_local_mode_ || static_cast<int64_t>(data_sizes[i]) <
RayConfig::instance().max_direct_call_object_size()) {
if (options_.is_local_mode ||
static_cast<int64_t>(data_sizes[i]) <
RayConfig::instance().max_direct_call_object_size()) {
data_buffer = std::make_shared<LocalMemoryBuffer>(data_sizes[i]);
} else {
RAY_RETURN_NOT_OK(
@ -1113,7 +1309,7 @@ Status CoreWorker::ExecuteTask(const TaskSpecification &task_spec,
resource_ids_ = resource_ids;
}
if (!is_local_mode_) {
if (!options_.is_local_mode) {
worker_context_.SetCurrentTask(task_spec);
SetCurrentTaskId(task_spec.TaskId());
}
@ -1162,13 +1358,17 @@ Status CoreWorker::ExecuteTask(const TaskSpecification &task_spec,
SetCallerCreationTimestamp();
}
status = task_execution_callback_(
// Because we support concurrent actor calls, we need to update the
// worker ID for the current thread.
CoreWorkerProcess::SetCurrentThreadWorkerId(GetWorkerID());
status = options_.task_execution_callback(
task_type, func, task_spec.GetRequiredResources().GetResourceMap(), args,
arg_reference_ids, return_ids, return_objects, worker_context_.GetWorkerID());
arg_reference_ids, return_ids, return_objects);
absl::optional<rpc::Address> caller_address(
is_local_mode_ ? absl::optional<rpc::Address>()
: worker_context_.GetCurrentTask()->CallerAddress());
options_.is_local_mode ? absl::optional<rpc::Address>()
: worker_context_.GetCurrentTask()->CallerAddress());
for (size_t i = 0; i < return_objects->size(); i++) {
// The object is nullptr if it already existed in the object store.
if (!return_objects->at(i)) {
@ -1196,7 +1396,7 @@ Status CoreWorker::ExecuteTask(const TaskSpecification &task_spec,
RAY_LOG(DEBUG) << "Decrementing ref for borrowed ID " << borrowed_id;
reference_counter_->RemoveLocalReference(borrowed_id, &deleted);
}
if (ref_counting_enabled_) {
if (options_.ref_counting_enabled) {
memory_store_->Delete(deleted);
}
@ -1210,7 +1410,7 @@ Status CoreWorker::ExecuteTask(const TaskSpecification &task_spec,
"reference counting, and may cause problems in the object store.";
}
if (!is_local_mode_) {
if (!options_.is_local_mode) {
SetCurrentTaskId(TaskID::Nil());
worker_context_.ResetCurrentTask(task_spec);
}
@ -1263,7 +1463,7 @@ Status CoreWorker::BuildArgsForExecutor(const TaskSpecification &task,
// Direct call type objects that weren't inlined have been promoted to plasma.
// We need to put an OBJECT_IN_PLASMA error here so the subsequent call to Get()
// properly redirects to the plasma store.
if (task.ArgId(i, 0).IsDirectCallType() && !is_local_mode_) {
if (task.ArgId(i, 0).IsDirectCallType() && !options_.is_local_mode) {
RAY_UNUSED(memory_store_->Put(RayObject(rpc::ErrorType::OBJECT_IN_PLASMA),
task.ArgId(i, 0)));
}
@ -1309,7 +1509,7 @@ Status CoreWorker::BuildArgsForExecutor(const TaskSpecification &task,
// Fetch by-reference arguments directly from the plasma store.
bool got_exception = false;
absl::flat_hash_map<ObjectID, std::shared_ptr<RayObject>> result_map;
if (is_local_mode_) {
if (options_.is_local_mode) {
RAY_RETURN_NOT_OK(
memory_store_->Get(by_ref_ids, -1, worker_context_, &result_map, &got_exception));
} else {
@ -1472,7 +1672,16 @@ void CoreWorker::HandleKillActor(const rpc::KillActorRequest &request,
if (request.no_reconstruction()) {
RAY_IGNORE_EXPR(local_raylet_client_->Disconnect());
}
if (log_dir_ != "") {
if (options_.num_workers > 1) {
// TODO (kfstorm): Should we add some kind of check before sending the killing
// request?
RAY_LOG(ERROR)
<< "Killing an actor which is running in a worker process with multiple "
"workers will also kill other actors in this process. To avoid this, "
"please create the Java actor with some dynamic options to make it being "
"hosted in a dedicated worker process.";
}
if (options_.log_dir != "") {
RayLog::ShutDownRayLog();
}
exit(1);
@ -1522,8 +1731,8 @@ void CoreWorker::HandleGetCoreWorkerStats(const rpc::GetCoreWorkerStatsRequest &
void CoreWorker::HandleLocalGC(const rpc::LocalGCRequest &request,
rpc::LocalGCReply *reply,
rpc::SendReplyCallback send_reply_callback) {
if (gc_collect_ != nullptr) {
gc_collect_();
if (options_.gc_collect != nullptr) {
options_.gc_collect();
send_reply_callback(Status::OK(), nullptr, nullptr);
} else {
send_reply_callback(Status::NotImplemented("GC callback not defined"), nullptr,
@ -1573,7 +1782,7 @@ void CoreWorker::HandlePlasmaObjectReady(const rpc::PlasmaObjectReadyRequest &re
void CoreWorker::SetActorId(const ActorID &actor_id) {
absl::MutexLock lock(&mutex_);
if (!is_local_mode_) {
if (!options_.is_local_mode) {
RAY_CHECK(actor_id_.IsNil());
}
actor_id_ = actor_id;

View file

@ -50,10 +50,9 @@
namespace ray {
/// The root class that contains all the core and language-independent functionalities
/// of the worker. This class is supposed to be used to implement app-language (Java,
/// Python, etc) workers.
class CoreWorker : public rpc::CoreWorkerServiceHandler {
class CoreWorker;
struct CoreWorkerOptions {
// Callback that must be implemented and provided by the language-specific worker
// frontend to execute tasks and return their results.
using TaskExecutionCallback = std::function<Status(
@ -62,60 +61,242 @@ class CoreWorker : public rpc::CoreWorkerServiceHandler {
const std::vector<std::shared_ptr<RayObject>> &args,
const std::vector<ObjectID> &arg_reference_ids,
const std::vector<ObjectID> &return_ids,
std::vector<std::shared_ptr<RayObject>> *results, const ray::WorkerID &worker_id)>;
std::vector<std::shared_ptr<RayObject>> *results)>;
/// Type of this worker (i.e., DRIVER or WORKER).
WorkerType worker_type;
/// Application language of this worker (i.e., PYTHON or JAVA).
Language language;
/// Object store socket to connect to.
std::string store_socket;
/// Raylet socket to connect to.
std::string raylet_socket;
/// Job ID of this worker.
JobID job_id;
/// Options for the GCS client.
gcs::GcsClientOptions gcs_options;
/// Directory to write logs to. If this is empty, logs won't be written to a file.
std::string log_dir;
/// If false, will not call `RayLog::InstallFailureSignalHandler()`.
bool install_failure_signal_handler;
/// IP address of the node.
std::string node_ip_address;
/// Port of the local raylet.
int node_manager_port;
/// The name of the driver.
std::string driver_name;
/// The stdout file of this process.
std::string stdout_file;
/// The stderr file of this process.
std::string stderr_file;
/// Language worker callback to execute tasks.
TaskExecutionCallback task_execution_callback;
/// Application-language callback to check for signals that have been received
/// since calling into C++. This will be called periodically (at least every
/// 1s) during long-running operations. If the function returns anything but StatusOK,
/// any long-running operations in the core worker will short circuit and return that
/// status.
std::function<Status()> check_signals;
/// Application-language callback to trigger garbage collection in the language
/// runtime. This is required to free distributed references that may otherwise
/// be held up in garbage objects.
std::function<void()> gc_collect;
/// Language worker callback to get the current call stack.
std::function<void(std::string *)> get_lang_stack;
/// Whether to enable object ref counting.
bool ref_counting_enabled;
/// Is local mode being used.
bool is_local_mode;
/// The number of workers to be started in the current process.
int num_workers;
};
/// Lifecycle management of one or more `CoreWorker` instances in a process.
///
/// To start a driver in the current process:
/// CoreWorkerOptions options = {
/// WorkerType::DRIVER, // worker_type
/// ..., // other arguments
/// 1, // num_workers
/// };
/// CoreWorkerProcess::Initialize(options);
///
/// To shutdown a driver in the current process:
/// CoreWorkerProcess::Shutdown();
///
/// To start one or more workers in the current process:
/// CoreWorkerOptions options = {
/// WorkerType::WORKER, // worker_type
/// ..., // other arguments
/// num_workers, // num_workers
/// };
/// CoreWorkerProcess::Initialize(options);
/// ... // Do other stuff
/// CoreWorkerProcess::RunTaskExecutionLoop();
///
/// To shutdown a worker in the current process, return a system exit status (with status
/// code `IntentionalSystemExit` or `UnexpectedSystemExit`) in the task execution
/// callback.
///
/// If more than 1 worker is started, only the threads which invoke the
/// `task_execution_callback` will be automatically associated with the corresponding
/// worker. If you started your own threads and you want to use core worker APIs in these
/// threads, remember to call `CoreWorkerProcess::SetCurrentThreadWorkerId(worker_id)`
/// once in the new thread before calling core worker APIs, to associate the current
/// thread with a worker. You can obtain the worker ID via
/// `CoreWorkerProcess::GetCoreWorker()->GetWorkerID()`. Currently a Java worker process
/// starts multiple workers by default, but can be configured to start only 1 worker by
/// overwriting the internal config `num_workers_per_process_java`.
///
/// If only 1 worker is started (either because the worker type is driver, or the
/// `num_workers` in `CoreWorkerOptions` is set to 1), all threads will be automatically
/// associated to the only worker. Then no need to call `SetCurrentThreadWorkerId` in
/// your own threads. Currently a Python worker process starts only 1 worker.
class CoreWorkerProcess {
public:
///
/// Public methods used in both DRIVER and WORKER mode.
///
/// Initialize core workers at the process level.
///
/// \param[in] options The various initialization options.
static void Initialize(const CoreWorkerOptions &options);
/// Get the core worker associated with the current thread.
/// NOTE (kfstorm): Here we return a reference instead of a `shared_ptr` to make sure
/// `CoreWorkerProcess` has full control of the destruction timing of `CoreWorker`.
static CoreWorker &GetCoreWorker();
/// Set the core worker associated with the current thread by worker ID.
/// Currently used by Java worker only.
///
/// \param worker_id The worker ID of the core worker instance.
static void SetCurrentThreadWorkerId(const WorkerID &worker_id);
/// Whether the current process has been initialized for core worker.
static bool IsInitialized();
///
/// Public methods used in DRIVER mode only.
///
/// Shutdown the driver completely at the process level.
static void Shutdown();
///
/// Public methods used in WORKER mode only.
///
/// Start receiving and executing tasks.
static void RunTaskExecutionLoop();
// The destructor is not to be used as a public API, but it's required by smart
// pointers.
~CoreWorkerProcess();
private:
/// Create an `CoreWorkerProcess` with proper options.
///
/// \param[in] options The various initialization options.
CoreWorkerProcess(const CoreWorkerOptions &options);
/// Check that the core worker environment is initialized for this process.
///
/// \return Void.
static void EnsureInitialized();
/// Get the `CoreWorker` instance by worker ID.
///
/// \param[in] workerId The worker ID.
/// \return The `CoreWorker` instance.
std::shared_ptr<CoreWorker> GetWorker(const WorkerID &worker_id) const
LOCKS_EXCLUDED(worker_map_mutex_);
/// Create a new `CoreWorker` instance.
///
/// \return The newly created `CoreWorker` instance.
std::shared_ptr<CoreWorker> CreateWorker() LOCKS_EXCLUDED(worker_map_mutex_);
/// Remove an existing `CoreWorker` instance.
///
/// \param[in] The existing `CoreWorker` instance.
/// \return Void.
void RemoveWorker(std::shared_ptr<CoreWorker> worker) LOCKS_EXCLUDED(worker_map_mutex_);
/// The global instance of `CoreWorkerProcess`.
static std::unique_ptr<CoreWorkerProcess> instance_;
/// The various options.
const CoreWorkerOptions options_;
/// The core worker instance associated with the current thread.
/// Use weak_ptr here to avoid memory leak due to multi-threading.
static thread_local std::weak_ptr<CoreWorker> current_core_worker_;
/// The only core worker instance, if the number of workers is 1.
std::shared_ptr<CoreWorker> global_worker_;
/// The worker ID of the global worker, if the number of workers is 1.
const WorkerID global_worker_id_;
/// Map from worker ID to worker.
std::unordered_map<WorkerID, std::shared_ptr<CoreWorker>> workers_
GUARDED_BY(worker_map_mutex_);
/// To protect accessing the `workers_` map.
mutable absl::Mutex worker_map_mutex_;
};
/// The root class that contains all the core and language-independent functionalities
/// of the worker. This class is supposed to be used to implement app-language (Java,
/// Python, etc) workers.
class CoreWorker : public rpc::CoreWorkerServiceHandler {
public:
/// Construct a CoreWorker instance.
///
/// \param[in] worker_type Type of this worker.
/// \param[in] language Language of this worker.
/// \param[in] store_socket Object store socket to connect to.
/// \param[in] raylet_socket Raylet socket to connect to.
/// \param[in] job_id Job ID of this worker.
/// \param[in] gcs_options Options for the GCS client.
/// \param[in] log_dir Directory to write logs to. If this is empty, logs
/// won't be written to a file.
/// \param[in] node_ip_address IP address of the node.
/// \param[in] node_manager_port Port of the local raylet.
/// \param[in] task_execution_callback Language worker callback to execute tasks.
/// \param[in] check_signals Language worker function to check for signals and handle
/// them. If the function returns anything but StatusOK, any long-running
/// operations in the core worker will short circuit and return that status.
/// \param[in] ref_counting_enabled Whether to enable object ref counting.
/// \param[in] options The various initialization options.
/// \param[in] worker_id ID of this worker.
CoreWorker(const CoreWorkerOptions &options, const WorkerID &worker_id);
CoreWorker(CoreWorker const &) = delete;
void operator=(CoreWorker const &other) = delete;
///
/// Public methods used by `CoreWorkerProcess` and `CoreWorker` itself.
///
/// NOTE(zhijunfu): the constructor would throw if a failure happens.
CoreWorker(const WorkerType worker_type, const Language language,
const std::string &store_socket, const std::string &raylet_socket,
const JobID &job_id, const gcs::GcsClientOptions &gcs_options,
const std::string &log_dir, const std::string &node_ip_address,
int node_manager_port, const TaskExecutionCallback &task_execution_callback,
std::function<Status()> check_signals = nullptr,
std::function<void()> gc_collect = nullptr,
std::function<void(std::string *)> get_lang_stack = nullptr,
bool ref_counting_enabled = false, bool local_mode = false);
virtual ~CoreWorker();
void Exit(bool intentional);
/// Gracefully disconnect the worker from other components of ray. e.g. Raylet.
/// If this function is called during shutdown, Raylet will treat it as an intentional
/// disconnect.
///
/// \return Void.
void Disconnect();
WorkerType GetWorkerType() const { return worker_type_; }
/// Shut down the worker completely.
///
/// \return void.
void Shutdown();
Language GetLanguage() const { return language_; }
/// Block the current thread until the worker is shut down.
void WaitForShutdown();
/// Start receiving and executing tasks.
/// \return void.
void RunTaskExecutionLoop();
const WorkerID &GetWorkerID() const;
WorkerType GetWorkerType() const { return options_.worker_type; }
Language GetLanguage() const { return options_.language; }
WorkerContext &GetWorkerContext() { return worker_context_; }
raylet::RayletClient &GetRayletClient() { return *local_raylet_client_; }
const TaskID &GetCurrentTaskId() const { return worker_context_.GetCurrentTaskID(); }
void SetCurrentTaskId(const TaskID &task_id);
const JobID &GetCurrentJobId() const { return worker_context_.GetCurrentJobID(); }
void SetActorId(const ActorID &actor_id);
void SetWebuiDisplay(const std::string &key, const std::string &message);
void SetActorTitle(const std::string &title);
@ -139,7 +320,7 @@ class CoreWorker : public rpc::CoreWorkerServiceHandler {
std::vector<ObjectID> deleted;
reference_counter_->RemoveLocalReference(object_id, &deleted);
// TOOD(ilr): better way of keeping an object from being deleted
if (ref_counting_enabled_ && !is_local_mode_) {
if (options_.ref_counting_enabled && !options_.is_local_mode) {
memory_store_->Delete(deleted);
}
}
@ -448,10 +629,7 @@ class CoreWorker : public rpc::CoreWorkerServiceHandler {
/// Create a profile event with a reference to the core worker's profiler.
std::unique_ptr<worker::ProfileEvent> CreateProfileEvent(const std::string &event_type);
/// Start receiving and executing tasks.
/// \return void.
void StartExecutingTasks();
public:
/// Allocate the return objects for an executing task. The caller should write into the
/// data buffers of the allocated buffers.
///
@ -566,18 +744,25 @@ class CoreWorker : public rpc::CoreWorkerServiceHandler {
void SubscribeToPlasmaAdd(const ObjectID &object_id);
private:
void SetCurrentTaskId(const TaskID &task_id);
void SetActorId(const ActorID &actor_id);
/// Run the io_service_ event loop. This should be called in a background thread.
void RunIOService();
/// Shut down the worker completely.
/// \return void.
void Shutdown();
/// (WORKER mode only) Exit the worker. This is the entrypoint used to shutdown a
/// worker.
void Exit(bool intentional);
/// Register this worker or driver to GCS.
void RegisterToGcs();
/// Check if the raylet has failed. If so, shutdown.
void CheckForRayletFailure();
void CheckForRayletFailure(const boost::system::error_code &error);
/// Heartbeat for internal bookkeeping.
void InternalHeartbeat();
void InternalHeartbeat(const boost::system::error_code &error);
///
/// Private methods related to task submission.
@ -682,30 +867,7 @@ class CoreWorker : public rpc::CoreWorkerServiceHandler {
}
}
/// Type of this worker (i.e., DRIVER or WORKER).
const WorkerType worker_type_;
/// Application language of this worker (i.e., PYTHON or JAVA).
const Language language_;
/// Directory where log files are written.
const std::string log_dir_;
/// Whether local reference counting is enabled.
const bool ref_counting_enabled_;
/// Is local mode being used.
const bool is_local_mode_;
/// Application-language callback to check for signals that have been received
/// since calling into C++. This will be called periodically (at least every
/// 1s) during long-running operations.
std::function<Status()> check_signals_;
/// Application-language callback to trigger garbage collection in the language
/// runtime. This is required to free distributed references that may otherwise
/// be held up in garbage objects.
std::function<void()> gc_collect_;
const CoreWorkerOptions options_;
/// Callback to get the current language (e.g., Python) call site.
std::function<void(std::string *)> get_call_site_;
@ -843,9 +1005,6 @@ class CoreWorker : public rpc::CoreWorkerServiceHandler {
/// Profiler including a background thread that pushes profiling events to the GCS.
std::shared_ptr<worker::Profiler> profiler_;
/// Task execution callback.
TaskExecutionCallback task_execution_callback_;
/// A map from resource name to the resource IDs that are currently reserved
/// for this worker. Each pair consists of the resource ID and the fraction
/// of that resource allocated for this worker. This is set on task assignment.

View file

@ -82,7 +82,6 @@ jfieldID java_native_ray_object_metadata;
jclass java_task_executor_class;
jmethodID java_task_executor_execute;
jmethodID java_task_executor_get;
JavaVM *jvm;
@ -197,9 +196,6 @@ jint JNI_OnLoad(JavaVM *vm, void *reserved) {
env->GetMethodID(java_task_executor_class, "execute",
"(Ljava/util/List;Ljava/util/List;)Ljava/util/List;");
java_task_executor_get = env->GetStaticMethodID(
java_task_executor_class, "get", "([B)Lorg/ray/runtime/task/TaskExecutor;");
return CURRENT_JNI_VERSION;
}

View file

@ -141,9 +141,6 @@ extern jclass java_task_executor_class;
/// execute method of TaskExecutor class
extern jmethodID java_task_executor_execute;
/// The `get` method in TaskExecutor class
extern jmethodID java_task_executor_get;
#define CURRENT_JNI_VERSION JNI_VERSION_1_8
extern JavaVM *jvm;

View file

@ -20,6 +20,7 @@
#include "ray/core_worker/lib/java/jni_utils.h"
thread_local JNIEnv *local_env = nullptr;
jobject java_task_executor = nullptr;
inline ray::gcs::GcsClientOptions ToGcsClientOptions(JNIEnv *env,
jobject gcs_client_options) {
@ -36,15 +37,20 @@ inline ray::gcs::GcsClientOptions ToGcsClientOptions(JNIEnv *env,
extern "C" {
#endif
JNIEXPORT jlong JNICALL Java_org_ray_runtime_RayNativeRuntime_nativeInitCoreWorker(
JNIEnv *env, jclass, jint workerMode, jstring storeSocket, jstring rayletSocket,
jstring nodeIpAddress, jint nodeManagerPort, jbyteArray jobId,
jobject gcsClientOptions) {
auto native_store_socket = JavaStringToNativeString(env, storeSocket);
auto native_raylet_socket = JavaStringToNativeString(env, rayletSocket);
auto job_id = JavaByteArrayToId<ray::JobID>(env, jobId);
auto gcs_client_options = ToGcsClientOptions(env, gcsClientOptions);
auto node_ip_address = JavaStringToNativeString(env, nodeIpAddress);
JNIEXPORT void JNICALL Java_org_ray_runtime_RayNativeRuntime_nativeInitialize(
JNIEnv *env, jclass, jint workerMode, jstring nodeIpAddress, jint nodeManagerPort,
jstring driverName, jstring storeSocket, jstring rayletSocket, jbyteArray jobId,
jobject gcsClientOptions, jint numWorkersPerProcess, jstring logDir,
jobject rayletConfigParameters) {
auto raylet_config = JavaMapToNativeMap<std::string, std::string>(
env, rayletConfigParameters,
[](JNIEnv *env, jobject java_key) {
return JavaStringToNativeString(env, (jstring)java_key);
},
[](JNIEnv *env, jobject java_value) {
return JavaStringToNativeString(env, (jstring)java_value);
});
RayConfig::instance().initialize(raylet_config);
auto task_execution_callback =
[](ray::TaskType task_type, const ray::RayFunction &ray_function,
@ -52,8 +58,7 @@ JNIEXPORT jlong JNICALL Java_org_ray_runtime_RayNativeRuntime_nativeInitCoreWork
const std::vector<std::shared_ptr<ray::RayObject>> &args,
const std::vector<ObjectID> &arg_reference_ids,
const std::vector<ObjectID> &return_ids,
std::vector<std::shared_ptr<ray::RayObject>> *results,
const ray::WorkerID &worker_id) {
std::vector<std::shared_ptr<ray::RayObject>> *results) {
JNIEnv *env = local_env;
if (!env) {
// Attach the native thread to JVM.
@ -64,12 +69,7 @@ JNIEXPORT jlong JNICALL Java_org_ray_runtime_RayNativeRuntime_nativeInitCoreWork
}
RAY_CHECK(env);
auto worker_id_bytes = IdToJavaByteArray<ray::WorkerID>(env, worker_id);
jobject local_java_task_executor = env->CallStaticObjectMethod(
java_task_executor_class, java_task_executor_get, worker_id_bytes);
RAY_CHECK(local_java_task_executor);
RAY_CHECK(java_task_executor);
// convert RayFunction
jobject ray_function_array_list = NativeRayFunctionDescriptorToJavaStringList(
env, ray_function.GetFunctionDescriptor());
@ -80,7 +80,7 @@ JNIEXPORT jlong JNICALL Java_org_ray_runtime_RayNativeRuntime_nativeInitCoreWork
// invoke Java method
jobject java_return_objects =
env->CallObjectMethod(local_java_task_executor, java_task_executor_execute,
env->CallObjectMethod(java_task_executor, java_task_executor_execute,
ray_function_array_list, args_array_list);
RAY_CHECK_JAVA_EXCEPTION(env);
std::vector<std::shared_ptr<ray::RayObject>> return_objects;
@ -99,81 +99,70 @@ JNIEXPORT jlong JNICALL Java_org_ray_runtime_RayNativeRuntime_nativeInitCoreWork
return ray::Status::OK();
};
try {
auto core_worker = new ray::CoreWorker(
static_cast<ray::WorkerType>(workerMode), ::Language::JAVA, native_store_socket,
native_raylet_socket, job_id, gcs_client_options, /*log_dir=*/"", node_ip_address,
nodeManagerPort, task_execution_callback);
return reinterpret_cast<jlong>(core_worker);
} catch (const std::exception &e) {
std::ostringstream oss;
oss << "Failed to construct core worker: " << e.what();
THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, ray::Status::Invalid(oss.str()), 0);
return 0; // To make compiler no complain
}
ray::CoreWorkerOptions options = {
static_cast<ray::WorkerType>(workerMode), // worker_type
ray::Language::JAVA, // langauge
JavaStringToNativeString(env, storeSocket), // store_socket
JavaStringToNativeString(env, rayletSocket), // raylet_socket
JavaByteArrayToId<ray::JobID>(env, jobId), // job_id
ToGcsClientOptions(env, gcsClientOptions), // gcs_options
JavaStringToNativeString(env, logDir), // log_dir
// TODO (kfstorm): JVM would crash if install_failure_signal_handler was set to true
false, // install_failure_signal_handler
JavaStringToNativeString(env, nodeIpAddress), // node_ip_address
static_cast<int>(nodeManagerPort), // node_manager_port
JavaStringToNativeString(env, driverName), // driver_name
"", // stdout_file
"", // stderr_file
task_execution_callback, // task_execution_callback
nullptr, // check_signals
nullptr, // gc_collect
nullptr, // get_lang_stack
false, // ref_counting_enabled
false, // is_local_mode
static_cast<int>(numWorkersPerProcess), // num_workers
};
ray::CoreWorkerProcess::Initialize(options);
}
JNIEXPORT void JNICALL Java_org_ray_runtime_RayNativeRuntime_nativeRunTaskExecutor(
JNIEnv *env, jclass o, jlong nativeCoreWorkerPointer) {
local_env = env;
auto core_worker = reinterpret_cast<ray::CoreWorker *>(nativeCoreWorkerPointer);
core_worker->StartExecutingTasks();
local_env = nullptr;
JNIEnv *env, jclass o, jobject javaTaskExecutor) {
java_task_executor = javaTaskExecutor;
ray::CoreWorkerProcess::RunTaskExecutionLoop();
java_task_executor = nullptr;
}
JNIEXPORT void JNICALL Java_org_ray_runtime_RayNativeRuntime_nativeDestroyCoreWorker(
JNIEnv *env, jclass o, jlong nativeCoreWorkerPointer) {
auto core_worker = reinterpret_cast<ray::CoreWorker *>(nativeCoreWorkerPointer);
core_worker->Disconnect();
delete core_worker;
}
JNIEXPORT void JNICALL Java_org_ray_runtime_RayNativeRuntime_nativeSetup(
JNIEnv *env, jclass, jstring logDir, jobject rayletConfigParameters) {
std::string log_dir = JavaStringToNativeString(env, logDir);
ray::RayLog::StartRayLog("java_worker", ray::RayLogLevel::INFO, log_dir);
// TODO (kfstorm): We can't InstallFailureSignalHandler here, because JVM already
// installed its own signal handler. It's possible to fix this by chaining signal
// handlers. But it's not easy. See
// https://docs.oracle.com/javase/9/troubleshoot/handle-signals-and-exceptions.htm.
auto raylet_config = JavaMapToNativeMap<std::string, std::string>(
env, rayletConfigParameters,
[](JNIEnv *env, jobject java_key) {
return JavaStringToNativeString(env, (jstring)java_key);
},
[](JNIEnv *env, jobject java_value) {
return JavaStringToNativeString(env, (jstring)java_value);
});
RayConfig::instance().initialize(raylet_config);
}
JNIEXPORT void JNICALL Java_org_ray_runtime_RayNativeRuntime_nativeShutdownHook(JNIEnv *,
jclass) {
ray::RayLog::ShutDownRayLog();
JNIEXPORT void JNICALL Java_org_ray_runtime_RayNativeRuntime_nativeShutdown(JNIEnv *env,
jclass o) {
ray::CoreWorkerProcess::Shutdown();
}
JNIEXPORT void JNICALL Java_org_ray_runtime_RayNativeRuntime_nativeSetResource(
JNIEnv *env, jclass, jlong nativeCoreWorkerPointer, jstring resourceName,
jdouble capacity, jbyteArray nodeId) {
JNIEnv *env, jclass, jstring resourceName, jdouble capacity, jbyteArray nodeId) {
const auto node_id = JavaByteArrayToId<ClientID>(env, nodeId);
const char *native_resource_name = env->GetStringUTFChars(resourceName, JNI_FALSE);
auto status =
reinterpret_cast<ray::CoreWorker *>(nativeCoreWorkerPointer)
->SetResource(native_resource_name, static_cast<double>(capacity), node_id);
auto status = ray::CoreWorkerProcess::GetCoreWorker().SetResource(
native_resource_name, static_cast<double>(capacity), node_id);
env->ReleaseStringUTFChars(resourceName, native_resource_name);
THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, (void)0);
}
JNIEXPORT void JNICALL Java_org_ray_runtime_RayNativeRuntime_nativeKillActor(
JNIEnv *env, jclass, jlong nativeCoreWorkerPointer, jbyteArray actorId,
jboolean noReconstruction) {
auto core_worker = reinterpret_cast<ray::CoreWorker *>(nativeCoreWorkerPointer);
auto status = core_worker->KillActor(JavaByteArrayToId<ActorID>(env, actorId),
/*force_kill=*/true, noReconstruction);
JNIEnv *env, jclass, jbyteArray actorId, jboolean noReconstruction) {
auto status = ray::CoreWorkerProcess::GetCoreWorker().KillActor(
JavaByteArrayToId<ActorID>(env, actorId),
/*force_kill=*/true, noReconstruction);
THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, (void)0);
}
JNIEXPORT void JNICALL Java_org_ray_runtime_RayNativeRuntime_nativeSetCoreWorker(
JNIEnv *env, jclass, jbyteArray workerId) {
const auto worker_id = JavaByteArrayToId<ray::WorkerID>(env, workerId);
ray::CoreWorkerProcess::SetCurrentThreadWorkerId(worker_id);
}
#ifdef __cplusplus
}
#endif

View file

@ -23,61 +23,55 @@ extern "C" {
#endif
/*
* Class: org_ray_runtime_RayNativeRuntime
* Method: nativeInitCoreWorker
* Method: nativeInitialize
* Signature:
* (ILjava/lang/String;Ljava/lang/String;Ljava/lang/String;I[BLorg/ray/runtime/gcs/GcsClientOptions;)J
* (ILjava/lang/String;ILjava/lang/String;Ljava/lang/String;Ljava/lang/String;[BLorg/ray/runtime/gcs/GcsClientOptions;ILjava/lang/String;Ljava/util/Map;)V
*/
JNIEXPORT jlong JNICALL Java_org_ray_runtime_RayNativeRuntime_nativeInitCoreWorker(
JNIEnv *, jclass, jint, jstring, jstring, jstring, jint, jbyteArray, jobject);
JNIEXPORT void JNICALL Java_org_ray_runtime_RayNativeRuntime_nativeInitialize(
JNIEnv *, jclass, jint, jstring, jint, jstring, jstring, jstring, jbyteArray, jobject,
jint, jstring, jobject);
/*
* Class: org_ray_runtime_RayNativeRuntime
* Method: nativeRunTaskExecutor
* Signature: (J)V
* Signature: (Lorg/ray/runtime/task/TaskExecutor;)V
*/
JNIEXPORT void JNICALL
Java_org_ray_runtime_RayNativeRuntime_nativeRunTaskExecutor(JNIEnv *, jclass, jlong);
Java_org_ray_runtime_RayNativeRuntime_nativeRunTaskExecutor(JNIEnv *, jclass, jobject);
/*
* Class: org_ray_runtime_RayNativeRuntime
* Method: nativeDestroyCoreWorker
* Signature: (J)V
*/
JNIEXPORT void JNICALL
Java_org_ray_runtime_RayNativeRuntime_nativeDestroyCoreWorker(JNIEnv *, jclass, jlong);
/*
* Class: org_ray_runtime_RayNativeRuntime
* Method: nativeSetup
* Signature: (Ljava/lang/String;Ljava/util/Map;)V
*/
JNIEXPORT void JNICALL Java_org_ray_runtime_RayNativeRuntime_nativeSetup(JNIEnv *, jclass,
jstring,
jobject);
/*
* Class: org_ray_runtime_RayNativeRuntime
* Method: nativeShutdownHook
* Method: nativeShutdown
* Signature: ()V
*/
JNIEXPORT void JNICALL Java_org_ray_runtime_RayNativeRuntime_nativeShutdownHook(JNIEnv *,
jclass);
JNIEXPORT void JNICALL Java_org_ray_runtime_RayNativeRuntime_nativeShutdown(JNIEnv *,
jclass);
/*
* Class: org_ray_runtime_RayNativeRuntime
* Method: nativeSetResource
* Signature: (JLjava/lang/String;D[B)V
* Signature: (Ljava/lang/String;D[B)V
*/
JNIEXPORT void JNICALL Java_org_ray_runtime_RayNativeRuntime_nativeSetResource(
JNIEnv *, jclass, jlong, jstring, jdouble, jbyteArray);
JNIEnv *, jclass, jstring, jdouble, jbyteArray);
/*
* Class: org_ray_runtime_RayNativeRuntime
* Method: nativeKillActor
* Signature: (J[BZ)V
* Signature: ([BZ)V
*/
JNIEXPORT void JNICALL Java_org_ray_runtime_RayNativeRuntime_nativeKillActor(
JNIEnv *, jclass, jlong, jbyteArray, jboolean);
JNIEXPORT void JNICALL Java_org_ray_runtime_RayNativeRuntime_nativeKillActor(JNIEnv *,
jclass,
jbyteArray,
jboolean);
/*
* Class: org_ray_runtime_RayNativeRuntime
* Method: nativeSetCoreWorker
* Signature: ([B)V
*/
JNIEXPORT void JNICALL
Java_org_ray_runtime_RayNativeRuntime_nativeSetCoreWorker(JNIEnv *, jclass, jbyteArray);
#ifdef __cplusplus
}

View file

@ -15,47 +15,44 @@
#include "ray/core_worker/lib/java/org_ray_runtime_actor_NativeRayActor.h"
#include <jni.h>
#include "ray/common/id.h"
#include "ray/core_worker/actor_handle.h"
#include "ray/core_worker/common.h"
#include "ray/core_worker/core_worker.h"
#include "ray/core_worker/lib/java/jni_utils.h"
inline ray::CoreWorker &GetCoreWorker(jlong nativeCoreWorkerPointer) {
return *reinterpret_cast<ray::CoreWorker *>(nativeCoreWorkerPointer);
}
#ifdef __cplusplus
extern "C" {
#endif
JNIEXPORT jint JNICALL Java_org_ray_runtime_actor_NativeRayActor_nativeGetLanguage(
JNIEnv *env, jclass o, jlong nativeCoreWorkerPointer, jbyteArray actorId) {
JNIEnv *env, jclass o, jbyteArray actorId) {
auto actor_id = JavaByteArrayToId<ray::ActorID>(env, actorId);
ray::ActorHandle *native_actor_handle = nullptr;
auto status = GetCoreWorker(nativeCoreWorkerPointer)
.GetActorHandle(actor_id, &native_actor_handle);
auto status = ray::CoreWorkerProcess::GetCoreWorker().GetActorHandle(
actor_id, &native_actor_handle);
THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, false);
return native_actor_handle->ActorLanguage();
}
JNIEXPORT jobject JNICALL
Java_org_ray_runtime_actor_NativeRayActor_nativeGetActorCreationTaskFunctionDescriptor(
JNIEnv *env, jclass o, jlong nativeCoreWorkerPointer, jbyteArray actorId) {
JNIEnv *env, jclass o, jbyteArray actorId) {
auto actor_id = JavaByteArrayToId<ray::ActorID>(env, actorId);
ray::ActorHandle *native_actor_handle = nullptr;
auto status = GetCoreWorker(nativeCoreWorkerPointer)
.GetActorHandle(actor_id, &native_actor_handle);
auto status = ray::CoreWorkerProcess::GetCoreWorker().GetActorHandle(
actor_id, &native_actor_handle);
THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, nullptr);
auto function_descriptor = native_actor_handle->ActorCreationTaskFunctionDescriptor();
return NativeRayFunctionDescriptorToJavaStringList(env, function_descriptor);
}
JNIEXPORT jbyteArray JNICALL Java_org_ray_runtime_actor_NativeRayActor_nativeSerialize(
JNIEnv *env, jclass o, jlong nativeCoreWorkerPointer, jbyteArray actorId) {
JNIEnv *env, jclass o, jbyteArray actorId) {
auto actor_id = JavaByteArrayToId<ray::ActorID>(env, actorId);
std::string output;
ObjectID actor_handle_id;
ray::Status status = GetCoreWorker(nativeCoreWorkerPointer)
.SerializeActorHandle(actor_id, &output, &actor_handle_id);
ray::Status status = ray::CoreWorkerProcess::GetCoreWorker().SerializeActorHandle(
actor_id, &output, &actor_handle_id);
jbyteArray bytes = env->NewByteArray(output.size());
env->SetByteArrayRegion(bytes, 0, output.size(),
reinterpret_cast<const jbyte *>(output.c_str()));
@ -63,13 +60,13 @@ JNIEXPORT jbyteArray JNICALL Java_org_ray_runtime_actor_NativeRayActor_nativeSer
}
JNIEXPORT jbyteArray JNICALL Java_org_ray_runtime_actor_NativeRayActor_nativeDeserialize(
JNIEnv *env, jclass o, jlong nativeCoreWorkerPointer, jbyteArray data) {
JNIEnv *env, jclass o, jbyteArray data) {
auto buffer = JavaByteArrayToNativeBuffer(env, data);
RAY_CHECK(buffer->Size() > 0);
auto binary = std::string(reinterpret_cast<char *>(buffer->Data()), buffer->Size());
auto actor_id =
GetCoreWorker(nativeCoreWorkerPointer)
.DeserializeAndRegisterActorHandle(binary, /*outer_object_id=*/ObjectID::Nil());
ray::CoreWorkerProcess::GetCoreWorker().DeserializeAndRegisterActorHandle(
binary, /*outer_object_id=*/ObjectID::Nil());
return IdToJavaByteArray<ray::ActorID>(env, actor_id);
}

View file

@ -24,35 +24,35 @@ extern "C" {
/*
* Class: org_ray_runtime_actor_NativeRayActor
* Method: nativeGetLanguage
* Signature: (J[B)I
* Signature: ([B)I
*/
JNIEXPORT jint JNICALL Java_org_ray_runtime_actor_NativeRayActor_nativeGetLanguage(
JNIEnv *, jclass, jlong, jbyteArray);
JNIEXPORT jint JNICALL
Java_org_ray_runtime_actor_NativeRayActor_nativeGetLanguage(JNIEnv *, jclass, jbyteArray);
/*
* Class: org_ray_runtime_actor_NativeRayActor
* Method: nativeGetActorCreationTaskFunctionDescriptor
* Signature: (J[B)Ljava/util/List;
* Signature: ([B)Ljava/util/List;
*/
JNIEXPORT jobject JNICALL
Java_org_ray_runtime_actor_NativeRayActor_nativeGetActorCreationTaskFunctionDescriptor(
JNIEnv *, jclass, jlong, jbyteArray);
JNIEnv *, jclass, jbyteArray);
/*
* Class: org_ray_runtime_actor_NativeRayActor
* Method: nativeSerialize
* Signature: (J[B)[B
* Signature: ([B)[B
*/
JNIEXPORT jbyteArray JNICALL Java_org_ray_runtime_actor_NativeRayActor_nativeSerialize(
JNIEnv *, jclass, jlong, jbyteArray);
JNIEXPORT jbyteArray JNICALL
Java_org_ray_runtime_actor_NativeRayActor_nativeSerialize(JNIEnv *, jclass, jbyteArray);
/*
* Class: org_ray_runtime_actor_NativeRayActor
* Method: nativeDeserialize
* Signature: (J[B)[B
* Signature: ([B)[B
*/
JNIEXPORT jbyteArray JNICALL Java_org_ray_runtime_actor_NativeRayActor_nativeDeserialize(
JNIEnv *, jclass, jlong, jbyteArray);
JNIEXPORT jbyteArray JNICALL
Java_org_ray_runtime_actor_NativeRayActor_nativeDeserialize(JNIEnv *, jclass, jbyteArray);
#ifdef __cplusplus
}

View file

@ -19,51 +19,48 @@
#include "ray/core_worker/core_worker.h"
#include "ray/core_worker/lib/java/jni_utils.h"
inline ray::WorkerContext &GetWorkerContextFromPointer(jlong nativeCoreWorkerPointer) {
return reinterpret_cast<ray::CoreWorker *>(nativeCoreWorkerPointer)->GetWorkerContext();
}
#ifdef __cplusplus
extern "C" {
#endif
JNIEXPORT jint JNICALL
Java_org_ray_runtime_context_NativeWorkerContext_nativeGetCurrentTaskType(
JNIEnv *env, jclass, jlong nativeCoreWorkerPointer) {
auto task_spec = GetWorkerContextFromPointer(nativeCoreWorkerPointer).GetCurrentTask();
Java_org_ray_runtime_context_NativeWorkerContext_nativeGetCurrentTaskType(JNIEnv *env,
jclass) {
auto task_spec =
ray::CoreWorkerProcess::GetCoreWorker().GetWorkerContext().GetCurrentTask();
RAY_CHECK(task_spec) << "Current task is not set.";
return static_cast<int>(task_spec->GetMessage().type());
}
JNIEXPORT jobject JNICALL
Java_org_ray_runtime_context_NativeWorkerContext_nativeGetCurrentTaskId(
JNIEnv *env, jclass, jlong nativeCoreWorkerPointer) {
Java_org_ray_runtime_context_NativeWorkerContext_nativeGetCurrentTaskId(JNIEnv *env,
jclass) {
const ray::TaskID &task_id =
GetWorkerContextFromPointer(nativeCoreWorkerPointer).GetCurrentTaskID();
ray::CoreWorkerProcess::GetCoreWorker().GetWorkerContext().GetCurrentTaskID();
return IdToJavaByteBuffer<ray::TaskID>(env, task_id);
}
JNIEXPORT jobject JNICALL
Java_org_ray_runtime_context_NativeWorkerContext_nativeGetCurrentJobId(
JNIEnv *env, jclass, jlong nativeCoreWorkerPointer) {
Java_org_ray_runtime_context_NativeWorkerContext_nativeGetCurrentJobId(JNIEnv *env,
jclass) {
const auto &job_id =
GetWorkerContextFromPointer(nativeCoreWorkerPointer).GetCurrentJobID();
ray::CoreWorkerProcess::GetCoreWorker().GetWorkerContext().GetCurrentJobID();
return IdToJavaByteBuffer<ray::JobID>(env, job_id);
}
JNIEXPORT jobject JNICALL
Java_org_ray_runtime_context_NativeWorkerContext_nativeGetCurrentWorkerId(
JNIEnv *env, jclass, jlong nativeCoreWorkerPointer) {
Java_org_ray_runtime_context_NativeWorkerContext_nativeGetCurrentWorkerId(JNIEnv *env,
jclass) {
const auto &worker_id =
GetWorkerContextFromPointer(nativeCoreWorkerPointer).GetWorkerID();
ray::CoreWorkerProcess::GetCoreWorker().GetWorkerContext().GetWorkerID();
return IdToJavaByteBuffer<ray::WorkerID>(env, worker_id);
}
JNIEXPORT jobject JNICALL
Java_org_ray_runtime_context_NativeWorkerContext_nativeGetCurrentActorId(
JNIEnv *env, jclass, jlong nativeCoreWorkerPointer) {
Java_org_ray_runtime_context_NativeWorkerContext_nativeGetCurrentActorId(JNIEnv *env,
jclass) {
const auto &actor_id =
GetWorkerContextFromPointer(nativeCoreWorkerPointer).GetCurrentActorID();
ray::CoreWorkerProcess::GetCoreWorker().GetWorkerContext().GetCurrentActorID();
return IdToJavaByteBuffer<ray::ActorID>(env, actor_id);
}

View file

@ -24,47 +24,45 @@ extern "C" {
/*
* Class: org_ray_runtime_context_NativeWorkerContext
* Method: nativeGetCurrentTaskType
* Signature: (J)I
* Signature: ()I
*/
JNIEXPORT jint JNICALL
Java_org_ray_runtime_context_NativeWorkerContext_nativeGetCurrentTaskType(JNIEnv *,
jclass, jlong);
jclass);
/*
* Class: org_ray_runtime_context_NativeWorkerContext
* Method: nativeGetCurrentTaskId
* Signature: (J)Ljava/nio/ByteBuffer;
* Signature: ()Ljava/nio/ByteBuffer;
*/
JNIEXPORT jobject JNICALL
Java_org_ray_runtime_context_NativeWorkerContext_nativeGetCurrentTaskId(JNIEnv *, jclass,
jlong);
Java_org_ray_runtime_context_NativeWorkerContext_nativeGetCurrentTaskId(JNIEnv *, jclass);
/*
* Class: org_ray_runtime_context_NativeWorkerContext
* Method: nativeGetCurrentJobId
* Signature: (J)Ljava/nio/ByteBuffer;
* Signature: ()Ljava/nio/ByteBuffer;
*/
JNIEXPORT jobject JNICALL
Java_org_ray_runtime_context_NativeWorkerContext_nativeGetCurrentJobId(JNIEnv *, jclass,
jlong);
Java_org_ray_runtime_context_NativeWorkerContext_nativeGetCurrentJobId(JNIEnv *, jclass);
/*
* Class: org_ray_runtime_context_NativeWorkerContext
* Method: nativeGetCurrentWorkerId
* Signature: (J)Ljava/nio/ByteBuffer;
* Signature: ()Ljava/nio/ByteBuffer;
*/
JNIEXPORT jobject JNICALL
Java_org_ray_runtime_context_NativeWorkerContext_nativeGetCurrentWorkerId(JNIEnv *,
jclass, jlong);
jclass);
/*
* Class: org_ray_runtime_context_NativeWorkerContext
* Method: nativeGetCurrentActorId
* Signature: (J)Ljava/nio/ByteBuffer;
* Signature: ()Ljava/nio/ByteBuffer;
*/
JNIEXPORT jobject JNICALL
Java_org_ray_runtime_context_NativeWorkerContext_nativeGetCurrentActorId(JNIEnv *, jclass,
jlong);
Java_org_ray_runtime_context_NativeWorkerContext_nativeGetCurrentActorId(JNIEnv *,
jclass);
#ifdef __cplusplus
}

View file

@ -24,55 +24,51 @@ extern "C" {
#endif
JNIEXPORT jbyteArray JNICALL
Java_org_ray_runtime_object_NativeObjectStore_nativePut__JLorg_ray_runtime_object_NativeRayObject_2(
JNIEnv *env, jclass, jlong nativeCoreWorkerPointer, jobject obj) {
Java_org_ray_runtime_object_NativeObjectStore_nativePut__Lorg_ray_runtime_object_NativeRayObject_2(
JNIEnv *env, jclass, jobject obj) {
auto ray_object = JavaNativeRayObjectToNativeRayObject(env, obj);
RAY_CHECK(ray_object != nullptr);
ray::ObjectID object_id;
auto status = reinterpret_cast<ray::CoreWorker *>(nativeCoreWorkerPointer)
->Put(*ray_object, {}, &object_id);
auto status = ray::CoreWorkerProcess::GetCoreWorker().Put(*ray_object, {}, &object_id);
THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, nullptr);
return IdToJavaByteArray<ray::ObjectID>(env, object_id);
}
JNIEXPORT void JNICALL
Java_org_ray_runtime_object_NativeObjectStore_nativePut__J_3BLorg_ray_runtime_object_NativeRayObject_2(
JNIEnv *env, jclass, jlong nativeCoreWorkerPointer, jbyteArray objectId,
jobject obj) {
Java_org_ray_runtime_object_NativeObjectStore_nativePut___3BLorg_ray_runtime_object_NativeRayObject_2(
JNIEnv *env, jclass, jbyteArray objectId, jobject obj) {
auto object_id = JavaByteArrayToId<ray::ObjectID>(env, objectId);
auto ray_object = JavaNativeRayObjectToNativeRayObject(env, obj);
RAY_CHECK(ray_object != nullptr);
auto status = reinterpret_cast<ray::CoreWorker *>(nativeCoreWorkerPointer)
->Put(*ray_object, {}, object_id);
auto status = ray::CoreWorkerProcess::GetCoreWorker().Put(*ray_object, {}, object_id);
THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, (void)0);
}
JNIEXPORT jobject JNICALL Java_org_ray_runtime_object_NativeObjectStore_nativeGet(
JNIEnv *env, jclass, jlong nativeCoreWorkerPointer, jobject ids, jlong timeoutMs) {
JNIEnv *env, jclass, jobject ids, jlong timeoutMs) {
std::vector<ray::ObjectID> object_ids;
JavaListToNativeVector<ray::ObjectID>(
env, ids, &object_ids, [](JNIEnv *env, jobject id) {
return JavaByteArrayToId<ray::ObjectID>(env, static_cast<jbyteArray>(id));
});
std::vector<std::shared_ptr<ray::RayObject>> results;
auto status = reinterpret_cast<ray::CoreWorker *>(nativeCoreWorkerPointer)
->Get(object_ids, (int64_t)timeoutMs, &results);
auto status = ray::CoreWorkerProcess::GetCoreWorker().Get(object_ids,
(int64_t)timeoutMs, &results);
THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, nullptr);
return NativeVectorToJavaList<std::shared_ptr<ray::RayObject>>(
env, results, NativeRayObjectToJavaNativeRayObject);
}
JNIEXPORT jobject JNICALL Java_org_ray_runtime_object_NativeObjectStore_nativeWait(
JNIEnv *env, jclass, jlong nativeCoreWorkerPointer, jobject objectIds,
jint numObjects, jlong timeoutMs) {
JNIEnv *env, jclass, jobject objectIds, jint numObjects, jlong timeoutMs) {
std::vector<ray::ObjectID> object_ids;
JavaListToNativeVector<ray::ObjectID>(
env, objectIds, &object_ids, [](JNIEnv *env, jobject id) {
return JavaByteArrayToId<ray::ObjectID>(env, static_cast<jbyteArray>(id));
});
std::vector<bool> results;
auto status = reinterpret_cast<ray::CoreWorker *>(nativeCoreWorkerPointer)
->Wait(object_ids, (int)numObjects, (int64_t)timeoutMs, &results);
auto status = ray::CoreWorkerProcess::GetCoreWorker().Wait(
object_ids, (int)numObjects, (int64_t)timeoutMs, &results);
THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, nullptr);
return NativeVectorToJavaList<bool>(env, results, [](JNIEnv *env, const bool &item) {
return env->NewObject(java_boolean_class, java_boolean_init, (jboolean)item);
@ -80,15 +76,15 @@ JNIEXPORT jobject JNICALL Java_org_ray_runtime_object_NativeObjectStore_nativeWa
}
JNIEXPORT void JNICALL Java_org_ray_runtime_object_NativeObjectStore_nativeDelete(
JNIEnv *env, jclass, jlong nativeCoreWorkerPointer, jobject objectIds,
jboolean localOnly, jboolean deleteCreatingTasks) {
JNIEnv *env, jclass, jobject objectIds, jboolean localOnly,
jboolean deleteCreatingTasks) {
std::vector<ray::ObjectID> object_ids;
JavaListToNativeVector<ray::ObjectID>(
env, objectIds, &object_ids, [](JNIEnv *env, jobject id) {
return JavaByteArrayToId<ray::ObjectID>(env, static_cast<jbyteArray>(id));
});
auto status = reinterpret_cast<ray::CoreWorker *>(nativeCoreWorkerPointer)
->Delete(object_ids, (bool)localOnly, (bool)deleteCreatingTasks);
auto status = ray::CoreWorkerProcess::GetCoreWorker().Delete(
object_ids, (bool)localOnly, (bool)deleteCreatingTasks);
THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, (void)0);
}

View file

@ -24,44 +24,44 @@ extern "C" {
/*
* Class: org_ray_runtime_object_NativeObjectStore
* Method: nativePut
* Signature: (JLorg/ray/runtime/object/NativeRayObject;)[B
* Signature: (Lorg/ray/runtime/object/NativeRayObject;)[B
*/
JNIEXPORT jbyteArray JNICALL
Java_org_ray_runtime_object_NativeObjectStore_nativePut__JLorg_ray_runtime_object_NativeRayObject_2(
JNIEnv *, jclass, jlong, jobject);
Java_org_ray_runtime_object_NativeObjectStore_nativePut__Lorg_ray_runtime_object_NativeRayObject_2(
JNIEnv *, jclass, jobject);
/*
* Class: org_ray_runtime_object_NativeObjectStore
* Method: nativePut
* Signature: (J[BLorg/ray/runtime/object/NativeRayObject;)V
* Signature: ([BLorg/ray/runtime/object/NativeRayObject;)V
*/
JNIEXPORT void JNICALL
Java_org_ray_runtime_object_NativeObjectStore_nativePut__J_3BLorg_ray_runtime_object_NativeRayObject_2(
JNIEnv *, jclass, jlong, jbyteArray, jobject);
Java_org_ray_runtime_object_NativeObjectStore_nativePut___3BLorg_ray_runtime_object_NativeRayObject_2(
JNIEnv *, jclass, jbyteArray, jobject);
/*
* Class: org_ray_runtime_object_NativeObjectStore
* Method: nativeGet
* Signature: (JLjava/util/List;J)Ljava/util/List;
* Signature: (Ljava/util/List;J)Ljava/util/List;
*/
JNIEXPORT jobject JNICALL Java_org_ray_runtime_object_NativeObjectStore_nativeGet(
JNIEnv *, jclass, jlong, jobject, jlong);
JNIEXPORT jobject JNICALL
Java_org_ray_runtime_object_NativeObjectStore_nativeGet(JNIEnv *, jclass, jobject, jlong);
/*
* Class: org_ray_runtime_object_NativeObjectStore
* Method: nativeWait
* Signature: (JLjava/util/List;IJ)Ljava/util/List;
* Signature: (Ljava/util/List;IJ)Ljava/util/List;
*/
JNIEXPORT jobject JNICALL Java_org_ray_runtime_object_NativeObjectStore_nativeWait(
JNIEnv *, jclass, jlong, jobject, jint, jlong);
JNIEnv *, jclass, jobject, jint, jlong);
/*
* Class: org_ray_runtime_object_NativeObjectStore
* Method: nativeDelete
* Signature: (JLjava/util/List;ZZ)V
* Signature: (Ljava/util/List;ZZ)V
*/
JNIEXPORT void JNICALL Java_org_ray_runtime_object_NativeObjectStore_nativeDelete(
JNIEnv *, jclass, jlong, jobject, jboolean, jboolean);
JNIEnv *, jclass, jobject, jboolean, jboolean);
#ifdef __cplusplus
}

View file

@ -27,9 +27,9 @@ extern "C" {
using ray::ClientID;
JNIEXPORT jbyteArray JNICALL
Java_org_ray_runtime_task_NativeTaskExecutor_nativePrepareCheckpoint(
JNIEnv *env, jclass, jlong nativeCoreWorkerPointer) {
auto &core_worker = *reinterpret_cast<ray::CoreWorker *>(nativeCoreWorkerPointer);
Java_org_ray_runtime_task_NativeTaskExecutor_nativePrepareCheckpoint(JNIEnv *env,
jclass) {
auto &core_worker = ray::CoreWorkerProcess::GetCoreWorker();
const auto &actor_id = core_worker.GetWorkerContext().GetCurrentActorID();
const auto &task_spec = core_worker.GetWorkerContext().GetCurrentTask();
RAY_CHECK(task_spec->IsActorTask());
@ -44,11 +44,12 @@ Java_org_ray_runtime_task_NativeTaskExecutor_nativePrepareCheckpoint(
JNIEXPORT void JNICALL
Java_org_ray_runtime_task_NativeTaskExecutor_nativeNotifyActorResumedFromCheckpoint(
JNIEnv *env, jclass, jlong nativeCoreWorkerPointer, jbyteArray checkpointId) {
auto &core_worker = *reinterpret_cast<ray::CoreWorker *>(nativeCoreWorkerPointer);
const auto &actor_id = core_worker.GetWorkerContext().GetCurrentActorID();
JNIEnv *env, jclass, jbyteArray checkpointId) {
const auto &actor_id =
ray::CoreWorkerProcess::GetCoreWorker().GetWorkerContext().GetCurrentActorID();
const auto checkpoint_id = JavaByteArrayToId<ActorCheckpointID>(env, checkpointId);
auto status = core_worker.NotifyActorResumedFromCheckpoint(actor_id, checkpoint_id);
auto status = ray::CoreWorkerProcess::GetCoreWorker().NotifyActorResumedFromCheckpoint(
actor_id, checkpoint_id);
THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, (void)0);
}

View file

@ -26,20 +26,19 @@ extern "C" {
/*
* Class: org_ray_runtime_task_NativeTaskExecutor
* Method: nativePrepareCheckpoint
* Signature: (J)[B
* Signature: ()[B
*/
JNIEXPORT jbyteArray JNICALL
Java_org_ray_runtime_task_NativeTaskExecutor_nativePrepareCheckpoint(JNIEnv *, jclass,
jlong);
Java_org_ray_runtime_task_NativeTaskExecutor_nativePrepareCheckpoint(JNIEnv *, jclass);
/*
* Class: org_ray_runtime_task_NativeTaskExecutor
* Method: nativeNotifyActorResumedFromCheckpoint
* Signature: (J[B)V
* Signature: ([B)V
*/
JNIEXPORT void JNICALL
Java_org_ray_runtime_task_NativeTaskExecutor_nativeNotifyActorResumedFromCheckpoint(
JNIEnv *, jclass, jlong, jbyteArray);
JNIEnv *, jclass, jbyteArray);
#ifdef __cplusplus
}

View file

@ -19,10 +19,6 @@
#include "ray/core_worker/core_worker.h"
#include "ray/core_worker/lib/java/jni_utils.h"
inline ray::CoreWorker &GetCoreWorker(jlong nativeCoreWorkerPointer) {
return *reinterpret_cast<ray::CoreWorker *>(nativeCoreWorkerPointer);
}
inline ray::RayFunction ToRayFunction(JNIEnv *env, jobject functionDescriptor) {
std::vector<std::string> function_descriptor_list;
jobject list =
@ -127,17 +123,17 @@ extern "C" {
#endif
JNIEXPORT jobject JNICALL Java_org_ray_runtime_task_NativeTaskSubmitter_nativeSubmitTask(
JNIEnv *env, jclass p, jlong nativeCoreWorkerPointer, jobject functionDescriptor,
jobject args, jint numReturns, jobject callOptions) {
JNIEnv *env, jclass p, jobject functionDescriptor, jobject args, jint numReturns,
jobject callOptions) {
auto ray_function = ToRayFunction(env, functionDescriptor);
auto task_args = ToTaskArgs(env, args);
auto task_options = ToTaskOptions(env, numReturns, callOptions);
std::vector<ObjectID> return_ids;
// TODO (kfstorm): Allow setting `max_retries` via `CallOptions`.
auto status = GetCoreWorker(nativeCoreWorkerPointer)
.SubmitTask(ray_function, task_args, task_options, &return_ids,
/*max_retries=*/0);
auto status = ray::CoreWorkerProcess::GetCoreWorker().SubmitTask(
ray_function, task_args, task_options, &return_ids,
/*max_retries=*/0);
THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, nullptr);
@ -146,16 +142,16 @@ JNIEXPORT jobject JNICALL Java_org_ray_runtime_task_NativeTaskSubmitter_nativeSu
JNIEXPORT jbyteArray JNICALL
Java_org_ray_runtime_task_NativeTaskSubmitter_nativeCreateActor(
JNIEnv *env, jclass p, jlong nativeCoreWorkerPointer, jobject functionDescriptor,
jobject args, jobject actorCreationOptions) {
JNIEnv *env, jclass p, jobject functionDescriptor, jobject args,
jobject actorCreationOptions) {
auto ray_function = ToRayFunction(env, functionDescriptor);
auto task_args = ToTaskArgs(env, args);
auto actor_creation_options = ToActorCreationOptions(env, actorCreationOptions);
ray::ActorID actor_id;
auto status = GetCoreWorker(nativeCoreWorkerPointer)
.CreateActor(ray_function, task_args, actor_creation_options,
/*extension_data*/ "", &actor_id);
ActorID actor_id;
auto status = ray::CoreWorkerProcess::GetCoreWorker().CreateActor(
ray_function, task_args, actor_creation_options,
/*extension_data*/ "", &actor_id);
THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, nullptr);
return IdToJavaByteArray<ray::ActorID>(env, actor_id);
@ -163,17 +159,16 @@ Java_org_ray_runtime_task_NativeTaskSubmitter_nativeCreateActor(
JNIEXPORT jobject JNICALL
Java_org_ray_runtime_task_NativeTaskSubmitter_nativeSubmitActorTask(
JNIEnv *env, jclass p, jlong nativeCoreWorkerPointer, jbyteArray actorId,
jobject functionDescriptor, jobject args, jint numReturns, jobject callOptions) {
JNIEnv *env, jclass p, jbyteArray actorId, jobject functionDescriptor, jobject args,
jint numReturns, jobject callOptions) {
auto actor_id = JavaByteArrayToId<ray::ActorID>(env, actorId);
auto ray_function = ToRayFunction(env, functionDescriptor);
auto task_args = ToTaskArgs(env, args);
auto task_options = ToTaskOptions(env, numReturns, callOptions);
std::vector<ObjectID> return_ids;
auto status =
GetCoreWorker(nativeCoreWorkerPointer)
.SubmitActorTask(actor_id, ray_function, task_args, task_options, &return_ids);
auto status = ray::CoreWorkerProcess::GetCoreWorker().SubmitActorTask(
actor_id, ray_function, task_args, task_options, &return_ids);
THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, nullptr);
return NativeIdVectorToJavaByteArrayList(env, return_ids);

View file

@ -25,33 +25,32 @@ extern "C" {
* Class: org_ray_runtime_task_NativeTaskSubmitter
* Method: nativeSubmitTask
* Signature:
* (JLorg/ray/runtime/functionmanager/FunctionDescriptor;Ljava/util/List;ILorg/ray/api/options/CallOptions;)Ljava/util/List;
* (Lorg/ray/runtime/functionmanager/FunctionDescriptor;Ljava/util/List;ILorg/ray/api/options/CallOptions;)Ljava/util/List;
*/
JNIEXPORT jobject JNICALL Java_org_ray_runtime_task_NativeTaskSubmitter_nativeSubmitTask(
JNIEnv *, jclass, jlong, jobject, jobject, jint, jobject);
JNIEnv *, jclass, jobject, jobject, jint, jobject);
/*
* Class: org_ray_runtime_task_NativeTaskSubmitter
* Method: nativeCreateActor
* Signature:
* (JLorg/ray/runtime/functionmanager/FunctionDescriptor;Ljava/util/List;Lorg/ray/api/options/ActorCreationOptions;)[B
* (Lorg/ray/runtime/functionmanager/FunctionDescriptor;Ljava/util/List;Lorg/ray/api/options/ActorCreationOptions;)[B
*/
JNIEXPORT jbyteArray JNICALL
Java_org_ray_runtime_task_NativeTaskSubmitter_nativeCreateActor(JNIEnv *, jclass, jlong,
jobject, jobject,
jobject);
Java_org_ray_runtime_task_NativeTaskSubmitter_nativeCreateActor(JNIEnv *, jclass, jobject,
jobject, jobject);
/*
* Class: org_ray_runtime_task_NativeTaskSubmitter
* Method: nativeSubmitActorTask
* Signature:
* (J[BLorg/ray/runtime/functionmanager/FunctionDescriptor;Ljava/util/List;ILorg/ray/api/options/CallOptions;)Ljava/util/List;
* ([BLorg/ray/runtime/functionmanager/FunctionDescriptor;Ljava/util/List;ILorg/ray/api/options/CallOptions;)Ljava/util/List;
*/
JNIEXPORT jobject JNICALL
Java_org_ray_runtime_task_NativeTaskSubmitter_nativeSubmitActorTask(JNIEnv *, jclass,
jlong, jbyteArray,
jobject, jobject,
jint, jobject);
jbyteArray, jobject,
jobject, jint,
jobject);
#ifdef __cplusplus
}

View file

@ -340,7 +340,7 @@ TEST(MemoryStoreIntegrationTest, TestSimple) {
RAY_CHECK(store.Put(buffer, id1));
ASSERT_EQ(store.Size(), 1);
std::vector<std::shared_ptr<RayObject>> results;
WorkerContext ctx(WorkerType::WORKER, JobID::Nil());
WorkerContext ctx(WorkerType::WORKER, WorkerID::FromRandom(), JobID::Nil());
RAY_CHECK_OK(store.Get({id1}, /*num_objects*/ 1, /*timeout_ms*/ -1, ctx,
/*remove_after_get*/ true, &results));
ASSERT_EQ(results.size(), 1);

View file

@ -57,8 +57,7 @@ static void flushall_redis(void) {
redisFree(context);
}
ActorID CreateActorHelper(CoreWorker &worker,
std::unordered_map<std::string, double> &resources,
ActorID CreateActorHelper(std::unordered_map<std::string, double> &resources,
uint64_t max_reconstructions) {
std::unique_ptr<ActorHandle> actor_handle;
@ -78,8 +77,8 @@ ActorID CreateActorHelper(CoreWorker &worker,
// Create an actor.
ActorID actor_id;
RAY_CHECK_OK(
worker.CreateActor(func, args, actor_options, /*extension_data*/ "", &actor_id));
RAY_CHECK_OK(CoreWorkerProcess::GetCoreWorker().CreateActor(
func, args, actor_options, /*extension_data*/ "", &actor_id));
return actor_id;
}
@ -90,7 +89,8 @@ std::string MetadataToString(std::shared_ptr<RayObject> obj) {
class CoreWorkerTest : public ::testing::Test {
public:
CoreWorkerTest(int num_nodes) : gcs_options_("127.0.0.1", 6379, "") {
CoreWorkerTest(int num_nodes)
: num_nodes_(num_nodes), gcs_options_("127.0.0.1", 6379, "") {
#ifdef _WIN32
RAY_CHECK(false) << "port system() calls to Windows before running this test";
#endif
@ -250,9 +250,39 @@ class CoreWorkerTest : public ::testing::Test {
ASSERT_TRUE(system(("rm -f " + gcs_server_pid).c_str()) == 0);
}
void SetUp() {}
void SetUp() {
if (num_nodes_ > 0) {
CoreWorkerOptions options = {
WorkerType::DRIVER, // worker_type
Language::PYTHON, // langauge
raylet_store_socket_names_[0], // store_socket
raylet_socket_names_[0], // raylet_socket
NextJobId(), // job_id
gcs_options_, // gcs_options
"", // log_dir
true, // install_failure_signal_handler
"127.0.0.1", // node_ip_address
node_manager_port, // node_manager_port
"core_worker_test", // driver_name
"", // stdout_file
"", // stderr_file
nullptr, // task_execution_callback
nullptr, // check_signals
nullptr, // gc_collect
nullptr, // get_lang_stack
true, // ref_counting_enabled
false, // is_local_mode
1, // num_workers
};
CoreWorkerProcess::Initialize(options);
}
}
void TearDown() {}
void TearDown() {
if (num_nodes_ > 0) {
CoreWorkerProcess::Shutdown();
}
}
// Test normal tasks.
void TestNormalTask(std::unordered_map<std::string, double> &resources);
@ -271,13 +301,14 @@ class CoreWorkerTest : public ::testing::Test {
void TestActorReconstruction(std::unordered_map<std::string, double> &resources);
protected:
bool WaitForDirectCallActorState(CoreWorker &worker, const ActorID &actor_id,
bool wait_alive, int timeout_ms);
bool WaitForDirectCallActorState(const ActorID &actor_id, bool wait_alive,
int timeout_ms);
// Get the pid for the worker process that runs the actor.
int GetActorPid(CoreWorker &worker, const ActorID &actor_id,
int GetActorPid(const ActorID &actor_id,
std::unordered_map<std::string, double> &resources);
int num_nodes_;
std::vector<std::string> raylet_socket_names_;
std::vector<std::string> raylet_store_socket_names_;
std::string raylet_monitor_pid_;
@ -285,18 +316,19 @@ class CoreWorkerTest : public ::testing::Test {
std::string gcs_server_pid_;
};
bool CoreWorkerTest::WaitForDirectCallActorState(CoreWorker &worker,
const ActorID &actor_id, bool wait_alive,
bool CoreWorkerTest::WaitForDirectCallActorState(const ActorID &actor_id, bool wait_alive,
int timeout_ms) {
auto condition_func = [&worker, actor_id, wait_alive]() -> bool {
bool actor_alive = worker.direct_actor_submitter_->IsActorAlive(actor_id);
auto condition_func = [actor_id, wait_alive]() -> bool {
bool actor_alive =
CoreWorkerProcess::GetCoreWorker().direct_actor_submitter_->IsActorAlive(
actor_id);
return wait_alive ? actor_alive : !actor_alive;
};
return WaitForCondition(condition_func, timeout_ms);
}
int CoreWorkerTest::GetActorPid(CoreWorker &worker, const ActorID &actor_id,
int CoreWorkerTest::GetActorPid(const ActorID &actor_id,
std::unordered_map<std::string, double> &resources) {
std::vector<TaskArg> args;
TaskOptions options{1, resources};
@ -304,10 +336,11 @@ int CoreWorkerTest::GetActorPid(CoreWorker &worker, const ActorID &actor_id,
RayFunction func{Language::PYTHON, ray::FunctionDescriptorBuilder::BuildPython(
"GetWorkerPid", "", "", "")};
RAY_CHECK_OK(worker.SubmitActorTask(actor_id, func, args, options, &return_ids));
RAY_CHECK_OK(CoreWorkerProcess::GetCoreWorker().SubmitActorTask(actor_id, func, args,
options, &return_ids));
std::vector<std::shared_ptr<ray::RayObject>> results;
RAY_CHECK_OK(worker.Get(return_ids, -1, &results));
RAY_CHECK_OK(CoreWorkerProcess::GetCoreWorker().Get(return_ids, -1, &results));
if (nullptr == results[0]->GetData()) {
// If failed to get actor process pid, return -1
@ -320,9 +353,7 @@ int CoreWorkerTest::GetActorPid(CoreWorker &worker, const ActorID &actor_id,
}
void CoreWorkerTest::TestNormalTask(std::unordered_map<std::string, double> &resources) {
CoreWorker driver(WorkerType::DRIVER, Language::PYTHON, raylet_store_socket_names_[0],
raylet_socket_names_[0], NextJobId(), gcs_options_, "", "127.0.0.1",
node_manager_port, nullptr);
auto &driver = CoreWorkerProcess::GetCoreWorker();
// Test for tasks with by-value and by-ref args.
{
@ -364,11 +395,9 @@ void CoreWorkerTest::TestNormalTask(std::unordered_map<std::string, double> &res
}
void CoreWorkerTest::TestActorTask(std::unordered_map<std::string, double> &resources) {
CoreWorker driver(WorkerType::DRIVER, Language::PYTHON, raylet_store_socket_names_[0],
raylet_socket_names_[0], NextJobId(), gcs_options_, "", "127.0.0.1",
node_manager_port, nullptr);
auto &driver = CoreWorkerProcess::GetCoreWorker();
auto actor_id = CreateActorHelper(driver, resources, 1000);
auto actor_id = CreateActorHelper(resources, 1000);
// Test submitting some tasks with by-value args for that actor.
{
@ -452,18 +481,16 @@ void CoreWorkerTest::TestActorTask(std::unordered_map<std::string, double> &reso
void CoreWorkerTest::TestActorReconstruction(
std::unordered_map<std::string, double> &resources) {
CoreWorker driver(WorkerType::DRIVER, Language::PYTHON, raylet_store_socket_names_[0],
raylet_socket_names_[0], NextJobId(), gcs_options_, "", "127.0.0.1",
node_manager_port, nullptr);
auto &driver = CoreWorkerProcess::GetCoreWorker();
// creating actor.
auto actor_id = CreateActorHelper(driver, resources, 1000);
auto actor_id = CreateActorHelper(resources, 1000);
// Wait for actor alive event.
ASSERT_TRUE(WaitForDirectCallActorState(driver, actor_id, true, 30 * 1000 /* 30s */));
ASSERT_TRUE(WaitForDirectCallActorState(actor_id, true, 30 * 1000 /* 30s */));
RAY_LOG(INFO) << "actor has been created";
auto pid = GetActorPid(driver, actor_id, resources);
auto pid = GetActorPid(actor_id, resources);
RAY_CHECK(pid != -1);
// Test submitting some tasks with by-value args for that actor.
@ -477,9 +504,8 @@ void CoreWorkerTest::TestActorReconstruction(
ASSERT_EQ(system("pkill mock_worker"), 0);
// Wait for actor restruction event, and then for alive event.
auto check_actor_restart_func = [this, pid, &driver, &actor_id,
&resources]() -> bool {
auto new_pid = GetActorPid(driver, actor_id, resources);
auto check_actor_restart_func = [this, pid, &actor_id, &resources]() -> bool {
auto new_pid = GetActorPid(actor_id, resources);
return new_pid != -1 && new_pid != pid;
};
ASSERT_TRUE(WaitForCondition(check_actor_restart_func, 30 * 1000 /* 30s */));
@ -514,12 +540,10 @@ void CoreWorkerTest::TestActorReconstruction(
void CoreWorkerTest::TestActorFailure(
std::unordered_map<std::string, double> &resources) {
CoreWorker driver(WorkerType::DRIVER, Language::PYTHON, raylet_store_socket_names_[0],
raylet_socket_names_[0], NextJobId(), gcs_options_, "", "127.0.0.1",
node_manager_port, nullptr);
auto &driver = CoreWorkerProcess::GetCoreWorker();
// creating actor.
auto actor_id = CreateActorHelper(driver, resources, 0 /* not reconstructable */);
auto actor_id = CreateActorHelper(resources, 0 /* not reconstructable */);
// Test submitting some tasks with by-value args for that actor.
{
@ -666,16 +690,14 @@ TEST_F(ZeroNodeTest, TestTaskSpecPerf) {
}
TEST_F(SingleNodeTest, TestDirectActorTaskSubmissionPerf) {
CoreWorker driver(WorkerType::DRIVER, Language::PYTHON, raylet_store_socket_names_[0],
raylet_socket_names_[0], JobID::FromInt(1), gcs_options_, "",
"127.0.0.1", node_manager_port, nullptr);
auto &driver = CoreWorkerProcess::GetCoreWorker();
std::vector<ObjectID> object_ids;
// Create an actor.
std::unordered_map<std::string, double> resources;
auto actor_id = CreateActorHelper(driver, resources,
auto actor_id = CreateActorHelper(resources,
/*max_reconstructions=*/0);
// wait for actor creation finish.
ASSERT_TRUE(WaitForDirectCallActorState(driver, actor_id, true, 30 * 1000 /* 30s */));
ASSERT_TRUE(WaitForDirectCallActorState(actor_id, true, 30 * 1000 /* 30s */));
// Test submitting some tasks with by-value args for that actor.
int64_t start_ms = current_time_ms();
const int num_tasks = 100000;
@ -713,7 +735,7 @@ TEST_F(SingleNodeTest, TestDirectActorTaskSubmissionPerf) {
TEST_F(ZeroNodeTest, TestWorkerContext) {
auto job_id = NextJobId();
WorkerContext context(WorkerType::WORKER, job_id);
WorkerContext context(WorkerType::WORKER, WorkerID::FromRandom(), job_id);
ASSERT_TRUE(context.GetCurrentTaskID().IsNil());
ASSERT_EQ(context.GetNextTaskIndex(), 1);
ASSERT_EQ(context.GetNextTaskIndex(), 2);
@ -779,7 +801,7 @@ TEST_F(SingleNodeTest, TestMemoryStoreProvider) {
absl::flat_hash_set<ObjectID> wait_results;
ObjectID nonexistent_id = ObjectID::FromRandom().WithDirectTransportType();
WorkerContext ctx(WorkerType::WORKER, JobID::Nil());
WorkerContext ctx(WorkerType::WORKER, WorkerID::FromRandom(), JobID::Nil());
wait_ids.insert(nonexistent_id);
RAY_CHECK_OK(provider.Wait(wait_ids, ids.size() + 1, 100, ctx, &wait_results));
ASSERT_EQ(wait_results.size(), ids.size());
@ -880,10 +902,7 @@ TEST_F(SingleNodeTest, TestMemoryStoreProvider) {
}
TEST_F(SingleNodeTest, TestObjectInterface) {
CoreWorker core_worker(WorkerType::DRIVER, Language::PYTHON,
raylet_store_socket_names_[0], raylet_socket_names_[0],
JobID::FromInt(1), gcs_options_, "", "127.0.0.1",
node_manager_port, nullptr);
auto &core_worker = CoreWorkerProcess::GetCoreWorker();
uint8_t array1[] = {1, 2, 3, 4, 5, 6, 7, 8};
uint8_t array2[] = {10, 11, 12, 13, 14, 15};

View file

@ -233,7 +233,7 @@ rpc::PushTaskRequest CreatePushTaskRequestHelper(ActorID actor_id, int64_t count
class MockWorkerContext : public WorkerContext {
public:
MockWorkerContext(WorkerType worker_type, const JobID &job_id)
: WorkerContext(worker_type, job_id) {
: WorkerContext(worker_type, WorkerID::FromRandom(), job_id) {
current_actor_is_direct_call_ = true;
}
};

View file

@ -33,13 +33,34 @@ namespace ray {
class MockWorker {
public:
MockWorker(const std::string &store_socket, const std::string &raylet_socket,
int node_manager_port, const gcs::GcsClientOptions &gcs_options)
: worker_(WorkerType::WORKER, Language::PYTHON, store_socket, raylet_socket,
JobID::FromInt(1), gcs_options, /*log_dir=*/"",
/*node_id_address=*/"127.0.0.1", node_manager_port,
std::bind(&MockWorker::ExecuteTask, this, _1, _2, _3, _4, _5, _6, _7)) {}
int node_manager_port, const gcs::GcsClientOptions &gcs_options) {
CoreWorkerOptions options = {
WorkerType::WORKER, // worker_type
Language::PYTHON, // langauge
store_socket, // store_socket
raylet_socket, // raylet_socket
JobID::FromInt(1), // job_id
gcs_options, // gcs_options
"", // log_dir
true, // install_failure_signal_handler
"127.0.0.1", // node_ip_address
node_manager_port, // node_manager_port
"", // driver_name
"", // stdout_file
"", // stderr_file
std::bind(&MockWorker::ExecuteTask, this, _1, _2, _3, _4, _5, _6,
_7), // task_execution_callback
nullptr, // check_signals
nullptr, // gc_collect
nullptr, // get_lang_stack
true, // ref_counting_enabled
false, // is_local_mode
1, // num_workers
};
CoreWorkerProcess::Initialize(options);
}
void StartExecutingTasks() { worker_.StartExecutingTasks(); }
void RunTaskExecutionLoop() { CoreWorkerProcess::RunTaskExecutionLoop(); }
private:
Status ExecuteTask(TaskType task_type, const RayFunction &ray_function,
@ -112,7 +133,6 @@ class MockWorker {
return Status::OK();
}
CoreWorker worker_;
int64_t prev_seq_no_ = 0;
};
@ -126,6 +146,6 @@ int main(int argc, char **argv) {
ray::gcs::GcsClientOptions gcs_options("127.0.0.1", 6379, "");
ray::MockWorker worker(store_socket, raylet_socket, node_manager_port, gcs_options);
worker.StartExecutingTasks();
worker.RunTaskExecutionLoop();
return 0;
}

View file

@ -37,7 +37,7 @@ class MockWaiter : public DependencyWaiter {
TEST(SchedulingQueueTest, TestInOrder) {
boost::asio::io_service io_service;
MockWaiter waiter;
WorkerContext context(WorkerType::WORKER, JobID::Nil());
WorkerContext context(WorkerType::WORKER, WorkerID::FromRandom(), JobID::Nil());
SchedulingQueue queue(io_service, waiter, context);
int n_ok = 0;
int n_rej = 0;
@ -58,7 +58,7 @@ TEST(SchedulingQueueTest, TestWaitForObjects) {
ObjectID obj3 = ObjectID::FromRandom();
boost::asio::io_service io_service;
MockWaiter waiter;
WorkerContext context(WorkerType::WORKER, JobID::Nil());
WorkerContext context(WorkerType::WORKER, WorkerID::FromRandom(), JobID::Nil());
SchedulingQueue queue(io_service, waiter, context);
int n_ok = 0;
int n_rej = 0;
@ -84,7 +84,7 @@ TEST(SchedulingQueueTest, TestWaitForObjectsNotSubjectToSeqTimeout) {
ObjectID obj1 = ObjectID::FromRandom();
boost::asio::io_service io_service;
MockWaiter waiter;
WorkerContext context(WorkerType::WORKER, JobID::Nil());
WorkerContext context(WorkerType::WORKER, WorkerID::FromRandom(), JobID::Nil());
SchedulingQueue queue(io_service, waiter, context);
int n_ok = 0;
int n_rej = 0;
@ -102,7 +102,7 @@ TEST(SchedulingQueueTest, TestWaitForObjectsNotSubjectToSeqTimeout) {
TEST(SchedulingQueueTest, TestOutOfOrder) {
boost::asio::io_service io_service;
MockWaiter waiter;
WorkerContext context(WorkerType::WORKER, JobID::Nil());
WorkerContext context(WorkerType::WORKER, WorkerID::FromRandom(), JobID::Nil());
SchedulingQueue queue(io_service, waiter, context);
int n_ok = 0;
int n_rej = 0;
@ -120,7 +120,7 @@ TEST(SchedulingQueueTest, TestOutOfOrder) {
TEST(SchedulingQueueTest, TestSeqWaitTimeout) {
boost::asio::io_service io_service;
MockWaiter waiter;
WorkerContext context(WorkerType::WORKER, JobID::Nil());
WorkerContext context(WorkerType::WORKER, WorkerID::FromRandom(), JobID::Nil());
SchedulingQueue queue(io_service, waiter, context);
int n_ok = 0;
int n_rej = 0;
@ -143,7 +143,7 @@ TEST(SchedulingQueueTest, TestSeqWaitTimeout) {
TEST(SchedulingQueueTest, TestSkipAlreadyProcessedByClient) {
boost::asio::io_service io_service;
MockWaiter waiter;
WorkerContext context(WorkerType::WORKER, JobID::Nil());
WorkerContext context(WorkerType::WORKER, WorkerID::FromRandom(), JobID::Nil());
SchedulingQueue queue(io_service, waiter, context);
int n_ok = 0;
int n_rej = 0;

View file

@ -79,7 +79,7 @@ TEST_F(TaskManagerTest, TestTaskSuccess) {
ASSERT_TRUE(manager_.IsTaskPending(spec.TaskId()));
ASSERT_EQ(reference_counter_->NumObjectIDsInScope(), 3);
auto return_id = spec.ReturnId(0, TaskTransportType::DIRECT);
WorkerContext ctx(WorkerType::WORKER, JobID::FromInt(0));
WorkerContext ctx(WorkerType::WORKER, WorkerID::FromRandom(), JobID::FromInt(0));
rpc::PushTaskReply reply;
auto return_object = reply.add_return_objects();
@ -119,7 +119,7 @@ TEST_F(TaskManagerTest, TestTaskFailure) {
ASSERT_TRUE(manager_.IsTaskPending(spec.TaskId()));
ASSERT_EQ(reference_counter_->NumObjectIDsInScope(), 3);
auto return_id = spec.ReturnId(0, TaskTransportType::DIRECT);
WorkerContext ctx(WorkerType::WORKER, JobID::FromInt(0));
WorkerContext ctx(WorkerType::WORKER, WorkerID::FromRandom(), JobID::FromInt(0));
auto error = rpc::ErrorType::WORKER_DIED;
manager_.PendingTaskFailed(spec.TaskId(), error);
@ -155,7 +155,7 @@ TEST_F(TaskManagerTest, TestTaskRetry) {
ASSERT_TRUE(manager_.IsTaskPending(spec.TaskId()));
ASSERT_EQ(reference_counter_->NumObjectIDsInScope(), 3);
auto return_id = spec.ReturnId(0, TaskTransportType::DIRECT);
WorkerContext ctx(WorkerType::WORKER, JobID::FromInt(0));
WorkerContext ctx(WorkerType::WORKER, WorkerID::FromRandom(), JobID::FromInt(0));
auto error = rpc::ErrorType::WORKER_DIED;
for (int i = 0; i < num_retries; i++) {

View file

@ -568,6 +568,17 @@ class WorkerInfoAccessor {
const std::shared_ptr<rpc::WorkerFailureData> &data_ptr,
const StatusCallback &callback) = 0;
/// Register a worker to GCS asynchronously.
///
/// \param worker_type The type of the worker.
/// \param worker_id The ID of the worker.
/// \param worker_info The information of the worker.
/// \return Status.
virtual Status AsyncRegisterWorker(
rpc::WorkerType worker_type, const WorkerID &worker_id,
const std::unordered_map<std::string, std::string> &worker_info,
const StatusCallback &callback) = 0;
protected:
WorkerInfoAccessor() = default;
};

View file

@ -45,6 +45,8 @@ class GcsClientOptions {
password_(password),
is_test_client_(is_test_client) {}
GcsClientOptions() {}
// GCS server address
std::string server_ip_;
int server_port_;

View file

@ -875,5 +875,25 @@ Status ServiceBasedWorkerInfoAccessor::AsyncReportWorkerFailure(
return Status::OK();
}
Status ServiceBasedWorkerInfoAccessor::AsyncRegisterWorker(
rpc::WorkerType worker_type, const WorkerID &worker_id,
const std::unordered_map<std::string, std::string> &worker_info,
const StatusCallback &callback) {
RAY_LOG(DEBUG) << "Registering the worker. worker id = " << worker_id;
rpc::RegisterWorkerRequest request;
request.set_worker_type(worker_type);
request.set_worker_id(worker_id.Binary());
request.mutable_worker_info()->insert(worker_info.begin(), worker_info.end());
client_impl_->GetGcsRpcClient().RegisterWorker(
request,
[worker_id, callback](const Status &status, const rpc::RegisterWorkerReply &reply) {
if (callback) {
callback(status);
}
RAY_LOG(DEBUG) << "Finished registering worker. worker id = " << worker_id;
});
return Status::OK();
}
} // namespace gcs
} // namespace ray

View file

@ -329,6 +329,11 @@ class ServiceBasedWorkerInfoAccessor : public WorkerInfoAccessor {
Status AsyncReportWorkerFailure(const std::shared_ptr<rpc::WorkerFailureData> &data_ptr,
const StatusCallback &callback) override;
Status AsyncRegisterWorker(
rpc::WorkerType worker_type, const WorkerID &worker_id,
const std::unordered_map<std::string, std::string> &worker_info,
const StatusCallback &callback) override;
private:
ServiceBasedGcsClient *client_impl_;

View file

@ -40,5 +40,27 @@ void DefaultWorkerInfoHandler::HandleReportWorkerFailure(
RAY_LOG(DEBUG) << "Finished reporting worker failure, " << worker_address.DebugString();
}
void DefaultWorkerInfoHandler::HandleRegisterWorker(
const RegisterWorkerRequest &request, RegisterWorkerReply *reply,
SendReplyCallback send_reply_callback) {
auto worker_type = request.worker_type();
auto worker_id = WorkerID::FromBinary(request.worker_id());
auto worker_info = MapFromProtobuf(request.worker_info());
auto on_done = [worker_id, reply, send_reply_callback](Status status) {
if (!status.ok()) {
RAY_LOG(ERROR) << "Failed to register worker " << worker_id;
}
GCS_RPC_SEND_REPLY(send_reply_callback, reply, Status::OK());
};
Status status = gcs_client_.Workers().AsyncRegisterWorker(worker_type, worker_id,
worker_info, on_done);
if (!status.ok()) {
on_done(status);
}
RAY_LOG(DEBUG) << "Finished registering worker " << worker_id;
}
} // namespace rpc
} // namespace ray

View file

@ -31,6 +31,10 @@ class DefaultWorkerInfoHandler : public rpc::WorkerInfoHandler {
ReportWorkerFailureReply *reply,
SendReplyCallback send_reply_callback) override;
void HandleRegisterWorker(const RegisterWorkerRequest &request,
RegisterWorkerReply *reply,
SendReplyCallback send_reply_callback) override;
private:
gcs::RedisGcsClient &gcs_client_;
};

View file

@ -747,6 +747,30 @@ Status RedisWorkerInfoAccessor::AsyncReportWorkerFailure(
return worker_failure_table.Add(JobID::Nil(), worker_id, data_ptr, on_done);
}
Status RedisWorkerInfoAccessor::AsyncRegisterWorker(
rpc::WorkerType worker_type, const WorkerID &worker_id,
const std::unordered_map<std::string, std::string> &worker_info,
const StatusCallback &callback) {
std::vector<std::string> args;
args.emplace_back("HMSET");
if (worker_type == rpc::WorkerType::DRIVER) {
args.emplace_back("Drivers:" + worker_id.Binary());
} else {
args.emplace_back("Workers:" + worker_id.Binary());
}
for (const auto &entry : worker_info) {
args.push_back(entry.first);
args.push_back(entry.second);
}
auto status = client_impl_->primary_context()->RunArgvAsync(args);
if (callback) {
// TODO (kfstorm): Invoke the callback asynchronously.
callback(status);
}
return status;
}
} // namespace gcs
} // namespace ray

View file

@ -397,6 +397,11 @@ class RedisWorkerInfoAccessor : public WorkerInfoAccessor {
Status AsyncReportWorkerFailure(const std::shared_ptr<WorkerFailureData> &data_ptr,
const StatusCallback &callback) override;
Status AsyncRegisterWorker(
rpc::WorkerType worker_type, const WorkerID &worker_id,
const std::unordered_map<std::string, std::string> &worker_info,
const StatusCallback &callback) override;
private:
RedisGcsClient *client_impl_{nullptr};

View file

@ -16,6 +16,7 @@ syntax = "proto3";
package ray.rpc;
import "src/ray/protobuf/common.proto";
import "src/ray/protobuf/gcs.proto";
message GcsStatus {
@ -345,8 +346,23 @@ message ReportWorkerFailureReply {
GcsStatus status = 1;
}
message RegisterWorkerRequest {
/// The type of the worker.
WorkerType worker_type = 1;
/// The ID of the worker.
bytes worker_id = 2;
/// The information of the worker in a dictionary.
map<string, bytes> worker_info = 3;
}
message RegisterWorkerReply {
GcsStatus status = 1;
}
// Service for worker info access.
service WorkerInfoGcsService {
// Report a worker failure to GCS Service.
rpc ReportWorkerFailure(ReportWorkerFailureRequest) returns (ReportWorkerFailureReply);
// Register a worker to GCS Service.
rpc RegisterWorker(RegisterWorkerRequest) returns (RegisterWorkerReply);
}

View file

@ -182,6 +182,10 @@ class GcsRpcClient {
VOID_GCS_RPC_CLIENT_METHOD(WorkerInfoGcsService, ReportWorkerFailure,
worker_info_grpc_client_, )
/// Register a worker to GCS Service.
VOID_GCS_RPC_CLIENT_METHOD(WorkerInfoGcsService, RegisterWorker,
worker_info_grpc_client_, )
private:
void Init(const std::string &address, const int port,
ClientCallManager &client_call_manager) {

View file

@ -390,6 +390,10 @@ class WorkerInfoGcsServiceHandler {
virtual void HandleReportWorkerFailure(const ReportWorkerFailureRequest &request,
ReportWorkerFailureReply *reply,
SendReplyCallback send_reply_callback) = 0;
virtual void HandleRegisterWorker(const RegisterWorkerRequest &request,
RegisterWorkerReply *reply,
SendReplyCallback send_reply_callback) = 0;
};
/// The `GrpcService` for `WorkerInfoGcsService`.
@ -409,6 +413,7 @@ class WorkerInfoGrpcService : public GrpcService {
const std::unique_ptr<grpc::ServerCompletionQueue> &cq,
std::vector<std::unique_ptr<ServerCallFactory>> *server_call_factories) override {
WORKER_INFO_SERVICE_RPC_HANDLER(ReportWorkerFailure);
WORKER_INFO_SERVICE_RPC_HANDLER(RegisterWorker);
}
private:

View file

@ -1,6 +1,5 @@
package org.ray.streaming.runtime.transfer;
import com.google.common.base.Preconditions;
import org.ray.runtime.RayNativeRuntime;
import org.ray.runtime.functionmanager.FunctionDescriptor;
import org.ray.runtime.functionmanager.JavaFunctionDescriptor;
@ -24,16 +23,12 @@ public class TransferHandler {
private long writerClientNative;
private long readerClientNative;
public TransferHandler(long coreWorkerNative,
JavaFunctionDescriptor writerAsyncFunc,
public TransferHandler(JavaFunctionDescriptor writerAsyncFunc,
JavaFunctionDescriptor writerSyncFunc,
JavaFunctionDescriptor readerAsyncFunc,
JavaFunctionDescriptor readerSyncFunc) {
Preconditions.checkArgument(coreWorkerNative != 0);
writerClientNative = createWriterClientNative(
coreWorkerNative, writerAsyncFunc, writerSyncFunc);
readerClientNative = createReaderClientNative(
coreWorkerNative, readerAsyncFunc, readerSyncFunc);
writerClientNative = createWriterClientNative(writerAsyncFunc, writerSyncFunc);
readerClientNative = createReaderClientNative(readerAsyncFunc, readerSyncFunc);
}
public void onWriterMessage(byte[] buffer) {
@ -53,12 +48,10 @@ public class TransferHandler {
}
private native long createWriterClientNative(
long coreWorkerNative,
FunctionDescriptor asyncFunc,
FunctionDescriptor syncFunc);
private native long createReaderClientNative(
long coreWorkerNative,
FunctionDescriptor asyncFunc,
FunctionDescriptor syncFunc);

View file

@ -2,9 +2,6 @@ package org.ray.streaming.runtime.worker;
import java.io.Serializable;
import java.util.Map;
import org.ray.api.Ray;
import org.ray.runtime.RayMultiWorkerNativeRuntime;
import org.ray.runtime.functionmanager.JavaFunctionDescriptor;
import org.ray.streaming.runtime.core.graph.ExecutionGraph;
import org.ray.streaming.runtime.core.graph.ExecutionNode;
@ -62,7 +59,6 @@ public class JobWorker implements Serializable {
Config.CHANNEL_TYPE, Config.DEFAULT_CHANNEL_TYPE);
if (channelType.equals(Config.NATIVE_CHANNEL)) {
transferHandler = new TransferHandler(
getNativeCoreWorker(),
new JavaFunctionDescriptor(JobWorker.class.getName(), "onWriterMessage", "([B)V"),
new JavaFunctionDescriptor(JobWorker.class.getName(), "onWriterMessageSync", "([B)[B"),
new JavaFunctionDescriptor(JobWorker.class.getName(), "onReaderMessage", "([B)V"),
@ -148,13 +144,4 @@ public class JobWorker implements Serializable {
public byte[] onWriterMessageSync(byte[] buffer) {
return transferHandler.onWriterMessageSync(buffer);
}
private static long getNativeCoreWorker() {
long pointer = 0;
if (Ray.internal() instanceof RayMultiWorkerNativeRuntime) {
pointer = ((RayMultiWorkerNativeRuntime) Ray.internal())
.getCurrentRuntime().getNativeCoreWorkerPointer();
}
return pointer;
}
}

View file

@ -10,7 +10,6 @@ import java.util.Random;
import org.ray.api.Ray;
import org.ray.api.RayActor;
import org.ray.api.id.ActorId;
import org.ray.runtime.RayMultiWorkerNativeRuntime;
import org.ray.runtime.actor.NativeRayActor;
import org.ray.runtime.functionmanager.JavaFunctionDescriptor;
import org.ray.streaming.runtime.transfer.ChannelID;
@ -29,8 +28,7 @@ public class Worker {
protected TransferHandler transferHandler = null;
public Worker() {
transferHandler = new TransferHandler(((RayMultiWorkerNativeRuntime) Ray.internal())
.getCurrentRuntime().getNativeCoreWorkerPointer(),
transferHandler = new TransferHandler(
new JavaFunctionDescriptor(Worker.class.getName(),
"onWriterMessage", "([B)V"),
new JavaFunctionDescriptor(Worker.class.getName(),

View file

@ -31,7 +31,6 @@ from ray.includes.unique_ids cimport (
CTaskID,
CObjectID,
)
from ray.includes.libcoreworker cimport CCoreWorker
cdef extern from "status.h" namespace "ray::streaming" nogil:
cdef cppclass CStreamingStatus "ray::streaming::StreamingStatus":
@ -100,15 +99,13 @@ cdef extern from "message/message_bundle.h" namespace "ray::streaming" nogil:
cdef extern from "queue/queue_client.h" namespace "ray::streaming" nogil:
cdef cppclass CReaderClient "ray::streaming::ReaderClient":
CReaderClient(CCoreWorker *core_worker,
CRayFunction &async_func,
CReaderClient(CRayFunction &async_func,
CRayFunction &sync_func)
void OnReaderMessage(shared_ptr[CLocalMemoryBuffer] buffer);
shared_ptr[CLocalMemoryBuffer] OnReaderMessageSync(shared_ptr[CLocalMemoryBuffer] buffer);
cdef cppclass CWriterClient "ray::streaming::WriterClient":
CWriterClient(CCoreWorker *core_worker,
CRayFunction &async_func,
CWriterClient(CRayFunction &async_func,
CRayFunction &sync_func)
void OnWriterMessage(shared_ptr[CLocalMemoryBuffer] buffer);
shared_ptr[CLocalMemoryBuffer] OnWriterMessageSync(shared_ptr[CLocalMemoryBuffer] buffer);

View file

@ -19,14 +19,11 @@ from ray.includes.unique_ids cimport (
)
from ray._raylet cimport (
Buffer,
CoreWorker,
ActorID,
ObjectID,
FunctionDescriptor,
)
from ray.includes.libcoreworker cimport CCoreWorker
cimport ray.streaming.includes.libstreaming as libstreaming
from ray.streaming.includes.libstreaming cimport (
CStreamingStatus,
@ -52,16 +49,14 @@ cdef class ReaderClient:
CReaderClient *client
def __cinit__(self,
CoreWorker worker,
FunctionDescriptor async_func,
FunctionDescriptor sync_func):
cdef:
CCoreWorker *core_worker = worker.core_worker.get()
CRayFunction async_native_func
CRayFunction sync_native_func
async_native_func = CRayFunction(LANGUAGE_PYTHON, async_func.descriptor)
sync_native_func = CRayFunction(LANGUAGE_PYTHON, sync_func.descriptor)
self.client = new CReaderClient(core_worker, async_native_func, sync_native_func)
self.client = new CReaderClient(async_native_func, sync_native_func)
def __dealloc__(self):
del self.client
@ -91,16 +86,14 @@ cdef class WriterClient:
CWriterClient * client
def __cinit__(self,
CoreWorker worker,
FunctionDescriptor async_func,
FunctionDescriptor sync_func):
cdef:
CCoreWorker *core_worker = worker.core_worker.get()
CRayFunction async_native_func
CRayFunction sync_native_func
async_native_func = CRayFunction(LANGUAGE_PYTHON, async_func.descriptor)
sync_native_func = CRayFunction(LANGUAGE_PYTHON, sync_func.descriptor)
self.client = new CWriterClient(core_worker, async_native_func, sync_native_func)
self.client = new CWriterClient(async_native_func, sync_native_func)
def __dealloc__(self):
del self.client

View file

@ -48,7 +48,6 @@ class JobWorker(object):
self.task_id, self.stream_processor))
if self.config.get(Config.CHANNEL_TYPE, Config.NATIVE_CHANNEL):
core_worker = ray.worker.global_worker.core_worker
reader_async_func = PythonFunctionDescriptor(
__name__, self.on_reader_message.__name__,
self.__class__.__name__)
@ -56,7 +55,7 @@ class JobWorker(object):
__name__, self.on_reader_message_sync.__name__,
self.__class__.__name__)
self.reader_client = _streaming.ReaderClient(
core_worker, reader_async_func, reader_sync_func)
reader_async_func, reader_sync_func)
writer_async_func = PythonFunctionDescriptor(
__name__, self.on_writer_message.__name__,
self.__class__.__name__)
@ -64,7 +63,7 @@ class JobWorker(object):
__name__, self.on_writer_message_sync.__name__,
self.__class__.__name__)
self.writer_client = _streaming.WriterClient(
core_worker, writer_async_func, writer_sync_func)
writer_async_func, writer_sync_func)
self.task = self.create_stream_task()
self.task.start()

View file

@ -12,21 +12,20 @@ from ray.streaming.config import Config
@ray.remote
class Worker:
def __init__(self):
core_worker = ray.worker.global_worker.core_worker
writer_async_func = PythonFunctionDescriptor(
__name__, self.on_writer_message.__name__, self.__class__.__name__)
writer_sync_func = PythonFunctionDescriptor(
__name__, self.on_writer_message_sync.__name__,
self.__class__.__name__)
self.writer_client = _streaming.WriterClient(
core_worker, writer_async_func, writer_sync_func)
self.writer_client = _streaming.WriterClient(writer_async_func,
writer_sync_func)
reader_async_func = PythonFunctionDescriptor(
__name__, self.on_reader_message.__name__, self.__class__.__name__)
reader_sync_func = PythonFunctionDescriptor(
__name__, self.on_reader_message_sync.__name__,
self.__class__.__name__)
self.reader_client = _streaming.ReaderClient(
core_worker, reader_async_func, reader_sync_func)
self.reader_client = _streaming.ReaderClient(reader_async_func,
reader_sync_func)
self.writer = None
self.output_channel_id = None
self.reader = None

View file

@ -1,6 +1,7 @@
#include <unordered_set>
#include "event_service.h"
namespace ray {
namespace streaming {
@ -105,7 +106,11 @@ Event &EventQueue::Front() {
}
EventService::EventService(uint32_t event_size)
: event_queue_(std::make_shared<EventQueue>(event_size)), stop_flag_(false) {}
: worker_id_(CoreWorkerProcess::IsInitialized()
? CoreWorkerProcess::GetCoreWorker().GetWorkerID()
: WorkerID::Nil()),
event_queue_(std::make_shared<EventQueue>(event_size)),
stop_flag_(false) {}
EventService::~EventService() {
stop_flag_ = true;
// No need to join if loop thread has never been created.
@ -154,6 +159,9 @@ void EventService::Execute(Event &event) {
}
void EventService::LoopThreadHandler() {
if (CoreWorkerProcess::IsInitialized()) {
CoreWorkerProcess::SetCurrentThreadWorkerId(worker_id_);
}
while (true) {
if (stop_flag_) {
break;

View file

@ -7,6 +7,7 @@
#include <unordered_map>
#include "channel.h"
#include "ray/core_worker/core_worker.h"
#include "ring_buffer.h"
#include "util/streaming_util.h"
@ -127,6 +128,7 @@ class EventService {
void LoopThreadHandler();
private:
WorkerID worker_id_;
std::unordered_map<EventType, Handle, EnumTypeHash> event_handle_map_;
std::shared_ptr<EventQueue> event_queue_;
std::shared_ptr<std::thread> loop_thread_;

View file

@ -14,25 +14,19 @@ static std::shared_ptr<ray::LocalMemoryBuffer> JByteArrayToBuffer(JNIEnv *env,
JNIEXPORT jlong JNICALL
Java_org_ray_streaming_runtime_transfer_TransferHandler_createWriterClientNative(
JNIEnv *env, jobject this_obj, jlong core_worker_ptr, jobject async_func,
jobject sync_func) {
JNIEnv *env, jobject this_obj, jobject async_func, jobject sync_func) {
auto ray_async_func = FunctionDescriptorToRayFunction(env, async_func);
auto ray_sync_func = FunctionDescriptorToRayFunction(env, sync_func);
auto *writer_client =
new WriterClient(reinterpret_cast<ray::CoreWorker *>(core_worker_ptr),
ray_async_func, ray_sync_func);
auto *writer_client = new WriterClient(ray_async_func, ray_sync_func);
return reinterpret_cast<jlong>(writer_client);
}
JNIEXPORT jlong JNICALL
Java_org_ray_streaming_runtime_transfer_TransferHandler_createReaderClientNative(
JNIEnv *env, jobject this_obj, jlong core_worker_ptr, jobject async_func,
jobject sync_func) {
JNIEnv *env, jobject this_obj, jobject async_func, jobject sync_func) {
ray::RayFunction ray_async_func = FunctionDescriptorToRayFunction(env, async_func);
ray::RayFunction ray_sync_func = FunctionDescriptorToRayFunction(env, sync_func);
auto *reader_client =
new ReaderClient(reinterpret_cast<ray::CoreWorker *>(core_worker_ptr),
ray_async_func, ray_sync_func);
auto *reader_client = new ReaderClient(ray_async_func, ray_sync_func);
return reinterpret_cast<jlong>(reader_client);
}

View file

@ -12,48 +12,58 @@ extern "C" {
* Method: createWriterClientNative
* Signature: (J)J
*/
JNIEXPORT jlong JNICALL Java_org_ray_streaming_runtime_transfer_TransferHandler_createWriterClientNative
(JNIEnv *, jobject, jlong, jobject, jobject);
JNIEXPORT jlong JNICALL
Java_org_ray_streaming_runtime_transfer_TransferHandler_createWriterClientNative(JNIEnv *,
jobject,
jobject,
jobject);
/*
* Class: org_ray_streaming_runtime_transfer_TransferHandler
* Method: createReaderClientNative
* Signature: (J)J
*/
JNIEXPORT jlong JNICALL Java_org_ray_streaming_runtime_transfer_TransferHandler_createReaderClientNative
(JNIEnv *, jobject, jlong, jobject, jobject);
JNIEXPORT jlong JNICALL
Java_org_ray_streaming_runtime_transfer_TransferHandler_createReaderClientNative(JNIEnv *,
jobject,
jobject,
jobject);
/*
* Class: org_ray_streaming_runtime_transfer_TransferHandler
* Method: handleWriterMessageNative
* Signature: (J[B)V
*/
JNIEXPORT void JNICALL Java_org_ray_streaming_runtime_transfer_TransferHandler_handleWriterMessageNative
(JNIEnv *, jobject, jlong, jbyteArray);
JNIEXPORT void JNICALL
Java_org_ray_streaming_runtime_transfer_TransferHandler_handleWriterMessageNative(
JNIEnv *, jobject, jlong, jbyteArray);
/*
* Class: org_ray_streaming_runtime_transfer_TransferHandler
* Method: handleWriterMessageSyncNative
* Signature: (J[B)[B
*/
JNIEXPORT jbyteArray JNICALL Java_org_ray_streaming_runtime_transfer_TransferHandler_handleWriterMessageSyncNative
(JNIEnv *, jobject, jlong, jbyteArray);
JNIEXPORT jbyteArray JNICALL
Java_org_ray_streaming_runtime_transfer_TransferHandler_handleWriterMessageSyncNative(
JNIEnv *, jobject, jlong, jbyteArray);
/*
* Class: org_ray_streaming_runtime_transfer_TransferHandler
* Method: handleReaderMessageNative
* Signature: (J[B)V
*/
JNIEXPORT void JNICALL Java_org_ray_streaming_runtime_transfer_TransferHandler_handleReaderMessageNative
(JNIEnv *, jobject, jlong, jbyteArray);
JNIEXPORT void JNICALL
Java_org_ray_streaming_runtime_transfer_TransferHandler_handleReaderMessageNative(
JNIEnv *, jobject, jlong, jbyteArray);
/*
* Class: org_ray_streaming_runtime_transfer_TransferHandler
* Method: handleReaderMessageSyncNative
* Signature: (J[B)[B
*/
JNIEXPORT jbyteArray JNICALL Java_org_ray_streaming_runtime_transfer_TransferHandler_handleReaderMessageSyncNative
(JNIEnv *, jobject, jlong, jbyteArray);
JNIEXPORT jbyteArray JNICALL
Java_org_ray_streaming_runtime_transfer_TransferHandler_handleReaderMessageSyncNative(
JNIEnv *, jobject, jlong, jbyteArray);
#ifdef __cplusplus
}

View file

@ -14,16 +14,14 @@ namespace streaming {
class ReaderClient {
public:
/// Construct a ReaderClient object.
/// \param[in] core_worker CoreWorker C++ pointer of current actor
/// \param[in] async_func DataReader's raycall function descriptor to be called by
/// DataWriter, asynchronous semantics \param[in] sync_func DataReader's raycall
/// function descriptor to be called by DataWriter, synchronous semantics
ReaderClient(CoreWorker *core_worker, RayFunction &async_func, RayFunction &sync_func)
: core_worker_(core_worker) {
ReaderClient(RayFunction &async_func, RayFunction &sync_func) {
DownstreamQueueMessageHandler::peer_async_function_ = async_func;
DownstreamQueueMessageHandler::peer_sync_function_ = sync_func;
downstream_handler_ = ray::streaming::DownstreamQueueMessageHandler::CreateService(
core_worker_, core_worker_->GetWorkerContext().GetCurrentActorID());
CoreWorkerProcess::GetCoreWorker().GetWorkerContext().GetCurrentActorID());
}
/// Post buffer to downstream queue service, asynchronously.
@ -34,19 +32,17 @@ class ReaderClient {
std::shared_ptr<LocalMemoryBuffer> buffer);
private:
CoreWorker *core_worker_;
std::shared_ptr<DownstreamQueueMessageHandler> downstream_handler_;
};
/// Interface of streaming queue for DataWriter. Similar to ReaderClient.
class WriterClient {
public:
WriterClient(CoreWorker *core_worker, RayFunction &async_func, RayFunction &sync_func)
: core_worker_(core_worker) {
WriterClient(RayFunction &async_func, RayFunction &sync_func) {
UpstreamQueueMessageHandler::peer_async_function_ = async_func;
UpstreamQueueMessageHandler::peer_sync_function_ = sync_func;
upstream_handler_ = ray::streaming::UpstreamQueueMessageHandler::CreateService(
core_worker, core_worker_->GetWorkerContext().GetCurrentActorID());
CoreWorkerProcess::GetCoreWorker().GetWorkerContext().GetCurrentActorID());
}
void OnWriterMessage(std::shared_ptr<LocalMemoryBuffer> buffer);
@ -54,7 +50,6 @@ class WriterClient {
std::shared_ptr<LocalMemoryBuffer> buffer);
private:
CoreWorker *core_worker_;
std::shared_ptr<UpstreamQueueMessageHandler> upstream_handler_;
};
} // namespace streaming

View file

@ -85,8 +85,8 @@ std::shared_ptr<Transport> QueueMessageHandler::GetOutTransport(
void QueueMessageHandler::SetPeerActorID(const ObjectID &queue_id,
const ActorID &actor_id) {
actors_.emplace(queue_id, actor_id);
out_transports_.emplace(
queue_id, std::make_shared<ray::streaming::Transport>(core_worker_, actor_id));
out_transports_.emplace(queue_id,
std::make_shared<ray::streaming::Transport>(actor_id));
}
ActorID QueueMessageHandler::GetPeerActorID(const ObjectID &queue_id) {
@ -113,10 +113,9 @@ void QueueMessageHandler::Stop() {
}
std::shared_ptr<UpstreamQueueMessageHandler> UpstreamQueueMessageHandler::CreateService(
CoreWorker *core_worker, const ActorID &actor_id) {
const ActorID &actor_id) {
if (nullptr == upstream_handler_) {
upstream_handler_ =
std::make_shared<UpstreamQueueMessageHandler>(core_worker, actor_id);
upstream_handler_ = std::make_shared<UpstreamQueueMessageHandler>(actor_id);
}
return upstream_handler_;
}
@ -247,11 +246,9 @@ void UpstreamQueueMessageHandler::ReleaseAllUpQueues() {
}
std::shared_ptr<DownstreamQueueMessageHandler>
DownstreamQueueMessageHandler::CreateService(CoreWorker *core_worker,
const ActorID &actor_id) {
DownstreamQueueMessageHandler::CreateService(const ActorID &actor_id) {
if (nullptr == downstream_handler_) {
downstream_handler_ =
std::make_shared<DownstreamQueueMessageHandler>(core_worker, actor_id);
downstream_handler_ = std::make_shared<DownstreamQueueMessageHandler>(actor_id);
}
return downstream_handler_;
}

View file

@ -24,16 +24,9 @@ namespace streaming {
class QueueMessageHandler {
public:
/// Construct a QueueMessageHandler instance.
/// \param[in] core_worker CoreWorker C++ pointer of current actor, used to call Core
/// Worker's api.
/// For Python worker, the pointer can be obtained from
/// ray.worker.global_worker.core_worker; For Java worker, obtained from
/// RayNativeRuntime object through java reflection.
/// \param[in] actor_id actor id of current actor.
QueueMessageHandler(CoreWorker *core_worker, const ActorID &actor_id)
: core_worker_(core_worker),
actor_id_(actor_id),
queue_dummy_work_(queue_service_) {
QueueMessageHandler(const ActorID &actor_id)
: actor_id_(actor_id), queue_dummy_work_(queue_service_) {
Start();
}
@ -87,8 +80,6 @@ class QueueMessageHandler {
void QueueThreadCallback() { queue_service_.run(); }
protected:
/// CoreWorker C++ pointer of current actor
CoreWorker *core_worker_;
/// actor_id actor id of current actor
ActorID actor_id_;
/// Helper function, parse message buffer to Message object.
@ -111,8 +102,7 @@ class QueueMessageHandler {
class UpstreamQueueMessageHandler : public QueueMessageHandler {
public:
/// Construct a UpstreamQueueMessageHandler instance.
UpstreamQueueMessageHandler(CoreWorker *core_worker, const ActorID &actor_id)
: QueueMessageHandler(core_worker, actor_id) {}
UpstreamQueueMessageHandler(const ActorID &actor_id) : QueueMessageHandler(actor_id) {}
/// Create a upstream queue.
/// \param[in] queue_id queue id of the queue to be created.
/// \param[in] peer_actor_id actor id of peer actor.
@ -140,7 +130,7 @@ class UpstreamQueueMessageHandler : public QueueMessageHandler {
std::function<void(std::shared_ptr<LocalMemoryBuffer>)> callback) override;
static std::shared_ptr<UpstreamQueueMessageHandler> CreateService(
CoreWorker *core_worker, const ActorID &actor_id);
const ActorID &actor_id);
static std::shared_ptr<UpstreamQueueMessageHandler> GetService();
static RayFunction peer_sync_function_;
@ -157,8 +147,8 @@ class UpstreamQueueMessageHandler : public QueueMessageHandler {
/// UpstreamQueueMessageHandler holds and manages all downstream queues of current actor.
class DownstreamQueueMessageHandler : public QueueMessageHandler {
public:
DownstreamQueueMessageHandler(CoreWorker *core_worker, const ActorID &actor_id)
: QueueMessageHandler(core_worker, actor_id) {}
DownstreamQueueMessageHandler(const ActorID &actor_id)
: QueueMessageHandler(actor_id) {}
std::shared_ptr<ReaderQueue> CreateDownstreamQueue(const ObjectID &queue_id,
const ActorID &peer_actor_id);
bool DownstreamQueueExists(const ObjectID &queue_id);
@ -178,7 +168,7 @@ class DownstreamQueueMessageHandler : public QueueMessageHandler {
std::function<void(std::shared_ptr<LocalMemoryBuffer>)> callback);
static std::shared_ptr<DownstreamQueueMessageHandler> CreateService(
CoreWorker *core_worker, const ActorID &actor_id);
const ActorID &actor_id);
static std::shared_ptr<DownstreamQueueMessageHandler> GetService();
static RayFunction peer_sync_function_;
static RayFunction peer_async_function_;

View file

@ -28,10 +28,9 @@ void Transport::SendInternal(std::shared_ptr<LocalMemoryBuffer> buffer,
args.emplace_back(TaskArg::PassByValue(std::make_shared<RayObject>(
std::move(buffer), meta, std::vector<ObjectID>(), true)));
STREAMING_CHECK(core_worker_ != nullptr);
std::vector<std::shared_ptr<RayObject>> results;
ray::Status st =
core_worker_->SubmitActorTask(peer_actor_id_, function, args, options, &return_ids);
ray::Status st = CoreWorkerProcess::GetCoreWorker().SubmitActorTask(
peer_actor_id_, function, args, options, &return_ids);
if (!st.ok()) {
STREAMING_LOG(ERROR) << "SubmitActorTask failed. " << st;
}
@ -50,7 +49,8 @@ std::shared_ptr<LocalMemoryBuffer> Transport::SendForResult(
SendInternal(buffer, function, TASK_OPTION_RETURN_NUM_1, return_ids);
std::vector<std::shared_ptr<RayObject>> results;
Status get_st = core_worker_->Get(return_ids, timeout_ms, &results);
Status get_st =
CoreWorkerProcess::GetCoreWorker().Get(return_ids, timeout_ms, &results);
if (!get_st.ok()) {
STREAMING_LOG(ERROR) << "Get fail.";
return nullptr;

View file

@ -13,11 +13,10 @@ namespace streaming {
class Transport {
public:
/// Construct a Transport object.
/// \param[in] core_worker CoreWorker C++ pointer of current actor, which we call direct
/// actor call interface with.
/// \param[in] peer_actor_id actor id of peer actor.
Transport(CoreWorker *core_worker, const ActorID &peer_actor_id)
: core_worker_(core_worker), peer_actor_id_(peer_actor_id) {}
Transport(const ActorID &peer_actor_id)
: worker_id_(CoreWorkerProcess::GetCoreWorker().GetWorkerID()),
peer_actor_id_(peer_actor_id) {}
virtual ~Transport() = default;
/// Send buffer asynchronously, peer's `function` will be called.
@ -55,7 +54,7 @@ class Transport {
std::vector<ObjectID> &return_ids);
private:
CoreWorker *core_worker_;
WorkerID worker_id_;
ActorID peer_actor_id_;
};
} // namespace streaming

View file

@ -20,11 +20,9 @@ namespace streaming {
class StreamingQueueTestSuite {
public:
StreamingQueueTestSuite(std::shared_ptr<CoreWorker> core_worker, ActorID &peer_actor_id,
std::vector<ObjectID> queue_ids,
StreamingQueueTestSuite(ActorID &peer_actor_id, std::vector<ObjectID> queue_ids,
std::vector<ObjectID> rescale_queue_ids)
: core_worker_(core_worker),
peer_actor_id_(peer_actor_id),
: peer_actor_id_(peer_actor_id),
queue_ids_(queue_ids),
rescale_queue_ids_(rescale_queue_ids) {}
@ -52,7 +50,6 @@ class StreamingQueueTestSuite {
std::string current_test_;
bool status_;
std::shared_ptr<std::thread> executor_thread_;
std::shared_ptr<CoreWorker> core_worker_;
ActorID peer_actor_id_;
std::vector<ObjectID> queue_ids_;
std::vector<ObjectID> rescale_queue_ids_;
@ -60,11 +57,9 @@ class StreamingQueueTestSuite {
class StreamingQueueWriterTestSuite : public StreamingQueueTestSuite {
public:
StreamingQueueWriterTestSuite(std::shared_ptr<CoreWorker> core_worker,
ActorID &peer_actor_id, std::vector<ObjectID> queue_ids,
StreamingQueueWriterTestSuite(ActorID &peer_actor_id, std::vector<ObjectID> queue_ids,
std::vector<ObjectID> rescale_queue_ids)
: StreamingQueueTestSuite(core_worker, peer_actor_id, queue_ids,
rescale_queue_ids) {
: StreamingQueueTestSuite(peer_actor_id, queue_ids, rescale_queue_ids) {
test_func_map_ = {
{"streaming_writer_exactly_once_test",
std::bind(&StreamingQueueWriterTestSuite::StreamingWriterExactlyOnceTest,
@ -135,11 +130,9 @@ class StreamingQueueWriterTestSuite : public StreamingQueueTestSuite {
class StreamingQueueReaderTestSuite : public StreamingQueueTestSuite {
public:
StreamingQueueReaderTestSuite(std::shared_ptr<CoreWorker> core_worker,
ActorID peer_actor_id, std::vector<ObjectID> queue_ids,
StreamingQueueReaderTestSuite(ActorID peer_actor_id, std::vector<ObjectID> queue_ids,
std::vector<ObjectID> rescale_queue_ids)
: StreamingQueueTestSuite(core_worker, peer_actor_id, queue_ids,
rescale_queue_ids) {
: StreamingQueueTestSuite(peer_actor_id, queue_ids, rescale_queue_ids) {
test_func_map_ = {
{"streaming_writer_exactly_once_test",
std::bind(&StreamingQueueReaderTestSuite::StreamingWriterExactlyOnceTest,
@ -247,7 +240,7 @@ class StreamingQueueReaderTestSuite : public StreamingQueueTestSuite {
class TestSuiteFactory {
public:
static std::shared_ptr<StreamingQueueTestSuite> CreateTestSuite(
std::shared_ptr<CoreWorker> worker, std::shared_ptr<TestInitMessage> message) {
std::shared_ptr<TestInitMessage> message) {
std::shared_ptr<StreamingQueueTestSuite> test_suite = nullptr;
std::string suite_name = message->TestSuiteName();
queue::protobuf::StreamingQueueTestRole role = message->Role();
@ -258,14 +251,14 @@ class TestSuiteFactory {
if (role == queue::protobuf::StreamingQueueTestRole::WRITER) {
if (suite_name == "StreamingWriterTest") {
test_suite = std::make_shared<StreamingQueueWriterTestSuite>(
worker, peer_actor_id, queue_ids, rescale_queue_ids);
peer_actor_id, queue_ids, rescale_queue_ids);
} else {
STREAMING_CHECK(false) << "unsurported suite_name: " << suite_name;
}
} else {
if (suite_name == "StreamingWriterTest") {
test_suite = std::make_shared<StreamingQueueReaderTestSuite>(
worker, peer_actor_id, queue_ids, rescale_queue_ids);
peer_actor_id, queue_ids, rescale_queue_ids);
} else {
STREAMING_CHECK(false) << "unsupported suite_name: " << suite_name;
}
@ -280,10 +273,30 @@ class StreamingWorker {
StreamingWorker(const std::string &store_socket, const std::string &raylet_socket,
int node_manager_port, const gcs::GcsClientOptions &gcs_options)
: test_suite_(nullptr), peer_actor_handle_(nullptr) {
worker_ = std::make_shared<CoreWorker>(
WorkerType::WORKER, Language::PYTHON, store_socket, raylet_socket,
JobID::FromInt(1), gcs_options, "", "127.0.0.1", node_manager_port,
std::bind(&StreamingWorker::ExecuteTask, this, _1, _2, _3, _4, _5, _6, _7));
CoreWorkerOptions options = {
WorkerType::WORKER, // worker_type
Language::PYTHON, // langauge
store_socket, // store_socket
raylet_socket, // raylet_socket
JobID::FromInt(1), // job_id
gcs_options, // gcs_options
"", // log_dir
true, // install_failure_signal_handler
"127.0.0.1", // node_ip_address
node_manager_port, // node_manager_port
"", // driver_name
"", // stdout_file
"", // stderr_file
std::bind(&StreamingWorker::ExecuteTask, this, _1, _2, _3, _4, _5, _6,
_7), // task_execution_callback
nullptr, // check_signals
nullptr, // gc_collect
nullptr, // get_lang_stack
true, // ref_counting_enabled
false, // is_local_mode
1, // num_workers
};
CoreWorkerProcess::Initialize(options);
RayFunction reader_async_call_func{ray::Language::PYTHON,
ray::FunctionDescriptorBuilder::BuildPython(
@ -298,16 +311,16 @@ class StreamingWorker {
ray::Language::PYTHON,
ray::FunctionDescriptorBuilder::BuildPython("writer_sync_call_func", "", "", "")};
reader_client_ = std::make_shared<ReaderClient>(worker_.get(), reader_async_call_func,
reader_sync_call_func);
writer_client_ = std::make_shared<WriterClient>(worker_.get(), writer_async_call_func,
writer_sync_call_func);
reader_client_ =
std::make_shared<ReaderClient>(reader_async_call_func, reader_sync_call_func);
writer_client_ =
std::make_shared<WriterClient>(writer_async_call_func, writer_sync_call_func);
STREAMING_LOG(INFO) << "StreamingWorker constructor";
}
void StartExecutingTasks() {
void RunTaskExecutionLoop() {
// Start executing tasks.
worker_->StartExecutingTasks();
CoreWorkerProcess::RunTaskExecutionLoop();
}
private:
@ -403,7 +416,8 @@ class StreamingWorker {
STREAMING_LOG(INFO) << "Init message: " << message->ToString();
std::string actor_handle_serialized = message->ActorHandleSerialized();
worker_->DeserializeAndRegisterActorHandle(actor_handle_serialized, ObjectID::Nil());
CoreWorkerProcess::GetCoreWorker().DeserializeAndRegisterActorHandle(
actor_handle_serialized, ObjectID::Nil());
std::shared_ptr<ActorHandle> actor_handle(new ActorHandle(actor_handle_serialized));
STREAMING_CHECK(actor_handle != nullptr);
STREAMING_LOG(INFO) << " actor id from handle: " << actor_handle->GetActorID();
@ -421,12 +435,11 @@ class StreamingWorker {
STREAMING_LOG(INFO) << "rescale queue: " << qid;
}
test_suite_ = TestSuiteFactory::CreateTestSuite(worker_, message);
test_suite_ = TestSuiteFactory::CreateTestSuite(message);
STREAMING_CHECK(test_suite_ != nullptr);
}
private:
std::shared_ptr<CoreWorker> worker_;
std::shared_ptr<ReaderClient> reader_client_;
std::shared_ptr<WriterClient> writer_client_;
std::shared_ptr<std::thread> test_thread_;
@ -446,6 +459,6 @@ int main(int argc, char **argv) {
ray::gcs::GcsClientOptions gcs_options("127.0.0.1", 6379, "");
ray::streaming::StreamingWorker worker(store_socket, raylet_socket, node_manager_port,
gcs_options);
worker.StartExecutingTasks();
worker.RunTaskExecutionLoop();
return 0;
}

View file

@ -153,11 +153,12 @@ class StreamingQueueTestBase : public ::testing::TestWithParam<uint64_t> {
ASSERT_TRUE(system(("rm -rf " + raylet_socket_name + ".pid").c_str()) == 0);
}
void InitWorker(CoreWorker &driver, ActorID &self_actor_id, ActorID &peer_actor_id,
void InitWorker(ActorID &self_actor_id, ActorID &peer_actor_id,
const queue::protobuf::StreamingQueueTestRole role,
const std::vector<ObjectID> &queue_ids,
const std::vector<ObjectID> &rescale_queue_ids, std::string suite_name,
std::string test_name, uint64_t param) {
auto &driver = CoreWorkerProcess::GetCoreWorker();
std::string forked_serialized_str;
ObjectID actor_handle_id;
Status st = driver.SerializeActorHandle(peer_actor_id, &forked_serialized_str,
@ -179,7 +180,8 @@ class StreamingQueueTestBase : public ::testing::TestWithParam<uint64_t> {
RAY_CHECK_OK(driver.SubmitActorTask(self_actor_id, func, args, options, &return_ids));
}
void SubmitTestToActor(CoreWorker &driver, ActorID &actor_id, const std::string test) {
void SubmitTestToActor(ActorID &actor_id, const std::string test) {
auto &driver = CoreWorkerProcess::GetCoreWorker();
uint8_t data[8];
auto buffer = std::make_shared<LocalMemoryBuffer>(data, 8, true);
std::vector<TaskArg> args;
@ -194,7 +196,8 @@ class StreamingQueueTestBase : public ::testing::TestWithParam<uint64_t> {
RAY_CHECK_OK(driver.SubmitActorTask(actor_id, func, args, options, &return_ids));
}
bool CheckCurTest(CoreWorker &driver, ActorID &actor_id, const std::string test_name) {
bool CheckCurTest(ActorID &actor_id, const std::string test_name) {
auto &driver = CoreWorkerProcess::GetCoreWorker();
uint8_t data[8];
auto buffer = std::make_shared<LocalMemoryBuffer>(data, 8, true);
std::vector<TaskArg> args;
@ -255,8 +258,7 @@ class StreamingQueueTestBase : public ::testing::TestWithParam<uint64_t> {
return message->Status();
}
ActorID CreateActorHelper(CoreWorker &worker,
const std::unordered_map<std::string, double> &resources,
ActorID CreateActorHelper(const std::unordered_map<std::string, double> &resources,
bool is_direct_call, uint64_t max_reconstructions) {
std::unique_ptr<ActorHandle> actor_handle;
@ -277,8 +279,8 @@ class StreamingQueueTestBase : public ::testing::TestWithParam<uint64_t> {
// Create an actor.
ActorID actor_id;
RAY_CHECK_OK(
worker.CreateActor(func, args, actor_options, /*extension_data*/ "", &actor_id));
RAY_CHECK_OK(CoreWorkerProcess::GetCoreWorker().CreateActor(
func, args, actor_options, /*extension_data*/ "", &actor_id));
return actor_id;
}
@ -305,33 +307,54 @@ class StreamingQueueTestBase : public ::testing::TestWithParam<uint64_t> {
}
STREAMING_LOG(INFO) << "Sub process: writer.";
CoreWorker driver(WorkerType::DRIVER, Language::PYTHON, raylet_store_socket_names_[0],
raylet_socket_names_[0], NextJobId(), gcs_options_, "", "127.0.0.1",
node_manager_port_, nullptr);
CoreWorkerOptions options = {
WorkerType::DRIVER, // worker_type
Language::PYTHON, // langauge
raylet_store_socket_names_[0], // store_socket
raylet_socket_names_[0], // raylet_socket
NextJobId(), // job_id
gcs_options_, // gcs_options
"", // log_dir
true, // install_failure_signal_handler
"127.0.0.1", // node_ip_address
node_manager_port_, // node_manager_port
"queue_tests", // driver_name
"", // stdout_file
"", // stderr_file
nullptr, // task_execution_callback
nullptr, // check_signals
nullptr, // gc_collect
nullptr, // get_lang_stack
true, // ref_counting_enabled
false, // is_local_mode
1, // num_workers
};
InitShutdownRAII core_worker_raii(CoreWorkerProcess::Initialize,
CoreWorkerProcess::Shutdown, options);
// Create writer and reader actors
std::unordered_map<std::string, double> resources;
auto actor_id_writer = CreateActorHelper(driver, resources, true, 0);
auto actor_id_reader = CreateActorHelper(driver, resources, true, 0);
auto actor_id_writer = CreateActorHelper(resources, true, 0);
auto actor_id_reader = CreateActorHelper(resources, true, 0);
InitWorker(driver, actor_id_writer, actor_id_reader,
InitWorker(actor_id_writer, actor_id_reader,
queue::protobuf::StreamingQueueTestRole::WRITER, queue_id_vec,
rescale_queue_id_vec, suite_name, test_name, GetParam());
InitWorker(driver, actor_id_reader, actor_id_writer,
InitWorker(actor_id_reader, actor_id_writer,
queue::protobuf::StreamingQueueTestRole::READER, queue_id_vec,
rescale_queue_id_vec, suite_name, test_name, GetParam());
std::this_thread::sleep_for(std::chrono::milliseconds(2000));
SubmitTestToActor(driver, actor_id_writer, test_name);
SubmitTestToActor(driver, actor_id_reader, test_name);
SubmitTestToActor(actor_id_writer, test_name);
SubmitTestToActor(actor_id_reader, test_name);
uint64_t slept_time_ms = 0;
while (slept_time_ms < timeout_ms) {
std::this_thread::sleep_for(std::chrono::milliseconds(5 * 1000));
STREAMING_LOG(INFO) << "Check test status.";
if (CheckCurTest(driver, actor_id_writer, test_name) &&
CheckCurTest(driver, actor_id_reader, test_name)) {
if (CheckCurTest(actor_id_writer, test_name) &&
CheckCurTest(actor_id_reader, test_name)) {
STREAMING_LOG(INFO) << "Test Success, Exit.";
return;
}