mirror of
https://github.com/vale981/ray
synced 2025-03-05 10:01:43 -05:00
Support multiple core workers in one process (#7623)
This commit is contained in:
parent
e91595f955
commit
48b48cc8c2
90 changed files with 2014 additions and 1411 deletions
|
@ -74,7 +74,10 @@ python_grpc_compile(
|
|||
proto_library(
|
||||
name = "gcs_service_proto",
|
||||
srcs = ["src/ray/protobuf/gcs_service.proto"],
|
||||
deps = [":gcs_proto"],
|
||||
deps = [
|
||||
":common_proto",
|
||||
":gcs_proto",
|
||||
],
|
||||
)
|
||||
|
||||
cc_proto_library(
|
||||
|
|
|
@ -12,8 +12,8 @@ namespace api {
|
|||
|
||||
LocalModeRayRuntime::LocalModeRayRuntime(std::shared_ptr<RayConfig> config) {
|
||||
config_ = config;
|
||||
worker_ =
|
||||
std::unique_ptr<WorkerContext>(new WorkerContext(WorkerType::DRIVER, JobID::Nil()));
|
||||
worker_ = std::unique_ptr<WorkerContext>(new WorkerContext(
|
||||
WorkerType::DRIVER, ComputeDriverIdFromJob(JobID::Nil()), JobID::Nil()));
|
||||
object_store_ = std::unique_ptr<ObjectStore>(new LocalModeObjectStore(*this));
|
||||
task_submitter_ = std::unique_ptr<TaskSubmitter>(new LocalModeTaskSubmitter(*this));
|
||||
}
|
||||
|
|
|
@ -44,7 +44,7 @@ public final class Ray extends RayCall {
|
|||
/**
|
||||
* Shutdown Ray runtime.
|
||||
*/
|
||||
public static void shutdown() {
|
||||
public static synchronized void shutdown() {
|
||||
if (runtime != null) {
|
||||
runtime.shutdown();
|
||||
runtime = null;
|
||||
|
@ -137,6 +137,11 @@ public final class Ray extends RayCall {
|
|||
runtime.setAsyncContext(asyncContext);
|
||||
}
|
||||
|
||||
// TODO (kfstorm): add the `rollbackAsyncContext` API to allow rollbacking the async context of
|
||||
// the current thread to the one before `setAsyncContext` is called.
|
||||
|
||||
// TODO (kfstorm): unify the `wrap*` methods.
|
||||
|
||||
/**
|
||||
* If users want to use Ray API in their own threads, they should wrap their {@link Runnable}
|
||||
* objects with this method.
|
||||
|
@ -155,7 +160,7 @@ public final class Ray extends RayCall {
|
|||
* @param callable The callable to wrap.
|
||||
* @return The wrapped callable.
|
||||
*/
|
||||
public static Callable wrapCallable(Callable callable) {
|
||||
public static <T> Callable<T> wrapCallable(Callable<T> callable) {
|
||||
return runtime.wrapCallable(callable);
|
||||
}
|
||||
|
||||
|
|
|
@ -159,6 +159,7 @@ public interface RayRuntime {
|
|||
|
||||
/**
|
||||
* Wrap a {@link Runnable} with necessary context capture.
|
||||
*
|
||||
* @param runnable The runnable to wrap.
|
||||
* @return The wrapped runnable.
|
||||
*/
|
||||
|
@ -166,8 +167,9 @@ public interface RayRuntime {
|
|||
|
||||
/**
|
||||
* Wrap a {@link Callable} with necessary context capture.
|
||||
*
|
||||
* @param callable The callable to wrap.
|
||||
* @return The wrapped callable.
|
||||
*/
|
||||
Callable wrapCallable(Callable callable);
|
||||
<T> Callable<T> wrapCallable(Callable<T> callable);
|
||||
}
|
||||
|
|
45
java/generate_jni_header_files.sh
Executable file
45
java/generate_jni_header_files.sh
Executable file
|
@ -0,0 +1,45 @@
|
|||
#!/usr/bin/env bash
|
||||
|
||||
set -e
|
||||
set -x
|
||||
|
||||
cd "$(dirname "$0")"
|
||||
|
||||
(cd .. && bazel build //java:all_tests_deploy.jar)
|
||||
|
||||
function generate_one()
|
||||
{
|
||||
file=${1//./_}.h
|
||||
javah -classpath ../bazel-bin/java/all_tests_deploy.jar $1
|
||||
clang-format -i $file
|
||||
|
||||
cat <<EOF > ../src/ray/core_worker/lib/java/$file
|
||||
// Copyright 2017 The Ray Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
EOF
|
||||
cat $file >> ../src/ray/core_worker/lib/java/$file
|
||||
rm -f $file
|
||||
}
|
||||
|
||||
generate_one org.ray.runtime.RayNativeRuntime
|
||||
generate_one org.ray.runtime.task.NativeTaskSubmitter
|
||||
generate_one org.ray.runtime.context.NativeWorkerContext
|
||||
generate_one org.ray.runtime.actor.NativeRayActor
|
||||
generate_one org.ray.runtime.object.NativeObjectStore
|
||||
generate_one org.ray.runtime.task.NativeTaskExecutor
|
||||
|
||||
# Remove empty files
|
||||
rm -f org_ray_runtime_RayNativeRuntime_AsyncContext.h
|
||||
rm -f org_ray_runtime_task_NativeTaskExecutor_NativeActorContext.h
|
|
@ -19,7 +19,6 @@ import org.ray.api.function.RayFuncVoid;
|
|||
import org.ray.api.id.ObjectId;
|
||||
import org.ray.api.options.ActorCreationOptions;
|
||||
import org.ray.api.options.CallOptions;
|
||||
import org.ray.api.runtime.RayRuntime;
|
||||
import org.ray.api.runtimecontext.RuntimeContext;
|
||||
import org.ray.runtime.config.RayConfig;
|
||||
import org.ray.runtime.context.RuntimeContextImpl;
|
||||
|
@ -29,6 +28,7 @@ import org.ray.runtime.functionmanager.FunctionManager;
|
|||
import org.ray.runtime.functionmanager.PyFunctionDescriptor;
|
||||
import org.ray.runtime.gcs.GcsClient;
|
||||
import org.ray.runtime.generated.Common.Language;
|
||||
import org.ray.runtime.generated.Common.WorkerType;
|
||||
import org.ray.runtime.object.ObjectStore;
|
||||
import org.ray.runtime.object.RayObjectImpl;
|
||||
import org.ray.runtime.task.ArgumentsBuilder;
|
||||
|
@ -41,7 +41,7 @@ import org.slf4j.LoggerFactory;
|
|||
/**
|
||||
* Core functionality to implement Ray APIs.
|
||||
*/
|
||||
public abstract class AbstractRayRuntime implements RayRuntime {
|
||||
public abstract class AbstractRayRuntime implements RayRuntimeInternal {
|
||||
|
||||
private static final Logger LOGGER = LoggerFactory.getLogger(AbstractRayRuntime.class);
|
||||
public static final String PYTHON_INIT_METHOD_NAME = "__init__";
|
||||
|
@ -55,9 +55,15 @@ public abstract class AbstractRayRuntime implements RayRuntime {
|
|||
protected TaskSubmitter taskSubmitter;
|
||||
protected WorkerContext workerContext;
|
||||
|
||||
public AbstractRayRuntime(RayConfig rayConfig, FunctionManager functionManager) {
|
||||
/**
|
||||
* Whether the required thread context is set on the current thread.
|
||||
*/
|
||||
final ThreadLocal<Boolean> isContextSet = ThreadLocal.withInitial(() -> false);
|
||||
|
||||
public AbstractRayRuntime(RayConfig rayConfig) {
|
||||
this.rayConfig = rayConfig;
|
||||
this.functionManager = functionManager;
|
||||
setIsContextSet(rayConfig.workerMode == WorkerType.DRIVER);
|
||||
functionManager = new FunctionManager(rayConfig.jobResourcePath);
|
||||
runtimeContext = new RuntimeContextImpl(this);
|
||||
}
|
||||
|
||||
|
@ -161,13 +167,54 @@ public abstract class AbstractRayRuntime implements RayRuntime {
|
|||
}
|
||||
|
||||
@Override
|
||||
public Runnable wrapRunnable(Runnable runnable) {
|
||||
return runnable;
|
||||
public void setAsyncContext(Object asyncContext) {
|
||||
isContextSet.set(true);
|
||||
}
|
||||
|
||||
// TODO (kfstorm): Simplify the duplicate code in wrap*** methods.
|
||||
|
||||
@Override
|
||||
public final Runnable wrapRunnable(Runnable runnable) {
|
||||
Object asyncContext = getAsyncContext();
|
||||
return () -> {
|
||||
boolean oldIsContextSet = isContextSet.get();
|
||||
Object oldAsyncContext = null;
|
||||
if (oldIsContextSet) {
|
||||
oldAsyncContext = getAsyncContext();
|
||||
}
|
||||
setAsyncContext(asyncContext);
|
||||
try {
|
||||
runnable.run();
|
||||
} finally {
|
||||
if (oldIsContextSet) {
|
||||
setAsyncContext(oldAsyncContext);
|
||||
} else {
|
||||
setIsContextSet(false);
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
@Override
|
||||
public Callable wrapCallable(Callable callable) {
|
||||
return callable;
|
||||
public final <T> Callable<T> wrapCallable(Callable<T> callable) {
|
||||
Object asyncContext = getAsyncContext();
|
||||
return () -> {
|
||||
boolean oldIsContextSet = isContextSet.get();
|
||||
Object oldAsyncContext = null;
|
||||
if (oldIsContextSet) {
|
||||
oldAsyncContext = getAsyncContext();
|
||||
}
|
||||
setAsyncContext(asyncContext);
|
||||
try {
|
||||
return callable.call();
|
||||
} finally {
|
||||
if (oldIsContextSet) {
|
||||
setAsyncContext(oldAsyncContext);
|
||||
} else {
|
||||
setIsContextSet(false);
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
private RayObject callNormalFunction(FunctionDescriptor functionDescriptor,
|
||||
|
@ -209,18 +256,22 @@ public abstract class AbstractRayRuntime implements RayRuntime {
|
|||
return actor;
|
||||
}
|
||||
|
||||
@Override
|
||||
public WorkerContext getWorkerContext() {
|
||||
return workerContext;
|
||||
}
|
||||
|
||||
@Override
|
||||
public ObjectStore getObjectStore() {
|
||||
return objectStore;
|
||||
}
|
||||
|
||||
@Override
|
||||
public FunctionManager getFunctionManager() {
|
||||
return functionManager;
|
||||
}
|
||||
|
||||
@Override
|
||||
public RayConfig getRayConfig() {
|
||||
return rayConfig;
|
||||
}
|
||||
|
@ -229,7 +280,13 @@ public abstract class AbstractRayRuntime implements RayRuntime {
|
|||
return runtimeContext;
|
||||
}
|
||||
|
||||
@Override
|
||||
public GcsClient getGcsClient() {
|
||||
return gcsClient;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setIsContextSet(boolean isContextSet) {
|
||||
this.isContextSet.set(isContextSet);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -4,8 +4,6 @@ import org.ray.api.runtime.RayRuntime;
|
|||
import org.ray.api.runtime.RayRuntimeFactory;
|
||||
import org.ray.runtime.config.RayConfig;
|
||||
import org.ray.runtime.config.RunMode;
|
||||
import org.ray.runtime.functionmanager.FunctionManager;
|
||||
import org.ray.runtime.generated.Common.WorkerType;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
|
@ -20,17 +18,13 @@ public class DefaultRayRuntimeFactory implements RayRuntimeFactory {
|
|||
public RayRuntime createRayRuntime() {
|
||||
RayConfig rayConfig = RayConfig.getInstance();
|
||||
try {
|
||||
FunctionManager functionManager = new FunctionManager(rayConfig.jobResourcePath);
|
||||
RayRuntime runtime;
|
||||
if (rayConfig.runMode == RunMode.SINGLE_PROCESS) {
|
||||
runtime = new RayDevRuntime(rayConfig, functionManager);
|
||||
} else {
|
||||
if (rayConfig.workerMode == WorkerType.DRIVER) {
|
||||
runtime = new RayNativeRuntime(rayConfig, functionManager);
|
||||
} else {
|
||||
runtime = new RayMultiWorkerNativeRuntime(rayConfig, functionManager);
|
||||
}
|
||||
}
|
||||
AbstractRayRuntime innerRuntime = rayConfig.runMode == RunMode.SINGLE_PROCESS
|
||||
? new RayDevRuntime(rayConfig)
|
||||
: new RayNativeRuntime(rayConfig);
|
||||
RayRuntimeInternal runtime = rayConfig.numWorkersPerProcess > 1
|
||||
? RayRuntimeProxy.newInstance(innerRuntime)
|
||||
: innerRuntime;
|
||||
runtime.start();
|
||||
return runtime;
|
||||
} catch (Exception e) {
|
||||
LOGGER.error("Failed to initialize ray runtime", e);
|
||||
|
|
|
@ -1,12 +1,12 @@
|
|||
package org.ray.runtime;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
import java.util.concurrent.atomic.AtomicInteger;
|
||||
import org.ray.api.BaseActor;
|
||||
import org.ray.api.id.JobId;
|
||||
import org.ray.api.id.UniqueId;
|
||||
import org.ray.runtime.config.RayConfig;
|
||||
import org.ray.runtime.context.LocalModeWorkerContext;
|
||||
import org.ray.runtime.functionmanager.FunctionManager;
|
||||
import org.ray.runtime.object.LocalModeObjectStore;
|
||||
import org.ray.runtime.task.LocalModeTaskExecutor;
|
||||
import org.ray.runtime.task.LocalModeTaskSubmitter;
|
||||
|
@ -19,18 +19,31 @@ public class RayDevRuntime extends AbstractRayRuntime {
|
|||
|
||||
private AtomicInteger jobCounter = new AtomicInteger(0);
|
||||
|
||||
public RayDevRuntime(RayConfig rayConfig, FunctionManager functionManager) {
|
||||
super(rayConfig, functionManager);
|
||||
public RayDevRuntime(RayConfig rayConfig) {
|
||||
super(rayConfig);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void start() {
|
||||
if (rayConfig.getJobId().isNil()) {
|
||||
rayConfig.setJobId(nextJobId());
|
||||
}
|
||||
taskExecutor = new LocalModeTaskExecutor(this);
|
||||
workerContext = new LocalModeWorkerContext(rayConfig.getJobId());
|
||||
objectStore = new LocalModeObjectStore(workerContext);
|
||||
taskSubmitter = new LocalModeTaskSubmitter(this, (LocalModeObjectStore) objectStore,
|
||||
rayConfig.numberExecThreadsForDevRuntime);
|
||||
taskSubmitter = new LocalModeTaskSubmitter(this, taskExecutor,
|
||||
(LocalModeObjectStore) objectStore);
|
||||
((LocalModeObjectStore) objectStore).addObjectPutCallback(
|
||||
objectId -> ((LocalModeTaskSubmitter) taskSubmitter).onObjectPut(objectId));
|
||||
objectId -> {
|
||||
if (taskSubmitter != null) {
|
||||
((LocalModeTaskSubmitter) taskSubmitter).onObjectPut(objectId);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
@Override
|
||||
public void run() {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -60,6 +73,8 @@ public class RayDevRuntime extends AbstractRayRuntime {
|
|||
|
||||
@Override
|
||||
public void setAsyncContext(Object asyncContext) {
|
||||
Preconditions.checkArgument(asyncContext == null);
|
||||
super.setAsyncContext(asyncContext);
|
||||
}
|
||||
|
||||
private JobId nextJobId() {
|
||||
|
|
|
@ -1,215 +0,0 @@
|
|||
package org.ray.runtime;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
import java.util.List;
|
||||
import java.util.concurrent.Callable;
|
||||
import org.ray.api.BaseActor;
|
||||
import org.ray.api.RayActor;
|
||||
import org.ray.api.RayObject;
|
||||
import org.ray.api.RayPyActor;
|
||||
import org.ray.api.WaitResult;
|
||||
import org.ray.api.function.PyActorClass;
|
||||
import org.ray.api.function.PyActorMethod;
|
||||
import org.ray.api.function.PyRemoteFunction;
|
||||
import org.ray.api.function.RayFunc;
|
||||
import org.ray.api.id.ObjectId;
|
||||
import org.ray.api.id.UniqueId;
|
||||
import org.ray.api.options.ActorCreationOptions;
|
||||
import org.ray.api.options.CallOptions;
|
||||
import org.ray.api.runtime.RayRuntime;
|
||||
import org.ray.api.runtimecontext.RuntimeContext;
|
||||
import org.ray.runtime.config.RayConfig;
|
||||
import org.ray.runtime.config.RunMode;
|
||||
import org.ray.runtime.functionmanager.FunctionManager;
|
||||
import org.ray.runtime.generated.Common.WorkerType;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
/**
|
||||
* This is a proxy runtime for multi-worker support. It holds multiple {@link RayNativeRuntime}
|
||||
* instances and redirect calls to the correct one based on thread context.
|
||||
*/
|
||||
public class RayMultiWorkerNativeRuntime implements RayRuntime {
|
||||
|
||||
private static final Logger LOGGER = LoggerFactory.getLogger(RayMultiWorkerNativeRuntime.class);
|
||||
|
||||
private final FunctionManager functionManager;
|
||||
|
||||
/**
|
||||
* The number of workers per worker process.
|
||||
*/
|
||||
private final int numWorkers;
|
||||
/**
|
||||
* The worker threads.
|
||||
*/
|
||||
private final Thread[] threads;
|
||||
/**
|
||||
* The {@link RayNativeRuntime} instances of workers.
|
||||
*/
|
||||
private final RayNativeRuntime[] runtimes;
|
||||
/**
|
||||
* The {@link RayNativeRuntime} instance of current thread.
|
||||
*/
|
||||
private final ThreadLocal<RayNativeRuntime> currentThreadRuntime = new ThreadLocal<>();
|
||||
|
||||
public RayMultiWorkerNativeRuntime(RayConfig rayConfig, FunctionManager functionManager) {
|
||||
this.functionManager = functionManager;
|
||||
Preconditions.checkState(
|
||||
rayConfig.runMode == RunMode.CLUSTER && rayConfig.workerMode == WorkerType.WORKER);
|
||||
Preconditions.checkState(rayConfig.numWorkersPerProcess > 0,
|
||||
"numWorkersPerProcess must be greater than 0.");
|
||||
numWorkers = rayConfig.numWorkersPerProcess;
|
||||
runtimes = new RayNativeRuntime[numWorkers];
|
||||
threads = new Thread[numWorkers];
|
||||
|
||||
LOGGER.info("Starting {} workers.", numWorkers);
|
||||
|
||||
for (int i = 0; i < numWorkers; i++) {
|
||||
final int workerIndex = i;
|
||||
threads[i] = new Thread(() -> {
|
||||
RayNativeRuntime runtime = new RayNativeRuntime(rayConfig, functionManager);
|
||||
runtimes[workerIndex] = runtime;
|
||||
currentThreadRuntime.set(runtime);
|
||||
runtime.run();
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
public void run() {
|
||||
for (int i = 0; i < numWorkers; i++) {
|
||||
threads[i].start();
|
||||
}
|
||||
for (int i = 0; i < numWorkers; i++) {
|
||||
try {
|
||||
threads[i].join();
|
||||
} catch (InterruptedException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void shutdown() {
|
||||
for (int i = 0; i < numWorkers; i++) {
|
||||
runtimes[i].shutdown();
|
||||
}
|
||||
for (int i = 0; i < numWorkers; i++) {
|
||||
try {
|
||||
threads[i].join();
|
||||
} catch (InterruptedException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public RayNativeRuntime getCurrentRuntime() {
|
||||
RayNativeRuntime currentRuntime = currentThreadRuntime.get();
|
||||
Preconditions.checkNotNull(currentRuntime,
|
||||
"RayRuntime is not set on current thread."
|
||||
+ " If you want to use Ray API in your own threads,"
|
||||
+ " please wrap your `Runnable`s or `Callable`s with"
|
||||
+ " `Ray.wrapRunnable` or `Ray.wrapCallable`.");
|
||||
return currentRuntime;
|
||||
}
|
||||
|
||||
@Override
|
||||
public <T> RayObject<T> put(T obj) {
|
||||
return getCurrentRuntime().put(obj);
|
||||
}
|
||||
|
||||
@Override
|
||||
public <T> T get(ObjectId objectId) {
|
||||
return getCurrentRuntime().get(objectId);
|
||||
}
|
||||
|
||||
@Override
|
||||
public <T> List<T> get(List<ObjectId> objectIds) {
|
||||
return getCurrentRuntime().get(objectIds);
|
||||
}
|
||||
|
||||
@Override
|
||||
public <T> WaitResult<T> wait(List<RayObject<T>> waitList, int numReturns, int timeoutMs) {
|
||||
return getCurrentRuntime().wait(waitList, numReturns, timeoutMs);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void free(List<ObjectId> objectIds, boolean localOnly, boolean deleteCreatingTasks) {
|
||||
getCurrentRuntime().free(objectIds, localOnly, deleteCreatingTasks);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setResource(String resourceName, double capacity, UniqueId nodeId) {
|
||||
getCurrentRuntime().setResource(resourceName, capacity, nodeId);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void killActor(BaseActor actor, boolean noReconstruction) {
|
||||
getCurrentRuntime().killActor(actor, noReconstruction);
|
||||
}
|
||||
|
||||
@Override
|
||||
public RayObject call(RayFunc func, Object[] args, CallOptions options) {
|
||||
return getCurrentRuntime().call(func, args, options);
|
||||
}
|
||||
|
||||
@Override
|
||||
public RayObject call(PyRemoteFunction pyRemoteFunction, Object[] args,
|
||||
CallOptions options) {
|
||||
return getCurrentRuntime().call(pyRemoteFunction, args, options);
|
||||
}
|
||||
|
||||
@Override
|
||||
public RayObject callActor(RayActor<?> actor, RayFunc func, Object[] args) {
|
||||
return getCurrentRuntime().callActor(actor, func, args);
|
||||
}
|
||||
|
||||
@Override
|
||||
public RayObject callActor(RayPyActor pyActor, PyActorMethod pyActorMethod, Object[] args) {
|
||||
return getCurrentRuntime().callActor(pyActor, pyActorMethod, args);
|
||||
}
|
||||
|
||||
@Override
|
||||
public <T> RayActor<T> createActor(RayFunc actorFactoryFunc, Object[] args,
|
||||
ActorCreationOptions options) {
|
||||
return getCurrentRuntime().createActor(actorFactoryFunc, args, options);
|
||||
}
|
||||
|
||||
@Override
|
||||
public RayPyActor createActor(PyActorClass pyActorClass, Object[] args,
|
||||
ActorCreationOptions options) {
|
||||
return getCurrentRuntime().createActor(pyActorClass, args, options);
|
||||
}
|
||||
|
||||
@Override
|
||||
public RuntimeContext getRuntimeContext() {
|
||||
return getCurrentRuntime().getRuntimeContext();
|
||||
}
|
||||
|
||||
@Override
|
||||
public Object getAsyncContext() {
|
||||
return getCurrentRuntime();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setAsyncContext(Object asyncContext) {
|
||||
currentThreadRuntime.set((RayNativeRuntime) asyncContext);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Runnable wrapRunnable(Runnable runnable) {
|
||||
Object asyncContext = getAsyncContext();
|
||||
return () -> {
|
||||
setAsyncContext(asyncContext);
|
||||
runnable.run();
|
||||
};
|
||||
}
|
||||
|
||||
@Override
|
||||
public Callable wrapCallable(Callable callable) {
|
||||
Object asyncContext = getAsyncContext();
|
||||
return () -> {
|
||||
setAsyncContext(asyncContext);
|
||||
return callable.call();
|
||||
};
|
||||
}
|
||||
}
|
|
@ -3,7 +3,6 @@ package org.ray.runtime;
|
|||
import com.google.common.base.Preconditions;
|
||||
import java.io.File;
|
||||
import java.io.IOException;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
import org.apache.commons.io.FileUtils;
|
||||
import org.ray.api.BaseActor;
|
||||
|
@ -11,7 +10,6 @@ import org.ray.api.id.JobId;
|
|||
import org.ray.api.id.UniqueId;
|
||||
import org.ray.runtime.config.RayConfig;
|
||||
import org.ray.runtime.context.NativeWorkerContext;
|
||||
import org.ray.runtime.functionmanager.FunctionManager;
|
||||
import org.ray.runtime.gcs.GcsClient;
|
||||
import org.ray.runtime.gcs.GcsClientOptions;
|
||||
import org.ray.runtime.gcs.RedisClient;
|
||||
|
@ -34,10 +32,6 @@ public final class RayNativeRuntime extends AbstractRayRuntime {
|
|||
|
||||
private RunManager manager = null;
|
||||
|
||||
/**
|
||||
* The native pointer of core worker.
|
||||
*/
|
||||
private long nativeCoreWorkerPointer;
|
||||
|
||||
static {
|
||||
LOGGER.debug("Loading native libraries.");
|
||||
|
@ -55,14 +49,13 @@ public final class RayNativeRuntime extends AbstractRayRuntime {
|
|||
|
||||
JniUtils.loadLibrary("core_worker_library_java", true);
|
||||
LOGGER.debug("Native libraries loaded.");
|
||||
// Reset library path at runtime.
|
||||
resetLibraryPath(rayConfig);
|
||||
try {
|
||||
FileUtils.forceMkdir(new File(rayConfig.logDir));
|
||||
} catch (IOException e) {
|
||||
throw new RuntimeException("Failed to create the log directory.", e);
|
||||
}
|
||||
nativeSetup(rayConfig.logDir, rayConfig.rayletConfigParameters);
|
||||
Runtime.getRuntime().addShutdownHook(new Thread(RayNativeRuntime::nativeShutdownHook));
|
||||
}
|
||||
|
||||
private static void resetLibraryPath(RayConfig rayConfig) {
|
||||
|
@ -71,11 +64,12 @@ public final class RayNativeRuntime extends AbstractRayRuntime {
|
|||
JniUtils.resetLibraryPath(libraryPath);
|
||||
}
|
||||
|
||||
public RayNativeRuntime(RayConfig rayConfig, FunctionManager functionManager) {
|
||||
super(rayConfig, functionManager);
|
||||
// Reset library path at runtime.
|
||||
resetLibraryPath(rayConfig);
|
||||
public RayNativeRuntime(RayConfig rayConfig) {
|
||||
super(rayConfig);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void start() {
|
||||
if (rayConfig.getRedisAddress() == null) {
|
||||
manager = new RunManager(rayConfig);
|
||||
manager.startRayProcesses(true);
|
||||
|
@ -86,21 +80,21 @@ public final class RayNativeRuntime extends AbstractRayRuntime {
|
|||
if (rayConfig.getJobId() == JobId.NIL) {
|
||||
rayConfig.setJobId(gcsClient.nextJobId());
|
||||
}
|
||||
int numWorkersPerProcess =
|
||||
rayConfig.workerMode == WorkerType.DRIVER ? 1 : rayConfig.numWorkersPerProcess;
|
||||
// TODO(qwang): Get object_store_socket_name and raylet_socket_name from Redis.
|
||||
nativeCoreWorkerPointer = nativeInitCoreWorker(rayConfig.workerMode.getNumber(),
|
||||
rayConfig.objectStoreSocketName, rayConfig.rayletSocketName,
|
||||
nativeInitialize(rayConfig.workerMode.getNumber(),
|
||||
rayConfig.nodeIp, rayConfig.getNodeManagerPort(),
|
||||
rayConfig.workerMode == WorkerType.DRIVER ? System.getProperty("user.dir") : "",
|
||||
rayConfig.objectStoreSocketName, rayConfig.rayletSocketName,
|
||||
(rayConfig.workerMode == WorkerType.DRIVER ? rayConfig.getJobId() : JobId.NIL).getBytes(),
|
||||
new GcsClientOptions(rayConfig));
|
||||
Preconditions.checkState(nativeCoreWorkerPointer != 0);
|
||||
new GcsClientOptions(rayConfig), numWorkersPerProcess,
|
||||
rayConfig.logDir, rayConfig.rayletConfigParameters);
|
||||
|
||||
workerContext = new NativeWorkerContext(nativeCoreWorkerPointer);
|
||||
taskExecutor = new NativeTaskExecutor(nativeCoreWorkerPointer, this);
|
||||
objectStore = new NativeObjectStore(workerContext, nativeCoreWorkerPointer);
|
||||
taskSubmitter = new NativeTaskSubmitter(nativeCoreWorkerPointer);
|
||||
|
||||
// register
|
||||
registerWorker();
|
||||
taskExecutor = new NativeTaskExecutor(this);
|
||||
workerContext = new NativeWorkerContext();
|
||||
objectStore = new NativeObjectStore(workerContext);
|
||||
taskSubmitter = new NativeTaskSubmitter();
|
||||
|
||||
LOGGER.info("RayNativeRuntime started with store {}, raylet {}",
|
||||
rayConfig.objectStoreSocketName, rayConfig.rayletSocketName);
|
||||
|
@ -108,10 +102,7 @@ public final class RayNativeRuntime extends AbstractRayRuntime {
|
|||
|
||||
@Override
|
||||
public void shutdown() {
|
||||
if (nativeCoreWorkerPointer != 0) {
|
||||
nativeDestroyCoreWorker(nativeCoreWorkerPointer);
|
||||
nativeCoreWorkerPointer = 0;
|
||||
}
|
||||
nativeShutdown();
|
||||
if (null != manager) {
|
||||
manager.cleanup();
|
||||
manager = null;
|
||||
|
@ -131,75 +122,56 @@ public final class RayNativeRuntime extends AbstractRayRuntime {
|
|||
if (nodeId == null) {
|
||||
nodeId = UniqueId.NIL;
|
||||
}
|
||||
nativeSetResource(nativeCoreWorkerPointer, resourceName, capacity, nodeId.getBytes());
|
||||
nativeSetResource(resourceName, capacity, nodeId.getBytes());
|
||||
}
|
||||
|
||||
@Override
|
||||
public void killActor(BaseActor actor, boolean noReconstruction) {
|
||||
nativeKillActor(nativeCoreWorkerPointer, actor.getId().getBytes(), noReconstruction);
|
||||
nativeKillActor(actor.getId().getBytes(), noReconstruction);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Object getAsyncContext() {
|
||||
return null;
|
||||
return new AsyncContext(workerContext.getCurrentWorkerId(),
|
||||
workerContext.getCurrentClassLoader());
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setAsyncContext(Object asyncContext) {
|
||||
nativeSetCoreWorker(((AsyncContext) asyncContext).workerId.getBytes());
|
||||
workerContext.setCurrentClassLoader(((AsyncContext) asyncContext).currentClassLoader);
|
||||
super.setAsyncContext(asyncContext);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void run() {
|
||||
nativeRunTaskExecutor(nativeCoreWorkerPointer);
|
||||
Preconditions.checkState(rayConfig.workerMode == WorkerType.WORKER);
|
||||
nativeRunTaskExecutor(taskExecutor);
|
||||
}
|
||||
|
||||
public long getNativeCoreWorkerPointer() {
|
||||
return nativeCoreWorkerPointer;
|
||||
}
|
||||
private static native void nativeInitialize(int workerMode, String ndoeIpAddress,
|
||||
int nodeManagerPort, String driverName, String storeSocket, String rayletSocket,
|
||||
byte[] jobId, GcsClientOptions gcsClientOptions, int numWorkersPerProcess,
|
||||
String logDir, Map<String, String> rayletConfigParameters);
|
||||
|
||||
public TaskExecutor getTaskExecutor() {
|
||||
return taskExecutor;
|
||||
}
|
||||
private static native void nativeRunTaskExecutor(TaskExecutor taskExecutor);
|
||||
|
||||
/**
|
||||
* Register this worker or driver to GCS.
|
||||
*/
|
||||
private void registerWorker() {
|
||||
RedisClient redisClient = new RedisClient(rayConfig.getRedisAddress(), rayConfig.redisPassword);
|
||||
Map<String, String> workerInfo = new HashMap<>();
|
||||
String workerId = new String(workerContext.getCurrentWorkerId().getBytes());
|
||||
if (rayConfig.workerMode == WorkerType.DRIVER) {
|
||||
workerInfo.put("node_ip_address", rayConfig.nodeIp);
|
||||
workerInfo.put("driver_id", workerId);
|
||||
workerInfo.put("start_time", String.valueOf(System.currentTimeMillis()));
|
||||
workerInfo.put("plasma_store_socket", rayConfig.objectStoreSocketName);
|
||||
workerInfo.put("raylet_socket", rayConfig.rayletSocketName);
|
||||
workerInfo.put("name", System.getProperty("user.dir"));
|
||||
//TODO: worker.redis_client.hmset(b"Drivers:" + worker.workerId, driver_info)
|
||||
redisClient.hmset("Drivers:" + workerId, workerInfo);
|
||||
} else {
|
||||
workerInfo.put("node_ip_address", rayConfig.nodeIp);
|
||||
workerInfo.put("plasma_store_socket", rayConfig.objectStoreSocketName);
|
||||
workerInfo.put("raylet_socket", rayConfig.rayletSocketName);
|
||||
//TODO: b"Workers:" + worker.workerId,
|
||||
redisClient.hmset("Workers:" + workerId, workerInfo);
|
||||
private static native void nativeShutdown();
|
||||
|
||||
private static native void nativeSetResource(String resourceName, double capacity, byte[] nodeId);
|
||||
|
||||
private static native void nativeKillActor(byte[] actorId, boolean noReconstruction);
|
||||
|
||||
private static native void nativeSetCoreWorker(byte[] workerId);
|
||||
|
||||
static class AsyncContext {
|
||||
|
||||
public final UniqueId workerId;
|
||||
public final ClassLoader currentClassLoader;
|
||||
|
||||
AsyncContext(UniqueId workerId, ClassLoader currentClassLoader) {
|
||||
this.workerId = workerId;
|
||||
this.currentClassLoader = currentClassLoader;
|
||||
}
|
||||
}
|
||||
|
||||
private static native long nativeInitCoreWorker(int workerMode, String storeSocket,
|
||||
String rayletSocket, String nodeIpAddress, int nodeManagerPort, byte[] jobId,
|
||||
GcsClientOptions gcsClientOptions);
|
||||
|
||||
private static native void nativeRunTaskExecutor(long nativeCoreWorkerPointer);
|
||||
|
||||
private static native void nativeDestroyCoreWorker(long nativeCoreWorkerPointer);
|
||||
|
||||
private static native void nativeSetup(String logDir, Map<String, String> rayletConfigParameters);
|
||||
|
||||
private static native void nativeShutdownHook();
|
||||
|
||||
private static native void nativeSetResource(long conn, String resourceName, double capacity,
|
||||
byte[] nodeId);
|
||||
|
||||
private static native void nativeKillActor(long nativeCoreWorkerPointer, byte[] actorId,
|
||||
boolean noReconstruction);
|
||||
}
|
||||
|
|
|
@ -0,0 +1,33 @@
|
|||
package org.ray.runtime;
|
||||
|
||||
import org.ray.api.runtime.RayRuntime;
|
||||
import org.ray.runtime.config.RayConfig;
|
||||
import org.ray.runtime.context.WorkerContext;
|
||||
import org.ray.runtime.functionmanager.FunctionManager;
|
||||
import org.ray.runtime.gcs.GcsClient;
|
||||
import org.ray.runtime.object.ObjectStore;
|
||||
|
||||
/**
|
||||
* This interface is required to make {@link RayRuntimeProxy} work.
|
||||
*/
|
||||
public interface RayRuntimeInternal extends RayRuntime {
|
||||
|
||||
/**
|
||||
* Start runtime.
|
||||
*/
|
||||
void start();
|
||||
|
||||
WorkerContext getWorkerContext();
|
||||
|
||||
ObjectStore getObjectStore();
|
||||
|
||||
FunctionManager getFunctionManager();
|
||||
|
||||
RayConfig getRayConfig();
|
||||
|
||||
GcsClient getGcsClient();
|
||||
|
||||
void setIsContextSet(boolean isContextSet);
|
||||
|
||||
void run();
|
||||
}
|
|
@ -0,0 +1,83 @@
|
|||
package org.ray.runtime;
|
||||
|
||||
import java.lang.reflect.InvocationHandler;
|
||||
import java.lang.reflect.InvocationTargetException;
|
||||
import java.lang.reflect.Method;
|
||||
import org.ray.api.exception.RayException;
|
||||
import org.ray.api.runtime.RayRuntime;
|
||||
import org.ray.runtime.config.RunMode;
|
||||
|
||||
/**
|
||||
* Protect a ray runtime with context checks for all methods of {@link RayRuntime} (except {@link
|
||||
* RayRuntime#shutdown(boolean)}).
|
||||
*/
|
||||
public class RayRuntimeProxy implements InvocationHandler {
|
||||
|
||||
/**
|
||||
* The original runtime.
|
||||
*/
|
||||
private AbstractRayRuntime obj;
|
||||
|
||||
private RayRuntimeProxy(AbstractRayRuntime obj) {
|
||||
this.obj = obj;
|
||||
}
|
||||
|
||||
public AbstractRayRuntime getRuntimeObject() {
|
||||
return obj;
|
||||
}
|
||||
|
||||
/**
|
||||
* Generate a new instance of {@link RayRuntimeInternal} with additional context check.
|
||||
*/
|
||||
static RayRuntimeInternal newInstance(AbstractRayRuntime obj) {
|
||||
return (RayRuntimeInternal) java.lang.reflect.Proxy
|
||||
.newProxyInstance(obj.getClass().getClassLoader(), new Class<?>[]{RayRuntimeInternal.class},
|
||||
new RayRuntimeProxy(obj));
|
||||
}
|
||||
|
||||
@Override
|
||||
public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
|
||||
if (isInterfaceMethod(method) && !method.getName().equals("shutdown") && !method.getName()
|
||||
.equals("setAsyncContext")) {
|
||||
checkIsContextSet();
|
||||
}
|
||||
try {
|
||||
return method.invoke(obj, args);
|
||||
} catch (InvocationTargetException e) {
|
||||
if (e.getCause() != null) {
|
||||
throw e.getCause();
|
||||
} else {
|
||||
throw e;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Whether the method is defined in the {@link RayRuntime} interface.
|
||||
*/
|
||||
private boolean isInterfaceMethod(Method method) {
|
||||
try {
|
||||
RayRuntime.class.getMethod(method.getName(), method.getParameterTypes());
|
||||
return true;
|
||||
} catch (NoSuchMethodException e) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if thread context is set.
|
||||
* <p/>
|
||||
* This method should be invoked at the beginning of most public methods of {@link RayRuntime},
|
||||
* otherwise the native code might crash due to thread local core worker was not set. We check it
|
||||
* for {@link AbstractRayRuntime} instead of {@link RayNativeRuntime} because we want to catch the
|
||||
* error even if the application runs in {@link RunMode#SINGLE_PROCESS} mode.
|
||||
*/
|
||||
private void checkIsContextSet() {
|
||||
if (!obj.isContextSet.get()) {
|
||||
throw new RayException(
|
||||
"`Ray.wrap***` is not called on the current thread."
|
||||
+ " If you want to use Ray API in your own threads,"
|
||||
+ " please wrap your executable with `Ray.wrap***`.");
|
||||
}
|
||||
}
|
||||
}
|
|
@ -7,11 +7,7 @@ import java.io.ObjectInput;
|
|||
import java.io.ObjectOutput;
|
||||
import java.util.List;
|
||||
import org.ray.api.BaseActor;
|
||||
import org.ray.api.Ray;
|
||||
import org.ray.api.id.ActorId;
|
||||
import org.ray.api.runtime.RayRuntime;
|
||||
import org.ray.runtime.RayMultiWorkerNativeRuntime;
|
||||
import org.ray.runtime.RayNativeRuntime;
|
||||
import org.ray.runtime.generated.Common.Language;
|
||||
|
||||
/**
|
||||
|
@ -20,20 +16,17 @@ import org.ray.runtime.generated.Common.Language;
|
|||
*/
|
||||
public abstract class NativeRayActor implements BaseActor, Externalizable {
|
||||
|
||||
/**
|
||||
* Address of core worker.
|
||||
*/
|
||||
long nativeCoreWorkerPointer;
|
||||
/**
|
||||
* ID of the actor.
|
||||
*/
|
||||
byte[] actorId;
|
||||
|
||||
NativeRayActor(long nativeCoreWorkerPointer, byte[] actorId) {
|
||||
Preconditions.checkState(nativeCoreWorkerPointer != 0);
|
||||
private Language language;
|
||||
|
||||
NativeRayActor(byte[] actorId, Language language) {
|
||||
Preconditions.checkState(!ActorId.fromBytes(actorId).isNil());
|
||||
this.nativeCoreWorkerPointer = nativeCoreWorkerPointer;
|
||||
this.actorId = actorId;
|
||||
this.language = language;
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -42,14 +35,12 @@ public abstract class NativeRayActor implements BaseActor, Externalizable {
|
|||
NativeRayActor() {
|
||||
}
|
||||
|
||||
public static NativeRayActor create(long nativeCoreWorkerPointer, byte[] actorId,
|
||||
Language language) {
|
||||
Preconditions.checkState(nativeCoreWorkerPointer != 0);
|
||||
public static NativeRayActor create(byte[] actorId, Language language) {
|
||||
switch (language) {
|
||||
case JAVA:
|
||||
return new NativeRayJavaActor(nativeCoreWorkerPointer, actorId);
|
||||
return new NativeRayJavaActor(actorId);
|
||||
case PYTHON:
|
||||
return new NativeRayPyActor(nativeCoreWorkerPointer, actorId);
|
||||
return new NativeRayPyActor(actorId);
|
||||
default:
|
||||
throw new IllegalStateException("Unknown actor handle language: " + language);
|
||||
}
|
||||
|
@ -61,18 +52,19 @@ public abstract class NativeRayActor implements BaseActor, Externalizable {
|
|||
}
|
||||
|
||||
public Language getLanguage() {
|
||||
return Language.forNumber(nativeGetLanguage(nativeCoreWorkerPointer, actorId));
|
||||
return language;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void writeExternal(ObjectOutput out) throws IOException {
|
||||
out.writeObject(toBytes());
|
||||
out.writeObject(nativeSerialize(actorId));
|
||||
out.writeObject(language);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException {
|
||||
nativeCoreWorkerPointer = getNativeCoreWorkerPointer();
|
||||
actorId = nativeDeserialize(nativeCoreWorkerPointer, (byte[]) in.readObject());
|
||||
actorId = nativeDeserialize((byte[]) in.readObject());
|
||||
language = (Language) in.readObject();
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -81,7 +73,7 @@ public abstract class NativeRayActor implements BaseActor, Externalizable {
|
|||
* @return the bytes of the actor handle
|
||||
*/
|
||||
public byte[] toBytes() {
|
||||
return nativeSerialize(nativeCoreWorkerPointer, actorId);
|
||||
return nativeSerialize(actorId);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -90,21 +82,10 @@ public abstract class NativeRayActor implements BaseActor, Externalizable {
|
|||
* @return the bytes of an actor handle
|
||||
*/
|
||||
public static NativeRayActor fromBytes(byte[] bytes) {
|
||||
long nativeCoreWorkerPointer = getNativeCoreWorkerPointer();
|
||||
byte[] actorId = nativeDeserialize(nativeCoreWorkerPointer, bytes);
|
||||
Language language = Language.forNumber(nativeGetLanguage(nativeCoreWorkerPointer, actorId));
|
||||
byte[] actorId = nativeDeserialize(bytes);
|
||||
Language language = Language.forNumber(nativeGetLanguage(actorId));
|
||||
Preconditions.checkNotNull(language);
|
||||
return create(nativeCoreWorkerPointer, actorId, language);
|
||||
}
|
||||
|
||||
private static long getNativeCoreWorkerPointer() {
|
||||
RayRuntime runtime = Ray.internal();
|
||||
if (runtime instanceof RayMultiWorkerNativeRuntime) {
|
||||
runtime = ((RayMultiWorkerNativeRuntime) runtime).getCurrentRuntime();
|
||||
}
|
||||
Preconditions.checkState(runtime instanceof RayNativeRuntime);
|
||||
|
||||
return ((RayNativeRuntime) runtime).getNativeCoreWorkerPointer();
|
||||
return create(actorId, language);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -112,13 +93,11 @@ public abstract class NativeRayActor implements BaseActor, Externalizable {
|
|||
// TODO(zhijunfu): do we need to free the ActorHandle in core worker?
|
||||
}
|
||||
|
||||
private static native int nativeGetLanguage(
|
||||
long nativeCoreWorkerPointer, byte[] actorId);
|
||||
private static native int nativeGetLanguage(byte[] actorId);
|
||||
|
||||
static native List<String> nativeGetActorCreationTaskFunctionDescriptor(
|
||||
long nativeCoreWorkerPointer, byte[] actorId);
|
||||
static native List<String> nativeGetActorCreationTaskFunctionDescriptor(byte[] actorId);
|
||||
|
||||
private static native byte[] nativeSerialize(long nativeCoreWorkerPointer, byte[] actorId);
|
||||
private static native byte[] nativeSerialize(byte[] actorId);
|
||||
|
||||
private static native byte[] nativeDeserialize(long nativeCoreWorkerPointer, byte[] data);
|
||||
private static native byte[] nativeDeserialize(byte[] data);
|
||||
}
|
||||
|
|
|
@ -11,8 +11,8 @@ import org.ray.runtime.generated.Common.Language;
|
|||
*/
|
||||
public class NativeRayJavaActor extends NativeRayActor implements RayActor {
|
||||
|
||||
NativeRayJavaActor(long nativeCoreWorkerPointer, byte[] actorId) {
|
||||
super(nativeCoreWorkerPointer, actorId);
|
||||
NativeRayJavaActor(byte[] actorId) {
|
||||
super(actorId, Language.JAVA);
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
@ -11,8 +11,8 @@ import org.ray.runtime.generated.Common.Language;
|
|||
*/
|
||||
public class NativeRayPyActor extends NativeRayActor implements RayPyActor {
|
||||
|
||||
NativeRayPyActor(long nativeCoreWorkerPointer, byte[] actorId) {
|
||||
super(nativeCoreWorkerPointer, actorId);
|
||||
NativeRayPyActor(byte[] actorId) {
|
||||
super(actorId, Language.PYTHON);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -24,12 +24,12 @@ public class NativeRayPyActor extends NativeRayActor implements RayPyActor {
|
|||
|
||||
@Override
|
||||
public String getModuleName() {
|
||||
return nativeGetActorCreationTaskFunctionDescriptor(nativeCoreWorkerPointer, actorId).get(0);
|
||||
return nativeGetActorCreationTaskFunctionDescriptor(actorId).get(0);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getClassName() {
|
||||
return nativeGetActorCreationTaskFunctionDescriptor(nativeCoreWorkerPointer, actorId).get(1);
|
||||
return nativeGetActorCreationTaskFunctionDescriptor(actorId).get(1);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -93,10 +93,6 @@ public class RayConfig {
|
|||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Number of threads that execute tasks.
|
||||
*/
|
||||
public final int numberExecThreadsForDevRuntime;
|
||||
|
||||
public final int numWorkersPerProcess;
|
||||
|
||||
|
@ -220,9 +216,6 @@ public class RayConfig {
|
|||
jobResourcePath = null;
|
||||
}
|
||||
|
||||
// Number of threads that execute tasks.
|
||||
numberExecThreadsForDevRuntime = config.getInt("ray.dev-runtime.execution-parallelism");
|
||||
|
||||
numWorkersPerProcess = config.getInt("ray.raylet.config.num_workers_per_process_java");
|
||||
|
||||
gcsServiceEnabled = System.getenv("RAY_GCS_SERVICE_ENABLED") == null ||
|
||||
|
|
|
@ -16,6 +16,7 @@ public class LocalModeWorkerContext implements WorkerContext {
|
|||
|
||||
private final JobId jobId;
|
||||
private ThreadLocal<TaskSpec> currentTask = new ThreadLocal<>();
|
||||
private final ThreadLocal<UniqueId> currentWorkerId = new ThreadLocal<>();
|
||||
|
||||
public LocalModeWorkerContext(JobId jobId) {
|
||||
this.jobId = jobId;
|
||||
|
@ -23,7 +24,11 @@ public class LocalModeWorkerContext implements WorkerContext {
|
|||
|
||||
@Override
|
||||
public UniqueId getCurrentWorkerId() {
|
||||
throw new UnsupportedOperationException();
|
||||
return currentWorkerId.get();
|
||||
}
|
||||
|
||||
public void setCurrentWorkerId(UniqueId workerId) {
|
||||
currentWorkerId.set(workerId);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -12,61 +12,52 @@ import org.ray.runtime.generated.Common.TaskType;
|
|||
*/
|
||||
public class NativeWorkerContext implements WorkerContext {
|
||||
|
||||
/**
|
||||
* The native pointer of core worker.
|
||||
*/
|
||||
private final long nativeCoreWorkerPointer;
|
||||
|
||||
private ClassLoader currentClassLoader;
|
||||
|
||||
public NativeWorkerContext(long nativeCoreWorkerPointer) {
|
||||
this.nativeCoreWorkerPointer = nativeCoreWorkerPointer;
|
||||
}
|
||||
private final ThreadLocal<ClassLoader> currentClassLoader = new ThreadLocal<>();
|
||||
|
||||
@Override
|
||||
public UniqueId getCurrentWorkerId() {
|
||||
return UniqueId.fromByteBuffer(nativeGetCurrentWorkerId(nativeCoreWorkerPointer));
|
||||
return UniqueId.fromByteBuffer(nativeGetCurrentWorkerId());
|
||||
}
|
||||
|
||||
@Override
|
||||
public JobId getCurrentJobId() {
|
||||
return JobId.fromByteBuffer(nativeGetCurrentJobId(nativeCoreWorkerPointer));
|
||||
return JobId.fromByteBuffer(nativeGetCurrentJobId());
|
||||
}
|
||||
|
||||
@Override
|
||||
public ActorId getCurrentActorId() {
|
||||
return ActorId.fromByteBuffer(nativeGetCurrentActorId(nativeCoreWorkerPointer));
|
||||
return ActorId.fromByteBuffer(nativeGetCurrentActorId());
|
||||
}
|
||||
|
||||
@Override
|
||||
public ClassLoader getCurrentClassLoader() {
|
||||
return currentClassLoader;
|
||||
return currentClassLoader.get();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setCurrentClassLoader(ClassLoader currentClassLoader) {
|
||||
if (this.currentClassLoader != currentClassLoader) {
|
||||
this.currentClassLoader = currentClassLoader;
|
||||
if (this.currentClassLoader.get() != currentClassLoader) {
|
||||
this.currentClassLoader.set(currentClassLoader);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public TaskType getCurrentTaskType() {
|
||||
return TaskType.forNumber(nativeGetCurrentTaskType(nativeCoreWorkerPointer));
|
||||
return TaskType.forNumber(nativeGetCurrentTaskType());
|
||||
}
|
||||
|
||||
@Override
|
||||
public TaskId getCurrentTaskId() {
|
||||
return TaskId.fromByteBuffer(nativeGetCurrentTaskId(nativeCoreWorkerPointer));
|
||||
return TaskId.fromByteBuffer(nativeGetCurrentTaskId());
|
||||
}
|
||||
|
||||
private static native int nativeGetCurrentTaskType(long nativeCoreWorkerPointer);
|
||||
private static native int nativeGetCurrentTaskType();
|
||||
|
||||
private static native ByteBuffer nativeGetCurrentTaskId(long nativeCoreWorkerPointer);
|
||||
private static native ByteBuffer nativeGetCurrentTaskId();
|
||||
|
||||
private static native ByteBuffer nativeGetCurrentJobId(long nativeCoreWorkerPointer);
|
||||
private static native ByteBuffer nativeGetCurrentJobId();
|
||||
|
||||
private static native ByteBuffer nativeGetCurrentWorkerId(long nativeCoreWorkerPointer);
|
||||
private static native ByteBuffer nativeGetCurrentWorkerId();
|
||||
|
||||
private static native ByteBuffer nativeGetCurrentActorId(long nativeCoreWorkerPointer);
|
||||
private static native ByteBuffer nativeGetCurrentActorId();
|
||||
}
|
||||
|
|
|
@ -6,15 +6,15 @@ import org.ray.api.id.ActorId;
|
|||
import org.ray.api.id.JobId;
|
||||
import org.ray.api.runtimecontext.NodeInfo;
|
||||
import org.ray.api.runtimecontext.RuntimeContext;
|
||||
import org.ray.runtime.AbstractRayRuntime;
|
||||
import org.ray.runtime.RayRuntimeInternal;
|
||||
import org.ray.runtime.config.RunMode;
|
||||
import org.ray.runtime.generated.Common.TaskType;
|
||||
|
||||
public class RuntimeContextImpl implements RuntimeContext {
|
||||
|
||||
private AbstractRayRuntime runtime;
|
||||
private RayRuntimeInternal runtime;
|
||||
|
||||
public RuntimeContextImpl(AbstractRayRuntime runtime) {
|
||||
public RuntimeContextImpl(RayRuntimeInternal runtime) {
|
||||
this.runtime = runtime;
|
||||
}
|
||||
|
||||
|
|
|
@ -15,56 +15,48 @@ public class NativeObjectStore extends ObjectStore {
|
|||
|
||||
private static final Logger LOGGER = LoggerFactory.getLogger(NativeObjectStore.class);
|
||||
|
||||
/**
|
||||
* The native pointer of core worker.
|
||||
*/
|
||||
private final long nativeCoreWorkerPointer;
|
||||
|
||||
public NativeObjectStore(WorkerContext workerContext, long nativeCoreWorkerPointer) {
|
||||
public NativeObjectStore(WorkerContext workerContext) {
|
||||
super(workerContext);
|
||||
this.nativeCoreWorkerPointer = nativeCoreWorkerPointer;
|
||||
}
|
||||
|
||||
@Override
|
||||
public ObjectId putRaw(NativeRayObject obj) {
|
||||
return new ObjectId(nativePut(nativeCoreWorkerPointer, obj));
|
||||
return new ObjectId(nativePut(obj));
|
||||
}
|
||||
|
||||
@Override
|
||||
public void putRaw(NativeRayObject obj, ObjectId objectId) {
|
||||
nativePut(nativeCoreWorkerPointer, objectId.getBytes(), obj);
|
||||
nativePut(objectId.getBytes(), obj);
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<NativeRayObject> getRaw(List<ObjectId> objectIds, long timeoutMs) {
|
||||
return nativeGet(nativeCoreWorkerPointer, toBinaryList(objectIds), timeoutMs);
|
||||
return nativeGet(toBinaryList(objectIds), timeoutMs);
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<Boolean> wait(List<ObjectId> objectIds, int numObjects, long timeoutMs) {
|
||||
return nativeWait(nativeCoreWorkerPointer, toBinaryList(objectIds), numObjects, timeoutMs);
|
||||
return nativeWait(toBinaryList(objectIds), numObjects, timeoutMs);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void delete(List<ObjectId> objectIds, boolean localOnly, boolean deleteCreatingTasks) {
|
||||
nativeDelete(nativeCoreWorkerPointer, toBinaryList(objectIds), localOnly, deleteCreatingTasks);
|
||||
nativeDelete(toBinaryList(objectIds), localOnly, deleteCreatingTasks);
|
||||
}
|
||||
|
||||
private static List<byte[]> toBinaryList(List<ObjectId> ids) {
|
||||
return ids.stream().map(BaseId::getBytes).collect(Collectors.toList());
|
||||
}
|
||||
|
||||
private static native byte[] nativePut(long nativeCoreWorkerPointer, NativeRayObject obj);
|
||||
private static native byte[] nativePut(NativeRayObject obj);
|
||||
|
||||
private static native void nativePut(long nativeCoreWorkerPointer, byte[] objectId,
|
||||
NativeRayObject obj);
|
||||
private static native void nativePut(byte[] objectId, NativeRayObject obj);
|
||||
|
||||
private static native List<NativeRayObject> nativeGet(long nativeCoreWorkerPointer,
|
||||
List<byte[]> ids, long timeoutMs);
|
||||
private static native List<NativeRayObject> nativeGet(List<byte[]> ids, long timeoutMs);
|
||||
|
||||
private static native List<Boolean> nativeWait(long nativeCoreWorkerPointer,
|
||||
List<byte[]> objectIds, int numObjects, long timeoutMs);
|
||||
private static native List<Boolean> nativeWait(List<byte[]> objectIds, int numObjects,
|
||||
long timeoutMs);
|
||||
|
||||
private static native void nativeDelete(long nativeCoreWorkerPointer, List<byte[]> objectIds,
|
||||
boolean localOnly, boolean deleteCreatingTasks);
|
||||
private static native void nativeDelete(List<byte[]> objectIds, boolean localOnly,
|
||||
boolean deleteCreatingTasks);
|
||||
}
|
||||
|
|
|
@ -1,9 +1,7 @@
|
|||
package org.ray.runtime.runner.worker;
|
||||
|
||||
import org.ray.api.Ray;
|
||||
import org.ray.api.runtime.RayRuntime;
|
||||
import org.ray.runtime.RayMultiWorkerNativeRuntime;
|
||||
import org.ray.runtime.RayNativeRuntime;
|
||||
import org.ray.runtime.RayRuntimeInternal;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
|
@ -25,14 +23,7 @@ public class DefaultWorker {
|
|||
});
|
||||
Ray.init();
|
||||
LOGGER.info("Worker started.");
|
||||
RayRuntime runtime = Ray.internal();
|
||||
if (runtime instanceof RayNativeRuntime) {
|
||||
((RayNativeRuntime)runtime).run();
|
||||
} else if (runtime instanceof RayMultiWorkerNativeRuntime) {
|
||||
((RayMultiWorkerNativeRuntime)runtime).run();
|
||||
} else {
|
||||
throw new RuntimeException("Unknown RayRuntime: " + runtime);
|
||||
}
|
||||
((RayRuntimeInternal) Ray.internal()).run();
|
||||
} catch (Exception e) {
|
||||
LOGGER.error("Failed to start worker.", e);
|
||||
}
|
||||
|
|
|
@ -6,9 +6,7 @@ import java.util.List;
|
|||
import org.ray.api.Ray;
|
||||
import org.ray.api.RayObject;
|
||||
import org.ray.api.id.ObjectId;
|
||||
import org.ray.api.runtime.RayRuntime;
|
||||
import org.ray.runtime.AbstractRayRuntime;
|
||||
import org.ray.runtime.RayMultiWorkerNativeRuntime;
|
||||
import org.ray.runtime.RayRuntimeInternal;
|
||||
import org.ray.runtime.generated.Common.Language;
|
||||
import org.ray.runtime.object.NativeRayObject;
|
||||
import org.ray.runtime.object.ObjectSerializer;
|
||||
|
@ -43,11 +41,7 @@ public class ArgumentsBuilder {
|
|||
} else {
|
||||
value = ObjectSerializer.serialize(arg);
|
||||
if (value.data.length > LARGEST_SIZE_PASS_BY_VALUE) {
|
||||
RayRuntime runtime = Ray.internal();
|
||||
if (runtime instanceof RayMultiWorkerNativeRuntime) {
|
||||
runtime = ((RayMultiWorkerNativeRuntime) runtime).getCurrentRuntime();
|
||||
}
|
||||
id = ((AbstractRayRuntime) runtime).getObjectStore()
|
||||
id = ((RayRuntimeInternal) Ray.internal()).getObjectStore()
|
||||
.putRaw(value);
|
||||
value = null;
|
||||
}
|
||||
|
|
|
@ -1,17 +1,40 @@
|
|||
package org.ray.runtime.task;
|
||||
|
||||
import org.ray.api.id.ActorId;
|
||||
import org.ray.runtime.AbstractRayRuntime;
|
||||
import org.ray.api.id.UniqueId;
|
||||
import org.ray.runtime.RayRuntimeInternal;
|
||||
import org.ray.runtime.task.LocalModeTaskExecutor.LocalActorContext;
|
||||
|
||||
/**
|
||||
* Task executor for local mode.
|
||||
*/
|
||||
public class LocalModeTaskExecutor extends TaskExecutor {
|
||||
public class LocalModeTaskExecutor extends TaskExecutor<LocalActorContext> {
|
||||
|
||||
public LocalModeTaskExecutor(AbstractRayRuntime runtime) {
|
||||
static class LocalActorContext extends TaskExecutor.ActorContext {
|
||||
|
||||
/**
|
||||
* The worker ID of the actor.
|
||||
*/
|
||||
private final UniqueId workerId;
|
||||
|
||||
public LocalActorContext(UniqueId workerId) {
|
||||
this.workerId = workerId;
|
||||
}
|
||||
|
||||
public UniqueId getWorkerId() {
|
||||
return workerId;
|
||||
}
|
||||
}
|
||||
|
||||
public LocalModeTaskExecutor(RayRuntimeInternal runtime) {
|
||||
super(runtime);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected LocalActorContext createActorContext() {
|
||||
return new LocalActorContext(runtime.getWorkerContext().getCurrentWorkerId());
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void maybeSaveCheckpoint(Object actor, ActorId actorId) {
|
||||
}
|
||||
|
|
|
@ -4,16 +4,15 @@ import com.google.common.base.Preconditions;
|
|||
import com.google.common.collect.ImmutableList;
|
||||
import com.google.protobuf.ByteString;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.util.ArrayDeque;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.Deque;
|
||||
import java.util.HashMap;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Random;
|
||||
import java.util.Set;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
import java.util.concurrent.ExecutorService;
|
||||
import java.util.concurrent.Executors;
|
||||
import java.util.stream.Collectors;
|
||||
|
@ -21,9 +20,10 @@ import org.ray.api.BaseActor;
|
|||
import org.ray.api.id.ActorId;
|
||||
import org.ray.api.id.ObjectId;
|
||||
import org.ray.api.id.TaskId;
|
||||
import org.ray.api.id.UniqueId;
|
||||
import org.ray.api.options.ActorCreationOptions;
|
||||
import org.ray.api.options.CallOptions;
|
||||
import org.ray.runtime.RayDevRuntime;
|
||||
import org.ray.runtime.RayRuntimeInternal;
|
||||
import org.ray.runtime.actor.LocalModeRayActor;
|
||||
import org.ray.runtime.context.LocalModeWorkerContext;
|
||||
import org.ray.runtime.functionmanager.FunctionDescriptor;
|
||||
|
@ -37,6 +37,7 @@ import org.ray.runtime.generated.Common.TaskSpec;
|
|||
import org.ray.runtime.generated.Common.TaskType;
|
||||
import org.ray.runtime.object.LocalModeObjectStore;
|
||||
import org.ray.runtime.object.NativeRayObject;
|
||||
import org.ray.runtime.task.TaskExecutor.ActorContext;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
|
@ -49,7 +50,8 @@ public class LocalModeTaskSubmitter implements TaskSubmitter {
|
|||
|
||||
private final Map<ObjectId, Set<TaskSpec>> waitingTasks = new HashMap<>();
|
||||
private final Object taskAndObjectLock = new Object();
|
||||
private final RayDevRuntime runtime;
|
||||
private final RayRuntimeInternal runtime;
|
||||
private final TaskExecutor taskExecutor;
|
||||
private final LocalModeObjectStore objectStore;
|
||||
|
||||
/// The thread pool to execute actor tasks.
|
||||
|
@ -58,17 +60,16 @@ public class LocalModeTaskSubmitter implements TaskSubmitter {
|
|||
/// The thread pool to execute normal tasks.
|
||||
private final ExecutorService normalTaskExecutorService;
|
||||
|
||||
private final Deque<TaskExecutor> idleTaskExecutors = new ArrayDeque<>();
|
||||
private final Map<ActorId, TaskExecutor> actorTaskExecutors = new HashMap<>();
|
||||
private final Object taskExecutorLock = new Object();
|
||||
private final ThreadLocal<TaskExecutor> currentTaskExecutor = new ThreadLocal<>();
|
||||
|
||||
public LocalModeTaskSubmitter(RayDevRuntime runtime, LocalModeObjectStore objectStore,
|
||||
int numberThreads) {
|
||||
private final Map<ActorId, ActorContext> actorContexts = new ConcurrentHashMap<>();
|
||||
|
||||
public LocalModeTaskSubmitter(RayRuntimeInternal runtime, TaskExecutor taskExecutor,
|
||||
LocalModeObjectStore objectStore) {
|
||||
this.runtime = runtime;
|
||||
this.taskExecutor = taskExecutor;
|
||||
this.objectStore = objectStore;
|
||||
// The thread pool that executes normal tasks in parallel.
|
||||
normalTaskExecutorService = Executors.newFixedThreadPool(numberThreads);
|
||||
normalTaskExecutorService = Executors.newCachedThreadPool();
|
||||
// The thread pool that executes actor tasks in parallel.
|
||||
actorTaskExecutorServices = new HashMap<>();
|
||||
}
|
||||
|
@ -88,46 +89,6 @@ public class LocalModeTaskSubmitter implements TaskSubmitter {
|
|||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the worker of current thread. <br> NOTE: Cannot be used for multi-threading in worker.
|
||||
*/
|
||||
public TaskExecutor getCurrentTaskExecutor() {
|
||||
return currentTaskExecutor.get();
|
||||
}
|
||||
|
||||
/**
|
||||
* Get a worker from the worker pool to run the given task.
|
||||
*/
|
||||
private TaskExecutor getTaskExecutor(TaskSpec task) {
|
||||
TaskExecutor taskExecutor;
|
||||
synchronized (taskExecutorLock) {
|
||||
if (task.getType() == TaskType.ACTOR_TASK) {
|
||||
taskExecutor = actorTaskExecutors.get(getActorId(task));
|
||||
} else if (task.getType() == TaskType.ACTOR_CREATION_TASK) {
|
||||
taskExecutor = new LocalModeTaskExecutor(runtime);
|
||||
actorTaskExecutors.put(getActorId(task), taskExecutor);
|
||||
} else if (idleTaskExecutors.size() > 0) {
|
||||
taskExecutor = idleTaskExecutors.pop();
|
||||
} else {
|
||||
taskExecutor = new LocalModeTaskExecutor(runtime);
|
||||
}
|
||||
}
|
||||
currentTaskExecutor.set(taskExecutor);
|
||||
return taskExecutor;
|
||||
}
|
||||
|
||||
/**
|
||||
* Return the worker to the worker pool.
|
||||
*/
|
||||
private void returnTaskExecutor(TaskExecutor worker, TaskSpec taskSpec) {
|
||||
currentTaskExecutor.remove();
|
||||
synchronized (taskExecutorLock) {
|
||||
if (taskSpec.getType() == TaskType.NORMAL_TASK) {
|
||||
idleTaskExecutors.push(worker);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private Set<ObjectId> getUnreadyObjects(TaskSpec taskSpec) {
|
||||
Set<ObjectId> unreadyObjects = new HashSet<>();
|
||||
// Check whether task arguments are ready.
|
||||
|
@ -257,32 +218,11 @@ public class LocalModeTaskSubmitter implements TaskSubmitter {
|
|||
Set<ObjectId> unreadyObjects = getUnreadyObjects(taskSpec);
|
||||
|
||||
final Runnable runnable = () -> {
|
||||
TaskExecutor taskExecutor = getTaskExecutor(taskSpec);
|
||||
try {
|
||||
List<NativeRayObject> args = getFunctionArgs(taskSpec).stream()
|
||||
.map(arg -> arg.id != null ?
|
||||
objectStore.getRaw(Collections.singletonList(arg.id), -1).get(0)
|
||||
: arg.value)
|
||||
.collect(Collectors.toList());
|
||||
((LocalModeWorkerContext) runtime.getWorkerContext()).setCurrentTask(taskSpec);
|
||||
List<NativeRayObject> returnObjects = taskExecutor
|
||||
.execute(getJavaFunctionDescriptor(taskSpec).toList(), args);
|
||||
((LocalModeWorkerContext) runtime.getWorkerContext()).setCurrentTask(null);
|
||||
List<ObjectId> returnIds = getReturnIds(taskSpec);
|
||||
for (int i = 0; i < returnIds.size(); i++) {
|
||||
NativeRayObject putObject;
|
||||
if (i >= returnObjects.size()) {
|
||||
// If the task is an actor task or an actor creation task,
|
||||
// put the dummy object in object store, so those tasks which depends on it
|
||||
// can be executed.
|
||||
putObject = new NativeRayObject(new byte[] {1}, null);
|
||||
} else {
|
||||
putObject = returnObjects.get(i);
|
||||
}
|
||||
objectStore.putRaw(putObject, returnIds.get(i));
|
||||
}
|
||||
} finally {
|
||||
returnTaskExecutor(taskExecutor, taskSpec);
|
||||
executeTask(taskSpec);
|
||||
} catch (Exception ex) {
|
||||
LOGGER.error("Unexpected exception when executing a task.", ex);
|
||||
System.exit(-1);
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -313,6 +253,52 @@ public class LocalModeTaskSubmitter implements TaskSubmitter {
|
|||
}
|
||||
}
|
||||
|
||||
private void executeTask(TaskSpec taskSpec) {
|
||||
ActorContext actorContext = null;
|
||||
if (taskSpec.getType() == TaskType.ACTOR_TASK) {
|
||||
actorContext = actorContexts.get(getActorId(taskSpec));
|
||||
Preconditions.checkNotNull(actorContext);
|
||||
}
|
||||
taskExecutor.setActorContext(actorContext);
|
||||
List<NativeRayObject> args = getFunctionArgs(taskSpec).stream()
|
||||
.map(arg -> arg.id != null ?
|
||||
objectStore.getRaw(Collections.singletonList(arg.id), -1).get(0)
|
||||
: arg.value)
|
||||
.collect(Collectors.toList());
|
||||
runtime.setIsContextSet(true);
|
||||
((LocalModeWorkerContext) runtime.getWorkerContext()).setCurrentTask(taskSpec);
|
||||
UniqueId workerId = actorContext != null
|
||||
? ((LocalModeTaskExecutor.LocalActorContext) actorContext).getWorkerId()
|
||||
: UniqueId.randomId();
|
||||
((LocalModeWorkerContext) runtime.getWorkerContext()).setCurrentWorkerId(workerId);
|
||||
List<NativeRayObject> returnObjects = taskExecutor
|
||||
.execute(getJavaFunctionDescriptor(taskSpec).toList(), args);
|
||||
if (taskSpec.getType() == TaskType.ACTOR_CREATION_TASK) {
|
||||
// Update actor context map ASAP in case objectStore.putRaw triggered the next actor task
|
||||
// on this actor.
|
||||
actorContexts.put(getActorId(taskSpec), taskExecutor.getActorContext());
|
||||
}
|
||||
// Set this flag to true is necessary because at the end of `taskExecutor.execute()`,
|
||||
// this flag will be set to false. And `runtime.getWorkerContext()` requires it to be
|
||||
// true.
|
||||
runtime.setIsContextSet(true);
|
||||
((LocalModeWorkerContext) runtime.getWorkerContext()).setCurrentTask(null);
|
||||
runtime.setIsContextSet(false);
|
||||
List<ObjectId> returnIds = getReturnIds(taskSpec);
|
||||
for (int i = 0; i < returnIds.size(); i++) {
|
||||
NativeRayObject putObject;
|
||||
if (i >= returnObjects.size()) {
|
||||
// If the task is an actor task or an actor creation task,
|
||||
// put the dummy object in object store, so those tasks which depends on it
|
||||
// can be executed.
|
||||
putObject = new NativeRayObject(new byte[]{1}, null);
|
||||
} else {
|
||||
putObject = returnObjects.get(i);
|
||||
}
|
||||
objectStore.putRaw(putObject, returnIds.get(i));
|
||||
}
|
||||
}
|
||||
|
||||
private static JavaFunctionDescriptor getJavaFunctionDescriptor(TaskSpec taskSpec) {
|
||||
org.ray.runtime.generated.Common.FunctionDescriptor functionDescriptor =
|
||||
taskSpec.getFunctionDescriptor();
|
||||
|
|
|
@ -8,39 +8,42 @@ import org.ray.api.Checkpointable.Checkpoint;
|
|||
import org.ray.api.Checkpointable.CheckpointContext;
|
||||
import org.ray.api.id.ActorId;
|
||||
import org.ray.api.id.UniqueId;
|
||||
import org.ray.runtime.AbstractRayRuntime;
|
||||
import org.ray.runtime.RayRuntimeInternal;
|
||||
import org.ray.runtime.task.NativeTaskExecutor.NativeActorContext;
|
||||
|
||||
/**
|
||||
* Task executor for cluster mode.
|
||||
*/
|
||||
public class NativeTaskExecutor extends TaskExecutor {
|
||||
public class NativeTaskExecutor extends TaskExecutor<NativeActorContext> {
|
||||
|
||||
// TODO(hchen): Use the C++ config.
|
||||
private static final int NUM_ACTOR_CHECKPOINTS_TO_KEEP = 20;
|
||||
|
||||
/**
|
||||
* The native pointer of core worker.
|
||||
*/
|
||||
private final long nativeCoreWorkerPointer;
|
||||
static class NativeActorContext extends TaskExecutor.ActorContext {
|
||||
|
||||
/**
|
||||
* Number of tasks executed since last actor checkpoint.
|
||||
*/
|
||||
private int numTasksSinceLastCheckpoint = 0;
|
||||
/**
|
||||
* Number of tasks executed since last actor checkpoint.
|
||||
*/
|
||||
private int numTasksSinceLastCheckpoint = 0;
|
||||
|
||||
/**
|
||||
* IDs of this actor's previous checkpoints.
|
||||
*/
|
||||
private List<UniqueId> checkpointIds;
|
||||
/**
|
||||
* IDs of this actor's previous checkpoints.
|
||||
*/
|
||||
private List<UniqueId> checkpointIds;
|
||||
|
||||
/**
|
||||
* Timestamp of the last actor checkpoint.
|
||||
*/
|
||||
private long lastCheckpointTimestamp = 0;
|
||||
/**
|
||||
* Timestamp of the last actor checkpoint.
|
||||
*/
|
||||
private long lastCheckpointTimestamp = 0;
|
||||
}
|
||||
|
||||
public NativeTaskExecutor(long nativeCoreWorkerPointer, AbstractRayRuntime runtime) {
|
||||
public NativeTaskExecutor(RayRuntimeInternal runtime) {
|
||||
super(runtime);
|
||||
this.nativeCoreWorkerPointer = nativeCoreWorkerPointer;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected NativeActorContext createActorContext() {
|
||||
return new NativeActorContext();
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -48,15 +51,18 @@ public class NativeTaskExecutor extends TaskExecutor {
|
|||
if (!(actor instanceof Checkpointable)) {
|
||||
return;
|
||||
}
|
||||
NativeActorContext actorContext = getActorContext();
|
||||
CheckpointContext checkpointContext = new CheckpointContext(actorId,
|
||||
++numTasksSinceLastCheckpoint, System.currentTimeMillis() - lastCheckpointTimestamp);
|
||||
++actorContext.numTasksSinceLastCheckpoint,
|
||||
System.currentTimeMillis() - actorContext.lastCheckpointTimestamp);
|
||||
Checkpointable checkpointable = (Checkpointable) actor;
|
||||
if (!checkpointable.shouldCheckpoint(checkpointContext)) {
|
||||
return;
|
||||
}
|
||||
numTasksSinceLastCheckpoint = 0;
|
||||
lastCheckpointTimestamp = System.currentTimeMillis();
|
||||
UniqueId checkpointId = new UniqueId(nativePrepareCheckpoint(nativeCoreWorkerPointer));
|
||||
actorContext.numTasksSinceLastCheckpoint = 0;
|
||||
actorContext.lastCheckpointTimestamp = System.currentTimeMillis();
|
||||
UniqueId checkpointId = new UniqueId(nativePrepareCheckpoint());
|
||||
List<UniqueId> checkpointIds = actorContext.checkpointIds;
|
||||
checkpointIds.add(checkpointId);
|
||||
if (checkpointIds.size() > NUM_ACTOR_CHECKPOINTS_TO_KEEP) {
|
||||
((Checkpointable) actor).checkpointExpired(actorId, checkpointIds.get(0));
|
||||
|
@ -70,9 +76,10 @@ public class NativeTaskExecutor extends TaskExecutor {
|
|||
if (!(actor instanceof Checkpointable)) {
|
||||
return;
|
||||
}
|
||||
numTasksSinceLastCheckpoint = 0;
|
||||
lastCheckpointTimestamp = System.currentTimeMillis();
|
||||
checkpointIds = new ArrayList<>();
|
||||
NativeActorContext actorContext = getActorContext();
|
||||
actorContext.numTasksSinceLastCheckpoint = 0;
|
||||
actorContext.lastCheckpointTimestamp = System.currentTimeMillis();
|
||||
actorContext.checkpointIds = new ArrayList<>();
|
||||
List<Checkpoint> availableCheckpoints
|
||||
= runtime.getGcsClient().getCheckpointsForActor(actorId);
|
||||
if (availableCheckpoints.isEmpty()) {
|
||||
|
@ -90,13 +97,11 @@ public class NativeTaskExecutor extends TaskExecutor {
|
|||
Preconditions.checkArgument(checkpointValid,
|
||||
"'loadCheckpoint' must return a checkpoint ID that exists in the "
|
||||
+ "'availableCheckpoints' list, or null.");
|
||||
|
||||
nativeNotifyActorResumedFromCheckpoint(nativeCoreWorkerPointer, checkpointId.getBytes());
|
||||
nativeNotifyActorResumedFromCheckpoint(checkpointId.getBytes());
|
||||
}
|
||||
}
|
||||
|
||||
private static native byte[] nativePrepareCheckpoint(long nativeCoreWorkerPointer);
|
||||
private static native byte[] nativePrepareCheckpoint();
|
||||
|
||||
private static native void nativeNotifyActorResumedFromCheckpoint(long nativeCoreWorkerPointer,
|
||||
byte[] checkpointId);
|
||||
private static native void nativeNotifyActorResumedFromCheckpoint(byte[] checkpointId);
|
||||
}
|
||||
|
|
|
@ -15,30 +15,18 @@ import org.ray.runtime.functionmanager.FunctionDescriptor;
|
|||
*/
|
||||
public class NativeTaskSubmitter implements TaskSubmitter {
|
||||
|
||||
/**
|
||||
* The native pointer of core worker.
|
||||
*/
|
||||
private final long nativeCoreWorkerPointer;
|
||||
|
||||
public NativeTaskSubmitter(long nativeCoreWorkerPointer) {
|
||||
this.nativeCoreWorkerPointer = nativeCoreWorkerPointer;
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<ObjectId> submitTask(FunctionDescriptor functionDescriptor, List<FunctionArg> args,
|
||||
int numReturns, CallOptions options) {
|
||||
List<byte[]> returnIds = nativeSubmitTask(nativeCoreWorkerPointer, functionDescriptor, args,
|
||||
numReturns, options);
|
||||
List<byte[]> returnIds = nativeSubmitTask(functionDescriptor, args, numReturns, options);
|
||||
return returnIds.stream().map(ObjectId::new).collect(Collectors.toList());
|
||||
}
|
||||
|
||||
@Override
|
||||
public BaseActor createActor(FunctionDescriptor functionDescriptor, List<FunctionArg> args,
|
||||
ActorCreationOptions options) {
|
||||
byte[] actorId = nativeCreateActor(nativeCoreWorkerPointer, functionDescriptor, args,
|
||||
options);
|
||||
return NativeRayActor.create(nativeCoreWorkerPointer, actorId,
|
||||
functionDescriptor.getLanguage());
|
||||
byte[] actorId = nativeCreateActor(functionDescriptor, args, options);
|
||||
return NativeRayActor.create(actorId, functionDescriptor.getLanguage());
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -46,24 +34,18 @@ public class NativeTaskSubmitter implements TaskSubmitter {
|
|||
BaseActor actor, FunctionDescriptor functionDescriptor,
|
||||
List<FunctionArg> args, int numReturns, CallOptions options) {
|
||||
Preconditions.checkState(actor instanceof NativeRayActor);
|
||||
List<byte[]> returnIds = nativeSubmitActorTask(nativeCoreWorkerPointer,
|
||||
actor.getId().getBytes(), functionDescriptor, args, numReturns,
|
||||
options);
|
||||
List<byte[]> returnIds = nativeSubmitActorTask(actor.getId().getBytes(),
|
||||
functionDescriptor, args, numReturns, options);
|
||||
return returnIds.stream().map(ObjectId::new).collect(Collectors.toList());
|
||||
}
|
||||
|
||||
private static native List<byte[]> nativeSubmitTask(
|
||||
long nativeCoreWorkerPointer,
|
||||
private static native List<byte[]> nativeSubmitTask(FunctionDescriptor functionDescriptor,
|
||||
List<FunctionArg> args, int numReturns, CallOptions callOptions);
|
||||
|
||||
private static native byte[] nativeCreateActor(FunctionDescriptor functionDescriptor,
|
||||
List<FunctionArg> args, ActorCreationOptions actorCreationOptions);
|
||||
|
||||
private static native List<byte[]> nativeSubmitActorTask(byte[] actorId,
|
||||
FunctionDescriptor functionDescriptor, List<FunctionArg> args, int numReturns,
|
||||
CallOptions callOptions);
|
||||
|
||||
private static native byte[] nativeCreateActor(
|
||||
long nativeCoreWorkerPointer,
|
||||
FunctionDescriptor functionDescriptor, List<FunctionArg> args,
|
||||
ActorCreationOptions actorCreationOptions);
|
||||
|
||||
private static native List<byte[]> nativeSubmitActorTask(
|
||||
long nativeCoreWorkerPointer,
|
||||
byte[] actorId, FunctionDescriptor functionDescriptor, List<FunctionArg> args,
|
||||
int numReturns, CallOptions callOptions);
|
||||
}
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package org.ray.runtime.task;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
import java.lang.reflect.InvocationTargetException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
|
@ -9,58 +10,75 @@ import org.ray.api.id.ActorId;
|
|||
import org.ray.api.id.JobId;
|
||||
import org.ray.api.id.TaskId;
|
||||
import org.ray.api.id.UniqueId;
|
||||
import org.ray.runtime.AbstractRayRuntime;
|
||||
import org.ray.runtime.config.RayConfig;
|
||||
import org.ray.runtime.config.RunMode;
|
||||
import org.ray.runtime.RayRuntimeInternal;
|
||||
import org.ray.runtime.functionmanager.JavaFunctionDescriptor;
|
||||
import org.ray.runtime.functionmanager.RayFunction;
|
||||
import org.ray.runtime.generated.Common.TaskType;
|
||||
import org.ray.runtime.object.NativeRayObject;
|
||||
import org.ray.runtime.object.ObjectSerializer;
|
||||
import org.ray.runtime.task.TaskExecutor.ActorContext;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
/**
|
||||
* The task executor, which executes tasks assigned by raylet continuously.
|
||||
*/
|
||||
public abstract class TaskExecutor {
|
||||
public abstract class TaskExecutor<T extends ActorContext> {
|
||||
|
||||
private static final Logger LOGGER = LoggerFactory.getLogger(TaskExecutor.class);
|
||||
|
||||
// A helper map to help we get the corresponding executor for the given worker in JNI.
|
||||
private static ConcurrentHashMap<UniqueId, TaskExecutor> taskExecutors
|
||||
= new ConcurrentHashMap<>();
|
||||
protected final RayRuntimeInternal runtime;
|
||||
|
||||
protected final AbstractRayRuntime runtime;
|
||||
private final ConcurrentHashMap<UniqueId, T> actorContextMap = new ConcurrentHashMap<>();
|
||||
|
||||
/**
|
||||
* The current actor object, if this worker is an actor, otherwise null.
|
||||
*/
|
||||
protected Object currentActor = null;
|
||||
static class ActorContext {
|
||||
|
||||
/**
|
||||
* The exception that failed the actor creation task, if any.
|
||||
*/
|
||||
private Exception actorCreationException = null;
|
||||
/**
|
||||
* The current actor object, if this worker is an actor, otherwise null.
|
||||
*/
|
||||
Object currentActor = null;
|
||||
|
||||
protected TaskExecutor(AbstractRayRuntime runtime) {
|
||||
this.runtime = runtime;
|
||||
if (RayConfig.getInstance().runMode == RunMode.CLUSTER) {
|
||||
taskExecutors.put(runtime.getWorkerContext().getCurrentWorkerId(), this);
|
||||
}
|
||||
/**
|
||||
* The exception that failed the actor creation task, if any.
|
||||
*/
|
||||
Throwable actorCreationException = null;
|
||||
}
|
||||
|
||||
public static TaskExecutor get(byte[] workerId) {
|
||||
return taskExecutors.get(new UniqueId(workerId));
|
||||
TaskExecutor(RayRuntimeInternal runtime) {
|
||||
this.runtime = runtime;
|
||||
}
|
||||
|
||||
protected abstract T createActorContext();
|
||||
|
||||
T getActorContext() {
|
||||
return actorContextMap.get(runtime.getWorkerContext().getCurrentWorkerId());
|
||||
}
|
||||
|
||||
void setActorContext(T actorContext) {
|
||||
if (actorContext == null) {
|
||||
// ConcurrentHashMap doesn't allow null values. So just return here.
|
||||
return;
|
||||
}
|
||||
this.actorContextMap.put(runtime.getWorkerContext().getCurrentWorkerId(), actorContext);
|
||||
}
|
||||
|
||||
protected List<NativeRayObject> execute(List<String> rayFunctionInfo,
|
||||
List<NativeRayObject> argsBytes) {
|
||||
runtime.setIsContextSet(true);
|
||||
JobId jobId = runtime.getWorkerContext().getCurrentJobId();
|
||||
TaskType taskType = runtime.getWorkerContext().getCurrentTaskType();
|
||||
TaskId taskId = runtime.getWorkerContext().getCurrentTaskId();
|
||||
LOGGER.debug("Executing task {}", taskId);
|
||||
|
||||
T actorContext = null;
|
||||
if (taskType == TaskType.ACTOR_CREATION_TASK) {
|
||||
actorContext = createActorContext();
|
||||
setActorContext(actorContext);
|
||||
} else if (taskType == TaskType.ACTOR_TASK) {
|
||||
actorContext = getActorContext();
|
||||
Preconditions.checkNotNull(actorContext);
|
||||
}
|
||||
|
||||
List<NativeRayObject> returnObjects = new ArrayList<>();
|
||||
ClassLoader oldLoader = Thread.currentThread().getContextClassLoader();
|
||||
JavaFunctionDescriptor functionDescriptor = parseFunctionDescriptor(rayFunctionInfo);
|
||||
|
@ -74,19 +92,26 @@ public abstract class TaskExecutor {
|
|||
// Get local actor object and arguments.
|
||||
Object actor = null;
|
||||
if (taskType == TaskType.ACTOR_TASK) {
|
||||
if (actorCreationException != null) {
|
||||
throw actorCreationException;
|
||||
if (actorContext.actorCreationException != null) {
|
||||
throw actorContext.actorCreationException;
|
||||
}
|
||||
actor = currentActor;
|
||||
|
||||
actor = actorContext.currentActor;
|
||||
}
|
||||
Object[] args = ArgumentsBuilder.unwrap(argsBytes, rayFunction.classLoader);
|
||||
// Execute the task.
|
||||
Object result;
|
||||
if (!rayFunction.isConstructor()) {
|
||||
result = rayFunction.getMethod().invoke(actor, args);
|
||||
} else {
|
||||
result = rayFunction.getConstructor().newInstance(args);
|
||||
try {
|
||||
if (!rayFunction.isConstructor()) {
|
||||
result = rayFunction.getMethod().invoke(actor, args);
|
||||
} else {
|
||||
result = rayFunction.getConstructor().newInstance(args);
|
||||
}
|
||||
} catch (InvocationTargetException e) {
|
||||
if (e.getCause() != null) {
|
||||
throw e.getCause();
|
||||
} else {
|
||||
throw e;
|
||||
}
|
||||
}
|
||||
// Set result
|
||||
if (taskType != TaskType.ACTOR_CREATION_TASK) {
|
||||
|
@ -100,10 +125,10 @@ public abstract class TaskExecutor {
|
|||
} else {
|
||||
// TODO (kfstorm): handle checkpoint in core worker.
|
||||
maybeLoadCheckpoint(result, runtime.getWorkerContext().getCurrentActorId());
|
||||
currentActor = result;
|
||||
actorContext.currentActor = result;
|
||||
}
|
||||
LOGGER.debug("Finished executing task {}", taskId);
|
||||
} catch (Exception e) {
|
||||
} catch (Throwable e) {
|
||||
LOGGER.error("Error executing task " + taskId, e);
|
||||
if (taskType != TaskType.ACTOR_CREATION_TASK) {
|
||||
boolean hasReturn = rayFunction != null && rayFunction.hasReturn();
|
||||
|
@ -113,11 +138,12 @@ public abstract class TaskExecutor {
|
|||
.serialize(new RayTaskException("Error executing task " + taskId, e)));
|
||||
}
|
||||
} else {
|
||||
actorCreationException = e;
|
||||
actorContext.actorCreationException = e;
|
||||
}
|
||||
} finally {
|
||||
Thread.currentThread().setContextClassLoader(oldLoader);
|
||||
runtime.getWorkerContext().setCurrentClassLoader(null);
|
||||
runtime.setIsContextSet(false);
|
||||
}
|
||||
return returnObjects;
|
||||
}
|
||||
|
|
|
@ -98,12 +98,6 @@ ray {
|
|||
}
|
||||
}
|
||||
|
||||
// ----------------------------
|
||||
// configurations under SINGLE_PROCESS mode
|
||||
// ----------------------------
|
||||
dev-runtime {
|
||||
// Number of threads that you process tasks
|
||||
execution-parallelism: 10
|
||||
}
|
||||
|
||||
// Whether we enable job manager to submit and manage job.
|
||||
enable-job-manager: false
|
||||
}
|
||||
|
|
|
@ -42,6 +42,10 @@ public class TestProgressListener implements IInvokedMethodListener, ITestListen
|
|||
@Override
|
||||
public void onTestFailure(ITestResult result) {
|
||||
printInfo("TEST FAILURE", getFullTestName(result));
|
||||
Throwable throwable = result.getThrowable();
|
||||
if (throwable != null) {
|
||||
throwable.printStackTrace();
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -1,11 +1,9 @@
|
|||
package org.ray.api;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
import java.io.Serializable;
|
||||
import java.util.function.Supplier;
|
||||
import org.ray.api.runtime.RayRuntime;
|
||||
import org.ray.runtime.AbstractRayRuntime;
|
||||
import org.ray.runtime.RayMultiWorkerNativeRuntime;
|
||||
import org.ray.runtime.RayRuntimeInternal;
|
||||
import org.ray.runtime.RayRuntimeProxy;
|
||||
import org.ray.runtime.config.RunMode;
|
||||
import org.testng.Assert;
|
||||
import org.testng.SkipException;
|
||||
|
@ -79,12 +77,13 @@ public class TestUtils {
|
|||
Assert.assertEquals(obj.get(), "hi");
|
||||
}
|
||||
|
||||
public static AbstractRayRuntime getRuntime() {
|
||||
RayRuntime runtime = Ray.internal();
|
||||
if (runtime instanceof RayMultiWorkerNativeRuntime) {
|
||||
runtime = ((RayMultiWorkerNativeRuntime) runtime).getCurrentRuntime();
|
||||
}
|
||||
Preconditions.checkState(runtime instanceof AbstractRayRuntime);
|
||||
return (AbstractRayRuntime) runtime;
|
||||
public static RayRuntimeInternal getRuntime() {
|
||||
return (RayRuntimeInternal) Ray.internal();
|
||||
}
|
||||
|
||||
public static RayRuntimeInternal getUnderlyingRuntime() {
|
||||
RayRuntimeProxy proxy = (RayRuntimeProxy) (java.lang.reflect.Proxy
|
||||
.getInvocationHandler(Ray.internal()));
|
||||
return proxy.getRuntimeObject();
|
||||
}
|
||||
}
|
||||
|
|
|
@ -134,7 +134,7 @@ public class ClassLoaderTest extends BaseTest {
|
|||
FunctionDescriptor.class, Object[].class, ActorCreationOptions.class);
|
||||
createActorMethod.setAccessible(true);
|
||||
return (RayActor<?>) createActorMethod
|
||||
.invoke(TestUtils.getRuntime(), functionDescriptor, new Object[0], null);
|
||||
.invoke(TestUtils.getUnderlyingRuntime(), functionDescriptor, new Object[0], null);
|
||||
}
|
||||
|
||||
private <T> RayObject<T> callActorFunction(RayActor<?> rayActor,
|
||||
|
@ -143,6 +143,6 @@ public class ClassLoaderTest extends BaseTest {
|
|||
BaseActor.class, FunctionDescriptor.class, Object[].class, int.class);
|
||||
callActorFunctionMethod.setAccessible(true);
|
||||
return (RayObject<T>) callActorFunctionMethod
|
||||
.invoke(TestUtils.getRuntime(), rayActor, functionDescriptor, args, numReturns);
|
||||
.invoke(TestUtils.getUnderlyingRuntime(), rayActor, functionDescriptor, args, numReturns);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -29,7 +29,8 @@ public class ClientExceptionTest extends BaseTest {
|
|||
try {
|
||||
TimeUnit.SECONDS.sleep(1);
|
||||
// kill raylet
|
||||
RunManager runManager = ((RayNativeRuntime) TestUtils.getRuntime()).getRunManager();
|
||||
RunManager runManager =
|
||||
((RayNativeRuntime) TestUtils.getUnderlyingRuntime()).getRunManager();
|
||||
for (Process process : runManager.getProcesses("raylet")) {
|
||||
runManager.terminateProcess("raylet", process);
|
||||
}
|
||||
|
|
|
@ -14,6 +14,7 @@ import org.ray.api.RayActor;
|
|||
import org.ray.api.RayObject;
|
||||
import org.ray.api.TestUtils;
|
||||
import org.ray.api.WaitResult;
|
||||
import org.ray.api.exception.RayException;
|
||||
import org.ray.api.id.ActorId;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
@ -135,6 +136,105 @@ public class MultiThreadingTest extends BaseTest {
|
|||
Assert.assertEquals(actorId, actorIdTester.getId());
|
||||
}
|
||||
|
||||
static boolean testMissingWrapRunnable() throws InterruptedException {
|
||||
final RayObject<Integer> fooObject = Ray.put(1);
|
||||
final RayActor<Echo> fooActor = Ray.createActor(Echo::new);
|
||||
final Runnable[] runnables = new Runnable[]{
|
||||
() -> Ray.put(1),
|
||||
() -> Ray.get(fooObject.getId()),
|
||||
fooObject::get,
|
||||
() -> Ray.wait(ImmutableList.of(fooObject)),
|
||||
Ray::getRuntimeContext,
|
||||
() -> Ray.call(MultiThreadingTest::echo, 1),
|
||||
() -> Ray.createActor(Echo::new),
|
||||
() -> fooActor.call(Echo::echo, 1),
|
||||
};
|
||||
|
||||
// It's OK to run them in main thread.
|
||||
for (Runnable runnable : runnables) {
|
||||
runnable.run();
|
||||
}
|
||||
|
||||
Exception[] exception = new Exception[1];
|
||||
|
||||
Thread thread = new Thread(Ray.wrapRunnable(() -> {
|
||||
try {
|
||||
// It would be OK to run them in another thread if wrapped the runnable.
|
||||
for (Runnable runnable : runnables) {
|
||||
runnable.run();
|
||||
}
|
||||
} catch (Exception ex) {
|
||||
exception[0] = ex;
|
||||
}
|
||||
}));
|
||||
thread.start();
|
||||
thread.join();
|
||||
if (exception[0] != null) {
|
||||
throw new RuntimeException("Exception occurred in thread.", exception[0]);
|
||||
}
|
||||
|
||||
thread = new Thread(() -> {
|
||||
try {
|
||||
// It wouldn't be OK to run them in another thread if not wrapped the runnable.
|
||||
for (Runnable runnable : runnables) {
|
||||
Assert.expectThrows(RayException.class, runnable::run);
|
||||
}
|
||||
} catch (Exception ex) {
|
||||
exception[0] = ex;
|
||||
}
|
||||
});
|
||||
thread.start();
|
||||
thread.join();
|
||||
if (exception[0] != null) {
|
||||
throw new RuntimeException("Exception occurred in thread.", exception[0]);
|
||||
}
|
||||
|
||||
Runnable[] wrappedRunnables = new Runnable[runnables.length];
|
||||
for (int i = 0; i < runnables.length; i++) {
|
||||
wrappedRunnables[i] = Ray.wrapRunnable(runnables[i]);
|
||||
}
|
||||
// It would be OK to run the wrapped runnables in the current thread.
|
||||
for (Runnable runnable : wrappedRunnables) {
|
||||
runnable.run();
|
||||
}
|
||||
|
||||
// It would be OK to invoke Ray APIs after executing a wrapped runnable in the current thread.
|
||||
wrappedRunnables[0].run();
|
||||
runnables[0].run();
|
||||
|
||||
// Return true here to make the Ray.call returns an RayObject.
|
||||
return true;
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testMissingWrapRunnableInDriver() throws InterruptedException {
|
||||
testMissingWrapRunnable();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testMissingWrapRunnableInWorker() {
|
||||
Ray.call(MultiThreadingTest::testMissingWrapRunnable).get();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testGetAndSetAsyncContext() throws InterruptedException {
|
||||
Object asyncContext = Ray.getAsyncContext();
|
||||
Exception[] exception = new Exception[1];
|
||||
Thread thread = new Thread(() -> {
|
||||
try {
|
||||
Ray.setAsyncContext(asyncContext);
|
||||
Ray.put(1);
|
||||
} catch (Exception ex) {
|
||||
exception[0] = ex;
|
||||
}
|
||||
});
|
||||
thread.start();
|
||||
thread.join();
|
||||
if (exception[0] != null) {
|
||||
throw new RuntimeException("Exception occurred in thread.", exception[0]);
|
||||
}
|
||||
}
|
||||
|
||||
private static void runTestCaseInMultipleThreads(Runnable testCase, int numRepeats) {
|
||||
ExecutorService service = Executors.newFixedThreadPool(NUM_THREADS);
|
||||
|
||||
|
|
|
@ -17,7 +17,7 @@ from ray.includes.common cimport (
|
|||
CBuffer,
|
||||
CRayObject
|
||||
)
|
||||
from ray.includes.libcoreworker cimport CCoreWorker
|
||||
from ray.includes.libcoreworker cimport CFiberEvent
|
||||
from ray.includes.unique_ids cimport (
|
||||
CObjectID,
|
||||
CActorID
|
||||
|
@ -72,7 +72,7 @@ cdef class ActorID(BaseID):
|
|||
|
||||
cdef class CoreWorker:
|
||||
cdef:
|
||||
unique_ptr[CCoreWorker] core_worker
|
||||
c_bool is_driver
|
||||
object async_thread
|
||||
object async_event_loop
|
||||
object plasma_event_handler
|
||||
|
@ -85,6 +85,7 @@ cdef class CoreWorker:
|
|||
cdef store_task_outputs(
|
||||
self, worker, outputs, const c_vector[CObjectID] return_ids,
|
||||
c_vector[shared_ptr[CRayObject]] *returns)
|
||||
cdef yield_current_fiber(self, CFiberEvent &fiber_event)
|
||||
|
||||
cdef class FunctionDescriptor:
|
||||
cdef:
|
||||
|
|
|
@ -69,7 +69,8 @@ from ray.includes.unique_ids cimport (
|
|||
)
|
||||
from ray.includes.libcoreworker cimport (
|
||||
CActorCreationOptions,
|
||||
CCoreWorker,
|
||||
CCoreWorkerOptions,
|
||||
CCoreWorkerProcess,
|
||||
CTaskOptions,
|
||||
ResourceMappingType,
|
||||
CFiberEvent,
|
||||
|
@ -312,7 +313,7 @@ cdef execute_task(
|
|||
dict execution_infos = manager.execution_infos
|
||||
CoreWorker core_worker = worker.core_worker
|
||||
JobID job_id = core_worker.get_current_job_id()
|
||||
CTaskID task_id = core_worker.core_worker.get().GetCurrentTaskId()
|
||||
TaskID task_id = core_worker.get_current_task_id()
|
||||
CFiberEvent task_done_event
|
||||
|
||||
# Automatically restrict the GPUs available to this task.
|
||||
|
@ -339,7 +340,7 @@ cdef execute_task(
|
|||
|
||||
function_name = execution_info.function_name
|
||||
extra_data = (b'{"name": ' + function_name.encode("ascii") +
|
||||
b' "task_id": ' + task_id.Hex() + b'}')
|
||||
b' "task_id": ' + task_id.hex().encode("ascii") + b'}')
|
||||
|
||||
if <int>task_type == <int>TASK_TYPE_NORMAL_TASK:
|
||||
title = "ray::{}()".format(function_name)
|
||||
|
@ -396,9 +397,7 @@ cdef execute_task(
|
|||
monitor_state.unregister_coroutine(coroutine)
|
||||
|
||||
future.add_done_callback(callback)
|
||||
with nogil:
|
||||
(core_worker.core_worker.get()
|
||||
.YieldCurrentFiber(task_done_event))
|
||||
core_worker.yield_current_fiber(task_done_event)
|
||||
|
||||
return future.result()
|
||||
|
||||
|
@ -499,8 +498,7 @@ cdef CRayStatus task_execution_handler(
|
|||
const c_vector[shared_ptr[CRayObject]] &c_args,
|
||||
const c_vector[CObjectID] &c_arg_reference_ids,
|
||||
const c_vector[CObjectID] &c_return_ids,
|
||||
c_vector[shared_ptr[CRayObject]] *returns,
|
||||
const CWorkerID &c_worker_id) nogil:
|
||||
c_vector[shared_ptr[CRayObject]] *returns) nogil:
|
||||
|
||||
with gil:
|
||||
try:
|
||||
|
@ -645,43 +643,76 @@ cdef class CoreWorker:
|
|||
|
||||
def __cinit__(self, is_driver, store_socket, raylet_socket,
|
||||
JobID job_id, GcsClientOptions gcs_options, log_dir,
|
||||
node_ip_address, node_manager_port, local_mode):
|
||||
use_driver = is_driver or local_mode
|
||||
self.core_worker.reset(new CCoreWorker(
|
||||
WORKER_TYPE_DRIVER if use_driver else WORKER_TYPE_WORKER,
|
||||
LANGUAGE_PYTHON, store_socket.encode("ascii"),
|
||||
raylet_socket.encode("ascii"), job_id.native(),
|
||||
gcs_options.native()[0], log_dir.encode("utf-8"),
|
||||
node_ip_address.encode("utf-8"), node_manager_port,
|
||||
task_execution_handler, check_signals, gc_collect,
|
||||
get_py_stack, True, local_mode))
|
||||
node_ip_address, node_manager_port, local_mode,
|
||||
driver_name, stdout_file, stderr_file):
|
||||
self.is_driver = is_driver
|
||||
self.is_local_mode = local_mode
|
||||
|
||||
cdef CCoreWorkerOptions options = CCoreWorkerOptions()
|
||||
options.worker_type = (
|
||||
WORKER_TYPE_DRIVER if is_driver else WORKER_TYPE_WORKER)
|
||||
options.language = LANGUAGE_PYTHON
|
||||
options.store_socket = store_socket.encode("ascii")
|
||||
options.raylet_socket = raylet_socket.encode("ascii")
|
||||
options.job_id = job_id.native()
|
||||
options.gcs_options = gcs_options.native()[0]
|
||||
options.log_dir = log_dir.encode("utf-8")
|
||||
options.install_failure_signal_handler = True
|
||||
options.node_ip_address = node_ip_address.encode("utf-8")
|
||||
options.node_manager_port = node_manager_port
|
||||
options.driver_name = driver_name
|
||||
options.stdout_file = stdout_file
|
||||
options.stderr_file = stderr_file
|
||||
options.task_execution_callback = task_execution_handler
|
||||
options.check_signals = check_signals
|
||||
options.gc_collect = gc_collect
|
||||
options.get_lang_stack = get_py_stack
|
||||
options.ref_counting_enabled = True
|
||||
options.is_local_mode = local_mode
|
||||
options.num_workers = 1
|
||||
|
||||
CCoreWorkerProcess.Initialize(options)
|
||||
|
||||
def __dealloc__(self):
|
||||
with nogil:
|
||||
# If it's a worker, the core worker process should have been
|
||||
# shutdown. So we can't call
|
||||
# `CCoreWorkerProcess.GetCoreWorker().GetWorkerType()` here.
|
||||
# Instead, we use the cached `is_driver` flag to test if it's a
|
||||
# driver.
|
||||
if self.is_driver:
|
||||
CCoreWorkerProcess.Shutdown()
|
||||
|
||||
def run_task_loop(self):
|
||||
with nogil:
|
||||
self.core_worker.get().StartExecutingTasks()
|
||||
CCoreWorkerProcess.RunTaskExecutionLoop()
|
||||
|
||||
def get_current_task_id(self):
|
||||
return TaskID(self.core_worker.get().GetCurrentTaskId().Binary())
|
||||
return TaskID(
|
||||
CCoreWorkerProcess.GetCoreWorker().GetCurrentTaskId().Binary())
|
||||
|
||||
def get_current_job_id(self):
|
||||
return JobID(self.core_worker.get().GetCurrentJobId().Binary())
|
||||
return JobID(
|
||||
CCoreWorkerProcess.GetCoreWorker().GetCurrentJobId().Binary())
|
||||
|
||||
def get_actor_id(self):
|
||||
return ActorID(self.core_worker.get().GetActorId().Binary())
|
||||
return ActorID(
|
||||
CCoreWorkerProcess.GetCoreWorker().GetActorId().Binary())
|
||||
|
||||
def set_webui_display(self, key, message):
|
||||
self.core_worker.get().SetWebuiDisplay(key, message)
|
||||
CCoreWorkerProcess.GetCoreWorker().SetWebuiDisplay(key, message)
|
||||
|
||||
def set_actor_title(self, title):
|
||||
self.core_worker.get().SetActorTitle(title)
|
||||
CCoreWorkerProcess.GetCoreWorker().SetActorTitle(title)
|
||||
|
||||
def set_plasma_added_callback(self, plasma_event_handler):
|
||||
self.plasma_event_handler = plasma_event_handler
|
||||
self.core_worker.get().SetPlasmaAddedCallback(async_plasma_callback)
|
||||
CCoreWorkerProcess.GetCoreWorker().SetPlasmaAddedCallback(
|
||||
async_plasma_callback)
|
||||
|
||||
def subscribe_to_plasma_object(self, ObjectID object_id):
|
||||
self.core_worker.get().SubscribeToPlasmaAdd(object_id.native())
|
||||
CCoreWorkerProcess.GetCoreWorker().SubscribeToPlasmaAdd(
|
||||
object_id.native())
|
||||
|
||||
def get_plasma_event_handler(self):
|
||||
return self.plasma_event_handler
|
||||
|
@ -694,7 +725,7 @@ cdef class CoreWorker:
|
|||
c_vector[CObjectID] c_object_ids = ObjectIDsToVector(object_ids)
|
||||
|
||||
with nogil:
|
||||
check_status(self.core_worker.get().Get(
|
||||
check_status(CCoreWorkerProcess.GetCoreWorker().Get(
|
||||
c_object_ids, timeout_ms, &results))
|
||||
|
||||
return RayObjectsToDataMetadataPairs(results)
|
||||
|
@ -705,7 +736,7 @@ cdef class CoreWorker:
|
|||
CObjectID c_object_id = object_id.native()
|
||||
|
||||
with nogil:
|
||||
check_status(self.core_worker.get().Contains(
|
||||
check_status(CCoreWorkerProcess.GetCoreWorker().Contains(
|
||||
c_object_id, &has_object))
|
||||
|
||||
return has_object
|
||||
|
@ -716,13 +747,13 @@ cdef class CoreWorker:
|
|||
CObjectID *c_object_id, shared_ptr[CBuffer] *data):
|
||||
if object_id is None:
|
||||
with nogil:
|
||||
check_status(self.core_worker.get().Create(
|
||||
check_status(CCoreWorkerProcess.GetCoreWorker().Create(
|
||||
metadata, data_size, contained_ids,
|
||||
c_object_id, data))
|
||||
else:
|
||||
c_object_id[0] = object_id.native()
|
||||
with nogil:
|
||||
check_status(self.core_worker.get().Create(
|
||||
check_status(CCoreWorkerProcess.GetCoreWorker().Create(
|
||||
metadata, data_size,
|
||||
c_object_id[0], data))
|
||||
|
||||
|
@ -752,7 +783,7 @@ cdef class CoreWorker:
|
|||
write_serialized_object(serialized_object, data)
|
||||
if self.is_local_mode:
|
||||
c_object_id_vector.push_back(c_object_id)
|
||||
check_status(self.core_worker.get().Put(
|
||||
check_status(CCoreWorkerProcess.GetCoreWorker().Put(
|
||||
CRayObject(data, metadata, c_object_id_vector),
|
||||
c_object_id_vector, c_object_id))
|
||||
else:
|
||||
|
@ -760,7 +791,7 @@ cdef class CoreWorker:
|
|||
# Using custom object IDs is not supported because we can't
|
||||
# track their lifecycle, so we don't pin the object in this
|
||||
# case.
|
||||
check_status(self.core_worker.get().Seal(
|
||||
check_status(CCoreWorkerProcess.GetCoreWorker().Seal(
|
||||
c_object_id,
|
||||
pin_object and object_id is None))
|
||||
|
||||
|
@ -775,7 +806,7 @@ cdef class CoreWorker:
|
|||
|
||||
wait_ids = ObjectIDsToVector(object_ids)
|
||||
with nogil:
|
||||
check_status(self.core_worker.get().Wait(
|
||||
check_status(CCoreWorkerProcess.GetCoreWorker().Wait(
|
||||
wait_ids, num_returns, timeout_ms, &results))
|
||||
|
||||
assert len(results) == len(object_ids)
|
||||
|
@ -795,19 +826,19 @@ cdef class CoreWorker:
|
|||
c_vector[CObjectID] free_ids = ObjectIDsToVector(object_ids)
|
||||
|
||||
with nogil:
|
||||
check_status(self.core_worker.get().Delete(
|
||||
check_status(CCoreWorkerProcess.GetCoreWorker().Delete(
|
||||
free_ids, local_only, delete_creating_tasks))
|
||||
|
||||
def global_gc(self):
|
||||
with nogil:
|
||||
self.core_worker.get().TriggerGlobalGC()
|
||||
CCoreWorkerProcess.GetCoreWorker().TriggerGlobalGC()
|
||||
|
||||
def set_object_store_client_options(self, client_name,
|
||||
int64_t limit_bytes):
|
||||
try:
|
||||
logger.debug("Setting plasma memory limit to {} for {}".format(
|
||||
limit_bytes, client_name))
|
||||
check_status(self.core_worker.get().SetClientOptions(
|
||||
check_status(CCoreWorkerProcess.GetCoreWorker().SetClientOptions(
|
||||
client_name.encode("ascii"), limit_bytes))
|
||||
except RayError as e:
|
||||
self.dump_object_store_memory_usage()
|
||||
|
@ -820,7 +851,7 @@ cdef class CoreWorker:
|
|||
limit_bytes, client_name, e))
|
||||
|
||||
def dump_object_store_memory_usage(self):
|
||||
message = self.core_worker.get().MemoryUsageString()
|
||||
message = CCoreWorkerProcess.GetCoreWorker().MemoryUsageString()
|
||||
logger.warning("Local object store memory usage:\n{}\n".format(
|
||||
message.decode("utf-8")))
|
||||
|
||||
|
@ -847,7 +878,7 @@ cdef class CoreWorker:
|
|||
prepare_args(self, args, &args_vector)
|
||||
|
||||
with nogil:
|
||||
check_status(self.core_worker.get().SubmitTask(
|
||||
check_status(CCoreWorkerProcess.GetCoreWorker().SubmitTask(
|
||||
ray_function, args_vector, task_options, &return_ids,
|
||||
max_retries))
|
||||
|
||||
|
@ -880,7 +911,7 @@ cdef class CoreWorker:
|
|||
prepare_args(self, args, &args_vector)
|
||||
|
||||
with nogil:
|
||||
check_status(self.core_worker.get().CreateActor(
|
||||
check_status(CCoreWorkerProcess.GetCoreWorker().CreateActor(
|
||||
ray_function, args_vector,
|
||||
CActorCreationOptions(
|
||||
max_reconstructions, max_concurrency,
|
||||
|
@ -916,10 +947,11 @@ cdef class CoreWorker:
|
|||
prepare_args(self, args, &args_vector)
|
||||
|
||||
with nogil:
|
||||
check_status(self.core_worker.get().SubmitActorTask(
|
||||
c_actor_id,
|
||||
ray_function,
|
||||
args_vector, task_options, &return_ids))
|
||||
check_status(
|
||||
CCoreWorkerProcess.GetCoreWorker().SubmitActorTask(
|
||||
c_actor_id,
|
||||
ray_function,
|
||||
args_vector, task_options, &return_ids))
|
||||
|
||||
return VectorToObjectIDs(return_ids)
|
||||
|
||||
|
@ -928,13 +960,13 @@ cdef class CoreWorker:
|
|||
CActorID c_actor_id = actor_id.native()
|
||||
|
||||
with nogil:
|
||||
check_status(self.core_worker.get().KillActor(
|
||||
check_status(CCoreWorkerProcess.GetCoreWorker().KillActor(
|
||||
c_actor_id, True, no_reconstruction))
|
||||
|
||||
def resource_ids(self):
|
||||
cdef:
|
||||
ResourceMappingType resource_mapping = (
|
||||
self.core_worker.get().GetResourceIDs())
|
||||
CCoreWorkerProcess.GetCoreWorker().GetResourceIDs())
|
||||
unordered_map[
|
||||
c_string, c_vector[pair[int64_t, double]]
|
||||
].iterator iterator = resource_mapping.begin()
|
||||
|
@ -955,13 +987,14 @@ cdef class CoreWorker:
|
|||
|
||||
def profile_event(self, c_string event_type, object extra_data=None):
|
||||
return ProfileEvent.make(
|
||||
self.core_worker.get().CreateProfileEvent(event_type),
|
||||
CCoreWorkerProcess.GetCoreWorker().CreateProfileEvent(event_type),
|
||||
extra_data)
|
||||
|
||||
def remove_actor_handle_reference(self, ActorID actor_id):
|
||||
cdef:
|
||||
CActorID c_actor_id = actor_id.native()
|
||||
self.core_worker.get().RemoveActorHandleReference(c_actor_id)
|
||||
CCoreWorkerProcess.GetCoreWorker().RemoveActorHandleReference(
|
||||
c_actor_id)
|
||||
|
||||
def deserialize_and_register_actor_handle(self, const c_string &bytes,
|
||||
ObjectID
|
||||
|
@ -974,9 +1007,10 @@ cdef class CoreWorker:
|
|||
worker = ray.worker.global_worker
|
||||
worker.check_connected()
|
||||
manager = worker.function_actor_manager
|
||||
c_actor_id = self.core_worker.get().DeserializeAndRegisterActorHandle(
|
||||
bytes, c_outer_object_id)
|
||||
check_status(self.core_worker.get().GetActorHandle(
|
||||
c_actor_id = (CCoreWorkerProcess.GetCoreWorker()
|
||||
.DeserializeAndRegisterActorHandle(
|
||||
bytes, c_outer_object_id))
|
||||
check_status(CCoreWorkerProcess.GetCoreWorker().GetActorHandle(
|
||||
c_actor_id, &c_actor_handle))
|
||||
actor_id = ActorID(c_actor_id.Binary())
|
||||
job_id = JobID(c_actor_handle.CreationJobID().Binary())
|
||||
|
@ -1017,24 +1051,26 @@ cdef class CoreWorker:
|
|||
cdef:
|
||||
c_string output
|
||||
CObjectID c_actor_handle_id
|
||||
check_status(self.core_worker.get().SerializeActorHandle(
|
||||
check_status(CCoreWorkerProcess.GetCoreWorker().SerializeActorHandle(
|
||||
actor_id.native(), &output, &c_actor_handle_id))
|
||||
return output, ObjectID(c_actor_handle_id.Binary())
|
||||
|
||||
def add_object_id_reference(self, ObjectID object_id):
|
||||
# Note: faster to not release GIL for short-running op.
|
||||
self.core_worker.get().AddLocalReference(object_id.native())
|
||||
CCoreWorkerProcess.GetCoreWorker().AddLocalReference(
|
||||
object_id.native())
|
||||
|
||||
def remove_object_id_reference(self, ObjectID object_id):
|
||||
# Note: faster to not release GIL for short-running op.
|
||||
self.core_worker.get().RemoveLocalReference(object_id.native())
|
||||
CCoreWorkerProcess.GetCoreWorker().RemoveLocalReference(
|
||||
object_id.native())
|
||||
|
||||
def serialize_and_promote_object_id(self, ObjectID object_id):
|
||||
cdef:
|
||||
CObjectID c_object_id = object_id.native()
|
||||
CTaskID c_owner_id = CTaskID.Nil()
|
||||
CAddress c_owner_address = CAddress()
|
||||
self.core_worker.get().PromoteToPlasmaAndGetOwnershipInfo(
|
||||
CCoreWorkerProcess.GetCoreWorker().PromoteToPlasmaAndGetOwnershipInfo(
|
||||
c_object_id, &c_owner_id, &c_owner_address)
|
||||
return (object_id,
|
||||
TaskID(c_owner_id.Binary()),
|
||||
|
@ -1053,11 +1089,12 @@ cdef class CoreWorker:
|
|||
CAddress c_owner_address = CAddress()
|
||||
|
||||
c_owner_address.ParseFromString(serialized_owner_address)
|
||||
self.core_worker.get().RegisterOwnershipInfoAndResolveFuture(
|
||||
(CCoreWorkerProcess.GetCoreWorker()
|
||||
.RegisterOwnershipInfoAndResolveFuture(
|
||||
c_object_id,
|
||||
c_outer_object_id,
|
||||
c_owner_id,
|
||||
c_owner_address)
|
||||
c_owner_address))
|
||||
|
||||
cdef store_task_outputs(
|
||||
self, worker, outputs, const c_vector[CObjectID] return_ids,
|
||||
|
@ -1088,8 +1125,10 @@ cdef class CoreWorker:
|
|||
ObjectIDsToVector(serialized_object.contained_object_ids))
|
||||
|
||||
with nogil:
|
||||
check_status(self.core_worker.get().AllocateReturnObjects(
|
||||
return_ids, data_sizes, metadatas, contained_ids, returns))
|
||||
check_status(CCoreWorkerProcess.GetCoreWorker()
|
||||
.AllocateReturnObjects(
|
||||
return_ids, data_sizes, metadatas, contained_ids,
|
||||
returns))
|
||||
|
||||
for i, serialized_object in enumerate(serialized_objects):
|
||||
# A nullptr is returned if the object already exists.
|
||||
|
@ -1099,7 +1138,7 @@ cdef class CoreWorker:
|
|||
if self.is_local_mode:
|
||||
return_ids_vector.push_back(return_ids[i])
|
||||
check_status(
|
||||
self.core_worker.get().Put(
|
||||
CCoreWorkerProcess.GetCoreWorker().Put(
|
||||
CRayObject(returns[0][i].get().GetData(),
|
||||
returns[0][i].get().GetMetadata(),
|
||||
return_ids_vector),
|
||||
|
@ -1138,7 +1177,7 @@ cdef class CoreWorker:
|
|||
future = asyncio.run_coroutine_threadsafe(coroutine, loop)
|
||||
future.add_done_callback(lambda _: event.Notify())
|
||||
with nogil:
|
||||
(self.core_worker.get()
|
||||
(CCoreWorkerProcess.GetCoreWorker()
|
||||
.YieldCurrentFiber(event))
|
||||
return future.result()
|
||||
|
||||
|
@ -1149,14 +1188,20 @@ cdef class CoreWorker:
|
|||
self.async_thread.join()
|
||||
|
||||
def current_actor_is_asyncio(self):
|
||||
return self.core_worker.get().GetWorkerContext().CurrentActorIsAsync()
|
||||
return (CCoreWorkerProcess.GetCoreWorker().GetWorkerContext()
|
||||
.CurrentActorIsAsync())
|
||||
|
||||
cdef yield_current_fiber(self, CFiberEvent &fiber_event):
|
||||
with nogil:
|
||||
CCoreWorkerProcess.GetCoreWorker().YieldCurrentFiber(fiber_event)
|
||||
|
||||
def get_all_reference_counts(self):
|
||||
cdef:
|
||||
unordered_map[CObjectID, pair[size_t, size_t]] c_ref_counts
|
||||
unordered_map[CObjectID, pair[size_t, size_t]].iterator it
|
||||
|
||||
c_ref_counts = self.core_worker.get().GetAllReferenceCounts()
|
||||
c_ref_counts = (
|
||||
CCoreWorkerProcess.GetCoreWorker().GetAllReferenceCounts())
|
||||
it = c_ref_counts.begin()
|
||||
|
||||
ref_counts = {}
|
||||
|
@ -1170,7 +1215,7 @@ cdef class CoreWorker:
|
|||
return ref_counts
|
||||
|
||||
def in_memory_store_get_async(self, ObjectID object_id, future):
|
||||
self.core_worker.get().GetAsync(
|
||||
CCoreWorkerProcess.GetCoreWorker().GetAsync(
|
||||
object_id.native(),
|
||||
async_set_result_callback,
|
||||
async_retry_with_plasma_callback,
|
||||
|
@ -1178,7 +1223,7 @@ cdef class CoreWorker:
|
|||
|
||||
def push_error(self, JobID job_id, error_type, error_message,
|
||||
double timestamp):
|
||||
check_status(self.core_worker.get().PushError(
|
||||
check_status(CCoreWorkerProcess.GetCoreWorker().PushError(
|
||||
job_id.native(), error_type.encode("ascii"),
|
||||
error_message.encode("ascii"), timestamp))
|
||||
|
||||
|
@ -1190,18 +1235,21 @@ cdef class CoreWorker:
|
|||
# PrepareActorCheckpoint will wait for raylet's reply, release
|
||||
# the GIL so other Python threads can run.
|
||||
with nogil:
|
||||
check_status(self.core_worker.get().PrepareActorCheckpoint(
|
||||
c_actor_id, &checkpoint_id))
|
||||
check_status(
|
||||
CCoreWorkerProcess.GetCoreWorker()
|
||||
.PrepareActorCheckpoint(c_actor_id, &checkpoint_id))
|
||||
return ActorCheckpointID(checkpoint_id.Binary())
|
||||
|
||||
def notify_actor_resumed_from_checkpoint(self, ActorID actor_id,
|
||||
ActorCheckpointID checkpoint_id):
|
||||
check_status(self.core_worker.get().NotifyActorResumedFromCheckpoint(
|
||||
actor_id.native(), checkpoint_id.native()))
|
||||
check_status(
|
||||
CCoreWorkerProcess.GetCoreWorker()
|
||||
.NotifyActorResumedFromCheckpoint(
|
||||
actor_id.native(), checkpoint_id.native()))
|
||||
|
||||
def set_resource(self, basestring resource_name,
|
||||
double capacity, ClientID client_id):
|
||||
self.core_worker.get().SetResource(
|
||||
CCoreWorkerProcess.GetCoreWorker().SetResource(
|
||||
resource_name.encode("ascii"), capacity,
|
||||
CClientID.FromBinary(client_id.binary()))
|
||||
|
||||
|
|
|
@ -17,7 +17,6 @@ from ray.includes.unique_ids cimport (
|
|||
CJobID,
|
||||
CTaskID,
|
||||
CObjectID,
|
||||
CWorkerID,
|
||||
)
|
||||
from ray.includes.common cimport (
|
||||
CAddress,
|
||||
|
@ -80,31 +79,9 @@ cdef extern from "ray/core_worker/core_worker.h" nogil:
|
|||
c_string ExtensionData() const
|
||||
|
||||
cdef cppclass CCoreWorker "ray::CoreWorker":
|
||||
CCoreWorker(const CWorkerType worker_type, const CLanguage language,
|
||||
const c_string &store_socket,
|
||||
const c_string &raylet_socket, const CJobID &job_id,
|
||||
const CGcsClientOptions &gcs_options,
|
||||
const c_string &log_dir, const c_string &node_ip_address,
|
||||
int node_manager_port,
|
||||
CRayStatus (
|
||||
CTaskType task_type,
|
||||
const CRayFunction &ray_function,
|
||||
const unordered_map[c_string, double] &resources,
|
||||
const c_vector[shared_ptr[CRayObject]] &args,
|
||||
const c_vector[CObjectID] &arg_reference_ids,
|
||||
const c_vector[CObjectID] &return_ids,
|
||||
c_vector[shared_ptr[CRayObject]] *returns,
|
||||
const CWorkerID &worker_id) nogil,
|
||||
CRayStatus() nogil,
|
||||
void() nogil,
|
||||
void(c_string *stack_out) nogil,
|
||||
c_bool ref_counting_enabled,
|
||||
c_bool local_worker)
|
||||
CWorkerType &GetWorkerType()
|
||||
CLanguage &GetLanguage()
|
||||
|
||||
void StartExecutingTasks()
|
||||
|
||||
CRayStatus SubmitTask(
|
||||
const CRayFunction &function, const c_vector[CTaskArg] &args,
|
||||
const CTaskOptions &options, c_vector[CObjectID] *return_ids,
|
||||
|
@ -206,3 +183,46 @@ cdef extern from "ray/core_worker/core_worker.h" nogil:
|
|||
void SetPlasmaAddedCallback(plasma_callback_function callback)
|
||||
|
||||
void SubscribeToPlasmaAdd(const CObjectID &object_id)
|
||||
|
||||
cdef cppclass CCoreWorkerOptions "ray::CoreWorkerOptions":
|
||||
CWorkerType worker_type
|
||||
CLanguage language
|
||||
c_string store_socket
|
||||
c_string raylet_socket
|
||||
CJobID job_id
|
||||
CGcsClientOptions gcs_options
|
||||
c_string log_dir
|
||||
c_bool install_failure_signal_handler
|
||||
c_string node_ip_address
|
||||
int node_manager_port
|
||||
c_string driver_name
|
||||
c_string stdout_file
|
||||
c_string stderr_file
|
||||
(CRayStatus(
|
||||
CTaskType task_type,
|
||||
const CRayFunction &ray_function,
|
||||
const unordered_map[c_string, double] &resources,
|
||||
const c_vector[shared_ptr[CRayObject]] &args,
|
||||
const c_vector[CObjectID] &arg_reference_ids,
|
||||
const c_vector[CObjectID] &return_ids,
|
||||
c_vector[shared_ptr[CRayObject]] *returns) nogil
|
||||
) task_execution_callback
|
||||
(CRayStatus() nogil) check_signals
|
||||
(void() nogil) gc_collect
|
||||
(void(c_string *stack_out) nogil) get_lang_stack
|
||||
c_bool ref_counting_enabled
|
||||
c_bool is_local_mode
|
||||
int num_workers
|
||||
CCoreWorkerOptions()
|
||||
|
||||
cdef cppclass CCoreWorkerProcess "ray::CoreWorkerProcess":
|
||||
@staticmethod
|
||||
void Initialize(const CCoreWorkerOptions &options)
|
||||
# Only call this in CoreWorker.__cinit__,
|
||||
# use CoreWorker.core_worker to access C++ CoreWorker.
|
||||
@staticmethod
|
||||
CCoreWorker &GetCoreWorker()
|
||||
@staticmethod
|
||||
void Shutdown()
|
||||
@staticmethod
|
||||
void RunTaskExecutionLoop()
|
||||
|
|
|
@ -1173,27 +1173,14 @@ def connect(node,
|
|||
ray.state.state._initialize_global_state(
|
||||
node.redis_address, redis_password=node.redis_password)
|
||||
|
||||
# Register the worker with Redis.
|
||||
driver_name = ""
|
||||
log_stdout_file_name = ""
|
||||
log_stderr_file_name = ""
|
||||
if mode == SCRIPT_MODE:
|
||||
# The concept of a driver is the same as the concept of a "job".
|
||||
# Register the driver/job with Redis here.
|
||||
import __main__ as main
|
||||
driver_info = {
|
||||
"node_ip_address": node.node_ip_address,
|
||||
"driver_id": worker.worker_id,
|
||||
"start_time": time.time(),
|
||||
"plasma_store_socket": node.plasma_store_socket_name,
|
||||
"raylet_socket": node.raylet_socket_name,
|
||||
"name": (main.__file__
|
||||
if hasattr(main, "__file__") else "INTERACTIVE MODE")
|
||||
}
|
||||
worker.redis_client.hmset(b"Drivers:" + worker.worker_id, driver_info)
|
||||
driver_name = (main.__file__
|
||||
if hasattr(main, "__file__") else "INTERACTIVE MODE")
|
||||
elif mode == WORKER_MODE:
|
||||
# Register the worker with Redis.
|
||||
worker_dict = {
|
||||
"node_ip_address": node.node_ip_address,
|
||||
"plasma_store_socket": node.plasma_store_socket_name,
|
||||
}
|
||||
# Check the RedirectOutput key in Redis and based on its value redirect
|
||||
# worker output and error to their own files.
|
||||
# This key is set in services.py when Redis is started.
|
||||
|
@ -1224,14 +1211,12 @@ def connect(node,
|
|||
print("Ray worker pid: {}".format(os.getpid()), file=sys.stderr)
|
||||
sys.stdout.flush()
|
||||
sys.stderr.flush()
|
||||
|
||||
worker_dict["stdout_file"] = os.path.abspath(
|
||||
log_stdout_file_name = os.path.abspath(
|
||||
(log_stdout_file
|
||||
if log_stdout_file is not None else sys.stdout).name)
|
||||
worker_dict["stderr_file"] = os.path.abspath(
|
||||
log_stderr_file_name = os.path.abspath(
|
||||
(log_stderr_file
|
||||
if log_stderr_file is not None else sys.stderr).name)
|
||||
worker.redis_client.hmset(b"Workers:" + worker.worker_id, worker_dict)
|
||||
elif not LOCAL_MODE:
|
||||
raise ValueError(
|
||||
"Invalid worker mode. Expected DRIVER, WORKER or LOCAL.")
|
||||
|
@ -1242,9 +1227,19 @@ def connect(node,
|
|||
node.redis_password,
|
||||
)
|
||||
worker.core_worker = ray._raylet.CoreWorker(
|
||||
(mode == SCRIPT_MODE), node.plasma_store_socket_name,
|
||||
node.raylet_socket_name, job_id, gcs_options, node.get_logs_dir_path(),
|
||||
node.node_ip_address, node.node_manager_port, mode == LOCAL_MODE)
|
||||
(mode == SCRIPT_MODE or mode == LOCAL_MODE),
|
||||
node.plasma_store_socket_name,
|
||||
node.raylet_socket_name,
|
||||
job_id,
|
||||
gcs_options,
|
||||
node.get_logs_dir_path(),
|
||||
node.node_ip_address,
|
||||
node.node_manager_port,
|
||||
(mode == LOCAL_MODE),
|
||||
driver_name,
|
||||
log_stdout_file_name,
|
||||
log_stderr_file_name,
|
||||
)
|
||||
|
||||
if driver_object_store_memory is not None:
|
||||
worker.core_worker.set_object_store_client_options(
|
||||
|
|
|
@ -63,10 +63,10 @@ struct WorkerThreadContext {
|
|||
thread_local std::unique_ptr<WorkerThreadContext> WorkerContext::thread_context_ =
|
||||
nullptr;
|
||||
|
||||
WorkerContext::WorkerContext(WorkerType worker_type, const JobID &job_id)
|
||||
WorkerContext::WorkerContext(WorkerType worker_type, const WorkerID &worker_id,
|
||||
const JobID &job_id)
|
||||
: worker_type_(worker_type),
|
||||
worker_id_(worker_type_ == WorkerType::DRIVER ? ComputeDriverIdFromJob(job_id)
|
||||
: WorkerID::FromRandom()),
|
||||
worker_id_(worker_id),
|
||||
current_job_id_(worker_type_ == WorkerType::DRIVER ? job_id : JobID::Nil()),
|
||||
current_actor_id_(ActorID::Nil()),
|
||||
main_thread_id_(boost::this_thread::get_id()) {
|
||||
|
|
|
@ -26,7 +26,7 @@ struct WorkerThreadContext;
|
|||
|
||||
class WorkerContext {
|
||||
public:
|
||||
WorkerContext(WorkerType worker_type, const JobID &job_id);
|
||||
WorkerContext(WorkerType worker_type, const WorkerID &worker_id, const JobID &job_id);
|
||||
|
||||
const WorkerType GetWorkerType() const;
|
||||
|
||||
|
|
|
@ -75,63 +75,214 @@ void GroupObjectIdsByStoreProvider(const std::vector<ObjectID> &object_ids,
|
|||
|
||||
namespace ray {
|
||||
|
||||
CoreWorker::CoreWorker(const WorkerType worker_type, const Language language,
|
||||
const std::string &store_socket, const std::string &raylet_socket,
|
||||
const JobID &job_id, const gcs::GcsClientOptions &gcs_options,
|
||||
const std::string &log_dir, const std::string &node_ip_address,
|
||||
int node_manager_port,
|
||||
const TaskExecutionCallback &task_execution_callback,
|
||||
std::function<Status()> check_signals,
|
||||
std::function<void()> gc_collect,
|
||||
std::function<void(std::string *)> get_lang_stack,
|
||||
bool ref_counting_enabled, bool local_mode)
|
||||
: worker_type_(worker_type),
|
||||
language_(language),
|
||||
log_dir_(log_dir),
|
||||
ref_counting_enabled_(ref_counting_enabled),
|
||||
is_local_mode_(local_mode),
|
||||
check_signals_(check_signals),
|
||||
gc_collect_(gc_collect),
|
||||
get_call_site_(RayConfig::instance().record_ref_creation_sites() ? get_lang_stack
|
||||
: nullptr),
|
||||
worker_context_(worker_type, job_id),
|
||||
std::unique_ptr<CoreWorkerProcess> CoreWorkerProcess::instance_;
|
||||
|
||||
thread_local std::weak_ptr<CoreWorker> CoreWorkerProcess::current_core_worker_;
|
||||
|
||||
void CoreWorkerProcess::Initialize(const CoreWorkerOptions &options) {
|
||||
RAY_CHECK(!instance_) << "The process is already initialized for core worker.";
|
||||
instance_ = std::unique_ptr<CoreWorkerProcess>(new CoreWorkerProcess(options));
|
||||
}
|
||||
|
||||
void CoreWorkerProcess::Shutdown() {
|
||||
if (!instance_) {
|
||||
return;
|
||||
}
|
||||
RAY_CHECK(instance_->options_.worker_type == WorkerType::DRIVER)
|
||||
<< "The `Shutdown` interface is for driver only.";
|
||||
RAY_CHECK(instance_->global_worker_);
|
||||
instance_->global_worker_->Disconnect();
|
||||
instance_->global_worker_->Shutdown();
|
||||
instance_->RemoveWorker(instance_->global_worker_);
|
||||
instance_.reset();
|
||||
}
|
||||
|
||||
bool CoreWorkerProcess::IsInitialized() { return instance_ != nullptr; }
|
||||
|
||||
CoreWorkerProcess::CoreWorkerProcess(const CoreWorkerOptions &options)
|
||||
: options_(options),
|
||||
global_worker_id_(
|
||||
options.worker_type == WorkerType::DRIVER
|
||||
? ComputeDriverIdFromJob(options_.job_id)
|
||||
: (options_.num_workers == 1 ? WorkerID::FromRandom() : WorkerID::Nil())) {
|
||||
// Initialize logging if log_dir is passed. Otherwise, it must be initialized
|
||||
// and cleaned up by the caller.
|
||||
if (options_.log_dir != "") {
|
||||
std::stringstream app_name;
|
||||
app_name << LanguageString(options_.language) << "-core-"
|
||||
<< WorkerTypeString(options_.worker_type);
|
||||
if (!global_worker_id_.IsNil()) {
|
||||
app_name << "-" << global_worker_id_;
|
||||
}
|
||||
RayLog::StartRayLog(app_name.str(), RayLogLevel::INFO, options_.log_dir);
|
||||
if (options_.install_failure_signal_handler) {
|
||||
RayLog::InstallFailureSignalHandler();
|
||||
}
|
||||
}
|
||||
|
||||
RAY_CHECK(options_.num_workers > 0);
|
||||
if (options_.worker_type == WorkerType::DRIVER) {
|
||||
// Driver process can only contain one worker.
|
||||
RAY_CHECK(options_.num_workers == 1);
|
||||
}
|
||||
|
||||
RAY_LOG(INFO) << "Constructing CoreWorkerProcess. pid: " << getpid();
|
||||
|
||||
if (options_.num_workers == 1) {
|
||||
// We need to create the worker instance here if:
|
||||
// 1. This is a driver process. In this case, the driver is ready to use right after
|
||||
// the CoreWorkerProcess::Initialize.
|
||||
// 2. This is a Python worker process. In this case, Python will invoke some core
|
||||
// worker APIs before `CoreWorkerProcess::RunTaskExecutionLoop` is called. So we need
|
||||
// to create the worker instance here. One example of invocations is
|
||||
// https://github.com/ray-project/ray/blob/45ce40e5d44801193220d2c546be8de0feeef988/python/ray/worker.py#L1281.
|
||||
if (options_.worker_type == WorkerType::DRIVER ||
|
||||
options_.language == Language::PYTHON) {
|
||||
CreateWorker();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
CoreWorkerProcess::~CoreWorkerProcess() {
|
||||
RAY_LOG(INFO) << "Destructing CoreWorkerProcess. pid: " << getpid();
|
||||
{
|
||||
// Check that all `CoreWorker` instances have been removed.
|
||||
absl::ReaderMutexLock lock(&worker_map_mutex_);
|
||||
RAY_CHECK(workers_.empty());
|
||||
}
|
||||
if (options_.log_dir != "") {
|
||||
RayLog::ShutDownRayLog();
|
||||
}
|
||||
}
|
||||
|
||||
void CoreWorkerProcess::EnsureInitialized() {
|
||||
RAY_CHECK(instance_) << "The core worker process is not initialized yet or already "
|
||||
<< "shutdown.";
|
||||
}
|
||||
|
||||
CoreWorker &CoreWorkerProcess::GetCoreWorker() {
|
||||
EnsureInitialized();
|
||||
if (instance_->options_.num_workers == 1) {
|
||||
return *instance_->global_worker_;
|
||||
}
|
||||
auto ptr = current_core_worker_.lock();
|
||||
RAY_CHECK(ptr != nullptr)
|
||||
<< "The current thread is not bound with a core worker instance.";
|
||||
return *ptr;
|
||||
}
|
||||
|
||||
void CoreWorkerProcess::SetCurrentThreadWorkerId(const WorkerID &worker_id) {
|
||||
EnsureInitialized();
|
||||
if (instance_->options_.num_workers == 1) {
|
||||
RAY_CHECK(instance_->global_worker_->GetWorkerID() == worker_id);
|
||||
return;
|
||||
}
|
||||
current_core_worker_ = instance_->GetWorker(worker_id);
|
||||
}
|
||||
|
||||
std::shared_ptr<CoreWorker> CoreWorkerProcess::GetWorker(
|
||||
const WorkerID &worker_id) const {
|
||||
absl::ReaderMutexLock lock(&worker_map_mutex_);
|
||||
auto it = workers_.find(worker_id);
|
||||
RAY_CHECK(it != workers_.end()) << "Worker " << worker_id << " not found.";
|
||||
return it->second;
|
||||
}
|
||||
|
||||
std::shared_ptr<CoreWorker> CoreWorkerProcess::CreateWorker() {
|
||||
auto worker = std::make_shared<CoreWorker>(
|
||||
options_,
|
||||
global_worker_id_ != WorkerID::Nil() ? global_worker_id_ : WorkerID::FromRandom());
|
||||
RAY_LOG(INFO) << "Worker " << worker->GetWorkerID() << " is created.";
|
||||
if (options_.num_workers == 1) {
|
||||
global_worker_ = worker;
|
||||
}
|
||||
current_core_worker_ = worker;
|
||||
|
||||
absl::MutexLock lock(&worker_map_mutex_);
|
||||
workers_.emplace(worker->GetWorkerID(), worker);
|
||||
RAY_CHECK(workers_.size() <= static_cast<size_t>(options_.num_workers));
|
||||
return worker;
|
||||
}
|
||||
|
||||
void CoreWorkerProcess::RemoveWorker(std::shared_ptr<CoreWorker> worker) {
|
||||
worker->WaitForShutdown();
|
||||
if (global_worker_) {
|
||||
RAY_CHECK(global_worker_ == worker);
|
||||
} else {
|
||||
RAY_CHECK(current_core_worker_.lock() == worker);
|
||||
}
|
||||
current_core_worker_.reset();
|
||||
{
|
||||
absl::MutexLock lock(&worker_map_mutex_);
|
||||
workers_.erase(worker->GetWorkerID());
|
||||
RAY_LOG(INFO) << "Removed worker " << worker->GetWorkerID();
|
||||
}
|
||||
if (global_worker_ == worker) {
|
||||
global_worker_ = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
void CoreWorkerProcess::RunTaskExecutionLoop() {
|
||||
EnsureInitialized();
|
||||
RAY_CHECK(instance_->options_.worker_type == WorkerType::WORKER);
|
||||
if (instance_->options_.num_workers == 1) {
|
||||
// Run the task loop in the current thread only if the number of workers is 1.
|
||||
auto worker =
|
||||
instance_->global_worker_ ? instance_->global_worker_ : instance_->CreateWorker();
|
||||
worker->RunTaskExecutionLoop();
|
||||
instance_->RemoveWorker(worker);
|
||||
} else {
|
||||
std::vector<std::thread> worker_threads;
|
||||
for (int i = 0; i < instance_->options_.num_workers; i++) {
|
||||
worker_threads.emplace_back([]() {
|
||||
auto worker = instance_->CreateWorker();
|
||||
worker->RunTaskExecutionLoop();
|
||||
instance_->RemoveWorker(worker);
|
||||
});
|
||||
}
|
||||
for (auto &thread : worker_threads) {
|
||||
thread.join();
|
||||
}
|
||||
}
|
||||
|
||||
instance_.reset();
|
||||
}
|
||||
|
||||
CoreWorker::CoreWorker(const CoreWorkerOptions &options, const WorkerID &worker_id)
|
||||
: options_(options),
|
||||
get_call_site_(RayConfig::instance().record_ref_creation_sites()
|
||||
? options_.get_lang_stack
|
||||
: nullptr),
|
||||
worker_context_(options_.worker_type, worker_id, options_.job_id),
|
||||
io_work_(io_service_),
|
||||
client_call_manager_(new rpc::ClientCallManager(io_service_)),
|
||||
death_check_timer_(io_service_),
|
||||
internal_timer_(io_service_),
|
||||
core_worker_server_(WorkerTypeString(worker_type), 0 /* let grpc choose a port */),
|
||||
core_worker_server_(WorkerTypeString(options_.worker_type),
|
||||
0 /* let grpc choose a port */),
|
||||
task_queue_length_(0),
|
||||
num_executed_tasks_(0),
|
||||
task_execution_service_work_(task_execution_service_),
|
||||
task_execution_callback_(task_execution_callback),
|
||||
resource_ids_(new ResourceMappingType()),
|
||||
grpc_service_(io_service_, *this) {
|
||||
// Initialize logging if log_dir is passed. Otherwise, it must be initialized
|
||||
// and cleaned up by the caller.
|
||||
if (log_dir_ != "") {
|
||||
std::stringstream app_name;
|
||||
app_name << LanguageString(language_) << "-" << WorkerTypeString(worker_type_) << "-"
|
||||
<< worker_context_.GetWorkerID();
|
||||
RayLog::StartRayLog(app_name.str(), RayLogLevel::INFO, log_dir_);
|
||||
RayLog::InstallFailureSignalHandler();
|
||||
}
|
||||
// Initialize gcs client.
|
||||
if (RayConfig::instance().gcs_service_enabled()) {
|
||||
gcs_client_ = std::make_shared<ray::gcs::ServiceBasedGcsClient>(gcs_options);
|
||||
gcs_client_ = std::make_shared<ray::gcs::ServiceBasedGcsClient>(options_.gcs_options);
|
||||
} else {
|
||||
gcs_client_ = std::make_shared<ray::gcs::RedisGcsClient>(gcs_options);
|
||||
gcs_client_ = std::make_shared<ray::gcs::RedisGcsClient>(options_.gcs_options);
|
||||
}
|
||||
RAY_CHECK_OK(gcs_client_->Connect(io_service_));
|
||||
RegisterToGcs();
|
||||
|
||||
actor_manager_ = std::unique_ptr<ActorManager>(new ActorManager(gcs_client_->Actors()));
|
||||
|
||||
// Initialize profiler.
|
||||
profiler_ = std::make_shared<worker::Profiler>(worker_context_, node_ip_address,
|
||||
io_service_, gcs_client_);
|
||||
profiler_ = std::make_shared<worker::Profiler>(
|
||||
worker_context_, options_.node_ip_address, io_service_, gcs_client_);
|
||||
|
||||
// Initialize task receivers.
|
||||
if (worker_type_ == WorkerType::WORKER || is_local_mode_) {
|
||||
RAY_CHECK(task_execution_callback_ != nullptr);
|
||||
if (options_.worker_type == WorkerType::WORKER || options_.is_local_mode) {
|
||||
RAY_CHECK(options_.task_execution_callback != nullptr);
|
||||
auto execute_task =
|
||||
std::bind(&CoreWorker::ExecuteTask, this, std::placeholders::_1,
|
||||
std::placeholders::_2, std::placeholders::_3, std::placeholders::_4);
|
||||
|
@ -154,17 +305,18 @@ CoreWorker::CoreWorker(const WorkerType worker_type, const Language language,
|
|||
// so that the worker (java/python .etc) can retrieve and handle the error
|
||||
// instead of crashing.
|
||||
auto grpc_client = rpc::NodeManagerWorkerClient::make(
|
||||
node_ip_address, node_manager_port, *client_call_manager_);
|
||||
options_.node_ip_address, options_.node_manager_port, *client_call_manager_);
|
||||
ClientID local_raylet_id;
|
||||
local_raylet_client_ = std::shared_ptr<raylet::RayletClient>(new raylet::RayletClient(
|
||||
io_service_, std::move(grpc_client), raylet_socket, worker_context_.GetWorkerID(),
|
||||
(worker_type_ == ray::WorkerType::WORKER), worker_context_.GetCurrentJobID(),
|
||||
language_, &local_raylet_id, core_worker_server_.GetPort()));
|
||||
io_service_, std::move(grpc_client), options_.raylet_socket, GetWorkerID(),
|
||||
(options_.worker_type == ray::WorkerType::WORKER),
|
||||
worker_context_.GetCurrentJobID(), options_.language, &local_raylet_id,
|
||||
core_worker_server_.GetPort()));
|
||||
connected_ = true;
|
||||
|
||||
// Set our own address.
|
||||
RAY_CHECK(!local_raylet_id.IsNil());
|
||||
rpc_address_.set_ip_address(node_ip_address);
|
||||
rpc_address_.set_ip_address(options_.node_ip_address);
|
||||
rpc_address_.set_port(core_worker_server_.GetPort());
|
||||
rpc_address_.set_raylet_id(local_raylet_id.Binary());
|
||||
rpc_address_.set_worker_id(worker_context_.GetWorkerID().Binary());
|
||||
|
@ -179,20 +331,21 @@ CoreWorker::CoreWorker(const WorkerType worker_type, const Language language,
|
|||
new rpc::CoreWorkerClient(addr, *client_call_manager_));
|
||||
});
|
||||
|
||||
if (worker_type_ == ray::WorkerType::WORKER) {
|
||||
if (options_.worker_type == ray::WorkerType::WORKER) {
|
||||
death_check_timer_.expires_from_now(boost::asio::chrono::milliseconds(
|
||||
RayConfig::instance().raylet_death_check_interval_milliseconds()));
|
||||
death_check_timer_.async_wait(boost::bind(&CoreWorker::CheckForRayletFailure, this));
|
||||
death_check_timer_.async_wait(
|
||||
boost::bind(&CoreWorker::CheckForRayletFailure, this, _1));
|
||||
}
|
||||
|
||||
internal_timer_.expires_from_now(
|
||||
boost::asio::chrono::milliseconds(kInternalHeartbeatMillis));
|
||||
internal_timer_.async_wait(boost::bind(&CoreWorker::InternalHeartbeat, this));
|
||||
internal_timer_.async_wait(boost::bind(&CoreWorker::InternalHeartbeat, this, _1));
|
||||
|
||||
io_thread_ = std::thread(&CoreWorker::RunIOService, this);
|
||||
|
||||
plasma_store_provider_.reset(new CoreWorkerPlasmaStoreProvider(
|
||||
store_socket, local_raylet_client_, check_signals_,
|
||||
options_.store_socket, local_raylet_client_, options_.check_signals,
|
||||
/*evict_if_full=*/RayConfig::instance().object_pinning_enabled(),
|
||||
boost::bind(&CoreWorker::TriggerGlobalGC, this),
|
||||
boost::bind(&CoreWorker::CurrentCallSite, this)));
|
||||
|
@ -201,8 +354,8 @@ CoreWorker::CoreWorker(const WorkerType worker_type, const Language language,
|
|||
RAY_LOG(DEBUG) << "Promoting object to plasma " << obj_id;
|
||||
RAY_CHECK_OK(Put(obj, /*contained_object_ids=*/{}, obj_id, /*pin_object=*/true));
|
||||
},
|
||||
ref_counting_enabled ? reference_counter_ : nullptr, local_raylet_client_,
|
||||
check_signals_));
|
||||
options_.ref_counting_enabled ? reference_counter_ : nullptr, local_raylet_client_,
|
||||
options_.check_signals));
|
||||
|
||||
task_manager_.reset(new TaskManager(
|
||||
memory_store_, reference_counter_, actor_manager_,
|
||||
|
@ -222,16 +375,17 @@ CoreWorker::CoreWorker(const WorkerType worker_type, const Language language,
|
|||
// driver creates an object that is later evicted, we should notify the
|
||||
// user that we're unable to reconstruct the object, since we cannot
|
||||
// rerun the driver.
|
||||
if (worker_type_ == WorkerType::DRIVER) {
|
||||
if (options_.worker_type == WorkerType::DRIVER) {
|
||||
TaskSpecBuilder builder;
|
||||
const TaskID task_id = TaskID::ForDriverTask(worker_context_.GetCurrentJobID());
|
||||
builder.SetDriverTaskSpec(task_id, language_, worker_context_.GetCurrentJobID(),
|
||||
builder.SetDriverTaskSpec(task_id, options_.language,
|
||||
worker_context_.GetCurrentJobID(),
|
||||
TaskID::ComputeDriverTaskId(worker_context_.GetWorkerID()),
|
||||
GetCallerId(), rpc_address_);
|
||||
|
||||
std::shared_ptr<gcs::TaskTableData> data = std::make_shared<gcs::TaskTableData>();
|
||||
data->mutable_task()->mutable_task_spec()->CopyFrom(builder.Build().GetMessage());
|
||||
if (!is_local_mode_) {
|
||||
if (!options_.is_local_mode) {
|
||||
RAY_CHECK_OK(gcs_client_->Tasks().AsyncAdd(data, nullptr));
|
||||
}
|
||||
SetCurrentTaskId(task_id);
|
||||
|
@ -262,17 +416,9 @@ CoreWorker::CoreWorker(const WorkerType worker_type, const Language language,
|
|||
}
|
||||
}
|
||||
|
||||
CoreWorker::~CoreWorker() {
|
||||
io_service_.stop();
|
||||
io_thread_.join();
|
||||
if (log_dir_ != "") {
|
||||
RayLog::ShutDownRayLog();
|
||||
}
|
||||
}
|
||||
|
||||
void CoreWorker::Shutdown() {
|
||||
io_service_.stop();
|
||||
if (worker_type_ == WorkerType::WORKER) {
|
||||
if (options_.worker_type == WorkerType::WORKER) {
|
||||
task_execution_service_.stop();
|
||||
}
|
||||
}
|
||||
|
@ -356,6 +502,17 @@ void CoreWorker::RunIOService() {
|
|||
io_service_.run();
|
||||
}
|
||||
|
||||
void CoreWorker::WaitForShutdown() {
|
||||
if (io_thread_.joinable()) {
|
||||
io_thread_.join();
|
||||
}
|
||||
if (options_.worker_type == WorkerType::WORKER) {
|
||||
RAY_CHECK(task_execution_service_.stopped());
|
||||
}
|
||||
}
|
||||
|
||||
const WorkerID &CoreWorker::GetWorkerID() const { return worker_context_.GetWorkerID(); }
|
||||
|
||||
void CoreWorker::SetCurrentTaskId(const TaskID &task_id) {
|
||||
worker_context_.SetCurrentTaskId(task_id);
|
||||
main_thread_task_id_ = task_id;
|
||||
|
@ -375,8 +532,41 @@ void CoreWorker::SetCurrentTaskId(const TaskID &task_id) {
|
|||
}
|
||||
}
|
||||
|
||||
void CoreWorker::CheckForRayletFailure() {
|
||||
// If the raylet fails, we will be reassigned to init (PID=1).
|
||||
void CoreWorker::RegisterToGcs() {
|
||||
std::unordered_map<std::string, std::string> worker_info;
|
||||
const auto &worker_id = GetWorkerID();
|
||||
worker_info.emplace("node_ip_address", options_.node_ip_address);
|
||||
worker_info.emplace("plasma_store_socket", options_.store_socket);
|
||||
worker_info.emplace("raylet_socket", options_.raylet_socket);
|
||||
|
||||
if (options_.worker_type == WorkerType::DRIVER) {
|
||||
auto start_time = std::chrono::duration_cast<std::chrono::milliseconds>(
|
||||
std::chrono::system_clock::now().time_since_epoch())
|
||||
.count();
|
||||
worker_info.emplace("driver_id", worker_id.Binary());
|
||||
worker_info.emplace("start_time", std::to_string(start_time));
|
||||
if (!options_.driver_name.empty()) {
|
||||
worker_info.emplace("name", options_.driver_name);
|
||||
}
|
||||
}
|
||||
|
||||
if (!options_.stdout_file.empty()) {
|
||||
worker_info.emplace("stdout_file", options_.stdout_file);
|
||||
}
|
||||
if (!options_.stderr_file.empty()) {
|
||||
worker_info.emplace("stderr_file", options_.stderr_file);
|
||||
}
|
||||
|
||||
RAY_CHECK_OK(gcs_client_->Workers().AsyncRegisterWorker(options_.worker_type, worker_id,
|
||||
worker_info, nullptr));
|
||||
}
|
||||
|
||||
void CoreWorker::CheckForRayletFailure(const boost::system::error_code &error) {
|
||||
if (error == boost::asio::error::operation_aborted) {
|
||||
return;
|
||||
}
|
||||
|
||||
// If the raylet fails, we will be reassigned to init (PID=1).
|
||||
if (getppid() == 1) {
|
||||
RAY_LOG(ERROR) << "Raylet failed. Shutting down.";
|
||||
Shutdown();
|
||||
|
@ -387,10 +577,15 @@ void CoreWorker::CheckForRayletFailure() {
|
|||
death_check_timer_.expiry() +
|
||||
boost::asio::chrono::milliseconds(
|
||||
RayConfig::instance().raylet_death_check_interval_milliseconds()));
|
||||
death_check_timer_.async_wait(boost::bind(&CoreWorker::CheckForRayletFailure, this));
|
||||
death_check_timer_.async_wait(
|
||||
boost::bind(&CoreWorker::CheckForRayletFailure, this, _1));
|
||||
}
|
||||
|
||||
void CoreWorker::InternalHeartbeat() {
|
||||
void CoreWorker::InternalHeartbeat(const boost::system::error_code &error) {
|
||||
if (error == boost::asio::error::operation_aborted) {
|
||||
return;
|
||||
}
|
||||
|
||||
absl::MutexLock lock(&mutex_);
|
||||
while (!to_resubmit_.empty() && current_time_ms() > to_resubmit_.front().first) {
|
||||
RAY_CHECK_OK(direct_task_submitter_->SubmitTask(to_resubmit_.front().second));
|
||||
|
@ -398,7 +593,7 @@ void CoreWorker::InternalHeartbeat() {
|
|||
}
|
||||
internal_timer_.expires_at(internal_timer_.expiry() +
|
||||
boost::asio::chrono::milliseconds(kInternalHeartbeatMillis));
|
||||
internal_timer_.async_wait(boost::bind(&CoreWorker::InternalHeartbeat, this));
|
||||
internal_timer_.async_wait(boost::bind(&CoreWorker::InternalHeartbeat, this, _1));
|
||||
}
|
||||
|
||||
std::unordered_map<ObjectID, std::pair<size_t, size_t>>
|
||||
|
@ -445,7 +640,7 @@ void CoreWorker::RegisterOwnershipInfoAndResolveFuture(
|
|||
reference_counter_->AddBorrowedObject(object_id, outer_object_id, owner_id,
|
||||
owner_address);
|
||||
|
||||
RAY_CHECK(!owner_id.IsNil() || is_local_mode_);
|
||||
RAY_CHECK(!owner_id.IsNil() || options_.is_local_mode);
|
||||
// We will ask the owner about the object until the object is
|
||||
// created or we can no longer reach the owner.
|
||||
future_resolver_->ResolveFutureAsync(object_id, owner_id, owner_address);
|
||||
|
@ -471,7 +666,7 @@ Status CoreWorker::Put(const RayObject &object,
|
|||
const std::vector<ObjectID> &contained_object_ids,
|
||||
const ObjectID &object_id, bool pin_object) {
|
||||
bool object_exists;
|
||||
if (is_local_mode_) {
|
||||
if (options_.is_local_mode) {
|
||||
RAY_CHECK(memory_store_->Put(object, object_id));
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -505,7 +700,7 @@ Status CoreWorker::Create(const std::shared_ptr<Buffer> &metadata, const size_t
|
|||
worker_context_.GetNextPutIndex(),
|
||||
static_cast<uint8_t>(TaskTransportType::DIRECT));
|
||||
|
||||
if (is_local_mode_) {
|
||||
if (options_.is_local_mode) {
|
||||
*data = std::make_shared<LocalMemoryBuffer>(data_size);
|
||||
} else {
|
||||
RAY_RETURN_NOT_OK(
|
||||
|
@ -523,7 +718,7 @@ Status CoreWorker::Create(const std::shared_ptr<Buffer> &metadata, const size_t
|
|||
|
||||
Status CoreWorker::Create(const std::shared_ptr<Buffer> &metadata, const size_t data_size,
|
||||
const ObjectID &object_id, std::shared_ptr<Buffer> *data) {
|
||||
if (is_local_mode_) {
|
||||
if (options_.is_local_mode) {
|
||||
return Status::NotImplemented(
|
||||
"Creating an object with a pre-existing ObjectID is not supported in local mode");
|
||||
} else {
|
||||
|
@ -791,7 +986,7 @@ TaskID CoreWorker::GetCallerId() const {
|
|||
|
||||
Status CoreWorker::PushError(const JobID &job_id, const std::string &type,
|
||||
const std::string &error_message, double timestamp) {
|
||||
if (is_local_mode_) {
|
||||
if (options_.is_local_mode) {
|
||||
RAY_LOG(ERROR) << "Pushed Error with JobID: " << job_id << " of type: " << type
|
||||
<< " with message: " << error_message << " at time: " << timestamp;
|
||||
return Status::OK();
|
||||
|
@ -831,7 +1026,7 @@ Status CoreWorker::SubmitTask(const RayFunction &function,
|
|||
rpc_address_, function, args, task_options.num_returns,
|
||||
task_options.resources, required_resources, return_ids);
|
||||
TaskSpecification task_spec = builder.Build();
|
||||
if (is_local_mode_) {
|
||||
if (options_.is_local_mode) {
|
||||
return ExecuteTaskLocalMode(task_spec);
|
||||
} else {
|
||||
task_manager_->AddPendingTask(GetCallerId(), rpc_address_, task_spec,
|
||||
|
@ -866,7 +1061,7 @@ Status CoreWorker::CreateActor(const RayFunction &function,
|
|||
*return_actor_id = actor_id;
|
||||
TaskSpecification task_spec = builder.Build();
|
||||
Status status;
|
||||
if (is_local_mode_) {
|
||||
if (options_.is_local_mode) {
|
||||
status = ExecuteTaskLocalMode(task_spec);
|
||||
} else {
|
||||
task_manager_->AddPendingTask(
|
||||
|
@ -914,7 +1109,7 @@ Status CoreWorker::SubmitActorTask(const ActorID &actor_id, const RayFunction &f
|
|||
// Submit task.
|
||||
Status status;
|
||||
TaskSpecification task_spec = builder.Build();
|
||||
if (is_local_mode_) {
|
||||
if (options_.is_local_mode) {
|
||||
return ExecuteTaskLocalMode(task_spec, actor_id);
|
||||
}
|
||||
task_manager_->AddPendingTask(GetCallerId(), rpc_address_, task_spec,
|
||||
|
@ -1055,7 +1250,7 @@ std::unique_ptr<worker::ProfileEvent> CoreWorker::CreateProfileEvent(
|
|||
new worker::ProfileEvent(profiler_, event_type));
|
||||
}
|
||||
|
||||
void CoreWorker::StartExecutingTasks() { task_execution_service_.run(); }
|
||||
void CoreWorker::RunTaskExecutionLoop() { task_execution_service_.run(); }
|
||||
|
||||
Status CoreWorker::AllocateReturnObjects(
|
||||
const std::vector<ObjectID> &object_ids, const std::vector<size_t> &data_sizes,
|
||||
|
@ -1066,7 +1261,7 @@ Status CoreWorker::AllocateReturnObjects(
|
|||
RAY_CHECK(object_ids.size() == data_sizes.size());
|
||||
return_objects->resize(object_ids.size(), nullptr);
|
||||
|
||||
rpc::Address owner_address(is_local_mode_
|
||||
rpc::Address owner_address(options_.is_local_mode
|
||||
? rpc::Address()
|
||||
: worker_context_.GetCurrentTask()->CallerAddress());
|
||||
|
||||
|
@ -1083,8 +1278,9 @@ Status CoreWorker::AllocateReturnObjects(
|
|||
}
|
||||
|
||||
// Allocate a buffer for the return object.
|
||||
if (is_local_mode_ || static_cast<int64_t>(data_sizes[i]) <
|
||||
RayConfig::instance().max_direct_call_object_size()) {
|
||||
if (options_.is_local_mode ||
|
||||
static_cast<int64_t>(data_sizes[i]) <
|
||||
RayConfig::instance().max_direct_call_object_size()) {
|
||||
data_buffer = std::make_shared<LocalMemoryBuffer>(data_sizes[i]);
|
||||
} else {
|
||||
RAY_RETURN_NOT_OK(
|
||||
|
@ -1113,7 +1309,7 @@ Status CoreWorker::ExecuteTask(const TaskSpecification &task_spec,
|
|||
resource_ids_ = resource_ids;
|
||||
}
|
||||
|
||||
if (!is_local_mode_) {
|
||||
if (!options_.is_local_mode) {
|
||||
worker_context_.SetCurrentTask(task_spec);
|
||||
SetCurrentTaskId(task_spec.TaskId());
|
||||
}
|
||||
|
@ -1162,13 +1358,17 @@ Status CoreWorker::ExecuteTask(const TaskSpecification &task_spec,
|
|||
SetCallerCreationTimestamp();
|
||||
}
|
||||
|
||||
status = task_execution_callback_(
|
||||
// Because we support concurrent actor calls, we need to update the
|
||||
// worker ID for the current thread.
|
||||
CoreWorkerProcess::SetCurrentThreadWorkerId(GetWorkerID());
|
||||
|
||||
status = options_.task_execution_callback(
|
||||
task_type, func, task_spec.GetRequiredResources().GetResourceMap(), args,
|
||||
arg_reference_ids, return_ids, return_objects, worker_context_.GetWorkerID());
|
||||
arg_reference_ids, return_ids, return_objects);
|
||||
|
||||
absl::optional<rpc::Address> caller_address(
|
||||
is_local_mode_ ? absl::optional<rpc::Address>()
|
||||
: worker_context_.GetCurrentTask()->CallerAddress());
|
||||
options_.is_local_mode ? absl::optional<rpc::Address>()
|
||||
: worker_context_.GetCurrentTask()->CallerAddress());
|
||||
for (size_t i = 0; i < return_objects->size(); i++) {
|
||||
// The object is nullptr if it already existed in the object store.
|
||||
if (!return_objects->at(i)) {
|
||||
|
@ -1196,7 +1396,7 @@ Status CoreWorker::ExecuteTask(const TaskSpecification &task_spec,
|
|||
RAY_LOG(DEBUG) << "Decrementing ref for borrowed ID " << borrowed_id;
|
||||
reference_counter_->RemoveLocalReference(borrowed_id, &deleted);
|
||||
}
|
||||
if (ref_counting_enabled_) {
|
||||
if (options_.ref_counting_enabled) {
|
||||
memory_store_->Delete(deleted);
|
||||
}
|
||||
|
||||
|
@ -1210,7 +1410,7 @@ Status CoreWorker::ExecuteTask(const TaskSpecification &task_spec,
|
|||
"reference counting, and may cause problems in the object store.";
|
||||
}
|
||||
|
||||
if (!is_local_mode_) {
|
||||
if (!options_.is_local_mode) {
|
||||
SetCurrentTaskId(TaskID::Nil());
|
||||
worker_context_.ResetCurrentTask(task_spec);
|
||||
}
|
||||
|
@ -1263,7 +1463,7 @@ Status CoreWorker::BuildArgsForExecutor(const TaskSpecification &task,
|
|||
// Direct call type objects that weren't inlined have been promoted to plasma.
|
||||
// We need to put an OBJECT_IN_PLASMA error here so the subsequent call to Get()
|
||||
// properly redirects to the plasma store.
|
||||
if (task.ArgId(i, 0).IsDirectCallType() && !is_local_mode_) {
|
||||
if (task.ArgId(i, 0).IsDirectCallType() && !options_.is_local_mode) {
|
||||
RAY_UNUSED(memory_store_->Put(RayObject(rpc::ErrorType::OBJECT_IN_PLASMA),
|
||||
task.ArgId(i, 0)));
|
||||
}
|
||||
|
@ -1309,7 +1509,7 @@ Status CoreWorker::BuildArgsForExecutor(const TaskSpecification &task,
|
|||
// Fetch by-reference arguments directly from the plasma store.
|
||||
bool got_exception = false;
|
||||
absl::flat_hash_map<ObjectID, std::shared_ptr<RayObject>> result_map;
|
||||
if (is_local_mode_) {
|
||||
if (options_.is_local_mode) {
|
||||
RAY_RETURN_NOT_OK(
|
||||
memory_store_->Get(by_ref_ids, -1, worker_context_, &result_map, &got_exception));
|
||||
} else {
|
||||
|
@ -1472,7 +1672,16 @@ void CoreWorker::HandleKillActor(const rpc::KillActorRequest &request,
|
|||
if (request.no_reconstruction()) {
|
||||
RAY_IGNORE_EXPR(local_raylet_client_->Disconnect());
|
||||
}
|
||||
if (log_dir_ != "") {
|
||||
if (options_.num_workers > 1) {
|
||||
// TODO (kfstorm): Should we add some kind of check before sending the killing
|
||||
// request?
|
||||
RAY_LOG(ERROR)
|
||||
<< "Killing an actor which is running in a worker process with multiple "
|
||||
"workers will also kill other actors in this process. To avoid this, "
|
||||
"please create the Java actor with some dynamic options to make it being "
|
||||
"hosted in a dedicated worker process.";
|
||||
}
|
||||
if (options_.log_dir != "") {
|
||||
RayLog::ShutDownRayLog();
|
||||
}
|
||||
exit(1);
|
||||
|
@ -1522,8 +1731,8 @@ void CoreWorker::HandleGetCoreWorkerStats(const rpc::GetCoreWorkerStatsRequest &
|
|||
void CoreWorker::HandleLocalGC(const rpc::LocalGCRequest &request,
|
||||
rpc::LocalGCReply *reply,
|
||||
rpc::SendReplyCallback send_reply_callback) {
|
||||
if (gc_collect_ != nullptr) {
|
||||
gc_collect_();
|
||||
if (options_.gc_collect != nullptr) {
|
||||
options_.gc_collect();
|
||||
send_reply_callback(Status::OK(), nullptr, nullptr);
|
||||
} else {
|
||||
send_reply_callback(Status::NotImplemented("GC callback not defined"), nullptr,
|
||||
|
@ -1573,7 +1782,7 @@ void CoreWorker::HandlePlasmaObjectReady(const rpc::PlasmaObjectReadyRequest &re
|
|||
|
||||
void CoreWorker::SetActorId(const ActorID &actor_id) {
|
||||
absl::MutexLock lock(&mutex_);
|
||||
if (!is_local_mode_) {
|
||||
if (!options_.is_local_mode) {
|
||||
RAY_CHECK(actor_id_.IsNil());
|
||||
}
|
||||
actor_id_ = actor_id;
|
||||
|
|
|
@ -50,10 +50,9 @@
|
|||
|
||||
namespace ray {
|
||||
|
||||
/// The root class that contains all the core and language-independent functionalities
|
||||
/// of the worker. This class is supposed to be used to implement app-language (Java,
|
||||
/// Python, etc) workers.
|
||||
class CoreWorker : public rpc::CoreWorkerServiceHandler {
|
||||
class CoreWorker;
|
||||
|
||||
struct CoreWorkerOptions {
|
||||
// Callback that must be implemented and provided by the language-specific worker
|
||||
// frontend to execute tasks and return their results.
|
||||
using TaskExecutionCallback = std::function<Status(
|
||||
|
@ -62,60 +61,242 @@ class CoreWorker : public rpc::CoreWorkerServiceHandler {
|
|||
const std::vector<std::shared_ptr<RayObject>> &args,
|
||||
const std::vector<ObjectID> &arg_reference_ids,
|
||||
const std::vector<ObjectID> &return_ids,
|
||||
std::vector<std::shared_ptr<RayObject>> *results, const ray::WorkerID &worker_id)>;
|
||||
std::vector<std::shared_ptr<RayObject>> *results)>;
|
||||
|
||||
/// Type of this worker (i.e., DRIVER or WORKER).
|
||||
WorkerType worker_type;
|
||||
/// Application language of this worker (i.e., PYTHON or JAVA).
|
||||
Language language;
|
||||
/// Object store socket to connect to.
|
||||
std::string store_socket;
|
||||
/// Raylet socket to connect to.
|
||||
std::string raylet_socket;
|
||||
/// Job ID of this worker.
|
||||
JobID job_id;
|
||||
/// Options for the GCS client.
|
||||
gcs::GcsClientOptions gcs_options;
|
||||
/// Directory to write logs to. If this is empty, logs won't be written to a file.
|
||||
std::string log_dir;
|
||||
/// If false, will not call `RayLog::InstallFailureSignalHandler()`.
|
||||
bool install_failure_signal_handler;
|
||||
/// IP address of the node.
|
||||
std::string node_ip_address;
|
||||
/// Port of the local raylet.
|
||||
int node_manager_port;
|
||||
/// The name of the driver.
|
||||
std::string driver_name;
|
||||
/// The stdout file of this process.
|
||||
std::string stdout_file;
|
||||
/// The stderr file of this process.
|
||||
std::string stderr_file;
|
||||
/// Language worker callback to execute tasks.
|
||||
TaskExecutionCallback task_execution_callback;
|
||||
/// Application-language callback to check for signals that have been received
|
||||
/// since calling into C++. This will be called periodically (at least every
|
||||
/// 1s) during long-running operations. If the function returns anything but StatusOK,
|
||||
/// any long-running operations in the core worker will short circuit and return that
|
||||
/// status.
|
||||
std::function<Status()> check_signals;
|
||||
/// Application-language callback to trigger garbage collection in the language
|
||||
/// runtime. This is required to free distributed references that may otherwise
|
||||
/// be held up in garbage objects.
|
||||
std::function<void()> gc_collect;
|
||||
/// Language worker callback to get the current call stack.
|
||||
std::function<void(std::string *)> get_lang_stack;
|
||||
/// Whether to enable object ref counting.
|
||||
bool ref_counting_enabled;
|
||||
/// Is local mode being used.
|
||||
bool is_local_mode;
|
||||
/// The number of workers to be started in the current process.
|
||||
int num_workers;
|
||||
};
|
||||
|
||||
/// Lifecycle management of one or more `CoreWorker` instances in a process.
|
||||
///
|
||||
/// To start a driver in the current process:
|
||||
/// CoreWorkerOptions options = {
|
||||
/// WorkerType::DRIVER, // worker_type
|
||||
/// ..., // other arguments
|
||||
/// 1, // num_workers
|
||||
/// };
|
||||
/// CoreWorkerProcess::Initialize(options);
|
||||
///
|
||||
/// To shutdown a driver in the current process:
|
||||
/// CoreWorkerProcess::Shutdown();
|
||||
///
|
||||
/// To start one or more workers in the current process:
|
||||
/// CoreWorkerOptions options = {
|
||||
/// WorkerType::WORKER, // worker_type
|
||||
/// ..., // other arguments
|
||||
/// num_workers, // num_workers
|
||||
/// };
|
||||
/// CoreWorkerProcess::Initialize(options);
|
||||
/// ... // Do other stuff
|
||||
/// CoreWorkerProcess::RunTaskExecutionLoop();
|
||||
///
|
||||
/// To shutdown a worker in the current process, return a system exit status (with status
|
||||
/// code `IntentionalSystemExit` or `UnexpectedSystemExit`) in the task execution
|
||||
/// callback.
|
||||
///
|
||||
/// If more than 1 worker is started, only the threads which invoke the
|
||||
/// `task_execution_callback` will be automatically associated with the corresponding
|
||||
/// worker. If you started your own threads and you want to use core worker APIs in these
|
||||
/// threads, remember to call `CoreWorkerProcess::SetCurrentThreadWorkerId(worker_id)`
|
||||
/// once in the new thread before calling core worker APIs, to associate the current
|
||||
/// thread with a worker. You can obtain the worker ID via
|
||||
/// `CoreWorkerProcess::GetCoreWorker()->GetWorkerID()`. Currently a Java worker process
|
||||
/// starts multiple workers by default, but can be configured to start only 1 worker by
|
||||
/// overwriting the internal config `num_workers_per_process_java`.
|
||||
///
|
||||
/// If only 1 worker is started (either because the worker type is driver, or the
|
||||
/// `num_workers` in `CoreWorkerOptions` is set to 1), all threads will be automatically
|
||||
/// associated to the only worker. Then no need to call `SetCurrentThreadWorkerId` in
|
||||
/// your own threads. Currently a Python worker process starts only 1 worker.
|
||||
class CoreWorkerProcess {
|
||||
public:
|
||||
///
|
||||
/// Public methods used in both DRIVER and WORKER mode.
|
||||
///
|
||||
|
||||
/// Initialize core workers at the process level.
|
||||
///
|
||||
/// \param[in] options The various initialization options.
|
||||
static void Initialize(const CoreWorkerOptions &options);
|
||||
|
||||
/// Get the core worker associated with the current thread.
|
||||
/// NOTE (kfstorm): Here we return a reference instead of a `shared_ptr` to make sure
|
||||
/// `CoreWorkerProcess` has full control of the destruction timing of `CoreWorker`.
|
||||
static CoreWorker &GetCoreWorker();
|
||||
|
||||
/// Set the core worker associated with the current thread by worker ID.
|
||||
/// Currently used by Java worker only.
|
||||
///
|
||||
/// \param worker_id The worker ID of the core worker instance.
|
||||
static void SetCurrentThreadWorkerId(const WorkerID &worker_id);
|
||||
|
||||
/// Whether the current process has been initialized for core worker.
|
||||
static bool IsInitialized();
|
||||
|
||||
///
|
||||
/// Public methods used in DRIVER mode only.
|
||||
///
|
||||
|
||||
/// Shutdown the driver completely at the process level.
|
||||
static void Shutdown();
|
||||
|
||||
///
|
||||
/// Public methods used in WORKER mode only.
|
||||
///
|
||||
|
||||
/// Start receiving and executing tasks.
|
||||
static void RunTaskExecutionLoop();
|
||||
|
||||
// The destructor is not to be used as a public API, but it's required by smart
|
||||
// pointers.
|
||||
~CoreWorkerProcess();
|
||||
|
||||
private:
|
||||
/// Create an `CoreWorkerProcess` with proper options.
|
||||
///
|
||||
/// \param[in] options The various initialization options.
|
||||
CoreWorkerProcess(const CoreWorkerOptions &options);
|
||||
|
||||
/// Check that the core worker environment is initialized for this process.
|
||||
///
|
||||
/// \return Void.
|
||||
static void EnsureInitialized();
|
||||
|
||||
/// Get the `CoreWorker` instance by worker ID.
|
||||
///
|
||||
/// \param[in] workerId The worker ID.
|
||||
/// \return The `CoreWorker` instance.
|
||||
std::shared_ptr<CoreWorker> GetWorker(const WorkerID &worker_id) const
|
||||
LOCKS_EXCLUDED(worker_map_mutex_);
|
||||
|
||||
/// Create a new `CoreWorker` instance.
|
||||
///
|
||||
/// \return The newly created `CoreWorker` instance.
|
||||
std::shared_ptr<CoreWorker> CreateWorker() LOCKS_EXCLUDED(worker_map_mutex_);
|
||||
|
||||
/// Remove an existing `CoreWorker` instance.
|
||||
///
|
||||
/// \param[in] The existing `CoreWorker` instance.
|
||||
/// \return Void.
|
||||
void RemoveWorker(std::shared_ptr<CoreWorker> worker) LOCKS_EXCLUDED(worker_map_mutex_);
|
||||
|
||||
/// The global instance of `CoreWorkerProcess`.
|
||||
static std::unique_ptr<CoreWorkerProcess> instance_;
|
||||
|
||||
/// The various options.
|
||||
const CoreWorkerOptions options_;
|
||||
|
||||
/// The core worker instance associated with the current thread.
|
||||
/// Use weak_ptr here to avoid memory leak due to multi-threading.
|
||||
static thread_local std::weak_ptr<CoreWorker> current_core_worker_;
|
||||
|
||||
/// The only core worker instance, if the number of workers is 1.
|
||||
std::shared_ptr<CoreWorker> global_worker_;
|
||||
|
||||
/// The worker ID of the global worker, if the number of workers is 1.
|
||||
const WorkerID global_worker_id_;
|
||||
|
||||
/// Map from worker ID to worker.
|
||||
std::unordered_map<WorkerID, std::shared_ptr<CoreWorker>> workers_
|
||||
GUARDED_BY(worker_map_mutex_);
|
||||
|
||||
/// To protect accessing the `workers_` map.
|
||||
mutable absl::Mutex worker_map_mutex_;
|
||||
};
|
||||
|
||||
/// The root class that contains all the core and language-independent functionalities
|
||||
/// of the worker. This class is supposed to be used to implement app-language (Java,
|
||||
/// Python, etc) workers.
|
||||
class CoreWorker : public rpc::CoreWorkerServiceHandler {
|
||||
public:
|
||||
/// Construct a CoreWorker instance.
|
||||
///
|
||||
/// \param[in] worker_type Type of this worker.
|
||||
/// \param[in] language Language of this worker.
|
||||
/// \param[in] store_socket Object store socket to connect to.
|
||||
/// \param[in] raylet_socket Raylet socket to connect to.
|
||||
/// \param[in] job_id Job ID of this worker.
|
||||
/// \param[in] gcs_options Options for the GCS client.
|
||||
/// \param[in] log_dir Directory to write logs to. If this is empty, logs
|
||||
/// won't be written to a file.
|
||||
/// \param[in] node_ip_address IP address of the node.
|
||||
/// \param[in] node_manager_port Port of the local raylet.
|
||||
/// \param[in] task_execution_callback Language worker callback to execute tasks.
|
||||
/// \param[in] check_signals Language worker function to check for signals and handle
|
||||
/// them. If the function returns anything but StatusOK, any long-running
|
||||
/// operations in the core worker will short circuit and return that status.
|
||||
/// \param[in] ref_counting_enabled Whether to enable object ref counting.
|
||||
/// \param[in] options The various initialization options.
|
||||
/// \param[in] worker_id ID of this worker.
|
||||
CoreWorker(const CoreWorkerOptions &options, const WorkerID &worker_id);
|
||||
|
||||
CoreWorker(CoreWorker const &) = delete;
|
||||
void operator=(CoreWorker const &other) = delete;
|
||||
|
||||
///
|
||||
/// Public methods used by `CoreWorkerProcess` and `CoreWorker` itself.
|
||||
///
|
||||
/// NOTE(zhijunfu): the constructor would throw if a failure happens.
|
||||
CoreWorker(const WorkerType worker_type, const Language language,
|
||||
const std::string &store_socket, const std::string &raylet_socket,
|
||||
const JobID &job_id, const gcs::GcsClientOptions &gcs_options,
|
||||
const std::string &log_dir, const std::string &node_ip_address,
|
||||
int node_manager_port, const TaskExecutionCallback &task_execution_callback,
|
||||
std::function<Status()> check_signals = nullptr,
|
||||
std::function<void()> gc_collect = nullptr,
|
||||
std::function<void(std::string *)> get_lang_stack = nullptr,
|
||||
bool ref_counting_enabled = false, bool local_mode = false);
|
||||
|
||||
virtual ~CoreWorker();
|
||||
|
||||
void Exit(bool intentional);
|
||||
|
||||
/// Gracefully disconnect the worker from other components of ray. e.g. Raylet.
|
||||
/// If this function is called during shutdown, Raylet will treat it as an intentional
|
||||
/// disconnect.
|
||||
///
|
||||
/// \return Void.
|
||||
void Disconnect();
|
||||
|
||||
WorkerType GetWorkerType() const { return worker_type_; }
|
||||
/// Shut down the worker completely.
|
||||
///
|
||||
/// \return void.
|
||||
void Shutdown();
|
||||
|
||||
Language GetLanguage() const { return language_; }
|
||||
/// Block the current thread until the worker is shut down.
|
||||
void WaitForShutdown();
|
||||
|
||||
/// Start receiving and executing tasks.
|
||||
/// \return void.
|
||||
void RunTaskExecutionLoop();
|
||||
|
||||
const WorkerID &GetWorkerID() const;
|
||||
|
||||
WorkerType GetWorkerType() const { return options_.worker_type; }
|
||||
|
||||
Language GetLanguage() const { return options_.language; }
|
||||
|
||||
WorkerContext &GetWorkerContext() { return worker_context_; }
|
||||
|
||||
raylet::RayletClient &GetRayletClient() { return *local_raylet_client_; }
|
||||
|
||||
const TaskID &GetCurrentTaskId() const { return worker_context_.GetCurrentTaskID(); }
|
||||
|
||||
void SetCurrentTaskId(const TaskID &task_id);
|
||||
|
||||
const JobID &GetCurrentJobId() const { return worker_context_.GetCurrentJobID(); }
|
||||
|
||||
void SetActorId(const ActorID &actor_id);
|
||||
|
||||
void SetWebuiDisplay(const std::string &key, const std::string &message);
|
||||
|
||||
void SetActorTitle(const std::string &title);
|
||||
|
@ -139,7 +320,7 @@ class CoreWorker : public rpc::CoreWorkerServiceHandler {
|
|||
std::vector<ObjectID> deleted;
|
||||
reference_counter_->RemoveLocalReference(object_id, &deleted);
|
||||
// TOOD(ilr): better way of keeping an object from being deleted
|
||||
if (ref_counting_enabled_ && !is_local_mode_) {
|
||||
if (options_.ref_counting_enabled && !options_.is_local_mode) {
|
||||
memory_store_->Delete(deleted);
|
||||
}
|
||||
}
|
||||
|
@ -448,10 +629,7 @@ class CoreWorker : public rpc::CoreWorkerServiceHandler {
|
|||
/// Create a profile event with a reference to the core worker's profiler.
|
||||
std::unique_ptr<worker::ProfileEvent> CreateProfileEvent(const std::string &event_type);
|
||||
|
||||
/// Start receiving and executing tasks.
|
||||
/// \return void.
|
||||
void StartExecutingTasks();
|
||||
|
||||
public:
|
||||
/// Allocate the return objects for an executing task. The caller should write into the
|
||||
/// data buffers of the allocated buffers.
|
||||
///
|
||||
|
@ -566,18 +744,25 @@ class CoreWorker : public rpc::CoreWorkerServiceHandler {
|
|||
void SubscribeToPlasmaAdd(const ObjectID &object_id);
|
||||
|
||||
private:
|
||||
void SetCurrentTaskId(const TaskID &task_id);
|
||||
|
||||
void SetActorId(const ActorID &actor_id);
|
||||
|
||||
/// Run the io_service_ event loop. This should be called in a background thread.
|
||||
void RunIOService();
|
||||
|
||||
/// Shut down the worker completely.
|
||||
/// \return void.
|
||||
void Shutdown();
|
||||
/// (WORKER mode only) Exit the worker. This is the entrypoint used to shutdown a
|
||||
/// worker.
|
||||
void Exit(bool intentional);
|
||||
|
||||
/// Register this worker or driver to GCS.
|
||||
void RegisterToGcs();
|
||||
|
||||
/// Check if the raylet has failed. If so, shutdown.
|
||||
void CheckForRayletFailure();
|
||||
void CheckForRayletFailure(const boost::system::error_code &error);
|
||||
|
||||
/// Heartbeat for internal bookkeeping.
|
||||
void InternalHeartbeat();
|
||||
void InternalHeartbeat(const boost::system::error_code &error);
|
||||
|
||||
///
|
||||
/// Private methods related to task submission.
|
||||
|
@ -682,30 +867,7 @@ class CoreWorker : public rpc::CoreWorkerServiceHandler {
|
|||
}
|
||||
}
|
||||
|
||||
/// Type of this worker (i.e., DRIVER or WORKER).
|
||||
const WorkerType worker_type_;
|
||||
|
||||
/// Application language of this worker (i.e., PYTHON or JAVA).
|
||||
const Language language_;
|
||||
|
||||
/// Directory where log files are written.
|
||||
const std::string log_dir_;
|
||||
|
||||
/// Whether local reference counting is enabled.
|
||||
const bool ref_counting_enabled_;
|
||||
|
||||
/// Is local mode being used.
|
||||
const bool is_local_mode_;
|
||||
|
||||
/// Application-language callback to check for signals that have been received
|
||||
/// since calling into C++. This will be called periodically (at least every
|
||||
/// 1s) during long-running operations.
|
||||
std::function<Status()> check_signals_;
|
||||
|
||||
/// Application-language callback to trigger garbage collection in the language
|
||||
/// runtime. This is required to free distributed references that may otherwise
|
||||
/// be held up in garbage objects.
|
||||
std::function<void()> gc_collect_;
|
||||
const CoreWorkerOptions options_;
|
||||
|
||||
/// Callback to get the current language (e.g., Python) call site.
|
||||
std::function<void(std::string *)> get_call_site_;
|
||||
|
@ -843,9 +1005,6 @@ class CoreWorker : public rpc::CoreWorkerServiceHandler {
|
|||
/// Profiler including a background thread that pushes profiling events to the GCS.
|
||||
std::shared_ptr<worker::Profiler> profiler_;
|
||||
|
||||
/// Task execution callback.
|
||||
TaskExecutionCallback task_execution_callback_;
|
||||
|
||||
/// A map from resource name to the resource IDs that are currently reserved
|
||||
/// for this worker. Each pair consists of the resource ID and the fraction
|
||||
/// of that resource allocated for this worker. This is set on task assignment.
|
||||
|
|
|
@ -82,7 +82,6 @@ jfieldID java_native_ray_object_metadata;
|
|||
|
||||
jclass java_task_executor_class;
|
||||
jmethodID java_task_executor_execute;
|
||||
jmethodID java_task_executor_get;
|
||||
|
||||
JavaVM *jvm;
|
||||
|
||||
|
@ -197,9 +196,6 @@ jint JNI_OnLoad(JavaVM *vm, void *reserved) {
|
|||
env->GetMethodID(java_task_executor_class, "execute",
|
||||
"(Ljava/util/List;Ljava/util/List;)Ljava/util/List;");
|
||||
|
||||
java_task_executor_get = env->GetStaticMethodID(
|
||||
java_task_executor_class, "get", "([B)Lorg/ray/runtime/task/TaskExecutor;");
|
||||
|
||||
return CURRENT_JNI_VERSION;
|
||||
}
|
||||
|
||||
|
|
|
@ -141,9 +141,6 @@ extern jclass java_task_executor_class;
|
|||
/// execute method of TaskExecutor class
|
||||
extern jmethodID java_task_executor_execute;
|
||||
|
||||
/// The `get` method in TaskExecutor class
|
||||
extern jmethodID java_task_executor_get;
|
||||
|
||||
#define CURRENT_JNI_VERSION JNI_VERSION_1_8
|
||||
|
||||
extern JavaVM *jvm;
|
||||
|
|
|
@ -20,6 +20,7 @@
|
|||
#include "ray/core_worker/lib/java/jni_utils.h"
|
||||
|
||||
thread_local JNIEnv *local_env = nullptr;
|
||||
jobject java_task_executor = nullptr;
|
||||
|
||||
inline ray::gcs::GcsClientOptions ToGcsClientOptions(JNIEnv *env,
|
||||
jobject gcs_client_options) {
|
||||
|
@ -36,15 +37,20 @@ inline ray::gcs::GcsClientOptions ToGcsClientOptions(JNIEnv *env,
|
|||
extern "C" {
|
||||
#endif
|
||||
|
||||
JNIEXPORT jlong JNICALL Java_org_ray_runtime_RayNativeRuntime_nativeInitCoreWorker(
|
||||
JNIEnv *env, jclass, jint workerMode, jstring storeSocket, jstring rayletSocket,
|
||||
jstring nodeIpAddress, jint nodeManagerPort, jbyteArray jobId,
|
||||
jobject gcsClientOptions) {
|
||||
auto native_store_socket = JavaStringToNativeString(env, storeSocket);
|
||||
auto native_raylet_socket = JavaStringToNativeString(env, rayletSocket);
|
||||
auto job_id = JavaByteArrayToId<ray::JobID>(env, jobId);
|
||||
auto gcs_client_options = ToGcsClientOptions(env, gcsClientOptions);
|
||||
auto node_ip_address = JavaStringToNativeString(env, nodeIpAddress);
|
||||
JNIEXPORT void JNICALL Java_org_ray_runtime_RayNativeRuntime_nativeInitialize(
|
||||
JNIEnv *env, jclass, jint workerMode, jstring nodeIpAddress, jint nodeManagerPort,
|
||||
jstring driverName, jstring storeSocket, jstring rayletSocket, jbyteArray jobId,
|
||||
jobject gcsClientOptions, jint numWorkersPerProcess, jstring logDir,
|
||||
jobject rayletConfigParameters) {
|
||||
auto raylet_config = JavaMapToNativeMap<std::string, std::string>(
|
||||
env, rayletConfigParameters,
|
||||
[](JNIEnv *env, jobject java_key) {
|
||||
return JavaStringToNativeString(env, (jstring)java_key);
|
||||
},
|
||||
[](JNIEnv *env, jobject java_value) {
|
||||
return JavaStringToNativeString(env, (jstring)java_value);
|
||||
});
|
||||
RayConfig::instance().initialize(raylet_config);
|
||||
|
||||
auto task_execution_callback =
|
||||
[](ray::TaskType task_type, const ray::RayFunction &ray_function,
|
||||
|
@ -52,8 +58,7 @@ JNIEXPORT jlong JNICALL Java_org_ray_runtime_RayNativeRuntime_nativeInitCoreWork
|
|||
const std::vector<std::shared_ptr<ray::RayObject>> &args,
|
||||
const std::vector<ObjectID> &arg_reference_ids,
|
||||
const std::vector<ObjectID> &return_ids,
|
||||
std::vector<std::shared_ptr<ray::RayObject>> *results,
|
||||
const ray::WorkerID &worker_id) {
|
||||
std::vector<std::shared_ptr<ray::RayObject>> *results) {
|
||||
JNIEnv *env = local_env;
|
||||
if (!env) {
|
||||
// Attach the native thread to JVM.
|
||||
|
@ -64,12 +69,7 @@ JNIEXPORT jlong JNICALL Java_org_ray_runtime_RayNativeRuntime_nativeInitCoreWork
|
|||
}
|
||||
|
||||
RAY_CHECK(env);
|
||||
|
||||
auto worker_id_bytes = IdToJavaByteArray<ray::WorkerID>(env, worker_id);
|
||||
jobject local_java_task_executor = env->CallStaticObjectMethod(
|
||||
java_task_executor_class, java_task_executor_get, worker_id_bytes);
|
||||
|
||||
RAY_CHECK(local_java_task_executor);
|
||||
RAY_CHECK(java_task_executor);
|
||||
// convert RayFunction
|
||||
jobject ray_function_array_list = NativeRayFunctionDescriptorToJavaStringList(
|
||||
env, ray_function.GetFunctionDescriptor());
|
||||
|
@ -80,7 +80,7 @@ JNIEXPORT jlong JNICALL Java_org_ray_runtime_RayNativeRuntime_nativeInitCoreWork
|
|||
|
||||
// invoke Java method
|
||||
jobject java_return_objects =
|
||||
env->CallObjectMethod(local_java_task_executor, java_task_executor_execute,
|
||||
env->CallObjectMethod(java_task_executor, java_task_executor_execute,
|
||||
ray_function_array_list, args_array_list);
|
||||
RAY_CHECK_JAVA_EXCEPTION(env);
|
||||
std::vector<std::shared_ptr<ray::RayObject>> return_objects;
|
||||
|
@ -99,81 +99,70 @@ JNIEXPORT jlong JNICALL Java_org_ray_runtime_RayNativeRuntime_nativeInitCoreWork
|
|||
return ray::Status::OK();
|
||||
};
|
||||
|
||||
try {
|
||||
auto core_worker = new ray::CoreWorker(
|
||||
static_cast<ray::WorkerType>(workerMode), ::Language::JAVA, native_store_socket,
|
||||
native_raylet_socket, job_id, gcs_client_options, /*log_dir=*/"", node_ip_address,
|
||||
nodeManagerPort, task_execution_callback);
|
||||
return reinterpret_cast<jlong>(core_worker);
|
||||
} catch (const std::exception &e) {
|
||||
std::ostringstream oss;
|
||||
oss << "Failed to construct core worker: " << e.what();
|
||||
THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, ray::Status::Invalid(oss.str()), 0);
|
||||
return 0; // To make compiler no complain
|
||||
}
|
||||
ray::CoreWorkerOptions options = {
|
||||
static_cast<ray::WorkerType>(workerMode), // worker_type
|
||||
ray::Language::JAVA, // langauge
|
||||
JavaStringToNativeString(env, storeSocket), // store_socket
|
||||
JavaStringToNativeString(env, rayletSocket), // raylet_socket
|
||||
JavaByteArrayToId<ray::JobID>(env, jobId), // job_id
|
||||
ToGcsClientOptions(env, gcsClientOptions), // gcs_options
|
||||
JavaStringToNativeString(env, logDir), // log_dir
|
||||
// TODO (kfstorm): JVM would crash if install_failure_signal_handler was set to true
|
||||
false, // install_failure_signal_handler
|
||||
JavaStringToNativeString(env, nodeIpAddress), // node_ip_address
|
||||
static_cast<int>(nodeManagerPort), // node_manager_port
|
||||
JavaStringToNativeString(env, driverName), // driver_name
|
||||
"", // stdout_file
|
||||
"", // stderr_file
|
||||
task_execution_callback, // task_execution_callback
|
||||
nullptr, // check_signals
|
||||
nullptr, // gc_collect
|
||||
nullptr, // get_lang_stack
|
||||
false, // ref_counting_enabled
|
||||
false, // is_local_mode
|
||||
static_cast<int>(numWorkersPerProcess), // num_workers
|
||||
};
|
||||
|
||||
ray::CoreWorkerProcess::Initialize(options);
|
||||
}
|
||||
|
||||
JNIEXPORT void JNICALL Java_org_ray_runtime_RayNativeRuntime_nativeRunTaskExecutor(
|
||||
JNIEnv *env, jclass o, jlong nativeCoreWorkerPointer) {
|
||||
local_env = env;
|
||||
auto core_worker = reinterpret_cast<ray::CoreWorker *>(nativeCoreWorkerPointer);
|
||||
core_worker->StartExecutingTasks();
|
||||
local_env = nullptr;
|
||||
JNIEnv *env, jclass o, jobject javaTaskExecutor) {
|
||||
java_task_executor = javaTaskExecutor;
|
||||
ray::CoreWorkerProcess::RunTaskExecutionLoop();
|
||||
java_task_executor = nullptr;
|
||||
}
|
||||
|
||||
JNIEXPORT void JNICALL Java_org_ray_runtime_RayNativeRuntime_nativeDestroyCoreWorker(
|
||||
JNIEnv *env, jclass o, jlong nativeCoreWorkerPointer) {
|
||||
auto core_worker = reinterpret_cast<ray::CoreWorker *>(nativeCoreWorkerPointer);
|
||||
core_worker->Disconnect();
|
||||
delete core_worker;
|
||||
}
|
||||
|
||||
JNIEXPORT void JNICALL Java_org_ray_runtime_RayNativeRuntime_nativeSetup(
|
||||
JNIEnv *env, jclass, jstring logDir, jobject rayletConfigParameters) {
|
||||
std::string log_dir = JavaStringToNativeString(env, logDir);
|
||||
ray::RayLog::StartRayLog("java_worker", ray::RayLogLevel::INFO, log_dir);
|
||||
// TODO (kfstorm): We can't InstallFailureSignalHandler here, because JVM already
|
||||
// installed its own signal handler. It's possible to fix this by chaining signal
|
||||
// handlers. But it's not easy. See
|
||||
// https://docs.oracle.com/javase/9/troubleshoot/handle-signals-and-exceptions.htm.
|
||||
auto raylet_config = JavaMapToNativeMap<std::string, std::string>(
|
||||
env, rayletConfigParameters,
|
||||
[](JNIEnv *env, jobject java_key) {
|
||||
return JavaStringToNativeString(env, (jstring)java_key);
|
||||
},
|
||||
[](JNIEnv *env, jobject java_value) {
|
||||
return JavaStringToNativeString(env, (jstring)java_value);
|
||||
});
|
||||
RayConfig::instance().initialize(raylet_config);
|
||||
}
|
||||
|
||||
JNIEXPORT void JNICALL Java_org_ray_runtime_RayNativeRuntime_nativeShutdownHook(JNIEnv *,
|
||||
jclass) {
|
||||
ray::RayLog::ShutDownRayLog();
|
||||
JNIEXPORT void JNICALL Java_org_ray_runtime_RayNativeRuntime_nativeShutdown(JNIEnv *env,
|
||||
jclass o) {
|
||||
ray::CoreWorkerProcess::Shutdown();
|
||||
}
|
||||
|
||||
JNIEXPORT void JNICALL Java_org_ray_runtime_RayNativeRuntime_nativeSetResource(
|
||||
JNIEnv *env, jclass, jlong nativeCoreWorkerPointer, jstring resourceName,
|
||||
jdouble capacity, jbyteArray nodeId) {
|
||||
JNIEnv *env, jclass, jstring resourceName, jdouble capacity, jbyteArray nodeId) {
|
||||
const auto node_id = JavaByteArrayToId<ClientID>(env, nodeId);
|
||||
const char *native_resource_name = env->GetStringUTFChars(resourceName, JNI_FALSE);
|
||||
|
||||
auto status =
|
||||
reinterpret_cast<ray::CoreWorker *>(nativeCoreWorkerPointer)
|
||||
->SetResource(native_resource_name, static_cast<double>(capacity), node_id);
|
||||
auto status = ray::CoreWorkerProcess::GetCoreWorker().SetResource(
|
||||
native_resource_name, static_cast<double>(capacity), node_id);
|
||||
env->ReleaseStringUTFChars(resourceName, native_resource_name);
|
||||
THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, (void)0);
|
||||
}
|
||||
|
||||
JNIEXPORT void JNICALL Java_org_ray_runtime_RayNativeRuntime_nativeKillActor(
|
||||
JNIEnv *env, jclass, jlong nativeCoreWorkerPointer, jbyteArray actorId,
|
||||
jboolean noReconstruction) {
|
||||
auto core_worker = reinterpret_cast<ray::CoreWorker *>(nativeCoreWorkerPointer);
|
||||
auto status = core_worker->KillActor(JavaByteArrayToId<ActorID>(env, actorId),
|
||||
/*force_kill=*/true, noReconstruction);
|
||||
JNIEnv *env, jclass, jbyteArray actorId, jboolean noReconstruction) {
|
||||
auto status = ray::CoreWorkerProcess::GetCoreWorker().KillActor(
|
||||
JavaByteArrayToId<ActorID>(env, actorId),
|
||||
/*force_kill=*/true, noReconstruction);
|
||||
THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, (void)0);
|
||||
}
|
||||
|
||||
JNIEXPORT void JNICALL Java_org_ray_runtime_RayNativeRuntime_nativeSetCoreWorker(
|
||||
JNIEnv *env, jclass, jbyteArray workerId) {
|
||||
const auto worker_id = JavaByteArrayToId<ray::WorkerID>(env, workerId);
|
||||
ray::CoreWorkerProcess::SetCurrentThreadWorkerId(worker_id);
|
||||
}
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
|
|
@ -23,61 +23,55 @@ extern "C" {
|
|||
#endif
|
||||
/*
|
||||
* Class: org_ray_runtime_RayNativeRuntime
|
||||
* Method: nativeInitCoreWorker
|
||||
* Method: nativeInitialize
|
||||
* Signature:
|
||||
* (ILjava/lang/String;Ljava/lang/String;Ljava/lang/String;I[BLorg/ray/runtime/gcs/GcsClientOptions;)J
|
||||
* (ILjava/lang/String;ILjava/lang/String;Ljava/lang/String;Ljava/lang/String;[BLorg/ray/runtime/gcs/GcsClientOptions;ILjava/lang/String;Ljava/util/Map;)V
|
||||
*/
|
||||
JNIEXPORT jlong JNICALL Java_org_ray_runtime_RayNativeRuntime_nativeInitCoreWorker(
|
||||
JNIEnv *, jclass, jint, jstring, jstring, jstring, jint, jbyteArray, jobject);
|
||||
JNIEXPORT void JNICALL Java_org_ray_runtime_RayNativeRuntime_nativeInitialize(
|
||||
JNIEnv *, jclass, jint, jstring, jint, jstring, jstring, jstring, jbyteArray, jobject,
|
||||
jint, jstring, jobject);
|
||||
|
||||
/*
|
||||
* Class: org_ray_runtime_RayNativeRuntime
|
||||
* Method: nativeRunTaskExecutor
|
||||
* Signature: (J)V
|
||||
* Signature: (Lorg/ray/runtime/task/TaskExecutor;)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL
|
||||
Java_org_ray_runtime_RayNativeRuntime_nativeRunTaskExecutor(JNIEnv *, jclass, jlong);
|
||||
Java_org_ray_runtime_RayNativeRuntime_nativeRunTaskExecutor(JNIEnv *, jclass, jobject);
|
||||
|
||||
/*
|
||||
* Class: org_ray_runtime_RayNativeRuntime
|
||||
* Method: nativeDestroyCoreWorker
|
||||
* Signature: (J)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL
|
||||
Java_org_ray_runtime_RayNativeRuntime_nativeDestroyCoreWorker(JNIEnv *, jclass, jlong);
|
||||
|
||||
/*
|
||||
* Class: org_ray_runtime_RayNativeRuntime
|
||||
* Method: nativeSetup
|
||||
* Signature: (Ljava/lang/String;Ljava/util/Map;)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL Java_org_ray_runtime_RayNativeRuntime_nativeSetup(JNIEnv *, jclass,
|
||||
jstring,
|
||||
jobject);
|
||||
|
||||
/*
|
||||
* Class: org_ray_runtime_RayNativeRuntime
|
||||
* Method: nativeShutdownHook
|
||||
* Method: nativeShutdown
|
||||
* Signature: ()V
|
||||
*/
|
||||
JNIEXPORT void JNICALL Java_org_ray_runtime_RayNativeRuntime_nativeShutdownHook(JNIEnv *,
|
||||
jclass);
|
||||
JNIEXPORT void JNICALL Java_org_ray_runtime_RayNativeRuntime_nativeShutdown(JNIEnv *,
|
||||
jclass);
|
||||
|
||||
/*
|
||||
* Class: org_ray_runtime_RayNativeRuntime
|
||||
* Method: nativeSetResource
|
||||
* Signature: (JLjava/lang/String;D[B)V
|
||||
* Signature: (Ljava/lang/String;D[B)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL Java_org_ray_runtime_RayNativeRuntime_nativeSetResource(
|
||||
JNIEnv *, jclass, jlong, jstring, jdouble, jbyteArray);
|
||||
JNIEnv *, jclass, jstring, jdouble, jbyteArray);
|
||||
|
||||
/*
|
||||
* Class: org_ray_runtime_RayNativeRuntime
|
||||
* Method: nativeKillActor
|
||||
* Signature: (J[BZ)V
|
||||
* Signature: ([BZ)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL Java_org_ray_runtime_RayNativeRuntime_nativeKillActor(
|
||||
JNIEnv *, jclass, jlong, jbyteArray, jboolean);
|
||||
JNIEXPORT void JNICALL Java_org_ray_runtime_RayNativeRuntime_nativeKillActor(JNIEnv *,
|
||||
jclass,
|
||||
jbyteArray,
|
||||
jboolean);
|
||||
|
||||
/*
|
||||
* Class: org_ray_runtime_RayNativeRuntime
|
||||
* Method: nativeSetCoreWorker
|
||||
* Signature: ([B)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL
|
||||
Java_org_ray_runtime_RayNativeRuntime_nativeSetCoreWorker(JNIEnv *, jclass, jbyteArray);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
|
|
|
@ -15,47 +15,44 @@
|
|||
#include "ray/core_worker/lib/java/org_ray_runtime_actor_NativeRayActor.h"
|
||||
#include <jni.h>
|
||||
#include "ray/common/id.h"
|
||||
#include "ray/core_worker/actor_handle.h"
|
||||
#include "ray/core_worker/common.h"
|
||||
#include "ray/core_worker/core_worker.h"
|
||||
#include "ray/core_worker/lib/java/jni_utils.h"
|
||||
|
||||
inline ray::CoreWorker &GetCoreWorker(jlong nativeCoreWorkerPointer) {
|
||||
return *reinterpret_cast<ray::CoreWorker *>(nativeCoreWorkerPointer);
|
||||
}
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
JNIEXPORT jint JNICALL Java_org_ray_runtime_actor_NativeRayActor_nativeGetLanguage(
|
||||
JNIEnv *env, jclass o, jlong nativeCoreWorkerPointer, jbyteArray actorId) {
|
||||
JNIEnv *env, jclass o, jbyteArray actorId) {
|
||||
auto actor_id = JavaByteArrayToId<ray::ActorID>(env, actorId);
|
||||
ray::ActorHandle *native_actor_handle = nullptr;
|
||||
auto status = GetCoreWorker(nativeCoreWorkerPointer)
|
||||
.GetActorHandle(actor_id, &native_actor_handle);
|
||||
auto status = ray::CoreWorkerProcess::GetCoreWorker().GetActorHandle(
|
||||
actor_id, &native_actor_handle);
|
||||
THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, false);
|
||||
return native_actor_handle->ActorLanguage();
|
||||
}
|
||||
|
||||
JNIEXPORT jobject JNICALL
|
||||
Java_org_ray_runtime_actor_NativeRayActor_nativeGetActorCreationTaskFunctionDescriptor(
|
||||
JNIEnv *env, jclass o, jlong nativeCoreWorkerPointer, jbyteArray actorId) {
|
||||
JNIEnv *env, jclass o, jbyteArray actorId) {
|
||||
auto actor_id = JavaByteArrayToId<ray::ActorID>(env, actorId);
|
||||
ray::ActorHandle *native_actor_handle = nullptr;
|
||||
auto status = GetCoreWorker(nativeCoreWorkerPointer)
|
||||
.GetActorHandle(actor_id, &native_actor_handle);
|
||||
auto status = ray::CoreWorkerProcess::GetCoreWorker().GetActorHandle(
|
||||
actor_id, &native_actor_handle);
|
||||
THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, nullptr);
|
||||
auto function_descriptor = native_actor_handle->ActorCreationTaskFunctionDescriptor();
|
||||
return NativeRayFunctionDescriptorToJavaStringList(env, function_descriptor);
|
||||
}
|
||||
|
||||
JNIEXPORT jbyteArray JNICALL Java_org_ray_runtime_actor_NativeRayActor_nativeSerialize(
|
||||
JNIEnv *env, jclass o, jlong nativeCoreWorkerPointer, jbyteArray actorId) {
|
||||
JNIEnv *env, jclass o, jbyteArray actorId) {
|
||||
auto actor_id = JavaByteArrayToId<ray::ActorID>(env, actorId);
|
||||
std::string output;
|
||||
ObjectID actor_handle_id;
|
||||
ray::Status status = GetCoreWorker(nativeCoreWorkerPointer)
|
||||
.SerializeActorHandle(actor_id, &output, &actor_handle_id);
|
||||
ray::Status status = ray::CoreWorkerProcess::GetCoreWorker().SerializeActorHandle(
|
||||
actor_id, &output, &actor_handle_id);
|
||||
jbyteArray bytes = env->NewByteArray(output.size());
|
||||
env->SetByteArrayRegion(bytes, 0, output.size(),
|
||||
reinterpret_cast<const jbyte *>(output.c_str()));
|
||||
|
@ -63,13 +60,13 @@ JNIEXPORT jbyteArray JNICALL Java_org_ray_runtime_actor_NativeRayActor_nativeSer
|
|||
}
|
||||
|
||||
JNIEXPORT jbyteArray JNICALL Java_org_ray_runtime_actor_NativeRayActor_nativeDeserialize(
|
||||
JNIEnv *env, jclass o, jlong nativeCoreWorkerPointer, jbyteArray data) {
|
||||
JNIEnv *env, jclass o, jbyteArray data) {
|
||||
auto buffer = JavaByteArrayToNativeBuffer(env, data);
|
||||
RAY_CHECK(buffer->Size() > 0);
|
||||
auto binary = std::string(reinterpret_cast<char *>(buffer->Data()), buffer->Size());
|
||||
auto actor_id =
|
||||
GetCoreWorker(nativeCoreWorkerPointer)
|
||||
.DeserializeAndRegisterActorHandle(binary, /*outer_object_id=*/ObjectID::Nil());
|
||||
ray::CoreWorkerProcess::GetCoreWorker().DeserializeAndRegisterActorHandle(
|
||||
binary, /*outer_object_id=*/ObjectID::Nil());
|
||||
|
||||
return IdToJavaByteArray<ray::ActorID>(env, actor_id);
|
||||
}
|
||||
|
|
|
@ -24,35 +24,35 @@ extern "C" {
|
|||
/*
|
||||
* Class: org_ray_runtime_actor_NativeRayActor
|
||||
* Method: nativeGetLanguage
|
||||
* Signature: (J[B)I
|
||||
* Signature: ([B)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_org_ray_runtime_actor_NativeRayActor_nativeGetLanguage(
|
||||
JNIEnv *, jclass, jlong, jbyteArray);
|
||||
JNIEXPORT jint JNICALL
|
||||
Java_org_ray_runtime_actor_NativeRayActor_nativeGetLanguage(JNIEnv *, jclass, jbyteArray);
|
||||
|
||||
/*
|
||||
* Class: org_ray_runtime_actor_NativeRayActor
|
||||
* Method: nativeGetActorCreationTaskFunctionDescriptor
|
||||
* Signature: (J[B)Ljava/util/List;
|
||||
* Signature: ([B)Ljava/util/List;
|
||||
*/
|
||||
JNIEXPORT jobject JNICALL
|
||||
Java_org_ray_runtime_actor_NativeRayActor_nativeGetActorCreationTaskFunctionDescriptor(
|
||||
JNIEnv *, jclass, jlong, jbyteArray);
|
||||
JNIEnv *, jclass, jbyteArray);
|
||||
|
||||
/*
|
||||
* Class: org_ray_runtime_actor_NativeRayActor
|
||||
* Method: nativeSerialize
|
||||
* Signature: (J[B)[B
|
||||
* Signature: ([B)[B
|
||||
*/
|
||||
JNIEXPORT jbyteArray JNICALL Java_org_ray_runtime_actor_NativeRayActor_nativeSerialize(
|
||||
JNIEnv *, jclass, jlong, jbyteArray);
|
||||
JNIEXPORT jbyteArray JNICALL
|
||||
Java_org_ray_runtime_actor_NativeRayActor_nativeSerialize(JNIEnv *, jclass, jbyteArray);
|
||||
|
||||
/*
|
||||
* Class: org_ray_runtime_actor_NativeRayActor
|
||||
* Method: nativeDeserialize
|
||||
* Signature: (J[B)[B
|
||||
* Signature: ([B)[B
|
||||
*/
|
||||
JNIEXPORT jbyteArray JNICALL Java_org_ray_runtime_actor_NativeRayActor_nativeDeserialize(
|
||||
JNIEnv *, jclass, jlong, jbyteArray);
|
||||
JNIEXPORT jbyteArray JNICALL
|
||||
Java_org_ray_runtime_actor_NativeRayActor_nativeDeserialize(JNIEnv *, jclass, jbyteArray);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
|
|
|
@ -19,51 +19,48 @@
|
|||
#include "ray/core_worker/core_worker.h"
|
||||
#include "ray/core_worker/lib/java/jni_utils.h"
|
||||
|
||||
inline ray::WorkerContext &GetWorkerContextFromPointer(jlong nativeCoreWorkerPointer) {
|
||||
return reinterpret_cast<ray::CoreWorker *>(nativeCoreWorkerPointer)->GetWorkerContext();
|
||||
}
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
JNIEXPORT jint JNICALL
|
||||
Java_org_ray_runtime_context_NativeWorkerContext_nativeGetCurrentTaskType(
|
||||
JNIEnv *env, jclass, jlong nativeCoreWorkerPointer) {
|
||||
auto task_spec = GetWorkerContextFromPointer(nativeCoreWorkerPointer).GetCurrentTask();
|
||||
Java_org_ray_runtime_context_NativeWorkerContext_nativeGetCurrentTaskType(JNIEnv *env,
|
||||
jclass) {
|
||||
auto task_spec =
|
||||
ray::CoreWorkerProcess::GetCoreWorker().GetWorkerContext().GetCurrentTask();
|
||||
RAY_CHECK(task_spec) << "Current task is not set.";
|
||||
return static_cast<int>(task_spec->GetMessage().type());
|
||||
}
|
||||
|
||||
JNIEXPORT jobject JNICALL
|
||||
Java_org_ray_runtime_context_NativeWorkerContext_nativeGetCurrentTaskId(
|
||||
JNIEnv *env, jclass, jlong nativeCoreWorkerPointer) {
|
||||
Java_org_ray_runtime_context_NativeWorkerContext_nativeGetCurrentTaskId(JNIEnv *env,
|
||||
jclass) {
|
||||
const ray::TaskID &task_id =
|
||||
GetWorkerContextFromPointer(nativeCoreWorkerPointer).GetCurrentTaskID();
|
||||
ray::CoreWorkerProcess::GetCoreWorker().GetWorkerContext().GetCurrentTaskID();
|
||||
return IdToJavaByteBuffer<ray::TaskID>(env, task_id);
|
||||
}
|
||||
|
||||
JNIEXPORT jobject JNICALL
|
||||
Java_org_ray_runtime_context_NativeWorkerContext_nativeGetCurrentJobId(
|
||||
JNIEnv *env, jclass, jlong nativeCoreWorkerPointer) {
|
||||
Java_org_ray_runtime_context_NativeWorkerContext_nativeGetCurrentJobId(JNIEnv *env,
|
||||
jclass) {
|
||||
const auto &job_id =
|
||||
GetWorkerContextFromPointer(nativeCoreWorkerPointer).GetCurrentJobID();
|
||||
ray::CoreWorkerProcess::GetCoreWorker().GetWorkerContext().GetCurrentJobID();
|
||||
return IdToJavaByteBuffer<ray::JobID>(env, job_id);
|
||||
}
|
||||
|
||||
JNIEXPORT jobject JNICALL
|
||||
Java_org_ray_runtime_context_NativeWorkerContext_nativeGetCurrentWorkerId(
|
||||
JNIEnv *env, jclass, jlong nativeCoreWorkerPointer) {
|
||||
Java_org_ray_runtime_context_NativeWorkerContext_nativeGetCurrentWorkerId(JNIEnv *env,
|
||||
jclass) {
|
||||
const auto &worker_id =
|
||||
GetWorkerContextFromPointer(nativeCoreWorkerPointer).GetWorkerID();
|
||||
ray::CoreWorkerProcess::GetCoreWorker().GetWorkerContext().GetWorkerID();
|
||||
return IdToJavaByteBuffer<ray::WorkerID>(env, worker_id);
|
||||
}
|
||||
|
||||
JNIEXPORT jobject JNICALL
|
||||
Java_org_ray_runtime_context_NativeWorkerContext_nativeGetCurrentActorId(
|
||||
JNIEnv *env, jclass, jlong nativeCoreWorkerPointer) {
|
||||
Java_org_ray_runtime_context_NativeWorkerContext_nativeGetCurrentActorId(JNIEnv *env,
|
||||
jclass) {
|
||||
const auto &actor_id =
|
||||
GetWorkerContextFromPointer(nativeCoreWorkerPointer).GetCurrentActorID();
|
||||
ray::CoreWorkerProcess::GetCoreWorker().GetWorkerContext().GetCurrentActorID();
|
||||
return IdToJavaByteBuffer<ray::ActorID>(env, actor_id);
|
||||
}
|
||||
|
||||
|
|
|
@ -24,47 +24,45 @@ extern "C" {
|
|||
/*
|
||||
* Class: org_ray_runtime_context_NativeWorkerContext
|
||||
* Method: nativeGetCurrentTaskType
|
||||
* Signature: (J)I
|
||||
* Signature: ()I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL
|
||||
Java_org_ray_runtime_context_NativeWorkerContext_nativeGetCurrentTaskType(JNIEnv *,
|
||||
jclass, jlong);
|
||||
jclass);
|
||||
|
||||
/*
|
||||
* Class: org_ray_runtime_context_NativeWorkerContext
|
||||
* Method: nativeGetCurrentTaskId
|
||||
* Signature: (J)Ljava/nio/ByteBuffer;
|
||||
* Signature: ()Ljava/nio/ByteBuffer;
|
||||
*/
|
||||
JNIEXPORT jobject JNICALL
|
||||
Java_org_ray_runtime_context_NativeWorkerContext_nativeGetCurrentTaskId(JNIEnv *, jclass,
|
||||
jlong);
|
||||
Java_org_ray_runtime_context_NativeWorkerContext_nativeGetCurrentTaskId(JNIEnv *, jclass);
|
||||
|
||||
/*
|
||||
* Class: org_ray_runtime_context_NativeWorkerContext
|
||||
* Method: nativeGetCurrentJobId
|
||||
* Signature: (J)Ljava/nio/ByteBuffer;
|
||||
* Signature: ()Ljava/nio/ByteBuffer;
|
||||
*/
|
||||
JNIEXPORT jobject JNICALL
|
||||
Java_org_ray_runtime_context_NativeWorkerContext_nativeGetCurrentJobId(JNIEnv *, jclass,
|
||||
jlong);
|
||||
Java_org_ray_runtime_context_NativeWorkerContext_nativeGetCurrentJobId(JNIEnv *, jclass);
|
||||
|
||||
/*
|
||||
* Class: org_ray_runtime_context_NativeWorkerContext
|
||||
* Method: nativeGetCurrentWorkerId
|
||||
* Signature: (J)Ljava/nio/ByteBuffer;
|
||||
* Signature: ()Ljava/nio/ByteBuffer;
|
||||
*/
|
||||
JNIEXPORT jobject JNICALL
|
||||
Java_org_ray_runtime_context_NativeWorkerContext_nativeGetCurrentWorkerId(JNIEnv *,
|
||||
jclass, jlong);
|
||||
jclass);
|
||||
|
||||
/*
|
||||
* Class: org_ray_runtime_context_NativeWorkerContext
|
||||
* Method: nativeGetCurrentActorId
|
||||
* Signature: (J)Ljava/nio/ByteBuffer;
|
||||
* Signature: ()Ljava/nio/ByteBuffer;
|
||||
*/
|
||||
JNIEXPORT jobject JNICALL
|
||||
Java_org_ray_runtime_context_NativeWorkerContext_nativeGetCurrentActorId(JNIEnv *, jclass,
|
||||
jlong);
|
||||
Java_org_ray_runtime_context_NativeWorkerContext_nativeGetCurrentActorId(JNIEnv *,
|
||||
jclass);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
|
|
|
@ -24,55 +24,51 @@ extern "C" {
|
|||
#endif
|
||||
|
||||
JNIEXPORT jbyteArray JNICALL
|
||||
Java_org_ray_runtime_object_NativeObjectStore_nativePut__JLorg_ray_runtime_object_NativeRayObject_2(
|
||||
JNIEnv *env, jclass, jlong nativeCoreWorkerPointer, jobject obj) {
|
||||
Java_org_ray_runtime_object_NativeObjectStore_nativePut__Lorg_ray_runtime_object_NativeRayObject_2(
|
||||
JNIEnv *env, jclass, jobject obj) {
|
||||
auto ray_object = JavaNativeRayObjectToNativeRayObject(env, obj);
|
||||
RAY_CHECK(ray_object != nullptr);
|
||||
ray::ObjectID object_id;
|
||||
auto status = reinterpret_cast<ray::CoreWorker *>(nativeCoreWorkerPointer)
|
||||
->Put(*ray_object, {}, &object_id);
|
||||
auto status = ray::CoreWorkerProcess::GetCoreWorker().Put(*ray_object, {}, &object_id);
|
||||
THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, nullptr);
|
||||
return IdToJavaByteArray<ray::ObjectID>(env, object_id);
|
||||
}
|
||||
|
||||
JNIEXPORT void JNICALL
|
||||
Java_org_ray_runtime_object_NativeObjectStore_nativePut__J_3BLorg_ray_runtime_object_NativeRayObject_2(
|
||||
JNIEnv *env, jclass, jlong nativeCoreWorkerPointer, jbyteArray objectId,
|
||||
jobject obj) {
|
||||
Java_org_ray_runtime_object_NativeObjectStore_nativePut___3BLorg_ray_runtime_object_NativeRayObject_2(
|
||||
JNIEnv *env, jclass, jbyteArray objectId, jobject obj) {
|
||||
auto object_id = JavaByteArrayToId<ray::ObjectID>(env, objectId);
|
||||
auto ray_object = JavaNativeRayObjectToNativeRayObject(env, obj);
|
||||
RAY_CHECK(ray_object != nullptr);
|
||||
auto status = reinterpret_cast<ray::CoreWorker *>(nativeCoreWorkerPointer)
|
||||
->Put(*ray_object, {}, object_id);
|
||||
auto status = ray::CoreWorkerProcess::GetCoreWorker().Put(*ray_object, {}, object_id);
|
||||
THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, (void)0);
|
||||
}
|
||||
|
||||
JNIEXPORT jobject JNICALL Java_org_ray_runtime_object_NativeObjectStore_nativeGet(
|
||||
JNIEnv *env, jclass, jlong nativeCoreWorkerPointer, jobject ids, jlong timeoutMs) {
|
||||
JNIEnv *env, jclass, jobject ids, jlong timeoutMs) {
|
||||
std::vector<ray::ObjectID> object_ids;
|
||||
JavaListToNativeVector<ray::ObjectID>(
|
||||
env, ids, &object_ids, [](JNIEnv *env, jobject id) {
|
||||
return JavaByteArrayToId<ray::ObjectID>(env, static_cast<jbyteArray>(id));
|
||||
});
|
||||
std::vector<std::shared_ptr<ray::RayObject>> results;
|
||||
auto status = reinterpret_cast<ray::CoreWorker *>(nativeCoreWorkerPointer)
|
||||
->Get(object_ids, (int64_t)timeoutMs, &results);
|
||||
auto status = ray::CoreWorkerProcess::GetCoreWorker().Get(object_ids,
|
||||
(int64_t)timeoutMs, &results);
|
||||
THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, nullptr);
|
||||
return NativeVectorToJavaList<std::shared_ptr<ray::RayObject>>(
|
||||
env, results, NativeRayObjectToJavaNativeRayObject);
|
||||
}
|
||||
|
||||
JNIEXPORT jobject JNICALL Java_org_ray_runtime_object_NativeObjectStore_nativeWait(
|
||||
JNIEnv *env, jclass, jlong nativeCoreWorkerPointer, jobject objectIds,
|
||||
jint numObjects, jlong timeoutMs) {
|
||||
JNIEnv *env, jclass, jobject objectIds, jint numObjects, jlong timeoutMs) {
|
||||
std::vector<ray::ObjectID> object_ids;
|
||||
JavaListToNativeVector<ray::ObjectID>(
|
||||
env, objectIds, &object_ids, [](JNIEnv *env, jobject id) {
|
||||
return JavaByteArrayToId<ray::ObjectID>(env, static_cast<jbyteArray>(id));
|
||||
});
|
||||
std::vector<bool> results;
|
||||
auto status = reinterpret_cast<ray::CoreWorker *>(nativeCoreWorkerPointer)
|
||||
->Wait(object_ids, (int)numObjects, (int64_t)timeoutMs, &results);
|
||||
auto status = ray::CoreWorkerProcess::GetCoreWorker().Wait(
|
||||
object_ids, (int)numObjects, (int64_t)timeoutMs, &results);
|
||||
THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, nullptr);
|
||||
return NativeVectorToJavaList<bool>(env, results, [](JNIEnv *env, const bool &item) {
|
||||
return env->NewObject(java_boolean_class, java_boolean_init, (jboolean)item);
|
||||
|
@ -80,15 +76,15 @@ JNIEXPORT jobject JNICALL Java_org_ray_runtime_object_NativeObjectStore_nativeWa
|
|||
}
|
||||
|
||||
JNIEXPORT void JNICALL Java_org_ray_runtime_object_NativeObjectStore_nativeDelete(
|
||||
JNIEnv *env, jclass, jlong nativeCoreWorkerPointer, jobject objectIds,
|
||||
jboolean localOnly, jboolean deleteCreatingTasks) {
|
||||
JNIEnv *env, jclass, jobject objectIds, jboolean localOnly,
|
||||
jboolean deleteCreatingTasks) {
|
||||
std::vector<ray::ObjectID> object_ids;
|
||||
JavaListToNativeVector<ray::ObjectID>(
|
||||
env, objectIds, &object_ids, [](JNIEnv *env, jobject id) {
|
||||
return JavaByteArrayToId<ray::ObjectID>(env, static_cast<jbyteArray>(id));
|
||||
});
|
||||
auto status = reinterpret_cast<ray::CoreWorker *>(nativeCoreWorkerPointer)
|
||||
->Delete(object_ids, (bool)localOnly, (bool)deleteCreatingTasks);
|
||||
auto status = ray::CoreWorkerProcess::GetCoreWorker().Delete(
|
||||
object_ids, (bool)localOnly, (bool)deleteCreatingTasks);
|
||||
THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, (void)0);
|
||||
}
|
||||
|
||||
|
|
|
@ -24,44 +24,44 @@ extern "C" {
|
|||
/*
|
||||
* Class: org_ray_runtime_object_NativeObjectStore
|
||||
* Method: nativePut
|
||||
* Signature: (JLorg/ray/runtime/object/NativeRayObject;)[B
|
||||
* Signature: (Lorg/ray/runtime/object/NativeRayObject;)[B
|
||||
*/
|
||||
JNIEXPORT jbyteArray JNICALL
|
||||
Java_org_ray_runtime_object_NativeObjectStore_nativePut__JLorg_ray_runtime_object_NativeRayObject_2(
|
||||
JNIEnv *, jclass, jlong, jobject);
|
||||
Java_org_ray_runtime_object_NativeObjectStore_nativePut__Lorg_ray_runtime_object_NativeRayObject_2(
|
||||
JNIEnv *, jclass, jobject);
|
||||
|
||||
/*
|
||||
* Class: org_ray_runtime_object_NativeObjectStore
|
||||
* Method: nativePut
|
||||
* Signature: (J[BLorg/ray/runtime/object/NativeRayObject;)V
|
||||
* Signature: ([BLorg/ray/runtime/object/NativeRayObject;)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL
|
||||
Java_org_ray_runtime_object_NativeObjectStore_nativePut__J_3BLorg_ray_runtime_object_NativeRayObject_2(
|
||||
JNIEnv *, jclass, jlong, jbyteArray, jobject);
|
||||
Java_org_ray_runtime_object_NativeObjectStore_nativePut___3BLorg_ray_runtime_object_NativeRayObject_2(
|
||||
JNIEnv *, jclass, jbyteArray, jobject);
|
||||
|
||||
/*
|
||||
* Class: org_ray_runtime_object_NativeObjectStore
|
||||
* Method: nativeGet
|
||||
* Signature: (JLjava/util/List;J)Ljava/util/List;
|
||||
* Signature: (Ljava/util/List;J)Ljava/util/List;
|
||||
*/
|
||||
JNIEXPORT jobject JNICALL Java_org_ray_runtime_object_NativeObjectStore_nativeGet(
|
||||
JNIEnv *, jclass, jlong, jobject, jlong);
|
||||
JNIEXPORT jobject JNICALL
|
||||
Java_org_ray_runtime_object_NativeObjectStore_nativeGet(JNIEnv *, jclass, jobject, jlong);
|
||||
|
||||
/*
|
||||
* Class: org_ray_runtime_object_NativeObjectStore
|
||||
* Method: nativeWait
|
||||
* Signature: (JLjava/util/List;IJ)Ljava/util/List;
|
||||
* Signature: (Ljava/util/List;IJ)Ljava/util/List;
|
||||
*/
|
||||
JNIEXPORT jobject JNICALL Java_org_ray_runtime_object_NativeObjectStore_nativeWait(
|
||||
JNIEnv *, jclass, jlong, jobject, jint, jlong);
|
||||
JNIEnv *, jclass, jobject, jint, jlong);
|
||||
|
||||
/*
|
||||
* Class: org_ray_runtime_object_NativeObjectStore
|
||||
* Method: nativeDelete
|
||||
* Signature: (JLjava/util/List;ZZ)V
|
||||
* Signature: (Ljava/util/List;ZZ)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL Java_org_ray_runtime_object_NativeObjectStore_nativeDelete(
|
||||
JNIEnv *, jclass, jlong, jobject, jboolean, jboolean);
|
||||
JNIEnv *, jclass, jobject, jboolean, jboolean);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
|
|
|
@ -27,9 +27,9 @@ extern "C" {
|
|||
using ray::ClientID;
|
||||
|
||||
JNIEXPORT jbyteArray JNICALL
|
||||
Java_org_ray_runtime_task_NativeTaskExecutor_nativePrepareCheckpoint(
|
||||
JNIEnv *env, jclass, jlong nativeCoreWorkerPointer) {
|
||||
auto &core_worker = *reinterpret_cast<ray::CoreWorker *>(nativeCoreWorkerPointer);
|
||||
Java_org_ray_runtime_task_NativeTaskExecutor_nativePrepareCheckpoint(JNIEnv *env,
|
||||
jclass) {
|
||||
auto &core_worker = ray::CoreWorkerProcess::GetCoreWorker();
|
||||
const auto &actor_id = core_worker.GetWorkerContext().GetCurrentActorID();
|
||||
const auto &task_spec = core_worker.GetWorkerContext().GetCurrentTask();
|
||||
RAY_CHECK(task_spec->IsActorTask());
|
||||
|
@ -44,11 +44,12 @@ Java_org_ray_runtime_task_NativeTaskExecutor_nativePrepareCheckpoint(
|
|||
|
||||
JNIEXPORT void JNICALL
|
||||
Java_org_ray_runtime_task_NativeTaskExecutor_nativeNotifyActorResumedFromCheckpoint(
|
||||
JNIEnv *env, jclass, jlong nativeCoreWorkerPointer, jbyteArray checkpointId) {
|
||||
auto &core_worker = *reinterpret_cast<ray::CoreWorker *>(nativeCoreWorkerPointer);
|
||||
const auto &actor_id = core_worker.GetWorkerContext().GetCurrentActorID();
|
||||
JNIEnv *env, jclass, jbyteArray checkpointId) {
|
||||
const auto &actor_id =
|
||||
ray::CoreWorkerProcess::GetCoreWorker().GetWorkerContext().GetCurrentActorID();
|
||||
const auto checkpoint_id = JavaByteArrayToId<ActorCheckpointID>(env, checkpointId);
|
||||
auto status = core_worker.NotifyActorResumedFromCheckpoint(actor_id, checkpoint_id);
|
||||
auto status = ray::CoreWorkerProcess::GetCoreWorker().NotifyActorResumedFromCheckpoint(
|
||||
actor_id, checkpoint_id);
|
||||
THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, (void)0);
|
||||
}
|
||||
|
||||
|
|
|
@ -26,20 +26,19 @@ extern "C" {
|
|||
/*
|
||||
* Class: org_ray_runtime_task_NativeTaskExecutor
|
||||
* Method: nativePrepareCheckpoint
|
||||
* Signature: (J)[B
|
||||
* Signature: ()[B
|
||||
*/
|
||||
JNIEXPORT jbyteArray JNICALL
|
||||
Java_org_ray_runtime_task_NativeTaskExecutor_nativePrepareCheckpoint(JNIEnv *, jclass,
|
||||
jlong);
|
||||
Java_org_ray_runtime_task_NativeTaskExecutor_nativePrepareCheckpoint(JNIEnv *, jclass);
|
||||
|
||||
/*
|
||||
* Class: org_ray_runtime_task_NativeTaskExecutor
|
||||
* Method: nativeNotifyActorResumedFromCheckpoint
|
||||
* Signature: (J[B)V
|
||||
* Signature: ([B)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL
|
||||
Java_org_ray_runtime_task_NativeTaskExecutor_nativeNotifyActorResumedFromCheckpoint(
|
||||
JNIEnv *, jclass, jlong, jbyteArray);
|
||||
JNIEnv *, jclass, jbyteArray);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
|
|
|
@ -19,10 +19,6 @@
|
|||
#include "ray/core_worker/core_worker.h"
|
||||
#include "ray/core_worker/lib/java/jni_utils.h"
|
||||
|
||||
inline ray::CoreWorker &GetCoreWorker(jlong nativeCoreWorkerPointer) {
|
||||
return *reinterpret_cast<ray::CoreWorker *>(nativeCoreWorkerPointer);
|
||||
}
|
||||
|
||||
inline ray::RayFunction ToRayFunction(JNIEnv *env, jobject functionDescriptor) {
|
||||
std::vector<std::string> function_descriptor_list;
|
||||
jobject list =
|
||||
|
@ -127,17 +123,17 @@ extern "C" {
|
|||
#endif
|
||||
|
||||
JNIEXPORT jobject JNICALL Java_org_ray_runtime_task_NativeTaskSubmitter_nativeSubmitTask(
|
||||
JNIEnv *env, jclass p, jlong nativeCoreWorkerPointer, jobject functionDescriptor,
|
||||
jobject args, jint numReturns, jobject callOptions) {
|
||||
JNIEnv *env, jclass p, jobject functionDescriptor, jobject args, jint numReturns,
|
||||
jobject callOptions) {
|
||||
auto ray_function = ToRayFunction(env, functionDescriptor);
|
||||
auto task_args = ToTaskArgs(env, args);
|
||||
auto task_options = ToTaskOptions(env, numReturns, callOptions);
|
||||
|
||||
std::vector<ObjectID> return_ids;
|
||||
// TODO (kfstorm): Allow setting `max_retries` via `CallOptions`.
|
||||
auto status = GetCoreWorker(nativeCoreWorkerPointer)
|
||||
.SubmitTask(ray_function, task_args, task_options, &return_ids,
|
||||
/*max_retries=*/0);
|
||||
auto status = ray::CoreWorkerProcess::GetCoreWorker().SubmitTask(
|
||||
ray_function, task_args, task_options, &return_ids,
|
||||
/*max_retries=*/0);
|
||||
|
||||
THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, nullptr);
|
||||
|
||||
|
@ -146,16 +142,16 @@ JNIEXPORT jobject JNICALL Java_org_ray_runtime_task_NativeTaskSubmitter_nativeSu
|
|||
|
||||
JNIEXPORT jbyteArray JNICALL
|
||||
Java_org_ray_runtime_task_NativeTaskSubmitter_nativeCreateActor(
|
||||
JNIEnv *env, jclass p, jlong nativeCoreWorkerPointer, jobject functionDescriptor,
|
||||
jobject args, jobject actorCreationOptions) {
|
||||
JNIEnv *env, jclass p, jobject functionDescriptor, jobject args,
|
||||
jobject actorCreationOptions) {
|
||||
auto ray_function = ToRayFunction(env, functionDescriptor);
|
||||
auto task_args = ToTaskArgs(env, args);
|
||||
auto actor_creation_options = ToActorCreationOptions(env, actorCreationOptions);
|
||||
|
||||
ray::ActorID actor_id;
|
||||
auto status = GetCoreWorker(nativeCoreWorkerPointer)
|
||||
.CreateActor(ray_function, task_args, actor_creation_options,
|
||||
/*extension_data*/ "", &actor_id);
|
||||
ActorID actor_id;
|
||||
auto status = ray::CoreWorkerProcess::GetCoreWorker().CreateActor(
|
||||
ray_function, task_args, actor_creation_options,
|
||||
/*extension_data*/ "", &actor_id);
|
||||
|
||||
THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, nullptr);
|
||||
return IdToJavaByteArray<ray::ActorID>(env, actor_id);
|
||||
|
@ -163,17 +159,16 @@ Java_org_ray_runtime_task_NativeTaskSubmitter_nativeCreateActor(
|
|||
|
||||
JNIEXPORT jobject JNICALL
|
||||
Java_org_ray_runtime_task_NativeTaskSubmitter_nativeSubmitActorTask(
|
||||
JNIEnv *env, jclass p, jlong nativeCoreWorkerPointer, jbyteArray actorId,
|
||||
jobject functionDescriptor, jobject args, jint numReturns, jobject callOptions) {
|
||||
JNIEnv *env, jclass p, jbyteArray actorId, jobject functionDescriptor, jobject args,
|
||||
jint numReturns, jobject callOptions) {
|
||||
auto actor_id = JavaByteArrayToId<ray::ActorID>(env, actorId);
|
||||
auto ray_function = ToRayFunction(env, functionDescriptor);
|
||||
auto task_args = ToTaskArgs(env, args);
|
||||
auto task_options = ToTaskOptions(env, numReturns, callOptions);
|
||||
|
||||
std::vector<ObjectID> return_ids;
|
||||
auto status =
|
||||
GetCoreWorker(nativeCoreWorkerPointer)
|
||||
.SubmitActorTask(actor_id, ray_function, task_args, task_options, &return_ids);
|
||||
auto status = ray::CoreWorkerProcess::GetCoreWorker().SubmitActorTask(
|
||||
actor_id, ray_function, task_args, task_options, &return_ids);
|
||||
|
||||
THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, nullptr);
|
||||
return NativeIdVectorToJavaByteArrayList(env, return_ids);
|
||||
|
|
|
@ -25,33 +25,32 @@ extern "C" {
|
|||
* Class: org_ray_runtime_task_NativeTaskSubmitter
|
||||
* Method: nativeSubmitTask
|
||||
* Signature:
|
||||
* (JLorg/ray/runtime/functionmanager/FunctionDescriptor;Ljava/util/List;ILorg/ray/api/options/CallOptions;)Ljava/util/List;
|
||||
* (Lorg/ray/runtime/functionmanager/FunctionDescriptor;Ljava/util/List;ILorg/ray/api/options/CallOptions;)Ljava/util/List;
|
||||
*/
|
||||
JNIEXPORT jobject JNICALL Java_org_ray_runtime_task_NativeTaskSubmitter_nativeSubmitTask(
|
||||
JNIEnv *, jclass, jlong, jobject, jobject, jint, jobject);
|
||||
JNIEnv *, jclass, jobject, jobject, jint, jobject);
|
||||
|
||||
/*
|
||||
* Class: org_ray_runtime_task_NativeTaskSubmitter
|
||||
* Method: nativeCreateActor
|
||||
* Signature:
|
||||
* (JLorg/ray/runtime/functionmanager/FunctionDescriptor;Ljava/util/List;Lorg/ray/api/options/ActorCreationOptions;)[B
|
||||
* (Lorg/ray/runtime/functionmanager/FunctionDescriptor;Ljava/util/List;Lorg/ray/api/options/ActorCreationOptions;)[B
|
||||
*/
|
||||
JNIEXPORT jbyteArray JNICALL
|
||||
Java_org_ray_runtime_task_NativeTaskSubmitter_nativeCreateActor(JNIEnv *, jclass, jlong,
|
||||
jobject, jobject,
|
||||
jobject);
|
||||
Java_org_ray_runtime_task_NativeTaskSubmitter_nativeCreateActor(JNIEnv *, jclass, jobject,
|
||||
jobject, jobject);
|
||||
|
||||
/*
|
||||
* Class: org_ray_runtime_task_NativeTaskSubmitter
|
||||
* Method: nativeSubmitActorTask
|
||||
* Signature:
|
||||
* (J[BLorg/ray/runtime/functionmanager/FunctionDescriptor;Ljava/util/List;ILorg/ray/api/options/CallOptions;)Ljava/util/List;
|
||||
* ([BLorg/ray/runtime/functionmanager/FunctionDescriptor;Ljava/util/List;ILorg/ray/api/options/CallOptions;)Ljava/util/List;
|
||||
*/
|
||||
JNIEXPORT jobject JNICALL
|
||||
Java_org_ray_runtime_task_NativeTaskSubmitter_nativeSubmitActorTask(JNIEnv *, jclass,
|
||||
jlong, jbyteArray,
|
||||
jobject, jobject,
|
||||
jint, jobject);
|
||||
jbyteArray, jobject,
|
||||
jobject, jint,
|
||||
jobject);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
|
|
|
@ -340,7 +340,7 @@ TEST(MemoryStoreIntegrationTest, TestSimple) {
|
|||
RAY_CHECK(store.Put(buffer, id1));
|
||||
ASSERT_EQ(store.Size(), 1);
|
||||
std::vector<std::shared_ptr<RayObject>> results;
|
||||
WorkerContext ctx(WorkerType::WORKER, JobID::Nil());
|
||||
WorkerContext ctx(WorkerType::WORKER, WorkerID::FromRandom(), JobID::Nil());
|
||||
RAY_CHECK_OK(store.Get({id1}, /*num_objects*/ 1, /*timeout_ms*/ -1, ctx,
|
||||
/*remove_after_get*/ true, &results));
|
||||
ASSERT_EQ(results.size(), 1);
|
||||
|
|
|
@ -57,8 +57,7 @@ static void flushall_redis(void) {
|
|||
redisFree(context);
|
||||
}
|
||||
|
||||
ActorID CreateActorHelper(CoreWorker &worker,
|
||||
std::unordered_map<std::string, double> &resources,
|
||||
ActorID CreateActorHelper(std::unordered_map<std::string, double> &resources,
|
||||
uint64_t max_reconstructions) {
|
||||
std::unique_ptr<ActorHandle> actor_handle;
|
||||
|
||||
|
@ -78,8 +77,8 @@ ActorID CreateActorHelper(CoreWorker &worker,
|
|||
|
||||
// Create an actor.
|
||||
ActorID actor_id;
|
||||
RAY_CHECK_OK(
|
||||
worker.CreateActor(func, args, actor_options, /*extension_data*/ "", &actor_id));
|
||||
RAY_CHECK_OK(CoreWorkerProcess::GetCoreWorker().CreateActor(
|
||||
func, args, actor_options, /*extension_data*/ "", &actor_id));
|
||||
return actor_id;
|
||||
}
|
||||
|
||||
|
@ -90,7 +89,8 @@ std::string MetadataToString(std::shared_ptr<RayObject> obj) {
|
|||
|
||||
class CoreWorkerTest : public ::testing::Test {
|
||||
public:
|
||||
CoreWorkerTest(int num_nodes) : gcs_options_("127.0.0.1", 6379, "") {
|
||||
CoreWorkerTest(int num_nodes)
|
||||
: num_nodes_(num_nodes), gcs_options_("127.0.0.1", 6379, "") {
|
||||
#ifdef _WIN32
|
||||
RAY_CHECK(false) << "port system() calls to Windows before running this test";
|
||||
#endif
|
||||
|
@ -250,9 +250,39 @@ class CoreWorkerTest : public ::testing::Test {
|
|||
ASSERT_TRUE(system(("rm -f " + gcs_server_pid).c_str()) == 0);
|
||||
}
|
||||
|
||||
void SetUp() {}
|
||||
void SetUp() {
|
||||
if (num_nodes_ > 0) {
|
||||
CoreWorkerOptions options = {
|
||||
WorkerType::DRIVER, // worker_type
|
||||
Language::PYTHON, // langauge
|
||||
raylet_store_socket_names_[0], // store_socket
|
||||
raylet_socket_names_[0], // raylet_socket
|
||||
NextJobId(), // job_id
|
||||
gcs_options_, // gcs_options
|
||||
"", // log_dir
|
||||
true, // install_failure_signal_handler
|
||||
"127.0.0.1", // node_ip_address
|
||||
node_manager_port, // node_manager_port
|
||||
"core_worker_test", // driver_name
|
||||
"", // stdout_file
|
||||
"", // stderr_file
|
||||
nullptr, // task_execution_callback
|
||||
nullptr, // check_signals
|
||||
nullptr, // gc_collect
|
||||
nullptr, // get_lang_stack
|
||||
true, // ref_counting_enabled
|
||||
false, // is_local_mode
|
||||
1, // num_workers
|
||||
};
|
||||
CoreWorkerProcess::Initialize(options);
|
||||
}
|
||||
}
|
||||
|
||||
void TearDown() {}
|
||||
void TearDown() {
|
||||
if (num_nodes_ > 0) {
|
||||
CoreWorkerProcess::Shutdown();
|
||||
}
|
||||
}
|
||||
|
||||
// Test normal tasks.
|
||||
void TestNormalTask(std::unordered_map<std::string, double> &resources);
|
||||
|
@ -271,13 +301,14 @@ class CoreWorkerTest : public ::testing::Test {
|
|||
void TestActorReconstruction(std::unordered_map<std::string, double> &resources);
|
||||
|
||||
protected:
|
||||
bool WaitForDirectCallActorState(CoreWorker &worker, const ActorID &actor_id,
|
||||
bool wait_alive, int timeout_ms);
|
||||
bool WaitForDirectCallActorState(const ActorID &actor_id, bool wait_alive,
|
||||
int timeout_ms);
|
||||
|
||||
// Get the pid for the worker process that runs the actor.
|
||||
int GetActorPid(CoreWorker &worker, const ActorID &actor_id,
|
||||
int GetActorPid(const ActorID &actor_id,
|
||||
std::unordered_map<std::string, double> &resources);
|
||||
|
||||
int num_nodes_;
|
||||
std::vector<std::string> raylet_socket_names_;
|
||||
std::vector<std::string> raylet_store_socket_names_;
|
||||
std::string raylet_monitor_pid_;
|
||||
|
@ -285,18 +316,19 @@ class CoreWorkerTest : public ::testing::Test {
|
|||
std::string gcs_server_pid_;
|
||||
};
|
||||
|
||||
bool CoreWorkerTest::WaitForDirectCallActorState(CoreWorker &worker,
|
||||
const ActorID &actor_id, bool wait_alive,
|
||||
bool CoreWorkerTest::WaitForDirectCallActorState(const ActorID &actor_id, bool wait_alive,
|
||||
int timeout_ms) {
|
||||
auto condition_func = [&worker, actor_id, wait_alive]() -> bool {
|
||||
bool actor_alive = worker.direct_actor_submitter_->IsActorAlive(actor_id);
|
||||
auto condition_func = [actor_id, wait_alive]() -> bool {
|
||||
bool actor_alive =
|
||||
CoreWorkerProcess::GetCoreWorker().direct_actor_submitter_->IsActorAlive(
|
||||
actor_id);
|
||||
return wait_alive ? actor_alive : !actor_alive;
|
||||
};
|
||||
|
||||
return WaitForCondition(condition_func, timeout_ms);
|
||||
}
|
||||
|
||||
int CoreWorkerTest::GetActorPid(CoreWorker &worker, const ActorID &actor_id,
|
||||
int CoreWorkerTest::GetActorPid(const ActorID &actor_id,
|
||||
std::unordered_map<std::string, double> &resources) {
|
||||
std::vector<TaskArg> args;
|
||||
TaskOptions options{1, resources};
|
||||
|
@ -304,10 +336,11 @@ int CoreWorkerTest::GetActorPid(CoreWorker &worker, const ActorID &actor_id,
|
|||
RayFunction func{Language::PYTHON, ray::FunctionDescriptorBuilder::BuildPython(
|
||||
"GetWorkerPid", "", "", "")};
|
||||
|
||||
RAY_CHECK_OK(worker.SubmitActorTask(actor_id, func, args, options, &return_ids));
|
||||
RAY_CHECK_OK(CoreWorkerProcess::GetCoreWorker().SubmitActorTask(actor_id, func, args,
|
||||
options, &return_ids));
|
||||
|
||||
std::vector<std::shared_ptr<ray::RayObject>> results;
|
||||
RAY_CHECK_OK(worker.Get(return_ids, -1, &results));
|
||||
RAY_CHECK_OK(CoreWorkerProcess::GetCoreWorker().Get(return_ids, -1, &results));
|
||||
|
||||
if (nullptr == results[0]->GetData()) {
|
||||
// If failed to get actor process pid, return -1
|
||||
|
@ -320,9 +353,7 @@ int CoreWorkerTest::GetActorPid(CoreWorker &worker, const ActorID &actor_id,
|
|||
}
|
||||
|
||||
void CoreWorkerTest::TestNormalTask(std::unordered_map<std::string, double> &resources) {
|
||||
CoreWorker driver(WorkerType::DRIVER, Language::PYTHON, raylet_store_socket_names_[0],
|
||||
raylet_socket_names_[0], NextJobId(), gcs_options_, "", "127.0.0.1",
|
||||
node_manager_port, nullptr);
|
||||
auto &driver = CoreWorkerProcess::GetCoreWorker();
|
||||
|
||||
// Test for tasks with by-value and by-ref args.
|
||||
{
|
||||
|
@ -364,11 +395,9 @@ void CoreWorkerTest::TestNormalTask(std::unordered_map<std::string, double> &res
|
|||
}
|
||||
|
||||
void CoreWorkerTest::TestActorTask(std::unordered_map<std::string, double> &resources) {
|
||||
CoreWorker driver(WorkerType::DRIVER, Language::PYTHON, raylet_store_socket_names_[0],
|
||||
raylet_socket_names_[0], NextJobId(), gcs_options_, "", "127.0.0.1",
|
||||
node_manager_port, nullptr);
|
||||
auto &driver = CoreWorkerProcess::GetCoreWorker();
|
||||
|
||||
auto actor_id = CreateActorHelper(driver, resources, 1000);
|
||||
auto actor_id = CreateActorHelper(resources, 1000);
|
||||
|
||||
// Test submitting some tasks with by-value args for that actor.
|
||||
{
|
||||
|
@ -452,18 +481,16 @@ void CoreWorkerTest::TestActorTask(std::unordered_map<std::string, double> &reso
|
|||
|
||||
void CoreWorkerTest::TestActorReconstruction(
|
||||
std::unordered_map<std::string, double> &resources) {
|
||||
CoreWorker driver(WorkerType::DRIVER, Language::PYTHON, raylet_store_socket_names_[0],
|
||||
raylet_socket_names_[0], NextJobId(), gcs_options_, "", "127.0.0.1",
|
||||
node_manager_port, nullptr);
|
||||
auto &driver = CoreWorkerProcess::GetCoreWorker();
|
||||
|
||||
// creating actor.
|
||||
auto actor_id = CreateActorHelper(driver, resources, 1000);
|
||||
auto actor_id = CreateActorHelper(resources, 1000);
|
||||
|
||||
// Wait for actor alive event.
|
||||
ASSERT_TRUE(WaitForDirectCallActorState(driver, actor_id, true, 30 * 1000 /* 30s */));
|
||||
ASSERT_TRUE(WaitForDirectCallActorState(actor_id, true, 30 * 1000 /* 30s */));
|
||||
RAY_LOG(INFO) << "actor has been created";
|
||||
|
||||
auto pid = GetActorPid(driver, actor_id, resources);
|
||||
auto pid = GetActorPid(actor_id, resources);
|
||||
RAY_CHECK(pid != -1);
|
||||
|
||||
// Test submitting some tasks with by-value args for that actor.
|
||||
|
@ -477,9 +504,8 @@ void CoreWorkerTest::TestActorReconstruction(
|
|||
ASSERT_EQ(system("pkill mock_worker"), 0);
|
||||
|
||||
// Wait for actor restruction event, and then for alive event.
|
||||
auto check_actor_restart_func = [this, pid, &driver, &actor_id,
|
||||
&resources]() -> bool {
|
||||
auto new_pid = GetActorPid(driver, actor_id, resources);
|
||||
auto check_actor_restart_func = [this, pid, &actor_id, &resources]() -> bool {
|
||||
auto new_pid = GetActorPid(actor_id, resources);
|
||||
return new_pid != -1 && new_pid != pid;
|
||||
};
|
||||
ASSERT_TRUE(WaitForCondition(check_actor_restart_func, 30 * 1000 /* 30s */));
|
||||
|
@ -514,12 +540,10 @@ void CoreWorkerTest::TestActorReconstruction(
|
|||
|
||||
void CoreWorkerTest::TestActorFailure(
|
||||
std::unordered_map<std::string, double> &resources) {
|
||||
CoreWorker driver(WorkerType::DRIVER, Language::PYTHON, raylet_store_socket_names_[0],
|
||||
raylet_socket_names_[0], NextJobId(), gcs_options_, "", "127.0.0.1",
|
||||
node_manager_port, nullptr);
|
||||
auto &driver = CoreWorkerProcess::GetCoreWorker();
|
||||
|
||||
// creating actor.
|
||||
auto actor_id = CreateActorHelper(driver, resources, 0 /* not reconstructable */);
|
||||
auto actor_id = CreateActorHelper(resources, 0 /* not reconstructable */);
|
||||
|
||||
// Test submitting some tasks with by-value args for that actor.
|
||||
{
|
||||
|
@ -666,16 +690,14 @@ TEST_F(ZeroNodeTest, TestTaskSpecPerf) {
|
|||
}
|
||||
|
||||
TEST_F(SingleNodeTest, TestDirectActorTaskSubmissionPerf) {
|
||||
CoreWorker driver(WorkerType::DRIVER, Language::PYTHON, raylet_store_socket_names_[0],
|
||||
raylet_socket_names_[0], JobID::FromInt(1), gcs_options_, "",
|
||||
"127.0.0.1", node_manager_port, nullptr);
|
||||
auto &driver = CoreWorkerProcess::GetCoreWorker();
|
||||
std::vector<ObjectID> object_ids;
|
||||
// Create an actor.
|
||||
std::unordered_map<std::string, double> resources;
|
||||
auto actor_id = CreateActorHelper(driver, resources,
|
||||
auto actor_id = CreateActorHelper(resources,
|
||||
/*max_reconstructions=*/0);
|
||||
// wait for actor creation finish.
|
||||
ASSERT_TRUE(WaitForDirectCallActorState(driver, actor_id, true, 30 * 1000 /* 30s */));
|
||||
ASSERT_TRUE(WaitForDirectCallActorState(actor_id, true, 30 * 1000 /* 30s */));
|
||||
// Test submitting some tasks with by-value args for that actor.
|
||||
int64_t start_ms = current_time_ms();
|
||||
const int num_tasks = 100000;
|
||||
|
@ -713,7 +735,7 @@ TEST_F(SingleNodeTest, TestDirectActorTaskSubmissionPerf) {
|
|||
TEST_F(ZeroNodeTest, TestWorkerContext) {
|
||||
auto job_id = NextJobId();
|
||||
|
||||
WorkerContext context(WorkerType::WORKER, job_id);
|
||||
WorkerContext context(WorkerType::WORKER, WorkerID::FromRandom(), job_id);
|
||||
ASSERT_TRUE(context.GetCurrentTaskID().IsNil());
|
||||
ASSERT_EQ(context.GetNextTaskIndex(), 1);
|
||||
ASSERT_EQ(context.GetNextTaskIndex(), 2);
|
||||
|
@ -779,7 +801,7 @@ TEST_F(SingleNodeTest, TestMemoryStoreProvider) {
|
|||
absl::flat_hash_set<ObjectID> wait_results;
|
||||
|
||||
ObjectID nonexistent_id = ObjectID::FromRandom().WithDirectTransportType();
|
||||
WorkerContext ctx(WorkerType::WORKER, JobID::Nil());
|
||||
WorkerContext ctx(WorkerType::WORKER, WorkerID::FromRandom(), JobID::Nil());
|
||||
wait_ids.insert(nonexistent_id);
|
||||
RAY_CHECK_OK(provider.Wait(wait_ids, ids.size() + 1, 100, ctx, &wait_results));
|
||||
ASSERT_EQ(wait_results.size(), ids.size());
|
||||
|
@ -880,10 +902,7 @@ TEST_F(SingleNodeTest, TestMemoryStoreProvider) {
|
|||
}
|
||||
|
||||
TEST_F(SingleNodeTest, TestObjectInterface) {
|
||||
CoreWorker core_worker(WorkerType::DRIVER, Language::PYTHON,
|
||||
raylet_store_socket_names_[0], raylet_socket_names_[0],
|
||||
JobID::FromInt(1), gcs_options_, "", "127.0.0.1",
|
||||
node_manager_port, nullptr);
|
||||
auto &core_worker = CoreWorkerProcess::GetCoreWorker();
|
||||
|
||||
uint8_t array1[] = {1, 2, 3, 4, 5, 6, 7, 8};
|
||||
uint8_t array2[] = {10, 11, 12, 13, 14, 15};
|
||||
|
|
|
@ -233,7 +233,7 @@ rpc::PushTaskRequest CreatePushTaskRequestHelper(ActorID actor_id, int64_t count
|
|||
class MockWorkerContext : public WorkerContext {
|
||||
public:
|
||||
MockWorkerContext(WorkerType worker_type, const JobID &job_id)
|
||||
: WorkerContext(worker_type, job_id) {
|
||||
: WorkerContext(worker_type, WorkerID::FromRandom(), job_id) {
|
||||
current_actor_is_direct_call_ = true;
|
||||
}
|
||||
};
|
||||
|
|
|
@ -33,13 +33,34 @@ namespace ray {
|
|||
class MockWorker {
|
||||
public:
|
||||
MockWorker(const std::string &store_socket, const std::string &raylet_socket,
|
||||
int node_manager_port, const gcs::GcsClientOptions &gcs_options)
|
||||
: worker_(WorkerType::WORKER, Language::PYTHON, store_socket, raylet_socket,
|
||||
JobID::FromInt(1), gcs_options, /*log_dir=*/"",
|
||||
/*node_id_address=*/"127.0.0.1", node_manager_port,
|
||||
std::bind(&MockWorker::ExecuteTask, this, _1, _2, _3, _4, _5, _6, _7)) {}
|
||||
int node_manager_port, const gcs::GcsClientOptions &gcs_options) {
|
||||
CoreWorkerOptions options = {
|
||||
WorkerType::WORKER, // worker_type
|
||||
Language::PYTHON, // langauge
|
||||
store_socket, // store_socket
|
||||
raylet_socket, // raylet_socket
|
||||
JobID::FromInt(1), // job_id
|
||||
gcs_options, // gcs_options
|
||||
"", // log_dir
|
||||
true, // install_failure_signal_handler
|
||||
"127.0.0.1", // node_ip_address
|
||||
node_manager_port, // node_manager_port
|
||||
"", // driver_name
|
||||
"", // stdout_file
|
||||
"", // stderr_file
|
||||
std::bind(&MockWorker::ExecuteTask, this, _1, _2, _3, _4, _5, _6,
|
||||
_7), // task_execution_callback
|
||||
nullptr, // check_signals
|
||||
nullptr, // gc_collect
|
||||
nullptr, // get_lang_stack
|
||||
true, // ref_counting_enabled
|
||||
false, // is_local_mode
|
||||
1, // num_workers
|
||||
};
|
||||
CoreWorkerProcess::Initialize(options);
|
||||
}
|
||||
|
||||
void StartExecutingTasks() { worker_.StartExecutingTasks(); }
|
||||
void RunTaskExecutionLoop() { CoreWorkerProcess::RunTaskExecutionLoop(); }
|
||||
|
||||
private:
|
||||
Status ExecuteTask(TaskType task_type, const RayFunction &ray_function,
|
||||
|
@ -112,7 +133,6 @@ class MockWorker {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
CoreWorker worker_;
|
||||
int64_t prev_seq_no_ = 0;
|
||||
};
|
||||
|
||||
|
@ -126,6 +146,6 @@ int main(int argc, char **argv) {
|
|||
|
||||
ray::gcs::GcsClientOptions gcs_options("127.0.0.1", 6379, "");
|
||||
ray::MockWorker worker(store_socket, raylet_socket, node_manager_port, gcs_options);
|
||||
worker.StartExecutingTasks();
|
||||
worker.RunTaskExecutionLoop();
|
||||
return 0;
|
||||
}
|
||||
|
|
|
@ -37,7 +37,7 @@ class MockWaiter : public DependencyWaiter {
|
|||
TEST(SchedulingQueueTest, TestInOrder) {
|
||||
boost::asio::io_service io_service;
|
||||
MockWaiter waiter;
|
||||
WorkerContext context(WorkerType::WORKER, JobID::Nil());
|
||||
WorkerContext context(WorkerType::WORKER, WorkerID::FromRandom(), JobID::Nil());
|
||||
SchedulingQueue queue(io_service, waiter, context);
|
||||
int n_ok = 0;
|
||||
int n_rej = 0;
|
||||
|
@ -58,7 +58,7 @@ TEST(SchedulingQueueTest, TestWaitForObjects) {
|
|||
ObjectID obj3 = ObjectID::FromRandom();
|
||||
boost::asio::io_service io_service;
|
||||
MockWaiter waiter;
|
||||
WorkerContext context(WorkerType::WORKER, JobID::Nil());
|
||||
WorkerContext context(WorkerType::WORKER, WorkerID::FromRandom(), JobID::Nil());
|
||||
SchedulingQueue queue(io_service, waiter, context);
|
||||
int n_ok = 0;
|
||||
int n_rej = 0;
|
||||
|
@ -84,7 +84,7 @@ TEST(SchedulingQueueTest, TestWaitForObjectsNotSubjectToSeqTimeout) {
|
|||
ObjectID obj1 = ObjectID::FromRandom();
|
||||
boost::asio::io_service io_service;
|
||||
MockWaiter waiter;
|
||||
WorkerContext context(WorkerType::WORKER, JobID::Nil());
|
||||
WorkerContext context(WorkerType::WORKER, WorkerID::FromRandom(), JobID::Nil());
|
||||
SchedulingQueue queue(io_service, waiter, context);
|
||||
int n_ok = 0;
|
||||
int n_rej = 0;
|
||||
|
@ -102,7 +102,7 @@ TEST(SchedulingQueueTest, TestWaitForObjectsNotSubjectToSeqTimeout) {
|
|||
TEST(SchedulingQueueTest, TestOutOfOrder) {
|
||||
boost::asio::io_service io_service;
|
||||
MockWaiter waiter;
|
||||
WorkerContext context(WorkerType::WORKER, JobID::Nil());
|
||||
WorkerContext context(WorkerType::WORKER, WorkerID::FromRandom(), JobID::Nil());
|
||||
SchedulingQueue queue(io_service, waiter, context);
|
||||
int n_ok = 0;
|
||||
int n_rej = 0;
|
||||
|
@ -120,7 +120,7 @@ TEST(SchedulingQueueTest, TestOutOfOrder) {
|
|||
TEST(SchedulingQueueTest, TestSeqWaitTimeout) {
|
||||
boost::asio::io_service io_service;
|
||||
MockWaiter waiter;
|
||||
WorkerContext context(WorkerType::WORKER, JobID::Nil());
|
||||
WorkerContext context(WorkerType::WORKER, WorkerID::FromRandom(), JobID::Nil());
|
||||
SchedulingQueue queue(io_service, waiter, context);
|
||||
int n_ok = 0;
|
||||
int n_rej = 0;
|
||||
|
@ -143,7 +143,7 @@ TEST(SchedulingQueueTest, TestSeqWaitTimeout) {
|
|||
TEST(SchedulingQueueTest, TestSkipAlreadyProcessedByClient) {
|
||||
boost::asio::io_service io_service;
|
||||
MockWaiter waiter;
|
||||
WorkerContext context(WorkerType::WORKER, JobID::Nil());
|
||||
WorkerContext context(WorkerType::WORKER, WorkerID::FromRandom(), JobID::Nil());
|
||||
SchedulingQueue queue(io_service, waiter, context);
|
||||
int n_ok = 0;
|
||||
int n_rej = 0;
|
||||
|
|
|
@ -79,7 +79,7 @@ TEST_F(TaskManagerTest, TestTaskSuccess) {
|
|||
ASSERT_TRUE(manager_.IsTaskPending(spec.TaskId()));
|
||||
ASSERT_EQ(reference_counter_->NumObjectIDsInScope(), 3);
|
||||
auto return_id = spec.ReturnId(0, TaskTransportType::DIRECT);
|
||||
WorkerContext ctx(WorkerType::WORKER, JobID::FromInt(0));
|
||||
WorkerContext ctx(WorkerType::WORKER, WorkerID::FromRandom(), JobID::FromInt(0));
|
||||
|
||||
rpc::PushTaskReply reply;
|
||||
auto return_object = reply.add_return_objects();
|
||||
|
@ -119,7 +119,7 @@ TEST_F(TaskManagerTest, TestTaskFailure) {
|
|||
ASSERT_TRUE(manager_.IsTaskPending(spec.TaskId()));
|
||||
ASSERT_EQ(reference_counter_->NumObjectIDsInScope(), 3);
|
||||
auto return_id = spec.ReturnId(0, TaskTransportType::DIRECT);
|
||||
WorkerContext ctx(WorkerType::WORKER, JobID::FromInt(0));
|
||||
WorkerContext ctx(WorkerType::WORKER, WorkerID::FromRandom(), JobID::FromInt(0));
|
||||
|
||||
auto error = rpc::ErrorType::WORKER_DIED;
|
||||
manager_.PendingTaskFailed(spec.TaskId(), error);
|
||||
|
@ -155,7 +155,7 @@ TEST_F(TaskManagerTest, TestTaskRetry) {
|
|||
ASSERT_TRUE(manager_.IsTaskPending(spec.TaskId()));
|
||||
ASSERT_EQ(reference_counter_->NumObjectIDsInScope(), 3);
|
||||
auto return_id = spec.ReturnId(0, TaskTransportType::DIRECT);
|
||||
WorkerContext ctx(WorkerType::WORKER, JobID::FromInt(0));
|
||||
WorkerContext ctx(WorkerType::WORKER, WorkerID::FromRandom(), JobID::FromInt(0));
|
||||
|
||||
auto error = rpc::ErrorType::WORKER_DIED;
|
||||
for (int i = 0; i < num_retries; i++) {
|
||||
|
|
|
@ -568,6 +568,17 @@ class WorkerInfoAccessor {
|
|||
const std::shared_ptr<rpc::WorkerFailureData> &data_ptr,
|
||||
const StatusCallback &callback) = 0;
|
||||
|
||||
/// Register a worker to GCS asynchronously.
|
||||
///
|
||||
/// \param worker_type The type of the worker.
|
||||
/// \param worker_id The ID of the worker.
|
||||
/// \param worker_info The information of the worker.
|
||||
/// \return Status.
|
||||
virtual Status AsyncRegisterWorker(
|
||||
rpc::WorkerType worker_type, const WorkerID &worker_id,
|
||||
const std::unordered_map<std::string, std::string> &worker_info,
|
||||
const StatusCallback &callback) = 0;
|
||||
|
||||
protected:
|
||||
WorkerInfoAccessor() = default;
|
||||
};
|
||||
|
|
|
@ -45,6 +45,8 @@ class GcsClientOptions {
|
|||
password_(password),
|
||||
is_test_client_(is_test_client) {}
|
||||
|
||||
GcsClientOptions() {}
|
||||
|
||||
// GCS server address
|
||||
std::string server_ip_;
|
||||
int server_port_;
|
||||
|
|
|
@ -875,5 +875,25 @@ Status ServiceBasedWorkerInfoAccessor::AsyncReportWorkerFailure(
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ServiceBasedWorkerInfoAccessor::AsyncRegisterWorker(
|
||||
rpc::WorkerType worker_type, const WorkerID &worker_id,
|
||||
const std::unordered_map<std::string, std::string> &worker_info,
|
||||
const StatusCallback &callback) {
|
||||
RAY_LOG(DEBUG) << "Registering the worker. worker id = " << worker_id;
|
||||
rpc::RegisterWorkerRequest request;
|
||||
request.set_worker_type(worker_type);
|
||||
request.set_worker_id(worker_id.Binary());
|
||||
request.mutable_worker_info()->insert(worker_info.begin(), worker_info.end());
|
||||
client_impl_->GetGcsRpcClient().RegisterWorker(
|
||||
request,
|
||||
[worker_id, callback](const Status &status, const rpc::RegisterWorkerReply &reply) {
|
||||
if (callback) {
|
||||
callback(status);
|
||||
}
|
||||
RAY_LOG(DEBUG) << "Finished registering worker. worker id = " << worker_id;
|
||||
});
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace gcs
|
||||
} // namespace ray
|
||||
|
|
|
@ -329,6 +329,11 @@ class ServiceBasedWorkerInfoAccessor : public WorkerInfoAccessor {
|
|||
Status AsyncReportWorkerFailure(const std::shared_ptr<rpc::WorkerFailureData> &data_ptr,
|
||||
const StatusCallback &callback) override;
|
||||
|
||||
Status AsyncRegisterWorker(
|
||||
rpc::WorkerType worker_type, const WorkerID &worker_id,
|
||||
const std::unordered_map<std::string, std::string> &worker_info,
|
||||
const StatusCallback &callback) override;
|
||||
|
||||
private:
|
||||
ServiceBasedGcsClient *client_impl_;
|
||||
|
||||
|
|
|
@ -40,5 +40,27 @@ void DefaultWorkerInfoHandler::HandleReportWorkerFailure(
|
|||
RAY_LOG(DEBUG) << "Finished reporting worker failure, " << worker_address.DebugString();
|
||||
}
|
||||
|
||||
void DefaultWorkerInfoHandler::HandleRegisterWorker(
|
||||
const RegisterWorkerRequest &request, RegisterWorkerReply *reply,
|
||||
SendReplyCallback send_reply_callback) {
|
||||
auto worker_type = request.worker_type();
|
||||
auto worker_id = WorkerID::FromBinary(request.worker_id());
|
||||
auto worker_info = MapFromProtobuf(request.worker_info());
|
||||
|
||||
auto on_done = [worker_id, reply, send_reply_callback](Status status) {
|
||||
if (!status.ok()) {
|
||||
RAY_LOG(ERROR) << "Failed to register worker " << worker_id;
|
||||
}
|
||||
GCS_RPC_SEND_REPLY(send_reply_callback, reply, Status::OK());
|
||||
};
|
||||
|
||||
Status status = gcs_client_.Workers().AsyncRegisterWorker(worker_type, worker_id,
|
||||
worker_info, on_done);
|
||||
if (!status.ok()) {
|
||||
on_done(status);
|
||||
}
|
||||
RAY_LOG(DEBUG) << "Finished registering worker " << worker_id;
|
||||
}
|
||||
|
||||
} // namespace rpc
|
||||
} // namespace ray
|
||||
|
|
|
@ -31,6 +31,10 @@ class DefaultWorkerInfoHandler : public rpc::WorkerInfoHandler {
|
|||
ReportWorkerFailureReply *reply,
|
||||
SendReplyCallback send_reply_callback) override;
|
||||
|
||||
void HandleRegisterWorker(const RegisterWorkerRequest &request,
|
||||
RegisterWorkerReply *reply,
|
||||
SendReplyCallback send_reply_callback) override;
|
||||
|
||||
private:
|
||||
gcs::RedisGcsClient &gcs_client_;
|
||||
};
|
||||
|
|
|
@ -747,6 +747,30 @@ Status RedisWorkerInfoAccessor::AsyncReportWorkerFailure(
|
|||
return worker_failure_table.Add(JobID::Nil(), worker_id, data_ptr, on_done);
|
||||
}
|
||||
|
||||
Status RedisWorkerInfoAccessor::AsyncRegisterWorker(
|
||||
rpc::WorkerType worker_type, const WorkerID &worker_id,
|
||||
const std::unordered_map<std::string, std::string> &worker_info,
|
||||
const StatusCallback &callback) {
|
||||
std::vector<std::string> args;
|
||||
args.emplace_back("HMSET");
|
||||
if (worker_type == rpc::WorkerType::DRIVER) {
|
||||
args.emplace_back("Drivers:" + worker_id.Binary());
|
||||
} else {
|
||||
args.emplace_back("Workers:" + worker_id.Binary());
|
||||
}
|
||||
for (const auto &entry : worker_info) {
|
||||
args.push_back(entry.first);
|
||||
args.push_back(entry.second);
|
||||
}
|
||||
|
||||
auto status = client_impl_->primary_context()->RunArgvAsync(args);
|
||||
if (callback) {
|
||||
// TODO (kfstorm): Invoke the callback asynchronously.
|
||||
callback(status);
|
||||
}
|
||||
return status;
|
||||
}
|
||||
|
||||
} // namespace gcs
|
||||
|
||||
} // namespace ray
|
||||
|
|
|
@ -397,6 +397,11 @@ class RedisWorkerInfoAccessor : public WorkerInfoAccessor {
|
|||
Status AsyncReportWorkerFailure(const std::shared_ptr<WorkerFailureData> &data_ptr,
|
||||
const StatusCallback &callback) override;
|
||||
|
||||
Status AsyncRegisterWorker(
|
||||
rpc::WorkerType worker_type, const WorkerID &worker_id,
|
||||
const std::unordered_map<std::string, std::string> &worker_info,
|
||||
const StatusCallback &callback) override;
|
||||
|
||||
private:
|
||||
RedisGcsClient *client_impl_{nullptr};
|
||||
|
||||
|
|
|
@ -16,6 +16,7 @@ syntax = "proto3";
|
|||
|
||||
package ray.rpc;
|
||||
|
||||
import "src/ray/protobuf/common.proto";
|
||||
import "src/ray/protobuf/gcs.proto";
|
||||
|
||||
message GcsStatus {
|
||||
|
@ -345,8 +346,23 @@ message ReportWorkerFailureReply {
|
|||
GcsStatus status = 1;
|
||||
}
|
||||
|
||||
message RegisterWorkerRequest {
|
||||
/// The type of the worker.
|
||||
WorkerType worker_type = 1;
|
||||
/// The ID of the worker.
|
||||
bytes worker_id = 2;
|
||||
/// The information of the worker in a dictionary.
|
||||
map<string, bytes> worker_info = 3;
|
||||
}
|
||||
|
||||
message RegisterWorkerReply {
|
||||
GcsStatus status = 1;
|
||||
}
|
||||
|
||||
// Service for worker info access.
|
||||
service WorkerInfoGcsService {
|
||||
// Report a worker failure to GCS Service.
|
||||
rpc ReportWorkerFailure(ReportWorkerFailureRequest) returns (ReportWorkerFailureReply);
|
||||
// Register a worker to GCS Service.
|
||||
rpc RegisterWorker(RegisterWorkerRequest) returns (RegisterWorkerReply);
|
||||
}
|
||||
|
|
|
@ -182,6 +182,10 @@ class GcsRpcClient {
|
|||
VOID_GCS_RPC_CLIENT_METHOD(WorkerInfoGcsService, ReportWorkerFailure,
|
||||
worker_info_grpc_client_, )
|
||||
|
||||
/// Register a worker to GCS Service.
|
||||
VOID_GCS_RPC_CLIENT_METHOD(WorkerInfoGcsService, RegisterWorker,
|
||||
worker_info_grpc_client_, )
|
||||
|
||||
private:
|
||||
void Init(const std::string &address, const int port,
|
||||
ClientCallManager &client_call_manager) {
|
||||
|
|
|
@ -390,6 +390,10 @@ class WorkerInfoGcsServiceHandler {
|
|||
virtual void HandleReportWorkerFailure(const ReportWorkerFailureRequest &request,
|
||||
ReportWorkerFailureReply *reply,
|
||||
SendReplyCallback send_reply_callback) = 0;
|
||||
|
||||
virtual void HandleRegisterWorker(const RegisterWorkerRequest &request,
|
||||
RegisterWorkerReply *reply,
|
||||
SendReplyCallback send_reply_callback) = 0;
|
||||
};
|
||||
|
||||
/// The `GrpcService` for `WorkerInfoGcsService`.
|
||||
|
@ -409,6 +413,7 @@ class WorkerInfoGrpcService : public GrpcService {
|
|||
const std::unique_ptr<grpc::ServerCompletionQueue> &cq,
|
||||
std::vector<std::unique_ptr<ServerCallFactory>> *server_call_factories) override {
|
||||
WORKER_INFO_SERVICE_RPC_HANDLER(ReportWorkerFailure);
|
||||
WORKER_INFO_SERVICE_RPC_HANDLER(RegisterWorker);
|
||||
}
|
||||
|
||||
private:
|
||||
|
|
|
@ -1,6 +1,5 @@
|
|||
package org.ray.streaming.runtime.transfer;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
import org.ray.runtime.RayNativeRuntime;
|
||||
import org.ray.runtime.functionmanager.FunctionDescriptor;
|
||||
import org.ray.runtime.functionmanager.JavaFunctionDescriptor;
|
||||
|
@ -24,16 +23,12 @@ public class TransferHandler {
|
|||
private long writerClientNative;
|
||||
private long readerClientNative;
|
||||
|
||||
public TransferHandler(long coreWorkerNative,
|
||||
JavaFunctionDescriptor writerAsyncFunc,
|
||||
public TransferHandler(JavaFunctionDescriptor writerAsyncFunc,
|
||||
JavaFunctionDescriptor writerSyncFunc,
|
||||
JavaFunctionDescriptor readerAsyncFunc,
|
||||
JavaFunctionDescriptor readerSyncFunc) {
|
||||
Preconditions.checkArgument(coreWorkerNative != 0);
|
||||
writerClientNative = createWriterClientNative(
|
||||
coreWorkerNative, writerAsyncFunc, writerSyncFunc);
|
||||
readerClientNative = createReaderClientNative(
|
||||
coreWorkerNative, readerAsyncFunc, readerSyncFunc);
|
||||
writerClientNative = createWriterClientNative(writerAsyncFunc, writerSyncFunc);
|
||||
readerClientNative = createReaderClientNative(readerAsyncFunc, readerSyncFunc);
|
||||
}
|
||||
|
||||
public void onWriterMessage(byte[] buffer) {
|
||||
|
@ -53,12 +48,10 @@ public class TransferHandler {
|
|||
}
|
||||
|
||||
private native long createWriterClientNative(
|
||||
long coreWorkerNative,
|
||||
FunctionDescriptor asyncFunc,
|
||||
FunctionDescriptor syncFunc);
|
||||
|
||||
private native long createReaderClientNative(
|
||||
long coreWorkerNative,
|
||||
FunctionDescriptor asyncFunc,
|
||||
FunctionDescriptor syncFunc);
|
||||
|
||||
|
|
|
@ -2,9 +2,6 @@ package org.ray.streaming.runtime.worker;
|
|||
|
||||
import java.io.Serializable;
|
||||
import java.util.Map;
|
||||
|
||||
import org.ray.api.Ray;
|
||||
import org.ray.runtime.RayMultiWorkerNativeRuntime;
|
||||
import org.ray.runtime.functionmanager.JavaFunctionDescriptor;
|
||||
import org.ray.streaming.runtime.core.graph.ExecutionGraph;
|
||||
import org.ray.streaming.runtime.core.graph.ExecutionNode;
|
||||
|
@ -62,7 +59,6 @@ public class JobWorker implements Serializable {
|
|||
Config.CHANNEL_TYPE, Config.DEFAULT_CHANNEL_TYPE);
|
||||
if (channelType.equals(Config.NATIVE_CHANNEL)) {
|
||||
transferHandler = new TransferHandler(
|
||||
getNativeCoreWorker(),
|
||||
new JavaFunctionDescriptor(JobWorker.class.getName(), "onWriterMessage", "([B)V"),
|
||||
new JavaFunctionDescriptor(JobWorker.class.getName(), "onWriterMessageSync", "([B)[B"),
|
||||
new JavaFunctionDescriptor(JobWorker.class.getName(), "onReaderMessage", "([B)V"),
|
||||
|
@ -148,13 +144,4 @@ public class JobWorker implements Serializable {
|
|||
public byte[] onWriterMessageSync(byte[] buffer) {
|
||||
return transferHandler.onWriterMessageSync(buffer);
|
||||
}
|
||||
|
||||
private static long getNativeCoreWorker() {
|
||||
long pointer = 0;
|
||||
if (Ray.internal() instanceof RayMultiWorkerNativeRuntime) {
|
||||
pointer = ((RayMultiWorkerNativeRuntime) Ray.internal())
|
||||
.getCurrentRuntime().getNativeCoreWorkerPointer();
|
||||
}
|
||||
return pointer;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -10,7 +10,6 @@ import java.util.Random;
|
|||
import org.ray.api.Ray;
|
||||
import org.ray.api.RayActor;
|
||||
import org.ray.api.id.ActorId;
|
||||
import org.ray.runtime.RayMultiWorkerNativeRuntime;
|
||||
import org.ray.runtime.actor.NativeRayActor;
|
||||
import org.ray.runtime.functionmanager.JavaFunctionDescriptor;
|
||||
import org.ray.streaming.runtime.transfer.ChannelID;
|
||||
|
@ -29,8 +28,7 @@ public class Worker {
|
|||
protected TransferHandler transferHandler = null;
|
||||
|
||||
public Worker() {
|
||||
transferHandler = new TransferHandler(((RayMultiWorkerNativeRuntime) Ray.internal())
|
||||
.getCurrentRuntime().getNativeCoreWorkerPointer(),
|
||||
transferHandler = new TransferHandler(
|
||||
new JavaFunctionDescriptor(Worker.class.getName(),
|
||||
"onWriterMessage", "([B)V"),
|
||||
new JavaFunctionDescriptor(Worker.class.getName(),
|
||||
|
|
|
@ -31,7 +31,6 @@ from ray.includes.unique_ids cimport (
|
|||
CTaskID,
|
||||
CObjectID,
|
||||
)
|
||||
from ray.includes.libcoreworker cimport CCoreWorker
|
||||
|
||||
cdef extern from "status.h" namespace "ray::streaming" nogil:
|
||||
cdef cppclass CStreamingStatus "ray::streaming::StreamingStatus":
|
||||
|
@ -100,15 +99,13 @@ cdef extern from "message/message_bundle.h" namespace "ray::streaming" nogil:
|
|||
|
||||
cdef extern from "queue/queue_client.h" namespace "ray::streaming" nogil:
|
||||
cdef cppclass CReaderClient "ray::streaming::ReaderClient":
|
||||
CReaderClient(CCoreWorker *core_worker,
|
||||
CRayFunction &async_func,
|
||||
CReaderClient(CRayFunction &async_func,
|
||||
CRayFunction &sync_func)
|
||||
void OnReaderMessage(shared_ptr[CLocalMemoryBuffer] buffer);
|
||||
shared_ptr[CLocalMemoryBuffer] OnReaderMessageSync(shared_ptr[CLocalMemoryBuffer] buffer);
|
||||
|
||||
cdef cppclass CWriterClient "ray::streaming::WriterClient":
|
||||
CWriterClient(CCoreWorker *core_worker,
|
||||
CRayFunction &async_func,
|
||||
CWriterClient(CRayFunction &async_func,
|
||||
CRayFunction &sync_func)
|
||||
void OnWriterMessage(shared_ptr[CLocalMemoryBuffer] buffer);
|
||||
shared_ptr[CLocalMemoryBuffer] OnWriterMessageSync(shared_ptr[CLocalMemoryBuffer] buffer);
|
||||
|
|
|
@ -19,14 +19,11 @@ from ray.includes.unique_ids cimport (
|
|||
)
|
||||
from ray._raylet cimport (
|
||||
Buffer,
|
||||
CoreWorker,
|
||||
ActorID,
|
||||
ObjectID,
|
||||
FunctionDescriptor,
|
||||
)
|
||||
|
||||
from ray.includes.libcoreworker cimport CCoreWorker
|
||||
|
||||
cimport ray.streaming.includes.libstreaming as libstreaming
|
||||
from ray.streaming.includes.libstreaming cimport (
|
||||
CStreamingStatus,
|
||||
|
@ -52,16 +49,14 @@ cdef class ReaderClient:
|
|||
CReaderClient *client
|
||||
|
||||
def __cinit__(self,
|
||||
CoreWorker worker,
|
||||
FunctionDescriptor async_func,
|
||||
FunctionDescriptor sync_func):
|
||||
cdef:
|
||||
CCoreWorker *core_worker = worker.core_worker.get()
|
||||
CRayFunction async_native_func
|
||||
CRayFunction sync_native_func
|
||||
async_native_func = CRayFunction(LANGUAGE_PYTHON, async_func.descriptor)
|
||||
sync_native_func = CRayFunction(LANGUAGE_PYTHON, sync_func.descriptor)
|
||||
self.client = new CReaderClient(core_worker, async_native_func, sync_native_func)
|
||||
self.client = new CReaderClient(async_native_func, sync_native_func)
|
||||
|
||||
def __dealloc__(self):
|
||||
del self.client
|
||||
|
@ -91,16 +86,14 @@ cdef class WriterClient:
|
|||
CWriterClient * client
|
||||
|
||||
def __cinit__(self,
|
||||
CoreWorker worker,
|
||||
FunctionDescriptor async_func,
|
||||
FunctionDescriptor sync_func):
|
||||
cdef:
|
||||
CCoreWorker *core_worker = worker.core_worker.get()
|
||||
CRayFunction async_native_func
|
||||
CRayFunction sync_native_func
|
||||
async_native_func = CRayFunction(LANGUAGE_PYTHON, async_func.descriptor)
|
||||
sync_native_func = CRayFunction(LANGUAGE_PYTHON, sync_func.descriptor)
|
||||
self.client = new CWriterClient(core_worker, async_native_func, sync_native_func)
|
||||
self.client = new CWriterClient(async_native_func, sync_native_func)
|
||||
|
||||
def __dealloc__(self):
|
||||
del self.client
|
||||
|
|
|
@ -48,7 +48,6 @@ class JobWorker(object):
|
|||
self.task_id, self.stream_processor))
|
||||
|
||||
if self.config.get(Config.CHANNEL_TYPE, Config.NATIVE_CHANNEL):
|
||||
core_worker = ray.worker.global_worker.core_worker
|
||||
reader_async_func = PythonFunctionDescriptor(
|
||||
__name__, self.on_reader_message.__name__,
|
||||
self.__class__.__name__)
|
||||
|
@ -56,7 +55,7 @@ class JobWorker(object):
|
|||
__name__, self.on_reader_message_sync.__name__,
|
||||
self.__class__.__name__)
|
||||
self.reader_client = _streaming.ReaderClient(
|
||||
core_worker, reader_async_func, reader_sync_func)
|
||||
reader_async_func, reader_sync_func)
|
||||
writer_async_func = PythonFunctionDescriptor(
|
||||
__name__, self.on_writer_message.__name__,
|
||||
self.__class__.__name__)
|
||||
|
@ -64,7 +63,7 @@ class JobWorker(object):
|
|||
__name__, self.on_writer_message_sync.__name__,
|
||||
self.__class__.__name__)
|
||||
self.writer_client = _streaming.WriterClient(
|
||||
core_worker, writer_async_func, writer_sync_func)
|
||||
writer_async_func, writer_sync_func)
|
||||
|
||||
self.task = self.create_stream_task()
|
||||
self.task.start()
|
||||
|
|
|
@ -12,21 +12,20 @@ from ray.streaming.config import Config
|
|||
@ray.remote
|
||||
class Worker:
|
||||
def __init__(self):
|
||||
core_worker = ray.worker.global_worker.core_worker
|
||||
writer_async_func = PythonFunctionDescriptor(
|
||||
__name__, self.on_writer_message.__name__, self.__class__.__name__)
|
||||
writer_sync_func = PythonFunctionDescriptor(
|
||||
__name__, self.on_writer_message_sync.__name__,
|
||||
self.__class__.__name__)
|
||||
self.writer_client = _streaming.WriterClient(
|
||||
core_worker, writer_async_func, writer_sync_func)
|
||||
self.writer_client = _streaming.WriterClient(writer_async_func,
|
||||
writer_sync_func)
|
||||
reader_async_func = PythonFunctionDescriptor(
|
||||
__name__, self.on_reader_message.__name__, self.__class__.__name__)
|
||||
reader_sync_func = PythonFunctionDescriptor(
|
||||
__name__, self.on_reader_message_sync.__name__,
|
||||
self.__class__.__name__)
|
||||
self.reader_client = _streaming.ReaderClient(
|
||||
core_worker, reader_async_func, reader_sync_func)
|
||||
self.reader_client = _streaming.ReaderClient(reader_async_func,
|
||||
reader_sync_func)
|
||||
self.writer = None
|
||||
self.output_channel_id = None
|
||||
self.reader = None
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
#include <unordered_set>
|
||||
|
||||
#include "event_service.h"
|
||||
|
||||
namespace ray {
|
||||
namespace streaming {
|
||||
|
||||
|
@ -105,7 +106,11 @@ Event &EventQueue::Front() {
|
|||
}
|
||||
|
||||
EventService::EventService(uint32_t event_size)
|
||||
: event_queue_(std::make_shared<EventQueue>(event_size)), stop_flag_(false) {}
|
||||
: worker_id_(CoreWorkerProcess::IsInitialized()
|
||||
? CoreWorkerProcess::GetCoreWorker().GetWorkerID()
|
||||
: WorkerID::Nil()),
|
||||
event_queue_(std::make_shared<EventQueue>(event_size)),
|
||||
stop_flag_(false) {}
|
||||
EventService::~EventService() {
|
||||
stop_flag_ = true;
|
||||
// No need to join if loop thread has never been created.
|
||||
|
@ -154,6 +159,9 @@ void EventService::Execute(Event &event) {
|
|||
}
|
||||
|
||||
void EventService::LoopThreadHandler() {
|
||||
if (CoreWorkerProcess::IsInitialized()) {
|
||||
CoreWorkerProcess::SetCurrentThreadWorkerId(worker_id_);
|
||||
}
|
||||
while (true) {
|
||||
if (stop_flag_) {
|
||||
break;
|
||||
|
|
|
@ -7,6 +7,7 @@
|
|||
#include <unordered_map>
|
||||
|
||||
#include "channel.h"
|
||||
#include "ray/core_worker/core_worker.h"
|
||||
#include "ring_buffer.h"
|
||||
#include "util/streaming_util.h"
|
||||
|
||||
|
@ -127,6 +128,7 @@ class EventService {
|
|||
void LoopThreadHandler();
|
||||
|
||||
private:
|
||||
WorkerID worker_id_;
|
||||
std::unordered_map<EventType, Handle, EnumTypeHash> event_handle_map_;
|
||||
std::shared_ptr<EventQueue> event_queue_;
|
||||
std::shared_ptr<std::thread> loop_thread_;
|
||||
|
|
|
@ -14,25 +14,19 @@ static std::shared_ptr<ray::LocalMemoryBuffer> JByteArrayToBuffer(JNIEnv *env,
|
|||
|
||||
JNIEXPORT jlong JNICALL
|
||||
Java_org_ray_streaming_runtime_transfer_TransferHandler_createWriterClientNative(
|
||||
JNIEnv *env, jobject this_obj, jlong core_worker_ptr, jobject async_func,
|
||||
jobject sync_func) {
|
||||
JNIEnv *env, jobject this_obj, jobject async_func, jobject sync_func) {
|
||||
auto ray_async_func = FunctionDescriptorToRayFunction(env, async_func);
|
||||
auto ray_sync_func = FunctionDescriptorToRayFunction(env, sync_func);
|
||||
auto *writer_client =
|
||||
new WriterClient(reinterpret_cast<ray::CoreWorker *>(core_worker_ptr),
|
||||
ray_async_func, ray_sync_func);
|
||||
auto *writer_client = new WriterClient(ray_async_func, ray_sync_func);
|
||||
return reinterpret_cast<jlong>(writer_client);
|
||||
}
|
||||
|
||||
JNIEXPORT jlong JNICALL
|
||||
Java_org_ray_streaming_runtime_transfer_TransferHandler_createReaderClientNative(
|
||||
JNIEnv *env, jobject this_obj, jlong core_worker_ptr, jobject async_func,
|
||||
jobject sync_func) {
|
||||
JNIEnv *env, jobject this_obj, jobject async_func, jobject sync_func) {
|
||||
ray::RayFunction ray_async_func = FunctionDescriptorToRayFunction(env, async_func);
|
||||
ray::RayFunction ray_sync_func = FunctionDescriptorToRayFunction(env, sync_func);
|
||||
auto *reader_client =
|
||||
new ReaderClient(reinterpret_cast<ray::CoreWorker *>(core_worker_ptr),
|
||||
ray_async_func, ray_sync_func);
|
||||
auto *reader_client = new ReaderClient(ray_async_func, ray_sync_func);
|
||||
return reinterpret_cast<jlong>(reader_client);
|
||||
}
|
||||
|
||||
|
|
|
@ -12,48 +12,58 @@ extern "C" {
|
|||
* Method: createWriterClientNative
|
||||
* Signature: (J)J
|
||||
*/
|
||||
JNIEXPORT jlong JNICALL Java_org_ray_streaming_runtime_transfer_TransferHandler_createWriterClientNative
|
||||
(JNIEnv *, jobject, jlong, jobject, jobject);
|
||||
JNIEXPORT jlong JNICALL
|
||||
Java_org_ray_streaming_runtime_transfer_TransferHandler_createWriterClientNative(JNIEnv *,
|
||||
jobject,
|
||||
jobject,
|
||||
jobject);
|
||||
|
||||
/*
|
||||
* Class: org_ray_streaming_runtime_transfer_TransferHandler
|
||||
* Method: createReaderClientNative
|
||||
* Signature: (J)J
|
||||
*/
|
||||
JNIEXPORT jlong JNICALL Java_org_ray_streaming_runtime_transfer_TransferHandler_createReaderClientNative
|
||||
(JNIEnv *, jobject, jlong, jobject, jobject);
|
||||
JNIEXPORT jlong JNICALL
|
||||
Java_org_ray_streaming_runtime_transfer_TransferHandler_createReaderClientNative(JNIEnv *,
|
||||
jobject,
|
||||
jobject,
|
||||
jobject);
|
||||
|
||||
/*
|
||||
* Class: org_ray_streaming_runtime_transfer_TransferHandler
|
||||
* Method: handleWriterMessageNative
|
||||
* Signature: (J[B)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL Java_org_ray_streaming_runtime_transfer_TransferHandler_handleWriterMessageNative
|
||||
(JNIEnv *, jobject, jlong, jbyteArray);
|
||||
JNIEXPORT void JNICALL
|
||||
Java_org_ray_streaming_runtime_transfer_TransferHandler_handleWriterMessageNative(
|
||||
JNIEnv *, jobject, jlong, jbyteArray);
|
||||
|
||||
/*
|
||||
* Class: org_ray_streaming_runtime_transfer_TransferHandler
|
||||
* Method: handleWriterMessageSyncNative
|
||||
* Signature: (J[B)[B
|
||||
*/
|
||||
JNIEXPORT jbyteArray JNICALL Java_org_ray_streaming_runtime_transfer_TransferHandler_handleWriterMessageSyncNative
|
||||
(JNIEnv *, jobject, jlong, jbyteArray);
|
||||
JNIEXPORT jbyteArray JNICALL
|
||||
Java_org_ray_streaming_runtime_transfer_TransferHandler_handleWriterMessageSyncNative(
|
||||
JNIEnv *, jobject, jlong, jbyteArray);
|
||||
|
||||
/*
|
||||
* Class: org_ray_streaming_runtime_transfer_TransferHandler
|
||||
* Method: handleReaderMessageNative
|
||||
* Signature: (J[B)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL Java_org_ray_streaming_runtime_transfer_TransferHandler_handleReaderMessageNative
|
||||
(JNIEnv *, jobject, jlong, jbyteArray);
|
||||
JNIEXPORT void JNICALL
|
||||
Java_org_ray_streaming_runtime_transfer_TransferHandler_handleReaderMessageNative(
|
||||
JNIEnv *, jobject, jlong, jbyteArray);
|
||||
|
||||
/*
|
||||
* Class: org_ray_streaming_runtime_transfer_TransferHandler
|
||||
* Method: handleReaderMessageSyncNative
|
||||
* Signature: (J[B)[B
|
||||
*/
|
||||
JNIEXPORT jbyteArray JNICALL Java_org_ray_streaming_runtime_transfer_TransferHandler_handleReaderMessageSyncNative
|
||||
(JNIEnv *, jobject, jlong, jbyteArray);
|
||||
JNIEXPORT jbyteArray JNICALL
|
||||
Java_org_ray_streaming_runtime_transfer_TransferHandler_handleReaderMessageSyncNative(
|
||||
JNIEnv *, jobject, jlong, jbyteArray);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
|
|
|
@ -14,16 +14,14 @@ namespace streaming {
|
|||
class ReaderClient {
|
||||
public:
|
||||
/// Construct a ReaderClient object.
|
||||
/// \param[in] core_worker CoreWorker C++ pointer of current actor
|
||||
/// \param[in] async_func DataReader's raycall function descriptor to be called by
|
||||
/// DataWriter, asynchronous semantics \param[in] sync_func DataReader's raycall
|
||||
/// function descriptor to be called by DataWriter, synchronous semantics
|
||||
ReaderClient(CoreWorker *core_worker, RayFunction &async_func, RayFunction &sync_func)
|
||||
: core_worker_(core_worker) {
|
||||
ReaderClient(RayFunction &async_func, RayFunction &sync_func) {
|
||||
DownstreamQueueMessageHandler::peer_async_function_ = async_func;
|
||||
DownstreamQueueMessageHandler::peer_sync_function_ = sync_func;
|
||||
downstream_handler_ = ray::streaming::DownstreamQueueMessageHandler::CreateService(
|
||||
core_worker_, core_worker_->GetWorkerContext().GetCurrentActorID());
|
||||
CoreWorkerProcess::GetCoreWorker().GetWorkerContext().GetCurrentActorID());
|
||||
}
|
||||
|
||||
/// Post buffer to downstream queue service, asynchronously.
|
||||
|
@ -34,19 +32,17 @@ class ReaderClient {
|
|||
std::shared_ptr<LocalMemoryBuffer> buffer);
|
||||
|
||||
private:
|
||||
CoreWorker *core_worker_;
|
||||
std::shared_ptr<DownstreamQueueMessageHandler> downstream_handler_;
|
||||
};
|
||||
|
||||
/// Interface of streaming queue for DataWriter. Similar to ReaderClient.
|
||||
class WriterClient {
|
||||
public:
|
||||
WriterClient(CoreWorker *core_worker, RayFunction &async_func, RayFunction &sync_func)
|
||||
: core_worker_(core_worker) {
|
||||
WriterClient(RayFunction &async_func, RayFunction &sync_func) {
|
||||
UpstreamQueueMessageHandler::peer_async_function_ = async_func;
|
||||
UpstreamQueueMessageHandler::peer_sync_function_ = sync_func;
|
||||
upstream_handler_ = ray::streaming::UpstreamQueueMessageHandler::CreateService(
|
||||
core_worker, core_worker_->GetWorkerContext().GetCurrentActorID());
|
||||
CoreWorkerProcess::GetCoreWorker().GetWorkerContext().GetCurrentActorID());
|
||||
}
|
||||
|
||||
void OnWriterMessage(std::shared_ptr<LocalMemoryBuffer> buffer);
|
||||
|
@ -54,7 +50,6 @@ class WriterClient {
|
|||
std::shared_ptr<LocalMemoryBuffer> buffer);
|
||||
|
||||
private:
|
||||
CoreWorker *core_worker_;
|
||||
std::shared_ptr<UpstreamQueueMessageHandler> upstream_handler_;
|
||||
};
|
||||
} // namespace streaming
|
||||
|
|
|
@ -85,8 +85,8 @@ std::shared_ptr<Transport> QueueMessageHandler::GetOutTransport(
|
|||
void QueueMessageHandler::SetPeerActorID(const ObjectID &queue_id,
|
||||
const ActorID &actor_id) {
|
||||
actors_.emplace(queue_id, actor_id);
|
||||
out_transports_.emplace(
|
||||
queue_id, std::make_shared<ray::streaming::Transport>(core_worker_, actor_id));
|
||||
out_transports_.emplace(queue_id,
|
||||
std::make_shared<ray::streaming::Transport>(actor_id));
|
||||
}
|
||||
|
||||
ActorID QueueMessageHandler::GetPeerActorID(const ObjectID &queue_id) {
|
||||
|
@ -113,10 +113,9 @@ void QueueMessageHandler::Stop() {
|
|||
}
|
||||
|
||||
std::shared_ptr<UpstreamQueueMessageHandler> UpstreamQueueMessageHandler::CreateService(
|
||||
CoreWorker *core_worker, const ActorID &actor_id) {
|
||||
const ActorID &actor_id) {
|
||||
if (nullptr == upstream_handler_) {
|
||||
upstream_handler_ =
|
||||
std::make_shared<UpstreamQueueMessageHandler>(core_worker, actor_id);
|
||||
upstream_handler_ = std::make_shared<UpstreamQueueMessageHandler>(actor_id);
|
||||
}
|
||||
return upstream_handler_;
|
||||
}
|
||||
|
@ -247,11 +246,9 @@ void UpstreamQueueMessageHandler::ReleaseAllUpQueues() {
|
|||
}
|
||||
|
||||
std::shared_ptr<DownstreamQueueMessageHandler>
|
||||
DownstreamQueueMessageHandler::CreateService(CoreWorker *core_worker,
|
||||
const ActorID &actor_id) {
|
||||
DownstreamQueueMessageHandler::CreateService(const ActorID &actor_id) {
|
||||
if (nullptr == downstream_handler_) {
|
||||
downstream_handler_ =
|
||||
std::make_shared<DownstreamQueueMessageHandler>(core_worker, actor_id);
|
||||
downstream_handler_ = std::make_shared<DownstreamQueueMessageHandler>(actor_id);
|
||||
}
|
||||
return downstream_handler_;
|
||||
}
|
||||
|
|
|
@ -24,16 +24,9 @@ namespace streaming {
|
|||
class QueueMessageHandler {
|
||||
public:
|
||||
/// Construct a QueueMessageHandler instance.
|
||||
/// \param[in] core_worker CoreWorker C++ pointer of current actor, used to call Core
|
||||
/// Worker's api.
|
||||
/// For Python worker, the pointer can be obtained from
|
||||
/// ray.worker.global_worker.core_worker; For Java worker, obtained from
|
||||
/// RayNativeRuntime object through java reflection.
|
||||
/// \param[in] actor_id actor id of current actor.
|
||||
QueueMessageHandler(CoreWorker *core_worker, const ActorID &actor_id)
|
||||
: core_worker_(core_worker),
|
||||
actor_id_(actor_id),
|
||||
queue_dummy_work_(queue_service_) {
|
||||
QueueMessageHandler(const ActorID &actor_id)
|
||||
: actor_id_(actor_id), queue_dummy_work_(queue_service_) {
|
||||
Start();
|
||||
}
|
||||
|
||||
|
@ -87,8 +80,6 @@ class QueueMessageHandler {
|
|||
void QueueThreadCallback() { queue_service_.run(); }
|
||||
|
||||
protected:
|
||||
/// CoreWorker C++ pointer of current actor
|
||||
CoreWorker *core_worker_;
|
||||
/// actor_id actor id of current actor
|
||||
ActorID actor_id_;
|
||||
/// Helper function, parse message buffer to Message object.
|
||||
|
@ -111,8 +102,7 @@ class QueueMessageHandler {
|
|||
class UpstreamQueueMessageHandler : public QueueMessageHandler {
|
||||
public:
|
||||
/// Construct a UpstreamQueueMessageHandler instance.
|
||||
UpstreamQueueMessageHandler(CoreWorker *core_worker, const ActorID &actor_id)
|
||||
: QueueMessageHandler(core_worker, actor_id) {}
|
||||
UpstreamQueueMessageHandler(const ActorID &actor_id) : QueueMessageHandler(actor_id) {}
|
||||
/// Create a upstream queue.
|
||||
/// \param[in] queue_id queue id of the queue to be created.
|
||||
/// \param[in] peer_actor_id actor id of peer actor.
|
||||
|
@ -140,7 +130,7 @@ class UpstreamQueueMessageHandler : public QueueMessageHandler {
|
|||
std::function<void(std::shared_ptr<LocalMemoryBuffer>)> callback) override;
|
||||
|
||||
static std::shared_ptr<UpstreamQueueMessageHandler> CreateService(
|
||||
CoreWorker *core_worker, const ActorID &actor_id);
|
||||
const ActorID &actor_id);
|
||||
static std::shared_ptr<UpstreamQueueMessageHandler> GetService();
|
||||
|
||||
static RayFunction peer_sync_function_;
|
||||
|
@ -157,8 +147,8 @@ class UpstreamQueueMessageHandler : public QueueMessageHandler {
|
|||
/// UpstreamQueueMessageHandler holds and manages all downstream queues of current actor.
|
||||
class DownstreamQueueMessageHandler : public QueueMessageHandler {
|
||||
public:
|
||||
DownstreamQueueMessageHandler(CoreWorker *core_worker, const ActorID &actor_id)
|
||||
: QueueMessageHandler(core_worker, actor_id) {}
|
||||
DownstreamQueueMessageHandler(const ActorID &actor_id)
|
||||
: QueueMessageHandler(actor_id) {}
|
||||
std::shared_ptr<ReaderQueue> CreateDownstreamQueue(const ObjectID &queue_id,
|
||||
const ActorID &peer_actor_id);
|
||||
bool DownstreamQueueExists(const ObjectID &queue_id);
|
||||
|
@ -178,7 +168,7 @@ class DownstreamQueueMessageHandler : public QueueMessageHandler {
|
|||
std::function<void(std::shared_ptr<LocalMemoryBuffer>)> callback);
|
||||
|
||||
static std::shared_ptr<DownstreamQueueMessageHandler> CreateService(
|
||||
CoreWorker *core_worker, const ActorID &actor_id);
|
||||
const ActorID &actor_id);
|
||||
static std::shared_ptr<DownstreamQueueMessageHandler> GetService();
|
||||
static RayFunction peer_sync_function_;
|
||||
static RayFunction peer_async_function_;
|
||||
|
|
|
@ -28,10 +28,9 @@ void Transport::SendInternal(std::shared_ptr<LocalMemoryBuffer> buffer,
|
|||
args.emplace_back(TaskArg::PassByValue(std::make_shared<RayObject>(
|
||||
std::move(buffer), meta, std::vector<ObjectID>(), true)));
|
||||
|
||||
STREAMING_CHECK(core_worker_ != nullptr);
|
||||
std::vector<std::shared_ptr<RayObject>> results;
|
||||
ray::Status st =
|
||||
core_worker_->SubmitActorTask(peer_actor_id_, function, args, options, &return_ids);
|
||||
ray::Status st = CoreWorkerProcess::GetCoreWorker().SubmitActorTask(
|
||||
peer_actor_id_, function, args, options, &return_ids);
|
||||
if (!st.ok()) {
|
||||
STREAMING_LOG(ERROR) << "SubmitActorTask failed. " << st;
|
||||
}
|
||||
|
@ -50,7 +49,8 @@ std::shared_ptr<LocalMemoryBuffer> Transport::SendForResult(
|
|||
SendInternal(buffer, function, TASK_OPTION_RETURN_NUM_1, return_ids);
|
||||
|
||||
std::vector<std::shared_ptr<RayObject>> results;
|
||||
Status get_st = core_worker_->Get(return_ids, timeout_ms, &results);
|
||||
Status get_st =
|
||||
CoreWorkerProcess::GetCoreWorker().Get(return_ids, timeout_ms, &results);
|
||||
if (!get_st.ok()) {
|
||||
STREAMING_LOG(ERROR) << "Get fail.";
|
||||
return nullptr;
|
||||
|
|
|
@ -13,11 +13,10 @@ namespace streaming {
|
|||
class Transport {
|
||||
public:
|
||||
/// Construct a Transport object.
|
||||
/// \param[in] core_worker CoreWorker C++ pointer of current actor, which we call direct
|
||||
/// actor call interface with.
|
||||
/// \param[in] peer_actor_id actor id of peer actor.
|
||||
Transport(CoreWorker *core_worker, const ActorID &peer_actor_id)
|
||||
: core_worker_(core_worker), peer_actor_id_(peer_actor_id) {}
|
||||
Transport(const ActorID &peer_actor_id)
|
||||
: worker_id_(CoreWorkerProcess::GetCoreWorker().GetWorkerID()),
|
||||
peer_actor_id_(peer_actor_id) {}
|
||||
virtual ~Transport() = default;
|
||||
|
||||
/// Send buffer asynchronously, peer's `function` will be called.
|
||||
|
@ -55,7 +54,7 @@ class Transport {
|
|||
std::vector<ObjectID> &return_ids);
|
||||
|
||||
private:
|
||||
CoreWorker *core_worker_;
|
||||
WorkerID worker_id_;
|
||||
ActorID peer_actor_id_;
|
||||
};
|
||||
} // namespace streaming
|
||||
|
|
|
@ -20,11 +20,9 @@ namespace streaming {
|
|||
|
||||
class StreamingQueueTestSuite {
|
||||
public:
|
||||
StreamingQueueTestSuite(std::shared_ptr<CoreWorker> core_worker, ActorID &peer_actor_id,
|
||||
std::vector<ObjectID> queue_ids,
|
||||
StreamingQueueTestSuite(ActorID &peer_actor_id, std::vector<ObjectID> queue_ids,
|
||||
std::vector<ObjectID> rescale_queue_ids)
|
||||
: core_worker_(core_worker),
|
||||
peer_actor_id_(peer_actor_id),
|
||||
: peer_actor_id_(peer_actor_id),
|
||||
queue_ids_(queue_ids),
|
||||
rescale_queue_ids_(rescale_queue_ids) {}
|
||||
|
||||
|
@ -52,7 +50,6 @@ class StreamingQueueTestSuite {
|
|||
std::string current_test_;
|
||||
bool status_;
|
||||
std::shared_ptr<std::thread> executor_thread_;
|
||||
std::shared_ptr<CoreWorker> core_worker_;
|
||||
ActorID peer_actor_id_;
|
||||
std::vector<ObjectID> queue_ids_;
|
||||
std::vector<ObjectID> rescale_queue_ids_;
|
||||
|
@ -60,11 +57,9 @@ class StreamingQueueTestSuite {
|
|||
|
||||
class StreamingQueueWriterTestSuite : public StreamingQueueTestSuite {
|
||||
public:
|
||||
StreamingQueueWriterTestSuite(std::shared_ptr<CoreWorker> core_worker,
|
||||
ActorID &peer_actor_id, std::vector<ObjectID> queue_ids,
|
||||
StreamingQueueWriterTestSuite(ActorID &peer_actor_id, std::vector<ObjectID> queue_ids,
|
||||
std::vector<ObjectID> rescale_queue_ids)
|
||||
: StreamingQueueTestSuite(core_worker, peer_actor_id, queue_ids,
|
||||
rescale_queue_ids) {
|
||||
: StreamingQueueTestSuite(peer_actor_id, queue_ids, rescale_queue_ids) {
|
||||
test_func_map_ = {
|
||||
{"streaming_writer_exactly_once_test",
|
||||
std::bind(&StreamingQueueWriterTestSuite::StreamingWriterExactlyOnceTest,
|
||||
|
@ -135,11 +130,9 @@ class StreamingQueueWriterTestSuite : public StreamingQueueTestSuite {
|
|||
|
||||
class StreamingQueueReaderTestSuite : public StreamingQueueTestSuite {
|
||||
public:
|
||||
StreamingQueueReaderTestSuite(std::shared_ptr<CoreWorker> core_worker,
|
||||
ActorID peer_actor_id, std::vector<ObjectID> queue_ids,
|
||||
StreamingQueueReaderTestSuite(ActorID peer_actor_id, std::vector<ObjectID> queue_ids,
|
||||
std::vector<ObjectID> rescale_queue_ids)
|
||||
: StreamingQueueTestSuite(core_worker, peer_actor_id, queue_ids,
|
||||
rescale_queue_ids) {
|
||||
: StreamingQueueTestSuite(peer_actor_id, queue_ids, rescale_queue_ids) {
|
||||
test_func_map_ = {
|
||||
{"streaming_writer_exactly_once_test",
|
||||
std::bind(&StreamingQueueReaderTestSuite::StreamingWriterExactlyOnceTest,
|
||||
|
@ -247,7 +240,7 @@ class StreamingQueueReaderTestSuite : public StreamingQueueTestSuite {
|
|||
class TestSuiteFactory {
|
||||
public:
|
||||
static std::shared_ptr<StreamingQueueTestSuite> CreateTestSuite(
|
||||
std::shared_ptr<CoreWorker> worker, std::shared_ptr<TestInitMessage> message) {
|
||||
std::shared_ptr<TestInitMessage> message) {
|
||||
std::shared_ptr<StreamingQueueTestSuite> test_suite = nullptr;
|
||||
std::string suite_name = message->TestSuiteName();
|
||||
queue::protobuf::StreamingQueueTestRole role = message->Role();
|
||||
|
@ -258,14 +251,14 @@ class TestSuiteFactory {
|
|||
if (role == queue::protobuf::StreamingQueueTestRole::WRITER) {
|
||||
if (suite_name == "StreamingWriterTest") {
|
||||
test_suite = std::make_shared<StreamingQueueWriterTestSuite>(
|
||||
worker, peer_actor_id, queue_ids, rescale_queue_ids);
|
||||
peer_actor_id, queue_ids, rescale_queue_ids);
|
||||
} else {
|
||||
STREAMING_CHECK(false) << "unsurported suite_name: " << suite_name;
|
||||
}
|
||||
} else {
|
||||
if (suite_name == "StreamingWriterTest") {
|
||||
test_suite = std::make_shared<StreamingQueueReaderTestSuite>(
|
||||
worker, peer_actor_id, queue_ids, rescale_queue_ids);
|
||||
peer_actor_id, queue_ids, rescale_queue_ids);
|
||||
} else {
|
||||
STREAMING_CHECK(false) << "unsupported suite_name: " << suite_name;
|
||||
}
|
||||
|
@ -280,10 +273,30 @@ class StreamingWorker {
|
|||
StreamingWorker(const std::string &store_socket, const std::string &raylet_socket,
|
||||
int node_manager_port, const gcs::GcsClientOptions &gcs_options)
|
||||
: test_suite_(nullptr), peer_actor_handle_(nullptr) {
|
||||
worker_ = std::make_shared<CoreWorker>(
|
||||
WorkerType::WORKER, Language::PYTHON, store_socket, raylet_socket,
|
||||
JobID::FromInt(1), gcs_options, "", "127.0.0.1", node_manager_port,
|
||||
std::bind(&StreamingWorker::ExecuteTask, this, _1, _2, _3, _4, _5, _6, _7));
|
||||
CoreWorkerOptions options = {
|
||||
WorkerType::WORKER, // worker_type
|
||||
Language::PYTHON, // langauge
|
||||
store_socket, // store_socket
|
||||
raylet_socket, // raylet_socket
|
||||
JobID::FromInt(1), // job_id
|
||||
gcs_options, // gcs_options
|
||||
"", // log_dir
|
||||
true, // install_failure_signal_handler
|
||||
"127.0.0.1", // node_ip_address
|
||||
node_manager_port, // node_manager_port
|
||||
"", // driver_name
|
||||
"", // stdout_file
|
||||
"", // stderr_file
|
||||
std::bind(&StreamingWorker::ExecuteTask, this, _1, _2, _3, _4, _5, _6,
|
||||
_7), // task_execution_callback
|
||||
nullptr, // check_signals
|
||||
nullptr, // gc_collect
|
||||
nullptr, // get_lang_stack
|
||||
true, // ref_counting_enabled
|
||||
false, // is_local_mode
|
||||
1, // num_workers
|
||||
};
|
||||
CoreWorkerProcess::Initialize(options);
|
||||
|
||||
RayFunction reader_async_call_func{ray::Language::PYTHON,
|
||||
ray::FunctionDescriptorBuilder::BuildPython(
|
||||
|
@ -298,16 +311,16 @@ class StreamingWorker {
|
|||
ray::Language::PYTHON,
|
||||
ray::FunctionDescriptorBuilder::BuildPython("writer_sync_call_func", "", "", "")};
|
||||
|
||||
reader_client_ = std::make_shared<ReaderClient>(worker_.get(), reader_async_call_func,
|
||||
reader_sync_call_func);
|
||||
writer_client_ = std::make_shared<WriterClient>(worker_.get(), writer_async_call_func,
|
||||
writer_sync_call_func);
|
||||
reader_client_ =
|
||||
std::make_shared<ReaderClient>(reader_async_call_func, reader_sync_call_func);
|
||||
writer_client_ =
|
||||
std::make_shared<WriterClient>(writer_async_call_func, writer_sync_call_func);
|
||||
STREAMING_LOG(INFO) << "StreamingWorker constructor";
|
||||
}
|
||||
|
||||
void StartExecutingTasks() {
|
||||
void RunTaskExecutionLoop() {
|
||||
// Start executing tasks.
|
||||
worker_->StartExecutingTasks();
|
||||
CoreWorkerProcess::RunTaskExecutionLoop();
|
||||
}
|
||||
|
||||
private:
|
||||
|
@ -403,7 +416,8 @@ class StreamingWorker {
|
|||
|
||||
STREAMING_LOG(INFO) << "Init message: " << message->ToString();
|
||||
std::string actor_handle_serialized = message->ActorHandleSerialized();
|
||||
worker_->DeserializeAndRegisterActorHandle(actor_handle_serialized, ObjectID::Nil());
|
||||
CoreWorkerProcess::GetCoreWorker().DeserializeAndRegisterActorHandle(
|
||||
actor_handle_serialized, ObjectID::Nil());
|
||||
std::shared_ptr<ActorHandle> actor_handle(new ActorHandle(actor_handle_serialized));
|
||||
STREAMING_CHECK(actor_handle != nullptr);
|
||||
STREAMING_LOG(INFO) << " actor id from handle: " << actor_handle->GetActorID();
|
||||
|
@ -421,12 +435,11 @@ class StreamingWorker {
|
|||
STREAMING_LOG(INFO) << "rescale queue: " << qid;
|
||||
}
|
||||
|
||||
test_suite_ = TestSuiteFactory::CreateTestSuite(worker_, message);
|
||||
test_suite_ = TestSuiteFactory::CreateTestSuite(message);
|
||||
STREAMING_CHECK(test_suite_ != nullptr);
|
||||
}
|
||||
|
||||
private:
|
||||
std::shared_ptr<CoreWorker> worker_;
|
||||
std::shared_ptr<ReaderClient> reader_client_;
|
||||
std::shared_ptr<WriterClient> writer_client_;
|
||||
std::shared_ptr<std::thread> test_thread_;
|
||||
|
@ -446,6 +459,6 @@ int main(int argc, char **argv) {
|
|||
ray::gcs::GcsClientOptions gcs_options("127.0.0.1", 6379, "");
|
||||
ray::streaming::StreamingWorker worker(store_socket, raylet_socket, node_manager_port,
|
||||
gcs_options);
|
||||
worker.StartExecutingTasks();
|
||||
worker.RunTaskExecutionLoop();
|
||||
return 0;
|
||||
}
|
||||
|
|
|
@ -153,11 +153,12 @@ class StreamingQueueTestBase : public ::testing::TestWithParam<uint64_t> {
|
|||
ASSERT_TRUE(system(("rm -rf " + raylet_socket_name + ".pid").c_str()) == 0);
|
||||
}
|
||||
|
||||
void InitWorker(CoreWorker &driver, ActorID &self_actor_id, ActorID &peer_actor_id,
|
||||
void InitWorker(ActorID &self_actor_id, ActorID &peer_actor_id,
|
||||
const queue::protobuf::StreamingQueueTestRole role,
|
||||
const std::vector<ObjectID> &queue_ids,
|
||||
const std::vector<ObjectID> &rescale_queue_ids, std::string suite_name,
|
||||
std::string test_name, uint64_t param) {
|
||||
auto &driver = CoreWorkerProcess::GetCoreWorker();
|
||||
std::string forked_serialized_str;
|
||||
ObjectID actor_handle_id;
|
||||
Status st = driver.SerializeActorHandle(peer_actor_id, &forked_serialized_str,
|
||||
|
@ -179,7 +180,8 @@ class StreamingQueueTestBase : public ::testing::TestWithParam<uint64_t> {
|
|||
RAY_CHECK_OK(driver.SubmitActorTask(self_actor_id, func, args, options, &return_ids));
|
||||
}
|
||||
|
||||
void SubmitTestToActor(CoreWorker &driver, ActorID &actor_id, const std::string test) {
|
||||
void SubmitTestToActor(ActorID &actor_id, const std::string test) {
|
||||
auto &driver = CoreWorkerProcess::GetCoreWorker();
|
||||
uint8_t data[8];
|
||||
auto buffer = std::make_shared<LocalMemoryBuffer>(data, 8, true);
|
||||
std::vector<TaskArg> args;
|
||||
|
@ -194,7 +196,8 @@ class StreamingQueueTestBase : public ::testing::TestWithParam<uint64_t> {
|
|||
RAY_CHECK_OK(driver.SubmitActorTask(actor_id, func, args, options, &return_ids));
|
||||
}
|
||||
|
||||
bool CheckCurTest(CoreWorker &driver, ActorID &actor_id, const std::string test_name) {
|
||||
bool CheckCurTest(ActorID &actor_id, const std::string test_name) {
|
||||
auto &driver = CoreWorkerProcess::GetCoreWorker();
|
||||
uint8_t data[8];
|
||||
auto buffer = std::make_shared<LocalMemoryBuffer>(data, 8, true);
|
||||
std::vector<TaskArg> args;
|
||||
|
@ -255,8 +258,7 @@ class StreamingQueueTestBase : public ::testing::TestWithParam<uint64_t> {
|
|||
return message->Status();
|
||||
}
|
||||
|
||||
ActorID CreateActorHelper(CoreWorker &worker,
|
||||
const std::unordered_map<std::string, double> &resources,
|
||||
ActorID CreateActorHelper(const std::unordered_map<std::string, double> &resources,
|
||||
bool is_direct_call, uint64_t max_reconstructions) {
|
||||
std::unique_ptr<ActorHandle> actor_handle;
|
||||
|
||||
|
@ -277,8 +279,8 @@ class StreamingQueueTestBase : public ::testing::TestWithParam<uint64_t> {
|
|||
|
||||
// Create an actor.
|
||||
ActorID actor_id;
|
||||
RAY_CHECK_OK(
|
||||
worker.CreateActor(func, args, actor_options, /*extension_data*/ "", &actor_id));
|
||||
RAY_CHECK_OK(CoreWorkerProcess::GetCoreWorker().CreateActor(
|
||||
func, args, actor_options, /*extension_data*/ "", &actor_id));
|
||||
return actor_id;
|
||||
}
|
||||
|
||||
|
@ -305,33 +307,54 @@ class StreamingQueueTestBase : public ::testing::TestWithParam<uint64_t> {
|
|||
}
|
||||
STREAMING_LOG(INFO) << "Sub process: writer.";
|
||||
|
||||
CoreWorker driver(WorkerType::DRIVER, Language::PYTHON, raylet_store_socket_names_[0],
|
||||
raylet_socket_names_[0], NextJobId(), gcs_options_, "", "127.0.0.1",
|
||||
node_manager_port_, nullptr);
|
||||
CoreWorkerOptions options = {
|
||||
WorkerType::DRIVER, // worker_type
|
||||
Language::PYTHON, // langauge
|
||||
raylet_store_socket_names_[0], // store_socket
|
||||
raylet_socket_names_[0], // raylet_socket
|
||||
NextJobId(), // job_id
|
||||
gcs_options_, // gcs_options
|
||||
"", // log_dir
|
||||
true, // install_failure_signal_handler
|
||||
"127.0.0.1", // node_ip_address
|
||||
node_manager_port_, // node_manager_port
|
||||
"queue_tests", // driver_name
|
||||
"", // stdout_file
|
||||
"", // stderr_file
|
||||
nullptr, // task_execution_callback
|
||||
nullptr, // check_signals
|
||||
nullptr, // gc_collect
|
||||
nullptr, // get_lang_stack
|
||||
true, // ref_counting_enabled
|
||||
false, // is_local_mode
|
||||
1, // num_workers
|
||||
};
|
||||
InitShutdownRAII core_worker_raii(CoreWorkerProcess::Initialize,
|
||||
CoreWorkerProcess::Shutdown, options);
|
||||
|
||||
// Create writer and reader actors
|
||||
std::unordered_map<std::string, double> resources;
|
||||
auto actor_id_writer = CreateActorHelper(driver, resources, true, 0);
|
||||
auto actor_id_reader = CreateActorHelper(driver, resources, true, 0);
|
||||
auto actor_id_writer = CreateActorHelper(resources, true, 0);
|
||||
auto actor_id_reader = CreateActorHelper(resources, true, 0);
|
||||
|
||||
InitWorker(driver, actor_id_writer, actor_id_reader,
|
||||
InitWorker(actor_id_writer, actor_id_reader,
|
||||
queue::protobuf::StreamingQueueTestRole::WRITER, queue_id_vec,
|
||||
rescale_queue_id_vec, suite_name, test_name, GetParam());
|
||||
InitWorker(driver, actor_id_reader, actor_id_writer,
|
||||
InitWorker(actor_id_reader, actor_id_writer,
|
||||
queue::protobuf::StreamingQueueTestRole::READER, queue_id_vec,
|
||||
rescale_queue_id_vec, suite_name, test_name, GetParam());
|
||||
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(2000));
|
||||
|
||||
SubmitTestToActor(driver, actor_id_writer, test_name);
|
||||
SubmitTestToActor(driver, actor_id_reader, test_name);
|
||||
SubmitTestToActor(actor_id_writer, test_name);
|
||||
SubmitTestToActor(actor_id_reader, test_name);
|
||||
|
||||
uint64_t slept_time_ms = 0;
|
||||
while (slept_time_ms < timeout_ms) {
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(5 * 1000));
|
||||
STREAMING_LOG(INFO) << "Check test status.";
|
||||
if (CheckCurTest(driver, actor_id_writer, test_name) &&
|
||||
CheckCurTest(driver, actor_id_reader, test_name)) {
|
||||
if (CheckCurTest(actor_id_writer, test_name) &&
|
||||
CheckCurTest(actor_id_reader, test_name)) {
|
||||
STREAMING_LOG(INFO) << "Test Success, Exit.";
|
||||
return;
|
||||
}
|
||||
|
|
Loading…
Add table
Reference in a new issue