mirror of
https://github.com/vale981/ray
synced 2025-03-05 10:01:43 -05:00
Optimize lambda copy to improve direct call performance. (#15036)
This commit is contained in:
parent
4ed7a14e23
commit
0ad0839265
17 changed files with 513 additions and 98 deletions
|
@ -13,6 +13,7 @@ all_modules = [
|
|||
"api",
|
||||
"runtime",
|
||||
"test",
|
||||
"performance_test",
|
||||
]
|
||||
|
||||
java_import(
|
||||
|
@ -106,12 +107,26 @@ define_java_module(
|
|||
],
|
||||
)
|
||||
|
||||
define_java_module(
|
||||
name = "performance_test",
|
||||
deps = [
|
||||
":io_ray_ray_api",
|
||||
":io_ray_ray_runtime",
|
||||
"@maven//:com_google_code_gson_gson",
|
||||
"@maven//:com_google_guava_guava",
|
||||
"@maven//:commons_io_commons_io",
|
||||
"@maven//:org_apache_commons_commons_lang3",
|
||||
"@maven//:org_slf4j_slf4j_api",
|
||||
],
|
||||
)
|
||||
|
||||
java_binary(
|
||||
name = "all_tests",
|
||||
args = ["java/testng.xml"],
|
||||
data = ["testng.xml"],
|
||||
main_class = "org.testng.TestNG",
|
||||
runtime_deps = [
|
||||
":io_ray_ray_performance_test",
|
||||
":io_ray_ray_runtime_test",
|
||||
":io_ray_ray_test",
|
||||
],
|
||||
|
@ -207,6 +222,7 @@ genrule(
|
|||
cp -f $(location //java:io_ray_ray_api_pom) "$$WORK_DIR/java/api/pom.xml"
|
||||
cp -f $(location //java:io_ray_ray_runtime_pom) "$$WORK_DIR/java/runtime/pom.xml"
|
||||
cp -f $(location //java:io_ray_ray_test_pom) "$$WORK_DIR/java/test/pom.xml"
|
||||
cp -f $(location //java:io_ray_ray_performance_test_pom) "$$WORK_DIR/java/performance_test/pom.xml"
|
||||
date > $@
|
||||
""",
|
||||
local = 1,
|
||||
|
|
56
java/performance_test/pom.xml
Executable file
56
java/performance_test/pom.xml
Executable file
|
@ -0,0 +1,56 @@
|
|||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<!-- This file is auto-generated by Bazel from pom_template.xml, do not modify it. -->
|
||||
<project xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
|
||||
xmlns="http://maven.apache.org/POM/4.0.0"
|
||||
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/maven-v4_0_0.xsd">
|
||||
<parent>
|
||||
<groupId>io.ray</groupId>
|
||||
<artifactId>ray-superpom</artifactId>
|
||||
<version>1.0.0-SNAPSHOT</version>
|
||||
</parent>
|
||||
<modelVersion>4.0.0</modelVersion>
|
||||
|
||||
<artifactId>ray-performance-test</artifactId>
|
||||
<name>java performance test cases for ray</name>
|
||||
<description>java performance test cases for ray</description>
|
||||
|
||||
<packaging>jar</packaging>
|
||||
|
||||
<dependencies>
|
||||
<dependency>
|
||||
<groupId>io.ray</groupId>
|
||||
<artifactId>ray-api</artifactId>
|
||||
<version>${project.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>io.ray</groupId>
|
||||
<artifactId>ray-runtime</artifactId>
|
||||
<version>${project.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>com.google.code.gson</groupId>
|
||||
<artifactId>gson</artifactId>
|
||||
<version>2.8.5</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>com.google.guava</groupId>
|
||||
<artifactId>guava</artifactId>
|
||||
<version>27.0.1-jre</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>commons-io</groupId>
|
||||
<artifactId>commons-io</artifactId>
|
||||
<version>2.5</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.apache.commons</groupId>
|
||||
<artifactId>commons-lang3</artifactId>
|
||||
<version>3.4</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.slf4j</groupId>
|
||||
<artifactId>slf4j-api</artifactId>
|
||||
<version>1.7.25</version>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
</project>
|
32
java/performance_test/pom_template.xml
Normal file
32
java/performance_test/pom_template.xml
Normal file
|
@ -0,0 +1,32 @@
|
|||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
{auto_gen_header}
|
||||
<project xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
|
||||
xmlns="http://maven.apache.org/POM/4.0.0"
|
||||
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/maven-v4_0_0.xsd">
|
||||
<parent>
|
||||
<groupId>io.ray</groupId>
|
||||
<artifactId>ray-superpom</artifactId>
|
||||
<version>1.0.0-SNAPSHOT</version>
|
||||
</parent>
|
||||
<modelVersion>4.0.0</modelVersion>
|
||||
|
||||
<artifactId>ray-performance-test</artifactId>
|
||||
<name>java performance test cases for ray</name>
|
||||
<description>java performance test cases for ray</description>
|
||||
|
||||
<packaging>jar</packaging>
|
||||
|
||||
<dependencies>
|
||||
<dependency>
|
||||
<groupId>io.ray</groupId>
|
||||
<artifactId>ray-api</artifactId>
|
||||
<version>${project.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>io.ray</groupId>
|
||||
<artifactId>ray-runtime</artifactId>
|
||||
<version>${project.version}</version>
|
||||
</dependency>
|
||||
{generated_bzl_deps}
|
||||
</dependencies>
|
||||
</project>
|
|
@ -0,0 +1,40 @@
|
|||
package io.ray.performancetest;
|
||||
|
||||
import java.nio.ByteBuffer;
|
||||
|
||||
public class Receiver {
|
||||
private int value = 0;
|
||||
|
||||
public Receiver() {}
|
||||
|
||||
public boolean ping() {
|
||||
return true;
|
||||
}
|
||||
|
||||
public void noArgsNoReturn() {
|
||||
value += 1;
|
||||
}
|
||||
|
||||
public int noArgsHasReturn() {
|
||||
value += 1;
|
||||
return value;
|
||||
}
|
||||
|
||||
public void bytesNoReturn(byte[] data) {
|
||||
value += 1;
|
||||
}
|
||||
|
||||
public int bytesHasReturn(byte[] data) {
|
||||
value += 1;
|
||||
return value;
|
||||
}
|
||||
|
||||
public void byteBufferNoReturn(ByteBuffer data) {
|
||||
value += 1;
|
||||
}
|
||||
|
||||
public int byteBufferHasReturn(ByteBuffer data) {
|
||||
value += 1;
|
||||
return value;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,117 @@
|
|||
package io.ray.performancetest;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
import io.ray.api.ActorHandle;
|
||||
import io.ray.api.ObjectRef;
|
||||
import io.ray.api.Ray;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Random;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
public class Source {
|
||||
private static final int BATCH_SIZE;
|
||||
|
||||
private static final Logger LOGGER = LoggerFactory.getLogger(Source.class);
|
||||
|
||||
private final List<ActorHandle<Receiver>> receivers;
|
||||
|
||||
static {
|
||||
String batchSizeString = System.getenv().get("PERF_TEST_BATCH_SIZE");
|
||||
if (batchSizeString != null) {
|
||||
BATCH_SIZE = Integer.valueOf(batchSizeString);
|
||||
} else {
|
||||
BATCH_SIZE = 1000;
|
||||
}
|
||||
}
|
||||
|
||||
public Source(List<ActorHandle<Receiver>> receivers) {
|
||||
this.receivers = receivers;
|
||||
}
|
||||
|
||||
public boolean startTest(
|
||||
boolean hasReturn, boolean ignoreReturn, int argSize, boolean useDirectByteBuffer) {
|
||||
LOGGER.info("Source startTest");
|
||||
byte[] bytes = null;
|
||||
ByteBuffer buffer = null;
|
||||
if (argSize > 0) {
|
||||
bytes = new byte[argSize];
|
||||
new Random().nextBytes(bytes);
|
||||
buffer = ByteBuffer.wrap(bytes);
|
||||
} else {
|
||||
Preconditions.checkState(!useDirectByteBuffer);
|
||||
}
|
||||
|
||||
// Wait for actors to be created.
|
||||
for (ActorHandle<Receiver> receiver : receivers) {
|
||||
receiver.task(Receiver::ping).remote().get();
|
||||
}
|
||||
|
||||
LOGGER.info(
|
||||
"Started executing tasks, useDirectByteBuffer: {}, argSize: {}, has return: {}",
|
||||
useDirectByteBuffer,
|
||||
argSize,
|
||||
hasReturn);
|
||||
|
||||
List<List<ObjectRef<Integer>>> returnObjects = new ArrayList<>();
|
||||
returnObjects.add(new ArrayList<>());
|
||||
returnObjects.add(new ArrayList<>());
|
||||
|
||||
long startTime = System.currentTimeMillis();
|
||||
int numTasks = 0;
|
||||
long lastReport = 0;
|
||||
long totalTime = 0;
|
||||
long batchCount = 0;
|
||||
while (true) {
|
||||
numTasks++;
|
||||
boolean batchEnd = numTasks % BATCH_SIZE == 0;
|
||||
for (ActorHandle<Receiver> receiver : receivers) {
|
||||
if (hasReturn || batchEnd) {
|
||||
ObjectRef<Integer> returnObject;
|
||||
if (useDirectByteBuffer) {
|
||||
returnObject = receiver.task(Receiver::byteBufferHasReturn, buffer).remote();
|
||||
} else if (argSize > 0) {
|
||||
returnObject = receiver.task(Receiver::bytesHasReturn, bytes).remote();
|
||||
} else {
|
||||
returnObject = receiver.task(Receiver::noArgsHasReturn).remote();
|
||||
}
|
||||
returnObjects.get(1).add(returnObject);
|
||||
} else {
|
||||
if (useDirectByteBuffer) {
|
||||
receiver.task(Receiver::byteBufferNoReturn, buffer).remote();
|
||||
} else if (argSize > 0) {
|
||||
receiver.task(Receiver::bytesNoReturn, bytes).remote();
|
||||
} else {
|
||||
receiver.task(Receiver::noArgsNoReturn).remote();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (batchEnd) {
|
||||
batchCount++;
|
||||
long getBeginTs = System.currentTimeMillis();
|
||||
Ray.get(returnObjects.get(0));
|
||||
long rt = System.currentTimeMillis() - getBeginTs;
|
||||
totalTime += rt;
|
||||
returnObjects.set(0, returnObjects.get(1));
|
||||
returnObjects.set(1, new ArrayList<>());
|
||||
|
||||
long elapsedTime = System.currentTimeMillis() - startTime;
|
||||
if (elapsedTime / 60000 > lastReport) {
|
||||
lastReport = elapsedTime / 60000;
|
||||
LOGGER.info(
|
||||
"Finished executing {} tasks in {} ms, useDirectByteBuffer: {}, argSize: {}, "
|
||||
+ "has return: {}, avg get rt: {}",
|
||||
numTasks,
|
||||
elapsedTime,
|
||||
useDirectByteBuffer,
|
||||
argSize,
|
||||
hasReturn,
|
||||
totalTime / (float) batchCount);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,76 @@
|
|||
package io.ray.performancetest.test;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
import io.ray.api.ActorHandle;
|
||||
import io.ray.api.ObjectRef;
|
||||
import io.ray.api.Ray;
|
||||
import io.ray.performancetest.Receiver;
|
||||
import io.ray.performancetest.Source;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.stream.Collectors;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
public class ActorPerformanceTestBase {
|
||||
|
||||
private static final Logger LOGGER = LoggerFactory.getLogger(ActorPerformanceTestBase.class);
|
||||
|
||||
public static void run(
|
||||
String[] args,
|
||||
int[] layers,
|
||||
int[] actorsPerLayer,
|
||||
boolean hasReturn,
|
||||
boolean ignoreReturn,
|
||||
int argSize,
|
||||
boolean useDirectByteBuffer,
|
||||
int numJavaWorkerPerProcess) {
|
||||
System.setProperty(
|
||||
"ray.job.num-java-workers-per-process", String.valueOf(numJavaWorkerPerProcess));
|
||||
Ray.init();
|
||||
try {
|
||||
// TODO: Support more layers.
|
||||
Preconditions.checkState(layers.length == 2);
|
||||
Preconditions.checkState(actorsPerLayer.length == layers.length);
|
||||
for (int i = 0; i < layers.length; i++) {
|
||||
Preconditions.checkState(layers[i] > 0);
|
||||
Preconditions.checkState(actorsPerLayer[i] > 0);
|
||||
}
|
||||
|
||||
List<ActorHandle<Receiver>> receivers = new ArrayList<>();
|
||||
for (int i = 0; i < layers[1]; i++) {
|
||||
int nodeIndex = layers[0] + i;
|
||||
for (int j = 0; j < actorsPerLayer[1]; j++) {
|
||||
receivers.add(Ray.actor(Receiver::new).remote());
|
||||
}
|
||||
}
|
||||
|
||||
List<ActorHandle<Source>> sources = new ArrayList<>();
|
||||
for (int i = 0; i < layers[0]; i++) {
|
||||
int nodeIndex = i;
|
||||
for (int j = 0; j < actorsPerLayer[0]; j++) {
|
||||
sources.add(Ray.actor(Source::new, receivers).remote());
|
||||
}
|
||||
}
|
||||
|
||||
List<ObjectRef<Boolean>> results =
|
||||
sources.stream()
|
||||
.map(
|
||||
source ->
|
||||
source
|
||||
.task(
|
||||
Source::startTest,
|
||||
hasReturn,
|
||||
ignoreReturn,
|
||||
argSize,
|
||||
useDirectByteBuffer)
|
||||
.remote())
|
||||
.collect(Collectors.toList());
|
||||
|
||||
Ray.get(results);
|
||||
} catch (Throwable e) {
|
||||
LOGGER.error("Run test failed.", e);
|
||||
throw e;
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,27 @@
|
|||
package io.ray.performancetest.test;
|
||||
|
||||
/**
|
||||
* 1-to-1 ray call, one receiving actor and one sending actor, test the throughput of the engine
|
||||
* itself.
|
||||
*/
|
||||
public class ActorPerformanceTestCase1 {
|
||||
|
||||
public static void main(String[] args) {
|
||||
final int[] layers = new int[] {1, 1};
|
||||
final int[] actorsPerLayer = new int[] {1, 1};
|
||||
final boolean hasReturn = false;
|
||||
final int argSize = 0;
|
||||
final boolean useDirectByteBuffer = false;
|
||||
final boolean ignoreReturn = false;
|
||||
final int numJavaWorkerPerProcess = 1;
|
||||
ActorPerformanceTestBase.run(
|
||||
args,
|
||||
layers,
|
||||
actorsPerLayer,
|
||||
hasReturn,
|
||||
ignoreReturn,
|
||||
argSize,
|
||||
useDirectByteBuffer,
|
||||
numJavaWorkerPerProcess);
|
||||
}
|
||||
}
|
25
java/test.sh
25
java/test.sh
|
@ -41,6 +41,24 @@ run_testng() {
|
|||
fi
|
||||
}
|
||||
|
||||
run_timeout() {
|
||||
local pid
|
||||
timeout=$1
|
||||
shift 1
|
||||
"$@" &
|
||||
pid=$!
|
||||
sleep "$timeout"
|
||||
if ps -p $pid > /dev/null
|
||||
then
|
||||
echo "run_timeout process exists, kill it."
|
||||
kill -9 $pid
|
||||
else
|
||||
echo "run_timeout process not exist."
|
||||
cat /tmp/ray/session_latest/logs/java-core-driver-*$pid*
|
||||
exit 1
|
||||
fi
|
||||
}
|
||||
|
||||
pushd "$ROOT_DIR"/..
|
||||
echo "Build java maven deps."
|
||||
bazel build //java:gen_maven_deps
|
||||
|
@ -114,3 +132,10 @@ mvn -Dorg.slf4j.simpleLogger.defaultLogLevel=WARN clean install -DskipTests -Dch
|
|||
# Ensure mvn test works
|
||||
mvn test -pl test -Dtest="io.ray.test.HelloWorldTest"
|
||||
popd
|
||||
|
||||
pushd "$ROOT_DIR"
|
||||
echo "Running performance test."
|
||||
run_timeout 60 java -cp "$ROOT_DIR"/../bazel-bin/java/all_tests_deploy.jar io.ray.performancetest.test.ActorPerformanceTestCase1
|
||||
# The performance process may be killed by run_timeout, so clear ray here.
|
||||
ray stop
|
||||
popd
|
||||
|
|
|
@ -35,7 +35,10 @@ class MessageWrapper {
|
|||
/// The input message will be **copied** into this object.
|
||||
///
|
||||
/// \param message The protobuf message.
|
||||
explicit MessageWrapper(const Message message)
|
||||
explicit MessageWrapper(const Message &message)
|
||||
: message_(std::make_shared<Message>(message)) {}
|
||||
|
||||
explicit MessageWrapper(Message &&message)
|
||||
: message_(std::make_shared<Message>(std::move(message))) {}
|
||||
|
||||
/// Construct from a protobuf message shared_ptr.
|
||||
|
@ -115,7 +118,8 @@ inline std::vector<ID> IdVectorFromProtobuf(
|
|||
|
||||
/// Converts a Protobuf map to a `unordered_map`.
|
||||
template <class K, class V>
|
||||
inline std::unordered_map<K, V> MapFromProtobuf(::google::protobuf::Map<K, V> pb_map) {
|
||||
inline std::unordered_map<K, V> MapFromProtobuf(
|
||||
const ::google::protobuf::Map<K, V> &pb_map) {
|
||||
return std::unordered_map<K, V>(pb_map.begin(), pb_map.end());
|
||||
}
|
||||
|
||||
|
|
|
@ -53,24 +53,24 @@ bool TaskSpecification::PlacementGroupCaptureChildTasks() const {
|
|||
}
|
||||
|
||||
void TaskSpecification::ComputeResources() {
|
||||
auto required_resources = MapFromProtobuf(message_->required_resources());
|
||||
auto required_placement_resources =
|
||||
MapFromProtobuf(message_->required_placement_resources());
|
||||
if (required_placement_resources.empty()) {
|
||||
required_placement_resources = required_resources;
|
||||
}
|
||||
auto &required_resources = message_->required_resources();
|
||||
|
||||
if (required_resources.empty()) {
|
||||
// A static nil object is used here to avoid allocating the empty object every time.
|
||||
required_resources_ = ResourceSet::Nil();
|
||||
} else {
|
||||
required_resources_.reset(new ResourceSet(required_resources));
|
||||
required_resources_.reset(new ResourceSet(MapFromProtobuf(required_resources)));
|
||||
}
|
||||
|
||||
auto &required_placement_resources = message_->required_placement_resources().empty()
|
||||
? required_resources
|
||||
: message_->required_placement_resources();
|
||||
|
||||
if (required_placement_resources.empty()) {
|
||||
required_placement_resources_ = ResourceSet::Nil();
|
||||
} else {
|
||||
required_placement_resources_.reset(new ResourceSet(required_placement_resources));
|
||||
required_placement_resources_.reset(
|
||||
new ResourceSet(MapFromProtobuf(required_placement_resources)));
|
||||
}
|
||||
|
||||
if (!IsActorTask()) {
|
||||
|
@ -174,14 +174,15 @@ std::vector<ObjectID> TaskSpecification::GetDependencyIds() const {
|
|||
return dependencies;
|
||||
}
|
||||
|
||||
std::vector<rpc::ObjectReference> TaskSpecification::GetDependencies() const {
|
||||
std::vector<rpc::ObjectReference> TaskSpecification::GetDependencies(
|
||||
bool add_dummy_dependency) const {
|
||||
std::vector<rpc::ObjectReference> dependencies;
|
||||
for (size_t i = 0; i < NumArgs(); ++i) {
|
||||
if (ArgByRef(i)) {
|
||||
dependencies.push_back(message_->args(i).object_ref());
|
||||
}
|
||||
}
|
||||
if (IsActorTask()) {
|
||||
if (add_dummy_dependency && IsActorTask()) {
|
||||
const auto &dummy_ref =
|
||||
GetReferenceForActorDummyObject(PreviousActorTaskDummyObjectId());
|
||||
dependencies.push_back(dummy_ref);
|
||||
|
|
|
@ -36,10 +36,15 @@ class TaskSpecification : public MessageWrapper<rpc::TaskSpec> {
|
|||
TaskSpecification() {}
|
||||
|
||||
/// Construct from a protobuf message object.
|
||||
/// The input message will be **copied** into this object.
|
||||
/// The input message will be copied/moved into this object.
|
||||
///
|
||||
/// \param message The protobuf message.
|
||||
explicit TaskSpecification(rpc::TaskSpec message) : MessageWrapper(message) {
|
||||
explicit TaskSpecification(rpc::TaskSpec &&message)
|
||||
: MessageWrapper(std::move(message)) {
|
||||
ComputeResources();
|
||||
}
|
||||
|
||||
explicit TaskSpecification(const rpc::TaskSpec &message) : MessageWrapper(message) {
|
||||
ComputeResources();
|
||||
}
|
||||
|
||||
|
@ -127,9 +132,10 @@ class TaskSpecification : public MessageWrapper<rpc::TaskSpec> {
|
|||
|
||||
/// Return the dependencies of this task. This is recomputed each time, so it can
|
||||
/// be used if the task spec is mutated.
|
||||
///
|
||||
/// \param add_dummy_dependency whether to add a dummy object in the returned objects.
|
||||
/// \return The recomputed dependencies for the task.
|
||||
std::vector<rpc::ObjectReference> GetDependencies() const;
|
||||
std::vector<rpc::ObjectReference> GetDependencies(
|
||||
bool add_dummy_dependency = true) const;
|
||||
|
||||
std::string GetDebuggerBreakpoint() const;
|
||||
|
||||
|
|
|
@ -2219,7 +2219,7 @@ void CoreWorker::HandlePushTask(const rpc::PushTaskRequest &request,
|
|||
// execution service.
|
||||
if (request.task_spec().type() == TaskType::ACTOR_TASK) {
|
||||
task_execution_service_.post(
|
||||
[=] {
|
||||
[this, request, reply, send_reply_callback = std::move(send_reply_callback)] {
|
||||
// We have posted an exit task onto the main event loop,
|
||||
// so shouldn't bother executing any further work.
|
||||
if (exiting_) return;
|
||||
|
|
|
@ -42,12 +42,12 @@ TEST(SchedulingQueueTest, TestInOrder) {
|
|||
ActorSchedulingQueue queue(io_service, waiter);
|
||||
int n_ok = 0;
|
||||
int n_rej = 0;
|
||||
auto fn_ok = [&n_ok]() { n_ok++; };
|
||||
auto fn_rej = [&n_rej]() { n_rej++; };
|
||||
queue.Add(0, -1, fn_ok, fn_rej);
|
||||
queue.Add(1, -1, fn_ok, fn_rej);
|
||||
queue.Add(2, -1, fn_ok, fn_rej);
|
||||
queue.Add(3, -1, fn_ok, fn_rej);
|
||||
auto fn_ok = [&n_ok](rpc::SendReplyCallback callback) { n_ok++; };
|
||||
auto fn_rej = [&n_rej](rpc::SendReplyCallback callback) { n_rej++; };
|
||||
queue.Add(0, -1, fn_ok, fn_rej, nullptr);
|
||||
queue.Add(1, -1, fn_ok, fn_rej, nullptr);
|
||||
queue.Add(2, -1, fn_ok, fn_rej, nullptr);
|
||||
queue.Add(3, -1, fn_ok, fn_rej, nullptr);
|
||||
io_service.run();
|
||||
ASSERT_EQ(n_ok, 4);
|
||||
ASSERT_EQ(n_rej, 0);
|
||||
|
@ -62,12 +62,12 @@ TEST(SchedulingQueueTest, TestWaitForObjects) {
|
|||
ActorSchedulingQueue queue(io_service, waiter);
|
||||
int n_ok = 0;
|
||||
int n_rej = 0;
|
||||
auto fn_ok = [&n_ok]() { n_ok++; };
|
||||
auto fn_rej = [&n_rej]() { n_rej++; };
|
||||
queue.Add(0, -1, fn_ok, fn_rej);
|
||||
queue.Add(1, -1, fn_ok, fn_rej, TaskID::Nil(), ObjectIdsToRefs({obj1}));
|
||||
queue.Add(2, -1, fn_ok, fn_rej, TaskID::Nil(), ObjectIdsToRefs({obj2}));
|
||||
queue.Add(3, -1, fn_ok, fn_rej, TaskID::Nil(), ObjectIdsToRefs({obj3}));
|
||||
auto fn_ok = [&n_ok](rpc::SendReplyCallback callback) { n_ok++; };
|
||||
auto fn_rej = [&n_rej](rpc::SendReplyCallback callback) { n_rej++; };
|
||||
queue.Add(0, -1, fn_ok, fn_rej, nullptr);
|
||||
queue.Add(1, -1, fn_ok, fn_rej, nullptr, TaskID::Nil(), ObjectIdsToRefs({obj1}));
|
||||
queue.Add(2, -1, fn_ok, fn_rej, nullptr, TaskID::Nil(), ObjectIdsToRefs({obj2}));
|
||||
queue.Add(3, -1, fn_ok, fn_rej, nullptr, TaskID::Nil(), ObjectIdsToRefs({obj3}));
|
||||
ASSERT_EQ(n_ok, 1);
|
||||
|
||||
waiter.Complete(0);
|
||||
|
@ -87,10 +87,10 @@ TEST(SchedulingQueueTest, TestWaitForObjectsNotSubjectToSeqTimeout) {
|
|||
ActorSchedulingQueue queue(io_service, waiter);
|
||||
int n_ok = 0;
|
||||
int n_rej = 0;
|
||||
auto fn_ok = [&n_ok]() { n_ok++; };
|
||||
auto fn_rej = [&n_rej]() { n_rej++; };
|
||||
queue.Add(0, -1, fn_ok, fn_rej);
|
||||
queue.Add(1, -1, fn_ok, fn_rej, TaskID::Nil(), ObjectIdsToRefs({obj1}));
|
||||
auto fn_ok = [&n_ok](rpc::SendReplyCallback callback) { n_ok++; };
|
||||
auto fn_rej = [&n_rej](rpc::SendReplyCallback callback) { n_rej++; };
|
||||
queue.Add(0, -1, fn_ok, fn_rej, nullptr);
|
||||
queue.Add(1, -1, fn_ok, fn_rej, nullptr, TaskID::Nil(), ObjectIdsToRefs({obj1}));
|
||||
ASSERT_EQ(n_ok, 1);
|
||||
io_service.run();
|
||||
ASSERT_EQ(n_rej, 0);
|
||||
|
@ -104,12 +104,12 @@ TEST(SchedulingQueueTest, TestOutOfOrder) {
|
|||
ActorSchedulingQueue queue(io_service, waiter);
|
||||
int n_ok = 0;
|
||||
int n_rej = 0;
|
||||
auto fn_ok = [&n_ok]() { n_ok++; };
|
||||
auto fn_rej = [&n_rej]() { n_rej++; };
|
||||
queue.Add(2, -1, fn_ok, fn_rej);
|
||||
queue.Add(0, -1, fn_ok, fn_rej);
|
||||
queue.Add(3, -1, fn_ok, fn_rej);
|
||||
queue.Add(1, -1, fn_ok, fn_rej);
|
||||
auto fn_ok = [&n_ok](rpc::SendReplyCallback callback) { n_ok++; };
|
||||
auto fn_rej = [&n_rej](rpc::SendReplyCallback callback) { n_rej++; };
|
||||
queue.Add(2, -1, fn_ok, fn_rej, nullptr);
|
||||
queue.Add(0, -1, fn_ok, fn_rej, nullptr);
|
||||
queue.Add(3, -1, fn_ok, fn_rej, nullptr);
|
||||
queue.Add(1, -1, fn_ok, fn_rej, nullptr);
|
||||
io_service.run();
|
||||
ASSERT_EQ(n_ok, 4);
|
||||
ASSERT_EQ(n_rej, 0);
|
||||
|
@ -121,18 +121,18 @@ TEST(SchedulingQueueTest, TestSeqWaitTimeout) {
|
|||
ActorSchedulingQueue queue(io_service, waiter);
|
||||
int n_ok = 0;
|
||||
int n_rej = 0;
|
||||
auto fn_ok = [&n_ok]() { n_ok++; };
|
||||
auto fn_rej = [&n_rej]() { n_rej++; };
|
||||
queue.Add(2, -1, fn_ok, fn_rej);
|
||||
queue.Add(0, -1, fn_ok, fn_rej);
|
||||
queue.Add(3, -1, fn_ok, fn_rej);
|
||||
auto fn_ok = [&n_ok](rpc::SendReplyCallback callback) { n_ok++; };
|
||||
auto fn_rej = [&n_rej](rpc::SendReplyCallback callback) { n_rej++; };
|
||||
queue.Add(2, -1, fn_ok, fn_rej, nullptr);
|
||||
queue.Add(0, -1, fn_ok, fn_rej, nullptr);
|
||||
queue.Add(3, -1, fn_ok, fn_rej, nullptr);
|
||||
ASSERT_EQ(n_ok, 1);
|
||||
ASSERT_EQ(n_rej, 0);
|
||||
io_service.run(); // immediately triggers timeout
|
||||
ASSERT_EQ(n_ok, 1);
|
||||
ASSERT_EQ(n_rej, 2);
|
||||
queue.Add(4, -1, fn_ok, fn_rej);
|
||||
queue.Add(5, -1, fn_ok, fn_rej);
|
||||
queue.Add(4, -1, fn_ok, fn_rej, nullptr);
|
||||
queue.Add(5, -1, fn_ok, fn_rej, nullptr);
|
||||
ASSERT_EQ(n_ok, 3);
|
||||
ASSERT_EQ(n_rej, 2);
|
||||
}
|
||||
|
@ -143,11 +143,11 @@ TEST(SchedulingQueueTest, TestSkipAlreadyProcessedByClient) {
|
|||
ActorSchedulingQueue queue(io_service, waiter);
|
||||
int n_ok = 0;
|
||||
int n_rej = 0;
|
||||
auto fn_ok = [&n_ok]() { n_ok++; };
|
||||
auto fn_rej = [&n_rej]() { n_rej++; };
|
||||
queue.Add(2, 2, fn_ok, fn_rej);
|
||||
queue.Add(3, 2, fn_ok, fn_rej);
|
||||
queue.Add(1, 2, fn_ok, fn_rej);
|
||||
auto fn_ok = [&n_ok](rpc::SendReplyCallback callback) { n_ok++; };
|
||||
auto fn_rej = [&n_rej](rpc::SendReplyCallback callback) { n_rej++; };
|
||||
queue.Add(2, 2, fn_ok, fn_rej, nullptr);
|
||||
queue.Add(3, 2, fn_ok, fn_rej, nullptr);
|
||||
queue.Add(1, 2, fn_ok, fn_rej, nullptr);
|
||||
io_service.run();
|
||||
ASSERT_EQ(n_ok, 1);
|
||||
ASSERT_EQ(n_rej, 2);
|
||||
|
@ -158,13 +158,13 @@ TEST(SchedulingQueueTest, TestCancelQueuedTask) {
|
|||
ASSERT_TRUE(queue->TaskQueueEmpty());
|
||||
int n_ok = 0;
|
||||
int n_rej = 0;
|
||||
auto fn_ok = [&n_ok]() { n_ok++; };
|
||||
auto fn_rej = [&n_rej]() { n_rej++; };
|
||||
queue->Add(-1, -1, fn_ok, fn_rej);
|
||||
queue->Add(-1, -1, fn_ok, fn_rej);
|
||||
queue->Add(-1, -1, fn_ok, fn_rej);
|
||||
queue->Add(-1, -1, fn_ok, fn_rej);
|
||||
queue->Add(-1, -1, fn_ok, fn_rej);
|
||||
auto fn_ok = [&n_ok](rpc::SendReplyCallback callback) { n_ok++; };
|
||||
auto fn_rej = [&n_rej](rpc::SendReplyCallback callback) { n_rej++; };
|
||||
queue->Add(-1, -1, fn_ok, fn_rej, nullptr);
|
||||
queue->Add(-1, -1, fn_ok, fn_rej, nullptr);
|
||||
queue->Add(-1, -1, fn_ok, fn_rej, nullptr);
|
||||
queue->Add(-1, -1, fn_ok, fn_rej, nullptr);
|
||||
queue->Add(-1, -1, fn_ok, fn_rej, nullptr);
|
||||
ASSERT_TRUE(queue->CancelTaskIfFound(TaskID::Nil()));
|
||||
ASSERT_FALSE(queue->TaskQueueEmpty());
|
||||
queue->ScheduleRequests();
|
||||
|
|
|
@ -59,14 +59,16 @@ void CoreWorkerDirectActorTaskSubmitter::KillActor(const ActorID &actor_id,
|
|||
}
|
||||
|
||||
Status CoreWorkerDirectActorTaskSubmitter::SubmitTask(TaskSpecification task_spec) {
|
||||
RAY_LOG(DEBUG) << "Submitting task " << task_spec.TaskId();
|
||||
auto task_id = task_spec.TaskId();
|
||||
auto actor_id = task_spec.ActorId();
|
||||
RAY_LOG(DEBUG) << "Submitting task " << task_id;
|
||||
RAY_CHECK(task_spec.IsActorTask());
|
||||
|
||||
bool task_queued = false;
|
||||
uint64_t send_pos = 0;
|
||||
{
|
||||
absl::MutexLock lock(&mu_);
|
||||
auto queue = client_queues_.find(task_spec.ActorId());
|
||||
auto queue = client_queues_.find(actor_id);
|
||||
RAY_CHECK(queue != client_queues_.end());
|
||||
if (queue->second.state != rpc::ActorTableData::DEAD) {
|
||||
// We must fix the send order prior to resolving dependencies, which may
|
||||
|
@ -82,7 +84,6 @@ Status CoreWorkerDirectActorTaskSubmitter::SubmitTask(TaskSpecification task_spe
|
|||
}
|
||||
|
||||
if (task_queued) {
|
||||
const auto actor_id = task_spec.ActorId();
|
||||
// We must release the lock before resolving the task dependencies since
|
||||
// the callback may get called in the same call stack.
|
||||
resolver_.ResolveDependencies(task_spec, [this, send_pos, actor_id]() {
|
||||
|
@ -99,7 +100,7 @@ Status CoreWorkerDirectActorTaskSubmitter::SubmitTask(TaskSpecification task_spe
|
|||
});
|
||||
} else {
|
||||
// Do not hold the lock while calling into task_finisher_.
|
||||
task_finisher_->MarkTaskCanceled(task_spec.TaskId());
|
||||
task_finisher_->MarkTaskCanceled(task_id);
|
||||
std::shared_ptr<rpc::RayException> creation_task_exception = nullptr;
|
||||
{
|
||||
absl::MutexLock lock(&mu_);
|
||||
|
@ -109,9 +110,8 @@ Status CoreWorkerDirectActorTaskSubmitter::SubmitTask(TaskSpecification task_spe
|
|||
auto status = Status::IOError("cancelling task of dead actor");
|
||||
// No need to increment the number of completed tasks since the actor is
|
||||
// dead.
|
||||
RAY_UNUSED(!task_finisher_->PendingTaskFailed(task_spec.TaskId(),
|
||||
rpc::ErrorType::ACTOR_DIED, &status,
|
||||
creation_task_exception));
|
||||
RAY_UNUSED(!task_finisher_->PendingTaskFailed(task_id, rpc::ErrorType::ACTOR_DIED,
|
||||
&status, creation_task_exception));
|
||||
}
|
||||
|
||||
// If the task submission subsequently fails, then the client will receive
|
||||
|
@ -423,7 +423,10 @@ void CoreWorkerDirectTaskReceiver::HandleTask(
|
|||
const rpc::PushTaskRequest &request, rpc::PushTaskReply *reply,
|
||||
rpc::SendReplyCallback send_reply_callback) {
|
||||
RAY_CHECK(waiter_ != nullptr) << "Must call init() prior to use";
|
||||
const TaskSpecification task_spec(request.task_spec());
|
||||
// Use `mutable_task_spec()` here as `task_spec()` returns a const reference
|
||||
// which doesn't work with std::move.
|
||||
TaskSpecification task_spec(
|
||||
std::move(*(const_cast<rpc::PushTaskRequest &>(request).mutable_task_spec())));
|
||||
|
||||
// If GCS server is restarted after sending an actor creation task to this core worker,
|
||||
// the restarted GCS server will send the same actor creation task to the core worker
|
||||
|
@ -455,7 +458,8 @@ void CoreWorkerDirectTaskReceiver::HandleTask(
|
|||
}
|
||||
}
|
||||
|
||||
auto accept_callback = [this, reply, send_reply_callback, task_spec, resource_ids]() {
|
||||
auto accept_callback = [this, reply, task_spec,
|
||||
resource_ids](rpc::SendReplyCallback send_reply_callback) {
|
||||
if (task_spec.GetMessage().skip_execution()) {
|
||||
send_reply_callback(Status::OK(), nullptr, nullptr);
|
||||
return;
|
||||
|
@ -522,11 +526,11 @@ void CoreWorkerDirectTaskReceiver::HandleTask(
|
|||
}
|
||||
};
|
||||
|
||||
auto reject_callback = [send_reply_callback]() {
|
||||
auto reject_callback = [](rpc::SendReplyCallback send_reply_callback) {
|
||||
send_reply_callback(Status::Invalid("client cancelled stale rpc"), nullptr, nullptr);
|
||||
};
|
||||
|
||||
auto dependencies = task_spec.GetDependencies();
|
||||
auto dependencies = task_spec.GetDependencies(false);
|
||||
|
||||
if (task_spec.IsActorTask()) {
|
||||
auto it = actor_scheduling_queues_.find(task_spec.CallerWorkerId());
|
||||
|
@ -538,16 +542,15 @@ void CoreWorkerDirectTaskReceiver::HandleTask(
|
|||
it = result.first;
|
||||
}
|
||||
|
||||
// Pop the dummy actor dependency.
|
||||
// TODO(swang): Remove this with legacy raylet code.
|
||||
dependencies.pop_back();
|
||||
it->second->Add(request.sequence_number(), request.client_processed_up_to(),
|
||||
accept_callback, reject_callback, task_spec.TaskId(), dependencies);
|
||||
std::move(accept_callback), std::move(reject_callback),
|
||||
std::move(send_reply_callback), task_spec.TaskId(), dependencies);
|
||||
} else {
|
||||
// Add the normal task's callbacks to the non-actor scheduling queue.
|
||||
normal_scheduling_queue_->Add(request.sequence_number(),
|
||||
request.client_processed_up_to(), accept_callback,
|
||||
reject_callback, task_spec.TaskId(), dependencies);
|
||||
normal_scheduling_queue_->Add(
|
||||
request.sequence_number(), request.client_processed_up_to(),
|
||||
std::move(accept_callback), std::move(reject_callback),
|
||||
std::move(send_reply_callback), task_spec.TaskId(), dependencies);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -276,23 +276,26 @@ class CoreWorkerDirectActorTaskSubmitter
|
|||
class InboundRequest {
|
||||
public:
|
||||
InboundRequest(){};
|
||||
InboundRequest(std::function<void()> accept_callback,
|
||||
std::function<void()> reject_callback, TaskID task_id,
|
||||
InboundRequest(std::function<void(rpc::SendReplyCallback)> accept_callback,
|
||||
std::function<void(rpc::SendReplyCallback)> reject_callback,
|
||||
rpc::SendReplyCallback send_reply_callback, TaskID task_id,
|
||||
bool has_dependencies)
|
||||
: accept_callback_(accept_callback),
|
||||
reject_callback_(reject_callback),
|
||||
: accept_callback_(std::move(accept_callback)),
|
||||
reject_callback_(std::move(reject_callback)),
|
||||
send_reply_callback_(std::move(send_reply_callback)),
|
||||
task_id(task_id),
|
||||
has_pending_dependencies_(has_dependencies) {}
|
||||
|
||||
void Accept() { accept_callback_(); }
|
||||
void Cancel() { reject_callback_(); }
|
||||
void Accept() { accept_callback_(std::move(send_reply_callback_)); }
|
||||
void Cancel() { reject_callback_(std::move(send_reply_callback_)); }
|
||||
bool CanExecute() const { return !has_pending_dependencies_; }
|
||||
ray::TaskID TaskID() const { return task_id; }
|
||||
void MarkDependenciesSatisfied() { has_pending_dependencies_ = false; }
|
||||
|
||||
private:
|
||||
std::function<void()> accept_callback_;
|
||||
std::function<void()> reject_callback_;
|
||||
std::function<void(rpc::SendReplyCallback)> accept_callback_;
|
||||
std::function<void(rpc::SendReplyCallback)> reject_callback_;
|
||||
rpc::SendReplyCallback send_reply_callback_;
|
||||
ray::TaskID task_id;
|
||||
bool has_pending_dependencies_;
|
||||
};
|
||||
|
@ -372,8 +375,10 @@ class BoundedExecutor {
|
|||
class SchedulingQueue {
|
||||
public:
|
||||
virtual void Add(int64_t seq_no, int64_t client_processed_up_to,
|
||||
std::function<void()> accept_request,
|
||||
std::function<void()> reject_request, TaskID task_id = TaskID::Nil(),
|
||||
std::function<void(rpc::SendReplyCallback)> accept_request,
|
||||
std::function<void(rpc::SendReplyCallback)> reject_request,
|
||||
rpc::SendReplyCallback send_reply_callback,
|
||||
TaskID task_id = TaskID::Nil(),
|
||||
const std::vector<rpc::ObjectReference> &dependencies = {}) = 0;
|
||||
virtual void ScheduleRequests() = 0;
|
||||
virtual bool TaskQueueEmpty() const = 0;
|
||||
|
@ -402,8 +407,9 @@ class ActorSchedulingQueue : public SchedulingQueue {
|
|||
|
||||
/// Add a new actor task's callbacks to the worker queue.
|
||||
void Add(int64_t seq_no, int64_t client_processed_up_to,
|
||||
std::function<void()> accept_request, std::function<void()> reject_request,
|
||||
TaskID task_id = TaskID::Nil(),
|
||||
std::function<void(rpc::SendReplyCallback)> accept_request,
|
||||
std::function<void(rpc::SendReplyCallback)> reject_request,
|
||||
rpc::SendReplyCallback send_reply_callback, TaskID task_id = TaskID::Nil(),
|
||||
const std::vector<rpc::ObjectReference> &dependencies = {}) {
|
||||
// A seq_no of -1 means no ordering constraint. Actor tasks must be executed in order.
|
||||
RAY_CHECK(seq_no != -1);
|
||||
|
@ -416,7 +422,8 @@ class ActorSchedulingQueue : public SchedulingQueue {
|
|||
}
|
||||
RAY_LOG(DEBUG) << "Enqueue " << seq_no << " cur seqno " << next_seq_no_;
|
||||
pending_actor_tasks_[seq_no] =
|
||||
InboundRequest(accept_request, reject_request, task_id, dependencies.size() > 0);
|
||||
InboundRequest(std::move(accept_request), std::move(reject_request),
|
||||
std::move(send_reply_callback), task_id, dependencies.size() > 0);
|
||||
if (dependencies.size() > 0) {
|
||||
waiter_.Wait(dependencies, [seq_no, this]() {
|
||||
RAY_CHECK(boost::this_thread::get_id() == main_thread_id_);
|
||||
|
@ -541,15 +548,17 @@ class NormalSchedulingQueue : public SchedulingQueue {
|
|||
|
||||
/// Add a new task's callbacks to the worker queue.
|
||||
void Add(int64_t seq_no, int64_t client_processed_up_to,
|
||||
std::function<void()> accept_request, std::function<void()> reject_request,
|
||||
TaskID task_id = TaskID::Nil(),
|
||||
std::function<void(rpc::SendReplyCallback)> accept_request,
|
||||
std::function<void(rpc::SendReplyCallback)> reject_request,
|
||||
rpc::SendReplyCallback send_reply_callback, TaskID task_id = TaskID::Nil(),
|
||||
const std::vector<rpc::ObjectReference> &dependencies = {}) {
|
||||
absl::MutexLock lock(&mu_);
|
||||
// Normal tasks should not have ordering constraints.
|
||||
RAY_CHECK(seq_no == -1);
|
||||
// Create a InboundRequest object for the new task, and add it to the queue.
|
||||
pending_normal_tasks_.push_back(
|
||||
InboundRequest(accept_request, reject_request, task_id, dependencies.size() > 0));
|
||||
InboundRequest(std::move(accept_request), std::move(reject_request),
|
||||
std::move(send_reply_callback), task_id, dependencies.size() > 0));
|
||||
}
|
||||
|
||||
// Search for an InboundRequest associated with the task that we are trying to cancel.
|
||||
|
|
|
@ -66,7 +66,8 @@ class ClientCallImpl : public ClientCall {
|
|||
///
|
||||
/// \param[in] callback The callback function to handle the reply.
|
||||
explicit ClientCallImpl(const ClientCallback<Reply> &callback, std::string call_name)
|
||||
: callback_(callback), call_name_(std::move(call_name)) {}
|
||||
: callback_(std::move(const_cast<ClientCallback<Reply> &>(callback))),
|
||||
call_name_(std::move(call_name)) {}
|
||||
|
||||
Status GetStatus() override {
|
||||
absl::MutexLock lock(&mutex_);
|
||||
|
|
|
@ -285,7 +285,9 @@ class CoreWorkerClient : public std::enable_shared_from_this<CoreWorkerClient>,
|
|||
|
||||
{
|
||||
absl::MutexLock lock(&mutex_);
|
||||
send_queue_.push_back(std::make_pair(std::move(request), callback));
|
||||
send_queue_.push_back(std::make_pair(
|
||||
std::move(request),
|
||||
std::move(const_cast<ClientCallback<PushTaskReply> &>(callback))));
|
||||
}
|
||||
SendRequests();
|
||||
}
|
||||
|
@ -311,13 +313,13 @@ class CoreWorkerClient : public std::enable_shared_from_this<CoreWorkerClient>,
|
|||
send_queue_.pop_front();
|
||||
|
||||
auto request = std::move(pair.first);
|
||||
auto callback = pair.second;
|
||||
int64_t task_size = RequestSizeInBytes(*request);
|
||||
int64_t seq_no = request->sequence_number();
|
||||
request->set_client_processed_up_to(max_finished_seq_no_);
|
||||
rpc_bytes_in_flight_ += task_size;
|
||||
|
||||
auto rpc_callback = [this, this_ptr, seq_no, task_size, callback](
|
||||
auto rpc_callback = [this, this_ptr, seq_no, task_size,
|
||||
callback = std::move(pair.second)](
|
||||
Status status, const rpc::PushTaskReply &reply) {
|
||||
{
|
||||
absl::MutexLock lock(&mutex_);
|
||||
|
@ -331,8 +333,8 @@ class CoreWorkerClient : public std::enable_shared_from_this<CoreWorkerClient>,
|
|||
callback(status, reply);
|
||||
};
|
||||
|
||||
RAY_UNUSED(INVOKE_RPC_CALL(CoreWorkerService, PushTask, *request, rpc_callback,
|
||||
grpc_client_));
|
||||
RAY_UNUSED(INVOKE_RPC_CALL(CoreWorkerService, PushTask, *request,
|
||||
std::move(rpc_callback), grpc_client_));
|
||||
}
|
||||
|
||||
if (!send_queue_.empty()) {
|
||||
|
|
Loading…
Add table
Reference in a new issue