mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[Java] Enhancement single process mode (#6795)
* enhancement * Add ut * Update java/runtime/src/main/java/org/ray/runtime/task/LocalModeTaskSubmitter.java Co-Authored-By: Kai Yang <kfstorm@outlook.com> * Update java/test/src/main/java/org/ray/api/test/RunModeTest.java Co-Authored-By: Kai Yang <kfstorm@outlook.com> * Address comments * Use ExecutorSerivce to replace raw thread Co-authored-by: Kai Yang <kfstorm@outlook.com>
This commit is contained in:
parent
cd5fc81bdd
commit
ad90693ca8
3 changed files with 137 additions and 34 deletions
|
@ -14,8 +14,8 @@ import java.util.List;
|
|||
import java.util.Map;
|
||||
import java.util.Random;
|
||||
import java.util.Set;
|
||||
import java.util.concurrent.ExecutorService;
|
||||
import java.util.concurrent.Executors;
|
||||
import java.util.concurrent.*;
|
||||
import java.util.concurrent.atomic.AtomicBoolean;
|
||||
import java.util.stream.Collectors;
|
||||
import org.ray.api.RayActor;
|
||||
import org.ray.api.id.ActorId;
|
||||
|
@ -50,7 +50,13 @@ public class LocalModeTaskSubmitter implements TaskSubmitter {
|
|||
private final Object taskAndObjectLock = new Object();
|
||||
private final RayDevRuntime runtime;
|
||||
private final LocalModeObjectStore objectStore;
|
||||
private final ExecutorService exec;
|
||||
|
||||
/// The thread pool to execute actor tasks.
|
||||
private final Map<ActorId, ExecutorService> actorTaskExecutorServices;
|
||||
|
||||
/// The thread pool to execute normal tasks.
|
||||
private final ExecutorService normalTaskExecutorService;
|
||||
|
||||
private final Deque<TaskExecutor> idleTaskExecutors = new ArrayDeque<>();
|
||||
private final Map<ActorId, TaskExecutor> actorTaskExecutors = new HashMap<>();
|
||||
private final Object taskExecutorLock = new Object();
|
||||
|
@ -60,8 +66,10 @@ public class LocalModeTaskSubmitter implements TaskSubmitter {
|
|||
int numberThreads) {
|
||||
this.runtime = runtime;
|
||||
this.objectStore = objectStore;
|
||||
// The thread pool that executes tasks in parallel.
|
||||
exec = Executors.newFixedThreadPool(numberThreads);
|
||||
// The thread pool that executes normal tasks in parallel.
|
||||
normalTaskExecutorService = Executors.newFixedThreadPool(numberThreads);
|
||||
// The thread pool that executes actor tasks in parallel.
|
||||
actorTaskExecutorServices = new HashMap<>();
|
||||
}
|
||||
|
||||
public void onObjectPut(ObjectId id) {
|
||||
|
@ -211,7 +219,14 @@ public class LocalModeTaskSubmitter implements TaskSubmitter {
|
|||
}
|
||||
|
||||
public void shutdown() {
|
||||
exec.shutdown();
|
||||
// Shutdown actor task executor service.
|
||||
synchronized (actorTaskExecutorServices) {
|
||||
for (Map.Entry<ActorId, ExecutorService> item : actorTaskExecutorServices.entrySet()) {
|
||||
item.getValue().shutdown();
|
||||
}
|
||||
}
|
||||
// Shutdown normal task executor service.
|
||||
normalTaskExecutorService.shutdown();
|
||||
}
|
||||
|
||||
public static ActorId getActorId(TaskSpec taskSpec) {
|
||||
|
@ -231,37 +246,54 @@ public class LocalModeTaskSubmitter implements TaskSubmitter {
|
|||
LOGGER.debug("Submitting task: {}.", taskSpec);
|
||||
synchronized (taskAndObjectLock) {
|
||||
Set<ObjectId> unreadyObjects = getUnreadyObjects(taskSpec);
|
||||
|
||||
final Runnable runnable = () -> {
|
||||
TaskExecutor taskExecutor = getTaskExecutor(taskSpec);
|
||||
try {
|
||||
List<NativeRayObject> args = getFunctionArgs(taskSpec).stream()
|
||||
.map(arg -> arg.id != null ?
|
||||
objectStore.getRaw(Collections.singletonList(arg.id), -1).get(0)
|
||||
: arg.value)
|
||||
.collect(Collectors.toList());
|
||||
((LocalModeWorkerContext) runtime.getWorkerContext()).setCurrentTask(taskSpec);
|
||||
List<NativeRayObject> returnObjects = taskExecutor
|
||||
.execute(getJavaFunctionDescriptor(taskSpec).toList(), args);
|
||||
((LocalModeWorkerContext) runtime.getWorkerContext()).setCurrentTask(null);
|
||||
List<ObjectId> returnIds = getReturnIds(taskSpec);
|
||||
for (int i = 0; i < returnIds.size(); i++) {
|
||||
NativeRayObject putObject;
|
||||
if (i >= returnObjects.size()) {
|
||||
// If the task is an actor task or an actor creation task,
|
||||
// put the dummy object in object store, so those tasks which depends on it
|
||||
// can be executed.
|
||||
putObject = new NativeRayObject(new byte[]{1}, null);
|
||||
} else {
|
||||
putObject = returnObjects.get(i);
|
||||
}
|
||||
objectStore.putRaw(putObject, returnIds.get(i));
|
||||
}
|
||||
} finally {
|
||||
returnTaskExecutor(taskExecutor, taskSpec);
|
||||
}
|
||||
};
|
||||
|
||||
if (unreadyObjects.isEmpty()) {
|
||||
// If all dependencies are ready, execute this task.
|
||||
exec.submit(() -> {
|
||||
TaskExecutor taskExecutor = getTaskExecutor(taskSpec);
|
||||
try {
|
||||
List<NativeRayObject> args = getFunctionArgs(taskSpec).stream()
|
||||
.map(arg -> arg.id != null ?
|
||||
objectStore.getRaw(Collections.singletonList(arg.id), -1).get(0)
|
||||
: arg.value)
|
||||
.collect(Collectors.toList());
|
||||
((LocalModeWorkerContext) runtime.getWorkerContext()).setCurrentTask(taskSpec);
|
||||
List<NativeRayObject> returnObjects = taskExecutor
|
||||
.execute(getJavaFunctionDescriptor(taskSpec).toList(), args);
|
||||
((LocalModeWorkerContext) runtime.getWorkerContext()).setCurrentTask(null);
|
||||
List<ObjectId> returnIds = getReturnIds(taskSpec);
|
||||
for (int i = 0; i < returnIds.size(); i++) {
|
||||
NativeRayObject putObject;
|
||||
if (i >= returnObjects.size()) {
|
||||
// If the task is an actor task or an actor creation task,
|
||||
// put the dummy object in object store, so those tasks which depends on it
|
||||
// can be executed.
|
||||
putObject = new NativeRayObject(new byte[]{1}, null);
|
||||
} else {
|
||||
putObject = returnObjects.get(i);
|
||||
}
|
||||
objectStore.putRaw(putObject, returnIds.get(i));
|
||||
}
|
||||
} finally {
|
||||
returnTaskExecutor(taskExecutor, taskSpec);
|
||||
if (taskSpec.getType() == TaskType.ACTOR_CREATION_TASK) {
|
||||
ExecutorService actorExecutorService = Executors.newSingleThreadExecutor();
|
||||
synchronized (actorTaskExecutorServices) {
|
||||
actorTaskExecutorServices.put(getActorId(taskSpec), actorExecutorService);
|
||||
}
|
||||
});
|
||||
actorExecutorService.submit(runnable);
|
||||
} else if (taskSpec.getType() == TaskType.ACTOR_TASK) {
|
||||
synchronized (actorTaskExecutorServices) {
|
||||
ExecutorService actorExecutorService = actorTaskExecutorServices.get(getActorId(taskSpec));
|
||||
actorExecutorService.submit(runnable);
|
||||
}
|
||||
} else {
|
||||
// Normal task.
|
||||
normalTaskExecutorService.submit(runnable);
|
||||
}
|
||||
} else {
|
||||
// If some dependencies aren't ready yet, put this task in waiting list.
|
||||
for (ObjectId id : unreadyObjects) {
|
||||
|
|
|
@ -27,6 +27,12 @@ public class TestUtils {
|
|||
}
|
||||
}
|
||||
|
||||
public static void skipTestUnderClusterMode() {
|
||||
if (getRuntime().getRayConfig().runMode == RunMode.CLUSTER) {
|
||||
throw new SkipException("This test doesn't work under cluster mode.");
|
||||
}
|
||||
}
|
||||
|
||||
public static void skipTestIfDirectActorCallEnabled() {
|
||||
skipTestIfDirectActorCallEnabled(true);
|
||||
}
|
||||
|
|
|
@ -0,0 +1,65 @@
|
|||
package org.ray.api.test;
|
||||
|
||||
import org.ray.api.Ray;
|
||||
import org.ray.api.RayActor;
|
||||
import org.ray.api.RayObject;
|
||||
import org.ray.api.TestUtils;
|
||||
import org.ray.api.annotation.RayRemote;
|
||||
import org.ray.api.id.ActorId;
|
||||
import org.testng.Assert;
|
||||
import org.testng.annotations.Test;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
public class SingleProcessModeTest extends BaseTest {
|
||||
|
||||
private final static int NUM_ACTOR_INSTANCE = 10;
|
||||
|
||||
private final static int TIMES_TO_CALL_PER_ACTOR = 10;
|
||||
|
||||
@RayRemote
|
||||
static class MyActor {
|
||||
public MyActor() {
|
||||
}
|
||||
|
||||
public long getThreadId() {
|
||||
return Thread.currentThread().getId();
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testActorTasksInOneThread() {
|
||||
TestUtils.skipTestUnderClusterMode();
|
||||
|
||||
List<RayActor<MyActor>> actors = new ArrayList<>();
|
||||
Map<ActorId, Long> actorThreadIds = new HashMap<>();
|
||||
for (int i = 0; i < NUM_ACTOR_INSTANCE; ++i) {
|
||||
RayActor<MyActor> actor = Ray.createActor(MyActor::new);
|
||||
actors.add(actor);
|
||||
actorThreadIds.put(actor.getId(), Ray.call(MyActor::getThreadId, actor).get());
|
||||
}
|
||||
|
||||
Map<ActorId, List<RayObject<Long>>> allResults = new HashMap<>();
|
||||
for (int i = 0; i < NUM_ACTOR_INSTANCE; ++i) {
|
||||
final RayActor<MyActor> actor = actors.get(i);
|
||||
List<RayObject<Long>> thisActorResult = new ArrayList<>();
|
||||
for (int j = 0; j < TIMES_TO_CALL_PER_ACTOR; ++j) {
|
||||
thisActorResult.add(Ray.call(MyActor::getThreadId, actor));
|
||||
}
|
||||
allResults.put(actor.getId(), thisActorResult);
|
||||
}
|
||||
|
||||
// check result.
|
||||
for (int i = 0; i < NUM_ACTOR_INSTANCE; ++i) {
|
||||
final RayActor<MyActor> actor = actors.get(i);
|
||||
final List<RayObject<Long>> thisActorResult = allResults.get(actor.getId());
|
||||
// assert
|
||||
for (RayObject<Long> threadId : thisActorResult) {
|
||||
Assert.assertEquals(threadId.get(), actorThreadIds.get(actor.getId()));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
Loading…
Add table
Reference in a new issue