From d6d27e9d34c2e880743b8cfef856005797d73394 Mon Sep 17 00:00:00 2001 From: Qing Wang Date: Thu, 10 Jun 2021 23:21:11 +0800 Subject: [PATCH] [Java] Enable concurrent calls in local mode. (#14896) * Enable concurrent calls in local mode. * Fix submitting actor tasks before actor creation task executed. Co-authored-by: Qing Wang --- .../runtime/task/LocalModeTaskSubmitter.java | 56 ++++++++++++++++--- .../io/ray/runtime/task/TaskExecutor.java | 7 +-- .../main/java/io/ray/runtime/util/IdUtil.java | 10 ++++ .../io/ray/test/ActorConcurrentCallTest.java | 2 +- 4 files changed, 62 insertions(+), 13 deletions(-) diff --git a/java/runtime/src/main/java/io/ray/runtime/task/LocalModeTaskSubmitter.java b/java/runtime/src/main/java/io/ray/runtime/task/LocalModeTaskSubmitter.java index 0d292e208..ef4eefa4d 100644 --- a/java/runtime/src/main/java/io/ray/runtime/task/LocalModeTaskSubmitter.java +++ b/java/runtime/src/main/java/io/ray/runtime/task/LocalModeTaskSubmitter.java @@ -32,6 +32,7 @@ import io.ray.runtime.generated.Common.TaskType; import io.ray.runtime.object.LocalModeObjectStore; import io.ray.runtime.object.NativeRayObject; import io.ray.runtime.placementgroup.PlacementGroupImpl; +import io.ray.runtime.util.IdUtil; import java.nio.ByteBuffer; import java.util.ArrayList; import java.util.Collections; @@ -65,6 +66,8 @@ public class LocalModeTaskSubmitter implements TaskSubmitter { /// The thread pool to execute actor tasks. private final Map actorTaskExecutorServices; + private final Map actorMaxConcurrency = new ConcurrentHashMap<>(); + /// The thread pool to execute normal tasks. private final ExecutorService normalTaskExecutorService; @@ -114,13 +117,22 @@ public class LocalModeTaskSubmitter implements TaskSubmitter { } } } - if (taskSpec.getType() == TaskType.ACTOR_TASK) { + if (taskSpec.getType() == TaskType.ACTOR_TASK && !isConcurrentActor(taskSpec)) { ObjectId dummyObjectId = new ObjectId( taskSpec.getActorTaskSpec().getPreviousActorTaskDummyObjectId().toByteArray()); if (!objectStore.isObjectReady(dummyObjectId)) { unreadyObjects.add(dummyObjectId); } + } else if (taskSpec.getType() == TaskType.ACTOR_TASK) { + // Code path of concurrent actors. + // For concurrent actors, we should make sure the actor created + // before we submit the following actor tasks. + ActorId actorId = ActorId.fromBytes(taskSpec.getActorTaskSpec().getActorId().toByteArray()); + ObjectId dummyActorCreationObjectId = IdUtil.getActorCreationDummyObjectId(actorId); + if (!objectStore.isObjectReady(dummyActorCreationObjectId)) { + unreadyObjects.add(dummyActorCreationObjectId); + } } return unreadyObjects; } @@ -198,6 +210,7 @@ public class LocalModeTaskSubmitter implements TaskSubmitter { .setActorCreationTaskSpec( ActorCreationTaskSpec.newBuilder() .setActorId(ByteString.copyFrom(actorId.toByteBuffer())) + .setMaxConcurrency(options.maxConcurrency) .build()) .build(); submitTaskSpec(taskSpec); @@ -320,17 +333,33 @@ public class LocalModeTaskSubmitter implements TaskSubmitter { () -> { try { executeTask(taskSpec); + if (taskSpec.getType() == TaskType.ACTOR_CREATION_TASK) { + // Construct a dummy object id for actor creation task so that the following + // actor task can touch if this actor is created. + ObjectId dummy = + IdUtil.getActorCreationDummyObjectId( + ActorId.fromBytes( + taskSpec.getActorCreationTaskSpec().getActorId().toByteArray())); + objectStore.put(new Object(), dummy); + } } catch (Exception ex) { LOGGER.error("Unexpected exception when executing a task.", ex); System.exit(-1); } }; + if (taskSpec.getType() == TaskType.ACTOR_CREATION_TASK) { + actorMaxConcurrency.put( + getActorId(taskSpec), taskSpec.getActorCreationTaskSpec().getMaxConcurrency()); + } + if (unreadyObjects.isEmpty()) { // If all dependencies are ready, execute this task. ExecutorService executorService; if (taskSpec.getType() == TaskType.ACTOR_CREATION_TASK) { - executorService = Executors.newSingleThreadExecutor(); + final int maxConcurrency = taskSpec.getActorCreationTaskSpec().getMaxConcurrency(); + Preconditions.checkState(maxConcurrency >= 1); + executorService = Executors.newFixedThreadPool(maxConcurrency); synchronized (actorTaskExecutorServices) { actorTaskExecutorServices.put(getActorId(taskSpec), executorService); } @@ -362,11 +391,16 @@ public class LocalModeTaskSubmitter implements TaskSubmitter { private void executeTask(TaskSpec taskSpec) { TaskExecutor.ActorContext actorContext = null; + UniqueId workerId; if (taskSpec.getType() == TaskType.ACTOR_TASK) { actorContext = actorContexts.get(getActorId(taskSpec)); Preconditions.checkNotNull(actorContext); + workerId = ((LocalModeTaskExecutor.LocalActorContext) actorContext).getWorkerId(); + } else { + // Actor creation task and normal task will use a new random worker id. + workerId = UniqueId.randomId(); } - taskExecutor.setActorContext(actorContext); + taskExecutor.setActorContext(workerId, actorContext); List args = getFunctionArgs(taskSpec).stream() .map( @@ -377,10 +411,7 @@ public class LocalModeTaskSubmitter implements TaskSubmitter { .collect(Collectors.toList()); runtime.setIsContextSet(true); ((LocalModeWorkerContext) runtime.getWorkerContext()).setCurrentTask(taskSpec); - UniqueId workerId = - actorContext != null - ? ((LocalModeTaskExecutor.LocalActorContext) actorContext).getWorkerId() - : UniqueId.randomId(); + ((LocalModeWorkerContext) runtime.getWorkerContext()).setCurrentWorkerId(workerId); List rayFunctionInfo = getJavaFunctionDescriptor(taskSpec).toList(); taskExecutor.checkByteBufferArguments(rayFunctionInfo); @@ -388,7 +419,9 @@ public class LocalModeTaskSubmitter implements TaskSubmitter { if (taskSpec.getType() == TaskType.ACTOR_CREATION_TASK) { // Update actor context map ASAP in case objectStore.putRaw triggered the next actor task // on this actor. - actorContexts.put(getActorId(taskSpec), taskExecutor.getActorContext()); + final TaskExecutor.ActorContext ac = taskExecutor.getActorContext(); + Preconditions.checkNotNull(ac); + actorContexts.put(getActorId(taskSpec), ac); } // Set this flag to true is necessary because at the end of `taskExecutor.execute()`, // this flag will be set to false. And `runtime.getWorkerContext()` requires it to be @@ -460,4 +493,11 @@ public class LocalModeTaskSubmitter implements TaskSubmitter { } return returnIds; } + + /** Whether this is an actor creation task spec of a concurrent actor. */ + private boolean isConcurrentActor(TaskSpec taskSpec) { + final ActorId actorId = getActorId(taskSpec); + Preconditions.checkNotNull(actorId); + return actorMaxConcurrency.containsKey(actorId) && actorMaxConcurrency.get(actorId) > 1; + } } diff --git a/java/runtime/src/main/java/io/ray/runtime/task/TaskExecutor.java b/java/runtime/src/main/java/io/ray/runtime/task/TaskExecutor.java index 6d19d39ca..c400acbdf 100644 --- a/java/runtime/src/main/java/io/ray/runtime/task/TaskExecutor.java +++ b/java/runtime/src/main/java/io/ray/runtime/task/TaskExecutor.java @@ -34,7 +34,6 @@ public abstract class TaskExecutor { private final ThreadLocal localRayFunction = new ThreadLocal<>(); static class ActorContext { - /** The current actor object, if this worker is an actor, otherwise null. */ Object currentActor = null; } @@ -49,12 +48,12 @@ public abstract class TaskExecutor { return actorContextMap.get(runtime.getWorkerContext().getCurrentWorkerId()); } - void setActorContext(T actorContext) { + void setActorContext(UniqueId workerId, T actorContext) { if (actorContext == null) { // ConcurrentHashMap doesn't allow null values. So just return here. return; } - this.actorContextMap.put(runtime.getWorkerContext().getCurrentWorkerId(), actorContext); + this.actorContextMap.put(workerId, actorContext); } protected void removeActorContext(UniqueId workerId) { @@ -93,7 +92,7 @@ public abstract class TaskExecutor { T actorContext = null; if (taskType == TaskType.ACTOR_CREATION_TASK) { actorContext = createActorContext(); - setActorContext(actorContext); + setActorContext(runtime.getWorkerContext().getCurrentWorkerId(), actorContext); } else if (taskType == TaskType.ACTOR_TASK) { actorContext = getActorContext(); Preconditions.checkNotNull(actorContext); diff --git a/java/runtime/src/main/java/io/ray/runtime/util/IdUtil.java b/java/runtime/src/main/java/io/ray/runtime/util/IdUtil.java index 239568afa..26343e279 100644 --- a/java/runtime/src/main/java/io/ray/runtime/util/IdUtil.java +++ b/java/runtime/src/main/java/io/ray/runtime/util/IdUtil.java @@ -3,6 +3,7 @@ package io.ray.runtime.util; import io.ray.api.id.ActorId; import io.ray.api.id.ObjectId; import io.ray.api.id.TaskId; +import java.util.Arrays; /** * Helper method for different Ids. Note: any changes to these methods must be synced with C++ @@ -24,4 +25,13 @@ public class IdUtil { taskId.getBytes(), TaskId.UNIQUE_BYTES_LENGTH, actorIdBytes, 0, ActorId.LENGTH); return ActorId.fromBytes(actorIdBytes); } + + /** Compute the dummy object id for actor creation task. */ + public static ObjectId getActorCreationDummyObjectId(ActorId actorId) { + byte[] objectIdBytes = new byte[ObjectId.LENGTH]; + Arrays.fill(objectIdBytes, (byte) 0xFF); + byte[] actorIdBytes = actorId.getBytes(); + System.arraycopy(actorIdBytes, 0, objectIdBytes, 0, ActorId.LENGTH); + return new ObjectId(objectIdBytes); + } } diff --git a/java/test/src/main/java/io/ray/test/ActorConcurrentCallTest.java b/java/test/src/main/java/io/ray/test/ActorConcurrentCallTest.java index de7cf6cfc..1fdb53294 100644 --- a/java/test/src/main/java/io/ray/test/ActorConcurrentCallTest.java +++ b/java/test/src/main/java/io/ray/test/ActorConcurrentCallTest.java @@ -7,7 +7,7 @@ import java.util.concurrent.CountDownLatch; import org.testng.Assert; import org.testng.annotations.Test; -@Test(groups = {"cluster"}) +@Test public class ActorConcurrentCallTest extends BaseTest { public static class ConcurrentActor {