Optimize lambda copy to improve direct call performance. (#15036)

This commit is contained in:
wanxing 2021-04-14 11:02:49 +08:00 committed by GitHub
parent 4ed7a14e23
commit 0ad0839265
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
17 changed files with 513 additions and 98 deletions

View file

@ -13,6 +13,7 @@ all_modules = [
"api",
"runtime",
"test",
"performance_test",
]
java_import(
@ -106,12 +107,26 @@ define_java_module(
],
)
define_java_module(
name = "performance_test",
deps = [
":io_ray_ray_api",
":io_ray_ray_runtime",
"@maven//:com_google_code_gson_gson",
"@maven//:com_google_guava_guava",
"@maven//:commons_io_commons_io",
"@maven//:org_apache_commons_commons_lang3",
"@maven//:org_slf4j_slf4j_api",
],
)
java_binary(
name = "all_tests",
args = ["java/testng.xml"],
data = ["testng.xml"],
main_class = "org.testng.TestNG",
runtime_deps = [
":io_ray_ray_performance_test",
":io_ray_ray_runtime_test",
":io_ray_ray_test",
],
@ -207,6 +222,7 @@ genrule(
cp -f $(location //java:io_ray_ray_api_pom) "$$WORK_DIR/java/api/pom.xml"
cp -f $(location //java:io_ray_ray_runtime_pom) "$$WORK_DIR/java/runtime/pom.xml"
cp -f $(location //java:io_ray_ray_test_pom) "$$WORK_DIR/java/test/pom.xml"
cp -f $(location //java:io_ray_ray_performance_test_pom) "$$WORK_DIR/java/performance_test/pom.xml"
date > $@
""",
local = 1,

56
java/performance_test/pom.xml Executable file
View file

@ -0,0 +1,56 @@
<?xml version="1.0" encoding="UTF-8"?>
<!-- This file is auto-generated by Bazel from pom_template.xml, do not modify it. -->
<project xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xmlns="http://maven.apache.org/POM/4.0.0"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/maven-v4_0_0.xsd">
<parent>
<groupId>io.ray</groupId>
<artifactId>ray-superpom</artifactId>
<version>1.0.0-SNAPSHOT</version>
</parent>
<modelVersion>4.0.0</modelVersion>
<artifactId>ray-performance-test</artifactId>
<name>java performance test cases for ray</name>
<description>java performance test cases for ray</description>
<packaging>jar</packaging>
<dependencies>
<dependency>
<groupId>io.ray</groupId>
<artifactId>ray-api</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
<groupId>io.ray</groupId>
<artifactId>ray-runtime</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
<groupId>com.google.code.gson</groupId>
<artifactId>gson</artifactId>
<version>2.8.5</version>
</dependency>
<dependency>
<groupId>com.google.guava</groupId>
<artifactId>guava</artifactId>
<version>27.0.1-jre</version>
</dependency>
<dependency>
<groupId>commons-io</groupId>
<artifactId>commons-io</artifactId>
<version>2.5</version>
</dependency>
<dependency>
<groupId>org.apache.commons</groupId>
<artifactId>commons-lang3</artifactId>
<version>3.4</version>
</dependency>
<dependency>
<groupId>org.slf4j</groupId>
<artifactId>slf4j-api</artifactId>
<version>1.7.25</version>
</dependency>
</dependencies>
</project>

View file

@ -0,0 +1,32 @@
<?xml version="1.0" encoding="UTF-8"?>
{auto_gen_header}
<project xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xmlns="http://maven.apache.org/POM/4.0.0"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/maven-v4_0_0.xsd">
<parent>
<groupId>io.ray</groupId>
<artifactId>ray-superpom</artifactId>
<version>1.0.0-SNAPSHOT</version>
</parent>
<modelVersion>4.0.0</modelVersion>
<artifactId>ray-performance-test</artifactId>
<name>java performance test cases for ray</name>
<description>java performance test cases for ray</description>
<packaging>jar</packaging>
<dependencies>
<dependency>
<groupId>io.ray</groupId>
<artifactId>ray-api</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
<groupId>io.ray</groupId>
<artifactId>ray-runtime</artifactId>
<version>${project.version}</version>
</dependency>
{generated_bzl_deps}
</dependencies>
</project>

View file

@ -0,0 +1,40 @@
package io.ray.performancetest;
import java.nio.ByteBuffer;
public class Receiver {
private int value = 0;
public Receiver() {}
public boolean ping() {
return true;
}
public void noArgsNoReturn() {
value += 1;
}
public int noArgsHasReturn() {
value += 1;
return value;
}
public void bytesNoReturn(byte[] data) {
value += 1;
}
public int bytesHasReturn(byte[] data) {
value += 1;
return value;
}
public void byteBufferNoReturn(ByteBuffer data) {
value += 1;
}
public int byteBufferHasReturn(ByteBuffer data) {
value += 1;
return value;
}
}

View file

@ -0,0 +1,117 @@
package io.ray.performancetest;
import com.google.common.base.Preconditions;
import io.ray.api.ActorHandle;
import io.ray.api.ObjectRef;
import io.ray.api.Ray;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
public class Source {
private static final int BATCH_SIZE;
private static final Logger LOGGER = LoggerFactory.getLogger(Source.class);
private final List<ActorHandle<Receiver>> receivers;
static {
String batchSizeString = System.getenv().get("PERF_TEST_BATCH_SIZE");
if (batchSizeString != null) {
BATCH_SIZE = Integer.valueOf(batchSizeString);
} else {
BATCH_SIZE = 1000;
}
}
public Source(List<ActorHandle<Receiver>> receivers) {
this.receivers = receivers;
}
public boolean startTest(
boolean hasReturn, boolean ignoreReturn, int argSize, boolean useDirectByteBuffer) {
LOGGER.info("Source startTest");
byte[] bytes = null;
ByteBuffer buffer = null;
if (argSize > 0) {
bytes = new byte[argSize];
new Random().nextBytes(bytes);
buffer = ByteBuffer.wrap(bytes);
} else {
Preconditions.checkState(!useDirectByteBuffer);
}
// Wait for actors to be created.
for (ActorHandle<Receiver> receiver : receivers) {
receiver.task(Receiver::ping).remote().get();
}
LOGGER.info(
"Started executing tasks, useDirectByteBuffer: {}, argSize: {}, has return: {}",
useDirectByteBuffer,
argSize,
hasReturn);
List<List<ObjectRef<Integer>>> returnObjects = new ArrayList<>();
returnObjects.add(new ArrayList<>());
returnObjects.add(new ArrayList<>());
long startTime = System.currentTimeMillis();
int numTasks = 0;
long lastReport = 0;
long totalTime = 0;
long batchCount = 0;
while (true) {
numTasks++;
boolean batchEnd = numTasks % BATCH_SIZE == 0;
for (ActorHandle<Receiver> receiver : receivers) {
if (hasReturn || batchEnd) {
ObjectRef<Integer> returnObject;
if (useDirectByteBuffer) {
returnObject = receiver.task(Receiver::byteBufferHasReturn, buffer).remote();
} else if (argSize > 0) {
returnObject = receiver.task(Receiver::bytesHasReturn, bytes).remote();
} else {
returnObject = receiver.task(Receiver::noArgsHasReturn).remote();
}
returnObjects.get(1).add(returnObject);
} else {
if (useDirectByteBuffer) {
receiver.task(Receiver::byteBufferNoReturn, buffer).remote();
} else if (argSize > 0) {
receiver.task(Receiver::bytesNoReturn, bytes).remote();
} else {
receiver.task(Receiver::noArgsNoReturn).remote();
}
}
}
if (batchEnd) {
batchCount++;
long getBeginTs = System.currentTimeMillis();
Ray.get(returnObjects.get(0));
long rt = System.currentTimeMillis() - getBeginTs;
totalTime += rt;
returnObjects.set(0, returnObjects.get(1));
returnObjects.set(1, new ArrayList<>());
long elapsedTime = System.currentTimeMillis() - startTime;
if (elapsedTime / 60000 > lastReport) {
lastReport = elapsedTime / 60000;
LOGGER.info(
"Finished executing {} tasks in {} ms, useDirectByteBuffer: {}, argSize: {}, "
+ "has return: {}, avg get rt: {}",
numTasks,
elapsedTime,
useDirectByteBuffer,
argSize,
hasReturn,
totalTime / (float) batchCount);
}
}
}
}
}

View file

@ -0,0 +1,76 @@
package io.ray.performancetest.test;
import com.google.common.base.Preconditions;
import io.ray.api.ActorHandle;
import io.ray.api.ObjectRef;
import io.ray.api.Ray;
import io.ray.performancetest.Receiver;
import io.ray.performancetest.Source;
import java.util.ArrayList;
import java.util.List;
import java.util.stream.Collectors;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
public class ActorPerformanceTestBase {
private static final Logger LOGGER = LoggerFactory.getLogger(ActorPerformanceTestBase.class);
public static void run(
String[] args,
int[] layers,
int[] actorsPerLayer,
boolean hasReturn,
boolean ignoreReturn,
int argSize,
boolean useDirectByteBuffer,
int numJavaWorkerPerProcess) {
System.setProperty(
"ray.job.num-java-workers-per-process", String.valueOf(numJavaWorkerPerProcess));
Ray.init();
try {
// TODO: Support more layers.
Preconditions.checkState(layers.length == 2);
Preconditions.checkState(actorsPerLayer.length == layers.length);
for (int i = 0; i < layers.length; i++) {
Preconditions.checkState(layers[i] > 0);
Preconditions.checkState(actorsPerLayer[i] > 0);
}
List<ActorHandle<Receiver>> receivers = new ArrayList<>();
for (int i = 0; i < layers[1]; i++) {
int nodeIndex = layers[0] + i;
for (int j = 0; j < actorsPerLayer[1]; j++) {
receivers.add(Ray.actor(Receiver::new).remote());
}
}
List<ActorHandle<Source>> sources = new ArrayList<>();
for (int i = 0; i < layers[0]; i++) {
int nodeIndex = i;
for (int j = 0; j < actorsPerLayer[0]; j++) {
sources.add(Ray.actor(Source::new, receivers).remote());
}
}
List<ObjectRef<Boolean>> results =
sources.stream()
.map(
source ->
source
.task(
Source::startTest,
hasReturn,
ignoreReturn,
argSize,
useDirectByteBuffer)
.remote())
.collect(Collectors.toList());
Ray.get(results);
} catch (Throwable e) {
LOGGER.error("Run test failed.", e);
throw e;
}
}
}

View file

@ -0,0 +1,27 @@
package io.ray.performancetest.test;
/**
* 1-to-1 ray call, one receiving actor and one sending actor, test the throughput of the engine
* itself.
*/
public class ActorPerformanceTestCase1 {
public static void main(String[] args) {
final int[] layers = new int[] {1, 1};
final int[] actorsPerLayer = new int[] {1, 1};
final boolean hasReturn = false;
final int argSize = 0;
final boolean useDirectByteBuffer = false;
final boolean ignoreReturn = false;
final int numJavaWorkerPerProcess = 1;
ActorPerformanceTestBase.run(
args,
layers,
actorsPerLayer,
hasReturn,
ignoreReturn,
argSize,
useDirectByteBuffer,
numJavaWorkerPerProcess);
}
}

View file

@ -41,6 +41,24 @@ run_testng() {
fi
}
run_timeout() {
local pid
timeout=$1
shift 1
"$@" &
pid=$!
sleep "$timeout"
if ps -p $pid > /dev/null
then
echo "run_timeout process exists, kill it."
kill -9 $pid
else
echo "run_timeout process not exist."
cat /tmp/ray/session_latest/logs/java-core-driver-*$pid*
exit 1
fi
}
pushd "$ROOT_DIR"/..
echo "Build java maven deps."
bazel build //java:gen_maven_deps
@ -114,3 +132,10 @@ mvn -Dorg.slf4j.simpleLogger.defaultLogLevel=WARN clean install -DskipTests -Dch
# Ensure mvn test works
mvn test -pl test -Dtest="io.ray.test.HelloWorldTest"
popd
pushd "$ROOT_DIR"
echo "Running performance test."
run_timeout 60 java -cp "$ROOT_DIR"/../bazel-bin/java/all_tests_deploy.jar io.ray.performancetest.test.ActorPerformanceTestCase1
# The performance process may be killed by run_timeout, so clear ray here.
ray stop
popd

View file

@ -35,7 +35,10 @@ class MessageWrapper {
/// The input message will be **copied** into this object.
///
/// \param message The protobuf message.
explicit MessageWrapper(const Message message)
explicit MessageWrapper(const Message &message)
: message_(std::make_shared<Message>(message)) {}
explicit MessageWrapper(Message &&message)
: message_(std::make_shared<Message>(std::move(message))) {}
/// Construct from a protobuf message shared_ptr.
@ -115,7 +118,8 @@ inline std::vector<ID> IdVectorFromProtobuf(
/// Converts a Protobuf map to a `unordered_map`.
template <class K, class V>
inline std::unordered_map<K, V> MapFromProtobuf(::google::protobuf::Map<K, V> pb_map) {
inline std::unordered_map<K, V> MapFromProtobuf(
const ::google::protobuf::Map<K, V> &pb_map) {
return std::unordered_map<K, V>(pb_map.begin(), pb_map.end());
}

View file

@ -53,24 +53,24 @@ bool TaskSpecification::PlacementGroupCaptureChildTasks() const {
}
void TaskSpecification::ComputeResources() {
auto required_resources = MapFromProtobuf(message_->required_resources());
auto required_placement_resources =
MapFromProtobuf(message_->required_placement_resources());
if (required_placement_resources.empty()) {
required_placement_resources = required_resources;
}
auto &required_resources = message_->required_resources();
if (required_resources.empty()) {
// A static nil object is used here to avoid allocating the empty object every time.
required_resources_ = ResourceSet::Nil();
} else {
required_resources_.reset(new ResourceSet(required_resources));
required_resources_.reset(new ResourceSet(MapFromProtobuf(required_resources)));
}
auto &required_placement_resources = message_->required_placement_resources().empty()
? required_resources
: message_->required_placement_resources();
if (required_placement_resources.empty()) {
required_placement_resources_ = ResourceSet::Nil();
} else {
required_placement_resources_.reset(new ResourceSet(required_placement_resources));
required_placement_resources_.reset(
new ResourceSet(MapFromProtobuf(required_placement_resources)));
}
if (!IsActorTask()) {
@ -174,14 +174,15 @@ std::vector<ObjectID> TaskSpecification::GetDependencyIds() const {
return dependencies;
}
std::vector<rpc::ObjectReference> TaskSpecification::GetDependencies() const {
std::vector<rpc::ObjectReference> TaskSpecification::GetDependencies(
bool add_dummy_dependency) const {
std::vector<rpc::ObjectReference> dependencies;
for (size_t i = 0; i < NumArgs(); ++i) {
if (ArgByRef(i)) {
dependencies.push_back(message_->args(i).object_ref());
}
}
if (IsActorTask()) {
if (add_dummy_dependency && IsActorTask()) {
const auto &dummy_ref =
GetReferenceForActorDummyObject(PreviousActorTaskDummyObjectId());
dependencies.push_back(dummy_ref);

View file

@ -36,10 +36,15 @@ class TaskSpecification : public MessageWrapper<rpc::TaskSpec> {
TaskSpecification() {}
/// Construct from a protobuf message object.
/// The input message will be **copied** into this object.
/// The input message will be copied/moved into this object.
///
/// \param message The protobuf message.
explicit TaskSpecification(rpc::TaskSpec message) : MessageWrapper(message) {
explicit TaskSpecification(rpc::TaskSpec &&message)
: MessageWrapper(std::move(message)) {
ComputeResources();
}
explicit TaskSpecification(const rpc::TaskSpec &message) : MessageWrapper(message) {
ComputeResources();
}
@ -127,9 +132,10 @@ class TaskSpecification : public MessageWrapper<rpc::TaskSpec> {
/// Return the dependencies of this task. This is recomputed each time, so it can
/// be used if the task spec is mutated.
///
/// \param add_dummy_dependency whether to add a dummy object in the returned objects.
/// \return The recomputed dependencies for the task.
std::vector<rpc::ObjectReference> GetDependencies() const;
std::vector<rpc::ObjectReference> GetDependencies(
bool add_dummy_dependency = true) const;
std::string GetDebuggerBreakpoint() const;

View file

@ -2219,7 +2219,7 @@ void CoreWorker::HandlePushTask(const rpc::PushTaskRequest &request,
// execution service.
if (request.task_spec().type() == TaskType::ACTOR_TASK) {
task_execution_service_.post(
[=] {
[this, request, reply, send_reply_callback = std::move(send_reply_callback)] {
// We have posted an exit task onto the main event loop,
// so shouldn't bother executing any further work.
if (exiting_) return;

View file

@ -42,12 +42,12 @@ TEST(SchedulingQueueTest, TestInOrder) {
ActorSchedulingQueue queue(io_service, waiter);
int n_ok = 0;
int n_rej = 0;
auto fn_ok = [&n_ok]() { n_ok++; };
auto fn_rej = [&n_rej]() { n_rej++; };
queue.Add(0, -1, fn_ok, fn_rej);
queue.Add(1, -1, fn_ok, fn_rej);
queue.Add(2, -1, fn_ok, fn_rej);
queue.Add(3, -1, fn_ok, fn_rej);
auto fn_ok = [&n_ok](rpc::SendReplyCallback callback) { n_ok++; };
auto fn_rej = [&n_rej](rpc::SendReplyCallback callback) { n_rej++; };
queue.Add(0, -1, fn_ok, fn_rej, nullptr);
queue.Add(1, -1, fn_ok, fn_rej, nullptr);
queue.Add(2, -1, fn_ok, fn_rej, nullptr);
queue.Add(3, -1, fn_ok, fn_rej, nullptr);
io_service.run();
ASSERT_EQ(n_ok, 4);
ASSERT_EQ(n_rej, 0);
@ -62,12 +62,12 @@ TEST(SchedulingQueueTest, TestWaitForObjects) {
ActorSchedulingQueue queue(io_service, waiter);
int n_ok = 0;
int n_rej = 0;
auto fn_ok = [&n_ok]() { n_ok++; };
auto fn_rej = [&n_rej]() { n_rej++; };
queue.Add(0, -1, fn_ok, fn_rej);
queue.Add(1, -1, fn_ok, fn_rej, TaskID::Nil(), ObjectIdsToRefs({obj1}));
queue.Add(2, -1, fn_ok, fn_rej, TaskID::Nil(), ObjectIdsToRefs({obj2}));
queue.Add(3, -1, fn_ok, fn_rej, TaskID::Nil(), ObjectIdsToRefs({obj3}));
auto fn_ok = [&n_ok](rpc::SendReplyCallback callback) { n_ok++; };
auto fn_rej = [&n_rej](rpc::SendReplyCallback callback) { n_rej++; };
queue.Add(0, -1, fn_ok, fn_rej, nullptr);
queue.Add(1, -1, fn_ok, fn_rej, nullptr, TaskID::Nil(), ObjectIdsToRefs({obj1}));
queue.Add(2, -1, fn_ok, fn_rej, nullptr, TaskID::Nil(), ObjectIdsToRefs({obj2}));
queue.Add(3, -1, fn_ok, fn_rej, nullptr, TaskID::Nil(), ObjectIdsToRefs({obj3}));
ASSERT_EQ(n_ok, 1);
waiter.Complete(0);
@ -87,10 +87,10 @@ TEST(SchedulingQueueTest, TestWaitForObjectsNotSubjectToSeqTimeout) {
ActorSchedulingQueue queue(io_service, waiter);
int n_ok = 0;
int n_rej = 0;
auto fn_ok = [&n_ok]() { n_ok++; };
auto fn_rej = [&n_rej]() { n_rej++; };
queue.Add(0, -1, fn_ok, fn_rej);
queue.Add(1, -1, fn_ok, fn_rej, TaskID::Nil(), ObjectIdsToRefs({obj1}));
auto fn_ok = [&n_ok](rpc::SendReplyCallback callback) { n_ok++; };
auto fn_rej = [&n_rej](rpc::SendReplyCallback callback) { n_rej++; };
queue.Add(0, -1, fn_ok, fn_rej, nullptr);
queue.Add(1, -1, fn_ok, fn_rej, nullptr, TaskID::Nil(), ObjectIdsToRefs({obj1}));
ASSERT_EQ(n_ok, 1);
io_service.run();
ASSERT_EQ(n_rej, 0);
@ -104,12 +104,12 @@ TEST(SchedulingQueueTest, TestOutOfOrder) {
ActorSchedulingQueue queue(io_service, waiter);
int n_ok = 0;
int n_rej = 0;
auto fn_ok = [&n_ok]() { n_ok++; };
auto fn_rej = [&n_rej]() { n_rej++; };
queue.Add(2, -1, fn_ok, fn_rej);
queue.Add(0, -1, fn_ok, fn_rej);
queue.Add(3, -1, fn_ok, fn_rej);
queue.Add(1, -1, fn_ok, fn_rej);
auto fn_ok = [&n_ok](rpc::SendReplyCallback callback) { n_ok++; };
auto fn_rej = [&n_rej](rpc::SendReplyCallback callback) { n_rej++; };
queue.Add(2, -1, fn_ok, fn_rej, nullptr);
queue.Add(0, -1, fn_ok, fn_rej, nullptr);
queue.Add(3, -1, fn_ok, fn_rej, nullptr);
queue.Add(1, -1, fn_ok, fn_rej, nullptr);
io_service.run();
ASSERT_EQ(n_ok, 4);
ASSERT_EQ(n_rej, 0);
@ -121,18 +121,18 @@ TEST(SchedulingQueueTest, TestSeqWaitTimeout) {
ActorSchedulingQueue queue(io_service, waiter);
int n_ok = 0;
int n_rej = 0;
auto fn_ok = [&n_ok]() { n_ok++; };
auto fn_rej = [&n_rej]() { n_rej++; };
queue.Add(2, -1, fn_ok, fn_rej);
queue.Add(0, -1, fn_ok, fn_rej);
queue.Add(3, -1, fn_ok, fn_rej);
auto fn_ok = [&n_ok](rpc::SendReplyCallback callback) { n_ok++; };
auto fn_rej = [&n_rej](rpc::SendReplyCallback callback) { n_rej++; };
queue.Add(2, -1, fn_ok, fn_rej, nullptr);
queue.Add(0, -1, fn_ok, fn_rej, nullptr);
queue.Add(3, -1, fn_ok, fn_rej, nullptr);
ASSERT_EQ(n_ok, 1);
ASSERT_EQ(n_rej, 0);
io_service.run(); // immediately triggers timeout
ASSERT_EQ(n_ok, 1);
ASSERT_EQ(n_rej, 2);
queue.Add(4, -1, fn_ok, fn_rej);
queue.Add(5, -1, fn_ok, fn_rej);
queue.Add(4, -1, fn_ok, fn_rej, nullptr);
queue.Add(5, -1, fn_ok, fn_rej, nullptr);
ASSERT_EQ(n_ok, 3);
ASSERT_EQ(n_rej, 2);
}
@ -143,11 +143,11 @@ TEST(SchedulingQueueTest, TestSkipAlreadyProcessedByClient) {
ActorSchedulingQueue queue(io_service, waiter);
int n_ok = 0;
int n_rej = 0;
auto fn_ok = [&n_ok]() { n_ok++; };
auto fn_rej = [&n_rej]() { n_rej++; };
queue.Add(2, 2, fn_ok, fn_rej);
queue.Add(3, 2, fn_ok, fn_rej);
queue.Add(1, 2, fn_ok, fn_rej);
auto fn_ok = [&n_ok](rpc::SendReplyCallback callback) { n_ok++; };
auto fn_rej = [&n_rej](rpc::SendReplyCallback callback) { n_rej++; };
queue.Add(2, 2, fn_ok, fn_rej, nullptr);
queue.Add(3, 2, fn_ok, fn_rej, nullptr);
queue.Add(1, 2, fn_ok, fn_rej, nullptr);
io_service.run();
ASSERT_EQ(n_ok, 1);
ASSERT_EQ(n_rej, 2);
@ -158,13 +158,13 @@ TEST(SchedulingQueueTest, TestCancelQueuedTask) {
ASSERT_TRUE(queue->TaskQueueEmpty());
int n_ok = 0;
int n_rej = 0;
auto fn_ok = [&n_ok]() { n_ok++; };
auto fn_rej = [&n_rej]() { n_rej++; };
queue->Add(-1, -1, fn_ok, fn_rej);
queue->Add(-1, -1, fn_ok, fn_rej);
queue->Add(-1, -1, fn_ok, fn_rej);
queue->Add(-1, -1, fn_ok, fn_rej);
queue->Add(-1, -1, fn_ok, fn_rej);
auto fn_ok = [&n_ok](rpc::SendReplyCallback callback) { n_ok++; };
auto fn_rej = [&n_rej](rpc::SendReplyCallback callback) { n_rej++; };
queue->Add(-1, -1, fn_ok, fn_rej, nullptr);
queue->Add(-1, -1, fn_ok, fn_rej, nullptr);
queue->Add(-1, -1, fn_ok, fn_rej, nullptr);
queue->Add(-1, -1, fn_ok, fn_rej, nullptr);
queue->Add(-1, -1, fn_ok, fn_rej, nullptr);
ASSERT_TRUE(queue->CancelTaskIfFound(TaskID::Nil()));
ASSERT_FALSE(queue->TaskQueueEmpty());
queue->ScheduleRequests();

View file

@ -59,14 +59,16 @@ void CoreWorkerDirectActorTaskSubmitter::KillActor(const ActorID &actor_id,
}
Status CoreWorkerDirectActorTaskSubmitter::SubmitTask(TaskSpecification task_spec) {
RAY_LOG(DEBUG) << "Submitting task " << task_spec.TaskId();
auto task_id = task_spec.TaskId();
auto actor_id = task_spec.ActorId();
RAY_LOG(DEBUG) << "Submitting task " << task_id;
RAY_CHECK(task_spec.IsActorTask());
bool task_queued = false;
uint64_t send_pos = 0;
{
absl::MutexLock lock(&mu_);
auto queue = client_queues_.find(task_spec.ActorId());
auto queue = client_queues_.find(actor_id);
RAY_CHECK(queue != client_queues_.end());
if (queue->second.state != rpc::ActorTableData::DEAD) {
// We must fix the send order prior to resolving dependencies, which may
@ -82,7 +84,6 @@ Status CoreWorkerDirectActorTaskSubmitter::SubmitTask(TaskSpecification task_spe
}
if (task_queued) {
const auto actor_id = task_spec.ActorId();
// We must release the lock before resolving the task dependencies since
// the callback may get called in the same call stack.
resolver_.ResolveDependencies(task_spec, [this, send_pos, actor_id]() {
@ -99,7 +100,7 @@ Status CoreWorkerDirectActorTaskSubmitter::SubmitTask(TaskSpecification task_spe
});
} else {
// Do not hold the lock while calling into task_finisher_.
task_finisher_->MarkTaskCanceled(task_spec.TaskId());
task_finisher_->MarkTaskCanceled(task_id);
std::shared_ptr<rpc::RayException> creation_task_exception = nullptr;
{
absl::MutexLock lock(&mu_);
@ -109,9 +110,8 @@ Status CoreWorkerDirectActorTaskSubmitter::SubmitTask(TaskSpecification task_spe
auto status = Status::IOError("cancelling task of dead actor");
// No need to increment the number of completed tasks since the actor is
// dead.
RAY_UNUSED(!task_finisher_->PendingTaskFailed(task_spec.TaskId(),
rpc::ErrorType::ACTOR_DIED, &status,
creation_task_exception));
RAY_UNUSED(!task_finisher_->PendingTaskFailed(task_id, rpc::ErrorType::ACTOR_DIED,
&status, creation_task_exception));
}
// If the task submission subsequently fails, then the client will receive
@ -423,7 +423,10 @@ void CoreWorkerDirectTaskReceiver::HandleTask(
const rpc::PushTaskRequest &request, rpc::PushTaskReply *reply,
rpc::SendReplyCallback send_reply_callback) {
RAY_CHECK(waiter_ != nullptr) << "Must call init() prior to use";
const TaskSpecification task_spec(request.task_spec());
// Use `mutable_task_spec()` here as `task_spec()` returns a const reference
// which doesn't work with std::move.
TaskSpecification task_spec(
std::move(*(const_cast<rpc::PushTaskRequest &>(request).mutable_task_spec())));
// If GCS server is restarted after sending an actor creation task to this core worker,
// the restarted GCS server will send the same actor creation task to the core worker
@ -455,7 +458,8 @@ void CoreWorkerDirectTaskReceiver::HandleTask(
}
}
auto accept_callback = [this, reply, send_reply_callback, task_spec, resource_ids]() {
auto accept_callback = [this, reply, task_spec,
resource_ids](rpc::SendReplyCallback send_reply_callback) {
if (task_spec.GetMessage().skip_execution()) {
send_reply_callback(Status::OK(), nullptr, nullptr);
return;
@ -522,11 +526,11 @@ void CoreWorkerDirectTaskReceiver::HandleTask(
}
};
auto reject_callback = [send_reply_callback]() {
auto reject_callback = [](rpc::SendReplyCallback send_reply_callback) {
send_reply_callback(Status::Invalid("client cancelled stale rpc"), nullptr, nullptr);
};
auto dependencies = task_spec.GetDependencies();
auto dependencies = task_spec.GetDependencies(false);
if (task_spec.IsActorTask()) {
auto it = actor_scheduling_queues_.find(task_spec.CallerWorkerId());
@ -538,16 +542,15 @@ void CoreWorkerDirectTaskReceiver::HandleTask(
it = result.first;
}
// Pop the dummy actor dependency.
// TODO(swang): Remove this with legacy raylet code.
dependencies.pop_back();
it->second->Add(request.sequence_number(), request.client_processed_up_to(),
accept_callback, reject_callback, task_spec.TaskId(), dependencies);
std::move(accept_callback), std::move(reject_callback),
std::move(send_reply_callback), task_spec.TaskId(), dependencies);
} else {
// Add the normal task's callbacks to the non-actor scheduling queue.
normal_scheduling_queue_->Add(request.sequence_number(),
request.client_processed_up_to(), accept_callback,
reject_callback, task_spec.TaskId(), dependencies);
normal_scheduling_queue_->Add(
request.sequence_number(), request.client_processed_up_to(),
std::move(accept_callback), std::move(reject_callback),
std::move(send_reply_callback), task_spec.TaskId(), dependencies);
}
}

View file

@ -276,23 +276,26 @@ class CoreWorkerDirectActorTaskSubmitter
class InboundRequest {
public:
InboundRequest(){};
InboundRequest(std::function<void()> accept_callback,
std::function<void()> reject_callback, TaskID task_id,
InboundRequest(std::function<void(rpc::SendReplyCallback)> accept_callback,
std::function<void(rpc::SendReplyCallback)> reject_callback,
rpc::SendReplyCallback send_reply_callback, TaskID task_id,
bool has_dependencies)
: accept_callback_(accept_callback),
reject_callback_(reject_callback),
: accept_callback_(std::move(accept_callback)),
reject_callback_(std::move(reject_callback)),
send_reply_callback_(std::move(send_reply_callback)),
task_id(task_id),
has_pending_dependencies_(has_dependencies) {}
void Accept() { accept_callback_(); }
void Cancel() { reject_callback_(); }
void Accept() { accept_callback_(std::move(send_reply_callback_)); }
void Cancel() { reject_callback_(std::move(send_reply_callback_)); }
bool CanExecute() const { return !has_pending_dependencies_; }
ray::TaskID TaskID() const { return task_id; }
void MarkDependenciesSatisfied() { has_pending_dependencies_ = false; }
private:
std::function<void()> accept_callback_;
std::function<void()> reject_callback_;
std::function<void(rpc::SendReplyCallback)> accept_callback_;
std::function<void(rpc::SendReplyCallback)> reject_callback_;
rpc::SendReplyCallback send_reply_callback_;
ray::TaskID task_id;
bool has_pending_dependencies_;
};
@ -372,8 +375,10 @@ class BoundedExecutor {
class SchedulingQueue {
public:
virtual void Add(int64_t seq_no, int64_t client_processed_up_to,
std::function<void()> accept_request,
std::function<void()> reject_request, TaskID task_id = TaskID::Nil(),
std::function<void(rpc::SendReplyCallback)> accept_request,
std::function<void(rpc::SendReplyCallback)> reject_request,
rpc::SendReplyCallback send_reply_callback,
TaskID task_id = TaskID::Nil(),
const std::vector<rpc::ObjectReference> &dependencies = {}) = 0;
virtual void ScheduleRequests() = 0;
virtual bool TaskQueueEmpty() const = 0;
@ -402,8 +407,9 @@ class ActorSchedulingQueue : public SchedulingQueue {
/// Add a new actor task's callbacks to the worker queue.
void Add(int64_t seq_no, int64_t client_processed_up_to,
std::function<void()> accept_request, std::function<void()> reject_request,
TaskID task_id = TaskID::Nil(),
std::function<void(rpc::SendReplyCallback)> accept_request,
std::function<void(rpc::SendReplyCallback)> reject_request,
rpc::SendReplyCallback send_reply_callback, TaskID task_id = TaskID::Nil(),
const std::vector<rpc::ObjectReference> &dependencies = {}) {
// A seq_no of -1 means no ordering constraint. Actor tasks must be executed in order.
RAY_CHECK(seq_no != -1);
@ -416,7 +422,8 @@ class ActorSchedulingQueue : public SchedulingQueue {
}
RAY_LOG(DEBUG) << "Enqueue " << seq_no << " cur seqno " << next_seq_no_;
pending_actor_tasks_[seq_no] =
InboundRequest(accept_request, reject_request, task_id, dependencies.size() > 0);
InboundRequest(std::move(accept_request), std::move(reject_request),
std::move(send_reply_callback), task_id, dependencies.size() > 0);
if (dependencies.size() > 0) {
waiter_.Wait(dependencies, [seq_no, this]() {
RAY_CHECK(boost::this_thread::get_id() == main_thread_id_);
@ -541,15 +548,17 @@ class NormalSchedulingQueue : public SchedulingQueue {
/// Add a new task's callbacks to the worker queue.
void Add(int64_t seq_no, int64_t client_processed_up_to,
std::function<void()> accept_request, std::function<void()> reject_request,
TaskID task_id = TaskID::Nil(),
std::function<void(rpc::SendReplyCallback)> accept_request,
std::function<void(rpc::SendReplyCallback)> reject_request,
rpc::SendReplyCallback send_reply_callback, TaskID task_id = TaskID::Nil(),
const std::vector<rpc::ObjectReference> &dependencies = {}) {
absl::MutexLock lock(&mu_);
// Normal tasks should not have ordering constraints.
RAY_CHECK(seq_no == -1);
// Create a InboundRequest object for the new task, and add it to the queue.
pending_normal_tasks_.push_back(
InboundRequest(accept_request, reject_request, task_id, dependencies.size() > 0));
InboundRequest(std::move(accept_request), std::move(reject_request),
std::move(send_reply_callback), task_id, dependencies.size() > 0));
}
// Search for an InboundRequest associated with the task that we are trying to cancel.

View file

@ -66,7 +66,8 @@ class ClientCallImpl : public ClientCall {
///
/// \param[in] callback The callback function to handle the reply.
explicit ClientCallImpl(const ClientCallback<Reply> &callback, std::string call_name)
: callback_(callback), call_name_(std::move(call_name)) {}
: callback_(std::move(const_cast<ClientCallback<Reply> &>(callback))),
call_name_(std::move(call_name)) {}
Status GetStatus() override {
absl::MutexLock lock(&mutex_);

View file

@ -285,7 +285,9 @@ class CoreWorkerClient : public std::enable_shared_from_this<CoreWorkerClient>,
{
absl::MutexLock lock(&mutex_);
send_queue_.push_back(std::make_pair(std::move(request), callback));
send_queue_.push_back(std::make_pair(
std::move(request),
std::move(const_cast<ClientCallback<PushTaskReply> &>(callback))));
}
SendRequests();
}
@ -311,13 +313,13 @@ class CoreWorkerClient : public std::enable_shared_from_this<CoreWorkerClient>,
send_queue_.pop_front();
auto request = std::move(pair.first);
auto callback = pair.second;
int64_t task_size = RequestSizeInBytes(*request);
int64_t seq_no = request->sequence_number();
request->set_client_processed_up_to(max_finished_seq_no_);
rpc_bytes_in_flight_ += task_size;
auto rpc_callback = [this, this_ptr, seq_no, task_size, callback](
auto rpc_callback = [this, this_ptr, seq_no, task_size,
callback = std::move(pair.second)](
Status status, const rpc::PushTaskReply &reply) {
{
absl::MutexLock lock(&mutex_);
@ -331,8 +333,8 @@ class CoreWorkerClient : public std::enable_shared_from_this<CoreWorkerClient>,
callback(status, reply);
};
RAY_UNUSED(INVOKE_RPC_CALL(CoreWorkerService, PushTask, *request, rpc_callback,
grpc_client_));
RAY_UNUSED(INVOKE_RPC_CALL(CoreWorkerService, PushTask, *request,
std::move(rpc_callback), grpc_client_));
}
if (!send_queue_.empty()) {