mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
Optimize lambda copy to improve direct call performance. (#15036)
This commit is contained in:
parent
4ed7a14e23
commit
0ad0839265
17 changed files with 513 additions and 98 deletions
|
@ -13,6 +13,7 @@ all_modules = [
|
||||||
"api",
|
"api",
|
||||||
"runtime",
|
"runtime",
|
||||||
"test",
|
"test",
|
||||||
|
"performance_test",
|
||||||
]
|
]
|
||||||
|
|
||||||
java_import(
|
java_import(
|
||||||
|
@ -106,12 +107,26 @@ define_java_module(
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
define_java_module(
|
||||||
|
name = "performance_test",
|
||||||
|
deps = [
|
||||||
|
":io_ray_ray_api",
|
||||||
|
":io_ray_ray_runtime",
|
||||||
|
"@maven//:com_google_code_gson_gson",
|
||||||
|
"@maven//:com_google_guava_guava",
|
||||||
|
"@maven//:commons_io_commons_io",
|
||||||
|
"@maven//:org_apache_commons_commons_lang3",
|
||||||
|
"@maven//:org_slf4j_slf4j_api",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
java_binary(
|
java_binary(
|
||||||
name = "all_tests",
|
name = "all_tests",
|
||||||
args = ["java/testng.xml"],
|
args = ["java/testng.xml"],
|
||||||
data = ["testng.xml"],
|
data = ["testng.xml"],
|
||||||
main_class = "org.testng.TestNG",
|
main_class = "org.testng.TestNG",
|
||||||
runtime_deps = [
|
runtime_deps = [
|
||||||
|
":io_ray_ray_performance_test",
|
||||||
":io_ray_ray_runtime_test",
|
":io_ray_ray_runtime_test",
|
||||||
":io_ray_ray_test",
|
":io_ray_ray_test",
|
||||||
],
|
],
|
||||||
|
@ -207,6 +222,7 @@ genrule(
|
||||||
cp -f $(location //java:io_ray_ray_api_pom) "$$WORK_DIR/java/api/pom.xml"
|
cp -f $(location //java:io_ray_ray_api_pom) "$$WORK_DIR/java/api/pom.xml"
|
||||||
cp -f $(location //java:io_ray_ray_runtime_pom) "$$WORK_DIR/java/runtime/pom.xml"
|
cp -f $(location //java:io_ray_ray_runtime_pom) "$$WORK_DIR/java/runtime/pom.xml"
|
||||||
cp -f $(location //java:io_ray_ray_test_pom) "$$WORK_DIR/java/test/pom.xml"
|
cp -f $(location //java:io_ray_ray_test_pom) "$$WORK_DIR/java/test/pom.xml"
|
||||||
|
cp -f $(location //java:io_ray_ray_performance_test_pom) "$$WORK_DIR/java/performance_test/pom.xml"
|
||||||
date > $@
|
date > $@
|
||||||
""",
|
""",
|
||||||
local = 1,
|
local = 1,
|
||||||
|
|
56
java/performance_test/pom.xml
Executable file
56
java/performance_test/pom.xml
Executable file
|
@ -0,0 +1,56 @@
|
||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<!-- This file is auto-generated by Bazel from pom_template.xml, do not modify it. -->
|
||||||
|
<project xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
|
||||||
|
xmlns="http://maven.apache.org/POM/4.0.0"
|
||||||
|
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/maven-v4_0_0.xsd">
|
||||||
|
<parent>
|
||||||
|
<groupId>io.ray</groupId>
|
||||||
|
<artifactId>ray-superpom</artifactId>
|
||||||
|
<version>1.0.0-SNAPSHOT</version>
|
||||||
|
</parent>
|
||||||
|
<modelVersion>4.0.0</modelVersion>
|
||||||
|
|
||||||
|
<artifactId>ray-performance-test</artifactId>
|
||||||
|
<name>java performance test cases for ray</name>
|
||||||
|
<description>java performance test cases for ray</description>
|
||||||
|
|
||||||
|
<packaging>jar</packaging>
|
||||||
|
|
||||||
|
<dependencies>
|
||||||
|
<dependency>
|
||||||
|
<groupId>io.ray</groupId>
|
||||||
|
<artifactId>ray-api</artifactId>
|
||||||
|
<version>${project.version}</version>
|
||||||
|
</dependency>
|
||||||
|
<dependency>
|
||||||
|
<groupId>io.ray</groupId>
|
||||||
|
<artifactId>ray-runtime</artifactId>
|
||||||
|
<version>${project.version}</version>
|
||||||
|
</dependency>
|
||||||
|
<dependency>
|
||||||
|
<groupId>com.google.code.gson</groupId>
|
||||||
|
<artifactId>gson</artifactId>
|
||||||
|
<version>2.8.5</version>
|
||||||
|
</dependency>
|
||||||
|
<dependency>
|
||||||
|
<groupId>com.google.guava</groupId>
|
||||||
|
<artifactId>guava</artifactId>
|
||||||
|
<version>27.0.1-jre</version>
|
||||||
|
</dependency>
|
||||||
|
<dependency>
|
||||||
|
<groupId>commons-io</groupId>
|
||||||
|
<artifactId>commons-io</artifactId>
|
||||||
|
<version>2.5</version>
|
||||||
|
</dependency>
|
||||||
|
<dependency>
|
||||||
|
<groupId>org.apache.commons</groupId>
|
||||||
|
<artifactId>commons-lang3</artifactId>
|
||||||
|
<version>3.4</version>
|
||||||
|
</dependency>
|
||||||
|
<dependency>
|
||||||
|
<groupId>org.slf4j</groupId>
|
||||||
|
<artifactId>slf4j-api</artifactId>
|
||||||
|
<version>1.7.25</version>
|
||||||
|
</dependency>
|
||||||
|
</dependencies>
|
||||||
|
</project>
|
32
java/performance_test/pom_template.xml
Normal file
32
java/performance_test/pom_template.xml
Normal file
|
@ -0,0 +1,32 @@
|
||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
{auto_gen_header}
|
||||||
|
<project xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
|
||||||
|
xmlns="http://maven.apache.org/POM/4.0.0"
|
||||||
|
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/maven-v4_0_0.xsd">
|
||||||
|
<parent>
|
||||||
|
<groupId>io.ray</groupId>
|
||||||
|
<artifactId>ray-superpom</artifactId>
|
||||||
|
<version>1.0.0-SNAPSHOT</version>
|
||||||
|
</parent>
|
||||||
|
<modelVersion>4.0.0</modelVersion>
|
||||||
|
|
||||||
|
<artifactId>ray-performance-test</artifactId>
|
||||||
|
<name>java performance test cases for ray</name>
|
||||||
|
<description>java performance test cases for ray</description>
|
||||||
|
|
||||||
|
<packaging>jar</packaging>
|
||||||
|
|
||||||
|
<dependencies>
|
||||||
|
<dependency>
|
||||||
|
<groupId>io.ray</groupId>
|
||||||
|
<artifactId>ray-api</artifactId>
|
||||||
|
<version>${project.version}</version>
|
||||||
|
</dependency>
|
||||||
|
<dependency>
|
||||||
|
<groupId>io.ray</groupId>
|
||||||
|
<artifactId>ray-runtime</artifactId>
|
||||||
|
<version>${project.version}</version>
|
||||||
|
</dependency>
|
||||||
|
{generated_bzl_deps}
|
||||||
|
</dependencies>
|
||||||
|
</project>
|
|
@ -0,0 +1,40 @@
|
||||||
|
package io.ray.performancetest;
|
||||||
|
|
||||||
|
import java.nio.ByteBuffer;
|
||||||
|
|
||||||
|
public class Receiver {
|
||||||
|
private int value = 0;
|
||||||
|
|
||||||
|
public Receiver() {}
|
||||||
|
|
||||||
|
public boolean ping() {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void noArgsNoReturn() {
|
||||||
|
value += 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
public int noArgsHasReturn() {
|
||||||
|
value += 1;
|
||||||
|
return value;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void bytesNoReturn(byte[] data) {
|
||||||
|
value += 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
public int bytesHasReturn(byte[] data) {
|
||||||
|
value += 1;
|
||||||
|
return value;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void byteBufferNoReturn(ByteBuffer data) {
|
||||||
|
value += 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
public int byteBufferHasReturn(ByteBuffer data) {
|
||||||
|
value += 1;
|
||||||
|
return value;
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,117 @@
|
||||||
|
package io.ray.performancetest;
|
||||||
|
|
||||||
|
import com.google.common.base.Preconditions;
|
||||||
|
import io.ray.api.ActorHandle;
|
||||||
|
import io.ray.api.ObjectRef;
|
||||||
|
import io.ray.api.Ray;
|
||||||
|
import java.nio.ByteBuffer;
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Random;
|
||||||
|
import org.slf4j.Logger;
|
||||||
|
import org.slf4j.LoggerFactory;
|
||||||
|
|
||||||
|
public class Source {
|
||||||
|
private static final int BATCH_SIZE;
|
||||||
|
|
||||||
|
private static final Logger LOGGER = LoggerFactory.getLogger(Source.class);
|
||||||
|
|
||||||
|
private final List<ActorHandle<Receiver>> receivers;
|
||||||
|
|
||||||
|
static {
|
||||||
|
String batchSizeString = System.getenv().get("PERF_TEST_BATCH_SIZE");
|
||||||
|
if (batchSizeString != null) {
|
||||||
|
BATCH_SIZE = Integer.valueOf(batchSizeString);
|
||||||
|
} else {
|
||||||
|
BATCH_SIZE = 1000;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public Source(List<ActorHandle<Receiver>> receivers) {
|
||||||
|
this.receivers = receivers;
|
||||||
|
}
|
||||||
|
|
||||||
|
public boolean startTest(
|
||||||
|
boolean hasReturn, boolean ignoreReturn, int argSize, boolean useDirectByteBuffer) {
|
||||||
|
LOGGER.info("Source startTest");
|
||||||
|
byte[] bytes = null;
|
||||||
|
ByteBuffer buffer = null;
|
||||||
|
if (argSize > 0) {
|
||||||
|
bytes = new byte[argSize];
|
||||||
|
new Random().nextBytes(bytes);
|
||||||
|
buffer = ByteBuffer.wrap(bytes);
|
||||||
|
} else {
|
||||||
|
Preconditions.checkState(!useDirectByteBuffer);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait for actors to be created.
|
||||||
|
for (ActorHandle<Receiver> receiver : receivers) {
|
||||||
|
receiver.task(Receiver::ping).remote().get();
|
||||||
|
}
|
||||||
|
|
||||||
|
LOGGER.info(
|
||||||
|
"Started executing tasks, useDirectByteBuffer: {}, argSize: {}, has return: {}",
|
||||||
|
useDirectByteBuffer,
|
||||||
|
argSize,
|
||||||
|
hasReturn);
|
||||||
|
|
||||||
|
List<List<ObjectRef<Integer>>> returnObjects = new ArrayList<>();
|
||||||
|
returnObjects.add(new ArrayList<>());
|
||||||
|
returnObjects.add(new ArrayList<>());
|
||||||
|
|
||||||
|
long startTime = System.currentTimeMillis();
|
||||||
|
int numTasks = 0;
|
||||||
|
long lastReport = 0;
|
||||||
|
long totalTime = 0;
|
||||||
|
long batchCount = 0;
|
||||||
|
while (true) {
|
||||||
|
numTasks++;
|
||||||
|
boolean batchEnd = numTasks % BATCH_SIZE == 0;
|
||||||
|
for (ActorHandle<Receiver> receiver : receivers) {
|
||||||
|
if (hasReturn || batchEnd) {
|
||||||
|
ObjectRef<Integer> returnObject;
|
||||||
|
if (useDirectByteBuffer) {
|
||||||
|
returnObject = receiver.task(Receiver::byteBufferHasReturn, buffer).remote();
|
||||||
|
} else if (argSize > 0) {
|
||||||
|
returnObject = receiver.task(Receiver::bytesHasReturn, bytes).remote();
|
||||||
|
} else {
|
||||||
|
returnObject = receiver.task(Receiver::noArgsHasReturn).remote();
|
||||||
|
}
|
||||||
|
returnObjects.get(1).add(returnObject);
|
||||||
|
} else {
|
||||||
|
if (useDirectByteBuffer) {
|
||||||
|
receiver.task(Receiver::byteBufferNoReturn, buffer).remote();
|
||||||
|
} else if (argSize > 0) {
|
||||||
|
receiver.task(Receiver::bytesNoReturn, bytes).remote();
|
||||||
|
} else {
|
||||||
|
receiver.task(Receiver::noArgsNoReturn).remote();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (batchEnd) {
|
||||||
|
batchCount++;
|
||||||
|
long getBeginTs = System.currentTimeMillis();
|
||||||
|
Ray.get(returnObjects.get(0));
|
||||||
|
long rt = System.currentTimeMillis() - getBeginTs;
|
||||||
|
totalTime += rt;
|
||||||
|
returnObjects.set(0, returnObjects.get(1));
|
||||||
|
returnObjects.set(1, new ArrayList<>());
|
||||||
|
|
||||||
|
long elapsedTime = System.currentTimeMillis() - startTime;
|
||||||
|
if (elapsedTime / 60000 > lastReport) {
|
||||||
|
lastReport = elapsedTime / 60000;
|
||||||
|
LOGGER.info(
|
||||||
|
"Finished executing {} tasks in {} ms, useDirectByteBuffer: {}, argSize: {}, "
|
||||||
|
+ "has return: {}, avg get rt: {}",
|
||||||
|
numTasks,
|
||||||
|
elapsedTime,
|
||||||
|
useDirectByteBuffer,
|
||||||
|
argSize,
|
||||||
|
hasReturn,
|
||||||
|
totalTime / (float) batchCount);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,76 @@
|
||||||
|
package io.ray.performancetest.test;
|
||||||
|
|
||||||
|
import com.google.common.base.Preconditions;
|
||||||
|
import io.ray.api.ActorHandle;
|
||||||
|
import io.ray.api.ObjectRef;
|
||||||
|
import io.ray.api.Ray;
|
||||||
|
import io.ray.performancetest.Receiver;
|
||||||
|
import io.ray.performancetest.Source;
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.stream.Collectors;
|
||||||
|
import org.slf4j.Logger;
|
||||||
|
import org.slf4j.LoggerFactory;
|
||||||
|
|
||||||
|
public class ActorPerformanceTestBase {
|
||||||
|
|
||||||
|
private static final Logger LOGGER = LoggerFactory.getLogger(ActorPerformanceTestBase.class);
|
||||||
|
|
||||||
|
public static void run(
|
||||||
|
String[] args,
|
||||||
|
int[] layers,
|
||||||
|
int[] actorsPerLayer,
|
||||||
|
boolean hasReturn,
|
||||||
|
boolean ignoreReturn,
|
||||||
|
int argSize,
|
||||||
|
boolean useDirectByteBuffer,
|
||||||
|
int numJavaWorkerPerProcess) {
|
||||||
|
System.setProperty(
|
||||||
|
"ray.job.num-java-workers-per-process", String.valueOf(numJavaWorkerPerProcess));
|
||||||
|
Ray.init();
|
||||||
|
try {
|
||||||
|
// TODO: Support more layers.
|
||||||
|
Preconditions.checkState(layers.length == 2);
|
||||||
|
Preconditions.checkState(actorsPerLayer.length == layers.length);
|
||||||
|
for (int i = 0; i < layers.length; i++) {
|
||||||
|
Preconditions.checkState(layers[i] > 0);
|
||||||
|
Preconditions.checkState(actorsPerLayer[i] > 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
List<ActorHandle<Receiver>> receivers = new ArrayList<>();
|
||||||
|
for (int i = 0; i < layers[1]; i++) {
|
||||||
|
int nodeIndex = layers[0] + i;
|
||||||
|
for (int j = 0; j < actorsPerLayer[1]; j++) {
|
||||||
|
receivers.add(Ray.actor(Receiver::new).remote());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
List<ActorHandle<Source>> sources = new ArrayList<>();
|
||||||
|
for (int i = 0; i < layers[0]; i++) {
|
||||||
|
int nodeIndex = i;
|
||||||
|
for (int j = 0; j < actorsPerLayer[0]; j++) {
|
||||||
|
sources.add(Ray.actor(Source::new, receivers).remote());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
List<ObjectRef<Boolean>> results =
|
||||||
|
sources.stream()
|
||||||
|
.map(
|
||||||
|
source ->
|
||||||
|
source
|
||||||
|
.task(
|
||||||
|
Source::startTest,
|
||||||
|
hasReturn,
|
||||||
|
ignoreReturn,
|
||||||
|
argSize,
|
||||||
|
useDirectByteBuffer)
|
||||||
|
.remote())
|
||||||
|
.collect(Collectors.toList());
|
||||||
|
|
||||||
|
Ray.get(results);
|
||||||
|
} catch (Throwable e) {
|
||||||
|
LOGGER.error("Run test failed.", e);
|
||||||
|
throw e;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,27 @@
|
||||||
|
package io.ray.performancetest.test;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 1-to-1 ray call, one receiving actor and one sending actor, test the throughput of the engine
|
||||||
|
* itself.
|
||||||
|
*/
|
||||||
|
public class ActorPerformanceTestCase1 {
|
||||||
|
|
||||||
|
public static void main(String[] args) {
|
||||||
|
final int[] layers = new int[] {1, 1};
|
||||||
|
final int[] actorsPerLayer = new int[] {1, 1};
|
||||||
|
final boolean hasReturn = false;
|
||||||
|
final int argSize = 0;
|
||||||
|
final boolean useDirectByteBuffer = false;
|
||||||
|
final boolean ignoreReturn = false;
|
||||||
|
final int numJavaWorkerPerProcess = 1;
|
||||||
|
ActorPerformanceTestBase.run(
|
||||||
|
args,
|
||||||
|
layers,
|
||||||
|
actorsPerLayer,
|
||||||
|
hasReturn,
|
||||||
|
ignoreReturn,
|
||||||
|
argSize,
|
||||||
|
useDirectByteBuffer,
|
||||||
|
numJavaWorkerPerProcess);
|
||||||
|
}
|
||||||
|
}
|
25
java/test.sh
25
java/test.sh
|
@ -41,6 +41,24 @@ run_testng() {
|
||||||
fi
|
fi
|
||||||
}
|
}
|
||||||
|
|
||||||
|
run_timeout() {
|
||||||
|
local pid
|
||||||
|
timeout=$1
|
||||||
|
shift 1
|
||||||
|
"$@" &
|
||||||
|
pid=$!
|
||||||
|
sleep "$timeout"
|
||||||
|
if ps -p $pid > /dev/null
|
||||||
|
then
|
||||||
|
echo "run_timeout process exists, kill it."
|
||||||
|
kill -9 $pid
|
||||||
|
else
|
||||||
|
echo "run_timeout process not exist."
|
||||||
|
cat /tmp/ray/session_latest/logs/java-core-driver-*$pid*
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
}
|
||||||
|
|
||||||
pushd "$ROOT_DIR"/..
|
pushd "$ROOT_DIR"/..
|
||||||
echo "Build java maven deps."
|
echo "Build java maven deps."
|
||||||
bazel build //java:gen_maven_deps
|
bazel build //java:gen_maven_deps
|
||||||
|
@ -114,3 +132,10 @@ mvn -Dorg.slf4j.simpleLogger.defaultLogLevel=WARN clean install -DskipTests -Dch
|
||||||
# Ensure mvn test works
|
# Ensure mvn test works
|
||||||
mvn test -pl test -Dtest="io.ray.test.HelloWorldTest"
|
mvn test -pl test -Dtest="io.ray.test.HelloWorldTest"
|
||||||
popd
|
popd
|
||||||
|
|
||||||
|
pushd "$ROOT_DIR"
|
||||||
|
echo "Running performance test."
|
||||||
|
run_timeout 60 java -cp "$ROOT_DIR"/../bazel-bin/java/all_tests_deploy.jar io.ray.performancetest.test.ActorPerformanceTestCase1
|
||||||
|
# The performance process may be killed by run_timeout, so clear ray here.
|
||||||
|
ray stop
|
||||||
|
popd
|
||||||
|
|
|
@ -35,7 +35,10 @@ class MessageWrapper {
|
||||||
/// The input message will be **copied** into this object.
|
/// The input message will be **copied** into this object.
|
||||||
///
|
///
|
||||||
/// \param message The protobuf message.
|
/// \param message The protobuf message.
|
||||||
explicit MessageWrapper(const Message message)
|
explicit MessageWrapper(const Message &message)
|
||||||
|
: message_(std::make_shared<Message>(message)) {}
|
||||||
|
|
||||||
|
explicit MessageWrapper(Message &&message)
|
||||||
: message_(std::make_shared<Message>(std::move(message))) {}
|
: message_(std::make_shared<Message>(std::move(message))) {}
|
||||||
|
|
||||||
/// Construct from a protobuf message shared_ptr.
|
/// Construct from a protobuf message shared_ptr.
|
||||||
|
@ -115,7 +118,8 @@ inline std::vector<ID> IdVectorFromProtobuf(
|
||||||
|
|
||||||
/// Converts a Protobuf map to a `unordered_map`.
|
/// Converts a Protobuf map to a `unordered_map`.
|
||||||
template <class K, class V>
|
template <class K, class V>
|
||||||
inline std::unordered_map<K, V> MapFromProtobuf(::google::protobuf::Map<K, V> pb_map) {
|
inline std::unordered_map<K, V> MapFromProtobuf(
|
||||||
|
const ::google::protobuf::Map<K, V> &pb_map) {
|
||||||
return std::unordered_map<K, V>(pb_map.begin(), pb_map.end());
|
return std::unordered_map<K, V>(pb_map.begin(), pb_map.end());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -53,24 +53,24 @@ bool TaskSpecification::PlacementGroupCaptureChildTasks() const {
|
||||||
}
|
}
|
||||||
|
|
||||||
void TaskSpecification::ComputeResources() {
|
void TaskSpecification::ComputeResources() {
|
||||||
auto required_resources = MapFromProtobuf(message_->required_resources());
|
auto &required_resources = message_->required_resources();
|
||||||
auto required_placement_resources =
|
|
||||||
MapFromProtobuf(message_->required_placement_resources());
|
|
||||||
if (required_placement_resources.empty()) {
|
|
||||||
required_placement_resources = required_resources;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (required_resources.empty()) {
|
if (required_resources.empty()) {
|
||||||
// A static nil object is used here to avoid allocating the empty object every time.
|
// A static nil object is used here to avoid allocating the empty object every time.
|
||||||
required_resources_ = ResourceSet::Nil();
|
required_resources_ = ResourceSet::Nil();
|
||||||
} else {
|
} else {
|
||||||
required_resources_.reset(new ResourceSet(required_resources));
|
required_resources_.reset(new ResourceSet(MapFromProtobuf(required_resources)));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
auto &required_placement_resources = message_->required_placement_resources().empty()
|
||||||
|
? required_resources
|
||||||
|
: message_->required_placement_resources();
|
||||||
|
|
||||||
if (required_placement_resources.empty()) {
|
if (required_placement_resources.empty()) {
|
||||||
required_placement_resources_ = ResourceSet::Nil();
|
required_placement_resources_ = ResourceSet::Nil();
|
||||||
} else {
|
} else {
|
||||||
required_placement_resources_.reset(new ResourceSet(required_placement_resources));
|
required_placement_resources_.reset(
|
||||||
|
new ResourceSet(MapFromProtobuf(required_placement_resources)));
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!IsActorTask()) {
|
if (!IsActorTask()) {
|
||||||
|
@ -174,14 +174,15 @@ std::vector<ObjectID> TaskSpecification::GetDependencyIds() const {
|
||||||
return dependencies;
|
return dependencies;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<rpc::ObjectReference> TaskSpecification::GetDependencies() const {
|
std::vector<rpc::ObjectReference> TaskSpecification::GetDependencies(
|
||||||
|
bool add_dummy_dependency) const {
|
||||||
std::vector<rpc::ObjectReference> dependencies;
|
std::vector<rpc::ObjectReference> dependencies;
|
||||||
for (size_t i = 0; i < NumArgs(); ++i) {
|
for (size_t i = 0; i < NumArgs(); ++i) {
|
||||||
if (ArgByRef(i)) {
|
if (ArgByRef(i)) {
|
||||||
dependencies.push_back(message_->args(i).object_ref());
|
dependencies.push_back(message_->args(i).object_ref());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (IsActorTask()) {
|
if (add_dummy_dependency && IsActorTask()) {
|
||||||
const auto &dummy_ref =
|
const auto &dummy_ref =
|
||||||
GetReferenceForActorDummyObject(PreviousActorTaskDummyObjectId());
|
GetReferenceForActorDummyObject(PreviousActorTaskDummyObjectId());
|
||||||
dependencies.push_back(dummy_ref);
|
dependencies.push_back(dummy_ref);
|
||||||
|
|
|
@ -36,10 +36,15 @@ class TaskSpecification : public MessageWrapper<rpc::TaskSpec> {
|
||||||
TaskSpecification() {}
|
TaskSpecification() {}
|
||||||
|
|
||||||
/// Construct from a protobuf message object.
|
/// Construct from a protobuf message object.
|
||||||
/// The input message will be **copied** into this object.
|
/// The input message will be copied/moved into this object.
|
||||||
///
|
///
|
||||||
/// \param message The protobuf message.
|
/// \param message The protobuf message.
|
||||||
explicit TaskSpecification(rpc::TaskSpec message) : MessageWrapper(message) {
|
explicit TaskSpecification(rpc::TaskSpec &&message)
|
||||||
|
: MessageWrapper(std::move(message)) {
|
||||||
|
ComputeResources();
|
||||||
|
}
|
||||||
|
|
||||||
|
explicit TaskSpecification(const rpc::TaskSpec &message) : MessageWrapper(message) {
|
||||||
ComputeResources();
|
ComputeResources();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -127,9 +132,10 @@ class TaskSpecification : public MessageWrapper<rpc::TaskSpec> {
|
||||||
|
|
||||||
/// Return the dependencies of this task. This is recomputed each time, so it can
|
/// Return the dependencies of this task. This is recomputed each time, so it can
|
||||||
/// be used if the task spec is mutated.
|
/// be used if the task spec is mutated.
|
||||||
///
|
/// \param add_dummy_dependency whether to add a dummy object in the returned objects.
|
||||||
/// \return The recomputed dependencies for the task.
|
/// \return The recomputed dependencies for the task.
|
||||||
std::vector<rpc::ObjectReference> GetDependencies() const;
|
std::vector<rpc::ObjectReference> GetDependencies(
|
||||||
|
bool add_dummy_dependency = true) const;
|
||||||
|
|
||||||
std::string GetDebuggerBreakpoint() const;
|
std::string GetDebuggerBreakpoint() const;
|
||||||
|
|
||||||
|
|
|
@ -2219,7 +2219,7 @@ void CoreWorker::HandlePushTask(const rpc::PushTaskRequest &request,
|
||||||
// execution service.
|
// execution service.
|
||||||
if (request.task_spec().type() == TaskType::ACTOR_TASK) {
|
if (request.task_spec().type() == TaskType::ACTOR_TASK) {
|
||||||
task_execution_service_.post(
|
task_execution_service_.post(
|
||||||
[=] {
|
[this, request, reply, send_reply_callback = std::move(send_reply_callback)] {
|
||||||
// We have posted an exit task onto the main event loop,
|
// We have posted an exit task onto the main event loop,
|
||||||
// so shouldn't bother executing any further work.
|
// so shouldn't bother executing any further work.
|
||||||
if (exiting_) return;
|
if (exiting_) return;
|
||||||
|
|
|
@ -42,12 +42,12 @@ TEST(SchedulingQueueTest, TestInOrder) {
|
||||||
ActorSchedulingQueue queue(io_service, waiter);
|
ActorSchedulingQueue queue(io_service, waiter);
|
||||||
int n_ok = 0;
|
int n_ok = 0;
|
||||||
int n_rej = 0;
|
int n_rej = 0;
|
||||||
auto fn_ok = [&n_ok]() { n_ok++; };
|
auto fn_ok = [&n_ok](rpc::SendReplyCallback callback) { n_ok++; };
|
||||||
auto fn_rej = [&n_rej]() { n_rej++; };
|
auto fn_rej = [&n_rej](rpc::SendReplyCallback callback) { n_rej++; };
|
||||||
queue.Add(0, -1, fn_ok, fn_rej);
|
queue.Add(0, -1, fn_ok, fn_rej, nullptr);
|
||||||
queue.Add(1, -1, fn_ok, fn_rej);
|
queue.Add(1, -1, fn_ok, fn_rej, nullptr);
|
||||||
queue.Add(2, -1, fn_ok, fn_rej);
|
queue.Add(2, -1, fn_ok, fn_rej, nullptr);
|
||||||
queue.Add(3, -1, fn_ok, fn_rej);
|
queue.Add(3, -1, fn_ok, fn_rej, nullptr);
|
||||||
io_service.run();
|
io_service.run();
|
||||||
ASSERT_EQ(n_ok, 4);
|
ASSERT_EQ(n_ok, 4);
|
||||||
ASSERT_EQ(n_rej, 0);
|
ASSERT_EQ(n_rej, 0);
|
||||||
|
@ -62,12 +62,12 @@ TEST(SchedulingQueueTest, TestWaitForObjects) {
|
||||||
ActorSchedulingQueue queue(io_service, waiter);
|
ActorSchedulingQueue queue(io_service, waiter);
|
||||||
int n_ok = 0;
|
int n_ok = 0;
|
||||||
int n_rej = 0;
|
int n_rej = 0;
|
||||||
auto fn_ok = [&n_ok]() { n_ok++; };
|
auto fn_ok = [&n_ok](rpc::SendReplyCallback callback) { n_ok++; };
|
||||||
auto fn_rej = [&n_rej]() { n_rej++; };
|
auto fn_rej = [&n_rej](rpc::SendReplyCallback callback) { n_rej++; };
|
||||||
queue.Add(0, -1, fn_ok, fn_rej);
|
queue.Add(0, -1, fn_ok, fn_rej, nullptr);
|
||||||
queue.Add(1, -1, fn_ok, fn_rej, TaskID::Nil(), ObjectIdsToRefs({obj1}));
|
queue.Add(1, -1, fn_ok, fn_rej, nullptr, TaskID::Nil(), ObjectIdsToRefs({obj1}));
|
||||||
queue.Add(2, -1, fn_ok, fn_rej, TaskID::Nil(), ObjectIdsToRefs({obj2}));
|
queue.Add(2, -1, fn_ok, fn_rej, nullptr, TaskID::Nil(), ObjectIdsToRefs({obj2}));
|
||||||
queue.Add(3, -1, fn_ok, fn_rej, TaskID::Nil(), ObjectIdsToRefs({obj3}));
|
queue.Add(3, -1, fn_ok, fn_rej, nullptr, TaskID::Nil(), ObjectIdsToRefs({obj3}));
|
||||||
ASSERT_EQ(n_ok, 1);
|
ASSERT_EQ(n_ok, 1);
|
||||||
|
|
||||||
waiter.Complete(0);
|
waiter.Complete(0);
|
||||||
|
@ -87,10 +87,10 @@ TEST(SchedulingQueueTest, TestWaitForObjectsNotSubjectToSeqTimeout) {
|
||||||
ActorSchedulingQueue queue(io_service, waiter);
|
ActorSchedulingQueue queue(io_service, waiter);
|
||||||
int n_ok = 0;
|
int n_ok = 0;
|
||||||
int n_rej = 0;
|
int n_rej = 0;
|
||||||
auto fn_ok = [&n_ok]() { n_ok++; };
|
auto fn_ok = [&n_ok](rpc::SendReplyCallback callback) { n_ok++; };
|
||||||
auto fn_rej = [&n_rej]() { n_rej++; };
|
auto fn_rej = [&n_rej](rpc::SendReplyCallback callback) { n_rej++; };
|
||||||
queue.Add(0, -1, fn_ok, fn_rej);
|
queue.Add(0, -1, fn_ok, fn_rej, nullptr);
|
||||||
queue.Add(1, -1, fn_ok, fn_rej, TaskID::Nil(), ObjectIdsToRefs({obj1}));
|
queue.Add(1, -1, fn_ok, fn_rej, nullptr, TaskID::Nil(), ObjectIdsToRefs({obj1}));
|
||||||
ASSERT_EQ(n_ok, 1);
|
ASSERT_EQ(n_ok, 1);
|
||||||
io_service.run();
|
io_service.run();
|
||||||
ASSERT_EQ(n_rej, 0);
|
ASSERT_EQ(n_rej, 0);
|
||||||
|
@ -104,12 +104,12 @@ TEST(SchedulingQueueTest, TestOutOfOrder) {
|
||||||
ActorSchedulingQueue queue(io_service, waiter);
|
ActorSchedulingQueue queue(io_service, waiter);
|
||||||
int n_ok = 0;
|
int n_ok = 0;
|
||||||
int n_rej = 0;
|
int n_rej = 0;
|
||||||
auto fn_ok = [&n_ok]() { n_ok++; };
|
auto fn_ok = [&n_ok](rpc::SendReplyCallback callback) { n_ok++; };
|
||||||
auto fn_rej = [&n_rej]() { n_rej++; };
|
auto fn_rej = [&n_rej](rpc::SendReplyCallback callback) { n_rej++; };
|
||||||
queue.Add(2, -1, fn_ok, fn_rej);
|
queue.Add(2, -1, fn_ok, fn_rej, nullptr);
|
||||||
queue.Add(0, -1, fn_ok, fn_rej);
|
queue.Add(0, -1, fn_ok, fn_rej, nullptr);
|
||||||
queue.Add(3, -1, fn_ok, fn_rej);
|
queue.Add(3, -1, fn_ok, fn_rej, nullptr);
|
||||||
queue.Add(1, -1, fn_ok, fn_rej);
|
queue.Add(1, -1, fn_ok, fn_rej, nullptr);
|
||||||
io_service.run();
|
io_service.run();
|
||||||
ASSERT_EQ(n_ok, 4);
|
ASSERT_EQ(n_ok, 4);
|
||||||
ASSERT_EQ(n_rej, 0);
|
ASSERT_EQ(n_rej, 0);
|
||||||
|
@ -121,18 +121,18 @@ TEST(SchedulingQueueTest, TestSeqWaitTimeout) {
|
||||||
ActorSchedulingQueue queue(io_service, waiter);
|
ActorSchedulingQueue queue(io_service, waiter);
|
||||||
int n_ok = 0;
|
int n_ok = 0;
|
||||||
int n_rej = 0;
|
int n_rej = 0;
|
||||||
auto fn_ok = [&n_ok]() { n_ok++; };
|
auto fn_ok = [&n_ok](rpc::SendReplyCallback callback) { n_ok++; };
|
||||||
auto fn_rej = [&n_rej]() { n_rej++; };
|
auto fn_rej = [&n_rej](rpc::SendReplyCallback callback) { n_rej++; };
|
||||||
queue.Add(2, -1, fn_ok, fn_rej);
|
queue.Add(2, -1, fn_ok, fn_rej, nullptr);
|
||||||
queue.Add(0, -1, fn_ok, fn_rej);
|
queue.Add(0, -1, fn_ok, fn_rej, nullptr);
|
||||||
queue.Add(3, -1, fn_ok, fn_rej);
|
queue.Add(3, -1, fn_ok, fn_rej, nullptr);
|
||||||
ASSERT_EQ(n_ok, 1);
|
ASSERT_EQ(n_ok, 1);
|
||||||
ASSERT_EQ(n_rej, 0);
|
ASSERT_EQ(n_rej, 0);
|
||||||
io_service.run(); // immediately triggers timeout
|
io_service.run(); // immediately triggers timeout
|
||||||
ASSERT_EQ(n_ok, 1);
|
ASSERT_EQ(n_ok, 1);
|
||||||
ASSERT_EQ(n_rej, 2);
|
ASSERT_EQ(n_rej, 2);
|
||||||
queue.Add(4, -1, fn_ok, fn_rej);
|
queue.Add(4, -1, fn_ok, fn_rej, nullptr);
|
||||||
queue.Add(5, -1, fn_ok, fn_rej);
|
queue.Add(5, -1, fn_ok, fn_rej, nullptr);
|
||||||
ASSERT_EQ(n_ok, 3);
|
ASSERT_EQ(n_ok, 3);
|
||||||
ASSERT_EQ(n_rej, 2);
|
ASSERT_EQ(n_rej, 2);
|
||||||
}
|
}
|
||||||
|
@ -143,11 +143,11 @@ TEST(SchedulingQueueTest, TestSkipAlreadyProcessedByClient) {
|
||||||
ActorSchedulingQueue queue(io_service, waiter);
|
ActorSchedulingQueue queue(io_service, waiter);
|
||||||
int n_ok = 0;
|
int n_ok = 0;
|
||||||
int n_rej = 0;
|
int n_rej = 0;
|
||||||
auto fn_ok = [&n_ok]() { n_ok++; };
|
auto fn_ok = [&n_ok](rpc::SendReplyCallback callback) { n_ok++; };
|
||||||
auto fn_rej = [&n_rej]() { n_rej++; };
|
auto fn_rej = [&n_rej](rpc::SendReplyCallback callback) { n_rej++; };
|
||||||
queue.Add(2, 2, fn_ok, fn_rej);
|
queue.Add(2, 2, fn_ok, fn_rej, nullptr);
|
||||||
queue.Add(3, 2, fn_ok, fn_rej);
|
queue.Add(3, 2, fn_ok, fn_rej, nullptr);
|
||||||
queue.Add(1, 2, fn_ok, fn_rej);
|
queue.Add(1, 2, fn_ok, fn_rej, nullptr);
|
||||||
io_service.run();
|
io_service.run();
|
||||||
ASSERT_EQ(n_ok, 1);
|
ASSERT_EQ(n_ok, 1);
|
||||||
ASSERT_EQ(n_rej, 2);
|
ASSERT_EQ(n_rej, 2);
|
||||||
|
@ -158,13 +158,13 @@ TEST(SchedulingQueueTest, TestCancelQueuedTask) {
|
||||||
ASSERT_TRUE(queue->TaskQueueEmpty());
|
ASSERT_TRUE(queue->TaskQueueEmpty());
|
||||||
int n_ok = 0;
|
int n_ok = 0;
|
||||||
int n_rej = 0;
|
int n_rej = 0;
|
||||||
auto fn_ok = [&n_ok]() { n_ok++; };
|
auto fn_ok = [&n_ok](rpc::SendReplyCallback callback) { n_ok++; };
|
||||||
auto fn_rej = [&n_rej]() { n_rej++; };
|
auto fn_rej = [&n_rej](rpc::SendReplyCallback callback) { n_rej++; };
|
||||||
queue->Add(-1, -1, fn_ok, fn_rej);
|
queue->Add(-1, -1, fn_ok, fn_rej, nullptr);
|
||||||
queue->Add(-1, -1, fn_ok, fn_rej);
|
queue->Add(-1, -1, fn_ok, fn_rej, nullptr);
|
||||||
queue->Add(-1, -1, fn_ok, fn_rej);
|
queue->Add(-1, -1, fn_ok, fn_rej, nullptr);
|
||||||
queue->Add(-1, -1, fn_ok, fn_rej);
|
queue->Add(-1, -1, fn_ok, fn_rej, nullptr);
|
||||||
queue->Add(-1, -1, fn_ok, fn_rej);
|
queue->Add(-1, -1, fn_ok, fn_rej, nullptr);
|
||||||
ASSERT_TRUE(queue->CancelTaskIfFound(TaskID::Nil()));
|
ASSERT_TRUE(queue->CancelTaskIfFound(TaskID::Nil()));
|
||||||
ASSERT_FALSE(queue->TaskQueueEmpty());
|
ASSERT_FALSE(queue->TaskQueueEmpty());
|
||||||
queue->ScheduleRequests();
|
queue->ScheduleRequests();
|
||||||
|
|
|
@ -59,14 +59,16 @@ void CoreWorkerDirectActorTaskSubmitter::KillActor(const ActorID &actor_id,
|
||||||
}
|
}
|
||||||
|
|
||||||
Status CoreWorkerDirectActorTaskSubmitter::SubmitTask(TaskSpecification task_spec) {
|
Status CoreWorkerDirectActorTaskSubmitter::SubmitTask(TaskSpecification task_spec) {
|
||||||
RAY_LOG(DEBUG) << "Submitting task " << task_spec.TaskId();
|
auto task_id = task_spec.TaskId();
|
||||||
|
auto actor_id = task_spec.ActorId();
|
||||||
|
RAY_LOG(DEBUG) << "Submitting task " << task_id;
|
||||||
RAY_CHECK(task_spec.IsActorTask());
|
RAY_CHECK(task_spec.IsActorTask());
|
||||||
|
|
||||||
bool task_queued = false;
|
bool task_queued = false;
|
||||||
uint64_t send_pos = 0;
|
uint64_t send_pos = 0;
|
||||||
{
|
{
|
||||||
absl::MutexLock lock(&mu_);
|
absl::MutexLock lock(&mu_);
|
||||||
auto queue = client_queues_.find(task_spec.ActorId());
|
auto queue = client_queues_.find(actor_id);
|
||||||
RAY_CHECK(queue != client_queues_.end());
|
RAY_CHECK(queue != client_queues_.end());
|
||||||
if (queue->second.state != rpc::ActorTableData::DEAD) {
|
if (queue->second.state != rpc::ActorTableData::DEAD) {
|
||||||
// We must fix the send order prior to resolving dependencies, which may
|
// We must fix the send order prior to resolving dependencies, which may
|
||||||
|
@ -82,7 +84,6 @@ Status CoreWorkerDirectActorTaskSubmitter::SubmitTask(TaskSpecification task_spe
|
||||||
}
|
}
|
||||||
|
|
||||||
if (task_queued) {
|
if (task_queued) {
|
||||||
const auto actor_id = task_spec.ActorId();
|
|
||||||
// We must release the lock before resolving the task dependencies since
|
// We must release the lock before resolving the task dependencies since
|
||||||
// the callback may get called in the same call stack.
|
// the callback may get called in the same call stack.
|
||||||
resolver_.ResolveDependencies(task_spec, [this, send_pos, actor_id]() {
|
resolver_.ResolveDependencies(task_spec, [this, send_pos, actor_id]() {
|
||||||
|
@ -99,7 +100,7 @@ Status CoreWorkerDirectActorTaskSubmitter::SubmitTask(TaskSpecification task_spe
|
||||||
});
|
});
|
||||||
} else {
|
} else {
|
||||||
// Do not hold the lock while calling into task_finisher_.
|
// Do not hold the lock while calling into task_finisher_.
|
||||||
task_finisher_->MarkTaskCanceled(task_spec.TaskId());
|
task_finisher_->MarkTaskCanceled(task_id);
|
||||||
std::shared_ptr<rpc::RayException> creation_task_exception = nullptr;
|
std::shared_ptr<rpc::RayException> creation_task_exception = nullptr;
|
||||||
{
|
{
|
||||||
absl::MutexLock lock(&mu_);
|
absl::MutexLock lock(&mu_);
|
||||||
|
@ -109,9 +110,8 @@ Status CoreWorkerDirectActorTaskSubmitter::SubmitTask(TaskSpecification task_spe
|
||||||
auto status = Status::IOError("cancelling task of dead actor");
|
auto status = Status::IOError("cancelling task of dead actor");
|
||||||
// No need to increment the number of completed tasks since the actor is
|
// No need to increment the number of completed tasks since the actor is
|
||||||
// dead.
|
// dead.
|
||||||
RAY_UNUSED(!task_finisher_->PendingTaskFailed(task_spec.TaskId(),
|
RAY_UNUSED(!task_finisher_->PendingTaskFailed(task_id, rpc::ErrorType::ACTOR_DIED,
|
||||||
rpc::ErrorType::ACTOR_DIED, &status,
|
&status, creation_task_exception));
|
||||||
creation_task_exception));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// If the task submission subsequently fails, then the client will receive
|
// If the task submission subsequently fails, then the client will receive
|
||||||
|
@ -423,7 +423,10 @@ void CoreWorkerDirectTaskReceiver::HandleTask(
|
||||||
const rpc::PushTaskRequest &request, rpc::PushTaskReply *reply,
|
const rpc::PushTaskRequest &request, rpc::PushTaskReply *reply,
|
||||||
rpc::SendReplyCallback send_reply_callback) {
|
rpc::SendReplyCallback send_reply_callback) {
|
||||||
RAY_CHECK(waiter_ != nullptr) << "Must call init() prior to use";
|
RAY_CHECK(waiter_ != nullptr) << "Must call init() prior to use";
|
||||||
const TaskSpecification task_spec(request.task_spec());
|
// Use `mutable_task_spec()` here as `task_spec()` returns a const reference
|
||||||
|
// which doesn't work with std::move.
|
||||||
|
TaskSpecification task_spec(
|
||||||
|
std::move(*(const_cast<rpc::PushTaskRequest &>(request).mutable_task_spec())));
|
||||||
|
|
||||||
// If GCS server is restarted after sending an actor creation task to this core worker,
|
// If GCS server is restarted after sending an actor creation task to this core worker,
|
||||||
// the restarted GCS server will send the same actor creation task to the core worker
|
// the restarted GCS server will send the same actor creation task to the core worker
|
||||||
|
@ -455,7 +458,8 @@ void CoreWorkerDirectTaskReceiver::HandleTask(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
auto accept_callback = [this, reply, send_reply_callback, task_spec, resource_ids]() {
|
auto accept_callback = [this, reply, task_spec,
|
||||||
|
resource_ids](rpc::SendReplyCallback send_reply_callback) {
|
||||||
if (task_spec.GetMessage().skip_execution()) {
|
if (task_spec.GetMessage().skip_execution()) {
|
||||||
send_reply_callback(Status::OK(), nullptr, nullptr);
|
send_reply_callback(Status::OK(), nullptr, nullptr);
|
||||||
return;
|
return;
|
||||||
|
@ -522,11 +526,11 @@ void CoreWorkerDirectTaskReceiver::HandleTask(
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
auto reject_callback = [send_reply_callback]() {
|
auto reject_callback = [](rpc::SendReplyCallback send_reply_callback) {
|
||||||
send_reply_callback(Status::Invalid("client cancelled stale rpc"), nullptr, nullptr);
|
send_reply_callback(Status::Invalid("client cancelled stale rpc"), nullptr, nullptr);
|
||||||
};
|
};
|
||||||
|
|
||||||
auto dependencies = task_spec.GetDependencies();
|
auto dependencies = task_spec.GetDependencies(false);
|
||||||
|
|
||||||
if (task_spec.IsActorTask()) {
|
if (task_spec.IsActorTask()) {
|
||||||
auto it = actor_scheduling_queues_.find(task_spec.CallerWorkerId());
|
auto it = actor_scheduling_queues_.find(task_spec.CallerWorkerId());
|
||||||
|
@ -538,16 +542,15 @@ void CoreWorkerDirectTaskReceiver::HandleTask(
|
||||||
it = result.first;
|
it = result.first;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Pop the dummy actor dependency.
|
|
||||||
// TODO(swang): Remove this with legacy raylet code.
|
|
||||||
dependencies.pop_back();
|
|
||||||
it->second->Add(request.sequence_number(), request.client_processed_up_to(),
|
it->second->Add(request.sequence_number(), request.client_processed_up_to(),
|
||||||
accept_callback, reject_callback, task_spec.TaskId(), dependencies);
|
std::move(accept_callback), std::move(reject_callback),
|
||||||
|
std::move(send_reply_callback), task_spec.TaskId(), dependencies);
|
||||||
} else {
|
} else {
|
||||||
// Add the normal task's callbacks to the non-actor scheduling queue.
|
// Add the normal task's callbacks to the non-actor scheduling queue.
|
||||||
normal_scheduling_queue_->Add(request.sequence_number(),
|
normal_scheduling_queue_->Add(
|
||||||
request.client_processed_up_to(), accept_callback,
|
request.sequence_number(), request.client_processed_up_to(),
|
||||||
reject_callback, task_spec.TaskId(), dependencies);
|
std::move(accept_callback), std::move(reject_callback),
|
||||||
|
std::move(send_reply_callback), task_spec.TaskId(), dependencies);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -276,23 +276,26 @@ class CoreWorkerDirectActorTaskSubmitter
|
||||||
class InboundRequest {
|
class InboundRequest {
|
||||||
public:
|
public:
|
||||||
InboundRequest(){};
|
InboundRequest(){};
|
||||||
InboundRequest(std::function<void()> accept_callback,
|
InboundRequest(std::function<void(rpc::SendReplyCallback)> accept_callback,
|
||||||
std::function<void()> reject_callback, TaskID task_id,
|
std::function<void(rpc::SendReplyCallback)> reject_callback,
|
||||||
|
rpc::SendReplyCallback send_reply_callback, TaskID task_id,
|
||||||
bool has_dependencies)
|
bool has_dependencies)
|
||||||
: accept_callback_(accept_callback),
|
: accept_callback_(std::move(accept_callback)),
|
||||||
reject_callback_(reject_callback),
|
reject_callback_(std::move(reject_callback)),
|
||||||
|
send_reply_callback_(std::move(send_reply_callback)),
|
||||||
task_id(task_id),
|
task_id(task_id),
|
||||||
has_pending_dependencies_(has_dependencies) {}
|
has_pending_dependencies_(has_dependencies) {}
|
||||||
|
|
||||||
void Accept() { accept_callback_(); }
|
void Accept() { accept_callback_(std::move(send_reply_callback_)); }
|
||||||
void Cancel() { reject_callback_(); }
|
void Cancel() { reject_callback_(std::move(send_reply_callback_)); }
|
||||||
bool CanExecute() const { return !has_pending_dependencies_; }
|
bool CanExecute() const { return !has_pending_dependencies_; }
|
||||||
ray::TaskID TaskID() const { return task_id; }
|
ray::TaskID TaskID() const { return task_id; }
|
||||||
void MarkDependenciesSatisfied() { has_pending_dependencies_ = false; }
|
void MarkDependenciesSatisfied() { has_pending_dependencies_ = false; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::function<void()> accept_callback_;
|
std::function<void(rpc::SendReplyCallback)> accept_callback_;
|
||||||
std::function<void()> reject_callback_;
|
std::function<void(rpc::SendReplyCallback)> reject_callback_;
|
||||||
|
rpc::SendReplyCallback send_reply_callback_;
|
||||||
ray::TaskID task_id;
|
ray::TaskID task_id;
|
||||||
bool has_pending_dependencies_;
|
bool has_pending_dependencies_;
|
||||||
};
|
};
|
||||||
|
@ -372,8 +375,10 @@ class BoundedExecutor {
|
||||||
class SchedulingQueue {
|
class SchedulingQueue {
|
||||||
public:
|
public:
|
||||||
virtual void Add(int64_t seq_no, int64_t client_processed_up_to,
|
virtual void Add(int64_t seq_no, int64_t client_processed_up_to,
|
||||||
std::function<void()> accept_request,
|
std::function<void(rpc::SendReplyCallback)> accept_request,
|
||||||
std::function<void()> reject_request, TaskID task_id = TaskID::Nil(),
|
std::function<void(rpc::SendReplyCallback)> reject_request,
|
||||||
|
rpc::SendReplyCallback send_reply_callback,
|
||||||
|
TaskID task_id = TaskID::Nil(),
|
||||||
const std::vector<rpc::ObjectReference> &dependencies = {}) = 0;
|
const std::vector<rpc::ObjectReference> &dependencies = {}) = 0;
|
||||||
virtual void ScheduleRequests() = 0;
|
virtual void ScheduleRequests() = 0;
|
||||||
virtual bool TaskQueueEmpty() const = 0;
|
virtual bool TaskQueueEmpty() const = 0;
|
||||||
|
@ -402,8 +407,9 @@ class ActorSchedulingQueue : public SchedulingQueue {
|
||||||
|
|
||||||
/// Add a new actor task's callbacks to the worker queue.
|
/// Add a new actor task's callbacks to the worker queue.
|
||||||
void Add(int64_t seq_no, int64_t client_processed_up_to,
|
void Add(int64_t seq_no, int64_t client_processed_up_to,
|
||||||
std::function<void()> accept_request, std::function<void()> reject_request,
|
std::function<void(rpc::SendReplyCallback)> accept_request,
|
||||||
TaskID task_id = TaskID::Nil(),
|
std::function<void(rpc::SendReplyCallback)> reject_request,
|
||||||
|
rpc::SendReplyCallback send_reply_callback, TaskID task_id = TaskID::Nil(),
|
||||||
const std::vector<rpc::ObjectReference> &dependencies = {}) {
|
const std::vector<rpc::ObjectReference> &dependencies = {}) {
|
||||||
// A seq_no of -1 means no ordering constraint. Actor tasks must be executed in order.
|
// A seq_no of -1 means no ordering constraint. Actor tasks must be executed in order.
|
||||||
RAY_CHECK(seq_no != -1);
|
RAY_CHECK(seq_no != -1);
|
||||||
|
@ -416,7 +422,8 @@ class ActorSchedulingQueue : public SchedulingQueue {
|
||||||
}
|
}
|
||||||
RAY_LOG(DEBUG) << "Enqueue " << seq_no << " cur seqno " << next_seq_no_;
|
RAY_LOG(DEBUG) << "Enqueue " << seq_no << " cur seqno " << next_seq_no_;
|
||||||
pending_actor_tasks_[seq_no] =
|
pending_actor_tasks_[seq_no] =
|
||||||
InboundRequest(accept_request, reject_request, task_id, dependencies.size() > 0);
|
InboundRequest(std::move(accept_request), std::move(reject_request),
|
||||||
|
std::move(send_reply_callback), task_id, dependencies.size() > 0);
|
||||||
if (dependencies.size() > 0) {
|
if (dependencies.size() > 0) {
|
||||||
waiter_.Wait(dependencies, [seq_no, this]() {
|
waiter_.Wait(dependencies, [seq_no, this]() {
|
||||||
RAY_CHECK(boost::this_thread::get_id() == main_thread_id_);
|
RAY_CHECK(boost::this_thread::get_id() == main_thread_id_);
|
||||||
|
@ -541,15 +548,17 @@ class NormalSchedulingQueue : public SchedulingQueue {
|
||||||
|
|
||||||
/// Add a new task's callbacks to the worker queue.
|
/// Add a new task's callbacks to the worker queue.
|
||||||
void Add(int64_t seq_no, int64_t client_processed_up_to,
|
void Add(int64_t seq_no, int64_t client_processed_up_to,
|
||||||
std::function<void()> accept_request, std::function<void()> reject_request,
|
std::function<void(rpc::SendReplyCallback)> accept_request,
|
||||||
TaskID task_id = TaskID::Nil(),
|
std::function<void(rpc::SendReplyCallback)> reject_request,
|
||||||
|
rpc::SendReplyCallback send_reply_callback, TaskID task_id = TaskID::Nil(),
|
||||||
const std::vector<rpc::ObjectReference> &dependencies = {}) {
|
const std::vector<rpc::ObjectReference> &dependencies = {}) {
|
||||||
absl::MutexLock lock(&mu_);
|
absl::MutexLock lock(&mu_);
|
||||||
// Normal tasks should not have ordering constraints.
|
// Normal tasks should not have ordering constraints.
|
||||||
RAY_CHECK(seq_no == -1);
|
RAY_CHECK(seq_no == -1);
|
||||||
// Create a InboundRequest object for the new task, and add it to the queue.
|
// Create a InboundRequest object for the new task, and add it to the queue.
|
||||||
pending_normal_tasks_.push_back(
|
pending_normal_tasks_.push_back(
|
||||||
InboundRequest(accept_request, reject_request, task_id, dependencies.size() > 0));
|
InboundRequest(std::move(accept_request), std::move(reject_request),
|
||||||
|
std::move(send_reply_callback), task_id, dependencies.size() > 0));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Search for an InboundRequest associated with the task that we are trying to cancel.
|
// Search for an InboundRequest associated with the task that we are trying to cancel.
|
||||||
|
|
|
@ -66,7 +66,8 @@ class ClientCallImpl : public ClientCall {
|
||||||
///
|
///
|
||||||
/// \param[in] callback The callback function to handle the reply.
|
/// \param[in] callback The callback function to handle the reply.
|
||||||
explicit ClientCallImpl(const ClientCallback<Reply> &callback, std::string call_name)
|
explicit ClientCallImpl(const ClientCallback<Reply> &callback, std::string call_name)
|
||||||
: callback_(callback), call_name_(std::move(call_name)) {}
|
: callback_(std::move(const_cast<ClientCallback<Reply> &>(callback))),
|
||||||
|
call_name_(std::move(call_name)) {}
|
||||||
|
|
||||||
Status GetStatus() override {
|
Status GetStatus() override {
|
||||||
absl::MutexLock lock(&mutex_);
|
absl::MutexLock lock(&mutex_);
|
||||||
|
|
|
@ -285,7 +285,9 @@ class CoreWorkerClient : public std::enable_shared_from_this<CoreWorkerClient>,
|
||||||
|
|
||||||
{
|
{
|
||||||
absl::MutexLock lock(&mutex_);
|
absl::MutexLock lock(&mutex_);
|
||||||
send_queue_.push_back(std::make_pair(std::move(request), callback));
|
send_queue_.push_back(std::make_pair(
|
||||||
|
std::move(request),
|
||||||
|
std::move(const_cast<ClientCallback<PushTaskReply> &>(callback))));
|
||||||
}
|
}
|
||||||
SendRequests();
|
SendRequests();
|
||||||
}
|
}
|
||||||
|
@ -311,13 +313,13 @@ class CoreWorkerClient : public std::enable_shared_from_this<CoreWorkerClient>,
|
||||||
send_queue_.pop_front();
|
send_queue_.pop_front();
|
||||||
|
|
||||||
auto request = std::move(pair.first);
|
auto request = std::move(pair.first);
|
||||||
auto callback = pair.second;
|
|
||||||
int64_t task_size = RequestSizeInBytes(*request);
|
int64_t task_size = RequestSizeInBytes(*request);
|
||||||
int64_t seq_no = request->sequence_number();
|
int64_t seq_no = request->sequence_number();
|
||||||
request->set_client_processed_up_to(max_finished_seq_no_);
|
request->set_client_processed_up_to(max_finished_seq_no_);
|
||||||
rpc_bytes_in_flight_ += task_size;
|
rpc_bytes_in_flight_ += task_size;
|
||||||
|
|
||||||
auto rpc_callback = [this, this_ptr, seq_no, task_size, callback](
|
auto rpc_callback = [this, this_ptr, seq_no, task_size,
|
||||||
|
callback = std::move(pair.second)](
|
||||||
Status status, const rpc::PushTaskReply &reply) {
|
Status status, const rpc::PushTaskReply &reply) {
|
||||||
{
|
{
|
||||||
absl::MutexLock lock(&mutex_);
|
absl::MutexLock lock(&mutex_);
|
||||||
|
@ -331,8 +333,8 @@ class CoreWorkerClient : public std::enable_shared_from_this<CoreWorkerClient>,
|
||||||
callback(status, reply);
|
callback(status, reply);
|
||||||
};
|
};
|
||||||
|
|
||||||
RAY_UNUSED(INVOKE_RPC_CALL(CoreWorkerService, PushTask, *request, rpc_callback,
|
RAY_UNUSED(INVOKE_RPC_CALL(CoreWorkerService, PushTask, *request,
|
||||||
grpc_client_));
|
std::move(rpc_callback), grpc_client_));
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!send_queue_.empty()) {
|
if (!send_queue_.empty()) {
|
||||||
|
|
Loading…
Add table
Reference in a new issue