[Java] Enable concurrent calls in local mode. (#14896)

* Enable concurrent calls in local mode.

* Fix submitting actor tasks before actor creation task executed.

Co-authored-by: Qing Wang <jovany.wq@antgroup.com>
This commit is contained in:
Qing Wang 2021-06-10 23:21:11 +08:00 committed by GitHub
parent 9741bc00c9
commit d6d27e9d34
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 62 additions and 13 deletions

View file

@ -32,6 +32,7 @@ import io.ray.runtime.generated.Common.TaskType;
import io.ray.runtime.object.LocalModeObjectStore;
import io.ray.runtime.object.NativeRayObject;
import io.ray.runtime.placementgroup.PlacementGroupImpl;
import io.ray.runtime.util.IdUtil;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Collections;
@ -65,6 +66,8 @@ public class LocalModeTaskSubmitter implements TaskSubmitter {
/// The thread pool to execute actor tasks.
private final Map<ActorId, ExecutorService> actorTaskExecutorServices;
private final Map<ActorId, Integer> actorMaxConcurrency = new ConcurrentHashMap<>();
/// The thread pool to execute normal tasks.
private final ExecutorService normalTaskExecutorService;
@ -114,13 +117,22 @@ public class LocalModeTaskSubmitter implements TaskSubmitter {
}
}
}
if (taskSpec.getType() == TaskType.ACTOR_TASK) {
if (taskSpec.getType() == TaskType.ACTOR_TASK && !isConcurrentActor(taskSpec)) {
ObjectId dummyObjectId =
new ObjectId(
taskSpec.getActorTaskSpec().getPreviousActorTaskDummyObjectId().toByteArray());
if (!objectStore.isObjectReady(dummyObjectId)) {
unreadyObjects.add(dummyObjectId);
}
} else if (taskSpec.getType() == TaskType.ACTOR_TASK) {
// Code path of concurrent actors.
// For concurrent actors, we should make sure the actor created
// before we submit the following actor tasks.
ActorId actorId = ActorId.fromBytes(taskSpec.getActorTaskSpec().getActorId().toByteArray());
ObjectId dummyActorCreationObjectId = IdUtil.getActorCreationDummyObjectId(actorId);
if (!objectStore.isObjectReady(dummyActorCreationObjectId)) {
unreadyObjects.add(dummyActorCreationObjectId);
}
}
return unreadyObjects;
}
@ -198,6 +210,7 @@ public class LocalModeTaskSubmitter implements TaskSubmitter {
.setActorCreationTaskSpec(
ActorCreationTaskSpec.newBuilder()
.setActorId(ByteString.copyFrom(actorId.toByteBuffer()))
.setMaxConcurrency(options.maxConcurrency)
.build())
.build();
submitTaskSpec(taskSpec);
@ -320,17 +333,33 @@ public class LocalModeTaskSubmitter implements TaskSubmitter {
() -> {
try {
executeTask(taskSpec);
if (taskSpec.getType() == TaskType.ACTOR_CREATION_TASK) {
// Construct a dummy object id for actor creation task so that the following
// actor task can touch if this actor is created.
ObjectId dummy =
IdUtil.getActorCreationDummyObjectId(
ActorId.fromBytes(
taskSpec.getActorCreationTaskSpec().getActorId().toByteArray()));
objectStore.put(new Object(), dummy);
}
} catch (Exception ex) {
LOGGER.error("Unexpected exception when executing a task.", ex);
System.exit(-1);
}
};
if (taskSpec.getType() == TaskType.ACTOR_CREATION_TASK) {
actorMaxConcurrency.put(
getActorId(taskSpec), taskSpec.getActorCreationTaskSpec().getMaxConcurrency());
}
if (unreadyObjects.isEmpty()) {
// If all dependencies are ready, execute this task.
ExecutorService executorService;
if (taskSpec.getType() == TaskType.ACTOR_CREATION_TASK) {
executorService = Executors.newSingleThreadExecutor();
final int maxConcurrency = taskSpec.getActorCreationTaskSpec().getMaxConcurrency();
Preconditions.checkState(maxConcurrency >= 1);
executorService = Executors.newFixedThreadPool(maxConcurrency);
synchronized (actorTaskExecutorServices) {
actorTaskExecutorServices.put(getActorId(taskSpec), executorService);
}
@ -362,11 +391,16 @@ public class LocalModeTaskSubmitter implements TaskSubmitter {
private void executeTask(TaskSpec taskSpec) {
TaskExecutor.ActorContext actorContext = null;
UniqueId workerId;
if (taskSpec.getType() == TaskType.ACTOR_TASK) {
actorContext = actorContexts.get(getActorId(taskSpec));
Preconditions.checkNotNull(actorContext);
workerId = ((LocalModeTaskExecutor.LocalActorContext) actorContext).getWorkerId();
} else {
// Actor creation task and normal task will use a new random worker id.
workerId = UniqueId.randomId();
}
taskExecutor.setActorContext(actorContext);
taskExecutor.setActorContext(workerId, actorContext);
List<NativeRayObject> args =
getFunctionArgs(taskSpec).stream()
.map(
@ -377,10 +411,7 @@ public class LocalModeTaskSubmitter implements TaskSubmitter {
.collect(Collectors.toList());
runtime.setIsContextSet(true);
((LocalModeWorkerContext) runtime.getWorkerContext()).setCurrentTask(taskSpec);
UniqueId workerId =
actorContext != null
? ((LocalModeTaskExecutor.LocalActorContext) actorContext).getWorkerId()
: UniqueId.randomId();
((LocalModeWorkerContext) runtime.getWorkerContext()).setCurrentWorkerId(workerId);
List<String> rayFunctionInfo = getJavaFunctionDescriptor(taskSpec).toList();
taskExecutor.checkByteBufferArguments(rayFunctionInfo);
@ -388,7 +419,9 @@ public class LocalModeTaskSubmitter implements TaskSubmitter {
if (taskSpec.getType() == TaskType.ACTOR_CREATION_TASK) {
// Update actor context map ASAP in case objectStore.putRaw triggered the next actor task
// on this actor.
actorContexts.put(getActorId(taskSpec), taskExecutor.getActorContext());
final TaskExecutor.ActorContext ac = taskExecutor.getActorContext();
Preconditions.checkNotNull(ac);
actorContexts.put(getActorId(taskSpec), ac);
}
// Set this flag to true is necessary because at the end of `taskExecutor.execute()`,
// this flag will be set to false. And `runtime.getWorkerContext()` requires it to be
@ -460,4 +493,11 @@ public class LocalModeTaskSubmitter implements TaskSubmitter {
}
return returnIds;
}
/** Whether this is an actor creation task spec of a concurrent actor. */
private boolean isConcurrentActor(TaskSpec taskSpec) {
final ActorId actorId = getActorId(taskSpec);
Preconditions.checkNotNull(actorId);
return actorMaxConcurrency.containsKey(actorId) && actorMaxConcurrency.get(actorId) > 1;
}
}

View file

@ -34,7 +34,6 @@ public abstract class TaskExecutor<T extends TaskExecutor.ActorContext> {
private final ThreadLocal<RayFunction> localRayFunction = new ThreadLocal<>();
static class ActorContext {
/** The current actor object, if this worker is an actor, otherwise null. */
Object currentActor = null;
}
@ -49,12 +48,12 @@ public abstract class TaskExecutor<T extends TaskExecutor.ActorContext> {
return actorContextMap.get(runtime.getWorkerContext().getCurrentWorkerId());
}
void setActorContext(T actorContext) {
void setActorContext(UniqueId workerId, T actorContext) {
if (actorContext == null) {
// ConcurrentHashMap doesn't allow null values. So just return here.
return;
}
this.actorContextMap.put(runtime.getWorkerContext().getCurrentWorkerId(), actorContext);
this.actorContextMap.put(workerId, actorContext);
}
protected void removeActorContext(UniqueId workerId) {
@ -93,7 +92,7 @@ public abstract class TaskExecutor<T extends TaskExecutor.ActorContext> {
T actorContext = null;
if (taskType == TaskType.ACTOR_CREATION_TASK) {
actorContext = createActorContext();
setActorContext(actorContext);
setActorContext(runtime.getWorkerContext().getCurrentWorkerId(), actorContext);
} else if (taskType == TaskType.ACTOR_TASK) {
actorContext = getActorContext();
Preconditions.checkNotNull(actorContext);

View file

@ -3,6 +3,7 @@ package io.ray.runtime.util;
import io.ray.api.id.ActorId;
import io.ray.api.id.ObjectId;
import io.ray.api.id.TaskId;
import java.util.Arrays;
/**
* Helper method for different Ids. Note: any changes to these methods must be synced with C++
@ -24,4 +25,13 @@ public class IdUtil {
taskId.getBytes(), TaskId.UNIQUE_BYTES_LENGTH, actorIdBytes, 0, ActorId.LENGTH);
return ActorId.fromBytes(actorIdBytes);
}
/** Compute the dummy object id for actor creation task. */
public static ObjectId getActorCreationDummyObjectId(ActorId actorId) {
byte[] objectIdBytes = new byte[ObjectId.LENGTH];
Arrays.fill(objectIdBytes, (byte) 0xFF);
byte[] actorIdBytes = actorId.getBytes();
System.arraycopy(actorIdBytes, 0, objectIdBytes, 0, ActorId.LENGTH);
return new ObjectId(objectIdBytes);
}
}

View file

@ -7,7 +7,7 @@ import java.util.concurrent.CountDownLatch;
import org.testng.Assert;
import org.testng.annotations.Test;
@Test(groups = {"cluster"})
@Test
public class ActorConcurrentCallTest extends BaseTest {
public static class ConcurrentActor {