From 0ad0839265750b9ee4d142b623e92a38fa72631d Mon Sep 17 00:00:00 2001 From: wanxing Date: Wed, 14 Apr 2021 11:02:49 +0800 Subject: [PATCH] Optimize lambda copy to improve direct call performance. (#15036) --- java/BUILD.bazel | 16 +++ java/performance_test/pom.xml | 56 +++++++++ java/performance_test/pom_template.xml | 32 +++++ .../java/io/ray/performancetest/Receiver.java | 40 ++++++ .../java/io/ray/performancetest/Source.java | 117 ++++++++++++++++++ .../test/ActorPerformanceTestBase.java | 76 ++++++++++++ .../test/ActorPerformanceTestCase1.java | 27 ++++ java/test.sh | 25 ++++ src/ray/common/grpc_util.h | 8 +- src/ray/common/task/task_spec.cc | 21 ++-- src/ray/common/task/task_spec.h | 14 ++- src/ray/core_worker/core_worker.cc | 2 +- .../core_worker/test/scheduling_queue_test.cc | 82 ++++++------ .../transport/direct_actor_transport.cc | 39 +++--- .../transport/direct_actor_transport.h | 41 +++--- src/ray/rpc/client_call.h | 3 +- src/ray/rpc/worker/core_worker_client.h | 12 +- 17 files changed, 513 insertions(+), 98 deletions(-) create mode 100755 java/performance_test/pom.xml create mode 100644 java/performance_test/pom_template.xml create mode 100644 java/performance_test/src/main/java/io/ray/performancetest/Receiver.java create mode 100644 java/performance_test/src/main/java/io/ray/performancetest/Source.java create mode 100644 java/performance_test/src/main/java/io/ray/performancetest/test/ActorPerformanceTestBase.java create mode 100644 java/performance_test/src/main/java/io/ray/performancetest/test/ActorPerformanceTestCase1.java diff --git a/java/BUILD.bazel b/java/BUILD.bazel index 119251c39..75de8b843 100644 --- a/java/BUILD.bazel +++ b/java/BUILD.bazel @@ -13,6 +13,7 @@ all_modules = [ "api", "runtime", "test", + "performance_test", ] java_import( @@ -106,12 +107,26 @@ define_java_module( ], ) +define_java_module( + name = "performance_test", + deps = [ + ":io_ray_ray_api", + ":io_ray_ray_runtime", + "@maven//:com_google_code_gson_gson", + "@maven//:com_google_guava_guava", + "@maven//:commons_io_commons_io", + "@maven//:org_apache_commons_commons_lang3", + "@maven//:org_slf4j_slf4j_api", + ], +) + java_binary( name = "all_tests", args = ["java/testng.xml"], data = ["testng.xml"], main_class = "org.testng.TestNG", runtime_deps = [ + ":io_ray_ray_performance_test", ":io_ray_ray_runtime_test", ":io_ray_ray_test", ], @@ -207,6 +222,7 @@ genrule( cp -f $(location //java:io_ray_ray_api_pom) "$$WORK_DIR/java/api/pom.xml" cp -f $(location //java:io_ray_ray_runtime_pom) "$$WORK_DIR/java/runtime/pom.xml" cp -f $(location //java:io_ray_ray_test_pom) "$$WORK_DIR/java/test/pom.xml" + cp -f $(location //java:io_ray_ray_performance_test_pom) "$$WORK_DIR/java/performance_test/pom.xml" date > $@ """, local = 1, diff --git a/java/performance_test/pom.xml b/java/performance_test/pom.xml new file mode 100755 index 000000000..f58dd123e --- /dev/null +++ b/java/performance_test/pom.xml @@ -0,0 +1,56 @@ + + + + + io.ray + ray-superpom + 1.0.0-SNAPSHOT + + 4.0.0 + + ray-performance-test + java performance test cases for ray + java performance test cases for ray + + jar + + + + io.ray + ray-api + ${project.version} + + + io.ray + ray-runtime + ${project.version} + + + com.google.code.gson + gson + 2.8.5 + + + com.google.guava + guava + 27.0.1-jre + + + commons-io + commons-io + 2.5 + + + org.apache.commons + commons-lang3 + 3.4 + + + org.slf4j + slf4j-api + 1.7.25 + + + diff --git a/java/performance_test/pom_template.xml b/java/performance_test/pom_template.xml new file mode 100644 index 000000000..f7109331e --- /dev/null +++ b/java/performance_test/pom_template.xml @@ -0,0 +1,32 @@ + +{auto_gen_header} + + + io.ray + ray-superpom + 1.0.0-SNAPSHOT + + 4.0.0 + + ray-performance-test + java performance test cases for ray + java performance test cases for ray + + jar + + + + io.ray + ray-api + ${project.version} + + + io.ray + ray-runtime + ${project.version} + + {generated_bzl_deps} + + diff --git a/java/performance_test/src/main/java/io/ray/performancetest/Receiver.java b/java/performance_test/src/main/java/io/ray/performancetest/Receiver.java new file mode 100644 index 000000000..8151ac631 --- /dev/null +++ b/java/performance_test/src/main/java/io/ray/performancetest/Receiver.java @@ -0,0 +1,40 @@ +package io.ray.performancetest; + +import java.nio.ByteBuffer; + +public class Receiver { + private int value = 0; + + public Receiver() {} + + public boolean ping() { + return true; + } + + public void noArgsNoReturn() { + value += 1; + } + + public int noArgsHasReturn() { + value += 1; + return value; + } + + public void bytesNoReturn(byte[] data) { + value += 1; + } + + public int bytesHasReturn(byte[] data) { + value += 1; + return value; + } + + public void byteBufferNoReturn(ByteBuffer data) { + value += 1; + } + + public int byteBufferHasReturn(ByteBuffer data) { + value += 1; + return value; + } +} diff --git a/java/performance_test/src/main/java/io/ray/performancetest/Source.java b/java/performance_test/src/main/java/io/ray/performancetest/Source.java new file mode 100644 index 000000000..bbd1e5be8 --- /dev/null +++ b/java/performance_test/src/main/java/io/ray/performancetest/Source.java @@ -0,0 +1,117 @@ +package io.ray.performancetest; + +import com.google.common.base.Preconditions; +import io.ray.api.ActorHandle; +import io.ray.api.ObjectRef; +import io.ray.api.Ray; +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.List; +import java.util.Random; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class Source { + private static final int BATCH_SIZE; + + private static final Logger LOGGER = LoggerFactory.getLogger(Source.class); + + private final List> receivers; + + static { + String batchSizeString = System.getenv().get("PERF_TEST_BATCH_SIZE"); + if (batchSizeString != null) { + BATCH_SIZE = Integer.valueOf(batchSizeString); + } else { + BATCH_SIZE = 1000; + } + } + + public Source(List> receivers) { + this.receivers = receivers; + } + + public boolean startTest( + boolean hasReturn, boolean ignoreReturn, int argSize, boolean useDirectByteBuffer) { + LOGGER.info("Source startTest"); + byte[] bytes = null; + ByteBuffer buffer = null; + if (argSize > 0) { + bytes = new byte[argSize]; + new Random().nextBytes(bytes); + buffer = ByteBuffer.wrap(bytes); + } else { + Preconditions.checkState(!useDirectByteBuffer); + } + + // Wait for actors to be created. + for (ActorHandle receiver : receivers) { + receiver.task(Receiver::ping).remote().get(); + } + + LOGGER.info( + "Started executing tasks, useDirectByteBuffer: {}, argSize: {}, has return: {}", + useDirectByteBuffer, + argSize, + hasReturn); + + List>> returnObjects = new ArrayList<>(); + returnObjects.add(new ArrayList<>()); + returnObjects.add(new ArrayList<>()); + + long startTime = System.currentTimeMillis(); + int numTasks = 0; + long lastReport = 0; + long totalTime = 0; + long batchCount = 0; + while (true) { + numTasks++; + boolean batchEnd = numTasks % BATCH_SIZE == 0; + for (ActorHandle receiver : receivers) { + if (hasReturn || batchEnd) { + ObjectRef returnObject; + if (useDirectByteBuffer) { + returnObject = receiver.task(Receiver::byteBufferHasReturn, buffer).remote(); + } else if (argSize > 0) { + returnObject = receiver.task(Receiver::bytesHasReturn, bytes).remote(); + } else { + returnObject = receiver.task(Receiver::noArgsHasReturn).remote(); + } + returnObjects.get(1).add(returnObject); + } else { + if (useDirectByteBuffer) { + receiver.task(Receiver::byteBufferNoReturn, buffer).remote(); + } else if (argSize > 0) { + receiver.task(Receiver::bytesNoReturn, bytes).remote(); + } else { + receiver.task(Receiver::noArgsNoReturn).remote(); + } + } + } + + if (batchEnd) { + batchCount++; + long getBeginTs = System.currentTimeMillis(); + Ray.get(returnObjects.get(0)); + long rt = System.currentTimeMillis() - getBeginTs; + totalTime += rt; + returnObjects.set(0, returnObjects.get(1)); + returnObjects.set(1, new ArrayList<>()); + + long elapsedTime = System.currentTimeMillis() - startTime; + if (elapsedTime / 60000 > lastReport) { + lastReport = elapsedTime / 60000; + LOGGER.info( + "Finished executing {} tasks in {} ms, useDirectByteBuffer: {}, argSize: {}, " + + "has return: {}, avg get rt: {}", + numTasks, + elapsedTime, + useDirectByteBuffer, + argSize, + hasReturn, + totalTime / (float) batchCount); + } + } + } + } +} diff --git a/java/performance_test/src/main/java/io/ray/performancetest/test/ActorPerformanceTestBase.java b/java/performance_test/src/main/java/io/ray/performancetest/test/ActorPerformanceTestBase.java new file mode 100644 index 000000000..84424a586 --- /dev/null +++ b/java/performance_test/src/main/java/io/ray/performancetest/test/ActorPerformanceTestBase.java @@ -0,0 +1,76 @@ +package io.ray.performancetest.test; + +import com.google.common.base.Preconditions; +import io.ray.api.ActorHandle; +import io.ray.api.ObjectRef; +import io.ray.api.Ray; +import io.ray.performancetest.Receiver; +import io.ray.performancetest.Source; +import java.util.ArrayList; +import java.util.List; +import java.util.stream.Collectors; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class ActorPerformanceTestBase { + + private static final Logger LOGGER = LoggerFactory.getLogger(ActorPerformanceTestBase.class); + + public static void run( + String[] args, + int[] layers, + int[] actorsPerLayer, + boolean hasReturn, + boolean ignoreReturn, + int argSize, + boolean useDirectByteBuffer, + int numJavaWorkerPerProcess) { + System.setProperty( + "ray.job.num-java-workers-per-process", String.valueOf(numJavaWorkerPerProcess)); + Ray.init(); + try { + // TODO: Support more layers. + Preconditions.checkState(layers.length == 2); + Preconditions.checkState(actorsPerLayer.length == layers.length); + for (int i = 0; i < layers.length; i++) { + Preconditions.checkState(layers[i] > 0); + Preconditions.checkState(actorsPerLayer[i] > 0); + } + + List> receivers = new ArrayList<>(); + for (int i = 0; i < layers[1]; i++) { + int nodeIndex = layers[0] + i; + for (int j = 0; j < actorsPerLayer[1]; j++) { + receivers.add(Ray.actor(Receiver::new).remote()); + } + } + + List> sources = new ArrayList<>(); + for (int i = 0; i < layers[0]; i++) { + int nodeIndex = i; + for (int j = 0; j < actorsPerLayer[0]; j++) { + sources.add(Ray.actor(Source::new, receivers).remote()); + } + } + + List> results = + sources.stream() + .map( + source -> + source + .task( + Source::startTest, + hasReturn, + ignoreReturn, + argSize, + useDirectByteBuffer) + .remote()) + .collect(Collectors.toList()); + + Ray.get(results); + } catch (Throwable e) { + LOGGER.error("Run test failed.", e); + throw e; + } + } +} diff --git a/java/performance_test/src/main/java/io/ray/performancetest/test/ActorPerformanceTestCase1.java b/java/performance_test/src/main/java/io/ray/performancetest/test/ActorPerformanceTestCase1.java new file mode 100644 index 000000000..c095d39ba --- /dev/null +++ b/java/performance_test/src/main/java/io/ray/performancetest/test/ActorPerformanceTestCase1.java @@ -0,0 +1,27 @@ +package io.ray.performancetest.test; + +/** + * 1-to-1 ray call, one receiving actor and one sending actor, test the throughput of the engine + * itself. + */ +public class ActorPerformanceTestCase1 { + + public static void main(String[] args) { + final int[] layers = new int[] {1, 1}; + final int[] actorsPerLayer = new int[] {1, 1}; + final boolean hasReturn = false; + final int argSize = 0; + final boolean useDirectByteBuffer = false; + final boolean ignoreReturn = false; + final int numJavaWorkerPerProcess = 1; + ActorPerformanceTestBase.run( + args, + layers, + actorsPerLayer, + hasReturn, + ignoreReturn, + argSize, + useDirectByteBuffer, + numJavaWorkerPerProcess); + } +} diff --git a/java/test.sh b/java/test.sh index a3f098f8e..4ae30dced 100755 --- a/java/test.sh +++ b/java/test.sh @@ -41,6 +41,24 @@ run_testng() { fi } +run_timeout() { + local pid + timeout=$1 + shift 1 + "$@" & + pid=$! + sleep "$timeout" + if ps -p $pid > /dev/null + then + echo "run_timeout process exists, kill it." + kill -9 $pid + else + echo "run_timeout process not exist." + cat /tmp/ray/session_latest/logs/java-core-driver-*$pid* + exit 1 + fi +} + pushd "$ROOT_DIR"/.. echo "Build java maven deps." bazel build //java:gen_maven_deps @@ -114,3 +132,10 @@ mvn -Dorg.slf4j.simpleLogger.defaultLogLevel=WARN clean install -DskipTests -Dch # Ensure mvn test works mvn test -pl test -Dtest="io.ray.test.HelloWorldTest" popd + +pushd "$ROOT_DIR" +echo "Running performance test." +run_timeout 60 java -cp "$ROOT_DIR"/../bazel-bin/java/all_tests_deploy.jar io.ray.performancetest.test.ActorPerformanceTestCase1 +# The performance process may be killed by run_timeout, so clear ray here. +ray stop +popd diff --git a/src/ray/common/grpc_util.h b/src/ray/common/grpc_util.h index e6d2e1c90..b3629bd61 100644 --- a/src/ray/common/grpc_util.h +++ b/src/ray/common/grpc_util.h @@ -35,7 +35,10 @@ class MessageWrapper { /// The input message will be **copied** into this object. /// /// \param message The protobuf message. - explicit MessageWrapper(const Message message) + explicit MessageWrapper(const Message &message) + : message_(std::make_shared(message)) {} + + explicit MessageWrapper(Message &&message) : message_(std::make_shared(std::move(message))) {} /// Construct from a protobuf message shared_ptr. @@ -115,7 +118,8 @@ inline std::vector IdVectorFromProtobuf( /// Converts a Protobuf map to a `unordered_map`. template -inline std::unordered_map MapFromProtobuf(::google::protobuf::Map pb_map) { +inline std::unordered_map MapFromProtobuf( + const ::google::protobuf::Map &pb_map) { return std::unordered_map(pb_map.begin(), pb_map.end()); } diff --git a/src/ray/common/task/task_spec.cc b/src/ray/common/task/task_spec.cc index a07f5bade..5a0bf5946 100644 --- a/src/ray/common/task/task_spec.cc +++ b/src/ray/common/task/task_spec.cc @@ -53,24 +53,24 @@ bool TaskSpecification::PlacementGroupCaptureChildTasks() const { } void TaskSpecification::ComputeResources() { - auto required_resources = MapFromProtobuf(message_->required_resources()); - auto required_placement_resources = - MapFromProtobuf(message_->required_placement_resources()); - if (required_placement_resources.empty()) { - required_placement_resources = required_resources; - } + auto &required_resources = message_->required_resources(); if (required_resources.empty()) { // A static nil object is used here to avoid allocating the empty object every time. required_resources_ = ResourceSet::Nil(); } else { - required_resources_.reset(new ResourceSet(required_resources)); + required_resources_.reset(new ResourceSet(MapFromProtobuf(required_resources))); } + auto &required_placement_resources = message_->required_placement_resources().empty() + ? required_resources + : message_->required_placement_resources(); + if (required_placement_resources.empty()) { required_placement_resources_ = ResourceSet::Nil(); } else { - required_placement_resources_.reset(new ResourceSet(required_placement_resources)); + required_placement_resources_.reset( + new ResourceSet(MapFromProtobuf(required_placement_resources))); } if (!IsActorTask()) { @@ -174,14 +174,15 @@ std::vector TaskSpecification::GetDependencyIds() const { return dependencies; } -std::vector TaskSpecification::GetDependencies() const { +std::vector TaskSpecification::GetDependencies( + bool add_dummy_dependency) const { std::vector dependencies; for (size_t i = 0; i < NumArgs(); ++i) { if (ArgByRef(i)) { dependencies.push_back(message_->args(i).object_ref()); } } - if (IsActorTask()) { + if (add_dummy_dependency && IsActorTask()) { const auto &dummy_ref = GetReferenceForActorDummyObject(PreviousActorTaskDummyObjectId()); dependencies.push_back(dummy_ref); diff --git a/src/ray/common/task/task_spec.h b/src/ray/common/task/task_spec.h index 4868e7100..3f2d039fc 100644 --- a/src/ray/common/task/task_spec.h +++ b/src/ray/common/task/task_spec.h @@ -36,10 +36,15 @@ class TaskSpecification : public MessageWrapper { TaskSpecification() {} /// Construct from a protobuf message object. - /// The input message will be **copied** into this object. + /// The input message will be copied/moved into this object. /// /// \param message The protobuf message. - explicit TaskSpecification(rpc::TaskSpec message) : MessageWrapper(message) { + explicit TaskSpecification(rpc::TaskSpec &&message) + : MessageWrapper(std::move(message)) { + ComputeResources(); + } + + explicit TaskSpecification(const rpc::TaskSpec &message) : MessageWrapper(message) { ComputeResources(); } @@ -127,9 +132,10 @@ class TaskSpecification : public MessageWrapper { /// Return the dependencies of this task. This is recomputed each time, so it can /// be used if the task spec is mutated. - /// + /// \param add_dummy_dependency whether to add a dummy object in the returned objects. /// \return The recomputed dependencies for the task. - std::vector GetDependencies() const; + std::vector GetDependencies( + bool add_dummy_dependency = true) const; std::string GetDebuggerBreakpoint() const; diff --git a/src/ray/core_worker/core_worker.cc b/src/ray/core_worker/core_worker.cc index b6a6fdbb1..6de12f144 100644 --- a/src/ray/core_worker/core_worker.cc +++ b/src/ray/core_worker/core_worker.cc @@ -2219,7 +2219,7 @@ void CoreWorker::HandlePushTask(const rpc::PushTaskRequest &request, // execution service. if (request.task_spec().type() == TaskType::ACTOR_TASK) { task_execution_service_.post( - [=] { + [this, request, reply, send_reply_callback = std::move(send_reply_callback)] { // We have posted an exit task onto the main event loop, // so shouldn't bother executing any further work. if (exiting_) return; diff --git a/src/ray/core_worker/test/scheduling_queue_test.cc b/src/ray/core_worker/test/scheduling_queue_test.cc index 9f3066965..ab417a7ed 100644 --- a/src/ray/core_worker/test/scheduling_queue_test.cc +++ b/src/ray/core_worker/test/scheduling_queue_test.cc @@ -42,12 +42,12 @@ TEST(SchedulingQueueTest, TestInOrder) { ActorSchedulingQueue queue(io_service, waiter); int n_ok = 0; int n_rej = 0; - auto fn_ok = [&n_ok]() { n_ok++; }; - auto fn_rej = [&n_rej]() { n_rej++; }; - queue.Add(0, -1, fn_ok, fn_rej); - queue.Add(1, -1, fn_ok, fn_rej); - queue.Add(2, -1, fn_ok, fn_rej); - queue.Add(3, -1, fn_ok, fn_rej); + auto fn_ok = [&n_ok](rpc::SendReplyCallback callback) { n_ok++; }; + auto fn_rej = [&n_rej](rpc::SendReplyCallback callback) { n_rej++; }; + queue.Add(0, -1, fn_ok, fn_rej, nullptr); + queue.Add(1, -1, fn_ok, fn_rej, nullptr); + queue.Add(2, -1, fn_ok, fn_rej, nullptr); + queue.Add(3, -1, fn_ok, fn_rej, nullptr); io_service.run(); ASSERT_EQ(n_ok, 4); ASSERT_EQ(n_rej, 0); @@ -62,12 +62,12 @@ TEST(SchedulingQueueTest, TestWaitForObjects) { ActorSchedulingQueue queue(io_service, waiter); int n_ok = 0; int n_rej = 0; - auto fn_ok = [&n_ok]() { n_ok++; }; - auto fn_rej = [&n_rej]() { n_rej++; }; - queue.Add(0, -1, fn_ok, fn_rej); - queue.Add(1, -1, fn_ok, fn_rej, TaskID::Nil(), ObjectIdsToRefs({obj1})); - queue.Add(2, -1, fn_ok, fn_rej, TaskID::Nil(), ObjectIdsToRefs({obj2})); - queue.Add(3, -1, fn_ok, fn_rej, TaskID::Nil(), ObjectIdsToRefs({obj3})); + auto fn_ok = [&n_ok](rpc::SendReplyCallback callback) { n_ok++; }; + auto fn_rej = [&n_rej](rpc::SendReplyCallback callback) { n_rej++; }; + queue.Add(0, -1, fn_ok, fn_rej, nullptr); + queue.Add(1, -1, fn_ok, fn_rej, nullptr, TaskID::Nil(), ObjectIdsToRefs({obj1})); + queue.Add(2, -1, fn_ok, fn_rej, nullptr, TaskID::Nil(), ObjectIdsToRefs({obj2})); + queue.Add(3, -1, fn_ok, fn_rej, nullptr, TaskID::Nil(), ObjectIdsToRefs({obj3})); ASSERT_EQ(n_ok, 1); waiter.Complete(0); @@ -87,10 +87,10 @@ TEST(SchedulingQueueTest, TestWaitForObjectsNotSubjectToSeqTimeout) { ActorSchedulingQueue queue(io_service, waiter); int n_ok = 0; int n_rej = 0; - auto fn_ok = [&n_ok]() { n_ok++; }; - auto fn_rej = [&n_rej]() { n_rej++; }; - queue.Add(0, -1, fn_ok, fn_rej); - queue.Add(1, -1, fn_ok, fn_rej, TaskID::Nil(), ObjectIdsToRefs({obj1})); + auto fn_ok = [&n_ok](rpc::SendReplyCallback callback) { n_ok++; }; + auto fn_rej = [&n_rej](rpc::SendReplyCallback callback) { n_rej++; }; + queue.Add(0, -1, fn_ok, fn_rej, nullptr); + queue.Add(1, -1, fn_ok, fn_rej, nullptr, TaskID::Nil(), ObjectIdsToRefs({obj1})); ASSERT_EQ(n_ok, 1); io_service.run(); ASSERT_EQ(n_rej, 0); @@ -104,12 +104,12 @@ TEST(SchedulingQueueTest, TestOutOfOrder) { ActorSchedulingQueue queue(io_service, waiter); int n_ok = 0; int n_rej = 0; - auto fn_ok = [&n_ok]() { n_ok++; }; - auto fn_rej = [&n_rej]() { n_rej++; }; - queue.Add(2, -1, fn_ok, fn_rej); - queue.Add(0, -1, fn_ok, fn_rej); - queue.Add(3, -1, fn_ok, fn_rej); - queue.Add(1, -1, fn_ok, fn_rej); + auto fn_ok = [&n_ok](rpc::SendReplyCallback callback) { n_ok++; }; + auto fn_rej = [&n_rej](rpc::SendReplyCallback callback) { n_rej++; }; + queue.Add(2, -1, fn_ok, fn_rej, nullptr); + queue.Add(0, -1, fn_ok, fn_rej, nullptr); + queue.Add(3, -1, fn_ok, fn_rej, nullptr); + queue.Add(1, -1, fn_ok, fn_rej, nullptr); io_service.run(); ASSERT_EQ(n_ok, 4); ASSERT_EQ(n_rej, 0); @@ -121,18 +121,18 @@ TEST(SchedulingQueueTest, TestSeqWaitTimeout) { ActorSchedulingQueue queue(io_service, waiter); int n_ok = 0; int n_rej = 0; - auto fn_ok = [&n_ok]() { n_ok++; }; - auto fn_rej = [&n_rej]() { n_rej++; }; - queue.Add(2, -1, fn_ok, fn_rej); - queue.Add(0, -1, fn_ok, fn_rej); - queue.Add(3, -1, fn_ok, fn_rej); + auto fn_ok = [&n_ok](rpc::SendReplyCallback callback) { n_ok++; }; + auto fn_rej = [&n_rej](rpc::SendReplyCallback callback) { n_rej++; }; + queue.Add(2, -1, fn_ok, fn_rej, nullptr); + queue.Add(0, -1, fn_ok, fn_rej, nullptr); + queue.Add(3, -1, fn_ok, fn_rej, nullptr); ASSERT_EQ(n_ok, 1); ASSERT_EQ(n_rej, 0); io_service.run(); // immediately triggers timeout ASSERT_EQ(n_ok, 1); ASSERT_EQ(n_rej, 2); - queue.Add(4, -1, fn_ok, fn_rej); - queue.Add(5, -1, fn_ok, fn_rej); + queue.Add(4, -1, fn_ok, fn_rej, nullptr); + queue.Add(5, -1, fn_ok, fn_rej, nullptr); ASSERT_EQ(n_ok, 3); ASSERT_EQ(n_rej, 2); } @@ -143,11 +143,11 @@ TEST(SchedulingQueueTest, TestSkipAlreadyProcessedByClient) { ActorSchedulingQueue queue(io_service, waiter); int n_ok = 0; int n_rej = 0; - auto fn_ok = [&n_ok]() { n_ok++; }; - auto fn_rej = [&n_rej]() { n_rej++; }; - queue.Add(2, 2, fn_ok, fn_rej); - queue.Add(3, 2, fn_ok, fn_rej); - queue.Add(1, 2, fn_ok, fn_rej); + auto fn_ok = [&n_ok](rpc::SendReplyCallback callback) { n_ok++; }; + auto fn_rej = [&n_rej](rpc::SendReplyCallback callback) { n_rej++; }; + queue.Add(2, 2, fn_ok, fn_rej, nullptr); + queue.Add(3, 2, fn_ok, fn_rej, nullptr); + queue.Add(1, 2, fn_ok, fn_rej, nullptr); io_service.run(); ASSERT_EQ(n_ok, 1); ASSERT_EQ(n_rej, 2); @@ -158,13 +158,13 @@ TEST(SchedulingQueueTest, TestCancelQueuedTask) { ASSERT_TRUE(queue->TaskQueueEmpty()); int n_ok = 0; int n_rej = 0; - auto fn_ok = [&n_ok]() { n_ok++; }; - auto fn_rej = [&n_rej]() { n_rej++; }; - queue->Add(-1, -1, fn_ok, fn_rej); - queue->Add(-1, -1, fn_ok, fn_rej); - queue->Add(-1, -1, fn_ok, fn_rej); - queue->Add(-1, -1, fn_ok, fn_rej); - queue->Add(-1, -1, fn_ok, fn_rej); + auto fn_ok = [&n_ok](rpc::SendReplyCallback callback) { n_ok++; }; + auto fn_rej = [&n_rej](rpc::SendReplyCallback callback) { n_rej++; }; + queue->Add(-1, -1, fn_ok, fn_rej, nullptr); + queue->Add(-1, -1, fn_ok, fn_rej, nullptr); + queue->Add(-1, -1, fn_ok, fn_rej, nullptr); + queue->Add(-1, -1, fn_ok, fn_rej, nullptr); + queue->Add(-1, -1, fn_ok, fn_rej, nullptr); ASSERT_TRUE(queue->CancelTaskIfFound(TaskID::Nil())); ASSERT_FALSE(queue->TaskQueueEmpty()); queue->ScheduleRequests(); diff --git a/src/ray/core_worker/transport/direct_actor_transport.cc b/src/ray/core_worker/transport/direct_actor_transport.cc index a8c5ae00a..d6f94fc82 100644 --- a/src/ray/core_worker/transport/direct_actor_transport.cc +++ b/src/ray/core_worker/transport/direct_actor_transport.cc @@ -59,14 +59,16 @@ void CoreWorkerDirectActorTaskSubmitter::KillActor(const ActorID &actor_id, } Status CoreWorkerDirectActorTaskSubmitter::SubmitTask(TaskSpecification task_spec) { - RAY_LOG(DEBUG) << "Submitting task " << task_spec.TaskId(); + auto task_id = task_spec.TaskId(); + auto actor_id = task_spec.ActorId(); + RAY_LOG(DEBUG) << "Submitting task " << task_id; RAY_CHECK(task_spec.IsActorTask()); bool task_queued = false; uint64_t send_pos = 0; { absl::MutexLock lock(&mu_); - auto queue = client_queues_.find(task_spec.ActorId()); + auto queue = client_queues_.find(actor_id); RAY_CHECK(queue != client_queues_.end()); if (queue->second.state != rpc::ActorTableData::DEAD) { // We must fix the send order prior to resolving dependencies, which may @@ -82,7 +84,6 @@ Status CoreWorkerDirectActorTaskSubmitter::SubmitTask(TaskSpecification task_spe } if (task_queued) { - const auto actor_id = task_spec.ActorId(); // We must release the lock before resolving the task dependencies since // the callback may get called in the same call stack. resolver_.ResolveDependencies(task_spec, [this, send_pos, actor_id]() { @@ -99,7 +100,7 @@ Status CoreWorkerDirectActorTaskSubmitter::SubmitTask(TaskSpecification task_spe }); } else { // Do not hold the lock while calling into task_finisher_. - task_finisher_->MarkTaskCanceled(task_spec.TaskId()); + task_finisher_->MarkTaskCanceled(task_id); std::shared_ptr creation_task_exception = nullptr; { absl::MutexLock lock(&mu_); @@ -109,9 +110,8 @@ Status CoreWorkerDirectActorTaskSubmitter::SubmitTask(TaskSpecification task_spe auto status = Status::IOError("cancelling task of dead actor"); // No need to increment the number of completed tasks since the actor is // dead. - RAY_UNUSED(!task_finisher_->PendingTaskFailed(task_spec.TaskId(), - rpc::ErrorType::ACTOR_DIED, &status, - creation_task_exception)); + RAY_UNUSED(!task_finisher_->PendingTaskFailed(task_id, rpc::ErrorType::ACTOR_DIED, + &status, creation_task_exception)); } // If the task submission subsequently fails, then the client will receive @@ -423,7 +423,10 @@ void CoreWorkerDirectTaskReceiver::HandleTask( const rpc::PushTaskRequest &request, rpc::PushTaskReply *reply, rpc::SendReplyCallback send_reply_callback) { RAY_CHECK(waiter_ != nullptr) << "Must call init() prior to use"; - const TaskSpecification task_spec(request.task_spec()); + // Use `mutable_task_spec()` here as `task_spec()` returns a const reference + // which doesn't work with std::move. + TaskSpecification task_spec( + std::move(*(const_cast(request).mutable_task_spec()))); // If GCS server is restarted after sending an actor creation task to this core worker, // the restarted GCS server will send the same actor creation task to the core worker @@ -455,7 +458,8 @@ void CoreWorkerDirectTaskReceiver::HandleTask( } } - auto accept_callback = [this, reply, send_reply_callback, task_spec, resource_ids]() { + auto accept_callback = [this, reply, task_spec, + resource_ids](rpc::SendReplyCallback send_reply_callback) { if (task_spec.GetMessage().skip_execution()) { send_reply_callback(Status::OK(), nullptr, nullptr); return; @@ -522,11 +526,11 @@ void CoreWorkerDirectTaskReceiver::HandleTask( } }; - auto reject_callback = [send_reply_callback]() { + auto reject_callback = [](rpc::SendReplyCallback send_reply_callback) { send_reply_callback(Status::Invalid("client cancelled stale rpc"), nullptr, nullptr); }; - auto dependencies = task_spec.GetDependencies(); + auto dependencies = task_spec.GetDependencies(false); if (task_spec.IsActorTask()) { auto it = actor_scheduling_queues_.find(task_spec.CallerWorkerId()); @@ -538,16 +542,15 @@ void CoreWorkerDirectTaskReceiver::HandleTask( it = result.first; } - // Pop the dummy actor dependency. - // TODO(swang): Remove this with legacy raylet code. - dependencies.pop_back(); it->second->Add(request.sequence_number(), request.client_processed_up_to(), - accept_callback, reject_callback, task_spec.TaskId(), dependencies); + std::move(accept_callback), std::move(reject_callback), + std::move(send_reply_callback), task_spec.TaskId(), dependencies); } else { // Add the normal task's callbacks to the non-actor scheduling queue. - normal_scheduling_queue_->Add(request.sequence_number(), - request.client_processed_up_to(), accept_callback, - reject_callback, task_spec.TaskId(), dependencies); + normal_scheduling_queue_->Add( + request.sequence_number(), request.client_processed_up_to(), + std::move(accept_callback), std::move(reject_callback), + std::move(send_reply_callback), task_spec.TaskId(), dependencies); } } diff --git a/src/ray/core_worker/transport/direct_actor_transport.h b/src/ray/core_worker/transport/direct_actor_transport.h index 02cf374e8..b483ccfed 100644 --- a/src/ray/core_worker/transport/direct_actor_transport.h +++ b/src/ray/core_worker/transport/direct_actor_transport.h @@ -276,23 +276,26 @@ class CoreWorkerDirectActorTaskSubmitter class InboundRequest { public: InboundRequest(){}; - InboundRequest(std::function accept_callback, - std::function reject_callback, TaskID task_id, + InboundRequest(std::function accept_callback, + std::function reject_callback, + rpc::SendReplyCallback send_reply_callback, TaskID task_id, bool has_dependencies) - : accept_callback_(accept_callback), - reject_callback_(reject_callback), + : accept_callback_(std::move(accept_callback)), + reject_callback_(std::move(reject_callback)), + send_reply_callback_(std::move(send_reply_callback)), task_id(task_id), has_pending_dependencies_(has_dependencies) {} - void Accept() { accept_callback_(); } - void Cancel() { reject_callback_(); } + void Accept() { accept_callback_(std::move(send_reply_callback_)); } + void Cancel() { reject_callback_(std::move(send_reply_callback_)); } bool CanExecute() const { return !has_pending_dependencies_; } ray::TaskID TaskID() const { return task_id; } void MarkDependenciesSatisfied() { has_pending_dependencies_ = false; } private: - std::function accept_callback_; - std::function reject_callback_; + std::function accept_callback_; + std::function reject_callback_; + rpc::SendReplyCallback send_reply_callback_; ray::TaskID task_id; bool has_pending_dependencies_; }; @@ -372,8 +375,10 @@ class BoundedExecutor { class SchedulingQueue { public: virtual void Add(int64_t seq_no, int64_t client_processed_up_to, - std::function accept_request, - std::function reject_request, TaskID task_id = TaskID::Nil(), + std::function accept_request, + std::function reject_request, + rpc::SendReplyCallback send_reply_callback, + TaskID task_id = TaskID::Nil(), const std::vector &dependencies = {}) = 0; virtual void ScheduleRequests() = 0; virtual bool TaskQueueEmpty() const = 0; @@ -402,8 +407,9 @@ class ActorSchedulingQueue : public SchedulingQueue { /// Add a new actor task's callbacks to the worker queue. void Add(int64_t seq_no, int64_t client_processed_up_to, - std::function accept_request, std::function reject_request, - TaskID task_id = TaskID::Nil(), + std::function accept_request, + std::function reject_request, + rpc::SendReplyCallback send_reply_callback, TaskID task_id = TaskID::Nil(), const std::vector &dependencies = {}) { // A seq_no of -1 means no ordering constraint. Actor tasks must be executed in order. RAY_CHECK(seq_no != -1); @@ -416,7 +422,8 @@ class ActorSchedulingQueue : public SchedulingQueue { } RAY_LOG(DEBUG) << "Enqueue " << seq_no << " cur seqno " << next_seq_no_; pending_actor_tasks_[seq_no] = - InboundRequest(accept_request, reject_request, task_id, dependencies.size() > 0); + InboundRequest(std::move(accept_request), std::move(reject_request), + std::move(send_reply_callback), task_id, dependencies.size() > 0); if (dependencies.size() > 0) { waiter_.Wait(dependencies, [seq_no, this]() { RAY_CHECK(boost::this_thread::get_id() == main_thread_id_); @@ -541,15 +548,17 @@ class NormalSchedulingQueue : public SchedulingQueue { /// Add a new task's callbacks to the worker queue. void Add(int64_t seq_no, int64_t client_processed_up_to, - std::function accept_request, std::function reject_request, - TaskID task_id = TaskID::Nil(), + std::function accept_request, + std::function reject_request, + rpc::SendReplyCallback send_reply_callback, TaskID task_id = TaskID::Nil(), const std::vector &dependencies = {}) { absl::MutexLock lock(&mu_); // Normal tasks should not have ordering constraints. RAY_CHECK(seq_no == -1); // Create a InboundRequest object for the new task, and add it to the queue. pending_normal_tasks_.push_back( - InboundRequest(accept_request, reject_request, task_id, dependencies.size() > 0)); + InboundRequest(std::move(accept_request), std::move(reject_request), + std::move(send_reply_callback), task_id, dependencies.size() > 0)); } // Search for an InboundRequest associated with the task that we are trying to cancel. diff --git a/src/ray/rpc/client_call.h b/src/ray/rpc/client_call.h index ccfbc0f01..4bc179749 100644 --- a/src/ray/rpc/client_call.h +++ b/src/ray/rpc/client_call.h @@ -66,7 +66,8 @@ class ClientCallImpl : public ClientCall { /// /// \param[in] callback The callback function to handle the reply. explicit ClientCallImpl(const ClientCallback &callback, std::string call_name) - : callback_(callback), call_name_(std::move(call_name)) {} + : callback_(std::move(const_cast &>(callback))), + call_name_(std::move(call_name)) {} Status GetStatus() override { absl::MutexLock lock(&mutex_); diff --git a/src/ray/rpc/worker/core_worker_client.h b/src/ray/rpc/worker/core_worker_client.h index 4264b0caf..659235bb9 100644 --- a/src/ray/rpc/worker/core_worker_client.h +++ b/src/ray/rpc/worker/core_worker_client.h @@ -285,7 +285,9 @@ class CoreWorkerClient : public std::enable_shared_from_this, { absl::MutexLock lock(&mutex_); - send_queue_.push_back(std::make_pair(std::move(request), callback)); + send_queue_.push_back(std::make_pair( + std::move(request), + std::move(const_cast &>(callback)))); } SendRequests(); } @@ -311,13 +313,13 @@ class CoreWorkerClient : public std::enable_shared_from_this, send_queue_.pop_front(); auto request = std::move(pair.first); - auto callback = pair.second; int64_t task_size = RequestSizeInBytes(*request); int64_t seq_no = request->sequence_number(); request->set_client_processed_up_to(max_finished_seq_no_); rpc_bytes_in_flight_ += task_size; - auto rpc_callback = [this, this_ptr, seq_no, task_size, callback]( + auto rpc_callback = [this, this_ptr, seq_no, task_size, + callback = std::move(pair.second)]( Status status, const rpc::PushTaskReply &reply) { { absl::MutexLock lock(&mutex_); @@ -331,8 +333,8 @@ class CoreWorkerClient : public std::enable_shared_from_this, callback(status, reply); }; - RAY_UNUSED(INVOKE_RPC_CALL(CoreWorkerService, PushTask, *request, rpc_callback, - grpc_client_)); + RAY_UNUSED(INVOKE_RPC_CALL(CoreWorkerService, PushTask, *request, + std::move(rpc_callback), grpc_client_)); } if (!send_queue_.empty()) {