[Java] Enhancement single process mode (#6795)

* enhancement

* Add ut

* Update java/runtime/src/main/java/org/ray/runtime/task/LocalModeTaskSubmitter.java

Co-Authored-By: Kai Yang <kfstorm@outlook.com>

* Update java/test/src/main/java/org/ray/api/test/RunModeTest.java

Co-Authored-By: Kai Yang <kfstorm@outlook.com>

* Address comments

* Use ExecutorSerivce to replace raw thread

Co-authored-by: Kai Yang <kfstorm@outlook.com>
This commit is contained in:
Qing Wang 2020-01-15 21:38:53 +08:00 committed by GitHub
parent cd5fc81bdd
commit ad90693ca8
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 137 additions and 34 deletions

View file

@ -14,8 +14,8 @@ import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.Set;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.*;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.stream.Collectors;
import org.ray.api.RayActor;
import org.ray.api.id.ActorId;
@ -50,7 +50,13 @@ public class LocalModeTaskSubmitter implements TaskSubmitter {
private final Object taskAndObjectLock = new Object();
private final RayDevRuntime runtime;
private final LocalModeObjectStore objectStore;
private final ExecutorService exec;
/// The thread pool to execute actor tasks.
private final Map<ActorId, ExecutorService> actorTaskExecutorServices;
/// The thread pool to execute normal tasks.
private final ExecutorService normalTaskExecutorService;
private final Deque<TaskExecutor> idleTaskExecutors = new ArrayDeque<>();
private final Map<ActorId, TaskExecutor> actorTaskExecutors = new HashMap<>();
private final Object taskExecutorLock = new Object();
@ -60,8 +66,10 @@ public class LocalModeTaskSubmitter implements TaskSubmitter {
int numberThreads) {
this.runtime = runtime;
this.objectStore = objectStore;
// The thread pool that executes tasks in parallel.
exec = Executors.newFixedThreadPool(numberThreads);
// The thread pool that executes normal tasks in parallel.
normalTaskExecutorService = Executors.newFixedThreadPool(numberThreads);
// The thread pool that executes actor tasks in parallel.
actorTaskExecutorServices = new HashMap<>();
}
public void onObjectPut(ObjectId id) {
@ -211,7 +219,14 @@ public class LocalModeTaskSubmitter implements TaskSubmitter {
}
public void shutdown() {
exec.shutdown();
// Shutdown actor task executor service.
synchronized (actorTaskExecutorServices) {
for (Map.Entry<ActorId, ExecutorService> item : actorTaskExecutorServices.entrySet()) {
item.getValue().shutdown();
}
}
// Shutdown normal task executor service.
normalTaskExecutorService.shutdown();
}
public static ActorId getActorId(TaskSpec taskSpec) {
@ -231,37 +246,54 @@ public class LocalModeTaskSubmitter implements TaskSubmitter {
LOGGER.debug("Submitting task: {}.", taskSpec);
synchronized (taskAndObjectLock) {
Set<ObjectId> unreadyObjects = getUnreadyObjects(taskSpec);
final Runnable runnable = () -> {
TaskExecutor taskExecutor = getTaskExecutor(taskSpec);
try {
List<NativeRayObject> args = getFunctionArgs(taskSpec).stream()
.map(arg -> arg.id != null ?
objectStore.getRaw(Collections.singletonList(arg.id), -1).get(0)
: arg.value)
.collect(Collectors.toList());
((LocalModeWorkerContext) runtime.getWorkerContext()).setCurrentTask(taskSpec);
List<NativeRayObject> returnObjects = taskExecutor
.execute(getJavaFunctionDescriptor(taskSpec).toList(), args);
((LocalModeWorkerContext) runtime.getWorkerContext()).setCurrentTask(null);
List<ObjectId> returnIds = getReturnIds(taskSpec);
for (int i = 0; i < returnIds.size(); i++) {
NativeRayObject putObject;
if (i >= returnObjects.size()) {
// If the task is an actor task or an actor creation task,
// put the dummy object in object store, so those tasks which depends on it
// can be executed.
putObject = new NativeRayObject(new byte[]{1}, null);
} else {
putObject = returnObjects.get(i);
}
objectStore.putRaw(putObject, returnIds.get(i));
}
} finally {
returnTaskExecutor(taskExecutor, taskSpec);
}
};
if (unreadyObjects.isEmpty()) {
// If all dependencies are ready, execute this task.
exec.submit(() -> {
TaskExecutor taskExecutor = getTaskExecutor(taskSpec);
try {
List<NativeRayObject> args = getFunctionArgs(taskSpec).stream()
.map(arg -> arg.id != null ?
objectStore.getRaw(Collections.singletonList(arg.id), -1).get(0)
: arg.value)
.collect(Collectors.toList());
((LocalModeWorkerContext) runtime.getWorkerContext()).setCurrentTask(taskSpec);
List<NativeRayObject> returnObjects = taskExecutor
.execute(getJavaFunctionDescriptor(taskSpec).toList(), args);
((LocalModeWorkerContext) runtime.getWorkerContext()).setCurrentTask(null);
List<ObjectId> returnIds = getReturnIds(taskSpec);
for (int i = 0; i < returnIds.size(); i++) {
NativeRayObject putObject;
if (i >= returnObjects.size()) {
// If the task is an actor task or an actor creation task,
// put the dummy object in object store, so those tasks which depends on it
// can be executed.
putObject = new NativeRayObject(new byte[]{1}, null);
} else {
putObject = returnObjects.get(i);
}
objectStore.putRaw(putObject, returnIds.get(i));
}
} finally {
returnTaskExecutor(taskExecutor, taskSpec);
if (taskSpec.getType() == TaskType.ACTOR_CREATION_TASK) {
ExecutorService actorExecutorService = Executors.newSingleThreadExecutor();
synchronized (actorTaskExecutorServices) {
actorTaskExecutorServices.put(getActorId(taskSpec), actorExecutorService);
}
});
actorExecutorService.submit(runnable);
} else if (taskSpec.getType() == TaskType.ACTOR_TASK) {
synchronized (actorTaskExecutorServices) {
ExecutorService actorExecutorService = actorTaskExecutorServices.get(getActorId(taskSpec));
actorExecutorService.submit(runnable);
}
} else {
// Normal task.
normalTaskExecutorService.submit(runnable);
}
} else {
// If some dependencies aren't ready yet, put this task in waiting list.
for (ObjectId id : unreadyObjects) {

View file

@ -27,6 +27,12 @@ public class TestUtils {
}
}
public static void skipTestUnderClusterMode() {
if (getRuntime().getRayConfig().runMode == RunMode.CLUSTER) {
throw new SkipException("This test doesn't work under cluster mode.");
}
}
public static void skipTestIfDirectActorCallEnabled() {
skipTestIfDirectActorCallEnabled(true);
}

View file

@ -0,0 +1,65 @@
package org.ray.api.test;
import org.ray.api.Ray;
import org.ray.api.RayActor;
import org.ray.api.RayObject;
import org.ray.api.TestUtils;
import org.ray.api.annotation.RayRemote;
import org.ray.api.id.ActorId;
import org.testng.Assert;
import org.testng.annotations.Test;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
public class SingleProcessModeTest extends BaseTest {
private final static int NUM_ACTOR_INSTANCE = 10;
private final static int TIMES_TO_CALL_PER_ACTOR = 10;
@RayRemote
static class MyActor {
public MyActor() {
}
public long getThreadId() {
return Thread.currentThread().getId();
}
}
@Test
public void testActorTasksInOneThread() {
TestUtils.skipTestUnderClusterMode();
List<RayActor<MyActor>> actors = new ArrayList<>();
Map<ActorId, Long> actorThreadIds = new HashMap<>();
for (int i = 0; i < NUM_ACTOR_INSTANCE; ++i) {
RayActor<MyActor> actor = Ray.createActor(MyActor::new);
actors.add(actor);
actorThreadIds.put(actor.getId(), Ray.call(MyActor::getThreadId, actor).get());
}
Map<ActorId, List<RayObject<Long>>> allResults = new HashMap<>();
for (int i = 0; i < NUM_ACTOR_INSTANCE; ++i) {
final RayActor<MyActor> actor = actors.get(i);
List<RayObject<Long>> thisActorResult = new ArrayList<>();
for (int j = 0; j < TIMES_TO_CALL_PER_ACTOR; ++j) {
thisActorResult.add(Ray.call(MyActor::getThreadId, actor));
}
allResults.put(actor.getId(), thisActorResult);
}
// check result.
for (int i = 0; i < NUM_ACTOR_INSTANCE; ++i) {
final RayActor<MyActor> actor = actors.get(i);
final List<RayObject<Long>> thisActorResult = allResults.get(actor.getId());
// assert
for (RayObject<Long> threadId : thisActorResult) {
Assert.assertEquals(threadId.get(), actorThreadIds.get(actor.getId()));
}
}
}
}