[Java] Enable concurrent calls in local mode. (#14896)

* Enable concurrent calls in local mode.

* Fix submitting actor tasks before actor creation task executed.

Co-authored-by: Qing Wang <jovany.wq@antgroup.com>
This commit is contained in:
Qing Wang 2021-06-10 23:21:11 +08:00 committed by GitHub
parent 9741bc00c9
commit d6d27e9d34
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 62 additions and 13 deletions

View file

@ -32,6 +32,7 @@ import io.ray.runtime.generated.Common.TaskType;
import io.ray.runtime.object.LocalModeObjectStore; import io.ray.runtime.object.LocalModeObjectStore;
import io.ray.runtime.object.NativeRayObject; import io.ray.runtime.object.NativeRayObject;
import io.ray.runtime.placementgroup.PlacementGroupImpl; import io.ray.runtime.placementgroup.PlacementGroupImpl;
import io.ray.runtime.util.IdUtil;
import java.nio.ByteBuffer; import java.nio.ByteBuffer;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Collections; import java.util.Collections;
@ -65,6 +66,8 @@ public class LocalModeTaskSubmitter implements TaskSubmitter {
/// The thread pool to execute actor tasks. /// The thread pool to execute actor tasks.
private final Map<ActorId, ExecutorService> actorTaskExecutorServices; private final Map<ActorId, ExecutorService> actorTaskExecutorServices;
private final Map<ActorId, Integer> actorMaxConcurrency = new ConcurrentHashMap<>();
/// The thread pool to execute normal tasks. /// The thread pool to execute normal tasks.
private final ExecutorService normalTaskExecutorService; private final ExecutorService normalTaskExecutorService;
@ -114,13 +117,22 @@ public class LocalModeTaskSubmitter implements TaskSubmitter {
} }
} }
} }
if (taskSpec.getType() == TaskType.ACTOR_TASK) { if (taskSpec.getType() == TaskType.ACTOR_TASK && !isConcurrentActor(taskSpec)) {
ObjectId dummyObjectId = ObjectId dummyObjectId =
new ObjectId( new ObjectId(
taskSpec.getActorTaskSpec().getPreviousActorTaskDummyObjectId().toByteArray()); taskSpec.getActorTaskSpec().getPreviousActorTaskDummyObjectId().toByteArray());
if (!objectStore.isObjectReady(dummyObjectId)) { if (!objectStore.isObjectReady(dummyObjectId)) {
unreadyObjects.add(dummyObjectId); unreadyObjects.add(dummyObjectId);
} }
} else if (taskSpec.getType() == TaskType.ACTOR_TASK) {
// Code path of concurrent actors.
// For concurrent actors, we should make sure the actor created
// before we submit the following actor tasks.
ActorId actorId = ActorId.fromBytes(taskSpec.getActorTaskSpec().getActorId().toByteArray());
ObjectId dummyActorCreationObjectId = IdUtil.getActorCreationDummyObjectId(actorId);
if (!objectStore.isObjectReady(dummyActorCreationObjectId)) {
unreadyObjects.add(dummyActorCreationObjectId);
}
} }
return unreadyObjects; return unreadyObjects;
} }
@ -198,6 +210,7 @@ public class LocalModeTaskSubmitter implements TaskSubmitter {
.setActorCreationTaskSpec( .setActorCreationTaskSpec(
ActorCreationTaskSpec.newBuilder() ActorCreationTaskSpec.newBuilder()
.setActorId(ByteString.copyFrom(actorId.toByteBuffer())) .setActorId(ByteString.copyFrom(actorId.toByteBuffer()))
.setMaxConcurrency(options.maxConcurrency)
.build()) .build())
.build(); .build();
submitTaskSpec(taskSpec); submitTaskSpec(taskSpec);
@ -320,17 +333,33 @@ public class LocalModeTaskSubmitter implements TaskSubmitter {
() -> { () -> {
try { try {
executeTask(taskSpec); executeTask(taskSpec);
if (taskSpec.getType() == TaskType.ACTOR_CREATION_TASK) {
// Construct a dummy object id for actor creation task so that the following
// actor task can touch if this actor is created.
ObjectId dummy =
IdUtil.getActorCreationDummyObjectId(
ActorId.fromBytes(
taskSpec.getActorCreationTaskSpec().getActorId().toByteArray()));
objectStore.put(new Object(), dummy);
}
} catch (Exception ex) { } catch (Exception ex) {
LOGGER.error("Unexpected exception when executing a task.", ex); LOGGER.error("Unexpected exception when executing a task.", ex);
System.exit(-1); System.exit(-1);
} }
}; };
if (taskSpec.getType() == TaskType.ACTOR_CREATION_TASK) {
actorMaxConcurrency.put(
getActorId(taskSpec), taskSpec.getActorCreationTaskSpec().getMaxConcurrency());
}
if (unreadyObjects.isEmpty()) { if (unreadyObjects.isEmpty()) {
// If all dependencies are ready, execute this task. // If all dependencies are ready, execute this task.
ExecutorService executorService; ExecutorService executorService;
if (taskSpec.getType() == TaskType.ACTOR_CREATION_TASK) { if (taskSpec.getType() == TaskType.ACTOR_CREATION_TASK) {
executorService = Executors.newSingleThreadExecutor(); final int maxConcurrency = taskSpec.getActorCreationTaskSpec().getMaxConcurrency();
Preconditions.checkState(maxConcurrency >= 1);
executorService = Executors.newFixedThreadPool(maxConcurrency);
synchronized (actorTaskExecutorServices) { synchronized (actorTaskExecutorServices) {
actorTaskExecutorServices.put(getActorId(taskSpec), executorService); actorTaskExecutorServices.put(getActorId(taskSpec), executorService);
} }
@ -362,11 +391,16 @@ public class LocalModeTaskSubmitter implements TaskSubmitter {
private void executeTask(TaskSpec taskSpec) { private void executeTask(TaskSpec taskSpec) {
TaskExecutor.ActorContext actorContext = null; TaskExecutor.ActorContext actorContext = null;
UniqueId workerId;
if (taskSpec.getType() == TaskType.ACTOR_TASK) { if (taskSpec.getType() == TaskType.ACTOR_TASK) {
actorContext = actorContexts.get(getActorId(taskSpec)); actorContext = actorContexts.get(getActorId(taskSpec));
Preconditions.checkNotNull(actorContext); Preconditions.checkNotNull(actorContext);
workerId = ((LocalModeTaskExecutor.LocalActorContext) actorContext).getWorkerId();
} else {
// Actor creation task and normal task will use a new random worker id.
workerId = UniqueId.randomId();
} }
taskExecutor.setActorContext(actorContext); taskExecutor.setActorContext(workerId, actorContext);
List<NativeRayObject> args = List<NativeRayObject> args =
getFunctionArgs(taskSpec).stream() getFunctionArgs(taskSpec).stream()
.map( .map(
@ -377,10 +411,7 @@ public class LocalModeTaskSubmitter implements TaskSubmitter {
.collect(Collectors.toList()); .collect(Collectors.toList());
runtime.setIsContextSet(true); runtime.setIsContextSet(true);
((LocalModeWorkerContext) runtime.getWorkerContext()).setCurrentTask(taskSpec); ((LocalModeWorkerContext) runtime.getWorkerContext()).setCurrentTask(taskSpec);
UniqueId workerId =
actorContext != null
? ((LocalModeTaskExecutor.LocalActorContext) actorContext).getWorkerId()
: UniqueId.randomId();
((LocalModeWorkerContext) runtime.getWorkerContext()).setCurrentWorkerId(workerId); ((LocalModeWorkerContext) runtime.getWorkerContext()).setCurrentWorkerId(workerId);
List<String> rayFunctionInfo = getJavaFunctionDescriptor(taskSpec).toList(); List<String> rayFunctionInfo = getJavaFunctionDescriptor(taskSpec).toList();
taskExecutor.checkByteBufferArguments(rayFunctionInfo); taskExecutor.checkByteBufferArguments(rayFunctionInfo);
@ -388,7 +419,9 @@ public class LocalModeTaskSubmitter implements TaskSubmitter {
if (taskSpec.getType() == TaskType.ACTOR_CREATION_TASK) { if (taskSpec.getType() == TaskType.ACTOR_CREATION_TASK) {
// Update actor context map ASAP in case objectStore.putRaw triggered the next actor task // Update actor context map ASAP in case objectStore.putRaw triggered the next actor task
// on this actor. // on this actor.
actorContexts.put(getActorId(taskSpec), taskExecutor.getActorContext()); final TaskExecutor.ActorContext ac = taskExecutor.getActorContext();
Preconditions.checkNotNull(ac);
actorContexts.put(getActorId(taskSpec), ac);
} }
// Set this flag to true is necessary because at the end of `taskExecutor.execute()`, // Set this flag to true is necessary because at the end of `taskExecutor.execute()`,
// this flag will be set to false. And `runtime.getWorkerContext()` requires it to be // this flag will be set to false. And `runtime.getWorkerContext()` requires it to be
@ -460,4 +493,11 @@ public class LocalModeTaskSubmitter implements TaskSubmitter {
} }
return returnIds; return returnIds;
} }
/** Whether this is an actor creation task spec of a concurrent actor. */
private boolean isConcurrentActor(TaskSpec taskSpec) {
final ActorId actorId = getActorId(taskSpec);
Preconditions.checkNotNull(actorId);
return actorMaxConcurrency.containsKey(actorId) && actorMaxConcurrency.get(actorId) > 1;
}
} }

View file

@ -34,7 +34,6 @@ public abstract class TaskExecutor<T extends TaskExecutor.ActorContext> {
private final ThreadLocal<RayFunction> localRayFunction = new ThreadLocal<>(); private final ThreadLocal<RayFunction> localRayFunction = new ThreadLocal<>();
static class ActorContext { static class ActorContext {
/** The current actor object, if this worker is an actor, otherwise null. */ /** The current actor object, if this worker is an actor, otherwise null. */
Object currentActor = null; Object currentActor = null;
} }
@ -49,12 +48,12 @@ public abstract class TaskExecutor<T extends TaskExecutor.ActorContext> {
return actorContextMap.get(runtime.getWorkerContext().getCurrentWorkerId()); return actorContextMap.get(runtime.getWorkerContext().getCurrentWorkerId());
} }
void setActorContext(T actorContext) { void setActorContext(UniqueId workerId, T actorContext) {
if (actorContext == null) { if (actorContext == null) {
// ConcurrentHashMap doesn't allow null values. So just return here. // ConcurrentHashMap doesn't allow null values. So just return here.
return; return;
} }
this.actorContextMap.put(runtime.getWorkerContext().getCurrentWorkerId(), actorContext); this.actorContextMap.put(workerId, actorContext);
} }
protected void removeActorContext(UniqueId workerId) { protected void removeActorContext(UniqueId workerId) {
@ -93,7 +92,7 @@ public abstract class TaskExecutor<T extends TaskExecutor.ActorContext> {
T actorContext = null; T actorContext = null;
if (taskType == TaskType.ACTOR_CREATION_TASK) { if (taskType == TaskType.ACTOR_CREATION_TASK) {
actorContext = createActorContext(); actorContext = createActorContext();
setActorContext(actorContext); setActorContext(runtime.getWorkerContext().getCurrentWorkerId(), actorContext);
} else if (taskType == TaskType.ACTOR_TASK) { } else if (taskType == TaskType.ACTOR_TASK) {
actorContext = getActorContext(); actorContext = getActorContext();
Preconditions.checkNotNull(actorContext); Preconditions.checkNotNull(actorContext);

View file

@ -3,6 +3,7 @@ package io.ray.runtime.util;
import io.ray.api.id.ActorId; import io.ray.api.id.ActorId;
import io.ray.api.id.ObjectId; import io.ray.api.id.ObjectId;
import io.ray.api.id.TaskId; import io.ray.api.id.TaskId;
import java.util.Arrays;
/** /**
* Helper method for different Ids. Note: any changes to these methods must be synced with C++ * Helper method for different Ids. Note: any changes to these methods must be synced with C++
@ -24,4 +25,13 @@ public class IdUtil {
taskId.getBytes(), TaskId.UNIQUE_BYTES_LENGTH, actorIdBytes, 0, ActorId.LENGTH); taskId.getBytes(), TaskId.UNIQUE_BYTES_LENGTH, actorIdBytes, 0, ActorId.LENGTH);
return ActorId.fromBytes(actorIdBytes); return ActorId.fromBytes(actorIdBytes);
} }
/** Compute the dummy object id for actor creation task. */
public static ObjectId getActorCreationDummyObjectId(ActorId actorId) {
byte[] objectIdBytes = new byte[ObjectId.LENGTH];
Arrays.fill(objectIdBytes, (byte) 0xFF);
byte[] actorIdBytes = actorId.getBytes();
System.arraycopy(actorIdBytes, 0, objectIdBytes, 0, ActorId.LENGTH);
return new ObjectId(objectIdBytes);
}
} }

View file

@ -7,7 +7,7 @@ import java.util.concurrent.CountDownLatch;
import org.testng.Assert; import org.testng.Assert;
import org.testng.annotations.Test; import org.testng.annotations.Test;
@Test(groups = {"cluster"}) @Test
public class ActorConcurrentCallTest extends BaseTest { public class ActorConcurrentCallTest extends BaseTest {
public static class ConcurrentActor { public static class ConcurrentActor {